[flang][hlfir] add support for elemental intrinsics with custom handling
authorTom Eccles <tom.eccles@arm.com>
Tue, 11 Jul 2023 14:40:29 +0000 (14:40 +0000)
committerTom Eccles <tom.eccles@arm.com>
Tue, 18 Jul 2023 11:03:34 +0000 (11:03 +0000)
Only minimal argument processing is needed here because they will be
lowered properly either by the elemental intrinsic call builder or the
lowering of the scalar call inside the elemental kernel.

Dynamically optional arrays are coming in the next patch.

Depends On: D155291

Differential Revision: https://reviews.llvm.org/D155292

flang/lib/Lower/ConvertCall.cpp
flang/test/Lower/HLFIR/custom-intrinsic.f90

index 5f2fba8..ec2c8ab 100644 (file)
@@ -1740,14 +1740,65 @@ genIsPresentIfArgMaybeAbsent(mlir::Location loc, hlfir::Entity actual,
       .getResult();
 }
 
+// Lower a reference to an elemental intrinsic procedure with array arguments
+// and custom optional handling
+static std::optional<hlfir::EntityWithAttributes>
+genCustomElementalIntrinsicRef(
+    const Fortran::evaluate::SpecificIntrinsic *intrinsic,
+    CallContext &callContext) {
+  assert(callContext.isElementalProcWithArrayArgs() &&
+         "Use genCustomIntrinsicRef for scalar calls");
+  mlir::Location loc = callContext.loc;
+  auto &converter = callContext.converter;
+  Fortran::lower::PreparedActualArguments operands;
+  assert(intrinsic && Fortran::lower::intrinsicRequiresCustomOptionalHandling(
+                          callContext.procRef, *intrinsic, converter));
+
+  // callback for optional arguments
+  auto prepareOptionalArg = [&](const Fortran::lower::SomeExpr &expr) {
+    hlfir::EntityWithAttributes actual = Fortran::lower::convertExprToHLFIR(
+        loc, converter, expr, callContext.symMap, callContext.stmtCtx);
+    if (expr.Rank() == 0) {
+      std::optional<mlir::Value> isPresent =
+          genIsPresentIfArgMaybeAbsent(loc, actual, expr, callContext,
+                                       /*passAsAllocatableOrPointer=*/false);
+      operands.emplace_back(
+          Fortran::lower::PreparedActualArgument{actual, isPresent});
+    } else {
+      TODO(loc, "elemental intrinsic with custom optional handling optional "
+                "array argument");
+    }
+  };
+
+  // callback for non-optional arguments
+  auto prepareOtherArg = [&](const Fortran::lower::SomeExpr &expr,
+                             fir::LowerIntrinsicArgAs lowerAs) {
+    hlfir::EntityWithAttributes actual = Fortran::lower::convertExprToHLFIR(
+        loc, converter, expr, callContext.symMap, callContext.stmtCtx);
+    operands.emplace_back(Fortran::lower::PreparedActualArgument{
+        actual, /*isPresent=*/std::nullopt});
+  };
+
+  Fortran::lower::prepareCustomIntrinsicArgument(
+      callContext.procRef, *intrinsic, callContext.resultType,
+      prepareOptionalArg, prepareOtherArg, converter);
+
+  const fir::IntrinsicArgumentLoweringRules *argLowering =
+      fir::getIntrinsicArgumentLowering(callContext.getProcedureName());
+  // All of the custom intrinsic elementals with custom handling are pure
+  // functions
+  return ElementalIntrinsicCallBuilder{intrinsic, argLowering,
+                                       /*isFunction=*/true}
+      .genElementalCall(operands, /*isImpure=*/false, callContext);
+}
+
 // Lower a reference to an intrinsic procedure with custom optional handling
 static std::optional<hlfir::EntityWithAttributes>
 genCustomIntrinsicRef(const Fortran::evaluate::SpecificIntrinsic *intrinsic,
                       CallContext &callContext) {
+  assert(!callContext.isElementalProcWithArrayArgs() &&
+         "Needs to be run through ElementalIntrinsicCallBuilder first");
   mlir::Location loc = callContext.loc;
-  if (callContext.isElementalProcWithArrayArgs())
-    TODO(loc, "Elemental proc with array args with custom optional argument "
-              "handling");
   fir::FirOpBuilder &builder = callContext.getBuilder();
   auto &converter = callContext.converter;
   auto &stmtCtx = callContext.stmtCtx;
@@ -1843,6 +1894,8 @@ genIntrinsicRef(const Fortran::evaluate::SpecificIntrinsic *intrinsic,
   auto &converter = callContext.converter;
   if (intrinsic && Fortran::lower::intrinsicRequiresCustomOptionalHandling(
                        callContext.procRef, *intrinsic, converter)) {
+    if (callContext.isElementalProcWithArrayArgs())
+      return genCustomElementalIntrinsicRef(intrinsic, callContext);
     return genCustomIntrinsicRef(intrinsic, callContext);
   }
 
index c409827..c6c856b 100644 (file)
@@ -95,6 +95,39 @@ end function
 ! CHECK:           return %[[VAL_27]] : i32
 ! CHECK:         }
 
+function max_array(a, b)
+   integer, dimension(42) :: a, b, max_array
+   max_array = max(a, b)
+end function
+! CHECK-LABEL:   func.func @_QPmax_array(
+! CHECK-SAME:                            %[[VAL_0:.*]]: !fir.ref<!fir.array<42xi32>> {fir.bindc_name = "a"},
+! CHECK-SAME:                            %[[VAL_1:.*]]: !fir.ref<!fir.array<42xi32>> {fir.bindc_name = "b"}) -> !fir.array<42xi32> {
+! CHECK:           %[[VAL_2:.*]] = arith.constant 42 : index
+! CHECK:           %[[VAL_3:.*]] = fir.shape %[[VAL_2]] : (index) -> !fir.shape<1>
+! CHECK:           %[[VAL_4:.*]]:2 = hlfir.declare %[[VAL_0]](%[[VAL_3]]) {uniq_name = "_QFmax_arrayEa"} : (!fir.ref<!fir.array<42xi32>>, !fir.shape<1>) -> (!fir.ref<!fir.array<42xi32>>, !fir.ref<!fir.array<42xi32>>)
+! CHECK:           %[[VAL_5:.*]] = arith.constant 42 : index
+! CHECK:           %[[VAL_6:.*]] = fir.shape %[[VAL_5]] : (index) -> !fir.shape<1>
+! CHECK:           %[[VAL_7:.*]]:2 = hlfir.declare %[[VAL_1]](%[[VAL_6]]) {uniq_name = "_QFmax_arrayEb"} : (!fir.ref<!fir.array<42xi32>>, !fir.shape<1>) -> (!fir.ref<!fir.array<42xi32>>, !fir.ref<!fir.array<42xi32>>)
+! CHECK:           %[[VAL_8:.*]] = arith.constant 42 : index
+! CHECK:           %[[VAL_9:.*]] = fir.alloca !fir.array<42xi32> {bindc_name = "max_array", uniq_name = "_QFmax_arrayEmax_array"}
+! CHECK:           %[[VAL_10:.*]] = fir.shape %[[VAL_8]] : (index) -> !fir.shape<1>
+! CHECK:           %[[VAL_11:.*]]:2 = hlfir.declare %[[VAL_9]](%[[VAL_10]]) {uniq_name = "_QFmax_arrayEmax_array"} : (!fir.ref<!fir.array<42xi32>>, !fir.shape<1>) -> (!fir.ref<!fir.array<42xi32>>, !fir.ref<!fir.array<42xi32>>)
+! CHECK:           %[[VAL_12:.*]] = hlfir.elemental %[[VAL_3]] unordered : (!fir.shape<1>) -> !hlfir.expr<42xi32> {
+! CHECK:           ^bb0(%[[VAL_13:.*]]: index):
+! CHECK:             %[[VAL_14:.*]] = hlfir.designate %[[VAL_4]]#0 (%[[VAL_13]])  : (!fir.ref<!fir.array<42xi32>>, index) -> !fir.ref<i32>
+! CHECK:             %[[VAL_15:.*]] = fir.load %[[VAL_14]] : !fir.ref<i32>
+! CHECK:             %[[VAL_16:.*]] = hlfir.designate %[[VAL_7]]#0 (%[[VAL_13]])  : (!fir.ref<!fir.array<42xi32>>, index) -> !fir.ref<i32>
+! CHECK:             %[[VAL_17:.*]] = fir.load %[[VAL_16]] : !fir.ref<i32>
+! CHECK:             %[[VAL_18:.*]] = arith.cmpi sgt, %[[VAL_15]], %[[VAL_17]] : i32
+! CHECK:             %[[VAL_19:.*]] = arith.select %[[VAL_18]], %[[VAL_15]], %[[VAL_17]] : i32
+! CHECK:             hlfir.yield_element %[[VAL_19]] : i32
+! CHECK:           }
+! CHECK:           hlfir.assign %[[VAL_20:.*]] to %[[VAL_11]]#0 : !hlfir.expr<42xi32>, !fir.ref<!fir.array<42xi32>>
+! CHECK:           hlfir.destroy %[[VAL_20]] : !hlfir.expr<42xi32>
+! CHECK:           %[[VAL_21:.*]] = fir.load %[[VAL_11]]#1 : !fir.ref<!fir.array<42xi32>>
+! CHECK:           return %[[VAL_21]] : !fir.array<42xi32>
+! CHECK:         }
+
 function min_simple(a, b)
   integer :: a, b, min_simple
   min_simple = min(a, b)
@@ -190,6 +223,39 @@ end function
 ! CHECK:           return %[[VAL_27]] : i32
 ! CHECK:         }
 
+function min_array(a, b)
+   integer, dimension(42) :: a, b, min_array
+   min_array = min(a, b)
+end function
+! CHECK-LABEL:   func.func @_QPmin_array(
+! CHECK-SAME:                            %[[VAL_0:.*]]: !fir.ref<!fir.array<42xi32>> {fir.bindc_name = "a"},
+! CHECK-SAME:                            %[[VAL_1:.*]]: !fir.ref<!fir.array<42xi32>> {fir.bindc_name = "b"}) -> !fir.array<42xi32> {
+! CHECK:           %[[VAL_2:.*]] = arith.constant 42 : index
+! CHECK:           %[[VAL_3:.*]] = fir.shape %[[VAL_2]] : (index) -> !fir.shape<1>
+! CHECK:           %[[VAL_4:.*]]:2 = hlfir.declare %[[VAL_0]](%[[VAL_3]]) {uniq_name = "_QFmin_arrayEa"} : (!fir.ref<!fir.array<42xi32>>, !fir.shape<1>) -> (!fir.ref<!fir.array<42xi32>>, !fir.ref<!fir.array<42xi32>>)
+! CHECK:           %[[VAL_5:.*]] = arith.constant 42 : index
+! CHECK:           %[[VAL_6:.*]] = fir.shape %[[VAL_5]] : (index) -> !fir.shape<1>
+! CHECK:           %[[VAL_7:.*]]:2 = hlfir.declare %[[VAL_1]](%[[VAL_6]]) {uniq_name = "_QFmin_arrayEb"} : (!fir.ref<!fir.array<42xi32>>, !fir.shape<1>) -> (!fir.ref<!fir.array<42xi32>>, !fir.ref<!fir.array<42xi32>>)
+! CHECK:           %[[VAL_8:.*]] = arith.constant 42 : index
+! CHECK:           %[[VAL_9:.*]] = fir.alloca !fir.array<42xi32> {bindc_name = "min_array", uniq_name = "_QFmin_arrayEmin_array"}
+! CHECK:           %[[VAL_10:.*]] = fir.shape %[[VAL_8]] : (index) -> !fir.shape<1>
+! CHECK:           %[[VAL_11:.*]]:2 = hlfir.declare %[[VAL_9]](%[[VAL_10]]) {uniq_name = "_QFmin_arrayEmin_array"} : (!fir.ref<!fir.array<42xi32>>, !fir.shape<1>) -> (!fir.ref<!fir.array<42xi32>>, !fir.ref<!fir.array<42xi32>>)
+! CHECK:           %[[VAL_12:.*]] = hlfir.elemental %[[VAL_3]] unordered : (!fir.shape<1>) -> !hlfir.expr<42xi32> {
+! CHECK:           ^bb0(%[[VAL_13:.*]]: index):
+! CHECK:             %[[VAL_14:.*]] = hlfir.designate %[[VAL_4]]#0 (%[[VAL_13]])  : (!fir.ref<!fir.array<42xi32>>, index) -> !fir.ref<i32>
+! CHECK:             %[[VAL_15:.*]] = fir.load %[[VAL_14]] : !fir.ref<i32>
+! CHECK:             %[[VAL_16:.*]] = hlfir.designate %[[VAL_7]]#0 (%[[VAL_13]])  : (!fir.ref<!fir.array<42xi32>>, index) -> !fir.ref<i32>
+! CHECK:             %[[VAL_17:.*]] = fir.load %[[VAL_16]] : !fir.ref<i32>
+! CHECK:             %[[VAL_18:.*]] = arith.cmpi slt, %[[VAL_15]], %[[VAL_17]] : i32
+! CHECK:             %[[VAL_19:.*]] = arith.select %[[VAL_18]], %[[VAL_15]], %[[VAL_17]] : i32
+! CHECK:             hlfir.yield_element %[[VAL_19]] : i32
+! CHECK:           }
+! CHECK:           hlfir.assign %[[VAL_20:.*]] to %[[VAL_11]]#0 : !hlfir.expr<42xi32>, !fir.ref<!fir.array<42xi32>>
+! CHECK:           hlfir.destroy %[[VAL_20]] : !hlfir.expr<42xi32>
+! CHECK:           %[[VAL_21:.*]] = fir.load %[[VAL_11]]#1 : !fir.ref<!fir.array<42xi32>>
+! CHECK:           return %[[VAL_21]] : !fir.array<42xi32>
+! CHECK:         }
+
 function associated_simple(pointer)
     integer, pointer :: pointer
     logical :: associated_simple
@@ -389,4 +455,70 @@ end function
 ! CHECK:           hlfir.assign %[[VAL_43]] to %[[VAL_5]]#0 : i32, !fir.ref<i32>
 ! CHECK:           %[[VAL_44:.*]] = fir.load %[[VAL_5]]#1 : !fir.ref<i32>
 ! CHECK:           return %[[VAL_44]] : i32
+! CHECK:         }
+
+function ishftc_array(i, shift, size)
+   integer, dimension(42) :: ishftc_array, i, shift, size
+   ishftc_array = ishftc(i, shift, size)
+end function
+! CHECK-LABEL:   func.func @_QPishftc_array(
+! CHECK-SAME:                               %[[VAL_0:.*]]: !fir.ref<!fir.array<42xi32>> {fir.bindc_name = "i"},
+! CHECK-SAME:                               %[[VAL_1:.*]]: !fir.ref<!fir.array<42xi32>> {fir.bindc_name = "shift"},
+! CHECK-SAME:                               %[[VAL_2:.*]]: !fir.ref<!fir.array<42xi32>> {fir.bindc_name = "size"}) -> !fir.array<42xi32> {
+! CHECK:           %[[VAL_3:.*]] = arith.constant 42 : index
+! CHECK:           %[[VAL_4:.*]] = fir.shape %[[VAL_3]] : (index) -> !fir.shape<1>
+! CHECK:           %[[VAL_5:.*]]:2 = hlfir.declare %[[VAL_0]](%[[VAL_4]]) {uniq_name = "_QFishftc_arrayEi"} : (!fir.ref<!fir.array<42xi32>>, !fir.shape<1>) -> (!fir.ref<!fir.array<42xi32>>, !fir.ref<!fir.array<42xi32>>)
+! CHECK:           %[[VAL_6:.*]] = arith.constant 42 : index
+! CHECK:           %[[VAL_7:.*]] = fir.alloca !fir.array<42xi32> {bindc_name = "ishftc_array", uniq_name = "_QFishftc_arrayEishftc_array"}
+! CHECK:           %[[VAL_8:.*]] = fir.shape %[[VAL_6]] : (index) -> !fir.shape<1>
+! CHECK:           %[[VAL_9:.*]]:2 = hlfir.declare %[[VAL_7]](%[[VAL_8]]) {uniq_name = "_QFishftc_arrayEishftc_array"} : (!fir.ref<!fir.array<42xi32>>, !fir.shape<1>) -> (!fir.ref<!fir.array<42xi32>>, !fir.ref<!fir.array<42xi32>>)
+! CHECK:           %[[VAL_10:.*]] = arith.constant 42 : index
+! CHECK:           %[[VAL_11:.*]] = fir.shape %[[VAL_10]] : (index) -> !fir.shape<1>
+! CHECK:           %[[VAL_12:.*]]:2 = hlfir.declare %[[VAL_1]](%[[VAL_11]]) {uniq_name = "_QFishftc_arrayEshift"} : (!fir.ref<!fir.array<42xi32>>, !fir.shape<1>) -> (!fir.ref<!fir.array<42xi32>>, !fir.ref<!fir.array<42xi32>>)
+! CHECK:           %[[VAL_13:.*]] = arith.constant 42 : index
+! CHECK:           %[[VAL_14:.*]] = fir.shape %[[VAL_13]] : (index) -> !fir.shape<1>
+! CHECK:           %[[VAL_15:.*]]:2 = hlfir.declare %[[VAL_2]](%[[VAL_14]]) {uniq_name = "_QFishftc_arrayEsize"} : (!fir.ref<!fir.array<42xi32>>, !fir.shape<1>) -> (!fir.ref<!fir.array<42xi32>>, !fir.ref<!fir.array<42xi32>>)
+! CHECK:           %[[VAL_16:.*]] = hlfir.elemental %[[VAL_4]] unordered : (!fir.shape<1>) -> !hlfir.expr<42xi32> {
+! CHECK:           ^bb0(%[[VAL_17:.*]]: index):
+! CHECK:             %[[VAL_18:.*]] = hlfir.designate %[[VAL_5]]#0 (%[[VAL_17]])  : (!fir.ref<!fir.array<42xi32>>, index) -> !fir.ref<i32>
+! CHECK:             %[[VAL_19:.*]] = fir.load %[[VAL_18]] : !fir.ref<i32>
+! CHECK:             %[[VAL_20:.*]] = hlfir.designate %[[VAL_12]]#0 (%[[VAL_17]])  : (!fir.ref<!fir.array<42xi32>>, index) -> !fir.ref<i32>
+! CHECK:             %[[VAL_21:.*]] = fir.load %[[VAL_20]] : !fir.ref<i32>
+! CHECK:             %[[VAL_22:.*]] = hlfir.designate %[[VAL_15]]#0 (%[[VAL_17]])  : (!fir.ref<!fir.array<42xi32>>, index) -> !fir.ref<i32>
+! CHECK:             %[[VAL_23:.*]] = fir.load %[[VAL_22]] : !fir.ref<i32>
+! CHECK:             %[[VAL_24:.*]] = arith.constant 32 : i32
+! CHECK:             %[[VAL_25:.*]] = arith.constant 0 : i32
+! CHECK:             %[[VAL_26:.*]] = arith.constant -1 : i32
+! CHECK:             %[[VAL_27:.*]] = arith.constant 31 : i32
+! CHECK:             %[[VAL_28:.*]] = arith.shrsi %[[VAL_21]], %[[VAL_27]] : i32
+! CHECK:             %[[VAL_29:.*]] = arith.xori %[[VAL_21]], %[[VAL_28]] : i32
+! CHECK:             %[[VAL_30:.*]] = arith.subi %[[VAL_29]], %[[VAL_28]] : i32
+! CHECK:             %[[VAL_31:.*]] = arith.subi %[[VAL_23]], %[[VAL_30]] : i32
+! CHECK:             %[[VAL_32:.*]] = arith.cmpi eq, %[[VAL_21]], %[[VAL_25]] : i32
+! CHECK:             %[[VAL_33:.*]] = arith.cmpi eq, %[[VAL_30]], %[[VAL_23]] : i32
+! CHECK:             %[[VAL_34:.*]] = arith.ori %[[VAL_32]], %[[VAL_33]] : i1
+! CHECK:             %[[VAL_35:.*]] = arith.cmpi sgt, %[[VAL_21]], %[[VAL_25]] : i32
+! CHECK:             %[[VAL_36:.*]] = arith.select %[[VAL_35]], %[[VAL_30]], %[[VAL_31]] : i32
+! CHECK:             %[[VAL_37:.*]] = arith.select %[[VAL_35]], %[[VAL_31]], %[[VAL_30]] : i32
+! CHECK:             %[[VAL_38:.*]] = arith.cmpi ne, %[[VAL_23]], %[[VAL_24]] : i32
+! CHECK:             %[[VAL_39:.*]] = arith.shrui %[[VAL_19]], %[[VAL_23]] : i32
+! CHECK:             %[[VAL_40:.*]] = arith.shli %[[VAL_39]], %[[VAL_23]] : i32
+! CHECK:             %[[VAL_41:.*]] = arith.select %[[VAL_38]], %[[VAL_40]], %[[VAL_25]] : i32
+! CHECK:             %[[VAL_42:.*]] = arith.subi %[[VAL_24]], %[[VAL_36]] : i32
+! CHECK:             %[[VAL_43:.*]] = arith.shrui %[[VAL_26]], %[[VAL_42]] : i32
+! CHECK:             %[[VAL_44:.*]] = arith.shrui %[[VAL_19]], %[[VAL_37]] : i32
+! CHECK:             %[[VAL_45:.*]] = arith.andi %[[VAL_44]], %[[VAL_43]] : i32
+! CHECK:             %[[VAL_46:.*]] = arith.subi %[[VAL_24]], %[[VAL_37]] : i32
+! CHECK:             %[[VAL_47:.*]] = arith.shrui %[[VAL_26]], %[[VAL_46]] : i32
+! CHECK:             %[[VAL_48:.*]] = arith.andi %[[VAL_19]], %[[VAL_47]] : i32
+! CHECK:             %[[VAL_49:.*]] = arith.shli %[[VAL_48]], %[[VAL_36]] : i32
+! CHECK:             %[[VAL_50:.*]] = arith.ori %[[VAL_41]], %[[VAL_45]] : i32
+! CHECK:             %[[VAL_51:.*]] = arith.ori %[[VAL_50]], %[[VAL_49]] : i32
+! CHECK:             %[[VAL_52:.*]] = arith.select %[[VAL_34]], %[[VAL_19]], %[[VAL_51]] : i32
+! CHECK:             hlfir.yield_element %[[VAL_52]] : i32
+! CHECK:           }
+! CHECK:           hlfir.assign %[[VAL_53:.*]] to %[[VAL_9]]#0 : !hlfir.expr<42xi32>, !fir.ref<!fir.array<42xi32>>
+! CHECK:           hlfir.destroy %[[VAL_53]] : !hlfir.expr<42xi32>
+! CHECK:           %[[VAL_54:.*]] = fir.load %[[VAL_9]]#1 : !fir.ref<!fir.array<42xi32>>
+! CHECK:           return %[[VAL_54]] : !fir.array<42xi32>
 ! CHECK:         }
\ No newline at end of file