#ifndef MLIR_SUPPORT_TYPEUTILITIES_H
#define MLIR_SUPPORT_TYPEUTILITIES_H
+#include "mlir/IR/Operation.h"
+#include "llvm/ADT/STLExtras.h"
+
namespace mlir {
class Attribute;
class Type;
class Value;
+//===----------------------------------------------------------------------===//
+// Utility Functions
+//===----------------------------------------------------------------------===//
+
/// Return the element type or return the type itself.
Type getElementTypeOrSelf(Type type);
Type getElementTypeOrSelf(Value *val);
Type getElementTypeOrSelf(Value &val);
+//===----------------------------------------------------------------------===//
+// Utility Iterators
+//===----------------------------------------------------------------------===//
+
+// An iterator for the element types of an op's operands of shaped types.
+class OperandElementTypeIterator final
+ : public llvm::mapped_iterator<OperandIterator, Type (*)(Value *)> {
+public:
+ /// Initializes the result element type iterator to the specified operand
+ /// iterator.
+ explicit OperandElementTypeIterator(OperandIterator it);
+
+private:
+ static Type unwrap(Value *value);
+};
+
+using OperandElementTypeRange =
+ llvm::iterator_range<OperandElementTypeIterator>;
+
+// An iterator for the tensor element types of an op's results of shaped types.
+class ResultElementTypeIterator final
+ : public llvm::mapped_iterator<ResultIterator, Type (*)(Value *)> {
+
+public:
+ /// Initializes the result element type iterator to the specified result
+ /// iterator.
+ explicit ResultElementTypeIterator(ResultIterator it);
+
+private:
+ static Type unwrap(Value *value);
+};
+
+using ResultElementTypeRange = llvm::iterator_range<ResultElementTypeIterator>;
+
} // end namespace mlir
#endif // MLIR_SUPPORT_TYPEUTILITIES_H
Type mlir::getElementTypeOrSelf(Attribute attr) {
return getElementTypeOrSelf(attr.getType());
}
+
+OperandElementTypeIterator::OperandElementTypeIterator(OperandIterator it)
+ : llvm::mapped_iterator<OperandIterator, Type (*)(Value *)>(it, &unwrap) {}
+
+Type OperandElementTypeIterator::unwrap(Value *value) {
+ return value->getType().cast<ShapedType>().getElementType();
+}
+
+ResultElementTypeIterator::ResultElementTypeIterator(ResultIterator it)
+ : llvm::mapped_iterator<ResultIterator, Type (*)(Value *)>(it, &unwrap) {}
+
+Type ResultElementTypeIterator::unwrap(Value *value) {
+ return value->getType().cast<ShapedType>().getElementType();
+}