[mlir][AVX512] Add mask.compress to AVX512 dialect.
authorMatthias Springer <springerm@google.com>
Fri, 5 Mar 2021 04:08:05 +0000 (13:08 +0900)
committerMatthias Springer <springerm@google.com>
Sat, 6 Mar 2021 01:02:48 +0000 (10:02 +0900)
Adds mask.compress to the AVX512 dialect and defines a lowering to the LLVM dialect.

Differential Revision: https://reviews.llvm.org/D97611

mlir/include/mlir/Dialect/AVX512/AVX512.td
mlir/include/mlir/Dialect/LLVMIR/LLVMAVX512.td
mlir/lib/Conversion/AVX512ToLLVM/ConvertAVX512ToLLVM.cpp
mlir/lib/Dialect/AVX512/IR/AVX512Dialect.cpp
mlir/test/Conversion/AVX512ToLLVM/convert-to-llvm.mlir
mlir/test/Dialect/AVX512/roundtrip.mlir
mlir/test/Integration/Dialect/Vector/CPU/AVX512/test-mask-compress.mlir [new file with mode: 0644]
mlir/test/Integration/Dialect/Vector/CPU/AVX512/test-vp2intersect-i32.mlir
mlir/test/Target/avx512.mlir

index 7140b01..c2487a0 100644 (file)
@@ -31,6 +31,42 @@ def AVX512_Dialect : Dialect {
 class AVX512_Op<string mnemonic, list<OpTrait> traits = []> :
   Op<AVX512_Dialect, mnemonic, traits> {}
 
+def MaskCompressOp : AVX512_Op<"mask.compress", [NoSideEffect,
+  // TODO: Support optional arguments in `AllTypesMatch`. "type($src)" could
+  // then be removed from assemblyFormat.
+  AllTypesMatch<["a", "dst"]>,
+  TypesMatchWith<"`k` has the same number of bits as elements in `dst`",
+                 "dst", "k",
+                 "VectorType::get({$_self.cast<VectorType>().getShape()[0]}, "
+                 "IntegerType::get($_self.getContext(), 1))">]> {
+  let summary = "Masked compress op";
+  let description = [{
+  The mask.compress op is an AVX512 specific op that can lower to the
+  `llvm.mask.compress` instruction. Instead of `src`, a constant vector
+  vector attribute `constant_src` may be specified. If neither `src` nor
+  `constant_src` is specified, the remaining elements in the result vector are
+  set to zero.
+
+  #### From the Intel Intrinsics Guide:
+
+  Contiguously store the active integer/floating-point elements in `a` (those
+  with their respective bit set in writemask `k`) to `dst`, and pass through the
+  remaining elements from `src`.
+  }];
+  let verifier = [{ return ::verify(*this); }];
+  let arguments = (ins VectorOfLengthAndType<[16, 16, 8, 8],
+                                             [I1, I1, I1, I1]>:$k,
+                   VectorOfLengthAndType<[16, 16, 8, 8],
+                                         [F32, I32, F64, I64]>:$a,
+                   Optional<VectorOfLengthAndType<[16, 16, 8, 8],
+                                                  [F32, I32, F64, I64]>>:$src,
+                   OptionalAttr<ElementsAttr>:$constant_src);
+  let results = (outs VectorOfLengthAndType<[16, 16, 8, 8],
+                                            [F32, I32, F64, I64]>:$dst);
+  let assemblyFormat = "$k `,` $a (`,` $src^)? attr-dict"
+                       " `:` type($dst) (`,` type($src)^)?";
+}
+
 def MaskRndScaleOp : AVX512_Op<"mask.rndscale", [NoSideEffect,
   AllTypesMatch<["src", "a", "dst"]>,
   TypesMatchWith<"imm has the same number of bits as elements in dst",
index 9bcbdb5..20fb803 100644 (file)
@@ -33,6 +33,16 @@ class LLVMAVX512_IntrOp<string mnemonic, int numResults, list<OpTrait> traits =
                   "x86_avx512_" # !subst(".", "_", mnemonic),
                   [], [], traits, numResults>;
 
+// Defined by first result overload. May have to be extended for other
+// instructions in the future.
+class LLVMAVX512_IntrOverloadedOp<string mnemonic,
+                                  list<OpTrait> traits = []> :
+  LLVM_IntrOpBase<LLVMAVX512_Dialect, mnemonic,
+                  "x86_avx512_" # !subst(".", "_", mnemonic),
+                  /*list<int> overloadedResults=*/[0],
+                  /*list<int> overloadedOperands=*/[],
+                  traits, /*numResults=*/1>;
+
 def LLVM_x86_avx512_mask_rndscale_ps_512 :
   LLVMAVX512_IntrOp<"mask.rndscale.ps.512", 1>,
   Arguments<(ins LLVM_Type, LLVM_Type, LLVM_Type, LLVM_Type, LLVM_Type)>;
@@ -49,6 +59,10 @@ def LLVM_x86_avx512_mask_scalef_pd_512 :
   LLVMAVX512_IntrOp<"mask.scalef.pd.512", 1>,
   Arguments<(ins LLVM_Type, LLVM_Type, LLVM_Type, LLVM_Type, LLVM_Type)>;
 
+def LLVM_x86_avx512_mask_compress :
+  LLVMAVX512_IntrOverloadedOp<"mask.compress">,
+  Arguments<(ins LLVM_Type, LLVM_Type, LLVM_Type)>;
+
 def LLVM_x86_avx512_vp2intersect_d_512 :
   LLVMAVX512_IntrOp<"vp2intersect.d.512", 2>,
   Arguments<(ins LLVM_Type, LLVM_Type)>;
index 3381ad8..74b9197 100644 (file)
@@ -56,6 +56,34 @@ struct MaskRndScaleOp512Conversion : public ConvertToLLVMPattern {
   }
 };
 
+struct MaskCompressOpConversion
+    : public ConvertOpToLLVMPattern<MaskCompressOp> {
+  using ConvertOpToLLVMPattern<MaskCompressOp>::ConvertOpToLLVMPattern;
+
+  LogicalResult
+  matchAndRewrite(MaskCompressOp op, ArrayRef<Value> operands,
+                  ConversionPatternRewriter &rewriter) const override {
+    MaskCompressOp::Adaptor adaptor(operands);
+    auto opType = adaptor.a().getType();
+
+    Value src;
+    if (op.src()) {
+      src = adaptor.src();
+    } else if (op.constant_src()) {
+      src = rewriter.create<ConstantOp>(op.getLoc(), opType,
+                                        op.constant_srcAttr());
+    } else {
+      Attribute zeroAttr = rewriter.getZeroAttr(opType);
+      src = rewriter.create<ConstantOp>(op->getLoc(), opType, zeroAttr);
+    }
+
+    rewriter.replaceOpWithNewOp<LLVM::x86_avx512_mask_compress>(
+        op, opType, adaptor.a(), src, adaptor.k());
+
+    return success();
+  }
+};
+
 struct ScaleFOp512Conversion : public ConvertToLLVMPattern {
   explicit ScaleFOp512Conversion(MLIRContext *context,
                                  LLVMTypeConverter &typeConverter)
@@ -110,5 +138,6 @@ void mlir::populateAVX512ToLLVMConversionPatterns(
                   ScaleFOp512Conversion,
                   Vp2IntersectOp512Conversion>(&converter.getContext(),
                                                converter);
+  patterns.insert<MaskCompressOpConversion>(converter);
   // clang-format on
 }
index 697f008..023018a 100644 (file)
@@ -25,5 +25,21 @@ void avx512::AVX512Dialect::initialize() {
       >();
 }
 
+static LogicalResult verify(avx512::MaskCompressOp op) {
+  if (op.src() && op.constant_src())
+    return emitError(op.getLoc(), "cannot use both src and constant_src");
+
+  if (op.src() && (op.src().getType() != op.dst().getType()))
+    return emitError(op.getLoc(),
+                     "failed to verify that src and dst have same type");
+
+  if (op.constant_src() && (op.constant_src()->getType() != op.dst().getType()))
+    return emitError(
+        op.getLoc(),
+        "failed to verify that constant_src and dst have same type");
+
+  return success();
+}
+
 #define GET_OP_CLASSES
 #include "mlir/Dialect/AVX512/AVX512.cpp.inc"
index b6f7ad8..0d03917 100644 (file)
@@ -17,6 +17,19 @@ func @avx512_mask_rndscale(%a: vector<16xf32>, %b: vector<8xf64>, %i32: i32, %i1
   return %0, %1, %2, %3 : vector<16xf32>, vector<8xf64>, vector<16xf32>, vector<8xf64>
 }
 
+func @avx512_mask_compress(%k1: vector<16xi1>, %a1: vector<16xf32>,
+                           %k2: vector<8xi1>, %a2: vector<8xi64>)
+  -> (vector<16xf32>, vector<16xf32>, vector<8xi64>)
+{
+  // CHECK: llvm_avx512.mask.compress
+  %0 = avx512.mask.compress %k1, %a1 : vector<16xf32>
+  // CHECK: llvm_avx512.mask.compress
+  %1 = avx512.mask.compress %k1, %a1 {constant_src = dense<5.0> : vector<16xf32>} : vector<16xf32>
+  // CHECK: llvm_avx512.mask.compress
+  %2 = avx512.mask.compress %k2, %a2, %a2 : vector<8xi64>, vector<8xi64>
+  return %0, %1, %2 : vector<16xf32>, vector<16xf32>, vector<8xi64>
+}
+
 func @avx512_vp2intersect(%a: vector<16xi32>, %b: vector<8xi64>)
   -> (vector<16xi1>, vector<16xi1>, vector<8xi1>, vector<8xi1>)
 {
index 865f918..dc1a65b 100644 (file)
@@ -29,3 +29,16 @@ func @avx512_vp2intersect(%a: vector<16xi32>, %b: vector<8xi64>)
   %2, %3 = avx512.vp2intersect %b, %b : vector<8xi64>
   return %0, %1, %2, %3 : vector<16xi1>, vector<16xi1>, vector<8xi1>, vector<8xi1>
 }
+
+func @avx512_mask_compress(%k1: vector<16xi1>, %a1: vector<16xf32>,
+                           %k2: vector<8xi1>, %a2: vector<8xi64>)
+  -> (vector<16xf32>, vector<16xf32>, vector<8xi64>)
+{
+  // CHECK: avx512.mask.compress {{.*}} : vector<16xf32>
+  %0 = avx512.mask.compress %k1, %a1 : vector<16xf32>
+  // CHECK: avx512.mask.compress {{.*}} : vector<16xf32>
+  %1 = avx512.mask.compress %k1, %a1 {constant_src = dense<5.0> : vector<16xf32>} : vector<16xf32>
+  // CHECK: avx512.mask.compress {{.*}} : vector<8xi64>
+  %2 = avx512.mask.compress %k2, %a2, %a2 : vector<8xi64>, vector<8xi64>
+  return %0, %1, %2 : vector<16xf32>, vector<16xf32>, vector<8xi64>
+}
diff --git a/mlir/test/Integration/Dialect/Vector/CPU/AVX512/test-mask-compress.mlir b/mlir/test/Integration/Dialect/Vector/CPU/AVX512/test-mask-compress.mlir
new file mode 100644 (file)
index 0000000..ae34524
--- /dev/null
@@ -0,0 +1,27 @@
+// RUN: mlir-opt %s -convert-scf-to-std -convert-vector-to-llvm="enable-avx512" -convert-std-to-llvm  | \
+// RUN: mlir-translate  --mlir-to-llvmir | \
+// RUN: %lli --entry-function=entry --mattr="avx512bw" --dlopen=%mlir_integration_test_dir/libmlir_c_runner_utils%shlibext | \
+// RUN: FileCheck %s
+
+func @entry() -> i32 {
+  %i0 = constant 0 : i32
+
+  %a = std.constant dense<[1., 0., 0., 2., 4., 3., 5., 7., 8., 1., 5., 5., 3., 1., 0., 7.]> : vector<16xf32>
+  %k = std.constant dense<[1,  0,  1,  1,  1,  0,  0,  0,  0,  0,  1,  1,  1,  1,  1,  0]> : vector<16xi1>
+  %r1 = avx512.mask.compress %k, %a : vector<16xf32>
+  %r2 = avx512.mask.compress %k, %a {constant_src = dense<5.0> : vector<16xf32>} : vector<16xf32>
+
+  vector.print %r1 : vector<16xf32>
+  // CHECK: ( 1, 0, 2, 4, 5, 5, 3, 1, 0, 0, 0, 0, 0, 0, 0, 0 )
+
+  vector.print %r2 : vector<16xf32>
+  // CHECK: ( 1, 0, 2, 4, 5, 5, 3, 1, 0, 5, 5, 5, 5, 5, 5, 5 )
+
+  %src = std.constant dense<[0., 2., 1., 8., 6., 4., 4., 3., 2., 8., 5., 6., 3., 7., 6., 9.]> : vector<16xf32>
+  %r3 = avx512.mask.compress %k, %a, %src : vector<16xf32>, vector<16xf32>
+
+  vector.print %r3 : vector<16xf32>
+  // CHECK: ( 1, 0, 2, 4, 5, 5, 3, 1, 0, 8, 5, 6, 3, 7, 6, 9 )
+
+  return %i0 : i32
+}
index d29789a..e291e80 100644 (file)
@@ -1,5 +1,5 @@
 // RUN: mlir-opt %s -convert-scf-to-std -convert-vector-to-llvm="enable-avx512" -convert-std-to-llvm  | \
-// RUN: mlir-translate  --avx512-mlir-to-llvmir | \
+// RUN: mlir-translate  --mlir-to-llvmir | \
 // RUN: %lli --entry-function=entry --mattr="avx512bw,avx512vp2intersect" --dlopen=%mlir_integration_test_dir/libmlir_c_runner_utils%shlibext | \
 // RUN: FileCheck %s
 
index 940873b..abf36bd 100644 (file)
@@ -30,6 +30,16 @@ llvm.func @LLVM_x86_avx512_mask_pd_512(%a: vector<8xf64>,
   llvm.return %1: vector<8xf64>
 }
 
+// CHECK-LABEL: define <16 x float> @LLVM_x86_mask_compress
+llvm.func @LLVM_x86_mask_compress(%k: vector<16xi1>, %a: vector<16xf32>)
+  -> vector<16xf32>
+{
+  // CHECK: call <16 x float> @llvm.x86.avx512.mask.compress.v16f32(
+  %0 = "llvm_avx512.mask.compress"(%a, %a, %k) :
+    (vector<16xf32>, vector<16xf32>, vector<16xi1>) -> vector<16xf32>
+  llvm.return %0 : vector<16xf32>
+}
+
 // CHECK-LABEL: define { <16 x i1>, <16 x i1> } @LLVM_x86_vp2intersect_d_512
 llvm.func @LLVM_x86_vp2intersect_d_512(%a: vector<16xi32>, %b: vector<16xi32>)
   -> !llvm.struct<(vector<16 x i1>, vector<16 x i1>)>