Expose Interpreter to tensorflow.contrib.lite
authorAndrew Selle <aselle@google.com>
Thu, 3 May 2018 04:15:01 +0000 (21:15 -0700)
committerTensorFlower Gardener <gardener@tensorflow.org>
Thu, 3 May 2018 04:17:42 +0000 (21:17 -0700)
PiperOrigin-RevId: 195198645

tensorflow/contrib/lite/BUILD
tensorflow/contrib/lite/python/BUILD
tensorflow/contrib/lite/python/interpreter.py
tensorflow/contrib/lite/python/interpreter_wrapper/interpreter_wrapper.cc
tensorflow/contrib/lite/python/lite.py
tensorflow/tools/pip_package/pip_smoke_test.py

index 1534f97..10065e8 100644 (file)
@@ -92,6 +92,8 @@ cc_library(
     deps = [":context"],
 )
 
+exports_files(["builtin_ops.h"])
+
 cc_library(
     name = "string",
     hdrs = [
index e6dcc7a..4920e83 100644 (file)
@@ -44,6 +44,7 @@ py_library(
     deps = [
         ":convert",
         ":convert_saved_model",
+        ":interpreter",
         ":op_hint",
     ],
 )
index cb9c0d3..5fbc551 100644 (file)
@@ -17,7 +17,19 @@ from __future__ import absolute_import
 from __future__ import division
 from __future__ import print_function
 
-from tensorflow.contrib.lite.python.interpreter_wrapper import tensorflow_wrap_interpreter_wrapper as interpreter_wrapper
+from tensorflow.python.util.lazy_loader import LazyLoader
+
+# Lazy load since some of the performance benchmark skylark rules
+# break dependencies. Must use double quotes to match code internal rewrite
+# rule.
+# pylint: disable=g-inconsistent-quotes
+_interpreter_wrapper = LazyLoader(
+    "_interpreter_wrapper", globals(),
+    "tensorflow.contrib.lite.python.interpreter_wrapper."
+    "tensorflow_wrap_interpreter_wrapper")
+# pylint: enable=g-inconsistent-quotes
+
+del LazyLoader
 
 
 class Interpreter(object):
@@ -35,13 +47,13 @@ class Interpreter(object):
     """
     if model_path and not model_content:
       self._interpreter = (
-          interpreter_wrapper.InterpreterWrapper_CreateWrapperCPPFromFile(
+          _interpreter_wrapper.InterpreterWrapper_CreateWrapperCPPFromFile(
               model_path))
       if not self._interpreter:
         raise ValueError('Failed to open {}'.format(model_path))
     elif model_content and not model_path:
       self._interpreter = (
-          interpreter_wrapper.InterpreterWrapper_CreateWrapperCPPFromBuffer(
+          _interpreter_wrapper.InterpreterWrapper_CreateWrapperCPPFromBuffer(
               model_content, len(model_content)))
       if not self._interpreter:
         raise ValueError(
index 04fc098..16f4f30 100644 (file)
@@ -116,7 +116,7 @@ PyObject* PyArrayFromIntVector(const int* data, npy_intp size) {
 PyObject* PyTupleFromQuantizationParam(const TfLiteQuantizationParams& param) {
   PyObject* result = PyTuple_New(2);
   PyTuple_SET_ITEM(result, 0, PyFloat_FromDouble(param.scale));
-  PyTuple_SET_ITEM(result, 1, PyInt_FromLong(param.zero_point));
+  PyTuple_SET_ITEM(result, 1, PyLong_FromLong(param.zero_point));
   return result;
 }
 
index 4ea4020..86b25e6 100644 (file)
@@ -19,6 +19,7 @@ EXPERIMENTAL: APIs here are unstable and likely to change without notice.
 @@toco_convert
 @@toco_convert_protos
 @@tflite_from_saved_model
+@@Interpreter
 @@OpHint
 @@convert_op_hints_to_stubs
 
@@ -31,6 +32,7 @@ from __future__ import print_function
 from tensorflow.contrib.lite.python.convert import toco_convert
 from tensorflow.contrib.lite.python.convert import toco_convert_protos
 from tensorflow.contrib.lite.python.convert_saved_model import tflite_from_saved_model
+from tensorflow.contrib.lite.python.interpreter import Interpreter
 from tensorflow.contrib.lite.python.op_hint import convert_op_hints_to_stubs
 from tensorflow.contrib.lite.python.op_hint import OpHint
 # pylint: enable=unused-import
index b23dde2..401f833 100644 (file)
@@ -30,15 +30,42 @@ os.chdir(os.path.abspath(os.path.join(os.path.dirname(__file__), "../../..")))
 PIP_PACKAGE_QUERY_EXPRESSION = (
     "deps(//tensorflow/tools/pip_package:build_pip_package)")
 
-# pylint: disable=g-backslash-continuation
-PY_TEST_QUERY_EXPRESSION = 'deps(\
-  filter("^((?!benchmark).)*$",\
-  kind(py_test,\
-  //tensorflow/python/... \
-  + //tensorflow/contrib/... \
-  - //tensorflow/contrib/tensorboard/... \
-  - attr(tags, "manual|no_pip", //tensorflow/...))), 1)'
-# pylint: enable=g-backslash-continuation
+
+def GetBuild(dir_base):
+  """Get the list of BUILD file all targets recursively startind at dir_base."""
+  items = []
+  for root, _, files in os.walk(dir_base):
+    for name in files:
+      if (name == "BUILD" and
+          root.find("tensorflow/contrib/lite/examples/android") == -1):
+        items.append("//" + root + ":all")
+  return items
+
+
+def BuildPyTestDependencies():
+  python_targets = GetBuild("tensorflow/python")
+  contrib_targets = GetBuild("tensorflow/contrib")
+  tensorboard_targets = GetBuild("tensorflow/contrib/tensorboard")
+  tensorflow_targets = GetBuild("tensorflow")
+  # Build list of test targets,
+  # python + contrib - tensorboard - attr(manual|pno_pip)
+  targets = " + ".join(python_targets)
+  for t in contrib_targets:
+    targets += " + " + t
+  for t in tensorboard_targets:
+    targets += " - " + t
+  targets += ' - attr(tags, "manual|no_pip", %s)' % " + ".join(
+      tensorflow_targets)
+  query_kind = "kind(py_test, %s)" % targets
+  # Skip benchmarks etc.
+  query_filter = 'filter("^((?!benchmark).)*$", %s)' % query_kind
+  # Get the dependencies
+  query_deps = "deps(%s, 1)" % query_filter
+
+  return python_targets, query_deps
+
+
+PYTHON_TARGETS, PY_TEST_QUERY_EXPRESSION = BuildPyTestDependencies()
 
 # Hard-coded blacklist of files if not included in pip package
 # TODO(amitpatankar): Clean up blacklist.
@@ -79,16 +106,6 @@ BLACKLIST = [
 ]
 
 
-def bazel_query(query_target):
-  """Run bazel query on target."""
-  try:
-    output = subprocess.check_output(
-        ["bazel", "query", "--keep_going", query_target])
-  except subprocess.CalledProcessError as e:
-    output = e.output
-  return output
-
-
 def main():
   """This script runs the pip smoke test.
 
@@ -103,14 +120,22 @@ def main():
   """
 
   # pip_package_dependencies_list is the list of included files in pip packages
-  pip_package_dependencies = bazel_query(PIP_PACKAGE_QUERY_EXPRESSION)
+  pip_package_dependencies = subprocess.check_output(
+      ["bazel", "cquery", PIP_PACKAGE_QUERY_EXPRESSION])
   pip_package_dependencies_list = pip_package_dependencies.strip().split("\n")
+  pip_package_dependencies_list = [
+      x.split()[0] for x in pip_package_dependencies_list
+  ]
   print("Pip package superset size: %d" % len(pip_package_dependencies_list))
 
   # tf_py_test_dependencies is the list of dependencies for all python
   # tests in tensorflow
-  tf_py_test_dependencies = bazel_query(PY_TEST_QUERY_EXPRESSION)
+  tf_py_test_dependencies = subprocess.check_output(
+      ["bazel", "cquery", PY_TEST_QUERY_EXPRESSION])
   tf_py_test_dependencies_list = tf_py_test_dependencies.strip().split("\n")
+  tf_py_test_dependencies_list = [
+      x.split()[0] for x in tf_py_test_dependencies.strip().split("\n")
+  ]
   print("Pytest dependency subset size: %d" % len(tf_py_test_dependencies_list))
 
   missing_dependencies = []
@@ -141,9 +166,9 @@ def main():
     for missing_dependency in missing_dependencies:
       print("\nMissing dependency: %s " % missing_dependency)
       print("Affected Tests:")
-      rdep_query = ("rdeps(kind(py_test, //tensorflow/python/...), %s)" %
-                    missing_dependency)
-      affected_tests = bazel_query(rdep_query)
+      rdep_query = ("rdeps(kind(py_test, %s), %s)" %
+                    (" + ".join(PYTHON_TARGETS), missing_dependency))
+      affected_tests = subprocess.check_output(["bazel", "cquery", rdep_query])
       affected_tests_list = affected_tests.split("\n")[:-2]
       print("\n".join(affected_tests_list))