Add CostInferenceFunction for SplitOp (#63133)
authorTanvir Zaman <motanv@fb.com>
Fri, 13 Aug 2021 19:25:16 +0000 (12:25 -0700)
committerFacebook GitHub Bot <facebook-github-bot@users.noreply.github.com>
Fri, 13 Aug 2021 19:28:15 +0000 (12:28 -0700)
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/63133

SplitOp is costly but missing cost inference function which hurts cost based balancing. Changes are:
(1) Addition of CostInferenceFunction for SplitOp
(2) Small fix in CostInferenceFunction for ConcatOp

Test Plan:
Added unit tests:

buck test //caffe2/caffe2/python/operator_test:split_op_cost_test

buck test //caffe2/caffe2/python/operator_test:concat_op_cost_test

Reviewed By: smacke

Differential Revision: D30247360

fbshipit-source-id: 989e962f3a981acc85b73aac3fb23e603b7d1591

caffe2/operators/concat_split_op.cc
caffe2/operators/concat_split_op.h
caffe2/python/operator_test/concat_op_cost_test.py [new file with mode: 0644]
caffe2/python/operator_test/split_op_cost_test.py [new file with mode: 0644]

index 3d3b0c1..8eceb5a 100644 (file)
@@ -62,7 +62,7 @@ vector<TensorShape> TensorInferenceForSplit(
       return ret_invalid_shape();
     }
     split.resize(output_size, input_channels / output_size);
-  // NOLINTNEXTLINE(clang-diagnostic-sign-compare)
+    // NOLINTNEXTLINE(clang-diagnostic-sign-compare)
   } else if (split.size() != output_size) {
     LOG(WARNING) << "`split` size (" << split.size()
                  << ") should be equal to output size (" << output_size << ")";
@@ -94,6 +94,27 @@ vector<TensorShape> TensorInferenceForSplit(
   }
   return output_shapes;
 }
+
+OpSchema::Cost CostInferenceForSplit(
+    const OperatorDef&,
+    const vector<TensorShape>& in) {
+  CAFFE_ENFORCE_GT(in.size(), 0);
+  struct OpSchema::Cost cost;
+  cost.flops = 0;
+  auto input_bytes_count = nElemFromDim(in[0]) * sizeof(in[0].data_type());
+  auto split_bytes_count =
+      (in.size() == 1) ? 0 : nElemFromDim(in[1]) * sizeof(in[1].data_type());
+  // There can be two input blobs:
+  // (1) actual tensor to be split
+  // (2) lengths of outputs along split axis
+  // So, bytes_read is the sum of the bytes in the two blobs.
+  cost.bytes_read = input_bytes_count + split_bytes_count;
+  // Split operator only changes shape, does not change element count. So,
+  // bytes_written is same as input_bytes_count.
+  cost.bytes_written = input_bytes_count;
+  cost.params_bytes = 0;
+  return cost;
+}
 } // namespace.
 
 REGISTER_CPU_OPERATOR(Split, SplitOp<CPUContext>);
@@ -117,6 +138,7 @@ OPERATOR_SCHEMA(Split)
         "(*string*): order of dimensions of input and output blobs; either \"NCHW\" or \"NHWC\"")
     .Output(0, "[output_0, output_1, ...]", "(*Tensor*): output tensor")
     .TensorInferenceFunction(TensorInferenceForSplit)
+    .CostInferenceFunction(CostInferenceForSplit)
     .DeviceInferenceFunction(splitOpDevInfer)
     .SetDoc(R"DOC(
 Split an `input` tensor into a list of tensors, along the axis specified by the `axis` dimension. The lengths of the split can be specified using argument `split` or optional second input blob to the operator. Otherwise, the tensor is split to equal sized parts.
@@ -296,7 +318,7 @@ OpSchema::Cost CostInferenceForConcat(
       out_shape[canonical_axis] += in[i].dims(canonical_axis);
     }
   }
-  uint64_t nElemRead = 1;
+  uint64_t nElemRead = 0;
   // NOLINTNEXTLINE(modernize-loop-convert,clang-diagnostic-sign-compare)
   for (int i = 0; i < in.size(); ++i) {
     nElemRead += nElemFromDim(in[i]);
@@ -305,11 +327,13 @@ OpSchema::Cost CostInferenceForConcat(
   for (auto& s : out_shape) {
     size *= s;
   }
+  auto split_info_bytes_count = in.size() * sizeof(int);
 
   struct OpSchema::Cost cost;
   cost.flops = 0;
   cost.bytes_read = nElemRead * sizeof(in[0].data_type());
-  cost.bytes_written = size * sizeof(in[0].data_type());
+  cost.bytes_written =
+      size * sizeof(in[0].data_type()) + split_info_bytes_count;
   cost.params_bytes = 0;
   return cost;
 }
index 0c82076..bbe355e 100644 (file)
@@ -63,7 +63,8 @@ class SplitByLengthsOp final : public Operator<Context> {
       axis_ = GetDimFromOrderString(
           this->template GetSingleArgument<string>("order", "NCHW"));
     }
-     scaling_ = this->template GetSingleArgument<bool>("use_scaling_lengths", false);
+    scaling_ =
+        this->template GetSingleArgument<bool>("use_scaling_lengths", false);
   }
 
   bool RunOnDevice() override;
@@ -134,7 +135,11 @@ bool SplitOp<Context>::RunOnDevice() {
         input_channels % OutputSize(),
         0,
         "If you did not specify split explicitly, the number of "
-        "input channels should be divisible by the output size.");
+        "input channels:",
+        input_channels,
+        " should be divisible by the output size:",
+        OutputSize(),
+        ".");
     equal_split.resize(OutputSize(), input_channels / OutputSize());
     axis_data = equal_split.data();
   } else {
@@ -195,18 +200,22 @@ bool SplitByLengthsOp<Context>::RunOnDevice() {
   int32_t* length_data;
 
   if (this->InputIsTensorType(1, CPU)) {
-      length_data = Input(1).template data<int32_t>();
-    } else {
-      // Length input in CUDA context
-      auto& input_length = Input(1);
-      lengths_host_ = TensorCPU(input_length, CPU);
-      length_data = lengths_host_.template data<int32_t>();
+    length_data = Input(1).template data<int32_t>();
+  } else {
+    // Length input in CUDA context
+    auto& input_length = Input(1);
+    lengths_host_ = TensorCPU(input_length, CPU);
+    length_data = lengths_host_.template data<int32_t>();
   }
 
   CAFFE_ENFORCE_EQ(
       lengths_length % OutputSize(),
       0,
-      "len(Lengths) ", lengths_length, "should be divisible by OutputSize() ", OutputSize(), ".");
+      "len(Lengths) ",
+      lengths_length,
+      "should be divisible by OutputSize() ",
+      OutputSize(),
+      ".");
   int canonical_axis = input.canonical_axis_index(axis_);
   CAFFE_ENFORCE_LT(
       canonical_axis, input.dim(), "Axis not in input ndim range.");
@@ -219,21 +228,24 @@ bool SplitByLengthsOp<Context>::RunOnDevice() {
     CAFFE_ENFORCE_EQ(
         input_channels % (sum_lengths ? sum_lengths : 1),
         0,
-        "Input channels ", input_channels, " should be divisible by ",
+        "Input channels ",
+        input_channels,
+        " should be divisible by ",
         sum_lengths);
   } else {
     CAFFE_ENFORCE_EQ(
         sum_lengths,
         input_channels,
         "Input channels should be equal to split dimensions sum, ",
-        input_channels, " vs ", sum_lengths
-        );
+        input_channels,
+        " vs ",
+        sum_lengths);
   }
   vector<int64_t> output_dims(input.sizes().vec());
   int before = input.size_to_dim(canonical_axis);
   int after = input.size_from_dim(canonical_axis + 1);
   size_t input_offset = 0;
-  auto dim_multiplier = sum_lengths ? (input_channels / sum_lengths): 1;
+  auto dim_multiplier = sum_lengths ? (input_channels / sum_lengths) : 1;
 
   if (!scaling_) {
     dim_multiplier = 1;
@@ -242,8 +254,10 @@ bool SplitByLengthsOp<Context>::RunOnDevice() {
   for (int i = 0; i < OutputSize(); ++i) {
     auto* output = Output(i);
     const auto* axis_offset = axis_data + lengths_length / OutputSize() * i;
-    auto axis_dim = dim_multiplier * std::accumulate(
-        axis_offset, axis_offset + lengths_length / OutputSize(), 0);
+    auto axis_dim =
+        dim_multiplier *
+        std::accumulate(
+            axis_offset, axis_offset + lengths_length / OutputSize(), 0);
     output_dims[canonical_axis] = axis_dim;
     output->Resize(output_dims);
     math::CopyMatrix<Context>(
diff --git a/caffe2/python/operator_test/concat_op_cost_test.py b/caffe2/python/operator_test/concat_op_cost_test.py
new file mode 100644 (file)
index 0000000..996b330
--- /dev/null
@@ -0,0 +1,78 @@
+from collections import namedtuple
+
+import numpy as np
+from caffe2.python import core, workspace
+from caffe2.python.test_util import TestCase
+
+
+class TestConcatOpCost(TestCase):
+    def test_columnwise_concat(self):
+        workspace.ResetWorkspace()
+        workspace.FeedBlob("input_1", np.array([[1, 2, 3], [4, 5, 6]], dtype=np.int32))
+        workspace.FeedBlob("input_2", np.array([[7], [8]], dtype=np.int32))
+        concat_op = core.CreateOperator(
+            "Concat",
+            ["input_1", "input_2"],
+            ["output", "split_info"],
+        )
+        workspace.RunOperatorOnce(concat_op)
+
+        output = workspace.FetchBlob("output")
+        self.assertTupleEqual(output.shape, (2, 4))
+        np.testing.assert_array_equal(output, [[1, 2, 3, 7], [4, 5, 6, 8]])
+
+        flops, bytes_written, bytes_read = workspace.GetOperatorCost(
+            concat_op, concat_op.input
+        )
+
+        self.assertEqual(flops, 0)
+        self.assertEqual(
+            bytes_read,
+            sum(workspace.FetchBlob(b).nbytes for b in concat_op.input),
+        )
+        self.assertEqual(
+            bytes_written,
+            sum(workspace.FetchBlob(b).nbytes for b in concat_op.output),
+        )
+
+    def test_split_then_concat(self):
+        workspace.ResetWorkspace()
+        workspace.FeedBlob("input", np.array([[1, 2, 3], [4, 5, 6]], dtype=np.int32))
+        workspace.FeedBlob("split", np.array([1, 1, 1], dtype=np.int32))
+        split_op = core.CreateOperator(
+            "Split",
+            ["input", "split"],
+            ["output_1", "output_2", "output_3"],
+            axis=1,
+            add_axis=1,
+        )
+        workspace.RunOperatorOnce(split_op)
+
+        concat_op = core.CreateOperator(
+            "Concat",
+            ["output_1", "output_2", "output_3"],
+            ["output", "split_info"],
+            axis=1,
+            add_axis=1,
+        )
+        workspace.RunOperatorOnce(concat_op)
+
+        np.testing.assert_array_equal(
+            workspace.FetchBlob("input"), workspace.FetchBlob("output")
+        )
+
+        split_cost = workspace.GetOperatorCost(split_op, split_op.input)
+        self.assertTupleEqual(
+            split_cost,
+            namedtuple("expected_cost", ["flops", "bytes_written", "bytes_read"])(
+                0, 24, 36
+            ),
+        )
+
+        concat_cost = workspace.GetOperatorCost(concat_op, concat_op.input)
+        self.assertTupleEqual(
+            concat_cost,
+            namedtuple("expected_cost", ["flops", "bytes_written", "bytes_read"])(
+                0, 36, 24
+            ),
+        )
diff --git a/caffe2/python/operator_test/split_op_cost_test.py b/caffe2/python/operator_test/split_op_cost_test.py
new file mode 100644 (file)
index 0000000..97df350
--- /dev/null
@@ -0,0 +1,246 @@
+import numpy as np
+from caffe2.python import core, workspace
+from caffe2.python.test_util import TestCase
+
+
+class TestSplitOpCost(TestCase):
+    def _verify_cost(self, workspace, split_op):
+        flops, bytes_written, bytes_read = workspace.GetOperatorCost(
+            split_op, split_op.input
+        )
+        self.assertEqual(flops, 0)
+        self.assertEqual(
+            bytes_read,
+            sum(workspace.FetchBlob(b).nbytes for b in split_op.input),
+        )
+        self.assertEqual(
+            bytes_written,
+            sum(workspace.FetchBlob(b).nbytes for b in split_op.output),
+        )
+
+    def test_columnwise_equal_outputSplit(self):
+        workspace.ResetWorkspace()
+        workspace.FeedBlob("input", np.array([[1, 2, 3], [4, 5, 6]], dtype=np.int32))
+        split_op = core.CreateOperator(
+            "Split",
+            ["input"],
+            ["output_1", "output_2", "output_3"],
+        )
+        workspace.RunOperatorOnce(split_op)
+
+        output_1 = workspace.FetchBlob("output_1")
+        self.assertTupleEqual(output_1.shape, (2, 1))
+        np.testing.assert_array_equal(output_1, [[1], [4]])
+
+        output_2 = workspace.FetchBlob("output_2")
+        np.testing.assert_array_equal(output_2, [[2], [5]])
+
+        output_3 = workspace.FetchBlob("output_3")
+        np.testing.assert_array_equal(output_3, [[3], [6]])
+
+        self._verify_cost(workspace, split_op)
+
+    def test_rowwise_equal_outputSplit(self):
+        workspace.ResetWorkspace()
+        workspace.FeedBlob("input", np.array([[1, 2, 3], [4, 5, 6]], dtype=np.int32))
+        split_op = core.CreateOperator(
+            "Split",
+            ["input"],
+            ["output_1", "output_2"],
+            axis=0,
+        )
+        workspace.RunOperatorOnce(split_op)
+
+        output_1 = workspace.FetchBlob("output_1")
+        self.assertTupleEqual(output_1.shape, (1, 3))
+        np.testing.assert_array_equal(output_1, [[1, 2, 3]])
+
+        output_2 = workspace.FetchBlob("output_2")
+        np.testing.assert_array_equal(output_2, [[4, 5, 6]])
+
+        self._verify_cost(workspace, split_op)
+
+    def test_columnwise_equal_outputSplit_columnRemoved(self):
+        workspace.ResetWorkspace()
+        workspace.FeedBlob("input", np.array([[1, 2, 3], [4, 5, 6]], dtype=np.int32))
+        # To be able to use 'add_axis' (which should have been called 'remove_axis') on 'axis',
+        # the dimensions of split tensors must match on 'axis'
+        split_op = core.CreateOperator(
+            "Split",
+            ["input"],
+            ["output_1", "output_2", "output_3"],
+            axis=1,
+            add_axis=1,
+        )
+        workspace.RunOperatorOnce(split_op)
+
+        output_1 = workspace.FetchBlob("output_1")
+        self.assertTupleEqual(output_1.shape, (2,))
+        np.testing.assert_array_equal(output_1, [1, 4])
+
+        output_2 = workspace.FetchBlob("output_2")
+        np.testing.assert_array_equal(output_2, [2, 5])
+
+        output_3 = workspace.FetchBlob("output_3")
+        np.testing.assert_array_equal(output_3, [3, 6])
+
+        self._verify_cost(workspace, split_op)
+
+    def test_rowwise_equal_outputSplit_rowRemoved(self):
+        workspace.ResetWorkspace()
+        workspace.FeedBlob("input", np.array([[1, 2, 3], [4, 5, 6]], dtype=np.int32))
+        split_op = core.CreateOperator(
+            "Split",
+            ["input"],
+            ["output_1", "output_2"],
+            axis=0,
+            add_axis=1,
+        )
+        workspace.RunOperatorOnce(split_op)
+
+        output_1 = workspace.FetchBlob("output_1")
+        self.assertTupleEqual(output_1.shape, (3,))
+        np.testing.assert_array_equal(output_1, [1, 2, 3])
+
+        output_2 = workspace.FetchBlob("output_2")
+        np.testing.assert_array_equal(output_2, [4, 5, 6])
+
+        self._verify_cost(workspace, split_op)
+
+    def test_rowwise_unequal_argSplit(self):
+        workspace.ResetWorkspace()
+        workspace.FeedBlob(
+            "input", np.array([[1, 2, 3], [4, 5, 6], [7, 8, 9]], dtype=np.int32)
+        )
+        split_op = core.CreateOperator(
+            "Split",
+            ["input"],
+            ["output_1", "output_2"],
+            axis=0,
+            split=[1, 2],
+        )
+        workspace.RunOperatorOnce(split_op)
+
+        output_1 = workspace.FetchBlob("output_1")
+        self.assertTupleEqual(output_1.shape, (1, 3))
+        np.testing.assert_array_equal(output_1, [[1, 2, 3]])
+
+        output_2 = workspace.FetchBlob("output_2")
+        self.assertTupleEqual(output_2.shape, (2, 3))
+        np.testing.assert_array_equal(output_2, [[4, 5, 6], [7, 8, 9]])
+
+        self._verify_cost(workspace, split_op)
+
+    def test_rowwise_unequal_argSplit_rowRemoved(self):
+        workspace.ResetWorkspace()
+        workspace.FeedBlob(
+            "input", np.array([[1, 2, 3], [4, 5, 6], [7, 8, 9]], dtype=np.int32)
+        )
+        split_op = core.CreateOperator(
+            "Split",
+            ["input"],
+            ["output_1", "output_2", "output_3"],
+            axis=0,
+            split=[1, 1, 1],
+            add_axis=1,
+        )
+        workspace.RunOperatorOnce(split_op)
+
+        output_1 = workspace.FetchBlob("output_1")
+        self.assertTupleEqual(output_1.shape, (3,))
+        np.testing.assert_array_equal(output_1, [1, 2, 3])
+
+        output_2 = workspace.FetchBlob("output_2")
+        np.testing.assert_array_equal(output_2, [4, 5, 6])
+
+        output_3 = workspace.FetchBlob("output_3")
+        np.testing.assert_array_equal(output_3, [7, 8, 9])
+
+        self._verify_cost(workspace, split_op)
+
+    def test_rowwise_unequal_blobSplit(self):
+        workspace.ResetWorkspace()
+        workspace.FeedBlob(
+            "input", np.array([[1, 2, 3], [4, 5, 6], [7, 8, 9]], dtype=np.int32)
+        )
+        workspace.FeedBlob("split", np.array([1, 2], dtype=np.int32))
+        split_op = core.CreateOperator(
+            "Split",
+            ["input", "split"],
+            ["output_1", "output_2"],
+            axis=0,
+        )
+        workspace.RunOperatorOnce(split_op)
+
+        output_1 = workspace.FetchBlob("output_1")
+        self.assertTupleEqual(output_1.shape, (1, 3))
+        np.testing.assert_array_equal(output_1, [[1, 2, 3]])
+
+        output_2 = workspace.FetchBlob("output_2")
+        self.assertTupleEqual(output_2.shape, (2, 3))
+        np.testing.assert_array_equal(output_2, [[4, 5, 6], [7, 8, 9]])
+
+        self._verify_cost(workspace, split_op)
+
+    def test_columnwise_unequal_argSplit(self):
+        workspace.ResetWorkspace()
+        workspace.FeedBlob("input", np.array([[1, 2, 3], [4, 5, 6]], dtype=np.int32))
+        split_op = core.CreateOperator(
+            "Split",
+            ["input"],
+            ["output_1", "output_2"],
+            axis=1,
+            split=[1, 2],
+        )
+        workspace.RunOperatorOnce(split_op)
+
+        output_1 = workspace.FetchBlob("output_1")
+        self.assertTupleEqual(output_1.shape, (2, 1))
+        np.testing.assert_array_equal(output_1, [[1], [4]])
+
+        output_2 = workspace.FetchBlob("output_2")
+        self.assertTupleEqual(output_2.shape, (2, 2))
+        np.testing.assert_array_equal(output_2, [[2, 3], [5, 6]])
+
+        self._verify_cost(workspace, split_op)
+
+    def test_columnWise_unequal_blobSplit_columnRemoved(self):
+        workspace.ResetWorkspace()
+        workspace.FeedBlob("input", np.array([[1, 2, 3], [4, 5, 6]], dtype=np.int32))
+        workspace.FeedBlob("split", np.array([1, 1, 1], dtype=np.int32))
+        split_op = core.CreateOperator(
+            "Split",
+            ["input", "split"],
+            ["output_1", "output_2", "output_3"],
+            axis=1,
+            add_axis=1,
+        )
+        workspace.RunOperatorOnce(split_op)
+
+        output_1 = workspace.FetchBlob("output_1")
+        self.assertTupleEqual(output_1.shape, (2,))
+        np.testing.assert_array_equal(output_1, [1, 4])
+
+        output_2 = workspace.FetchBlob("output_2")
+        np.testing.assert_array_equal(output_2, [2, 5])
+
+        output_3 = workspace.FetchBlob("output_3")
+        np.testing.assert_array_equal(output_3, [3, 6])
+
+        self._verify_cost(workspace, split_op)
+
+    def test_equal_outputSplit_NHWC(self):
+        workspace.ResetWorkspace()
+        workspace.FeedBlob("input", np.random.rand(2, 5, 7, 9).astype(np.int32))
+        split_op = core.CreateOperator(
+            "Split",
+            ["input"],
+            ["output_1", "output_2", "output_3"],
+            order="NHWC",
+        )
+        workspace.RunOperatorOnce(split_op)
+
+        for b in split_op.output:
+            self.assertTupleEqual(workspace.FetchBlob(b).shape, (2, 5, 7, 3))
+
+        self._verify_cost(workspace, split_op)