From 9ab1efc77ab1ba104c6dce48e5d3e34d1174bb5d Mon Sep 17 00:00:00 2001 From: peter klausler Date: Thu, 26 Aug 2021 16:01:03 -0700 Subject: [PATCH] [flang] Fold UNPACK and TRANSPOSE Implement constant folding for the transformational intrinsic functions UNPACK and TRANSPOSE. Differential Revision: https://reviews.llvm.org/D109010 --- flang/lib/Evaluate/fold-implementation.h | 80 +++++++++++++++++++++++++++++++- flang/test/Evaluate/folding19.f90 | 8 +++- flang/test/Evaluate/folding25.f90 | 10 ++++ flang/test/Evaluate/folding26.f90 | 7 +++ 4 files changed, 103 insertions(+), 2 deletions(-) create mode 100644 flang/test/Evaluate/folding25.f90 create mode 100644 flang/test/Evaluate/folding26.f90 diff --git a/flang/lib/Evaluate/fold-implementation.h b/flang/lib/Evaluate/fold-implementation.h index 134222d..f68e2ea 100644 --- a/flang/lib/Evaluate/fold-implementation.h +++ b/flang/lib/Evaluate/fold-implementation.h @@ -65,6 +65,8 @@ public: Expr EOSHIFT(FunctionRef &&); Expr PACK(FunctionRef &&); Expr RESHAPE(FunctionRef &&); + Expr TRANSPOSE(FunctionRef &&); + Expr UNPACK(FunctionRef &&); private: FoldingContext &context_; @@ -853,6 +855,78 @@ template Expr Folder::RESHAPE(FunctionRef &&funcRef) { return MakeInvalidIntrinsic(std::move(funcRef)); } +template Expr Folder::TRANSPOSE(FunctionRef &&funcRef) { + auto args{funcRef.arguments()}; + CHECK(args.size() == 1); + const auto *matrix{UnwrapConstantValue(args[0])}; + if (!matrix) { + return Expr{std::move(funcRef)}; + } + // Argument is constant. Traverse its elements in transposed order. + std::vector> resultElements; + ConstantSubscripts at(2); + for (ConstantSubscript j{0}; j < matrix->shape()[0]; ++j) { + at[0] = matrix->lbounds()[0] + j; + for (ConstantSubscript k{0}; k < matrix->shape()[1]; ++k) { + at[1] = matrix->lbounds()[1] + k; + resultElements.push_back(matrix->At(at)); + } + } + at = matrix->shape(); + std::swap(at[0], at[1]); + return Expr{PackageConstant(std::move(resultElements), *matrix, at)}; +} + +template Expr Folder::UNPACK(FunctionRef &&funcRef) { + auto args{funcRef.arguments()}; + CHECK(args.size() == 3); + const auto *vector{UnwrapConstantValue(args[0])}; + auto convertedMask{Fold(context_, + ConvertToType( + Expr{DEREF(UnwrapExpr>(args[1]))}))}; + const auto *mask{UnwrapConstantValue(convertedMask)}; + const auto *field{UnwrapConstantValue(args[2])}; + if (!vector || !mask || !field) { + return Expr{std::move(funcRef)}; + } + // Arguments are constant. + if (field->Rank() > 0 && field->shape() != mask->shape()) { + // Error already emitted from intrinsic processing + return MakeInvalidIntrinsic(std::move(funcRef)); + } + ConstantSubscript maskElements{GetSize(mask->shape())}; + ConstantSubscript truths{0}; + ConstantSubscripts maskAt{mask->lbounds()}; + for (ConstantSubscript j{0}; j < maskElements; + ++j, mask->IncrementSubscripts(maskAt)) { + if (mask->At(maskAt).IsTrue()) { + ++truths; + } + } + if (truths > GetSize(vector->shape())) { + context_.messages().Say( + "Invalid 'vector=' argument in UNPACK: the 'mask=' argument has %jd true elements, but the vector has only %jd elements"_err_en_US, + static_cast(truths), + static_cast(GetSize(vector->shape()))); + return MakeInvalidIntrinsic(std::move(funcRef)); + } + std::vector> resultElements; + ConstantSubscripts vectorAt{vector->lbounds()}; + ConstantSubscripts fieldAt{field->lbounds()}; + for (ConstantSubscript j{0}; j < maskElements; ++j) { + if (mask->At(maskAt).IsTrue()) { + resultElements.push_back(vector->At(vectorAt)); + vector->IncrementSubscripts(vectorAt); + } else { + resultElements.push_back(field->At(fieldAt)); + } + mask->IncrementSubscripts(maskAt); + field->IncrementSubscripts(fieldAt); + } + return Expr{ + PackageConstant(std::move(resultElements), *vector, mask->shape())}; +} + template Expr FoldMINorMAX( FoldingContext &context, FunctionRef &&funcRef, Ordering order) { @@ -943,8 +1017,12 @@ Expr FoldOperation(FoldingContext &context, FunctionRef &&funcRef) { return Folder{context}.PACK(std::move(funcRef)); } else if (name == "reshape") { return Folder{context}.RESHAPE(std::move(funcRef)); + } else if (name == "transpose") { + return Folder{context}.TRANSPOSE(std::move(funcRef)); + } else if (name == "unpack") { + return Folder{context}.UNPACK(std::move(funcRef)); } - // TODO: spread, unpack, transpose + // TODO: spread // TODO: extends_type_of, same_type_as if constexpr (!std::is_same_v) { return FoldIntrinsicFunction(context, std::move(funcRef)); diff --git a/flang/test/Evaluate/folding19.f90 b/flang/test/Evaluate/folding19.f90 index 5940f25..8cfaeb1 100644 --- a/flang/test/Evaluate/folding19.f90 +++ b/flang/test/Evaluate/folding19.f90 @@ -43,5 +43,11 @@ module m !CHECK: error: Invalid 'vector=' argument in PACK: the 'mask=' argument has 3 true elements, but the vector has only 2 elements x = pack(array, mask, [0,0]) end subroutine + subroutine s5 + logical, parameter :: mask(2,3) = reshape([.false., .true., .true., .false., .false., .true.], shape(mask)) + integer, parameter :: field(3,2) = reshape([(-j,j=1,6)], shape(field)) + integer :: x(2,3) + !CHECK: error: Invalid 'vector=' argument in UNPACK: the 'mask=' argument has 3 true elements, but the vector has only 2 elements + x = unpack([1,2], mask, 0) + end subroutine end module - diff --git a/flang/test/Evaluate/folding25.f90 b/flang/test/Evaluate/folding25.f90 new file mode 100644 index 0000000..a94565f --- /dev/null +++ b/flang/test/Evaluate/folding25.f90 @@ -0,0 +1,10 @@ +! RUN: %S/test_folding.sh %s %t %flang_fc1 +! REQUIRES: shell +! Tests folding of UNPACK (valid cases) +module m + integer, parameter :: vector(*) = [1, 2, 3, 4] + integer, parameter :: field(2,3) = reshape([(-j,j=1,6)], shape(field)) + logical, parameter :: mask(*,*) = reshape([.false., .true., .true., .false., .false., .true.], shape(field)) + logical, parameter :: test_unpack_1 = all(unpack(vector, mask, 0) == reshape([0,1,2,0,0,3], shape(mask))) + logical, parameter :: test_unpack_2 = all(unpack(vector, mask, field) == reshape([-1,1,2,-4,-5,3], shape(mask))) +end module diff --git a/flang/test/Evaluate/folding26.f90 b/flang/test/Evaluate/folding26.f90 new file mode 100644 index 0000000..09fd08a --- /dev/null +++ b/flang/test/Evaluate/folding26.f90 @@ -0,0 +1,7 @@ +! RUN: %S/test_folding.sh %s %t %flang_fc1 +! REQUIRES: shell +! Tests folding of TRANSPOSE +module m + integer, parameter :: matrix(0:1,0:2) = reshape([1,2,3,4,5,6],shape(matrix)) + logical, parameter :: test_transpose_1 = all(transpose(matrix) == reshape([1,3,5,2,4,6],[3,2])) +end module -- 2.7.4