Porting tests for `rpc_op` to OS.
authorJiri Simsa <jsimsa@google.com>
Fri, 13 Apr 2018 00:32:36 +0000 (17:32 -0700)
committerTensorFlower Gardener <gardener@tensorflow.org>
Fri, 13 Apr 2018 00:35:08 +0000 (17:35 -0700)
PiperOrigin-RevId: 192698931

12 files changed:
tensorflow/contrib/BUILD
tensorflow/contrib/__init__.py
tensorflow/contrib/cmake/tf_python.cmake
tensorflow/contrib/rpc/BUILD
tensorflow/contrib/rpc/python/kernel_tests/BUILD [new file with mode: 0644]
tensorflow/contrib/rpc/python/kernel_tests/rpc_op_test.py [new file with mode: 0644]
tensorflow/contrib/rpc/python/kernel_tests/rpc_op_test_base.py [new file with mode: 0644]
tensorflow/contrib/rpc/python/kernel_tests/rpc_op_test_servicer.py [new file with mode: 0644]
tensorflow/contrib/rpc/python/kernel_tests/test_example.proto [new file with mode: 0644]
tensorflow/core/platform/default/build_config.bzl
tensorflow/tools/pip_package/BUILD
tensorflow/workspace.bzl

index ae68f4a..7e47516 100644 (file)
@@ -87,6 +87,7 @@ py_library(
         "//tensorflow/contrib/remote_fused_graph/pylib:remote_fused_graph_ops_py",
         "//tensorflow/contrib/resampler:resampler_py",
         "//tensorflow/contrib/rnn:rnn_py",
+        "//tensorflow/contrib/rpc",
         "//tensorflow/contrib/saved_model:saved_model_py",
         "//tensorflow/contrib/seq2seq:seq2seq_py",
         "//tensorflow/contrib/signal:signal_py",
index e27ece8..36cc514 100644 (file)
@@ -71,6 +71,7 @@ from tensorflow.contrib import recurrent
 from tensorflow.contrib import reduce_slice_ops
 from tensorflow.contrib import resampler
 from tensorflow.contrib import rnn
+from tensorflow.contrib import rpc
 from tensorflow.contrib import saved_model
 from tensorflow.contrib import seq2seq
 from tensorflow.contrib import signal
index 21f59d2..f6aaf41 100755 (executable)
@@ -347,7 +347,8 @@ GENERATE_PYTHON_OP_LIB("random_ops")
 GENERATE_PYTHON_OP_LIB("remote_fused_graph_ops"
   DESTINATION ${CMAKE_CURRENT_BINARY_DIR}/tf_python/tensorflow/contrib/remote_fused_graph/pylib/python/ops/gen_remote_fused_graph_ops.py)
 GENERATE_PYTHON_OP_LIB("resource_variable_ops")
-GENERATE_PYTHON_OP_LIB("rpc_ops")
+GENERATE_PYTHON_OP_LIB("rpc_ops"
+  DESTINATION ${CMAKE_CURRENT_BINARY_DIR}/tf_python/tensorflow/contrib/rpc/python/ops/gen_rpc_op.py)
 GENERATE_PYTHON_OP_LIB("script_ops")
 GENERATE_PYTHON_OP_LIB("sdca_ops")
 GENERATE_PYTHON_OP_LIB("set_ops")
index 597f18c..dbd311a 100644 (file)
@@ -4,6 +4,8 @@ licenses(["notice"])  # Apache 2.0
 
 exports_files(["LICENSE"])
 
+load("//tensorflow/core:platform/default/build_config_root.bzl", "if_static")
+
 py_library(
     name = "rpc",
     srcs = [
@@ -11,3 +13,17 @@ py_library(
     ],
     deps = ["//tensorflow/contrib/rpc/python/ops:rpc_op_py"],
 )
+
+py_library(
+    name = "rpc_pip",
+    data = if_static(
+        [],
+        otherwise = ["//tensorflow/contrib/rpc/python/kernel_tests:libtestexample.so"],
+    ),
+    deps = [
+        ":rpc",
+        "//tensorflow/contrib/rpc/python/kernel_tests:py_test_deps",
+        "//tensorflow/contrib/rpc/python/kernel_tests:rpc_op_test_base",
+        "//tensorflow/contrib/rpc/python/kernel_tests:rpc_op_test_servicer",
+    ],
+)
diff --git a/tensorflow/contrib/rpc/python/kernel_tests/BUILD b/tensorflow/contrib/rpc/python/kernel_tests/BUILD
new file mode 100644 (file)
index 0000000..08ec1e6
--- /dev/null
@@ -0,0 +1,76 @@
+# TODO(b/76425722): Port everything in here to OS (currently excluded).
+
+package(default_visibility = ["//visibility:public"])
+
+licenses(["notice"])  # Apache 2.0
+
+exports_files(["LICENSE"])
+
+# Placeholder for loading internal BUILD rule.
+load("//tensorflow:tensorflow.bzl", "tf_py_test")
+load("//tensorflow:tensorflow.bzl", "tf_cc_shared_object")
+load("//tensorflow/core:platform/default/build_config_root.bzl", "if_static")
+load("//tensorflow/core:platform/default/build_config.bzl", "tf_proto_library")
+
+tf_proto_library(
+    name = "test_example_proto",
+    srcs = ["test_example.proto"],
+    has_services = 1,
+    cc_api_version = 2,
+    protodeps = ["//tensorflow/core:protos_all"],
+)
+
+py_library(
+    name = "py_test_deps",
+    deps = [":test_example_proto_py"],
+)
+
+py_library(
+    name = "rpc_op_test_base",
+    srcs = ["rpc_op_test_base.py"],
+    deps = [
+        ":test_example_proto_py",
+        "//tensorflow/contrib/proto",
+        "//tensorflow/contrib/rpc",
+        "//tensorflow/core:protos_all_py",
+        "//tensorflow/python:dtypes",
+        "//tensorflow/python:errors",
+        "//third_party/py/numpy",
+    ],
+)
+
+py_library(
+    name = "rpc_op_test_servicer",
+    srcs = ["rpc_op_test_servicer.py"],
+    deps = [
+        ":py_test_deps",
+        ":rpc_op_test_base",
+        "//tensorflow/core:protos_all_py",
+        "//third_party/py/numpy",
+    ],
+)
+
+tf_cc_shared_object(
+    name = "libtestexample.so",
+    linkstatic = 1,
+    deps = [
+        ":test_example_proto_cc",
+    ],
+)
+
+tf_py_test(
+    name = "rpc_op_test",
+    size = "small",
+    srcs = ["rpc_op_test.py"],
+    additional_deps = [
+        ":py_test_deps",
+        ":rpc_op_test_base",
+        ":rpc_op_test_servicer",
+        "//tensorflow/core:protos_all_py",
+        "//tensorflow/python:client_testlib",
+    ],
+    data = if_static(
+        [],
+        otherwise = [":libtestexample.so"],
+    ),
+)
diff --git a/tensorflow/contrib/rpc/python/kernel_tests/rpc_op_test.py b/tensorflow/contrib/rpc/python/kernel_tests/rpc_op_test.py
new file mode 100644 (file)
index 0000000..e2e0dbc
--- /dev/null
@@ -0,0 +1,71 @@
+# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# =============================================================================
+
+"""Tests for RpcOp."""
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import ctypes as ct
+import os
+
+import grpc
+from grpc.framework.foundation import logging_pool
+import portpicker
+
+from tensorflow.contrib.rpc.python.kernel_tests import rpc_op_test_base
+from tensorflow.contrib.rpc.python.kernel_tests import rpc_op_test_servicer
+from tensorflow.contrib.rpc.python.kernel_tests import test_example_pb2_grpc
+from tensorflow.python.platform import test
+
+
+class RpcOpTest(test.TestCase, rpc_op_test_base.RpcOpTestBase):
+  _protocol = 'grpc'
+
+  invalid_method_string = 'Method not found'
+
+  def __init__(self, methodName='runTest'):  # pylint: disable=invalid-name
+    super(RpcOpTest, self).__init__(methodName)
+    lib = os.path.join(os.path.dirname(__file__), 'libtestexample.so')
+    if os.path.isfile(lib):
+      ct.cdll.LoadLibrary(lib)
+
+  def get_method_name(self, suffix):
+    return '/tensorflow.contrib.rpc.TestCaseService/%s' % suffix
+
+  def setUp(self):
+    super(RpcOpTest, self).setUp()
+
+    service_port = portpicker.pick_unused_port()
+
+    server = grpc.server(logging_pool.pool(max_workers=25))
+    servicer = rpc_op_test_servicer.RpcOpTestServicer()
+    test_example_pb2_grpc.add_TestCaseServiceServicer_to_server(
+        servicer, server)
+    self._address = 'localhost:%d' % service_port
+    server.add_insecure_port(self._address)
+    server.start()
+    self._server = server
+
+  def tearDown(self):
+    # TODO(ebrevdo): Figure out why this sometimes times out.
+    #    self._service.ExitLoop()
+    #    self._service_thread.join()
+    # self._server.stop()
+    super(RpcOpTest, self).tearDown()
+
+
+if __name__ == '__main__':
+  test.main()
diff --git a/tensorflow/contrib/rpc/python/kernel_tests/rpc_op_test_base.py b/tensorflow/contrib/rpc/python/kernel_tests/rpc_op_test_base.py
new file mode 100644 (file)
index 0000000..aa03a10
--- /dev/null
@@ -0,0 +1,337 @@
+# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# =============================================================================
+
+"""Base class for RpcOp tests."""
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import itertools
+
+import numpy as np
+
+from tensorflow.contrib.proto import decode_proto
+from tensorflow.contrib.proto import encode_proto
+from tensorflow.contrib.rpc import rpc
+from tensorflow.contrib.rpc import try_rpc
+from tensorflow.contrib.rpc.python.kernel_tests import test_example_pb2
+from tensorflow.core.protobuf import config_pb2
+from tensorflow.python.framework import dtypes
+from tensorflow.python.framework import errors
+
+__all__ = ['I_WARNED_YOU', 'RpcOpTestBase']
+
+I_WARNED_YOU = 'I warned you!'
+
+
+class RpcOpTestBase(object):
+  # pylint: disable=missing-docstring,invalid-name
+  """Base class for RpcOp tests."""
+
+  def get_method_name(self, suffix):
+    raise NotImplementedError
+
+  def rpc(self, *args, **kwargs):
+    return rpc(*args, protocol=self._protocol, **kwargs)
+
+  def try_rpc(self, *args, **kwargs):
+    return try_rpc(*args, protocol=self._protocol, **kwargs)
+
+  def testScalarHostPortRpc(self):
+    with self.test_session() as sess:
+      request_tensors = (
+          test_example_pb2.TestCase(shape=[1, 2, 3]).SerializeToString())
+      response_tensors = self.rpc(
+          method=self.get_method_name('IncrementTestShapes'),
+          address=self._address,
+          request=request_tensors)
+      self.assertEqual(response_tensors.shape, ())
+      response_values = sess.run(response_tensors)
+    response_message = test_example_pb2.TestCase()
+    self.assertTrue(response_message.ParseFromString(response_values))
+    self.assertAllEqual([2, 3, 4], response_message.shape)
+
+  def testScalarHostPortTryRpc(self):
+    with self.test_session() as sess:
+      request_tensors = (
+          test_example_pb2.TestCase(shape=[1, 2, 3]).SerializeToString())
+      response_tensors, status_code, status_message = self.try_rpc(
+          method=self.get_method_name('IncrementTestShapes'),
+          address=self._address,
+          request=request_tensors)
+      self.assertEqual(status_code.shape, ())
+      self.assertEqual(status_message.shape, ())
+      self.assertEqual(response_tensors.shape, ())
+      response_values, status_code_values, status_message_values = (
+          sess.run((response_tensors, status_code, status_message)))
+    response_message = test_example_pb2.TestCase()
+    self.assertTrue(response_message.ParseFromString(response_values))
+    self.assertAllEqual([2, 3, 4], response_message.shape)
+    # For the base Rpc op, don't expect to get error status back.
+    self.assertEqual(errors.OK, status_code_values)
+    self.assertEqual(b'', status_message_values)
+
+  def testEmptyHostPortRpc(self):
+    with self.test_session() as sess:
+      request_tensors = []
+      response_tensors = self.rpc(
+          method=self.get_method_name('IncrementTestShapes'),
+          address=self._address,
+          request=request_tensors)
+      self.assertAllEqual(response_tensors.shape, [0])
+      response_values = sess.run(response_tensors)
+    self.assertAllEqual(response_values.shape, [0])
+
+  def testInvalidAddresses(self):
+    with self.test_session() as sess:
+      with self.assertRaisesOpError(self.invalid_method_string):
+        sess.run(
+            self.rpc(
+                method='/InvalidService.IncrementTestShapes',
+                address=self._address,
+                request=''))
+
+      with self.assertRaisesOpError(self.invalid_method_string):
+        sess.run(
+            self.rpc(
+                method=self.get_method_name('InvalidMethodName'),
+                address=self._address,
+                request=''))
+
+      # This also covers the case of address=''
+      # and address='localhost:293874293874'
+      with self.assertRaises(errors.UnavailableError):
+        sess.run(
+            self.rpc(
+                method=self.get_method_name('IncrementTestShapes'),
+                address='unix:/tmp/this_unix_socket_doesnt_exist_97820348!!@',
+                request=''))
+
+      # Test invalid method with the TryRpc op
+      _, status_code_value, status_message_value = sess.run(
+          self.try_rpc(
+              method=self.get_method_name('InvalidMethodName'),
+              address=self._address,
+              request=''))
+      self.assertEqual(errors.UNIMPLEMENTED, status_code_value)
+      self.assertTrue(
+          self.invalid_method_string in status_message_value.decode('ascii'))
+
+  def testAlwaysFailingMethod(self):
+    with self.test_session() as sess:
+      response_tensors = self.rpc(
+          method=self.get_method_name('AlwaysFailWithInvalidArgument'),
+          address=self._address,
+          request='')
+      self.assertEqual(response_tensors.shape, ())
+      with self.assertRaisesOpError(I_WARNED_YOU):
+        sess.run(response_tensors)
+
+  def testSometimesFailingMethodWithManyRequests(self):
+    with self.test_session() as sess:
+      # Fail hard by default.
+      response_tensors = self.rpc(
+          method=self.get_method_name('SometimesFailWithInvalidArgument'),
+          address=self._address,
+          request=[''] * 20)
+      self.assertEqual(response_tensors.shape, (20,))
+      with self.assertRaisesOpError(I_WARNED_YOU):
+        sess.run(response_tensors)
+
+      # Don't fail hard, use TryRpc - return the failing status instead.
+      response_tensors, status_code, status_message = self.try_rpc(
+          method=self.get_method_name('SometimesFailWithInvalidArgument'),
+          address=self._address,
+          request=[''] * 20)
+      self.assertEqual(response_tensors.shape, (20,))
+      self.assertEqual(status_code.shape, (20,))
+      self.assertEqual(status_message.shape, (20,))
+      status_code_values, status_message_values = sess.run((status_code,
+                                                            status_message))
+      self.assertTrue([
+          x in (errors.OK, errors.INVALID_ARGUMENT) for x in status_code_values
+      ])
+      expected_message_values = np.where(
+          status_code_values == errors.INVALID_ARGUMENT,
+          I_WARNED_YOU.encode('ascii'), b'')
+      self.assertAllEqual(expected_message_values, status_message_values)
+
+  def testVecHostPortRpc(self):
+    with self.test_session() as sess:
+      request_tensors = [
+          test_example_pb2.TestCase(
+              shape=[i, i + 1, i + 2]).SerializeToString() for i in range(20)
+      ]
+      response_tensors = self.rpc(
+          method=self.get_method_name('IncrementTestShapes'),
+          address=self._address,
+          request=request_tensors)
+      self.assertEqual(response_tensors.shape, (20,))
+      response_values = sess.run(response_tensors)
+    self.assertEqual(response_values.shape, (20,))
+    for i in range(20):
+      response_message = test_example_pb2.TestCase()
+      self.assertTrue(response_message.ParseFromString(response_values[i]))
+      self.assertAllEqual([i + 1, i + 2, i + 3], response_message.shape)
+
+  def testVecHostPortManyParallelRpcs(self):
+    with self.test_session() as sess:
+      request_tensors = [
+          test_example_pb2.TestCase(
+              shape=[i, i + 1, i + 2]).SerializeToString() for i in range(20)
+      ]
+      many_response_tensors = [
+          self.rpc(
+              method=self.get_method_name('IncrementTestShapes'),
+              address=self._address,
+              request=request_tensors) for _ in range(10)
+      ]
+      # Launch parallel 10 calls to the RpcOp, each containing
+      # 20 rpc requests.
+      many_response_values = sess.run(many_response_tensors)
+    self.assertEqual(10, len(many_response_values))
+    for response_values in many_response_values:
+      self.assertEqual(response_values.shape, (20,))
+      for i in range(20):
+        response_message = test_example_pb2.TestCase()
+        self.assertTrue(response_message.ParseFromString(response_values[i]))
+        self.assertAllEqual([i + 1, i + 2, i + 3], response_message.shape)
+
+  def testVecHostPortRpcUsingEncodeAndDecodeProto(self):
+    with self.test_session() as sess:
+      request_tensors = encode_proto(
+          message_type='tensorflow.contrib.rpc.TestCase',
+          field_names=['shape'],
+          sizes=[[3]] * 20,
+          values=[
+              [[i, i + 1, i + 2] for i in range(20)],
+          ])
+      response_tensor_strings = self.rpc(
+          method=self.get_method_name('IncrementTestShapes'),
+          address=self._address,
+          request=request_tensors)
+      _, (response_shape,) = decode_proto(
+          bytes=response_tensor_strings,
+          message_type='tensorflow.contrib.rpc.TestCase',
+          field_names=['shape'],
+          output_types=[dtypes.int32])
+      response_shape_values = sess.run(response_shape)
+    self.assertAllEqual([[i + 1, i + 2, i + 3]
+                         for i in range(20)], response_shape_values)
+
+  def testVecHostPortRpcCancelsUponSessionTimeOutWhenSleepingForever(self):
+    with self.test_session() as sess:
+      request_tensors = [''] * 25  # This will launch 25 RPC requests.
+      response_tensors = self.rpc(
+          method=self.get_method_name('SleepForever'),
+          address=self._address,
+          request=request_tensors)
+      for timeout_ms in [1, 500, 1000]:
+        options = config_pb2.RunOptions(timeout_in_ms=timeout_ms)
+        with self.assertRaises((errors.UnavailableError,
+                                errors.DeadlineExceededError)):
+          sess.run(response_tensors, options=options)
+
+  def testVecHostPortRpcCancelsUponConfiguredTimeOutWhenSleepingForever(self):
+    with self.test_session() as sess:
+      request_tensors = [''] * 25  # This will launch 25 RPC requests.
+      response_tensors = self.rpc(
+          method=self.get_method_name('SleepForever'),
+          address=self._address,
+          timeout_in_ms=1000,
+          request=request_tensors)
+      with self.assertRaises(errors.DeadlineExceededError):
+        sess.run(response_tensors)
+
+  def testTryRpcPropagatesDeadlineErrorWithSometimesTimingOutRequests(self):
+    with self.test_session() as sess:
+      response_tensors, status_code, status_message = self.try_rpc(
+          method=self.get_method_name('SometimesSleepForever'),
+          timeout_in_ms=1000,
+          address=self._address,
+          request=[''] * 20)
+      self.assertEqual(response_tensors.shape, (20,))
+      self.assertEqual(status_code.shape, (20,))
+      self.assertEqual(status_message.shape, (20,))
+      status_code_values = sess.run(status_code)
+      self.assertTrue([
+          x in (errors.OK, errors.DEADLINE_EXCEEDED) for x in status_code_values
+      ])
+
+  def testTryRpcWithMultipleAddressesSingleRequest(self):
+    flatten = lambda x: list(itertools.chain.from_iterable(x))
+    with self.test_session() as sess:
+      addresses = flatten([[
+          self._address, 'unix:/tmp/this_unix_socket_doesnt_exist_97820348!!@'
+      ] for _ in range(10)])
+      request = test_example_pb2.TestCase(shape=[0, 1, 2]).SerializeToString()
+      response_tensors, status_code, _ = self.try_rpc(
+          method=self.get_method_name('IncrementTestShapes'),
+          address=addresses,
+          request=request)
+      response_tensors_values, status_code_values = sess.run((response_tensors,
+                                                              status_code))
+      self.assertAllEqual(
+          flatten([errors.OK, errors.UNAVAILABLE] for _ in range(10)),
+          status_code_values)
+      for i in range(10):
+        self.assertTrue(response_tensors_values[2 * i])
+        self.assertFalse(response_tensors_values[2 * i + 1])
+
+  def testTryRpcWithMultipleMethodsSingleRequest(self):
+    flatten = lambda x: list(itertools.chain.from_iterable(x))
+    with self.test_session() as sess:
+      methods = flatten(
+          [[self.get_method_name('IncrementTestShapes'), 'InvalidMethodName']
+           for _ in range(10)])
+      request = test_example_pb2.TestCase(shape=[0, 1, 2]).SerializeToString()
+      response_tensors, status_code, _ = self.try_rpc(
+          method=methods, address=self._address, request=request)
+      response_tensors_values, status_code_values = sess.run((response_tensors,
+                                                              status_code))
+      self.assertAllEqual(
+          flatten([errors.OK, errors.UNIMPLEMENTED] for _ in range(10)),
+          status_code_values)
+      for i in range(10):
+        self.assertTrue(response_tensors_values[2 * i])
+        self.assertFalse(response_tensors_values[2 * i + 1])
+
+  def testTryRpcWithMultipleAddressesAndRequests(self):
+    flatten = lambda x: list(itertools.chain.from_iterable(x))
+    with self.test_session() as sess:
+      addresses = flatten([[
+          self._address, 'unix:/tmp/this_unix_socket_doesnt_exist_97820348!!@'
+      ] for _ in range(10)])
+      requests = [
+          test_example_pb2.TestCase(
+              shape=[i, i + 1, i + 2]).SerializeToString() for i in range(20)
+      ]
+      response_tensors, status_code, _ = self.try_rpc(
+          method=self.get_method_name('IncrementTestShapes'),
+          address=addresses,
+          request=requests)
+      response_tensors_values, status_code_values = sess.run((response_tensors,
+                                                              status_code))
+      self.assertAllEqual(
+          flatten([errors.OK, errors.UNAVAILABLE] for _ in range(10)),
+          status_code_values)
+      for i in range(20):
+        if i % 2 == 1:
+          self.assertFalse(response_tensors_values[i])
+        else:
+          response_message = test_example_pb2.TestCase()
+          self.assertTrue(
+              response_message.ParseFromString(response_tensors_values[i]))
+          self.assertAllEqual([i + 1, i + 2, i + 3], response_message.shape)
diff --git a/tensorflow/contrib/rpc/python/kernel_tests/rpc_op_test_servicer.py b/tensorflow/contrib/rpc/python/kernel_tests/rpc_op_test_servicer.py
new file mode 100644 (file)
index 0000000..7cbd636
--- /dev/null
@@ -0,0 +1,101 @@
+# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# =============================================================================
+
+"""Test servicer for RpcOp tests."""
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import random
+import time
+
+import grpc
+
+from tensorflow.contrib.rpc.python.kernel_tests import rpc_op_test_base
+from tensorflow.contrib.rpc.python.kernel_tests import test_example_pb2_grpc
+
+
+class RpcOpTestServicer(test_example_pb2_grpc.TestCaseServiceServicer):
+  """Test servicer for RpcOp tests."""
+
+  def IncrementTestShapes(self, request, context):
+    """Increment the entries in the shape attribute of request.
+
+    Args:
+      request: input TestCase.
+      context: the rpc context.
+
+    Returns:
+      output TestCase.
+    """
+    for i in range(len(request.shape)):
+      request.shape[i] += 1
+    return request
+
+  def AlwaysFailWithInvalidArgument(self, request, context):
+    """Always fails with an InvalidArgument status.
+
+    Args:
+      request: input TestCase.
+      context: the rpc context.
+
+    Returns:
+      output TestCase.
+    """
+    del request
+    context.set_code(grpc.StatusCode.INVALID_ARGUMENT)
+    context.set_details(rpc_op_test_base.I_WARNED_YOU)
+
+  def SometimesFailWithInvalidArgument(self, request, context):
+    """Sometimes fails with an InvalidArgument status.
+
+    Args:
+      request: input TestCase.
+      context: the rpc context.
+
+    Returns:
+      output TestCase.
+    """
+    if random.randint(0, 1) == 1:
+      context.set_code(grpc.StatusCode.INVALID_ARGUMENT)
+      context.set_details(rpc_op_test_base.I_WARNED_YOU)
+    return request
+
+  def SleepForever(self, request, context):
+    """Sleeps forever.
+
+    Args:
+      request: input TestCase.
+      context: the rpc context.
+
+    Returns:
+      output TestCase.
+    """
+    # TODO(ebrevdo): Make this async wait like the stubby version.
+    time.sleep(5)
+
+  def SometimesSleepForever(self, request, context):
+    """Sometimes sleeps forever.
+
+    Args:
+      request: input TestCase.
+      context: the rpc context.
+
+    Returns:
+      output TestCase.
+    """
+    if random.randint(0, 1) == 1:
+      time.sleep(5)
+    return request
diff --git a/tensorflow/contrib/rpc/python/kernel_tests/test_example.proto b/tensorflow/contrib/rpc/python/kernel_tests/test_example.proto
new file mode 100644 (file)
index 0000000..96f4550
--- /dev/null
@@ -0,0 +1,171 @@
+// Test description and protos to work with it.
+//
+// Many of the protos in this file are for unit tests that haven't been written yet.
+
+syntax = "proto2";
+
+import "tensorflow/core/framework/types.proto";
+
+package tensorflow.contrib.rpc;
+
+// A TestCase holds a proto and a bunch of assertions
+// about how it should decode.
+message TestCase {
+  // A batch of primitives to be serialized and decoded.
+  repeated RepeatedPrimitiveValue primitive = 1;
+  // The shape of the batch.
+  repeated int32 shape = 2;
+  // Expected sizes for each field.
+  repeated int32 sizes = 3;
+  // Expected values for each field.
+  repeated FieldSpec field = 4;
+};
+
+service TestCaseService {
+  // Copy input, and increment each entry in 'shape' by 1.
+  rpc IncrementTestShapes(TestCase) returns (TestCase) {
+  }
+
+  // Sleep forever.
+  rpc SleepForever(TestCase) returns (TestCase) {
+  }
+
+  // Sleep forever 50% of the time, return immediately the other 50%.
+  rpc SometimesSleepForever(TestCase) returns (TestCase) {
+  }
+
+  // Always fails with InvalidArgument.
+  rpc AlwaysFailWithInvalidArgument(TestCase) returns (TestCase) {
+  }
+
+  // Fails with InvalidArgument 50% of the time.
+  rpc SometimesFailWithInvalidArgument(TestCase) returns (TestCase) {
+  }
+};
+
+// FieldSpec describes the expected output for a single field.
+message FieldSpec {
+  optional string name = 1;
+  optional tensorflow.DataType dtype = 2;
+  optional RepeatedPrimitiveValue expected = 3;
+};
+
+message TestValue {
+  optional PrimitiveValue primitive_value = 1;
+  optional EnumValue enum_value = 2;
+  optional MessageValue message_value = 3;
+  optional RepeatedMessageValue repeated_message_value = 4;
+  optional RepeatedPrimitiveValue repeated_primitive_value = 6;
+}
+
+message PrimitiveValue {
+  optional double double_value = 1;
+  optional float float_value = 2;
+  optional int64 int64_value = 3;
+  optional uint64 uint64_value = 4;
+  optional int32 int32_value = 5;
+  optional fixed64 fixed64_value = 6;
+  optional fixed32 fixed32_value = 7;
+  optional bool bool_value = 8;
+  optional string string_value = 9;
+  optional bytes bytes_value = 12;
+  optional uint32 uint32_value = 13;
+  optional sfixed32 sfixed32_value = 15;
+  optional sfixed64 sfixed64_value = 16;
+  optional sint32 sint32_value = 17;
+  optional sint64 sint64_value = 18;
+}
+
+// NOTE: This definition must be kept in sync with PackedPrimitiveValue.
+message RepeatedPrimitiveValue {
+  repeated double double_value = 1;
+  repeated float float_value = 2;
+  repeated int64 int64_value = 3;
+  repeated uint64 uint64_value = 4;
+  repeated int32 int32_value = 5;
+  repeated fixed64 fixed64_value = 6;
+  repeated fixed32 fixed32_value = 7;
+  repeated bool bool_value = 8;
+  repeated string string_value = 9;
+  repeated bytes bytes_value = 12;
+  repeated uint32 uint32_value = 13;
+  repeated sfixed32 sfixed32_value = 15;
+  repeated sfixed64 sfixed64_value = 16;
+  repeated sint32 sint32_value = 17;
+  repeated sint64 sint64_value = 18;
+  repeated PrimitiveValue message_value = 19;
+}
+
+// A PackedPrimitiveValue looks exactly the same as a RepeatedPrimitiveValue
+// in the text format, but the binary serializion is different.
+// We test the packed representations by loading the same test cases
+// using this definition instead of RepeatedPrimitiveValue.
+// NOTE: This definition must be kept in sync with RepeatedPrimitiveValue
+// in every way except the packed=true declaration.
+message PackedPrimitiveValue {
+  repeated double double_value = 1 [packed = true];
+  repeated float float_value = 2 [packed = true];
+  repeated int64 int64_value = 3 [packed = true];
+  repeated uint64 uint64_value = 4 [packed = true];
+  repeated int32 int32_value = 5 [packed = true];
+  repeated fixed64 fixed64_value = 6 [packed = true];
+  repeated fixed32 fixed32_value = 7 [packed = true];
+  repeated bool bool_value = 8 [packed = true];
+  repeated string string_value = 9;
+  repeated bytes bytes_value = 12;
+  repeated uint32 uint32_value = 13 [packed = true];
+  repeated sfixed32 sfixed32_value = 15 [packed = true];
+  repeated sfixed64 sfixed64_value = 16 [packed = true];
+  repeated sint32 sint32_value = 17 [packed = true];
+  repeated sint64 sint64_value = 18 [packed = true];
+  repeated PrimitiveValue message_value = 19;
+}
+
+message EnumValue {
+  enum Color {
+    RED = 0;
+    ORANGE = 1;
+    YELLOW = 2;
+    GREEN = 3;
+    BLUE = 4;
+    INDIGO = 5;
+    VIOLET = 6;
+  };
+  optional Color enum_value = 14;
+  repeated Color repeated_enum_value = 15;
+}
+
+
+message InnerMessageValue {
+  optional float float_value = 2;
+  repeated bytes bytes_values = 8;
+}
+
+message MiddleMessageValue {
+  repeated int32 int32_values = 5;
+  optional InnerMessageValue message_value = 11;
+  optional uint32 uint32_value = 13;
+}
+
+message MessageValue {
+  optional double double_value = 1;
+  optional MiddleMessageValue message_value = 11;
+}
+
+message RepeatedMessageValue {
+  message NestedMessageValue {
+    optional float float_value = 2;
+    repeated bytes bytes_values = 8;
+  }
+
+  repeated NestedMessageValue message_values = 11;
+}
+
+// Message containing fields with field numbers higher than any field above. An
+// instance of this message is prepended to each binary message in the test to
+// exercise the code path that handles fields encoded out of order of field
+// number.
+message ExtraFields {
+  optional string string_value = 1776;
+  optional bool bool_value = 1777;
+}
index 4cfa25b..44356e3 100644 (file)
@@ -1,7 +1,6 @@
 # Platform-specific build configurations.
 
 load("@protobuf_archive//:protobuf.bzl", "proto_gen")
-load("@protobuf_archive//:protobuf.bzl", "py_proto_library")
 load("//tensorflow:tensorflow.bzl", "if_not_mobile")
 load("//tensorflow:tensorflow.bzl", "if_windows")
 load("//tensorflow:tensorflow.bzl", "if_not_windows")
@@ -110,6 +109,12 @@ def _proto_cc_srcs(srcs, use_grpc_plugin=False):
     ret += [s[:-len(".proto")] + ".grpc.pb.cc" for s in srcs]
   return ret
 
+def _proto_py_outs(srcs, use_grpc_plugin=False):
+  ret = [s[:-len(".proto")] + "_pb2.py" for s in srcs]
+  if use_grpc_plugin:
+    ret += [s[:-len(".proto")] + "_pb2_grpc.py" for s in srcs]
+  return ret
+
 # Re-defined protocol buffer rule to allow building "header only" protocol
 # buffers, to avoid duplicate registrations. Also allows non-iterable cc_libs
 # containing select() statements.
@@ -217,6 +222,80 @@ def cc_proto_library(
       hdrs=gen_hdrs,
       **kargs)
 
+# Re-defined protocol buffer rule to bring in the change introduced in commit
+# https://github.com/google/protobuf/commit/294b5758c373cbab4b72f35f4cb62dc1d8332b68
+# which was not part of a stable protobuf release in 04/2018.
+# TODO(jsimsa): Remove this once the protobuf dependency version is updated
+# to include the above commit.
+def py_proto_library(
+        name,
+        srcs=[],
+        deps=[],
+        py_libs=[],
+        py_extra_srcs=[],
+        include=None,
+        default_runtime="@protobuf_archive//:protobuf_python",
+        protoc="@protobuf_archive//:protoc",
+        use_grpc_plugin=False,
+        **kargs):
+  """Bazel rule to create a Python protobuf library from proto source files
+
+  NOTE: the rule is only an internal workaround to generate protos. The
+  interface may change and the rule may be removed when bazel has introduced
+  the native rule.
+
+  Args:
+    name: the name of the py_proto_library.
+    srcs: the .proto files of the py_proto_library.
+    deps: a list of dependency labels; must be py_proto_library.
+    py_libs: a list of other py_library targets depended by the generated
+        py_library.
+    py_extra_srcs: extra source files that will be added to the output
+        py_library. This attribute is used for internal bootstrapping.
+    include: a string indicating the include path of the .proto files.
+    default_runtime: the implicitly default runtime which will be depended on by
+        the generated py_library target.
+    protoc: the label of the protocol compiler to generate the sources.
+    use_grpc_plugin: a flag to indicate whether to call the Python C++ plugin
+        when processing the proto files.
+    **kargs: other keyword arguments that are passed to cc_library.
+  """
+  outs = _proto_py_outs(srcs, use_grpc_plugin)
+
+  includes = []
+  if include != None:
+    includes = [include]
+
+  grpc_python_plugin = None
+  if use_grpc_plugin:
+    grpc_python_plugin = "//external:grpc_python_plugin"
+    # Note: Generated grpc code depends on Python grpc module. This dependency
+    # is not explicitly listed in py_libs. Instead, host system is assumed to
+    # have grpc installed.
+
+  proto_gen(
+      name=name + "_genproto",
+      srcs=srcs,
+      deps=[s + "_genproto" for s in deps],
+      includes=includes,
+      protoc=protoc,
+      gen_py=1,
+      outs=outs,
+      visibility=["//visibility:public"],
+      plugin=grpc_python_plugin,
+      plugin_language="grpc"
+  )
+
+  if default_runtime and not default_runtime in py_libs + deps:
+    py_libs = py_libs + [default_runtime]
+
+  native.py_library(
+      name=name,
+      srcs=outs+py_extra_srcs,
+      deps=py_libs+deps,
+      imports=includes,
+      **kargs)
+
 def tf_proto_library_cc(name, srcs = [], has_services = None,
                         protodeps = [],
                         visibility = [], testonly = 0,
@@ -261,8 +340,7 @@ def tf_proto_library_cc(name, srcs = [], has_services = None,
   )
 
 def tf_proto_library_py(name, srcs=[], protodeps=[], deps=[], visibility=[],
-                        testonly=0,
-                        srcs_version="PY2AND3"):
+                        testonly=0, srcs_version="PY2AND3", use_grpc_plugin=False):
   py_proto_library(
       name = name + "_py",
       srcs = srcs,
@@ -272,6 +350,7 @@ def tf_proto_library_py(name, srcs=[], protodeps=[], deps=[], visibility=[],
       default_runtime = "@protobuf_archive//:protobuf_python",
       visibility = visibility,
       testonly = testonly,
+      use_grpc_plugin = use_grpc_plugin,
   )
 
 def tf_jspb_proto_library(**kwargs):
@@ -310,6 +389,7 @@ def tf_proto_library(name, srcs = [], has_services = None,
       srcs_version = "PY2AND3",
       testonly = testonly,
       visibility = visibility,
+      use_grpc_plugin = has_services,
   )
 
 def tf_additional_lib_hdrs(exclude = []):
index a0bae23..2ef1057 100644 (file)
@@ -76,6 +76,7 @@ COMMON_PIP_DEPS = [
     "//tensorflow/contrib/predictor:predictor_pip",
     "//tensorflow/contrib/proto:proto_pip",
     "//tensorflow/contrib/receptive_field:receptive_field_pip",
+    "//tensorflow/contrib/rpc:rpc_pip",
     "//tensorflow/contrib/session_bundle:session_bundle_pip",
     "//tensorflow/contrib/signal:signal_py",
     "//tensorflow/contrib/signal:test_util",
index 72f446d..dee2fcd 100644 (file)
@@ -763,6 +763,10 @@ def tf_workspace(path_prefix="", tf_repo_name=""):
       name = "grpc_cpp_plugin",
       actual = "@grpc//:grpc_cpp_plugin",
   )
+  native.bind(
+      name = "grpc_python_plugin",
+      actual = "@grpc//:grpc_python_plugin",
+  )
 
   # gRPC has three empty C++ functions which it wants the user to define
   # at build time. https://github.com/grpc/grpc/issues/13590