From 6f43e89950067b7e0a0df0471123f86d40865357 Mon Sep 17 00:00:00 2001 From: =?utf8?q?=EB=B0=95=EC=84=B8=ED=9D=AC/On-Device=20Lab=28SR=29/Princip?= =?utf8?q?al=20Engineer/=EC=82=BC=EC=84=B1=EC=A0=84=EC=9E=90?= Date: Tue, 8 Oct 2019 10:24:29 +0900 Subject: [PATCH] [moco-tf] ShapeInf for TFConcatV2 (#7947) * [moco-tf] ShapeInf for TFConcatV2 This will update shape inference to be done in TFShapeInferenceRule for TFConcatV2 node and not use ConcatData annotation Signed-off-by: SaeHie Park * apply comments * simplify assert check --- .../src/Canonicalization/ConcatV2Canonicalizer.cpp | 43 ++++-- .../moco-tf/src/Dialect/TFShapeInferenceRule.cpp | 103 +++++++++++++ .../moco-tf/src/Transforms/FixShapeTransform.cpp | 169 +-------------------- 3 files changed, 138 insertions(+), 177 deletions(-) diff --git a/compiler/moco-tf/src/Canonicalization/ConcatV2Canonicalizer.cpp b/compiler/moco-tf/src/Canonicalization/ConcatV2Canonicalizer.cpp index 7a045e8..14493f1 100644 --- a/compiler/moco-tf/src/Canonicalization/ConcatV2Canonicalizer.cpp +++ b/compiler/moco-tf/src/Canonicalization/ConcatV2Canonicalizer.cpp @@ -18,8 +18,6 @@ #include "TFShapeInferenceHelper.h" #include "LogHelper.h" -#include "Annotations/ConcatData.h" - #include "Dialect/TFDialect.h" #include "Dialect/TFNodes.h" #include "Dialect/TFNodeVisitor.h" @@ -27,6 +25,8 @@ #include +#include + #include namespace @@ -34,6 +34,17 @@ namespace using namespace moco::tf; +int32_t scala_value(moco::tf::TFConst *node) +{ + auto nodeshape = node_shape(node); + assert(node->dtype() == loco::DataType::S32); + + auto tensor_shape = nodeshape.as(); + assert(tensor_shape.rank() == 0 || tensor_shape.rank() == 1); + + return node->at(0); +} + bool canonicalize_concat(loco::Graph *graph, moco::tf::TFConcatV2 *node) { LOGGER(l); @@ -70,20 +81,34 @@ bool canonicalize_concat(loco::Graph *graph, moco::tf::TFConcatV2 *node) const int num_values = node->num_values(); assert(num_values >= 2); - // get axis value - auto concat_data = node->annot(); - assert(concat_data != nullptr); - auto axis_value = concat_data->axis(); + // get axis absolute value + auto value_a = node->values(0); + if (!loco::shape_known(value_a)) + return false; + + uint32_t node_rank = 0; + { + auto value_a_shape = node_shape(value_a); + assert(value_a_shape.domain() == loco::Domain::Tensor); - auto nodeshape = moco::tf::node_shape(node); - auto tensorshape = nodeshape.as(); - auto node_rank = tensorshape.rank(); + auto value_a_tensor_shape = value_a_shape.as(); + node_rank = value_a_tensor_shape.rank(); + } + int32_t axis_value = 0; + { + // axis should be TFConst + auto axis_node = node->axis(); + auto tfconst = dynamic_cast(axis_node); + assert(tfconst != nullptr); + axis_value = scala_value(tfconst); + } uint32_t axis_absolute = (axis_value >= 0) ? axis_value : (int32_t)node_rank + axis_value; INFO(l) << "canonicalize_concat axis(" << axis_absolute << "), value(" << axis_value << "), rank(" << node_rank << ")"; + // Convert series of TensorConcat if num_values > 2 auto concat_node = graph->nodes()->create(); concat_node->lhs(node->values(0)); concat_node->rhs(node->values(1)); diff --git a/compiler/moco-tf/src/Dialect/TFShapeInferenceRule.cpp b/compiler/moco-tf/src/Dialect/TFShapeInferenceRule.cpp index e0bd6dd..531693a 100644 --- a/compiler/moco-tf/src/Dialect/TFShapeInferenceRule.cpp +++ b/compiler/moco-tf/src/Dialect/TFShapeInferenceRule.cpp @@ -64,6 +64,38 @@ private: return sum_shape; } + bool valid_scala_value(moco::tf::TFConst *node) + { + auto nodeshape = node_shape(node); + if (nodeshape.domain() != loco::Domain::Tensor) + { + return false; + } + if (node->dtype() != loco::DataType::S32) + { + return false; + } + + auto tensor_shape = nodeshape.as(); + if (!(tensor_shape.rank() == 0 || tensor_shape.rank() == 1)) + { + return false; + } + + return true; + } + + int32_t scala_value(moco::tf::TFConst *node) + { + auto nodeshape = node_shape(node); + assert(node->dtype() == loco::DataType::S32); + + auto tensor_shape = nodeshape.as(); + assert(tensor_shape.rank() == 0 || tensor_shape.rank() == 1); + + return node->at(0); + } + public: loco::NodeShape visit(const moco::tf::TFAdd *node) final { return binary_node_shape(node); } @@ -96,6 +128,77 @@ public: return value_shape; } + loco::NodeShape visit(const moco::tf::TFConcatV2 *node) final + { + // axis shape should be available + auto axis_node = node->axis(); + auto axis_shape = node_shape(axis_node); + assert(axis_shape.domain() != loco::Domain::Unknown); + + // check all input shapes and all ranks should be same + auto value_a = node->values(0); + auto value_a_shape = node_shape(value_a); + assert(value_a_shape.domain() == loco::Domain::Tensor); + auto value_a_tensor_shape = value_a_shape.as(); + uint32_t a_rank = value_a_tensor_shape.rank(); + + uint32_t num_values = node->num_values(); + for (uint32_t ni = 1; ni < num_values; ++ni) + { + auto value_b = node->values(ni); + auto value_b_shape = node_shape(value_b); + assert(value_b_shape.domain() == loco::Domain::Tensor); + auto value_b_tensor_shape = value_b_shape.as(); + uint32_t b_rank = value_b_tensor_shape.rank(); + assert(a_rank == b_rank); + } + + int32_t axis_value = 0; + bool axis_available = false; + { + // check for axis is TFConst + auto tfconst = dynamic_cast(axis_node); + if (tfconst != nullptr) + { + if (valid_scala_value(tfconst)) + { + axis_value = scala_value(tfconst); + axis_available = true; + } + } + } + assert(axis_available); + + uint32_t axis_absolute = (axis_value >= 0) ? axis_value : (int32_t)a_rank + axis_value; + loco::TensorShape output_tensor_shape = value_a_tensor_shape; + + for (uint32_t index = 0; index < a_rank; ++index) + { + if (value_a_tensor_shape.dim(index).known()) + { + uint32_t dim = value_a_tensor_shape.dim(index).value(); + uint32_t dim_acc = dim; + + for (uint32_t ni = 1; ni < num_values; ++ni) + { + auto value_b = node->values(ni); + auto value_b_shape = node_shape(value_b); + assert(value_b_shape.domain() == loco::Domain::Tensor); + auto value_b_tensor_shape = value_b_shape.as(); + assert(value_b_tensor_shape.dim(index).known()); + if (index == axis_absolute) + dim_acc += value_b_tensor_shape.dim(index).value(); + else + assert(dim == value_b_tensor_shape.dim(index).value()); + } + output_tensor_shape.dim(index) = dim_acc; + } + else + output_tensor_shape.dim(index).unset(); + } + return loco::NodeShape(output_tensor_shape); + } + public: loco::NodeShape visit(const moco::tf::TFNode *node) final { diff --git a/compiler/moco-tf/src/Transforms/FixShapeTransform.cpp b/compiler/moco-tf/src/Transforms/FixShapeTransform.cpp index f6d0532..493be50 100644 --- a/compiler/moco-tf/src/Transforms/FixShapeTransform.cpp +++ b/compiler/moco-tf/src/Transforms/FixShapeTransform.cpp @@ -18,7 +18,6 @@ #include "TFShapeInferenceHelper.h" #include "LogHelper.h" -#include "Annotations/ConcatData.h" #include "Annotations/PadData.h" #include "Annotations/ShapeInferenceData.h" #include "Annotations/StrideData.h" @@ -438,173 +437,7 @@ bool fix_shape(moco::tf::TFBiasAdd *node) return copy_shapedata(value, node); } -template bool valid_scala_value(CONST_CLASS *node) -{ - LOGGER(l); - - loco::NodeShape nodeshape; - if (!node_shape(node, nodeshape)) - { - return false; - } - - if (node->dtype() != loco::DataType::S32) - { - INFO(l) << "valid_scala_value not S32"; - return false; - } - - auto tensor_shape = nodeshape.as(); - if (!(tensor_shape.rank() == 0 || tensor_shape.rank() == 1)) - { - INFO(l) << "valid_scala_value rank not 0/1 : " << tensor_shape.rank(); - return false; - } - - return true; -} - -template int32_t scala_value(CONST_CLASS *node) -{ - loco::NodeShape nodeshape; - if (!node_shape(node, nodeshape)) - { - return false; - } - - assert(node->dtype() == loco::DataType::S32); - - auto tensor_shape = nodeshape.as(); - assert(tensor_shape.rank() == 0 || tensor_shape.rank() == 1); - - return node->template at(0); -} - -bool fix_shape(moco::tf::TFConcatV2 *node) -{ - LOGGER(l); - - if (shape_inference_done(node)) - { - INFO(l) << "Fix shape TFConcatV2 already done"; - return false; - } - // ConcatData should be null - assert(node->annot() == nullptr); - - // Check shape inference data are all ready - // Check shape rank are all same - auto value_a = node->values(0); - loco::NodeShape value_a_shape; - if (!node_shape(value_a, value_a_shape)) - { - // shape inference is not ready for this value - INFO(l) << "Fix shape TFConcatV2 value 0 shape_data not ready"; - return false; - } - assert(value_a_shape.domain() == loco::Domain::Tensor); - auto value_a_tensor_shape = value_a_shape.as(); - uint32_t a_rank = value_a_tensor_shape.rank(); - - uint32_t num_values = node->num_values(); - for (uint32_t ni = 1; ni < num_values; ++ni) - { - auto value_b = node->values(ni); - loco::NodeShape value_b_shape; - if (!node_shape(value_b, value_b_shape)) - { - // shape inference is not ready for this value - INFO(l) << "Fix shape TFConcatV2 value " << ni << " shape_data not ready"; - return false; - } - assert(value_b_shape.domain() == loco::Domain::Tensor); - auto value_b_tensor_shape = value_b_shape.as(); - uint32_t b_rank = value_b_tensor_shape.rank(); - assert(a_rank == b_rank); - } - - // check for axis - auto axis_node = node->axis(); - loco::NodeShape axis_shape; - if (!node_shape(axis_node, axis_shape)) - { - // shape inference is not ready for axis_node - INFO(l) << "Fix shape TFConcatV2 axis shape_data not ready"; - return false; - } - - int32_t axis_value = 0; - bool axis_available = false; - { - // check for axis is TFConst - auto tfconst = dynamic_cast(axis_node); - if (tfconst != nullptr) - { - if (valid_scala_value(tfconst)) - { - axis_value = scala_value(tfconst); - axis_available = true; - } - } - } - { - // check for axis is ConstGen - auto constgen = dynamic_cast(axis_node); - if (constgen != nullptr) - { - if (valid_scala_value(constgen)) - { - axis_value = scala_value(constgen); - axis_available = true; - } - } - } - if (!axis_available) - { - // we cannot find a valid axis value - INFO(l) << "Fix shape TFConcatV2 axis_available false"; - return false; - } - - auto concat_data = stdex::make_unique(axis_value); - node->annot(std::move(concat_data)); - - uint32_t axis_absolute = (axis_value >= 0) ? axis_value : (int32_t)a_rank + axis_value; - - auto shape_data = stdex::make_unique(); - shape_data->rank(a_rank); - - for (uint32_t index = 0; index < a_rank; ++index) - { - if (value_a_tensor_shape.dim(index).known()) - { - uint32_t dim = value_a_tensor_shape.dim(index).value(); - if (index == axis_absolute) - { - uint32_t dim_acc = dim; - for (uint32_t ni = 1; ni < num_values; ++ni) - { - auto value_b = node->values(ni); - loco::NodeShape value_b_shape; - node_shape(value_b, value_b_shape); - assert(value_b_shape.domain() == loco::Domain::Tensor); - auto value_b_tensor_shape = value_b_shape.as(); - assert(value_b_tensor_shape.dim(index).known()); - dim_acc += value_b_tensor_shape.dim(index).value(); - } - dim = dim_acc; - } - shape_data->dim(index) = dim; - } - else - shape_data->dim(index).unset(); - } - node->annot(std::move(shape_data)); - - INFO(l) << "Fix TFConcat shape = " << node->annot(); - - return true; -} +bool fix_shape(moco::tf::TFConcatV2 *node) { return false; } bool fix_shape(moco::tf::TFConst *node) { -- 2.7.4