Add linalg.mmt4d named op
authorAhmed Taei <ataei@google.com>
Wed, 30 Jun 2021 23:03:19 +0000 (16:03 -0700)
committerAhmed Taei <ataei@google.com>
Thu, 1 Jul 2021 19:41:08 +0000 (12:41 -0700)
This op performs matrix-matrix-transpose multiplication of 4-d inputs as the following:

```
C[m1, n1, m0, n0] = sum_{k1, k0}(A[m1, k1, m0, k0] * B[n1, k1, n0, k0])
```

Reviewed By: Benoit

Differential Revision: https://reviews.llvm.org/D105244

mlir/include/mlir/Dialect/Linalg/IR/LinalgNamedStructuredOps.yaml
mlir/python/mlir/dialects/linalg/opdsl/ops/core_named_ops.py

index 8781e16..a8baf23 100644 (file)
@@ -63,6 +63,79 @@ structured_op: !LinalgStructuredOpConfig
                   scalar_arg: B
 --- !LinalgOpConfig
 metadata: !LinalgOpMetadata
+  name: mmt4d
+  cpp_class_name: Mmt4DOp
+  doc: |-
+    Performs a matrix-matrix-transpose multiplication of two 4D inputs.
+
+    Differences from linalg.matmul:
+    * The right hand side is transposed, whence the 't' in 'mmt'.
+    * The input and output tensors have a 4D shape instead of a 2D shape. They
+      are interpreted as 2D matrices with one level of 2D tile subdivision,
+      whence the 2+2=4 dimensions. The inner tile dimensions are identified with
+      '0' suffixes below, for instance the LHS matrix shape (M, K, M0, K0) reads
+      as: MxK tiles, each of shape M0xK0.
+  implements:
+  - LinalgContractionOpInterface
+structured_op: !LinalgStructuredOpConfig
+  args:
+  - !LinalgOperandDefConfig
+    name: lhs
+    usage: InputOperand
+    type_var: LhsType
+    shape_map: affine_map<()[s0, s1, s2, s3, s4, s5] -> (s0, s1, s2, s3)>
+  - !LinalgOperandDefConfig
+    name: rhs
+    usage: InputOperand
+    type_var: RhsType
+    shape_map: affine_map<()[s0, s1, s2, s3, s4, s5] -> (s4, s1, s5, s3)>
+  - !LinalgOperandDefConfig
+    name: accum
+    usage: OutputOperand
+    type_var: AccumType
+    shape_map: affine_map<()[s0, s1, s2, s3, s4, s5] -> (s0, s4, s2, s5)>
+  indexing_maps: !LinalgIndexingMapsConfig
+    static_indexing_maps:
+    - affine_map<(d0, d1, d2, d3, d4, d5)[s0, s1, s2, s3, s4, s5] -> (d0, d4, d1,
+      d5)>
+    - affine_map<(d0, d1, d2, d3, d4, d5)[s0, s1, s2, s3, s4, s5] -> (d2, d4, d3,
+      d5)>
+    - affine_map<(d0, d1, d2, d3, d4, d5)[s0, s1, s2, s3, s4, s5] -> (d0, d2, d1,
+      d3)>
+  iterator_types:
+  - parallel
+  - parallel
+  - parallel
+  - parallel
+  - reduction
+  - reduction
+  assignments:
+  - !ScalarAssign
+    arg: accum
+    value: !ScalarExpression
+      scalar_apply:
+        fn_name: add
+        operands:
+        - !ScalarExpression
+          scalar_arg: accum
+        - !ScalarExpression
+          scalar_apply:
+            fn_name: mul
+            operands:
+            - !ScalarExpression
+              symbolic_cast:
+                type_var: AccumType
+                operands:
+                - !ScalarExpression
+                  scalar_arg: lhs
+            - !ScalarExpression
+              symbolic_cast:
+                type_var: AccumType
+                operands:
+                - !ScalarExpression
+                  scalar_arg: rhs
+--- !LinalgOpConfig
+metadata: !LinalgOpMetadata
   name: batch_matmul
   cpp_class_name: BatchMatmulOp
   doc: |-
index 561cd2e..095d949 100644 (file)
@@ -22,6 +22,26 @@ def matmul(
 
 
 @linalg_structured_op
+def mmt4d(lhs=TensorDef(TV.LhsType, S.M, S.K, S.M0, S.K0),
+          rhs=TensorDef(TV.RhsType, S.N, S.K, S.N0, S.K0),
+          accum=TensorDef(TV.AccumType, S.M, S.N, S.M0, S.N0,
+                                  output=True)):
+  """Performs a matrix-matrix-transpose multiplication of two 4D inputs.
+
+    Differences from linalg.matmul:
+    * The right hand side is transposed, whence the 't' in 'mmt'.
+    * The input and output tensors have a 4D shape instead of a 2D shape. They
+      are interpreted as 2D matrices with one level of 2D tile subdivision,
+      whence the 2+2=4 dimensions. The inner tile dimensions are identified with
+      '0' suffixes below, for instance the LHS matrix shape (M, K, M0, K0) reads
+      as: MxK tiles, each of shape M0xK0.
+  """
+  domain(D.m, D.m0, D.n, D.n0, D.k, D.k0)
+  implements(ContractionOpInterface)
+  accum[D.m, D.n, D.m0, D.n0] += cast(TV.AccumType, lhs[D.m, D.k, D.m0, D.k0]) * cast(TV.AccumType, rhs[D.n, D.k, D.n0, D.k0])
+
+
+@linalg_structured_op
 def batch_matmul(
     A=TensorDef(T1, Batch, S.M, S.K),
     B=TensorDef(T2, Batch, S.K, S.N),