From d53495a7397d71c17550db9666070ce3cbf85892 Mon Sep 17 00:00:00 2001 From: =?utf8?q?=D0=9F=D0=B0=D0=B2=D0=B5=D0=BB=20=D0=98=D0=BB=D1=8C=D1=8E?= =?utf8?q?=D1=82=D1=87=D0=B5=D0=BD=D0=BA=D0=BE/AI=20Tools=20Lab=20/SRR/Eng?= =?utf8?q?ineer/=EC=82=BC=EC=84=B1=EC=A0=84=EC=9E=90?= Date: Thu, 1 Aug 2019 20:01:13 +0300 Subject: [PATCH] [mir2loco] Support Const operation (#5794) * Support Const operation transformation * Append test with Const op containing float values Signed-off-by: Pavel Iliutchenko --- compiler/mir2loco/src/mir2loco.cpp | 45 ++++++++++++++++++++++++++++++++- compiler/mir2loco/src/mir2loco.test.cpp | 28 ++++++++++++++++++++ 2 files changed, 72 insertions(+), 1 deletion(-) diff --git a/compiler/mir2loco/src/mir2loco.cpp b/compiler/mir2loco/src/mir2loco.cpp index b2031b7..17e9d22 100644 --- a/compiler/mir2loco/src/mir2loco.cpp +++ b/compiler/mir2loco/src/mir2loco.cpp @@ -17,12 +17,15 @@ #include "mir2loco.h" #include "mir/ops/ConcatOp.h" +#include "mir/ops/ConstantOp.h" #include "mir/ops/PoolOp.h" #include "mir/ops/ReluOp.h" #include "mir/ops/ReshapeOp.h" #include +#include + namespace mir2loco { namespace @@ -95,6 +98,26 @@ loco::FeatureDecode *createNHWCFeatureDecode(loco::Graph *graph) decode_node->decoder(std::move(dec)); return decode_node; } + +loco::DataType DTYPE2DataType(const mir::DTYPE &dtype) +{ + switch (dtype) + { + case mir::DTYPE::UNKNOWN: + return loco::DataType::Unknown; + case mir::DTYPE::FLOAT32: + return loco::DataType::FLOAT32; + case mir::DTYPE::FLOAT64: + return loco::DataType::FLOAT64; + case mir::DTYPE::INT32: + return loco::DataType::S32; + case mir::DTYPE::INT64: + return loco::DataType::S64; + default: + break; + } + throw std::runtime_error("Unsupported dtype"); +} } // namespace void Transformer::visit(mir::ops::BatchNormOp &op) { throw std::runtime_error("NYI"); } @@ -135,7 +158,27 @@ void Transformer::visit(mir::ops::ConcatOp &op) _mir2loco_map.emplace(&op, last_concat); } -void Transformer::visit(mir::ops::ConstantOp &op) { throw std::runtime_error("NYI"); } +void Transformer::visit(mir::ops::ConstantOp &op) +{ + auto const_node = _loco_graph->nodes()->create(); + // Not set Input + // Set Shape + const auto &out_shape = op.getOutputShape(0); + setupShape(out_shape, const_node); + // Copy value + const auto &value = op.getValue(); + const_node->dtype(DTYPE2DataType(value.getDataType())); + // TODO Support other data types + assert(const_node->dtype() == loco::DataType::FLOAT32); + const_node->size(out_shape.numElements()); + // TODO Change that when loco support other DataTypeImpl + float &const_float = const_node->at(0); + char *loco_ptr = reinterpret_cast(&const_float); + char *mir_ptr = value.at(mir::Index(out_shape.rank())); + std::memcpy(loco_ptr, mir_ptr, out_shape.numElements() * sizeof(float)); + // Add to map + _mir2loco_map.emplace(&op, const_node); +} void Transformer::visit(mir::ops::Conv2DOp &op) { throw std::runtime_error("NYI"); } diff --git a/compiler/mir2loco/src/mir2loco.test.cpp b/compiler/mir2loco/src/mir2loco.test.cpp index b274644..dcc85e1 100644 --- a/compiler/mir2loco/src/mir2loco.test.cpp +++ b/compiler/mir2loco/src/mir2loco.test.cpp @@ -17,6 +17,7 @@ #include "mir2loco.h" #include "mir/ops/ConcatOp.h" +#include "mir/ops/ConstantOp.h" #include "mir/ops/PoolOp.h" #include "mir/ops/ReluOp.h" #include "mir/ops/ReshapeOp.h" @@ -248,3 +249,30 @@ TEST_F(TestTransformer_mir2loco, Reshape_Test) ASSERT_EQ(reshape_node->dim(1), 8); ASSERT_EQ(reshape_node->dim(2), 81); } + +TEST_F(TestTransformer_mir2loco, Const_Float_Test) +{ + mir::Graph mir_graph; + + mir::Shape shape = mir::Shape({2, 3}); + const float data[] = {5.9, 6.7, 5.32, 54.11231, 43.2444, 3.409}; + auto mir_tensor = mir::TensorVariant(mir::DTYPE::FLOAT32, shape, (const void *)data); + auto *constant = mir_graph.create("constant", mir_tensor); + auto *output = mir_graph.create("output", constant->getOutput(0)); + + mir2loco::Transformer transformer; + auto loco_graph = transformer.transform(&mir_graph); + loco::ConstGen *const_node = dynamic_cast(loco_graph->nodes()->at(0)); + loco::Push *push_node = dynamic_cast(loco_graph->nodes()->at(1)); + + ASSERT_NE(const_node, nullptr); + ASSERT_NE(push_node, nullptr); + ASSERT_EQ(push_node->from(), const_node); + // Shape check + ASSERT_EQ(const_node->rank(), 2); + ASSERT_EQ(const_node->dim(0), 2); + ASSERT_EQ(const_node->dim(1), 3); + + for (int i = 0; i < 6; i++) + ASSERT_FLOAT_EQ(const_node->at(i), data[i]); +} -- 2.7.4