[Relay][Op] Add type check to dense (#4724)
authorWei Chen <ipondering.weic@gmail.com>
Thu, 16 Jan 2020 17:01:24 +0000 (09:01 -0800)
committerTianqi Chen <tqchen@users.noreply.github.com>
Thu, 16 Jan 2020 17:01:24 +0000 (09:01 -0800)
src/relay/op/nn/nn.h
tests/python/relay/test_op_level1.py

index 1b27dea..7389909 100644 (file)
@@ -57,6 +57,11 @@ bool DenseRel(const Array<Type>& types, int num_inputs, const Attrs& attrs,
   } else {
     if (weight == nullptr) return false;
     Array<tvm::PrimExpr> wshape = weight->shape;
+    CHECK(static_cast<int>(weight->shape.size()) == 2);
+    CHECK(reporter->AssertEQ(data->shape[data->shape.size() - 1],
+                             weight->shape[1]))
+        << "DenseRel: input dimension doesn't match,"
+        << " data shape=" << data->shape << ", weight shape=" << weight->shape;
     oshape.Set((oshape.size() - 1), wshape[0]);
   }
 
index adfcbb1..f73826e 100644 (file)
@@ -15,6 +15,7 @@
 # specific language governing permissions and limitations
 # under the License.
 import numpy as np
+import pytest
 import tvm
 import scipy
 from tvm import relay
@@ -336,6 +337,16 @@ def test_batch_norm():
             relay.ty.TensorType((3,), dtype)
         ]))
 
+@pytest.mark.xfail
+def test_dense_type_check():
+    dtype = 'float16'
+    n, c , h, w = 2, 2 , 2 ,2
+    x = relay.var("x", relay.TensorType((n, c, h, w), dtype))
+    # it should fail since it does not match with m(2)
+    mismatch_w = 3
+    w = relay.var("w", relay.TensorType((2, mismatch_w), dtype))
+    y = relay.nn.dense(x, w)
+    yy = run_infer_type(y)
 
 def test_dense():
     for dtype in ['float16', 'float32']: