// Sums `tensor` repeatedly to produce a tensor of shape `shape`.
// Precondition: is_expandable_to(shape, tensor.sizes()) must be true
-static inline Tensor sum_to(Tensor tensor, IntList shape) {
+static inline Tensor sum_to(Tensor tensor, const IntList shape) {
if (shape.size() == 0) {
return tensor.sum();
}
- Tensor result = tensor;
- while (result.dim() > (int64_t)shape.size()) {
- result = result.sum(0, false);
+ c10::SmallVector<int64_t, 8> reduce_dims;
+ const at::IntList sizes = tensor.sizes();
+ const int64_t leading_dims = sizes.size() - shape.size();
+ for (int64_t i = 0; i < leading_dims; ++i) {
+ reduce_dims.push_back(i);
}
- for (int64_t i = 0; i < result.dim(); ++i) {
- if (shape[i] == 1 && result.sizes()[i] > 1) {
- result = result.sum(i, true);
+ for (int64_t i = leading_dims; i < static_cast<int64_t>(sizes.size()); ++i) {
+ if (shape[i - leading_dims] == 1 && sizes[i] > 1) {
+ reduce_dims.push_back(i);
}
}
- return result;
+ if (!reduce_dims.empty()) {
+ tensor = tensor.sum(reduce_dims, /*keepdim=*/true);
+ }
+ return leading_dims > 0 ? tensor.view(shape) : tensor;
}
// True if `shape` can be broadcasted to `desired`
_(prim, requires_grad) \
_(prim, AutogradAdd) \
_(prim, GradOf) \
+ _(prim, SumToSize) \
_(prim, AnyDefined) \
_(prim, FusedConcat) \
_(prim, ConstantChunk) \
return r;
}
-void run(InterpreterState & interp, const std::vector<at::Tensor> & inputs, std::vector<at::Tensor> & outputs) {
+std::vector<at::Tensor> run(InterpreterState & interp, const std::vector<at::Tensor> & inputs) {
std::vector<IValue> stack(inputs.begin(), inputs.end());
interp.run(stack);
- outputs.clear();
- for (auto& ivalue : stack) {
- outputs.push_back(std::move(ivalue).toTensor());
- }
+ return fmap(stack, [](const IValue& i) { return i.toTensor(); });
}
std::pair<tensor_list, tensor_list> runGradient(
Gradient& grad_spec,
tensor_list& tensors_in,
tensor_list& tensor_grads_in) {
- tensor_list tensors_out, tensor_grads_out;
+ static const auto as_tensorlist = [](const Stack& stack) {
+ return fmap(stack, [](const IValue& i) { return i.toTensor(); });
+ };
Code f_code{grad_spec.f}, df_code{grad_spec.df};
InterpreterState f_interpreter{f_code}, df_interpreter{df_code};
- run(f_interpreter, tensors_in, tensors_out);
+ auto f_stack = fmap<IValue>(tensors_in);
+ f_interpreter.run(f_stack);
- tensor_list df_inputs;
- df_inputs.insert(
- df_inputs.end(), tensor_grads_in.begin(), tensor_grads_in.end());
+ Stack df_stack;
+ df_stack.insert(
+ df_stack.end(), tensor_grads_in.begin(), tensor_grads_in.end());
for (auto offset : grad_spec.df_input_captured_inputs)
- df_inputs.push_back(tensors_in[offset]);
+ df_stack.push_back(tensors_in[offset]);
for (auto offset : grad_spec.df_input_captured_outputs)
- df_inputs.push_back(tensors_out[offset]);
- run(df_interpreter, df_inputs, tensor_grads_out);
+ df_stack.push_back(f_stack[offset]);
+ df_interpreter.run(df_stack);
// Outputs of f needs to be sliced
- tensors_out.erase(
- tensors_out.begin() + grad_spec.f_real_outputs, tensors_out.end());
- return std::make_pair(tensors_out, tensor_grads_out);
+ f_stack.erase(
+ f_stack.begin() + grad_spec.f_real_outputs, f_stack.end());
+ return std::make_pair(as_tensorlist(f_stack), as_tensorlist(df_stack));
}
void assertAllClose(const tensor_list& a, const tensor_list& b) {
auto lstm_g = build_lstm();
Code lstm_function(lstm_g);
- std::vector<at::Tensor> outputs;
InterpreterState lstm_interp(lstm_function);
- run(lstm_interp, {input[0], hx, cx, w_ih, w_hh}, outputs);
+ auto outputs = run(lstm_interp, {input[0], hx, cx, w_ih, w_hh});
std::tie(hx, cx) = lstm(input[0], hx, cx, w_ih, w_hh);
// std::cout << almostEqual(outputs[0],hx) << "\n";
auto grad_spec = differentiate(graph);
std::vector<size_t> expected_captured_inputs = {0, 1};
- std::vector<size_t> expected_captured_outputs = {1};
- std::vector<size_t> expected_input_vjps = {0, 1};
+ std::vector<size_t> expected_captured_outputs = {1, 2, 3, 4, 5, 6, 7};
+ std::vector<size_t> expected_input_vjps = {0, 3};
std::vector<size_t> expected_output_vjps = {0, 1};
ASSERT_EQ(grad_spec.f_real_outputs, 1);
ASSERT_EQ(grad_spec.df_input_captured_inputs, expected_captured_inputs);
PropagateRequiresGrad(graph);
auto grad_spec = differentiate(graph);
- std::vector<size_t> expected_input_vjps = {1, 2}; // for e and %4 = (d + a)
+ std::vector<size_t> expected_input_vjps = {1, 4}; // for e and %4 = (d + a)
std::vector<size_t> expected_output_vjps = {0}; // only a requires grad
- ASSERT_EQ(grad_spec.f_real_outputs, 2); // we need one temporary %4 = (d + a)
+ ASSERT_EQ(grad_spec.f_real_outputs, 2);
ASSERT_EQ(grad_spec.df_input_captured_inputs, std::vector<size_t>({0}));
- ASSERT_EQ(grad_spec.df_input_captured_outputs, std::vector<size_t>({2}));
+ ASSERT_EQ(grad_spec.df_input_captured_outputs, std::vector<size_t>({2, 3, 4, 5, 6, 7, 8}));
ASSERT_EQ(grad_spec.df_input_vjps, expected_input_vjps);
ASSERT_EQ(grad_spec.df_output_vjps, expected_output_vjps);
out << "testDifferentiateWithRequiresGrad\n";
testDifferentiate
graph(%0 : Float(2, 3, 4)
%1 : Float(2, 3, 4)) {
+ %23 : int[] = aten::size(%0)
+ %26 : int[] = aten::size(%1)
%2 : Float(2, 3, 4) = aten::mul(%0, %1)
+ %15 : int[] = aten::size(%2)
+ %18 : int[] = aten::size(%0)
%3 : Float(2, 3, 4) = aten::mul(%2, %0)
%4 : int = prim::Constant[value=1]()
+ %7 : int[] = aten::size(%3)
+ %10 : int[] = aten::size(%1)
%5 : Float(2, 3, 4) = aten::add(%3, %1, %4)
- return (%5, %2);
+ return (%5, %23, %26, %2, %15, %18, %7, %10);
}
graph(%0 : Float(2, 3, 4)
%1 : Float(2, 3, 4)
%2 : Float(2, 3, 4)
%3 : Float(2, 3, 4)
- %4 : Float(2, 3, 4)) {
- %5 : int = prim::Constant[value=1]()
- %6 : Float(2, 3, 4), %7 : Float(2, 3, 4) = prim::GradOf[name="aten::add"](%0)
+ %4 : int[]
+ %5 : int[]
+ %6 : Float(2, 3, 4)
+ %7 : int[]
+ %8 : int[]
+ %9 : int[]
+ %10 : int[]) {
+ %11 : int = prim::Constant[value=1]()
+ %12 : Tensor, %13 : Tensor = prim::GradOf[name="aten::add"](%0)
block0() {
- %8 : Float(2, 3, 4) = aten::mul(%0, %5)
- -> (%0, %8)
+ %14 : Tensor = prim::SumToSize(%0, %9)
+ %15 : Float(2, 3, 4) = aten::mul(%0, %11)
+ %16 : Tensor = prim::SumToSize(%15, %10)
+ -> (%14, %16)
}
- %9 : Float(2, 3, 4), %10 : Float(2, 3, 4) = prim::GradOf[name="aten::mul"](%6)
+ %17 : Tensor, %18 : Tensor = prim::GradOf[name="aten::mul"](%12)
block0() {
- %11 : Float(2, 3, 4) = aten::mul(%6, %2)
- %12 : Float(2, 3, 4) = aten::mul(%6, %4)
- -> (%11, %12)
+ %19 : Tensor = aten::mul(%12, %2)
+ %20 : Tensor = prim::SumToSize(%19, %7)
+ %21 : Tensor = aten::mul(%12, %6)
+ %22 : Tensor = prim::SumToSize(%21, %8)
+ -> (%20, %22)
}
- %13 : Tensor = prim::AutogradAdd(%1, %9)
- %14 : Float(2, 3, 4), %15 : Float(2, 3, 4) = prim::GradOf[name="aten::mul"](%13)
+ %23 : Tensor = prim::AutogradAdd(%1, %17)
+ %24 : Tensor, %25 : Tensor = prim::GradOf[name="aten::mul"](%23)
block0() {
- %16 : Float(2, 3, 4) = aten::mul(%13, %3)
- %17 : Float(2, 3, 4) = aten::mul(%13, %2)
- -> (%16, %17)
+ %26 : Tensor = aten::mul(%23, %3)
+ %27 : Tensor = prim::SumToSize(%26, %4)
+ %28 : Tensor = aten::mul(%23, %2)
+ %29 : Tensor = prim::SumToSize(%28, %5)
+ -> (%27, %29)
}
- %18 : Tensor = prim::AutogradAdd(%10, %14)
- %19 : Tensor = prim::AutogradAdd(%7, %15)
- return (%18, %19);
+ %30 : Tensor = prim::AutogradAdd(%18, %24)
+ %31 : Tensor = prim::AutogradAdd(%13, %25)
+ return (%30, %31);
}
testDifferentiateWithRequiresGrad
%3 : int = prim::Constant[value=1]()
%4 : Float(*) = aten::add(%2, %1, %3)
%5 : int = prim::Constant[value=1]()
+ %26 : int[] = aten::size(%4)
+ %29 : int[] = aten::size(%0)
%6 : Float(*) = aten::add(%4, %0, %5)
+ %19 : int[] = aten::size(%6)
+ %22 : int[] = aten::size(%0)
%7 : Float(*) = aten::mul(%6, %0)
%8 : int = prim::Constant[value=1]()
+ %11 : int[] = aten::size(%7)
+ %14 : int[] = aten::size(%1)
%9 : Float(*) = aten::add(%7, %1, %8)
- return (%4, %9, %6);
+ return (%4, %9, %26, %29, %6, %19, %22, %11, %14);
}
graph(%0 : Float(*)
%1 : Float(*)
%2 : Float(*)
- %3 : Float(*)) {
- %4 : int = prim::Constant[value=1]()
- %6 : Float(*), %7 : Tensor = prim::GradOf[name="aten::add"](%0)
+ %3 : int[]
+ %4 : int[]
+ %5 : Float(*)
+ %6 : int[]
+ %7 : int[]
+ %8 : int[]
+ %9 : int[]) {
+ %10 : int = prim::Constant[value=1]()
+ %12 : Tensor, %13 : Tensor = prim::GradOf[name="aten::add"](%0)
block0() {
- %8 : Tensor = aten::mul(%0, %4)
- -> (%0, %8)
+ %14 : Tensor = prim::SumToSize(%0, %8)
+ %15 : Tensor = aten::mul(%0, %10)
+ %16 : Tensor = prim::SumToSize(%15, %9)
+ -> (%14, %16)
}
- %9 : Tensor, %10 : Tensor = prim::GradOf[name="aten::mul"](%6)
+ %17 : Tensor, %18 : Tensor = prim::GradOf[name="aten::mul"](%12)
block0() {
- %11 : Tensor = aten::mul(%6, %2)
- %12 : Tensor = aten::mul(%6, %3)
- -> (%11, %12)
+ %19 : Tensor = aten::mul(%12, %2)
+ %20 : Tensor = prim::SumToSize(%19, %6)
+ %21 : Tensor = aten::mul(%12, %5)
+ %22 : Tensor = prim::SumToSize(%21, %7)
+ -> (%20, %22)
}
- %13 : Tensor = prim::AutogradAdd(%1, %9)
- %14 : Tensor, %15 : Tensor = prim::GradOf[name="aten::add"](%13)
+ %23 : Tensor = prim::AutogradAdd(%1, %17)
+ %24 : Tensor, %25 : Tensor = prim::GradOf[name="aten::add"](%23)
block0() {
- %16 : Tensor = aten::mul(%13, %4)
- -> (%13, %16)
+ %26 : Tensor = prim::SumToSize(%23, %3)
+ %27 : Tensor = aten::mul(%23, %10)
+ %28 : Tensor = prim::SumToSize(%27, %4)
+ -> (%26, %28)
}
- %17 : Tensor = prim::AutogradAdd(%10, %15)
- return (%17);
+ %29 : Tensor = prim::AutogradAdd(%18, %25)
+ return (%29);
}
%11 : Float(*, *)
%12 : Float(*, *)
%13 : Float(*, *)
+ %14 : int[]
+ %15 : int[]
+ %16 : int[]
+ %17 : int[]
+ %18 : int[]
+ %19 : int[]
%ingate : Float(*, *)
%forgetgate : Float(*, *)
%cellgate : Float(*, *)
%outgate : Float(*, *)
- %18 : Float(*, *)) {
- %19 : Float(*, *), %20 : Float(*, *) = prim::FusionGroup_0(%forgetgate, %ingate, %cellgate, %outgate, %9, %1, %18, %0)
- %21 : Float(*, *) = aten::t(%13)
- %22 : Float(*, *) = aten::mm(%20, %21)
- %23 : Float(*, *) = aten::t(%10)
- %24 : Float(*, *) = aten::mm(%23, %20)
- %25 : Float(*, *) = aten::t(%24)
- %26 : Float(*, *) = aten::t(%12)
- %27 : Float(*, *) = aten::mm(%20, %26)
- %28 : Float(*, *) = aten::t(%11)
- %29 : Float(*, *) = aten::mm(%28, %20)
- %30 : Float(*, *) = aten::t(%29)
- return (%19, %20, %20, %22, %25, %27, %30);
+ %24 : int[]
+ %25 : int[]
+ %26 : int[]
+ %27 : int[]
+ %28 : int[]
+ %29 : int[]
+ %30 : Float(*, *)
+ %31 : int[]
+ %32 : int[]) {
+ %33 : Float(*, *) = aten::mul(%0, %30)
+ %34 : Tensor = prim::SumToSize(%33, %31)
+ %35 : Float(*, *) = aten::mul(%0, %outgate)
+ %36 : Tensor = prim::SumToSize(%35, %32)
+ %37 : Tensor = prim::FusionGroup_0(%1, %36, %30)
+ %38 : Tensor = prim::SumToSize(%37, %28)
+ %39 : Tensor = prim::SumToSize(%37, %29)
+ %40 : Tensor = aten::mul(%39, %cellgate)
+ %41 : Tensor = prim::SumToSize(%40, %26)
+ %42 : Tensor = aten::mul(%39, %ingate)
+ %43 : Tensor = prim::SumToSize(%42, %27)
+ %44 : Tensor = aten::mul(%38, %9)
+ %45 : Tensor = prim::SumToSize(%44, %24)
+ %46 : Tensor = aten::mul(%38, %forgetgate)
+ %47 : Tensor = prim::SumToSize(%46, %25)
+ %48 : Tensor = prim::FusionGroup_1(%41, %ingate, %45, %forgetgate, %43, %cellgate, %34, %outgate)
+ %49 : Tensor = prim::SumToSize(%48, %18)
+ %50 : Tensor = prim::SumToSize(%48, %19)
+ %51 : Tensor = prim::SumToSize(%49, %16)
+ %52 : Tensor = prim::SumToSize(%49, %17)
+ %53 : Tensor = prim::SumToSize(%51, %14)
+ %54 : Tensor = prim::SumToSize(%51, %15)
+ %55 : Float(*, *) = aten::t(%13)
+ %56 : Float(*, *) = aten::mm(%54, %55)
+ %57 : Float(*, *) = aten::t(%10)
+ %58 : Float(*, *) = aten::mm(%57, %54)
+ %59 : Float(*, *) = aten::t(%58)
+ %60 : Float(*, *) = aten::t(%12)
+ %61 : Float(*, *) = aten::mm(%53, %60)
+ %62 : Float(*, *) = aten::t(%11)
+ %63 : Float(*, *) = aten::mm(%62, %53)
+ %64 : Float(*, *) = aten::t(%63)
+ return (%47, %50, %52, %56, %59, %61, %64);
}
with prim::FusionGroup_0 = graph(%0 : Float(*, *)
+ %1 : Tensor
+ %2 : Float(*, *)) {
+ %3 : Float(*, *) = aten::mul(%2, %2)
+ %4 : Float(*, *) = aten::neg(%3)
+ %5 : int = prim::Constant[value=1]()
+ %6 : Float(*, *) = aten::add(%4, %5, %5)
+ %7 : Tensor = aten::mul(%1, %6)
+ %8 : int = prim::Constant[value=1]()
+ %9 : Tensor = aten::add(%0, %7, %8)
+ return (%9);
+}
+with prim::FusionGroup_1 = graph(%0 : Tensor
%1 : Float(*, *)
- %2 : Float(*, *)
+ %2 : Tensor
%3 : Float(*, *)
- %4 : Float(*, *)
+ %4 : Tensor
%5 : Float(*, *)
- %6 : Float(*, *)
+ %6 : Tensor
%7 : Float(*, *)) {
- %8 : Float(*, *) = aten::mul(%7, %6)
- %9 : Float(*, *) = aten::mul(%7, %3)
- %10 : Float(*, *) = aten::mul(%6, %6)
- %11 : Float(*, *) = aten::neg(%10)
- %12 : int = prim::Constant[value=1]()
- %13 : int = prim::Constant[value=1]()
- %14 : Float(*, *) = aten::add(%11, %13, %13)
- %15 : Float(*, *) = aten::mul(%9, %14)
- %16 : int = prim::Constant[value=1]()
- %17 : int = prim::Constant[value=1]()
- %18 : Float(*, *) = aten::add(%5, %15, %17)
- %19 : Float(*, *) = aten::mul(%18, %2)
- %20 : Float(*, *) = aten::mul(%18, %1)
- %21 : Float(*, *) = aten::mul(%18, %4)
- %22 : Float(*, *) = aten::neg(%3)
- %23 : int = prim::Constant[value=1]()
+ %8 : Float(*, *) = aten::neg(%7)
+ %9 : int = prim::Constant[value=1]()
+ %10 : Float(*, *) = aten::add(%8, %9, %9)
+ %11 : Float(*, *) = aten::mul(%10, %7)
+ %12 : Tensor = aten::mul(%11, %6)
+ %13 : Float(*, *) = aten::mul(%5, %5)
+ %14 : Float(*, *) = aten::neg(%13)
+ %15 : int = prim::Constant[value=1]()
+ %16 : Float(*, *) = aten::add(%14, %15, %15)
+ %17 : Tensor = aten::mul(%4, %16)
+ %18 : Float(*, *) = aten::neg(%3)
+ %19 : int = prim::Constant[value=1]()
+ %20 : Float(*, *) = aten::add(%18, %19, %19)
+ %21 : Float(*, *) = aten::mul(%20, %3)
+ %22 : Tensor = aten::mul(%21, %2)
+ %23 : Float(*, *) = aten::neg(%1)
%24 : int = prim::Constant[value=1]()
- %25 : Float(*, *) = aten::add(%22, %24, %24)
- %26 : Float(*, *) = aten::mul(%25, %3)
- %27 : Float(*, *) = aten::mul(%26, %8)
- %28 : Float(*, *) = aten::mul(%2, %2)
- %29 : Float(*, *) = aten::neg(%28)
- %30 : int = prim::Constant[value=1]()
- %31 : int = prim::Constant[value=1]()
- %32 : Float(*, *) = aten::add(%29, %31, %31)
- %33 : Float(*, *) = aten::mul(%20, %32)
- %34 : Float(*, *) = aten::neg(%0)
- %35 : int = prim::Constant[value=1]()
- %36 : int = prim::Constant[value=1]()
- %37 : Float(*, *) = aten::add(%34, %36, %36)
- %38 : Float(*, *) = aten::mul(%37, %0)
- %39 : Float(*, *) = aten::mul(%38, %21)
- %40 : Float(*, *) = aten::neg(%1)
- %41 : int = prim::Constant[value=1]()
- %42 : int = prim::Constant[value=1]()
- %43 : Float(*, *) = aten::add(%40, %42, %42)
- %44 : Float(*, *) = aten::mul(%43, %1)
- %45 : Float(*, *) = aten::mul(%44, %19)
- %46 : Float(*, *) = prim::FusedConcat[dim=1](%45, %39, %33, %27)
- %47 : Float(*, *) = aten::mul(%18, %0)
- return (%47, %46);
+ %25 : Float(*, *) = aten::add(%23, %24, %24)
+ %26 : Float(*, *) = aten::mul(%25, %1)
+ %27 : Tensor = aten::mul(%26, %0)
+ %28 : Tensor = prim::FusedConcat[dim=1](%27, %22, %17, %12)
+ return (%28);
}
%9 : Float(*, *) = aten::t(%4)
%10 : Float(*, *) = aten::mm(%3, %9)
%11 : int = prim::Constant[value=1]()
- %12 : Tensor[] = prim::ListConstruct(%1, %2, %8, %10)
- %13 : Tensor[] = aten::broadcast_tensors(%12)
- %14 : Tensor, %15 : Tensor, %16 : Tensor, %17 : Tensor = prim::ListUnpack(%13)
- %hy : Float(*, *), %19 : Float(*, *), %cy : Float(*, *), %outgate.1 : Float(*, *), %cellgate.1 : Float(*, *), %forgetgate.1 : Float(*, *), %ingate.1 : Float(*, *) = prim::FusionGroup_0(%0, %17, %16, %15, %14)
- return (%hy, %cy, %7, %9, %ingate.1, %forgetgate.1, %cellgate.1, %outgate.1, %19);
+ %12 : int[] = aten::size(%8)
+ %13 : int[] = aten::size(%10)
+ %14 : int[] = aten::size(%2)
+ %15 : Float(*, *), %16 : Float(*, *) = prim::FusionGroup_0(%2, %8, %10)
+ %17 : int[] = aten::size(%16)
+ %18 : int[] = aten::size(%15)
+ %19 : int[] = aten::size(%1)
+ %20 : Tensor[] = prim::ListConstruct(%15, %1)
+ %21 : Tensor[] = aten::broadcast_tensors(%20)
+ %22 : Tensor, %23 : Tensor = prim::ListUnpack(%21)
+ %24 : int[] = aten::size(%0)
+ %hy : Float(*, *), %26 : Float(*, *), %cy : Float(*, *), %28 : Float(*, *), %29 : Float(*, *), %outgate.1 : Float(*, *), %cellgate.1 : Float(*, *), %forgetgate.1 : Float(*, *), %ingate.1 : Float(*, *) = prim::FusionGroup_1(%0, %23, %22)
+ %34 : int[] = aten::size(%ingate.1)
+ %35 : int[] = aten::size(%forgetgate.1)
+ %36 : int[] = aten::size(%cellgate.1)
+ %37 : int[] = aten::size(%outgate.1)
+ %38 : int[] = aten::size(%29)
+ %39 : int[] = aten::size(%28)
+ %40 : int[] = aten::size(%26)
+ return (%hy, %cy, %7, %9, %12, %13, %17, %14, %18, %19, %ingate.1, %forgetgate.1, %cellgate.1, %outgate.1, %35, %24, %34, %36, %38, %39, %26, %37, %40);
}
-with prim::FusionGroup_0 = graph(%0 : Float(*, *)
+with prim::FusionGroup_0 = graph(%0 : Float(*)
+ %1 : Float(*, *)
+ %2 : Float(*, *)) {
+ %3 : int = prim::Constant[value=1]()
+ %4 : Float(*, *) = aten::add(%1, %2, %3)
+ %5 : int = prim::Constant[value=1]()
+ %6 : Float(*, *) = aten::add(%4, %0, %5)
+ return (%6, %4);
+}
+with prim::FusionGroup_1 = graph(%0 : Float(*, *)
%1 : Tensor
- %2 : Tensor
- %3 : Tensor
- %4 : Tensor) {
- %5 : Float(*, *), %6 : Float(*, *), %7 : Float(*, *), %8 : Float(*, *) = prim::ConstantChunk[chunks=4, dim=1](%4)
- %9 : Float(*, *), %10 : Float(*, *), %11 : Float(*, *), %12 : Float(*, *) = prim::ConstantChunk[chunks=4, dim=1](%3)
- %13 : Float(*, *), %14 : Float(*, *), %15 : Float(*, *), %16 : Float(*, *) = prim::ConstantChunk[chunks=4, dim=1](%2)
- %17 : Float(*, *), %18 : Float(*, *), %19 : Float(*, *), %20 : Float(*, *) = prim::ConstantChunk[chunks=4, dim=1](%1)
- %21 : int = prim::Constant[value=1]()
- %22 : Float(*, *) = aten::add(%13, %17, %21)
- %23 : int = prim::Constant[value=1]()
- %24 : Float(*, *) = aten::add(%14, %18, %23)
+ %2 : Tensor) {
+ %3 : Float(*, *), %4 : Float(*, *), %5 : Float(*, *), %6 : Float(*, *) = prim::ConstantChunk[chunks=4, dim=1](%2)
+ %7 : Float(*, *), %8 : Float(*, *), %9 : Float(*, *), %10 : Float(*, *) = prim::ConstantChunk[chunks=4, dim=1](%1)
+ %11 : int = prim::Constant[value=1]()
+ %12 : Float(*, *) = aten::add(%3, %7, %11)
+ %13 : int = prim::Constant[value=1]()
+ %14 : Float(*, *) = aten::add(%4, %8, %13)
+ %15 : int = prim::Constant[value=1]()
+ %16 : Float(*, *) = aten::add(%5, %9, %15)
+ %17 : int = prim::Constant[value=1]()
+ %18 : Float(*, *) = aten::add(%6, %10, %17)
+ %ingate.1 : Float(*, *) = aten::sigmoid(%12)
+ %forgetgate.1 : Float(*, *) = aten::sigmoid(%14)
+ %cellgate.1 : Float(*, *) = aten::tanh(%16)
+ %outgate.1 : Float(*, *) = aten::sigmoid(%18)
+ %23 : Float(*, *) = aten::mul(%forgetgate.1, %0)
+ %24 : Float(*, *) = aten::mul(%ingate.1, %cellgate.1)
%25 : int = prim::Constant[value=1]()
- %26 : Float(*, *) = aten::add(%15, %19, %25)
- %27 : int = prim::Constant[value=1]()
- %28 : Float(*, *) = aten::add(%16, %20, %27)
- %29 : int = prim::Constant[value=1]()
- %30 : Float(*, *) = aten::add(%22, %9, %29)
- %31 : int = prim::Constant[value=1]()
- %32 : Float(*, *) = aten::add(%24, %10, %31)
- %33 : int = prim::Constant[value=1]()
- %34 : Float(*, *) = aten::add(%26, %11, %33)
- %35 : int = prim::Constant[value=1]()
- %36 : Float(*, *) = aten::add(%28, %12, %35)
- %37 : int = prim::Constant[value=1]()
- %38 : Float(*, *) = aten::add(%30, %5, %37)
- %39 : int = prim::Constant[value=1]()
- %40 : Float(*, *) = aten::add(%32, %6, %39)
- %41 : int = prim::Constant[value=1]()
- %42 : Float(*, *) = aten::add(%34, %7, %41)
- %43 : int = prim::Constant[value=1]()
- %44 : Float(*, *) = aten::add(%36, %8, %43)
- %ingate.1 : Float(*, *) = aten::sigmoid(%38)
- %forgetgate.1 : Float(*, *) = aten::sigmoid(%40)
- %cellgate.1 : Float(*, *) = aten::tanh(%42)
- %outgate.1 : Float(*, *) = aten::sigmoid(%44)
- %49 : Float(*, *) = aten::mul(%forgetgate.1, %0)
- %50 : Float(*, *) = aten::mul(%ingate.1, %cellgate.1)
- %51 : int = prim::Constant[value=1]()
- %cy : Float(*, *) = aten::add(%49, %50, %51)
- %53 : Float(*, *) = aten::tanh(%cy)
- %hy : Float(*, *) = aten::mul(%outgate.1, %53)
- return (%hy, %53, %cy, %outgate.1, %cellgate.1, %forgetgate.1, %ingate.1);
+ %cy : Float(*, *) = aten::add(%23, %24, %25)
+ %27 : Float(*, *) = aten::tanh(%cy)
+ %hy : Float(*, *) = aten::mul(%outgate.1, %27)
+ return (%hy, %27, %cy, %24, %23, %outgate.1, %cellgate.1, %forgetgate.1, %ingate.1);
}
%Wx : Float(*, *)
%20 : Float(*, *)
%Uz : Float(*, *)
- %22 : Float(*, *)
+ %22 : int[]
+ %23 : int[]
+ %24 : Float(*, *)
+ %25 : int[]
+ %26 : int[]
+ %27 : int[]
+ %28 : int[]
+ %29 : int[]
+ %30 : int[]
+ %31 : int[]
+ %32 : int[]
+ %33 : int[]
+ %34 : int[]
+ %35 : int[]
+ %36 : int[]
%ingate : Float(*, *)
%forgetgate : Float(*, *)
%cellgate : Float(*, *)
%outgate : Float(*, *)
- %27 : Float(*, *)) {
- %28 : Float(*, *) = prim::FusionGroup_0(%ingate, %forgetgate, %cellgate, %outgate, %12, %1, %27, %0)
- %29 : Float(*, *) = aten::mul(%28, %Wx)
- %30 : Float(*, *) = prim::FusionGroup_1(%28, %22, %13)
- %31 : Float(*, *), %32 : Float(*, *), %33 : Float(*, *) = prim::FusionGroup_2(%Wx, %15, %28, %14, %Uz)
- %34 : Float(*, *) = aten::t(%16)
- %35 : Float(*, *) = aten::mm(%34, %30)
- %36 : Float(*, *) = aten::t(%35)
- %37 : Float(*, *) = aten::t(%17)
- %38 : Float(*, *) = aten::mm(%37, %32)
- %39 : Float(*, *) = aten::t(%38)
- return (%28, %33, %29, %31, %36, %39);
+ %41 : int[]
+ %42 : int[]
+ %43 : int[]
+ %44 : int[]
+ %45 : int[]
+ %46 : int[]
+ %47 : Float(*, *)
+ %48 : int[]
+ %49 : int[]) {
+ %50 : int = prim::Constant[value=1]()
+ %51 : Float(*, *) = aten::mul(%0, %47)
+ %52 : Tensor = prim::SumToSize(%51, %48)
+ %53 : Float(*, *) = aten::mul(%0, %outgate)
+ %54 : Tensor = prim::SumToSize(%53, %49)
+ %55 : Tensor = prim::FusionGroup_0(%1, %54, %47)
+ %56 : Tensor = prim::SumToSize(%55, %45)
+ %57 : Tensor = prim::SumToSize(%55, %46)
+ %58 : Tensor = aten::mul(%57, %cellgate)
+ %59 : Tensor = prim::SumToSize(%58, %43)
+ %60 : Tensor = aten::mul(%57, %ingate)
+ %61 : Tensor = prim::SumToSize(%60, %44)
+ %62 : Tensor = aten::mul(%56, %12)
+ %63 : Tensor = prim::SumToSize(%62, %41)
+ %64 : Tensor = prim::FusionGroup_1(%59, %ingate, %63, %forgetgate, %61, %cellgate, %52, %outgate)
+ %65 : Tensor = prim::SumToSize(%64, %35)
+ %66 : Tensor = prim::SumToSize(%64, %36)
+ %67 : Tensor = prim::SumToSize(%65, %33)
+ %68 : Tensor = prim::SumToSize(%65, %34)
+ %69 : Tensor = aten::mul(%68, %Uz)
+ %70 : Tensor = prim::SumToSize(%69, %31)
+ %71 : Tensor = aten::mul(%68, %13)
+ %72 : Tensor = prim::SumToSize(%71, %32)
+ %73 : Tensor = prim::SumToSize(%67, %29)
+ %74 : Tensor = prim::SumToSize(%67, %30)
+ %75 : Tensor = aten::mul(%74, %Wx)
+ %76 : Tensor = prim::SumToSize(%75, %27)
+ %77 : Tensor = aten::mul(%74, %14)
+ %78 : Tensor = prim::SumToSize(%77, %28)
+ %79 : Tensor = aten::mul(%73, %Uz)
+ %80 : Tensor = prim::SumToSize(%79, %25)
+ %81 : Tensor = aten::mul(%73, %24)
+ %82 : Tensor = prim::SumToSize(%81, %26)
+ %83 : Tensor = aten::add(%72, %82, %50)
+ %84 : Tensor = aten::mul(%80, %Wx)
+ %85 : Tensor = prim::SumToSize(%84, %22)
+ %86 : Tensor = aten::mul(%80, %15)
+ %87 : Tensor = prim::SumToSize(%86, %23)
+ %88 : Tensor = aten::add(%78, %87, %50)
+ %89 : Float(*, *) = aten::t(%16)
+ %90 : Float(*, *) = aten::mm(%89, %83)
+ %91 : Float(*, *) = aten::t(%90)
+ %92 : Float(*, *) = aten::t(%17)
+ %93 : Float(*, *) = aten::mm(%92, %88)
+ %94 : Float(*, *) = aten::t(%93)
+ return (%66, %70, %76, %85, %91, %94);
}
with prim::FusionGroup_0 = graph(%0 : Float(*, *)
+ %1 : Tensor
+ %2 : Float(*, *)) {
+ %3 : Float(*, *) = aten::mul(%2, %2)
+ %4 : Float(*, *) = aten::neg(%3)
+ %5 : int = prim::Constant[value=1]()
+ %6 : Float(*, *) = aten::add(%4, %5, %5)
+ %7 : Tensor = aten::mul(%1, %6)
+ %8 : int = prim::Constant[value=1]()
+ %9 : Tensor = aten::add(%0, %7, %8)
+ return (%9);
+}
+with prim::FusionGroup_1 = graph(%0 : Tensor
%1 : Float(*, *)
- %2 : Float(*, *)
+ %2 : Tensor
%3 : Float(*, *)
- %4 : Float(*, *)
+ %4 : Tensor
%5 : Float(*, *)
- %6 : Float(*, *)
+ %6 : Tensor
%7 : Float(*, *)) {
- %8 : Float(*, *) = aten::mul(%7, %6)
- %9 : Float(*, *) = aten::mul(%7, %3)
- %10 : Float(*, *) = aten::mul(%6, %6)
- %11 : Float(*, *) = aten::neg(%10)
- %12 : int = prim::Constant[value=1]()
- %13 : Float(*, *) = aten::add(%11, %12, %12)
- %14 : Float(*, *) = aten::mul(%9, %13)
- %15 : int = prim::Constant[value=1]()
- %16 : Float(*, *) = aten::add(%5, %14, %15)
- %17 : Float(*, *) = aten::mul(%16, %2)
- %18 : Float(*, *) = aten::mul(%16, %0)
- %19 : Float(*, *) = aten::mul(%16, %4)
- %20 : Float(*, *) = aten::neg(%3)
- %21 : int = prim::Constant[value=1]()
- %22 : Float(*, *) = aten::add(%20, %21, %21)
- %23 : Float(*, *) = aten::mul(%22, %3)
- %24 : Float(*, *) = aten::mul(%23, %8)
- %25 : Float(*, *) = aten::mul(%2, %2)
- %26 : Float(*, *) = aten::neg(%25)
- %27 : int = prim::Constant[value=1]()
- %28 : Float(*, *) = aten::add(%26, %27, %27)
- %29 : Float(*, *) = aten::mul(%18, %28)
- %30 : Float(*, *) = aten::neg(%1)
- %31 : int = prim::Constant[value=1]()
- %32 : Float(*, *) = aten::add(%30, %31, %31)
- %33 : Float(*, *) = aten::mul(%32, %1)
- %34 : Float(*, *) = aten::mul(%33, %19)
- %35 : Float(*, *) = aten::neg(%0)
- %36 : int = prim::Constant[value=1]()
- %37 : Float(*, *) = aten::add(%35, %36, %36)
- %38 : Float(*, *) = aten::mul(%37, %0)
- %39 : Float(*, *) = aten::mul(%38, %17)
- %40 : Float(*, *) = prim::FusedConcat[dim=1](%39, %34, %29, %24)
- return (%40);
-}
-with prim::FusionGroup_1 = graph(%0 : Float(*, *)
- %1 : Float(*, *)
- %2 : Float(*)) {
- %3 : Float(*, *) = aten::mul(%0, %2)
- %4 : Float(*, *) = aten::mul(%0, %1)
- %5 : int = prim::Constant[value=1]()
- %6 : Float(*, *) = aten::add(%3, %4, %5)
- return (%6);
-}
-with prim::FusionGroup_2 = graph(%0 : Float(*, *)
- %1 : Float(*)
- %2 : Float(*, *)
- %3 : Float(*)
- %4 : Float(*, *)) {
- %5 : Float(*, *) = aten::mul(%2, %4)
- %6 : Float(*, *) = aten::mul(%2, %3)
- %7 : Float(*, *) = aten::mul(%5, %1)
- %8 : int = prim::Constant[value=1]()
+ %8 : Float(*, *) = aten::neg(%7)
%9 : int = prim::Constant[value=1]()
- %10 : Float(*, *) = aten::add(%6, %7, %9)
- %11 : Float(*, *) = aten::mul(%5, %0)
- return (%11, %10, %5);
+ %10 : Float(*, *) = aten::add(%8, %9, %9)
+ %11 : Float(*, *) = aten::mul(%10, %7)
+ %12 : Tensor = aten::mul(%11, %6)
+ %13 : Float(*, *) = aten::mul(%5, %5)
+ %14 : Float(*, *) = aten::neg(%13)
+ %15 : int = prim::Constant[value=1]()
+ %16 : Float(*, *) = aten::add(%14, %15, %15)
+ %17 : Tensor = aten::mul(%4, %16)
+ %18 : Float(*, *) = aten::neg(%3)
+ %19 : int = prim::Constant[value=1]()
+ %20 : Float(*, *) = aten::add(%18, %19, %19)
+ %21 : Float(*, *) = aten::mul(%20, %3)
+ %22 : Tensor = aten::mul(%21, %2)
+ %23 : Float(*, *) = aten::neg(%1)
+ %24 : int = prim::Constant[value=1]()
+ %25 : Float(*, *) = aten::add(%23, %24, %24)
+ %26 : Float(*, *) = aten::mul(%25, %1)
+ %27 : Tensor = aten::mul(%26, %0)
+ %28 : Tensor = prim::FusedConcat[dim=1](%27, %22, %17, %12)
+ return (%28);
}
%Wx.1 : Float(*, *) = aten::mm(%7, %9)
%11 : Float(*, *) = aten::t(%6)
%Uz.1 : Float(*, *) = aten::mm(%5, %11)
- %13 : Float(*, *) = aten::mul(%4, %Wx.1)
- %14 : int = prim::Constant[value=1]()
- %15 : Tensor[] = prim::ListConstruct(%1, %2, %Uz.1, %13, %3, %Wx.1)
- %16 : Tensor[] = aten::broadcast_tensors(%15)
- %17 : Tensor, %18 : Tensor, %19 : Tensor, %20 : Tensor, %21 : Tensor, %22 : Tensor = prim::ListUnpack(%16)
- %hy : Float(*, *), %24 : Float(*, *), %cy : Float(*, *), %outgate.1 : Float(*, *), %cellgate.1 : Float(*, *), %forgetgate.1 : Float(*, *), %ingate.1 : Float(*, *) = prim::FusionGroup_0(%0, %22, %21, %20, %19, %18, %17)
- return (%hy, %cy, %9, %Wx.1, %11, %Uz.1, %13, %ingate.1, %forgetgate.1, %cellgate.1, %outgate.1, %24);
+ %13 : int[] = aten::size(%4)
+ %14 : int[] = aten::size(%Wx.1)
+ %15 : int[] = aten::size(%Uz.1)
+ %16 : int[] = aten::size(%3)
+ %17 : int = prim::Constant[value=1]()
+ %18 : int[] = aten::size(%2)
+ %19 : Float(*, *), %20 : Float(*, *), %21 : Float(*, *), %22 : Float(*, *), %23 : Float(*, *), %24 : Float(*, *) = prim::FusionGroup_0(%2, %Uz.1, %3, %Wx.1, %4)
+ %25 : int[] = aten::size(%24)
+ %26 : int[] = aten::size(%23)
+ %27 : int[] = aten::size(%22)
+ %28 : int[] = aten::size(%21)
+ %29 : int[] = aten::size(%20)
+ %30 : int[] = aten::size(%19)
+ %31 : int[] = aten::size(%1)
+ %32 : Tensor[] = prim::ListConstruct(%19, %1)
+ %33 : Tensor[] = aten::broadcast_tensors(%32)
+ %34 : Tensor, %35 : Tensor = prim::ListUnpack(%33)
+ %36 : int[] = aten::size(%0)
+ %hy : Float(*, *), %38 : Float(*, *), %cy : Float(*, *), %40 : Float(*, *), %41 : Float(*, *), %outgate.1 : Float(*, *), %cellgate.1 : Float(*, *), %forgetgate.1 : Float(*, *), %ingate.1 : Float(*, *) = prim::FusionGroup_1(%0, %35, %34)
+ %46 : int[] = aten::size(%ingate.1)
+ %47 : int[] = aten::size(%forgetgate.1)
+ %48 : int[] = aten::size(%cellgate.1)
+ %49 : int[] = aten::size(%outgate.1)
+ %50 : int[] = aten::size(%41)
+ %51 : int[] = aten::size(%40)
+ %52 : int[] = aten::size(%38)
+ return (%hy, %cy, %9, %Wx.1, %11, %Uz.1, %13, %14, %24, %25, %15, %16, %14, %26, %27, %18, %15, %28, %29, %30, %31, %ingate.1, %forgetgate.1, %cellgate.1, %outgate.1, %47, %36, %46, %48, %50, %51, %38, %49, %52);
}
-with prim::FusionGroup_0 = graph(%0 : Float(*, *)
+with prim::FusionGroup_0 = graph(%0 : Float(*)
+ %1 : Float(*, *)
+ %2 : Float(*)
+ %3 : Float(*, *)
+ %4 : Float(*)) {
+ %5 : Float(*, *) = aten::mul(%4, %3)
+ %6 : Float(*, *) = aten::mul(%5, %1)
+ %7 : Float(*, *) = aten::mul(%2, %3)
+ %8 : int = prim::Constant[value=1]()
+ %9 : Float(*, *) = aten::add(%6, %7, %8)
+ %10 : Float(*, *) = aten::mul(%0, %1)
+ %11 : int = prim::Constant[value=1]()
+ %12 : Float(*, *) = aten::add(%9, %10, %11)
+ return (%12, %10, %9, %7, %6, %5);
+}
+with prim::FusionGroup_1 = graph(%0 : Float(*, *)
%1 : Tensor
- %2 : Tensor
- %3 : Tensor
- %4 : Tensor
- %5 : Tensor
- %6 : Tensor) {
- %7 : Float(*, *), %8 : Float(*, *), %9 : Float(*, *), %10 : Float(*, *) = prim::ConstantChunk[chunks=4, dim=1](%6)
- %11 : Float(*, *), %12 : Float(*, *), %13 : Float(*, *), %14 : Float(*, *) = prim::ConstantChunk[chunks=4, dim=1](%5)
- %15 : Float(*, *), %16 : Float(*, *), %17 : Float(*, *), %18 : Float(*, *) = prim::ConstantChunk[chunks=4, dim=1](%4)
- %19 : Float(*, *), %20 : Float(*, *), %21 : Float(*, *), %22 : Float(*, *) = prim::ConstantChunk[chunks=4, dim=1](%3)
- %23 : Float(*, *), %24 : Float(*, *), %25 : Float(*, *), %26 : Float(*, *) = prim::ConstantChunk[chunks=4, dim=1](%2)
- %27 : Float(*, *), %28 : Float(*, *), %29 : Float(*, *), %30 : Float(*, *) = prim::ConstantChunk[chunks=4, dim=1](%1)
- %31 : Float(*, *) = aten::mul(%23, %27)
- %32 : Float(*, *) = aten::mul(%24, %28)
- %33 : Float(*, *) = aten::mul(%25, %29)
- %34 : Float(*, *) = aten::mul(%26, %30)
- %35 : Float(*, *) = aten::mul(%19, %15)
- %36 : Float(*, *) = aten::mul(%20, %16)
- %37 : Float(*, *) = aten::mul(%21, %17)
- %38 : Float(*, *) = aten::mul(%22, %18)
- %39 : Float(*, *) = aten::mul(%11, %15)
- %40 : Float(*, *) = aten::mul(%12, %16)
- %41 : Float(*, *) = aten::mul(%13, %17)
- %42 : Float(*, *) = aten::mul(%14, %18)
- %43 : int = prim::Constant[value=1]()
- %44 : Float(*, *) = aten::add(%35, %31, %43)
- %45 : int = prim::Constant[value=1]()
- %46 : Float(*, *) = aten::add(%36, %32, %45)
- %47 : int = prim::Constant[value=1]()
- %48 : Float(*, *) = aten::add(%37, %33, %47)
- %49 : int = prim::Constant[value=1]()
- %50 : Float(*, *) = aten::add(%38, %34, %49)
- %51 : int = prim::Constant[value=1]()
- %52 : Float(*, *) = aten::add(%44, %39, %51)
- %53 : int = prim::Constant[value=1]()
- %54 : Float(*, *) = aten::add(%46, %40, %53)
- %55 : int = prim::Constant[value=1]()
- %56 : Float(*, *) = aten::add(%48, %41, %55)
- %57 : int = prim::Constant[value=1]()
- %58 : Float(*, *) = aten::add(%50, %42, %57)
- %59 : int = prim::Constant[value=1]()
- %60 : Float(*, *) = aten::add(%52, %7, %59)
- %61 : int = prim::Constant[value=1]()
- %62 : Float(*, *) = aten::add(%54, %8, %61)
- %63 : int = prim::Constant[value=1]()
- %64 : Float(*, *) = aten::add(%56, %9, %63)
- %65 : int = prim::Constant[value=1]()
- %66 : Float(*, *) = aten::add(%58, %10, %65)
- %ingate.1 : Float(*, *) = aten::sigmoid(%60)
- %forgetgate.1 : Float(*, *) = aten::sigmoid(%62)
- %cellgate.1 : Float(*, *) = aten::tanh(%64)
- %outgate.1 : Float(*, *) = aten::sigmoid(%66)
- %71 : Float(*, *) = aten::mul(%forgetgate.1, %0)
- %72 : Float(*, *) = aten::mul(%ingate.1, %cellgate.1)
- %73 : int = prim::Constant[value=1]()
- %cy : Float(*, *) = aten::add(%71, %72, %73)
- %75 : Float(*, *) = aten::tanh(%cy)
- %hy : Float(*, *) = aten::mul(%outgate.1, %75)
- return (%hy, %75, %cy, %outgate.1, %cellgate.1, %forgetgate.1, %ingate.1);
+ %2 : Tensor) {
+ %3 : Float(*, *), %4 : Float(*, *), %5 : Float(*, *), %6 : Float(*, *) = prim::ConstantChunk[chunks=4, dim=1](%2)
+ %7 : Float(*, *), %8 : Float(*, *), %9 : Float(*, *), %10 : Float(*, *) = prim::ConstantChunk[chunks=4, dim=1](%1)
+ %11 : int = prim::Constant[value=1]()
+ %12 : Float(*, *) = aten::add(%3, %7, %11)
+ %13 : int = prim::Constant[value=1]()
+ %14 : Float(*, *) = aten::add(%4, %8, %13)
+ %15 : int = prim::Constant[value=1]()
+ %16 : Float(*, *) = aten::add(%5, %9, %15)
+ %17 : int = prim::Constant[value=1]()
+ %18 : Float(*, *) = aten::add(%6, %10, %17)
+ %ingate.1 : Float(*, *) = aten::sigmoid(%12)
+ %forgetgate.1 : Float(*, *) = aten::sigmoid(%14)
+ %cellgate.1 : Float(*, *) = aten::tanh(%16)
+ %outgate.1 : Float(*, *) = aten::sigmoid(%18)
+ %23 : Float(*, *) = aten::mul(%forgetgate.1, %0)
+ %24 : Float(*, *) = aten::mul(%ingate.1, %cellgate.1)
+ %25 : int = prim::Constant[value=1]()
+ %cy : Float(*, *) = aten::add(%23, %24, %25)
+ %27 : Float(*, *) = aten::tanh(%cy)
+ %hy : Float(*, *) = aten::mul(%outgate.1, %27)
+ return (%hy, %27, %cy, %24, %23, %outgate.1, %cellgate.1, %forgetgate.1, %ingate.1);
}
vs = vs[:-drop]
# we don't want all the grad for all the outputs to be the same
# so we multiply each by a constant
- return sum([(i + 1) * v.sum() for i, v in enumerate(vs) if v is not None])
+ return sum([math.log(i + 2) * v.sum() for i, v in enumerate(vs) if v is not None])
if input_tensors is None:
input_tensors = reference_tensors
return ge
- def assertAllFused(self, graph):
+ def assertAllFused(self, graph, except_for=()):
if [n.kind() for n in graph.nodes()] == ['prim::DifferentiableGraph']:
graph = next(graph.nodes()).g('Subgraph')
- self.assertTrue(all(node.kind() in {'prim::Constant', 'prim::FusionGroup'} for node in graph.nodes()),
+ allowed_nodes = {'prim::Constant', 'prim::FusionGroup'} | set(except_for)
+ self.assertTrue(all(node.kind() in allowed_nodes for node in graph.nodes()),
'got {}'.format(graph))
self.assertTrue([node.kind() for node in graph.nodes()].count('prim::FusionGroup') == 1)
funcs = (func2, funcInf)
for f in funcs:
s = self.checkScript(f, (a, b))
- self.assertAllFused(s.graph_for(a, b))
+ self.assertAllFused(s.graph_for(a, b), except_for={'aten::size'})
c = s(a, b)
c.sum().backward()
graph = backward_graph(s)
- self.assertAllFused(graph)
+ self.assertAllFused(graph, except_for={'prim::SumToSize'})
def test_mul(self):
def func(a, b):
inputs = get_lstm_inputs('cuda', training=True)
module = self.checkScript(LSTMCellS, inputs)
forward_graph = module.graph_for(*inputs)
- self.assertGraphContainsExactly(
- forward_graph, 'prim::FusionGroup', 1, consider_subgraphs=True)
+ with self.assertRaises(AssertionError):
+ self.assertGraphContainsExactly(
+ forward_graph, 'prim::FusionGroup', 1, consider_subgraphs=True)
self.assertExpectedGraph(forward_graph, subname='forward')
hy, cy = module(*inputs)
inputs = get_milstm_inputs('cuda', training=True)
module = self.checkScript(MiLSTMCell, inputs)
forward_graph = module.graph_for(*inputs)
- self.assertGraphContainsExactly(
- forward_graph, 'prim::FusionGroup', 1, consider_subgraphs=True)
+ with self.assertRaises(AssertionError):
+ self.assertGraphContainsExactly(
+ forward_graph, 'prim::FusionGroup', 1, consider_subgraphs=True)
self.assertExpectedGraph(forward_graph, subname='forward')
hy, cy = module(*inputs)
ss << metadata.shape();
AT_ERROR(format_error(ss.str()));
}
- grads[i] = at::sum_to(grads[i], metadata.shape());
+ grads[i] = at::sum_to(std::move(grads[i]), metadata.shape());
}
if (!is_compatible_type(metadata.type(), grads[i].type())) {
std::stringstream ss;
size_t size() const {
return args.size();
}
+ const ArgumentInfo& at(size_t i) const {
+ return args[i];
+ }
size_t hashCode() const {
return hash_code;
}
return fillType(subtype, offset);
}));
} else {
+ offset++;
return original;
}
}
const int i;
};
+inline std::ostream & operator<<(std::ostream & out, const ArgumentInfo & info) {
+ if(!info.defined()) {
+ return out << "<undefined>";
+ }
+ out << "Tensor(device=" << info.device()
+ << ", type=" << toString(info.type())
+ << ", requires_grad=" << info.requires_grad()
+ << ", dims=" << info.dim() << ")";
+ return out;
+}
+
+inline std::ostream& operator<<(std::ostream & out, const ArgumentSpec & spec) {
+ out << "{";
+ for(size_t i = 0; i < spec.size(); ++i) {
+ if (i > 0)
+ out << ", ";
+ out << spec.at(i);
+ }
+ out << "}";
+ return out;
+}
+
inline std::ostream & operator<<(std::ostream & out, const CompleteArgumentInfo & info) {
if(!info.defined()) {
return out << "<undefined>";
"aten::sinh(Tensor self) -> Tensor",
"aten::tan(Tensor self) -> Tensor",
"aten::trunc(Tensor self) -> Tensor",
+ "prim::SumToSize(Tensor(a) self, int[] size) -> Tensor(a)",
"aten::log_softmax(Tensor self, int dim) -> Tensor",
"aten::avg_pool2d(Tensor self, int[] kernel_size, int[] stride, int[] padding, bool ceil_mode, bool count_include_pad) -> Tensor",
"aten::max_pool2d_with_indices(Tensor self, int[] kernel_size, int[] stride, int[] padding, int[] dilation, bool ceil_mode) -> (Tensor, Tensor)",
"aten::eq(Tensor self, Tensor other) -> Tensor",
"aten::ne(Tensor self, Tensor other) -> Tensor"
};
- const auto build_sym_grad = [node](const std::vector<SymbolicVariable>& grads) -> std::vector<SymbolicVariable> {
+ const auto sumToSizeOf = [node](SymbolicVariable v, Symbol input_name) -> SymbolicVariable {
+ Value * size;
+ {
+ WithInsertPoint insert_guard {node};
+ size = SymbolicVariable(node->namedInput(input_name)).size();
+ }
+ return v.sumToSize(size);
+ };
+ const auto build_sym_grad = [node, &sumToSizeOf](const std::vector<SymbolicVariable>& grads) -> std::vector<SymbolicVariable> {
auto inputs = fmap<SymbolicVariable>(node->inputs());
auto outputs = fmap<SymbolicVariable>(node->outputs());
if (node->matches("aten::add(Tensor self, Tensor other, *, Scalar alpha) -> Tensor")) {
- return {grads.at(0), grads.at(0) * node->namedInput(attr::alpha), nullptr};
+ return {sumToSizeOf(grads.at(0), attr::self),
+ sumToSizeOf(grads.at(0) * node->namedInput(attr::alpha), attr::other),
+ nullptr};
} else if (node->matches("aten::add(Tensor self, Scalar other, Scalar alpha) -> Tensor")) {
return {grads.at(0), nullptr, nullptr};
} else if (node->kind() == prim::AutogradAdd) {
+ // NB: AutogradAdds don't broadcast
return {grads.at(0), grads.at(0)};
} else if (node->matches("aten::sub(Tensor self, Tensor other, *, Scalar alpha) -> Tensor")) {
- return {grads.at(0), -grads.at(0) * node->namedInput(attr::alpha), nullptr};
+ return {sumToSizeOf(grads.at(0), attr::self),
+ sumToSizeOf(-grads.at(0) * node->namedInput(attr::alpha), attr::other),
+ nullptr};
} else if (node->matches("aten::sub(Tensor self, Scalar other, Scalar alpha) -> Tensor")) {
return {grads.at(0), nullptr, nullptr};
} else if (node->matches("aten::mul(Tensor self, Tensor other) -> Tensor")) {
- return {grads.at(0) * inputs.at(1), grads.at(0) * inputs.at(0)};
+ return {sumToSizeOf(grads.at(0) * inputs.at(1), attr::self),
+ sumToSizeOf(grads.at(0) * inputs.at(0), attr::other)};
} else if (node->matches("aten::mul(Tensor self, Scalar other) -> Tensor")) {
return {grads.at(0) * inputs.at(1), nullptr};
} else if (node->matches("aten::div(Tensor self, Tensor other) -> Tensor")) {
- return {grads.at(0) / inputs.at(1), -grads.at(0) * inputs.at(0) / (inputs.at(1) * inputs.at(1))};
+ return {sumToSizeOf(grads.at(0) / inputs.at(1), attr::self),
+ sumToSizeOf(-grads.at(0) * inputs.at(0) / (inputs.at(1) * inputs.at(1)), attr::other)};
} else if (node->matches("aten::div(Tensor self, Scalar other) -> Tensor")) {
return {grads.at(0) / inputs.at(1), nullptr};
} else if (node->matches("aten::max(Tensor self, Tensor other) -> Tensor")) {
- return {grads.at(0) * (inputs.at(0) > inputs.at(1)).type_as(grads.at(0)),
- grads.at(0) * (inputs.at(1) > inputs.at(0)).type_as(grads.at(0))};
+ return {sumToSizeOf(grads.at(0) * (inputs.at(0) > inputs.at(1)).type_as(grads.at(0)), attr::self),
+ sumToSizeOf(grads.at(0) * (inputs.at(1) > inputs.at(0)).type_as(grads.at(0)), attr::other)};
} else if (node->matches("aten::min(Tensor self, Tensor other) -> Tensor")) {
- return {grads.at(0) * (inputs.at(0) < inputs.at(1)).type_as(grads.at(0)),
- grads.at(0) * (inputs.at(1) < inputs.at(0)).type_as(grads.at(0))};
+ return {sumToSizeOf(grads.at(0) * (inputs.at(0) < inputs.at(1)).type_as(grads.at(0)), attr::self),
+ sumToSizeOf(grads.at(0) * (inputs.at(1) < inputs.at(0)).type_as(grads.at(0)), attr::other)};
} else if (node->matches("aten::where(Tensor condition, Tensor self, Tensor other) -> Tensor")) {
- return {nullptr, grads.at(0) * inputs.at(0).type_as(grads.at(0)),
- grads.at(0) * (1 - inputs.at(0)).type_as(grads.at(0))};
+ return {nullptr,
+ sumToSizeOf(grads.at(0) * inputs.at(0).type_as(grads.at(0)), attr::self),
+ sumToSizeOf(grads.at(0) * (1 - inputs.at(0)).type_as(grads.at(0)), attr::other)};
} else if (node->matches("aten::sigmoid(Tensor self) -> Tensor")) {
// TODO: The order of operations matter in this case. This
} else if (node->matches("aten::atan(Tensor self) -> Tensor")) {
return {grads.at(0) / (inputs.at(0) * inputs.at(0) + at::Scalar(1))};
+ } else if (node->matches("prim::SumToSize(Tensor(a) self, int[] size) -> Tensor(a)")) {
+ Value * self_size;
+ {
+ WithInsertPoint insert_guard { node };
+ self_size = inputs.at(0).size();
+ }
+ return {grads.at(0).expand(self_size), nullptr};
+
} else if (node->matches("aten::ceil(Tensor self) -> Tensor")) {
return {SymbolicVariable::zeros_like(grads.at(0))};
return {grads.at(0).squeeze(node->namedInput(attr::dim)), nullptr};
} else if (node->matches("aten::addmm(Tensor self, Tensor mat1, Tensor mat2, *, Scalar beta, Scalar alpha) -> Tensor")) {
- return {grads.at(0) * node->namedInput(attr::beta),
+ return {sumToSizeOf(grads.at(0) * node->namedInput(attr::beta), attr::self),
grads.at(0).mm(inputs.at(2).t()) * node->namedInput(attr::alpha),
inputs.at(1).t().mm(grads.at(0)) * node->namedInput(attr::alpha),
nullptr, nullptr};
} else if (node->matches("aten::thnn_conv2d_forward(Tensor self, Tensor weight, int[] kernel_size, Tensor? bias, int[] stride, int[] padding) -> (Tensor, Tensor, Tensor)")) {
auto graph = node->owningGraph();
auto backward_value = graph->insert(aten::thnn_conv2d_backward, {
- grads.at(0).value(),
- inputs.at(0).value(),
- inputs.at(1).value(),
- node->namedInput(attr::kernel_size),
- node->namedInput(attr::stride),
- node->namedInput(attr::padding),
- outputs.at(1).value(),
- outputs.at(2).value(),
- graph->insertConstant(std::vector<bool>{true, true, true})
+ grads.at(0).value(),
+ inputs.at(0).value(),
+ inputs.at(1).value(),
+ node->namedInput(attr::kernel_size),
+ node->namedInput(attr::stride),
+ node->namedInput(attr::padding),
+ outputs.at(1).value(),
+ outputs.at(2).value(),
+ graph->insertConstant(std::vector<bool>{true, true, true})
});
// graph->insert returns a tuple automatically if multiple outputs are returned. So unpack them again.
Node* tuple_unpack_node = graph->insertNode(graph->createTupleUnpack(backward_value));
} else if (node->matches("aten::native_batch_norm(Tensor input, Tensor? weight, Tensor? bias, Tensor? running_mean, Tensor? running_var, bool training, float momentum, float eps) -> (Tensor, Tensor, Tensor)")) {
auto graph = node->owningGraph();
auto backward_value = graph->insert(aten::native_batch_norm_backward, {
- grads.at(0).value(),
- inputs.at(0).value(),
- inputs.at(1).value(),
- inputs.at(3).value(),
- inputs.at(4).value(),
- outputs.at(1).value(),
- outputs.at(2).value(),
- inputs.at(5).value(),
- inputs.at(7).value(),
- graph->insertConstant(std::vector<bool>{true, true, true})
+ grads.at(0).value(),
+ inputs.at(0).value(),
+ inputs.at(1).value(),
+ inputs.at(3).value(),
+ inputs.at(4).value(),
+ outputs.at(1).value(),
+ outputs.at(2).value(),
+ inputs.at(5).value(),
+ inputs.at(7).value(),
+ graph->insertConstant(std::vector<bool>{true, true, true})
});
// graph->insert returns a tuple automatically if multiple outputs are returned. So unpack them again.
Node* tuple_unpack_node = graph->insertNode(graph->createTupleUnpack(backward_value));
auto rev_info = addReverseInline(grad_desc);
// Lift constants captured for the reverse graph into it
liftConstants(grad_desc, rev_info);
- // addReverseInline has to call gradientForNode if *any* of the outputs
- // require grad, but it will emit vjps for *all* outputs. Use DCE to remove
- // unnecessary nodes.
- EliminateDeadCode(rev_info.reverse_block);
// Fills in f, df, f_real_outputs, df_input_captures,
// modifies df_input_vjps (new vjps are added for temporaries)
lambdaLiftReverse(grad_desc, rev_info);
// It's possible the we've cloned the same constants many times, so
// de-duplicate them
ConstantPooling(grad_desc.df);
+ // addReverseInline has to call gradientForNode if *any* of the inputs
+ // require grad, but it will emit vjps for *all* inputs. Use DCE to remove
+ // unnecessary nodes.
+ EliminateDeadCode(grad_desc.df->block());
return grad_desc;
}
#include "torch/csrc/jit/operator.h"
#include "torch/csrc/jit/custom_operator.h"
#include "torch/csrc/jit/script/jit_exception.h"
-
#include "torch/csrc/variable_tensor_functions.h"
+#include <ATen/ExpandUtils.h>
+#include <c10/util/SmallVector.h>
+
#include <exception>
#include <iostream>
#include <limits>
};
}),
Operator(
+ "prim::SumToSize(Tensor(a) self, int[] size) -> Tensor(a)",
+ [](const Node* node) {
+ return [=](Stack& stack) {
+ at::Tensor self;
+ Shared<IntList> desired_sizes;
+ pop(stack, self, desired_sizes);
+ push(stack, at::sum_to(std::move(self), desired_sizes->elements()));
+ return 0;
+ };
+ }),
+ Operator(
prim::TupleUnpack,
[](const Node* node) {
size_t num_elems = node->outputs().size();
"aten::size(Tensor self) -> int[]",
[](Stack& stack) {
autograd::profiler::RecordFunction record("sizes");
- auto result = (std::move(pop(stack))).toTensor().sizes();
- pack(stack, std::move(result));
+ auto t = std::move(pop(stack)).toTensor();
+ pack(stack, t.sizes().vec());
return 0;
}),
Operator(
SymbolicVariable operator%(at::Scalar rhs) const {
return create(aten::remainder, {*this, insertConstant(rhs)})[0].typeLike(*this);
}
+ Value* size() const {
+ return v->owningGraph()->insert(aten::size, {v});
+ }
+ SymbolicVariable sumToSize(Value * size) const {
+ return create(prim::SumToSize, {*this, size})[0];
+ }
+ SymbolicVariable expand(Value * size) const {
+ return v->owningGraph()->insert(aten::expand, {v, size});
+ }
SymbolicVariable isnan() const {
return create(aten::ne, {*this, *this})[0].typeLikeWithScalarType(*this, at::kByte);
}