From: Robert Suderman Date: Wed, 18 Aug 2021 18:55:54 +0000 (-0700) Subject: [mlir][tosa] Fix clamp to restrict only within valid bitwidth range X-Git-Tag: upstream/15.0.7~33576 X-Git-Url: http://review.tizen.org/git/?a=commitdiff_plain;h=76c9712196906a1c5c1598e196b6abed139f090e;p=platform%2Fupstream%2Fllvm.git [mlir][tosa] Fix clamp to restrict only within valid bitwidth range Its possible for the clamp to have invalid min/max values on its range. To fix this we validate the range of the min/max and clamp to a valid range. Reviewed By: NatashaKnk Differential Revision: https://reviews.llvm.org/D108256 --- diff --git a/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp b/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp index ca64e5d..f600c27 100644 --- a/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp +++ b/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp @@ -428,12 +428,32 @@ createLinalgBodyCalculationForElementwiseOp(Operation *op, ValueRange args, } if (isa(op) && elementTy.isa()) { - auto min = createConstFromIntAttribute(op, "min_int", elementTy, - rewriter); - auto max = createConstFromIntAttribute(op, "max_int", elementTy, - rewriter); - return clampHelper(loc, args[0], min, max, CmpIPredicate::slt, - rewriter); + auto intTy = elementTy.cast(); + int32_t min = static_cast( + op->getAttr("min_int").cast().getValue().getSExtValue()); + int32_t max = static_cast( + op->getAttr("max_int").cast().getValue().getSExtValue()); + + if (intTy.isUnsignedInteger()) { + min = std::max(min, 0); + max = std::min( + max, + APInt::getMaxValue(intTy.getIntOrFloatBitWidth()).getSExtValue()); + } else { + min = std::max( + min, APInt::getSignedMinValue(intTy.getIntOrFloatBitWidth()) + .getSExtValue()); + max = std::min( + max, APInt::getSignedMaxValue(intTy.getIntOrFloatBitWidth()) + .getSExtValue()); + } + + auto minVal = + rewriter.create(loc, min, intTy.getIntOrFloatBitWidth()); + auto maxVal = + rewriter.create(loc, max, intTy.getIntOrFloatBitWidth()); + return clampHelper(loc, args[0], minVal, maxVal, + CmpIPredicate::slt, rewriter); } // tosa::ReluNOp diff --git a/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir b/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir index 88906b7..99c33d9 100644 --- a/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir +++ b/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir @@ -404,6 +404,31 @@ func @test_simple_i32(%arg0: tensor<1xi32>) -> () { // ----- +// CHECK-LABEL: @test_i8 +func @test_i8(%arg0: tensor<1xi8>) -> () { + // CHECK: linalg.generic + // CHECK-DAG: %[[C127:.+]] = constant -127 + // CHECK-DAG: %[[C126:.+]] = constant 126 + // CHECK-DAG: %[[CMP1:.+]] = cmpi slt, %arg1, %[[C127]] + // CHECK-DAG: %[[SEL1:.+]] = select %[[CMP1]], %[[C127]] + // CHECK-DAG: %[[CMP2:.+]] = cmpi slt, %[[C126]], %arg1 + // CHECK: %[[SEL2:.+]] = select %[[CMP2]], %[[C126]], %[[SEL1]] + %0 = "tosa.clamp"(%arg0) {min_int = -127 : i64, max_int = 126 : i64, min_fp = 0.0 : f32, max_fp = 0.0 : f32} : (tensor<1xi8>) -> tensor<1xi8> + + // CHECK: linalg.generic + // CHECK-DAG: %[[C128:.+]] = constant -128 + // CHECK-DAG: %[[C127:.+]] = constant 127 + // CHECK-DAG: %[[CMP1:.+]] = cmpi slt, %arg1, %[[C128]] + // CHECK-DAG: %[[SEL1:.+]] = select %[[CMP1]], %[[C128]] + // CHECK-DAG: %[[CMP2:.+]] = cmpi slt, %[[C127]], %arg1 + // CHECK: %[[SEL2:.+]] = select %[[CMP2]], %[[C127]], %[[SEL1]] + %1 = "tosa.clamp"(%arg0) {min_int = -130 : i64, max_int = 130 : i64, min_fp = 0.0 : f32, max_fp = 0.0 : f32} : (tensor<1xi8>) -> tensor<1xi8> + + return +} + +// ----- + // CHECK-LABEL: @test_bool func @test_bool(%arg0: tensor<1xi1>, %arg1: tensor<1xi1>) -> () { // CHECK: linalg.generic