--- /dev/null
+#include "core/serialize/Deserializer.h"
+#include "model_ir.pb.h"
+
+#include "core/modelIR/ShapeRange.h"
+
+namespace nncc {
+namespace contrib {
+namespace core {
+
+using namespace nncc::contrib::core::data;
+using namespace nncc::contrib::core::ADT;
+
+//
+// Shape Deserialization
+//
+
+Shape deserializeFromMessage(const proto::TensorShapeProto& objectAsMessage)
+{
+ Shape res;
+ auto rank = (uint32_t) objectAsMessage.dims_size();
+ res.resize((uint32_t) rank);
+ for (uint32_t i = 0; i < rank; i++)
+ {
+ res.dim(i) = (uint32_t) objectAsMessage.dims(i);
+ }
+
+ return res;
+}
+
+template <>
+Shape Deserializer<Shape>::deserializeFromStream (std::istream& stream)
+{
+ proto::TensorShapeProto objectAsMessage;
+
+ objectAsMessage.ParseFromIstream(&stream);
+
+ return deserializeFromMessage(objectAsMessage);
+}
+
+template <>
+Shape Deserializer<Shape>::deserializeFromString (const std::string& bytes)
+{
+ proto::TensorShapeProto objectAsMessage;
+
+ objectAsMessage.ParseFromString(bytes);
+
+ return deserializeFromMessage(objectAsMessage);
+}
+
+//
+// Tensor Deserialization
+//
+
+TensorVariant deserializeFromMessage(const proto::TensorProto& objectAsMessage)
+{
+ Shape shape = deserializeFromMessage(objectAsMessage.shape());
+
+ size_t data_size;
+ proto::DataType dt = objectAsMessage.dtype();
+ switch (dt)
+ {
+ case proto::DataType::DT_INT32:
+ {
+ data_size = objectAsMessage.int_val_size();
+ assert(data_size == num_elements(shape));
+
+ auto raw_data = new int32_t[data_size];
+ for (size_t i = 0; i < data_size; i++) {
+ raw_data[i] = objectAsMessage.int_val( (int) i);
+ }
+
+ std::shared_ptr<int32_t> data(raw_data, std::default_delete<int32_t[]>());
+ return TensorVariant(shape, data, TensorVariant::DTYPE::INT);
+ }
+ case proto::DataType::DT_FLOAT:
+ {
+ data_size = objectAsMessage.float_val_size();
+ assert(data_size == num_elements(shape));
+
+ auto raw_data = new float[data_size];
+ for (size_t i = 0; i < data_size; i++)
+ {
+ raw_data[i] = objectAsMessage.float_val( (int) i);
+ }
+
+ std::shared_ptr<float> data(raw_data, std::default_delete<float[]>());
+ return TensorVariant(shape, data, TensorVariant::DTYPE::FLOAT);
+ }
+ case proto::DataType::DT_DOUBLE :
+ {
+ data_size = objectAsMessage.double_val_size();
+ assert(data_size == num_elements(shape));
+
+ auto raw_data = new double[data_size];
+ for (size_t i = 0; i < data_size; i++) {
+ raw_data[i] = objectAsMessage.double_val( (int) i);
+ }
+
+ std::shared_ptr<double> data(raw_data, std::default_delete<double[]>());
+ return TensorVariant(shape, data, TensorVariant::DTYPE::FLOAT);
+ }
+ default:
+ {
+ throw std::logic_error("Deserializer: received unsupported data type");
+ }
+ }
+}
+
+template <>
+TensorVariant Deserializer<TensorVariant>::deserializeFromStream (std::istream& stream)
+{
+ proto::TensorProto objectAsMessage;
+
+ objectAsMessage.ParseFromIstream(&stream);
+
+ return deserializeFromMessage(objectAsMessage);
+}
+
+template <>
+TensorVariant Deserializer<TensorVariant>::deserializeFromString (const std::string& bytes)
+{
+ proto::TensorProto objectAsMessage;
+
+ objectAsMessage.ParseFromString(bytes);
+
+ return deserializeFromMessage(objectAsMessage);
+}
+
+} // namespace core
+} // namespace contrib
+} // namespace nncc
--- /dev/null
+#include <gtest/gtest.h>
+
+#include "core/serialize/Deserializer.h"
+#include "core/modelIR/ShapeRange.h"
+#include "core/modelIR/Tensor.h"
+
+using namespace nncc::contrib::core;
+using namespace nncc::contrib::core::data;
+using namespace nncc::contrib::core::ADT;
+
+const double EPS = 0.0000001;
+
+static void checkShape(const Shape& shape, const proto::TensorShapeProto& protoShape)
+{
+ ASSERT_EQ(shape.rank(), protoShape.dims_size());
+ for (int i = 0; i < shape.rank(); i++) {
+ ASSERT_EQ(shape.dim(i), protoShape.dims(i));
+ }
+}
+
+static void checkIntTensorContent(const TensorVariant& tensorV, const proto::TensorProto& protoTensor)
+{
+ ASSERT_EQ(protoTensor.dtype(), proto::DataType::DT_INT32);
+ Tensor<int> tensor(tensorV);
+ Shape shape = tensor.getShape();
+ ShapeRange range(shape);
+ int i = 0;
+ for (auto& idx : range) {
+ ASSERT_EQ(tensor.at(idx), protoTensor.int_val(i++));
+ }
+}
+
+static void checkFloatTensorContent(const TensorVariant& tensorV, const proto::TensorProto& protoTensor)
+{
+ ASSERT_EQ(protoTensor.dtype(), proto::DataType::DT_FLOAT);
+ Tensor<float> tensor(tensorV);
+ ShapeRange range(tensor.getShape());
+ int i = 0;
+ for (auto& idx : range) {
+ ASSERT_NEAR(tensor.at(idx), protoTensor.float_val(i++), EPS);
+ }
+}
+
+static void checkDoubleTensorContent(const TensorVariant& tensorV, const proto::TensorProto& protoTensor)
+{
+ ASSERT_EQ(protoTensor.dtype(), proto::DataType::DT_DOUBLE);
+ Tensor<double> tensor(tensorV);
+ ShapeRange range(tensor.getShape());
+ int i = 0;
+ for (auto& idx : range) {
+ ASSERT_NEAR(tensor.at(idx), protoTensor.double_val(i++), EPS);
+ }
+}
+
+
+TEST(Deserializer, ShapeDeserializationTest) {
+ Deserializer<Shape> deserializer;
+
+ proto::TensorShapeProto protoShape;
+ std::string serializedShape;
+ Shape shape;
+ protoShape.SerializeToString(&serializedShape);
+ shape = deserializer.deserializeFromString(serializedShape);
+ checkShape(shape, protoShape);
+
+ protoShape.add_dims(5);
+ protoShape.SerializeToString(&serializedShape);
+ shape = deserializer.deserializeFromString(serializedShape);
+ checkShape(shape, protoShape);
+
+ protoShape.add_dims(2);
+ protoShape.add_dims(4);
+ protoShape.SerializeToString(&serializedShape);
+ shape = deserializer.deserializeFromString(serializedShape);
+ checkShape(shape, protoShape);
+
+ protoShape.clear_dims();
+ protoShape.add_dims(1);
+ protoShape.add_dims(1);
+ protoShape.add_dims(1);
+ protoShape.add_dims(1);
+ protoShape.SerializeToString(&serializedShape);
+ shape = deserializer.deserializeFromString(serializedShape);
+ checkShape(shape, protoShape);
+}
+
+TEST(Deserializer, IntTensorDeserializationTest) {
+ Deserializer<TensorVariant> deserializer;
+ int tmp = 0;
+
+ proto::TensorProto protoTensor;
+ protoTensor.set_dtype(proto::DataType::DT_INT32);
+ proto::TensorShapeProto* protoShapePtr = protoTensor.mutable_shape();
+ std::string serializedTensor;
+
+ Shape shape_1{3};
+ for (auto& idx : ShapeRange(shape_1))
+ {
+ protoTensor.add_int_val(tmp++);
+ }
+ for (uint32_t i = 0; i < shape_1.rank(); i++)
+ {
+ protoShapePtr->add_dims(shape_1.dim(i));
+ }
+ protoTensor.SerializeToString(&serializedTensor);
+ TensorVariant tensor_1 = deserializer.deserializeFromString(serializedTensor);
+ checkShape(shape_1, protoTensor.shape());
+ checkIntTensorContent(tensor_1, protoTensor);
+
+ Shape shape_2{3, 4, 5};
+ protoShapePtr->clear_dims();
+ protoTensor.clear_int_val();
+ for (auto& idx : ShapeRange(shape_2))
+ {
+ protoTensor.add_int_val(tmp--);
+ }
+ for (uint32_t i = 0; i < shape_2.rank(); i++)
+ {
+ protoShapePtr->add_dims(shape_2.dim(i));
+ }
+ protoTensor.SerializeToString(&serializedTensor);
+ TensorVariant tensor_2 = deserializer.deserializeFromString(serializedTensor);
+ checkShape(shape_2, protoTensor.shape());
+ checkIntTensorContent(tensor_2, protoTensor);
+
+ Shape shape_3{1, 1, 1, 1, 1};
+ protoShapePtr->clear_dims();
+ protoTensor.clear_int_val();
+ for (auto& idx : ShapeRange(shape_3))
+ {
+ protoTensor.add_int_val(tmp++);
+ }
+ for (uint32_t i = 0; i < shape_3.rank(); i++)
+ {
+ protoShapePtr->add_dims(shape_3.dim(i));
+ }
+ protoTensor.SerializeToString(&serializedTensor);
+ TensorVariant tensor_3 = deserializer.deserializeFromString(serializedTensor);
+ checkShape(shape_3, protoTensor.shape());
+ checkIntTensorContent(tensor_3, protoTensor);
+}
+
+TEST(Deserializer, FloatTensorDeserializationTest) {
+ Deserializer<TensorVariant> deserializer;
+ float tmp = 1.0f;
+
+ proto::TensorProto protoTensor;
+ protoTensor.set_dtype(proto::DataType::DT_FLOAT);
+ proto::TensorShapeProto* protoShapePtr = protoTensor.mutable_shape();
+ std::string serializedTensor;
+
+ Shape shape_1{3};
+ for (auto& idx : ShapeRange(shape_1))
+ {
+ protoTensor.add_float_val(tmp);
+ tmp += 7.3f;
+ }
+ for (uint32_t i = 0; i < shape_1.rank(); i++)
+ {
+ protoShapePtr->add_dims(shape_1.dim(i));
+ }
+ protoTensor.SerializeToString(&serializedTensor);
+ TensorVariant tensor_1 = deserializer.deserializeFromString(serializedTensor);
+ checkShape(shape_1, protoTensor.shape());
+ checkFloatTensorContent(tensor_1, protoTensor);
+
+ Shape shape_2{3, 4, 5};
+ protoShapePtr->clear_dims();
+ protoTensor.clear_float_val();
+ for (auto& idx : ShapeRange(shape_2))
+ {
+ protoTensor.add_float_val(tmp);
+ tmp *= -1.32f;
+ }
+ for (uint32_t i = 0; i < shape_2.rank(); i++)
+ {
+ protoShapePtr->add_dims(shape_2.dim(i));
+ }
+ protoTensor.SerializeToString(&serializedTensor);
+ TensorVariant tensor_2 = deserializer.deserializeFromString(serializedTensor);
+ checkShape(shape_2, protoTensor.shape());
+ checkFloatTensorContent(tensor_2, protoTensor);
+
+ Shape shape_3{1, 1, 1, 1, 1};
+ protoShapePtr->clear_dims();
+ protoTensor.clear_float_val();
+ for (auto& idx : ShapeRange(shape_3))
+ {
+ tmp /= 2;
+ protoTensor.add_float_val(tmp);
+ }
+ for (uint32_t i = 0; i < shape_3.rank(); i++)
+ {
+ protoShapePtr->add_dims(shape_3.dim(i));
+ }
+ protoTensor.SerializeToString(&serializedTensor);
+ TensorVariant tensor_3 = deserializer.deserializeFromString(serializedTensor);
+ checkShape(shape_3, protoTensor.shape());
+ checkFloatTensorContent(tensor_3, protoTensor);
+}
+
+TEST(Deserializer, DoubleTensorDeserializationTest) {
+ Deserializer<TensorVariant> deserializer;
+ double tmp = 1.0f;
+
+ proto::TensorProto protoTensor;
+ protoTensor.set_dtype(proto::DataType::DT_DOUBLE);
+ proto::TensorShapeProto* protoShapePtr = protoTensor.mutable_shape();
+ std::string serializedTensor;
+
+ Shape shape_1{3};
+ for (auto& idx : ShapeRange(shape_1))
+ {
+ protoTensor.add_double_val(tmp);
+ tmp += 7.3f;
+ }
+ for (uint32_t i = 0; i < shape_1.rank(); i++)
+ {
+ protoShapePtr->add_dims(shape_1.dim(i));
+ }
+ protoTensor.SerializeToString(&serializedTensor);
+ TensorVariant tensor_1 = deserializer.deserializeFromString(serializedTensor);
+ checkShape(shape_1, protoTensor.shape());
+ checkDoubleTensorContent(tensor_1, protoTensor);
+
+ Shape shape_2{3, 4, 5};
+ protoShapePtr->clear_dims();
+ protoTensor.clear_double_val();
+ for (auto& idx : ShapeRange(shape_2))
+ {
+ protoTensor.add_double_val(tmp);
+ tmp *= -1.32f;
+ }
+ for (uint32_t i = 0; i < shape_2.rank(); i++)
+ {
+ protoShapePtr->add_dims(shape_2.dim(i));
+ }
+ protoTensor.SerializeToString(&serializedTensor);
+ TensorVariant tensor_2 = deserializer.deserializeFromString(serializedTensor);
+ checkShape(shape_2, protoTensor.shape());
+ checkDoubleTensorContent(tensor_2, protoTensor);
+
+ Shape shape_3{1, 1, 1, 1, 1};
+ protoShapePtr->clear_dims();
+ protoTensor.clear_double_val();
+ for (auto& idx : ShapeRange(shape_3))
+ {
+ tmp /= 2;
+ protoTensor.add_double_val(tmp);
+ }
+ for (uint32_t i = 0; i < shape_3.rank(); i++)
+ {
+ protoShapePtr->add_dims(shape_3.dim(i));
+ }
+ protoTensor.SerializeToString(&serializedTensor);
+ TensorVariant tensor_3 = deserializer.deserializeFromString(serializedTensor);
+ checkShape(shape_3, protoTensor.shape());
+ checkDoubleTensorContent(tensor_3, protoTensor);
+}