Insert AdjustBatchSizeOp into the predict_net. (#16811)
authorYing Zhang <yingz@fb.com>
Thu, 7 Feb 2019 08:33:29 +0000 (00:33 -0800)
committerFacebook Github Bot <facebook-github-bot@users.noreply.github.com>
Thu, 7 Feb 2019 08:40:11 +0000 (00:40 -0800)
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/16811

As the title. The AdjustBatch ops will be inserted before and after the Onnxifi op to:
1) adjust batch/seq sizes to the ideal batch/seq size before these tensors are processed by the Onnxifi op;
2) adjust batch size to the original batch size for batches generated by the Onnxifi op.

Reviewed By: yinghai

Differential Revision: D13967711

fbshipit-source-id: 471b25ae6a60bf5b7ebee1de6449e0389b6cafff

caffe2/opt/onnxifi_transformer.cc

index 0af59e9..0eaea78 100644 (file)
@@ -232,6 +232,147 @@ void FillModelInfo(::ONNX_NAMESPACE::ModelProto* model) {
   opset_id->set_domain("");
   opset_id->set_version(7);
 }
+
+string MkBatchSizeBlob() {
+  return "real_batch_size";
+}
+
+string MkSeqSizeBlob(const string& blob_name) {
+  return blob_name + "_real_seq_size";
+}
+
+string MkOutputForAdjustBatchOp(const string& input) {
+  return input + "_post_adjust_batch";
+}
+
+string MkInputForAdjustBatchOp(const string& output) {
+  return output + "_pre_adjust_batch";
+}
+
+OperatorDef MkAdjustBatchOp(
+    const string& input_blob,
+    const string& output_blob,
+    int max_batch_size,
+    const string& real_batch_size_blob,
+    bool adjust_to_max_batch_size) {
+  OperatorDef adjust_batch_op;
+  adjust_batch_op.set_type("AdjustBatch");
+  auto* arg = adjust_batch_op.add_arg();
+  arg->set_name("max_batch_size");
+  arg->set_i(max_batch_size);
+  adjust_batch_op.add_input(input_blob);
+  adjust_batch_op.add_output(output_blob);
+  if (adjust_to_max_batch_size) {
+    adjust_batch_op.add_output(real_batch_size_blob);
+  } else {
+    adjust_batch_op.add_input(real_batch_size_blob);
+  }
+  return adjust_batch_op;
+}
+
+std::unordered_set<string> ToHashSet(
+    const ::google::protobuf::RepeatedPtrField<string>& strs) {
+  return std::unordered_set<string>(strs.begin(), strs.end());
+}
+
+int64_t GetBlob1stDimSize(
+    const ShapeInfo& shape_info,
+    const string& blob_name) {
+  CAFFE_ENFORCE(
+      shape_info.shape.dims_size() > 0 && shape_info.shape.dims(0) > 0,
+      "Tensor " + blob_name +
+          " is type BATCH / SEQ, however the batch_size is unknown. " +
+          "Dims size: " + to_string(shape_info.shape.dims_size()) +
+          ", dim[0] = " + to_string(shape_info.shape.dims(0)));
+  return shape_info.shape.dims(0);
+}
+
+// Generates AdjustBatchOps for external inputs / outputs with type BATCH or
+// SEQ and adds them to input_ops and output_ops.
+// Meanwhile, modifies inputs / outputs of corresponding operators in the
+// wrapper_net to use the new inputs / outputs of AdjustBatchOps.
+void AddAdjustBatchOps(
+    const ShapeInfoMap& shape_hints,
+    NetDef* wrapper_net,
+    vector<OperatorDef>* input_ops,
+    vector<OperatorDef>* output_ops) {
+  const auto external_inputs = ToHashSet(wrapper_net->external_input());
+  const auto external_outputs = ToHashSet(wrapper_net->external_output());
+
+  for (auto& op : *(wrapper_net->mutable_op())) {
+    // Add AdjustBatchOp for all external inputs with type BATCH or SEQ.
+    // This will adjust the batch/seq size to the batch/seq size inferred by
+    // bound_shape_inference.
+    for (auto& input_blob : *(op.mutable_input())) {
+      if (external_inputs.count(input_blob)) {
+        auto shape_info_it = shape_hints.find(input_blob);
+        if (shape_info_it == shape_hints.end()) {
+          LOG(WARNING) << "Cannot find shape_info for external input blob: "
+                       << input_blob;
+          continue;
+        }
+        string real_batch_size_blob = "";
+        if (shape_info_it->second.dim_type == ShapeInfo::DimType::BATCH) {
+          real_batch_size_blob = MkBatchSizeBlob();
+        } else if (shape_info_it->second.dim_type == ShapeInfo::DimType::SEQ) {
+          real_batch_size_blob = MkSeqSizeBlob(input_blob);
+        } else {
+          continue;
+        }
+        auto output_blob = MkOutputForAdjustBatchOp(input_blob);
+        input_ops->push_back(MkAdjustBatchOp(
+            input_blob,
+            output_blob,
+            GetBlob1stDimSize(shape_info_it->second, input_blob),
+            real_batch_size_blob,
+            true /* adjust_to_max_batch_size */));
+        input_blob = output_blob;
+      }
+    }
+    // Add AdjustBatchOp for all external outputs with type BATCH.
+    // This will adjust the batch size to the original batch size.
+    for (auto& output_blob : *(op.mutable_output())) {
+      if (external_outputs.count(output_blob)) {
+        auto shape_info_it = shape_hints.find(output_blob);
+        if (shape_info_it == shape_hints.end()) {
+          continue;
+        }
+        if (shape_info_it->second.dim_type == ShapeInfo::DimType::BATCH) {
+          auto input_blob = MkInputForAdjustBatchOp(output_blob);
+          output_ops->push_back(MkAdjustBatchOp(
+              input_blob,
+              output_blob,
+              GetBlob1stDimSize(shape_info_it->second, output_blob),
+              MkBatchSizeBlob(),
+              false /* adjust_to_max_batch_size */));
+          output_blob = input_blob;
+        } else {
+          CAFFE_ENFORCE(
+              shape_info_it->second.dim_type != ShapeInfo::DimType::SEQ,
+              "Output tensor " + output_blob +
+                  " should never have dim_type SEQ.");
+        }
+      }
+    }
+  }
+}
+
+NetDef ComposeResultNet(
+    const vector<OperatorDef>& input_ops,
+    const vector<OperatorDef>& output_ops,
+    const OperatorDef& onnxifi_op) {
+  NetDef net_opt;
+  for (const auto& op : input_ops) {
+    *(net_opt.add_op()) = op;
+  }
+  *(net_opt.add_op()) = onnxifi_op;
+  // Add AdjustBatch ops for output blobs to the net.
+  for (const auto& op : output_ops) {
+    *(net_opt.add_op()) = op;
+  }
+  return net_opt;
+}
+
 } // namespace
 
 OnnxifiTransformer::OnnxifiTransformer(const OnnxifiTransformerOptions& opts)
@@ -362,6 +503,10 @@ NetDef OnnxifiTransformer::SubnetToOnnxifiOpViaC2(
     }
   }
 
+  vector<OperatorDef> input_ops;
+  vector<OperatorDef> output_ops;
+  AddAdjustBatchOps(shape_hints, &wrapper_net, &input_ops, &output_ops);
+
   // Figure out weights and add it to external_inputs too
   std::vector<std::string> extra_weights;
   std::unordered_set<std::string> initialization_list;
@@ -381,23 +526,17 @@ NetDef OnnxifiTransformer::SubnetToOnnxifiOpViaC2(
         WrapShapeInfoIntoTensorProto(i, shape_hints.at(i)));
   }
 
-  // Debugging stuff
-  if (opts_.debug) {
-    WriteProtoToTextFile(wrapper_net, "debug.pb_txt");
-  }
-
-  // C2 model is ready. Build ONNXIFI Op
+  // Build ONNXIFI Op
   std::string model_str;
   wrapper_net.SerializeToString(&model_str);
-  NetDef net_opt;
-  auto* op = net_opt.add_op();
-  *op = BuildOnnxifiOp(
+  auto onnxifi_op = BuildOnnxifiOp(
       model_str, output_shape_hints, initialization_list, net_copy);
-  for (const auto& i : op->input()) {
-    net_opt.add_external_input(i);
-  }
-  for (const auto& o : op->output()) {
-    net_opt.add_external_output(o);
+  NetDef net_opt = ComposeResultNet(input_ops, output_ops, onnxifi_op);
+
+  // Debugging stuff
+  if (opts_.debug) {
+    WriteProtoToTextFile(wrapper_net, "debug_wrapper_net.pb_txt");
+    WriteProtoToTextFile(net_opt, "debug_optimized_net.pb_txt");
   }
   return net_opt;
 }
@@ -412,6 +551,11 @@ NetDef OnnxifiTransformer::SubnetToOnnxifiOpViaOnnx(
   ::ONNX_NAMESPACE::ModelProto onnx_model;
   FillModelInfo(&onnx_model);
 
+  caffe2::NetDef wrapper_net(net);
+  vector<OperatorDef> input_ops;
+  vector<OperatorDef> output_ops;
+  AddAdjustBatchOps(*shape_hints, &wrapper_net, &input_ops, &output_ops);
+
   // Convert c2 ops to onnx ops, add const weights if there are any
   DeviceOption option;
   CPUContext context(option);
@@ -511,15 +655,9 @@ NetDef OnnxifiTransformer::SubnetToOnnxifiOpViaOnnx(
   // Onnx model is ready. Build ONNXIFI Op
   std::string model_str;
   onnx_model.SerializeToString(&model_str);
-  NetDef net_opt;
-  auto* op = net_opt.add_op();
-  *op = BuildOnnxifiOp(model_str, output_shape_hints, initialization_list, net);
-  for (const auto& i : op->input()) {
-    net_opt.add_external_input(i);
-  }
-  for (const auto& i : op->output()) {
-    net_opt.add_external_output(i);
-  }
+  auto onnxifi_op =
+      BuildOnnxifiOp(model_str, output_shape_hints, initialization_list, net);
+  NetDef net_opt = ComposeResultNet(input_ops, output_ops, onnxifi_op);
 
   return net_opt;
 }