Make flat_transforms_to_matrices and matrices_to_flat_transforms public (#781).
authorDan Ringwalt <ringwalt@google.com>
Mon, 5 Feb 2018 18:36:24 +0000 (10:36 -0800)
committerTensorFlower Gardener <gardener@tensorflow.org>
Mon, 5 Feb 2018 18:40:26 +0000 (10:40 -0800)
PiperOrigin-RevId: 184549704

tensorflow/contrib/image/python/ops/image_ops.py

index 6122ee5..c139ae8 100644 (file)
@@ -290,31 +290,76 @@ def compose_transforms(*transforms):
   """
   assert transforms, "transforms cannot be empty"
   with ops.name_scope("compose_transforms"):
-    composed = _flat_transforms_to_matrices(transforms[0])
+    composed = flat_transforms_to_matrices(transforms[0])
     for tr in transforms[1:]:
       # Multiply batches of matrices.
-      composed = math_ops.matmul(composed, _flat_transforms_to_matrices(tr))
-    return _transform_matrices_to_flat(composed)
+      composed = math_ops.matmul(composed, flat_transforms_to_matrices(tr))
+    return matrices_to_flat_transforms(composed)
 
 
-def _flat_transforms_to_matrices(transforms):
-  # Make the transform(s) 2D in case the input is a single transform.
-  transforms = array_ops.reshape(transforms, constant_op.constant([-1, 8]))
-  num_transforms = array_ops.shape(transforms)[0]
-  # Add a column of ones for the implicit last entry in the matrix.
-  return array_ops.reshape(
-      array_ops.concat(
-          [transforms, array_ops.ones([num_transforms, 1])], axis=1),
-      constant_op.constant([-1, 3, 3]))
+def flat_transforms_to_matrices(transforms):
+  """Converts `tf.contrib.image` projective transforms to affine matrices.
 
+  Note that the output matrices map output coordinates to input coordinates. For
+  the forward transformation matrix, call `tf.linalg.inv` on the result.
 
-def _transform_matrices_to_flat(transform_matrices):
-  # Flatten each matrix.
-  transforms = array_ops.reshape(transform_matrices,
-                                 constant_op.constant([-1, 9]))
-  # Divide each matrix by the last entry (normally 1).
-  transforms /= transforms[:, 8:9]
-  return transforms[:, :8]
+  Args:
+    transforms: Vector of length 8, or batches of transforms with shape
+      `(N, 8)`.
+
+  Returns:
+    3D tensor of matrices with shape `(N, 3, 3)`. The output matrices map the
+      *output coordinates* (in homogeneous coordinates) of each transform to the
+      corresponding *input coordinates*.
+
+  Raises:
+    ValueError: If `transforms` have an invalid shape.
+  """
+  with ops.name_scope("flat_transforms_to_matrices"):
+    transforms = ops.convert_to_tensor(transforms, name="transforms")
+    if transforms.shape.ndims not in (1, 2):
+      raise ValueError("Transforms should be 1D or 2D, got: %s" % transforms)
+    # Make the transform(s) 2D in case the input is a single transform.
+    transforms = array_ops.reshape(transforms, constant_op.constant([-1, 8]))
+    num_transforms = array_ops.shape(transforms)[0]
+    # Add a column of ones for the implicit last entry in the matrix.
+    return array_ops.reshape(
+        array_ops.concat(
+            [transforms, array_ops.ones([num_transforms, 1])], axis=1),
+        constant_op.constant([-1, 3, 3]))
+
+
+def matrices_to_flat_transforms(transform_matrices):
+  """Converts affine matrices to `tf.contrib.image` projective transforms.
+
+  Note that we expect matrices that map output coordinates to input coordinates.
+  To convert forward transformation matrices, call `tf.linalg.inv` on the
+  matrices and use the result here.
+
+  Args:
+    transform_matrices: One or more affine transformation matrices, for the
+      reverse transformation in homogeneous coordinates. Shape `(3, 3)` or
+      `(N, 3, 3)`.
+
+  Returns:
+    2D tensor of flat transforms with shape `(N, 8)`, which may be passed into
+      `tf.contrib.image.transform`.
+
+  Raises:
+    ValueError: If `transform_matrices` have an invalid shape.
+  """
+  with ops.name_scope("matrices_to_flat_transforms"):
+    transform_matrices = ops.convert_to_tensor(
+        transform_matrices, name="transform_matrices")
+    if transform_matrices.shape.ndims not in (2, 3):
+      raise ValueError(
+          "Matrices should be 2D or 3D, got: %s" % transform_matrices)
+    # Flatten each matrix.
+    transforms = array_ops.reshape(transform_matrices,
+                                   constant_op.constant([-1, 9]))
+    # Divide each matrix by the last entry (normally 1).
+    transforms /= transforms[:, 8:9]
+    return transforms[:, :8]
 
 
 @ops.RegisterGradient("ImageProjectiveTransform")
@@ -346,9 +391,9 @@ def _image_projective_transform_grad(op, grad):
     raise TypeError("Transforms should have rank 1 or 2.")
 
   # Invert transformations
-  transforms = _flat_transforms_to_matrices(transforms=transforms)
+  transforms = flat_transforms_to_matrices(transforms=transforms)
   inverse = linalg_ops.matrix_inverse(transforms)
-  transforms = _transform_matrices_to_flat(inverse)
+  transforms = matrices_to_flat_transforms(inverse)
   output = gen_image_ops.image_projective_transform(
       grad, transforms, interpolation=interpolation)
   if len(image_or_images.get_shape()) == 2: