[nnvm] fix nnvm compiler build module error (#3378)
authorHowave <myhouseng@gmail.com>
Mon, 17 Jun 2019 16:56:58 +0000 (00:56 +0800)
committerTianqi Chen <tqchen@users.noreply.github.com>
Mon, 17 Jun 2019 16:56:58 +0000 (09:56 -0700)
nnvm/python/nnvm/compiler/build_module.py

index 30ddc7c..c52265f 100644 (file)
@@ -148,7 +148,7 @@ def _update_shape_dtype(shape, dtype, params):
     shape.update({k : v.shape for k, v in params.items()})
     if isinstance(dtype, str):
         for k, v in params.items():
-            if v.dtype != dtype:
+            if v.dtype != dtype and v.shape:
                 raise ValueError(
                     "%s: dtype not expected %s vs %s" % (k, dtype, v.dtype))
     else: