From 26cb7de9c03a9d73703decec8c917651369ee9ee Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Mon, 26 Feb 2018 14:25:37 -0800 Subject: [PATCH] Add a function that allows to dynamically verify whether a function is white listed for graph mode. PiperOrigin-RevId: 187080654 --- tensorflow/contrib/py2tf/impl/conversion.py | 18 ++++++++++++++++++ tensorflow/contrib/py2tf/impl/conversion_test.py | 11 +++++++++++ 2 files changed, 29 insertions(+) diff --git a/tensorflow/contrib/py2tf/impl/conversion.py b/tensorflow/contrib/py2tf/impl/conversion.py index 044de33..d95469e 100644 --- a/tensorflow/contrib/py2tf/impl/conversion.py +++ b/tensorflow/contrib/py2tf/impl/conversion.py @@ -97,6 +97,24 @@ class ConversionMap(object): self.dependency_cache[original_entity] = converted_ast +def is_whitelisted_for_graph(o): + """Check whether an entity is whitelisted for use in graph mode. + + Examples of whitelisted entities include all members of the tensorflow + package. + + Args: + o: A Python entity. + Returns: + Boolean + """ + m = tf_inspect.getmodule(o) + for prefix, in config.DEFAULT_UNCOMPILED_MODULES: + if m.__name__.startswith(prefix): + return True + return False + + def entity_to_graph(o, conversion_map, arg_values, arg_types): """Compile a Python entity into equivalent TensorFlow. diff --git a/tensorflow/contrib/py2tf/impl/conversion_test.py b/tensorflow/contrib/py2tf/impl/conversion_test.py index 7816f95..9ff256a 100644 --- a/tensorflow/contrib/py2tf/impl/conversion_test.py +++ b/tensorflow/contrib/py2tf/impl/conversion_test.py @@ -20,12 +20,23 @@ from __future__ import print_function import gast +from tensorflow.contrib.py2tf import utils from tensorflow.contrib.py2tf.impl import conversion +from tensorflow.python.framework import constant_op from tensorflow.python.platform import test class ConversionTest(test.TestCase): + def test_is_whitelisted_for_graph(self): + + def test_fn(): + return constant_op.constant(1) + + self.assertFalse(conversion.is_whitelisted_for_graph(test_fn)) + self.assertTrue(conversion.is_whitelisted_for_graph(utils)) + self.assertTrue(conversion.is_whitelisted_for_graph(constant_op.constant)) + def test_entity_to_graph_unsupported_types(self): with self.assertRaises(ValueError): conversion_map = conversion.ConversionMap(True, (), (), None) -- 2.7.4