[Relay] Enhance relay.split(), allow splitted dim to be dynamic (#6289)
authorlixiaoquan <radioheads@163.com>
Wed, 2 Sep 2020 06:04:37 +0000 (14:04 +0800)
committerGitHub <noreply@github.com>
Wed, 2 Sep 2020 06:04:37 +0000 (23:04 -0700)
* [Relay] Enhance relay.split(), allow splitted dim to be dynamic

* Add assert in shape function

* Fix CI

include/tvm/topi/transform.h
python/tvm/relay/op/_transform.py
src/relay/op/tensor/transform.cc
tests/python/relay/test_any.py

index eb69fc5..b09b035 100644 (file)
@@ -481,26 +481,29 @@ inline Tensor stack(const Array<Tensor>& inputs, int axis = 0, std::string name
  *
  * \return A Tensor whose op member is the split operation
  */
-inline Array<Tensor> split(const Tensor& x, Array<Integer> split_indices, int axis,
+inline Array<Tensor> split(const Tensor& x, Array<PrimExpr> split_indices, int axis,
                            std::string name = "T_split", std::string tag = kInjective) {
   if (axis < 0) {
     axis += static_cast<int>(x->shape.size());
   }
   CHECK_LT(axis, x->shape.size()) << "axis out of bounds";
 
-  auto src_axis_size = static_cast<int>(GetConstInt(x->shape[axis]));
-  std::vector<int> begin_ids;
+  auto src_axis_size = x->shape[axis];
+  std::vector<PrimExpr> begin_ids;
   begin_ids.push_back(0);
 
-  for (Integer idx : split_indices) {
-    int val = static_cast<int>(idx->value);
-    CHECK_GT(val, begin_ids.back()) << "split_indices must be sorted";
-    begin_ids.push_back(val);
+  for (auto idx : split_indices) {
+    auto idx_node = idx.as<IntImmNode>();
+    auto back_node = begin_ids.back().as<IntImmNode>();
+    if (idx_node && back_node) {
+      CHECK_GT(idx_node->value, back_node->value) << "split_indices must be sorted";
+    }
+    begin_ids.push_back(idx);
   }
 
   Array<Array<PrimExpr> > out_shapes;
   for (size_t i = 0; i < begin_ids.size(); ++i) {
-    int out_axis_size;
+    PrimExpr out_axis_size;
     if (i == begin_ids.size() - 1) {
       out_axis_size = src_axis_size - begin_ids[i];
     } else {
@@ -668,15 +671,18 @@ inline Array<Tensor> split_sections(const Tensor& x, int num_sections, int axis,
   }
   CHECK_LT(axis, x->shape.size()) << "axis out of bounds";
 
-  auto src_axis_size = static_cast<int>(GetConstInt(x->shape[axis]));
+  auto src_axis_size = x->shape[axis];
 
   CHECK_GT(num_sections, 0) << "Slice count must be > 0";
-  CHECK_EQ(src_axis_size % num_sections, 0)
-      << "num_sections must be an integer factor of the size of axis " << axis << " ("
-      << src_axis_size << ")";
 
-  Array<Integer> split_indices;
-  auto seg_size = src_axis_size / num_sections;
+  if (auto node = src_axis_size.as<IntImmNode>()) {
+    CHECK_EQ(node->value % num_sections, 0)
+        << "num_sections must be an integer factor of the size of axis " << axis << " ("
+        << node->value << ")";
+  }
+
+  Array<PrimExpr> split_indices;
+  auto seg_size = indexdiv(src_axis_size, num_sections);
   for (int i = 0; i < num_sections; ++i) {
     // region at index 0 is added by split()
     if (i != 0) {
index b562233..937c36e 100644 (file)
@@ -634,6 +634,8 @@ def _split_shape_func(data_shape, index, indices_or_sections, axis):
     if len(indices_or_sections) == 1:
         for i in const_range(data_shape.shape[0]):
             if i == axis:
+                assert data_shape[axis] % indices_or_sections[0] == 0, \
+                    "num_sections must be an integer factor of the size of axis"
                 out[i] = ceil_div(data_shape[axis], indices_or_sections[0])
             else:
                 out[i] = data_shape[i]
@@ -658,8 +660,12 @@ def split_shape_func(attrs, inputs, _):
     """
     if isinstance(attrs.indices_or_sections, (int, tvm.tir.IntImm)):
         indices_or_sections = get_const_int(attrs.indices_or_sections)
+        assert indices_or_sections > 0, "Slice count must be > 0"
     else:
-        indices_or_sections = get_const_tuple(attrs.indices_or_sections)
+        indices_or_sections = list(get_const_tuple(attrs.indices_or_sections))
+        assert sorted(indices_or_sections)[0] > 0 and \
+               indices_or_sections == sorted(indices_or_sections), \
+            "split_indices must be sorted"
 
     axis = get_const_int(attrs.axis)
 
index 1e223b7..9f67b7f 100644 (file)
@@ -2397,7 +2397,10 @@ Array<te::Tensor> SplitCompute(const Attrs& attrs, const Array<te::Tensor>& inpu
     int64_t num_sections = sections->value;
     return Array<te::Tensor>{topi::split_sections(inputs[0], num_sections, param->axis)};
   } else {
-    auto indices = Downcast<Array<Integer>>(param->indices_or_sections);
+    Array<PrimExpr> indices;
+    for (auto i : Downcast<Array<Integer>>(param->indices_or_sections)) {
+      indices.push_back(IntImm(DataType::Int(32), i.as<IntImmNode>()->value));
+    }
     return Array<te::Tensor>{topi::split(inputs[0], indices, param->axis)};
   }
 }
index a84020d..7f37336 100644 (file)
@@ -487,7 +487,9 @@ def verify_any_split(data_shape, indices_or_sections, axis, static_data_shape, r
 
 def test_any_split():
     verify_any_split((relay.Any(), 4), 2, 1, (9, 4), [(9, 2), (9, 2)])
+    verify_any_split((relay.Any(), relay.Any()), 2, 1, (9, 4), [(9, 2), (9, 2)])
     verify_any_split((relay.Any(), 12), (1, 4, 8), 1, (7, 12), [(7, 1), (7, 3), (7, 4)])
+    verify_any_split((relay.Any(), relay.Any()), (1, 4, 8), 1, (7, 12), [(7, 1), (7, 3), (7, 4)])
 
 def test_any_batch_flatten():
     mod = tvm.IRModule()