[mlir][arith] Add arith.constant materialization helper
authorRahul Kayaith <rkayaith@gmail.com>
Sun, 16 Apr 2023 21:26:03 +0000 (17:26 -0400)
committerRahul Kayaith <rkayaith@gmail.com>
Thu, 20 Apr 2023 20:31:52 +0000 (16:31 -0400)
This adds `arith::ConstantOp::materialize`, which builds a constant from
an attribute and type only if it would result in a valid op. This is
useful for dialect `materializeConstant` hooks, and allows for removing
the previous `Attribute, Type` builder which was only used during
materialization.

Reviewed By: rriddle

Differential Revision: https://reviews.llvm.org/D148491

12 files changed:
mlir/include/mlir/Dialect/Arith/IR/ArithOps.td
mlir/lib/Dialect/Affine/IR/AffineOps.cpp
mlir/lib/Dialect/Arith/IR/ArithDialect.cpp
mlir/lib/Dialect/Arith/IR/ArithOps.cpp
mlir/lib/Dialect/Complex/IR/ComplexDialect.cpp
mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp
mlir/lib/Dialect/Math/IR/MathOps.cpp
mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp
mlir/lib/Dialect/Shape/IR/Shape.cpp
mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
mlir/lib/Dialect/Vector/IR/VectorOps.cpp

index 35f4d07..7b7b30e 100644 (file)
@@ -178,15 +178,15 @@ def Arith_ConstantOp : Op<Arith_Dialect, "constant",
   // splitting the Standard dialect.
   let results = (outs /*SignlessIntegerOrFloatLike*/AnyType:$result);
 
-  let builders = [
-    OpBuilder<(ins "Attribute":$value, "Type":$type),
-    [{ build($_builder, $_state, type, value); }]>,
-  ];
-
   let extraClassDeclaration = [{
     /// Whether the constant op can be constructed with a particular value and
     /// type.
     static bool isBuildableWith(Attribute value, Type type);
+
+    /// Build the constant op with `value` and `type` if possible, otherwise
+    /// returns null.
+    static ConstantOp materialize(OpBuilder &builder, Attribute value,
+                                  Type type, Location loc);
   }];
 
   let hasFolder = 1;
index 0ffb258..bd98110 100644 (file)
@@ -225,7 +225,7 @@ void AffineDialect::initialize() {
 Operation *AffineDialect::materializeConstant(OpBuilder &builder,
                                               Attribute value, Type type,
                                               Location loc) {
-  return builder.create<arith::ConstantOp>(loc, type, value);
+  return arith::ConstantOp::materialize(builder, value, type, loc);
 }
 
 /// A utility function to check if a value is defined at the top level of an
index 0a7b2c4..7f2d793 100644 (file)
@@ -49,5 +49,5 @@ void arith::ArithDialect::initialize() {
 Operation *arith::ArithDialect::materializeConstant(OpBuilder &builder,
                                                     Attribute value, Type type,
                                                     Location loc) {
-  return builder.create<arith::ConstantOp>(loc, value, type);
+  return ConstantOp::materialize(builder, value, type, loc);
 }
index e203dbc..d3ca198 100644 (file)
@@ -185,6 +185,13 @@ bool arith::ConstantOp::isBuildableWith(Attribute value, Type type) {
   return value.isa<IntegerAttr, FloatAttr, ElementsAttr>();
 }
 
+ConstantOp arith::ConstantOp::materialize(OpBuilder &builder, Attribute value,
+                                          Type type, Location loc) {
+  if (isBuildableWith(value, type))
+    return builder.create<arith::ConstantOp>(loc, cast<TypedAttr>(value));
+  return nullptr;
+}
+
 OpFoldResult arith::ConstantOp::fold(FoldAdaptor adaptor) { return getValue(); }
 
 void arith::ConstantIntOp::build(OpBuilder &builder, OperationState &result,
index ec56d93..0a2691a 100644 (file)
@@ -36,9 +36,7 @@ Operation *complex::ComplexDialect::materializeConstant(OpBuilder &builder,
     return builder.create<complex::ConstantOp>(loc, type,
                                                value.cast<ArrayAttr>());
   }
-  if (arith::ConstantOp::isBuildableWith(value, type))
-    return builder.create<arith::ConstantOp>(loc, type, value);
-  return nullptr;
+  return arith::ConstantOp::materialize(builder, value, type, loc);
 }
 
 #define GET_ATTRDEF_CLASSES
index 07ddc02..f9265b4 100644 (file)
@@ -2109,5 +2109,5 @@ void LinalgDialect::getCanonicalizationPatterns(
 Operation *LinalgDialect::materializeConstant(OpBuilder &builder,
                                               Attribute value, Type type,
                                               Location loc) {
-  return builder.create<arith::ConstantOp>(loc, type, value);
+  return arith::ConstantOp::materialize(builder, value, type, loc);
 }
index f3c8c5c..57e6e2a 100644 (file)
@@ -1688,8 +1688,8 @@ public:
       }
 
       // Create a constant scalar value from the splat constant.
-      Value scalarConstant = rewriter.create<arith::ConstantOp>(
-          def->getLoc(), constantAttr, constantAttr.getType());
+      Value scalarConstant =
+          rewriter.create<arith::ConstantOp>(def->getLoc(), constantAttr);
 
       SmallVector<Value> outputOperands = genericOp.getOutputs();
       auto fusedOp = rewriter.create<GenericOp>(
index 5c93f9f..ae9dc08 100644 (file)
@@ -522,5 +522,5 @@ OpFoldResult math::TruncOp::fold(FoldAdaptor adaptor) {
 Operation *math::MathDialect::materializeConstant(OpBuilder &builder,
                                                   Attribute value, Type type,
                                                   Location loc) {
-  return builder.create<arith::ConstantOp>(loc, value, type);
+  return arith::ConstantOp::materialize(builder, value, type, loc);
 }
index ee47547..a828fb6 100644 (file)
@@ -74,9 +74,7 @@ struct Wrapper {
 Operation *MemRefDialect::materializeConstant(OpBuilder &builder,
                                               Attribute value, Type type,
                                               Location loc) {
-  if (arith::ConstantOp::isBuildableWith(value, type))
-    return builder.create<arith::ConstantOp>(loc, value, type);
-  return nullptr;
+  return arith::ConstantOp::materialize(builder, value, type, loc);
 }
 
 //===----------------------------------------------------------------------===//
index 9af32fb..3417388 100644 (file)
@@ -154,9 +154,7 @@ Operation *ShapeDialect::materializeConstant(OpBuilder &builder,
     return builder.create<ConstSizeOp>(loc, type, value.cast<IntegerAttr>());
   if (type.isa<WitnessType>())
     return builder.create<ConstWitnessOp>(loc, type, value.cast<BoolAttr>());
-  if (arith::ConstantOp::isBuildableWith(value, type))
-    return builder.create<arith::ConstantOp>(loc, type, value);
-  return nullptr;
+  return arith::ConstantOp::materialize(builder, value, type, loc);
 }
 
 LogicalResult ShapeDialect::verifyOperationAttribute(Operation *op,
index 0eca184..99382a3 100644 (file)
@@ -38,8 +38,8 @@ using namespace mlir::tensor;
 Operation *TensorDialect::materializeConstant(OpBuilder &builder,
                                               Attribute value, Type type,
                                               Location loc) {
-  if (arith::ConstantOp::isBuildableWith(value, type))
-    return builder.create<arith::ConstantOp>(loc, value, type);
+  if (auto op = arith::ConstantOp::materialize(builder, value, type, loc))
+    return op;
   if (complex::ConstantOp::isBuildableWith(value, type))
     return builder.create<complex::ConstantOp>(loc, type,
                                                value.cast<ArrayAttr>());
index 89ca099..e214820 100644 (file)
@@ -280,7 +280,7 @@ void VectorDialect::initialize() {
 Operation *VectorDialect::materializeConstant(OpBuilder &builder,
                                               Attribute value, Type type,
                                               Location loc) {
-  return builder.create<arith::ConstantOp>(loc, type, value);
+  return arith::ConstantOp::materialize(builder, value, type, loc);
 }
 
 IntegerType vector::getVectorSubscriptType(Builder &builder) {