[FLANG][NFCI]De-duplicate code in SimplifyIntrinsics
authorMats Petersson <mats.petersson@arm.com>
Fri, 19 Aug 2022 16:45:35 +0000 (17:45 +0100)
committerMats Petersson <mats.petersson@arm.com>
Fri, 2 Sep 2022 09:49:25 +0000 (10:49 +0100)
This removes a bunch of duplicated code, by adding an intermediate
function simplifyReduction that takes a std::function argument
for the actual replacement of the code.

No functional change intended.

Reviewed By: vzakhari

Differential Revision: https://reviews.llvm.org/D132588

flang/lib/Optimizer/Transforms/SimplifyIntrinsics.cpp

index 91480e8..c51d859 100644 (file)
@@ -52,9 +52,11 @@ namespace {
 class SimplifyIntrinsicsPass
     : public fir::impl::SimplifyIntrinsicsBase<SimplifyIntrinsicsPass> {
   using FunctionTypeGeneratorTy =
-      std::function<mlir::FunctionType(fir::FirOpBuilder &)>;
+      llvm::function_ref<mlir::FunctionType(fir::FirOpBuilder &)>;
   using FunctionBodyGeneratorTy =
-      std::function<void(fir::FirOpBuilder &, mlir::func::FuncOp &)>;
+      llvm::function_ref<void(fir::FirOpBuilder &, mlir::func::FuncOp &)>;
+  using GenReductionBodyTy = llvm::function_ref<void(
+      fir::FirOpBuilder &builder, mlir::func::FuncOp &funcOp)>;
 
 public:
   /// Generate a new function implementing a simplified version
@@ -68,6 +70,16 @@ public:
                                          FunctionBodyGeneratorTy bodyGenerator);
   void runOnOperation() override;
   void getDependentDialects(mlir::DialectRegistry &registry) const override;
+
+private:
+  /// Helper function to replace a reduction type of call with its
+  /// simplified form. The actual function is generated using a callback
+  /// function.
+  /// \p call is the call to be replaced
+  /// \p kindMap is used to create FIROpBuilder
+  /// \p genBodyFunc is the callback that builds the replacement function
+  void simplifyReduction(fir::CallOp call, const fir::KindMapping &kindMap,
+                         GenReductionBodyTy genBodyFunc);
 };
 
 } // namespace
@@ -81,10 +93,10 @@ static mlir::FunctionType genNoneBoxType(fir::FirOpBuilder &builder,
                                  {elementType});
 }
 
-using BodyOpGeneratorTy =
-    std::function<mlir::Value(fir::FirOpBuilder &, mlir::Location,
-                              const mlir::Type &, mlir::Value, mlir::Value)>;
-using InitValGeneratorTy = std::function<mlir::Value(
+using BodyOpGeneratorTy = llvm::function_ref<mlir::Value(
+    fir::FirOpBuilder &, mlir::Location, const mlir::Type &, mlir::Value,
+    mlir::Value)>;
+using InitValGeneratorTy = llvm::function_ref<mlir::Value(
     fir::FirOpBuilder &, mlir::Location, const mlir::Type &)>;
 
 /// Generate the reduction loop into \p funcOp.
@@ -432,6 +444,43 @@ static llvm::Optional<mlir::Type> getArgElementType(mlir::Value val) {
   } while (true);
 }
 
+void SimplifyIntrinsicsPass::simplifyReduction(fir::CallOp call,
+                                               const fir::KindMapping &kindMap,
+                                               GenReductionBodyTy genBodyFunc) {
+  mlir::SymbolRefAttr callee = call.getCalleeAttr();
+  mlir::StringRef funcName = callee.getLeafReference().getValue();
+  mlir::Operation::operand_range args = call.getArgs();
+  // args[1] and args[2] are source filename and line number, ignored.
+  const mlir::Value &dim = args[3];
+  const mlir::Value &mask = args[4];
+  // dim is zero when it is absent, which is an implementation
+  // detail in the runtime library.
+  bool dimAndMaskAbsent = isZero(dim) && isOperandAbsent(mask);
+  unsigned rank = getDimCount(args[0]);
+  if (dimAndMaskAbsent && rank == 1) {
+    mlir::Location loc = call.getLoc();
+    mlir::Type type;
+    fir::FirOpBuilder builder(call, kindMap);
+    if (funcName.endswith("Integer4")) {
+      type = mlir::IntegerType::get(builder.getContext(), 32);
+    } else if (funcName.endswith("Real8")) {
+      type = mlir::FloatType::getF64(builder.getContext());
+    } else {
+      return;
+    }
+    auto typeGenerator = [&type](fir::FirOpBuilder &builder) {
+      return genNoneBoxType(builder, type);
+    };
+    mlir::func::FuncOp newFunc =
+        getOrCreateFunction(builder, funcName, typeGenerator, genBodyFunc);
+    auto newCall =
+        builder.create<fir::CallOp>(loc, newFunc, mlir::ValueRange{args[0]});
+    call->replaceAllUsesWith(newCall.getResults());
+    call->dropAllReferences();
+    call->erase();
+  }
+}
+
 void SimplifyIntrinsicsPass::runOnOperation() {
   LLVM_DEBUG(llvm::dbgs() << "=== Begin " DEBUG_TYPE " ===\n");
   mlir::ModuleOp module = getOperation();
@@ -450,37 +499,7 @@ void SimplifyIntrinsicsPass::runOnOperation() {
         //                int dim, const Descriptor *mask)
         //
         if (funcName.startswith("_FortranASum")) {
-          mlir::Operation::operand_range args = call.getArgs();
-          // args[1] and args[2] are source filename and line number, ignored.
-          const mlir::Value &dim = args[3];
-          const mlir::Value &mask = args[4];
-          // dim is zero when it is absent, which is an implementation
-          // detail in the runtime library.
-          bool dimAndMaskAbsent = isZero(dim) && isOperandAbsent(mask);
-          unsigned rank = getDimCount(args[0]);
-          if (dimAndMaskAbsent && rank == 1) {
-            mlir::Location loc = call.getLoc();
-            mlir::Type type;
-            fir::FirOpBuilder builder(op, kindMap);
-            if (funcName.endswith("Integer4")) {
-              type = mlir::IntegerType::get(builder.getContext(), 32);
-            } else if (funcName.endswith("Real8")) {
-              type = mlir::FloatType::getF64(builder.getContext());
-            } else {
-              return;
-            }
-            auto typeGenerator = [&type](fir::FirOpBuilder &builder) {
-              return genNoneBoxType(builder, type);
-            };
-            mlir::func::FuncOp newFunc = getOrCreateFunction(
-                builder, funcName, typeGenerator, genFortranASumBody);
-            auto newCall = builder.create<fir::CallOp>(
-                loc, newFunc, mlir::ValueRange{args[0]});
-            call->replaceAllUsesWith(newCall.getResults());
-            call->dropAllReferences();
-            call->erase();
-          }
-
+          simplifyReduction(call, kindMap, genFortranASumBody);
           return;
         }
         if (funcName.startswith("_FortranADotProduct")) {
@@ -544,36 +563,8 @@ void SimplifyIntrinsicsPass::runOnOperation() {
           return;
         }
         if (funcName.startswith("_FortranAMaxval")) {
-          mlir::Operation::operand_range args = call.getArgs();
-          const mlir::Value &dim = args[3];
-          const mlir::Value &mask = args[4];
-          // dim is zero when it is absent, which is an implementation
-          // detail in the runtime library.
-          bool dimAndMaskAbsent = isZero(dim) && isOperandAbsent(mask);
-          unsigned rank = getDimCount(args[0]);
-          if (dimAndMaskAbsent && rank == 1) {
-            mlir::Location loc = call.getLoc();
-            mlir::Type type;
-            fir::FirOpBuilder builder(op, kindMap);
-            if (funcName.endswith("Integer4")) {
-              type = mlir::IntegerType::get(builder.getContext(), 32);
-            } else if (funcName.endswith("Real8")) {
-              type = mlir::FloatType::getF64(builder.getContext());
-            } else {
-              return;
-            }
-            auto typeGenerator = [&type](fir::FirOpBuilder &builder) {
-              return genNoneBoxType(builder, type);
-            };
-            mlir::func::FuncOp newFunc = getOrCreateFunction(
-                builder, funcName, typeGenerator, genFortranAMaxvalBody);
-            auto newCall = builder.create<fir::CallOp>(
-                loc, newFunc, mlir::ValueRange{args[0]});
-            call->replaceAllUsesWith(newCall.getResults());
-            call->dropAllReferences();
-            call->erase();
-            return;
-          }
+          simplifyReduction(call, kindMap, genFortranAMaxvalBody);
+          return;
         }
       }
     }