TFE: Register a GPU kernel for tfe.py_func.
authorAkshay Agrawal <akshayka@google.com>
Tue, 30 Jan 2018 02:30:59 +0000 (18:30 -0800)
committerTensorFlower Gardener <gardener@tensorflow.org>
Tue, 30 Jan 2018 02:34:39 +0000 (18:34 -0800)
PiperOrigin-RevId: 183765122

tensorflow/python/BUILD
tensorflow/python/kernel_tests/py_func_test.py
tensorflow/python/lib/core/py_func.cc
tensorflow/python/ops/script_ops.py

index a323d5bc393f4212a04d9cd89e86675207bf95f5..363ff6fae933e7ee76de9b0b2a34bda27d141428 100644 (file)
@@ -298,6 +298,7 @@ cc_library(
         ":safe_ptr",
         "//tensorflow/c:tf_status_helper",
         "//tensorflow/c/eager:c_api",
+        "//tensorflow/c/eager:c_api_internal",
         "//tensorflow/core:framework",
         "//tensorflow/core:lib",
         "//tensorflow/core:protos_all_cc",
index 92fb68820e04c3db1385296d91d956134b8ff2d4..c7181497d891f6d35a788c90bf59a0ce5a536328 100644 (file)
@@ -396,66 +396,66 @@ class PyFuncTest(test.TestCase):
 
   @test_util.run_in_graph_and_eager_modes()
   def testEagerSingleOutputFloat32(self):
-    a = array_ops.ones((3, 3), dtype=dtypes.float32)
-    x = array_ops.ones((3, 1), dtype=dtypes.float32)
-    output = script_ops.eager_py_func(matmul, inp=[a, x], Tout=dtypes.float32)
-    with self.test_session():
+    with test_util.device(use_gpu=True):
+      a = array_ops.ones((3, 3), dtype=dtypes.float32)
+      x = array_ops.ones((3, 1), dtype=dtypes.float32)
+      output = script_ops.eager_py_func(matmul, inp=[a, x], Tout=dtypes.float32)
       ret = self.evaluate(output)
       self.assertAllClose(ret, [[3.0], [3.0], [3.0]])
 
   @test_util.run_in_graph_and_eager_modes()
   def testEagerArrayOutput(self):
-    a = array_ops.ones((3, 3), dtype=dtypes.int32)
-    x = array_ops.ones((3, 1), dtype=dtypes.int32)
-    output = script_ops.eager_py_func(
-        lambda a, x: [matmul(a, x)], inp=[a, x], Tout=[dtypes.int32])
-
-    with self.test_session():
+    with test_util.device(use_gpu=True):
+      a = array_ops.ones((3, 3), dtype=dtypes.float32)
+      x = array_ops.ones((3, 1), dtype=dtypes.float32)
+      output = script_ops.eager_py_func(
+          lambda a, x: [matmul(a, x)], inp=[a, x], Tout=[dtypes.float32])
       ret = self.evaluate(output)
-      self.assertAllEqual(ret, [[[3], [3], [3]]])
+      self.assertAllEqual(ret, [[[3.0], [3.0], [3.0]]])
 
   @test_util.run_in_graph_and_eager_modes()
   def testEagerReturnNone(self):
+    with test_util.device(use_gpu=True):
+      def no_return_value():
+        return
 
-    def no_return_value():
-      return
-
-    output = script_ops.eager_py_func(no_return_value, inp=[], Tout=[])
-    ret = self.evaluate(output)
-    if context.in_eager_mode():
-      self.assertEquals(len(ret), 0)
-    else:
-      self.assertIsNone(ret)
+      output = script_ops.eager_py_func(no_return_value, inp=[], Tout=[])
+      ret = self.evaluate(output)
+      if context.in_eager_mode():
+        self.assertEquals(len(ret), 0)
+      else:
+        self.assertIsNone(ret)
 
   @test_util.run_in_graph_and_eager_modes()
   def testEagerPyFuncInDefun(self):
+    with test_util.device(use_gpu=True):
+      def wrapper():
+        a = array_ops.ones((3, 3), dtype=dtypes.float32)
+        x = array_ops.ones((3, 1), dtype=dtypes.float32)
+        return script_ops.eager_py_func(matmul, inp=[a, x], Tout=dtypes.float32)
 
-    def wrapper():
-      a = array_ops.ones((3, 3), dtype=dtypes.int32)
-      x = array_ops.ones((3, 1), dtype=dtypes.int32)
-      return script_ops.eager_py_func(matmul, inp=[a, x], Tout=dtypes.int32)
-
-    wrapped = function.defun(wrapper)
-    ret = self.evaluate(wrapped())
-    self.assertAllEqual(ret, [[3], [3], [3]])
+      wrapped = function.defun(wrapper)
+      ret = self.evaluate(wrapped())
+      self.assertAllEqual(ret, [[3.0], [3.0], [3.0]])
 
   @test_util.run_in_graph_and_eager_modes()
   def testEagerExceptionHandling(self):
-    self._testExceptionHandling(
-        ValueError, errors.InvalidArgumentError, eager=True)
-    self._testExceptionHandling(
-        TypeError, errors.InvalidArgumentError, eager=True)
-    self._testExceptionHandling(
-        StopIteration, errors.OutOfRangeError, eager=True)
-    self._testExceptionHandling(
-        MemoryError, errors.ResourceExhaustedError, eager=True)
-    self._testExceptionHandling(
-        NotImplementedError, errors.UnimplementedError, eager=True)
-
-    class WeirdError(Exception):
-      pass
-
-    self._testExceptionHandling(WeirdError, errors.UnknownError, eager=True)
+    with test_util.device(use_gpu=True):
+      self._testExceptionHandling(
+          ValueError, errors.InvalidArgumentError, eager=True)
+      self._testExceptionHandling(
+          TypeError, errors.InvalidArgumentError, eager=True)
+      self._testExceptionHandling(
+          StopIteration, errors.OutOfRangeError, eager=True)
+      self._testExceptionHandling(
+          MemoryError, errors.ResourceExhaustedError, eager=True)
+      self._testExceptionHandling(
+          NotImplementedError, errors.UnimplementedError, eager=True)
+
+      class WeirdError(Exception):
+        pass
+
+      self._testExceptionHandling(WeirdError, errors.UnknownError, eager=True)
 
 
 if __name__ == "__main__":
index d3bfa0ee337d1f606e5e994406969685a2986ab4..e0422ef80add42307268be2743e668eb8c8acb68 100644 (file)
@@ -19,6 +19,7 @@ limitations under the License.
 
 #include "numpy/arrayobject.h"
 #include "tensorflow/c/eager/c_api.h"
+#include "tensorflow/c/eager/c_api_internal.h"
 #include "tensorflow/c/tf_status_helper.h"
 #include "tensorflow/core/framework/allocation_description.pb.h"
 #include "tensorflow/core/framework/op_kernel.h"
@@ -53,6 +54,12 @@ struct PyCall {
   // with this "token".
   string token;
 
+  // The device on which Tensors are stored; only used for EagerPyFunc.
+  Device* device;
+
+  // True if and only if the op has been placed on a GPU.
+  bool gpu;
+
   // True if the call is associated with an EagerPyFunc.
   bool eager;
 
@@ -71,7 +78,12 @@ Status MakeArgTuple(const PyCall* call, PyObject** tuple) {
     PyObject* arg = nullptr;
     const Tensor& t = call->ins[i];
     if (call->eager) {
-      arg = EagerTensorFromHandle(TFE_NewTensorHandle(t));
+      if (call->gpu) {
+        arg = EagerTensorFromHandle(new TFE_TensorHandle(t, call->device));
+      } else {
+        // TFE_TensorHandle assumes that CPU is identified by `nullptr`.
+        arg = EagerTensorFromHandle(new TFE_TensorHandle(t, nullptr));
+      }
       if (arg == nullptr) {
         return errors::Internal("Unable to procure EagerTensor from Tensor.");
       }
@@ -84,7 +96,8 @@ Status MakeArgTuple(const PyCall* call, PyObject** tuple) {
     }
     PyList_SetItem(lst, i, arg);
   }
-  *tuple = Py_BuildValue("(sN)", call->token.c_str(), lst);
+  *tuple = Py_BuildValue("(sON)", call->token.c_str(),
+                         call->gpu ? Py_True : Py_False, lst);
   CHECK(*tuple);
   return Status::OK();
 }
@@ -150,15 +163,9 @@ bool IsSingleNone(PyObject* obj) {
 }
 
 // Retrieves a Tensor from `eager_tensor` and stores it in `output_tensor`.
-Status ExtractTensorFromEagerTensor(const PyObject* eager_tensor,
-                                    Tensor* output_tensor,
-                                    TF_Status* tf_status) {
-  // TODO(akshayka): Lift the restriction requiring output tensors to
-  // lie in host memory; EagerPyFunc should be able to dispatch ops on GPU
-  // tensors, so we should eventually implement a GPU kernel for EagerPyFunc.
-  *output_tensor = *TFE_TensorHandleUnderlyingTensorInHostMemory(
-      EagerTensor_Handle(eager_tensor), tf_status);
-  return StatusFromTF_Status(tf_status);
+void ExtractTensorFromEagerTensor(const PyObject* eager_tensor,
+                                  Tensor* output_tensor) {
+  *output_tensor = EagerTensor_Handle(eager_tensor)->t;
 }
 
 // Calls the registered py function through the trampoline.
@@ -201,15 +208,23 @@ Status DoCallPyFunc(PyCall* call, bool* out_log_on_error) {
   }
 
   // Process the return values and convert them to TF Tensors.
-  Status s;
+  Status s = Status::OK();
   if (PyList_Check(result)) {
+    // `result` is a Python list; if this operation is an `EagerPyFunc`, then
+    // every item in the list must be an `EagerTensor`; otherwise, every element
+    // must be a NumPy array.
     call->out.clear();
     for (int i = 0; i < PyList_Size(result); ++i) {
       Tensor t;
       if (call->eager) {
-        auto tf_status = tensorflow::make_safe(TF_NewStatus());
-        s = ExtractTensorFromEagerTensor(PyList_GetItem(result, i), &t,
-                                         tf_status.get());
+        const PyObject* item = PyList_GetItem(result, i);
+        if (EagerTensor_CheckExact(item)) {
+          ExtractTensorFromEagerTensor(item, &t);
+        } else {
+          s = errors::FailedPrecondition(
+              "Expected EagerTensor, found PyObject of type: ",
+              Py_TYPE(item)->tp_name);
+        }
       } else {
         s = ConvertNdarrayToTensor(PyList_GetItem(result, i), &t);
       }
@@ -220,16 +235,15 @@ Status DoCallPyFunc(PyCall* call, bool* out_log_on_error) {
       call->out.push_back(t);
     }
   } else if (EagerTensor_CheckExact(result) || result == Py_None) {
+    // result is an `EagerTensor` or `None`.
     DCHECK(call->eager);
     Tensor t;
     if (result != Py_None) {
-      auto tf_status = tensorflow::make_safe(TF_NewStatus());
-      s = ExtractTensorFromEagerTensor(result, &t, tf_status.get());
-      if (s.ok()) {
-        call->out.push_back(t);
-      }
+      ExtractTensorFromEagerTensor(result, &t);
+      call->out.push_back(t);
     }
   } else if (PyArray_Check(result)) {
+    // `result` is a NumPy array.
     DCHECK(!call->eager);
     if (!IsSingleNone(result)) {
       Tensor t;
@@ -239,7 +253,7 @@ Status DoCallPyFunc(PyCall* call, bool* out_log_on_error) {
       }
     }
   } else {
-    s = errors::Internal("Unexpected pyobject is returned: ",
+    s = errors::Internal("Unexpected PyObject was returned: ",
                          Py_TYPE(result)->tp_name);
   }
   Py_DECREF(result);
@@ -429,12 +443,24 @@ class PyFuncOp : public OpKernel {
   explicit PyFuncOp(OpKernelConstruction* ctx) : OpKernel(ctx) {
     OP_REQUIRES_OK(ctx, ctx->GetAttr("token", &token_));
     eager_ = type_string() == "EagerPyFunc";
+    gpu_ = ctx->device_type().type_string() == DEVICE_GPU;
   }
 
   void Compute(OpKernelContext* ctx) override {
     PyCall call;
     call.token = token_;
+    call.gpu = gpu_;
     call.eager = eager_;
+    if (call.eager) {
+      // Eager's C API uses `Device`, whereas `OpKernelContext` stores a
+      // `DeviceBase`; attempt to downcast.
+      call.device = dynamic_cast<Device*>(ctx->device());
+      if (call.device == nullptr) {
+        ctx->CtxFailureWithWarning(
+            errors::Internal("Unrecognized device class"));
+      }
+    }
+
     for (int i = 0; i < ctx->num_inputs(); ++i) {
       call.ins.push_back(ctx->input(i));
     }
@@ -476,6 +502,9 @@ class PyFuncOp : public OpKernel {
  private:
   string token_;
 
+  // True if and only if this op has been placed on a GPU.
+  bool gpu_;
+
   // True if and only if this op should execute the python function eagerly,
   // i.e., if and only if the eager attribute is set.
   bool eager_;
@@ -486,5 +515,6 @@ class PyFuncOp : public OpKernel {
 REGISTER_KERNEL_BUILDER(Name("PyFunc").Device(DEVICE_CPU), PyFuncOp);
 REGISTER_KERNEL_BUILDER(Name("PyFuncStateless").Device(DEVICE_CPU), PyFuncOp);
 REGISTER_KERNEL_BUILDER(Name("EagerPyFunc").Device(DEVICE_CPU), PyFuncOp);
+REGISTER_KERNEL_BUILDER(Name("EagerPyFunc").Device(DEVICE_GPU), PyFuncOp);
 
 }  // end namespace tensorflow
index 4b5072fd6799ae289d3c1a1b2a40878e36604bf4..1b9071ee93c21f8d6bdc9ace11dbf57f3eb3e218 100644 (file)
@@ -50,19 +50,21 @@ class EagerFunc(object):
     self._func = func
     self._out_dtypes = Tout
 
-  def __call__(self, *args, **kwargs):
-    """Passes args, kwargs to `self._func`, which is executed eagerly."""
+  def __call__(self, on_gpu, args):
+    """Passes `args` to `self._func`, which is executed eagerly."""
     with context.eager_mode():
-      ret = self._func(*args, **kwargs)
+      ret = self._func(*args)
+      maybe_copy_to_gpu = lambda x: x if not on_gpu else x.gpu()
       if isinstance(ret, (tuple, list)):
         return [
-            ops.convert_to_tensor(x, dtype=dtype)
+            maybe_copy_to_gpu(ops.convert_to_tensor(x, dtype=dtype))
             for (x, dtype) in zip(ret, self._out_dtypes)
         ]
       elif ret is None:
         return ret
       else:
-        return ops.convert_to_tensor(ret, dtype=self._out_dtypes[0])
+        return maybe_copy_to_gpu(
+            ops.convert_to_tensor(ret, dtype=self._out_dtypes[0]))
 
 
 class FuncRegistry(object):
@@ -116,16 +118,29 @@ class FuncRegistry(object):
     else:
       return result
 
-  def __call__(self, token, args):
-    """Calls the registered function for `token` with args."""
+  def __call__(self, token, on_gpu, args):
+    """Calls the registered function for `token` with args.
+
+    Args:
+      token: A key into this `FuncRegistry` identifying which function to call.
+      on_gpu: A boolean indicating whether or not `token`'s corresponding
+        operation was placed on GPU; only used if the function registered for
+        `token` is an `EagerPyFunc`.
+      args: The arguments to pass to the function registered for `token`.
+
+    Returns:
+      The output of the function registered for `token`.
+
+    Raises:
+      ValueError: if no function is registered for `token`.
+    """
     func = self._funcs[token]
     if func is None:
       raise ValueError("callback %s is not found" % token)
-    ret = func(*args)
-
     if isinstance(func, EagerFunc):
-      return ret
+      return func(on_gpu, args)
     else:
+      ret = func(*args)
       # Strings seem to lead to a memory leak here if they're not wrapped in a
       # list.
       if isinstance(ret, six.binary_type):
@@ -302,8 +317,5 @@ def py_func(func, inp, Tout, stateful=True, name=None):
       func=func, inp=inp, Tout=Tout, stateful=stateful, eager=False, name=name)
 
 
-# TODO(akshayka): PyFuncs where the 'eager' attribute is set to True should be
-# differentiable, i.e., the gradient of PyFunc should propagate Nones if the
-# eager attribute is not set, and otherwise, it should return the gradient.
 ops.NotDifferentiable("PyFunc")
 ops.NotDifferentiable("PyFuncStateless")