if (isa<tosa::LogicalRightShiftOp>(op) && elementTy.isa<IntegerType>())
return rewriter.create<mlir::UnsignedShiftRightOp>(loc, resultTypes, args);
+ // tosa::ArithmeticRightShiftOp
+ if (isa<tosa::ArithmeticRightShiftOp>(op) && elementTy.isa<IntegerType>()) {
+ auto result =
+ rewriter.create<mlir::SignedShiftRightOp>(loc, resultTypes, args);
+ auto round = op->getAttr("round").cast<BoolAttr>().getValue();
+ if (!round) {
+ return result;
+ }
+
+ Type i1Ty = IntegerType::get(rewriter.getContext(), /*width=*/1);
+ auto one =
+ rewriter.create<mlir::ConstantOp>(loc, IntegerAttr::get(elementTy, 1));
+ auto zero =
+ rewriter.create<mlir::ConstantOp>(loc, IntegerAttr::get(elementTy, 0));
+ auto i1one =
+ rewriter.create<mlir::ConstantOp>(loc, IntegerAttr::get(i1Ty, 1));
+
+ // Checking that input2 != 0
+ auto shiftValueGreaterThanZero =
+ rewriter.create<mlir::CmpIOp>(loc, CmpIPredicate::sgt, args[1], zero);
+
+ // Checking for the last bit of input1 to be 1
+ auto subtract =
+ rewriter.create<mlir::SubIOp>(loc, resultTypes, args[1], one);
+ auto shifted = rewriter
+ .create<mlir::SignedShiftRightOp>(loc, resultTypes,
+ args[0], subtract)
+ ->getResults();
+ auto truncated =
+ rewriter.create<mlir::TruncateIOp>(loc, i1Ty, shifted, mlir::None);
+ auto isInputOdd = rewriter.create<mlir::AndOp>(loc, i1Ty, truncated, i1one);
+
+ auto shouldRound = rewriter.create<mlir::AndOp>(
+ loc, i1Ty, shiftValueGreaterThanZero, isInputOdd);
+ auto extended =
+ rewriter.create<ZeroExtendIOp>(loc, resultTypes, shouldRound);
+ return rewriter.create<mlir::AddIOp>(loc, resultTypes, result, extended);
+ }
+
// tosa::LogicalAnd
if (isa<tosa::LogicalAndOp>(op) && elementTy.isInteger(1))
return rewriter.create<mlir::AndOp>(loc, resultTypes, args);
return rewriter.create<mlir::CmpIOp>(loc, CmpIPredicate::sge, args[0],
args[1]);
+ // tosa::EqualOp
+ if (isa<tosa::EqualOp>(op) && elementTy.isa<FloatType>())
+ return rewriter.create<mlir::CmpFOp>(loc, CmpFPredicate::OEQ, args[0],
+ args[1]);
+
+ if (isa<tosa::EqualOp>(op) && elementTy.isSignlessInteger())
+ return rewriter.create<mlir::CmpIOp>(loc, CmpIPredicate::eq, args[0],
+ args[1]);
+
// tosa::SelectOp
if (isa<tosa::SelectOp>(op)) {
elementTy = op->getOperand(1).getType().cast<ShapedType>().getElementType();
PointwiseConverter<tosa::CastOp>,
PointwiseConverter<tosa::LogicalLeftShiftOp>,
PointwiseConverter<tosa::LogicalRightShiftOp>,
+ PointwiseConverter<tosa::ArithmeticRightShiftOp>,
PointwiseConverter<tosa::SelectOp>,
PointwiseConverter<tosa::GreaterOp>,
PointwiseConverter<tosa::GreaterEqualOp>,
+ PointwiseConverter<tosa::EqualOp>,
PointwiseConverter<tosa::MaximumOp>,
PointwiseConverter<tosa::MinimumOp>,
PointwiseConverter<tosa::CeilOp>,
%11 = "tosa.greater_equal"(%0, %1) : (tensor<1xf32>, tensor<1xf32>) -> tensor<1xi1>
// CHECK: linalg.generic
+ // CHECK: cmpf
+ %12 = "tosa.equal"(%0, %1) : (tensor<1xf32>, tensor<1xf32>) -> tensor<1xi1>
+
+ // CHECK: linalg.generic
// CHECK: select
- %12 = "tosa.select"(%10, %0, %1) : (tensor<1xi1>, tensor<1xf32>, tensor<1xf32>) -> tensor<1xf32>
+ %13 = "tosa.select"(%10, %0, %1) : (tensor<1xi1>, tensor<1xf32>, tensor<1xf32>) -> tensor<1xf32>
// CHECK: linalg.generic
// CHECK: cmpf
// CHECK: select
- %13 = "tosa.maximum"(%0, %1) : (tensor<1xf32>, tensor<1xf32>) -> tensor<1xf32>
+ %14 = "tosa.maximum"(%0, %1) : (tensor<1xf32>, tensor<1xf32>) -> tensor<1xf32>
// CHECK: linalg.generic
// CHECK: cmpf
// CHECK: select
- %14 = "tosa.minimum"(%0, %1) : (tensor<1xf32>, tensor<1xf32>) -> tensor<1xf32>
+ %15 = "tosa.minimum"(%0, %1) : (tensor<1xf32>, tensor<1xf32>) -> tensor<1xf32>
// CHECK: linalg.generic
// CHECK: ceil
- %15 = "tosa.ceil"(%0) : (tensor<1xf32>) -> tensor<1xf32>
+ %16 = "tosa.ceil"(%0) : (tensor<1xf32>) -> tensor<1xf32>
// CHECK: linalg.generic
// CHECK: floor
- %16 = "tosa.floor"(%0) : (tensor<1xf32>) -> tensor<1xf32>
+ %17 = "tosa.floor"(%0) : (tensor<1xf32>) -> tensor<1xf32>
// CHECK: linalg.generic
// CHECK: cmpf
// CHECK: select
- %17 = "tosa.clamp"(%0) {min_int = 1 : i64, max_int = 5 : i64, min_fp = 1.0 : f32, max_fp = 5.0 : f32} : (tensor<1xf32>) -> tensor<1xf32>
+ %18 = "tosa.clamp"(%0) {min_int = 1 : i64, max_int = 5 : i64, min_fp = 1.0 : f32, max_fp = 5.0 : f32} : (tensor<1xf32>) -> tensor<1xf32>
// CHECK: linalg.generic
// CHECK: cmpf
// CHECK: select
- %18 = "tosa.reluN"(%0) {max_int = 5 : i64, max_fp = 5.0 : f32} : (tensor<1xf32>) -> tensor<1xf32>
+ %19 = "tosa.reluN"(%0) {max_int = 5 : i64, max_fp = 5.0 : f32} : (tensor<1xf32>) -> tensor<1xf32>
// CHECK: linalg.generic
// CHECK: negf
// CHECK: exp
// CHECK: addf
// CHECK: divf
- %19 = "tosa.sigmoid"(%0) : (tensor<1xf32>) -> tensor<1xf32>
+ %20 = "tosa.sigmoid"(%0) : (tensor<1xf32>) -> tensor<1xf32>
// CHECK: linalg.generic
// CHECK: fptosi
- %20 = "tosa.cast"(%0) : (tensor<1xf32>) -> tensor<1xi32>
+ %21 = "tosa.cast"(%0) : (tensor<1xf32>) -> tensor<1xi32>
// CHECK: linalg.generic
// CHECK: constant 0
// CHECK: cmpf
- %21 = "tosa.cast"(%0) : (tensor<1xf32>) -> tensor<1xi1>
+ %22 = "tosa.cast"(%0) : (tensor<1xf32>) -> tensor<1xi1>
// CHECK: linalg.generic
// CHECK: fptrunc
- %22 = "tosa.cast"(%0) : (tensor<1xf32>) -> tensor<1xf16>
+ %23 = "tosa.cast"(%0) : (tensor<1xf32>) -> tensor<1xf16>
// CHECK: linalg.generic
// CHECK: yield
- %23 = "tosa.cast"(%0) : (tensor<1xf32>) -> tensor<1xf32>
+ %24 = "tosa.cast"(%0) : (tensor<1xf32>) -> tensor<1xf32>
// CHECK: linalg.generic
// CHECK: divf
- %24 = "tosa.reciprocal"(%0) : (tensor<1xf32>) -> tensor<1xf32>
+ %25 = "tosa.reciprocal"(%0) : (tensor<1xf32>) -> tensor<1xf32>
return
}
%9 = "tosa.logical_right_shift"(%arg0, %arg0) : (tensor<1xi32>, tensor<1xi32>) -> tensor<1xi32>
// CHECK: linalg.generic
+ // CHECK: shift_right_signed
+ %10 = "tosa.arithmetic_right_shift"(%arg0, %arg0) {round = 0 : i1} : (tensor<1xi32>, tensor<1xi32>) -> tensor<1xi32>
+
+ // CHECK: linalg.generic
+ // CHECK: constant 1
+ // CHECK: constant 0
+ // CHECK: constant true
+ // CHECK: cmpi
+ // CHECK: subi
+ // CHECK: shift_right_signed
+ // CHECK: trunci
+ // CHECK: and
+ // CHECK: and
+ // CHECK: zexti
+ // CHECK: addi
+ %11 = "tosa.arithmetic_right_shift"(%arg0, %arg0) {round = 1 : i1} : (tensor<1xi32>, tensor<1xi32>) -> tensor<1xi32>
+
+ // CHECK: linalg.generic
// CHECK: cmpi
- %10 = "tosa.greater"(%0, %1) : (tensor<1xi32>, tensor<1xi32>) -> tensor<1xi1>
+ %12 = "tosa.greater"(%0, %1) : (tensor<1xi32>, tensor<1xi32>) -> tensor<1xi1>
// CHECK: linalg.generic
// CHECK: cmpi
- %11 = "tosa.greater_equal"(%0, %1) : (tensor<1xi32>, tensor<1xi32>) -> tensor<1xi1>
+ %13 = "tosa.greater_equal"(%0, %1) : (tensor<1xi32>, tensor<1xi32>) -> tensor<1xi1>
// CHECK: linalg.generic
// CHECK: select
- %12 = "tosa.select"(%10, %0, %1) : (tensor<1xi1>, tensor<1xi32>, tensor<1xi32>) -> tensor<1xi32>
+ %14 = "tosa.select"(%12, %0, %1) : (tensor<1xi1>, tensor<1xi32>, tensor<1xi32>) -> tensor<1xi32>
// CHECK: linalg.generic
// CHECK: cmpi
// CHECK: select
- %13 = "tosa.maximum"(%0, %1) : (tensor<1xi32>, tensor<1xi32>) -> tensor<1xi32>
+ %15 = "tosa.maximum"(%0, %1) : (tensor<1xi32>, tensor<1xi32>) -> tensor<1xi32>
// CHECK: linalg.generic
// CHECK: cmpi
// CHECK: select
- %14 = "tosa.minimum"(%0, %1) : (tensor<1xi32>, tensor<1xi32>) -> tensor<1xi32>
+ %16 = "tosa.minimum"(%0, %1) : (tensor<1xi32>, tensor<1xi32>) -> tensor<1xi32>
// CHECK: linalg.generic
// CHECK: cmpi
// CHECK: select
- %15 = "tosa.clamp"(%0) {min_int = 1 : i64, max_int = 5 : i64, min_fp = 1.0 : f32, max_fp = 5.0 : f32} : (tensor<1xi32>) -> tensor<1xi32>
+ %17 = "tosa.clamp"(%0) {min_int = 1 : i64, max_int = 5 : i64, min_fp = 1.0 : f32, max_fp = 5.0 : f32} : (tensor<1xi32>) -> tensor<1xi32>
// CHECK: linalg.generic
// CHECK: cmpi
// CHECK: select
- %16 = "tosa.reluN"(%0) {max_int = 5 : i64, max_fp = 5.0 : f32} : (tensor<1xi32>) -> tensor<1xi32>
+ %18 = "tosa.reluN"(%0) {max_int = 5 : i64, max_fp = 5.0 : f32} : (tensor<1xi32>) -> tensor<1xi32>
// CHECK: linalg.generic
// CHECK: trunci
- %17 = "tosa.cast"(%0) : (tensor<1xi32>) -> tensor<1xi16>
+ %19 = "tosa.cast"(%0) : (tensor<1xi32>) -> tensor<1xi16>
// CHECK: linalg.generic
// CHECK: yield
- %18 = "tosa.cast"(%0) : (tensor<1xi32>) -> tensor<1xi32>
+ %20 = "tosa.cast"(%0) : (tensor<1xi32>) -> tensor<1xi32>
// CHECK: linalg.generic
// CHECK: sexti
- %19 = "tosa.cast"(%0) : (tensor<1xi32>) -> tensor<1xi64>
+ %21 = "tosa.cast"(%0) : (tensor<1xi32>) -> tensor<1xi64>
// CHECK: linalg.generic
// CHECK: constant 0
// CHECK: cmpi
- %20 = "tosa.cast"(%0) : (tensor<1xi32>) -> tensor<1xi1>
+ %22 = "tosa.cast"(%0) : (tensor<1xi32>) -> tensor<1xi1>
// CHECK: linalg.generic
// CHECK: sitofp
- %21 = "tosa.cast"(%0) : (tensor<1xi32>) -> tensor<1xf32>
+ %23 = "tosa.cast"(%0) : (tensor<1xi32>) -> tensor<1xf32>
return
}