[exo-tflite] shape & type inference for TFLConv2D (#7835)
author윤현식/On-Device Lab(SR)/Principal Engineer/삼성전자 <hyunsik.yoon@samsung.com>
Mon, 30 Sep 2019 23:21:46 +0000 (08:21 +0900)
committer박종현/On-Device Lab(SR)/Staff Engineer/삼성전자 <jh1302.park@samsung.com>
Mon, 30 Sep 2019 23:21:46 +0000 (08:21 +0900)
This adds shape inference and type inference for TFLConv2D.

Signed-off-by: Hyun Sik Yoon <hyunsik.yoon@samsung.com>
compiler/exo-tflite/src/Dialect/Service/TFLShapeInferenceRule.cpp
compiler/exo-tflite/src/Dialect/Service/TFLTypeInferenceRule.cpp

index 3fd0298..7a892db 100644 (file)
@@ -226,7 +226,51 @@ public:
     return loco::NodeShape{shape};
   }
 
-  // TODO TFLConv2D
+  loco::NodeShape visit(const locoex::TFLConv2D *node) final
+  {
+    auto ifm_shape = loco::shape_get(node->input()).as<loco::TensorShape>();  // in NHWC
+    auto ker_shape = loco::shape_get(node->filter()).as<loco::TensorShape>(); // in OHWI
+
+    assert(ifm_shape.rank() == 4);
+    assert(ker_shape.rank() == 4);
+    assert(ifm_shape.dim(3) == ker_shape.dim(3));
+
+    uint32_t input_height = ifm_shape.dim(1).value();
+    uint32_t input_width = ifm_shape.dim(2).value();
+    uint32_t stride_height = node->stride()->h();
+    uint32_t stride_width = node->stride()->w();
+    uint32_t ker_height = ker_shape.dim(1).value();
+    uint32_t ker_width = ker_shape.dim(2).value();
+    uint32_t dilation_height = 1;
+    uint32_t dilation_width = 1;
+    uint32_t effective_ker_height = dilation_height * (ker_height - 1) + 1;
+    uint32_t effective_ker_width = dilation_width * (ker_width - 1) + 1;
+
+    uint32_t output_height;
+    uint32_t output_width;
+
+    if (node->padding() == locoex::Padding::VALID)
+    {
+      output_height = (input_height + stride_height - effective_ker_height) / stride_height;
+      output_width = (input_width + stride_width - effective_ker_width) / stride_width;
+    }
+    else if (node->padding() == locoex::Padding::SAME)
+    {
+      output_height = (input_height + stride_height - 1) / stride_height;
+      output_width = (input_width + stride_width - 1) / stride_width;
+    }
+    else
+      EXO_ASSERT(false, "Wrong padding type");
+
+    loco::TensorShape ofm_shape;
+    ofm_shape.rank(4);
+    ofm_shape.dim(0) = ifm_shape.dim(0);
+    ofm_shape.dim(1) = output_height;
+    ofm_shape.dim(2) = output_width;
+    ofm_shape.dim(3) = ker_shape.dim(0);
+
+    return loco::NodeShape{ofm_shape};
+  }
 
   // TODO TFLDepthwiseConv2D
 
index 95e5686..11198ed 100644 (file)
@@ -51,7 +51,10 @@ struct TypeInferenceAlgorithm final : public locoex::TFLNodeVisitor<loco::DataTy
 
   loco::DataType visit(const locoex::TFLConst *node) final { return node->dtype(); }
 
-  // TODO TFLConv2D
+  loco::DataType visit(const locoex::TFLConv2D *node) final
+  {
+    return loco::dtype_get(node->input());
+  }
 
   // TODO TFLDepthwiseConv2D