From 53e4c52cab8c6506f5e4298ba4130fa8d4effab8 Mon Sep 17 00:00:00 2001 From: =?utf8?q?=EB=82=A8=EA=B6=81=EC=84=9D/On-Device=20Lab=28SR=29/Enginee?= =?utf8?q?r/=EC=82=BC=EC=84=B1=EC=A0=84=EC=9E=90?= Date: Wed, 14 Aug 2019 07:30:41 +0900 Subject: [PATCH] [loco] Introduce services for TensorSoftmax (#6564) This commit will introduce loco services for `TensorSoftmax` Signed-off-by: Seok NamKoong --- compiler/loco/src/Service/CanonicalShapeInferenceRule.cpp | 6 ++++++ compiler/loco/src/Service/TypeInference.cpp | 1 + 2 files changed, 7 insertions(+) diff --git a/compiler/loco/src/Service/CanonicalShapeInferenceRule.cpp b/compiler/loco/src/Service/CanonicalShapeInferenceRule.cpp index f5ca943..243693a 100644 --- a/compiler/loco/src/Service/CanonicalShapeInferenceRule.cpp +++ b/compiler/loco/src/Service/CanonicalShapeInferenceRule.cpp @@ -398,6 +398,12 @@ public: return loco::NodeShape{out_shape}; } + + // CASE: TensorSoftmax + loco::NodeShape visit(const loco::TensorSoftmax *node) final + { + return loco::shape_get(node->input()); + } }; } // namespace diff --git a/compiler/loco/src/Service/TypeInference.cpp b/compiler/loco/src/Service/TypeInference.cpp index b238b4a..2776fda 100644 --- a/compiler/loco/src/Service/TypeInference.cpp +++ b/compiler/loco/src/Service/TypeInference.cpp @@ -121,6 +121,7 @@ struct CanonicalTypeForwardAlgorithm final : public loco::CanonicalNodeVisitorinput()); } loco::DataType visit(const loco::TensorConcat *node) { return loco::dtype_get(node->lhs()); } loco::DataType visit(const loco::TensorBiasAdd *node) { return loco::dtype_get(node->value()); } + loco::DataType visit(const loco::TensorSoftmax *node) { return loco::dtype_get(node->input()); } }; } // namespace -- 2.7.4