[flang] add hlfir.sum operation
authorTom Eccles <tom.eccles@arm.com>
Tue, 17 Jan 2023 17:37:15 +0000 (17:37 +0000)
committerTom Eccles <tom.eccles@arm.com>
Mon, 13 Feb 2023 10:50:11 +0000 (10:50 +0000)
Add an HLFIR operation for the SUM transformational intrinsic, according
to the design set out in flang/doc/HighLevelFIR.md.

I decided to make hlfir.sum very lenient about the form of its
arguments. This allows the sum intrinsic to be lowered to only this HLFIR
operation, without needing several operations to convert and box
arguments. Having only one operation generated for the intrinsic
invocation should make optimisation passes on HLFIR simpler.

Differential Revision: https://reviews.llvm.org/D142897

flang/include/flang/Optimizer/Builder/HLFIRTools.h
flang/include/flang/Optimizer/HLFIR/HLFIRDialect.h
flang/include/flang/Optimizer/HLFIR/HLFIROpBase.td
flang/include/flang/Optimizer/HLFIR/HLFIROps.h
flang/include/flang/Optimizer/HLFIR/HLFIROps.td
flang/lib/Optimizer/Builder/HLFIRTools.cpp
flang/lib/Optimizer/HLFIR/IR/HLFIRDialect.cpp
flang/lib/Optimizer/HLFIR/IR/HLFIROps.cpp
flang/test/HLFIR/invalid.fir
flang/test/HLFIR/sum.fir [new file with mode: 0644]

index e610478..389f5ce 100644 (file)
@@ -373,6 +373,7 @@ convertToAddress(mlir::Location loc, fir::FirOpBuilder &builder,
 std::pair<fir::ExtendedValue, std::optional<hlfir::CleanupFunction>>
 convertToBox(mlir::Location loc, fir::FirOpBuilder &builder,
              const hlfir::Entity &entity, mlir::Type targetType);
+
 } // namespace hlfir
 
 #endif // FORTRAN_OPTIMIZER_BUILDER_HLFIRTOOLS_H
index 8aace6b..6a9acb4 100644 (file)
@@ -69,6 +69,13 @@ inline bool isBoxAddressOrValueType(mlir::Type type) {
   return fir::unwrapRefType(type).isa<fir::BaseBoxType>();
 }
 
+bool isFortranScalarNumericalType(mlir::Type);
+bool isFortranNumericalArrayObject(mlir::Type);
+bool isPassByRefOrIntegerType(mlir::Type);
+bool isI1Type(mlir::Type);
+// scalar i1 or logical, or sequence of logical (via (boxed?) array or expr)
+bool isMaskArgument(mlir::Type);
+
 } // namespace hlfir
 
 #endif // FORTRAN_OPTIMIZER_HLFIR_HLFIRDIALECT_H
index d17a4cb..23ad2ed 100644 (file)
@@ -107,4 +107,19 @@ def IsFortranScalarCharacterExprPred
 def AnyScalarCharacterExpr : Type<IsFortranScalarCharacterExprPred,
     "any character scalar expression type">;
 
+def IsFortranNumericalArrayObjectPred
+        : CPred<"::hlfir::isFortranNumericalArrayObject($_self)">;
+def AnyFortranNumericalArrayObject : Type<IsFortranNumericalArrayObjectPred,
+    "any array-like object containing a numerical type">;
+
+def IsPassByRefOrIntegerTypePred
+        : CPred<"::hlfir::isPassByRefOrIntegerType($_self)">;
+def AnyPassByRefOrIntegerType : Type<IsPassByRefOrIntegerTypePred,
+    "an integer type either by value or by reference">;
+
+def IsMaskArgumentPred
+        : CPred<"::hlfir::isMaskArgument($_self)">;
+def AnyFortranLogicalOrI1ArrayObject : Type<IsMaskArgumentPred,
+    "A scalar i1 or logical or an array-like object containing logicals">;
+
 #endif // FORTRAN_DIALECT_HLFIR_OP_BASE
index 33530d1..e0e7183 100644 (file)
@@ -13,6 +13,7 @@
 #include "flang/Optimizer/Dialect/FIRType.h"
 #include "flang/Optimizer/Dialect/FortranVariableInterface.h"
 #include "flang/Optimizer/HLFIR/HLFIRDialect.h"
+#include "mlir/Dialect/Arith/IR/Arith.h"
 #include "mlir/Interfaces/InferTypeOpInterface.h"
 #include "mlir/Interfaces/SideEffectInterfaces.h"
 
index 3228685..1321eab 100644 (file)
@@ -18,6 +18,8 @@ include "flang/Optimizer/HLFIR/HLFIROpBase.td"
 include "flang/Optimizer/Dialect/FIRTypes.td"
 include "flang/Optimizer/Dialect/FIRAttr.td"
 include "flang/Optimizer/Dialect/FortranVariableInterface.td"
+include "mlir/Dialect/Arith/IR/ArithBase.td"
+include "mlir/Dialect/Arith/IR/ArithOpsInterfaces.td"
 include "mlir/IR/BuiltinAttributes.td"
 
 // Base class for FIR operations.
@@ -256,6 +258,31 @@ def hlfir_SetLengthOp : hlfir_Op<"set_length", []> {
   let builders = [OpBuilder<(ins "mlir::Value":$string,"mlir::Value":$len)>];
 }
 
+def hlfir_SumOp : hlfir_Op<"sum", [AttrSizedOperandSegments,
+    DeclareOpInterfaceMethods<ArithFastMathInterface>]> {
+  let summary = "SUM transformational intrinsic";
+  let description = [{
+    Sums the elements of an array, optionally along a particular dimension,
+    optionally if a mask is true.
+  }];
+
+  let arguments = (ins
+    AnyFortranNumericalArrayObject:$array,
+    Optional<AnyIntegerType>:$dim,
+    Optional<AnyFortranLogicalOrI1ArrayObject>:$mask,
+    DefaultValuedAttr<Arith_FastMathAttr,
+                      "::mlir::arith::FastMathFlags::none">:$fastmath
+  );
+
+  let results = (outs hlfir_ExprType);
+
+  let assemblyFormat = [{
+    $array (`dim` $dim^)? (`mask` $mask^)? attr-dict `:` functional-type(operands, results)
+  }];
+
+  let hasVerifier = 1;
+}
+
 def hlfir_AssociateOp : hlfir_Op<"associate", [AttrSizedOperandSegments,
     DeclareOpInterfaceMethods<fir_FortranVariableOpInterface>]> {
   let summary = "Create a variable from an expression value";
index 41cc800..072fb5c 100644 (file)
@@ -17,6 +17,8 @@
 #include "flang/Optimizer/Builder/Todo.h"
 #include "flang/Optimizer/HLFIR/HLFIROps.h"
 #include "mlir/IR/IRMapping.h"
+#include "mlir/Support/LLVM.h"
+#include "llvm/ADT/TypeSwitch.h"
 #include <optional>
 
 // Return explicit extents. If the base is a fir.box, this won't read it to
index fbb7d77..f23be5d 100644 (file)
@@ -11,6 +11,7 @@
 //===----------------------------------------------------------------------===//
 
 #include "flang/Optimizer/HLFIR/HLFIRDialect.h"
+#include "flang/Optimizer/Dialect/FIRType.h"
 #include "flang/Optimizer/HLFIR/HLFIROps.h"
 #include "mlir/IR/Builders.h"
 #include "mlir/IR/BuiltinTypes.h"
@@ -99,3 +100,43 @@ bool hlfir::isFortranScalarCharacterExprType(mlir::Type type) {
            exprType.getElementType().isa<fir::CharacterType>();
   return false;
 }
+
+bool hlfir::isFortranScalarNumericalType(mlir::Type type) {
+  return fir::isa_integer(type) || fir::isa_real(type) ||
+         fir::isa_complex(type);
+}
+
+bool hlfir::isFortranNumericalArrayObject(mlir::Type type) {
+  if (isBoxAddressType(type))
+    return false;
+  if (auto arrayTy =
+          getFortranElementOrSequenceType(type).dyn_cast<fir::SequenceType>())
+    return isFortranScalarNumericalType(arrayTy.getEleTy());
+  return false;
+}
+
+bool hlfir::isPassByRefOrIntegerType(mlir::Type type) {
+  mlir::Type unwrappedType = fir::unwrapPassByRefType(type);
+  return fir::isa_integer(unwrappedType);
+}
+
+bool hlfir::isI1Type(mlir::Type type) {
+  if (mlir::IntegerType integer = type.dyn_cast<mlir::IntegerType>())
+    if (integer.getWidth() == 1)
+      return true;
+  return false;
+}
+
+bool hlfir::isMaskArgument(mlir::Type type) {
+  if (isBoxAddressType(type))
+    return false;
+
+  mlir::Type unwrappedType = fir::unwrapPassByRefType(fir::unwrapRefType(type));
+  mlir::Type elementType = getFortranElementType(unwrappedType);
+  if (unwrappedType != elementType)
+    // input type is an array
+    return mlir::isa<fir::LogicalType>(elementType);
+
+  // input is a scalar, so allow i1 too
+  return mlir::isa<fir::LogicalType>(elementType) || isI1Type(elementType);
+}
index 9bf5601..feea448 100644 (file)
 //===----------------------------------------------------------------------===//
 
 #include "flang/Optimizer/HLFIR/HLFIROps.h"
+#include "flang/Optimizer/Builder/HLFIRTools.h"
 #include "flang/Optimizer/Dialect/FIROpsSupport.h"
+#include "flang/Optimizer/Dialect/FIRType.h"
+#include "flang/Optimizer/HLFIR/HLFIRDialect.h"
 #include "mlir/IR/Builders.h"
 #include "mlir/IR/BuiltinTypes.h"
 #include "mlir/IR/DialectImplementation.h"
@@ -421,6 +424,72 @@ void hlfir::SetLengthOp::build(mlir::OpBuilder &builder,
 }
 
 //===----------------------------------------------------------------------===//
+// SumOp
+//===----------------------------------------------------------------------===//
+
+mlir::LogicalResult hlfir::SumOp::verify() {
+  mlir::Operation *op = getOperation();
+
+  auto results = op->getResultTypes();
+  assert(results.size() == 1);
+
+  mlir::Value array = getArray();
+  mlir::Value mask = getMask();
+
+  fir::SequenceType arrayTy =
+      hlfir::getFortranElementOrSequenceType(array.getType())
+          .cast<fir::SequenceType>();
+  mlir::Type numTy = arrayTy.getEleTy();
+  llvm::ArrayRef<int64_t> arrayShape = arrayTy.getShape();
+  hlfir::ExprType resultTy = results[0].cast<hlfir::ExprType>();
+
+  if (mask) {
+    fir::SequenceType maskSeq =
+        hlfir::getFortranElementOrSequenceType(mask.getType())
+            .dyn_cast<fir::SequenceType>();
+    llvm::ArrayRef<int64_t> maskShape;
+
+    if (maskSeq)
+      maskShape = maskSeq.getShape();
+
+    if (!maskShape.empty()) {
+      if (maskShape.size() != arrayShape.size())
+        return emitWarning("MASK must be conformable to ARRAY");
+      static_assert(fir::SequenceType::getUnknownExtent() ==
+                    hlfir::ExprType::getUnknownExtent());
+      constexpr int64_t unknownExtent = fir::SequenceType::getUnknownExtent();
+      for (std::size_t i = 0; i < arrayShape.size(); ++i) {
+        int64_t arrayExtent = arrayShape[i];
+        int64_t maskExtent = maskShape[i];
+        if ((arrayExtent != maskExtent) && (arrayExtent != unknownExtent) &&
+            (maskExtent != unknownExtent))
+          return emitWarning("MASK must be conformable to ARRAY");
+      }
+    }
+  }
+
+  if (resultTy.isArray()) {
+    // Result is of the same type as ARRAY
+    if (resultTy.getEleTy() != numTy)
+      return emitOpError(
+          "result must have the same element type as ARRAY argument");
+
+    llvm::ArrayRef<int64_t> resultShape = resultTy.getShape();
+
+    // Result has rank n-1
+    if (resultShape.size() != (arrayShape.size() - 1))
+      return emitOpError("result rank must be one less than ARRAY");
+  } else {
+    // Result is of the same type as ARRAY
+    if (resultTy.getElementType() != numTy)
+      return emitOpError(
+          "result must have the same element type as ARRAY argument");
+  }
+
+  return mlir::success();
+}
+
+//===----------------------------------------------------------------------===//
 // AssociateOp
 //===----------------------------------------------------------------------===//
 
index a7a3150..dd80163 100644 (file)
@@ -295,3 +295,27 @@ func.func @bad_concat_4(%arg0: !fir.ref<!fir.char<1,30>>) {
   %0 = hlfir.concat %arg0 len %c30 : (!fir.ref<!fir.char<1,30>>, index) -> (!hlfir.expr<!fir.char<1,30>>)
   return
 }
+
+// -----
+func.func @bad_sum1(%arg0: !hlfir.expr<?xi32>, %arg1: i32, %arg2: !fir.box<!fir.logical<4>>) {
+  // expected-error@+1 {{'hlfir.sum' op result must have the same element type as ARRAY argument}}
+  %0 = hlfir.sum %arg0 dim %arg1 mask %arg2 : (!hlfir.expr<?xi32>, i32, !fir.box<!fir.logical<4>>) -> !hlfir.expr<f32>
+}
+
+// -----
+func.func @bad_sum2(%arg0: !hlfir.expr<?xi32>, %arg1: i32, %arg2: !fir.box<!fir.array<?x?x?x?x?x!fir.logical<4>>>) {
+  // expected-warning@+1 {{MASK must be conformable to ARRAY}}
+  %0 = hlfir.sum %arg0 dim %arg1 mask %arg2 : (!hlfir.expr<?xi32>, i32, !fir.box<!fir.array<?x?x?x?x?x!fir.logical<4>>>) -> !hlfir.expr<i32>
+}
+
+// -----
+func.func @bad_sum3(%arg0: !hlfir.expr<?x5x?xi32>, %arg1: i32, %arg2: !fir.box<!fir.array<2x6x?x!fir.logical<4>>>) {
+  // expected-warning@+1 {{MASK must be conformable to ARRAY}}
+  %0 = hlfir.sum %arg0 dim %arg1 mask %arg2 : (!hlfir.expr<?x5x?xi32>, i32, !fir.box<!fir.array<2x6x?x!fir.logical<4>>>) -> !hlfir.expr<i32>
+}
+
+// -----
+func.func @bad_sum4(%arg0: !hlfir.expr<?xi32>, %arg1: i32, %arg2: !fir.box<!fir.logical<4>>) {
+  // expected-error@+1 {{'hlfir.sum' op result rank must be one less than ARRAY}}
+  %0 = hlfir.sum %arg0 dim %arg1 mask %arg2 : (!hlfir.expr<?xi32>, i32, !fir.box<!fir.logical<4>>) -> !hlfir.expr<?x?xi32>
+}
diff --git a/flang/test/HLFIR/sum.fir b/flang/test/HLFIR/sum.fir
new file mode 100644 (file)
index 0000000..45388c3
--- /dev/null
@@ -0,0 +1,239 @@
+// Test hlfir.sum operation parse, verify (no errors), and unparse
+
+// RUN: fir-opt %s | fir-opt | FileCheck %s
+
+// array is an expression of known shape
+func.func @sum0(%arg0: !hlfir.expr<42xi32>) {
+  %mask = fir.alloca !fir.logical<4>
+  %c_1 = arith.constant 1 : index
+  %true = arith.constant true
+  %true_logical = fir.convert %true : (i1) -> !fir.logical<4>
+  fir.store %true_logical to %mask : !fir.ref<!fir.logical<4>>
+  %mask_box = fir.embox %mask : (!fir.ref<!fir.logical<4>>) -> !fir.box<!fir.logical<4>>
+  %0 = hlfir.sum %arg0 dim %c_1 mask %mask_box : (!hlfir.expr<42xi32>, index, !fir.box<!fir.logical<4>>) -> !hlfir.expr<i32>
+  return
+}
+// CHECK:      func.func @sum0(%[[ARRAY:.*]]: !hlfir.expr<42xi32>) {
+// CHECK-NEXT:   %[[MASK:.*]] = fir.alloca !fir.logical<4>
+// CHECK-NEXT:   %[[C1:.*]] = arith.constant 1 : index
+// CHECK-NEXT:   %[[TRUE:.*]] = arith.constant true
+// CHECK-NEXT:   %[[LOGICAL:.*]] = fir.convert %[[TRUE]] : (i1) -> !fir.logical<4>
+// CHECK-NEXT:   fir.store %[[LOGICAL]] to %[[MASK]] : !fir.ref<!fir.logical<4>>
+// CHECK-NEXT:   %[[BOX:.*]] = fir.embox %0 : (!fir.ref<!fir.logical<4>>) -> !fir.box<!fir.logical<4>>
+// CHECK-NEXT:   hlfir.sum %[[ARRAY]] dim %[[C1]] mask %[[BOX]] : (!hlfir.expr<42xi32>, index, !fir.box<!fir.logical<4>>) -> !hlfir.expr<i32>
+// CHECK-NEXT:   return
+// CHECK-NEXT: }
+
+// array is an expression of assumed shape
+func.func @sum1(%arg0: !hlfir.expr<?xi32>) {
+  %mask = fir.alloca !fir.logical<4>
+  %c_1 = arith.constant 1 : index
+  %true = arith.constant true
+  %true_logical = fir.convert %true : (i1) -> !fir.logical<4>
+  fir.store %true_logical to %mask : !fir.ref<!fir.logical<4>>
+  %mask_box = fir.embox %mask : (!fir.ref<!fir.logical<4>>) -> !fir.box<!fir.logical<4>>
+  %0 = hlfir.sum %arg0 dim %c_1 mask %mask_box : (!hlfir.expr<?xi32>, index, !fir.box<!fir.logical<4>>) -> !hlfir.expr<i32>
+  return
+}
+// CHECK:      func.func @sum1(%[[ARRAY:.*]]: !hlfir.expr<?xi32>) {
+// CHECK-NEXT:   %[[MASK:.*]] = fir.alloca !fir.logical<4>
+// CHECK-NEXT:   %[[C1:.*]] = arith.constant 1 : index
+// CHECK-NEXT:   %[[TRUE:.*]] = arith.constant true
+// CHECK-NEXT:   %[[LOGICAL:.*]] = fir.convert %[[TRUE]] : (i1) -> !fir.logical<4>
+// CHECK-NEXT:   fir.store %[[LOGICAL:.*]] to %[[MASK:.*]] : !fir.ref<!fir.logical<4>>
+// CHECK-NEXT:   %[[BOX:.*]] = fir.embox %[[MASK:.*]] : (!fir.ref<!fir.logical<4>>) -> !fir.box<!fir.logical<4>>
+// CHECK-NEXT:   hlfir.sum %[[ARRAY:.*]] dim %[[C1]] mask %[[BOX]] : (!hlfir.expr<?xi32>, index, !fir.box<!fir.logical<4>>) -> !hlfir.expr<i32>
+// CHECK-NEXT:   return
+// CHECK-NEXT: }
+
+// boxed array
+func.func @sum2(%arg0: !fir.box<!fir.array<42xi32>>) {
+  %mask = fir.alloca !fir.logical<4>
+  %c_1 = arith.constant 1 : index
+  %true = arith.constant true
+  %true_logical = fir.convert %true : (i1) -> !fir.logical<4>
+  fir.store %true_logical to %mask : !fir.ref<!fir.logical<4>>
+  %mask_box = fir.embox %mask : (!fir.ref<!fir.logical<4>>) -> !fir.box<!fir.logical<4>>
+  %0 = hlfir.sum %arg0 dim %c_1 mask %mask_box : (!fir.box<!fir.array<42xi32>>, index, !fir.box<!fir.logical<4>>) -> !hlfir.expr<i32>
+  return
+}
+// CHECK:      func.func @sum2(%[[ARRAY:.*]]: !fir.box<!fir.array<42xi32>>) {
+// CHECK-NEXT:   %[[MASK:.*]] = fir.alloca !fir.logical<4>
+// CHECK-NEXT:   %[[C1:.*]] = arith.constant 1 : index
+// CHECK-NEXT:   %[[TRUE:.*]] = arith.constant true
+// CHECK-NEXT:   %[[LOGICAL:.*]] = fir.convert %[[TRUE]] : (i1) -> !fir.logical<4>
+// CHECK-NEXT:   fir.store %[[LOGICAL:.*]] to %[[MASK:.*]] : !fir.ref<!fir.logical<4>>
+// CHECK-NEXT:   %[[BOX:.*]] = fir.embox %[[MASK:.*]] : (!fir.ref<!fir.logical<4>>) -> !fir.box<!fir.logical<4>>
+// CHECK-NEXT:   hlfir.sum %[[ARRAY:.*]] dim %[[C1]] mask %[[BOX]] : (!fir.box<!fir.array<42xi32>>, index, !fir.box<!fir.logical<4>>) -> !hlfir.expr<i32>
+// CHECK-NEXT:   return
+// CHECK-NEXT: }
+
+// assumed shape boxed array
+func.func @sum3(%arg0: !fir.box<!fir.array<?xi32>>) {
+  %mask = fir.alloca !fir.logical<4>
+  %c_1 = arith.constant 1 : index
+  %true = arith.constant true
+  %true_logical = fir.convert %true : (i1) -> !fir.logical<4>
+  fir.store %true_logical to %mask : !fir.ref<!fir.logical<4>>
+  %mask_box = fir.embox %mask : (!fir.ref<!fir.logical<4>>) -> !fir.box<!fir.logical<4>>
+  %0 = hlfir.sum %arg0 dim %c_1 mask %mask_box : (!fir.box<!fir.array<?xi32>>, index, !fir.box<!fir.logical<4>>) -> !hlfir.expr<i32>
+  return
+}
+// CHECK:      func.func @sum3(%[[ARRAY:.*]]: !fir.box<!fir.array<?xi32>>) {
+// CHECK-NEXT:   %[[MASK:.*]] = fir.alloca !fir.logical<4>
+// CHECK-NEXT:   %[[C1:.*]] = arith.constant 1 : index
+// CHECK-NEXT:   %[[TRUE:.*]] = arith.constant true
+// CHECK-NEXT:   %[[LOGICAL:.*]] = fir.convert %[[TRUE]] : (i1) -> !fir.logical<4>
+// CHECK-NEXT:   fir.store %[[LOGICAL:.*]] to %[[MASK:.*]] : !fir.ref<!fir.logical<4>>
+// CHECK-NEXT:   %[[BOX:.*]] = fir.embox %[[MASK:.*]] : (!fir.ref<!fir.logical<4>>) -> !fir.box<!fir.logical<4>>
+// CHECK-NEXT:   hlfir.sum %[[ARRAY:.*]] dim %[[C1]] mask %[[BOX]] : (!fir.box<!fir.array<?xi32>>, index, !fir.box<!fir.logical<4>>) -> !hlfir.expr<i32>
+// CHECK-NEXT:   return
+// CHECK-NEXT: }
+
+// known shape expr mask
+func.func @sum4(%arg0: !fir.box<!fir.array<?xi32>>, %arg1: !hlfir.expr<42x!fir.logical<4>>) {
+  %c_1 = arith.constant 1 : index
+  %0 = hlfir.sum %arg0 dim %c_1 mask %arg1 : (!fir.box<!fir.array<?xi32>>, index, !hlfir.expr<42x!fir.logical<4>>) -> !hlfir.expr<i32>
+  return
+}
+// CHECK:      func.func @sum4(%[[ARRAY:.*]]: !fir.box<!fir.array<?xi32>>, %[[MASK:.*]]: !hlfir.expr<42x!fir.logical<4>>) {
+// CHECK-NEXT:   %[[C1:.*]] = arith.constant 1 : index
+// CHECK-NEXT:   hlfir.sum %[[ARRAY]] dim %[[C1]] mask %[[MASK]] : (!fir.box<!fir.array<?xi32>>, index, !hlfir.expr<42x!fir.logical<4>>) -> !hlfir.expr<i32>
+// CHECK-NEXT:   return
+// CHECK-NEXT: }
+
+// assumed shape expr mask
+func.func @sum5(%arg0: !fir.box<!fir.array<?xi32>>, %arg1: !hlfir.expr<?x!fir.logical<4>>) {
+  %c_1 = arith.constant 1 : index
+  %0 = hlfir.sum %arg0 dim %c_1 mask %arg1 : (!fir.box<!fir.array<?xi32>>, index, !hlfir.expr<?x!fir.logical<4>>) -> !hlfir.expr<i32>
+  return
+}
+// CHECK:      func.func @sum5(%[[ARRAY:.*]]: !fir.box<!fir.array<?xi32>>, %[[MASK:.*]]: !hlfir.expr<?x!fir.logical<4>>) {
+// CHECK-NEXT:   %[[C1:.*]] = arith.constant 1 : index
+// CHECK-NEXT:   hlfir.sum %[[ARRAY]] dim %[[C1]] mask %[[MASK]] : (!fir.box<!fir.array<?xi32>>, index, !hlfir.expr<?x!fir.logical<4>>) -> !hlfir.expr<i32>
+// CHECK-NEXT:   return
+// CHECK-NEXT: }
+
+// known shape array mask
+func.func @sum6(%arg0: !fir.box<!fir.array<?xi32>>, %arg1: !fir.box<!fir.array<42x!fir.logical<4>>>) {
+  %c_1 = arith.constant 1 : index
+  %0 = hlfir.sum %arg0 dim %c_1 mask %arg1 : (!fir.box<!fir.array<?xi32>>, index, !fir.box<!fir.array<42x!fir.logical<4>>>) -> !hlfir.expr<i32>
+  return
+}
+// CHECK:      func.func @sum6(%[[ARRAY:.*]]: !fir.box<!fir.array<?xi32>>, %[[MASK:.*]]: !fir.box<!fir.array<42x!fir.logical<4>>>) {
+// CHECK-NEXT:   %[[C1:.*]] = arith.constant 1 : index
+// CHECK-NEXT:   hlfir.sum %[[ARRAY]] dim %[[C1]] mask %[[MASK]] : (!fir.box<!fir.array<?xi32>>, index, !fir.box<!fir.array<42x!fir.logical<4>>>) -> !hlfir.expr<i32>
+// CHECK-NEXT:   return
+// CHECK-NEXT: }
+
+// assumed shape array mask
+func.func @sum7(%arg0: !fir.box<!fir.array<?xi32>>, %arg1: !fir.box<!fir.array<?x!fir.logical<4>>>) {
+  %c_1 = arith.constant 1 : index
+  %0 = hlfir.sum %arg0 dim %c_1 mask %arg1 : (!fir.box<!fir.array<?xi32>>, index, !fir.box<!fir.array<?x!fir.logical<4>>>) -> !hlfir.expr<i32>
+  return
+}
+// CHECK:      func.func @sum7(%[[ARRAY:.*]]: !fir.box<!fir.array<?xi32>>, %[[MASK:.*]]: !fir.box<!fir.array<?x!fir.logical<4>>>) {
+// CHECK-NEXT:   %[[C1:.*]] = arith.constant 1 : index
+// CHECK-NEXT:   hlfir.sum %[[ARRAY]] dim %[[C1]] mask %[[MASK]] : (!fir.box<!fir.array<?xi32>>, index, !fir.box<!fir.array<?x!fir.logical<4>>>) -> !hlfir.expr<i32>
+// CHECK-NEXT:   return
+// CHECK-NEXT: }
+
+// known shape expr return
+func.func @sum8(%arg0: !fir.box<!fir.array<2x2xi32>>, %arg1: i32) {
+  %mask = fir.alloca !fir.logical<4>
+  %true = arith.constant true
+  %true_logical = fir.convert %true : (i1) -> !fir.logical<4>
+  fir.store %true_logical to %mask : !fir.ref<!fir.logical<4>>
+  %mask_box = fir.embox %mask : (!fir.ref<!fir.logical<4>>) -> !fir.box<!fir.logical<4>>
+  %0 = hlfir.sum %arg0 dim %arg1 mask %mask_box : (!fir.box<!fir.array<2x2xi32>>, i32, !fir.box<!fir.logical<4>>) -> !hlfir.expr<2xi32>
+  return
+}
+// CHECK:      func.func @sum8(%[[ARRAY:.*]]: !fir.box<!fir.array<2x2xi32>>, %[[DIM:.*]]: i32) {
+// CHECK-NEXT:   %[[MASK:.*]] = fir.alloca !fir.logical<4>
+// CHECK-NEXT:   %[[TRUE:.*]] = arith.constant true
+// CHECK-NEXT:   %[[LOGICAL:.*]] = fir.convert %[[TRUE]] : (i1) -> !fir.logical<4>
+// CHECK-NEXT:   fir.store %[[LOGICAL]] to %[[MASK]] : !fir.ref<!fir.logical<4>>
+// CHECK-NEXT:   %[[BOX:.*]] = fir.embox %0 : (!fir.ref<!fir.logical<4>>) -> !fir.box<!fir.logical<4>>
+// CHECK-NEXT:   hlfir.sum %[[ARRAY]] dim %[[DIM]] mask %[[BOX]] : (!fir.box<!fir.array<2x2xi32>>, i32, !fir.box<!fir.logical<4>>) -> !hlfir.expr<2xi32>
+// CHECK-NEXT:   return
+// CHECK-NEXT: }
+
+// assumed shape expr return
+func.func @sum9(%arg0: !fir.box<!fir.array<?x?xi32>>, %arg1: i32) {
+  %mask = fir.alloca !fir.logical<4>
+  %true = arith.constant true
+  %true_logical = fir.convert %true : (i1) -> !fir.logical<4>
+  fir.store %true_logical to %mask : !fir.ref<!fir.logical<4>>
+  %mask_box = fir.embox %mask : (!fir.ref<!fir.logical<4>>) -> !fir.box<!fir.logical<4>>
+  %0 = hlfir.sum %arg0 dim %arg1 mask %mask_box : (!fir.box<!fir.array<?x?xi32>>, i32, !fir.box<!fir.logical<4>>) -> !hlfir.expr<?xi32>
+  return
+}
+// CHECK:      func.func @sum9(%[[ARRAY:.*]]: !fir.box<!fir.array<?x?xi32>>, %[[DIM:.*]]: i32) {
+// CHECK-NEXT:   %[[MASK:.*]] = fir.alloca !fir.logical<4>
+// CHECK-NEXT:   %[[TRUE:.*]] = arith.constant true
+// CHECK-NEXT:   %[[LOGICAL:.*]] = fir.convert %[[TRUE]] : (i1) -> !fir.logical<4>
+// CHECK-NEXT:   fir.store %[[LOGICAL]] to %[[MASK]] : !fir.ref<!fir.logical<4>>
+// CHECK-NEXT:   %[[BOX:.*]] = fir.embox %0 : (!fir.ref<!fir.logical<4>>) -> !fir.box<!fir.logical<4>>
+// CHECK-NEXT:   hlfir.sum %[[ARRAY]] dim %[[DIM]] mask %[[BOX]] : (!fir.box<!fir.array<?x?xi32>>, i32, !fir.box<!fir.logical<4>>) -> !hlfir.expr<?xi32>
+// CHECK-NEXT:   return
+// CHECK-NEXT: }
+
+// hlfir.sum with only an array argument
+func.func @sum10(%arg0: !fir.box<!fir.array<?x?xi32>>) {
+  %sum = hlfir.sum %arg0 : (!fir.box<!fir.array<?x?xi32>>) -> !hlfir.expr<i32>
+  return
+}
+// CHECK:      func.func @sum10(%[[ARRAY:.*]]: !fir.box<!fir.array<?x?xi32>>
+// CHECK-NEXT:   %[[SUM:.*]] = hlfir.sum %[[ARRAY]] : (!fir.box<!fir.array<?x?xi32>>) -> !hlfir.expr<i32>
+// CHECK-NEXT:   return
+// CHECK-NEXT: }
+
+// hlfir.sum with array and dim argument
+func.func @sum11(%arg0: !fir.box<!fir.array<?x?xi32>>, %arg1: i32) {
+  %sum = hlfir.sum %arg0 dim %arg1 : (!fir.box<!fir.array<?x?xi32>>, i32) -> !hlfir.expr<?xi32>
+  return
+}
+// CHECK:      func.func @sum11(%[[ARRAY:.*]]: !fir.box<!fir.array<?x?xi32>>, %[[DIM:.*]]: i32
+// CHECK-NEXT:   %[[SUM:.*]] = hlfir.sum %[[ARRAY]] dim %[[DIM]] : (!fir.box<!fir.array<?x?xi32>>, i32) -> !hlfir.expr<?xi32>
+// CHECK-NEXT:   return
+// CHECK-NEXT: }
+
+// hlfir.sum with array and mask argument
+func.func @sum12(%arg0: !fir.box<!fir.array<?xi32>>, %arg1: !fir.logical<4>) {
+  %sum = hlfir.sum %arg0 mask %arg1 : (!fir.box<!fir.array<?xi32>>, !fir.logical<4>) -> !hlfir.expr<i32>
+  return
+}
+// CHECK:      func.func @sum12(%[[ARRAY:.*]]: !fir.box<!fir.array<?xi32>>, %[[MASK:.*]]: !fir.logical<4>
+// CHECK-NEXT:   %[[SUM:.*]] = hlfir.sum %[[ARRAY]] mask %[[MASK]] : (!fir.box<!fir.array<?xi32>>, !fir.logical<4>) -> !hlfir.expr<i32>
+// CHECK-NEXT:   return
+// CHECK-NEXT: }
+
+// hlfir.sum with dim argument with an unusual type
+func.func @sum13(%arg0: !fir.box<!fir.array<?x?xi32>>, %arg1: index) {
+  %sum = hlfir.sum %arg0 dim %arg1 : (!fir.box<!fir.array<?x?xi32>>, index) -> !hlfir.expr<?xi32>
+  return
+}
+// CHECK:      func.func @sum13(%[[ARRAY:.*]]: !fir.box<!fir.array<?x?xi32>>, %[[DIM:.*]]: index
+// CHECK-NEXT:   %[[SUM:.*]] = hlfir.sum %[[ARRAY]] dim %[[DIM]] : (!fir.box<!fir.array<?x?xi32>>, index) -> !hlfir.expr<?xi32>
+// CHECK-NEXT:   return
+// CHECK-NEXT: }
+
+// hlfir.sum with mask argument of unusual type
+func.func @sum14(%arg0: !fir.box<!fir.array<?xi32>>, %arg1: i1) {
+  %sum = hlfir.sum %arg0 mask %arg1 : (!fir.box<!fir.array<?xi32>>, i1) -> !hlfir.expr<i32>
+  return
+}
+// CHECK:      func.func @sum14(%[[ARRAY:.*]]: !fir.box<!fir.array<?xi32>>, %[[MASK:.*]]: i1
+// CHECK-NEXT:   %[[SUM:.*]] = hlfir.sum %[[ARRAY]] mask %[[MASK]] : (!fir.box<!fir.array<?xi32>>, i1) -> !hlfir.expr<i32>
+// CHECK-NEXT:   return
+// CHECK-NEXT: }
+
+// hlfir.sum with mask argument of ref<array<>> type
+func.func @sum15(%arg0: !fir.box<!fir.array<?xi32>>, %arg1: !fir.ref<!fir.array<?x!fir.logical<4>>>) {
+  %sum = hlfir.sum %arg0 mask %arg1 : (!fir.box<!fir.array<?xi32>>, !fir.ref<!fir.array<?x!fir.logical<4>>>) -> !hlfir.expr<i32>
+  return
+}
+// CHECK:      func.func @sum15(%[[ARRAY:.*]]: !fir.box<!fir.array<?xi32>>, %[[MASK:.*]]: !fir.ref<!fir.array<?x!fir.logical<4>>>
+// CHECK-NEXT:   %[[SUM:.*]] = hlfir.sum %[[ARRAY]] mask %[[MASK]] : (!fir.box<!fir.array<?xi32>>, !fir.ref<!fir.array<?x!fir.logical<4>>>) -> !hlfir.expr<i32>
+// CHECK-NEXT:   return
+// CHECK-NEXT: }