// CASE: TensorSoftmax
loco::NodeShape visit(const loco::TensorSoftmax *node) final { return node_shape(node->input()); }
+ // CASE: TensorTranspose
+ loco::NodeShape visit(const loco::TensorTranspose *node) final
+ {
+ loco::TensorShape output_shape;
+
+ auto input_shape = node_shape(node->input()).as<loco::TensorShape>();
+ assert(input_shape.rank() == node->perm()->size());
+
+ output_shape.rank(input_shape.rank());
+
+ for (uint32_t output_axis = 0; output_axis < output_shape.rank(); output_axis++)
+ {
+ auto new_dim = input_shape.dim(node->perm()->axis(output_axis));
+ output_shape.dim(output_axis) = new_dim;
+ }
+
+ return loco::NodeShape(output_shape);
+ }
+
// CASE: TransposedConv2D
loco::NodeShape visit(const loco::TransposedConv2D *node) final
{
ASSERT_EQ(loco::shape_get(testcase.push_node).as<loco::TensorShape>().dim(1), 2);
}
+TEST(CanonicalShapeInferenceRuleTest, tensor_transpose)
+{
+ // Create a sample network
+ GraphTestcase<GraphCode::TensorTranspose> tc;
+
+ tc.pull_node->shape({10, 20, 30, 40});
+
+ tc.transpose_node->perm()->size(4);
+ tc.transpose_node->perm()->axis(0) = 2;
+ tc.transpose_node->perm()->axis(1) = 3;
+ tc.transpose_node->perm()->axis(2) = 0;
+ tc.transpose_node->perm()->axis(3) = 1;
+
+ // Run Inference
+ loco::CanonicalShapeInferenceRule rule;
+
+ loco::apply(&rule).to(tc.graph());
+
+ // Verify!
+ ASSERT_TRUE(loco::shape_known(tc.push_node));
+ ASSERT_EQ(loco::shape_get(tc.push_node).domain(), loco::Domain::Tensor);
+ ASSERT_EQ(loco::shape_get(tc.push_node).as<loco::TensorShape>().rank(), 4);
+ ASSERT_EQ(loco::shape_get(tc.push_node).as<loco::TensorShape>().dim(0), 30);
+ ASSERT_EQ(loco::shape_get(tc.push_node).as<loco::TensorShape>().dim(1), 40);
+ ASSERT_EQ(loco::shape_get(tc.push_node).as<loco::TensorShape>().dim(2), 10);
+ ASSERT_EQ(loco::shape_get(tc.push_node).as<loco::TensorShape>().dim(3), 20);
+}
+
namespace
{
MaxPool2D,
TensorBroadcast,
TensorConcat,
+ TensorTranspose,
FixedReshape,
};
std::unique_ptr<loco::Graph> _graph;
};
+template <> class GraphTestcase<GraphCode::TensorTranspose> final
+{
+public:
+ GraphTestcase()
+ {
+ using namespace loco;
+
+ // Create a sample network
+ _graph = make_graph();
+
+ // Create Graph Input/Output
+ auto graph_input = _graph->inputs()->create();
+ auto graph_output = _graph->outputs()->create();
+
+ graph_input->name("input");
+ graph_output->name("output");
+
+ // Create and connect nodes
+ pull_node = _graph->nodes()->create<Pull>();
+ pull_node->index(0);
+
+ transpose_node = _graph->nodes()->create<TensorTranspose>();
+ transpose_node->input(pull_node);
+
+ push_node = _graph->nodes()->create<loco::Push>();
+ push_node->index(0);
+ push_node->from(transpose_node);
+
+ // Create a link between input/output and corresponding nodes
+ loco::link(graph_input, pull_node);
+ loco::link(graph_output, push_node);
+ }
+
+public:
+ loco::Graph *graph() { return _graph.get(); }
+
+ loco::Pull *pull_node = nullptr;
+ loco::TensorTranspose *transpose_node = nullptr;
+ loco::Push *push_node = nullptr;
+
+private:
+ std::unique_ptr<loco::Graph> _graph;
+};
+
#endif // __GRAPH_TESTCASE_H__
loco::DataType visit(const loco::TensorBroadcast *node) { return loco::dtype_get(node->input()); }
loco::DataType visit(const loco::TensorReduce *node) { return loco::dtype_get(node->input()); }
loco::DataType visit(const loco::TensorSoftmax *node) { return loco::dtype_get(node->input()); }
+ loco::DataType visit(const loco::TensorTranspose *node) { return loco::dtype_get(node->input()); }
loco::DataType visit(const loco::TransposedConv2D *node) { return loco::dtype_get(node->ifm()); }
};