void testIRParser() {
{
auto graph = std::make_shared<Graph>();
+ std::unordered_map<std::string, Value*> vmap;
script::parseIR(
R"IR(
graph(%0 : Tensor, %1 : Tensor):
%res, %3 = foo::mul(%0, %2)
%x, %y = foo::combine(%res, %2, %3)
return (%x, %y, %res))IR",
- &*graph);
+ &*graph,
+ vmap);
AT_ASSERT(graph->inputs().size() == 2);
AT_ASSERT(graph->outputs().size() == 3);
Value* res = graph->outputs()[2];
Value* t0 = graph->inputs()[0];
Value* t1 = graph->inputs()[1];
+ AT_ASSERT(vmap["x"] == x);
+ AT_ASSERT(vmap["y"] == y);
+ AT_ASSERT(vmap["res"] == res);
+ AT_ASSERT(vmap["0"] == t0);
+ AT_ASSERT(vmap["1"] == t1);
AT_ASSERT(x->node() == y->node());
Node* comb = x->node();
Value* t2 = comb->inputs()[1];
Value* t3 = comb->inputs()[2];
+ AT_ASSERT(vmap["2"] == t2);
+ AT_ASSERT(vmap["3"] == t3);
AT_ASSERT(comb->kind().toQualString() == std::string("foo::combine"));
AT_ASSERT(comb->outputs() == std::vector<Value*>({x, y}));
AT_ASSERT(comb->inputs() == std::vector<Value*>({res, t2, t3}));
struct ParsedLiteral;
class IRParser {
- friend void parseIR(const std::string& str, torch::jit::Graph* graph);
- IRParser(const std::string& str, torch::jit::Graph* graph)
+ friend void parseIR(
+ const std::string& str,
+ torch::jit::Graph* graph,
+ std::unordered_map<std::string, Value*>& vmap);
+ IRParser(
+ const std::string& str,
+ torch::jit::Graph* graph,
+ std::unordered_map<std::string, Value*>& vmap)
: L(str),
g(graph),
+ vmap(vmap),
type_parser(L, /*parse_complete_tensor_types*/ true) {}
std::string parseVar();
torch::jit::script::Lexer L;
torch::jit::Graph* g = nullptr;
- std::unordered_map<std::string, Value*> vmap;
+ std::unordered_map<std::string, Value*>& vmap;
SchemaTypeParser type_parser;
};
TypePtr type;
};
-void parseIR(const std::string& str, torch::jit::Graph* graph) {
- torch::jit::script::IRParser p(str, graph);
+void parseIR(
+ const std::string& str,
+ torch::jit::Graph* graph,
+ std::unordered_map<std::string, Value*>& vmap) {
+ torch::jit::script::IRParser p(str, graph, vmap);
p.parse();
}
+void parseIR(const std::string& str, torch::jit::Graph* graph) {
+ std::unordered_map<std::string, Value*> vmap;
+ parseIR(str, graph, vmap);
+}
+
VarWithType IRParser::parseVarWithType() {
VarWithType r;
r.name = parseVar();
-#include <string>
#include <torch/csrc/WindowsTorchApiMacro.h>
+#include <string>
+#include <unordered_map>
namespace torch {
namespace jit {
struct Graph;
+struct Value;
namespace script {
// \brief Parse IR from \p STR constructing the corresponding IR in\ GRAPH.
TORCH_API void parseIR(const std::string& str, torch::jit::Graph* graph);
+/** \brief Parse IR from \p STR constructing the corresponding IR in\ GRAPH.
+ *
+ * \p VMAP is filled with String to Value pairs allowing to index Values in the
+ * newly created graph by their name in the original IR string.
+ */
+TORCH_API void parseIR(
+ const std::string& str,
+ torch::jit::Graph* graph,
+ std::unordered_map<std::string, Value*>& vmap);
+
} // namespace script
} // namespace jit
} // namespace torch