--- /dev/null
+op {
+ graph_op_name: "TensorListGetItem"
+}
--- /dev/null
+op {
+ graph_op_name: "TensorListPushBackBatch"
+}
--- /dev/null
+op {
+ graph_op_name: "TensorListSetItem"
+}
#endif // GOOGLE_CUDA
+#define REGISTER_TENSOR_LIST_PUSH_BACK_BATCH_CPU(T) \
+ REGISTER_KERNEL_BUILDER(Name("TensorListPushBackBatch") \
+ .TypeConstraint<T>("element_dtype") \
+ .Device(DEVICE_CPU), \
+ TensorListPushBackBatch<CPUDevice, T>)
+
+TF_CALL_ALL_TYPES(REGISTER_TENSOR_LIST_PUSH_BACK_BATCH_CPU);
+REGISTER_TENSOR_LIST_PUSH_BACK_BATCH_CPU(quint8);
+REGISTER_TENSOR_LIST_PUSH_BACK_BATCH_CPU(qint8);
+REGISTER_TENSOR_LIST_PUSH_BACK_BATCH_CPU(quint16);
+REGISTER_TENSOR_LIST_PUSH_BACK_BATCH_CPU(qint16);
+REGISTER_TENSOR_LIST_PUSH_BACK_BATCH_CPU(qint32);
+REGISTER_TENSOR_LIST_PUSH_BACK_BATCH_CPU(bfloat16);
+
+#undef REGISTER_TENSOR_LIST_PUSH_BACK_BATCH_CPU
+
#define REGISTER_TENSOR_LIST_STACK_CPU(T) \
REGISTER_KERNEL_BUILDER(Name("TensorListStack") \
.TypeConstraint<T>("element_dtype") \
#undef REGISTER_TENSOR_LIST_STACK_GPU
+#define REGISTER_TENSOR_LIST_PUSH_BACK_BATCH_GPU(T) \
+ REGISTER_KERNEL_BUILDER(Name("TensorListPushBackBatch") \
+ .TypeConstraint<T>("element_dtype") \
+ .Device(DEVICE_GPU), \
+ TensorListPushBackBatch<GPUDevice, T>)
+
+TF_CALL_GPU_NUMBER_TYPES(REGISTER_TENSOR_LIST_PUSH_BACK_BATCH_GPU);
+REGISTER_TENSOR_LIST_PUSH_BACK_BATCH_GPU(bfloat16);
+TF_CALL_complex64(REGISTER_TENSOR_LIST_PUSH_BACK_BATCH_GPU);
+TF_CALL_complex128(REGISTER_TENSOR_LIST_PUSH_BACK_BATCH_GPU);
+TF_CALL_int64(REGISTER_TENSOR_LIST_PUSH_BACK_BATCH_GPU);
+REGISTER_TENSOR_LIST_PUSH_BACK_BATCH_GPU(bool);
+
+#undef REGISTER_TENSOR_LIST_PUSH_BACK_BATCH_GPU
+
#define REGISTER_TENSOR_LIST_FROM_TENSOR_GPU(T) \
REGISTER_KERNEL_BUILDER(Name("TensorListFromTensor") \
.TypeConstraint<T>("element_dtype") \
namespace tensorflow {
+typedef Eigen::ThreadPoolDevice CPUDevice;
+
// Variant compatible type for a list of tensors. This is mutable but instances
// should never be mutated after stored in a variant tensor.
struct TensorList {
TensorList output_list;
const Tensor& t = c->input(0);
output_list.element_dtype = t.dtype();
+ OP_REQUIRES(c, TensorShapeUtils::IsVectorOrHigher(t.shape()),
+ errors::InvalidArgument(
+ "Tensor must be at least a vector, but saw shape: ",
+ t.shape().DebugString()));
TensorShape output_shape(t.shape());
output_shape.RemoveDim(0);
OP_REQUIRES(c, element_shape.IsCompatibleWith(output_shape),
return Status::OK();
}
+template <typename Device, typename T>
+class TensorListPushBackBatch : public OpKernel {
+ public:
+ explicit TensorListPushBackBatch(OpKernelConstruction* c) : OpKernel(c) {
+ OP_REQUIRES_OK(c, c->GetAttr("element_dtype", &element_dtype_));
+ }
+
+ ~TensorListPushBackBatch() override {}
+
+ void Compute(OpKernelContext* c) override {
+ const Tensor& input = c->input(1);
+ OP_REQUIRES(c, element_dtype_ == input.dtype(),
+ errors::InvalidArgument("Invalid data types; list elements ",
+ DataTypeString(element_dtype_),
+ " but tried to append ",
+ DataTypeString(input.dtype())));
+ OP_REQUIRES(c, TensorShapeUtils::IsVectorOrHigher(input.shape()),
+ errors::InvalidArgument(
+ "Expected tensor to be at least a vector, but saw shape: ",
+ input.shape().DebugString()));
+
+ const TensorShape& tls_shape = c->input(0).shape();
+
+ // For purposes of input forwarding, we want the least restrictive
+ // AllocatorAttributes possible. If we need to allocate later,
+ // we'll request the DT_VARIANT be allocated on host.
+ AllocatorAttributes attr;
+
+ std::unique_ptr<Tensor> tls_alias = c->forward_input(
+ 0 /*input_index*/, 0 /*output_index*/, DT_VARIANT, tls_shape,
+ DEVICE_MEMORY /* input is always on DEVICE_MEMORY */, attr);
+
+ const Tensor& tls = tls_alias ? *tls_alias : c->input(0);
+
+ OP_REQUIRES(c, tls.dtype() == DT_VARIANT,
+ errors::InvalidArgument(
+ "Expected input_handles dtype to be Variant, but saw: ",
+ DataTypeString(tls.dtype())));
+ OP_REQUIRES(c, TensorShapeUtils::IsVector(tls_shape),
+ errors::InvalidArgument(
+ "Expected input_handles to be a vector, but saw shape: ",
+ tls_shape.DebugString()));
+ const int64 batch_size = tls.NumElements();
+ OP_REQUIRES(c, input.dim_size(0) == batch_size,
+ errors::InvalidArgument(
+ "Expected tensor.shape[0] == input_handles.size, but saw ",
+ input.dim_size(0), " vs. ", batch_size));
+ auto tls_t = tls.vec<Variant>();
+
+ TensorShape input_element_shape = input.shape();
+ input_element_shape.RemoveDim(0);
+ std::vector<const TensorList*> tl_batch;
+ for (int64 b = 0; b < batch_size; ++b) {
+ const TensorList* l = tls_t(b).get<TensorList>();
+ OP_REQUIRES(c, l != nullptr,
+ errors::InvalidArgument("Input handle at index ", b,
+ " is not a list. Saw: '",
+ tls_t(b).DebugString(), "'"));
+ OP_REQUIRES(
+ c, l->element_shape.IsCompatibleWith(input_element_shape),
+ errors::InvalidArgument(
+ "Tried to append a tensor with incompatible shape to a "
+ "list at index ",
+ b, ". Op element shape: ", input_element_shape.DebugString(),
+ " list shape: ", l->element_shape.DebugString()));
+ OP_REQUIRES(c, element_dtype_ == l->element_dtype,
+ errors::InvalidArgument(
+ "Invalid data type at index ", b, "; op elements ",
+ DataTypeString(element_dtype_), " but list elements ",
+ DataTypeString(l->element_dtype)));
+ tl_batch.push_back(l);
+ }
+
+ Tensor* result;
+
+ if (tls_alias) {
+ result = tls_alias.get();
+ c->set_output(0, *result);
+ } else {
+ // DT_VARIANT tensors always allocated on host.
+ AllocatorAttributes attr;
+ attr.set_on_host(true);
+ OP_REQUIRES_OK(
+ c, c->allocate_output(0, TensorShape{batch_size}, &result, attr));
+ }
+
+ if (batch_size == 0) {
+ return;
+ }
+
+ auto input_t = input.flat_outer_dims<T, 2>();
+ auto result_t = result->vec<Variant>();
+
+ for (int64 b = 0; b < batch_size; ++b) {
+ if (!tls_alias) {
+ result_t(b) = *tl_batch[b];
+ }
+ TensorList* output = result_t(b).get<TensorList>();
+ DCHECK(output != nullptr);
+ Tensor* frame;
+ PersistentTensor tmp;
+ OP_REQUIRES_OK(c, c->allocate_persistent(
+ element_dtype_, input_element_shape, &tmp, &frame));
+ if (input_element_shape.num_elements() > 0) {
+ auto frame_t = frame->flat<T>();
+ frame_t.device(c->eigen_device<Device>()) = input_t.template chip<0>(b);
+ }
+ output->tensors.push_back(std::move(*frame));
+ }
+ }
+
+ private:
+ DataType element_dtype_;
+};
+
} // namespace tensorflow
#endif // TENSORFLOW_CORE_KERNELS_LIST_KERNELS_H_
DataTypeString(variable->tensor()->dtype()), " got ",
DataTypeString(DT_VARIANT)));
+ // For purposes of forwarding DT_VARIANT, we want the least
+ // restrictive attr; we already know the input is on host.
AllocatorAttributes attr;
- attr.set_on_host(true);
// Copying is unnecessary if we are the last user of the value
// tensor, we can just adopt the input tensor's buffer instead.
std::unique_ptr<Tensor> input_alias = context->forward_input(
1, OpKernelContext::Params::kNoReservation /*output_index*/, DT_VARIANT,
value.shape(),
- std::is_same<Device, CPUDevice>::value ? HOST_MEMORY : DEVICE_MEMORY,
+ DEVICE_MEMORY /* HOST_MEMORY is only reserved for special cases */,
attr);
mutex_lock ml(*variable->mu());
!variable->tensor()->shape().IsSameSize(value.shape())) {
PersistentTensor unused;
Tensor* tmp;
+ // Allocation of DT_VARIANT is always on host.
+ attr.set_on_host(true);
OP_REQUIRES_OK(context,
context->allocate_persistent(DT_VARIANT, value.shape(),
&unused, &tmp, attr));
return Status::OK();
});
+REGISTER_OP("TensorListPushBackBatch")
+ .Input("input_handles: variant")
+ .Input("tensor: element_dtype")
+ .Output("output_handles: variant")
+ .Attr("element_dtype: type")
+ .SetShapeFn([](shape_inference::InferenceContext* c) {
+ shape_inference::ShapeHandle input_handles;
+ TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 1, &input_handles));
+
+ shape_inference::ShapeHandle tensor;
+ TF_RETURN_IF_ERROR(c->WithRankAtLeast(c->input(1), 1, &tensor));
+
+ TF_RETURN_IF_ERROR(
+ c->MergePrefix(tensor, input_handles, &tensor, &input_handles));
+
+ c->set_output(0, input_handles);
+
+ DataType t;
+ TF_RETURN_IF_ERROR(c->GetAttr("element_dtype", &t));
+ shape_inference::ShapeHandle s = c->UnknownShape();
+
+ auto* handle_data = c->input_handle_shapes_and_types(0);
+ if (handle_data != nullptr && handle_data->size() != 1) {
+ return errors::InvalidArgument(
+ "Trying to push to list with wrong variant data.");
+ }
+ if (handle_data != nullptr) {
+ const shape_inference::ShapeAndType& list_shape_type =
+ (*handle_data)[0];
+ if (list_shape_type.dtype != t) {
+ return errors::InvalidArgument(
+ "Trying to push to list with wrong element dtype. List has type ",
+ DataTypeString(list_shape_type.dtype),
+ " but trying to push element with type ", DataTypeString(t));
+ }
+ shape_inference::ShapeHandle ignored;
+ TF_RETURN_IF_ERROR(c->Merge(s, list_shape_type.shape, &ignored));
+ s = list_shape_type.shape;
+ }
+ c->set_output_handle_shapes_and_types(
+ 0, std::vector<shape_inference::ShapeAndType>{{s, t}});
+ return Status::OK();
+ });
+
REGISTER_OP("TensorListLength")
.Input("input_handle: variant")
.Output("length: int32")
[[1.0, 2.0]] * 4)
self.assertAllEqual(self.evaluate(updated_v_stacked), expected)
+ @test_util.run_in_graph_and_eager_modes()
+ def testPushBackBatch(self):
+ c = constant_op.constant([1.0, 2.0], dtype=dtypes.float32)
+ l0 = list_ops.tensor_list_from_tensor(c, element_shape=scalar_shape())
+ l1 = list_ops.tensor_list_from_tensor([-1.0], element_shape=scalar_shape())
+ l_batch = array_ops.stack([l0, l1])
+ l_push = list_ops.tensor_list_push_back_batch(l_batch, [3.0, 4.0])
+ l_unstack = array_ops.unstack(l_push)
+ l0_ret = list_ops.tensor_list_stack(l_unstack[0], dtypes.float32)
+ l1_ret = list_ops.tensor_list_stack(l_unstack[1], dtypes.float32)
+ self.assertAllClose([1.0, 2.0, 3.0], self.evaluate(l0_ret))
+ self.assertAllClose([-1.0, 4.0], self.evaluate(l1_ret))
+
+ with ops.control_dependencies([l_push]):
+ l_unstack_orig = array_ops.unstack(l_batch)
+ l0_orig_ret = list_ops.tensor_list_stack(l_unstack_orig[0],
+ dtypes.float32)
+ l1_orig_ret = list_ops.tensor_list_stack(l_unstack_orig[1],
+ dtypes.float32)
+
+ # Check that without aliasing, push_back_batch still works; and
+ # that it doesn't modify the input.
+ l0_r_v, l1_r_v, l0_orig_v, l1_orig_v = self.evaluate(
+ (l0_ret, l1_ret, l0_orig_ret, l1_orig_ret))
+ self.assertAllClose([1.0, 2.0, 3.0], l0_r_v)
+ self.assertAllClose([-1.0, 4.0], l1_r_v)
+ self.assertAllClose([1.0, 2.0], l0_orig_v)
+ self.assertAllClose([-1.0], l1_orig_v)
+
+ # Pushing back mismatched shapes fails.
+ with self.assertRaises((errors.InvalidArgumentError, ValueError)):
+ self.evaluate(list_ops.tensor_list_push_back_batch(l_batch, []))
+
+ with self.assertRaisesRegexp(errors.InvalidArgumentError,
+ "incompatible shape to a list at index 0"):
+ self.evaluate(
+ list_ops.tensor_list_push_back_batch(l_batch, [[3.0], [4.0]]))
+
+ with self.assertRaisesRegexp(errors.InvalidArgumentError,
+ "Invalid data type at index 0"):
+ self.evaluate(list_ops.tensor_list_push_back_batch(l_batch, [3, 4]))
+
if __name__ == "__main__":
test.main()