view = slice.getParentView();
assert(viewType.isa<ViewType>() && "expected a ViewType");
}
- return view->getDefiningOp()->cast<ViewOp>();
+ return cast<ViewOp>(view->getDefiningOp());
}
Value *linalg::getViewSupportingMemRef(Value *view) {
if (auto viewOp = dyn_cast<ViewOp>(view->getDefiningOp()))
return std::make_pair(viewOp.getIndexing(dim), dim);
- auto sliceOp = view->getDefiningOp()->cast<SliceOp>();
+ auto sliceOp = cast<SliceOp>(view->getDefiningOp());
auto *parentView = sliceOp.getParentView();
unsigned sliceDim = sliceOp.getSlicingDim();
auto *indexing = sliceOp.getIndexing();
if (indexing->getDefiningOp()) {
- if (auto rangeOp = indexing->getDefiningOp()->cast<RangeOp>()) {
+ if (auto rangeOp = cast<RangeOp>(indexing->getDefiningOp())) {
// If I sliced with a range and I sliced at this dim, then I'm it.
if (dim == sliceDim) {
return std::make_pair(rangeOp.getResult(), dim);
auto lb = rangeOp.getMin();
auto ub = rangeOp.getMax();
// This must be a constexpr index until we relax the affine.for constraint
- auto step =
- rangeOp.getStep()->getDefiningOp()->cast<ConstantIndexOp>().getValue();
+ auto step = llvm::cast<ConstantIndexOp>(rangeOp.getStep()->getDefiningOp())
+ .getValue();
loops.emplace_back(ivs[i], ValueHandle(lb), ValueHandle(ub), step);
}
}
SmallVector<Value *, 4> rewrite(Operation *op, ArrayRef<Value *> operands,
FuncBuilder &rewriter) const override {
- auto rangeOp = op->cast<linalg::RangeOp>();
+ auto rangeOp = cast<linalg::RangeOp>(op);
auto rangeDescriptorType =
linalg::convertLinalgType(rangeOp.getResult()->getType());
SmallVector<Value *, 4> rewrite(Operation *op, ArrayRef<Value *> operands,
FuncBuilder &rewriter) const override {
- auto viewOp = op->cast<linalg::ViewOp>();
+ auto viewOp = cast<linalg::ViewOp>(op);
auto viewDescriptorType = linalg::convertLinalgType(viewOp.getViewType());
auto memrefType =
viewOp.getSupportingMemRef()->getType().cast<MemRefType>();
SmallVector<Value *, 4> rewrite(Operation *op, ArrayRef<Value *> operands,
FuncBuilder &rewriter) const override {
- auto sliceOp = op->cast<linalg::SliceOp>();
+ auto sliceOp = cast<linalg::SliceOp>(op);
auto newViewDescriptorType =
linalg::convertLinalgType(sliceOp.getViewType());
auto elementType = rewriter.getType<LLVM::LLVMType>(
assert(view->getType().isa<ViewType>() && "expected a ViewType");
if (auto viewOp = dyn_cast<ViewOp>(view->getDefiningOp()))
return viewOp.getRank();
- return view->getDefiningOp()->cast<SliceOp>().getRank();
+ return cast<SliceOp>(view->getDefiningOp()).getRank();
}
ViewOp linalg::emitAndReturnViewOpFromMemRef(Value *memRef) {
#include "mlir/IR/StandardTypes.h"
using llvm::ArrayRef;
+using llvm::cast;
using llvm::SmallVector;
using mlir::FuncBuilder;
using mlir::MemRefType;
SmallVector<mlir::Value *, 8> tmp;
do {
- auto sliceOp = v->getDefiningOp()->cast<SliceOp>(); // must be a slice op
+ auto sliceOp = cast<SliceOp>(v->getDefiningOp()); // must be a slice op
tmp.push_back(v);
v = sliceOp.getParentView();
} while (!v->getType().isa<ViewType>());
ArrayRef<Value *> chain) {
using namespace mlir::edsc::op;
assert(chain.front()->getType().isa<ViewType>() && "must be a ViewType");
- auto viewOp = chain.front()->getDefiningOp()->cast<ViewOp>();
+ auto viewOp = cast<ViewOp>(chain.front()->getDefiningOp());
auto *indexing = viewOp.getIndexing(dim);
if (!indexing->getType().isa<RangeType>())
return indexing;
- auto rangeOp = indexing->getDefiningOp()->cast<RangeOp>();
+ auto rangeOp = cast<RangeOp>(indexing->getDefiningOp());
Value *min = rangeOp.getMin(), *max = rangeOp.getMax(),
*step = rangeOp.getStep();
for (auto *v : chain.drop_front(1)) {
- auto slice = v->getDefiningOp()->cast<SliceOp>();
+ auto slice = cast<SliceOp>(v->getDefiningOp());
if (slice.getRank() != slice.getParentRank()) {
// Rank-reducing slice.
if (slice.getSlicingDim() == dim) {
dim = (slice.getSlicingDim() < dim) ? dim - 1 : dim;
} else { // not a rank-reducing slice.
if (slice.getSlicingDim() == dim) {
- auto range = slice.getIndexing()->getDefiningOp()->cast<RangeOp>();
+ auto range = cast<RangeOp>(slice.getIndexing()->getDefiningOp());
auto oldMin = min;
min = ValueHandle(min) + ValueHandle(range.getMin());
// ideally: max = min(oldMin + ValueHandle(range.getMax()), oldMax);
for (unsigned idx = 0; idx < rank; ++idx) {
ranges.push_back(createFullyComposedIndexing(idx, chain));
}
- return view(memRef, ranges).getOperation()->cast<ViewOp>();
+ return cast<ViewOp>(view(memRef, ranges).getOperation());
}
if (auto viewOp = llvm::dyn_cast<linalg::ViewOp>(view->getDefiningOp()))
return viewOp.getRanges();
- auto sliceOp = view->getDefiningOp()->cast<linalg::SliceOp>();
+ auto sliceOp = llvm::cast<linalg::SliceOp>(view->getDefiningOp());
unsigned slicingDim = sliceOp.getSlicingDim();
auto *indexing = *(sliceOp.getIndexings().begin());
bool isRankReducing = indexing->getType().isa<mlir::IndexType>();
// a getelementptr.
Value *obtainDataPtr(Operation *op, Value *viewDescriptor,
ArrayRef<Value *> indices, FuncBuilder &rewriter) const {
- auto loadOp = op->cast<Op>();
+ auto loadOp = cast<Op>(op);
auto elementType =
loadOp.getViewType().template cast<linalg::ViewType>().getElementType();
auto *llvmPtrType = linalg::convertLinalgType(elementType)
SmallVector<Value *, 4> res;
res.reserve(ranges.size());
for (auto *v : ranges) {
- auto r = v->getDefiningOp()->cast<RangeOp>();
+ auto r = cast<RangeOp>(v->getDefiningOp());
res.push_back(extract(r));
}
return res;
for (auto z : llvm::zip(res.steps, tileSizes)) {
auto *step = std::get<0>(z);
auto tileSize = std::get<1>(z);
- auto stepValue = step->getDefiningOp()->cast<ConstantIndexOp>().getValue();
+ auto stepValue = cast<ConstantIndexOp>(step->getDefiningOp()).getValue();
auto tileSizeValue =
- tileSize->getDefiningOp()->cast<ConstantIndexOp>().getValue();
+ cast<ConstantIndexOp>(tileSize->getDefiningOp()).getValue();
assert(stepValue > 0);
tiledSteps.push_back(constant_index(stepValue * tileSizeValue));
}
operands.push_back(indexing);
continue;
}
- RangeOp range = indexing->getDefiningOp()->cast<RangeOp>();
+ RangeOp range = cast<RangeOp>(indexing->getDefiningOp());
ValueHandle min(range.getMin());
Value *storeIndex = *(loadOrStoreOp.getIndices().begin() + storeDim++);
using edsc::op::operator+;
PatternMatchResult
Rewriter<linalg::LoadOp>::matchAndRewrite(Operation *op,
PatternRewriter &rewriter) const {
- auto load = op->cast<linalg::LoadOp>();
+ auto load = cast<linalg::LoadOp>(op);
SliceOp slice = dyn_cast<SliceOp>(load.getView()->getDefiningOp());
ViewOp view = slice ? emitAndReturnFullyComposedView(slice.getResult())
- : load.getView()->getDefiningOp()->cast<ViewOp>();
+ : cast<ViewOp>(load.getView()->getDefiningOp());
ScopedContext scope(FuncBuilder(load), load.getLoc());
auto *memRef = view.getSupportingMemRef();
auto operands = emitAndReturnLoadStoreOperands(load, view);
PatternMatchResult
Rewriter<linalg::StoreOp>::matchAndRewrite(Operation *op,
PatternRewriter &rewriter) const {
- auto store = op->cast<linalg::StoreOp>();
+ auto store = cast<linalg::StoreOp>(op);
SliceOp slice = dyn_cast<SliceOp>(store.getView()->getDefiningOp());
ViewOp view = slice ? emitAndReturnFullyComposedView(slice.getResult())
- : store.getView()->getDefiningOp()->cast<ViewOp>();
+ : cast<ViewOp>(store.getView()->getDefiningOp());
ScopedContext scope(FuncBuilder(store), store.getLoc());
auto *valueToStore = store.getValueToStore();
auto *memRef = view.getSupportingMemRef();
// Finally, update the return type of the function based on the argument to
// the return operation.
for (auto &block : f->getBlocks()) {
- auto ret = block.getTerminator()->cast<ReturnOp>();
+ auto ret = llvm::cast<ReturnOp>(block.getTerminator());
if (!ret)
continue;
if (ret.getNumOperands() &&
mlir::PatternRewriter &rewriter) const override {
// We can directly cast the current operation as this will only get invoked
// on TransposeOp.
- TransposeOp transpose = op->cast<TransposeOp>();
+ TransposeOp transpose = llvm::cast<TransposeOp>(op);
// Look through the input of the current transpose.
mlir::Value *transposeInput = transpose.getOperand();
TransposeOp transposeInputOp =
mlir::PatternMatchResult
matchAndRewrite(mlir::Operation *op,
mlir::PatternRewriter &rewriter) const override {
- ReshapeOp reshape = op->cast<ReshapeOp>();
+ ReshapeOp reshape = llvm::cast<ReshapeOp>(op);
// Look through the input of the current reshape.
ConstantOp constantOp = llvm::dyn_cast_or_null<ConstantOp>(
reshape.getOperand()->getDefiningOp());
mlir::PatternMatchResult
matchAndRewrite(mlir::Operation *op,
mlir::PatternRewriter &rewriter) const override {
- ReshapeOp reshape = op->cast<ReshapeOp>();
+ ReshapeOp reshape = llvm::cast<ReshapeOp>(op);
// Look through the input of the current reshape.
mlir::Value *reshapeInput = reshape.getOperand();
// If the input is defined by another reshape, bingo!
mlir::PatternMatchResult
matchAndRewrite(mlir::Operation *op,
mlir::PatternRewriter &rewriter) const override {
- ReshapeOp reshape = op->cast<ReshapeOp>();
+ ReshapeOp reshape = llvm::cast<ReshapeOp>(op);
if (reshape.getOperand()->getType() != reshape.getResult()->getType())
return matchFailure();
rewriter.replaceOp(reshape, {reshape.getOperand()});
using intrinsics::constant_index;
using linalg::intrinsics::range;
using linalg::intrinsics::view;
- toy::MulOp mul = op->cast<toy::MulOp>();
+ toy::MulOp mul = cast<toy::MulOp>(op);
auto loc = mul.getLoc();
Value *result = memRefTypeCast(
rewriter, rewriter.create<toy::AllocOp>(loc, mul.getResult()->getType())
/// number must match the number of result of `op`.
SmallVector<Value *, 4> rewrite(Operation *op, ArrayRef<Value *> operands,
FuncBuilder &rewriter) const override {
- auto add = op->cast<toy::AddOp>();
+ auto add = cast<toy::AddOp>(op);
auto loc = add.getLoc();
// Create a `toy.alloc` operation to allocate the output buffer for this op.
Value *result = memRefTypeCast(
// Get or create the declaration of the printf function in the module.
Function *printfFunc = getPrintf(*op->getFunction()->getModule());
- auto print = op->cast<toy::PrintOp>();
+ auto print = cast<toy::PrintOp>(op);
auto loc = print.getLoc();
// We will operate on a MemRef abstraction, we use a type.cast to get one
// if our operand is still a Toy array.
SmallVector<Value *, 4> rewrite(Operation *op, ArrayRef<Value *> operands,
FuncBuilder &rewriter) const override {
- toy::ConstantOp cstOp = op->cast<toy::ConstantOp>();
+ toy::ConstantOp cstOp = cast<toy::ConstantOp>(op);
auto loc = cstOp.getLoc();
auto retTy = cstOp.getResult()->getType().cast<toy::ToyArrayType>();
auto shape = retTy.getShape();
SmallVector<Value *, 4> rewrite(Operation *op, ArrayRef<Value *> operands,
FuncBuilder &rewriter) const override {
- auto transpose = op->cast<toy::TransposeOp>();
+ auto transpose = cast<toy::TransposeOp>(op);
auto loc = transpose.getLoc();
Value *result = memRefTypeCast(
rewriter,
SmallVector<Value *, 4> rewrite(Operation *op, ArrayRef<Value *> operands,
FuncBuilder &rewriter) const override {
- auto retOp = op->cast<toy::ReturnOp>();
+ auto retOp = cast<toy::ReturnOp>(op);
using namespace edsc;
auto loc = retOp.getLoc();
// Argument is optional, handle both cases.
// Finally, update the return type of the function based on the argument to
// the return operation.
for (auto &block : f->getBlocks()) {
- auto ret = block.getTerminator()->cast<ReturnOp>();
+ auto ret = llvm::cast<ReturnOp>(block.getTerminator());
if (!ret)
continue;
if (ret.getNumOperands() &&
mlir::PatternRewriter &rewriter) const override {
// We can directly cast the current operation as this will only get invoked
// on TransposeOp.
- TransposeOp transpose = op->cast<TransposeOp>();
+ TransposeOp transpose = llvm::cast<TransposeOp>(op);
// look through the input to the current transpose
mlir::Value *transposeInput = transpose.getOperand();
mlir::Operation *transposeInputInst = transposeInput->getDefiningOp();
mlir::PatternMatchResult
matchAndRewrite(mlir::Operation *op,
mlir::PatternRewriter &rewriter) const override {
- ReshapeOp reshape = op->cast<ReshapeOp>();
+ ReshapeOp reshape = llvm::cast<ReshapeOp>(op);
// look through the input to the current reshape
mlir::Value *reshapeInput = reshape.getOperand();
mlir::Operation *reshapeInputInst = reshapeInput->getDefiningOp();
mlir::PatternMatchResult
matchAndRewrite(mlir::Operation *op,
mlir::PatternRewriter &rewriter) const override {
- ReshapeOp reshape = op->cast<ReshapeOp>();
+ ReshapeOp reshape = llvm::cast<ReshapeOp>(op);
// look through the input to the current reshape
mlir::Value *reshapeInput = reshape.getOperand();
mlir::Operation *reshapeInputInst = reshapeInput->getDefiningOp();
mlir::PatternMatchResult
matchAndRewrite(mlir::Operation *op,
mlir::PatternRewriter &rewriter) const override {
- ReshapeOp reshape = op->cast<ReshapeOp>();
+ ReshapeOp reshape = llvm::cast<ReshapeOp>(op);
if (reshape.getOperand()->getType() != reshape.getResult()->getType())
return matchFailure();
rewriter.replaceOp(reshape, {reshape.getOperand()});
mlir::PatternMatchResult
matchAndRewrite(mlir::Operation *op,
mlir::PatternRewriter &rewriter) const override {
- TypeCastOp typeCast = op->cast<TypeCastOp>();
+ TypeCastOp typeCast = llvm::cast<TypeCastOp>(op);
auto resTy = typeCast.getResult()->getType();
auto *candidateOp = op;
while (candidateOp && candidateOp->isa<TypeCastOp>()) {
static LogicalResult constantFoldHook(Operation *op,
ArrayRef<Attribute> operands,
SmallVectorImpl<Attribute> &results) {
- return op->cast<ConcreteType>().constantFold(operands, results,
- op->getContext());
+ return cast<ConcreteType>(op).constantFold(operands, results,
+ op->getContext());
}
/// Op implementations can implement this hook. It should attempt to constant
/// This is an implementation detail of the folder hook for AbstractOperation.
static LogicalResult foldHook(Operation *op,
SmallVectorImpl<Value *> &results) {
- return op->cast<ConcreteType>().fold(results);
+ return cast<ConcreteType>(op).fold(results);
}
/// This hook implements a generalized folder for this operation. Operations
ArrayRef<Attribute> operands,
SmallVectorImpl<Attribute> &results) {
auto result =
- op->cast<ConcreteType>().constantFold(operands, op->getContext());
+ cast<ConcreteType>(op).constantFold(operands, op->getContext());
if (!result)
return failure();
/// This is an implementation detail of the folder hook for AbstractOperation.
static LogicalResult foldHook(Operation *op,
SmallVectorImpl<Value *> &results) {
- auto *result = op->cast<ConcreteType>().fold();
+ auto *result = cast<ConcreteType>(op).fold();
if (!result)
return failure();
if (result != op->getResult(0))
static LogicalResult verifyInvariants(Operation *op) {
return failure(
failed(BaseVerifier<Traits<ConcreteType>...>::verifyTrait(op)) ||
- failed(op->cast<ConcreteType>().verify()));
+ failed(cast<ConcreteType>(op).verify()));
}
// Returns the properties of an operation by combining the properties of the
// Conversions to declared operations like DimOp
//===--------------------------------------------------------------------===//
- /// The cast methods perform a cast from an Operation to a typed Op like
- /// DimOp. This aborts if the parameter to the template isn't an instance of
- /// the template type argument.
- template <typename OpClass> OpClass cast() {
- assert(isa<OpClass>() && "cast<Ty>() argument of incompatible type!");
- return OpClass(this);
- }
-
/// The is methods return true if the operation is a typed op (like DimOp) of
/// of the given class.
template <typename OpClass> bool isa() { return OpClass::classof(this); }
PatternMatchResult matchAndRewrite(Operation *op,
PatternRewriter &rewriter) const override {
- auto apply = op->cast<AffineApplyOp>();
+ auto apply = cast<AffineApplyOp>(op);
auto map = apply.getAffineMap();
AffineMap oldMap = map;
PatternMatchResult matchAndRewrite(Operation *op,
PatternRewriter &rewriter) const override {
- auto forOp = op->cast<AffineForOp>();
+ auto forOp = cast<AffineForOp>(op);
auto foldLowerOrUpperBound = [&forOp](bool lower) {
// Check to see if each of the operands is the result of a constant. If
// so, get the value. If not, ignore it.
return false;
}
- auto composeOp = affineApplyOps[0]->cast<AffineApplyOp>();
+ auto composeOp = cast<AffineApplyOp>(affineApplyOps[0]);
// We need yet another level of indirection because the `dim` index of the
// access may not correspond to the `dim` index of composeOp.
return !(AffineValueMap(composeOp).isFunctionOf(0, iv));
auto dstAffineForOp = dstLoopIVs[dstLoopDepth - 1];
FuncBuilder b(dstAffineForOp.getBody(), dstAffineForOp.getBody()->begin());
auto sliceLoopNest =
- b.clone(*srcLoopIVs[0].getOperation())->cast<AffineForOp>();
+ cast<AffineForOp>(b.clone(*srcLoopIVs[0].getOperation()));
Operation *sliceInst =
getInstAtPosition(positions, /*level=*/0, sliceLoopNest.getBody());
for (auto kvp : enclosingLoopToVectorDim) {
assert(kvp.second < perm.size());
auto invariants = getInvariantAccesses(
- kvp.first->cast<AffineForOp>().getInductionVar(), indices);
+ cast<AffineForOp>(kvp.first).getInductionVar(), indices);
unsigned numIndices = indices.size();
unsigned countInvariantIndices = 0;
for (unsigned dim = 0; dim < numIndices; ++dim) {
return ::makePermutationMap(load.getIndices(), enclosingLoopToVectorDim);
}
- auto store = op->cast<StoreOp>();
+ auto store = cast<StoreOp>(op);
return ::makePermutationMap(store.getIndices(), enclosingLoopToVectorDim);
}
PatternMatchResult matchAndRewrite(Operation *op,
PatternRewriter &rewriter) const {
- auto dcastOp = op->cast<DequantizeCastOp>();
+ auto dcastOp = cast<DequantizeCastOp>(op);
Type inputType = dcastOp.arg()->getType();
Type outputType = dcastOp.getResult()->getType();
PatternMatchResult matchAndRewrite(Operation *op,
PatternRewriter &rewriter) const {
- auto addOp = op->cast<RealAddEwOp>();
+ auto addOp = cast<RealAddEwOp>(op);
const UniformBinaryOpInfo info(op, addOp.lhs(), addOp.rhs(),
addOp.clamp_min(), addOp.clamp_max());
if (!info.isValid()) {
PatternMatchResult matchAndRewrite(Operation *op,
PatternRewriter &rewriter) const {
- auto mulOp = op->cast<RealMulEwOp>();
+ auto mulOp = cast<RealMulEwOp>(op);
const UniformBinaryOpInfo info(op, mulOp.lhs(), mulOp.rhs(),
mulOp.clamp_min(), mulOp.clamp_max());
if (!info.isValid()) {
PatternMatchResult match(Operation *op) const override {
if (!LLVMLegalizationPattern<AllocOp>::match(op))
return matchFailure();
- auto allocOp = op->cast<AllocOp>();
+ auto allocOp = cast<AllocOp>(op);
MemRefType type = allocOp.getType();
return isSupportedMemRefType(type) ? matchSuccess() : matchFailure();
}
SmallVector<Value *, 4> rewrite(Operation *op, ArrayRef<Value *> operands,
FuncBuilder &rewriter) const override {
- auto allocOp = op->cast<AllocOp>();
+ auto allocOp = cast<AllocOp>(op);
MemRefType type = allocOp.getType();
// Get actual sizes of the memref as values: static sizes are constant
PatternMatchResult match(Operation *op) const override {
if (!LLVMLegalizationPattern<MemRefCastOp>::match(op))
return matchFailure();
- auto memRefCastOp = op->cast<MemRefCastOp>();
+ auto memRefCastOp = cast<MemRefCastOp>(op);
MemRefType sourceType =
memRefCastOp.getOperand()->getType().cast<MemRefType>();
MemRefType targetType = memRefCastOp.getType();
SmallVector<Value *, 4> rewrite(Operation *op, ArrayRef<Value *> operands,
FuncBuilder &rewriter) const override {
- auto memRefCastOp = op->cast<MemRefCastOp>();
+ auto memRefCastOp = cast<MemRefCastOp>(op);
auto targetType = memRefCastOp.getType();
auto sourceType = memRefCastOp.getOperand()->getType().cast<MemRefType>();
PatternMatchResult match(Operation *op) const override {
if (!LLVMLegalizationPattern<DimOp>::match(op))
return this->matchFailure();
- auto dimOp = op->cast<DimOp>();
+ auto dimOp = cast<DimOp>(op);
MemRefType type = dimOp.getOperand()->getType().cast<MemRefType>();
return isSupportedMemRefType(type) ? matchSuccess() : matchFailure();
}
SmallVector<Value *, 4> rewrite(Operation *op, ArrayRef<Value *> operands,
FuncBuilder &rewriter) const override {
assert(operands.size() == 1 && "expected exactly one operand");
- auto dimOp = op->cast<DimOp>();
+ auto dimOp = cast<DimOp>(op);
MemRefType type = dimOp.getOperand()->getType().cast<MemRefType>();
SmallVector<Value *, 4> results;
PatternMatchResult match(Operation *op) const override {
if (!LLVMLegalizationPattern<Derived>::match(op))
return this->matchFailure();
- auto loadOp = op->cast<Derived>();
+ auto loadOp = cast<Derived>(op);
MemRefType type = loadOp.getMemRefType();
return isSupportedMemRefType(type) ? this->matchSuccess()
: this->matchFailure();
SmallVector<Value *, 4> rewrite(Operation *op, ArrayRef<Value *> operands,
FuncBuilder &rewriter) const override {
- auto loadOp = op->cast<LoadOp>();
+ auto loadOp = cast<LoadOp>(op);
auto type = loadOp.getMemRefType();
Value *dataPtr = getDataPtr(op->getLoc(), type, operands.front(),
SmallVector<Value *, 4> rewrite(Operation *op, ArrayRef<Value *> operands,
FuncBuilder &rewriter) const override {
- auto storeOp = op->cast<StoreOp>();
+ auto storeOp = cast<StoreOp>(op);
auto type = storeOp.getMemRefType();
Value *dataPtr = getDataPtr(op->getLoc(), type, operands[1],
}
ViewOp mlir::linalg::SliceOp::getBaseViewOp() {
- return getOperand(0)->getDefiningOp()->cast<ViewOp>();
+ return cast<ViewOp>(getOperand(0)->getDefiningOp());
}
ViewType mlir::linalg::SliceOp::getBaseViewType() {
/// ```
void mlir::linalg::impl::printBufferSizeOp(OpAsmPrinter *p, Operation *op) {
assert(op->getAbstractOperation() && "unregistered operation");
- *p << op->cast<BufferSizeOp>().getOperationName() << " "
- << *op->getOperand(0);
+ *p << cast<BufferSizeOp>(op).getOperationName() << " " << *op->getOperand(0);
p->printOptionalAttrDict(op->getAttrs());
*p << " : " << op->getOperand(0)->getType();
}
}
// Get MLIR types for injecting element pointer.
- auto allocOp = op->cast<BufferAllocOp>();
+ auto allocOp = cast<BufferAllocOp>(op);
auto elementType = allocOp.getElementType();
uint64_t elementSize = 0;
if (auto vectorType = elementType.dyn_cast<VectorType>())
}
// Get MLIR types for extracting element pointer.
- auto deallocOp = op->cast<BufferDeallocOp>();
+ auto deallocOp = cast<BufferDeallocOp>(op);
auto elementPtrTy = rewriter.getType<LLVMType>(getPtrToElementType(
deallocOp.getOperand()->getType().cast<BufferType>(), lowering));
// a getelementptr. This must be called under an edsc::ScopedContext.
Value *obtainDataPtr(Operation *op, Value *viewDescriptor,
ArrayRef<Value *> indices, FuncBuilder &rewriter) const {
- auto loadOp = op->cast<Op>();
+ auto loadOp = cast<Op>(op);
auto elementTy = rewriter.getType<LLVMType>(
getPtrToElementType(loadOp.getViewType(), lowering));
auto int64Ty = lowering.convertType(rewriter.getIntegerType(64));
SmallVector<Value *, 4> rewrite(Operation *op, ArrayRef<Value *> operands,
FuncBuilder &rewriter) const override {
- auto rangeOp = op->cast<RangeOp>();
+ auto rangeOp = cast<RangeOp>(op);
auto rangeDescriptorTy =
convertLinalgType(rangeOp.getResult()->getType(), lowering);
SmallVector<Value *, 4> rewrite(Operation *op, ArrayRef<Value *> operands,
FuncBuilder &rewriter) const override {
- auto sliceOp = op->cast<SliceOp>();
+ auto sliceOp = cast<SliceOp>(op);
auto viewDescriptorTy = convertLinalgType(sliceOp.getViewType(), lowering);
auto viewType = sliceOp.getBaseViewType();
auto int64Ty = lowering.convertType(rewriter.getIntegerType(64));
SmallVector<Value *, 4> rewrite(Operation *op, ArrayRef<Value *> operands,
FuncBuilder &rewriter) const override {
- auto viewOp = op->cast<ViewOp>();
+ auto viewOp = cast<ViewOp>(op);
auto viewDescriptorTy = convertLinalgType(viewOp.getViewType(), lowering);
auto elementTy = rewriter.getType<LLVMType>(
getPtrToElementType(viewOp.getViewType(), lowering));
}
static bool isZero(Value *v) {
- return v->getDefiningOp() && v->getDefiningOp()->isa<ConstantIndexOp>() &&
- v->getDefiningOp()->cast<ConstantIndexOp>().getValue() == 0;
+ return isa_and_nonnull<ConstantIndexOp>(v->getDefiningOp()) &&
+ cast<ConstantIndexOp>(v->getDefiningOp()).getValue() == 0;
}
/// Returns a map that can be used to filter the zero values out of tileSizes.
// Steps must be constant for now to abide by affine.for semantics.
auto *newStep =
state.getOrCreate(
- step->getDefiningOp()->cast<ConstantIndexOp>().getValue() *
- tileSize->getDefiningOp()->cast<ConstantIndexOp>().getValue());
+ cast<ConstantIndexOp>(step->getDefiningOp()).getValue() *
+ cast<ConstantIndexOp>(tileSize->getDefiningOp()).getValue());
res.push_back(b->create<RangeOp>(loc, mins[idx], maxes[idx], newStep));
// clang-format on
}
assert(ranges[i].getType() && "expected !linalg.range type");
assert(ranges[i].getValue()->getDefiningOp() &&
"need operations to extract range parts");
- auto rangeOp = ranges[i].getValue()->getDefiningOp()->cast<RangeOp>();
+ auto rangeOp = cast<RangeOp>(ranges[i].getValue()->getDefiningOp());
auto lb = rangeOp.min();
auto ub = rangeOp.max();
// This must be a constexpr index until we relax the affine.for constraint
auto step =
- rangeOp.step()->getDefiningOp()->cast<ConstantIndexOp>().getValue();
+ cast<ConstantIndexOp>(rangeOp.step()->getDefiningOp()).getValue();
loops.emplace_back(ivs[i], ValueHandle(lb), ValueHandle(ub), step);
}
assert(loops.size() == ivs.size() && "Mismatch loops vs ivs size");
return view.getResult();
return b->create<SliceOp>(loc, view.getResult(), ranges);
}
- auto slice = viewDefiningOp->cast<SliceOp>();
+ auto slice = cast<SliceOp>(viewDefiningOp);
unsigned idxRange = 0;
SmallVector<Value *, 4> newIndexings;
bool elide = true;
: RewritePattern(StorageCastOp::getOperationName(), 1, context) {}
PatternMatchResult match(Operation *op) const override {
- auto scastOp = op->cast<StorageCastOp>();
+ auto scastOp = cast<StorageCastOp>(op);
if (matchPattern(scastOp.arg(), m_Op<StorageCastOp>())) {
- auto srcScastOp = scastOp.arg()->getDefiningOp()->cast<StorageCastOp>();
+ auto srcScastOp = cast<StorageCastOp>(scastOp.arg()->getDefiningOp());
if (srcScastOp.arg()->getType() == scastOp.getResult()->getType()) {
return matchSuccess();
}
}
void rewrite(Operation *op, PatternRewriter &rewriter) const override {
- auto scastOp = op->cast<StorageCastOp>();
- auto srcScastOp = scastOp.arg()->getDefiningOp()->cast<StorageCastOp>();
+ auto scastOp = cast<StorageCastOp>(op);
+ auto srcScastOp = cast<StorageCastOp>(scastOp.arg()->getDefiningOp());
rewriter.replaceOp(op, srcScastOp.arg());
}
};
State state;
// Is the operand a constant?
- auto qbarrier = op->cast<QuantizeCastOp>();
+ auto qbarrier = cast<QuantizeCastOp>(op);
if (!matchPattern(qbarrier.arg(), m_Constant(&state.value))) {
return matchFailure();
}
}
bool failableRewrite(Operation *op, PatternRewriter &rewriter) const {
- auto fqOp = op->cast<ConstFakeQuant>();
+ auto fqOp = cast<ConstFakeQuant>(op);
auto converter =
ExpressedToUniformQuantizedConverter::forInputType(fqOp.getType());
: RewritePattern(AllocOp::getOperationName(), 1, context) {}
PatternMatchResult match(Operation *op) const override {
- auto alloc = op->cast<AllocOp>();
+ auto alloc = cast<AllocOp>(op);
// Check to see if any dimensions operands are constants. If so, we can
// substitute and drop them.
}
void rewrite(Operation *op, PatternRewriter &rewriter) const override {
- auto allocOp = op->cast<AllocOp>();
+ auto allocOp = cast<AllocOp>(op);
auto memrefType = allocOp.getType();
// Ok, we have one or more constant operands. Collect the non-constant ones
PatternMatchResult matchAndRewrite(Operation *op,
PatternRewriter &rewriter) const override {
// Check if the alloc'ed value has any uses.
- auto alloc = op->cast<AllocOp>();
+ auto alloc = cast<AllocOp>(op);
if (!alloc.use_empty())
return matchFailure();
PatternMatchResult matchAndRewrite(Operation *op,
PatternRewriter &rewriter) const override {
- auto indirectCall = op->cast<CallIndirectOp>();
+ auto indirectCall = cast<CallIndirectOp>(op);
// Check that the callee is a constant operation.
Attribute callee;
PatternMatchResult matchAndRewrite(Operation *op,
PatternRewriter &rewriter) const override {
- auto condbr = op->cast<CondBranchOp>();
+ auto condbr = cast<CondBranchOp>(op);
// Check that the condition is a constant.
if (!matchPattern(condbr.getCondition(), m_Op<ConstantOp>()))
PatternMatchResult matchAndRewrite(Operation *op,
PatternRewriter &rewriter) const override {
- auto dealloc = op->cast<DeallocOp>();
+ auto dealloc = cast<DeallocOp>(op);
// Check that the memref operand's defining operation is an AllocOp.
Value *memref = dealloc.memref();
PatternMatchResult matchAndRewrite(Operation *op,
PatternRewriter &rewriter) const override {
- auto subi = op->cast<SubIOp>();
+ auto subi = cast<SubIOp>(op);
if (subi.getOperand(0) != subi.getOperand(1))
return matchFailure();
PatternMatchResult matchAndRewrite(Operation *op,
PatternRewriter &rewriter) const override {
- auto xorOp = op->cast<XOrOp>();
+ auto xorOp = cast<XOrOp>(op);
if (xorOp.lhs() != xorOp.rhs())
return matchFailure();
void collect(Operation *opToWalk) {
opToWalk->walk([&](Operation *op) {
if (op->isa<AffineForOp>())
- forOps.push_back(op->cast<AffineForOp>());
+ forOps.push_back(cast<AffineForOp>(op));
else if (op->getNumRegions() != 0)
hasNonForRegion = true;
else if (op->isa<LoadOp>())
unsigned getLoadOpCount(Value *memref) {
unsigned loadOpCount = 0;
for (auto *loadOpInst : loads) {
- if (memref == loadOpInst->cast<LoadOp>().getMemRef())
+ if (memref == cast<LoadOp>(loadOpInst).getMemRef())
++loadOpCount;
}
return loadOpCount;
unsigned getStoreOpCount(Value *memref) {
unsigned storeOpCount = 0;
for (auto *storeOpInst : stores) {
- if (memref == storeOpInst->cast<StoreOp>().getMemRef())
+ if (memref == cast<StoreOp>(storeOpInst).getMemRef())
++storeOpCount;
}
return storeOpCount;
void getStoreOpsForMemref(Value *memref,
SmallVectorImpl<Operation *> *storeOps) {
for (auto *storeOpInst : stores) {
- if (memref == storeOpInst->cast<StoreOp>().getMemRef())
+ if (memref == cast<StoreOp>(storeOpInst).getMemRef())
storeOps->push_back(storeOpInst);
}
}
void getLoadOpsForMemref(Value *memref,
SmallVectorImpl<Operation *> *loadOps) {
for (auto *loadOpInst : loads) {
- if (memref == loadOpInst->cast<LoadOp>().getMemRef())
+ if (memref == cast<LoadOp>(loadOpInst).getMemRef())
loadOps->push_back(loadOpInst);
}
}
void getLoadAndStoreMemrefSet(DenseSet<Value *> *loadAndStoreMemrefSet) {
llvm::SmallDenseSet<Value *, 2> loadMemrefs;
for (auto *loadOpInst : loads) {
- loadMemrefs.insert(loadOpInst->cast<LoadOp>().getMemRef());
+ loadMemrefs.insert(cast<LoadOp>(loadOpInst).getMemRef());
}
for (auto *storeOpInst : stores) {
- auto *memref = storeOpInst->cast<StoreOp>().getMemRef();
+ auto *memref = cast<StoreOp>(storeOpInst).getMemRef();
if (loadMemrefs.count(memref) > 0)
loadAndStoreMemrefSet->insert(memref);
}
bool writesToLiveInOrEscapingMemrefs(unsigned id) {
Node *node = getNode(id);
for (auto *storeOpInst : node->stores) {
- auto *memref = storeOpInst->cast<StoreOp>().getMemRef();
+ auto *memref = cast<StoreOp>(storeOpInst).getMemRef();
auto *op = memref->getDefiningOp();
// Return true if 'memref' is a block argument.
if (!op)
Node *node = getNode(id);
for (auto *storeOpInst : node->stores) {
// Return false if there exist out edges from 'id' on 'memref'.
- if (getOutEdgeCount(id, storeOpInst->cast<StoreOp>().getMemRef()) > 0)
+ if (getOutEdgeCount(id, cast<StoreOp>(storeOpInst).getMemRef()) > 0)
return false;
}
return true;
Node node(nextNodeId++, &op);
for (auto *opInst : collector.loadOpInsts) {
node.loads.push_back(opInst);
- auto *memref = opInst->cast<LoadOp>().getMemRef();
+ auto *memref = cast<LoadOp>(opInst).getMemRef();
memrefAccesses[memref].insert(node.id);
}
for (auto *opInst : collector.storeOpInsts) {
node.stores.push_back(opInst);
- auto *memref = opInst->cast<StoreOp>().getMemRef();
+ auto *memref = cast<StoreOp>(opInst).getMemRef();
memrefAccesses[memref].insert(node.id);
}
forToNodeMap[&op] = node.id;
// Create graph node for top-level load op.
Node node(nextNodeId++, &op);
node.loads.push_back(&op);
- auto *memref = op.cast<LoadOp>().getMemRef();
+ auto *memref = cast<LoadOp>(op).getMemRef();
memrefAccesses[memref].insert(node.id);
nodes.insert({node.id, node});
} else if (auto storeOp = dyn_cast<StoreOp>(op)) {
// Create graph node for top-level store op.
Node node(nextNodeId++, &op);
node.stores.push_back(&op);
- auto *memref = op.cast<StoreOp>().getMemRef();
+ auto *memref = cast<StoreOp>(op).getMemRef();
memrefAccesses[memref].insert(node.id);
nodes.insert({node.id, node});
} else if (op.getNumRegions() != 0) {
dstLoads->clear();
SmallVector<Operation *, 4> srcLoadsToKeep;
for (auto *load : *srcLoads) {
- if (load->cast<LoadOp>().getMemRef() == memref)
+ if (cast<LoadOp>(load).getMemRef() == memref)
dstLoads->push_back(load);
else
srcLoadsToKeep.push_back(load);
static void sinkSequentialLoops(MemRefDependenceGraph::Node *node) {
assert(node->op->isa<AffineForOp>());
SmallVector<AffineForOp, 4> loops;
- AffineForOp curr = node->op->cast<AffineForOp>();
+ AffineForOp curr = cast<AffineForOp>(node->op);
getPerfectlyNestedLoops(loops, curr);
if (loops.size() < 2)
return;
// Builder to create constants at the top level.
FuncBuilder top(forInst->getFunction());
// Create new memref type based on slice bounds.
- auto *oldMemRef = srcStoreOpInst->cast<StoreOp>().getMemRef();
+ auto *oldMemRef = cast<StoreOp>(srcStoreOpInst).getMemRef();
auto oldMemRefType = oldMemRef->getType().cast<MemRefType>();
unsigned rank = oldMemRefType.getRank();
// Gather all memrefs from 'srcNode' store ops.
DenseSet<Value *> storeMemrefs;
for (auto *storeOpInst : srcNode->stores) {
- storeMemrefs.insert(storeOpInst->cast<StoreOp>().getMemRef());
+ storeMemrefs.insert(cast<StoreOp>(storeOpInst).getMemRef());
}
// Return false if any of the following are true:
// *) 'srcNode' writes to a live in/out memref other than 'memref'.
DenseSet<Value *> visitedMemrefs;
while (!loads.empty()) {
// Get memref of load on top of the stack.
- auto *memref = loads.back()->cast<LoadOp>().getMemRef();
+ auto *memref = cast<LoadOp>(loads.back()).getMemRef();
if (visitedMemrefs.count(memref) > 0)
continue;
visitedMemrefs.insert(memref);
// Gather 'dstNode' store ops to 'memref'.
SmallVector<Operation *, 2> dstStoreOpInsts;
for (auto *storeOpInst : dstNode->stores)
- if (storeOpInst->cast<StoreOp>().getMemRef() == memref)
+ if (cast<StoreOp>(storeOpInst).getMemRef() == memref)
dstStoreOpInsts.push_back(storeOpInst);
unsigned bestDstLoopDepth;
LLVM_DEBUG(llvm::dbgs() << "\tslice loop nest:\n"
<< *sliceLoopNest.getOperation() << "\n");
// Move 'dstAffineForOp' before 'insertPointInst' if needed.
- auto dstAffineForOp = dstNode->op->cast<AffineForOp>();
+ auto dstAffineForOp = cast<AffineForOp>(dstNode->op);
if (insertPointInst != dstAffineForOp.getOperation()) {
dstAffineForOp.getOperation()->moveBefore(insertPointInst);
}
// Create private memref for 'memref' in 'dstAffineForOp'.
SmallVector<Operation *, 4> storesForMemref;
for (auto *storeOpInst : sliceCollector.storeOpInsts) {
- if (storeOpInst->cast<StoreOp>().getMemRef() == memref)
+ if (cast<StoreOp>(storeOpInst).getMemRef() == memref)
storesForMemref.push_back(storeOpInst);
}
assert(storesForMemref.size() == 1);
// Add new load ops to current Node load op list 'loads' to
// continue fusing based on new operands.
for (auto *loadOpInst : dstLoopCollector.loadOpInsts) {
- auto *loadMemRef = loadOpInst->cast<LoadOp>().getMemRef();
+ auto *loadMemRef = cast<LoadOp>(loadOpInst).getMemRef();
if (visitedMemrefs.count(loadMemRef) == 0)
loads.push_back(loadOpInst);
}
auto sliceLoopNest = mlir::insertBackwardComputationSlice(
sibLoadOpInst, dstLoadOpInsts[0], bestDstLoopDepth, &sliceState);
if (sliceLoopNest != nullptr) {
- auto dstForInst = dstNode->op->cast<AffineForOp>();
+ auto dstForInst = cast<AffineForOp>(dstNode->op);
// Update operation position of fused loop nest (if needed).
if (insertPointInst != dstForInst.getOperation()) {
dstForInst.getOperation()->moveBefore(insertPointInst);
// Check that all stores are to the same memref.
DenseSet<Value *> storeMemrefs;
for (auto *storeOpInst : sibNode->stores) {
- storeMemrefs.insert(storeOpInst->cast<StoreOp>().getMemRef());
+ storeMemrefs.insert(cast<StoreOp>(storeOpInst).getMemRef());
}
if (storeMemrefs.size() != 1)
return false;
}
// Collect dst loop stats after memref privatizaton transformation.
- auto dstForInst = dstNode->op->cast<AffineForOp>();
+ auto dstForInst = cast<AffineForOp>(dstNode->op);
LoopNestStateCollector dstLoopCollector;
dstLoopCollector.collect(dstForInst.getOperation());
// Clear and add back loads and stores
// function.
if (mdg->getOutEdgeCount(sibNode->id) == 0) {
mdg->removeNode(sibNode->id);
- sibNode->op->cast<AffineForOp>().erase();
+ sibNode->op->erase();
}
}
hasInnerLoops |= walkPostOrder(block.begin(), block.end());
if (opInst->isa<AffineForOp>()) {
if (!hasInnerLoops)
- loops.push_back(opInst->cast<AffineForOp>());
+ loops.push_back(cast<AffineForOp>(opInst));
return true;
}
return hasInnerLoops;
// Insert the cleanup loop right after 'forOp'.
FuncBuilder builder(forInst->getBlock(),
std::next(Block::iterator(forInst)));
- auto cleanupAffineForOp = builder.clone(*forInst)->cast<AffineForOp>();
+ auto cleanupAffineForOp = cast<AffineForOp>(builder.clone(*forInst));
// Adjust the lower bound of the cleanup loop; its upper bound is the same
// as the original loop's upper bound.
AffineMap cleanupMap;
} else if (auto forOp = dyn_cast<AffineForOp>(op)) {
if (lowerAffineFor(forOp))
return signalPassFailure();
- } else if (lowerAffineApply(op->cast<AffineApplyOp>())) {
+ } else if (lowerAffineApply(cast<AffineApplyOp>(op))) {
return signalPassFailure();
}
}
using namespace mlir::edsc::op;
using namespace mlir::edsc::intrinsics;
- VectorTransferReadOp transfer = op->cast<VectorTransferReadOp>();
+ VectorTransferReadOp transfer = cast<VectorTransferReadOp>(op);
// 1. Setup all the captures.
ScopedContext scope(FuncBuilder(op), transfer.getLoc());
using namespace mlir::edsc::op;
using namespace mlir::edsc::intrinsics;
- VectorTransferWriteOp transfer = op->cast<VectorTransferWriteOp>();
+ VectorTransferWriteOp transfer = cast<VectorTransferWriteOp>(op);
// 1. Setup all the captures.
ScopedContext scope(FuncBuilder(op), transfer.getLoc());
continue;
}
- auto terminator = term->cast<VectorTransferWriteOp>();
+ auto terminator = cast<VectorTransferWriteOp>(term);
LLVM_DEBUG(dbgs() << "\nFrom terminator:" << *term);
// Get the transitive use-defs starting from terminator, limited to the
return;
// Perform the actual store to load forwarding.
- Value *storeVal = lastWriteStoreOp->cast<StoreOp>().getValueToStore();
+ Value *storeVal = cast<StoreOp>(lastWriteStoreOp).getValueToStore();
loadOp.getResult()->replaceAllUsesWith(storeVal);
// Record the memref for a later sweep to optimize away.
memrefsToErase.insert(loadOp.getMemRef());
// For each start operation, we look for a matching finish operation.
for (auto *dmaStartInst : dmaStartInsts) {
for (auto *dmaFinishInst : dmaFinishInsts) {
- if (checkTagMatch(dmaStartInst->cast<DmaStartOp>(),
- dmaFinishInst->cast<DmaWaitOp>())) {
+ if (checkTagMatch(cast<DmaStartOp>(dmaStartInst),
+ cast<DmaWaitOp>(dmaFinishInst))) {
startWaitPairs.push_back({dmaStartInst, dmaFinishInst});
break;
}
for (auto &pair : startWaitPairs) {
auto *dmaStartInst = pair.first;
Value *oldMemRef = dmaStartInst->getOperand(
- dmaStartInst->cast<DmaStartOp>().getFasterMemPos());
+ cast<DmaStartOp>(dmaStartInst).getFasterMemPos());
if (!doubleBuffer(oldMemRef, forOp)) {
// Normally, double buffering should not fail because we already checked
// that there are no uses outside.
Operation *op = forOp.getOperation();
if (getLargestDivisorOfTripCount(forOp) % unrollFactor != 0) {
FuncBuilder builder(op->getBlock(), ++Block::iterator(op));
- auto cleanupForInst = builder.clone(*op)->cast<AffineForOp>();
+ auto cleanupForInst = cast<AffineForOp>(builder.clone(*op));
AffineMap cleanupMap;
SmallVector<Value *, 4> cleanupOperands;
getCleanupLoopLowerBound(forOp, unrollFactor, &cleanupMap, &cleanupOperands,
void mlir::sinkLoop(AffineForOp forOp, unsigned loopDepth) {
for (unsigned i = 0; i < loopDepth; ++i) {
assert(forOp.getBody()->front().isa<AffineForOp>());
- AffineForOp nextForOp = forOp.getBody()->front().cast<AffineForOp>();
+ AffineForOp nextForOp = cast<AffineForOp>(forOp.getBody()->front());
interchangeLoops(forOp, nextForOp);
}
}
SmallVector<NestedMatch, 8> matches;
pattern.match(f, &matches);
for (auto m : matches) {
- auto app = m.getMatchedOperation()->cast<AffineApplyOp>();
+ auto app = cast<AffineApplyOp>(m.getMatchedOperation());
FuncBuilder b(m.getMatchedOperation());
SmallVector<Value *, 8> operands(app.getOperands());
makeComposedAffineApply(&b, app.getLoc(), app.getAffineMap(), operands);
isVectorizableLoopPtrFactory(const llvm::DenseSet<Operation *> ¶llelLoops,
int fastestVaryingMemRefDimension) {
return [¶llelLoops, fastestVaryingMemRefDimension](Operation &forOp) {
- auto loop = forOp.cast<AffineForOp>();
+ auto loop = cast<AffineForOp>(forOp);
auto parallelIt = parallelLoops.find(loop);
if (parallelIt == parallelLoops.end())
return false;
vectorizeLoopsAndLoadsRecursively(NestedMatch oneMatch,
VectorizationState *state) {
auto *loopInst = oneMatch.getMatchedOperation();
- auto loop = loopInst->cast<AffineForOp>();
+ auto loop = cast<AffineForOp>(loopInst);
auto childrenMatches = oneMatch.getMatchedChildren();
// 1. DFS postorder recursion, if any of my children fails, I fail too.
/// anything below it fails.
static LogicalResult vectorizeRootMatch(NestedMatch m,
VectorizationStrategy *strategy) {
- auto loop = m.getMatchedOperation()->cast<AffineForOp>();
+ auto loop = cast<AffineForOp>(m.getMatchedOperation());
VectorizationState state;
state.strategy = strategy;
/// RAII.
auto *loopInst = loop.getOperation();
FuncBuilder builder(loopInst);
- auto clonedLoop = builder.clone(*loopInst)->cast<AffineForOp>();
+ auto clonedLoop = cast<AffineForOp>(builder.clone(*loopInst));
struct Guard {
LogicalResult failure() {
loop.getInductionVar()->replaceAllUsesWith(clonedLoop.getInductionVar());