[Relay]Improve Shape Func handling for Tuple inputs (#5467)
authorYao Wang <kevinthesunwy@gmail.com>
Mon, 18 May 2020 19:27:32 +0000 (12:27 -0700)
committerGitHub <noreply@github.com>
Mon, 18 May 2020 19:27:32 +0000 (12:27 -0700)
* Improve Shape Func handling for Tuple inputs

* Fix lint

* Improve

* Fix build

src/relay/backend/compile_engine.cc
src/relay/op/memory/memory.cc
tests/python/relay/test_any.py

index 4834fdc..31293a9 100644 (file)
@@ -525,6 +525,13 @@ class MakeShapeFunc : public backend::MemoizedExprTranslator<Array<te::Tensor>>
     return fields;
   }
 
+  Array<te::Tensor> VisitExpr_(const TupleGetItemNode* op) final {
+    Array<te::Tensor> input_shapes = VisitExpr(op->tuple);
+    Array<te::Tensor> out;
+    out.push_back(input_shapes[op->index]);
+    return out;
+  }
+
  private:
   /*! \brief String stream for function name */
   std::ostringstream readable_name_stream_;
index 76a3315..e5081ad 100644 (file)
@@ -368,12 +368,23 @@ bool ShapeFuncRel(const Array<Type>& types, int num_inputs, const Attrs& attrs,
   auto tuple = TupleType(func_type->arg_types);
   auto in_types = FlattenTupleType(tuple);
   auto out_types = FlattenTupleType(func_type->ret_type);
+  Array<Integer> is_input;
+  for (size_t i = 0; i < func_type->arg_types.size(); ++i) {
+    auto const& aty = func_type->arg_types[i];
+    size_t num_types = 1;
+    if (aty.as<TupleTypeNode>()) {
+      num_types = FlattenTupleType(aty).size();
+    }
+    for (size_t j = 0; j < num_types; ++j) {
+      is_input.push_back(shape_func_attrs->is_input[i]);
+    }
+  }
 
   Array<Type> shape_func_ins, shape_func_outs;
   for (size_t i = 0; i < in_types.size(); i++) {
     auto in_type = in_types[i];
 
-    if (shape_func_attrs->is_input[i]) {
+    if (is_input[i]) {
       shape_func_ins.push_back(in_type);
     } else {
       auto shape = RankShape(in_type->shape);
index c9de675..5e5542d 100644 (file)
@@ -680,6 +680,47 @@ def test_recursive_concat_with_wrong_annotation():
     except Exception as e:
         assert "in particular dimension 0 conflicts 2 does not match 1" in str(e)
 
+def test_tuple_get_item():
+    mod = tvm.IRModule()
+    dtype = "float32"
+    static_data_shape = (9, 4)
+    data_shape = (relay.Any(), 4)
+    indices_or_sections = 2
+    axis = 1
+    data = relay.var('data', shape=data_shape, dtype=dtype)
+    y = relay.split(data, indices_or_sections, axis)
+    y = relay.expr.TupleGetItem(y.astuple(), 0)
+    mod["main"] = relay.Function([data], y)
+    data_np = np.random.uniform(size=static_data_shape).astype(dtype)
+    ref_out_shape = (9, 2)
+    for kind in ["vm"]:
+        ex = relay.create_executor(kind, mod=mod, ctx=tvm.cpu(), target="llvm")
+        result = ex.evaluate()(data_np)
+        assert result.asnumpy().shape == ref_out_shape, \
+            "Shape mismatch: expect %s but got %s." % (str(ref_out_shape), str(ret.asnumpy().shape))
+
+def test_mixed_input_type():
+    mod = tvm.IRModule()
+    dtype = "float32"
+    static_data_shape = (9, 4)
+    data_shape = (relay.Any(), 4)
+    tensor_type = relay.TensorType(data_shape, dtype)
+    tuple_type = relay.TupleType([tensor_type, tensor_type])
+    data0 = relay.var("d0", type_annotation=relay.TupleType([tuple_type, tensor_type]))
+    data1 = relay.var("d1", shape=(relay.Any(), 4), dtype=dtype)
+    data_tuple = relay.expr.TupleWrapper(data0, 2)
+    nested_data_tuple = relay.expr.TupleWrapper(data_tuple[0], 2)
+    y = nested_data_tuple[1] * data_tuple[1] + data1
+    mod["main"] = relay.Function([data0, data1], y)
+    data_np0 = np.random.uniform(size=static_data_shape).astype(dtype)
+    data_np1 = np.random.uniform(size=static_data_shape).astype(dtype)
+    ref_out_shape = (9, 4)
+    for kind in ["vm"]:
+        ex = relay.create_executor(kind, mod=mod, ctx=tvm.cpu(), target="llvm")
+        result = ex.evaluate()([[data_np0, data_np0], data_np0], data_np1)
+        assert result.asnumpy().shape == ref_out_shape, \
+            "Shape mismatch: expect %s but got %s." % (str(ref_out_shape), str(ret.asnumpy().shape))
+
 if __name__ == "__main__":
     test_any_full()
     test_any_broadcast()
@@ -708,3 +749,6 @@ if __name__ == "__main__":
     test_arange_with_dynamic_shape()
     test_recursive_concat()
     test_recursive_concat_with_wrong_annotation()
+    test_tuple_get_item()
+    test_mixed_input_type()
+