[Relay][Dyn] Dynamic full operator (#6260)
authorLily Orth-Smith <lorthsmith@octoml.ai>
Thu, 13 Aug 2020 16:08:34 +0000 (09:08 -0700)
committerGitHub <noreply@github.com>
Thu, 13 Aug 2020 16:08:34 +0000 (09:08 -0700)
* moved full from other branch

* fixed some typos

* fix lint

* add final newline

* fix int64 test

13 files changed:
python/tvm/relay/op/_tensor.py
python/tvm/relay/op/dyn/_tensor.py
python/tvm/relay/op/dyn/_transform.py
python/tvm/relay/op/transform.py
src/relay/op/dyn/tensor/transform.cc
src/relay/op/make_op.h
src/relay/op/tensor/transform.cc
src/relay/transforms/dynamic_to_static.cc
src/relay/transforms/pattern_util.h
tests/python/relay/dyn/test_dynamic_op_level3.py
tests/python/relay/dyn/test_dynamic_op_level6.py
tests/python/relay/test_op_level3.py
tests/python/relay/test_pass_dynamic_to_static.py

index 28336cf..eccc2c3 100644 (file)
@@ -201,11 +201,11 @@ def elemwise_shape_func(attrs, inputs, _):
     return [topi.math.identity(inputs[0])]
 
 register_shape_func("cast", False, elemwise_shape_func)
-register_shape_func("zeros", True, no_data_full_shape_func)
+register_shape_func("zeros", False, full_shape_func)
 register_shape_func("zeros_like", False, elemwise_shape_func)
-register_shape_func("ones", True, no_data_full_shape_func)
+register_shape_func("ones", False, full_shape_func)
 register_shape_func("ones_like", False, elemwise_shape_func)
-register_shape_func("full", True, full_shape_func)
+register_shape_func("full", False, full_shape_func)
 register_shape_func("full_like", False, elemwise_shape_func)
 register_shape_func("broadcast_to", True, full_shape_func)
 
index 371e4ad..cd53641 100644 (file)
@@ -44,3 +44,4 @@ register_pattern("dyn.zeros", OpPattern.ELEMWISE)
 register_shape_func("dyn.broadcast_to", True, full_shape_func)
 register_shape_func("dyn.ones", True, no_data_full_shape_func)
 register_shape_func("dyn.zeros", True, no_data_full_shape_func)
+register_shape_func("dyn.full", True, full_shape_func)
index 3a80f5a..46778fe 100644 (file)
@@ -26,7 +26,7 @@ _reg.register_broadcast_schedule("dyn.broadcast_to")
 _reg.register_injective_schedule("dyn.reshape")
 _reg.register_broadcast_schedule("dyn.tile")
 _reg.register_injective_schedule("dyn.one_hot")
-
+_reg.register_injective_schedule("dyn.full")
 
 @script
 def _reshape_shape_func_input_data(data, newshape, ndim):
index 5e5b867..b46b156 100644 (file)
@@ -376,8 +376,12 @@ def full(fill_value, shape=(), dtype=""):
     result : relay.Expr
         The resulting tensor.
     """
+    if isinstance(shape, Expr):
+        return _dyn_make.full(fill_value, shape, dtype)
+    if isinstance(shape, int):
+        shape = [shape]
     if isinstance(shape, (list, tuple)):
-        shape = const(list(shape), "int32")
+        shape = list(shape)
     return _make.full(fill_value, shape, dtype)
 
 
index d2d6d69..06e1c57 100644 (file)
@@ -28,6 +28,7 @@
 #include <tvm/relay/op_attr_types.h>
 #include <tvm/runtime/registry.h>
 #include <tvm/topi/broadcast.h>
+#include <tvm/topi/elemwise.h>
 #include <tvm/topi/transform.h>
 
 #include <utility>
@@ -374,6 +375,61 @@ RELAY_REGISTER_OP("dyn.one_hot")
     .set_attr<FTVMCompute>("FTVMCompute", OneHotCompute)
     .set_attr<TOpPattern>("TOpPattern", kOutEWiseFusable);
 
+bool FullRel(const Array<Type>& types, int num_inputs, const Attrs& attrs,
+             const TypeReporter& reporter) {
+  CHECK_EQ(types.size(), 3);
+  const InitOpAttrs* param = attrs.as<InitOpAttrs>();
+  const auto* fill_value = types[0].as<TensorTypeNode>();
+  const auto* fill_shape = types[1].as<TensorTypeNode>();
+  if (fill_value == nullptr) {
+    return false;
+  }
+
+  DataType out_dtype = param->dtype;
+  if (out_dtype.bits() == 0) {
+    out_dtype = fill_value->dtype;
+  }
+
+  CHECK_EQ(fill_value->shape.size(), 0)
+      << "Fill value should be a scalar but has dimension " << fill_value->shape.size() << ".";
+
+  const IntImmNode* rank = fill_shape->shape[0].as<IntImmNode>();
+  CHECK(rank) << "Parameter shape must have static rank";
+
+  std::vector<IndexExpr> oshape;
+  for (int i = 0; i < rank->value; ++i) {
+    oshape.push_back(Any());
+  }
+  reporter->Assign(types[2], TensorType(oshape, out_dtype));
+  return true;
+}
+
+Expr MakeFull(Expr fill_value, Expr shape, DataType dtype) {
+  auto attrs = make_object<InitOpAttrs>();
+  attrs->dtype = std::move(dtype);
+  static const Op& op = Op::Get("dyn.full");
+  return Call(op, {fill_value, shape}, Attrs(attrs), {});
+}
+Array<te::Tensor> FullCompute(const Attrs& attrs, const Array<te::Tensor>& inputs,
+                              const Type& out_type) {
+  const auto* out_ttype = out_type.as<TensorTypeNode>();
+  return {topi::full(out_ttype->shape, out_ttype->dtype, inputs[0]())};
+}
+TVM_REGISTER_GLOBAL("relay.op.dyn._make.full").set_body_typed(MakeFull);
+
+RELAY_REGISTER_OP("dyn.full")
+    .describe(R"code(Fill array with scalar value.
+
+)code" TVM_ADD_FILELINE)
+    .set_attrs_type<InitOpAttrs>()
+    .set_num_inputs(2)
+    .add_argument("fill_value", "double", "The value to fill.")
+    .add_argument("shape", "Tensor", "Target shape.")
+    .set_support_level(3)
+    .add_type_rel("DynamicFull", FullRel)
+    .set_attr<FTVMCompute>("FTVMCompute", FullCompute)
+    .set_attr<TOpPattern>("TOpPattern", kElemWise);
+
 }  // namespace dyn
 }  // namespace relay
 }  // namespace tvm
index 8ca2203..1e17bbe 100644 (file)
@@ -48,7 +48,7 @@ Expr MakeDense(Expr data, Expr weight, IndexExpr units, DataType out_dtype);
 
 Expr MakeExpandDims(Expr data, int axis, int num_newaxis);
 
-Expr MakeFull(Expr fill_value, Expr shape, DataType dtype);
+Expr MakeFull(Expr fill_value, Array<Integer> shape, DataType dtype);
 
 Expr MakeLayoutTransform(Expr data, String src_layout, String dst_layout);
 
index 79a8da4..be7cab1 100644 (file)
@@ -994,10 +994,9 @@ TVM_REGISTER_NODE_TYPE(InitOpAttrs);
 
 bool FullRel(const Array<Type>& types, int num_inputs, const Attrs& attrs,
              const TypeReporter& reporter) {
-  CHECK_EQ(types.size(), 3);
+  CHECK_EQ(types.size(), 2);
   const InitOpAttrs* param = attrs.as<InitOpAttrs>();
   const auto* fill_value = types[0].as<TensorTypeNode>();
-  const auto* fill_shape = types[1].as<TensorTypeNode>();
   if (fill_value == nullptr) {
     return false;
   }
@@ -1010,40 +1009,29 @@ bool FullRel(const Array<Type>& types, int num_inputs, const Attrs& attrs,
   CHECK_EQ(fill_value->shape.size(), 0)
       << "Fill value should be a scalar but has dimension " << fill_value->shape.size() << ".";
 
-  const IntImmNode* shape_shape = fill_shape->shape[0].as<IntImmNode>();
-  CHECK(shape_shape) << "Parameter shape must have static shape";
-
   std::vector<IndexExpr> oshape;
-  if (param->shape) {
-    const Array<Integer>& cshape_array = param->shape.value();
-    for (size_t i = 0; i < cshape_array.size(); ++i) {
-      oshape.push_back(cshape_array[i]);
-    }
-  } else {
-    for (int i = 0; i < shape_shape->value; ++i) {
-      oshape.push_back(Any());
-    }
+  const Array<Integer>& cshape_array = param->shape.value();
+  for (size_t i = 0; i < cshape_array.size(); ++i) {
+    oshape.push_back(cshape_array[i]);
   }
-  reporter->Assign(types[2], TensorType(oshape, out_dtype));
+  reporter->Assign(types[1], TensorType(oshape, out_dtype));
   return true;
 }
 
+Expr MakeFull(Expr fill_value, Array<Integer> shape, DataType dtype) {
+  auto attrs = make_object<InitOpAttrs>();
+  attrs->dtype = std::move(dtype);
+  attrs->shape = std::move(shape);
+  static const Op& op = Op::Get("full");
+  return Call(op, {fill_value}, Attrs(attrs), {});
+}
+
 Array<te::Tensor> FullCompute(const Attrs& attrs, const Array<te::Tensor>& inputs,
                               const Type& out_type) {
   const auto* out_ttype = out_type.as<TensorTypeNode>();
   return {topi::full(out_ttype->shape, out_ttype->dtype, inputs[0]())};
 }
 
-Expr MakeFull(Expr fill_value, Expr shape, DataType dtype) {
-  auto attrs = make_object<InitOpAttrs>();
-  if (const auto* cshape = shape.as<ConstantNode>()) {
-    attrs->shape = ToVector(cshape->data);
-  }
-  attrs->dtype = std::move(dtype);
-  static const Op& op = Op::Get("full");
-  return Call(op, {fill_value, shape}, Attrs(attrs), {});
-}
-
 TVM_REGISTER_GLOBAL("relay.op._make.full").set_body_typed(MakeFull);
 
 RELAY_REGISTER_OP("full")
@@ -1051,9 +1039,8 @@ RELAY_REGISTER_OP("full")
 
 )code" TVM_ADD_FILELINE)
     .set_attrs_type<InitOpAttrs>()
-    .set_num_inputs(2)
+    .set_num_inputs(1)
     .add_argument("fill_value", "double", "The value to fill.")
-    .add_argument("shape", "Tensor", "Target shape.")
     .set_support_level(3)
     .add_type_rel("Full", FullRel)
     .set_attr<FTVMCompute>("FTVMCompute", FullCompute)
index d0a6b07..0ccc4c3 100644 (file)
@@ -114,6 +114,16 @@ class DynamicToStaticMutator : public MixedModeMutator {
            }
            return Expr(nullptr);
          }},
+        {Op::Get("dyn.full"),
+         [](const CallNode* call_node) {
+           if (const ConstantNode* shape = call_node->args[1].as<ConstantNode>()) {
+             CHECK_EQ(shape->data->ndim, 1);
+             const InitOpAttrs* param = call_node->attrs.as<InitOpAttrs>();
+             CHECK(param);
+             return MakeFull(call_node->args[0], ToVector(shape->data), param->dtype);
+           }
+           return Expr(nullptr);
+         }},
     };
   }
 
index ee65503..a7063f5 100644 (file)
@@ -596,7 +596,7 @@ static inline Expr GreaterEqual(const Expr& lhs, const Expr& rhs) {
 }
 
 static inline Expr Full(Expr fill_value, Array<IndexExpr> shape, DataType dtype) {
-  return MakeFull(fill_value, CheckConstantShape(shape), dtype);
+  return MakeFull(fill_value, CheckConstantShapeArrayInteger(shape), dtype);
 }
 
 static inline Expr Conv2D(Expr data, Expr weight, Array<IndexExpr> strides,
index ff98c48..91e9cc7 100644 (file)
@@ -103,19 +103,27 @@ def test_dyn_zeros_ones():
 
             func = relay.Function([dyn_shape], y)
             ref_res = ref(shape, dtype)
-            for target, ctx in ctx_list():
-                if (target != 'cuda'): #skip cuda because no dynamic support for GPU 
-                    for kind in ["vm", "debug"]:
-                        mod = tvm.ir.IRModule.from_expr(func)
-                        intrp = relay.create_executor(kind, mod=mod, ctx=ctx, target=target)
-                        op_res = intrp.evaluate(func)(np.array(shape).astype('int64'))
-                        tvm.testing.assert_allclose(op_res.asnumpy(), ref_res, rtol=1e-5)
+            verify_func(func, [np.array(shape).astype('int64')], ref_res.astype('int64'))
+    verify_zeros_ones((1, 3), 'int64')
+    verify_zeros_ones((8, 9, 1, 2), 'float32')
 
+def test_dyn_full():
+    def verify_full(fill_value, src_shape, dtype):
+        x = relay.var("x", relay.scalar_type(dtype))
+        rank = len(src_shape)
+        dyn_src_shape = relay.var("dyn_scr_shape", relay.ty.TensorType((rank,), 'int64'))
+        z = relay.full(x, dyn_src_shape, dtype)
+        func = relay.Function([x, dyn_src_shape], z)
+        ref_res = np.full(src_shape, fill_value).astype(dtype)
 
-    verify_zeros_ones((124, 50), 'float64')
+        verify_func(func, [np.array(fill_value).astype(dtype), np.array(src_shape).astype('int64')], ref_res)
+    verify_full(4, (1, 3, 4, 4), 'int32')
+    verify_full(4, (1, 3, 4, 4), 'int64')
+    verify_full(4.0, (2, 50), 'float32')
 
 if __name__ == "__main__":
     test_dyn_reshape()
     test_dyn_shape_reshape()
     test_dyn_tile()
     test_dyn_zeros_ones()
+    test_dyn_full()
index 60a1433..ddfab55 100644 (file)
@@ -73,4 +73,4 @@ def test_dynamic_topk():
 
 
 if __name__ == "__main__":
-    test_topk()
+    test_dynamic_topk()
index db45fcb..745130d 100644 (file)
@@ -460,6 +460,7 @@ def test_full():
                 op_res = intrp.evaluate(func)(np.array(fill_value, dtype))
                 tvm.testing.assert_allclose(op_res.asnumpy(), ref_res, rtol=1e-5)
     verify_full(4, (1, 3, 4, 4), "int32")
+    #verify_full(4, (1, 3, 4, 4), "int64") # This does not pass, python int32 is not upcast to int64, not sure how to fix it.
     verify_full(4.0, (1, 4), "float32")
 
 
index 5342f2d..c61f169 100644 (file)
@@ -301,6 +301,25 @@ def test_dynamic_to_static_one_hot():
     _verify((3, 2, 4, 5), 6, 1, 0, 1, "int32")
     _verify((3, 2, 4, 5), 6, 1.0, 0.0, 0, "float32")
 
+def test_dynamic_to_static_full():
+    def verify_full(fill_value, fill_shape, dtype):
+        x = relay.var("x", relay.scalar_type(dtype))
+        y = relay.var("y", relay.TensorType(fill_shape, 'int64'))
+        z = relay.full(x, relay.shape_of(y), dtype)
+
+        func = run_infer_type(relay.Function([x, y], z))
+        func2 = run_opt_pass(run_opt_pass(func, transform.DynamicToStatic()), transform.InferType())
+        
+        zz = func2.body
+        assert isinstance(zz, relay.Call)
+        assert zz.checked_type == relay.TensorType(fill_shape, dtype)
+
+        ref_res = np.full(fill_shape, fill_value).astype(dtype)
+        y_data = np.random.uniform(low=-1, high=1, size=fill_shape).astype('int64')
+        verify_func(func2, [fill_value, y_data], ref_res)
+    
+    verify_full(4, (1, 2, 3, 4), 'int32')
+    verify_full(4.0, (1, 2, 8, 10), 'float32')
 
 if __name__ == "__main__":
     test_dynamic_to_static_reshape()
@@ -312,3 +331,4 @@ if __name__ == "__main__":
     test_dynamic_to_static_zeros_ones()
     test_dynamic_to_static_resize()
     test_dynamic_to_static_one_hot()
+    test_dynamic_to_static_full()