Add AdjustBatch Op (#16676)
authorYinghai Lu <yinghai@fb.com>
Thu, 7 Feb 2019 03:12:32 +0000 (19:12 -0800)
committerFacebook Github Bot <facebook-github-bot@users.noreply.github.com>
Thu, 7 Feb 2019 03:15:41 +0000 (19:15 -0800)
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/16676

This op is used for changing batch size (first dimension) of the tensor.

Reviewed By: bertmaher, ipiszy

Differential Revision: D13929200

fbshipit-source-id: 4f2c3faec072d468be8301bf00c80d33adb3b5b3

caffe2/operators/adjust_batch_op.cc [new file with mode: 0644]
caffe2/operators/adjust_batch_op.h [new file with mode: 0644]
caffe2/python/operator_test/adjust_batch_op_test.py [new file with mode: 0644]

diff --git a/caffe2/operators/adjust_batch_op.cc b/caffe2/operators/adjust_batch_op.cc
new file mode 100644 (file)
index 0000000..1e29f5c
--- /dev/null
@@ -0,0 +1,20 @@
+#include "caffe2/operators/adjust_batch_op.h"
+
+namespace caffe2 {
+REGISTER_CPU_OPERATOR(AdjustBatch, AdjustBatchOp<CPUContext>);
+OPERATOR_SCHEMA(AdjustBatch)
+    .NumInputs(1, 2)
+    .NumOutputs(1, 2)
+    .Input(0, "Input", "Input data")
+    .Input(1, "RealBatchSizeIn", "[Optional] Real batch size")
+    .Output(0, "Output", "Data with Adjusted batch size")
+    .Output(1, "RealBatchSizeOut", "[Optional] Real batah size")
+    .Arg("max_batch_size", "(*int*): max batch size")
+    .SetDoc(R"DOC(
+Adjust the batch size of `input` tensor. When we only have 1 input, it will adjust the batch size according to `max_batch_size` argument. In this case, in addition, if it has two outputs, it will record the input batch size and record it to the second output. When we have 2 inputs, it expects the seocnd input contains the batch size to adjust to, and will truncate the input data accordingly.
+
+Github Links:
+- https://github.com/pytorch/pytorch/blob/master/caffe2/operators/adjust_batch_op.cc
+
+  )DOC");
+} // namespace caffe2
diff --git a/caffe2/operators/adjust_batch_op.h b/caffe2/operators/adjust_batch_op.h
new file mode 100644 (file)
index 0000000..f8fbf0a
--- /dev/null
@@ -0,0 +1,75 @@
+#pragma once
+
+#include "caffe2/core/context.h"
+#include "caffe2/core/operator.h"
+
+namespace caffe2 {
+
+template <class Context>
+class AdjustBatchOp final : public Operator<Context> {
+ public:
+  USE_OPERATOR_CONTEXT_FUNCTIONS;
+  AdjustBatchOp(const OperatorDef& operator_def, Workspace* ws)
+      : Operator<Context>(operator_def, ws),
+        max_batch_size_(
+            this->template GetSingleArgument<int64_t>("max_batch_size", -1)) {}
+
+  bool RunOnDevice() override {
+    auto& input = Input(0);
+    vector<int64_t> output_dims(input.sizes().vec());
+    CAFFE_ENFORCE(!output_dims.empty());
+    if (InputSize() > 1) {
+      // TODO: if we have a second input and we have max_batch_size set, check
+      // the batch size of the two inputs for consistency
+      auto& batch_size = Input(1);
+      int64_t real_batch_size = *batch_size.template data<int64_t>();
+      int64_t max_batch_size = output_dims[0];
+      CAFFE_ENFORCE_GE(max_batch_size, real_batch_size);
+      output_dims[0] = real_batch_size;
+      auto* output = Output(0, output_dims, input.dtype());
+      this->context_.template CopyItems<Context, Context>(
+          input.dtype(),
+          input.numel() * real_batch_size / max_batch_size,
+          input.raw_data(),
+          output->raw_mutable_data(input.dtype()));
+    } else {
+      // Pad to max batch size
+      CAFFE_ENFORCE_GT(
+          max_batch_size_,
+          0,
+          "max_batch_size should be larger than 0. Got ",
+          max_batch_size_);
+
+      // TODO: ideally we can support the case when input batch is larger than
+      // the max_batch_size, as we can just pad to the multiple of
+      // max_batch_size.
+      CAFFE_ENFORCE_GE(max_batch_size_, output_dims.front());
+
+      int64_t real_batch_size = output_dims[0];
+      output_dims[0] = max_batch_size_;
+      auto* output = Output(0, output_dims, input.dtype());
+      math::Set(
+          output->nbytes(),
+          static_cast<char>(0),
+          static_cast<char*>(output->raw_data()),
+          &context_);
+      this->context_.template CopyItems<Context, Context>(
+          input.dtype(),
+          input.numel(),
+          input.raw_data(),
+          output->raw_mutable_data(input.dtype()));
+
+      if (OutputSize() > 1) {
+        auto* real_batch_tensor = Output(1, {1}, at::dtype<int64_t>());
+        real_batch_tensor->template mutable_data<int64_t>()[0] =
+            real_batch_size;
+      }
+    }
+
+    return true;
+  }
+
+ private:
+  int64_t max_batch_size_;
+};
+} // namespace caffe2
diff --git a/caffe2/python/operator_test/adjust_batch_op_test.py b/caffe2/python/operator_test/adjust_batch_op_test.py
new file mode 100644 (file)
index 0000000..f4dffd3
--- /dev/null
@@ -0,0 +1,75 @@
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+from __future__ import unicode_literals
+
+from caffe2.python import core, workspace
+from hypothesis import given, assume
+import caffe2.python.hypothesis_test_util as hu
+import hypothesis.strategies as st
+import numpy as np
+
+import unittest
+import os
+
+
+class TestAdjustBatchOp(hu.HypothesisTestCase):
+    @given(d=st.integers(1, 4), n=st.integers(1, 20),
+           seed=st.integers(0, 1000), **hu.gcs_cpu_only)
+    def test_pad(self, d, n, gc, dc, seed):
+        for dtype in [np.float32, np.int8, np.int64]:
+            np.random.seed(seed)
+            dims = [n] * d
+            X = np.random.rand(*dims).astype(dtype)
+            max_batch_size = n + 8
+
+            def ref_op(X):
+                shape = list(X.shape)
+                out = np.zeros((1), dtype=np.int64)
+                out[0] = shape[0]
+                shape[0] = max_batch_size
+                Y = np.zeros(shape, dtype=dtype)
+                Y[:n] = X
+                return [Y, out]
+
+            op = core.CreateOperator(
+                "AdjustBatch",
+                ["X"],
+                ["Y", "RealBatch"],
+                max_batch_size=max_batch_size,
+            )
+
+            self.assertReferenceChecks(
+                device_option=gc,
+                op=op,
+                inputs=[X],
+                reference=ref_op,
+            )
+
+    @given(d=st.integers(1, 4), n=st.integers(8, 20),
+           seed=st.integers(0, 1000), **hu.gcs_cpu_only)
+    def test_truncate(self, d, n, gc, dc, seed):
+        for dtype in [np.float32, np.int8, np.int64]:
+            np.random.seed(seed)
+            dims = [n] * d
+            X = np.random.rand(*dims).astype(dtype)
+            real_batch_size = n - 8
+            R = np.zeros((1), dtype=np.int64)
+            R[0] = real_batch_size
+
+            def ref_op(X, R):
+                r = R[0]
+                return [X[:r]]
+
+            op = core.CreateOperator(
+                "AdjustBatch",
+                ["X", "RealBatch"],
+                ["Y"],
+            )
+
+            self.assertReferenceChecks(
+                device_option=gc,
+                op=op,
+                inputs=[X, R],
+                reference=ref_op,
+            )