return loco::NodeShape{node->encoder()->shape(input_tensor_shape)};
}
- // TODO Support FixedReshape
+ // CASE: FixedReshape
+ loco::NodeShape visit(const loco::FixedReshape *node) final
+ {
+ loco::TensorShape tensor_shape;
+
+ tensor_shape.rank(node->rank());
+ for (uint32_t axis = 0; axis < node->rank(); ++axis)
+ {
+ tensor_shape.dim(axis) = node->dim(axis);
+ }
+
+ return loco::NodeShape{tensor_shape};
+ }
// CASE: MaxPool2D
loco::NodeShape visit(const loco::MaxPool2D *node) final
ASSERT_EQ(loco::shape_get(testcase.concat_node).as<TensorShape>().dim(1), 6);
ASSERT_EQ(loco::shape_get(testcase.concat_node).as<TensorShape>().dim(2), 3);
}
+
+TEST(CanonicalShapeInferenceRuleTest, fixed_reshape)
+{
+ // Create a sample network
+ GraphTestcase<GraphCode::FixedReshape> testcase;
+
+ testcase.pull_node->shape({6, 6});
+ testcase.reshape_node->shape({4, 9});
+
+ // Run Inference
+ loco::CanonicalShapeInferenceRule rule;
+
+ loco::apply(&rule).to(testcase.graph());
+
+ // Verify!
+ ASSERT_TRUE(loco::shape_known(testcase.push_node));
+ ASSERT_EQ(loco::shape_get(testcase.push_node).domain(), loco::Domain::Tensor);
+ ASSERT_EQ(loco::shape_get(testcase.push_node).as<loco::TensorShape>().rank(), 2);
+ ASSERT_EQ(loco::shape_get(testcase.push_node).as<loco::TensorShape>().dim(0), 4);
+ ASSERT_EQ(loco::shape_get(testcase.push_node).as<loco::TensorShape>().dim(1), 9);
+}