[exo] TransposeConv type & shape inference (#8258)
author박천교/On-Device Lab(SR)/Engineer/삼성전자 <ch.bahk@samsung.com>
Thu, 17 Oct 2019 05:28:38 +0000 (14:28 +0900)
committer박세희/On-Device Lab(SR)/Principal Engineer/삼성전자 <saehie.park@samsung.com>
Thu, 17 Oct 2019 05:28:38 +0000 (14:28 +0900)
This commit introduces type and shape inferences of TFLTransposeConv

Signed-off-by: Cheongyo Bahk <ch.bahk@samsung.com>
compiler/exo/src/Dialect/Service/TFLShapeInferenceRule.cpp
compiler/exo/src/Dialect/Service/TFLTypeInferenceRule.cpp

index e4b986f..06dc3e9 100644 (file)
@@ -396,6 +396,24 @@ public:
     else
       EXO_THROW("perm of TFLTranspose should be either ConstGen or TFLConst");
   }
+
+  loco::NodeShape visit(const locoex::TFLTransposeConv *node) final
+  {
+    // TransposeConv's output shape is written in its 'inputSizes' argument
+    auto input_sizes_const = dynamic_cast<locoex::TFLConst *>(node->inputSizes());
+    EXO_ASSERT(input_sizes_const, "Only support when TFLTransposeConv's inputSizes is TFLConst")
+    EXO_ASSERT(input_sizes_const->dtype() == loco::DataType::S32, "Only support S32 dtype")
+    EXO_ASSERT(input_sizes_const->rank() == 1 && input_sizes_const->dim(0).value() == 4,
+               "Only support rank 1 with 4 entries")
+
+    loco::TensorShape shape;
+
+    shape.rank(4);
+    for (uint32_t axis = 0; axis < 4; ++axis)
+      shape.dim(axis) = input_sizes_const->at<loco::DataType::S32>(axis);
+
+    return loco::NodeShape{shape};
+  }
 };
 
 } // namespace
index 2c3bf64..d9bf695 100644 (file)
@@ -100,6 +100,11 @@ struct TypeInferenceAlgorithm final : public locoex::TFLNodeVisitor<loco::DataTy
   {
     return loco::dtype_get(node->a());
   }
+
+  loco::DataType visit(const locoex::TFLTransposeConv *node) final
+  {
+    return loco::dtype_get(node->outBackprop());
+  }
 };
 
 } // namespace