[mlir][arith] Add arith.constant materialization helper
authorRahul Kayaith <rkayaith@gmail.com>
Sun, 16 Apr 2023 21:26:03 +0000 (17:26 -0400)
committerRahul Kayaith <rkayaith@gmail.com>
Thu, 20 Apr 2023 20:31:52 +0000 (16:31 -0400)
This adds `arith::ConstantOp::materialize`, which builds a constant from
an attribute and type only if it would result in a valid op. This is
useful for dialect `materializeConstant` hooks, and allows for removing
the previous `Attribute, Type` builder which was only used during
materialization.

Reviewed By: rriddle

Differential Revision: https://reviews.llvm.org/D148491

12 files changed:
mlir/include/mlir/Dialect/Arith/IR/ArithOps.td
mlir/lib/Dialect/Affine/IR/AffineOps.cpp
mlir/lib/Dialect/Arith/IR/ArithDialect.cpp
mlir/lib/Dialect/Arith/IR/ArithOps.cpp
mlir/lib/Dialect/Complex/IR/ComplexDialect.cpp
mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp
mlir/lib/Dialect/Math/IR/MathOps.cpp
mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp
mlir/lib/Dialect/Shape/IR/Shape.cpp
mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
mlir/lib/Dialect/Vector/IR/VectorOps.cpp

index 35f4d0761db9370b0349e4928013f93bafbce6e4..7b7b30e84ce2d987c2a693dd2914cc80242d3d36 100644 (file)
@@ -178,15 +178,15 @@ def Arith_ConstantOp : Op<Arith_Dialect, "constant",
   // splitting the Standard dialect.
   let results = (outs /*SignlessIntegerOrFloatLike*/AnyType:$result);
 
-  let builders = [
-    OpBuilder<(ins "Attribute":$value, "Type":$type),
-    [{ build($_builder, $_state, type, value); }]>,
-  ];
-
   let extraClassDeclaration = [{
     /// Whether the constant op can be constructed with a particular value and
     /// type.
     static bool isBuildableWith(Attribute value, Type type);
+
+    /// Build the constant op with `value` and `type` if possible, otherwise
+    /// returns null.
+    static ConstantOp materialize(OpBuilder &builder, Attribute value,
+                                  Type type, Location loc);
   }];
 
   let hasFolder = 1;
index 0ffb258f3f48875e11657b6707e7a90fa468019c..bd9811095356b1ab8d951f110a508fafae5f41ea 100644 (file)
@@ -225,7 +225,7 @@ void AffineDialect::initialize() {
 Operation *AffineDialect::materializeConstant(OpBuilder &builder,
                                               Attribute value, Type type,
                                               Location loc) {
-  return builder.create<arith::ConstantOp>(loc, type, value);
+  return arith::ConstantOp::materialize(builder, value, type, loc);
 }
 
 /// A utility function to check if a value is defined at the top level of an
index 0a7b2c4f0b561a1dcac32f3e9ba151ef3c048870..7f2d79355fe0f0fc636b4bc4ffe0f089c60d2c42 100644 (file)
@@ -49,5 +49,5 @@ void arith::ArithDialect::initialize() {
 Operation *arith::ArithDialect::materializeConstant(OpBuilder &builder,
                                                     Attribute value, Type type,
                                                     Location loc) {
-  return builder.create<arith::ConstantOp>(loc, value, type);
+  return ConstantOp::materialize(builder, value, type, loc);
 }
index e203dbc847339014c2d9be6025e8d967e03c38ac..d3ca1987a17071fdbefe816ca42bff50d6c199af 100644 (file)
@@ -185,6 +185,13 @@ bool arith::ConstantOp::isBuildableWith(Attribute value, Type type) {
   return value.isa<IntegerAttr, FloatAttr, ElementsAttr>();
 }
 
+ConstantOp arith::ConstantOp::materialize(OpBuilder &builder, Attribute value,
+                                          Type type, Location loc) {
+  if (isBuildableWith(value, type))
+    return builder.create<arith::ConstantOp>(loc, cast<TypedAttr>(value));
+  return nullptr;
+}
+
 OpFoldResult arith::ConstantOp::fold(FoldAdaptor adaptor) { return getValue(); }
 
 void arith::ConstantIntOp::build(OpBuilder &builder, OperationState &result,
index ec56d93c6f1568660988c0e0b928707be755dfaf..0a2691a113f718e8ea3f28eded77971a671329dd 100644 (file)
@@ -36,9 +36,7 @@ Operation *complex::ComplexDialect::materializeConstant(OpBuilder &builder,
     return builder.create<complex::ConstantOp>(loc, type,
                                                value.cast<ArrayAttr>());
   }
-  if (arith::ConstantOp::isBuildableWith(value, type))
-    return builder.create<arith::ConstantOp>(loc, type, value);
-  return nullptr;
+  return arith::ConstantOp::materialize(builder, value, type, loc);
 }
 
 #define GET_ATTRDEF_CLASSES
index 07ddc02c00f47854b08fb330264cc4bed85cb0b1..f9265b43eb379ee497f5afc0e260a0da2471a04d 100644 (file)
@@ -2109,5 +2109,5 @@ void LinalgDialect::getCanonicalizationPatterns(
 Operation *LinalgDialect::materializeConstant(OpBuilder &builder,
                                               Attribute value, Type type,
                                               Location loc) {
-  return builder.create<arith::ConstantOp>(loc, type, value);
+  return arith::ConstantOp::materialize(builder, value, type, loc);
 }
index f3c8c5c06a087b03af0ccc3e9aebd859b8a8fec4..57e6e2a6c81e416fdbd68bc27aeb9931b524addc 100644 (file)
@@ -1688,8 +1688,8 @@ public:
       }
 
       // Create a constant scalar value from the splat constant.
-      Value scalarConstant = rewriter.create<arith::ConstantOp>(
-          def->getLoc(), constantAttr, constantAttr.getType());
+      Value scalarConstant =
+          rewriter.create<arith::ConstantOp>(def->getLoc(), constantAttr);
 
       SmallVector<Value> outputOperands = genericOp.getOutputs();
       auto fusedOp = rewriter.create<GenericOp>(
index 5c93f9f7017c97d55992e7e1add4cd4f242543d4..ae9dc08c745b4fdd35bdb4d526deda8c2dc4f8e4 100644 (file)
@@ -522,5 +522,5 @@ OpFoldResult math::TruncOp::fold(FoldAdaptor adaptor) {
 Operation *math::MathDialect::materializeConstant(OpBuilder &builder,
                                                   Attribute value, Type type,
                                                   Location loc) {
-  return builder.create<arith::ConstantOp>(loc, value, type);
+  return arith::ConstantOp::materialize(builder, value, type, loc);
 }
index ee47547a1775b54c14f7e7b01b3d0936f5468a58..a828fb6a7a6798780561416413a2e4a5b2099cc8 100644 (file)
@@ -74,9 +74,7 @@ struct Wrapper {
 Operation *MemRefDialect::materializeConstant(OpBuilder &builder,
                                               Attribute value, Type type,
                                               Location loc) {
-  if (arith::ConstantOp::isBuildableWith(value, type))
-    return builder.create<arith::ConstantOp>(loc, value, type);
-  return nullptr;
+  return arith::ConstantOp::materialize(builder, value, type, loc);
 }
 
 //===----------------------------------------------------------------------===//
index 9af32fb5afe748848117462aeebcee24af7ca752..3417388d0bb9201621e3f8c1e7300338a3f2c633 100644 (file)
@@ -154,9 +154,7 @@ Operation *ShapeDialect::materializeConstant(OpBuilder &builder,
     return builder.create<ConstSizeOp>(loc, type, value.cast<IntegerAttr>());
   if (type.isa<WitnessType>())
     return builder.create<ConstWitnessOp>(loc, type, value.cast<BoolAttr>());
-  if (arith::ConstantOp::isBuildableWith(value, type))
-    return builder.create<arith::ConstantOp>(loc, type, value);
-  return nullptr;
+  return arith::ConstantOp::materialize(builder, value, type, loc);
 }
 
 LogicalResult ShapeDialect::verifyOperationAttribute(Operation *op,
index 0eca1843ea19f27dee7a2558bf6d3174f8a6c348..99382a375c1f47510155fae2007bc082689ed467 100644 (file)
@@ -38,8 +38,8 @@ using namespace mlir::tensor;
 Operation *TensorDialect::materializeConstant(OpBuilder &builder,
                                               Attribute value, Type type,
                                               Location loc) {
-  if (arith::ConstantOp::isBuildableWith(value, type))
-    return builder.create<arith::ConstantOp>(loc, value, type);
+  if (auto op = arith::ConstantOp::materialize(builder, value, type, loc))
+    return op;
   if (complex::ConstantOp::isBuildableWith(value, type))
     return builder.create<complex::ConstantOp>(loc, type,
                                                value.cast<ArrayAttr>());
index 89ca09911230948a960615d1676e340dc922a435..e214820c2f47f67ec6b908ac19c0b523556e3ac8 100644 (file)
@@ -280,7 +280,7 @@ void VectorDialect::initialize() {
 Operation *VectorDialect::materializeConstant(OpBuilder &builder,
                                               Attribute value, Type type,
                                               Location loc) {
-  return builder.create<arith::ConstantOp>(loc, type, value);
+  return arith::ConstantOp::materialize(builder, value, type, loc);
 }
 
 IntegerType vector::getVectorSubscriptType(Builder &builder) {