#include "mir2loco.h"
+#include "mir/ops/BiasAddOp.h"
#include "mir/ops/ConcatOp.h"
#include "mir/ops/ConstantOp.h"
#include "mir/ops/PoolOp.h"
void Transformer::visit(mir::ops::BatchNormOp &op) { throw std::runtime_error("NYI"); }
-void Transformer::visit(mir::ops::BiasAddOp &op) { throw std::runtime_error("NYI"); }
+void Transformer::visit(mir::ops::BiasAddOp &op)
+{
+ // Set Input
+ auto input = op.getInput(0)->getProducer()->getNode();
+ auto bias = op.getInput(1)->getProducer()->getNode();
+ auto loco_input = _mir2loco_map.at(input);
+ auto loco_bias = _mir2loco_map.at(bias);
+ // Create BiasEncode
+ auto bias_encode = _loco_graph->nodes()->create<loco::BiasEncode>();
+ bias_encode->input(loco_bias);
+ // Set value and bias
+ auto bias_add_node = _loco_graph->nodes()->create<loco::TensorBiasAdd>();
+ bias_add_node->value(loco_input);
+ bias_add_node->bias(bias_encode);
+ // Set axis
+ bias_add_node->axis(3); // NHWC
+ // Not set Shape
+ // Add to map
+ _mir2loco_map.emplace(&op, bias_add_node);
+}
void Transformer::visit(mir::ops::CappedReluOp &op) { throw std::runtime_error("NYI"); }
#include "mir2loco.h"
+#include "mir/ops/BiasAddOp.h"
#include "mir/ops/ConcatOp.h"
#include "mir/ops/ConstantOp.h"
#include "mir/ops/PoolOp.h"
mir2loco::Transformer transformer;
auto loco_graph = transformer.transform(&mir_graph);
+
loco::ConstGen *const_node = dynamic_cast<loco::ConstGen *>(loco_graph->nodes()->at(0));
loco::Push *push_node = dynamic_cast<loco::Push *>(loco_graph->nodes()->at(1));
for (int i = 0; i < 6; i++)
ASSERT_FLOAT_EQ(const_node->at<loco::DataType::FLOAT32>(i), data[i]);
}
+
+TEST_F(TestTransformer_mir2loco, Bias_Add_Test)
+{
+ mir::Graph mir_graph;
+
+ mir::Shape input_shape{5, 6, 7, 3};
+ mir::Shape bias_shape{3};
+ auto *input = mir_graph.create<mir::ops::InputOp>("input", input_shape);
+ auto *bias = mir_graph.create<mir::ops::InputOp>("bias", bias_shape);
+ auto *bias_add =
+ mir_graph.create<mir::ops::BiasAddOp>("bias_add", input->getOutput(0), bias->getOutput(0));
+ auto *output = mir_graph.create<mir::ops::OutputOp>("output", bias_add->getOutput(0));
+
+ mir2loco::Transformer transformer;
+ auto loco_graph = transformer.transform(&mir_graph);
+
+ loco::Pull *pull1_node = dynamic_cast<loco::Pull *>(loco_graph->nodes()->at(0));
+ loco::Pull *pull2_node = dynamic_cast<loco::Pull *>(loco_graph->nodes()->at(1));
+ loco::BiasEncode *bias_enc_node = dynamic_cast<loco::BiasEncode *>(loco_graph->nodes()->at(2));
+ loco::TensorBiasAdd *bias_add_node =
+ dynamic_cast<loco::TensorBiasAdd *>(loco_graph->nodes()->at(3));
+ loco::Push *push_node = dynamic_cast<loco::Push *>(loco_graph->nodes()->at(4));
+
+ ASSERT_NE(pull1_node, nullptr);
+ ASSERT_NE(pull2_node, nullptr);
+ ASSERT_NE(bias_enc_node, nullptr);
+ ASSERT_NE(bias_add_node, nullptr);
+ ASSERT_NE(push_node, nullptr);
+
+ ASSERT_EQ(bias_enc_node->input(), pull2_node);
+ ASSERT_EQ(bias_add_node->value(), pull1_node);
+ ASSERT_EQ(bias_add_node->bias(), bias_enc_node);
+ ASSERT_EQ(push_node->from(), bias_add_node);
+ // Shape check
+ ASSERT_EQ(pull1_node->rank(), 4);
+ ASSERT_EQ(pull1_node->dim(0), 5);
+ ASSERT_EQ(pull1_node->dim(1), 6);
+ ASSERT_EQ(pull1_node->dim(2), 7);
+ ASSERT_EQ(pull1_node->dim(3), 3);
+
+ ASSERT_EQ(pull2_node->rank(), 1);
+ ASSERT_EQ(pull2_node->dim(0), 3);
+
+ ASSERT_EQ(bias_add_node->axis(), 3);
+}