[flang] Lower F08 NORM2 intrinsic
authorTarun Prabhu <tarun@lanl.gov>
Mon, 5 Dec 2022 20:50:33 +0000 (13:50 -0700)
committerTarun Prabhu <tarun@lanl.gov>
Mon, 5 Dec 2022 20:53:35 +0000 (13:53 -0700)
The implementation follows the pattern used in comparable intrinsics.
Change the runtime API for Norm2 so it does not expect a mask argument
since the Norm2 intrinsic does not accept a mask in Fortran.

Differential Revision: https://reviews.llvm.org/D138150

flang/include/flang/Optimizer/Builder/Runtime/Reduction.h
flang/include/flang/Runtime/reduction.h
flang/lib/Lower/IntrinsicCall.cpp
flang/lib/Optimizer/Builder/Runtime/Reduction.cpp
flang/runtime/extrema.cpp
flang/test/Lower/Intrinsics/norm2.f90 [new file with mode: 0644]

index 8faf568..1e848e3 100644 (file)
@@ -148,6 +148,16 @@ void genMinvalDim(fir::FirOpBuilder &builder, mlir::Location loc,
                   mlir::Value resultBox, mlir::Value arrayBox, mlir::Value dim,
                   mlir::Value maskBox);
 
+/// Generate call to `Norm2` intrinsic runtime routine. This is the version
+/// that does not take a dim argument.
+mlir::Value genNorm2(fir::FirOpBuilder &builder, mlir::Location loc,
+                     mlir::Value arrayBox);
+
+/// Generate call to `Norm2Dim` intrinsic runtime routine. This is the version
+/// that takes a dim argument.
+void genNorm2Dim(fir::FirOpBuilder &builder, mlir::Location loc,
+                 mlir::Value resultBox, mlir::Value arrayBox, mlir::Value dim);
+
 /// Generate call to `Parity` runtime routine. This version of `parity` is
 /// specialized for rank 1 mask arguments.
 /// This calls the version that returns a scalar logical value.
index 43da1d3..82c67ca 100644 (file)
@@ -347,23 +347,23 @@ void RTNAME(MinvalDim)(Descriptor &, const Descriptor &, int dim,
     const char *source, int line, const Descriptor *mask = nullptr);
 
 // NORM2
-float RTNAME(Norm2_2)(const Descriptor &, const char *source, int line,
-    int dim = 0, const Descriptor *mask = nullptr);
-float RTNAME(Norm2_3)(const Descriptor &, const char *source, int line,
-    int dim = 0, const Descriptor *mask = nullptr);
-float RTNAME(Norm2_4)(const Descriptor &, const char *source, int line,
-    int dim = 0, const Descriptor *mask = nullptr);
-double RTNAME(Norm2_8)(const Descriptor &, const char *source, int line,
-    int dim = 0, const Descriptor *mask = nullptr);
+float RTNAME(Norm2_2)(
+    const Descriptor &, const char *source, int line, int dim = 0);
+float RTNAME(Norm2_3)(
+    const Descriptor &, const char *source, int line, int dim = 0);
+float RTNAME(Norm2_4)(
+    const Descriptor &, const char *source, int line, int dim = 0);
+double RTNAME(Norm2_8)(
+    const Descriptor &, const char *source, int line, int dim = 0);
 #if LDBL_MANT_DIG == 64
-long double RTNAME(Norm2_10)(const Descriptor &, const char *source, int line,
-    int dim = 0, const Descriptor *mask = nullptr);
+long double RTNAME(Norm2_10)(
+    const Descriptor &, const char *source, int line, int dim = 0);
 #elif LDBL_MANT_DIG == 113
-long double RTNAME(Norm2_16)(const Descriptor &, const char *source, int line,
-    int dim = 0, const Descriptor *mask = nullptr);
+long double RTNAME(Norm2_16)(
+    const Descriptor &, const char *source, int line, int dim = 0);
 #endif
-void RTNAME(Norm2Dim)(Descriptor &, const Descriptor &, int dim,
-    const char *source, int line, const Descriptor *mask = nullptr);
+void RTNAME(Norm2Dim)(
+    Descriptor &, const Descriptor &, int dim, const char *source, int line);
 
 // ALL, ANY, COUNT, & PARITY logical reductions
 bool RTNAME(All)(const Descriptor &, const char *source, int line, int dim = 0);
index 76ee6a9..bda5806 100644 (file)
@@ -546,6 +546,7 @@ struct IntrinsicLibrary {
   void genMvbits(llvm::ArrayRef<fir::ExtendedValue>);
   mlir::Value genNearest(mlir::Type, llvm::ArrayRef<mlir::Value>);
   mlir::Value genNint(mlir::Type, llvm::ArrayRef<mlir::Value>);
+  fir::ExtendedValue genNorm2(mlir::Type, llvm::ArrayRef<fir::ExtendedValue>);
   mlir::Value genNot(mlir::Type, llvm::ArrayRef<mlir::Value>);
   fir::ExtendedValue genNull(mlir::Type, llvm::ArrayRef<fir::ExtendedValue>);
   fir::ExtendedValue genPack(mlir::Type, llvm::ArrayRef<fir::ExtendedValue>);
@@ -940,6 +941,10 @@ static constexpr IntrinsicHandler handlers[]{
        {"topos", asValue}}}},
     {"nearest", &I::genNearest},
     {"nint", &I::genNint},
+    {"norm2",
+     &I::genNorm2,
+     {{{"array", asBox}, {"dim", asValue}}},
+     /*isElemental=*/false},
     {"not", &I::genNot},
     {"null", &I::genNull, {{{"mold", asInquired}}}, /*isElemental=*/false},
     {"pack",
@@ -4100,6 +4105,50 @@ mlir::Value IntrinsicLibrary::genNint(mlir::Type resultType,
   return genRuntimeCall("nint", resultType, {args[0]});
 }
 
+// NORM2
+fir::ExtendedValue
+IntrinsicLibrary::genNorm2(mlir::Type resultType,
+                           llvm::ArrayRef<fir::ExtendedValue> args) {
+  assert(args.size() == 2);
+
+  // Handle required array argument
+  mlir::Value array = builder.createBox(loc, args[0]);
+  unsigned rank = fir::BoxValue(array).rank();
+  assert(rank >= 1);
+
+  // Check if the dim argument is present
+  bool absentDim = isStaticallyAbsent(args[1]);
+
+  // If dim argument is absent or the array is rank 1, then the result is
+  // a scalar (since the the result is rank-1 or 0). Otherwise, the result is
+  // an array.
+  if (absentDim || rank == 1) {
+    return fir::runtime::genNorm2(builder, loc, array);
+  } else {
+    // Create mutable fir.box to be passed to the runtime for the result.
+    mlir::Type resultArrayType = builder.getVarLenSeqTy(resultType, rank - 1);
+    fir::MutableBoxValue resultMutableBox =
+        fir::factory::createTempMutableBox(builder, loc, resultArrayType);
+    mlir::Value resultIrBox =
+        fir::factory::getMutableIRBox(builder, loc, resultMutableBox);
+
+    mlir::Value dim = fir::getBase(args[1]);
+    fir::runtime::genNorm2Dim(builder, loc, resultIrBox, array, dim);
+
+    // Handle cleanup of allocatable result descriptor and return
+    fir::ExtendedValue res =
+        fir::factory::genMutableBoxRead(builder, loc, resultMutableBox);
+    return res.match(
+        [&](const fir::ArrayBoxValue &box) -> fir::ExtendedValue {
+          addCleanUpForTemp(loc, box.getAddr());
+          return box;
+        },
+        [&](const auto &) -> fir::ExtendedValue {
+          fir::emitFatalError(loc, "unexpected result for Norm2");
+        });
+  }
+}
+
 // NOT
 mlir::Value IntrinsicLibrary::genNot(mlir::Type resultType,
                                      llvm::ArrayRef<mlir::Value> args) {
index b1c7bae..9d47d65 100644 (file)
@@ -119,6 +119,36 @@ struct ForcedMinvalInteger16 {
   }
 };
 
+/// Placeholder for real*10 version of Norm2 Intrinsic
+struct ForcedNorm2Real10 {
+  static constexpr const char *name = ExpandAndQuoteKey(RTNAME(Norm2_10));
+  static constexpr fir::runtime::FuncTypeBuilderFunc getTypeModel() {
+    return [](mlir::MLIRContext *ctx) {
+      auto ty = mlir::FloatType::getF80(ctx);
+      auto boxTy =
+          fir::runtime::getModel<const Fortran::runtime::Descriptor &>()(ctx);
+      auto strTy = fir::ReferenceType::get(mlir::IntegerType::get(ctx, 8));
+      auto intTy = mlir::IntegerType::get(ctx, 8 * sizeof(int));
+      return mlir::FunctionType::get(ctx, {boxTy, strTy, intTy, intTy}, {ty});
+    };
+  }
+};
+
+/// Placeholder for real*16 version of Norm2 Intrinsic
+struct ForcedNorm2Real16 {
+  static constexpr const char *name = ExpandAndQuoteKey(RTNAME(Norm2_16));
+  static constexpr fir::runtime::FuncTypeBuilderFunc getTypeModel() {
+    return [](mlir::MLIRContext *ctx) {
+      auto ty = mlir::FloatType::getF128(ctx);
+      auto boxTy =
+          fir::runtime::getModel<const Fortran::runtime::Descriptor &>()(ctx);
+      auto strTy = fir::ReferenceType::get(mlir::IntegerType::get(ctx, 8));
+      auto intTy = mlir::IntegerType::get(ctx, 8 * sizeof(int));
+      return mlir::FunctionType::get(ctx, {boxTy, strTy, intTy, intTy}, {ty});
+    };
+  }
+};
+
 /// Placeholder for real*10 version of Product Intrinsic
 struct ForcedProductReal10 {
   static constexpr const char *name = ExpandAndQuoteKey(RTNAME(ProductReal10));
@@ -849,6 +879,55 @@ mlir::Value fir::runtime::genMinval(fir::FirOpBuilder &builder,
   return builder.create<fir::CallOp>(loc, func, args).getResult(0);
 }
 
+/// Generate call to `Norm2Dim` intrinsic runtime routine. This is the version
+/// that takes a dim argument.
+void fir::runtime::genNorm2Dim(fir::FirOpBuilder &builder, mlir::Location loc,
+                               mlir::Value resultBox, mlir::Value arrayBox,
+                               mlir::Value dim) {
+  auto func = fir::runtime::getRuntimeFunc<mkRTKey(Norm2Dim)>(loc, builder);
+  auto fTy = func.getFunctionType();
+  auto sourceFile = fir::factory::locationToFilename(builder, loc);
+  auto sourceLine =
+      fir::factory::locationToLineNo(builder, loc, fTy.getInput(4));
+  auto args = fir::runtime::createArguments(
+      builder, loc, fTy, resultBox, arrayBox, dim, sourceFile, sourceLine);
+
+  builder.create<fir::CallOp>(loc, func, args);
+}
+
+/// Generate call to `Norm2` intrinsic runtime routine. This is the version
+/// that does not take a dim argument.
+mlir::Value fir::runtime::genNorm2(fir::FirOpBuilder &builder,
+                                   mlir::Location loc, mlir::Value arrayBox) {
+  mlir::func::FuncOp func;
+  auto ty = arrayBox.getType();
+  auto arrTy = fir::dyn_cast_ptrOrBoxEleTy(ty);
+  auto eleTy = arrTy.cast<fir::SequenceType>().getEleTy();
+  auto dim = builder.createIntegerConstant(loc, builder.getIndexType(), 0);
+
+  if (eleTy.isF16() || eleTy.isBF16())
+    TODO(loc, "half-precision NORM2");
+  else if (eleTy.isF32())
+    func = fir::runtime::getRuntimeFunc<mkRTKey(Norm2_4)>(loc, builder);
+  else if (eleTy.isF64())
+    func = fir::runtime::getRuntimeFunc<mkRTKey(Norm2_8)>(loc, builder);
+  else if (eleTy.isF80())
+    func = fir::runtime::getRuntimeFunc<ForcedNorm2Real10>(loc, builder);
+  else if (eleTy.isF128())
+    func = fir::runtime::getRuntimeFunc<ForcedNorm2Real16>(loc, builder);
+  else
+    fir::emitFatalError(loc, "invalid type in NORM2");
+
+  auto fTy = func.getFunctionType();
+  auto sourceFile = fir::factory::locationToFilename(builder, loc);
+  auto sourceLine =
+      fir::factory::locationToLineNo(builder, loc, fTy.getInput(2));
+  auto args = fir::runtime::createArguments(builder, loc, fTy, arrayBox,
+                                            sourceFile, sourceLine, dim);
+
+  return builder.create<fir::CallOp>(loc, func, args).getResult(0);
+}
+
 /// Generate call to `Parity` intrinsic runtime routine. This routine is
 /// specialized for mask arguments with rank == 1.
 mlir::Value fir::runtime::genParity(fir::FirOpBuilder &builder,
index 00c02ee..c9dcc65 100644 (file)
@@ -833,39 +833,39 @@ template <int KIND> struct Norm2Helper {
 
 extern "C" {
 // TODO: REAL(2 & 3)
-CppTypeFor<TypeCategory::Real, 4> RTNAME(Norm2_4)(const Descriptor &x,
-    const char *source, int line, int dim, const Descriptor *mask) {
+CppTypeFor<TypeCategory::Real, 4> RTNAME(Norm2_4)(
+    const Descriptor &x, const char *source, int line, int dim) {
   return GetTotalReduction<TypeCategory::Real, 4>(
-      x, source, line, dim, mask, Norm2Accumulator<4>{x}, "NORM2");
+      x, source, line, dim, nullptr, Norm2Accumulator<4>{x}, "NORM2");
 }
-CppTypeFor<TypeCategory::Real, 8> RTNAME(Norm2_8)(const Descriptor &x,
-    const char *source, int line, int dim, const Descriptor *mask) {
+CppTypeFor<TypeCategory::Real, 8> RTNAME(Norm2_8)(
+    const Descriptor &x, const char *source, int line, int dim) {
   return GetTotalReduction<TypeCategory::Real, 8>(
-      x, source, line, dim, mask, Norm2Accumulator<8>{x}, "NORM2");
+      x, source, line, dim, nullptr, Norm2Accumulator<8>{x}, "NORM2");
 }
 #if LDBL_MANT_DIG == 64
-CppTypeFor<TypeCategory::Real, 10> RTNAME(Norm2_10)(const Descriptor &x,
-    const char *source, int line, int dim, const Descriptor *mask) {
+CppTypeFor<TypeCategory::Real, 10> RTNAME(Norm2_10)(
+    const Descriptor &x, const char *source, int line, int dim) {
   return GetTotalReduction<TypeCategory::Real, 10>(
-      x, source, line, dim, mask, Norm2Accumulator<10>{x}, "NORM2");
+      x, source, line, dim, nullptr, Norm2Accumulator<10>{x}, "NORM2");
 }
 #endif
 #if LDBL_MANT_DIG == 113
-CppTypeFor<TypeCategory::Real, 16> RTNAME(Norm2_16)(const Descriptor &x,
-    const char *source, int line, int dim, const Descriptor *mask) {
+CppTypeFor<TypeCategory::Real, 16> RTNAME(Norm2_16)(
+    const Descriptor &x, const char *source, int line, int dim) {
   return GetTotalReduction<TypeCategory::Real, 16>(
-      x, source, line, dim, mask, Norm2Accumulator<16>{x}, "NORM2");
+      x, source, line, dim, nullptr, Norm2Accumulator<16>{x}, "NORM2");
 }
 #endif
 
 void RTNAME(Norm2Dim)(Descriptor &result, const Descriptor &x, int dim,
-    const char *source, int line, const Descriptor *mask) {
+    const char *source, int line) {
   Terminator terminator{source, line};
   auto type{x.type().GetCategoryAndKind()};
   RUNTIME_CHECK(terminator, type);
   if (type->first == TypeCategory::Real) {
     ApplyFloatingPointKind<Norm2Helper, void>(
-        type->second, terminator, result, x, dim, mask, terminator);
+        type->second, terminator, result, x, dim, nullptr, terminator);
   } else {
     terminator.Crash("NORM2: bad type code %d", x.type().raw());
   }
diff --git a/flang/test/Lower/Intrinsics/norm2.f90 b/flang/test/Lower/Intrinsics/norm2.f90
new file mode 100644 (file)
index 0000000..01c5ad5
--- /dev/null
@@ -0,0 +1,78 @@
+! RUN: bbc -emit-fir %s -o - | FileCheck %s
+! RUN: %flang_fc1 -emit-fir %s -o - | FileCheck %s
+
+! CHECK-LABEL: func @_QPnorm2_test_4(
+! CHECK-SAME: %[[arg0:.*]]: !fir.box<!fir.array<?xf32>>{{.*}}) -> f32
+real(4) function norm2_test_4(a)
+  real(4) :: a(:)
+  ! CHECK-DAG:  %[[c0:.*]] = arith.constant 0 : index
+  ! CHECK-DAG:  %[[arr:.*]] = fir.convert %[[arg0]] : (!fir.box<!fir.array<?xf32>>) -> !fir.box<none>
+  ! CHECK:  %[[dim:.*]] = fir.convert %[[c0]] : (index) -> i32
+  norm2_test_4 = norm2(a)
+  ! CHECK:  %{{.*}} = fir.call @_FortranANorm2_4(%[[arr]], %{{.*}}, %{{.*}}, %[[dim]]) {{.*}} : (!fir.box<none>, !fir.ref<i8>, i32, i32) -> f32
+end function norm2_test_4
+
+! CHECK-LABEL: func @_QPnorm2_test_8(
+! CHECK-SAME: %[[arg0:.*]]: !fir.box<!fir.array<?x?xf64>>{{.*}}) -> f64
+real(8) function norm2_test_8(a)
+  real(8) :: a(:,:)
+  ! CHECK-DAG:  %[[c0:.*]] = arith.constant 0 : index
+  ! CHECK-DAG:  %[[arr:.*]] = fir.convert %[[arg0]] : (!fir.box<!fir.array<?x?xf64>>) -> !fir.box<none>
+  ! CHECK:  %[[dim:.*]] = fir.convert %[[c0]] : (index) -> i32
+  norm2_test_8 = norm2(a)
+  ! CHECK:  %{{.*}} = fir.call @_FortranANorm2_8(%[[arr]], %{{.*}}, %{{.*}}, %[[dim]]) {{.*}} : (!fir.box<none>, !fir.ref<i8>, i32, i32) -> f64
+end function norm2_test_8
+
+! CHECK-LABEL: func @_QPnorm2_test_10(
+! CHECK-SAME: %[[arg0:.*]]: !fir.box<!fir.array<?x?x?xf80>>{{.*}}) -> f80
+real(10) function norm2_test_10(a)
+  real(10) :: a(:,:,:)
+  ! CHECK-DAG:  %[[c0:.*]] = arith.constant 0 : index
+  ! CHECK-DAG:  %[[arr:.*]] = fir.convert %[[arg0]] : (!fir.box<!fir.array<?x?x?xf80>>) -> !fir.box<none>
+  ! CHECK:  %[[dim:.*]] = fir.convert %[[c0]] : (index) -> i32
+  norm2_test_10 = norm2(a)
+  ! CHECK:  %{{.*}} = fir.call @_FortranANorm2_10(%[[arr]], %{{.*}}, %{{.*}}, %[[dim]]) {{.*}} : (!fir.box<none>, !fir.ref<i8>, i32, i32) -> f80
+end function norm2_test_10
+
+! CHECK-LABEL: func @_QPnorm2_test_16(
+! CHECK-SAME: %[[arg0:.*]]: !fir.box<!fir.array<?x?x?xf128>>{{.*}}) -> f128
+real(16) function norm2_test_16(a)
+  real(16) :: a(:,:,:)
+  ! CHECK-DAG:  %[[c0:.*]] = arith.constant 0 : index
+  ! CHECK-DAG:  %[[arr:.*]] = fir.convert %[[arg0]] : (!fir.box<!fir.array<?x?x?xf128>>) -> !fir.box<none>
+  ! CHECK:  %[[dim:.*]] = fir.convert %[[c0]] : (index) -> i32
+  norm2_test_16 = norm2(a)
+  ! CHECK:  %{{.*}} = fir.call @_FortranANorm2_16(%[[arr]], %{{.*}}, %{{.*}}, %[[dim]]) {{.*}} : (!fir.box<none>, !fir.ref<i8>, i32, i32) -> f128
+end function norm2_test_16
+
+! CHECK-LABEL: func @_QPnorm2_test_dim_2(
+! CHECK-SAME: %[[arg0:.*]]: !fir.box<!fir.array<?x?xf32>>{{.*}}, %[[arg1:.*]]: !fir.box<!fir.array<?xf32>>{{.*}})
+subroutine norm2_test_dim_2(a,r)
+  real :: a(:,:)
+  real :: r(:)
+  ! CHECK-DAG:  %[[dim:.*]] = arith.constant 1 : i32
+  ! CHECK-DAG:  %[[r:.*]] = fir.alloca !fir.box<!fir.heap<!fir.array<?xf32>>>
+  ! CHECK-DAG:  %[[res:.*]] = fir.convert %[[r]] : (!fir.ref<!fir.box<!fir.heap<!fir.array<?xf32>>>>) -> !fir.ref<!fir.box<none>>
+  ! CHECK:  %[[arr:.*]] = fir.convert %[[arg0]] : (!fir.box<!fir.array<?x?xf32>>) -> !fir.box<none>
+  r = norm2(a,dim=1)
+  ! CHECK:  %{{.*}} = fir.call @_FortranANorm2Dim(%[[res]], %[[arr]], %[[dim]], %{{.*}}, %{{.*}}) {{.*}} : (!fir.ref<!fir.box<none>>, !fir.box<none>, i32, !fir.ref<i8>, i32) -> none
+  ! CHECK:  %[[box:.*]] = fir.load %[[r]] : !fir.ref<!fir.box<!fir.heap<!fir.array<?xf32>>>>
+  ! CHECK-DAG:  %[[addr:.*]] = fir.box_addr %[[box]] : (!fir.box<!fir.heap<!fir.array<?xf32>>>) -> !fir.heap<!fir.array<?xf32>>
+  ! CHECK-DAG:  fir.freemem %[[addr]]
+end subroutine norm2_test_dim_2
+
+! CHECK-LABEL: func @_QPnorm2_test_dim_3(
+! CHECK-SAME: %[[arg0:.*]]: !fir.box<!fir.array<?x?x?xf32>>{{.*}}, %[[arg1:.*]]: !fir.box<!fir.array<?x?xf32>>{{.*}})
+subroutine norm2_test_dim_3(a,r)
+  real :: a(:,:,:)
+  real :: r(:,:)
+  ! CHECK-DAG:  %[[dim:.*]] = arith.constant 3 : i32
+  ! CHECK-DAG:  %[[r:.*]] = fir.alloca !fir.box<!fir.heap<!fir.array<?x?xf32>>>
+  ! CHECK-DAG:  %[[res:.*]] = fir.convert %[[r]] : (!fir.ref<!fir.box<!fir.heap<!fir.array<?x?xf32>>>>) -> !fir.ref<!fir.box<none>>
+  ! CHECK:  %[[arr:.*]] = fir.convert %[[arg0]] : (!fir.box<!fir.array<?x?x?xf32>>) -> !fir.box<none>
+  r = norm2(a,dim=3)
+  ! CHECK:  %{{.*}} = fir.call @_FortranANorm2Dim(%[[res]], %[[arr]], %[[dim]], %{{.*}}, %{{.*}}) {{.*}} : (!fir.ref<!fir.box<none>>, !fir.box<none>, i32, !fir.ref<i8>, i32) -> none
+  ! CHECK:  %[[box:.*]] = fir.load %[[r]] : !fir.ref<!fir.box<!fir.heap<!fir.array<?x?xf32>>>>
+  ! CHECK-DAG:  %[[addr:.*]] = fir.box_addr %[[box]] : (!fir.box<!fir.heap<!fir.array<?x?xf32>>>) -> !fir.heap<!fir.array<?x?xf32>>
+  ! CHECK-DAG:  fir.freemem %[[addr]]
+end subroutine norm2_test_dim_3