Add shape inference function for Split (#18838)
authorYinghai Lu <yinghai@fb.com>
Thu, 4 Apr 2019 07:19:21 +0000 (00:19 -0700)
committerFacebook Github Bot <facebook-github-bot@users.noreply.github.com>
Thu, 4 Apr 2019 07:22:22 +0000 (00:22 -0700)
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/18838

It turns out that we don't have shape inference function of `Split` op at all. This diff adds that.

Reviewed By: bertmaher

Differential Revision: D14766871

fbshipit-source-id: 535cb4f24bdada603c76579e00e7a39aee93e19f

caffe2/operators/concat_split_op.cc
caffe2/opt/bound_shape_inference_test.cc

index ff66578..3b4bf97 100644 (file)
@@ -16,6 +16,76 @@ std::pair<std::vector<DeviceOption>, std::vector<DeviceOption>> splitOpDevInfer(
   }
   return std::make_pair(in_dev, out_dev);
 }
+
+vector<TensorShape> TensorInferenceForSplit(
+    const OperatorDef& def,
+    const vector<TensorShape>& in) {
+  auto ret_invalid_shape = [&def]() {
+    vector<TensorShape> out(def.output().size());
+    for (auto& out_ts : out) {
+      out_ts.set_unknown_shape(true);
+    }
+    return out;
+  };
+  // We only support shape inference of Split with 1 input
+  if (def.input_size() != 1 || in.empty() || in.front().unknown_shape()) {
+    return ret_invalid_shape();
+  } else if (def.output_size() == 0) {
+    return vector<TensorShape>();
+  }
+  ArgumentHelper helper(def);
+  const int axis = helper.HasArgument("axis")
+      ? helper.GetSingleArgument<int>("axis", -1)
+      : GetDimFromOrderString(
+            helper.GetSingleArgument<string>("order", "NCHW"));
+  const int add_axis = helper.HasArgument("axis")
+      ? helper.GetSingleArgument<int>("add_axis", 0)
+      : 0;
+  const auto& input = in[0];
+  const int canonical_axis = canonical_axis_index_(axis, input.dims_size());
+  const int input_channels = input.dims(canonical_axis);
+  auto split = helper.GetRepeatedArgument<int>("split");
+  // Equally split the input into outputs
+  const int output_size = def.output_size();
+  if (split.empty()) {
+    if (!input_channels % output_size) {
+      LOG(WARNING) << "Input channels (" << input_channels
+                   << ") should be divisible by number of outputs ("
+                   << output_size << ")";
+      return ret_invalid_shape();
+    }
+    split.resize(output_size, input_channels / output_size);
+  } else if (split.size() != output_size) {
+    LOG(WARNING) << "`split` size (" << split.size()
+                 << ") should be equal to output size (" << output_size << ")";
+    return ret_invalid_shape();
+  }
+
+  // Check validity of the split
+  const int total_channels = add_axis
+      ? def.output_size()
+      : std::accumulate(split.begin(), split.begin() + output_size, 0);
+  if (total_channels != input_channels) {
+    LOG(WARNING) << "Input channels (" << input_channels
+                 << ") is not equal to total output channels ("
+                 << total_channels << ")";
+    return ret_invalid_shape();
+  }
+
+  vector<int> output_dims(input.dims().begin(), input.dims().end());
+  if (add_axis) {
+    output_dims.erase(output_dims.begin() + canonical_axis);
+  }
+  vector<TensorShape> output_shapes;
+  for (int i = 0; i < output_size; ++i) {
+    if (!add_axis) {
+      output_dims[canonical_axis] = split[i];
+    }
+    output_shapes.emplace_back(
+        CreateTensorShape(output_dims, input.data_type()));
+  }
+  return output_shapes;
+}
 } // namespace.
 
 REGISTER_CPU_OPERATOR(Split, SplitOp<CPUContext>);
@@ -29,11 +99,15 @@ OPERATOR_SCHEMA(Split)
         "split",
         "(*Tensor`<int>`*): [OPTIONAL] list of output lengths (see also arg `split`)")
     .Arg("axis", "(*int*): axis to split on")
+    .Arg(
+        "add_axis",
+        "*(type: int)* Pass non-zero integer to remove the axis specified in `axis` to all input tensors.")
     .Arg("split", "(*Tuple(int)*): length of each output")
     .Arg(
         "order",
         "(*string*): order of dimensions of input and output blobs; either \"NCHW\" or \"NHWC\"")
     .Output(0, "[output_0, output_1, ...]", "(*Tensor*): output tensor")
+    .TensorInferenceFunction(TensorInferenceForSplit)
     .DeviceInferenceFunction(splitOpDevInfer)
     .SetDoc(R"DOC(
 Split an `input` tensor into a list of tensors, along the axis specified by the `axis` dimension. The lengths of the split can be specified using argument `split` or optional second input blob to the operator. Otherwise, the tensor is split to equal sized parts.
index d8f77cf..a148b0e 100644 (file)
@@ -214,6 +214,60 @@ TEST(BoundShapeInference, ConcatMissingInput) {
       {spec.max_batch_size, 2, 60});
 }
 
+TEST(BoundShapeInference, Split) {
+  NetDef net;
+  net.add_op()->CopyFrom(CreateOperatorDef(
+      "Split", "", {"X"}, {"Y0", "Y1"}, {MakeArgument<int>("axis", 1)}));
+  net.add_op()->CopyFrom(CreateOperatorDef(
+      "Split",
+      "",
+      {"X"},
+      {"Y2", "Y3", "Y4"},
+      {MakeArgument<int>("axis", 1),
+       MakeArgument<std::vector<int>>("split", {4, 30, 14})}));
+  net.add_op()->CopyFrom(CreateOperatorDef(
+      "Split",
+      "",
+      {"X1"},
+      {"Y5", "Y6"},
+      {MakeArgument<int>("axis", 1), MakeArgument<int>("add_axis", 1)}));
+  BoundShapeSpec spec(20, 1000);
+  ShapeInfoMap shape_map;
+  shape_map.emplace(
+      "X",
+      makeTensorInfo(ShapeInfo::DimType::BATCH, {spec.max_batch_size, 48}));
+  shape_map.emplace(
+      "X1",
+      makeTensorInfo(ShapeInfo::DimType::BATCH, {spec.max_batch_size, 2, 48}));
+  BoundShapeInferencer eng(spec);
+  eng.InferBoundShapeAndType(net, shape_map);
+  const auto& out_shape = eng.shape_info();
+  verifyShapeInfo(
+      out_shape, "X", ShapeInfo::DimType::BATCH, {spec.max_batch_size, 48});
+  verifyShapeInfo(
+      out_shape, "X1", ShapeInfo::DimType::BATCH, {spec.max_batch_size, 2, 48});
+  verifyShapeInfo(
+      out_shape,
+      "Y0",
+      ShapeInfo::DimType::BATCH,
+      {spec.max_batch_size, 48 / 2});
+  verifyShapeInfo(
+      out_shape,
+      "Y1",
+      ShapeInfo::DimType::BATCH,
+      {spec.max_batch_size, 48 / 2});
+  verifyShapeInfo(
+      out_shape, "Y2", ShapeInfo::DimType::BATCH, {spec.max_batch_size, 4});
+  verifyShapeInfo(
+      out_shape, "Y3", ShapeInfo::DimType::BATCH, {spec.max_batch_size, 30});
+  verifyShapeInfo(
+      out_shape, "Y4", ShapeInfo::DimType::BATCH, {spec.max_batch_size, 14});
+  verifyShapeInfo(
+      out_shape, "Y5", ShapeInfo::DimType::BATCH, {spec.max_batch_size, 48});
+  verifyShapeInfo(
+      out_shape, "Y6", ShapeInfo::DimType::BATCH, {spec.max_batch_size, 48});
+}
+
 TEST(BoundShapeInference, FC) {
   NetDef net;
   net.add_op()->CopyFrom(