From d53495a7397d71c17550db9666070ce3cbf85892 Mon Sep 17 00:00:00 2001
From: =?utf8?q?=D0=9F=D0=B0=D0=B2=D0=B5=D0=BB=20=D0=98=D0=BB=D1=8C=D1=8E?=
=?utf8?q?=D1=82=D1=87=D0=B5=D0=BD=D0=BA=D0=BE/AI=20Tools=20Lab=20/SRR/Eng?=
=?utf8?q?ineer/=EC=82=BC=EC=84=B1=EC=A0=84=EC=9E=90?=
Date: Thu, 1 Aug 2019 20:01:13 +0300
Subject: [PATCH] [mir2loco] Support Const operation (#5794)
* Support Const operation transformation
* Append test with Const op containing float values
Signed-off-by: Pavel Iliutchenko
---
compiler/mir2loco/src/mir2loco.cpp | 45 ++++++++++++++++++++++++++++++++-
compiler/mir2loco/src/mir2loco.test.cpp | 28 ++++++++++++++++++++
2 files changed, 72 insertions(+), 1 deletion(-)
diff --git a/compiler/mir2loco/src/mir2loco.cpp b/compiler/mir2loco/src/mir2loco.cpp
index b2031b7..17e9d22 100644
--- a/compiler/mir2loco/src/mir2loco.cpp
+++ b/compiler/mir2loco/src/mir2loco.cpp
@@ -17,12 +17,15 @@
#include "mir2loco.h"
#include "mir/ops/ConcatOp.h"
+#include "mir/ops/ConstantOp.h"
#include "mir/ops/PoolOp.h"
#include "mir/ops/ReluOp.h"
#include "mir/ops/ReshapeOp.h"
#include
+#include
+
namespace mir2loco
{
namespace
@@ -95,6 +98,26 @@ loco::FeatureDecode *createNHWCFeatureDecode(loco::Graph *graph)
decode_node->decoder(std::move(dec));
return decode_node;
}
+
+loco::DataType DTYPE2DataType(const mir::DTYPE &dtype)
+{
+ switch (dtype)
+ {
+ case mir::DTYPE::UNKNOWN:
+ return loco::DataType::Unknown;
+ case mir::DTYPE::FLOAT32:
+ return loco::DataType::FLOAT32;
+ case mir::DTYPE::FLOAT64:
+ return loco::DataType::FLOAT64;
+ case mir::DTYPE::INT32:
+ return loco::DataType::S32;
+ case mir::DTYPE::INT64:
+ return loco::DataType::S64;
+ default:
+ break;
+ }
+ throw std::runtime_error("Unsupported dtype");
+}
} // namespace
void Transformer::visit(mir::ops::BatchNormOp &op) { throw std::runtime_error("NYI"); }
@@ -135,7 +158,27 @@ void Transformer::visit(mir::ops::ConcatOp &op)
_mir2loco_map.emplace(&op, last_concat);
}
-void Transformer::visit(mir::ops::ConstantOp &op) { throw std::runtime_error("NYI"); }
+void Transformer::visit(mir::ops::ConstantOp &op)
+{
+ auto const_node = _loco_graph->nodes()->create();
+ // Not set Input
+ // Set Shape
+ const auto &out_shape = op.getOutputShape(0);
+ setupShape(out_shape, const_node);
+ // Copy value
+ const auto &value = op.getValue();
+ const_node->dtype(DTYPE2DataType(value.getDataType()));
+ // TODO Support other data types
+ assert(const_node->dtype() == loco::DataType::FLOAT32);
+ const_node->size(out_shape.numElements());
+ // TODO Change that when loco support other DataTypeImpl
+ float &const_float = const_node->at(0);
+ char *loco_ptr = reinterpret_cast(&const_float);
+ char *mir_ptr = value.at(mir::Index(out_shape.rank()));
+ std::memcpy(loco_ptr, mir_ptr, out_shape.numElements() * sizeof(float));
+ // Add to map
+ _mir2loco_map.emplace(&op, const_node);
+}
void Transformer::visit(mir::ops::Conv2DOp &op) { throw std::runtime_error("NYI"); }
diff --git a/compiler/mir2loco/src/mir2loco.test.cpp b/compiler/mir2loco/src/mir2loco.test.cpp
index b274644..dcc85e1 100644
--- a/compiler/mir2loco/src/mir2loco.test.cpp
+++ b/compiler/mir2loco/src/mir2loco.test.cpp
@@ -17,6 +17,7 @@
#include "mir2loco.h"
#include "mir/ops/ConcatOp.h"
+#include "mir/ops/ConstantOp.h"
#include "mir/ops/PoolOp.h"
#include "mir/ops/ReluOp.h"
#include "mir/ops/ReshapeOp.h"
@@ -248,3 +249,30 @@ TEST_F(TestTransformer_mir2loco, Reshape_Test)
ASSERT_EQ(reshape_node->dim(1), 8);
ASSERT_EQ(reshape_node->dim(2), 81);
}
+
+TEST_F(TestTransformer_mir2loco, Const_Float_Test)
+{
+ mir::Graph mir_graph;
+
+ mir::Shape shape = mir::Shape({2, 3});
+ const float data[] = {5.9, 6.7, 5.32, 54.11231, 43.2444, 3.409};
+ auto mir_tensor = mir::TensorVariant(mir::DTYPE::FLOAT32, shape, (const void *)data);
+ auto *constant = mir_graph.create("constant", mir_tensor);
+ auto *output = mir_graph.create("output", constant->getOutput(0));
+
+ mir2loco::Transformer transformer;
+ auto loco_graph = transformer.transform(&mir_graph);
+ loco::ConstGen *const_node = dynamic_cast(loco_graph->nodes()->at(0));
+ loco::Push *push_node = dynamic_cast(loco_graph->nodes()->at(1));
+
+ ASSERT_NE(const_node, nullptr);
+ ASSERT_NE(push_node, nullptr);
+ ASSERT_EQ(push_node->from(), const_node);
+ // Shape check
+ ASSERT_EQ(const_node->rank(), 2);
+ ASSERT_EQ(const_node->dim(0), 2);
+ ASSERT_EQ(const_node->dim(1), 3);
+
+ for (int i = 0; i < 6; i++)
+ ASSERT_FLOAT_EQ(const_node->at(i), data[i]);
+}
--
2.7.4