From 670ae4b6da874270aa0cd8ab32120c17b2eadb95 Mon Sep 17 00:00:00 2001 From: Frederik Gossen Date: Fri, 24 Jul 2020 13:29:51 +0000 Subject: [PATCH] [MLIR][Shape] Fold `shape.mul` Implement constant folding for `shape.mul`. Differential Revision: https://reviews.llvm.org/D84438 --- mlir/include/mlir/Dialect/Shape/IR/ShapeOps.td | 1 + mlir/lib/Dialect/Shape/IR/Shape.cpp | 12 ++++++++ mlir/test/Dialect/Shape/canonicalize.mlir | 40 ++++++++++++++++++++++++++ 3 files changed, 53 insertions(+) diff --git a/mlir/include/mlir/Dialect/Shape/IR/ShapeOps.td b/mlir/include/mlir/Dialect/Shape/IR/ShapeOps.td index 425cf91..797dc0b 100644 --- a/mlir/include/mlir/Dialect/Shape/IR/ShapeOps.td +++ b/mlir/include/mlir/Dialect/Shape/IR/ShapeOps.td @@ -326,6 +326,7 @@ def Shape_MulOp : Shape_Op<"mul", [Commutative, NoSideEffect]> { }]; let verifier = [{ return ::verify(*this); }]; + let hasFolder = 1; } def Shape_NumElementsOp : Shape_Op<"num_elements", [NoSideEffect]> { diff --git a/mlir/lib/Dialect/Shape/IR/Shape.cpp b/mlir/lib/Dialect/Shape/IR/Shape.cpp index 2f64130..d2b0dbd 100644 --- a/mlir/lib/Dialect/Shape/IR/Shape.cpp +++ b/mlir/lib/Dialect/Shape/IR/Shape.cpp @@ -695,6 +695,18 @@ static LogicalResult verify(MulOp op) { return success(); } +OpFoldResult MulOp::fold(ArrayRef operands) { + auto lhs = operands[0].dyn_cast_or_null(); + if (!lhs) + return nullptr; + auto rhs = operands[1].dyn_cast_or_null(); + if (!rhs) + return nullptr; + APInt folded = lhs.getValue() * rhs.getValue(); + Type indexTy = IndexType::get(getContext()); + return IntegerAttr::get(indexTy, folded); +} + //===----------------------------------------------------------------------===// // ShapeOfOp //===----------------------------------------------------------------------===// diff --git a/mlir/test/Dialect/Shape/canonicalize.mlir b/mlir/test/Dialect/Shape/canonicalize.mlir index b4dca5e..577656a 100644 --- a/mlir/test/Dialect/Shape/canonicalize.mlir +++ b/mlir/test/Dialect/Shape/canonicalize.mlir @@ -734,3 +734,43 @@ func @shape_eq_do_not_fold(%a : !shape.shape) -> i1 { %result = shape.shape_eq %a, %b : !shape.shape, !shape.shape return %result : i1 } + +// ----- + +// Fold `mul` for constant sizes. +// CHECK-LABEL: @fold_mul_size +func @fold_mul_size() -> !shape.size { + // CHECK: %[[RESULT:.*]] = shape.const_size 6 + // CHECK: return %[[RESULT]] : !shape.size + %c2 = shape.const_size 2 + %c3 = shape.const_size 3 + %result = shape.mul %c2, %c3 : !shape.size, !shape.size -> !shape.size + return %result : !shape.size +} + +// ----- + +// Fold `mul` for constant indices. +// CHECK-LABEL: @fold_mul_index +func @fold_mul_index() -> index { + // CHECK: %[[RESULT:.*]] = constant 6 : index + // CHECK: return %[[RESULT]] : index + %c2 = constant 2 : index + %c3 = constant 3 : index + %result = shape.mul %c2, %c3 : index, index -> index + return %result : index +} + +// ----- + +// Fold `mul` for mixed constants. +// CHECK-LABEL: @fold_mul_mixed +func @fold_mul_mixed() -> !shape.size { + // CHECK: %[[RESULT:.*]] = shape.const_size 6 + // CHECK: return %[[RESULT]] : !shape.size + %c2 = shape.const_size 2 + %c3 = constant 3 : index + %result = shape.mul %c2, %c3 : !shape.size, index -> !shape.size + return %result : !shape.size +} + -- 2.7.4