ASSERT_EQ(batch[1].target, "2");
}
+struct DummyTensorDataset
+ : datasets::Dataset<DummyTensorDataset, Example<torch::Tensor, int>> {
+ Example<torch::Tensor, int> get(size_t index) override {
+ const auto channels = static_cast<int64_t>(index);
+ torch::Tensor tensor =
+ (channels > 0) ? torch::ones({channels, 4, 4}) : torch::ones({4, 4});
+ return {tensor, static_cast<int>(channels)};
+ }
+
+ torch::optional<size_t> size() const override {
+ return 100;
+ }
+};
+
+TEST(DataTest, NormalizeTransform) {
+ auto dataset = DummyTensorDataset().map(transforms::Normalize<int>(0.5, 0.1));
+
+ // Works for zero (one implicit) channels
+ std::vector<Example<torch::Tensor, int>> output = dataset.get_batch(0);
+ ASSERT_EQ(output.size(), 1);
+ // (1 - 0.5) / 0.1 = 5
+ ASSERT_TRUE(output[0].data.allclose(torch::ones({4, 4}) * 5))
+ << output[0].data;
+
+ // Works for one explicit channel
+ output = dataset.get_batch(1);
+ ASSERT_EQ(output.size(), 1);
+ ASSERT_EQ(output[0].data.size(0), 1);
+ ASSERT_TRUE(output[0].data.allclose(torch::ones({1, 4, 4}) * 5))
+ << output[0].data;
+
+ // Works for two channels with different moments
+ dataset = DummyTensorDataset().map(
+ transforms::Normalize<int>({0.5, 1.5}, {0.1, 0.2}));
+ output = dataset.get_batch(2);
+ ASSERT_EQ(output.size(), 1);
+ ASSERT_EQ(output[0].data.size(0), 2);
+ ASSERT_TRUE(output[0]
+ .data.slice(/*dim=*/0, /*start=*/0, /*end=*/1)
+ .allclose(torch::ones({1, 4, 4}) * 5))
+ << output[0].data;
+ ASSERT_TRUE(output[0]
+ .data.slice(/*dim=*/0, /*start=*/1)
+ .allclose(torch::ones({1, 4, 4}) * -2.5))
+ << output[0].data;
+
+ // Works for three channels with one moment value
+ dataset = DummyTensorDataset().map(transforms::Normalize<int>(1.5, 0.2));
+ output = dataset.get_batch(3);
+ ASSERT_EQ(output.size(), 1);
+ ASSERT_EQ(output[0].data.size(0), 3);
+ ASSERT_TRUE(output[0].data.allclose(torch::ones({3, 4, 4}) * -2.5))
+ << output[0].data;
+
+ // Works for three channels with different moments
+ dataset = DummyTensorDataset().map(
+ transforms::Normalize<int>({0.5, 1.5, -1.5}, {0.1, 0.2, 0.2}));
+ output = dataset.get_batch(3);
+ ASSERT_EQ(output.size(), 1);
+ ASSERT_EQ(output[0].data.size(0), 3);
+ ASSERT_TRUE(output[0]
+ .data.slice(/*dim=*/0, /*start=*/0, /*end=*/1)
+ .allclose(torch::ones({1, 4, 4}) * 5))
+ << output[0].data;
+ ASSERT_TRUE(output[0]
+ .data.slice(/*dim=*/0, /*start=*/1, /*end=*/2)
+ .allclose(torch::ones({1, 4, 4}) * -2.5))
+ << output[0].data;
+ ASSERT_TRUE(output[0]
+ .data.slice(/*dim=*/0, /*start=*/2)
+ .allclose(torch::ones({1, 4, 4}) * 12.5))
+ << output[0].data;
+}
+
struct UnCopyableDataset : public datasets::Dataset<UnCopyableDataset> {
UnCopyableDataset() = default;
private:
FunctionType function_;
};
+
+/// Normalizes input tensors by subtracting the supplied mean and dividing by
+/// the given standard deviation.
+template <typename Target = Tensor>
+struct Normalize : public TensorTransform<Target> {
+ /// Constructs a `Normalize` transform. The mean and standard deviation can be
+ /// anything that is broadcastable over the input tensors (like single
+ /// scalars).
+ Normalize(ArrayRef<double> mean, ArrayRef<double> stddev)
+ : mean(torch::tensor(mean, torch::kFloat32)
+ .unsqueeze(/*dim=*/1)
+ .unsqueeze(/*dim=*/2)),
+ stddev(torch::tensor(stddev, torch::kFloat32)
+ .unsqueeze(/*dim=*/1)
+ .unsqueeze(/*dim=*/2)) {}
+
+ torch::Tensor operator()(Tensor input) {
+ return input.sub(mean).div(stddev);
+ }
+
+ torch::Tensor mean, stddev;
+};
} // namespace transforms
} // namespace data
} // namespace torch