mlir::Value matrixABox, mlir::Value matrixBBox,
mlir::Value resultBox);
+void genMatmulTranspose(fir::FirOpBuilder &builder, mlir::Location loc,
+ mlir::Value matrixABox, mlir::Value matrixBBox,
+ mlir::Value resultBox);
+
void genPack(fir::FirOpBuilder &builder, mlir::Location loc,
mlir::Value resultBox, mlir::Value arrayBox, mlir::Value maskBox,
mlir::Value vectorBox);
template <typename Shift>
mlir::Value genMask(mlir::Type, llvm::ArrayRef<mlir::Value>);
fir::ExtendedValue genMatmul(mlir::Type, llvm::ArrayRef<fir::ExtendedValue>);
+ fir::ExtendedValue genMatmulTranspose(mlir::Type,
+ llvm::ArrayRef<fir::ExtendedValue>);
fir::ExtendedValue genMaxloc(mlir::Type, llvm::ArrayRef<fir::ExtendedValue>);
fir::ExtendedValue genMaxval(mlir::Type, llvm::ArrayRef<fir::ExtendedValue>);
fir::ExtendedValue genMerge(mlir::Type, llvm::ArrayRef<fir::ExtendedValue>);
&I::genMatmul,
{{{"matrix_a", asAddr}, {"matrix_b", asAddr}}},
/*isElemental=*/false},
+ {"matmul_transpose",
+ &I::genMatmulTranspose,
+ {{{"matrix_a", asAddr}, {"matrix_b", asAddr}}},
+ /*isElemental=*/false},
{"max", &I::genExtremum<Extremum::Max, ExtremumBehavior::MinMaxss>},
{"maxloc",
&I::genMaxloc,
return readAndAddCleanUp(resultMutableBox, resultType, "MATMUL");
}
+// MATMUL_TRANSPOSE
+fir::ExtendedValue
+IntrinsicLibrary::genMatmulTranspose(mlir::Type resultType,
+ llvm::ArrayRef<fir::ExtendedValue> args) {
+ assert(args.size() == 2);
+
+ // Handle required matmul_transpose arguments
+ fir::BoxValue matrixTmpA = builder.createBox(loc, args[0]);
+ mlir::Value matrixA = fir::getBase(matrixTmpA);
+ fir::BoxValue matrixTmpB = builder.createBox(loc, args[1]);
+ mlir::Value matrixB = fir::getBase(matrixTmpB);
+ unsigned resultRank =
+ (matrixTmpA.rank() == 1 || matrixTmpB.rank() == 1) ? 1 : 2;
+
+ // Create mutable fir.box to be passed to the runtime for the result.
+ mlir::Type resultArrayType = builder.getVarLenSeqTy(resultType, resultRank);
+ fir::MutableBoxValue resultMutableBox =
+ fir::factory::createTempMutableBox(builder, loc, resultArrayType);
+ mlir::Value resultIrBox =
+ fir::factory::getMutableIRBox(builder, loc, resultMutableBox);
+ // Call runtime. The runtime is allocating the result.
+ fir::runtime::genMatmulTranspose(builder, loc, resultIrBox, matrixA, matrixB);
+ // Read result from mutable fir.box and add it to the list of temps to be
+ // finalized by the StatementContext.
+ return readAndAddCleanUp(resultMutableBox, resultType, "MATMUL_TRANSPOSE");
+}
+
// MERGE
fir::ExtendedValue
IntrinsicLibrary::genMerge(mlir::Type,
#include "flang/Optimizer/Builder/FIRBuilder.h"
#include "flang/Optimizer/Builder/Runtime/RTBuilder.h"
#include "flang/Optimizer/Builder/Todo.h"
+#include "flang/Runtime/matmul-transpose.h"
#include "flang/Runtime/matmul.h"
#include "flang/Runtime/transformational.h"
#include "mlir/Dialect/Func/IR/FuncOps.h"
builder.create<fir::CallOp>(loc, func, args);
}
+/// Generate call to MatmulTranspose intrinsic runtime routine.
+void fir::runtime::genMatmulTranspose(fir::FirOpBuilder &builder,
+ mlir::Location loc, mlir::Value resultBox,
+ mlir::Value matrixABox,
+ mlir::Value matrixBBox) {
+ auto func =
+ fir::runtime::getRuntimeFunc<mkRTKey(MatmulTranspose)>(loc, builder);
+ auto fTy = func.getFunctionType();
+ auto sourceFile = fir::factory::locationToFilename(builder, loc);
+ auto sourceLine =
+ fir::factory::locationToLineNo(builder, loc, fTy.getInput(4));
+ auto args =
+ fir::runtime::createArguments(builder, loc, fTy, resultBox, matrixABox,
+ matrixBBox, sourceFile, sourceLine);
+ builder.create<fir::CallOp>(loc, func, args);
+}
+
/// Generate call to Pack intrinsic runtime routine.
void fir::runtime::genPack(fir::FirOpBuilder &builder, mlir::Location loc,
mlir::Value resultBox, mlir::Value arrayBox,
}
};
+struct MatmulTransposeOpConversion
+ : public HlfirIntrinsicConversion<hlfir::MatmulTransposeOp> {
+ using HlfirIntrinsicConversion<
+ hlfir::MatmulTransposeOp>::HlfirIntrinsicConversion;
+
+ mlir::LogicalResult
+ matchAndRewrite(hlfir::MatmulTransposeOp multranspose,
+ mlir::PatternRewriter &rewriter) const override {
+ fir::KindMapping kindMapping{rewriter.getContext()};
+ fir::FirOpBuilder builder{rewriter, kindMapping};
+ const mlir::Location &loc = multranspose->getLoc();
+
+ mlir::Value lhs = multranspose.getLhs();
+ mlir::Value rhs = multranspose.getRhs();
+ llvm::SmallVector<IntrinsicArgument, 2> inArgs;
+ inArgs.push_back({lhs, lhs.getType()});
+ inArgs.push_back({rhs, rhs.getType()});
+
+ auto *argLowering = fir::getIntrinsicArgumentLowering("matmul");
+ llvm::SmallVector<fir::ExtendedValue, 2> args =
+ lowerArguments(multranspose, inArgs, rewriter, argLowering);
+
+ mlir::Type scalarResultType =
+ hlfir::getFortranElementType(multranspose.getType());
+
+ auto [resultExv, mustBeFreed] = fir::genIntrinsicCall(
+ builder, loc, "matmul_transpose", scalarResultType, args);
+
+ processReturnValue(multranspose, resultExv, mustBeFreed, builder, rewriter);
+ return mlir::success();
+ }
+};
+
class LowerHLFIRIntrinsics
: public hlfir::impl::LowerHLFIRIntrinsicsBase<LowerHLFIRIntrinsics> {
public:
mlir::ModuleOp module = this->getOperation();
mlir::MLIRContext *context = &getContext();
mlir::RewritePatternSet patterns(context);
- patterns.insert<MatmulOpConversion, SumOpConversion, TransposeOpConversion>(
- context);
+ patterns.insert<MatmulOpConversion, MatmulTransposeOpConversion,
+ SumOpConversion, TransposeOpConversion>(context);
mlir::ConversionTarget target(*context);
target.addLegalDialect<mlir::BuiltinDialect, mlir::arith::ArithDialect,
mlir::func::FuncDialect, fir::FIROpsDialect,
hlfir::hlfirDialect>();
- target.addIllegalOp<hlfir::MatmulOp, hlfir::SumOp, hlfir::TransposeOp>();
+ target.addIllegalOp<hlfir::MatmulOp, hlfir::MatmulTransposeOp, hlfir::SumOp,
+ hlfir::TransposeOp>();
target.markUnknownOpDynamicallyLegal(
[](mlir::Operation *) { return true; });
if (mlir::failed(
! RUN: bbc -emit-fir -hlfir %s -o - | FileCheck --check-prefix CHECK-BASE --check-prefix CHECK-ALL %s
! RUN: bbc -emit-fir -hlfir %s -o - | fir-opt --canonicalize | FileCheck --check-prefix CHECK-CANONICAL --check-prefix CHECK-ALL %s
! RUN: bbc -emit-fir -hlfir %s -o - | fir-opt --lower-hlfir-intrinsics | FileCheck --check-prefix CHECK-LOWERING --check-prefix CHECK-ALL %s
+! RUN: bbc -emit-fir -hlfir %s -o - | fir-opt --canonicalize | fir-opt --lower-hlfir-intrinsics | FileCheck --check-prefix CHECK-LOWERING-OPT --check-prefix CHECK-ALL %s
! RUN: bbc -emit-fir -hlfir %s -o - | fir-opt --lower-hlfir-intrinsics | fir-opt --bufferize-hlfir | FileCheck --check-prefix CHECK-BUFFERING --check-prefix CHECK-ALL %s
! Test passing a hlfir.expr from one intrinsic to another
! CHECK-LOWERING-NEXT: hlfir.destroy %[[MUL_EXPR]]
! CHECK-LOWERING-NEXT: hlfir.destroy %[[TRANSPOSE_EXPR]]
+! CHECK-LOWERING-OPT: %[[LHS_BOX:.*]] = fir.embox %[[A_DECL]]#1(%{{.*}})
+! CHECK-LOWERING-OPT: %[[B_BOX:.*]] = fir.embox %[[B_DECL]]#1(%{{.*}})
+! CHECK-LOWERING-OPT: %[[MUL_CONV_RES:.*]] = fir.convert %[[MUL_RES_BOX:.*]] : (!fir.ref<!fir.box<!fir.heap<!fir.array<?x?xf32>>>>) -> !fir.ref<!fir.box<none>>
+! CHECK-LOWERING-OPT: %[[LHS_CONV:.*]] = fir.convert %[[LHS_BOX]] : (!fir.box<!fir.array<2x1xf32>>) -> !fir.box<none>
+! CHECK-LOWERING-OPT: %[[B_BOX_CONV:.*]] = fir.convert %[[B_BOX]] : (!fir.box<!fir.array<2x2xf32>>) -> !fir.box<none>
+! CHECK-LOWERING-OPT: fir.call @_FortranAMatmulTranspose(%[[MUL_CONV_RES]], %[[LHS_CONV]], %[[B_BOX_CONV]], %[[LOC_STR2:.*]], %[[LOC_N2:.*]])
+! CHECK-LOWERING-OPT: %[[MUL_RES_LD:.*]] = fir.load %[[MUL_RES_BOX:.*]]
+! CHECK-LOWERING-OPT: %[[MUL_RES_ADDR:.*]] = fir.box_addr %[[MUL_RES_LD]]
+! CHECK-LOWERING-OPT: %[[MUL_RES_VAR:.*]]:2 = hlfir.declare %[[MUL_RES_ADDR]]({{.*}}) {uniq_name = ".tmp.intrinsic_result"}
+! CHECK-LOWERING-OPT: %[[TRUE2:.*]] = arith.constant true
+! CHECK-LOWERING-OPT: %[[MUL_EXPR:.*]] = hlfir.as_expr %[[MUL_RES_VAR]]#0 move %[[TRUE2]] : (!fir.box<!fir.array<?x?xf32>>, i1) -> !hlfir.expr<?x?xf32>
+! CHECK-LOWERING-OPT: hlfir.assign %[[MUL_EXPR]] to %[[RES_DECL]]#0 : !hlfir.expr<?x?xf32>, !fir.ref<!fir.array<1x2xf32>>
+! CHECK-LOWERING-OPT: hlfir.destroy %[[MUL_EXPR]]
+
! [argument handling unchanged]
! CHECK-BUFFERING: fir.call @_FortranATranspose(
! CHECK-BUFFERING: %[[TRANSPOSE_RES_LD:.*]] = fir.load %[[TRANSPOSE_RES_BOX:.*]]