[fir] TargetRewrite: Rewrite fir.address_of(func)
authorDiana Picus <diana.picus@linaro.org>
Thu, 2 Dec 2021 04:27:18 +0000 (04:27 +0000)
committerDiana Picus <diana.picus@linaro.org>
Fri, 3 Dec 2021 10:56:24 +0000 (10:56 +0000)
Rewrite AddrOfOp if taking the address of a function.

Differential Revision: https://reviews.llvm.org/D114925

Co-authored-by: Eric Schweitz <eschweitz@nvidia.com>
flang/lib/Optimizer/CodeGen/TargetRewrite.cpp
flang/test/Fir/target-rewrite-boxchar.fir
flang/test/Fir/target-rewrite-complex.fir

index 25e1e44..7a762fb 100644 (file)
@@ -100,6 +100,10 @@ public:
       } else if (auto dispatch = dyn_cast<DispatchOp>(op)) {
         if (!hasPortableSignature(dispatch.getFunctionType()))
           convertCallOp(dispatch);
+      } else if (auto addr = dyn_cast<AddrOfOp>(op)) {
+        if (addr.getType().isa<mlir::FunctionType>() &&
+            !hasPortableSignature(addr.getType()))
+          convertAddrOp(addr);
       }
     });
 
@@ -319,6 +323,55 @@ public:
         newInTys.push_back(std::get<mlir::Type>(tup));
   }
 
+  /// Taking the address of a function. Modify the signature as needed.
+  void convertAddrOp(AddrOfOp addrOp) {
+    rewriter->setInsertionPoint(addrOp);
+    auto addrTy = addrOp.getType().cast<mlir::FunctionType>();
+    llvm::SmallVector<mlir::Type> newResTys;
+    llvm::SmallVector<mlir::Type> newInTys;
+    for (mlir::Type ty : addrTy.getResults()) {
+      llvm::TypeSwitch<mlir::Type>(ty)
+          .Case<fir::ComplexType>([&](fir::ComplexType ty) {
+            lowerComplexSignatureRes(ty, newResTys, newInTys);
+          })
+          .Case<mlir::ComplexType>([&](mlir::ComplexType ty) {
+            lowerComplexSignatureRes(ty, newResTys, newInTys);
+          })
+          .Default([&](mlir::Type ty) { newResTys.push_back(ty); });
+    }
+    llvm::SmallVector<mlir::Type> trailingInTys;
+    for (mlir::Type ty : addrTy.getInputs()) {
+      llvm::TypeSwitch<mlir::Type>(ty)
+          .Case<BoxCharType>([&](BoxCharType box) {
+            if (noCharacterConversion) {
+              newInTys.push_back(box);
+            } else {
+              for (auto &tup : specifics->boxcharArgumentType(box.getEleTy())) {
+                auto attr = std::get<CodeGenSpecifics::Attributes>(tup);
+                auto argTy = std::get<mlir::Type>(tup);
+                llvm::SmallVector<mlir::Type> &vec =
+                    attr.isAppend() ? trailingInTys : newInTys;
+                vec.push_back(argTy);
+              }
+            }
+          })
+          .Case<fir::ComplexType>([&](fir::ComplexType ty) {
+            lowerComplexSignatureArg(ty, newInTys);
+          })
+          .Case<mlir::ComplexType>([&](mlir::ComplexType ty) {
+            lowerComplexSignatureArg(ty, newInTys);
+          })
+          .Default([&](mlir::Type ty) { newInTys.push_back(ty); });
+    }
+    // append trailing input types
+    newInTys.insert(newInTys.end(), trailingInTys.begin(), trailingInTys.end());
+    // replace this op with a new one with the updated signature
+    auto newTy = rewriter->getFunctionType(newInTys, newResTys);
+    auto newOp =
+        rewriter->create<AddrOfOp>(addrOp.getLoc(), newTy, addrOp.symbol());
+    replaceOp(addrOp, newOp.getResult());
+  }
+
   /// Convert the type signatures on all the functions present in the module.
   /// As the type signature is being changed, this must also update the
   /// function itself to use any new arguments, etc.
index e2fb31f..400ae54 100644 (file)
@@ -93,3 +93,13 @@ fir.global @name constant : !fir.char<1,9> {
   //constant 1
   fir.has_value %str : !fir.char<1,9>
 }
+
+// Test that we rewrite the fir.address_of operator
+// INT32-LABEL: @addrof
+// INT64-LABEL: @addrof
+func @addrof() {
+  // INT32: {{.*}} = fir.address_of(@boxcharcallee) : (!fir.ref<!fir.char<1,?>>, i32) -> ()
+  // INT64: {{.*}} = fir.address_of(@boxcharcallee) : (!fir.ref<!fir.char<1,?>>, i64) -> ()
+  %f = fir.address_of(@boxcharcallee) : (!fir.boxchar<1>) -> ()
+  return
+}
index 54fd2f2..49c9586 100644 (file)
@@ -452,3 +452,23 @@ func private @mlircomplexf32(%z1: complex<f32>, %z2: complex<f32>) -> complex<f3
   // PPC: return [[RES]] : tuple<f32, f32>
   return %0 : complex<f32>
 }
+
+// Test that we rewrite the fir.address_of operator.
+// I32-LABEL: func @addrof()
+// X64-LABEL: func @addrof()
+// AARCH64-LABEL: func @addrof()
+// PPC-LABEL: func @addrof()
+func @addrof() {
+  // I32: {{%.*}} = fir.address_of(@returncomplex4) : () -> i64
+  // X64: {{%.*}} = fir.address_of(@returncomplex4) : () -> !fir.vector<2:!fir.real<4>>
+  // AARCH64: {{%.*}} = fir.address_of(@returncomplex4) : () -> tuple<!fir.real<4>, !fir.real<4>>
+  // PPC: {{%.*}} = fir.address_of(@returncomplex4) : () -> tuple<!fir.real<4>, !fir.real<4>>
+  %r = fir.address_of(@returncomplex4) : () -> !fir.complex<4>
+
+  // I32: {{%.*}} = fir.address_of(@paramcomplex4) : (!fir.ref<tuple<!fir.real<4>, !fir.real<4>>>) -> ()
+  // X64: {{%.*}} = fir.address_of(@paramcomplex4) : (!fir.vector<2:!fir.real<4>>) -> ()
+  // AARCH64: {{%.*}} = fir.address_of(@paramcomplex4) : (!fir.array<2x!fir.real<4>>) -> ()
+  // PPC: {{%.*}} = fir.address_of(@paramcomplex4) : (!fir.real<4>, !fir.real<4>) -> ()
+  %p = fir.address_of(@paramcomplex4) : (!fir.complex<4>) -> ()
+  return
+}