Use bound shape inference in SparseNN tests (#16834)
authorYinghai Lu <yinghai@fb.com>
Thu, 7 Feb 2019 22:11:44 +0000 (14:11 -0800)
committerFacebook Github Bot <facebook-github-bot@users.noreply.github.com>
Thu, 7 Feb 2019 22:51:32 +0000 (14:51 -0800)
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/16834

Inserting AdjustBatch ops will possibly change the names of the input/output, so we need to create a mapping and use the renamed names for external_inputs/outputs and input_shape_info for the onnxifi_net.

Reviewed By: ipiszy

Differential Revision: D13982731

fbshipit-source-id: c18b8a03d01490162929b2ca30c182d166001626

caffe2/opt/bound_shape_inferencer.cc
caffe2/opt/onnxifi_transformer.cc
caffe2/opt/onnxifi_transformer.h

index 1ea1609..166532d 100644 (file)
@@ -47,7 +47,7 @@ void BoundShapeInferencer::InferBoundShapeAndType(
   visited_tensors_.clear();
 
   for (const auto& op : net.op()) {
-    LOG(INFO) << op.type();
+    VLOG(1) << op.type();
     if (op.type() == "SparseLengthsSum" ||
         op.type() == "SparseLengthsSumFused8BitRowwise") {
       InferSparseLengthsSum(op);
@@ -215,6 +215,10 @@ void BoundShapeInferencer::InferConcat(const OperatorDef& op) {
     }
   }
   InferCommonOp(op);
+  // split_info should be a constant
+  if (op.output_size() > 1) {
+    shape_info_[op.output(1)].dim_type = ShapeInfo::DimType::CONSTANT;
+  }
 }
 
 void BoundShapeInferencer::InferFC(const OperatorDef& op) {
index 0eaea78..7b8ee34 100644 (file)
@@ -21,6 +21,7 @@ using ShapeInfoMap = std::unordered_map<std::string, ShapeInfo>;
 
 const std::string kNetPos("net_pos");
 const std::string kModelId("model_id");
+const std::string kRealBatchSizeBlob("real_batch_size");
 constexpr size_t kBufferSize = 64;
 
 void AnnotateOpIndex(NetDef* net) {
@@ -118,8 +119,7 @@ ShapeInfoMap InferShapes(
       shape_map.emplace(
           std::piecewise_construct,
           std::forward_as_tuple(kv.first),
-          std::forward_as_tuple(
-              ShapeInfo::DimType::CONSTANT, kv.second.shape));
+          std::forward_as_tuple(kv.second.dim_type, kv.second.shape));
     }
   } else {
     // TODO: deprecate this path
@@ -233,27 +233,23 @@ void FillModelInfo(::ONNX_NAMESPACE::ModelProto* model) {
   opset_id->set_version(7);
 }
 
-string MkBatchSizeBlob() {
-  return "real_batch_size";
-}
-
-string MkSeqSizeBlob(const string& blob_name) {
+std::string MakeSeqSizeBlob(const std::string& blob_name) {
   return blob_name + "_real_seq_size";
 }
 
-string MkOutputForAdjustBatchOp(const string& input) {
+std::string MakeOutputForAdjustBatchOp(const std::string& input) {
   return input + "_post_adjust_batch";
 }
 
-string MkInputForAdjustBatchOp(const string& output) {
+std::string MakeInputForAdjustBatchOp(const std::string& output) {
   return output + "_pre_adjust_batch";
 }
 
-OperatorDef MkAdjustBatchOp(
-    const string& input_blob,
-    const string& output_blob,
+OperatorDef MakeAdjustBatchOp(
+    const std::string& input_blob,
+    const std::string& output_blob,
     int max_batch_size,
-    const string& real_batch_size_blob,
+    const std::string& real_batch_size_blob,
     bool adjust_to_max_batch_size) {
   OperatorDef adjust_batch_op;
   adjust_batch_op.set_type("AdjustBatch");
@@ -263,7 +259,9 @@ OperatorDef MkAdjustBatchOp(
   adjust_batch_op.add_input(input_blob);
   adjust_batch_op.add_output(output_blob);
   if (adjust_to_max_batch_size) {
-    adjust_batch_op.add_output(real_batch_size_blob);
+    if (!real_batch_size_blob.empty()) {
+      adjust_batch_op.add_output(real_batch_size_blob);
+    }
   } else {
     adjust_batch_op.add_input(real_batch_size_blob);
   }
@@ -290,19 +288,22 @@ int64_t GetBlob1stDimSize(
 // Generates AdjustBatchOps for external inputs / outputs with type BATCH or
 // SEQ and adds them to input_ops and output_ops.
 // Meanwhile, modifies inputs / outputs of corresponding operators in the
-// wrapper_net to use the new inputs / outputs of AdjustBatchOps.
-void AddAdjustBatchOps(
+// onnxifi_net to use the new inputs / outputs of AdjustBatchOps.
+std::unordered_map<std::string, std::string> AddAdjustBatchOps(
     const ShapeInfoMap& shape_hints,
-    NetDef* wrapper_net,
+    NetDef* onnxifi_net,
     vector<OperatorDef>* input_ops,
     vector<OperatorDef>* output_ops) {
-  const auto external_inputs = ToHashSet(wrapper_net->external_input());
-  const auto external_outputs = ToHashSet(wrapper_net->external_output());
+  std::unordered_map<std::string, std::string> renaming_map;
+  const auto external_inputs = ToHashSet(onnxifi_net->external_input());
+  const auto external_outputs = ToHashSet(onnxifi_net->external_output());
+  std::unordered_set<std::string> real_batch_size_blobs;
 
-  for (auto& op : *(wrapper_net->mutable_op())) {
+  for (auto& op : *(onnxifi_net->mutable_op())) {
     // Add AdjustBatchOp for all external inputs with type BATCH or SEQ.
     // This will adjust the batch/seq size to the batch/seq size inferred by
-    // bound_shape_inference.
+    // bound_shape_inference. Note that we only produce real batch size tensor
+    // once to avoid data race
     for (auto& input_blob : *(op.mutable_input())) {
       if (external_inputs.count(input_blob)) {
         auto shape_info_it = shape_hints.find(input_blob);
@@ -313,24 +314,27 @@ void AddAdjustBatchOps(
         }
         string real_batch_size_blob = "";
         if (shape_info_it->second.dim_type == ShapeInfo::DimType::BATCH) {
-          real_batch_size_blob = MkBatchSizeBlob();
+          real_batch_size_blob = kRealBatchSizeBlob;
         } else if (shape_info_it->second.dim_type == ShapeInfo::DimType::SEQ) {
-          real_batch_size_blob = MkSeqSizeBlob(input_blob);
+          real_batch_size_blob = MakeSeqSizeBlob(input_blob);
         } else {
           continue;
         }
-        auto output_blob = MkOutputForAdjustBatchOp(input_blob);
-        input_ops->push_back(MkAdjustBatchOp(
+        auto output_blob = MakeOutputForAdjustBatchOp(input_blob);
+        auto ret = real_batch_size_blobs.emplace(real_batch_size_blob);
+        input_ops->push_back(MakeAdjustBatchOp(
             input_blob,
             output_blob,
             GetBlob1stDimSize(shape_info_it->second, input_blob),
-            real_batch_size_blob,
+            ret.second ? real_batch_size_blob : "",
             true /* adjust_to_max_batch_size */));
+        renaming_map[input_blob] = output_blob;
         input_blob = output_blob;
       }
     }
-    // Add AdjustBatchOp for all external outputs with type BATCH.
-    // This will adjust the batch size to the original batch size.
+    // Add AdjustBatchOp for all external outputs with type BATCH if the real
+    // batch size is presented. This will adjust the batch size to the original
+    // batch size.
     for (auto& output_blob : *(op.mutable_output())) {
       if (external_outputs.count(output_blob)) {
         auto shape_info_it = shape_hints.find(output_blob);
@@ -338,13 +342,17 @@ void AddAdjustBatchOps(
           continue;
         }
         if (shape_info_it->second.dim_type == ShapeInfo::DimType::BATCH) {
-          auto input_blob = MkInputForAdjustBatchOp(output_blob);
-          output_ops->push_back(MkAdjustBatchOp(
+          if (!real_batch_size_blobs.count(kRealBatchSizeBlob)) {
+            continue;
+          }
+          auto input_blob = MakeInputForAdjustBatchOp(output_blob);
+          output_ops->push_back(MakeAdjustBatchOp(
               input_blob,
               output_blob,
               GetBlob1stDimSize(shape_info_it->second, output_blob),
-              MkBatchSizeBlob(),
+              kRealBatchSizeBlob,
               false /* adjust_to_max_batch_size */));
+          renaming_map[output_blob] = input_blob;
           output_blob = input_blob;
         } else {
           CAFFE_ENFORCE(
@@ -355,6 +363,8 @@ void AddAdjustBatchOps(
       }
     }
   }
+
+  return renaming_map;
 }
 
 NetDef ComposeResultNet(
@@ -363,12 +373,12 @@ NetDef ComposeResultNet(
     const OperatorDef& onnxifi_op) {
   NetDef net_opt;
   for (const auto& op : input_ops) {
-    *(net_opt.add_op()) = op;
+    net_opt.add_op()->CopyFrom(op);
   }
-  *(net_opt.add_op()) = onnxifi_op;
+  net_opt.add_op()->CopyFrom(onnxifi_op);
   // Add AdjustBatch ops for output blobs to the net.
   for (const auto& op : output_ops) {
-    *(net_opt.add_op()) = op;
+    net_opt.add_op()->CopyFrom(op);
   }
   return net_opt;
 }
@@ -402,7 +412,8 @@ OperatorDef OnnxifiTransformer::BuildOnnxifiOp(
     const std::string& onnx_model_str,
     const std::unordered_map<std::string, TensorShape>& output_shape_hints,
     const std::unordered_set<std::string>& initialization_list,
-    const caffe2::NetDef& net) {
+    const std::vector<std::string>& external_inputs,
+    const std::vector<std::string>& external_outputs) {
   OperatorDef op;
   op.set_type("Onnxifi");
   auto* onnx_model_arg = op.add_arg();
@@ -421,7 +432,7 @@ OperatorDef OnnxifiTransformer::BuildOnnxifiOp(
   // Add the input/output
   auto* input_names = op.add_arg();
   input_names->set_name("input_names");
-  for (const auto& input : net.external_input()) {
+  for (const auto& input : external_inputs) {
     if (!initialization_list.count(input)) {
       op.add_input(input);
       input_names->add_strings(input);
@@ -429,7 +440,7 @@ OperatorDef OnnxifiTransformer::BuildOnnxifiOp(
   }
   auto* output_names = op.add_arg();
   output_names->set_name("output_names");
-  for (const auto& output : net.external_output()) {
+  for (const auto& output : external_outputs) {
     op.add_output(output);
     output_names->add_strings(output);
   }
@@ -469,17 +480,7 @@ NetDef OnnxifiTransformer::SubnetToOnnxifiOpViaC2(
     const std::unordered_set<std::string>& weights_in_ws,
     const ShapeInfoMap& shape_hints) {
   // We already have all the ops and external inputs and outputs!
-  NetDef wrapper_net(net);
-
-  // Compute output shape hints
-  std::unordered_map<std::string, TensorShape> output_shape_hints;
-  for (const auto& o : wrapper_net.external_output()) {
-    const auto it = shape_hints.find(o);
-    CAFFE_ENFORCE(
-        it != shape_hints.end(), "Cannot find shape info for output ", o);
-    const auto& shape = it->second.shape;
-    output_shape_hints.emplace(o, shape);
-  }
+  NetDef onnxifi_net(net);
 
   // Remove the second output of Concat from external_output. In addition, we
   // remove those outputs from the Onnxifi op too.
@@ -488,54 +489,84 @@ NetDef OnnxifiTransformer::SubnetToOnnxifiOpViaC2(
   // where we statically computes the split_info given input shape and insert a
   // GivenTensorIntFill op
   std::unordered_set<std::string> split_infos;
-  NetDef net_copy(net);
-  for (auto& op : *wrapper_net.mutable_op()) {
+  for (auto& op : *onnxifi_net.mutable_op()) {
     if (op.type() == "Concat" && op.output_size() == 2) {
       split_infos.emplace(op.output(1));
     }
   }
-  wrapper_net.clear_external_output();
-  net_copy.clear_external_output();
+  onnxifi_net.clear_external_output();
   for (const auto& o : net.external_output()) {
     if (!split_infos.count(o)) {
-      wrapper_net.add_external_output(o);
-      net_copy.add_external_output(o);
+      onnxifi_net.add_external_output(o);
     }
   }
 
+  // Insert AdjustBatch ops, note that this step will possibly change the names
+  // of the input/output, so we need to create a mapping and use the renamed
+  // names for external_inputs/outputs and input_shape_info for the onnxifi_net.
   vector<OperatorDef> input_ops;
   vector<OperatorDef> output_ops;
-  AddAdjustBatchOps(shape_hints, &wrapper_net, &input_ops, &output_ops);
+  auto renaming_map =
+      AddAdjustBatchOps(shape_hints, &onnxifi_net, &input_ops, &output_ops);
 
   // Figure out weights and add it to external_inputs too
-  std::vector<std::string> extra_weights;
   std::unordered_set<std::string> initialization_list;
   std::vector<std::string> total_inputs_vec;
   GetWeightsAndInputs(
       net,
       weights_in_ws,
-      extra_weights,
+      std::vector<std::string>(),
       &initialization_list,
       &total_inputs_vec);
-  auto* shape_arg = wrapper_net.add_arg();
+  auto* shape_arg = onnxifi_net.add_arg();
   shape_arg->set_name("input_shape_info");
-  wrapper_net.clear_external_input();
+  onnxifi_net.clear_external_input();
   for (const auto& i : total_inputs_vec) {
-    wrapper_net.add_external_input(i);
+    auto input = i;
+    const auto it = renaming_map.find(i);
+    if (it != renaming_map.end()) {
+      input = it->second;
+    }
+    onnxifi_net.add_external_input(input);
     shape_arg->mutable_tensors()->Add()->CopyFrom(
-        WrapShapeInfoIntoTensorProto(i, shape_hints.at(i)));
+        WrapShapeInfoIntoTensorProto(input, shape_hints.at(i)));
+  }
+
+  // Compute output shape hints
+  std::unordered_map<std::string, TensorShape> output_shape_hints;
+  for (auto& o : *onnxifi_net.mutable_external_output()) {
+    auto output = o;
+    const auto rit = renaming_map.find(o);
+    if (rit != renaming_map.end()) {
+      output = rit->second;
+    }
+    const auto it = shape_hints.find(o);
+    CAFFE_ENFORCE(
+        it != shape_hints.end(), "Cannot find shape info for output ", o);
+    const auto& shape = it->second.shape;
+    output_shape_hints.emplace(output, shape);
+    o = output;
   }
 
   // Build ONNXIFI Op
+  std::vector<std::string> onnxifi_net_inputs(
+      onnxifi_net.external_input().begin(), onnxifi_net.external_input().end());
+  std::vector<std::string> onnxifi_net_outputs(
+      onnxifi_net.external_output().begin(),
+      onnxifi_net.external_output().end());
   std::string model_str;
-  wrapper_net.SerializeToString(&model_str);
+  onnxifi_net.SerializeToString(&model_str);
   auto onnxifi_op = BuildOnnxifiOp(
-      model_str, output_shape_hints, initialization_list, net_copy);
+      model_str,
+      output_shape_hints,
+      initialization_list,
+      onnxifi_net_inputs,
+      onnxifi_net_outputs);
   NetDef net_opt = ComposeResultNet(input_ops, output_ops, onnxifi_op);
 
   // Debugging stuff
   if (opts_.debug) {
-    WriteProtoToTextFile(wrapper_net, "debug_wrapper_net.pb_txt");
+    WriteProtoToTextFile(onnxifi_net, "debug_onnxifi_net.pb_txt");
     WriteProtoToTextFile(net_opt, "debug_optimized_net.pb_txt");
   }
   return net_opt;
@@ -551,17 +582,21 @@ NetDef OnnxifiTransformer::SubnetToOnnxifiOpViaOnnx(
   ::ONNX_NAMESPACE::ModelProto onnx_model;
   FillModelInfo(&onnx_model);
 
-  caffe2::NetDef wrapper_net(net);
+  caffe2::NetDef onnxifi_net(net);
   vector<OperatorDef> input_ops;
   vector<OperatorDef> output_ops;
-  AddAdjustBatchOps(*shape_hints, &wrapper_net, &input_ops, &output_ops);
+  auto renaming_map =
+      AddAdjustBatchOps(*shape_hints, &onnxifi_net, &input_ops, &output_ops);
+  for (const auto& kv : renaming_map) {
+    shape_hints_onnx->emplace(kv.second, shape_hints_onnx->at(kv.first));
+  }
 
   // Convert c2 ops to onnx ops, add const weights if there are any
   DeviceOption option;
   CPUContext context(option);
   context.SwitchToDevice();
   std::vector<std::string> extra_weights;
-  for (const auto& op : net.op()) {
+  for (const auto& op : onnxifi_net.op()) {
     const auto results = exporter->Caffe2OpToOnnxNodes(op, *shape_hints_onnx);
     for (const auto& n : results.first) {
       onnx_model.mutable_graph()->add_node()->CopyFrom(n);
@@ -610,12 +645,17 @@ NetDef OnnxifiTransformer::SubnetToOnnxifiOpViaOnnx(
   }
 
   // Convert outputs and compute output shape hints
-  std::vector<std::string> io_names;
-  for (const auto& output : net.external_output()) {
-    io_names.emplace_back(output);
+  std::vector<std::string> onnxifi_net_outputs;
+  for (const auto& o : net.external_output()) {
+    auto output = o;
+    const auto it = renaming_map.find(o);
+    if (it != renaming_map.end()) {
+      output = it->second;
+    }
+    onnxifi_net_outputs.emplace_back(output);
   }
   auto io_vec = ConvertToValueInfo(
-      io_names,
+      onnxifi_net_outputs,
       *shape_hints_onnx,
       std::unordered_map<std::string, ::ONNX_NAMESPACE::TypeProto>());
   std::unordered_map<std::string, TensorShape> output_shape_hints;
@@ -632,33 +672,43 @@ NetDef OnnxifiTransformer::SubnetToOnnxifiOpViaOnnx(
 
   // Convert inputs and figure out weights
   std::unordered_set<std::string> initialization_list;
-  std::vector<std::string> total_inputs_vec;
+  std::vector<std::string> onnxifi_net_inputs;
   GetWeightsAndInputs(
       net,
       weights_in_ws,
       extra_weights,
       &initialization_list,
-      &total_inputs_vec);
+      &onnxifi_net_inputs);
+  for (auto& i : onnxifi_net_inputs) {
+    const auto it = renaming_map.find(i);
+    if (it != renaming_map.end()) {
+      i = it->second;
+    }
+  }
   io_vec = ConvertToValueInfo(
-      total_inputs_vec,
+      onnxifi_net_inputs,
       *shape_hints_onnx,
       std::unordered_map<std::string, ::ONNX_NAMESPACE::TypeProto>());
   for (const auto& i : io_vec) {
     onnx_model.mutable_graph()->add_input()->CopyFrom(i);
   }
 
-  // Debugging stuff
-  if (opts_.debug) {
-    WriteProtoToTextFile(onnx_model, "debug.onnx_txt");
-  }
-
   // Onnx model is ready. Build ONNXIFI Op
   std::string model_str;
   onnx_model.SerializeToString(&model_str);
-  auto onnxifi_op =
-      BuildOnnxifiOp(model_str, output_shape_hints, initialization_list, net);
+  auto onnxifi_op = BuildOnnxifiOp(
+      model_str,
+      output_shape_hints,
+      initialization_list,
+      onnxifi_net_inputs,
+      onnxifi_net_outputs);
   NetDef net_opt = ComposeResultNet(input_ops, output_ops, onnxifi_op);
 
+  // Debugging stuff
+  if (opts_.debug) {
+    WriteProtoToTextFile(onnx_model, "debug_onnxifi_net.onnx_txt");
+    WriteProtoToTextFile(net_opt, "debug_optimized_net.pb_txt");
+  }
   return net_opt;
 }
 
index b178909..a7ba90a 100644 (file)
@@ -79,7 +79,8 @@ class CAFFE2_API OnnxifiTransformer final {
       const std::string& onnx_model_str,
       const std::unordered_map<std::string, TensorShape>& output_size_hints,
       const std::unordered_set<std::string>& initialization_list,
-      const caffe2::NetDef& net);
+      const std::vector<std::string>& external_inputs,
+      const std::vector<std::string>& external_outputs);
 
   CaffeMap<std::string, TensorShape> SsaRewriteAndMapNames(
       Workspace* ws,