1 //===- SimplifyIntrinsics.cpp -- replace intrinsics with simpler form -----===//
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
7 //===----------------------------------------------------------------------===//
9 //===----------------------------------------------------------------------===//
11 /// This pass looks for suitable calls to runtime library for intrinsics that
12 /// can be simplified/specialized and replaces with a specialized function.
14 /// For example, SUM(arr) can be specialized as a simple function with one loop,
15 /// compared to the three arguments (plus file & line info) that the runtime
16 /// call has - when the argument is a 1D-array (multiple loops may be needed
17 // for higher dimension arrays, of course)
19 /// The general idea is that besides making the call simpler, it can also be
20 /// inlined by other passes that run after this pass, which further improves
21 /// performance, particularly when the work done in the function is trivial
22 /// and small in size.
23 //===----------------------------------------------------------------------===//
25 #include "PassDetail.h"
26 #include "flang/Optimizer/Builder/BoxValue.h"
27 #include "flang/Optimizer/Builder/FIRBuilder.h"
28 #include "flang/Optimizer/Builder/Todo.h"
29 #include "flang/Optimizer/Dialect/FIROps.h"
30 #include "flang/Optimizer/Dialect/FIRType.h"
31 #include "flang/Optimizer/Support/FIRContext.h"
32 #include "flang/Optimizer/Transforms/Passes.h"
33 #include "mlir/IR/Matchers.h"
34 #include "mlir/Pass/Pass.h"
35 #include "mlir/Transforms/DialectConversion.h"
36 #include "mlir/Transforms/GreedyPatternRewriteDriver.h"
37 #include "mlir/Transforms/RegionUtils.h"
41 class SimplifyIntrinsicsPass
42 : public fir::SimplifyIntrinsicsBase<SimplifyIntrinsicsPass> {
44 mlir::func::FuncOp getOrCreateFunction(const mlir::Location &loc,
45 fir::FirOpBuilder &builder,
46 const mlir::Type &type,
47 const mlir::StringRef &basename);
48 void runOnOperation() override;
53 mlir::func::FuncOp SimplifyIntrinsicsPass::getOrCreateFunction(
54 const mlir::Location &loc, fir::FirOpBuilder &builder,
55 const mlir::Type &type, const mlir::StringRef &baseName) {
56 // In future, the idea is that instead of building the function inside
57 // this function, this does the base creation, and calls a callback
58 // function (e.g. a lambda function) that fills in the actual content.
59 // For now, check that it's the ONLY the SUM runtime call.
60 assert(baseName.startswith("_FortranASum"));
62 std::string replacementName = mlir::Twine{baseName, "_simplified"}.str();
63 mlir::ModuleOp module = builder.getModule();
64 // If we already have a function, just return it.
65 mlir::func::FuncOp newFunc =
66 fir::FirOpBuilder::getNamedFunction(module, replacementName);
70 // Need to build the function!
72 // function FortranASum<T>_simplified(arr)
73 // T, dimension(:) :: arr
76 // do iter = 0, extent(arr)
77 // sum = sum + arr[iter]
79 // FortranASum<T>_simplified = sum
80 // end function FortranASum<T>_simplified
81 mlir::Type boxType = fir::BoxType::get(builder.getNoneType());
82 mlir::FunctionType fType =
83 mlir::FunctionType::get(builder.getContext(), {boxType}, {type});
85 fir::FirOpBuilder::createFunction(loc, module, replacementName, fType);
86 auto inlineLinkage = mlir::LLVM::linkage::Linkage::LinkonceODR;
88 mlir::LLVM::LinkageAttr::get(builder.getContext(), inlineLinkage);
89 newFunc->setAttr("llvm.linkage", linkage);
91 // Save the position of the original call.
92 mlir::OpBuilder::InsertPoint insertPt = builder.saveInsertionPoint();
93 builder.setInsertionPointToEnd(newFunc.addEntryBlock());
95 mlir::IndexType idxTy = builder.getIndexType();
97 mlir::Value zero = type.isa<mlir::FloatType>()
98 ? builder.createRealConstant(loc, type, 0.0)
99 : builder.createIntegerConstant(loc, type, 0);
100 mlir::Value sum = builder.create<fir::AllocaOp>(loc, type);
101 builder.create<fir::StoreOp>(loc, zero, sum);
103 mlir::Block::BlockArgListType args = newFunc.front().getArguments();
104 mlir::Value arg = args[0];
106 mlir::Value zeroIdx = builder.createIntegerConstant(loc, idxTy, 0);
108 fir::SequenceType::Shape flatShape = {fir::SequenceType::getUnknownExtent()};
109 mlir::Type arrTy = fir::SequenceType::get(flatShape, type);
110 mlir::Type boxArrTy = fir::BoxType::get(arrTy);
111 mlir::Value array = builder.create<fir::ConvertOp>(loc, boxArrTy, arg);
113 builder.create<fir::BoxDimsOp>(loc, idxTy, idxTy, idxTy, array, zeroIdx);
114 mlir::Value len = dims.getResult(1);
115 mlir::Value one = builder.createIntegerConstant(loc, idxTy, 1);
116 mlir::Value step = one;
118 // We use C indexing here, so len-1 as loopcount
119 mlir::Value loopCount = builder.create<mlir::arith::SubIOp>(loc, len, one);
120 auto loop = builder.create<fir::DoLoopOp>(loc, zeroIdx, loopCount, step);
123 mlir::OpBuilder::InsertPoint loopEndPt = builder.saveInsertionPoint();
124 builder.setInsertionPointToStart(loop.getBody());
126 mlir::Type eleRefTy = builder.getRefType(type);
127 mlir::Value index = loop.getInductionVar();
129 builder.create<fir::CoordinateOp>(loc, eleRefTy, array, index);
130 mlir::Value elem = builder.create<fir::LoadOp>(loc, addr);
131 mlir::Value sumVal = builder.create<fir::LoadOp>(loc, sum);
134 if (type.isa<mlir::FloatType>())
135 res = builder.create<mlir::arith::AddFOp>(loc, elem, sumVal);
136 else if (type.isa<mlir::IntegerType>())
137 res = builder.create<mlir::arith::AddIOp>(loc, elem, sumVal);
139 TODO(loc, "Unsupported type");
141 builder.create<fir::StoreOp>(loc, res, sum);
143 builder.restoreInsertionPoint(loopEndPt);
145 mlir::Value resultVal = builder.create<fir::LoadOp>(loc, sum);
146 builder.create<mlir::func::ReturnOp>(loc, resultVal);
148 // Now back to where we were adding code earlier...
149 builder.restoreInsertionPoint(insertPt);
154 static bool isOperandAbsent(mlir::Value val) {
155 if (mlir::Operation *op = val.getDefiningOp())
156 return mlir::isa_and_nonnull<fir::AbsentOp>(
157 op->getOperand(0).getDefiningOp());
161 static bool isZero(mlir::Value val) {
162 if (mlir::Operation *op = val.getDefiningOp())
163 if (mlir::Operation *defOp = op->getOperand(0).getDefiningOp())
164 return mlir::matchPattern(defOp, mlir::m_Zero());
168 static mlir::Value findShape(mlir::Value val) {
169 mlir::Operation *defOp = val.getDefiningOp();
171 defOp = defOp->getOperand(0).getDefiningOp();
172 if (fir::EmboxOp box = mlir::dyn_cast_or_null<fir::EmboxOp>(defOp))
173 return box.getShape();
178 static unsigned getDimCount(mlir::Value val) {
179 if (mlir::Value shapeVal = findShape(val)) {
180 mlir::Type resType = shapeVal.getDefiningOp()->getResultTypes()[0];
181 return fir::getRankOfShapeType(resType);
186 void SimplifyIntrinsicsPass::runOnOperation() {
187 mlir::ModuleOp module = getOperation();
188 fir::KindMapping kindMap = fir::getKindMapping(module);
189 module.walk([&](mlir::Operation *op) {
190 if (auto call = mlir::dyn_cast<fir::CallOp>(op)) {
191 if (mlir::SymbolRefAttr callee = call.getCalleeAttr()) {
192 mlir::StringRef funcName = callee.getLeafReference().getValue();
193 // Replace call to runtime function for SUM when it has single
194 // argument (no dim or mask argument) for 1D arrays with either
195 // Integer4 or Real8 types. Other forms are ignored.
196 // The new function is added to the module.
198 // Prototype for runtime call (from sum.cpp):
199 // RTNAME(Sum<T>)(const Descriptor &x, const char *source, int line,
200 // int dim, const Descriptor *mask)
201 if (funcName.startswith("_FortranASum")) {
202 mlir::Operation::operand_range args = call.getArgs();
203 // args[1] and args[2] are source filename and line number, ignored.
204 const mlir::Value &dim = args[3];
205 const mlir::Value &mask = args[4];
206 // dim is zero when it is absent, which is an implementation
207 // detail in the runtime library.
208 bool dimAndMaskAbsent = isZero(dim) && isOperandAbsent(mask);
209 unsigned rank = getDimCount(args[0]);
210 if (dimAndMaskAbsent && rank == 1) {
211 mlir::Location loc = call.getLoc();
213 fir::FirOpBuilder builder(op, kindMap);
214 if (funcName.endswith("Integer4")) {
215 type = mlir::IntegerType::get(builder.getContext(), 32);
216 } else if (funcName.endswith("Real8")) {
217 type = mlir::FloatType::getF64(builder.getContext());
221 mlir::func::FuncOp newFunc =
222 getOrCreateFunction(loc, builder, type, funcName);
223 auto newCall = builder.create<fir::CallOp>(
224 loc, newFunc, mlir::ValueRange{args[0]});
225 call->replaceAllUsesWith(newCall.getResults());
226 call->dropAllReferences();
235 std::unique_ptr<mlir::Pass> fir::createSimplifyIntrinsicsPass() {
236 return std::make_unique<SimplifyIntrinsicsPass>();