[Relay][Frontend][ONNX] Add support for broadcasting to Where and MatMul (#4267)
authorJon Soifer <soiferj@gmail.com>
Thu, 7 Nov 2019 22:10:30 +0000 (14:10 -0800)
committerTianqi Chen <tqchen@users.noreply.github.com>
Thu, 7 Nov 2019 22:10:30 +0000 (14:10 -0800)
python/tvm/relay/frontend/onnx.py
tests/python/frontend/onnx/test_forward.py

index a28b8f6..0c581c9 100644 (file)
@@ -298,6 +298,12 @@ class MatMul(OnnxOpConverter):
             # Convert a and b into 3 dimensional tensors.
             a = _op.reshape(inputs[0], [-1, a_shape[-2], a_shape[-1]])
             b = _op.reshape(inputs[1], [-1, b_shape[-2], b_shape[-1]])
+            # Broadcast b to match batch size of a
+            new_b_shape = list(infer_shape(b))
+            new_a_shape = infer_shape(a)
+            if new_a_shape[0] > new_b_shape[0]:
+                new_b_shape[0] = new_a_shape[0]
+                b = _op.broadcast_to(b, new_b_shape)
             # Transpose matrix dimensions of b.
             b = _op.transpose(b, [0, 2, 1])
             # Perform a batch matmul.
@@ -987,6 +993,14 @@ class Where(OnnxOpConverter):
     """
     @classmethod
     def _impl_v9(cls, inputs, attr, params):
+        # x and y can be broadcasted
+        condition_shape = infer_shape(inputs[0])
+        x_shape = infer_shape(inputs[1])
+        y_shape = infer_shape(inputs[2])
+        if len(condition_shape) > len(x_shape):
+            inputs[1] = _op.broadcast_to(inputs[1], condition_shape)
+        if len(condition_shape) > len(y_shape):
+            inputs[2] = _op.broadcast_to(inputs[2], condition_shape)
         return _op.where(inputs[0], inputs[1], inputs[2])
 
 class Or(Elemwise):
@@ -996,6 +1010,7 @@ class Or(Elemwise):
     def _impl_v7(cls, inputs, attr, params):
         return _op.logical_or(inputs[0], inputs[1])
 
+
 # compatible operators that do NOT require any conversion.
 _identity_list = []
 
index 5dfaee4..6391a1a 100644 (file)
@@ -498,11 +498,7 @@ def test_matmul():
             model, [a_array, b_array], target, ctx, out_np.shape)
         tvm.testing.assert_allclose(out_np, tvm_out, rtol=1e-5, atol=1e-5)
 
-
-def test_batch_matmul():
-    a_shape = (2, 3, 4, 3)
-    b_shape = (2, 3, 3, 4)
-
+def verify_batch_matmul(a_shape, b_shape):
     a_array = np.random.uniform(size=a_shape).astype('float32')
     b_array = np.random.uniform(size=b_shape).astype('float32')
     out_np = np.matmul(a_array, b_array)
@@ -525,6 +521,10 @@ def test_batch_matmul():
             model, [a_array, b_array], target, ctx, out_np.shape)
         tvm.testing.assert_allclose(out_np, tvm_out, rtol=1e-5, atol=1e-5)
 
+def test_batch_matmul():
+    verify_batch_matmul((2, 3, 4, 3), (2, 3, 3, 4))
+    verify_batch_matmul((2, 4, 3), (3, 4))
+    verify_batch_matmul((2, 3, 4, 3), (3, 4))
 
 def verify_lrn(shape, nsize, dtype, alpha=None, beta=None, bias=None):
     in_array = np.random.uniform(size=shape).astype(dtype)
@@ -1600,6 +1600,11 @@ def test_where():
     outdata = np.where(condition, x, y)
     verify_where(condition, x, y, TensorProto.FLOAT, outdata)
 
+    x = np.array(1, dtype=np.float32)
+    y = np.array([2], dtype=np.float32)
+    outdata = np.where(condition, x, y)
+    verify_where(condition, x, y, TensorProto.FLOAT, outdata)
+
 
 def verify_or(indata, dtype):
     x = indata[0].astype(dtype)