From a4c1aa4bc542c7ff6e600b67e9a0aeb233718514 Mon Sep 17 00:00:00 2001 From: Peter Goldsborough Date: Fri, 11 Jan 2019 19:45:40 -0800 Subject: [PATCH] Add the normalize transform to the core library (#15891) Summary: Adds the `Normalize` transform to the core C++ frontend library. ebetica ezyang soumith Pull Request resolved: https://github.com/pytorch/pytorch/pull/15891 Differential Revision: D13642167 Pulled By: goldsborough fbshipit-source-id: 573428e626d6106cf2aadf3dc2e2aecb9a85efc3 --- test/cpp/api/dataloader.cpp | 74 ++++++++++++++++++++++ .../api/include/torch/data/transforms/tensor.h | 22 +++++++ 2 files changed, 96 insertions(+) diff --git a/test/cpp/api/dataloader.cpp b/test/cpp/api/dataloader.cpp index 461dfe5..dbb6349 100644 --- a/test/cpp/api/dataloader.cpp +++ b/test/cpp/api/dataloader.cpp @@ -450,6 +450,80 @@ TEST(DataTest, TensorLambdaWorksforAnyTargetType) { ASSERT_EQ(batch[1].target, "2"); } +struct DummyTensorDataset + : datasets::Dataset> { + Example get(size_t index) override { + const auto channels = static_cast(index); + torch::Tensor tensor = + (channels > 0) ? torch::ones({channels, 4, 4}) : torch::ones({4, 4}); + return {tensor, static_cast(channels)}; + } + + torch::optional size() const override { + return 100; + } +}; + +TEST(DataTest, NormalizeTransform) { + auto dataset = DummyTensorDataset().map(transforms::Normalize(0.5, 0.1)); + + // Works for zero (one implicit) channels + std::vector> output = dataset.get_batch(0); + ASSERT_EQ(output.size(), 1); + // (1 - 0.5) / 0.1 = 5 + ASSERT_TRUE(output[0].data.allclose(torch::ones({4, 4}) * 5)) + << output[0].data; + + // Works for one explicit channel + output = dataset.get_batch(1); + ASSERT_EQ(output.size(), 1); + ASSERT_EQ(output[0].data.size(0), 1); + ASSERT_TRUE(output[0].data.allclose(torch::ones({1, 4, 4}) * 5)) + << output[0].data; + + // Works for two channels with different moments + dataset = DummyTensorDataset().map( + transforms::Normalize({0.5, 1.5}, {0.1, 0.2})); + output = dataset.get_batch(2); + ASSERT_EQ(output.size(), 1); + ASSERT_EQ(output[0].data.size(0), 2); + ASSERT_TRUE(output[0] + .data.slice(/*dim=*/0, /*start=*/0, /*end=*/1) + .allclose(torch::ones({1, 4, 4}) * 5)) + << output[0].data; + ASSERT_TRUE(output[0] + .data.slice(/*dim=*/0, /*start=*/1) + .allclose(torch::ones({1, 4, 4}) * -2.5)) + << output[0].data; + + // Works for three channels with one moment value + dataset = DummyTensorDataset().map(transforms::Normalize(1.5, 0.2)); + output = dataset.get_batch(3); + ASSERT_EQ(output.size(), 1); + ASSERT_EQ(output[0].data.size(0), 3); + ASSERT_TRUE(output[0].data.allclose(torch::ones({3, 4, 4}) * -2.5)) + << output[0].data; + + // Works for three channels with different moments + dataset = DummyTensorDataset().map( + transforms::Normalize({0.5, 1.5, -1.5}, {0.1, 0.2, 0.2})); + output = dataset.get_batch(3); + ASSERT_EQ(output.size(), 1); + ASSERT_EQ(output[0].data.size(0), 3); + ASSERT_TRUE(output[0] + .data.slice(/*dim=*/0, /*start=*/0, /*end=*/1) + .allclose(torch::ones({1, 4, 4}) * 5)) + << output[0].data; + ASSERT_TRUE(output[0] + .data.slice(/*dim=*/0, /*start=*/1, /*end=*/2) + .allclose(torch::ones({1, 4, 4}) * -2.5)) + << output[0].data; + ASSERT_TRUE(output[0] + .data.slice(/*dim=*/0, /*start=*/2) + .allclose(torch::ones({1, 4, 4}) * 12.5)) + << output[0].data; +} + struct UnCopyableDataset : public datasets::Dataset { UnCopyableDataset() = default; diff --git a/torch/csrc/api/include/torch/data/transforms/tensor.h b/torch/csrc/api/include/torch/data/transforms/tensor.h index c1fed20..cb49ee7 100644 --- a/torch/csrc/api/include/torch/data/transforms/tensor.h +++ b/torch/csrc/api/include/torch/data/transforms/tensor.h @@ -50,6 +50,28 @@ class TensorLambda : public TensorTransform { private: FunctionType function_; }; + +/// Normalizes input tensors by subtracting the supplied mean and dividing by +/// the given standard deviation. +template +struct Normalize : public TensorTransform { + /// Constructs a `Normalize` transform. The mean and standard deviation can be + /// anything that is broadcastable over the input tensors (like single + /// scalars). + Normalize(ArrayRef mean, ArrayRef stddev) + : mean(torch::tensor(mean, torch::kFloat32) + .unsqueeze(/*dim=*/1) + .unsqueeze(/*dim=*/2)), + stddev(torch::tensor(stddev, torch::kFloat32) + .unsqueeze(/*dim=*/1) + .unsqueeze(/*dim=*/2)) {} + + torch::Tensor operator()(Tensor input) { + return input.sub(mean).div(stddev); + } + + torch::Tensor mean, stddev; +}; } // namespace transforms } // namespace data } // namespace torch -- 2.7.4