From 959cff1c786e0eb33b99007be66de61d2275d7a5 Mon Sep 17 00:00:00 2001 From: lixiaoquan Date: Sat, 25 Jul 2020 23:16:06 +0800 Subject: [PATCH] [Relay] Fix interpreter for dyanmic shape input of ndarray_size (#6086) --- src/relay/backend/interpreter.cc | 14 ++------------ tests/python/relay/test_any.py | 22 ++++++++++++++++++++-- 2 files changed, 22 insertions(+), 14 deletions(-) diff --git a/src/relay/backend/interpreter.cc b/src/relay/backend/interpreter.cc index 9a75c0a..08c5a7c 100644 --- a/src/relay/backend/interpreter.cc +++ b/src/relay/backend/interpreter.cc @@ -213,11 +213,7 @@ class Interpreter : public ExprFunctor, PatternFunctor { public: Interpreter(IRModule mod, DLContext context, Target target) - : mod_(mod), - context_(context), - target_(target), - debug_op_(Op::Get("debug")), - shape_of_op_(Op::Get("shape_of")) { + : mod_(mod), context_(context), target_(target), debug_op_(Op::Get("debug")) { engine_ = CompileEngine::Global(); } @@ -481,12 +477,7 @@ class Interpreter : public ExprFunctor, Array out_shapes; auto ret_type = func->body->checked_type(); - bool is_dyn = IsDynamic(func->checked_type()); - if (call_node->op == shape_of_op_) { - // The output shape of shape_of must be static since Relay doesn't support - // dynamic rank tensors. - is_dyn = false; - } + bool is_dyn = IsDynamic(ret_type); if (is_dyn) { CHECK(func->HasNonzeroAttr(attr::kPrimitive)); @@ -722,7 +713,6 @@ class Interpreter : public ExprFunctor, CompileEngine engine_; // Cache ops that need to be frequently used later to reduce lookup overhead. const Op& debug_op_; - const Op& shape_of_op_; }; TypedPackedFunc CreateInterpreter(IRModule mod, DLContext context, Target target) { diff --git a/tests/python/relay/test_any.py b/tests/python/relay/test_any.py index bf28ee1..0e8a328 100644 --- a/tests/python/relay/test_any.py +++ b/tests/python/relay/test_any.py @@ -814,7 +814,7 @@ def test_mixed_input_type(): assert result.asnumpy().shape == ref_out_shape, \ "Shape mismatch: expect %s but got %s." % (str(ref_out_shape), str(result.asnumpy().shape)) -def verify_any_crop_and_resize(data_shape, boxes_shape, box_indices_shape, crop_size, +def verify_any_crop_and_resize(data_shape, boxes_shape, box_indices_shape, crop_size, layout, static_boxes, static_box_indices_shape, ref_out_shape): mod = tvm.IRModule() dtype = "float32" @@ -872,6 +872,24 @@ def test_any_mirror_pad(): static_data_shape=(1, 256, 232, 232), ref_out_shape=(1, 256, 234, 234)) +def verify_any_ndarray_size(data_np_shape): + v = relay.var("v", shape=any_dims(len(data_np_shape)), dtype='float32') + n = relay.ndarray_size(v, dtype='int32') + mod = tvm.IRModule() + mod['main'] = relay.Function([v], n) + np_data = np.zeros(data_np_shape, dtype='float32') + ref_res = np.size(np_data) + + for kind in ["debug", "vm"]: + ex = relay.create_executor(kind, mod=mod, ctx=tvm.cpu(), target="llvm") + result = ex.evaluate()(np_data) + tvm.testing.assert_allclose(result.asnumpy(), ref_res) + +def test_any_ndarray_size(): + verify_any_ndarray_size((2,)) + verify_any_ndarray_size((2, 2)) + verify_any_ndarray_size((1, 2, 3, 4)) + if __name__ == "__main__": test_any_full() test_any_full_like() @@ -908,4 +926,4 @@ if __name__ == "__main__": test_mixed_input_type() test_any_crop_and_resize() test_any_mirror_pad() - + test_any_ndarray_size() -- 2.7.4