Refactor the various operand/result/type iterators to use indexed_accessor_range.
authorRiver Riddle <riverriddle@google.com>
Tue, 10 Dec 2019 21:20:50 +0000 (13:20 -0800)
committerA. Unique TensorFlower <gardener@tensorflow.org>
Tue, 10 Dec 2019 21:21:22 +0000 (13:21 -0800)
This has several benefits:
* The implementation is much cleaner and more efficient.
* The ranges now have support for many useful operations: operator[], slice, drop_front, size, etc.
* Value ranges can now directly query a range for their types via 'getTypes()': e.g:
   void foo(Operation::operand_range operands) {
     auto operandTypes = operands.getTypes();
   }

PiperOrigin-RevId: 284834912

mlir/include/mlir/IR/BlockSupport.h
mlir/include/mlir/IR/Operation.h
mlir/include/mlir/IR/OperationSupport.h
mlir/include/mlir/IR/Region.h
mlir/include/mlir/IR/TypeUtilities.h
mlir/include/mlir/Support/STLExtras.h
mlir/lib/Dialect/StandardOps/Ops.cpp
mlir/lib/IR/Operation.cpp
mlir/lib/IR/OperationSupport.cpp
mlir/lib/IR/TypeUtilities.cpp

index 83fbee3..fd30c36 100644 (file)
@@ -68,14 +68,12 @@ class SuccessorRange final
     : public detail::indexed_accessor_range_base<SuccessorRange, BlockOperand *,
                                                  Block *, Block *, Block *> {
 public:
-  using detail::indexed_accessor_range_base<
-      SuccessorRange, BlockOperand *, Block *, Block *,
-      Block *>::indexed_accessor_range_base;
+  using RangeBaseT::RangeBaseT;
   SuccessorRange(Block *block);
 
 private:
   /// See `detail::indexed_accessor_range_base` for details.
-  static BlockOperand *offset_object(BlockOperand *object, ptrdiff_t index) {
+  static BlockOperand *offset_base(BlockOperand *object, ptrdiff_t index) {
     return object + index;
   }
   /// See `detail::indexed_accessor_range_base` for details.
@@ -83,9 +81,8 @@ private:
     return object[index].get();
   }
 
-  /// Allow access to `offset_object` and `dereference_iterator`.
-  friend detail::indexed_accessor_range_base<SuccessorRange, BlockOperand *,
-                                             Block *, Block *, Block *>;
+  /// Allow access to `offset_base` and `dereference_iterator`.
+  friend RangeBaseT;
 };
 
 } // end namespace mlir
index 037c4fc..ac78647 100644 (file)
 #include "llvm/ADT/Twine.h"
 
 namespace mlir {
-class BlockAndValueMapping;
-class Location;
-class MLIRContext;
-class OperandIterator;
-class OperandTypeIterator;
-struct OperationState;
-class ResultIterator;
-class ResultTypeIterator;
-
 /// Terminator operations can have Block operands to represent successors.
 using BlockOperand = IROperandImpl<Block>;
 
@@ -230,14 +221,14 @@ public:
   }
 
   // Support operand iteration.
-  using operand_iterator = OperandIterator;
-  using operand_range = llvm::iterator_range<operand_iterator>;
+  using operand_range = OperandRange;
+  using operand_iterator = operand_range::iterator;
 
-  operand_iterator operand_begin();
-  operand_iterator operand_end();
+  operand_iterator operand_begin() { return getOperands().begin(); }
+  operand_iterator operand_end() { return getOperands().end(); }
 
   /// Returns an iterator on the underlying Value's (Value *).
-  operand_range getOperands();
+  operand_range getOperands() { return operand_range(this); }
 
   /// Erase the operand at position `idx`.
   void eraseOperand(unsigned idx) { getOperandStorage().eraseOperand(idx); }
@@ -249,11 +240,11 @@ public:
   OpOperand &getOpOperand(unsigned idx) { return getOpOperands()[idx]; }
 
   // Support operand type iteration.
-  using operand_type_iterator = OperandTypeIterator;
-  using operand_type_range = llvm::iterator_range<operand_type_iterator>;
-  operand_type_iterator operand_type_begin();
-  operand_type_iterator operand_type_end();
-  operand_type_range getOperandTypes();
+  using operand_type_iterator = operand_range::type_iterator;
+  using operand_type_range = iterator_range<operand_type_iterator>;
+  operand_type_iterator operand_type_begin() { return operand_begin(); }
+  operand_type_iterator operand_type_end() { return operand_end(); }
+  operand_type_range getOperandTypes() { return getOperands().getTypes(); }
 
   //===--------------------------------------------------------------------===//
   // Results
@@ -266,14 +257,13 @@ public:
 
   Value *getResult(unsigned idx) { return &getOpResult(idx); }
 
-  // Support result iteration.
-  using result_iterator = ResultIterator;
-  using result_range = llvm::iterator_range<result_iterator>;
-
-  result_iterator result_begin();
-  result_iterator result_end();
+  /// Support result iteration.
+  using result_range = ResultRange;
+  using result_iterator = result_range::iterator;
 
-  result_range getResults();
+  result_iterator result_begin() { return getResults().begin(); }
+  result_iterator result_end() { return getResults().end(); }
+  result_range getResults() { return result_range(this); }
 
   MutableArrayRef<OpResult> getOpResults() {
     return {getTrailingObjects<OpResult>(), numResults};
@@ -281,12 +271,12 @@ public:
 
   OpResult &getOpResult(unsigned idx) { return getOpResults()[idx]; }
 
-  // Support result type iteration.
-  using result_type_iterator = ResultTypeIterator;
-  using result_type_range = llvm::iterator_range<result_type_iterator>;
-  result_type_iterator result_type_begin();
-  result_type_iterator result_type_end();
-  result_type_range getResultTypes();
+  /// Support result type iteration.
+  using result_type_iterator = result_range::type_iterator;
+  using result_type_range = iterator_range<result_type_iterator>;
+  result_type_iterator result_type_begin() { return result_begin(); }
+  result_type_iterator result_type_end() { return result_end(); }
+  result_type_range getResultTypes() { return getResults().getTypes(); }
 
   //===--------------------------------------------------------------------===//
   // Attributes
@@ -657,91 +647,6 @@ inline raw_ostream &operator<<(raw_ostream &os, Operation &op) {
   return os;
 }
 
-/// This class implements the const/non-const operand iterators for the
-/// Operation class in terms of getOperand(idx).
-class OperandIterator final
-    : public indexed_accessor_iterator<OperandIterator, Operation *, Value *,
-                                       Value *, Value *> {
-public:
-  /// Initializes the operand iterator to the specified operand index.
-  OperandIterator(Operation *object, unsigned index)
-      : indexed_accessor_iterator<OperandIterator, Operation *, Value *,
-                                  Value *, Value *>(object, index) {}
-
-  Value *operator*() const { return this->base->getOperand(this->index); }
-};
-
-/// This class implements the operand type iterators for the Operation
-/// class in terms of operand_iterator->getType().
-class OperandTypeIterator final
-    : public llvm::mapped_iterator<OperandIterator, Type (*)(Value *)> {
-  static Type unwrap(Value *value) { return value->getType(); }
-
-public:
-  using reference = Type;
-
-  /// Provide a const deference method.
-  Type operator*() const { return unwrap(*I); }
-
-  /// Initializes the operand type iterator to the specified operand iterator.
-  OperandTypeIterator(OperandIterator it)
-      : llvm::mapped_iterator<OperandIterator, Type (*)(Value *)>(it, &unwrap) {
-  }
-};
-
-// Implement the inline operand iterator methods.
-inline auto Operation::operand_begin() -> operand_iterator {
-  return operand_iterator(this, 0);
-}
-
-inline auto Operation::operand_end() -> operand_iterator {
-  return operand_iterator(this, getNumOperands());
-}
-
-inline auto Operation::getOperands() -> operand_range {
-  return {operand_begin(), operand_end()};
-}
-
-inline auto Operation::operand_type_begin() -> operand_type_iterator {
-  return operand_type_iterator(operand_begin());
-}
-
-inline auto Operation::operand_type_end() -> operand_type_iterator {
-  return operand_type_iterator(operand_end());
-}
-
-inline auto Operation::getOperandTypes() -> operand_type_range {
-  return {operand_type_begin(), operand_type_end()};
-}
-
-/// This class implements the result iterators for the Operation class
-/// in terms of getResult(idx).
-class ResultIterator final
-    : public indexed_accessor_iterator<ResultIterator, Operation *, Value *,
-                                       Value *, Value *> {
-public:
-  /// Initializes the result iterator to the specified index.
-  ResultIterator(Operation *base, unsigned index)
-      : indexed_accessor_iterator<ResultIterator, Operation *, Value *, Value *,
-                                  Value *>(base, index) {}
-
-  Value *operator*() const { return this->base->getResult(this->index); }
-};
-
-/// This class implements the result type iterators for the Operation
-/// class in terms of result_iterator->getType().
-class ResultTypeIterator final
-    : public llvm::mapped_iterator<ResultIterator, Type (*)(Value *)> {
-  static Type unwrap(Value *value) { return value->getType(); }
-
-public:
-  using reference = Type;
-
-  /// Initializes the result type iterator to the specified result iterator.
-  ResultTypeIterator(ResultIterator it)
-      : llvm::mapped_iterator<ResultIterator, Type (*)(Value *)>(it, &unwrap) {}
-};
-
 /// This class implements use iterator for the Operation. This iterates over all
 /// uses of all results of an Operation.
 class UseIterator final
@@ -768,75 +673,6 @@ private:
   /// The use of the result.
   Value::use_iterator use;
 };
-
-// Implement the inline result iterator methods.
-inline auto Operation::result_begin() -> result_iterator {
-  return result_iterator(this, 0);
-}
-
-inline auto Operation::result_end() -> result_iterator {
-  return result_iterator(this, getNumResults());
-}
-
-inline auto Operation::getResults() -> llvm::iterator_range<result_iterator> {
-  return {result_begin(), result_end()};
-}
-
-inline auto Operation::result_type_begin() -> result_type_iterator {
-  return result_type_iterator(result_begin());
-}
-
-inline auto Operation::result_type_end() -> result_type_iterator {
-  return result_type_iterator(result_end());
-}
-
-inline auto Operation::getResultTypes() -> result_type_range {
-  return {result_type_begin(), result_type_end()};
-}
-
-/// This class provides an abstraction over the different types of ranges over
-/// Value*s. In many cases, this prevents the need to explicitly materialize a
-/// SmallVector/std::vector. This class should be used in places that are not
-/// suitable for a more derived type (e.g. ArrayRef) or a template range
-/// parameter.
-class ValueRange
-    : public detail::indexed_accessor_range_base<
-          ValueRange,
-          llvm::PointerUnion<Value *const *, OpOperand *, OpResult *>, Value *,
-          Value *, Value *> {
-  /// The type representing the owner of this range. This is either a list of
-  /// values, operands, or results.
-  using OwnerT = llvm::PointerUnion<Value *const *, OpOperand *, OpResult *>;
-
-public:
-  using detail::indexed_accessor_range_base<
-      ValueRange, OwnerT, Value *, Value *,
-      Value *>::indexed_accessor_range_base;
-
-  template <typename Arg,
-            typename = typename std::enable_if_t<
-                std::is_constructible<ArrayRef<Value *>, Arg>::value &&
-                !std::is_convertible<Arg, Value *>::value>>
-  ValueRange(Arg &&arg)
-      : ValueRange(ArrayRef<Value *>(std::forward<Arg>(arg))) {}
-  ValueRange(Value *const &value) : ValueRange(&value, /*count=*/1) {}
-  ValueRange(const std::initializer_list<Value *> &values)
-      : ValueRange(ArrayRef<Value *>(values)) {}
-  ValueRange(ArrayRef<Value *> values = llvm::None);
-  ValueRange(iterator_range<OperandIterator> values);
-  ValueRange(iterator_range<ResultIterator> values);
-
-private:
-  /// See `detail::indexed_accessor_range_base` for details.
-  static OwnerT offset_base(const OwnerT &owner, ptrdiff_t index);
-  /// See `detail::indexed_accessor_range_base` for details.
-  static Value *dereference_iterator(const OwnerT &owner, ptrdiff_t index);
-
-  /// Allow access to `offset_base` and `dereference_iterator`.
-  friend detail::indexed_accessor_range_base<ValueRange, OwnerT, Value *,
-                                             Value *, Value *>;
-};
-
 } // end namespace mlir
 
 namespace llvm {
index 14ddf2d..0a0e1ac 100644 (file)
@@ -60,6 +60,10 @@ template <typename OpTy> using OperandAdaptor = typename OpTy::OperandAdaptor;
 
 class OwningRewritePatternList;
 
+//===----------------------------------------------------------------------===//
+// AbstractOperation
+//===----------------------------------------------------------------------===//
+
 enum class OperationProperty {
   /// This bit is set for an operation if it is a commutative operation: that
   /// is a binary operator (two inputs) where "a op b" and "b op a" produce the
@@ -201,6 +205,10 @@ private:
   bool (&hasRawTrait)(ClassID *traitID);
 };
 
+//===----------------------------------------------------------------------===//
+// OperationName
+//===----------------------------------------------------------------------===//
+
 class OperationName {
 public:
   using RepresentationUnion =
@@ -251,6 +259,10 @@ inline llvm::hash_code hash_value(OperationName arg) {
   return llvm::hash_value(arg.getAsOpaquePointer());
 }
 
+//===----------------------------------------------------------------------===//
+// OperationState
+//===----------------------------------------------------------------------===//
+
 /// This represents an operation in an abstracted form, suitable for use with
 /// the builder APIs.  This object is a large and heavy weight object meant to
 /// be used as a temporary object on the stack.  It is generally unwise to put
@@ -322,6 +334,10 @@ public:
   MLIRContext *getContext() { return location->getContext(); }
 };
 
+//===----------------------------------------------------------------------===//
+// OperandStorage
+//===----------------------------------------------------------------------===//
+
 namespace detail {
 /// A utility class holding the information necessary to dynamically resize
 /// operands.
@@ -445,6 +461,10 @@ private:
 };
 } // end namespace detail
 
+//===----------------------------------------------------------------------===//
+// OpPrintingFlags
+//===----------------------------------------------------------------------===//
+
 /// Set of flags used to control the behavior of the various IR print methods
 /// (e.g. Operation::Print).
 class OpPrintingFlags {
@@ -504,6 +524,138 @@ private:
   bool printLocalScope : 1;
 };
 
+//===----------------------------------------------------------------------===//
+// Operation Value-Iterators
+//===----------------------------------------------------------------------===//
+
+//===----------------------------------------------------------------------===//
+// ValueTypeRange
+
+/// This class implements iteration on the types of a given range of values.
+template <typename ValueIteratorT>
+class ValueTypeIterator final
+    : public llvm::mapped_iterator<ValueIteratorT, Type (*)(Value *)> {
+  static Type unwrap(Value *value) { return value->getType(); }
+
+public:
+  using reference = Type;
+
+  /// Provide a const dereference method.
+  Type operator*() const { return unwrap(*this->I); }
+
+  /// Initializes the type iterator to the specified value iterator.
+  ValueTypeIterator(ValueIteratorT it)
+      : llvm::mapped_iterator<ValueIteratorT, Type (*)(Value *)>(it, &unwrap) {}
+};
+
+//===----------------------------------------------------------------------===//
+// OperandRange
+
+/// This class implements the operand iterators for the Operation class.
+class OperandRange final
+    : public detail::indexed_accessor_range_base<OperandRange, OpOperand *,
+                                                 Value *, Value *, Value *> {
+public:
+  using RangeBaseT::RangeBaseT;
+  OperandRange(Operation *op);
+
+  /// Returns the types of the values within this range.
+  using type_iterator = ValueTypeIterator<iterator>;
+  iterator_range<type_iterator> getTypes() const { return {begin(), end()}; }
+
+private:
+  /// See `detail::indexed_accessor_range_base` for details.
+  static OpOperand *offset_base(OpOperand *object, ptrdiff_t index) {
+    return object + index;
+  }
+  /// See `detail::indexed_accessor_range_base` for details.
+  static Value *dereference_iterator(OpOperand *object, ptrdiff_t index) {
+    return object[index].get();
+  }
+
+  /// Allow access to `offset_base` and `dereference_iterator`.
+  friend RangeBaseT;
+};
+
+//===----------------------------------------------------------------------===//
+// ResultRange
+
+/// This class implements the result iterators for the Operation class.
+class ResultRange final
+    : public detail::indexed_accessor_range_base<ResultRange, OpResult *,
+                                                 Value *, Value *, Value *> {
+public:
+  using RangeBaseT::RangeBaseT;
+  ResultRange(Operation *op);
+
+  /// Returns the types of the values within this range.
+  using type_iterator = ValueTypeIterator<iterator>;
+  iterator_range<type_iterator> getTypes() const { return {begin(), end()}; }
+
+private:
+  /// See `detail::indexed_accessor_range_base` for details.
+  static OpResult *offset_base(OpResult *object, ptrdiff_t index) {
+    return object + index;
+  }
+  /// See `detail::indexed_accessor_range_base` for details.
+  static Value *dereference_iterator(OpResult *object, ptrdiff_t index) {
+    return &object[index];
+  }
+
+  /// Allow access to `offset_base` and `dereference_iterator`.
+  friend RangeBaseT;
+};
+
+//===----------------------------------------------------------------------===//
+// ValueRange
+
+/// This class provides an abstraction over the different types of ranges over
+/// Value*s. In many cases, this prevents the need to explicitly materialize a
+/// SmallVector/std::vector. This class should be used in places that are not
+/// suitable for a more derived type (e.g. ArrayRef) or a template range
+/// parameter.
+class ValueRange final
+    : public detail::indexed_accessor_range_base<
+          ValueRange,
+          llvm::PointerUnion<Value *const *, OpOperand *, OpResult *>, Value *,
+          Value *, Value *> {
+public:
+  using RangeBaseT::RangeBaseT;
+
+  template <typename Arg,
+            typename = typename std::enable_if_t<
+                std::is_constructible<ArrayRef<Value *>, Arg>::value &&
+                !std::is_convertible<Arg, Value *>::value>>
+  ValueRange(Arg &&arg)
+      : ValueRange(ArrayRef<Value *>(std::forward<Arg>(arg))) {}
+  ValueRange(Value *const &value) : ValueRange(&value, /*count=*/1) {}
+  ValueRange(const std::initializer_list<Value *> &values)
+      : ValueRange(ArrayRef<Value *>(values)) {}
+  ValueRange(iterator_range<OperandRange::iterator> values)
+      : ValueRange(OperandRange(values)) {}
+  ValueRange(iterator_range<ResultRange::iterator> values)
+      : ValueRange(ResultRange(values)) {}
+  ValueRange(ArrayRef<Value *> values = llvm::None);
+  ValueRange(OperandRange values);
+  ValueRange(ResultRange values);
+
+  /// Returns the types of the values within this range.
+  using type_iterator = ValueTypeIterator<iterator>;
+  iterator_range<type_iterator> getTypes() const { return {begin(), end()}; }
+
+private:
+  /// The type representing the owner of this range. This is either a list of
+  /// values, operands, or results.
+  using OwnerT = llvm::PointerUnion<Value *const *, OpOperand *, OpResult *>;
+
+  /// See `detail::indexed_accessor_range_base` for details.
+  static OwnerT offset_base(const OwnerT &owner, ptrdiff_t index);
+  /// See `detail::indexed_accessor_range_base` for details.
+  static Value *dereference_iterator(const OwnerT &owner, ptrdiff_t index);
+
+  /// Allow access to `offset_base` and `dereference_iterator`.
+  friend RangeBaseT;
+};
 } // end namespace mlir
 
 namespace llvm {
index 3d25140..27b20c2 100644 (file)
@@ -175,9 +175,7 @@ class RegionRange
   using OwnerT = llvm::PointerUnion<Region *, const std::unique_ptr<Region> *>;
 
 public:
-  using detail::indexed_accessor_range_base<
-      RegionRange, OwnerT, Region *, Region *,
-      Region *>::indexed_accessor_range_base;
+  using RangeBaseT::RangeBaseT;
 
   RegionRange(MutableArrayRef<Region> regions = llvm::None);
 
@@ -196,8 +194,7 @@ private:
   static Region *dereference_iterator(const OwnerT &owner, ptrdiff_t index);
 
   /// Allow access to `offset_base` and `dereference_iterator`.
-  friend detail::indexed_accessor_range_base<RegionRange, OwnerT, Region *,
-                                             Region *, Region *>;
+  friend RangeBaseT;
 };
 
 } // end namespace mlir
index 49d57e8..6512f8f 100644 (file)
@@ -65,13 +65,14 @@ LogicalResult verifyCompatibleShape(Type type1, Type type2);
 
 // An iterator for the element types of an op's operands of shaped types.
 class OperandElementTypeIterator final
-    : public llvm::mapped_iterator<OperandIterator, Type (*)(Value *)> {
+    : public llvm::mapped_iterator<Operation::operand_iterator,
+                                   Type (*)(Value *)> {
 public:
   using reference = Type;
 
   /// Initializes the result element type iterator to the specified operand
   /// iterator.
-  explicit OperandElementTypeIterator(OperandIterator it);
+  explicit OperandElementTypeIterator(Operation::operand_iterator it);
 
 private:
   static Type unwrap(Value *value);
@@ -82,13 +83,14 @@ using OperandElementTypeRange =
 
 // An iterator for the tensor element types of an op's results of shaped types.
 class ResultElementTypeIterator final
-    : public llvm::mapped_iterator<ResultIterator, Type (*)(Value *)> {
+    : public llvm::mapped_iterator<Operation::result_iterator,
+                                   Type (*)(Value *)> {
 public:
   using reference = Type;
 
   /// Initializes the result element type iterator to the specified result
   /// iterator.
-  explicit ResultElementTypeIterator(ResultIterator it);
+  explicit ResultElementTypeIterator(Operation::result_iterator it);
 
 private:
   static Type unwrap(Value *value);
index 07db06a..c98f925 100644 (file)
@@ -204,6 +204,9 @@ template <typename DerivedT, typename BaseT, typename T,
           typename PointerT = T *, typename ReferenceT = T &>
 class indexed_accessor_range_base {
 public:
+  using RangeBaseT =
+      indexed_accessor_range_base<DerivedT, BaseT, T, PointerT, ReferenceT>;
+
   /// An iterator element of this range.
   class iterator : public indexed_accessor_iterator<iterator, BaseT, T,
                                                     PointerT, ReferenceT> {
@@ -223,11 +226,17 @@ public:
                                        ReferenceT>;
   };
 
+  indexed_accessor_range_base(iterator begin, iterator end)
+      : base(DerivedT::offset_base(begin.getBase(), begin.getIndex())),
+        count(end.getIndex() - begin.getIndex()) {}
+  indexed_accessor_range_base(const iterator_range<iterator> &range)
+      : indexed_accessor_range_base(range.begin(), range.end()) {}
+
   iterator begin() const { return iterator(base, 0); }
   iterator end() const { return iterator(base, count); }
   ReferenceT operator[](unsigned index) const {
     assert(index < size() && "invalid index for value range");
-    return *std::next(begin(), index);
+    return DerivedT::dereference_iterator(base, index);
   }
 
   /// Return the size of this range.
@@ -237,22 +246,35 @@ public:
   bool empty() const { return size() == 0; }
 
   /// Drop the first N elements, and keep M elements.
-  DerivedT slice(unsigned n, unsigned m) const {
+  DerivedT slice(size_t n, size_t m) const {
     assert(n + m <= size() && "invalid size specifiers");
     return DerivedT(DerivedT::offset_base(base, n), m);
   }
 
   /// Drop the first n elements.
-  DerivedT drop_front(unsigned n = 1) const {
+  DerivedT drop_front(size_t n = 1) const {
     assert(size() >= n && "Dropping more elements than exist");
     return slice(n, size() - n);
   }
   /// Drop the last n elements.
-  DerivedT drop_back(unsigned n = 1) const {
+  DerivedT drop_back(size_t n = 1) const {
     assert(size() >= n && "Dropping more elements than exist");
     return DerivedT(base, size() - n);
   }
 
+  /// Take the first n elements.
+  DerivedT take_front(size_t n = 1) const {
+    return n < size() ? drop_back(size() - n)
+                      : static_cast<const DerivedT &>(*this);
+  }
+
+  /// Allow conversion to SmallVector if necessary.
+  /// TODO(riverriddle) Remove this when SmallVector accepts different range
+  /// types in its constructor.
+  template <typename SVT, unsigned N> operator SmallVector<SVT, N>() const {
+    return {begin(), end()};
+  }
+
 protected:
   indexed_accessor_range_base(BaseT base, ptrdiff_t count)
       : base(base), count(count) {}
index ee90cea..7726c04 100644 (file)
@@ -2819,7 +2819,7 @@ public:
       return matchFailure();
     }
     SmallVector<int64_t, 4> staticShape(subViewOp.getNumSizes());
-    for (auto size : enumerate(subViewOp.sizes())) {
+    for (auto size : llvm::enumerate(subViewOp.sizes())) {
       auto defOp = size.value()->getDefiningOp();
       assert(defOp);
       staticShape[size.index()] = cast<ConstantIndexOp>(defOp).getValue();
@@ -2865,7 +2865,7 @@ public:
     }
 
     SmallVector<int64_t, 4> staticStrides(subViewOp.getNumStrides());
-    for (auto stride : enumerate(subViewOp.strides())) {
+    for (auto stride : llvm::enumerate(subViewOp.strides())) {
       auto defOp = stride.value()->getDefiningOp();
       assert(defOp);
       assert(baseStrides[stride.index()] > 0);
@@ -2916,7 +2916,7 @@ public:
     }
 
     auto staticOffset = baseOffset;
-    for (auto offset : enumerate(subViewOp.offsets())) {
+    for (auto offset : llvm::enumerate(subViewOp.offsets())) {
       auto defOp = offset.value()->getDefiningOp();
       assert(defOp);
       assert(baseStrides[offset.index()] > 0);
index 0483c27..fd747a9 100644 (file)
@@ -597,9 +597,8 @@ void Operation::setSuccessor(Block *block, unsigned index) {
 }
 
 auto Operation::getNonSuccessorOperands() -> operand_range {
-  return {operand_iterator(this, 0),
-          operand_iterator(this, hasSuccessors() ? getSuccessorOperandIndex(0)
-                                                 : getNumOperands())};
+  return getOperands().take_front(hasSuccessors() ? getSuccessorOperandIndex(0)
+                                                  : getNumOperands());
 }
 
 /// Get the index of the first operand of the successor at the provided
@@ -635,9 +634,7 @@ Operation::decomposeSuccessorOperandIndex(unsigned operandIndex) {
 
 auto Operation::getSuccessorOperands(unsigned index) -> operand_range {
   unsigned succOperandIndex = getSuccessorOperandIndex(index);
-  return {operand_iterator(this, succOperandIndex),
-          operand_iterator(this,
-                           succOperandIndex + getNumSuccessorOperands(index))};
+  return getOperands().slice(succOperandIndex, getNumSuccessorOperands(index));
 }
 
 /// Attempt to fold this operation using the Op's registered foldHook.
@@ -746,48 +743,6 @@ Operation *Operation::clone() {
 }
 
 //===----------------------------------------------------------------------===//
-// ValueRange
-//===----------------------------------------------------------------------===//
-
-ValueRange::ValueRange(ArrayRef<Value *> values)
-    : ValueRange(values.data(), values.size()) {}
-ValueRange::ValueRange(llvm::iterator_range<OperandIterator> values)
-    : ValueRange(nullptr, llvm::size(values)) {
-  if (!empty()) {
-    auto begin = values.begin();
-    base = &begin.getBase()->getOpOperand(begin.getIndex());
-  }
-}
-ValueRange::ValueRange(llvm::iterator_range<ResultIterator> values)
-    : ValueRange(nullptr, llvm::size(values)) {
-  if (!empty()) {
-    auto begin = values.begin();
-    base = &begin.getBase()->getOpResult(begin.getIndex());
-  }
-}
-
-/// See `detail::indexed_accessor_range_base` for details.
-ValueRange::OwnerT ValueRange::offset_base(const OwnerT &owner,
-                                           ptrdiff_t index) {
-  if (OpOperand *operand = owner.dyn_cast<OpOperand *>())
-    return operand + index;
-  if (OpResult *result = owner.dyn_cast<OpResult *>())
-    return result + index;
-  return owner.get<Value *const *>() + index;
-}
-/// See `detail::indexed_accessor_range_base` for details.
-Value *ValueRange::dereference_iterator(const OwnerT &owner, ptrdiff_t index) {
-  // Operands access the held value via 'get'.
-  if (OpOperand *operand = owner.dyn_cast<OpOperand *>())
-    return operand[index].get();
-  // An OpResult is a value, so we can return it directly.
-  if (OpResult *result = owner.dyn_cast<OpResult *>())
-    return &result[index];
-  // Otherwise, this is a raw value array so just index directly.
-  return owner.get<Value *const *>()[index];
-}
-
-//===----------------------------------------------------------------------===//
 // OpState trait class.
 //===----------------------------------------------------------------------===//
 
@@ -979,7 +934,7 @@ OpTrait::impl::verifySameOperandsAndResultElementType(Operation *op) {
   auto elementType = getElementTypeOrSelf(op->getResult(0));
 
   // Verify result element type matches first result's element type.
-  for (auto result : drop_begin(op->getResults(), 1)) {
+  for (auto result : llvm::drop_begin(op->getResults(), 1)) {
     if (getElementTypeOrSelf(result) != elementType)
       return op->emitOpError(
           "requires the same element type for all operands and results");
@@ -1210,7 +1165,7 @@ Value *impl::foldCastOp(Operation *op) {
 }
 
 //===----------------------------------------------------------------------===//
-// CastOp implementation
+// Misc. utils
 //===----------------------------------------------------------------------===//
 
 /// Insert an operation, generated by `buildTerminatorOp`, at the end of the
@@ -1230,6 +1185,10 @@ void impl::ensureRegionTerminator(
   block.push_back(buildTerminatorOp());
 }
 
+//===----------------------------------------------------------------------===//
+// UseIterator
+//===----------------------------------------------------------------------===//
+
 UseIterator::UseIterator(Operation *op, bool end)
     : op(op), res(end ? op->result_end() : op->result_begin()) {
   // Only initialize current use if there are results/can be uses.
index e4ff889..256a261 100644 (file)
@@ -144,3 +144,50 @@ void detail::OperandStorage::grow(ResizableStorage &resizeUtil,
     operand.~OpOperand();
   resizeUtil.setDynamicStorage(newStorage);
 }
+
+//===----------------------------------------------------------------------===//
+// Operation Value-Iterators
+//===----------------------------------------------------------------------===//
+
+//===----------------------------------------------------------------------===//
+// OperandRange
+
+OperandRange::OperandRange(Operation *op)
+    : OperandRange(op->getOpOperands().data(), op->getNumOperands()) {}
+
+//===----------------------------------------------------------------------===//
+// ResultRange
+
+ResultRange::ResultRange(Operation *op)
+    : ResultRange(op->getOpResults().data(), op->getNumResults()) {}
+
+//===----------------------------------------------------------------------===//
+// ValueRange
+
+ValueRange::ValueRange(ArrayRef<Value *> values)
+    : ValueRange(values.data(), values.size()) {}
+ValueRange::ValueRange(OperandRange values)
+    : ValueRange(values.begin().getBase(), values.size()) {}
+ValueRange::ValueRange(ResultRange values)
+    : ValueRange(values.begin().getBase(), values.size()) {}
+
+/// See `detail::indexed_accessor_range_base` for details.
+ValueRange::OwnerT ValueRange::offset_base(const OwnerT &owner,
+                                           ptrdiff_t index) {
+  if (OpOperand *operand = owner.dyn_cast<OpOperand *>())
+    return operand + index;
+  if (OpResult *result = owner.dyn_cast<OpResult *>())
+    return result + index;
+  return owner.get<Value *const *>() + index;
+}
+/// See `detail::indexed_accessor_range_base` for details.
+Value *ValueRange::dereference_iterator(const OwnerT &owner, ptrdiff_t index) {
+  // Operands access the held value via 'get'.
+  if (OpOperand *operand = owner.dyn_cast<OpOperand *>())
+    return operand[index].get();
+  // An OpResult is a value, so we can return it directly.
+  if (OpResult *result = owner.dyn_cast<OpResult *>())
+    return &result[index];
+  // Otherwise, this is a raw value array so just index directly.
+  return owner.get<Value *const *>()[index];
+}
index a963a8d..0172141 100644 (file)
@@ -92,15 +92,19 @@ LogicalResult mlir::verifyCompatibleShape(Type type1, Type type2) {
   return success();
 }
 
-OperandElementTypeIterator::OperandElementTypeIterator(OperandIterator it)
-    : llvm::mapped_iterator<OperandIterator, Type (*)(Value *)>(it, &unwrap) {}
+OperandElementTypeIterator::OperandElementTypeIterator(
+    Operation::operand_iterator it)
+    : llvm::mapped_iterator<Operation::operand_iterator, Type (*)(Value *)>(
+          it, &unwrap) {}
 
 Type OperandElementTypeIterator::unwrap(Value *value) {
   return value->getType().cast<ShapedType>().getElementType();
 }
 
-ResultElementTypeIterator::ResultElementTypeIterator(ResultIterator it)
-    : llvm::mapped_iterator<ResultIterator, Type (*)(Value *)>(it, &unwrap) {}
+ResultElementTypeIterator::ResultElementTypeIterator(
+    Operation::result_iterator it)
+    : llvm::mapped_iterator<Operation::result_iterator, Type (*)(Value *)>(
+          it, &unwrap) {}
 
 Type ResultElementTypeIterator::unwrap(Value *value) {
   return value->getType().cast<ShapedType>().getElementType();