[mlir][tosa] Fix clamp to restrict only within valid bitwidth range
authorRobert Suderman <suderman@google.com>
Wed, 18 Aug 2021 18:55:54 +0000 (11:55 -0700)
committerRob Suderman <rob.suderman@gmail.com>
Wed, 18 Aug 2021 19:14:01 +0000 (12:14 -0700)
Its possible for the clamp to have invalid min/max values on its range. To fix
this we validate the range of the min/max and clamp to a valid range.

Reviewed By: NatashaKnk

Differential Revision: https://reviews.llvm.org/D108256

mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp
mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir

index ca64e5d..f600c27 100644 (file)
@@ -428,12 +428,32 @@ createLinalgBodyCalculationForElementwiseOp(Operation *op, ValueRange args,
   }
 
   if (isa<tosa::ClampOp>(op) && elementTy.isa<IntegerType>()) {
-    auto min = createConstFromIntAttribute<int32_t>(op, "min_int", elementTy,
-                                                    rewriter);
-    auto max = createConstFromIntAttribute<int32_t>(op, "max_int", elementTy,
-                                                    rewriter);
-    return clampHelper<mlir::CmpIOp>(loc, args[0], min, max, CmpIPredicate::slt,
-                                     rewriter);
+    auto intTy = elementTy.cast<IntegerType>();
+    int32_t min = static_cast<int32_t>(
+        op->getAttr("min_int").cast<IntegerAttr>().getValue().getSExtValue());
+    int32_t max = static_cast<int32_t>(
+        op->getAttr("max_int").cast<IntegerAttr>().getValue().getSExtValue());
+
+    if (intTy.isUnsignedInteger()) {
+      min = std::max<int32_t>(min, 0);
+      max = std::min<int32_t>(
+          max,
+          APInt::getMaxValue(intTy.getIntOrFloatBitWidth()).getSExtValue());
+    } else {
+      min = std::max<int32_t>(
+          min, APInt::getSignedMinValue(intTy.getIntOrFloatBitWidth())
+                   .getSExtValue());
+      max = std::min<int32_t>(
+          max, APInt::getSignedMaxValue(intTy.getIntOrFloatBitWidth())
+                   .getSExtValue());
+    }
+
+    auto minVal =
+        rewriter.create<ConstantIntOp>(loc, min, intTy.getIntOrFloatBitWidth());
+    auto maxVal =
+        rewriter.create<ConstantIntOp>(loc, max, intTy.getIntOrFloatBitWidth());
+    return clampHelper<mlir::CmpIOp>(loc, args[0], minVal, maxVal,
+                                     CmpIPredicate::slt, rewriter);
   }
 
   // tosa::ReluNOp
index 88906b7..99c33d9 100644 (file)
@@ -404,6 +404,31 @@ func @test_simple_i32(%arg0: tensor<1xi32>) -> () {
 
 // -----
 
+// CHECK-LABEL: @test_i8
+func @test_i8(%arg0: tensor<1xi8>) -> () {
+  // CHECK: linalg.generic
+  // CHECK-DAG: %[[C127:.+]] = constant -127
+  // CHECK-DAG: %[[C126:.+]] = constant 126
+  // CHECK-DAG: %[[CMP1:.+]] = cmpi slt, %arg1, %[[C127]]
+  // CHECK-DAG: %[[SEL1:.+]] = select %[[CMP1]], %[[C127]]
+  // CHECK-DAG: %[[CMP2:.+]] = cmpi slt, %[[C126]], %arg1
+  // CHECK: %[[SEL2:.+]] = select %[[CMP2]], %[[C126]], %[[SEL1]]
+  %0 = "tosa.clamp"(%arg0) {min_int = -127 : i64, max_int = 126 : i64, min_fp = 0.0 : f32, max_fp = 0.0 : f32} : (tensor<1xi8>) -> tensor<1xi8>
+
+  // CHECK: linalg.generic
+  // CHECK-DAG: %[[C128:.+]] = constant -128
+  // CHECK-DAG: %[[C127:.+]] = constant 127
+  // CHECK-DAG: %[[CMP1:.+]] = cmpi slt, %arg1, %[[C128]]
+  // CHECK-DAG: %[[SEL1:.+]] = select %[[CMP1]], %[[C128]]
+  // CHECK-DAG: %[[CMP2:.+]] = cmpi slt, %[[C127]], %arg1
+  // CHECK: %[[SEL2:.+]] = select %[[CMP2]], %[[C127]], %[[SEL1]]
+  %1 = "tosa.clamp"(%arg0) {min_int = -130 : i64, max_int = 130 : i64, min_fp = 0.0 : f32, max_fp = 0.0 : f32} : (tensor<1xi8>) -> tensor<1xi8>
+
+  return
+}
+
+// -----
+
 // CHECK-LABEL: @test_bool
 func @test_bool(%arg0: tensor<1xi1>, %arg1: tensor<1xi1>) -> () {
   // CHECK: linalg.generic