[exo-tflite] Adding test assert into TFLTypeInferenceRuleTest (#7284)
author윤현식/On-Device Lab(SR)/Principal Engineer/삼성전자 <hyunsik.yoon@samsung.com>
Mon, 9 Sep 2019 04:32:36 +0000 (13:32 +0900)
committer박세희/On-Device Lab(SR)/Principal Engineer/삼성전자 <saehie.park@samsung.com>
Mon, 9 Sep 2019 04:32:36 +0000 (13:32 +0900)
This adds more asserts into the test body of TFLTypeInferenceRuleTest.

Signed-off-by: Hyun Sik Yoon <hyunsik.yoon@samsung.com>
compiler/exo-tflite/src/Dialect/Service/TFLTypeInference.test.cpp

index 57bc203..9eef2f6 100644 (file)
  */
 
 #include "Dialect/IR/TFLNodes.h"
+#include "Dialect/IR/TFLDialect.h"
 #include "Dialect/Service/TFLTypeInferenceRule.h"
 
 #include <loco.h>
+#include <loco/IR/CanonicalDialect.h>
+#include <loco/Service/TypeInference.h>
 
 #include <stdex/Memory.h>
 
@@ -29,6 +32,9 @@ TEST(TFLTypeInferenceRuleTest, minimal_with_TFLRelu)
   auto g = loco::make_graph();
 
   auto pull_node = g->nodes()->create<loco::Pull>();
+  {
+    pull_node->dtype(loco::DataType::S32);
+  }
 
   auto tfl_node = g->nodes()->create<locoex::TFLRelu>();
   tfl_node->features(pull_node);
@@ -50,5 +56,18 @@ TEST(TFLTypeInferenceRuleTest, minimal_with_TFLRelu)
   // pre-check
   ASSERT_FALSE(loco::dtype_known(tfl_node));
 
-  // TODO Rewrite code below using new TFLShapeInferenceRule
+  // type inference
+  locoex::TFLTypeInferenceRule tfl_rule;
+  loco::CanonicalTypeInferenceRule canon_rule;
+  loco::MultiDialectTypeInferenceRule rules;
+
+  rules.bind(loco::CanonicalDialect::get(), &canon_rule);
+  rules.bind(locoex::TFLDialect::get(), &tfl_rule);
+
+  loco::apply(&rules).to(g.get());
+
+  // Verify
+  ASSERT_TRUE(loco::dtype_known(tfl_node));
+  auto type = loco::dtype_get(tfl_node);
+  ASSERT_EQ(type, loco::DataType::S32);
 }