From ae618b991c27f677b5c3b0f5f09a9ed61c25cb39 Mon Sep 17 00:00:00 2001 From: =?utf8?q?=EB=B0=95=EC=B2=9C=EA=B5=90/On-Device=20Lab=28SR=29/Enginee?= =?utf8?q?r/=EC=82=BC=EC=84=B1=EC=A0=84=EC=9E=90?= Date: Thu, 24 Oct 2019 13:26:58 +0900 Subject: [PATCH] [exo] TFLReshape inference for shape & type (#8430) * [exo] TFLReshape inference for shape & type This commit implements shape and type inference for TFLReshape Signed-off-by: Cheongyo Bahk * Fix typo --- .../src/Dialect/Service/TFLShapeInferenceRule.cpp | 50 +++++++++++++++++++++- .../src/Dialect/Service/TFLTypeInferenceRule.cpp | 5 ++- 2 files changed, 53 insertions(+), 2 deletions(-) diff --git a/compiler/exo/src/Dialect/Service/TFLShapeInferenceRule.cpp b/compiler/exo/src/Dialect/Service/TFLShapeInferenceRule.cpp index dbefb04..4b39007 100644 --- a/compiler/exo/src/Dialect/Service/TFLShapeInferenceRule.cpp +++ b/compiler/exo/src/Dialect/Service/TFLShapeInferenceRule.cpp @@ -384,7 +384,55 @@ public: return loco::NodeShape{input_shape}; } - // TODO TFLReshape + /** + * @note TFLReshape has new shape info in two places: 2nd input and attribute. + * This shape inference forces both to exist, and match each other. + * When this condition satisfied, it return the inferred shape + * + * TODO Change this policy when not appropriate + */ + loco::NodeShape visit(const locoex::TFLReshape *node) final + { + const loco::DataType S32 = loco::DataType::S32; + + loco::TensorShape shape_by_input; + { + EXO_ASSERT(node->shape(), "2nd input shape() should not be nullptr"); + + // Only support node's shape() is TFLConst with S32 + // TODO support other node with other types + auto const_shape_node = dynamic_cast(node->shape()); + EXO_ASSERT(const_shape_node, "Only support TFLConst for shape of TFLReshape"); + EXO_ASSERT(const_shape_node->dtype() == S32, "Only support int32 TFLConst"); + + if (const_shape_node->rank() != 1) + EXO_THROW("Only support rank 1 TFLConst"); + + shape_by_input.rank(const_shape_node->dim(0).value()); + + for (uint32_t axis = 0; axis < shape_by_input.rank(); ++axis) + { + EXO_ASSERT(const_shape_node->at(axis) > 0, "Dimension should be > 0") + shape_by_input.dim(axis) = const_shape_node->at(axis); + } + } + + loco::TensorShape shape_by_attr; + { + shape_by_attr.rank(node->newShape()->rank()); + + for (uint32_t axis = 0; axis < shape_by_attr.rank(); ++axis) + { + EXO_ASSERT(node->newShape()->dim(axis) > 0, "Dimension should be > 0") + shape_by_attr.dim(axis) = node->newShape()->dim(axis); + } + } + + EXO_ASSERT(shape_by_input == shape_by_attr, + "Warning: Two new shape information mismatched for TFLReshape"); + + return loco::NodeShape{shape_by_input}; + } // TODO TFLSoftmax diff --git a/compiler/exo/src/Dialect/Service/TFLTypeInferenceRule.cpp b/compiler/exo/src/Dialect/Service/TFLTypeInferenceRule.cpp index 4b527f2..5f068dc 100644 --- a/compiler/exo/src/Dialect/Service/TFLTypeInferenceRule.cpp +++ b/compiler/exo/src/Dialect/Service/TFLTypeInferenceRule.cpp @@ -89,7 +89,10 @@ struct TypeInferenceAlgorithm final : public locoex::TFLNodeVisitorfeatures()); } - // TODO TFLReshape + loco::DataType visit(const locoex::TFLReshape *node) final + { + return loco::dtype_get(node->tensor()); + } // TODO TFLSoftmax -- 2.7.4