From d042914221e4308c8e1610144ac83a31ff519868 Mon Sep 17 00:00:00 2001 From: Lin Yang Date: Fri, 11 Jan 2019 14:09:50 -0800 Subject: [PATCH] FC shape inference should use int64_t (#15961) Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/15961 as title Reviewed By: yinghai Differential Revision: D13634427 fbshipit-source-id: ec7d168b6272f0dac8a693401cfd0bea368f929a --- caffe2/core/operator_schema.h | 2 +- caffe2/operators/fc_inference.cc | 5 +++-- caffe2/operators/fully_connected_op.cc | 5 +++-- 3 files changed, 7 insertions(+), 5 deletions(-) diff --git a/caffe2/core/operator_schema.h b/caffe2/core/operator_schema.h index ddab956..58c08a3 100644 --- a/caffe2/core/operator_schema.h +++ b/caffe2/core/operator_schema.h @@ -521,7 +521,7 @@ inline TensorShape CreateTensorShape( vector dims, ::caffe2::TensorProto_DataType dt) { TensorShape ts; - for (int d : dims) { + for (T_I d : dims) { ts.add_dims(d); } ts.set_data_type(dt); diff --git a/caffe2/operators/fc_inference.cc b/caffe2/operators/fc_inference.cc index 2c11e8a..6b82acf 100644 --- a/caffe2/operators/fc_inference.cc +++ b/caffe2/operators/fc_inference.cc @@ -13,14 +13,15 @@ std::vector FCShapeInference( auto axis_w = helper.GetSingleArgument("axis_w", 1); const int canonical_axis_w = canonical_axis_index_(axis_w, in[1].dims().size()); - const int N = pretransposed_weight + const int64_t N = pretransposed_weight ? size_from_dim_(canonical_axis_w, GetDimsVector(in[1])) : size_to_dim_(canonical_axis_w, GetDimsVector(in[1])); - vector y_shape(in[0].dims().begin(), in[0].dims().end()); + vector y_shape(in[0].dims().begin(), in[0].dims().end()); CAFFE_ENFORCE_LE(canonical_axis + 1, y_shape.size()); y_shape.resize(canonical_axis + 1); y_shape[canonical_axis] = N; + out[0] = CreateTensorShape(y_shape, in[0].data_type()); return out; } diff --git a/caffe2/operators/fully_connected_op.cc b/caffe2/operators/fully_connected_op.cc index 141fc3a..452dac8 100644 --- a/caffe2/operators/fully_connected_op.cc +++ b/caffe2/operators/fully_connected_op.cc @@ -41,14 +41,15 @@ std::vector FCShapeInference( auto axis_w = helper.GetSingleArgument("axis_w", 1); const int canonical_axis_w = canonical_axis_index_(axis_w, in[1].dims().size()); - const int N = pretransposed_weight + const int64_t N = pretransposed_weight ? size_from_dim_(canonical_axis_w, GetDimsVector(in[1])) : size_to_dim_(canonical_axis_w, GetDimsVector(in[1])); - vector y_shape(in[0].dims().begin(), in[0].dims().end()); + vector y_shape(in[0].dims().begin(), in[0].dims().end()); CAFFE_ENFORCE_LE(canonical_axis + 1, y_shape.size()); y_shape.resize(canonical_axis + 1); y_shape[canonical_axis] = N; + out[0] = CreateTensorShape(y_shape, in[0].data_type()); return out; } -- 2.7.4