LogicalResult shape::getShapeVec(Value input,
SmallVectorImpl<int64_t> &shapeValues) {
if (auto inputOp = input.getDefiningOp<ShapeOfOp>()) {
- auto type = inputOp.getArg().getType().dyn_cast<ShapedType>();
+ auto type = inputOp.getArg().getType().cast<ShapedType>();
if (!type.hasRank())
return failure();
- shapeValues = llvm::to_vector<6>(type.getShape());
+ llvm::append_range(shapeValues, type.getShape());
return success();
}
- if (auto inputOp = input.getDefiningOp<ConstShapeOp>()) {
- shapeValues = llvm::to_vector<6>(inputOp.getShape().getValues<int64_t>());
- return success();
- }
- if (auto inputOp = input.getDefiningOp<arith::ConstantOp>()) {
- shapeValues = llvm::to_vector<6>(
- inputOp.getValue().cast<DenseIntElementsAttr>().getValues<int64_t>());
+ DenseIntElementsAttr attr;
+ if (matchPattern(input, m_Constant(&attr))) {
+ llvm::append_range(shapeValues, attr.getValues<int64_t>());
return success();
}
return failure();