[loco] TensorConstantPad shape and type inference (#8043)
author채성우/On-Device Lab(SR)/Engineer/삼성전자 <sw4670.chae@samsung.com>
Mon, 14 Oct 2019 08:04:34 +0000 (17:04 +0900)
committer박종현/On-Device Lab(SR)/Staff Engineer/삼성전자 <jh1302.park@samsung.com>
Mon, 14 Oct 2019 08:04:34 +0000 (17:04 +0900)
* [loco] TensorConstantPad shape and type inference

This commit support shape and type inference of TensorConstantPad.

Signed-off-by: seongwoo <sw4670.chae@samsung.com>
* format patch.

compiler/loco/src/Service/CanonicalShapeInferenceRule.cpp
compiler/loco/src/Service/TypeInference.cpp

index e555de6..947ee04 100644 (file)
@@ -649,6 +649,23 @@ public:
 
     return loco::NodeShape{output_feature_shape};
   }
+
+  // CASE: TensorConstantPad
+  loco::NodeShape visit(const loco::TensorConstantPad *node) final
+  {
+    auto const tensor_shape = loco::shape_get(node->input()).as<loco::TensorShape>();
+    auto padding = node->padding();
+
+    loco::TensorShape out_shape;
+    out_shape.rank(tensor_shape.rank());
+    for (uint32_t axis = 0; axis < out_shape.rank(); ++axis)
+    {
+      out_shape.dim(axis) =
+          tensor_shape.dim(axis).value() + padding->front(axis) + padding->back(axis);
+    }
+
+    return loco::NodeShape{out_shape};
+  }
 };
 
 } // namespace
index 1dc963f..0bc4f6b 100644 (file)
@@ -148,6 +148,10 @@ struct CanonicalTypeForwardAlgorithm final : public loco::CanonicalNodeVisitor<l
   loco::DataType visit(const loco::ReLU6 *node) { return loco::dtype_get(node->input()); }
   loco::DataType visit(const loco::Tanh *node) { return loco::dtype_get(node->input()); }
   loco::DataType visit(const loco::TensorConcat *node) { return loco::dtype_get(node->lhs()); }
+  loco::DataType visit(const loco::TensorConstantPad *node)
+  {
+    return loco::dtype_get(node->input());
+  }
   loco::DataType visit(const loco::TensorBiasAdd *node) { return loco::dtype_get(node->value()); }
   loco::DataType visit(const loco::TensorBroadcast *node) { return loco::dtype_get(node->input()); }
   loco::DataType visit(const loco::TensorReduce *node) { return loco::dtype_get(node->input()); }