[nnc] Refactor serialization (#1733)
authorDenis Maksimenko/AI Tools Lab /SRR/Assistant Engineer/삼성전자 <d.maksimenko@partner.samsung.com>
Mon, 8 Oct 2018 15:27:29 +0000 (18:27 +0300)
committerРоман Михайлович Русяев/AI Tools Lab /SRR/Staff Engineer/삼성전자 <r.rusyaev@samsung.com>
Mon, 8 Oct 2018 15:27:29 +0000 (18:27 +0300)
This PR:
* Addresses issue #1681: de/serialization systems are changed to stare tensor data as raw bytes sequences, which reduces some amount of code now and, potentially, more later
* Changes codestyle in related files a bit to correspond with our future standarts
* Changes file structure of related files: serializer is now moved from a separate folder into modelIR folder.
* Fixes #1701

Signed-off-by: Denis Maksimenko <d.maksimenko@partner.samsung.com>
contrib/nnc/core/CMakeLists.txt
contrib/nnc/core/modelIR/Deserializer.cpp [new file with mode: 0644]
contrib/nnc/core/modelIR/Serializer.cpp [moved from contrib/nnc/core/serialize/Serializer.cpp with 60% similarity]
contrib/nnc/core/modelIR/proto/model_ir.proto [moved from contrib/nnc/core/serialize/proto/model_ir.proto with 79% similarity]
contrib/nnc/core/serialize/Deserializer.cpp [deleted file]
contrib/nnc/include/core/modelIR/Deserializer.h [moved from contrib/nnc/include/core/serialize/Deserializer.h with 100% similarity]
contrib/nnc/include/core/modelIR/Serializer.h [moved from contrib/nnc/include/core/serialize/Serializer.h with 100% similarity]
contrib/nnc/unittests/core/deserializer.cpp
contrib/nnc/unittests/core/serializer.cpp

index fb5610f..b23bd1d 100644 (file)
@@ -5,7 +5,7 @@ file(GLOB_RECURSE SERIALIZER_SOURCES "serialize/*.cpp")
 set(GENERATED_OUTPUT_DIR "${CMAKE_CURRENT_BINARY_DIR}/generated")
 Protobuf_Generate(MODEL_IR_PROTO
                 ${GENERATED_OUTPUT_DIR}
-                ${CMAKE_CURRENT_SOURCE_DIR}/serialize/proto
+                modelIR/proto
                 model_ir.proto)
 
 add_nncc_library(model_ir_proto STATIC ${MODEL_IR_PROTO_SOURCES})
diff --git a/contrib/nnc/core/modelIR/Deserializer.cpp b/contrib/nnc/core/modelIR/Deserializer.cpp
new file mode 100644 (file)
index 0000000..b91efef
--- /dev/null
@@ -0,0 +1,125 @@
+/*
+ * Copyright (c) 2018 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "core/modelIR/Deserializer.h"
+#include "model_ir.pb.h"
+
+#include "core/modelIR/ShapeRange.h"
+
+namespace nnc
+{
+namespace mir
+{
+
+//
+// Shape Deserialization
+//
+
+static Shape deserializeFromMessage(const proto::TensorShapeProto& object_as_message)
+{
+  Shape res;
+  auto rank = (uint32_t) object_as_message.dims_size();
+  res.resize((uint32_t) rank);
+  for (uint32_t i = 0; i < rank; i++)
+  {
+    res.dim(i) = (uint32_t) object_as_message.dims(i);
+  }
+
+  return res;
+}
+
+template <>
+Shape Deserializer<Shape>::deserializeFromStream (std::istream& stream)
+{
+  proto::TensorShapeProto object_as_message;
+
+  object_as_message.ParseFromIstream(&stream);
+
+  return deserializeFromMessage(object_as_message);
+}
+
+template <>
+Shape Deserializer<Shape>::deserializeFromString (const std::string& bytes)
+{
+  proto::TensorShapeProto object_as_message;
+
+  object_as_message.ParseFromString(bytes);
+
+  return deserializeFromMessage(object_as_message);
+}
+
+//
+// Tensor Deserialization
+//
+
+static TensorVariant deserializeFromMessage(const proto::TensorProto& object_as_message)
+{
+  Shape shape = deserializeFromMessage(object_as_message.shape());
+
+  proto::DataType dt = object_as_message.dtype();
+
+  const std::string& tensor_content = object_as_message.tensor_content();
+  size_t raw_data_size = tensor_content.size();
+  auto raw_data = new char[raw_data_size];
+  tensor_content.copy(raw_data, raw_data_size);
+
+  TensorVariant::DTYPE tv_dtype;
+  size_t element_size;
+
+  switch (dt)
+  {
+  case proto::DataType::DT_INT32:
+    element_size = sizeof(int32_t);
+    tv_dtype = TensorVariant::DTYPE::INT;
+    break;
+  case proto::DataType::DT_FLOAT:
+    element_size = sizeof(float);
+    tv_dtype = TensorVariant::DTYPE::FLOAT;
+    break;
+  case proto::DataType::DT_DOUBLE :
+    element_size = sizeof(double);
+    tv_dtype = TensorVariant::DTYPE::FLOAT;
+    break;
+  default:
+    throw std::logic_error("Deserializer<TensorVariant>: received unsupported data type");
+  }
+  assert(raw_data_size / element_size == num_elements(shape));
+  std::shared_ptr<char> data(raw_data, std::default_delete<char[]>());
+  return TensorVariant(shape, data, tv_dtype, element_size);
+}
+
+template <>
+TensorVariant Deserializer<TensorVariant>::deserializeFromStream (std::istream& stream)
+{
+  proto::TensorProto object_as_message;
+
+  object_as_message.ParseFromIstream(&stream);
+
+  return deserializeFromMessage(object_as_message);
+}
+
+template <>
+TensorVariant Deserializer<TensorVariant>::deserializeFromString (const std::string& bytes)
+{
+  proto::TensorProto object_as_message;
+
+  object_as_message.ParseFromString(bytes);
+
+  return deserializeFromMessage(object_as_message);
+}
+
+} // namespace mir
+} // namespace nnc
similarity index 60%
rename from contrib/nnc/core/serialize/Serializer.cpp
rename to contrib/nnc/core/modelIR/Serializer.cpp
index 41453dd..b61b95c 100644 (file)
@@ -14,7 +14,7 @@
  * limitations under the License.
  */
 
-#include "core/serialize/Serializer.h"
+#include "core/modelIR/Serializer.h"
 #include "model_ir.pb.h"
 
 #include "core/modelIR/ShapeRange.h"
@@ -44,55 +44,53 @@ std::string Serializer<Shape>::getSerializedObject (const Shape& shape)
   return shapeProto.SerializeAsString();
 }
 
-void setShapeToTensorProto(proto::TensorProto& tensorProto, const Shape& shape)
+static void setShapeToTensorProto(proto::TensorProto& tensor_proto, const Shape& shape)
 {
-  Serializer<Shape> shapeSerializer;
-  tensorProto.mutable_shape()->ParseFromString( shapeSerializer.getSerializedObject( shape ) );
+  Serializer<Shape> shape_serializer;
+  tensor_proto.mutable_shape()->ParseFromString( shape_serializer.getSerializedObject( shape ) );
 }
 
-template <>
-std::string Serializer<Tensor<int> >::getSerializedObject (const Tensor<int>& tensor)
+template <typename T, proto::DataType dtype>
+static proto::TensorProto serializeTensorContent(const Tensor<T>& tensor)
 {
-  proto::TensorProto tensorProto;
-  setShapeToTensorProto(tensorProto, tensor.getShape());
+  proto::TensorProto tp;
+  setShapeToTensorProto(tp, tensor.getShape());
+
+  tp.set_dtype(dtype);
 
-  tensorProto.set_dtype(proto::DT_INT32);
+  size_t data_size = num_elements(tensor.getShape());
+  auto tensor_data = new T[data_size];
+  size_t i = 0;
   ShapeRange shapeRange(tensor.getShape());
   for (auto& idx : shapeRange) {
-    tensorProto.add_int_val(tensor.at(idx));
+    tensor_data[i] = tensor.at(idx);
+    i++;
   }
 
-  return tensorProto.SerializeAsString();
+  size_t raw_data_size = data_size * sizeof(T);
+  std::string raw_data((char*) tensor_data, raw_data_size);
+  delete[] tensor_data;
+  tp.set_tensor_content(raw_data);
+
+  return tp;
 }
 
 template <>
-std::string Serializer<Tensor<float> >::getSerializedObject (const Tensor<float>& tensor)
+std::string Serializer<Tensor<int> >::getSerializedObject (const Tensor<int>& tensor)
 {
-  proto::TensorProto tensorProto;
-  setShapeToTensorProto(tensorProto, tensor.getShape());
-
-  tensorProto.set_dtype(proto::DataType::DT_FLOAT);
-  ShapeRange shapeRange(tensor.getShape());
-  for (auto& idx : shapeRange) {
-    tensorProto.add_float_val(tensor.at(idx));
-  }
+  return serializeTensorContent<int, proto::DT_INT32>(tensor).SerializeAsString();
+}
 
-  return tensorProto.SerializeAsString();
+template <>
+std::string Serializer<Tensor<float> >::getSerializedObject (const Tensor<float>& tensor)
+{
+  return serializeTensorContent<float, proto::DT_FLOAT>(tensor).SerializeAsString();
 }
 
 template <>
 std::string Serializer<Tensor<double> >::getSerializedObject (const Tensor<double>& tensor)
 {
-  proto::TensorProto tensorProto;
-  setShapeToTensorProto(tensorProto, tensor.getShape());
-
-  tensorProto.set_dtype(proto::DataType::DT_DOUBLE);
-  ShapeRange shapeRange(tensor.getShape());
-  for (auto& idx : shapeRange) {
-    tensorProto.add_double_val(tensor.at(idx));
-  }
-
-  return tensorProto.SerializeAsString();
+  return serializeTensorContent<double, proto::DT_DOUBLE>(tensor).SerializeAsString();
 }
 
 } // namespace mir
@@ -61,18 +61,8 @@ message TensorProto {
     // Tensor name
     optional string name = 3;
 
-    // Tensor data. Not using oneof to avoid writing messages for each data type.
-    // TODO: consider using raw tensor content
-    //bytes tensor_content = 4;
-
-    // DT_FLOAT.
-    repeated float float_val = 5 [packed = true];
-
-    // DT_DOUBLE.
-    repeated double double_val = 6 [packed = true];
-
-    // DT_INT32.
-    repeated int32 int_val = 7 [packed = true];
+    // Raw tensor data stored as string of bytes
+    optional bytes tensor_content = 4;
 };
 
 
diff --git a/contrib/nnc/core/serialize/Deserializer.cpp b/contrib/nnc/core/serialize/Deserializer.cpp
deleted file mode 100644 (file)
index 1a21fa1..0000000
+++ /dev/null
@@ -1,144 +0,0 @@
-/*
- * Copyright (c) 2018 Samsung Electronics Co., Ltd. All Rights Reserved
- *
- * Licensed under the Apache License, Version 2.0 (the "License");
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- *    http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-#include "core/serialize/Deserializer.h"
-#include "model_ir.pb.h"
-
-#include "core/modelIR/ShapeRange.h"
-
-namespace nnc
-{
-namespace mir
-{
-
-//
-// Shape Deserialization
-//
-
-Shape deserializeFromMessage(const proto::TensorShapeProto& objectAsMessage)
-{
-  Shape res;
-  auto rank = (uint32_t) objectAsMessage.dims_size();
-  res.resize((uint32_t) rank);
-  for (uint32_t i = 0; i < rank; i++)
-  {
-    res.dim(i) = (uint32_t) objectAsMessage.dims(i);
-  }
-
-  return res;
-}
-
-template <>
-Shape Deserializer<Shape>::deserializeFromStream (std::istream& stream)
-{
-  proto::TensorShapeProto objectAsMessage;
-
-  objectAsMessage.ParseFromIstream(&stream);
-
-  return deserializeFromMessage(objectAsMessage);
-}
-
-template <>
-Shape Deserializer<Shape>::deserializeFromString (const std::string& bytes)
-{
-  proto::TensorShapeProto objectAsMessage;
-
-  objectAsMessage.ParseFromString(bytes);
-
-  return deserializeFromMessage(objectAsMessage);
-}
-
-//
-// Tensor Deserialization
-//
-
-TensorVariant deserializeFromMessage(const proto::TensorProto& objectAsMessage)
-{
-  Shape shape = deserializeFromMessage(objectAsMessage.shape());
-
-  size_t data_size;
-  proto::DataType dt = objectAsMessage.dtype();
-  switch (dt)
-  {
-  case proto::DataType::DT_INT32:
-    {
-      data_size = objectAsMessage.int_val_size();
-      assert(data_size == num_elements(shape));
-
-      auto raw_data = new int32_t[data_size];
-      for (size_t i = 0; i < data_size; i++) {
-        raw_data[i] = objectAsMessage.int_val( (int) i);
-      }
-
-      std::shared_ptr<int32_t> data(raw_data, std::default_delete<int32_t[]>());
-      return TensorVariant(shape, data, TensorVariant::DTYPE::INT);
-    }
-  case proto::DataType::DT_FLOAT:
-    {
-      data_size = objectAsMessage.float_val_size();
-      assert(data_size == num_elements(shape));
-
-      auto raw_data = new float[data_size];
-      for (size_t i = 0; i < data_size; i++)
-      {
-        raw_data[i] = objectAsMessage.float_val( (int) i);
-      }
-
-      std::shared_ptr<float> data(raw_data, std::default_delete<float[]>());
-      return TensorVariant(shape, data, TensorVariant::DTYPE::FLOAT);
-    }
-  case proto::DataType::DT_DOUBLE :
-    {
-      data_size = objectAsMessage.double_val_size();
-      assert(data_size == num_elements(shape));
-
-      auto raw_data = new double[data_size];
-      for (size_t i = 0; i < data_size; i++) {
-        raw_data[i] = objectAsMessage.double_val( (int) i);
-      }
-
-      std::shared_ptr<double> data(raw_data, std::default_delete<double[]>());
-      return TensorVariant(shape, data, TensorVariant::DTYPE::FLOAT);
-    }
-  default:
-    {
-      throw std::logic_error("Deserializer: received unsupported data type");
-    }
-  }
-}
-
-template <>
-TensorVariant Deserializer<TensorVariant>::deserializeFromStream (std::istream& stream)
-{
-  proto::TensorProto objectAsMessage;
-
-  objectAsMessage.ParseFromIstream(&stream);
-
-  return deserializeFromMessage(objectAsMessage);
-}
-
-template <>
-TensorVariant Deserializer<TensorVariant>::deserializeFromString (const std::string& bytes)
-{
-  proto::TensorProto objectAsMessage;
-
-  objectAsMessage.ParseFromString(bytes);
-
-  return deserializeFromMessage(objectAsMessage);
-}
-
-} // namespace mir
-} // namespace nnc
index 8e4cde6..f4421bb 100644 (file)
@@ -16,7 +16,7 @@
 
 #include <gtest/gtest.h>
 
-#include "core/serialize/Deserializer.h"
+#include "core/modelIR/Deserializer.h"
 #include "core/modelIR/ShapeRange.h"
 #include "core/modelIR/Tensor.h"
 
@@ -24,250 +24,257 @@ using namespace nnc::mir;
 
 const double EPS = 0.0000001;
 
-static void checkShape(const Shape& shape, const proto::TensorShapeProto& protoShape)
+static void checkShape(const Shape& shape, const proto::TensorShapeProto& proto_shape)
 {
-  ASSERT_EQ(shape.rank(), protoShape.dims_size());
+  ASSERT_EQ(shape.rank(), proto_shape.dims_size());
   for (int i = 0; i < shape.rank(); i++) {
-    ASSERT_EQ(shape.dim(i), protoShape.dims(i));
+    ASSERT_EQ(shape.dim(i), proto_shape.dims(i));
   }
 }
 
-static void checkIntTensorContent(const TensorVariant& tensorV, const proto::TensorProto& protoTensor)
+template <typename T>
+static void checkTensorContent(const Tensor<T>& tensor, const proto::TensorProto& proto_tensor)
 {
-  ASSERT_EQ(protoTensor.dtype(), proto::DataType::DT_INT32);
-  Tensor<int> tensor(tensorV);
-  Shape shape = tensor.getShape();
-  ShapeRange range(shape);
+  ShapeRange range(tensor.getShape());
+  auto data = (T*) proto_tensor.tensor_content().c_str();
   int i = 0;
   for (auto& idx : range) {
-    ASSERT_EQ(tensor.at(idx), protoTensor.int_val(i++));
+    ASSERT_EQ(tensor.at(idx), data[i++]);
   }
 }
 
-static void checkFloatTensorContent(const TensorVariant& tensorV, const proto::TensorProto& protoTensor)
+static void checkIntTensor(const Tensor<int> &tensor, const proto::TensorProto &proto_tensor)
 {
-  ASSERT_EQ(protoTensor.dtype(), proto::DataType::DT_FLOAT);
-  Tensor<float> tensor(tensorV);
-  ShapeRange range(tensor.getShape());
-  int i = 0;
-  for (auto& idx : range) {
-    ASSERT_NEAR(tensor.at(idx), protoTensor.float_val(i++), EPS);
-  }
+  ASSERT_EQ(proto_tensor.dtype(), proto::DataType::DT_INT32);
+  checkTensorContent<int>(tensor, proto_tensor);
 }
 
-static void checkDoubleTensorContent(const TensorVariant& tensorV, const proto::TensorProto& protoTensor)
+static void checkFloatTensor(const Tensor<float> &tensor, const proto::TensorProto &proto_tensor)
 {
-  ASSERT_EQ(protoTensor.dtype(), proto::DataType::DT_DOUBLE);
-  Tensor<double> tensor(tensorV);
-  ShapeRange range(tensor.getShape());
-  int i = 0;
-  for (auto& idx : range) {
-    ASSERT_NEAR(tensor.at(idx), protoTensor.double_val(i++), EPS);
-  }
+  ASSERT_EQ(proto_tensor.dtype(), proto::DataType::DT_FLOAT);
+  checkTensorContent<float>(tensor, proto_tensor);
+}
+
+static void checkDoubleTensor(const Tensor<double> &tensor, const proto::TensorProto &proto_tensor)
+{
+  ASSERT_EQ(proto_tensor.dtype(), proto::DataType::DT_DOUBLE);
+  checkTensorContent<double>(tensor, proto_tensor);
 }
 
 
 TEST(Deserializer, ShapeDeserializationTest) {
   Deserializer<Shape> deserializer;
 
-  proto::TensorShapeProto protoShape;
+  proto::TensorShapeProto proto_shape;
   std::string serializedShape;
   Shape shape;
-  protoShape.SerializeToString(&serializedShape);
+  proto_shape.SerializeToString(&serializedShape);
   shape = deserializer.deserializeFromString(serializedShape);
-  checkShape(shape, protoShape);
+  checkShape(shape, proto_shape);
 
-  protoShape.add_dims(5);
-  protoShape.SerializeToString(&serializedShape);
+  proto_shape.add_dims(5);
+  proto_shape.SerializeToString(&serializedShape);
   shape = deserializer.deserializeFromString(serializedShape);
-  checkShape(shape, protoShape);
+  checkShape(shape, proto_shape);
 
-  protoShape.add_dims(2);
-  protoShape.add_dims(4);
-  protoShape.SerializeToString(&serializedShape);
+  proto_shape.add_dims(2);
+  proto_shape.add_dims(4);
+  proto_shape.SerializeToString(&serializedShape);
   shape = deserializer.deserializeFromString(serializedShape);
-  checkShape(shape, protoShape);
+  checkShape(shape, proto_shape);
 
-  protoShape.clear_dims();
-  protoShape.add_dims(1);
-  protoShape.add_dims(1);
-  protoShape.add_dims(1);
-  protoShape.add_dims(1);
-  protoShape.SerializeToString(&serializedShape);
+  proto_shape.clear_dims();
+  proto_shape.add_dims(1);
+  proto_shape.add_dims(1);
+  proto_shape.add_dims(1);
+  proto_shape.add_dims(1);
+  proto_shape.SerializeToString(&serializedShape);
   shape = deserializer.deserializeFromString(serializedShape);
-  checkShape(shape, protoShape);
+  checkShape(shape, proto_shape);
 }
 
 TEST(Deserializer, IntTensorDeserializationTest) {
   Deserializer<TensorVariant> deserializer;
   int tmp = 0;
 
-  proto::TensorProto protoTensor;
-  protoTensor.set_dtype(proto::DataType::DT_INT32);
-  proto::TensorShapeProto* protoShapePtr = protoTensor.mutable_shape();
+  std::vector<int> values;
+  proto::TensorProto proto_tensor;
+  proto_tensor.set_dtype(proto::DataType::DT_INT32);
+  proto::TensorShapeProto* proto_shapePtr = proto_tensor.mutable_shape();
   std::string serializedTensor;
 
   Shape shape_1{3};
   for (auto& idx : ShapeRange(shape_1))
   {
-    protoTensor.add_int_val(tmp++);
+    values.push_back(tmp++);
   }
+  proto_tensor.set_tensor_content(std::string((char*) values.data(), sizeof(int) * num_elements(shape_1)));
   for (uint32_t i = 0; i < shape_1.rank(); i++)
   {
-    protoShapePtr->add_dims(shape_1.dim(i));
+    proto_shapePtr->add_dims(shape_1.dim(i));
   }
-  protoTensor.SerializeToString(&serializedTensor);
+  proto_tensor.SerializeToString(&serializedTensor);
   TensorVariant tensor_1 = deserializer.deserializeFromString(serializedTensor);
-  checkShape(shape_1, protoTensor.shape());
-  checkIntTensorContent(tensor_1, protoTensor);
+  checkShape(shape_1, proto_tensor.shape());
+  checkIntTensor(Tensor<int>(tensor_1), proto_tensor);
 
   Shape shape_2{3, 4, 5};
-  protoShapePtr->clear_dims();
-  protoTensor.clear_int_val();
+  values.clear();
+  proto_shapePtr->clear_dims();
   for (auto& idx : ShapeRange(shape_2))
   {
-    protoTensor.add_int_val(tmp--);
+    values.push_back(tmp--);
   }
+  proto_tensor.set_tensor_content(std::string((char*) values.data(), sizeof(int) * num_elements(shape_2)));
   for (uint32_t i = 0; i < shape_2.rank(); i++)
   {
-    protoShapePtr->add_dims(shape_2.dim(i));
+    proto_shapePtr->add_dims(shape_2.dim(i));
   }
-  protoTensor.SerializeToString(&serializedTensor);
+  proto_tensor.SerializeToString(&serializedTensor);
   TensorVariant tensor_2 = deserializer.deserializeFromString(serializedTensor);
-  checkShape(shape_2, protoTensor.shape());
-  checkIntTensorContent(tensor_2, protoTensor);
+  checkShape(shape_2, proto_tensor.shape());
+  checkIntTensor(Tensor<int>(tensor_2), proto_tensor);
 
   Shape shape_3{1, 1, 1, 1, 1};
-  protoShapePtr->clear_dims();
-  protoTensor.clear_int_val();
+  values.clear();
+  proto_shapePtr->clear_dims();
   for (auto& idx : ShapeRange(shape_3))
   {
-    protoTensor.add_int_val(tmp++);
+    values.push_back(tmp++);
   }
+  proto_tensor.set_tensor_content(std::string((char*) values.data(), sizeof(int) * num_elements(shape_3)));
   for (uint32_t i = 0; i < shape_3.rank(); i++)
   {
-    protoShapePtr->add_dims(shape_3.dim(i));
+    proto_shapePtr->add_dims(shape_3.dim(i));
   }
-  protoTensor.SerializeToString(&serializedTensor);
+  proto_tensor.SerializeToString(&serializedTensor);
   TensorVariant tensor_3 = deserializer.deserializeFromString(serializedTensor);
-  checkShape(shape_3, protoTensor.shape());
-  checkIntTensorContent(tensor_3, protoTensor);
+  checkShape(shape_3, proto_tensor.shape());
+  checkIntTensor(Tensor<int>(tensor_3), proto_tensor);
 }
 
 TEST(Deserializer, FloatTensorDeserializationTest) {
   Deserializer<TensorVariant> deserializer;
   float tmp = 1.0f;
 
-  proto::TensorProto protoTensor;
-  protoTensor.set_dtype(proto::DataType::DT_FLOAT);
-  proto::TensorShapeProto* protoShapePtr = protoTensor.mutable_shape();
+  std::vector<float> values;
+  proto::TensorProto proto_tensor;
+  proto_tensor.set_dtype(proto::DataType::DT_FLOAT);
+  proto::TensorShapeProto* proto_shapePtr = proto_tensor.mutable_shape();
   std::string serializedTensor;
 
   Shape shape_1{3};
   for (auto& idx : ShapeRange(shape_1))
   {
-    protoTensor.add_float_val(tmp);
+    values.push_back(tmp);
     tmp += 7.3f;
   }
+  proto_tensor.set_tensor_content(std::string((char*) values.data(), sizeof(float) * num_elements(shape_1)));
   for (uint32_t i = 0; i < shape_1.rank(); i++)
   {
-    protoShapePtr->add_dims(shape_1.dim(i));
+    proto_shapePtr->add_dims(shape_1.dim(i));
   }
-  protoTensor.SerializeToString(&serializedTensor);
+  proto_tensor.SerializeToString(&serializedTensor);
   TensorVariant tensor_1 = deserializer.deserializeFromString(serializedTensor);
-  checkShape(shape_1, protoTensor.shape());
-  checkFloatTensorContent(tensor_1, protoTensor);
+  checkShape(shape_1, proto_tensor.shape());
+  checkFloatTensor(Tensor<float>(tensor_1), proto_tensor);
 
   Shape shape_2{3, 4, 5};
-  protoShapePtr->clear_dims();
-  protoTensor.clear_float_val();
+  values.clear();
+  proto_shapePtr->clear_dims();
   for (auto& idx : ShapeRange(shape_2))
   {
-    protoTensor.add_float_val(tmp);
+    values.push_back(tmp);
     tmp *= -1.32f;
   }
+  proto_tensor.set_tensor_content(std::string((char*) values.data(), sizeof(float) * num_elements(shape_2)));
   for (uint32_t i = 0; i < shape_2.rank(); i++)
   {
-    protoShapePtr->add_dims(shape_2.dim(i));
+    proto_shapePtr->add_dims(shape_2.dim(i));
   }
-  protoTensor.SerializeToString(&serializedTensor);
+  proto_tensor.SerializeToString(&serializedTensor);
   TensorVariant tensor_2 = deserializer.deserializeFromString(serializedTensor);
-  checkShape(shape_2, protoTensor.shape());
-  checkFloatTensorContent(tensor_2, protoTensor);
+  checkShape(shape_2, proto_tensor.shape());
+  checkFloatTensor(Tensor<float>(tensor_2), proto_tensor);
 
   Shape shape_3{1, 1, 1, 1, 1};
-  protoShapePtr->clear_dims();
-  protoTensor.clear_float_val();
+  values.clear();
+  proto_shapePtr->clear_dims();
   for (auto& idx : ShapeRange(shape_3))
   {
     tmp /= 2;
-    protoTensor.add_float_val(tmp);
+    values.push_back(tmp);
   }
+  proto_tensor.set_tensor_content(std::string((char*) values.data(), sizeof(float) * num_elements(shape_3)));
   for (uint32_t i = 0; i < shape_3.rank(); i++)
   {
-    protoShapePtr->add_dims(shape_3.dim(i));
+    proto_shapePtr->add_dims(shape_3.dim(i));
   }
-  protoTensor.SerializeToString(&serializedTensor);
+  proto_tensor.SerializeToString(&serializedTensor);
   TensorVariant tensor_3 = deserializer.deserializeFromString(serializedTensor);
-  checkShape(shape_3, protoTensor.shape());
-  checkFloatTensorContent(tensor_3, protoTensor);
+  checkShape(shape_3, proto_tensor.shape());
+  checkFloatTensor(Tensor<float>(tensor_3), proto_tensor);
 }
 
 TEST(Deserializer, DoubleTensorDeserializationTest) {
   Deserializer<TensorVariant> deserializer;
   double tmp = 1.0f;
 
-  proto::TensorProto protoTensor;
-  protoTensor.set_dtype(proto::DataType::DT_DOUBLE);
-  proto::TensorShapeProto* protoShapePtr = protoTensor.mutable_shape();
+  std::vector<double> values;
+  proto::TensorProto proto_tensor;
+  proto_tensor.set_dtype(proto::DataType::DT_DOUBLE);
+  proto::TensorShapeProto* proto_shapePtr = proto_tensor.mutable_shape();
   std::string serializedTensor;
 
   Shape shape_1{3};
   for (auto& idx : ShapeRange(shape_1))
   {
-    protoTensor.add_double_val(tmp);
+    values.push_back(tmp);
     tmp += 7.3f;
   }
+  proto_tensor.set_tensor_content(std::string((char*) values.data(), sizeof(double) * num_elements(shape_1)));
   for (uint32_t i = 0; i < shape_1.rank(); i++)
   {
-    protoShapePtr->add_dims(shape_1.dim(i));
+    proto_shapePtr->add_dims(shape_1.dim(i));
   }
-  protoTensor.SerializeToString(&serializedTensor);
+  proto_tensor.SerializeToString(&serializedTensor);
   TensorVariant tensor_1 = deserializer.deserializeFromString(serializedTensor);
-  checkShape(shape_1, protoTensor.shape());
-  checkDoubleTensorContent(tensor_1, protoTensor);
+  checkShape(shape_1, proto_tensor.shape());
+  checkDoubleTensor(Tensor<double>(tensor_1), proto_tensor);
 
   Shape shape_2{3, 4, 5};
-  protoShapePtr->clear_dims();
-  protoTensor.clear_double_val();
+  values.clear();
+  proto_shapePtr->clear_dims();
   for (auto& idx : ShapeRange(shape_2))
   {
-    protoTensor.add_double_val(tmp);
+    values.push_back(tmp);
     tmp *= -1.32f;
   }
+  proto_tensor.set_tensor_content(std::string((char*) values.data(), sizeof(double) * num_elements(shape_2)));
   for (uint32_t i = 0; i < shape_2.rank(); i++)
   {
-    protoShapePtr->add_dims(shape_2.dim(i));
+    proto_shapePtr->add_dims(shape_2.dim(i));
   }
-  protoTensor.SerializeToString(&serializedTensor);
+  proto_tensor.SerializeToString(&serializedTensor);
   TensorVariant tensor_2 = deserializer.deserializeFromString(serializedTensor);
-  checkShape(shape_2, protoTensor.shape());
-  checkDoubleTensorContent(tensor_2, protoTensor);
+  checkShape(shape_2, proto_tensor.shape());
+  checkDoubleTensor(Tensor<double>(tensor_2), proto_tensor);
 
   Shape shape_3{1, 1, 1, 1, 1};
-  protoShapePtr->clear_dims();
-  protoTensor.clear_double_val();
+  values.clear();
+  proto_shapePtr->clear_dims();
   for (auto& idx : ShapeRange(shape_3))
   {
     tmp /= 2;
-    protoTensor.add_double_val(tmp);
+    values.push_back(tmp);
   }
+  proto_tensor.set_tensor_content(std::string((char*) values.data(), sizeof(double) * num_elements(shape_3)));
   for (uint32_t i = 0; i < shape_3.rank(); i++)
   {
-    protoShapePtr->add_dims(shape_3.dim(i));
+    proto_shapePtr->add_dims(shape_3.dim(i));
   }
-  protoTensor.SerializeToString(&serializedTensor);
+  proto_tensor.SerializeToString(&serializedTensor);
   TensorVariant tensor_3 = deserializer.deserializeFromString(serializedTensor);
-  checkShape(shape_3, protoTensor.shape());
-  checkDoubleTensorContent(tensor_3, protoTensor);
+  checkShape(shape_3, proto_tensor.shape());
+  checkDoubleTensor(Tensor<double>(tensor_3), proto_tensor);
 }
index bbba716..5c2655c 100644 (file)
 #include <gtest/gtest.h>
 #include <cmath>
 
-#include "core/serialize/Serializer.h"
+#include "core/modelIR/Serializer.h"
 #include "core/modelIR/ShapeRange.h"
 
 using namespace nnc::mir;
 
 const double EPS = 0.0000001;
 
-static void checkShape(const Shape& shape, const proto::TensorShapeProto& protoShape)
+static void checkShape(const Shape& shape, const proto::TensorShapeProto& proto_shape)
 {
-  ASSERT_EQ(shape.rank(), protoShape.dims_size());
+  ASSERT_EQ(shape.rank(), proto_shape.dims_size());
   for (int i = 0; i < shape.rank(); i++) {
-    ASSERT_EQ(shape.dim(i), protoShape.dims(i));
+    ASSERT_EQ(shape.dim(i), proto_shape.dims(i));
   }
 }
 
-static TensorVariant allocateIntTensor(const Shape &shape)
-{
-  size_t data_size = 1;
-  for (uint32_t i = 0; i < shape.rank(); ++i)
-  {
-    data_size *= shape.dim(i);
-  }
-
-  auto od = new int[data_size];
-
-  std::shared_ptr<int> data(od, std::default_delete<int>());
-  TensorVariant t(shape, data, TensorVariant::DTYPE::INT);
-
-  return t;
-}
-
-static void checkIntTensorContent(const Tensor<int>& tensor, const proto::TensorProto& protoTensor)
+template <typename T>
+static void checkTensorContent(const Tensor<T>& tensor, const proto::TensorProto& proto_tensor)
 {
-  ASSERT_EQ(protoTensor.dtype(), proto::DataType::DT_INT32);
   ShapeRange range(tensor.getShape());
+  auto data = (T*) proto_tensor.tensor_content().c_str();
   int i = 0;
   for (auto& idx : range) {
-    ASSERT_EQ(tensor.at(idx), protoTensor.int_val(i++));
+    ASSERT_EQ(tensor.at(idx), data[i++]);
   }
 }
 
-static TensorVariant allocateFloatTensor(const Shape &shape)
+template <typename T>
+static std::shared_ptr<T> allocateTensorContent(const Shape &shape)
 {
   size_t data_size = 1;
   for (uint32_t i = 0; i < shape.rank(); ++i)
@@ -66,48 +52,47 @@ static TensorVariant allocateFloatTensor(const Shape &shape)
     data_size *= shape.dim(i);
   }
 
-  auto od = new float[data_size];
+  auto od = new T[data_size];
 
-  std::shared_ptr<float> data(od, std::default_delete<float>());
-  TensorVariant t(shape, data, TensorVariant::DTYPE::FLOAT);
+  std::shared_ptr<T> data(od, std::default_delete<T[]>());
 
-  return t;
+  return data;
 }
 
-static void checkFloatTensorContent(const Tensor<float>& tensor, const proto::TensorProto& protoTensor)
+static TensorVariant allocateIntTensor(const Shape &shape)
 {
-  ASSERT_EQ(protoTensor.dtype(), proto::DataType::DT_FLOAT);
-  ShapeRange range(tensor.getShape());
-  int i = 0;
-  for (auto& idx : range) {
-    ASSERT_TRUE(fabsf(tensor.at(idx) - protoTensor.float_val(i++)) < EPS);
-  }
+  std::shared_ptr<int> data = allocateTensorContent<int>(shape);
+  return TensorVariant(shape, data, TensorVariant::DTYPE::INT);
 }
 
-static TensorVariant allocateDoubleTensor(const Shape &shape)
+static void checkIntTensor(const Tensor<int>& tensor, const proto::TensorProto& proto_tensor)
 {
-  size_t data_size = 1;
-  for (uint32_t i = 0; i < shape.rank(); ++i)
-  {
-    data_size *= shape.dim(i);
-  }
+  ASSERT_EQ(proto_tensor.dtype(), proto::DataType::DT_INT32);
+  checkTensorContent<int>(tensor, proto_tensor);
+}
 
-  auto od = new double[data_size];
+static TensorVariant allocateFloatTensor(const Shape &shape)
+{
+  std::shared_ptr<float> data = allocateTensorContent<float>(shape);
+  return TensorVariant(shape, data, TensorVariant::DTYPE::FLOAT);
+}
 
-  std::shared_ptr<double> data(od, std::default_delete<double>());
-  TensorVariant t(shape, data, TensorVariant::DTYPE::FLOAT);
+static void checkFloatTensor(const Tensor<float>& tensor, const proto::TensorProto& proto_tensor)
+{
+  ASSERT_EQ(proto_tensor.dtype(), proto::DataType::DT_FLOAT);
+  checkTensorContent<float>(tensor, proto_tensor);
+}
 
-  return t;
+static TensorVariant allocateDoubleTensor(const Shape &shape)
+{
+  std::shared_ptr<double> data = allocateTensorContent<double>(shape);
+  return TensorVariant(shape, data, TensorVariant::DTYPE::FLOAT);
 }
 
-static void checkDoubleTensorContent(const Tensor<double>& tensor, const proto::TensorProto& protoTensor)
+static void checkDoubleTensor(const Tensor<double>& tensor, const proto::TensorProto& proto_tensor)
 {
-  ASSERT_EQ(protoTensor.dtype(), proto::DataType::DT_DOUBLE);
-  ShapeRange range(tensor.getShape());
-  int i = 0;
-  for (auto& idx : range) {
-    ASSERT_TRUE(fabs(tensor.at(idx) - protoTensor.double_val(i++)) < EPS);
-  }
+  ASSERT_EQ(proto_tensor.dtype(), proto::DataType::DT_DOUBLE);
+  checkTensorContent<double>(tensor, proto_tensor);
 }
 
 
@@ -116,39 +101,39 @@ TEST(Serializer, ShapeSerializationTest) {
 
   Shape shape_0{};
   std::string serializedShape_0 = serializer.getSerializedObject(shape_0);
-  proto::TensorShapeProto protoShape_0;
-  protoShape_0.ParseFromString(serializedShape_0);
-  checkShape(shape_0, protoShape_0);
+  proto::TensorShapeProto proto_shape_0;
+  proto_shape_0.ParseFromString(serializedShape_0);
+  checkShape(shape_0, proto_shape_0);
 
   Shape shape_1{1};
   std::string serializedShape_1 = serializer.getSerializedObject(shape_1);
-  proto::TensorShapeProto protoShape_1;
-  protoShape_1.ParseFromString(serializedShape_1);
-  checkShape(shape_1, protoShape_1);
+  proto::TensorShapeProto proto_shape_1;
+  proto_shape_1.ParseFromString(serializedShape_1);
+  checkShape(shape_1, proto_shape_1);
 
   Shape shape_2{5};
   std::string serializedShape_2 = serializer.getSerializedObject(shape_2);
-  proto::TensorShapeProto protoShape_2;
-  protoShape_2.ParseFromString(serializedShape_2);
-  checkShape(shape_2, protoShape_2);
+  proto::TensorShapeProto proto_shape_2;
+  proto_shape_2.ParseFromString(serializedShape_2);
+  checkShape(shape_2, proto_shape_2);
 
   Shape shape_3{2, 4};
   std::string serializedShape_3 = serializer.getSerializedObject(shape_3);
-  proto::TensorShapeProto protoShape_3;
-  protoShape_3.ParseFromString(serializedShape_3);
-  checkShape(shape_3, protoShape_3);
+  proto::TensorShapeProto proto_shape_3;
+  proto_shape_3.ParseFromString(serializedShape_3);
+  checkShape(shape_3, proto_shape_3);
 
   Shape shape_4{1, 1, 1, 1};
   std::string serializedShape_4 = serializer.getSerializedObject(shape_4);
-  proto::TensorShapeProto protoShape_4;
-  protoShape_4.ParseFromString(serializedShape_4);
-  checkShape(shape_4, protoShape_4);
+  proto::TensorShapeProto proto_shape_4;
+  proto_shape_4.ParseFromString(serializedShape_4);
+  checkShape(shape_4, proto_shape_4);
 
   Shape shape_5{1, 2, 3, 4, 5};
   std::string serializedShape_5 = serializer.getSerializedObject(shape_5);
-  proto::TensorShapeProto protoShape_5;
-  protoShape_5.ParseFromString(serializedShape_5);
-  checkShape(shape_5, protoShape_5);
+  proto::TensorShapeProto proto_shape_5;
+  proto_shape_5.ParseFromString(serializedShape_5);
+  checkShape(shape_5, proto_shape_5);
 }
 
 TEST(Serializer, IntTensorSerializationTest) {
@@ -162,10 +147,10 @@ TEST(Serializer, IntTensorSerializationTest) {
     tensor_1.at(idx) = tmp++;
   }
   std::string serializedTensor_1 = serializer.getSerializedObject(tensor_1);
-  proto::TensorProto protoTensor_1;
-  protoTensor_1.ParseFromString(serializedTensor_1);
-  checkShape(shape_1, protoTensor_1.shape());
-  checkIntTensorContent(tensor_1, protoTensor_1);
+  proto::TensorProto proto_tensor_1;
+  proto_tensor_1.ParseFromString(serializedTensor_1);
+  checkShape(shape_1, proto_tensor_1.shape());
+  checkIntTensor(tensor_1, proto_tensor_1);
 
   Shape shape_2{3, 4, 5};
   TensorVariant tv_2(allocateIntTensor(shape_2));
@@ -174,10 +159,10 @@ TEST(Serializer, IntTensorSerializationTest) {
     tensor_2.at(idx) = tmp--;
   }
   std::string serializedTensor_2 = serializer.getSerializedObject(tensor_2);
-  proto::TensorProto protoTensor_2;
-  protoTensor_2.ParseFromString(serializedTensor_2);
-  checkShape(shape_2, protoTensor_2.shape());
-  checkIntTensorContent(tensor_2, protoTensor_2);
+  proto::TensorProto proto_tensor_2;
+  proto_tensor_2.ParseFromString(serializedTensor_2);
+  checkShape(shape_2, proto_tensor_2.shape());
+  checkIntTensor(tensor_2, proto_tensor_2);
 
   Shape shape_3{1, 1, 1, 1, 1};
   TensorVariant tv_3(allocateIntTensor(shape_3));
@@ -186,10 +171,10 @@ TEST(Serializer, IntTensorSerializationTest) {
     tensor_3.at(idx) = tmp++;
   }
   std::string serializedTensor_3 = serializer.getSerializedObject(tensor_3);
-  proto::TensorProto protoTensor_3;
-  protoTensor_3.ParseFromString(serializedTensor_3);
-  checkShape(shape_3, protoTensor_3.shape());
-  checkIntTensorContent(tensor_3, protoTensor_3);
+  proto::TensorProto proto_tensor_3;
+  proto_tensor_3.ParseFromString(serializedTensor_3);
+  checkShape(shape_3, proto_tensor_3.shape());
+  checkIntTensor(tensor_3, proto_tensor_3);
 }
 
 TEST(Serializer, FloatTensorSerializationTest) {
@@ -204,10 +189,10 @@ TEST(Serializer, FloatTensorSerializationTest) {
     tmp += 10.3f;
   }
   std::string serializedTensor_1 = serializer.getSerializedObject(tensor_1);
-  proto::TensorProto protoTensor_1;
-  protoTensor_1.ParseFromString(serializedTensor_1);
-  checkShape(shape_1, protoTensor_1.shape());
-  checkFloatTensorContent(tensor_1, protoTensor_1);
+  proto::TensorProto proto_tensor_1;
+  proto_tensor_1.ParseFromString(serializedTensor_1);
+  checkShape(shape_1, proto_tensor_1.shape());
+  checkFloatTensor(tensor_1, proto_tensor_1);
 
   Shape shape_2{3, 4, 5};
   TensorVariant tv_2(allocateFloatTensor(shape_2));
@@ -217,10 +202,10 @@ TEST(Serializer, FloatTensorSerializationTest) {
     tmp *= -1.21f;
   }
   std::string serializedTensor_2 = serializer.getSerializedObject(tensor_2);
-  proto::TensorProto protoTensor_2;
-  protoTensor_2.ParseFromString(serializedTensor_2);
-  checkShape(shape_2, protoTensor_2.shape());
-  checkFloatTensorContent(tensor_2, protoTensor_2);
+  proto::TensorProto proto_tensor_2;
+  proto_tensor_2.ParseFromString(serializedTensor_2);
+  checkShape(shape_2, proto_tensor_2.shape());
+  checkFloatTensor(tensor_2, proto_tensor_2);
 
   Shape shape_3{1, 1, 1, 1, 1};
   TensorVariant tv_3(allocateFloatTensor(shape_3));
@@ -230,10 +215,10 @@ TEST(Serializer, FloatTensorSerializationTest) {
     tensor_3.at(idx) = tmp;
   }
   std::string serializedTensor_3 = serializer.getSerializedObject(tensor_3);
-  proto::TensorProto protoTensor_3;
-  protoTensor_3.ParseFromString(serializedTensor_3);
-  checkShape(shape_3, protoTensor_3.shape());
-  checkFloatTensorContent(tensor_3, protoTensor_3);
+  proto::TensorProto proto_tensor_3;
+  proto_tensor_3.ParseFromString(serializedTensor_3);
+  checkShape(shape_3, proto_tensor_3.shape());
+  checkFloatTensor(tensor_3, proto_tensor_3);
 }
 
 TEST(Serializer, DoubleTensorSerializationTest) {
@@ -248,10 +233,10 @@ TEST(Serializer, DoubleTensorSerializationTest) {
     tmp += 10.3f;
   }
   std::string serializedTensor_1 = serializer.getSerializedObject(tensor_1);
-  proto::TensorProto protoTensor_1;
-  protoTensor_1.ParseFromString(serializedTensor_1);
-  checkShape(shape_1, protoTensor_1.shape());
-  checkDoubleTensorContent(tensor_1, protoTensor_1);
+  proto::TensorProto proto_tensor_1;
+  proto_tensor_1.ParseFromString(serializedTensor_1);
+  checkShape(shape_1, proto_tensor_1.shape());
+  checkDoubleTensor(tensor_1, proto_tensor_1);
 
   Shape shape_2{3, 4, 5};
   TensorVariant tv_2(allocateDoubleTensor(shape_2));
@@ -261,10 +246,10 @@ TEST(Serializer, DoubleTensorSerializationTest) {
     tmp *= -1.21f;
   }
   std::string serializedTensor_2 = serializer.getSerializedObject(tensor_2);
-  proto::TensorProto protoTensor_2;
-  protoTensor_2.ParseFromString(serializedTensor_2);
-  checkShape(shape_2, protoTensor_2.shape());
-  checkDoubleTensorContent(tensor_2, protoTensor_2);
+  proto::TensorProto proto_tensor_2;
+  proto_tensor_2.ParseFromString(serializedTensor_2);
+  checkShape(shape_2, proto_tensor_2.shape());
+  checkDoubleTensor(tensor_2, proto_tensor_2);
 
   Shape shape_3{1, 1, 1, 1, 1};
   TensorVariant tv_3(allocateDoubleTensor(shape_3));
@@ -274,8 +259,8 @@ TEST(Serializer, DoubleTensorSerializationTest) {
     tensor_3.at(idx) = tmp;
   }
   std::string serializedTensor_3 = serializer.getSerializedObject(tensor_3);
-  proto::TensorProto protoTensor_3;
-  protoTensor_3.ParseFromString(serializedTensor_3);
-  checkShape(shape_3, protoTensor_3.shape());
-  checkDoubleTensorContent(tensor_3, protoTensor_3);
+  proto::TensorProto proto_tensor_3;
+  proto_tensor_3.ParseFromString(serializedTensor_3);
+  checkShape(shape_3, proto_tensor_3.shape());
+  checkDoubleTensor(tensor_3, proto_tensor_3);
 }