// If the type is a function type, it contains the input and result types of
// this operation.
- if (FunctionType funcType = type.dyn_cast<FunctionType>()) {
+ if (FunctionType funcType = llvm::dyn_cast<FunctionType>(type)) {
if (parser.resolveOperands(operands, funcType.getInputs(), operandsLoc,
result.operands))
return mlir::failure();
mlir::LogicalResult ConstantOp::verify() {
// If the return type of the constant is not an unranked tensor, the shape
// must match the shape of the attribute holding the data.
- auto resultType = getResult().getType().dyn_cast<mlir::RankedTensorType>();
+ auto resultType = llvm::dyn_cast<mlir::RankedTensorType>(getResult().getType());
if (!resultType)
return success();
// Check that the rank of the attribute type matches the rank of the constant
// result type.
- auto attrType = getValue().getType().cast<mlir::RankedTensorType>();
+ auto attrType = llvm::cast<mlir::RankedTensorType>(getValue().getType());
if (attrType.getRank() != resultType.getRank()) {
return emitOpError("return type must match the one of the attached value "
"attribute: ")
auto resultType = results.front();
// Check that the result type of the function matches the operand type.
- if (inputType == resultType || inputType.isa<mlir::UnrankedTensorType>() ||
- resultType.isa<mlir::UnrankedTensorType>())
+ if (inputType == resultType || llvm::isa<mlir::UnrankedTensorType>(inputType) ||
+ llvm::isa<mlir::UnrankedTensorType>(resultType))
return mlir::success();
return emitError() << "type of return operand (" << inputType
}
mlir::LogicalResult TransposeOp::verify() {
- auto inputType = getOperand().getType().dyn_cast<RankedTensorType>();
- auto resultType = getType().dyn_cast<RankedTensorType>();
+ auto inputType = llvm::dyn_cast<RankedTensorType>(getOperand().getType());
+ auto resultType = llvm::dyn_cast<RankedTensorType>(getType());
if (!inputType || !resultType)
return mlir::success();
// If the type is a function type, it contains the input and result types of
// this operation.
- if (FunctionType funcType = type.dyn_cast<FunctionType>()) {
+ if (FunctionType funcType = llvm::dyn_cast<FunctionType>(type)) {
if (parser.resolveOperands(operands, funcType.getInputs(), operandsLoc,
result.operands))
return mlir::failure();
mlir::LogicalResult ConstantOp::verify() {
// If the return type of the constant is not an unranked tensor, the shape
// must match the shape of the attribute holding the data.
- auto resultType = getResult().getType().dyn_cast<mlir::RankedTensorType>();
+ auto resultType = llvm::dyn_cast<mlir::RankedTensorType>(getResult().getType());
if (!resultType)
return success();
// Check that the rank of the attribute type matches the rank of the constant
// result type.
- auto attrType = getValue().getType().cast<mlir::RankedTensorType>();
+ auto attrType = llvm::cast<mlir::RankedTensorType>(getValue().getType());
if (attrType.getRank() != resultType.getRank()) {
return emitOpError("return type must match the one of the attached value "
"attribute: ")
auto resultType = results.front();
// Check that the result type of the function matches the operand type.
- if (inputType == resultType || inputType.isa<mlir::UnrankedTensorType>() ||
- resultType.isa<mlir::UnrankedTensorType>())
+ if (inputType == resultType || llvm::isa<mlir::UnrankedTensorType>(inputType) ||
+ llvm::isa<mlir::UnrankedTensorType>(resultType))
return mlir::success();
return emitError() << "type of return operand (" << inputType
}
mlir::LogicalResult TransposeOp::verify() {
- auto inputType = getOperand().getType().dyn_cast<RankedTensorType>();
- auto resultType = getType().dyn_cast<RankedTensorType>();
+ auto inputType = llvm::dyn_cast<RankedTensorType>(getOperand().getType());
+ auto resultType = llvm::dyn_cast<RankedTensorType>(getType());
if (!inputType || !resultType)
return mlir::success();
// If the type is a function type, it contains the input and result types of
// this operation.
- if (FunctionType funcType = type.dyn_cast<FunctionType>()) {
+ if (FunctionType funcType = llvm::dyn_cast<FunctionType>(type)) {
if (parser.resolveOperands(operands, funcType.getInputs(), operandsLoc,
result.operands))
return mlir::failure();
mlir::LogicalResult ConstantOp::verify() {
// If the return type of the constant is not an unranked tensor, the shape
// must match the shape of the attribute holding the data.
- auto resultType = getResult().getType().dyn_cast<mlir::RankedTensorType>();
+ auto resultType = llvm::dyn_cast<mlir::RankedTensorType>(getResult().getType());
if (!resultType)
return success();
// Check that the rank of the attribute type matches the rank of the constant
// result type.
- auto attrType = getValue().getType().cast<mlir::RankedTensorType>();
+ auto attrType = llvm::cast<mlir::RankedTensorType>(getValue().getType());
if (attrType.getRank() != resultType.getRank()) {
return emitOpError("return type must match the one of the attached value "
"attribute: ")
if (inputs.size() != 1 || outputs.size() != 1)
return false;
// The inputs must be Tensors with the same element type.
- TensorType input = inputs.front().dyn_cast<TensorType>();
- TensorType output = outputs.front().dyn_cast<TensorType>();
+ TensorType input = llvm::dyn_cast<TensorType>(inputs.front());
+ TensorType output = llvm::dyn_cast<TensorType>(outputs.front());
if (!input || !output || input.getElementType() != output.getElementType())
return false;
// The shape is required to match if both types are ranked.
auto resultType = results.front();
// Check that the result type of the function matches the operand type.
- if (inputType == resultType || inputType.isa<mlir::UnrankedTensorType>() ||
- resultType.isa<mlir::UnrankedTensorType>())
+ if (inputType == resultType || llvm::isa<mlir::UnrankedTensorType>(inputType) ||
+ llvm::isa<mlir::UnrankedTensorType>(resultType))
return mlir::success();
return emitError() << "type of return operand (" << inputType
}
void TransposeOp::inferShapes() {
- auto arrayTy = getOperand().getType().cast<RankedTensorType>();
+ auto arrayTy = llvm::cast<RankedTensorType>(getOperand().getType());
SmallVector<int64_t, 2> dims(llvm::reverse(arrayTy.getShape()));
getResult().setType(RankedTensorType::get(dims, arrayTy.getElementType()));
}
mlir::LogicalResult TransposeOp::verify() {
- auto inputType = getOperand().getType().dyn_cast<RankedTensorType>();
- auto resultType = getType().dyn_cast<RankedTensorType>();
+ auto inputType = llvm::dyn_cast<RankedTensorType>(getOperand().getType());
+ auto resultType = llvm::dyn_cast<RankedTensorType>(getType());
if (!inputType || !resultType)
return mlir::success();
/// operands inferred.
static bool allOperandsInferred(Operation *op) {
return llvm::all_of(op->getOperandTypes(), [](Type operandType) {
- return operandType.isa<RankedTensorType>();
+ return llvm::isa<RankedTensorType>(operandType);
});
}
/// shaped result.
static bool returnsDynamicShape(Operation *op) {
return llvm::any_of(op->getResultTypes(), [](Type resultType) {
- return !resultType.isa<RankedTensorType>();
+ return !llvm::isa<RankedTensorType>(resultType);
});
}
};
// If the type is a function type, it contains the input and result types of
// this operation.
- if (FunctionType funcType = type.dyn_cast<FunctionType>()) {
+ if (FunctionType funcType = llvm::dyn_cast<FunctionType>(type)) {
if (parser.resolveOperands(operands, funcType.getInputs(), operandsLoc,
result.operands))
return mlir::failure();
mlir::LogicalResult ConstantOp::verify() {
// If the return type of the constant is not an unranked tensor, the shape
// must match the shape of the attribute holding the data.
- auto resultType = getResult().getType().dyn_cast<mlir::RankedTensorType>();
+ auto resultType = llvm::dyn_cast<mlir::RankedTensorType>(getResult().getType());
if (!resultType)
return success();
// Check that the rank of the attribute type matches the rank of the constant
// result type.
- auto attrType = getValue().getType().cast<mlir::RankedTensorType>();
+ auto attrType = llvm::cast<mlir::RankedTensorType>(getValue().getType());
if (attrType.getRank() != resultType.getRank()) {
return emitOpError("return type must match the one of the attached value "
"attribute: ")
if (inputs.size() != 1 || outputs.size() != 1)
return false;
// The inputs must be Tensors with the same element type.
- TensorType input = inputs.front().dyn_cast<TensorType>();
- TensorType output = outputs.front().dyn_cast<TensorType>();
+ TensorType input = llvm::dyn_cast<TensorType>(inputs.front());
+ TensorType output = llvm::dyn_cast<TensorType>(outputs.front());
if (!input || !output || input.getElementType() != output.getElementType())
return false;
// The shape is required to match if both types are ranked.
auto resultType = results.front();
// Check that the result type of the function matches the operand type.
- if (inputType == resultType || inputType.isa<mlir::UnrankedTensorType>() ||
- resultType.isa<mlir::UnrankedTensorType>())
+ if (inputType == resultType || llvm::isa<mlir::UnrankedTensorType>(inputType) ||
+ llvm::isa<mlir::UnrankedTensorType>(resultType))
return mlir::success();
return emitError() << "type of return operand (" << inputType
}
void TransposeOp::inferShapes() {
- auto arrayTy = getOperand().getType().cast<RankedTensorType>();
+ auto arrayTy = llvm::cast<RankedTensorType>(getOperand().getType());
SmallVector<int64_t, 2> dims(llvm::reverse(arrayTy.getShape()));
getResult().setType(RankedTensorType::get(dims, arrayTy.getElementType()));
}
mlir::LogicalResult TransposeOp::verify() {
- auto inputType = getOperand().getType().dyn_cast<RankedTensorType>();
- auto resultType = getType().dyn_cast<RankedTensorType>();
+ auto inputType = llvm::dyn_cast<RankedTensorType>(getOperand().getType());
+ auto resultType = llvm::dyn_cast<RankedTensorType>(getType());
if (!inputType || !resultType)
return mlir::success();
static void lowerOpToLoops(Operation *op, ValueRange operands,
PatternRewriter &rewriter,
LoopIterationFn processIteration) {
- auto tensorType = (*op->result_type_begin()).cast<RankedTensorType>();
+ auto tensorType = llvm::cast<RankedTensorType>((*op->result_type_begin()));
auto loc = op->getLoc();
// Insert an allocation and deallocation for the result of this operation.
// When lowering the constant operation, we allocate and assign the constant
// values to a corresponding memref allocation.
- auto tensorType = op.getType().cast<RankedTensorType>();
+ auto tensorType = llvm::cast<RankedTensorType>(op.getType());
auto memRefType = convertTensorToMemRef(tensorType);
auto alloc = insertAllocAndDealloc(memRefType, loc, rewriter);
target.addIllegalDialect<toy::ToyDialect>();
target.addDynamicallyLegalOp<toy::PrintOp>([](toy::PrintOp op) {
return llvm::none_of(op->getOperandTypes(),
- [](Type type) { return type.isa<TensorType>(); });
+ [](Type type) { return llvm::isa<TensorType>(type); });
});
// Now that the conversion target has been defined, we just need to provide
/// operands inferred.
static bool allOperandsInferred(Operation *op) {
return llvm::all_of(op->getOperandTypes(), [](Type operandType) {
- return operandType.isa<RankedTensorType>();
+ return llvm::isa<RankedTensorType>(operandType);
});
}
/// shaped result.
static bool returnsDynamicShape(Operation *op) {
return llvm::any_of(op->getResultTypes(), [](Type resultType) {
- return !resultType.isa<RankedTensorType>();
+ return !llvm::isa<RankedTensorType>(resultType);
});
}
};
// If the type is a function type, it contains the input and result types of
// this operation.
- if (FunctionType funcType = type.dyn_cast<FunctionType>()) {
+ if (FunctionType funcType = llvm::dyn_cast<FunctionType>(type)) {
if (parser.resolveOperands(operands, funcType.getInputs(), operandsLoc,
result.operands))
return mlir::failure();
mlir::LogicalResult ConstantOp::verify() {
// If the return type of the constant is not an unranked tensor, the shape
// must match the shape of the attribute holding the data.
- auto resultType = getResult().getType().dyn_cast<mlir::RankedTensorType>();
+ auto resultType = llvm::dyn_cast<mlir::RankedTensorType>(getResult().getType());
if (!resultType)
return success();
// Check that the rank of the attribute type matches the rank of the constant
// result type.
- auto attrType = getValue().getType().cast<mlir::RankedTensorType>();
+ auto attrType = llvm::cast<mlir::RankedTensorType>(getValue().getType());
if (attrType.getRank() != resultType.getRank()) {
return emitOpError("return type must match the one of the attached value "
"attribute: ")
if (inputs.size() != 1 || outputs.size() != 1)
return false;
// The inputs must be Tensors with the same element type.
- TensorType input = inputs.front().dyn_cast<TensorType>();
- TensorType output = outputs.front().dyn_cast<TensorType>();
+ TensorType input = llvm::dyn_cast<TensorType>(inputs.front());
+ TensorType output = llvm::dyn_cast<TensorType>(outputs.front());
if (!input || !output || input.getElementType() != output.getElementType())
return false;
// The shape is required to match if both types are ranked.
auto resultType = results.front();
// Check that the result type of the function matches the operand type.
- if (inputType == resultType || inputType.isa<mlir::UnrankedTensorType>() ||
- resultType.isa<mlir::UnrankedTensorType>())
+ if (inputType == resultType || llvm::isa<mlir::UnrankedTensorType>(inputType) ||
+ llvm::isa<mlir::UnrankedTensorType>(resultType))
return mlir::success();
return emitError() << "type of return operand (" << inputType
}
void TransposeOp::inferShapes() {
- auto arrayTy = getOperand().getType().cast<RankedTensorType>();
+ auto arrayTy = llvm::cast<RankedTensorType>(getOperand().getType());
SmallVector<int64_t, 2> dims(llvm::reverse(arrayTy.getShape()));
getResult().setType(RankedTensorType::get(dims, arrayTy.getElementType()));
}
mlir::LogicalResult TransposeOp::verify() {
- auto inputType = getOperand().getType().dyn_cast<RankedTensorType>();
- auto resultType = getType().dyn_cast<RankedTensorType>();
+ auto inputType = llvm::dyn_cast<RankedTensorType>(getOperand().getType());
+ auto resultType = llvm::dyn_cast<RankedTensorType>(getType());
if (!inputType || !resultType)
return mlir::success();
static void lowerOpToLoops(Operation *op, ValueRange operands,
PatternRewriter &rewriter,
LoopIterationFn processIteration) {
- auto tensorType = (*op->result_type_begin()).cast<RankedTensorType>();
+ auto tensorType = llvm::cast<RankedTensorType>((*op->result_type_begin()));
auto loc = op->getLoc();
// Insert an allocation and deallocation for the result of this operation.
// When lowering the constant operation, we allocate and assign the constant
// values to a corresponding memref allocation.
- auto tensorType = op.getType().cast<RankedTensorType>();
+ auto tensorType = llvm::cast<RankedTensorType>(op.getType());
auto memRefType = convertTensorToMemRef(tensorType);
auto alloc = insertAllocAndDealloc(memRefType, loc, rewriter);
target.addIllegalDialect<toy::ToyDialect>();
target.addDynamicallyLegalOp<toy::PrintOp>([](toy::PrintOp op) {
return llvm::none_of(op->getOperandTypes(),
- [](Type type) { return type.isa<TensorType>(); });
+ [](Type type) { return llvm::isa<TensorType>(type); });
});
// Now that the conversion target has been defined, we just need to provide
LogicalResult
matchAndRewrite(Operation *op, ArrayRef<Value> operands,
ConversionPatternRewriter &rewriter) const override {
- auto memRefType = (*op->operand_type_begin()).cast<MemRefType>();
+ auto memRefType = llvm::cast<MemRefType>((*op->operand_type_begin()));
auto memRefShape = memRefType.getShape();
auto loc = op->getLoc();
/// operands inferred.
static bool allOperandsInferred(Operation *op) {
return llvm::all_of(op->getOperandTypes(), [](Type operandType) {
- return operandType.isa<RankedTensorType>();
+ return llvm::isa<RankedTensorType>(operandType);
});
}
/// shaped result.
static bool returnsDynamicShape(Operation *op) {
return llvm::any_of(op->getResultTypes(), [](Type resultType) {
- return !resultType.isa<RankedTensorType>();
+ return !llvm::isa<RankedTensorType>(resultType);
});
}
};
// If the type is a function type, it contains the input and result types of
// this operation.
- if (FunctionType funcType = type.dyn_cast<FunctionType>()) {
+ if (FunctionType funcType = llvm::dyn_cast<FunctionType>(type)) {
if (parser.resolveOperands(operands, funcType.getInputs(), operandsLoc,
result.operands))
return mlir::failure();
static mlir::LogicalResult verifyConstantForType(mlir::Type type,
mlir::Attribute opaqueValue,
mlir::Operation *op) {
- if (type.isa<mlir::TensorType>()) {
+ if (llvm::isa<mlir::TensorType>(type)) {
// Check that the value is an elements attribute.
- auto attrValue = opaqueValue.dyn_cast<mlir::DenseFPElementsAttr>();
+ auto attrValue = llvm::dyn_cast<mlir::DenseFPElementsAttr>(opaqueValue);
if (!attrValue)
return op->emitError("constant of TensorType must be initialized by "
"a DenseFPElementsAttr, got ")
// If the return type of the constant is not an unranked tensor, the shape
// must match the shape of the attribute holding the data.
- auto resultType = type.dyn_cast<mlir::RankedTensorType>();
+ auto resultType = llvm::dyn_cast<mlir::RankedTensorType>(type);
if (!resultType)
return success();
// Check that the rank of the attribute type matches the rank of the
// constant result type.
- auto attrType = attrValue.getType().cast<mlir::RankedTensorType>();
+ auto attrType = llvm::cast<mlir::RankedTensorType>(attrValue.getType());
if (attrType.getRank() != resultType.getRank()) {
return op->emitOpError("return type must match the one of the attached "
"value attribute: ")
}
return mlir::success();
}
- auto resultType = type.cast<StructType>();
+ auto resultType = llvm::cast<StructType>(type);
llvm::ArrayRef<mlir::Type> resultElementTypes = resultType.getElementTypes();
// Verify that the initializer is an Array.
- auto attrValue = opaqueValue.dyn_cast<ArrayAttr>();
+ auto attrValue = llvm::dyn_cast<ArrayAttr>(opaqueValue);
if (!attrValue || attrValue.getValue().size() != resultElementTypes.size())
return op->emitError("constant of StructType must be initialized by an "
"ArrayAttr with the same number of elements, got ")
if (inputs.size() != 1 || outputs.size() != 1)
return false;
// The inputs must be Tensors with the same element type.
- TensorType input = inputs.front().dyn_cast<TensorType>();
- TensorType output = outputs.front().dyn_cast<TensorType>();
+ TensorType input = llvm::dyn_cast<TensorType>(inputs.front());
+ TensorType output = llvm::dyn_cast<TensorType>(outputs.front());
if (!input || !output || input.getElementType() != output.getElementType())
return false;
// The shape is required to match if both types are ranked.
auto resultType = results.front();
// Check that the result type of the function matches the operand type.
- if (inputType == resultType || inputType.isa<mlir::UnrankedTensorType>() ||
- resultType.isa<mlir::UnrankedTensorType>())
+ if (inputType == resultType || llvm::isa<mlir::UnrankedTensorType>(inputType) ||
+ llvm::isa<mlir::UnrankedTensorType>(resultType))
return mlir::success();
return emitError() << "type of return operand (" << inputType
void StructAccessOp::build(mlir::OpBuilder &b, mlir::OperationState &state,
mlir::Value input, size_t index) {
// Extract the result type from the input type.
- StructType structTy = input.getType().cast<StructType>();
+ StructType structTy = llvm::cast<StructType>(input.getType());
assert(index < structTy.getNumElementTypes());
mlir::Type resultType = structTy.getElementTypes()[index];
}
mlir::LogicalResult StructAccessOp::verify() {
- StructType structTy = getInput().getType().cast<StructType>();
+ StructType structTy = llvm::cast<StructType>(getInput().getType());
size_t indexValue = getIndex();
if (indexValue >= structTy.getNumElementTypes())
return emitOpError()
}
void TransposeOp::inferShapes() {
- auto arrayTy = getOperand().getType().cast<RankedTensorType>();
+ auto arrayTy = llvm::cast<RankedTensorType>(getOperand().getType());
SmallVector<int64_t, 2> dims(llvm::reverse(arrayTy.getShape()));
getResult().setType(RankedTensorType::get(dims, arrayTy.getElementType()));
}
mlir::LogicalResult TransposeOp::verify() {
- auto inputType = getOperand().getType().dyn_cast<RankedTensorType>();
- auto resultType = getType().dyn_cast<RankedTensorType>();
+ auto inputType = llvm::dyn_cast<RankedTensorType>(getOperand().getType());
+ auto resultType = llvm::dyn_cast<RankedTensorType>(getType());
if (!inputType || !resultType)
return mlir::success();
return nullptr;
// Check that the type is either a TensorType or another StructType.
- if (!elementType.isa<mlir::TensorType, StructType>()) {
+ if (!llvm::isa<mlir::TensorType, StructType>(elementType)) {
parser.emitError(typeLoc, "element type for a struct must either "
"be a TensorType or a StructType, got: ")
<< elementType;
void ToyDialect::printType(mlir::Type type,
mlir::DialectAsmPrinter &printer) const {
// Currently the only toy type is a struct type.
- StructType structType = type.cast<StructType>();
+ StructType structType = llvm::cast<StructType>(type);
// Print the struct type according to the parser format.
printer << "struct<";
mlir::Attribute value,
mlir::Type type,
mlir::Location loc) {
- if (type.isa<StructType>())
+ if (llvm::isa<StructType>(type))
return builder.create<StructConstantOp>(loc, type,
- value.cast<mlir::ArrayAttr>());
+ llvm::cast<mlir::ArrayAttr>(value));
return builder.create<ConstantOp>(loc, type,
- value.cast<mlir::DenseElementsAttr>());
+ llvm::cast<mlir::DenseElementsAttr>(value));
}
static void lowerOpToLoops(Operation *op, ValueRange operands,
PatternRewriter &rewriter,
LoopIterationFn processIteration) {
- auto tensorType = (*op->result_type_begin()).cast<RankedTensorType>();
+ auto tensorType = llvm::cast<RankedTensorType>((*op->result_type_begin()));
auto loc = op->getLoc();
// Insert an allocation and deallocation for the result of this operation.
// When lowering the constant operation, we allocate and assign the constant
// values to a corresponding memref allocation.
- auto tensorType = op.getType().cast<RankedTensorType>();
+ auto tensorType = llvm::cast<RankedTensorType>(op.getType());
auto memRefType = convertTensorToMemRef(tensorType);
auto alloc = insertAllocAndDealloc(memRefType, loc, rewriter);
target.addIllegalDialect<toy::ToyDialect>();
target.addDynamicallyLegalOp<toy::PrintOp>([](toy::PrintOp op) {
return llvm::none_of(op->getOperandTypes(),
- [](Type type) { return type.isa<TensorType>(); });
+ [](Type type) { return llvm::isa<TensorType>(type); });
});
// Now that the conversion target has been defined, we just need to provide
LogicalResult
matchAndRewrite(Operation *op, ArrayRef<Value> operands,
ConversionPatternRewriter &rewriter) const override {
- auto memRefType = (*op->operand_type_begin()).cast<MemRefType>();
+ auto memRefType = llvm::cast<MemRefType>((*op->operand_type_begin()));
auto memRefShape = memRefType.getShape();
auto loc = op->getLoc();
/// operands inferred.
static bool allOperandsInferred(Operation *op) {
return llvm::all_of(op->getOperandTypes(), [](Type operandType) {
- return operandType.isa<RankedTensorType>();
+ return llvm::isa<RankedTensorType>(operandType);
});
}
/// shaped result.
static bool returnsDynamicShape(Operation *op) {
return llvm::any_of(op->getResultTypes(), [](Type resultType) {
- return !resultType.isa<RankedTensorType>();
+ return !llvm::isa<RankedTensorType>(resultType);
});
}
};
/// Fold simple struct access operations that access into a constant.
OpFoldResult StructAccessOp::fold(FoldAdaptor adaptor) {
- auto structAttr = adaptor.getInput().dyn_cast_or_null<mlir::ArrayAttr>();
+ auto structAttr =
+ llvm::dyn_cast_if_present<mlir::ArrayAttr>(adaptor.getInput());
if (!structAttr)
return nullptr;
public:
Breakpoint *match(const Action &action) const override {
for (const IRUnit &unit : action.getContextIRUnits()) {
- if (auto *op = unit.dyn_cast<Operation *>()) {
+ if (auto *op = llvm::dyn_cast_if_present<Operation *>(unit)) {
if (auto match = matchFromLocation(op->getLoc()))
return *match;
continue;
}
- if (auto *block = unit.dyn_cast<Block *>()) {
+ if (auto *block = llvm::dyn_cast_if_present<Block *>(unit)) {
for (auto &op : block->getOperations()) {
if (auto match = matchFromLocation(op.getLoc()))
return *match;
}
continue;
}
- if (Region *region = unit.dyn_cast<Region *>()) {
+ if (Region *region = llvm::dyn_cast_if_present<Region *>(unit)) {
if (auto match = matchFromLocation(region->getLoc()))
return *match;
continue;
"gpu.mma_matrix", "::mlir::gpu::MMAMatrixType">;
// Types for all sparse handles.
-def GPU_SparseEnvHandle :
- DialectType<GPU_Dialect,
- CPred<"$_self.isa<::mlir::gpu::SparseEnvHandleType>()">,
- "sparse environment handle type">,
+def GPU_SparseEnvHandle :
+ DialectType<GPU_Dialect,
+ CPred<"llvm::isa<::mlir::gpu::SparseEnvHandleType>($_self)">,
+ "sparse environment handle type">,
BuildableType<"mlir::gpu::SparseEnvHandleType::get($_builder.getContext())">;
-def GPU_SparseDnVecHandle :
- DialectType<GPU_Dialect,
- CPred<"$_self.isa<::mlir::gpu::SparseDnVecHandleType>()">,
+def GPU_SparseDnVecHandle :
+ DialectType<GPU_Dialect,
+ CPred<"llvm::isa<::mlir::gpu::SparseDnVecHandleType>($_self)">,
"dense vector handle type">,
BuildableType<"mlir::gpu::SparseDnVecHandleType::get($_builder.getContext())">;
-def GPU_SparseDnMatHandle :
- DialectType<GPU_Dialect,
- CPred<"$_self.isa<::mlir::gpu::SparseDnMatHandleType>()">,
+def GPU_SparseDnMatHandle :
+ DialectType<GPU_Dialect,
+ CPred<"llvm::isa<::mlir::gpu::SparseDnMatHandleType>($_self)">,
"dense matrix handle type">,
BuildableType<"mlir::gpu::SparseDnMatHandleType::get($_builder.getContext())">;
-def GPU_SparseSpMatHandle :
- DialectType<GPU_Dialect,
- CPred<"$_self.isa<::mlir::gpu::SparseSpMatHandleType>()">,
+def GPU_SparseSpMatHandle :
+ DialectType<GPU_Dialect,
+ CPred<"llvm::isa<::mlir::gpu::SparseSpMatHandleType>($_self)">,
"sparse matrix handle type">,
BuildableType<"mlir::gpu::SparseSpMatHandleType::get($_builder.getContext())">;
/*methodName=*/"getDeclareTargetDeviceType",
(ins), [{}], [{
if (mlir::Attribute dTar = $_op->getAttr("omp.declare_target"))
- if (auto dAttr = dTar.dyn_cast_or_null<mlir::omp::DeclareTargetAttr>())
+ if (auto dAttr = llvm::dyn_cast_or_null<mlir::omp::DeclareTargetAttr>(dTar))
return dAttr.getDeviceType().getValue();
return {};
}]>,
/*methodName=*/"getDeclareTargetCaptureClause",
(ins), [{}], [{
if (mlir::Attribute dTar = $_op->getAttr("omp.declare_target"))
- if (auto dAttr = dTar.dyn_cast_or_null<mlir::omp::DeclareTargetAttr>())
+ if (auto dAttr = llvm::dyn_cast_or_null<mlir::omp::DeclareTargetAttr>(dTar))
return dAttr.getCaptureClause().getValue();
return {};
}]>
static bool classof(Type type);
/// Allow implicit conversion to ShapedType.
- operator ShapedType() const { return cast<ShapedType>(); }
+ operator ShapedType() const { return llvm::cast<ShapedType>(*this); }
};
//===----------------------------------------------------------------------===//
unsigned getMemorySpaceAsInt() const;
/// Allow implicit conversion to ShapedType.
- operator ShapedType() const { return cast<ShapedType>(); }
+ operator ShapedType() const { return llvm::cast<ShapedType>(*this); }
};
} // namespace mlir
}
static bool isEmptyKey(mlir::TypeRange range) {
- if (const auto *type = range.getBase().dyn_cast<const mlir::Type *>())
+ if (const auto *type =
+ llvm::dyn_cast_if_present<const mlir::Type *>(range.getBase()))
return type == getEmptyKeyPointer();
return false;
}
static bool isTombstoneKey(mlir::TypeRange range) {
- if (const auto *type = range.getBase().dyn_cast<const mlir::Type *>())
+ if (const auto *type =
+ llvm::dyn_cast_if_present<const mlir::Type *>(range.getBase()))
return type == getTombstoneKeyPointer();
return false;
}
/// Return the value the effect is applied on, or nullptr if there isn't a
/// known value being affected.
- Value getValue() const { return value ? value.dyn_cast<Value>() : Value(); }
+ Value getValue() const { return value ? llvm::dyn_cast_if_present<Value>(value) : Value(); }
/// Return the symbol reference the effect is applied on, or nullptr if there
/// isn't a known smbol being affected.
SymbolRefAttr getSymbolRef() const {
- return value ? value.dyn_cast<SymbolRefAttr>() : SymbolRefAttr();
+ return value ? llvm::dyn_cast_if_present<SymbolRefAttr>(value) : SymbolRefAttr();
}
/// Return the resource that the effect applies to.
/// Returns the parent analysis map for this analysis map, or null if this is
/// the top-level map.
const NestedAnalysisMap *getParent() const {
- return parentOrInstrumentor.dyn_cast<NestedAnalysisMap *>();
+ return llvm::dyn_cast_if_present<NestedAnalysisMap *>(parentOrInstrumentor);
}
/// Returns a pass instrumentation object for the current operation. This
// Merge in the result of the fold, either a constant or a value.
OpFoldResult foldResult = std::get<1>(it);
- if (Attribute attr = foldResult.dyn_cast<Attribute>()) {
+ if (Attribute attr = llvm::dyn_cast_if_present<Attribute>(foldResult)) {
LLVM_DEBUG(llvm::dbgs() << "Folded to constant: " << attr << "\n");
propagateIfChanged(lattice,
lattice->join(ConstantValue(attr, op->getDialect())));
}
void Executable::onUpdate(DataFlowSolver *solver) const {
- if (auto *block = point.dyn_cast<Block *>()) {
+ if (auto *block = llvm::dyn_cast_if_present<Block *>(point)) {
// Re-invoke the analyses on the block itself.
for (DataFlowAnalysis *analysis : subscribers)
solver->enqueue({block, analysis});
for (DataFlowAnalysis *analysis : subscribers)
for (Operation &op : *block)
solver->enqueue({&op, analysis});
- } else if (auto *programPoint = point.dyn_cast<GenericProgramPoint *>()) {
+ } else if (auto *programPoint = llvm::dyn_cast_if_present<GenericProgramPoint *>(point)) {
// Re-invoke the analysis on the successor block.
if (auto *edge = dyn_cast<CFGEdge>(programPoint)) {
for (DataFlowAnalysis *analysis : subscribers)
LogicalResult DeadCodeAnalysis::visit(ProgramPoint point) {
if (point.is<Block *>())
return success();
- auto *op = point.dyn_cast<Operation *>();
+ auto *op = llvm::dyn_cast_if_present<Operation *>(point);
if (!op)
return emitError(point.getLoc(), "unknown program point kind");
}
LogicalResult AbstractDenseDataFlowAnalysis::visit(ProgramPoint point) {
- if (auto *op = point.dyn_cast<Operation *>())
+ if (auto *op = llvm::dyn_cast_if_present<Operation *>(point))
processOperation(op);
- else if (auto *block = point.dyn_cast<Block *>())
+ else if (auto *block = llvm::dyn_cast_if_present<Block *>(point))
visitBlock(block);
else
return failure();
if (auto bound =
dyn_cast_or_null<IntegerAttr>(loopBound->get<Attribute>()))
return bound.getValue();
- } else if (auto value = loopBound->dyn_cast<Value>()) {
+ } else if (auto value = llvm::dyn_cast_if_present<Value>(*loopBound)) {
const IntegerValueRangeLattice *lattice =
getLatticeElementFor(op, value);
if (lattice != nullptr)
}
LogicalResult AbstractSparseDataFlowAnalysis::visit(ProgramPoint point) {
- if (Operation *op = point.dyn_cast<Operation *>())
+ if (Operation *op = llvm::dyn_cast_if_present<Operation *>(point))
visitOperation(op);
- else if (Block *block = point.dyn_cast<Block *>())
+ else if (Block *block = llvm::dyn_cast_if_present<Block *>(point))
visitBlock(block);
else
return failure();
unsigned firstIndex = 0;
if (inputs.size() != lattices.size()) {
- if (point.dyn_cast<Operation *>()) {
+ if (llvm::dyn_cast_if_present<Operation *>(point)) {
if (!inputs.empty())
firstIndex = cast<OpResult>(inputs.front()).getResultNumber();
visitNonControlFlowArgumentsImpl(
LogicalResult
AbstractSparseBackwardDataFlowAnalysis::visit(ProgramPoint point) {
- if (Operation *op = point.dyn_cast<Operation *>())
+ if (Operation *op = llvm::dyn_cast_if_present<Operation *>(point))
visitOperation(op);
- else if (point.dyn_cast<Block *>())
+ else if (llvm::dyn_cast_if_present<Block *>(point))
// For backward dataflow, we don't have to do any work for the blocks
// themselves. CFG edges between blocks are processed by the BranchOp
// logic in `visitOperation`, and entry blocks for functions are tied
os << "<NULL POINT>";
return;
}
- if (auto *programPoint = dyn_cast<GenericProgramPoint *>())
+ if (auto *programPoint = llvm::dyn_cast<GenericProgramPoint *>(*this))
return programPoint->print(os);
- if (auto *op = dyn_cast<Operation *>())
+ if (auto *op = llvm::dyn_cast<Operation *>(*this))
return op->print(os);
- if (auto value = dyn_cast<Value>())
+ if (auto value = llvm::dyn_cast<Value>(*this))
return value.print(os);
return get<Block *>()->print(os);
}
Location ProgramPoint::getLoc() const {
- if (auto *programPoint = dyn_cast<GenericProgramPoint *>())
+ if (auto *programPoint = llvm::dyn_cast<GenericProgramPoint *>(*this))
return programPoint->getLoc();
- if (auto *op = dyn_cast<Operation *>())
+ if (auto *op = llvm::dyn_cast<Operation *>(*this))
return op->getLoc();
- if (auto value = dyn_cast<Value>())
+ if (auto value = llvm::dyn_cast<Value>(*this))
return value.getLoc();
return get<Block *>()->getParent()->getLoc();
}
if (parseToken(Token::r_paren, "expected ')' in location"))
return failure();
- if (auto *op = opOrArgument.dyn_cast<Operation *>())
+ if (auto *op = llvm::dyn_cast_if_present<Operation *>(opOrArgument))
op->setLoc(directLoc);
else
opOrArgument.get<BlockArgument>().setLoc(directLoc);
DictionaryAttr unwrapAttributes(MlirAttribute attributes) {
DictionaryAttr attributeDict;
if (!mlirAttributeIsNull(attributes))
- attributeDict = unwrap(attributes).cast<DictionaryAttr>();
+ attributeDict = llvm::cast<DictionaryAttr>(unwrap(attributes));
return attributeDict;
}
// TODO: safer and more flexible to store data type in actual op instead?
static Type getSpMatElemType(Value spMat) {
if (auto op = spMat.getDefiningOp<gpu::CreateCooOp>())
- return op.getValues().getType().cast<MemRefType>().getElementType();
+ return llvm::cast<MemRefType>(op.getValues().getType()).getElementType();
if (auto op = spMat.getDefiningOp<gpu::CreateCsrOp>())
- return op.getValues().getType().cast<MemRefType>().getElementType();
+ return llvm::cast<MemRefType>(op.getValues().getType()).getElementType();
llvm_unreachable("cannot find spmat def");
}
MemRefDescriptor(adaptor.getMemref()).allocatedPtr(rewriter, loc);
if (!getTypeConverter()->useOpaquePointers())
pVec = rewriter.create<LLVM::BitcastOp>(loc, llvmPointerType, pVec);
- Type dType = op.getMemref().getType().cast<MemRefType>().getElementType();
+ Type dType = llvm::cast<MemRefType>(op.getMemref().getType()).getElementType();
auto dw = rewriter.create<LLVM::ConstantOp>(loc, llvmInt32Type,
dType.getIntOrFloatBitWidth());
auto handle =
MemRefDescriptor(adaptor.getMemref()).allocatedPtr(rewriter, loc);
if (!getTypeConverter()->useOpaquePointers())
pMat = rewriter.create<LLVM::BitcastOp>(loc, llvmPointerType, pMat);
- Type dType = op.getMemref().getType().cast<MemRefType>().getElementType();
+ Type dType = llvm::cast<MemRefType>(op.getMemref().getType()).getElementType();
auto dw = rewriter.create<LLVM::ConstantOp>(loc, llvmInt32Type,
dType.getIntOrFloatBitWidth());
auto handle =
pColIdxs = rewriter.create<LLVM::BitcastOp>(loc, llvmPointerType, pColIdxs);
pValues = rewriter.create<LLVM::BitcastOp>(loc, llvmPointerType, pValues);
}
- Type iType = op.getColIdxs().getType().cast<MemRefType>().getElementType();
- Type dType = op.getValues().getType().cast<MemRefType>().getElementType();
+ Type iType = llvm::cast<MemRefType>(op.getColIdxs().getType()).getElementType();
+ Type dType = llvm::cast<MemRefType>(op.getValues().getType()).getElementType();
auto iw = rewriter.create<LLVM::ConstantOp>(
loc, llvmInt32Type, iType.isIndex() ? 64 : iType.getIntOrFloatBitWidth());
auto dw = rewriter.create<LLVM::ConstantOp>(loc, llvmInt32Type,
pColIdxs = rewriter.create<LLVM::BitcastOp>(loc, llvmPointerType, pColIdxs);
pValues = rewriter.create<LLVM::BitcastOp>(loc, llvmPointerType, pValues);
}
- Type pType = op.getRowPos().getType().cast<MemRefType>().getElementType();
- Type iType = op.getColIdxs().getType().cast<MemRefType>().getElementType();
- Type dType = op.getValues().getType().cast<MemRefType>().getElementType();
+ Type pType = llvm::cast<MemRefType>(op.getRowPos().getType()).getElementType();
+ Type iType = llvm::cast<MemRefType>(op.getColIdxs().getType()).getElementType();
+ Type dType = llvm::cast<MemRefType>(op.getValues().getType()).getElementType();
auto pw = rewriter.create<LLVM::ConstantOp>(
loc, llvmInt32Type, pType.isIndex() ? 64 : pType.getIntOrFloatBitWidth());
auto iw = rewriter.create<LLVM::ConstantOp>(
return failure();
if (!(*converted)) // Conversion to default is 0.
return 0;
- if (auto explicitSpace = converted->dyn_cast_or_null<IntegerAttr>())
+ if (auto explicitSpace = llvm::dyn_cast_if_present<IntegerAttr>(*converted))
return explicitSpace.getInt();
return failure();
}
Attribute initialValue = nullptr;
if (!global.isExternal() && !global.isUninitialized()) {
- auto elementsAttr = global.getInitialValue()->cast<ElementsAttr>();
+ auto elementsAttr = llvm::cast<ElementsAttr>(*global.getInitialValue());
initialValue = elementsAttr;
// For scalar memrefs, the global variable created is of the element type,
auto *ans = cast<TypeAnswer>(answer);
if (isa<pdl::RangeType>(val.getType()))
builder.create<pdl_interp::CheckTypesOp>(
- loc, val, ans->getValue().cast<ArrayAttr>(), success, failure);
+ loc, val, llvm::cast<ArrayAttr>(ans->getValue()), success, failure);
else
builder.create<pdl_interp::CheckTypeOp>(
- loc, val, ans->getValue().cast<TypeAttr>(), success, failure);
+ loc, val, llvm::cast<TypeAttr>(ans->getValue()), success, failure);
break;
}
case Predicates::AttributeQuestion: {
return rewriter.create<mlir::math::TanhOp>(loc, resultTypes, args);
// tosa::ErfOp
- if (isa<tosa::ErfOp>(op) && elementTy.isa<FloatType>())
+ if (isa<tosa::ErfOp>(op) && llvm::isa<FloatType>(elementTy))
return rewriter.create<mlir::math::ErfOp>(loc, resultTypes, args);
// tosa::GreaterOp
auto addDynamicDimension = [&](Value source, int64_t dim) {
auto dynamicDim = tensor::createDimValue(builder, loc, source, dim);
- if (auto dimValue = dynamicDim.value().dyn_cast<Value>())
+ if (auto dimValue = llvm::dyn_cast_if_present<Value>(dynamicDim.value()))
results.push_back(dimValue);
};
return;
}
IRUnit *unit = &state.cursor;
- if (auto *op = unit->dyn_cast<Operation *>()) {
+ if (auto *op = llvm::dyn_cast_if_present<Operation *>(*unit)) {
state.cursor = op->getBlock();
- } else if (auto *region = unit->dyn_cast<Region *>()) {
+ } else if (auto *region = llvm::dyn_cast_if_present<Region *>(*unit)) {
state.cursor = region->getParentOp();
- } else if (auto *block = unit->dyn_cast<Block *>()) {
+ } else if (auto *block = llvm::dyn_cast_if_present<Block *>(*unit)) {
state.cursor = block->getParent();
} else {
llvm::outs() << "Current cursor is not a valid IRUnit";
return;
}
IRUnit *unit = &state.cursor;
- if (auto *op = unit->dyn_cast<Operation *>()) {
+ if (auto *op = llvm::dyn_cast_if_present<Operation *>(*unit)) {
if (index < 0 || index >= static_cast<int>(op->getNumRegions())) {
llvm::outs() << "Index invalid, op has " << op->getNumRegions()
<< " but got " << index << "\n";
return;
}
state.cursor = &op->getRegion(index);
- } else if (auto *region = unit->dyn_cast<Region *>()) {
+ } else if (auto *region = llvm::dyn_cast_if_present<Region *>(*unit)) {
auto block = region->begin();
int count = 0;
while (block != region->end() && count != index) {
return;
}
state.cursor = &*block;
- } else if (auto *block = unit->dyn_cast<Block *>()) {
+ } else if (auto *block = llvm::dyn_cast_if_present<Block *>(*unit)) {
auto op = block->begin();
int count = 0;
while (op != block->end() && count != index) {
return;
}
IRUnit *unit = &state.cursor;
- if (auto *op = unit->dyn_cast<Operation *>()) {
+ if (auto *op = llvm::dyn_cast_if_present<Operation *>(*unit)) {
Operation *previous = op->getPrevNode();
if (!previous) {
llvm::outs() << "No previous operation in the current block\n";
return;
}
state.cursor = previous;
- } else if (auto *region = unit->dyn_cast<Region *>()) {
+ } else if (auto *region = llvm::dyn_cast_if_present<Region *>(*unit)) {
llvm::outs() << "Has region\n";
Operation *parent = region->getParentOp();
if (!parent) {
}
state.cursor =
®ion->getParentOp()->getRegion(region->getRegionNumber() - 1);
- } else if (auto *block = unit->dyn_cast<Block *>()) {
+ } else if (auto *block = llvm::dyn_cast_if_present<Block *>(*unit)) {
Block *previous = block->getPrevNode();
if (!previous) {
llvm::outs() << "No previous block in the current region\n";
return;
}
IRUnit *unit = &state.cursor;
- if (auto *op = unit->dyn_cast<Operation *>()) {
+ if (auto *op = llvm::dyn_cast_if_present<Operation *>(*unit)) {
Operation *next = op->getNextNode();
if (!next) {
llvm::outs() << "No next operation in the current block\n";
return;
}
state.cursor = next;
- } else if (auto *region = unit->dyn_cast<Region *>()) {
+ } else if (auto *region = llvm::dyn_cast_if_present<Region *>(*unit)) {
Operation *parent = region->getParentOp();
if (!parent) {
llvm::outs() << "No parent operation for the current region\n";
}
state.cursor =
®ion->getParentOp()->getRegion(region->getRegionNumber() + 1);
- } else if (auto *block = unit->dyn_cast<Block *>()) {
+ } else if (auto *block = llvm::dyn_cast_if_present<Block *>(*unit)) {
Block *next = block->getNextNode();
if (!next) {
llvm::outs() << "No next block in the current region\n";
actualValues.reserve(values.size());
auto *dialect = b.getContext()->getLoadedDialect<AffineDialect>();
for (OpFoldResult ofr : values) {
- if (auto value = ofr.dyn_cast<Value>()) {
+ if (auto value = llvm::dyn_cast_if_present<Value>(ofr)) {
actualValues.push_back(value);
continue;
}
if (staticDim.has_value())
return builder.create<arith::ConstantIndexOp>(result.location,
*staticDim);
- return ofr.dyn_cast<Value>();
+ return llvm::dyn_cast_if_present<Value>(ofr);
});
result.addOperands(basisValues);
}
if (matchPattern(getRhs(), m_Zero()))
return getLhs();
/// or(x, <all ones>) -> <all ones>
- if (auto rhsAttr = adaptor.getRhs().dyn_cast_or_null<IntegerAttr>())
+ if (auto rhsAttr = llvm::dyn_cast_if_present<IntegerAttr>(adaptor.getRhs()))
if (rhsAttr.getValue().isAllOnes())
return rhsAttr;
/// Always fold extension of FP constants.
OpFoldResult arith::ExtFOp::fold(FoldAdaptor adaptor) {
- auto constOperand = adaptor.getIn().dyn_cast_or_null<FloatAttr>();
+ auto constOperand = llvm::dyn_cast_if_present<FloatAttr>(adaptor.getIn());
if (!constOperand)
return {};
// We are moving constants to the right side; So if lhs is constant rhs is
// guaranteed to be a constant.
- if (auto lhs = adaptor.getLhs().dyn_cast_or_null<TypedAttr>()) {
+ if (auto lhs = llvm::dyn_cast_if_present<TypedAttr>(adaptor.getLhs())) {
return constFoldBinaryOp<IntegerAttr>(
adaptor.getOperands(), getI1SameShape(lhs.getType()),
[pred = getPredicate()](const APInt &lhs, const APInt &rhs) {
}
OpFoldResult arith::CmpFOp::fold(FoldAdaptor adaptor) {
- auto lhs = adaptor.getLhs().dyn_cast_or_null<FloatAttr>();
- auto rhs = adaptor.getRhs().dyn_cast_or_null<FloatAttr>();
+ auto lhs = llvm::dyn_cast_if_present<FloatAttr>(adaptor.getLhs());
+ auto rhs = llvm::dyn_cast_if_present<FloatAttr>(adaptor.getRhs());
// If one operand is NaN, making them both NaN does not change the result.
if (lhs && lhs.getValue().isNaN())
// Constant-fold constant operands over non-splat constant condition.
// select %cst_vec, %cst0, %cst1 => %cst2
if (auto cond =
- adaptor.getCondition().dyn_cast_or_null<DenseElementsAttr>()) {
+ llvm::dyn_cast_if_present<DenseElementsAttr>(adaptor.getCondition())) {
if (auto lhs =
- adaptor.getTrueValue().dyn_cast_or_null<DenseElementsAttr>()) {
+ llvm::dyn_cast_if_present<DenseElementsAttr>(adaptor.getTrueValue())) {
if (auto rhs =
- adaptor.getFalseValue().dyn_cast_or_null<DenseElementsAttr>()) {
+ llvm::dyn_cast_if_present<DenseElementsAttr>(adaptor.getFalseValue())) {
SmallVector<Attribute> results;
results.reserve(static_cast<size_t>(cond.getNumElements()));
auto condVals = llvm::make_range(cond.value_begin<BoolAttr>(),
// If the buffers have different types, they differ only in their layout
// map.
- auto memrefType = trueType->cast<MemRefType>();
+ auto memrefType = llvm::cast<MemRefType>(*trueType);
return getMemRefTypeWithFullyDynamicLayout(
RankedTensorType::get(memrefType.getShape(),
memrefType.getElementType()),
if (ofr.is<Attribute>())
continue;
// Newly static, move from Value to constant.
- if (auto cstOp =
- ofr.dyn_cast<Value>().getDefiningOp<arith::ConstantIndexOp>()) {
+ if (auto cstOp = llvm::dyn_cast_if_present<Value>(ofr)
+ .getDefiningOp<arith::ConstantIndexOp>()) {
ofr = b.getIndexAttr(cstOp.value());
valuesChanged = true;
}
Value mlir::getValueOrCreateConstantIndexOp(OpBuilder &b, Location loc,
OpFoldResult ofr) {
- if (auto value = ofr.dyn_cast<Value>())
+ if (auto value = llvm::dyn_cast_if_present<Value>(ofr))
return value;
- auto attr = dyn_cast<IntegerAttr>(ofr.dyn_cast<Attribute>());
+ auto attr = dyn_cast<IntegerAttr>(llvm::dyn_cast_if_present<Attribute>(ofr));
assert(attr && "expect the op fold result casts to an integer attribute");
return b.create<arith::ConstantIndexOp>(loc, attr.getValue().getSExtValue());
}
populateDynamicDimSizes(rewriter, loc, copyBuffer, dynamicDims);
}
FailureOr<Value> alloc = options.createAlloc(
- rewriter, loc, allocType->cast<MemRefType>(), dynamicDims);
+ rewriter, loc, llvm::cast<MemRefType>(*allocType), dynamicDims);
if (failed(alloc))
return failure();
/// Return the func::FuncOp called by `callOp`.
static func::FuncOp getCalledFunction(CallOpInterface callOp) {
- SymbolRefAttr sym = callOp.getCallableForCallee().dyn_cast<SymbolRefAttr>();
+ SymbolRefAttr sym =
+ llvm::dyn_cast_if_present<SymbolRefAttr>(callOp.getCallableForCallee());
if (!sym)
return nullptr;
return dyn_cast_or_null<func::FuncOp>(
/// Return the FuncOp called by `callOp`.
static FuncOp getCalledFunction(CallOpInterface callOp) {
- SymbolRefAttr sym = callOp.getCallableForCallee().dyn_cast<SymbolRefAttr>();
+ SymbolRefAttr sym = llvm::dyn_cast_if_present<SymbolRefAttr>(callOp.getCallableForCallee());
if (!sym)
return nullptr;
return dyn_cast_or_null<FuncOp>(
op->walk([&](Operation *op) {
SmallVector<Attribute> aliasSets;
for (OpResult opResult : op->getOpResults()) {
- if (opResult.getType().isa<TensorType>()) {
+ if (llvm::isa<TensorType>(opResult.getType())) {
SmallVector<Attribute> aliases;
state.applyOnAliases(opResult, [&](Value alias) {
std::string buffer;
/// Return the func::FuncOp called by `callOp`.
static func::FuncOp getCalledFunction(func::CallOp callOp) {
- SymbolRefAttr sym = callOp.getCallableForCallee().dyn_cast<SymbolRefAttr>();
+ SymbolRefAttr sym = llvm::dyn_cast_if_present<SymbolRefAttr>(callOp.getCallableForCallee());
if (!sym)
return nullptr;
return dyn_cast_or_null<func::FuncOp>(
//===----------------------------------------------------------------------===//
OpFoldResult ImOp::fold(FoldAdaptor adaptor) {
- ArrayAttr arrayAttr = adaptor.getComplex().dyn_cast_or_null<ArrayAttr>();
+ ArrayAttr arrayAttr =
+ llvm::dyn_cast_if_present<ArrayAttr>(adaptor.getComplex());
if (arrayAttr && arrayAttr.size() == 2)
return arrayAttr[1];
if (auto createOp = getOperand().getDefiningOp<CreateOp>())
//===----------------------------------------------------------------------===//
OpFoldResult ReOp::fold(FoldAdaptor adaptor) {
- ArrayAttr arrayAttr = adaptor.getComplex().dyn_cast_or_null<ArrayAttr>();
+ ArrayAttr arrayAttr =
+ llvm::dyn_cast_if_present<ArrayAttr>(adaptor.getComplex());
if (arrayAttr && arrayAttr.size() == 2)
return arrayAttr[0];
if (auto createOp = getOperand().getDefiningOp<CreateOp>())
void DataLayoutEntryAttr::print(AsmPrinter &os) const {
os << DataLayoutEntryAttr::kAttrKeyword << "<";
- if (auto type = getKey().dyn_cast<Type>())
+ if (auto type = llvm::dyn_cast_if_present<Type>(getKey()))
os << type;
else
os << "\"" << getKey().get<StringAttr>().strref() << "\"";
DenseSet<Type> types;
DenseSet<StringAttr> ids;
for (DataLayoutEntryInterface entry : entries) {
- if (auto type = entry.getKey().dyn_cast<Type>()) {
+ if (auto type = llvm::dyn_cast_if_present<Type>(entry.getKey())) {
if (!types.insert(type).second)
return emitError() << "repeated layout entry key: " << type;
} else {
// error. All other canonicalization is done in the fold method.
bool requiresConst = !rawConstantIndices.empty() &&
currType.isa_and_nonnull<LLVMStructType>();
- if (Value val = iter.dyn_cast<Value>()) {
+ if (Value val = llvm::dyn_cast_if_present<Value>(iter)) {
APInt intC;
if (requiresConst && matchPattern(val, m_ConstantInt(&intC)) &&
intC.isSignedIntN(kGEPConstantBitWidth)) {
llvm::interleaveComma(
GEPIndicesAdaptor<OperandRange>(rawConstantIndices, indices), printer,
[&](PointerUnion<IntegerAttr, Value> cst) {
- if (Value val = cst.dyn_cast<Value>())
+ if (Value val = llvm::dyn_cast_if_present<Value>(cst))
printer.printOperand(val);
else
printer << cst.get<IntegerAttr>().getInt();
!integer.getValue().isSignedIntN(kGEPConstantBitWidth)) {
PointerUnion<IntegerAttr, Value> existing = getIndices()[iter.index()];
- if (Value val = existing.dyn_cast<Value>())
+ if (Value val = llvm::dyn_cast_if_present<Value>(existing))
gepArgs.emplace_back(val);
else
gepArgs.emplace_back(existing.get<IntegerAttr>().getInt());
static bool hasAllZeroIndices(LLVM::GEPOp gepOp) {
return llvm::all_of(gepOp.getIndices(), [](auto index) {
- auto indexAttr = index.template dyn_cast<IntegerAttr>();
+ auto indexAttr = llvm::dyn_cast_if_present<IntegerAttr>(index);
return indexAttr && indexAttr.getValue() == 0;
});
}
// Ensures all indices are static and fetches them.
SmallVector<IntegerAttr> indices;
for (auto index : gep.getIndices()) {
- IntegerAttr indexInt = index.dyn_cast<IntegerAttr>();
+ IntegerAttr indexInt = llvm::dyn_cast_if_present<IntegerAttr>(index);
if (!indexInt)
return {};
indices.push_back(indexInt);
for (IntegerAttr index : llvm::drop_begin(indices)) {
// Ensure the structure of the type being indexed can be reasoned about.
// This includes rejecting any potential typed pointer.
- auto destructurable = selectedType.dyn_cast<DestructurableTypeInterface>();
+ auto destructurable = llvm::dyn_cast<DestructurableTypeInterface>(selectedType);
if (!destructurable)
return {};
bool LLVM::GEPOp::canRewire(const DestructurableMemorySlot &slot,
SmallPtrSetImpl<Attribute> &usedIndices,
SmallVectorImpl<MemorySlot> &mustBeSafelyUsed) {
- auto basePtrType = getBase().getType().dyn_cast<LLVM::LLVMPointerType>();
+ auto basePtrType = llvm::dyn_cast<LLVM::LLVMPointerType>(getBase().getType());
if (!basePtrType)
return false;
return false;
auto firstLevelIndex = cast<IntegerAttr>(getIndices()[1]);
assert(slot.elementPtrs.contains(firstLevelIndex));
- if (!slot.elementPtrs.at(firstLevelIndex).isa<LLVM::LLVMPointerType>())
+ if (!llvm::isa<LLVM::LLVMPointerType>(slot.elementPtrs.at(firstLevelIndex)))
return false;
mustBeSafelyUsed.emplace_back<MemorySlot>({getResult(), reachedType});
usedIndices.insert(firstLevelIndex);
DeletionKind LLVM::GEPOp::rewire(const DestructurableMemorySlot &slot,
DenseMap<Attribute, MemorySlot> &subslots,
RewriterBase &rewriter) {
- IntegerAttr firstLevelIndex = getIndices()[1].dyn_cast<IntegerAttr>();
+ IntegerAttr firstLevelIndex = llvm::dyn_cast_if_present<IntegerAttr>(getIndices()[1]);
const MemorySlot &newSlot = subslots.at(firstLevelIndex);
ArrayRef<int32_t> remainingIndices = getRawConstantIndices().slice(2);
}
Type LLVM::LLVMStructType::getTypeAtIndex(Attribute index) {
- auto indexAttr = index.dyn_cast<IntegerAttr>();
+ auto indexAttr = llvm::dyn_cast<IntegerAttr>(index);
if (!indexAttr || !indexAttr.getType().isInteger(32))
return {};
int32_t indexInt = indexAttr.getInt();
}
Type LLVM::LLVMArrayType::getTypeAtIndex(Attribute index) const {
- auto indexAttr = index.dyn_cast<IntegerAttr>();
+ auto indexAttr = llvm::dyn_cast<IntegerAttr>(index);
if (!indexAttr || !indexAttr.getType().isInteger(32))
return {};
int32_t indexInt = indexAttr.getInt();
auto newType = llvm::cast<LLVMPointerType>(newEntry.getKey().get<Type>());
const auto *it =
llvm::find_if(oldLayout, [&](DataLayoutEntryInterface entry) {
- if (auto type = entry.getKey().dyn_cast<Type>()) {
+ if (auto type = llvm::dyn_cast_if_present<Type>(entry.getKey())) {
return llvm::cast<LLVMPointerType>(type).getAddressSpace() ==
newType.getAddressSpace();
}
});
if (it == oldLayout.end()) {
llvm::find_if(oldLayout, [&](DataLayoutEntryInterface entry) {
- if (auto type = entry.getKey().dyn_cast<Type>()) {
+ if (auto type = llvm::dyn_cast_if_present<Type>(entry.getKey())) {
return llvm::cast<LLVMPointerType>(type).getAddressSpace() == 0;
}
return false;
sizes.reserve(tileSizes.size());
unsigned dynamicIdx = 0;
for (OpFoldResult ofr : getMixedSizes()) {
- if (auto attr = ofr.dyn_cast<Attribute>()) {
+ if (auto attr = llvm::dyn_cast_if_present<Attribute>(ofr)) {
sizes.push_back(b.create<arith::ConstantIndexOp>(
getLoc(), cast<IntegerAttr>(attr).getInt()));
continue;
sizes.reserve(tileSizes.size());
unsigned dynamicIdx = 0;
for (OpFoldResult ofr : getMixedSizes()) {
- if (auto attr = ofr.dyn_cast<Attribute>()) {
+ if (auto attr = llvm::dyn_cast_if_present<Attribute>(ofr)) {
sizes.push_back(b.create<arith::ConstantIndexOp>(
getLoc(), cast<IntegerAttr>(attr).getInt()));
} else {
cast<LinalgOp>(genericOp.getOperation())
.createLoopRanges(rewriter, genericOp.getLoc());
auto opFoldIsConstantValue = [](OpFoldResult ofr, int64_t value) {
- if (auto attr = ofr.dyn_cast<Attribute>())
+ if (auto attr = llvm::dyn_cast_if_present<Attribute>(ofr))
return cast<IntegerAttr>(attr).getInt() == value;
llvm::APInt actual;
return matchPattern(ofr.get<Value>(), m_ConstantInt(&actual)) &&
// to look for the bound.
LLVM_DEBUG(llvm::dbgs() << "Extract tightest: " << rangeValue.size << "\n");
Value size;
- if (auto attr = rangeValue.size.dyn_cast<Attribute>()) {
+ if (auto attr = llvm::dyn_cast_if_present<Attribute>(rangeValue.size)) {
size = getValueOrCreateConstantIndexOp(b, loc, rangeValue.size);
} else {
Value materializedSize =
rewriter, op.getLoc(), d0 + d1 - d2,
{iterationSpace[dimension].offset, iterationSpace[dimension].size,
minSplitPoint});
- if (auto attr = remainingSize.dyn_cast<Attribute>()) {
+ if (auto attr = llvm::dyn_cast_if_present<Attribute>(remainingSize)) {
if (cast<IntegerAttr>(attr).getValue().isZero())
return {op, TilingInterface()};
}
static bool isZero(OpFoldResult v) {
if (!v)
return false;
- if (auto attr = v.dyn_cast<Attribute>()) {
+ if (auto attr = llvm::dyn_cast_if_present<Attribute>(v)) {
IntegerAttr intAttr = dyn_cast<IntegerAttr>(attr);
return intAttr && intAttr.getValue().isZero();
}
/// checked at runtime.
static void emitIsPositiveIndexAssertion(ImplicitLocOpBuilder &b,
OpFoldResult value) {
- if (auto attr = value.dyn_cast<Attribute>()) {
+ if (auto attr = llvm::dyn_cast_if_present<Attribute>(value)) {
assert(cast<IntegerAttr>(attr).getValue().isStrictlyPositive() &&
"expected strictly positive tile size and divisor");
return;
PatternRewriter &rewriter) const {
// Given an OpFoldResult, return an index-typed value.
auto getIdxValue = [&](OpFoldResult ofr) {
- if (auto val = ofr.dyn_cast<Value>())
+ if (auto val = llvm::dyn_cast_if_present<Value>(ofr))
return val;
return rewriter
.create<arith::ConstantIndexOp>(
ArrayRef<OpFoldResult> ofrs) {
SmallVector<Value> result;
for (auto o : ofrs) {
- if (auto val = o.template dyn_cast<Value>()) {
+ if (auto val = llvm::dyn_cast_if_present<Value>(o)) {
result.push_back(val);
} else {
result.push_back(rewriter.create<arith::ConstantIndexOp>(
continue;
// Other cases: Take a deeper look at defining ops of values.
- auto v1 = size1.dyn_cast<Value>();
- auto v2 = size2.dyn_cast<Value>();
+ auto v1 = llvm::dyn_cast_if_present<Value>(size1);
+ auto v2 = llvm::dyn_cast_if_present<Value>(size2);
if (!v1 || !v2)
return false;
auto dim = it.index();
auto size = it.value();
curr.push_back(dim);
- auto attr = size.dyn_cast<Attribute>();
+ auto attr = llvm::dyn_cast_if_present<Attribute>(size);
if (attr && cast<IntegerAttr>(attr).getInt() == 1)
continue;
reassociation.emplace_back(ReassociationIndices{});
//===----------------------------------------------------------------------===//
static bool isSupportedElementType(Type type) {
- return type.isa<MemRefType>() ||
+ return llvm::isa<MemRefType>(type) ||
OpBuilder(type.getContext()).getZeroAttr(type);
}
SmallVector<DestructurableMemorySlot>
memref::AllocaOp::getDestructurableSlots() {
MemRefType memrefType = getType();
- auto destructurable = memrefType.dyn_cast<DestructurableTypeInterface>();
+ auto destructurable = llvm::dyn_cast<DestructurableTypeInterface>(memrefType);
if (!destructurable)
return {};
DenseMap<Attribute, MemorySlot> slotMap;
- auto memrefType = getType().cast<DestructurableTypeInterface>();
+ auto memrefType = llvm::cast<DestructurableTypeInterface>(getType());
for (Attribute usedIndex : usedIndices) {
Type elemType = memrefType.getTypeAtIndex(usedIndex);
MemRefType elemPtr = MemRefType::get({}, elemType);
MemRefDestructurableTypeExternalModel, MemRefType> {
std::optional<DenseMap<Attribute, Type>>
getSubelementIndexMap(Type type) const {
- auto memrefType = type.cast<MemRefType>();
+ auto memrefType = llvm::cast<MemRefType>(type);
constexpr int64_t maxMemrefSizeForDestructuring = 16;
if (!memrefType.hasStaticShape() ||
memrefType.getNumElements() > maxMemrefSizeForDestructuring ||
}
Type getTypeAtIndex(Type type, Attribute index) const {
- auto memrefType = type.cast<MemRefType>();
- auto coordArrAttr = index.dyn_cast<ArrayAttr>();
+ auto memrefType = llvm::cast<MemRefType>(type);
+ auto coordArrAttr = llvm::dyn_cast<ArrayAttr>(index);
if (!coordArrAttr || coordArrAttr.size() != memrefType.getShape().size())
return {};
Type indexType = IndexType::get(memrefType.getContext());
for (const auto &[coordAttr, dimSize] :
llvm::zip(coordArrAttr, memrefType.getShape())) {
- auto coord = coordAttr.dyn_cast<IntegerAttr>();
+ auto coord = llvm::dyn_cast<IntegerAttr>(coordAttr);
if (!coord || coord.getType() != indexType || coord.getInt() < 0 ||
coord.getInt() >= dimSize)
return {};
return unusedDims;
for (const auto &dim : llvm::enumerate(sizes))
- if (auto attr = dim.value().dyn_cast<Attribute>())
+ if (auto attr = llvm::dyn_cast_if_present<Attribute>(dim.value()))
if (llvm::cast<IntegerAttr>(attr).getInt() == 1)
unusedDims.set(dim.index());
OpFoldResult DimOp::fold(FoldAdaptor adaptor) {
// All forms of folding require a known index.
- auto index = adaptor.getIndex().dyn_cast_or_null<IntegerAttr>();
+ auto index = llvm::dyn_cast_if_present<IntegerAttr>(adaptor.getIndex());
if (!index)
return {};
// Because we only support input strides of 1, the output stride is also
// always 1.
if (llvm::all_of(strides, [](OpFoldResult &valueOrAttr) {
- Attribute attr = valueOrAttr.dyn_cast<Attribute>();
+ Attribute attr = llvm::dyn_cast_if_present<Attribute>(valueOrAttr);
return attr && cast<IntegerAttr>(attr).getInt() == 1;
})) {
strides = SmallVector<OpFoldResult>(sourceOp.getMixedStrides().size(),
}
sizes.push_back(opSize);
- Attribute opOffsetAttr = opOffset.dyn_cast<Attribute>(),
- sourceOffsetAttr = sourceOffset.dyn_cast<Attribute>();
+ Attribute opOffsetAttr = llvm::dyn_cast_if_present<Attribute>(opOffset),
+ sourceOffsetAttr =
+ llvm::dyn_cast_if_present<Attribute>(sourceOffset);
if (opOffsetAttr && sourceOffsetAttr) {
// If both offsets are static we can simply calculate the combined
AffineExpr expr = rewriter.getAffineConstantExpr(0);
SmallVector<Value> affineApplyOperands;
for (auto valueOrAttr : {opOffset, sourceOffset}) {
- if (auto attr = valueOrAttr.dyn_cast<Attribute>()) {
+ if (auto attr = llvm::dyn_cast_if_present<Attribute>(valueOrAttr)) {
expr = expr + cast<IntegerAttr>(attr).getInt();
} else {
expr =
<< operandName << " operand appears more than once";
mlir::Type varType = operand.getType();
- auto symbolRef = std::get<1>(args).cast<SymbolRefAttr>();
+ auto symbolRef = llvm::cast<SymbolRefAttr>(std::get<1>(args));
auto decl = SymbolTable::lookupNearestSymbolFrom<Op>(op, symbolRef);
if (!decl)
return op->emitOpError()
for (const auto &mapTypeOp : *map_types) {
int64_t mapTypeBits = 0x00;
- if (!mapTypeOp.isa<mlir::IntegerAttr>())
+ if (!llvm::isa<mlir::IntegerAttr>(mapTypeOp))
return failure();
- mapTypeBits = mapTypeOp.cast<mlir::IntegerAttr>().getInt();
+ mapTypeBits = llvm::cast<mlir::IntegerAttr>(mapTypeOp).getInt();
bool to =
bitAnd(mapTypeBits, llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_TO);
// map.
auto yieldedRanked = cast<MemRefType>(yieldedValueBufferType);
#ifndef NDEBUG
- auto iterRanked = initArgBufferType->cast<MemRefType>();
+ auto iterRanked = llvm::cast<MemRefType>(*initArgBufferType);
assert(llvm::equal(yieldedRanked.getShape(), iterRanked.getShape()) &&
"expected same shape");
assert(yieldedRanked.getMemorySpace() == iterRanked.getMemorySpace() &&
if (!isa<TensorType>(bbArg.getType()))
return bbArg.getType();
// TODO: error handling
- return bufferization::getBufferType(bbArg, options)->cast<Type>();
+ return llvm::cast<Type>(*bufferization::getBufferType(bbArg, options));
}));
// Construct a new scf.while op with memref instead of tensor values.
return failure();
unsigned dimIv = cstr.appendDimVar(iv);
- auto lbv = lb.dyn_cast<Value>();
+ auto lbv = llvm::dyn_cast_if_present<Value>(lb);
unsigned symLb =
lbv ? cstr.appendSymbolVar(lbv) : cstr.appendSymbolVar(/*num=*/1);
- auto ubv = ub.dyn_cast<Value>();
+ auto ubv = llvm::dyn_cast_if_present<Value>(ub);
unsigned symUb =
ubv ? cstr.appendSymbolVar(ubv) : cstr.appendSymbolVar(/*num=*/1);
auto type = llvm::cast<spirv::CompositeType>(constructOp.getType());
if (getIndices().size() == 1 &&
constructOp.getConstituents().size() == type.getNumElements()) {
- auto i = getIndices().begin()->cast<IntegerAttr>();
+ auto i = llvm::cast<IntegerAttr>(*getIndices().begin());
return constructOp.getConstituents()[i.getValue().getSExtValue()];
}
}
//===----------------------------------------------------------------------===//
LogicalResult spirv::ConvertPtrToUOp::verify() {
- auto operandType = getPointer().getType().cast<spirv::PointerType>();
- auto resultType = getResult().getType().cast<spirv::ScalarType>();
+ auto operandType = llvm::cast<spirv::PointerType>(getPointer().getType());
+ auto resultType = llvm::cast<spirv::ScalarType>(getResult().getType());
if (!resultType || !resultType.isSignlessInteger())
return emitError("result must be a scalar type of unsigned integer");
auto spirvModule = (*this)->getParentOfType<spirv::ModuleOp>();
//===----------------------------------------------------------------------===//
LogicalResult spirv::ConvertUToPtrOp::verify() {
- auto operandType = getOperand().getType().cast<spirv::ScalarType>();
- auto resultType = getResult().getType().cast<spirv::PointerType>();
+ auto operandType = llvm::cast<spirv::ScalarType>(getOperand().getType());
+ auto resultType = llvm::cast<spirv::PointerType>(getResult().getType());
if (!operandType || !operandType.isSignlessInteger())
return emitError("result must be a scalar type of unsigned integer");
auto spirvModule = (*this)->getParentOfType<spirv::ModuleOp>();
}
unsigned CompositeType::getNumElements() const {
- if (auto arrayType = dyn_cast<ArrayType>())
+ if (auto arrayType = llvm::dyn_cast<ArrayType>(*this))
return arrayType.getNumElements();
- if (auto matrixType = dyn_cast<MatrixType>())
+ if (auto matrixType = llvm::dyn_cast<MatrixType>(*this))
return matrixType.getNumColumns();
- if (auto structType = dyn_cast<StructType>())
+ if (auto structType = llvm::dyn_cast<StructType>(*this))
return structType.getNumElements();
- if (auto vectorType = dyn_cast<VectorType>())
+ if (auto vectorType = llvm::dyn_cast<VectorType>(*this))
return vectorType.getNumElements();
- if (isa<CooperativeMatrixNVType>()) {
+ if (llvm::isa<CooperativeMatrixNVType>(*this)) {
llvm_unreachable(
"invalid to query number of elements of spirv::CooperativeMatrix type");
}
- if (isa<JointMatrixINTELType>()) {
+ if (llvm::isa<JointMatrixINTELType>(*this)) {
llvm_unreachable(
"invalid to query number of elements of spirv::JointMatrix type");
}
- if (isa<RuntimeArrayType>()) {
+ if (llvm::isa<RuntimeArrayType>(*this)) {
llvm_unreachable(
"invalid to query number of elements of spirv::RuntimeArray type");
}
}
bool CompositeType::hasCompileTimeKnownNumElements() const {
- return !isa<CooperativeMatrixNVType, JointMatrixINTELType,
- RuntimeArrayType>();
+ return !llvm::isa<CooperativeMatrixNVType, JointMatrixINTELType,
+ RuntimeArrayType>(*this);
}
void CompositeType::getExtensions(
}
std::optional<int64_t> CompositeType::getSizeInBytes() {
- if (auto arrayType = dyn_cast<ArrayType>())
+ if (auto arrayType = llvm::dyn_cast<ArrayType>(*this))
return arrayType.getSizeInBytes();
- if (auto structType = dyn_cast<StructType>())
+ if (auto structType = llvm::dyn_cast<StructType>(*this))
return structType.getSizeInBytes();
- if (auto vectorType = dyn_cast<VectorType>()) {
+ if (auto vectorType = llvm::dyn_cast<VectorType>(*this)) {
std::optional<int64_t> elementSize =
llvm::cast<ScalarType>(vectorType.getElementType()).getSizeInBytes();
if (!elementSize)
capabilities.push_back(ref); \
} break
- if (auto intType = dyn_cast<IntegerType>()) {
+ if (auto intType = llvm::dyn_cast<IntegerType>(*this)) {
switch (bitwidth) {
WIDTH_CASE(Int, 8);
WIDTH_CASE(Int, 16);
llvm_unreachable("invalid bitwidth to getCapabilities");
}
} else {
- assert(isa<FloatType>());
+ assert(llvm::isa<FloatType>(*this));
switch (bitwidth) {
WIDTH_CASE(Float, 16);
WIDTH_CASE(Float, 64);
}
bool SPIRVType::isScalarOrVector() {
- return isIntOrFloat() || isa<VectorType>();
+ return isIntOrFloat() || llvm::isa<VectorType>(*this);
}
void SPIRVType::getExtensions(SPIRVType::ExtensionArrayRefVector &extensions,
std::optional<StorageClass> storage) {
- if (auto scalarType = dyn_cast<ScalarType>()) {
+ if (auto scalarType = llvm::dyn_cast<ScalarType>(*this)) {
scalarType.getExtensions(extensions, storage);
- } else if (auto compositeType = dyn_cast<CompositeType>()) {
+ } else if (auto compositeType = llvm::dyn_cast<CompositeType>(*this)) {
compositeType.getExtensions(extensions, storage);
- } else if (auto imageType = dyn_cast<ImageType>()) {
+ } else if (auto imageType = llvm::dyn_cast<ImageType>(*this)) {
imageType.getExtensions(extensions, storage);
- } else if (auto sampledImageType = dyn_cast<SampledImageType>()) {
+ } else if (auto sampledImageType = llvm::dyn_cast<SampledImageType>(*this)) {
sampledImageType.getExtensions(extensions, storage);
- } else if (auto matrixType = dyn_cast<MatrixType>()) {
+ } else if (auto matrixType = llvm::dyn_cast<MatrixType>(*this)) {
matrixType.getExtensions(extensions, storage);
- } else if (auto ptrType = dyn_cast<PointerType>()) {
+ } else if (auto ptrType = llvm::dyn_cast<PointerType>(*this)) {
ptrType.getExtensions(extensions, storage);
} else {
llvm_unreachable("invalid SPIR-V Type to getExtensions");
void SPIRVType::getCapabilities(
SPIRVType::CapabilityArrayRefVector &capabilities,
std::optional<StorageClass> storage) {
- if (auto scalarType = dyn_cast<ScalarType>()) {
+ if (auto scalarType = llvm::dyn_cast<ScalarType>(*this)) {
scalarType.getCapabilities(capabilities, storage);
- } else if (auto compositeType = dyn_cast<CompositeType>()) {
+ } else if (auto compositeType = llvm::dyn_cast<CompositeType>(*this)) {
compositeType.getCapabilities(capabilities, storage);
- } else if (auto imageType = dyn_cast<ImageType>()) {
+ } else if (auto imageType = llvm::dyn_cast<ImageType>(*this)) {
imageType.getCapabilities(capabilities, storage);
- } else if (auto sampledImageType = dyn_cast<SampledImageType>()) {
+ } else if (auto sampledImageType = llvm::dyn_cast<SampledImageType>(*this)) {
sampledImageType.getCapabilities(capabilities, storage);
- } else if (auto matrixType = dyn_cast<MatrixType>()) {
+ } else if (auto matrixType = llvm::dyn_cast<MatrixType>(*this)) {
matrixType.getCapabilities(capabilities, storage);
- } else if (auto ptrType = dyn_cast<PointerType>()) {
+ } else if (auto ptrType = llvm::dyn_cast<PointerType>(*this)) {
ptrType.getCapabilities(capabilities, storage);
} else {
llvm_unreachable("invalid SPIR-V Type to getCapabilities");
}
std::optional<int64_t> SPIRVType::getSizeInBytes() {
- if (auto scalarType = dyn_cast<ScalarType>())
+ if (auto scalarType = llvm::dyn_cast<ScalarType>(*this))
return scalarType.getSizeInBytes();
- if (auto compositeType = dyn_cast<CompositeType>())
+ if (auto compositeType = llvm::dyn_cast<CompositeType>(*this))
return compositeType.getSizeInBytes();
return std::nullopt;
}
if (!adaptor.getLhs() || !adaptor.getRhs())
return nullptr;
auto lhsShape = llvm::to_vector<6>(
- adaptor.getLhs().cast<DenseIntElementsAttr>().getValues<int64_t>());
+ llvm::cast<DenseIntElementsAttr>(adaptor.getLhs()).getValues<int64_t>());
auto rhsShape = llvm::to_vector<6>(
- adaptor.getRhs().cast<DenseIntElementsAttr>().getValues<int64_t>());
+ llvm::cast<DenseIntElementsAttr>(adaptor.getRhs()).getValues<int64_t>());
SmallVector<int64_t, 6> resultShape;
resultShape.append(lhsShape.begin(), lhsShape.end());
resultShape.append(rhsShape.begin(), rhsShape.end());
if (!operand)
return false;
extents.push_back(llvm::to_vector<6>(
- operand.cast<DenseIntElementsAttr>().getValues<int64_t>()));
+ llvm::cast<DenseIntElementsAttr>(operand).getValues<int64_t>()));
}
return OpTrait::util::staticallyKnownBroadcastable(extents);
}())
//===----------------------------------------------------------------------===//
OpFoldResult DivOp::fold(FoldAdaptor adaptor) {
- auto lhs = adaptor.getLhs().dyn_cast_or_null<IntegerAttr>();
+ auto lhs = llvm::dyn_cast_if_present<IntegerAttr>(adaptor.getLhs());
if (!lhs)
return nullptr;
- auto rhs = adaptor.getRhs().dyn_cast_or_null<IntegerAttr>();
+ auto rhs = llvm::dyn_cast_if_present<IntegerAttr>(adaptor.getRhs());
if (!rhs)
return nullptr;
}
OpFoldResult GetExtentOp::fold(FoldAdaptor adaptor) {
- auto elements = adaptor.getShape().dyn_cast_or_null<DenseIntElementsAttr>();
+ auto elements = llvm::dyn_cast_if_present<DenseIntElementsAttr>(adaptor.getShape());
if (!elements)
return nullptr;
std::optional<int64_t> dim = getConstantDim();
//===----------------------------------------------------------------------===//
OpFoldResult shape::RankOp::fold(FoldAdaptor adaptor) {
- auto shape = adaptor.getShape().dyn_cast_or_null<DenseIntElementsAttr>();
+ auto shape = llvm::dyn_cast_if_present<DenseIntElementsAttr>(adaptor.getShape());
if (!shape)
return {};
int64_t rank = shape.getNumElements();
//===----------------------------------------------------------------------===//
OpFoldResult MulOp::fold(FoldAdaptor adaptor) {
- auto lhs = adaptor.getLhs().dyn_cast_or_null<IntegerAttr>();
+ auto lhs = llvm::dyn_cast_if_present<IntegerAttr>(adaptor.getLhs());
if (!lhs)
return nullptr;
- auto rhs = adaptor.getRhs().dyn_cast_or_null<IntegerAttr>();
+ auto rhs = llvm::dyn_cast_if_present<IntegerAttr>(adaptor.getRhs());
if (!rhs)
return nullptr;
APInt folded = lhs.getValue() * rhs.getValue();
if (!adaptor.getOperand() || !adaptor.getIndex())
return failure();
auto shapeVec = llvm::to_vector<6>(
- adaptor.getOperand().cast<DenseIntElementsAttr>().getValues<int64_t>());
+ llvm::cast<DenseIntElementsAttr>(adaptor.getOperand()).getValues<int64_t>());
auto shape = llvm::ArrayRef(shapeVec);
- auto splitPoint = adaptor.getIndex().cast<IntegerAttr>().getInt();
+ auto splitPoint = llvm::cast<IntegerAttr>(adaptor.getIndex()).getInt();
// Verify that the split point is in the correct range.
// TODO: Constant fold to an "error".
int64_t rank = shape.size();
return OpFoldResult();
Builder builder(getContext());
auto shape = llvm::to_vector<6>(
- adaptor.getInput().cast<DenseIntElementsAttr>().getValues<int64_t>());
+ llvm::cast<DenseIntElementsAttr>(adaptor.getInput()).getValues<int64_t>());
auto type = RankedTensorType::get({static_cast<int64_t>(shape.size())},
builder.getIndexType());
return DenseIntElementsAttr::get(type, shape);
Level cooStartLvl = getCOOStart(stt.getEncoding());
if (cooStartLvl < stt.getLvlRank()) {
// We only supports trailing COO for now, must be the last input.
- auto cooTp = lvlTps.back().cast<ShapedType>();
+ auto cooTp = llvm::cast<ShapedType>(lvlTps.back());
// The coordinates should be in shape of <? x rank>
unsigned expCOORank = stt.getLvlRank() - cooStartLvl;
if (cooTp.getRank() != 2 || expCOORank != cooTp.getShape().back()) {
inputTp = lvlTps[idx++];
}
// The input element type and expected element type should match.
- Type inpElemTp = inputTp.cast<TensorType>().getElementType();
+ Type inpElemTp = llvm::cast<TensorType>(inputTp).getElementType();
Type expElemTp = getFieldElemType(stt, fKind);
if (inpElemTp != expElemTp) {
misMatch = true;
/// Generates a memref from tensor operation.
static Value genTensorToMemref(PatternRewriter &rewriter, Location loc,
Value tensor) {
- auto tensorType = tensor.getType().cast<ShapedType>();
+ auto tensorType = llvm::cast<ShapedType>(tensor.getType());
auto memrefType =
MemRefType::get(tensorType.getShape(), tensorType.getElementType());
return rewriter.create<bufferization::ToMemrefOp>(loc, memrefType, tensor);
/// TODO: better unord/not-unique; also generalize, optimize, specialize!
SmallVector<Value> genImplementation(TypeRange retTypes, ValueRange args,
OpBuilder &builder, Location loc) {
- const SparseTensorType stt(rtp.cast<RankedTensorType>());
+ const SparseTensorType stt(llvm::cast<RankedTensorType>(rtp));
const Level lvlRank = stt.getLvlRank();
// Extract fields and coordinates from args.
SmallVector<Value> fields = llvm::to_vector(args.drop_back(lvlRank + 1));
// The mangled name of the function has this format:
// <namePrefix>_<DLT>_<shape>_<ordering>_<eltType>_<crdWidth>_<posWidth>
constexpr const char kInsertFuncNamePrefix[] = "_insert_";
- const SparseTensorType stt(rtp.cast<RankedTensorType>());
+ const SparseTensorType stt(llvm::cast<RankedTensorType>(rtp));
SmallString<32> nameBuffer;
llvm::raw_svector_ostream nameOstream(nameBuffer);
static TypedValue<BaseMemRefType> genToMemref(OpBuilder &builder, Location loc,
Value tensor) {
- auto tTp = tensor.getType().cast<TensorType>();
+ auto tTp = llvm::cast<TensorType>(tensor.getType());
auto mTp = MemRefType::get(tTp.getShape(), tTp.getElementType());
return builder.create<bufferization::ToMemrefOp>(loc, mTp, tensor)
.getResult();
}
Value genSliceToSize(OpBuilder &builder, Location loc, Value mem, Value sz) {
- auto elemTp = mem.getType().cast<MemRefType>().getElementType();
+ auto elemTp = llvm::cast<MemRefType>(mem.getType()).getElementType();
return builder
.create<memref::SubViewOp>(
loc, MemRefType::get({ShapedType::kDynamic}, elemTp), mem,
AffineExpr expr = b.getAffineDimExpr(0);
unsigned numSymbols = 0;
auto addOpFoldResult = [&](OpFoldResult valueOrAttr) {
- if (Value v = valueOrAttr.dyn_cast<Value>()) {
+ if (Value v = llvm::dyn_cast_if_present<Value>(valueOrAttr)) {
expr = expr + b.getAffineSymbolExpr(numSymbols++);
mapOperands.push_back(v);
return;
OpFoldResult DimOp::fold(FoldAdaptor adaptor) {
// All forms of folding require a known index.
- auto index = adaptor.getIndex().dyn_cast_or_null<IntegerAttr>();
+ auto index = llvm::dyn_cast_if_present<IntegerAttr>(adaptor.getIndex());
if (!index)
return {};
OpFoldResult currDim = std::get<1>(it);
// Case 1: The empty tensor dim is static. Check that the tensor cast
// result dim matches.
- if (auto attr = currDim.dyn_cast<Attribute>()) {
+ if (auto attr = llvm::dyn_cast_if_present<Attribute>(currDim)) {
if (ShapedType::isDynamic(newDim) ||
newDim != llvm::cast<IntegerAttr>(attr).getInt()) {
// Something is off, the cast result shape cannot be more dynamic
}
OpFoldResult ExtractSliceOp::fold(FoldAdaptor adaptor) {
- if (auto splat = adaptor.getSource().dyn_cast_or_null<SplatElementsAttr>()) {
+ if (auto splat = llvm::dyn_cast_if_present<SplatElementsAttr>(adaptor.getSource())) {
auto resultType = llvm::cast<ShapedType>(getResult().getType());
if (resultType.hasStaticShape())
return splat.resizeSplat(resultType);
SmallVector<int64_t> result;
for (auto o : ofrs) {
// Have to do this first, as getConstantIntValue special-cases constants.
- if (o.dyn_cast<Value>())
+ if (llvm::dyn_cast_if_present<Value>(o))
result.push_back(ShapedType::kDynamic);
else
result.push_back(getConstantIntValue(o).value_or(ShapedType::kDynamic));
auto rankedResultType = cast<RankedTensorType>(castOp.getType());
return MemRefType::get(
rankedResultType.getShape(), rankedResultType.getElementType(),
- maybeSrcBufferType->cast<MemRefType>().getLayout(), memorySpace);
+ llvm::cast<MemRefType>(*maybeSrcBufferType).getLayout(), memorySpace);
}
LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
collapseShapeOp.getSrc(), options, fixedTypes);
if (failed(maybeSrcBufferType))
return failure();
- auto srcBufferType = maybeSrcBufferType->cast<MemRefType>();
+ auto srcBufferType = llvm::cast<MemRefType>(*maybeSrcBufferType);
bool canBeCollapsed = memref::CollapseShapeOp::isGuaranteedCollapsible(
srcBufferType, collapseShapeOp.getReassociationIndices());
expandShapeOp.getSrc(), options, fixedTypes);
if (failed(maybeSrcBufferType))
return failure();
- auto srcBufferType = maybeSrcBufferType->cast<MemRefType>();
+ auto srcBufferType = llvm::cast<MemRefType>(*maybeSrcBufferType);
auto maybeResultType = memref::ExpandShapeOp::computeExpandedType(
srcBufferType, expandShapeOp.getResultType().getShape(),
expandShapeOp.getReassociationIndices());
if (failed(resultMemrefType))
return failure();
Value subView = rewriter.create<memref::SubViewOp>(
- loc, resultMemrefType->cast<MemRefType>(), *srcMemref, mixedOffsets,
+ loc, llvm::cast<MemRefType>(*resultMemrefType), *srcMemref, mixedOffsets,
mixedSizes, mixedStrides);
replaceOpWithBufferizedValues(rewriter, op, subView);
SmallVector<OpFoldResult> mixedSizes = extractSliceOp.getMixedSizes();
SmallVector<OpFoldResult> mixedStrides = extractSliceOp.getMixedStrides();
return cast<BaseMemRefType>(memref::SubViewOp::inferRankReducedResultType(
- extractSliceOp.getType().getShape(), srcMemrefType->cast<MemRefType>(),
+ extractSliceOp.getType().getShape(), llvm::cast<MemRefType>(*srcMemrefType),
mixedOffsets, mixedSizes, mixedStrides));
}
};
return {};
auto resultETy = resultTy.getElementType();
- auto lhsAttr = adaptor.getInput1().dyn_cast_or_null<DenseElementsAttr>();
- auto rhsAttr = adaptor.getInput2().dyn_cast_or_null<DenseElementsAttr>();
+ auto lhsAttr = llvm::dyn_cast_if_present<DenseElementsAttr>(adaptor.getInput1());
+ auto rhsAttr = llvm::dyn_cast_if_present<DenseElementsAttr>(adaptor.getInput2());
if (lhsTy == resultTy && isSplatZero(resultETy, rhsAttr))
return getInput1();
return {};
auto resultETy = resultTy.getElementType();
- auto lhsAttr = adaptor.getInput1().dyn_cast_or_null<DenseElementsAttr>();
- auto rhsAttr = adaptor.getInput2().dyn_cast_or_null<DenseElementsAttr>();
+ auto lhsAttr = llvm::dyn_cast_if_present<DenseElementsAttr>(adaptor.getInput1());
+ auto rhsAttr = llvm::dyn_cast_if_present<DenseElementsAttr>(adaptor.getInput2());
if (lhsAttr && lhsAttr.isSplat()) {
if (llvm::isa<IntegerType>(resultETy) &&
lhsAttr.getSplatValue<APInt>().isZero())
return {};
auto resultETy = resultTy.getElementType();
- auto lhsAttr = adaptor.getInput1().dyn_cast_or_null<DenseElementsAttr>();
- auto rhsAttr = adaptor.getInput2().dyn_cast_or_null<DenseElementsAttr>();
+ auto lhsAttr = llvm::dyn_cast_if_present<DenseElementsAttr>(adaptor.getInput1());
+ auto rhsAttr = llvm::dyn_cast_if_present<DenseElementsAttr>(adaptor.getInput2());
const int64_t shift = llvm::isa<IntegerType>(resultETy) ? getShift() : 0;
if (rhsTy == resultTy) {
return {};
auto resultETy = resultTy.getElementType();
- auto lhsAttr = adaptor.getInput1().dyn_cast_or_null<DenseElementsAttr>();
- auto rhsAttr = adaptor.getInput2().dyn_cast_or_null<DenseElementsAttr>();
+ auto lhsAttr = llvm::dyn_cast_if_present<DenseElementsAttr>(adaptor.getInput1());
+ auto rhsAttr = llvm::dyn_cast_if_present<DenseElementsAttr>(adaptor.getInput2());
if (lhsTy == resultTy && isSplatZero(resultETy, rhsAttr))
return getInput1();
OpFoldResult GreaterOp::fold(FoldAdaptor adaptor) {
auto resultTy = llvm::dyn_cast<RankedTensorType>(getType());
- auto lhsAttr = adaptor.getInput1().dyn_cast_or_null<DenseElementsAttr>();
- auto rhsAttr = adaptor.getInput2().dyn_cast_or_null<DenseElementsAttr>();
+ auto lhsAttr = llvm::dyn_cast_if_present<DenseElementsAttr>(adaptor.getInput1());
+ auto rhsAttr = llvm::dyn_cast_if_present<DenseElementsAttr>(adaptor.getInput2());
if (!lhsAttr || !rhsAttr)
return {};
OpFoldResult GreaterEqualOp::fold(FoldAdaptor adaptor) {
auto resultTy = llvm::dyn_cast<RankedTensorType>(getType());
- auto lhsAttr = adaptor.getInput1().dyn_cast_or_null<DenseElementsAttr>();
- auto rhsAttr = adaptor.getInput2().dyn_cast_or_null<DenseElementsAttr>();
+ auto lhsAttr = llvm::dyn_cast_if_present<DenseElementsAttr>(adaptor.getInput1());
+ auto rhsAttr = llvm::dyn_cast_if_present<DenseElementsAttr>(adaptor.getInput2());
if (!lhsAttr || !rhsAttr)
return {};
OpFoldResult EqualOp::fold(FoldAdaptor adaptor) {
auto resultTy = llvm::dyn_cast<RankedTensorType>(getType());
- auto lhsAttr = adaptor.getInput1().dyn_cast_or_null<DenseElementsAttr>();
- auto rhsAttr = adaptor.getInput2().dyn_cast_or_null<DenseElementsAttr>();
+ auto lhsAttr = llvm::dyn_cast_if_present<DenseElementsAttr>(adaptor.getInput1());
+ auto rhsAttr = llvm::dyn_cast_if_present<DenseElementsAttr>(adaptor.getInput2());
Value lhs = getInput1();
Value rhs = getInput2();
auto lhsTy = llvm::cast<ShapedType>(lhs.getType());
if (getInput().getType() == getType())
return getInput();
- auto operand = adaptor.getInput().dyn_cast_or_null<ElementsAttr>();
+ auto operand = llvm::dyn_cast_if_present<ElementsAttr>(adaptor.getInput());
if (!operand)
return {};
if (inputTy == outputTy)
return getInput1();
- auto operand = adaptor.getInput1().dyn_cast_or_null<DenseElementsAttr>();
+ auto operand = llvm::dyn_cast_if_present<DenseElementsAttr>(adaptor.getInput1());
if (operand && outputTy.hasStaticShape() && operand.isSplat()) {
return SplatElementsAttr::get(outputTy, operand.getSplatValue<Attribute>());
}
OpFoldResult PadOp::fold(FoldAdaptor adaptor) {
// If the pad is all zeros we can fold this operation away.
if (adaptor.getPadding()) {
- auto densePad = adaptor.getPadding().cast<DenseElementsAttr>();
+ auto densePad = llvm::cast<DenseElementsAttr>(adaptor.getPadding());
if (densePad.isSplat() && densePad.getSplatValue<APInt>().isZero()) {
return getInput1();
}
auto operand = getInput();
auto operandTy = llvm::cast<ShapedType>(operand.getType());
auto axis = getAxis();
- auto operandAttr = adaptor.getInput().dyn_cast_or_null<SplatElementsAttr>();
+ auto operandAttr = llvm::dyn_cast_if_present<SplatElementsAttr>(adaptor.getInput());
if (operandAttr)
return operandAttr;
!outputTy.getElementType().isIntOrIndexOrFloat())
return {};
- auto operand = adaptor.getInput().cast<ElementsAttr>();
+ auto operand = llvm::cast<ElementsAttr>(adaptor.getInput());
if (operand.isSplat() && outputTy.hasStaticShape()) {
return SplatElementsAttr::get(outputTy, operand.getSplatValue<Attribute>());
}
if (getOnTrue() == getOnFalse())
return getOnTrue();
- auto predicate = adaptor.getPred().dyn_cast_or_null<DenseIntElementsAttr>();
+ auto predicate = llvm::dyn_cast_if_present<DenseIntElementsAttr>(adaptor.getPred());
if (!predicate)
return {};
auto resultTy = llvm::cast<ShapedType>(getType());
// Transposing splat values just means reshaping.
- if (auto input = adaptor.getInput1().dyn_cast_or_null<DenseElementsAttr>()) {
+ if (auto input = llvm::dyn_cast_if_present<DenseElementsAttr>(adaptor.getInput1())) {
if (input.isSplat() && resultTy.hasStaticShape() &&
inputTy.getElementType() == resultTy.getElementType())
return input.reshape(resultTy);
// Verify the rank agrees with the output type if the output type is ranked.
if (outputType) {
if (outputType.getRank() !=
- input1_copy.getType().cast<RankedTensorType>().getRank() ||
+ llvm::cast<RankedTensorType>(input1_copy.getType()).getRank() ||
outputType.getRank() !=
- input2_copy.getType().cast<RankedTensorType>().getRank())
+ llvm::cast<RankedTensorType>(input2_copy.getType()).getRank())
return rewriter.notifyMatchFailure(
loc, "the reshaped type doesn't agrees with the ranked output type");
}
LogicalResult mlir::tosa::EqualizeRanks(PatternRewriter &rewriter, Location loc,
Value &input1, Value &input2) {
- auto input1Ty = input1.getType().dyn_cast<RankedTensorType>();
- auto input2Ty = input2.getType().dyn_cast<RankedTensorType>();
+ auto input1Ty = llvm::dyn_cast<RankedTensorType>(input1.getType());
+ auto input2Ty = llvm::dyn_cast<RankedTensorType>(input2.getType());
if (!input1Ty || !input2Ty) {
return failure();
}
ArrayRef<int64_t> higherRankShape =
- higherTensorValue.getType().cast<RankedTensorType>().getShape();
+ llvm::cast<RankedTensorType>(higherTensorValue.getType()).getShape();
ArrayRef<int64_t> lowerRankShape =
- lowerTensorValue.getType().cast<RankedTensorType>().getShape();
+ llvm::cast<RankedTensorType>(lowerTensorValue.getType()).getShape();
SmallVector<int64_t, 4> reshapeOutputShape;
.failed())
return failure();
- auto reshapeInputType = lowerTensorValue.getType().cast<RankedTensorType>();
+ auto reshapeInputType =
+ llvm::cast<RankedTensorType>(lowerTensorValue.getType());
auto reshapeOutputType = RankedTensorType::get(
ArrayRef<int64_t>(reshapeOutputShape), reshapeInputType.getElementType());
SmallVector<Operation *> operations;
operations.reserve(values.size());
for (transform::MappedValue value : values) {
- if (auto *op = value.dyn_cast<Operation *>()) {
+ if (auto *op = llvm::dyn_cast_if_present<Operation *>(value)) {
operations.push_back(op);
continue;
}
SmallVector<Value> payloadValues;
payloadValues.reserve(values.size());
for (transform::MappedValue value : values) {
- if (auto v = value.dyn_cast<Value>()) {
+ if (auto v = llvm::dyn_cast_if_present<Value>(value)) {
payloadValues.push_back(v);
continue;
}
SmallVector<transform::Param> parameters;
parameters.reserve(values.size());
for (transform::MappedValue value : values) {
- if (auto attr = value.dyn_cast<Attribute>()) {
+ if (auto attr = llvm::dyn_cast_if_present<Attribute>(value)) {
parameters.push_back(attr);
continue;
}
bool isZeroIndex(OpFoldResult v) {
if (!v)
return false;
- if (auto attr = v.dyn_cast<Attribute>()) {
+ if (auto attr = llvm::dyn_cast_if_present<Attribute>(v)) {
IntegerAttr intAttr = dyn_cast<IntegerAttr>(attr);
return intAttr && intAttr.getValue().isZero();
}
void dispatchIndexOpFoldResult(OpFoldResult ofr,
SmallVectorImpl<Value> &dynamicVec,
SmallVectorImpl<int64_t> &staticVec) {
- auto v = ofr.dyn_cast<Value>();
+ auto v = llvm::dyn_cast_if_present<Value>(ofr);
if (!v) {
APInt apInt = cast<IntegerAttr>(ofr.get<Attribute>()).getValue();
staticVec.push_back(apInt.getSExtValue());
/// If ofr is a constant integer or an IntegerAttr, return the integer.
std::optional<int64_t> getConstantIntValue(OpFoldResult ofr) {
// Case 1: Check for Constant integer.
- if (auto val = ofr.dyn_cast<Value>()) {
+ if (auto val = llvm::dyn_cast_if_present<Value>(ofr)) {
APSInt intVal;
if (matchPattern(val, m_ConstantInt(&intVal)))
return intVal.getSExtValue();
return std::nullopt;
}
// Case 2: Check for IntegerAttr.
- Attribute attr = ofr.dyn_cast<Attribute>();
+ Attribute attr = llvm::dyn_cast_if_present<Attribute>(ofr);
if (auto intAttr = dyn_cast_or_null<IntegerAttr>(attr))
return intAttr.getValue().getSExtValue();
return std::nullopt;
auto cst1 = getConstantIntValue(ofr1), cst2 = getConstantIntValue(ofr2);
if (cst1 && cst2 && *cst1 == *cst2)
return true;
- auto v1 = ofr1.dyn_cast<Value>(), v2 = ofr2.dyn_cast<Value>();
+ auto v1 = llvm::dyn_cast_if_present<Value>(ofr1),
+ v2 = llvm::dyn_cast_if_present<Value>(ofr2);
return v1 && v1 == v2;
}
OpaqueProperties properties, RegionRange,
SmallVectorImpl<Type> &inferredReturnTypes) {
ExtractOp::Adaptor op(operands, attributes, properties);
- auto vectorType = op.getVector().getType().cast<VectorType>();
+ auto vectorType = llvm::cast<VectorType>(op.getVector().getType());
if (static_cast<int64_t>(op.getPosition().size()) == vectorType.getRank()) {
inferredReturnTypes.push_back(vectorType.getElementType());
} else {
if (!adaptor.getSource())
return {};
auto vectorType = getResultVectorType();
- if (adaptor.getSource().isa<IntegerAttr, FloatAttr>())
+ if (llvm::isa<IntegerAttr, FloatAttr>(adaptor.getSource()))
return DenseElementsAttr::get(vectorType, adaptor.getSource());
- if (auto attr = adaptor.getSource().dyn_cast<SplatElementsAttr>())
+ if (auto attr = llvm::dyn_cast<SplatElementsAttr>(adaptor.getSource()))
return DenseElementsAttr::get(vectorType, attr.getSplatValue<Attribute>());
return {};
}
OpaqueProperties properties, RegionRange,
SmallVectorImpl<Type> &inferredReturnTypes) {
ShuffleOp::Adaptor op(operands, attributes, properties);
- auto v1Type = op.getV1().getType().cast<VectorType>();
+ auto v1Type = llvm::cast<VectorType>(op.getV1().getType());
auto v1Rank = v1Type.getRank();
// Construct resulting type: leading dimension matches mask
// length, all trailing dimensions match the operands.
OpFoldResult vector::TransposeOp::fold(FoldAdaptor adaptor) {
// Eliminate splat constant transpose ops.
- if (auto attr = adaptor.getVector().dyn_cast_or_null<DenseElementsAttr>())
+ if (auto attr = llvm::dyn_cast_if_present<DenseElementsAttr>(adaptor.getVector()))
if (attr.isSplat())
return attr.reshape(getResultVectorType());
if (auto *op = getDefiningOp())
return op->print(os, flags);
// TODO: Improve BlockArgument print'ing.
- BlockArgument arg = this->cast<BlockArgument>();
+ BlockArgument arg = llvm::cast<BlockArgument>(*this);
os << "<block argument> of type '" << arg.getType()
<< "' at index: " << arg.getArgNumber();
}
return op->print(os, state);
// TODO: Improve BlockArgument print'ing.
- BlockArgument arg = this->cast<BlockArgument>();
+ BlockArgument arg = llvm::cast<BlockArgument>(*this);
os << "<block argument> of type '" << arg.getType()
<< "' at index: " << arg.getArgNumber();
}
void Value::printAsOperand(raw_ostream &os, const OpPrintingFlags &flags) {
Operation *op;
- if (auto result = dyn_cast<OpResult>()) {
+ if (auto result = llvm::dyn_cast<OpResult>(*this)) {
op = result.getOwner();
} else {
- op = cast<BlockArgument>().getOwner()->getParentOp();
+ op = llvm::cast<BlockArgument>(*this).getOwner()->getParentOp();
if (!op) {
os << "<<UNKNOWN SSA VALUE>>";
return;
/// See `llvm::detail::indexed_accessor_range_base` for details.
BlockRange::OwnerT BlockRange::offset_base(OwnerT object, ptrdiff_t index) {
- if (auto *operand = object.dyn_cast<BlockOperand *>())
+ if (auto *operand = llvm::dyn_cast_if_present<BlockOperand *>(object))
return {operand + index};
- return {object.dyn_cast<Block *const *>() + index};
+ return {llvm::dyn_cast_if_present<Block *const *>(object) + index};
}
/// See `llvm::detail::indexed_accessor_range_base` for details.
Block *BlockRange::dereference_iterator(OwnerT object, ptrdiff_t index) {
- if (const auto *operand = object.dyn_cast<BlockOperand *>())
+ if (const auto *operand = llvm::dyn_cast_if_present<BlockOperand *>(object))
return operand[index].get();
- return object.dyn_cast<Block *const *>()[index];
+ return llvm::dyn_cast_if_present<Block *const *>(object)[index];
}
Type expectedType = std::get<1>(it);
// Normal values get pushed back directly.
- if (auto value = std::get<0>(it).dyn_cast<Value>()) {
+ if (auto value = llvm::dyn_cast_if_present<Value>(std::get<0>(it))) {
if (value.getType() != expectedType)
return cleanupFailure();
DenseElementsAttr
DenseElementsAttr::mapValues(Type newElementType,
function_ref<APInt(const APInt &)> mapping) const {
- return cast<DenseIntElementsAttr>().mapValues(newElementType, mapping);
+ return llvm::cast<DenseIntElementsAttr>(*this).mapValues(newElementType, mapping);
}
DenseElementsAttr DenseElementsAttr::mapValues(
Type newElementType, function_ref<APInt(const APFloat &)> mapping) const {
- return cast<DenseFPElementsAttr>().mapValues(newElementType, mapping);
+ return llvm::cast<DenseFPElementsAttr>(*this).mapValues(newElementType, mapping);
}
ShapedType DenseElementsAttr::getType() const {
//===----------------------------------------------------------------------===//
unsigned FloatType::getWidth() {
- if (isa<Float8E5M2Type, Float8E4M3FNType, Float8E5M2FNUZType,
- Float8E4M3FNUZType, Float8E4M3B11FNUZType>())
+ if (llvm::isa<Float8E5M2Type, Float8E4M3FNType, Float8E5M2FNUZType,
+ Float8E4M3FNUZType, Float8E4M3B11FNUZType>(*this))
return 8;
- if (isa<Float16Type, BFloat16Type>())
+ if (llvm::isa<Float16Type, BFloat16Type>(*this))
return 16;
- if (isa<Float32Type>())
+ if (llvm::isa<Float32Type>(*this))
return 32;
- if (isa<Float64Type>())
+ if (llvm::isa<Float64Type>(*this))
return 64;
- if (isa<Float80Type>())
+ if (llvm::isa<Float80Type>(*this))
return 80;
- if (isa<Float128Type>())
+ if (llvm::isa<Float128Type>(*this))
return 128;
llvm_unreachable("unexpected float type");
}
/// Returns the floating semantics for the given type.
const llvm::fltSemantics &FloatType::getFloatSemantics() {
- if (isa<Float8E5M2Type>())
+ if (llvm::isa<Float8E5M2Type>(*this))
return APFloat::Float8E5M2();
- if (isa<Float8E4M3FNType>())
+ if (llvm::isa<Float8E4M3FNType>(*this))
return APFloat::Float8E4M3FN();
- if (isa<Float8E5M2FNUZType>())
+ if (llvm::isa<Float8E5M2FNUZType>(*this))
return APFloat::Float8E5M2FNUZ();
- if (isa<Float8E4M3FNUZType>())
+ if (llvm::isa<Float8E4M3FNUZType>(*this))
return APFloat::Float8E4M3FNUZ();
- if (isa<Float8E4M3B11FNUZType>())
+ if (llvm::isa<Float8E4M3B11FNUZType>(*this))
return APFloat::Float8E4M3B11FNUZ();
- if (isa<BFloat16Type>())
+ if (llvm::isa<BFloat16Type>(*this))
return APFloat::BFloat();
- if (isa<Float16Type>())
+ if (llvm::isa<Float16Type>(*this))
return APFloat::IEEEhalf();
- if (isa<Float32Type>())
+ if (llvm::isa<Float32Type>(*this))
return APFloat::IEEEsingle();
- if (isa<Float64Type>())
+ if (llvm::isa<Float64Type>(*this))
return APFloat::IEEEdouble();
- if (isa<Float80Type>())
+ if (llvm::isa<Float80Type>(*this))
return APFloat::x87DoubleExtended();
- if (isa<Float128Type>())
+ if (llvm::isa<Float128Type>(*this))
return APFloat::IEEEquad();
llvm_unreachable("non-floating point type used");
}
[](auto type) { return type.getElementType(); });
}
-bool TensorType::hasRank() const { return !isa<UnrankedTensorType>(); }
+bool TensorType::hasRank() const { return !llvm::isa<UnrankedTensorType>(*this); }
ArrayRef<int64_t> TensorType::getShape() const {
- return cast<RankedTensorType>().getShape();
+ return llvm::cast<RankedTensorType>(*this).getShape();
}
TensorType TensorType::cloneWith(std::optional<ArrayRef<int64_t>> shape,
Type elementType) const {
- if (auto unrankedTy = dyn_cast<UnrankedTensorType>()) {
+ if (auto unrankedTy = llvm::dyn_cast<UnrankedTensorType>(*this)) {
if (shape)
return RankedTensorType::get(*shape, elementType);
return UnrankedTensorType::get(elementType);
}
- auto rankedTy = cast<RankedTensorType>();
+ auto rankedTy = llvm::cast<RankedTensorType>(*this);
if (!shape)
return RankedTensorType::get(rankedTy.getShape(), elementType,
rankedTy.getEncoding());
[](auto type) { return type.getElementType(); });
}
-bool BaseMemRefType::hasRank() const { return !isa<UnrankedMemRefType>(); }
+bool BaseMemRefType::hasRank() const { return !llvm::isa<UnrankedMemRefType>(*this); }
ArrayRef<int64_t> BaseMemRefType::getShape() const {
- return cast<MemRefType>().getShape();
+ return llvm::cast<MemRefType>(*this).getShape();
}
BaseMemRefType BaseMemRefType::cloneWith(std::optional<ArrayRef<int64_t>> shape,
Type elementType) const {
- if (auto unrankedTy = dyn_cast<UnrankedMemRefType>()) {
+ if (auto unrankedTy = llvm::dyn_cast<UnrankedMemRefType>(*this)) {
if (!shape)
return UnrankedMemRefType::get(elementType, getMemorySpace());
MemRefType::Builder builder(*shape, elementType);
return builder;
}
- MemRefType::Builder builder(cast<MemRefType>());
+ MemRefType::Builder builder(llvm::cast<MemRefType>(*this));
if (shape)
builder.setShape(*shape);
builder.setElementType(elementType);
}
Attribute BaseMemRefType::getMemorySpace() const {
- if (auto rankedMemRefTy = dyn_cast<MemRefType>())
+ if (auto rankedMemRefTy = llvm::dyn_cast<MemRefType>(*this))
return rankedMemRefTy.getMemorySpace();
- return cast<UnrankedMemRefType>().getMemorySpace();
+ return llvm::cast<UnrankedMemRefType>(*this).getMemorySpace();
}
unsigned BaseMemRefType::getMemorySpaceAsInt() const {
- if (auto rankedMemRefTy = dyn_cast<MemRefType>())
+ if (auto rankedMemRefTy = llvm::dyn_cast<MemRefType>(*this))
return rankedMemRefTy.getMemorySpaceAsInt();
- return cast<UnrankedMemRefType>().getMemorySpaceAsInt();
+ return llvm::cast<UnrankedMemRefType>(*this).getMemorySpaceAsInt();
}
//===----------------------------------------------------------------------===//
/// See `llvm::detail::indexed_accessor_range_base` for details.
ValueRange::OwnerT ValueRange::offset_base(const OwnerT &owner,
ptrdiff_t index) {
- if (const auto *value = owner.dyn_cast<const Value *>())
+ if (const auto *value = llvm::dyn_cast_if_present<const Value *>(owner))
return {value + index};
- if (auto *operand = owner.dyn_cast<OpOperand *>())
+ if (auto *operand = llvm::dyn_cast_if_present<OpOperand *>(owner))
return {operand + index};
return owner.get<detail::OpResultImpl *>()->getNextResultAtOffset(index);
}
/// See `llvm::detail::indexed_accessor_range_base` for details.
Value ValueRange::dereference_iterator(const OwnerT &owner, ptrdiff_t index) {
- if (const auto *value = owner.dyn_cast<const Value *>())
+ if (const auto *value = llvm::dyn_cast_if_present<const Value *>(owner))
return value[index];
- if (auto *operand = owner.dyn_cast<OpOperand *>())
+ if (auto *operand = llvm::dyn_cast_if_present<OpOperand *>(owner))
return operand[index].get();
return owner.get<detail::OpResultImpl *>()->getNextResultAtOffset(index);
}
/// See `llvm::detail::indexed_accessor_range_base` for details.
RegionRange::OwnerT RegionRange::offset_base(const OwnerT &owner,
ptrdiff_t index) {
- if (auto *region = owner.dyn_cast<const std::unique_ptr<Region> *>())
+ if (auto *region = llvm::dyn_cast_if_present<const std::unique_ptr<Region> *>(owner))
return region + index;
- if (auto **region = owner.dyn_cast<Region **>())
+ if (auto **region = llvm::dyn_cast_if_present<Region **>(owner))
return region + index;
return &owner.get<Region *>()[index];
}
/// See `llvm::detail::indexed_accessor_range_base` for details.
Region *RegionRange::dereference_iterator(const OwnerT &owner,
ptrdiff_t index) {
- if (auto *region = owner.dyn_cast<const std::unique_ptr<Region> *>())
+ if (auto *region = llvm::dyn_cast_if_present<const std::unique_ptr<Region> *>(owner))
return region[index].get();
- if (auto **region = owner.dyn_cast<Region **>())
+ if (auto **region = llvm::dyn_cast_if_present<Region **>(owner))
return region[index];
return &owner.get<Region *>()[index];
}
typename llvm::function_traits<CallbackT>::result_t,
void>::value> * = nullptr>
std::optional<WalkResult> walk(CallbackT cback) {
- if (Region *region = limit.dyn_cast<Region *>())
+ if (Region *region = llvm::dyn_cast_if_present<Region *>(limit))
return walkSymbolUses(*region, cback);
return walkSymbolUses(limit.get<Operation *>(), cback);
}
/// traversing into any nested symbol tables.
template <typename CallbackT>
std::optional<WalkResult> walkSymbolTable(CallbackT &&cback) {
- if (Region *region = limit.dyn_cast<Region *>())
+ if (Region *region = llvm::dyn_cast_if_present<Region *>(limit))
return ::walkSymbolTable(*region, cback);
return ::walkSymbolTable(limit.get<Operation *>(), cback);
}
if (count == 0)
return;
ValueRange::OwnerT owner = values.begin().getBase();
- if (auto *result = owner.dyn_cast<detail::OpResultImpl *>())
+ if (auto *result = llvm::dyn_cast_if_present<detail::OpResultImpl *>(owner))
this->base = result;
- else if (auto *operand = owner.dyn_cast<OpOperand *>())
+ else if (auto *operand = llvm::dyn_cast_if_present<OpOperand *>(owner))
this->base = operand;
else
this->base = owner.get<const Value *>();
/// See `llvm::detail::indexed_accessor_range_base` for details.
TypeRange::OwnerT TypeRange::offset_base(OwnerT object, ptrdiff_t index) {
- if (const auto *value = object.dyn_cast<const Value *>())
+ if (const auto *value = llvm::dyn_cast_if_present<const Value *>(object))
return {value + index};
- if (auto *operand = object.dyn_cast<OpOperand *>())
+ if (auto *operand = llvm::dyn_cast_if_present<OpOperand *>(object))
return {operand + index};
- if (auto *result = object.dyn_cast<detail::OpResultImpl *>())
+ if (auto *result = llvm::dyn_cast_if_present<detail::OpResultImpl *>(object))
return {result->getNextResultAtOffset(index)};
- return {object.dyn_cast<const Type *>() + index};
+ return {llvm::dyn_cast_if_present<const Type *>(object) + index};
}
/// See `llvm::detail::indexed_accessor_range_base` for details.
Type TypeRange::dereference_iterator(OwnerT object, ptrdiff_t index) {
- if (const auto *value = object.dyn_cast<const Value *>())
+ if (const auto *value = llvm::dyn_cast_if_present<const Value *>(object))
return (value + index)->getType();
- if (auto *operand = object.dyn_cast<OpOperand *>())
+ if (auto *operand = llvm::dyn_cast_if_present<OpOperand *>(object))
return (operand + index)->get().getType();
- if (auto *result = object.dyn_cast<detail::OpResultImpl *>())
+ if (auto *result = llvm::dyn_cast_if_present<detail::OpResultImpl *>(object))
return result->getNextResultAtOffset(index)->getType();
- return object.dyn_cast<const Type *>()[index];
+ return llvm::dyn_cast_if_present<const Type *>(object)[index];
}
MLIRContext *Type::getContext() const { return getDialect().getContext(); }
-bool Type::isFloat8E5M2() const { return isa<Float8E5M2Type>(); }
-bool Type::isFloat8E4M3FN() const { return isa<Float8E4M3FNType>(); }
-bool Type::isFloat8E5M2FNUZ() const { return isa<Float8E5M2FNUZType>(); }
-bool Type::isFloat8E4M3FNUZ() const { return isa<Float8E4M3FNUZType>(); }
-bool Type::isFloat8E4M3B11FNUZ() const { return isa<Float8E4M3B11FNUZType>(); }
-bool Type::isBF16() const { return isa<BFloat16Type>(); }
-bool Type::isF16() const { return isa<Float16Type>(); }
-bool Type::isF32() const { return isa<Float32Type>(); }
-bool Type::isF64() const { return isa<Float64Type>(); }
-bool Type::isF80() const { return isa<Float80Type>(); }
-bool Type::isF128() const { return isa<Float128Type>(); }
-
-bool Type::isIndex() const { return isa<IndexType>(); }
+bool Type::isFloat8E5M2() const { return llvm::isa<Float8E5M2Type>(*this); }
+bool Type::isFloat8E4M3FN() const { return llvm::isa<Float8E4M3FNType>(*this); }
+bool Type::isFloat8E5M2FNUZ() const {
+ return llvm::isa<Float8E5M2FNUZType>(*this);
+}
+bool Type::isFloat8E4M3FNUZ() const {
+ return llvm::isa<Float8E4M3FNUZType>(*this);
+}
+bool Type::isFloat8E4M3B11FNUZ() const {
+ return llvm::isa<Float8E4M3B11FNUZType>(*this);
+}
+bool Type::isBF16() const { return llvm::isa<BFloat16Type>(*this); }
+bool Type::isF16() const { return llvm::isa<Float16Type>(*this); }
+bool Type::isF32() const { return llvm::isa<Float32Type>(*this); }
+bool Type::isF64() const { return llvm::isa<Float64Type>(*this); }
+bool Type::isF80() const { return llvm::isa<Float80Type>(*this); }
+bool Type::isF128() const { return llvm::isa<Float128Type>(*this); }
+
+bool Type::isIndex() const { return llvm::isa<IndexType>(*this); }
/// Return true if this is an integer type with the specified width.
bool Type::isInteger(unsigned width) const {
- if (auto intTy = dyn_cast<IntegerType>())
+ if (auto intTy = llvm::dyn_cast<IntegerType>(*this))
return intTy.getWidth() == width;
return false;
}
bool Type::isSignlessInteger() const {
- if (auto intTy = dyn_cast<IntegerType>())
+ if (auto intTy = llvm::dyn_cast<IntegerType>(*this))
return intTy.isSignless();
return false;
}
bool Type::isSignlessInteger(unsigned width) const {
- if (auto intTy = dyn_cast<IntegerType>())
+ if (auto intTy = llvm::dyn_cast<IntegerType>(*this))
return intTy.isSignless() && intTy.getWidth() == width;
return false;
}
bool Type::isSignedInteger() const {
- if (auto intTy = dyn_cast<IntegerType>())
+ if (auto intTy = llvm::dyn_cast<IntegerType>(*this))
return intTy.isSigned();
return false;
}
bool Type::isSignedInteger(unsigned width) const {
- if (auto intTy = dyn_cast<IntegerType>())
+ if (auto intTy = llvm::dyn_cast<IntegerType>(*this))
return intTy.isSigned() && intTy.getWidth() == width;
return false;
}
bool Type::isUnsignedInteger() const {
- if (auto intTy = dyn_cast<IntegerType>())
+ if (auto intTy = llvm::dyn_cast<IntegerType>(*this))
return intTy.isUnsigned();
return false;
}
bool Type::isUnsignedInteger(unsigned width) const {
- if (auto intTy = dyn_cast<IntegerType>())
+ if (auto intTy = llvm::dyn_cast<IntegerType>(*this))
return intTy.isUnsigned() && intTy.getWidth() == width;
return false;
}
bool Type::isSignlessIntOrIndex() const {
- return isSignlessInteger() || isa<IndexType>();
+ return isSignlessInteger() || llvm::isa<IndexType>(*this);
}
bool Type::isSignlessIntOrIndexOrFloat() const {
- return isSignlessInteger() || isa<IndexType, FloatType>();
+ return isSignlessInteger() || llvm::isa<IndexType, FloatType>(*this);
}
bool Type::isSignlessIntOrFloat() const {
- return isSignlessInteger() || isa<FloatType>();
+ return isSignlessInteger() || llvm::isa<FloatType>(*this);
}
-bool Type::isIntOrIndex() const { return isa<IntegerType>() || isIndex(); }
+bool Type::isIntOrIndex() const {
+ return llvm::isa<IntegerType>(*this) || isIndex();
+}
-bool Type::isIntOrFloat() const { return isa<IntegerType, FloatType>(); }
+bool Type::isIntOrFloat() const {
+ return llvm::isa<IntegerType, FloatType>(*this);
+}
bool Type::isIntOrIndexOrFloat() const { return isIntOrFloat() || isIndex(); }
unsigned Type::getIntOrFloatBitWidth() const {
assert(isIntOrFloat() && "only integers and floats have a bitwidth");
- if (auto intType = dyn_cast<IntegerType>())
+ if (auto intType = llvm::dyn_cast<IntegerType>(*this))
return intType.getWidth();
- return cast<FloatType>().getWidth();
+ return llvm::cast<FloatType>(*this).getWidth();
}
}
void mlir::IRUnit::print(llvm::raw_ostream &os, OpPrintingFlags flags) const {
- if (auto *op = this->dyn_cast<Operation *>())
+ if (auto *op = llvm::dyn_cast_if_present<Operation *>(*this))
return printOp(os, op, flags);
- if (auto *region = this->dyn_cast<Region *>())
+ if (auto *region = llvm::dyn_cast_if_present<Region *>(*this))
return printRegion(os, region, flags);
- if (auto *block = this->dyn_cast<Block *>())
+ if (auto *block = llvm::dyn_cast_if_present<Block *>(*this))
return printBlock(os, block, flags);
llvm_unreachable("unknown IRUnit");
}
/// If this value is the result of an Operation, return the operation that
/// defines it.
Operation *Value::getDefiningOp() const {
- if (auto result = dyn_cast<OpResult>())
+ if (auto result = llvm::dyn_cast<OpResult>(*this))
return result.getOwner();
return nullptr;
}
if (auto *op = getDefiningOp())
return op->getLoc();
- return cast<BlockArgument>().getLoc();
+ return llvm::cast<BlockArgument>(*this).getLoc();
}
void Value::setLoc(Location loc) {
if (auto *op = getDefiningOp())
return op->setLoc(loc);
- return cast<BlockArgument>().setLoc(loc);
+ return llvm::cast<BlockArgument>(*this).setLoc(loc);
}
/// Return the Region in which this Value is defined.
Region *Value::getParentRegion() {
if (auto *op = getDefiningOp())
return op->getParentRegion();
- return cast<BlockArgument>().getOwner()->getParent();
+ return llvm::cast<BlockArgument>(*this).getOwner()->getParent();
}
/// Return the Block in which this Value is defined.
Block *Value::getParentBlock() {
if (Operation *op = getDefiningOp())
return op->getBlock();
- return cast<BlockArgument>().getOwner();
+ return llvm::cast<BlockArgument>(*this).getOwner();
}
//===----------------------------------------------------------------------===//
TypeID typeID) {
return llvm::to_vector<4>(llvm::make_filter_range(
entries, [typeID](DataLayoutEntryInterface entry) {
- auto type = entry.getKey().dyn_cast<Type>();
+ auto type = llvm::dyn_cast_if_present<Type>(entry.getKey());
return type && type.getTypeID() == typeID;
}));
}
DenseMap<TypeID, DataLayoutEntryList> &types,
DenseMap<StringAttr, DataLayoutEntryInterface> &ids) {
for (DataLayoutEntryInterface entry : getEntries()) {
- if (auto type = entry.getKey().dyn_cast<Type>())
+ if (auto type = llvm::dyn_cast_if_present<Type>(entry.getKey()))
types[type.getTypeID()].push_back(entry);
else
ids[entry.getKey().get<StringAttr>()] = entry;
bool ShapeAdaptor::hasRank() const {
if (val.isNull())
return false;
- if (auto t = val.dyn_cast<Type>())
+ if (auto t = llvm::dyn_cast_if_present<Type>(val))
return cast<ShapedType>(t).hasRank();
if (val.is<Attribute>())
return true;
Type ShapeAdaptor::getElementType() const {
if (val.isNull())
return nullptr;
- if (auto t = val.dyn_cast<Type>())
+ if (auto t = llvm::dyn_cast_if_present<Type>(val))
return cast<ShapedType>(t).getElementType();
if (val.is<Attribute>())
return nullptr;
void ShapeAdaptor::getDims(SmallVectorImpl<int64_t> &res) const {
assert(hasRank());
- if (auto t = val.dyn_cast<Type>()) {
+ if (auto t = llvm::dyn_cast_if_present<Type>(val)) {
ArrayRef<int64_t> vals = cast<ShapedType>(t).getShape();
res.assign(vals.begin(), vals.end());
- } else if (auto attr = val.dyn_cast<Attribute>()) {
+ } else if (auto attr = llvm::dyn_cast_if_present<Attribute>(val)) {
auto dattr = cast<DenseIntElementsAttr>(attr);
res.clear();
res.reserve(dattr.size());
int64_t ShapeAdaptor::getDimSize(int index) const {
assert(hasRank());
- if (auto t = val.dyn_cast<Type>())
+ if (auto t = llvm::dyn_cast_if_present<Type>(val))
return cast<ShapedType>(t).getDimSize(index);
- if (auto attr = val.dyn_cast<Attribute>())
+ if (auto attr = llvm::dyn_cast_if_present<Attribute>(val))
return cast<DenseIntElementsAttr>(attr)
.getValues<APInt>()[index]
.getSExtValue();
int64_t ShapeAdaptor::getRank() const {
assert(hasRank());
- if (auto t = val.dyn_cast<Type>())
+ if (auto t = llvm::dyn_cast_if_present<Type>(val))
return cast<ShapedType>(t).getRank();
- if (auto attr = val.dyn_cast<Attribute>())
+ if (auto attr = llvm::dyn_cast_if_present<Attribute>(val))
return cast<DenseIntElementsAttr>(attr).size();
return val.get<ShapedTypeComponents *>()->getDims().size();
}
if (!hasRank())
return false;
- if (auto t = val.dyn_cast<Type>())
+ if (auto t = llvm::dyn_cast_if_present<Type>(val))
return cast<ShapedType>(t).hasStaticShape();
- if (auto attr = val.dyn_cast<Attribute>()) {
+ if (auto attr = llvm::dyn_cast_if_present<Attribute>(val)) {
auto dattr = cast<DenseIntElementsAttr>(attr);
for (auto index : dattr.getValues<APInt>())
if (ShapedType::isDynamic(index.getSExtValue()))
int64_t ShapeAdaptor::getNumElements() const {
assert(hasStaticShape() && "cannot get element count of dynamic shaped type");
- if (auto t = val.dyn_cast<Type>())
+ if (auto t = llvm::dyn_cast_if_present<Type>(val))
return cast<ShapedType>(t).getNumElements();
- if (auto attr = val.dyn_cast<Attribute>()) {
+ if (auto attr = llvm::dyn_cast_if_present<Attribute>(val)) {
auto dattr = cast<DenseIntElementsAttr>(attr);
int64_t num = 1;
for (auto index : dattr.getValues<APInt>()) {
/// If ofr is a constant integer or an IntegerAttr, return the integer.
static std::optional<int64_t> getConstantIntValue(OpFoldResult ofr) {
// Case 1: Check for Constant integer.
- if (auto val = ofr.dyn_cast<Value>()) {
+ if (auto val = llvm::dyn_cast_if_present<Value>(ofr)) {
APSInt intVal;
if (matchPattern(val, m_ConstantInt(&intVal)))
return intVal.getSExtValue();
return std::nullopt;
}
// Case 2: Check for IntegerAttr.
- Attribute attr = ofr.dyn_cast<Attribute>();
+ Attribute attr = llvm::dyn_cast_if_present<Attribute>(ofr);
if (auto intAttr = dyn_cast_or_null<IntegerAttr>(attr))
return intAttr.getValue().getSExtValue();
return std::nullopt;
}
AffineExpr ValueBoundsConstraintSet::getExpr(OpFoldResult ofr) {
- if (Value value = ofr.dyn_cast<Value>())
+ if (Value value = llvm::dyn_cast_if_present<Value>(ofr))
return getExpr(value, /*dim=*/std::nullopt);
auto constInt = getConstantIntValue(ofr);
assert(constInt.has_value() && "expected Integer constant");
const Pass &getPass() const { return pass; }
Operation *getOp() const {
ArrayRef<IRUnit> irUnits = getContextIRUnits();
- return irUnits.empty() ? nullptr : irUnits[0].dyn_cast<Operation *>();
+ return irUnits.empty() ? nullptr
+ : llvm::dyn_cast_if_present<Operation *>(irUnits[0]);
}
public:
if (getTrait("::mlir::OpTrait::SameOperandsAndResultType")) {
// Check for a non-variable length operand to use as the type anchor.
auto *operandI = llvm::find_if(arguments, [](const Argument &arg) {
- NamedTypeConstraint *operand = arg.dyn_cast<NamedTypeConstraint *>();
+ NamedTypeConstraint *operand = llvm::dyn_cast_if_present<NamedTypeConstraint *>(arg);
return operand && !operand->isVariableLength();
});
if (operandI == arguments.end())
void Operator::print(llvm::raw_ostream &os) const {
os << "op '" << getOperationName() << "'\n";
for (Argument arg : arguments) {
- if (auto *attr = arg.dyn_cast<NamedAttribute *>())
+ if (auto *attr = llvm::dyn_cast_if_present<NamedAttribute *>(arg))
os << "[attribute] " << attr->name << '\n';
else
os << "[operand] " << arg.get<NamedTypeConstraint *>()->name << '\n';
return nullptr;
SmallVector<uint32_t> weightValues;
weightValues.reserve(weights->size());
- for (APInt weight : weights->cast<DenseIntElementsAttr>())
+ for (APInt weight : llvm::cast<DenseIntElementsAttr>(*weights))
weightValues.push_back(weight.getLimitedValue());
return llvm::MDBuilder(moduleTranslation.getLLVMContext())
.createBranchWeights(weightValues);
auto *ty = llvm::cast<llvm::IntegerType>(
moduleTranslation.convertType(switchOp.getValue().getType()));
for (auto i :
- llvm::zip(switchOp.getCaseValues()->cast<DenseIntElementsAttr>(),
+ llvm::zip(llvm::cast<DenseIntElementsAttr>(*switchOp.getCaseValues()),
switchOp.getCaseDestinations()))
switchInst->addCase(
llvm::ConstantInt::get(ty, std::get<0>(i).getLimitedValue()),
// Returns the static shape of the provided type if possible.
auto getConstantShape = [&](llvm::Type *type) {
- return getBuiltinTypeForAttr(convertType(type))
- .dyn_cast_or_null<ShapedType>();
+ return llvm::dyn_cast_if_present<ShapedType>(getBuiltinTypeForAttr(convertType(type))
+ );
};
// Convert one-dimensional constant arrays or vectors that store 1/2/4/8-byte
// Convert zero aggregates.
if (auto *constZero = dyn_cast<llvm::ConstantAggregateZero>(constant)) {
- auto shape = getBuiltinTypeForAttr(convertType(constZero->getType()))
- .dyn_cast_or_null<ShapedType>();
+ auto shape = llvm::dyn_cast_if_present<ShapedType>(getBuiltinTypeForAttr(convertType(constZero->getType()))
+ );
if (!shape)
return {};
// Convert zero aggregates with a static shape to splat elements attributes.
std::string llvmDataLayout;
llvm::raw_string_ostream layoutStream(llvmDataLayout);
for (DataLayoutEntryInterface entry : attribute.getEntries()) {
- auto key = entry.getKey().dyn_cast<StringAttr>();
+ auto key = llvm::dyn_cast_if_present<StringAttr>(entry.getKey());
if (!key)
continue;
if (key.getValue() == DLTIDialect::kDataLayoutEndiannessKey) {
// specified in entries. Where possible, data layout queries are used instead
// of directly inspecting the entries.
for (DataLayoutEntryInterface entry : attribute.getEntries()) {
- auto type = entry.getKey().dyn_cast<Type>();
+ auto type = llvm::dyn_cast_if_present<Type>(entry.getKey());
if (!type)
continue;
// Data layout for the index type is irrelevant at this point.
static_cast<::mlir::spirv::LinkageType>(words[wordIndex++]));
auto linkageAttr = opBuilder.getAttr<::mlir::spirv::LinkageAttributesAttr>(
linkageName, linkageTypeAttr);
- decorations[words[0]].set(symbol, linkageAttr.dyn_cast<Attribute>());
+ decorations[words[0]].set(symbol, llvm::dyn_cast<Attribute>(linkageAttr));
break;
}
case spirv::Decoration::Aliased:
if (values) {
for (auto &intVal : values.getValue()) {
operands.push_back(static_cast<uint32_t>(
- intVal.cast<IntegerAttr>().getValue().getZExtValue()));
+ llvm::cast<IntegerAttr>(intVal).getValue().getZExtValue()));
}
}
encodeInstructionInto(executionModes, spirv::Opcode::OpExecutionMode,
case spirv::Decoration::LinkageAttributes: {
// Get the value of the Linkage Attributes
// e.g., LinkageAttributes=["linkageName", linkageType].
- auto linkageAttr = attr.getValue().dyn_cast<spirv::LinkageAttributesAttr>();
+ auto linkageAttr = llvm::dyn_cast<spirv::LinkageAttributesAttr>(attr.getValue());
auto linkageName = linkageAttr.getLinkageName();
auto linkageType = linkageAttr.getLinkageType().getValue();
// Encode the Linkage Name (string literal to uint32_t).
/// Return the location of the definition of this symbol.
SMRange getDefLoc() const {
- if (const ast::Decl *decl = definition.dyn_cast<const ast::Decl *>()) {
+ if (const ast::Decl *decl = llvm::dyn_cast_if_present<const ast::Decl *>(definition)) {
const ast::Name *declName = decl->getName();
return declName ? declName->getLoc() : decl->getLoc();
}
return std::nullopt;
// Add hover for operation names.
- if (const auto *op = symbol->definition.dyn_cast<const ods::Operation *>())
+ if (const auto *op = llvm::dyn_cast_if_present<const ods::Operation *>(symbol->definition))
return buildHoverForOpName(op, hoverRange);
const auto *decl = symbol->definition.get<const ast::Decl *>();
return findHover(decl, hoverRange);
#ifndef NDEBUG
static std::string getNodeName(CallOpInterface op) {
- if (auto sym = op.getCallableForCallee().dyn_cast<SymbolRefAttr>())
+ if (auto sym = llvm::dyn_cast_if_present<SymbolRefAttr>(op.getCallableForCallee()))
return debugString(op);
return "_unnamed_callee_";
}
assert(!foldResults[i].isNull() && "expected valid OpFoldResult");
// Check if the result was an SSA value.
- if (auto repl = foldResults[i].dyn_cast<Value>()) {
+ if (auto repl = llvm::dyn_cast_if_present<Value>(foldResults[i])) {
if (repl.getType() != op->getResult(i).getType()) {
results.clear();
return failure();
// Remap the locations of the inlined operations if a valid source location
// was provided.
- if (inlineLoc && !inlineLoc->isa<UnknownLoc>())
+ if (inlineLoc && !llvm::isa<UnknownLoc>(*inlineLoc))
remapInlinedLocations(newBlocks, *inlineLoc);
// If the blocks were moved in-place, make sure to remap any necessary
}
LogicalResult FooAnalysis::visit(ProgramPoint point) {
- if (auto *op = point.dyn_cast<Operation *>()) {
+ if (auto *op = llvm::dyn_cast_if_present<Operation *>(point)) {
visitOperation(op);
return success();
}
- if (auto *block = point.dyn_cast<Block *>()) {
+ if (auto *block = llvm::dyn_cast_if_present<Block *>(point)) {
visitBlock(block);
return success();
}
}
// Replace the op with the reified bound.
- if (auto val = reified->dyn_cast<Value>()) {
+ if (auto val = llvm::dyn_cast_if_present<Value>(*reified)) {
rewriter.replaceOp(op, val);
return WalkResult::skip();
}
)decl";
for (const auto &attrOrProp : attrOrProperties) {
if (const auto *namedProperty =
- attrOrProp.dyn_cast<const NamedProperty *>()) {
+ llvm::dyn_cast_if_present<const NamedProperty *>(attrOrProp)) {
StringRef name = namedProperty->name;
auto &prop = namedProperty->prop;
FmtContext fctx;
.addSubst("_diag", propertyDiag)),
name);
} else {
- const auto *namedAttr = attrOrProp.dyn_cast<const AttributeMetadata *>();
+ const auto *namedAttr = llvm::dyn_cast_if_present<const AttributeMetadata *>(attrOrProp);
StringRef name = namedAttr->attrName;
setPropMethod << formatv(R"decl(
{{
)decl";
for (const auto &attrOrProp : attrOrProperties) {
if (const auto *namedProperty =
- attrOrProp.dyn_cast<const NamedProperty *>()) {
+ llvm::dyn_cast_if_present<const NamedProperty *>(attrOrProp)) {
StringRef name = namedProperty->name;
auto &prop = namedProperty->prop;
FmtContext fctx;
.addSubst("_storage", propertyStorage)));
continue;
}
- const auto *namedAttr = attrOrProp.dyn_cast<const AttributeMetadata *>();
+ const auto *namedAttr = llvm::dyn_cast_if_present<const AttributeMetadata *>(attrOrProp);
StringRef name = namedAttr->attrName;
getPropMethod << formatv(R"decl(
{{
)decl";
for (const auto &attrOrProp : attrOrProperties) {
if (const auto *namedProperty =
- attrOrProp.dyn_cast<const NamedProperty *>()) {
+ llvm::dyn_cast_if_present<const NamedProperty *>(attrOrProp)) {
StringRef name = namedProperty->name;
auto &prop = namedProperty->prop;
FmtContext fctx;
llvm::interleaveComma(
attrOrProperties, hashMethod, [&](const ConstArgument &attrOrProp) {
if (const auto *namedProperty =
- attrOrProp.dyn_cast<const NamedProperty *>()) {
+ llvm::dyn_cast_if_present<const NamedProperty *>(attrOrProp)) {
hashMethod << "\n hash_" << namedProperty->name << "(prop."
<< namedProperty->name << ")";
return;
}
const auto *namedAttr =
- attrOrProp.dyn_cast<const AttributeMetadata *>();
+ llvm::dyn_cast_if_present<const AttributeMetadata *>(attrOrProp);
StringRef name = namedAttr->attrName;
hashMethod << "\n llvm::hash_value(prop." << name
<< ".getAsOpaquePointer())";
)decl";
for (const auto &attrOrProp : attrOrProperties) {
if (const auto *namedAttr =
- attrOrProp.dyn_cast<const AttributeMetadata *>()) {
+ llvm::dyn_cast_if_present<const AttributeMetadata *>(attrOrProp)) {
StringRef name = namedAttr->attrName;
getInherentAttrMethod << formatv(getInherentAttrMethodFmt, name);
setInherentAttrMethod << formatv(setInherentAttrMethodFmt, name);
// syntax. This method verifies the constraint on the properties attributes
// before they are set, since dyn_cast<> will silently omit failures.
for (const auto &attrOrProp : attrOrProperties) {
- const auto *namedAttr = attrOrProp.dyn_cast<const AttributeMetadata *>();
+ const auto *namedAttr = llvm::dyn_cast_if_present<const AttributeMetadata *>(attrOrProp);
if (!namedAttr || !namedAttr->constraint)
continue;
Attribute attr = *namedAttr->constraint;
// Calculate the start index from which we can attach default values in the
// builder declaration.
for (int i = op.getNumArgs() - 1; i >= 0; --i) {
- auto *namedAttr = op.getArg(i).dyn_cast<tblgen::NamedAttribute *>();
+ auto *namedAttr = llvm::dyn_cast_if_present<tblgen::NamedAttribute *>(op.getArg(i));
if (!namedAttr || !namedAttr->attr.hasDefaultValue())
break;
for (int i = 0, e = op.getNumArgs(), numOperands = 0; i < e; ++i) {
Argument arg = op.getArg(i);
- if (const auto *operand = arg.dyn_cast<NamedTypeConstraint *>()) {
+ if (const auto *operand = llvm::dyn_cast_if_present<NamedTypeConstraint *>(arg)) {
StringRef type;
if (operand->isVariadicOfVariadic())
type = "::llvm::ArrayRef<::mlir::ValueRange>";
operand->isOptional());
continue;
}
- if (const auto *operand = arg.dyn_cast<NamedProperty *>()) {
+ if (const auto *operand = llvm::dyn_cast_if_present<NamedProperty *>(arg)) {
// TODO
continue;
}
llvm::raw_string_ostream comparatorOs(comparator);
for (const auto &attrOrProp : attrOrProperties) {
if (const auto *namedProperty =
- attrOrProp.dyn_cast<const NamedProperty *>()) {
+ llvm::dyn_cast_if_present<const NamedProperty *>(attrOrProp)) {
StringRef name = namedProperty->name;
if (name.empty())
report_fatal_error("missing name for property");
.addSubst("_storage", propertyStorage)));
continue;
}
- const auto *namedAttr = attrOrProp.dyn_cast<const AttributeMetadata *>();
+ const auto *namedAttr = llvm::dyn_cast_if_present<const AttributeMetadata *>(attrOrProp);
const Attribute *attr = nullptr;
if (namedAttr->constraint)
attr = &*namedAttr->constraint;
/// Get the variable this type is resolved to, or nullptr.
const NamedTypeConstraint *getVariable() const {
- return resolver.dyn_cast<const NamedTypeConstraint *>();
+ return llvm::dyn_cast_if_present<const NamedTypeConstraint *>(resolver);
}
/// Get the attribute this type is resolved to, or nullptr.
const NamedAttribute *getAttribute() const {
- return resolver.dyn_cast<const NamedAttribute *>();
+ return llvm::dyn_cast_if_present<const NamedAttribute *>(resolver);
}
/// Get the transformer for the type of the variable, or std::nullopt.
std::optional<StringRef> getVarTransformer() const {
builderLines.push_back("_ods_context = _ods_get_default_loc_context(loc)");
for (int i = 0, e = op.getNumArgs(); i < e; ++i) {
Argument arg = op.getArg(i);
- auto *attribute = arg.dyn_cast<NamedAttribute *>();
+ auto *attribute = llvm::dyn_cast_if_present<NamedAttribute *>(arg);
if (!attribute)
continue;
// - default-valued named attributes
// - optional operands
Argument a = op.getArg(builderArgIndex - numResultArgs);
- if (auto *nattr = a.dyn_cast<NamedAttribute *>())
+ if (auto *nattr = llvm::dyn_cast_if_present<NamedAttribute *>(a))
return (nattr->attr.isOptional() || nattr->attr.hasDefaultValue());
- if (auto *ntype = a.dyn_cast<NamedTypeConstraint *>())
+ if (auto *ntype = llvm::dyn_cast_if_present<NamedTypeConstraint *>(a))
return ntype->isOptional();
return false;
};
++opArgIdx;
continue;
}
- if (auto *operand = opArg.dyn_cast<NamedTypeConstraint *>()) {
+ if (auto *operand = llvm::dyn_cast_if_present<NamedTypeConstraint *>(opArg)) {
if (operand->isVariableLength()) {
auto error = formatv("use nested DAG construct to match op {0}'s "
"variadic operand #{1} unsupported now",
int valueIndex = 0; // An index for uniquing local variable names.
for (int argIndex = 0, e = resultOp.getNumArgs(); argIndex < e; ++argIndex) {
const auto *operand =
- resultOp.getArg(argIndex).dyn_cast<NamedTypeConstraint *>();
+ llvm::dyn_cast_if_present<NamedTypeConstraint *>(resultOp.getArg(argIndex));
// We do not need special handling for attributes.
if (!operand)
continue;
Argument opArg = resultOp.getArg(argIndex);
// Handle the case of operand first.
- if (auto *operand = opArg.dyn_cast<NamedTypeConstraint *>()) {
+ if (auto *operand = llvm::dyn_cast_if_present<NamedTypeConstraint *>(opArg)) {
if (!operand->name.empty())
os << "/*" << operand->name << "=*/";
os << childNodeNames.lookup(argIndex);
// Process operands/attributes
for (unsigned i = 0, e = op.getNumArgs(); i < e; ++i) {
auto argument = op.getArg(i);
- if (auto *valueArg = argument.dyn_cast<NamedTypeConstraint *>()) {
+ if (auto *valueArg = llvm::dyn_cast_if_present<NamedTypeConstraint *>(argument)) {
if (valueArg->isVariableLength()) {
if (i != e - 1) {
PrintFatalError(loc, "SPIR-V ops can have Variadic<..> or "
// Handle built-in types that are not handled by the default process.
if (auto iType = dyn_cast<IntegerType>(type)) {
for (DataLayoutEntryInterface entry : params)
- if (entry.getKey().dyn_cast<Type>() == type)
+ if (llvm::dyn_cast_if_present<Type>(entry.getKey()) == type)
return 8 *
cast<IntegerAttr>(entry.getValue()).getValue().getZExtValue();
return 8 * iType.getIntOrFloatBitWidth();