#include "mlir/Dialect/Shape/IR/Shape.h"
#include "mlir/Dialect/Arithmetic/IR/Arithmetic.h"
+#include "mlir/Dialect/CommonFolders.h"
#include "mlir/Dialect/StandardOps/IR/Ops.h"
#include "mlir/Dialect/Tensor/IR/Tensor.h"
#include "mlir/Dialect/Traits.h"
#include "mlir/IR/Builders.h"
#include "mlir/IR/BuiltinTypes.h"
#include "mlir/IR/DialectImplementation.h"
+#include "mlir/IR/Matchers.h"
#include "mlir/IR/PatternMatch.h"
#include "mlir/IR/TypeUtilities.h"
#include "mlir/Transforms/InliningUtils.h"
return eachHasOnlyOneOfTypes<SizeType, IndexType>(l, r);
}
+OpFoldResult mlir::shape::AddOp::fold(ArrayRef<Attribute> operands) {
+ // add(x, 0) -> x
+ if (matchPattern(rhs(), m_Zero()))
+ return lhs();
+
+ return constFoldBinaryOp<IntegerAttr>(operands,
+ [](APInt a, APInt b) { return a + b; });
+}
+
//===----------------------------------------------------------------------===//
// AssumingAllOp
//===----------------------------------------------------------------------===//
// -----
+// Fold `add` for constant sizes.
+// CHECK-LABEL: @fold_add_size
+func @fold_add_size() -> !shape.size {
+ // CHECK: %[[RESULT:.*]] = shape.const_size 5
+ // CHECK: return %[[RESULT]] : !shape.size
+ %c2 = shape.const_size 2
+ %c3 = shape.const_size 3
+ %result = shape.add %c2, %c3 : !shape.size, !shape.size -> !shape.size
+ return %result : !shape.size
+}
+
+// -----
+
// Fold `mul` for constant sizes.
// CHECK-LABEL: @fold_mul_size
func @fold_mul_size() -> !shape.size {