Improvements to ResourceVariable + Variant code.
authorEugene Brevdo <ebrevdo@google.com>
Sat, 7 Apr 2018 04:00:42 +0000 (21:00 -0700)
committerTensorFlower Gardener <gardener@tensorflow.org>
Sat, 7 Apr 2018 04:03:10 +0000 (21:03 -0700)
* Works in graph + eager modes
* Fixed shape inference
* Updated shape inference + refiner + constant eval code to support static shape tensor of `-1` meaning unknown shape.
* Gather and Scatter for Variants now properly supported.
* Variable copy-on-write for Variants now does a more shallow copy (as Variants are not expected to be updated "in-place" inside a variable; instead Variants will be updated via read-update-write inside a CriticalSection)

PiperOrigin-RevId: 191975898

16 files changed:
tensorflow/contrib/makefile/tf_op_files.txt
tensorflow/core/common_runtime/shape_refiner.cc
tensorflow/core/framework/shape_inference.cc
tensorflow/core/framework/shape_inference.h
tensorflow/core/framework/shape_inference_test.cc
tensorflow/core/kernels/BUILD
tensorflow/core/kernels/dense_update_functor.cc
tensorflow/core/kernels/dense_update_functor.h
tensorflow/core/kernels/gather_functor.h
tensorflow/core/kernels/resource_variable_ops.cc
tensorflow/core/kernels/scatter_functor.h
tensorflow/core/kernels/training_op_helpers.h
tensorflow/core/ops/list_ops.cc
tensorflow/python/framework/tensor_util.py
tensorflow/python/kernel_tests/list_ops_test.py
tensorflow/python/ops/list_ops.py

index 0bc4c5d..d4c3f2e 100644 (file)
@@ -151,6 +151,7 @@ tensorflow/core/kernels/decode_bmp_op.cc
 tensorflow/core/kernels/depthtospace_op.cc
 tensorflow/core/kernels/data_format_ops.cc
 tensorflow/core/kernels/spacetodepth_op.cc
+tensorflow/core/kernels/dense_update_functor.cc
 tensorflow/core/kernels/dense_update_ops.cc
 tensorflow/core/kernels/deep_conv2d.cc
 tensorflow/core/kernels/decode_wav_op.cc
index 1b7e313..06dbe04 100644 (file)
@@ -431,6 +431,32 @@ Status ShapeRefiner::ConstantPartialShape(InferenceContext* target_context,
   InferenceContext* src_context = GetContext(input_edge->src());
   if (src_context == nullptr) return errors::Internal("Missing src context");
   ShapeHandle src_shape = src_context->output(input_edge->src_output());
+
+  if (src_context->Value(src_context->Rank(src_shape)) == 0) {
+    Tensor t;
+    bool evaluated = false;
+    TF_RETURN_IF_ERROR(
+        EvaluateConstantTensorForEdge(node, dst_idx, &evaluated, &t));
+    if (!evaluated) {
+      return errors::InvalidArgument(
+          "Received a shape scalar with unknown static value.  A static value "
+          "of '-1' is required to represent an unknown shape.");
+    }
+    if (t.dims() == 0) {
+      if (t.dtype() == DT_INT32 && t.scalar<int32>()() == -1) {
+        *result = target_context->UnknownShape();
+        return Status::OK();
+      } else if (t.dtype() == DT_INT64 && t.scalar<int64>()() == -1) {
+        *result = target_context->UnknownShape();
+        return Status::OK();
+      }
+    }
+    return errors::InvalidArgument(
+        "Received an invalid shape scalar with a static value that is not "
+        "'-1': ",
+        t.DebugString());
+  }
+
   TF_RETURN_IF_ERROR(src_context->WithRank(src_shape, 1, &src_shape));
 
   const string& src_op = input_edge->src()->type_string();
index 54ecaa5..cc1ec47 100644 (file)
@@ -726,6 +726,24 @@ ShapeHandle InferenceContext::Matrix(DimensionOrConstant dim1,
   return MakeShape({dim1, dim2});
 }
 
+Status InferenceContext::MakeShapeFromShapeTensorTreatScalarAsUnknownShape(
+    int input_idx, ShapeHandle* out) {
+  ShapeHandle input_shape;
+  TF_RETURN_IF_ERROR(WithRankAtMost(input(input_idx), 1, &input_shape));
+
+  requested_input_tensor_as_partial_shape_[input_idx] = true;
+  if (input_idx < input_tensors_as_shapes_.size() &&
+      input_tensors_as_shapes_[input_idx].IsSet() &&
+      RankKnown(input_tensors_as_shapes_[input_idx])) {
+    *out = input_tensors_as_shapes_[input_idx];
+    return Status::OK();
+  }
+
+  return InternalMakeShapeFromTensor(
+      true /* treat_unknown_scalar_tensor_as_unknown_shape */,
+      input_tensor(input_idx), input_shape, out);
+}
+
 Status InferenceContext::MakeShapeFromShapeTensor(int input_idx,
                                                   ShapeHandle* out) {
   ShapeHandle input_shape;
@@ -739,13 +757,31 @@ Status InferenceContext::MakeShapeFromShapeTensor(int input_idx,
     return Status::OK();
   }
 
-  return MakeShapeFromTensor(input_tensor(input_idx), input_shape, out);
+  return InternalMakeShapeFromTensor(
+      false /* treat_unknown_scalar_tensor_as_unknown_shape */,
+      input_tensor(input_idx), input_shape, out);
 }
 
 Status InferenceContext::MakeShapeFromTensor(const Tensor* t,
                                              ShapeHandle tensor_shape,
                                              ShapeHandle* out) {
+  return InternalMakeShapeFromTensor(
+      false /* treat_unknown_scalar_tensor_as_unknown_shape */, t, tensor_shape,
+      out);
+}
+
+Status InferenceContext::InternalMakeShapeFromTensor(
+    bool treat_unknown_scalar_tensor_as_unknown_shape, const Tensor* t,
+    ShapeHandle tensor_shape, ShapeHandle* out) {
+  // Only callers who have set
+  if (!treat_unknown_scalar_tensor_as_unknown_shape) {
+    TF_RETURN_IF_ERROR(WithRank(tensor_shape, 1, &tensor_shape));
+  }
   if (t == nullptr) {
+    // This is guarded by the check above.
+    if (Rank(tensor_shape) == 0) {
+      return ReturnUnknownShape(out);
+    }
     // Shape tensor is not known, but if the shape of the shape tensor is then
     // the right number of unknown dims can be created.
     DimensionHandle shape_dim = Dim(tensor_shape, 0);
@@ -759,10 +795,46 @@ Status InferenceContext::MakeShapeFromTensor(const Tensor* t,
     return ReturnCreatedShape(dims, out);
   }
 
+  if (t->shape().dims() == 0) {
+    if (t->dtype() == DataType::DT_INT32) {
+      auto flat_t = t->scalar<int32>();
+      if (flat_t() != -1) {
+        *out = nullptr;
+        return errors::InvalidArgument(
+            "Input tensor must be rank 1, or if its rank 0 it must have value "
+            "-1 "
+            "(representing an unknown shape).  Saw value: ",
+            flat_t());
+      }
+      return ReturnUnknownShape(out);
+    } else if (t->dtype() == DataType::DT_INT64) {
+      auto flat_t = t->scalar<int64>();
+      if (flat_t() != -1) {
+        *out = nullptr;
+        return errors::InvalidArgument(
+            "Input tensor must be rank 1, or if its rank 0 it must have value "
+            "-1 "
+            "(representing an unknown shape).  Saw value: ",
+            flat_t());
+      }
+      return ReturnUnknownShape(out);
+    } else {
+      *out = nullptr;
+      return errors::InvalidArgument(
+          "Input tensor must be int32 or int64, but was ",
+          DataTypeString(t->dtype()));
+    }
+  }
+
   if (t->shape().dims() != 1) {
     *out = nullptr;
-    return errors::InvalidArgument("Input tensor must be rank 1, but was rank ",
-                                   t->shape().dims());
+    return errors::InvalidArgument(
+        "Input tensor must be rank 1, but was rank ", t->shape().dims(), ".",
+        ((t->shape().dims() == 0)
+             ? "If it is rank 0 rank 0 it must have statically known value -1 "
+               "(representing an unknown shape). "
+             : " "),
+        "Saw tensor shape ", t->shape().DebugString());
   }
   std::vector<DimensionHandle> dims;
   if (t->dtype() == DataType::DT_INT32) {
index accc587..cdb4bd7 100644 (file)
@@ -463,6 +463,12 @@ class InferenceContext {
   // the input tensor is NULL, then an unknown shape is returned.
   Status MakeShapeFromShapeTensor(int input_idx, ShapeHandle* out);
 
+  // Like the function above, but treats scalar values as unknown
+  // shapes.  **NOTE** If the scalar is statically known, its value
+  // must be -1 or an error is returned.
+  Status MakeShapeFromShapeTensorTreatScalarAsUnknownShape(int input_idx,
+                                                           ShapeHandle* out);
+
   // Returns in <out> a new shape corresponding to <proto>.
   Status MakeShapeFromShapeProto(const TensorShapeProto& proto,
                                  ShapeHandle* out);
@@ -708,6 +714,11 @@ class InferenceContext {
     merged_dims_.clear();
   }
 
+  // Helper method for MakeShapeFromTensor and MakeShapeFromShapeTensor.
+  Status InternalMakeShapeFromTensor(
+      bool treat_unknown_scalar_tensor_as_unknown_shape, const Tensor* t,
+      ShapeHandle tensor_shape, ShapeHandle* out);
+
   ShapeManager shape_manager_;
 
   // inputs_, outputs_, and input_tensors_as_shapes_ refer to values from
index da103bf..586c38e 100644 (file)
@@ -1081,17 +1081,26 @@ TEST_F(ShapeInferenceTest, MakeShapeFromShapeTensor) {
   t = ::tensorflow::test::AsTensor<int64>({});
   EXPECT_EQ("[]", create(&t));
 
+  // Test negative scalar
+  t = ::tensorflow::test::AsScalar<int32>(-1);
+  EXPECT_EQ("?", create(&t));
+
   t = ::tensorflow::test::AsTensor<float>({1, 2, 3});
   EXPECT_TRUE(str_util::StrContains(
       create(&t), "Input tensor must be int32 or int64, but was float"));
 
   t = ::tensorflow::test::AsScalar<int32>(1);
+  auto s_scalar = create(&t);
   EXPECT_TRUE(str_util::StrContains(
-      create(&t), "Input tensor must be rank 1, but was rank 0"));
+      s_scalar,
+      "Input tensor must be rank 1, or if its rank 0 it must have value -1"))
+      << s_scalar;
 
   t = ::tensorflow::test::AsTensor<int32>({1, 2}, TensorShape{2, 1});
+  auto s_matrix = create(&t);
   EXPECT_TRUE(str_util::StrContains(
-      create(&t), "Input tensor must be rank 1, but was rank 2"));
+      s_matrix, "Input tensor must be rank 1, but was rank 2"))
+      << s_matrix;
 
   // Test negative values for the dims.
   t = ::tensorflow::test::AsTensor<int64>({3, -2, 1});
index 783de6a..b931f79 100644 (file)
@@ -1395,6 +1395,7 @@ tf_kernel_library(
     visibility = [":friends"],
     deps = [
         ":bounds_check",
+        ":dense_update_functor",
         "//tensorflow/core:framework",
         "//tensorflow/core:lib",
         "//third_party/eigen3",
index a878fe9..3ed3794 100644 (file)
@@ -18,6 +18,7 @@ limitations under the License.
 #include "tensorflow/core/kernels/dense_update_functor.h"
 
 #include "tensorflow/core/framework/register_types.h"
+#include "tensorflow/core/framework/variant_op_registry.h"
 #include "tensorflow/core/lib/core/errors.h"
 #include "tensorflow/core/platform/mutex.h"
 #include "tensorflow/core/platform/types.h"
@@ -70,4 +71,59 @@ struct DenseUpdate<CPUDevice, string, ASSIGN> {
 
 }  // namespace functor
 
+#define CPU_DENSE_COPY(T)                                                \
+  case DataTypeToEnum<T>::value: {                                       \
+    functor::DenseUpdate<CPUDevice, T, ASSIGN> copy_functor_;            \
+    copy_functor_(context->eigen_device<CPUDevice>(), tensor->flat<T>(), \
+                  from.flat<T>());                                       \
+    break;                                                               \
+  }
+
+#define INSTANTIATE_GET_VARIANT_COPY_FN(DEVICE, TYPE_CALLER, TYPE_DENSE_COPY) \
+  template <>                                                                 \
+  Status VariantCopyFn<DEVICE>(OpKernelContext * context, const Tensor& from, \
+                               Tensor* to) {                                  \
+    PersistentTensor tmp;                                                     \
+    Tensor* tensor;                                                           \
+    AllocatorAttributes attr;                                                 \
+    attr.set_gpu_compatible(true);                                            \
+    attr.set_nic_compatible(true);                                            \
+    TF_RETURN_IF_ERROR(context->allocate_persistent(                          \
+        from.dtype(), from.shape(), &tmp, &tensor, attr));                    \
+    switch (from.dtype()) {                                                   \
+      TYPE_CALLER(TYPE_DENSE_COPY);                                           \
+      default:                                                                \
+        return errors::InvalidArgument(                                       \
+            "VariantCopyFn: Could not perform a deep copy of variant "        \
+            "element of type: ",                                              \
+            DataTypeString(from.dtype()),                                     \
+            " using device: ", context->device()->name());                    \
+    }                                                                         \
+    *to = *tensor;                                                            \
+    return Status::OK();                                                      \
+  }
+
+INSTANTIATE_GET_VARIANT_COPY_FN(CPUDevice, TF_CALL_ALL_TYPES, CPU_DENSE_COPY);
+
+#if GOOGLE_CUDA
+#define GPU_DENSE_COPY(T)                                                \
+  case DataTypeToEnum<T>::value: {                                       \
+    functor::DenseUpdate<GPUDevice, T, ASSIGN> copy_functor_;            \
+    copy_functor_(context->eigen_device<GPUDevice>(), tensor->flat<T>(), \
+                  from.flat<T>());                                       \
+    break;                                                               \
+  }
+#define TF_CALL_GPU_AND_ADDITIONAL_TYPES(T) \
+  TF_CALL_GPU_ALL_TYPES(T);                 \
+  TF_CALL_int32(T);                         \
+  TF_CALL_int64(T);
+INSTANTIATE_GET_VARIANT_COPY_FN(GPUDevice, TF_CALL_GPU_AND_ADDITIONAL_TYPES,
+                                GPU_DENSE_COPY);
+#undef TF_CALL_GPU_AND_ADDITIONAL_TYPES
+#undef GPU_DENSE_COPY
+#endif  // GOOGLE_CUDA
+
+#undef CPU_DENSE_COPY
+#undef INSTANTIATE_GET_VARIANT_COPY_FN
+
 }  // namespace tensorflow
index 4aefe26..240c132 100644 (file)
@@ -19,11 +19,14 @@ limitations under the License.
 #define EIGEN_USE_THREADS
 
 #include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
+#include "tensorflow/core/framework/op_kernel.h"
 #include "tensorflow/core/framework/tensor_types.h"
 
 namespace tensorflow {
 
 typedef Eigen::ThreadPoolDevice CPUDevice;
+typedef Eigen::GpuDevice GPUDevice;
+
 #ifdef TENSORFLOW_USE_SYCL
 typedef Eigen::SyclDevice SYCLDevice;
 #endif  // TENSORFLOW_USE_SYCL
@@ -89,6 +92,17 @@ struct DenseUpdate<SYCLDevice, T, ASSIGN> {
 #endif  // TENSORFLOW_USE_SYCL
 
 }  // end namespace functor
+
+template <typename Device>
+Status VariantCopyFn(OpKernelContext* context, const Tensor& from, Tensor* to);
+
+template <>
+Status VariantCopyFn<CPUDevice>(OpKernelContext* context, const Tensor& from,
+                                Tensor* to);
+template <>
+Status VariantCopyFn<GPUDevice>(OpKernelContext* context, const Tensor& from,
+                                Tensor* to);
+
 }  // end namespace tensorflow
 
 #endif  // TENSORFLOW_KERNELS_DENSE_UPDATE_FUNCTOR_H_
index 16ccb03..2c6e8bf 100644 (file)
@@ -28,6 +28,7 @@ limitations under the License.
 
 namespace tensorflow {
 typedef Eigen::ThreadPoolDevice CPUDevice;
+typedef Eigen::GpuDevice GPUDevice;
 
 namespace functor {
 
@@ -50,7 +51,7 @@ SliceIndex HandleCopies(OpKernelContext* ctx,
   }
   // Compute slice_bytes here so that static knowledge is available
   const size_t slice_bytes = slice_elems * sizeof(T);
-  auto worker_threads = ctx->device()->tensorflow_cpu_worker_threads();
+  auto* worker_threads = ctx->device()->tensorflow_cpu_worker_threads();
   mutex mu;
   // Store the value of invalidate index for printing error information, it's a
   // shared variable.
@@ -162,6 +163,16 @@ struct GatherFunctor<CPUDevice, T, Index> {
   }
 };
 
+template <typename Index>
+struct GatherFunctor<GPUDevice, Variant, Index> {
+  int64 operator()(OpKernelContext* ctx,
+                   typename TTypes<Variant, 3>::ConstTensor params,
+                   typename TTypes<Index>::ConstFlat indices,
+                   typename TTypes<Variant, 3>::Tensor out) {
+    return GatherFunctorCPU<Variant, Index>()(ctx, params, indices, out);
+  }
+};
+
 }  // namespace functor
 }  // namespace tensorflow
 
index f49a05c..7250420 100644 (file)
@@ -280,64 +280,6 @@ class AssignVariableOp : public OpKernel {
 };
 
 template <typename Device>
-Status VariantCopyFn(OpKernelContext* context, const Tensor& from, Tensor* to);
-
-#define CPU_DENSE_COPY(T)                                                \
-  case DataTypeToEnum<T>::value: {                                       \
-    functor::DenseUpdate<CPUDevice, T, ASSIGN> copy_functor_;            \
-    copy_functor_(context->eigen_device<CPUDevice>(), tensor->flat<T>(), \
-                  from.flat<T>());                                       \
-    break;                                                               \
-  }
-
-#define INSTANTIATE_GET_VARIANT_COPY_FN(Device, TYPE_CALLER, TYPE_DENSE_COPY) \
-  template <>                                                                 \
-  Status VariantCopyFn<Device>(OpKernelContext * context, const Tensor& from, \
-                               Tensor* to) {                                  \
-    PersistentTensor tmp;                                                     \
-    Tensor* tensor;                                                           \
-    AllocatorAttributes attr;                                                 \
-    attr.set_gpu_compatible(true);                                            \
-    attr.set_nic_compatible(true);                                            \
-    TF_RETURN_IF_ERROR(context->allocate_persistent(                          \
-        from.dtype(), from.shape(), &tmp, &tensor, attr));                    \
-    switch (from.dtype()) {                                                   \
-      TYPE_CALLER(TYPE_DENSE_COPY);                                           \
-      default:                                                                \
-        return errors::InvalidArgument(                                       \
-            "VariantCopyFn: Could not perform a deep copy of variant "        \
-            "element of type: ",                                              \
-            DataTypeString(from.dtype()),                                     \
-            " using device: ", context->device()->name());                    \
-    }                                                                         \
-    *to = *tensor;                                                            \
-    return Status::OK();                                                      \
-  }
-
-INSTANTIATE_GET_VARIANT_COPY_FN(CPUDevice, TF_CALL_ALL_TYPES, CPU_DENSE_COPY);
-
-#if GOOGLE_CUDA
-#define GPU_DENSE_COPY(T)                                                \
-  case DataTypeToEnum<T>::value: {                                       \
-    functor::DenseUpdate<GPUDevice, T, ASSIGN> copy_functor_;            \
-    copy_functor_(context->eigen_device<GPUDevice>(), tensor->flat<T>(), \
-                  from.flat<T>());                                       \
-    break;                                                               \
-  }
-#define TF_CALL_GPU_AND_ADDITIONAL_TYPES(T) \
-  TF_CALL_GPU_ALL_TYPES(T);                 \
-  TF_CALL_int32(T);                         \
-  TF_CALL_int64(T);
-INSTANTIATE_GET_VARIANT_COPY_FN(GPUDevice, TF_CALL_GPU_AND_ADDITIONAL_TYPES,
-                                GPU_DENSE_COPY);
-#undef TF_CALL_GPU_AND_ADDITIONAL_TYPES
-#undef GPU_DENSE_COPY
-#endif  // GOOGLE_CUDA
-
-#undef CPU_DENSE_COPY
-#undef INSTANTIATE_GET_VARIANT_COPY_FN
-
-template <typename Device>
 class AssignVariableOp<Device, Variant> : public OpKernel {
  public:
   explicit AssignVariableOp(OpKernelConstruction* c) : OpKernel(c) {
@@ -370,9 +312,16 @@ class AssignVariableOp<Device, Variant> : public OpKernel {
     // Copying is unnecessary if we are the last user of the value
     // tensor, we can just adopt the input tensor's buffer instead.
     // Note that Variant objects themselves always reside on host.
+    //
+    // We nevertheless want to signal to the runtime that the tensor
+    // should reside in memory of the associated device, as Variant
+    // tensors may be marked as sitting on either CPU or GPU.  This
+    // helps to elide one or more copies.
     std::unique_ptr<Tensor> input_alias = context->forward_input(
         1, OpKernelContext::Params::kNoReservation /*output_index*/, DT_VARIANT,
-        value.shape(), HOST_MEMORY, attr);
+        value.shape(),
+        std::is_same<Device, CPUDevice>::value ? HOST_MEMORY : DEVICE_MEMORY,
+        attr);
 
     mutex_lock ml(*variable->mu());
     variable->is_initialized = true;
@@ -396,12 +345,8 @@ class AssignVariableOp<Device, Variant> : public OpKernel {
 
     const auto elements_in = value.flat<Variant>();
     auto elements_out = variable->tensor()->flat<Variant>();
-    auto copy_fn = std::bind(&VariantCopyFn<Device>, context,
-                             std::placeholders::_1, std::placeholders::_2);
     for (int64 i = 0; i < elements_in.size(); ++i) {
-      OP_REQUIRES_OK(context, VariantDeviceCopy(
-                                  VariantDeviceCopyDirection::DEVICE_TO_DEVICE,
-                                  elements_in(i), &elements_out(i), copy_fn));
+      elements_out(i) = elements_in(i);
     }
   }
 
@@ -560,7 +505,14 @@ class ResourceGatherOp : public OpKernel {
     }
 
     Tensor* out = nullptr;
-    OP_REQUIRES_OK(c, c->allocate_output(0, result_shape, &out));
+    Tensor tmp;
+    if (params.dtype() == DT_VARIANT) {
+      tmp = Tensor(DT_VARIANT, result_shape);
+      c->set_output(0, tmp);
+      out = &tmp;
+    } else {
+      OP_REQUIRES_OK(c, c->allocate_output(0, result_shape, &out));
+    }
     if (N > 0) {
       const int64 gather_dim_size = params.dim_size(0);
       int64 inner_size = 1;
@@ -607,6 +559,23 @@ TF_CALL_QUANTIZED_TYPES(REGISTER_GATHER_CPU);
 
 TF_CALL_GPU_NUMBER_TYPES(REGISTER_GATHER_GPU);
 
+// Variant objects themselves sit on CPU, even if they contain data
+// pointing to a device.
+REGISTER_KERNEL_BUILDER(Name("ResourceGather")
+                            .Device(DEVICE_GPU)
+                            .HostMemory("resource")
+                            .HostMemory("indices")
+                            .TypeConstraint<Variant>("dtype")
+                            .TypeConstraint<int32>("Tindices"),
+                        ResourceGatherOp<GPUDevice, Variant, int32>)
+REGISTER_KERNEL_BUILDER(Name("ResourceGather")
+                            .Device(DEVICE_GPU)
+                            .HostMemory("resource")
+                            .HostMemory("indices")
+                            .TypeConstraint<Variant>("dtype")
+                            .TypeConstraint<int64>("Tindices"),
+                        ResourceGatherOp<GPUDevice, Variant, int64>)
+
 #endif  // GOOGLE_CUDA
 
 #undef REGISTER_GATHER_CPU
@@ -721,6 +690,8 @@ TF_CALL_REAL_NUMBER_TYPES(REGISTER_SCATTER_MINMAX_CPU);
 
 REGISTER_SCATTER_KERNEL(string, CPU, "ResourceScatterUpdate",
                         scatter_op::UpdateOp::ASSIGN);
+REGISTER_SCATTER_KERNEL(Variant, CPU, "ResourceScatterUpdate",
+                        scatter_op::UpdateOp::ASSIGN);
 
 // Registers GPU kernels.
 #if GOOGLE_CUDA
@@ -733,6 +704,23 @@ REGISTER_SCATTER_KERNEL(string, CPU, "ResourceScatterUpdate",
 TF_CALL_GPU_NUMBER_TYPES_NO_HALF(REGISTER_SCATTER_ARITHMETIC_GPU);
 TF_CALL_GPU_NUMBER_TYPES_NO_HALF(REGISTER_SCATTER_MINMAX_GPU);
 
+REGISTER_KERNEL_BUILDER(Name("ResourceScatterUpdate")
+                            .Device(DEVICE_GPU)
+                            .HostMemory("resource")
+                            .HostMemory("indices")
+                            .TypeConstraint<Variant>("dtype")
+                            .TypeConstraint<int32>("Tindices"),
+                        ResourceScatterUpdateOp<GPUDevice, Variant, int32,
+                                                scatter_op::UpdateOp::ASSIGN>)
+REGISTER_KERNEL_BUILDER(Name("ResourceScatterUpdate")
+                            .Device(DEVICE_GPU)
+                            .HostMemory("resource")
+                            .HostMemory("indices")
+                            .TypeConstraint<Variant>("dtype")
+                            .TypeConstraint<int64>("Tindices"),
+                        ResourceScatterUpdateOp<GPUDevice, Variant, int64,
+                                                scatter_op::UpdateOp::ASSIGN>)
+
 #endif  // GOOGLE_CUDA
 
 #undef REGISTER_SCATTER_ARITHMETIC
index 5266664..ebaa2bd 100644 (file)
@@ -20,8 +20,11 @@ limitations under the License.
 
 #include "third_party/eigen3/Eigen/Core"
 #include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
+#include "tensorflow/core/framework/op_kernel.h"
 #include "tensorflow/core/framework/tensor.h"
+#include "tensorflow/core/framework/variant_op_registry.h"
 #include "tensorflow/core/kernels/bounds_check.h"
+#include "tensorflow/core/kernels/dense_update_functor.h"
 #include "tensorflow/core/platform/types.h"
 
 namespace tensorflow {
@@ -203,9 +206,9 @@ struct ScatterFunctorBase {
     const Index N = static_cast<Index>(indices.size());
     const Index limit = static_cast<Index>(params.dimension(0));
     for (Index i = 0; i < N; i++) {
-      // Grab the index and check its validity.  An earlier version of the
-      // code checked it and then grabbed it from memory a second time, which
-      // was a security risk since it could have changed in between.
+      // Grab the index and check its validity.  Do this carefully,
+      // to avoid checking the value and grabbing it again from
+      // memory a second time (a security risk since it may change in between).
       const Index index = ::tensorflow::internal::SubtleMustCopy(indices(i));
       if (!FastBoundsCheck(index, limit)) return i;
       // Copy last Ndim-1 dimensions of updates[i] to params[index]
@@ -216,6 +219,42 @@ struct ScatterFunctorBase {
   }
 };
 
+template <typename Device, typename Index>
+struct ScatterFunctorVariantAssignBase {
+  Index operator()(OpKernelContext* c, const Device& d,
+                   typename TTypes<Variant>::Matrix params,
+                   typename TTypes<Variant>::ConstMatrix updates,
+                   typename TTypes<Index>::ConstFlat indices) {
+    // indices and params sizes were validated in DoCompute().
+    const Index N = static_cast<Index>(indices.size());
+    const Index limit = static_cast<Index>(params.dimension(0));
+    const Index cols = static_cast<Index>(params.dimension(1));
+    DCHECK_EQ(N, updates.dimension(0));
+    DCHECK_EQ(cols, updates.dimension(1));
+    for (Index i = 0; i < N; i++) {
+      // Grab the index and check its validity.  Do this carefully,
+      // to avoid checking the value and grabbing it again from
+      // memory a second time (a security risk since it may change in between).
+      const Index index = ::tensorflow::internal::SubtleMustCopy(indices(i));
+      if (!FastBoundsCheck(index, limit)) return i;
+      // Copy last Ndim-1 dimensions of updates[i] to params[index]
+      for (int j = 0; j < cols; ++j) {
+        const Variant& to_scatter = updates(i, j);
+        params(index, j) = to_scatter;
+      }
+    }
+    return -1;
+  }
+};
+
+template <typename Index>
+struct ScatterFunctor<CPUDevice, Variant, Index, scatter_op::UpdateOp::ASSIGN>
+    : ScatterFunctorVariantAssignBase<CPUDevice, Index> {};
+
+template <typename Index>
+struct ScatterFunctor<GPUDevice, Variant, Index, scatter_op::UpdateOp::ASSIGN>
+    : ScatterFunctorVariantAssignBase<GPUDevice, Index> {};
+
 #ifdef TENSORFLOW_USE_SYCL
 template <typename T, typename Index, scatter_op::UpdateOp op>
 struct ScatterFunctorBase<SYCLDevice, T, Index, op> {
@@ -227,9 +266,9 @@ struct ScatterFunctorBase<SYCLDevice, T, Index, op> {
     const Index N = static_cast<Index>(indices.size());
     const Index limit = static_cast<Index>(params.dimension(0));
     for (Index i = 0; i < N; i++) {
-      // Grab the index and check its validity.  An earlier version of the
-      // code checked it and then grabbed it from memory a second time, which
-      // was a security risk since it could have changed in between.
+      // Grab the index and check its validity.  Do this carefully,
+      // to avoid checking the value and grabbing it again from
+      // memory a second time (a security risk since it may change in between).
       const Index index = ::tensorflow::internal::SubtleMustCopy(indices(i));
       if (!FastBoundsCheck(index, limit)) return i;
       // Copy last Ndim-1 dimensions of updates[i] to params[index]
@@ -252,9 +291,10 @@ struct ScatterFunctorBase<CPUDevice, T, Index, scatter_op::UpdateOp::ASSIGN> {
     const Index limit = static_cast<Index>(params.dimension(0));
     if (!std::is_same<T, string>::value) {
       for (Index i = 0; i < N; i++) {
-        // Grab the index and check its validity.  An earlier version of the
-        // code checked it and then grabbed it from memory a second time, which
-        // was a security risk since it could have changed in between.
+        // Grab the index and check its validity.  Do this carefully,
+        // to avoid checking the value and grabbing it again from
+        // memory a second time (a security risk since it may change in
+        // between).
         const Index index = ::tensorflow::internal::SubtleMustCopy(indices(i));
         if (!FastBoundsCheck(index, limit)) return i;
         memmove(params.data() + index * params.dimension(1),
@@ -263,9 +303,10 @@ struct ScatterFunctorBase<CPUDevice, T, Index, scatter_op::UpdateOp::ASSIGN> {
       }
     } else {
       for (Index i = 0; i < N; i++) {
-        // Grab the index and check its validity.  An earlier version of the
-        // code checked it and then grabbed it from memory a second time, which
-        // was a security risk since it could have changed in between.
+        // Grab the index and check its validity.  Do this carefully,
+        // to avoid checking the value and grabbing it again from
+        // memory a second time (a security risk since it may change in
+        // between).
         const Index index = ::tensorflow::internal::SubtleMustCopy(indices(i));
         if (!FastBoundsCheck(index, limit)) return i;
         // Copy last Ndim-1 dimensions of updates[i] to params[index]
@@ -321,9 +362,9 @@ struct ScatterScalarFunctorBase {
     const Index N = static_cast<Index>(indices.size());
     const Index limit = static_cast<Index>(params.dimension(0));
     for (Index i = 0; i < N; i++) {
-      // Grab the index and check its validity.  An earlier version of the
-      // code checked it and then grabbed it from memory a second time, which
-      // was a security risk since it could have changed in between.
+      // Grab the index and check its validity.  Do this carefully,
+      // to avoid checking the value and grabbing it again from
+      // memory a second time (a security risk since it may change in between).
       const Index index = ::tensorflow::internal::SubtleMustCopy(indices(i));
       if (!FastBoundsCheck(index, limit)) return i;
       // Broadcast update to params[index]
@@ -334,6 +375,41 @@ struct ScatterScalarFunctorBase {
   }
 };
 
+template <typename Device, typename Index>
+struct ScatterScalarFunctorVariantAssignBase {
+  Index operator()(OpKernelContext* c, const Device& d,
+                   typename TTypes<Variant>::Matrix params,
+                   const typename TTypes<Variant>::ConstScalar update,
+                   typename TTypes<Index>::ConstFlat indices) {
+    // indices and params sizes were validated in DoCompute().
+    const Index N = static_cast<Index>(indices.size());
+    const Index limit = static_cast<Index>(params.dimension(0));
+    const Index cols = static_cast<Index>(params.dimension(1));
+    const Variant& to_scatter = update();
+    for (Index i = 0; i < N; i++) {
+      // Grab the index and check its validity.  Do this carefully,
+      // to avoid checking the value and grabbing it again from
+      // memory a second time (a security risk since it may change in between).
+      const Index index = ::tensorflow::internal::SubtleMustCopy(indices(i));
+      if (!FastBoundsCheck(index, limit)) return i;
+      // Broadcast update to params[index]
+      for (Index j = 0; j < cols; ++j) {
+        params(index, j) = to_scatter;
+      }
+    }
+    return -1;
+  }
+};
+
+template <typename Index>
+struct ScatterScalarFunctor<CPUDevice, Variant, Index,
+                            scatter_op::UpdateOp::ASSIGN>
+    : ScatterScalarFunctorVariantAssignBase<CPUDevice, Index> {};
+template <typename Index>
+struct ScatterScalarFunctor<GPUDevice, Variant, Index,
+                            scatter_op::UpdateOp::ASSIGN>
+    : ScatterScalarFunctorVariantAssignBase<GPUDevice, Index> {};
+
 #ifdef TENSORFLOW_USE_SYCL
 template <typename T, typename Index, scatter_op::UpdateOp op>
 struct ScatterScalarFunctorBase<SYCLDevice, T, Index, op> {
@@ -345,9 +421,9 @@ struct ScatterScalarFunctorBase<SYCLDevice, T, Index, op> {
     const Index N = static_cast<Index>(indices.size());
     const Index limit = static_cast<Index>(params.dimension(0));
     for (Index i = 0; i < N; i++) {
-      // Grab the index and check its validity.  An earlier version of the
-      // code checked it and then grabbed it from memory a second time, which
-      // was a security risk since it could have changed in between.
+      // Grab the index and check its validity.  Do this carefully,
+      // to avoid checking the value and grabbing it again from
+      // memory a second time (a security risk since it may change in between).
       const Index index = ::tensorflow::internal::SubtleMustCopy(indices(i));
       if (!FastBoundsCheck(index, limit)) return i;
       // Broadcast update to params[index]
@@ -370,9 +446,9 @@ struct ScatterScalarFunctorBase<CPUDevice, T, Index,
     const Index N = static_cast<Index>(indices.size());
     const Index limit = static_cast<Index>(params.dimension(0));
     for (Index i = 0; i < N; i++) {
-      // Grab the index and check its validity.  An earlier version of the
-      // code checked it and then grabbed it from memory a second time, which
-      // was a security risk since it could have changed in between.
+      // Grab the index and check its validity.  Do this carefully,
+      // to avoid checking the value and grabbing it again from
+      // memory a second time (a security risk since it may change in between).
       const Index index = ::tensorflow::internal::SubtleMustCopy(indices(i));
       if (!FastBoundsCheck(index, limit)) return i;
       // Broadcast update to params[index]
index f6e2a5a..857daae 100644 (file)
@@ -17,6 +17,7 @@ limitations under the License.
 #define TENSORFLOW_KERNELS_TRAINING_OP_HELPERS_H_
 
 #include "tensorflow/core/framework/op_kernel.h"
+#include "tensorflow/core/framework/variant_op_registry.h"
 #include "tensorflow/core/kernels/dense_update_functor.h"
 #include "tensorflow/core/kernels/variable_ops.h"
 
@@ -40,14 +41,27 @@ Status PrepareToUpdateVariable(OpKernelContext* ctx, Tensor* tensor) {
     // updating.
     PersistentTensor unused;
     Tensor* tmp;
-    AllocatorAttributes attr;
-    attr.set_gpu_compatible(true);
-    attr.set_nic_compatible(true);
-    TF_RETURN_IF_ERROR(ctx->allocate_persistent(
-        tensor->dtype(), tensor->shape(), &unused, &tmp, attr));
-    functor::DenseUpdate<Device, T, ASSIGN> copy_functor;
-    copy_functor(ctx->eigen_device<Device>(), tmp->flat<T>(),
-                 const_cast<const Tensor*>(tensor)->flat<T>());
+    if (std::is_same<T, Variant>::value) {
+      AllocatorAttributes attr;
+      attr.set_on_host(true);
+      TF_RETURN_IF_ERROR(ctx->allocate_persistent(
+          tensor->dtype(), tensor->shape(), &unused, &tmp, attr));
+
+      const auto elements_in = tensor->flat<Variant>();
+      auto elements_out = tmp->flat<Variant>();
+      for (int64 i = 0; i < elements_in.size(); ++i) {
+        elements_out(i) = elements_in(i);
+      }
+    } else {
+      AllocatorAttributes attr;
+      attr.set_gpu_compatible(true);
+      attr.set_nic_compatible(true);
+      TF_RETURN_IF_ERROR(ctx->allocate_persistent(
+          tensor->dtype(), tensor->shape(), &unused, &tmp, attr));
+      functor::DenseUpdate<Device, T, ASSIGN> copy_functor;
+      copy_functor(ctx->eigen_device<Device>(), tmp->flat<T>(),
+                   const_cast<const Tensor*>(tensor)->flat<T>());
+    }
     *tensor = *tmp;
   }
   return Status::OK();
index cad6176..c151055 100644 (file)
@@ -30,7 +30,8 @@ REGISTER_OP("EmptyTensorList")
       DataType t;
       TF_RETURN_IF_ERROR(c->GetAttr("element_dtype", &t));
       shape_inference::ShapeHandle s;
-      TF_RETURN_IF_ERROR(c->MakeShapeFromShapeTensor(0, &s));
+      TF_RETURN_IF_ERROR(
+          c->MakeShapeFromShapeTensorTreatScalarAsUnknownShape(0, &s));
       c->set_output_handle_shapes_and_types(
           0, std::vector<shape_inference::ShapeAndType>{{s, t}});
       return Status::OK();
@@ -193,6 +194,7 @@ REGISTER_OP("TensorListReserve")
     .Attr("element_dtype: type")
     .Attr("shape_type: {int32, int64}")
     .SetShapeFn([](shape_inference::InferenceContext* c) {
+      c->set_output(0, c->Scalar());
       shape_inference::ShapeHandle s;
       TF_RETURN_IF_ERROR(c->MakeShapeFromShapeTensor(0, &s));
       DataType t;
index 64b0fa6..8cf2420 100644 (file)
@@ -822,17 +822,32 @@ def constant_value_as_shape(tensor):  # pylint: disable=invalid-name
   all-or-nothing.
 
   Args:
-    tensor: The rank-1 Tensor to be evaluated.
+    tensor: The rank-0 or rank-1 Tensor to be evaluated.
 
   Returns:
     A `TensorShape` based on the constant value of the given `tensor`.
+
+  Raises:
+    ValueError: If the shape is rank-0 and is not statically known to be -1.
   """
   if isinstance(tensor, ops.EagerTensor):
     return tensor_shape.as_shape(
         [dim if dim != -1 else None for dim in tensor.numpy()])
 
+  if tensor.get_shape().ndims == 0:
+    value = constant_value(tensor)
+    if value is None:
+      raise ValueError(
+          "Received a scalar with unknown value as shape; require a statically "
+          "known scalar with value '-1' to describe an unknown shape.")
+    if value != -1:
+      raise ValueError(
+          "Received a scalar value '%s' as shape; require a statically known "
+          "scalar with value '-1' to describe an unknown shape." % value)
+    return tensor_shape.unknown_shape()
+
   shape = tensor.get_shape().with_rank(1)
-  if tensor.get_shape() == [0]:
+  if shape == [0]:
     return tensor_shape.scalar()
   elif tensor.op.type == "Shape":
     return tensor.op.inputs[0].get_shape()
index dbbed39..d969f0e 100644 (file)
@@ -31,8 +31,11 @@ from tensorflow.python.framework import ops
 from tensorflow.python.framework import test_util
 from tensorflow.python.ops import array_ops
 from tensorflow.python.ops import control_flow_ops
+from tensorflow.python.ops import gradients
 from tensorflow.python.ops import list_ops
 from tensorflow.python.ops import math_ops
+from tensorflow.python.ops import state_ops
+from tensorflow.python.ops import variable_scope as vs
 from tensorflow.python.platform import test
 from tensorflow.python.training import server_lib
 
@@ -43,71 +46,83 @@ def scalar_shape():
 
 class ListOpsTest(test_util.TensorFlowTestCase):
 
+  @test_util.run_in_graph_and_eager_modes()
   def testPushPop(self):
     l = list_ops.empty_tensor_list(element_dtype=dtypes.float32,
                                    element_shape=scalar_shape())
     l = list_ops.tensor_list_push_back(l, constant_op.constant(1.0))
     l, e = list_ops.tensor_list_pop_back(l, element_dtype=dtypes.float32)
-    self.assertAllEqual(e, 1.0)
+    self.assertAllEqual(self.evaluate(e), 1.0)
 
+  @test_util.run_in_graph_and_eager_modes()
   def testPushPopGPU(self):
     if not context.num_gpus():
       return
     with context.device("gpu:0"):
       self.testPushPop()
 
+  @test_util.run_in_graph_and_eager_modes()
   def testStack(self):
     l = list_ops.empty_tensor_list(element_dtype=dtypes.float32,
                                    element_shape=scalar_shape())
     l = list_ops.tensor_list_push_back(l, constant_op.constant(1.0))
     l = list_ops.tensor_list_push_back(l, constant_op.constant(2.0))
     t = list_ops.tensor_list_stack(l, element_dtype=dtypes.float32)
-    self.assertAllEqual(t, [1.0, 2.0])
+    self.assertAllEqual(self.evaluate(t), [1.0, 2.0])
 
+  @test_util.run_in_graph_and_eager_modes()
   def testStackGPU(self):
     if not context.num_gpus():
       return
     with context.device("gpu:0"):
       self.testStack()
 
+  @test_util.run_in_graph_and_eager_modes()
   def testTensorListFromTensor(self):
     t = constant_op.constant([1.0, 2.0])
     l = list_ops.tensor_list_from_tensor(t, element_shape=scalar_shape())
     l, e = list_ops.tensor_list_pop_back(l, element_dtype=dtypes.float32)
-    self.assertAllEqual(e, 2.0)
+    self.assertAllEqual(self.evaluate(e), 2.0)
     l, e = list_ops.tensor_list_pop_back(l, element_dtype=dtypes.float32)
-    self.assertAllEqual(e, 1.0)
-    self.assertAllEqual(list_ops.tensor_list_length(l), 0)
+    self.assertAllEqual(self.evaluate(e), 1.0)
+    self.assertAllEqual(self.evaluate(list_ops.tensor_list_length(l)), 0)
 
+  @test_util.run_in_graph_and_eager_modes()
   def testFromTensorGPU(self):
     if not context.num_gpus():
       return
     with context.device("gpu:0"):
       self.testTensorListFromTensor()
 
+  @test_util.run_in_graph_and_eager_modes()
   def testGetSetItem(self):
     t = constant_op.constant([1.0, 2.0])
     l = list_ops.tensor_list_from_tensor(t, element_shape=scalar_shape())
     e0 = list_ops.tensor_list_get_item(l, 0, element_dtype=dtypes.float32)
-    self.assertAllEqual(e0, 1.0)
+    self.assertAllEqual(self.evaluate(e0), 1.0)
     l = list_ops.tensor_list_set_item(l, 0, 3.0)
     t = list_ops.tensor_list_stack(l, element_dtype=dtypes.float32)
-    self.assertAllEqual(t, [3.0, 2.0])
+    self.assertAllEqual(self.evaluate(t), [3.0, 2.0])
 
+  @test_util.run_in_graph_and_eager_modes()
   def testGetSetGPU(self):
     if not context.num_gpus():
       return
     with context.device("gpu:0"):
       self.testGetSetItem()
 
+  @test_util.run_in_graph_and_eager_modes()
   def testUnknownShape(self):
-    l = list_ops.empty_tensor_list(element_dtype=dtypes.float32,
-                                   element_shape=-1)
+    l = list_ops.empty_tensor_list(
+        element_dtype=dtypes.float32, element_shape=-1)
     l = list_ops.tensor_list_push_back(l, constant_op.constant(1.0))
     l = list_ops.tensor_list_push_back(l, constant_op.constant([1.0, 2.0]))
-    _, e = list_ops.tensor_list_pop_back(l, element_dtype=dtypes.float32)
-    self.assertAllEqual(e, [1.0, 2.0])
+    l, e = list_ops.tensor_list_pop_back(l, element_dtype=dtypes.float32)
+    self.assertAllEqual(self.evaluate(e), [1.0, 2.0])
+    l, e = list_ops.tensor_list_pop_back(l, element_dtype=dtypes.float32)
+    self.assertAllEqual(self.evaluate(e), 1.0)
 
+  @test_util.run_in_graph_and_eager_modes()
   def testCPUGPUCopy(self):
     if not context.num_gpus():
       return
@@ -116,15 +131,16 @@ class ListOpsTest(test_util.TensorFlowTestCase):
     with context.device("gpu:0"):
       l_gpu = array_ops.identity(l)
       self.assertAllEqual(
-          list_ops.tensor_list_pop_back(
-              l_gpu, element_dtype=dtypes.float32)[1],
-          2.0)
+          self.evaluate(
+              list_ops.tensor_list_pop_back(
+                  l_gpu, element_dtype=dtypes.float32)[1]), 2.0)
     l_cpu = array_ops.identity(l_gpu)
     self.assertAllEqual(
-        list_ops.tensor_list_pop_back(
-            l_cpu, element_dtype=dtypes.float32)[1],
-        2.0)
+        self.evaluate(
+            list_ops.tensor_list_pop_back(
+                l_cpu, element_dtype=dtypes.float32)[1]), 2.0)
 
+  @test_util.run_in_graph_and_eager_modes()
   def testGraphStack(self):
     with context.graph_mode(), self.test_session():
       tl = list_ops.empty_tensor_list(
@@ -132,9 +148,11 @@ class ListOpsTest(test_util.TensorFlowTestCase):
           element_dtype=dtypes.int32)
       tl = list_ops.tensor_list_push_back(tl, [1])
       self.assertAllEqual(
-          list_ops.tensor_list_stack(tl, element_dtype=dtypes.int32).eval(),
+          self.evaluate(
+              list_ops.tensor_list_stack(tl, element_dtype=dtypes.int32)),
           [[1]])
 
+  @test_util.run_in_graph_and_eager_modes()
   def testGraphStackInLoop(self):
     with context.graph_mode(), self.test_session():
       t1 = list_ops.empty_tensor_list(
@@ -149,9 +167,10 @@ class ListOpsTest(test_util.TensorFlowTestCase):
 
       i, t1 = control_flow_ops.while_loop(lambda i, t1: math_ops.less(i, 4),
                                           body, [i, t1])
-      s1 = list_ops.tensor_list_stack(t1, element_dtype=dtypes.int32).eval()
-      self.assertAllEqual(s1, [0, 1, 2, 3])
+      s1 = list_ops.tensor_list_stack(t1, element_dtype=dtypes.int32)
+      self.assertAllEqual(self.evaluate(s1), [0, 1, 2, 3])
 
+  @test_util.run_in_graph_and_eager_modes()
   def testGraphStackSwitchDtype(self):
     with context.graph_mode(), self.test_session():
       list_ = list_ops.empty_tensor_list(
@@ -169,11 +188,11 @@ class ListOpsTest(test_util.TensorFlowTestCase):
       for _ in range(2):
         list_, m = body(list_, m)
 
-      s1 = list_ops.tensor_list_stack(
-          list_, element_dtype=dtypes.float32).eval()
+      s1 = list_ops.tensor_list_stack(list_, element_dtype=dtypes.float32)
       np_s1 = np.array([[1, 2, 3], [1, 2, 3]], dtype=np.float32)
-      self.assertAllEqual(s1, np_s1)
+      self.assertAllEqual(self.evaluate(s1), np_s1)
 
+  @test_util.run_in_graph_and_eager_modes()
   def testGraphStackInLoopSwitchDtype(self):
     with context.graph_mode(), self.test_session():
       t1 = list_ops.empty_tensor_list(
@@ -193,10 +212,11 @@ class ListOpsTest(test_util.TensorFlowTestCase):
 
       i, m, t1 = control_flow_ops.while_loop(
           lambda i, m, t1: math_ops.less(i, 4), body, [i, m, t1])
-      s1 = list_ops.tensor_list_stack(t1, element_dtype=dtypes.float32).eval()
+      s1 = list_ops.tensor_list_stack(t1, element_dtype=dtypes.float32)
       np_s1 = np.vstack([np.arange(1, 4) * i for i in range(4)])
-      self.assertAllEqual(s1, np_s1)
+      self.assertAllEqual(self.evaluate(s1), np_s1)
 
+  @test_util.run_in_graph_and_eager_modes()
   def testSerialize(self):
     # pylint: disable=g-import-not-at-top
     try:
@@ -226,8 +246,9 @@ class ListOpsTest(test_util.TensorFlowTestCase):
               l_ps, element_dtype=dtypes.float32)
         with ops.device("/job:worker"):
           worker_e = array_ops.identity(e)
-        self.assertAllEqual(worker_e.eval(), [2.0])
+        self.assertAllEqual(self.evaluate(worker_e), [2.0])
 
+  @test_util.run_in_graph_and_eager_modes()
   def testPushPopGradients(self):
     with backprop.GradientTape() as tape:
       l = list_ops.empty_tensor_list(element_dtype=dtypes.float32,
@@ -237,18 +258,24 @@ class ListOpsTest(test_util.TensorFlowTestCase):
       l = list_ops.tensor_list_push_back(l, c)
       l, e = list_ops.tensor_list_pop_back(l, element_dtype=dtypes.float32)
       e = 2 * e
-    self.assertAllEqual(tape.gradient(e, [c])[0], 2.0)
+    self.assertAllEqual(self.evaluate(tape.gradient(e, [c])[0]), 2.0)
 
+  @test_util.run_in_graph_and_eager_modes()
   def testStackFromTensorGradients(self):
     with backprop.GradientTape() as tape:
       c = constant_op.constant([1.0, 2.0])
       tape.watch(c)
       l = list_ops.tensor_list_from_tensor(c, element_shape=scalar_shape())
-      c2 = list_ops.tensor_list_stack(
-          l, element_dtype=dtypes.float32)
+      c2 = list_ops.tensor_list_stack(l, element_dtype=dtypes.float32)
       result = c2 * 2.0
-    self.assertAllEqual(tape.gradient(result, [c])[0], [2.0, 2.0])
-
+    if context.in_eager_mode():
+      # TODO(b/77609620): Fix this in graph mode.
+      grad = tape.gradient(result, [c])[0]
+    else:
+      grad = gradients.gradients(result, [c])[0]
+    self.assertAllEqual(self.evaluate(grad), [2.0, 2.0])
+
+  @test_util.run_in_graph_and_eager_modes()
   def testGetSetGradients(self):
     with backprop.GradientTape() as tape:
       c = constant_op.constant([1.0, 2.0])
@@ -261,16 +288,40 @@ class ListOpsTest(test_util.TensorFlowTestCase):
       ee = list_ops.tensor_list_get_item(l, 1, element_dtype=dtypes.float32)
       y = e * e + ee * ee
     grad_c, grad_c2 = tape.gradient(y, [c, c2])
-    self.assertAllEqual(grad_c, [0.0, 4.0])
-    self.assertAllEqual(grad_c2, 6.0)
+    self.assertAllEqual(self.evaluate(grad_c), [0.0, 4.0])
+    self.assertAllEqual(self.evaluate(grad_c2), 6.0)
 
+  @test_util.run_in_graph_and_eager_modes()
   def testSetOutOfBounds(self):
     c = constant_op.constant([1.0, 2.0])
     l = list_ops.tensor_list_from_tensor(c, element_shape=scalar_shape())
     with self.assertRaises(errors.InvalidArgumentError):
-      list_ops.tensor_list_set_item(l, 20, 3.0)
+      self.evaluate(list_ops.tensor_list_set_item(l, 20, 3.0))
+
+  @test_util.run_in_graph_and_eager_modes()
+  def testResourceVariableScatterGather(self):
+    c = constant_op.constant([1.0, 2.0], dtype=dtypes.float32)
+    l = list_ops.tensor_list_from_tensor(c, element_shape=scalar_shape())
+    v = vs.get_variable("var", initializer=[l] * 10, use_resource=True)
+    v_r_0_stacked = list_ops.tensor_list_stack(v[0], dtypes.float32)
+    self.evaluate(v.initializer)
+    self.assertAllEqual([1.0, 2.0], self.evaluate(v_r_0_stacked))
+    v_r_sparse_stacked = list_ops.tensor_list_stack(
+        v.sparse_read(0), dtypes.float32)
+    self.assertAllEqual([1.0, 2.0], self.evaluate(v_r_sparse_stacked))
+    l_new_0 = list_ops.tensor_list_from_tensor(
+        [3.0, 4.0], element_shape=scalar_shape())
+    l_new_1 = list_ops.tensor_list_from_tensor(
+        [5.0, 6.0], element_shape=scalar_shape())
+    updated_v = state_ops.scatter_update(v, [3, 5], [l_new_0, l_new_1])
+    updated_v_elems = array_ops.unstack(updated_v)
+    updated_v_stacked = [
+        list_ops.tensor_list_stack(el, dtypes.float32) for el in updated_v_elems
+    ]
+    expected = ([[1.0, 2.0]] * 3 + [[3.0, 4.0], [1.0, 2.0], [5.0, 6.0]] +
+                [[1.0, 2.0]] * 4)
+    self.assertAllEqual(self.evaluate(updated_v_stacked), expected)
 
 
 if __name__ == "__main__":
-  ops.enable_eager_execution()
   test.main()
index bba59eb..bdf0774 100644 (file)
@@ -54,8 +54,8 @@ def _TensorListStackGrad(unused_op, dtensor):
 @ops.RegisterGradient("TensorListFromTensor")
 def _TensorListFromTensorGrad(op, dlist):
   """Gradient for TensorListFromTensor."""
-  if op.inputs[0].shape[0] is not None:
-    num_elements = op.inputs[0].shape[0]
+  if op.inputs[0].shape[0].value is not None:
+    num_elements = op.inputs[0].shape[0].value
   else:
     num_elements = None
   if dlist is None:
@@ -63,9 +63,10 @@ def _TensorListFromTensorGrad(op, dlist):
         element_dtype=op.inputs[0].dtype,
         element_shape=gen_list_ops.tensor_list_element_shape(
             op.outputs[0], shape_type=dtypes.int32))
-  return gen_list_ops.tensor_list_stack(
-      dlist, element_dtype=op.inputs[0].dtype,
-      num_elements=num_elements)
+  tensor_grad = gen_list_ops.tensor_list_stack(
+      dlist, element_dtype=op.inputs[0].dtype, num_elements=num_elements)
+  shape_grad = None
+  return tensor_grad, shape_grad
 
 
 @ops.RegisterGradient("TensorListGetItem")