[flang] Support more data types for reduction in SimplifyIntrinsicsPass.
[platform/upstream/llvm.git] / flang / lib / Optimizer / Transforms / SimplifyIntrinsics.cpp
1 //===- SimplifyIntrinsics.cpp -- replace intrinsics with simpler form -----===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8
9 //===----------------------------------------------------------------------===//
10 /// \file
11 /// This pass looks for suitable calls to runtime library for intrinsics that
12 /// can be simplified/specialized and replaces with a specialized function.
13 ///
14 /// For example, SUM(arr) can be specialized as a simple function with one loop,
15 /// compared to the three arguments (plus file & line info) that the runtime
16 /// call has - when the argument is a 1D-array (multiple loops may be needed
17 //  for higher dimension arrays, of course)
18 ///
19 /// The general idea is that besides making the call simpler, it can also be
20 /// inlined by other passes that run after this pass, which further improves
21 /// performance, particularly when the work done in the function is trivial
22 /// and small in size.
23 //===----------------------------------------------------------------------===//
24
25 #include "flang/Optimizer/Builder/BoxValue.h"
26 #include "flang/Optimizer/Builder/FIRBuilder.h"
27 #include "flang/Optimizer/Builder/Todo.h"
28 #include "flang/Optimizer/Dialect/FIROps.h"
29 #include "flang/Optimizer/Dialect/FIRType.h"
30 #include "flang/Optimizer/Support/FIRContext.h"
31 #include "flang/Optimizer/Transforms/Passes.h"
32 #include "flang/Runtime/entry-names.h"
33 #include "mlir/Dialect/LLVMIR/LLVMDialect.h"
34 #include "mlir/IR/Matchers.h"
35 #include "mlir/IR/TypeUtilities.h"
36 #include "mlir/Pass/Pass.h"
37 #include "mlir/Transforms/DialectConversion.h"
38 #include "mlir/Transforms/GreedyPatternRewriteDriver.h"
39 #include "mlir/Transforms/RegionUtils.h"
40 #include "llvm/ADT/Optional.h"
41 #include "llvm/Support/Debug.h"
42 #include "llvm/Support/raw_ostream.h"
43
44 namespace fir {
45 #define GEN_PASS_DEF_SIMPLIFYINTRINSICS
46 #include "flang/Optimizer/Transforms/Passes.h.inc"
47 } // namespace fir
48
49 #define RTNAME_STRINGIFY2(x) #x
50 #define RTNAME_STRINGIFY(x) RTNAME_STRINGIFY2(x)
51 #define RTNAME_STRING(x) RTNAME_STRINGIFY(RTNAME(x))
52
53 #define DEBUG_TYPE "flang-simplify-intrinsics"
54
55 namespace {
56
57 class SimplifyIntrinsicsPass
58     : public fir::impl::SimplifyIntrinsicsBase<SimplifyIntrinsicsPass> {
59   using FunctionTypeGeneratorTy =
60       llvm::function_ref<mlir::FunctionType(fir::FirOpBuilder &)>;
61   using FunctionBodyGeneratorTy =
62       llvm::function_ref<void(fir::FirOpBuilder &, mlir::func::FuncOp &)>;
63   using GenReductionBodyTy = llvm::function_ref<void(
64       fir::FirOpBuilder &builder, mlir::func::FuncOp &funcOp)>;
65
66 public:
67   /// Generate a new function implementing a simplified version
68   /// of a Fortran runtime function defined by \p basename name.
69   /// \p typeGenerator is a callback that generates the new function's type.
70   /// \p bodyGenerator is a callback that generates the new function's body.
71   /// The new function is created in the \p builder's Module.
72   mlir::func::FuncOp getOrCreateFunction(fir::FirOpBuilder &builder,
73                                          const mlir::StringRef &basename,
74                                          FunctionTypeGeneratorTy typeGenerator,
75                                          FunctionBodyGeneratorTy bodyGenerator);
76   void runOnOperation() override;
77   void getDependentDialects(mlir::DialectRegistry &registry) const override;
78
79 private:
80   /// Helper function to replace a reduction type of call with its
81   /// simplified form. The actual function is generated using a callback
82   /// function.
83   /// \p call is the call to be replaced
84   /// \p kindMap is used to create FIROpBuilder
85   /// \p genBodyFunc is the callback that builds the replacement function
86   void simplifyReduction(fir::CallOp call, const fir::KindMapping &kindMap,
87                          GenReductionBodyTy genBodyFunc);
88 };
89
90 } // namespace
91
92 /// Generate function type for the simplified version of RTNAME(Sum) and
93 /// similar functions with a fir.box<none> type returning \p elementType.
94 static mlir::FunctionType genNoneBoxType(fir::FirOpBuilder &builder,
95                                          const mlir::Type &elementType) {
96   mlir::Type boxType = fir::BoxType::get(builder.getNoneType());
97   return mlir::FunctionType::get(builder.getContext(), {boxType},
98                                  {elementType});
99 }
100
101 using BodyOpGeneratorTy = llvm::function_ref<mlir::Value(
102     fir::FirOpBuilder &, mlir::Location, const mlir::Type &, mlir::Value,
103     mlir::Value)>;
104 using InitValGeneratorTy = llvm::function_ref<mlir::Value(
105     fir::FirOpBuilder &, mlir::Location, const mlir::Type &)>;
106
107 /// Generate the reduction loop into \p funcOp.
108 ///
109 /// \p initVal is a function, called to get the initial value for
110 ///    the reduction value
111 /// \p genBody is called to fill in the actual reduciton operation
112 ///    for example add for SUM, MAX for MAXVAL, etc.
113 static void genReductionLoop(fir::FirOpBuilder &builder,
114                              mlir::func::FuncOp &funcOp,
115                              InitValGeneratorTy initVal,
116                              BodyOpGeneratorTy genBody) {
117   auto loc = mlir::UnknownLoc::get(builder.getContext());
118   mlir::Type elementType = funcOp.getResultTypes()[0];
119   builder.setInsertionPointToEnd(funcOp.addEntryBlock());
120
121   mlir::IndexType idxTy = builder.getIndexType();
122
123   mlir::Block::BlockArgListType args = funcOp.front().getArguments();
124   mlir::Value arg = args[0];
125
126   mlir::Value zeroIdx = builder.createIntegerConstant(loc, idxTy, 0);
127
128   fir::SequenceType::Shape flatShape = {fir::SequenceType::getUnknownExtent()};
129   mlir::Type arrTy = fir::SequenceType::get(flatShape, elementType);
130   mlir::Type boxArrTy = fir::BoxType::get(arrTy);
131   mlir::Value array = builder.create<fir::ConvertOp>(loc, boxArrTy, arg);
132   auto dims =
133       builder.create<fir::BoxDimsOp>(loc, idxTy, idxTy, idxTy, array, zeroIdx);
134   mlir::Value len = dims.getResult(1);
135   mlir::Value one = builder.createIntegerConstant(loc, idxTy, 1);
136   mlir::Value step = one;
137
138   // We use C indexing here, so len-1 as loopcount
139   mlir::Value loopCount = builder.create<mlir::arith::SubIOp>(loc, len, one);
140   mlir::Value init = initVal(builder, loc, elementType);
141   auto loop = builder.create<fir::DoLoopOp>(loc, zeroIdx, loopCount, step,
142                                             /*unordered=*/false,
143                                             /*finalCountValue=*/false, init);
144   mlir::Value reductionVal = loop.getRegionIterArgs()[0];
145
146   // Begin loop code
147   mlir::OpBuilder::InsertPoint loopEndPt = builder.saveInsertionPoint();
148   builder.setInsertionPointToStart(loop.getBody());
149
150   mlir::Type eleRefTy = builder.getRefType(elementType);
151   mlir::Value index = loop.getInductionVar();
152   mlir::Value addr =
153       builder.create<fir::CoordinateOp>(loc, eleRefTy, array, index);
154   mlir::Value elem = builder.create<fir::LoadOp>(loc, addr);
155
156   reductionVal = genBody(builder, loc, elementType, elem, reductionVal);
157
158   builder.create<fir::ResultOp>(loc, reductionVal);
159   // End of loop.
160   builder.restoreInsertionPoint(loopEndPt);
161
162   mlir::Value resultVal = loop.getResult(0);
163   builder.create<mlir::func::ReturnOp>(loc, resultVal);
164 }
165
166 /// Generate function body of the simplified version of RTNAME(Sum)
167 /// with signature provided by \p funcOp. The caller is responsible
168 /// for saving/restoring the original insertion point of \p builder.
169 /// \p funcOp is expected to be empty on entry to this function.
170 static void genRuntimeSumBody(fir::FirOpBuilder &builder,
171                               mlir::func::FuncOp &funcOp) {
172   // function RTNAME(Sum)<T>_simplified(arr)
173   //   T, dimension(:) :: arr
174   //   T sum = 0
175   //   integer iter
176   //   do iter = 0, extent(arr)
177   //     sum = sum + arr[iter]
178   //   end do
179   //   RTNAME(Sum)<T>_simplified = sum
180   // end function RTNAME(Sum)<T>_simplified
181   auto zero = [](fir::FirOpBuilder builder, mlir::Location loc,
182                  mlir::Type elementType) {
183     if (auto ty = elementType.dyn_cast<mlir::FloatType>()) {
184       const llvm::fltSemantics &sem = ty.getFloatSemantics();
185       return builder.createRealConstant(loc, elementType,
186                                         llvm::APFloat::getZero(sem));
187     }
188     return builder.createIntegerConstant(loc, elementType, 0);
189   };
190
191   auto genBodyOp = [](fir::FirOpBuilder builder, mlir::Location loc,
192                       mlir::Type elementType, mlir::Value elem1,
193                       mlir::Value elem2) -> mlir::Value {
194     if (elementType.isa<mlir::FloatType>())
195       return builder.create<mlir::arith::AddFOp>(loc, elem1, elem2);
196     if (elementType.isa<mlir::IntegerType>())
197       return builder.create<mlir::arith::AddIOp>(loc, elem1, elem2);
198
199     llvm_unreachable("unsupported type");
200     return {};
201   };
202
203   genReductionLoop(builder, funcOp, zero, genBodyOp);
204 }
205
206 static void genRuntimeMaxvalBody(fir::FirOpBuilder &builder,
207                                  mlir::func::FuncOp &funcOp) {
208   auto init = [](fir::FirOpBuilder builder, mlir::Location loc,
209                  mlir::Type elementType) {
210     if (auto ty = elementType.dyn_cast<mlir::FloatType>()) {
211       const llvm::fltSemantics &sem = ty.getFloatSemantics();
212       return builder.createRealConstant(
213           loc, elementType, llvm::APFloat::getLargest(sem, /*Negative=*/true));
214     }
215     unsigned bits = elementType.getIntOrFloatBitWidth();
216     int64_t minInt = llvm::APInt::getSignedMinValue(bits).getSExtValue();
217     return builder.createIntegerConstant(loc, elementType, minInt);
218   };
219
220   auto genBodyOp = [](fir::FirOpBuilder builder, mlir::Location loc,
221                       mlir::Type elementType, mlir::Value elem1,
222                       mlir::Value elem2) -> mlir::Value {
223     if (elementType.isa<mlir::FloatType>())
224       return builder.create<mlir::arith::MaxFOp>(loc, elem1, elem2);
225     if (elementType.isa<mlir::IntegerType>())
226       return builder.create<mlir::arith::MaxSIOp>(loc, elem1, elem2);
227
228     llvm_unreachable("unsupported type");
229     return {};
230   };
231   genReductionLoop(builder, funcOp, init, genBodyOp);
232 }
233
234 /// Generate function type for the simplified version of RTNAME(DotProduct)
235 /// operating on the given \p elementType.
236 static mlir::FunctionType genRuntimeDotType(fir::FirOpBuilder &builder,
237                                             const mlir::Type &elementType) {
238   mlir::Type boxType = fir::BoxType::get(builder.getNoneType());
239   return mlir::FunctionType::get(builder.getContext(), {boxType, boxType},
240                                  {elementType});
241 }
242
243 /// Generate function body of the simplified version of RTNAME(DotProduct)
244 /// with signature provided by \p funcOp. The caller is responsible
245 /// for saving/restoring the original insertion point of \p builder.
246 /// \p funcOp is expected to be empty on entry to this function.
247 /// \p arg1ElementTy and \p arg2ElementTy specify elements types
248 /// of the underlying array objects - they are used to generate proper
249 /// element accesses.
250 static void genRuntimeDotBody(fir::FirOpBuilder &builder,
251                               mlir::func::FuncOp &funcOp,
252                               mlir::Type arg1ElementTy,
253                               mlir::Type arg2ElementTy) {
254   // function RTNAME(DotProduct)<T>_simplified(arr1, arr2)
255   //   T, dimension(:) :: arr1, arr2
256   //   T product = 0
257   //   integer iter
258   //   do iter = 0, extent(arr1)
259   //     product = product + arr1[iter] * arr2[iter]
260   //   end do
261   //   RTNAME(ADotProduct)<T>_simplified = product
262   // end function RTNAME(DotProduct)<T>_simplified
263   auto loc = mlir::UnknownLoc::get(builder.getContext());
264   mlir::Type resultElementType = funcOp.getResultTypes()[0];
265   builder.setInsertionPointToEnd(funcOp.addEntryBlock());
266
267   mlir::IndexType idxTy = builder.getIndexType();
268
269   mlir::Value zero =
270       resultElementType.isa<mlir::FloatType>()
271           ? builder.createRealConstant(loc, resultElementType, 0.0)
272           : builder.createIntegerConstant(loc, resultElementType, 0);
273
274   mlir::Block::BlockArgListType args = funcOp.front().getArguments();
275   mlir::Value arg1 = args[0];
276   mlir::Value arg2 = args[1];
277
278   mlir::Value zeroIdx = builder.createIntegerConstant(loc, idxTy, 0);
279
280   fir::SequenceType::Shape flatShape = {fir::SequenceType::getUnknownExtent()};
281   mlir::Type arrTy1 = fir::SequenceType::get(flatShape, arg1ElementTy);
282   mlir::Type boxArrTy1 = fir::BoxType::get(arrTy1);
283   mlir::Value array1 = builder.create<fir::ConvertOp>(loc, boxArrTy1, arg1);
284   mlir::Type arrTy2 = fir::SequenceType::get(flatShape, arg2ElementTy);
285   mlir::Type boxArrTy2 = fir::BoxType::get(arrTy2);
286   mlir::Value array2 = builder.create<fir::ConvertOp>(loc, boxArrTy2, arg2);
287   // This version takes the loop trip count from the first argument.
288   // If the first argument's box has unknown (at compilation time)
289   // extent, then it may be better to take the extent from the second
290   // argument - so that after inlining the loop may be better optimized, e.g.
291   // fully unrolled. This requires generating two versions of the simplified
292   // function and some analysis at the call site to choose which version
293   // is more profitable to call.
294   // Note that we can assume that both arguments have the same extent.
295   auto dims =
296       builder.create<fir::BoxDimsOp>(loc, idxTy, idxTy, idxTy, array1, zeroIdx);
297   mlir::Value len = dims.getResult(1);
298   mlir::Value one = builder.createIntegerConstant(loc, idxTy, 1);
299   mlir::Value step = one;
300
301   // We use C indexing here, so len-1 as loopcount
302   mlir::Value loopCount = builder.create<mlir::arith::SubIOp>(loc, len, one);
303   auto loop = builder.create<fir::DoLoopOp>(loc, zeroIdx, loopCount, step,
304                                             /*unordered=*/false,
305                                             /*finalCountValue=*/false, zero);
306   mlir::Value sumVal = loop.getRegionIterArgs()[0];
307
308   // Begin loop code
309   mlir::OpBuilder::InsertPoint loopEndPt = builder.saveInsertionPoint();
310   builder.setInsertionPointToStart(loop.getBody());
311
312   mlir::Type eleRef1Ty = builder.getRefType(arg1ElementTy);
313   mlir::Value index = loop.getInductionVar();
314   mlir::Value addr1 =
315       builder.create<fir::CoordinateOp>(loc, eleRef1Ty, array1, index);
316   mlir::Value elem1 = builder.create<fir::LoadOp>(loc, addr1);
317   // Convert to the result type.
318   elem1 = builder.create<fir::ConvertOp>(loc, resultElementType, elem1);
319
320   mlir::Type eleRef2Ty = builder.getRefType(arg2ElementTy);
321   mlir::Value addr2 =
322       builder.create<fir::CoordinateOp>(loc, eleRef2Ty, array2, index);
323   mlir::Value elem2 = builder.create<fir::LoadOp>(loc, addr2);
324   // Convert to the result type.
325   elem2 = builder.create<fir::ConvertOp>(loc, resultElementType, elem2);
326
327   if (resultElementType.isa<mlir::FloatType>())
328     sumVal = builder.create<mlir::arith::AddFOp>(
329         loc, builder.create<mlir::arith::MulFOp>(loc, elem1, elem2), sumVal);
330   else if (resultElementType.isa<mlir::IntegerType>())
331     sumVal = builder.create<mlir::arith::AddIOp>(
332         loc, builder.create<mlir::arith::MulIOp>(loc, elem1, elem2), sumVal);
333   else
334     llvm_unreachable("unsupported type");
335
336   builder.create<fir::ResultOp>(loc, sumVal);
337   // End of loop.
338   builder.restoreInsertionPoint(loopEndPt);
339
340   mlir::Value resultVal = loop.getResult(0);
341   builder.create<mlir::func::ReturnOp>(loc, resultVal);
342 }
343
344 mlir::func::FuncOp SimplifyIntrinsicsPass::getOrCreateFunction(
345     fir::FirOpBuilder &builder, const mlir::StringRef &baseName,
346     FunctionTypeGeneratorTy typeGenerator,
347     FunctionBodyGeneratorTy bodyGenerator) {
348   // WARNING: if the function generated here changes its signature
349   //          or behavior (the body code), we should probably embed some
350   //          versioning information into its name, otherwise libraries
351   //          statically linked with older versions of Flang may stop
352   //          working with object files created with newer Flang.
353   //          We can also avoid this by using internal linkage, but
354   //          this may increase the size of final executable/shared library.
355   std::string replacementName = mlir::Twine{baseName, "_simplified"}.str();
356   mlir::ModuleOp module = builder.getModule();
357   // If we already have a function, just return it.
358   mlir::func::FuncOp newFunc =
359       fir::FirOpBuilder::getNamedFunction(module, replacementName);
360   mlir::FunctionType fType = typeGenerator(builder);
361   if (newFunc) {
362     assert(newFunc.getFunctionType() == fType &&
363            "type mismatch for simplified function");
364     return newFunc;
365   }
366
367   // Need to build the function!
368   auto loc = mlir::UnknownLoc::get(builder.getContext());
369   newFunc =
370       fir::FirOpBuilder::createFunction(loc, module, replacementName, fType);
371   auto inlineLinkage = mlir::LLVM::linkage::Linkage::LinkonceODR;
372   auto linkage =
373       mlir::LLVM::LinkageAttr::get(builder.getContext(), inlineLinkage);
374   newFunc->setAttr("llvm.linkage", linkage);
375
376   // Save the position of the original call.
377   mlir::OpBuilder::InsertPoint insertPt = builder.saveInsertionPoint();
378
379   bodyGenerator(builder, newFunc);
380
381   // Now back to where we were adding code earlier...
382   builder.restoreInsertionPoint(insertPt);
383
384   return newFunc;
385 }
386
387 fir::ConvertOp expectConvertOp(mlir::Value val) {
388   if (fir::ConvertOp op =
389           mlir::dyn_cast_or_null<fir::ConvertOp>(val.getDefiningOp()))
390     return op;
391   LLVM_DEBUG(llvm::dbgs() << "Didn't find expected fir::ConvertOp\n");
392   return nullptr;
393 }
394
395 static bool isOperandAbsent(mlir::Value val) {
396   if (auto op = expectConvertOp(val)) {
397     assert(op->getOperands().size() != 0);
398     return mlir::isa_and_nonnull<fir::AbsentOp>(
399         op->getOperand(0).getDefiningOp());
400   }
401   return false;
402 }
403
404 static bool isZero(mlir::Value val) {
405   if (auto op = expectConvertOp(val)) {
406     assert(op->getOperands().size() != 0);
407     if (mlir::Operation *defOp = op->getOperand(0).getDefiningOp())
408       return mlir::matchPattern(defOp, mlir::m_Zero());
409   }
410   return false;
411 }
412
413 static mlir::Value findShape(mlir::Value val) {
414   if (auto op = expectConvertOp(val)) {
415     assert(op->getOperands().size() != 0);
416     if (auto box = mlir::dyn_cast_or_null<fir::EmboxOp>(
417             op->getOperand(0).getDefiningOp()))
418       return box.getShape();
419   }
420   return {};
421 }
422
423 static unsigned getDimCount(mlir::Value val) {
424   if (mlir::Value shapeVal = findShape(val)) {
425     mlir::Type resType = shapeVal.getDefiningOp()->getResultTypes()[0];
426     return fir::getRankOfShapeType(resType);
427   }
428   return 0;
429 }
430
431 /// Given the call operation's box argument \p val, discover
432 /// the element type of the underlying array object.
433 /// \returns the element type or llvm::None if the type cannot
434 /// be reliably found.
435 /// We expect that the argument is a result of fir.convert
436 /// with the destination type of !fir.box<none>.
437 static llvm::Optional<mlir::Type> getArgElementType(mlir::Value val) {
438   mlir::Operation *defOp;
439   do {
440     defOp = val.getDefiningOp();
441     // Analyze only sequences of convert operations.
442     if (!mlir::isa<fir::ConvertOp>(defOp))
443       return llvm::None;
444     val = defOp->getOperand(0);
445     // The convert operation is expected to convert from one
446     // box type to another box type.
447     auto boxType = val.getType().cast<fir::BoxType>();
448     auto elementType = fir::unwrapSeqOrBoxedSeqType(boxType);
449     if (!elementType.isa<mlir::NoneType>())
450       return elementType;
451   } while (true);
452 }
453
454 void SimplifyIntrinsicsPass::simplifyReduction(fir::CallOp call,
455                                                const fir::KindMapping &kindMap,
456                                                GenReductionBodyTy genBodyFunc) {
457   mlir::SymbolRefAttr callee = call.getCalleeAttr();
458   mlir::StringRef funcName = callee.getLeafReference().getValue();
459   mlir::Operation::operand_range args = call.getArgs();
460   // args[1] and args[2] are source filename and line number, ignored.
461   const mlir::Value &dim = args[3];
462   const mlir::Value &mask = args[4];
463   // dim is zero when it is absent, which is an implementation
464   // detail in the runtime library.
465   bool dimAndMaskAbsent = isZero(dim) && isOperandAbsent(mask);
466   unsigned rank = getDimCount(args[0]);
467   if (dimAndMaskAbsent && rank == 1) {
468     mlir::Location loc = call.getLoc();
469     fir::FirOpBuilder builder(call, kindMap);
470
471     // Support only floating point and integer results now.
472     mlir::Type resultType = call.getResult(0).getType();
473     if (!resultType.isa<mlir::FloatType>() &&
474         !resultType.isa<mlir::IntegerType>())
475       return;
476
477     auto argType = getArgElementType(args[0]);
478     if (!argType)
479       return;
480     assert(*argType == resultType &&
481            "Argument/result types mismatch in reduction");
482
483     auto typeGenerator = [&resultType](fir::FirOpBuilder &builder) {
484       return genNoneBoxType(builder, resultType);
485     };
486     mlir::func::FuncOp newFunc =
487         getOrCreateFunction(builder, funcName, typeGenerator, genBodyFunc);
488     auto newCall =
489         builder.create<fir::CallOp>(loc, newFunc, mlir::ValueRange{args[0]});
490     call->replaceAllUsesWith(newCall.getResults());
491     call->dropAllReferences();
492     call->erase();
493   }
494 }
495
496 void SimplifyIntrinsicsPass::runOnOperation() {
497   LLVM_DEBUG(llvm::dbgs() << "=== Begin " DEBUG_TYPE " ===\n");
498   mlir::ModuleOp module = getOperation();
499   fir::KindMapping kindMap = fir::getKindMapping(module);
500   module.walk([&](mlir::Operation *op) {
501     if (auto call = mlir::dyn_cast<fir::CallOp>(op)) {
502       if (mlir::SymbolRefAttr callee = call.getCalleeAttr()) {
503         mlir::StringRef funcName = callee.getLeafReference().getValue();
504         // Replace call to runtime function for SUM when it has single
505         // argument (no dim or mask argument) for 1D arrays with either
506         // Integer4 or Real8 types. Other forms are ignored.
507         // The new function is added to the module.
508         //
509         // Prototype for runtime call (from sum.cpp):
510         // RTNAME(Sum<T>)(const Descriptor &x, const char *source, int line,
511         //                int dim, const Descriptor *mask)
512         //
513         if (funcName.startswith(RTNAME_STRING(Sum))) {
514           simplifyReduction(call, kindMap, genRuntimeSumBody);
515           return;
516         }
517         if (funcName.startswith(RTNAME_STRING(DotProduct))) {
518           LLVM_DEBUG(llvm::dbgs() << "Handling " << funcName << "\n");
519           LLVM_DEBUG(llvm::dbgs() << "Call operation:\n"; op->dump();
520                      llvm::dbgs() << "\n");
521           mlir::Operation::operand_range args = call.getArgs();
522           const mlir::Value &v1 = args[0];
523           const mlir::Value &v2 = args[1];
524           mlir::Location loc = call.getLoc();
525           fir::FirOpBuilder builder(op, kindMap);
526
527           mlir::Type type = call.getResult(0).getType();
528           if (!type.isa<mlir::FloatType>() && !type.isa<mlir::IntegerType>())
529             return;
530
531           // Try to find the element types of the boxed arguments.
532           auto arg1Type = getArgElementType(v1);
533           auto arg2Type = getArgElementType(v2);
534
535           if (!arg1Type || !arg2Type)
536             return;
537
538           // Support only floating point and integer arguments
539           // now (e.g. logical is skipped here).
540           if (!arg1Type->isa<mlir::FloatType>() &&
541               !arg1Type->isa<mlir::IntegerType>())
542             return;
543           if (!arg2Type->isa<mlir::FloatType>() &&
544               !arg2Type->isa<mlir::IntegerType>())
545             return;
546
547           auto typeGenerator = [&type](fir::FirOpBuilder &builder) {
548             return genRuntimeDotType(builder, type);
549           };
550           auto bodyGenerator = [&arg1Type,
551                                 &arg2Type](fir::FirOpBuilder &builder,
552                                            mlir::func::FuncOp &funcOp) {
553             genRuntimeDotBody(builder, funcOp, *arg1Type, *arg2Type);
554           };
555
556           // Suffix the function name with the element types
557           // of the arguments.
558           std::string typedFuncName(funcName);
559           llvm::raw_string_ostream nameOS(typedFuncName);
560           nameOS << "_";
561           arg1Type->print(nameOS);
562           nameOS << "_";
563           arg2Type->print(nameOS);
564
565           mlir::func::FuncOp newFunc = getOrCreateFunction(
566               builder, typedFuncName, typeGenerator, bodyGenerator);
567           auto newCall = builder.create<fir::CallOp>(loc, newFunc,
568                                                      mlir::ValueRange{v1, v2});
569           call->replaceAllUsesWith(newCall.getResults());
570           call->dropAllReferences();
571           call->erase();
572
573           LLVM_DEBUG(llvm::dbgs() << "Replaced with:\n"; newCall.dump();
574                      llvm::dbgs() << "\n");
575           return;
576         }
577         if (funcName.startswith(RTNAME_STRING(Maxval))) {
578           simplifyReduction(call, kindMap, genRuntimeMaxvalBody);
579           return;
580         }
581       }
582     }
583   });
584   LLVM_DEBUG(llvm::dbgs() << "=== End " DEBUG_TYPE " ===\n");
585 }
586
587 void SimplifyIntrinsicsPass::getDependentDialects(
588     mlir::DialectRegistry &registry) const {
589   // LLVM::LinkageAttr creation requires that LLVM dialect is loaded.
590   registry.insert<mlir::LLVM::LLVMDialect>();
591 }
592 std::unique_ptr<mlir::Pass> fir::createSimplifyIntrinsicsPass() {
593   return std::make_unique<SimplifyIntrinsicsPass>();
594 }