From 392ae7d79470ea9d94b339443c24ccac8ae4c991 Mon Sep 17 00:00:00 2001 From: =?utf8?q?=D0=A1=D0=B5=D1=80=D0=B3=D0=B5=D0=B9=20=D0=91=D0=B0=D1=80?= =?utf8?q?=D0=B0=D0=BD=D0=BD=D0=B8=D0=BA=D0=BE=D0=B2/AI=20Tools=20Lab=20/S?= =?utf8?q?RR/Engineer/=EC=82=BC=EC=84=B1=EC=A0=84=EC=9E=90?= Date: Mon, 5 Aug 2019 21:06:11 +0300 Subject: [PATCH] [mir2loco] Rename DTYPE to DataType (#6247) Replace deprecated `DTYPE` with `DataType`. Signed-off-by: Sergei Barannikov --- compiler/mir2loco/src/mir2loco.cpp | 18 +++++++++--------- compiler/mir2loco/src/mir2loco.test.cpp | 4 ++-- 2 files changed, 11 insertions(+), 11 deletions(-) diff --git a/compiler/mir2loco/src/mir2loco.cpp b/compiler/mir2loco/src/mir2loco.cpp index 6f70dd7..cea9a12 100644 --- a/compiler/mir2loco/src/mir2loco.cpp +++ b/compiler/mir2loco/src/mir2loco.cpp @@ -101,24 +101,24 @@ loco::FeatureDecode *createNHWCFeatureDecode(loco::Graph *graph) return decode_node; } -loco::DataType DTYPE2DataType(const mir::DTYPE &dtype) +loco::DataType ConvertDataType(mir::DataType data_type) { - switch (dtype) + switch (data_type) { - case mir::DTYPE::UNKNOWN: + case mir::DataType::UNKNOWN: return loco::DataType::Unknown; - case mir::DTYPE::FLOAT32: + case mir::DataType::FLOAT32: return loco::DataType::FLOAT32; - case mir::DTYPE::FLOAT64: + case mir::DataType::FLOAT64: return loco::DataType::FLOAT64; - case mir::DTYPE::INT32: + case mir::DataType::INT32: return loco::DataType::S32; - case mir::DTYPE::INT64: + case mir::DataType::INT64: return loco::DataType::S64; default: break; } - throw std::runtime_error("Unsupported dtype"); + throw std::runtime_error("Unsupported data type"); } } // namespace @@ -188,7 +188,7 @@ void Transformer::visit(mir::ops::ConstantOp &op) setupShape(out_shape, const_node); // Copy value const auto &value = op.getValue(); - const_node->dtype(DTYPE2DataType(value.getDataType())); + const_node->dtype(ConvertDataType(value.getDataType())); // TODO Support other data types assert(const_node->dtype() == loco::DataType::FLOAT32); const_node->size(out_shape.numElements()); diff --git a/compiler/mir2loco/src/mir2loco.test.cpp b/compiler/mir2loco/src/mir2loco.test.cpp index 077ee80..cd732be 100644 --- a/compiler/mir2loco/src/mir2loco.test.cpp +++ b/compiler/mir2loco/src/mir2loco.test.cpp @@ -258,7 +258,7 @@ TEST_F(TestTransformer_mir2loco, Const_Float_Test) mir::Shape shape = mir::Shape({2, 3}); const float data[] = {5.9, 6.7, 5.32, 54.11231, 43.2444, 3.409}; - auto mir_tensor = mir::TensorVariant(mir::DTYPE::FLOAT32, shape, (const void *)data); + auto mir_tensor = mir::TensorVariant(mir::DataType::FLOAT32, shape, (const void *)data); auto *constant = mir_graph.create("constant", mir_tensor); auto *output = mir_graph.create("output", constant->getOutput(0)); @@ -333,7 +333,7 @@ TEST_F(TestTransformer_mir2loco, Conv2D_Test) auto *input = mir_graph.create("input", input_shape); mir::Shape shape = mir::Shape({2, 3, 1, 1}); const float data[] = {5.9, 6.7, 5.32, 54.11231, 43.2444, 3.409}; - auto mir_tensor = mir::TensorVariant(mir::DTYPE::FLOAT32, shape, (const void *)data); + auto mir_tensor = mir::TensorVariant(mir::DataType::FLOAT32, shape, (const void *)data); auto *constant = mir_graph.create("constant", mir_tensor); auto *conv = mir_graph.create( "conv", input->getOutput(0), constant->getOutput(0), mir::Shape{2, 3}, -- 2.7.4