loco::NodeShape visit(const loco::ReLU6 *node) final { return loco::shape_get(node->input()); }
// TODO Support TensorBiasAdd
- // TODO SUpport TensorConcat
+
+ // CASE: TensorConcat
+ loco::NodeShape visit(const loco::TensorConcat *node)
+ {
+ auto const lhs_shape = loco::shape_get(node->lhs()).as<loco::TensorShape>();
+ auto const rhs_shape = loco::shape_get(node->rhs()).as<loco::TensorShape>();
+
+ assert(lhs_shape.rank() == rhs_shape.rank());
+ uint32_t const out_rank = lhs_shape.rank();
+
+ loco::TensorShape out_shape;
+
+ out_shape.rank(out_rank);
+
+ for (uint32_t axis = 0; axis < out_rank; ++axis)
+ {
+ if (axis == node->axis())
+ {
+ out_shape.dim(axis) = lhs_shape.dim(axis).value() + rhs_shape.dim(axis).value();
+ }
+ else
+ {
+ assert(lhs_shape.dim(axis) == rhs_shape.dim(axis));
+ out_shape.dim(axis) = lhs_shape.dim(axis);
+ }
+ }
+
+ return loco::NodeShape{out_shape};
+ }
};
} // namespace
ASSERT_EQ(loco::shape_get(testcase.maxpool2d_node).as<FeatureShape>().height(), 4);
ASSERT_EQ(loco::shape_get(testcase.maxpool2d_node).as<FeatureShape>().width(), 2);
}
+
+TEST(CanonicalShapeInferenceRuleTest, tensor_concat)
+{
+ using namespace loco;
+
+ // Create a sample network
+ GraphTestcase<GraphCode::TensorConcat> testcase;
+
+ testcase.lhs_node->shape({1, 2, 3});
+ testcase.rhs_node->shape({1, 4, 3});
+ testcase.concat_node->axis(1);
+
+ // Run Inference
+ loco::CanonicalShapeInferenceRule rule;
+
+ loco::apply(&rule).to(testcase.graph());
+
+ // Verify!
+ ASSERT_TRUE(loco::shape_known(testcase.concat_node));
+ ASSERT_EQ(loco::shape_get(testcase.concat_node).domain(), loco::Domain::Tensor);
+ ASSERT_EQ(loco::shape_get(testcase.concat_node).as<TensorShape>().rank(), 3);
+ ASSERT_EQ(loco::shape_get(testcase.concat_node).as<TensorShape>().dim(0), 1);
+ ASSERT_EQ(loco::shape_get(testcase.concat_node).as<TensorShape>().dim(1), 6);
+ ASSERT_EQ(loco::shape_get(testcase.concat_node).as<TensorShape>().dim(2), 3);
+}
FeatureCodec,
AvgPool2D,
MaxPool2D,
+ TensorConcat,
};
template <GraphCode Code> class GraphTestcase;
std::unique_ptr<loco::Graph> _graph;
};
+template <> class GraphTestcase<GraphCode::TensorConcat> final
+{
+public:
+ GraphTestcase()
+ {
+ using namespace loco;
+
+ // Create a sample network
+ _graph = make_graph();
+
+ // Create Graph Input/Output
+ auto graph_lhs = _graph->inputs()->create();
+ auto graph_rhs = _graph->inputs()->create();
+ auto graph_out = _graph->outputs()->create();
+
+ graph_lhs->name("lhs");
+ graph_rhs->name("rhs");
+ graph_out->name("output");
+
+ // Create and connect nodes
+ lhs_node = _graph->nodes()->create<Pull>();
+ lhs_node->index(0);
+
+ rhs_node = _graph->nodes()->create<Pull>();
+ rhs_node->index(1);
+
+ concat_node = _graph->nodes()->create<TensorConcat>();
+ concat_node->lhs(lhs_node);
+ concat_node->rhs(rhs_node);
+
+ push_node = _graph->nodes()->create<loco::Push>();
+ push_node->index(0);
+ push_node->from(concat_node);
+
+ // Create a link between input/output and corresponding nodes
+ loco::link(graph_lhs, lhs_node);
+ loco::link(graph_rhs, rhs_node);
+ loco::link(graph_out, push_node);
+ }
+
+public:
+ loco::Graph *graph() { return _graph.get(); }
+
+ loco::Pull *lhs_node = nullptr;
+ loco::Pull *rhs_node = nullptr;
+ loco::TensorConcat *concat_node = nullptr;
+ loco::Push *push_node = nullptr;
+
+private:
+ std::unique_ptr<loco::Graph> _graph;
+};
+
namespace
{