From 9a978db7d190aa2abeaf05ca54c1e654df4dfd61 Mon Sep 17 00:00:00 2001 From: =?utf8?q?=EB=B0=95=EC=B2=9C=EA=B5=90/On-Device=20Lab=28SR=29/Enginee?= =?utf8?q?r/=EC=82=BC=EC=84=B1=EC=A0=84=EC=9E=90?= Date: Tue, 5 Nov 2019 15:35:46 +0900 Subject: [PATCH] [exo] Type & shape inference for TFLMean (#8739) This commit introduces type and shape inference for TFLMean Signed-off-by: Cheongyo Bahk --- .../src/Dialect/Service/TFLShapeInferenceRule.cpp | 56 ++++++++++++++++++++++ .../src/Dialect/Service/TFLTypeInferenceRule.cpp | 2 + 2 files changed, 58 insertions(+) diff --git a/compiler/exo/src/Dialect/Service/TFLShapeInferenceRule.cpp b/compiler/exo/src/Dialect/Service/TFLShapeInferenceRule.cpp index dd17ca4..34a8f63 100644 --- a/compiler/exo/src/Dialect/Service/TFLShapeInferenceRule.cpp +++ b/compiler/exo/src/Dialect/Service/TFLShapeInferenceRule.cpp @@ -360,6 +360,62 @@ public: return infer_pool_2d_shape(node); } + loco::NodeShape visit(const locoex::TFLMean *node) final + { + const loco::DataType S32 = loco::DataType::S32; + + auto input_shape = loco::shape_get(node->input()).as(); + auto reduction_indices = dynamic_cast(node->reduction_indices()); + + { // Exceptions + // TODO support non-const case + EXO_ASSERT(reduction_indices, "Only support constant reduction_indices"); + // TODO support other data type + EXO_ASSERT(reduction_indices->dtype() == S32, "Only support int 32"); + } + + std::vector reduction_values; + + for (uint32_t i = 0; i < reduction_indices->size(); ++i) + { + int32_t axis = reduction_indices->at(i); + if (axis < 0) + axis += input_shape.rank(); + if (not(0 <= axis and axis < static_cast(input_shape.rank()))) + EXO_THROW("Invalid reduction axis for MEAN"); + reduction_values.push_back(axis); + } + + loco::TensorShape output_shape; + + if (node->keep_dims()) + { + output_shape.rank(input_shape.rank()); + for (uint32_t i = 0; i < input_shape.rank(); ++i) + output_shape.dim(i) = input_shape.dim(i); + for (uint32_t i = 0; i < reduction_values.size(); ++i) + output_shape.dim(reduction_values.at(i)) = 1; + } + else + { + std::vector check_reduce(input_shape.rank(), false); + for (uint32_t i = 0; i < reduction_values.size(); ++i) + check_reduce.at(reduction_values.at(i)) = true; + + uint32_t reduce_cnt = 0; + for (uint32_t i = 0; i < check_reduce.size(); ++i) + if (check_reduce.at(i)) + ++reduce_cnt; + + output_shape.rank(input_shape.rank() - reduce_cnt); + for (uint32_t i = 0, j = 0; i < check_reduce.size(); ++i) + if (check_reduce.at(i) == false) + output_shape.dim(j++) = i; + } + + return loco::NodeShape{output_shape}; + } + loco::NodeShape visit(const locoex::TFLMul *node) final { auto x_shape = loco::shape_get(node->x()).as(); diff --git a/compiler/exo/src/Dialect/Service/TFLTypeInferenceRule.cpp b/compiler/exo/src/Dialect/Service/TFLTypeInferenceRule.cpp index 62970fa..8184520 100644 --- a/compiler/exo/src/Dialect/Service/TFLTypeInferenceRule.cpp +++ b/compiler/exo/src/Dialect/Service/TFLTypeInferenceRule.cpp @@ -77,6 +77,8 @@ struct TypeInferenceAlgorithm final : public locoex::TFLNodeVisitorvalue()); } + loco::DataType visit(const locoex::TFLMean *node) final { return loco::dtype_get(node->input()); } + loco::DataType visit(const locoex::TFLMul *node) final { return loco::dtype_get(node->x()); } loco::DataType visit(const locoex::TFLRelu *node) final -- 2.7.4