[mlir][tosa] Add clamp + clamp as single clamp canonicalization
authornot-jenni <jennik@google.com>
Sat, 22 Jan 2022 00:16:29 +0000 (16:16 -0800)
committerRob Suderman <rob.suderman@gmail.com>
Sat, 22 Jan 2022 00:24:43 +0000 (16:24 -0800)
When 2 clamp ops are in a row, they can be canonicalized into a single clamp
that uses the most constrained range

Reviewed By: rsuderman

Differential Revision: https://reviews.llvm.org/D117934

mlir/lib/Dialect/Tosa/IR/TosaOps.cpp
mlir/test/Dialect/Tosa/canonicalize.mlir

index f93e5b2..af8fa30 100644 (file)
@@ -526,9 +526,40 @@ struct ClampIsNoOp : public OpRewritePattern<tosa::ClampOp> {
   }
 };
 
+struct ClampClampOptimization : public OpRewritePattern<tosa::ClampOp> {
+  using OpRewritePattern<tosa::ClampOp>::OpRewritePattern;
+
+  LogicalResult matchAndRewrite(tosa::ClampOp op,
+                                PatternRewriter &rewriter) const override {
+    Value input = op.input();
+
+    Operation *definingOp = input.getDefiningOp();
+    if (!definingOp)
+      return failure();
+
+    if (tosa::ClampOp clampOp = dyn_cast<tosa::ClampOp>(definingOp)) {
+      auto min_fp = std::max(op.min_fp(), clampOp.min_fp()).convertToFloat();
+      auto max_fp = std::min(op.max_fp(), clampOp.max_fp()).convertToFloat();
+
+      auto min_int = std::max(op.min_int(), clampOp.min_int());
+      auto max_int = std::min(op.max_int(), clampOp.max_int());
+
+      rewriter.replaceOpWithNewOp<tosa::ClampOp>(
+          op, op.getType(), clampOp.input(),
+          rewriter.getI64IntegerAttr(min_int),
+          rewriter.getI64IntegerAttr(max_int), rewriter.getF32FloatAttr(min_fp),
+          rewriter.getF32FloatAttr(max_fp));
+      return success();
+    }
+
+    return failure();
+  }
+};
+
 void ClampOp::getCanonicalizationPatterns(OwningRewritePatternList &results,
                                           MLIRContext *context) {
   results.insert<ClampIsNoOp>(context);
+  results.insert<ClampClampOptimization>(context);
 }
 
 //===----------------------------------------------------------------------===//
index c1e828c..41303ee 100644 (file)
@@ -98,6 +98,16 @@ func @clamp_uint8_is_noop(%arg0: tensor<4xui8>) -> tensor<4xui8> {
 
 // -----
 
+// CHECK-LABEL: @clamp_twice_is_single_clamp
+func @clamp_twice_is_single_clamp(%arg0: tensor<4xi8>) -> tensor<4xi8> {
+  // CHECK: "tosa.clamp"(%arg0) {max_fp = 3.000000e+00 : f32, max_int = 2 : i64, min_fp = -3.000000e+00 : f32, min_int = -2 : i64}
+  %0 = "tosa.clamp"(%arg0) {max_fp = 3.0 : f32, max_int = 4 : i64, min_fp = -5.0 : f32, min_int = -2 : i64} :  (tensor<4xi8>) -> tensor<4xi8>
+  %1 = "tosa.clamp"(%0) {max_fp = 5.0 : f32, max_int = 2 : i64, min_fp = -3.0 : f32, min_int = -4 : i64} :  (tensor<4xi8>) -> tensor<4xi8>
+  return %1 : tensor<4xi8>
+}
+
+// -----
+
 // CHECK-LABEL: @concat_fold
 func @concat_fold(%arg0: tensor<?x1xf32>) -> tensor<?x1xf32> {
   // CHECK: return %arg0