From cdb6eb7e8372027e74d6b0fb1258fff37e2b3b5a Mon Sep 17 00:00:00 2001 From: Mehdi Amini Date: Sat, 20 Mar 2021 01:23:12 +0000 Subject: [PATCH] Update syntax for amx.tile_muli to use two Unit attr to mark the zext case This makes the annotation tied to the operand and the use of a keyword more explicit/readable on what it means. Differential Revision: https://reviews.llvm.org/D99001 --- mlir/include/mlir/Dialect/AMX/AMX.td | 14 ++++++++------ mlir/lib/Dialect/AMX/IR/AMXDialect.cpp | 2 -- mlir/lib/Dialect/AMX/Transforms/LegalizeForLLVMExport.cpp | 4 ++-- mlir/test/Dialect/AMX/invalid.mlir | 10 ---------- mlir/test/Dialect/AMX/legalize-for-llvm.mlir | 8 ++++---- mlir/test/Dialect/AMX/roundtrip.mlir | 12 ++++++++++-- .../Integration/Dialect/Vector/CPU/AMX/test-muli-ext.mlir | 8 ++++---- .../test/Integration/Dialect/Vector/CPU/AMX/test-muli.mlir | 4 ++-- 8 files changed, 30 insertions(+), 32 deletions(-) diff --git a/mlir/include/mlir/Dialect/AMX/AMX.td b/mlir/include/mlir/Dialect/AMX/AMX.td index 45c63a9..24052ed 100644 --- a/mlir/include/mlir/Dialect/AMX/AMX.td +++ b/mlir/include/mlir/Dialect/AMX/AMX.td @@ -196,14 +196,14 @@ def TileMulIOp : AMX_Op<"tile_muli", [NoSideEffect, AllTypesMatch<["acc", "res"] into a "m x n" destination tile. Supports all "si32 <- s/ui8 x s/ui8" combinations (4 bytes packed into dwords in the columns of both the source operand tiles; the zero or sign extension is specified with - the attributes). The operation is eventually lowered into one of - the "tdpbssd", "tdpbsud", "tdpbusd", or "tdpbuud" instructions with - the corresponding tile configuration. + the attributes and default to sign extended). The operation is eventually + lowered into one of the "tdpbssd", "tdpbsud", "tdpbusd", or "tdpbuud" + instructions with the corresponding tile configuration. Example: ```mlir - %0 = amx.tile_muli %a, %b, %c [true, true] + %0 = amx.tile_muli %a zext, %b zext, %c : vector<16x64xi8>, vector<16x64xi8>, vector<16x16xi32> ``` }]; @@ -211,7 +211,9 @@ def TileMulIOp : AMX_Op<"tile_muli", [NoSideEffect, AllTypesMatch<["acc", "res"] let arguments = (ins VectorOfRankAndType<[2], [I32, I8]>:$lhs, VectorOfRankAndType<[2], [I32, I8]>:$rhs, VectorOfRankAndType<[2], [I32, I8]>:$acc, - BoolArrayAttr:$zext); + UnitAttr:$isZextLhs, + UnitAttr:$isZextRhs + ); let results = (outs VectorOfRankAndType<[2], [I32, I8]>:$res); let extraClassDeclaration = [{ VectorType getLhsVectorType() { @@ -224,7 +226,7 @@ def TileMulIOp : AMX_Op<"tile_muli", [NoSideEffect, AllTypesMatch<["acc", "res"] return res().getType().cast(); } }]; - let assemblyFormat = "$lhs `,` $rhs `,` $acc $zext attr-dict `:` " + let assemblyFormat = "$lhs (`zext` $isZextLhs^)? `,` $rhs (`zext` $isZextRhs^)? `,` $acc attr-dict `:` " "type($lhs) `,` type($rhs) `,` type($acc) "; } diff --git a/mlir/lib/Dialect/AMX/IR/AMXDialect.cpp b/mlir/lib/Dialect/AMX/IR/AMXDialect.cpp index 5ebef7e..ab98820 100644 --- a/mlir/lib/Dialect/AMX/IR/AMXDialect.cpp +++ b/mlir/lib/Dialect/AMX/IR/AMXDialect.cpp @@ -85,8 +85,6 @@ static LogicalResult verify(amx::TileMulFOp op) { } static LogicalResult verify(amx::TileMulIOp op) { - if (op.zext().size() != 2) - return op.emitOpError("unexpected zext length"); VectorType aType = op.getLhsVectorType(); VectorType bType = op.getRhsVectorType(); VectorType cType = op.getVectorType(); diff --git a/mlir/lib/Dialect/AMX/Transforms/LegalizeForLLVMExport.cpp b/mlir/lib/Dialect/AMX/Transforms/LegalizeForLLVMExport.cpp index 6e082ce..7db57d3 100644 --- a/mlir/lib/Dialect/AMX/Transforms/LegalizeForLLVMExport.cpp +++ b/mlir/lib/Dialect/AMX/Transforms/LegalizeForLLVMExport.cpp @@ -191,8 +191,8 @@ struct TileMulIConversion : public ConvertOpToLLVMPattern { getTileSizes(rewriter, *getTypeConverter(), bType, op.getLoc()); // Replace operation with intrinsic. Type resType = typeConverter->convertType(cType); - bool zexta = op.zext()[0].cast().getValue(); - bool zextb = op.zext()[1].cast().getValue(); + bool zexta = op.isZextLhs(); + bool zextb = op.isZextRhs(); if (zexta && zextb) rewriter.replaceOpWithNewOp( op, resType, tsza.first, tszb.second, tsza.second, adaptor.acc(), diff --git a/mlir/test/Dialect/AMX/invalid.mlir b/mlir/test/Dialect/AMX/invalid.mlir index b3a7286..6f147cf 100644 --- a/mlir/test/Dialect/AMX/invalid.mlir +++ b/mlir/test/Dialect/AMX/invalid.mlir @@ -46,13 +46,3 @@ func @multsize() { // expected-error@+1 {{'amx.tile_mulf' op bad mult shape: 4 x 4 x 4}} %3 = amx.tile_mulf %0, %1, %2 : vector<8x8xbf16>, vector<8x8xbf16>, vector<4x4xf32> } - -// ----- - -func @zextsize() { - %0 = amx.tile_zero : vector<8x8xi8> - %1 = amx.tile_zero : vector<8x8xi8> - %2 = amx.tile_zero : vector<8x8xi32> - // expected-error@+1 {{'amx.tile_muli' op unexpected zext length}} - %3 = amx.tile_muli %0, %1, %2 [true] : vector<8x8xi8>, vector<8x8xi8>, vector<8x8xi32> -} diff --git a/mlir/test/Dialect/AMX/legalize-for-llvm.mlir b/mlir/test/Dialect/AMX/legalize-for-llvm.mlir index f88d83d..37382b3 100644 --- a/mlir/test/Dialect/AMX/legalize-for-llvm.mlir +++ b/mlir/test/Dialect/AMX/legalize-for-llvm.mlir @@ -17,13 +17,13 @@ func @muli(%arg0: memref, %arg1: memref) { %1 = amx.tile_zero : vector<16x64xi8> %2 = amx.tile_load %arg0[%0, %0] : memref into vector<16x64xi8> %3 = amx.tile_load %arg1[%0, %0] : memref into vector<16x16xi32> - %4 = amx.tile_muli %1, %2, %3 [true, true] : vector<16x64xi8>, vector<16x64xi8>, vector<16x16xi32> + %4 = amx.tile_muli %1 zext, %2 zext, %3 : vector<16x64xi8>, vector<16x64xi8>, vector<16x16xi32> amx.tile_store %arg1[%0, %0], %4 : memref, vector<16x16xi32> - %5 = amx.tile_muli %1, %2, %3 [false, false] : vector<16x64xi8>, vector<16x64xi8>, vector<16x16xi32> + %5 = amx.tile_muli %1, %2, %3 : vector<16x64xi8>, vector<16x64xi8>, vector<16x16xi32> amx.tile_store %arg1[%0, %0], %5 : memref, vector<16x16xi32> - %6 = amx.tile_muli %1, %2, %3 [true, false] : vector<16x64xi8>, vector<16x64xi8>, vector<16x16xi32> + %6 = amx.tile_muli %1 zext, %2, %3 : vector<16x64xi8>, vector<16x64xi8>, vector<16x16xi32> amx.tile_store %arg1[%0, %0], %6 : memref, vector<16x16xi32> - %7 = amx.tile_muli %1, %2, %3 [false, true] : vector<16x64xi8>, vector<16x64xi8>, vector<16x16xi32> + %7 = amx.tile_muli %1, %2 zext, %3 : vector<16x64xi8>, vector<16x64xi8>, vector<16x16xi32> amx.tile_store %arg1[%0, %0], %7 : memref, vector<16x16xi32> return } diff --git a/mlir/test/Dialect/AMX/roundtrip.mlir b/mlir/test/Dialect/AMX/roundtrip.mlir index 98b8024..93f3ea4 100644 --- a/mlir/test/Dialect/AMX/roundtrip.mlir +++ b/mlir/test/Dialect/AMX/roundtrip.mlir @@ -28,14 +28,22 @@ func @tmulf(%arg0: memref, %arg1: memref) { // CHECK: %[[x:.*]] = amx.tile_load %{{.*}}[%{{.*}}, %{{.*}}] : memref into vector<16x64xi8> // CHECK: %[[y:.*]] = amx.tile_load %{{.*}}[%{{.*}}, %{{.*}}] : memref into vector<16x64xi8> // CHECK: %[[z:.*]] = amx.tile_load %{{.*}}[%{{.*}}, %{{.*}}] : memref into vector<16x16xi32> -// CHECK: %[[m:.*]] = amx.tile_muli %[[x]], %[[y]], %[[z]] [true, true] : vector<16x64xi8>, vector<16x64xi8>, vector<16x16xi32> +// CHECK: %[[m:.*]] = amx.tile_muli %[[x]] zext, %[[y]] zext, %[[z]] : vector<16x64xi8>, vector<16x64xi8>, vector<16x16xi32> // CHECK: amx.tile_store %{{.*}}[%{{.*}}, %{{.*}}], %[[m]] : memref, vector<16x16xi32> +// Verify the parsing/printing of the sign-extension annotation. +// CHECK: amx.tile_muli %{{.*}}, %{{.*}} zext, %{{.*}} +// CHECK: amx.tile_muli %{{.*}} zext, %{{.*}}, %{{.*}} +// CHECK: amx.tile_muli %{{.*}}, %{{.*}}, %{{.*}} func @tmuli(%arg0: memref, %arg1: memref, %arg2: memref) { %0 = constant 0 : index %1 = amx.tile_load %arg0[%0, %0] : memref into vector<16x64xi8> %2 = amx.tile_load %arg1[%0, %0] : memref into vector<16x64xi8> %3 = amx.tile_load %arg2[%0, %0] : memref into vector<16x16xi32> - %4 = amx.tile_muli %1, %2, %3 [true, true] : vector<16x64xi8>, vector<16x64xi8>, vector<16x16xi32> + %4 = amx.tile_muli %1 zext, %2 zext, %3 : vector<16x64xi8>, vector<16x64xi8>, vector<16x16xi32> amx.tile_store %arg2[%0, %0], %4 : memref, vector<16x16xi32> + // Verify the various `zext` combinations. + %5 = amx.tile_muli %1, %2 zext, %3 : vector<16x64xi8>, vector<16x64xi8>, vector<16x16xi32> + %6 = amx.tile_muli %1 zext, %2, %3 : vector<16x64xi8>, vector<16x64xi8>, vector<16x16xi32> + %7 = amx.tile_muli %1, %2, %3 : vector<16x64xi8>, vector<16x64xi8>, vector<16x16xi32> return } diff --git a/mlir/test/Integration/Dialect/Vector/CPU/AMX/test-muli-ext.mlir b/mlir/test/Integration/Dialect/Vector/CPU/AMX/test-muli-ext.mlir index dee283c..45e9816 100644 --- a/mlir/test/Integration/Dialect/Vector/CPU/AMX/test-muli-ext.mlir +++ b/mlir/test/Integration/Dialect/Vector/CPU/AMX/test-muli-ext.mlir @@ -24,7 +24,7 @@ func @kernel1(%arg0: memref<16x16xi8>, %1 = amx.tile_load %arg0[%0, %0] : memref<16x16xi8> into vector<16x16xi8> %2 = amx.tile_load %arg1[%0, %0] : memref<4x16xi8> into vector<4x16xi8> %3 = amx.tile_zero : vector<16x4xi32> - %4 = amx.tile_muli %1, %2, %3 [false, false] : vector<16x16xi8>, vector<4x16xi8>, vector<16x4xi32> + %4 = amx.tile_muli %1, %2, %3 : vector<16x16xi8>, vector<4x16xi8>, vector<16x4xi32> amx.tile_store %arg2[%0, %0], %4 : memref<16x4xi32>, vector<16x4xi32> return } @@ -36,7 +36,7 @@ func @kernel2(%arg0: memref<16x16xi8>, %1 = amx.tile_load %arg0[%0, %0] : memref<16x16xi8> into vector<16x16xi8> %2 = amx.tile_load %arg1[%0, %0] : memref<4x16xi8> into vector<4x16xi8> %3 = amx.tile_zero : vector<16x4xi32> - %4 = amx.tile_muli %1, %2, %3 [false, true] : vector<16x16xi8>, vector<4x16xi8>, vector<16x4xi32> + %4 = amx.tile_muli %1, %2 zext, %3 : vector<16x16xi8>, vector<4x16xi8>, vector<16x4xi32> amx.tile_store %arg2[%0, %0], %4 : memref<16x4xi32>, vector<16x4xi32> return } @@ -48,7 +48,7 @@ func @kernel3(%arg0: memref<16x16xi8>, %1 = amx.tile_load %arg0[%0, %0] : memref<16x16xi8> into vector<16x16xi8> %2 = amx.tile_load %arg1[%0, %0] : memref<4x16xi8> into vector<4x16xi8> %3 = amx.tile_zero : vector<16x4xi32> - %4 = amx.tile_muli %1, %2, %3 [true, false] : vector<16x16xi8>, vector<4x16xi8>, vector<16x4xi32> + %4 = amx.tile_muli %1 zext, %2, %3 : vector<16x16xi8>, vector<4x16xi8>, vector<16x4xi32> amx.tile_store %arg2[%0, %0], %4 : memref<16x4xi32>, vector<16x4xi32> return } @@ -60,7 +60,7 @@ func @kernel4(%arg0: memref<16x16xi8>, %1 = amx.tile_load %arg0[%0, %0] : memref<16x16xi8> into vector<16x16xi8> %2 = amx.tile_load %arg1[%0, %0] : memref<4x16xi8> into vector<4x16xi8> %3 = amx.tile_zero : vector<16x4xi32> - %4 = amx.tile_muli %1, %2, %3 [true, true] : vector<16x16xi8>, vector<4x16xi8>, vector<16x4xi32> + %4 = amx.tile_muli %1 zext, %2 zext, %3 : vector<16x16xi8>, vector<4x16xi8>, vector<16x4xi32> amx.tile_store %arg2[%0, %0], %4 : memref<16x4xi32>, vector<16x4xi32> return } diff --git a/mlir/test/Integration/Dialect/Vector/CPU/AMX/test-muli.mlir b/mlir/test/Integration/Dialect/Vector/CPU/AMX/test-muli.mlir index a52f66c..df848a0 100644 --- a/mlir/test/Integration/Dialect/Vector/CPU/AMX/test-muli.mlir +++ b/mlir/test/Integration/Dialect/Vector/CPU/AMX/test-muli.mlir @@ -13,7 +13,7 @@ func @kernel1(%arg0: memref<2x8xi8>, %1 = amx.tile_load %arg0[%0, %0] : memref<2x8xi8> into vector<2x8xi8> %2 = amx.tile_load %arg1[%0, %0] : memref<2x8xi8> into vector<2x8xi8> %3 = amx.tile_zero : vector<2x2xi32> - %4 = amx.tile_muli %1, %2, %3 [true, true] : vector<2x8xi8>, vector<2x8xi8>, vector<2x2xi32> + %4 = amx.tile_muli %1 zext, %2 zext, %3 : vector<2x8xi8>, vector<2x8xi8>, vector<2x2xi32> amx.tile_store %arg2[%0, %0], %4 : memref<2x2xi32>, vector<2x2xi32> return } @@ -26,7 +26,7 @@ func @kernel2(%arg0: memref<2x8xi8>, %1 = amx.tile_load %arg0[%0, %0] : memref<2x8xi8> into vector<2x8xi8> %2 = amx.tile_load %arg1[%0, %0] : memref<2x8xi8> into vector<2x8xi8> %3 = amx.tile_load %arg2[%0, %0] : memref<2x2xi32> into vector<2x2xi32> - %4 = amx.tile_muli %1, %2, %3 [true, true] : vector<2x8xi8>, vector<2x8xi8>, vector<2x2xi32> + %4 = amx.tile_muli %1 zext, %2 zext, %3 : vector<2x8xi8>, vector<2x8xi8>, vector<2x2xi32> amx.tile_store %arg2[%0, %0], %4 : memref<2x2xi32>, vector<2x2xi32> return } -- 2.7.4