Optimize channel_stats_op (#16243)
authorXiaomeng Yang <yangxm@fb.com>
Tue, 12 Mar 2019 18:52:01 +0000 (11:52 -0700)
committerFacebook Github Bot <facebook-github-bot@users.noreply.github.com>
Tue, 12 Mar 2019 19:08:00 +0000 (12:08 -0700)
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/16243

Optimize channel_stats_op and add NHWC impl

Reviewed By: takatosp1

Differential Revision: D13775515

fbshipit-source-id: decb889e646f5316d4afefdf9f9b6bc6343613cd

caffe2/operators/channel_stats_op.cc
caffe2/operators/channel_stats_op.cu
caffe2/operators/channel_stats_op.h
caffe2/python/operator_test/channel_stats_op_test.py
caffe2/python/operator_test/group_norm_op_test.py
caffe2/python/serialized_test/SerializedTestCoverage.md
caffe2/python/serialized_test/data/operator_test/channel_stats_op_test.test_channel_stats_2d.zip [new file with mode: 0644]
caffe2/python/serialized_test/data/operator_test/channel_stats_op_test.test_channel_stats_3d.zip [new file with mode: 0644]

index c4ed32c..af736c7 100644 (file)
@@ -1,39 +1,51 @@
 #include "caffe2/operators/channel_stats_op.h"
+
 #include "caffe2/utils/eigen_utils.h"
 
 namespace caffe2 {
 
 template <>
-bool ChannelStatsOp<CPUContext>::RunOnDevice() {
-  const auto& X = Input(INPUT);
-  CAFFE_ENFORCE(X.dim() >= 3 && X.dim() <= 5);
-  const int N = X.dim32(0);
-  const int C = X.dim32(1);
-  const int H = X.dim32(2);
-  const int W = X.dim() > 3 ? X.dim32(3) : 1;
-  const int D = X.dim() > 4 ? X.dim32(4) : 1;
-
-  const int sampleSize = H * W * D;
-
-  Output(SUM)->Resize(C);
-  Output(SUMSQ)->Resize(C);
-  EigenVectorArrayMap<float> sum(
-      Output(SUM)->template mutable_data<float>(), C);
-  EigenVectorArrayMap<float> sumsq(
-      Output(SUMSQ)->template mutable_data<float>(), C);
-
-  sum.setZero();
-  sumsq.setZero();
-  ConstEigenArrayMap<float> X_arr(X.data<float>(), sampleSize, N * C);
-  auto index = 0;
-  for (int n = 0; n < N; ++n) {
-    for (int c = 0; c < C; ++c) {
-      sum(c) += X_arr.col(index).sum();
-      sumsq(c) += X_arr.col(index).matrix().squaredNorm();
-      index++;
+template <>
+bool ChannelStatsOp<CPUContext>::ComputeChannelStatsNCHW<float>(
+    const int N,
+    const int C,
+    const int HxW,
+    const float* X,
+    float* sum,
+    float* sumsq) {
+  ConstEigenArrayMap<float> X_arr(X, HxW, N * C);
+  for (int i = 0; i < C; ++i) {
+    sum[i] = X_arr.col(i).sum();
+    sumsq[i] = X_arr.col(i).square().sum();
+  }
+  for (int i = 1; i < N; ++i) {
+    for (int j = 0; j < C; ++j) {
+      const int c = i * C + j;
+      sum[j] += X_arr.col(c).sum();
+      sumsq[j] += X_arr.col(c).square().sum();
     }
   }
+  return true;
+}
 
+template <>
+template <>
+bool ChannelStatsOp<CPUContext>::ComputeChannelStatsNHWC<float>(
+    const int N,
+    const int C,
+    const int HxW,
+    const float* X,
+    float* sum,
+    float* sumsq) {
+  ConstEigenArrayMap<float> X_arr(X, C, N * HxW);
+  EigenVectorArrayMap<float> sum_arr(sum, C);
+  EigenVectorArrayMap<float> sumsq_arr(sumsq, C);
+  sum_arr = X_arr.col(0);
+  sumsq_arr = X_arr.col(0).square();
+  for (int i = 1; i < N * HxW; ++i) {
+    sum_arr += X_arr.col(i);
+    sumsq_arr += X_arr.col(i).square();
+  }
   return true;
 }
 
@@ -49,7 +61,6 @@ reduced across multiple batches and used to obtain the mean and variance across
 the full set of batches. Using the new mean and variance as input to SpatialBN
 has the effect of changing the batch size over which SpatialBN is applied.
 )DOC")
-
     .Input(0, "X", "The input 4-dimensional tensor of shape NCHW")
     .Output(
         0,
@@ -61,5 +72,7 @@ has the effect of changing the batch size over which SpatialBN is applied.
         "sumsq",
         "The output 1-dimensional tensor of size C containing the sum of "
         "elements squared per channel.");
+
 SHOULD_NOT_DO_GRADIENT(ChannelStats);
+
 } // namespace caffe2
index 5243005..ae7f7d2 100644 (file)
-#include "caffe2/core/context_gpu.h"
 #include "caffe2/operators/channel_stats_op.h"
 
+#include "caffe2/core/context_gpu.h"
+#include "caffe2/utils/math/reduce.cuh"
+
 namespace caffe2 {
 
 namespace {
 
-// based on "Optimizing Parallel Reduction in CUDA" by Mark Harris
-
-// note - volatile keyword is needed to allow doing a warp reduction without
-// synchronization on recent architectures
-template <unsigned int blockSize>
-__device__ void warpReduce(volatile float* sdata, unsigned int tid) {
-  // note - the if statements are "free" as they are resolved at compile time
-  if (blockSize >= 64)
-    sdata[tid] += sdata[tid + 32];
-  if (blockSize >= 32)
-    sdata[tid] += sdata[tid + 16];
-  if (blockSize >= 16)
-    sdata[tid] += sdata[tid + 8];
-  if (blockSize >= 8)
-    sdata[tid] += sdata[tid + 4];
-  if (blockSize >= 4)
-    sdata[tid] += sdata[tid + 2];
-  if (blockSize >= 2)
-    sdata[tid] += sdata[tid + 1];
-}
-
-template <unsigned int blockSize>
-__global__ void ChannelStatsBlockKernel(
-    int N,
-    int C,
-    int valsPerChannel,
-    const float* inputData,
-    float* sums,
-    float* sumsq) {
-  __shared__ float sumData[blockSize];
-  __shared__ float sumSqData[blockSize];
-
-  auto tid = threadIdx.x;
-  auto numBlocksPerChannel = (valsPerChannel + blockSize - 1) / blockSize;
-  auto localBlockIndex = blockIdx.x % numBlocksPerChannel;
-  auto inputIndex = (blockIdx.x / numBlocksPerChannel) * valsPerChannel +
-      localBlockIndex * blockSize + tid;
-
-  sumData[tid] = 0;
-  sumSqData[tid] = 0;
-
-  if (localBlockIndex * blockSize + tid < valsPerChannel) {
-    sumData[tid] += inputData[inputIndex];
-    sumSqData[tid] += inputData[inputIndex] * inputData[inputIndex];
-  }
-
-  __syncthreads();
-  if (blockSize >= 512) {
-    if (tid < 256) {
-      sumData[tid] += sumData[tid + 256];
-      sumSqData[tid] += sumSqData[tid + 256];
-    }
-    __syncthreads();
-  }
-  if (blockSize >= 256) {
-    if (tid < 128) {
-      sumData[tid] += sumData[tid + 128];
-      sumSqData[tid] += sumSqData[tid + 128];
-    }
-    __syncthreads();
-  }
-  if (blockSize >= 128) {
-    if (tid < 64) {
-      sumData[tid] += sumData[tid + 64];
-      sumSqData[tid] += sumSqData[tid + 64];
+template <typename T, int kBlockDimX, int kBlockDimY>
+__global__ void ChannelStatsNCHWCUDAKernel(
+    const int N,
+    const int C,
+    const int HxW,
+    const T* X,
+    T* sum,
+    T* sumsq) {
+  __shared__
+      typename BlockReduce2D<T, kBlockDimX, kBlockDimY>::TempStorage m_storage;
+  __shared__
+      typename BlockReduce2D<T, kBlockDimX, kBlockDimY>::TempStorage v_storage;
+  const int c = blockIdx.x;
+  T m_val = 0;
+  T v_val = 0;
+  for (int n = threadIdx.x; n < N; n += blockDim.x) {
+    for (int hw = threadIdx.y; hw < HxW; hw += blockDim.y) {
+      const int index = (n * C + c) * HxW + hw;
+#if __CUDA_ARCH__ >= 350 || defined(__HIP_PLATFORM_HCC__)
+      m_val += __ldg(X + index);
+      v_val += __ldg(X + index) * __ldg(X + index);
+#else
+      m_val += X[index];
+      v_val += X[index] * X[index];
+#endif
     }
-    __syncthreads();
   }
-
-  if (tid < 32) {
-    warpReduce<blockSize>(sumData, tid);
-    warpReduce<blockSize>(sumSqData, tid);
-  }
-
-  // output block data sorted by C to simplify second reduction
-  if (tid == 0) {
-    auto n = blockIdx.x / numBlocksPerChannel / C;
-    auto c = (blockIdx.x / numBlocksPerChannel) % C;
-    auto outputIndex = (c * N + n) * numBlocksPerChannel + localBlockIndex;
-    sums[outputIndex] = sumData[0];
-    sumsq[outputIndex] = sumSqData[0];
+  m_val = BlockReduce2D<T, kBlockDimX, kBlockDimY>(m_storage).Sum(m_val);
+  v_val = BlockReduce2D<T, kBlockDimX, kBlockDimY>(v_storage).Sum(v_val);
+  if (threadIdx.x == 0 && threadIdx.y == 0) {
+    sum[c] = m_val;
+    sumsq[c] = v_val;
   }
 }
 
-template <unsigned int blockSize>
-__global__ void ChannelStatsFinalSumsKernel(
-    int N,
-    int C,
-    int numSumsPerChannel,
-    const float* sumsScratch,
-    const float* sumsqScratch,
-    float* channelSums,
-    float* channelSumsq) {
-  __shared__ float sumData[blockSize];
-  __shared__ float sumSqData[blockSize];
-
-  auto tid = threadIdx.x;
-  auto inputIndex = blockIdx.x * N * numSumsPerChannel + tid;
-  sumData[tid] = 0;
-  sumSqData[tid] = 0;
-  for (auto i = inputIndex; i < (blockIdx.x + 1) * N * numSumsPerChannel;
-       i += blockSize) {
-    sumData[tid] += sumsScratch[i];
-    sumSqData[tid] += sumsqScratch[i];
-  }
-  __syncthreads();
-  if (blockSize >= 512) {
-    if (tid < 256) {
-      sumData[tid] += sumData[tid + 256];
-      sumSqData[tid] += sumSqData[tid + 256];
-    }
-    __syncthreads();
-  }
-  if (blockSize >= 256) {
-    if (tid < 128) {
-      sumData[tid] += sumData[tid + 128];
-      sumSqData[tid] += sumSqData[tid + 128];
-    }
-    __syncthreads();
+template <typename T>
+__global__ void ChannelStatsNHWCCUDAKernel(
+    const int N,
+    const int C,
+    const int HxW,
+    const T* X,
+    T* sum,
+    T* sumsq) {
+  __shared__ typename BlockReduce<T>::TempStorage m_storage;
+  __shared__ typename BlockReduce<T>::TempStorage v_storage;
+  const int inner_size = N * HxW;
+  const int c = blockIdx.x;
+  T m_val = 0;
+  T v_val = 0;
+  for (int i = threadIdx.x; i < inner_size; i += blockDim.x) {
+    const int index = i * C + c;
+#if __CUDA_ARCH__ >= 350 || defined(__HIP_PLATFORM_HCC__)
+    m_val += __ldg(X + index);
+    v_val += __ldg(X + index) * __ldg(X + index);
+#else
+    m_val += X[index];
+    v_val += X[index] * X[index];
+#endif
   }
-  if (blockSize >= 128) {
-    if (tid < 64) {
-      sumData[tid] += sumData[tid + 64];
-      sumSqData[tid] += sumSqData[tid + 64];
-    }
-    __syncthreads();
-  }
-  if (tid < 32) {
-    warpReduce<blockSize>(sumData, tid);
-    warpReduce<blockSize>(sumSqData, tid);
-  }
-
-  if (tid == 0) {
-    channelSums[blockIdx.x] = sumData[0];
-    channelSumsq[blockIdx.x] = sumSqData[0];
+  m_val = BlockReduce<T>(m_storage).Sum(m_val);
+  v_val = BlockReduce<T>(v_storage).Sum(v_val);
+  if (threadIdx.x == 0) {
+    sum[c] = m_val;
+    sumsq[c] = v_val;
   }
 }
+
 } // namespace
 
 template <>
-bool ChannelStatsOp<CUDAContext>::RunOnDevice() {
-  const auto& X = Input(INPUT);
-  CAFFE_ENFORCE(X.dim() >= 3 && X.dim() <= 5);
-  const int N = X.dim32(0);
-  const int C = X.dim32(1);
-  const int H = X.dim32(2);
-  const int W = X.dim() > 3 ? X.dim32(3) : 1;
-  const int D = X.dim() > 4 ? X.dim32(4) : 1;
-
-  const auto X_arr = X.data<float>();
-  const auto valsPerChannel = H * W * D;
-
-  const auto numBlocksPerChannel = CAFFE_GET_BLOCKS(valsPerChannel);
-  const auto numBlocksTotal = numBlocksPerChannel * N * C;
-
-  ReinitializeTensor(
-      &sumScratch_, {numBlocksTotal}, at::dtype<float>().device(CUDA));
-  ReinitializeTensor(
-      &sumsqScratch_, {numBlocksTotal}, at::dtype<float>().device(CUDA));
-
-  auto sum = Output(SUM, {C}, at::dtype<float>());
-  auto sumsq = Output(SUMSQ, {C}, at::dtype<float>());
-
-  ChannelStatsBlockKernel<CAFFE_CUDA_NUM_THREADS>
-      <<<numBlocksTotal, CAFFE_CUDA_NUM_THREADS, 0, context_.cuda_stream()>>>(
-          N,
-          C,
-          valsPerChannel,
-          X_arr,
-          sumScratch_.mutable_data<float>(),
-          sumsqScratch_.mutable_data<float>());
+template <>
+bool ChannelStatsOp<CUDAContext>::ComputeChannelStatsNCHW<float>(
+    const int N,
+    const int C,
+    const int HxW,
+    const float* X,
+    float* sum,
+    float* sumsq) {
+  DISPATCH_REDUCE_KERNEL_BY_2D_BLOCK_WITH_TYPE_1(
+      HxW,
+      ChannelStatsNCHWCUDAKernel,
+      float,
+      C,
+      context_.cuda_stream(),
+      N,
+      C,
+      HxW,
+      X,
+      sum,
+      sumsq);
+  return true;
+}
 
-  ChannelStatsFinalSumsKernel<CAFFE_CUDA_NUM_THREADS>
+template <>
+template <>
+bool ChannelStatsOp<CUDAContext>::ComputeChannelStatsNHWC<float>(
+    const int N,
+    const int C,
+    const int HxW,
+    const float* X,
+    float* sum,
+    float* sumsq) {
+  ChannelStatsNHWCCUDAKernel<float>
       <<<C, CAFFE_CUDA_NUM_THREADS, 0, context_.cuda_stream()>>>(
-          N,
-          C,
-          numBlocksPerChannel,
-          sumScratch_.data<float>(),
-          sumsqScratch_.data<float>(),
-          sum->template mutable_data<float>(),
-          sumsq->template mutable_data<float>());
-
+          N, C, HxW, X, sum, sumsq);
   return true;
 }
 
index 9bdd3d3..17ff43b 100644 (file)
@@ -1,5 +1,7 @@
-#ifndef CAFFE2_OPERATORS_CHANNEL_STATS_OP_H
-#define CAFFE2_OPERATORS_CHANNEL_STATS_OP_H
+#ifndef CAFFE2_OPERATORS_CHANNEL_STATS_OP_H_
+#define CAFFE2_OPERATORS_CHANNEL_STATS_OP_H_
+
+#include <string>
 
 #include "caffe2/core/context.h"
 #include "caffe2/core/operator.h"
 namespace caffe2 {
 
 template <class Context>
-class ChannelStatsOp : public Operator<Context> {
+class ChannelStatsOp final : public Operator<Context> {
  public:
   USE_OPERATOR_CONTEXT_FUNCTIONS;
+
   template <class... Args>
   explicit ChannelStatsOp(Args&&... args)
-      : Operator<Context>(std::forward<Args>(args)...) {}
-  ~ChannelStatsOp() {}
+      : Operator<Context>(std::forward<Args>(args)...),
+        order_(StringToStorageOrder(
+            this->template GetSingleArgument<std::string>("order", "NCHW"))) {
+    CAFFE_ENFORCE_NE(order_, StorageOrder::UNKNOWN);
+  }
 
   bool RunOnDevice() override {
-    return true;
+    return DispatchHelper<TensorTypes<float>>::call(this, Input(0));
   }
 
- protected:
-  INPUT_TAGS(INPUT);
-  OUTPUT_TAGS(SUM, SUMSQ);
+  template <typename T>
+  bool DoRunWithType() {
+    const auto& X = Input(0);
+    const int ndim = X.dim();
+    const int N = X.dim32(0);
+    const int C = order_ == StorageOrder::NCHW ? X.dim32(1) : X.dim32(ndim - 1);
+    const int HxW = X.numel() / (N * C);
+    auto* sum = Output(0, {C}, at::dtype<T>());
+    auto* sumsq = Output(1, {C}, at::dtype<T>());
+    const T* X_data = X.template data<T>();
+    T* sum_data = sum->template mutable_data<T>();
+    T* sumsq_data = sumsq->template mutable_data<T>();
+    return order_ == StorageOrder::NCHW
+        ? ComputeChannelStatsNCHW<T>(N, C, HxW, X_data, sum_data, sumsq_data)
+        : ComputeChannelStatsNHWC<T>(N, C, HxW, X_data, sum_data, sumsq_data);
+  }
+
+ private:
+  template <typename T>
+  bool
+  ComputeChannelStatsNCHW(int N, int C, int HxW, const T* X, T* sum, T* sumsq);
+
+  template <typename T>
+  bool
+  ComputeChannelStatsNHWC(int N, int C, int HxW, const T* X, T* sum, T* sumsq);
 
-  Tensor sumScratch_;
-  Tensor sumsqScratch_;
+  const StorageOrder order_;
 };
 
 } // namespace caffe2
 
-#endif // CAFFE2_OPERATORS_CHANNEL_STATS_OP_H
+#endif // CAFFE2_OPERATORS_CHANNEL_STATS_OP_H_
index f1dadde..d793b5f 100644 (file)
@@ -1,42 +1,85 @@
 from __future__ import absolute_import
 from __future__ import division
 from __future__ import print_function
-from __future__ import unicode_literals
 
 from caffe2.python import core
 import caffe2.python.hypothesis_test_util as hu
 import caffe2.python.serialized_test.serialized_test_util as serial
+
 from hypothesis import assume, given
 import hypothesis.strategies as st
 import numpy as np
+
 import unittest
 
 
-class TestChannelStats(serial.SerializedTestCase):
+class TestChannelStatsOp(serial.SerializedTestCase):
+    def channel_stats_nchw_ref(self, X):
+        dims = X.shape
+        N = dims[0]
+        C = dims[1]
+        X = X.reshape(N, C, -1)
+        sum1 = np.sum(X, axis=(0, 2), keepdims=False)
+        sum2 = np.sum(X**2, axis=(0, 2), keepdims=False)
+        return (sum1, sum2)
+
+    def channel_stats_nhwc_ref(self, X):
+        dims = X.shape
+        N = dims[0]
+        C = dims[-1]
+        X = X.reshape(N, -1, C)
+        sum1 = np.sum(X, axis=(0, 1), keepdims=False)
+        sum2 = np.sum(X**2, axis=(0, 1), keepdims=False)
+        return (sum1, sum2)
+
     @serial.given(
-        size=st.integers(7, 10),
-        inputChannels=st.integers(1, 10),
-        batchSize=st.integers(1, 3),
-        **hu.gcs
-    )
-    def testChannelStats(self, size, inputChannels, batchSize, gc, dc):
+        N=st.integers(1, 5), C=st.integers(1, 10), H=st.integers(1, 12),
+        W=st.integers(1, 12), order=st.sampled_from(["NCHW", "NHWC"]), **hu.gcs)
+    def test_channel_stats_2d(self, N, C, H, W, order, gc, dc):
+        op = core.CreateOperator(
+            "ChannelStats",
+            ["X"],
+            ["sum", "sumsq"],
+            order=order,
+        )
+
+        def ref_op(X):
+            if order == "NCHW":
+                return self.channel_stats_nchw_ref(X)
+            else:
+                return self.channel_stats_nhwc_ref(X)
 
+        X = np.random.randn(N, C, H, W).astype(np.float32)
+        if order == "NHWC":
+            X = np.transpose(X, [0, 2, 3, 1])
+
+        self.assertReferenceChecks(gc, op, [X], reference=ref_op)
+        self.assertDeviceChecks(dc, op, [X], [0, 1])
+
+    @serial.given(
+        N=st.integers(1, 5), C=st.integers(1, 10), D=st.integers(1, 6),
+        H=st.integers(1, 6), W=st.integers(1, 6),
+        order=st.sampled_from(["NCHW", "NHWC"]), **hu.gcs)
+    def test_channel_stats_3d(self, N, C, D, H, W, order, gc, dc):
         op = core.CreateOperator(
             "ChannelStats",
             ["X"],
             ["sum", "sumsq"],
+            order=order,
         )
 
-        def referenceChannelStatsTest(X):
-            sums = np.sum(X, axis=(0, 2, 3), keepdims=False)
-            sumsq = np.zeros(inputChannels)
-            sumsq = np.sum(X**2, axis=(0, 2, 3), keepdims=False)
-            return sums, sumsq
+        def ref_op(X):
+            if order == "NCHW":
+                return self.channel_stats_nchw_ref(X)
+            else:
+                return self.channel_stats_nhwc_ref(X)
 
-        X = np.random.rand(batchSize, inputChannels, size, size)\
-                .astype(np.float32) - 0.5
-        self.assertReferenceChecks(gc, op, [X], referenceChannelStatsTest)
+        X = np.random.randn(N, C, D, H, W).astype(np.float32)
+        if order == "NHWC":
+            X = np.transpose(X, [0, 2, 3, 4, 1])
 
+        self.assertReferenceChecks(gc, op, [X], reference=ref_op)
+        self.assertDeviceChecks(dc, op, [X], [0, 1])
 
 if __name__ == "__main__":
     unittest.main()
index febf051..5507cd7 100644 (file)
@@ -10,6 +10,8 @@ from hypothesis import given
 import hypothesis.strategies as st
 import numpy as np
 
+import unittest
+
 
 class TestGroupNormOp(serial.SerializedTestCase):
     def group_norm_nchw_ref(self, X, gamma, beta, group, epsilon):
@@ -144,3 +146,7 @@ class TestGroupNormOp(serial.SerializedTestCase):
         inputs = [X, gamma, beta]
         for i in range(len(inputs)):
             self.assertGradientChecks(gc, op, inputs, i, [0])
+
+
+if __name__ == "__main__":
+    unittest.main()
index ae31d43..4db6c37 100644 (file)
@@ -1,11 +1,11 @@
 # Serialized Test Coverage Report
 This is an automatically generated file. Please see `caffe2/python/serialized_test/README.md` for details. In the case of merge conflicts, please rebase and regenerate.
 ## Summary
-Serialized tests have covered 217/684 (31.7%) operators
+Serialized tests have covered 219/688 (31.8%) operators
 
 ## Not covered operators
 <details>
-<summary>There are 467 not covered operators</summary>
+<summary>There are 469 not covered operators</summary>
 
 * APMeter
 * ATen
@@ -17,6 +17,7 @@ Serialized tests have covered 217/684 (31.7%) operators
 * Adam
 * Add
 * AddGradient
+* AdjustBatch
 * Alias
 * Allgather
 * Allreduce
@@ -96,6 +97,7 @@ Serialized tests have covered 217/684 (31.7%) operators
 * CubeGradient
 * DBExists
 * DataCouple
+* DenseVectorToIdList
 * DepthConcat
 * DepthSplit
 * DequeueBlobs
@@ -478,7 +480,7 @@ Serialized tests have covered 217/684 (31.7%) operators
 
 ## Covered operators
 <details>
-<summary>There are 217 covered operators</summary>
+<summary>There are 219 covered operators</summary>
 
 * Acos
 * AcosGradient
@@ -543,6 +545,8 @@ Serialized tests have covered 217/684 (31.7%) operators
 * ElementwiseLinearGradient
 * Elu
 * EluGradient
+* Erf
+* ErfGradient
 * Expand
 * ExpandGradient
 * FC
@@ -702,7 +706,7 @@ Serialized tests have covered 217/684 (31.7%) operators
 ## Excluded from coverage statistics
 ### Schemaless operators
 <details>
-<summary>There are 21 schemaless operators</summary>
+<summary>There are 22 schemaless operators</summary>
 
 * C10Add_DontUseThisOpYet
 * C10AveragedLoss_DontUseThisOpYet
@@ -718,6 +722,7 @@ Serialized tests have covered 217/684 (31.7%) operators
 * C10GivenTensorFill_DontUseThisOpYet
 * C10GivenTensorInt64Fill_DontUseThisOpYet
 * C10GivenTensorIntFill_DontUseThisOpYet
+* C10LayerNorm_DontUseThisOpYet
 * C10Mul_DontUseThisOpYet
 * C10Relu_DontUseThisOpYet
 * C10SigmoidCrossEntropyWithLogits_DontUseThisOpYet
diff --git a/caffe2/python/serialized_test/data/operator_test/channel_stats_op_test.test_channel_stats_2d.zip b/caffe2/python/serialized_test/data/operator_test/channel_stats_op_test.test_channel_stats_2d.zip
new file mode 100644 (file)
index 0000000..9e19369
Binary files /dev/null and b/caffe2/python/serialized_test/data/operator_test/channel_stats_op_test.test_channel_stats_2d.zip differ
diff --git a/caffe2/python/serialized_test/data/operator_test/channel_stats_op_test.test_channel_stats_3d.zip b/caffe2/python/serialized_test/data/operator_test/channel_stats_op_test.test_channel_stats_3d.zip
new file mode 100644 (file)
index 0000000..a8a237a
Binary files /dev/null and b/caffe2/python/serialized_test/data/operator_test/channel_stats_op_test.test_channel_stats_3d.zip differ