From fbb6dd28a94823d385c37509a55d693bb660cd98 Mon Sep 17 00:00:00 2001 From: =?utf8?q?=EB=B0=95=EC=A2=85=ED=98=84/On-Device=20Lab=28SR=29/Staff?= =?utf8?q?=20Engineer/=EC=82=BC=EC=84=B1=EC=A0=84=EC=9E=90?= Date: Thu, 8 Aug 2019 11:41:16 +0900 Subject: [PATCH] [loco] Infer the shape of bias-taking nodes (#6332) * [loco] Infer the shape of bias-taking nodes CanonicalShapeInferenceRule is now able to infer the shape of BiasEncode, FeatureBiasAdd, and TensorBiasAdd nodes. Signed-off-by: Jonghyun Park * Fix a typo --- .../src/Service/CanonicalShapeInferenceRule.cpp | 37 ++++++++++++++++++++-- 1 file changed, 34 insertions(+), 3 deletions(-) diff --git a/compiler/loco/src/Service/CanonicalShapeInferenceRule.cpp b/compiler/loco/src/Service/CanonicalShapeInferenceRule.cpp index f865395..9bef64f 100644 --- a/compiler/loco/src/Service/CanonicalShapeInferenceRule.cpp +++ b/compiler/loco/src/Service/CanonicalShapeInferenceRule.cpp @@ -180,7 +180,19 @@ public: return loco::NodeShape{output_feature_shape}; } - // TODO Support BiasEncode + // CASE: BiasEncode + loco::NodeShape visit(const loco::BiasEncode *node) final + { + // The input of BiasEncode SHOULD BE a tensor! + assert(loco::shape_get(node->input()).domain() == loco::Domain::Tensor); + auto input_tensor_shape = loco::shape_get(node->input()).as(); + + loco::BiasShape output_bias_shape; + + output_bias_shape.length() = input_tensor_shape.dim(0); + + return loco::NodeShape{output_bias_shape}; + } // CASE: ConstGen loco::NodeShape visit(const loco::ConstGen *node) final @@ -252,7 +264,17 @@ public: } // TODO Support Forward - // TODO Support FeatureBiasAdd + + // CASE: FeatureBiasAdd + loco::NodeShape visit(const loco::FeatureBiasAdd *node) final + { + assert(loco::shape_get(node->value()).domain() == loco::Domain::Feature); + assert(loco::shape_get(node->bias()).domain() == loco::Domain::Bias); + + // Q. What to do when there is a mismatch between value's depth and bias's length? + + return loco::shape_get(node->value()); + } // CASE: FeatureDecode loco::NodeShape visit(const loco::FeatureDecode *node) final @@ -325,7 +347,16 @@ public: // CASE: ReLU6 loco::NodeShape visit(const loco::ReLU6 *node) final { return loco::shape_get(node->input()); } - // TODO Support TensorBiasAdd + // CASE: TensorBiasAdd + loco::NodeShape visit(const loco::TensorBiasAdd *node) final + { + assert(loco::shape_get(node->value()).domain() == loco::Domain::Tensor); + assert(loco::shape_get(node->bias()).domain() == loco::Domain::Bias); + + // Q. What to do when there is a mismatch between value's dim and bias's length? + + return loco::shape_get(node->value()); + } // CASE: TensorConcat loco::NodeShape visit(const loco::TensorConcat *node) -- 2.7.4