scalar_arg: B
--- !LinalgOpConfig
metadata: !LinalgOpMetadata
+ name: mmt4d
+ cpp_class_name: Mmt4DOp
+ doc: |-
+ Performs a matrix-matrix-transpose multiplication of two 4D inputs.
+
+ Differences from linalg.matmul:
+ * The right hand side is transposed, whence the 't' in 'mmt'.
+ * The input and output tensors have a 4D shape instead of a 2D shape. They
+ are interpreted as 2D matrices with one level of 2D tile subdivision,
+ whence the 2+2=4 dimensions. The inner tile dimensions are identified with
+ '0' suffixes below, for instance the LHS matrix shape (M, K, M0, K0) reads
+ as: MxK tiles, each of shape M0xK0.
+ implements:
+ - LinalgContractionOpInterface
+structured_op: !LinalgStructuredOpConfig
+ args:
+ - !LinalgOperandDefConfig
+ name: lhs
+ usage: InputOperand
+ type_var: LhsType
+ shape_map: affine_map<()[s0, s1, s2, s3, s4, s5] -> (s0, s1, s2, s3)>
+ - !LinalgOperandDefConfig
+ name: rhs
+ usage: InputOperand
+ type_var: RhsType
+ shape_map: affine_map<()[s0, s1, s2, s3, s4, s5] -> (s4, s1, s5, s3)>
+ - !LinalgOperandDefConfig
+ name: accum
+ usage: OutputOperand
+ type_var: AccumType
+ shape_map: affine_map<()[s0, s1, s2, s3, s4, s5] -> (s0, s4, s2, s5)>
+ indexing_maps: !LinalgIndexingMapsConfig
+ static_indexing_maps:
+ - affine_map<(d0, d1, d2, d3, d4, d5)[s0, s1, s2, s3, s4, s5] -> (d0, d4, d1,
+ d5)>
+ - affine_map<(d0, d1, d2, d3, d4, d5)[s0, s1, s2, s3, s4, s5] -> (d2, d4, d3,
+ d5)>
+ - affine_map<(d0, d1, d2, d3, d4, d5)[s0, s1, s2, s3, s4, s5] -> (d0, d2, d1,
+ d3)>
+ iterator_types:
+ - parallel
+ - parallel
+ - parallel
+ - parallel
+ - reduction
+ - reduction
+ assignments:
+ - !ScalarAssign
+ arg: accum
+ value: !ScalarExpression
+ scalar_apply:
+ fn_name: add
+ operands:
+ - !ScalarExpression
+ scalar_arg: accum
+ - !ScalarExpression
+ scalar_apply:
+ fn_name: mul
+ operands:
+ - !ScalarExpression
+ symbolic_cast:
+ type_var: AccumType
+ operands:
+ - !ScalarExpression
+ scalar_arg: lhs
+ - !ScalarExpression
+ symbolic_cast:
+ type_var: AccumType
+ operands:
+ - !ScalarExpression
+ scalar_arg: rhs
+--- !LinalgOpConfig
+metadata: !LinalgOpMetadata
name: batch_matmul
cpp_class_name: BatchMatmulOp
doc: |-
@linalg_structured_op
+def mmt4d(lhs=TensorDef(TV.LhsType, S.M, S.K, S.M0, S.K0),
+ rhs=TensorDef(TV.RhsType, S.N, S.K, S.N0, S.K0),
+ accum=TensorDef(TV.AccumType, S.M, S.N, S.M0, S.N0,
+ output=True)):
+ """Performs a matrix-matrix-transpose multiplication of two 4D inputs.
+
+ Differences from linalg.matmul:
+ * The right hand side is transposed, whence the 't' in 'mmt'.
+ * The input and output tensors have a 4D shape instead of a 2D shape. They
+ are interpreted as 2D matrices with one level of 2D tile subdivision,
+ whence the 2+2=4 dimensions. The inner tile dimensions are identified with
+ '0' suffixes below, for instance the LHS matrix shape (M, K, M0, K0) reads
+ as: MxK tiles, each of shape M0xK0.
+ """
+ domain(D.m, D.m0, D.n, D.n0, D.k, D.k0)
+ implements(ContractionOpInterface)
+ accum[D.m, D.n, D.m0, D.n0] += cast(TV.AccumType, lhs[D.m, D.k, D.m0, D.k0]) * cast(TV.AccumType, rhs[D.n, D.k, D.n0, D.k0])
+
+
+@linalg_structured_op
def batch_matmul(
A=TensorDef(T1, Batch, S.M, S.K),
B=TensorDef(T2, Batch, S.K, S.N),