[mlir][llvm] Add support for importing masked intrinsics from LLVM IR.
authorTobias Gysi <tobias.gysi@nextsilicon.com>
Mon, 17 Oct 2022 12:43:35 +0000 (15:43 +0300)
committerTobias Gysi <tobias.gysi@nextsilicon.com>
Mon, 17 Oct 2022 12:53:47 +0000 (15:53 +0300)
The revision adds support for importing the masked load/store and
gather/scatter intrinsics from LLVM IR. To enable the import, the
revision also includes an extension of the mlirBuilder code generation
to support variadic arguments.

Depends on D136057

Reviewed By: ftynse

Differential Revision: https://reviews.llvm.org/D136058

mlir/include/mlir/Dialect/LLVMIR/LLVMIntrinsicOps.td
mlir/include/mlir/Dialect/LLVMIR/LLVMOpBase.td
mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td
mlir/lib/Target/LLVMIR/ConvertFromLLVMIR.cpp
mlir/test/Target/LLVMIR/Import/intrinsic.ll
mlir/tools/mlir-tblgen/LLVMIRConversionGen.cpp

index 5f19aab..acf6f51 100644 (file)
@@ -400,60 +400,84 @@ def LLVM_GetActiveLaneMaskOp
 }
 
 /// Create a call to Masked Load intrinsic.
-def LLVM_MaskedLoadOp : LLVM_Op<"intr.masked.load"> {
+def LLVM_MaskedLoadOp : LLVM_OneResultIntrOp<"masked.load"> {
   let arguments = (ins LLVM_Type:$data, LLVM_Type:$mask,
                    Variadic<LLVM_Type>:$pass_thru, I32Attr:$alignment);
   let results = (outs LLVM_AnyVector:$res);
+  let assemblyFormat =
+    "operands attr-dict `:` functional-type(operands, results)";
+
   string llvmBuilder = [{
     $res = $pass_thru.empty() ? builder.CreateMaskedLoad(
         $_resultType, $data, llvm::Align($alignment), $mask) :
       builder.CreateMaskedLoad(
         $_resultType, $data, llvm::Align($alignment), $mask, $pass_thru[0]);
   }];
-  let assemblyFormat =
-    "operands attr-dict `:` functional-type(operands, results)";
+  string mlirBuilder = [{
+    $res = $_builder.create<LLVM::MaskedLoadOp>($_location,
+      $_resultType, $data, $mask, $pass_thru, $_int_attr($alignment));
+  }];
+  list<int> llvmArgIndices = [0, 2, 3, 1];
 }
 
 /// Create a call to Masked Store intrinsic.
-def LLVM_MaskedStoreOp : LLVM_Op<"intr.masked.store"> {
+def LLVM_MaskedStoreOp : LLVM_ZeroResultIntrOp<"masked.store"> {
   let arguments = (ins LLVM_Type:$value, LLVM_Type:$data, LLVM_Type:$mask,
                    I32Attr:$alignment);
   let builders = [LLVM_VoidResultTypeOpBuilder, LLVM_ZeroResultOpBuilder];
+  let assemblyFormat = "$value `,` $data `,` $mask attr-dict `:` "
+    "type($value) `,` type($mask) `into` type($data)";
+
   string llvmBuilder = [{
     builder.CreateMaskedStore(
       $value, $data, llvm::Align($alignment), $mask);
   }];
-  let assemblyFormat = "$value `,` $data `,` $mask attr-dict `:` "
-    "type($value) `,` type($mask) `into` type($data)";
+  string mlirBuilder = [{
+    $_builder.create<LLVM::MaskedStoreOp>($_location,
+      $value, $data, $mask, $_int_attr($alignment));
+  }];
+  list<int> llvmArgIndices = [0, 1, 3, 2];
 }
 
 /// Create a call to Masked Gather intrinsic.
-def LLVM_masked_gather : LLVM_Op<"intr.masked.gather"> {
+def LLVM_masked_gather : LLVM_OneResultIntrOp<"masked.gather"> {
   let arguments = (ins LLVM_AnyVector:$ptrs, LLVM_Type:$mask,
                    Variadic<LLVM_Type>:$pass_thru, I32Attr:$alignment);
   let results = (outs LLVM_Type:$res);
   let builders = [LLVM_OneResultOpBuilder];
+  let assemblyFormat =
+    "operands attr-dict `:` functional-type(operands, results)";
+
   string llvmBuilder = [{
     $res = $pass_thru.empty() ? builder.CreateMaskedGather(
         $_resultType, $ptrs, llvm::Align($alignment), $mask) :
       builder.CreateMaskedGather(
         $_resultType, $ptrs, llvm::Align($alignment), $mask, $pass_thru[0]);
   }];
-  let assemblyFormat =
-    "operands attr-dict `:` functional-type(operands, results)";
+  string mlirBuilder = [{
+    $res = $_builder.create<LLVM::masked_gather>($_location,
+      $_resultType, $ptrs, $mask, $pass_thru, $_int_attr($alignment));
+  }];
+  list<int> llvmArgIndices = [0, 2, 3, 1];
 }
 
 /// Create a call to Masked Scatter intrinsic.
-def LLVM_masked_scatter : LLVM_Op<"intr.masked.scatter"> {
+def LLVM_masked_scatter : LLVM_ZeroResultIntrOp<"masked.scatter"> {
   let arguments = (ins LLVM_Type:$value, LLVM_Type:$ptrs, LLVM_Type:$mask,
                    I32Attr:$alignment);
   let builders = [LLVM_VoidResultTypeOpBuilder, LLVM_ZeroResultOpBuilder];
+  let assemblyFormat = "$value `,` $ptrs `,` $mask attr-dict `:` "
+    "type($value) `,` type($mask) `into` type($ptrs)";
+
   string llvmBuilder = [{
     builder.CreateMaskedScatter(
       $value, $ptrs, llvm::Align($alignment), $mask);
   }];
-  let assemblyFormat = "$value `,` $ptrs `,` $mask attr-dict `:` "
-    "type($value) `,` type($mask) `into` type($ptrs)";
+  string mlirBuilder = [{
+    $_builder.create<LLVM::masked_scatter>($_location,
+      $value, $ptrs, $mask, $_int_attr($alignment));
+  }];
+  list<int> llvmArgIndices = [0, 1, 3, 2];
 }
 
 /// Create a call to Masked Expand Load intrinsic.
index 2ed9c57..f5f48fd 100644 (file)
@@ -228,7 +228,6 @@ class LLVM_OpBase<Dialect dialect, string mnemonic, list<Trait> traits = []> :
   //   - $_builder - substituted with the MLIR builder;
   //   - $_qualCppClassName - substitiuted with the MLIR operation class name.
   // Additionally, `$$` can be used to produce the dollar character.
-  // FIXME: The $name variable resolution does not support variadic arguments.
   string mlirBuilder = "";
 
   // An array that specifies a mapping from MLIR argument indices to LLVM IR
index 4b82ccd..b25d151 100644 (file)
@@ -1021,7 +1021,7 @@ def LLVM_ReturnOp : LLVM_TerminatorOp<"return", [Pure]> {
   string llvmInstName = "Ret";
   string llvmBuilder = [{
     if ($_numOperands != 0)
-      builder.CreateRet($arg[0]);
+      builder.CreateRet($arg);
     else
       builder.CreateRetVoid();
   }];
index 0070000..0ff4fe0 100644 (file)
@@ -451,8 +451,10 @@ LogicalResult Importer::convertIntrinsic(OpBuilder &odsBuilder,
   if (!isConvertibleIntrinsic(intrinsicID))
     return failure();
 
-  // Copy the call arguments to an operands array used by the conversion.
-  SmallVector<llvm::Value *> llvmOperands(inst->args());
+  // Copy the call arguments to initialize operands array reference used by
+  // the conversion.
+  SmallVector<llvm::Value *> args(inst->args());
+  ArrayRef<llvm::Value *> llvmOperands(args);
 #include "mlir/Dialect/LLVMIR/LLVMIntrinsicFromLLVMIRConversions.inc"
 
   return failure();
@@ -460,8 +462,10 @@ LogicalResult Importer::convertIntrinsic(OpBuilder &odsBuilder,
 
 LogicalResult Importer::convertOperation(OpBuilder &odsBuilder,
                                          llvm::Instruction *inst) {
-  // Copy the instruction operands used for the conversion.
-  SmallVector<llvm::Value *> llvmOperands(inst->operands());
+  // Copy the instruction operands to initialize the operands array reference
+  // used by the conversion.
+  SmallVector<llvm::Value *> operands(inst->operands());
+  ArrayRef<llvm::Value *> llvmOperands(operands);
 #include "mlir/Dialect/LLVMIR/LLVMOpFromLLVMIRConversions.inc"
 
   return failure();
index 19f741d..e036f09 100644 (file)
@@ -278,19 +278,35 @@ define <7 x i1> @get_active_lane_mask(i64 %0, i64 %1) {
   ret <7 x i1> %3
 }
 
-; TODO: masked load store intrinsics should be handled specially.
-define void @masked_load_store_intrinsics(<7 x float>* %0, <7 x i1> %1) {
-  %3 = call <7 x float> @llvm.masked.load.v7f32.p0v7f32(<7 x float>* %0, i32 1, <7 x i1> %1, <7 x float> undef)
-  %4 = call <7 x float> @llvm.masked.load.v7f32.p0v7f32(<7 x float>* %0, i32 1, <7 x i1> %1, <7 x float> %3)
-  call void @llvm.masked.store.v7f32.p0v7f32(<7 x float> %4, <7 x float>* %0, i32 1, <7 x i1> %1)
-  ret void
-}
-
-; TODO: masked gather scatter intrinsics should be handled specially.
-define void @masked_gather_scatter_intrinsics(<7 x float*> %0, <7 x i1> %1) {
-  %3 = call <7 x float> @llvm.masked.gather.v7f32.v7p0f32(<7 x float*> %0, i32 1, <7 x i1> %1, <7 x float> undef)
-  %4 = call <7 x float> @llvm.masked.gather.v7f32.v7p0f32(<7 x float*> %0, i32 1, <7 x i1> %1, <7 x float> %3)
-  call void @llvm.masked.scatter.v7f32.v7p0f32(<7 x float> %4, <7 x float*> %0, i32 1, <7 x i1> %1)
+; CHECK-LABEL: @masked_load_store_intrinsics
+; CHECK-SAME:  %[[VEC:[a-zA-Z0-9]+]]
+; CHECK-SAME:  %[[MASK:[a-zA-Z0-9]+]]
+define void @masked_load_store_intrinsics(<7 x float>* %vec, <7 x i1> %mask) {
+  ; CHECK:  %[[UNDEF:.+]] = llvm.mlir.undef
+  ; CHECK:  %[[VAL1:.+]] = llvm.intr.masked.load %[[VEC]], %[[MASK]], %[[UNDEF]] {alignment = 1 : i32}
+  ; CHECK-SAME:  (!llvm.ptr<vector<7xf32>>, vector<7xi1>, vector<7xf32>) -> vector<7xf32>
+  %1 = call <7 x float> @llvm.masked.load.v7f32.p0v7f32(<7 x float>* %vec, i32 1, <7 x i1> %mask, <7 x float> undef)
+  ; CHECK:  %[[VAL2:.+]] = llvm.intr.masked.load %[[VEC]], %[[MASK]], %[[VAL1]] {alignment = 4 : i32}
+  %2 = call <7 x float> @llvm.masked.load.v7f32.p0v7f32(<7 x float>* %vec, i32 4, <7 x i1> %mask, <7 x float> %1)
+  ; CHECK:  llvm.intr.masked.store %[[VAL2]], %[[VEC]], %[[MASK]] {alignment = 8 : i32}
+  ; CHECK-SAME:  vector<7xf32>, vector<7xi1> into !llvm.ptr<vector<7xf32>>
+  call void @llvm.masked.store.v7f32.p0v7f32(<7 x float> %2, <7 x float>* %vec, i32 8, <7 x i1> %mask)
+  ret void
+}
+
+; CHECK-LABEL: @masked_gather_scatter_intrinsics
+; CHECK-SAME:  %[[VEC:[a-zA-Z0-9]+]]
+; CHECK-SAME:  %[[MASK:[a-zA-Z0-9]+]]
+define void @masked_gather_scatter_intrinsics(<7 x float*> %vec, <7 x i1> %mask) {
+  ; CHECK:  %[[UNDEF:.+]] = llvm.mlir.undef
+  ; CHECK:  %[[VAL1:.+]] = llvm.intr.masked.gather %[[VEC]], %[[MASK]], %[[UNDEF]] {alignment = 1 : i32}
+  ; CHECK-SAME:  (!llvm.vec<7 x ptr<f32>>, vector<7xi1>, vector<7xf32>) -> vector<7xf32>
+  %1 = call <7 x float> @llvm.masked.gather.v7f32.v7p0f32(<7 x float*> %vec, i32 1, <7 x i1> %mask, <7 x float> undef)
+  ; CHECK:  %[[VAL2:.+]] = llvm.intr.masked.gather %[[VEC]], %[[MASK]], %[[VAL1]] {alignment = 4 : i32}
+  %2 = call <7 x float> @llvm.masked.gather.v7f32.v7p0f32(<7 x float*> %vec, i32 4, <7 x i1> %mask, <7 x float> %1)
+  ; CHECK:  llvm.intr.masked.scatter %[[VAL2]], %[[VEC]], %[[MASK]] {alignment = 8 : i32}
+  ; CHECK-SAME:  vector<7xf32>, vector<7xi1> into !llvm.vec<7 x ptr<f32>>
+  call void @llvm.masked.scatter.v7f32.v7p0f32(<7 x float> %2, <7 x float*> %vec, i32 8, <7 x i1> %mask)
   ret void
 }
 
index e0a773f..8124569 100644 (file)
@@ -71,14 +71,14 @@ static StringLoc findNextVariable(StringRef str) {
   return {startPos, endPos - startPos};
 }
 
-// Check if `name` is the name of the variadic operand of `op`.  The variadic
-// operand can only appear at the last position in the list of operands.
+// Check if `name` is a variadic operand of `op`. Seach all operands since the
+// MLIR and LLVM IR operand order may differ and only for the latter the
+// variadic operand is guaranteed to be at the end of the operands list.
 static bool isVariadicOperandName(const tblgen::Operator &op, StringRef name) {
-  unsigned numOperands = op.getNumOperands();
-  if (numOperands == 0)
-    return false;
-  const auto &operand = op.getOperand(numOperands - 1);
-  return operand.isVariableLength() && operand.name == name;
+  for (int i = 0, e = op.getNumOperands(); i < e; ++i)
+    if (op.getOperand(i).name == name)
+      return op.getOperand(i).isVariadic();
+  return false;
 }
 
 // Check if `result` is a known name of a result of `op`.
@@ -232,14 +232,17 @@ static LogicalResult emitOneMLIRBuilder(const Record &record, raw_ostream &os,
     if (succeeded(argIndex)) {
       // Access the LLVM IR operand that maps to the given argument index using
       // the provided argument indices mapping.
-      // FIXME: support trailing variadic arguments.
-      int64_t operandIdx = llvmArgIndices[*argIndex];
-      if (operandIdx < 0) {
+      int64_t idx = llvmArgIndices[*argIndex];
+      if (idx < 0) {
         return emitError(
             record, "expected non-negative operand index for argument " + name);
       }
-      assert(!isVariadicOperandName(op, name) && "unexpected variadic operand");
-      bs << formatv("processValue(llvmOperands[{0}])", operandIdx);
+      bool isVariadicOperand = isVariadicOperandName(op, name);
+      auto result =
+          isVariadicOperand
+              ? formatv("processValues(llvmOperands.drop_front({0}))", idx)
+              : formatv("processValue(llvmOperands[{0}])", idx);
+      bs << result;
     } else if (isResultName(op, name)) {
       if (op.getNumResults() != 1)
         return emitError(record, "expected op to have one result");