2 * Copyright (c) 2019 Samsung Electronics Co., Ltd. All Rights Reserved
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
8 * http://www.apache.org/licenses/LICENSE-2.0
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
17 #include "loco/Service/ShapeInference.h"
18 #include "GraphTestcase.h"
22 #include <gtest/gtest.h>
24 // This test validates whether framework works as expected.
25 TEST(ShapeInferenceTest, framework)
27 // Mock-up Shape Inference Rule
28 struct SampleShapeInferenceRule final : public loco::ShapeInferenceRule
31 SampleShapeInferenceRule(std::vector<const loco::Node *> *nodes) : _nodes{nodes}
37 // Accept all the dialects
38 bool recognize(const loco::Dialect *) const final { return true; }
40 bool infer(const loco::Node *node, loco::NodeShape &shape) const final
42 // Record the order of inference
43 _nodes->emplace_back(node);
45 if (_nodes->size() != 1)
50 // Set the first node as Tensor<1>
51 loco::TensorShape tensor_shape;
54 tensor_shape.dim(0) = 4;
56 shape.set(tensor_shape);
62 std::vector<const loco::Node *> *_nodes;
65 GraphTestcase<GraphCode::Identity> testcase;
67 std::vector<const loco::Node *> nodes;
69 SampleShapeInferenceRule rule{&nodes};
71 loco::apply(&rule).to(testcase.graph());
73 // Framework SHOULD visit all the nodes
74 ASSERT_EQ(2, nodes.size());
75 // Framework SHOULD visit "pull" before "push"
76 ASSERT_EQ(testcase.pull_node, nodes.at(0));
77 ASSERT_EQ(testcase.push_node, nodes.at(1));
79 // Framework SHOULD make an annotation if "rule" returns TRUE
80 ASSERT_TRUE(loco::shape_known(testcase.pull_node));
81 ASSERT_EQ(loco::Domain::Tensor, loco::shape_get(testcase.pull_node).domain());
82 ASSERT_EQ(1, loco::shape_get(testcase.pull_node).as<loco::TensorShape>().rank());
83 ASSERT_EQ(4, loco::shape_get(testcase.pull_node).as<loco::TensorShape>().dim(0));
85 // Framework SHOULD NOT make any annotation if "rule" returns FALSE
86 ASSERT_FALSE(loco::shape_known(testcase.push_node));