FC shape inference should use int64_t (#15961)
authorLin Yang <ylin1@fb.com>
Fri, 11 Jan 2019 22:09:50 +0000 (14:09 -0800)
committerFacebook Github Bot <facebook-github-bot@users.noreply.github.com>
Fri, 11 Jan 2019 22:28:39 +0000 (14:28 -0800)
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/15961

as title

Reviewed By: yinghai

Differential Revision: D13634427

fbshipit-source-id: ec7d168b6272f0dac8a693401cfd0bea368f929a

caffe2/core/operator_schema.h
caffe2/operators/fc_inference.cc
caffe2/operators/fully_connected_op.cc

index ddab956..58c08a3 100644 (file)
@@ -521,7 +521,7 @@ inline TensorShape CreateTensorShape(
     vector<T_I> dims,
     ::caffe2::TensorProto_DataType dt) {
   TensorShape ts;
-  for (int d : dims) {
+  for (T_I d : dims) {
     ts.add_dims(d);
   }
   ts.set_data_type(dt);
index 2c11e8a..6b82acf 100644 (file)
@@ -13,14 +13,15 @@ std::vector<TensorShape> FCShapeInference(
   auto axis_w = helper.GetSingleArgument<int32_t>("axis_w", 1);
   const int canonical_axis_w =
       canonical_axis_index_(axis_w, in[1].dims().size());
-  const int N = pretransposed_weight
+  const int64_t N = pretransposed_weight
       ? size_from_dim_(canonical_axis_w, GetDimsVector(in[1]))
       : size_to_dim_(canonical_axis_w, GetDimsVector(in[1]));
 
-  vector<int> y_shape(in[0].dims().begin(), in[0].dims().end());
+  vector<int64_t> y_shape(in[0].dims().begin(), in[0].dims().end());
   CAFFE_ENFORCE_LE(canonical_axis + 1, y_shape.size());
   y_shape.resize(canonical_axis + 1);
   y_shape[canonical_axis] = N;
+
   out[0] = CreateTensorShape(y_shape, in[0].data_type());
   return out;
 }
index 141fc3a..452dac8 100644 (file)
@@ -41,14 +41,15 @@ std::vector<TensorShape> FCShapeInference(
   auto axis_w = helper.GetSingleArgument<int32_t>("axis_w", 1);
   const int canonical_axis_w =
       canonical_axis_index_(axis_w, in[1].dims().size());
-  const int N = pretransposed_weight
+  const int64_t N = pretransposed_weight
       ? size_from_dim_(canonical_axis_w, GetDimsVector(in[1]))
       : size_to_dim_(canonical_axis_w, GetDimsVector(in[1]));
 
-  vector<int> y_shape(in[0].dims().begin(), in[0].dims().end());
+  vector<int64_t> y_shape(in[0].dims().begin(), in[0].dims().end());
   CAFFE_ENFORCE_LE(canonical_axis + 1, y_shape.size());
   y_shape.resize(canonical_axis + 1);
   y_shape[canonical_axis] = N;
+
   out[0] = CreateTensorShape(y_shape, in[0].data_type());
   return out;
 }