Fix concat dimension check bug (#17343)
authorChandler Zuo <chandlerzuo@fb.com>
Fri, 22 Feb 2019 03:31:21 +0000 (19:31 -0800)
committerFacebook Github Bot <facebook-github-bot@users.noreply.github.com>
Fri, 22 Feb 2019 03:34:30 +0000 (19:34 -0800)
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/17343

See [post](https://fb.workplace.com/groups/1405155842844877/permalink/2630764056950710/)

Reviewed By: dzhulgakov

Differential Revision: D14163001

fbshipit-source-id: 038f15d6a58b3bc31910e7bfa47c335e25739f12

caffe2/operators/concat_split_op.cc

index d630f44..57ef45c 100644 (file)
@@ -179,7 +179,10 @@ OPERATOR_SCHEMA(Concat)
           : GetDimFromOrderString(
                 helper.GetSingleArgument<string>("order", "NCHW"));
       bool add_axis = helper.GetSingleArgument<int>("add_axis", 0) != 0;
-      const int canonical_axis = canonical_axis_index_(axis, in[0].dims_size());
+      int adj_size = in[0].dims_size() + (add_axis ? 1 : 0);
+      const int canonical_axis = canonical_axis_index_(axis, adj_size);
+      CAFFE_ENFORCE_LT(
+          canonical_axis, adj_size, "Axis not in input ndim range.");
       CAFFE_ENFORCE_GT(in.size(), 0);
       vector<int> split_shape(1, in.size());
       vector<int> out_shape(in[0].dims().begin(), in[0].dims().end());