1 //===- SimplifyIntrinsics.cpp -- replace intrinsics with simpler form -----===//
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
7 //===----------------------------------------------------------------------===//
9 //===----------------------------------------------------------------------===//
11 /// This pass looks for suitable calls to runtime library for intrinsics that
12 /// can be simplified/specialized and replaces with a specialized function.
14 /// For example, SUM(arr) can be specialized as a simple function with one loop,
15 /// compared to the three arguments (plus file & line info) that the runtime
16 /// call has - when the argument is a 1D-array (multiple loops may be needed
17 // for higher dimension arrays, of course)
19 /// The general idea is that besides making the call simpler, it can also be
20 /// inlined by other passes that run after this pass, which further improves
21 /// performance, particularly when the work done in the function is trivial
22 /// and small in size.
23 //===----------------------------------------------------------------------===//
25 #include "PassDetail.h"
26 #include "flang/Optimizer/Builder/BoxValue.h"
27 #include "flang/Optimizer/Builder/FIRBuilder.h"
28 #include "flang/Optimizer/Builder/Todo.h"
29 #include "flang/Optimizer/Dialect/FIROps.h"
30 #include "flang/Optimizer/Dialect/FIRType.h"
31 #include "flang/Optimizer/Support/FIRContext.h"
32 #include "flang/Optimizer/Transforms/Passes.h"
33 #include "mlir/IR/Matchers.h"
34 #include "mlir/IR/TypeUtilities.h"
35 #include "mlir/Pass/Pass.h"
36 #include "mlir/Transforms/DialectConversion.h"
37 #include "mlir/Transforms/GreedyPatternRewriteDriver.h"
38 #include "mlir/Transforms/RegionUtils.h"
39 #include "llvm/ADT/Optional.h"
40 #include "llvm/Support/Debug.h"
41 #include "llvm/Support/raw_ostream.h"
43 #define DEBUG_TYPE "flang-simplify-intrinsics"
47 class SimplifyIntrinsicsPass
48 : public fir::SimplifyIntrinsicsBase<SimplifyIntrinsicsPass> {
49 using FunctionTypeGeneratorTy =
50 std::function<mlir::FunctionType(fir::FirOpBuilder &)>;
51 using FunctionBodyGeneratorTy =
52 std::function<void(fir::FirOpBuilder &, mlir::func::FuncOp &)>;
55 /// Generate a new function implementing a simplified version
56 /// of a Fortran runtime function defined by \p basename name.
57 /// \p typeGenerator is a callback that generates the new function's type.
58 /// \p bodyGenerator is a callback that generates the new function's body.
59 /// The new function is created in the \p builder's Module.
60 mlir::func::FuncOp getOrCreateFunction(fir::FirOpBuilder &builder,
61 const mlir::StringRef &basename,
62 FunctionTypeGeneratorTy typeGenerator,
63 FunctionBodyGeneratorTy bodyGenerator);
64 void runOnOperation() override;
65 void getDependentDialects(mlir::DialectRegistry ®istry) const override;
70 /// Generate function type for the simplified version of FortranASum
71 /// operating on the given \p elementType.
72 static mlir::FunctionType genFortranASumType(fir::FirOpBuilder &builder,
73 const mlir::Type &elementType) {
74 mlir::Type boxType = fir::BoxType::get(builder.getNoneType());
75 return mlir::FunctionType::get(builder.getContext(), {boxType},
79 /// Generate function body of the simplified version of FortranASum
80 /// with signature provided by \p funcOp. The caller is responsible
81 /// for saving/restoring the original insertion point of \p builder.
82 /// \p funcOp is expected to be empty on entry to this function.
83 static void genFortranASumBody(fir::FirOpBuilder &builder,
84 mlir::func::FuncOp &funcOp) {
85 // function FortranASum<T>_simplified(arr)
86 // T, dimension(:) :: arr
89 // do iter = 0, extent(arr)
90 // sum = sum + arr[iter]
92 // FortranASum<T>_simplified = sum
93 // end function FortranASum<T>_simplified
94 auto loc = mlir::UnknownLoc::get(builder.getContext());
95 mlir::Type elementType = funcOp.getResultTypes()[0];
96 builder.setInsertionPointToEnd(funcOp.addEntryBlock());
98 mlir::IndexType idxTy = builder.getIndexType();
100 mlir::Value zero = elementType.isa<mlir::FloatType>()
101 ? builder.createRealConstant(loc, elementType, 0.0)
102 : builder.createIntegerConstant(loc, elementType, 0);
103 mlir::Value sum = builder.create<fir::AllocaOp>(loc, elementType);
104 builder.create<fir::StoreOp>(loc, zero, sum);
106 mlir::Block::BlockArgListType args = funcOp.front().getArguments();
107 mlir::Value arg = args[0];
109 mlir::Value zeroIdx = builder.createIntegerConstant(loc, idxTy, 0);
111 fir::SequenceType::Shape flatShape = {fir::SequenceType::getUnknownExtent()};
112 mlir::Type arrTy = fir::SequenceType::get(flatShape, elementType);
113 mlir::Type boxArrTy = fir::BoxType::get(arrTy);
114 mlir::Value array = builder.create<fir::ConvertOp>(loc, boxArrTy, arg);
116 builder.create<fir::BoxDimsOp>(loc, idxTy, idxTy, idxTy, array, zeroIdx);
117 mlir::Value len = dims.getResult(1);
118 mlir::Value one = builder.createIntegerConstant(loc, idxTy, 1);
119 mlir::Value step = one;
121 // We use C indexing here, so len-1 as loopcount
122 mlir::Value loopCount = builder.create<mlir::arith::SubIOp>(loc, len, one);
123 auto loop = builder.create<fir::DoLoopOp>(loc, zeroIdx, loopCount, step);
126 mlir::OpBuilder::InsertPoint loopEndPt = builder.saveInsertionPoint();
127 builder.setInsertionPointToStart(loop.getBody());
129 mlir::Type eleRefTy = builder.getRefType(elementType);
130 mlir::Value index = loop.getInductionVar();
132 builder.create<fir::CoordinateOp>(loc, eleRefTy, array, index);
133 mlir::Value elem = builder.create<fir::LoadOp>(loc, addr);
134 mlir::Value sumVal = builder.create<fir::LoadOp>(loc, sum);
137 if (elementType.isa<mlir::FloatType>())
138 res = builder.create<mlir::arith::AddFOp>(loc, elem, sumVal);
139 else if (elementType.isa<mlir::IntegerType>())
140 res = builder.create<mlir::arith::AddIOp>(loc, elem, sumVal);
142 TODO(loc, "Unsupported type");
144 builder.create<fir::StoreOp>(loc, res, sum);
146 builder.restoreInsertionPoint(loopEndPt);
148 mlir::Value resultVal = builder.create<fir::LoadOp>(loc, sum);
149 builder.create<mlir::func::ReturnOp>(loc, resultVal);
152 /// Generate function type for the simplified version of FortranADotProduct
153 /// operating on the given \p elementType.
154 static mlir::FunctionType genFortranADotType(fir::FirOpBuilder &builder,
155 const mlir::Type &elementType) {
156 mlir::Type boxType = fir::BoxType::get(builder.getNoneType());
157 return mlir::FunctionType::get(builder.getContext(), {boxType, boxType},
161 /// Generate function body of the simplified version of FortranADotProduct
162 /// with signature provided by \p funcOp. The caller is responsible
163 /// for saving/restoring the original insertion point of \p builder.
164 /// \p funcOp is expected to be empty on entry to this function.
165 /// \p arg1ElementTy and \p arg2ElementTy specify elements types
166 /// of the underlying array objects - they are used to generate proper
167 /// element accesses.
168 static void genFortranADotBody(fir::FirOpBuilder &builder,
169 mlir::func::FuncOp &funcOp,
170 mlir::Type arg1ElementTy,
171 mlir::Type arg2ElementTy) {
172 // function FortranADotProduct<T>_simplified(arr1, arr2)
173 // T, dimension(:) :: arr1, arr2
176 // do iter = 0, extent(arr1)
177 // product = product + arr1[iter] * arr2[iter]
179 // FortranADotProduct<T>_simplified = product
180 // end function FortranADotProduct<T>_simplified
181 auto loc = mlir::UnknownLoc::get(builder.getContext());
182 mlir::Type resultElementType = funcOp.getResultTypes()[0];
183 builder.setInsertionPointToEnd(funcOp.addEntryBlock());
185 mlir::IndexType idxTy = builder.getIndexType();
188 resultElementType.isa<mlir::FloatType>()
189 ? builder.createRealConstant(loc, resultElementType, 0.0)
190 : builder.createIntegerConstant(loc, resultElementType, 0);
192 mlir::Block::BlockArgListType args = funcOp.front().getArguments();
193 mlir::Value arg1 = args[0];
194 mlir::Value arg2 = args[1];
196 mlir::Value zeroIdx = builder.createIntegerConstant(loc, idxTy, 0);
198 fir::SequenceType::Shape flatShape = {fir::SequenceType::getUnknownExtent()};
199 mlir::Type arrTy1 = fir::SequenceType::get(flatShape, arg1ElementTy);
200 mlir::Type boxArrTy1 = fir::BoxType::get(arrTy1);
201 mlir::Value array1 = builder.create<fir::ConvertOp>(loc, boxArrTy1, arg1);
202 mlir::Type arrTy2 = fir::SequenceType::get(flatShape, arg2ElementTy);
203 mlir::Type boxArrTy2 = fir::BoxType::get(arrTy2);
204 mlir::Value array2 = builder.create<fir::ConvertOp>(loc, boxArrTy2, arg2);
205 // This version takes the loop trip count from the first argument.
206 // If the first argument's box has unknown (at compilation time)
207 // extent, then it may be better to take the extent from the second
208 // argument - so that after inlining the loop may be better optimized, e.g.
209 // fully unrolled. This requires generating two versions of the simplified
210 // function and some analysis at the call site to choose which version
211 // is more profitable to call.
212 // Note that we can assume that both arguments have the same extent.
214 builder.create<fir::BoxDimsOp>(loc, idxTy, idxTy, idxTy, array1, zeroIdx);
215 mlir::Value len = dims.getResult(1);
216 mlir::Value one = builder.createIntegerConstant(loc, idxTy, 1);
217 mlir::Value step = one;
219 // We use C indexing here, so len-1 as loopcount
220 mlir::Value loopCount = builder.create<mlir::arith::SubIOp>(loc, len, one);
221 auto loop = builder.create<fir::DoLoopOp>(loc, zeroIdx, loopCount, step,
223 /*finalCountValue=*/false, zero);
224 mlir::Value sumVal = loop.getRegionIterArgs()[0];
227 mlir::OpBuilder::InsertPoint loopEndPt = builder.saveInsertionPoint();
228 builder.setInsertionPointToStart(loop.getBody());
230 mlir::Type eleRef1Ty = builder.getRefType(arg1ElementTy);
231 mlir::Value index = loop.getInductionVar();
233 builder.create<fir::CoordinateOp>(loc, eleRef1Ty, array1, index);
234 mlir::Value elem1 = builder.create<fir::LoadOp>(loc, addr1);
235 // Convert to the result type.
236 elem1 = builder.create<fir::ConvertOp>(loc, resultElementType, elem1);
238 mlir::Type eleRef2Ty = builder.getRefType(arg2ElementTy);
240 builder.create<fir::CoordinateOp>(loc, eleRef2Ty, array2, index);
241 mlir::Value elem2 = builder.create<fir::LoadOp>(loc, addr2);
242 // Convert to the result type.
243 elem2 = builder.create<fir::ConvertOp>(loc, resultElementType, elem2);
245 if (resultElementType.isa<mlir::FloatType>())
246 sumVal = builder.create<mlir::arith::AddFOp>(
247 loc, builder.create<mlir::arith::MulFOp>(loc, elem1, elem2), sumVal);
248 else if (resultElementType.isa<mlir::IntegerType>())
249 sumVal = builder.create<mlir::arith::AddIOp>(
250 loc, builder.create<mlir::arith::MulIOp>(loc, elem1, elem2), sumVal);
252 llvm_unreachable("unsupported type");
254 builder.create<fir::ResultOp>(loc, sumVal);
256 builder.restoreInsertionPoint(loopEndPt);
258 mlir::Value resultVal = loop.getResult(0);
259 builder.create<mlir::func::ReturnOp>(loc, resultVal);
262 mlir::func::FuncOp SimplifyIntrinsicsPass::getOrCreateFunction(
263 fir::FirOpBuilder &builder, const mlir::StringRef &baseName,
264 FunctionTypeGeneratorTy typeGenerator,
265 FunctionBodyGeneratorTy bodyGenerator) {
266 // WARNING: if the function generated here changes its signature
267 // or behavior (the body code), we should probably embed some
268 // versioning information into its name, otherwise libraries
269 // statically linked with older versions of Flang may stop
270 // working with object files created with newer Flang.
271 // We can also avoid this by using internal linkage, but
272 // this may increase the size of final executable/shared library.
273 std::string replacementName = mlir::Twine{baseName, "_simplified"}.str();
274 mlir::ModuleOp module = builder.getModule();
275 // If we already have a function, just return it.
276 mlir::func::FuncOp newFunc =
277 fir::FirOpBuilder::getNamedFunction(module, replacementName);
278 mlir::FunctionType fType = typeGenerator(builder);
280 assert(newFunc.getFunctionType() == fType &&
281 "type mismatch for simplified function");
285 // Need to build the function!
286 auto loc = mlir::UnknownLoc::get(builder.getContext());
288 fir::FirOpBuilder::createFunction(loc, module, replacementName, fType);
289 auto inlineLinkage = mlir::LLVM::linkage::Linkage::LinkonceODR;
291 mlir::LLVM::LinkageAttr::get(builder.getContext(), inlineLinkage);
292 newFunc->setAttr("llvm.linkage", linkage);
294 // Save the position of the original call.
295 mlir::OpBuilder::InsertPoint insertPt = builder.saveInsertionPoint();
297 bodyGenerator(builder, newFunc);
299 // Now back to where we were adding code earlier...
300 builder.restoreInsertionPoint(insertPt);
305 fir::ConvertOp expectConvertOp(mlir::Value val) {
306 if (fir::ConvertOp op =
307 mlir::dyn_cast_or_null<fir::ConvertOp>(val.getDefiningOp()))
309 LLVM_DEBUG(llvm::dbgs() << "Didn't find expected fir::ConvertOp\n");
313 static bool isOperandAbsent(mlir::Value val) {
314 if (auto op = expectConvertOp(val)) {
315 assert(op->getOperands().size() != 0);
316 return mlir::isa_and_nonnull<fir::AbsentOp>(
317 op->getOperand(0).getDefiningOp());
322 static bool isZero(mlir::Value val) {
323 if (auto op = expectConvertOp(val)) {
324 assert(op->getOperands().size() != 0);
325 if (mlir::Operation *defOp = op->getOperand(0).getDefiningOp())
326 return mlir::matchPattern(defOp, mlir::m_Zero());
331 static mlir::Value findShape(mlir::Value val) {
332 if (auto op = expectConvertOp(val)) {
333 assert(op->getOperands().size() != 0);
335 mlir::dyn_cast<fir::EmboxOp>(op->getOperand(0).getDefiningOp()))
336 return box.getShape();
341 static unsigned getDimCount(mlir::Value val) {
342 if (mlir::Value shapeVal = findShape(val)) {
343 mlir::Type resType = shapeVal.getDefiningOp()->getResultTypes()[0];
344 return fir::getRankOfShapeType(resType);
349 /// Given the call operation's box argument \p val, discover
350 /// the element type of the underlying array object.
351 /// \returns the element type or llvm::None if the type cannot
352 /// be reliably found.
353 /// We expect that the argument is a result of fir.convert
354 /// with the destination type of !fir.box<none>.
355 static llvm::Optional<mlir::Type> getArgElementType(mlir::Value val) {
356 mlir::Operation *defOp;
358 defOp = val.getDefiningOp();
359 // Analyze only sequences of convert operations.
360 if (!mlir::isa<fir::ConvertOp>(defOp))
362 val = defOp->getOperand(0);
363 // The convert operation is expected to convert from one
364 // box type to another box type.
365 auto boxType = val.getType().cast<fir::BoxType>();
366 auto elementType = fir::unwrapSeqOrBoxedSeqType(boxType);
367 if (!elementType.isa<mlir::NoneType>())
372 void SimplifyIntrinsicsPass::runOnOperation() {
373 LLVM_DEBUG(llvm::dbgs() << "=== Begin " DEBUG_TYPE " ===\n");
374 mlir::ModuleOp module = getOperation();
375 fir::KindMapping kindMap = fir::getKindMapping(module);
376 module.walk([&](mlir::Operation *op) {
377 if (auto call = mlir::dyn_cast<fir::CallOp>(op)) {
378 if (mlir::SymbolRefAttr callee = call.getCalleeAttr()) {
379 mlir::StringRef funcName = callee.getLeafReference().getValue();
380 // Replace call to runtime function for SUM when it has single
381 // argument (no dim or mask argument) for 1D arrays with either
382 // Integer4 or Real8 types. Other forms are ignored.
383 // The new function is added to the module.
385 // Prototype for runtime call (from sum.cpp):
386 // RTNAME(Sum<T>)(const Descriptor &x, const char *source, int line,
387 // int dim, const Descriptor *mask)
388 if (funcName.startswith("_FortranASum")) {
389 mlir::Operation::operand_range args = call.getArgs();
390 // args[1] and args[2] are source filename and line number, ignored.
391 const mlir::Value &dim = args[3];
392 const mlir::Value &mask = args[4];
393 // dim is zero when it is absent, which is an implementation
394 // detail in the runtime library.
395 bool dimAndMaskAbsent = isZero(dim) && isOperandAbsent(mask);
396 unsigned rank = getDimCount(args[0]);
397 if (dimAndMaskAbsent && rank == 1) {
398 mlir::Location loc = call.getLoc();
400 fir::FirOpBuilder builder(op, kindMap);
401 if (funcName.endswith("Integer4")) {
402 type = mlir::IntegerType::get(builder.getContext(), 32);
403 } else if (funcName.endswith("Real8")) {
404 type = mlir::FloatType::getF64(builder.getContext());
408 auto typeGenerator = [&type](fir::FirOpBuilder &builder) {
409 return genFortranASumType(builder, type);
411 mlir::func::FuncOp newFunc = getOrCreateFunction(
412 builder, funcName, typeGenerator, genFortranASumBody);
413 auto newCall = builder.create<fir::CallOp>(
414 loc, newFunc, mlir::ValueRange{args[0]});
415 call->replaceAllUsesWith(newCall.getResults());
416 call->dropAllReferences();
422 if (funcName.startswith("_FortranADotProduct")) {
423 LLVM_DEBUG(llvm::dbgs() << "Handling " << funcName << "\n");
424 LLVM_DEBUG(llvm::dbgs() << "Call operation:\n"; op->dump();
425 llvm::dbgs() << "\n");
426 mlir::Operation::operand_range args = call.getArgs();
427 const mlir::Value &v1 = args[0];
428 const mlir::Value &v2 = args[1];
429 mlir::Location loc = call.getLoc();
430 fir::FirOpBuilder builder(op, kindMap);
431 mlir::Type type = call.getResult(0).getType();
432 if (!type.isa<mlir::FloatType>() && !type.isa<mlir::IntegerType>())
435 // Try to find the element types of the boxed arguments.
436 auto arg1Type = getArgElementType(v1);
437 auto arg2Type = getArgElementType(v2);
439 if (!arg1Type || !arg2Type)
442 // Support only floating point and integer arguments
443 // now (e.g. logical is skipped here).
444 if (!arg1Type->isa<mlir::FloatType>() &&
445 !arg1Type->isa<mlir::IntegerType>())
447 if (!arg2Type->isa<mlir::FloatType>() &&
448 !arg2Type->isa<mlir::IntegerType>())
451 auto typeGenerator = [&type](fir::FirOpBuilder &builder) {
452 return genFortranADotType(builder, type);
454 auto bodyGenerator = [&arg1Type,
455 &arg2Type](fir::FirOpBuilder &builder,
456 mlir::func::FuncOp &funcOp) {
457 genFortranADotBody(builder, funcOp, *arg1Type, *arg2Type);
460 // Suffix the function name with the element types
462 std::string typedFuncName(funcName);
463 llvm::raw_string_ostream nameOS(typedFuncName);
465 arg1Type->print(nameOS);
467 arg2Type->print(nameOS);
469 mlir::func::FuncOp newFunc = getOrCreateFunction(
470 builder, typedFuncName, typeGenerator, bodyGenerator);
471 auto newCall = builder.create<fir::CallOp>(loc, newFunc,
472 mlir::ValueRange{v1, v2});
473 call->replaceAllUsesWith(newCall.getResults());
474 call->dropAllReferences();
477 LLVM_DEBUG(llvm::dbgs() << "Replaced with:\n"; newCall.dump();
478 llvm::dbgs() << "\n");
484 LLVM_DEBUG(llvm::dbgs() << "=== End " DEBUG_TYPE " ===\n");
487 void SimplifyIntrinsicsPass::getDependentDialects(
488 mlir::DialectRegistry ®istry) const {
489 // LLVM::LinkageAttr creation requires that LLVM dialect is loaded.
490 registry.insert<mlir::LLVM::LLVMDialect>();
492 std::unique_ptr<mlir::Pass> fir::createSimplifyIntrinsicsPass() {
493 return std::make_unique<SimplifyIntrinsicsPass>();