[mlir][tosa] Use math.ctlz intrinsic for tosa.clz
authorRobert Suderman <suderman@google.com>
Mon, 16 May 2022 18:08:49 +0000 (11:08 -0700)
committerRob Suderman <suderman@google.com>
Mon, 16 May 2022 18:31:35 +0000 (11:31 -0700)
We were custom counting per bit for the clz instruction. Math dialect
now has an intrinsic to do this in one instruction. Migrated to this
instruction and fixed a minor bug math-to-llvm for the intrinsic.

Reviewed By: mravishankar

Differential Revision: https://reviews.llvm.org/D125592

mlir/lib/Conversion/MathToLLVM/MathToLLVM.cpp
mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp
mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir

index ea7a13b..189680c 100644 (file)
@@ -74,8 +74,8 @@ struct CountOpLowering : public ConvertOpToLLVMPattern<MathOp> {
         [&](Type llvm1DVectorTy, ValueRange operands) {
           LLVM::ConstantOp zero =
               rewriter.create<LLVM::ConstantOp>(loc, boolType, boolZero);
-          return rewriter.replaceOpWithNewOp<LLVMOp>(op, llvm1DVectorTy,
-                                                     operands[0], zero);
+          return rewriter.create<LLVMOp>(loc, llvm1DVectorTy, operands[0],
+                                         zero);
         },
         rewriter);
   }
index 3fd69c5..5497e93 100644 (file)
@@ -259,54 +259,7 @@ createLinalgBodyCalculationForElementwiseOp(Operation *op, ValueRange args,
 
   // tosa::ClzOp
   if (isa<tosa::ClzOp>(op) && elementTy.isa<IntegerType>()) {
-    int bitWidth = elementTy.getIntOrFloatBitWidth();
-    auto zero =
-        rewriter.create<arith::ConstantOp>(loc, IntegerAttr::get(elementTy, 0));
-    auto leadingZeros = rewriter.create<arith::ConstantOp>(
-        loc, IntegerAttr::get(elementTy, bitWidth));
-
-    SmallVector<Value> operands = {args[0], leadingZeros, zero};
-    SmallVector<Type> types = {elementTy, elementTy, elementTy};
-    SmallVector<Location> locations = {loc, loc, loc};
-
-    auto whileOp = rewriter.create<scf::WhileOp>(loc, types, operands);
-    Block *before =
-        rewriter.createBlock(&whileOp.getBefore(), {}, types, locations);
-    Block *after =
-        rewriter.createBlock(&whileOp.getAfter(), {}, types, locations);
-
-    // The conditional block of the while loop.
-    {
-      rewriter.setInsertionPointToStart(&whileOp.getBefore().front());
-      Value input = before->getArgument(0);
-      Value zero = before->getArgument(2);
-
-      Value inputLargerThanZero = rewriter.create<arith::CmpIOp>(
-          loc, arith::CmpIPredicate::ne, input, zero);
-      rewriter.create<scf::ConditionOp>(loc, inputLargerThanZero,
-                                        before->getArguments());
-    }
-
-    // The body of the while loop: shift right until reaching a value of 0.
-    {
-      rewriter.setInsertionPointToStart(&whileOp.getAfter().front());
-      Value input = after->getArgument(0);
-      Value leadingZeros = after->getArgument(1);
-
-      auto one = rewriter.create<arith::ConstantOp>(
-          loc, IntegerAttr::get(elementTy, 1));
-      auto shifted =
-          rewriter.create<arith::ShRUIOp>(loc, resultTypes, input, one);
-      auto leadingZerosMinusOne =
-          rewriter.create<arith::SubIOp>(loc, resultTypes, leadingZeros, one);
-
-      rewriter.create<scf::YieldOp>(
-          loc,
-          ValueRange({shifted, leadingZerosMinusOne, after->getArgument(2)}));
-    }
-
-    rewriter.setInsertionPointAfter(whileOp);
-    return whileOp->getResult(1);
+    return rewriter.create<math::CountLeadingZerosOp>(loc, elementTy, args[0]);
   }
 
   // tosa::LogicalAnd
index 8864a85..775e7ba 100644 (file)
@@ -366,12 +366,7 @@ func.func @test_simple_i32(%arg0: tensor<1xi32>) -> () {
   // CHECK: arith.addi
   %12 = "tosa.arithmetic_right_shift"(%arg0, %arg0) {round = 1 : i1} : (tensor<1xi32>, tensor<1xi32>) -> tensor<1xi32>
 
-  // CHECK: scf.while
-  // CHECK: arith.cmpi ne
-  // CHECK: scf.condition
-  // CHECK: arith.shrui
-  // CHECK: arith.subi
-  // CHECK: scf.yield
+  // CHECK: math.ctlz
   %13 = "tosa.clz"(%arg0) : (tensor<1xi32>) -> tensor<1xi32>
 
   // CHECK: linalg.generic