[TF] Add TensorListPushBackBatch.
authorEugene Brevdo <ebrevdo@google.com>
Thu, 12 Apr 2018 17:35:41 +0000 (10:35 -0700)
committerTensorFlower Gardener <gardener@tensorflow.org>
Thu, 12 Apr 2018 17:38:25 +0000 (10:38 -0700)
Also modify code to ensure aliased forwarding happens whenever
possible with DT_VARIANT objects in ResourceVariables and in the new op.

PiperOrigin-RevId: 192632202

tensorflow/core/api_def/base_api/api_def_TensorListGetItem.pbtxt [new file with mode: 0644]
tensorflow/core/api_def/base_api/api_def_TensorListPushBackBatch.pbtxt [new file with mode: 0644]
tensorflow/core/api_def/base_api/api_def_TensorListSetItem.pbtxt [new file with mode: 0644]
tensorflow/core/kernels/list_kernels.cc
tensorflow/core/kernels/list_kernels.cu.cc
tensorflow/core/kernels/list_kernels.h
tensorflow/core/kernels/resource_variable_ops.cc
tensorflow/core/ops/list_ops.cc
tensorflow/python/kernel_tests/list_ops_test.py

diff --git a/tensorflow/core/api_def/base_api/api_def_TensorListGetItem.pbtxt b/tensorflow/core/api_def/base_api/api_def_TensorListGetItem.pbtxt
new file mode 100644 (file)
index 0000000..2c47208
--- /dev/null
@@ -0,0 +1,3 @@
+op {
+  graph_op_name: "TensorListGetItem"
+}
diff --git a/tensorflow/core/api_def/base_api/api_def_TensorListPushBackBatch.pbtxt b/tensorflow/core/api_def/base_api/api_def_TensorListPushBackBatch.pbtxt
new file mode 100644 (file)
index 0000000..1f33d49
--- /dev/null
@@ -0,0 +1,3 @@
+op {
+  graph_op_name: "TensorListPushBackBatch"
+}
diff --git a/tensorflow/core/api_def/base_api/api_def_TensorListSetItem.pbtxt b/tensorflow/core/api_def/base_api/api_def_TensorListSetItem.pbtxt
new file mode 100644 (file)
index 0000000..002e2a9
--- /dev/null
@@ -0,0 +1,3 @@
+op {
+  graph_op_name: "TensorListSetItem"
+}
index 9e7786f..d1e481d 100644 (file)
@@ -475,6 +475,22 @@ REGISTER_KERNEL_BUILDER(
 
 #endif  // GOOGLE_CUDA
 
+#define REGISTER_TENSOR_LIST_PUSH_BACK_BATCH_CPU(T)               \
+  REGISTER_KERNEL_BUILDER(Name("TensorListPushBackBatch")         \
+                              .TypeConstraint<T>("element_dtype") \
+                              .Device(DEVICE_CPU),                \
+                          TensorListPushBackBatch<CPUDevice, T>)
+
+TF_CALL_ALL_TYPES(REGISTER_TENSOR_LIST_PUSH_BACK_BATCH_CPU);
+REGISTER_TENSOR_LIST_PUSH_BACK_BATCH_CPU(quint8);
+REGISTER_TENSOR_LIST_PUSH_BACK_BATCH_CPU(qint8);
+REGISTER_TENSOR_LIST_PUSH_BACK_BATCH_CPU(quint16);
+REGISTER_TENSOR_LIST_PUSH_BACK_BATCH_CPU(qint16);
+REGISTER_TENSOR_LIST_PUSH_BACK_BATCH_CPU(qint32);
+REGISTER_TENSOR_LIST_PUSH_BACK_BATCH_CPU(bfloat16);
+
+#undef REGISTER_TENSOR_LIST_PUSH_BACK_BATCH_CPU
+
 #define REGISTER_TENSOR_LIST_STACK_CPU(T)                         \
   REGISTER_KERNEL_BUILDER(Name("TensorListStack")                 \
                               .TypeConstraint<T>("element_dtype") \
index 935f892..0ea9362 100644 (file)
@@ -51,6 +51,21 @@ REGISTER_TENSOR_LIST_STACK_GPU(bool);
 
 #undef REGISTER_TENSOR_LIST_STACK_GPU
 
+#define REGISTER_TENSOR_LIST_PUSH_BACK_BATCH_GPU(T)               \
+  REGISTER_KERNEL_BUILDER(Name("TensorListPushBackBatch")         \
+                              .TypeConstraint<T>("element_dtype") \
+                              .Device(DEVICE_GPU),                \
+                          TensorListPushBackBatch<GPUDevice, T>)
+
+TF_CALL_GPU_NUMBER_TYPES(REGISTER_TENSOR_LIST_PUSH_BACK_BATCH_GPU);
+REGISTER_TENSOR_LIST_PUSH_BACK_BATCH_GPU(bfloat16);
+TF_CALL_complex64(REGISTER_TENSOR_LIST_PUSH_BACK_BATCH_GPU);
+TF_CALL_complex128(REGISTER_TENSOR_LIST_PUSH_BACK_BATCH_GPU);
+TF_CALL_int64(REGISTER_TENSOR_LIST_PUSH_BACK_BATCH_GPU);
+REGISTER_TENSOR_LIST_PUSH_BACK_BATCH_GPU(bool);
+
+#undef REGISTER_TENSOR_LIST_PUSH_BACK_BATCH_GPU
+
 #define REGISTER_TENSOR_LIST_FROM_TENSOR_GPU(T)                   \
   REGISTER_KERNEL_BUILDER(Name("TensorListFromTensor")            \
                               .TypeConstraint<T>("element_dtype") \
index f3bbf3b..42871c6 100644 (file)
@@ -34,6 +34,8 @@ limitations under the License.
 
 namespace tensorflow {
 
+typedef Eigen::ThreadPoolDevice CPUDevice;
+
 // Variant compatible type for a list of tensors. This is mutable but instances
 // should never be mutated after stored in a variant tensor.
 struct TensorList {
@@ -146,6 +148,10 @@ class TensorListFromTensor : public OpKernel {
     TensorList output_list;
     const Tensor& t = c->input(0);
     output_list.element_dtype = t.dtype();
+    OP_REQUIRES(c, TensorShapeUtils::IsVectorOrHigher(t.shape()),
+                errors::InvalidArgument(
+                    "Tensor must be at least a vector, but saw shape: ",
+                    t.shape().DebugString()));
     TensorShape output_shape(t.shape());
     output_shape.RemoveDim(0);
     OP_REQUIRES(c, element_shape.IsCompatibleWith(output_shape),
@@ -267,6 +273,121 @@ Status TensorListZerosLike(OpKernelContext* c, const TensorList& x,
   return Status::OK();
 }
 
+template <typename Device, typename T>
+class TensorListPushBackBatch : public OpKernel {
+ public:
+  explicit TensorListPushBackBatch(OpKernelConstruction* c) : OpKernel(c) {
+    OP_REQUIRES_OK(c, c->GetAttr("element_dtype", &element_dtype_));
+  }
+
+  ~TensorListPushBackBatch() override {}
+
+  void Compute(OpKernelContext* c) override {
+    const Tensor& input = c->input(1);
+    OP_REQUIRES(c, element_dtype_ == input.dtype(),
+                errors::InvalidArgument("Invalid data types; list elements ",
+                                        DataTypeString(element_dtype_),
+                                        " but tried to append ",
+                                        DataTypeString(input.dtype())));
+    OP_REQUIRES(c, TensorShapeUtils::IsVectorOrHigher(input.shape()),
+                errors::InvalidArgument(
+                    "Expected tensor to be at least a vector, but saw shape: ",
+                    input.shape().DebugString()));
+
+    const TensorShape& tls_shape = c->input(0).shape();
+
+    // For purposes of input forwarding, we want the least restrictive
+    // AllocatorAttributes possible.  If we need to allocate later,
+    // we'll request the DT_VARIANT be allocated on host.
+    AllocatorAttributes attr;
+
+    std::unique_ptr<Tensor> tls_alias = c->forward_input(
+        0 /*input_index*/, 0 /*output_index*/, DT_VARIANT, tls_shape,
+        DEVICE_MEMORY /* input is always on DEVICE_MEMORY */, attr);
+
+    const Tensor& tls = tls_alias ? *tls_alias : c->input(0);
+
+    OP_REQUIRES(c, tls.dtype() == DT_VARIANT,
+                errors::InvalidArgument(
+                    "Expected input_handles dtype to be Variant, but saw: ",
+                    DataTypeString(tls.dtype())));
+    OP_REQUIRES(c, TensorShapeUtils::IsVector(tls_shape),
+                errors::InvalidArgument(
+                    "Expected input_handles to be a vector, but saw shape: ",
+                    tls_shape.DebugString()));
+    const int64 batch_size = tls.NumElements();
+    OP_REQUIRES(c, input.dim_size(0) == batch_size,
+                errors::InvalidArgument(
+                    "Expected tensor.shape[0] == input_handles.size, but saw ",
+                    input.dim_size(0), " vs. ", batch_size));
+    auto tls_t = tls.vec<Variant>();
+
+    TensorShape input_element_shape = input.shape();
+    input_element_shape.RemoveDim(0);
+    std::vector<const TensorList*> tl_batch;
+    for (int64 b = 0; b < batch_size; ++b) {
+      const TensorList* l = tls_t(b).get<TensorList>();
+      OP_REQUIRES(c, l != nullptr,
+                  errors::InvalidArgument("Input handle at index ", b,
+                                          " is not a list. Saw: '",
+                                          tls_t(b).DebugString(), "'"));
+      OP_REQUIRES(
+          c, l->element_shape.IsCompatibleWith(input_element_shape),
+          errors::InvalidArgument(
+              "Tried to append a tensor with incompatible shape to a "
+              "list at index ",
+              b, ". Op element shape: ", input_element_shape.DebugString(),
+              " list shape: ", l->element_shape.DebugString()));
+      OP_REQUIRES(c, element_dtype_ == l->element_dtype,
+                  errors::InvalidArgument(
+                      "Invalid data type at index ", b, "; op elements ",
+                      DataTypeString(element_dtype_), " but list elements ",
+                      DataTypeString(l->element_dtype)));
+      tl_batch.push_back(l);
+    }
+
+    Tensor* result;
+
+    if (tls_alias) {
+      result = tls_alias.get();
+      c->set_output(0, *result);
+    } else {
+      // DT_VARIANT tensors always allocated on host.
+      AllocatorAttributes attr;
+      attr.set_on_host(true);
+      OP_REQUIRES_OK(
+          c, c->allocate_output(0, TensorShape{batch_size}, &result, attr));
+    }
+
+    if (batch_size == 0) {
+      return;
+    }
+
+    auto input_t = input.flat_outer_dims<T, 2>();
+    auto result_t = result->vec<Variant>();
+
+    for (int64 b = 0; b < batch_size; ++b) {
+      if (!tls_alias) {
+        result_t(b) = *tl_batch[b];
+      }
+      TensorList* output = result_t(b).get<TensorList>();
+      DCHECK(output != nullptr);
+      Tensor* frame;
+      PersistentTensor tmp;
+      OP_REQUIRES_OK(c, c->allocate_persistent(
+                            element_dtype_, input_element_shape, &tmp, &frame));
+      if (input_element_shape.num_elements() > 0) {
+        auto frame_t = frame->flat<T>();
+        frame_t.device(c->eigen_device<Device>()) = input_t.template chip<0>(b);
+      }
+      output->tensors.push_back(std::move(*frame));
+    }
+  }
+
+ private:
+  DataType element_dtype_;
+};
+
 }  // namespace tensorflow
 
 #endif  // TENSORFLOW_CORE_KERNELS_LIST_KERNELS_H_
index 7250420..916869f 100644 (file)
@@ -306,8 +306,9 @@ class AssignVariableOp<Device, Variant> : public OpKernel {
                     DataTypeString(variable->tensor()->dtype()), " got ",
                     DataTypeString(DT_VARIANT)));
 
+    // For purposes of forwarding DT_VARIANT, we want the least
+    // restrictive attr; we already know the input is on host.
     AllocatorAttributes attr;
-    attr.set_on_host(true);
 
     // Copying is unnecessary if we are the last user of the value
     // tensor, we can just adopt the input tensor's buffer instead.
@@ -320,7 +321,7 @@ class AssignVariableOp<Device, Variant> : public OpKernel {
     std::unique_ptr<Tensor> input_alias = context->forward_input(
         1, OpKernelContext::Params::kNoReservation /*output_index*/, DT_VARIANT,
         value.shape(),
-        std::is_same<Device, CPUDevice>::value ? HOST_MEMORY : DEVICE_MEMORY,
+        DEVICE_MEMORY /* HOST_MEMORY is only reserved for special cases */,
         attr);
 
     mutex_lock ml(*variable->mu());
@@ -337,6 +338,8 @@ class AssignVariableOp<Device, Variant> : public OpKernel {
         !variable->tensor()->shape().IsSameSize(value.shape())) {
       PersistentTensor unused;
       Tensor* tmp;
+      // Allocation of DT_VARIANT is always on host.
+      attr.set_on_host(true);
       OP_REQUIRES_OK(context,
                      context->allocate_persistent(DT_VARIANT, value.shape(),
                                                   &unused, &tmp, attr));
index c151055..7af7011 100644 (file)
@@ -71,6 +71,50 @@ REGISTER_OP("TensorListPushBack")
       return Status::OK();
     });
 
+REGISTER_OP("TensorListPushBackBatch")
+    .Input("input_handles: variant")
+    .Input("tensor: element_dtype")
+    .Output("output_handles: variant")
+    .Attr("element_dtype: type")
+    .SetShapeFn([](shape_inference::InferenceContext* c) {
+      shape_inference::ShapeHandle input_handles;
+      TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 1, &input_handles));
+
+      shape_inference::ShapeHandle tensor;
+      TF_RETURN_IF_ERROR(c->WithRankAtLeast(c->input(1), 1, &tensor));
+
+      TF_RETURN_IF_ERROR(
+          c->MergePrefix(tensor, input_handles, &tensor, &input_handles));
+
+      c->set_output(0, input_handles);
+
+      DataType t;
+      TF_RETURN_IF_ERROR(c->GetAttr("element_dtype", &t));
+      shape_inference::ShapeHandle s = c->UnknownShape();
+
+      auto* handle_data = c->input_handle_shapes_and_types(0);
+      if (handle_data != nullptr && handle_data->size() != 1) {
+        return errors::InvalidArgument(
+            "Trying to push to list with wrong variant data.");
+      }
+      if (handle_data != nullptr) {
+        const shape_inference::ShapeAndType& list_shape_type =
+            (*handle_data)[0];
+        if (list_shape_type.dtype != t) {
+          return errors::InvalidArgument(
+              "Trying to push to list with wrong element dtype. List has type ",
+              DataTypeString(list_shape_type.dtype),
+              " but trying to push element with type ", DataTypeString(t));
+        }
+        shape_inference::ShapeHandle ignored;
+        TF_RETURN_IF_ERROR(c->Merge(s, list_shape_type.shape, &ignored));
+        s = list_shape_type.shape;
+      }
+      c->set_output_handle_shapes_and_types(
+          0, std::vector<shape_inference::ShapeAndType>{{s, t}});
+      return Status::OK();
+    });
+
 REGISTER_OP("TensorListLength")
     .Input("input_handle: variant")
     .Output("length: int32")
index 6173a1d..2084599 100644 (file)
@@ -318,6 +318,48 @@ class ListOpsTest(test_util.TensorFlowTestCase):
                 [[1.0, 2.0]] * 4)
     self.assertAllEqual(self.evaluate(updated_v_stacked), expected)
 
+  @test_util.run_in_graph_and_eager_modes()
+  def testPushBackBatch(self):
+    c = constant_op.constant([1.0, 2.0], dtype=dtypes.float32)
+    l0 = list_ops.tensor_list_from_tensor(c, element_shape=scalar_shape())
+    l1 = list_ops.tensor_list_from_tensor([-1.0], element_shape=scalar_shape())
+    l_batch = array_ops.stack([l0, l1])
+    l_push = list_ops.tensor_list_push_back_batch(l_batch, [3.0, 4.0])
+    l_unstack = array_ops.unstack(l_push)
+    l0_ret = list_ops.tensor_list_stack(l_unstack[0], dtypes.float32)
+    l1_ret = list_ops.tensor_list_stack(l_unstack[1], dtypes.float32)
+    self.assertAllClose([1.0, 2.0, 3.0], self.evaluate(l0_ret))
+    self.assertAllClose([-1.0, 4.0], self.evaluate(l1_ret))
+
+    with ops.control_dependencies([l_push]):
+      l_unstack_orig = array_ops.unstack(l_batch)
+      l0_orig_ret = list_ops.tensor_list_stack(l_unstack_orig[0],
+                                               dtypes.float32)
+      l1_orig_ret = list_ops.tensor_list_stack(l_unstack_orig[1],
+                                               dtypes.float32)
+
+    # Check that without aliasing, push_back_batch still works; and
+    # that it doesn't modify the input.
+    l0_r_v, l1_r_v, l0_orig_v, l1_orig_v = self.evaluate(
+        (l0_ret, l1_ret, l0_orig_ret, l1_orig_ret))
+    self.assertAllClose([1.0, 2.0, 3.0], l0_r_v)
+    self.assertAllClose([-1.0, 4.0], l1_r_v)
+    self.assertAllClose([1.0, 2.0], l0_orig_v)
+    self.assertAllClose([-1.0], l1_orig_v)
+
+    # Pushing back mismatched shapes fails.
+    with self.assertRaises((errors.InvalidArgumentError, ValueError)):
+      self.evaluate(list_ops.tensor_list_push_back_batch(l_batch, []))
+
+    with self.assertRaisesRegexp(errors.InvalidArgumentError,
+                                 "incompatible shape to a list at index 0"):
+      self.evaluate(
+          list_ops.tensor_list_push_back_batch(l_batch, [[3.0], [4.0]]))
+
+    with self.assertRaisesRegexp(errors.InvalidArgumentError,
+                                 "Invalid data type at index 0"):
+      self.evaluate(list_ops.tensor_list_push_back_batch(l_batch, [3, 4]))
+
 
 if __name__ == "__main__":
   test.main()