From 8178e41dc1a395fc0a99f945bf27fbfca871d3e5 Mon Sep 17 00:00:00 2001 From: Tres Popp Date: Fri, 9 Oct 2020 16:45:50 +0200 Subject: [PATCH] [mlir] Type erase inputs to select statements in shape.broadcast lowering. This is required or broadcasting with operands of different ranks will lead to failures as the select op requires both possible outputs and its output type to be the same. Differential Revision: https://reviews.llvm.org/D89134 --- .../Conversion/ShapeToStandard/ShapeToStandard.cpp | 10 ++++- .../ShapeToStandard/shape-to-standard.mlir | 50 ++++++++++++++++++++-- 2 files changed, 54 insertions(+), 6 deletions(-) diff --git a/mlir/lib/Conversion/ShapeToStandard/ShapeToStandard.cpp b/mlir/lib/Conversion/ShapeToStandard/ShapeToStandard.cpp index b1319a8..2840761 100644 --- a/mlir/lib/Conversion/ShapeToStandard/ShapeToStandard.cpp +++ b/mlir/lib/Conversion/ShapeToStandard/ShapeToStandard.cpp @@ -99,10 +99,16 @@ LogicalResult BroadcastOpConverter::matchAndRewrite( rewriter.create(loc, lhsRankULE, lhsRank, rhsRank); Value greaterRank = rewriter.create(loc, lhsRankULE, rhsRank, lhsRank); + auto erasedRankType = + RankedTensorType::get({ShapedType::kDynamicSize}, indexTy); + Value rankErasedLhs = + rewriter.create(loc, erasedRankType, transformed.lhs()); + Value rankErasedRhs = + rewriter.create(loc, erasedRankType, transformed.rhs()); Value lesserRankOperand = - rewriter.create(loc, lhsRankULE, op.lhs(), op.rhs()); + rewriter.create(loc, lhsRankULE, rankErasedLhs, rankErasedRhs); Value greaterRankOperand = - rewriter.create(loc, lhsRankULE, op.rhs(), op.lhs()); + rewriter.create(loc, lhsRankULE, rankErasedRhs, rankErasedLhs); // Allocate stack memory for the broadcasted extent tensor. Type memTy = MemRefType::get({ShapedType::kDynamicSize}, indexTy); diff --git a/mlir/test/Conversion/ShapeToStandard/shape-to-standard.mlir b/mlir/test/Conversion/ShapeToStandard/shape-to-standard.mlir index 6207486..4dc4868 100644 --- a/mlir/test/Conversion/ShapeToStandard/shape-to-standard.mlir +++ b/mlir/test/Conversion/ShapeToStandard/shape-to-standard.mlir @@ -305,9 +305,9 @@ func @broadcast(%a : tensor, %b : !shape.shape) -> !shape.shape { // ----- -// CHECK-LABEL: @broadcast +// CHECK-LABEL: @broadcast_unknown_extents // CHECK-SAME: (%[[LHS:.*]]: tensor, %[[RHS:.*]]: tensor) -func @broadcast(%a : tensor, %b : tensor) { +func @broadcast_unknown_extents(%a : tensor, %b : tensor) { // CHECK: %[[C0:.*]] = constant 0 : index // CHECK: %[[C1:.*]] = constant 1 : index // CHECK: %[[LHS_RANK:.*]] = dim %[[LHS]], %[[C0]] : tensor @@ -315,8 +315,10 @@ func @broadcast(%a : tensor, %b : tensor) { // CHECK: %[[LHS_RANK_ULE:.*]] = cmpi "ule", %[[LHS_RANK]], %[[RHS_RANK]] : index // CHECK: %[[LESSER_RANK:.*]] = select %[[LHS_RANK_ULE]], %[[LHS_RANK]], %[[RHS_RANK]] : index // CHECK: %[[GREATER_RANK:.*]] = select %[[LHS_RANK_ULE]], %[[RHS_RANK]], %[[LHS_RANK]] : index - // CHECK: %[[LESSER_RANK_OPERAND:.*]] = select %[[LHS_RANK_ULE]], %[[LHS]], %[[RHS]] : tensor - // CHECK: %[[GREATER_RANK_OPERAND:.*]] = select %[[LHS_RANK_ULE]], %[[RHS]], %[[LHS]] : tensor + // CHECK: %[[ERASED_LHS:.*]] = tensor_cast %[[LHS]] : tensor to tensor + // CHECK: %[[ERASED_RHS:.*]] = tensor_cast %[[RHS]] : tensor to tensor + // CHECK: %[[LESSER_RANK_OPERAND:.*]] = select %[[LHS_RANK_ULE]], %[[ERASED_LHS]], %[[ERASED_RHS]] : tensor + // CHECK: %[[GREATER_RANK_OPERAND:.*]] = select %[[LHS_RANK_ULE]], %[[ERASED_RHS]], %[[ERASED_LHS]] : tensor // CHECK: %[[MEM:.*]] = alloca(%[[GREATER_RANK]]) : memref // CHECK: %[[RANK_DIFF:.*]] = subi %[[GREATER_RANK]], %[[LESSER_RANK]] : index // CHECK: scf.for %[[IV:.*]] = %[[C0]] to %[[RANK_DIFF]] step %[[C1]] { @@ -340,3 +342,43 @@ func @broadcast(%a : tensor, %b : tensor) { : tensor, tensor -> tensor return } + +// ----- + +// CHECK-LABEL: @broadcast_known_different_extents +// CHECK-SAME: (%[[LHS:.*]]: tensor<2xindex>, %[[RHS:.*]]: tensor<3xindex>) +func @broadcast_known_different_extents(%a : tensor<2xindex>, %b : tensor<3xindex>) { + // CHECK: %[[C0:.*]] = constant 0 : index + // CHECK: %[[C1:.*]] = constant 1 : index + // CHECK: %[[LHS_RANK:.*]] = dim %[[LHS]], %[[C0]] : tensor<2xindex> + // CHECK: %[[RHS_RANK:.*]] = dim %[[RHS]], %[[C0]] : tensor<3xindex> + // CHECK: %[[LHS_RANK_ULE:.*]] = cmpi "ule", %[[LHS_RANK]], %[[RHS_RANK]] : index + // CHECK: %[[LESSER_RANK:.*]] = select %[[LHS_RANK_ULE]], %[[LHS_RANK]], %[[RHS_RANK]] : index + // CHECK: %[[GREATER_RANK:.*]] = select %[[LHS_RANK_ULE]], %[[RHS_RANK]], %[[LHS_RANK]] : index + // CHECK: %[[ERASED_LHS:.*]] = tensor_cast %[[LHS]] : tensor<2xindex> to tensor + // CHECK: %[[ERASED_RHS:.*]] = tensor_cast %[[RHS]] : tensor<3xindex> to tensor + // CHECK: %[[LESSER_RANK_OPERAND:.*]] = select %[[LHS_RANK_ULE]], %[[ERASED_LHS]], %[[ERASED_RHS]] : tensor + // CHECK: %[[GREATER_RANK_OPERAND:.*]] = select %[[LHS_RANK_ULE]], %[[ERASED_RHS]], %[[ERASED_LHS]] : tensor + // CHECK: %[[MEM:.*]] = alloca(%[[GREATER_RANK]]) : memref + // CHECK: %[[RANK_DIFF:.*]] = subi %[[GREATER_RANK]], %[[LESSER_RANK]] : index + // CHECK: scf.for %[[IV:.*]] = %[[C0]] to %[[RANK_DIFF]] step %[[C1]] { + // CHECK: %[[EXTENT:.*]] = extract_element %[[GREATER_RANK_OPERAND]][%[[IV]]] : tensor + // CHECK: store %[[EXTENT]], %[[MEM]][%[[IV]]] : memref + // CHECK: } + // CHECK: scf.for %[[IV:.*]] = %[[RANK_DIFF]] to %[[GREATER_RANK]] step %[[C1]] { + // CHECK: %[[GREATER_RANK_OPERAND_EXTENT:.*]] = extract_element %[[GREATER_RANK_OPERAND]][%[[IV]]] : tensor + // CHECK: %[[GREATER_OPERAND_EXTENT_IS_ONE:.*]] = cmpi "eq", %[[GREATER_RANK_OPERAND_EXTENT]], %[[C1]] : index + // CHECK: %[[EXTENT:.*]] = scf.if %[[GREATER_OPERAND_EXTENT_IS_ONE]] -> (index) { + // CHECK: %[[IV_SHIFTED:.*]] = subi %[[IV]], %[[RANK_DIFF]] : index + // CHECK: %[[LESSER_RANK_OPERAND_EXTENT:.*]] = extract_element %[[LESSER_RANK_OPERAND]][%[[IV_SHIFTED]]] : tensor + // CHECK: scf.yield %[[LESSER_RANK_OPERAND_EXTENT]] : index + // CHECK: } else { + // CHECK: scf.yield %[[GREATER_RANK_OPERAND_EXTENT]] : index + // CHECK: } + // CHECK: store %[[EXTENT]], %[[MEM]][%[[IV]]] : memref + // CHECK: } + // CHECK: %[[BROADCASTED:.*]] = tensor_load %[[MEM]] : memref + %0 = shape.broadcast %a, %b + : tensor<2xindex>, tensor<3xindex> -> tensor + return +} -- 2.7.4