[flang][hlfir] Add pass to inline elementals
authorTom Eccles <tom.eccles@arm.com>
Tue, 25 Apr 2023 09:07:11 +0000 (09:07 +0000)
committerTom Eccles <tom.eccles@arm.com>
Thu, 18 May 2023 10:48:45 +0000 (10:48 +0000)
Implement hlfir.elemental inlining as proposed in
flang/docs/HighLevelFIR.md.

This is a separate pass to make the code easier to understand. One
alternative would have been to modify the hlfir.elemental lowering in
the HLFIR bufferization pass.

Currently, a hlfir.elemental can only be inlined once; if there are
more uses, the existing bufferization is used instead.

Usage of mlir::applyPatternsAndFoldGreedily was suggested by @jeanPerier

Differential Revision: https://reviews.llvm.org/D149258

flang/include/flang/Optimizer/HLFIR/HLFIROps.td
flang/include/flang/Optimizer/HLFIR/Passes.h
flang/include/flang/Optimizer/HLFIR/Passes.td
flang/include/flang/Tools/CLOptions.inc
flang/lib/Optimizer/HLFIR/Transforms/CMakeLists.txt
flang/lib/Optimizer/HLFIR/Transforms/InlineElementals.cpp [new file with mode: 0644]
flang/test/Driver/mlir-debug-pass-pipeline.f90
flang/test/Driver/mlir-pass-pipeline.f90
flang/test/Fir/basic-program.fir
flang/test/HLFIR/inline-elemental.fir [new file with mode: 0644]

index 0aed277..1fe3eba 100644 (file)
@@ -606,6 +606,11 @@ def hlfir_ElementalOp : hlfir_Op<"elemental", [RecursiveMemoryEffects]> {
     The shape and typeparams operands represent the extents and type
     parameters of the resulting array value.
 
+    Currently there is no way to control the iteration order of a hlfir
+    elemental operation and so operations in the body of the elemental must
+    not have side effects. If this is changed, an attribute must be added so
+    that the elemental inlining pass can skip these impure elementals.
+
 
     Example: Y + X,  with Integer :: X(10, 20), Y(10,20)
     ```
@@ -670,9 +675,17 @@ def hlfir_ApplyOp : hlfir_Op<"apply", [NoMemoryEffect, AttrSizedOperandSegments]
   let description = [{
     Given an hlfir.expr array value, hlfir.apply allow retrieving
     the value for an element given one based indices.
+
     When hlfir.apply is used on an hlfir.elemental, and if the hlfir.elemental
     operation evaluation can be moved to the location of the hlfir.apply, it is
     as if the hlfir.elemental body was evaluated given the hlfir.apply indices.
+    Therefore, apply operations on hlfir.elemental expressions should be located
+    such that evaluating the hlfir.elemental at the position of the hlfir.apply
+    operation produces the same result as evaluating the hlfir.elemental at its
+    location in the instruction stream. Attention should be paid to
+    hlfir.elemental memory side effects (in practice these are unlikely).
+    "10.1.4 Evaluation of operations" says that expression evaluation shall not
+    impact/be impacted by other expression evaluation in the statement.
   }];
 
   let arguments = (ins hlfir_ExprType:$expr,
index a5aa35b..eb3cc14 100644 (file)
@@ -26,6 +26,7 @@ std::unique_ptr<mlir::Pass> createConvertHLFIRtoFIRPass();
 std::unique_ptr<mlir::Pass> createBufferizeHLFIRPass();
 std::unique_ptr<mlir::Pass> createLowerHLFIRIntrinsicsPass();
 std::unique_ptr<mlir::Pass> createSimplifyHLFIRIntrinsicsPass();
+std::unique_ptr<mlir::Pass> createInlineElementalsPass();
 std::unique_ptr<mlir::Pass> createLowerHLFIROrderedAssignmentsPass();
 
 #define GEN_PASS_REGISTRATION
index 4932409..7e832a9 100644 (file)
@@ -43,4 +43,9 @@ def SimplifyHLFIRIntrinsics : Pass<"simplify-hlfir-intrinsics", "::mlir::func::F
   let constructor = "hlfir::createSimplifyHLFIRIntrinsicsPass()";
 }
 
+def InlineElementals : Pass<"inline-elementals", "::mlir::func::FuncOp"> {
+  let summary = "Inline chained hlfir.elemental operations";
+  let constructor = "hlfir::createInlineElementalsPass()";
+}
+
 #endif //FORTRAN_DIALECT_HLFIR_PASSES
index a799427..16eb998 100644 (file)
@@ -242,6 +242,7 @@ inline void createHLFIRToFIRPassPipeline(
     addCanonicalizerPassWithoutRegionSimplification(pm);
     pm.addPass(hlfir::createSimplifyHLFIRIntrinsicsPass());
   }
+  pm.addPass(hlfir::createInlineElementalsPass());
   pm.addPass(hlfir::createLowerHLFIROrderedAssignmentsPass());
   pm.addPass(hlfir::createLowerHLFIRIntrinsicsPass());
   pm.addPass(hlfir::createBufferizeHLFIRPass());
index f7e51dc..bde1d47 100644 (file)
@@ -3,6 +3,7 @@ get_property(dialect_libs GLOBAL PROPERTY MLIR_DIALECT_LIBS)
 add_flang_library(HLFIRTransforms
   BufferizeHLFIR.cpp
   ConvertToFIR.cpp
+  InlineElementals.cpp
   LowerHLFIRIntrinsics.cpp
   LowerHLFIROrderedAssignments.cpp
   ScheduleOrderedAssignments.cpp
diff --git a/flang/lib/Optimizer/HLFIR/Transforms/InlineElementals.cpp b/flang/lib/Optimizer/HLFIR/Transforms/InlineElementals.cpp
new file mode 100644 (file)
index 0000000..f0acd22
--- /dev/null
@@ -0,0 +1,119 @@
+//===- InlineElementals.cpp - Inline chained hlfir.elemental ops ----------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+// Chained elemental operations like a + b + c can inline the first elemental
+// at the hlfir.apply in the body of the second one (as described in
+// docs/HighLevelFIR.md). This has to be done in a pass rather than in lowering
+// so that it happens after the HLFIR intrinsic simplification pass.
+//===----------------------------------------------------------------------===//
+
+#include "flang/Optimizer/Builder/FIRBuilder.h"
+#include "flang/Optimizer/Builder/HLFIRTools.h"
+#include "flang/Optimizer/Dialect/Support/FIRContext.h"
+#include "flang/Optimizer/Dialect/Support/KindMapping.h"
+#include "flang/Optimizer/HLFIR/HLFIROps.h"
+#include "flang/Optimizer/HLFIR/Passes.h"
+#include "mlir/Dialect/Func/IR/FuncOps.h"
+#include "mlir/IR/IRMapping.h"
+#include "mlir/IR/PatternMatch.h"
+#include "mlir/Pass/Pass.h"
+#include "mlir/Support/LLVM.h"
+#include "mlir/Transforms/DialectConversion.h"
+#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
+#include "llvm/ADT/TypeSwitch.h"
+#include <iterator>
+
+namespace hlfir {
+#define GEN_PASS_DEF_INLINEELEMENTALS
+#include "flang/Optimizer/HLFIR/Passes.h.inc"
+} // namespace hlfir
+
+/// If the elemental has only two uses and those two are an apply operation and
+/// a destory operation, return those two, otherwise return {}
+static std::optional<std::pair<hlfir::ApplyOp, hlfir::DestroyOp>>
+getTwoUses(hlfir::ElementalOp elemental) {
+  mlir::Operation::user_range users = elemental->getUsers();
+  // don't inline anything with more than one use (plus hfir.destroy)
+  if (std::distance(users.begin(), users.end()) != 2) {
+    return std::nullopt;
+  }
+
+  hlfir::ApplyOp apply;
+  hlfir::DestroyOp destroy;
+  for (mlir::Operation *user : users)
+    mlir::TypeSwitch<mlir::Operation *, void>(user)
+        .Case([&](hlfir::ApplyOp op) { apply = op; })
+        .Case([&](hlfir::DestroyOp op) { destroy = op; });
+
+  if (!apply || !destroy)
+    return std::nullopt;
+  return std::pair{apply, destroy};
+}
+
+namespace {
+class InlineElementalConversion
+    : public mlir::OpRewritePattern<hlfir::ElementalOp> {
+public:
+  using mlir::OpRewritePattern<hlfir::ElementalOp>::OpRewritePattern;
+
+  mlir::LogicalResult
+  matchAndRewrite(hlfir::ElementalOp elemental,
+                  mlir::PatternRewriter &rewriter) const override {
+    std::optional<std::pair<hlfir::ApplyOp, hlfir::DestroyOp>> maybeTuple =
+        getTwoUses(elemental);
+    if (!maybeTuple) {
+      return rewriter.notifyMatchFailure(elemental.getLoc(),
+                                         [](mlir::Diagnostic &) {});
+    }
+    auto [apply, destroy] = *maybeTuple;
+
+    assert(elemental.getRegion().hasOneBlock() &&
+           "expect elemental region to have one block");
+
+    fir::FirOpBuilder builder{rewriter,
+                              fir::KindMapping{rewriter.getContext()}};
+    builder.setInsertionPointAfter(apply);
+    hlfir::YieldElementOp yield = hlfir::inlineElementalOp(
+        elemental.getLoc(), builder, elemental, apply.getIndices());
+
+    // remove the old elemental and all of the bookkeeping
+    rewriter.replaceAllUsesWith(apply.getResult(), yield.getElementValue());
+    rewriter.eraseOp(yield);
+    rewriter.eraseOp(apply);
+    rewriter.eraseOp(destroy);
+    rewriter.eraseOp(elemental);
+
+    return mlir::success();
+  }
+};
+
+class InlineElementalsPass
+    : public hlfir::impl::InlineElementalsBase<InlineElementalsPass> {
+public:
+  void runOnOperation() override {
+    mlir::func::FuncOp func = getOperation();
+    mlir::MLIRContext *context = &getContext();
+
+    mlir::GreedyRewriteConfig config;
+    // Prevent the pattern driver from merging blocks.
+    config.enableRegionSimplification = false;
+
+    mlir::RewritePatternSet patterns(context);
+    patterns.insert<InlineElementalConversion>(context);
+
+    if (mlir::failed(mlir::applyPatternsAndFoldGreedily(
+            func, std::move(patterns), config))) {
+      mlir::emitError(func->getLoc(), "failure in HLFIR elemental inlining");
+      signalPassFailure();
+    }
+  }
+};
+} // namespace
+
+std::unique_ptr<mlir::Pass> hlfir::createInlineElementalsPass() {
+  return std::make_unique<InlineElementalsPass>();
+}
index 320f06a..a3ff416 100644 (file)
@@ -25,6 +25,8 @@ end program
 ! ALL: Pass statistics report
 
 ! ALL: Fortran::lower::VerifierPass
+! ALL-NEXT: 'func.func' Pipeline
+! ALL-NEXT:   InlineElementals
 ! ALL-NEXT: LowerHLFIROrderedAssignments
 ! ALL-NEXT: LowerHLFIRIntrinsics
 ! ALL-NEXT: BufferizeHLFIR
index 44d253a..7f92ec2 100644 (file)
@@ -15,7 +15,8 @@ end program
 ! O2-NEXT: Canonicalizer
 ! O2-NEXT: 'func.func' Pipeline
 ! O2-NEXT:   SimplifyHLFIRIntrinsics
-! ALL-NEXT: LowerHLFIROrderedAssignments
+! ALL:       InlineElementals
+! ALL: LowerHLFIROrderedAssignments
 ! ALL-NEXT: LowerHLFIRIntrinsics
 ! ALL-NEXT: BufferizeHLFIR
 ! ALL-NEXT: ConvertHLFIRtoFIR
index e6d849c..4f0efb1 100644 (file)
@@ -19,6 +19,7 @@ func.func @_QQmain() {
 // PASSES:        Canonicalizer
 // PASSES-NEXT: 'func.func' Pipeline
 // PASSES-NEXT:   SimplifyHLFIRIntrinsics
+// PASSES-NEXT:   InlineElementals
 // PASSES-NEXT:   LowerHLFIROrderedAssignments
 // PASSES-NEXT:   LowerHLFIRIntrinsics
 // PASSES-NEXT:   BufferizeHLFIR
diff --git a/flang/test/HLFIR/inline-elemental.fir b/flang/test/HLFIR/inline-elemental.fir
new file mode 100644 (file)
index 0000000..a9bae19
--- /dev/null
@@ -0,0 +1,245 @@
+// RUN: fir-opt --inline-elementals %s | FileCheck %s
+
+// check inlining one elemental into another
+// a = b * c + d
+func.func @inline_to_elemental(%arg0: !fir.box<!fir.array<?xi32>> {fir.bindc_name = "a"}, %arg1: !fir.box<!fir.array<?xi32>> {fir.bindc_name = "b"}, %arg2: !fir.box<!fir.array<?xi32>> {fir.bindc_name = "c"}, %arg3: !fir.box<!fir.array<?xi32>> {fir.bindc_name = "d"}) {
+  %0:2 = hlfir.declare %arg0 {uniq_name = "a"} : (!fir.box<!fir.array<?xi32>>) -> (!fir.box<!fir.array<?xi32>>, !fir.box<!fir.array<?xi32>>)
+  %1:2 = hlfir.declare %arg1 {uniq_name = "b"} : (!fir.box<!fir.array<?xi32>>) -> (!fir.box<!fir.array<?xi32>>, !fir.box<!fir.array<?xi32>>)
+  %2:2 = hlfir.declare %arg2 {uniq_name = "c"} : (!fir.box<!fir.array<?xi32>>) -> (!fir.box<!fir.array<?xi32>>, !fir.box<!fir.array<?xi32>>)
+  %3:2 = hlfir.declare %arg3 {uniq_name = "d"} : (!fir.box<!fir.array<?xi32>>) -> (!fir.box<!fir.array<?xi32>>, !fir.box<!fir.array<?xi32>>)
+  %c0 = arith.constant 0 : index
+  %4:3 = fir.box_dims %1#0, %c0 : (!fir.box<!fir.array<?xi32>>, index) -> (index, index, index)
+  %5 = fir.shape %4#1 : (index) -> !fir.shape<1>
+  %6 = hlfir.elemental %5 : (!fir.shape<1>) -> !hlfir.expr<?xi32> {
+  ^bb0(%arg4: index):
+    %8 = hlfir.designate %1#0 (%arg4)  : (!fir.box<!fir.array<?xi32>>, index) -> !fir.ref<i32>
+    %9 = hlfir.designate %2#0 (%arg4)  : (!fir.box<!fir.array<?xi32>>, index) -> !fir.ref<i32>
+    %10 = fir.load %8 : !fir.ref<i32>
+    %11 = fir.load %9 : !fir.ref<i32>
+    %12 = arith.muli %10, %11 : i32
+    hlfir.yield_element %12 : i32
+  }
+  %7 = hlfir.elemental %5 : (!fir.shape<1>) -> !hlfir.expr<?xi32> {
+  ^bb0(%arg4: index):
+    %8 = hlfir.apply %6, %arg4 : (!hlfir.expr<?xi32>, index) -> i32
+    %9 = hlfir.designate %3#0 (%arg4)  : (!fir.box<!fir.array<?xi32>>, index) -> !fir.ref<i32>
+    %10 = fir.load %9 : !fir.ref<i32>
+    %11 = arith.addi %8, %10 : i32
+    hlfir.yield_element %11 : i32
+  }
+  hlfir.assign %7 to %0#0 : !hlfir.expr<?xi32>, !fir.box<!fir.array<?xi32>>
+  hlfir.destroy %7 : !hlfir.expr<?xi32>
+  hlfir.destroy %6 : !hlfir.expr<?xi32>
+  return
+}
+// CHECK-LABEL: func.func @inline_to_elemental
+// CHECK-SAME:      %[[A_ARG:.*]]: !fir.box<!fir.array<?xi32>> {fir.bindc_name = "a"}
+// CHECK-SAME:      %[[B_ARG:.*]]: !fir.box<!fir.array<?xi32>> {fir.bindc_name = "b"}
+// CHECK-SAME:      %[[C_ARG:.*]]: !fir.box<!fir.array<?xi32>> {fir.bindc_name = "c"}
+// CHECK-SAME:      %[[D_ARG:.*]]: !fir.box<!fir.array<?xi32>> {fir.bindc_name = "d"}
+// CHECK-NEXT:    %[[C0:.*]] = arith.constant 0 : index
+// CHECK-DAG:     %[[A:.*]]:2 = hlfir.declare %[[A_ARG]]
+// CHECK-DAG:     %[[B:.*]]:2 = hlfir.declare %[[B_ARG]]
+// CHECK-DAG:     %[[C:.*]]:2 = hlfir.declare %[[C_ARG]]
+// CHECK-DAG:     %[[D:.*]]:2 = hlfir.declare %[[D_ARG]]
+// CHECK-NEXT:    %[[B_DIM0:.*]]:3 = fir.box_dims %[[B]]#0, %[[C0]]
+// CHECK-NEXT:    %[[B_SHAPE:.*]] = fir.shape %[[B_DIM0]]#1
+// CHECK-NEXT:    %[[EXPR:.*]] = hlfir.elemental %[[B_SHAPE]]
+// CHECK-NEXT:    ^bb0(%[[I:.*]]: index):
+// inline the first elemental:
+// CHECK-NEXT:      %[[B_I_REF:.*]] = hlfir.designate %[[B]]#0 (%[[I]])
+// CHECK-NEXT:      %[[C_I_REF:.*]] = hlfir.designate %[[C]]#0 (%[[I]])
+// CHECK-NEXT:      %[[B_I:.*]] = fir.load %[[B_I_REF]]
+// CHECK-NEXT:      %[[C_I:.*]] = fir.load %[[C_I_REF]]
+// CHECK-NEXT:      %[[MUL:.*]] = arith.muli %[[B_I]], %[[C_I]]
+// second elemental:
+// CHECK-NEXT:      %[[D_I_REF:.*]] = hlfir.designate %[[D]]#0 (%[[I]])
+// CHECK-NEXT:      %[[D_I:.*]] = fir.load %[[D_I_REF]]
+// CHECK-NEXT:      %[[ADD:.*]] = arith.addi %[[MUL]], %[[D_I]]
+// CHECK-NEXT:      hlfir.yield_element %[[ADD]]
+// CHECK-NEXT:    }
+// CHECK-NEXT:    hlfir.assign %[[EXPR]] to %[[A]]#0
+// CHECK-NEXT:    hlfir.destroy %[[EXPR]]
+// CHECK-NEXT:    return
+// CHECK-NEXT:  }
+
+// check inlining into a do_loop
+func.func @inline_to_loop(%arg0: !fir.box<!fir.array<?xi32>> {fir.bindc_name = "a"}, %arg1: !fir.box<!fir.array<?xi32>> {fir.bindc_name = "b"}, %arg2: !fir.box<!fir.array<?xi32>> {fir.bindc_name = "c"}, %arg3: !fir.box<!fir.array<?xi32>> {fir.bindc_name = "d"}) {
+  %0:2 = hlfir.declare %arg0 {uniq_name = "a"} : (!fir.box<!fir.array<?xi32>>) -> (!fir.box<!fir.array<?xi32>>, !fir.box<!fir.array<?xi32>>)
+  %1:2 = hlfir.declare %arg1 {uniq_name = "b"} : (!fir.box<!fir.array<?xi32>>) -> (!fir.box<!fir.array<?xi32>>, !fir.box<!fir.array<?xi32>>)
+  %2:2 = hlfir.declare %arg2 {uniq_name = "c"} : (!fir.box<!fir.array<?xi32>>) -> (!fir.box<!fir.array<?xi32>>, !fir.box<!fir.array<?xi32>>)
+  %3:2 = hlfir.declare %arg3 {uniq_name = "d"} : (!fir.box<!fir.array<?xi32>>) -> (!fir.box<!fir.array<?xi32>>, !fir.box<!fir.array<?xi32>>)
+  %c0 = arith.constant 0 : index
+  %4:3 = fir.box_dims %1#0, %c0 : (!fir.box<!fir.array<?xi32>>, index) -> (index, index, index)
+  %5 = fir.shape %4#1 : (index) -> !fir.shape<1>
+  %6 = hlfir.elemental %5 : (!fir.shape<1>) -> !hlfir.expr<?xi32> {
+  ^bb0(%arg4: index):
+    %8 = hlfir.designate %1#0 (%arg4)  : (!fir.box<!fir.array<?xi32>>, index) -> !fir.ref<i32>
+    %9 = hlfir.designate %2#0 (%arg4)  : (!fir.box<!fir.array<?xi32>>, index) -> !fir.ref<i32>
+    %10 = fir.load %8 : !fir.ref<i32>
+    %11 = fir.load %9 : !fir.ref<i32>
+    %12 = arith.muli %10, %11 : i32
+    hlfir.yield_element %12 : i32
+  }
+  %array = fir.array_load %0#0 : (!fir.box<!fir.array<?xi32>>) -> !fir.array<?xi32>
+  %c1 = arith.constant 1 : index
+  %max = arith.subi %4#1, %c1 : index
+  %7 = fir.do_loop %arg4 = %c0 to %max step %c1 unordered iter_args(%arg5 = %array) -> (!fir.array<?xi32>) {
+    %8 = hlfir.apply %6, %arg4 : (!hlfir.expr<?xi32>, index) -> i32
+    %9 = hlfir.designate %3#0 (%arg4)  : (!fir.box<!fir.array<?xi32>>, index) -> !fir.ref<i32>
+    %10 = fir.load %9 : !fir.ref<i32>
+    %11 = arith.addi %8, %10 : i32
+    %12 = fir.array_update %arg5, %11, %arg4 : (!fir.array<?xi32>, i32, index) -> !fir.array<?xi32>
+    fir.result %12 : !fir.array<?xi32>
+  }
+  fir.array_merge_store %array, %7 to %arg0 : !fir.array<?xi32>, !fir.array<?xi32>, !fir.box<!fir.array<?xi32>>
+  hlfir.destroy %6 : !hlfir.expr<?xi32>
+  return
+}
+// CHECK-LABEL: func.func @inline_to_loop
+// CHECK-SAME:      %[[A_ARG:.*]]: !fir.box<!fir.array<?xi32>> {fir.bindc_name = "a"}
+// CHECK-SAME:      %[[B_ARG:.*]]: !fir.box<!fir.array<?xi32>> {fir.bindc_name = "b"}
+// CHECK-SAME:      %[[C_ARG:.*]]: !fir.box<!fir.array<?xi32>> {fir.bindc_name = "c"}
+// CHECK-SAME:      %[[D_ARG:.*]]: !fir.box<!fir.array<?xi32>> {fir.bindc_name = "d"}
+// CHECK-DAG:     %[[C1:.*]] = arith.constant 1 : index
+// CHECK-DAG:     %[[C0:.*]] = arith.constant 0 : index
+// CHECK-DAG:     %[[A:.*]]:2 = hlfir.declare %[[A_ARG]]
+// CHECK-DAG:     %[[B:.*]]:2 = hlfir.declare %[[B_ARG]]
+// CHECK-DAG:     %[[C:.*]]:2 = hlfir.declare %[[C_ARG]]
+// CHECK-DAG:     %[[D:.*]]:2 = hlfir.declare %[[D_ARG]]
+// CHECK-NEXT:    %[[B_DIM0:.*]]:3 = fir.box_dims %[[B]]#0, %[[C0]]
+// CHECK-NEXT:    %[[ARRAY:.*]] = fir.array_load %[[A]]#0
+// CHECK-NEXT:    %[[MAX:.*]] = arith.subi %[[B_DIM0]]#1, %[[C1]]
+// CHECK-NEXT:    %[[LOOP:.*]] = fir.do_loop %[[I:.*]] = %[[C0]] to %[[MAX]] step %[[C1]] unordered iter_args(%[[LOOP_ARRAY:.*]] = %[[ARRAY]])
+// inline the elemental:
+// CHECK-NEXT:      %[[B_I_REF:.*]] = hlfir.designate %[[B]]#0 (%[[I]])
+// CHECK-NEXT:      %[[C_I_REF:.*]] = hlfir.designate %[[C]]#0 (%[[I]])
+// CHECK-NEXT:      %[[B_I:.*]] = fir.load %[[B_I_REF]]
+// CHECK-NEXT:      %[[C_I:.*]] = fir.load %[[C_I_REF]]
+// CHECK-NEXT:      %[[MUL:.*]] = arith.muli %[[B_I]], %[[C_I]]
+// loop body:
+// CHECK-NEXT:      %[[D_I_REF:.*]] = hlfir.designate %[[D]]#0 (%[[I]])
+// CHECK-NEXT:      %[[D_I:.*]] = fir.load %[[D_I_REF]]
+// CHECK-NEXT:      %[[ADD:.*]] = arith.addi %[[MUL]], %[[D_I]]
+// CHECK-NEXT:      %[[ARRAY_UPD:.*]] = fir.array_update %[[LOOP_ARRAY]], %[[ADD]], %[[I]]
+// CHECK-NEXT:      fir.result %[[ARRAY_UPD]]
+// CHECK-NEXT:    }
+// CHECK-NEXT:    fir.array_merge_store %[[ARRAY]], %[[LOOP]] to %[[A_ARG]]
+// CHECK-NEXT:    return
+// CHECK-NEXT:  }
+
+// inlining into a single hlfir.apply
+// a = (b * c)[1]
+func.func @inline_to_apply(%arg0: !fir.ref<i32> {fir.bindc_name = "a"}, %arg1: !fir.box<!fir.array<?xi32>> {fir.bindc_name = "b"}, %arg2: !fir.box<!fir.array<?xi32>> {fir.bindc_name = "c"}) {
+  %0:2 = hlfir.declare %arg0 {uniq_name = "a"} : (!fir.ref<i32>) -> (!fir.ref<i32>, !fir.ref<i32>)
+  %1:2 = hlfir.declare %arg1 {uniq_name = "b"} : (!fir.box<!fir.array<?xi32>>) -> (!fir.box<!fir.array<?xi32>>, !fir.box<!fir.array<?xi32>>)
+  %2:2 = hlfir.declare %arg2 {uniq_name = "c"} : (!fir.box<!fir.array<?xi32>>) -> (!fir.box<!fir.array<?xi32>>, !fir.box<!fir.array<?xi32>>)
+  %c0 = arith.constant 0 : index
+  %4:3 = fir.box_dims %1#0, %c0 : (!fir.box<!fir.array<?xi32>>, index) -> (index, index, index)
+  %5 = fir.shape %4#1 : (index) -> !fir.shape<1>
+  %6 = hlfir.elemental %5 : (!fir.shape<1>) -> !hlfir.expr<?xi32> {
+  ^bb0(%arg4: index):
+    %8 = hlfir.designate %1#0 (%arg4)  : (!fir.box<!fir.array<?xi32>>, index) -> !fir.ref<i32>
+    %9 = hlfir.designate %2#0 (%arg4)  : (!fir.box<!fir.array<?xi32>>, index) -> !fir.ref<i32>
+    %10 = fir.load %8 : !fir.ref<i32>
+    %11 = fir.load %9 : !fir.ref<i32>
+    %12 = arith.muli %10, %11 : i32
+    hlfir.yield_element %12 : i32
+  }
+  %c1 = arith.constant 1 : index
+  %res = hlfir.apply %6, %c1 : (!hlfir.expr<?xi32>, index) -> i32
+  fir.store %res to %0#0 : !fir.ref<i32>
+  hlfir.destroy %6 : !hlfir.expr<?xi32>
+  return
+}
+// CHECK-LABEL: func.func @inline_to_apply
+// CHECK-SAME:      %[[A_ARG:.*]]: !fir.ref<i32> {fir.bindc_name = "a"}
+// CHECK-SAME:      %[[B_ARG:.*]]: !fir.box<!fir.array<?xi32>> {fir.bindc_name = "b"}
+// CHECK-SAME:      %[[C_ARG:.*]]: !fir.box<!fir.array<?xi32>> {fir.bindc_name = "c"}
+// CHECK-DAG:     %[[C1:.*]] = arith.constant 1 : index
+// CHECK-DAG:     %[[A:.*]]:2 = hlfir.declare %[[A_ARG]]
+// CHECK-DAG:     %[[B:.*]]:2 = hlfir.declare %[[B_ARG]]
+// CHECK-DAG:     %[[C:.*]]:2 = hlfir.declare %[[C_ARG]]
+// inline the elemental:
+// CHECK-NEXT:    %[[B_1_REF:.*]] = hlfir.designate %[[B]]#0 (%[[C1]])
+// CHECK-NEXT:    %[[C_1_REF:.*]] = hlfir.designate %[[C]]#0 (%[[C1]])
+// CHECK-NEXT:    %[[B_1:.*]] = fir.load %[[B_1_REF]]
+// CHECK-NEXT:    %[[C_1:.*]] = fir.load %[[C_1_REF]]
+// CHECK-NEXT:    %[[MUL:.*]] = arith.muli %[[B_1]], %[[C_1]]
+// store:
+// CHECK-NEXT:    fir.store %[[MUL]] to %0#0 : !fir.ref<i32>
+// CHECK-NEXT:    return
+// CHECK-NEXT:  }
+
+// Check long chains of elementals
+// subroutine reproducer(a)
+//   real, dimension(:) :: a
+//   a = sqrt(a * (a - 1))
+// end subroutine
+func.func @_QPreproducer(%arg0: !fir.box<!fir.array<?xf32>> {fir.bindc_name = "a"}) {
+  %0:2 = hlfir.declare %arg0 {uniq_name = "_QFreproducerEa"} : (!fir.box<!fir.array<?xf32>>) -> (!fir.box<!fir.array<?xf32>>, !fir.box<!fir.array<?xf32>>)
+  %cst = arith.constant 1.000000e+00 : f32
+  %c0 = arith.constant 0 : index
+  %1:3 = fir.box_dims %0#0, %c0 : (!fir.box<!fir.array<?xf32>>, index) -> (index, index, index)
+  %2 = fir.shape %1#1 : (index) -> !fir.shape<1>
+  %3 = hlfir.elemental %2 : (!fir.shape<1>) -> !hlfir.expr<?xf32> {
+  ^bb0(%arg1: index):
+    %9 = hlfir.designate %0#0 (%arg1)  : (!fir.box<!fir.array<?xf32>>, index) -> !fir.ref<f32>
+    %10 = fir.load %9 : !fir.ref<f32>
+    %11 = arith.subf %10, %cst fastmath<contract> : f32
+    hlfir.yield_element %11 : f32
+  }
+  %4 = hlfir.elemental %2 : (!fir.shape<1>) -> !hlfir.expr<?xf32> {
+  ^bb0(%arg1: index):
+    %9 = hlfir.apply %3, %arg1 : (!hlfir.expr<?xf32>, index) -> f32
+    %10 = hlfir.no_reassoc %9 : f32
+    hlfir.yield_element %10 : f32
+  }
+  %c0_0 = arith.constant 0 : index
+  %5:3 = fir.box_dims %0#0, %c0_0 : (!fir.box<!fir.array<?xf32>>, index) -> (index, index, index)
+  %6 = fir.shape %5#1 : (index) -> !fir.shape<1>
+  %7 = hlfir.elemental %6 : (!fir.shape<1>) -> !hlfir.expr<?xf32> {
+  ^bb0(%arg1: index):
+    %9 = hlfir.designate %0#0 (%arg1)  : (!fir.box<!fir.array<?xf32>>, index) -> !fir.ref<f32>
+    %10 = hlfir.apply %4, %arg1 : (!hlfir.expr<?xf32>, index) -> f32
+    %11 = fir.load %9 : !fir.ref<f32>
+    %12 = arith.mulf %11, %10 fastmath<contract> : f32
+    hlfir.yield_element %12 : f32
+  }
+  %8 = hlfir.elemental %6 : (!fir.shape<1>) -> !hlfir.expr<?xf32> {
+  ^bb0(%arg1: index):
+    %9 = hlfir.apply %7, %arg1 : (!hlfir.expr<?xf32>, index) -> f32
+    %10 = math.sqrt %9 fastmath<contract> : f32
+    hlfir.yield_element %10 : f32
+  }
+  hlfir.assign %8 to %0#0 : !hlfir.expr<?xf32>, !fir.box<!fir.array<?xf32>>
+  hlfir.destroy %8 : !hlfir.expr<?xf32>
+  hlfir.destroy %7 : !hlfir.expr<?xf32>
+  hlfir.destroy %4 : !hlfir.expr<?xf32>
+  hlfir.destroy %3 : !hlfir.expr<?xf32>
+  return
+}
+// CHECK-LABEL: func.func @_QPreproducer
+// CHECK-SAME:      %[[A_ARG:.*]]: !fir.box<!fir.array<?xf32>> {fir.bindc_name = "a"}
+// CHECK-DAG:     %[[CST:.*]] = arith.constant 1.0000
+// CHECK-DAG:     %[[C0:.*]] = arith.constant 0
+// CHECK-DAG:     %[[A_VAR:.*]]:2 = hlfir.declare %[[A_ARG]]
+// CHECK-NEXT:    %[[A_DIMS_0:.*]]:3 = fir.box_dims %[[A_VAR]]#0, %[[C0]]
+// CHECK-NEXT:    %[[SHAPE_0:.*]] = fir.shape %[[A_DIMS_0]]#1
+// all in one elemental:
+// CHECK-NEXT:    %[[EXPR:.*]] = hlfir.elemental %[[SHAPE_0]]
+// CHECK-NEXT:    ^bb0(%[[I:.*]]: index):
+// CHECK-NEXT:      %[[A_I0:.*]] = hlfir.designate %[[A_VAR]]#0 (%[[I]])
+// CHECK-NEXT:      %[[A_I1:.*]] = hlfir.designate %[[A_VAR]]#0 (%[[I]])
+// CHECK-NEXT:      %[[A_I1_VAL:.*]] = fir.load %[[A_I1]]
+// CHECK-NEXT:      %[[SUB:.*]] = arith.subf %[[A_I1_VAL]], %[[CST]]
+// CHECK-NEXT:      %[[SUB0:.*]] = hlfir.no_reassoc %[[SUB]] : f32
+// CHECK-NEXT:      %[[A_I0_VAL:.*]] = fir.load %[[A_I0]]
+// CHECK-NEXT:      %[[MUL:.*]] = arith.mulf %[[A_I0_VAL]], %[[SUB0]]
+// CHECK-NEXT:      %[[SQRT:.*]] = math.sqrt %[[MUL]]
+// CHECK-NEXT:      hlfir.yield_element %[[SQRT]]
+// CHECK-NEXT:    }
+// CHECK-NEXT:    hlfir.assign %[[EXPR]] to %[[A_VAR]]#0
+// CHECK-NEXT:    hlfir.destroy %[[EXPR]]
+// CHECK-NEXT:    return
+// CHECK-NEXT:  }