Export PyTorch erf to ONNX Erf and add Caffe2 Erf operator
authorbddppq <bai@in.tum.de>
Thu, 17 Jan 2019 17:15:14 +0000 (09:15 -0800)
committerFacebook Github Bot <facebook-github-bot@users.noreply.github.com>
Thu, 17 Jan 2019 17:18:08 +0000 (09:18 -0800)
Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/16106

Differential Revision: D13709490

Pulled By: bddppq

fbshipit-source-id: 1b5b32261f06543371f7bd7ac9b11957a5eb4ad0

14 files changed:
caffe2/operators/erf_op.cc [new file with mode: 0644]
caffe2/operators/erf_op.cu [new file with mode: 0644]
caffe2/operators/erf_op.h [new file with mode: 0644]
caffe2/python/onnx/tests/onnx_backend_test.py
caffe2/python/operator_test/erf_op_test.py [new file with mode: 0644]
caffe2/python/serialized_test/data/operator_test/erf_op_test.test_erf.zip [new file with mode: 0644]
caffe2/utils/eigen_utils.h
caffe2/utils/math.h
caffe2/utils/math_cpu.cc
caffe2/utils/math_gpu.cu
test/onnx/expect/TestOperators.test_erf.expect [new file with mode: 0644]
test/onnx/test_operators.py
test/onnx/test_pytorch_onnx_caffe2.py
torch/onnx/symbolic.py

diff --git a/caffe2/operators/erf_op.cc b/caffe2/operators/erf_op.cc
new file mode 100644 (file)
index 0000000..efb15ae
--- /dev/null
@@ -0,0 +1,74 @@
+#include "caffe2/operators/erf_op.h"
+#include "caffe2/utils/eigen_utils.h"
+
+#include <algorithm>
+#include <functional>
+
+namespace caffe2 {
+
+template <>
+template <typename T>
+bool ErfGradientFunctor<CPUContext>::Forward(
+    const std::vector<int>& X_dims,
+    const std::vector<int>& /* dY_dims */,
+    const T* X,
+    const T* dY,
+    T* dX,
+    CPUContext* /* context */) const {
+  const int size = std::accumulate(
+      X_dims.cbegin(), X_dims.cend(), 1, std::multiplies<int>());
+  ConstEigenVectorArrayMap<T> dY_arr(dY, size);
+  ConstEigenVectorArrayMap<T> X_arr(X, size);
+  EigenVectorMap<T>(dX, size) = T(2) / sqrtf(PI) * (-X_arr.square()).exp() * dY_arr;
+  return true;
+}
+
+REGISTER_CPU_OPERATOR(
+    Erf,
+    UnaryElementwiseOp<
+        TensorTypes<float>,
+        CPUContext,
+        ErfFunctor<CPUContext>>);
+REGISTER_CPU_OPERATOR(
+    ErfGradient,
+    BinaryElementwiseOp<
+        TensorTypes<float>,
+        CPUContext,
+        ErfGradientFunctor<CPUContext>>);
+
+OPERATOR_SCHEMA(Erf)
+    .NumInputs(1)
+    .NumOutputs(1)
+    .IdenticalTypeAndShape()
+    .SetDoc(R"DOC(
+Calculates the arcsine of the given input tensor, element-wise.
+)DOC")
+    .Input(0, "input", "Input tensor")
+    .Output(
+        0,
+        "output",
+        "The arcsine of the input tensor computed element-wise");
+
+OPERATOR_SCHEMA(ErfGradient)
+    .NumInputs(2)
+    .NumOutputs(1)
+    .IdenticalTypeAndShape();
+
+namespace {
+
+class GetErfGradient : public GradientMakerBase {
+  using GradientMakerBase::GradientMakerBase;
+  std::vector<OperatorDef> GetGradientDefs() override {
+    return SingleGradientDef(
+        "ErfGradient",
+        "",
+        std::vector<std::string>{I(0), GO(0)},
+        std::vector<std::string>{GI(0)});
+  }
+};
+
+} // namespace
+
+REGISTER_GRADIENT(Erf, GetErfGradient);
+
+} // namespace caffe2
diff --git a/caffe2/operators/erf_op.cu b/caffe2/operators/erf_op.cu
new file mode 100644 (file)
index 0000000..72f3895
--- /dev/null
@@ -0,0 +1,60 @@
+#include "caffe2/operators/erf_op.h"
+
+#include <algorithm>
+#include <functional>
+
+#include "caffe2/core/context_gpu.h"
+
+namespace caffe2 {
+
+namespace {
+
+__global__ void ErfGradientCUDAKernel(
+    const int N,
+    const float* dY,
+    const float* X,
+    float* dX) {
+  CUDA_1D_KERNEL_LOOP(i, N) {
+#if __CUDA_ARCH__ >= 350
+    dX[i] = 2.0f / sqrtf(PI) * expf(-powf(__ldg(X+i), 2.0f)) * __ldg(dY + i);
+#else
+    dX[i] = 2.0f / sqrtf(PI) * expf(-powf(X[i], 2.0f)) * dY[i];
+#endif
+  }
+}
+
+} // namespace
+
+template <>
+template <typename T>
+bool ErfGradientFunctor<CUDAContext>::Forward(
+    const std::vector<int>& X_dims,
+    const std::vector<int>& /* dY_dims */,
+    const T* X,
+    const T* dY,
+    T* dX,
+    CUDAContext* context) const {
+  const int size = std::accumulate(
+      X_dims.cbegin(), X_dims.cend(), 1, std::multiplies<int>());
+  ErfGradientCUDAKernel<<<
+      CAFFE_GET_BLOCKS(size),
+      CAFFE_CUDA_NUM_THREADS,
+      0,
+      context->cuda_stream()>>>(size, dY, X, dX);
+  return true;
+}
+
+REGISTER_CUDA_OPERATOR(
+    Erf,
+    UnaryElementwiseOp<
+        TensorTypes<float>,
+        CUDAContext,
+        ErfFunctor<CUDAContext>>);
+REGISTER_CUDA_OPERATOR(
+    ErfGradient,
+    BinaryElementwiseOp<
+        TensorTypes<float>,
+        CUDAContext,
+        ErfGradientFunctor<CUDAContext>>);
+
+} // namespace caffe2
diff --git a/caffe2/operators/erf_op.h b/caffe2/operators/erf_op.h
new file mode 100644 (file)
index 0000000..40fd8db
--- /dev/null
@@ -0,0 +1,36 @@
+#ifndef CAFFE2_OPERATORS_ERF_OP_H_
+#define CAFFE2_OPERATORS_ERF_OP_H_
+
+#include <vector>
+
+#include "caffe2/operators/elementwise_ops.h"
+#include "caffe2/utils/math.h"
+
+constexpr float PI = 3.14159265358979323846;
+
+namespace caffe2 {
+
+template <class Context>
+struct ErfFunctor {
+  template <typename T>
+  bool operator()(const int N, const T* X, T* Y, Context* context) const {
+    math::Erf(N, X, Y, context);
+    return true;
+  }
+};
+
+template <class Context>
+struct ErfGradientFunctor {
+  template <typename T>
+  bool Forward(
+      const std::vector<int>& X_dims,
+      const std::vector<int>& dY_dims,
+      const T* X,
+      const T* dY,
+      T* dX,
+      Context* context) const;
+};
+
+} // namespace caffe2
+
+#endif // CAFFE2_OPERATORS_ERF_OP_H_
index 79d3abc..b979759 100644 (file)
@@ -49,7 +49,6 @@ backend_test.exclude(r'(test_hardsigmoid'  # Does not support Hardsigmoid.
                      '|test_atanh.*'  # Needs implementation
                      '|test_onehot.*'  # Needs implementation
                      '|test_scan.*'  # Needs implementation
-                     '|test_erf.*'  # Needs implementation
                      '|test_isnan.*'  # Needs implementation
                      '|test_scatter.*'  # Should be similar to ScatterAssign
                      '|test_constantofshape.*'  # Needs implementation
diff --git a/caffe2/python/operator_test/erf_op_test.py b/caffe2/python/operator_test/erf_op_test.py
new file mode 100644 (file)
index 0000000..d47ead5
--- /dev/null
@@ -0,0 +1,30 @@
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+from __future__ import unicode_literals
+
+import math
+
+from caffe2.python import core
+from hypothesis import given
+from hypothesis import strategies as st
+import caffe2.python.hypothesis_test_util as hu
+import caffe2.python.serialized_test.serialized_test_util as serial
+
+import numpy as np
+import unittest
+
+
+class TestErfOp(serial.SerializedTestCase):
+    @serial.given(
+        X=hu.tensor(elements=st.floats(min_value=-0.7, max_value=0.7)),
+        **hu.gcs)
+    def test_erf(self, X, gc, dc):
+        op = core.CreateOperator('Erf', ["X"], ["Y"])
+        self.assertReferenceChecks(gc, op, [X], lambda x: (np.vectorize(math.erf)(X),))
+        self.assertDeviceChecks(dc, op, [X], [0])
+        self.assertGradientChecks(gc, op, [X], 0, [0])
+
+
+if __name__ == "__main__":
+    unittest.main()
diff --git a/caffe2/python/serialized_test/data/operator_test/erf_op_test.test_erf.zip b/caffe2/python/serialized_test/data/operator_test/erf_op_test.test_erf.zip
new file mode 100644 (file)
index 0000000..3e50fe6
Binary files /dev/null and b/caffe2/python/serialized_test/data/operator_test/erf_op_test.test_erf.zip differ
index 0b35a15..83e7cb2 100644 (file)
@@ -5,6 +5,7 @@
 
 #include "Eigen/Core"
 #include "Eigen/Dense"
+
 #include "caffe2/core/logging.h"
 
 namespace caffe2 {
index f406a1a..cb949ee 100644 (file)
@@ -74,6 +74,8 @@ template <typename T, class Context>
 void Powx(const int N, const T* a, const T b, T* y, Context* context);
 template <typename T, class Context>
 void Inv(const int N, const T* x, T* y, Context* context);
+template <typename T, class Context>
+void Erf(const int N, const T* x, T* y, Context* context);
 
 #define C10_DECLARE_COMPARE_OP(Comp)                                         \
   template <typename T, class Context>                                       \
index b6dda59..dd78239 100644 (file)
@@ -680,6 +680,8 @@ DELEGATE_SIMPLE_UNARY_FUNCTION(float, Cbrt, vsCbrt)
 DELEGATE_SIMPLE_UNARY_FUNCTION(double, Cbrt, vdCbrt)
 DELEGATE_SIMPLE_UNARY_FUNCTION(float, Inv, vsInv)
 DELEGATE_SIMPLE_UNARY_FUNCTION(double, Inv, vdInv)
+DELEGATE_SIMPLE_UNARY_FUNCTION(float, Erf, vsErf)
+DELEGATE_SIMPLE_UNARY_FUNCTION(double, Erf, vdErf)
 #undef DELEGATE_SIMPLE_UNARY_FUNCTION
 
 #define DELEGATE_SINCOS_FUNCTION(T, OriginalFunc)           \
@@ -748,7 +750,6 @@ DELEGATE_SIMPLE_UNARY_FUNCTION(float, Sqrt, sqrt)
 DELEGATE_SIMPLE_UNARY_FUNCTION(double, Sqrt, sqrt)
 DELEGATE_SIMPLE_UNARY_FUNCTION(float, Rsqrt, rsqrt)
 DELEGATE_SIMPLE_UNARY_FUNCTION(double, Rsqrt, rsqrt)
-
 #undef DELEGATE_SIMPLE_UNARY_FUNCTION
 
 #define DELEGATE_SINCOS_FUNCTION(T)                                     \
@@ -784,6 +785,16 @@ DELEGATE_CBRT_FUNCTION(float)
 DELEGATE_CBRT_FUNCTION(double)
 #undef DELEGATE_CBRT_FUNCTION
 
+#define DELEGATE_ERF_FUNCTION(T)                                    \
+    template <>                                                     \
+  C10_EXPORT void Erf<T, CPUContext>(                               \
+      const int N, const T* X, T* Y, CPUContext*) {                 \
+    std::transform(X, X + N, Y, [](const T x) { return erf(x); });  \
+  }
+DELEGATE_ERF_FUNCTION(float)
+DELEGATE_ERF_FUNCTION(double)
+#undef DELEGATE_ERF_FUNCTION
+
 #define DELEGATE_POWX_FUNCTION(T)                                       \
   template <>                                                           \
   C10_EXPORT void Powx<T, CPUContext>(                                  \
index dc7cb22..820dfe4 100644 (file)
@@ -346,6 +346,8 @@ DELEGATE_SIMPLE_CUDA_UNARY_FUNCTION(float, Sqrt, sqrtf)
 DELEGATE_SIMPLE_CUDA_UNARY_FUNCTION(float, Rsqrt, rsqrtf)
 DELEGATE_SIMPLE_CUDA_UNARY_FUNCTION(float, Cbrt, cbrtf)
 
+DELEGATE_SIMPLE_CUDA_UNARY_FUNCTION(float, Erf, erff)
+DELEGATE_SIMPLE_CUDA_UNARY_FUNCTION(double, Erf, erf)
 DELEGATE_SIMPLE_CUDA_UNARY_FUNCTION(float, Cube, utils::Cube<float>)
 DELEGATE_SIMPLE_CUDA_UNARY_FUNCTION(double, Cube, utils::Cube<double>)
 DELEGATE_SIMPLE_CUDA_UNARY_FUNCTION(
diff --git a/test/onnx/expect/TestOperators.test_erf.expect b/test/onnx/expect/TestOperators.test_erf.expect
new file mode 100644 (file)
index 0000000..8c33540
--- /dev/null
@@ -0,0 +1,58 @@
+ir_version: 3
+producer_name: "pytorch"
+producer_version: "0.4"
+graph {
+  node {
+    input: "0"
+    output: "1"
+    op_type: "Erf"
+  }
+  name: "torch-jit-export"
+  input {
+    name: "0"
+    type {
+      tensor_type {
+        elem_type: 1
+        shape {
+          dim {
+            dim_value: 1
+          }
+          dim {
+            dim_value: 2
+          }
+          dim {
+            dim_value: 3
+          }
+          dim {
+            dim_value: 4
+          }
+        }
+      }
+    }
+  }
+  output {
+    name: "1"
+    type {
+      tensor_type {
+        elem_type: 1
+        shape {
+          dim {
+            dim_value: 1
+          }
+          dim {
+            dim_value: 2
+          }
+          dim {
+            dim_value: 3
+          }
+          dim {
+            dim_value: 4
+          }
+        }
+      }
+    }
+  }
+}
+opset_import {
+  version: 9
+}
index af53a3e..ddb98fa 100644 (file)
@@ -510,6 +510,10 @@ class TestOperators(TestCase):
         x = torch.randn(1, 2, 3, 4)
         self.assertONNX(lambda x: torch.min(x), x)
 
+    def test_erf(self):
+        x = torch.randn(1, 2, 3, 4)
+        self.assertONNX(lambda x: x.erf(), x)
+
 if __name__ == '__main__':
     no_onnx_dep_flag = '--no-onnx'
     _onnx_dep = no_onnx_dep_flag not in common.UNITTEST_ARGS
index 2ea4352..e45a6dc 100644 (file)
@@ -589,6 +589,16 @@ class TestCaffe2Backend(unittest.TestCase):
         input = torch.empty(BATCH_SIZE, 10, 10).uniform_(4, 9)
         self.run_model_test(MyModel(), train=False, input=input, batch_size=BATCH_SIZE)
 
+    def test_erf(self):
+        class MyModel(torch.nn.Module):
+            def __init__(self):
+                super(MyModel, self).__init__()
+
+            def forward(self, input):
+                return input.erf()
+        input = torch.empty(BATCH_SIZE, 10, 10).uniform_(4, 9)
+        self.run_model_test(MyModel(), train=False, input=input, batch_size=BATCH_SIZE)
+
     def test_trigonometry(self):
         def test_func(name):
             class MyModel(torch.nn.Module):
index 30091ec..3007882 100644 (file)
@@ -1475,3 +1475,8 @@ def rrelu(g, input, lower, upper, training, generator):
 def log_sigmoid(g, input):
     p = g.op('Sigmoid', input)
     return g.op('Log', p)
+
+
+@parse_args('v')
+def erf(g, input):
+    return g.op('Erf', input)