[flang] Support DOT_PRODUCT in late inlining.
[platform/upstream/llvm.git] / flang / lib / Optimizer / Transforms / SimplifyIntrinsics.cpp
1 //===- SimplifyIntrinsics.cpp -- replace intrinsics with simpler form -----===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8
9 //===----------------------------------------------------------------------===//
10 /// \file
11 /// This pass looks for suitable calls to runtime library for intrinsics that
12 /// can be simplified/specialized and replaces with a specialized function.
13 ///
14 /// For example, SUM(arr) can be specialized as a simple function with one loop,
15 /// compared to the three arguments (plus file & line info) that the runtime
16 /// call has - when the argument is a 1D-array (multiple loops may be needed
17 //  for higher dimension arrays, of course)
18 ///
19 /// The general idea is that besides making the call simpler, it can also be
20 /// inlined by other passes that run after this pass, which further improves
21 /// performance, particularly when the work done in the function is trivial
22 /// and small in size.
23 //===----------------------------------------------------------------------===//
24
25 #include "PassDetail.h"
26 #include "flang/Optimizer/Builder/BoxValue.h"
27 #include "flang/Optimizer/Builder/FIRBuilder.h"
28 #include "flang/Optimizer/Builder/Todo.h"
29 #include "flang/Optimizer/Dialect/FIROps.h"
30 #include "flang/Optimizer/Dialect/FIRType.h"
31 #include "flang/Optimizer/Support/FIRContext.h"
32 #include "flang/Optimizer/Transforms/Passes.h"
33 #include "mlir/IR/Matchers.h"
34 #include "mlir/Pass/Pass.h"
35 #include "mlir/Transforms/DialectConversion.h"
36 #include "mlir/Transforms/GreedyPatternRewriteDriver.h"
37 #include "mlir/Transforms/RegionUtils.h"
38 #include "llvm/Support/Debug.h"
39
40 #define DEBUG_TYPE "flang-simplify-intrinsics"
41
42 namespace {
43
44 class SimplifyIntrinsicsPass
45     : public fir::SimplifyIntrinsicsBase<SimplifyIntrinsicsPass> {
46   using FunctionTypeGeneratorTy =
47       std::function<mlir::FunctionType(fir::FirOpBuilder &)>;
48   using FunctionBodyGeneratorTy =
49       std::function<void(fir::FirOpBuilder &, mlir::func::FuncOp &)>;
50
51 public:
52   /// Generate a new function implementing a simplified version
53   /// of a Fortran runtime function defined by \p basename name.
54   /// \p typeGenerator is a callback that generates the new function's type.
55   /// \p bodyGenerator is a callback that generates the new function's body.
56   /// The new function is created in the \p builder's Module.
57   mlir::func::FuncOp getOrCreateFunction(fir::FirOpBuilder &builder,
58                                          const mlir::StringRef &basename,
59                                          FunctionTypeGeneratorTy typeGenerator,
60                                          FunctionBodyGeneratorTy bodyGenerator);
61   void runOnOperation() override;
62   void getDependentDialects(mlir::DialectRegistry &registry) const override;
63 };
64
65 } // namespace
66
67 /// Generate function type for the simplified version of FortranASum
68 /// operating on the given \p elementType.
69 static mlir::FunctionType genFortranASumType(fir::FirOpBuilder &builder,
70                                              const mlir::Type &elementType) {
71   mlir::Type boxType = fir::BoxType::get(builder.getNoneType());
72   return mlir::FunctionType::get(builder.getContext(), {boxType},
73                                  {elementType});
74 }
75
76 /// Generate function body of the simplified version of FortranASum
77 /// with signature provided by \p funcOp. The caller is responsible
78 /// for saving/restoring the original insertion point of \p builder.
79 /// \p funcOp is expected to be empty on entry to this function.
80 static void genFortranASumBody(fir::FirOpBuilder &builder,
81                                mlir::func::FuncOp &funcOp) {
82   // function FortranASum<T>_simplified(arr)
83   //   T, dimension(:) :: arr
84   //   T sum = 0
85   //   integer iter
86   //   do iter = 0, extent(arr)
87   //     sum = sum + arr[iter]
88   //   end do
89   //   FortranASum<T>_simplified = sum
90   // end function FortranASum<T>_simplified
91   auto loc = mlir::UnknownLoc::get(builder.getContext());
92   mlir::Type elementType = funcOp.getResultTypes()[0];
93   builder.setInsertionPointToEnd(funcOp.addEntryBlock());
94
95   mlir::IndexType idxTy = builder.getIndexType();
96
97   mlir::Value zero = elementType.isa<mlir::FloatType>()
98                          ? builder.createRealConstant(loc, elementType, 0.0)
99                          : builder.createIntegerConstant(loc, elementType, 0);
100   mlir::Value sum = builder.create<fir::AllocaOp>(loc, elementType);
101   builder.create<fir::StoreOp>(loc, zero, sum);
102
103   mlir::Block::BlockArgListType args = funcOp.front().getArguments();
104   mlir::Value arg = args[0];
105
106   mlir::Value zeroIdx = builder.createIntegerConstant(loc, idxTy, 0);
107
108   fir::SequenceType::Shape flatShape = {fir::SequenceType::getUnknownExtent()};
109   mlir::Type arrTy = fir::SequenceType::get(flatShape, elementType);
110   mlir::Type boxArrTy = fir::BoxType::get(arrTy);
111   mlir::Value array = builder.create<fir::ConvertOp>(loc, boxArrTy, arg);
112   auto dims =
113       builder.create<fir::BoxDimsOp>(loc, idxTy, idxTy, idxTy, array, zeroIdx);
114   mlir::Value len = dims.getResult(1);
115   mlir::Value one = builder.createIntegerConstant(loc, idxTy, 1);
116   mlir::Value step = one;
117
118   // We use C indexing here, so len-1 as loopcount
119   mlir::Value loopCount = builder.create<mlir::arith::SubIOp>(loc, len, one);
120   auto loop = builder.create<fir::DoLoopOp>(loc, zeroIdx, loopCount, step);
121
122   // Begin loop code
123   mlir::OpBuilder::InsertPoint loopEndPt = builder.saveInsertionPoint();
124   builder.setInsertionPointToStart(loop.getBody());
125
126   mlir::Type eleRefTy = builder.getRefType(elementType);
127   mlir::Value index = loop.getInductionVar();
128   mlir::Value addr =
129       builder.create<fir::CoordinateOp>(loc, eleRefTy, array, index);
130   mlir::Value elem = builder.create<fir::LoadOp>(loc, addr);
131   mlir::Value sumVal = builder.create<fir::LoadOp>(loc, sum);
132
133   mlir::Value res;
134   if (elementType.isa<mlir::FloatType>())
135     res = builder.create<mlir::arith::AddFOp>(loc, elem, sumVal);
136   else if (elementType.isa<mlir::IntegerType>())
137     res = builder.create<mlir::arith::AddIOp>(loc, elem, sumVal);
138   else
139     TODO(loc, "Unsupported type");
140
141   builder.create<fir::StoreOp>(loc, res, sum);
142   // End of loop.
143   builder.restoreInsertionPoint(loopEndPt);
144
145   mlir::Value resultVal = builder.create<fir::LoadOp>(loc, sum);
146   builder.create<mlir::func::ReturnOp>(loc, resultVal);
147 }
148
149 /// Generate function type for the simplified version of FortranADotProduct
150 /// operating on the given \p elementType.
151 static mlir::FunctionType genFortranADotType(fir::FirOpBuilder &builder,
152                                              const mlir::Type &elementType) {
153   mlir::Type boxType = fir::BoxType::get(builder.getNoneType());
154   return mlir::FunctionType::get(builder.getContext(), {boxType, boxType},
155                                  {elementType});
156 }
157
158 /// Generate function body of the simplified version of FortranADotProduct
159 /// with signature provided by \p funcOp. The caller is responsible
160 /// for saving/restoring the original insertion point of \p builder.
161 /// \p funcOp is expected to be empty on entry to this function.
162 static void genFortranADotBody(fir::FirOpBuilder &builder,
163                                mlir::func::FuncOp &funcOp) {
164   // function FortranADotProduct<T>_simplified(arr1, arr2)
165   //   T, dimension(:) :: arr1, arr2
166   //   T product = 0
167   //   integer iter
168   //   do iter = 0, extent(arr1)
169   //     product = product + arr1[iter] * arr2[iter]
170   //   end do
171   //   FortranADotProduct<T>_simplified = product
172   // end function FortranADotProduct<T>_simplified
173   auto loc = mlir::UnknownLoc::get(builder.getContext());
174   mlir::Type elementType = funcOp.getResultTypes()[0];
175   builder.setInsertionPointToEnd(funcOp.addEntryBlock());
176
177   mlir::IndexType idxTy = builder.getIndexType();
178
179   mlir::Value zero = elementType.isa<mlir::FloatType>()
180                          ? builder.createRealConstant(loc, elementType, 0.0)
181                          : builder.createIntegerConstant(loc, elementType, 0);
182
183   mlir::Block::BlockArgListType args = funcOp.front().getArguments();
184   mlir::Value arg1 = args[0];
185   mlir::Value arg2 = args[1];
186
187   mlir::Value zeroIdx = builder.createIntegerConstant(loc, idxTy, 0);
188
189   fir::SequenceType::Shape flatShape = {fir::SequenceType::getUnknownExtent()};
190   mlir::Type arrTy = fir::SequenceType::get(flatShape, elementType);
191   mlir::Type boxArrTy = fir::BoxType::get(arrTy);
192   mlir::Value array1 = builder.create<fir::ConvertOp>(loc, boxArrTy, arg1);
193   mlir::Value array2 = builder.create<fir::ConvertOp>(loc, boxArrTy, arg2);
194   // This version takes the loop trip count from the first argument.
195   // If the first argument's box has unknown (at compilation time)
196   // extent, then it may be better to take the extent from the second
197   // argument - so that after inlining the loop may be better optimized, e.g.
198   // fully unrolled. This requires generating two versions of the simplified
199   // function and some analysis at the call site to choose which version
200   // is more profitable to call.
201   // Note that we can assume that both arguments have the same extent.
202   auto dims =
203       builder.create<fir::BoxDimsOp>(loc, idxTy, idxTy, idxTy, array1, zeroIdx);
204   mlir::Value len = dims.getResult(1);
205   mlir::Value one = builder.createIntegerConstant(loc, idxTy, 1);
206   mlir::Value step = one;
207
208   // We use C indexing here, so len-1 as loopcount
209   mlir::Value loopCount = builder.create<mlir::arith::SubIOp>(loc, len, one);
210   auto loop = builder.create<fir::DoLoopOp>(loc, zeroIdx, loopCount, step,
211                                             /*unordered=*/false,
212                                             /*finalCountValue=*/false, zero);
213   mlir::Value sumVal = loop.getRegionIterArgs()[0];
214
215   // Begin loop code
216   mlir::OpBuilder::InsertPoint loopEndPt = builder.saveInsertionPoint();
217   builder.setInsertionPointToStart(loop.getBody());
218
219   mlir::Type eleRefTy = builder.getRefType(elementType);
220   mlir::Value index = loop.getInductionVar();
221   mlir::Value addr1 =
222       builder.create<fir::CoordinateOp>(loc, eleRefTy, array1, index);
223   mlir::Value elem1 = builder.create<fir::LoadOp>(loc, addr1);
224   mlir::Value addr2 =
225       builder.create<fir::CoordinateOp>(loc, eleRefTy, array2, index);
226   mlir::Value elem2 = builder.create<fir::LoadOp>(loc, addr2);
227
228   if (elementType.isa<mlir::FloatType>())
229     sumVal = builder.create<mlir::arith::AddFOp>(
230         loc, builder.create<mlir::arith::MulFOp>(loc, elem1, elem2), sumVal);
231   else if (elementType.isa<mlir::IntegerType>())
232     sumVal = builder.create<mlir::arith::AddIOp>(
233         loc, builder.create<mlir::arith::MulIOp>(loc, elem1, elem2), sumVal);
234   else
235     llvm_unreachable("unsupported type");
236
237   builder.create<fir::ResultOp>(loc, sumVal);
238   // End of loop.
239   builder.restoreInsertionPoint(loopEndPt);
240
241   mlir::Value resultVal = loop.getResult(0);
242   builder.create<mlir::func::ReturnOp>(loc, resultVal);
243 }
244
245 mlir::func::FuncOp SimplifyIntrinsicsPass::getOrCreateFunction(
246     fir::FirOpBuilder &builder, const mlir::StringRef &baseName,
247     FunctionTypeGeneratorTy typeGenerator,
248     FunctionBodyGeneratorTy bodyGenerator) {
249   // WARNING: if the function generated here changes its signature
250   //          or behavior (the body code), we should probably embed some
251   //          versioning information into its name, otherwise libraries
252   //          statically linked with older versions of Flang may stop
253   //          working with object files created with newer Flang.
254   //          We can also avoid this by using internal linkage, but
255   //          this may increase the size of final executable/shared library.
256   std::string replacementName = mlir::Twine{baseName, "_simplified"}.str();
257   mlir::ModuleOp module = builder.getModule();
258   // If we already have a function, just return it.
259   mlir::func::FuncOp newFunc =
260       fir::FirOpBuilder::getNamedFunction(module, replacementName);
261   mlir::FunctionType fType = typeGenerator(builder);
262   if (newFunc) {
263     assert(newFunc.getFunctionType() == fType &&
264            "type mismatch for simplified function");
265     return newFunc;
266   }
267
268   // Need to build the function!
269   auto loc = mlir::UnknownLoc::get(builder.getContext());
270   newFunc =
271       fir::FirOpBuilder::createFunction(loc, module, replacementName, fType);
272   auto inlineLinkage = mlir::LLVM::linkage::Linkage::LinkonceODR;
273   auto linkage =
274       mlir::LLVM::LinkageAttr::get(builder.getContext(), inlineLinkage);
275   newFunc->setAttr("llvm.linkage", linkage);
276
277   // Save the position of the original call.
278   mlir::OpBuilder::InsertPoint insertPt = builder.saveInsertionPoint();
279
280   bodyGenerator(builder, newFunc);
281
282   // Now back to where we were adding code earlier...
283   builder.restoreInsertionPoint(insertPt);
284
285   return newFunc;
286 }
287
288 static bool isOperandAbsent(mlir::Value val) {
289   if (mlir::Operation *op = val.getDefiningOp())
290     return mlir::isa_and_nonnull<fir::AbsentOp>(
291         op->getOperand(0).getDefiningOp());
292   return false;
293 }
294
295 static bool isZero(mlir::Value val) {
296   if (mlir::Operation *op = val.getDefiningOp())
297     if (mlir::Operation *defOp = op->getOperand(0).getDefiningOp())
298       return mlir::matchPattern(defOp, mlir::m_Zero());
299   return false;
300 }
301
302 static mlir::Value findShape(mlir::Value val) {
303   mlir::Operation *defOp = val.getDefiningOp();
304   while (defOp) {
305     defOp = defOp->getOperand(0).getDefiningOp();
306     if (fir::EmboxOp box = mlir::dyn_cast_or_null<fir::EmboxOp>(defOp))
307       return box.getShape();
308   }
309   return {};
310 }
311
312 static unsigned getDimCount(mlir::Value val) {
313   if (mlir::Value shapeVal = findShape(val)) {
314     mlir::Type resType = shapeVal.getDefiningOp()->getResultTypes()[0];
315     return fir::getRankOfShapeType(resType);
316   }
317   return 0;
318 }
319
320 void SimplifyIntrinsicsPass::runOnOperation() {
321   LLVM_DEBUG(llvm::dbgs() << "=== Begin " DEBUG_TYPE " ===\n");
322   mlir::ModuleOp module = getOperation();
323   fir::KindMapping kindMap = fir::getKindMapping(module);
324   module.walk([&](mlir::Operation *op) {
325     if (auto call = mlir::dyn_cast<fir::CallOp>(op)) {
326       if (mlir::SymbolRefAttr callee = call.getCalleeAttr()) {
327         mlir::StringRef funcName = callee.getLeafReference().getValue();
328         // Replace call to runtime function for SUM when it has single
329         // argument (no dim or mask argument) for 1D arrays with either
330         // Integer4 or Real8 types. Other forms are ignored.
331         // The new function is added to the module.
332         //
333         // Prototype for runtime call (from sum.cpp):
334         // RTNAME(Sum<T>)(const Descriptor &x, const char *source, int line,
335         //                int dim, const Descriptor *mask)
336         if (funcName.startswith("_FortranASum")) {
337           mlir::Operation::operand_range args = call.getArgs();
338           // args[1] and args[2] are source filename and line number, ignored.
339           const mlir::Value &dim = args[3];
340           const mlir::Value &mask = args[4];
341           // dim is zero when it is absent, which is an implementation
342           // detail in the runtime library.
343           bool dimAndMaskAbsent = isZero(dim) && isOperandAbsent(mask);
344           unsigned rank = getDimCount(args[0]);
345           if (dimAndMaskAbsent && rank == 1) {
346             mlir::Location loc = call.getLoc();
347             mlir::Type type;
348             fir::FirOpBuilder builder(op, kindMap);
349             if (funcName.endswith("Integer4")) {
350               type = mlir::IntegerType::get(builder.getContext(), 32);
351             } else if (funcName.endswith("Real8")) {
352               type = mlir::FloatType::getF64(builder.getContext());
353             } else {
354               return;
355             }
356             auto typeGenerator = [&type](fir::FirOpBuilder &builder) {
357               return genFortranASumType(builder, type);
358             };
359             mlir::func::FuncOp newFunc = getOrCreateFunction(
360                 builder, funcName, typeGenerator, genFortranASumBody);
361             auto newCall = builder.create<fir::CallOp>(
362                 loc, newFunc, mlir::ValueRange{args[0]});
363             call->replaceAllUsesWith(newCall.getResults());
364             call->dropAllReferences();
365             call->erase();
366           }
367
368           return;
369         }
370         if (funcName.startswith("_FortranADotProduct")) {
371           LLVM_DEBUG(llvm::dbgs() << "Handling " << funcName << "\n");
372           LLVM_DEBUG(llvm::dbgs() << "Call operation:\n"; op->dump();
373                      llvm::dbgs() << "\n");
374           mlir::Operation::operand_range args = call.getArgs();
375           const mlir::Value &v1 = args[0];
376           const mlir::Value &v2 = args[1];
377           mlir::Location loc = call.getLoc();
378           fir::FirOpBuilder builder(op, kindMap);
379           mlir::Type type = call.getResult(0).getType();
380           if (!type.isa<mlir::FloatType>() && !type.isa<mlir::IntegerType>())
381             return;
382
383           auto typeGenerator = [&type](fir::FirOpBuilder &builder) {
384             return genFortranADotType(builder, type);
385           };
386           mlir::func::FuncOp newFunc = getOrCreateFunction(
387               builder, funcName, typeGenerator, genFortranADotBody);
388           auto newCall = builder.create<fir::CallOp>(loc, newFunc,
389                                                      mlir::ValueRange{v1, v2});
390           call->replaceAllUsesWith(newCall.getResults());
391           call->dropAllReferences();
392           call->erase();
393
394           LLVM_DEBUG(llvm::dbgs() << "Replaced with:\n"; newCall.dump();
395                      llvm::dbgs() << "\n");
396           return;
397         }
398       }
399     }
400   });
401   LLVM_DEBUG(llvm::dbgs() << "=== End " DEBUG_TYPE " ===\n");
402 }
403
404 void SimplifyIntrinsicsPass::getDependentDialects(
405     mlir::DialectRegistry &registry) const {
406   // LLVM::LinkageAttr creation requires that LLVM dialect is loaded.
407   registry.insert<mlir::LLVM::LLVMDialect>();
408 }
409 std::unique_ptr<mlir::Pass> fir::createSimplifyIntrinsicsPass() {
410   return std::make_unique<SimplifyIntrinsicsPass>();
411 }