--- /dev/null
+#include "caffe2/operators/adjust_batch_op.h"
+
+namespace caffe2 {
+REGISTER_CPU_OPERATOR(AdjustBatch, AdjustBatchOp<CPUContext>);
+OPERATOR_SCHEMA(AdjustBatch)
+ .NumInputs(1, 2)
+ .NumOutputs(1, 2)
+ .Input(0, "Input", "Input data")
+ .Input(1, "RealBatchSizeIn", "[Optional] Real batch size")
+ .Output(0, "Output", "Data with Adjusted batch size")
+ .Output(1, "RealBatchSizeOut", "[Optional] Real batah size")
+ .Arg("max_batch_size", "(*int*): max batch size")
+ .SetDoc(R"DOC(
+Adjust the batch size of `input` tensor. When we only have 1 input, it will adjust the batch size according to `max_batch_size` argument. In this case, in addition, if it has two outputs, it will record the input batch size and record it to the second output. When we have 2 inputs, it expects the seocnd input contains the batch size to adjust to, and will truncate the input data accordingly.
+
+Github Links:
+- https://github.com/pytorch/pytorch/blob/master/caffe2/operators/adjust_batch_op.cc
+
+ )DOC");
+} // namespace caffe2
--- /dev/null
+#pragma once
+
+#include "caffe2/core/context.h"
+#include "caffe2/core/operator.h"
+
+namespace caffe2 {
+
+template <class Context>
+class AdjustBatchOp final : public Operator<Context> {
+ public:
+ USE_OPERATOR_CONTEXT_FUNCTIONS;
+ AdjustBatchOp(const OperatorDef& operator_def, Workspace* ws)
+ : Operator<Context>(operator_def, ws),
+ max_batch_size_(
+ this->template GetSingleArgument<int64_t>("max_batch_size", -1)) {}
+
+ bool RunOnDevice() override {
+ auto& input = Input(0);
+ vector<int64_t> output_dims(input.sizes().vec());
+ CAFFE_ENFORCE(!output_dims.empty());
+ if (InputSize() > 1) {
+ // TODO: if we have a second input and we have max_batch_size set, check
+ // the batch size of the two inputs for consistency
+ auto& batch_size = Input(1);
+ int64_t real_batch_size = *batch_size.template data<int64_t>();
+ int64_t max_batch_size = output_dims[0];
+ CAFFE_ENFORCE_GE(max_batch_size, real_batch_size);
+ output_dims[0] = real_batch_size;
+ auto* output = Output(0, output_dims, input.dtype());
+ this->context_.template CopyItems<Context, Context>(
+ input.dtype(),
+ input.numel() * real_batch_size / max_batch_size,
+ input.raw_data(),
+ output->raw_mutable_data(input.dtype()));
+ } else {
+ // Pad to max batch size
+ CAFFE_ENFORCE_GT(
+ max_batch_size_,
+ 0,
+ "max_batch_size should be larger than 0. Got ",
+ max_batch_size_);
+
+ // TODO: ideally we can support the case when input batch is larger than
+ // the max_batch_size, as we can just pad to the multiple of
+ // max_batch_size.
+ CAFFE_ENFORCE_GE(max_batch_size_, output_dims.front());
+
+ int64_t real_batch_size = output_dims[0];
+ output_dims[0] = max_batch_size_;
+ auto* output = Output(0, output_dims, input.dtype());
+ math::Set(
+ output->nbytes(),
+ static_cast<char>(0),
+ static_cast<char*>(output->raw_data()),
+ &context_);
+ this->context_.template CopyItems<Context, Context>(
+ input.dtype(),
+ input.numel(),
+ input.raw_data(),
+ output->raw_mutable_data(input.dtype()));
+
+ if (OutputSize() > 1) {
+ auto* real_batch_tensor = Output(1, {1}, at::dtype<int64_t>());
+ real_batch_tensor->template mutable_data<int64_t>()[0] =
+ real_batch_size;
+ }
+ }
+
+ return true;
+ }
+
+ private:
+ int64_t max_batch_size_;
+};
+} // namespace caffe2
--- /dev/null
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+from __future__ import unicode_literals
+
+from caffe2.python import core, workspace
+from hypothesis import given, assume
+import caffe2.python.hypothesis_test_util as hu
+import hypothesis.strategies as st
+import numpy as np
+
+import unittest
+import os
+
+
+class TestAdjustBatchOp(hu.HypothesisTestCase):
+ @given(d=st.integers(1, 4), n=st.integers(1, 20),
+ seed=st.integers(0, 1000), **hu.gcs_cpu_only)
+ def test_pad(self, d, n, gc, dc, seed):
+ for dtype in [np.float32, np.int8, np.int64]:
+ np.random.seed(seed)
+ dims = [n] * d
+ X = np.random.rand(*dims).astype(dtype)
+ max_batch_size = n + 8
+
+ def ref_op(X):
+ shape = list(X.shape)
+ out = np.zeros((1), dtype=np.int64)
+ out[0] = shape[0]
+ shape[0] = max_batch_size
+ Y = np.zeros(shape, dtype=dtype)
+ Y[:n] = X
+ return [Y, out]
+
+ op = core.CreateOperator(
+ "AdjustBatch",
+ ["X"],
+ ["Y", "RealBatch"],
+ max_batch_size=max_batch_size,
+ )
+
+ self.assertReferenceChecks(
+ device_option=gc,
+ op=op,
+ inputs=[X],
+ reference=ref_op,
+ )
+
+ @given(d=st.integers(1, 4), n=st.integers(8, 20),
+ seed=st.integers(0, 1000), **hu.gcs_cpu_only)
+ def test_truncate(self, d, n, gc, dc, seed):
+ for dtype in [np.float32, np.int8, np.int64]:
+ np.random.seed(seed)
+ dims = [n] * d
+ X = np.random.rand(*dims).astype(dtype)
+ real_batch_size = n - 8
+ R = np.zeros((1), dtype=np.int64)
+ R[0] = real_batch_size
+
+ def ref_op(X, R):
+ r = R[0]
+ return [X[:r]]
+
+ op = core.CreateOperator(
+ "AdjustBatch",
+ ["X", "RealBatch"],
+ ["Y"],
+ )
+
+ self.assertReferenceChecks(
+ device_option=gc,
+ op=op,
+ inputs=[X, R],
+ reference=ref_op,
+ )