Tensor sum(IntList dim, bool keepdim, ScalarType dtype) const;
Tensor sum(IntList dim, bool keepdim=false) const;
Tensor sum(IntList dim, ScalarType dtype) const;
+ Tensor sum_to_size(IntList size) const;
Tensor sqrt() const;
Tensor & sqrt_();
Tensor std(bool unbiased=true) const;
inline Tensor Tensor::sum(IntList dim, ScalarType dtype) const {
return type().sum(*this, dim, dtype);
}
+inline Tensor Tensor::sum_to_size(IntList size) const {
+ return type().sum_to_size(*this, size);
+}
inline Tensor Tensor::sqrt() const {
return type().sqrt(*this);
}
virtual Tensor sum(const Tensor & self, IntList dim, bool keepdim, ScalarType dtype) const = 0;
virtual Tensor sum(const Tensor & self, IntList dim, bool keepdim) const = 0;
virtual Tensor sum(const Tensor & self, IntList dim, ScalarType dtype) const = 0;
+ virtual Tensor sum_to_size(const Tensor & self, IntList size) const = 0;
virtual Tensor sqrt(const Tensor & self) const = 0;
virtual Tensor & sqrt_(Tensor & self) const = 0;
virtual Tensor std(const Tensor & self, bool unbiased) const = 0;
_(aten, sub_) \
_(aten, rsub) \
_(aten, sum) \
+_(aten, sum_to_size) \
_(aten, svd) \
_(aten, symeig) \
_(aten, t) \
}
Tensor cat(TensorList tensors, int64_t dim) {
- if (tensors.size() > 0 &&
+ if (tensors.size() > 0 &&
tensors[0].is_sparse()) {
return cat_sparse(tensors, dim);
}
return self.expand(other.sizes());
}
+Tensor sum_to_size(const Tensor& self, IntList size) {
+ AT_CHECK(is_expandable_to(size, self.sizes()),
+ "size {", size, "} is not expandable to size {", self.sizes(), "}.");
+
+ return sum_to(self, size);
+}
+
Tensor as_strided(const Tensor& self, IntList size, IntList stride, int64_t storage_offset) {
auto tid = self.type_id();
AT_CHECK(
- func: sum_out(Tensor result, Tensor self, IntList[1] dim, *, ScalarType dtype) -> Tensor
+- func: sum_to_size(Tensor self, IntList size) -> Tensor
+ variants: method
+ device_guard: false
+
- func: sqrt(Tensor self) -> Tensor
variants: function, method
#include "torch/csrc/jit/passes/shape_analysis.h"
#include "torch/csrc/jit/passes/utils/subgraph_utils.h"
#include "torch/csrc/jit/symbolic_variable.h"
+#include "torch/csrc/jit/symbolic_script.h"
#include "torch/csrc/jit/tracer.h"
#include "torch/csrc/utils/hash.h"
#include "torch/csrc/autograd/generated/variable_factories.h"
%24 : int[]
%25 : int[]
%26 : Float(*, *)) {
- %27 : int[] = aten::size(%26)
+ %27 : Float(*, *) = aten::mul(%0, %26)
%28 : int[] = aten::size(%outgate)
- %29 : int[] = aten::size(%cellgate)
- %30 : int[] = aten::size(%ingate)
- %31 : int[] = aten::size(%9)
- %32 : int[] = aten::size(%forgetgate)
- %33 : Float(*, *) = aten::mul(%0, %26)
- %34 : Tensor = prim::SumToSize(%33, %28)
- %35 : Float(*, *) = aten::mul(%0, %outgate)
- %36 : Tensor = prim::SumToSize(%35, %27)
- %37 : Tensor = prim::FusionGroup_0(%1, %36, %26)
- %38 : Tensor = prim::SumToSize(%37, %24)
- %39 : Tensor = prim::SumToSize(%37, %25)
- %40 : Tensor = aten::mul(%39, %cellgate)
- %41 : Tensor = prim::SumToSize(%40, %30)
- %42 : Tensor = aten::mul(%39, %ingate)
- %43 : Tensor = prim::SumToSize(%42, %29)
- %44 : Tensor = aten::mul(%38, %9)
- %45 : Tensor = prim::SumToSize(%44, %32)
- %46 : Tensor = aten::mul(%38, %forgetgate)
- %47 : Tensor = prim::SumToSize(%46, %31)
- %48 : Tensor = prim::FusionGroup_1(%41, %ingate, %45, %forgetgate, %43, %cellgate, %34, %outgate)
+ %29 : Tensor = aten::sum_to_size(%27, %28)
+ %30 : Float(*, *) = aten::mul(%0, %outgate)
+ %31 : int[] = aten::size(%26)
+ %32 : Tensor = aten::sum_to_size(%30, %31)
+ %33 : Tensor = prim::FusionGroup_0(%1, %32, %26)
+ %34 : Tensor = prim::SumToSize(%33, %24)
+ %35 : Tensor = prim::SumToSize(%33, %25)
+ %36 : Tensor = aten::mul(%35, %cellgate)
+ %37 : int[] = aten::size(%ingate)
+ %38 : Tensor = aten::sum_to_size(%36, %37)
+ %39 : Tensor = aten::mul(%35, %ingate)
+ %40 : int[] = aten::size(%cellgate)
+ %41 : Tensor = aten::sum_to_size(%39, %40)
+ %42 : Tensor = aten::mul(%34, %9)
+ %43 : int[] = aten::size(%forgetgate)
+ %44 : Tensor = aten::sum_to_size(%42, %43)
+ %45 : Tensor = aten::mul(%34, %forgetgate)
+ %46 : int[] = aten::size(%9)
+ %47 : Tensor = aten::sum_to_size(%45, %46)
+ %48 : Tensor = prim::FusionGroup_1(%38, %ingate, %44, %forgetgate, %41, %cellgate, %29, %outgate)
%49 : Tensor = prim::SumToSize(%48, %19)
%50 : Tensor = prim::SumToSize(%48, %17)
%51 : Tensor = prim::SumToSize(%48, %14)
%30 : int[]
%31 : Float(*, *)) {
%32 : int = prim::Constant[value=1]()
- %33 : int[] = aten::size(%31)
+ %33 : Float(*, *) = aten::mul(%0, %31)
%34 : int[] = aten::size(%outgate)
- %35 : int[] = aten::size(%cellgate)
- %36 : int[] = aten::size(%ingate)
- %37 : int[] = aten::size(%forgetgate)
- %38 : int[] = aten::size(%11)
- %39 : int[] = aten::size(%12)
- %40 : int[] = aten::size(%Uz)
- %41 : int[] = aten::size(%18)
- %42 : int[] = aten::size(%Wx)
- %43 : int[] = aten::size(%13)
- %44 : Float(*, *) = aten::mul(%0, %31)
- %45 : Tensor = prim::SumToSize(%44, %34)
- %46 : Float(*, *) = aten::mul(%0, %outgate)
- %47 : Tensor = prim::SumToSize(%46, %33)
- %48 : Tensor = prim::FusionGroup_0(%1, %47, %31)
- %49 : Tensor = prim::SumToSize(%48, %29)
- %50 : Tensor = prim::SumToSize(%48, %30)
- %51 : Tensor = aten::mul(%50, %cellgate)
- %52 : Tensor = prim::SumToSize(%51, %36)
- %53 : Tensor = aten::mul(%50, %ingate)
- %54 : Tensor = prim::SumToSize(%53, %35)
- %55 : Tensor = aten::mul(%49, %10)
- %56 : Tensor = prim::SumToSize(%55, %37)
- %57 : Tensor = prim::FusionGroup_1(%52, %ingate, %56, %forgetgate, %54, %cellgate, %45, %outgate)
- %58 : Tensor = prim::SumToSize(%57, %24)
- %59 : Tensor = prim::SumToSize(%57, %22)
- %60 : Tensor = aten::mul(%59, %Uz)
- %61 : Tensor = prim::SumToSize(%60, %38)
- %62 : Tensor = aten::mul(%59, %11)
- %63 : Tensor = prim::SumToSize(%62, %40)
- %64 : Tensor = prim::SumToSize(%57, %19)
- %65 : Tensor = prim::SumToSize(%57, %20)
- %66 : Tensor = aten::mul(%65, %Wx)
- %67 : Tensor = prim::SumToSize(%66, %39)
- %68 : Tensor = aten::mul(%65, %12)
- %69 : Tensor = prim::SumToSize(%68, %42)
- %70 : Tensor = aten::mul(%64, %Uz)
- %71 : Tensor = prim::SumToSize(%70, %41)
- %72 : Tensor = aten::mul(%64, %18)
- %73 : Tensor = prim::SumToSize(%72, %40)
- %74 : Tensor = aten::add(%63, %73, %32)
- %75 : Tensor = aten::mul(%71, %Wx)
- %76 : Tensor = prim::SumToSize(%75, %43)
- %77 : Tensor = aten::mul(%71, %13)
- %78 : Tensor = prim::SumToSize(%77, %42)
- %79 : Tensor = aten::add(%69, %78, %32)
+ %35 : Tensor = aten::sum_to_size(%33, %34)
+ %36 : Float(*, *) = aten::mul(%0, %outgate)
+ %37 : int[] = aten::size(%31)
+ %38 : Tensor = aten::sum_to_size(%36, %37)
+ %39 : Tensor = prim::FusionGroup_0(%1, %38, %31)
+ %40 : Tensor = prim::SumToSize(%39, %29)
+ %41 : Tensor = prim::SumToSize(%39, %30)
+ %42 : Tensor = aten::mul(%41, %cellgate)
+ %43 : int[] = aten::size(%ingate)
+ %44 : Tensor = aten::sum_to_size(%42, %43)
+ %45 : Tensor = aten::mul(%41, %ingate)
+ %46 : int[] = aten::size(%cellgate)
+ %47 : Tensor = aten::sum_to_size(%45, %46)
+ %48 : Tensor = aten::mul(%40, %10)
+ %49 : int[] = aten::size(%forgetgate)
+ %50 : Tensor = aten::sum_to_size(%48, %49)
+ %51 : Tensor = prim::FusionGroup_1(%44, %ingate, %50, %forgetgate, %47, %cellgate, %35, %outgate)
+ %52 : Tensor = prim::SumToSize(%51, %24)
+ %53 : Tensor = prim::SumToSize(%51, %22)
+ %54 : Tensor = aten::mul(%53, %Uz)
+ %55 : int[] = aten::size(%11)
+ %56 : Tensor = aten::sum_to_size(%54, %55)
+ %57 : Tensor = aten::mul(%53, %11)
+ %58 : int[] = aten::size(%Uz)
+ %59 : Tensor = aten::sum_to_size(%57, %58)
+ %60 : Tensor = prim::SumToSize(%51, %19)
+ %61 : Tensor = prim::SumToSize(%51, %20)
+ %62 : Tensor = aten::mul(%61, %Wx)
+ %63 : int[] = aten::size(%12)
+ %64 : Tensor = aten::sum_to_size(%62, %63)
+ %65 : Tensor = aten::mul(%61, %12)
+ %66 : int[] = aten::size(%Wx)
+ %67 : Tensor = aten::sum_to_size(%65, %66)
+ %68 : Tensor = aten::mul(%60, %Uz)
+ %69 : int[] = aten::size(%18)
+ %70 : Tensor = aten::sum_to_size(%68, %69)
+ %71 : Tensor = aten::mul(%60, %18)
+ %72 : Tensor = aten::sum_to_size(%71, %58)
+ %73 : Tensor = aten::add(%59, %72, %32)
+ %74 : Tensor = aten::mul(%70, %Wx)
+ %75 : int[] = aten::size(%13)
+ %76 : Tensor = aten::sum_to_size(%74, %75)
+ %77 : Tensor = aten::mul(%70, %13)
+ %78 : Tensor = aten::sum_to_size(%77, %66)
+ %79 : Tensor = aten::add(%67, %78, %32)
%80 : Float(*, *) = aten::t(%14)
- %81 : Float(*, *) = aten::mm(%80, %74)
+ %81 : Float(*, *) = aten::mm(%80, %73)
%82 : Float(*, *) = aten::t(%81)
%83 : Float(*, *) = aten::t(%15)
%84 : Float(*, *) = aten::mm(%83, %79)
%85 : Float(*, *) = aten::t(%84)
- return (%58, %61, %67, %76, %82, %85);
+ return (%52, %56, %64, %76, %82, %85);
}
with prim::FusionGroup_0 = graph(%0 : Float(*, *)
%1 : Tensor
%3 : Float(2, 3, 4)
%4 : Float(2, 3, 4)
%5 : int[]) {
- %9 : int = prim::Constant[value=1]()
- %6 : int[] = aten::size(%4)
- %7 : int[] = aten::size(%3)
- %8 : int[] = aten::size(%2)
- %10 : Tensor, %11 : Tensor = prim::GradOf[name="aten::add"](%0)
+ %7 : int = prim::Constant[value=1]()
+ %6 : int[] = aten::size(%3)
+ %8 : Tensor, %9 : Tensor = prim::GradOf[name="aten::add"](%0)
block0() {
- %12 : Tensor = prim::SumToSize(%0, %5)
- %13 : Float(2, 3, 4) = aten::mul(%0, %9)
- %14 : Tensor = prim::SumToSize(%13, %7)
- -> (%12, %14)
+ %10 : Tensor = prim::SumToSize(%0, %5)
+ %11 : Float(2, 3, 4) = aten::mul(%0, %7)
+ %12 : Tensor = prim::SumToSize(%11, %6)
+ -> (%10, %12)
}
- %15 : Tensor, %16 : Tensor = prim::GradOf[name="aten::mul"](%10)
+ %13 : Tensor, %14 : Tensor = prim::GradOf[name="aten::mul"](%8)
block0() {
- %17 : Tensor = aten::mul(%10, %2)
- %18 : Tensor = prim::SumToSize(%17, %6)
- %19 : Tensor = aten::mul(%10, %4)
- %20 : Tensor = prim::SumToSize(%19, %8)
- -> (%18, %20)
+ %15 : Tensor = aten::mul(%8, %2)
+ %16 : int[] = aten::size(%4)
+ %17 : Tensor = aten::sum_to_size(%15, %16)
+ %18 : Tensor = aten::mul(%8, %4)
+ %19 : int[] = aten::size(%2)
+ %20 : Tensor = aten::sum_to_size(%18, %19)
+ -> (%17, %20)
}
- %21 : Tensor = prim::AutogradAdd(%1, %15)
+ %21 : Tensor = prim::AutogradAdd(%1, %13)
%22 : Tensor, %23 : Tensor = prim::GradOf[name="aten::mul"](%21)
block0() {
%24 : Tensor = aten::mul(%21, %3)
- %25 : Tensor = prim::SumToSize(%24, %8)
- %26 : Tensor = aten::mul(%21, %2)
- %27 : Tensor = prim::SumToSize(%26, %7)
- -> (%25, %27)
+ %25 : int[] = aten::size(%2)
+ %26 : Tensor = aten::sum_to_size(%24, %25)
+ %27 : Tensor = aten::mul(%21, %2)
+ %28 : int[] = aten::size(%3)
+ %29 : Tensor = aten::sum_to_size(%27, %28)
+ -> (%26, %29)
}
- %28 : Tensor = prim::AutogradAdd(%16, %22)
- %29 : Tensor = prim::AutogradAdd(%11, %23)
- return (%28, %29);
+ %30 : Tensor = prim::AutogradAdd(%14, %22)
+ %31 : Tensor = prim::AutogradAdd(%9, %23)
+ return (%30, %31);
}
testDifferentiateWithRequiresGrad
%2 : Float(*) = aten::mul(%1, %1)
%3 : int = prim::Constant[value=1]()
%4 : Float(*) = aten::add(%2, %1, %3)
- %26 : int[] = aten::size(%4)
%6 : Float(*) = aten::add(%4, %0, %3)
%7 : Float(*) = aten::mul(%6, %0)
%11 : int[] = aten::size(%7)
- %14 : int[] = aten::size(%1)
%9 : Float(*) = aten::add(%7, %1, %3)
return (%4, %9, %6, %11);
}
%2 : Float(*)
%3 : Float(*)
%4 : int[]) {
- %7 : int = prim::Constant[value=1]()
- %5 : int[] = aten::size(%3)
- %6 : int[] = aten::size(%2)
- %8 : Tensor = prim::GradOf[name="aten::add"](%0)
+ %6 : int = prim::Constant[value=1]()
+ %5 : int[] = aten::size(%2)
+ %7 : Tensor = prim::GradOf[name="aten::add"](%0)
block0() {
- %9 : Tensor = prim::SumToSize(%0, %4)
- -> (%9)
+ %8 : Tensor = prim::SumToSize(%0, %4)
+ -> (%8)
}
- %10 : Tensor, %11 : Tensor = prim::GradOf[name="aten::mul"](%8)
+ %9 : Tensor, %10 : Tensor = prim::GradOf[name="aten::mul"](%7)
block0() {
- %12 : Tensor = aten::mul(%8, %2)
- %13 : Tensor = prim::SumToSize(%12, %5)
- %14 : Tensor = aten::mul(%8, %3)
- %15 : Tensor = prim::SumToSize(%14, %6)
- -> (%13, %15)
+ %11 : Tensor = aten::mul(%7, %2)
+ %12 : int[] = aten::size(%3)
+ %13 : Tensor = aten::sum_to_size(%11, %12)
+ %14 : Tensor = aten::mul(%7, %3)
+ %15 : int[] = aten::size(%2)
+ %16 : Tensor = aten::sum_to_size(%14, %15)
+ -> (%13, %16)
}
- %16 : Tensor = prim::AutogradAdd(%1, %10)
- %17 : Tensor = prim::GradOf[name="aten::add"](%16)
+ %17 : Tensor = prim::AutogradAdd(%1, %9)
+ %18 : Tensor = prim::GradOf[name="aten::add"](%17)
block0() {
- %18 : Tensor = aten::mul(%16, %7)
- %19 : Tensor = prim::SumToSize(%18, %6)
- -> (%19)
+ %19 : Tensor = aten::mul(%17, %6)
+ %20 : Tensor = prim::SumToSize(%19, %5)
+ -> (%20)
}
- %20 : Tensor = prim::AutogradAdd(%11, %17)
- return (%20);
+ %21 : Tensor = prim::AutogradAdd(%10, %18)
+ return (%21);
}
"torch/csrc/jit/import.cpp",
"torch/csrc/jit/interpreter.cpp",
"torch/csrc/jit/ir.cpp",
+ "torch/csrc/jit/symbolic_script.cpp",
"torch/csrc/jit/operator.cpp",
"torch/csrc/jit/passes/alias_analysis.cpp",
"torch/csrc/jit/passes/batch_mm.cpp",
${TORCH_SRC_DIR}/csrc/jit/node_hashing.cpp
${TORCH_SRC_DIR}/csrc/jit/ir.cpp
${TORCH_SRC_DIR}/csrc/jit/operator.cpp
- ${TORCH_SRC_DIR}/csrc/jit/operator.cpp
+ ${TORCH_SRC_DIR}/csrc/jit/symbolic_script.cpp
${TORCH_SRC_DIR}/csrc/jit/passes/alias_analysis.cpp
${TORCH_SRC_DIR}/csrc/jit/passes/batch_mm.cpp
${TORCH_SRC_DIR}/csrc/jit/passes/canonicalize.cpp
as :attr:`other`.
""")
+add_docstr_all('sum_to_size',
+ r"""
+sum_to_size(*size) -> Tensor
+
+Sum ``this`` tensor to :attr:`size`.
+:attr:`size` must be broadcastable to ``this`` tensor size.
+Args:
+ other (:class:`torch.Tensor`): The result tensor has the same size
+ as :attr:`other`.
+""")
+
+
add_docstr_all('zero_',
r"""
zero_() -> Tensor
#include <torch/csrc/jit/autodiff.h>
+#include "torch/csrc/jit/passes/lower_tuples.h"
#include <torch/csrc/jit/passes/common_subexpression_elimination.h>
#include <torch/csrc/jit/passes/dead_code_elimination.h>
#include <torch/csrc/jit/passes/constant_pooling.h>
+#include "torch/csrc/jit/symbolic_script.h"
#include <torch/csrc/jit/symbolic_variable.h>
#include <torch/csrc/jit/operator.h>
#include <torch/csrc/utils/functional.h>
+#include "torch/csrc/jit/script/compiler.h"
#include <torch/csrc/jit/assertions.h>
if (differentiable_ops.find(n))
return true;
+ auto schema = n->maybeSchema();
+ if (schema && hasGradientInfoForSchema(*schema)) {
+ return true;
+ }
+
if (n->matches("aten::expand(Tensor self, int[] size, *, bool implicit) -> Tensor")) {
return n->get<std::vector<int64_t>>(attr::size) && n->is_constant(attr::implicit) &&
n->namedInput(attr::self)->type()->cast<CompleteTensorType>();
static_cast<bool(*)(Node*)>(isDifferentiable));
}
+// TODO: Remove this after #15355.
+namespace {
+ std::vector<Value*> inlineUnpackedCallTo(Graph& g, Graph& callee, ArrayRef<Value*> inputs) {
+ auto outputs = script::inlineCallTo(g, callee, inputs);
+ if (callee.outputs().size() == 1 && callee.outputs().at(0)->type()->kind() == TupleType::Kind) {
+ auto tc = script::createTupleUnpack(outputs.at(0));
+ outputs = std::vector<Value*>(tc.begin(), tc.end());
+ }
+ return outputs;
+ }
+} //anonymous namespace
+
+
+// NB: Write gradient using torchscript
+// For example, node aten::mul() should be defined as follows
+// def forward(x, y):
+// return x*y, (x, y)
+// def backward(ctx, grad_output):
+// x, y = ctx
+// return (y * grad_output).sum_to_size(x), (x * grad_output).sum_to_size(y)
+//
+// Here ctx is a tuple that carries all input/intermediate results needed in
+// backward from forward pass.
+// This python code is compiled into a GradientPair which includes a forward graph
+// and a backward graph. Forward graph will be used to replace the node in grad_desc.f,
+// and backward graph will be used to construct GradOf(node) in reverse_block.
+// Grad_values(a.k.a gradOutputs) propagated through node->owningGraph() in
+// **reversed** order, thus GradientPair.forward ahould be inserted **after**
+// the node being replaced, so that we don't traverse the graph infinite times.
+// The output of compiled forward graph is [real_outputs, ctx]
+// The input of compiled backward graph is [ctx, grad_values]
+// We run LowerSimpleTuples afterwards to elmininate all tuples generated in this process.
+// The original node and TupleConstruct nodes in forward graph will be cleaned up
+// later using EliminateDeadCode(block).
+// TupleUnPack node in backward graph will be removed in eliminateDeadcode(ReverseDetails)
+// defined in this file.
+static c10::optional<std::vector<Value*>> build_script_grad(
+ Node* node,
+ const ArrayRef<Value*>& grads) {
+ auto graph = node->owningGraph();
+
+ auto compiled_graphs = gradientInfoForSchema(node->schema());
+ if (!compiled_graphs) {
+ return c10::nullopt;
+ }
+ // Use forward graph to replace node in grad_desc.f
+ value_list new_outputs;
+ {
+ WithInsertPoint guard(node->next());
+ auto fw_graph = compiled_graphs->forward;
+ new_outputs = inlineUnpackedCallTo(*graph, *fw_graph, node->inputs());
+ for (size_t i = 0; i < node->outputs().size(); ++i) {
+ new_outputs.at(i)->setType(node->outputs()[i]->type());
+ new_outputs.at(i)->replaceAllUsesWith(node->outputs()[i]);
+ }
+ }
+
+ // Use backward graph to construct reverse_block
+ auto bw_graph = compiled_graphs->backward;
+ auto grad_vec = grads.vec();
+ auto it = grad_vec.begin();
+ grad_vec.insert(it, new_outputs.back());
+ ArrayRef<Value*> grad(grad_vec);
+ auto grad_inputs = inlineUnpackedCallTo(*graph, *bw_graph, grad);
+ return grad_inputs;
+};
static std::vector<Value*> gradientForNode(Node* node, ArrayRef<Value*> grad_values) {
static const OperatorSet comparison_ops = {
throw std::runtime_error(std::string("differentiation of ") + node->kind().toDisplayString() + " "
"is not supported, or it is missing necessary type information");
}
+ // If AD is defined using torchscript, use it instead of symbolic
+ auto script_grads = build_script_grad(node, grad_values);
+ if (script_grads)
+ return *script_grads;
+ // Definition not found in torchscript, look up in the build_sym_grad
+ // TODO: migrate all to using torchscript
auto sym_grads = build_sym_grad(fmap<SymbolicVariable>(grad_values));
return fmap(sym_grads, [](const SymbolicVariable &v) { return v.value(); });
}
}
value_list grad_inputs = linearGradientForNode(node, fmap(node->outputs(), get_grad));
+ LowerSimpleTuples(reverse_block);
+
JIT_ASSERT(grad_inputs.size() == node->inputs().size());
for (size_t i = 0, num_inputs = grad_inputs.size(); i < num_inputs; ++i) {
if (!inputs[i]->requires_grad()) continue;
reverse_block->registerOutput(get_grad(input));
grad_desc.df_output_vjps.push_back(i);
}
+
return ReverseDetails(std::move(grad_map), reverse_block);
}
// Fills in df_input_vjps and df_output_vjps
auto rev_info = addReverseInline(grad_desc);
Optimize(grad_desc, rev_info);
+ // Clean up old nodes which has been replaced by forward graphs in torchscript
+ EliminateDeadCode(grad_desc.f->block());
+
// Fills in f, df, f_real_outputs, df_input_captures,
// modifies df_input_vjps (new vjps are added for temporaries)
lambdaLiftReverse(grad_desc, rev_info);
} // namespace script
namespace {
-
-std::string canonicalSchemaString(const FunctionSchema& schema) {
- std::ostringstream out;
-
- out << schema.name();
- out << "(";
-
- bool seen_kwarg_only = false;
- for(size_t i = 0; i < schema.arguments().size(); ++i) {
- if (i > 0) out << ", ";
- if (schema.arguments()[i].kwarg_only() && !seen_kwarg_only) {
- out << "*, ";
- seen_kwarg_only = true;
- }
- const auto & arg = schema.arguments()[i];
- out << arg.type()->str() << " " << arg.name();
- }
-
- out << ") -> ";
- if (schema.returns().size() == 1) {
- out << schema.returns().at(0).type()->str();
- } else if (schema.returns().size() > 1) {
- out << "(";
- for (size_t i = 0; i < schema.returns().size(); ++i) {
- if (i > 0) out << ", ";
- out << schema.returns()[i].type()->str();
- }
- out << ")";
- }
- return out.str();
-}
-
using OperatorMap = std::unordered_map<Symbol, std::vector<std::shared_ptr<Operator>>>;
struct OperatorRegistry {
private:
return script::SchemaParser(schema).parseDeclarations().at(0);
}
+std::string canonicalSchemaString(const FunctionSchema& schema) {
+ std::ostringstream out;
+
+ out << schema.name();
+ out << "(";
+
+ bool seen_kwarg_only = false;
+ for(size_t i = 0; i < schema.arguments().size(); ++i) {
+ if (i > 0) out << ", ";
+ if (schema.arguments()[i].kwarg_only() && !seen_kwarg_only) {
+ out << "*, ";
+ seen_kwarg_only = true;
+ }
+ const auto & arg = schema.arguments()[i];
+ out << arg.type()->str() << " " << arg.name();
+ }
+
+ out << ") -> ";
+ if (schema.returns().size() == 1) {
+ out << schema.returns().at(0).type()->str();
+ } else if (schema.returns().size() > 1) {
+ out << "(";
+ for (size_t i = 0; i < schema.returns().size(); ++i) {
+ if (i > 0) out << ", ";
+ out << schema.returns()[i].type()->str();
+ }
+ out << ")";
+ }
+ return out.str();
+}
+
bool Operator::matches(const Node* node) const {
// wrong name
if (node->kind().toQualString() != schema().name()) {
OperationCreator op_creator_;
};
+TORCH_API std::string canonicalSchemaString(const FunctionSchema& schema);
+
TORCH_API const std::vector<std::shared_ptr<Operator>>& getAllOperatorsFor(Symbol name);
std::shared_ptr<Operator> findOperatorFor(const Node* node);
const Operator& getOperatorFor(const Node* node);
sweep(block, true);
}
}
- if (!marked_.count(node)) {
+ // NB: Checking hasUses() is required. AD graphs are not perfectly
+ // valid, as a node in grad_desc.f might be used in reverse_block.
+ // Reverse_block is inlined in grad_desc.f before it's separated
+ // to grad_desc.df.
+ if (!(marked_.count(node) || node->hasUses())) {
it.destroyCurrent();
}
}
EnsureNoTuples(graph->block());
}
-static void LowerSimpleTuples(Block* block) {
+void LowerSimpleTuples(Block* block) {
for(auto n : block->nodes()) {
removeTupleNodes(n, /*must_remove_tuples*/false);
for(auto b : n->blocks()) {
// but will not work on graphs whose inputs contain tuples.
TORCH_API void LowerAllTuples(std::shared_ptr<Graph>& graph);
+TORCH_API void LowerSimpleTuples(Block* block);
}}
--- /dev/null
+#include <torch/csrc/jit/symbolic_script.h>
+
+namespace torch { namespace jit {
+ namespace {
+ std::mutex lock;
+ const std::unordered_map<std::string, std::string> symbolic_scripts({
+ {"aten::mul(Tensor self, Tensor other) -> Tensor",
+R"(
+def forward(self, other):
+ return self * other, (self, other)
+def backward(ctx, grad_output):
+ # type: (Tuple[Tensor, Tensor], Tensor) -> Tuple[Tensor, Tensor]
+ self, other = ctx
+ return (grad_output * other).sum_to_size(self.size()), (grad_output * self).sum_to_size(other.size())
+)"},
+ });
+
+ // This map is a workaround to cache compiled gradient_pairs. Ideally this graph
+ // should be compiled only once and saved in Operator structure.
+ // This should be done along with merging into native_functions.yaml.
+ std::unordered_map<const FunctionSchema*, GradientPair> cached_gradient_pairs;
+ } // anonymous namespace
+
+ c10::optional<GradientPair> gradientInfoForSchema(const FunctionSchema& schema) {
+ std::lock_guard<std::mutex> guard(lock);
+ auto cache_it = cached_gradient_pairs.find(&schema);
+ if (cache_it != cached_gradient_pairs.end()) {
+ return cache_it->second;
+ } else {
+ auto schema_str = canonicalSchemaString(schema);
+ auto sym_script_it = symbolic_scripts.find(schema_str);
+
+ if (sym_script_it != symbolic_scripts.end()) {
+ // Compile the python code to a script module
+ auto cu = std::make_shared<script::Module>();
+ script::defineMethodsInModule(cu, symbolic_scripts.at(schema_str), script::nativeResolver, nullptr);
+ auto fw_graph = cu->find_method("forward")->graph();
+ auto bw_graph = cu->find_method("backward")->graph();
+
+ GradientPair compiled_graphs{fw_graph, bw_graph};
+ cached_gradient_pairs.emplace_hint(cache_it, &schema, compiled_graphs);
+ return compiled_graphs;
+ }
+ }
+ return c10::nullopt;
+ }
+
+ bool hasGradientInfoForSchema(const FunctionSchema& schema) {
+ std::lock_guard<std::mutex> guard(lock);
+ auto cache_it = cached_gradient_pairs.find(&schema);
+ if (cache_it == cached_gradient_pairs.end()) {
+ auto schema_str = canonicalSchemaString(schema);
+ auto sym_script_it = symbolic_scripts.find(schema_str);
+ return !(sym_script_it == symbolic_scripts.end());
+ }
+ return true;
+ }
+}}
+
--- /dev/null
+#pragma once
+// This file is temporary until native_functions.yaml and derivatives.yaml are merged.
+// Ideally this should all go into native_functions.yaml
+
+#include <c10/util/Optional.h>
+#include <torch/csrc/jit/script/compiler.h>
+#include <torch/csrc/jit/script/module.h>
+#include <torch/csrc/jit/operator.h>
+
+namespace torch { namespace jit {
+ struct GradientPair {
+ std::shared_ptr<Graph> forward;
+ std::shared_ptr<Graph> backward;
+ };
+
+ TORCH_API c10::optional<GradientPair> gradientInfoForSchema(const FunctionSchema& schema);
+ TORCH_API bool hasGradientInfoForSchema(const FunctionSchema& schema);
+}}