[exo-tflite] Test of broadcast algorithm (#7471)
author박세희/On-Device Lab(SR)/Principal Engineer/삼성전자 <saehie.park@samsung.com>
Mon, 16 Sep 2019 22:01:23 +0000 (07:01 +0900)
committerGitHub Enterprise <noreply-CODE@samsung.com>
Mon, 16 Sep 2019 22:01:23 +0000 (07:01 +0900)
This will add a test of broadcast algorithm with TFLAdd

Signed-off-by: SaeHie Park <saehie.park@samsung.com>
compiler/exo-tflite/src/Dialect/Service/TFLShapeInferenceRule.test.cpp

index eca72e6..bdacaf0 100644 (file)
@@ -20,6 +20,8 @@
 #include "Dialect/IR/TFLDialect.h"
 #include "Dialect/Service/TFLShapeInferenceRule.h"
 
+#include "Conversion/ShapeInferencePass.h"
+
 #include <loco.h>
 #include <loco/IR/CanonicalDialect.h>
 #include <loco/Service/ShapeInference.h>
@@ -171,3 +173,77 @@ TEST(TFLShapeInferenceRuleTest, avgpool2d_same)
     ASSERT_EQ(shape.dim(3).value(), 1);
   }
 }
+
+/**
+ * @note Function to test: Shape inference of two different input shapes
+ *
+ *       Rank expansion to higher input side
+ *          x(2,1,5) + y(3,5) --> x(2,1,5) + y(1,3,5)
+ *       Do output shape inference like numpy
+ *          x(2,1,5) + y(1,3,5) --> output(2,3,5)
+ *       For each axis, dim value should be same OR one of them should be 1
+ */
+TEST(TFLShapeInferenceRuleTest, TFAdd_shapeinf_different)
+{
+  auto g = loco::make_graph();
+
+  auto x_node = g->nodes()->create<loco::Pull>();
+  {
+    x_node->rank(3);
+    x_node->dim(0) = 2;
+    x_node->dim(1) = 1;
+    x_node->dim(2) = 5;
+  }
+  auto y_node = g->nodes()->create<loco::Pull>();
+  {
+    y_node->rank(2);
+    y_node->dim(0) = 3;
+    y_node->dim(1) = 5;
+  }
+  auto tfl_node = g->nodes()->create<locoex::TFLAdd>();
+  {
+    tfl_node->x(x_node);
+    tfl_node->y(y_node);
+  }
+  auto push_node = g->nodes()->create<loco::Push>();
+  {
+    push_node->from(tfl_node);
+  }
+
+  auto x_input = g->inputs()->create();
+  {
+    x_input->name("x");
+    loco::link(x_input, x_node);
+  }
+  auto y_input = g->inputs()->create();
+  {
+    y_input->name("y");
+    loco::link(y_input, y_node);
+  }
+  auto output = g->outputs()->create();
+  {
+    output->name("output");
+    loco::link(output, push_node);
+  }
+
+  // pre-check
+  ASSERT_FALSE(loco::shape_known(tfl_node));
+
+  exo::ShapeInferencePass pass;
+  while (pass.run(g.get()) == true)
+  {
+    ;
+  }
+
+  // Verify
+  {
+    ASSERT_TRUE(loco::shape_known(tfl_node));
+    ASSERT_EQ(loco::shape_get(tfl_node).domain(), loco::Domain::Tensor);
+
+    auto shape = loco::shape_get(tfl_node).as<loco::TensorShape>();
+    ASSERT_EQ(shape.rank(), 3);
+    ASSERT_EQ(shape.dim(0), 2);
+    ASSERT_EQ(shape.dim(1), 3);
+    ASSERT_EQ(shape.dim(2), 5);
+  }
+}