Added the ability to run a mapping function across the values of an elements
authorRob Suderman <suderman@google.com>
Wed, 22 May 2019 22:55:17 +0000 (15:55 -0700)
committerMehdi Amini <joker.eph@gmail.com>
Sun, 2 Jun 2019 02:57:13 +0000 (19:57 -0700)
    attr. This supports both the SplatElementsAttr and DenseElementsAttr for both
    float and integer inputs / outputs.

--

PiperOrigin-RevId: 249538085

mlir/include/mlir/IR/Attributes.h
mlir/lib/IR/Attributes.cpp

index 39ca337..106a9f5 100644 (file)
@@ -405,6 +405,20 @@ public:
   /// element, then a null attribute is returned.
   Attribute getValue(ArrayRef<uint64_t> index) const;
 
+  /// Generates a new ElementsAttr by mapping each int value to a new
+  /// underlying APInt. The new values can represent either a integer or float.
+  /// This ElementsAttr should contain integers.
+  ElementsAttr
+  mapValues(Type newElementType,
+            llvm::function_ref<APInt(const APInt &)> mapping) const;
+
+  /// Generates a new ElementsAttr by mapping each float value to a new
+  /// underlying APInt. The new values can represent either a integer or float.
+  /// This ElementsAttr should contain floats.
+  ElementsAttr
+  mapValues(Type newElementType,
+            llvm::function_ref<APInt(const APFloat &)> mapping) const;
+
   /// Method for support type inquiry through isa, cast and dyn_cast.
   static bool classof(Attribute attr) {
     return attr.getKind() >= StandardAttributes::FIRST_ELEMENTS_ATTR &&
@@ -424,6 +438,20 @@ public:
   static SplatElementsAttr get(ShapedType type, Attribute elt);
   Attribute getValue() const;
 
+  /// Generates a new SplatElementsAttr by mapping each int value to a new
+  /// underlying APInt. The new values can represent either a integer or float.
+  /// This ElementsAttr should contain integers.
+  SplatElementsAttr
+  mapValues(Type newElementType,
+            llvm::function_ref<APInt(const APInt &)> mapping) const;
+
+  /// Generates a new SplatElementsAttr by mapping each float value to a new
+  /// underlying APInt. The new values can represent either a integer or float.
+  /// This ElementsAttr should contain floats.
+  SplatElementsAttr
+  mapValues(Type newElementType,
+            llvm::function_ref<APInt(const APFloat &)> mapping) const;
+
   /// Method for support type inquiry through isa, cast and dyn_cast.
   static bool kindof(unsigned kind) {
     return kind == StandardAttributes::SplatElements;
@@ -454,6 +482,20 @@ public:
 
   void getValues(SmallVectorImpl<Attribute> &values) const;
 
+  /// Generates a new DenseElementsAttr by mapping each int value to a new
+  /// underlying APInt. The new values can represent either a integer or float.
+  /// This underlying type must be an DenseIntElementsAttr.
+  DenseElementsAttr
+  mapValues(Type newElementType,
+            llvm::function_ref<APInt(const APInt &)> mapping) const;
+
+  /// Generates a new DenseElementsAttr by mapping each float value to a new
+  /// underlying APInt. the new values can represent either a integer or float.
+  /// This underlying type must be an DenseFPElementsAttr.
+  DenseElementsAttr
+  mapValues(Type newElementType,
+            llvm::function_ref<APInt(const APFloat &)> mapping) const;
+
   ArrayRef<char> getRawData() const;
 
   /// Writes value to the bit position `bitPos` in array `rawData`. 'rawData' is
@@ -550,6 +592,12 @@ public:
   /// type of 'type'.
   static DenseIntElementsAttr get(ShapedType type, ArrayRef<int64_t> values);
 
+  /// Generates a new DenseElementsAttr by mapping each value attribute, and
+  /// constructing the DenseElementsAttr given the new element type.
+  DenseElementsAttr
+  mapValues(Type newElementType,
+            llvm::function_ref<APInt(const APInt &)> mapping) const;
+
   /// Gets the integer value of each of the dense elements.
   void getValues(SmallVectorImpl<APInt> &values) const;
 
@@ -593,6 +641,12 @@ public:
   /// Gets the float value of each of the dense elements.
   void getValues(SmallVectorImpl<APFloat> &values) const;
 
+  /// Generates a new DenseElementsAttr by mapping each value attribute, and
+  /// constructing the DenseElementsAttr given the new element type.
+  DenseElementsAttr
+  mapValues(Type newElementType,
+            llvm::function_ref<APInt(const APFloat &)> mapping) const;
+
   /// Iterator access to the float element values.
   iterator begin() const;
   iterator end() const;
index add8a85..9b0d744 100644 (file)
@@ -287,6 +287,36 @@ Attribute ElementsAttr::getValue(ArrayRef<uint64_t> index) const {
   }
 }
 
+ElementsAttr ElementsAttr::mapValues(
+    Type newElementType,
+    llvm::function_ref<APInt(const APInt &)> mapping) const {
+  switch (getKind()) {
+  case StandardAttributes::DenseIntElements:
+  case StandardAttributes::DenseFPElements:
+    return cast<DenseElementsAttr>().mapValues(newElementType, mapping);
+  case StandardAttributes::SplatElements:
+    return cast<SplatElementsAttr>().mapValues(newElementType, mapping);
+  default:
+    llvm_unreachable("unsupported ElementsAttr subtype");
+  }
+}
+
+ElementsAttr ElementsAttr::mapValues(
+    Type newElementType,
+    llvm::function_ref<APInt(const APFloat &)> mapping) const {
+  switch (getKind()) {
+  case StandardAttributes::DenseIntElements:
+  case StandardAttributes::DenseFPElements:
+    return cast<DenseElementsAttr>().mapValues(newElementType, mapping);
+  case StandardAttributes::SplatElements:
+    return cast<SplatElementsAttr>().mapValues(newElementType, mapping);
+  default:
+    break;
+  }
+
+  llvm_unreachable("unsupported ElementsAttr subtype");
+}
+
 //===----------------------------------------------------------------------===//
 // SplatElementsAttr
 //===----------------------------------------------------------------------===//
@@ -300,6 +330,74 @@ SplatElementsAttr SplatElementsAttr::get(ShapedType type, Attribute elt) {
 
 Attribute SplatElementsAttr::getValue() const { return getImpl()->elt; }
 
+SplatElementsAttr SplatElementsAttr::mapValues(
+    Type newElementType,
+    llvm::function_ref<APInt(const APInt &)> mapping) const {
+  Type inType = getType();
+  auto inElementType = getType().getElementType();
+
+  ShapedType newArrayType;
+  if (inType.isa<RankedTensorType>())
+    newArrayType = RankedTensorType::get(getType().getShape(), newElementType);
+  else if (inType.isa<UnrankedTensorType>())
+    newArrayType = RankedTensorType::get(getType().getShape(), newElementType);
+  else
+    assert(false && "Unhandled tensor type");
+
+  assert(inElementType.isa<IntegerType>() &&
+         "Attempting to map non-integer array as integers");
+
+  if (newElementType.isa<IntegerType>()) {
+    APInt newValue = mapping(getValue().cast<IntegerAttr>().getValue());
+    auto newAttr = IntegerAttr::get(newElementType, newValue);
+    return get(newArrayType, newAttr);
+  }
+
+  if (newElementType.isa<FloatType>()) {
+    APFloat newValue(newElementType.cast<FloatType>().getFloatSemantics(),
+                     mapping(getValue().cast<IntegerAttr>().getValue()));
+    auto newAttr = FloatAttr::get(newElementType, newValue);
+    return get(newArrayType, newAttr);
+  }
+
+  llvm_unreachable("unknown output splat type");
+}
+
+SplatElementsAttr SplatElementsAttr::mapValues(
+    Type newElementType,
+    llvm::function_ref<APInt(const APFloat &)> mapping) const {
+  Type inType = getType();
+  auto inElementType = getType().getElementType();
+
+  ShapedType newArrayType;
+  if (inType.isa<RankedTensorType>()) {
+    newArrayType = RankedTensorType::get(getType().getShape(), newElementType);
+  } else if (inType.isa<UnrankedTensorType>()) {
+    newArrayType = RankedTensorType::get(getType().getShape(), newElementType);
+  }
+
+  assert(newArrayType && "Unhandled tensor type");
+  assert(inElementType.isa<FloatType>() &&
+         "mapping function expects float tensor");
+
+  Attribute newAttr;
+  if (newElementType.isa<IntegerType>()) {
+    APInt newValue = mapping(getValue().cast<FloatAttr>().getValue());
+    newAttr = IntegerAttr::get(newElementType, newValue);
+    return get(newArrayType, newAttr);
+  }
+
+  if (newElementType.isa<FloatType>()) {
+    APFloat newValue =
+        APFloat(newElementType.cast<FloatType>().getFloatSemantics(),
+                mapping(getValue().cast<FloatAttr>().getValue()));
+    newAttr = FloatAttr::get(newElementType, newValue);
+    return get(newArrayType, newAttr);
+  }
+
+  llvm_unreachable("unknown output splat type");
+}
+
 //===----------------------------------------------------------------------===//
 // RawElementIterator
 //===----------------------------------------------------------------------===//
@@ -459,6 +557,18 @@ void DenseElementsAttr::getValues(SmallVectorImpl<Attribute> &values) const {
   }
 }
 
+DenseElementsAttr DenseElementsAttr::mapValues(
+    Type newElementType,
+    llvm::function_ref<APInt(const APInt &)> mapping) const {
+  return cast<DenseIntElementsAttr>().mapValues(newElementType, mapping);
+}
+
+DenseElementsAttr DenseElementsAttr::mapValues(
+    Type newElementType,
+    llvm::function_ref<APInt(const APFloat &)> mapping) const {
+  return cast<DenseFPElementsAttr>().mapValues(newElementType, mapping);
+}
+
 ArrayRef<char> DenseElementsAttr::getRawData() const {
   return static_cast<ImplType *>(impl)->data;
 }
@@ -562,6 +672,35 @@ void DenseIntElementsAttr::getValues(SmallVectorImpl<APInt> &values) const {
   values.assign(raw_begin(), raw_end());
 }
 
+DenseElementsAttr DenseIntElementsAttr::mapValues(
+    Type newElementType,
+    llvm::function_ref<APInt(const APInt &)> mapping) const {
+  Type inType = getType();
+  size_t bitWidth = getDenseElementBitwidth(newElementType);
+
+  ShapedType newArrayType;
+  if (inType.isa<RankedTensorType>()) {
+    newArrayType = RankedTensorType::get(getType().getShape(), newElementType);
+  } else if (inType.isa<UnrankedTensorType>()) {
+    newArrayType = RankedTensorType::get(getType().getShape(), newElementType);
+  }
+
+  assert(newArrayType && "Unhandled tensor type");
+
+  llvm::SmallVector<char, 8> elementData(APInt::getNumWords(bitWidth * size()) *
+                                         APInt::APINT_WORD_SIZE);
+
+  uint64_t elementIdx = 0;
+  for (auto value : *this) {
+    auto newInt = mapping(value);
+    assert(newInt.getBitWidth() == bitWidth);
+    writeBits(elementData.data(), elementIdx * bitWidth, newInt);
+    ++elementIdx;
+  }
+
+  return get(newArrayType, elementData);
+}
+
 //===----------------------------------------------------------------------===//
 // DenseFPElementsAttr
 //===----------------------------------------------------------------------===//
@@ -589,6 +728,34 @@ void DenseFPElementsAttr::getValues(SmallVectorImpl<APFloat> &values) const {
   values.assign(begin(), end());
 }
 
+DenseElementsAttr DenseFPElementsAttr::mapValues(
+    Type newElementType,
+    llvm::function_ref<APInt(const APFloat &)> mapping) const {
+  Type inType = getType();
+  size_t bitWidth = getDenseElementBitwidth(newElementType);
+
+  ShapedType newArrayType;
+  if (inType.isa<RankedTensorType>())
+    newArrayType = RankedTensorType::get(getType().getShape(), newElementType);
+  else if (inType.isa<UnrankedTensorType>())
+    newArrayType = RankedTensorType::get(getType().getShape(), newElementType);
+  else
+    assert(false && "Unhandled tensor type");
+
+  llvm::SmallVector<char, 80> elementData(
+      APInt::getNumWords(bitWidth * size()) * APInt::APINT_WORD_SIZE);
+
+  uint64_t elementIdx = 0;
+  for (auto value : *this) {
+    auto newInt = mapping(value);
+    assert(newInt.getBitWidth() == bitWidth);
+    writeBits(elementData.data(), elementIdx * bitWidth, newInt);
+    ++elementIdx;
+  }
+
+  return get(newArrayType, elementData);
+}
+
 /// Iterator access to the float element values.
 DenseFPElementsAttr::iterator DenseFPElementsAttr::begin() const {
   auto elementType = getType().getElementType().cast<FloatType>();