From a0c80b9a54dc9669c0f5d151bee9f0b3a4fd71a0 Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Thu, 5 Apr 2018 09:57:58 -0700 Subject: [PATCH] Expand activity analysis to the test nodes of if and while statements. PiperOrigin-RevId: 191756234 --- .../contrib/autograph/pyct/static_analysis/activity.py | 18 +++++++++++++++--- .../autograph/pyct/static_analysis/activity_test.py | 2 ++ .../contrib/autograph/pyct/static_analysis/annos.py | 1 + 3 files changed, 18 insertions(+), 3 deletions(-) diff --git a/tensorflow/contrib/autograph/pyct/static_analysis/activity.py b/tensorflow/contrib/autograph/pyct/static_analysis/activity.py index da6a2f6..6dd5309 100644 --- a/tensorflow/contrib/autograph/pyct/static_analysis/activity.py +++ b/tensorflow/contrib/autograph/pyct/static_analysis/activity.py @@ -265,10 +265,10 @@ class ActivityAnalizer(transformer.Base): qn = QN(node.name) self.scope.mark_write(qn) current_scope = self.scope - fndef_scope = Scope(current_scope, isolated=True) - self.scope = fndef_scope + body_scope = Scope(current_scope, isolated=True) + self.scope = body_scope self.generic_visit(node) - anno.setanno(node, NodeAnno.BODY_SCOPE, fndef_scope) + anno.setanno(node, NodeAnno.BODY_SCOPE, body_scope) self.scope = current_scope return node @@ -282,7 +282,13 @@ class ActivityAnalizer(transformer.Base): return node def visit_If(self, node): + current_scope = self.scope + cond_scope = Scope(current_scope, isolated=False) + self.scope = cond_scope self.visit(node.test) + anno.setanno(node, NodeAnno.COND_SCOPE, cond_scope) + self.scope = current_scope + node = self._process_parallel_blocks(node, ((node.body, NodeAnno.BODY_SCOPE), (node.orelse, NodeAnno.ORELSE_SCOPE))) @@ -297,7 +303,13 @@ class ActivityAnalizer(transformer.Base): return node def visit_While(self, node): + current_scope = self.scope + cond_scope = Scope(current_scope, isolated=False) + self.scope = cond_scope self.visit(node.test) + anno.setanno(node, NodeAnno.COND_SCOPE, cond_scope) + self.scope = current_scope + node = self._process_parallel_blocks(node, ((node.body, NodeAnno.BODY_SCOPE), (node.orelse, NodeAnno.ORELSE_SCOPE))) diff --git a/tensorflow/contrib/autograph/pyct/static_analysis/activity_test.py b/tensorflow/contrib/autograph/pyct/static_analysis/activity_test.py index 37c2887..1e6c686 100644 --- a/tensorflow/contrib/autograph/pyct/static_analysis/activity_test.py +++ b/tensorflow/contrib/autograph/pyct/static_analysis/activity_test.py @@ -204,6 +204,8 @@ class ActivityAnalizerTest(test.TestCase): self.assertScopeIsRmc( anno.getanno(while_node, NodeAnno.BODY_SCOPE).parent, ('a', 'b', 'c'), ('b', 'c'), ('a', 'b', 'c')) + self.assertScopeIsRmc( + anno.getanno(while_node, NodeAnno.COND_SCOPE), ('b',), (), ()) def test_for(self): diff --git a/tensorflow/contrib/autograph/pyct/static_analysis/annos.py b/tensorflow/contrib/autograph/pyct/static_analysis/annos.py index 5254b83..d6d9f7e 100644 --- a/tensorflow/contrib/autograph/pyct/static_analysis/annos.py +++ b/tensorflow/contrib/autograph/pyct/static_analysis/annos.py @@ -43,6 +43,7 @@ class NodeAnno(NoValue): # Scopes # Scopes are represented by objects of type activity.Scope. ARGS_SCOPE = 'The scope for the argument list of a function call.' + COND_SCOPE = 'The scope for the test node of a conditional statement.' BODY_SCOPE = ( 'The scope for the main body of a statement (True branch for if ' 'statements, main body for loops).') -- 2.7.4