From eb56fa97de96856bb63e31340598a356056470c5 Mon Sep 17 00:00:00 2001 From: Frederik Gossen Date: Thu, 29 Apr 2021 10:07:20 +0200 Subject: [PATCH] [MLIR][Shape] Fix `shape.broadcast` to standard lowering Differential Revision: https://reviews.llvm.org/D101456 --- .../Conversion/ShapeToStandard/ShapeToStandard.cpp | 23 +++++++++++----------- .../ShapeToStandard/shape-to-standard.mlir | 11 +++++++++++ 2 files changed, 23 insertions(+), 11 deletions(-) diff --git a/mlir/lib/Conversion/ShapeToStandard/ShapeToStandard.cpp b/mlir/lib/Conversion/ShapeToStandard/ShapeToStandard.cpp index e0342f6..9e0020a 100644 --- a/mlir/lib/Conversion/ShapeToStandard/ShapeToStandard.cpp +++ b/mlir/lib/Conversion/ShapeToStandard/ShapeToStandard.cpp @@ -155,17 +155,18 @@ LogicalResult BroadcastOpConverter::matchAndRewrite( return lb.create(indexTy, maxRank, v); })); - rewriter.replaceOp( - op, lb.create( - getExtentTensorType(lb.getContext()), ValueRange{maxRank}, - [&](OpBuilder &b, Location loc, ValueRange args) { - Value broadcastedDim = getBroadcastedDim( - ImplicitLocOpBuilder(loc, b), transformed.shapes(), - rankDiffs, args[0]); - - b.create(loc, broadcastedDim); - }) - ->getResults()); + Value replacement = lb.create( + getExtentTensorType(lb.getContext()), ValueRange{maxRank}, + [&](OpBuilder &b, Location loc, ValueRange args) { + Value broadcastedDim = + getBroadcastedDim(ImplicitLocOpBuilder(loc, b), + transformed.shapes(), rankDiffs, args[0]); + + b.create(loc, broadcastedDim); + }); + if (replacement.getType() != op.getType()) + replacement = lb.create(op.getType(), replacement); + rewriter.replaceOp(op, replacement); return success(); } diff --git a/mlir/test/Conversion/ShapeToStandard/shape-to-standard.mlir b/mlir/test/Conversion/ShapeToStandard/shape-to-standard.mlir index 751f500..9800044 100644 --- a/mlir/test/Conversion/ShapeToStandard/shape-to-standard.mlir +++ b/mlir/test/Conversion/ShapeToStandard/shape-to-standard.mlir @@ -593,6 +593,17 @@ func @broadcast_3_shapes_different_extents(%a : tensor<2xindex>, return } +// ---- + +// CHECK-LABEL: @broadcast_to_known_rank +func @broadcast_to_known_rank(%a : tensor<1xindex>, %b : tensor<3xindex>) + -> tensor<3xindex> { + // CHECK: %[[RES:.*]] = tensor.cast %{{.*}} : tensor to tensor<3xindex> + // CHECK: return %[[RES]] : tensor<3xindex> + %0 = shape.broadcast %a, %b : tensor<1xindex>, tensor<3xindex> -> tensor<3xindex> + return %0 : tensor<3xindex> +} + // ----- // Lower `split_at` -- 2.7.4