Clean up onnxifi transformation code (#15453)
authorYinghai Lu <yinghai@fb.com>
Fri, 21 Dec 2018 06:04:09 +0000 (22:04 -0800)
committerFacebook Github Bot <facebook-github-bot@users.noreply.github.com>
Fri, 21 Dec 2018 06:06:47 +0000 (22:06 -0800)
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/15453

Just move things around to facilitate further development. No logic change.

Reviewed By: rdzhabarov

Differential Revision: D13533959

fbshipit-source-id: eebab1306939e802aacffb24a711d372fd67916c

caffe2/opt/onnxifi_transformer.cc
caffe2/opt/onnxifi_transformer.h
caffe2/python/pybind_state.cc

index 25bb5c6..5292a30 100644 (file)
@@ -134,8 +134,8 @@ void FillModelInfo(::ONNX_NAMESPACE::ModelProto* model) {
 }
 } // namespace
 
-OnnxifiTransformer::OnnxifiTransformer(bool infer_shapes, bool debug)
-    : infer_shapes_(infer_shapes), debug_(debug) {
+OnnxifiTransformer::OnnxifiTransformer(const OnnxifiTransformerOptions& opts)
+    : opts_(opts) {
   lib_ = onnx::initOnnxifiLibrary();
   CAFFE_ENFORCE(lib_, "Cannot initialize ONNXIFI library");
   CAFFE_ENFORCE_EQ(
@@ -334,7 +334,7 @@ NetDef OnnxifiTransformer::SubnetToOnnxifiOp(
   }
 
   // Debugging stuff
-  if (debug_) {
+  if (opts_.debug) {
     DumpModel(onnx_model, "debug.onnxtxt");
   }
 
@@ -383,33 +383,21 @@ CaffeMap<std::string, TensorShape> OnnxifiTransformer::SsaRewriteAndMapNames(
   return shape_hints_ordered;
 }
 
-// Cutting off the runnable part and replace with ONNXIFI ops. Asssume the nets
-// were topologically sorted
-void OnnxifiTransformer::Transform(
+NetDef OnnxifiTransformer::TransformViaOnnx(
     Workspace* ws,
     NetDef* pred_net,
-    const std::vector<std::string>& external_inputs,
-    const std::unordered_map<std::string, TensorShape>& input_shape_hints,
-    const std::unordered_set<int>& blacklisted_ops) {
-  CAFFE_ENFORCE(ws);
-  auto shape_hints_ordered =
-      SsaRewriteAndMapNames(ws, pred_net, input_shape_hints);
-  Workspace mapped_ws(ws, input_mapping_);
-  std::unordered_map<std::string, TensorShape> shape_hints =
-      InferShapes(&mapped_ws, pred_net, &shape_hints_ordered, infer_shapes_);
-
-  CAFFE_ENFORCE(pred_net, "Predict net cannot be nullptr");
-  onnx::OnnxExporter exporter(nullptr);
-
-  // function to tell whether the ONNXIFI backend supports a given C2 op or not
-  // TODO: choose backend id
+    const std::unordered_set<std::string>& weights,
+    const std::unordered_set<int>& blacklisted_ops,
+    std::unordered_map<std::string, TensorShape>* shape_hints) {
   onnxifi_library* backend = lib_;
   onnxBackendID backend_id = backend_ids_[0];
-  auto supports = [&exporter,
-                   &shape_hints,
-                   &blacklisted_ops,
-                   backend,
-                   backend_id](const caffe2::OperatorDef& op) {
+  // function to tell whether the ONNXIFI backend supports a given C2 op or not
+  onnx::OnnxExporter exporter(nullptr);
+  auto onnx_supports = [&exporter,
+                        shape_hints,
+                        &blacklisted_ops,
+                        backend,
+                        backend_id](const caffe2::OperatorDef& op) {
     try {
       int pos =
           ArgumentHelper::GetSingleArgument<OperatorDef, int>(op, kNetPos, -1);
@@ -427,7 +415,7 @@ void OnnxifiTransformer::Transform(
 
       ::ONNX_NAMESPACE::ModelProto onnx_model;
       FillModelInfo(&onnx_model);
-      auto results = exporter.Caffe2OpToOnnxNodes(op, shape_hints);
+      auto results = exporter.Caffe2OpToOnnxNodes(op, *shape_hints);
       std::unordered_set<std::string> used_inputs;
       std::unordered_set<std::string> used_outputs;
       std::vector<std::string> boundary_inputs;
@@ -476,12 +464,12 @@ void OnnxifiTransformer::Transform(
 
       // Add input/output shape info
       auto io_vec =
-          ConvertToValueInfo(boundary_inputs, shape_hints, extra_shape_hints);
+          ConvertToValueInfo(boundary_inputs, *shape_hints, extra_shape_hints);
       for (const auto& i : io_vec) {
         onnx_model.mutable_graph()->add_input()->CopyFrom(i);
       }
       io_vec =
-          ConvertToValueInfo(boundary_outputs, shape_hints, extra_shape_hints);
+          ConvertToValueInfo(boundary_outputs, *shape_hints, extra_shape_hints);
       for (const auto& i : io_vec) {
         onnx_model.mutable_graph()->add_output()->CopyFrom(i);
       }
@@ -498,18 +486,45 @@ void OnnxifiTransformer::Transform(
         return true;
       }
     } catch (const std::exception& ex) {
-      LOG(ERROR) << "Gaught exception when converting op " << op.type()
+      LOG(ERROR) << "Caught exception when converting op " << op.type()
                  << ", what: " << ex.what();
       return false;
     }
   };
 
-  // function to convert runnable subgraph into a trt op. Note that to keep the
-  // interface clean, we do the double conversion from C2 op to Onnx ops here
-  // but it should be OK as the cost is really small. We also need to keep the
-  // same exporter throughout the process to avoid duplicated dummy name
+  // function to convert runnable subgraph into an onnxifi op. We need to keep
+  // the same exporter throughout the process to avoid duplicated dummy name
   // generation
   onnx::OnnxExporter exporter2(nullptr);
+  auto onnx_converter = [this, ws, &weights, shape_hints, &exporter2](
+                            const caffe2::NetDef& net) mutable {
+    return SubnetToOnnxifiOp(net, weights, ws, &exporter2, shape_hints);
+  };
+
+  return opt::OptimizeForBackend(*pred_net, onnx_supports, onnx_converter);
+}
+
+// Cutting off the runnable part and replace with ONNXIFI ops. Asssume the nets
+// were topologically sorted
+void OnnxifiTransformer::Transform(
+    Workspace* ws,
+    NetDef* pred_net,
+    const std::vector<std::string>& external_inputs,
+    const std::unordered_map<std::string, TensorShape>& input_shape_hints,
+    const std::unordered_set<int>& blacklisted_ops) {
+  CAFFE_ENFORCE(ws);
+  CAFFE_ENFORCE(pred_net, "Predict net cannot be nullptr");
+
+  // SSA Rewrite the net
+  auto shape_hints_ordered =
+      SsaRewriteAndMapNames(ws, pred_net, input_shape_hints);
+
+  // Populate shape info
+  Workspace mapped_ws(ws, input_mapping_);
+  std::unordered_map<std::string, TensorShape> shape_hints = InferShapes(
+      &mapped_ws, pred_net, &shape_hints_ordered, opts_.infer_shapes);
+
+  // Figure out what are the weights
   std::unordered_set<std::string> weights;
   std::unordered_set<std::string> input_set;
   for (const auto& i : external_inputs) {
@@ -524,12 +539,10 @@ void OnnxifiTransformer::Transform(
       weights.emplace(s);
     }
   }
-  auto trt_converter = [this, ws, &weights, &shape_hints, &exporter2](
-                           const caffe2::NetDef& net) mutable {
-    return SubnetToOnnxifiOp(net, weights, ws, &exporter2, &shape_hints);
-  };
 
-  NetDef net_opt = opt::OptimizeForBackend(*pred_net, supports, trt_converter);
+  // Transform the net
+  NetDef net_opt =
+      TransformViaOnnx(ws, pred_net, weights, blacklisted_ops, &shape_hints);
 
   // Need to figure out a proper place to handle device option
   net_opt.mutable_device_option()->CopyFrom(pred_net->device_option());
index e3eee2c..6286535 100644 (file)
@@ -18,9 +18,18 @@ namespace onnx {
 class OnnxExporter;
 }
 
+struct OnnxifiTransformerOptions {
+  // Run shape inference
+  bool infer_shapes{false};
+  // Dump onnx model for debugging
+  bool debug{false};
+  // Pass serialized onnx model if true, otherwise pass serialized c2 model
+  bool use_onnx{true};
+};
+
 class CAFFE2_API OnnxifiTransformer final {
  public:
-  explicit OnnxifiTransformer(bool infer_shapes, bool debug);
+  explicit OnnxifiTransformer(const OnnxifiTransformerOptions& opts);
   ~OnnxifiTransformer();
 
   void Transform(
@@ -60,11 +69,16 @@ class CAFFE2_API OnnxifiTransformer final {
       NetDef* pred_net,
       const std::unordered_map<std::string, TensorShape>& input_shape_hints);
 
-  // Run shape inference
-  bool infer_shapes_{false};
+  // Transform by passing ONNX proto to backend
+  NetDef TransformViaOnnx(
+      Workspace* ws,
+      NetDef* pred_net,
+      const std::unordered_set<std::string>& weights,
+      const std::unordered_set<int>& blacklisted_ops,
+      std::unordered_map<std::string, TensorShape>* shape_hints);
 
-  // Dump onnx model for debugging
-  bool debug_{false};
+  // Options
+  OnnxifiTransformerOptions opts_;
 
   // Pointer to loaded onnxifi library
   onnxifi_library* lib_{nullptr};
index a4a1509..52fe08c 100644 (file)
@@ -1618,7 +1618,10 @@ void addGlobalMethods(py::module& m) {
           tensor_shapes.emplace(
               it.first, CreateTensorShape(it.second, TensorProto::FLOAT));
         }
-        OnnxifiTransformer ts(infer_shapes, debug_builder);
+        OnnxifiTransformerOptions opts;
+        opts.infer_shapes = infer_shapes;
+        opts.debug = debug_builder;
+        OnnxifiTransformer ts(opts);
         ts.Transform(
             GetCurrentWorkspace(),
             &pred_net,