[flang][hlfir] add hlfir.matmul_transpose operation
authorTom Eccles <tom.eccles@arm.com>
Fri, 17 Mar 2023 09:26:27 +0000 (09:26 +0000)
committerTom Eccles <tom.eccles@arm.com>
Fri, 17 Mar 2023 09:30:04 +0000 (09:30 +0000)
This operation will be used to transform MATMUL(TRANSPOSE(a), b). The
transformation will go in the following stages:
        1. Lowering to hlfir.transpose and hlfir.matmul
        2. Canonicalise to hlfir.matmul_transpose
        3. hlfir.matmul_transpose will be lowered to FIR as a new runtime
          library call

Step 2 (and this operation) are included for consistency with the other
hlfir intrinsic operations and to avoid mixing concerns in the intrinsic
lowering pass.

In step 3, a new runtime library call is used because this operation is
most easily implemented in one go (the transposed indexing actually
makes the indexing simpler than for a normal matrix multiplication). In
the long run, it is intended that HLFIR will allow the same buffer
to be shared between different runtime calls without temporary
allocations, but in this specific case we can do even better than that
with a dedicated implementation.

This should speed up galgel from SPEC2000 (but this hadn't been tested
yet). The optimization was implemented in Classic Flang.

Reviewed By: vzakhari

Differential Revision: https://reviews.llvm.org/D145957

flang/docs/HighLevelFIR.md
flang/include/flang/Optimizer/HLFIR/HLFIROps.td
flang/lib/Optimizer/HLFIR/IR/HLFIROps.cpp
flang/test/HLFIR/invalid.fir
flang/test/HLFIR/matmul_transpose.fir [new file with mode: 0644]

index 033ac41..8c671ae 100644 (file)
@@ -652,7 +652,6 @@ Syntax:
 %element = hlfir.apply %array_expr %i, %j: (hlfir.expr<?x?xi32>) -> i32
 ```
 
-
 #### Introducing operations for transformational intrinsic functions
 
 Motivation: Represent transformational intrinsics functions at a high-level so
@@ -701,6 +700,39 @@ call will probably be used since there is little point to keep them high level:
 - selected_char_kind, selected_int_kind, selected_real_kind that returns scalar
   integers
 
+#### Introducing operations for composed intrinsic functions
+
+Motivation: optimize commonly composed intrinsic functions (e.g.
+MATMUL(TRANSPOSE(a), b)). This optimization is implemented in Classic Flang.
+
+An operation and runtime function will be added for each commonly used
+composition of intrinsic functions. The operation will be the canonical way to
+write this chained operation (the MLIR canonicalization pass will rewrite the
+operations for the composed intrinsics into this one operation).
+
+These new operations will be treated as though they were standard
+transformational intrinsic functions.
+
+The composed intrinsic operation will return a hlfir.expr<T>. The arguments
+may be hlfir.expr<T>, boxed arrays, simple scalar types (e.g. i32, f32), or
+variables.
+
+To keep things simple, these operations will only match one form of the composed
+intrinsic functions: therefore there will be no optional arguments.
+
+Syntax:
+```
+%res = hlfir."intrinsic_name" %expr_or_var, ...
+```
+
+The composed intrinsic operation will be lowered to a `fir.call` to the newly
+added runtime implementation of the operation.
+
+These operations should not be added where the only improvement is to avoid
+creating a temporary intermediate buffer which would otherwise be removed by
+intelligent bufferization of a hlfir.expr. Similarly, these should not replace
+profitable uses of hlfir.elemental.
+
 #### Introducing operations for character operations and elemental intrinsic functions
 
 
index b797cd2..86318f6 100644 (file)
@@ -402,6 +402,29 @@ def hlfir_TransposeOp : hlfir_Op<"transpose", []> {
   let hasVerifier = 1;
 }
 
+def hlfir_MatmulTransposeOp : hlfir_Op<"matmul_transpose",
+    [DeclareOpInterfaceMethods<ArithFastMathInterface>]> {
+  let summary = "Optimized MATMUL(TRANSPOSE(...), ...)";
+  let description = [{
+    Matrix multiplication where the left hand side is transposed
+  }];
+
+  let arguments = (ins
+    AnyFortranNumericalOrLogicalArrayObject:$lhs,
+    AnyFortranNumericalOrLogicalArrayObject:$rhs,
+    DefaultValuedAttr<Arith_FastMathAttr,
+                      "::mlir::arith::FastMathFlags::none">:$fastmath
+  );
+
+  let results = (outs hlfir_ExprType);
+
+  let assemblyFormat = [{
+    $lhs $rhs attr-dict `:` functional-type(operands, results)
+  }];
+
+  let hasVerifier = 1;
+}
+
 def hlfir_AssociateOp : hlfir_Op<"associate", [AttrSizedOperandSegments,
     DeclareOpInterfaceMethods<fir_FortranVariableOpInterface>]> {
   let summary = "Create a variable from an expression value";
index 13103c5..0114c12 100644 (file)
@@ -669,6 +669,71 @@ mlir::LogicalResult hlfir::TransposeOp::verify() {
 }
 
 //===----------------------------------------------------------------------===//
+// MatmulTransposeOp
+//===----------------------------------------------------------------------===//
+
+mlir::LogicalResult hlfir::MatmulTransposeOp::verify() {
+  mlir::Value lhs = getLhs();
+  mlir::Value rhs = getRhs();
+  fir::SequenceType lhsTy =
+      hlfir::getFortranElementOrSequenceType(lhs.getType())
+          .cast<fir::SequenceType>();
+  fir::SequenceType rhsTy =
+      hlfir::getFortranElementOrSequenceType(rhs.getType())
+          .cast<fir::SequenceType>();
+  llvm::ArrayRef<int64_t> lhsShape = lhsTy.getShape();
+  llvm::ArrayRef<int64_t> rhsShape = rhsTy.getShape();
+  std::size_t lhsRank = lhsShape.size();
+  std::size_t rhsRank = rhsShape.size();
+  mlir::Type lhsEleTy = lhsTy.getEleTy();
+  mlir::Type rhsEleTy = rhsTy.getEleTy();
+  hlfir::ExprType resultTy = getResult().getType().cast<hlfir::ExprType>();
+  llvm::ArrayRef<int64_t> resultShape = resultTy.getShape();
+  mlir::Type resultEleTy = resultTy.getEleTy();
+
+  // lhs must have rank 2 for the transpose to be valid
+  if ((lhsRank != 2) || ((rhsRank != 1) && (rhsRank != 2)))
+    return emitOpError("array must have either rank 1 or rank 2");
+
+  if (mlir::isa<fir::LogicalType>(lhsEleTy) !=
+      mlir::isa<fir::LogicalType>(rhsEleTy))
+    return emitOpError("if one array is logical, so should the other be");
+
+  // for matmul we compare the last dimension of lhs with the first dimension of
+  // rhs, but for MatmulTranspose, dimensions of lhs are inverted by the
+  // transpose
+  int64_t firstLhsDim = lhsShape[0];
+  int64_t firstRhsDim = rhsShape[0];
+  constexpr int64_t unknownExtent = fir::SequenceType::getUnknownExtent();
+  if (firstLhsDim != firstRhsDim)
+    if ((firstLhsDim != unknownExtent) && (firstRhsDim != unknownExtent))
+      return emitOpError(
+          "the first dimension of LHS should match the first dimension of RHS");
+
+  if (mlir::isa<fir::LogicalType>(lhsEleTy) !=
+      mlir::isa<fir::LogicalType>(resultEleTy))
+    return emitOpError("the result type should be a logical only if the "
+                       "argument types are logical");
+
+  llvm::SmallVector<int64_t, 2> expectedResultShape;
+  if (rhsRank == 2) {
+    expectedResultShape.push_back(lhsShape[1]);
+    expectedResultShape.push_back(rhsShape[1]);
+  } else {
+    // rhsRank == 1
+    expectedResultShape.push_back(lhsShape[1]);
+  }
+  if (resultShape.size() != expectedResultShape.size())
+    return emitOpError("incorrect result shape");
+  if (resultShape[0] != expectedResultShape[0])
+    return emitOpError("incorrect result shape");
+  if (resultShape.size() == 2 && resultShape[1] != expectedResultShape[1])
+    return emitOpError("incorrect result shape");
+
+  return mlir::success();
+}
+
+//===----------------------------------------------------------------------===//
 // AssociateOp
 //===----------------------------------------------------------------------===//
 
index 2ec7c68..a8ba337 100644 (file)
@@ -398,6 +398,48 @@ func.func @bad_transpose3(%arg0: !hlfir.expr<2x3xi32>) {
 }
 
 // -----
+func.func @bad_matmultranspose1(%arg0: !hlfir.expr<?x?x?xi32>, %arg1: !hlfir.expr<?x?xi32>) {
+  // expected-error@+1 {{'hlfir.matmul_transpose' op array must have either rank 1 or rank 2}}
+  %0 = hlfir.matmul_transpose %arg0 %arg1 : (!hlfir.expr<?x?x?xi32>, !hlfir.expr<?x?xi32>) -> !hlfir.expr<?x?xi32>
+  return
+}
+
+// -----
+func.func @bad_matmultranspose2(%arg0: !hlfir.expr<?xi32>, %arg1: !hlfir.expr<?xi32>) {
+  // expected-error@+1 {{'hlfir.matmul_transpose' op array must have either rank 1 or rank 2}}
+  %0 = hlfir.matmul_transpose %arg0 %arg1 : (!hlfir.expr<?xi32>, !hlfir.expr<?xi32>) -> !hlfir.expr<?x?xi32>
+  return
+}
+
+// -----
+func.func @bad_matmultranspose3(%arg0: !hlfir.expr<?x?x!fir.logical<4>>, %arg1: !hlfir.expr<?x?xi32>) {
+  // expected-error@+1 {{'hlfir.matmul_transpose' op if one array is logical, so should the other be}}
+  %0 = hlfir.matmul_transpose %arg0 %arg1 : (!hlfir.expr<?x?x!fir.logical<4>>, !hlfir.expr<?x?xi32>) -> !hlfir.expr<?x?xi32>
+  return
+}
+
+// -----
+func.func @bad_matmultranspose5(%arg0: !hlfir.expr<?x?xi32>, %arg1: !hlfir.expr<?x?xi32>) {
+  // expected-error@+1 {{'hlfir.matmul_transpose' op the result type should be a logical only if the argument types are logical}}
+  %0 = hlfir.matmul_transpose %arg0 %arg1 : (!hlfir.expr<?x?xi32>, !hlfir.expr<?x?xi32>) -> !hlfir.expr<?x?x!fir.logical<4>>
+  return
+}
+
+// -----
+func.func @bad_matmultranspose6(%arg0: !hlfir.expr<2x1xi32>, %arg1: !hlfir.expr<2x3xi32>) {
+  // expected-error@+1 {{'hlfir.matmul_transpose' op incorrect result shape}}
+  %0 = hlfir.matmul_transpose %arg0 %arg1 : (!hlfir.expr<2x1xi32>, !hlfir.expr<2x3xi32>) -> !hlfir.expr<10x30xi32>
+  return
+}
+
+// -----
+func.func @bad_matmultranspose7(%arg0: !hlfir.expr<2x1xi32>, %arg1: !hlfir.expr<2xi32>) {
+  // expected-error@+1 {{'hlfir.matmul_transpose' op incorrect result shape}}
+  %0 = hlfir.matmul_transpose %arg0 %arg1 : (!hlfir.expr<2x1xi32>, !hlfir.expr<2xi32>) -> !hlfir.expr<1x3xi32>
+  return
+}
+
+// -----
 func.func @bad_assign_1(%arg0: !fir.box<!fir.array<?xi32>>, %arg1: !fir.box<!fir.array<?xi32>>) {
   // expected-error@+1 {{'hlfir.assign' op lhs must be an allocatable when `realloc` is set}}
   hlfir.assign %arg1 to %arg0 realloc : !fir.box<!fir.array<?xi32>>, !fir.box<!fir.array<?xi32>>
diff --git a/flang/test/HLFIR/matmul_transpose.fir b/flang/test/HLFIR/matmul_transpose.fir
new file mode 100644 (file)
index 0000000..967edec
--- /dev/null
@@ -0,0 +1,87 @@
+// Test hlfir.matmul_transpose operation parse, verify (no errors), and unparse
+
+// RUN: fir-opt %s | fir-opt | FileCheck %s
+
+// arguments are expressions of known shape
+func.func @matmul_transpose0(%arg0: !hlfir.expr<2x2xi32>, %arg1: !hlfir.expr<2x2xi32>) {
+  %res = hlfir.matmul_transpose %arg0 %arg1 : (!hlfir.expr<2x2xi32>, !hlfir.expr<2x2xi32>) -> !hlfir.expr<2x2xi32>
+  return
+}
+// CHECK-LABEL: func.func @matmul_transpose0
+// CHECK:           %[[ARG0:.*]]: !hlfir.expr<2x2xi32>,
+// CHECK:           %[[ARG1:.*]]: !hlfir.expr<2x2xi32>) {
+// CHECK-NEXT:    %[[RES:.*]] = hlfir.matmul_transpose %[[ARG0]] %[[ARG1]] : (!hlfir.expr<2x2xi32>, !hlfir.expr<2x2xi32>) -> !hlfir.expr<2x2xi32>
+// CHECK-NEXT:    return
+// CHECK-NEXT:  }
+
+// arguments are expressions of assumed shape
+func.func @matmul_transpose1(%arg0: !hlfir.expr<?x?xi32>, %arg1: !hlfir.expr<?x?xi32>) {
+  %res = hlfir.matmul_transpose %arg0 %arg1 : (!hlfir.expr<?x?xi32>, !hlfir.expr<?x?xi32>) -> !hlfir.expr<?x?xi32>
+  return
+}
+// CHECK-LABEL: func.func @matmul_transpose1
+// CHECK:           %[[ARG0:.*]]: !hlfir.expr<?x?xi32>,
+// CHECK:           %[[ARG1:.*]]: !hlfir.expr<?x?xi32>) {
+// CHECK-NEXT:    %[[RES:.*]] = hlfir.matmul_transpose %[[ARG0]] %[[ARG1]] : (!hlfir.expr<?x?xi32>, !hlfir.expr<?x?xi32>) -> !hlfir.expr<?x?xi32>
+// CHECK-NEXT:    return
+// CHECK-NEXT:  }
+
+// arguments are expressions where only some dimensions are known #1
+func.func @matmul_transpose2(%arg0: !hlfir.expr<?x2xi32>, %arg1: !hlfir.expr<?x2xi32>) {
+  %res = hlfir.matmul_transpose %arg0 %arg1 : (!hlfir.expr<?x2xi32>, !hlfir.expr<?x2xi32>) -> !hlfir.expr<2x2xi32>
+  return
+}
+// CHECK-LABEL: func.func @matmul_transpose2
+// CHECK:           %[[ARG0:.*]]: !hlfir.expr<?x2xi32>,
+// CHECK:           %[[ARG1:.*]]: !hlfir.expr<?x2xi32>) {
+// CHECK-NEXT:    %[[RES:.*]] = hlfir.matmul_transpose %[[ARG0]] %[[ARG1]] : (!hlfir.expr<?x2xi32>, !hlfir.expr<?x2xi32>) -> !hlfir.expr<2x2xi32>
+// CHECK-NEXT:    return
+// CHECK-NEXT:  }
+
+// arguments are expressions where only some dimensions are known #2
+func.func @matmul_transpose3(%arg0: !hlfir.expr<2x?xi32>, %arg1: !hlfir.expr<2x?xi32>) {
+  %res = hlfir.matmul_transpose %arg0 %arg1 : (!hlfir.expr<2x?xi32>, !hlfir.expr<2x?xi32>) -> !hlfir.expr<?x?xi32>
+  return
+}
+// CHECK-LABEL: func.func @matmul_transpose3
+// CHECK:           %[[ARG0:.*]]: !hlfir.expr<2x?xi32>,
+// CHECK:           %[[ARG1:.*]]: !hlfir.expr<2x?xi32>) {
+// CHECK-NEXT:    %[[RES:.*]] = hlfir.matmul_transpose %[[ARG0]] %[[ARG1]] : (!hlfir.expr<2x?xi32>, !hlfir.expr<2x?xi32>) -> !hlfir.expr<?x?xi32>
+// CHECK-NEXT:    return
+// CHECK-NEXT:  }
+
+// arguments are logicals
+func.func @matmul_transpose4(%arg0: !hlfir.expr<?x?x!fir.logical<4>>, %arg1: !hlfir.expr<?x?x!fir.logical<4>>) {
+  %res = hlfir.matmul_transpose %arg0 %arg1 : (!hlfir.expr<?x?x!fir.logical<4>>, !hlfir.expr<?x?x!fir.logical<4>>) -> !hlfir.expr<?x?x!fir.logical<4>>
+  return
+}
+// CHECK-LABEL: func.func @matmul_transpose4
+// CHECK:           %[[ARG0:.*]]: !hlfir.expr<?x?x!fir.logical<4>>,
+// CHECK:           %[[ARG1:.*]]: !hlfir.expr<?x?x!fir.logical<4>>) {
+// CHECK-NEXT:    %[[RES:.*]] = hlfir.matmul_transpose %[[ARG0]] %[[ARG1]] : (!hlfir.expr<?x?x!fir.logical<4>>, !hlfir.expr<?x?x!fir.logical<4>>) -> !hlfir.expr<?x?x!fir.logical<4>>
+// CHECK-NEXT:    return
+// CHECK-NEXT:  }
+
+// rhs is rank 1
+func.func @matmul_transpose6(%arg0: !hlfir.expr<?x?xi32>, %arg1: !hlfir.expr<?xi32>) {
+  %res = hlfir.matmul_transpose %arg0 %arg1 : (!hlfir.expr<?x?xi32>, !hlfir.expr<?xi32>) -> !hlfir.expr<?xi32>
+  return
+}
+// CHECK-LABEL: func.func @matmul_transpose6
+// CHECK:           %[[ARG0:.*]]: !hlfir.expr<?x?xi32>,
+// CHECK:           %[[ARG1:.*]]: !hlfir.expr<?xi32>) {
+// CHECK-NEXT:    %[[RES:.*]] = hlfir.matmul_transpose %[[ARG0]] %[[ARG1]] : (!hlfir.expr<?x?xi32>, !hlfir.expr<?xi32>) -> !hlfir.expr<?xi32>
+// CHECK-NEXT:    return
+// CHECK-NEXT:  }
+
+// arguments are boxed arrays
+func.func @matmul_transpose7(%arg0: !fir.box<!fir.array<2x2xf32>>, %arg1: !fir.box<!fir.array<2x2xf32>>) {
+  %res = hlfir.matmul_transpose %arg0 %arg1 : (!fir.box<!fir.array<2x2xf32>>, !fir.box<!fir.array<2x2xf32>>) -> !hlfir.expr<2x2xf32>
+  return
+}
+// CHECK-LABEL: func.func @matmul_transpose7
+// CHECK:           %[[ARG0:.*]]: !fir.box<!fir.array<2x2xf32>>,
+// CHECK:           %[[ARG1:.*]]: !fir.box<!fir.array<2x2xf32>>) {
+// CHECK-NEXT:    %[[RES:.*]] = hlfir.matmul_transpose %[[ARG0]] %[[ARG1]] : (!fir.box<!fir.array<2x2xf32>>, !fir.box<!fir.array<2x2xf32>>) -> !hlfir.expr<2x2xf32>
+// CHECK-NEXT:    return
+// CHECK-NEXT:  }