From 16b2a4b66dc3ff7c20da59fee9cc047a1593cf64 Mon Sep 17 00:00:00 2001 From: Lily Orth-Smith Date: Thu, 13 Aug 2020 09:08:34 -0700 Subject: [PATCH] [Relay][Dyn] Dynamic full operator (#6260) * moved full from other branch * fixed some typos * fix lint * add final newline * fix int64 test --- python/tvm/relay/op/_tensor.py | 6 +-- python/tvm/relay/op/dyn/_tensor.py | 1 + python/tvm/relay/op/dyn/_transform.py | 2 +- python/tvm/relay/op/transform.py | 6 ++- src/relay/op/dyn/tensor/transform.cc | 56 +++++++++++++++++++++++ src/relay/op/make_op.h | 2 +- src/relay/op/tensor/transform.cc | 41 ++++++----------- src/relay/transforms/dynamic_to_static.cc | 10 ++++ src/relay/transforms/pattern_util.h | 2 +- tests/python/relay/dyn/test_dynamic_op_level3.py | 24 ++++++---- tests/python/relay/dyn/test_dynamic_op_level6.py | 2 +- tests/python/relay/test_op_level3.py | 1 + tests/python/relay/test_pass_dynamic_to_static.py | 20 ++++++++ 13 files changed, 130 insertions(+), 43 deletions(-) diff --git a/python/tvm/relay/op/_tensor.py b/python/tvm/relay/op/_tensor.py index 28336cf..eccc2c3 100644 --- a/python/tvm/relay/op/_tensor.py +++ b/python/tvm/relay/op/_tensor.py @@ -201,11 +201,11 @@ def elemwise_shape_func(attrs, inputs, _): return [topi.math.identity(inputs[0])] register_shape_func("cast", False, elemwise_shape_func) -register_shape_func("zeros", True, no_data_full_shape_func) +register_shape_func("zeros", False, full_shape_func) register_shape_func("zeros_like", False, elemwise_shape_func) -register_shape_func("ones", True, no_data_full_shape_func) +register_shape_func("ones", False, full_shape_func) register_shape_func("ones_like", False, elemwise_shape_func) -register_shape_func("full", True, full_shape_func) +register_shape_func("full", False, full_shape_func) register_shape_func("full_like", False, elemwise_shape_func) register_shape_func("broadcast_to", True, full_shape_func) diff --git a/python/tvm/relay/op/dyn/_tensor.py b/python/tvm/relay/op/dyn/_tensor.py index 371e4ad..cd53641 100644 --- a/python/tvm/relay/op/dyn/_tensor.py +++ b/python/tvm/relay/op/dyn/_tensor.py @@ -44,3 +44,4 @@ register_pattern("dyn.zeros", OpPattern.ELEMWISE) register_shape_func("dyn.broadcast_to", True, full_shape_func) register_shape_func("dyn.ones", True, no_data_full_shape_func) register_shape_func("dyn.zeros", True, no_data_full_shape_func) +register_shape_func("dyn.full", True, full_shape_func) diff --git a/python/tvm/relay/op/dyn/_transform.py b/python/tvm/relay/op/dyn/_transform.py index 3a80f5a..46778fe 100644 --- a/python/tvm/relay/op/dyn/_transform.py +++ b/python/tvm/relay/op/dyn/_transform.py @@ -26,7 +26,7 @@ _reg.register_broadcast_schedule("dyn.broadcast_to") _reg.register_injective_schedule("dyn.reshape") _reg.register_broadcast_schedule("dyn.tile") _reg.register_injective_schedule("dyn.one_hot") - +_reg.register_injective_schedule("dyn.full") @script def _reshape_shape_func_input_data(data, newshape, ndim): diff --git a/python/tvm/relay/op/transform.py b/python/tvm/relay/op/transform.py index 5e5b867..b46b156 100644 --- a/python/tvm/relay/op/transform.py +++ b/python/tvm/relay/op/transform.py @@ -376,8 +376,12 @@ def full(fill_value, shape=(), dtype=""): result : relay.Expr The resulting tensor. """ + if isinstance(shape, Expr): + return _dyn_make.full(fill_value, shape, dtype) + if isinstance(shape, int): + shape = [shape] if isinstance(shape, (list, tuple)): - shape = const(list(shape), "int32") + shape = list(shape) return _make.full(fill_value, shape, dtype) diff --git a/src/relay/op/dyn/tensor/transform.cc b/src/relay/op/dyn/tensor/transform.cc index d2d6d69..06e1c57 100644 --- a/src/relay/op/dyn/tensor/transform.cc +++ b/src/relay/op/dyn/tensor/transform.cc @@ -28,6 +28,7 @@ #include #include #include +#include #include #include @@ -374,6 +375,61 @@ RELAY_REGISTER_OP("dyn.one_hot") .set_attr("FTVMCompute", OneHotCompute) .set_attr("TOpPattern", kOutEWiseFusable); +bool FullRel(const Array& types, int num_inputs, const Attrs& attrs, + const TypeReporter& reporter) { + CHECK_EQ(types.size(), 3); + const InitOpAttrs* param = attrs.as(); + const auto* fill_value = types[0].as(); + const auto* fill_shape = types[1].as(); + if (fill_value == nullptr) { + return false; + } + + DataType out_dtype = param->dtype; + if (out_dtype.bits() == 0) { + out_dtype = fill_value->dtype; + } + + CHECK_EQ(fill_value->shape.size(), 0) + << "Fill value should be a scalar but has dimension " << fill_value->shape.size() << "."; + + const IntImmNode* rank = fill_shape->shape[0].as(); + CHECK(rank) << "Parameter shape must have static rank"; + + std::vector oshape; + for (int i = 0; i < rank->value; ++i) { + oshape.push_back(Any()); + } + reporter->Assign(types[2], TensorType(oshape, out_dtype)); + return true; +} + +Expr MakeFull(Expr fill_value, Expr shape, DataType dtype) { + auto attrs = make_object(); + attrs->dtype = std::move(dtype); + static const Op& op = Op::Get("dyn.full"); + return Call(op, {fill_value, shape}, Attrs(attrs), {}); +} +Array FullCompute(const Attrs& attrs, const Array& inputs, + const Type& out_type) { + const auto* out_ttype = out_type.as(); + return {topi::full(out_ttype->shape, out_ttype->dtype, inputs[0]())}; +} +TVM_REGISTER_GLOBAL("relay.op.dyn._make.full").set_body_typed(MakeFull); + +RELAY_REGISTER_OP("dyn.full") + .describe(R"code(Fill array with scalar value. + +)code" TVM_ADD_FILELINE) + .set_attrs_type() + .set_num_inputs(2) + .add_argument("fill_value", "double", "The value to fill.") + .add_argument("shape", "Tensor", "Target shape.") + .set_support_level(3) + .add_type_rel("DynamicFull", FullRel) + .set_attr("FTVMCompute", FullCompute) + .set_attr("TOpPattern", kElemWise); + } // namespace dyn } // namespace relay } // namespace tvm diff --git a/src/relay/op/make_op.h b/src/relay/op/make_op.h index 8ca2203..1e17bbe 100644 --- a/src/relay/op/make_op.h +++ b/src/relay/op/make_op.h @@ -48,7 +48,7 @@ Expr MakeDense(Expr data, Expr weight, IndexExpr units, DataType out_dtype); Expr MakeExpandDims(Expr data, int axis, int num_newaxis); -Expr MakeFull(Expr fill_value, Expr shape, DataType dtype); +Expr MakeFull(Expr fill_value, Array shape, DataType dtype); Expr MakeLayoutTransform(Expr data, String src_layout, String dst_layout); diff --git a/src/relay/op/tensor/transform.cc b/src/relay/op/tensor/transform.cc index 79a8da4..be7cab1 100644 --- a/src/relay/op/tensor/transform.cc +++ b/src/relay/op/tensor/transform.cc @@ -994,10 +994,9 @@ TVM_REGISTER_NODE_TYPE(InitOpAttrs); bool FullRel(const Array& types, int num_inputs, const Attrs& attrs, const TypeReporter& reporter) { - CHECK_EQ(types.size(), 3); + CHECK_EQ(types.size(), 2); const InitOpAttrs* param = attrs.as(); const auto* fill_value = types[0].as(); - const auto* fill_shape = types[1].as(); if (fill_value == nullptr) { return false; } @@ -1010,40 +1009,29 @@ bool FullRel(const Array& types, int num_inputs, const Attrs& attrs, CHECK_EQ(fill_value->shape.size(), 0) << "Fill value should be a scalar but has dimension " << fill_value->shape.size() << "."; - const IntImmNode* shape_shape = fill_shape->shape[0].as(); - CHECK(shape_shape) << "Parameter shape must have static shape"; - std::vector oshape; - if (param->shape) { - const Array& cshape_array = param->shape.value(); - for (size_t i = 0; i < cshape_array.size(); ++i) { - oshape.push_back(cshape_array[i]); - } - } else { - for (int i = 0; i < shape_shape->value; ++i) { - oshape.push_back(Any()); - } + const Array& cshape_array = param->shape.value(); + for (size_t i = 0; i < cshape_array.size(); ++i) { + oshape.push_back(cshape_array[i]); } - reporter->Assign(types[2], TensorType(oshape, out_dtype)); + reporter->Assign(types[1], TensorType(oshape, out_dtype)); return true; } +Expr MakeFull(Expr fill_value, Array shape, DataType dtype) { + auto attrs = make_object(); + attrs->dtype = std::move(dtype); + attrs->shape = std::move(shape); + static const Op& op = Op::Get("full"); + return Call(op, {fill_value}, Attrs(attrs), {}); +} + Array FullCompute(const Attrs& attrs, const Array& inputs, const Type& out_type) { const auto* out_ttype = out_type.as(); return {topi::full(out_ttype->shape, out_ttype->dtype, inputs[0]())}; } -Expr MakeFull(Expr fill_value, Expr shape, DataType dtype) { - auto attrs = make_object(); - if (const auto* cshape = shape.as()) { - attrs->shape = ToVector(cshape->data); - } - attrs->dtype = std::move(dtype); - static const Op& op = Op::Get("full"); - return Call(op, {fill_value, shape}, Attrs(attrs), {}); -} - TVM_REGISTER_GLOBAL("relay.op._make.full").set_body_typed(MakeFull); RELAY_REGISTER_OP("full") @@ -1051,9 +1039,8 @@ RELAY_REGISTER_OP("full") )code" TVM_ADD_FILELINE) .set_attrs_type() - .set_num_inputs(2) + .set_num_inputs(1) .add_argument("fill_value", "double", "The value to fill.") - .add_argument("shape", "Tensor", "Target shape.") .set_support_level(3) .add_type_rel("Full", FullRel) .set_attr("FTVMCompute", FullCompute) diff --git a/src/relay/transforms/dynamic_to_static.cc b/src/relay/transforms/dynamic_to_static.cc index d0a6b07..0ccc4c3 100644 --- a/src/relay/transforms/dynamic_to_static.cc +++ b/src/relay/transforms/dynamic_to_static.cc @@ -114,6 +114,16 @@ class DynamicToStaticMutator : public MixedModeMutator { } return Expr(nullptr); }}, + {Op::Get("dyn.full"), + [](const CallNode* call_node) { + if (const ConstantNode* shape = call_node->args[1].as()) { + CHECK_EQ(shape->data->ndim, 1); + const InitOpAttrs* param = call_node->attrs.as(); + CHECK(param); + return MakeFull(call_node->args[0], ToVector(shape->data), param->dtype); + } + return Expr(nullptr); + }}, }; } diff --git a/src/relay/transforms/pattern_util.h b/src/relay/transforms/pattern_util.h index ee65503..a7063f5 100644 --- a/src/relay/transforms/pattern_util.h +++ b/src/relay/transforms/pattern_util.h @@ -596,7 +596,7 @@ static inline Expr GreaterEqual(const Expr& lhs, const Expr& rhs) { } static inline Expr Full(Expr fill_value, Array shape, DataType dtype) { - return MakeFull(fill_value, CheckConstantShape(shape), dtype); + return MakeFull(fill_value, CheckConstantShapeArrayInteger(shape), dtype); } static inline Expr Conv2D(Expr data, Expr weight, Array strides, diff --git a/tests/python/relay/dyn/test_dynamic_op_level3.py b/tests/python/relay/dyn/test_dynamic_op_level3.py index ff98c48..91e9cc7 100644 --- a/tests/python/relay/dyn/test_dynamic_op_level3.py +++ b/tests/python/relay/dyn/test_dynamic_op_level3.py @@ -103,19 +103,27 @@ def test_dyn_zeros_ones(): func = relay.Function([dyn_shape], y) ref_res = ref(shape, dtype) - for target, ctx in ctx_list(): - if (target != 'cuda'): #skip cuda because no dynamic support for GPU - for kind in ["vm", "debug"]: - mod = tvm.ir.IRModule.from_expr(func) - intrp = relay.create_executor(kind, mod=mod, ctx=ctx, target=target) - op_res = intrp.evaluate(func)(np.array(shape).astype('int64')) - tvm.testing.assert_allclose(op_res.asnumpy(), ref_res, rtol=1e-5) + verify_func(func, [np.array(shape).astype('int64')], ref_res.astype('int64')) + verify_zeros_ones((1, 3), 'int64') + verify_zeros_ones((8, 9, 1, 2), 'float32') +def test_dyn_full(): + def verify_full(fill_value, src_shape, dtype): + x = relay.var("x", relay.scalar_type(dtype)) + rank = len(src_shape) + dyn_src_shape = relay.var("dyn_scr_shape", relay.ty.TensorType((rank,), 'int64')) + z = relay.full(x, dyn_src_shape, dtype) + func = relay.Function([x, dyn_src_shape], z) + ref_res = np.full(src_shape, fill_value).astype(dtype) - verify_zeros_ones((124, 50), 'float64') + verify_func(func, [np.array(fill_value).astype(dtype), np.array(src_shape).astype('int64')], ref_res) + verify_full(4, (1, 3, 4, 4), 'int32') + verify_full(4, (1, 3, 4, 4), 'int64') + verify_full(4.0, (2, 50), 'float32') if __name__ == "__main__": test_dyn_reshape() test_dyn_shape_reshape() test_dyn_tile() test_dyn_zeros_ones() + test_dyn_full() diff --git a/tests/python/relay/dyn/test_dynamic_op_level6.py b/tests/python/relay/dyn/test_dynamic_op_level6.py index 60a1433..ddfab55 100644 --- a/tests/python/relay/dyn/test_dynamic_op_level6.py +++ b/tests/python/relay/dyn/test_dynamic_op_level6.py @@ -73,4 +73,4 @@ def test_dynamic_topk(): if __name__ == "__main__": - test_topk() + test_dynamic_topk() diff --git a/tests/python/relay/test_op_level3.py b/tests/python/relay/test_op_level3.py index db45fcb..745130d 100644 --- a/tests/python/relay/test_op_level3.py +++ b/tests/python/relay/test_op_level3.py @@ -460,6 +460,7 @@ def test_full(): op_res = intrp.evaluate(func)(np.array(fill_value, dtype)) tvm.testing.assert_allclose(op_res.asnumpy(), ref_res, rtol=1e-5) verify_full(4, (1, 3, 4, 4), "int32") + #verify_full(4, (1, 3, 4, 4), "int64") # This does not pass, python int32 is not upcast to int64, not sure how to fix it. verify_full(4.0, (1, 4), "float32") diff --git a/tests/python/relay/test_pass_dynamic_to_static.py b/tests/python/relay/test_pass_dynamic_to_static.py index 5342f2d..c61f169 100644 --- a/tests/python/relay/test_pass_dynamic_to_static.py +++ b/tests/python/relay/test_pass_dynamic_to_static.py @@ -301,6 +301,25 @@ def test_dynamic_to_static_one_hot(): _verify((3, 2, 4, 5), 6, 1, 0, 1, "int32") _verify((3, 2, 4, 5), 6, 1.0, 0.0, 0, "float32") +def test_dynamic_to_static_full(): + def verify_full(fill_value, fill_shape, dtype): + x = relay.var("x", relay.scalar_type(dtype)) + y = relay.var("y", relay.TensorType(fill_shape, 'int64')) + z = relay.full(x, relay.shape_of(y), dtype) + + func = run_infer_type(relay.Function([x, y], z)) + func2 = run_opt_pass(run_opt_pass(func, transform.DynamicToStatic()), transform.InferType()) + + zz = func2.body + assert isinstance(zz, relay.Call) + assert zz.checked_type == relay.TensorType(fill_shape, dtype) + + ref_res = np.full(fill_shape, fill_value).astype(dtype) + y_data = np.random.uniform(low=-1, high=1, size=fill_shape).astype('int64') + verify_func(func2, [fill_value, y_data], ref_res) + + verify_full(4, (1, 2, 3, 4), 'int32') + verify_full(4.0, (1, 2, 8, 10), 'float32') if __name__ == "__main__": test_dynamic_to_static_reshape() @@ -312,3 +331,4 @@ if __name__ == "__main__": test_dynamic_to_static_zeros_ones() test_dynamic_to_static_resize() test_dynamic_to_static_one_hot() + test_dynamic_to_static_full() -- 2.7.4