Expand activity analysis to the test nodes of if and while statements.
authorA. Unique TensorFlower <gardener@tensorflow.org>
Thu, 5 Apr 2018 16:57:58 +0000 (09:57 -0700)
committerTensorFlower Gardener <gardener@tensorflow.org>
Thu, 5 Apr 2018 17:07:37 +0000 (10:07 -0700)
PiperOrigin-RevId: 191756234

tensorflow/contrib/autograph/pyct/static_analysis/activity.py
tensorflow/contrib/autograph/pyct/static_analysis/activity_test.py
tensorflow/contrib/autograph/pyct/static_analysis/annos.py

index da6a2f6..6dd5309 100644 (file)
@@ -265,10 +265,10 @@ class ActivityAnalizer(transformer.Base):
       qn = QN(node.name)
       self.scope.mark_write(qn)
     current_scope = self.scope
-    fndef_scope = Scope(current_scope, isolated=True)
-    self.scope = fndef_scope
+    body_scope = Scope(current_scope, isolated=True)
+    self.scope = body_scope
     self.generic_visit(node)
-    anno.setanno(node, NodeAnno.BODY_SCOPE, fndef_scope)
+    anno.setanno(node, NodeAnno.BODY_SCOPE, body_scope)
     self.scope = current_scope
     return node
 
@@ -282,7 +282,13 @@ class ActivityAnalizer(transformer.Base):
     return node
 
   def visit_If(self, node):
+    current_scope = self.scope
+    cond_scope = Scope(current_scope, isolated=False)
+    self.scope = cond_scope
     self.visit(node.test)
+    anno.setanno(node, NodeAnno.COND_SCOPE, cond_scope)
+    self.scope = current_scope
+
     node = self._process_parallel_blocks(node,
                                          ((node.body, NodeAnno.BODY_SCOPE),
                                           (node.orelse, NodeAnno.ORELSE_SCOPE)))
@@ -297,7 +303,13 @@ class ActivityAnalizer(transformer.Base):
     return node
 
   def visit_While(self, node):
+    current_scope = self.scope
+    cond_scope = Scope(current_scope, isolated=False)
+    self.scope = cond_scope
     self.visit(node.test)
+    anno.setanno(node, NodeAnno.COND_SCOPE, cond_scope)
+    self.scope = current_scope
+
     node = self._process_parallel_blocks(node,
                                          ((node.body, NodeAnno.BODY_SCOPE),
                                           (node.orelse, NodeAnno.ORELSE_SCOPE)))
index 37c2887..1e6c686 100644 (file)
@@ -204,6 +204,8 @@ class ActivityAnalizerTest(test.TestCase):
     self.assertScopeIsRmc(
         anno.getanno(while_node, NodeAnno.BODY_SCOPE).parent, ('a', 'b', 'c'),
         ('b', 'c'), ('a', 'b', 'c'))
+    self.assertScopeIsRmc(
+        anno.getanno(while_node, NodeAnno.COND_SCOPE), ('b',), (), ())
 
   def test_for(self):
 
index 5254b83..d6d9f7e 100644 (file)
@@ -43,6 +43,7 @@ class NodeAnno(NoValue):
   # Scopes
   # Scopes are represented by objects of type activity.Scope.
   ARGS_SCOPE = 'The scope for the argument list of a function call.'
+  COND_SCOPE = 'The scope for the test node of a conditional statement.'
   BODY_SCOPE = (
       'The scope for the main body of a statement (True branch for if '
       'statements, main body for loops).')