Replace Operation::cast with llvm::cast.
authorRiver Riddle <riverriddle@google.com>
Sun, 12 May 2019 00:57:32 +0000 (17:57 -0700)
committerMehdi Amini <joker.eph@gmail.com>
Mon, 20 May 2019 20:37:42 +0000 (13:37 -0700)
--

PiperOrigin-RevId: 247785983

41 files changed:
mlir/examples/Linalg/Linalg1/lib/Analysis.cpp
mlir/examples/Linalg/Linalg1/lib/Common.cpp
mlir/examples/Linalg/Linalg1/lib/ConvertToLLVMDialect.cpp
mlir/examples/Linalg/Linalg1/lib/Utils.cpp
mlir/examples/Linalg/Linalg2/lib/Transforms.cpp
mlir/examples/Linalg/Linalg3/include/linalg3/TensorOps-inl.h
mlir/examples/Linalg/Linalg3/lib/ConvertToLLVMDialect.cpp
mlir/examples/Linalg/Linalg3/lib/Transforms.cpp
mlir/examples/toy/Ch4/mlir/ShapeInferencePass.cpp
mlir/examples/toy/Ch4/mlir/ToyCombine.cpp
mlir/examples/toy/Ch5/mlir/EarlyLowering.cpp
mlir/examples/toy/Ch5/mlir/LateLowering.cpp
mlir/examples/toy/Ch5/mlir/ShapeInferencePass.cpp
mlir/examples/toy/Ch5/mlir/ToyCombine.cpp
mlir/include/mlir/IR/OpDefinition.h
mlir/include/mlir/IR/Operation.h
mlir/lib/AffineOps/AffineOps.cpp
mlir/lib/Analysis/LoopAnalysis.cpp
mlir/lib/Analysis/Utils.cpp
mlir/lib/Analysis/VectorAnalysis.cpp
mlir/lib/FxpMathOps/Transforms/LowerUniformRealMath.cpp
mlir/lib/LLVMIR/Transforms/ConvertToLLVMDialect.cpp
mlir/lib/Linalg/IR/LinalgOps.cpp
mlir/lib/Linalg/Transforms/LowerToLLVMDialect.cpp
mlir/lib/Linalg/Transforms/Tiling.cpp
mlir/lib/Linalg/Utils/Utils.cpp
mlir/lib/Quantization/IR/QuantOps.cpp
mlir/lib/Quantization/Transforms/ConvertConst.cpp
mlir/lib/Quantization/Transforms/ConvertSimQuant.cpp
mlir/lib/StandardOps/Ops.cpp
mlir/lib/Transforms/LoopFusion.cpp
mlir/lib/Transforms/LoopUnroll.cpp
mlir/lib/Transforms/LoopUnrollAndJam.cpp
mlir/lib/Transforms/LowerAffine.cpp
mlir/lib/Transforms/LowerVectorTransfers.cpp
mlir/lib/Transforms/MaterializeVectors.cpp
mlir/lib/Transforms/MemRefDataFlowOpt.cpp
mlir/lib/Transforms/PipelineDataTransfer.cpp
mlir/lib/Transforms/Utils/LoopUtils.cpp
mlir/lib/Transforms/Vectorization/VectorizerTestPass.cpp
mlir/lib/Transforms/Vectorize.cpp

index a7fba17..092b83a 100644 (file)
@@ -35,7 +35,7 @@ ViewOp linalg::getViewBaseViewOp(Value *view) {
     view = slice.getParentView();
     assert(viewType.isa<ViewType>() && "expected a ViewType");
   }
-  return view->getDefiningOp()->cast<ViewOp>();
+  return cast<ViewOp>(view->getDefiningOp());
 }
 
 Value *linalg::getViewSupportingMemRef(Value *view) {
@@ -51,12 +51,12 @@ std::pair<mlir::Value *, unsigned> linalg::getViewRootIndexing(Value *view,
   if (auto viewOp = dyn_cast<ViewOp>(view->getDefiningOp()))
     return std::make_pair(viewOp.getIndexing(dim), dim);
 
-  auto sliceOp = view->getDefiningOp()->cast<SliceOp>();
+  auto sliceOp = cast<SliceOp>(view->getDefiningOp());
   auto *parentView = sliceOp.getParentView();
   unsigned sliceDim = sliceOp.getSlicingDim();
   auto *indexing = sliceOp.getIndexing();
   if (indexing->getDefiningOp()) {
-    if (auto rangeOp = indexing->getDefiningOp()->cast<RangeOp>()) {
+    if (auto rangeOp = cast<RangeOp>(indexing->getDefiningOp())) {
       // If I sliced with a range and I sliced at this dim, then I'm it.
       if (dim == sliceDim) {
         return std::make_pair(rangeOp.getResult(), dim);
index 278f9c5..1e211bf 100644 (file)
@@ -47,8 +47,8 @@ linalg::common::LoopNestRangeBuilder::LoopNestRangeBuilder(
     auto lb = rangeOp.getMin();
     auto ub = rangeOp.getMax();
     // This must be a constexpr index until we relax the affine.for constraint
-    auto step =
-        rangeOp.getStep()->getDefiningOp()->cast<ConstantIndexOp>().getValue();
+    auto step = llvm::cast<ConstantIndexOp>(rangeOp.getStep()->getDefiningOp())
+                    .getValue();
     loops.emplace_back(ivs[i], ValueHandle(lb), ValueHandle(ub), step);
   }
 }
index 6097240..48884b1 100644 (file)
@@ -155,7 +155,7 @@ public:
 
   SmallVector<Value *, 4> rewrite(Operation *op, ArrayRef<Value *> operands,
                                   FuncBuilder &rewriter) const override {
-    auto rangeOp = op->cast<linalg::RangeOp>();
+    auto rangeOp = cast<linalg::RangeOp>(op);
     auto rangeDescriptorType =
         linalg::convertLinalgType(rangeOp.getResult()->getType());
 
@@ -187,7 +187,7 @@ public:
 
   SmallVector<Value *, 4> rewrite(Operation *op, ArrayRef<Value *> operands,
                                   FuncBuilder &rewriter) const override {
-    auto viewOp = op->cast<linalg::ViewOp>();
+    auto viewOp = cast<linalg::ViewOp>(op);
     auto viewDescriptorType = linalg::convertLinalgType(viewOp.getViewType());
     auto memrefType =
         viewOp.getSupportingMemRef()->getType().cast<MemRefType>();
@@ -319,7 +319,7 @@ public:
 
   SmallVector<Value *, 4> rewrite(Operation *op, ArrayRef<Value *> operands,
                                   FuncBuilder &rewriter) const override {
-    auto sliceOp = op->cast<linalg::SliceOp>();
+    auto sliceOp = cast<linalg::SliceOp>(op);
     auto newViewDescriptorType =
         linalg::convertLinalgType(sliceOp.getViewType());
     auto elementType = rewriter.getType<LLVM::LLVMType>(
index 5bcebc7..05070a9 100644 (file)
@@ -35,7 +35,7 @@ unsigned linalg::getViewRank(Value *view) {
   assert(view->getType().isa<ViewType>() && "expected a ViewType");
   if (auto viewOp = dyn_cast<ViewOp>(view->getDefiningOp()))
     return viewOp.getRank();
-  return view->getDefiningOp()->cast<SliceOp>().getRank();
+  return cast<SliceOp>(view->getDefiningOp()).getRank();
 }
 
 ViewOp linalg::emitAndReturnViewOpFromMemRef(Value *memRef) {
index 83fd9ad..9df0af8 100644 (file)
@@ -28,6 +28,7 @@
 #include "mlir/IR/StandardTypes.h"
 
 using llvm::ArrayRef;
+using llvm::cast;
 using llvm::SmallVector;
 using mlir::FuncBuilder;
 using mlir::MemRefType;
@@ -49,7 +50,7 @@ static SmallVector<Value *, 8> getViewChain(mlir::Value *v) {
 
   SmallVector<mlir::Value *, 8> tmp;
   do {
-    auto sliceOp = v->getDefiningOp()->cast<SliceOp>(); // must be a slice op
+    auto sliceOp = cast<SliceOp>(v->getDefiningOp()); // must be a slice op
     tmp.push_back(v);
     v = sliceOp.getParentView();
   } while (!v->getType().isa<ViewType>());
@@ -62,15 +63,15 @@ static mlir::Value *createFullyComposedIndexing(unsigned dim,
                                                 ArrayRef<Value *> chain) {
   using namespace mlir::edsc::op;
   assert(chain.front()->getType().isa<ViewType>() && "must be a ViewType");
-  auto viewOp = chain.front()->getDefiningOp()->cast<ViewOp>();
+  auto viewOp = cast<ViewOp>(chain.front()->getDefiningOp());
   auto *indexing = viewOp.getIndexing(dim);
   if (!indexing->getType().isa<RangeType>())
     return indexing;
-  auto rangeOp = indexing->getDefiningOp()->cast<RangeOp>();
+  auto rangeOp = cast<RangeOp>(indexing->getDefiningOp());
   Value *min = rangeOp.getMin(), *max = rangeOp.getMax(),
         *step = rangeOp.getStep();
   for (auto *v : chain.drop_front(1)) {
-    auto slice = v->getDefiningOp()->cast<SliceOp>();
+    auto slice = cast<SliceOp>(v->getDefiningOp());
     if (slice.getRank() != slice.getParentRank()) {
       // Rank-reducing slice.
       if (slice.getSlicingDim() == dim) {
@@ -82,7 +83,7 @@ static mlir::Value *createFullyComposedIndexing(unsigned dim,
       dim = (slice.getSlicingDim() < dim) ? dim - 1 : dim;
     } else { // not a rank-reducing slice.
       if (slice.getSlicingDim() == dim) {
-        auto range = slice.getIndexing()->getDefiningOp()->cast<RangeOp>();
+        auto range = cast<RangeOp>(slice.getIndexing()->getDefiningOp());
         auto oldMin = min;
         min = ValueHandle(min) + ValueHandle(range.getMin());
         // ideally: max = min(oldMin + ValueHandle(range.getMax()), oldMax);
@@ -110,5 +111,5 @@ ViewOp linalg::emitAndReturnFullyComposedView(Value *v) {
   for (unsigned idx = 0; idx < rank; ++idx) {
     ranges.push_back(createFullyComposedIndexing(idx, chain));
   }
-  return view(memRef, ranges).getOperation()->cast<ViewOp>();
+  return cast<ViewOp>(view(memRef, ranges).getOperation());
 }
index 3090f29..2c47541 100644 (file)
@@ -94,7 +94,7 @@ extractRangesFromViewOrSliceOp(mlir::Value *view) {
   if (auto viewOp = llvm::dyn_cast<linalg::ViewOp>(view->getDefiningOp()))
     return viewOp.getRanges();
 
-  auto sliceOp = view->getDefiningOp()->cast<linalg::SliceOp>();
+  auto sliceOp = llvm::cast<linalg::SliceOp>(view->getDefiningOp());
   unsigned slicingDim = sliceOp.getSlicingDim();
   auto *indexing = *(sliceOp.getIndexings().begin());
   bool isRankReducing = indexing->getType().isa<mlir::IndexType>();
index f1bb90d..22feb66 100644 (file)
@@ -71,7 +71,7 @@ public:
   // a getelementptr.
   Value *obtainDataPtr(Operation *op, Value *viewDescriptor,
                        ArrayRef<Value *> indices, FuncBuilder &rewriter) const {
-    auto loadOp = op->cast<Op>();
+    auto loadOp = cast<Op>(op);
     auto elementType =
         loadOp.getViewType().template cast<linalg::ViewType>().getElementType();
     auto *llvmPtrType = linalg::convertLinalgType(elementType)
index bce7f58..6309300 100644 (file)
@@ -95,7 +95,7 @@ extractFromRanges(ArrayRef<Value *> ranges,
   SmallVector<Value *, 4> res;
   res.reserve(ranges.size());
   for (auto *v : ranges) {
-    auto r = v->getDefiningOp()->cast<RangeOp>();
+    auto r = cast<RangeOp>(v->getDefiningOp());
     res.push_back(extract(r));
   }
   return res;
@@ -149,9 +149,9 @@ linalg::makeGenericLoopRanges(AffineMap operandRangesToLoopMaps,
   for (auto z : llvm::zip(res.steps, tileSizes)) {
     auto *step = std::get<0>(z);
     auto tileSize = std::get<1>(z);
-    auto stepValue = step->getDefiningOp()->cast<ConstantIndexOp>().getValue();
+    auto stepValue = cast<ConstantIndexOp>(step->getDefiningOp()).getValue();
     auto tileSizeValue =
-        tileSize->getDefiningOp()->cast<ConstantIndexOp>().getValue();
+        cast<ConstantIndexOp>(tileSize->getDefiningOp()).getValue();
     assert(stepValue > 0);
     tiledSteps.push_back(constant_index(stepValue * tileSizeValue));
   }
@@ -236,7 +236,7 @@ emitAndReturnLoadStoreOperands(LoadOrStoreOp loadOrStoreOp, ViewOp viewOp) {
       operands.push_back(indexing);
       continue;
     }
-    RangeOp range = indexing->getDefiningOp()->cast<RangeOp>();
+    RangeOp range = cast<RangeOp>(indexing->getDefiningOp());
     ValueHandle min(range.getMin());
     Value *storeIndex = *(loadOrStoreOp.getIndices().begin() + storeDim++);
     using edsc::op::operator+;
@@ -275,10 +275,10 @@ template <>
 PatternMatchResult
 Rewriter<linalg::LoadOp>::matchAndRewrite(Operation *op,
                                           PatternRewriter &rewriter) const {
-  auto load = op->cast<linalg::LoadOp>();
+  auto load = cast<linalg::LoadOp>(op);
   SliceOp slice = dyn_cast<SliceOp>(load.getView()->getDefiningOp());
   ViewOp view = slice ? emitAndReturnFullyComposedView(slice.getResult())
-                      : load.getView()->getDefiningOp()->cast<ViewOp>();
+                      : cast<ViewOp>(load.getView()->getDefiningOp());
   ScopedContext scope(FuncBuilder(load), load.getLoc());
   auto *memRef = view.getSupportingMemRef();
   auto operands = emitAndReturnLoadStoreOperands(load, view);
@@ -290,10 +290,10 @@ template <>
 PatternMatchResult
 Rewriter<linalg::StoreOp>::matchAndRewrite(Operation *op,
                                            PatternRewriter &rewriter) const {
-  auto store = op->cast<linalg::StoreOp>();
+  auto store = cast<linalg::StoreOp>(op);
   SliceOp slice = dyn_cast<SliceOp>(store.getView()->getDefiningOp());
   ViewOp view = slice ? emitAndReturnFullyComposedView(slice.getResult())
-                      : store.getView()->getDefiningOp()->cast<ViewOp>();
+                      : cast<ViewOp>(store.getView()->getDefiningOp());
   ScopedContext scope(FuncBuilder(store), store.getLoc());
   auto *valueToStore = store.getValueToStore();
   auto *memRef = view.getSupportingMemRef();
index c9f98e7..5f024ea 100644 (file)
@@ -350,7 +350,7 @@ public:
     // Finally, update the return type of the function based on the argument to
     // the return operation.
     for (auto &block : f->getBlocks()) {
-      auto ret = block.getTerminator()->cast<ReturnOp>();
+      auto ret = llvm::cast<ReturnOp>(block.getTerminator());
       if (!ret)
         continue;
       if (ret.getNumOperands() &&
index 942ce86..4175fc2 100644 (file)
@@ -49,7 +49,7 @@ struct SimplifyRedundantTranspose : public mlir::RewritePattern {
                   mlir::PatternRewriter &rewriter) const override {
     // We can directly cast the current operation as this will only get invoked
     // on TransposeOp.
-    TransposeOp transpose = op->cast<TransposeOp>();
+    TransposeOp transpose = llvm::cast<TransposeOp>(op);
     // Look through the input of the current transpose.
     mlir::Value *transposeInput = transpose.getOperand();
     TransposeOp transposeInputOp =
@@ -73,7 +73,7 @@ struct SimplifyReshapeConstant : public mlir::RewritePattern {
   mlir::PatternMatchResult
   matchAndRewrite(mlir::Operation *op,
                   mlir::PatternRewriter &rewriter) const override {
-    ReshapeOp reshape = op->cast<ReshapeOp>();
+    ReshapeOp reshape = llvm::cast<ReshapeOp>(op);
     // Look through the input of the current reshape.
     ConstantOp constantOp = llvm::dyn_cast_or_null<ConstantOp>(
         reshape.getOperand()->getDefiningOp());
@@ -120,7 +120,7 @@ struct SimplifyReshapeReshape : public mlir::RewritePattern {
   mlir::PatternMatchResult
   matchAndRewrite(mlir::Operation *op,
                   mlir::PatternRewriter &rewriter) const override {
-    ReshapeOp reshape = op->cast<ReshapeOp>();
+    ReshapeOp reshape = llvm::cast<ReshapeOp>(op);
     // Look through the input of the current reshape.
     mlir::Value *reshapeInput = reshape.getOperand();
     // If the input is defined by another reshape, bingo!
@@ -142,7 +142,7 @@ struct SimplifyNullReshape : public mlir::RewritePattern {
   mlir::PatternMatchResult
   matchAndRewrite(mlir::Operation *op,
                   mlir::PatternRewriter &rewriter) const override {
-    ReshapeOp reshape = op->cast<ReshapeOp>();
+    ReshapeOp reshape = llvm::cast<ReshapeOp>(op);
     if (reshape.getOperand()->getType() != reshape.getResult()->getType())
       return matchFailure();
     rewriter.replaceOp(reshape, {reshape.getOperand()});
index db6ba73..3e640de 100644 (file)
@@ -92,7 +92,7 @@ public:
     using intrinsics::constant_index;
     using linalg::intrinsics::range;
     using linalg::intrinsics::view;
-    toy::MulOp mul = op->cast<toy::MulOp>();
+    toy::MulOp mul = cast<toy::MulOp>(op);
     auto loc = mul.getLoc();
     Value *result = memRefTypeCast(
         rewriter, rewriter.create<toy::AllocOp>(loc, mul.getResult()->getType())
index 534b5cb..0a2ff1d 100644 (file)
@@ -93,7 +93,7 @@ public:
   /// number must match the number of result of `op`.
   SmallVector<Value *, 4> rewrite(Operation *op, ArrayRef<Value *> operands,
                                   FuncBuilder &rewriter) const override {
-    auto add = op->cast<toy::AddOp>();
+    auto add = cast<toy::AddOp>(op);
     auto loc = add.getLoc();
     // Create a `toy.alloc` operation to allocate the output buffer for this op.
     Value *result = memRefTypeCast(
@@ -135,7 +135,7 @@ public:
     // Get or create the declaration of the printf function in the module.
     Function *printfFunc = getPrintf(*op->getFunction()->getModule());
 
-    auto print = op->cast<toy::PrintOp>();
+    auto print = cast<toy::PrintOp>(op);
     auto loc = print.getLoc();
     // We will operate on a MemRef abstraction, we use a type.cast to get one
     // if our operand is still a Toy array.
@@ -234,7 +234,7 @@ public:
 
   SmallVector<Value *, 4> rewrite(Operation *op, ArrayRef<Value *> operands,
                                   FuncBuilder &rewriter) const override {
-    toy::ConstantOp cstOp = op->cast<toy::ConstantOp>();
+    toy::ConstantOp cstOp = cast<toy::ConstantOp>(op);
     auto loc = cstOp.getLoc();
     auto retTy = cstOp.getResult()->getType().cast<toy::ToyArrayType>();
     auto shape = retTy.getShape();
@@ -277,7 +277,7 @@ public:
 
   SmallVector<Value *, 4> rewrite(Operation *op, ArrayRef<Value *> operands,
                                   FuncBuilder &rewriter) const override {
-    auto transpose = op->cast<toy::TransposeOp>();
+    auto transpose = cast<toy::TransposeOp>(op);
     auto loc = transpose.getLoc();
     Value *result = memRefTypeCast(
         rewriter,
@@ -309,7 +309,7 @@ public:
 
   SmallVector<Value *, 4> rewrite(Operation *op, ArrayRef<Value *> operands,
                                   FuncBuilder &rewriter) const override {
-    auto retOp = op->cast<toy::ReturnOp>();
+    auto retOp = cast<toy::ReturnOp>(op);
     using namespace edsc;
     auto loc = retOp.getLoc();
     // Argument is optional, handle both cases.
index 4e17b23..ab99019 100644 (file)
@@ -357,7 +357,7 @@ public:
     // Finally, update the return type of the function based on the argument to
     // the return operation.
     for (auto &block : f->getBlocks()) {
-      auto ret = block.getTerminator()->cast<ReturnOp>();
+      auto ret = llvm::cast<ReturnOp>(block.getTerminator());
       if (!ret)
         continue;
       if (ret.getNumOperands() &&
index 8d6aed6..260f6a6 100644 (file)
@@ -49,7 +49,7 @@ struct SimplifyRedundantTranspose : public mlir::RewritePattern {
                   mlir::PatternRewriter &rewriter) const override {
     // We can directly cast the current operation as this will only get invoked
     // on TransposeOp.
-    TransposeOp transpose = op->cast<TransposeOp>();
+    TransposeOp transpose = llvm::cast<TransposeOp>(op);
     // look through the input to the current transpose
     mlir::Value *transposeInput = transpose.getOperand();
     mlir::Operation *transposeInputInst = transposeInput->getDefiningOp();
@@ -74,7 +74,7 @@ struct SimplifyReshapeConstant : public mlir::RewritePattern {
   mlir::PatternMatchResult
   matchAndRewrite(mlir::Operation *op,
                   mlir::PatternRewriter &rewriter) const override {
-    ReshapeOp reshape = op->cast<ReshapeOp>();
+    ReshapeOp reshape = llvm::cast<ReshapeOp>(op);
     // look through the input to the current reshape
     mlir::Value *reshapeInput = reshape.getOperand();
     mlir::Operation *reshapeInputInst = reshapeInput->getDefiningOp();
@@ -125,7 +125,7 @@ struct SimplifyReshapeReshape : public mlir::RewritePattern {
   mlir::PatternMatchResult
   matchAndRewrite(mlir::Operation *op,
                   mlir::PatternRewriter &rewriter) const override {
-    ReshapeOp reshape = op->cast<ReshapeOp>();
+    ReshapeOp reshape = llvm::cast<ReshapeOp>(op);
     // look through the input to the current reshape
     mlir::Value *reshapeInput = reshape.getOperand();
     mlir::Operation *reshapeInputInst = reshapeInput->getDefiningOp();
@@ -150,7 +150,7 @@ struct SimplifyNullReshape : public mlir::RewritePattern {
   mlir::PatternMatchResult
   matchAndRewrite(mlir::Operation *op,
                   mlir::PatternRewriter &rewriter) const override {
-    ReshapeOp reshape = op->cast<ReshapeOp>();
+    ReshapeOp reshape = llvm::cast<ReshapeOp>(op);
     if (reshape.getOperand()->getType() != reshape.getResult()->getType())
       return matchFailure();
     rewriter.replaceOp(reshape, {reshape.getOperand()});
@@ -185,7 +185,7 @@ struct SimplifyIdentityTypeCast : public mlir::RewritePattern {
   mlir::PatternMatchResult
   matchAndRewrite(mlir::Operation *op,
                   mlir::PatternRewriter &rewriter) const override {
-    TypeCastOp typeCast = op->cast<TypeCastOp>();
+    TypeCastOp typeCast = llvm::cast<TypeCastOp>(op);
     auto resTy = typeCast.getResult()->getType();
     auto *candidateOp = op;
     while (candidateOp && candidateOp->isa<TypeCastOp>()) {
index 2eff412..250fb94 100644 (file)
@@ -183,8 +183,8 @@ public:
   static LogicalResult constantFoldHook(Operation *op,
                                         ArrayRef<Attribute> operands,
                                         SmallVectorImpl<Attribute> &results) {
-    return op->cast<ConcreteType>().constantFold(operands, results,
-                                                 op->getContext());
+    return cast<ConcreteType>(op).constantFold(operands, results,
+                                               op->getContext());
   }
 
   /// Op implementations can implement this hook.  It should attempt to constant
@@ -205,7 +205,7 @@ public:
   /// This is an implementation detail of the folder hook for AbstractOperation.
   static LogicalResult foldHook(Operation *op,
                                 SmallVectorImpl<Value *> &results) {
-    return op->cast<ConcreteType>().fold(results);
+    return cast<ConcreteType>(op).fold(results);
   }
 
   /// This hook implements a generalized folder for this operation.  Operations
@@ -253,7 +253,7 @@ public:
                                         ArrayRef<Attribute> operands,
                                         SmallVectorImpl<Attribute> &results) {
     auto result =
-        op->cast<ConcreteType>().constantFold(operands, op->getContext());
+        cast<ConcreteType>(op).constantFold(operands, op->getContext());
     if (!result)
       return failure();
 
@@ -277,7 +277,7 @@ public:
   /// This is an implementation detail of the folder hook for AbstractOperation.
   static LogicalResult foldHook(Operation *op,
                                 SmallVectorImpl<Value *> &results) {
-    auto *result = op->cast<ConcreteType>().fold();
+    auto *result = cast<ConcreteType>(op).fold();
     if (!result)
       return failure();
     if (result != op->getResult(0))
@@ -808,7 +808,7 @@ public:
   static LogicalResult verifyInvariants(Operation *op) {
     return failure(
         failed(BaseVerifier<Traits<ConcreteType>...>::verifyTrait(op)) ||
-        failed(op->cast<ConcreteType>().verify()));
+        failed(cast<ConcreteType>(op).verify()));
   }
 
   // Returns the properties of an operation by combining the properties of the
index 31ec8ea..088a4e4 100644 (file)
@@ -389,14 +389,6 @@ public:
   // Conversions to declared operations like DimOp
   //===--------------------------------------------------------------------===//
 
-  /// The cast methods perform a cast from an Operation to a typed Op like
-  /// DimOp.  This aborts if the parameter to the template isn't an instance of
-  /// the template type argument.
-  template <typename OpClass> OpClass cast() {
-    assert(isa<OpClass>() && "cast<Ty>() argument of incompatible type!");
-    return OpClass(this);
-  }
-
   /// The is methods return true if the operation is a typed op (like DimOp) of
   /// of the given class.
   template <typename OpClass> bool isa() { return OpClass::classof(this); }
index 2dfed93..f551afb 100644 (file)
@@ -661,7 +661,7 @@ struct SimplifyAffineApply : public RewritePattern {
 
   PatternMatchResult matchAndRewrite(Operation *op,
                                      PatternRewriter &rewriter) const override {
-    auto apply = op->cast<AffineApplyOp>();
+    auto apply = cast<AffineApplyOp>(op);
     auto map = apply.getAffineMap();
 
     AffineMap oldMap = map;
@@ -1010,7 +1010,7 @@ struct AffineForLoopBoundFolder : public RewritePattern {
 
   PatternMatchResult matchAndRewrite(Operation *op,
                                      PatternRewriter &rewriter) const override {
-    auto forOp = op->cast<AffineForOp>();
+    auto forOp = cast<AffineForOp>(op);
     auto foldLowerOrUpperBound = [&forOp](bool lower) {
       // Check to see if each of the operands is the result of a constant.  If
       // so, get the value.  If not, ignore it.
index 60f2b14..3d984c5 100644 (file)
@@ -192,7 +192,7 @@ bool mlir::isAccessInvariant(Value *iv, Value *index) {
     return false;
   }
 
-  auto composeOp = affineApplyOps[0]->cast<AffineApplyOp>();
+  auto composeOp = cast<AffineApplyOp>(affineApplyOps[0]);
   // We need yet another level of indirection because the `dim` index of the
   // access may not correspond to the `dim` index of composeOp.
   return !(AffineValueMap(composeOp).isFunctionOf(0, iv));
index 8d963e4..cc46d65 100644 (file)
@@ -603,7 +603,7 @@ mlir::insertBackwardComputationSlice(Operation *srcOpInst, Operation *dstOpInst,
   auto dstAffineForOp = dstLoopIVs[dstLoopDepth - 1];
   FuncBuilder b(dstAffineForOp.getBody(), dstAffineForOp.getBody()->begin());
   auto sliceLoopNest =
-      b.clone(*srcLoopIVs[0].getOperation())->cast<AffineForOp>();
+      cast<AffineForOp>(b.clone(*srcLoopIVs[0].getOperation()));
 
   Operation *sliceInst =
       getInstAtPosition(positions, /*level=*/0, sliceLoopNest.getBody());
index 8fecf05..627ca7a 100644 (file)
@@ -123,7 +123,7 @@ static AffineMap makePermutationMap(
   for (auto kvp : enclosingLoopToVectorDim) {
     assert(kvp.second < perm.size());
     auto invariants = getInvariantAccesses(
-        kvp.first->cast<AffineForOp>().getInductionVar(), indices);
+        cast<AffineForOp>(kvp.first).getInductionVar(), indices);
     unsigned numIndices = indices.size();
     unsigned countInvariantIndices = 0;
     for (unsigned dim = 0; dim < numIndices; ++dim) {
@@ -181,7 +181,7 @@ AffineMap mlir::makePermutationMap(
     return ::makePermutationMap(load.getIndices(), enclosingLoopToVectorDim);
   }
 
-  auto store = op->cast<StoreOp>();
+  auto store = cast<StoreOp>(op);
   return ::makePermutationMap(store.getIndices(), enclosingLoopToVectorDim);
 }
 
index afd8152..8bfcfa5 100644 (file)
@@ -124,7 +124,7 @@ struct UniformDequantizePattern : public RewritePattern {
 
   PatternMatchResult matchAndRewrite(Operation *op,
                                      PatternRewriter &rewriter) const {
-    auto dcastOp = op->cast<DequantizeCastOp>();
+    auto dcastOp = cast<DequantizeCastOp>(op);
     Type inputType = dcastOp.arg()->getType();
     Type outputType = dcastOp.getResult()->getType();
 
@@ -328,7 +328,7 @@ struct UniformRealAddEwPattern : public RewritePattern {
 
   PatternMatchResult matchAndRewrite(Operation *op,
                                      PatternRewriter &rewriter) const {
-    auto addOp = op->cast<RealAddEwOp>();
+    auto addOp = cast<RealAddEwOp>(op);
     const UniformBinaryOpInfo info(op, addOp.lhs(), addOp.rhs(),
                                    addOp.clamp_min(), addOp.clamp_max());
     if (!info.isValid()) {
@@ -350,7 +350,7 @@ struct UniformRealMulEwPattern : public RewritePattern {
 
   PatternMatchResult matchAndRewrite(Operation *op,
                                      PatternRewriter &rewriter) const {
-    auto mulOp = op->cast<RealMulEwOp>();
+    auto mulOp = cast<RealMulEwOp>(op);
     const UniformBinaryOpInfo info(op, mulOp.lhs(), mulOp.rhs(),
                                    mulOp.clamp_min(), mulOp.clamp_max());
     if (!info.isValid()) {
index 0d9025b..e9aee95 100644 (file)
@@ -414,14 +414,14 @@ struct AllocOpLowering : public LLVMLegalizationPattern<AllocOp> {
   PatternMatchResult match(Operation *op) const override {
     if (!LLVMLegalizationPattern<AllocOp>::match(op))
       return matchFailure();
-    auto allocOp = op->cast<AllocOp>();
+    auto allocOp = cast<AllocOp>(op);
     MemRefType type = allocOp.getType();
     return isSupportedMemRefType(type) ? matchSuccess() : matchFailure();
   }
 
   SmallVector<Value *, 4> rewrite(Operation *op, ArrayRef<Value *> operands,
                                   FuncBuilder &rewriter) const override {
-    auto allocOp = op->cast<AllocOp>();
+    auto allocOp = cast<AllocOp>(op);
     MemRefType type = allocOp.getType();
 
     // Get actual sizes of the memref as values: static sizes are constant
@@ -557,7 +557,7 @@ struct MemRefCastOpLowering : public LLVMLegalizationPattern<MemRefCastOp> {
   PatternMatchResult match(Operation *op) const override {
     if (!LLVMLegalizationPattern<MemRefCastOp>::match(op))
       return matchFailure();
-    auto memRefCastOp = op->cast<MemRefCastOp>();
+    auto memRefCastOp = cast<MemRefCastOp>(op);
     MemRefType sourceType =
         memRefCastOp.getOperand()->getType().cast<MemRefType>();
     MemRefType targetType = memRefCastOp.getType();
@@ -569,7 +569,7 @@ struct MemRefCastOpLowering : public LLVMLegalizationPattern<MemRefCastOp> {
 
   SmallVector<Value *, 4> rewrite(Operation *op, ArrayRef<Value *> operands,
                                   FuncBuilder &rewriter) const override {
-    auto memRefCastOp = op->cast<MemRefCastOp>();
+    auto memRefCastOp = cast<MemRefCastOp>(op);
     auto targetType = memRefCastOp.getType();
     auto sourceType = memRefCastOp.getOperand()->getType().cast<MemRefType>();
 
@@ -636,7 +636,7 @@ struct DimOpLowering : public LLVMLegalizationPattern<DimOp> {
   PatternMatchResult match(Operation *op) const override {
     if (!LLVMLegalizationPattern<DimOp>::match(op))
       return this->matchFailure();
-    auto dimOp = op->cast<DimOp>();
+    auto dimOp = cast<DimOp>(op);
     MemRefType type = dimOp.getOperand()->getType().cast<MemRefType>();
     return isSupportedMemRefType(type) ? matchSuccess() : matchFailure();
   }
@@ -644,7 +644,7 @@ struct DimOpLowering : public LLVMLegalizationPattern<DimOp> {
   SmallVector<Value *, 4> rewrite(Operation *op, ArrayRef<Value *> operands,
                                   FuncBuilder &rewriter) const override {
     assert(operands.size() == 1 && "expected exactly one operand");
-    auto dimOp = op->cast<DimOp>();
+    auto dimOp = cast<DimOp>(op);
     MemRefType type = dimOp.getOperand()->getType().cast<MemRefType>();
 
     SmallVector<Value *, 4> results;
@@ -683,7 +683,7 @@ struct LoadStoreOpLowering : public LLVMLegalizationPattern<Derived> {
   PatternMatchResult match(Operation *op) const override {
     if (!LLVMLegalizationPattern<Derived>::match(op))
       return this->matchFailure();
-    auto loadOp = op->cast<Derived>();
+    auto loadOp = cast<Derived>(op);
     MemRefType type = loadOp.getMemRefType();
     return isSupportedMemRefType(type) ? this->matchSuccess()
                                        : this->matchFailure();
@@ -794,7 +794,7 @@ struct LoadOpLowering : public LoadStoreOpLowering<LoadOp> {
 
   SmallVector<Value *, 4> rewrite(Operation *op, ArrayRef<Value *> operands,
                                   FuncBuilder &rewriter) const override {
-    auto loadOp = op->cast<LoadOp>();
+    auto loadOp = cast<LoadOp>(op);
     auto type = loadOp.getMemRefType();
 
     Value *dataPtr = getDataPtr(op->getLoc(), type, operands.front(),
@@ -815,7 +815,7 @@ struct StoreOpLowering : public LoadStoreOpLowering<StoreOp> {
 
   SmallVector<Value *, 4> rewrite(Operation *op, ArrayRef<Value *> operands,
                                   FuncBuilder &rewriter) const override {
-    auto storeOp = op->cast<StoreOp>();
+    auto storeOp = cast<StoreOp>(op);
     auto type = storeOp.getMemRefType();
 
     Value *dataPtr = getDataPtr(op->getLoc(), type, operands[1],
index 6998da5..8ea45df 100644 (file)
@@ -320,7 +320,7 @@ void mlir::linalg::SliceOp::print(OpAsmPrinter *p) {
 }
 
 ViewOp mlir::linalg::SliceOp::getBaseViewOp() {
-  return getOperand(0)->getDefiningOp()->cast<ViewOp>();
+  return cast<ViewOp>(getOperand(0)->getDefiningOp());
 }
 
 ViewType mlir::linalg::SliceOp::getBaseViewType() {
@@ -505,8 +505,7 @@ ParseResult parseBufferSizeOp(OpAsmParser *parser, OperationState *result);
 /// ```
 void mlir::linalg::impl::printBufferSizeOp(OpAsmPrinter *p, Operation *op) {
   assert(op->getAbstractOperation() && "unregistered operation");
-  *p << op->cast<BufferSizeOp>().getOperationName() << " "
-     << *op->getOperand(0);
+  *p << cast<BufferSizeOp>(op).getOperationName() << " " << *op->getOperand(0);
   p->printOptionalAttrDict(op->getAttrs());
   *p << " : " << op->getOperand(0)->getType();
 }
index 90111a8..2d1f5f2 100644 (file)
@@ -181,7 +181,7 @@ public:
     }
 
     // Get MLIR types for injecting element pointer.
-    auto allocOp = op->cast<BufferAllocOp>();
+    auto allocOp = cast<BufferAllocOp>(op);
     auto elementType = allocOp.getElementType();
     uint64_t elementSize = 0;
     if (auto vectorType = elementType.dyn_cast<VectorType>())
@@ -239,7 +239,7 @@ public:
     }
 
     // Get MLIR types for extracting element pointer.
-    auto deallocOp = op->cast<BufferDeallocOp>();
+    auto deallocOp = cast<BufferDeallocOp>(op);
     auto elementPtrTy = rewriter.getType<LLVMType>(getPtrToElementType(
         deallocOp.getOperand()->getType().cast<BufferType>(), lowering));
 
@@ -283,7 +283,7 @@ public:
   // a getelementptr. This must be called under an edsc::ScopedContext.
   Value *obtainDataPtr(Operation *op, Value *viewDescriptor,
                        ArrayRef<Value *> indices, FuncBuilder &rewriter) const {
-    auto loadOp = op->cast<Op>();
+    auto loadOp = cast<Op>(op);
     auto elementTy = rewriter.getType<LLVMType>(
         getPtrToElementType(loadOp.getViewType(), lowering));
     auto int64Ty = lowering.convertType(rewriter.getIntegerType(64));
@@ -329,7 +329,7 @@ public:
 
   SmallVector<Value *, 4> rewrite(Operation *op, ArrayRef<Value *> operands,
                                   FuncBuilder &rewriter) const override {
-    auto rangeOp = op->cast<RangeOp>();
+    auto rangeOp = cast<RangeOp>(op);
     auto rangeDescriptorTy =
         convertLinalgType(rangeOp.getResult()->getType(), lowering);
 
@@ -355,7 +355,7 @@ public:
 
   SmallVector<Value *, 4> rewrite(Operation *op, ArrayRef<Value *> operands,
                                   FuncBuilder &rewriter) const override {
-    auto sliceOp = op->cast<SliceOp>();
+    auto sliceOp = cast<SliceOp>(op);
     auto viewDescriptorTy = convertLinalgType(sliceOp.getViewType(), lowering);
     auto viewType = sliceOp.getBaseViewType();
     auto int64Ty = lowering.convertType(rewriter.getIntegerType(64));
@@ -453,7 +453,7 @@ public:
 
   SmallVector<Value *, 4> rewrite(Operation *op, ArrayRef<Value *> operands,
                                   FuncBuilder &rewriter) const override {
-    auto viewOp = op->cast<ViewOp>();
+    auto viewOp = cast<ViewOp>(op);
     auto viewDescriptorTy = convertLinalgType(viewOp.getViewType(), lowering);
     auto elementTy = rewriter.getType<LLVMType>(
         getPtrToElementType(viewOp.getViewType(), lowering));
index 6e20542..e1fa74d 100644 (file)
@@ -115,8 +115,8 @@ static SmallVector<Value *, 4> applyMapToRangePart(FuncBuilder *b, Location loc,
 }
 
 static bool isZero(Value *v) {
-  return v->getDefiningOp() && v->getDefiningOp()->isa<ConstantIndexOp>() &&
-         v->getDefiningOp()->cast<ConstantIndexOp>().getValue() == 0;
+  return isa_and_nonnull<ConstantIndexOp>(v->getDefiningOp()) &&
+         cast<ConstantIndexOp>(v->getDefiningOp()).getValue() == 0;
 }
 
 /// Returns a map that can be used to filter the zero values out of tileSizes.
@@ -176,8 +176,8 @@ makeTiledLoopRanges(FuncBuilder *b, Location loc, AffineMap map,
     // Steps must be constant for now to abide by affine.for semantics.
     auto *newStep =
         state.getOrCreate(
-            step->getDefiningOp()->cast<ConstantIndexOp>().getValue() *
-            tileSize->getDefiningOp()->cast<ConstantIndexOp>().getValue());
+            cast<ConstantIndexOp>(step->getDefiningOp()).getValue() *
+            cast<ConstantIndexOp>(tileSize->getDefiningOp()).getValue());
     res.push_back(b->create<RangeOp>(loc, mins[idx], maxes[idx], newStep));
     // clang-format on
   }
index 98cf4b7..6732fa1 100644 (file)
@@ -42,12 +42,12 @@ mlir::edsc::LoopNestRangeBuilder::LoopNestRangeBuilder(
     assert(ranges[i].getType() && "expected !linalg.range type");
     assert(ranges[i].getValue()->getDefiningOp() &&
            "need operations to extract range parts");
-    auto rangeOp = ranges[i].getValue()->getDefiningOp()->cast<RangeOp>();
+    auto rangeOp = cast<RangeOp>(ranges[i].getValue()->getDefiningOp());
     auto lb = rangeOp.min();
     auto ub = rangeOp.max();
     // This must be a constexpr index until we relax the affine.for constraint
     auto step =
-        rangeOp.step()->getDefiningOp()->cast<ConstantIndexOp>().getValue();
+        cast<ConstantIndexOp>(rangeOp.step()->getDefiningOp()).getValue();
     loops.emplace_back(ivs[i], ValueHandle(lb), ValueHandle(ub), step);
   }
   assert(loops.size() == ivs.size() && "Mismatch loops vs ivs size");
@@ -106,7 +106,7 @@ Value *mlir::createOrReturnView(FuncBuilder *b, Location loc,
       return view.getResult();
     return b->create<SliceOp>(loc, view.getResult(), ranges);
   }
-  auto slice = viewDefiningOp->cast<SliceOp>();
+  auto slice = cast<SliceOp>(viewDefiningOp);
   unsigned idxRange = 0;
   SmallVector<Value *, 4> newIndexings;
   bool elide = true;
index 046ad85..ab6d97f 100644 (file)
@@ -43,9 +43,9 @@ public:
       : RewritePattern(StorageCastOp::getOperationName(), 1, context) {}
 
   PatternMatchResult match(Operation *op) const override {
-    auto scastOp = op->cast<StorageCastOp>();
+    auto scastOp = cast<StorageCastOp>(op);
     if (matchPattern(scastOp.arg(), m_Op<StorageCastOp>())) {
-      auto srcScastOp = scastOp.arg()->getDefiningOp()->cast<StorageCastOp>();
+      auto srcScastOp = cast<StorageCastOp>(scastOp.arg()->getDefiningOp());
       if (srcScastOp.arg()->getType() == scastOp.getResult()->getType()) {
         return matchSuccess();
       }
@@ -54,8 +54,8 @@ public:
   }
 
   void rewrite(Operation *op, PatternRewriter &rewriter) const override {
-    auto scastOp = op->cast<StorageCastOp>();
-    auto srcScastOp = scastOp.arg()->getDefiningOp()->cast<StorageCastOp>();
+    auto scastOp = cast<StorageCastOp>(op);
+    auto srcScastOp = cast<StorageCastOp>(scastOp.arg()->getDefiningOp());
     rewriter.replaceOp(op, srcScastOp.arg());
   }
 };
index 21a0de2..ad41f8f 100644 (file)
@@ -59,7 +59,7 @@ PatternMatchResult QuantizedConstRewrite::match(Operation *op) const {
   State state;
 
   // Is the operand a constant?
-  auto qbarrier = op->cast<QuantizeCastOp>();
+  auto qbarrier = cast<QuantizeCastOp>(op);
   if (!matchPattern(qbarrier.arg(), m_Constant(&state.value))) {
     return matchFailure();
   }
index 4df7b88..c62adc8 100644 (file)
@@ -59,7 +59,7 @@ public:
   }
 
   bool failableRewrite(Operation *op, PatternRewriter &rewriter) const {
-    auto fqOp = op->cast<ConstFakeQuant>();
+    auto fqOp = cast<ConstFakeQuant>(op);
 
     auto converter =
         ExpressedToUniformQuantizedConverter::forInputType(fqOp.getType());
index bc68a78..59c1400 100644 (file)
@@ -283,7 +283,7 @@ struct SimplifyAllocConst : public RewritePattern {
       : RewritePattern(AllocOp::getOperationName(), 1, context) {}
 
   PatternMatchResult match(Operation *op) const override {
-    auto alloc = op->cast<AllocOp>();
+    auto alloc = cast<AllocOp>(op);
 
     // Check to see if any dimensions operands are constants.  If so, we can
     // substitute and drop them.
@@ -294,7 +294,7 @@ struct SimplifyAllocConst : public RewritePattern {
   }
 
   void rewrite(Operation *op, PatternRewriter &rewriter) const override {
-    auto allocOp = op->cast<AllocOp>();
+    auto allocOp = cast<AllocOp>(op);
     auto memrefType = allocOp.getType();
 
     // Ok, we have one or more constant operands.  Collect the non-constant ones
@@ -352,7 +352,7 @@ struct SimplifyDeadAlloc : public RewritePattern {
   PatternMatchResult matchAndRewrite(Operation *op,
                                      PatternRewriter &rewriter) const override {
     // Check if the alloc'ed value has any uses.
-    auto alloc = op->cast<AllocOp>();
+    auto alloc = cast<AllocOp>(op);
     if (!alloc.use_empty())
       return matchFailure();
 
@@ -468,7 +468,7 @@ struct SimplifyIndirectCallWithKnownCallee : public RewritePattern {
 
   PatternMatchResult matchAndRewrite(Operation *op,
                                      PatternRewriter &rewriter) const override {
-    auto indirectCall = op->cast<CallIndirectOp>();
+    auto indirectCall = cast<CallIndirectOp>(op);
 
     // Check that the callee is a constant operation.
     Attribute callee;
@@ -978,7 +978,7 @@ struct SimplifyConstCondBranchPred : public RewritePattern {
 
   PatternMatchResult matchAndRewrite(Operation *op,
                                      PatternRewriter &rewriter) const override {
-    auto condbr = op->cast<CondBranchOp>();
+    auto condbr = cast<CondBranchOp>(op);
 
     // Check that the condition is a constant.
     if (!matchPattern(condbr.getCondition(), m_Op<ConstantOp>()))
@@ -1222,7 +1222,7 @@ struct SimplifyDeadDealloc : public RewritePattern {
 
   PatternMatchResult matchAndRewrite(Operation *op,
                                      PatternRewriter &rewriter) const override {
-    auto dealloc = op->cast<DeallocOp>();
+    auto dealloc = cast<DeallocOp>(op);
 
     // Check that the memref operand's defining operation is an AllocOp.
     Value *memref = dealloc.memref();
@@ -2107,7 +2107,7 @@ struct SimplifyXMinusX : public RewritePattern {
 
   PatternMatchResult matchAndRewrite(Operation *op,
                                      PatternRewriter &rewriter) const override {
-    auto subi = op->cast<SubIOp>();
+    auto subi = cast<SubIOp>(op);
     if (subi.getOperand(0) != subi.getOperand(1))
       return matchFailure();
 
@@ -2192,7 +2192,7 @@ struct SimplifyXXOrX : public RewritePattern {
 
   PatternMatchResult matchAndRewrite(Operation *op,
                                      PatternRewriter &rewriter) const override {
-    auto xorOp = op->cast<XOrOp>();
+    auto xorOp = cast<XOrOp>(op);
     if (xorOp.lhs() != xorOp.rhs())
       return matchFailure();
 
index 1c4a4d1..d430c5d 100644 (file)
@@ -128,7 +128,7 @@ struct LoopNestStateCollector {
   void collect(Operation *opToWalk) {
     opToWalk->walk([&](Operation *op) {
       if (op->isa<AffineForOp>())
-        forOps.push_back(op->cast<AffineForOp>());
+        forOps.push_back(cast<AffineForOp>(op));
       else if (op->getNumRegions() != 0)
         hasNonForRegion = true;
       else if (op->isa<LoadOp>())
@@ -172,7 +172,7 @@ public:
     unsigned getLoadOpCount(Value *memref) {
       unsigned loadOpCount = 0;
       for (auto *loadOpInst : loads) {
-        if (memref == loadOpInst->cast<LoadOp>().getMemRef())
+        if (memref == cast<LoadOp>(loadOpInst).getMemRef())
           ++loadOpCount;
       }
       return loadOpCount;
@@ -182,7 +182,7 @@ public:
     unsigned getStoreOpCount(Value *memref) {
       unsigned storeOpCount = 0;
       for (auto *storeOpInst : stores) {
-        if (memref == storeOpInst->cast<StoreOp>().getMemRef())
+        if (memref == cast<StoreOp>(storeOpInst).getMemRef())
           ++storeOpCount;
       }
       return storeOpCount;
@@ -192,7 +192,7 @@ public:
     void getStoreOpsForMemref(Value *memref,
                               SmallVectorImpl<Operation *> *storeOps) {
       for (auto *storeOpInst : stores) {
-        if (memref == storeOpInst->cast<StoreOp>().getMemRef())
+        if (memref == cast<StoreOp>(storeOpInst).getMemRef())
           storeOps->push_back(storeOpInst);
       }
     }
@@ -201,7 +201,7 @@ public:
     void getLoadOpsForMemref(Value *memref,
                              SmallVectorImpl<Operation *> *loadOps) {
       for (auto *loadOpInst : loads) {
-        if (memref == loadOpInst->cast<LoadOp>().getMemRef())
+        if (memref == cast<LoadOp>(loadOpInst).getMemRef())
           loadOps->push_back(loadOpInst);
       }
     }
@@ -211,10 +211,10 @@ public:
     void getLoadAndStoreMemrefSet(DenseSet<Value *> *loadAndStoreMemrefSet) {
       llvm::SmallDenseSet<Value *, 2> loadMemrefs;
       for (auto *loadOpInst : loads) {
-        loadMemrefs.insert(loadOpInst->cast<LoadOp>().getMemRef());
+        loadMemrefs.insert(cast<LoadOp>(loadOpInst).getMemRef());
       }
       for (auto *storeOpInst : stores) {
-        auto *memref = storeOpInst->cast<StoreOp>().getMemRef();
+        auto *memref = cast<StoreOp>(storeOpInst).getMemRef();
         if (loadMemrefs.count(memref) > 0)
           loadAndStoreMemrefSet->insert(memref);
       }
@@ -306,7 +306,7 @@ public:
   bool writesToLiveInOrEscapingMemrefs(unsigned id) {
     Node *node = getNode(id);
     for (auto *storeOpInst : node->stores) {
-      auto *memref = storeOpInst->cast<StoreOp>().getMemRef();
+      auto *memref = cast<StoreOp>(storeOpInst).getMemRef();
       auto *op = memref->getDefiningOp();
       // Return true if 'memref' is a block argument.
       if (!op)
@@ -331,7 +331,7 @@ public:
     Node *node = getNode(id);
     for (auto *storeOpInst : node->stores) {
       // Return false if there exist out edges from 'id' on 'memref'.
-      if (getOutEdgeCount(id, storeOpInst->cast<StoreOp>().getMemRef()) > 0)
+      if (getOutEdgeCount(id, cast<StoreOp>(storeOpInst).getMemRef()) > 0)
         return false;
     }
     return true;
@@ -656,12 +656,12 @@ bool MemRefDependenceGraph::init(Function &f) {
       Node node(nextNodeId++, &op);
       for (auto *opInst : collector.loadOpInsts) {
         node.loads.push_back(opInst);
-        auto *memref = opInst->cast<LoadOp>().getMemRef();
+        auto *memref = cast<LoadOp>(opInst).getMemRef();
         memrefAccesses[memref].insert(node.id);
       }
       for (auto *opInst : collector.storeOpInsts) {
         node.stores.push_back(opInst);
-        auto *memref = opInst->cast<StoreOp>().getMemRef();
+        auto *memref = cast<StoreOp>(opInst).getMemRef();
         memrefAccesses[memref].insert(node.id);
       }
       forToNodeMap[&op] = node.id;
@@ -670,14 +670,14 @@ bool MemRefDependenceGraph::init(Function &f) {
       // Create graph node for top-level load op.
       Node node(nextNodeId++, &op);
       node.loads.push_back(&op);
-      auto *memref = op.cast<LoadOp>().getMemRef();
+      auto *memref = cast<LoadOp>(op).getMemRef();
       memrefAccesses[memref].insert(node.id);
       nodes.insert({node.id, node});
     } else if (auto storeOp = dyn_cast<StoreOp>(op)) {
       // Create graph node for top-level store op.
       Node node(nextNodeId++, &op);
       node.stores.push_back(&op);
-      auto *memref = op.cast<StoreOp>().getMemRef();
+      auto *memref = cast<StoreOp>(op).getMemRef();
       memrefAccesses[memref].insert(node.id);
       nodes.insert({node.id, node});
     } else if (op.getNumRegions() != 0) {
@@ -887,7 +887,7 @@ static void moveLoadsAccessingMemrefTo(Value *memref,
   dstLoads->clear();
   SmallVector<Operation *, 4> srcLoadsToKeep;
   for (auto *load : *srcLoads) {
-    if (load->cast<LoadOp>().getMemRef() == memref)
+    if (cast<LoadOp>(load).getMemRef() == memref)
       dstLoads->push_back(load);
     else
       srcLoadsToKeep.push_back(load);
@@ -1051,7 +1051,7 @@ computeLoopInterchangePermutation(ArrayRef<AffineForOp> loops,
 static void sinkSequentialLoops(MemRefDependenceGraph::Node *node) {
   assert(node->op->isa<AffineForOp>());
   SmallVector<AffineForOp, 4> loops;
-  AffineForOp curr = node->op->cast<AffineForOp>();
+  AffineForOp curr = cast<AffineForOp>(node->op);
   getPerfectlyNestedLoops(loops, curr);
   if (loops.size() < 2)
     return;
@@ -1107,7 +1107,7 @@ static Value *createPrivateMemRef(AffineForOp forOp, Operation *srcStoreOpInst,
   // Builder to create constants at the top level.
   FuncBuilder top(forInst->getFunction());
   // Create new memref type based on slice bounds.
-  auto *oldMemRef = srcStoreOpInst->cast<StoreOp>().getMemRef();
+  auto *oldMemRef = cast<StoreOp>(srcStoreOpInst).getMemRef();
   auto oldMemRefType = oldMemRef->getType().cast<MemRefType>();
   unsigned rank = oldMemRefType.getRank();
 
@@ -1233,7 +1233,7 @@ static bool canFuseSrcWhichWritesToLiveOut(unsigned srcId, unsigned dstId,
   // Gather all memrefs from 'srcNode' store ops.
   DenseSet<Value *> storeMemrefs;
   for (auto *storeOpInst : srcNode->stores) {
-    storeMemrefs.insert(storeOpInst->cast<StoreOp>().getMemRef());
+    storeMemrefs.insert(cast<StoreOp>(storeOpInst).getMemRef());
   }
   // Return false if any of the following are true:
   // *) 'srcNode' writes to a live in/out memref other than 'memref'.
@@ -1842,7 +1842,7 @@ public:
       DenseSet<Value *> visitedMemrefs;
       while (!loads.empty()) {
         // Get memref of load on top of the stack.
-        auto *memref = loads.back()->cast<LoadOp>().getMemRef();
+        auto *memref = cast<LoadOp>(loads.back()).getMemRef();
         if (visitedMemrefs.count(memref) > 0)
           continue;
         visitedMemrefs.insert(memref);
@@ -1898,7 +1898,7 @@ public:
           // Gather 'dstNode' store ops to 'memref'.
           SmallVector<Operation *, 2> dstStoreOpInsts;
           for (auto *storeOpInst : dstNode->stores)
-            if (storeOpInst->cast<StoreOp>().getMemRef() == memref)
+            if (cast<StoreOp>(storeOpInst).getMemRef() == memref)
               dstStoreOpInsts.push_back(storeOpInst);
 
           unsigned bestDstLoopDepth;
@@ -1916,7 +1916,7 @@ public:
             LLVM_DEBUG(llvm::dbgs() << "\tslice loop nest:\n"
                                     << *sliceLoopNest.getOperation() << "\n");
             // Move 'dstAffineForOp' before 'insertPointInst' if needed.
-            auto dstAffineForOp = dstNode->op->cast<AffineForOp>();
+            auto dstAffineForOp = cast<AffineForOp>(dstNode->op);
             if (insertPointInst != dstAffineForOp.getOperation()) {
               dstAffineForOp.getOperation()->moveBefore(insertPointInst);
             }
@@ -1934,7 +1934,7 @@ public:
               // Create private memref for 'memref' in 'dstAffineForOp'.
               SmallVector<Operation *, 4> storesForMemref;
               for (auto *storeOpInst : sliceCollector.storeOpInsts) {
-                if (storeOpInst->cast<StoreOp>().getMemRef() == memref)
+                if (cast<StoreOp>(storeOpInst).getMemRef() == memref)
                   storesForMemref.push_back(storeOpInst);
               }
               assert(storesForMemref.size() == 1);
@@ -1956,7 +1956,7 @@ public:
             // Add new load ops to current Node load op list 'loads' to
             // continue fusing based on new operands.
             for (auto *loadOpInst : dstLoopCollector.loadOpInsts) {
-              auto *loadMemRef = loadOpInst->cast<LoadOp>().getMemRef();
+              auto *loadMemRef = cast<LoadOp>(loadOpInst).getMemRef();
               if (visitedMemrefs.count(loadMemRef) == 0)
                 loads.push_back(loadOpInst);
             }
@@ -2072,7 +2072,7 @@ public:
       auto sliceLoopNest = mlir::insertBackwardComputationSlice(
           sibLoadOpInst, dstLoadOpInsts[0], bestDstLoopDepth, &sliceState);
       if (sliceLoopNest != nullptr) {
-        auto dstForInst = dstNode->op->cast<AffineForOp>();
+        auto dstForInst = cast<AffineForOp>(dstNode->op);
         // Update operation position of fused loop nest (if needed).
         if (insertPointInst != dstForInst.getOperation()) {
           dstForInst.getOperation()->moveBefore(insertPointInst);
@@ -2114,7 +2114,7 @@ public:
       // Check that all stores are to the same memref.
       DenseSet<Value *> storeMemrefs;
       for (auto *storeOpInst : sibNode->stores) {
-        storeMemrefs.insert(storeOpInst->cast<StoreOp>().getMemRef());
+        storeMemrefs.insert(cast<StoreOp>(storeOpInst).getMemRef());
       }
       if (storeMemrefs.size() != 1)
         return false;
@@ -2214,7 +2214,7 @@ public:
     }
 
     // Collect dst loop stats after memref privatizaton transformation.
-    auto dstForInst = dstNode->op->cast<AffineForOp>();
+    auto dstForInst = cast<AffineForOp>(dstNode->op);
     LoopNestStateCollector dstLoopCollector;
     dstLoopCollector.collect(dstForInst.getOperation());
     // Clear and add back loads and stores
@@ -2226,7 +2226,7 @@ public:
     // function.
     if (mdg->getOutEdgeCount(sibNode->id) == 0) {
       mdg->removeNode(sibNode->id);
-      sibNode->op->cast<AffineForOp>().erase();
+      sibNode->op->erase();
     }
   }
 
index 236ef81..1707f78 100644 (file)
@@ -113,7 +113,7 @@ void LoopUnroll::runOnFunction() {
           hasInnerLoops |= walkPostOrder(block.begin(), block.end());
       if (opInst->isa<AffineForOp>()) {
         if (!hasInnerLoops)
-          loops.push_back(opInst->cast<AffineForOp>());
+          loops.push_back(cast<AffineForOp>(opInst));
         return true;
       }
       return hasInnerLoops;
index 0a23295..43e8f4a 100644 (file)
@@ -187,7 +187,7 @@ LogicalResult mlir::loopUnrollJamByFactor(AffineForOp forOp,
     // Insert the cleanup loop right after 'forOp'.
     FuncBuilder builder(forInst->getBlock(),
                         std::next(Block::iterator(forInst)));
-    auto cleanupAffineForOp = builder.clone(*forInst)->cast<AffineForOp>();
+    auto cleanupAffineForOp = cast<AffineForOp>(builder.clone(*forInst));
     // Adjust the lower bound of the cleanup loop; its upper bound is the same
     // as the original loop's upper bound.
     AffineMap cleanupMap;
index 1ffe5e3..6f0162e 100644 (file)
@@ -626,7 +626,7 @@ void LowerAffinePass::runOnFunction() {
     } else if (auto forOp = dyn_cast<AffineForOp>(op)) {
       if (lowerAffineFor(forOp))
         return signalPassFailure();
-    } else if (lowerAffineApply(op->cast<AffineApplyOp>())) {
+    } else if (lowerAffineApply(cast<AffineApplyOp>(op))) {
       return signalPassFailure();
     }
   }
index f7352d6..657169a 100644 (file)
@@ -264,7 +264,7 @@ VectorTransferRewriter<VectorTransferReadOp>::matchAndRewrite(
   using namespace mlir::edsc::op;
   using namespace mlir::edsc::intrinsics;
 
-  VectorTransferReadOp transfer = op->cast<VectorTransferReadOp>();
+  VectorTransferReadOp transfer = cast<VectorTransferReadOp>(op);
 
   // 1. Setup all the captures.
   ScopedContext scope(FuncBuilder(op), transfer.getLoc());
@@ -323,7 +323,7 @@ VectorTransferRewriter<VectorTransferWriteOp>::matchAndRewrite(
   using namespace mlir::edsc::op;
   using namespace mlir::edsc::intrinsics;
 
-  VectorTransferWriteOp transfer = op->cast<VectorTransferWriteOp>();
+  VectorTransferWriteOp transfer = cast<VectorTransferWriteOp>(op);
 
   // 1. Setup all the captures.
   ScopedContext scope(FuncBuilder(op), transfer.getLoc());
index 28dfb22..206ae53 100644 (file)
@@ -679,7 +679,7 @@ static bool materialize(Function *f, const SetVector<Operation *> &terminators,
       continue;
     }
 
-    auto terminator = term->cast<VectorTransferWriteOp>();
+    auto terminator = cast<VectorTransferWriteOp>(term);
     LLVM_DEBUG(dbgs() << "\nFrom terminator:" << *term);
 
     // Get the transitive use-defs starting from terminator, limited to the
index 94df936..118efe5 100644 (file)
@@ -201,7 +201,7 @@ void MemRefDataFlowOpt::forwardStoreToLoad(LoadOp loadOp) {
     return;
 
   // Perform the actual store to load forwarding.
-  Value *storeVal = lastWriteStoreOp->cast<StoreOp>().getValueToStore();
+  Value *storeVal = cast<StoreOp>(lastWriteStoreOp).getValueToStore();
   loadOp.getResult()->replaceAllUsesWith(storeVal);
   // Record the memref for a later sweep to optimize away.
   memrefsToErase.insert(loadOp.getMemRef());
index 0da97f7..272972d 100644 (file)
@@ -234,8 +234,8 @@ static void findMatchingStartFinishInsts(
   // For each start operation, we look for a matching finish operation.
   for (auto *dmaStartInst : dmaStartInsts) {
     for (auto *dmaFinishInst : dmaFinishInsts) {
-      if (checkTagMatch(dmaStartInst->cast<DmaStartOp>(),
-                        dmaFinishInst->cast<DmaWaitOp>())) {
+      if (checkTagMatch(cast<DmaStartOp>(dmaStartInst),
+                        cast<DmaWaitOp>(dmaFinishInst))) {
         startWaitPairs.push_back({dmaStartInst, dmaFinishInst});
         break;
       }
@@ -273,7 +273,7 @@ void PipelineDataTransfer::runOnAffineForOp(AffineForOp forOp) {
   for (auto &pair : startWaitPairs) {
     auto *dmaStartInst = pair.first;
     Value *oldMemRef = dmaStartInst->getOperand(
-        dmaStartInst->cast<DmaStartOp>().getFasterMemPos());
+        cast<DmaStartOp>(dmaStartInst).getFasterMemPos());
     if (!doubleBuffer(oldMemRef, forOp)) {
       // Normally, double buffering should not fail because we already checked
       // that there are no uses outside.
index 7fbb48e..1ae75b4 100644 (file)
@@ -426,7 +426,7 @@ LogicalResult mlir::loopUnrollByFactor(AffineForOp forOp,
   Operation *op = forOp.getOperation();
   if (getLargestDivisorOfTripCount(forOp) % unrollFactor != 0) {
     FuncBuilder builder(op->getBlock(), ++Block::iterator(op));
-    auto cleanupForInst = builder.clone(*op)->cast<AffineForOp>();
+    auto cleanupForInst = cast<AffineForOp>(builder.clone(*op));
     AffineMap cleanupMap;
     SmallVector<Value *, 4> cleanupOperands;
     getCleanupLoopLowerBound(forOp, unrollFactor, &cleanupMap, &cleanupOperands,
@@ -512,7 +512,7 @@ void mlir::interchangeLoops(AffineForOp forOpA, AffineForOp forOpB) {
 void mlir::sinkLoop(AffineForOp forOp, unsigned loopDepth) {
   for (unsigned i = 0; i < loopDepth; ++i) {
     assert(forOp.getBody()->front().isa<AffineForOp>());
-    AffineForOp nextForOp = forOp.getBody()->front().cast<AffineForOp>();
+    AffineForOp nextForOp = cast<AffineForOp>(forOp.getBody()->front());
     interchangeLoops(forOp, nextForOp);
   }
 }
index b64dc53..20138d5 100644 (file)
@@ -253,7 +253,7 @@ void VectorizerTestPass::testNormalizeMaps() {
     SmallVector<NestedMatch, 8> matches;
     pattern.match(f, &matches);
     for (auto m : matches) {
-      auto app = m.getMatchedOperation()->cast<AffineApplyOp>();
+      auto app = cast<AffineApplyOp>(m.getMatchedOperation());
       FuncBuilder b(m.getMatchedOperation());
       SmallVector<Value *, 8> operands(app.getOperands());
       makeComposedAffineApply(&b, app.getLoc(), app.getAffineMap(), operands);
index 9b8768a..4a58b15 100644 (file)
@@ -859,7 +859,7 @@ static FilterFunctionType
 isVectorizableLoopPtrFactory(const llvm::DenseSet<Operation *> &parallelLoops,
                              int fastestVaryingMemRefDimension) {
   return [&parallelLoops, fastestVaryingMemRefDimension](Operation &forOp) {
-    auto loop = forOp.cast<AffineForOp>();
+    auto loop = cast<AffineForOp>(forOp);
     auto parallelIt = parallelLoops.find(loop);
     if (parallelIt == parallelLoops.end())
       return false;
@@ -879,7 +879,7 @@ static LogicalResult
 vectorizeLoopsAndLoadsRecursively(NestedMatch oneMatch,
                                   VectorizationState *state) {
   auto *loopInst = oneMatch.getMatchedOperation();
-  auto loop = loopInst->cast<AffineForOp>();
+  auto loop = cast<AffineForOp>(loopInst);
   auto childrenMatches = oneMatch.getMatchedChildren();
 
   // 1. DFS postorder recursion, if any of my children fails, I fail too.
@@ -1118,7 +1118,7 @@ static LogicalResult vectorizeNonTerminals(VectorizationState *state) {
 /// anything below it fails.
 static LogicalResult vectorizeRootMatch(NestedMatch m,
                                         VectorizationStrategy *strategy) {
-  auto loop = m.getMatchedOperation()->cast<AffineForOp>();
+  auto loop = cast<AffineForOp>(m.getMatchedOperation());
   VectorizationState state;
   state.strategy = strategy;
 
@@ -1139,7 +1139,7 @@ static LogicalResult vectorizeRootMatch(NestedMatch m,
   /// RAII.
   auto *loopInst = loop.getOperation();
   FuncBuilder builder(loopInst);
-  auto clonedLoop = builder.clone(*loopInst)->cast<AffineForOp>();
+  auto clonedLoop = cast<AffineForOp>(builder.clone(*loopInst));
   struct Guard {
     LogicalResult failure() {
       loop.getInductionVar()->replaceAllUsesWith(clonedLoop.getInductionVar());