2 * Copyright (c) 2019 Samsung Electronics Co., Ltd. All Rights Reserved
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
8 * http://www.apache.org/licenses/LICENSE-2.0
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
17 #include "TestGraph.h"
19 #include "Dialect/IR/TFLNodes.h"
20 #include "Dialect/IR/TFLDialect.h"
21 #include "Dialect/Service/TFLShapeInferenceRule.h"
24 #include <loco/IR/CanonicalDialect.h>
25 #include <loco/Service/ShapeInference.h>
26 #include <loco/Service/CanonicalShapeInferenceRule.h>
27 #include <loco/Service/MultiDialectShapeInferenceRule.h>
29 #include <stdex/Memory.h>
31 #include <gtest/gtest.h>
33 TEST(TFLShapeInferenceRuleTest, minimal_with_TFLRelu)
35 // Create a simple network
36 exo::test::TestGraph graph;
37 auto tfl_node = graph.append<locoex::TFLRelu>(graph.pull);
38 graph.complete(tfl_node);
43 graph.pull->dim(0) = 3;
44 graph.pull->dim(1) = 4;
48 ASSERT_FALSE(loco::shape_known(tfl_node));
51 locoex::TFLShapeInferenceRule tfl_rule;
52 loco::CanonicalShapeInferenceRule canonical_rule;
53 loco::MultiDialectShapeInferenceRule rules;
55 rules.bind(loco::CanonicalDialect::get(), &canonical_rule)
56 .bind(locoex::TFLDialect::get(), &tfl_rule);
58 loco::apply(&rules).to(graph.g.get());
62 ASSERT_TRUE(loco::shape_known(tfl_node));
63 ASSERT_EQ(loco::Domain::Tensor, loco::shape_get(tfl_node).domain());
65 auto shape = loco::shape_get(tfl_node).as<loco::TensorShape>();
66 ASSERT_EQ(2, shape.rank());
67 ASSERT_EQ(3, shape.dim(0));
68 ASSERT_EQ(4, shape.dim(1));
72 // based on the case shown in
73 // https://www.corvil.com/kb/what-is-the-difference-between-same-and-valid-padding-in-tf-nn-max-pool-of-tensorflow
74 TEST(TFLShapeInferenceRuleTest, avgpool2d_valid)
76 exo::test::TestGraph graph;
77 auto tfl_node = graph.append<locoex::TFLAveragePool2D>(graph.pull);
80 auto pull = graph.pull;
82 pull->shape({1, 4, 3, 1});
84 // setting TFLAveragePool2D
86 tfl_node->filter()->h(2);
87 tfl_node->filter()->w(2);
88 tfl_node->stride()->h(2);
89 tfl_node->stride()->w(2);
90 tfl_node->fusedActivationFunction(locoex::FusedActFunc::NONE);
91 tfl_node->padding(locoex::Padding::VALID);
93 ASSERT_FALSE(loco::shape_known(tfl_node));
96 locoex::TFLShapeInferenceRule tfl_rule;
97 loco::CanonicalShapeInferenceRule canonical_rule;
98 loco::MultiDialectShapeInferenceRule rules;
100 rules.bind(loco::CanonicalDialect::get(), &canonical_rule)
101 .bind(locoex::TFLDialect::get(), &tfl_rule);
103 loco::apply(&rules).to(graph.g.get());
107 ASSERT_TRUE(loco::shape_known(tfl_node));
108 ASSERT_EQ(loco::Domain::Tensor, loco::shape_get(tfl_node).domain());
110 auto shape = loco::shape_get(tfl_node).as<loco::TensorShape>();
111 ASSERT_EQ(4, shape.rank());
112 ASSERT_EQ(1, shape.dim(0).value());
113 ASSERT_EQ(2, shape.dim(1).value());
114 ASSERT_EQ(1, shape.dim(2).value());
115 ASSERT_EQ(1, shape.dim(3).value());
119 TEST(TFLShapeInferenceRuleTest, avgpool2d_same)
121 exo::test::TestGraph graph;
122 auto tfl_node = graph.append<locoex::TFLAveragePool2D>(graph.pull);
125 auto pull = graph.pull;
127 pull->shape({1, 4, 3, 1});
130 // setting TFLAveragePool2D
132 tfl_node->filter()->h(2);
133 tfl_node->filter()->w(2);
134 tfl_node->stride()->h(2);
135 tfl_node->stride()->w(2);
136 tfl_node->fusedActivationFunction(locoex::FusedActFunc::NONE);
137 tfl_node->padding(locoex::Padding::SAME);
140 ASSERT_FALSE(loco::shape_known(tfl_node));
143 locoex::TFLShapeInferenceRule tfl_rule;
144 loco::CanonicalShapeInferenceRule canonical_rule;
145 loco::MultiDialectShapeInferenceRule rules;
147 rules.bind(loco::CanonicalDialect::get(), &canonical_rule)
148 .bind(locoex::TFLDialect::get(), &tfl_rule);
150 loco::apply(&rules).to(graph.g.get());
154 ASSERT_TRUE(loco::shape_known(tfl_node));
155 ASSERT_EQ(loco::Domain::Tensor, loco::shape_get(tfl_node).domain());
157 auto shape = loco::shape_get(tfl_node).as<loco::TensorShape>();
158 ASSERT_EQ(4, shape.rank());
159 ASSERT_EQ(1, shape.dim(0).value());
160 ASSERT_EQ(2, shape.dim(1).value());
161 ASSERT_EQ(2, shape.dim(2).value());
162 ASSERT_EQ(1, shape.dim(3).value());
167 * @note Function to test: Shape inference of two different input shapes
169 * Rank expansion to higher input side
170 * x(2,1,5) + y(3,5) --> x(2,1,5) + y(1,3,5)
171 * Do output shape inference like numpy
172 * x(2,1,5) + y(1,3,5) --> output(2,3,5)
173 * For each axis, dim value should be same OR one of them should be 1
175 TEST(TFLShapeInferenceRuleTest, TFAdd_shapeinf_different)
177 auto g = loco::make_graph();
179 auto x_node = g->nodes()->create<loco::Pull>();
186 auto y_node = g->nodes()->create<loco::Pull>();
192 auto tfl_node = g->nodes()->create<locoex::TFLAdd>();
197 auto push_node = g->nodes()->create<loco::Push>();
199 push_node->from(tfl_node);
202 auto x_input = g->inputs()->create();
205 loco::link(x_input, x_node);
207 auto y_input = g->inputs()->create();
210 loco::link(y_input, y_node);
212 auto output = g->outputs()->create();
214 output->name("output");
215 loco::link(output, push_node);
219 ASSERT_FALSE(loco::shape_known(tfl_node));
221 exo::ShapeInferencePass pass;
222 while (pass.run(g.get()) == true)
229 ASSERT_TRUE(loco::shape_known(tfl_node));
230 ASSERT_EQ(loco::Domain::Tensor, loco::shape_get(tfl_node).domain());
232 auto shape = loco::shape_get(tfl_node).as<loco::TensorShape>();
233 ASSERT_EQ(3, shape.rank());
234 ASSERT_EQ(2, shape.dim(0));
235 ASSERT_EQ(3, shape.dim(1));
236 ASSERT_EQ(5, shape.dim(2));
240 TEST(TFLShapeInferenceRuleTest, TFLTranspose_simple)
242 exo::test::ExampleGraph<exo::test::ExampleGraphType::TFLTranspose> g;
250 g.const_perm->dtype(loco::DataType::S32);
251 g.const_perm->rank(1);
252 g.const_perm->dim(0) = 4;
253 g.const_perm->size<loco::DataType::S32>(4);
254 g.const_perm->at<loco::DataType::S32>(0) = 2;
255 g.const_perm->at<loco::DataType::S32>(1) = 3;
256 g.const_perm->at<loco::DataType::S32>(2) = 0;
257 g.const_perm->at<loco::DataType::S32>(3) = 1;
260 ASSERT_FALSE(loco::shape_known(g.tfl_transpose));
262 exo::ShapeInferencePass pass;
263 while (pass.run(g.graph()) == true)
268 ASSERT_TRUE(loco::shape_known(g.tfl_transpose));
270 auto shape = loco::shape_get(g.tfl_transpose).as<loco::TensorShape>();
271 ASSERT_EQ(4, shape.rank());
272 ASSERT_EQ(30, shape.dim(0));
273 ASSERT_EQ(40, shape.dim(1));
274 ASSERT_EQ(10, shape.dim(2));
275 ASSERT_EQ(20, shape.dim(3));