Add backward pass to infer single missing input shape for Concat opportunitiscally...
authorYinghai Lu <yinghai@fb.com>
Fri, 5 Apr 2019 17:09:14 +0000 (10:09 -0700)
committerFacebook Github Bot <facebook-github-bot@users.noreply.github.com>
Fri, 5 Apr 2019 17:11:58 +0000 (10:11 -0700)
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/18911

Att.

Reviewed By: bddppq

Differential Revision: D14791295

fbshipit-source-id: 4b7a775924f0eadb0cb73aa6c434a6a5be8b92be

caffe2/operators/concat_split_op.h
caffe2/opt/bound_shape_inference_test.cc
caffe2/opt/bound_shape_inferencer.cc
caffe2/opt/bound_shape_inferencer.h
caffe2/opt/onnxifi_transformer.cc
caffe2/utils/string_utils.h

index 47ed663..74d74e4 100644 (file)
@@ -5,24 +5,10 @@
 #include "caffe2/core/operator.h"
 #include "caffe2/core/types.h"
 #include "caffe2/utils/math.h"
+#include "caffe2/utils/string_utils.h"
 
 namespace caffe2 {
 
-namespace {
-inline int GetDimFromOrderString(const string& str) {
-  auto order = StringToStorageOrder(str);
-  switch (order) {
-    case StorageOrder::NHWC:
-      return 3;
-    case StorageOrder::NCHW:
-      return 1;
-    default:
-      CAFFE_THROW("Unsupported storage order: ", str);
-      return -1;
-  }
-}
-} // namespace
-
 template <class Context>
 class SplitOp final : public Operator<Context> {
  public:
index a148b0e..0efa878 100644 (file)
@@ -214,6 +214,40 @@ TEST(BoundShapeInference, ConcatMissingInput) {
       {spec.max_batch_size, 2, 60});
 }
 
+TEST(BoundShapeInference, ConcatInferInputBackwards) {
+  NetDef net;
+  net.add_op()->CopyFrom(CreateOperatorDef(
+      "Concat",
+      "",
+      {"I0", "I1"},
+      {"Cout", "split_info"},
+      {MakeArgument<int>("axis", 1)}));
+  net.add_op()->CopyFrom(
+      CreateOperatorDef("FCTransposed", "", {"Cout", "W0", "B0"}, {"Y"}, {}));
+  BoundShapeSpec spec(20, 1000);
+  ShapeInfoMap shape_map;
+  shape_map.emplace(
+      "I0",
+      makeTensorInfo(ShapeInfo::DimType::BATCH, {spec.max_batch_size, 60}));
+  shape_map.emplace(
+      "W0", makeTensorInfo(ShapeInfo::DimType::CONSTANT, {101, 16}));
+  shape_map.emplace("B0", makeTensorInfo(ShapeInfo::DimType::CONSTANT, {16}));
+  BoundShapeInferencer eng(spec);
+  eng.InferBoundShapeAndType(net, shape_map);
+  const auto& out_shape = eng.shape_info();
+  verifyShapeInfo(
+      out_shape, "I0", ShapeInfo::DimType::BATCH, {spec.max_batch_size, 60});
+  verifyShapeInfo(
+      out_shape, "Cout", ShapeInfo::DimType::BATCH, {spec.max_batch_size, 101});
+  verifyShapeInfo(
+      out_shape, "Y", ShapeInfo::DimType::BATCH, {spec.max_batch_size, 16});
+  verifyShapeInfo(
+      out_shape,
+      "I1",
+      ShapeInfo::DimType::BATCH,
+      {spec.max_batch_size, 101 - 60});
+}
+
 TEST(BoundShapeInference, Split) {
   NetDef net;
   net.add_op()->CopyFrom(CreateOperatorDef(
index b7c20d5..1d2f940 100644 (file)
@@ -79,6 +79,14 @@ void BoundShapeInferencer::InferBoundShapeAndType(
     }
   }
 
+  // Doing a reverse pass to infer the input shapes if applicable
+  for (int i = net.op_size() - 1; i >= 0; --i) {
+    const auto& op = net.op(i);
+    if (op.type() == "Concat") {
+      InferConcatInputs(op);
+    }
+  }
+
   // Make sure shape has name
   EnsureShapeNames(&shape_info_);
 }
@@ -251,6 +259,55 @@ void BoundShapeInferencer::InferReshape(const OperatorDef& op) {
     shape_info_[op.output(1)].dim_type = ShapeInfo::DimType::CONSTANT;
   }
 }
+
+void BoundShapeInferencer::InferConcatInputs(const OperatorDef& op) {
+  ArgumentHelper helper(op);
+  const auto add_axis = helper.GetSingleArgument<int32_t>("add_axis", 0);
+  if (add_axis) {
+    return;
+  } else if (op.output_size() == 0 || !shape_info_.count(op.output(0))) {
+    return;
+  }
+
+  const auto axis = helper.HasArgument("axis")
+      ? helper.GetSingleArgument<int32_t>("axis", -1)
+      : GetDimFromOrderString(
+            helper.GetSingleArgument<string>("order", "NCHW"));
+
+  const auto& shape_info = shape_info_.at(op.output(0));
+  int output_channel = shape_info.shape.dims(axis);
+  int missing_shape_infos = 0;
+  int channel_acc = 0;
+  std::string input_to_infer;
+  for (const auto& i : op.input()) {
+    const auto it = shape_info_.find(i);
+    if (it != shape_info_.end()) {
+      const auto& current_input_shape = it->second;
+      channel_acc += current_input_shape.shape.dims(axis);
+    } else if (missing_shape_infos) {
+      LOG(INFO) << "More than one missing shapes, previous one: "
+                << input_to_infer;
+      // We can only infer one missing input shape info
+      return;
+    } else {
+      ++missing_shape_infos;
+      input_to_infer = i;
+    }
+  }
+
+  if (missing_shape_infos && !input_to_infer.empty()) {
+    auto input_shape_info = shape_info;
+    input_shape_info.shape.set_dims(axis, output_channel - channel_acc);
+    shape_info_.emplace(input_to_infer, std::move(input_shape_info));
+
+    // Infer the shape of the second output of Concat
+    InferCommonOp(op);
+    if (op.output_size() > 1 && shape_info_.count(op.output(1))) {
+      shape_info_[op.output(1)].dim_type = ShapeInfo::DimType::CONSTANT;
+    }
+  }
+}
+
 // For concat net, if some inputs are missing and we have add_axis argument, it
 // means that all the inputs should be of the same dimension. In this case, we
 // can infer the shape of the missing inputs
@@ -399,7 +456,7 @@ void BoundShapeInferencer::InferCommonOp(const OperatorDef& op) {
       !(op.type().compare(0, 4, "Int8")) && (op.type() != "Int8Dequantize");
   TensorProto::DataType infered_data_type = TensorProto::UNDEFINED;
   if (is_quantized) {
-    const static std::map<string, int> type_info_from_input = {
+    const static std::map<std::string, int> type_info_from_input = {
         {"Int8Quantize", -1}, // Force this op's output to be uint8
         {"Int8ConvRelu", 1},
         {"Int8MaxPool", 0},
@@ -420,6 +477,7 @@ void BoundShapeInferencer::InferCommonOp(const OperatorDef& op) {
   } else if (op.type() == "Int8Dequantize") {
     infered_data_type = TensorProto::FLOAT;
   }
+
   for (const auto& shape : output_shapes) {
     if (infered_data_type == TensorProto::UNDEFINED) {
       infered_data_type = shape.data_type();
index ee1c670..216534e 100644 (file)
@@ -64,6 +64,8 @@ class CAFFE2_API BoundShapeInferencer {
       TensorProto::DataType type,
       bool is_quantized);
 
+  void InferConcatInputs(const OperatorDef& op);
+
   void InferGivenTensorFill(const OperatorDef& op);
   void InferSparseLengthsSum(const OperatorDef& op);
   void InferFC(const OperatorDef& op);
index 797c3f4..8ec572b 100644 (file)
@@ -826,6 +826,8 @@ bool OnnxifiTransformer::supportOpC2(
     for (const auto& i : op.input()) {
       const auto it = shape_hints.find(i);
       if (it == shape_hints.end()) {
+        VLOG(1) << "Skipping " << op.type() << " (" << pos
+                << ") due to missing shape info for input " << i;
         return false;
       }
       if ((it->second).is_quantized == false) {
@@ -844,6 +846,8 @@ bool OnnxifiTransformer::supportOpC2(
     for (const auto& i : op.output()) {
       const auto it = shape_hints.find(i);
       if (it == shape_hints.end()) {
+        VLOG(1) << "Skipping " << op.type() << " (" << pos
+                << ") due to missing shape info for output " << i;
         return false;
       }
       if ((it->second).is_quantized == false) {
index 3591866..ada947e 100644 (file)
@@ -6,6 +6,7 @@
 #include <vector>
 
 #include "caffe2/core/common.h"
+#include "caffe2/core/types.h"
 
 namespace caffe2 {
 
@@ -33,6 +34,19 @@ CAFFE2_API inline bool EndsWith(
   }
 }
 
+CAFFE2_API inline int32_t GetDimFromOrderString(const std::string& str) {
+  auto order = StringToStorageOrder(str);
+  switch (order) {
+    case StorageOrder::NHWC:
+      return 3;
+    case StorageOrder::NCHW:
+      return 1;
+    default:
+      CAFFE_THROW("Unsupported storage order: ", str);
+      return -1;
+  }
+}
+
 CAFFE2_API int32_t editDistanceHelper(const char* s1,
   size_t s1_len,
   const char* s2,