cdc3ab9393f3bb4450fce8de66511ce39e4246d4
[platform/upstream/llvm.git] / flang / lib / Optimizer / Transforms / SimplifyIntrinsics.cpp
1 //===- SimplifyIntrinsics.cpp -- replace intrinsics with simpler form -----===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8
9 //===----------------------------------------------------------------------===//
10 /// \file
11 /// This pass looks for suitable calls to runtime library for intrinsics that
12 /// can be simplified/specialized and replaces with a specialized function.
13 ///
14 /// For example, SUM(arr) can be specialized as a simple function with one loop,
15 /// compared to the three arguments (plus file & line info) that the runtime
16 /// call has - when the argument is a 1D-array (multiple loops may be needed
17 //  for higher dimension arrays, of course)
18 ///
19 /// The general idea is that besides making the call simpler, it can also be
20 /// inlined by other passes that run after this pass, which further improves
21 /// performance, particularly when the work done in the function is trivial
22 /// and small in size.
23 //===----------------------------------------------------------------------===//
24
25 #include "flang/Optimizer/Builder/BoxValue.h"
26 #include "flang/Optimizer/Builder/FIRBuilder.h"
27 #include "flang/Optimizer/Builder/Todo.h"
28 #include "flang/Optimizer/Dialect/FIROps.h"
29 #include "flang/Optimizer/Dialect/FIRType.h"
30 #include "flang/Optimizer/Support/FIRContext.h"
31 #include "flang/Optimizer/Transforms/Passes.h"
32 #include "flang/Runtime/entry-names.h"
33 #include "mlir/Dialect/LLVMIR/LLVMDialect.h"
34 #include "mlir/IR/Matchers.h"
35 #include "mlir/IR/TypeUtilities.h"
36 #include "mlir/Pass/Pass.h"
37 #include "mlir/Transforms/DialectConversion.h"
38 #include "mlir/Transforms/GreedyPatternRewriteDriver.h"
39 #include "mlir/Transforms/RegionUtils.h"
40 #include "llvm/ADT/Optional.h"
41 #include "llvm/Support/Debug.h"
42 #include "llvm/Support/raw_ostream.h"
43
44 namespace fir {
45 #define GEN_PASS_DEF_SIMPLIFYINTRINSICS
46 #include "flang/Optimizer/Transforms/Passes.h.inc"
47 } // namespace fir
48
49 #define DEBUG_TYPE "flang-simplify-intrinsics"
50
51 namespace {
52
53 class SimplifyIntrinsicsPass
54     : public fir::impl::SimplifyIntrinsicsBase<SimplifyIntrinsicsPass> {
55   using FunctionTypeGeneratorTy =
56       llvm::function_ref<mlir::FunctionType(fir::FirOpBuilder &)>;
57   using FunctionBodyGeneratorTy =
58       llvm::function_ref<void(fir::FirOpBuilder &, mlir::func::FuncOp &)>;
59   using GenReductionBodyTy = llvm::function_ref<void(
60       fir::FirOpBuilder &builder, mlir::func::FuncOp &funcOp, unsigned rank)>;
61
62 public:
63   /// Generate a new function implementing a simplified version
64   /// of a Fortran runtime function defined by \p basename name.
65   /// \p typeGenerator is a callback that generates the new function's type.
66   /// \p bodyGenerator is a callback that generates the new function's body.
67   /// The new function is created in the \p builder's Module.
68   mlir::func::FuncOp getOrCreateFunction(fir::FirOpBuilder &builder,
69                                          const mlir::StringRef &basename,
70                                          FunctionTypeGeneratorTy typeGenerator,
71                                          FunctionBodyGeneratorTy bodyGenerator);
72   void runOnOperation() override;
73   void getDependentDialects(mlir::DialectRegistry &registry) const override;
74
75 private:
76   /// Helper function to replace a reduction type of call with its
77   /// simplified form. The actual function is generated using a callback
78   /// function.
79   /// \p call is the call to be replaced
80   /// \p kindMap is used to create FIROpBuilder
81   /// \p genBodyFunc is the callback that builds the replacement function
82   void simplifyReduction(fir::CallOp call, const fir::KindMapping &kindMap,
83                          GenReductionBodyTy genBodyFunc);
84 };
85
86 } // namespace
87
88 /// Generate function type for the simplified version of RTNAME(Sum) and
89 /// similar functions with a fir.box<none> type returning \p elementType.
90 static mlir::FunctionType genNoneBoxType(fir::FirOpBuilder &builder,
91                                          const mlir::Type &elementType) {
92   mlir::Type boxType = fir::BoxType::get(builder.getNoneType());
93   return mlir::FunctionType::get(builder.getContext(), {boxType},
94                                  {elementType});
95 }
96
97 using BodyOpGeneratorTy = llvm::function_ref<mlir::Value(
98     fir::FirOpBuilder &, mlir::Location, const mlir::Type &, mlir::Value,
99     mlir::Value)>;
100 using InitValGeneratorTy = llvm::function_ref<mlir::Value(
101     fir::FirOpBuilder &, mlir::Location, const mlir::Type &)>;
102
103 /// Generate the reduction loop into \p funcOp.
104 ///
105 /// \p initVal is a function, called to get the initial value for
106 ///    the reduction value
107 /// \p genBody is called to fill in the actual reduciton operation
108 ///    for example add for SUM, MAX for MAXVAL, etc.
109 /// \p rank is the rank of the input argument.
110 static void genReductionLoop(fir::FirOpBuilder &builder,
111                              mlir::func::FuncOp &funcOp,
112                              InitValGeneratorTy initVal,
113                              BodyOpGeneratorTy genBody, unsigned rank) {
114   auto loc = mlir::UnknownLoc::get(builder.getContext());
115   mlir::Type elementType = funcOp.getResultTypes()[0];
116   builder.setInsertionPointToEnd(funcOp.addEntryBlock());
117
118   mlir::IndexType idxTy = builder.getIndexType();
119
120   mlir::Block::BlockArgListType args = funcOp.front().getArguments();
121   mlir::Value arg = args[0];
122
123   mlir::Value zeroIdx = builder.createIntegerConstant(loc, idxTy, 0);
124
125   fir::SequenceType::Shape flatShape(rank,
126                                      fir::SequenceType::getUnknownExtent());
127   mlir::Type arrTy = fir::SequenceType::get(flatShape, elementType);
128   mlir::Type boxArrTy = fir::BoxType::get(arrTy);
129   mlir::Value array = builder.create<fir::ConvertOp>(loc, boxArrTy, arg);
130   mlir::Value init = initVal(builder, loc, elementType);
131
132   llvm::SmallVector<mlir::Value, 15> bounds;
133
134   assert(rank > 0 && "rank cannot be zero");
135   mlir::Value one = builder.createIntegerConstant(loc, idxTy, 1);
136
137   // Compute all the upper bounds before the loop nest.
138   // It is not strictly necessary for performance, since the loop nest
139   // does not have any store operations and any LICM optimization
140   // should be able to optimize the redundancy.
141   for (unsigned i = 0; i < rank; ++i) {
142     mlir::Value dimIdx = builder.createIntegerConstant(loc, idxTy, i);
143     auto dims =
144         builder.create<fir::BoxDimsOp>(loc, idxTy, idxTy, idxTy, array, dimIdx);
145     mlir::Value len = dims.getResult(1);
146     // We use C indexing here, so len-1 as loopcount
147     mlir::Value loopCount = builder.create<mlir::arith::SubIOp>(loc, len, one);
148     bounds.push_back(loopCount);
149   }
150
151   // Create a loop nest consisting of DoLoopOp operations.
152   // Collect the loops' induction variables into indices array,
153   // which will be used in the innermost loop to load the input
154   // array's element.
155   // The loops are generated such that the innermost loop processes
156   // the 0 dimension.
157   llvm::SmallVector<mlir::Value, 15> indices;
158   for (unsigned i = rank; 0 < i; --i) {
159     mlir::Value step = one;
160     mlir::Value loopCount = bounds[i - 1];
161     auto loop = builder.create<fir::DoLoopOp>(loc, zeroIdx, loopCount, step,
162                                               /*unordered=*/false,
163                                               /*finalCountValue=*/false, init);
164     init = loop.getRegionIterArgs()[0];
165     indices.push_back(loop.getInductionVar());
166     // Set insertion point to the loop body so that the next loop
167     // is inserted inside the current one.
168     builder.setInsertionPointToStart(loop.getBody());
169   }
170
171   // Reverse the indices such that they are ordered as:
172   //   <dim-0-idx, dim-1-idx, ...>
173   std::reverse(indices.begin(), indices.end());
174
175   // We are in the innermost loop: generate the reduction body.
176   mlir::Type eleRefTy = builder.getRefType(elementType);
177   mlir::Value addr =
178       builder.create<fir::CoordinateOp>(loc, eleRefTy, array, indices);
179   mlir::Value elem = builder.create<fir::LoadOp>(loc, addr);
180
181   mlir::Value reductionVal = genBody(builder, loc, elementType, elem, init);
182
183   // Unwind the loop nest and insert ResultOp on each level
184   // to return the updated value of the reduction to the enclosing
185   // loops.
186   for (unsigned i = 0; i < rank; ++i) {
187     auto result = builder.create<fir::ResultOp>(loc, reductionVal);
188     // Proceed to the outer loop.
189     auto loop = mlir::cast<fir::DoLoopOp>(result->getParentOp());
190     reductionVal = loop.getResult(0);
191     // Set insertion point after the loop operation that we have
192     // just processed.
193     builder.setInsertionPointAfter(loop.getOperation());
194   }
195
196   // End of loop nest. The insertion point is after the outermost loop.
197   // Return the reduction value from the function.
198   builder.create<mlir::func::ReturnOp>(loc, reductionVal);
199 }
200
201 /// Generate function body of the simplified version of RTNAME(Sum)
202 /// with signature provided by \p funcOp. The caller is responsible
203 /// for saving/restoring the original insertion point of \p builder.
204 /// \p funcOp is expected to be empty on entry to this function.
205 /// \p rank specifies the rank of the input argument.
206 static void genRuntimeSumBody(fir::FirOpBuilder &builder,
207                               mlir::func::FuncOp &funcOp, unsigned rank) {
208   // function RTNAME(Sum)<T>x<rank>_simplified(arr)
209   //   T, dimension(:) :: arr
210   //   T sum = 0
211   //   integer iter
212   //   do iter = 0, extent(arr)
213   //     sum = sum + arr[iter]
214   //   end do
215   //   RTNAME(Sum)<T>x<rank>_simplified = sum
216   // end function RTNAME(Sum)<T>x<rank>_simplified
217   auto zero = [](fir::FirOpBuilder builder, mlir::Location loc,
218                  mlir::Type elementType) {
219     if (auto ty = elementType.dyn_cast<mlir::FloatType>()) {
220       const llvm::fltSemantics &sem = ty.getFloatSemantics();
221       return builder.createRealConstant(loc, elementType,
222                                         llvm::APFloat::getZero(sem));
223     }
224     return builder.createIntegerConstant(loc, elementType, 0);
225   };
226
227   auto genBodyOp = [](fir::FirOpBuilder builder, mlir::Location loc,
228                       mlir::Type elementType, mlir::Value elem1,
229                       mlir::Value elem2) -> mlir::Value {
230     if (elementType.isa<mlir::FloatType>())
231       return builder.create<mlir::arith::AddFOp>(loc, elem1, elem2);
232     if (elementType.isa<mlir::IntegerType>())
233       return builder.create<mlir::arith::AddIOp>(loc, elem1, elem2);
234
235     llvm_unreachable("unsupported type");
236     return {};
237   };
238
239   genReductionLoop(builder, funcOp, zero, genBodyOp, rank);
240 }
241
242 static void genRuntimeMaxvalBody(fir::FirOpBuilder &builder,
243                                  mlir::func::FuncOp &funcOp, unsigned rank) {
244   auto init = [](fir::FirOpBuilder builder, mlir::Location loc,
245                  mlir::Type elementType) {
246     if (auto ty = elementType.dyn_cast<mlir::FloatType>()) {
247       const llvm::fltSemantics &sem = ty.getFloatSemantics();
248       return builder.createRealConstant(
249           loc, elementType, llvm::APFloat::getLargest(sem, /*Negative=*/true));
250     }
251     unsigned bits = elementType.getIntOrFloatBitWidth();
252     int64_t minInt = llvm::APInt::getSignedMinValue(bits).getSExtValue();
253     return builder.createIntegerConstant(loc, elementType, minInt);
254   };
255
256   auto genBodyOp = [](fir::FirOpBuilder builder, mlir::Location loc,
257                       mlir::Type elementType, mlir::Value elem1,
258                       mlir::Value elem2) -> mlir::Value {
259     if (elementType.isa<mlir::FloatType>())
260       return builder.create<mlir::arith::MaxFOp>(loc, elem1, elem2);
261     if (elementType.isa<mlir::IntegerType>())
262       return builder.create<mlir::arith::MaxSIOp>(loc, elem1, elem2);
263
264     llvm_unreachable("unsupported type");
265     return {};
266   };
267   genReductionLoop(builder, funcOp, init, genBodyOp, rank);
268 }
269
270 /// Generate function type for the simplified version of RTNAME(DotProduct)
271 /// operating on the given \p elementType.
272 static mlir::FunctionType genRuntimeDotType(fir::FirOpBuilder &builder,
273                                             const mlir::Type &elementType) {
274   mlir::Type boxType = fir::BoxType::get(builder.getNoneType());
275   return mlir::FunctionType::get(builder.getContext(), {boxType, boxType},
276                                  {elementType});
277 }
278
279 /// Generate function body of the simplified version of RTNAME(DotProduct)
280 /// with signature provided by \p funcOp. The caller is responsible
281 /// for saving/restoring the original insertion point of \p builder.
282 /// \p funcOp is expected to be empty on entry to this function.
283 /// \p arg1ElementTy and \p arg2ElementTy specify elements types
284 /// of the underlying array objects - they are used to generate proper
285 /// element accesses.
286 static void genRuntimeDotBody(fir::FirOpBuilder &builder,
287                               mlir::func::FuncOp &funcOp,
288                               mlir::Type arg1ElementTy,
289                               mlir::Type arg2ElementTy) {
290   // function RTNAME(DotProduct)<T>_simplified(arr1, arr2)
291   //   T, dimension(:) :: arr1, arr2
292   //   T product = 0
293   //   integer iter
294   //   do iter = 0, extent(arr1)
295   //     product = product + arr1[iter] * arr2[iter]
296   //   end do
297   //   RTNAME(ADotProduct)<T>_simplified = product
298   // end function RTNAME(DotProduct)<T>_simplified
299   auto loc = mlir::UnknownLoc::get(builder.getContext());
300   mlir::Type resultElementType = funcOp.getResultTypes()[0];
301   builder.setInsertionPointToEnd(funcOp.addEntryBlock());
302
303   mlir::IndexType idxTy = builder.getIndexType();
304
305   mlir::Value zero =
306       resultElementType.isa<mlir::FloatType>()
307           ? builder.createRealConstant(loc, resultElementType, 0.0)
308           : builder.createIntegerConstant(loc, resultElementType, 0);
309
310   mlir::Block::BlockArgListType args = funcOp.front().getArguments();
311   mlir::Value arg1 = args[0];
312   mlir::Value arg2 = args[1];
313
314   mlir::Value zeroIdx = builder.createIntegerConstant(loc, idxTy, 0);
315
316   fir::SequenceType::Shape flatShape = {fir::SequenceType::getUnknownExtent()};
317   mlir::Type arrTy1 = fir::SequenceType::get(flatShape, arg1ElementTy);
318   mlir::Type boxArrTy1 = fir::BoxType::get(arrTy1);
319   mlir::Value array1 = builder.create<fir::ConvertOp>(loc, boxArrTy1, arg1);
320   mlir::Type arrTy2 = fir::SequenceType::get(flatShape, arg2ElementTy);
321   mlir::Type boxArrTy2 = fir::BoxType::get(arrTy2);
322   mlir::Value array2 = builder.create<fir::ConvertOp>(loc, boxArrTy2, arg2);
323   // This version takes the loop trip count from the first argument.
324   // If the first argument's box has unknown (at compilation time)
325   // extent, then it may be better to take the extent from the second
326   // argument - so that after inlining the loop may be better optimized, e.g.
327   // fully unrolled. This requires generating two versions of the simplified
328   // function and some analysis at the call site to choose which version
329   // is more profitable to call.
330   // Note that we can assume that both arguments have the same extent.
331   auto dims =
332       builder.create<fir::BoxDimsOp>(loc, idxTy, idxTy, idxTy, array1, zeroIdx);
333   mlir::Value len = dims.getResult(1);
334   mlir::Value one = builder.createIntegerConstant(loc, idxTy, 1);
335   mlir::Value step = one;
336
337   // We use C indexing here, so len-1 as loopcount
338   mlir::Value loopCount = builder.create<mlir::arith::SubIOp>(loc, len, one);
339   auto loop = builder.create<fir::DoLoopOp>(loc, zeroIdx, loopCount, step,
340                                             /*unordered=*/false,
341                                             /*finalCountValue=*/false, zero);
342   mlir::Value sumVal = loop.getRegionIterArgs()[0];
343
344   // Begin loop code
345   mlir::OpBuilder::InsertPoint loopEndPt = builder.saveInsertionPoint();
346   builder.setInsertionPointToStart(loop.getBody());
347
348   mlir::Type eleRef1Ty = builder.getRefType(arg1ElementTy);
349   mlir::Value index = loop.getInductionVar();
350   mlir::Value addr1 =
351       builder.create<fir::CoordinateOp>(loc, eleRef1Ty, array1, index);
352   mlir::Value elem1 = builder.create<fir::LoadOp>(loc, addr1);
353   // Convert to the result type.
354   elem1 = builder.create<fir::ConvertOp>(loc, resultElementType, elem1);
355
356   mlir::Type eleRef2Ty = builder.getRefType(arg2ElementTy);
357   mlir::Value addr2 =
358       builder.create<fir::CoordinateOp>(loc, eleRef2Ty, array2, index);
359   mlir::Value elem2 = builder.create<fir::LoadOp>(loc, addr2);
360   // Convert to the result type.
361   elem2 = builder.create<fir::ConvertOp>(loc, resultElementType, elem2);
362
363   if (resultElementType.isa<mlir::FloatType>())
364     sumVal = builder.create<mlir::arith::AddFOp>(
365         loc, builder.create<mlir::arith::MulFOp>(loc, elem1, elem2), sumVal);
366   else if (resultElementType.isa<mlir::IntegerType>())
367     sumVal = builder.create<mlir::arith::AddIOp>(
368         loc, builder.create<mlir::arith::MulIOp>(loc, elem1, elem2), sumVal);
369   else
370     llvm_unreachable("unsupported type");
371
372   builder.create<fir::ResultOp>(loc, sumVal);
373   // End of loop.
374   builder.restoreInsertionPoint(loopEndPt);
375
376   mlir::Value resultVal = loop.getResult(0);
377   builder.create<mlir::func::ReturnOp>(loc, resultVal);
378 }
379
380 mlir::func::FuncOp SimplifyIntrinsicsPass::getOrCreateFunction(
381     fir::FirOpBuilder &builder, const mlir::StringRef &baseName,
382     FunctionTypeGeneratorTy typeGenerator,
383     FunctionBodyGeneratorTy bodyGenerator) {
384   // WARNING: if the function generated here changes its signature
385   //          or behavior (the body code), we should probably embed some
386   //          versioning information into its name, otherwise libraries
387   //          statically linked with older versions of Flang may stop
388   //          working with object files created with newer Flang.
389   //          We can also avoid this by using internal linkage, but
390   //          this may increase the size of final executable/shared library.
391   std::string replacementName = mlir::Twine{baseName, "_simplified"}.str();
392   mlir::ModuleOp module = builder.getModule();
393   // If we already have a function, just return it.
394   mlir::func::FuncOp newFunc =
395       fir::FirOpBuilder::getNamedFunction(module, replacementName);
396   mlir::FunctionType fType = typeGenerator(builder);
397   if (newFunc) {
398     assert(newFunc.getFunctionType() == fType &&
399            "type mismatch for simplified function");
400     return newFunc;
401   }
402
403   // Need to build the function!
404   auto loc = mlir::UnknownLoc::get(builder.getContext());
405   newFunc =
406       fir::FirOpBuilder::createFunction(loc, module, replacementName, fType);
407   auto inlineLinkage = mlir::LLVM::linkage::Linkage::LinkonceODR;
408   auto linkage =
409       mlir::LLVM::LinkageAttr::get(builder.getContext(), inlineLinkage);
410   newFunc->setAttr("llvm.linkage", linkage);
411
412   // Save the position of the original call.
413   mlir::OpBuilder::InsertPoint insertPt = builder.saveInsertionPoint();
414
415   bodyGenerator(builder, newFunc);
416
417   // Now back to where we were adding code earlier...
418   builder.restoreInsertionPoint(insertPt);
419
420   return newFunc;
421 }
422
423 fir::ConvertOp expectConvertOp(mlir::Value val) {
424   if (fir::ConvertOp op =
425           mlir::dyn_cast_or_null<fir::ConvertOp>(val.getDefiningOp()))
426     return op;
427   LLVM_DEBUG(llvm::dbgs() << "Didn't find expected fir::ConvertOp\n");
428   return nullptr;
429 }
430
431 static bool isOperandAbsent(mlir::Value val) {
432   if (auto op = expectConvertOp(val)) {
433     assert(op->getOperands().size() != 0);
434     return mlir::isa_and_nonnull<fir::AbsentOp>(
435         op->getOperand(0).getDefiningOp());
436   }
437   return false;
438 }
439
440 static bool isZero(mlir::Value val) {
441   if (auto op = expectConvertOp(val)) {
442     assert(op->getOperands().size() != 0);
443     if (mlir::Operation *defOp = op->getOperand(0).getDefiningOp())
444       return mlir::matchPattern(defOp, mlir::m_Zero());
445   }
446   return false;
447 }
448
449 static mlir::Value findBoxDef(mlir::Value val) {
450   if (auto op = expectConvertOp(val)) {
451     assert(op->getOperands().size() != 0);
452     if (auto box = mlir::dyn_cast_or_null<fir::EmboxOp>(
453             op->getOperand(0).getDefiningOp()))
454       return box.getResult();
455     if (auto box = mlir::dyn_cast_or_null<fir::ReboxOp>(
456             op->getOperand(0).getDefiningOp()))
457       return box.getResult();
458   }
459   return {};
460 }
461
462 static unsigned getDimCount(mlir::Value val) {
463   // In order to find the dimensions count, we look for EmboxOp/ReboxOp
464   // and take the count from its *result* type. Note that in case
465   // of sliced emboxing the operand and the result of EmboxOp/ReboxOp
466   // have different types.
467   // Actually, we can take the box type from the operand of
468   // the first ConvertOp that has non-opaque box type that we meet
469   // going through the ConvertOp chain.
470   if (mlir::Value emboxVal = findBoxDef(val))
471     if (auto boxTy = emboxVal.getType().dyn_cast<fir::BoxType>())
472       if (auto seqTy = boxTy.getEleTy().dyn_cast<fir::SequenceType>())
473         return seqTy.getDimension();
474   return 0;
475 }
476
477 /// Given the call operation's box argument \p val, discover
478 /// the element type of the underlying array object.
479 /// \returns the element type or llvm::None if the type cannot
480 /// be reliably found.
481 /// We expect that the argument is a result of fir.convert
482 /// with the destination type of !fir.box<none>.
483 static llvm::Optional<mlir::Type> getArgElementType(mlir::Value val) {
484   mlir::Operation *defOp;
485   do {
486     defOp = val.getDefiningOp();
487     // Analyze only sequences of convert operations.
488     if (!mlir::isa<fir::ConvertOp>(defOp))
489       return llvm::None;
490     val = defOp->getOperand(0);
491     // The convert operation is expected to convert from one
492     // box type to another box type.
493     auto boxType = val.getType().cast<fir::BoxType>();
494     auto elementType = fir::unwrapSeqOrBoxedSeqType(boxType);
495     if (!elementType.isa<mlir::NoneType>())
496       return elementType;
497   } while (true);
498 }
499
500 void SimplifyIntrinsicsPass::simplifyReduction(fir::CallOp call,
501                                                const fir::KindMapping &kindMap,
502                                                GenReductionBodyTy genBodyFunc) {
503   mlir::SymbolRefAttr callee = call.getCalleeAttr();
504   mlir::Operation::operand_range args = call.getArgs();
505   // args[1] and args[2] are source filename and line number, ignored.
506   const mlir::Value &dim = args[3];
507   const mlir::Value &mask = args[4];
508   // dim is zero when it is absent, which is an implementation
509   // detail in the runtime library.
510   bool dimAndMaskAbsent = isZero(dim) && isOperandAbsent(mask);
511   unsigned rank = getDimCount(args[0]);
512   if (dimAndMaskAbsent && rank > 0) {
513     mlir::Location loc = call.getLoc();
514     fir::FirOpBuilder builder(call, kindMap);
515
516     // Support only floating point and integer results now.
517     mlir::Type resultType = call.getResult(0).getType();
518     if (!resultType.isa<mlir::FloatType>() &&
519         !resultType.isa<mlir::IntegerType>())
520       return;
521
522     auto argType = getArgElementType(args[0]);
523     if (!argType)
524       return;
525     assert(*argType == resultType &&
526            "Argument/result types mismatch in reduction");
527
528     auto typeGenerator = [&resultType](fir::FirOpBuilder &builder) {
529       return genNoneBoxType(builder, resultType);
530     };
531     auto bodyGenerator = [&rank, &genBodyFunc](fir::FirOpBuilder &builder,
532                                                mlir::func::FuncOp &funcOp) {
533       genBodyFunc(builder, funcOp, rank);
534     };
535     // Mangle the function name with the rank value as "x<rank>".
536     std::string funcName =
537         (mlir::Twine{callee.getLeafReference().getValue(), "x"} +
538          mlir::Twine{rank})
539             .str();
540     mlir::func::FuncOp newFunc =
541         getOrCreateFunction(builder, funcName, typeGenerator, bodyGenerator);
542     auto newCall =
543         builder.create<fir::CallOp>(loc, newFunc, mlir::ValueRange{args[0]});
544     call->replaceAllUsesWith(newCall.getResults());
545     call->dropAllReferences();
546     call->erase();
547   }
548 }
549
550 void SimplifyIntrinsicsPass::runOnOperation() {
551   LLVM_DEBUG(llvm::dbgs() << "=== Begin " DEBUG_TYPE " ===\n");
552   mlir::ModuleOp module = getOperation();
553   fir::KindMapping kindMap = fir::getKindMapping(module);
554   module.walk([&](mlir::Operation *op) {
555     if (auto call = mlir::dyn_cast<fir::CallOp>(op)) {
556       if (mlir::SymbolRefAttr callee = call.getCalleeAttr()) {
557         mlir::StringRef funcName = callee.getLeafReference().getValue();
558         // Replace call to runtime function for SUM when it has single
559         // argument (no dim or mask argument) for 1D arrays with either
560         // Integer4 or Real8 types. Other forms are ignored.
561         // The new function is added to the module.
562         //
563         // Prototype for runtime call (from sum.cpp):
564         // RTNAME(Sum<T>)(const Descriptor &x, const char *source, int line,
565         //                int dim, const Descriptor *mask)
566         //
567         if (funcName.startswith(RTNAME_STRING(Sum))) {
568           simplifyReduction(call, kindMap, genRuntimeSumBody);
569           return;
570         }
571         if (funcName.startswith(RTNAME_STRING(DotProduct))) {
572           LLVM_DEBUG(llvm::dbgs() << "Handling " << funcName << "\n");
573           LLVM_DEBUG(llvm::dbgs() << "Call operation:\n"; op->dump();
574                      llvm::dbgs() << "\n");
575           mlir::Operation::operand_range args = call.getArgs();
576           const mlir::Value &v1 = args[0];
577           const mlir::Value &v2 = args[1];
578           mlir::Location loc = call.getLoc();
579           fir::FirOpBuilder builder(op, kindMap);
580
581           mlir::Type type = call.getResult(0).getType();
582           if (!type.isa<mlir::FloatType>() && !type.isa<mlir::IntegerType>())
583             return;
584
585           // Try to find the element types of the boxed arguments.
586           auto arg1Type = getArgElementType(v1);
587           auto arg2Type = getArgElementType(v2);
588
589           if (!arg1Type || !arg2Type)
590             return;
591
592           // Support only floating point and integer arguments
593           // now (e.g. logical is skipped here).
594           if (!arg1Type->isa<mlir::FloatType>() &&
595               !arg1Type->isa<mlir::IntegerType>())
596             return;
597           if (!arg2Type->isa<mlir::FloatType>() &&
598               !arg2Type->isa<mlir::IntegerType>())
599             return;
600
601           auto typeGenerator = [&type](fir::FirOpBuilder &builder) {
602             return genRuntimeDotType(builder, type);
603           };
604           auto bodyGenerator = [&arg1Type,
605                                 &arg2Type](fir::FirOpBuilder &builder,
606                                            mlir::func::FuncOp &funcOp) {
607             genRuntimeDotBody(builder, funcOp, *arg1Type, *arg2Type);
608           };
609
610           // Suffix the function name with the element types
611           // of the arguments.
612           std::string typedFuncName(funcName);
613           llvm::raw_string_ostream nameOS(typedFuncName);
614           nameOS << "_";
615           arg1Type->print(nameOS);
616           nameOS << "_";
617           arg2Type->print(nameOS);
618
619           mlir::func::FuncOp newFunc = getOrCreateFunction(
620               builder, typedFuncName, typeGenerator, bodyGenerator);
621           auto newCall = builder.create<fir::CallOp>(loc, newFunc,
622                                                      mlir::ValueRange{v1, v2});
623           call->replaceAllUsesWith(newCall.getResults());
624           call->dropAllReferences();
625           call->erase();
626
627           LLVM_DEBUG(llvm::dbgs() << "Replaced with:\n"; newCall.dump();
628                      llvm::dbgs() << "\n");
629           return;
630         }
631         if (funcName.startswith(RTNAME_STRING(Maxval))) {
632           simplifyReduction(call, kindMap, genRuntimeMaxvalBody);
633           return;
634         }
635       }
636     }
637   });
638   LLVM_DEBUG(llvm::dbgs() << "=== End " DEBUG_TYPE " ===\n");
639 }
640
641 void SimplifyIntrinsicsPass::getDependentDialects(
642     mlir::DialectRegistry &registry) const {
643   // LLVM::LinkageAttr creation requires that LLVM dialect is loaded.
644   registry.insert<mlir::LLVM::LLVMDialect>();
645 }
646 std::unique_ptr<mlir::Pass> fir::createSimplifyIntrinsicsPass() {
647   return std::make_unique<SimplifyIntrinsicsPass>();
648 }