#include <ATen/core/ivalue.h>
#include <torch/csrc/jit/mobile/parse_bytecode.h>
#include <torch/csrc/jit/mobile/type_parser.h>
+#include <torch/csrc/jit/runtime/instruction.h>
#include <torch/csrc/jit/serialization/import_export_constants.h>
#include <torch/csrc/jit/serialization/import_export_functions.h>
#include <torch/custom_class_detail.h>
namespace mobile {
-namespace {} // namespace
+namespace {
+#define COUNT_OPCODE(_, _a) 1 +
+constexpr size_t numOpcodes = FORALL_OPCODES(COUNT_OPCODE) 0;
+#undef COUNT_OPCODE
+
+// Pickled strings are memoized, so we can cache a mapping from
+// pointers to parsed OpCodes to speed up parsing.
+class OpCodeCache {
+ private:
+ // We store as void* to emphasize that we care only about the
+ // address and should not be dereferencing these pointers.
+ std::array<const void*, numOpcodes> keys_{};
+ std::array<OpCode, numOpcodes> values_{};
+ size_t usedEntries_ = 0;
+
+ public:
+ OpCodeCache() {
+ memset(keys_.data(), 0, keys_.size() * sizeof(keys_[0]));
+ }
+
+ OpCode parse(const c10::ivalue::ConstantString& s) {
+ const auto endIt = keys_.begin() + usedEntries_;
+ auto it = std::find_if(
+ keys_.begin(), endIt, [&s](const void* k) { return k == &s; });
+ if (it == endIt) {
+ OpCode result = parseOpCode(s.string().c_str());
+ if (usedEntries_ < numOpcodes) {
+ keys_[usedEntries_] = &s;
+ values_[usedEntries_++] = result;
+ }
+ return result;
+ }
+ // NOTE: I tried implementing the transpose heuristic here to
+ // speed up the search, but it removed the benefit of this cache.
+ return values_[it - keys_.begin()];
+ }
+};
+} // namespace
void parseInstructions(
const std::string& function_name,
"The numbers of instructions and debug handles strings do not match.");
}
+ // NOTE: this won't perform particularly well if the ins_list IValue
+ // didn't come from unpickler and thus have its strings
+ // interned. Consider adding a flag to bypass the cache if that
+ // becomes an important use case.
+ OpCodeCache opCodeCache;
for (const auto j : c10::irange(ins_list.size())) {
std::vector<IValue> ins_item =
std::move(*std::move(ins_list[j]).toTuple()).elements();
ins_item.size() == 3,
"There should be three parts in an instruction. The function name is ",
function_name);
- OpCode op_code = parseOpCode(ins_item[0].toString()->string().c_str());
+ OpCode op_code = opCodeCache.parse(*ins_item[0].toString());
int X = ins_item[1].toInt();
int N = ins_item[2].toInt();
if (!debug_handles_list.empty()) {