Imported Upstream version 1.7.0
[platform/core/ml/nnfw.git] / compiler / exo / src / Dialect / Service / TFLShapeInferenceRule.test.cpp
1 /*
2  * Copyright (c) 2019 Samsung Electronics Co., Ltd. All Rights Reserved
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  *    http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16
17 #include "TestGraph.h"
18
19 #include "Dialect/IR/TFLNodes.h"
20 #include "Dialect/IR/TFLDialect.h"
21 #include "Dialect/Service/TFLShapeInferenceRule.h"
22
23 #include <loco.h>
24 #include <loco/IR/CanonicalDialect.h>
25 #include <loco/Service/ShapeInference.h>
26 #include <loco/Service/CanonicalShapeInferenceRule.h>
27 #include <loco/Service/MultiDialectShapeInferenceRule.h>
28
29 #include <stdex/Memory.h>
30
31 #include <gtest/gtest.h>
32
33 TEST(TFLShapeInferenceRuleTest, minimal_with_TFLRelu)
34 {
35   // Create a simple network
36   exo::test::TestGraph graph;
37   auto tfl_node = graph.append<locoex::TFLRelu>(graph.pull);
38   graph.complete(tfl_node);
39
40   // set shape
41   {
42     graph.pull->rank(2);
43     graph.pull->dim(0) = 3;
44     graph.pull->dim(1) = 4;
45   }
46
47   // pre-check
48   ASSERT_FALSE(loco::shape_known(tfl_node));
49
50   // shape inference
51   locoex::TFLShapeInferenceRule tfl_rule;
52   loco::CanonicalShapeInferenceRule canonical_rule;
53   loco::MultiDialectShapeInferenceRule rules;
54
55   rules.bind(loco::CanonicalDialect::get(), &canonical_rule)
56       .bind(locoex::TFLDialect::get(), &tfl_rule);
57
58   loco::apply(&rules).to(graph.g.get());
59
60   // Verify
61   {
62     ASSERT_TRUE(loco::shape_known(tfl_node));
63     ASSERT_EQ(loco::Domain::Tensor, loco::shape_get(tfl_node).domain());
64
65     auto shape = loco::shape_get(tfl_node).as<loco::TensorShape>();
66     ASSERT_EQ(2, shape.rank());
67     ASSERT_EQ(3, shape.dim(0));
68     ASSERT_EQ(4, shape.dim(1));
69   }
70 }
71
72 // based on the case shown in
73 // https://www.corvil.com/kb/what-is-the-difference-between-same-and-valid-padding-in-tf-nn-max-pool-of-tensorflow
74 TEST(TFLShapeInferenceRuleTest, avgpool2d_valid)
75 {
76   exo::test::TestGraph graph;
77   auto tfl_node = graph.append<locoex::TFLAveragePool2D>(graph.pull);
78   graph.complete();
79
80   auto pull = graph.pull;
81   {
82     pull->shape({1, 4, 3, 1});
83   }
84   // setting TFLAveragePool2D
85   {
86     tfl_node->filter()->h(2);
87     tfl_node->filter()->w(2);
88     tfl_node->stride()->h(2);
89     tfl_node->stride()->w(2);
90     tfl_node->fusedActivationFunction(locoex::FusedActFunc::NONE);
91     tfl_node->padding(locoex::Padding::VALID);
92   }
93   ASSERT_FALSE(loco::shape_known(tfl_node));
94
95   // shape inference
96   locoex::TFLShapeInferenceRule tfl_rule;
97   loco::CanonicalShapeInferenceRule canonical_rule;
98   loco::MultiDialectShapeInferenceRule rules;
99
100   rules.bind(loco::CanonicalDialect::get(), &canonical_rule)
101       .bind(locoex::TFLDialect::get(), &tfl_rule);
102
103   loco::apply(&rules).to(graph.g.get());
104
105   // Verify
106   {
107     ASSERT_TRUE(loco::shape_known(tfl_node));
108     ASSERT_EQ(loco::Domain::Tensor, loco::shape_get(tfl_node).domain());
109
110     auto shape = loco::shape_get(tfl_node).as<loco::TensorShape>();
111     ASSERT_EQ(4, shape.rank());
112     ASSERT_EQ(1, shape.dim(0).value());
113     ASSERT_EQ(2, shape.dim(1).value());
114     ASSERT_EQ(1, shape.dim(2).value());
115     ASSERT_EQ(1, shape.dim(3).value());
116   }
117 }
118
119 TEST(TFLShapeInferenceRuleTest, avgpool2d_same)
120 {
121   exo::test::TestGraph graph;
122   auto tfl_node = graph.append<locoex::TFLAveragePool2D>(graph.pull);
123   graph.complete();
124
125   auto pull = graph.pull;
126   {
127     pull->shape({1, 4, 3, 1});
128   }
129
130   // setting TFLAveragePool2D
131   {
132     tfl_node->filter()->h(2);
133     tfl_node->filter()->w(2);
134     tfl_node->stride()->h(2);
135     tfl_node->stride()->w(2);
136     tfl_node->fusedActivationFunction(locoex::FusedActFunc::NONE);
137     tfl_node->padding(locoex::Padding::SAME);
138   }
139
140   ASSERT_FALSE(loco::shape_known(tfl_node));
141
142   // shape inference
143   locoex::TFLShapeInferenceRule tfl_rule;
144   loco::CanonicalShapeInferenceRule canonical_rule;
145   loco::MultiDialectShapeInferenceRule rules;
146
147   rules.bind(loco::CanonicalDialect::get(), &canonical_rule)
148       .bind(locoex::TFLDialect::get(), &tfl_rule);
149
150   loco::apply(&rules).to(graph.g.get());
151
152   // Verify
153   {
154     ASSERT_TRUE(loco::shape_known(tfl_node));
155     ASSERT_EQ(loco::Domain::Tensor, loco::shape_get(tfl_node).domain());
156
157     auto shape = loco::shape_get(tfl_node).as<loco::TensorShape>();
158     ASSERT_EQ(4, shape.rank());
159     ASSERT_EQ(1, shape.dim(0).value());
160     ASSERT_EQ(2, shape.dim(1).value());
161     ASSERT_EQ(2, shape.dim(2).value());
162     ASSERT_EQ(1, shape.dim(3).value());
163   }
164 }
165
166 /**
167  * @note Function to test: Shape inference of two different input shapes
168  *
169  *       Rank expansion to higher input side
170  *          x(2,1,5) + y(3,5) --> x(2,1,5) + y(1,3,5)
171  *       Do output shape inference like numpy
172  *          x(2,1,5) + y(1,3,5) --> output(2,3,5)
173  *       For each axis, dim value should be same OR one of them should be 1
174  */
175 TEST(TFLShapeInferenceRuleTest, TFAdd_shapeinf_different)
176 {
177   auto g = loco::make_graph();
178
179   auto x_node = g->nodes()->create<loco::Pull>();
180   {
181     x_node->rank(3);
182     x_node->dim(0) = 2;
183     x_node->dim(1) = 1;
184     x_node->dim(2) = 5;
185   }
186   auto y_node = g->nodes()->create<loco::Pull>();
187   {
188     y_node->rank(2);
189     y_node->dim(0) = 3;
190     y_node->dim(1) = 5;
191   }
192   auto tfl_node = g->nodes()->create<locoex::TFLAdd>();
193   {
194     tfl_node->x(x_node);
195     tfl_node->y(y_node);
196   }
197   auto push_node = g->nodes()->create<loco::Push>();
198   {
199     push_node->from(tfl_node);
200   }
201
202   auto x_input = g->inputs()->create();
203   {
204     x_input->name("x");
205     loco::link(x_input, x_node);
206   }
207   auto y_input = g->inputs()->create();
208   {
209     y_input->name("y");
210     loco::link(y_input, y_node);
211   }
212   auto output = g->outputs()->create();
213   {
214     output->name("output");
215     loco::link(output, push_node);
216   }
217
218   // pre-check
219   ASSERT_FALSE(loco::shape_known(tfl_node));
220
221   exo::ShapeInferencePass pass;
222   while (pass.run(g.get()) == true)
223   {
224     ;
225   }
226
227   // Verify
228   {
229     ASSERT_TRUE(loco::shape_known(tfl_node));
230     ASSERT_EQ(loco::Domain::Tensor, loco::shape_get(tfl_node).domain());
231
232     auto shape = loco::shape_get(tfl_node).as<loco::TensorShape>();
233     ASSERT_EQ(3, shape.rank());
234     ASSERT_EQ(2, shape.dim(0));
235     ASSERT_EQ(3, shape.dim(1));
236     ASSERT_EQ(5, shape.dim(2));
237   }
238 }
239
240 TEST(TFLShapeInferenceRuleTest, TFLTranspose_simple)
241 {
242   exo::test::ExampleGraph<exo::test::ExampleGraphType::TFLTranspose> g;
243
244   g.pull->rank(4);
245   g.pull->dim(0) = 10;
246   g.pull->dim(1) = 20;
247   g.pull->dim(2) = 30;
248   g.pull->dim(3) = 40;
249
250   g.const_perm->dtype(loco::DataType::S32);
251   g.const_perm->rank(1);
252   g.const_perm->dim(0) = 4;
253   g.const_perm->size<loco::DataType::S32>(4);
254   g.const_perm->at<loco::DataType::S32>(0) = 2;
255   g.const_perm->at<loco::DataType::S32>(1) = 3;
256   g.const_perm->at<loco::DataType::S32>(2) = 0;
257   g.const_perm->at<loco::DataType::S32>(3) = 1;
258
259   // pre-check
260   ASSERT_FALSE(loco::shape_known(g.tfl_transpose));
261
262   exo::ShapeInferencePass pass;
263   while (pass.run(g.graph()) == true)
264     ;
265
266   // Verify
267   {
268     ASSERT_TRUE(loco::shape_known(g.tfl_transpose));
269
270     auto shape = loco::shape_get(g.tfl_transpose).as<loco::TensorShape>();
271     ASSERT_EQ(4, shape.rank());
272     ASSERT_EQ(30, shape.dim(0));
273     ASSERT_EQ(40, shape.dim(1));
274     ASSERT_EQ(10, shape.dim(2));
275     ASSERT_EQ(20, shape.dim(3));
276   }
277 }