Add basic class support:
authorA. Unique TensorFlower <gardener@tensorflow.org>
Thu, 18 Jan 2018 20:50:29 +0000 (12:50 -0800)
committerTensorFlower Gardener <gardener@tensorflow.org>
Thu, 18 Jan 2018 20:54:13 +0000 (12:54 -0800)
 * frontend API extended to allow converting an entire class
 * function conversion now supports object methods
 * responsibility for renaming the top level object is now moved up from
the call tree transformer - this is because functions may or may not be renamed based on context (e.g. object members are not renamed)
Not yet implemented:
 * static analysis and side effect protection for object fields

PiperOrigin-RevId: 182423203

12 files changed:
tensorflow/contrib/py2tf/api.py
tensorflow/contrib/py2tf/conversion.py
tensorflow/contrib/py2tf/convert/builtin_functions_test.py
tensorflow/contrib/py2tf/convert/call_trees.py
tensorflow/contrib/py2tf/convert/call_trees_test.py
tensorflow/contrib/py2tf/convert/control_flow_test.py
tensorflow/contrib/py2tf/convert/print_functions_test.py
tensorflow/contrib/py2tf/convert/side_effect_guards_test.py
tensorflow/contrib/py2tf/naming.py
tensorflow/contrib/py2tf/pyct/static_analysis/live_values.py
tensorflow/contrib/py2tf/pyct/static_analysis/type_info.py
tensorflow/contrib/py2tf/pyct/static_analysis/type_info_test.py

index 296a084a9f63e6a3df1b8c2b347d66322988e3d2..3a367209694d3210913e515ece62ad1f9e3fc3ed 100644 (file)
@@ -25,22 +25,33 @@ from tensorflow.contrib.py2tf import config
 from tensorflow.contrib.py2tf import conversion
 from tensorflow.contrib.py2tf.pyct import compiler
 from tensorflow.contrib.py2tf.pyct import parser
+from tensorflow.python.util import tf_inspect
 
+# TODO(mdan): Properly document the type hints.
+# TODO(mdan): Reduce the type hint information to (module, type).
+# (currently we require (module + class name, type))
 
-def to_graph(f, arg_value_hints=None):
-  """Compile a Python function into equivalent TensorFlow code.
+
+def to_graph(o, arg_value_hints=None):
+  """Compile a Python entity into equivalent TensorFlow code.
+
+  Currently supported entities:
+    * functions
+    * classes
+
+  Classes are handled by converting all their methods into a new class.
 
   Args:
-    f: A Python function with arbitrary arguments and return values.
+    o: A Python function or class.
     arg_value_hints: A dict mapping parameter names to objects that can hint
         at the type of those parameters.
 
   Returns:
-    A function with a signature identical to `f`, but which when executed it
-  creates TF a graph that has the same functionality as the original function.
+    A function with a signature identical to `o`, but which when executed it
+  creates TF a graph that has the same functionality as the original entity.
   """
   conversion_map = conversion.ConversionMap()
-  _, name = conversion.object_to_graph(f, conversion_map, arg_value_hints)
+  _, name = conversion.object_to_graph(o, conversion_map, arg_value_hints)
 
   module = gast.Module([])
   for import_line in config.COMPILED_IMPORT_STATEMENTS:
@@ -51,17 +62,20 @@ def to_graph(f, arg_value_hints=None):
 
   # The compiled code should see everything the entry function saw.
   # TODO(mdan): This might not work well if the call tree spans modules?
-  compiled_node.__dict__.update(six.get_function_globals(f))
+  if tf_inspect.isfunction(o):
+    compiled_node.__dict__.update(six.get_function_globals(o))
 
   compiled_fn = getattr(compiled_node, name)
   return compiled_fn
 
 
-def to_code(f, arg_value_hints=None, indentation='  '):
-  """Return the equivalent of a function in TensorFlow code.
+def to_code(o, arg_value_hints=None, indentation='  '):
+  """Return the equivalent of an entity in TensorFlow code.
+
+  See `to_graph` for more details.
 
   Args:
-    f: A Python function with arbitrary arguments and return values.
+    o: A Python function or class.
     arg_value_hints: A dict mapping parameter names to objects that can hint
         at the type of those parameters.
     indentation: String, when to use for each level of indentation.
@@ -70,7 +84,7 @@ def to_code(f, arg_value_hints=None, indentation='  '):
     String.
   """
   conversion_map = conversion.ConversionMap()
-  conversion.object_to_graph(f, conversion_map, arg_value_hints)
+  conversion.object_to_graph(o, conversion_map, arg_value_hints)
 
   imports = '\n'.join(config.COMPILED_IMPORT_STATEMENTS)
   code = '\n'.join(
index c50db4ec982ab6d6d5b856d015fba1ee459f9062..43bccae9538c4c68867764a9e433cac81bb98e78 100644 (file)
@@ -18,6 +18,7 @@ from __future__ import absolute_import
 from __future__ import division
 from __future__ import print_function
 
+import gast
 import six
 
 from tensorflow.contrib.py2tf import config
@@ -35,6 +36,7 @@ from tensorflow.contrib.py2tf.pyct import parser
 from tensorflow.contrib.py2tf.pyct.static_analysis import access
 from tensorflow.contrib.py2tf.pyct.static_analysis import live_values
 from tensorflow.contrib.py2tf.pyct.static_analysis import type_info
+from tensorflow.python.util import tf_inspect
 
 
 class ConversionMap(object):
@@ -93,13 +95,61 @@ def object_to_graph(o, conversion_map, value_hints):
   Raises:
     ValueError: if the object is not supported.
   """
-  if callable(o):
-    return function_to_graph(o, conversion_map, value_hints)
-  raise ValueError(
-      'Unsupported object type %s. Only functions are supported for now.')
+  if value_hints is None:
+    value_hints = {}
+
+  if tf_inspect.isclass(o):
+    node, new_name = class_to_graph(o, conversion_map, value_hints)
+  elif tf_inspect.isfunction(o):
+    node, new_name = function_to_graph(o, conversion_map, value_hints)
+  else:
+    raise ValueError(
+        'Unsupported object type %s. Only functions and classes are supported'
+        ' for now.')
+
+  conversion_map.add_to_cache(o, node)
+  # Recursively convert remaining dependencies.
+  for obj in conversion_map.name_map.keys():
+    if obj not in conversion_map.dependency_cache:
+      if hasattr(obj, 'im_class'):
+        # Class members are converted with their objects.
+        continue
+      object_to_graph(obj, conversion_map, None)
+
+  return node, new_name
+
+
+def class_to_graph(c, conversion_map, param_value_hints):
+  """Specialization of `object_to_graph` for classes."""
+  converted_members = {}
+  members = tf_inspect.getmembers(c, predicate=tf_inspect.ismethod)
+  if not members:
+    raise ValueError('Cannot convert %s: it has no member methods.')
+
+  if 'self' in param_value_hints:
+    raise ValueError('Hints may not be provided for reserved name "self".')
+  param_value_hints['self'] = (c.__name__, c)
 
+  class_globals = None
+  for _, m in members:
+    node, _ = function_to_graph(m, conversion_map, param_value_hints, c)
+    # TODO(mdan): Do not assume all members have the same view of globals.
+    if class_globals is None:
+      class_globals = six.get_function_globals(m)
+    converted_members[m] = node
+  namer = conversion_map.new_namer(class_globals)
+  class_name = namer.compiled_class_name(c.__name__, c)
+  node = gast.ClassDef(
+      class_name,
+      bases=[],
+      keywords=[],
+      body=converted_members.values(),
+      decorator_list=[])
 
-def function_to_graph(f, conversion_map, param_value_hints):
+  return node, class_name
+
+
+def function_to_graph(f, conversion_map, param_value_hints, owner_type=None):
   """Specialization of `object_to_graph` for callable functions."""
   node = parser.parse_object(f).body[0]
   node_globals = six.get_function_globals(f)
@@ -118,15 +168,12 @@ def function_to_graph(f, conversion_map, param_value_hints):
   # Simulate a rename to ensure the top level is in the name map. This is needed
   # for top level functions, and it also helps the consistency verification made
   # by update_name_map.
-  namer.compiled_function_name(f.__name__, f)
-
-  conversion_map.add_to_cache(f, node)
+  if owner_type is not None:
+    new_name = namer.compiled_function_name(f.__name__, f, owner_type)
+  else:
+    new_name = namer.compiled_function_name(f.__name__, f)
+  node.name = new_name
   conversion_map.update_name_map(namer)
-
-  # Recursively convert any remaining dependencies.
-  for obj in conversion_map.name_map.keys():
-    if obj not in conversion_map.dependency_cache:
-      object_to_graph(obj, conversion_map, None)
   return node, conversion_map.name_map[f]
 
 
index 9a6517321c810e289fa6e0101eb6e16b44c8fc93..633602f4d49792c45826afd8646593e280e35d12 100644 (file)
@@ -35,7 +35,7 @@ class BuiltinFunctionsTest(test.TestCase):
     node = parser.parse_object(test_fn)
     node = access.resolve(node)
     node = live_values.resolve(node, namespace, {})
-    node = type_info.resolve(node, None)
+    node = type_info.resolve(node, {})
     return node
 
   def test_len(self):
index 5e6c8247bd85fad7b3c366ab1b3573af1ca91f2b..92c3439101ed9d3fe54147346be3cd6a1c0f9d8c 100644 (file)
 # See the License for the specific language governing permissions and
 # limitations under the License.
 # ==============================================================================
-"""Handles function calls, by generating compiled function names and calls."""
+"""Handles function calls, by generating compiled function names and calls.
+
+Note: this transformer does not rename the top level object being converted;
+that is the caller's responsibility.
+"""
 
 from __future__ import absolute_import
 from __future__ import division
@@ -29,12 +33,28 @@ from tensorflow.contrib.py2tf.pyct import templates
 class FunctionNamer(object):
   """Describes the interface for CallTreeTransformer's namer."""
 
-  def compiled_function_name(self, original_name, live_object=None):
+  def compiled_function_name(self,
+                             original_name,
+                             live_object=None,
+                             owner_type=None):
     """Generate the name corresponding to the compiled version of a function.
 
     Args:
       original_name: String
       live_object: Callable, the actual target function, if known.
+      owner_type: Optional object. If present, it indicates that the function is
+          a member of the given type.
+    Returns:
+      String.
+    """
+    raise NotImplementedError()
+
+  def compiled_class_name(self, original_name, live_object=None):
+    """Generate the name corresponding to the compiled version of a class.
+
+    Args:
+      original_name: String
+      live_object: The actual target class, if known.
     Returns:
       String.
     """
@@ -50,11 +70,6 @@ class CallTreeTransformer(gast.NodeTransformer):
 
   # pylint:disable=invalid-name
 
-  def visit_FunctionDef(self, node):
-    self.generic_visit(node)
-    node.name = self.namer.compiled_function_name(node.name)
-    return node
-
   def _should_compile(self, fqn):
     for i in range(1, len(fqn)):
       if fqn[:i] in self.uncompiled_modules:
@@ -70,17 +85,32 @@ class CallTreeTransformer(gast.NodeTransformer):
     if not self._should_compile(target_fqn):
       return node
 
-    new_name = self.namer.compiled_function_name(
-        '.'.join(target_fqn), live_object=target_obj)
+    if anno.hasanno(node, 'is_constructor'):
+      new_name = self.namer.compiled_class_name(
+          '.'.join(target_fqn), live_object=target_obj)
+    else:
+      new_name = self.namer.compiled_function_name(
+          '.'.join(target_fqn), live_object=target_obj)
     node.func = gast.Name(id=new_name, ctx=gast.Load(), annotation=None)
     return node
 
   def _rename_member_function_of_known_type(self, node):
-    target_fqn = anno.getanno(node.func, 'type_fqn')
-    if not self._should_compile(target_fqn):
+    assert isinstance(node.func, gast.Attribute)
+
+    type_fqn = anno.getanno(node.func, 'type_fqn')
+    assert anno.hasanno(node.func, 'type')
+    target_type = anno.getanno(node.func, 'type')
+
+    if not self._should_compile(type_fqn):
       return node
 
-    raise NotImplementedError('Member function call (of known type).')
+    # TODO(mdan): We should not assume that the namer only needs the
+    # member function name.
+    new_name = self.namer.compiled_function_name(
+        node.func.attr, live_object=None, owner_type=target_type)
+    node.func.attr = new_name
+
+    return node
 
   def _wrap_to_py_func_no_return(self, node):
     args_scope = anno.getanno(node, 'args_scope')
index 302fad5840dc06f8c61c996a2375e8f0df91297b..38c701eaadee8ad4df006a950192d51d78c799fe 100644 (file)
@@ -41,7 +41,7 @@ class CallTreesTest(test.TestCase):
     node = parser.parse_object(test_fn)
     node = access.resolve(node)
     node = live_values.resolve(node, namespace, {})
-    node = type_info.resolve(node, None)
+    node = type_info.resolve(node, {})
     return node
 
   def test_basic(self):
@@ -61,7 +61,7 @@ class CallTreesTest(test.TestCase):
     # Only test_fn_2 is transformed, so we'll insert renamed_test_fn_1 manually.
     setattr(result, 'renamed_test_fn_1', renamed_test_fn_1)
 
-    self.assertEquals(3, result.renamed_test_fn_2(1))
+    self.assertEquals(3, result.test_fn_2(1))
 
   def test_uncompiled_modules(self):
 
@@ -82,7 +82,9 @@ class CallTreesTest(test.TestCase):
     setattr(result, 'constant_op', constant_op)
 
     with self.test_session() as sess:
-      result_tensor = result.renamed_test_fn(constant_op.constant(1))
+      # Not renamed, because the converter doesn't rename the definition itself.
+      # (the caller is responsible for that).
+      result_tensor = result.test_fn(constant_op.constant(1))
       result_val = sess.run(result_tensor)
 
     self.assertEquals(3, result_val)
index 51237a291d30da93115be25fd85e6f829646f389..121af4ee949152cb6df7496a4a0c64f13f65a5eb 100644 (file)
@@ -46,7 +46,7 @@ class ControlFlowTest(test.TestCase):
     node = parser.parse_object(test_fn)
     node = access.resolve(node)
     node = live_values.resolve(node, namespace, {})
-    node = type_info.resolve(node, None)
+    node = type_info.resolve(node, {})
     return node
 
   def test_simple_while(self):
index f8fee878495352f91f0825896348a5f33082a5aa..65e592b66e9d0c08c7d2127ff40be8a0dc28ec6c 100644 (file)
@@ -35,7 +35,7 @@ class PrintFunctionsTest(test.TestCase):
     node = parser.parse_object(test_fn)
     node = access.resolve(node)
     node = live_values.resolve(node, namespace, {})
-    node = type_info.resolve(node, None)
+    node = type_info.resolve(node, {})
     return node
 
   def test_transform(self):
index e8888ab1924cb11ba74eb84109cae79c7b7173a0..d932840186034c073512cbd1e253fc7676aa83e7 100644 (file)
@@ -43,7 +43,7 @@ class SideEffectGuardsTest(test.TestCase):
     node = parser.parse_object(test_fn)
     node = access.resolve(node)
     node = live_values.resolve(node, namespace, {})
-    node = type_info.resolve(node, None)
+    node = type_info.resolve(node, {})
     return node
 
   def test_transform(self):
index 7a03c8282dfa1d6cd396e60895d11b8a11e6a245..61772ec07b41d366769307982bf0376de9bb495e 100644 (file)
@@ -18,6 +18,8 @@ from __future__ import absolute_import
 from __future__ import division
 from __future__ import print_function
 
+from tensorflow.python.util import tf_inspect
+
 
 class Namer(object):
   """Implementation of the namer interfaces required by various converters.
@@ -41,23 +43,49 @@ class Namer(object):
 
     self.generated_names = set()
 
-  def compiled_function_name(self, original_name, live_object=None):
-    """See call_trees.FunctionNamer.compiled_function_name."""
+  def compiled_class_name(self, original_name, live_object=None):
+    """See call_trees.FunctionNamer.compiled_class_name."""
     if live_object is not None and live_object in self.renamed_calls:
       return self.renamed_calls[live_object]
 
-    new_name_root = 'tf__%s' % original_name
-
+    new_name_root = 'Tf%s' % original_name
     new_name = new_name_root
     n = 0
     while new_name in self.global_namespace:
       n += 1
       new_name = '%s_%d' % (new_name_root, n)
-
     if live_object is not None:
       self.renamed_calls[live_object] = new_name
     self.generated_names.add(new_name)
+    return new_name
 
+  def compiled_function_name(self,
+                             original_name,
+                             live_object=None,
+                             owner_type=None):
+    """See call_trees.FunctionNamer.compiled_function_name."""
+    if live_object is not None and live_object in self.renamed_calls:
+      return self.renamed_calls[live_object]
+
+    if owner_type is None:
+      # Top level functions: rename
+      new_name_root = 'tf__%s' % original_name
+      new_name = new_name_root
+      n = 0
+      while new_name in self.global_namespace:
+        n += 1
+        new_name = '%s_%d' % (new_name_root, n)
+    else:
+      if tf_inspect.isclass(owner_type):
+        # Class members: do not rename (the entire class will be renamed)
+        new_name = original_name
+      else:
+        raise NotImplementedError('Member function "%s" of non-class type: %s' %
+                                  (original_name, owner_type))
+
+    if live_object is not None:
+      self.renamed_calls[live_object] = new_name
+    self.generated_names.add(new_name)
     return new_name
 
   def new_symbol(self, name_root, reserved_locals):
index f5542a8405090892d5fb93982e4455ad61715a2b..242e544b5286c683ee4aa97bc586751932c73815 100644 (file)
@@ -43,6 +43,11 @@ class LiveValueResolver(gast.NodeTransformer):
     self.namespace = namespace
     self.literals = literals
 
+  def visit_ClassDef(self, node):
+    self.generic_visit(node)
+    anno.setanno(node, 'live_val', self.namespace[node.name])
+    return node
+
   def visit_Name(self, node):
     self.generic_visit(node)
     if isinstance(node.ctx, gast.Load):
@@ -74,8 +79,13 @@ class LiveValueResolver(gast.NodeTransformer):
                                                          node.attr))
       anno.setanno(node, 'live_val', getattr(parent_object, node.attr))
       anno.setanno(node, 'fqn', anno.getanno(node.value, 'fqn') + (node.attr,))
-    # TODO(mdan): Figure out what to do when calling attribute on local object.
-    # Maybe just leave as-is?
+    elif isinstance(node.value, gast.Name):
+      stem_name = node.value
+      # All nonlocal symbols should be fully resolved.
+      assert anno.hasanno(stem_name, 'is_local'), stem_name
+      assert anno.getanno(stem_name, 'is_local'), stem_name
+      # TODO(mdan): Figure out what to do when calling attribute on local object
+      # Maybe just leave as-is?
     return node
 
 
index c6db4fcbb10cc928cdd2862cc565c4b0b3ef0239..3e545903261a41cac4dc9ac0e23f857e0be41f96 100644 (file)
@@ -120,8 +120,9 @@ class TypeInfoResolver(gast.NodeTransformer):
     self.generic_visit(node)
     if isinstance(node.ctx, gast.Param):
       self.scope.setval(node.id, gast.Name(node.id, gast.Load(), None))
-      if (self.function_level == 1 and self.value_hints is not None and
-          node.id in self.value_hints):
+      # TODO(mdan): Member functions should not need type hints.
+      # We could attemp to extract im_class from the live_val annotation.
+      if self.function_level == 1 and node.id in self.value_hints:
         # Forge a node to hold the type information, so that method calls on
         # it can resolve the type.
         type_holder = gast.Name(node.id, gast.Load(), None)
@@ -137,7 +138,7 @@ class TypeInfoResolver(gast.NodeTransformer):
       if anno.hasanno(func, 'live_val'):
         func_obj = anno.getanno(func, 'live_val')
         if tf_inspect.isclass(func_obj):
-          # This is then a constructor.
+          anno.setanno(source, 'is_constructor', True)
           anno.setanno(source, 'type', func_obj)
           anno.setanno(source, 'type_fqn', anno.getanno(func, 'fqn'))
           # TODO(mdan): Raise an error if constructor has side effects.
@@ -150,8 +151,15 @@ class TypeInfoResolver(gast.NodeTransformer):
           self.scope.setval(e.id,
                             gast.Subscript(
                                 source, gast.Index(i), ctx=gast.Store()))
-      else:
+      elif isinstance(t, gast.Name):
         self.scope.setval(t.id, source)
+      elif isinstance(t, gast.Attribute):
+        if not (isinstance(t.value, gast.Name) and t.value.id == 'self'):
+          raise ValueError(
+              'Dont know how to handle assignment to attributes of objects'
+              ' other than "self": [%s].%s' % (t.value, t.attr))
+      else:
+        raise ValueError('Dont know how to handle assignment to %s' % t)
 
   def visit_With(self, node):
     for wi in node.items:
@@ -187,6 +195,7 @@ class TypeInfoResolver(gast.NodeTransformer):
         if not anno.hasanno(object_source, 'type'):
           raise ValueError('Could not determine type of "%s". Is it dynamic?' %
                            (target.value.id))
+        anno.setanno(target, 'type', anno.getanno(object_source, 'type'))
         anno.setanno(target, 'type_fqn', anno.getanno(object_source,
                                                       'type_fqn'))
       else:
@@ -198,4 +207,5 @@ class TypeInfoResolver(gast.NodeTransformer):
 
 
 def resolve(node, value_hints):
+  assert value_hints is not None
   return TypeInfoResolver(value_hints).visit(node)
index e9deaa085d44a69b0d2ff37b3c3977851106c510..8526f42413b9cca077da45195249615b55c45bc9 100644 (file)
@@ -63,7 +63,7 @@ class TypeInfoResolverTest(test.TestCase):
     node = parser.parse_object(test_fn)
     node = access.resolve(node)
     node = live_values.resolve(node, {'training': training}, {})
-    node = type_info.resolve(node, None)
+    node = type_info.resolve(node, {})
 
     call_node = node.body[0].body[0].value
     self.assertEquals(training.GradientDescentOptimizer,
@@ -80,7 +80,7 @@ class TypeInfoResolverTest(test.TestCase):
     node = parser.parse_object(test_fn)
     node = access.resolve(node)
     node = live_values.resolve(node, {'training': training}, {})
-    node = type_info.resolve(node, None)
+    node = type_info.resolve(node, {})
 
     attr_call_node = node.body[0].body[1].value.func
     self.assertEquals((training.__name__, 'GradientDescentOptimizer'),
@@ -95,7 +95,7 @@ class TypeInfoResolverTest(test.TestCase):
     node = parser.parse_object(test_fn)
     node = access.resolve(node)
     node = live_values.resolve(node, {'session': session}, {})
-    node = type_info.resolve(node, None)
+    node = type_info.resolve(node, {})
 
     constructor_call = node.body[0].body[0].items[0].context_expr
     self.assertEquals(session.Session, anno.getanno(constructor_call, 'type'))
@@ -119,7 +119,7 @@ class TypeInfoResolverTest(test.TestCase):
     node = access.resolve(node)
     node = live_values.resolve(node, {'training': training}, {})
     with self.assertRaises(ValueError):
-      node = type_info.resolve(node, None)
+      node = type_info.resolve(node, {})
 
   def test_parameter_class_members(self):
 
@@ -130,7 +130,7 @@ class TypeInfoResolverTest(test.TestCase):
     node = access.resolve(node)
     node = live_values.resolve(node, {'training': training}, {})
     with self.assertRaises(ValueError):
-      node = type_info.resolve(node, None)
+      node = type_info.resolve(node, {})
 
   def test_parameter_class_members_with_value_hints(self):
 
@@ -164,7 +164,7 @@ class TypeInfoResolverTest(test.TestCase):
     node = access.resolve(node)
     node = live_values.resolve(node, {'bar': bar}, {})
     with self.assertRaises(ValueError):
-      node = type_info.resolve(node, None)
+      node = type_info.resolve(node, {})
 
   def test_nested_members(self):
 
@@ -176,7 +176,7 @@ class TypeInfoResolverTest(test.TestCase):
     node = access.resolve(node)
     node = live_values.resolve(node, {'training': training}, {})
     with self.assertRaises(ValueError):
-      node = type_info.resolve(node, None)
+      node = type_info.resolve(node, {})
 
 
 if __name__ == '__main__':