[moco-tf] Type inference for TFSoftmax (#6700)
author박세희/On-Device Lab(SR)/Principal Engineer/삼성전자 <saehie.park@samsung.com>
Tue, 20 Aug 2019 04:01:29 +0000 (13:01 +0900)
committerGitHub Enterprise <noreply-CODE@samsung.com>
Tue, 20 Aug 2019 04:01:29 +0000 (13:01 +0900)
This will add type inference for TFSoftmax node

Signed-off-by: SaeHie Park <saehie.park@samsung.com>
compiler/moco-tf/src/Dialect/TFTypeInferenceRule.cpp

index d8e823e..ca07fa2 100644 (file)
@@ -52,6 +52,7 @@ struct TypeForwardAlgorithm final : public moco::tf::TFNodeVisitor<loco::DataTyp
 
   loco::DataType visit(const TFShape *node) { return node->dtype(); }
 
+  loco::DataType visit(const TFSoftmax *node) { return dtype_get(node->logits()); }
   loco::DataType visit(const TFSqrt *node) { return dtype_get(node->x()); }
   loco::DataType visit(const TFSquaredDifference *node) { return dtype_get(node->x()); }
   loco::DataType visit(const TFSqueeze *node) { return dtype_get(node->input()); }