[flang] Handle mixed types in DOT_PRODUCT simplification.
[platform/upstream/llvm.git] / flang / lib / Optimizer / Transforms / SimplifyIntrinsics.cpp
1 //===- SimplifyIntrinsics.cpp -- replace intrinsics with simpler form -----===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8
9 //===----------------------------------------------------------------------===//
10 /// \file
11 /// This pass looks for suitable calls to runtime library for intrinsics that
12 /// can be simplified/specialized and replaces with a specialized function.
13 ///
14 /// For example, SUM(arr) can be specialized as a simple function with one loop,
15 /// compared to the three arguments (plus file & line info) that the runtime
16 /// call has - when the argument is a 1D-array (multiple loops may be needed
17 //  for higher dimension arrays, of course)
18 ///
19 /// The general idea is that besides making the call simpler, it can also be
20 /// inlined by other passes that run after this pass, which further improves
21 /// performance, particularly when the work done in the function is trivial
22 /// and small in size.
23 //===----------------------------------------------------------------------===//
24
25 #include "PassDetail.h"
26 #include "flang/Optimizer/Builder/BoxValue.h"
27 #include "flang/Optimizer/Builder/FIRBuilder.h"
28 #include "flang/Optimizer/Builder/Todo.h"
29 #include "flang/Optimizer/Dialect/FIROps.h"
30 #include "flang/Optimizer/Dialect/FIRType.h"
31 #include "flang/Optimizer/Support/FIRContext.h"
32 #include "flang/Optimizer/Transforms/Passes.h"
33 #include "mlir/IR/Matchers.h"
34 #include "mlir/IR/TypeUtilities.h"
35 #include "mlir/Pass/Pass.h"
36 #include "mlir/Transforms/DialectConversion.h"
37 #include "mlir/Transforms/GreedyPatternRewriteDriver.h"
38 #include "mlir/Transforms/RegionUtils.h"
39 #include "llvm/ADT/Optional.h"
40 #include "llvm/Support/Debug.h"
41 #include "llvm/Support/raw_ostream.h"
42
43 #define DEBUG_TYPE "flang-simplify-intrinsics"
44
45 namespace {
46
47 class SimplifyIntrinsicsPass
48     : public fir::SimplifyIntrinsicsBase<SimplifyIntrinsicsPass> {
49   using FunctionTypeGeneratorTy =
50       std::function<mlir::FunctionType(fir::FirOpBuilder &)>;
51   using FunctionBodyGeneratorTy =
52       std::function<void(fir::FirOpBuilder &, mlir::func::FuncOp &)>;
53
54 public:
55   /// Generate a new function implementing a simplified version
56   /// of a Fortran runtime function defined by \p basename name.
57   /// \p typeGenerator is a callback that generates the new function's type.
58   /// \p bodyGenerator is a callback that generates the new function's body.
59   /// The new function is created in the \p builder's Module.
60   mlir::func::FuncOp getOrCreateFunction(fir::FirOpBuilder &builder,
61                                          const mlir::StringRef &basename,
62                                          FunctionTypeGeneratorTy typeGenerator,
63                                          FunctionBodyGeneratorTy bodyGenerator);
64   void runOnOperation() override;
65   void getDependentDialects(mlir::DialectRegistry &registry) const override;
66 };
67
68 } // namespace
69
70 /// Generate function type for the simplified version of FortranASum
71 /// operating on the given \p elementType.
72 static mlir::FunctionType genFortranASumType(fir::FirOpBuilder &builder,
73                                              const mlir::Type &elementType) {
74   mlir::Type boxType = fir::BoxType::get(builder.getNoneType());
75   return mlir::FunctionType::get(builder.getContext(), {boxType},
76                                  {elementType});
77 }
78
79 /// Generate function body of the simplified version of FortranASum
80 /// with signature provided by \p funcOp. The caller is responsible
81 /// for saving/restoring the original insertion point of \p builder.
82 /// \p funcOp is expected to be empty on entry to this function.
83 static void genFortranASumBody(fir::FirOpBuilder &builder,
84                                mlir::func::FuncOp &funcOp) {
85   // function FortranASum<T>_simplified(arr)
86   //   T, dimension(:) :: arr
87   //   T sum = 0
88   //   integer iter
89   //   do iter = 0, extent(arr)
90   //     sum = sum + arr[iter]
91   //   end do
92   //   FortranASum<T>_simplified = sum
93   // end function FortranASum<T>_simplified
94   auto loc = mlir::UnknownLoc::get(builder.getContext());
95   mlir::Type elementType = funcOp.getResultTypes()[0];
96   builder.setInsertionPointToEnd(funcOp.addEntryBlock());
97
98   mlir::IndexType idxTy = builder.getIndexType();
99
100   mlir::Value zero = elementType.isa<mlir::FloatType>()
101                          ? builder.createRealConstant(loc, elementType, 0.0)
102                          : builder.createIntegerConstant(loc, elementType, 0);
103   mlir::Value sum = builder.create<fir::AllocaOp>(loc, elementType);
104   builder.create<fir::StoreOp>(loc, zero, sum);
105
106   mlir::Block::BlockArgListType args = funcOp.front().getArguments();
107   mlir::Value arg = args[0];
108
109   mlir::Value zeroIdx = builder.createIntegerConstant(loc, idxTy, 0);
110
111   fir::SequenceType::Shape flatShape = {fir::SequenceType::getUnknownExtent()};
112   mlir::Type arrTy = fir::SequenceType::get(flatShape, elementType);
113   mlir::Type boxArrTy = fir::BoxType::get(arrTy);
114   mlir::Value array = builder.create<fir::ConvertOp>(loc, boxArrTy, arg);
115   auto dims =
116       builder.create<fir::BoxDimsOp>(loc, idxTy, idxTy, idxTy, array, zeroIdx);
117   mlir::Value len = dims.getResult(1);
118   mlir::Value one = builder.createIntegerConstant(loc, idxTy, 1);
119   mlir::Value step = one;
120
121   // We use C indexing here, so len-1 as loopcount
122   mlir::Value loopCount = builder.create<mlir::arith::SubIOp>(loc, len, one);
123   auto loop = builder.create<fir::DoLoopOp>(loc, zeroIdx, loopCount, step);
124
125   // Begin loop code
126   mlir::OpBuilder::InsertPoint loopEndPt = builder.saveInsertionPoint();
127   builder.setInsertionPointToStart(loop.getBody());
128
129   mlir::Type eleRefTy = builder.getRefType(elementType);
130   mlir::Value index = loop.getInductionVar();
131   mlir::Value addr =
132       builder.create<fir::CoordinateOp>(loc, eleRefTy, array, index);
133   mlir::Value elem = builder.create<fir::LoadOp>(loc, addr);
134   mlir::Value sumVal = builder.create<fir::LoadOp>(loc, sum);
135
136   mlir::Value res;
137   if (elementType.isa<mlir::FloatType>())
138     res = builder.create<mlir::arith::AddFOp>(loc, elem, sumVal);
139   else if (elementType.isa<mlir::IntegerType>())
140     res = builder.create<mlir::arith::AddIOp>(loc, elem, sumVal);
141   else
142     TODO(loc, "Unsupported type");
143
144   builder.create<fir::StoreOp>(loc, res, sum);
145   // End of loop.
146   builder.restoreInsertionPoint(loopEndPt);
147
148   mlir::Value resultVal = builder.create<fir::LoadOp>(loc, sum);
149   builder.create<mlir::func::ReturnOp>(loc, resultVal);
150 }
151
152 /// Generate function type for the simplified version of FortranADotProduct
153 /// operating on the given \p elementType.
154 static mlir::FunctionType genFortranADotType(fir::FirOpBuilder &builder,
155                                              const mlir::Type &elementType) {
156   mlir::Type boxType = fir::BoxType::get(builder.getNoneType());
157   return mlir::FunctionType::get(builder.getContext(), {boxType, boxType},
158                                  {elementType});
159 }
160
161 /// Generate function body of the simplified version of FortranADotProduct
162 /// with signature provided by \p funcOp. The caller is responsible
163 /// for saving/restoring the original insertion point of \p builder.
164 /// \p funcOp is expected to be empty on entry to this function.
165 /// \p arg1ElementTy and \p arg2ElementTy specify elements types
166 /// of the underlying array objects - they are used to generate proper
167 /// element accesses.
168 static void genFortranADotBody(fir::FirOpBuilder &builder,
169                                mlir::func::FuncOp &funcOp,
170                                mlir::Type arg1ElementTy,
171                                mlir::Type arg2ElementTy) {
172   // function FortranADotProduct<T>_simplified(arr1, arr2)
173   //   T, dimension(:) :: arr1, arr2
174   //   T product = 0
175   //   integer iter
176   //   do iter = 0, extent(arr1)
177   //     product = product + arr1[iter] * arr2[iter]
178   //   end do
179   //   FortranADotProduct<T>_simplified = product
180   // end function FortranADotProduct<T>_simplified
181   auto loc = mlir::UnknownLoc::get(builder.getContext());
182   mlir::Type resultElementType = funcOp.getResultTypes()[0];
183   builder.setInsertionPointToEnd(funcOp.addEntryBlock());
184
185   mlir::IndexType idxTy = builder.getIndexType();
186
187   mlir::Value zero =
188       resultElementType.isa<mlir::FloatType>()
189           ? builder.createRealConstant(loc, resultElementType, 0.0)
190           : builder.createIntegerConstant(loc, resultElementType, 0);
191
192   mlir::Block::BlockArgListType args = funcOp.front().getArguments();
193   mlir::Value arg1 = args[0];
194   mlir::Value arg2 = args[1];
195
196   mlir::Value zeroIdx = builder.createIntegerConstant(loc, idxTy, 0);
197
198   fir::SequenceType::Shape flatShape = {fir::SequenceType::getUnknownExtent()};
199   mlir::Type arrTy1 = fir::SequenceType::get(flatShape, arg1ElementTy);
200   mlir::Type boxArrTy1 = fir::BoxType::get(arrTy1);
201   mlir::Value array1 = builder.create<fir::ConvertOp>(loc, boxArrTy1, arg1);
202   mlir::Type arrTy2 = fir::SequenceType::get(flatShape, arg2ElementTy);
203   mlir::Type boxArrTy2 = fir::BoxType::get(arrTy2);
204   mlir::Value array2 = builder.create<fir::ConvertOp>(loc, boxArrTy2, arg2);
205   // This version takes the loop trip count from the first argument.
206   // If the first argument's box has unknown (at compilation time)
207   // extent, then it may be better to take the extent from the second
208   // argument - so that after inlining the loop may be better optimized, e.g.
209   // fully unrolled. This requires generating two versions of the simplified
210   // function and some analysis at the call site to choose which version
211   // is more profitable to call.
212   // Note that we can assume that both arguments have the same extent.
213   auto dims =
214       builder.create<fir::BoxDimsOp>(loc, idxTy, idxTy, idxTy, array1, zeroIdx);
215   mlir::Value len = dims.getResult(1);
216   mlir::Value one = builder.createIntegerConstant(loc, idxTy, 1);
217   mlir::Value step = one;
218
219   // We use C indexing here, so len-1 as loopcount
220   mlir::Value loopCount = builder.create<mlir::arith::SubIOp>(loc, len, one);
221   auto loop = builder.create<fir::DoLoopOp>(loc, zeroIdx, loopCount, step,
222                                             /*unordered=*/false,
223                                             /*finalCountValue=*/false, zero);
224   mlir::Value sumVal = loop.getRegionIterArgs()[0];
225
226   // Begin loop code
227   mlir::OpBuilder::InsertPoint loopEndPt = builder.saveInsertionPoint();
228   builder.setInsertionPointToStart(loop.getBody());
229
230   mlir::Type eleRef1Ty = builder.getRefType(arg1ElementTy);
231   mlir::Value index = loop.getInductionVar();
232   mlir::Value addr1 =
233       builder.create<fir::CoordinateOp>(loc, eleRef1Ty, array1, index);
234   mlir::Value elem1 = builder.create<fir::LoadOp>(loc, addr1);
235   // Convert to the result type.
236   elem1 = builder.create<fir::ConvertOp>(loc, resultElementType, elem1);
237
238   mlir::Type eleRef2Ty = builder.getRefType(arg2ElementTy);
239   mlir::Value addr2 =
240       builder.create<fir::CoordinateOp>(loc, eleRef2Ty, array2, index);
241   mlir::Value elem2 = builder.create<fir::LoadOp>(loc, addr2);
242   // Convert to the result type.
243   elem2 = builder.create<fir::ConvertOp>(loc, resultElementType, elem2);
244
245   if (resultElementType.isa<mlir::FloatType>())
246     sumVal = builder.create<mlir::arith::AddFOp>(
247         loc, builder.create<mlir::arith::MulFOp>(loc, elem1, elem2), sumVal);
248   else if (resultElementType.isa<mlir::IntegerType>())
249     sumVal = builder.create<mlir::arith::AddIOp>(
250         loc, builder.create<mlir::arith::MulIOp>(loc, elem1, elem2), sumVal);
251   else
252     llvm_unreachable("unsupported type");
253
254   builder.create<fir::ResultOp>(loc, sumVal);
255   // End of loop.
256   builder.restoreInsertionPoint(loopEndPt);
257
258   mlir::Value resultVal = loop.getResult(0);
259   builder.create<mlir::func::ReturnOp>(loc, resultVal);
260 }
261
262 mlir::func::FuncOp SimplifyIntrinsicsPass::getOrCreateFunction(
263     fir::FirOpBuilder &builder, const mlir::StringRef &baseName,
264     FunctionTypeGeneratorTy typeGenerator,
265     FunctionBodyGeneratorTy bodyGenerator) {
266   // WARNING: if the function generated here changes its signature
267   //          or behavior (the body code), we should probably embed some
268   //          versioning information into its name, otherwise libraries
269   //          statically linked with older versions of Flang may stop
270   //          working with object files created with newer Flang.
271   //          We can also avoid this by using internal linkage, but
272   //          this may increase the size of final executable/shared library.
273   std::string replacementName = mlir::Twine{baseName, "_simplified"}.str();
274   mlir::ModuleOp module = builder.getModule();
275   // If we already have a function, just return it.
276   mlir::func::FuncOp newFunc =
277       fir::FirOpBuilder::getNamedFunction(module, replacementName);
278   mlir::FunctionType fType = typeGenerator(builder);
279   if (newFunc) {
280     assert(newFunc.getFunctionType() == fType &&
281            "type mismatch for simplified function");
282     return newFunc;
283   }
284
285   // Need to build the function!
286   auto loc = mlir::UnknownLoc::get(builder.getContext());
287   newFunc =
288       fir::FirOpBuilder::createFunction(loc, module, replacementName, fType);
289   auto inlineLinkage = mlir::LLVM::linkage::Linkage::LinkonceODR;
290   auto linkage =
291       mlir::LLVM::LinkageAttr::get(builder.getContext(), inlineLinkage);
292   newFunc->setAttr("llvm.linkage", linkage);
293
294   // Save the position of the original call.
295   mlir::OpBuilder::InsertPoint insertPt = builder.saveInsertionPoint();
296
297   bodyGenerator(builder, newFunc);
298
299   // Now back to where we were adding code earlier...
300   builder.restoreInsertionPoint(insertPt);
301
302   return newFunc;
303 }
304
305 static bool isOperandAbsent(mlir::Value val) {
306   if (mlir::Operation *op = val.getDefiningOp())
307     return mlir::isa_and_nonnull<fir::AbsentOp>(
308         op->getOperand(0).getDefiningOp());
309   return false;
310 }
311
312 static bool isZero(mlir::Value val) {
313   if (mlir::Operation *op = val.getDefiningOp())
314     if (mlir::Operation *defOp = op->getOperand(0).getDefiningOp())
315       return mlir::matchPattern(defOp, mlir::m_Zero());
316   return false;
317 }
318
319 static mlir::Value findShape(mlir::Value val) {
320   mlir::Operation *defOp = val.getDefiningOp();
321   while (defOp) {
322     defOp = defOp->getOperand(0).getDefiningOp();
323     if (fir::EmboxOp box = mlir::dyn_cast_or_null<fir::EmboxOp>(defOp))
324       return box.getShape();
325   }
326   return {};
327 }
328
329 static unsigned getDimCount(mlir::Value val) {
330   if (mlir::Value shapeVal = findShape(val)) {
331     mlir::Type resType = shapeVal.getDefiningOp()->getResultTypes()[0];
332     return fir::getRankOfShapeType(resType);
333   }
334   return 0;
335 }
336
337 /// Given the call operation's box argument \p val, discover
338 /// the element type of the underlying array object.
339 /// \returns the element type or llvm::None if the type cannot
340 /// be reliably found.
341 /// We expect that the argument is a result of fir.convert
342 /// with the destination type of !fir.box<none>.
343 static llvm::Optional<mlir::Type> getArgElementType(mlir::Value val) {
344   mlir::Operation *defOp;
345   do {
346     defOp = val.getDefiningOp();
347     // Analyze only sequences of convert operations.
348     if (!mlir::isa<fir::ConvertOp>(defOp))
349       return llvm::None;
350     val = defOp->getOperand(0);
351     // The convert operation is expected to convert from one
352     // box type to another box type.
353     auto boxType = val.getType().cast<fir::BoxType>();
354     auto elementType = fir::unwrapSeqOrBoxedSeqType(boxType);
355     if (!elementType.isa<mlir::NoneType>())
356       return elementType;
357   } while (true);
358 }
359
360 void SimplifyIntrinsicsPass::runOnOperation() {
361   LLVM_DEBUG(llvm::dbgs() << "=== Begin " DEBUG_TYPE " ===\n");
362   mlir::ModuleOp module = getOperation();
363   fir::KindMapping kindMap = fir::getKindMapping(module);
364   module.walk([&](mlir::Operation *op) {
365     if (auto call = mlir::dyn_cast<fir::CallOp>(op)) {
366       if (mlir::SymbolRefAttr callee = call.getCalleeAttr()) {
367         mlir::StringRef funcName = callee.getLeafReference().getValue();
368         // Replace call to runtime function for SUM when it has single
369         // argument (no dim or mask argument) for 1D arrays with either
370         // Integer4 or Real8 types. Other forms are ignored.
371         // The new function is added to the module.
372         //
373         // Prototype for runtime call (from sum.cpp):
374         // RTNAME(Sum<T>)(const Descriptor &x, const char *source, int line,
375         //                int dim, const Descriptor *mask)
376         if (funcName.startswith("_FortranASum")) {
377           mlir::Operation::operand_range args = call.getArgs();
378           // args[1] and args[2] are source filename and line number, ignored.
379           const mlir::Value &dim = args[3];
380           const mlir::Value &mask = args[4];
381           // dim is zero when it is absent, which is an implementation
382           // detail in the runtime library.
383           bool dimAndMaskAbsent = isZero(dim) && isOperandAbsent(mask);
384           unsigned rank = getDimCount(args[0]);
385           if (dimAndMaskAbsent && rank == 1) {
386             mlir::Location loc = call.getLoc();
387             mlir::Type type;
388             fir::FirOpBuilder builder(op, kindMap);
389             if (funcName.endswith("Integer4")) {
390               type = mlir::IntegerType::get(builder.getContext(), 32);
391             } else if (funcName.endswith("Real8")) {
392               type = mlir::FloatType::getF64(builder.getContext());
393             } else {
394               return;
395             }
396             auto typeGenerator = [&type](fir::FirOpBuilder &builder) {
397               return genFortranASumType(builder, type);
398             };
399             mlir::func::FuncOp newFunc = getOrCreateFunction(
400                 builder, funcName, typeGenerator, genFortranASumBody);
401             auto newCall = builder.create<fir::CallOp>(
402                 loc, newFunc, mlir::ValueRange{args[0]});
403             call->replaceAllUsesWith(newCall.getResults());
404             call->dropAllReferences();
405             call->erase();
406           }
407
408           return;
409         }
410         if (funcName.startswith("_FortranADotProduct")) {
411           LLVM_DEBUG(llvm::dbgs() << "Handling " << funcName << "\n");
412           LLVM_DEBUG(llvm::dbgs() << "Call operation:\n"; op->dump();
413                      llvm::dbgs() << "\n");
414           mlir::Operation::operand_range args = call.getArgs();
415           const mlir::Value &v1 = args[0];
416           const mlir::Value &v2 = args[1];
417           mlir::Location loc = call.getLoc();
418           fir::FirOpBuilder builder(op, kindMap);
419           mlir::Type type = call.getResult(0).getType();
420           if (!type.isa<mlir::FloatType>() && !type.isa<mlir::IntegerType>())
421             return;
422
423           // Try to find the element types of the boxed arguments.
424           auto arg1Type = getArgElementType(v1);
425           auto arg2Type = getArgElementType(v2);
426
427           if (!arg1Type || !arg2Type)
428             return;
429
430           // Support only floating point and integer arguments
431           // now (e.g. logical is skipped here).
432           if (!arg1Type->isa<mlir::FloatType>() &&
433               !arg1Type->isa<mlir::IntegerType>())
434             return;
435           if (!arg2Type->isa<mlir::FloatType>() &&
436               !arg2Type->isa<mlir::IntegerType>())
437             return;
438
439           auto typeGenerator = [&type](fir::FirOpBuilder &builder) {
440             return genFortranADotType(builder, type);
441           };
442           auto bodyGenerator = [&arg1Type,
443                                 &arg2Type](fir::FirOpBuilder &builder,
444                                            mlir::func::FuncOp &funcOp) {
445             genFortranADotBody(builder, funcOp, *arg1Type, *arg2Type);
446           };
447
448           // Suffix the function name with the element types
449           // of the arguments.
450           std::string typedFuncName(funcName);
451           llvm::raw_string_ostream nameOS(typedFuncName);
452           nameOS << "_";
453           arg1Type->print(nameOS);
454           nameOS << "_";
455           arg2Type->print(nameOS);
456
457           mlir::func::FuncOp newFunc = getOrCreateFunction(
458               builder, typedFuncName, typeGenerator, bodyGenerator);
459           auto newCall = builder.create<fir::CallOp>(loc, newFunc,
460                                                      mlir::ValueRange{v1, v2});
461           call->replaceAllUsesWith(newCall.getResults());
462           call->dropAllReferences();
463           call->erase();
464
465           LLVM_DEBUG(llvm::dbgs() << "Replaced with:\n"; newCall.dump();
466                      llvm::dbgs() << "\n");
467           return;
468         }
469       }
470     }
471   });
472   LLVM_DEBUG(llvm::dbgs() << "=== End " DEBUG_TYPE " ===\n");
473 }
474
475 void SimplifyIntrinsicsPass::getDependentDialects(
476     mlir::DialectRegistry &registry) const {
477   // LLVM::LinkageAttr creation requires that LLVM dialect is loaded.
478   registry.insert<mlir::LLVM::LLVMDialect>();
479 }
480 std::unique_ptr<mlir::Pass> fir::createSimplifyIntrinsicsPass() {
481   return std::make_unique<SimplifyIntrinsicsPass>();
482 }