Exposing tensorflow.contrib.proto in the pip package.
authorJiri Simsa <jsimsa@google.com>
Thu, 12 Apr 2018 23:35:47 +0000 (16:35 -0700)
committerTensorFlower Gardener <gardener@tensorflow.org>
Thu, 12 Apr 2018 23:37:59 +0000 (16:37 -0700)
PiperOrigin-RevId: 192691078

tensorflow/contrib/BUILD
tensorflow/contrib/__init__.py
tensorflow/contrib/cmake/tf_python.cmake
tensorflow/contrib/proto/python/kernel_tests/decode_proto_fail_test.py
tensorflow/contrib/proto/python/kernel_tests/decode_proto_op_test.py
tensorflow/contrib/proto/python/kernel_tests/encode_proto_op_test.py

index 9bef0d8..ae68f4a 100644 (file)
@@ -77,6 +77,7 @@ py_library(
         "//tensorflow/contrib/optimizer_v2:optimizer_v2_py",
         "//tensorflow/contrib/periodic_resample:init_py",
         "//tensorflow/contrib/predictor",
+        "//tensorflow/contrib/proto",
         "//tensorflow/contrib/quantization:quantization_py",
         "//tensorflow/contrib/quantize:quantize_graph",
         "//tensorflow/contrib/autograph",
index aaddb06..e27ece8 100644 (file)
@@ -64,6 +64,7 @@ from tensorflow.contrib import nn
 from tensorflow.contrib import opt
 from tensorflow.contrib import periodic_resample
 from tensorflow.contrib import predictor
+from tensorflow.contrib import proto
 from tensorflow.contrib import quantization
 from tensorflow.contrib import quantize
 from tensorflow.contrib import recurrent
index ded15b4..21f59d2 100755 (executable)
@@ -330,8 +330,10 @@ GENERATE_PYTHON_OP_LIB("ctc_ops")
 GENERATE_PYTHON_OP_LIB("cudnn_rnn_ops")
 GENERATE_PYTHON_OP_LIB("data_flow_ops")
 GENERATE_PYTHON_OP_LIB("dataset_ops")
-GENERATE_PYTHON_OP_LIB("decode_proto_ops")
-GENERATE_PYTHON_OP_LIB("encode_proto_ops")
+GENERATE_PYTHON_OP_LIB("decode_proto_ops"
+  DESTINATION ${CMAKE_CURRENT_BINARY_DIR}/tf_python/tensorflow/contrib/proto/python/ops/gen_decode_proto_op.py)
+GENERATE_PYTHON_OP_LIB("encode_proto_ops"
+  DESTINATION ${CMAKE_CURRENT_BINARY_DIR}/tf_python/tensorflow/contrib/proto/python/ops/gen_encode_proto_op.py)
 GENERATE_PYTHON_OP_LIB("image_ops")
 GENERATE_PYTHON_OP_LIB("io_ops")
 GENERATE_PYTHON_OP_LIB("linalg_ops")
index f019833..f8969b0 100644 (file)
@@ -21,7 +21,7 @@ from __future__ import print_function
 
 import numpy as np
 
-from tensorflow.contrib import proto
+from tensorflow.contrib.proto import decode_proto
 from tensorflow.contrib.proto.python.kernel_tests import test_case
 from tensorflow.python.framework import dtypes
 from tensorflow.python.framework import errors
@@ -46,7 +46,7 @@ class DecodeProtoFailTest(test_case.ProtoOpTestCase):
     field_types = [dtypes.int32]
 
     with self.test_session() as sess:
-      ctensor, vtensor = proto.decode_proto(
+      ctensor, vtensor = decode_proto(
           batch,
           message_type=msg_type,
           field_names=field_names,
index 30ceac5..cd5121c 100644 (file)
@@ -27,7 +27,7 @@ import numpy as np
 
 from google.protobuf import text_format
 
-from tensorflow.contrib import proto
+from tensorflow.contrib.proto import decode_proto
 from tensorflow.contrib.proto.python.kernel_tests import test_case
 from tensorflow.contrib.proto.python.kernel_tests import test_example_pb2
 from tensorflow.python.framework import dtypes
@@ -175,7 +175,7 @@ class DecodeProtoOpTest(test_case.ProtoOpTestCase):
     output_types = [f.dtype for f in fields]
 
     with self.test_session() as sess:
-      sizes, vtensor = proto.decode_proto(
+      sizes, vtensor = decode_proto(
           batch,
           message_type=message_type,
           field_names=field_names,
index 2a24c3b..a289ff2 100644 (file)
@@ -30,7 +30,8 @@ import numpy as np
 
 from google.protobuf import text_format
 
-from tensorflow.contrib import proto
+from tensorflow.contrib.proto import decode_proto
+from tensorflow.contrib.proto import encode_proto
 from tensorflow.contrib.proto.python.kernel_tests import test_case
 from tensorflow.contrib.proto.python.kernel_tests import test_example_pb2
 from tensorflow.python.framework import dtypes
@@ -50,7 +51,7 @@ class EncodeProtoOpTest(test_case.ProtoOpTestCase):
     # Invalid field name
     with self.test_session():
       with self.assertRaisesOpError('Unknown field: non_existent_field'):
-        proto.encode_proto(
+        encode_proto(
             sizes=[[1]],
             values=[np.array([[0.0]], dtype=np.int32)],
             message_type='tensorflow.contrib.proto.RepeatedPrimitiveValue',
@@ -60,7 +61,7 @@ class EncodeProtoOpTest(test_case.ProtoOpTestCase):
     with self.test_session():
       with self.assertRaisesOpError(
           'Incompatible type for field double_value.'):
-        proto.encode_proto(
+        encode_proto(
             sizes=[[1]],
             values=[np.array([[0.0]], dtype=np.int32)],
             message_type='tensorflow.contrib.proto.RepeatedPrimitiveValue',
@@ -72,7 +73,7 @@ class EncodeProtoOpTest(test_case.ProtoOpTestCase):
           r'sizes should be batch_size \+ \[len\(field_names\)\]'):
         sizes = array_ops.placeholder(dtypes.int32)
         values = array_ops.placeholder(dtypes.float64)
-        proto.encode_proto(
+        encode_proto(
             sizes=sizes,
             values=[values],
             message_type='tensorflow.contrib.proto.RepeatedPrimitiveValue',
@@ -88,7 +89,7 @@ class EncodeProtoOpTest(test_case.ProtoOpTestCase):
         sizes = array_ops.placeholder(dtypes.int32)
         values1 = array_ops.placeholder(dtypes.float64)
         values2 = array_ops.placeholder(dtypes.int32)
-        (proto.encode_proto(
+        (encode_proto(
             sizes=[[1, 1]],
             values=[values1, values2],
             message_type='tensorflow.contrib.proto.RepeatedPrimitiveValue',
@@ -103,13 +104,13 @@ class EncodeProtoOpTest(test_case.ProtoOpTestCase):
     out_types = [f.dtype for f in fields]
 
     with self.test_session() as sess:
-      sizes, field_tensors = proto.decode_proto(
+      sizes, field_tensors = decode_proto(
           in_bufs,
           message_type=message_type,
           field_names=field_names,
           output_types=out_types)
 
-      out_tensors = proto.encode_proto(
+      out_tensors = encode_proto(
           sizes,
           field_tensors,
           message_type=message_type,