From 8b109bc2eae0d33a140982c02c77501932bfa394 Mon Sep 17 00:00:00 2001 From: Jacques Pienaar Date: Tue, 6 Apr 2021 20:22:42 -0700 Subject: [PATCH] [mlir,shape] Add max/min folder for simple case When both arguments are the same for these ops, propagate this argument. --- mlir/include/mlir/Dialect/Shape/IR/ShapeOps.td | 4 ++++ mlir/lib/Dialect/Shape/IR/Shape.cpp | 22 ++++++++++++++++++++++ mlir/test/Dialect/Shape/canonicalize.mlir | 20 ++++++++++++++++++++ 3 files changed, 46 insertions(+) diff --git a/mlir/include/mlir/Dialect/Shape/IR/ShapeOps.td b/mlir/include/mlir/Dialect/Shape/IR/ShapeOps.td index 0b8c26d..41e6f8a 100644 --- a/mlir/include/mlir/Dialect/Shape/IR/ShapeOps.td +++ b/mlir/include/mlir/Dialect/Shape/IR/ShapeOps.td @@ -416,6 +416,8 @@ def Shape_MaxOp : Shape_Op<"max", [Commutative, NoSideEffect]> { let assemblyFormat = [{ $lhs `,` $rhs attr-dict `:` type($lhs) `,` type($rhs) `->` type($result) }]; + + let hasFolder = 1; } def Shape_MinOp : Shape_Op<"min", [Commutative, NoSideEffect]> { @@ -433,6 +435,8 @@ def Shape_MinOp : Shape_Op<"min", [Commutative, NoSideEffect]> { let assemblyFormat = [{ $lhs `,` $rhs attr-dict `:` type($lhs) `,` type($rhs) `->` type($result) }]; + + let hasFolder = 1; } def Shape_MulOp : Shape_Op<"mul", [Commutative, NoSideEffect]> { diff --git a/mlir/lib/Dialect/Shape/IR/Shape.cpp b/mlir/lib/Dialect/Shape/IR/Shape.cpp index bb7ed5c..388a3a5 100644 --- a/mlir/lib/Dialect/Shape/IR/Shape.cpp +++ b/mlir/lib/Dialect/Shape/IR/Shape.cpp @@ -938,6 +938,28 @@ void NumElementsOp::build(OpBuilder &builder, OperationState &result, } //===----------------------------------------------------------------------===// +// MaxOp +//===----------------------------------------------------------------------===// + +OpFoldResult MaxOp::fold(llvm::ArrayRef operands) { + // If operands are equal, just propagate one. + if (lhs() == rhs()) + return lhs(); + return nullptr; +} + +//===----------------------------------------------------------------------===// +// MinOp +//===----------------------------------------------------------------------===// + +OpFoldResult MinOp::fold(llvm::ArrayRef operands) { + // If operands are equal, just propagate one. + if (lhs() == rhs()) + return lhs(); + return nullptr; +} + +//===----------------------------------------------------------------------===// // MulOp //===----------------------------------------------------------------------===// diff --git a/mlir/test/Dialect/Shape/canonicalize.mlir b/mlir/test/Dialect/Shape/canonicalize.mlir index b0c12ea..86ac4c9 100644 --- a/mlir/test/Dialect/Shape/canonicalize.mlir +++ b/mlir/test/Dialect/Shape/canonicalize.mlir @@ -1188,3 +1188,23 @@ func @casted_extent_tensor(%arg : tensor<*xf32>) -> tensor<3xindex> { %1 = tensor.cast %0 : tensor to tensor<3xindex> return %1 : tensor<3xindex> } + +// ---- + +// CHECK-LABEL: max_same_arg +// CHECK-SAME: (%[[SHAPE:.*]]: !shape.shape) +func @max_same_arg(%a: !shape.shape) -> !shape.shape { + %1 = shape.max %a, %a : !shape.shape, !shape.shape -> !shape.shape + // CHECK: return %[[SHAPE]] + return %1 : !shape.shape +} + +// ---- + +// CHECK-LABEL: min_same_arg +// CHECK-SAME: (%[[SHAPE:.*]]: !shape.shape) +func @min_same_arg(%a: !shape.shape) -> !shape.shape { + %1 = shape.min %a, %a : !shape.shape, !shape.shape -> !shape.shape + // CHECK: return %[[SHAPE]] + return %1 : !shape.shape +} -- 2.7.4