[flang] Generate DOT_PRODUCT runtime call based on the result type.
authorSlava Zakharin <szakharin@nvidia.com>
Wed, 31 Aug 2022 16:55:24 +0000 (09:55 -0700)
committerSlava Zakharin <szakharin@nvidia.com>
Wed, 31 Aug 2022 22:20:12 +0000 (15:20 -0700)
We used to select the runtime function based on the first argument's
type, which was not correct behavior. The selection is done using
the result type now.

Differential Revision: https://reviews.llvm.org/D133032

flang/lib/Lower/IntrinsicCall.cpp
flang/lib/Optimizer/Builder/Runtime/Reduction.cpp
flang/test/Lower/Intrinsics/dot_product.f90
flang/unittests/Optimizer/Builder/Runtime/ReductionTest.cpp

index a8ebe10..c5a1121 100644 (file)
@@ -231,18 +231,18 @@ genDotProd(FN func, mlir::Type resultType, fir::FirOpBuilder &builder,
   // Handle required vector arguments
   mlir::Value vectorA = fir::getBase(args[0]);
   mlir::Value vectorB = fir::getBase(args[1]);
+  // Result type is used for picking appropriate runtime function.
+  mlir::Type eleTy = resultType;
 
-  mlir::Type eleTy = fir::dyn_cast_ptrOrBoxEleTy(vectorA.getType())
-                         .cast<fir::SequenceType>()
-                         .getEleTy();
   if (fir::isa_complex(eleTy)) {
     mlir::Value result = builder.createTemporary(loc, eleTy);
     func(builder, loc, vectorA, vectorB, result);
     return builder.create<fir::LoadOp>(loc, result);
   }
 
-  auto resultBox = builder.create<fir::AbsentOp>(
-      loc, fir::BoxType::get(builder.getI1Type()));
+  // This operation is only used to pass the result type
+  // information to the DotProduct generator.
+  auto resultBox = builder.create<fir::AbsentOp>(loc, fir::BoxType::get(eleTy));
   return func(builder, loc, vectorA, vectorB, resultBox);
 }
 
index 7c6d187..0fa0352 100644 (file)
@@ -799,9 +799,10 @@ mlir::Value fir::runtime::genDotProduct(fir::FirOpBuilder &builder,
                                         mlir::Value vectorBBox,
                                         mlir::Value resultBox) {
   mlir::func::FuncOp func;
-  auto ty = vectorABox.getType();
-  auto arrTy = fir::dyn_cast_ptrOrBoxEleTy(ty);
-  auto eleTy = arrTy.cast<fir::SequenceType>().getEleTy();
+  // For complex data types, resultBox is !fir.ref<!fir.complex<N>>,
+  // otherwise it is !fir.box<T>.
+  auto ty = resultBox.getType();
+  auto eleTy = fir::dyn_cast_ptrOrBoxEleTy(ty);
 
   if (eleTy.isF16() || eleTy.isBF16())
     TODO(loc, "half-precision DOTPRODUCT");
index 3b4c77a..42843dc 100644 (file)
@@ -245,3 +245,46 @@ subroutine dot_prod_logical (x, y, z)
   ! CHECK-DAG: %[[res:.*]] = fir.call @_FortranADotProductLogical(%[[x_conv]], %[[y_conv]], %{{[0-9]+}}, %{{.*}}) : (!fir.box<none>, !fir.box<none>, !fir.ref<i8>, i32) -> i1
   z = dot_product(x,y)
 end subroutine
+
+! CHECK-LABEL: dot_product_mixed_int_real
+! CHECK-SAME: %[[x:arg0]]: !fir.box<!fir.array<?xi32>>
+! CHECK-SAME: %[[y:arg1]]: !fir.box<!fir.array<?xf32>>
+! CHECK-SAME: %[[z:arg2]]: !fir.box<!fir.array<?xf32>>
+subroutine dot_product_mixed_int_real(x, y, z)
+  integer, dimension(1:) :: x
+  real, dimension(1:) :: y, z
+  ! CHECK-DAG: %[[x_conv:.*]] = fir.convert %[[x]] : (!fir.box<!fir.array<?xi32>>) -> !fir.box<none>
+  ! CHECK-DAG: %[[y_conv:.*]] = fir.convert %[[y]] : (!fir.box<!fir.array<?xf32>>) -> !fir.box<none>
+  ! CHECK-DAG: %[[res:.*]] = fir.call @_FortranADotProductReal4(%[[x_conv]], %[[y_conv]], %{{[0-9]+}}, %{{.*}}) : (!fir.box<none>, !fir.box<none>, !fir.ref<i8>, i32) -> f32
+  z = dot_product(x,y)
+end subroutine
+
+! CHECK-LABEL: dot_product_mixed_int_complex
+! CHECK-SAME: %[[x:arg0]]: !fir.box<!fir.array<?xi32>>
+! CHECK-SAME: %[[y:arg1]]: !fir.box<!fir.array<?x!fir.complex<4>>>
+! CHECK-SAME: %[[z:arg2]]: !fir.box<!fir.array<?x!fir.complex<4>>>
+subroutine dot_product_mixed_int_complex(x, y, z)
+  integer, dimension(1:) :: x
+  complex, dimension(1:) :: y, z
+  ! CHECK-DAG: %[[res:.*]] = fir.alloca !fir.complex<4>
+  ! CHECK-DAG: %[[res_conv:.*]] = fir.convert %[[res]] : (!fir.ref<!fir.complex<4>>) -> !fir.ref<complex<f32>>
+  ! CHECK-DAG: %[[x_conv:.*]] = fir.convert %[[x]] : (!fir.box<!fir.array<?xi32>>) -> !fir.box<none>
+  ! CHECK-DAG: %[[y_conv:.*]] = fir.convert %[[y]] : (!fir.box<!fir.array<?x!fir.complex<4>>>) -> !fir.box<none>
+  ! CHECK-DAG: fir.call @_FortranACppDotProductComplex4(%[[res_conv]], %[[x_conv]], %[[y_conv]], %{{[0-9]+}}, %{{.*}}) : (!fir.ref<complex<f32>>, !fir.box<none>, !fir.box<none>, !fir.ref<i8>, i32) -> none
+  z = dot_product(x,y)
+end subroutine
+
+! CHECK-LABEL: dot_product_mixed_real_complex
+! CHECK-SAME: %[[x:arg0]]: !fir.box<!fir.array<?xf32>>
+! CHECK-SAME: %[[y:arg1]]: !fir.box<!fir.array<?x!fir.complex<4>>>
+! CHECK-SAME: %[[z:arg2]]: !fir.box<!fir.array<?x!fir.complex<4>>>
+subroutine dot_product_mixed_real_complex(x, y, z)
+  real, dimension(1:) :: x
+  complex, dimension(1:) :: y, z
+  ! CHECK-DAG: %[[res:.*]] = fir.alloca !fir.complex<4>
+  ! CHECK-DAG: %[[res_conv:.*]] = fir.convert %[[res]] : (!fir.ref<!fir.complex<4>>) -> !fir.ref<complex<f32>>
+  ! CHECK-DAG: %[[x_conv:.*]] = fir.convert %[[x]] : (!fir.box<!fir.array<?xf32>>) -> !fir.box<none>
+  ! CHECK-DAG: %[[y_conv:.*]] = fir.convert %[[y]] : (!fir.box<!fir.array<?x!fir.complex<4>>>) -> !fir.box<none>
+  ! CHECK-DAG: fir.call @_FortranACppDotProductComplex4(%[[res_conv]], %[[x_conv]], %[[y_conv]], %{{[0-9]+}}, %{{.*}}) : (!fir.ref<complex<f32>>, !fir.box<none>, !fir.box<none>, !fir.ref<i8>, i32) -> none
+  z = dot_product(x,y)
+end subroutine
index daa2082..fea28d0 100644 (file)
@@ -202,7 +202,8 @@ void testGenDotProduct(
   mlir::Type refSeqTy = fir::ReferenceType::get(seqTy);
   mlir::Value a = builder.create<fir::UndefOp>(loc, refSeqTy);
   mlir::Value b = builder.create<fir::UndefOp>(loc, refSeqTy);
-  mlir::Value result = builder.create<fir::UndefOp>(loc, seqTy);
+  mlir::Value result =
+      builder.create<fir::UndefOp>(loc, fir::ReferenceType::get(eleTy));
   mlir::Value prod = fir::runtime::genDotProduct(builder, loc, a, b, result);
   if (fir::isa_complex(eleTy))
     checkCallOpFromResultBox(result, fctName, 3);