#include "mlir/Conversion/TosaToLinalg/TosaToLinalg.h"
#include "mlir/Dialect/Linalg/IR/LinalgOps.h"
#include "mlir/Dialect/Math/IR/Math.h"
+#include "mlir/Dialect/SCF/SCF.h"
#include "mlir/Dialect/StandardOps/IR/Ops.h"
#include "mlir/Dialect/Tensor/IR/Tensor.h"
#include "mlir/Dialect/Tosa/IR/TosaOps.h"
return rewriter.create<mlir::AddIOp>(loc, resultTypes, result, extended);
}
+ // tosa::ClzOp
+ if (isa<tosa::ClzOp>(op) && elementTy.isa<IntegerType>()) {
+ int bitWidth = elementTy.getIntOrFloatBitWidth();
+ auto zero =
+ rewriter.create<mlir::ConstantOp>(loc, IntegerAttr::get(elementTy, 0));
+ auto leadingZeros = rewriter.create<mlir::ConstantOp>(
+ loc, IntegerAttr::get(elementTy, bitWidth));
+
+ SmallVector<Value> operands = {args[0], leadingZeros, zero};
+ SmallVector<Type> types = {elementTy, elementTy, elementTy};
+
+ auto whileOp = rewriter.create<scf::WhileOp>(loc, types, operands);
+ Block *before = rewriter.createBlock(&whileOp.before(), {}, types);
+ Block *after = rewriter.createBlock(&whileOp.after(), {}, types);
+
+ // The conditional block of the while loop.
+ {
+ rewriter.setInsertionPointToStart(&whileOp.before().front());
+ Value input = before->getArgument(0);
+ Value zero = before->getArgument(2);
+
+ Value inputLargerThanZero =
+ rewriter.create<CmpIOp>(loc, CmpIPredicate::ne, input, zero);
+ rewriter.create<scf::ConditionOp>(loc, inputLargerThanZero,
+ before->getArguments());
+ }
+
+ // The body of the while loop: shift right until reaching a value of 0.
+ {
+ rewriter.setInsertionPointToStart(&whileOp.after().front());
+ Value input = after->getArgument(0);
+ Value leadingZeros = after->getArgument(1);
+
+ auto one = rewriter.create<mlir::ConstantOp>(
+ loc, IntegerAttr::get(elementTy, 1));
+ auto shifted = rewriter.create<mlir::UnsignedShiftRightOp>(
+ loc, resultTypes, input, one);
+ auto leadingZerosMinusOne =
+ rewriter.create<mlir::SubIOp>(loc, resultTypes, leadingZeros, one);
+
+ rewriter.create<scf::YieldOp>(
+ loc,
+ ValueRange({shifted, leadingZerosMinusOne, after->getArgument(2)}));
+ }
+
+ rewriter.setInsertionPointAfter(whileOp);
+ return whileOp->getResult(1);
+ }
+
// tosa::LogicalAnd
if (isa<tosa::LogicalAndOp>(op) && elementTy.isInteger(1))
return rewriter.create<mlir::AndOp>(loc, resultTypes, args);
PointwiseConverter<tosa::LogicalLeftShiftOp>,
PointwiseConverter<tosa::LogicalRightShiftOp>,
PointwiseConverter<tosa::ArithmeticRightShiftOp>,
+ PointwiseConverter<tosa::ClzOp>,
PointwiseConverter<tosa::SelectOp>,
PointwiseConverter<tosa::GreaterOp>,
PointwiseConverter<tosa::GreaterEqualOp>,
#include "mlir/Conversion/TosaToLinalg/TosaToLinalg.h"
#include "mlir/Dialect/Linalg/IR/LinalgOps.h"
#include "mlir/Dialect/Math/IR/Math.h"
+#include "mlir/Dialect/SCF/SCF.h"
#include "mlir/Dialect/StandardOps/IR/Ops.h"
#include "mlir/Dialect/Tensor/IR/Tensor.h"
#include "mlir/Dialect/Tosa/IR/TosaOps.h"
: public TosaToLinalgOnTensorsBase<TosaToLinalgOnTensors> {
public:
void getDependentDialects(DialectRegistry ®istry) const override {
- registry.insert<linalg::LinalgDialect, math::MathDialect,
- StandardOpsDialect, tensor::TensorDialect>();
+ registry
+ .insert<linalg::LinalgDialect, math::MathDialect, StandardOpsDialect,
+ tensor::TensorDialect, scf::SCFDialect>();
}
void runOnFunction() override {
RewritePatternSet patterns(&getContext());
ConversionTarget target(getContext());
target.addLegalDialect<linalg::LinalgDialect, StandardOpsDialect,
- tensor::TensorDialect>();
+ tensor::TensorDialect, scf::SCFDialect>();
target.addIllegalDialect<tosa::TosaDialect>();
// Not every TOSA op can be legalized to linalg.
// CHECK: addi
%12 = "tosa.arithmetic_right_shift"(%arg0, %arg0) {round = 1 : i1} : (tensor<1xi32>, tensor<1xi32>) -> tensor<1xi32>
+ // CHECK: scf.while
+ // CHECK: cmpi ne
+ // CHECK: scf.condition
+ // CHECK: shift_right_unsigned
+ // CHECK: subi
+ // CHECK: scf.yield
+ %13 = "tosa.clz"(%arg0) : (tensor<1xi32>) -> tensor<1xi32>
+
// CHECK: linalg.generic
// CHECK: cmpi
- %13 = "tosa.greater"(%0, %1) : (tensor<1xi32>, tensor<1xi32>) -> tensor<1xi1>
+ %14 = "tosa.greater"(%0, %1) : (tensor<1xi32>, tensor<1xi32>) -> tensor<1xi1>
// CHECK: linalg.generic
// CHECK: cmpi
- %14 = "tosa.greater_equal"(%0, %1) : (tensor<1xi32>, tensor<1xi32>) -> tensor<1xi1>
+ %15 = "tosa.greater_equal"(%0, %1) : (tensor<1xi32>, tensor<1xi32>) -> tensor<1xi1>
// CHECK: linalg.generic
// CHECK: select
- %15 = "tosa.select"(%13, %0, %1) : (tensor<1xi1>, tensor<1xi32>, tensor<1xi32>) -> tensor<1xi32>
+ %16 = "tosa.select"(%14, %0, %1) : (tensor<1xi1>, tensor<1xi32>, tensor<1xi32>) -> tensor<1xi32>
// CHECK: linalg.generic
// CHECK: cmpi
// CHECK: select
- %16 = "tosa.maximum"(%0, %1) : (tensor<1xi32>, tensor<1xi32>) -> tensor<1xi32>
+ %17 = "tosa.maximum"(%0, %1) : (tensor<1xi32>, tensor<1xi32>) -> tensor<1xi32>
// CHECK: linalg.generic
// CHECK: cmpi
// CHECK: select
- %17 = "tosa.minimum"(%0, %1) : (tensor<1xi32>, tensor<1xi32>) -> tensor<1xi32>
+ %18 = "tosa.minimum"(%0, %1) : (tensor<1xi32>, tensor<1xi32>) -> tensor<1xi32>
// CHECK: linalg.generic
// CHECK: cmpi
// CHECK: select
- %18 = "tosa.clamp"(%0) {min_int = 1 : i64, max_int = 5 : i64, min_fp = 1.0 : f32, max_fp = 5.0 : f32} : (tensor<1xi32>) -> tensor<1xi32>
+ %19 = "tosa.clamp"(%0) {min_int = 1 : i64, max_int = 5 : i64, min_fp = 1.0 : f32, max_fp = 5.0 : f32} : (tensor<1xi32>) -> tensor<1xi32>
// CHECK: linalg.generic
// CHECK: cmpi
// CHECK: select
- %19 = "tosa.reluN"(%0) {max_int = 5 : i64, max_fp = 5.0 : f32} : (tensor<1xi32>) -> tensor<1xi32>
+ %20 = "tosa.reluN"(%0) {max_int = 5 : i64, max_fp = 5.0 : f32} : (tensor<1xi32>) -> tensor<1xi32>
// CHECK: linalg.generic
// CHECK: constant -32768
// CHECK: cmpi slt
// CHECK: select
// CHECK: trunci
- %20 = "tosa.cast"(%0) : (tensor<1xi32>) -> tensor<1xi16>
+ %21 = "tosa.cast"(%0) : (tensor<1xi32>) -> tensor<1xi16>
// CHECK: linalg.generic
// CHECK: sexti
- %21 = "tosa.cast"(%0) : (tensor<1xi32>) -> tensor<1xi64>
+ %22 = "tosa.cast"(%0) : (tensor<1xi32>) -> tensor<1xi64>
// CHECK: linalg.generic
// CHECK: constant 0
// CHECK: cmpi
- %22 = "tosa.cast"(%0) : (tensor<1xi32>) -> tensor<1xi1>
+ %23 = "tosa.cast"(%0) : (tensor<1xi32>) -> tensor<1xi1>
// CHECK: linalg.generic
// CHECK: sitofp
- %23 = "tosa.cast"(%0) : (tensor<1xi32>) -> tensor<1xf32>
+ %24 = "tosa.cast"(%0) : (tensor<1xi32>) -> tensor<1xf32>
// CHECK: linalg.generic
// CHECK: constant 0
// CHECK: cmpi sgt
// CHECK: subi
// CHECK: select
- %24 = "tosa.abs"(%arg0) : (tensor<1xi32>) -> tensor<1xi32>
+ %25 = "tosa.abs"(%arg0) : (tensor<1xi32>) -> tensor<1xi32>
return
}
":LinalgOps",
":MathDialect",
":Pass",
+ ":SCFDialect",
":StandardOps",
":TensorDialect",
":TosaDialect",