From 49bd444fc3617a140ef67047d756c4d652a2a835 Mon Sep 17 00:00:00 2001 From: Tom Eccles Date: Fri, 17 Mar 2023 09:26:27 +0000 Subject: [PATCH] [flang][hlfir] add hlfir.matmul_transpose operation This operation will be used to transform MATMUL(TRANSPOSE(a), b). The transformation will go in the following stages: 1. Lowering to hlfir.transpose and hlfir.matmul 2. Canonicalise to hlfir.matmul_transpose 3. hlfir.matmul_transpose will be lowered to FIR as a new runtime library call Step 2 (and this operation) are included for consistency with the other hlfir intrinsic operations and to avoid mixing concerns in the intrinsic lowering pass. In step 3, a new runtime library call is used because this operation is most easily implemented in one go (the transposed indexing actually makes the indexing simpler than for a normal matrix multiplication). In the long run, it is intended that HLFIR will allow the same buffer to be shared between different runtime calls without temporary allocations, but in this specific case we can do even better than that with a dedicated implementation. This should speed up galgel from SPEC2000 (but this hadn't been tested yet). The optimization was implemented in Classic Flang. Reviewed By: vzakhari Differential Revision: https://reviews.llvm.org/D145957 --- flang/docs/HighLevelFIR.md | 34 +++++++++- flang/include/flang/Optimizer/HLFIR/HLFIROps.td | 23 +++++++ flang/lib/Optimizer/HLFIR/IR/HLFIROps.cpp | 65 ++++++++++++++++++ flang/test/HLFIR/invalid.fir | 42 ++++++++++++ flang/test/HLFIR/matmul_transpose.fir | 87 +++++++++++++++++++++++++ 5 files changed, 250 insertions(+), 1 deletion(-) create mode 100644 flang/test/HLFIR/matmul_transpose.fir diff --git a/flang/docs/HighLevelFIR.md b/flang/docs/HighLevelFIR.md index 033ac41..8c671ae 100644 --- a/flang/docs/HighLevelFIR.md +++ b/flang/docs/HighLevelFIR.md @@ -652,7 +652,6 @@ Syntax: %element = hlfir.apply %array_expr %i, %j: (hlfir.expr) -> i32 ``` - #### Introducing operations for transformational intrinsic functions Motivation: Represent transformational intrinsics functions at a high-level so @@ -701,6 +700,39 @@ call will probably be used since there is little point to keep them high level: - selected_char_kind, selected_int_kind, selected_real_kind that returns scalar integers +#### Introducing operations for composed intrinsic functions + +Motivation: optimize commonly composed intrinsic functions (e.g. +MATMUL(TRANSPOSE(a), b)). This optimization is implemented in Classic Flang. + +An operation and runtime function will be added for each commonly used +composition of intrinsic functions. The operation will be the canonical way to +write this chained operation (the MLIR canonicalization pass will rewrite the +operations for the composed intrinsics into this one operation). + +These new operations will be treated as though they were standard +transformational intrinsic functions. + +The composed intrinsic operation will return a hlfir.expr. The arguments +may be hlfir.expr, boxed arrays, simple scalar types (e.g. i32, f32), or +variables. + +To keep things simple, these operations will only match one form of the composed +intrinsic functions: therefore there will be no optional arguments. + +Syntax: +``` +%res = hlfir."intrinsic_name" %expr_or_var, ... +``` + +The composed intrinsic operation will be lowered to a `fir.call` to the newly +added runtime implementation of the operation. + +These operations should not be added where the only improvement is to avoid +creating a temporary intermediate buffer which would otherwise be removed by +intelligent bufferization of a hlfir.expr. Similarly, these should not replace +profitable uses of hlfir.elemental. + #### Introducing operations for character operations and elemental intrinsic functions diff --git a/flang/include/flang/Optimizer/HLFIR/HLFIROps.td b/flang/include/flang/Optimizer/HLFIR/HLFIROps.td index b797cd2..86318f6 100644 --- a/flang/include/flang/Optimizer/HLFIR/HLFIROps.td +++ b/flang/include/flang/Optimizer/HLFIR/HLFIROps.td @@ -402,6 +402,29 @@ def hlfir_TransposeOp : hlfir_Op<"transpose", []> { let hasVerifier = 1; } +def hlfir_MatmulTransposeOp : hlfir_Op<"matmul_transpose", + [DeclareOpInterfaceMethods]> { + let summary = "Optimized MATMUL(TRANSPOSE(...), ...)"; + let description = [{ + Matrix multiplication where the left hand side is transposed + }]; + + let arguments = (ins + AnyFortranNumericalOrLogicalArrayObject:$lhs, + AnyFortranNumericalOrLogicalArrayObject:$rhs, + DefaultValuedAttr:$fastmath + ); + + let results = (outs hlfir_ExprType); + + let assemblyFormat = [{ + $lhs $rhs attr-dict `:` functional-type(operands, results) + }]; + + let hasVerifier = 1; +} + def hlfir_AssociateOp : hlfir_Op<"associate", [AttrSizedOperandSegments, DeclareOpInterfaceMethods]> { let summary = "Create a variable from an expression value"; diff --git a/flang/lib/Optimizer/HLFIR/IR/HLFIROps.cpp b/flang/lib/Optimizer/HLFIR/IR/HLFIROps.cpp index 13103c5..0114c12 100644 --- a/flang/lib/Optimizer/HLFIR/IR/HLFIROps.cpp +++ b/flang/lib/Optimizer/HLFIR/IR/HLFIROps.cpp @@ -669,6 +669,71 @@ mlir::LogicalResult hlfir::TransposeOp::verify() { } //===----------------------------------------------------------------------===// +// MatmulTransposeOp +//===----------------------------------------------------------------------===// + +mlir::LogicalResult hlfir::MatmulTransposeOp::verify() { + mlir::Value lhs = getLhs(); + mlir::Value rhs = getRhs(); + fir::SequenceType lhsTy = + hlfir::getFortranElementOrSequenceType(lhs.getType()) + .cast(); + fir::SequenceType rhsTy = + hlfir::getFortranElementOrSequenceType(rhs.getType()) + .cast(); + llvm::ArrayRef lhsShape = lhsTy.getShape(); + llvm::ArrayRef rhsShape = rhsTy.getShape(); + std::size_t lhsRank = lhsShape.size(); + std::size_t rhsRank = rhsShape.size(); + mlir::Type lhsEleTy = lhsTy.getEleTy(); + mlir::Type rhsEleTy = rhsTy.getEleTy(); + hlfir::ExprType resultTy = getResult().getType().cast(); + llvm::ArrayRef resultShape = resultTy.getShape(); + mlir::Type resultEleTy = resultTy.getEleTy(); + + // lhs must have rank 2 for the transpose to be valid + if ((lhsRank != 2) || ((rhsRank != 1) && (rhsRank != 2))) + return emitOpError("array must have either rank 1 or rank 2"); + + if (mlir::isa(lhsEleTy) != + mlir::isa(rhsEleTy)) + return emitOpError("if one array is logical, so should the other be"); + + // for matmul we compare the last dimension of lhs with the first dimension of + // rhs, but for MatmulTranspose, dimensions of lhs are inverted by the + // transpose + int64_t firstLhsDim = lhsShape[0]; + int64_t firstRhsDim = rhsShape[0]; + constexpr int64_t unknownExtent = fir::SequenceType::getUnknownExtent(); + if (firstLhsDim != firstRhsDim) + if ((firstLhsDim != unknownExtent) && (firstRhsDim != unknownExtent)) + return emitOpError( + "the first dimension of LHS should match the first dimension of RHS"); + + if (mlir::isa(lhsEleTy) != + mlir::isa(resultEleTy)) + return emitOpError("the result type should be a logical only if the " + "argument types are logical"); + + llvm::SmallVector expectedResultShape; + if (rhsRank == 2) { + expectedResultShape.push_back(lhsShape[1]); + expectedResultShape.push_back(rhsShape[1]); + } else { + // rhsRank == 1 + expectedResultShape.push_back(lhsShape[1]); + } + if (resultShape.size() != expectedResultShape.size()) + return emitOpError("incorrect result shape"); + if (resultShape[0] != expectedResultShape[0]) + return emitOpError("incorrect result shape"); + if (resultShape.size() == 2 && resultShape[1] != expectedResultShape[1]) + return emitOpError("incorrect result shape"); + + return mlir::success(); +} + +//===----------------------------------------------------------------------===// // AssociateOp //===----------------------------------------------------------------------===// diff --git a/flang/test/HLFIR/invalid.fir b/flang/test/HLFIR/invalid.fir index 2ec7c68..a8ba337 100644 --- a/flang/test/HLFIR/invalid.fir +++ b/flang/test/HLFIR/invalid.fir @@ -398,6 +398,48 @@ func.func @bad_transpose3(%arg0: !hlfir.expr<2x3xi32>) { } // ----- +func.func @bad_matmultranspose1(%arg0: !hlfir.expr, %arg1: !hlfir.expr) { + // expected-error@+1 {{'hlfir.matmul_transpose' op array must have either rank 1 or rank 2}} + %0 = hlfir.matmul_transpose %arg0 %arg1 : (!hlfir.expr, !hlfir.expr) -> !hlfir.expr + return +} + +// ----- +func.func @bad_matmultranspose2(%arg0: !hlfir.expr, %arg1: !hlfir.expr) { + // expected-error@+1 {{'hlfir.matmul_transpose' op array must have either rank 1 or rank 2}} + %0 = hlfir.matmul_transpose %arg0 %arg1 : (!hlfir.expr, !hlfir.expr) -> !hlfir.expr + return +} + +// ----- +func.func @bad_matmultranspose3(%arg0: !hlfir.expr>, %arg1: !hlfir.expr) { + // expected-error@+1 {{'hlfir.matmul_transpose' op if one array is logical, so should the other be}} + %0 = hlfir.matmul_transpose %arg0 %arg1 : (!hlfir.expr>, !hlfir.expr) -> !hlfir.expr + return +} + +// ----- +func.func @bad_matmultranspose5(%arg0: !hlfir.expr, %arg1: !hlfir.expr) { + // expected-error@+1 {{'hlfir.matmul_transpose' op the result type should be a logical only if the argument types are logical}} + %0 = hlfir.matmul_transpose %arg0 %arg1 : (!hlfir.expr, !hlfir.expr) -> !hlfir.expr> + return +} + +// ----- +func.func @bad_matmultranspose6(%arg0: !hlfir.expr<2x1xi32>, %arg1: !hlfir.expr<2x3xi32>) { + // expected-error@+1 {{'hlfir.matmul_transpose' op incorrect result shape}} + %0 = hlfir.matmul_transpose %arg0 %arg1 : (!hlfir.expr<2x1xi32>, !hlfir.expr<2x3xi32>) -> !hlfir.expr<10x30xi32> + return +} + +// ----- +func.func @bad_matmultranspose7(%arg0: !hlfir.expr<2x1xi32>, %arg1: !hlfir.expr<2xi32>) { + // expected-error@+1 {{'hlfir.matmul_transpose' op incorrect result shape}} + %0 = hlfir.matmul_transpose %arg0 %arg1 : (!hlfir.expr<2x1xi32>, !hlfir.expr<2xi32>) -> !hlfir.expr<1x3xi32> + return +} + +// ----- func.func @bad_assign_1(%arg0: !fir.box>, %arg1: !fir.box>) { // expected-error@+1 {{'hlfir.assign' op lhs must be an allocatable when `realloc` is set}} hlfir.assign %arg1 to %arg0 realloc : !fir.box>, !fir.box> diff --git a/flang/test/HLFIR/matmul_transpose.fir b/flang/test/HLFIR/matmul_transpose.fir new file mode 100644 index 0000000..967edec --- /dev/null +++ b/flang/test/HLFIR/matmul_transpose.fir @@ -0,0 +1,87 @@ +// Test hlfir.matmul_transpose operation parse, verify (no errors), and unparse + +// RUN: fir-opt %s | fir-opt | FileCheck %s + +// arguments are expressions of known shape +func.func @matmul_transpose0(%arg0: !hlfir.expr<2x2xi32>, %arg1: !hlfir.expr<2x2xi32>) { + %res = hlfir.matmul_transpose %arg0 %arg1 : (!hlfir.expr<2x2xi32>, !hlfir.expr<2x2xi32>) -> !hlfir.expr<2x2xi32> + return +} +// CHECK-LABEL: func.func @matmul_transpose0 +// CHECK: %[[ARG0:.*]]: !hlfir.expr<2x2xi32>, +// CHECK: %[[ARG1:.*]]: !hlfir.expr<2x2xi32>) { +// CHECK-NEXT: %[[RES:.*]] = hlfir.matmul_transpose %[[ARG0]] %[[ARG1]] : (!hlfir.expr<2x2xi32>, !hlfir.expr<2x2xi32>) -> !hlfir.expr<2x2xi32> +// CHECK-NEXT: return +// CHECK-NEXT: } + +// arguments are expressions of assumed shape +func.func @matmul_transpose1(%arg0: !hlfir.expr, %arg1: !hlfir.expr) { + %res = hlfir.matmul_transpose %arg0 %arg1 : (!hlfir.expr, !hlfir.expr) -> !hlfir.expr + return +} +// CHECK-LABEL: func.func @matmul_transpose1 +// CHECK: %[[ARG0:.*]]: !hlfir.expr, +// CHECK: %[[ARG1:.*]]: !hlfir.expr) { +// CHECK-NEXT: %[[RES:.*]] = hlfir.matmul_transpose %[[ARG0]] %[[ARG1]] : (!hlfir.expr, !hlfir.expr) -> !hlfir.expr +// CHECK-NEXT: return +// CHECK-NEXT: } + +// arguments are expressions where only some dimensions are known #1 +func.func @matmul_transpose2(%arg0: !hlfir.expr, %arg1: !hlfir.expr) { + %res = hlfir.matmul_transpose %arg0 %arg1 : (!hlfir.expr, !hlfir.expr) -> !hlfir.expr<2x2xi32> + return +} +// CHECK-LABEL: func.func @matmul_transpose2 +// CHECK: %[[ARG0:.*]]: !hlfir.expr, +// CHECK: %[[ARG1:.*]]: !hlfir.expr) { +// CHECK-NEXT: %[[RES:.*]] = hlfir.matmul_transpose %[[ARG0]] %[[ARG1]] : (!hlfir.expr, !hlfir.expr) -> !hlfir.expr<2x2xi32> +// CHECK-NEXT: return +// CHECK-NEXT: } + +// arguments are expressions where only some dimensions are known #2 +func.func @matmul_transpose3(%arg0: !hlfir.expr<2x?xi32>, %arg1: !hlfir.expr<2x?xi32>) { + %res = hlfir.matmul_transpose %arg0 %arg1 : (!hlfir.expr<2x?xi32>, !hlfir.expr<2x?xi32>) -> !hlfir.expr + return +} +// CHECK-LABEL: func.func @matmul_transpose3 +// CHECK: %[[ARG0:.*]]: !hlfir.expr<2x?xi32>, +// CHECK: %[[ARG1:.*]]: !hlfir.expr<2x?xi32>) { +// CHECK-NEXT: %[[RES:.*]] = hlfir.matmul_transpose %[[ARG0]] %[[ARG1]] : (!hlfir.expr<2x?xi32>, !hlfir.expr<2x?xi32>) -> !hlfir.expr +// CHECK-NEXT: return +// CHECK-NEXT: } + +// arguments are logicals +func.func @matmul_transpose4(%arg0: !hlfir.expr>, %arg1: !hlfir.expr>) { + %res = hlfir.matmul_transpose %arg0 %arg1 : (!hlfir.expr>, !hlfir.expr>) -> !hlfir.expr> + return +} +// CHECK-LABEL: func.func @matmul_transpose4 +// CHECK: %[[ARG0:.*]]: !hlfir.expr>, +// CHECK: %[[ARG1:.*]]: !hlfir.expr>) { +// CHECK-NEXT: %[[RES:.*]] = hlfir.matmul_transpose %[[ARG0]] %[[ARG1]] : (!hlfir.expr>, !hlfir.expr>) -> !hlfir.expr> +// CHECK-NEXT: return +// CHECK-NEXT: } + +// rhs is rank 1 +func.func @matmul_transpose6(%arg0: !hlfir.expr, %arg1: !hlfir.expr) { + %res = hlfir.matmul_transpose %arg0 %arg1 : (!hlfir.expr, !hlfir.expr) -> !hlfir.expr + return +} +// CHECK-LABEL: func.func @matmul_transpose6 +// CHECK: %[[ARG0:.*]]: !hlfir.expr, +// CHECK: %[[ARG1:.*]]: !hlfir.expr) { +// CHECK-NEXT: %[[RES:.*]] = hlfir.matmul_transpose %[[ARG0]] %[[ARG1]] : (!hlfir.expr, !hlfir.expr) -> !hlfir.expr +// CHECK-NEXT: return +// CHECK-NEXT: } + +// arguments are boxed arrays +func.func @matmul_transpose7(%arg0: !fir.box>, %arg1: !fir.box>) { + %res = hlfir.matmul_transpose %arg0 %arg1 : (!fir.box>, !fir.box>) -> !hlfir.expr<2x2xf32> + return +} +// CHECK-LABEL: func.func @matmul_transpose7 +// CHECK: %[[ARG0:.*]]: !fir.box>, +// CHECK: %[[ARG1:.*]]: !fir.box>) { +// CHECK-NEXT: %[[RES:.*]] = hlfir.matmul_transpose %[[ARG0]] %[[ARG1]] : (!fir.box>, !fir.box>) -> !hlfir.expr<2x2xf32> +// CHECK-NEXT: return +// CHECK-NEXT: } -- 2.7.4