[XLA] Add a test for map with static operands in the local Python client.
authorRoy Frostig <frostig@google.com>
Wed, 14 Feb 2018 20:18:48 +0000 (12:18 -0800)
committerTensorFlower Gardener <gardener@tensorflow.org>
Wed, 14 Feb 2018 20:22:51 +0000 (12:22 -0800)
PiperOrigin-RevId: 185725205

tensorflow/compiler/xla/python/xla_client_test.py

index 65720c6ef9ec1cd7a816bcf719960fa803dd45a1..7565e8146b68ac9e1ba5d12dd4b35741e0f91b2a 100644 (file)
@@ -881,6 +881,13 @@ class EmbeddedComputationsTest(LocalComputationTest):
     c.Mul(c.ParameterFromNumpy(NumpyArrayF32(0)), c.ConstantF32Scalar(2.0))
     return c.Build()
 
+  def _CreateMulF32ByParamComputation(self):
+    """Computation (f32) -> f32 that multiplies one parameter by the other."""
+    c = self._NewComputation("mul_f32_by_param")
+    c.Mul(c.ParameterFromNumpy(NumpyArrayF32(0)),
+          c.ParameterFromNumpy(NumpyArrayF32(0)))
+    return c.Build()
+
   def _CreateMulF64By2Computation(self):
     """Computation (f64) -> f64 that multiplies its parameter by 2."""
     c = self._NewComputation("mul_f64_by2")
@@ -1021,6 +1028,14 @@ class EmbeddedComputationsTest(LocalComputationTest):
           self._CreateBinaryDivF64Computation(), [0])
     self._ExecuteAndCompareClose(c, expected=[0.2, 0.4, 0.75, 1.0])
 
+  def DISABLED_testMapWithStaticOperands(self):
+    c = self._NewComputation()
+    factor = c.ConstantF32Scalar(3.0)
+    c.Map([c.Constant(NumpyArrayF32([1.0, 2.0, 3.0, 4.0]))],
+          self._CreateMulF32ByParamComputation(), [0],
+          static_operands=[factor])
+    self._ExecuteAndCompareClose(c, expected=[3.0, 6.0, 9.0, 12.0])
+
   def testSelectAndScatterF32(self):
     c = self._NewComputation()
     c.SelectAndScatter(c.Constant(NumpyArrayF32([[1., 2., 6.], [4., 5., 3.]])),