[mlir][llvm] Import matrix, vector, and assume intrinsics from LLVM IR.
authorTobias Gysi <tobias.gysi@nextsilicon.com>
Tue, 18 Oct 2022 08:01:07 +0000 (11:01 +0300)
committerTobias Gysi <tobias.gysi@nextsilicon.com>
Tue, 18 Oct 2022 08:11:03 +0000 (11:11 +0300)
The revision adds support to import:
- matrix intrinsics
- vector reduce fadd/fmul intrinsics
- assume intrinsics
from LLVM IR.

Reviewed By: ftynse

Differential Revision: https://reviews.llvm.org/D136137

mlir/include/mlir/Dialect/LLVMIR/LLVMIntrinsicOps.td
mlir/include/mlir/Dialect/LLVMIR/LLVMOpBase.td
mlir/test/Target/LLVMIR/Import/intrinsic.ll

index acf6f51..38990ac 100644 (file)
@@ -168,14 +168,8 @@ def LLVM_UMulWithOverflowOp
 }
 
 
-def LLVM_AssumeOp : LLVM_Op<"intr.assume", []> {
+def LLVM_AssumeOp : LLVM_ZeroResultIntrOp<"assume", []> {
   let arguments = (ins LLVM_Type:$cond);
-  let llvmBuilder = [{
-    llvm::Module *module = builder.GetInsertBlock()->getModule();
-    llvm::Function *fn =
-        llvm::Intrinsic::getDeclaration(module, llvm::Intrinsic::assume, {});
-    builder.CreateCall(fn, {$cond});
-  }];
 }
 
 
@@ -308,11 +302,14 @@ def LLVM_vector_reduce_fmul : LLVM_VectorReductionAcc<"fmul">;
 /// isVolatile - True if the load operation is marked as volatile.
 /// columns    - Number of columns in matrix (must be a constant)
 /// stride     - Space between columns
-def LLVM_MatrixColumnMajorLoadOp : LLVM_Op<"intr.matrix.column.major.load"> {
+def LLVM_MatrixColumnMajorLoadOp : LLVM_OneResultIntrOp<"matrix.column.major.load"> {
   let arguments = (ins LLVM_Type:$data, LLVM_Type:$stride, I1Attr:$isVolatile,
                    I32Attr:$rows, I32Attr:$columns);
   let results = (outs LLVM_AnyVector:$res);
   let builders = [LLVM_OneResultOpBuilder];
+  let assemblyFormat = "$data `,` `<` `stride` `=` $stride `>` attr-dict"
+    "`:` type($res) `from` type($data) `stride` type($stride)";
+
   string llvmBuilder = [{
     llvm::MatrixBuilder mb(builder);
     const llvm::DataLayout &dl =
@@ -324,8 +321,11 @@ def LLVM_MatrixColumnMajorLoadOp : LLVM_Op<"intr.matrix.column.major.load"> {
       ElemTy, $data, align, $stride, $isVolatile, $rows,
       $columns);
   }];
-  let assemblyFormat = "$data `,` `<` `stride` `=` $stride `>` attr-dict"
-    "`:` type($res) `from` type($data) `stride` type($stride)";
+  string mlirBuilder = [{
+    $res = $_builder.create<LLVM::MatrixColumnMajorLoadOp>(
+      $_location, $_resultType, $data, $stride,
+      $_int_attr($isVolatile), $_int_attr($rows), $_int_attr($columns));
+  }];
 }
 
 /// Create a column major, strided 2-D matrix store, as specified in the LLVM
@@ -336,11 +336,14 @@ def LLVM_MatrixColumnMajorLoadOp : LLVM_Op<"intr.matrix.column.major.load"> {
 /// rows       - Number of rows in matrix (must be a constant)
 /// columns    - Number of columns in matrix (must be a constant)
 /// stride     - Space between columns
-def LLVM_MatrixColumnMajorStoreOp : LLVM_Op<"intr.matrix.column.major.store"> {
+def LLVM_MatrixColumnMajorStoreOp : LLVM_ZeroResultIntrOp<"matrix.column.major.store"> {
   let arguments = (ins LLVM_AnyVector:$matrix, LLVM_Type:$data,
                    LLVM_Type:$stride, I1Attr:$isVolatile, I32Attr:$rows,
                    I32Attr:$columns);
   let builders = [LLVM_VoidResultTypeOpBuilder, LLVM_ZeroResultOpBuilder];
+  let assemblyFormat = "$matrix `,` $data `,` `<` `stride` `=` $stride `>` "
+    "attr-dict`:` type($matrix) `to` type($data) `stride` type($stride)";
+
   string llvmBuilder = [{
     llvm::MatrixBuilder mb(builder);
     const llvm::DataLayout &dl =
@@ -352,39 +355,54 @@ def LLVM_MatrixColumnMajorStoreOp : LLVM_Op<"intr.matrix.column.major.store"> {
       $matrix, $data, align, $stride, $isVolatile,
       $rows, $columns);
   }];
-  let assemblyFormat = "$matrix `,` $data `,` `<` `stride` `=` $stride `>` "
-    "attr-dict`:` type($matrix) `to` type($data) `stride` type($stride)";
+  string mlirBuilder = [{
+    $_builder.create<LLVM::MatrixColumnMajorStoreOp>(
+      $_location, $matrix, $data, $stride,
+      $_int_attr($isVolatile), $_int_attr($rows), $_int_attr($columns));
+  }];
 }
 
 /// Create a llvm.matrix.multiply call, multiplying 2-D matrices LHS and RHS, as
 /// specified in the LLVM MatrixBuilder.
-def LLVM_MatrixMultiplyOp : LLVM_Op<"intr.matrix.multiply"> {
+def LLVM_MatrixMultiplyOp : LLVM_OneResultIntrOp<"matrix.multiply"> {
   let arguments = (ins LLVM_Type:$lhs, LLVM_Type:$rhs, I32Attr:$lhs_rows,
                    I32Attr:$lhs_columns, I32Attr:$rhs_columns);
   let results = (outs LLVM_Type:$res);
   let builders = [LLVM_OneResultOpBuilder];
+  let assemblyFormat = "$lhs `,` $rhs attr-dict "
+    "`:` `(` type($lhs) `,` type($rhs) `)` `->` type($res)";
+
   string llvmBuilder = [{
     llvm::MatrixBuilder mb(builder);
     $res = mb.CreateMatrixMultiply(
       $lhs, $rhs, $lhs_rows, $lhs_columns,
       $rhs_columns);
   }];
-  let assemblyFormat = "$lhs `,` $rhs attr-dict "
-    "`:` `(` type($lhs) `,` type($rhs) `)` `->` type($res)";
+  string mlirBuilder = [{
+    $res = $_builder.create<LLVM::MatrixMultiplyOp>(
+      $_location, $_resultType, $lhs, $rhs,
+      $_int_attr($lhs_rows), $_int_attr($lhs_columns), $_int_attr($rhs_columns));
+  }];
 }
 
 /// Create a llvm.matrix.transpose call, transposing a `rows` x `columns` 2-D
 /// `matrix`, as specified in the LLVM MatrixBuilder.
-def LLVM_MatrixTransposeOp : LLVM_Op<"intr.matrix.transpose"> {
+def LLVM_MatrixTransposeOp : LLVM_OneResultIntrOp<"matrix.transpose"> {
   let arguments = (ins LLVM_Type:$matrix, I32Attr:$rows, I32Attr:$columns);
   let results = (outs LLVM_Type:$res);
   let builders = [LLVM_OneResultOpBuilder];
+  let assemblyFormat = "$matrix attr-dict `:` type($matrix) `into` type($res)";
+
   string llvmBuilder = [{
     llvm::MatrixBuilder mb(builder);
     $res = mb.CreateMatrixTranspose(
       $matrix, $rows, $columns);
   }];
-  let assemblyFormat = "$matrix attr-dict `:` type($matrix) `into` type($res)";
+  string mlirBuilder = [{
+    $res = $_builder.create<LLVM::MatrixTransposeOp>(
+      $_location, $_resultType, $matrix,
+      $_int_attr($rows), $_int_attr($columns));
+  }];
 }
 
 //
index f5f48fd..8f63e98 100644 (file)
@@ -417,10 +417,9 @@ class LLVM_VectorReduction<string mnem>
 // LLVM vector reduction over a single vector, with an initial value,
 // and with permission to reassociate the reduction operations.
 class LLVM_VectorReductionAcc<string mnem>
-    : LLVM_OpBase<LLVM_Dialect, "intr.vector.reduce." # mnem,
-                  [Pure]>,
-      Results<(outs LLVM_Type:$res)>,
-      Arguments<(ins LLVM_Type, LLVM_Type,
+    : LLVM_OneResultIntrOp<"vector.reduce." # mnem,
+                           [], [0], [Pure]>,
+      Arguments<(ins LLVM_Type:$start_value, LLVM_Type:$input,
                  DefaultValuedAttr<BoolAttr, "false">:$reassoc)> {
   let llvmBuilder = [{
     llvm::Module *module = builder.GetInsertBlock()->getModule();
@@ -438,6 +437,11 @@ class LLVM_VectorReductionAcc<string mnem>
     $res = builder.CreateCall(fn, operands);
     builder.setFastMathFlags(origFM);  // restore fastmath flag
   }];
+  let mlirBuilder = [{
+    bool allowReassoc = inst->getFastMathFlags().allowReassoc();
+    $res = $_builder.create<$_qualCppClassName>($_location,
+      $_resultType, $start_value, $input, allowReassoc);
+  }];
 }
 
 def LLVM_OneResultOpBuilder :
index e036f09..d8fca91 100644 (file)
@@ -230,6 +230,7 @@ define void @umin_test(i32 %0, i32 %1, <8 x i32> %2, <8 x i32> %3) {
   %6 = call <8 x i32> @llvm.umin.v8i32(<8 x i32> %2, <8 x i32> %3)
   ret void
 }
+
 ; CHECK-LABEL:  llvm.func @vector_reductions
 define void @vector_reductions(float %0, <8 x float> %1, <8 x i32> %2) {
   ; CHECK: "llvm.intr.vector.reduce.add"(%{{.*}}) : (vector<8xi32>) -> i32
@@ -252,22 +253,37 @@ define void @vector_reductions(float %0, <8 x float> %1, <8 x i32> %2) {
   %12 = call i32 @llvm.vector.reduce.umax.v8i32(<8 x i32> %2)
   ; CHECK: "llvm.intr.vector.reduce.umin"(%{{.*}}) : (vector<8xi32>) -> i32
   %13 = call i32 @llvm.vector.reduce.umin.v8i32(<8 x i32> %2)
-  ; TODO: vector reduce fadd and fmul should be handled specially.
+  ; CHECK: "llvm.intr.vector.reduce.fadd"(%{{.*}}, %{{.*}}) {reassoc = false} : (f32, vector<8xf32>) -> f32
   %14 = call float @llvm.vector.reduce.fadd.v8f32(float %0, <8 x float> %1)
+  ; CHECK: "llvm.intr.vector.reduce.fmul"(%{{.*}}, %{{.*}}) {reassoc = false} : (f32, vector<8xf32>) -> f32
   %15 = call float @llvm.vector.reduce.fmul.v8f32(float %0, <8 x float> %1)
+  ; CHECK: "llvm.intr.vector.reduce.fadd"(%{{.*}}, %{{.*}}) {reassoc = true} : (f32, vector<8xf32>) -> f32
   %16 = call reassoc float @llvm.vector.reduce.fadd.v8f32(float %0, <8 x float> %1)
+  ; CHECK: "llvm.intr.vector.reduce.fmul"(%{{.*}}, %{{.*}}) {reassoc = true} : (f32, vector<8xf32>) -> f32
   %17 = call reassoc float @llvm.vector.reduce.fmul.v8f32(float %0, <8 x float> %1)
   ; CHECK:  "llvm.intr.vector.reduce.xor"(%{{.*}}) : (vector<8xi32>) -> i32
   %18 = call i32 @llvm.vector.reduce.xor.v8i32(<8 x i32> %2)
   ret void
 }
 
-; TODO: matrix intrinsic should be handled specially.
-define void @matrix_intrinsics(<64 x float> %0, <48 x float> %1, float* %2, i64 %3) {
-  %5 = call <12 x float> @llvm.matrix.multiply.v12f32.v64f32.v48f32(<64 x float> %0, <48 x float> %1, i32 4, i32 16, i32 3)
-  %6 = call <48 x float> @llvm.matrix.transpose.v48f32(<48 x float> %1, i32 3, i32 16)
-  %7 = call <48 x float> @llvm.matrix.column.major.load.v48f32.i64(float* align 4 %2, i64 %3, i1 false, i32 3, i32 16)
-  call void @llvm.matrix.column.major.store.v48f32.i64(<48 x float> %7, float* align 4 %2, i64 %3, i1 false, i32 3, i32 16)
+; CHECK-LABEL: @matrix_intrinsics
+; CHECK-SAME:  %[[VEC1:[a-zA-Z0-9]+]]
+; CHECK-SAME:  %[[VEC2:[a-zA-Z0-9]+]]
+; CHECK-SAME:  %[[PTR:[a-zA-Z0-9]+]]
+; CHECK-SAME:  %[[STRIDE:[a-zA-Z0-9]+]]
+define void @matrix_intrinsics(<64 x float> %vec1, <48 x float> %vec2, float* %ptr, i64 %stride) {
+  ; CHECK:  llvm.intr.matrix.multiply %[[VEC1]], %[[VEC2]]
+  ; CHECK-SAME:  {lhs_columns = 16 : i32, lhs_rows = 4 : i32, rhs_columns = 3 : i32}
+  %1 = call <12 x float> @llvm.matrix.multiply.v12f32.v64f32.v48f32(<64 x float> %vec1, <48 x float> %vec2, i32 4, i32 16, i32 3)
+  ; CHECK:  llvm.intr.matrix.transpose %[[VEC2]]
+  ; CHECK-SAME:  {columns = 16 : i32, rows = 3 : i32}
+  %2 = call <48 x float> @llvm.matrix.transpose.v48f32(<48 x float> %vec2, i32 3, i32 16)
+  ; CHECK:  %[[VAL1:.+]] = llvm.intr.matrix.column.major.load %[[PTR]], <stride = %[[STRIDE]]>
+  ; CHECK-SAME:  {columns = 16 : i32, isVolatile = false, rows = 3 : i32}
+  %3 = call <48 x float> @llvm.matrix.column.major.load.v48f32.i64(float* align 4 %ptr, i64 %stride, i1 false, i32 3, i32 16)
+  ; CHECK:  llvm.intr.matrix.column.major.store %[[VAL1]], %[[PTR]], <stride = %[[STRIDE]]>
+  ; CHECK-SAME:  {columns = 16 : i32, isVolatile = true, rows = 3 : i32}
+  call void @llvm.matrix.column.major.store.v48f32.i64(<48 x float> %3, float* align 4 %ptr, i64 %stride, i1 true, i32 3, i32 16)
   ret void
 }
 
@@ -412,6 +428,14 @@ define void @va_intrinsics_test(i8* %0, i8* %1) {
   ret void
 }
 
+; CHECK-LABEL: @assume
+; CHECK-SAME:  %[[TRUE:[a-zA-Z0-9]+]]
+define void @assume(i1 %true) {
+  ; CHECK:  "llvm.intr.assume"(%[[TRUE]]) : (i1) -> ()
+  call void @llvm.assume(i1 %true)
+  ret void
+}
+
 ; CHECK-LABEL:  llvm.func @coro_id
 define void @coro_id(i32 %0, i8* %1) {
   ; CHECK: llvm.intr.coro.id %{{.*}}, %{{.*}}, %{{.*}}, %{{.*}} : !llvm.token
@@ -775,3 +799,4 @@ declare <8 x i64> @llvm.vp.ptrtoint.v8i64.v8p0i32(<8 x i32*>, <8 x i1>, i32)
 declare <8 x i32*> @llvm.vp.inttoptr.v8p0i32.v8i64(<8 x i64>, <8 x i1>, i32)
 declare void @llvm.lifetime.start.p0i8(i64 immarg, i8* nocapture)
 declare void @llvm.lifetime.end.p0i8(i64 immarg, i8* nocapture)
+declare void @llvm.assume(i1)