[exo] Type & shape inference for TFLMean (#8739)
author박천교/On-Device Lab(SR)/Engineer/삼성전자 <ch.bahk@samsung.com>
Tue, 5 Nov 2019 06:35:46 +0000 (15:35 +0900)
committer박종현/On-Device Lab(SR)/Staff Engineer/삼성전자 <jh1302.park@samsung.com>
Tue, 5 Nov 2019 06:35:46 +0000 (15:35 +0900)
This commit introduces type and shape inference for TFLMean

Signed-off-by: Cheongyo Bahk <ch.bahk@samsung.com>
compiler/exo/src/Dialect/Service/TFLShapeInferenceRule.cpp
compiler/exo/src/Dialect/Service/TFLTypeInferenceRule.cpp

index dd17ca4..34a8f63 100644 (file)
@@ -360,6 +360,62 @@ public:
     return infer_pool_2d_shape(node);
   }
 
+  loco::NodeShape visit(const locoex::TFLMean *node) final
+  {
+    const loco::DataType S32 = loco::DataType::S32;
+
+    auto input_shape = loco::shape_get(node->input()).as<loco::TensorShape>();
+    auto reduction_indices = dynamic_cast<locoex::TFLConst *>(node->reduction_indices());
+
+    { // Exceptions
+      // TODO support non-const case
+      EXO_ASSERT(reduction_indices, "Only support constant reduction_indices");
+      // TODO support other data type
+      EXO_ASSERT(reduction_indices->dtype() == S32, "Only support int 32");
+    }
+
+    std::vector<int32_t> reduction_values;
+
+    for (uint32_t i = 0; i < reduction_indices->size<S32>(); ++i)
+    {
+      int32_t axis = reduction_indices->at<S32>(i);
+      if (axis < 0)
+        axis += input_shape.rank();
+      if (not(0 <= axis and axis < static_cast<int32_t>(input_shape.rank())))
+        EXO_THROW("Invalid reduction axis for MEAN");
+      reduction_values.push_back(axis);
+    }
+
+    loco::TensorShape output_shape;
+
+    if (node->keep_dims())
+    {
+      output_shape.rank(input_shape.rank());
+      for (uint32_t i = 0; i < input_shape.rank(); ++i)
+        output_shape.dim(i) = input_shape.dim(i);
+      for (uint32_t i = 0; i < reduction_values.size(); ++i)
+        output_shape.dim(reduction_values.at(i)) = 1;
+    }
+    else
+    {
+      std::vector<bool> check_reduce(input_shape.rank(), false);
+      for (uint32_t i = 0; i < reduction_values.size(); ++i)
+        check_reduce.at(reduction_values.at(i)) = true;
+
+      uint32_t reduce_cnt = 0;
+      for (uint32_t i = 0; i < check_reduce.size(); ++i)
+        if (check_reduce.at(i))
+          ++reduce_cnt;
+
+      output_shape.rank(input_shape.rank() - reduce_cnt);
+      for (uint32_t i = 0, j = 0; i < check_reduce.size(); ++i)
+        if (check_reduce.at(i) == false)
+          output_shape.dim(j++) = i;
+    }
+
+    return loco::NodeShape{output_shape};
+  }
+
   loco::NodeShape visit(const locoex::TFLMul *node) final
   {
     auto x_shape = loco::shape_get(node->x()).as<loco::TensorShape>();
index 62970fa..8184520 100644 (file)
@@ -77,6 +77,8 @@ struct TypeInferenceAlgorithm final : public locoex::TFLNodeVisitor<loco::DataTy
     return loco::dtype_get(node->value());
   }
 
+  loco::DataType visit(const locoex::TFLMean *node) final { return loco::dtype_get(node->input()); }
+
   loco::DataType visit(const locoex::TFLMul *node) final { return loco::dtype_get(node->x()); }
 
   loco::DataType visit(const locoex::TFLRelu *node) final