Make loss_ops_test.py work with C API enabled.
authorSkye Wanderman-Milne <skyewm@google.com>
Tue, 30 Jan 2018 19:18:03 +0000 (11:18 -0800)
committerTensorFlower Gardener <gardener@tensorflow.org>
Tue, 30 Jan 2018 20:42:17 +0000 (12:42 -0800)
PiperOrigin-RevId: 183861779

tensorflow/contrib/losses/python/losses/loss_ops_test.py

index 9d0f95e6f3e7fa9666a99e31578b38d52e0b6b4a..1417772e0496cb571488e5b30bd4f3fb1b591730 100644 (file)
@@ -27,6 +27,7 @@ from tensorflow.python.framework import dtypes
 from tensorflow.python.framework import errors_impl
 from tensorflow.python.framework import ops
 from tensorflow.python.framework import random_seed
+from tensorflow.python.framework import test_util
 from tensorflow.python.ops import array_ops
 from tensorflow.python.ops import init_ops
 from tensorflow.python.ops import math_ops
@@ -274,6 +275,7 @@ class SoftmaxCrossEntropyLossTest(test.TestCase):
       self.assertAlmostEqual(np.average(weights) * 10.0, loss, 3)
 
 
+@test_util.with_c_api
 class SparseSoftmaxCrossEntropyLossTest(test.TestCase):
 
   def testNoneWeightRaisesValueError(self):
@@ -471,7 +473,11 @@ class SparseSoftmaxCrossEntropyLossTest(test.TestCase):
       labels = constant_op.constant([[0, 1], [2, 3]])
       weights = constant_op.constant([1.2, 3.4, 5.6, 7.8])
 
-      with self.assertRaises(errors_impl.InvalidArgumentError):
+      if ops._USE_C_API:
+        error_type = ValueError
+      else:
+        error_type = errors_impl.InvalidArgumentError
+      with self.assertRaises(error_type):
         loss_ops.sparse_softmax_cross_entropy(
             logits, labels, weights=weights).eval()