}
} // namespace
-OnnxifiTransformer::OnnxifiTransformer(bool infer_shapes, bool debug)
- : infer_shapes_(infer_shapes), debug_(debug) {
+OnnxifiTransformer::OnnxifiTransformer(const OnnxifiTransformerOptions& opts)
+ : opts_(opts) {
lib_ = onnx::initOnnxifiLibrary();
CAFFE_ENFORCE(lib_, "Cannot initialize ONNXIFI library");
CAFFE_ENFORCE_EQ(
}
// Debugging stuff
- if (debug_) {
+ if (opts_.debug) {
DumpModel(onnx_model, "debug.onnxtxt");
}
return shape_hints_ordered;
}
-// Cutting off the runnable part and replace with ONNXIFI ops. Asssume the nets
-// were topologically sorted
-void OnnxifiTransformer::Transform(
+NetDef OnnxifiTransformer::TransformViaOnnx(
Workspace* ws,
NetDef* pred_net,
- const std::vector<std::string>& external_inputs,
- const std::unordered_map<std::string, TensorShape>& input_shape_hints,
- const std::unordered_set<int>& blacklisted_ops) {
- CAFFE_ENFORCE(ws);
- auto shape_hints_ordered =
- SsaRewriteAndMapNames(ws, pred_net, input_shape_hints);
- Workspace mapped_ws(ws, input_mapping_);
- std::unordered_map<std::string, TensorShape> shape_hints =
- InferShapes(&mapped_ws, pred_net, &shape_hints_ordered, infer_shapes_);
-
- CAFFE_ENFORCE(pred_net, "Predict net cannot be nullptr");
- onnx::OnnxExporter exporter(nullptr);
-
- // function to tell whether the ONNXIFI backend supports a given C2 op or not
- // TODO: choose backend id
+ const std::unordered_set<std::string>& weights,
+ const std::unordered_set<int>& blacklisted_ops,
+ std::unordered_map<std::string, TensorShape>* shape_hints) {
onnxifi_library* backend = lib_;
onnxBackendID backend_id = backend_ids_[0];
- auto supports = [&exporter,
- &shape_hints,
- &blacklisted_ops,
- backend,
- backend_id](const caffe2::OperatorDef& op) {
+ // function to tell whether the ONNXIFI backend supports a given C2 op or not
+ onnx::OnnxExporter exporter(nullptr);
+ auto onnx_supports = [&exporter,
+ shape_hints,
+ &blacklisted_ops,
+ backend,
+ backend_id](const caffe2::OperatorDef& op) {
try {
int pos =
ArgumentHelper::GetSingleArgument<OperatorDef, int>(op, kNetPos, -1);
::ONNX_NAMESPACE::ModelProto onnx_model;
FillModelInfo(&onnx_model);
- auto results = exporter.Caffe2OpToOnnxNodes(op, shape_hints);
+ auto results = exporter.Caffe2OpToOnnxNodes(op, *shape_hints);
std::unordered_set<std::string> used_inputs;
std::unordered_set<std::string> used_outputs;
std::vector<std::string> boundary_inputs;
// Add input/output shape info
auto io_vec =
- ConvertToValueInfo(boundary_inputs, shape_hints, extra_shape_hints);
+ ConvertToValueInfo(boundary_inputs, *shape_hints, extra_shape_hints);
for (const auto& i : io_vec) {
onnx_model.mutable_graph()->add_input()->CopyFrom(i);
}
io_vec =
- ConvertToValueInfo(boundary_outputs, shape_hints, extra_shape_hints);
+ ConvertToValueInfo(boundary_outputs, *shape_hints, extra_shape_hints);
for (const auto& i : io_vec) {
onnx_model.mutable_graph()->add_output()->CopyFrom(i);
}
return true;
}
} catch (const std::exception& ex) {
- LOG(ERROR) << "Gaught exception when converting op " << op.type()
+ LOG(ERROR) << "Caught exception when converting op " << op.type()
<< ", what: " << ex.what();
return false;
}
};
- // function to convert runnable subgraph into a trt op. Note that to keep the
- // interface clean, we do the double conversion from C2 op to Onnx ops here
- // but it should be OK as the cost is really small. We also need to keep the
- // same exporter throughout the process to avoid duplicated dummy name
+ // function to convert runnable subgraph into an onnxifi op. We need to keep
+ // the same exporter throughout the process to avoid duplicated dummy name
// generation
onnx::OnnxExporter exporter2(nullptr);
+ auto onnx_converter = [this, ws, &weights, shape_hints, &exporter2](
+ const caffe2::NetDef& net) mutable {
+ return SubnetToOnnxifiOp(net, weights, ws, &exporter2, shape_hints);
+ };
+
+ return opt::OptimizeForBackend(*pred_net, onnx_supports, onnx_converter);
+}
+
+// Cutting off the runnable part and replace with ONNXIFI ops. Asssume the nets
+// were topologically sorted
+void OnnxifiTransformer::Transform(
+ Workspace* ws,
+ NetDef* pred_net,
+ const std::vector<std::string>& external_inputs,
+ const std::unordered_map<std::string, TensorShape>& input_shape_hints,
+ const std::unordered_set<int>& blacklisted_ops) {
+ CAFFE_ENFORCE(ws);
+ CAFFE_ENFORCE(pred_net, "Predict net cannot be nullptr");
+
+ // SSA Rewrite the net
+ auto shape_hints_ordered =
+ SsaRewriteAndMapNames(ws, pred_net, input_shape_hints);
+
+ // Populate shape info
+ Workspace mapped_ws(ws, input_mapping_);
+ std::unordered_map<std::string, TensorShape> shape_hints = InferShapes(
+ &mapped_ws, pred_net, &shape_hints_ordered, opts_.infer_shapes);
+
+ // Figure out what are the weights
std::unordered_set<std::string> weights;
std::unordered_set<std::string> input_set;
for (const auto& i : external_inputs) {
weights.emplace(s);
}
}
- auto trt_converter = [this, ws, &weights, &shape_hints, &exporter2](
- const caffe2::NetDef& net) mutable {
- return SubnetToOnnxifiOp(net, weights, ws, &exporter2, &shape_hints);
- };
- NetDef net_opt = opt::OptimizeForBackend(*pred_net, supports, trt_converter);
+ // Transform the net
+ NetDef net_opt =
+ TransformViaOnnx(ws, pred_net, weights, blacklisted_ops, &shape_hints);
// Need to figure out a proper place to handle device option
net_opt.mutable_device_option()->CopyFrom(pred_net->device_option());