Add gelu op (#18992)
authorXiaomeng Yang <yangxm@fb.com>
Tue, 9 Apr 2019 04:55:43 +0000 (21:55 -0700)
committerFacebook Github Bot <facebook-github-bot@users.noreply.github.com>
Tue, 9 Apr 2019 04:58:29 +0000 (21:58 -0700)
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/18992

Add gelu op

Reviewed By: houseroad

Differential Revision: D14814811

fbshipit-source-id: 00f126b8b83763c57ebbf28fbd2de5a8fab6d491

caffe2/operators/gelu_op.cc [new file with mode: 0644]
caffe2/operators/gelu_op.cu [new file with mode: 0644]
caffe2/operators/gelu_op.h [new file with mode: 0644]
caffe2/python/operator_test/activation_ops_test.py
caffe2/utils/math/elementwise.cc
caffe2/utils/math/elementwise.cu
caffe2/utils/math/elementwise.h

diff --git a/caffe2/operators/gelu_op.cc b/caffe2/operators/gelu_op.cc
new file mode 100644 (file)
index 0000000..8bd9bf0
--- /dev/null
@@ -0,0 +1,132 @@
+#include "caffe2/operators/gelu_op.h"
+
+#include <algorithm>
+#include <cmath>
+#include <functional>
+#include <numeric>
+#include <string>
+#include <vector>
+
+#include "caffe2/utils/eigen_utils.h"
+#include "caffe2/utils/math.h"
+
+namespace caffe2 {
+
+template <>
+template <typename T>
+bool GeluFunctor<CPUContext>::
+operator()(const int N, const T* X, T* Y, CPUContext* context) const {
+  if (fast_gelu) {
+    // y = 0.5x * (1 + tanh(sqrt(2/Pi) * (x + 0.044715x^3)))
+    constexpr T kAlpha = gelu_utils::kSqrt2 / gelu_utils::kSqrtPi;
+    ConstEigenVectorArrayMap<T> X_arr(X, N);
+    EigenVectorArrayMap<T> Y_arr(Y, N);
+    Y_arr = X_arr *
+        (((X_arr + X_arr.cube() * gelu_utils::kFastCoeff) * kAlpha).tanh() +
+         T(1)) *
+        static_cast<T>(0.5);
+  } else {
+    // y = x * P(X <= x) where X ~ N(0, 1)
+    math::CdfNorm<T, CPUContext>(N, X, Y, context);
+    math::Mul<T, CPUContext>(N, X, Y, Y, context);
+  }
+  return true;
+}
+
+template <>
+template <typename T>
+bool GeluGradientFunctor<CPUContext>::Forward(
+    const std::vector<int>& dY_dims,
+    const std::vector<int>& /* X_dims */,
+    const T* dY,
+    const T* X,
+    T* dX,
+    CPUContext* context) const {
+  const int N = std::accumulate(
+      dY_dims.cbegin(), dY_dims.cend(), 1, std::multiplies<int>());
+  ConstEigenVectorArrayMap<T> dY_arr(dY, N);
+  ConstEigenVectorArrayMap<T> X_arr(X, N);
+  EigenVectorArrayMap<T> dX_arr(dX, N);
+  if (fast_gelu) {
+    constexpr T kAlpha = gelu_utils::kSqrt2 / gelu_utils::kSqrtPi;
+    constexpr T kBeta = kAlpha * gelu_utils::kFastCoeff * T(3);
+    dX_arr = ((X_arr + X_arr.cube() * gelu_utils::kFastCoeff) * kAlpha).tanh();
+    dX_arr =
+        (T(1) + dX_arr +
+         X_arr * (T(1) - dX_arr.square()) * (kBeta * X_arr.square() + kAlpha)) *
+        dY_arr * static_cast<T>(0.5);
+  } else {
+    constexpr T kAlpha = T(1) / (gelu_utils::kSqrt2 * gelu_utils::kSqrtPi);
+    math::CdfNorm<T, CPUContext>(N, X, dX, context);
+    dX_arr = (dX_arr +
+              X_arr * (-X_arr.square() * static_cast<T>(0.5)).exp() * kAlpha) *
+        dY_arr;
+  }
+  return true;
+}
+
+REGISTER_CPU_OPERATOR(
+    Gelu,
+    UnaryElementwiseWithArgsOp<
+        TensorTypes<float>,
+        CPUContext,
+        GeluFunctor<CPUContext>>);
+REGISTER_CPU_OPERATOR(
+    GeluGradient,
+    BinaryElementwiseWithArgsOp<
+        TensorTypes<float>,
+        CPUContext,
+        GeluGradientFunctor<CPUContext>>);
+
+namespace {
+
+OpSchema::Cost CostInferenceForGelu(
+    const OperatorDef& def,
+    const vector<TensorShape>& in) {
+  struct OpSchema::Cost cost = PointwiseCostInference<2>(def, in);
+  cost.params_bytes = 0;
+  return cost;
+}
+
+} // namespace
+
+// Input: X, output: Y
+OPERATOR_SCHEMA(Gelu)
+    .NumInputs(1)
+    .NumOutputs(1)
+    .Arg(
+        "fast_gelu",
+        "If true, use y = 0.5x * (1 + tanh(sqrt(2/Pi) * (x + 0.044715x^3))).")
+    .CostInferenceFunction(CostInferenceForGelu)
+    .IdenticalTypeAndShape()
+    .SetDoc(R"DOC(
+Relu takes one input data (Tensor) and produces one output data
+(Tensor) where the rectified linear function, y = xP(X <= x) where X ~ N(0, 1),
+is applied to the tensor elementwise.
+)DOC")
+    .Input(0, "X", "1D input tensor")
+    .Output(0, "Y", "1D input tensor");
+
+OPERATOR_SCHEMA(GeluGradient)
+    .NumInputs(2)
+    .NumOutputs(1)
+    .IdenticalTypeAndShape();
+
+namespace {
+
+class GetGeluGradient : public GradientMakerBase {
+  using GradientMakerBase::GradientMakerBase;
+  std::vector<OperatorDef> GetGradientDefs() override {
+    return SingleGradientDef(
+        "GeluGradient",
+        "",
+        std::vector<std::string>{GO(0), I(0)},
+        std::vector<std::string>{GI(0)});
+  }
+};
+
+} // namespace
+
+REGISTER_GRADIENT(Gelu, GetGeluGradient);
+
+} // namespace caffe2
diff --git a/caffe2/operators/gelu_op.cu b/caffe2/operators/gelu_op.cu
new file mode 100644 (file)
index 0000000..4c42fc4
--- /dev/null
@@ -0,0 +1,158 @@
+#include "caffe2/operators/gelu_op.h"
+
+#include <algorithm>
+#include <functional>
+#include <numeric>
+#include <vector>
+
+#include "caffe2/core/context_gpu.h"
+#include "caffe2/utils/math.h"
+
+namespace caffe2 {
+
+namespace {
+
+// y = x * P(X <= x) where X ~ N(0, 1)
+template <typename T>
+__global__ void GeluCUDAKernel(const int N, const T* X, T* Y);
+
+#define DELEGATE_GELU_CUDA_KERNEL(T, CdfNormFunc)                        \
+  template <>                                                            \
+  __global__ void GeluCUDAKernel<T>(const int N, const T* X, T* Y) {     \
+    const int index = blockIdx.x * CAFFE_CUDA_NUM_THREADS + threadIdx.x; \
+    if (index < N) {                                                     \
+      Y[index] = X[index] * CdfNormFunc(X[index]);                       \
+    }                                                                    \
+  }
+DELEGATE_GELU_CUDA_KERNEL(float, normcdff)
+#undef DELEGATE_GELU_CUDA_KERNEL
+
+// y = 0.5x * (1 + tanh(sqrt(2/Pi) * (x + 0.044715x^3)))
+template <typename T>
+__global__ void FastGeluCUDAKernel(const int N, const T* X, T* Y);
+
+#define DELEGATE_FAST_GELU_CUDA_KERNEL(T, FMAFunc, TanhFunc)             \
+  template <>                                                            \
+  __global__ void FastGeluCUDAKernel(const int N, const T* X, T* Y) {    \
+    constexpr T kAlpha = gelu_utils::kSqrt2 / gelu_utils::kSqrtPi;       \
+    const int index = blockIdx.x * CAFFE_CUDA_NUM_THREADS + threadIdx.x; \
+    if (index < N) {                                                     \
+      Y[index] = static_cast<T>(0.5) *                                   \
+          FMAFunc(X[index],                                              \
+                  TanhFunc(                                              \
+                      kAlpha *                                           \
+                      FMAFunc(                                           \
+                          gelu_utils::kFastCoeff,                        \
+                          math::utils::Cube<T>(X[index]),                \
+                          X[index])),                                    \
+                  X[index]);                                             \
+    }                                                                    \
+  }
+DELEGATE_FAST_GELU_CUDA_KERNEL(float, fmaf, tanhf)
+#undef DELEGATE_FAST_GELU_CUDA_KERNEL
+
+template <typename T>
+__global__ void
+GeluGradientCUDAKernel(const int N, const T* dY, const T* X, T* dX);
+
+#define DELEGATE_GELU_GRADIENT_CUDA_KERNEL(T, FMAFunc, CdfNormFunc, ExpFunc) \
+  template <>                                                                \
+  __global__ void GeluGradientCUDAKernel<T>(                                 \
+      const int N, const T* dY, const T* X, T* dX) {                         \
+    constexpr T kAlpha = T(1) / (gelu_utils::kSqrt2 * gelu_utils::kSqrtPi);  \
+    const int index = blockIdx.x * CAFFE_CUDA_NUM_THREADS + threadIdx.x;     \
+    if (index < N) {                                                         \
+      dX[index] = dY[index] *                                                \
+          FMAFunc(kAlpha * X[index],                                         \
+                  ExpFunc(-X[index] * X[index] * static_cast<T>(0.5)),       \
+                  CdfNormFunc(X[index]));                                    \
+    }                                                                        \
+  }
+DELEGATE_GELU_GRADIENT_CUDA_KERNEL(float, fmaf, normcdff, expf)
+#undef DELEGATE_GELU_GRADIENT_CUDA_KERNEL
+
+template <typename T>
+__global__ void
+FastGeluGradientCUDAKernel(const int N, const T* dY, const T* X, T* dX);
+
+#define DELEGATE_FAST_GELU_GRADIENT_CUDA_KERNEL(T, FMAFunc, TanhFunc)    \
+  template <>                                                            \
+  __global__ void FastGeluGradientCUDAKernel<T>(                         \
+      const int N, const T* dY, const T* X, T* dX) {                     \
+    constexpr T kAlpha = gelu_utils::kSqrt2 / gelu_utils::kSqrtPi;       \
+    constexpr T kBeta = kAlpha * gelu_utils::kFastCoeff * T(3);          \
+    const int index = blockIdx.x * CAFFE_CUDA_NUM_THREADS + threadIdx.x; \
+    if (index < N) {                                                     \
+      const T y = TanhFunc(                                              \
+          kAlpha *                                                       \
+          FMAFunc(                                                       \
+              gelu_utils::kFastCoeff,                                    \
+              math::utils::Cube<T>(X[index]),                            \
+              X[index]));                                                \
+      dX[index] = FMAFunc(                                               \
+                      FMAFunc(-X[index], y * y, X[index]),               \
+                      FMAFunc(kBeta, X[index] * X[index], kAlpha),       \
+                      T(1) + y) *                                        \
+          dY[index] * static_cast<T>(0.5);                               \
+    }                                                                    \
+  }
+DELEGATE_FAST_GELU_GRADIENT_CUDA_KERNEL(float, fmaf, tanhf)
+#undef DELEGATE_FAST_GELU_GRADIENT_CUDA_KERNEL
+
+} // namespace
+
+template <>
+template <typename T>
+bool GeluFunctor<CUDAContext>::
+operator()(const int N, const T* X, T* Y, CUDAContext* context) const {
+  const int M = math::DivUp(N, CAFFE_CUDA_NUM_THREADS);
+  if (fast_gelu) {
+    FastGeluCUDAKernel<T>
+        <<<M, CAFFE_CUDA_NUM_THREADS, 0, context->cuda_stream()>>>(N, X, Y);
+  } else {
+    GeluCUDAKernel<T>
+        <<<M, CAFFE_CUDA_NUM_THREADS, 0, context->cuda_stream()>>>(N, X, Y);
+  }
+  return true;
+}
+
+template <>
+template <typename T>
+bool GeluGradientFunctor<CUDAContext>::Forward(
+    const std::vector<int>& dY_dims,
+    const std::vector<int>& /* X_dims */,
+    const T* dY,
+    const T* X,
+    T* dX,
+    CUDAContext* context) const {
+  const int N = std::accumulate(
+      dY_dims.cbegin(), dY_dims.cend(), 1, std::multiplies<int>());
+  const int M = math::DivUp(N, CAFFE_CUDA_NUM_THREADS);
+  if (fast_gelu) {
+    // y = 0.5x * (1 + tanh(sqrt(2/Pi) * (x + 0.044715x^3)))
+    FastGeluGradientCUDAKernel<T>
+        <<<M, CAFFE_CUDA_NUM_THREADS, 0, context->cuda_stream()>>>(
+            N, dY, X, dX);
+  } else {
+    // y = x * P(X <= x) where X ~ N(0, 1)
+    GeluGradientCUDAKernel<T>
+        <<<M, CAFFE_CUDA_NUM_THREADS, 0, context->cuda_stream()>>>(
+            N, dY, X, dX);
+  }
+  return true;
+}
+
+REGISTER_CUDA_OPERATOR(
+    Gelu,
+    UnaryElementwiseWithArgsOp<
+        TensorTypes<float>,
+        CUDAContext,
+        GeluFunctor<CUDAContext>>);
+REGISTER_CUDA_OPERATOR(
+    GeluGradient,
+    BinaryElementwiseWithArgsOp<
+        TensorTypes<float>,
+        CUDAContext,
+        GeluGradientFunctor<CUDAContext>>);
+
+} // namespace caffe2
diff --git a/caffe2/operators/gelu_op.h b/caffe2/operators/gelu_op.h
new file mode 100644 (file)
index 0000000..594315e
--- /dev/null
@@ -0,0 +1,49 @@
+#ifndef CAFFE2_OPERATORS_GELU_OP_H_
+#define CAFFE2_OPERATORS_GELU_OP_H_
+
+#include "caffe2/core/common.h"
+#include "caffe2/core/context.h"
+#include "caffe2/core/operator.h"
+#include "caffe2/operators/elementwise_ops.h"
+
+namespace caffe2 {
+
+namespace gelu_utils {
+
+constexpr float kSqrt2 = 1.4142135623730951f;
+constexpr float kSqrtPi = 1.7724538509055159f;
+constexpr float kFastCoeff = 0.044715f;
+
+} // namespace gelu_utils
+
+template <class Context>
+struct GeluFunctor {
+  explicit GeluFunctor(OperatorBase& op)
+      : fast_gelu(op.GetSingleArgument<bool>("fast_gelu", false)) {}
+
+  template <typename T>
+  bool operator()(const int N, const T* X, T* Y, Context* context) const;
+
+  const bool fast_gelu;
+};
+
+template <class Context>
+struct GeluGradientFunctor {
+  explicit GeluGradientFunctor(OperatorBase& op)
+      : fast_gelu(op.GetSingleArgument<bool>("fast_gelu", false)) {}
+
+  template <typename T>
+  bool Forward(
+      const std::vector<int>& dY_dims,
+      const std::vector<int>& X_dims,
+      const T* dY,
+      const T* X,
+      T* dX,
+      Context* context) const;
+
+  const bool fast_gelu;
+};
+
+} // namespace caffe2
+
+#endif // CAFFE2_OPERATORS_GELU_OP_H_
index d6e3562..a374c56 100644 (file)
@@ -13,12 +13,14 @@ import caffe2.python.hypothesis_test_util as hu
 import caffe2.python.mkl_test_util as mu
 import caffe2.python.serialized_test.serialized_test_util as serial
 
+from scipy.stats import norm
+
 import unittest
 
 
 class TestActivations(serial.SerializedTestCase):
     @serial.given(X=hu.tensor(), in_place=st.booleans(),
-           engine=st.sampled_from(["", "CUDNN"]), **mu.gcs)
+                  engine=st.sampled_from(["", "CUDNN"]), **mu.gcs)
     def test_relu(self, X, in_place, engine, gc, dc):
         if gc == mu.mkl_do:
             in_place = False
@@ -78,8 +80,8 @@ class TestActivations(serial.SerializedTestCase):
             grad_reference=relu_grad_ref)
 
     @serial.given(X=hu.tensor(elements=st.floats(-3.0, 3.0)),
-           n=st.floats(min_value=0.5, max_value=2.0),
-           in_place=st.booleans(), **hu.gcs)
+                  n=st.floats(min_value=0.5, max_value=2.0),
+                  in_place=st.booleans(), **hu.gcs)
     def test_relu_n(self, X, n, in_place, gc, dc):
         op = core.CreateOperator(
             "ReluN",
@@ -104,9 +106,9 @@ class TestActivations(serial.SerializedTestCase):
         self.assertGradientChecks(gc, op, [X], 0, [0], stepsize=0.005)
 
     @serial.given(X=hu.tensor(),
-           alpha=st.floats(min_value=0.1, max_value=2.0),
-           in_place=st.booleans(), engine=st.sampled_from(["", "CUDNN"]),
-           **hu.gcs)
+                  alpha=st.floats(min_value=0.1, max_value=2.0),
+                  in_place=st.booleans(), engine=st.sampled_from(["", "CUDNN"]),
+                  **hu.gcs)
     def test_elu(self, X, alpha, in_place, engine, gc, dc):
         op = core.CreateOperator(
             "Elu",
@@ -173,9 +175,9 @@ class TestActivations(serial.SerializedTestCase):
             self.assertGradientChecks(gc, op, [X, W], 1, [0], stepsize=1e-2)
 
     @serial.given(X=hu.tensor(),
-           alpha=st.floats(min_value=0.1, max_value=2.0),
-           inplace=st.booleans(),
-           **hu.gcs)
+                  alpha=st.floats(min_value=0.1, max_value=2.0),
+                  inplace=st.booleans(),
+                  **hu.gcs)
     def test_leaky_relu(self, X, alpha, inplace, gc, dc):
         # go away from the origin point to avoid kink problems
         X += 0.04 * np.sign(X)
@@ -215,3 +217,26 @@ class TestActivations(serial.SerializedTestCase):
         self.assertReferenceChecks(gc, op, [X], leaky_relu_ref)
         # Check over multiple devices
         self.assertDeviceChecks(dc, op, [X], [0])
+
+    @given(X=hu.tensor(),
+           fast_gelu=st.booleans(),
+           **hu.gcs)
+    def test_gelu(self, X, fast_gelu, gc, dc):
+        op = core.CreateOperator(
+            "Gelu",
+            ["X"],
+            ["Y"],
+            fast_gelu=fast_gelu,
+        )
+
+        def gelu_ref(X):
+            return (X * norm.cdf(X),)
+
+        tol = 1e-3 if fast_gelu else 1e-4
+        self.assertReferenceChecks(gc, op, [X], gelu_ref, threshold=tol)
+        self.assertDeviceChecks(dc, op, [X], [0])
+        self.assertGradientChecks(gc, op, [X], 0, [0])
+
+
+if __name__ == "__main__":
+    unittest.main()
index c18f35b..cdb8bcd 100644 (file)
@@ -76,6 +76,8 @@ DELEGATE_SIMPLE_UNARY_FUNCTION(float, Inv, vsInv)
 DELEGATE_SIMPLE_UNARY_FUNCTION(double, Inv, vdInv)
 DELEGATE_SIMPLE_UNARY_FUNCTION(float, Erf, vsErf)
 DELEGATE_SIMPLE_UNARY_FUNCTION(double, Erf, vdErf)
+DELEGATE_SIMPLE_UNARY_FUNCTION(float, CdfNorm, vsCdfNorm)
+DELEGATE_SIMPLE_UNARY_FUNCTION(double, CdfNorm, vdCdfNorm)
 #undef DELEGATE_SIMPLE_UNARY_FUNCTION
 
 #define DELEGATE_SINCOS(T, MKLFunc)                                     \
@@ -240,6 +242,19 @@ CAFFE2_SPECIALIZED_ERF(float)
 CAFFE2_SPECIALIZED_ERF(double)
 #undef CAFFE2_SPECIALIZED_ERF
 
+#define CAFFE2_SPECIALIZED_CDF_NORM(T)                            \
+  template <>                                                     \
+  C10_EXPORT void CdfNorm<T, CPUContext>(                         \
+      const int N, const T* X, T* Y, CPUContext* /* context */) { \
+    std::transform(X, X + N, Y, [](const T x) {                   \
+      constexpr T kRsqrt2 = 0.7071067811865475;                   \
+      return (T(1) + erf(x * kRsqrt2)) * static_cast<T>(0.5);     \
+    });                                                           \
+  }
+CAFFE2_SPECIALIZED_CDF_NORM(float)
+CAFFE2_SPECIALIZED_CDF_NORM(double)
+#undef CAFFE2_SPECIALIZED_CDF_NORM
+
 #define DELEGATE_SIMPLE_BINARY_FUNCTION_BY_EIGEN_OPERATOR(T, Func, EigenOp)   \
   template <>                                                                 \
   C10_EXPORT void Func<T, CPUContext>(                                        \
index 9b321a5..f133b28 100644 (file)
@@ -283,6 +283,8 @@ DELEGATE_SIMPLE_CUDA_UNARY_FUNCTION(double, Cube, utils::Cube<double>)
 DELEGATE_SIMPLE_CUDA_UNARY_FUNCTION(float, Cbrt, cbrtf)
 DELEGATE_SIMPLE_CUDA_UNARY_FUNCTION(float, Erf, erff)
 DELEGATE_SIMPLE_CUDA_UNARY_FUNCTION(double, Erf, erf)
+DELEGATE_SIMPLE_CUDA_UNARY_FUNCTION(float, CdfNorm, normcdff)
+DELEGATE_SIMPLE_CUDA_UNARY_FUNCTION(double, CdfNorm, normcdf)
 DELEGATE_SIMPLE_CUDA_UNARY_FUNCTION(bool, Not, utils::Not<bool>)
 DELEGATE_SIMPLE_CUDA_UNARY_FUNCTION(
     std::int32_t,
index 0655dc6..93b56eb 100644 (file)
@@ -55,6 +55,8 @@ template <typename T, class Context>
 CAFFE2_API void Inv(int N, const T* X, T* Y, Context* context);
 template <typename T, class Context>
 CAFFE2_API void Erf(int N, const T* X, T* Y, Context* context);
+template <typename T, class Context>
+CAFFE2_API void CdfNorm(int N, const T* X, T* Y, Context* context);
 
 template <typename T, class Context>
 CAFFE2_API void Set(std::int64_t N, T alpha, T* X, Context* context);