1 //===- SimplifyIntrinsics.cpp -- replace intrinsics with simpler form -----===//
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
7 //===----------------------------------------------------------------------===//
9 //===----------------------------------------------------------------------===//
11 /// This pass looks for suitable calls to runtime library for intrinsics that
12 /// can be simplified/specialized and replaces with a specialized function.
14 /// For example, SUM(arr) can be specialized as a simple function with one loop,
15 /// compared to the three arguments (plus file & line info) that the runtime
16 /// call has - when the argument is a 1D-array (multiple loops may be needed
17 // for higher dimension arrays, of course)
19 /// The general idea is that besides making the call simpler, it can also be
20 /// inlined by other passes that run after this pass, which further improves
21 /// performance, particularly when the work done in the function is trivial
22 /// and small in size.
23 //===----------------------------------------------------------------------===//
25 #include "flang/Optimizer/Builder/BoxValue.h"
26 #include "flang/Optimizer/Builder/FIRBuilder.h"
27 #include "flang/Optimizer/Builder/Todo.h"
28 #include "flang/Optimizer/Dialect/FIROps.h"
29 #include "flang/Optimizer/Dialect/FIRType.h"
30 #include "flang/Optimizer/Support/FIRContext.h"
31 #include "flang/Optimizer/Transforms/Passes.h"
32 #include "flang/Runtime/entry-names.h"
33 #include "mlir/Dialect/LLVMIR/LLVMDialect.h"
34 #include "mlir/IR/Matchers.h"
35 #include "mlir/IR/TypeUtilities.h"
36 #include "mlir/Pass/Pass.h"
37 #include "mlir/Transforms/DialectConversion.h"
38 #include "mlir/Transforms/GreedyPatternRewriteDriver.h"
39 #include "mlir/Transforms/RegionUtils.h"
40 #include "llvm/ADT/Optional.h"
41 #include "llvm/Support/Debug.h"
42 #include "llvm/Support/raw_ostream.h"
45 #define GEN_PASS_DEF_SIMPLIFYINTRINSICS
46 #include "flang/Optimizer/Transforms/Passes.h.inc"
49 #define DEBUG_TYPE "flang-simplify-intrinsics"
53 class SimplifyIntrinsicsPass
54 : public fir::impl::SimplifyIntrinsicsBase<SimplifyIntrinsicsPass> {
55 using FunctionTypeGeneratorTy =
56 llvm::function_ref<mlir::FunctionType(fir::FirOpBuilder &)>;
57 using FunctionBodyGeneratorTy =
58 llvm::function_ref<void(fir::FirOpBuilder &, mlir::func::FuncOp &)>;
59 using GenReductionBodyTy = llvm::function_ref<void(
60 fir::FirOpBuilder &builder, mlir::func::FuncOp &funcOp, unsigned rank)>;
63 /// Generate a new function implementing a simplified version
64 /// of a Fortran runtime function defined by \p basename name.
65 /// \p typeGenerator is a callback that generates the new function's type.
66 /// \p bodyGenerator is a callback that generates the new function's body.
67 /// The new function is created in the \p builder's Module.
68 mlir::func::FuncOp getOrCreateFunction(fir::FirOpBuilder &builder,
69 const mlir::StringRef &basename,
70 FunctionTypeGeneratorTy typeGenerator,
71 FunctionBodyGeneratorTy bodyGenerator);
72 void runOnOperation() override;
73 void getDependentDialects(mlir::DialectRegistry ®istry) const override;
76 /// Helper function to replace a reduction type of call with its
77 /// simplified form. The actual function is generated using a callback
79 /// \p call is the call to be replaced
80 /// \p kindMap is used to create FIROpBuilder
81 /// \p genBodyFunc is the callback that builds the replacement function
82 void simplifyReduction(fir::CallOp call, const fir::KindMapping &kindMap,
83 GenReductionBodyTy genBodyFunc);
88 /// Generate function type for the simplified version of RTNAME(Sum) and
89 /// similar functions with a fir.box<none> type returning \p elementType.
90 static mlir::FunctionType genNoneBoxType(fir::FirOpBuilder &builder,
91 const mlir::Type &elementType) {
92 mlir::Type boxType = fir::BoxType::get(builder.getNoneType());
93 return mlir::FunctionType::get(builder.getContext(), {boxType},
97 using BodyOpGeneratorTy = llvm::function_ref<mlir::Value(
98 fir::FirOpBuilder &, mlir::Location, const mlir::Type &, mlir::Value,
100 using InitValGeneratorTy = llvm::function_ref<mlir::Value(
101 fir::FirOpBuilder &, mlir::Location, const mlir::Type &)>;
103 /// Generate the reduction loop into \p funcOp.
105 /// \p initVal is a function, called to get the initial value for
106 /// the reduction value
107 /// \p genBody is called to fill in the actual reduciton operation
108 /// for example add for SUM, MAX for MAXVAL, etc.
109 /// \p rank is the rank of the input argument.
110 static void genReductionLoop(fir::FirOpBuilder &builder,
111 mlir::func::FuncOp &funcOp,
112 InitValGeneratorTy initVal,
113 BodyOpGeneratorTy genBody, unsigned rank) {
114 auto loc = mlir::UnknownLoc::get(builder.getContext());
115 mlir::Type elementType = funcOp.getResultTypes()[0];
116 builder.setInsertionPointToEnd(funcOp.addEntryBlock());
118 mlir::IndexType idxTy = builder.getIndexType();
120 mlir::Block::BlockArgListType args = funcOp.front().getArguments();
121 mlir::Value arg = args[0];
123 mlir::Value zeroIdx = builder.createIntegerConstant(loc, idxTy, 0);
125 fir::SequenceType::Shape flatShape(rank,
126 fir::SequenceType::getUnknownExtent());
127 mlir::Type arrTy = fir::SequenceType::get(flatShape, elementType);
128 mlir::Type boxArrTy = fir::BoxType::get(arrTy);
129 mlir::Value array = builder.create<fir::ConvertOp>(loc, boxArrTy, arg);
130 mlir::Value init = initVal(builder, loc, elementType);
132 llvm::SmallVector<mlir::Value, 15> bounds;
134 assert(rank > 0 && "rank cannot be zero");
135 mlir::Value one = builder.createIntegerConstant(loc, idxTy, 1);
137 // Compute all the upper bounds before the loop nest.
138 // It is not strictly necessary for performance, since the loop nest
139 // does not have any store operations and any LICM optimization
140 // should be able to optimize the redundancy.
141 for (unsigned i = 0; i < rank; ++i) {
142 mlir::Value dimIdx = builder.createIntegerConstant(loc, idxTy, i);
144 builder.create<fir::BoxDimsOp>(loc, idxTy, idxTy, idxTy, array, dimIdx);
145 mlir::Value len = dims.getResult(1);
146 // We use C indexing here, so len-1 as loopcount
147 mlir::Value loopCount = builder.create<mlir::arith::SubIOp>(loc, len, one);
148 bounds.push_back(loopCount);
151 // Create a loop nest consisting of DoLoopOp operations.
152 // Collect the loops' induction variables into indices array,
153 // which will be used in the innermost loop to load the input
155 // The loops are generated such that the innermost loop processes
157 llvm::SmallVector<mlir::Value, 15> indices;
158 for (unsigned i = rank; 0 < i; --i) {
159 mlir::Value step = one;
160 mlir::Value loopCount = bounds[i - 1];
161 auto loop = builder.create<fir::DoLoopOp>(loc, zeroIdx, loopCount, step,
163 /*finalCountValue=*/false, init);
164 init = loop.getRegionIterArgs()[0];
165 indices.push_back(loop.getInductionVar());
166 // Set insertion point to the loop body so that the next loop
167 // is inserted inside the current one.
168 builder.setInsertionPointToStart(loop.getBody());
171 // Reverse the indices such that they are ordered as:
172 // <dim-0-idx, dim-1-idx, ...>
173 std::reverse(indices.begin(), indices.end());
175 // We are in the innermost loop: generate the reduction body.
176 mlir::Type eleRefTy = builder.getRefType(elementType);
178 builder.create<fir::CoordinateOp>(loc, eleRefTy, array, indices);
179 mlir::Value elem = builder.create<fir::LoadOp>(loc, addr);
181 mlir::Value reductionVal = genBody(builder, loc, elementType, elem, init);
183 // Unwind the loop nest and insert ResultOp on each level
184 // to return the updated value of the reduction to the enclosing
186 for (unsigned i = 0; i < rank; ++i) {
187 auto result = builder.create<fir::ResultOp>(loc, reductionVal);
188 // Proceed to the outer loop.
189 auto loop = mlir::cast<fir::DoLoopOp>(result->getParentOp());
190 reductionVal = loop.getResult(0);
191 // Set insertion point after the loop operation that we have
193 builder.setInsertionPointAfter(loop.getOperation());
196 // End of loop nest. The insertion point is after the outermost loop.
197 // Return the reduction value from the function.
198 builder.create<mlir::func::ReturnOp>(loc, reductionVal);
201 /// Generate function body of the simplified version of RTNAME(Sum)
202 /// with signature provided by \p funcOp. The caller is responsible
203 /// for saving/restoring the original insertion point of \p builder.
204 /// \p funcOp is expected to be empty on entry to this function.
205 /// \p rank specifies the rank of the input argument.
206 static void genRuntimeSumBody(fir::FirOpBuilder &builder,
207 mlir::func::FuncOp &funcOp, unsigned rank) {
208 // function RTNAME(Sum)<T>x<rank>_simplified(arr)
209 // T, dimension(:) :: arr
212 // do iter = 0, extent(arr)
213 // sum = sum + arr[iter]
215 // RTNAME(Sum)<T>x<rank>_simplified = sum
216 // end function RTNAME(Sum)<T>x<rank>_simplified
217 auto zero = [](fir::FirOpBuilder builder, mlir::Location loc,
218 mlir::Type elementType) {
219 if (auto ty = elementType.dyn_cast<mlir::FloatType>()) {
220 const llvm::fltSemantics &sem = ty.getFloatSemantics();
221 return builder.createRealConstant(loc, elementType,
222 llvm::APFloat::getZero(sem));
224 return builder.createIntegerConstant(loc, elementType, 0);
227 auto genBodyOp = [](fir::FirOpBuilder builder, mlir::Location loc,
228 mlir::Type elementType, mlir::Value elem1,
229 mlir::Value elem2) -> mlir::Value {
230 if (elementType.isa<mlir::FloatType>())
231 return builder.create<mlir::arith::AddFOp>(loc, elem1, elem2);
232 if (elementType.isa<mlir::IntegerType>())
233 return builder.create<mlir::arith::AddIOp>(loc, elem1, elem2);
235 llvm_unreachable("unsupported type");
239 genReductionLoop(builder, funcOp, zero, genBodyOp, rank);
242 static void genRuntimeMaxvalBody(fir::FirOpBuilder &builder,
243 mlir::func::FuncOp &funcOp, unsigned rank) {
244 auto init = [](fir::FirOpBuilder builder, mlir::Location loc,
245 mlir::Type elementType) {
246 if (auto ty = elementType.dyn_cast<mlir::FloatType>()) {
247 const llvm::fltSemantics &sem = ty.getFloatSemantics();
248 return builder.createRealConstant(
249 loc, elementType, llvm::APFloat::getLargest(sem, /*Negative=*/true));
251 unsigned bits = elementType.getIntOrFloatBitWidth();
252 int64_t minInt = llvm::APInt::getSignedMinValue(bits).getSExtValue();
253 return builder.createIntegerConstant(loc, elementType, minInt);
256 auto genBodyOp = [](fir::FirOpBuilder builder, mlir::Location loc,
257 mlir::Type elementType, mlir::Value elem1,
258 mlir::Value elem2) -> mlir::Value {
259 if (elementType.isa<mlir::FloatType>())
260 return builder.create<mlir::arith::MaxFOp>(loc, elem1, elem2);
261 if (elementType.isa<mlir::IntegerType>())
262 return builder.create<mlir::arith::MaxSIOp>(loc, elem1, elem2);
264 llvm_unreachable("unsupported type");
267 genReductionLoop(builder, funcOp, init, genBodyOp, rank);
270 /// Generate function type for the simplified version of RTNAME(DotProduct)
271 /// operating on the given \p elementType.
272 static mlir::FunctionType genRuntimeDotType(fir::FirOpBuilder &builder,
273 const mlir::Type &elementType) {
274 mlir::Type boxType = fir::BoxType::get(builder.getNoneType());
275 return mlir::FunctionType::get(builder.getContext(), {boxType, boxType},
279 /// Generate function body of the simplified version of RTNAME(DotProduct)
280 /// with signature provided by \p funcOp. The caller is responsible
281 /// for saving/restoring the original insertion point of \p builder.
282 /// \p funcOp is expected to be empty on entry to this function.
283 /// \p arg1ElementTy and \p arg2ElementTy specify elements types
284 /// of the underlying array objects - they are used to generate proper
285 /// element accesses.
286 static void genRuntimeDotBody(fir::FirOpBuilder &builder,
287 mlir::func::FuncOp &funcOp,
288 mlir::Type arg1ElementTy,
289 mlir::Type arg2ElementTy) {
290 // function RTNAME(DotProduct)<T>_simplified(arr1, arr2)
291 // T, dimension(:) :: arr1, arr2
294 // do iter = 0, extent(arr1)
295 // product = product + arr1[iter] * arr2[iter]
297 // RTNAME(ADotProduct)<T>_simplified = product
298 // end function RTNAME(DotProduct)<T>_simplified
299 auto loc = mlir::UnknownLoc::get(builder.getContext());
300 mlir::Type resultElementType = funcOp.getResultTypes()[0];
301 builder.setInsertionPointToEnd(funcOp.addEntryBlock());
303 mlir::IndexType idxTy = builder.getIndexType();
306 resultElementType.isa<mlir::FloatType>()
307 ? builder.createRealConstant(loc, resultElementType, 0.0)
308 : builder.createIntegerConstant(loc, resultElementType, 0);
310 mlir::Block::BlockArgListType args = funcOp.front().getArguments();
311 mlir::Value arg1 = args[0];
312 mlir::Value arg2 = args[1];
314 mlir::Value zeroIdx = builder.createIntegerConstant(loc, idxTy, 0);
316 fir::SequenceType::Shape flatShape = {fir::SequenceType::getUnknownExtent()};
317 mlir::Type arrTy1 = fir::SequenceType::get(flatShape, arg1ElementTy);
318 mlir::Type boxArrTy1 = fir::BoxType::get(arrTy1);
319 mlir::Value array1 = builder.create<fir::ConvertOp>(loc, boxArrTy1, arg1);
320 mlir::Type arrTy2 = fir::SequenceType::get(flatShape, arg2ElementTy);
321 mlir::Type boxArrTy2 = fir::BoxType::get(arrTy2);
322 mlir::Value array2 = builder.create<fir::ConvertOp>(loc, boxArrTy2, arg2);
323 // This version takes the loop trip count from the first argument.
324 // If the first argument's box has unknown (at compilation time)
325 // extent, then it may be better to take the extent from the second
326 // argument - so that after inlining the loop may be better optimized, e.g.
327 // fully unrolled. This requires generating two versions of the simplified
328 // function and some analysis at the call site to choose which version
329 // is more profitable to call.
330 // Note that we can assume that both arguments have the same extent.
332 builder.create<fir::BoxDimsOp>(loc, idxTy, idxTy, idxTy, array1, zeroIdx);
333 mlir::Value len = dims.getResult(1);
334 mlir::Value one = builder.createIntegerConstant(loc, idxTy, 1);
335 mlir::Value step = one;
337 // We use C indexing here, so len-1 as loopcount
338 mlir::Value loopCount = builder.create<mlir::arith::SubIOp>(loc, len, one);
339 auto loop = builder.create<fir::DoLoopOp>(loc, zeroIdx, loopCount, step,
341 /*finalCountValue=*/false, zero);
342 mlir::Value sumVal = loop.getRegionIterArgs()[0];
345 mlir::OpBuilder::InsertPoint loopEndPt = builder.saveInsertionPoint();
346 builder.setInsertionPointToStart(loop.getBody());
348 mlir::Type eleRef1Ty = builder.getRefType(arg1ElementTy);
349 mlir::Value index = loop.getInductionVar();
351 builder.create<fir::CoordinateOp>(loc, eleRef1Ty, array1, index);
352 mlir::Value elem1 = builder.create<fir::LoadOp>(loc, addr1);
353 // Convert to the result type.
354 elem1 = builder.create<fir::ConvertOp>(loc, resultElementType, elem1);
356 mlir::Type eleRef2Ty = builder.getRefType(arg2ElementTy);
358 builder.create<fir::CoordinateOp>(loc, eleRef2Ty, array2, index);
359 mlir::Value elem2 = builder.create<fir::LoadOp>(loc, addr2);
360 // Convert to the result type.
361 elem2 = builder.create<fir::ConvertOp>(loc, resultElementType, elem2);
363 if (resultElementType.isa<mlir::FloatType>())
364 sumVal = builder.create<mlir::arith::AddFOp>(
365 loc, builder.create<mlir::arith::MulFOp>(loc, elem1, elem2), sumVal);
366 else if (resultElementType.isa<mlir::IntegerType>())
367 sumVal = builder.create<mlir::arith::AddIOp>(
368 loc, builder.create<mlir::arith::MulIOp>(loc, elem1, elem2), sumVal);
370 llvm_unreachable("unsupported type");
372 builder.create<fir::ResultOp>(loc, sumVal);
374 builder.restoreInsertionPoint(loopEndPt);
376 mlir::Value resultVal = loop.getResult(0);
377 builder.create<mlir::func::ReturnOp>(loc, resultVal);
380 mlir::func::FuncOp SimplifyIntrinsicsPass::getOrCreateFunction(
381 fir::FirOpBuilder &builder, const mlir::StringRef &baseName,
382 FunctionTypeGeneratorTy typeGenerator,
383 FunctionBodyGeneratorTy bodyGenerator) {
384 // WARNING: if the function generated here changes its signature
385 // or behavior (the body code), we should probably embed some
386 // versioning information into its name, otherwise libraries
387 // statically linked with older versions of Flang may stop
388 // working with object files created with newer Flang.
389 // We can also avoid this by using internal linkage, but
390 // this may increase the size of final executable/shared library.
391 std::string replacementName = mlir::Twine{baseName, "_simplified"}.str();
392 mlir::ModuleOp module = builder.getModule();
393 // If we already have a function, just return it.
394 mlir::func::FuncOp newFunc =
395 fir::FirOpBuilder::getNamedFunction(module, replacementName);
396 mlir::FunctionType fType = typeGenerator(builder);
398 assert(newFunc.getFunctionType() == fType &&
399 "type mismatch for simplified function");
403 // Need to build the function!
404 auto loc = mlir::UnknownLoc::get(builder.getContext());
406 fir::FirOpBuilder::createFunction(loc, module, replacementName, fType);
407 auto inlineLinkage = mlir::LLVM::linkage::Linkage::LinkonceODR;
409 mlir::LLVM::LinkageAttr::get(builder.getContext(), inlineLinkage);
410 newFunc->setAttr("llvm.linkage", linkage);
412 // Save the position of the original call.
413 mlir::OpBuilder::InsertPoint insertPt = builder.saveInsertionPoint();
415 bodyGenerator(builder, newFunc);
417 // Now back to where we were adding code earlier...
418 builder.restoreInsertionPoint(insertPt);
423 fir::ConvertOp expectConvertOp(mlir::Value val) {
424 if (fir::ConvertOp op =
425 mlir::dyn_cast_or_null<fir::ConvertOp>(val.getDefiningOp()))
427 LLVM_DEBUG(llvm::dbgs() << "Didn't find expected fir::ConvertOp\n");
431 static bool isOperandAbsent(mlir::Value val) {
432 if (auto op = expectConvertOp(val)) {
433 assert(op->getOperands().size() != 0);
434 return mlir::isa_and_nonnull<fir::AbsentOp>(
435 op->getOperand(0).getDefiningOp());
440 static bool isZero(mlir::Value val) {
441 if (auto op = expectConvertOp(val)) {
442 assert(op->getOperands().size() != 0);
443 if (mlir::Operation *defOp = op->getOperand(0).getDefiningOp())
444 return mlir::matchPattern(defOp, mlir::m_Zero());
449 static mlir::Value findBoxDef(mlir::Value val) {
450 if (auto op = expectConvertOp(val)) {
451 assert(op->getOperands().size() != 0);
452 if (auto box = mlir::dyn_cast_or_null<fir::EmboxOp>(
453 op->getOperand(0).getDefiningOp()))
454 return box.getResult();
455 if (auto box = mlir::dyn_cast_or_null<fir::ReboxOp>(
456 op->getOperand(0).getDefiningOp()))
457 return box.getResult();
462 static unsigned getDimCount(mlir::Value val) {
463 // In order to find the dimensions count, we look for EmboxOp/ReboxOp
464 // and take the count from its *result* type. Note that in case
465 // of sliced emboxing the operand and the result of EmboxOp/ReboxOp
466 // have different types.
467 // Actually, we can take the box type from the operand of
468 // the first ConvertOp that has non-opaque box type that we meet
469 // going through the ConvertOp chain.
470 if (mlir::Value emboxVal = findBoxDef(val))
471 if (auto boxTy = emboxVal.getType().dyn_cast<fir::BoxType>())
472 if (auto seqTy = boxTy.getEleTy().dyn_cast<fir::SequenceType>())
473 return seqTy.getDimension();
477 /// Given the call operation's box argument \p val, discover
478 /// the element type of the underlying array object.
479 /// \returns the element type or llvm::None if the type cannot
480 /// be reliably found.
481 /// We expect that the argument is a result of fir.convert
482 /// with the destination type of !fir.box<none>.
483 static llvm::Optional<mlir::Type> getArgElementType(mlir::Value val) {
484 mlir::Operation *defOp;
486 defOp = val.getDefiningOp();
487 // Analyze only sequences of convert operations.
488 if (!mlir::isa<fir::ConvertOp>(defOp))
490 val = defOp->getOperand(0);
491 // The convert operation is expected to convert from one
492 // box type to another box type.
493 auto boxType = val.getType().cast<fir::BoxType>();
494 auto elementType = fir::unwrapSeqOrBoxedSeqType(boxType);
495 if (!elementType.isa<mlir::NoneType>())
500 void SimplifyIntrinsicsPass::simplifyReduction(fir::CallOp call,
501 const fir::KindMapping &kindMap,
502 GenReductionBodyTy genBodyFunc) {
503 mlir::SymbolRefAttr callee = call.getCalleeAttr();
504 mlir::Operation::operand_range args = call.getArgs();
505 // args[1] and args[2] are source filename and line number, ignored.
506 const mlir::Value &dim = args[3];
507 const mlir::Value &mask = args[4];
508 // dim is zero when it is absent, which is an implementation
509 // detail in the runtime library.
510 bool dimAndMaskAbsent = isZero(dim) && isOperandAbsent(mask);
511 unsigned rank = getDimCount(args[0]);
512 if (dimAndMaskAbsent && rank > 0) {
513 mlir::Location loc = call.getLoc();
514 fir::FirOpBuilder builder(call, kindMap);
516 // Support only floating point and integer results now.
517 mlir::Type resultType = call.getResult(0).getType();
518 if (!resultType.isa<mlir::FloatType>() &&
519 !resultType.isa<mlir::IntegerType>())
522 auto argType = getArgElementType(args[0]);
525 assert(*argType == resultType &&
526 "Argument/result types mismatch in reduction");
528 auto typeGenerator = [&resultType](fir::FirOpBuilder &builder) {
529 return genNoneBoxType(builder, resultType);
531 auto bodyGenerator = [&rank, &genBodyFunc](fir::FirOpBuilder &builder,
532 mlir::func::FuncOp &funcOp) {
533 genBodyFunc(builder, funcOp, rank);
535 // Mangle the function name with the rank value as "x<rank>".
536 std::string funcName =
537 (mlir::Twine{callee.getLeafReference().getValue(), "x"} +
540 mlir::func::FuncOp newFunc =
541 getOrCreateFunction(builder, funcName, typeGenerator, bodyGenerator);
543 builder.create<fir::CallOp>(loc, newFunc, mlir::ValueRange{args[0]});
544 call->replaceAllUsesWith(newCall.getResults());
545 call->dropAllReferences();
550 void SimplifyIntrinsicsPass::runOnOperation() {
551 LLVM_DEBUG(llvm::dbgs() << "=== Begin " DEBUG_TYPE " ===\n");
552 mlir::ModuleOp module = getOperation();
553 fir::KindMapping kindMap = fir::getKindMapping(module);
554 module.walk([&](mlir::Operation *op) {
555 if (auto call = mlir::dyn_cast<fir::CallOp>(op)) {
556 if (mlir::SymbolRefAttr callee = call.getCalleeAttr()) {
557 mlir::StringRef funcName = callee.getLeafReference().getValue();
558 // Replace call to runtime function for SUM when it has single
559 // argument (no dim or mask argument) for 1D arrays with either
560 // Integer4 or Real8 types. Other forms are ignored.
561 // The new function is added to the module.
563 // Prototype for runtime call (from sum.cpp):
564 // RTNAME(Sum<T>)(const Descriptor &x, const char *source, int line,
565 // int dim, const Descriptor *mask)
567 if (funcName.startswith(RTNAME_STRING(Sum))) {
568 simplifyReduction(call, kindMap, genRuntimeSumBody);
571 if (funcName.startswith(RTNAME_STRING(DotProduct))) {
572 LLVM_DEBUG(llvm::dbgs() << "Handling " << funcName << "\n");
573 LLVM_DEBUG(llvm::dbgs() << "Call operation:\n"; op->dump();
574 llvm::dbgs() << "\n");
575 mlir::Operation::operand_range args = call.getArgs();
576 const mlir::Value &v1 = args[0];
577 const mlir::Value &v2 = args[1];
578 mlir::Location loc = call.getLoc();
579 fir::FirOpBuilder builder(op, kindMap);
581 mlir::Type type = call.getResult(0).getType();
582 if (!type.isa<mlir::FloatType>() && !type.isa<mlir::IntegerType>())
585 // Try to find the element types of the boxed arguments.
586 auto arg1Type = getArgElementType(v1);
587 auto arg2Type = getArgElementType(v2);
589 if (!arg1Type || !arg2Type)
592 // Support only floating point and integer arguments
593 // now (e.g. logical is skipped here).
594 if (!arg1Type->isa<mlir::FloatType>() &&
595 !arg1Type->isa<mlir::IntegerType>())
597 if (!arg2Type->isa<mlir::FloatType>() &&
598 !arg2Type->isa<mlir::IntegerType>())
601 auto typeGenerator = [&type](fir::FirOpBuilder &builder) {
602 return genRuntimeDotType(builder, type);
604 auto bodyGenerator = [&arg1Type,
605 &arg2Type](fir::FirOpBuilder &builder,
606 mlir::func::FuncOp &funcOp) {
607 genRuntimeDotBody(builder, funcOp, *arg1Type, *arg2Type);
610 // Suffix the function name with the element types
612 std::string typedFuncName(funcName);
613 llvm::raw_string_ostream nameOS(typedFuncName);
615 arg1Type->print(nameOS);
617 arg2Type->print(nameOS);
619 mlir::func::FuncOp newFunc = getOrCreateFunction(
620 builder, typedFuncName, typeGenerator, bodyGenerator);
621 auto newCall = builder.create<fir::CallOp>(loc, newFunc,
622 mlir::ValueRange{v1, v2});
623 call->replaceAllUsesWith(newCall.getResults());
624 call->dropAllReferences();
627 LLVM_DEBUG(llvm::dbgs() << "Replaced with:\n"; newCall.dump();
628 llvm::dbgs() << "\n");
631 if (funcName.startswith(RTNAME_STRING(Maxval))) {
632 simplifyReduction(call, kindMap, genRuntimeMaxvalBody);
638 LLVM_DEBUG(llvm::dbgs() << "=== End " DEBUG_TYPE " ===\n");
641 void SimplifyIntrinsicsPass::getDependentDialects(
642 mlir::DialectRegistry ®istry) const {
643 // LLVM::LinkageAttr creation requires that LLVM dialect is loaded.
644 registry.insert<mlir::LLVM::LLVMDialect>();
646 std::unique_ptr<mlir::Pass> fir::createSimplifyIntrinsicsPass() {
647 return std::make_unique<SimplifyIntrinsicsPass>();