From 244105f791539a84eeef7e8e50c180e413675b60 Mon Sep 17 00:00:00 2001 From: Oleg Shyshkov Date: Fri, 18 Nov 2022 14:45:33 +0100 Subject: [PATCH] [mlir][linalg] Do not check if added dimension are static in linalg.broadcast. Added dimensions can be both static and dinamic. Mapped dimension should be the same in the input and the init. Differential Revision: https://reviews.llvm.org/D138291 --- mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp | 8 +------- mlir/test/Dialect/Linalg/invalid.mlir | 13 ------------- 2 files changed, 1 insertion(+), 20 deletions(-) diff --git a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp index 18e399e..ea5263c 100644 --- a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp +++ b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp @@ -1543,13 +1543,7 @@ LogicalResult BroadcastOp::verify() { } for (const auto &[idx, inputDimIdx] : llvm::enumerate(reverseDimMap)) { - if (inputDimIdx == kUnmappedDim) { - // This dimensions is being added. Should be statically known. - if (ShapedType::isDynamic(initShape[idx])) - return emitOpError() - << "init dim " << idx - << " can't be dynamic, because it's not matched to input"; - } else { + if (inputDimIdx != kUnmappedDim) { // This dimensions is mapped from the input. Init and input dims should // match. if (inputShape[inputDimIdx] != initShape[idx]) diff --git a/mlir/test/Dialect/Linalg/invalid.mlir b/mlir/test/Dialect/Linalg/invalid.mlir index ebce71f..9eddc1c7 100644 --- a/mlir/test/Dialect/Linalg/invalid.mlir +++ b/mlir/test/Dialect/Linalg/invalid.mlir @@ -728,19 +728,6 @@ func.func @broadcast_mapped_dim_mismatch( // ----- -func.func @broadcast_added_dynamic_mismatch( - %input: tensor<4x16xf32>, %init: tensor<4x?x16xf32>) - -> tensor<4x?x16xf32> { - // expected-error @+1 {{'linalg.broadcast' op init dim 1 can't be dynamic, because it's not matched to input}} - %bcast = linalg.broadcast - ins(%input:tensor<4x16xf32>) - outs(%init:tensor<4x?x16xf32>) - dimensions = [0, 2] - func.return %bcast : tensor<4x?x16xf32> -} - -// ----- - func.func @broadcast_size_1_extension_not_supported( %input: tensor<1x16xf32>, %init: tensor<4x?x16xf32>) -> tensor<4x?x16xf32> { -- 2.7.4