Remove NoneGenerator
authorDavid Riazati <davidriazati@fb.com>
Sat, 22 Dec 2018 00:30:35 +0000 (16:30 -0800)
committerFacebook Github Bot <facebook-github-bot@users.noreply.github.com>
Sat, 22 Dec 2018 00:33:37 +0000 (16:33 -0800)
Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/15335

Differential Revision: D13540357

Pulled By: driazati

fbshipit-source-id: a289e5944b65872103f68faac74e18f10e7c6fff

aten/src/ATen/core/interned_strings.h
tools/jit/gen_jit_dispatch.py
torch/csrc/jit/ir.cpp
torch/csrc/jit/ir.h
torch/csrc/jit/passes/constant_propagation.cpp
torch/csrc/jit/passes/python_print.cpp
torch/csrc/jit/passes/shape_analysis.cpp
torch/csrc/jit/register_prim_ops.cpp
torch/csrc/jit/script/schema_matching.cpp
torch/csrc/jit/tracer.cpp

index 29d0062..2e6ff32 100644 (file)
@@ -67,7 +67,6 @@ namespace c10 {
   _(prim, AnyDefined)              \
   _(prim, FusedConcat)             \
   _(prim, ConstantChunk)           \
-  _(prim, NoneGenerator)           \
   _(prim, MMTreeReduce)            \
   _(prim, MMBatchSide)             \
   _(aten, warn)                    \
index 2f53e67..58e6570 100644 (file)
@@ -59,7 +59,7 @@ TYPE_MAP = {
     'int64_t?': 'int?',
     'double': 'float',
     'bool': 'bool',
-    'Generator': 'Generator',
+    'Generator': 'Generator?',
 }
 
 
index 3ff7340..f30e92a 100644 (file)
@@ -650,17 +650,17 @@ const FunctionSchema* Node::maybeSchema() const {
 bool Node::isNondeterministic() const {
   static const OperatorSet nondeterministic_ops = {
     "aten::dropout(Tensor input, float p, bool train) -> Tensor",
-    "aten::_fused_dropout(Tensor self, float p, Generator generator) -> (Tensor, Tensor)",
-    "aten::_standard_gamma(Tensor self, Generator generator) -> Tensor",
-    "aten::bernoulli(Tensor self, *, Generator generator) -> Tensor",
-    "aten::bernoulli(Tensor self, float p, *, Generator generator) -> Tensor",
-    "aten::multinomial(Tensor self, int num_samples, bool replacement, *, Generator generator) -> Tensor",
-    "aten::normal(Tensor mean, Tensor std, *, Generator generator) -> Tensor",
-    "aten::normal(float mean, Tensor std, *, Generator generator) -> Tensor",
-    "aten::normal(Tensor mean, float std, *, Generator generator) -> Tensor",
-    "aten::poisson(Tensor self, Generator generator) -> Tensor",
-    "aten::rrelu(Tensor self, Scalar lower, Scalar upper, bool training, Generator generator) -> Tensor",
-    "aten::rrelu_with_noise(Tensor self, Tensor noise, Scalar lower, Scalar upper, bool training, Generator generator) -> Tensor",
+    "aten::_fused_dropout(Tensor self, float p, Generator? generator) -> (Tensor, Tensor)",
+    "aten::_standard_gamma(Tensor self, Generator? generator) -> Tensor",
+    "aten::bernoulli(Tensor self, *, Generator? generator) -> Tensor",
+    "aten::bernoulli(Tensor self, float p, *, Generator? generator) -> Tensor",
+    "aten::multinomial(Tensor self, int num_samples, bool replacement, *, Generator? generator) -> Tensor",
+    "aten::normal(Tensor mean, Tensor std, *, Generator? generator) -> Tensor",
+    "aten::normal(float mean, Tensor std, *, Generator? generator) -> Tensor",
+    "aten::normal(Tensor mean, float std, *, Generator? generator) -> Tensor",
+    "aten::poisson(Tensor self, Generator? generator) -> Tensor",
+    "aten::rrelu(Tensor self, Scalar lower, Scalar upper, bool training, Generator? generator) -> Tensor",
+    "aten::rrelu_with_noise(Tensor self, Tensor noise, Scalar lower, Scalar upper, bool training, Generator? generator) -> Tensor",
     "aten::rand(int[] size, *, int dtype, int layout, Device device) -> Tensor",
     "aten::rand_like(Tensor self) -> Tensor",
     "aten::rand_like(Tensor self, *, int dtype, int layout, Device device) -> Tensor",
@@ -1344,12 +1344,6 @@ Node* Graph::createNone(TypePtr typ) {
   return n;
 }
 
-Node * Graph::createNoneGenerator() {
-  auto n = create(prim::NoneGenerator);
-  n->output()->setType(GeneratorType::get());
-  return n;
-}
-
 Node * Graph::createFusionGroup() {
   auto n = create(prim::FusionGroup, 0);
   n->g_(attr::Subgraph,std::make_shared<Graph>(current_scope()));
index 251033d..5a8f4f0 100644 (file)
@@ -858,7 +858,6 @@ public:
 
   TORCH_API Node* createNone(TypePtr typ); // value of None with type Optional[typ]
   TORCH_API Node* createUndefined();
-  TORCH_API Node* createNoneGenerator();
   TORCH_API Node* createFusionGroup();
   TORCH_API Node* createDifferentiableSubgraph();
   TORCH_API Node* createTuple(at::ArrayRef<Value*> values);
index 2446759..6d66dc2 100644 (file)
@@ -18,7 +18,6 @@ std::unordered_set<Symbol> skip_list = {
   prim::Loop, //TODO: handle Loop
   prim::Constant,
   prim::Undefined,
-  prim::NoneGenerator,
   prim::None, // it is already a constant and propagating it will lose
               // important type information about which Optional type it is
   // TODO (zach): we should consider skipping tensor factories in the cases
index 638ff95..1f28cc0 100644 (file)
@@ -237,7 +237,6 @@ struct PythonPrintPass {
   bool isConstantLike(Node* n) {
     switch(n->kind()) {
       case prim::Constant:
-      case prim::NoneGenerator:
       case prim::Undefined:
       case prim::None:
         return true;
@@ -679,7 +678,6 @@ struct PythonPrintPass {
         IValue v = toIValue(node->output()).value();
         printConstant(stmt, v);
       } break;
-      case prim::NoneGenerator:
       case prim::Undefined:
       case prim::None: {
         if (node->output()->type()->isSubtypeOf(NoneType::get())) {
@@ -1004,7 +1002,6 @@ TORCH_API bool printerHasSpecialCaseFor(Symbol sym) {
     prim::ListConstruct,
     prim::ListUnpack,
     prim::None,
-    prim::NoneGenerator,
     prim::Print,
     prim::PythonOp,
     prim::TupleConstruct,
index 69e3170..e4f321d 100644 (file)
@@ -546,13 +546,13 @@ class ShapePropagator {
             "aten::ceil(Tensor self) -> Tensor",
             "aten::clone(Tensor self) -> Tensor",
             "aten::contiguous(Tensor self) -> Tensor",
-            "aten::bernoulli(Tensor self, *, Generator generator) -> Tensor",
+            "aten::bernoulli(Tensor self, *, Generator? generator) -> Tensor",
             "aten::celu(Tensor self, Scalar alpha) -> Tensor",
             "aten::clamp(Tensor self, Scalar? min, Scalar? max) -> Tensor",
             "aten::clamp_max(Tensor self, Scalar max) -> Tensor",
             "aten::clamp_min(Tensor self, Scalar min) -> Tensor",
             "aten::alpha_dropout(Tensor input, float p, bool train) -> Tensor",
-            "aten::bernoulli(Tensor self, float p, *, Generator generator) -> Tensor",
+            "aten::bernoulli(Tensor self, float p, *, Generator? generator) -> Tensor",
             "aten::cos(Tensor self) -> Tensor",
             "aten::cosh(Tensor self) -> Tensor",
             "aten::digamma(Tensor self) -> Tensor",
@@ -581,15 +581,15 @@ class ShapePropagator {
             "aten::leaky_relu(Tensor self, Scalar negative_slope) -> Tensor",
             "aten::lgamma(Tensor self) -> Tensor",
             "aten::mvlgamma(Tensor self, int p) -> Tensor",
-            "aten::normal(float mean, Tensor std, *, Generator generator) -> Tensor",
-            "aten::normal(Tensor mean, float std, *, Generator generator) -> Tensor",
+            "aten::normal(float mean, Tensor std, *, Generator? generator) -> Tensor",
+            "aten::normal(Tensor mean, float std, *, Generator? generator) -> Tensor",
             "aten::permute(Tensor self, int[] dims) -> Tensor",
             "aten::pin_memory(Tensor self) -> Tensor",
             "aten::pinverse(Tensor self, float rcond) -> Tensor",
             "aten::reciprocal(Tensor self) -> Tensor",
             "aten::relu(Tensor self) -> Tensor",
             "aten::round(Tensor self) -> Tensor",
-            "aten::rrelu(Tensor self, Scalar lower, Scalar upper, bool training, Generator generator) -> Tensor",
+            "aten::rrelu(Tensor self, Scalar lower, Scalar upper, bool training, Generator? generator) -> Tensor",
             "aten::rsqrt(Tensor self) -> Tensor",
             "aten::selu(Tensor self) -> Tensor",
             "aten::sigmoid(Tensor self) -> Tensor",
@@ -723,7 +723,7 @@ class ShapePropagator {
     //   tensor outputs : 1
     static const register_formula_for binary_ops_strict_match{
         {
-            "aten::normal(Tensor mean, Tensor std, *, Generator generator) -> Tensor",
+            "aten::normal(Tensor mean, Tensor std, *, Generator? generator) -> Tensor",
             "aten::mm(Tensor self, Tensor mat2) -> Tensor",
             "aten::bmm(Tensor self, Tensor mat2) -> Tensor",
         },
index 6c9e549..e9ebd1f 100644 (file)
@@ -260,14 +260,6 @@ RegisterOperators reg({
         };
       }),
     Operator(
-        "prim::NoneGenerator() -> Generator",
-        [](const Node* node) {
-          return [](Stack& stack) {
-            stack.emplace_back();
-            return 0;
-          };
-        }),
-    Operator(
         prim::Print,
         [](const Node* node) {
           size_t num_inputs = node->inputs().size();
index e27f5ea..4ec2651 100644 (file)
@@ -83,9 +83,7 @@ Value* tryConvertToType(
   }
 
   if (value->type()->isSubtypeOf(NoneType::get()) && !concrete_type->isSubtypeOf(NoneType::get())){
-    if (concrete_type->isSubtypeOf(GeneratorType::get())) {
-      value = graph.insertNode(graph.createNoneGenerator())->output();
-    } else if (concrete_type->isSubtypeOf(OptionalType::ofTensor())) {
+    if (concrete_type->isSubtypeOf(OptionalType::ofTensor())) {
       // create undefined tensor when None pass to a optional[tensor] formal arg
       value = graph.insertNode(graph.createUndefined())->output();
     } else if (auto optional_type = concrete_type->cast<OptionalType>()) {
index 874fb4d..ae358be 100644 (file)
@@ -107,7 +107,7 @@ void addInputs(Node *n, const char * name, at::Generator * value)            {
     detail::badArgType(value);
   }
   Graph * g = n->owningGraph();
-  Value * undef_gen = g->insertNode(g->createNoneGenerator())->output();
+  Value * undef_gen = g->insertNode(g->createNone(GeneratorType::get()))->output();
   n->addInput(undef_gen);
 }
 void addInputs(Node *n, const char * name, at::Device value) {