);
}
+
+//===----------------------------------------------------------------------===//
+// Operator: erf
+//===----------------------------------------------------------------------===//
+def Tosa_ErfOp : Tosa_Op<"erf", [
+ DeclareOpInterfaceMethods<InferShapedTypeOpInterface,
+ ["inferReturnTypeComponents"]>,
+ Pure]> {
+ let summary = "Computes gauss error function of input";
+
+ let description = [{
+ Gauss error function: $ erf(x) = \frac{2}{\sqrt(\pi)} \int_{0}^{x} e^{-t^2} \,dt $
+ For quantized integer data types, the TABLE operator should be used instead
+ with the following definition. The erf_table has 513 entries each of
+ 16-bit/8-bit precision and covering the input range -4.0 to +4.0 in steps of 1/64.
+ }];
+
+ let arguments = (ins
+ Tosa_Tensor:$input
+ );
+
+ let results = (outs
+ Tosa_Tensor:$output
+ );
+}
+
//===----------------------------------------------------------------------===//
// TOSA Spec Section 2.4
// Operator Class: Elementwise unary/binary/ternary operators.
if (isa<tosa::TanhOp>(op) && isa<FloatType>(elementTy))
return rewriter.create<mlir::math::TanhOp>(loc, resultTypes, args);
+ // tosa::ErfOp
+ if (isa<tosa::ErfOp>(op) && elementTy.isa<FloatType>())
+ return rewriter.create<mlir::math::ErfOp>(loc, resultTypes, args);
+
// tosa::GreaterOp
if (isa<tosa::GreaterOp>(op) && isa<FloatType>(elementTy))
return rewriter.create<arith::CmpFOp>(loc, arith::CmpFPredicate::OGT,
PointwiseConverter<tosa::ExpOp>,
PointwiseConverter<tosa::AbsOp>,
PointwiseConverter<tosa::TanhOp>,
+ PointwiseConverter<tosa::ErfOp>,
PointwiseConverter<tosa::BitwiseAndOp>,
PointwiseConverter<tosa::BitwiseOrOp>,
PointwiseConverter<tosa::BitwiseNotOp>,
NARY_SHAPE_INFER(tosa::SelectOp)
NARY_SHAPE_INFER(tosa::SubOp)
NARY_SHAPE_INFER(tosa::TanhOp)
+NARY_SHAPE_INFER(tosa::ErfOp)
NARY_SHAPE_INFER(tosa::SigmoidOp)
#undef PRED_SHAPE_INFER
// CHECK: arith.divf
%23 = "tosa.reciprocal"(%0) : (tensor<1xf32>) -> tensor<1xf32>
+ // CHECK: linalg.generic
+ // CHECK: math.erf
+ %24 = "tosa.erf"(%0) : (tensor<1xf32>) -> tensor<1xf32>
+
return
}
}
// -----
+// CHECK-LABEL: erf
+func.func @test_erf(%arg0: tensor<13x21x3xf32>) -> tensor<13x21x3xf32> {
+ %0 = "tosa.erf"(%arg0) : (tensor<13x21x3xf32>) -> tensor<13x21x3xf32>
+ return %0 : tensor<13x21x3xf32>
+}
+
+// -----
// CHECK-LABEL: add
func.func @test_add(%arg0: tensor<13x21x1xf32>, %arg1: tensor<13x21x3xf32>) -> tensor<13x21x3xf32> {
%0 = "tosa.add"(%arg0, %arg1) : (tensor<13x21x1xf32>, tensor<13x21x3xf32>) -> tensor<13x21x3xf32>
// CHECK: "tosa.cast"(%arg0) : (tensor<4xf32>) -> tensor<4xi32>
%12 = "tosa.cast"(%arg0) : (tensor<4xf32>) -> tensor<*xi32>
+
+ // CHECK: "tosa.erf"(%arg0) : (tensor<4xf32>) -> tensor<4xf32>
+ %13 = "tosa.erf"(%arg0) : (tensor<4xf32>) -> tensor<*xf32>
return
}