const TensorShape &input_shape) {
#define BROADCAST_SHAPE(broadcast, reshape, NDIMS, input_shape, output_shape) \
for (int i = 0; i < NDIMS; i++) { \
- OP_REQUIRES(ctx, (broadcast[i] % reshape[i] == 0), \
- errors::InvalidArgument("invalid shape to broadcast from ", \
- input_shape.DebugString(), " to ", \
- output_shape.DebugString())); \
- broadcast[i] = broadcast[i] / reshape[i]; \
+ if (reshape[i] != broadcast[i]) { \
+ OP_REQUIRES(ctx, \
+ ((reshape[i] != 0) && (broadcast[i] % reshape[i] == 0)), \
+ errors::InvalidArgument("invalid shape to broadcast from ", \
+ input_shape.DebugString(), " to ", \
+ output_shape.DebugString())); \
+ broadcast[i] = broadcast[i] / reshape[i]; \
+ } else { \
+ broadcast[i] = 1; \
+ } \
}
+ if (output_shape.num_elements() == 0) {
+ return;
+ }
+ if (output_shape == input_shape) {
+ output_tensor.flat<T>().device(d) = input_tensor.flat<T>();
+ return;
+ }
+
switch (output_shape.dims()) {
+ case 0: {
+ if (input_shape.dims() > 0) {
+ ctx->CtxFailure(errors::InvalidArgument(
+ "invalid shape to broadcast from ", input_shape.DebugString(),
+ " to ", output_shape.DebugString()));
+ break;
+ }
+ output_tensor.scalar<T>().device(d) = input_tensor.scalar<T>();
+ break;
+ }
case 1: {
auto reshape = AsEigenDSizesWithPrefix<1>(input_shape);
auto broadcast = output_shape.AsEigenDSizes<1>();
auto broadcast = output_shape.AsEigenDSizes<4>();
BROADCAST_SHAPE(broadcast, reshape, 4, input_shape, output_shape);
-
auto output = output_tensor.tensor<T, 4>();
switch (input_shape.dims()) {
case 0: {
// so no check needed.
if (i >= in_offset) {
DimensionHandle in_dim = c->Dim(in, i - in_offset);
- if (c->ValueKnown(in_dim)) {
+ if (c->ValueKnown(in_dim) && c->Value(in_dim) != 0) {
if (c->Value(dim) % c->Value(in_dim) != 0) {
return errors::InvalidArgument(
"Cannot broadcast a tensor with shape ", c->DebugString(in),
tags = ["no_windows"],
)
+cuda_py_test(
+ name = "reduce_benchmark_test",
+ srcs = ["reduce_benchmark_test.py"],
+ additional_deps = [
+ "//tensorflow/python/eager:backprop",
+ "//tensorflow/python:client_testlib",
+ "//tensorflow/python/eager:context",
+ "//tensorflow/python:framework",
+ "//tensorflow/python:array_ops",
+ "//tensorflow/python:gradients",
+ "//tensorflow/python:math_ops",
+ "//tensorflow/python:platform",
+ "//tensorflow/python:platform_benchmark",
+ ],
+)
+
tf_py_test(
name = "bincount_op_test",
size = "small",
--- /dev/null
+# Copyright 2015 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Simple benchmarks for reductions and their gradients."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import time
+
+import numpy as np
+from six.moves import range # pylint: disable=redefined-builtin
+
+from tensorflow.core.protobuf import config_pb2
+from tensorflow.python.client import session
+from tensorflow.python.eager import backprop
+from tensorflow.python.eager import context
+from tensorflow.python.framework import constant_op
+from tensorflow.python.framework import ops
+from tensorflow.python.ops import array_ops
+from tensorflow.python.ops import gradients_impl
+from tensorflow.python.ops import math_ops
+from tensorflow.python.platform import test
+
+
+class ReduceBenchmarks(test.Benchmark):
+ """Benchmarks for reductions."""
+
+ def _run(self, func, num_iters):
+ # call func to maybe warm up the GPU
+ func()
+ start = time.time()
+ for _ in range(num_iters):
+ func()
+ end = time.time()
+ mean_us = (end - start) * 1e6 / num_iters
+ self.report_benchmark(
+ iters=num_iters,
+ wall_time=mean_us,
+ extras={"examples_per_sec": num_iters / (end - start)})
+
+ def benchmark_reduce_sum_grad_eager(self):
+ with context.eager_mode():
+ tensor = array_ops.zeros([100, 1000])
+
+ def fn():
+ backprop.gradients_function(math_ops.reduce_sum, [0])(tensor)
+
+ self._run(fn, 10000)
+
+ def benchmark_reduce_sum_grad_eager_cpu(self):
+ with context.eager_mode(), ops.device("/cpu:0"):
+ tensor = array_ops.zeros([100, 1000])
+
+ def fn():
+ backprop.gradients_function(math_ops.reduce_sum, [0])(tensor)
+
+ self._run(fn, 10000)
+
+ def benchmark_reduce_sum_grad_graph(self):
+ config = config_pb2.ConfigProto(
+ graph_options=config_pb2.GraphOptions(
+ optimizer_options=config_pb2.OptimizerOptions(
+ opt_level=config_pb2.OptimizerOptions.L0)))
+ with ops.Graph().as_default(), session.Session(config=config) as sess:
+
+ tensor = constant_op.constant(np.zeros([100, 1000], dtype=np.float32))
+ reduction = math_ops.reduce_sum(tensor)
+ grad, = gradients_impl.gradients(reduction, tensor)
+
+ def fn():
+ sess.run(grad.op)
+
+ self._run(fn, 10000)
+
+ def benchmark_reduce_sum_grad_graph_cpu(self):
+ config = config_pb2.ConfigProto(
+ graph_options=config_pb2.GraphOptions(
+ optimizer_options=config_pb2.OptimizerOptions(
+ opt_level=config_pb2.OptimizerOptions.L0)))
+ with ops.Graph().as_default(), session.Session(config=config) as sess:
+
+ with ops.device("/cpu:0"):
+ tensor = constant_op.constant(np.zeros([100, 1000], dtype=np.float32))
+ reduction = math_ops.reduce_sum(tensor)
+ grad, = gradients_impl.gradients(reduction, tensor)
+
+ def fn():
+ sess.run(grad.op)
+
+ self._run(fn, 10000)
+
+
+if __name__ == "__main__":
+ test.main()
"keep_dims", keep_dims)
if keepdims is None:
keepdims = False
+ input_tensor = ops.convert_to_tensor(input_tensor)
with ops.name_scope(name, "ReduceLogSumExp", [input_tensor]) as name:
raw_max = reduce_max(
input_tensor,
array_ops.zeros_like(raw_max)))
result = gen_math_ops.log(
reduce_sum(
- gen_math_ops.exp(input_tensor - my_max),
+ gen_math_ops.exp(gen_math_ops.sub(input_tensor, my_max)),
axis,
keepdims=keepdims,
reduction_indices=reduction_indices))
if not keepdims:
my_max = array_ops.reshape(my_max, array_ops.shape(result))
- result += my_max
+ result = gen_math_ops.add(result, my_max)
return _may_reduce_to_scalar(keepdims, axis, reduction_indices, result)
"""
# Example:
# cast needed for SparseTensor reductions
+ if context.executing_eagerly():
+ input_shape = input_shape.numpy()
+ axes = axes.numpy()
+ input_shape[axes] = 1
+ return input_shape
+
input_shape = to_int32(input_shape) # [2, 3, 5, 7]
axes = to_int32(axes) # [1, 2]