from tensorflow.python.framework import errors_impl
from tensorflow.python.framework import ops
from tensorflow.python.framework import random_seed
+from tensorflow.python.framework import test_util
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import init_ops
from tensorflow.python.ops import math_ops
self.assertAlmostEqual(np.average(weights) * 10.0, loss, 3)
+@test_util.with_c_api
class SparseSoftmaxCrossEntropyLossTest(test.TestCase):
def testNoneWeightRaisesValueError(self):
labels = constant_op.constant([[0, 1], [2, 3]])
weights = constant_op.constant([1.2, 3.4, 5.6, 7.8])
- with self.assertRaises(errors_impl.InvalidArgumentError):
+ if ops._USE_C_API:
+ error_type = ValueError
+ else:
+ error_type = errors_impl.InvalidArgumentError
+ with self.assertRaises(error_type):
loss_ops.sparse_softmax_cross_entropy(
logits, labels, weights=weights).eval()