[mlir] Remove the non-templated DenseElementsAttr::getSplatValue
authorRiver Riddle <riddleriver@gmail.com>
Tue, 9 Nov 2021 01:40:17 +0000 (01:40 +0000)
committerRiver Riddle <riddleriver@gmail.com>
Tue, 9 Nov 2021 01:40:40 +0000 (01:40 +0000)
This predates the templated variant, and has been simply forwarding
to getSplatValue<Attribute> for some time. Removing this makes the
API a bit more uniform, and also helps prevent users from thinking
it is "cheap".

mlir/include/mlir/IR/BuiltinAttributes.h
mlir/include/mlir/IR/Matchers.h
mlir/lib/Conversion/MemRefToLLVM/MemRefToLLVM.cpp
mlir/lib/Conversion/VectorToGPU/VectorToGPU.cpp
mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp
mlir/lib/Dialect/Vector/VectorOps.cpp
mlir/lib/Target/LLVMIR/ModuleTranslation.cpp

index 37da2eb..ba0fbe4 100644 (file)
@@ -353,7 +353,6 @@ public:
 
   /// Return the splat value for this attribute. This asserts that the attribute
   /// corresponds to a splat.
-  Attribute getSplatValue() const { return getSplatValue<Attribute>(); }
   template <typename T>
   typename std::enable_if<!std::is_base_of<Attribute, T>::value ||
                               std::is_same<Attribute, T>::value,
index 548cbe3..1cac3ea 100644 (file)
@@ -110,7 +110,7 @@ struct constant_int_op_binder {
     if (type.isa<VectorType, RankedTensorType>()) {
       if (auto splatAttr = attr.dyn_cast<SplatElementsAttr>()) {
         return attr_value_binder<IntegerAttr>(bind_value)
-            .match(splatAttr.getSplatValue());
+            .match(splatAttr.getSplatValue<Attribute>());
       }
     }
     return false;
index b4ea696..521b3fc 100644 (file)
@@ -451,7 +451,7 @@ struct GlobalMemrefOpLowering
       // For scalar memrefs, the global variable created is of the element type,
       // so unpack the elements attribute to extract the value.
       if (type.getRank() == 0)
-        initialValue = elementsAttr.getValues<Attribute>()[0];
+        initialValue = elementsAttr.getSplatValue<Attribute>();
     }
 
     uint64_t alignment = global.alignment().getValueOr(0);
index b97a046..a9f3c7d 100644 (file)
@@ -349,7 +349,8 @@ static void convertConstantOp(arith::ConstantOp op,
                               llvm::DenseMap<Value, Value> &valueMapping) {
   assert(constantSupportsMMAMatrixType(op));
   OpBuilder b(op);
-  Attribute splat = op.getValue().cast<SplatElementsAttr>().getSplatValue();
+  Attribute splat =
+      op.getValue().cast<SplatElementsAttr>().getSplatValue<Attribute>();
   auto scalarConstant =
       b.create<arith::ConstantOp>(op.getLoc(), splat.getType(), splat);
   const char *fragType = inferFragType(op);
index 8831a8d..566b3a2 100644 (file)
@@ -1574,7 +1574,7 @@ static bool isZeroAttribute(Attribute value) {
   if (auto fpValue = value.dyn_cast<FloatAttr>())
     return fpValue.getValue().isZero();
   if (auto splatValue = value.dyn_cast<SplatElementsAttr>())
-    return isZeroAttribute(splatValue.getSplatValue());
+    return isZeroAttribute(splatValue.getSplatValue<Attribute>());
   if (auto elementsValue = value.dyn_cast<ElementsAttr>())
     return llvm::all_of(elementsValue.getValues<Attribute>(), isZeroAttribute);
   if (auto arrayValue = value.dyn_cast<ArrayAttr>())
index bda4ee7..703bc9c 100644 (file)
@@ -1395,7 +1395,7 @@ OpFoldResult BroadcastOp::fold(ArrayRef<Attribute> operands) {
   if (operands[0].getType().isIntOrIndexOrFloat())
     return DenseElementsAttr::get(vectorType, operands[0]);
   if (auto attr = operands[0].dyn_cast<SplatElementsAttr>())
-    return DenseElementsAttr::get(vectorType, attr.getSplatValue());
+    return DenseElementsAttr::get(vectorType, attr.getSplatValue<Attribute>());
   return {};
 }
 
@@ -2212,7 +2212,7 @@ public:
     if (!dense)
       return failure();
     auto newAttr = DenseElementsAttr::get(extractStridedSliceOp.getType(),
-                                          dense.getSplatValue());
+                                          dense.getSplatValue<Attribute>());
     rewriter.replaceOpWithNewOp<arith::ConstantOp>(extractStridedSliceOp,
                                                    newAttr);
     return success();
@@ -3670,8 +3670,9 @@ public:
     auto dense = constantOp.getValue().dyn_cast<SplatElementsAttr>();
     if (!dense)
       return failure();
-    auto newAttr = DenseElementsAttr::get(
-        shapeCastOp.getType().cast<VectorType>(), dense.getSplatValue());
+    auto newAttr =
+        DenseElementsAttr::get(shapeCastOp.getType().cast<VectorType>(),
+                               dense.getSplatValue<Attribute>());
     rewriter.replaceOpWithNewOp<arith::ConstantOp>(shapeCastOp, newAttr);
     return success();
   }
index 39b83c2..db0e7c2 100644 (file)
@@ -139,7 +139,7 @@ convertDenseElementsAttr(Location loc, DenseElementsAttr denseElementsAttr,
   if (denseElementsAttr.isSplat() &&
       (type.isa<VectorType>() || hasVectorElementType)) {
     llvm::Constant *splatValue = LLVM::detail::getLLVMConstant(
-        innermostLLVMType, denseElementsAttr.getSplatValue(), loc,
+        innermostLLVMType, denseElementsAttr.getSplatValue<Attribute>(), loc,
         moduleTranslation, /*isTopLevel=*/false);
     llvm::Constant *splatVector =
         llvm::ConstantDataVector::getSplat(0, splatValue);
@@ -254,8 +254,9 @@ llvm::Constant *mlir::LLVM::detail::getLLVMConstant(
         isa<llvm::ArrayType, llvm::VectorType>(elementType);
     llvm::Constant *child = getLLVMConstant(
         elementType,
-        elementTypeSequential ? splatAttr : splatAttr.getSplatValue(), loc,
-        moduleTranslation, false);
+        elementTypeSequential ? splatAttr
+                              : splatAttr.getSplatValue<Attribute>(),
+        loc, moduleTranslation, false);
     if (!child)
       return nullptr;
     if (llvmType->isVectorTy())