Add the normalize transform to the core library (#15891)
authorPeter Goldsborough <psag@fb.com>
Sat, 12 Jan 2019 03:45:40 +0000 (19:45 -0800)
committerFacebook Github Bot <facebook-github-bot@users.noreply.github.com>
Sat, 12 Jan 2019 03:50:18 +0000 (19:50 -0800)
Summary:
Adds the `Normalize` transform to the core C++ frontend library.

ebetica ezyang soumith
Pull Request resolved: https://github.com/pytorch/pytorch/pull/15891

Differential Revision: D13642167

Pulled By: goldsborough

fbshipit-source-id: 573428e626d6106cf2aadf3dc2e2aecb9a85efc3

test/cpp/api/dataloader.cpp
torch/csrc/api/include/torch/data/transforms/tensor.h

index 461dfe5..dbb6349 100644 (file)
@@ -450,6 +450,80 @@ TEST(DataTest, TensorLambdaWorksforAnyTargetType) {
   ASSERT_EQ(batch[1].target, "2");
 }
 
+struct DummyTensorDataset
+    : datasets::Dataset<DummyTensorDataset, Example<torch::Tensor, int>> {
+  Example<torch::Tensor, int> get(size_t index) override {
+    const auto channels = static_cast<int64_t>(index);
+    torch::Tensor tensor =
+        (channels > 0) ? torch::ones({channels, 4, 4}) : torch::ones({4, 4});
+    return {tensor, static_cast<int>(channels)};
+  }
+
+  torch::optional<size_t> size() const override {
+    return 100;
+  }
+};
+
+TEST(DataTest, NormalizeTransform) {
+  auto dataset = DummyTensorDataset().map(transforms::Normalize<int>(0.5, 0.1));
+
+  // Works for zero (one implicit) channels
+  std::vector<Example<torch::Tensor, int>> output = dataset.get_batch(0);
+  ASSERT_EQ(output.size(), 1);
+  // (1 - 0.5) / 0.1 = 5
+  ASSERT_TRUE(output[0].data.allclose(torch::ones({4, 4}) * 5))
+      << output[0].data;
+
+  // Works for one explicit channel
+  output = dataset.get_batch(1);
+  ASSERT_EQ(output.size(), 1);
+  ASSERT_EQ(output[0].data.size(0), 1);
+  ASSERT_TRUE(output[0].data.allclose(torch::ones({1, 4, 4}) * 5))
+      << output[0].data;
+
+  // Works for two channels with different moments
+  dataset = DummyTensorDataset().map(
+      transforms::Normalize<int>({0.5, 1.5}, {0.1, 0.2}));
+  output = dataset.get_batch(2);
+  ASSERT_EQ(output.size(), 1);
+  ASSERT_EQ(output[0].data.size(0), 2);
+  ASSERT_TRUE(output[0]
+                  .data.slice(/*dim=*/0, /*start=*/0, /*end=*/1)
+                  .allclose(torch::ones({1, 4, 4}) * 5))
+      << output[0].data;
+  ASSERT_TRUE(output[0]
+                  .data.slice(/*dim=*/0, /*start=*/1)
+                  .allclose(torch::ones({1, 4, 4}) * -2.5))
+      << output[0].data;
+
+  // Works for three channels with one moment value
+  dataset = DummyTensorDataset().map(transforms::Normalize<int>(1.5, 0.2));
+  output = dataset.get_batch(3);
+  ASSERT_EQ(output.size(), 1);
+  ASSERT_EQ(output[0].data.size(0), 3);
+  ASSERT_TRUE(output[0].data.allclose(torch::ones({3, 4, 4}) * -2.5))
+      << output[0].data;
+
+  // Works for three channels with different moments
+  dataset = DummyTensorDataset().map(
+      transforms::Normalize<int>({0.5, 1.5, -1.5}, {0.1, 0.2, 0.2}));
+  output = dataset.get_batch(3);
+  ASSERT_EQ(output.size(), 1);
+  ASSERT_EQ(output[0].data.size(0), 3);
+  ASSERT_TRUE(output[0]
+                  .data.slice(/*dim=*/0, /*start=*/0, /*end=*/1)
+                  .allclose(torch::ones({1, 4, 4}) * 5))
+      << output[0].data;
+  ASSERT_TRUE(output[0]
+                  .data.slice(/*dim=*/0, /*start=*/1, /*end=*/2)
+                  .allclose(torch::ones({1, 4, 4}) * -2.5))
+      << output[0].data;
+  ASSERT_TRUE(output[0]
+                  .data.slice(/*dim=*/0, /*start=*/2)
+                  .allclose(torch::ones({1, 4, 4}) * 12.5))
+      << output[0].data;
+}
+
 struct UnCopyableDataset : public datasets::Dataset<UnCopyableDataset> {
   UnCopyableDataset() = default;
 
index c1fed20..cb49ee7 100644 (file)
@@ -50,6 +50,28 @@ class TensorLambda : public TensorTransform<Target> {
  private:
   FunctionType function_;
 };
+
+/// Normalizes input tensors by subtracting the supplied mean and dividing by
+/// the given standard deviation.
+template <typename Target = Tensor>
+struct Normalize : public TensorTransform<Target> {
+  /// Constructs a `Normalize` transform. The mean and standard deviation can be
+  /// anything that is broadcastable over the input tensors (like single
+  /// scalars).
+  Normalize(ArrayRef<double> mean, ArrayRef<double> stddev)
+      : mean(torch::tensor(mean, torch::kFloat32)
+                 .unsqueeze(/*dim=*/1)
+                 .unsqueeze(/*dim=*/2)),
+        stddev(torch::tensor(stddev, torch::kFloat32)
+                   .unsqueeze(/*dim=*/1)
+                   .unsqueeze(/*dim=*/2)) {}
+
+  torch::Tensor operator()(Tensor input) {
+    return input.sub(mean).div(stddev);
+  }
+
+  torch::Tensor mean, stddev;
+};
 } // namespace transforms
 } // namespace data
 } // namespace torch