From 0eb89efb7180f7d59358497703adfb155a5906ea Mon Sep 17 00:00:00 2001 From: =?utf8?q?=EB=B0=95=EC=84=B8=ED=9D=AC/On-Device=20Lab=28SR=29/Princip?= =?utf8?q?al=20Engineer/=EC=82=BC=EC=84=B1=EC=A0=84=EC=9E=90?= Date: Fri, 6 Sep 2019 17:57:22 +0900 Subject: [PATCH] [moco-tf] Fix SquaredDifferenceCanonicalizer (#7245) This will fix SquaredDifferenceCanonicalizer to check input nodes shape Signed-off-by: SaeHie Park --- .../SquaredDifferenceCanonicalizer.cpp | 24 +++++++++++++++++++--- 1 file changed, 21 insertions(+), 3 deletions(-) diff --git a/compiler/moco-tf/src/Canonicalization/SquaredDifferenceCanonicalizer.cpp b/compiler/moco-tf/src/Canonicalization/SquaredDifferenceCanonicalizer.cpp index 390cae0..4eb7a72 100644 --- a/compiler/moco-tf/src/Canonicalization/SquaredDifferenceCanonicalizer.cpp +++ b/compiler/moco-tf/src/Canonicalization/SquaredDifferenceCanonicalizer.cpp @@ -21,6 +21,9 @@ #include "Dialect/TFNodeVisitor.h" #include "Dialect/TFNodeImpl.h" +#include +#include + #include namespace @@ -47,12 +50,27 @@ bool canonicalize_sqdiff(loco::Graph *graph, moco::tf::TFSquaredDifference *node * A and B are drawn multiple times to simplify the diagram */ - auto sub_node = graph->nodes()->create(); - auto mul_node = graph->nodes()->create(); - auto node_A = node->x(); auto node_B = node->y(); + if (!loco::shape_known(node_A) || !loco::shape_known(node_B)) + { + // Wait for shape inference + return false; + } + + const auto &x_shape = loco::shape_get(node_A); + const auto &y_shape = loco::shape_get(node_B); + + if (!(x_shape == y_shape)) + { + // TODO support broadcast + return false; + } + + auto sub_node = graph->nodes()->create(); + auto mul_node = graph->nodes()->create(); + // update connections sub_node->lhs(node_A); sub_node->rhs(node_B); -- 2.7.4