Allow output has a different shape from input in the image.transform (#17011).
authorA. Unique TensorFlower <gardener@tensorflow.org>
Thu, 19 Apr 2018 20:21:25 +0000 (13:21 -0700)
committerTensorFlower Gardener <gardener@tensorflow.org>
Thu, 19 Apr 2018 20:24:36 +0000 (13:24 -0700)
PiperOrigin-RevId: 193564222

tensorflow/contrib/image/kernels/image_ops.cc
tensorflow/contrib/image/kernels/image_ops.h
tensorflow/contrib/image/ops/image_ops.cc
tensorflow/contrib/image/python/kernel_tests/image_ops_test.py
tensorflow/contrib/image/python/ops/image_ops.py

index c2e32da..ae4b1ba 100644 (file)
@@ -70,6 +70,7 @@ class ImageProjectiveTransform : public OpKernel {
   void Compute(OpKernelContext* ctx) override {
     const Tensor& images_t = ctx->input(0);
     const Tensor& transform_t = ctx->input(1);
+    const Tensor& output_dim = ctx->input(2);
     OP_REQUIRES(ctx, images_t.shape().dims() == 4,
                 errors::InvalidArgument("Input images must have rank 4"));
     OP_REQUIRES(ctx,
@@ -83,7 +84,11 @@ class ImageProjectiveTransform : public OpKernel {
     auto images = images_t.tensor<T, 4>();
     auto transform = transform_t.matrix<float>();
     Tensor* output_t;
-    OP_REQUIRES_OK(ctx, ctx->allocate_output(0, images_t.shape(), &output_t));
+    // Image is NHWC format.
+    auto output_shape = images_t.shape();
+    output_shape.set_dim(1, output_dim.vec<int>()(0));
+    output_shape.set_dim(2, output_dim.vec<int>()(1));
+    OP_REQUIRES_OK(ctx, ctx->allocate_output(0, output_shape, &output_t));
     auto output = output_t->tensor<T, 4>();
     (FillProjectiveTransform<Device, T>(interpolation_))(
         ctx->eigen_device<Device>(), &output, images, transform);
index ad50133..2320329 100644 (file)
@@ -161,7 +161,7 @@ struct FillProjectiveTransform {
   void operator()(const Device& device, OutputType* output,
                   const InputType& images,
                   const TransformsType& transform) const {
-    output->device(device) = images.generate(
+    output->device(device) = output->generate(
         ProjectiveGenerator<Device, T>(images, transform, interpolation_));
   }
 };
index 68771b3..4c6d8c0 100644 (file)
@@ -19,9 +19,55 @@ limitations under the License.
 
 namespace tensorflow {
 
+using shape_inference::DimensionHandle;
 using shape_inference::InferenceContext;
 using shape_inference::ShapeHandle;
 
+namespace {
+
+// Sets output[0] to shape [batch_dim,height,width,channel_dim], where
+// height and width come from the size_tensor.
+Status SetOutputToSizedImage(InferenceContext* c, DimensionHandle batch_dim,
+                             int size_input_idx, DimensionHandle channel_dim) {
+  // Verify shape of size input.
+  ShapeHandle size;
+  TF_RETURN_IF_ERROR(c->WithRank(c->input(size_input_idx), 1, &size));
+  DimensionHandle unused;
+  TF_RETURN_IF_ERROR(c->WithValue(c->Dim(size, 0), 2, &unused));
+
+  // Get size values from the size tensor.
+  const Tensor* size_tensor = c->input_tensor(size_input_idx);
+  DimensionHandle width;
+  DimensionHandle height;
+  if (size_tensor == nullptr) {
+    width = c->UnknownDim();
+    height = c->UnknownDim();
+  } else {
+    // TODO(petewarden) - Remove once we have constant evaluation in C++ only.
+    if (size_tensor->dtype() != DT_INT32) {
+      return errors::InvalidArgument(
+          "Bad size input type for SetOutputToSizedImage: Expected DT_INT32 "
+          "but got ",
+          DataTypeString(size_tensor->dtype()), " for input #", size_input_idx,
+          " in ", c->DebugString());
+    }
+    auto vec = size_tensor->vec<int32>();
+    height = c->MakeDim(vec(0));
+    width = c->MakeDim(vec(1));
+  }
+  c->set_output(0, c->MakeShape({batch_dim, height, width, channel_dim}));
+  return Status::OK();
+}
+
+Status ResizeShapeFn(InferenceContext* c) {
+  ShapeHandle input;
+  TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 4, &input));
+  return SetOutputToSizedImage(c, c->Dim(input, 0), 2 /* size_input_idx */,
+                               c->Dim(input, 3));
+}
+
+}  // namespace
+
 // TODO(ringwalt): Add a "fill_mode" argument with "constant", "mirror", etc.
 // TODO(ringwalt): Add a "fill_constant" argument for constant mode (default 0).
 // TODO(ringwalt): Add an "output_shape" argument. This is sufficient to
@@ -29,13 +75,11 @@ using shape_inference::ShapeHandle;
 REGISTER_OP("ImageProjectiveTransform")
     .Input("images: dtype")
     .Input("transforms: float32")
+    .Input("output_shape: int32")
     .Attr("dtype: {uint8, int32, int64, float32, float64}")
     .Attr("interpolation: string")
     .Output("transformed_images: dtype")
-    .SetShapeFn([](InferenceContext* c) {
-      c->set_output(0, c->input(0));
-      return Status::OK();
-    })
+    .SetShapeFn(ResizeShapeFn)
     .Doc(R"doc(
 Applies the given transform to each of the images.
 
index b50177a..c0151d3 100644 (file)
@@ -195,10 +195,40 @@ class ImageOpsTest(test_util.TensorFlowTestCase):
           x_init_value=test_image)
       self.assertLess(left_err, 1e-10)
 
+  def _test_grad_different_shape(self, input_shape, output_shape):
+    with self.test_session():
+      test_image_shape = input_shape
+      test_image = np.random.randn(*test_image_shape)
+      test_image_tensor = constant_op.constant(
+          test_image, shape=test_image_shape)
+      test_transform = image_ops.angles_to_projective_transforms(
+          np.pi / 2, 4, 4)
+
+      if len(output_shape) == 2:
+        resize_shape = output_shape
+      elif len(output_shape) == 3:
+        resize_shape = output_shape[0:2]
+      elif len(output_shape) == 4:
+        resize_shape = output_shape[1:3]
+      output = image_ops.transform(
+          images=test_image_tensor,
+          transforms=test_transform,
+          output_shape=resize_shape)
+      left_err = gradient_checker.compute_gradient_error(
+          test_image_tensor,
+          test_image_shape,
+          output,
+          output_shape,
+          x_init_value=test_image)
+      self.assertLess(left_err, 1e-10)
+
   def test_grad(self):
     self._test_grad([16, 16])
     self._test_grad([4, 12, 12])
     self._test_grad([3, 4, 12, 12])
+    self._test_grad_different_shape([16, 16], [8, 8])
+    self._test_grad_different_shape([4, 12, 3], [8, 24, 3])
+    self._test_grad_different_shape([3, 4, 12, 3], [3, 8, 24, 3])
 
 
 class BipartiteMatchTest(test_util.TensorFlowTestCase):
index c139ae8..0cb7bdc 100644 (file)
@@ -212,7 +212,11 @@ def translations_to_projective_transforms(translations, name=None):
         axis=1)
 
 
-def transform(images, transforms, interpolation="NEAREST", name=None):
+def transform(images,
+              transforms,
+              output_shape=None,
+              interpolation="NEAREST",
+              name=None):
   """Applies the given transform(s) to the image(s).
 
   Args:
@@ -228,7 +232,10 @@ def transform(images, transforms, interpolation="NEAREST", name=None):
        where `k = c0 x + c1 y + 1`. The transforms are *inverted* compared to
        the transform mapping input points to output points. Note that gradients
        are not backpropagated into transformation parameters.
+    output_shape: Output dimesion after the transform, [height, width].
+       If None, output is the same size as input image.
     interpolation: Interpolation mode. Supported values: "NEAREST", "BILINEAR".
+    name: The name of the op.
 
   Returns:
     Image(s) with the same type and shape as `images`, with the given
@@ -255,6 +262,14 @@ def transform(images, transforms, interpolation="NEAREST", name=None):
     else:
       raise TypeError("Images should have rank between 2 and 4.")
 
+    if output_shape is None:
+      output_shape = images.get_shape()[1:3]
+    elif len(output_shape) != 2:
+      raise TypeError(
+          "output_shape must either be None or a vector of 2 elements.")
+    output_shape = ops.convert_to_tensor(
+        output_shape, name="output_shape", dtype=dtypes.int32)
+
     if len(transform_or_transforms.get_shape()) == 1:
       transforms = transform_or_transforms[None]
     elif transform_or_transforms.get_shape().ndims is None:
@@ -265,7 +280,7 @@ def transform(images, transforms, interpolation="NEAREST", name=None):
     else:
       raise TypeError("Transforms should have rank 1 or 2.")
     output = gen_image_ops.image_projective_transform(
-        images, transforms, interpolation=interpolation.upper())
+        images, transforms, output_shape, interpolation=interpolation.upper())
     if len(image_or_images.get_shape()) == 2:
       return output[0, :, :, 0]
     elif len(image_or_images.get_shape()) == 3:
@@ -375,14 +390,6 @@ def _image_projective_transform_grad(op, grad):
 
   if image_or_images.dtype.base_dtype not in _IMAGE_DTYPES:
     raise TypeError("Invalid dtype %s." % image_or_images.dtype)
-  if len(image_or_images.get_shape()) == 2:
-    images = image_or_images[None, :, :, None]
-  elif len(image_or_images.get_shape()) == 3:
-    images = image_or_images[None, :, :, :]
-  elif len(image_or_images.get_shape()) == 4:
-    images = image_or_images
-  else:
-    raise TypeError("Images should have rank between 2 and 4")
   if len(transform_or_transforms.get_shape()) == 1:
     transforms = transform_or_transforms[None]
   elif len(transform_or_transforms.get_shape()) == 2:
@@ -395,13 +402,11 @@ def _image_projective_transform_grad(op, grad):
   inverse = linalg_ops.matrix_inverse(transforms)
   transforms = matrices_to_flat_transforms(inverse)
   output = gen_image_ops.image_projective_transform(
-      grad, transforms, interpolation=interpolation)
-  if len(image_or_images.get_shape()) == 2:
-    return [output[0, :, :, 0], None]
-  elif len(image_or_images.get_shape()) == 3:
-    return [output[0, :, :, :], None]
-  else:
-    return [output, None]
+      images=grad,
+      transforms=transforms,
+      output_shape=image_or_images.get_shape()[1:3],
+      interpolation=interpolation)
+  return [output, None, None]
 
 
 def bipartite_match(distance_mat,