[moco-tf] Update shape inference in Canonicalizers (#6732)
author박세희/On-Device Lab(SR)/Principal Engineer/삼성전자 <saehie.park@samsung.com>
Tue, 20 Aug 2019 09:12:28 +0000 (18:12 +0900)
committerGitHub Enterprise <noreply-CODE@samsung.com>
Tue, 20 Aug 2019 09:12:28 +0000 (18:12 +0900)
This will update shape inference to use loco service in Canonicalizers of Add,Mul,Realdiv and Sub nodes

Signed-off-by: SaeHie Park <saehie.park@samsung.com>
compiler/moco-tf/src/Canonicalization/AddCanonicalizer.cpp
compiler/moco-tf/src/Canonicalization/MulCanonicalizer.cpp
compiler/moco-tf/src/Canonicalization/RealDivCanonicalizer.cpp
compiler/moco-tf/src/Canonicalization/SubCanonicalizer.cpp

index 1daeb1a..6fc1fc2 100644 (file)
 
 #include "AddCanonicalizer.h"
 
-#include "Annotations/ShapeInferenceData.h"
-
 #include "Dialect/TFDialect.h"
 #include "Dialect/TFNodes.h"
 #include "Dialect/TFNodeVisitor.h"
 #include "Dialect/TFNodeImpl.h"
 
+#include <loco/IR/NodeShape.h>
+#include <loco/Service/ShapeInference.h>
+
 #include <stdex/Memory.h>
 
 namespace
@@ -56,12 +57,16 @@ bool canonicalize_add(loco::Graph *graph, moco::tf::TFAdd *node)
   auto node_A = node->x();
   auto node_B = node->y();
 
-  auto x_shapedata = node_A->annot<moco::tf::ShapeInferenceData>();
-  assert(x_shapedata != nullptr);
-  auto y_shapedata = node_B->annot<moco::tf::ShapeInferenceData>();
-  assert(y_shapedata != nullptr);
+  if (!loco::shape_known(node_A) || !loco::shape_known(node_B))
+  {
+    // Wait for shape inference
+    return false;
+  }
+
+  const auto &x_shape = loco::shape_get(node_A);
+  const auto &y_shape = loco::shape_get(node_B);
 
-  if (!(*x_shapedata == *y_shapedata))
+  if (!(x_shape == y_shape))
   {
     // TODO support broadcast Add
     return false;
index e0448b4..b440209 100644 (file)
 
 #include "MulCanonicalizer.h"
 
-#include "Annotations/ShapeInferenceData.h"
-
 #include "Dialect/TFDialect.h"
 #include "Dialect/TFNodes.h"
 #include "Dialect/TFNodeVisitor.h"
 #include "Dialect/TFNodeImpl.h"
 
+#include <loco/IR/NodeShape.h>
+#include <loco/Service/ShapeInference.h>
+
 #include <stdex/Memory.h>
 
 namespace
@@ -56,12 +57,16 @@ bool canonicalize_mul(loco::Graph *graph, moco::tf::TFMul *node)
   auto node_A = node->x();
   auto node_B = node->y();
 
-  auto x_shapedata = node_A->annot<moco::tf::ShapeInferenceData>();
-  assert(x_shapedata != nullptr);
-  auto y_shapedata = node_B->annot<moco::tf::ShapeInferenceData>();
-  assert(y_shapedata != nullptr);
+  if (!loco::shape_known(node_A) || !loco::shape_known(node_B))
+  {
+    // Wait for shape inference
+    return false;
+  }
+
+  const auto &x_shape = loco::shape_get(node_A);
+  const auto &y_shape = loco::shape_get(node_B);
 
-  if (!(*x_shapedata == *y_shapedata))
+  if (!(x_shape == y_shape))
   {
     // TODO support broadcast Mul
     return false;
index 72b3ce6..e0e4c5e 100644 (file)
 
 #include "RealDivCanonicalizer.h"
 
-#include "Annotations/ShapeInferenceData.h"
-
 #include "Dialect/TFDialect.h"
 #include "Dialect/TFNodes.h"
 #include "Dialect/TFNodeVisitor.h"
 #include "Dialect/TFNodeImpl.h"
 
+#include <loco/IR/NodeShape.h>
+#include <loco/Service/ShapeInference.h>
+
 #include <stdex/Memory.h>
 
 namespace
@@ -56,12 +57,16 @@ bool canonicalize_div(loco::Graph *graph, moco::tf::TFRealDiv *node)
   auto node_A = node->x();
   auto node_B = node->y();
 
-  auto x_shapedata = node_A->annot<moco::tf::ShapeInferenceData>();
-  assert(x_shapedata != nullptr);
-  auto y_shapedata = node_B->annot<moco::tf::ShapeInferenceData>();
-  assert(y_shapedata != nullptr);
+  if (!loco::shape_known(node_A) || !loco::shape_known(node_B))
+  {
+    // Wait for shape inference
+    return false;
+  }
+
+  const auto &x_shape = loco::shape_get(node_A);
+  const auto &y_shape = loco::shape_get(node_B);
 
-  if (!(*x_shapedata == *y_shapedata))
+  if (!(x_shape == y_shape))
   {
     // TODO support broadcast Div
     return false;
index 316d9fb..10bfcf6 100644 (file)
 
 #include "SubCanonicalizer.h"
 
-#include "Annotations/ShapeInferenceData.h"
-
 #include "Dialect/TFDialect.h"
 #include "Dialect/TFNodes.h"
 #include "Dialect/TFNodeVisitor.h"
 #include "Dialect/TFNodeImpl.h"
 
+#include <loco/IR/NodeShape.h>
+#include <loco/Service/ShapeInference.h>
+
 #include <stdex/Memory.h>
 
 namespace
@@ -56,12 +57,16 @@ bool canonicalize_sub(loco::Graph *graph, moco::tf::TFSub *node)
   auto node_A = node->x();
   auto node_B = node->y();
 
-  auto x_shapedata = node_A->annot<moco::tf::ShapeInferenceData>();
-  assert(x_shapedata != nullptr);
-  auto y_shapedata = node_B->annot<moco::tf::ShapeInferenceData>();
-  assert(y_shapedata != nullptr);
+  if (!loco::shape_known(node_A) || !loco::shape_known(node_B))
+  {
+    // Wait for shape inference
+    return false;
+  }
+
+  const auto &x_shape = loco::shape_get(node_A);
+  const auto &y_shape = loco::shape_get(node_B);
 
-  if (!(*x_shapedata == *y_shapedata))
+  if (!(x_shape == y_shape))
   {
     // TODO support broadcast Sub
     return false;