} else {
if (weight == nullptr) return false;
Array<tvm::PrimExpr> wshape = weight->shape;
+ CHECK(static_cast<int>(weight->shape.size()) == 2);
+ CHECK(reporter->AssertEQ(data->shape[data->shape.size() - 1],
+ weight->shape[1]))
+ << "DenseRel: input dimension doesn't match,"
+ << " data shape=" << data->shape << ", weight shape=" << weight->shape;
oshape.Set((oshape.size() - 1), wshape[0]);
}
# specific language governing permissions and limitations
# under the License.
import numpy as np
+import pytest
import tvm
import scipy
from tvm import relay
relay.ty.TensorType((3,), dtype)
]))
+@pytest.mark.xfail
+def test_dense_type_check():
+ dtype = 'float16'
+ n, c , h, w = 2, 2 , 2 ,2
+ x = relay.var("x", relay.TensorType((n, c, h, w), dtype))
+ # it should fail since it does not match with m(2)
+ mismatch_w = 3
+ w = relay.var("w", relay.TensorType((2, mismatch_w), dtype))
+ y = relay.nn.dense(x, w)
+ yy = run_infer_type(y)
def test_dense():
for dtype in ['float16', 'float32']: