[mlir][tosa] Add lowerings for tosa.equal and tosa.arithmetic_right_shift
authornatashaknk <natashaknk@google.com>
Tue, 4 May 2021 01:08:14 +0000 (18:08 -0700)
committerRob Suderman <rob.suderman@gmail.com>
Tue, 4 May 2021 01:26:49 +0000 (18:26 -0700)
Lowerings equal and arithmetic_right_shift for elementwise ops to linalg dialect using linalg.generic

Reviewed By: rsuderman

Differential Revision: https://reviews.llvm.org/D101804

mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp
mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir

index 8107d97..ee4f29c 100644 (file)
@@ -227,6 +227,45 @@ createLinalgBodyCalculationForElementwiseOp(Operation *op, ValueRange args,
   if (isa<tosa::LogicalRightShiftOp>(op) && elementTy.isa<IntegerType>())
     return rewriter.create<mlir::UnsignedShiftRightOp>(loc, resultTypes, args);
 
+  // tosa::ArithmeticRightShiftOp
+  if (isa<tosa::ArithmeticRightShiftOp>(op) && elementTy.isa<IntegerType>()) {
+    auto result =
+        rewriter.create<mlir::SignedShiftRightOp>(loc, resultTypes, args);
+    auto round = op->getAttr("round").cast<BoolAttr>().getValue();
+    if (!round) {
+      return result;
+    }
+
+    Type i1Ty = IntegerType::get(rewriter.getContext(), /*width=*/1);
+    auto one =
+        rewriter.create<mlir::ConstantOp>(loc, IntegerAttr::get(elementTy, 1));
+    auto zero =
+        rewriter.create<mlir::ConstantOp>(loc, IntegerAttr::get(elementTy, 0));
+    auto i1one =
+        rewriter.create<mlir::ConstantOp>(loc, IntegerAttr::get(i1Ty, 1));
+
+    // Checking that input2 != 0
+    auto shiftValueGreaterThanZero =
+        rewriter.create<mlir::CmpIOp>(loc, CmpIPredicate::sgt, args[1], zero);
+
+    // Checking for the last bit of input1 to be 1
+    auto subtract =
+        rewriter.create<mlir::SubIOp>(loc, resultTypes, args[1], one);
+    auto shifted = rewriter
+                       .create<mlir::SignedShiftRightOp>(loc, resultTypes,
+                                                         args[0], subtract)
+                       ->getResults();
+    auto truncated =
+        rewriter.create<mlir::TruncateIOp>(loc, i1Ty, shifted, mlir::None);
+    auto isInputOdd = rewriter.create<mlir::AndOp>(loc, i1Ty, truncated, i1one);
+
+    auto shouldRound = rewriter.create<mlir::AndOp>(
+        loc, i1Ty, shiftValueGreaterThanZero, isInputOdd);
+    auto extended =
+        rewriter.create<ZeroExtendIOp>(loc, resultTypes, shouldRound);
+    return rewriter.create<mlir::AddIOp>(loc, resultTypes, result, extended);
+  }
+
   // tosa::LogicalAnd
   if (isa<tosa::LogicalAndOp>(op) && elementTy.isInteger(1))
     return rewriter.create<mlir::AndOp>(loc, resultTypes, args);
@@ -284,6 +323,15 @@ createLinalgBodyCalculationForElementwiseOp(Operation *op, ValueRange args,
     return rewriter.create<mlir::CmpIOp>(loc, CmpIPredicate::sge, args[0],
                                          args[1]);
 
+  // tosa::EqualOp
+  if (isa<tosa::EqualOp>(op) && elementTy.isa<FloatType>())
+    return rewriter.create<mlir::CmpFOp>(loc, CmpFPredicate::OEQ, args[0],
+                                         args[1]);
+
+  if (isa<tosa::EqualOp>(op) && elementTy.isSignlessInteger())
+    return rewriter.create<mlir::CmpIOp>(loc, CmpIPredicate::eq, args[0],
+                                         args[1]);
+
   // tosa::SelectOp
   if (isa<tosa::SelectOp>(op)) {
     elementTy = op->getOperand(1).getType().cast<ShapedType>().getElementType();
@@ -2202,9 +2250,11 @@ void mlir::tosa::populateTosaToLinalgOnTensorsConversionPatterns(
       PointwiseConverter<tosa::CastOp>,
       PointwiseConverter<tosa::LogicalLeftShiftOp>,
       PointwiseConverter<tosa::LogicalRightShiftOp>,
+      PointwiseConverter<tosa::ArithmeticRightShiftOp>,
       PointwiseConverter<tosa::SelectOp>,
       PointwiseConverter<tosa::GreaterOp>,
       PointwiseConverter<tosa::GreaterEqualOp>,
+      PointwiseConverter<tosa::EqualOp>,
       PointwiseConverter<tosa::MaximumOp>,
       PointwiseConverter<tosa::MinimumOp>,
       PointwiseConverter<tosa::CeilOp>,
index 775ac6d..9bd03f1 100644 (file)
@@ -152,64 +152,68 @@ func @test_simple_f32(%arg0: tensor<1xf32>) -> () {
   %11 = "tosa.greater_equal"(%0, %1) : (tensor<1xf32>, tensor<1xf32>) -> tensor<1xi1>
 
   // CHECK: linalg.generic
+  // CHECK: cmpf
+  %12 = "tosa.equal"(%0, %1) : (tensor<1xf32>, tensor<1xf32>) -> tensor<1xi1>
+
+  // CHECK: linalg.generic
   // CHECK: select
-  %12 = "tosa.select"(%10, %0, %1) : (tensor<1xi1>, tensor<1xf32>, tensor<1xf32>) -> tensor<1xf32>
+  %13 = "tosa.select"(%10, %0, %1) : (tensor<1xi1>, tensor<1xf32>, tensor<1xf32>) -> tensor<1xf32>
 
   // CHECK: linalg.generic
   // CHECK: cmpf
   // CHECK: select
-  %13 = "tosa.maximum"(%0, %1) : (tensor<1xf32>, tensor<1xf32>) -> tensor<1xf32>
+  %14 = "tosa.maximum"(%0, %1) : (tensor<1xf32>, tensor<1xf32>) -> tensor<1xf32>
 
   // CHECK: linalg.generic
   // CHECK: cmpf
   // CHECK: select
-  %14 = "tosa.minimum"(%0, %1) : (tensor<1xf32>, tensor<1xf32>) -> tensor<1xf32>
+  %15 = "tosa.minimum"(%0, %1) : (tensor<1xf32>, tensor<1xf32>) -> tensor<1xf32>
 
   // CHECK: linalg.generic
   // CHECK: ceil
-  %15 = "tosa.ceil"(%0) : (tensor<1xf32>) -> tensor<1xf32>
+  %16 = "tosa.ceil"(%0) : (tensor<1xf32>) -> tensor<1xf32>
 
   // CHECK: linalg.generic
   // CHECK: floor
-  %16 = "tosa.floor"(%0) : (tensor<1xf32>) -> tensor<1xf32>
+  %17 = "tosa.floor"(%0) : (tensor<1xf32>) -> tensor<1xf32>
 
   // CHECK: linalg.generic
   // CHECK: cmpf
   // CHECK: select
-  %17 = "tosa.clamp"(%0) {min_int = 1 : i64, max_int = 5 : i64, min_fp = 1.0 : f32, max_fp = 5.0 : f32} : (tensor<1xf32>) -> tensor<1xf32>
+  %18 = "tosa.clamp"(%0) {min_int = 1 : i64, max_int = 5 : i64, min_fp = 1.0 : f32, max_fp = 5.0 : f32} : (tensor<1xf32>) -> tensor<1xf32>
 
   // CHECK: linalg.generic
   // CHECK: cmpf
   // CHECK: select
-  %18 = "tosa.reluN"(%0) {max_int = 5 : i64, max_fp = 5.0 : f32} : (tensor<1xf32>) -> tensor<1xf32>
+  %19 = "tosa.reluN"(%0) {max_int = 5 : i64, max_fp = 5.0 : f32} : (tensor<1xf32>) -> tensor<1xf32>
 
   // CHECK: linalg.generic
   // CHECK: negf
   // CHECK: exp
   // CHECK: addf
   // CHECK: divf
-  %19 = "tosa.sigmoid"(%0) : (tensor<1xf32>) -> tensor<1xf32>
+  %20 = "tosa.sigmoid"(%0) : (tensor<1xf32>) -> tensor<1xf32>
 
   // CHECK: linalg.generic
   // CHECK: fptosi
-  %20 = "tosa.cast"(%0) : (tensor<1xf32>) -> tensor<1xi32>
+  %21 = "tosa.cast"(%0) : (tensor<1xf32>) -> tensor<1xi32>
 
   // CHECK: linalg.generic
   // CHECK: constant 0
   // CHECK: cmpf
-  %21 = "tosa.cast"(%0) : (tensor<1xf32>) -> tensor<1xi1>
+  %22 = "tosa.cast"(%0) : (tensor<1xf32>) -> tensor<1xi1>
 
   // CHECK: linalg.generic
   // CHECK: fptrunc
-  %22 = "tosa.cast"(%0) : (tensor<1xf32>) -> tensor<1xf16>
+  %23 = "tosa.cast"(%0) : (tensor<1xf32>) -> tensor<1xf16>
 
   // CHECK: linalg.generic
   // CHECK: yield
-  %23 = "tosa.cast"(%0) : (tensor<1xf32>) -> tensor<1xf32>
+  %24 = "tosa.cast"(%0) : (tensor<1xf32>) -> tensor<1xf32>
 
   // CHECK: linalg.generic
   // CHECK: divf
-  %24 = "tosa.reciprocal"(%0) : (tensor<1xf32>) -> tensor<1xf32>
+  %25 = "tosa.reciprocal"(%0) : (tensor<1xf32>) -> tensor<1xf32>
 
   return
 }
@@ -286,57 +290,75 @@ func @test_simple_i32(%arg0: tensor<1xi32>) -> () {
   %9 = "tosa.logical_right_shift"(%arg0, %arg0) : (tensor<1xi32>, tensor<1xi32>) -> tensor<1xi32>
 
   // CHECK: linalg.generic
+  // CHECK: shift_right_signed
+  %10 = "tosa.arithmetic_right_shift"(%arg0, %arg0) {round = 0 : i1} : (tensor<1xi32>, tensor<1xi32>) -> tensor<1xi32>
+
+  // CHECK: linalg.generic
+  // CHECK: constant 1
+  // CHECK: constant 0
+  // CHECK: constant true
+  // CHECK: cmpi
+  // CHECK: subi
+  // CHECK: shift_right_signed
+  // CHECK: trunci
+  // CHECK: and
+  // CHECK: and
+  // CHECK: zexti
+  // CHECK: addi
+  %11 = "tosa.arithmetic_right_shift"(%arg0, %arg0) {round = 1 : i1} : (tensor<1xi32>, tensor<1xi32>) -> tensor<1xi32>
+
+  // CHECK: linalg.generic
   // CHECK: cmpi
-  %10 = "tosa.greater"(%0, %1) : (tensor<1xi32>, tensor<1xi32>) -> tensor<1xi1>
+  %12 = "tosa.greater"(%0, %1) : (tensor<1xi32>, tensor<1xi32>) -> tensor<1xi1>
 
   // CHECK: linalg.generic
   // CHECK: cmpi
-  %11 = "tosa.greater_equal"(%0, %1) : (tensor<1xi32>, tensor<1xi32>) -> tensor<1xi1>
+  %13 = "tosa.greater_equal"(%0, %1) : (tensor<1xi32>, tensor<1xi32>) -> tensor<1xi1>
 
   // CHECK: linalg.generic
   // CHECK: select
-  %12 = "tosa.select"(%10, %0, %1) : (tensor<1xi1>, tensor<1xi32>, tensor<1xi32>) -> tensor<1xi32>
+  %14 = "tosa.select"(%12, %0, %1) : (tensor<1xi1>, tensor<1xi32>, tensor<1xi32>) -> tensor<1xi32>
 
   // CHECK: linalg.generic
   // CHECK: cmpi
   // CHECK: select
-  %13 = "tosa.maximum"(%0, %1) : (tensor<1xi32>, tensor<1xi32>) -> tensor<1xi32>
+  %15 = "tosa.maximum"(%0, %1) : (tensor<1xi32>, tensor<1xi32>) -> tensor<1xi32>
 
   // CHECK: linalg.generic
   // CHECK: cmpi
   // CHECK: select
-  %14 = "tosa.minimum"(%0, %1) : (tensor<1xi32>, tensor<1xi32>) -> tensor<1xi32>
+  %16 = "tosa.minimum"(%0, %1) : (tensor<1xi32>, tensor<1xi32>) -> tensor<1xi32>
 
   // CHECK: linalg.generic
   // CHECK: cmpi
   // CHECK: select
-  %15 = "tosa.clamp"(%0) {min_int = 1 : i64, max_int = 5 : i64, min_fp = 1.0 : f32, max_fp = 5.0 : f32} : (tensor<1xi32>) -> tensor<1xi32>
+  %17 = "tosa.clamp"(%0) {min_int = 1 : i64, max_int = 5 : i64, min_fp = 1.0 : f32, max_fp = 5.0 : f32} : (tensor<1xi32>) -> tensor<1xi32>
 
   // CHECK: linalg.generic
   // CHECK: cmpi
   // CHECK: select
-  %16 = "tosa.reluN"(%0) {max_int = 5 : i64, max_fp = 5.0 : f32} : (tensor<1xi32>) -> tensor<1xi32>
+  %18 = "tosa.reluN"(%0) {max_int = 5 : i64, max_fp = 5.0 : f32} : (tensor<1xi32>) -> tensor<1xi32>
 
   // CHECK: linalg.generic
   // CHECK: trunci
-  %17 = "tosa.cast"(%0) : (tensor<1xi32>) -> tensor<1xi16>
+  %19 = "tosa.cast"(%0) : (tensor<1xi32>) -> tensor<1xi16>
 
   // CHECK: linalg.generic
   // CHECK: yield
-  %18 = "tosa.cast"(%0) : (tensor<1xi32>) -> tensor<1xi32>
+  %20 = "tosa.cast"(%0) : (tensor<1xi32>) -> tensor<1xi32>
 
   // CHECK: linalg.generic
   // CHECK: sexti
-  %19 = "tosa.cast"(%0) : (tensor<1xi32>) -> tensor<1xi64>
+  %21 = "tosa.cast"(%0) : (tensor<1xi32>) -> tensor<1xi64>
 
   // CHECK: linalg.generic
   // CHECK: constant 0
   // CHECK: cmpi
-  %20 = "tosa.cast"(%0) : (tensor<1xi32>) -> tensor<1xi1>
+  %22 = "tosa.cast"(%0) : (tensor<1xi32>) -> tensor<1xi1>
 
   // CHECK: linalg.generic
   // CHECK: sitofp
-  %21 = "tosa.cast"(%0) : (tensor<1xi32>) -> tensor<1xf32>
+  %23 = "tosa.cast"(%0) : (tensor<1xi32>) -> tensor<1xf32>
 
   return
 }