*
* \return A Tensor whose op member is the split operation
*/
-inline Array<Tensor> split(const Tensor& x, Array<Integer> split_indices, int axis,
+inline Array<Tensor> split(const Tensor& x, Array<PrimExpr> split_indices, int axis,
std::string name = "T_split", std::string tag = kInjective) {
if (axis < 0) {
axis += static_cast<int>(x->shape.size());
}
CHECK_LT(axis, x->shape.size()) << "axis out of bounds";
- auto src_axis_size = static_cast<int>(GetConstInt(x->shape[axis]));
- std::vector<int> begin_ids;
+ auto src_axis_size = x->shape[axis];
+ std::vector<PrimExpr> begin_ids;
begin_ids.push_back(0);
- for (Integer idx : split_indices) {
- int val = static_cast<int>(idx->value);
- CHECK_GT(val, begin_ids.back()) << "split_indices must be sorted";
- begin_ids.push_back(val);
+ for (auto idx : split_indices) {
+ auto idx_node = idx.as<IntImmNode>();
+ auto back_node = begin_ids.back().as<IntImmNode>();
+ if (idx_node && back_node) {
+ CHECK_GT(idx_node->value, back_node->value) << "split_indices must be sorted";
+ }
+ begin_ids.push_back(idx);
}
Array<Array<PrimExpr> > out_shapes;
for (size_t i = 0; i < begin_ids.size(); ++i) {
- int out_axis_size;
+ PrimExpr out_axis_size;
if (i == begin_ids.size() - 1) {
out_axis_size = src_axis_size - begin_ids[i];
} else {
}
CHECK_LT(axis, x->shape.size()) << "axis out of bounds";
- auto src_axis_size = static_cast<int>(GetConstInt(x->shape[axis]));
+ auto src_axis_size = x->shape[axis];
CHECK_GT(num_sections, 0) << "Slice count must be > 0";
- CHECK_EQ(src_axis_size % num_sections, 0)
- << "num_sections must be an integer factor of the size of axis " << axis << " ("
- << src_axis_size << ")";
- Array<Integer> split_indices;
- auto seg_size = src_axis_size / num_sections;
+ if (auto node = src_axis_size.as<IntImmNode>()) {
+ CHECK_EQ(node->value % num_sections, 0)
+ << "num_sections must be an integer factor of the size of axis " << axis << " ("
+ << node->value << ")";
+ }
+
+ Array<PrimExpr> split_indices;
+ auto seg_size = indexdiv(src_axis_size, num_sections);
for (int i = 0; i < num_sections; ++i) {
// region at index 0 is added by split()
if (i != 0) {
if len(indices_or_sections) == 1:
for i in const_range(data_shape.shape[0]):
if i == axis:
+ assert data_shape[axis] % indices_or_sections[0] == 0, \
+ "num_sections must be an integer factor of the size of axis"
out[i] = ceil_div(data_shape[axis], indices_or_sections[0])
else:
out[i] = data_shape[i]
"""
if isinstance(attrs.indices_or_sections, (int, tvm.tir.IntImm)):
indices_or_sections = get_const_int(attrs.indices_or_sections)
+ assert indices_or_sections > 0, "Slice count must be > 0"
else:
- indices_or_sections = get_const_tuple(attrs.indices_or_sections)
+ indices_or_sections = list(get_const_tuple(attrs.indices_or_sections))
+ assert sorted(indices_or_sections)[0] > 0 and \
+ indices_or_sections == sorted(indices_or_sections), \
+ "split_indices must be sorted"
axis = get_const_int(attrs.axis)
int64_t num_sections = sections->value;
return Array<te::Tensor>{topi::split_sections(inputs[0], num_sections, param->axis)};
} else {
- auto indices = Downcast<Array<Integer>>(param->indices_or_sections);
+ Array<PrimExpr> indices;
+ for (auto i : Downcast<Array<Integer>>(param->indices_or_sections)) {
+ indices.push_back(IntImm(DataType::Int(32), i.as<IntImmNode>()->value));
+ }
return Array<te::Tensor>{topi::split(inputs[0], indices, param->axis)};
}
}
def test_any_split():
verify_any_split((relay.Any(), 4), 2, 1, (9, 4), [(9, 2), (9, 2)])
+ verify_any_split((relay.Any(), relay.Any()), 2, 1, (9, 4), [(9, 2), (9, 2)])
verify_any_split((relay.Any(), 12), (1, 4, 8), 1, (7, 12), [(7, 1), (7, 3), (7, 4)])
+ verify_any_split((relay.Any(), relay.Any()), (1, 4, 8), 1, (7, 12), [(7, 1), (7, 3), (7, 4)])
def test_any_batch_flatten():
mod = tvm.IRModule()