From 123a40774c0a05311b37e55857fcea6042b1bb3f Mon Sep 17 00:00:00 2001 From: Haichen Shen Date: Thu, 12 Dec 2019 14:33:57 -0800 Subject: [PATCH] [Hybrid][Fix] Fix hybrid script to support array of tensors (#4494) * [Fix][Hybrid] Fix hybrid script to support array of tensors * add test case * clean up * trigger ci --- python/tvm/hybrid/parser.py | 10 +++++++-- tests/python/unittest/test_hybrid_script.py | 32 +++++++++++++++++++++++++++++ 2 files changed, 40 insertions(+), 2 deletions(-) diff --git a/python/tvm/hybrid/parser.py b/python/tvm/hybrid/parser.py index 44db999..816a0e1 100644 --- a/python/tvm/hybrid/parser.py +++ b/python/tvm/hybrid/parser.py @@ -647,9 +647,15 @@ def source_to_op(src, args, symbols, closure_vars): parser = parse_python(src, args, symbols, closure_vars) input_tensors = [] + def get_input_tensors(arg): + if isinstance(arg, Tensor): + input_tensors.append(arg) + elif isinstance(arg, Array): + for i in arg: + get_input_tensors(i) + for i in args: - if isinstance(i, Tensor): - input_tensors.append(i) + get_input_tensors(i) op = _tvm_internal._HybridOp(parser.func_name, "HybridOp", None, input_tensors, parser.outputs, parser.parsed_body) res = [op.output(i) for i in range(len(parser.outputs))] diff --git a/tests/python/unittest/test_hybrid_script.py b/tests/python/unittest/test_hybrid_script.py index 3e93719..1f101a1 100644 --- a/tests/python/unittest/test_hybrid_script.py +++ b/tests/python/unittest/test_hybrid_script.py @@ -789,6 +789,37 @@ def test_capture(): func, ins, outs = run_and_check(add_something, [a]) run_and_check(func, ins, outs=outs) +def test_array_inputs(): + @script + def sum_array(inputs): + out = output_tensor((10,), inputs[0].dtype) + n = len(inputs) + for i in range(10): + for j in const_range(n): + out[i] += inputs[j][i] + return out + n = 5 + inputs = [] + for i in range(n): + inputs.append(tvm.placeholder((10,), name='t%s' % i, dtype='float32')) + + out = sum_array(tvm.convert(inputs)) + assert len(out.op.inputs) == n + + sch = tvm.create_schedule(out.op) + mod = tvm.build(sch, inputs + [out], target='llvm') + assert mod + + input_nd = [] + out_ref = numpy.zeros((10,)) + for _ in range(n): + arr = numpy.random.uniform(size=(10,)).astype('float32') + input_nd.append(tvm.nd.array(arr)) + out_ref += arr + out_nd = tvm.nd.array(numpy.zeros((10,), 'float32')) + mod(*input_nd, out_nd) + tvm.testing.assert_allclose(out_nd.asnumpy(), out_ref) + if __name__ == "__main__": test_outer_product() test_fanout() @@ -807,5 +838,6 @@ if __name__ == "__main__": test_const_range() test_schedule() test_capture() + test_array_inputs() # TODO: # test_inplace() -- 2.7.4