fix dropout AD & rename range to rangelist (#17691)
authorAiling Zhang <ailzhang@fb.com>
Wed, 6 Mar 2019 04:47:02 +0000 (20:47 -0800)
committerFacebook Github Bot <facebook-github-bot@users.noreply.github.com>
Wed, 6 Mar 2019 04:50:10 +0000 (20:50 -0800)
Summary:
fixes #17669
Address apaszke 's comments in #17523
Pull Request resolved: https://github.com/pytorch/pytorch/pull/17691

Differential Revision: D14328083

Pulled By: ailzhang

fbshipit-source-id: 9ec4a54f13bfd1aaf4b1821dd00c31793ac07a44

aten/src/ATen/core/interned_strings.h
torch/csrc/jit/autodiff.cpp
torch/csrc/jit/register_prim_ops.cpp
torch/csrc/jit/script/compiler.cpp
torch/csrc/jit/symbolic_script.cpp

index 2bd28fb..77eabe7 100644 (file)
@@ -73,7 +73,7 @@ namespace c10 {
   _(prim, MMBatchSide)             \
   _(prim, min)                     \
   _(prim, max)                     \
-  _(prim, range)                   \
+  _(prim, rangelist)               \
   _(aten, _grad_sum_to_size)       \
   _(aten, _ncf_unsqueeze)          \
   _(aten, warn)                    \
index b1b1da2..cfe244f 100644 (file)
@@ -129,9 +129,8 @@ bool isDifferentiable(Node* n) {
     return true;
 
   if (n->matches(
-          "aten::dropout(Tensor input, float p, bool train) -> Tensor")) {
-    auto train = n->get<bool>(attr::train).value();
-    return train;
+          "aten::dropout(Tensor input, float p, bool train) -> Tensor", attr::train)) {
+    return n->get<bool>(attr::train).value();
   }
 
   auto schema = n->maybeSchema();
index 257e09a..e825fd7 100644 (file)
@@ -122,7 +122,7 @@ RegisterOperators reg(
            };
          }),
      Operator(
-         "prim::range(int n) -> int[]",
+         "prim::rangelist(int n) -> int[]",
          [](Stack& stack) {
            int64_t n;
            pop(stack, n);
index 6ca5442..1614bb3 100644 (file)
@@ -401,7 +401,7 @@ struct Environment {
           {"min", std::make_shared<BuiltinFunction>(prim::min, at::nullopt)},
           {"max", std::make_shared<BuiltinFunction>(prim::max, at::nullopt)},
           {"list", std::make_shared<BuiltinFunction>(aten::list, at::nullopt)},
-          {"range", std::make_shared<BuiltinFunction>(prim::range, at::nullopt)},
+          {"rangelist", std::make_shared<BuiltinFunction>(prim::rangelist, at::nullopt)},
       };
       auto it = globals.find(ident);
       if (it != globals.end())
index d6ee9fb..d1983d9 100644 (file)
@@ -338,7 +338,7 @@ const std::vector<std::string> functions = {
             elif dim == 2:
                 out = mat.t()
             else:
-                dims = range(dim)
+                dims = rangelist(dim)
                 dims[-1] = dim - 2
                 dims[-2] = dim - 1
                 out = mat.permute(dims)