[mlir][tosa] Add lowering for tosa.clz using scf::whileOp
authornatashaknk <natashaknk@google.com>
Thu, 9 Sep 2021 22:57:22 +0000 (15:57 -0700)
committerRob Suderman <rob.suderman@gmail.com>
Thu, 9 Sep 2021 22:57:35 +0000 (15:57 -0700)
Reviewed By: rsuderman

Differential Revision: https://reviews.llvm.org/D109540

mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp
mlir/lib/Conversion/TosaToLinalg/TosaToLinalgPass.cpp
mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir
utils/bazel/llvm-project-overlay/mlir/BUILD.bazel

index 07b519a..558ba4d 100644 (file)
@@ -13,6 +13,7 @@
 #include "mlir/Conversion/TosaToLinalg/TosaToLinalg.h"
 #include "mlir/Dialect/Linalg/IR/LinalgOps.h"
 #include "mlir/Dialect/Math/IR/Math.h"
+#include "mlir/Dialect/SCF/SCF.h"
 #include "mlir/Dialect/StandardOps/IR/Ops.h"
 #include "mlir/Dialect/Tensor/IR/Tensor.h"
 #include "mlir/Dialect/Tosa/IR/TosaOps.h"
@@ -310,6 +311,55 @@ createLinalgBodyCalculationForElementwiseOp(Operation *op, ValueRange args,
     return rewriter.create<mlir::AddIOp>(loc, resultTypes, result, extended);
   }
 
+  // tosa::ClzOp
+  if (isa<tosa::ClzOp>(op) && elementTy.isa<IntegerType>()) {
+    int bitWidth = elementTy.getIntOrFloatBitWidth();
+    auto zero =
+        rewriter.create<mlir::ConstantOp>(loc, IntegerAttr::get(elementTy, 0));
+    auto leadingZeros = rewriter.create<mlir::ConstantOp>(
+        loc, IntegerAttr::get(elementTy, bitWidth));
+
+    SmallVector<Value> operands = {args[0], leadingZeros, zero};
+    SmallVector<Type> types = {elementTy, elementTy, elementTy};
+
+    auto whileOp = rewriter.create<scf::WhileOp>(loc, types, operands);
+    Block *before = rewriter.createBlock(&whileOp.before(), {}, types);
+    Block *after = rewriter.createBlock(&whileOp.after(), {}, types);
+
+    // The conditional block of the while loop.
+    {
+      rewriter.setInsertionPointToStart(&whileOp.before().front());
+      Value input = before->getArgument(0);
+      Value zero = before->getArgument(2);
+
+      Value inputLargerThanZero =
+          rewriter.create<CmpIOp>(loc, CmpIPredicate::ne, input, zero);
+      rewriter.create<scf::ConditionOp>(loc, inputLargerThanZero,
+                                        before->getArguments());
+    }
+
+    // The body of the while loop: shift right until reaching a value of 0.
+    {
+      rewriter.setInsertionPointToStart(&whileOp.after().front());
+      Value input = after->getArgument(0);
+      Value leadingZeros = after->getArgument(1);
+
+      auto one = rewriter.create<mlir::ConstantOp>(
+          loc, IntegerAttr::get(elementTy, 1));
+      auto shifted = rewriter.create<mlir::UnsignedShiftRightOp>(
+          loc, resultTypes, input, one);
+      auto leadingZerosMinusOne =
+          rewriter.create<mlir::SubIOp>(loc, resultTypes, leadingZeros, one);
+
+      rewriter.create<scf::YieldOp>(
+          loc,
+          ValueRange({shifted, leadingZerosMinusOne, after->getArgument(2)}));
+    }
+
+    rewriter.setInsertionPointAfter(whileOp);
+    return whileOp->getResult(1);
+  }
+
   // tosa::LogicalAnd
   if (isa<tosa::LogicalAndOp>(op) && elementTy.isInteger(1))
     return rewriter.create<mlir::AndOp>(loc, resultTypes, args);
@@ -2905,6 +2955,7 @@ void mlir::tosa::populateTosaToLinalgOnTensorsConversionPatterns(
       PointwiseConverter<tosa::LogicalLeftShiftOp>,
       PointwiseConverter<tosa::LogicalRightShiftOp>,
       PointwiseConverter<tosa::ArithmeticRightShiftOp>,
+      PointwiseConverter<tosa::ClzOp>,
       PointwiseConverter<tosa::SelectOp>,
       PointwiseConverter<tosa::GreaterOp>,
       PointwiseConverter<tosa::GreaterEqualOp>,
index b89d3f6..232f85a 100644 (file)
@@ -14,6 +14,7 @@
 #include "mlir/Conversion/TosaToLinalg/TosaToLinalg.h"
 #include "mlir/Dialect/Linalg/IR/LinalgOps.h"
 #include "mlir/Dialect/Math/IR/Math.h"
+#include "mlir/Dialect/SCF/SCF.h"
 #include "mlir/Dialect/StandardOps/IR/Ops.h"
 #include "mlir/Dialect/Tensor/IR/Tensor.h"
 #include "mlir/Dialect/Tosa/IR/TosaOps.h"
@@ -32,15 +33,16 @@ struct TosaToLinalgOnTensors
     : public TosaToLinalgOnTensorsBase<TosaToLinalgOnTensors> {
 public:
   void getDependentDialects(DialectRegistry &registry) const override {
-    registry.insert<linalg::LinalgDialect, math::MathDialect,
-                    StandardOpsDialect, tensor::TensorDialect>();
+    registry
+        .insert<linalg::LinalgDialect, math::MathDialect, StandardOpsDialect,
+                tensor::TensorDialect, scf::SCFDialect>();
   }
 
   void runOnFunction() override {
     RewritePatternSet patterns(&getContext());
     ConversionTarget target(getContext());
     target.addLegalDialect<linalg::LinalgDialect, StandardOpsDialect,
-                           tensor::TensorDialect>();
+                           tensor::TensorDialect, scf::SCFDialect>();
     target.addIllegalDialect<tosa::TosaDialect>();
 
     // Not every TOSA op can be legalized to linalg.
index 9cf3eba..4209172 100644 (file)
@@ -357,37 +357,45 @@ func @test_simple_i32(%arg0: tensor<1xi32>) -> () {
   // CHECK: addi
   %12 = "tosa.arithmetic_right_shift"(%arg0, %arg0) {round = 1 : i1} : (tensor<1xi32>, tensor<1xi32>) -> tensor<1xi32>
 
+  // CHECK: scf.while
+  // CHECK: cmpi ne
+  // CHECK: scf.condition
+  // CHECK: shift_right_unsigned
+  // CHECK: subi
+  // CHECK: scf.yield
+  %13 = "tosa.clz"(%arg0) : (tensor<1xi32>) -> tensor<1xi32>
+
   // CHECK: linalg.generic
   // CHECK: cmpi
-  %13 = "tosa.greater"(%0, %1) : (tensor<1xi32>, tensor<1xi32>) -> tensor<1xi1>
+  %14 = "tosa.greater"(%0, %1) : (tensor<1xi32>, tensor<1xi32>) -> tensor<1xi1>
 
   // CHECK: linalg.generic
   // CHECK: cmpi
-  %14 = "tosa.greater_equal"(%0, %1) : (tensor<1xi32>, tensor<1xi32>) -> tensor<1xi1>
+  %15 = "tosa.greater_equal"(%0, %1) : (tensor<1xi32>, tensor<1xi32>) -> tensor<1xi1>
 
   // CHECK: linalg.generic
   // CHECK: select
-  %15 = "tosa.select"(%13, %0, %1) : (tensor<1xi1>, tensor<1xi32>, tensor<1xi32>) -> tensor<1xi32>
+  %16 = "tosa.select"(%14, %0, %1) : (tensor<1xi1>, tensor<1xi32>, tensor<1xi32>) -> tensor<1xi32>
 
   // CHECK: linalg.generic
   // CHECK: cmpi
   // CHECK: select
-  %16 = "tosa.maximum"(%0, %1) : (tensor<1xi32>, tensor<1xi32>) -> tensor<1xi32>
+  %17 = "tosa.maximum"(%0, %1) : (tensor<1xi32>, tensor<1xi32>) -> tensor<1xi32>
 
   // CHECK: linalg.generic
   // CHECK: cmpi
   // CHECK: select
-  %17 = "tosa.minimum"(%0, %1) : (tensor<1xi32>, tensor<1xi32>) -> tensor<1xi32>
+  %18 = "tosa.minimum"(%0, %1) : (tensor<1xi32>, tensor<1xi32>) -> tensor<1xi32>
 
   // CHECK: linalg.generic
   // CHECK: cmpi
   // CHECK: select
-  %18 = "tosa.clamp"(%0) {min_int = 1 : i64, max_int = 5 : i64, min_fp = 1.0 : f32, max_fp = 5.0 : f32} : (tensor<1xi32>) -> tensor<1xi32>
+  %19 = "tosa.clamp"(%0) {min_int = 1 : i64, max_int = 5 : i64, min_fp = 1.0 : f32, max_fp = 5.0 : f32} : (tensor<1xi32>) -> tensor<1xi32>
 
   // CHECK: linalg.generic
   // CHECK: cmpi
   // CHECK: select
-  %19 = "tosa.reluN"(%0) {max_int = 5 : i64, max_fp = 5.0 : f32} : (tensor<1xi32>) -> tensor<1xi32>
+  %20 = "tosa.reluN"(%0) {max_int = 5 : i64, max_fp = 5.0 : f32} : (tensor<1xi32>) -> tensor<1xi32>
 
   // CHECK: linalg.generic
   // CHECK: constant -32768
@@ -397,27 +405,27 @@ func @test_simple_i32(%arg0: tensor<1xi32>) -> () {
   // CHECK: cmpi slt
   // CHECK: select
   // CHECK: trunci
-  %20 = "tosa.cast"(%0) : (tensor<1xi32>) -> tensor<1xi16>
+  %21 = "tosa.cast"(%0) : (tensor<1xi32>) -> tensor<1xi16>
 
   // CHECK: linalg.generic
   // CHECK: sexti
-  %21 = "tosa.cast"(%0) : (tensor<1xi32>) -> tensor<1xi64>
+  %22 = "tosa.cast"(%0) : (tensor<1xi32>) -> tensor<1xi64>
 
   // CHECK: linalg.generic
   // CHECK: constant 0
   // CHECK: cmpi
-  %22 = "tosa.cast"(%0) : (tensor<1xi32>) -> tensor<1xi1>
+  %23 = "tosa.cast"(%0) : (tensor<1xi32>) -> tensor<1xi1>
 
   // CHECK: linalg.generic
   // CHECK: sitofp
-  %23 = "tosa.cast"(%0) : (tensor<1xi32>) -> tensor<1xf32>
+  %24 = "tosa.cast"(%0) : (tensor<1xi32>) -> tensor<1xf32>
 
   // CHECK: linalg.generic
   // CHECK: constant 0
   // CHECK: cmpi sgt
   // CHECK: subi
   // CHECK: select
-  %24 = "tosa.abs"(%arg0) : (tensor<1xi32>) -> tensor<1xi32>
+  %25 = "tosa.abs"(%arg0) : (tensor<1xi32>) -> tensor<1xi32>
 
   return
 }
index f9e468e..631620d 100644 (file)
@@ -6497,6 +6497,7 @@ cc_library(
         ":LinalgOps",
         ":MathDialect",
         ":Pass",
+        ":SCFDialect",
         ":StandardOps",
         ":TensorDialect",
         ":TosaDialect",