#include "tensorflow/core/util/padding.h"
#include "tensorflow/core/util/tensor_format.h"
#include "tensorflow/core/util/use_cudnn.h"
+#include "tensorflow/core/kernels/bounds_check.h"
#if GOOGLE_CUDA
#include "tensorflow/core/kernels/maxpooling_op_gpu.h"
static void SpatialMaxPoolWithArgMaxHelper(
OpKernelContext* context, Tensor* output, Tensor* output_arg_max,
Tensor* input_backprop, const Tensor& tensor_in, const Tensor& out_backprop,
- const PoolParameters& params, const Padding& padding) {
+ const PoolParameters& params) {
typedef Eigen::Map<const Eigen::Matrix<T, Eigen::Dynamic, Eigen::Dynamic>>
ConstEigenMatrixMap;
typedef Eigen::Map<Eigen::Matrix<T, Eigen::Dynamic, Eigen::Dynamic>>
}
}
- {
+ if (input_backprop != nullptr) {
auto input_backprop_flat = input_backprop->flat<T>();
auto out_arg_max_flat = output_arg_max->flat<int64>();
auto out_backprop_flat = out_backprop.flat<T>();
// Although this check is in the inner loop, it is worth its value
// so we don't end up with memory corruptions. Our benchmark shows that
// the performance impact is quite small
- CHECK(input_backprop_index >= in_start && input_backprop_index < in_end)
- << "Invalid input backprop index: " << input_backprop_index << ", "
- << in_start << ", " << in_end;
+ //CHECK(input_backprop_index >= in_start && input_backprop_index < in_end)
+ FastBoundsCheck(input_backprop_index - in_start, in_end - in_start);
input_backprop_flat(input_backprop_index) += out_backprop_flat(index);
}
}
SpatialMaxPoolWithArgMaxHelper<CPUDevice, T>(
context, &tensor_out_dup, &tensor_out_arg_max, output, tensor_in,
- out_backprop, params, padding_);
+ out_backprop, params);
}
private:
template <typename Device, typename T>
struct LaunchMaxPoolingWithArgmax;
+template <typename T>
+struct LaunchMaxPoolingWithArgmax<CPUDevice, T> {
+ static void launch(OpKernelContext* context, const PoolParameters& params,
+ const Tensor& input, Tensor* output, Tensor* argmax,
+ bool propogate_nans) {
+ Tensor unused;
+ SpatialMaxPoolWithArgMaxHelper<CPUDevice, T>(
+ context, output, argmax, nullptr, input, unused, params);
+ }
+};
+
template <typename Device, typename T>
class MaxPoolingWithArgmaxOp : public OpKernel {
public:
template <typename Device, typename T>
struct LaunchMaxPoolingGradWithArgmax;
+template <typename T>
+struct LaunchMaxPoolingGradWithArgmax<CPUDevice, T> {
+ typedef Eigen::Map<Eigen::Matrix<T, Eigen::Dynamic, Eigen::Dynamic>>
+ EigenMatrixMap;
+
+ static void launch(OpKernelContext* context, const PoolParameters& params,
+ const Tensor& grad_in, const Tensor& argmax,
+ Tensor* grad_out) {
+ const DeviceBase::CpuWorkerThreads& worker_threads =
+ *(context->device()->tensorflow_cpu_worker_threads());
+
+ auto shard = [&grad_in, &argmax, &grad_out](int64 start, int64 limit) {
+ const int64 batch_size =
+ GetTensorDim(grad_out->shape(), FORMAT_NHWC, 'N');
+ const int64 output_size_per_batch = grad_out->NumElements() / batch_size;
+ const int64 input_size_per_batch = grad_in.NumElements() / batch_size;
+
+ {
+ auto grad_out_flat = grad_out->flat<T>();
+ auto argmax_flat = argmax.flat<int64>();
+ auto grad_in_flat = grad_in.flat<T>();
+
+ const int64 output_start = start * output_size_per_batch;
+ const int64 output_end = limit * output_size_per_batch;
+ EigenMatrixMap inputShard(grad_out_flat.data() + output_start, 1,
+ output_end - output_start);
+ inputShard.setConstant(T(0));
+
+ const int input_start = start * input_size_per_batch;
+ const int input_end = limit * input_size_per_batch;
+ for (int64 index = input_start; index < input_end; index++) {
+ const int64 grad_out_index = argmax_flat(index);
+ CHECK(grad_out_index >= output_start && grad_out_index < output_end)
+ << "Invalid output gradient index: " << grad_out_index << ", "
+ << output_start << ", " << output_end;
+ grad_out_flat(grad_out_index) += grad_in_flat(index);
+ }
+ }
+ };
+
+ const int64 batch_size = GetTensorDim(grad_out->shape(), FORMAT_NHWC, 'N');
+ const int64 shard_cost = grad_out->NumElements() / batch_size;
+ Shard(worker_threads.num_threads, worker_threads.workers, batch_size,
+ shard_cost, shard);
+ }
+};
+
template <typename Device, typename T>
class MaxPoolingGradWithArgmaxOp : public OpKernel {
public:
.HostMemory("ksize") \
.HostMemory("strides") \
.TypeConstraint<T>("T"), \
- MaxPoolingGradGradOp<D##Device, T>);
+ MaxPoolingGradGradOp<D##Device, T>) \
+ REGISTER_KERNEL_BUILDER(Name("MaxPoolWithArgmax") \
+ .Device(DEVICE_##D) \
+ .TypeConstraint<int64>("Targmax") \
+ .TypeConstraint<T>("T"), \
+ MaxPoolingWithArgmaxOp<D##Device, T>); \
+ REGISTER_KERNEL_BUILDER(Name("MaxPoolGradWithArgmax") \
+ .Device(DEVICE_##D) \
+ .TypeConstraint<T>("T") \
+ .TypeConstraint<int64>("Targmax"), \
+ MaxPoolingGradWithArgmaxOp<D##Device, T>);
// Below kernels implemented only for CPU device.
#define REGISTER_CPU_ONLY_POOL_KERNELS(T) \
.HostMemory("strides") \
.TypeConstraint<T>("T"), \
MaxPoolingNoMaskV2Op<GPUDevice, T>); \
- REGISTER_KERNEL_BUILDER(Name("MaxPoolWithArgmax") \
- .Device(DEVICE_GPU) \
- .TypeConstraint<int64>("Targmax") \
- .TypeConstraint<T>("T"), \
- MaxPoolingWithArgmaxOp<GPUDevice, T>); \
- REGISTER_KERNEL_BUILDER(Name("MaxPoolGradWithArgmax") \
- .Device(DEVICE_GPU) \
- .TypeConstraint<T>("T") \
- .TypeConstraint<int64>("Targmax"), \
- MaxPoolingGradWithArgmaxOp<GPUDevice, T>); \
REGISTER_KERNEL_BUILDER(Name("MaxPoolGradGradWithArgmax") \
.Device(DEVICE_GPU) \
.TypeConstraint<T>("T") \
cpu_val, gpu_val, half_rtol=0.01, half_atol=0.01)
def testMaxPoolingWithArgmax(self):
- # MaxPoolWithArgMax is implemented only on CUDA.
- if not test.is_gpu_available(cuda_only=True):
- return
tensor_input = [1.0, 1.0, 1.0, 1.0, 0.0, 1.0, 1.0, 1.0, 1.0]
with self.test_session(use_gpu=True) as sess:
t = constant_op.constant(tensor_input, shape=[1, 3, 3, 1])
self.assertAllEqual(argmax.ravel(), [0, 1, 3, 5])
def testMaxPoolingGradWithArgmax(self):
- # MaxPoolWithArgMax is implemented only on CUDA.
- if not test.is_gpu_available(cuda_only=True):
- return
orig_input = [1.0, 1.0, 1.0, 1.0, 0.0, 1.0, 1.0, 1.0, 1.0]
tensor_input = [11.0, 12.0, 13.0, 14.0]
tensor_argmax = list(np.array([0, 1, 3, 5], dtype=np.int64))