[FLANG]Remove experimental flag from SUM simplification
[platform/upstream/llvm.git] / flang / lib / Optimizer / Transforms / SimplifyIntrinsics.cpp
1 //===- SimplifyIntrinsics.cpp -- replace intrinsics with simpler form -----===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8
9 //===----------------------------------------------------------------------===//
10 /// \file
11 /// This pass looks for suitable calls to runtime library for intrinsics that
12 /// can be simplified/specialized and replaces with a specialized function.
13 ///
14 /// For example, SUM(arr) can be specialized as a simple function with one loop,
15 /// compared to the three arguments (plus file & line info) that the runtime
16 /// call has - when the argument is a 1D-array (multiple loops may be needed
17 //  for higher dimension arrays, of course)
18 ///
19 /// The general idea is that besides making the call simpler, it can also be
20 /// inlined by other passes that run after this pass, which further improves
21 /// performance, particularly when the work done in the function is trivial
22 /// and small in size.
23 //===----------------------------------------------------------------------===//
24
25 #include "PassDetail.h"
26 #include "flang/Optimizer/Builder/BoxValue.h"
27 #include "flang/Optimizer/Builder/FIRBuilder.h"
28 #include "flang/Optimizer/Builder/Todo.h"
29 #include "flang/Optimizer/Dialect/FIROps.h"
30 #include "flang/Optimizer/Dialect/FIRType.h"
31 #include "flang/Optimizer/Support/FIRContext.h"
32 #include "flang/Optimizer/Transforms/Passes.h"
33 #include "mlir/IR/Matchers.h"
34 #include "mlir/IR/TypeUtilities.h"
35 #include "mlir/Pass/Pass.h"
36 #include "mlir/Transforms/DialectConversion.h"
37 #include "mlir/Transforms/GreedyPatternRewriteDriver.h"
38 #include "mlir/Transforms/RegionUtils.h"
39 #include "llvm/ADT/Optional.h"
40 #include "llvm/Support/Debug.h"
41 #include "llvm/Support/raw_ostream.h"
42
43 #define DEBUG_TYPE "flang-simplify-intrinsics"
44
45 namespace {
46
47 class SimplifyIntrinsicsPass
48     : public fir::SimplifyIntrinsicsBase<SimplifyIntrinsicsPass> {
49   using FunctionTypeGeneratorTy =
50       std::function<mlir::FunctionType(fir::FirOpBuilder &)>;
51   using FunctionBodyGeneratorTy =
52       std::function<void(fir::FirOpBuilder &, mlir::func::FuncOp &)>;
53
54 public:
55   /// Generate a new function implementing a simplified version
56   /// of a Fortran runtime function defined by \p basename name.
57   /// \p typeGenerator is a callback that generates the new function's type.
58   /// \p bodyGenerator is a callback that generates the new function's body.
59   /// The new function is created in the \p builder's Module.
60   mlir::func::FuncOp getOrCreateFunction(fir::FirOpBuilder &builder,
61                                          const mlir::StringRef &basename,
62                                          FunctionTypeGeneratorTy typeGenerator,
63                                          FunctionBodyGeneratorTy bodyGenerator);
64   void runOnOperation() override;
65   void getDependentDialects(mlir::DialectRegistry &registry) const override;
66 };
67
68 } // namespace
69
70 /// Generate function type for the simplified version of FortranASum and
71 /// similar functions with a fir.box<none> type returning \p elementType.
72 static mlir::FunctionType genNoneBoxType(fir::FirOpBuilder &builder,
73                                          const mlir::Type &elementType) {
74   mlir::Type boxType = fir::BoxType::get(builder.getNoneType());
75   return mlir::FunctionType::get(builder.getContext(), {boxType},
76                                  {elementType});
77 }
78
79 using BodyOpGeneratorTy =
80     std::function<mlir::Value(fir::FirOpBuilder &, mlir::Location,
81                               const mlir::Type &, mlir::Value, mlir::Value)>;
82 using InitValGeneratorTy = std::function<mlir::Value(
83     fir::FirOpBuilder &, mlir::Location, const mlir::Type &)>;
84
85 /// Generate the reduction loop into \p funcOp.
86 ///
87 /// \p initVal is a function, called to get the initial value for
88 ///    the reduction value
89 /// \p genBody is called to fill in the actual reduciton operation
90 ///    for example add for SUM, MAX for MAXVAL, etc.
91 static void genReductionLoop(fir::FirOpBuilder &builder,
92                              mlir::func::FuncOp &funcOp,
93                              InitValGeneratorTy initVal,
94                              BodyOpGeneratorTy genBody) {
95   auto loc = mlir::UnknownLoc::get(builder.getContext());
96   mlir::Type elementType = funcOp.getResultTypes()[0];
97   builder.setInsertionPointToEnd(funcOp.addEntryBlock());
98
99   mlir::IndexType idxTy = builder.getIndexType();
100
101   mlir::Block::BlockArgListType args = funcOp.front().getArguments();
102   mlir::Value arg = args[0];
103
104   mlir::Value zeroIdx = builder.createIntegerConstant(loc, idxTy, 0);
105
106   fir::SequenceType::Shape flatShape = {fir::SequenceType::getUnknownExtent()};
107   mlir::Type arrTy = fir::SequenceType::get(flatShape, elementType);
108   mlir::Type boxArrTy = fir::BoxType::get(arrTy);
109   mlir::Value array = builder.create<fir::ConvertOp>(loc, boxArrTy, arg);
110   auto dims =
111       builder.create<fir::BoxDimsOp>(loc, idxTy, idxTy, idxTy, array, zeroIdx);
112   mlir::Value len = dims.getResult(1);
113   mlir::Value one = builder.createIntegerConstant(loc, idxTy, 1);
114   mlir::Value step = one;
115
116   // We use C indexing here, so len-1 as loopcount
117   mlir::Value loopCount = builder.create<mlir::arith::SubIOp>(loc, len, one);
118   mlir::Value init = initVal(builder, loc, elementType);
119   auto loop = builder.create<fir::DoLoopOp>(loc, zeroIdx, loopCount, step,
120                                             /*unordered=*/false,
121                                             /*finalCountValue=*/false, init);
122   mlir::Value reductionVal = loop.getRegionIterArgs()[0];
123
124   // Begin loop code
125   mlir::OpBuilder::InsertPoint loopEndPt = builder.saveInsertionPoint();
126   builder.setInsertionPointToStart(loop.getBody());
127
128   mlir::Type eleRefTy = builder.getRefType(elementType);
129   mlir::Value index = loop.getInductionVar();
130   mlir::Value addr =
131       builder.create<fir::CoordinateOp>(loc, eleRefTy, array, index);
132   mlir::Value elem = builder.create<fir::LoadOp>(loc, addr);
133
134   reductionVal = genBody(builder, loc, elementType, elem, reductionVal);
135
136   builder.create<fir::ResultOp>(loc, reductionVal);
137   // End of loop.
138   builder.restoreInsertionPoint(loopEndPt);
139
140   mlir::Value resultVal = loop.getResult(0);
141   builder.create<mlir::func::ReturnOp>(loc, resultVal);
142 }
143
144 /// Generate function body of the simplified version of FortranASum
145 /// with signature provided by \p funcOp. The caller is responsible
146 /// for saving/restoring the original insertion point of \p builder.
147 /// \p funcOp is expected to be empty on entry to this function.
148 static void genFortranASumBody(fir::FirOpBuilder &builder,
149                                mlir::func::FuncOp &funcOp) {
150   // function FortranASum<T>_simplified(arr)
151   //   T, dimension(:) :: arr
152   //   T sum = 0
153   //   integer iter
154   //   do iter = 0, extent(arr)
155   //     sum = sum + arr[iter]
156   //   end do
157   //   FortranASum<T>_simplified = sum
158   // end function FortranASum<T>_simplified
159   auto zero = [](fir::FirOpBuilder builder, mlir::Location loc,
160                  mlir::Type elementType) {
161     return elementType.isa<mlir::FloatType>()
162                ? builder.createRealConstant(loc, elementType,
163                                             llvm::APFloat(0.0))
164                : builder.createIntegerConstant(loc, elementType, 0);
165   };
166
167   auto genBodyOp = [](fir::FirOpBuilder builder, mlir::Location loc,
168                       mlir::Type elementType, mlir::Value elem1,
169                       mlir::Value elem2) -> mlir::Value {
170     if (elementType.isa<mlir::FloatType>())
171       return builder.create<mlir::arith::AddFOp>(loc, elem1, elem2);
172     if (elementType.isa<mlir::IntegerType>())
173       return builder.create<mlir::arith::AddIOp>(loc, elem1, elem2);
174
175     llvm_unreachable("unsupported type");
176     return {};
177   };
178
179   genReductionLoop(builder, funcOp, zero, genBodyOp);
180 }
181
182 static void genFortranAMaxvalBody(fir::FirOpBuilder &builder,
183                                   mlir::func::FuncOp &funcOp) {
184   auto init = [](fir::FirOpBuilder builder, mlir::Location loc,
185                  mlir::Type elementType) {
186     if (auto ty = elementType.dyn_cast<mlir::FloatType>()) {
187       const llvm::fltSemantics &sem = ty.getFloatSemantics();
188       return builder.createRealConstant(
189           loc, elementType, llvm::APFloat::getLargest(sem, /*Negative=*/true));
190     }
191     unsigned bits = elementType.getIntOrFloatBitWidth();
192     int64_t minInt = llvm::APInt::getSignedMinValue(bits).getSExtValue();
193     return builder.createIntegerConstant(loc, elementType, minInt);
194   };
195
196   auto genBodyOp = [](fir::FirOpBuilder builder, mlir::Location loc,
197                       mlir::Type elementType, mlir::Value elem1,
198                       mlir::Value elem2) -> mlir::Value {
199     if (elementType.isa<mlir::FloatType>())
200       return builder.create<mlir::arith::MaxFOp>(loc, elem1, elem2);
201     if (elementType.isa<mlir::IntegerType>())
202       return builder.create<mlir::arith::MaxSIOp>(loc, elem1, elem2);
203
204     llvm_unreachable("unsupported type");
205     return {};
206   };
207   genReductionLoop(builder, funcOp, init, genBodyOp);
208 }
209
210 /// Generate function type for the simplified version of FortranADotProduct
211 /// operating on the given \p elementType.
212 static mlir::FunctionType genFortranADotType(fir::FirOpBuilder &builder,
213                                              const mlir::Type &elementType) {
214   mlir::Type boxType = fir::BoxType::get(builder.getNoneType());
215   return mlir::FunctionType::get(builder.getContext(), {boxType, boxType},
216                                  {elementType});
217 }
218
219 /// Generate function body of the simplified version of FortranADotProduct
220 /// with signature provided by \p funcOp. The caller is responsible
221 /// for saving/restoring the original insertion point of \p builder.
222 /// \p funcOp is expected to be empty on entry to this function.
223 /// \p arg1ElementTy and \p arg2ElementTy specify elements types
224 /// of the underlying array objects - they are used to generate proper
225 /// element accesses.
226 static void genFortranADotBody(fir::FirOpBuilder &builder,
227                                mlir::func::FuncOp &funcOp,
228                                mlir::Type arg1ElementTy,
229                                mlir::Type arg2ElementTy) {
230   // function FortranADotProduct<T>_simplified(arr1, arr2)
231   //   T, dimension(:) :: arr1, arr2
232   //   T product = 0
233   //   integer iter
234   //   do iter = 0, extent(arr1)
235   //     product = product + arr1[iter] * arr2[iter]
236   //   end do
237   //   FortranADotProduct<T>_simplified = product
238   // end function FortranADotProduct<T>_simplified
239   auto loc = mlir::UnknownLoc::get(builder.getContext());
240   mlir::Type resultElementType = funcOp.getResultTypes()[0];
241   builder.setInsertionPointToEnd(funcOp.addEntryBlock());
242
243   mlir::IndexType idxTy = builder.getIndexType();
244
245   mlir::Value zero =
246       resultElementType.isa<mlir::FloatType>()
247           ? builder.createRealConstant(loc, resultElementType, 0.0)
248           : builder.createIntegerConstant(loc, resultElementType, 0);
249
250   mlir::Block::BlockArgListType args = funcOp.front().getArguments();
251   mlir::Value arg1 = args[0];
252   mlir::Value arg2 = args[1];
253
254   mlir::Value zeroIdx = builder.createIntegerConstant(loc, idxTy, 0);
255
256   fir::SequenceType::Shape flatShape = {fir::SequenceType::getUnknownExtent()};
257   mlir::Type arrTy1 = fir::SequenceType::get(flatShape, arg1ElementTy);
258   mlir::Type boxArrTy1 = fir::BoxType::get(arrTy1);
259   mlir::Value array1 = builder.create<fir::ConvertOp>(loc, boxArrTy1, arg1);
260   mlir::Type arrTy2 = fir::SequenceType::get(flatShape, arg2ElementTy);
261   mlir::Type boxArrTy2 = fir::BoxType::get(arrTy2);
262   mlir::Value array2 = builder.create<fir::ConvertOp>(loc, boxArrTy2, arg2);
263   // This version takes the loop trip count from the first argument.
264   // If the first argument's box has unknown (at compilation time)
265   // extent, then it may be better to take the extent from the second
266   // argument - so that after inlining the loop may be better optimized, e.g.
267   // fully unrolled. This requires generating two versions of the simplified
268   // function and some analysis at the call site to choose which version
269   // is more profitable to call.
270   // Note that we can assume that both arguments have the same extent.
271   auto dims =
272       builder.create<fir::BoxDimsOp>(loc, idxTy, idxTy, idxTy, array1, zeroIdx);
273   mlir::Value len = dims.getResult(1);
274   mlir::Value one = builder.createIntegerConstant(loc, idxTy, 1);
275   mlir::Value step = one;
276
277   // We use C indexing here, so len-1 as loopcount
278   mlir::Value loopCount = builder.create<mlir::arith::SubIOp>(loc, len, one);
279   auto loop = builder.create<fir::DoLoopOp>(loc, zeroIdx, loopCount, step,
280                                             /*unordered=*/false,
281                                             /*finalCountValue=*/false, zero);
282   mlir::Value sumVal = loop.getRegionIterArgs()[0];
283
284   // Begin loop code
285   mlir::OpBuilder::InsertPoint loopEndPt = builder.saveInsertionPoint();
286   builder.setInsertionPointToStart(loop.getBody());
287
288   mlir::Type eleRef1Ty = builder.getRefType(arg1ElementTy);
289   mlir::Value index = loop.getInductionVar();
290   mlir::Value addr1 =
291       builder.create<fir::CoordinateOp>(loc, eleRef1Ty, array1, index);
292   mlir::Value elem1 = builder.create<fir::LoadOp>(loc, addr1);
293   // Convert to the result type.
294   elem1 = builder.create<fir::ConvertOp>(loc, resultElementType, elem1);
295
296   mlir::Type eleRef2Ty = builder.getRefType(arg2ElementTy);
297   mlir::Value addr2 =
298       builder.create<fir::CoordinateOp>(loc, eleRef2Ty, array2, index);
299   mlir::Value elem2 = builder.create<fir::LoadOp>(loc, addr2);
300   // Convert to the result type.
301   elem2 = builder.create<fir::ConvertOp>(loc, resultElementType, elem2);
302
303   if (resultElementType.isa<mlir::FloatType>())
304     sumVal = builder.create<mlir::arith::AddFOp>(
305         loc, builder.create<mlir::arith::MulFOp>(loc, elem1, elem2), sumVal);
306   else if (resultElementType.isa<mlir::IntegerType>())
307     sumVal = builder.create<mlir::arith::AddIOp>(
308         loc, builder.create<mlir::arith::MulIOp>(loc, elem1, elem2), sumVal);
309   else
310     llvm_unreachable("unsupported type");
311
312   builder.create<fir::ResultOp>(loc, sumVal);
313   // End of loop.
314   builder.restoreInsertionPoint(loopEndPt);
315
316   mlir::Value resultVal = loop.getResult(0);
317   builder.create<mlir::func::ReturnOp>(loc, resultVal);
318 }
319
320 mlir::func::FuncOp SimplifyIntrinsicsPass::getOrCreateFunction(
321     fir::FirOpBuilder &builder, const mlir::StringRef &baseName,
322     FunctionTypeGeneratorTy typeGenerator,
323     FunctionBodyGeneratorTy bodyGenerator) {
324   // WARNING: if the function generated here changes its signature
325   //          or behavior (the body code), we should probably embed some
326   //          versioning information into its name, otherwise libraries
327   //          statically linked with older versions of Flang may stop
328   //          working with object files created with newer Flang.
329   //          We can also avoid this by using internal linkage, but
330   //          this may increase the size of final executable/shared library.
331   std::string replacementName = mlir::Twine{baseName, "_simplified"}.str();
332   mlir::ModuleOp module = builder.getModule();
333   // If we already have a function, just return it.
334   mlir::func::FuncOp newFunc =
335       fir::FirOpBuilder::getNamedFunction(module, replacementName);
336   mlir::FunctionType fType = typeGenerator(builder);
337   if (newFunc) {
338     assert(newFunc.getFunctionType() == fType &&
339            "type mismatch for simplified function");
340     return newFunc;
341   }
342
343   // Need to build the function!
344   auto loc = mlir::UnknownLoc::get(builder.getContext());
345   newFunc =
346       fir::FirOpBuilder::createFunction(loc, module, replacementName, fType);
347   auto inlineLinkage = mlir::LLVM::linkage::Linkage::LinkonceODR;
348   auto linkage =
349       mlir::LLVM::LinkageAttr::get(builder.getContext(), inlineLinkage);
350   newFunc->setAttr("llvm.linkage", linkage);
351
352   // Save the position of the original call.
353   mlir::OpBuilder::InsertPoint insertPt = builder.saveInsertionPoint();
354
355   bodyGenerator(builder, newFunc);
356
357   // Now back to where we were adding code earlier...
358   builder.restoreInsertionPoint(insertPt);
359
360   return newFunc;
361 }
362
363 fir::ConvertOp expectConvertOp(mlir::Value val) {
364   if (fir::ConvertOp op =
365           mlir::dyn_cast_or_null<fir::ConvertOp>(val.getDefiningOp()))
366     return op;
367   LLVM_DEBUG(llvm::dbgs() << "Didn't find expected fir::ConvertOp\n");
368   return nullptr;
369 }
370
371 static bool isOperandAbsent(mlir::Value val) {
372   if (auto op = expectConvertOp(val)) {
373     assert(op->getOperands().size() != 0);
374     return mlir::isa_and_nonnull<fir::AbsentOp>(
375         op->getOperand(0).getDefiningOp());
376   }
377   return false;
378 }
379
380 static bool isZero(mlir::Value val) {
381   if (auto op = expectConvertOp(val)) {
382     assert(op->getOperands().size() != 0);
383     if (mlir::Operation *defOp = op->getOperand(0).getDefiningOp())
384       return mlir::matchPattern(defOp, mlir::m_Zero());
385   }
386   return false;
387 }
388
389 static mlir::Value findShape(mlir::Value val) {
390   if (auto op = expectConvertOp(val)) {
391     assert(op->getOperands().size() != 0);
392     if (auto box = mlir::dyn_cast_or_null<fir::EmboxOp>(
393             op->getOperand(0).getDefiningOp()))
394       return box.getShape();
395   }
396   return {};
397 }
398
399 static unsigned getDimCount(mlir::Value val) {
400   if (mlir::Value shapeVal = findShape(val)) {
401     mlir::Type resType = shapeVal.getDefiningOp()->getResultTypes()[0];
402     return fir::getRankOfShapeType(resType);
403   }
404   return 0;
405 }
406
407 /// Given the call operation's box argument \p val, discover
408 /// the element type of the underlying array object.
409 /// \returns the element type or llvm::None if the type cannot
410 /// be reliably found.
411 /// We expect that the argument is a result of fir.convert
412 /// with the destination type of !fir.box<none>.
413 static llvm::Optional<mlir::Type> getArgElementType(mlir::Value val) {
414   mlir::Operation *defOp;
415   do {
416     defOp = val.getDefiningOp();
417     // Analyze only sequences of convert operations.
418     if (!mlir::isa<fir::ConvertOp>(defOp))
419       return llvm::None;
420     val = defOp->getOperand(0);
421     // The convert operation is expected to convert from one
422     // box type to another box type.
423     auto boxType = val.getType().cast<fir::BoxType>();
424     auto elementType = fir::unwrapSeqOrBoxedSeqType(boxType);
425     if (!elementType.isa<mlir::NoneType>())
426       return elementType;
427   } while (true);
428 }
429
430 void SimplifyIntrinsicsPass::runOnOperation() {
431   LLVM_DEBUG(llvm::dbgs() << "=== Begin " DEBUG_TYPE " ===\n");
432   mlir::ModuleOp module = getOperation();
433   fir::KindMapping kindMap = fir::getKindMapping(module);
434   module.walk([&](mlir::Operation *op) {
435     if (auto call = mlir::dyn_cast<fir::CallOp>(op)) {
436       if (mlir::SymbolRefAttr callee = call.getCalleeAttr()) {
437         mlir::StringRef funcName = callee.getLeafReference().getValue();
438         // Replace call to runtime function for SUM when it has single
439         // argument (no dim or mask argument) for 1D arrays with either
440         // Integer4 or Real8 types. Other forms are ignored.
441         // The new function is added to the module.
442         //
443         // Prototype for runtime call (from sum.cpp):
444         // RTNAME(Sum<T>)(const Descriptor &x, const char *source, int line,
445         //                int dim, const Descriptor *mask)
446         //
447         if (funcName.startswith("_FortranASum")) {
448           mlir::Operation::operand_range args = call.getArgs();
449           // args[1] and args[2] are source filename and line number, ignored.
450           const mlir::Value &dim = args[3];
451           const mlir::Value &mask = args[4];
452           // dim is zero when it is absent, which is an implementation
453           // detail in the runtime library.
454           bool dimAndMaskAbsent = isZero(dim) && isOperandAbsent(mask);
455           unsigned rank = getDimCount(args[0]);
456           if (dimAndMaskAbsent && rank == 1) {
457             mlir::Location loc = call.getLoc();
458             mlir::Type type;
459             fir::FirOpBuilder builder(op, kindMap);
460             if (funcName.endswith("Integer4")) {
461               type = mlir::IntegerType::get(builder.getContext(), 32);
462             } else if (funcName.endswith("Real8")) {
463               type = mlir::FloatType::getF64(builder.getContext());
464             } else {
465               return;
466             }
467             auto typeGenerator = [&type](fir::FirOpBuilder &builder) {
468               return genNoneBoxType(builder, type);
469             };
470             mlir::func::FuncOp newFunc = getOrCreateFunction(
471                 builder, funcName, typeGenerator, genFortranASumBody);
472             auto newCall = builder.create<fir::CallOp>(
473                 loc, newFunc, mlir::ValueRange{args[0]});
474             call->replaceAllUsesWith(newCall.getResults());
475             call->dropAllReferences();
476             call->erase();
477           }
478
479           return;
480         }
481         if (funcName.startswith("_FortranADotProduct")) {
482           LLVM_DEBUG(llvm::dbgs() << "Handling " << funcName << "\n");
483           LLVM_DEBUG(llvm::dbgs() << "Call operation:\n"; op->dump();
484                      llvm::dbgs() << "\n");
485           mlir::Operation::operand_range args = call.getArgs();
486           const mlir::Value &v1 = args[0];
487           const mlir::Value &v2 = args[1];
488           mlir::Location loc = call.getLoc();
489           fir::FirOpBuilder builder(op, kindMap);
490
491           mlir::Type type = call.getResult(0).getType();
492           if (!type.isa<mlir::FloatType>() && !type.isa<mlir::IntegerType>())
493             return;
494
495           // Try to find the element types of the boxed arguments.
496           auto arg1Type = getArgElementType(v1);
497           auto arg2Type = getArgElementType(v2);
498
499           if (!arg1Type || !arg2Type)
500             return;
501
502           // Support only floating point and integer arguments
503           // now (e.g. logical is skipped here).
504           if (!arg1Type->isa<mlir::FloatType>() &&
505               !arg1Type->isa<mlir::IntegerType>())
506             return;
507           if (!arg2Type->isa<mlir::FloatType>() &&
508               !arg2Type->isa<mlir::IntegerType>())
509             return;
510
511           auto typeGenerator = [&type](fir::FirOpBuilder &builder) {
512             return genFortranADotType(builder, type);
513           };
514           auto bodyGenerator = [&arg1Type,
515                                 &arg2Type](fir::FirOpBuilder &builder,
516                                            mlir::func::FuncOp &funcOp) {
517             genFortranADotBody(builder, funcOp, *arg1Type, *arg2Type);
518           };
519
520           // Suffix the function name with the element types
521           // of the arguments.
522           std::string typedFuncName(funcName);
523           llvm::raw_string_ostream nameOS(typedFuncName);
524           nameOS << "_";
525           arg1Type->print(nameOS);
526           nameOS << "_";
527           arg2Type->print(nameOS);
528
529           mlir::func::FuncOp newFunc = getOrCreateFunction(
530               builder, typedFuncName, typeGenerator, bodyGenerator);
531           auto newCall = builder.create<fir::CallOp>(loc, newFunc,
532                                                      mlir::ValueRange{v1, v2});
533           call->replaceAllUsesWith(newCall.getResults());
534           call->dropAllReferences();
535           call->erase();
536
537           LLVM_DEBUG(llvm::dbgs() << "Replaced with:\n"; newCall.dump();
538                      llvm::dbgs() << "\n");
539           return;
540         }
541         if (funcName.startswith("_FortranAMaxval")) {
542           mlir::Operation::operand_range args = call.getArgs();
543           const mlir::Value &dim = args[3];
544           const mlir::Value &mask = args[4];
545           // dim is zero when it is absent, which is an implementation
546           // detail in the runtime library.
547           bool dimAndMaskAbsent = isZero(dim) && isOperandAbsent(mask);
548           unsigned rank = getDimCount(args[0]);
549           if (dimAndMaskAbsent && rank == 1) {
550             mlir::Location loc = call.getLoc();
551             mlir::Type type;
552             fir::FirOpBuilder builder(op, kindMap);
553             if (funcName.endswith("Integer4")) {
554               type = mlir::IntegerType::get(builder.getContext(), 32);
555             } else if (funcName.endswith("Real8")) {
556               type = mlir::FloatType::getF64(builder.getContext());
557             } else {
558               return;
559             }
560             auto typeGenerator = [&type](fir::FirOpBuilder &builder) {
561               return genNoneBoxType(builder, type);
562             };
563             mlir::func::FuncOp newFunc = getOrCreateFunction(
564                 builder, funcName, typeGenerator, genFortranAMaxvalBody);
565             auto newCall = builder.create<fir::CallOp>(
566                 loc, newFunc, mlir::ValueRange{args[0]});
567             call->replaceAllUsesWith(newCall.getResults());
568             call->dropAllReferences();
569             call->erase();
570             return;
571           }
572         }
573       }
574     }
575   });
576   LLVM_DEBUG(llvm::dbgs() << "=== End " DEBUG_TYPE " ===\n");
577 }
578
579 void SimplifyIntrinsicsPass::getDependentDialects(
580     mlir::DialectRegistry &registry) const {
581   // LLVM::LinkageAttr creation requires that LLVM dialect is loaded.
582   registry.insert<mlir::LLVM::LLVMDialect>();
583 }
584 std::unique_ptr<mlir::Pass> fir::createSimplifyIntrinsicsPass() {
585   return std::make_unique<SimplifyIntrinsicsPass>();
586 }