[flang] Add hlfir.associate and hlfir.end_associate definitions
authorJean Perier <jperier@nvidia.com>
Wed, 30 Nov 2022 12:41:44 +0000 (13:41 +0100)
committerJean Perier <jperier@nvidia.com>
Wed, 30 Nov 2022 12:43:56 +0000 (13:43 +0100)
These operations allow creating an HLFIR variable from a HLFIR value and
destroying it at the end of the variable lifetime.
This will both be used to implement procedure reference arguments association
when the actual is an expression, and to implement the Fortran associate
construct when the associated entity is an expression.

See https://github.com/llvm/llvm-project/blob/main/flang/docs/HighLevelFIR.md
for more details.

Differential Revision: https://reviews.llvm.org/D138996

flang/include/flang/Optimizer/HLFIR/HLFIROps.td
flang/lib/Optimizer/HLFIR/IR/HLFIROps.cpp
flang/test/HLFIR/associate.fir [new file with mode: 0644]

index c9ba19b..d576df5 100644 (file)
@@ -229,4 +229,75 @@ def hlfir_ConcatOp : hlfir_Op<"concat", []> {
   let hasVerifier = 1;
 }
 
+def hlfir_AssociateOp : hlfir_Op<"associate", [AttrSizedOperandSegments,
+    DeclareOpInterfaceMethods<fir_FortranVariableOpInterface>]> {
+  let summary = "Create a variable from an expression value";
+  let description = [{
+    Create a variable from an expression value.
+    For expressions, this operation is an incentive to re-use the expression
+    storage, if any, after the bufferization pass when possible (if the
+    expression is not used afterwards).
+  }];
+
+  let arguments = (ins
+    AnyFortranValue:$source,
+    Optional<AnyShapeOrShiftType>:$shape,
+    Variadic<AnyIntegerType>:$typeparams,
+    Builtin_StringAttr:$uniq_name,
+    OptionalAttr<fir_FortranVariableFlagsAttr>:$fortran_attrs
+  );
+
+  let results = (outs AnyFortranVariable, AnyRefOrBoxLike, I1);
+
+  let assemblyFormat = [{
+    $source (`(` $shape^ `)`)? (`typeparams` $typeparams^)?
+     attr-dict `:` functional-type(operands, results)
+  }];
+
+  let builders = [
+    OpBuilder<(ins "mlir::Value":$source, "llvm::StringRef":$uniq_name,
+      CArg<"mlir::Value", "{}">:$shape, CArg<"mlir::ValueRange", "{}">:$typeparams,
+      CArg<"fir::FortranVariableFlagsAttr", "{}">:$fortran_attrs)>];
+
+  let extraClassDeclaration = [{
+    /// Override FortranVariableInterface default implementation
+    mlir::Value getBase() {
+      return getResult(0);
+    }
+
+    /// Get the variable FIR base (same as input). It lacks
+    /// any explicit lower bounds and the extents might not be retrievable
+    /// from it. This matches what is used as a "base" in FIR. All non
+    /// polymorphic expressions FIR base is a simple raw address (they are
+    /// contiguous in memory).
+    mlir::Value getFirBase() {
+      return getResult(1);
+    }
+
+    /// Return the result value that indicates if the variable storage
+    /// was allocated on the heap. At the HLFIR level, this may not be
+    /// known yet, and lowering will need to conditionally free the storage.
+    mlir::Value getMustFreeStrorageFlag() {
+      return getResult(2);
+    }
+  }];
+}
+
+def hlfir_EndAssociateOp : hlfir_Op<"end_associate", []> {
+  let summary = "Mark the end of life of a variable associated to an expression";
+
+  let description = [{
+    Mark the end of life of a variable associated to an expression.
+  }];
+
+  let arguments = (ins AnyRefOrBoxLike:$var,
+                   I1:$must_free);
+
+  let assemblyFormat = [{
+    $var `,` $must_free attr-dict `:` type(operands)
+  }];
+
+  let builders = [OpBuilder<(ins "hlfir::AssociateOp":$associate)>];
+}
+
 #endif // FORTRAN_DIALECT_HLFIR_OPS
index dea3203..6cb3170 100644 (file)
@@ -384,5 +384,36 @@ void hlfir::ConcatOp::build(mlir::OpBuilder &builder,
   build(builder, result, resultType, strings, len);
 }
 
+//===----------------------------------------------------------------------===//
+// AssociateOp
+//===----------------------------------------------------------------------===//
+
+void hlfir::AssociateOp::build(mlir::OpBuilder &builder,
+                               mlir::OperationState &result, mlir::Value source,
+                               llvm::StringRef uniq_name, mlir::Value shape,
+                               mlir::ValueRange typeparams,
+                               fir::FortranVariableFlagsAttr fortran_attrs) {
+  auto nameAttr = builder.getStringAttr(uniq_name);
+  // TODO: preserve polymorphism of polymorphic expr.
+  mlir::Type firVarType = fir::ReferenceType::get(
+      getFortranElementOrSequenceType(source.getType()));
+  mlir::Type hlfirVariableType =
+      DeclareOp::getHLFIRVariableType(firVarType, /*hasExplicitLbs=*/false);
+  mlir::Type i1Type = builder.getI1Type();
+  build(builder, result, {hlfirVariableType, firVarType, i1Type}, source, shape,
+        typeparams, nameAttr, fortran_attrs);
+}
+
+//===----------------------------------------------------------------------===//
+// EndAssociateOp
+//===----------------------------------------------------------------------===//
+
+void hlfir::EndAssociateOp::build(mlir::OpBuilder &builder,
+                                  mlir::OperationState &result,
+                                  hlfir::AssociateOp associate) {
+  return build(builder, result, associate.getFirBase(),
+               associate.getMustFreeStrorageFlag());
+}
+
 #define GET_OP_CLASSES
 #include "flang/Optimizer/HLFIR/HLFIROps.cpp.inc"
diff --git a/flang/test/HLFIR/associate.fir b/flang/test/HLFIR/associate.fir
new file mode 100644 (file)
index 0000000..d0a1e36
--- /dev/null
@@ -0,0 +1,93 @@
+// Test hlfir.associate and hlfir.end_associate operations parse, verify
+// (no errors), and unparse.
+
+// RUN: fir-opt %s | fir-opt | FileCheck %s
+
+func.func @test_cst_char(%arg0: !hlfir.expr<!fir.char<1,12>>) {
+  %c12 = arith.constant 12 : index
+  %0:3 = hlfir.associate %arg0 typeparams %c12 {uniq_name = "x"} : (!hlfir.expr<!fir.char<1,12>>, index) -> (!fir.ref<!fir.char<1,12>>, !fir.ref<!fir.char<1,12>>, i1)
+  fir.call @foo(%0#0) : (!fir.ref<!fir.char<1,12>>) -> ()
+  hlfir.end_associate %0#1, %0#2 : !fir.ref<!fir.char<1,12>>, i1
+  return
+}
+func.func private @foo(!fir.ref<!fir.char<1,12>>)
+// CHECK-LABEL:   func.func @test_cst_char(
+// CHECK-SAME:    %[[VAL_0:.*]]: !hlfir.expr<!fir.char<1,12>>) {
+// CHECK:  %[[VAL_1:.*]] = arith.constant 12 : index
+// CHECK:  %[[VAL_2:.*]]:3 = hlfir.associate %[[VAL_0]] typeparams %[[VAL_1]] {uniq_name = "x"} : (!hlfir.expr<!fir.char<1,12>>, index) -> (!fir.ref<!fir.char<1,12>>, !fir.ref<!fir.char<1,12>>, i1)
+// CHECK:  fir.call @foo(%[[VAL_2]]#0) : (!fir.ref<!fir.char<1,12>>) -> ()
+// CHECK:  hlfir.end_associate %[[VAL_2]]#1, %[[VAL_2]]#2 : !fir.ref<!fir.char<1,12>>, i1
+
+
+func.func @test_dyn_char(%arg0: !hlfir.expr<!fir.char<1,?>>) {
+  %c12 = arith.constant 12 : index
+  %0:3 = hlfir.associate %arg0 typeparams %c12 {uniq_name = "x"} : (!hlfir.expr<!fir.char<1,?>>, index) -> (!fir.boxchar<1>, !fir.ref<!fir.char<1,?>>, i1)
+  fir.call @foo2(%0#0) : (!fir.boxchar<1>) -> ()
+  hlfir.end_associate %0#1, %0#2 : !fir.ref<!fir.char<1,?>>, i1
+  return
+}
+func.func private @foo2(!fir.boxchar<1>)
+// CHECK-LABEL:   func.func @test_dyn_char(
+// CHECK-SAME:    %[[VAL_0:.*]]: !hlfir.expr<!fir.char<1,?>>) {
+// CHECK:  %[[VAL_1:.*]] = arith.constant 12 : index
+// CHECK:  %[[VAL_2:.*]]:3 = hlfir.associate %[[VAL_0]] typeparams %[[VAL_1]] {uniq_name = "x"} : (!hlfir.expr<!fir.char<1,?>>, index) -> (!fir.boxchar<1>, !fir.ref<!fir.char<1,?>>, i1)
+// CHECK:  fir.call @foo2(%[[VAL_2]]#0) : (!fir.boxchar<1>) -> ()
+// CHECK:  hlfir.end_associate %[[VAL_2]]#1, %[[VAL_2]]#2 : !fir.ref<!fir.char<1,?>>, i1
+
+
+func.func @test_integer(%arg0: i32) {
+  %0:3 = hlfir.associate %arg0 {uniq_name = "x"} : (i32) -> (!fir.ref<i32>, !fir.ref<i32>, i1)
+  fir.call @foo3(%0#0) : (!fir.ref<i32>) -> ()
+  hlfir.end_associate %0#1, %0#2 : !fir.ref<i32>, i1
+  return
+}
+func.func private @foo3(!fir.ref<i32>)
+// CHECK-LABEL:   func.func @test_integer(
+// CHECK-SAME:    %[[VAL_0:.*]]: i32) {
+// CHECK:  %[[VAL_1:.*]]:3 = hlfir.associate %[[VAL_0]] {uniq_name = "x"} : (i32) -> (!fir.ref<i32>, !fir.ref<i32>, i1)
+// CHECK:  fir.call @foo3(%[[VAL_1]]#0) : (!fir.ref<i32>) -> ()
+// CHECK:  hlfir.end_associate %[[VAL_1]]#1, %[[VAL_1]]#2 : !fir.ref<i32>, i1
+
+
+func.func @test_logical(%arg0: !fir.logical<8>) {
+  %0:3 = hlfir.associate %arg0 {uniq_name = "x"} : (!fir.logical<8>) -> (!fir.ref<!fir.logical<8>>, !fir.ref<!fir.logical<8>>, i1)
+  fir.call @foo4(%0#0) : (!fir.ref<!fir.logical<8>>) -> ()
+  hlfir.end_associate %0#1, %0#2 : !fir.ref<!fir.logical<8>>, i1
+  return
+}
+func.func private @foo4(!fir.ref<!fir.logical<8>>)
+// CHECK-LABEL:   func.func @test_logical(
+// CHECK-SAME:    %[[VAL_0:.*]]: !fir.logical<8>) {
+// CHECK:  %[[VAL_1:.*]]:3 = hlfir.associate %[[VAL_0]] {uniq_name = "x"} : (!fir.logical<8>) -> (!fir.ref<!fir.logical<8>>, !fir.ref<!fir.logical<8>>, i1)
+// CHECK:  fir.call @foo4(%[[VAL_1]]#0) : (!fir.ref<!fir.logical<8>>) -> ()
+// CHECK:  hlfir.end_associate %[[VAL_1]]#1, %[[VAL_1]]#2 : !fir.ref<!fir.logical<8>>, i1
+
+
+func.func @test_complex(%arg0: !fir.complex<8>) {
+  %0:3 = hlfir.associate %arg0 {uniq_name = "x"} : (!fir.complex<8>) -> (!fir.ref<!fir.complex<8>>, !fir.ref<!fir.complex<8>>, i1)
+  fir.call @foo5(%0#0) : (!fir.ref<!fir.complex<8>>) -> ()
+  hlfir.end_associate %0#1, %0#2 : !fir.ref<!fir.complex<8>>, i1
+  return
+}
+func.func private @foo5(!fir.ref<!fir.complex<8>>)
+// CHECK-LABEL:   func.func @test_complex(
+// CHECK-SAME:    %[[VAL_0:.*]]: !fir.complex<8>) {
+// CHECK:  %[[VAL_1:.*]]:3 = hlfir.associate %[[VAL_0]] {uniq_name = "x"} : (!fir.complex<8>) -> (!fir.ref<!fir.complex<8>>, !fir.ref<!fir.complex<8>>, i1)
+// CHECK:  fir.call @foo5(%[[VAL_1]]#0) : (!fir.ref<!fir.complex<8>>) -> ()
+// CHECK:  hlfir.end_associate %[[VAL_1]]#1, %[[VAL_1]]#2 : !fir.ref<!fir.complex<8>>, i1
+
+
+func.func @test_array(%arg0: !hlfir.expr<!fir.array<10x?xi32>>) {
+  %shape = fir.undefined !fir.shape<2>
+  %0:3 = hlfir.associate %arg0(%shape) {uniq_name = "x"} : (!hlfir.expr<!fir.array<10x?xi32>>, !fir.shape<2>) -> (!fir.box<!fir.array<10x?xi32>>, !fir.ref<!fir.array<10x?xi32>>, i1)
+  fir.call @foo2(%0#0) : (!fir.box<!fir.array<10x?xi32>>) -> ()
+  hlfir.end_associate %0#1, %0#2 : !fir.ref<!fir.array<10x?xi32>>, i1
+  return
+}
+func.func private @foo6(!fir.box<!fir.array<10x?xi32>>)
+// CHECK-LABEL:   func.func @test_array(
+// CHECK-SAME:    %[[VAL_0:.*]]: !hlfir.expr<!fir.array<10x?xi32>>) {
+// CHECK:  %[[VAL_1:.*]] = fir.undefined !fir.shape<2>
+// CHECK:  %[[VAL_2:.*]]:3 = hlfir.associate %[[VAL_0]](%[[VAL_1]]) {uniq_name = "x"} : (!hlfir.expr<!fir.array<10x?xi32>>, !fir.shape<2>) -> (!fir.box<!fir.array<10x?xi32>>, !fir.ref<!fir.array<10x?xi32>>, i1)
+// CHECK:  fir.call @foo2(%[[VAL_2]]#0) : (!fir.box<!fir.array<10x?xi32>>) -> ()
+// CHECK:  hlfir.end_associate %[[VAL_2]]#1, %[[VAL_2]]#2 : !fir.ref<!fir.array<10x?xi32>>, i1