From 2bafe97434b3dd803631c7e9029c50e5995104c1 Mon Sep 17 00:00:00 2001 From: =?utf8?q?=EC=9C=A4=ED=98=84=EC=8B=9D/On-Device=20Lab=28SR=29/Princip?= =?utf8?q?al=20Engineer/=EC=82=BC=EC=84=B1=EC=A0=84=EC=9E=90?= Date: Tue, 1 Oct 2019 08:21:46 +0900 Subject: [PATCH] [exo-tflite] shape & type inference for TFLConv2D (#7835) This adds shape inference and type inference for TFLConv2D. Signed-off-by: Hyun Sik Yoon --- .../src/Dialect/Service/TFLShapeInferenceRule.cpp | 46 +++++++++++++++++++++- .../src/Dialect/Service/TFLTypeInferenceRule.cpp | 5 ++- 2 files changed, 49 insertions(+), 2 deletions(-) diff --git a/compiler/exo-tflite/src/Dialect/Service/TFLShapeInferenceRule.cpp b/compiler/exo-tflite/src/Dialect/Service/TFLShapeInferenceRule.cpp index 3fd0298..7a892db 100644 --- a/compiler/exo-tflite/src/Dialect/Service/TFLShapeInferenceRule.cpp +++ b/compiler/exo-tflite/src/Dialect/Service/TFLShapeInferenceRule.cpp @@ -226,7 +226,51 @@ public: return loco::NodeShape{shape}; } - // TODO TFLConv2D + loco::NodeShape visit(const locoex::TFLConv2D *node) final + { + auto ifm_shape = loco::shape_get(node->input()).as(); // in NHWC + auto ker_shape = loco::shape_get(node->filter()).as(); // in OHWI + + assert(ifm_shape.rank() == 4); + assert(ker_shape.rank() == 4); + assert(ifm_shape.dim(3) == ker_shape.dim(3)); + + uint32_t input_height = ifm_shape.dim(1).value(); + uint32_t input_width = ifm_shape.dim(2).value(); + uint32_t stride_height = node->stride()->h(); + uint32_t stride_width = node->stride()->w(); + uint32_t ker_height = ker_shape.dim(1).value(); + uint32_t ker_width = ker_shape.dim(2).value(); + uint32_t dilation_height = 1; + uint32_t dilation_width = 1; + uint32_t effective_ker_height = dilation_height * (ker_height - 1) + 1; + uint32_t effective_ker_width = dilation_width * (ker_width - 1) + 1; + + uint32_t output_height; + uint32_t output_width; + + if (node->padding() == locoex::Padding::VALID) + { + output_height = (input_height + stride_height - effective_ker_height) / stride_height; + output_width = (input_width + stride_width - effective_ker_width) / stride_width; + } + else if (node->padding() == locoex::Padding::SAME) + { + output_height = (input_height + stride_height - 1) / stride_height; + output_width = (input_width + stride_width - 1) / stride_width; + } + else + EXO_ASSERT(false, "Wrong padding type"); + + loco::TensorShape ofm_shape; + ofm_shape.rank(4); + ofm_shape.dim(0) = ifm_shape.dim(0); + ofm_shape.dim(1) = output_height; + ofm_shape.dim(2) = output_width; + ofm_shape.dim(3) = ker_shape.dim(0); + + return loco::NodeShape{ofm_shape}; + } // TODO TFLDepthwiseConv2D diff --git a/compiler/exo-tflite/src/Dialect/Service/TFLTypeInferenceRule.cpp b/compiler/exo-tflite/src/Dialect/Service/TFLTypeInferenceRule.cpp index 95e5686..11198ed 100644 --- a/compiler/exo-tflite/src/Dialect/Service/TFLTypeInferenceRule.cpp +++ b/compiler/exo-tflite/src/Dialect/Service/TFLTypeInferenceRule.cpp @@ -51,7 +51,10 @@ struct TypeInferenceAlgorithm final : public locoex::TFLNodeVisitordtype(); } - // TODO TFLConv2D + loco::DataType visit(const locoex::TFLConv2D *node) final + { + return loco::dtype_get(node->input()); + } // TODO TFLDepthwiseConv2D -- 2.7.4