From a9ab8469dded078de9ffa75b4171feb9082fa842 Mon Sep 17 00:00:00 2001 From: =?utf8?q?=EB=B0=95=EC=84=B8=ED=9D=AC/On-Device=20Lab=28SR=29/Princip?= =?utf8?q?al=20Engineer/=EC=82=BC=EC=84=B1=EC=A0=84=EC=9E=90?= Date: Mon, 30 Sep 2019 12:09:33 +0900 Subject: [PATCH] [moco-tf] Remove ShapeInferenceData in Canonicalizers (#7810) This will revise Canonicalizers not to use ShapeInferenceData and use node_shape instead Signed-off-by: SaeHie Park --- .../src/Canonicalization/AvgPoolCanonicalizer.cpp | 1 - .../src/Canonicalization/ConcatV2Canonicalizer.cpp | 7 +++--- .../DepthwiseConv2dNativeCanonicalizer.cpp | 1 - .../src/Canonicalization/RsqrtCanonicalizer.cpp | 27 +++++++++++----------- .../src/Canonicalization/SoftmaxCanonicalizer.cpp | 10 ++++---- .../src/Canonicalization/SqueezeCanonicalizer.cpp | 9 ++++---- 6 files changed, 25 insertions(+), 30 deletions(-) diff --git a/compiler/moco-tf/src/Canonicalization/AvgPoolCanonicalizer.cpp b/compiler/moco-tf/src/Canonicalization/AvgPoolCanonicalizer.cpp index b764ba8..73cfaa3 100644 --- a/compiler/moco-tf/src/Canonicalization/AvgPoolCanonicalizer.cpp +++ b/compiler/moco-tf/src/Canonicalization/AvgPoolCanonicalizer.cpp @@ -18,7 +18,6 @@ #include "Annotations/PadData.h" #include "Annotations/StrideData.h" -#include "Annotations/ShapeInferenceData.h" #include "Annotations/WindowData.h" #include "Dialect/TFDialect.h" diff --git a/compiler/moco-tf/src/Canonicalization/ConcatV2Canonicalizer.cpp b/compiler/moco-tf/src/Canonicalization/ConcatV2Canonicalizer.cpp index f5cf1ac..0125c69 100644 --- a/compiler/moco-tf/src/Canonicalization/ConcatV2Canonicalizer.cpp +++ b/compiler/moco-tf/src/Canonicalization/ConcatV2Canonicalizer.cpp @@ -19,12 +19,12 @@ #include "LogHelper.h" #include "Annotations/ConcatData.h" -#include "Annotations/ShapeInferenceData.h" #include "Dialect/TFDialect.h" #include "Dialect/TFNodes.h" #include "Dialect/TFNodeVisitor.h" #include "Dialect/TFNodeImpl.h" +#include "Dialect/TFShapeInferenceHelper.h" #include @@ -76,8 +76,9 @@ bool canonicalize_concat(loco::Graph *graph, moco::tf::TFConcatV2 *node) assert(concat_data != nullptr); auto axis_value = concat_data->axis(); - auto shapedata = node->annot(); - auto node_rank = shapedata->rank(); + auto nodeshape = moco::tf::node_shape(node); + auto tensorshape = nodeshape.as(); + auto node_rank = tensorshape.rank(); uint32_t axis_absolute = (axis_value >= 0) ? axis_value : (int32_t)node_rank + axis_value; diff --git a/compiler/moco-tf/src/Canonicalization/DepthwiseConv2dNativeCanonicalizer.cpp b/compiler/moco-tf/src/Canonicalization/DepthwiseConv2dNativeCanonicalizer.cpp index 75907f7..561e04e 100644 --- a/compiler/moco-tf/src/Canonicalization/DepthwiseConv2dNativeCanonicalizer.cpp +++ b/compiler/moco-tf/src/Canonicalization/DepthwiseConv2dNativeCanonicalizer.cpp @@ -17,7 +17,6 @@ #include "DepthwiseConv2dNativeCanonicalizer.h" #include "Annotations/PadData.h" -#include "Annotations/ShapeInferenceData.h" #include "Annotations/StrideData.h" #include "Dialect/TFDialect.h" diff --git a/compiler/moco-tf/src/Canonicalization/RsqrtCanonicalizer.cpp b/compiler/moco-tf/src/Canonicalization/RsqrtCanonicalizer.cpp index 98352f8..aafa8cf 100644 --- a/compiler/moco-tf/src/Canonicalization/RsqrtCanonicalizer.cpp +++ b/compiler/moco-tf/src/Canonicalization/RsqrtCanonicalizer.cpp @@ -16,12 +16,11 @@ #include "RsqrtCanonicalizer.h" -#include "Annotations/ShapeInferenceData.h" - #include "Dialect/TFDialect.h" #include "Dialect/TFNodes.h" #include "Dialect/TFNodeVisitor.h" #include "Dialect/TFNodeImpl.h" +#include "Dialect/TFShapeInferenceHelper.h" #include @@ -33,12 +32,11 @@ namespace { template -void prepare_const_gen(loco::ConstGen *const_node, const moco::tf::ShapeInferenceData *shapedata, - T value); +void prepare_const_gen(loco::ConstGen *const_node, const loco::TensorShape &tensorshape, T value); template <> -void prepare_const_gen(loco::ConstGen *const_node, - const moco::tf::ShapeInferenceData *shapedata, float value) +void prepare_const_gen(loco::ConstGen *const_node, const loco::TensorShape &tensorshape, + float value) { LOGGER(l); @@ -47,18 +45,18 @@ void prepare_const_gen(loco::ConstGen *const_node, auto dtype = loco::DataType::FLOAT32; const_node->dtype(dtype); - auto rank = shapedata->rank(); + auto rank = tensorshape.rank(); const_node->rank(rank); for (uint32_t r = 0; r < rank; ++r) { - if (shapedata->dim(r).known()) - const_node->dim(r) = shapedata->dim(r); + if (tensorshape.dim(r).known()) + const_node->dim(r) = tensorshape.dim(r); else throw std::runtime_error("Cannot handle unknown shape"); - assert(shapedata->dim(r).value() > 0); + assert(tensorshape.dim(r).value() > 0); - const_num_elements *= shapedata->dim(r).value(); + const_num_elements *= tensorshape.dim(r).value(); } INFO(l) << "prepare_const_gen : Elements = " << const_num_elements; @@ -91,13 +89,14 @@ bool canonicalize_rsqrt(loco::Graph *graph, moco::tf::TFRsqrt *node) * TFRsqrt is converted to 1 / EltwiseSqrt */ - auto rsqrt_shapedata = node->annot(); - if (rsqrt_shapedata == nullptr) + auto nodeshape = moco::tf::node_shape(node); + if (nodeshape.domain() == loco::Domain::Unknown) { // We need this shape information assert(false); // this shouldn't happen, let's add an alarm return false; } + auto tensorshape = nodeshape.as(); if (!loco::dtype_known(node)) { @@ -114,7 +113,7 @@ bool canonicalize_rsqrt(loco::Graph *graph, moco::tf::TFRsqrt *node) switch (dtype) { case loco::DataType::FLOAT32: - prepare_const_gen(const_node, rsqrt_shapedata, 1.0f); + prepare_const_gen(const_node, tensorshape, 1.0f); break; default: diff --git a/compiler/moco-tf/src/Canonicalization/SoftmaxCanonicalizer.cpp b/compiler/moco-tf/src/Canonicalization/SoftmaxCanonicalizer.cpp index 5bee38b..466ade2 100644 --- a/compiler/moco-tf/src/Canonicalization/SoftmaxCanonicalizer.cpp +++ b/compiler/moco-tf/src/Canonicalization/SoftmaxCanonicalizer.cpp @@ -16,12 +16,11 @@ #include "SoftmaxCanonicalizer.h" -#include "Annotations/ShapeInferenceData.h" - #include "Dialect/TFDialect.h" #include "Dialect/TFNodes.h" #include "Dialect/TFNodeVisitor.h" #include "Dialect/TFNodeImpl.h" +#include "Dialect/TFShapeInferenceHelper.h" #include @@ -46,12 +45,11 @@ bool canonicalize_softmax(loco::Graph *graph, moco::tf::TFSoftmax *node) * In ---- TensorSoftmax ----- Out(s) */ - auto softmax_shape = node->annot(); - + auto nodeshape = moco::tf::node_shape(node); // Canonicalization into TensorSoftmax is valid when softmax has shape info - assert(softmax_shape); + assert(nodeshape.domain() != loco::Domain::Unknown); - auto softmax_tensor_shape = softmax_shape->tensor_shape(); + auto softmax_tensor_shape = nodeshape.as(); // Create loco node to replace auto softmax = graph->nodes()->create(); diff --git a/compiler/moco-tf/src/Canonicalization/SqueezeCanonicalizer.cpp b/compiler/moco-tf/src/Canonicalization/SqueezeCanonicalizer.cpp index 12793cb..f2752bf 100644 --- a/compiler/moco-tf/src/Canonicalization/SqueezeCanonicalizer.cpp +++ b/compiler/moco-tf/src/Canonicalization/SqueezeCanonicalizer.cpp @@ -16,12 +16,11 @@ #include "SqueezeCanonicalizer.h" -#include "Annotations/ShapeInferenceData.h" - #include "Dialect/TFDialect.h" #include "Dialect/TFNodes.h" #include "Dialect/TFNodeVisitor.h" #include "Dialect/TFNodeImpl.h" +#include "Dialect/TFShapeInferenceHelper.h" #include @@ -46,12 +45,12 @@ bool canonicalize_squeeze_to_reshape(loco::Graph *graph, moco::tf::TFSqueeze *no * In ---- FixedReshape ----- Out(s) */ - auto squeeze_shape = node->annot(); + auto nodeshape = moco::tf::node_shape(node); // canonicalize into FixedReshape is valid when squeeze has shape info // TODO Support general Squeeze case - assert(squeeze_shape); + assert(nodeshape.domain() != loco::Domain::Unknown); - auto squeeze_tensor_shape = squeeze_shape->tensor_shape(); + auto squeeze_tensor_shape = nodeshape.as(); // Create loco node to replace auto reshape = graph->nodes()->create(); -- 2.7.4