[mir2loco] Support Const operation (#5794)
authorПавел Ильютченко/AI Tools Lab /SRR/Engineer/삼성전자 <p.iliutchenk@samsung.com>
Thu, 1 Aug 2019 17:01:13 +0000 (20:01 +0300)
committerEfimov Alexander/AI Tools Lab/./Samsung Electronics <a.efimov@samsung.com>
Thu, 1 Aug 2019 17:01:13 +0000 (20:01 +0300)
* Support Const operation transformation
* Append test with Const op containing float values

Signed-off-by: Pavel Iliutchenko <p.iliutchenk@samsung.com>
compiler/mir2loco/src/mir2loco.cpp
compiler/mir2loco/src/mir2loco.test.cpp

index b2031b7..17e9d22 100644 (file)
 #include "mir2loco.h"
 
 #include "mir/ops/ConcatOp.h"
+#include "mir/ops/ConstantOp.h"
 #include "mir/ops/PoolOp.h"
 #include "mir/ops/ReluOp.h"
 #include "mir/ops/ReshapeOp.h"
 
 #include <stdex/Memory.h>
 
+#include <cstring>
+
 namespace mir2loco
 {
 namespace
@@ -95,6 +98,26 @@ loco::FeatureDecode *createNHWCFeatureDecode(loco::Graph *graph)
   decode_node->decoder(std::move(dec));
   return decode_node;
 }
+
+loco::DataType DTYPE2DataType(const mir::DTYPE &dtype)
+{
+  switch (dtype)
+  {
+    case mir::DTYPE::UNKNOWN:
+      return loco::DataType::Unknown;
+    case mir::DTYPE::FLOAT32:
+      return loco::DataType::FLOAT32;
+    case mir::DTYPE::FLOAT64:
+      return loco::DataType::FLOAT64;
+    case mir::DTYPE::INT32:
+      return loco::DataType::S32;
+    case mir::DTYPE::INT64:
+      return loco::DataType::S64;
+    default:
+      break;
+  }
+  throw std::runtime_error("Unsupported dtype");
+}
 } // namespace
 
 void Transformer::visit(mir::ops::BatchNormOp &op) { throw std::runtime_error("NYI"); }
@@ -135,7 +158,27 @@ void Transformer::visit(mir::ops::ConcatOp &op)
   _mir2loco_map.emplace(&op, last_concat);
 }
 
-void Transformer::visit(mir::ops::ConstantOp &op) { throw std::runtime_error("NYI"); }
+void Transformer::visit(mir::ops::ConstantOp &op)
+{
+  auto const_node = _loco_graph->nodes()->create<loco::ConstGen>();
+  // Not set Input
+  // Set Shape
+  const auto &out_shape = op.getOutputShape(0);
+  setupShape(out_shape, const_node);
+  // Copy value
+  const auto &value = op.getValue();
+  const_node->dtype(DTYPE2DataType(value.getDataType()));
+  // TODO Support other data types
+  assert(const_node->dtype() == loco::DataType::FLOAT32);
+  const_node->size<loco::DataType::FLOAT32>(out_shape.numElements());
+  // TODO Change that when loco support other DataTypeImpl
+  float &const_float = const_node->at<loco::DataType::FLOAT32>(0);
+  char *loco_ptr = reinterpret_cast<char *>(&const_float);
+  char *mir_ptr = value.at(mir::Index(out_shape.rank()));
+  std::memcpy(loco_ptr, mir_ptr, out_shape.numElements() * sizeof(float));
+  // Add to map
+  _mir2loco_map.emplace(&op, const_node);
+}
 
 void Transformer::visit(mir::ops::Conv2DOp &op) { throw std::runtime_error("NYI"); }
 
index b274644..dcc85e1 100644 (file)
@@ -17,6 +17,7 @@
 #include "mir2loco.h"
 
 #include "mir/ops/ConcatOp.h"
+#include "mir/ops/ConstantOp.h"
 #include "mir/ops/PoolOp.h"
 #include "mir/ops/ReluOp.h"
 #include "mir/ops/ReshapeOp.h"
@@ -248,3 +249,30 @@ TEST_F(TestTransformer_mir2loco, Reshape_Test)
   ASSERT_EQ(reshape_node->dim(1), 8);
   ASSERT_EQ(reshape_node->dim(2), 81);
 }
+
+TEST_F(TestTransformer_mir2loco, Const_Float_Test)
+{
+  mir::Graph mir_graph;
+
+  mir::Shape shape = mir::Shape({2, 3});
+  const float data[] = {5.9, 6.7, 5.32, 54.11231, 43.2444, 3.409};
+  auto mir_tensor = mir::TensorVariant(mir::DTYPE::FLOAT32, shape, (const void *)data);
+  auto *constant = mir_graph.create<mir::ops::ConstantOp>("constant", mir_tensor);
+  auto *output = mir_graph.create<mir::ops::OutputOp>("output", constant->getOutput(0));
+
+  mir2loco::Transformer transformer;
+  auto loco_graph = transformer.transform(&mir_graph);
+  loco::ConstGen *const_node = dynamic_cast<loco::ConstGen *>(loco_graph->nodes()->at(0));
+  loco::Push *push_node = dynamic_cast<loco::Push *>(loco_graph->nodes()->at(1));
+
+  ASSERT_NE(const_node, nullptr);
+  ASSERT_NE(push_node, nullptr);
+  ASSERT_EQ(push_node->from(), const_node);
+  // Shape check
+  ASSERT_EQ(const_node->rank(), 2);
+  ASSERT_EQ(const_node->dim(0), 2);
+  ASSERT_EQ(const_node->dim(1), 3);
+
+  for (int i = 0; i < 6; i++)
+    ASSERT_FLOAT_EQ(const_node->at<loco::DataType::FLOAT32>(i), data[i]);
+}