[mlir] Remove the dependency on StandardOps from FoldUtils
authorRiver Riddle <riddleriver@gmail.com>
Thu, 10 Dec 2020 22:13:37 +0000 (14:13 -0800)
committerRiver Riddle <riddleriver@gmail.com>
Thu, 10 Dec 2020 22:13:57 +0000 (14:13 -0800)
OperationFolder currently uses ConstantOp as a backup when trying to materialize a constant after an operation is folded. This dependency isn't really useful or necessary given that dialects can/should provide a `materializeConstant` implementation.

Fixes PR#44866

Differential Revision: https://reviews.llvm.org/D92980

mlir/lib/Dialect/Shape/IR/Shape.cpp
mlir/lib/Transforms/Utils/FoldUtils.cpp
mlir/test/lib/Dialect/Test/TestDialect.cpp
mlir/test/lib/Dialect/Test/TestOps.td
mlir/test/mlir-tblgen/pattern.mlir

index c71360c..ef29ddc 100644 (file)
@@ -110,7 +110,7 @@ Operation *ShapeDialect::materializeConstant(OpBuilder &builder,
     return builder.create<ConstSizeOp>(loc, type, value.cast<IntegerAttr>());
   if (type.isa<WitnessType>())
     return builder.create<ConstWitnessOp>(loc, type, value.cast<BoolAttr>());
-  if (type.isa<IndexType>())
+  if (ConstantOp::isBuildableWith(value, type))
     return builder.create<ConstantOp>(loc, type, value);
   return nullptr;
 }
index 074f71c..ba755a7 100644 (file)
@@ -13,7 +13,6 @@
 
 #include "mlir/Transforms/FoldUtils.h"
 
-#include "mlir/Dialect/StandardOps/IR/Ops.h"
 #include "mlir/IR/Builders.h"
 #include "mlir/IR/Matchers.h"
 #include "mlir/IR/Operation.h"
@@ -60,11 +59,6 @@ static Operation *materializeConstant(Dialect *dialect, OpBuilder &builder,
     assert(matchPattern(constOp, m_Constant()));
     return constOp;
   }
-
-  // If the dialect is unable to materialize a constant, check to see if the
-  // standard constant can be used.
-  if (ConstantOp::isBuildableWith(value, type))
-    return builder.create<ConstantOp>(loc, type, value);
   return nullptr;
 }
 
index c7e1b7f..eeff840 100644 (file)
@@ -178,6 +178,11 @@ void TestDialect::initialize() {
   allowUnknownOperations();
 }
 
+Operation *TestDialect::materializeConstant(OpBuilder &builder, Attribute value,
+                                            Type type, Location loc) {
+  return builder.create<TestOpConstant>(loc, type, value);
+}
+
 static Type parseTestType(MLIRContext *ctxt, DialectAsmParser &parser,
                           llvm::SetVector<Type> &stack) {
   StringRef typeTag;
index 1579e53..9008ee7 100644 (file)
@@ -23,6 +23,7 @@ include "TestInterfaces.td"
 def Test_Dialect : Dialect {
   let name = "test";
   let cppNamespace = "::mlir::test";
+  let hasConstantMaterializer = 1;
   let hasOperationAttrVerify = 1;
   let hasRegionArgAttrVerify = 1;
   let hasRegionResultAttrVerify = 1;
index 5496209..0425cf8 100644 (file)
@@ -254,7 +254,7 @@ func @verifyUnitAttr() -> (i32, i32) {
 
 // CHECK-LABEL: testConstOp
 func @testConstOp() -> (i32) {
-  // CHECK-NEXT: [[C0:%.+]] = constant 1
+  // CHECK-NEXT: [[C0:%.+]] = "test.constant"() {value = 1
   %0 = "test.constant"() {value = 1 : i32} : () -> i32
 
   // CHECK-NEXT: return [[C0]]
@@ -263,7 +263,7 @@ func @testConstOp() -> (i32) {
 
 // CHECK-LABEL: testConstOpUsed
 func @testConstOpUsed() -> (i32) {
-  // CHECK-NEXT: [[C0:%.+]] = constant 1
+  // CHECK-NEXT: [[C0:%.+]] = "test.constant"() {value = 1
   %0 = "test.constant"() {value = 1 : i32} : () -> i32
 
   // CHECK-NEXT: [[V0:%.+]] = "test.op_s"([[C0]])
@@ -275,7 +275,7 @@ func @testConstOpUsed() -> (i32) {
 
 // CHECK-LABEL: testConstOpReplaced
 func @testConstOpReplaced() -> (i32) {
-  // CHECK-NEXT: [[C0:%.+]] = constant 1
+  // CHECK-NEXT: [[C0:%.+]] = "test.constant"() {value = 1
   %0 = "test.constant"() {value = 1 : i32} : () -> i32
   %1 = "test.constant"() {value = 2 : i32} : () -> i32
 
@@ -288,10 +288,10 @@ func @testConstOpReplaced() -> (i32) {
 
 // CHECK-LABEL: testConstOpMatchFailure
 func @testConstOpMatchFailure() -> (i64) {
-  // CHECK-DAG: [[C0:%.+]] = constant 1
+  // CHECK-DAG: [[C0:%.+]] = "test.constant"() {value = 1
   %0 = "test.constant"() {value = 1 : i64} : () -> i64
 
-  // CHECK-DAG: [[C1:%.+]] = constant 2
+  // CHECK-DAG: [[C1:%.+]] = "test.constant"() {value = 2
   %1 = "test.constant"() {value = 2 : i64} : () -> i64
 
   // CHECK: [[V0:%.+]] = "test.op_r"([[C0]], [[C1]])
@@ -303,7 +303,7 @@ func @testConstOpMatchFailure() -> (i64) {
 
 // CHECK-LABEL: testConstOpMatchNonConst
 func @testConstOpMatchNonConst(%arg0 : i32) -> (i32) {
-  // CHECK-DAG: [[C0:%.+]] = constant 1
+  // CHECK-DAG: [[C0:%.+]] = "test.constant"() {value = 1
   %0 = "test.constant"() {value = 1 : i32} : () -> i32
 
   // CHECK: [[V0:%.+]] = "test.op_r"([[C0]], %arg0)