[TF] Fix some shape mismatches between TF and Relay (#6166)
authorlixiaoquan <radioheads@163.com>
Wed, 29 Jul 2020 17:34:32 +0000 (01:34 +0800)
committerGitHub <noreply@github.com>
Wed, 29 Jul 2020 17:34:32 +0000 (10:34 -0700)
Make ndarray_size output scalar
  Make gather_nd output scalar if needed

src/relay/op/tensor/transform.cc
src/relay/op/tensor/unary.cc
src/relay/transforms/fold_constant.cc
tests/python/frontend/tensorflow/test_forward.py
tests/python/relay/test_pass_fold_constant.py
topi/include/topi/transform.h
topi/tests/python/test_topi_transform.py

index 7ebca66..99a1f59 100644 (file)
@@ -2740,9 +2740,6 @@ bool GatherNDRel(const Array<Type>& types, int num_inputs, const Attrs& attrs,
   Array<IndexExpr> oshape;
   for (size_t i = 1; i < kdim + 1; ++i) oshape.push_back(indices->shape[i]);
   for (size_t i = mdim->value; i < ndim; ++i) oshape.push_back(data->shape[i]);
-  if (oshape.size() == 0) {
-    oshape.push_back(tir::make_const(DataType::Int(32), 1));
-  }
   reporter->Assign(types[2], TensorType(oshape, data->dtype));
   return true;
 }
index fc61661..5809798 100644 (file)
@@ -462,7 +462,7 @@ bool NdarraySizeRel(const Array<Type>& types, int num_inputs, const Attrs& attrs
   CHECK(tt != nullptr);
   const auto* param = attrs.as<NdarraySizeAttrs>();
   CHECK(param != nullptr);
-  reporter->Assign(types[1], TensorType({1}, param->dtype));
+  reporter->Assign(types[1], TensorType({}, param->dtype));
   return true;
 }
 
index 3f5ecaa..b077a8a 100644 (file)
@@ -288,7 +288,7 @@ class ConstantFolder : public ExprMutator {
     ctx.device_id = 0;
     runtime::NDArray value;
     DLDataType cdtype = DataType::Int(32);
-    value = runtime::NDArray::Empty({1}, cdtype, ctx);
+    value = runtime::NDArray::Empty({}, cdtype, ctx);
     int32_t* data = static_cast<int32_t*>(value->data);
     if (ishape.size() == 0) {
       *data = 0;
index 4a4a2cd..e6f1687 100644 (file)
@@ -73,7 +73,7 @@ tf_dtypes = {
 
 def vmobj_to_list(o):
     if isinstance(o, tvm.nd.NDArray):
-        return [o.asnumpy().tolist()]
+        return [o.asnumpy()]
     elif isinstance(o, tvm.runtime.container.ADT):
         result = []
         for f in o:
@@ -211,6 +211,8 @@ def compare_tf_with_tvm(in_data, in_name, out_name, init_global_variables=False,
             # since the names from tensorflow and relay runs are not exactly same,
             # first len(tf_output) will be compared
             for i in range(len(tf_output)):
+                if not isinstance(tf_output[i], np.ndarray):
+                    assert len(tvm_output[i].shape) == 0
                 tvm.testing.assert_allclose(
                     tf_output[i], tvm_output[i], atol=1e-5, rtol=1e-5)
 
index e985268..e0c813d 100644 (file)
@@ -175,7 +175,7 @@ def test_fold_ndarray_size():
     def expected(dtype):
         x = relay.var("x", shape=c_shape, dtype="float32")
         y = relay.var("y", shape=c_shape, dtype="float32")
-        z = relay.const([np.size(np.zeros(c_shape))], dtype=dtype)
+        z = relay.const(np.size(np.zeros(c_shape)), dtype=dtype)
         func = relay.Function([x, y], z)
         return func
 
index b5fc02a..0b339d2 100644 (file)
@@ -1126,9 +1126,6 @@ inline Tensor gather_nd(const Tensor& data, const Tensor& indices, std::string n
   for (size_t i = indices_dim0; i < ndim_d; ++i) {
     out_shape.push_back(data->shape[i]);
   }
-  if (out_shape.size() == 0) {
-    out_shape.push_back(make_const(DataType::Int(32), 1));
-  }
   return compute(
       out_shape,
       [&](const Array<Var>& out_index) {
@@ -1401,7 +1398,7 @@ inline Tensor ndarray_size(const Tensor& src, const DataType& dtype,
                            const std::string& name = "ndarray_size",
                            const std::string& tag = kInjective) {
   int ndim = static_cast<int>(src->shape.size());
-  Array<PrimExpr> out_ndarray_size = {1};
+  Array<PrimExpr> out_ndarray_size = {};
   return compute(
       out_ndarray_size,
       [&](const Array<Var>& indices) {
index b0aee6a..ee7f114 100644 (file)
@@ -1029,7 +1029,7 @@ def test_ndarray_size():
             print("Skip because %s is not enabled" % device)
             return
         tvm_input = tvm.nd.array(input, ctx=ctx)
-        tvm_output = tvm.nd.empty((1,), ctx=ctx, dtype=B.dtype)
+        tvm_output = tvm.nd.empty((), ctx=ctx, dtype=B.dtype)
         print("Running on target: %s" % device)
         with tvm.target.create(device):
             s = topi.testing.get_injective_schedule(device)(B)