Add tf.math.polyval that evaluates an element-wise polynomial using Horner's method...
authorA. Unique TensorFlower <gardener@tensorflow.org>
Mon, 2 Apr 2018 22:39:35 +0000 (15:39 -0700)
committerTensorFlower Gardener <gardener@tensorflow.org>
Mon, 2 Apr 2018 22:41:59 +0000 (15:41 -0700)
PiperOrigin-RevId: 191359241

tensorflow/python/kernel_tests/cwise_ops_test.py
tensorflow/python/ops/math_ops.py
tensorflow/tools/api/generator/BUILD
tensorflow/tools/api/golden/tensorflow.math.pbtxt [new file with mode: 0644]
tensorflow/tools/api/golden/tensorflow.pbtxt

index 8db0bb6..34e7751 100644 (file)
@@ -2165,5 +2165,47 @@ class AccumulateTest(test.TestCase):
         math_ops.accumulate_n([a], tensor_dtype=np.int32)
 
 
+class PolyvalTest(test.TestCase):
+
+  def _runtest(self, dtype, degree):
+    x = np.random.rand(2, 2).astype(dtype)
+    coeffs = [np.random.rand(2, 2).astype(dtype) for _ in range(degree + 1)]
+    np_val = np.polyval(coeffs, x)
+    with self.test_session():
+      tf_val = math_ops.polyval(coeffs, x)
+      self.assertAllClose(np_val, tf_val.eval())
+
+  def testSimple(self):
+    for dtype in [
+        np.int32, np.float32, np.float64, np.complex64, np.complex128
+    ]:
+      for degree in range(5):
+        self._runtest(dtype, degree)
+
+  def testBroadcast(self):
+    dtype = np.float32
+    degree = 3
+    shapes = [(1,), (2, 1), (1, 2), (2, 2)]
+    for x_shape in shapes:
+      for coeff_shape in shapes:
+        x = np.random.rand(*x_shape).astype(dtype)
+        coeffs = [
+            np.random.rand(*coeff_shape).astype(dtype)
+            for _ in range(degree + 1)
+        ]
+        np_val = np.polyval(coeffs, x)
+        with self.test_session():
+          tf_val = math_ops.polyval(coeffs, x)
+          self.assertAllClose(np_val, tf_val.eval())
+
+  def testEmpty(self):
+    x = np.random.rand(2, 2).astype(np.float32)
+    coeffs = []
+    np_val = np.polyval(coeffs, x)
+    with self.test_session():
+      tf_val = math_ops.polyval(coeffs, x)
+      self.assertAllClose(np_val, tf_val.eval())
+
+
 if __name__ == "__main__":
   test.main()
index 276897a..1c20d00 100644 (file)
@@ -71,6 +71,7 @@ See the @{$python/math_ops} guide.
 @@igammac
 @@zeta
 @@polygamma
+@@polyval
 @@betainc
 @@rint
 @@diag
@@ -174,6 +175,7 @@ from tensorflow.python.ops.gen_math_ops import *
 # pylint: enable=wildcard-import
 from tensorflow.python.util import compat
 from tensorflow.python.util import deprecation
+from tensorflow.python.util import nest
 from tensorflow.python.util.tf_export import tf_export
 
 # Aliases for some automatically-generated names.
@@ -184,7 +186,6 @@ arg_min = deprecation.deprecated(None, "Use `argmin` instead")(arg_min)  # pylin
 tf_export("arg_max")(arg_max)
 tf_export("arg_min")(arg_min)
 
-
 # This is set by resource_variable_ops.py. It is included in this way since
 # there is a circular dependency between math_ops and resource_variable_ops
 _resource_variable_type = None
@@ -1343,8 +1344,7 @@ def _ReductionDims(x, axis, reduction_indices):
   else:
     # Fast path: avoid creating Rank and Range ops if ndims is known.
     if isinstance(x, ops.Tensor) and x._rank() is not None:  # pylint: disable=protected-access
-      return constant_op.constant(
-          np.arange(x._rank()), dtype=dtypes.int32)  # pylint: disable=protected-access
+      return constant_op.constant(np.arange(x._rank()), dtype=dtypes.int32)  # pylint: disable=protected-access
     if (isinstance(x, sparse_tensor.SparseTensor) and
         x.dense_shape.get_shape().is_fully_defined()):
       rank = x.dense_shape.get_shape()[0].value  # sparse.dense_shape is 1-D.
@@ -2273,10 +2273,11 @@ def accumulate_n(inputs, shape=None, tensor_dtype=None, name=None):
     ValueError: If `inputs` don't all have same shape and dtype or the shape
     cannot be inferred.
   """
+
   def _input_error():
-    return ValueError(
-        "inputs must be a list of at least one Tensor with the "
-        "same dtype and shape")
+    return ValueError("inputs must be a list of at least one Tensor with the "
+                      "same dtype and shape")
+
   if not inputs or not isinstance(inputs, (list, tuple)):
     raise _input_error()
   inputs = ops.convert_n_to_tensor_or_indexed_slices(inputs)
@@ -2294,8 +2295,8 @@ def accumulate_n(inputs, shape=None, tensor_dtype=None, name=None):
 
   # tensor_dtype is for safety only; operator's output type computed in C++
   if tensor_dtype is not None and tensor_dtype != inputs[0].dtype:
-    raise TypeError("tensor_dtype is {}, but input is of type {}"
-                    .format(tensor_dtype, inputs[0].dtype))
+    raise TypeError("tensor_dtype is {}, but input is of type {}".format(
+        tensor_dtype, inputs[0].dtype))
 
   if len(inputs) == 1 and name is None:
     return inputs[0]
@@ -2761,14 +2762,14 @@ def sparse_segment_sum(data, indices, segment_ids, name=None,
         name=name)
   else:
     return gen_math_ops.sparse_segment_sum(
-        data=data,
-        indices=indices,
-        segment_ids=segment_ids,
-        name=name)
+        data=data, indices=indices, segment_ids=segment_ids, name=name)
 
 
 @tf_export("sparse_segment_mean")
-def sparse_segment_mean(data, indices, segment_ids, name=None,
+def sparse_segment_mean(data,
+                        indices,
+                        segment_ids,
+                        name=None,
                         num_segments=None):
   r"""Computes the mean along sparse segments of a tensor.
 
@@ -2805,14 +2806,14 @@ def sparse_segment_mean(data, indices, segment_ids, name=None,
         name=name)
   else:
     return gen_math_ops.sparse_segment_mean(
-        data=data,
-        indices=indices,
-        segment_ids=segment_ids,
-        name=name)
+        data=data, indices=indices, segment_ids=segment_ids, name=name)
 
 
 @tf_export("sparse_segment_sqrt_n")
-def sparse_segment_sqrt_n(data, indices, segment_ids, name=None,
+def sparse_segment_sqrt_n(data,
+                          indices,
+                          segment_ids,
+                          name=None,
                           num_segments=None):
   r"""Computes the sum along sparse segments of a tensor divided by the sqrt(N).
 
@@ -2842,10 +2843,7 @@ def sparse_segment_sqrt_n(data, indices, segment_ids, name=None,
         name=name)
   else:
     return gen_math_ops.sparse_segment_sqrt_n(
-        data=data,
-        indices=indices,
-        segment_ids=segment_ids,
-        name=name)
+        data=data, indices=indices, segment_ids=segment_ids, name=name)
 
 
 @tf_export("tensordot", "linalg.tensordot")
@@ -3016,6 +3014,47 @@ def tensordot(a, b, axes, name=None):
       return product
 
 
+@tf_export("math.polyval")
+def polyval(coeffs, x, name=None):
+  r"""Computes the elementwise value of a polynomial.
+
+  If `x` is a tensor and `coeffs` is a list n + 1 tensors, this function returns
+  the value of the n-th order polynomial
+
+     p(x) = coeffs[n-1] + coeffs[n-2] * x + ...  + coeffs[0] * x**(n-1)
+
+  evaluated using Horner's method, i.e.
+
+     p(x) = coeffs[n-1] + x * (coeffs[n-2] + ... + x * (coeffs[1] +
+            x * coeffs[0]))
+
+  Args:
+    coeffs: A list of `Tensor` representing the coefficients of the polynomial.
+    x: A `Tensor` representing the variable of the polynomial.
+    name: A name for the operation (optional).
+
+  Returns:
+    A `tensor` of the shape as the expression p(x) with usual broadcasting rules
+    for element-wise addition and multiplication applied.
+
+  @compatibility(numpy)
+  Equivalent to numpy.polyval.
+  @end_compatibility
+  """
+
+  with ops.name_scope(name, "polyval", nest.flatten(coeffs) + [x]) as name:
+    x = ops.convert_to_tensor(x, name="x")
+    if len(coeffs) < 1:
+      return array_ops.zeros_like(x, name=name)
+    coeffs = [
+        ops.convert_to_tensor(coeff, name=("coeff_%d" % index))
+        for index, coeff in enumerate(coeffs)
+    ]
+    p = coeffs[0]
+    for c in coeffs[1:]:
+      p = c + p * x
+    return p
+
 # FFT ops were moved to tf.spectral. tf.fft symbols were part of the TensorFlow
 # 1.0 API so we leave these here for backwards compatibility.
 fft = gen_spectral_ops.fft
index f8063ae..a1c5699 100644 (file)
@@ -94,6 +94,7 @@ genrule(
         "api/logging/__init__.py",
         "api/losses/__init__.py",
         "api/manip/__init__.py",
+        "api/math/__init__.py",
         "api/metrics/__init__.py",
         "api/nn/__init__.py",
         "api/nn/rnn_cell/__init__.py",
diff --git a/tensorflow/tools/api/golden/tensorflow.math.pbtxt b/tensorflow/tools/api/golden/tensorflow.math.pbtxt
new file mode 100644 (file)
index 0000000..897718c
--- /dev/null
@@ -0,0 +1,7 @@
+path: "tensorflow.math"
+tf_module {
+  member_method {
+    name: "polyval"
+    argspec: "args=[\'coeffs\', \'x\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
+  }
+}
index 937044a..afa3b78 100644 (file)
@@ -405,6 +405,10 @@ tf_module {
     mtype: "<type \'module\'>"
   }
   member {
+    name: "math"
+    mtype: "<type \'module\'>"
+  }
+  member {
     name: "metrics"
     mtype: "<type \'module\'>"
   }