From 683a6e1c9e5396f64086c07bec334a38acd0ec7a Mon Sep 17 00:00:00 2001 From: Tom Eccles Date: Fri, 17 Mar 2023 15:49:22 +0000 Subject: [PATCH] [flang][hlfir] lower hlfir.shape_of If possible the shape is gotten from the bufferization of the expr argument. The simple cases should already have been resolved during lowering. This is mostly intended for cases where shape information is added in between lowering and the end of bufferization (for example transformational intrinsics with assumed shape arguments). Depends on: D146832 Differential Revision: https://reviews.llvm.org/D146833 --- .../Optimizer/HLFIR/Transforms/BufferizeHLFIR.cpp | 45 +++++++++++++++--- flang/test/HLFIR/shapeof-lowering.fir | 55 ++++++++++++++++++++++ 2 files changed, 94 insertions(+), 6 deletions(-) create mode 100644 flang/test/HLFIR/shapeof-lowering.fir diff --git a/flang/lib/Optimizer/HLFIR/Transforms/BufferizeHLFIR.cpp b/flang/lib/Optimizer/HLFIR/Transforms/BufferizeHLFIR.cpp index 4b631b2..21fe2d9 100644 --- a/flang/lib/Optimizer/HLFIR/Transforms/BufferizeHLFIR.cpp +++ b/flang/lib/Optimizer/HLFIR/Transforms/BufferizeHLFIR.cpp @@ -27,8 +27,9 @@ #include "mlir/IR/PatternMatch.h" #include "mlir/Pass/Pass.h" #include "mlir/Pass/PassManager.h" +#include "mlir/Support/LogicalResult.h" #include "mlir/Transforms/DialectConversion.h" -#include +#include "llvm/ADT/TypeSwitch.h" namespace hlfir { #define GEN_PASS_DEF_BUFFERIZEHLFIR @@ -169,6 +170,38 @@ struct AsExprOpConversion : public mlir::OpConversionPattern { } }; +struct ShapeOfOpConversion + : public mlir::OpConversionPattern { + using mlir::OpConversionPattern::OpConversionPattern; + + mlir::LogicalResult + matchAndRewrite(hlfir::ShapeOfOp shapeOf, OpAdaptor adaptor, + mlir::ConversionPatternRewriter &rewriter) const override { + mlir::Location loc = shapeOf.getLoc(); + mlir::ModuleOp mod = shapeOf->getParentOfType(); + fir::FirOpBuilder builder(rewriter, fir::getKindMapping(mod)); + + mlir::Value shape; + hlfir::Entity bufferizedExpr{getBufferizedExprStorage(adaptor.getExpr())}; + if (bufferizedExpr.isVariable()) { + shape = hlfir::genShape(loc, builder, bufferizedExpr); + } else { + // everything else failed so try to create a shape from static type info + hlfir::ExprType exprTy = + adaptor.getExpr().getType().dyn_cast_or_null(); + if (exprTy) + shape = hlfir::genExprShape(builder, loc, exprTy); + } + // expected to never happen + if (!shape) + return emitError(loc, + "Unresolvable hlfir.shape_of where extents are unknown"); + + rewriter.replaceOp(shapeOf, shape); + return mlir::success(); + } +}; + struct ApplyOpConversion : public mlir::OpConversionPattern { using mlir::OpConversionPattern::OpConversionPattern; explicit ApplyOpConversion(mlir::MLIRContext *ctx) @@ -529,11 +562,11 @@ public: auto module = this->getOperation(); auto *context = &getContext(); mlir::RewritePatternSet patterns(context); - patterns - .insert(context); + patterns.insert(context); mlir::ConversionTarget target(*context); target.addIllegalOp>>) -> !fir.shape<1> { + %c0 = arith.constant 0 : index + %59:3 = fir.box_dims %arg0, %c0 : (!fir.box>>, index) -> (index, index, index) + %60 = fir.box_addr %arg0 : (!fir.box>>) -> !fir.heap> + %61 = fir.shape_shift %59#0, %59#1 : (index, index) -> !fir.shapeshift<1> + %62:2 = hlfir.declare %60(%61) {uniq_name = ".tmp.intrinsic_result"} : (!fir.heap>, !fir.shapeshift<1>) -> (!fir.box>, !fir.heap>) + %true = arith.constant true + %63 = hlfir.as_expr %62#0 move %true : (!fir.box>, i1) -> !hlfir.expr + %64 = hlfir.shape_of %63 : (!hlfir.expr) -> !fir.shape<1> + return %64 : !fir.shape<1> +} +// CHECK-LABEL: @shapeof_asexpr +// CHECK: %[[ARG0:.*]]: !fir.box>> +// CHECK-NEXT: %[[C0:.*]] = arith.constant 0 +// CHECK-NEXT: %[[BOX_DIMS:.*]]:3 = fir.box_dims %[[ARG0]], %[[C0]] +// CHECK-NEXT: %[[BOX_ADDR:.*]] = fir.box_addr %[[ARG0]] +// CHECK-NEXT: %[[SHPE_SHFT:.*]] = fir.shape_shift %[[BOX_DIMS]]#0, %[[BOX_DIMS]]#1 +// CHECK-NEXT: %[[VAR:.*]]:2 = hlfir.declare %[[BOX_ADDR]](%[[SHPE_SHFT]]) +// CHECK-NEXT: %[[TRUE:.*]] = arith.constant true +// CHECK-NEXT: %[[TUPLE0:.*]] = fir.undefined tuple +// CHECK-NEXT: %[[TUPLE1:.*]] = fir.insert_value %[[TUPLE0]], %[[TRUE]] +// CHECK-NEXT: %[[TUPLE2:.*]] = fir.insert_value %[[TUPLE1]], %[[VAR]]#0 +// CHECK-NEXT: %[[SHAPE:.*]] = fir.shape %[[BOX_DIMS]]#1 +// CHECK-NEXT: return %[[SHAPE]] + +func.func @shapeof_elemental() -> !fir.shape<1> { + %c1 = arith.constant 1 : index + %0 = fir.shape %c1 : (index) -> !fir.shape<1> + %1 = hlfir.elemental %0 : (!fir.shape<1>) -> !hlfir.expr { + ^bb0(%arg3: index): + hlfir.yield_element %arg3 : index + } + %2 = hlfir.shape_of %1 : (!hlfir.expr) -> !fir.shape<1> + return %2 : !fir.shape<1> +} +// CHECK-LABEL: @shapeof_elemental +// CHECK: %[[C1:.*]] = arith.constant 1 : index +// CHECK-NEXT: %[[SHAPE:.*]] = fir.shape %[[C1]] +// CHECK: fir.do_loop %{{.*}} = %{{.*}} to %[[C1:.*]] +// CHECK: return %[[SHAPE]] + +func.func @shapeof_fallback(%arg0: !hlfir.expr<1x2x3xi32>) -> !fir.shape<3> { + %shape = hlfir.shape_of %arg0 : (!hlfir.expr<1x2x3xi32>) -> !fir.shape<3> + return %shape : !fir.shape<3> +} +// CHECK-LABEL: @shapeof_fallback +// CHECK: %[[EXPR:.*]]: !hlfir.expr<1x2x3xi32> +// CHECK-NEXT: %[[C1:.*]] = arith.constant 1 : index +// CHECK-NEXT: %[[C2:.*]] = arith.constant 2 : index +// CHECK-NEXT: %[[C3:.*]] = arith.constant 3 : index +// CHECK-NEXT: %[[SHAPE:.*]] = fir.shape %[[C1]], %[[C2]], %[[C3]] : +// CHECK-NEXT: return %[[SHAPE]] -- 2.7.4