std::unordered_map<std::string, std::string> SsaRewrite(
caffe2::NetDef* init_net,
- caffe2::NetDef* pred_net) {
+ caffe2::NetDef* pred_net,
+ const std::unordered_set<string>& exceptions) {
std::unordered_map<std::string, std::string> input_mapping;
std::unordered_map<std::string, int> blob_versions;
#define REWRITE_EXTERNAL_IO(net, name) \
for (auto& name : *net->mutable_external_##name()) { \
+ if (exceptions.count(name)) { \
+ continue; \
+ } \
auto version = blob_versions.at(name); \
auto new_##name = SsaName(name, version); \
name##_mapping.emplace(new_##name, name); \
op.set_output(0, SsaName(output, 0));
}
for (const auto& input : init_net->external_input()) {
+ if (exceptions.count(input)) {
+ continue;
+ }
blob_versions.emplace(input, 0);
}
for (const auto& output : init_net->external_output()) {
+ if (exceptions.count(output)) {
+ continue;
+ }
blob_versions.emplace(output, 0);
}
REWRITE_EXTERNAL_IO(init_net, input);
if (pred_net) {
for (const auto& input : pred_net->external_input()) {
+ if (exceptions.count(input)) {
+ continue;
+ }
blob_versions.emplace(input, 0);
}
REWRITE_EXTERNAL_IO(pred_net, input);
for (auto& op : *pred_net->mutable_op()) {
for (auto& input : *op.mutable_input()) {
+ if (exceptions.count(input)) {
+ continue;
+ }
const auto it = blob_versions.find(input);
if (it != blob_versions.end()) {
input = SsaName(input, it->second);
}
}
for (auto& output : *op.mutable_output()) {
+ if (exceptions.count(output)) {
+ continue;
+ }
auto it = blob_versions.find(output);
if (it != blob_versions.end()) {
it->second += 1;
}
for (auto& op : *pred_net->mutable_op()) {
for (auto& output : *op.mutable_output()) {
+ if (exceptions.count(output)) {
+ continue;
+ }
auto pos = output.find_last_of('_');
CAFFE_ENFORCE_NE(pos, 0);
auto basename = output.substr(0, pos);
// output names for predict net.
CAFFE2_API std::unordered_map<std::string, std::string> SsaRewrite(
caffe2::NetDef* init_net,
- caffe2::NetDef* pred_net);
+ caffe2::NetDef* pred_net,
+ const std::unordered_set<std::string>& exceptions =
+ std::unordered_set<std::string>());
::ONNX_NAMESPACE::TensorProto::DataType Caffe2TypeToOnnxType(
caffe2::TensorProto::DataType t);
CaffeMap<std::string, TensorShape> OnnxifiTransformer::SsaRewriteAndMapNames(
Workspace* ws,
NetDef* pred_net,
+ const std::unordered_set<std::string>& weights,
const std::unordered_map<std::string, TensorShape>& input_shape_hints) {
- input_mapping_ = onnx::SsaRewrite(nullptr, pred_net);
+ // Make sure weights do not contain output of any op.
+ for (const auto& op : pred_net->op()) {
+ for (const auto& output : op.output()) {
+ CAFFE_ENFORCE_EQ(weights.count(output), 0);
+ }
+ }
+ input_mapping_ = onnx::SsaRewrite(nullptr, pred_net, weights);
// Annote the ops with net position
AnnotateOpIndex(pred_net);
std::vector<std::string> external_inputs;
+ // Need to add mapping for weights. This will be used to create new workspace
+ // with mapped weights.
+ for (const auto& w : weights) {
+ input_mapping_.emplace(w, w);
+ }
for (const auto kv : input_mapping_) {
reverse_input_mapping_.emplace(kv.second, kv.first);
if (!ws->HasBlob(kv.second)) {
Workspace* ws,
NetDef* pred_net,
const std::vector<std::string>& external_inputs,
+ const std::vector<std::string>& weight_names,
const std::unordered_map<std::string, TensorShape>& input_shape_hints,
const std::unordered_set<int>& blacklisted_ops) {
CAFFE_ENFORCE(ws);
model_id_ = GetModelId(*pred_net);
onnxifi_op_id_ = 0;
+ std::unordered_set<std::string> weights(
+ weight_names.begin(), weight_names.end());
+
// SSA Rewrite the net
auto shape_hints_ordered =
- SsaRewriteAndMapNames(ws, pred_net, input_shape_hints);
+ SsaRewriteAndMapNames(ws, pred_net, weights, input_shape_hints);
// Populate shape info
Workspace mapped_ws(ws, input_mapping_);
opts_.infer_shapes,
opts_.bound_shape_spec);
- // Figure out what are the weights
- std::unordered_set<std::string> weights;
- std::unordered_set<std::string> input_set;
- for (const auto& i : external_inputs) {
- const auto it = reverse_input_mapping_.find(i);
- if (it != reverse_input_mapping_.end()) {
- input_set.emplace(it->second);
- }
- }
- const std::vector<string>& ws_blobs = mapped_ws.Blobs();
- for (const auto& s : ws_blobs) {
- if (!input_set.count(s)) {
- weights.emplace(s);
- }
- }
-
// Transform the net
NetDef net_opt = opts_.use_onnx
? TransformViaOnnx(ws, pred_net, weights, blacklisted_ops, &shape_hints)
Workspace* ws,
NetDef* pred_net,
const std::vector<std::string>& external_inputs,
+ const std::vector<std::string>& weight_names,
const std::unordered_map<std::string, TensorShape>& shape_hints,
const std::unordered_set<int>& blacklisted_ops);
CaffeMap<std::string, TensorShape> SsaRewriteAndMapNames(
Workspace* ws,
NetDef* pred_net,
+ const std::unordered_set<std::string>& weights,
const std::unordered_map<std::string, TensorShape>& input_shape_hints);
// Transform by passing C2 proto to backend
opts.debug = debug_builder;
opts.use_onnx = use_onnx;
OnnxifiTransformer ts(opts);
+ Workspace* curr_ws = GetCurrentWorkspace();
+ auto weight_names = curr_ws->Blobs();
ts.Transform(
- GetCurrentWorkspace(),
+ curr_ws,
&pred_net,
external_inputs,
+ weight_names,
tensor_shapes,
std::unordered_set<int>());
std::string pred_net_str2;