[flang][hlfir] Add hlfir.copy_in and hlfir.copy_out codegen to FIR.
authorJean Perier <jperier@nvidia.com>
Wed, 25 Jan 2023 08:54:50 +0000 (09:54 +0100)
committerJean Perier <jperier@nvidia.com>
Wed, 25 Jan 2023 09:46:06 +0000 (10:46 +0100)
Use runtime Assign to deal with the copy (and the temporary creation, so
that this code can deal with polymorphic temps without any change).

Using Assign for the copy is desired here since the copy happens when
the data is not contiguous, and it happens inside an if/then which
makes it hard to optimize.
See https://github.com/llvm/llvm-project/commit/2b60ed405b8110b20ab2e383839759ea34003127
for more details (note that, contrary to this last commit, the code at
hand is only dealing with copy-in/copy-out, it is not intended to deal
with preparing VALUE arguments).

Differential Revision: https://reviews.llvm.org/D142475

flang/include/flang/Optimizer/Builder/FIRBuilder.h
flang/lib/Optimizer/HLFIR/Transforms/ConvertToFIR.cpp
flang/test/HLFIR/copy-in-out-codegen.fir [new file with mode: 0644]

index 5d1f855..5e32a1b 100644 (file)
@@ -404,6 +404,11 @@ public:
     return IfBuilder(op, *this);
   }
 
+  mlir::Value genNot(mlir::Location loc, mlir::Value boolean) {
+    return create<mlir::arith::CmpIOp>(loc, mlir::arith::CmpIPredicate::eq,
+                                       boolean, createBool(loc, false));
+  }
+
   /// Generate code testing \p addr is not a null address.
   mlir::Value genIsNotNullAddr(mlir::Location loc, mlir::Value addr);
 
index 5487def..a71463d 100644 (file)
@@ -13,6 +13,7 @@
 #include "flang/Optimizer/Builder/HLFIRTools.h"
 #include "flang/Optimizer/Builder/MutableBox.h"
 #include "flang/Optimizer/Builder/Runtime/Assign.h"
+#include "flang/Optimizer/Builder/Runtime/Inquiry.h"
 #include "flang/Optimizer/Builder/Todo.h"
 #include "flang/Optimizer/Dialect/FIRDialect.h"
 #include "flang/Optimizer/Dialect/FIROps.h"
@@ -29,6 +30,37 @@ namespace hlfir {
 
 using namespace mlir;
 
+static mlir::Value genAllocatableTempFromSourceBox(mlir::Location loc,
+                                                   fir::FirOpBuilder &builder,
+                                                   mlir::Value sourceBox) {
+  assert(sourceBox.getType().isa<fir::BaseBoxType>() &&
+         "must be a base box type");
+  // Use the runtime to make a quick and dirty temp with the rhs value.
+  // Overkill for scalar rhs that could be done in much more clever ways.
+  // Note that temp descriptor must have the allocatable flag set so that
+  // the runtime will allocate it with the shape and type parameters of
+  // the RHS.
+  // This has the huge benefit of dealing with all cases, including
+  // polymorphic entities.
+  mlir::Type fromHeapType = fir::HeapType::get(
+      fir::unwrapRefType(sourceBox.getType().cast<fir::BoxType>().getEleTy()));
+  mlir::Type fromBoxHeapType = fir::BoxType::get(fromHeapType);
+  auto fromMutableBox = builder.createTemporary(loc, fromBoxHeapType);
+  mlir::Value unallocatedBox =
+      fir::factory::createUnallocatedBox(builder, loc, fromBoxHeapType, {});
+  builder.create<fir::StoreOp>(loc, unallocatedBox, fromMutableBox);
+  fir::runtime::genAssign(builder, loc, fromMutableBox, sourceBox);
+  mlir::Value copy = builder.create<fir::LoadOp>(loc, fromMutableBox);
+  return copy;
+}
+
+static std::pair<mlir::Value, bool>
+genTempFromSourceBox(mlir::Location loc, fir::FirOpBuilder &builder,
+                     mlir::Value sourceBox) {
+  return {genAllocatableTempFromSourceBox(loc, builder, sourceBox),
+          /*cleanUpTemp=*/true};
+}
+
 namespace {
 /// May \p lhs alias with \p rhs?
 /// TODO: implement HLFIR alias analysis.
@@ -65,23 +97,9 @@ public:
       auto to = fir::getBase(builder.createBox(loc, lhsExv));
       auto from = fir::getBase(builder.createBox(loc, rhsExv));
       bool cleanUpTemp = false;
-      mlir::Type fromHeapType = fir::HeapType::get(
-          fir::unwrapRefType(from.getType().cast<fir::BoxType>().getEleTy()));
-      if (mayAlias(rhs, lhs)) {
-        /// Use the runtime to make a quick and dirty temp with the rhs value.
-        /// Overkill for scalar rhs that could be done in much more clever ways.
-        /// Note that temp descriptor must have the allocatable flag set so that
-        /// the runtime will allocate it with the shape and type parameters of
-        //  the RHS.
-        mlir::Type fromBoxHeapType = fir::BoxType::get(fromHeapType);
-        auto fromMutableBox = builder.createTemporary(loc, fromBoxHeapType);
-        mlir::Value unallocatedBox = fir::factory::createUnallocatedBox(
-            builder, loc, fromBoxHeapType, {});
-        builder.create<fir::StoreOp>(loc, unallocatedBox, fromMutableBox);
-        fir::runtime::genAssign(builder, loc, fromMutableBox, from);
-        cleanUpTemp = true;
-        from = builder.create<fir::LoadOp>(loc, fromMutableBox);
-      }
+      if (mayAlias(rhs, lhs))
+        std::tie(from, cleanUpTemp) = genTempFromSourceBox(loc, builder, from);
+
       auto toMutableBox = builder.createTemporary(loc, to.getType());
       // As per 10.2.1.2 point 1 (1) polymorphic variables must be allocatable.
       // It is assumed here that they have been reallocated with the dynamic
@@ -89,8 +107,7 @@ public:
       builder.create<fir::StoreOp>(loc, to, toMutableBox);
       fir::runtime::genAssign(builder, loc, toMutableBox, from);
       if (cleanUpTemp) {
-        mlir::Value addr =
-            builder.create<fir::BoxAddrOp>(loc, fromHeapType, from);
+        mlir::Value addr = builder.create<fir::BoxAddrOp>(loc, from);
         builder.create<fir::FreeMemOp>(loc, addr);
       }
     } else {
@@ -109,6 +126,122 @@ public:
   }
 };
 
+class CopyInOpConversion : public mlir::OpRewritePattern<hlfir::CopyInOp> {
+public:
+  explicit CopyInOpConversion(mlir::MLIRContext *ctx) : OpRewritePattern{ctx} {}
+
+  struct CopyInResult {
+    mlir::Value addr;
+    mlir::Value wasCopied;
+  };
+
+  static CopyInResult genNonOptionalCopyIn(mlir::Location loc,
+                                           fir::FirOpBuilder &builder,
+                                           hlfir::CopyInOp copyInOp) {
+    mlir::Value inputVariable = copyInOp.getVar();
+    mlir::Type resultAddrType = copyInOp.getCopiedIn().getType();
+    mlir::Value isContiguous =
+        fir::runtime::genIsContiguous(builder, loc, inputVariable);
+    mlir::Value addr =
+        builder
+            .genIfOp(loc, {resultAddrType}, isContiguous,
+                     /*withElseRegion=*/true)
+            .genThen(
+                [&]() { builder.create<fir::ResultOp>(loc, inputVariable); })
+            .genElse([&] {
+              // Create temporary on the heap. Note that the runtime is used and
+              // that is desired: since the data copy happens under a runtime
+              // check (for IsContiguous) the copy loops can hardly provide any
+              // value to optimizations, instead, the optimizer just wastes
+              // compilation time on these loops.
+              mlir::Value temp =
+                  genAllocatableTempFromSourceBox(loc, builder, inputVariable);
+              // Get rid of allocatable flag in the fir.box.
+              temp = builder.create<fir::ReboxOp>(loc, resultAddrType, temp,
+                                                  /*shape=*/mlir::Value{},
+                                                  /*slice=*/mlir::Value{});
+              builder.create<fir::ResultOp>(loc, temp);
+            })
+            .getResults()[0];
+    return {addr, builder.genNot(loc, isContiguous)};
+  }
+
+  static CopyInResult genOptionalCopyIn(mlir::Location loc,
+                                        fir::FirOpBuilder &builder,
+                                        hlfir::CopyInOp copyInOp) {
+    mlir::Type resultAddrType = copyInOp.getCopiedIn().getType();
+    mlir::Value isPresent = copyInOp.getVarIsPresent();
+    auto res =
+        builder
+            .genIfOp(loc, {resultAddrType, builder.getI1Type()}, isPresent,
+                     /*withElseRegion=*/true)
+            .genThen([&]() {
+              CopyInResult res = genNonOptionalCopyIn(loc, builder, copyInOp);
+              builder.create<fir::ResultOp>(
+                  loc, mlir::ValueRange{res.addr, res.wasCopied});
+            })
+            .genElse([&] {
+              mlir::Value absent =
+                  builder.create<fir::AbsentOp>(loc, resultAddrType);
+              builder.create<fir::ResultOp>(
+                  loc, mlir::ValueRange{absent, isPresent});
+            })
+            .getResults();
+    return {res[0], res[1]};
+  }
+
+  mlir::LogicalResult
+  matchAndRewrite(hlfir::CopyInOp copyInOp,
+                  mlir::PatternRewriter &rewriter) const override {
+    mlir::Location loc = copyInOp.getLoc();
+    auto module = copyInOp->getParentOfType<mlir::ModuleOp>();
+    fir::FirOpBuilder builder(rewriter, fir::getKindMapping(module));
+    CopyInResult result = copyInOp.getVarIsPresent()
+                              ? genOptionalCopyIn(loc, builder, copyInOp)
+                              : genNonOptionalCopyIn(loc, builder, copyInOp);
+    rewriter.replaceOp(copyInOp, {result.addr, result.wasCopied});
+    return mlir::success();
+  }
+};
+
+class CopyOutOpConversion : public mlir::OpRewritePattern<hlfir::CopyOutOp> {
+public:
+  explicit CopyOutOpConversion(mlir::MLIRContext *ctx)
+      : OpRewritePattern{ctx} {}
+
+  mlir::LogicalResult
+  matchAndRewrite(hlfir::CopyOutOp copyOutOp,
+                  mlir::PatternRewriter &rewriter) const override {
+    mlir::Location loc = copyOutOp.getLoc();
+    auto module = copyOutOp->getParentOfType<mlir::ModuleOp>();
+    fir::FirOpBuilder builder(rewriter, fir::getKindMapping(module));
+
+    builder.genIfThen(loc, copyOutOp.getWasCopied())
+        .genThen([&]() {
+          mlir::Value temp = copyOutOp.getTemp();
+          if (mlir::Value var = copyOutOp.getVar()) {
+            auto mutableBox = builder.createTemporary(loc, var.getType());
+            builder.create<fir::StoreOp>(loc, var, mutableBox);
+            // Generate Assign() call to copy data from the temporary
+            // to the variable. Note that in case the actual argument
+            // is ALLOCATABLE/POINTER the Assign() implementation
+            // should not engage its reallocation, because the temporary
+            // is rank, shape and type compatible with it (it was created
+            // from the variable).
+            fir::runtime::genAssign(builder, loc, mutableBox, temp);
+          }
+          mlir::Type heapType =
+              fir::HeapType::get(fir::dyn_cast_ptrOrBoxEleTy(temp.getType()));
+          mlir::Value tempAddr =
+              builder.create<fir::BoxAddrOp>(loc, heapType, temp);
+          builder.create<fir::FreeMemOp>(loc, tempAddr);
+        })
+        .end();
+    rewriter.eraseOp(copyOutOp);
+    return mlir::success();
+  }
+};
+
 class DeclareOpConversion : public mlir::OpRewritePattern<hlfir::DeclareOp> {
 public:
   explicit DeclareOpConversion(mlir::MLIRContext *ctx)
@@ -390,9 +523,9 @@ public:
     auto module = this->getOperation();
     auto *context = &getContext();
     mlir::RewritePatternSet patterns(context);
-    patterns
-        .insert<AssignOpConversion, DeclareOpConversion, DesignateOpConversion,
-                NoReassocOpConversion, NullOpConversion>(context);
+    patterns.insert<AssignOpConversion, CopyInOpConversion, CopyOutOpConversion,
+                    DeclareOpConversion, DesignateOpConversion,
+                    NoReassocOpConversion, NullOpConversion>(context);
     mlir::ConversionTarget target(*context);
     target.addIllegalDialect<hlfir::hlfirDialect>();
     target.markUnknownOpDynamicallyLegal(
@@ -405,6 +538,7 @@ public:
     }
   }
 };
+
 } // namespace
 
 std::unique_ptr<mlir::Pass> hlfir::createConvertHLFIRtoFIRPass() {
diff --git a/flang/test/HLFIR/copy-in-out-codegen.fir b/flang/test/HLFIR/copy-in-out-codegen.fir
new file mode 100644 (file)
index 0000000..b6c7a3c
--- /dev/null
@@ -0,0 +1,96 @@
+// Test hlfir.copy_in and hlfir.copy_out operation codegen
+
+// RUN: fir-opt %s -convert-hlfir-to-fir | FileCheck %s
+
+func.func @test_copy_in(%box: !fir.box<!fir.array<?xf64>>) {
+  %0:2 = hlfir.copy_in %box : (!fir.box<!fir.array<?xf64>>) -> (!fir.box<!fir.array<?xf64>>, i1)
+  return
+}
+// CHECK-LABEL:   func.func @test_copy_in(
+// CHECK-SAME:    %[[VAL_0:.*]]: !fir.box<!fir.array<?xf64>>) {
+// CHECK:    %[[VAL_1:.*]] = fir.alloca !fir.box<!fir.heap<!fir.array<?xf64>>>
+// CHECK:    %[[VAL_2:.*]] = fir.convert %[[VAL_0]] : (!fir.box<!fir.array<?xf64>>) -> !fir.box<none>
+// CHECK:    %[[VAL_3:.*]] = fir.call @_FortranAIsContiguous(%[[VAL_2]]) : (!fir.box<none>) -> i1
+// CHECK:    %[[VAL_4:.*]] = fir.if %[[VAL_3]] -> (!fir.box<!fir.array<?xf64>>) {
+// CHECK:      fir.result %[[VAL_0]] : !fir.box<!fir.array<?xf64>>
+// CHECK:    } else {
+// CHECK:      %[[VAL_5:.*]] = fir.zero_bits !fir.heap<!fir.array<?xf64>>
+// CHECK:      %[[VAL_6:.*]] = arith.constant 0 : index
+// CHECK:      %[[VAL_7:.*]] = fir.shape %[[VAL_6]] : (index) -> !fir.shape<1>
+// CHECK:      %[[VAL_8:.*]] = fir.embox %[[VAL_5]](%[[VAL_7]]) : (!fir.heap<!fir.array<?xf64>>, !fir.shape<1>) -> !fir.box<!fir.heap<!fir.array<?xf64>>>
+// CHECK:      fir.store %[[VAL_8]] to %[[VAL_1]] : !fir.ref<!fir.box<!fir.heap<!fir.array<?xf64>>>>
+// CHECK:      %[[VAL_12:.*]] = fir.convert %[[VAL_1]] : (!fir.ref<!fir.box<!fir.heap<!fir.array<?xf64>>>>) -> !fir.ref<!fir.box<none>>
+// CHECK:      %[[VAL_13:.*]] = fir.convert %[[VAL_0]] : (!fir.box<!fir.array<?xf64>>) -> !fir.box<none>
+// CHECK:      %[[VAL_15:.*]] = fir.call @_FortranAAssign(%[[VAL_12]], %[[VAL_13]],
+// CHECK:      %[[VAL_16:.*]] = fir.load %[[VAL_1]] : !fir.ref<!fir.box<!fir.heap<!fir.array<?xf64>>>>
+// CHECK:      %[[VAL_17:.*]] = fir.rebox %[[VAL_16]] : (!fir.box<!fir.heap<!fir.array<?xf64>>>) -> !fir.box<!fir.array<?xf64>>
+// CHECK:      fir.result %[[VAL_17]] : !fir.box<!fir.array<?xf64>>
+// CHECK:    }
+// CHECK:    %[[VAL_18:.*]] = arith.constant false
+// CHECK:    %[[VAL_19:.*]] = arith.cmpi eq, %[[VAL_3]], %[[VAL_18]] : i1
+// CHECK:    return
+// CHECK:  }
+
+func.func @test_copy_in_optional(%box: !fir.box<!fir.array<?xf64>>, %is_present: i1) {
+  %0:2 = hlfir.copy_in %box handle_optional %is_present : (!fir.box<!fir.array<?xf64>>, i1) -> (!fir.box<!fir.array<?xf64>>, i1)
+  return
+}
+// CHECK-LABEL:   func.func @test_copy_in_optional(
+// CHECK-SAME:    %[[VAL_0:.*]]: !fir.box<!fir.array<?xf64>>,
+// CHECK-SAME:    %[[VAL_1:.*]]: i1) {
+// CHECK:    %[[VAL_2:.*]] = fir.alloca !fir.box<!fir.heap<!fir.array<?xf64>>>
+// CHECK:    %[[VAL_3:.*]]:2 = fir.if %[[VAL_1]] -> (!fir.box<!fir.array<?xf64>>, i1) {
+// CHECK:      %[[VAL_4:.*]] = fir.convert %[[VAL_0]] : (!fir.box<!fir.array<?xf64>>) -> !fir.box<none>
+// CHECK:      %[[VAL_5:.*]] = fir.call @_FortranAIsContiguous(%[[VAL_4]]) : (!fir.box<none>) -> i1
+// CHECK:      %[[VAL_6:.*]] = fir.if %[[VAL_5]] -> (!fir.box<!fir.array<?xf64>>) {
+// CHECK:        fir.result %[[VAL_0]] : !fir.box<!fir.array<?xf64>>
+// CHECK:      } else {
+// CHECK:        %[[VAL_7:.*]] = fir.zero_bits !fir.heap<!fir.array<?xf64>>
+// CHECK:        %[[VAL_8:.*]] = arith.constant 0 : index
+// CHECK:        %[[VAL_9:.*]] = fir.shape %[[VAL_8]] : (index) -> !fir.shape<1>
+// CHECK:        %[[VAL_10:.*]] = fir.embox %[[VAL_7]](%[[VAL_9]]) : (!fir.heap<!fir.array<?xf64>>, !fir.shape<1>) -> !fir.box<!fir.heap<!fir.array<?xf64>>>
+// CHECK:        fir.store %[[VAL_10]] to %[[VAL_2]] : !fir.ref<!fir.box<!fir.heap<!fir.array<?xf64>>>>
+// CHECK:        %[[VAL_14:.*]] = fir.convert %[[VAL_2]] : (!fir.ref<!fir.box<!fir.heap<!fir.array<?xf64>>>>) -> !fir.ref<!fir.box<none>>
+// CHECK:        %[[VAL_15:.*]] = fir.convert %[[VAL_0]] : (!fir.box<!fir.array<?xf64>>) -> !fir.box<none>
+// CHECK:        %[[VAL_17:.*]] = fir.call @_FortranAAssign(%[[VAL_14]], %[[VAL_15]],
+// CHECK:        %[[VAL_18:.*]] = fir.load %[[VAL_2]] : !fir.ref<!fir.box<!fir.heap<!fir.array<?xf64>>>>
+// CHECK:        %[[VAL_19:.*]] = fir.rebox %[[VAL_18]] : (!fir.box<!fir.heap<!fir.array<?xf64>>>) -> !fir.box<!fir.array<?xf64>>
+// CHECK:        fir.result %[[VAL_19]] : !fir.box<!fir.array<?xf64>>
+// CHECK:      }
+// CHECK:      %[[VAL_20:.*]] = arith.constant false
+// CHECK:      %[[VAL_21:.*]] = arith.cmpi eq, %[[VAL_5]], %[[VAL_20]] : i1
+// CHECK:      fir.result %[[VAL_22:.*]], %[[VAL_21]] : !fir.box<!fir.array<?xf64>>, i1
+// CHECK:    } else {
+// CHECK:      %[[VAL_23:.*]] = fir.absent !fir.box<!fir.array<?xf64>>
+// CHECK:      fir.result %[[VAL_23]], %[[VAL_1]] : !fir.box<!fir.array<?xf64>>, i1
+// CHECK:    }
+
+func.func @test_copy_out_no_copy_back(%temp: !fir.box<!fir.array<?xf64>>, %was_copied: i1) {
+  hlfir.copy_out %temp, %was_copied : (!fir.box<!fir.array<?xf64>>, i1) -> ()
+  return
+}
+// CHECK-LABEL:   func.func @test_copy_out_no_copy_back(
+// CHECK-SAME:    %[[VAL_0:.*]]: !fir.box<!fir.array<?xf64>>,
+// CHECK-SAME:    %[[VAL_1:.*]]: i1) {
+// CHECK-NEXT:    fir.if %[[VAL_1]] {
+// CHECK-NEXT:      %[[VAL_2:.*]] = fir.box_addr %[[VAL_0]] : (!fir.box<!fir.array<?xf64>>) -> !fir.heap<!fir.array<?xf64>>
+// CHECK-NEXT:      fir.freemem %[[VAL_2]] : !fir.heap<!fir.array<?xf64>>
+// CHECK-NEXT:    }
+
+func.func @test_copy_out_copy_back(%box: !fir.box<!fir.array<?xf64>>, %temp: !fir.box<!fir.array<?xf64>>, %was_copied: i1) {
+  hlfir.copy_out %temp, %was_copied to %box : (!fir.box<!fir.array<?xf64>>, i1, !fir.box<!fir.array<?xf64>>) -> ()
+  return
+}
+// CHECK-LABEL:   func.func @test_copy_out_copy_back(
+// CHECK-SAME:    %[[VAL_0:[^:]*]]: !fir.box<!fir.array<?xf64>>,
+// CHECK-SAME:    %[[VAL_1:.*]]: !fir.box<!fir.array<?xf64>>,
+// CHECK-SAME:    %[[VAL_2:.*]]: i1) {
+// CHECK:    %[[VAL_3:.*]] = fir.alloca !fir.box<!fir.array<?xf64>>
+// CHECK:    fir.if %[[VAL_2]] {
+// CHECK:      fir.store %[[VAL_0]] to %[[VAL_3]] : !fir.ref<!fir.box<!fir.array<?xf64>>>
+// CHECK:      %[[VAL_7:.*]] = fir.convert %[[VAL_3]] : (!fir.ref<!fir.box<!fir.array<?xf64>>>) -> !fir.ref<!fir.box<none>>
+// CHECK:      %[[VAL_8:.*]] = fir.convert %[[VAL_1]] : (!fir.box<!fir.array<?xf64>>) -> !fir.box<none>
+// CHECK:      %[[VAL_10:.*]] = fir.call @_FortranAAssign(%[[VAL_7]], %[[VAL_8]],
+// CHECK:      %[[VAL_11:.*]] = fir.box_addr %[[VAL_1]] : (!fir.box<!fir.array<?xf64>>) -> !fir.heap<!fir.array<?xf64>>
+// CHECK:      fir.freemem %[[VAL_11]] : !fir.heap<!fir.array<?xf64>>
+// CHECK:    }