[flang] Support multidimensional reductions in SimplifyIntrinsicsPass.
[platform/upstream/llvm.git] / flang / lib / Optimizer / Transforms / SimplifyIntrinsics.cpp
1 //===- SimplifyIntrinsics.cpp -- replace intrinsics with simpler form -----===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8
9 //===----------------------------------------------------------------------===//
10 /// \file
11 /// This pass looks for suitable calls to runtime library for intrinsics that
12 /// can be simplified/specialized and replaces with a specialized function.
13 ///
14 /// For example, SUM(arr) can be specialized as a simple function with one loop,
15 /// compared to the three arguments (plus file & line info) that the runtime
16 /// call has - when the argument is a 1D-array (multiple loops may be needed
17 //  for higher dimension arrays, of course)
18 ///
19 /// The general idea is that besides making the call simpler, it can also be
20 /// inlined by other passes that run after this pass, which further improves
21 /// performance, particularly when the work done in the function is trivial
22 /// and small in size.
23 //===----------------------------------------------------------------------===//
24
25 #include "flang/Optimizer/Builder/BoxValue.h"
26 #include "flang/Optimizer/Builder/FIRBuilder.h"
27 #include "flang/Optimizer/Builder/Todo.h"
28 #include "flang/Optimizer/Dialect/FIROps.h"
29 #include "flang/Optimizer/Dialect/FIRType.h"
30 #include "flang/Optimizer/Support/FIRContext.h"
31 #include "flang/Optimizer/Transforms/Passes.h"
32 #include "flang/Runtime/entry-names.h"
33 #include "mlir/Dialect/LLVMIR/LLVMDialect.h"
34 #include "mlir/IR/Matchers.h"
35 #include "mlir/IR/TypeUtilities.h"
36 #include "mlir/Pass/Pass.h"
37 #include "mlir/Transforms/DialectConversion.h"
38 #include "mlir/Transforms/GreedyPatternRewriteDriver.h"
39 #include "mlir/Transforms/RegionUtils.h"
40 #include "llvm/ADT/Optional.h"
41 #include "llvm/Support/Debug.h"
42 #include "llvm/Support/raw_ostream.h"
43
44 namespace fir {
45 #define GEN_PASS_DEF_SIMPLIFYINTRINSICS
46 #include "flang/Optimizer/Transforms/Passes.h.inc"
47 } // namespace fir
48
49 #define RTNAME_STRINGIFY2(x) #x
50 #define RTNAME_STRINGIFY(x) RTNAME_STRINGIFY2(x)
51 #define RTNAME_STRING(x) RTNAME_STRINGIFY(RTNAME(x))
52
53 #define DEBUG_TYPE "flang-simplify-intrinsics"
54
55 namespace {
56
57 class SimplifyIntrinsicsPass
58     : public fir::impl::SimplifyIntrinsicsBase<SimplifyIntrinsicsPass> {
59   using FunctionTypeGeneratorTy =
60       llvm::function_ref<mlir::FunctionType(fir::FirOpBuilder &)>;
61   using FunctionBodyGeneratorTy =
62       llvm::function_ref<void(fir::FirOpBuilder &, mlir::func::FuncOp &)>;
63   using GenReductionBodyTy = llvm::function_ref<void(
64       fir::FirOpBuilder &builder, mlir::func::FuncOp &funcOp, unsigned rank)>;
65
66 public:
67   /// Generate a new function implementing a simplified version
68   /// of a Fortran runtime function defined by \p basename name.
69   /// \p typeGenerator is a callback that generates the new function's type.
70   /// \p bodyGenerator is a callback that generates the new function's body.
71   /// The new function is created in the \p builder's Module.
72   mlir::func::FuncOp getOrCreateFunction(fir::FirOpBuilder &builder,
73                                          const mlir::StringRef &basename,
74                                          FunctionTypeGeneratorTy typeGenerator,
75                                          FunctionBodyGeneratorTy bodyGenerator);
76   void runOnOperation() override;
77   void getDependentDialects(mlir::DialectRegistry &registry) const override;
78
79 private:
80   /// Helper function to replace a reduction type of call with its
81   /// simplified form. The actual function is generated using a callback
82   /// function.
83   /// \p call is the call to be replaced
84   /// \p kindMap is used to create FIROpBuilder
85   /// \p genBodyFunc is the callback that builds the replacement function
86   void simplifyReduction(fir::CallOp call, const fir::KindMapping &kindMap,
87                          GenReductionBodyTy genBodyFunc);
88 };
89
90 } // namespace
91
92 /// Generate function type for the simplified version of RTNAME(Sum) and
93 /// similar functions with a fir.box<none> type returning \p elementType.
94 static mlir::FunctionType genNoneBoxType(fir::FirOpBuilder &builder,
95                                          const mlir::Type &elementType) {
96   mlir::Type boxType = fir::BoxType::get(builder.getNoneType());
97   return mlir::FunctionType::get(builder.getContext(), {boxType},
98                                  {elementType});
99 }
100
101 using BodyOpGeneratorTy = llvm::function_ref<mlir::Value(
102     fir::FirOpBuilder &, mlir::Location, const mlir::Type &, mlir::Value,
103     mlir::Value)>;
104 using InitValGeneratorTy = llvm::function_ref<mlir::Value(
105     fir::FirOpBuilder &, mlir::Location, const mlir::Type &)>;
106
107 /// Generate the reduction loop into \p funcOp.
108 ///
109 /// \p initVal is a function, called to get the initial value for
110 ///    the reduction value
111 /// \p genBody is called to fill in the actual reduciton operation
112 ///    for example add for SUM, MAX for MAXVAL, etc.
113 /// \p rank is the rank of the input argument.
114 static void genReductionLoop(fir::FirOpBuilder &builder,
115                              mlir::func::FuncOp &funcOp,
116                              InitValGeneratorTy initVal,
117                              BodyOpGeneratorTy genBody, unsigned rank) {
118   auto loc = mlir::UnknownLoc::get(builder.getContext());
119   mlir::Type elementType = funcOp.getResultTypes()[0];
120   builder.setInsertionPointToEnd(funcOp.addEntryBlock());
121
122   mlir::IndexType idxTy = builder.getIndexType();
123
124   mlir::Block::BlockArgListType args = funcOp.front().getArguments();
125   mlir::Value arg = args[0];
126
127   mlir::Value zeroIdx = builder.createIntegerConstant(loc, idxTy, 0);
128
129   fir::SequenceType::Shape flatShape(rank,
130                                      fir::SequenceType::getUnknownExtent());
131   mlir::Type arrTy = fir::SequenceType::get(flatShape, elementType);
132   mlir::Type boxArrTy = fir::BoxType::get(arrTy);
133   mlir::Value array = builder.create<fir::ConvertOp>(loc, boxArrTy, arg);
134   mlir::Value init = initVal(builder, loc, elementType);
135
136   llvm::SmallVector<mlir::Value, 15> bounds;
137
138   assert(rank > 0 && "rank cannot be zero");
139   mlir::Value one = builder.createIntegerConstant(loc, idxTy, 1);
140
141   // Compute all the upper bounds before the loop nest.
142   // It is not strictly necessary for performance, since the loop nest
143   // does not have any store operations and any LICM optimization
144   // should be able to optimize the redundancy.
145   for (unsigned i = 0; i < rank; ++i) {
146     mlir::Value dimIdx = builder.createIntegerConstant(loc, idxTy, i);
147     auto dims =
148         builder.create<fir::BoxDimsOp>(loc, idxTy, idxTy, idxTy, array, dimIdx);
149     mlir::Value len = dims.getResult(1);
150     // We use C indexing here, so len-1 as loopcount
151     mlir::Value loopCount = builder.create<mlir::arith::SubIOp>(loc, len, one);
152     bounds.push_back(loopCount);
153   }
154
155   // Create a loop nest consisting of DoLoopOp operations.
156   // Collect the loops' induction variables into indices array,
157   // which will be used in the innermost loop to load the input
158   // array's element.
159   // The loops are generated such that the innermost loop processes
160   // the 0 dimension.
161   llvm::SmallVector<mlir::Value, 15> indices;
162   for (unsigned i = rank; 0 < i; --i) {
163     mlir::Value step = one;
164     mlir::Value loopCount = bounds[i - 1];
165     auto loop = builder.create<fir::DoLoopOp>(loc, zeroIdx, loopCount, step,
166                                               /*unordered=*/false,
167                                               /*finalCountValue=*/false, init);
168     init = loop.getRegionIterArgs()[0];
169     indices.push_back(loop.getInductionVar());
170     // Set insertion point to the loop body so that the next loop
171     // is inserted inside the current one.
172     builder.setInsertionPointToStart(loop.getBody());
173   }
174
175   // Reverse the indices such that they are ordered as:
176   //   <dim-0-idx, dim-1-idx, ...>
177   std::reverse(indices.begin(), indices.end());
178
179   // We are in the innermost loop: generate the reduction body.
180   mlir::Type eleRefTy = builder.getRefType(elementType);
181   mlir::Value addr =
182       builder.create<fir::CoordinateOp>(loc, eleRefTy, array, indices);
183   mlir::Value elem = builder.create<fir::LoadOp>(loc, addr);
184
185   mlir::Value reductionVal = genBody(builder, loc, elementType, elem, init);
186
187   // Unwind the loop nest and insert ResultOp on each level
188   // to return the updated value of the reduction to the enclosing
189   // loops.
190   for (unsigned i = 0; i < rank; ++i) {
191     auto result = builder.create<fir::ResultOp>(loc, reductionVal);
192     // Proceed to the outer loop.
193     auto loop = mlir::cast<fir::DoLoopOp>(result->getParentOp());
194     reductionVal = loop.getResult(0);
195     // Set insertion point after the loop operation that we have
196     // just processed.
197     builder.setInsertionPointAfter(loop.getOperation());
198   }
199
200   // End of loop nest. The insertion point is after the outermost loop.
201   // Return the reduction value from the function.
202   builder.create<mlir::func::ReturnOp>(loc, reductionVal);
203 }
204
205 /// Generate function body of the simplified version of RTNAME(Sum)
206 /// with signature provided by \p funcOp. The caller is responsible
207 /// for saving/restoring the original insertion point of \p builder.
208 /// \p funcOp is expected to be empty on entry to this function.
209 /// \p rank specifies the rank of the input argument.
210 static void genRuntimeSumBody(fir::FirOpBuilder &builder,
211                               mlir::func::FuncOp &funcOp, unsigned rank) {
212   // function RTNAME(Sum)<T>x<rank>_simplified(arr)
213   //   T, dimension(:) :: arr
214   //   T sum = 0
215   //   integer iter
216   //   do iter = 0, extent(arr)
217   //     sum = sum + arr[iter]
218   //   end do
219   //   RTNAME(Sum)<T>x<rank>_simplified = sum
220   // end function RTNAME(Sum)<T>x<rank>_simplified
221   auto zero = [](fir::FirOpBuilder builder, mlir::Location loc,
222                  mlir::Type elementType) {
223     if (auto ty = elementType.dyn_cast<mlir::FloatType>()) {
224       const llvm::fltSemantics &sem = ty.getFloatSemantics();
225       return builder.createRealConstant(loc, elementType,
226                                         llvm::APFloat::getZero(sem));
227     }
228     return builder.createIntegerConstant(loc, elementType, 0);
229   };
230
231   auto genBodyOp = [](fir::FirOpBuilder builder, mlir::Location loc,
232                       mlir::Type elementType, mlir::Value elem1,
233                       mlir::Value elem2) -> mlir::Value {
234     if (elementType.isa<mlir::FloatType>())
235       return builder.create<mlir::arith::AddFOp>(loc, elem1, elem2);
236     if (elementType.isa<mlir::IntegerType>())
237       return builder.create<mlir::arith::AddIOp>(loc, elem1, elem2);
238
239     llvm_unreachable("unsupported type");
240     return {};
241   };
242
243   genReductionLoop(builder, funcOp, zero, genBodyOp, rank);
244 }
245
246 static void genRuntimeMaxvalBody(fir::FirOpBuilder &builder,
247                                  mlir::func::FuncOp &funcOp, unsigned rank) {
248   auto init = [](fir::FirOpBuilder builder, mlir::Location loc,
249                  mlir::Type elementType) {
250     if (auto ty = elementType.dyn_cast<mlir::FloatType>()) {
251       const llvm::fltSemantics &sem = ty.getFloatSemantics();
252       return builder.createRealConstant(
253           loc, elementType, llvm::APFloat::getLargest(sem, /*Negative=*/true));
254     }
255     unsigned bits = elementType.getIntOrFloatBitWidth();
256     int64_t minInt = llvm::APInt::getSignedMinValue(bits).getSExtValue();
257     return builder.createIntegerConstant(loc, elementType, minInt);
258   };
259
260   auto genBodyOp = [](fir::FirOpBuilder builder, mlir::Location loc,
261                       mlir::Type elementType, mlir::Value elem1,
262                       mlir::Value elem2) -> mlir::Value {
263     if (elementType.isa<mlir::FloatType>())
264       return builder.create<mlir::arith::MaxFOp>(loc, elem1, elem2);
265     if (elementType.isa<mlir::IntegerType>())
266       return builder.create<mlir::arith::MaxSIOp>(loc, elem1, elem2);
267
268     llvm_unreachable("unsupported type");
269     return {};
270   };
271   genReductionLoop(builder, funcOp, init, genBodyOp, rank);
272 }
273
274 /// Generate function type for the simplified version of RTNAME(DotProduct)
275 /// operating on the given \p elementType.
276 static mlir::FunctionType genRuntimeDotType(fir::FirOpBuilder &builder,
277                                             const mlir::Type &elementType) {
278   mlir::Type boxType = fir::BoxType::get(builder.getNoneType());
279   return mlir::FunctionType::get(builder.getContext(), {boxType, boxType},
280                                  {elementType});
281 }
282
283 /// Generate function body of the simplified version of RTNAME(DotProduct)
284 /// with signature provided by \p funcOp. The caller is responsible
285 /// for saving/restoring the original insertion point of \p builder.
286 /// \p funcOp is expected to be empty on entry to this function.
287 /// \p arg1ElementTy and \p arg2ElementTy specify elements types
288 /// of the underlying array objects - they are used to generate proper
289 /// element accesses.
290 static void genRuntimeDotBody(fir::FirOpBuilder &builder,
291                               mlir::func::FuncOp &funcOp,
292                               mlir::Type arg1ElementTy,
293                               mlir::Type arg2ElementTy) {
294   // function RTNAME(DotProduct)<T>_simplified(arr1, arr2)
295   //   T, dimension(:) :: arr1, arr2
296   //   T product = 0
297   //   integer iter
298   //   do iter = 0, extent(arr1)
299   //     product = product + arr1[iter] * arr2[iter]
300   //   end do
301   //   RTNAME(ADotProduct)<T>_simplified = product
302   // end function RTNAME(DotProduct)<T>_simplified
303   auto loc = mlir::UnknownLoc::get(builder.getContext());
304   mlir::Type resultElementType = funcOp.getResultTypes()[0];
305   builder.setInsertionPointToEnd(funcOp.addEntryBlock());
306
307   mlir::IndexType idxTy = builder.getIndexType();
308
309   mlir::Value zero =
310       resultElementType.isa<mlir::FloatType>()
311           ? builder.createRealConstant(loc, resultElementType, 0.0)
312           : builder.createIntegerConstant(loc, resultElementType, 0);
313
314   mlir::Block::BlockArgListType args = funcOp.front().getArguments();
315   mlir::Value arg1 = args[0];
316   mlir::Value arg2 = args[1];
317
318   mlir::Value zeroIdx = builder.createIntegerConstant(loc, idxTy, 0);
319
320   fir::SequenceType::Shape flatShape = {fir::SequenceType::getUnknownExtent()};
321   mlir::Type arrTy1 = fir::SequenceType::get(flatShape, arg1ElementTy);
322   mlir::Type boxArrTy1 = fir::BoxType::get(arrTy1);
323   mlir::Value array1 = builder.create<fir::ConvertOp>(loc, boxArrTy1, arg1);
324   mlir::Type arrTy2 = fir::SequenceType::get(flatShape, arg2ElementTy);
325   mlir::Type boxArrTy2 = fir::BoxType::get(arrTy2);
326   mlir::Value array2 = builder.create<fir::ConvertOp>(loc, boxArrTy2, arg2);
327   // This version takes the loop trip count from the first argument.
328   // If the first argument's box has unknown (at compilation time)
329   // extent, then it may be better to take the extent from the second
330   // argument - so that after inlining the loop may be better optimized, e.g.
331   // fully unrolled. This requires generating two versions of the simplified
332   // function and some analysis at the call site to choose which version
333   // is more profitable to call.
334   // Note that we can assume that both arguments have the same extent.
335   auto dims =
336       builder.create<fir::BoxDimsOp>(loc, idxTy, idxTy, idxTy, array1, zeroIdx);
337   mlir::Value len = dims.getResult(1);
338   mlir::Value one = builder.createIntegerConstant(loc, idxTy, 1);
339   mlir::Value step = one;
340
341   // We use C indexing here, so len-1 as loopcount
342   mlir::Value loopCount = builder.create<mlir::arith::SubIOp>(loc, len, one);
343   auto loop = builder.create<fir::DoLoopOp>(loc, zeroIdx, loopCount, step,
344                                             /*unordered=*/false,
345                                             /*finalCountValue=*/false, zero);
346   mlir::Value sumVal = loop.getRegionIterArgs()[0];
347
348   // Begin loop code
349   mlir::OpBuilder::InsertPoint loopEndPt = builder.saveInsertionPoint();
350   builder.setInsertionPointToStart(loop.getBody());
351
352   mlir::Type eleRef1Ty = builder.getRefType(arg1ElementTy);
353   mlir::Value index = loop.getInductionVar();
354   mlir::Value addr1 =
355       builder.create<fir::CoordinateOp>(loc, eleRef1Ty, array1, index);
356   mlir::Value elem1 = builder.create<fir::LoadOp>(loc, addr1);
357   // Convert to the result type.
358   elem1 = builder.create<fir::ConvertOp>(loc, resultElementType, elem1);
359
360   mlir::Type eleRef2Ty = builder.getRefType(arg2ElementTy);
361   mlir::Value addr2 =
362       builder.create<fir::CoordinateOp>(loc, eleRef2Ty, array2, index);
363   mlir::Value elem2 = builder.create<fir::LoadOp>(loc, addr2);
364   // Convert to the result type.
365   elem2 = builder.create<fir::ConvertOp>(loc, resultElementType, elem2);
366
367   if (resultElementType.isa<mlir::FloatType>())
368     sumVal = builder.create<mlir::arith::AddFOp>(
369         loc, builder.create<mlir::arith::MulFOp>(loc, elem1, elem2), sumVal);
370   else if (resultElementType.isa<mlir::IntegerType>())
371     sumVal = builder.create<mlir::arith::AddIOp>(
372         loc, builder.create<mlir::arith::MulIOp>(loc, elem1, elem2), sumVal);
373   else
374     llvm_unreachable("unsupported type");
375
376   builder.create<fir::ResultOp>(loc, sumVal);
377   // End of loop.
378   builder.restoreInsertionPoint(loopEndPt);
379
380   mlir::Value resultVal = loop.getResult(0);
381   builder.create<mlir::func::ReturnOp>(loc, resultVal);
382 }
383
384 mlir::func::FuncOp SimplifyIntrinsicsPass::getOrCreateFunction(
385     fir::FirOpBuilder &builder, const mlir::StringRef &baseName,
386     FunctionTypeGeneratorTy typeGenerator,
387     FunctionBodyGeneratorTy bodyGenerator) {
388   // WARNING: if the function generated here changes its signature
389   //          or behavior (the body code), we should probably embed some
390   //          versioning information into its name, otherwise libraries
391   //          statically linked with older versions of Flang may stop
392   //          working with object files created with newer Flang.
393   //          We can also avoid this by using internal linkage, but
394   //          this may increase the size of final executable/shared library.
395   std::string replacementName = mlir::Twine{baseName, "_simplified"}.str();
396   mlir::ModuleOp module = builder.getModule();
397   // If we already have a function, just return it.
398   mlir::func::FuncOp newFunc =
399       fir::FirOpBuilder::getNamedFunction(module, replacementName);
400   mlir::FunctionType fType = typeGenerator(builder);
401   if (newFunc) {
402     assert(newFunc.getFunctionType() == fType &&
403            "type mismatch for simplified function");
404     return newFunc;
405   }
406
407   // Need to build the function!
408   auto loc = mlir::UnknownLoc::get(builder.getContext());
409   newFunc =
410       fir::FirOpBuilder::createFunction(loc, module, replacementName, fType);
411   auto inlineLinkage = mlir::LLVM::linkage::Linkage::LinkonceODR;
412   auto linkage =
413       mlir::LLVM::LinkageAttr::get(builder.getContext(), inlineLinkage);
414   newFunc->setAttr("llvm.linkage", linkage);
415
416   // Save the position of the original call.
417   mlir::OpBuilder::InsertPoint insertPt = builder.saveInsertionPoint();
418
419   bodyGenerator(builder, newFunc);
420
421   // Now back to where we were adding code earlier...
422   builder.restoreInsertionPoint(insertPt);
423
424   return newFunc;
425 }
426
427 fir::ConvertOp expectConvertOp(mlir::Value val) {
428   if (fir::ConvertOp op =
429           mlir::dyn_cast_or_null<fir::ConvertOp>(val.getDefiningOp()))
430     return op;
431   LLVM_DEBUG(llvm::dbgs() << "Didn't find expected fir::ConvertOp\n");
432   return nullptr;
433 }
434
435 static bool isOperandAbsent(mlir::Value val) {
436   if (auto op = expectConvertOp(val)) {
437     assert(op->getOperands().size() != 0);
438     return mlir::isa_and_nonnull<fir::AbsentOp>(
439         op->getOperand(0).getDefiningOp());
440   }
441   return false;
442 }
443
444 static bool isZero(mlir::Value val) {
445   if (auto op = expectConvertOp(val)) {
446     assert(op->getOperands().size() != 0);
447     if (mlir::Operation *defOp = op->getOperand(0).getDefiningOp())
448       return mlir::matchPattern(defOp, mlir::m_Zero());
449   }
450   return false;
451 }
452
453 static mlir::Value findBoxDef(mlir::Value val) {
454   if (auto op = expectConvertOp(val)) {
455     assert(op->getOperands().size() != 0);
456     if (auto box = mlir::dyn_cast_or_null<fir::EmboxOp>(
457             op->getOperand(0).getDefiningOp()))
458       return box.getResult();
459     if (auto box = mlir::dyn_cast_or_null<fir::ReboxOp>(
460             op->getOperand(0).getDefiningOp()))
461       return box.getResult();
462   }
463   return {};
464 }
465
466 static unsigned getDimCount(mlir::Value val) {
467   // In order to find the dimensions count, we look for EmboxOp/ReboxOp
468   // and take the count from its *result* type. Note that in case
469   // of sliced emboxing the operand and the result of EmboxOp/ReboxOp
470   // have different types.
471   // Actually, we can take the box type from the operand of
472   // the first ConvertOp that has non-opaque box type that we meet
473   // going through the ConvertOp chain.
474   if (mlir::Value emboxVal = findBoxDef(val))
475     if (auto boxTy = emboxVal.getType().dyn_cast<fir::BoxType>())
476       if (auto seqTy = boxTy.getEleTy().dyn_cast<fir::SequenceType>())
477         return seqTy.getDimension();
478   return 0;
479 }
480
481 /// Given the call operation's box argument \p val, discover
482 /// the element type of the underlying array object.
483 /// \returns the element type or llvm::None if the type cannot
484 /// be reliably found.
485 /// We expect that the argument is a result of fir.convert
486 /// with the destination type of !fir.box<none>.
487 static llvm::Optional<mlir::Type> getArgElementType(mlir::Value val) {
488   mlir::Operation *defOp;
489   do {
490     defOp = val.getDefiningOp();
491     // Analyze only sequences of convert operations.
492     if (!mlir::isa<fir::ConvertOp>(defOp))
493       return llvm::None;
494     val = defOp->getOperand(0);
495     // The convert operation is expected to convert from one
496     // box type to another box type.
497     auto boxType = val.getType().cast<fir::BoxType>();
498     auto elementType = fir::unwrapSeqOrBoxedSeqType(boxType);
499     if (!elementType.isa<mlir::NoneType>())
500       return elementType;
501   } while (true);
502 }
503
504 void SimplifyIntrinsicsPass::simplifyReduction(fir::CallOp call,
505                                                const fir::KindMapping &kindMap,
506                                                GenReductionBodyTy genBodyFunc) {
507   mlir::SymbolRefAttr callee = call.getCalleeAttr();
508   mlir::Operation::operand_range args = call.getArgs();
509   // args[1] and args[2] are source filename and line number, ignored.
510   const mlir::Value &dim = args[3];
511   const mlir::Value &mask = args[4];
512   // dim is zero when it is absent, which is an implementation
513   // detail in the runtime library.
514   bool dimAndMaskAbsent = isZero(dim) && isOperandAbsent(mask);
515   unsigned rank = getDimCount(args[0]);
516   if (dimAndMaskAbsent && rank > 0) {
517     mlir::Location loc = call.getLoc();
518     fir::FirOpBuilder builder(call, kindMap);
519
520     // Support only floating point and integer results now.
521     mlir::Type resultType = call.getResult(0).getType();
522     if (!resultType.isa<mlir::FloatType>() &&
523         !resultType.isa<mlir::IntegerType>())
524       return;
525
526     auto argType = getArgElementType(args[0]);
527     if (!argType)
528       return;
529     assert(*argType == resultType &&
530            "Argument/result types mismatch in reduction");
531
532     auto typeGenerator = [&resultType](fir::FirOpBuilder &builder) {
533       return genNoneBoxType(builder, resultType);
534     };
535     auto bodyGenerator = [&rank, &genBodyFunc](fir::FirOpBuilder &builder,
536                                                mlir::func::FuncOp &funcOp) {
537       genBodyFunc(builder, funcOp, rank);
538     };
539     // Mangle the function name with the rank value as "x<rank>".
540     std::string funcName =
541         (mlir::Twine{callee.getLeafReference().getValue(), "x"} +
542          mlir::Twine{rank})
543             .str();
544     mlir::func::FuncOp newFunc =
545         getOrCreateFunction(builder, funcName, typeGenerator, bodyGenerator);
546     auto newCall =
547         builder.create<fir::CallOp>(loc, newFunc, mlir::ValueRange{args[0]});
548     call->replaceAllUsesWith(newCall.getResults());
549     call->dropAllReferences();
550     call->erase();
551   }
552 }
553
554 void SimplifyIntrinsicsPass::runOnOperation() {
555   LLVM_DEBUG(llvm::dbgs() << "=== Begin " DEBUG_TYPE " ===\n");
556   mlir::ModuleOp module = getOperation();
557   fir::KindMapping kindMap = fir::getKindMapping(module);
558   module.walk([&](mlir::Operation *op) {
559     if (auto call = mlir::dyn_cast<fir::CallOp>(op)) {
560       if (mlir::SymbolRefAttr callee = call.getCalleeAttr()) {
561         mlir::StringRef funcName = callee.getLeafReference().getValue();
562         // Replace call to runtime function for SUM when it has single
563         // argument (no dim or mask argument) for 1D arrays with either
564         // Integer4 or Real8 types. Other forms are ignored.
565         // The new function is added to the module.
566         //
567         // Prototype for runtime call (from sum.cpp):
568         // RTNAME(Sum<T>)(const Descriptor &x, const char *source, int line,
569         //                int dim, const Descriptor *mask)
570         //
571         if (funcName.startswith(RTNAME_STRING(Sum))) {
572           simplifyReduction(call, kindMap, genRuntimeSumBody);
573           return;
574         }
575         if (funcName.startswith(RTNAME_STRING(DotProduct))) {
576           LLVM_DEBUG(llvm::dbgs() << "Handling " << funcName << "\n");
577           LLVM_DEBUG(llvm::dbgs() << "Call operation:\n"; op->dump();
578                      llvm::dbgs() << "\n");
579           mlir::Operation::operand_range args = call.getArgs();
580           const mlir::Value &v1 = args[0];
581           const mlir::Value &v2 = args[1];
582           mlir::Location loc = call.getLoc();
583           fir::FirOpBuilder builder(op, kindMap);
584
585           mlir::Type type = call.getResult(0).getType();
586           if (!type.isa<mlir::FloatType>() && !type.isa<mlir::IntegerType>())
587             return;
588
589           // Try to find the element types of the boxed arguments.
590           auto arg1Type = getArgElementType(v1);
591           auto arg2Type = getArgElementType(v2);
592
593           if (!arg1Type || !arg2Type)
594             return;
595
596           // Support only floating point and integer arguments
597           // now (e.g. logical is skipped here).
598           if (!arg1Type->isa<mlir::FloatType>() &&
599               !arg1Type->isa<mlir::IntegerType>())
600             return;
601           if (!arg2Type->isa<mlir::FloatType>() &&
602               !arg2Type->isa<mlir::IntegerType>())
603             return;
604
605           auto typeGenerator = [&type](fir::FirOpBuilder &builder) {
606             return genRuntimeDotType(builder, type);
607           };
608           auto bodyGenerator = [&arg1Type,
609                                 &arg2Type](fir::FirOpBuilder &builder,
610                                            mlir::func::FuncOp &funcOp) {
611             genRuntimeDotBody(builder, funcOp, *arg1Type, *arg2Type);
612           };
613
614           // Suffix the function name with the element types
615           // of the arguments.
616           std::string typedFuncName(funcName);
617           llvm::raw_string_ostream nameOS(typedFuncName);
618           nameOS << "_";
619           arg1Type->print(nameOS);
620           nameOS << "_";
621           arg2Type->print(nameOS);
622
623           mlir::func::FuncOp newFunc = getOrCreateFunction(
624               builder, typedFuncName, typeGenerator, bodyGenerator);
625           auto newCall = builder.create<fir::CallOp>(loc, newFunc,
626                                                      mlir::ValueRange{v1, v2});
627           call->replaceAllUsesWith(newCall.getResults());
628           call->dropAllReferences();
629           call->erase();
630
631           LLVM_DEBUG(llvm::dbgs() << "Replaced with:\n"; newCall.dump();
632                      llvm::dbgs() << "\n");
633           return;
634         }
635         if (funcName.startswith(RTNAME_STRING(Maxval))) {
636           simplifyReduction(call, kindMap, genRuntimeMaxvalBody);
637           return;
638         }
639       }
640     }
641   });
642   LLVM_DEBUG(llvm::dbgs() << "=== End " DEBUG_TYPE " ===\n");
643 }
644
645 void SimplifyIntrinsicsPass::getDependentDialects(
646     mlir::DialectRegistry &registry) const {
647   // LLVM::LinkageAttr creation requires that LLVM dialect is loaded.
648   registry.insert<mlir::LLVM::LLVMDialect>();
649 }
650 std::unique_ptr<mlir::Pass> fir::createSimplifyIntrinsicsPass() {
651   return std::make_unique<SimplifyIntrinsicsPass>();
652 }