From 87a3c967973641d3b0d2a16d17add184ed967392 Mon Sep 17 00:00:00 2001 From: Akshay Agrawal Date: Mon, 29 Jan 2018 18:30:59 -0800 Subject: [PATCH] TFE: Register a GPU kernel for tfe.py_func. PiperOrigin-RevId: 183765122 --- tensorflow/python/BUILD | 1 + .../python/kernel_tests/py_func_test.py | 86 +++++++++---------- tensorflow/python/lib/core/py_func.cc | 72 +++++++++++----- tensorflow/python/ops/script_ops.py | 38 +++++--- 4 files changed, 120 insertions(+), 77 deletions(-) diff --git a/tensorflow/python/BUILD b/tensorflow/python/BUILD index a323d5bc39..363ff6fae9 100644 --- a/tensorflow/python/BUILD +++ b/tensorflow/python/BUILD @@ -298,6 +298,7 @@ cc_library( ":safe_ptr", "//tensorflow/c:tf_status_helper", "//tensorflow/c/eager:c_api", + "//tensorflow/c/eager:c_api_internal", "//tensorflow/core:framework", "//tensorflow/core:lib", "//tensorflow/core:protos_all_cc", diff --git a/tensorflow/python/kernel_tests/py_func_test.py b/tensorflow/python/kernel_tests/py_func_test.py index 92fb68820e..c7181497d8 100644 --- a/tensorflow/python/kernel_tests/py_func_test.py +++ b/tensorflow/python/kernel_tests/py_func_test.py @@ -396,66 +396,66 @@ class PyFuncTest(test.TestCase): @test_util.run_in_graph_and_eager_modes() def testEagerSingleOutputFloat32(self): - a = array_ops.ones((3, 3), dtype=dtypes.float32) - x = array_ops.ones((3, 1), dtype=dtypes.float32) - output = script_ops.eager_py_func(matmul, inp=[a, x], Tout=dtypes.float32) - with self.test_session(): + with test_util.device(use_gpu=True): + a = array_ops.ones((3, 3), dtype=dtypes.float32) + x = array_ops.ones((3, 1), dtype=dtypes.float32) + output = script_ops.eager_py_func(matmul, inp=[a, x], Tout=dtypes.float32) ret = self.evaluate(output) self.assertAllClose(ret, [[3.0], [3.0], [3.0]]) @test_util.run_in_graph_and_eager_modes() def testEagerArrayOutput(self): - a = array_ops.ones((3, 3), dtype=dtypes.int32) - x = array_ops.ones((3, 1), dtype=dtypes.int32) - output = script_ops.eager_py_func( - lambda a, x: [matmul(a, x)], inp=[a, x], Tout=[dtypes.int32]) - - with self.test_session(): + with test_util.device(use_gpu=True): + a = array_ops.ones((3, 3), dtype=dtypes.float32) + x = array_ops.ones((3, 1), dtype=dtypes.float32) + output = script_ops.eager_py_func( + lambda a, x: [matmul(a, x)], inp=[a, x], Tout=[dtypes.float32]) ret = self.evaluate(output) - self.assertAllEqual(ret, [[[3], [3], [3]]]) + self.assertAllEqual(ret, [[[3.0], [3.0], [3.0]]]) @test_util.run_in_graph_and_eager_modes() def testEagerReturnNone(self): + with test_util.device(use_gpu=True): + def no_return_value(): + return - def no_return_value(): - return - - output = script_ops.eager_py_func(no_return_value, inp=[], Tout=[]) - ret = self.evaluate(output) - if context.in_eager_mode(): - self.assertEquals(len(ret), 0) - else: - self.assertIsNone(ret) + output = script_ops.eager_py_func(no_return_value, inp=[], Tout=[]) + ret = self.evaluate(output) + if context.in_eager_mode(): + self.assertEquals(len(ret), 0) + else: + self.assertIsNone(ret) @test_util.run_in_graph_and_eager_modes() def testEagerPyFuncInDefun(self): + with test_util.device(use_gpu=True): + def wrapper(): + a = array_ops.ones((3, 3), dtype=dtypes.float32) + x = array_ops.ones((3, 1), dtype=dtypes.float32) + return script_ops.eager_py_func(matmul, inp=[a, x], Tout=dtypes.float32) - def wrapper(): - a = array_ops.ones((3, 3), dtype=dtypes.int32) - x = array_ops.ones((3, 1), dtype=dtypes.int32) - return script_ops.eager_py_func(matmul, inp=[a, x], Tout=dtypes.int32) - - wrapped = function.defun(wrapper) - ret = self.evaluate(wrapped()) - self.assertAllEqual(ret, [[3], [3], [3]]) + wrapped = function.defun(wrapper) + ret = self.evaluate(wrapped()) + self.assertAllEqual(ret, [[3.0], [3.0], [3.0]]) @test_util.run_in_graph_and_eager_modes() def testEagerExceptionHandling(self): - self._testExceptionHandling( - ValueError, errors.InvalidArgumentError, eager=True) - self._testExceptionHandling( - TypeError, errors.InvalidArgumentError, eager=True) - self._testExceptionHandling( - StopIteration, errors.OutOfRangeError, eager=True) - self._testExceptionHandling( - MemoryError, errors.ResourceExhaustedError, eager=True) - self._testExceptionHandling( - NotImplementedError, errors.UnimplementedError, eager=True) - - class WeirdError(Exception): - pass - - self._testExceptionHandling(WeirdError, errors.UnknownError, eager=True) + with test_util.device(use_gpu=True): + self._testExceptionHandling( + ValueError, errors.InvalidArgumentError, eager=True) + self._testExceptionHandling( + TypeError, errors.InvalidArgumentError, eager=True) + self._testExceptionHandling( + StopIteration, errors.OutOfRangeError, eager=True) + self._testExceptionHandling( + MemoryError, errors.ResourceExhaustedError, eager=True) + self._testExceptionHandling( + NotImplementedError, errors.UnimplementedError, eager=True) + + class WeirdError(Exception): + pass + + self._testExceptionHandling(WeirdError, errors.UnknownError, eager=True) if __name__ == "__main__": diff --git a/tensorflow/python/lib/core/py_func.cc b/tensorflow/python/lib/core/py_func.cc index d3bfa0ee33..e0422ef80a 100644 --- a/tensorflow/python/lib/core/py_func.cc +++ b/tensorflow/python/lib/core/py_func.cc @@ -19,6 +19,7 @@ limitations under the License. #include "numpy/arrayobject.h" #include "tensorflow/c/eager/c_api.h" +#include "tensorflow/c/eager/c_api_internal.h" #include "tensorflow/c/tf_status_helper.h" #include "tensorflow/core/framework/allocation_description.pb.h" #include "tensorflow/core/framework/op_kernel.h" @@ -53,6 +54,12 @@ struct PyCall { // with this "token". string token; + // The device on which Tensors are stored; only used for EagerPyFunc. + Device* device; + + // True if and only if the op has been placed on a GPU. + bool gpu; + // True if the call is associated with an EagerPyFunc. bool eager; @@ -71,7 +78,12 @@ Status MakeArgTuple(const PyCall* call, PyObject** tuple) { PyObject* arg = nullptr; const Tensor& t = call->ins[i]; if (call->eager) { - arg = EagerTensorFromHandle(TFE_NewTensorHandle(t)); + if (call->gpu) { + arg = EagerTensorFromHandle(new TFE_TensorHandle(t, call->device)); + } else { + // TFE_TensorHandle assumes that CPU is identified by `nullptr`. + arg = EagerTensorFromHandle(new TFE_TensorHandle(t, nullptr)); + } if (arg == nullptr) { return errors::Internal("Unable to procure EagerTensor from Tensor."); } @@ -84,7 +96,8 @@ Status MakeArgTuple(const PyCall* call, PyObject** tuple) { } PyList_SetItem(lst, i, arg); } - *tuple = Py_BuildValue("(sN)", call->token.c_str(), lst); + *tuple = Py_BuildValue("(sON)", call->token.c_str(), + call->gpu ? Py_True : Py_False, lst); CHECK(*tuple); return Status::OK(); } @@ -150,15 +163,9 @@ bool IsSingleNone(PyObject* obj) { } // Retrieves a Tensor from `eager_tensor` and stores it in `output_tensor`. -Status ExtractTensorFromEagerTensor(const PyObject* eager_tensor, - Tensor* output_tensor, - TF_Status* tf_status) { - // TODO(akshayka): Lift the restriction requiring output tensors to - // lie in host memory; EagerPyFunc should be able to dispatch ops on GPU - // tensors, so we should eventually implement a GPU kernel for EagerPyFunc. - *output_tensor = *TFE_TensorHandleUnderlyingTensorInHostMemory( - EagerTensor_Handle(eager_tensor), tf_status); - return StatusFromTF_Status(tf_status); +void ExtractTensorFromEagerTensor(const PyObject* eager_tensor, + Tensor* output_tensor) { + *output_tensor = EagerTensor_Handle(eager_tensor)->t; } // Calls the registered py function through the trampoline. @@ -201,15 +208,23 @@ Status DoCallPyFunc(PyCall* call, bool* out_log_on_error) { } // Process the return values and convert them to TF Tensors. - Status s; + Status s = Status::OK(); if (PyList_Check(result)) { + // `result` is a Python list; if this operation is an `EagerPyFunc`, then + // every item in the list must be an `EagerTensor`; otherwise, every element + // must be a NumPy array. call->out.clear(); for (int i = 0; i < PyList_Size(result); ++i) { Tensor t; if (call->eager) { - auto tf_status = tensorflow::make_safe(TF_NewStatus()); - s = ExtractTensorFromEagerTensor(PyList_GetItem(result, i), &t, - tf_status.get()); + const PyObject* item = PyList_GetItem(result, i); + if (EagerTensor_CheckExact(item)) { + ExtractTensorFromEagerTensor(item, &t); + } else { + s = errors::FailedPrecondition( + "Expected EagerTensor, found PyObject of type: ", + Py_TYPE(item)->tp_name); + } } else { s = ConvertNdarrayToTensor(PyList_GetItem(result, i), &t); } @@ -220,16 +235,15 @@ Status DoCallPyFunc(PyCall* call, bool* out_log_on_error) { call->out.push_back(t); } } else if (EagerTensor_CheckExact(result) || result == Py_None) { + // result is an `EagerTensor` or `None`. DCHECK(call->eager); Tensor t; if (result != Py_None) { - auto tf_status = tensorflow::make_safe(TF_NewStatus()); - s = ExtractTensorFromEagerTensor(result, &t, tf_status.get()); - if (s.ok()) { - call->out.push_back(t); - } + ExtractTensorFromEagerTensor(result, &t); + call->out.push_back(t); } } else if (PyArray_Check(result)) { + // `result` is a NumPy array. DCHECK(!call->eager); if (!IsSingleNone(result)) { Tensor t; @@ -239,7 +253,7 @@ Status DoCallPyFunc(PyCall* call, bool* out_log_on_error) { } } } else { - s = errors::Internal("Unexpected pyobject is returned: ", + s = errors::Internal("Unexpected PyObject was returned: ", Py_TYPE(result)->tp_name); } Py_DECREF(result); @@ -429,12 +443,24 @@ class PyFuncOp : public OpKernel { explicit PyFuncOp(OpKernelConstruction* ctx) : OpKernel(ctx) { OP_REQUIRES_OK(ctx, ctx->GetAttr("token", &token_)); eager_ = type_string() == "EagerPyFunc"; + gpu_ = ctx->device_type().type_string() == DEVICE_GPU; } void Compute(OpKernelContext* ctx) override { PyCall call; call.token = token_; + call.gpu = gpu_; call.eager = eager_; + if (call.eager) { + // Eager's C API uses `Device`, whereas `OpKernelContext` stores a + // `DeviceBase`; attempt to downcast. + call.device = dynamic_cast(ctx->device()); + if (call.device == nullptr) { + ctx->CtxFailureWithWarning( + errors::Internal("Unrecognized device class")); + } + } + for (int i = 0; i < ctx->num_inputs(); ++i) { call.ins.push_back(ctx->input(i)); } @@ -476,6 +502,9 @@ class PyFuncOp : public OpKernel { private: string token_; + // True if and only if this op has been placed on a GPU. + bool gpu_; + // True if and only if this op should execute the python function eagerly, // i.e., if and only if the eager attribute is set. bool eager_; @@ -486,5 +515,6 @@ class PyFuncOp : public OpKernel { REGISTER_KERNEL_BUILDER(Name("PyFunc").Device(DEVICE_CPU), PyFuncOp); REGISTER_KERNEL_BUILDER(Name("PyFuncStateless").Device(DEVICE_CPU), PyFuncOp); REGISTER_KERNEL_BUILDER(Name("EagerPyFunc").Device(DEVICE_CPU), PyFuncOp); +REGISTER_KERNEL_BUILDER(Name("EagerPyFunc").Device(DEVICE_GPU), PyFuncOp); } // end namespace tensorflow diff --git a/tensorflow/python/ops/script_ops.py b/tensorflow/python/ops/script_ops.py index 4b5072fd67..1b9071ee93 100644 --- a/tensorflow/python/ops/script_ops.py +++ b/tensorflow/python/ops/script_ops.py @@ -50,19 +50,21 @@ class EagerFunc(object): self._func = func self._out_dtypes = Tout - def __call__(self, *args, **kwargs): - """Passes args, kwargs to `self._func`, which is executed eagerly.""" + def __call__(self, on_gpu, args): + """Passes `args` to `self._func`, which is executed eagerly.""" with context.eager_mode(): - ret = self._func(*args, **kwargs) + ret = self._func(*args) + maybe_copy_to_gpu = lambda x: x if not on_gpu else x.gpu() if isinstance(ret, (tuple, list)): return [ - ops.convert_to_tensor(x, dtype=dtype) + maybe_copy_to_gpu(ops.convert_to_tensor(x, dtype=dtype)) for (x, dtype) in zip(ret, self._out_dtypes) ] elif ret is None: return ret else: - return ops.convert_to_tensor(ret, dtype=self._out_dtypes[0]) + return maybe_copy_to_gpu( + ops.convert_to_tensor(ret, dtype=self._out_dtypes[0])) class FuncRegistry(object): @@ -116,16 +118,29 @@ class FuncRegistry(object): else: return result - def __call__(self, token, args): - """Calls the registered function for `token` with args.""" + def __call__(self, token, on_gpu, args): + """Calls the registered function for `token` with args. + + Args: + token: A key into this `FuncRegistry` identifying which function to call. + on_gpu: A boolean indicating whether or not `token`'s corresponding + operation was placed on GPU; only used if the function registered for + `token` is an `EagerPyFunc`. + args: The arguments to pass to the function registered for `token`. + + Returns: + The output of the function registered for `token`. + + Raises: + ValueError: if no function is registered for `token`. + """ func = self._funcs[token] if func is None: raise ValueError("callback %s is not found" % token) - ret = func(*args) - if isinstance(func, EagerFunc): - return ret + return func(on_gpu, args) else: + ret = func(*args) # Strings seem to lead to a memory leak here if they're not wrapped in a # list. if isinstance(ret, six.binary_type): @@ -302,8 +317,5 @@ def py_func(func, inp, Tout, stateful=True, name=None): func=func, inp=inp, Tout=Tout, stateful=stateful, eager=False, name=name) -# TODO(akshayka): PyFuncs where the 'eager' attribute is set to True should be -# differentiable, i.e., the gradient of PyFunc should propagate Nones if the -# eager attribute is not set, and otherwise, it should return the gradient. ops.NotDifferentiable("PyFunc") ops.NotDifferentiable("PyFuncStateless") -- 2.34.1