Add definition for OperandElementTypeIterator and ResultElementTypeIterator
authorLei Zhang <antiagainst@google.com>
Thu, 20 Jun 2019 12:59:19 +0000 (05:59 -0700)
committerjpienaar <jpienaar@google.com>
Sat, 22 Jun 2019 16:14:01 +0000 (09:14 -0700)
These are useful utility iterators helping use to get the element types of
operands/results of shaped types.

Also defined ranges for these iterators.

PiperOrigin-RevId: 254180888

mlir/include/mlir/Support/TypeUtilities.h
mlir/lib/Support/TypeUtilities.cpp

index 255281f..6f8a90c 100644 (file)
 #ifndef MLIR_SUPPORT_TYPEUTILITIES_H
 #define MLIR_SUPPORT_TYPEUTILITIES_H
 
+#include "mlir/IR/Operation.h"
+#include "llvm/ADT/STLExtras.h"
+
 namespace mlir {
 
 class Attribute;
 class Type;
 class Value;
 
+//===----------------------------------------------------------------------===//
+// Utility Functions
+//===----------------------------------------------------------------------===//
+
 /// Return the element type or return the type itself.
 Type getElementTypeOrSelf(Type type);
 
@@ -36,6 +43,40 @@ Type getElementTypeOrSelf(Attribute attr);
 Type getElementTypeOrSelf(Value *val);
 Type getElementTypeOrSelf(Value &val);
 
+//===----------------------------------------------------------------------===//
+// Utility Iterators
+//===----------------------------------------------------------------------===//
+
+// An iterator for the element types of an op's operands of shaped types.
+class OperandElementTypeIterator final
+    : public llvm::mapped_iterator<OperandIterator, Type (*)(Value *)> {
+public:
+  /// Initializes the result element type iterator to the specified operand
+  /// iterator.
+  explicit OperandElementTypeIterator(OperandIterator it);
+
+private:
+  static Type unwrap(Value *value);
+};
+
+using OperandElementTypeRange =
+    llvm::iterator_range<OperandElementTypeIterator>;
+
+// An iterator for the tensor element types of an op's results of shaped types.
+class ResultElementTypeIterator final
+    : public llvm::mapped_iterator<ResultIterator, Type (*)(Value *)> {
+
+public:
+  /// Initializes the result element type iterator to the specified result
+  /// iterator.
+  explicit ResultElementTypeIterator(ResultIterator it);
+
+private:
+  static Type unwrap(Value *value);
+};
+
+using ResultElementTypeRange = llvm::iterator_range<ResultElementTypeIterator>;
+
 } // end namespace mlir
 
 #endif // MLIR_SUPPORT_TYPEUTILITIES_H
index 8114b73..63b04f4 100644 (file)
@@ -44,3 +44,17 @@ Type mlir::getElementTypeOrSelf(Value &val) {
 Type mlir::getElementTypeOrSelf(Attribute attr) {
   return getElementTypeOrSelf(attr.getType());
 }
+
+OperandElementTypeIterator::OperandElementTypeIterator(OperandIterator it)
+    : llvm::mapped_iterator<OperandIterator, Type (*)(Value *)>(it, &unwrap) {}
+
+Type OperandElementTypeIterator::unwrap(Value *value) {
+  return value->getType().cast<ShapedType>().getElementType();
+}
+
+ResultElementTypeIterator::ResultElementTypeIterator(ResultIterator it)
+    : llvm::mapped_iterator<ResultIterator, Type (*)(Value *)>(it, &unwrap) {}
+
+Type ResultElementTypeIterator::unwrap(Value *value) {
+  return value->getType().cast<ShapedType>().getElementType();
+}