[loco] Type and Shape inference for FilterDecode (#7804)
author윤현식/On-Device Lab(SR)/Principal Engineer/삼성전자 <hyunsik.yoon@samsung.com>
Fri, 27 Sep 2019 07:50:52 +0000 (16:50 +0900)
committer박종현/On-Device Lab(SR)/Staff Engineer/삼성전자 <jh1302.park@samsung.com>
Fri, 27 Sep 2019 07:50:52 +0000 (16:50 +0900)
This adds Type and Shape inference for FilterDecode.

Signed-off-by: Hyun Sik Yoon <hyunsik.yoon@samsung.com>
compiler/loco/src/Service/CanonicalShapeInferenceRule.cpp
compiler/loco/src/Service/TypeInference.cpp

index 8d37c22..e555de6 100644 (file)
@@ -418,6 +418,13 @@ public:
     return loco::NodeShape{node->encoder()->shape(input_node_shape.as<loco::TensorShape>())};
   }
 
+  // CASE: FilterDecode
+  loco::NodeShape visit(const loco::FilterDecode *node) final
+  {
+    auto input_filter_shape = node_shape(node->input()).as<loco::FilterShape>();
+    return loco::NodeShape{node->decoder()->shape(input_filter_shape)};
+  }
+
   // CASE: FilterEncode
   loco::NodeShape visit(const loco::FilterEncode *node) final
   {
index f3bb998..1dc963f 100644 (file)
@@ -135,6 +135,7 @@ struct CanonicalTypeForwardAlgorithm final : public loco::CanonicalNodeVisitor<l
   loco::DataType visit(const loco::FeatureBiasAdd *node) { return loco::dtype_get(node->value()); }
   loco::DataType visit(const loco::FeatureDecode *node) { return loco::dtype_get(node->input()); }
   loco::DataType visit(const loco::FeatureEncode *node) { return loco::dtype_get(node->input()); }
+  loco::DataType visit(const loco::FilterDecode *node) { return loco::dtype_get(node->input()); }
   loco::DataType visit(const loco::FilterEncode *node) { return loco::dtype_get(node->input()); }
   loco::DataType visit(const loco::FixedReshape *node) { return loco::dtype_get(node->input()); }
   loco::DataType visit(const loco::MatrixDecode *node) { return loco::dtype_get(node->input()); }