* [loco] TensorConstantPad shape and type inference
This commit support shape and type inference of TensorConstantPad.
Signed-off-by: seongwoo <sw4670.chae@samsung.com>
* format patch.
return loco::NodeShape{output_feature_shape};
}
+
+ // CASE: TensorConstantPad
+ loco::NodeShape visit(const loco::TensorConstantPad *node) final
+ {
+ auto const tensor_shape = loco::shape_get(node->input()).as<loco::TensorShape>();
+ auto padding = node->padding();
+
+ loco::TensorShape out_shape;
+ out_shape.rank(tensor_shape.rank());
+ for (uint32_t axis = 0; axis < out_shape.rank(); ++axis)
+ {
+ out_shape.dim(axis) =
+ tensor_shape.dim(axis).value() + padding->front(axis) + padding->back(axis);
+ }
+
+ return loco::NodeShape{out_shape};
+ }
};
} // namespace
loco::DataType visit(const loco::ReLU6 *node) { return loco::dtype_get(node->input()); }
loco::DataType visit(const loco::Tanh *node) { return loco::dtype_get(node->input()); }
loco::DataType visit(const loco::TensorConcat *node) { return loco::dtype_get(node->lhs()); }
+ loco::DataType visit(const loco::TensorConstantPad *node)
+ {
+ return loco::dtype_get(node->input());
+ }
loco::DataType visit(const loco::TensorBiasAdd *node) { return loco::dtype_get(node->value()); }
loco::DataType visit(const loco::TensorBroadcast *node) { return loco::dtype_get(node->input()); }
loco::DataType visit(const loco::TensorReduce *node) { return loco::dtype_get(node->input()); }