From 64bb1de61377f12859a719448b65b452b03047a7 Mon Sep 17 00:00:00 2001 From: Alexandre Passos Date: Mon, 30 Apr 2018 17:11:40 -0700 Subject: [PATCH] Faster reduce_logsoftmax (specially in eager) and bugfixes in broadcast_to PiperOrigin-RevId: 194870645 --- tensorflow/core/kernels/broadcast_to_op.h | 34 +++++-- tensorflow/core/ops/array_ops.cc | 2 +- tensorflow/python/kernel_tests/BUILD | 16 +++ .../python/kernel_tests/reduce_benchmark_test.py | 107 +++++++++++++++++++++ tensorflow/python/ops/math_ops.py | 11 ++- 5 files changed, 161 insertions(+), 9 deletions(-) create mode 100644 tensorflow/python/kernel_tests/reduce_benchmark_test.py diff --git a/tensorflow/core/kernels/broadcast_to_op.h b/tensorflow/core/kernels/broadcast_to_op.h index 608e9b6..73fdd5d 100644 --- a/tensorflow/core/kernels/broadcast_to_op.h +++ b/tensorflow/core/kernels/broadcast_to_op.h @@ -34,14 +34,37 @@ struct BroadcastTo { const TensorShape &input_shape) { #define BROADCAST_SHAPE(broadcast, reshape, NDIMS, input_shape, output_shape) \ for (int i = 0; i < NDIMS; i++) { \ - OP_REQUIRES(ctx, (broadcast[i] % reshape[i] == 0), \ - errors::InvalidArgument("invalid shape to broadcast from ", \ - input_shape.DebugString(), " to ", \ - output_shape.DebugString())); \ - broadcast[i] = broadcast[i] / reshape[i]; \ + if (reshape[i] != broadcast[i]) { \ + OP_REQUIRES(ctx, \ + ((reshape[i] != 0) && (broadcast[i] % reshape[i] == 0)), \ + errors::InvalidArgument("invalid shape to broadcast from ", \ + input_shape.DebugString(), " to ", \ + output_shape.DebugString())); \ + broadcast[i] = broadcast[i] / reshape[i]; \ + } else { \ + broadcast[i] = 1; \ + } \ } + if (output_shape.num_elements() == 0) { + return; + } + if (output_shape == input_shape) { + output_tensor.flat().device(d) = input_tensor.flat(); + return; + } + switch (output_shape.dims()) { + case 0: { + if (input_shape.dims() > 0) { + ctx->CtxFailure(errors::InvalidArgument( + "invalid shape to broadcast from ", input_shape.DebugString(), + " to ", output_shape.DebugString())); + break; + } + output_tensor.scalar().device(d) = input_tensor.scalar(); + break; + } case 1: { auto reshape = AsEigenDSizesWithPrefix<1>(input_shape); auto broadcast = output_shape.AsEigenDSizes<1>(); @@ -125,7 +148,6 @@ struct BroadcastTo { auto broadcast = output_shape.AsEigenDSizes<4>(); BROADCAST_SHAPE(broadcast, reshape, 4, input_shape, output_shape); - auto output = output_tensor.tensor(); switch (input_shape.dims()) { case 0: { diff --git a/tensorflow/core/ops/array_ops.cc b/tensorflow/core/ops/array_ops.cc index 88fc038..fce0b93 100644 --- a/tensorflow/core/ops/array_ops.cc +++ b/tensorflow/core/ops/array_ops.cc @@ -466,7 +466,7 @@ REGISTER_OP("BroadcastTo") // so no check needed. if (i >= in_offset) { DimensionHandle in_dim = c->Dim(in, i - in_offset); - if (c->ValueKnown(in_dim)) { + if (c->ValueKnown(in_dim) && c->Value(in_dim) != 0) { if (c->Value(dim) % c->Value(in_dim) != 0) { return errors::InvalidArgument( "Cannot broadcast a tensor with shape ", c->DebugString(in), diff --git a/tensorflow/python/kernel_tests/BUILD b/tensorflow/python/kernel_tests/BUILD index b4ff094..c892b6e 100644 --- a/tensorflow/python/kernel_tests/BUILD +++ b/tensorflow/python/kernel_tests/BUILD @@ -112,6 +112,22 @@ cuda_py_test( tags = ["no_windows"], ) +cuda_py_test( + name = "reduce_benchmark_test", + srcs = ["reduce_benchmark_test.py"], + additional_deps = [ + "//tensorflow/python/eager:backprop", + "//tensorflow/python:client_testlib", + "//tensorflow/python/eager:context", + "//tensorflow/python:framework", + "//tensorflow/python:array_ops", + "//tensorflow/python:gradients", + "//tensorflow/python:math_ops", + "//tensorflow/python:platform", + "//tensorflow/python:platform_benchmark", + ], +) + tf_py_test( name = "bincount_op_test", size = "small", diff --git a/tensorflow/python/kernel_tests/reduce_benchmark_test.py b/tensorflow/python/kernel_tests/reduce_benchmark_test.py new file mode 100644 index 0000000..3a2fb81 --- /dev/null +++ b/tensorflow/python/kernel_tests/reduce_benchmark_test.py @@ -0,0 +1,107 @@ +# Copyright 2015 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Simple benchmarks for reductions and their gradients.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import time + +import numpy as np +from six.moves import range # pylint: disable=redefined-builtin + +from tensorflow.core.protobuf import config_pb2 +from tensorflow.python.client import session +from tensorflow.python.eager import backprop +from tensorflow.python.eager import context +from tensorflow.python.framework import constant_op +from tensorflow.python.framework import ops +from tensorflow.python.ops import array_ops +from tensorflow.python.ops import gradients_impl +from tensorflow.python.ops import math_ops +from tensorflow.python.platform import test + + +class ReduceBenchmarks(test.Benchmark): + """Benchmarks for reductions.""" + + def _run(self, func, num_iters): + # call func to maybe warm up the GPU + func() + start = time.time() + for _ in range(num_iters): + func() + end = time.time() + mean_us = (end - start) * 1e6 / num_iters + self.report_benchmark( + iters=num_iters, + wall_time=mean_us, + extras={"examples_per_sec": num_iters / (end - start)}) + + def benchmark_reduce_sum_grad_eager(self): + with context.eager_mode(): + tensor = array_ops.zeros([100, 1000]) + + def fn(): + backprop.gradients_function(math_ops.reduce_sum, [0])(tensor) + + self._run(fn, 10000) + + def benchmark_reduce_sum_grad_eager_cpu(self): + with context.eager_mode(), ops.device("/cpu:0"): + tensor = array_ops.zeros([100, 1000]) + + def fn(): + backprop.gradients_function(math_ops.reduce_sum, [0])(tensor) + + self._run(fn, 10000) + + def benchmark_reduce_sum_grad_graph(self): + config = config_pb2.ConfigProto( + graph_options=config_pb2.GraphOptions( + optimizer_options=config_pb2.OptimizerOptions( + opt_level=config_pb2.OptimizerOptions.L0))) + with ops.Graph().as_default(), session.Session(config=config) as sess: + + tensor = constant_op.constant(np.zeros([100, 1000], dtype=np.float32)) + reduction = math_ops.reduce_sum(tensor) + grad, = gradients_impl.gradients(reduction, tensor) + + def fn(): + sess.run(grad.op) + + self._run(fn, 10000) + + def benchmark_reduce_sum_grad_graph_cpu(self): + config = config_pb2.ConfigProto( + graph_options=config_pb2.GraphOptions( + optimizer_options=config_pb2.OptimizerOptions( + opt_level=config_pb2.OptimizerOptions.L0))) + with ops.Graph().as_default(), session.Session(config=config) as sess: + + with ops.device("/cpu:0"): + tensor = constant_op.constant(np.zeros([100, 1000], dtype=np.float32)) + reduction = math_ops.reduce_sum(tensor) + grad, = gradients_impl.gradients(reduction, tensor) + + def fn(): + sess.run(grad.op) + + self._run(fn, 10000) + + +if __name__ == "__main__": + test.main() diff --git a/tensorflow/python/ops/math_ops.py b/tensorflow/python/ops/math_ops.py index b937273..5766057 100644 --- a/tensorflow/python/ops/math_ops.py +++ b/tensorflow/python/ops/math_ops.py @@ -1757,6 +1757,7 @@ def reduce_logsumexp(input_tensor, "keep_dims", keep_dims) if keepdims is None: keepdims = False + input_tensor = ops.convert_to_tensor(input_tensor) with ops.name_scope(name, "ReduceLogSumExp", [input_tensor]) as name: raw_max = reduce_max( input_tensor, @@ -1769,13 +1770,13 @@ def reduce_logsumexp(input_tensor, array_ops.zeros_like(raw_max))) result = gen_math_ops.log( reduce_sum( - gen_math_ops.exp(input_tensor - my_max), + gen_math_ops.exp(gen_math_ops.sub(input_tensor, my_max)), axis, keepdims=keepdims, reduction_indices=reduction_indices)) if not keepdims: my_max = array_ops.reshape(my_max, array_ops.shape(result)) - result += my_max + result = gen_math_ops.add(result, my_max) return _may_reduce_to_scalar(keepdims, axis, reduction_indices, result) @@ -2475,6 +2476,12 @@ def reduced_shape(input_shape, axes): """ # Example: # cast needed for SparseTensor reductions + if context.executing_eagerly(): + input_shape = input_shape.numpy() + axes = axes.numpy() + input_shape[axes] = 1 + return input_shape + input_shape = to_int32(input_shape) # [2, 3, 5, 7] axes = to_int32(axes) # [1, 2] -- 2.7.4