[exo-tflite] Revise shape inf test (#7144)
author박세희/On-Device Lab(SR)/Principal Engineer/삼성전자 <saehie.park@samsung.com>
Wed, 4 Sep 2019 01:18:14 +0000 (10:18 +0900)
committerGitHub Enterprise <noreply-CODE@samsung.com>
Wed, 4 Sep 2019 01:18:14 +0000 (10:18 +0900)
This will revise shape inference test as network has TFL and Canonical dialects

Signed-off-by: SaeHie Park <saehie.park@samsung.com>
compiler/exo-tflite/src/Dialect/Service/TFLShapeInferenceRule.test.cpp

index 3bb24f9..7337b4b 100644 (file)
  */
 
 #include "Dialect/IR/TFLNodes.h"
+#include "Dialect/IR/TFLDialect.h"
 #include "Dialect/Service/TFLShapeAnnot.h"
 #include "Dialect/Service/TFLShapeInferenceRule.h"
 
 #include <loco.h>
+#include <loco/IR/CanonicalDialect.h>
 #include <loco/Service/ShapeInference.h>
+#include <loco/Service/CanonicalShapeInferenceRule.h>
+#include <loco/Service/MultiDialectShapeInferenceRule.h>
 
 #include <stdex/Memory.h>
 
@@ -68,7 +72,13 @@ TEST(TFLShapeInferenceRuleTest, minimal_with_TFLRelu)
   tfl_node->annot<locoex::TFLShapeAnnot>(std::move(shape_annot));
 
   locoex::TFLShapeInferenceRule tfl_rule;
-  loco::apply(&tfl_rule).to(g.get());
+  loco::CanonicalShapeInferenceRule canonical_rule;
+  loco::MultiDialectShapeInferenceRule rules;
+
+  rules.bind(loco::CanonicalDialect::get(), &canonical_rule)
+      .bind(locoex::TFLDialect::get(), &tfl_rule);
+
+  loco::apply(&rules).to(g.get());
 
   // Verify
   auto check_shape = [](locoex::TFLRelu *tfl_node) {
@@ -84,7 +94,7 @@ TEST(TFLShapeInferenceRuleTest, minimal_with_TFLRelu)
   check_shape(tfl_node);
 
   // step 2.
-  loco::apply(&tfl_rule).to(g.get());
+  loco::apply(&rules).to(g.get());
 
   check_shape(tfl_node);
 }