auto max = rewriter.create<arith::ConstantIntOp>(
loc, APInt::getSignedMaxValue(inputBitWidth).getSExtValue(),
intermediateType);
- auto clamp = clampHelper<arith::CmpIOp>(
- loc, sub, min, max, arith::CmpIPredicate::slt, rewriter);
+ auto clamp = clampIntHelper(loc, sub, min, max, rewriter);
// Truncate to the final value.
return rewriter.create<arith::TruncIOp>(loc, elementTy, clamp);
// tosa::MaximumOp
if (isa<tosa::MaximumOp>(op) && elementTy.isa<FloatType>()) {
- auto predicate = rewriter.create<arith::CmpFOp>(
- loc, arith::CmpFPredicate::OGT, args[0], args[1]);
- return rewriter.create<arith::SelectOp>(loc, predicate, args[0], args[1]);
+ return rewriter.create<arith::MaxFOp>(loc, args[0], args[1]);
}
if (isa<tosa::MaximumOp>(op) && elementTy.isSignlessInteger()) {
// tosa::MinimumOp
if (isa<tosa::MinimumOp>(op) && elementTy.isa<FloatType>()) {
- auto predicate = rewriter.create<arith::CmpFOp>(
- loc, arith::CmpFPredicate::OLT, args[0], args[1]);
- return rewriter.create<arith::SelectOp>(loc, predicate, args[0], args[1]);
+ return rewriter.create<arith::MinFOp>(loc, args[0], args[1]);
}
if (isa<tosa::MinimumOp>(op) && elementTy.isSignlessInteger()) {
loc, elementTy, rewriter.getFloatAttr(elementTy, min_apf));
auto max = rewriter.create<arith::ConstantOp>(
loc, elementTy, rewriter.getFloatAttr(elementTy, max_apf));
- return clampHelper<arith::CmpFOp>(loc, args[0], min, max,
- arith::CmpFPredicate::OLT, rewriter);
+ return clampFloatHelper(loc, args[0], min, max, rewriter);
}
if (isa<tosa::ClampOp>(op) && elementTy.isa<IntegerType>()) {
loc, min, intTy.getIntOrFloatBitWidth());
auto maxVal = rewriter.create<arith::ConstantIntOp>(
loc, max, intTy.getIntOrFloatBitWidth());
- return clampHelper<arith::CmpIOp>(loc, args[0], minVal, maxVal,
- arith::CmpIPredicate::slt, rewriter);
+ return clampIntHelper(loc, args[0], minVal, maxVal, rewriter);
}
// tosa::ReluNOp
APFloat::rmNearestTiesToEven, &losesInfo);
auto n = rewriter.create<arith::ConstantOp>(
loc, elementTy, rewriter.getFloatAttr(elementTy, max_apf));
- return clampHelper<arith::CmpFOp>(loc, args[0], zero, n,
- arith::CmpFPredicate::OLT, rewriter);
+ return clampFloatHelper(loc, args[0], zero, n, rewriter);
}
if (isa<tosa::ReluNOp>(op) && elementTy.isa<IntegerType>()) {
rewriter.create<arith::ConstantOp>(loc, IntegerAttr::get(elementTy, 0));
auto n = createConstFromIntAttribute<int32_t>(op, "max_int", elementTy,
rewriter);
- return clampHelper<arith::CmpIOp>(loc, args[0], zero, n,
- arith::CmpIPredicate::slt, rewriter);
+ return clampIntHelper(loc, args[0], zero, n, rewriter);
}
// tosa::SigmoidOp
auto rounded =
rewriter.create<arith::SelectOp>(loc, negative, subbed, added);
- auto clamped = clampHelper<arith::CmpFOp>(
- loc, rounded, intMin, intMax, arith::CmpFPredicate::OLT, rewriter);
+ auto clamped = clampFloatHelper(loc, rounded, intMin, intMax, rewriter);
return rewriter.create<arith::FPToSIOp>(loc, dstTy, clamped);
}
.getSExtValue(),
srcTy.getIntOrFloatBitWidth());
- auto clamped = clampHelper<arith::CmpIOp>(
- loc, args[0], intMin, intMax, arith::CmpIPredicate::slt, rewriter);
+ auto clamped = clampIntHelper(loc, args[0], intMin, intMax, rewriter);
return rewriter.create<arith::TruncIOp>(loc, dstTy, clamped);
}
}
}
if (isa<tosa::ReduceMinOp>(op) && elementTy.isa<FloatType>()) {
- auto predicate = rewriter.create<arith::CmpFOp>(
- loc, arith::CmpFPredicate::OLT, args[0], args[1]);
- return rewriter.create<arith::SelectOp>(loc, predicate, args[0], args[1]);
+ return rewriter.create<arith::MinFOp>(loc, args[0], args[1]);
}
if (isa<tosa::ReduceMinOp>(op) && elementTy.isa<IntegerType>()) {
}
if (isa<tosa::ReduceMaxOp>(op) && elementTy.isa<FloatType>()) {
- auto predicate = rewriter.create<arith::CmpFOp>(
- loc, arith::CmpFPredicate::OGT, args[0], args[1]);
- return rewriter.create<arith::SelectOp>(loc, predicate, args[0], args[1]);
+ return rewriter.create<arith::MaxFOp>(loc, args[0], args[1]);
}
if (isa<tosa::ReduceMaxOp>(op) && elementTy.isa<IntegerType>()) {
auto intMaxVal = nestedBuilder.create<arith::ConstantOp>(
loc, nestedBuilder.getI32IntegerAttr(intMax));
- value = clampHelper<arith::CmpIOp>(
- nestedLoc, value, intMinVal, intMaxVal, arith::CmpIPredicate::slt,
- nestedBuilder);
+ value = clampIntHelper(nestedLoc, value, intMinVal, intMaxVal,
+ nestedBuilder);
if (outIntType.getWidth() < 32) {
value = nestedBuilder.create<arith::TruncIOp>(
// Clamp the to be within the bounds of the input image.
- iy = clampHelper<arith::CmpIOp>(loc, iy, hwMin, hMax,
- arith::CmpIPredicate::slt, rewriter);
- ix = clampHelper<arith::CmpIOp>(loc, ix, hwMin, wMax,
- arith::CmpIPredicate::slt, rewriter);
+ iy = clampIntHelper(loc, iy, hwMin, hMax, rewriter);
+ ix = clampIntHelper(loc, ix, hwMin, wMax, rewriter);
// Read the value from the input array.
iy =
Value y1 = rewriter.create<arith::AddIOp>(loc, y0, oneVal);
Value x1 = rewriter.create<arith::AddIOp>(loc, x0, oneVal);
- y0 = clampHelper<arith::CmpIOp>(loc, y0, hwMin, hMax,
- arith::CmpIPredicate::slt, rewriter);
- y1 = clampHelper<arith::CmpIOp>(loc, y1, hwMin, hMax,
- arith::CmpIPredicate::slt, rewriter);
+ y0 = clampIntHelper(loc, y0, hwMin, hMax, rewriter);
+ y1 = clampIntHelper(loc, y1, hwMin, hMax, rewriter);
- x0 = clampHelper<arith::CmpIOp>(loc, x0, hwMin, wMax,
- arith::CmpIPredicate::slt, rewriter);
- x1 = clampHelper<arith::CmpIOp>(loc, x1, hwMin, wMax,
- arith::CmpIPredicate::slt, rewriter);
+ x0 = clampIntHelper(loc, x0, hwMin, wMax, rewriter);
+ x1 = clampIntHelper(loc, x1, hwMin, wMax, rewriter);
y0 =
rewriter.create<arith::IndexCastOp>(loc, rewriter.getIndexType(), y0);
%13 = "tosa.select"(%10, %0, %1) : (tensor<1xi1>, tensor<1xf32>, tensor<1xf32>) -> tensor<1xf32>
// CHECK: linalg.generic
- // CHECK: arith.cmpf
- // CHECK: select
+ // CHECK: arith.maxf
%14 = "tosa.maximum"(%0, %1) : (tensor<1xf32>, tensor<1xf32>) -> tensor<1xf32>
// CHECK: linalg.generic
- // CHECK: arith.cmpf
- // CHECK: select
+ // CHECK: arith.minf
%15 = "tosa.minimum"(%0, %1) : (tensor<1xf32>, tensor<1xf32>) -> tensor<1xf32>
// CHECK: linalg.generic
%17 = "tosa.floor"(%0) : (tensor<1xf32>) -> tensor<1xf32>
// CHECK: linalg.generic
- // CHECK: arith.cmpf
- // CHECK: select
+ // CHECK: arith.minf
+ // CHECK: arith.maxf
%18 = "tosa.clamp"(%0) {min_int = 1 : i64, max_int = 5 : i64, min_fp = 1.0 : f32, max_fp = 5.0 : f32} : (tensor<1xf32>) -> tensor<1xf32>
// CHECK: linalg.generic
- // CHECK: arith.cmpf
- // CHECK: select
+ // CHECK: arith.minf
+ // CHECK: arith.maxf
%19 = "tosa.reluN"(%0) {max_int = 5 : i64, max_fp = 5.0 : f32} : (tensor<1xf32>) -> tensor<1xf32>
// CHECK: linalg.generic
// CHECK: arith.subf
// CHECK: arith.cmpf olt
// CHECK: select
- // CHECK: arith.cmpf olt
- // CHECK: select
- // CHECK: arith.cmpf olt
- // CHECK: select
+ // CHECK: arith.minf
+ // CHECK: arith.maxf
// CHECK: arith.fptosi
%21 = "tosa.cast"(%0) : (tensor<1xf32>) -> tensor<1xi32>
// CHECK-LABEL: @test_i8
func.func @test_i8(%arg0: tensor<1xi8>) -> () {
// CHECK: linalg.generic
+ // CHECK: ^bb0(%[[ARG1:.+]]: i8,
// CHECK-DAG: %[[C127:.+]] = arith.constant -127
// CHECK-DAG: %[[C126:.+]] = arith.constant 126
- // CHECK-DAG: %[[CMP1:.+]] = arith.cmpi slt, %arg1, %[[C127]]
+ // CHECK-DAG: %[[CMP1:.+]] = arith.cmpi slt, %[[ARG1]], %[[C127]]
// CHECK-DAG: %[[SEL1:.+]] = arith.select %[[CMP1]], %[[C127]]
- // CHECK-DAG: %[[CMP2:.+]] = arith.cmpi slt, %[[C126]], %arg1
+ // CHECK-DAG: %[[CMP2:.+]] = arith.cmpi slt, %[[C126]], %[[ARG1]]
// CHECK: %[[SEL2:.+]] = arith.select %[[CMP2]], %[[C126]], %[[SEL1]]
%0 = "tosa.clamp"(%arg0) {min_int = -127 : i64, max_int = 126 : i64, min_fp = 0.0 : f32, max_fp = 0.0 : f32} : (tensor<1xi8>) -> tensor<1xi8>
// CHECK: linalg.generic
+ // CHECK: ^bb0(%[[ARG1:.+]]: i8,
// CHECK-DAG: %[[C128:.+]] = arith.constant -128
// CHECK-DAG: %[[C127:.+]] = arith.constant 127
- // CHECK-DAG: %[[CMP1:.+]] = arith.cmpi slt, %arg1, %[[C128]]
+ // CHECK-DAG: %[[CMP1:.+]] = arith.cmpi slt, %[[ARG1]], %[[C128]]
// CHECK-DAG: %[[SEL1:.+]] = arith.select %[[CMP1]], %[[C128]]
- // CHECK-DAG: %[[CMP2:.+]] = arith.cmpi slt, %[[C127]], %arg1
+ // CHECK-DAG: %[[CMP2:.+]] = arith.cmpi slt, %[[C127]], %[[ARG1]]
// CHECK: %[[SEL2:.+]] = arith.select %[[CMP2]], %[[C127]], %[[SEL1]]
%1 = "tosa.clamp"(%arg0) {min_int = -130 : i64, max_int = 130 : i64, min_fp = 0.0 : f32, max_fp = 0.0 : f32} : (tensor<1xi8>) -> tensor<1xi8>
// CHECK-LABEL: @test_clamp_f16
func.func @test_clamp_f16(%arg0: tensor<1xf16>) -> () {
// CHECK: linalg.generic
+ // CHECK: ^bb0(%[[ARG1:.+]]: f16,
// CHECK-DAG: %[[C0:.+]] = arith.constant 0.0
// CHECK-DAG: %[[C6:.+]] = arith.constant 6.0
- // CHECK-DAG: %[[CMP1:.+]] = arith.cmpf olt, %arg1, %[[C0]]
- // CHECK-DAG: %[[SEL1:.+]] = arith.select %[[CMP1]], %[[C0]]
- // CHECK-DAG: %[[CMP2:.+]] = arith.cmpf olt, %[[C6]], %arg1
- // CHECK: %[[SEL2:.+]] = arith.select %[[CMP2]], %[[C6]], %[[SEL1]]
+ // CHECK-DAG: %[[MIN:.+]] = arith.minf %[[ARG1]], %[[C0]]
+ // CHECK-DAG: %[[MAX:.+]] = arith.maxf %[[MIN]], %[[C6]]
%0 = "tosa.clamp"(%arg0) {min_int = 0 : i64, max_int = 0 : i64, min_fp = 0.0 : f32, max_fp = 6.0 : f32} : (tensor<1xf16>) -> tensor<1xf16>
return
// CHECK: arith.constant 3.40282347E+38 : f32
// CHECK: linalg.fill
// CHECK: linalg.generic
- // CHECK: arith.cmpf olt
- // CHECK: select
+ // CHECK: arith.minf
%3 = "tosa.reduce_min"(%arg0) {axis = 0 : i64} : (tensor<5x4xf32>) -> tensor<1x4xf32>
// CHECK: arith.constant -3.40282347E+38 : f32
// CHECK: linalg.fill
// CHECK: linalg.generic
- // CHECK: arith.cmpf ogt
- // CHECK: select
+ // CHECK: arith.maxf
%4 = "tosa.reduce_max"(%arg0) {axis = 0 : i64} : (tensor<5x4xf32>) -> tensor<1x4xf32>
return
}
// CHECK: %[[FILL:.+]] = linalg.fill ins(%[[CMIN]]{{.*}}outs(%[[INIT]]
// CHECK: %[[GENERIC:.+]] = linalg.generic {indexing_maps = [#[[$MAP0]], #[[$MAP1]]], iterator_types = ["parallel", "reduction"]} ins(%arg0 : tensor<?x?xf32>) outs(%[[FILL]] : tensor<?xf32>)
// CHECK: ^bb0(%arg1: f32, %arg2: f32)
- // CHECK: %[[CMP:.+]] = arith.cmpf ogt, %arg1, %arg2 : f32
- // CHECK: %[[RES:.+]] = arith.select %[[CMP]], %arg1, %arg2 : f32
- // CHECK: linalg.yield %[[RES]] : f32
+ // CHECK: %[[MAX:.+]] = arith.maxf %arg1, %arg2 : f32
+ // CHECK: linalg.yield %[[MAX]] : f32
// CHECK: tensor.expand_shape %[[GENERIC]] {{\[}}[0, 1]] : tensor<?xf32> into tensor<?x1xf32>
%0 = "tosa.reduce_max"(%arg0) {axis = 1 : i64} : (tensor<?x?xf32>) -> tensor<?x1xf32>
return