From 6ed242eadebee61e7da2ad468519661acfe31b84 Mon Sep 17 00:00:00 2001 From: lixiaoquan Date: Wed, 2 Sep 2020 14:04:37 +0800 Subject: [PATCH] [Relay] Enhance relay.split(), allow splitted dim to be dynamic (#6289) * [Relay] Enhance relay.split(), allow splitted dim to be dynamic * Add assert in shape function * Fix CI --- include/tvm/topi/transform.h | 34 ++++++++++++++++++++-------------- python/tvm/relay/op/_transform.py | 8 +++++++- src/relay/op/tensor/transform.cc | 5 ++++- tests/python/relay/test_any.py | 2 ++ 4 files changed, 33 insertions(+), 16 deletions(-) diff --git a/include/tvm/topi/transform.h b/include/tvm/topi/transform.h index eb69fc5..b09b035 100644 --- a/include/tvm/topi/transform.h +++ b/include/tvm/topi/transform.h @@ -481,26 +481,29 @@ inline Tensor stack(const Array& inputs, int axis = 0, std::string name * * \return A Tensor whose op member is the split operation */ -inline Array split(const Tensor& x, Array split_indices, int axis, +inline Array split(const Tensor& x, Array split_indices, int axis, std::string name = "T_split", std::string tag = kInjective) { if (axis < 0) { axis += static_cast(x->shape.size()); } CHECK_LT(axis, x->shape.size()) << "axis out of bounds"; - auto src_axis_size = static_cast(GetConstInt(x->shape[axis])); - std::vector begin_ids; + auto src_axis_size = x->shape[axis]; + std::vector begin_ids; begin_ids.push_back(0); - for (Integer idx : split_indices) { - int val = static_cast(idx->value); - CHECK_GT(val, begin_ids.back()) << "split_indices must be sorted"; - begin_ids.push_back(val); + for (auto idx : split_indices) { + auto idx_node = idx.as(); + auto back_node = begin_ids.back().as(); + if (idx_node && back_node) { + CHECK_GT(idx_node->value, back_node->value) << "split_indices must be sorted"; + } + begin_ids.push_back(idx); } Array > out_shapes; for (size_t i = 0; i < begin_ids.size(); ++i) { - int out_axis_size; + PrimExpr out_axis_size; if (i == begin_ids.size() - 1) { out_axis_size = src_axis_size - begin_ids[i]; } else { @@ -668,15 +671,18 @@ inline Array split_sections(const Tensor& x, int num_sections, int axis, } CHECK_LT(axis, x->shape.size()) << "axis out of bounds"; - auto src_axis_size = static_cast(GetConstInt(x->shape[axis])); + auto src_axis_size = x->shape[axis]; CHECK_GT(num_sections, 0) << "Slice count must be > 0"; - CHECK_EQ(src_axis_size % num_sections, 0) - << "num_sections must be an integer factor of the size of axis " << axis << " (" - << src_axis_size << ")"; - Array split_indices; - auto seg_size = src_axis_size / num_sections; + if (auto node = src_axis_size.as()) { + CHECK_EQ(node->value % num_sections, 0) + << "num_sections must be an integer factor of the size of axis " << axis << " (" + << node->value << ")"; + } + + Array split_indices; + auto seg_size = indexdiv(src_axis_size, num_sections); for (int i = 0; i < num_sections; ++i) { // region at index 0 is added by split() if (i != 0) { diff --git a/python/tvm/relay/op/_transform.py b/python/tvm/relay/op/_transform.py index b562233..937c36e 100644 --- a/python/tvm/relay/op/_transform.py +++ b/python/tvm/relay/op/_transform.py @@ -634,6 +634,8 @@ def _split_shape_func(data_shape, index, indices_or_sections, axis): if len(indices_or_sections) == 1: for i in const_range(data_shape.shape[0]): if i == axis: + assert data_shape[axis] % indices_or_sections[0] == 0, \ + "num_sections must be an integer factor of the size of axis" out[i] = ceil_div(data_shape[axis], indices_or_sections[0]) else: out[i] = data_shape[i] @@ -658,8 +660,12 @@ def split_shape_func(attrs, inputs, _): """ if isinstance(attrs.indices_or_sections, (int, tvm.tir.IntImm)): indices_or_sections = get_const_int(attrs.indices_or_sections) + assert indices_or_sections > 0, "Slice count must be > 0" else: - indices_or_sections = get_const_tuple(attrs.indices_or_sections) + indices_or_sections = list(get_const_tuple(attrs.indices_or_sections)) + assert sorted(indices_or_sections)[0] > 0 and \ + indices_or_sections == sorted(indices_or_sections), \ + "split_indices must be sorted" axis = get_const_int(attrs.axis) diff --git a/src/relay/op/tensor/transform.cc b/src/relay/op/tensor/transform.cc index 1e223b7..9f67b7f 100644 --- a/src/relay/op/tensor/transform.cc +++ b/src/relay/op/tensor/transform.cc @@ -2397,7 +2397,10 @@ Array SplitCompute(const Attrs& attrs, const Array& inpu int64_t num_sections = sections->value; return Array{topi::split_sections(inputs[0], num_sections, param->axis)}; } else { - auto indices = Downcast>(param->indices_or_sections); + Array indices; + for (auto i : Downcast>(param->indices_or_sections)) { + indices.push_back(IntImm(DataType::Int(32), i.as()->value)); + } return Array{topi::split(inputs[0], indices, param->axis)}; } } diff --git a/tests/python/relay/test_any.py b/tests/python/relay/test_any.py index a84020d..7f37336 100644 --- a/tests/python/relay/test_any.py +++ b/tests/python/relay/test_any.py @@ -487,7 +487,9 @@ def verify_any_split(data_shape, indices_or_sections, axis, static_data_shape, r def test_any_split(): verify_any_split((relay.Any(), 4), 2, 1, (9, 4), [(9, 2), (9, 2)]) + verify_any_split((relay.Any(), relay.Any()), 2, 1, (9, 4), [(9, 2), (9, 2)]) verify_any_split((relay.Any(), 12), (1, 4, 8), 1, (7, 12), [(7, 1), (7, 3), (7, 4)]) + verify_any_split((relay.Any(), relay.Any()), (1, 4, 8), 1, (7, 12), [(7, 1), (7, 3), (7, 4)]) def test_any_batch_flatten(): mod = tvm.IRModule() -- 2.7.4