[MLIR][TOSA] add tosa erf operator
authorManupa Karunaratne <manupa.karunaratne@amd.com>
Fri, 19 May 2023 00:10:25 +0000 (17:10 -0700)
committerEric Kunze <eric.kunze@arm.com>
Fri, 19 May 2023 21:50:14 +0000 (14:50 -0700)
This commit adds tosa erf operator and its lowering
to math lib functions.

Reviewed By: eric-k256, jpienaar

Differential Revision: https://reviews.llvm.org/D150357

mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td
mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp
mlir/lib/Dialect/Tosa/IR/TosaOps.cpp
mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir
mlir/test/Dialect/Tosa/ops.mlir
mlir/test/Dialect/Tosa/tosa-infer-shapes.mlir

index b9fa2f8..842ac74 100644 (file)
@@ -455,6 +455,32 @@ def Tosa_TanhOp : Tosa_Op<"tanh", [
   );
 }
 
+
+//===----------------------------------------------------------------------===//
+// Operator: erf
+//===----------------------------------------------------------------------===//
+def Tosa_ErfOp : Tosa_Op<"erf", [
+    DeclareOpInterfaceMethods<InferShapedTypeOpInterface,
+                              ["inferReturnTypeComponents"]>,
+    Pure]> {
+  let summary = "Computes gauss error function of input";
+
+  let description = [{
+    Gauss error function: $ erf(x) = \frac{2}{\sqrt(\pi)} \int_{0}^{x} e^{-t^2} \,dt $
+    For quantized integer data types, the TABLE operator should be used instead
+    with the following definition.  The erf_table has 513 entries each of
+    16-bit/8-bit precision and covering the input range -4.0 to +4.0 in steps of 1/64.
+  }];
+
+  let arguments = (ins
+    Tosa_Tensor:$input
+  );
+
+  let results = (outs
+    Tosa_Tensor:$output
+  );
+}
+
 //===----------------------------------------------------------------------===//
 // TOSA Spec Section 2.4
 // Operator Class: Elementwise unary/binary/ternary operators.
index 6aa0751..2faf7f1 100644 (file)
@@ -299,6 +299,10 @@ createLinalgBodyCalculationForElementwiseOp(Operation *op, ValueRange args,
   if (isa<tosa::TanhOp>(op) && isa<FloatType>(elementTy))
     return rewriter.create<mlir::math::TanhOp>(loc, resultTypes, args);
 
+  // tosa::ErfOp
+  if (isa<tosa::ErfOp>(op) && elementTy.isa<FloatType>())
+    return rewriter.create<mlir::math::ErfOp>(loc, resultTypes, args);
+
   // tosa::GreaterOp
   if (isa<tosa::GreaterOp>(op) && isa<FloatType>(elementTy))
     return rewriter.create<arith::CmpFOp>(loc, arith::CmpFPredicate::OGT,
@@ -2044,6 +2048,7 @@ void mlir::tosa::populateTosaToLinalgConversionPatterns(
       PointwiseConverter<tosa::ExpOp>,
       PointwiseConverter<tosa::AbsOp>,
       PointwiseConverter<tosa::TanhOp>,
+      PointwiseConverter<tosa::ErfOp>,
       PointwiseConverter<tosa::BitwiseAndOp>,
       PointwiseConverter<tosa::BitwiseOrOp>,
       PointwiseConverter<tosa::BitwiseNotOp>,
index 1040d4c..d2c732c 100644 (file)
@@ -1058,6 +1058,7 @@ NARY_SHAPE_INFER(tosa::RsqrtOp)
 NARY_SHAPE_INFER(tosa::SelectOp)
 NARY_SHAPE_INFER(tosa::SubOp)
 NARY_SHAPE_INFER(tosa::TanhOp)
+NARY_SHAPE_INFER(tosa::ErfOp)
 NARY_SHAPE_INFER(tosa::SigmoidOp)
 #undef PRED_SHAPE_INFER
 
index 3e654ab..65d56ad 100644 (file)
@@ -258,6 +258,10 @@ func.func @test_simple_f32(%arg0: tensor<1xf32>) -> () {
   // CHECK: arith.divf
   %23 = "tosa.reciprocal"(%0) : (tensor<1xf32>) -> tensor<1xf32>
 
+  // CHECK: linalg.generic
+  // CHECK: math.erf
+  %24 = "tosa.erf"(%0) : (tensor<1xf32>) -> tensor<1xf32>
+
   return
 }
 
index bf3cf3d..72f0203 100644 (file)
@@ -115,6 +115,13 @@ func.func @test_tanh(%arg0: tensor<13x21x3xf32>) -> tensor<13x21x3xf32> {
 }
 
 // -----
+// CHECK-LABEL: erf
+func.func @test_erf(%arg0: tensor<13x21x3xf32>) -> tensor<13x21x3xf32> {
+  %0 = "tosa.erf"(%arg0) : (tensor<13x21x3xf32>) -> tensor<13x21x3xf32>
+  return %0 : tensor<13x21x3xf32>
+}
+
+// -----
 // CHECK-LABEL: add
 func.func @test_add(%arg0: tensor<13x21x1xf32>, %arg1: tensor<13x21x3xf32>) -> tensor<13x21x3xf32> {
   %0 = "tosa.add"(%arg0, %arg1) : (tensor<13x21x1xf32>, tensor<13x21x3xf32>) -> tensor<13x21x3xf32>
index 56b0887..5bbb6e1 100644 (file)
@@ -65,6 +65,9 @@ func.func @test_unary_f32(%arg0 : tensor<4xf32>) -> () {
 
   // CHECK: "tosa.cast"(%arg0) : (tensor<4xf32>) -> tensor<4xi32>
   %12 = "tosa.cast"(%arg0) : (tensor<4xf32>) -> tensor<*xi32>
+
+  // CHECK: "tosa.erf"(%arg0) : (tensor<4xf32>) -> tensor<4xf32>
+  %13 = "tosa.erf"(%arg0) : (tensor<4xf32>) -> tensor<*xf32>
   return
 }