_(aten, __round_to_zero_floordiv)\
_(prim, fork) \
_(prim, RaiseException) \
+ _(prim, Function) \
_(aten, append) \
_(aten, format) \
_(aten, __not__) \
%26 : Float(*, *)) {
%27 : Float(*, *) = aten::mul(%0, %26)
%28 : int[] = aten::size(%outgate)
- %29 : Tensor = aten::sum_to_size(%27, %28)
+ %grad_self.1 : Tensor = aten::sum_to_size(%27, %28)
%30 : Float(*, *) = aten::mul(%0, %outgate)
%31 : int[] = aten::size(%26)
- %32 : Tensor = aten::sum_to_size(%30, %31)
- %33 : Tensor = prim::FusionGroup_0(%1, %32, %26)
+ %grad_other.1 : Tensor = aten::sum_to_size(%30, %31)
+ %33 : Tensor = prim::FusionGroup_0(%1, %grad_other.1, %26)
%34 : Tensor = prim::SumToSize(%33, %24)
%35 : Tensor = prim::SumToSize(%33, %25)
%36 : Tensor = aten::mul(%35, %cellgate)
%37 : int[] = aten::size(%ingate)
- %38 : Tensor = aten::sum_to_size(%36, %37)
+ %grad_self.3 : Tensor = aten::sum_to_size(%36, %37)
%39 : Tensor = aten::mul(%35, %ingate)
%40 : int[] = aten::size(%cellgate)
- %41 : Tensor = aten::sum_to_size(%39, %40)
+ %grad_other.3 : Tensor = aten::sum_to_size(%39, %40)
%42 : Tensor = aten::mul(%34, %9)
%43 : int[] = aten::size(%forgetgate)
- %44 : Tensor = aten::sum_to_size(%42, %43)
+ %grad_self.5 : Tensor = aten::sum_to_size(%42, %43)
%45 : Tensor = aten::mul(%34, %forgetgate)
%46 : int[] = aten::size(%9)
- %47 : Tensor = aten::sum_to_size(%45, %46)
- %48 : Tensor = prim::FusionGroup_1(%38, %ingate, %44, %forgetgate, %41, %cellgate, %29, %outgate)
+ %grad_other.5 : Tensor = aten::sum_to_size(%45, %46)
+ %48 : Tensor = prim::FusionGroup_1(%grad_self.3, %ingate, %grad_self.5, %forgetgate, %grad_other.3, %cellgate, %grad_self.1, %outgate)
%49 : Tensor = prim::SumToSize(%48, %19)
%50 : Tensor = prim::SumToSize(%48, %17)
%51 : Tensor = prim::SumToSize(%48, %14)
%60 : Float(*, *) = aten::t(%11)
%61 : Float(*, *) = aten::mm(%60, %51)
%62 : Float(*, *) = aten::t(%61)
- return (%47, %49, %50, %54, %57, %59, %62);
+ return (%grad_other.5, %49, %50, %54, %57, %59, %62);
}
with prim::FusionGroup_0 = graph(%0 : Float(*, *)
%1 : Tensor
%32 : int = prim::Constant[value=1]()
%33 : Float(*, *) = aten::mul(%0, %31)
%34 : int[] = aten::size(%outgate)
- %35 : Tensor = aten::sum_to_size(%33, %34)
+ %grad_self.1 : Tensor = aten::sum_to_size(%33, %34)
%36 : Float(*, *) = aten::mul(%0, %outgate)
%37 : int[] = aten::size(%31)
- %38 : Tensor = aten::sum_to_size(%36, %37)
- %39 : Tensor = prim::FusionGroup_0(%1, %38, %31)
+ %grad_other.1 : Tensor = aten::sum_to_size(%36, %37)
+ %39 : Tensor = prim::FusionGroup_0(%1, %grad_other.1, %31)
%40 : Tensor = prim::SumToSize(%39, %29)
%41 : Tensor = prim::SumToSize(%39, %30)
%42 : Tensor = aten::mul(%41, %cellgate)
%43 : int[] = aten::size(%ingate)
- %44 : Tensor = aten::sum_to_size(%42, %43)
+ %grad_self.3 : Tensor = aten::sum_to_size(%42, %43)
%45 : Tensor = aten::mul(%41, %ingate)
%46 : int[] = aten::size(%cellgate)
- %47 : Tensor = aten::sum_to_size(%45, %46)
+ %grad_other.3 : Tensor = aten::sum_to_size(%45, %46)
%48 : Tensor = aten::mul(%40, %10)
%49 : int[] = aten::size(%forgetgate)
- %50 : Tensor = aten::sum_to_size(%48, %49)
- %51 : Tensor = prim::FusionGroup_1(%44, %ingate, %50, %forgetgate, %47, %cellgate, %35, %outgate)
+ %grad_self.5 : Tensor = aten::sum_to_size(%48, %49)
+ %51 : Tensor = prim::FusionGroup_1(%grad_self.3, %ingate, %grad_self.5, %forgetgate, %grad_other.3, %cellgate, %grad_self.1, %outgate)
%52 : Tensor = prim::SumToSize(%51, %24)
%53 : Tensor = prim::SumToSize(%51, %22)
%54 : Tensor = aten::mul(%53, %Uz)
%55 : int[] = aten::size(%11)
- %56 : Tensor = aten::sum_to_size(%54, %55)
+ %grad_self.7 : Tensor = aten::sum_to_size(%54, %55)
%57 : Tensor = aten::mul(%53, %11)
%58 : int[] = aten::size(%Uz)
- %59 : Tensor = aten::sum_to_size(%57, %58)
+ %grad_other.7 : Tensor = aten::sum_to_size(%57, %58)
%60 : Tensor = prim::SumToSize(%51, %19)
%61 : Tensor = prim::SumToSize(%51, %20)
%62 : Tensor = aten::mul(%61, %Wx)
%63 : int[] = aten::size(%12)
- %64 : Tensor = aten::sum_to_size(%62, %63)
+ %grad_self.9 : Tensor = aten::sum_to_size(%62, %63)
%65 : Tensor = aten::mul(%61, %12)
%66 : int[] = aten::size(%Wx)
- %67 : Tensor = aten::sum_to_size(%65, %66)
+ %grad_other.9 : Tensor = aten::sum_to_size(%65, %66)
%68 : Tensor = aten::mul(%60, %Uz)
%69 : int[] = aten::size(%18)
- %70 : Tensor = aten::sum_to_size(%68, %69)
+ %grad_self.11 : Tensor = aten::sum_to_size(%68, %69)
%71 : Tensor = aten::mul(%60, %18)
- %72 : Tensor = aten::sum_to_size(%71, %58)
- %73 : Tensor = aten::add(%59, %72, %32)
- %74 : Tensor = aten::mul(%70, %Wx)
+ %grad_other.11 : Tensor = aten::sum_to_size(%71, %58)
+ %73 : Tensor = aten::add(%grad_other.7, %grad_other.11, %32)
+ %74 : Tensor = aten::mul(%grad_self.11, %Wx)
%75 : int[] = aten::size(%13)
- %76 : Tensor = aten::sum_to_size(%74, %75)
- %77 : Tensor = aten::mul(%70, %13)
- %78 : Tensor = aten::sum_to_size(%77, %66)
- %79 : Tensor = aten::add(%67, %78, %32)
+ %grad_self.13 : Tensor = aten::sum_to_size(%74, %75)
+ %77 : Tensor = aten::mul(%grad_self.11, %13)
+ %grad_other.13 : Tensor = aten::sum_to_size(%77, %66)
+ %79 : Tensor = aten::add(%grad_other.9, %grad_other.13, %32)
%80 : Float(*, *) = aten::t(%14)
%81 : Float(*, *) = aten::mm(%80, %73)
%82 : Float(*, *) = aten::t(%81)
%83 : Float(*, *) = aten::t(%15)
%84 : Float(*, *) = aten::mm(%83, %79)
%85 : Float(*, *) = aten::t(%84)
- return (%52, %56, %64, %76, %82, %85);
+ return (%52, %grad_self.7, %grad_self.9, %grad_self.13, %82, %85);
}
with prim::FusionGroup_0 = graph(%0 : Float(*, *)
%1 : Tensor
%12 : Tensor = prim::SumToSize(%11, %6)
-> (%10, %12)
}
- %13 : Tensor, %14 : Tensor = prim::GradOf[name="aten::mul"](%8)
+ %grad_self.2 : Tensor, %grad_other.2 : Tensor = prim::GradOf[name="aten::mul"](%8)
block0() {
%15 : Tensor = aten::mul(%8, %2)
%16 : int[] = aten::size(%4)
- %17 : Tensor = aten::sum_to_size(%15, %16)
+ %grad_self.1 : Tensor = aten::sum_to_size(%15, %16)
%18 : Tensor = aten::mul(%8, %4)
%19 : int[] = aten::size(%2)
- %20 : Tensor = aten::sum_to_size(%18, %19)
- -> (%17, %20)
+ %grad_other.1 : Tensor = aten::sum_to_size(%18, %19)
+ -> (%grad_self.1, %grad_other.1)
}
- %21 : Tensor = prim::AutogradAdd(%1, %13)
- %22 : Tensor, %23 : Tensor = prim::GradOf[name="aten::mul"](%21)
+ %21 : Tensor = prim::AutogradAdd(%1, %grad_self.2)
+ %grad_self : Tensor, %grad_other : Tensor = prim::GradOf[name="aten::mul"](%21)
block0() {
%24 : Tensor = aten::mul(%21, %3)
%25 : int[] = aten::size(%2)
- %26 : Tensor = aten::sum_to_size(%24, %25)
+ %grad_self.3 : Tensor = aten::sum_to_size(%24, %25)
%27 : Tensor = aten::mul(%21, %2)
%28 : int[] = aten::size(%3)
- %29 : Tensor = aten::sum_to_size(%27, %28)
- -> (%26, %29)
+ %grad_other.3 : Tensor = aten::sum_to_size(%27, %28)
+ -> (%grad_self.3, %grad_other.3)
}
- %30 : Tensor = prim::AutogradAdd(%14, %22)
- %31 : Tensor = prim::AutogradAdd(%9, %23)
+ %30 : Tensor = prim::AutogradAdd(%grad_other.2, %grad_self)
+ %31 : Tensor = prim::AutogradAdd(%9, %grad_other)
return (%30, %31);
}
%8 : Tensor = prim::SumToSize(%0, %4)
-> (%8)
}
- %9 : Tensor, %10 : Tensor = prim::GradOf[name="aten::mul"](%7)
+ %grad_self : Tensor, %grad_other : Tensor = prim::GradOf[name="aten::mul"](%7)
block0() {
%11 : Tensor = aten::mul(%7, %2)
%12 : int[] = aten::size(%3)
- %13 : Tensor = aten::sum_to_size(%11, %12)
+ %grad_self.1 : Tensor = aten::sum_to_size(%11, %12)
%14 : Tensor = aten::mul(%7, %3)
%15 : int[] = aten::size(%2)
- %16 : Tensor = aten::sum_to_size(%14, %15)
- -> (%13, %16)
+ %grad_other.1 : Tensor = aten::sum_to_size(%14, %15)
+ -> (%grad_self.1, %grad_other.1)
}
- %17 : Tensor = prim::AutogradAdd(%1, %9)
+ %17 : Tensor = prim::AutogradAdd(%1, %grad_self)
%18 : Tensor = prim::GradOf[name="aten::add"](%17)
block0() {
%19 : Tensor = aten::mul(%17, %6)
%20 : Tensor = prim::SumToSize(%19, %5)
-> (%20)
}
- %21 : Tensor = prim::AutogradAdd(%10, %18)
+ %21 : Tensor = prim::AutogradAdd(%grad_other, %18)
return (%21);
}
try:
pp, constant_table = module._python_print()
except RuntimeError as e:
- if "could not export python function" not in str(e):
+ se = str(e)
+ if "could not export python function" not in se and \
+ "closures are not exportable" not in se:
raise
else:
return
case prim::None:
case prim::BroadcastSizes:
case prim::ChunkSizes:
+ case prim::Function:
return analyzeCreator(node);
case prim::TupleUnpack:
case prim::TupleIndex:
}
addAlias(value, getFreshAlias());
}
-
} // namespace jit
} // namespace torch
}
stmt << ")";
} break;
+ case prim::Function: {
+ if (enforce_importable_) {
+ throw script::ErrorReport(node->getSourceLocation()) << "closures are not exportable";
+ }
+ auto name = genMethodName("__lambda");
+ std::shared_ptr<Graph> graph = node->g(attr::Subgraph);
+ worklist.emplace_back([graph, name, this] {
+ printFunctionDefinition(*graph, name);
+ });
+ stmt << "self." << name;
+ } break;
default: {
Symbol kind = node->kind();
if (kind.is_aten()) {
struct to_ir {
to_ir(
- Def def_,
+ Def def,
Resolver resolver_,
- SugaredValuePtr self_,
+ const SugaredValuePtr& self,
Method& method) // method being constructed
: method(method)
, graph(method.graph())
- , def(std::move(def_))
, resolver(std::move(resolver_))
- , self(std::move(self_))
, environment_stack(nullptr) {
JIT_ASSERT(resolver);
pushFrame(graph->block());
if (self && def.decl().params().size() == 0) {
throw ErrorReport(def.decl().params().range()) << "methods must have a self argument";
}
- auto schema = extractSchemaFromDef(def);
- std::vector<Argument> arguments = emitFormalArguments(self, schema);
- // body
- auto stmts = def.statements();
- auto stmts_begin = stmts.begin();
- auto stmts_end = stmts.end();
- c10::optional<Return> return_stmt;
- if (stmts_begin != stmts_end && (*std::prev(stmts_end)).kind() == TK_RETURN) {
- --stmts_end;
- return_stmt = Return(*stmts_end);
- }
- emitStatements(stmts_begin, stmts_end);
- std::vector<Argument> returns = {emitReturn(
- return_stmt ? return_stmt->range() : def.range(), return_stmt, schema)};
-
- method.setSchema({def.name().name(), std::move(arguments), std::move(returns)});
- // remove any uses of tuples that we inserted that are not needed
- LowerSimpleTuples(graph);
- ConstantPooling(graph);
+ method.setSchema(emitDef(def, self, graph->block()));
+ runCleanupPasses(graph);
}
private:
Method& method;
std::shared_ptr<Graph> graph;
- Def def;
Resolver resolver;
- SugaredValuePtr self;
std::unordered_map<int64_t, Value*> integral_constants;
std::unordered_map<double, Value*> fp_constants;
return old_frame;
}
+ void runCleanupPasses(std::shared_ptr<Graph>& to_clean) {
+ // remove any uses of tuples that we inserted that are not needed
+ LowerSimpleTuples(to_clean);
+ ConstantPooling(to_clean);
+ }
+
+ FunctionSchema emitDef(const Def& def, const SugaredValuePtr& self, Block* block) {
+ auto schema = extractSchemaFromDef(def, self);
+ std::vector<Argument> arguments = emitFormalArguments(def, self, schema, block);
+
+ // body
+ auto stmts = def.statements();
+ auto stmts_begin = stmts.begin();
+ auto stmts_end = stmts.end();
+ c10::optional<Return> return_stmt;
+ if (stmts_begin != stmts_end && (*std::prev(stmts_end)).kind() == TK_RETURN) {
+ --stmts_end;
+ return_stmt = Return(*stmts_end);
+ }
+ emitStatements(stmts_begin, stmts_end);
+ const SourceRange& range = return_stmt ? return_stmt->range() : def.range();
+ std::vector<Argument> returns = {emitReturn(range, return_stmt, schema, block)};
+ return {def.name().name(), std::move(arguments), std::move(returns)};
+ }
+
std::vector<IValue> evaluateDefaults(const SourceRange& r, const std::vector<Expr>& default_types, const std::vector<Expr>& default_exprs) {
std::vector<IValue> default_values;
if (default_exprs.empty())
return stack.at(0).toTuple()->elements();
}
- std::vector<Argument> parseArgsFromDecl(const Decl& decl) {
+ std::vector<Argument> parseArgsFromDecl(const Decl& decl, const SugaredValuePtr& self) {
auto params_begin = decl.params().begin();
auto params_end = decl.params().end();
if (self)
/*default_value =*/c10::nullopt,
/*kwarg_only =*/false)};
}
- FunctionSchema extractSchemaFromDef(const Def &def) {
+ FunctionSchema extractSchemaFromDef(const Def &def, const SugaredValuePtr& self) {
auto name = def.name().name();
- std::vector<Argument> args = parseArgsFromDecl(def.decl());
+ std::vector<Argument> args = parseArgsFromDecl(def.decl(), self);
std::vector<Argument> returns = parseReturnFromDecl(def.decl());
return FunctionSchema(name, std::move(args), std::move(returns), false, false);
}
- std::vector<Argument> emitFormalArguments(const SugaredValuePtr& self, const FunctionSchema& schema) {
+ std::vector<Argument> emitFormalArguments(const Def& def, const SugaredValuePtr& self, const FunctionSchema& schema, Block* block) {
std::vector<Argument> arguments; // for schema
// inputs
auto it = def.decl().params().begin();
for(;it != end; ++it) {
auto& name = (*it).ident().name();
// Add the input to the graph
- Value *new_input = graph->addInput();
+ Value *new_input = block->addInput();
if (meaningfulName(name)) {
new_input->setUniqueName(name);
}
return arguments;
}
- Argument emitReturn(const SourceRange& range, c10::optional<Return> return_stmt, const FunctionSchema& schema) {
+ Argument emitReturn(const SourceRange& range, c10::optional<Return> return_stmt, const FunctionSchema& schema, Block* block) {
JIT_ASSERT(schema.returns().size() <= 1);
// outputs
Value* result = return_stmt ? emitExpr(return_stmt->expr())
throw ErrorReport(range) << "Return value was annotated as having type " << result_type->python_str()
<< " but is actually of type " << result->type()->python_str();
}
- graph->registerOutput(result);
+ block->registerOutput(result);
return Argument("", result_type);
}
void emitStatements(const List<Stmt>& statements) {
return emitStatements(statements.begin(), statements.end());
}
+ std::pair<std::shared_ptr<Graph>, Value*> lambdaLift(Block* block) {
+ auto subgraph = std::make_shared<Graph>();
+ // note: type is set later on pack_context and context when we know it
+ Node* pack_context = graph->insertNode(graph->create(prim::TupleConstruct, {}, 1));
+ Value* context = subgraph->addInput("context");
+ // cannot use createTupleUnpack because the type is not known yet
+ Node* unpack_context = subgraph->insertNode(subgraph->create(prim::TupleUnpack, {context}, 0));
+
+ std::unordered_map<Value*, Value*> captures;
+ auto env = [&](Value* v) -> Value* {
+ auto it = captures.find(v);
+ if (it != captures.end()) {
+ return it->second;
+ }
+ pack_context->addInput(v);
+ Value* r = unpack_context->addOutput()->copyMetadata(v);
+ captures[v] = r;
+ return r;
+ };
+ subgraph->block()->cloneFrom(block, env);
+ auto context_type = TupleType::create(
+ fmap(pack_context->inputs(), [](Value* v) { return v->type(); }));
+ pack_context->output()->setType(context_type);
+ context->setType(context_type);
+ return std::make_pair(std::move(subgraph), pack_context->output());
+ }
+ // XXX - right now closures are used _only_ for defining gradients internally
+ // There are several unfinished aspects that make them unusable generally
+ // 1. We do not have a type, ivalue, operator to represent prim::Function, so closure_node has type None
+ // and any graphs that contain it cannot be run
+ // 2. There is no export logic for it yet, so it cannot be exported/python_printed
+ // 3. There is nothing preventing the assignment of already existing variables inside the closures
+ // the changes to those variables will just get forgotten.
+ // 4. There is no parsing support in frontend.py, this is intentional since it
+ // prevents people from accidentally using this feature.
+ void emitClosure(const Def& def) {
+ Node* closure_node = graph->insertNode(graph->create(prim::Function, 1));
+ closure_node->output()->setType(NoneType::get()); //it is not a real thing yet, so just say the type is none.
+ Block* block = closure_node->addBlock();
+ {
+ WithInsertPoint guard(block);
+ pushFrame(block);
+ emitDef(def, nullptr, block); //ignore schema return, we just wont use it for now since we never create a Method for the closure
+ popFrame();
+ }
+ std::shared_ptr<Graph> subgraph;
+ Value* context;
+ std::tie(subgraph, context) = lambdaLift(block);
+ runCleanupPasses(subgraph);
+ closure_node->eraseBlock(0);
+ closure_node->g_(attr::Subgraph, std::move(subgraph));
+ auto tup = graph->insertNode(graph->createTuple({closure_node->output(), context}))->output();
+ environment_stack->setVar(def.name().range(), def.name().name(), tup);
+ }
void emitStatements(List<Stmt>::const_iterator begin, List<Stmt>::const_iterator end) {
for (; begin != end; ++begin) {
auto stmt = *begin;
case TK_PASS:
// Emit nothing for pass
break;
+ case TK_DEF:
+ emitClosure(Def(stmt));
+ break;
default:
throw ErrorReport(stmt)
<< "Unrecognized statement kind " << kindToString(stmt.kind());
L.expect(TK_NEWLINE);
return Pass::create(range);
}
+ case TK_DEF: {
+ return parseFunction(/*is_method=*/false);
+ }
default: {
auto lhs = parseExpOrExpTuple();
if (L.cur().kind != TK_NEWLINE) {
// | Return(List<Expr> values) TK_RETURN
// | ExprStmt(List<Expr> expr) TK_EXPR_STMT
// | Raise(Expr expr) TK_RAISE
+// | Def TK_DEF
//
// Expr = TernaryIf(Expr cond, Expr true_expr, Expr false_expr) TK_IF_EXPR
// | BinOp(Expr lhs, Expr rhs)
case TK_RAISE:
case TK_ASSERT:
case TK_PASS:
+ case TK_DEF:
return;
default:
throw ErrorReport(tree) << kindToString(tree->kind()) << " is not a valid Stmt";
#include <torch/csrc/jit/symbolic_script.h>
+
+
namespace torch { namespace jit {
namespace {
std::mutex lock;
- const std::unordered_map<std::string, std::string> symbolic_scripts({
- {"aten::mul(Tensor self, Tensor other) -> Tensor",
-R"(
-def forward(self, other):
- return self * other, (self, other)
-def backward(ctx, grad_output):
- # type: (Tuple[Tensor, Tensor], Tensor) -> Tuple[Tensor, Tensor]
- self, other = ctx
- return (grad_output * other).sum_to_size(self.size()), (grad_output * self).sum_to_size(other.size())
-)"},
- });
+ const std::vector<std::string> functions = {
+ R"(
+ def mul(self, other):
+ def backward(grad_output):
+ grad_self = (grad_output * other).sum_to_size(self.size())
+ grad_other = (grad_output * self).sum_to_size(other.size())
+ return grad_self, grad_other
+ return self * other, backward
+ )"
+ };
+ std::unordered_map<std::string, GradientPair> schema_to_graphs;
// This map is a workaround to cache compiled gradient_pairs. Ideally this graph
// should be compiled only once and saved in Operator structure.
std::unordered_map<const FunctionSchema*, GradientPair> cached_gradient_pairs;
} // anonymous namespace
+ std::pair<std::shared_ptr<Graph>, Value*> extractClosure(Value* closure) {
+ AT_CHECK(closure->node()->kind() == prim::TupleConstruct, "closure must be a literal tuple construct");
+ Value* fn = closure->node()->inputs().at(0);
+ Value* context = closure->node()->inputs().at(1);
+
+ AT_CHECK(fn->node()->kind() == prim::Function, "closure tuple must contain a prim::Function");
+ return std::make_pair(fn->node()->g(attr::Subgraph), context);
+ }
+
+ Argument originalReturnType(const TupleTypePtr& tup) {
+ AT_CHECK(tup->elements().size() > 1);
+ if(tup->elements().size() == 2)
+ return Argument("", tup->elements().at(0));
+ std::vector<TypePtr> types = tup->elements().vec();
+ types.pop_back();
+ return Argument("", TupleType::create(std::move(types)));
+ }
+
+ void loadModule(const std::shared_ptr<script::Module>& module) {
+ for(const auto& method_ : module->get_methods()) {
+ const auto& method = method_.value();
+ GradientPair pair;
+ pair.forward = method->graph();
+
+ // lookup the backward function
+ Node* forward_tuple = pair.forward->outputs().at(0)->node();
+
+ if (forward_tuple->kind() != prim::TupleConstruct) {
+ throw script::ErrorReport(forward_tuple->getSourceLocation()) << "gradient must return literal a tuple";
+ }
+
+ Value* context;
+ std::tie(pair.backward, context) = extractClosure(forward_tuple->inputs().back());
+
+ // do surgery on the forward function to remove the closure tuple and replace it with the
+ // context variable:
+ // backward = (<lambda>, context_tuple)
+ // return original, backward
+ // -----
+ // return original, context_tuple
+ std::vector<Value*> new_inputs = forward_tuple->inputs().vec();
+ new_inputs.back() = context;
+ Value* new_tuple = pair.forward->appendNode(pair.forward->createTuple(new_inputs))->output();
+ pair.forward->eraseOutput(0);
+ pair.forward->registerOutput(new_tuple);
+ forward_tuple->destroy();
+
+ // derive schema from original function's schema:
+ const FunctionSchema& loaded_schema = method->getSchema();
+ FunctionSchema actual_schema(Symbol::aten(loaded_schema.name()),
+ loaded_schema.arguments(),
+ {originalReturnType(new_tuple->type()->expect<TupleType>())}
+ );
+ std::string key = canonicalSchemaString(actual_schema);
+ schema_to_graphs[key] = std::move(pair);
+ }
+ }
+
+ void loadFunctions() {
+ for(const std::string& str : functions) {
+ auto cu = std::make_shared<script::Module>();
+ script::defineMethodsInModule(cu, str, script::nativeResolver, nullptr);
+ loadModule(cu);
+ }
+ }
+
c10::optional<GradientPair> gradientInfoForSchema(const FunctionSchema& schema) {
std::lock_guard<std::mutex> guard(lock);
+ if (schema_to_graphs.size() == 0) {
+ loadFunctions();
+ }
auto cache_it = cached_gradient_pairs.find(&schema);
if (cache_it != cached_gradient_pairs.end()) {
return cache_it->second;
} else {
auto schema_str = canonicalSchemaString(schema);
- auto sym_script_it = symbolic_scripts.find(schema_str);
-
- if (sym_script_it != symbolic_scripts.end()) {
- // Compile the python code to a script module
- auto cu = std::make_shared<script::Module>();
- script::defineMethodsInModule(cu, symbolic_scripts.at(schema_str), script::nativeResolver, nullptr);
- auto fw_graph = cu->find_method("forward")->graph();
- auto bw_graph = cu->find_method("backward")->graph();
-
- GradientPair compiled_graphs{fw_graph, bw_graph};
- cached_gradient_pairs.emplace_hint(cache_it, &schema, compiled_graphs);
- return compiled_graphs;
+ auto sym_script_it = schema_to_graphs.find(schema_str);
+ if (sym_script_it != schema_to_graphs.end()) {
+ cached_gradient_pairs.emplace_hint(cache_it, &schema, sym_script_it->second);
+ return sym_script_it->second;
}
}
return c10::nullopt;
}
bool hasGradientInfoForSchema(const FunctionSchema& schema) {
- std::lock_guard<std::mutex> guard(lock);
- auto cache_it = cached_gradient_pairs.find(&schema);
- if (cache_it == cached_gradient_pairs.end()) {
- auto schema_str = canonicalSchemaString(schema);
- auto sym_script_it = symbolic_scripts.find(schema_str);
- return !(sym_script_it == symbolic_scripts.end());
- }
- return true;
+ return gradientInfoForSchema(schema).has_value();
}
-}}
+}}