Remove partially initialized Tensor in Deserialization (#14197)
authorJerry Zhang <jerryzh@fb.com>
Tue, 11 Dec 2018 01:13:51 +0000 (17:13 -0800)
committerFacebook Github Bot <facebook-github-bot@users.noreply.github.com>
Tue, 11 Dec 2018 01:17:29 +0000 (17:17 -0800)
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/14197

Pull Request resolved: https://github.com/pytorch/pytorch/pull/13642

Previously we pass in a patially initialized Tensor to Deserialize and it will fill
it with the result of deserialization of a tensor proto. Now we want it to return
a Tensor directly since it's just a shared pointer to TensorImpl.

Reviewed By: dzhulgakov

Differential Revision: D12874357

fbshipit-source-id: 12b80a763375da23cfa64a74d6bc186d8d03b94f

binaries/benchmark_helper.cc
caffe2/core/blob_serialization.cc
caffe2/core/blob_serialization.h
caffe2/core/blob_test.cc
caffe2/operators/map_ops.h
caffe2/operators/tensor_protos_db_input.h

index 1020dc1..ab53a51 100644 (file)
@@ -231,7 +231,7 @@ void fillInputBlob(
       // int total_size = tensor_proto->float_data_size();
       caffe2::TensorCPU* tensor =
           new caffe2::TensorCPU(dims, caffe2::DeviceType::CPU);
-      serializer.Deserialize(*tensor_proto, tensor);
+      serializer.DeserializeToTensor(*tensor_proto, tensor);
       blob->Reset(tensor);
     }
     // todo: for other types
index 35e52b3..13eab5b 100644 (file)
@@ -387,32 +387,112 @@ void DeserializeBlob(const BlobProto& blob_proto, Blob* result) {
   }
 }
 
+// === Local helper functions ===
+// Get dimensions from Tensor proto
+static std::vector<int64_t> DimsFromTensorProto(const TensorProto& proto) {
+  std::vector<int64_t> dims;
+  for (const int64_t d : proto.dims()) {
+    dims.push_back(d);
+  }
+  return dims;
+}
+
+// Get number of elements from Tensor proto
+static int64_t NumelFromTensorProto(const TensorProto& tensor_proto) {
+  int64_t numel = 1;
+  for (const int64_t d : tensor_proto.dims()) {
+    numel *= d;
+  }
+  return numel;
+}
+
+// Get data type from Tensor proto
+static TypeMeta GetDataType(const TensorProto& tensor_proto) {
+  TypeMeta dtype;
+  if (tensor_proto.data_type() != TensorProto_DataType_UNDEFINED) {
+    dtype = DataTypeToTypeMeta(tensor_proto.data_type());
+  } else {
+    Blob temp_blob;
+    DeserializeBlob(tensor_proto.string_data(0), &temp_blob);
+    dtype = temp_blob.meta();
+  }
+  return dtype;
+}
+
+// Get TensorOptions from Tensor proto
+// Assumes TensorProto is not empty
+static at::TensorOptions TensorOptionsFromProto(
+    const TensorProto& tensor_proto) {
+  return at::dtype(GetDataType(tensor_proto))
+      .device(OptionToDevice(tensor_proto.device_detail()));
+}
+
+static std::unique_ptr<BaseContext> ContextFromProto(
+    const TensorProto& tensor_proto) {
+  auto device = OptionToDevice(tensor_proto.device_detail());
+  return CreateContext(device);
+}
+
+// === Local helper functions ===
+
+Tensor EmptyTensorFromProto(const TensorProto& tensor_proto) {
+  auto context = ContextFromProto(tensor_proto);
+  context->SwitchToDevice(0);
+  if (NumelFromTensorProto(tensor_proto) == 0 &&
+      tensor_proto.data_type() == TensorProto_DataType_UNDEFINED) {
+    // TODO: remove when serialization of dtype uninitialized tensor is removed
+    return caffe2::empty(
+        {0},
+        at::dtype<float>().device(
+            OptionToDevice(tensor_proto.device_detail())));
+  } else {
+    return caffe2::empty(
+        DimsFromTensorProto(tensor_proto),
+        TensorOptionsFromProto(tensor_proto));
+  }
+}
+
 void TensorDeserializer::Deserialize(const BlobProto& blob_proto, Blob* blob) {
   auto tensor_proto = blob_proto.tensor();
-  Deserialize(
-      tensor_proto,
-      BlobGetMutableTensor(
-          blob,
-          static_cast<DeviceType>(tensor_proto.device_detail().device_type())));
+  auto context = ContextFromProto(tensor_proto);
+  context->SwitchToDevice(0);
+  if (NumelFromTensorProto(tensor_proto) == 0 &&
+      tensor_proto.data_type() == TensorProto_DataType_UNDEFINED) {
+    // TODO: remove after empty Tensor serialization is forbidden
+    VLOG(1) << "Deseriralizing an empty Tensor.";
+    BlobGetMutableTensor(
+        blob,
+        {0},
+        at::dtype<float>().device(
+            OptionToDevice(tensor_proto.device_detail())));
+  } else {
+    DeserializeToTensor(
+        tensor_proto,
+        BlobGetMutableTensor(
+            blob,
+            DimsFromTensorProto(tensor_proto),
+            TensorOptionsFromProto(tensor_proto)));
+  }
 }
 
-void TensorDeserializer::Deserialize(const TensorProto& proto, Tensor* tensor) {
+void TensorDeserializer::DeserializeToTensor(
+    const TensorProto& tensor_proto,
+    Tensor* tensor) {
+  CAFFE_ENFORCE(
+      tensor->storage_initialized() && tensor->dtype_initialized(),
+      "Tensor must be initialized before passed into Deserialize function.");
   // We create a local context for deserializing. Since Caffe2 contexts are
   // usually lightweight, this should not involve too much overhead.
-  auto uniq_ptr = CreateContext(OptionToDevice(proto.device_detail()));
+  auto uniq_ptr = ContextFromProto(tensor_proto);
+  // since CopyFromProtoAsIs accepts BaseContext*
   auto context = uniq_ptr.get();
   context->SwitchToDevice(0);
-  vector<int64_t> dims;
-  for (const int64_t d : proto.dims()) {
-    dims.push_back(d);
-  }
-  tensor->Resize(dims);
 
   int64_t chunkBegin = 0;
   auto chunkEnd = tensor->numel();
-  if (proto.has_segment()) {
-    chunkBegin = proto.segment().begin();
-    chunkEnd = proto.segment().end();
+  if (tensor_proto.has_segment()) {
+    chunkBegin = tensor_proto.segment().begin();
+    chunkEnd = tensor_proto.segment().end();
   }
   CAFFE_ENFORCE(
       0 <= chunkBegin && chunkBegin <= chunkEnd && chunkEnd <= tensor->numel(),
@@ -424,18 +504,18 @@ void TensorDeserializer::Deserialize(const TensorProto& proto, Tensor* tensor) {
       tensor->numel());
   auto chunkSize = chunkEnd - chunkBegin;
 
-  switch (proto.data_type()) {
+  switch (tensor_proto.data_type()) {
     case TensorProto_DataType_FLOAT:
       detail::CopyFromProtoAsIs(
           chunkSize,
-          proto.float_data(),
+          tensor_proto.float_data(),
           tensor->template mutable_data<float>() + chunkBegin,
           context);
       break;
     case TensorProto_DataType_INT32:
       detail::CopyFromProtoAsIs(
           chunkSize,
-          proto.int32_data(),
+          tensor_proto.int32_data(),
           tensor->template mutable_data<int>() + chunkBegin,
           context);
       break;
@@ -443,10 +523,12 @@ void TensorDeserializer::Deserialize(const TensorProto& proto, Tensor* tensor) {
       // Since BYTE stores the data in a string field instead of a repreated
       // field we will have it special cased.
       CAFFE_ENFORCE_EQ(
-          chunkSize, proto.byte_data().size(), "Incorrect proto field size.");
+          chunkSize,
+          tensor_proto.byte_data().size(),
+          "Incorrect proto field size.");
       context->template CopyToCPU<uint8_t>(
           chunkSize,
-          reinterpret_cast<const uint8_t*>(proto.byte_data().data()),
+          reinterpret_cast<const uint8_t*>(tensor_proto.byte_data().data()),
           tensor->template mutable_data<uint8_t>() + chunkBegin);
       break;
     case TensorProto_DataType_STRING:
@@ -454,54 +536,54 @@ void TensorDeserializer::Deserialize(const TensorProto& proto, Tensor* tensor) {
       {
         string* content = tensor->template mutable_data<string>();
         for (int i = 0; i < chunkSize; ++i) {
-          content[i + chunkBegin] = proto.string_data(i);
+          content[i + chunkBegin] = tensor_proto.string_data(i);
         }
       }
       break;
     case TensorProto_DataType_BOOL:
       detail::CopyFromProtoWithCast(
           chunkSize,
-          proto.int32_data(),
+          tensor_proto.int32_data(),
           tensor->template mutable_data<bool>() + chunkBegin,
           context);
       break;
     case TensorProto_DataType_UINT8:
       detail::CopyFromProtoWithCast(
           chunkSize,
-          proto.int32_data(),
+          tensor_proto.int32_data(),
           tensor->template mutable_data<uint8_t>() + chunkBegin,
           context);
       break;
     case TensorProto_DataType_INT8:
       detail::CopyFromProtoWithCast(
           chunkSize,
-          proto.int32_data(),
+          tensor_proto.int32_data(),
           tensor->template mutable_data<int8_t>() + chunkBegin,
           context);
       break;
     case TensorProto_DataType_UINT16:
       detail::CopyFromProtoWithCast(
           chunkSize,
-          proto.int32_data(),
+          tensor_proto.int32_data(),
           tensor->template mutable_data<uint16_t>() + chunkBegin,
           context);
       break;
     case TensorProto_DataType_INT16:
       detail::CopyFromProtoWithCast(
           chunkSize,
-          proto.int32_data(),
+          tensor_proto.int32_data(),
           tensor->template mutable_data<int16_t>() + chunkBegin,
           context);
       break;
     case TensorProto_DataType_INT64:
       detail::CopyFromProtoAsIs(
           chunkSize,
-          proto.int64_data(),
+          tensor_proto.int64_data(),
           tensor->template mutable_data<int64_t>() + chunkBegin,
           context);
       break;
     case TensorProto_DataType_FLOAT16:
-      if (proto.has_byte_data()) {
+      if (tensor_proto.has_byte_data()) {
         const int kValue = 1;
         CAFFE_ENFORCE_EQ(
             reinterpret_cast<const char*>(&kValue)[0],
@@ -510,17 +592,17 @@ void TensorDeserializer::Deserialize(const TensorProto& proto, Tensor* tensor) {
             "is not written yet.");
         CAFFE_ENFORCE_EQ(
             2 * chunkSize,
-            proto.byte_data().size(),
+            tensor_proto.byte_data().size(),
             "Incorrect proto field size.");
         context->template CopyToCPU<at::Half>(
             chunkSize,
-            reinterpret_cast<const at::Half*>(proto.byte_data().data()),
+            reinterpret_cast<const at::Half*>(tensor_proto.byte_data().data()),
             tensor->template mutable_data<at::Half>() + chunkBegin);
       } else {
         // Backward compatibility with models which used int32_data field
         detail::CopyFromProtoWithCast(
             chunkSize,
-            proto.int32_data(),
+            tensor_proto.int32_data(),
             reinterpret_cast<uint16_t*>(
                 tensor->template mutable_data<at::Half>()) +
                 chunkBegin,
@@ -530,7 +612,7 @@ void TensorDeserializer::Deserialize(const TensorProto& proto, Tensor* tensor) {
     case TensorProto_DataType_DOUBLE:
       detail::CopyFromProtoAsIs(
           chunkSize,
-          proto.double_data(),
+          tensor_proto.double_data(),
           tensor->template mutable_data<double>() + chunkBegin,
           context);
       break;
@@ -538,7 +620,7 @@ void TensorDeserializer::Deserialize(const TensorProto& proto, Tensor* tensor) {
       Blob temp_blob;
       void* raw_ptr = nullptr;
       for (int i = 0; i < chunkSize; ++i) {
-        DeserializeBlob(proto.string_data(i), &temp_blob);
+        DeserializeBlob(tensor_proto.string_data(i), &temp_blob);
         if (i == 0) {
           raw_ptr = tensor->raw_mutable_data(temp_blob.meta());
         }
@@ -554,6 +636,12 @@ void TensorDeserializer::Deserialize(const TensorProto& proto, Tensor* tensor) {
   context->FinishDeviceComputation();
 }
 
+Tensor TensorDeserializer::Deserialize(const TensorProto& tensor_proto) {
+  auto tensor = EmptyTensorFromProto(tensor_proto);
+  DeserializeToTensor(tensor_proto, &tensor);
+  return tensor;
+}
+
 ////////////////////////////////////////////////////////////////////////////////
 // Serialization Helpers
 ////////////////////////////////////////////////////////////////////////////////
index 910d4da..f61331e 100644 (file)
@@ -55,6 +55,28 @@ CAFFE2_API string SerializeBlob(const Blob& blob, const string& name);
 CAFFE2_API void DeserializeBlob(const string& content, Blob* result);
 CAFFE2_API void DeserializeBlob(const BlobProto& proto, Blob* result);
 
+/*
+ * Get an empty Tensor from the TensorProto given the meta data in proto (data
+ * type and size of the Tensor) without actually filling in the data.
+ *
+ * We need this function because we want to construct a fully initialized Tensor
+ * in the beginning instead of keeping partially initialized Tensor around the
+ * process. Consider the case when we have a Tensor that is split into multiple
+ * protos during serialization, in deserialization, we have to fill the Tensor
+ * in multiple calls to Deserialize, therefore we need to create a new Tensor
+ * with the correct size and data type before the call to Deserialize, because
+ * otherwise we will have to check whether the function call is the first call
+ * to initialize the underlying Tensor, which makes the function stateful and
+ * complicated.
+ *
+ * The legacy code get away with this problem by passing in a partially
+ * initialized Tensor and use Resize and mutable_data to set the correct size,
+ * data type and allocate memory for the Tensor, so the state is encoded in
+ * these function calls. e.g. mutable_data will allocate memory on the first
+ * call and it will return a pointer to the allocated memory on later calls.
+ */
+CAFFE2_API Tensor EmptyTensorFromProto(const TensorProto& proto);
+
 /**
  * @brief TensorSerializer is the serializer for Tensors.
  *
@@ -106,7 +128,22 @@ class CAFFE2_API TensorSerializer : public BlobSerializerBase {
 class CAFFE2_API TensorDeserializer : public BlobDeserializerBase {
  public:
   void Deserialize(const BlobProto& proto, Blob* blob) override;
-  void Deserialize(const TensorProto& proto, Tensor* tensor);
+
+  /* There are cases when a Tensor is split into multiple protos and
+   * we have to call Deserialize multiple times to get the complete deserialized
+   * Tensor, each call will fill part of the Tensor given the segment begin and
+   * end information in proto, therefore we have to pass in the Tensor pointer
+   * rather than create a new Tensor everytime.
+   *
+   * Precondition: Tensor must be initialized
+   */
+  void DeserializeToTensor(const TensorProto& proto, Tensor* tensor);
+
+  /* Deserialize the proto and return a new Tensor
+   * This is a utility function that combines EmptyTensorFromProto and
+   * Deserialize(const TensorProto&, Tensor*);
+   */
+  Tensor Deserialize(const TensorProto& proto);
 };
 
 ////////////////////////////////////////////////////////////////////////////////
index 106cea5..290e310 100644 (file)
@@ -1142,10 +1142,12 @@ TEST(TensorSerialization, MistakenlySerializingDtypeUninitializedTensor) {
   LOG(INFO) << "serialized proto: " << b.DebugString();
 
   Blob new_blob;
+  // Deserializing an empty Tensor gives a {0}-dim, float CPU Tensor
   DeserializeBlob(output, &new_blob);
   const Tensor& new_tensor = new_blob.Get<Tensor>();
   LOG(INFO) << "tensor " << new_tensor.DebugString();
-  EXPECT_FALSE(new_tensor.dtype_initialized());
+  EXPECT_TRUE(new_tensor.dtype_initialized());
+  LOG(INFO) << "dtype:" << new_tensor.dtype();
   EXPECT_EQ(0, new_tensor.numel());
   EXPECT_EQ(1, new_tensor.dim());
 }
index a43891d..966968f 100644 (file)
@@ -245,9 +245,8 @@ class MapDeserializer : public BlobDeserializerBase {
         tensor_protos.ParseFromString(proto.content()),
         "Fail to parse TensorProtos");
     TensorDeserializer deser;
-    Tensor key_tensor(CPU), value_tensor(CPU);
-    deser.Deserialize(tensor_protos.protos(0), &key_tensor);
-    deser.Deserialize(tensor_protos.protos(1), &value_tensor);
+    Tensor key_tensor = deser.Deserialize(tensor_protos.protos(0));
+    Tensor value_tensor = deser.Deserialize(tensor_protos.protos(1));
     auto* key_data = key_tensor.data<KEY_T>();
     auto* value_data = value_tensor.data<VALUE_T>();
 
index 8ba2ff5..d07763c 100644 (file)
@@ -55,35 +55,30 @@ bool TensorProtosDBInput<Context>::Prefetch() {
       if (protos.protos(i).has_device_detail()) {
         protos.mutable_protos(i)->clear_device_detail();
       }
-      deserializer.Deserialize(
-          protos.protos(i), BlobGetMutableTensor(&prefetched_blobs_[i], CPU));
+      BlobSetTensor(
+          &prefetched_blobs_[i], deserializer.Deserialize(protos.protos(i)));
+      // deserializer.Deserialize(
+      //     protos.protos(i), BlobGetMutableTensor(&prefetched_blobs_[i],
+      //     CPU));
     }
   } else {
-    vector<Tensor> temp_tensors;
-    for (int i = 0; i < OutputSize(); ++i) {
-      temp_tensors.emplace_back(CPU);
-    }
     for (int item_id = 0; item_id < batch_size_; ++item_id) {
       reader.Read(&key_, &value_);
       TensorProtos protos;
       CAFFE_ENFORCE(protos.ParseFromString(value_));
       CAFFE_ENFORCE(protos.protos_size() == OutputSize());
-      if (!shape_inferred_) {
-        // First, set the shape of all the blobs.
-        for (int i = 0; i < protos.protos_size(); ++i) {
-          vector<int> dims(
-              protos.protos(i).dims().begin(), protos.protos(i).dims().end());
-          dims.insert(dims.begin(), batch_size_);
-          BlobGetMutableTensor(&prefetched_blobs_[i], CPU)->Resize(dims);
-        }
-      }
+      // Note: shape_inferred_ is ignored, we'll always get dimensions from
+      // proto
       for (int i = 0; i < protos.protos_size(); ++i) {
-        TensorCPU* dst = BlobGetMutableTensor(&prefetched_blobs_[i], CPU);
-        TensorCPU& src = temp_tensors[i];
+        vector<int64_t> dims(
+            protos.protos(i).dims().begin(), protos.protos(i).dims().end());
+        dims.insert(dims.begin(), batch_size_);
         if (protos.protos(i).has_device_detail()) {
           protos.mutable_protos(i)->clear_device_detail();
         }
-        deserializer.Deserialize(protos.protos(i), &src);
+        Tensor src = deserializer.Deserialize(protos.protos(i));
+        Tensor* dst = BlobGetMutableTensor(
+            &prefetched_blobs_[i], dims, at::dtype(src.dtype()).device(CPU));
         DCHECK_EQ(src.numel() * batch_size_, dst->numel());
         this->context_.CopyItemsSameDevice(
             src.dtype(),