#include "mir2loco.h"
#include "mir/ops/ConcatOp.h"
+#include "mir/ops/ConstantOp.h"
#include "mir/ops/PoolOp.h"
#include "mir/ops/ReluOp.h"
#include "mir/ops/ReshapeOp.h"
#include <stdex/Memory.h>
+#include <cstring>
+
namespace mir2loco
{
namespace
decode_node->decoder(std::move(dec));
return decode_node;
}
+
+loco::DataType DTYPE2DataType(const mir::DTYPE &dtype)
+{
+ switch (dtype)
+ {
+ case mir::DTYPE::UNKNOWN:
+ return loco::DataType::Unknown;
+ case mir::DTYPE::FLOAT32:
+ return loco::DataType::FLOAT32;
+ case mir::DTYPE::FLOAT64:
+ return loco::DataType::FLOAT64;
+ case mir::DTYPE::INT32:
+ return loco::DataType::S32;
+ case mir::DTYPE::INT64:
+ return loco::DataType::S64;
+ default:
+ break;
+ }
+ throw std::runtime_error("Unsupported dtype");
+}
} // namespace
void Transformer::visit(mir::ops::BatchNormOp &op) { throw std::runtime_error("NYI"); }
_mir2loco_map.emplace(&op, last_concat);
}
-void Transformer::visit(mir::ops::ConstantOp &op) { throw std::runtime_error("NYI"); }
+void Transformer::visit(mir::ops::ConstantOp &op)
+{
+ auto const_node = _loco_graph->nodes()->create<loco::ConstGen>();
+ // Not set Input
+ // Set Shape
+ const auto &out_shape = op.getOutputShape(0);
+ setupShape(out_shape, const_node);
+ // Copy value
+ const auto &value = op.getValue();
+ const_node->dtype(DTYPE2DataType(value.getDataType()));
+ // TODO Support other data types
+ assert(const_node->dtype() == loco::DataType::FLOAT32);
+ const_node->size<loco::DataType::FLOAT32>(out_shape.numElements());
+ // TODO Change that when loco support other DataTypeImpl
+ float &const_float = const_node->at<loco::DataType::FLOAT32>(0);
+ char *loco_ptr = reinterpret_cast<char *>(&const_float);
+ char *mir_ptr = value.at(mir::Index(out_shape.rank()));
+ std::memcpy(loco_ptr, mir_ptr, out_shape.numElements() * sizeof(float));
+ // Add to map
+ _mir2loco_map.emplace(&op, const_node);
+}
void Transformer::visit(mir::ops::Conv2DOp &op) { throw std::runtime_error("NYI"); }
#include "mir2loco.h"
#include "mir/ops/ConcatOp.h"
+#include "mir/ops/ConstantOp.h"
#include "mir/ops/PoolOp.h"
#include "mir/ops/ReluOp.h"
#include "mir/ops/ReshapeOp.h"
ASSERT_EQ(reshape_node->dim(1), 8);
ASSERT_EQ(reshape_node->dim(2), 81);
}
+
+TEST_F(TestTransformer_mir2loco, Const_Float_Test)
+{
+ mir::Graph mir_graph;
+
+ mir::Shape shape = mir::Shape({2, 3});
+ const float data[] = {5.9, 6.7, 5.32, 54.11231, 43.2444, 3.409};
+ auto mir_tensor = mir::TensorVariant(mir::DTYPE::FLOAT32, shape, (const void *)data);
+ auto *constant = mir_graph.create<mir::ops::ConstantOp>("constant", mir_tensor);
+ auto *output = mir_graph.create<mir::ops::OutputOp>("output", constant->getOutput(0));
+
+ mir2loco::Transformer transformer;
+ auto loco_graph = transformer.transform(&mir_graph);
+ loco::ConstGen *const_node = dynamic_cast<loco::ConstGen *>(loco_graph->nodes()->at(0));
+ loco::Push *push_node = dynamic_cast<loco::Push *>(loco_graph->nodes()->at(1));
+
+ ASSERT_NE(const_node, nullptr);
+ ASSERT_NE(push_node, nullptr);
+ ASSERT_EQ(push_node->from(), const_node);
+ // Shape check
+ ASSERT_EQ(const_node->rank(), 2);
+ ASSERT_EQ(const_node->dim(0), 2);
+ ASSERT_EQ(const_node->dim(1), 3);
+
+ for (int i = 0; i < 6; i++)
+ ASSERT_FLOAT_EQ(const_node->at<loco::DataType::FLOAT32>(i), data[i]);
+}