From 3652556dab3ebfe0152232facc7304fe5754aecb Mon Sep 17 00:00:00 2001 From: Scott Zhu Date: Fri, 13 Apr 2018 17:52:20 -0700 Subject: [PATCH] Merge changes from github. PiperOrigin-RevId: 192850372 --- tensorflow/BUILD | 5 +- tensorflow/compiler/jit/BUILD | 1 + .../compiler/jit/mark_for_compilation_pass.cc | 4 + tensorflow/contrib/cmake/external/grpc.cmake | 1 + .../copy_graph/python/util/copy_elements.py | 4 +- tensorflow/contrib/data/__init__.py | 2 + tensorflow/contrib/data/python/kernel_tests/BUILD | 1 + .../python/kernel_tests/batch_dataset_op_test.py | 70 ++++ .../kernel_tests/sequence_dataset_op_test.py | 10 + tensorflow/contrib/data/python/ops/BUILD | 1 + tensorflow/contrib/data/python/ops/batching.py | 41 +++ tensorflow/contrib/distribute/python/values.py | 2 +- tensorflow/contrib/kernel_methods/python/losses.py | 6 +- .../python/mappers/random_fourier_features.py | 44 ++- .../python/mappers/random_fourier_features_test.py | 2 +- .../contrib/kfac/python/ops/fisher_blocks.py | 82 ++--- tensorflow/contrib/lite/build_ios_universal_lib.sh | 15 +- .../contrib/metrics/python/ops/metric_ops.py | 29 +- tensorflow/contrib/rnn/python/ops/rnn_cell.py | 2 +- .../seq2seq/python/ops/attention_wrapper.py | 4 +- tensorflow/contrib/sparsemax/__init__.py | 2 +- .../contrib/sparsemax/python/ops/sparsemax.py | 2 +- .../contrib/tensorrt/convert/convert_graph.cc | 10 +- .../contrib/tensorrt/convert/convert_nodes.cc | 68 +++- .../api_def/base_api/api_def_ClipByValue.pbtxt | 36 ++ .../api_def/python_api/api_def_ClipByValue.pbtxt | 4 + tensorflow/core/common_runtime/process_util.cc | 21 +- tensorflow/core/grappler/optimizers/BUILD | 23 +- tensorflow/core/kernels/BUILD | 2 + tensorflow/core/kernels/cwise_op_abs.cc | 2 - tensorflow/core/kernels/cwise_op_clip.cc | 225 ++++++++++++ tensorflow/core/kernels/cwise_op_clip.h | 61 ++++ tensorflow/core/kernels/cwise_op_clip_gpu.cu.cc | 134 +++++++ tensorflow/core/kernels/maxpooling_op.cc | 93 ++++- tensorflow/core/kernels/segment_reduction_ops.h | 6 + tensorflow/core/ops/dataset_ops.cc | 12 +- tensorflow/core/ops/math_ops.cc | 8 + tensorflow/core/platform/macros.h | 9 +- tensorflow/docs_src/community/documentation.md | 18 +- tensorflow/docs_src/extend/adding_an_op.md | 159 +++++---- .../docs_src/get_started/custom_estimators.md | 2 +- tensorflow/docs_src/install/install_c.md | 2 +- .../docs_src/performance/performance_guide.md | 8 +- tensorflow/docs_src/programmers_guide/debugger.md | 61 ++-- tensorflow/python/BUILD | 1 + tensorflow/python/framework/dtypes.py | 10 + tensorflow/python/framework/dtypes_test.py | 5 + tensorflow/python/framework/function_test.py | 3 +- tensorflow/python/framework/tensor_shape.py | 3 + tensorflow/python/framework/tensor_shape_test.py | 5 + .../python/keras/_impl/keras/utils/io_utils.py | 14 +- tensorflow/python/kernel_tests/clip_ops_test.py | 124 ++++++- tensorflow/python/kernel_tests/pooling_ops_test.py | 6 - tensorflow/python/ops/clip_ops.py | 30 ++ tensorflow/python/ops/hidden_ops.txt | 395 +++++++++++++++++++++ tensorflow/python/util/tf_inspect.py | 43 ++- tensorflow/tensorflow.bzl | 53 ++- .../tools/api/generator/create_python_api.py | 3 +- tensorflow/tools/docker/Dockerfile | 2 +- tensorflow/tools/docker/Dockerfile.devel | 2 + tensorflow/tools/docker/Dockerfile.devel-gpu | 2 + tensorflow/tools/docker/Dockerfile.gpu | 2 +- .../docker/notebooks/3_mnist_from_scratch.ipynb | 6 +- .../tools/docker/parameterized_docker_build.sh | 4 +- tensorflow/tools/docs/BUILD | 2 +- tensorflow/tools/docs/build_docs_test.py | 5 - tensorflow/tools/docs/generate_lib.py | 19 +- tensorflow/tools/docs/generate_lib_test.py | 3 - tensorflow/tools/docs/parser.py | 56 +-- tensorflow/tools/docs/parser_test.py | 80 ++++- tensorflow/tools/docs/pretty_docs.py | 12 +- tensorflow/tools/docs/py_guide_parser.py | 2 +- tensorflow/workspace.bzl | 13 +- 73 files changed, 1797 insertions(+), 402 deletions(-) create mode 100644 tensorflow/core/api_def/base_api/api_def_ClipByValue.pbtxt create mode 100644 tensorflow/core/api_def/python_api/api_def_ClipByValue.pbtxt create mode 100644 tensorflow/core/kernels/cwise_op_clip.cc create mode 100644 tensorflow/core/kernels/cwise_op_clip.h create mode 100644 tensorflow/core/kernels/cwise_op_clip_gpu.cu.cc create mode 100644 tensorflow/python/ops/hidden_ops.txt diff --git a/tensorflow/BUILD b/tensorflow/BUILD index cfafffd..f2ad16f 100644 --- a/tensorflow/BUILD +++ b/tensorflow/BUILD @@ -450,11 +450,12 @@ tf_cc_shared_object( linkstatic = 1, visibility = ["//visibility:public"], deps = [ + "//tensorflow/core:core_cpu_impl", "//tensorflow/core:framework_internal_impl", + "//tensorflow/core:gpu_runtime_impl", + "//tensorflow/core/grappler/optimizers:custom_graph_optimizer_registry_impl", "//tensorflow/core:lib_internal_impl", - "//tensorflow/core:core_cpu_impl", "//tensorflow/stream_executor:stream_executor_impl", - "//tensorflow/core:gpu_runtime_impl", ] + tf_additional_binary_deps(), ) diff --git a/tensorflow/compiler/jit/BUILD b/tensorflow/compiler/jit/BUILD index 6edeb70..50fa95c 100644 --- a/tensorflow/compiler/jit/BUILD +++ b/tensorflow/compiler/jit/BUILD @@ -318,6 +318,7 @@ cc_library( "//tensorflow/core:lib", "//tensorflow/core:lib_internal", "//tensorflow/core:protos_all_cc", + "//tensorflow/core/kernels:bounds_check", ], ) diff --git a/tensorflow/compiler/jit/mark_for_compilation_pass.cc b/tensorflow/compiler/jit/mark_for_compilation_pass.cc index 0c9fbf3..8e2ee0f 100644 --- a/tensorflow/compiler/jit/mark_for_compilation_pass.cc +++ b/tensorflow/compiler/jit/mark_for_compilation_pass.cc @@ -35,6 +35,7 @@ limitations under the License. #include "tensorflow/core/framework/types.h" #include "tensorflow/core/graph/algorithm.h" #include "tensorflow/core/graph/control_flow.h" +#include "tensorflow/core/kernels/bounds_check.h" #include "tensorflow/core/lib/strings/strcat.h" #include "tensorflow/core/public/version.h" @@ -441,6 +442,9 @@ string DescribeCycle(const GraphCycles& cycles, const Graph& graph, int src, } auto node_name = [&cycles, &graph](int node_id) { + if (!FastBoundsCheck(node_id, graph.num_node_ids())) { + return string("(null)"); + } auto* node = graph.FindNodeId(node_id); if (node == nullptr) { return string("(null)"); diff --git a/tensorflow/contrib/cmake/external/grpc.cmake b/tensorflow/contrib/cmake/external/grpc.cmake index bec8177..35c2a29 100644 --- a/tensorflow/contrib/cmake/external/grpc.cmake +++ b/tensorflow/contrib/cmake/external/grpc.cmake @@ -35,6 +35,7 @@ else() set(grpc_STATIC_LIBRARIES ${CMAKE_CURRENT_BINARY_DIR}/grpc/src/grpc/libgrpc++_unsecure.a ${CMAKE_CURRENT_BINARY_DIR}/grpc/src/grpc/libgrpc_unsecure.a + ${CMAKE_CURRENT_BINARY_DIR}/grpc/src/grpc/libaddress_sorting.a ${CMAKE_CURRENT_BINARY_DIR}/grpc/src/grpc/third_party/cares/cares/lib/libcares.a ${CMAKE_CURRENT_BINARY_DIR}/grpc/src/grpc/libgpr.a) endif() diff --git a/tensorflow/contrib/copy_graph/python/util/copy_elements.py b/tensorflow/contrib/copy_graph/python/util/copy_elements.py index b806799..102bc46 100644 --- a/tensorflow/contrib/copy_graph/python/util/copy_elements.py +++ b/tensorflow/contrib/copy_graph/python/util/copy_elements.py @@ -201,7 +201,7 @@ def copy_op_to_graph(org_instance, to_graph, variables, scope=''): #An instance of tensorflow.core.framework.node_def_pb2.NodeDef, it #stores String-based info such as name, device and type of the op. #Unique to every Operation instance. - new_node_def = deepcopy(op._node_def) + new_node_def = deepcopy(op.node_def) #Change the name new_node_def.name = new_name @@ -211,7 +211,7 @@ def copy_op_to_graph(org_instance, to_graph, variables, scope=''): #Make a copy of the op_def too. #Its unique to every _type_ of Operation. - op_def = deepcopy(op._op_def) + op_def = deepcopy(op.op_def) #Initialize a new Operation instance new_op = ops.Operation(new_node_def, to_graph, new_inputs, output_types, diff --git a/tensorflow/contrib/data/__init__.py b/tensorflow/contrib/data/__init__.py index f58e5ec..637b1dc 100644 --- a/tensorflow/contrib/data/__init__.py +++ b/tensorflow/contrib/data/__init__.py @@ -25,6 +25,7 @@ See the @{$datasets$Importing Data} Programmer's Guide for an overview. @@Counter @@SqlDataset +@@assert_element_shape @@batch_and_drop_remainder @@bucket_by_sequence_length @@dense_to_sparse_batch @@ -55,6 +56,7 @@ from __future__ import print_function # pylint: disable=unused-import +from tensorflow.contrib.data.python.ops.batching import assert_element_shape from tensorflow.contrib.data.python.ops.batching import batch_and_drop_remainder from tensorflow.contrib.data.python.ops.batching import dense_to_sparse_batch from tensorflow.contrib.data.python.ops.batching import map_and_batch diff --git a/tensorflow/contrib/data/python/kernel_tests/BUILD b/tensorflow/contrib/data/python/kernel_tests/BUILD index a8481dc..b475c9f 100644 --- a/tensorflow/contrib/data/python/kernel_tests/BUILD +++ b/tensorflow/contrib/data/python/kernel_tests/BUILD @@ -21,6 +21,7 @@ py_test( "//tensorflow/python:dtypes", "//tensorflow/python:errors", "//tensorflow/python:math_ops", + "//tensorflow/python:script_ops", "//tensorflow/python:sparse_tensor", "//tensorflow/python:string_ops", "//tensorflow/python:tensor_shape", diff --git a/tensorflow/contrib/data/python/kernel_tests/batch_dataset_op_test.py b/tensorflow/contrib/data/python/kernel_tests/batch_dataset_op_test.py index 75482f6..413d873 100644 --- a/tensorflow/contrib/data/python/kernel_tests/batch_dataset_op_test.py +++ b/tensorflow/contrib/data/python/kernel_tests/batch_dataset_op_test.py @@ -28,8 +28,10 @@ from tensorflow.python.framework import constant_op from tensorflow.python.framework import dtypes from tensorflow.python.framework import errors from tensorflow.python.framework import sparse_tensor +from tensorflow.python.framework import tensor_shape from tensorflow.python.ops import array_ops from tensorflow.python.ops import math_ops +from tensorflow.python.ops import script_ops from tensorflow.python.ops import string_ops from tensorflow.python.platform import test @@ -579,5 +581,73 @@ class PaddedBatchDatasetSerializationTest( lambda: build_dataset(seq_lens2), 8) +class RestructuredDatasetTest(test.TestCase): + + def test_assert_element_shape(self): + + def create_unknown_shape_dataset(x): + return script_ops.py_func(lambda _: (np.ones(2, dtype=np.float32), + np.zeros((3, 4), dtype=np.int32)), + [x], + [dtypes.float32, dtypes.int32]) + + dataset = dataset_ops.Dataset.range(5).map(create_unknown_shape_dataset) + unknown_shapes = (tensor_shape.TensorShape(None), + tensor_shape.TensorShape(None)) + self.assertEqual(unknown_shapes, dataset.output_shapes) + + expected_shapes = (tensor_shape.TensorShape(2), + tensor_shape.TensorShape((3, 4))) + result = dataset.apply(batching.assert_element_shape(expected_shapes)) + self.assertEqual(expected_shapes, result.output_shapes) + + iterator = result.make_initializable_iterator() + init_op = iterator.initializer + get_next = iterator.get_next() + with self.test_session() as sess: + sess.run(init_op) + for _ in range(5): + sess.run(get_next) + with self.assertRaises(errors.OutOfRangeError): + sess.run(get_next) + + def test_assert_wrong_element_shape(self): + + def create_dataset(_): + return (array_ops.ones(2, dtype=dtypes.float32), + array_ops.zeros((3, 4), dtype=dtypes.int32)) + + dataset = dataset_ops.Dataset.range(3).map(create_dataset) + wrong_shapes = (tensor_shape.TensorShape(2), + tensor_shape.TensorShape((3, 10))) + with self.assertRaises(ValueError): + dataset.apply(batching.assert_element_shape(wrong_shapes)) + + def test_assert_wrong_element_shape_on_unknown_shape_dataset(self): + + def create_unknown_shape_dataset(x): + return script_ops.py_func(lambda _: (np.ones(2, dtype=np.float32), + np.zeros((3, 4), dtype=np.int32)), + [x], + [dtypes.float32, dtypes.int32]) + + dataset = dataset_ops.Dataset.range(3).map(create_unknown_shape_dataset) + unknown_shapes = (tensor_shape.TensorShape(None), + tensor_shape.TensorShape(None)) + self.assertEqual(unknown_shapes, dataset.output_shapes) + + wrong_shapes = (tensor_shape.TensorShape(2), + tensor_shape.TensorShape((3, 10))) + iterator = ( + dataset.apply(batching.assert_element_shape(wrong_shapes)) + .make_initializable_iterator()) + init_op = iterator.initializer + get_next = iterator.get_next() + with self.test_session() as sess: + sess.run(init_op) + with self.assertRaises(errors.InvalidArgumentError): + sess.run(get_next) + + if __name__ == "__main__": test.main() diff --git a/tensorflow/contrib/data/python/kernel_tests/sequence_dataset_op_test.py b/tensorflow/contrib/data/python/kernel_tests/sequence_dataset_op_test.py index b044ff1..d0cb203 100644 --- a/tensorflow/contrib/data/python/kernel_tests/sequence_dataset_op_test.py +++ b/tensorflow/contrib/data/python/kernel_tests/sequence_dataset_op_test.py @@ -47,6 +47,11 @@ class SequenceDatasetSerializationTest( # Skip nothing self.run_core_tests(lambda: self._build_skip_dataset(0), None, 10) + def testInvalidSkip(self): + with self.assertRaisesRegexp(ValueError, + 'Shape must be rank 0 but is rank 1'): + self.run_core_tests(lambda: self._build_skip_dataset([1, 2]), None, 0) + def _build_take_dataset(self, count): components = (np.arange(10),) return dataset_ops.Dataset.from_tensor_slices(components).take(count) @@ -69,6 +74,11 @@ class SequenceDatasetSerializationTest( # Take nothing self.run_core_tests(lambda: self._build_take_dataset(0), None, 0) + def testInvalidTake(self): + with self.assertRaisesRegexp(ValueError, + 'Shape must be rank 0 but is rank 1'): + self.run_core_tests(lambda: self._build_take_dataset([1, 2]), None, 0) + def _build_repeat_dataset(self, count, take_count=3): components = (np.arange(10),) return dataset_ops.Dataset.from_tensor_slices(components).take( diff --git a/tensorflow/contrib/data/python/ops/BUILD b/tensorflow/contrib/data/python/ops/BUILD index 7c28d1f..0e45908 100644 --- a/tensorflow/contrib/data/python/ops/BUILD +++ b/tensorflow/contrib/data/python/ops/BUILD @@ -112,6 +112,7 @@ py_library( srcs = ["batching.py"], srcs_version = "PY2AND3", deps = [ + "//tensorflow/contrib/framework:framework_py", "//tensorflow/python:array_ops", "//tensorflow/python:dataset_ops_gen", "//tensorflow/python:dtypes", diff --git a/tensorflow/contrib/data/python/ops/batching.py b/tensorflow/contrib/data/python/ops/batching.py index a212adf..28db949 100644 --- a/tensorflow/contrib/data/python/ops/batching.py +++ b/tensorflow/contrib/data/python/ops/batching.py @@ -17,6 +17,7 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function +from tensorflow.contrib.framework import with_shape from tensorflow.python.data.ops import dataset_ops from tensorflow.python.data.util import nest from tensorflow.python.data.util import sparse @@ -345,6 +346,46 @@ class _RestructuredDataset(dataset_ops.Dataset): return self._output_shapes +def assert_element_shape(expected_shapes): + """Assert the shape of this `Dataset`. + + ```python + shapes = [tf.TensorShape([16, 256]), tf.TensorShape(None)] + result = dataset.apply(tf.contrib.data.assert_element_shape(shapes)) + print(result.output_shapes) # ==> "((16, 256), )" + ``` + + If dataset shapes and expected_shape, are fully defined, assert they match. + Otherwise, add assert op that will validate the shapes when tensors are + evaluated, and set shapes on tensors, respectively. + + Args: + expected_shapes: A nested structure of `tf.TensorShape` objects. + + Returns: + A `Dataset` transformation function, which can be passed to + @{tf.data.Dataset.apply} + """ + + def _check_shape(*elements): + flatten_tensors = nest.flatten(elements) + flatten_shapes = nest.flatten(expected_shapes) + checked_tensors = [ + with_shape(shape, tensor) + for shape, tensor in zip(flatten_shapes, flatten_tensors) + ] + return nest.pack_sequence_as(elements, checked_tensors) + + def _apply_fn(dataset): + return _RestructuredDataset( + dataset.map(_check_shape), + dataset.output_types, + output_shapes=expected_shapes, + output_classes=dataset.output_classes) + + return _apply_fn + + class _MapAndBatchDataset(dataset_ops.MapDataset): """A `Dataset` that maps a function over a batch of elements.""" diff --git a/tensorflow/contrib/distribute/python/values.py b/tensorflow/contrib/distribute/python/values.py index 9acb6a9..87bf059 100644 --- a/tensorflow/contrib/distribute/python/values.py +++ b/tensorflow/contrib/distribute/python/values.py @@ -73,7 +73,7 @@ class DistributedValues(object): @property def devices(self): - return self._index.keys() + return list(self._index.keys()) def __str__(self): return "%s:%s" % (self.__class__.__name__, self._index) diff --git a/tensorflow/contrib/kernel_methods/python/losses.py b/tensorflow/contrib/kernel_methods/python/losses.py index f182fef..4ef0a66 100644 --- a/tensorflow/contrib/kernel_methods/python/losses.py +++ b/tensorflow/contrib/kernel_methods/python/losses.py @@ -43,10 +43,10 @@ def sparse_multiclass_hinge_loss( This is a generalization of standard (binary) hinge loss. For a given instance with correct label c*, the loss is given by: - loss = max_{c != c*} logits_c - logits_{c*} + 1. + $$loss = max_{c != c*} logits_c - logits_{c*} + 1.$$ or equivalently - loss = max_c { logits_c - logits_{c*} + I_{c != c*} } - where I_{c != c*} = 1 if c != c* and 0 otherwise. + $$loss = max_c { logits_c - logits_{c*} + I_{c != c*} }$$ + where \\(I_{c != c*} = 1\ \text{if}\ c != c*\\) and 0 otherwise. Args: labels: `Tensor` of shape [batch_size] or [batch_size, 1]. Corresponds to diff --git a/tensorflow/contrib/kernel_methods/python/mappers/random_fourier_features.py b/tensorflow/contrib/kernel_methods/python/mappers/random_fourier_features.py index 9dc0112..9a721a9 100644 --- a/tensorflow/contrib/kernel_methods/python/mappers/random_fourier_features.py +++ b/tensorflow/contrib/kernel_methods/python/mappers/random_fourier_features.py @@ -34,33 +34,31 @@ class RandomFourierFeatureMapper(dkm.DenseKernelMapper): r"""Class that implements Random Fourier Feature Mapping (RFFM) in TensorFlow. The RFFM mapping is used to approximate the Gaussian (RBF) kernel: - ``` - exp(-||x-y||_2^2 / (2 * sigma^2)) - ``` + $$(exp(-||x-y||_2^2 / (2 * \sigma^2))$$ The implementation of RFFM is based on the following paper: "Random Features for Large-Scale Kernel Machines" by Ali Rahimi and Ben Recht. (link: https://people.eecs.berkeley.edu/~brecht/papers/07.rah.rec.nips.pdf) - The mapping uses a matrix `Omega \in R^{d x D}` and a bias vector `b \in R^D` - where `d` is the input dimension (number of dense input features) and `D` is - the output dimension (i.e., dimension of the feature space the input is mapped - to). Each entry of `Omega` is sampled i.i.d. from a (scaled) Gaussian - distribution and each entry of `b` is sampled independently and uniformly from - [0, 2 * pi]. - - For a single input feature vector x in R^d, its RFFM is defined as: - ``` - sqrt(2/D) * cos(x * Omega + b) - ``` - where `cos` is the element-wise cosine function and `x, b` are represented as - row vectors. The aforementioned paper shows that the linear kernel of - RFFM-mapped vectors approximates the Gaussian kernel of the initial vectors. + The mapping uses a matrix \\(\Omega \in R^{d x D}\\) and a bias vector + \\(b \in R^D\\) where \\(d\\) is the input dimension (number of dense input + features) and \\(D\\) is the output dimension (i.e., dimension of the feature + space the input is mapped to). Each entry of \\(\Omega\\) is sampled i.i.d. + from a (scaled) Gaussian distribution and each entry of \\(b\\) is sampled + independently and uniformly from [0, \\(2 * \pi\\)]. + + For a single input feature vector \\(x \in R^d\\), its RFFM is defined as: + $$\sqrt(2/D) * cos(x * \Omega + b)$$ + + where \\(cos\\) is the element-wise cosine function and \\(x, b\\) are + represented as row vectors. The aforementioned paper shows that the linear + kernel of RFFM-mapped vectors approximates the Gaussian kernel of the initial + vectors. """ def __init__(self, input_dim, output_dim, stddev=1.0, seed=1, name=None): - """Constructs a RandomFourierFeatureMapper instance. + r"""Constructs a RandomFourierFeatureMapper instance. Args: input_dim: The dimension (number of features) of the tensors to be mapped. @@ -68,11 +66,11 @@ class RandomFourierFeatureMapper(dkm.DenseKernelMapper): stddev: The standard deviation of the Gaussian kernel to be approximated. The error of the classifier trained using this approximation is very sensitive to this parameter. - seed: An integer used to initialize the parameters (`Omega` and `b`) of - the mapper. For repeatable sequences across different invocations of the - mapper object (for instance, to ensure consistent mapping both at - training and eval/inference if these happen in different invocations), - set this to the same integer. + seed: An integer used to initialize the parameters (\\(\Omega\\) and + \\(b\\)) of the mapper. For repeatable sequences across different + invocations of the mapper object (for instance, to ensure consistent + mapping both at training and eval/inference if these happen in + different invocations), set this to the same integer. name: name for the mapper object. """ # TODO(sibyl-vie3Poto): Maybe infer input_dim and/or output_dim (if not explicitly diff --git a/tensorflow/contrib/kernel_methods/python/mappers/random_fourier_features_test.py b/tensorflow/contrib/kernel_methods/python/mappers/random_fourier_features_test.py index 6f4a264..9192918 100644 --- a/tensorflow/contrib/kernel_methods/python/mappers/random_fourier_features_test.py +++ b/tensorflow/contrib/kernel_methods/python/mappers/random_fourier_features_test.py @@ -34,7 +34,7 @@ def _inner_product(x, y): """Inner product between tensors x and y. The input tensors are assumed to be in ROW representation, that is, the method - returns x * y^T. + returns \\(x * y^T\\). Args: x: input tensor in row format diff --git a/tensorflow/contrib/kfac/python/ops/fisher_blocks.py b/tensorflow/contrib/kfac/python/ops/fisher_blocks.py index e0d9cb5..00b3673 100644 --- a/tensorflow/contrib/kfac/python/ops/fisher_blocks.py +++ b/tensorflow/contrib/kfac/python/ops/fisher_blocks.py @@ -19,11 +19,11 @@ Information matrix. Suppose one has a model that parameterizes a posterior distribution over 'y' given 'x' with parameters 'params', p(y | x, params). Its Fisher Information matrix is given by, - F(params) = E[ v(x, y, params) v(x, y, params)^T ] + $$F(params) = E[ v(x, y, params) v(x, y, params)^T ]$$ where, - v(x, y, params) = (d / d params) log p(y | x, params) + $$v(x, y, params) = (d / d params) log p(y | x, params)$$ and the expectation is taken with respect to the data's distribution for 'x' and the model's posterior distribution for 'y', @@ -85,7 +85,7 @@ def normalize_damping(damping, num_replications): def compute_pi_tracenorm(left_cov, right_cov): """Computes the scalar constant pi for Tikhonov regularization/damping. - pi = sqrt( (trace(A) / dim(A)) / (trace(B) / dim(B)) ) + $$\pi = \sqrt{ (trace(A) / dim(A)) / (trace(B) / dim(B)) }$$ See section 6.3 of https://arxiv.org/pdf/1503.05671.pdf for details. Args: @@ -462,14 +462,14 @@ class FullyConnectedDiagonalFB(InputOutputMultiTower, FisherBlock): Let 'params' be a vector parameterizing a model and 'i' an arbitrary index into it. We are interested in Fisher(params)[i, i]. This is, - Fisher(params)[i, i] = E[ v(x, y, params) v(x, y, params)^T ][i, i] - = E[ v(x, y, params)[i] ^ 2 ] + $$Fisher(params)[i, i] = E[ v(x, y, params) v(x, y, params)^T ][i, i] + = E[ v(x, y, params)[i] ^ 2 ]$$ Consider fully connected layer in this model with (unshared) weight matrix 'w'. For an example 'x' that produces layer inputs 'a' and output preactivations 's', - v(x, y, w) = vec( a (d loss / d s)^T ) + $$v(x, y, w) = vec( a (d loss / d s)^T )$$ This FisherBlock tracks Fisher(params)[i, i] for all indices 'i' corresponding to the layer's parameters 'w'. @@ -532,14 +532,14 @@ class ConvDiagonalFB(InputOutputMultiTower, FisherBlock): Let 'params' be a vector parameterizing a model and 'i' an arbitrary index into it. We are interested in Fisher(params)[i, i]. This is, - Fisher(params)[i, i] = E[ v(x, y, params) v(x, y, params)^T ][i, i] - = E[ v(x, y, params)[i] ^ 2 ] + $$Fisher(params)[i, i] = E[ v(x, y, params) v(x, y, params)^T ][i, i] + = E[ v(x, y, params)[i] ^ 2 ]$$ Consider a convoluational layer in this model with (unshared) filter matrix 'w'. For an example image 'x' that produces layer inputs 'a' and output preactivations 's', - v(x, y, w) = vec( sum_{loc} a_{loc} (d loss / d s_{loc})^T ) + $$v(x, y, w) = vec( sum_{loc} a_{loc} (d loss / d s_{loc})^T )$$ where 'loc' is a single (x, y) location in an image. @@ -805,12 +805,12 @@ class ConvKFCBasicFB(InputOutputMultiTower, KroneckerProductFB): 'w'. For a minibatch that produces inputs 'a' and output preactivations 's', this FisherBlock estimates, - F(w) = #locations * kronecker(E[flat(a) flat(a)^T], - E[flat(ds) flat(ds)^T]) + $$F(w) = \#locations * kronecker(E[flat(a) flat(a)^T], + E[flat(ds) flat(ds)^T])$$ where - ds = (d / ds) log p(y | x, w) + $$ds = (d / ds) log p(y | x, w)$$ #locations = number of (x, y) locations where 'w' is applied. where the expectation is taken over all examples and locations and flat() @@ -1567,7 +1567,7 @@ class FullyConnectedSeriesFB(InputOutputMultiTowerMultiUse, if self._option == SeriesFBApproximation.option1: - # Note that L_A = A0^(-1/2) * U_A and L_G = G0^(-1/2) * U_G. + # Note that \\(L_A = A0^{-1/2} * U_A and L_G = G0^{-1/2} * U_G.\\) L_A, psi_A = self._input_factor.get_option1quants( self._input_damping_func) L_G, psi_G = self._output_factor.get_option1quants( @@ -1581,33 +1581,33 @@ class FullyConnectedSeriesFB(InputOutputMultiTowerMultiUse, T = self._num_timesteps return (1 - x)**2 / (T * (1 - x**2) - 2 * x * (1 - x**T)) - # Y = gamma( psi_G*psi_A^T ) (computed element-wise) + # \\(Y = \gamma( psi_G*psi_A^T )\\) (computed element-wise) # Even though Y is Z-independent we are recomputing it from the psi's # each since Y depends on both A and G quantities, and it is relatively # cheap to compute. Y = gamma(array_ops.reshape(psi_G, [int(psi_G.shape[0]), -1]) * psi_A) - # Z = L_G^T * Z * L_A + # \\(Z = L_G^T * Z * L_A\\) # This is equivalent to the following computation from the original # pseudo-code: - # Z = G0^(-1/2) * Z * A0^(-1/2) - # Z = U_G^T * Z * U_A + # \\(Z = G0^{-1/2} * Z * A0^{-1/2}\\) + # \\(Z = U_G^T * Z * U_A\\) Z = math_ops.matmul(L_G, math_ops.matmul(Z, L_A), transpose_a=True) - # Z = Z .* Y + # \\(Z = Z .* Y\\) Z *= Y - # Z = L_G * Z * L_A^T + # \\(Z = L_G * Z * L_A^T\\) # This is equivalent to the following computation from the original # pseudo-code: - # Z = U_G * Z * U_A^T - # Z = G0^(-1/2) * Z * A0^(-1/2) + # \\(Z = U_G * Z * U_A^T\\) + # \\(Z = G0^{-1/2} * Z * A0^{-1/2}\\) Z = math_ops.matmul(L_G, math_ops.matmul(Z, L_A, transpose_b=True)) elif self._option == SeriesFBApproximation.option2: - # Note that P_A = A_1^T * A_0^(-1) and P_G = G_1^T * G_0^(-1), - # and K_A = A_0^(-1/2) * E_A and K_G = G_0^(-1/2) * E_G. + # Note that \\(P_A = A_1^T * A_0^{-1} and P_G = G_1^T * G_0^{-1}\\), + # and \\(K_A = A_0^{-1/2} * E_A\ and\ K_G = G_0^{-1/2} * E_G.\\) P_A, K_A, mu_A = self._input_factor.get_option2quants( self._input_damping_func) P_G, K_G, mu_G = self._output_factor.get_option2quants( @@ -1616,26 +1616,26 @@ class FullyConnectedSeriesFB(InputOutputMultiTowerMultiUse, # Our approach differs superficially from the pseudo-code in the paper # in order to reduce the total number of matrix-matrix multiplies. # In particular, the first three computations in the pseudo code are - # Z = G0^(-1/2) * Z * A0^(-1/2) - # Z = Z - hPsi_G^T * Z * hPsi_A - # Z = E_G^T * Z * E_A - # Noting that hPsi = C0^(-1/2) * C1 * C0^(-1/2), so that - # C0^(-1/2) * hPsi = C0^(-1) * C1 * C0^(-1/2) = P^T * C0^(-1/2) + # \\(Z = G0^{-1/2} * Z * A0^{-1/2}\\) + # \\(Z = Z - hPsi_G^T * Z * hPsi_A\\) + # \\(Z = E_G^T * Z * E_A\\) + # Noting that hPsi = C0^{-1/2} * C1 * C0^{-1/2}\\), so that + # \\(C0^{-1/2} * hPsi = C0^{-1} * C1 * C0^{-1/2} = P^T * C0^{-1/2}\\) # the entire computation can be written as - # Z = E_G^T * (G0^(-1/2) * Z * A0^(-1/2) - # - hPsi_G^T * G0^(-1/2) * Z * A0^(-1/2) * hPsi_A) * E_A - # = E_G^T * (G0^(-1/2) * Z * A0^(-1/2) - # - G0^(-1/2) * P_G * Z * P_A^T * A0^(-1/2)) * E_A - # = E_G^T * G0^(-1/2) * Z * A0^(-1/2) * E_A - # - E_G^T* G0^(-1/2) * P_G * Z * P_A^T * A0^(-1/2) * E_A - # = K_G^T * Z * K_A - K_G^T * P_G * Z * P_A^T * K_A + # \\(Z = E_G^T * (G0^{-1/2} * Z * A0^{-1/2}\\) + # \\( - hPsi_G^T * G0^{-1/2} * Z * A0^{-1/2} * hPsi_A) * E_A\\) + # \\( = E_G^T * (G0^{-1/2} * Z * A0^{-1/2}\\) + # \\( - G0^{-1/2} * P_G * Z * P_A^T * A0^{-1/2}) * E_A\\) + # \\( = E_G^T * G0^{-1/2} * Z * A0^{-1/2} * E_A\\) + # \\( - E_G^T* G0^{-1/2} * P_G * Z * P_A^T * A0^{-1/2} * E_A\\) + # \\( = K_G^T * Z * K_A - K_G^T * P_G * Z * P_A^T * K_A\\) # This final expression is computed by the following two lines: - # Z = Z - P_G * Z * P_A^T + # \\(Z = Z - P_G * Z * P_A^T\\) Z -= math_ops.matmul(P_G, math_ops.matmul(Z, P_A, transpose_b=True)) - # Z = K_G^T * Z * K_A + # \\(Z = K_G^T * Z * K_A\\) Z = math_ops.matmul(K_G, math_ops.matmul(Z, K_A), transpose_a=True) - # Z = Z ./ (1*1^T - mu_G*mu_A^T) + # \\(Z = Z ./ (1*1^T - mu_G*mu_A^T)\\) # Be careful with the outer product. We don't want to accidentally # make it an inner-product instead. tmp = 1.0 - array_ops.reshape(mu_G, [int(mu_G.shape[0]), -1]) * mu_A @@ -1646,13 +1646,13 @@ class FullyConnectedSeriesFB(InputOutputMultiTowerMultiUse, # We now perform the transpose/reverse version of the operations # derived above, whose derivation from the original pseudo-code is # analgous. - # Z = K_G * Z * K_A^T + # \\(Z = K_G * Z * K_A^T\\) Z = math_ops.matmul(K_G, math_ops.matmul(Z, K_A, transpose_b=True)) - # Z = Z - P_G^T * Z * P_A + # \\(Z = Z - P_G^T * Z * P_A\\) Z -= math_ops.matmul(P_G, math_ops.matmul(Z, P_A), transpose_a=True) - # Z = normalize (1/E[T]) * Z + # \\(Z = normalize (1/E[T]) * Z\\) # Note that this normalization is done because we compute the statistics # by averaging, not summing, over time. (And the gradient is presumably # summed over time, not averaged, and thus their scales are different.) diff --git a/tensorflow/contrib/lite/build_ios_universal_lib.sh b/tensorflow/contrib/lite/build_ios_universal_lib.sh index 4a9023f..9f398f4 100755 --- a/tensorflow/contrib/lite/build_ios_universal_lib.sh +++ b/tensorflow/contrib/lite/build_ios_universal_lib.sh @@ -19,11 +19,16 @@ set -e SCRIPT_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)" cd "$SCRIPT_DIR/../../.." -make -f tensorflow/contrib/lite/Makefile TARGET=IOS IOS_ARCH=x86_64 -j 8 -make -f tensorflow/contrib/lite/Makefile TARGET=IOS IOS_ARCH=i386 -j 8 -make -f tensorflow/contrib/lite/Makefile TARGET=IOS IOS_ARCH=armv7 -j 8 -make -f tensorflow/contrib/lite/Makefile TARGET=IOS IOS_ARCH=armv7s -j 8 -make -f tensorflow/contrib/lite/Makefile TARGET=IOS IOS_ARCH=arm64 -j 8 +make -f tensorflow/contrib/lite/Makefile TARGET=IOS IOS_ARCH=x86_64 -j 8 \ +$SCRIPT_DIR/gen/lib/ios_x86_64/libtensorflow-lite.a +make -f tensorflow/contrib/lite/Makefile TARGET=IOS IOS_ARCH=i386 -j 8 \ +$SCRIPT_DIR/gen/lib/ios_i386/libtensorflow-lite.a +make -f tensorflow/contrib/lite/Makefile TARGET=IOS IOS_ARCH=armv7 -j 8 \ +$SCRIPT_DIR/gen/lib/ios_armv7/libtensorflow-lite.a +make -f tensorflow/contrib/lite/Makefile TARGET=IOS IOS_ARCH=armv7s -j 8 \ +$SCRIPT_DIR/gen/lib/ios_armv7s/libtensorflow-lite.a +make -f tensorflow/contrib/lite/Makefile TARGET=IOS IOS_ARCH=arm64 -j 8 \ +$SCRIPT_DIR/gen/lib/ios_arm64/libtensorflow-lite.a lipo \ tensorflow/contrib/lite/gen/lib/ios_x86_64/libtensorflow-lite.a \ diff --git a/tensorflow/contrib/metrics/python/ops/metric_ops.py b/tensorflow/contrib/metrics/python/ops/metric_ops.py index 81f05e7..9c8ae48 100644 --- a/tensorflow/contrib/metrics/python/ops/metric_ops.py +++ b/tensorflow/contrib/metrics/python/ops/metric_ops.py @@ -63,6 +63,8 @@ def _safe_div(numerator, denominator, name): name=name) +@deprecated(None, 'Please switch to tf.metrics.true_positives. Note that the ' + 'order of the labels and predictions arguments has been switched.') def streaming_true_positives(predictions, labels, weights=None, @@ -107,6 +109,8 @@ def streaming_true_positives(predictions, name=name) +@deprecated(None, 'Please switch to tf.metrics.true_negatives. Note that the ' + 'order of the labels and predictions arguments has been switched.') def streaming_true_negatives(predictions, labels, weights=None, @@ -151,6 +155,8 @@ def streaming_true_negatives(predictions, name=name) +@deprecated(None, 'Please switch to tf.metrics.false_positives. Note that the ' + 'order of the labels and predictions arguments has been switched.') def streaming_false_positives(predictions, labels, weights=None, @@ -195,6 +201,8 @@ def streaming_false_positives(predictions, name=name) +@deprecated(None, 'Please switch to tf.metrics.false_negatives. Note that the ' + 'order of the labels and predictions arguments has been switched.') def streaming_false_negatives(predictions, labels, weights=None, @@ -238,6 +246,7 @@ def streaming_false_negatives(predictions, name=name) +@deprecated(None, 'Please switch to tf.metrics.mean') def streaming_mean(values, weights=None, metrics_collections=None, @@ -287,6 +296,7 @@ def streaming_mean(values, name=name) +@deprecated(None, 'Please switch to tf.metrics.mean_tensor') def streaming_mean_tensor(values, weights=None, metrics_collections=None, @@ -340,9 +350,8 @@ def streaming_mean_tensor(values, name=name) -@deprecated(None, - 'Please switch to tf.metrics.accuracy. Note that the order of the ' - 'labels and predictions arguments has been switched.') +@deprecated(None, 'Please switch to tf.metrics.accuracy. Note that the order ' + 'of the labels and predictions arguments has been switched.') def streaming_accuracy(predictions, labels, weights=None, @@ -400,6 +409,8 @@ def streaming_accuracy(predictions, name=name) +@deprecated(None, 'Please switch to tf.metrics.precision. Note that the order ' + 'of the labels and predictions arguments has been switched.') def streaming_precision(predictions, labels, weights=None, @@ -456,6 +467,8 @@ def streaming_precision(predictions, name=name) +@deprecated(None, 'Please switch to tf.metrics.recall. Note that the order ' + 'of the labels and predictions arguments has been switched.') def streaming_recall(predictions, labels, weights=None, @@ -975,8 +988,8 @@ def streaming_curve_points(labels=None, return points, update_op -@deprecated(None, 'Please switch to tf.metrics.auc. Note that the order of the ' - 'labels and predictions arguments has been switched.') +@deprecated(None, 'Please switch to tf.metrics.auc. Note that the order of ' + 'the labels and predictions arguments has been switched.') def streaming_auc(predictions, labels, weights=None, @@ -1797,9 +1810,9 @@ def streaming_sensitivity_at_specificity(predictions, name=name) -@deprecated( - None, 'Please switch to tf.metrics.precision_at_thresholds. Note that the ' - 'order of the labels and predictions arguments has been switched.') +@deprecated(None, + 'Please switch to tf.metrics.precision_at_thresholds. Note that ' + 'the order of the labels and predictions arguments are switched.') def streaming_precision_at_thresholds(predictions, labels, thresholds, diff --git a/tensorflow/contrib/rnn/python/ops/rnn_cell.py b/tensorflow/contrib/rnn/python/ops/rnn_cell.py index 2f6ae9f..b12e2cd 100644 --- a/tensorflow/contrib/rnn/python/ops/rnn_cell.py +++ b/tensorflow/contrib/rnn/python/ops/rnn_cell.py @@ -2891,7 +2891,7 @@ class WeightNormLSTMCell(rnn_cell_impl.RNNCell): output_size = weight.get_shape().as_list()[1] g = vs.get_variable(name, [output_size], dtype=weight.dtype) - return nn_impl.l2_normalize(weight, dim=0) * g + return nn_impl.l2_normalize(weight, axis=0) * g def _linear(self, args, diff --git a/tensorflow/contrib/seq2seq/python/ops/attention_wrapper.py b/tensorflow/contrib/seq2seq/python/ops/attention_wrapper.py index 9e0d695..f0f143d 100644 --- a/tensorflow/contrib/seq2seq/python/ops/attention_wrapper.py +++ b/tensorflow/contrib/seq2seq/python/ops/attention_wrapper.py @@ -610,8 +610,8 @@ def monotonic_attention(p_choose_i, previous_attention, mode): addition, once an input sequence element is attended to at a given output timestep, elements occurring before it cannot be attended to at subsequent output timesteps. This function generates attention distributions according - to these assumptions. For more information, see ``Online and Linear-Time - Attention by Enforcing Monotonic Alignments''. + to these assumptions. For more information, see `Online and Linear-Time + Attention by Enforcing Monotonic Alignments`. Args: p_choose_i: Probability of choosing input sequence/memory element i. Should diff --git a/tensorflow/contrib/sparsemax/__init__.py b/tensorflow/contrib/sparsemax/__init__.py index 19d213f..7bc726f 100644 --- a/tensorflow/contrib/sparsemax/__init__.py +++ b/tensorflow/contrib/sparsemax/__init__.py @@ -14,7 +14,7 @@ # ============================================================================== """Module that implements sparsemax and sparsemax loss, see [1]. -[1] https://arxiv.org/abs/1602.02068 +[1]: https://arxiv.org/abs/1602.02068 ## Sparsemax diff --git a/tensorflow/contrib/sparsemax/python/ops/sparsemax.py b/tensorflow/contrib/sparsemax/python/ops/sparsemax.py index 890ca20..e617af2 100644 --- a/tensorflow/contrib/sparsemax/python/ops/sparsemax.py +++ b/tensorflow/contrib/sparsemax/python/ops/sparsemax.py @@ -31,7 +31,7 @@ def sparsemax(logits, name=None): """Computes sparsemax activations [1]. For each batch `i` and class `j` we have - sparsemax[i, j] = max(logits[i, j] - tau(logits[i, :]), 0) + $$sparsemax[i, j] = max(logits[i, j] - tau(logits[i, :]), 0)$$ [1]: https://arxiv.org/abs/1602.02068 diff --git a/tensorflow/contrib/tensorrt/convert/convert_graph.cc b/tensorflow/contrib/tensorrt/convert/convert_graph.cc index ff8cc63..b412b29 100644 --- a/tensorflow/contrib/tensorrt/convert/convert_graph.cc +++ b/tensorflow/contrib/tensorrt/convert/convert_graph.cc @@ -405,7 +405,13 @@ tensorflow::Status ConvertGraphDefToTensorRT( max_mem_per_engine, static_graph_properties, &output_edge_map, precision_mode); if (precision_mode == INT8MODE) { - TF_RETURN_IF_ERROR(GetCalibNode(&p)); + tensorflow::Status status = GetCalibNode(&p); + if (status != tensorflow::Status::OK()) { + LOG(WARNING) << "subgraph conversion error for subgraph_index:" << count + << " due to: \"" << status.ToString() + << "\" SKIPPING......( " << subgraph_node_names.size() + << " nodes)"; + } } else { tensorflow::Status status = ConvertSubGraphToTensorRT(&p); if (status != tensorflow::Status::OK()) { @@ -414,8 +420,8 @@ tensorflow::Status ConvertGraphDefToTensorRT( << "\" SKIPPING......( " << subgraph_node_names.size() << " nodes)"; } - count++; } + count++; } graph.ToGraphDef(new_graph_def); return tensorflow::Status::OK(); diff --git a/tensorflow/contrib/tensorrt/convert/convert_nodes.cc b/tensorflow/contrib/tensorrt/convert/convert_nodes.cc index e920a79..b81ae9d 100644 --- a/tensorflow/contrib/tensorrt/convert/convert_nodes.cc +++ b/tensorflow/contrib/tensorrt/convert/convert_nodes.cc @@ -443,7 +443,9 @@ class Converter { * 2) Control dependency inputs contain caret at the beginning and we * remove this and annotate the edge as a control dependency. ************************************************************************/ - string name = input_name[0] == '^' ? input_name.substr(1) : input_name; + // skip control nodes + if (input_name[0] == '^') continue; + string name = input_name; auto first = name.find_first_of(':'); if (first != string::npos && first + 2 == name.size() && name[first + 1] == '0') @@ -2262,6 +2264,7 @@ tensorflow::Status InjectCalibrationNode(tensorrt::convert::SubGraphParams& s) { auto ws = new tensorflow::tensorrt::TRTWeightStore(); TF_CHECK_OK(weight_rmgr->Create(calib_op_name, calib_op_name, ws)); Converter converter(op_res->network_, ws, s.precision_mode == FP16MODE); + std::vector input_names; std::vector input_dtypes; for (const std::pair& input : s.input_inds) { @@ -2270,20 +2273,41 @@ tensorflow::Status InjectCalibrationNode(tensorrt::convert::SubGraphParams& s) { int output_idx = input.second; tensorflow::Node* node = s.graph.FindNodeId(node_id); auto node_name = node->name(); - input_names.push_back(node_name); // insert original node name without port - // TODO(jie): alternative :) - if (!s.graph_properties.HasOutputProperties(node_name)) + // input_names should use the node name in the graph + // here it should be the input tensor name -> matching the binding + // insert original node name without port + auto tensor_name = node_name; + if (output_idx != 0) { + tensor_name = StrCat(tensor_name, ":", output_idx); + } + + VLOG(2) << "input name: " << node_name << " tensor_name: " << tensor_name + << " idx: " << output_idx; + + auto shape_inference_node_name = node_name; + auto shape_inference_output_idx = output_idx; + // rewire the shape inference to original node in the graph + if (s.output_edge_map->count(tensor_name)) { + shape_inference_node_name = s.output_edge_map->at(tensor_name).second; + shape_inference_output_idx = s.output_edge_map->at(tensor_name).first; + } + if (shape_inference_output_idx < 0) continue; + VLOG(2) << "shapeinference name: " << shape_inference_node_name + << " idx: " << shape_inference_output_idx; + + if (!s.graph_properties.HasOutputProperties(shape_inference_node_name)) return tensorflow::errors::Internal("failed to find input node: " + - node_name); + shape_inference_node_name); - auto op_info_vec = s.graph_properties.GetOutputProperties(node_name); - if (static_cast(op_info_vec.size()) < output_idx) + auto op_info_vec = + s.graph_properties.GetOutputProperties(shape_inference_node_name); + if (static_cast(op_info_vec.size()) <= shape_inference_output_idx) return tensorflow::errors::Internal( - "accessing output index of: ", output_idx, ", at node: ", node_name, - "with output entry from shape_map: ", op_info_vec.size()); - - auto op_info = op_info_vec.at(output_idx); + "accessing output index of: ", shape_inference_output_idx, + ", at node: ", shape_inference_node_name, + " with output entry from shape_map: ", op_info_vec.size()); + auto op_info = op_info_vec.at(shape_inference_output_idx); tensorflow::DataType tf_dtype = op_info.dtype(); input_dtypes.push_back(tf_dtype); @@ -2294,16 +2318,23 @@ tensorflow::Status InjectCalibrationNode(tensorrt::convert::SubGraphParams& s) { << "' failed"; return type_status; } - TF_CHECK_OK(ConvertDType(tf_dtype, &dtype)); VLOG(2) << "accessing output index of: " << output_idx << ", at node: " << node_name << "with output entry from shape_map: " << op_info_vec.size(); - // TODO(ben,jie): update TRT input format/dimension nvinfer1::DimsCHW input_dim_psuedo_chw; for (int i = 0; i < 3; i++) input_dim_psuedo_chw.d[i] = 1; + // TODO(jie): TRT 3.x only support 4 dimensional input tensor. + // update the code once TRT 4.0 comes out. + if (op_info.shape().dim_size() != 4) { + string err_str = "Require 4 dimensional input."; + StrAppend(&err_str, " Got ", op_info.shape().dim_size(), " ", + shape_inference_node_name); + return tensorflow::errors::Unimplemented(err_str); + } + for (int i = 1; i < op_info.shape().dim_size(); i++) { VLOG(2) << "dimension: " << i << " , size: " << op_info.shape().dim(i).size(); @@ -2312,8 +2343,11 @@ tensorflow::Status InjectCalibrationNode(tensorrt::convert::SubGraphParams& s) { // TODO(ben,jie): proper way to restore input tensor name? auto input_tensor_name = node_name; - if (output_idx != 0) input_tensor_name = StrCat(node_name, ":", output_idx); + if (output_idx != 0) { + input_tensor_name = StrCat(node_name, ":", output_idx); + } + input_names.push_back(input_tensor_name); nvinfer1::ITensor* input_tensor = converter.network()->addInput( input_tensor_name.c_str(), dtype, input_dim_psuedo_chw); @@ -2377,11 +2411,13 @@ tensorflow::Status InjectCalibrationNode(tensorrt::convert::SubGraphParams& s) { tensor->setType(trt_dtype); } - VLOG(2) << "finished output"; + VLOG(2) << "Finished processing outputs"; // Build the engine op_res->builder_->setMaxBatchSize(s.max_batch_size); op_res->builder_->setMaxWorkspaceSize(s.max_workspace_size_bytes); + VLOG(0) << "Max batch size= " << s.max_batch_size + << " max workspace size= " << s.max_workspace_size_bytes; // Build the TRT op // TODO(sami,ben,jie): proper naming! @@ -2475,7 +2511,7 @@ tensorflow::Status ConvertSubGraphToTensorRTNodeDef( std::vector input_names; std::vector input_dtypes; for (const std::pair& input : s.input_inds) { - VLOG(2) << "parsing input!!!!!"; + VLOG(2) << "parsing input. Node id= " << input.first; int node_id = input.first; int output_idx = input.second; tensorflow::Node* node = s.graph.FindNodeId(node_id); diff --git a/tensorflow/core/api_def/base_api/api_def_ClipByValue.pbtxt b/tensorflow/core/api_def/base_api/api_def_ClipByValue.pbtxt new file mode 100644 index 0000000..803d897 --- /dev/null +++ b/tensorflow/core/api_def/base_api/api_def_ClipByValue.pbtxt @@ -0,0 +1,36 @@ +op { + graph_op_name: "ClipByValue" + in_arg { + name: "t" + description: <