}
}
+// === Local helper functions ===
+// Get dimensions from Tensor proto
+static std::vector<int64_t> DimsFromTensorProto(const TensorProto& proto) {
+ std::vector<int64_t> dims;
+ for (const int64_t d : proto.dims()) {
+ dims.push_back(d);
+ }
+ return dims;
+}
+
+// Get number of elements from Tensor proto
+static int64_t NumelFromTensorProto(const TensorProto& tensor_proto) {
+ int64_t numel = 1;
+ for (const int64_t d : tensor_proto.dims()) {
+ numel *= d;
+ }
+ return numel;
+}
+
+// Get data type from Tensor proto
+static TypeMeta GetDataType(const TensorProto& tensor_proto) {
+ TypeMeta dtype;
+ if (tensor_proto.data_type() != TensorProto_DataType_UNDEFINED) {
+ dtype = DataTypeToTypeMeta(tensor_proto.data_type());
+ } else {
+ Blob temp_blob;
+ DeserializeBlob(tensor_proto.string_data(0), &temp_blob);
+ dtype = temp_blob.meta();
+ }
+ return dtype;
+}
+
+// Get TensorOptions from Tensor proto
+// Assumes TensorProto is not empty
+static at::TensorOptions TensorOptionsFromProto(
+ const TensorProto& tensor_proto) {
+ return at::dtype(GetDataType(tensor_proto))
+ .device(OptionToDevice(tensor_proto.device_detail()));
+}
+
+static std::unique_ptr<BaseContext> ContextFromProto(
+ const TensorProto& tensor_proto) {
+ auto device = OptionToDevice(tensor_proto.device_detail());
+ return CreateContext(device);
+}
+
+// === Local helper functions ===
+
+Tensor EmptyTensorFromProto(const TensorProto& tensor_proto) {
+ auto context = ContextFromProto(tensor_proto);
+ context->SwitchToDevice(0);
+ if (NumelFromTensorProto(tensor_proto) == 0 &&
+ tensor_proto.data_type() == TensorProto_DataType_UNDEFINED) {
+ // TODO: remove when serialization of dtype uninitialized tensor is removed
+ return caffe2::empty(
+ {0},
+ at::dtype<float>().device(
+ OptionToDevice(tensor_proto.device_detail())));
+ } else {
+ return caffe2::empty(
+ DimsFromTensorProto(tensor_proto),
+ TensorOptionsFromProto(tensor_proto));
+ }
+}
+
void TensorDeserializer::Deserialize(const BlobProto& blob_proto, Blob* blob) {
auto tensor_proto = blob_proto.tensor();
- Deserialize(
- tensor_proto,
- BlobGetMutableTensor(
- blob,
- static_cast<DeviceType>(tensor_proto.device_detail().device_type())));
+ auto context = ContextFromProto(tensor_proto);
+ context->SwitchToDevice(0);
+ if (NumelFromTensorProto(tensor_proto) == 0 &&
+ tensor_proto.data_type() == TensorProto_DataType_UNDEFINED) {
+ // TODO: remove after empty Tensor serialization is forbidden
+ VLOG(1) << "Deseriralizing an empty Tensor.";
+ BlobGetMutableTensor(
+ blob,
+ {0},
+ at::dtype<float>().device(
+ OptionToDevice(tensor_proto.device_detail())));
+ } else {
+ DeserializeToTensor(
+ tensor_proto,
+ BlobGetMutableTensor(
+ blob,
+ DimsFromTensorProto(tensor_proto),
+ TensorOptionsFromProto(tensor_proto)));
+ }
}
-void TensorDeserializer::Deserialize(const TensorProto& proto, Tensor* tensor) {
+void TensorDeserializer::DeserializeToTensor(
+ const TensorProto& tensor_proto,
+ Tensor* tensor) {
+ CAFFE_ENFORCE(
+ tensor->storage_initialized() && tensor->dtype_initialized(),
+ "Tensor must be initialized before passed into Deserialize function.");
// We create a local context for deserializing. Since Caffe2 contexts are
// usually lightweight, this should not involve too much overhead.
- auto uniq_ptr = CreateContext(OptionToDevice(proto.device_detail()));
+ auto uniq_ptr = ContextFromProto(tensor_proto);
+ // since CopyFromProtoAsIs accepts BaseContext*
auto context = uniq_ptr.get();
context->SwitchToDevice(0);
- vector<int64_t> dims;
- for (const int64_t d : proto.dims()) {
- dims.push_back(d);
- }
- tensor->Resize(dims);
int64_t chunkBegin = 0;
auto chunkEnd = tensor->numel();
- if (proto.has_segment()) {
- chunkBegin = proto.segment().begin();
- chunkEnd = proto.segment().end();
+ if (tensor_proto.has_segment()) {
+ chunkBegin = tensor_proto.segment().begin();
+ chunkEnd = tensor_proto.segment().end();
}
CAFFE_ENFORCE(
0 <= chunkBegin && chunkBegin <= chunkEnd && chunkEnd <= tensor->numel(),
tensor->numel());
auto chunkSize = chunkEnd - chunkBegin;
- switch (proto.data_type()) {
+ switch (tensor_proto.data_type()) {
case TensorProto_DataType_FLOAT:
detail::CopyFromProtoAsIs(
chunkSize,
- proto.float_data(),
+ tensor_proto.float_data(),
tensor->template mutable_data<float>() + chunkBegin,
context);
break;
case TensorProto_DataType_INT32:
detail::CopyFromProtoAsIs(
chunkSize,
- proto.int32_data(),
+ tensor_proto.int32_data(),
tensor->template mutable_data<int>() + chunkBegin,
context);
break;
// Since BYTE stores the data in a string field instead of a repreated
// field we will have it special cased.
CAFFE_ENFORCE_EQ(
- chunkSize, proto.byte_data().size(), "Incorrect proto field size.");
+ chunkSize,
+ tensor_proto.byte_data().size(),
+ "Incorrect proto field size.");
context->template CopyToCPU<uint8_t>(
chunkSize,
- reinterpret_cast<const uint8_t*>(proto.byte_data().data()),
+ reinterpret_cast<const uint8_t*>(tensor_proto.byte_data().data()),
tensor->template mutable_data<uint8_t>() + chunkBegin);
break;
case TensorProto_DataType_STRING:
{
string* content = tensor->template mutable_data<string>();
for (int i = 0; i < chunkSize; ++i) {
- content[i + chunkBegin] = proto.string_data(i);
+ content[i + chunkBegin] = tensor_proto.string_data(i);
}
}
break;
case TensorProto_DataType_BOOL:
detail::CopyFromProtoWithCast(
chunkSize,
- proto.int32_data(),
+ tensor_proto.int32_data(),
tensor->template mutable_data<bool>() + chunkBegin,
context);
break;
case TensorProto_DataType_UINT8:
detail::CopyFromProtoWithCast(
chunkSize,
- proto.int32_data(),
+ tensor_proto.int32_data(),
tensor->template mutable_data<uint8_t>() + chunkBegin,
context);
break;
case TensorProto_DataType_INT8:
detail::CopyFromProtoWithCast(
chunkSize,
- proto.int32_data(),
+ tensor_proto.int32_data(),
tensor->template mutable_data<int8_t>() + chunkBegin,
context);
break;
case TensorProto_DataType_UINT16:
detail::CopyFromProtoWithCast(
chunkSize,
- proto.int32_data(),
+ tensor_proto.int32_data(),
tensor->template mutable_data<uint16_t>() + chunkBegin,
context);
break;
case TensorProto_DataType_INT16:
detail::CopyFromProtoWithCast(
chunkSize,
- proto.int32_data(),
+ tensor_proto.int32_data(),
tensor->template mutable_data<int16_t>() + chunkBegin,
context);
break;
case TensorProto_DataType_INT64:
detail::CopyFromProtoAsIs(
chunkSize,
- proto.int64_data(),
+ tensor_proto.int64_data(),
tensor->template mutable_data<int64_t>() + chunkBegin,
context);
break;
case TensorProto_DataType_FLOAT16:
- if (proto.has_byte_data()) {
+ if (tensor_proto.has_byte_data()) {
const int kValue = 1;
CAFFE_ENFORCE_EQ(
reinterpret_cast<const char*>(&kValue)[0],
"is not written yet.");
CAFFE_ENFORCE_EQ(
2 * chunkSize,
- proto.byte_data().size(),
+ tensor_proto.byte_data().size(),
"Incorrect proto field size.");
context->template CopyToCPU<at::Half>(
chunkSize,
- reinterpret_cast<const at::Half*>(proto.byte_data().data()),
+ reinterpret_cast<const at::Half*>(tensor_proto.byte_data().data()),
tensor->template mutable_data<at::Half>() + chunkBegin);
} else {
// Backward compatibility with models which used int32_data field
detail::CopyFromProtoWithCast(
chunkSize,
- proto.int32_data(),
+ tensor_proto.int32_data(),
reinterpret_cast<uint16_t*>(
tensor->template mutable_data<at::Half>()) +
chunkBegin,
case TensorProto_DataType_DOUBLE:
detail::CopyFromProtoAsIs(
chunkSize,
- proto.double_data(),
+ tensor_proto.double_data(),
tensor->template mutable_data<double>() + chunkBegin,
context);
break;
Blob temp_blob;
void* raw_ptr = nullptr;
for (int i = 0; i < chunkSize; ++i) {
- DeserializeBlob(proto.string_data(i), &temp_blob);
+ DeserializeBlob(tensor_proto.string_data(i), &temp_blob);
if (i == 0) {
raw_ptr = tensor->raw_mutable_data(temp_blob.meta());
}
context->FinishDeviceComputation();
}
+Tensor TensorDeserializer::Deserialize(const TensorProto& tensor_proto) {
+ auto tensor = EmptyTensorFromProto(tensor_proto);
+ DeserializeToTensor(tensor_proto, &tensor);
+ return tensor;
+}
+
////////////////////////////////////////////////////////////////////////////////
// Serialization Helpers
////////////////////////////////////////////////////////////////////////////////
CAFFE2_API void DeserializeBlob(const string& content, Blob* result);
CAFFE2_API void DeserializeBlob(const BlobProto& proto, Blob* result);
+/*
+ * Get an empty Tensor from the TensorProto given the meta data in proto (data
+ * type and size of the Tensor) without actually filling in the data.
+ *
+ * We need this function because we want to construct a fully initialized Tensor
+ * in the beginning instead of keeping partially initialized Tensor around the
+ * process. Consider the case when we have a Tensor that is split into multiple
+ * protos during serialization, in deserialization, we have to fill the Tensor
+ * in multiple calls to Deserialize, therefore we need to create a new Tensor
+ * with the correct size and data type before the call to Deserialize, because
+ * otherwise we will have to check whether the function call is the first call
+ * to initialize the underlying Tensor, which makes the function stateful and
+ * complicated.
+ *
+ * The legacy code get away with this problem by passing in a partially
+ * initialized Tensor and use Resize and mutable_data to set the correct size,
+ * data type and allocate memory for the Tensor, so the state is encoded in
+ * these function calls. e.g. mutable_data will allocate memory on the first
+ * call and it will return a pointer to the allocated memory on later calls.
+ */
+CAFFE2_API Tensor EmptyTensorFromProto(const TensorProto& proto);
+
/**
* @brief TensorSerializer is the serializer for Tensors.
*
class CAFFE2_API TensorDeserializer : public BlobDeserializerBase {
public:
void Deserialize(const BlobProto& proto, Blob* blob) override;
- void Deserialize(const TensorProto& proto, Tensor* tensor);
+
+ /* There are cases when a Tensor is split into multiple protos and
+ * we have to call Deserialize multiple times to get the complete deserialized
+ * Tensor, each call will fill part of the Tensor given the segment begin and
+ * end information in proto, therefore we have to pass in the Tensor pointer
+ * rather than create a new Tensor everytime.
+ *
+ * Precondition: Tensor must be initialized
+ */
+ void DeserializeToTensor(const TensorProto& proto, Tensor* tensor);
+
+ /* Deserialize the proto and return a new Tensor
+ * This is a utility function that combines EmptyTensorFromProto and
+ * Deserialize(const TensorProto&, Tensor*);
+ */
+ Tensor Deserialize(const TensorProto& proto);
};
////////////////////////////////////////////////////////////////////////////////
if (protos.protos(i).has_device_detail()) {
protos.mutable_protos(i)->clear_device_detail();
}
- deserializer.Deserialize(
- protos.protos(i), BlobGetMutableTensor(&prefetched_blobs_[i], CPU));
+ BlobSetTensor(
+ &prefetched_blobs_[i], deserializer.Deserialize(protos.protos(i)));
+ // deserializer.Deserialize(
+ // protos.protos(i), BlobGetMutableTensor(&prefetched_blobs_[i],
+ // CPU));
}
} else {
- vector<Tensor> temp_tensors;
- for (int i = 0; i < OutputSize(); ++i) {
- temp_tensors.emplace_back(CPU);
- }
for (int item_id = 0; item_id < batch_size_; ++item_id) {
reader.Read(&key_, &value_);
TensorProtos protos;
CAFFE_ENFORCE(protos.ParseFromString(value_));
CAFFE_ENFORCE(protos.protos_size() == OutputSize());
- if (!shape_inferred_) {
- // First, set the shape of all the blobs.
- for (int i = 0; i < protos.protos_size(); ++i) {
- vector<int> dims(
- protos.protos(i).dims().begin(), protos.protos(i).dims().end());
- dims.insert(dims.begin(), batch_size_);
- BlobGetMutableTensor(&prefetched_blobs_[i], CPU)->Resize(dims);
- }
- }
+ // Note: shape_inferred_ is ignored, we'll always get dimensions from
+ // proto
for (int i = 0; i < protos.protos_size(); ++i) {
- TensorCPU* dst = BlobGetMutableTensor(&prefetched_blobs_[i], CPU);
- TensorCPU& src = temp_tensors[i];
+ vector<int64_t> dims(
+ protos.protos(i).dims().begin(), protos.protos(i).dims().end());
+ dims.insert(dims.begin(), batch_size_);
if (protos.protos(i).has_device_detail()) {
protos.mutable_protos(i)->clear_device_detail();
}
- deserializer.Deserialize(protos.protos(i), &src);
+ Tensor src = deserializer.Deserialize(protos.protos(i));
+ Tensor* dst = BlobGetMutableTensor(
+ &prefetched_blobs_[i], dims, at::dtype(src.dtype()).device(CPU));
DCHECK_EQ(src.numel() * batch_size_, dst->numel());
this->context_.CopyItemsSameDevice(
src.dtype(),