From 979db037221d91f9dbd08df58a261d7448e6240d Mon Sep 17 00:00:00 2001 From: Yinghai Lu Date: Thu, 21 Mar 2019 15:28:20 -0700 Subject: [PATCH] Blacklist certain op types when doing bound shape inference (#18290) Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/18290 Some such as `Tile` will mess up our tracking of batch size and for now it makes sense to stop the shape inference on these ops so that we don't lower it and downstream ops without proper batch info. Reviewed By: zrphercule Differential Revision: D14463550 fbshipit-source-id: 2792481efa540f2a7dd310e677c213860c3053ca --- caffe2/opt/bound_shape_inference_test.cc | 10 +++++++--- caffe2/opt/bound_shape_inferencer.cc | 31 ++++++++++++------------------- 2 files changed, 19 insertions(+), 22 deletions(-) diff --git a/caffe2/opt/bound_shape_inference_test.cc b/caffe2/opt/bound_shape_inference_test.cc index ee59fe7..d8f77cf 100644 --- a/caffe2/opt/bound_shape_inference_test.cc +++ b/caffe2/opt/bound_shape_inference_test.cc @@ -244,8 +244,7 @@ TEST(BoundShapeInference, FC) { {spec.max_batch_size, 1024}); } -// We don't support inference input shape when Weight is not 2D -TEST(BoundShapeInference, UnsupportedFC) { +TEST(BoundShapeInference, FC3D) { NetDef net; net.add_op()->CopyFrom( CreateOperatorDef("FC", "", {"X0", "W0", "B0"}, {"Out0"}, {})); @@ -255,7 +254,12 @@ TEST(BoundShapeInference, UnsupportedFC) { shape_map.emplace("B0", makeTensorInfo(ShapeInfo::DimType::CONSTANT, {16})); BoundShapeSpec spec(20, 1000); BoundShapeInferencer eng(spec); - EXPECT_THROW(eng.InferBoundShapeAndType(net, shape_map), EnforceNotMet); + eng.InferBoundShapeAndType(net, shape_map); + const auto& out_shape = eng.shape_info(); + verifyShapeInfo( + out_shape, "X0", ShapeInfo::DimType::BATCH, {spec.max_batch_size, 1024}); + verifyShapeInfo( + out_shape, "Out0", ShapeInfo::DimType::BATCH, {spec.max_batch_size, 16}); } TEST(BoundShapeInference, Combo0) { diff --git a/caffe2/opt/bound_shape_inferencer.cc b/caffe2/opt/bound_shape_inferencer.cc index dea1eeb..717f843 100644 --- a/caffe2/opt/bound_shape_inferencer.cc +++ b/caffe2/opt/bound_shape_inferencer.cc @@ -44,10 +44,15 @@ void EnsureShapeNames(std::unordered_map* info) { void BoundShapeInferencer::InferBoundShapeAndType( const NetDef& net, const std::unordered_map& info) { + const static std::unordered_set unsupported{"Tile"}; shape_info_ = info; for (const auto& op : net.op()) { VLOG(1) << op.type(); + if (unsupported.count(op.type())) { + continue; + } + if (op.type() == "SparseLengthsSum" || op.type() == "SparseLengthsSumFused8BitRowwise" || op.type() == "SparseLengthsWeightedSum" || @@ -316,34 +321,22 @@ void BoundShapeInferencer::InferFC(const OperatorDef& op) { ArgumentHelper helper(op); auto axis = helper.GetSingleArgument("axis", 1); auto axis_w = helper.GetSingleArgument("axis_w", 1); - CAFFE_ENFORCE_EQ( - axis, - 1, - "Don't know how to deduce input of FC with axis not equal to 1: ", - op.input(0)); - CAFFE_ENFORCE_EQ( - axis_w, - 1, - "Don't know how to deduce input of FC with axis_w not equal to 1: ", - op.input(0)); const TensorShape w_shape = w_shape_info.shape; - CAFFE_ENFORCE_EQ( - w_shape.dims_size(), - 2, - "Don't know how to deduce input of FC other than of dim size 2: ", - op.input(0)); bool transposed = (op.type() == "FC") ? false : true; const int canonical_axis_w = canonical_axis_index_(axis_w, w_shape.dims().size()); const int64_t K = transposed ? SizeToDim(w_shape, canonical_axis_w) : SizeFromDim(w_shape, canonical_axis_w); + std::vector dims; + for (int i = 0; i < axis - 1; ++i) { + dims.push_back(1); + } + dims.push_back(spec_.max_batch_size); + dims.push_back(K); current_dim_type_ = ShapeInfo::DimType::BATCH; current_max_batch_size_ = spec_.max_batch_size; CheckAndSetTensorShapeAndType( - op.input(0), - ShapeInfo::DimType::BATCH, - {spec_.max_batch_size, K}, - w_shape.data_type()); + op.input(0), ShapeInfo::DimType::BATCH, dims, w_shape.data_type()); } else { ShapeInfo& x_shape_info = x_it->second; if (x_shape_info.dim_type != ShapeInfo::DimType::BATCH) { -- 2.7.4