From 3598c24983be90a582cdafb7864e302193c340f4 Mon Sep 17 00:00:00 2001 From: Oleg Shyshkov Date: Mon, 21 Nov 2022 13:15:50 +0100 Subject: [PATCH] [mlir][linalg] Change linalg.broadcast `dimensions` attribute to represent added dimensions. Original [RFC](discourse.llvm.org/t/rfc-primitive-ops-add-broadcastop-to-linalg/66313) defined `dimensions` as a map from input to init, but a discussion in reviews.llvm.org/D138291 concluded that it's more natural for `dimensions` to represent added dims. Also this way is more consistent with `linalg.reduce`. Differential Revision: https://reviews.llvm.org/D138408 --- .../mlir/Dialect/Linalg/IR/LinalgStructuredOps.td | 9 +---- mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp | 46 ++++++++++------------ mlir/test/Dialect/Linalg/invalid.mlir | 25 +++--------- mlir/test/Dialect/Linalg/one-shot-bufferize.mlir | 2 +- mlir/test/Dialect/Linalg/roundtrip.mlir | 6 +-- .../lower-to-loops-using-interface.mlir | 2 +- 6 files changed, 34 insertions(+), 56 deletions(-) diff --git a/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td b/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td index e822435..815d542 100644 --- a/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td +++ b/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td @@ -463,19 +463,14 @@ def BroadcastOp : LinalgStructuredBase_Op<"broadcast", [ SingleBlockImplicitTerminator<"YieldOp">]> { let summary = "Static broadcast operator"; let description = [{ - Broadcast the input into the given shape by adding dimensions. - - Each index in `dimensions` attribute maps input dimension into the - corresponding target dimension. The length of the `dimensions` list should - match the `input` rank and dimensions should be in sorted order. There is no - ambiguity at compile-time about shape information. + Broadcast the input into the given shape by adding `dimensions`. Example: ``` %bcast = linalg.broadcast ins(%input:tensor<16xf32>) inits(%init:tensor<16x64xf32>) - dimensions = [0] + dimensions = [1] ``` }]; diff --git a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp index ea5263c..5ce936e 100644 --- a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp +++ b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp @@ -1511,10 +1511,6 @@ void BroadcastOp::print(OpAsmPrinter &p) { LogicalResult BroadcastOp::verify() { ArrayRef dimensionsRef = getDimensions(); - if (!llvm::is_sorted(dimensionsRef)) - return emitOpError() << "dimensions should be in sorted order, implicit " - "transpose is not supported"; - auto inputType = getInput().getType(); auto initType = getInit().getType(); @@ -1524,34 +1520,35 @@ LogicalResult BroadcastOp::verify() { auto inputShape = inputType.getShape(); auto initShape = initType.getShape(); - if ((size_t)inputRank != dimensionsRef.size()) - return emitOpError() - << "input rank does match the number of dimensions. expected: " - << inputRank << ", got: " << dimensionsRef.size(); - - // Mapping from init dims to input dims. - const int64_t kUnmappedDim = -1; - SmallVector reverseDimMap(initRank, kUnmappedDim); + if ((size_t)inputRank + dimensionsRef.size() != (size_t)initRank) + return emitOpError() << "input rank plus added dimensions does not " + "match init rank. input rank: " + << inputRank + << ", dimensions size: " << dimensionsRef.size() + << ", init rank: " << initRank; for (const auto &[idx, dim] : llvm::enumerate(dimensionsRef)) { if (dim < 0 || dim >= initRank) return emitOpError() << "dimension " << idx << " is out of range. expected range: [0, " << initRank - 1 << "], got: " << dim; + } - reverseDimMap[dim] = idx; + // Mapping from input dims to init dims. + SmallVector dimMap; + for (auto dim : llvm::seq(0, initRank)) { + if (!llvm::is_contained(dimensionsRef, dim)) + dimMap.push_back(dim); } - for (const auto &[idx, inputDimIdx] : llvm::enumerate(reverseDimMap)) { - if (inputDimIdx != kUnmappedDim) { - // This dimensions is mapped from the input. Init and input dims should - // match. - if (inputShape[inputDimIdx] != initShape[idx]) - return emitOpError() - << "input dim " << inputDimIdx << " should match init dim " - << idx << ". input: " << inputShape[inputDimIdx] - << ", init: " << initShape[idx]; - } + for (const auto &[inputDimIdx, initDimIdx] : llvm::enumerate(dimMap)) { + // This dimensions is mapped from the input. Init and input dims should + // match. + if (inputShape[inputDimIdx] != initShape[initDimIdx]) + return emitOpError() << "input dim " << inputDimIdx + << " should match init dim " << initDimIdx + << ". input: " << inputShape[inputDimIdx] + << ", init: " << initShape[initDimIdx]; } return success(); @@ -1566,8 +1563,7 @@ ArrayAttr BroadcastOp::getIndexingMaps() { Builder builder(getContext()); int64_t rank = getInit().getType().getRank(); return builder.getAffineMapArrayAttr( - {builder.getMultiDimIdentityMap(rank).getSubMap( - llvm::to_vector_of(getDimensions())), + {builder.getMultiDimIdentityMap(rank).dropResults(getDimensions()), builder.getMultiDimIdentityMap(rank)}); } diff --git a/mlir/test/Dialect/Linalg/invalid.mlir b/mlir/test/Dialect/Linalg/invalid.mlir index 9eddc1c7..03540be 100644 --- a/mlir/test/Dialect/Linalg/invalid.mlir +++ b/mlir/test/Dialect/Linalg/invalid.mlir @@ -676,27 +676,14 @@ func.func @transpose_input_init_rank_mismatch(%input: tensor<16x32xf32>, // ----- -func.func @broadcast_unsorted_dims( - %input: tensor<4x16xf32>, %init: tensor<4x8x16xf32>) - -> tensor<4x8x16xf32> { - // expected-error @+1 {{'linalg.broadcast' op dimensions should be in sorted order}} - %bcast = linalg.broadcast - ins(%input:tensor<4x16xf32>) - outs(%init:tensor<4x8x16xf32>) - dimensions = [1, 0] - func.return %bcast : tensor<4x8x16xf32> -} - -// ----- - func.func @broadcast_input_dims_rank_mismatch( %input: tensor<4x16xf32>, %init: tensor<4x8x16xf32>) -> tensor<4x8x16xf32> { - // expected-error @+1 {{'linalg.broadcast' op input rank does match the number of dimensions. expected: 2, got: 1}} + // expected-error @+1 {{'linalg.broadcast' op input rank plus added dimensions does not match init rank. }} %bcast = linalg.broadcast ins(%input:tensor<4x16xf32>) outs(%init:tensor<4x8x16xf32>) - dimensions = [0] + dimensions = [1, 2] func.return %bcast : tensor<4x8x16xf32> } @@ -705,11 +692,11 @@ func.func @broadcast_input_dims_rank_mismatch( func.func @broadcast_unsorted_dims( %input: tensor<4x16xf32>, %init: tensor<4x8x16xf32>) -> tensor<4x8x16xf32> { - // expected-error @+1 {{'linalg.broadcast' op dimension 1 is out of range. expected range: [0, 2], got: 5}} + // expected-error @+1 {{'linalg.broadcast' op dimension 0 is out of range. expected range: [0, 2], got: 5}} %bcast = linalg.broadcast ins(%input:tensor<4x16xf32>) outs(%init:tensor<4x8x16xf32>) - dimensions = [0, 5] + dimensions = [5] func.return %bcast : tensor<4x8x16xf32> } @@ -722,7 +709,7 @@ func.func @broadcast_mapped_dim_mismatch( %bcast = linalg.broadcast ins(%input:tensor<4x16xf32>) outs(%init:tensor<5x8x16xf32>) - dimensions = [0, 2] + dimensions = [1] func.return %bcast : tensor<5x8x16xf32> } @@ -735,6 +722,6 @@ func.func @broadcast_size_1_extension_not_supported( %bcast = linalg.broadcast ins(%input:tensor<1x16xf32>) outs(%init:tensor<4x?x16xf32>) - dimensions = [0, 2] + dimensions = [1] func.return %bcast : tensor<4x?x16xf32> } diff --git a/mlir/test/Dialect/Linalg/one-shot-bufferize.mlir b/mlir/test/Dialect/Linalg/one-shot-bufferize.mlir index 9d100d5..424539b 100644 --- a/mlir/test/Dialect/Linalg/one-shot-bufferize.mlir +++ b/mlir/test/Dialect/Linalg/one-shot-bufferize.mlir @@ -395,7 +395,7 @@ func.func @broadcast(%input: tensor<8x32xf32>, %bcast = linalg.broadcast ins(%input:tensor<8x32xf32>) outs(%init:tensor<8x16x32xf32>) - dimensions = [0, 2] + dimensions = [1] func.return %bcast : tensor<8x16x32xf32> } diff --git a/mlir/test/Dialect/Linalg/roundtrip.mlir b/mlir/test/Dialect/Linalg/roundtrip.mlir index 64c2bea..8f0c83f 100644 --- a/mlir/test/Dialect/Linalg/roundtrip.mlir +++ b/mlir/test/Dialect/Linalg/roundtrip.mlir @@ -525,7 +525,7 @@ func.func @broadcast_static_sizes(%input: tensor<8x32xf32>, %bcast = linalg.broadcast ins(%input:tensor<8x32xf32>) outs(%init:tensor<8x16x32xf32>) - dimensions = [0, 2] + dimensions = [1] func.return %bcast : tensor<8x16x32xf32> } // CHECK-LABEL: func @broadcast_static_sizes @@ -542,7 +542,7 @@ func.func @broadcast_with_dynamic_sizes( %bcast = linalg.broadcast ins(%input:tensor<8x?xf32>) outs(%init:tensor<8x16x?xf32>) - dimensions = [0, 2] + dimensions = [1] func.return %bcast : tensor<8x16x?xf32> } // CHECK-LABEL: func @broadcast_with_dynamic_sizes @@ -558,7 +558,7 @@ func.func @broadcast_memref(%input: memref<8x32xf32>, linalg.broadcast ins(%input:memref<8x32xf32>) outs(%init:memref<8x16x32xf32>) - dimensions = [0, 2] + dimensions = [1] func.return } diff --git a/mlir/test/Interfaces/TilingInterface/lower-to-loops-using-interface.mlir b/mlir/test/Interfaces/TilingInterface/lower-to-loops-using-interface.mlir index b2e3fd5..f0d1938 100644 --- a/mlir/test/Interfaces/TilingInterface/lower-to-loops-using-interface.mlir +++ b/mlir/test/Interfaces/TilingInterface/lower-to-loops-using-interface.mlir @@ -248,7 +248,7 @@ func.func @broadcast(%input: memref<8x32xf32>, linalg.broadcast ins(%input:memref<8x32xf32>) outs(%init:memref<8x16x32xf32>) - dimensions = [0, 2] + dimensions = [1] func.return } // CHECK-LABEL: func.func @broadcast( -- 2.7.4