[mlir][tosa] Add tosa.div integer lowering to linalg.generic.
authornatashaknk <natashaknk@google.com>
Thu, 13 May 2021 20:15:57 +0000 (13:15 -0700)
committerRob Suderman <rob.suderman@gmail.com>
Thu, 13 May 2021 20:16:00 +0000 (13:16 -0700)
Lowering div elementwise op to the linalg dialect. Since tosa only supports integer division, that is the only version that is currently implemented.

Reviewed By: rsuderman

Differential Revision: https://reviews.llvm.org/D102430

mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp
mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir

index 66e1747..3934690 100644 (file)
@@ -126,6 +126,10 @@ createLinalgBodyCalculationForElementwiseOp(Operation *op, ValueRange args,
     return rewriter.create<mlir::MulFOp>(loc, resultTypes, args);
   }
 
+  // tosa::DivOp
+  if (isa<tosa::DivOp>(op) && elementTy.isa<IntegerType>())
+    return rewriter.create<mlir::SignedDivIOp>(loc, resultTypes, args);
+
   // tosa::ReciprocalOp
   if (isa<tosa::ReciprocalOp>(op) && elementTy.isa<FloatType>()) {
     auto one =
@@ -2335,6 +2339,7 @@ void mlir::tosa::populateTosaToLinalgOnTensorsConversionPatterns(
       PointwiseConverter<tosa::AddOp>,
       PointwiseConverter<tosa::SubOp>,
       PointwiseConverter<tosa::MulOp>,
+      PointwiseConverter<tosa::DivOp>,
       PointwiseConverter<tosa::NegateOp>,
       PointwiseConverter<tosa::PowOp>,
       PointwiseConverter<tosa::ReciprocalOp>,
index 46841a1..1e250c8 100644 (file)
@@ -295,33 +295,37 @@ func @test_simple_i32(%arg0: tensor<1xi32>) -> () {
   %3 = "tosa.mul"(%arg0, %arg0) {shift = 2 : i32} : (tensor<1xi32>, tensor<1xi32>) -> tensor<1xi32>
 
   // CHECK: linalg.generic
+  // CHECK: divi
+  %4 = "tosa.div"(%arg0, %arg0) : (tensor<1xi32>, tensor<1xi32>) -> tensor<1xi32>
+
+  // CHECK: linalg.generic
   // CHECK: [[ZERO:%.+]] = constant 0
   // CHECK: subi [[ZERO]], %arg1
-  %4 = "tosa.negate"(%arg0) : (tensor<1xi32>) -> tensor<1xi32>
+  %5 = "tosa.negate"(%arg0) : (tensor<1xi32>) -> tensor<1xi32>
 
   // CHECK: linalg.generic
   // CHECK: and
-  %5 = "tosa.bitwise_and"(%arg0, %arg0) : (tensor<1xi32>, tensor<1xi32>) -> tensor<1xi32>
+  %6 = "tosa.bitwise_and"(%arg0, %arg0) : (tensor<1xi32>, tensor<1xi32>) -> tensor<1xi32>
 
   // CHECK: linalg.generic
   // CHECK: or
-  %6 = "tosa.bitwise_or"(%arg0, %arg0) : (tensor<1xi32>, tensor<1xi32>) -> tensor<1xi32>
+  %7 = "tosa.bitwise_or"(%arg0, %arg0) : (tensor<1xi32>, tensor<1xi32>) -> tensor<1xi32>
 
   // CHECK: linalg.generic
   // CHECK: xor
-  %7 = "tosa.bitwise_xor"(%arg0, %arg0) : (tensor<1xi32>, tensor<1xi32>) -> tensor<1xi32>
+  %8 = "tosa.bitwise_xor"(%arg0, %arg0) : (tensor<1xi32>, tensor<1xi32>) -> tensor<1xi32>
 
   // CHECK: linalg.generic
   // CHECK: shift_left
-  %8 = "tosa.logical_left_shift"(%arg0, %arg0) : (tensor<1xi32>, tensor<1xi32>) -> tensor<1xi32>
+  %9 = "tosa.logical_left_shift"(%arg0, %arg0) : (tensor<1xi32>, tensor<1xi32>) -> tensor<1xi32>
 
   // CHECK: linalg.generic
   // CHECK: shift_right_unsigned
-  %9 = "tosa.logical_right_shift"(%arg0, %arg0) : (tensor<1xi32>, tensor<1xi32>) -> tensor<1xi32>
+  %10 = "tosa.logical_right_shift"(%arg0, %arg0) : (tensor<1xi32>, tensor<1xi32>) -> tensor<1xi32>
 
   // CHECK: linalg.generic
   // CHECK: shift_right_signed
-  %10 = "tosa.arithmetic_right_shift"(%arg0, %arg0) {round = 0 : i1} : (tensor<1xi32>, tensor<1xi32>) -> tensor<1xi32>
+  %11 = "tosa.arithmetic_right_shift"(%arg0, %arg0) {round = 0 : i1} : (tensor<1xi32>, tensor<1xi32>) -> tensor<1xi32>
 
   // CHECK: linalg.generic
   // CHECK: constant 1
@@ -335,39 +339,39 @@ func @test_simple_i32(%arg0: tensor<1xi32>) -> () {
   // CHECK: and
   // CHECK: zexti
   // CHECK: addi
-  %11 = "tosa.arithmetic_right_shift"(%arg0, %arg0) {round = 1 : i1} : (tensor<1xi32>, tensor<1xi32>) -> tensor<1xi32>
+  %12 = "tosa.arithmetic_right_shift"(%arg0, %arg0) {round = 1 : i1} : (tensor<1xi32>, tensor<1xi32>) -> tensor<1xi32>
 
   // CHECK: linalg.generic
   // CHECK: cmpi
-  %12 = "tosa.greater"(%0, %1) : (tensor<1xi32>, tensor<1xi32>) -> tensor<1xi1>
+  %13 = "tosa.greater"(%0, %1) : (tensor<1xi32>, tensor<1xi32>) -> tensor<1xi1>
 
   // CHECK: linalg.generic
   // CHECK: cmpi
-  %13 = "tosa.greater_equal"(%0, %1) : (tensor<1xi32>, tensor<1xi32>) -> tensor<1xi1>
+  %14 = "tosa.greater_equal"(%0, %1) : (tensor<1xi32>, tensor<1xi32>) -> tensor<1xi1>
 
   // CHECK: linalg.generic
   // CHECK: select
-  %14 = "tosa.select"(%12, %0, %1) : (tensor<1xi1>, tensor<1xi32>, tensor<1xi32>) -> tensor<1xi32>
+  %15 = "tosa.select"(%13, %0, %1) : (tensor<1xi1>, tensor<1xi32>, tensor<1xi32>) -> tensor<1xi32>
 
   // CHECK: linalg.generic
   // CHECK: cmpi
   // CHECK: select
-  %15 = "tosa.maximum"(%0, %1) : (tensor<1xi32>, tensor<1xi32>) -> tensor<1xi32>
+  %16 = "tosa.maximum"(%0, %1) : (tensor<1xi32>, tensor<1xi32>) -> tensor<1xi32>
 
   // CHECK: linalg.generic
   // CHECK: cmpi
   // CHECK: select
-  %16 = "tosa.minimum"(%0, %1) : (tensor<1xi32>, tensor<1xi32>) -> tensor<1xi32>
+  %17 = "tosa.minimum"(%0, %1) : (tensor<1xi32>, tensor<1xi32>) -> tensor<1xi32>
 
   // CHECK: linalg.generic
   // CHECK: cmpi
   // CHECK: select
-  %17 = "tosa.clamp"(%0) {min_int = 1 : i64, max_int = 5 : i64, min_fp = 1.0 : f32, max_fp = 5.0 : f32} : (tensor<1xi32>) -> tensor<1xi32>
+  %18 = "tosa.clamp"(%0) {min_int = 1 : i64, max_int = 5 : i64, min_fp = 1.0 : f32, max_fp = 5.0 : f32} : (tensor<1xi32>) -> tensor<1xi32>
 
   // CHECK: linalg.generic
   // CHECK: cmpi
   // CHECK: select
-  %18 = "tosa.reluN"(%0) {max_int = 5 : i64, max_fp = 5.0 : f32} : (tensor<1xi32>) -> tensor<1xi32>
+  %19 = "tosa.reluN"(%0) {max_int = 5 : i64, max_fp = 5.0 : f32} : (tensor<1xi32>) -> tensor<1xi32>
 
   // CHECK: linalg.generic
   // CHECK: constant -32768
@@ -377,24 +381,24 @@ func @test_simple_i32(%arg0: tensor<1xi32>) -> () {
   // CHECK: cmpi slt
   // CHECK: select
   // CHECK: trunci
-  %19 = "tosa.cast"(%0) : (tensor<1xi32>) -> tensor<1xi16>
+  %20 = "tosa.cast"(%0) : (tensor<1xi32>) -> tensor<1xi16>
 
   // CHECK: linalg.generic
   // CHECK: yield
-  %20 = "tosa.cast"(%0) : (tensor<1xi32>) -> tensor<1xi32>
+  %21 = "tosa.cast"(%0) : (tensor<1xi32>) -> tensor<1xi32>
 
   // CHECK: linalg.generic
   // CHECK: sexti
-  %21 = "tosa.cast"(%0) : (tensor<1xi32>) -> tensor<1xi64>
+  %22 = "tosa.cast"(%0) : (tensor<1xi32>) -> tensor<1xi64>
 
   // CHECK: linalg.generic
   // CHECK: constant 0
   // CHECK: cmpi
-  %22 = "tosa.cast"(%0) : (tensor<1xi32>) -> tensor<1xi1>
+  %23 = "tosa.cast"(%0) : (tensor<1xi32>) -> tensor<1xi1>
 
   // CHECK: linalg.generic
   // CHECK: sitofp
-  %23 = "tosa.cast"(%0) : (tensor<1xi32>) -> tensor<1xf32>
+  %24 = "tosa.cast"(%0) : (tensor<1xi32>) -> tensor<1xf32>
 
   return
 }