From: Rahul Kayaith Date: Sun, 16 Apr 2023 21:26:03 +0000 (-0400) Subject: [mlir][arith] Add arith.constant materialization helper X-Git-Tag: upstream/17.0.6~10880 X-Git-Url: http://review.tizen.org/git/?a=commitdiff_plain;h=00e3566d6c98f7c531be5140a614ca7fb3cc03a1;p=platform%2Fupstream%2Fllvm.git [mlir][arith] Add arith.constant materialization helper This adds `arith::ConstantOp::materialize`, which builds a constant from an attribute and type only if it would result in a valid op. This is useful for dialect `materializeConstant` hooks, and allows for removing the previous `Attribute, Type` builder which was only used during materialization. Reviewed By: rriddle Differential Revision: https://reviews.llvm.org/D148491 --- diff --git a/mlir/include/mlir/Dialect/Arith/IR/ArithOps.td b/mlir/include/mlir/Dialect/Arith/IR/ArithOps.td index 35f4d0761db9..7b7b30e84ce2 100644 --- a/mlir/include/mlir/Dialect/Arith/IR/ArithOps.td +++ b/mlir/include/mlir/Dialect/Arith/IR/ArithOps.td @@ -178,15 +178,15 @@ def Arith_ConstantOp : Op, - ]; - let extraClassDeclaration = [{ /// Whether the constant op can be constructed with a particular value and /// type. static bool isBuildableWith(Attribute value, Type type); + + /// Build the constant op with `value` and `type` if possible, otherwise + /// returns null. + static ConstantOp materialize(OpBuilder &builder, Attribute value, + Type type, Location loc); }]; let hasFolder = 1; diff --git a/mlir/lib/Dialect/Affine/IR/AffineOps.cpp b/mlir/lib/Dialect/Affine/IR/AffineOps.cpp index 0ffb258f3f48..bd9811095356 100644 --- a/mlir/lib/Dialect/Affine/IR/AffineOps.cpp +++ b/mlir/lib/Dialect/Affine/IR/AffineOps.cpp @@ -225,7 +225,7 @@ void AffineDialect::initialize() { Operation *AffineDialect::materializeConstant(OpBuilder &builder, Attribute value, Type type, Location loc) { - return builder.create(loc, type, value); + return arith::ConstantOp::materialize(builder, value, type, loc); } /// A utility function to check if a value is defined at the top level of an diff --git a/mlir/lib/Dialect/Arith/IR/ArithDialect.cpp b/mlir/lib/Dialect/Arith/IR/ArithDialect.cpp index 0a7b2c4f0b56..7f2d79355fe0 100644 --- a/mlir/lib/Dialect/Arith/IR/ArithDialect.cpp +++ b/mlir/lib/Dialect/Arith/IR/ArithDialect.cpp @@ -49,5 +49,5 @@ void arith::ArithDialect::initialize() { Operation *arith::ArithDialect::materializeConstant(OpBuilder &builder, Attribute value, Type type, Location loc) { - return builder.create(loc, value, type); + return ConstantOp::materialize(builder, value, type, loc); } diff --git a/mlir/lib/Dialect/Arith/IR/ArithOps.cpp b/mlir/lib/Dialect/Arith/IR/ArithOps.cpp index e203dbc84733..d3ca1987a170 100644 --- a/mlir/lib/Dialect/Arith/IR/ArithOps.cpp +++ b/mlir/lib/Dialect/Arith/IR/ArithOps.cpp @@ -185,6 +185,13 @@ bool arith::ConstantOp::isBuildableWith(Attribute value, Type type) { return value.isa(); } +ConstantOp arith::ConstantOp::materialize(OpBuilder &builder, Attribute value, + Type type, Location loc) { + if (isBuildableWith(value, type)) + return builder.create(loc, cast(value)); + return nullptr; +} + OpFoldResult arith::ConstantOp::fold(FoldAdaptor adaptor) { return getValue(); } void arith::ConstantIntOp::build(OpBuilder &builder, OperationState &result, diff --git a/mlir/lib/Dialect/Complex/IR/ComplexDialect.cpp b/mlir/lib/Dialect/Complex/IR/ComplexDialect.cpp index ec56d93c6f15..0a2691a113f7 100644 --- a/mlir/lib/Dialect/Complex/IR/ComplexDialect.cpp +++ b/mlir/lib/Dialect/Complex/IR/ComplexDialect.cpp @@ -36,9 +36,7 @@ Operation *complex::ComplexDialect::materializeConstant(OpBuilder &builder, return builder.create(loc, type, value.cast()); } - if (arith::ConstantOp::isBuildableWith(value, type)) - return builder.create(loc, type, value); - return nullptr; + return arith::ConstantOp::materialize(builder, value, type, loc); } #define GET_ATTRDEF_CLASSES diff --git a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp index 07ddc02c00f4..f9265b43eb37 100644 --- a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp +++ b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp @@ -2109,5 +2109,5 @@ void LinalgDialect::getCanonicalizationPatterns( Operation *LinalgDialect::materializeConstant(OpBuilder &builder, Attribute value, Type type, Location loc) { - return builder.create(loc, type, value); + return arith::ConstantOp::materialize(builder, value, type, loc); } diff --git a/mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp b/mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp index f3c8c5c06a08..57e6e2a6c81e 100644 --- a/mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp @@ -1688,8 +1688,8 @@ public: } // Create a constant scalar value from the splat constant. - Value scalarConstant = rewriter.create( - def->getLoc(), constantAttr, constantAttr.getType()); + Value scalarConstant = + rewriter.create(def->getLoc(), constantAttr); SmallVector outputOperands = genericOp.getOutputs(); auto fusedOp = rewriter.create( diff --git a/mlir/lib/Dialect/Math/IR/MathOps.cpp b/mlir/lib/Dialect/Math/IR/MathOps.cpp index 5c93f9f7017c..ae9dc08c745b 100644 --- a/mlir/lib/Dialect/Math/IR/MathOps.cpp +++ b/mlir/lib/Dialect/Math/IR/MathOps.cpp @@ -522,5 +522,5 @@ OpFoldResult math::TruncOp::fold(FoldAdaptor adaptor) { Operation *math::MathDialect::materializeConstant(OpBuilder &builder, Attribute value, Type type, Location loc) { - return builder.create(loc, value, type); + return arith::ConstantOp::materialize(builder, value, type, loc); } diff --git a/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp b/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp index ee47547a1775..a828fb6a7a67 100644 --- a/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp +++ b/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp @@ -74,9 +74,7 @@ struct Wrapper { Operation *MemRefDialect::materializeConstant(OpBuilder &builder, Attribute value, Type type, Location loc) { - if (arith::ConstantOp::isBuildableWith(value, type)) - return builder.create(loc, value, type); - return nullptr; + return arith::ConstantOp::materialize(builder, value, type, loc); } //===----------------------------------------------------------------------===// diff --git a/mlir/lib/Dialect/Shape/IR/Shape.cpp b/mlir/lib/Dialect/Shape/IR/Shape.cpp index 9af32fb5afe7..3417388d0bb9 100644 --- a/mlir/lib/Dialect/Shape/IR/Shape.cpp +++ b/mlir/lib/Dialect/Shape/IR/Shape.cpp @@ -154,9 +154,7 @@ Operation *ShapeDialect::materializeConstant(OpBuilder &builder, return builder.create(loc, type, value.cast()); if (type.isa()) return builder.create(loc, type, value.cast()); - if (arith::ConstantOp::isBuildableWith(value, type)) - return builder.create(loc, type, value); - return nullptr; + return arith::ConstantOp::materialize(builder, value, type, loc); } LogicalResult ShapeDialect::verifyOperationAttribute(Operation *op, diff --git a/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp b/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp index 0eca1843ea19..99382a375c1f 100644 --- a/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp +++ b/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp @@ -38,8 +38,8 @@ using namespace mlir::tensor; Operation *TensorDialect::materializeConstant(OpBuilder &builder, Attribute value, Type type, Location loc) { - if (arith::ConstantOp::isBuildableWith(value, type)) - return builder.create(loc, value, type); + if (auto op = arith::ConstantOp::materialize(builder, value, type, loc)) + return op; if (complex::ConstantOp::isBuildableWith(value, type)) return builder.create(loc, type, value.cast()); diff --git a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp index 89ca09911230..e214820c2f47 100644 --- a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp +++ b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp @@ -280,7 +280,7 @@ void VectorDialect::initialize() { Operation *VectorDialect::materializeConstant(OpBuilder &builder, Attribute value, Type type, Location loc) { - return builder.create(loc, type, value); + return arith::ConstantOp::materialize(builder, value, type, loc); } IntegerType vector::getVectorSubscriptType(Builder &builder) {