args.front(), zero);
}
- if (mlir::FPToSIOp::areCastCompatible(srcTy, dstTy))
- return rewriter.create<mlir::FPToSIOp>(loc, resultTypes, args,
- mlir::None);
+ if (mlir::FPToSIOp::areCastCompatible(srcTy, dstTy)) {
+ auto zero =
+ rewriter.create<ConstantOp>(loc, rewriter.getF32FloatAttr(0.0f));
+ auto half =
+ rewriter.create<ConstantOp>(loc, rewriter.getF32FloatAttr(0.5f));
+
+ auto intMin = rewriter.create<ConstantOp>(
+ loc, rewriter.getF32FloatAttr(
+ APInt::getSignedMinValue(dstTy.getIntOrFloatBitWidth())
+ .getSExtValue()));
+
+ auto intMax = rewriter.create<ConstantOp>(
+ loc, rewriter.getF32FloatAttr(
+ APInt::getSignedMaxValue(dstTy.getIntOrFloatBitWidth())
+ .getSExtValue()));
+
+ auto added = rewriter.create<AddFOp>(loc, args[0], half);
+ auto subbed = rewriter.create<SubFOp>(loc, args[0], half);
+ auto negative =
+ rewriter.create<mlir::CmpFOp>(loc, CmpFPredicate::OLT, args[0], zero);
+ auto rounded =
+ rewriter.create<mlir::SelectOp>(loc, negative, subbed, added);
+
+ auto clamped = clampHelper<mlir::CmpFOp>(loc, rounded, intMin, intMax,
+ CmpFPredicate::OLT, rewriter);
+
+ return rewriter.create<mlir::FPToSIOp>(loc, dstTy, clamped);
+ }
// Casting to boolean, integers need to only be checked as not-equal to
// zero.
return rewriter.create<mlir::SignExtendIOp>(loc, resultTypes, args,
mlir::None);
- if (srcTy.isa<IntegerType>() && dstTy.isa<IntegerType>() && !bitExtend)
- return rewriter.create<mlir::TruncateIOp>(loc, resultTypes, args,
- mlir::None);
+ if (srcTy.isa<IntegerType>() && dstTy.isa<IntegerType>() && !bitExtend) {
+ auto intMin = rewriter.create<ConstantIntOp>(
+ loc,
+ APInt::getSignedMinValue(dstTy.getIntOrFloatBitWidth())
+ .getSExtValue(),
+ srcTy.getIntOrFloatBitWidth());
+
+ auto intMax = rewriter.create<ConstantIntOp>(
+ loc,
+ APInt::getSignedMaxValue(dstTy.getIntOrFloatBitWidth())
+ .getSExtValue(),
+ srcTy.getIntOrFloatBitWidth());
+
+ auto clamped = clampHelper<mlir::CmpIOp>(loc, args[0], intMin, intMax,
+ CmpIPredicate::slt, rewriter);
+ return rewriter.create<mlir::TruncateIOp>(loc, dstTy, clamped);
+ }
}
(void)rewriter.notifyMatchFailure(
%20 = "tosa.sigmoid"(%0) : (tensor<1xf32>) -> tensor<1xf32>
// CHECK: linalg.generic
+ // CHECK: constant 0.000000e+00
+ // CHECK: constant 5.000000e-01
+ // CHECK: constant -2.14748365E+9
+ // CHECK: constant 2.14748365E+9
+ // CHECK: addf
+ // CHECK: subf
+ // CHECK: cmpf olt
+ // CHECK: select
+ // CHECK: cmpf olt
+ // CHECK: select
+ // CHECK: cmpf olt
+ // CHECK: select
// CHECK: fptosi
%21 = "tosa.cast"(%0) : (tensor<1xf32>) -> tensor<1xi32>
%18 = "tosa.reluN"(%0) {max_int = 5 : i64, max_fp = 5.0 : f32} : (tensor<1xi32>) -> tensor<1xi32>
// CHECK: linalg.generic
+ // CHECK: constant -32768
+ // CHECK: constant 32767
+ // CHECK: cmpi slt
+ // CHECK: select
+ // CHECK: cmpi slt
+ // CHECK: select
// CHECK: trunci
%19 = "tosa.cast"(%0) : (tensor<1xi32>) -> tensor<1xi16>