Prepare for inlining of SUM intrinsic
[platform/upstream/llvm.git] / flang / lib / Optimizer / Transforms / SimplifyIntrinsics.cpp
1 //===- SimplifyIntrinsics.cpp -- replace intrinsics with simpler form -----===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8
9 //===----------------------------------------------------------------------===//
10 /// \file
11 /// This pass looks for suitable calls to runtime library for intrinsics that
12 /// can be simplified/specialized and replaces with a specialized function.
13 ///
14 /// For example, SUM(arr) can be specialized as a simple function with one loop,
15 /// compared to the three arguments (plus file & line info) that the runtime
16 /// call has - when the argument is a 1D-array (multiple loops may be needed
17 //  for higher dimension arrays, of course)
18 ///
19 /// The general idea is that besides making the call simpler, it can also be
20 /// inlined by other passes that run after this pass, which further improves
21 /// performance, particularly when the work done in the function is trivial
22 /// and small in size.
23 //===----------------------------------------------------------------------===//
24
25 #include "PassDetail.h"
26 #include "flang/Optimizer/Builder/BoxValue.h"
27 #include "flang/Optimizer/Builder/FIRBuilder.h"
28 #include "flang/Optimizer/Builder/Todo.h"
29 #include "flang/Optimizer/Dialect/FIROps.h"
30 #include "flang/Optimizer/Dialect/FIRType.h"
31 #include "flang/Optimizer/Support/FIRContext.h"
32 #include "flang/Optimizer/Transforms/Passes.h"
33 #include "mlir/IR/Matchers.h"
34 #include "mlir/Pass/Pass.h"
35 #include "mlir/Transforms/DialectConversion.h"
36 #include "mlir/Transforms/GreedyPatternRewriteDriver.h"
37 #include "mlir/Transforms/RegionUtils.h"
38
39 namespace {
40
41 class SimplifyIntrinsicsPass
42     : public fir::SimplifyIntrinsicsBase<SimplifyIntrinsicsPass> {
43 public:
44   mlir::func::FuncOp getOrCreateFunction(const mlir::Location &loc,
45                                          fir::FirOpBuilder &builder,
46                                          const mlir::Type &type,
47                                          const mlir::StringRef &basename);
48   void runOnOperation() override;
49 };
50
51 } // namespace
52
53 mlir::func::FuncOp SimplifyIntrinsicsPass::getOrCreateFunction(
54     const mlir::Location &loc, fir::FirOpBuilder &builder,
55     const mlir::Type &type, const mlir::StringRef &baseName) {
56   // In future, the idea is that instead of building the function inside
57   // this function, this does the base creation, and calls a callback
58   // function (e.g. a lambda function) that fills in the actual content.
59   // For now, check that it's the ONLY the SUM runtime call.
60   assert(baseName.startswith("_FortranASum"));
61
62   std::string replacementName = mlir::Twine{baseName, "_simplified"}.str();
63   mlir::ModuleOp module = builder.getModule();
64   // If we already have a function, just return it.
65   mlir::func::FuncOp newFunc =
66       fir::FirOpBuilder::getNamedFunction(module, replacementName);
67   if (newFunc)
68     return newFunc;
69
70   // Need to build the function!
71   // Basic idea:
72   // function FortranASum<T>_simplified(arr)
73   //   T, dimension(:) :: arr
74   //   T sum = 0
75   //   integer iter
76   //   do iter = 0, extent(arr)
77   //     sum = sum + arr[iter]
78   //   end do
79   //   FortranASum<T>_simplified = sum
80   // end function FortranASum<T>_simplified
81   mlir::Type boxType = fir::BoxType::get(builder.getNoneType());
82   mlir::FunctionType fType =
83       mlir::FunctionType::get(builder.getContext(), {boxType}, {type});
84   newFunc =
85       fir::FirOpBuilder::createFunction(loc, module, replacementName, fType);
86   auto inlineLinkage = mlir::LLVM::linkage::Linkage::LinkonceODR;
87   auto linkage =
88       mlir::LLVM::LinkageAttr::get(builder.getContext(), inlineLinkage);
89   newFunc->setAttr("llvm.linkage", linkage);
90
91   // Save the position of the original call.
92   mlir::OpBuilder::InsertPoint insertPt = builder.saveInsertionPoint();
93   builder.setInsertionPointToEnd(newFunc.addEntryBlock());
94
95   mlir::IndexType idxTy = builder.getIndexType();
96
97   mlir::Value zero = type.isa<mlir::FloatType>()
98                          ? builder.createRealConstant(loc, type, 0.0)
99                          : builder.createIntegerConstant(loc, type, 0);
100   mlir::Value sum = builder.create<fir::AllocaOp>(loc, type);
101   builder.create<fir::StoreOp>(loc, zero, sum);
102
103   mlir::Block::BlockArgListType args = newFunc.front().getArguments();
104   mlir::Value arg = args[0];
105
106   mlir::Value zeroIdx = builder.createIntegerConstant(loc, idxTy, 0);
107
108   fir::SequenceType::Shape flatShape = {fir::SequenceType::getUnknownExtent()};
109   mlir::Type arrTy = fir::SequenceType::get(flatShape, type);
110   mlir::Type boxArrTy = fir::BoxType::get(arrTy);
111   mlir::Value array = builder.create<fir::ConvertOp>(loc, boxArrTy, arg);
112   auto dims =
113       builder.create<fir::BoxDimsOp>(loc, idxTy, idxTy, idxTy, array, zeroIdx);
114   mlir::Value len = dims.getResult(1);
115   mlir::Value one = builder.createIntegerConstant(loc, idxTy, 1);
116   mlir::Value step = one;
117
118   // We use C indexing here, so len-1 as loopcount
119   mlir::Value loopCount = builder.create<mlir::arith::SubIOp>(loc, len, one);
120   auto loop = builder.create<fir::DoLoopOp>(loc, zeroIdx, loopCount, step);
121
122   // Begin loop code
123   mlir::OpBuilder::InsertPoint loopEndPt = builder.saveInsertionPoint();
124   builder.setInsertionPointToStart(loop.getBody());
125
126   mlir::Type eleRefTy = builder.getRefType(type);
127   mlir::Value index = loop.getInductionVar();
128   mlir::Value addr =
129       builder.create<fir::CoordinateOp>(loc, eleRefTy, array, index);
130   mlir::Value elem = builder.create<fir::LoadOp>(loc, addr);
131   mlir::Value sumVal = builder.create<fir::LoadOp>(loc, sum);
132
133   mlir::Value res;
134   if (type.isa<mlir::FloatType>())
135     res = builder.create<mlir::arith::AddFOp>(loc, elem, sumVal);
136   else if (type.isa<mlir::IntegerType>())
137     res = builder.create<mlir::arith::AddIOp>(loc, elem, sumVal);
138   else
139     TODO(loc, "Unsupported type");
140
141   builder.create<fir::StoreOp>(loc, res, sum);
142   // End of loop.
143   builder.restoreInsertionPoint(loopEndPt);
144
145   mlir::Value resultVal = builder.create<fir::LoadOp>(loc, sum);
146   builder.create<mlir::func::ReturnOp>(loc, resultVal);
147
148   // Now back to where we were adding code earlier...
149   builder.restoreInsertionPoint(insertPt);
150
151   return newFunc;
152 }
153
154 static bool isOperandAbsent(mlir::Value val) {
155   if (mlir::Operation *op = val.getDefiningOp())
156     return mlir::isa_and_nonnull<fir::AbsentOp>(
157         op->getOperand(0).getDefiningOp());
158   return false;
159 }
160
161 static bool isZero(mlir::Value val) {
162   if (mlir::Operation *op = val.getDefiningOp())
163     if (mlir::Operation *defOp = op->getOperand(0).getDefiningOp())
164       return mlir::matchPattern(defOp, mlir::m_Zero());
165   return false;
166 }
167
168 static mlir::Value findShape(mlir::Value val) {
169   mlir::Operation *defOp = val.getDefiningOp();
170   while (defOp) {
171     defOp = defOp->getOperand(0).getDefiningOp();
172     if (fir::EmboxOp box = mlir::dyn_cast_or_null<fir::EmboxOp>(defOp))
173       return box.getShape();
174   }
175   return {};
176 }
177
178 static unsigned getDimCount(mlir::Value val) {
179   if (mlir::Value shapeVal = findShape(val)) {
180     mlir::Type resType = shapeVal.getDefiningOp()->getResultTypes()[0];
181     return fir::getRankOfShapeType(resType);
182   }
183   return 0;
184 }
185
186 void SimplifyIntrinsicsPass::runOnOperation() {
187   mlir::ModuleOp module = getOperation();
188   fir::KindMapping kindMap = fir::getKindMapping(module);
189   module.walk([&](mlir::Operation *op) {
190     if (auto call = mlir::dyn_cast<fir::CallOp>(op)) {
191       if (mlir::SymbolRefAttr callee = call.getCalleeAttr()) {
192         mlir::StringRef funcName = callee.getLeafReference().getValue();
193         // Replace call to runtime function for SUM when it has single
194         // argument (no dim or mask argument) for 1D arrays with either
195         // Integer4 or Real8 types. Other forms are ignored.
196         // The new function is added to the module.
197         //
198         // Prototype for runtime call (from sum.cpp):
199         // RTNAME(Sum<T>)(const Descriptor &x, const char *source, int line,
200         //                int dim, const Descriptor *mask)
201         if (funcName.startswith("_FortranASum")) {
202           mlir::Operation::operand_range args = call.getArgs();
203           // args[1] and args[2] are source filename and line number, ignored.
204           const mlir::Value &dim = args[3];
205           const mlir::Value &mask = args[4];
206           // dim is zero when it is absent, which is an implementation
207           // detail in the runtime library.
208           bool dimAndMaskAbsent = isZero(dim) && isOperandAbsent(mask);
209           unsigned rank = getDimCount(args[0]);
210           if (dimAndMaskAbsent && rank == 1) {
211             mlir::Location loc = call.getLoc();
212             mlir::Type type;
213             fir::FirOpBuilder builder(op, kindMap);
214             if (funcName.endswith("Integer4")) {
215               type = mlir::IntegerType::get(builder.getContext(), 32);
216             } else if (funcName.endswith("Real8")) {
217               type = mlir::FloatType::getF64(builder.getContext());
218             } else {
219               return;
220             }
221             mlir::func::FuncOp newFunc =
222                 getOrCreateFunction(loc, builder, type, funcName);
223             auto newCall = builder.create<fir::CallOp>(
224                 loc, newFunc, mlir::ValueRange{args[0]});
225             call->replaceAllUsesWith(newCall.getResults());
226             call->dropAllReferences();
227             call->erase();
228           }
229         }
230       }
231     }
232   });
233 }
234
235 std::unique_ptr<mlir::Pass> fir::createSimplifyIntrinsicsPass() {
236   return std::make_unique<SimplifyIntrinsicsPass>();
237 }