[flang][hlfir] lower hlfir.matmul_transpose to runtime call
authorTom Eccles <tom.eccles@arm.com>
Fri, 17 Mar 2023 09:26:45 +0000 (09:26 +0000)
committerTom Eccles <tom.eccles@arm.com>
Fri, 17 Mar 2023 09:30:04 +0000 (09:30 +0000)
Depends on D145960

Reviewed By: jeanPerier

Differential Revision: https://reviews.llvm.org/D145961

flang/include/flang/Optimizer/Builder/Runtime/Transformational.h
flang/lib/Optimizer/Builder/IntrinsicCall.cpp
flang/lib/Optimizer/Builder/Runtime/Transformational.cpp
flang/lib/Optimizer/HLFIR/Transforms/LowerHLFIRIntrinsics.cpp
flang/test/HLFIR/mul_transpose.f90

index f084e2e..ae0a097 100644 (file)
@@ -55,6 +55,10 @@ void genMatmul(fir::FirOpBuilder &builder, mlir::Location loc,
                mlir::Value matrixABox, mlir::Value matrixBBox,
                mlir::Value resultBox);
 
+void genMatmulTranspose(fir::FirOpBuilder &builder, mlir::Location loc,
+                        mlir::Value matrixABox, mlir::Value matrixBBox,
+                        mlir::Value resultBox);
+
 void genPack(fir::FirOpBuilder &builder, mlir::Location loc,
              mlir::Value resultBox, mlir::Value arrayBox, mlir::Value maskBox,
              mlir::Value vectorBox);
index bfc36b8..b933603 100644 (file)
@@ -269,6 +269,8 @@ struct IntrinsicLibrary {
   template <typename Shift>
   mlir::Value genMask(mlir::Type, llvm::ArrayRef<mlir::Value>);
   fir::ExtendedValue genMatmul(mlir::Type, llvm::ArrayRef<fir::ExtendedValue>);
+  fir::ExtendedValue genMatmulTranspose(mlir::Type,
+                                        llvm::ArrayRef<fir::ExtendedValue>);
   fir::ExtendedValue genMaxloc(mlir::Type, llvm::ArrayRef<fir::ExtendedValue>);
   fir::ExtendedValue genMaxval(mlir::Type, llvm::ArrayRef<fir::ExtendedValue>);
   fir::ExtendedValue genMerge(mlir::Type, llvm::ArrayRef<fir::ExtendedValue>);
@@ -679,6 +681,10 @@ static constexpr IntrinsicHandler handlers[]{
      &I::genMatmul,
      {{{"matrix_a", asAddr}, {"matrix_b", asAddr}}},
      /*isElemental=*/false},
+    {"matmul_transpose",
+     &I::genMatmulTranspose,
+     {{{"matrix_a", asAddr}, {"matrix_b", asAddr}}},
+     /*isElemental=*/false},
     {"max", &I::genExtremum<Extremum::Max, ExtremumBehavior::MinMaxss>},
     {"maxloc",
      &I::genMaxloc,
@@ -4015,6 +4021,33 @@ IntrinsicLibrary::genMatmul(mlir::Type resultType,
   return readAndAddCleanUp(resultMutableBox, resultType, "MATMUL");
 }
 
+// MATMUL_TRANSPOSE
+fir::ExtendedValue
+IntrinsicLibrary::genMatmulTranspose(mlir::Type resultType,
+                                     llvm::ArrayRef<fir::ExtendedValue> args) {
+  assert(args.size() == 2);
+
+  // Handle required matmul_transpose arguments
+  fir::BoxValue matrixTmpA = builder.createBox(loc, args[0]);
+  mlir::Value matrixA = fir::getBase(matrixTmpA);
+  fir::BoxValue matrixTmpB = builder.createBox(loc, args[1]);
+  mlir::Value matrixB = fir::getBase(matrixTmpB);
+  unsigned resultRank =
+      (matrixTmpA.rank() == 1 || matrixTmpB.rank() == 1) ? 1 : 2;
+
+  // Create mutable fir.box to be passed to the runtime for the result.
+  mlir::Type resultArrayType = builder.getVarLenSeqTy(resultType, resultRank);
+  fir::MutableBoxValue resultMutableBox =
+      fir::factory::createTempMutableBox(builder, loc, resultArrayType);
+  mlir::Value resultIrBox =
+      fir::factory::getMutableIRBox(builder, loc, resultMutableBox);
+  // Call runtime. The runtime is allocating the result.
+  fir::runtime::genMatmulTranspose(builder, loc, resultIrBox, matrixA, matrixB);
+  // Read result from mutable fir.box and add it to the list of temps to be
+  // finalized by the StatementContext.
+  return readAndAddCleanUp(resultMutableBox, resultType, "MATMUL_TRANSPOSE");
+}
+
 // MERGE
 fir::ExtendedValue
 IntrinsicLibrary::genMerge(mlir::Type,
index c51f386..3ae0977 100644 (file)
@@ -13,6 +13,7 @@
 #include "flang/Optimizer/Builder/FIRBuilder.h"
 #include "flang/Optimizer/Builder/Runtime/RTBuilder.h"
 #include "flang/Optimizer/Builder/Todo.h"
+#include "flang/Runtime/matmul-transpose.h"
 #include "flang/Runtime/matmul.h"
 #include "flang/Runtime/transformational.h"
 #include "mlir/Dialect/Func/IR/FuncOps.h"
@@ -351,6 +352,23 @@ void fir::runtime::genMatmul(fir::FirOpBuilder &builder, mlir::Location loc,
   builder.create<fir::CallOp>(loc, func, args);
 }
 
+/// Generate call to MatmulTranspose intrinsic runtime routine.
+void fir::runtime::genMatmulTranspose(fir::FirOpBuilder &builder,
+                                      mlir::Location loc, mlir::Value resultBox,
+                                      mlir::Value matrixABox,
+                                      mlir::Value matrixBBox) {
+  auto func =
+      fir::runtime::getRuntimeFunc<mkRTKey(MatmulTranspose)>(loc, builder);
+  auto fTy = func.getFunctionType();
+  auto sourceFile = fir::factory::locationToFilename(builder, loc);
+  auto sourceLine =
+      fir::factory::locationToLineNo(builder, loc, fTy.getInput(4));
+  auto args =
+      fir::runtime::createArguments(builder, loc, fTy, resultBox, matrixABox,
+                                    matrixBBox, sourceFile, sourceLine);
+  builder.create<fir::CallOp>(loc, func, args);
+}
+
 /// Generate call to Pack intrinsic runtime routine.
 void fir::runtime::genPack(fir::FirOpBuilder &builder, mlir::Location loc,
                            mlir::Value resultBox, mlir::Value arrayBox,
index 59e5b3b..95192a3 100644 (file)
@@ -257,6 +257,39 @@ class TransposeOpConversion
   }
 };
 
+struct MatmulTransposeOpConversion
+    : public HlfirIntrinsicConversion<hlfir::MatmulTransposeOp> {
+  using HlfirIntrinsicConversion<
+      hlfir::MatmulTransposeOp>::HlfirIntrinsicConversion;
+
+  mlir::LogicalResult
+  matchAndRewrite(hlfir::MatmulTransposeOp multranspose,
+                  mlir::PatternRewriter &rewriter) const override {
+    fir::KindMapping kindMapping{rewriter.getContext()};
+    fir::FirOpBuilder builder{rewriter, kindMapping};
+    const mlir::Location &loc = multranspose->getLoc();
+
+    mlir::Value lhs = multranspose.getLhs();
+    mlir::Value rhs = multranspose.getRhs();
+    llvm::SmallVector<IntrinsicArgument, 2> inArgs;
+    inArgs.push_back({lhs, lhs.getType()});
+    inArgs.push_back({rhs, rhs.getType()});
+
+    auto *argLowering = fir::getIntrinsicArgumentLowering("matmul");
+    llvm::SmallVector<fir::ExtendedValue, 2> args =
+        lowerArguments(multranspose, inArgs, rewriter, argLowering);
+
+    mlir::Type scalarResultType =
+        hlfir::getFortranElementType(multranspose.getType());
+
+    auto [resultExv, mustBeFreed] = fir::genIntrinsicCall(
+        builder, loc, "matmul_transpose", scalarResultType, args);
+
+    processReturnValue(multranspose, resultExv, mustBeFreed, builder, rewriter);
+    return mlir::success();
+  }
+};
+
 class LowerHLFIRIntrinsics
     : public hlfir::impl::LowerHLFIRIntrinsicsBase<LowerHLFIRIntrinsics> {
 public:
@@ -271,13 +304,14 @@ public:
     mlir::ModuleOp module = this->getOperation();
     mlir::MLIRContext *context = &getContext();
     mlir::RewritePatternSet patterns(context);
-    patterns.insert<MatmulOpConversion, SumOpConversion, TransposeOpConversion>(
-        context);
+    patterns.insert<MatmulOpConversion, MatmulTransposeOpConversion,
+                    SumOpConversion, TransposeOpConversion>(context);
     mlir::ConversionTarget target(*context);
     target.addLegalDialect<mlir::BuiltinDialect, mlir::arith::ArithDialect,
                            mlir::func::FuncDialect, fir::FIROpsDialect,
                            hlfir::hlfirDialect>();
-    target.addIllegalOp<hlfir::MatmulOp, hlfir::SumOp, hlfir::TransposeOp>();
+    target.addIllegalOp<hlfir::MatmulOp, hlfir::MatmulTransposeOp, hlfir::SumOp,
+                        hlfir::TransposeOp>();
     target.markUnknownOpDynamicallyLegal(
         [](mlir::Operation *) { return true; });
     if (mlir::failed(
index ab3742f..c79c742 100644 (file)
@@ -1,6 +1,7 @@
 ! RUN: bbc -emit-fir -hlfir %s -o - | FileCheck --check-prefix CHECK-BASE --check-prefix CHECK-ALL %s
 ! RUN: bbc -emit-fir -hlfir %s -o - | fir-opt --canonicalize | FileCheck --check-prefix CHECK-CANONICAL --check-prefix CHECK-ALL %s
 ! RUN: bbc -emit-fir -hlfir %s -o - | fir-opt --lower-hlfir-intrinsics | FileCheck --check-prefix CHECK-LOWERING --check-prefix CHECK-ALL %s
+! RUN: bbc -emit-fir -hlfir %s -o - | fir-opt --canonicalize | fir-opt --lower-hlfir-intrinsics | FileCheck --check-prefix CHECK-LOWERING-OPT --check-prefix CHECK-ALL %s
 ! RUN: bbc -emit-fir -hlfir %s -o - | fir-opt --lower-hlfir-intrinsics | fir-opt --bufferize-hlfir | FileCheck --check-prefix CHECK-BUFFERING --check-prefix CHECK-ALL %s
 
 ! Test passing a hlfir.expr from one intrinsic to another
@@ -56,6 +57,20 @@ endsubroutine
 ! CHECK-LOWERING-NEXT:  hlfir.destroy %[[MUL_EXPR]]
 ! CHECK-LOWERING-NEXT:  hlfir.destroy %[[TRANSPOSE_EXPR]]
 
+! CHECK-LOWERING-OPT:   %[[LHS_BOX:.*]] = fir.embox %[[A_DECL]]#1(%{{.*}})
+! CHECK-LOWERING-OPT:   %[[B_BOX:.*]] = fir.embox %[[B_DECL]]#1(%{{.*}})
+! CHECK-LOWERING-OPT:   %[[MUL_CONV_RES:.*]] = fir.convert %[[MUL_RES_BOX:.*]] : (!fir.ref<!fir.box<!fir.heap<!fir.array<?x?xf32>>>>) -> !fir.ref<!fir.box<none>>
+! CHECK-LOWERING-OPT:   %[[LHS_CONV:.*]] = fir.convert %[[LHS_BOX]] : (!fir.box<!fir.array<2x1xf32>>) -> !fir.box<none>
+! CHECK-LOWERING-OPT:   %[[B_BOX_CONV:.*]] = fir.convert %[[B_BOX]] : (!fir.box<!fir.array<2x2xf32>>) -> !fir.box<none>
+! CHECK-LOWERING-OPT:   fir.call @_FortranAMatmulTranspose(%[[MUL_CONV_RES]], %[[LHS_CONV]], %[[B_BOX_CONV]], %[[LOC_STR2:.*]], %[[LOC_N2:.*]])
+! CHECK-LOWERING-OPT:   %[[MUL_RES_LD:.*]] = fir.load %[[MUL_RES_BOX:.*]]
+! CHECK-LOWERING-OPT:   %[[MUL_RES_ADDR:.*]] = fir.box_addr %[[MUL_RES_LD]]
+! CHECK-LOWERING-OPT:   %[[MUL_RES_VAR:.*]]:2 = hlfir.declare %[[MUL_RES_ADDR]]({{.*}}) {uniq_name = ".tmp.intrinsic_result"}
+! CHECK-LOWERING-OPT:   %[[TRUE2:.*]] = arith.constant true
+! CHECK-LOWERING-OPT:   %[[MUL_EXPR:.*]] = hlfir.as_expr %[[MUL_RES_VAR]]#0 move %[[TRUE2]] : (!fir.box<!fir.array<?x?xf32>>, i1) -> !hlfir.expr<?x?xf32>
+! CHECK-LOWERING-OPT:   hlfir.assign %[[MUL_EXPR]] to %[[RES_DECL]]#0 : !hlfir.expr<?x?xf32>, !fir.ref<!fir.array<1x2xf32>>
+! CHECK-LOWERING-OPT:   hlfir.destroy %[[MUL_EXPR]]
+
 ! [argument handling unchanged]
 ! CHECK-BUFFERING:      fir.call @_FortranATranspose(
 ! CHECK-BUFFERING:      %[[TRANSPOSE_RES_LD:.*]] = fir.load %[[TRANSPOSE_RES_BOX:.*]]