From cb79e1b3a5489cf3030a1588903c880fc4441b76 Mon Sep 17 00:00:00 2001 From: Yinghai Lu Date: Thu, 20 Dec 2018 22:04:09 -0800 Subject: [PATCH] Clean up onnxifi transformation code (#15453) Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/15453 Just move things around to facilitate further development. No logic change. Reviewed By: rdzhabarov Differential Revision: D13533959 fbshipit-source-id: eebab1306939e802aacffb24a711d372fd67916c --- caffe2/opt/onnxifi_transformer.cc | 91 ++++++++++++++++++++++----------------- caffe2/opt/onnxifi_transformer.h | 24 ++++++++--- caffe2/python/pybind_state.cc | 5 ++- 3 files changed, 75 insertions(+), 45 deletions(-) diff --git a/caffe2/opt/onnxifi_transformer.cc b/caffe2/opt/onnxifi_transformer.cc index 25bb5c6..5292a30 100644 --- a/caffe2/opt/onnxifi_transformer.cc +++ b/caffe2/opt/onnxifi_transformer.cc @@ -134,8 +134,8 @@ void FillModelInfo(::ONNX_NAMESPACE::ModelProto* model) { } } // namespace -OnnxifiTransformer::OnnxifiTransformer(bool infer_shapes, bool debug) - : infer_shapes_(infer_shapes), debug_(debug) { +OnnxifiTransformer::OnnxifiTransformer(const OnnxifiTransformerOptions& opts) + : opts_(opts) { lib_ = onnx::initOnnxifiLibrary(); CAFFE_ENFORCE(lib_, "Cannot initialize ONNXIFI library"); CAFFE_ENFORCE_EQ( @@ -334,7 +334,7 @@ NetDef OnnxifiTransformer::SubnetToOnnxifiOp( } // Debugging stuff - if (debug_) { + if (opts_.debug) { DumpModel(onnx_model, "debug.onnxtxt"); } @@ -383,33 +383,21 @@ CaffeMap OnnxifiTransformer::SsaRewriteAndMapNames( return shape_hints_ordered; } -// Cutting off the runnable part and replace with ONNXIFI ops. Asssume the nets -// were topologically sorted -void OnnxifiTransformer::Transform( +NetDef OnnxifiTransformer::TransformViaOnnx( Workspace* ws, NetDef* pred_net, - const std::vector& external_inputs, - const std::unordered_map& input_shape_hints, - const std::unordered_set& blacklisted_ops) { - CAFFE_ENFORCE(ws); - auto shape_hints_ordered = - SsaRewriteAndMapNames(ws, pred_net, input_shape_hints); - Workspace mapped_ws(ws, input_mapping_); - std::unordered_map shape_hints = - InferShapes(&mapped_ws, pred_net, &shape_hints_ordered, infer_shapes_); - - CAFFE_ENFORCE(pred_net, "Predict net cannot be nullptr"); - onnx::OnnxExporter exporter(nullptr); - - // function to tell whether the ONNXIFI backend supports a given C2 op or not - // TODO: choose backend id + const std::unordered_set& weights, + const std::unordered_set& blacklisted_ops, + std::unordered_map* shape_hints) { onnxifi_library* backend = lib_; onnxBackendID backend_id = backend_ids_[0]; - auto supports = [&exporter, - &shape_hints, - &blacklisted_ops, - backend, - backend_id](const caffe2::OperatorDef& op) { + // function to tell whether the ONNXIFI backend supports a given C2 op or not + onnx::OnnxExporter exporter(nullptr); + auto onnx_supports = [&exporter, + shape_hints, + &blacklisted_ops, + backend, + backend_id](const caffe2::OperatorDef& op) { try { int pos = ArgumentHelper::GetSingleArgument(op, kNetPos, -1); @@ -427,7 +415,7 @@ void OnnxifiTransformer::Transform( ::ONNX_NAMESPACE::ModelProto onnx_model; FillModelInfo(&onnx_model); - auto results = exporter.Caffe2OpToOnnxNodes(op, shape_hints); + auto results = exporter.Caffe2OpToOnnxNodes(op, *shape_hints); std::unordered_set used_inputs; std::unordered_set used_outputs; std::vector boundary_inputs; @@ -476,12 +464,12 @@ void OnnxifiTransformer::Transform( // Add input/output shape info auto io_vec = - ConvertToValueInfo(boundary_inputs, shape_hints, extra_shape_hints); + ConvertToValueInfo(boundary_inputs, *shape_hints, extra_shape_hints); for (const auto& i : io_vec) { onnx_model.mutable_graph()->add_input()->CopyFrom(i); } io_vec = - ConvertToValueInfo(boundary_outputs, shape_hints, extra_shape_hints); + ConvertToValueInfo(boundary_outputs, *shape_hints, extra_shape_hints); for (const auto& i : io_vec) { onnx_model.mutable_graph()->add_output()->CopyFrom(i); } @@ -498,18 +486,45 @@ void OnnxifiTransformer::Transform( return true; } } catch (const std::exception& ex) { - LOG(ERROR) << "Gaught exception when converting op " << op.type() + LOG(ERROR) << "Caught exception when converting op " << op.type() << ", what: " << ex.what(); return false; } }; - // function to convert runnable subgraph into a trt op. Note that to keep the - // interface clean, we do the double conversion from C2 op to Onnx ops here - // but it should be OK as the cost is really small. We also need to keep the - // same exporter throughout the process to avoid duplicated dummy name + // function to convert runnable subgraph into an onnxifi op. We need to keep + // the same exporter throughout the process to avoid duplicated dummy name // generation onnx::OnnxExporter exporter2(nullptr); + auto onnx_converter = [this, ws, &weights, shape_hints, &exporter2]( + const caffe2::NetDef& net) mutable { + return SubnetToOnnxifiOp(net, weights, ws, &exporter2, shape_hints); + }; + + return opt::OptimizeForBackend(*pred_net, onnx_supports, onnx_converter); +} + +// Cutting off the runnable part and replace with ONNXIFI ops. Asssume the nets +// were topologically sorted +void OnnxifiTransformer::Transform( + Workspace* ws, + NetDef* pred_net, + const std::vector& external_inputs, + const std::unordered_map& input_shape_hints, + const std::unordered_set& blacklisted_ops) { + CAFFE_ENFORCE(ws); + CAFFE_ENFORCE(pred_net, "Predict net cannot be nullptr"); + + // SSA Rewrite the net + auto shape_hints_ordered = + SsaRewriteAndMapNames(ws, pred_net, input_shape_hints); + + // Populate shape info + Workspace mapped_ws(ws, input_mapping_); + std::unordered_map shape_hints = InferShapes( + &mapped_ws, pred_net, &shape_hints_ordered, opts_.infer_shapes); + + // Figure out what are the weights std::unordered_set weights; std::unordered_set input_set; for (const auto& i : external_inputs) { @@ -524,12 +539,10 @@ void OnnxifiTransformer::Transform( weights.emplace(s); } } - auto trt_converter = [this, ws, &weights, &shape_hints, &exporter2]( - const caffe2::NetDef& net) mutable { - return SubnetToOnnxifiOp(net, weights, ws, &exporter2, &shape_hints); - }; - NetDef net_opt = opt::OptimizeForBackend(*pred_net, supports, trt_converter); + // Transform the net + NetDef net_opt = + TransformViaOnnx(ws, pred_net, weights, blacklisted_ops, &shape_hints); // Need to figure out a proper place to handle device option net_opt.mutable_device_option()->CopyFrom(pred_net->device_option()); diff --git a/caffe2/opt/onnxifi_transformer.h b/caffe2/opt/onnxifi_transformer.h index e3eee2c..6286535 100644 --- a/caffe2/opt/onnxifi_transformer.h +++ b/caffe2/opt/onnxifi_transformer.h @@ -18,9 +18,18 @@ namespace onnx { class OnnxExporter; } +struct OnnxifiTransformerOptions { + // Run shape inference + bool infer_shapes{false}; + // Dump onnx model for debugging + bool debug{false}; + // Pass serialized onnx model if true, otherwise pass serialized c2 model + bool use_onnx{true}; +}; + class CAFFE2_API OnnxifiTransformer final { public: - explicit OnnxifiTransformer(bool infer_shapes, bool debug); + explicit OnnxifiTransformer(const OnnxifiTransformerOptions& opts); ~OnnxifiTransformer(); void Transform( @@ -60,11 +69,16 @@ class CAFFE2_API OnnxifiTransformer final { NetDef* pred_net, const std::unordered_map& input_shape_hints); - // Run shape inference - bool infer_shapes_{false}; + // Transform by passing ONNX proto to backend + NetDef TransformViaOnnx( + Workspace* ws, + NetDef* pred_net, + const std::unordered_set& weights, + const std::unordered_set& blacklisted_ops, + std::unordered_map* shape_hints); - // Dump onnx model for debugging - bool debug_{false}; + // Options + OnnxifiTransformerOptions opts_; // Pointer to loaded onnxifi library onnxifi_library* lib_{nullptr}; diff --git a/caffe2/python/pybind_state.cc b/caffe2/python/pybind_state.cc index a4a1509..52fe08c 100644 --- a/caffe2/python/pybind_state.cc +++ b/caffe2/python/pybind_state.cc @@ -1618,7 +1618,10 @@ void addGlobalMethods(py::module& m) { tensor_shapes.emplace( it.first, CreateTensorShape(it.second, TensorProto::FLOAT)); } - OnnxifiTransformer ts(infer_shapes, debug_builder); + OnnxifiTransformerOptions opts; + opts.infer_shapes = infer_shapes; + opts.debug = debug_builder; + OnnxifiTransformer ts(opts); ts.Transform( GetCurrentWorkspace(), &pred_net, -- 2.7.4