This patch prepares MLIR code base to change the value of kDynamicSize.
https://discourse.llvm.org/t/rfc-unify-kdynamicsize-and-kdynamicstrideoroffset/64534/4
Differential Revision: https://reviews.llvm.org/D136327
static constexpr LenType singleton() { return 1; }
/// Character has a LEN value which is not a compile-time known constant.
- static constexpr LenType unknownLen() { return -1; }
+ static constexpr LenType unknownLen() { return mlir::ShapedType::kDynamicSize; }
/// Character LEN is a runtime value.
bool hasDynamicLen() { return getLen() == unknownLen(); }
auto affineApply = rewriter.create<mlir::AffineApplyOp>(acoOp.getLoc(),
affineMap, indexArgs);
auto arrayElementType = coordinateArrayElement(acoOp);
- auto newType = mlir::MemRefType::get({-1}, arrayElementType);
+ auto newType =
+ mlir::MemRefType::get({mlir::ShapedType::kDynamicSize}, arrayElementType);
auto arrayConvert = rewriter.create<fir::ConvertOp>(acoOp.getLoc(), newType,
acoOp.getMemref());
return std::make_pair(affineApply, arrayConvert);
if (consumeIf(Token::question)) {
if (!allowDynamic)
return emitError(loc, "expected static shape");
- dimensions.push_back(-1);
+ dimensions.push_back(ShapedType::kDynamicSize);
} else {
int64_t value;
if (failed(parseIntegerInDimensionList(value)))
bool isDynamic) {
if (isDynamic) {
// TODO (natashaknk): Make dynamic intermediate shape not always be rank-1
- intermediateShape = {-1};
+ intermediateShape = {ShapedType::kDynamicSize};
return true;
}
// Broadcast the newly added dimensions to their appropriate multiple.
SmallVector<int64_t, 2> genericShape;
for (int i = 0; i < rank; i++) {
- genericShape.push_back(multiples[i]);
+ int64_t dim = multiples[i];
+ genericShape.push_back(dim == -1 ? ShapedType::kDynamicSize : dim);
genericShape.push_back(inputShape[i]);
}
PatternRewriter &rewriter) const final {
Location loc = sliceOp.getLoc();
Value input = sliceOp.getInput();
- SmallVector<int64_t> strides;
+ SmallVector<int64_t> strides, sizes;
auto starts = sliceOp.getStart();
- auto sizes = sliceOp.getSize();
strides.resize(sliceOp.getType().template cast<ShapedType>().getRank(), 1);
SmallVector<Value> dynSizes;
- for (const auto &i : llvm::enumerate(sizes)) {
+ for (const auto &i : llvm::enumerate(sliceOp.getSize())) {
int64_t size = i.value().cast<IntegerAttr>().getInt();
size_t index = i.index();
- if (size != ShapedType::kDynamicSize)
+ sizes.push_back(size == -1 ? ShapedType::kDynamicSize : size);
+ if (!ShapedType::isDynamic(sizes.back()))
continue;
auto dim = rewriter.create<tensor::DimOp>(loc, input, index);
auto newSliceOp = rewriter.create<tensor::ExtractSliceOp>(
sliceOp.getLoc(), sliceOp.getType(), input, ValueRange({}), dynSizes,
- ValueRange({}), starts, sizes, rewriter.getI64ArrayAttr(strides));
+ ValueRange({}), starts, rewriter.getI64ArrayAttr(sizes),
+ rewriter.getI64ArrayAttr(strides));
rewriter.replaceOp(sliceOp, newSliceOp.getResult());
return success();
bool isDynDim =
isNormalizedMemRefDynamicDim(d, layoutMap, memrefTypeDynDims, context);
if (isDynDim) {
- newShape[d] = -1;
+ newShape[d] = ShapedType::kDynamicSize;
} else {
// The lower bound for the shape is always zero.
Optional<int64_t> ubConst =
// We parsed a generic dimension list, but vectors only support two forms:
// - single non-dynamic entry in the list (fixed vector);
- // - two elements, the first dynamic (indicated by -1) and the second
+ // - two elements, the first dynamic (indicated by ShapedType::kDynamicSize)
+ // and the second
// non-dynamic (scalable vector).
if (dims.empty() || dims.size() > 2 ||
- ((dims.size() == 2) ^ (dims[0] == -1)) ||
- (dims.size() == 2 && dims[1] == -1)) {
+ ((dims.size() == 2) ^ (ShapedType::isDynamic(dims[0]))) ||
+ (dims.size() == 2 && ShapedType::isDynamic(dims[1]))) {
parser.emitError(dimPos)
<< "expected '? x <integer> x <type>' or '<integer> x <type>'";
return Type();
}
// Fallback dynamic buffer.
- auto dynamicBufferType = MemRefType::get(-1, b.getIntegerType(8));
+ auto dynamicBufferType =
+ MemRefType::get(ShapedType::kDynamicSize, b.getIntegerType(8));
Value mul = b.createOrFold<arith::MulIOp>(
b.create<arith::ConstantIndexOp>(width), allocSize);
if (options.useAlloca)
partialSizes.push_back(
b.createOrFold<memref::DimOp>(loc, subView, resultDimIdx++));
}
- SmallVector<int64_t, 4> dynSizes(fullSizes.size(), -1);
+ SmallVector<int64_t, 4> dynSizes(fullSizes.size(), ShapedType::kDynamicSize);
// If a callback is not specified, then use the default implementation for
// allocating the promoted buffer.
Optional<Value> fullLocalView = allocationFn(b, subView, fullSizes, layout);
for (unsigned dim = 0, e = memrefType.getRank(); dim < e; ++dim) {
int64_t dimSize = memrefType.getDimSize(dim);
// If this is already static dimension, keep it.
- if (dimSize != -1) {
+ if (!ShapedType::isDynamic(dimSize)) {
newShapeConstants.push_back(dimSize);
continue;
}
newShapeConstants.push_back(constantIndexOp.value());
} else {
// Dynamic shape dimension not folded; copy dynamicSize from old memref.
- newShapeConstants.push_back(-1);
+ newShapeConstants.push_back(ShapedType::kDynamicSize);
dynamicSizes.push_back(dynamicSize);
}
dynamicDimPos++;
for (unsigned i = 0, e = aT.getRank(); i != e; ++i) {
int64_t aDim = aT.getDimSize(i), bDim = bT.getDimSize(i);
- if (aDim != -1 && bDim != -1 && aDim != bDim)
+ if (!ShapedType::isDynamic(aDim) && !ShapedType::isDynamic(bDim) &&
+ aDim != bDim)
return false;
}
return true;
"sum of all the concatenation dimensions of the input tensors.");
}
} else {
- int prev = dstDim;
+ int64_t prev = dstDim;
for (auto src : getInputs()) {
auto d = src.getType().cast<RankedTensorType>().getShape()[i];
if (prev != ShapedType::kDynamicSize && d != prev)
}
// Determine the dimension size along the concatenation axis.
- int concatDimSize = 0;
+ int64_t concatDimSize = 0;
for (auto operand : operands) {
ShapeAdaptor operandShape = operands.getShape(operand);
// Any non dynamic dimension can be multiplied to a known size.
outputShape.reserve(multiples.size());
for (int i = 0, s = inputShape.getRank(); i < s; i++) {
- int dim = inputShape.getDimSize(i);
+ int64_t dim = inputShape.getDimSize(i);
if (dim != ShapedType::kDynamicSize)
dim *= multipleValues[i];
outputShape.push_back(dim);
return success();
}
+static SmallVector<int64_t> ConvertToMlirShape(ArrayRef<int64_t> shape) {
+ return to_vector(llvm::map_range(shape, [](int64_t dim) {
+ return dim == -1 ? ShapedType::kDynamicSize : dim;
+ }));
+}
+
LogicalResult tosa::ReshapeOp::inferReturnTypeComponents(
MLIRContext *context, ::llvm::Optional<Location> location,
ValueShapeRange operands, DictionaryAttr attributes, RegionRange regions,
ArrayAttr newShape = adaptor.getNewShape();
llvm::SmallVector<int64_t> newShapeValue;
getI64Values(newShape, newShapeValue);
+ newShapeValue = ConvertToMlirShape(newShapeValue);
// We cannot infer from the total number of elements so we must take the
// shape attribute as exact.
int64_t numElements = inputShape.getNumElements();
int64_t staticMul = 1;
for (auto val : newShapeValue) {
- if (val != ShapedType::kDynamicSize) {
+ if (!ShapedType::isDynamic(val)) {
staticMul *= val;
}
}
// Determine the length of the dynamic dimension.
for (auto &val : newShapeValue) {
- if (val == ShapedType::kDynamicSize)
+ if (ShapedType::isDynamic(val))
val = numElements / staticMul;
}
outputShape[0] = inputShape.getDimSize(0);
outputShape[3] = inputShape.getDimSize(3);
- int32_t inputHeight = inputShape.getDimSize(1);
- int32_t inputWidth = inputShape.getDimSize(2);
+ int64_t inputHeight = inputShape.getDimSize(1);
+ int64_t inputWidth = inputShape.getDimSize(2);
if ((inputHeight == ShapedType::kDynamicSize) ||
(inputWidth == ShapedType::kDynamicSize))
SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) {
ShapeAdaptor inputShape = operands.getShape(0);
llvm::SmallVector<int64_t> outputShape;
- outputShape.resize(4, -1);
+ outputShape.resize(4, ShapedType::kDynamicSize);
// We only know the rank if the input type is unranked.
if (!inputShape) {
outputShape[0] = inputShape.getDimSize(0);
outputShape[3] = inputShape.getDimSize(3);
- int32_t height = inputShape.getDimSize(1);
- int32_t width = inputShape.getDimSize(2);
+ int64_t height = inputShape.getDimSize(1);
+ int64_t width = inputShape.getDimSize(2);
llvm::SmallVector<int64_t> kernel;
llvm::SmallVector<int64_t> stride;
getI64Values(attributes.get("stride").cast<ArrayAttr>(), stride);
getI64Values(attributes.get("pad").cast<ArrayAttr>(), pad);
- if (height != -1) {
- int32_t padded = height + pad[0] + pad[1] - kernel[0];
+ if (!ShapedType::isDynamic(height)) {
+ int64_t padded = height + pad[0] + pad[1] - kernel[0];
outputShape[1] = padded / stride[0] + 1;
}
- if (width != -1) {
- int32_t padded = width + pad[2] + pad[3] - kernel[1];
+ if (!ShapedType::isDynamic(width)) {
+ int64_t padded = width + pad[2] + pad[3] - kernel[1];
outputShape[2] = padded / stride[1] + 1;
}
llvm::SmallVector<int64_t> outputShape(4, ShapedType::kDynamicSize);
Conv2DOp::Adaptor adaptor(operands.getValues(), attributes);
- int32_t inputWidth = ShapedType::kDynamicSize;
- int32_t inputHeight = ShapedType::kDynamicSize;
- int32_t weightWidth = ShapedType::kDynamicSize;
- int32_t weightHeight = ShapedType::kDynamicSize;
+ int64_t inputWidth = ShapedType::kDynamicSize;
+ int64_t inputHeight = ShapedType::kDynamicSize;
+ int64_t weightWidth = ShapedType::kDynamicSize;
+ int64_t weightHeight = ShapedType::kDynamicSize;
// Input shape describes input width/height and batch.
if (!ShapedType::isDynamic(inputHeight) &&
!ShapedType::isDynamic(weightHeight)) {
- int32_t inputSize = inputHeight + padding[0] + padding[1];
- int32_t filterSize = (weightHeight - 1) * dilation[0] + 1;
- int32_t unstridedResult = inputSize - filterSize + 1;
+ int64_t inputSize = inputHeight + padding[0] + padding[1];
+ int64_t filterSize = (weightHeight - 1) * dilation[0] + 1;
+ int64_t unstridedResult = inputSize - filterSize + 1;
outputShape[1] = (unstridedResult - 1) / stride[0] + 1;
}
if (!ShapedType::isDynamic(inputWidth) &&
!ShapedType::isDynamic(weightWidth)) {
- int32_t inputSize = inputWidth + padding[2] + padding[3];
- int32_t filterSize = (weightWidth - 1) * dilation[1] + 1;
- int32_t unstridedResult = inputSize - filterSize + 1;
+ int64_t inputSize = inputWidth + padding[2] + padding[3];
+ int64_t filterSize = (weightWidth - 1) * dilation[1] + 1;
+ int64_t unstridedResult = inputSize - filterSize + 1;
outputShape[2] = (unstridedResult - 1) / stride[1] + 1;
}
llvm::SmallVector<int64_t> outputShape(5, ShapedType::kDynamicSize);
Conv3DOp::Adaptor adaptor(operands.getValues(), attributes);
- int32_t inputWidth = ShapedType::kDynamicSize;
- int32_t inputHeight = ShapedType::kDynamicSize;
- int32_t inputDepth = ShapedType::kDynamicSize;
+ int64_t inputWidth = ShapedType::kDynamicSize;
+ int64_t inputHeight = ShapedType::kDynamicSize;
+ int64_t inputDepth = ShapedType::kDynamicSize;
- int32_t weightWidth = ShapedType::kDynamicSize;
- int32_t weightHeight = ShapedType::kDynamicSize;
- int32_t weightDepth = ShapedType::kDynamicSize;
+ int64_t weightWidth = ShapedType::kDynamicSize;
+ int64_t weightHeight = ShapedType::kDynamicSize;
+ int64_t weightDepth = ShapedType::kDynamicSize;
// Input shape describes input width/height and batch.
ShapeAdaptor inputShape = operands.getShape(adaptor.getInput());
llvm::SmallVector<int64_t> outputShape(4, ShapedType::kDynamicSize);
DepthwiseConv2DOp::Adaptor adaptor(operands.getValues(), attributes);
- int32_t inputWidth = ShapedType::kDynamicSize;
- int32_t inputHeight = ShapedType::kDynamicSize;
- int32_t inputChannels = ShapedType::kDynamicSize;
+ int64_t inputWidth = ShapedType::kDynamicSize;
+ int64_t inputHeight = ShapedType::kDynamicSize;
+ int64_t inputChannels = ShapedType::kDynamicSize;
- int32_t weightWidth = ShapedType::kDynamicSize;
- int32_t weightHeight = ShapedType::kDynamicSize;
- int32_t depthChannels = ShapedType::kDynamicSize;
+ int64_t weightWidth = ShapedType::kDynamicSize;
+ int64_t weightHeight = ShapedType::kDynamicSize;
+ int64_t depthChannels = ShapedType::kDynamicSize;
// Input shape describes input width/height and batch.
ShapeAdaptor inputShape = operands.getShape(adaptor.getInput());
if (!ShapedType::isDynamic(inputHeight) &&
!ShapedType::isDynamic(weightHeight)) {
- int32_t inputSize = inputHeight + padding[0] + padding[1];
- int32_t filterSize = (weightHeight - 1) * dilation[0] + 1;
- int32_t unstridedResult = inputSize - filterSize + 1;
+ int64_t inputSize = inputHeight + padding[0] + padding[1];
+ int64_t filterSize = (weightHeight - 1) * dilation[0] + 1;
+ int64_t unstridedResult = inputSize - filterSize + 1;
outputShape[1] = (unstridedResult - 1) / stride[0] + 1;
}
if (!ShapedType::isDynamic(inputWidth) &&
!ShapedType::isDynamic(weightWidth)) {
- int32_t inputSize = inputWidth + padding[2] + padding[3];
- int32_t filterSize = (weightWidth - 1) * dilation[1] + 1;
- int32_t unstridedResult = inputSize - filterSize + 1;
+ int64_t inputSize = inputWidth + padding[2] + padding[3];
+ int64_t filterSize = (weightWidth - 1) * dilation[1] + 1;
+ int64_t unstridedResult = inputSize - filterSize + 1;
outputShape[2] = (unstridedResult - 1) / stride[1] + 1;
}
TransposeConv2DOp::Adaptor adaptor(operands.getValues(), attributes);
llvm::SmallVector<int64_t> outputShape;
getI64Values(adaptor.getOutShape(), outputShape);
+ outputShape = ConvertToMlirShape(outputShape);
- int32_t inputWidth = ShapedType::kDynamicSize;
- int32_t inputHeight = ShapedType::kDynamicSize;
- int32_t weightWidth = ShapedType::kDynamicSize;
- int32_t weightHeight = ShapedType::kDynamicSize;
+ int64_t inputWidth = ShapedType::kDynamicSize;
+ int64_t inputHeight = ShapedType::kDynamicSize;
+ int64_t weightWidth = ShapedType::kDynamicSize;
+ int64_t weightHeight = ShapedType::kDynamicSize;
// Input shape describes input width/height and batch.
ShapeAdaptor inputShape = operands.getShape(adaptor.getInput());
if (!ShapedType::isDynamic(inputHeight) &&
!ShapedType::isDynamic(weightHeight)) {
- int32_t calculateSize =
+ int64_t calculateSize =
(inputHeight - 1) * stride[0] - padding[0] - padding[1] + weightHeight;
- outputShape[1] = outputShape[1] == -1 ? calculateSize : outputShape[1];
+ outputShape[1] =
+ ShapedType::isDynamic(outputShape[1]) ? calculateSize : outputShape[1];
}
if (!ShapedType::isDynamic(inputWidth) &&
!ShapedType::isDynamic(weightWidth)) {
- int32_t calculateSize =
+ int64_t calculateSize =
(inputWidth - 1) * stride[1] - padding[2] - padding[3] + weightWidth;
- outputShape[2] = outputShape[2] == -1 ? calculateSize : outputShape[2];
+ outputShape[2] =
+ ShapedType::isDynamic(outputShape[2]) ? calculateSize : outputShape[2];
}
inferredReturnShapes.push_back(ShapedTypeComponents(outputShape));
namespace {
+SmallVector<int64_t> ConvertFromMlirShape(ArrayRef<int64_t> shape) {
+ return to_vector(llvm::map_range(shape, [](int64_t dim) {
+ return ShapedType::isDynamic(dim) ? -1 : dim;
+ }));
+}
+
struct Conv2DIsFullyConnected : public OpRewritePattern<tosa::Conv2DOp> {
explicit Conv2DIsFullyConnected(MLIRContext *context)
: OpRewritePattern(context) {}
// Reshape input to [N,IH,IW,IC] -> [N * IH * IW, IC].
ArrayRef<int64_t> inputShape = inputType.getShape();
- int64_t combined = inputShape[0] * inputShape[1] * inputShape[2];
- if (combined < 0)
- combined = ShapedType::kDynamicSize;
+ int64_t combined = ShapedType::kDynamicSize;
+ if (numDynamic == 0)
+ combined = inputShape[0] * inputShape[1] * inputShape[2];
llvm::SmallVector<int64_t, 2> revisedInputShape{combined, inputShape[3]};
auto revisedInputShapeType =
RankedTensorType::get(revisedInputShape, inputType.getElementType());
auto reshapedInput = rewriter
.create<tosa::ReshapeOp>(
op.getLoc(), revisedInputShapeType, input,
- rewriter.getI64ArrayAttr(revisedInputShape))
+ rewriter.getI64ArrayAttr(
+ ConvertFromMlirShape(revisedInputShape)))
.getResult();
// Reshape kernel to [OC,KH,KW,IC] -> [OC, IC].
auto reshapedWeight = rewriter
.create<tosa::ReshapeOp>(
op.getLoc(), revisedWeightShapeType, weight,
- rewriter.getI64ArrayAttr(revisedWeightShape))
+ rewriter.getI64ArrayAttr(
+ ConvertFromMlirShape(revisedWeightShape)))
.getResult();
// Perform a fully connected network over the reshaped input and weight.
inputShape[2], weightShape[0]};
rewriter.replaceOpWithNewOp<tosa::ReshapeOp>(
op, resultType, fullyConnectedValue,
- rewriter.getI64ArrayAttr(outputShape));
+ rewriter.getI64ArrayAttr(ConvertFromMlirShape(outputShape)));
return success();
}
};
// Check each dimension is consistent.
for (; i1 != e1 && i2 != e2; ++i1, ++i2, ++iR) {
- if (*i1 == -1 || *i2 == -1) {
+ if (ShapedType::isDynamic(*i1) || ShapedType::isDynamic(*i2)) {
// One or both dimensions is unknown. Follow TensorFlow behavior:
// - If either dimension is greater than 1, we assume that the program is
// correct, and the other dimension will be broadcast to match it.
} else if (*i2 == 1) {
*iR = *i1;
} else {
- *iR = -1;
+ *iR = ShapedType::kDynamicSize;
}
} else {
if (*i1 == *i2 || *i2 == 1) {
// then it is compatible, else if the inferred dim is 1 then it is also
// compatible. But if the existing dim is 1 and the inferred is greater than
// 1 then flag.
- return dim1 == dim2 || dim1 == -1 || dim2 == -1 || dim1 == 1;
+ return dim1 == dim2 || ShapedType::isDynamic(dim1) ||
+ ShapedType::isDynamic(dim2) || dim1 == 1;
};
if (inferred.size() != existing.size())
return false;
ArrayRef<int64_t> shape, Type elementType,
Attribute encoding) {
for (int64_t s : shape)
- if (s < -1)
+ if (s < 0 && !ShapedType::isDynamic(s))
return emitError() << "invalid tensor dimension size";
if (auto v = encoding.dyn_cast_or_null<VerifiableTensorEncoding>())
if (failed(v.verifyEncoding(shape, elementType, emitError)))
if (!BaseMemRefType::isValidElementType(elementType))
return emitError() << "invalid memref element type";
- // Negative sizes are not allowed except for `-1` that means dynamic size.
+ // Negative sizes are not allowed except for `kDynamicSize`.
for (int64_t s : shape)
- if (s < -1)
+ if (s < 0 && !ShapedType::isDynamic(s))
return emitError() << "invalid memref size";
assert(layout && "missing layout specification");
if isinstance(s, int):
static_sizes.append(s)
else:
- static_sizes.append(-1)
+ static_sizes.append(ShapedType.get_dynamic_size())
dynamic_sizes.append(s)
result_type = RankedTensorType.get(static_sizes, element_type)
op = self.build_generic(
# CHECK-NEXT: %[[CST:.*]] = arith.constant 0.0{{.*}} : f32
# CHECK-NEXT: %[[RES:.*]] = linalg.fill ins(%[[CST]] : f32) outs(%[[OUT]] : tensor<12x?xf32>) -> tensor<12x?xf32>
# CHECK-NEXT: return %[[RES]] : tensor<12x?xf32>
- @func.FuncOp.from_py_func(RankedTensorType.get((12, -1), f32))
+ @func.FuncOp.from_py_func(
+ RankedTensorType.get((12, ShapedType.get_dynamic_size()), f32))
def fill_tensor(out):
zero = arith.ConstantOp(value=FloatAttr.get(f32, 0.), result=f32).result
return linalg.fill(zero, outs=[out])
# CHECK-NEXT: %[[CST:.*]] = arith.constant 0.0{{.*}} : f32
# CHECK-NEXT: linalg.fill ins(%[[CST]] : f32) outs(%[[OUT]] : memref<12x?xf32>)
# CHECK-NEXT: return
- @func.FuncOp.from_py_func(MemRefType.get((12, -1), f32))
+ @func.FuncOp.from_py_func(
+ MemRefType.get((12, ShapedType.get_dynamic_size()), f32))
def fill_buffer(out):
zero = arith.ConstantOp(value=FloatAttr.get(f32, 0.), result=f32).result
linalg.fill(zero, outs=[out])
f32 = F32Type.get()
with InsertionPoint(module.body):
@func.FuncOp.from_py_func(
- RankedTensorType.get((12, -1), f32))
+ RankedTensorType.get((12, ShapedType.get_dynamic_size()), f32))
def const_shape_tensor(arg):
return shape.ConstShapeOp(
DenseElementsAttr.get(np.array([10, 20], dtype=np.int64), type=IndexType.get()))
indexType = IndexType.get()
with InsertionPoint(module.body):
- @func.FuncOp.from_py_func(RankedTensorType.get((-1, -1), f32Type))
+ @func.FuncOp.from_py_func(
+ RankedTensorType.get(
+ (ShapedType.get_dynamic_size(), ShapedType.get_dynamic_size()),
+ f32Type))
# CHECK: func @tensor_static_dim
# CHECK-SAME: %[[ARG0:.+]]: tensor<?x?xf32>
# CHECK-DAG: %[[C0:.+]] = arith.constant 0 : index
module = Module.create()
with InsertionPoint(module.body):
vector_type = VectorType.get([2, 3], F32Type.get())
- memref_type = MemRefType.get([-1, -1], F32Type.get())
+ memref_type = MemRefType.get(
+ [ShapedType.get_dynamic_size(),
+ ShapedType.get_dynamic_size()], F32Type.get())
index_type = IndexType.get()
mask_type = VectorType.get(vector_type.shape, IntegerType.get_signless(1))
identity_map = AffineMap.get_identity(vector_type.rank)
TEST(BroadcastShapeTest, InterleavingUnknowns) {
SmallVector<int64_t, 4> result;
- ASSERT_TRUE(
- getBroadcastedShape({1, 2, -1, -1, -1}, {-1, -1, -1, 4, 1}, result));
- EXPECT_THAT(result, ElementsAre(-1, 2, -1, 4, -1));
+ int64_t dyn = mlir::ShapedType::kDynamicSize;
+ ASSERT_TRUE(getBroadcastedShape({1, 2, dyn, dyn, dyn}, {dyn, dyn, dyn, 4, 1},
+ result));
+ EXPECT_THAT(result, ElementsAre(dyn, 2, dyn, 4, dyn));
}
TEST(BroadcastShapeTest, IncompatibleLowDim) {