From a9034babfd0100f0b998e900d074dd3d839d18bc Mon Sep 17 00:00:00 2001 From: Dan Ringwalt Date: Mon, 5 Feb 2018 10:36:24 -0800 Subject: [PATCH] Make flat_transforms_to_matrices and matrices_to_flat_transforms public (#781). PiperOrigin-RevId: 184549704 --- tensorflow/contrib/image/python/ops/image_ops.py | 87 ++++++++++++++++++------ 1 file changed, 66 insertions(+), 21 deletions(-) diff --git a/tensorflow/contrib/image/python/ops/image_ops.py b/tensorflow/contrib/image/python/ops/image_ops.py index 6122ee5..c139ae8 100644 --- a/tensorflow/contrib/image/python/ops/image_ops.py +++ b/tensorflow/contrib/image/python/ops/image_ops.py @@ -290,31 +290,76 @@ def compose_transforms(*transforms): """ assert transforms, "transforms cannot be empty" with ops.name_scope("compose_transforms"): - composed = _flat_transforms_to_matrices(transforms[0]) + composed = flat_transforms_to_matrices(transforms[0]) for tr in transforms[1:]: # Multiply batches of matrices. - composed = math_ops.matmul(composed, _flat_transforms_to_matrices(tr)) - return _transform_matrices_to_flat(composed) + composed = math_ops.matmul(composed, flat_transforms_to_matrices(tr)) + return matrices_to_flat_transforms(composed) -def _flat_transforms_to_matrices(transforms): - # Make the transform(s) 2D in case the input is a single transform. - transforms = array_ops.reshape(transforms, constant_op.constant([-1, 8])) - num_transforms = array_ops.shape(transforms)[0] - # Add a column of ones for the implicit last entry in the matrix. - return array_ops.reshape( - array_ops.concat( - [transforms, array_ops.ones([num_transforms, 1])], axis=1), - constant_op.constant([-1, 3, 3])) +def flat_transforms_to_matrices(transforms): + """Converts `tf.contrib.image` projective transforms to affine matrices. + Note that the output matrices map output coordinates to input coordinates. For + the forward transformation matrix, call `tf.linalg.inv` on the result. -def _transform_matrices_to_flat(transform_matrices): - # Flatten each matrix. - transforms = array_ops.reshape(transform_matrices, - constant_op.constant([-1, 9])) - # Divide each matrix by the last entry (normally 1). - transforms /= transforms[:, 8:9] - return transforms[:, :8] + Args: + transforms: Vector of length 8, or batches of transforms with shape + `(N, 8)`. + + Returns: + 3D tensor of matrices with shape `(N, 3, 3)`. The output matrices map the + *output coordinates* (in homogeneous coordinates) of each transform to the + corresponding *input coordinates*. + + Raises: + ValueError: If `transforms` have an invalid shape. + """ + with ops.name_scope("flat_transforms_to_matrices"): + transforms = ops.convert_to_tensor(transforms, name="transforms") + if transforms.shape.ndims not in (1, 2): + raise ValueError("Transforms should be 1D or 2D, got: %s" % transforms) + # Make the transform(s) 2D in case the input is a single transform. + transforms = array_ops.reshape(transforms, constant_op.constant([-1, 8])) + num_transforms = array_ops.shape(transforms)[0] + # Add a column of ones for the implicit last entry in the matrix. + return array_ops.reshape( + array_ops.concat( + [transforms, array_ops.ones([num_transforms, 1])], axis=1), + constant_op.constant([-1, 3, 3])) + + +def matrices_to_flat_transforms(transform_matrices): + """Converts affine matrices to `tf.contrib.image` projective transforms. + + Note that we expect matrices that map output coordinates to input coordinates. + To convert forward transformation matrices, call `tf.linalg.inv` on the + matrices and use the result here. + + Args: + transform_matrices: One or more affine transformation matrices, for the + reverse transformation in homogeneous coordinates. Shape `(3, 3)` or + `(N, 3, 3)`. + + Returns: + 2D tensor of flat transforms with shape `(N, 8)`, which may be passed into + `tf.contrib.image.transform`. + + Raises: + ValueError: If `transform_matrices` have an invalid shape. + """ + with ops.name_scope("matrices_to_flat_transforms"): + transform_matrices = ops.convert_to_tensor( + transform_matrices, name="transform_matrices") + if transform_matrices.shape.ndims not in (2, 3): + raise ValueError( + "Matrices should be 2D or 3D, got: %s" % transform_matrices) + # Flatten each matrix. + transforms = array_ops.reshape(transform_matrices, + constant_op.constant([-1, 9])) + # Divide each matrix by the last entry (normally 1). + transforms /= transforms[:, 8:9] + return transforms[:, :8] @ops.RegisterGradient("ImageProjectiveTransform") @@ -346,9 +391,9 @@ def _image_projective_transform_grad(op, grad): raise TypeError("Transforms should have rank 1 or 2.") # Invert transformations - transforms = _flat_transforms_to_matrices(transforms=transforms) + transforms = flat_transforms_to_matrices(transforms=transforms) inverse = linalg_ops.matrix_inverse(transforms) - transforms = _transform_matrices_to_flat(inverse) + transforms = matrices_to_flat_transforms(inverse) output = gen_image_ops.image_projective_transform( grad, transforms, interpolation=interpolation) if len(image_or_images.get_shape()) == 2: -- 2.7.4