Update syntax for amx.tile_muli to use two Unit attr to mark the zext case
authorMehdi Amini <joker.eph@gmail.com>
Sat, 20 Mar 2021 01:23:12 +0000 (01:23 +0000)
committerMehdi Amini <joker.eph@gmail.com>
Sat, 20 Mar 2021 04:12:24 +0000 (04:12 +0000)
This makes the annotation tied to the operand and the use of a keyword
more explicit/readable on what it means.

Differential Revision: https://reviews.llvm.org/D99001

mlir/include/mlir/Dialect/AMX/AMX.td
mlir/lib/Dialect/AMX/IR/AMXDialect.cpp
mlir/lib/Dialect/AMX/Transforms/LegalizeForLLVMExport.cpp
mlir/test/Dialect/AMX/invalid.mlir
mlir/test/Dialect/AMX/legalize-for-llvm.mlir
mlir/test/Dialect/AMX/roundtrip.mlir
mlir/test/Integration/Dialect/Vector/CPU/AMX/test-muli-ext.mlir
mlir/test/Integration/Dialect/Vector/CPU/AMX/test-muli.mlir

index 45c63a9..24052ed 100644 (file)
@@ -196,14 +196,14 @@ def TileMulIOp : AMX_Op<"tile_muli", [NoSideEffect, AllTypesMatch<["acc", "res"]
     into a "m x n" destination tile. Supports all "si32 <- s/ui8 x s/ui8"
     combinations (4 bytes packed into dwords in the columns of both the
     source operand tiles; the zero or sign extension is specified with
-    the attributes). The operation is eventually lowered into one of
-    the "tdpbssd", "tdpbsud", "tdpbusd", or "tdpbuud" instructions with
-    the corresponding tile configuration.
+    the attributes and default to sign extended). The operation is eventually
+    lowered into one of the "tdpbssd", "tdpbsud", "tdpbusd", or "tdpbuud"
+    instructions with the corresponding tile configuration.
 
     Example:
 
     ```mlir
-      %0 = amx.tile_muli %a, %b, %c [true, true]
+      %0 = amx.tile_muli %a zext, %b zext, %c 
         : vector<16x64xi8>, vector<16x64xi8>, vector<16x16xi32>
     ```
   }];
@@ -211,7 +211,9 @@ def TileMulIOp : AMX_Op<"tile_muli", [NoSideEffect, AllTypesMatch<["acc", "res"]
   let arguments = (ins VectorOfRankAndType<[2], [I32, I8]>:$lhs,
                        VectorOfRankAndType<[2], [I32, I8]>:$rhs,
                        VectorOfRankAndType<[2], [I32, I8]>:$acc,
-                       BoolArrayAttr:$zext);
+                       UnitAttr:$isZextLhs,
+                       UnitAttr:$isZextRhs
+                       );
   let results = (outs VectorOfRankAndType<[2], [I32, I8]>:$res);
   let extraClassDeclaration = [{
     VectorType getLhsVectorType() {
@@ -224,7 +226,7 @@ def TileMulIOp : AMX_Op<"tile_muli", [NoSideEffect, AllTypesMatch<["acc", "res"]
       return res().getType().cast<VectorType>();
     }
   }];
-  let assemblyFormat = "$lhs `,` $rhs `,` $acc $zext attr-dict `:` "
+  let assemblyFormat = "$lhs (`zext` $isZextLhs^)? `,` $rhs (`zext` $isZextRhs^)? `,` $acc attr-dict `:` "
                        "type($lhs) `,` type($rhs) `,` type($acc) ";
 }
 
index 5ebef7e..ab98820 100644 (file)
@@ -85,8 +85,6 @@ static LogicalResult verify(amx::TileMulFOp op) {
 }
 
 static LogicalResult verify(amx::TileMulIOp op) {
-  if (op.zext().size() != 2)
-    return op.emitOpError("unexpected zext length");
   VectorType aType = op.getLhsVectorType();
   VectorType bType = op.getRhsVectorType();
   VectorType cType = op.getVectorType();
index 6e082ce..7db57d3 100644 (file)
@@ -191,8 +191,8 @@ struct TileMulIConversion : public ConvertOpToLLVMPattern<TileMulIOp> {
         getTileSizes(rewriter, *getTypeConverter(), bType, op.getLoc());
     // Replace operation with intrinsic.
     Type resType = typeConverter->convertType(cType);
-    bool zexta = op.zext()[0].cast<BoolAttr>().getValue();
-    bool zextb = op.zext()[1].cast<BoolAttr>().getValue();
+    bool zexta = op.isZextLhs();
+    bool zextb = op.isZextRhs();
     if (zexta && zextb)
       rewriter.replaceOpWithNewOp<amx::x86_amx_tdpbuud>(
           op, resType, tsza.first, tszb.second, tsza.second, adaptor.acc(),
index b3a7286..6f147cf 100644 (file)
@@ -46,13 +46,3 @@ func @multsize() {
   // expected-error@+1 {{'amx.tile_mulf' op bad mult shape: 4 x 4 x 4}}
   %3 = amx.tile_mulf %0, %1, %2 : vector<8x8xbf16>, vector<8x8xbf16>, vector<4x4xf32>
 }
-
-// -----
-
-func @zextsize() {
-  %0 = amx.tile_zero : vector<8x8xi8>
-  %1 = amx.tile_zero : vector<8x8xi8>
-  %2 = amx.tile_zero : vector<8x8xi32>
-  // expected-error@+1 {{'amx.tile_muli' op unexpected zext length}}
-  %3 = amx.tile_muli %0, %1, %2 [true] : vector<8x8xi8>, vector<8x8xi8>, vector<8x8xi32>
-}
index f88d83d..37382b3 100644 (file)
@@ -17,13 +17,13 @@ func @muli(%arg0: memref<?x?xi8>, %arg1: memref<?x?xi32>) {
   %1 = amx.tile_zero : vector<16x64xi8>
   %2 = amx.tile_load %arg0[%0, %0] : memref<?x?xi8> into vector<16x64xi8>
   %3 = amx.tile_load %arg1[%0, %0] : memref<?x?xi32> into vector<16x16xi32>
-  %4 = amx.tile_muli %1, %2, %3 [true, true] : vector<16x64xi8>, vector<16x64xi8>, vector<16x16xi32>
+  %4 = amx.tile_muli %1 zext, %2 zext, %3 : vector<16x64xi8>, vector<16x64xi8>, vector<16x16xi32>
   amx.tile_store %arg1[%0, %0], %4 : memref<?x?xi32>, vector<16x16xi32>
-  %5 = amx.tile_muli %1, %2, %3 [false, false] : vector<16x64xi8>, vector<16x64xi8>, vector<16x16xi32>
+  %5 = amx.tile_muli %1, %2, %3 : vector<16x64xi8>, vector<16x64xi8>, vector<16x16xi32>
   amx.tile_store %arg1[%0, %0], %5 : memref<?x?xi32>, vector<16x16xi32>
-  %6 = amx.tile_muli %1, %2, %3 [true, false] : vector<16x64xi8>, vector<16x64xi8>, vector<16x16xi32>
+  %6 = amx.tile_muli %1 zext, %2, %3 : vector<16x64xi8>, vector<16x64xi8>, vector<16x16xi32>
   amx.tile_store %arg1[%0, %0], %6 : memref<?x?xi32>, vector<16x16xi32>
-  %7 = amx.tile_muli %1, %2, %3 [false, true] : vector<16x64xi8>, vector<16x64xi8>, vector<16x16xi32>
+  %7 = amx.tile_muli %1, %2 zext, %3 : vector<16x64xi8>, vector<16x64xi8>, vector<16x16xi32>
   amx.tile_store %arg1[%0, %0], %7  : memref<?x?xi32>, vector<16x16xi32>
   return
 }
index 98b8024..93f3ea4 100644 (file)
@@ -28,14 +28,22 @@ func @tmulf(%arg0: memref<?x?xbf16>, %arg1: memref<?x?xf32>) {
 // CHECK: %[[x:.*]] = amx.tile_load %{{.*}}[%{{.*}}, %{{.*}}] : memref<?x?xi8> into vector<16x64xi8>
 // CHECK: %[[y:.*]] = amx.tile_load %{{.*}}[%{{.*}}, %{{.*}}] : memref<?x?xi8> into vector<16x64xi8>
 // CHECK: %[[z:.*]] = amx.tile_load %{{.*}}[%{{.*}}, %{{.*}}] : memref<?x?xi32> into vector<16x16xi32>
-// CHECK: %[[m:.*]] = amx.tile_muli %[[x]], %[[y]], %[[z]] [true, true] : vector<16x64xi8>, vector<16x64xi8>, vector<16x16xi32>
+// CHECK: %[[m:.*]] = amx.tile_muli %[[x]] zext, %[[y]] zext, %[[z]] : vector<16x64xi8>, vector<16x64xi8>, vector<16x16xi32>
 // CHECK: amx.tile_store %{{.*}}[%{{.*}}, %{{.*}}], %[[m]] : memref<?x?xi32>, vector<16x16xi32>
+// Verify the parsing/printing of the sign-extension annotation.
+// CHECK: amx.tile_muli %{{.*}}, %{{.*}} zext, %{{.*}}
+// CHECK: amx.tile_muli %{{.*}} zext, %{{.*}}, %{{.*}}
+// CHECK: amx.tile_muli %{{.*}}, %{{.*}}, %{{.*}}
 func @tmuli(%arg0: memref<?x?xi8>, %arg1: memref<?x?xi8>, %arg2: memref<?x?xi32>) {
   %0 = constant 0 : index
   %1 = amx.tile_load %arg0[%0, %0] : memref<?x?xi8> into vector<16x64xi8>
   %2 = amx.tile_load %arg1[%0, %0] : memref<?x?xi8> into vector<16x64xi8>
   %3 = amx.tile_load %arg2[%0, %0] : memref<?x?xi32> into vector<16x16xi32>
-  %4 = amx.tile_muli %1, %2, %3 [true, true] : vector<16x64xi8>, vector<16x64xi8>, vector<16x16xi32>
+  %4 = amx.tile_muli %1 zext, %2 zext, %3 : vector<16x64xi8>, vector<16x64xi8>, vector<16x16xi32>
   amx.tile_store %arg2[%0, %0], %4 : memref<?x?xi32>, vector<16x16xi32>
+  // Verify the various `zext` combinations.
+  %5 = amx.tile_muli %1, %2 zext, %3 : vector<16x64xi8>, vector<16x64xi8>, vector<16x16xi32>
+  %6 = amx.tile_muli %1 zext, %2, %3 : vector<16x64xi8>, vector<16x64xi8>, vector<16x16xi32>
+  %7 = amx.tile_muli %1, %2, %3 : vector<16x64xi8>, vector<16x64xi8>, vector<16x16xi32>
   return
 }
index dee283c..45e9816 100644 (file)
@@ -24,7 +24,7 @@ func @kernel1(%arg0: memref<16x16xi8>,
   %1 = amx.tile_load %arg0[%0, %0] : memref<16x16xi8>  into vector<16x16xi8>
   %2 = amx.tile_load %arg1[%0, %0] : memref<4x16xi8>  into vector<4x16xi8>
   %3 = amx.tile_zero : vector<16x4xi32>
-  %4 = amx.tile_muli %1, %2, %3 [false, false] : vector<16x16xi8>, vector<4x16xi8>, vector<16x4xi32>
+  %4 = amx.tile_muli %1, %2, %3 : vector<16x16xi8>, vector<4x16xi8>, vector<16x4xi32>
   amx.tile_store %arg2[%0, %0], %4 : memref<16x4xi32>, vector<16x4xi32>
   return
 }
@@ -36,7 +36,7 @@ func @kernel2(%arg0: memref<16x16xi8>,
   %1 = amx.tile_load %arg0[%0, %0] : memref<16x16xi8>  into vector<16x16xi8>
   %2 = amx.tile_load %arg1[%0, %0] : memref<4x16xi8>  into vector<4x16xi8>
   %3 = amx.tile_zero : vector<16x4xi32>
-  %4 = amx.tile_muli %1, %2, %3 [false, true] : vector<16x16xi8>, vector<4x16xi8>, vector<16x4xi32>
+  %4 = amx.tile_muli %1, %2 zext, %3 : vector<16x16xi8>, vector<4x16xi8>, vector<16x4xi32>
   amx.tile_store %arg2[%0, %0], %4 : memref<16x4xi32>, vector<16x4xi32>
   return
 }
@@ -48,7 +48,7 @@ func @kernel3(%arg0: memref<16x16xi8>,
   %1 = amx.tile_load %arg0[%0, %0] : memref<16x16xi8>  into vector<16x16xi8>
   %2 = amx.tile_load %arg1[%0, %0] : memref<4x16xi8>  into vector<4x16xi8>
   %3 = amx.tile_zero : vector<16x4xi32>
-  %4 = amx.tile_muli %1, %2, %3 [true, false] : vector<16x16xi8>, vector<4x16xi8>, vector<16x4xi32>
+  %4 = amx.tile_muli %1 zext, %2, %3 : vector<16x16xi8>, vector<4x16xi8>, vector<16x4xi32>
   amx.tile_store %arg2[%0, %0], %4 : memref<16x4xi32>, vector<16x4xi32>
   return
 }
@@ -60,7 +60,7 @@ func @kernel4(%arg0: memref<16x16xi8>,
   %1 = amx.tile_load %arg0[%0, %0] : memref<16x16xi8>  into vector<16x16xi8>
   %2 = amx.tile_load %arg1[%0, %0] : memref<4x16xi8>  into vector<4x16xi8>
   %3 = amx.tile_zero : vector<16x4xi32>
-  %4 = amx.tile_muli %1, %2, %3 [true, true] : vector<16x16xi8>, vector<4x16xi8>, vector<16x4xi32>
+  %4 = amx.tile_muli %1 zext, %2 zext, %3 : vector<16x16xi8>, vector<4x16xi8>, vector<16x4xi32>
   amx.tile_store %arg2[%0, %0], %4 : memref<16x4xi32>, vector<16x4xi32>
   return
 }
index a52f66c..df848a0 100644 (file)
@@ -13,7 +13,7 @@ func @kernel1(%arg0: memref<2x8xi8>,
   %1 = amx.tile_load %arg0[%0, %0] : memref<2x8xi8>  into vector<2x8xi8>
   %2 = amx.tile_load %arg1[%0, %0] : memref<2x8xi8>  into vector<2x8xi8>
   %3 = amx.tile_zero : vector<2x2xi32>
-  %4 = amx.tile_muli %1, %2, %3 [true, true] : vector<2x8xi8>, vector<2x8xi8>, vector<2x2xi32>
+  %4 = amx.tile_muli %1 zext, %2 zext, %3 : vector<2x8xi8>, vector<2x8xi8>, vector<2x2xi32>
   amx.tile_store %arg2[%0, %0], %4 : memref<2x2xi32>, vector<2x2xi32>
   return
 }
@@ -26,7 +26,7 @@ func @kernel2(%arg0: memref<2x8xi8>,
   %1 = amx.tile_load %arg0[%0, %0] : memref<2x8xi8>  into vector<2x8xi8>
   %2 = amx.tile_load %arg1[%0, %0] : memref<2x8xi8>  into vector<2x8xi8>
   %3 = amx.tile_load %arg2[%0, %0] : memref<2x2xi32> into vector<2x2xi32>
-  %4 = amx.tile_muli %1, %2, %3 [true, true] : vector<2x8xi8>, vector<2x8xi8>, vector<2x2xi32>
+  %4 = amx.tile_muli %1 zext, %2 zext, %3 : vector<2x8xi8>, vector<2x8xi8>, vector<2x2xi32>
   amx.tile_store %arg2[%0, %0], %4 : memref<2x2xi32>, vector<2x2xi32>
   return
 }