#pragma once
#include <memory>
+#include <unordered_map>
#include <vector>
#include <c10/util/irange.h>
std::vector<IValue> constant_table_;
std::vector<Operation> operator_table_;
+ // map<(op name, num inputs), index in operator table>, to avoid duplicates,
+ // not including vararg operators
+ std::unordered_map<
+ std::pair<std::string, int>,
+ int,
+ std::function<size_t(const std::pair<std::string, int>& p)>>
+ operator_table_inv_;
std::vector<Function*> function_table_;
std::vector<std::unique_ptr<GraphFunction>> forked_functions_;
std::vector<TypePtr> type_table_;
std::string function_name,
size_t remaining_bailout_depth,
bool emit_instructions = true)
- : function_name_(std::move(function_name)),
+ : operator_table_inv_(
+ 0,
+ [](const std::pair<std::string, int>& p) {
+ return std::hash<std::string>()(p.first) ^
+ std::hash<int>()(p.second);
+ }),
+ function_name_(std::move(function_name)),
preprocess_(*graph),
current_node_(preprocess_.graph->return_node()),
remaining_bailout_depth_(remaining_bailout_depth) {
virtual void emitOperator(Node* node) {
emitLoadInputs(node->inputs());
const Operator& op = node->getOperator();
- if (op.hasOperation() && op.schema().is_vararg()) {
- insertInstruction(OPN, operator_table_.size(), node->inputs().size());
+ int num_inputs = node->inputs().size();
+ bool is_vararg = op.schema().is_vararg();
+
+ int operation_index = add_to_operator_table(
+ op.getOperation(node),
+ c10::toString(op.schema().operator_name()),
+ num_inputs,
+ is_vararg);
+
+ if (op.hasOperation() && is_vararg) {
+ insertInstruction(OPN, operation_index, num_inputs);
} else {
- insertInstruction(OP, operator_table_.size());
+ insertInstruction(OP, operation_index);
}
- operator_table_.emplace_back(op.getOperation(node));
}
void emitWait(Node* node) {
dump(out, i);
}
}
+
+ /**
+ * Add an operation to operator_table_ if not a duplicate and return its index
+ */
+ int add_to_operator_table(
+ const Operation& oper,
+ const std::string& op_name,
+ const int num_inputs,
+ const bool is_vararg) {
+ int size = operator_table_.size();
+
+ if (!is_vararg) {
+ std::pair<std::string, int> key(op_name, num_inputs);
+ auto found = operator_table_inv_.find(key);
+
+ if (found != operator_table_inv_.end()) {
+ return found->second;
+ }
+
+ operator_table_inv_.emplace(key, size);
+ }
+
+ operator_table_.emplace_back(oper);
+ return size;
+ }
};
struct MobileCodeImpl : CodeImpl {
CodeImpl::emitOperator(node);
} else {
const Operator& op = node->getOperator();
- if (op.hasOperation() && op.schema().is_vararg()) {
+ std::string unique_op_name = c10::toString(op.schema().operator_name());
+ int num_inputs = node->inputs().size();
+ bool is_vararg = op.schema().is_vararg();
+
+ if (op.hasOperation() && is_vararg) {
emitLoadInputs(node->inputs());
- insertInstruction(OPN, operator_table_.size(), node->inputs().size());
+ int operation_index = add_to_operator_table(
+ op.getOperation(node),
+ unique_op_name,
+ num_inputs,
+ /* is_vararg */ true);
+ insertInstruction(OPN, operation_index, num_inputs);
} else {
- auto unique_op_name = c10::toString(op.schema().operator_name());
- auto num_include = node->inputs().size();
+ auto num_include = num_inputs;
auto it = op_to_num_specified_args_.find(unique_op_name);
if (it != op_to_num_specified_args_.end()) {
num_include = it->second;
} else {
emitLoadInputs(node->inputs(), num_include);
}
- insertInstruction(OP, operator_table_.size());
+ int operation_index = add_to_operator_table(
+ op.getOperation(node), unique_op_name, num_inputs, is_vararg);
+ insertInstruction(OP, operation_index);
}
- operator_table_.emplace_back(op.getOperation(node));
}
}