LogicalResult BroadcastOp::verify() {
ArrayRef<int64_t> dimensionsRef = getDimensions();
- if (!llvm::is_sorted(dimensionsRef))
- return emitOpError() << "dimensions should be in sorted order, implicit "
- "transpose is not supported";
-
auto inputType = getInput().getType();
auto initType = getInit().getType();
auto inputShape = inputType.getShape();
auto initShape = initType.getShape();
- if ((size_t)inputRank != dimensionsRef.size())
- return emitOpError()
- << "input rank does match the number of dimensions. expected: "
- << inputRank << ", got: " << dimensionsRef.size();
-
- // Mapping from init dims to input dims.
- const int64_t kUnmappedDim = -1;
- SmallVector<int64_t> reverseDimMap(initRank, kUnmappedDim);
+ if ((size_t)inputRank + dimensionsRef.size() != (size_t)initRank)
+ return emitOpError() << "input rank plus added dimensions does not "
+ "match init rank. input rank: "
+ << inputRank
+ << ", dimensions size: " << dimensionsRef.size()
+ << ", init rank: " << initRank;
for (const auto &[idx, dim] : llvm::enumerate(dimensionsRef)) {
if (dim < 0 || dim >= initRank)
return emitOpError() << "dimension " << idx
<< " is out of range. expected range: [0, "
<< initRank - 1 << "], got: " << dim;
+ }
- reverseDimMap[dim] = idx;
+ // Mapping from input dims to init dims.
+ SmallVector<int64_t> dimMap;
+ for (auto dim : llvm::seq<int64_t>(0, initRank)) {
+ if (!llvm::is_contained(dimensionsRef, dim))
+ dimMap.push_back(dim);
}
- for (const auto &[idx, inputDimIdx] : llvm::enumerate(reverseDimMap)) {
- if (inputDimIdx != kUnmappedDim) {
- // This dimensions is mapped from the input. Init and input dims should
- // match.
- if (inputShape[inputDimIdx] != initShape[idx])
- return emitOpError()
- << "input dim " << inputDimIdx << " should match init dim "
- << idx << ". input: " << inputShape[inputDimIdx]
- << ", init: " << initShape[idx];
- }
+ for (const auto &[inputDimIdx, initDimIdx] : llvm::enumerate(dimMap)) {
+ // This dimensions is mapped from the input. Init and input dims should
+ // match.
+ if (inputShape[inputDimIdx] != initShape[initDimIdx])
+ return emitOpError() << "input dim " << inputDimIdx
+ << " should match init dim " << initDimIdx
+ << ". input: " << inputShape[inputDimIdx]
+ << ", init: " << initShape[initDimIdx];
}
return success();
Builder builder(getContext());
int64_t rank = getInit().getType().getRank();
return builder.getAffineMapArrayAttr(
- {builder.getMultiDimIdentityMap(rank).getSubMap(
- llvm::to_vector_of<unsigned>(getDimensions())),
+ {builder.getMultiDimIdentityMap(rank).dropResults(getDimensions()),
builder.getMultiDimIdentityMap(rank)});
}
// -----
-func.func @broadcast_unsorted_dims(
- %input: tensor<4x16xf32>, %init: tensor<4x8x16xf32>)
- -> tensor<4x8x16xf32> {
- // expected-error @+1 {{'linalg.broadcast' op dimensions should be in sorted order}}
- %bcast = linalg.broadcast
- ins(%input:tensor<4x16xf32>)
- outs(%init:tensor<4x8x16xf32>)
- dimensions = [1, 0]
- func.return %bcast : tensor<4x8x16xf32>
-}
-
-// -----
-
func.func @broadcast_input_dims_rank_mismatch(
%input: tensor<4x16xf32>, %init: tensor<4x8x16xf32>)
-> tensor<4x8x16xf32> {
- // expected-error @+1 {{'linalg.broadcast' op input rank does match the number of dimensions. expected: 2, got: 1}}
+ // expected-error @+1 {{'linalg.broadcast' op input rank plus added dimensions does not match init rank. }}
%bcast = linalg.broadcast
ins(%input:tensor<4x16xf32>)
outs(%init:tensor<4x8x16xf32>)
- dimensions = [0]
+ dimensions = [1, 2]
func.return %bcast : tensor<4x8x16xf32>
}
func.func @broadcast_unsorted_dims(
%input: tensor<4x16xf32>, %init: tensor<4x8x16xf32>)
-> tensor<4x8x16xf32> {
- // expected-error @+1 {{'linalg.broadcast' op dimension 1 is out of range. expected range: [0, 2], got: 5}}
+ // expected-error @+1 {{'linalg.broadcast' op dimension 0 is out of range. expected range: [0, 2], got: 5}}
%bcast = linalg.broadcast
ins(%input:tensor<4x16xf32>)
outs(%init:tensor<4x8x16xf32>)
- dimensions = [0, 5]
+ dimensions = [5]
func.return %bcast : tensor<4x8x16xf32>
}
%bcast = linalg.broadcast
ins(%input:tensor<4x16xf32>)
outs(%init:tensor<5x8x16xf32>)
- dimensions = [0, 2]
+ dimensions = [1]
func.return %bcast : tensor<5x8x16xf32>
}
%bcast = linalg.broadcast
ins(%input:tensor<1x16xf32>)
outs(%init:tensor<4x?x16xf32>)
- dimensions = [0, 2]
+ dimensions = [1]
func.return %bcast : tensor<4x?x16xf32>
}
%bcast = linalg.broadcast
ins(%input:tensor<8x32xf32>)
outs(%init:tensor<8x16x32xf32>)
- dimensions = [0, 2]
+ dimensions = [1]
func.return %bcast : tensor<8x16x32xf32>
}
// CHECK-LABEL: func @broadcast_static_sizes
%bcast = linalg.broadcast
ins(%input:tensor<8x?xf32>)
outs(%init:tensor<8x16x?xf32>)
- dimensions = [0, 2]
+ dimensions = [1]
func.return %bcast : tensor<8x16x?xf32>
}
// CHECK-LABEL: func @broadcast_with_dynamic_sizes
linalg.broadcast
ins(%input:memref<8x32xf32>)
outs(%init:memref<8x16x32xf32>)
- dimensions = [0, 2]
+ dimensions = [1]
func.return
}