return _ir_pass.FuseOps(expr, opt_level)
-def combine_parallel_conv2d(expr):
- """Fold multiple conv2d into one.
+def combine_parallel_conv2d(expr, min_num_branches=3):
+ """Combine multiple conv2d into one.
Parameters
----------
expr : tvm.relay.Expr
The input expression.
+ min_num_branches : int
+ The minimum number of parallel branches when the transformation should be applied.
+
Returns
-------
transformed_expr : tvm.relay.Expr
Transformed expression
"""
- return _ir_pass.CombineParallelConv2D(expr)
+ return _ir_pass.CombineParallelConv2D(expr, min_num_branches)
def alter_op_layout(expr):
}
+Array<Array<Layout> > StridedSliceInferCorrectLayout(
+ const Attrs& attrs,
+ const Array<Layout>& new_in_layouts,
+ const Array<Layout>& old_in_layouts,
+ const Array<Array<IndexExpr>>& old_in_shapes) {
+ CHECK(old_in_layouts.defined());
+ CHECK_EQ(old_in_layouts.size(), 1);
+ CHECK(old_in_shapes.defined());
+ CHECK_EQ(old_in_shapes.size(), 1);
+
+ auto layout = old_in_layouts[0];
+ if (layout.defined() && new_in_layouts.defined()) {
+ CHECK_EQ(new_in_layouts.size(), 1);
+ auto new_layout = new_in_layouts[0];
+ auto shape = old_in_shapes[0];
+
+ // NOTE: Discard "const" qualifier here.
+ auto *params = const_cast<StridedSliceAttrs*>(attrs.as<StridedSliceAttrs>());
+
+ Array<Integer> new_begin, new_end;
+
+ for (size_t i = 0; i < params->begin.size(); i++) {
+ const LayoutAxis& axis = layout[i];
+ if (!axis.IsPrimal()) {
+ // original layout that contains splitted axes is not supported
+ return {{Layout::Undef()}, {Layout::Undef()}};
+ }
+ auto factor = new_layout.FactorOf(axis);
+ if (factor == -1) {
+ new_begin.push_back(params->begin[i]);
+ new_end.push_back(params->end[i]);
+ } else {
+ if (params->strides.defined() && i < params->strides.size()) {
+ auto stride = params->strides[i];
+ // arbitrary stride is not supported
+ if (stride.defined() && stride->value != 1) {
+ return {{Layout::Undef()}, {Layout::Undef()}};
+ }
+ }
+ int64_t begin = params->begin[i].defined() ? params->begin[i]->value : 0;
+ int64_t end = params->end[i].defined() ? params->end[i]->value :
+ shape[i].as<IntImm>()->value;
+ if (begin % factor || end % factor) {
+ // transform to original layout
+ return {{Layout::Undef()}, {Layout::Undef()}};
+ }
+ new_begin.push_back(tvm::Integer(begin / factor));
+ new_end.push_back(tvm::Integer(end / factor));
+ }
+ }
+ layout = new_layout;
+ params->begin = new_begin;
+ params->end = new_end;
+ }
+ return {{layout}, {layout}};
+}
+
+
// Positional relay function to create StridedSlice operator used by frontend FFI.
Expr MakeStridedSlice(Expr data,
Array<Integer> begin,
.set_attrs_type_key("relay.attrs.StridedSliceAttrs")
.add_type_rel("StridedSlice", StridedSliceRel)
.set_attr<FTVMCompute>("FTVMCompute", StridedSliceCompute)
-.set_attr<TOpPattern>("TOpPattern", kInjective);
+.set_attr<TOpPattern>("TOpPattern", kInjective)
+.set_attr<FInferCorrectLayout>("FInferCorrectLayout", StridedSliceInferCorrectLayout);
// relay.split
class ParallelConv2DCombiner {
public:
+ explicit ParallelConv2DCombiner(uint64_t min_num_branches) : min_num_branches_(min_num_branches) {
+ }
+
Expr Combine(const Expr& expr) {
auto groups = BranchGroupFinder().Find(expr);
for (const Group& group : groups) {
- if (group.size() < 2) continue;
+ if (group.size() < min_num_branches_) {
+ continue;
+ }
CombineBranches(group);
}
return ExprSubst(expr, std::move(subst_map_));
private:
std::unordered_map<Expr, Expr, NodeHash, NodeEqual> subst_map_;
+ uint64_t min_num_branches_;
std::tuple<Expr, IndexExpr> TransformWeight(const Group& branches) {
int64_t num_filters = 0; // number of filters of the transformed weight
}
};
-Expr CombineParallelConv2D(const Expr& expr) { return ParallelConv2DCombiner().Combine(expr); }
+/*! \brief Combine parallel conv2d if number of branches >= min_num_branches */
+Expr CombineParallelConv2D(const Expr& expr, uint64_t min_num_branches) {
+ return ParallelConv2DCombiner(min_num_branches).Combine(expr);
+}
TVM_REGISTER_API("relay._ir_pass.CombineParallelConv2D")
.set_body([](TVMArgs args, TVMRetValue* ret) {
- *ret = CombineParallelConv2D(args[0]);
+ *ret = CombineParallelConv2D(args[0], args[1]);
});
} // namespace relay
assert(alpha_equal(a, b))
+def test_alter_layout_strided_slice():
+ """Test rewriting strided_slice during alter_iop_layout"""
+ def before():
+ x = relay.var("x", shape=(1, 32, 28, 28))
+ weight = relay.var('weight', shape=(32, 32, 3, 3))
+ y = relay.nn.conv2d(x, weight, channels=32, kernel_size=(3, 3), padding=(1, 1))
+ y = relay.strided_slice(y, begin=[0, 16], end=[None, None])
+ y = relay.Function(free_vars(y), y)
+ return y
+
+ @register_alter_op_layout("nn.conv2d", level=109)
+ def alter_conv2d(attrs, inputs, tinfos):
+ data, weight = inputs
+ new_attrs = dict(attrs)
+ new_attrs['data_layout'] = 'NCHW4c'
+ return relay.nn.conv2d(data, weight, **new_attrs)
+
+ def expected():
+ x = relay.var("x", shape=(1, 32, 28, 28))
+ weight = relay.var("weight")
+ x = relay.layout_transform(x, "NCHW", "NCHW4c")
+ y = relay.nn.conv2d(x, weight, channels=32, kernel_size=(3, 3), padding=(1, 1),
+ data_layout="NCHW4c")
+ y = relay.strided_slice(y, begin=[0, 4], end=[None, 8])
+ y = relay.layout_transform(y, "NCHW4c", "NCHW")
+ y = relay.Function(free_vars(y), y)
+ return y
+
+ a = before()
+ a = infer_type(a)
+ a = canonicalize_ops(a)
+ a = infer_type(a)
+
+ a = alter_op_layout(a)
+ a = infer_type(a)
+
+ b = expected()
+ b = infer_type(b)
+
+ assert(alpha_equal(a, b))
+
+
if __name__ == "__main__":
test_alter_op()
test_alter_return_none()
test_alter_layout_scalar()
test_alter_layout_concatenate()
test_alter_layout_nchw_upsamping_op()
+ test_alter_layout_strided_slice()
y_before = before(x, w1, w2, w3, w4)
y = relay.ir_pass.infer_type(y_before)
- y = relay.ir_pass.combine_parallel_conv2d(y)
+ y = relay.ir_pass.combine_parallel_conv2d(y, min_num_branches=2)
y = relay.ir_pass.infer_type(y)
y_expected = expected(x, w1, w2, w3, w4, channels1, channels2, channels3, channels4)
y_expected = relay.ir_pass.infer_type(y_expected)
bias = relay.var("bias", shape=(channels2, 1, 1))
y_before = before(x, w1, w2, scale1, scale2, bias)
y = relay.ir_pass.infer_type(y_before)
- y = relay.ir_pass.combine_parallel_conv2d(y)
+ y = relay.ir_pass.combine_parallel_conv2d(y, min_num_branches=2)
y = relay.ir_pass.infer_type(y)
y_expected = expected(x, w1, w2, scale1, scale2, bias, channels1, channels2)
y_expected = relay.ir_pass.infer_type(y_expected)
scale2 = relay.var("scale2", shape=(1,))
y_before = before(x, w1, w2, scale1, scale2)
y = relay.ir_pass.infer_type(y_before)
- y = relay.ir_pass.combine_parallel_conv2d(y)
+ y = relay.ir_pass.combine_parallel_conv2d(y, min_num_branches=2)
y = relay.ir_pass.infer_type(y)
y_expected = expected(x, w1, w2, scale1, scale2, channels1, channels2)
y_expected = relay.ir_pass.infer_type(y_expected)
w = relay.var("w", shape=(out_c, in_c, 1, 1))
y_before = before(x, w, repeat)
y = relay.ir_pass.infer_type(y_before)
- y = relay.ir_pass.combine_parallel_conv2d(y)
+ y = relay.ir_pass.combine_parallel_conv2d(y, min_num_branches=2)
y = relay.ir_pass.infer_type(y)
y_expected = expected(x, w, out_c, repeat)
y_expected = relay.ir_pass.infer_type(y_expected)