from tensorflow.contrib.py2tf import conversion
from tensorflow.contrib.py2tf.pyct import compiler
from tensorflow.contrib.py2tf.pyct import parser
+from tensorflow.python.util import tf_inspect
+# TODO(mdan): Properly document the type hints.
+# TODO(mdan): Reduce the type hint information to (module, type).
+# (currently we require (module + class name, type))
-def to_graph(f, arg_value_hints=None):
- """Compile a Python function into equivalent TensorFlow code.
+
+def to_graph(o, arg_value_hints=None):
+ """Compile a Python entity into equivalent TensorFlow code.
+
+ Currently supported entities:
+ * functions
+ * classes
+
+ Classes are handled by converting all their methods into a new class.
Args:
- f: A Python function with arbitrary arguments and return values.
+ o: A Python function or class.
arg_value_hints: A dict mapping parameter names to objects that can hint
at the type of those parameters.
Returns:
- A function with a signature identical to `f`, but which when executed it
- creates TF a graph that has the same functionality as the original function.
+ A function with a signature identical to `o`, but which when executed it
+ creates TF a graph that has the same functionality as the original entity.
"""
conversion_map = conversion.ConversionMap()
- _, name = conversion.object_to_graph(f, conversion_map, arg_value_hints)
+ _, name = conversion.object_to_graph(o, conversion_map, arg_value_hints)
module = gast.Module([])
for import_line in config.COMPILED_IMPORT_STATEMENTS:
# The compiled code should see everything the entry function saw.
# TODO(mdan): This might not work well if the call tree spans modules?
- compiled_node.__dict__.update(six.get_function_globals(f))
+ if tf_inspect.isfunction(o):
+ compiled_node.__dict__.update(six.get_function_globals(o))
compiled_fn = getattr(compiled_node, name)
return compiled_fn
-def to_code(f, arg_value_hints=None, indentation=' '):
- """Return the equivalent of a function in TensorFlow code.
+def to_code(o, arg_value_hints=None, indentation=' '):
+ """Return the equivalent of an entity in TensorFlow code.
+
+ See `to_graph` for more details.
Args:
- f: A Python function with arbitrary arguments and return values.
+ o: A Python function or class.
arg_value_hints: A dict mapping parameter names to objects that can hint
at the type of those parameters.
indentation: String, when to use for each level of indentation.
String.
"""
conversion_map = conversion.ConversionMap()
- conversion.object_to_graph(f, conversion_map, arg_value_hints)
+ conversion.object_to_graph(o, conversion_map, arg_value_hints)
imports = '\n'.join(config.COMPILED_IMPORT_STATEMENTS)
code = '\n'.join(
from __future__ import division
from __future__ import print_function
+import gast
import six
from tensorflow.contrib.py2tf import config
from tensorflow.contrib.py2tf.pyct.static_analysis import access
from tensorflow.contrib.py2tf.pyct.static_analysis import live_values
from tensorflow.contrib.py2tf.pyct.static_analysis import type_info
+from tensorflow.python.util import tf_inspect
class ConversionMap(object):
Raises:
ValueError: if the object is not supported.
"""
- if callable(o):
- return function_to_graph(o, conversion_map, value_hints)
- raise ValueError(
- 'Unsupported object type %s. Only functions are supported for now.')
+ if value_hints is None:
+ value_hints = {}
+
+ if tf_inspect.isclass(o):
+ node, new_name = class_to_graph(o, conversion_map, value_hints)
+ elif tf_inspect.isfunction(o):
+ node, new_name = function_to_graph(o, conversion_map, value_hints)
+ else:
+ raise ValueError(
+ 'Unsupported object type %s. Only functions and classes are supported'
+ ' for now.')
+
+ conversion_map.add_to_cache(o, node)
+ # Recursively convert remaining dependencies.
+ for obj in conversion_map.name_map.keys():
+ if obj not in conversion_map.dependency_cache:
+ if hasattr(obj, 'im_class'):
+ # Class members are converted with their objects.
+ continue
+ object_to_graph(obj, conversion_map, None)
+
+ return node, new_name
+
+
+def class_to_graph(c, conversion_map, param_value_hints):
+ """Specialization of `object_to_graph` for classes."""
+ converted_members = {}
+ members = tf_inspect.getmembers(c, predicate=tf_inspect.ismethod)
+ if not members:
+ raise ValueError('Cannot convert %s: it has no member methods.')
+
+ if 'self' in param_value_hints:
+ raise ValueError('Hints may not be provided for reserved name "self".')
+ param_value_hints['self'] = (c.__name__, c)
+ class_globals = None
+ for _, m in members:
+ node, _ = function_to_graph(m, conversion_map, param_value_hints, c)
+ # TODO(mdan): Do not assume all members have the same view of globals.
+ if class_globals is None:
+ class_globals = six.get_function_globals(m)
+ converted_members[m] = node
+ namer = conversion_map.new_namer(class_globals)
+ class_name = namer.compiled_class_name(c.__name__, c)
+ node = gast.ClassDef(
+ class_name,
+ bases=[],
+ keywords=[],
+ body=converted_members.values(),
+ decorator_list=[])
-def function_to_graph(f, conversion_map, param_value_hints):
+ return node, class_name
+
+
+def function_to_graph(f, conversion_map, param_value_hints, owner_type=None):
"""Specialization of `object_to_graph` for callable functions."""
node = parser.parse_object(f).body[0]
node_globals = six.get_function_globals(f)
# Simulate a rename to ensure the top level is in the name map. This is needed
# for top level functions, and it also helps the consistency verification made
# by update_name_map.
- namer.compiled_function_name(f.__name__, f)
-
- conversion_map.add_to_cache(f, node)
+ if owner_type is not None:
+ new_name = namer.compiled_function_name(f.__name__, f, owner_type)
+ else:
+ new_name = namer.compiled_function_name(f.__name__, f)
+ node.name = new_name
conversion_map.update_name_map(namer)
-
- # Recursively convert any remaining dependencies.
- for obj in conversion_map.name_map.keys():
- if obj not in conversion_map.dependency_cache:
- object_to_graph(obj, conversion_map, None)
return node, conversion_map.name_map[f]
node = parser.parse_object(test_fn)
node = access.resolve(node)
node = live_values.resolve(node, namespace, {})
- node = type_info.resolve(node, None)
+ node = type_info.resolve(node, {})
return node
def test_len(self):
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
-"""Handles function calls, by generating compiled function names and calls."""
+"""Handles function calls, by generating compiled function names and calls.
+
+Note: this transformer does not rename the top level object being converted;
+that is the caller's responsibility.
+"""
from __future__ import absolute_import
from __future__ import division
class FunctionNamer(object):
"""Describes the interface for CallTreeTransformer's namer."""
- def compiled_function_name(self, original_name, live_object=None):
+ def compiled_function_name(self,
+ original_name,
+ live_object=None,
+ owner_type=None):
"""Generate the name corresponding to the compiled version of a function.
Args:
original_name: String
live_object: Callable, the actual target function, if known.
+ owner_type: Optional object. If present, it indicates that the function is
+ a member of the given type.
+ Returns:
+ String.
+ """
+ raise NotImplementedError()
+
+ def compiled_class_name(self, original_name, live_object=None):
+ """Generate the name corresponding to the compiled version of a class.
+
+ Args:
+ original_name: String
+ live_object: The actual target class, if known.
Returns:
String.
"""
# pylint:disable=invalid-name
- def visit_FunctionDef(self, node):
- self.generic_visit(node)
- node.name = self.namer.compiled_function_name(node.name)
- return node
-
def _should_compile(self, fqn):
for i in range(1, len(fqn)):
if fqn[:i] in self.uncompiled_modules:
if not self._should_compile(target_fqn):
return node
- new_name = self.namer.compiled_function_name(
- '.'.join(target_fqn), live_object=target_obj)
+ if anno.hasanno(node, 'is_constructor'):
+ new_name = self.namer.compiled_class_name(
+ '.'.join(target_fqn), live_object=target_obj)
+ else:
+ new_name = self.namer.compiled_function_name(
+ '.'.join(target_fqn), live_object=target_obj)
node.func = gast.Name(id=new_name, ctx=gast.Load(), annotation=None)
return node
def _rename_member_function_of_known_type(self, node):
- target_fqn = anno.getanno(node.func, 'type_fqn')
- if not self._should_compile(target_fqn):
+ assert isinstance(node.func, gast.Attribute)
+
+ type_fqn = anno.getanno(node.func, 'type_fqn')
+ assert anno.hasanno(node.func, 'type')
+ target_type = anno.getanno(node.func, 'type')
+
+ if not self._should_compile(type_fqn):
return node
- raise NotImplementedError('Member function call (of known type).')
+ # TODO(mdan): We should not assume that the namer only needs the
+ # member function name.
+ new_name = self.namer.compiled_function_name(
+ node.func.attr, live_object=None, owner_type=target_type)
+ node.func.attr = new_name
+
+ return node
def _wrap_to_py_func_no_return(self, node):
args_scope = anno.getanno(node, 'args_scope')
node = parser.parse_object(test_fn)
node = access.resolve(node)
node = live_values.resolve(node, namespace, {})
- node = type_info.resolve(node, None)
+ node = type_info.resolve(node, {})
return node
def test_basic(self):
# Only test_fn_2 is transformed, so we'll insert renamed_test_fn_1 manually.
setattr(result, 'renamed_test_fn_1', renamed_test_fn_1)
- self.assertEquals(3, result.renamed_test_fn_2(1))
+ self.assertEquals(3, result.test_fn_2(1))
def test_uncompiled_modules(self):
setattr(result, 'constant_op', constant_op)
with self.test_session() as sess:
- result_tensor = result.renamed_test_fn(constant_op.constant(1))
+ # Not renamed, because the converter doesn't rename the definition itself.
+ # (the caller is responsible for that).
+ result_tensor = result.test_fn(constant_op.constant(1))
result_val = sess.run(result_tensor)
self.assertEquals(3, result_val)
node = parser.parse_object(test_fn)
node = access.resolve(node)
node = live_values.resolve(node, namespace, {})
- node = type_info.resolve(node, None)
+ node = type_info.resolve(node, {})
return node
def test_simple_while(self):
node = parser.parse_object(test_fn)
node = access.resolve(node)
node = live_values.resolve(node, namespace, {})
- node = type_info.resolve(node, None)
+ node = type_info.resolve(node, {})
return node
def test_transform(self):
node = parser.parse_object(test_fn)
node = access.resolve(node)
node = live_values.resolve(node, namespace, {})
- node = type_info.resolve(node, None)
+ node = type_info.resolve(node, {})
return node
def test_transform(self):
from __future__ import division
from __future__ import print_function
+from tensorflow.python.util import tf_inspect
+
class Namer(object):
"""Implementation of the namer interfaces required by various converters.
self.generated_names = set()
- def compiled_function_name(self, original_name, live_object=None):
- """See call_trees.FunctionNamer.compiled_function_name."""
+ def compiled_class_name(self, original_name, live_object=None):
+ """See call_trees.FunctionNamer.compiled_class_name."""
if live_object is not None and live_object in self.renamed_calls:
return self.renamed_calls[live_object]
- new_name_root = 'tf__%s' % original_name
-
+ new_name_root = 'Tf%s' % original_name
new_name = new_name_root
n = 0
while new_name in self.global_namespace:
n += 1
new_name = '%s_%d' % (new_name_root, n)
-
if live_object is not None:
self.renamed_calls[live_object] = new_name
self.generated_names.add(new_name)
+ return new_name
+ def compiled_function_name(self,
+ original_name,
+ live_object=None,
+ owner_type=None):
+ """See call_trees.FunctionNamer.compiled_function_name."""
+ if live_object is not None and live_object in self.renamed_calls:
+ return self.renamed_calls[live_object]
+
+ if owner_type is None:
+ # Top level functions: rename
+ new_name_root = 'tf__%s' % original_name
+ new_name = new_name_root
+ n = 0
+ while new_name in self.global_namespace:
+ n += 1
+ new_name = '%s_%d' % (new_name_root, n)
+ else:
+ if tf_inspect.isclass(owner_type):
+ # Class members: do not rename (the entire class will be renamed)
+ new_name = original_name
+ else:
+ raise NotImplementedError('Member function "%s" of non-class type: %s' %
+ (original_name, owner_type))
+
+ if live_object is not None:
+ self.renamed_calls[live_object] = new_name
+ self.generated_names.add(new_name)
return new_name
def new_symbol(self, name_root, reserved_locals):
self.namespace = namespace
self.literals = literals
+ def visit_ClassDef(self, node):
+ self.generic_visit(node)
+ anno.setanno(node, 'live_val', self.namespace[node.name])
+ return node
+
def visit_Name(self, node):
self.generic_visit(node)
if isinstance(node.ctx, gast.Load):
node.attr))
anno.setanno(node, 'live_val', getattr(parent_object, node.attr))
anno.setanno(node, 'fqn', anno.getanno(node.value, 'fqn') + (node.attr,))
- # TODO(mdan): Figure out what to do when calling attribute on local object.
- # Maybe just leave as-is?
+ elif isinstance(node.value, gast.Name):
+ stem_name = node.value
+ # All nonlocal symbols should be fully resolved.
+ assert anno.hasanno(stem_name, 'is_local'), stem_name
+ assert anno.getanno(stem_name, 'is_local'), stem_name
+ # TODO(mdan): Figure out what to do when calling attribute on local object
+ # Maybe just leave as-is?
return node
self.generic_visit(node)
if isinstance(node.ctx, gast.Param):
self.scope.setval(node.id, gast.Name(node.id, gast.Load(), None))
- if (self.function_level == 1 and self.value_hints is not None and
- node.id in self.value_hints):
+ # TODO(mdan): Member functions should not need type hints.
+ # We could attemp to extract im_class from the live_val annotation.
+ if self.function_level == 1 and node.id in self.value_hints:
# Forge a node to hold the type information, so that method calls on
# it can resolve the type.
type_holder = gast.Name(node.id, gast.Load(), None)
if anno.hasanno(func, 'live_val'):
func_obj = anno.getanno(func, 'live_val')
if tf_inspect.isclass(func_obj):
- # This is then a constructor.
+ anno.setanno(source, 'is_constructor', True)
anno.setanno(source, 'type', func_obj)
anno.setanno(source, 'type_fqn', anno.getanno(func, 'fqn'))
# TODO(mdan): Raise an error if constructor has side effects.
self.scope.setval(e.id,
gast.Subscript(
source, gast.Index(i), ctx=gast.Store()))
- else:
+ elif isinstance(t, gast.Name):
self.scope.setval(t.id, source)
+ elif isinstance(t, gast.Attribute):
+ if not (isinstance(t.value, gast.Name) and t.value.id == 'self'):
+ raise ValueError(
+ 'Dont know how to handle assignment to attributes of objects'
+ ' other than "self": [%s].%s' % (t.value, t.attr))
+ else:
+ raise ValueError('Dont know how to handle assignment to %s' % t)
def visit_With(self, node):
for wi in node.items:
if not anno.hasanno(object_source, 'type'):
raise ValueError('Could not determine type of "%s". Is it dynamic?' %
(target.value.id))
+ anno.setanno(target, 'type', anno.getanno(object_source, 'type'))
anno.setanno(target, 'type_fqn', anno.getanno(object_source,
'type_fqn'))
else:
def resolve(node, value_hints):
+ assert value_hints is not None
return TypeInfoResolver(value_hints).visit(node)
node = parser.parse_object(test_fn)
node = access.resolve(node)
node = live_values.resolve(node, {'training': training}, {})
- node = type_info.resolve(node, None)
+ node = type_info.resolve(node, {})
call_node = node.body[0].body[0].value
self.assertEquals(training.GradientDescentOptimizer,
node = parser.parse_object(test_fn)
node = access.resolve(node)
node = live_values.resolve(node, {'training': training}, {})
- node = type_info.resolve(node, None)
+ node = type_info.resolve(node, {})
attr_call_node = node.body[0].body[1].value.func
self.assertEquals((training.__name__, 'GradientDescentOptimizer'),
node = parser.parse_object(test_fn)
node = access.resolve(node)
node = live_values.resolve(node, {'session': session}, {})
- node = type_info.resolve(node, None)
+ node = type_info.resolve(node, {})
constructor_call = node.body[0].body[0].items[0].context_expr
self.assertEquals(session.Session, anno.getanno(constructor_call, 'type'))
node = access.resolve(node)
node = live_values.resolve(node, {'training': training}, {})
with self.assertRaises(ValueError):
- node = type_info.resolve(node, None)
+ node = type_info.resolve(node, {})
def test_parameter_class_members(self):
node = access.resolve(node)
node = live_values.resolve(node, {'training': training}, {})
with self.assertRaises(ValueError):
- node = type_info.resolve(node, None)
+ node = type_info.resolve(node, {})
def test_parameter_class_members_with_value_hints(self):
node = access.resolve(node)
node = live_values.resolve(node, {'bar': bar}, {})
with self.assertRaises(ValueError):
- node = type_info.resolve(node, None)
+ node = type_info.resolve(node, {})
def test_nested_members(self):
node = access.resolve(node)
node = live_values.resolve(node, {'training': training}, {})
with self.assertRaises(ValueError):
- node = type_info.resolve(node, None)
+ node = type_info.resolve(node, {})
if __name__ == '__main__':