[mlir] Add folder for shape.add
authorJacques Pienaar <jpienaar@google.com>
Sat, 16 Oct 2021 00:30:17 +0000 (17:30 -0700)
committerJacques Pienaar <jpienaar@google.com>
Sat, 16 Oct 2021 00:30:17 +0000 (17:30 -0700)
mlir/include/mlir/Dialect/Shape/IR/ShapeOps.td
mlir/lib/Dialect/Shape/IR/Shape.cpp
mlir/test/Dialect/Shape/canonicalize.mlir

index 606f906..a05f7f3 100644 (file)
@@ -55,6 +55,8 @@ def Shape_AddOp : Shape_Op<"add",
     // InferTypeOpInterface
     static bool isCompatibleReturnTypes(TypeRange l, TypeRange r);
   }];
+
+  let hasFolder = 1;
 }
 
 def Shape_BroadcastOp : Shape_Op<"broadcast", [Commutative, NoSideEffect]> {
index d9ced56..f21fafe 100644 (file)
@@ -9,12 +9,14 @@
 #include "mlir/Dialect/Shape/IR/Shape.h"
 
 #include "mlir/Dialect/Arithmetic/IR/Arithmetic.h"
+#include "mlir/Dialect/CommonFolders.h"
 #include "mlir/Dialect/StandardOps/IR/Ops.h"
 #include "mlir/Dialect/Tensor/IR/Tensor.h"
 #include "mlir/Dialect/Traits.h"
 #include "mlir/IR/Builders.h"
 #include "mlir/IR/BuiltinTypes.h"
 #include "mlir/IR/DialectImplementation.h"
+#include "mlir/IR/Matchers.h"
 #include "mlir/IR/PatternMatch.h"
 #include "mlir/IR/TypeUtilities.h"
 #include "mlir/Transforms/InliningUtils.h"
@@ -436,6 +438,15 @@ bool mlir::shape::AddOp::isCompatibleReturnTypes(TypeRange l, TypeRange r) {
   return eachHasOnlyOneOfTypes<SizeType, IndexType>(l, r);
 }
 
+OpFoldResult mlir::shape::AddOp::fold(ArrayRef<Attribute> operands) {
+  // add(x, 0) -> x
+  if (matchPattern(rhs(), m_Zero()))
+    return lhs();
+
+  return constFoldBinaryOp<IntegerAttr>(operands,
+                                        [](APInt a, APInt b) { return a + b; });
+}
+
 //===----------------------------------------------------------------------===//
 // AssumingAllOp
 //===----------------------------------------------------------------------===//
index c3e72a0..4a8839b 100644 (file)
@@ -1020,6 +1020,19 @@ func @shape_eq_do_not_fold(%a : !shape.shape) -> i1 {
 
 // -----
 
+// Fold `add` for constant sizes.
+// CHECK-LABEL: @fold_add_size
+func @fold_add_size() -> !shape.size {
+  // CHECK: %[[RESULT:.*]] = shape.const_size 5
+  // CHECK: return %[[RESULT]] : !shape.size
+  %c2 = shape.const_size 2
+  %c3 = shape.const_size 3
+  %result = shape.add %c2, %c3 : !shape.size, !shape.size -> !shape.size
+  return %result : !shape.size
+}
+
+// -----
+
 // Fold `mul` for constant sizes.
 // CHECK-LABEL: @fold_mul_size
 func @fold_mul_size() -> !shape.size {