[mlir][linalg] Change linalg.broadcast `dimensions` attribute to represent added...
authorOleg Shyshkov <shyshkov@google.com>
Mon, 21 Nov 2022 12:15:50 +0000 (13:15 +0100)
committerOleg Shyshkov <shyshkov@google.com>
Mon, 21 Nov 2022 12:16:41 +0000 (13:16 +0100)
Original [RFC](discourse.llvm.org/t/rfc-primitive-ops-add-broadcastop-to-linalg/66313) defined `dimensions` as a map from input to init, but a discussion in reviews.llvm.org/D138291 concluded that it's more natural for `dimensions` to represent added dims. Also this way is more consistent with `linalg.reduce`.

Differential Revision: https://reviews.llvm.org/D138408

mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td
mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
mlir/test/Dialect/Linalg/invalid.mlir
mlir/test/Dialect/Linalg/one-shot-bufferize.mlir
mlir/test/Dialect/Linalg/roundtrip.mlir
mlir/test/Interfaces/TilingInterface/lower-to-loops-using-interface.mlir

index e822435..815d542 100644 (file)
@@ -463,19 +463,14 @@ def BroadcastOp : LinalgStructuredBase_Op<"broadcast", [
     SingleBlockImplicitTerminator<"YieldOp">]> {
   let summary = "Static broadcast operator";
   let description = [{
-    Broadcast the input into the given shape by adding dimensions.
-
-    Each index in `dimensions` attribute maps input dimension into the
-    corresponding target dimension. The length of the `dimensions` list should
-    match the `input` rank and dimensions should be in sorted order. There is no
-    ambiguity at compile-time about shape information.
+    Broadcast the input into the given shape by adding `dimensions`.
 
     Example:
     ```
       %bcast = linalg.broadcast
           ins(%input:tensor<16xf32>)
           inits(%init:tensor<16x64xf32>)
-          dimensions = [0]
+          dimensions = [1]
     ```
   }];
 
index ea5263c..5ce936e 100644 (file)
@@ -1511,10 +1511,6 @@ void BroadcastOp::print(OpAsmPrinter &p) {
 LogicalResult BroadcastOp::verify() {
   ArrayRef<int64_t> dimensionsRef = getDimensions();
 
-  if (!llvm::is_sorted(dimensionsRef))
-    return emitOpError() << "dimensions should be in sorted order, implicit "
-                            "transpose is not supported";
-
   auto inputType = getInput().getType();
   auto initType = getInit().getType();
 
@@ -1524,34 +1520,35 @@ LogicalResult BroadcastOp::verify() {
   auto inputShape = inputType.getShape();
   auto initShape = initType.getShape();
 
-  if ((size_t)inputRank != dimensionsRef.size())
-    return emitOpError()
-           << "input rank does match the number of dimensions. expected: "
-           << inputRank << ", got: " << dimensionsRef.size();
-
-  // Mapping from init dims to input dims.
-  const int64_t kUnmappedDim = -1;
-  SmallVector<int64_t> reverseDimMap(initRank, kUnmappedDim);
+  if ((size_t)inputRank + dimensionsRef.size() != (size_t)initRank)
+    return emitOpError() << "input rank plus added dimensions does not "
+                            "match init rank. input rank: "
+                         << inputRank
+                         << ", dimensions size: " << dimensionsRef.size()
+                         << ", init rank: " << initRank;
 
   for (const auto &[idx, dim] : llvm::enumerate(dimensionsRef)) {
     if (dim < 0 || dim >= initRank)
       return emitOpError() << "dimension " << idx
                            << " is out of range. expected range: [0, "
                            << initRank - 1 << "], got: " << dim;
+  }
 
-    reverseDimMap[dim] = idx;
+  // Mapping from input dims to init dims.
+  SmallVector<int64_t> dimMap;
+  for (auto dim : llvm::seq<int64_t>(0, initRank)) {
+    if (!llvm::is_contained(dimensionsRef, dim))
+      dimMap.push_back(dim);
   }
 
-  for (const auto &[idx, inputDimIdx] : llvm::enumerate(reverseDimMap)) {
-    if (inputDimIdx != kUnmappedDim) {
-      // This dimensions is mapped from the input. Init and input dims should
-      // match.
-      if (inputShape[inputDimIdx] != initShape[idx])
-        return emitOpError()
-               << "input dim " << inputDimIdx << " should match init dim "
-               << idx << ". input: " << inputShape[inputDimIdx]
-               << ", init: " << initShape[idx];
-    }
+  for (const auto &[inputDimIdx, initDimIdx] : llvm::enumerate(dimMap)) {
+    // This dimensions is mapped from the input. Init and input dims should
+    // match.
+    if (inputShape[inputDimIdx] != initShape[initDimIdx])
+      return emitOpError() << "input dim " << inputDimIdx
+                           << " should match init dim " << initDimIdx
+                           << ". input: " << inputShape[inputDimIdx]
+                           << ", init: " << initShape[initDimIdx];
   }
 
   return success();
@@ -1566,8 +1563,7 @@ ArrayAttr BroadcastOp::getIndexingMaps() {
   Builder builder(getContext());
   int64_t rank = getInit().getType().getRank();
   return builder.getAffineMapArrayAttr(
-      {builder.getMultiDimIdentityMap(rank).getSubMap(
-           llvm::to_vector_of<unsigned>(getDimensions())),
+      {builder.getMultiDimIdentityMap(rank).dropResults(getDimensions()),
        builder.getMultiDimIdentityMap(rank)});
 }
 
index 9eddc1c..03540be 100644 (file)
@@ -676,27 +676,14 @@ func.func @transpose_input_init_rank_mismatch(%input: tensor<16x32xf32>,
 
 // -----
 
-func.func @broadcast_unsorted_dims(
-    %input: tensor<4x16xf32>, %init: tensor<4x8x16xf32>)
-    -> tensor<4x8x16xf32> {
-  // expected-error @+1 {{'linalg.broadcast' op dimensions should be in sorted order}}
-  %bcast = linalg.broadcast
-      ins(%input:tensor<4x16xf32>)
-      outs(%init:tensor<4x8x16xf32>)
-      dimensions = [1, 0]
-  func.return %bcast : tensor<4x8x16xf32>
-}
-
-// -----
-
 func.func @broadcast_input_dims_rank_mismatch(
     %input: tensor<4x16xf32>, %init: tensor<4x8x16xf32>)
     -> tensor<4x8x16xf32> {
-  // expected-error @+1 {{'linalg.broadcast' op input rank does match the number of dimensions. expected: 2, got: 1}}
+  // expected-error @+1 {{'linalg.broadcast' op input rank plus added dimensions does not match init rank. }}
   %bcast = linalg.broadcast
       ins(%input:tensor<4x16xf32>)
       outs(%init:tensor<4x8x16xf32>)
-      dimensions = [0]
+      dimensions = [1, 2]
   func.return %bcast : tensor<4x8x16xf32>
 }
 
@@ -705,11 +692,11 @@ func.func @broadcast_input_dims_rank_mismatch(
 func.func @broadcast_unsorted_dims(
     %input: tensor<4x16xf32>, %init: tensor<4x8x16xf32>)
     -> tensor<4x8x16xf32> {
-  // expected-error @+1 {{'linalg.broadcast' op dimension 1 is out of range. expected range: [0, 2], got: 5}}
+  // expected-error @+1 {{'linalg.broadcast' op dimension 0 is out of range. expected range: [0, 2], got: 5}}
   %bcast = linalg.broadcast
       ins(%input:tensor<4x16xf32>)
       outs(%init:tensor<4x8x16xf32>)
-      dimensions = [0, 5]
+      dimensions = [5]
   func.return %bcast : tensor<4x8x16xf32>
 }
 
@@ -722,7 +709,7 @@ func.func @broadcast_mapped_dim_mismatch(
   %bcast = linalg.broadcast
       ins(%input:tensor<4x16xf32>)
       outs(%init:tensor<5x8x16xf32>)
-      dimensions = [0, 2]
+      dimensions = [1]
   func.return %bcast : tensor<5x8x16xf32>
 }
 
@@ -735,6 +722,6 @@ func.func @broadcast_size_1_extension_not_supported(
   %bcast = linalg.broadcast
       ins(%input:tensor<1x16xf32>)
       outs(%init:tensor<4x?x16xf32>)
-      dimensions = [0, 2]
+      dimensions = [1]
   func.return %bcast : tensor<4x?x16xf32>
 }
index 9d100d5..424539b 100644 (file)
@@ -395,7 +395,7 @@ func.func @broadcast(%input: tensor<8x32xf32>,
   %bcast = linalg.broadcast
       ins(%input:tensor<8x32xf32>)
       outs(%init:tensor<8x16x32xf32>)
-      dimensions = [0, 2]
+      dimensions = [1]
   func.return %bcast : tensor<8x16x32xf32>
 }
 
index 64c2bea..8f0c83f 100644 (file)
@@ -525,7 +525,7 @@ func.func @broadcast_static_sizes(%input: tensor<8x32xf32>,
   %bcast = linalg.broadcast
       ins(%input:tensor<8x32xf32>)
       outs(%init:tensor<8x16x32xf32>)
-      dimensions = [0, 2]
+      dimensions = [1]
   func.return %bcast : tensor<8x16x32xf32>
 }
 // CHECK-LABEL: func @broadcast_static_sizes
@@ -542,7 +542,7 @@ func.func @broadcast_with_dynamic_sizes(
   %bcast = linalg.broadcast
       ins(%input:tensor<8x?xf32>)
       outs(%init:tensor<8x16x?xf32>)
-      dimensions = [0, 2]
+      dimensions = [1]
   func.return %bcast : tensor<8x16x?xf32>
 }
 // CHECK-LABEL: func @broadcast_with_dynamic_sizes
@@ -558,7 +558,7 @@ func.func @broadcast_memref(%input: memref<8x32xf32>,
   linalg.broadcast
       ins(%input:memref<8x32xf32>)
       outs(%init:memref<8x16x32xf32>)
-      dimensions = [0, 2]
+      dimensions = [1]
   func.return
 }
 
index b2e3fd5..f0d1938 100644 (file)
@@ -248,7 +248,7 @@ func.func @broadcast(%input: memref<8x32xf32>,
   linalg.broadcast
       ins(%input:memref<8x32xf32>)
       outs(%init:memref<8x16x32xf32>)
-      dimensions = [0, 2]
+      dimensions = [1]
   func.return
 }
 // CHECK-LABEL: func.func @broadcast(