From c0c2e3f270f08c620314b073fe64b18090e0cace Mon Sep 17 00:00:00 2001 From: =?utf8?q?=EB=B0=95=EC=A2=85=ED=98=84/On-Device=20Lab=28SR=29/Staff?= =?utf8?q?=20Engineer/=EC=82=BC=EC=84=B1=EC=A0=84=EC=9E=90?= Date: Fri, 6 Sep 2019 16:27:46 +0900 Subject: [PATCH] [loco] Do NOT invoke shape_get directly (#7231) This commit replaces all the loco::shape_get calls in ForwardShapeInferenceAlgorithm implemention with "node_shape" method to facilitate refactoring. Signed-off-by: Jonghyun Park --- .../src/Service/CanonicalShapeInferenceRule.cpp | 63 +++++++++++----------- 1 file changed, 30 insertions(+), 33 deletions(-) diff --git a/compiler/loco/src/Service/CanonicalShapeInferenceRule.cpp b/compiler/loco/src/Service/CanonicalShapeInferenceRule.cpp index 4e6a466..4d84d1d 100644 --- a/compiler/loco/src/Service/CanonicalShapeInferenceRule.cpp +++ b/compiler/loco/src/Service/CanonicalShapeInferenceRule.cpp @@ -169,6 +169,9 @@ loco::NodeShape eltwise_binary_node_shape(const loco::Node *node) */ class ForwardShapeInferenceAlgorithm final : public loco::CanonicalNodeVisitor { +private: + loco::NodeShape node_shape(const loco::Node *node) const { return loco::shape_get(node); } + public: // CASE: AvgPool2D loco::NodeShape visit(const loco::AvgPool2D *node) final @@ -179,7 +182,7 @@ public: infer_plane_shape.window(node->window()); infer_plane_shape.stride(node->stride()); - auto input_feature_shape = loco::shape_get(node->ifm()).as(); + auto input_feature_shape = node_shape(node->ifm()).as(); auto input_plane_shape = make_plane_shape(input_feature_shape); auto output_plane_shape = infer_plane_shape(input_plane_shape); auto output_feature_shape = input_feature_shape; // AvgPool2D does not change count/depth @@ -194,8 +197,8 @@ public: loco::NodeShape visit(const loco::BiasEncode *node) final { // The input of BiasEncode SHOULD BE a tensor! - assert(loco::shape_get(node->input()).domain() == loco::Domain::Tensor); - auto input_tensor_shape = loco::shape_get(node->input()).as(); + assert(node_shape(node->input()).domain() == loco::Domain::Tensor); + auto input_tensor_shape = node_shape(node->input()).as(); loco::BiasShape output_bias_shape; @@ -221,7 +224,7 @@ public: // CASE: Conv2D loco::NodeShape visit(const loco::Conv2D *node) final { - auto filter_shape = loco::shape_get(node->ker()).as(); + auto filter_shape = node_shape(node->ker()).as(); auto filter_window = window_of(filter_shape); PlaneInference infer_plane_shape; @@ -230,7 +233,7 @@ public: infer_plane_shape.window(&filter_window); infer_plane_shape.stride(node->stride()); - auto input_feature_shape = loco::shape_get(node->ifm()).as(); + auto input_feature_shape = node_shape(node->ifm()).as(); auto input_plane_shape = make_plane_shape(input_feature_shape); auto output_plane_shape = infer_plane_shape(input_plane_shape); @@ -249,7 +252,7 @@ public: // CASE: DepthwiseConv2D loco::NodeShape visit(const loco::DepthwiseConv2D *node) final { - auto depthwise_filter_shape = loco::shape_get(node->ker()).as(); + auto depthwise_filter_shape = node_shape(node->ker()).as(); auto dpethwise_filter_window = window_of(depthwise_filter_shape); PlaneInference infer_plane_shape; @@ -258,7 +261,7 @@ public: infer_plane_shape.window(&dpethwise_filter_window); infer_plane_shape.stride(node->stride()); - auto input_feature_shape = loco::shape_get(node->ifm()).as(); + auto input_feature_shape = node_shape(node->ifm()).as(); auto input_plane_shape = make_plane_shape(input_feature_shape); auto output_plane_shape = infer_plane_shape(input_plane_shape); @@ -278,7 +281,7 @@ public: // CASE: DepthwiseFilterEncode loco::NodeShape visit(const loco::DepthwiseFilterEncode *node) final { - auto input_tensor_shape = loco::shape_get(node->input()).as(); + auto input_tensor_shape = node_shape(node->input()).as(); return loco::NodeShape{node->encoder()->shape(input_tensor_shape)}; } @@ -301,10 +304,7 @@ public: } // CASE: EltwiseSqrt - loco::NodeShape visit(const loco::EltwiseSqrt *node) final - { - return loco::shape_get(node->input()); - } + loco::NodeShape visit(const loco::EltwiseSqrt *node) final { return node_shape(node->input()); } // CASE: EltwiseSub loco::NodeShape visit(const loco::EltwiseSub *node) final @@ -313,37 +313,37 @@ public: } // CASE: Forward - loco::NodeShape visit(const loco::Forward *node) final { return loco::shape_get(node->input()); } + loco::NodeShape visit(const loco::Forward *node) final { return node_shape(node->input()); } // CASE: FeatureBiasAdd loco::NodeShape visit(const loco::FeatureBiasAdd *node) final { - assert(loco::shape_get(node->value()).domain() == loco::Domain::Feature); - assert(loco::shape_get(node->bias()).domain() == loco::Domain::Bias); + assert(node_shape(node->value()).domain() == loco::Domain::Feature); + assert(node_shape(node->bias()).domain() == loco::Domain::Bias); // Q. What to do when there is a mismatch between value's depth and bias's length? - return loco::shape_get(node->value()); + return node_shape(node->value()); } // CASE: FeatureDecode loco::NodeShape visit(const loco::FeatureDecode *node) final { - auto input_node_shape = loco::shape_get(node->input()); + auto input_node_shape = node_shape(node->input()); return loco::NodeShape{node->decoder()->shape(input_node_shape.as())}; } // CASE: FeatureEncode loco::NodeShape visit(const loco::FeatureEncode *node) final { - auto input_node_shape = loco::shape_get(node->input()); + auto input_node_shape = node_shape(node->input()); return loco::NodeShape{node->encoder()->shape(input_node_shape.as())}; } // CASE: FilterEncode loco::NodeShape visit(const loco::FilterEncode *node) final { - auto input_tensor_shape = loco::shape_get(node->input()).as(); + auto input_tensor_shape = node_shape(node->input()).as(); return loco::NodeShape{node->encoder()->shape(input_tensor_shape)}; } @@ -370,7 +370,7 @@ public: infer_plane_shape.window(node->window()); infer_plane_shape.stride(node->stride()); - auto input_feature_shape = loco::shape_get(node->ifm()).as(); + auto input_feature_shape = node_shape(node->ifm()).as(); auto input_plane_shape = make_plane_shape(input_feature_shape); auto output_plane_shape = infer_plane_shape(input_plane_shape); auto output_feature_shape = input_feature_shape; // MaxPool2D does not change count/depth @@ -385,7 +385,7 @@ public: loco::NodeShape visit(const loco::Push *node) final { assert(loco::shape_known(node->from())); - return loco::shape_get(node->from()); + return node_shape(node->from()); } // CASE: Pull @@ -404,30 +404,30 @@ public: } // CASE: ReLU - loco::NodeShape visit(const loco::ReLU *node) final { return loco::shape_get(node->input()); } + loco::NodeShape visit(const loco::ReLU *node) final { return node_shape(node->input()); } // CASE: ReLU6 - loco::NodeShape visit(const loco::ReLU6 *node) final { return loco::shape_get(node->input()); } + loco::NodeShape visit(const loco::ReLU6 *node) final { return node_shape(node->input()); } // CASE: Tanh - loco::NodeShape visit(const loco::Tanh *node) final { return loco::shape_get(node->input()); } + loco::NodeShape visit(const loco::Tanh *node) final { return node_shape(node->input()); } // CASE: TensorBiasAdd loco::NodeShape visit(const loco::TensorBiasAdd *node) final { - assert(loco::shape_get(node->value()).domain() == loco::Domain::Tensor); - assert(loco::shape_get(node->bias()).domain() == loco::Domain::Bias); + assert(node_shape(node->value()).domain() == loco::Domain::Tensor); + assert(node_shape(node->bias()).domain() == loco::Domain::Bias); // Q. What to do when there is a mismatch between value's dim and bias's length? - return loco::shape_get(node->value()); + return node_shape(node->value()); } // CASE: TensorConcat loco::NodeShape visit(const loco::TensorConcat *node) { - auto const lhs_shape = loco::shape_get(node->lhs()).as(); - auto const rhs_shape = loco::shape_get(node->rhs()).as(); + auto const lhs_shape = node_shape(node->lhs()).as(); + auto const rhs_shape = node_shape(node->rhs()).as(); assert(lhs_shape.rank() == rhs_shape.rank()); uint32_t const out_rank = lhs_shape.rank(); @@ -453,10 +453,7 @@ public: } // CASE: TensorSoftmax - loco::NodeShape visit(const loco::TensorSoftmax *node) final - { - return loco::shape_get(node->input()); - } + loco::NodeShape visit(const loco::TensorSoftmax *node) final { return node_shape(node->input()); } }; } // namespace -- 2.7.4