From 1975917d0e6ba5580b2f14cfa8dd70b2592c6d3b Mon Sep 17 00:00:00 2001 From: James Reed Date: Thu, 29 Nov 2018 20:30:02 -0800 Subject: [PATCH] fix copy_ (#14593) Summary: Closes https://github.com/pytorch/pytorch/issues/14590 Pull Request resolved: https://github.com/pytorch/pytorch/pull/14593 Differential Revision: D13272510 Pulled By: jamesr66a fbshipit-source-id: b6921a98460c371d435277c416dad0b5ab0fec8c --- test/test_jit.py | 7 +++++++ torch/csrc/jit/register_prim_ops.cpp | 3 ++- 2 files changed, 9 insertions(+), 1 deletion(-) diff --git a/test/test_jit.py b/test/test_jit.py index 3e1e1fe..da0487a 100644 --- a/test/test_jit.py +++ b/test/test_jit.py @@ -8734,6 +8734,13 @@ a") return ls self.checkScript(foo, (torch.rand(2, 3), torch.rand(3))) + def test_inplace_copy_script(self): + def foo(x): + a = torch.rand(3, 4) + a.copy_(x) + return a + self.checkScript(foo, (torch.rand(3, 4),)) + def test_lhs_indexing_increment(self): def foo(a, b): a[0] += b diff --git a/torch/csrc/jit/register_prim_ops.cpp b/torch/csrc/jit/register_prim_ops.cpp index 0bdd505..3e802a9 100644 --- a/torch/csrc/jit/register_prim_ops.cpp +++ b/torch/csrc/jit/register_prim_ops.cpp @@ -921,7 +921,8 @@ Operator( \ #define CREATE_COPY_OP(other_type, c_type) \ Operator( \ - "aten::copy_(Tensor(a!) t, " #other_type " other) -> Tensor(a!)", \ + "aten::copy_(Tensor(a!) self, " #other_type \ + " other) -> Tensor(a!)", \ [](const Node* node) { \ return [=](Stack& stack) { \ at::Tensor t; \ -- 2.7.4