[MLIR][Shape] Fold `shape.mul`
authorFrederik Gossen <frgossen@google.com>
Fri, 24 Jul 2020 13:29:51 +0000 (13:29 +0000)
committerFrederik Gossen <frgossen@google.com>
Fri, 24 Jul 2020 13:30:45 +0000 (13:30 +0000)
Implement constant folding for `shape.mul`.

Differential Revision: https://reviews.llvm.org/D84438

mlir/include/mlir/Dialect/Shape/IR/ShapeOps.td
mlir/lib/Dialect/Shape/IR/Shape.cpp
mlir/test/Dialect/Shape/canonicalize.mlir

index 425cf91..797dc0b 100644 (file)
@@ -326,6 +326,7 @@ def Shape_MulOp : Shape_Op<"mul", [Commutative, NoSideEffect]> {
   }];
 
   let verifier = [{ return ::verify(*this); }];
+  let hasFolder = 1;
 }
 
 def Shape_NumElementsOp : Shape_Op<"num_elements", [NoSideEffect]> {
index 2f64130..d2b0dbd 100644 (file)
@@ -695,6 +695,18 @@ static LogicalResult verify(MulOp op) {
   return success();
 }
 
+OpFoldResult MulOp::fold(ArrayRef<Attribute> operands) {
+  auto lhs = operands[0].dyn_cast_or_null<IntegerAttr>();
+  if (!lhs)
+    return nullptr;
+  auto rhs = operands[1].dyn_cast_or_null<IntegerAttr>();
+  if (!rhs)
+    return nullptr;
+  APInt folded = lhs.getValue() * rhs.getValue();
+  Type indexTy = IndexType::get(getContext());
+  return IntegerAttr::get(indexTy, folded);
+}
+
 //===----------------------------------------------------------------------===//
 // ShapeOfOp
 //===----------------------------------------------------------------------===//
index b4dca5e..577656a 100644 (file)
@@ -734,3 +734,43 @@ func @shape_eq_do_not_fold(%a : !shape.shape) -> i1 {
   %result = shape.shape_eq %a, %b : !shape.shape, !shape.shape
   return %result : i1
 }
+
+// -----
+
+// Fold `mul` for constant sizes.
+// CHECK-LABEL: @fold_mul_size
+func @fold_mul_size() -> !shape.size {
+  // CHECK: %[[RESULT:.*]] = shape.const_size 6
+  // CHECK: return %[[RESULT]] : !shape.size
+  %c2 = shape.const_size 2
+  %c3 = shape.const_size 3
+  %result = shape.mul %c2, %c3 : !shape.size, !shape.size -> !shape.size
+  return %result : !shape.size
+}
+
+// -----
+
+// Fold `mul` for constant indices.
+// CHECK-LABEL: @fold_mul_index
+func @fold_mul_index() -> index {
+  // CHECK: %[[RESULT:.*]] = constant 6 : index
+  // CHECK: return %[[RESULT]] : index
+  %c2 = constant 2 : index
+  %c3 = constant 3 : index
+  %result = shape.mul %c2, %c3 : index, index -> index
+  return %result : index
+}
+
+// -----
+
+// Fold `mul` for mixed constants.
+// CHECK-LABEL: @fold_mul_mixed
+func @fold_mul_mixed() -> !shape.size {
+  // CHECK: %[[RESULT:.*]] = shape.const_size 6
+  // CHECK: return %[[RESULT]] : !shape.size
+  %c2 = shape.const_size 2
+  %c3 = constant 3 : index
+  %result = shape.mul %c2, %c3 : !shape.size, index -> !shape.size
+  return %result : !shape.size
+}
+