[moco-tf] Fix SquaredDifferenceCanonicalizer (#7245)
author박세희/On-Device Lab(SR)/Principal Engineer/삼성전자 <saehie.park@samsung.com>
Fri, 6 Sep 2019 08:57:22 +0000 (17:57 +0900)
committerGitHub Enterprise <noreply-CODE@samsung.com>
Fri, 6 Sep 2019 08:57:22 +0000 (17:57 +0900)
This will fix SquaredDifferenceCanonicalizer to check input nodes shape

Signed-off-by: SaeHie Park <saehie.park@samsung.com>
compiler/moco-tf/src/Canonicalization/SquaredDifferenceCanonicalizer.cpp

index 390cae0..4eb7a72 100644 (file)
@@ -21,6 +21,9 @@
 #include "Dialect/TFNodeVisitor.h"
 #include "Dialect/TFNodeImpl.h"
 
+#include <loco/IR/NodeShape.h>
+#include <loco/Service/ShapeInference.h>
+
 #include <stdex/Memory.h>
 
 namespace
@@ -47,12 +50,27 @@ bool canonicalize_sqdiff(loco::Graph *graph, moco::tf::TFSquaredDifference *node
    *                 A and B are drawn multiple times to simplify the diagram
    */
 
-  auto sub_node = graph->nodes()->create<loco::EltwiseSub>();
-  auto mul_node = graph->nodes()->create<loco::EltwiseMul>();
-
   auto node_A = node->x();
   auto node_B = node->y();
 
+  if (!loco::shape_known(node_A) || !loco::shape_known(node_B))
+  {
+    // Wait for shape inference
+    return false;
+  }
+
+  const auto &x_shape = loco::shape_get(node_A);
+  const auto &y_shape = loco::shape_get(node_B);
+
+  if (!(x_shape == y_shape))
+  {
+    // TODO support broadcast
+    return false;
+  }
+
+  auto sub_node = graph->nodes()->create<loco::EltwiseSub>();
+  auto mul_node = graph->nodes()->create<loco::EltwiseMul>();
+
   // update connections
   sub_node->lhs(node_A);
   sub_node->rhs(node_B);