From 65b49b46966abc1fd9b7ad6668a994e9a669be96 Mon Sep 17 00:00:00 2001 From: Yinghai Lu Date: Tue, 12 Feb 2019 14:43:43 -0800 Subject: [PATCH] Ignore unknown_shaped tensor in bound shape inference (#16916) Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/16916 Two fixes for maximum effort bound shape inference 1. Ignore failed and unknown shape 2. Add specialization for `SparseLengthsWeightedSumFused8BitRowwise`. Reviewed By: ipiszy Differential Revision: D14017810 fbshipit-source-id: 25cd68d35aa20b9ed077bdb562eb7f9deff0ab96 --- caffe2/opt/bound_shape_inferencer.cc | 32 ++++++++++++++++++++------------ caffe2/opt/bound_shape_inferencer.h | 1 - 2 files changed, 20 insertions(+), 13 deletions(-) diff --git a/caffe2/opt/bound_shape_inferencer.cc b/caffe2/opt/bound_shape_inferencer.cc index bf57660..adfff97 100644 --- a/caffe2/opt/bound_shape_inferencer.cc +++ b/caffe2/opt/bound_shape_inferencer.cc @@ -44,13 +44,13 @@ void BoundShapeInferencer::InferBoundShapeAndType( const NetDef& net, const std::unordered_map& info) { shape_info_ = info; - visited_tensors_.clear(); for (const auto& op : net.op()) { - VLOG(1) << op.type(); + LOG(INFO) << op.type(); if (op.type() == "SparseLengthsSum" || op.type() == "SparseLengthsSumFused8BitRowwise" || - op.type() == "SparseLengthsWeightedSum") { + op.type() == "SparseLengthsWeightedSum" || + op.type() == "SparseLengthsWeightedSumFused8BitRowwise") { InferSparseLengthsSum(op); } else if (op.type() == "FC" || op.type() == "FCTransposed") { InferFC(op); @@ -72,12 +72,8 @@ TensorShape& BoundShapeInferencer::CheckAndSetTensorShapeAndType( ShapeInfo::DimType t, std::vector bound_dims, TensorProto::DataType type) { - if (!visited_tensors_.emplace(name).second) { - return shape_info_.at(name).shape; - } auto rt = shape_info_.emplace(name, ShapeInfo()); ShapeInfo& shape_info = rt.first->second; - shape_info.dim_type = t; TensorShape& shape = shape_info.shape; if (!rt.second) { // Check shape consistency @@ -107,6 +103,7 @@ TensorShape& BoundShapeInferencer::CheckAndSetTensorShapeAndType( return shape; } + shape_info.dim_type = t; shape.mutable_dims()->Clear(); for (const auto d : bound_dims) { shape.add_dims(d); @@ -156,7 +153,11 @@ void BoundShapeInferencer::InferSparseLengthsSum(const OperatorDef& op) { op.input(0), "needs to be 2D"); - int weight = (op.type() == "SparseLengthsWeightedSum") ? 1 : 0; + int weight = (op.type() == "SparseLengthsWeightedSum" || + op.type() == "SparseLengthsWeightedSumFused8BitRowwise") + ? 1 + : 0; + if (weight) { CAFFE_ENFORCE_EQ( op.input_size(), 4, "SparseLengthsWeightedSum must have 4 inputs"); @@ -186,7 +187,8 @@ void BoundShapeInferencer::InferSparseLengthsSum(const OperatorDef& op) { auto output_dim1 = it->second.shape.dims(1); // If the op is SparseLengthsSumFused8BitRowwise, we need to extract 4 for // scale and 4 byte for bias (https://fburl.com/t6dp9tsc) - if (op.type() == "SparseLengthsSumFused8BitRowwise") { + if (op.type() == "SparseLengthsSumFused8BitRowwise" || + op.type() == "SparseLengthsWeightedSumFused8BitRowwise") { output_dim1 -= 8; } CheckAndSetTensorShapeAndType( @@ -211,9 +213,12 @@ void BoundShapeInferencer::InferConcat(const OperatorDef& op) { if (it != shape_info_.end()) { const auto& current_input_shape = it->second; if (ref_input_shape) { - CAFFE_ENFORCE( + CAFFE_ENFORCE_EQ( ref_input_shape->shape.dims_size(), - current_input_shape.shape.dims_size()); + current_input_shape.shape.dims_size(), + ref_name, + " vs ", + i); for (int j = 0; j < ref_input_shape->shape.dims_size(); ++j) { CAFFE_ENFORCE_EQ( ref_input_shape->shape.dims(j), @@ -341,9 +346,12 @@ void BoundShapeInferencer::InferCommonOp(const OperatorDef& op) { const OpSchema* schema = OpSchemaRegistry::Schema(op.type()); CAFFE_ENFORCE(schema); auto output_shapes = schema->InferTensor(op, input_shapes); - CAFFE_ENFORCE_EQ(output_shapes.size(), op.output_size()); int i = 0; for (const auto& shape : output_shapes) { + if (shape.unknown_shape()) { + ++i; + continue; + } CheckAndSetTensorShapeAndType( op.output(i++), current_dim_type_, diff --git a/caffe2/opt/bound_shape_inferencer.h b/caffe2/opt/bound_shape_inferencer.h index c70e242..9699a3b 100644 --- a/caffe2/opt/bound_shape_inferencer.h +++ b/caffe2/opt/bound_shape_inferencer.h @@ -89,7 +89,6 @@ class CAFFE2_API BoundShapeInferencer { ShapeInfo::DimType current_dim_type_{ShapeInfo::DimType::UNKNOWN}; int64_t current_max_batch_size_{0}; std::unordered_map shape_info_; - std::unordered_set visited_tensors_; }; } // namespace caffe2 -- 2.7.4