Fix SoftmaxOps (#16049)
authorJerry Zhang <jerryzh@fb.com>
Fri, 18 Jan 2019 20:14:34 +0000 (12:14 -0800)
committerFacebook Github Bot <facebook-github-bot@users.noreply.github.com>
Fri, 18 Jan 2019 20:30:59 +0000 (12:30 -0800)
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/16049

We might see the pattern
```
if (scale_.numel() != N) {
   scale_->Resize(N);
   // set initial value for scale_
}

// In class:
Tensor scale_{CPU};
```
before in the code, where `scale_` is a member variable of Type `caffe2::Tensor`
This pattern actually serves two purposes, if `scale_` is partially initialized with device type but not size, this call will
initialize Tensor with the correct size, or if `scale_` is already initialized with size, it will check whether the size
matches a runtime value `N` and if not it will Resize. To rewrite this we'll do the following:
```
if (!scale_.defined() || scale_.numel() != N) {
  ReinitializeTensor(&scale_, {N}, at::dtype<float>().device(CPU));
  // set initial value for scale_
}

```
There are some variants, if `scale_` is resized to a constant size, we can call `ReinitializeTensor` instead
```
if (scale_.numel() != 1) {
  scale_->Resize(1);
}
```
-->
```
ReinitializeTensor(&scale_, {1}, at::dtype<float>().device(CPU));
```

Normal Resize will be refactored directly into ReinitializeTensor:
```
scale_->Resize(N);
```
-->
```
ReinitializeTensor(&scale_, {N}, at::dtype<float>().device(CPU));
```

Reviewed By: dzhulgakov

Differential Revision: D13667883

fbshipit-source-id: 2c7cb61544b72765b594011b99150eb5a1b50836

caffe2/operators/softmax_op.cc
caffe2/operators/softmax_ops.cu
caffe2/operators/softmax_with_loss_op.cc
caffe2/operators/softmax_with_loss_op.h
caffe2/operators/spatial_softmax_with_loss_op.cc
caffe2/operators/spatial_softmax_with_loss_op.h

index f925c35..2a021ab 100644 (file)
@@ -13,19 +13,26 @@ bool SoftmaxOp<float, CPUContext>::RunOnDevice() {
   const int D = X.size_from_dim(canonical_axis);
   auto* Y = Output(0, X.sizes(), at::dtype<float>());
   float* Ydata = Y->template mutable_data<float>();
-  // ReinitializeTensor itself has the effect of caching, so there is no need to check for numel of Tensor
   // First, get scales
-  ReinitializeTensor(
-      &scale_, {N}, at::dtype<float>().device(CPU));
+  if (!scale_.defined()) {
+    scale_ = caffe2::empty({N}, at::dtype<float>().device(CPU));
+  } else if (scale_.numel() != N) {
+    scale_.Resize(N);
+  }
 
-  ReinitializeTensor(
-      &rowmax_, {N}, at::dtype<float>().device(CPU));
+  if (!rowmax_.defined()) {
+    rowmax_ = caffe2::empty({N}, at::dtype<float>().device(CPU));
+  } else if (rowmax_.numel() != N) {
+    rowmax_.Resize(N);
+  }
 
-  ReinitializeTensor(
-      &sum_multiplier_,
-      {D},
-      at::dtype<float>().device(CPU));
-  math::Set<float, CPUContext>(D, 1.f, sum_multiplier_.mutable_data<float>(), &context_);
+  if (!sum_multiplier_.defined()) {
+    sum_multiplier_ = caffe2::empty({D}, at::dtype<float>().device(CPU));
+    math::Set<float, CPUContext>(D, 1.f, sum_multiplier_.mutable_data<float>(), &context_);
+  } else if (sum_multiplier_.numel() != D) {
+    sum_multiplier_.Resize(D);
+    math::Set<float, CPUContext>(D, 1.f, sum_multiplier_.mutable_data<float>(), &context_);
+  }
 
   SoftmaxCPU(
       context_,
@@ -50,18 +57,20 @@ bool SoftmaxGradientOp<float, CPUContext>::RunOnDevice() {
   const int64_t N = Y.size_to_dim(canonical_axis);
   const int64_t D = Y.size_from_dim(canonical_axis);
   // First, get scales
-  if (scale_.numel() != N) {
-    ReinitializeTensor(
-        &scale_, {N}, at::dtype<float>().device(CPU));
+  if (!scale_.defined()) {
+    scale_ = caffe2::empty({N}, at::dtype<float>().device(CPU));
+  } else if (scale_.numel() != N) {
+    scale_.Resize(N);
   }
-  if (sum_multiplier_.numel() != D) {
-    ReinitializeTensor(
-        &sum_multiplier_,
-        {D},
-        at::dtype<float>().device(CPU));
-    math::Set<float, CPUContext>(D, 1.f, sum_multiplier_.mutable_data<float>(),
-                                 &context_);
+
+  if (!sum_multiplier_.defined()) {
+    sum_multiplier_ = caffe2::empty({D}, at::dtype<float>().device(CPU));
+    math::Set<float, CPUContext>(D, 1.f, sum_multiplier_.mutable_data<float>(), &context_);
+  } else if (sum_multiplier_.numel() != D) {
+    sum_multiplier_.Resize(D);
+    math::Set<float, CPUContext>(D, 1.f, sum_multiplier_.mutable_data<float>(), &context_);
   }
+
   auto* dX = Output(0, Y.sizes(), at::dtype<float>());
   const float* Ydata = Y.data<float>();
   const float* dYdata = dY.data<float>();
index 0876ad8..d1dd96d 100644 (file)
@@ -311,17 +311,26 @@ bool SoftmaxWithLossOp<float, CUDAContext>::RunOnDevice() {
 
   auto* avg_loss =
       Output(1, vector<int64_t>(), at::dtype<float>()); // Average loss
-  if (losses_.size() != N) {
-    ReinitializeTensor(&losses_, {N}, at::dtype<float>().device(CUDA));
+  if (!losses_.defined()) {
+    losses_ = caffe2::empty({N}, at::dtype<float>().device(CUDA));
+  } else if (losses_.numel() != N) {
+    losses_.Resize(N);
   }
-  if (rowmax_.size() != N) {
-    ReinitializeTensor(&rowmax_, {N}, at::dtype<float>().device(CUDA));
+
+  if (!rowmax_.defined()) {
+    rowmax_ = caffe2::empty({N}, at::dtype<float>().device(CUDA));
+  } else if (rowmax_.numel() != N) {
+    rowmax_.Resize(N);
   }
-  if (sum_multiplier_.size() != D) {
-    ReinitializeTensor(&sum_multiplier_, {D}, at::dtype<float>().device(CUDA));
-    math::Set<float, CUDAContext>(
-        D, 1.f, sum_multiplier_.mutable_data<float>(), &context_);
+
+  if (!sum_multiplier_.defined()) {
+    sum_multiplier_ = caffe2::empty({D}, at::dtype<float>().device(CUDA));
+    math::Set<float, CUDAContext>(D, 1.f, sum_multiplier_.mutable_data<float>(), &context_);
+  } else if (sum_multiplier_.numel() != D) {
+    sum_multiplier_.Resize(D);
+    math::Set<float, CUDAContext>(D, 1.f, sum_multiplier_.mutable_data<float>(), &context_);
   }
+
   Softmax(
       N,
       D,
@@ -379,7 +388,7 @@ bool SoftmaxWithLossOp<float, CUDAContext>::RunOnDevice() {
   // Sum of all losses
   float* avg_loss_data = avg_loss->template mutable_data<float>();
   math::Sum<float, CUDAContext>(
-      losses_.size(), losses_.data<float>(), avg_loss_data, &context_, &scratch_);
+      losses_.numel(), losses_.data<float>(), avg_loss_data, &context_, &scratch_);
   // Average of input batch size
   if (total_weight > 0) {
     math::Scale<float, float, CUDAContext>(
@@ -409,11 +418,16 @@ bool SpatialSoftmaxWithLossOp<float, CUDAContext>::RunOnDevice() {
 
   int H = X.dim32(2);
   int W = X.dim32(3);
-  if (losses_.size() != N * W * H) {
-    ReinitializeTensor(&losses_, {N * W * H}, at::dtype<float>().device(CUDA));
+  if (!losses_.defined()) {
+    losses_ = caffe2::empty({N * W * H}, at::dtype<float>().device(CUDA));
+  } else if (losses_.numel() != N * W * H) {
+    losses_.Resize(N * W * H);
   }
-  if (weights_.size() != N * W * H) {
-    ReinitializeTensor(&weights_, {N * W * H}, at::dtype<float>().device(CUDA));
+
+  if (!weights_.defined()) {
+    weights_ = caffe2::empty({N * W * H}, at::dtype<float>().device(CUDA));
+  } else if (weights_.numel() != N * W * H) {
+    weights_.Resize(N * W * H);
   }
 
   const float* Xdata = X.data<float>();
@@ -454,7 +468,7 @@ bool SpatialSoftmaxWithLossOp<float, CUDAContext>::RunOnDevice() {
   // Somewhat awkward scalar passing from device to host
   float h_total_weight;
   math::Sum<float, CUDAContext>(
-      weights_.size(),
+      weights_.numel(),
       weights_.data<float>(),
       total_weight_ptr_.mutable_data<float>(),
       &context_,
@@ -467,7 +481,7 @@ bool SpatialSoftmaxWithLossOp<float, CUDAContext>::RunOnDevice() {
       context_.cuda_stream()));
 
   math::Sum<float, CUDAContext>(
-      losses_.size(), losses_.data<float>(), avg_loss_data, &context_, &scratch_);
+      losses_.numel(), losses_.data<float>(), avg_loss_data, &context_, &scratch_);
 
   // Final scaling
   if (h_total_weight > 0) {
@@ -624,8 +638,10 @@ bool SpatialSoftmaxWithLossGradientOp<float, CUDAContext>::RunOnDevice() {
   int H = X.dim32(2);
   int W = X.dim32(3);
   dX->ResizeLike(X);
-  if (weights_.size() != N * W * H) {
-    ReinitializeTensor(&weights_, {N * W * H}, at::dtype<float>().device(CUDA));
+  if (!weights_.defined()) {
+    weights_ = caffe2::empty({N * W * H}, at::dtype<float>().device(CUDA));
+  } else if (weights_.numel() != N * W * H) {
+    weights_.Resize(N * W * H);
   }
 
   const float* Pdata = P.data<float>();
@@ -649,7 +665,7 @@ bool SpatialSoftmaxWithLossGradientOp<float, CUDAContext>::RunOnDevice() {
       N, D, W, H, label_data, weights, dX_data, weights_.mutable_data<float>());
 
   math::Sum<float, CUDAContext>(
-      weights_.size(),
+      weights_.numel(),
       weights_.data<float>(),
       total_weight_ptr_.mutable_data<float>(),
       &context_,
@@ -696,17 +712,27 @@ bool SoftmaxOp<float, CUDAContext>::RunOnDevice() {
   if (N == 0) {
     return true;
   }
-  if (sum_multiplier_.size() != D) {
-    ReinitializeTensor(&sum_multiplier_, {D}, at::dtype<float>().device(CUDA));
+  if (!sum_multiplier_.defined()) {
+    sum_multiplier_ = caffe2::empty({D}, at::dtype<float>().device(CUDA));
+    math::Set<float, CUDAContext>(
+        D, 1.f, sum_multiplier_.mutable_data<float>(), &context_);
+  } else if (sum_multiplier_.numel() != D) {
+    sum_multiplier_.Resize(D);
     math::Set<float, CUDAContext>(
         D, 1.f, sum_multiplier_.mutable_data<float>(), &context_);
   }
-  if (scale_.size() != N) {
-    ReinitializeTensor(&scale_, {N}, at::dtype<float>().device(CUDA));
+  if (!scale_.defined()) {
+    scale_ = caffe2::empty({N}, at::dtype<float>().device(CUDA));
+  } else if (scale_.numel() != N) {
+    scale_.Resize(N);
   }
-  if (rowmax_.size() != N) {
-    ReinitializeTensor(&rowmax_, {N}, at::dtype<float>().device(CUDA));
+
+  if (!rowmax_.defined()) {
+    rowmax_ = caffe2::empty({N}, at::dtype<float>().device(CUDA));
+  } else if (rowmax_.numel() != N) {
+    rowmax_.Resize(N);
   }
+
   Softmax(
       N,
       D,
index 81a6c7e..36a7740 100644 (file)
@@ -178,19 +178,25 @@ bool SoftmaxWithLossOp<float, CPUContext>::RunOnDevice() {
     }
   }
 
-  if (sum_multiplier_.numel() != D) {
-    ReinitializeTensor(
-        &sum_multiplier_,
-        {D},
-        at::dtype<float>().device(CPU));
-    math::Set<float, CPUContext>(
-        D, 1.f, sum_multiplier_.mutable_data<float>(), &context_);
+  if (!sum_multiplier_.defined()) {
+    sum_multiplier_ = caffe2::empty({D}, at::dtype<float>().device(CPU));
+    math::Set<float, CPUContext>(D, 1.f, sum_multiplier_.mutable_data<float>(), &context_);
+  } else if (sum_multiplier_.numel() != D) {
+    sum_multiplier_.Resize(D);
+    math::Set<float, CPUContext>(D, 1.f, sum_multiplier_.mutable_data<float>(), &context_);
   }
 
-  ReinitializeTensor(
-      &rowmax_, {N}, at::dtype<float>().device(CPU));
-  ReinitializeTensor(
-      &losses_, {N}, at::dtype<float>().device(CPU));
+  if (!losses_.defined()) {
+    losses_ = caffe2::empty({N}, at::dtype<float>().device(CPU));
+  } else if (losses_.numel() != N) {
+    losses_.Resize(N);
+  }
+
+  if (!rowmax_.defined()) {
+    rowmax_ = caffe2::empty({N}, at::dtype<float>().device(CPU));
+  } else if (rowmax_.numel() != N) {
+    rowmax_.Resize(N);
+  }
 
   SoftmaxCPU(
       context_,
index ce47072..d72e305 100644 (file)
@@ -34,9 +34,10 @@ class SoftmaxWithLossOp final : public Operator<Context> {
 
   Tensor losses_; // Per example loss
   Tensor rowmax_; // per example row max
-  Tensor weights_{Context::GetDeviceType()}; // unignored weights
+  Tensor weights_; // unignored weights
   Tensor sum_multiplier_; // Vector of ones for summing via dot prod
   Tensor total_weight_ptr_;
+  // passed to a function
   Tensor scratch_{Context::GetDeviceType()};
 };
 
@@ -62,8 +63,9 @@ class SoftmaxWithLossGradientOp final : public Operator<Context> {
  protected:
   float scale_;
   int label_prob_mode_;
+  // not used?
   Tensor sum_multiplier_{Context::GetDeviceType()};
-  Tensor weights_{Context::GetDeviceType()}; // unignored weights
+  Tensor weights_; // unignored weights
   Tensor total_weight_ptr_;
   StorageOrder order_;
   bool only_loss_;
index 9c650bf..09464b0 100644 (file)
@@ -72,11 +72,12 @@ bool SpatialSoftmaxWithLossOp<float, CPUContext>::RunOnDevice() {
   auto* P =
       Output(0, X.sizes(), at::dtype<float>()); // Probabilities from softmax
 
-  if (sum_multiplier_.numel() != D) {
-    ReinitializeTensor(
-        &sum_multiplier_,
-        {D},
-        at::dtype<float>().device(CPU));
+  if (!sum_multiplier_.defined()) {
+    sum_multiplier_ = caffe2::empty({D}, at::dtype<float>().device(CPU));
+    math::Set<float, CPUContext>(
+        D, 1.f, sum_multiplier_.mutable_data<float>(), &context_);
+  } else if (sum_multiplier_.numel() != D) {
+    sum_multiplier_.Resize(D);
     math::Set<float, CPUContext>(
         D, 1.f, sum_multiplier_.mutable_data<float>(), &context_);
   }
index c728a45..97d3818 100644 (file)
@@ -29,7 +29,7 @@ class SpatialSoftmaxWithLossOp final : public Operator<Context> {
   StorageOrder order_;
 
   Tensor losses_; // Per example loss
-  Tensor rowmax_{Context::GetDeviceType()}; // per example row max
+  Tensor rowmax_; // per example row max
   Tensor weights_; // unignored weights
   Tensor sum_multiplier_; // Vector of ones for summing via dot prod
   Tensor total_weight_ptr_;
@@ -55,7 +55,7 @@ class SpatialSoftmaxWithLossGradientOp final : public Operator<Context> {
 
  protected:
   float scale_;
-  Tensor sum_multiplier_{Context::GetDeviceType()};
+  Tensor sum_multiplier_;
   Tensor weights_; // unignored weights
   Tensor total_weight_ptr_;
   StorageOrder order_;