Add functions to extract the basic symbols on which a composite name relies. This...
authorA. Unique TensorFlower <gardener@tensorflow.org>
Thu, 5 Apr 2018 22:37:50 +0000 (15:37 -0700)
committerTensorFlower Gardener <gardener@tensorflow.org>
Thu, 5 Apr 2018 22:40:25 +0000 (15:40 -0700)
PiperOrigin-RevId: 191809965

tensorflow/contrib/autograph/pyct/qual_names.py
tensorflow/contrib/autograph/pyct/qual_names_test.py

index 4d5764a..583cf7e 100644 (file)
@@ -112,6 +112,29 @@ class QN(object):
       raise ValueError('Cannot get parent of simple name "%s".' % self.qn[0])
     return self._parent
 
+  @property
+  def support_set(self):
+    """Returns the set of simple symbols that this QN relies on.
+
+    This would be the smallest set of symbols necessary for the QN to
+    statically resolve (assuming properties and index ranges are verified
+    at runtime).
+
+    Examples:
+      'a.b' has only one support symbol, 'a'
+      'a[i]' has two roots, 'a' and 'i'
+    """
+    # TODO(mdan): This might be the set of Name nodes in the AST. Track those?
+    roots = set()
+    if self.has_attr():
+      roots.update(self.parent.support_set)
+    elif self.has_subscript():
+      roots.update(self.parent.support_set)
+      roots.update(self.qn[1].support_set)
+    else:
+      roots.add(self)
+    return roots
+
   def __hash__(self):
     return hash(self.qn + (self._has_attr, self._has_subscript))
 
index 103bd25..264afd5 100644 (file)
@@ -154,6 +154,21 @@ class QNTest(test.TestCase):
     a_sub_three = QN(a, subscript=QN(qual_names.NumberLiteral(3)))
     self.assertEqual(a_sub_three.ast().slice.value.n, 3)
 
+  def test_support_set(self):
+    a = QN('a')
+    b = QN('b')
+    c = QN('c')
+    a_sub_b = QN(a, subscript=b)
+    a_dot_b = QN(a, attr='b')
+    a_dot_b_dot_c = QN(a_dot_b, attr='c')
+    a_dot_b_sub_c = QN(a_dot_b, subscript=c)
+
+    self.assertSetEqual(a.support_set, set((a,)))
+    self.assertSetEqual(a_sub_b.support_set, set((a, b)))
+    self.assertSetEqual(a_dot_b.support_set, set((a,)))
+    self.assertSetEqual(a_dot_b_dot_c.support_set, set((a,)))
+    self.assertSetEqual(a_dot_b_sub_c.support_set, set((a, c)))
+
 
 class QNResolverTest(test.TestCase):