[flang] Add cpowi function to runtime and use instead of pgmath
[platform/upstream/llvm.git] / flang / lib / Optimizer / Transforms / SimplifyIntrinsics.cpp
index 9a9e066..cdc3ab9 100644 (file)
@@ -22,7 +22,6 @@
 /// and small in size.
 //===----------------------------------------------------------------------===//
 
-#include "PassDetail.h"
 #include "flang/Optimizer/Builder/BoxValue.h"
 #include "flang/Optimizer/Builder/FIRBuilder.h"
 #include "flang/Optimizer/Builder/Todo.h"
@@ -30,6 +29,8 @@
 #include "flang/Optimizer/Dialect/FIRType.h"
 #include "flang/Optimizer/Support/FIRContext.h"
 #include "flang/Optimizer/Transforms/Passes.h"
+#include "flang/Runtime/entry-names.h"
+#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
 #include "mlir/IR/Matchers.h"
 #include "mlir/IR/TypeUtilities.h"
 #include "mlir/Pass/Pass.h"
 #include "llvm/Support/Debug.h"
 #include "llvm/Support/raw_ostream.h"
 
+namespace fir {
+#define GEN_PASS_DEF_SIMPLIFYINTRINSICS
+#include "flang/Optimizer/Transforms/Passes.h.inc"
+} // namespace fir
+
 #define DEBUG_TYPE "flang-simplify-intrinsics"
 
 namespace {
 
 class SimplifyIntrinsicsPass
-    : public fir::SimplifyIntrinsicsBase<SimplifyIntrinsicsPass> {
+    : public fir::impl::SimplifyIntrinsicsBase<SimplifyIntrinsicsPass> {
   using FunctionTypeGeneratorTy =
-      std::function<mlir::FunctionType(fir::FirOpBuilder &)>;
+      llvm::function_ref<mlir::FunctionType(fir::FirOpBuilder &)>;
   using FunctionBodyGeneratorTy =
-      std::function<void(fir::FirOpBuilder &, mlir::func::FuncOp &)>;
+      llvm::function_ref<void(fir::FirOpBuilder &, mlir::func::FuncOp &)>;
+  using GenReductionBodyTy = llvm::function_ref<void(
+      fir::FirOpBuilder &builder, mlir::func::FuncOp &funcOp, unsigned rank)>;
 
 public:
   /// Generate a new function implementing a simplified version
@@ -63,11 +71,21 @@ public:
                                          FunctionBodyGeneratorTy bodyGenerator);
   void runOnOperation() override;
   void getDependentDialects(mlir::DialectRegistry &registry) const override;
+
+private:
+  /// Helper function to replace a reduction type of call with its
+  /// simplified form. The actual function is generated using a callback
+  /// function.
+  /// \p call is the call to be replaced
+  /// \p kindMap is used to create FIROpBuilder
+  /// \p genBodyFunc is the callback that builds the replacement function
+  void simplifyReduction(fir::CallOp call, const fir::KindMapping &kindMap,
+                         GenReductionBodyTy genBodyFunc);
 };
 
 } // namespace
 
-/// Generate function type for the simplified version of FortranASum and
+/// Generate function type for the simplified version of RTNAME(Sum) and
 /// similar functions with a fir.box<none> type returning \p elementType.
 static mlir::FunctionType genNoneBoxType(fir::FirOpBuilder &builder,
                                          const mlir::Type &elementType) {
@@ -76,10 +94,10 @@ static mlir::FunctionType genNoneBoxType(fir::FirOpBuilder &builder,
                                  {elementType});
 }
 
-using BodyOpGeneratorTy =
-    std::function<mlir::Value(fir::FirOpBuilder &, mlir::Location,
-                              const mlir::Type &, mlir::Value, mlir::Value)>;
-using InitValGeneratorTy = std::function<mlir::Value(
+using BodyOpGeneratorTy = llvm::function_ref<mlir::Value(
+    fir::FirOpBuilder &, mlir::Location, const mlir::Type &, mlir::Value,
+    mlir::Value)>;
+using InitValGeneratorTy = llvm::function_ref<mlir::Value(
     fir::FirOpBuilder &, mlir::Location, const mlir::Type &)>;
 
 /// Generate the reduction loop into \p funcOp.
@@ -88,10 +106,11 @@ using InitValGeneratorTy = std::function<mlir::Value(
 ///    the reduction value
 /// \p genBody is called to fill in the actual reduciton operation
 ///    for example add for SUM, MAX for MAXVAL, etc.
+/// \p rank is the rank of the input argument.
 static void genReductionLoop(fir::FirOpBuilder &builder,
                              mlir::func::FuncOp &funcOp,
                              InitValGeneratorTy initVal,
-                             BodyOpGeneratorTy genBody) {
+                             BodyOpGeneratorTy genBody, unsigned rank) {
   auto loc = mlir::UnknownLoc::get(builder.getContext());
   mlir::Type elementType = funcOp.getResultTypes()[0];
   builder.setInsertionPointToEnd(funcOp.addEntryBlock());
@@ -103,65 +122,106 @@ static void genReductionLoop(fir::FirOpBuilder &builder,
 
   mlir::Value zeroIdx = builder.createIntegerConstant(loc, idxTy, 0);
 
-  fir::SequenceType::Shape flatShape = {fir::SequenceType::getUnknownExtent()};
+  fir::SequenceType::Shape flatShape(rank,
+                                     fir::SequenceType::getUnknownExtent());
   mlir::Type arrTy = fir::SequenceType::get(flatShape, elementType);
   mlir::Type boxArrTy = fir::BoxType::get(arrTy);
   mlir::Value array = builder.create<fir::ConvertOp>(loc, boxArrTy, arg);
-  auto dims =
-      builder.create<fir::BoxDimsOp>(loc, idxTy, idxTy, idxTy, array, zeroIdx);
-  mlir::Value len = dims.getResult(1);
+  mlir::Value init = initVal(builder, loc, elementType);
+
+  llvm::SmallVector<mlir::Value, 15> bounds;
+
+  assert(rank > 0 && "rank cannot be zero");
   mlir::Value one = builder.createIntegerConstant(loc, idxTy, 1);
-  mlir::Value step = one;
 
-  // We use C indexing here, so len-1 as loopcount
-  mlir::Value loopCount = builder.create<mlir::arith::SubIOp>(loc, len, one);
-  mlir::Value init = initVal(builder, loc, elementType);
-  auto loop = builder.create<fir::DoLoopOp>(loc, zeroIdx, loopCount, step,
-                                            /*unordered=*/false,
-                                            /*finalCountValue=*/false, init);
-  mlir::Value reductionVal = loop.getRegionIterArgs()[0];
+  // Compute all the upper bounds before the loop nest.
+  // It is not strictly necessary for performance, since the loop nest
+  // does not have any store operations and any LICM optimization
+  // should be able to optimize the redundancy.
+  for (unsigned i = 0; i < rank; ++i) {
+    mlir::Value dimIdx = builder.createIntegerConstant(loc, idxTy, i);
+    auto dims =
+        builder.create<fir::BoxDimsOp>(loc, idxTy, idxTy, idxTy, array, dimIdx);
+    mlir::Value len = dims.getResult(1);
+    // We use C indexing here, so len-1 as loopcount
+    mlir::Value loopCount = builder.create<mlir::arith::SubIOp>(loc, len, one);
+    bounds.push_back(loopCount);
+  }
 
-  // Begin loop code
-  mlir::OpBuilder::InsertPoint loopEndPt = builder.saveInsertionPoint();
-  builder.setInsertionPointToStart(loop.getBody());
+  // Create a loop nest consisting of DoLoopOp operations.
+  // Collect the loops' induction variables into indices array,
+  // which will be used in the innermost loop to load the input
+  // array's element.
+  // The loops are generated such that the innermost loop processes
+  // the 0 dimension.
+  llvm::SmallVector<mlir::Value, 15> indices;
+  for (unsigned i = rank; 0 < i; --i) {
+    mlir::Value step = one;
+    mlir::Value loopCount = bounds[i - 1];
+    auto loop = builder.create<fir::DoLoopOp>(loc, zeroIdx, loopCount, step,
+                                              /*unordered=*/false,
+                                              /*finalCountValue=*/false, init);
+    init = loop.getRegionIterArgs()[0];
+    indices.push_back(loop.getInductionVar());
+    // Set insertion point to the loop body so that the next loop
+    // is inserted inside the current one.
+    builder.setInsertionPointToStart(loop.getBody());
+  }
+
+  // Reverse the indices such that they are ordered as:
+  //   <dim-0-idx, dim-1-idx, ...>
+  std::reverse(indices.begin(), indices.end());
 
+  // We are in the innermost loop: generate the reduction body.
   mlir::Type eleRefTy = builder.getRefType(elementType);
-  mlir::Value index = loop.getInductionVar();
   mlir::Value addr =
-      builder.create<fir::CoordinateOp>(loc, eleRefTy, array, index);
+      builder.create<fir::CoordinateOp>(loc, eleRefTy, array, indices);
   mlir::Value elem = builder.create<fir::LoadOp>(loc, addr);
 
-  reductionVal = genBody(builder, loc, elementType, elem, reductionVal);
-
-  builder.create<fir::ResultOp>(loc, reductionVal);
-  // End of loop.
-  builder.restoreInsertionPoint(loopEndPt);
+  mlir::Value reductionVal = genBody(builder, loc, elementType, elem, init);
+
+  // Unwind the loop nest and insert ResultOp on each level
+  // to return the updated value of the reduction to the enclosing
+  // loops.
+  for (unsigned i = 0; i < rank; ++i) {
+    auto result = builder.create<fir::ResultOp>(loc, reductionVal);
+    // Proceed to the outer loop.
+    auto loop = mlir::cast<fir::DoLoopOp>(result->getParentOp());
+    reductionVal = loop.getResult(0);
+    // Set insertion point after the loop operation that we have
+    // just processed.
+    builder.setInsertionPointAfter(loop.getOperation());
+  }
 
-  mlir::Value resultVal = loop.getResult(0);
-  builder.create<mlir::func::ReturnOp>(loc, resultVal);
+  // End of loop nest. The insertion point is after the outermost loop.
+  // Return the reduction value from the function.
+  builder.create<mlir::func::ReturnOp>(loc, reductionVal);
 }
 
-/// Generate function body of the simplified version of FortranASum
+/// Generate function body of the simplified version of RTNAME(Sum)
 /// with signature provided by \p funcOp. The caller is responsible
 /// for saving/restoring the original insertion point of \p builder.
 /// \p funcOp is expected to be empty on entry to this function.
-static void genFortranASumBody(fir::FirOpBuilder &builder,
-                               mlir::func::FuncOp &funcOp) {
-  // function FortranASum<T>_simplified(arr)
+/// \p rank specifies the rank of the input argument.
+static void genRuntimeSumBody(fir::FirOpBuilder &builder,
+                              mlir::func::FuncOp &funcOp, unsigned rank) {
+  // function RTNAME(Sum)<T>x<rank>_simplified(arr)
   //   T, dimension(:) :: arr
   //   T sum = 0
   //   integer iter
   //   do iter = 0, extent(arr)
   //     sum = sum + arr[iter]
   //   end do
-  //   FortranASum<T>_simplified = sum
-  // end function FortranASum<T>_simplified
+  //   RTNAME(Sum)<T>x<rank>_simplified = sum
+  // end function RTNAME(Sum)<T>x<rank>_simplified
   auto zero = [](fir::FirOpBuilder builder, mlir::Location loc,
                  mlir::Type elementType) {
-    return elementType.isa<mlir::FloatType>()
-               ? builder.createRealConstant(loc, elementType,
-                                            llvm::APFloat(0.0))
-               : builder.createIntegerConstant(loc, elementType, 0);
+    if (auto ty = elementType.dyn_cast<mlir::FloatType>()) {
+      const llvm::fltSemantics &sem = ty.getFloatSemantics();
+      return builder.createRealConstant(loc, elementType,
+                                        llvm::APFloat::getZero(sem));
+    }
+    return builder.createIntegerConstant(loc, elementType, 0);
   };
 
   auto genBodyOp = [](fir::FirOpBuilder builder, mlir::Location loc,
@@ -176,11 +236,11 @@ static void genFortranASumBody(fir::FirOpBuilder &builder,
     return {};
   };
 
-  genReductionLoop(builder, funcOp, zero, genBodyOp);
+  genReductionLoop(builder, funcOp, zero, genBodyOp, rank);
 }
 
-static void genFortranAMaxvalBody(fir::FirOpBuilder &builder,
-                                  mlir::func::FuncOp &funcOp) {
+static void genRuntimeMaxvalBody(fir::FirOpBuilder &builder,
+                                 mlir::func::FuncOp &funcOp, unsigned rank) {
   auto init = [](fir::FirOpBuilder builder, mlir::Location loc,
                  mlir::Type elementType) {
     if (auto ty = elementType.dyn_cast<mlir::FloatType>()) {
@@ -204,38 +264,38 @@ static void genFortranAMaxvalBody(fir::FirOpBuilder &builder,
     llvm_unreachable("unsupported type");
     return {};
   };
-  genReductionLoop(builder, funcOp, init, genBodyOp);
+  genReductionLoop(builder, funcOp, init, genBodyOp, rank);
 }
 
-/// Generate function type for the simplified version of FortranADotProduct
+/// Generate function type for the simplified version of RTNAME(DotProduct)
 /// operating on the given \p elementType.
-static mlir::FunctionType genFortranADotType(fir::FirOpBuilder &builder,
-                                             const mlir::Type &elementType) {
+static mlir::FunctionType genRuntimeDotType(fir::FirOpBuilder &builder,
+                                            const mlir::Type &elementType) {
   mlir::Type boxType = fir::BoxType::get(builder.getNoneType());
   return mlir::FunctionType::get(builder.getContext(), {boxType, boxType},
                                  {elementType});
 }
 
-/// Generate function body of the simplified version of FortranADotProduct
+/// Generate function body of the simplified version of RTNAME(DotProduct)
 /// with signature provided by \p funcOp. The caller is responsible
 /// for saving/restoring the original insertion point of \p builder.
 /// \p funcOp is expected to be empty on entry to this function.
 /// \p arg1ElementTy and \p arg2ElementTy specify elements types
 /// of the underlying array objects - they are used to generate proper
 /// element accesses.
-static void genFortranADotBody(fir::FirOpBuilder &builder,
-                               mlir::func::FuncOp &funcOp,
-                               mlir::Type arg1ElementTy,
-                               mlir::Type arg2ElementTy) {
-  // function FortranADotProduct<T>_simplified(arr1, arr2)
+static void genRuntimeDotBody(fir::FirOpBuilder &builder,
+                              mlir::func::FuncOp &funcOp,
+                              mlir::Type arg1ElementTy,
+                              mlir::Type arg2ElementTy) {
+  // function RTNAME(DotProduct)<T>_simplified(arr1, arr2)
   //   T, dimension(:) :: arr1, arr2
   //   T product = 0
   //   integer iter
   //   do iter = 0, extent(arr1)
   //     product = product + arr1[iter] * arr2[iter]
   //   end do
-  //   FortranADotProduct<T>_simplified = product
-  // end function FortranADotProduct<T>_simplified
+  //   RTNAME(ADotProduct)<T>_simplified = product
+  // end function RTNAME(DotProduct)<T>_simplified
   auto loc = mlir::UnknownLoc::get(builder.getContext());
   mlir::Type resultElementType = funcOp.getResultTypes()[0];
   builder.setInsertionPointToEnd(funcOp.addEntryBlock());
@@ -386,21 +446,31 @@ static bool isZero(mlir::Value val) {
   return false;
 }
 
-static mlir::Value findShape(mlir::Value val) {
+static mlir::Value findBoxDef(mlir::Value val) {
   if (auto op = expectConvertOp(val)) {
     assert(op->getOperands().size() != 0);
     if (auto box = mlir::dyn_cast_or_null<fir::EmboxOp>(
             op->getOperand(0).getDefiningOp()))
-      return box.getShape();
+      return box.getResult();
+    if (auto box = mlir::dyn_cast_or_null<fir::ReboxOp>(
+            op->getOperand(0).getDefiningOp()))
+      return box.getResult();
   }
   return {};
 }
 
 static unsigned getDimCount(mlir::Value val) {
-  if (mlir::Value shapeVal = findShape(val)) {
-    mlir::Type resType = shapeVal.getDefiningOp()->getResultTypes()[0];
-    return fir::getRankOfShapeType(resType);
-  }
+  // In order to find the dimensions count, we look for EmboxOp/ReboxOp
+  // and take the count from its *result* type. Note that in case
+  // of sliced emboxing the operand and the result of EmboxOp/ReboxOp
+  // have different types.
+  // Actually, we can take the box type from the operand of
+  // the first ConvertOp that has non-opaque box type that we meet
+  // going through the ConvertOp chain.
+  if (mlir::Value emboxVal = findBoxDef(val))
+    if (auto boxTy = emboxVal.getType().dyn_cast<fir::BoxType>())
+      if (auto seqTy = boxTy.getEleTy().dyn_cast<fir::SequenceType>())
+        return seqTy.getDimension();
   return 0;
 }
 
@@ -427,6 +497,56 @@ static llvm::Optional<mlir::Type> getArgElementType(mlir::Value val) {
   } while (true);
 }
 
+void SimplifyIntrinsicsPass::simplifyReduction(fir::CallOp call,
+                                               const fir::KindMapping &kindMap,
+                                               GenReductionBodyTy genBodyFunc) {
+  mlir::SymbolRefAttr callee = call.getCalleeAttr();
+  mlir::Operation::operand_range args = call.getArgs();
+  // args[1] and args[2] are source filename and line number, ignored.
+  const mlir::Value &dim = args[3];
+  const mlir::Value &mask = args[4];
+  // dim is zero when it is absent, which is an implementation
+  // detail in the runtime library.
+  bool dimAndMaskAbsent = isZero(dim) && isOperandAbsent(mask);
+  unsigned rank = getDimCount(args[0]);
+  if (dimAndMaskAbsent && rank > 0) {
+    mlir::Location loc = call.getLoc();
+    fir::FirOpBuilder builder(call, kindMap);
+
+    // Support only floating point and integer results now.
+    mlir::Type resultType = call.getResult(0).getType();
+    if (!resultType.isa<mlir::FloatType>() &&
+        !resultType.isa<mlir::IntegerType>())
+      return;
+
+    auto argType = getArgElementType(args[0]);
+    if (!argType)
+      return;
+    assert(*argType == resultType &&
+           "Argument/result types mismatch in reduction");
+
+    auto typeGenerator = [&resultType](fir::FirOpBuilder &builder) {
+      return genNoneBoxType(builder, resultType);
+    };
+    auto bodyGenerator = [&rank, &genBodyFunc](fir::FirOpBuilder &builder,
+                                               mlir::func::FuncOp &funcOp) {
+      genBodyFunc(builder, funcOp, rank);
+    };
+    // Mangle the function name with the rank value as "x<rank>".
+    std::string funcName =
+        (mlir::Twine{callee.getLeafReference().getValue(), "x"} +
+         mlir::Twine{rank})
+            .str();
+    mlir::func::FuncOp newFunc =
+        getOrCreateFunction(builder, funcName, typeGenerator, bodyGenerator);
+    auto newCall =
+        builder.create<fir::CallOp>(loc, newFunc, mlir::ValueRange{args[0]});
+    call->replaceAllUsesWith(newCall.getResults());
+    call->dropAllReferences();
+    call->erase();
+  }
+}
+
 void SimplifyIntrinsicsPass::runOnOperation() {
   LLVM_DEBUG(llvm::dbgs() << "=== Begin " DEBUG_TYPE " ===\n");
   mlir::ModuleOp module = getOperation();
@@ -444,41 +564,11 @@ void SimplifyIntrinsicsPass::runOnOperation() {
         // RTNAME(Sum<T>)(const Descriptor &x, const char *source, int line,
         //                int dim, const Descriptor *mask)
         //
-        if (funcName.startswith("_FortranASum")) {
-          mlir::Operation::operand_range args = call.getArgs();
-          // args[1] and args[2] are source filename and line number, ignored.
-          const mlir::Value &dim = args[3];
-          const mlir::Value &mask = args[4];
-          // dim is zero when it is absent, which is an implementation
-          // detail in the runtime library.
-          bool dimAndMaskAbsent = isZero(dim) && isOperandAbsent(mask);
-          unsigned rank = getDimCount(args[0]);
-          if (dimAndMaskAbsent && rank == 1) {
-            mlir::Location loc = call.getLoc();
-            mlir::Type type;
-            fir::FirOpBuilder builder(op, kindMap);
-            if (funcName.endswith("Integer4")) {
-              type = mlir::IntegerType::get(builder.getContext(), 32);
-            } else if (funcName.endswith("Real8")) {
-              type = mlir::FloatType::getF64(builder.getContext());
-            } else {
-              return;
-            }
-            auto typeGenerator = [&type](fir::FirOpBuilder &builder) {
-              return genNoneBoxType(builder, type);
-            };
-            mlir::func::FuncOp newFunc = getOrCreateFunction(
-                builder, funcName, typeGenerator, genFortranASumBody);
-            auto newCall = builder.create<fir::CallOp>(
-                loc, newFunc, mlir::ValueRange{args[0]});
-            call->replaceAllUsesWith(newCall.getResults());
-            call->dropAllReferences();
-            call->erase();
-          }
-
+        if (funcName.startswith(RTNAME_STRING(Sum))) {
+          simplifyReduction(call, kindMap, genRuntimeSumBody);
           return;
         }
-        if (funcName.startswith("_FortranADotProduct")) {
+        if (funcName.startswith(RTNAME_STRING(DotProduct))) {
           LLVM_DEBUG(llvm::dbgs() << "Handling " << funcName << "\n");
           LLVM_DEBUG(llvm::dbgs() << "Call operation:\n"; op->dump();
                      llvm::dbgs() << "\n");
@@ -509,12 +599,12 @@ void SimplifyIntrinsicsPass::runOnOperation() {
             return;
 
           auto typeGenerator = [&type](fir::FirOpBuilder &builder) {
-            return genFortranADotType(builder, type);
+            return genRuntimeDotType(builder, type);
           };
           auto bodyGenerator = [&arg1Type,
                                 &arg2Type](fir::FirOpBuilder &builder,
                                            mlir::func::FuncOp &funcOp) {
-            genFortranADotBody(builder, funcOp, *arg1Type, *arg2Type);
+            genRuntimeDotBody(builder, funcOp, *arg1Type, *arg2Type);
           };
 
           // Suffix the function name with the element types
@@ -538,37 +628,9 @@ void SimplifyIntrinsicsPass::runOnOperation() {
                      llvm::dbgs() << "\n");
           return;
         }
-        if (funcName.startswith("_FortranAMaxval")) {
-          mlir::Operation::operand_range args = call.getArgs();
-          const mlir::Value &dim = args[3];
-          const mlir::Value &mask = args[4];
-          // dim is zero when it is absent, which is an implementation
-          // detail in the runtime library.
-          bool dimAndMaskAbsent = isZero(dim) && isOperandAbsent(mask);
-          unsigned rank = getDimCount(args[0]);
-          if (dimAndMaskAbsent && rank == 1) {
-            mlir::Location loc = call.getLoc();
-            mlir::Type type;
-            fir::FirOpBuilder builder(op, kindMap);
-            if (funcName.endswith("Integer4")) {
-              type = mlir::IntegerType::get(builder.getContext(), 32);
-            } else if (funcName.endswith("Real8")) {
-              type = mlir::FloatType::getF64(builder.getContext());
-            } else {
-              return;
-            }
-            auto typeGenerator = [&type](fir::FirOpBuilder &builder) {
-              return genNoneBoxType(builder, type);
-            };
-            mlir::func::FuncOp newFunc = getOrCreateFunction(
-                builder, funcName, typeGenerator, genFortranAMaxvalBody);
-            auto newCall = builder.create<fir::CallOp>(
-                loc, newFunc, mlir::ValueRange{args[0]});
-            call->replaceAllUsesWith(newCall.getResults());
-            call->dropAllReferences();
-            call->erase();
-            return;
-          }
+        if (funcName.startswith(RTNAME_STRING(Maxval))) {
+          simplifyReduction(call, kindMap, genRuntimeMaxvalBody);
+          return;
         }
       }
     }