return chunks
return _impl
-def _matmul():
- def _impl(inputs, input_types):
- data0 = inputs[0]
- data1 = inputs[1]
- data1_t = _op.transpose(data1, axes=(1, 0))
+def _matmul(prelude):
+ def _impl(inputs, input_types):
+
+ inputs_0 = inputs[0]
+ inputs_1 = inputs[1]
+
+ # Need to check input shape as batch matmul must be supported.
+ a_shape = _infer_shape(inputs_0, prelude.mod)
+ b_shape = _infer_shape(inputs_1, prelude.mod)
+
+ # When performing a batch matmul, we need to properly handle N-dim shapes.
+ if len(a_shape) > 2 or len(b_shape) > 2:
+ # Convert a and b into 3 dimensional tensors.
+ a = _op.reshape(inputs_0, [-1, a_shape[-2], a_shape[-1]])
+ b = _op.reshape(inputs_1, [-1, b_shape[-2], b_shape[-1]])
+ # Broadcast b to match batch size of a
+ new_b_shape = list(_infer_shape(b, prelude.mod))
+ new_a_shape = _infer_shape(a, prelude.mod)
+ if new_a_shape[0] > new_b_shape[0]:
+ new_b_shape[0] = new_a_shape[0]
+ b = _op.broadcast_to(b, new_b_shape)
+ # Transpose matrix dimensions of b.
+ b = _op.transpose(b, [0, 2, 1])
+ # Perform a batch matmul.
+ output = _op.nn.batch_matmul(a, b)
+ # Reshape output to original dimensions.
+ return _op.reshape(output, [*a_shape[:-2], a_shape[-2], b_shape[-1]])
+
+ # Otherwise a simple dense op will get the job done.
+ if len(b_shape) == 1:
+ input_1 = _op.expand_dims(inputs_1, 0, 1)
+ else:
+ input_1 = _op.transpose(inputs_1, axes=(1, 0))
+
+ out = _op.nn.dense(inputs_0, input_1)
+
+ if len(b_shape) == 1:
+ out = _op.squeeze(out, axis=[-1])
+
+ return out
- return _op.nn.dense(data0, data1_t)
return _impl
"aten::alpha_dropout" : _dropout(),
"aten::mean" : _mean(),
"aten::chunk" : _chunk(prelude),
- "aten::matmul" : _matmul(),
+ "aten::matmul" : _matmul(prelude),
"aten::expand" : _expand(),
"aten::Int" : _int(),
"prim::NumToTensor" : _numtotensor(),
"aten::rsub" : _rsub(),
"aten::embedding" : _embedding(),
"aten::one_hot" : _one_hot(),
- "aten::mm" : _matmul(),
+ "aten::mm" : _matmul(prelude),
"relay::tensor_array_stack" : _tensor_array_stack(prelude),
"aten::add" : _add(prelude),
"aten::add_" : _add(prelude),
verify_model(Addcmul2().float().eval(), input_data=[input_data, t1, t2])
+def test_forward_matmul():
+ torch.set_grad_enabled(False)
+
+ class MatMul1(Module):
+ def forward(self, *args):
+ return torch.matmul(args[0], args[1])
+
+ # matrix x vector
+ tensor1 = torch.randn(3, 4)
+ tensor2 = torch.randn(4)
+ verify_model(MatMul1().float().eval(), input_data=[tensor1, tensor2])
+
+ # matrix x matrix
+ tensor1 = torch.randn(10, 4)
+ tensor2 = torch.randn(4, 10)
+ verify_model(MatMul1().float().eval(), input_data=[tensor1, tensor2])
+
+ # batched matrix x batched matrix
+ tensor1 = torch.randn(10, 3, 4)
+ tensor2 = torch.randn(10, 4, 5)
+ verify_model(MatMul1().float().eval(), input_data=[tensor1, tensor2])
+
+ # batched matrix x broadcasted matrix
+ tensor1 = torch.randn(10, 3, 4)
+ tensor2 = torch.randn(4, 5)
+ verify_model(MatMul1().float().eval(), input_data=[tensor1, tensor2])
+
+ # batched matrix x batched matrix
+ tensor1 = torch.randn(1, 12, 14, 64)
+ tensor2 = torch.randn(1, 12, 64, 14)
+ verify_model(MatMul1().float().eval(), input_data=[tensor1, tensor2])
+
+
if __name__ == "__main__":
# Single operator tests
test_forward_add()
test_forward_subtract()
test_forward_multiply()
+ test_forward_matmul()
test_forward_rsub()
test_forward_onehot()
test_forward_embedding()