# Convert a and b into 3 dimensional tensors.
a = _op.reshape(inputs[0], [-1, a_shape[-2], a_shape[-1]])
b = _op.reshape(inputs[1], [-1, b_shape[-2], b_shape[-1]])
+ # Broadcast b to match batch size of a
+ new_b_shape = list(infer_shape(b))
+ new_a_shape = infer_shape(a)
+ if new_a_shape[0] > new_b_shape[0]:
+ new_b_shape[0] = new_a_shape[0]
+ b = _op.broadcast_to(b, new_b_shape)
# Transpose matrix dimensions of b.
b = _op.transpose(b, [0, 2, 1])
# Perform a batch matmul.
"""
@classmethod
def _impl_v9(cls, inputs, attr, params):
+ # x and y can be broadcasted
+ condition_shape = infer_shape(inputs[0])
+ x_shape = infer_shape(inputs[1])
+ y_shape = infer_shape(inputs[2])
+ if len(condition_shape) > len(x_shape):
+ inputs[1] = _op.broadcast_to(inputs[1], condition_shape)
+ if len(condition_shape) > len(y_shape):
+ inputs[2] = _op.broadcast_to(inputs[2], condition_shape)
return _op.where(inputs[0], inputs[1], inputs[2])
class Or(Elemwise):
def _impl_v7(cls, inputs, attr, params):
return _op.logical_or(inputs[0], inputs[1])
+
# compatible operators that do NOT require any conversion.
_identity_list = []
model, [a_array, b_array], target, ctx, out_np.shape)
tvm.testing.assert_allclose(out_np, tvm_out, rtol=1e-5, atol=1e-5)
-
-def test_batch_matmul():
- a_shape = (2, 3, 4, 3)
- b_shape = (2, 3, 3, 4)
-
+def verify_batch_matmul(a_shape, b_shape):
a_array = np.random.uniform(size=a_shape).astype('float32')
b_array = np.random.uniform(size=b_shape).astype('float32')
out_np = np.matmul(a_array, b_array)
model, [a_array, b_array], target, ctx, out_np.shape)
tvm.testing.assert_allclose(out_np, tvm_out, rtol=1e-5, atol=1e-5)
+def test_batch_matmul():
+ verify_batch_matmul((2, 3, 4, 3), (2, 3, 3, 4))
+ verify_batch_matmul((2, 4, 3), (3, 4))
+ verify_batch_matmul((2, 3, 4, 3), (3, 4))
def verify_lrn(shape, nsize, dtype, alpha=None, beta=None, bias=None):
in_array = np.random.uniform(size=shape).astype(dtype)
outdata = np.where(condition, x, y)
verify_where(condition, x, y, TensorProto.FLOAT, outdata)
+ x = np.array(1, dtype=np.float32)
+ y = np.array([2], dtype=np.float32)
+ outdata = np.where(condition, x, y)
+ verify_where(condition, x, y, TensorProto.FLOAT, outdata)
+
def verify_or(indata, dtype):
x = indata[0].astype(dtype)