Catch exceptions in bound_shape_inference (#17775)
authorYinghai Lu <yinghai@fb.com>
Fri, 8 Mar 2019 21:15:05 +0000 (13:15 -0800)
committerFacebook Github Bot <facebook-github-bot@users.noreply.github.com>
Fri, 8 Mar 2019 21:18:28 +0000 (13:18 -0800)
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/17775

Handles use input shape hint properly.

Reviewed By: zrphercule

Differential Revision: D14368735

fbshipit-source-id: 504cd96589e47aa432617e56362aa6b01a25ba9b

caffe2/opt/backend_transformer_base.cc
caffe2/opt/bound_shape_inferencer.cc
caffe2/opt/bound_shape_inferencer.h
caffe2/utils/string_utils.h

index f23a7a0..a4db8b6 100644 (file)
@@ -78,11 +78,15 @@ ShapeInfoMap BackendTransformerBase::inferShapes(
       shape_map[s] = shape_info;
     }
   }
+  // We treat hinted shapes as BATCH. If there are shape hints on blobs in the
+  // workspace, since they are already inserted as CONSTANT, it will take effect
+  // here. For SEQ typed tensors, there are only a few of them and they will be
+  // handled by BoundShapeInferencer.
   for (const auto& kv : shape_hints_mapped) {
     shape_map.emplace(
         std::piecewise_construct,
         std::forward_as_tuple(kv.first),
-        std::forward_as_tuple(ShapeInfo::DimType::CONSTANT, kv.second));
+        std::forward_as_tuple(ShapeInfo::DimType::BATCH, kv.second));
   }
   BoundShapeInferencer eng(spec);
   eng.InferBoundShapeAndType(*pred_net, shape_map);
index 990220a..e56c8d2 100644 (file)
@@ -2,6 +2,7 @@
 #include "caffe2/core/operator_schema.h"
 #include "caffe2/core/tensor_impl.h"
 #include "caffe2/utils/proto_utils.h"
+#include "caffe2/utils/string_utils.h"
 
 namespace caffe2 {
 
@@ -60,6 +61,10 @@ void BoundShapeInferencer::InferBoundShapeAndType(
       InferReshape(op);
     } else if (op.type() == "LengthsRangeFill") {
       InferLengthsRangeFill(op);
+    } else if (
+        caffe2::StartsWith(op.type(), "GivenTensor") &&
+        caffe2::EndsWith(op.type(), "Fill")) {
+      InferGivenTensorFill(op);
     } else {
       InferCommonOp(op);
     }
@@ -122,6 +127,15 @@ std::vector<TensorShape> InferOutput(
   return schema->InferTensor(op, input_shapes);
 }
 
+void BoundShapeInferencer::InferGivenTensorFill(const OperatorDef& op) {
+  CAFFE_ENFORCE_EQ(op.output_size(), 1, op.type(), " must have 1 output");
+  InferCommonOp(op);
+  auto it = shape_info_.find(op.output(0));
+  if (it != shape_info_.end()) {
+    it->second.dim_type = ShapeInfo::DimType::CONSTANT;
+  }
+}
+
 void BoundShapeInferencer::InferLengthsRangeFill(const OperatorDef& op) {
   CAFFE_ENFORCE_EQ(op.input_size(), 1, "LengthsRangeFill must have 1 input");
   CAFFE_ENFORCE_EQ(op.output_size(), 1, "LengthsRangeFill must have 1 output");
@@ -342,6 +356,7 @@ void BoundShapeInferencer::InferFC(const OperatorDef& op) {
 void BoundShapeInferencer::InferCommonOp(const OperatorDef& op) {
   // First, we need to check that all the input shape/types are already
   // presented
+  try {
   std::vector<TensorShape> input_shapes;
   for (const auto& input : op.input()) {
     const auto it = shape_info_.find(input);
@@ -356,11 +371,7 @@ void BoundShapeInferencer::InferCommonOp(const OperatorDef& op) {
   const OpSchema* schema = OpSchemaRegistry::Schema(op.type());
   CAFFE_ENFORCE(schema);
   std::vector<TensorShape> output_shapes;
-  try {
     output_shapes = schema->InferTensor(op, input_shapes);
-  } catch (const std::exception& e) {
-    LOG(WARNING) << "Caught exception while inferring shapes for " << op.type();
-  }
   int i = 0;
   for (const auto& shape : output_shapes) {
     if (shape.unknown_shape()) {
@@ -373,6 +384,13 @@ void BoundShapeInferencer::InferCommonOp(const OperatorDef& op) {
         ConvertToVec(shape.dims()),
         shape.data_type());
   }
+  } catch (const caffe2::EnforceNotMet& e) {
+    LOG(ERROR) << "Enforce not met while inferring shapes for " << op.type()
+               << ": " << e.msg();
+  } catch (const std::exception& e) {
+    LOG(WARNING) << "Caught exception while inferring shapes for " << op.type()
+                 << ": " << e.what();
+  }
 }
 
 } // namespace caffe2
index dafac5b..ef6fa04 100644 (file)
@@ -63,6 +63,7 @@ class CAFFE2_API BoundShapeInferencer {
       std::vector<int64_t> bound_dims,
       TensorProto::DataType type);
 
+  void InferGivenTensorFill(const OperatorDef& op);
   void InferSparseLengthsSum(const OperatorDef& op);
   void InferFC(const OperatorDef& op);
   void InferConcat(const OperatorDef& op);
@@ -74,7 +75,7 @@ class CAFFE2_API BoundShapeInferencer {
   void InferCommonOp(const OperatorDef& op);
 
   const BoundShapeSpec spec_;
-  ShapeInfo::DimType current_dim_type_{ShapeInfo::DimType::UNKNOWN};
+  ShapeInfo::DimType current_dim_type_{ShapeInfo::DimType::BATCH};
   int64_t current_max_batch_size_{0};
   std::unordered_map<std::string, ShapeInfo> shape_info_;
 };
index d004eac..3591866 100644 (file)
@@ -21,6 +21,18 @@ CAFFE2_API inline bool StartsWith(const std::string& str, const std::string& pre
       prefix.end();
 }
 
+CAFFE2_API inline bool EndsWith(
+    const std::string& full,
+    const std::string& ending) {
+  if (full.length() >= ending.length()) {
+    return (
+        0 ==
+        full.compare(full.length() - ending.length(), ending.length(), ending));
+  } else {
+    return false;
+  }
+}
+
 CAFFE2_API int32_t editDistanceHelper(const char* s1,
   size_t s1_len,
   const char* s2,