self.assertEqual(expected, self.evaluate(unwrapped[0]))
-@test_util.with_c_api
class MirroredStrategyVariableCreationTest(test.TestCase):
config = config_pb2.ConfigProto()
from tensorflow.python.training import distribute as distribute_lib
-@test_util.with_c_api
class MirroredOneCPUDistributionTest(strategy_test_lib.DistributionTestBase):
def _get_distribution_strategy(self):
self._test_call_and_merge_exceptions(self._get_distribution_strategy())
-@test_util.with_c_api
class VariableCreatorStackTest(test.TestCase):
def testCreatorStacksAreThreadLocal(self):
from tensorflow.python.eager import test
from tensorflow.python.framework import constant_op
from tensorflow.python.framework import ops
-from tensorflow.python.framework import test_util
from tensorflow.python.training import server_lib
-@test_util.with_c_api
class MultiWorkerStrategyTest(multi_worker_test_base.MultiWorkerTestBase,
strategy_test_lib.DistributionTestBase):
from tensorflow.python.framework import test_util
-@test_util.with_c_api
class OneDeviceStrategyTest(strategy_test_lib.DistributionTestBase):
def _get_distribution_strategy(self):
self.assertEquals("foo_a", self._canonicalize("foo_a"))
-@test_util.with_c_api
class SharedVariableCreatorTest(test.TestCase):
@test_util.run_in_graph_and_eager_modes()
from tensorflow.python.util import nest
-@test_util.with_c_api
class DistributedValuesTest(test.TestCase):
def testGetEager(self):
v = values.DistributedValues({"/device:cpu:0": 42})
-@test_util.with_c_api
class DistributedDelegateTest(test.TestCase):
@test_util.run_in_graph_and_eager_modes()
return v, devices, mirrored
-@test_util.with_c_api
class RegroupAndSelectDeviceTest(test.TestCase):
def _is_per_device(self, result, expected, klass=values.PerDevice):
merged_estimator_spec))
-@test_util.with_c_api
class PerDeviceDatasetTest(test.TestCase):
config = config_pb2.ConfigProto()
multi_worker_iterator.get_next()
-@test_util.with_c_api
class MirroredVariableTest(test.TestCase):
config = config_pb2.ConfigProto()
return v, tower_local
-@test_util.with_c_api
class TowerLocalVariableTest(test.TestCase):
config = config_pb2.ConfigProto()
def Foo(x, y, z):
return math_ops.tanh(math_ops.matmul(x, y) + z)
- # We added more randomness to function names in C API.
- # TODO(iga): Remove this if statement when we switch to C API.
- if ops._USE_C_API: # pylint: disable=protected-access
- if sys.byteorder == "big":
- self.assertEqual("Foo_kEdkAG8SJvg",
- Foo.instantiate([dtypes.float32] * 3).name)
- else:
- self.assertEqual("Foo_aCYSbwBkR5A",
- Foo.instantiate([dtypes.float32] * 3).name)
+ if sys.byteorder == "big":
+ self.assertEqual("Foo_kEdkAG8SJvg",
+ Foo.instantiate([dtypes.float32] * 3).name)
else:
- self.assertEqual("Foo_d643acf7",
+ self.assertEqual("Foo_aCYSbwBkR5A",
Foo.instantiate([dtypes.float32] * 3).name)
def testSignatureHash(self):
from tensorflow.python.framework import importer
from tensorflow.python.framework import ops
from tensorflow.python.framework import test_ops # pylint: disable=unused-import
-from tensorflow.python.framework import test_util
from tensorflow.python.framework import versions
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import control_flow_ops
from tensorflow.python.platform import test
-@test_util.with_c_api
class ImportGraphDefTest(test.TestCase):
def _MakeGraphDef(self,
return_elements=["foo"],
name="")
- if ops._USE_C_API:
- self.assertEqual(op.name, "foo")
- else:
- self.assertEqual(op.name, "foo_1")
+ self.assertEqual(op.name, "foo")
def testInputMap(self):
with ops.Graph().as_default():
self.assertEqual(sess.run(imported_r), 10)
def testTypeMismatchInGraphDef(self):
- if ops._USE_C_API:
- # TODO(skyewm): improve error message
- error_msg = ("Input 0 of node import/B was passed int32 from import/A:0 "
- "incompatible with expected float.")
- else:
- error_msg = ("Cannot convert a tensor of type int32 to an input of type "
- "float")
-
+ # TODO(skyewm): improve error message
+ error_msg = ("Input 0 of node import/B was passed int32 from import/A:0 "
+ "incompatible with expected float.")
with ops.Graph().as_default():
with self.assertRaisesRegexp(ValueError, error_msg):
importer.import_graph_def(
"Shapes () and (43,) are not compatible" in str(e.exception))
def testInvalidSignatureTooManyInputsInGraphDef(self):
- if ops._USE_C_API:
- # TODO(skyewm): improve error message
- error_msg = "NodeDef expected inputs '' do not match 1 inputs specified"
- else:
- error_msg = r"More inputs specified \('A:0'\) than the op expects"
-
with ops.Graph().as_default():
- with self.assertRaisesRegexp(ValueError, error_msg):
+ # TODO(skyewm): improve error message
+ with self.assertRaisesRegexp(
+ ValueError,
+ "NodeDef expected inputs '' do not match 1 inputs specified"):
importer.import_graph_def(
self._MakeGraphDef("""
node { name: 'A' op: 'IntOutput' }
"""))
def testInvalidSignatureNotEnoughInputsInGraphDef(self):
- if ops._USE_C_API:
- # TODO(skyewm): improve error message
- error_msg = ("NodeDef expected inputs 'int32, float' do not match 1 "
- "inputs specified")
- else:
- error_msg = (r"Input types mismatch \(expected 'int32, float32' but "
- r"got 'int32'\)")
-
with ops.Graph().as_default():
- with self.assertRaisesRegexp(ValueError, error_msg):
+ # TODO(skyewm): improve error message
+ with self.assertRaisesRegexp(
+ ValueError,
+ "NodeDef expected inputs 'int32, float' do not match 1 inputs "
+ "specified"):
importer.import_graph_def(
self._MakeGraphDef("""
node { name: 'A' op: 'IntOutput' }
"""))
def testMissingInputOpInGraphDef(self):
- if ops._USE_C_API:
- error_msg = "Node 'B': Unknown input node 'A:0'"
- else:
- error_msg = "Input tensor 'A:0' not found"
-
with ops.Graph().as_default():
- with self.assertRaisesRegexp(ValueError, error_msg):
+ with self.assertRaisesRegexp(ValueError,
+ "Node 'B': Unknown input node 'A:0'"):
importer.import_graph_def(
self._MakeGraphDef("""
node { name: 'B' op: 'FloatInput' input: 'A:0' }
self.assertEqual(b.inputs[0], feed_a_0)
def testMissingInputTensorInGraphDef(self):
- if ops._USE_C_API:
- error_msg = ("Node 'B': Connecting to invalid output 1 of source node A "
- "which has 1 outputs")
- else:
- error_msg = "Input tensor 'A:1' not found"
-
with ops.Graph().as_default():
- with self.assertRaisesRegexp(ValueError, error_msg):
+ with self.assertRaisesRegexp(
+ ValueError,
+ "Node 'B': Connecting to invalid output 1 of source node A "
+ "which has 1 outputs"):
importer.import_graph_def(
self._MakeGraphDef("""
node { name: 'A' op: 'FloatOutput' }
"""))
def testMissingControlInputInGraphDef(self):
- if ops._USE_C_API:
- error_msg = r"Node 'B': Unknown input node '\^A'"
- else:
- error_msg = r"Control input '\^A' not found"
-
with ops.Graph().as_default():
- with self.assertRaisesRegexp(ValueError, error_msg):
+ with self.assertRaisesRegexp(ValueError,
+ r"Node 'B': Unknown input node '\^A'"):
importer.import_graph_def(
self._MakeGraphDef("""
node { name: 'B' op: 'None' input: '^A' }
"""))
def testInvalidTensorNameOutputIndexInGraphDef(self):
- if ops._USE_C_API:
- error_msg = "Node 'B': Unknown input node 'A:B'"
- else:
- error_msg = "Cannot convert 'A:B' to a tensor name."
-
with ops.Graph().as_default():
- with self.assertRaisesRegexp(ValueError, error_msg):
+ with self.assertRaisesRegexp(ValueError,
+ "Node 'B': Unknown input node 'A:B'"):
importer.import_graph_def(
self._MakeGraphDef("""
node { name: 'B' op: 'None' input: 'A:B' }
"""))
def testInvalidTensorNameInGraphDef(self):
- if ops._USE_C_API:
- error_msg = "Node 'B': Unknown input node 'A:B:0'"
- else:
- error_msg = "Cannot convert 'A:B:0' to a tensor name."
-
with ops.Graph().as_default():
- with self.assertRaisesRegexp(ValueError, error_msg):
+ with self.assertRaisesRegexp(ValueError,
+ "Node 'B': Unknown input node 'A:B:0'"):
importer.import_graph_def(
self._MakeGraphDef("""
node { name: 'B' op: 'None' input: 'A:B:0' }
"""))
def testMissingReturnOperation(self):
- if ops._USE_C_API:
- error_msg = "Requested return node 'B' not found in graph def"
- else:
- error_msg = "return_element 'B' not found in graph_def."
-
with ops.Graph().as_default():
- with self.assertRaisesRegexp(ValueError, error_msg):
+ with self.assertRaisesRegexp(
+ ValueError, "Requested return node 'B' not found in graph def"):
importer.import_graph_def(
self._MakeGraphDef("""
node { name: 'A' op: 'None' }
return_elements=["B"])
def testMissingReturnTensor(self):
- if ops._USE_C_API:
- error_msg = (r"Invalid return output 1 of node 'A', which has 1 "
- r"output\(s\)")
- else:
- error_msg = "return_element 'A:1' not found in graph_def."
-
with ops.Graph().as_default():
- with self.assertRaisesRegexp(ValueError, error_msg):
+ with self.assertRaisesRegexp(
+ ValueError,
+ r"Invalid return output 1 of node 'A', which has 1 output\(s\)"):
importer.import_graph_def(
self._MakeGraphDef("""
node { name: 'A' op: 'IntOutput' }
"""),
return_elements=["A:1"])
- if ops._USE_C_API:
- error_msg = "Requested return tensor 'B:0' not found in graph def"
- else:
- error_msg = "return_element 'B:0' not found in graph_def."
-
- with self.assertRaisesRegexp(ValueError, error_msg):
+ with self.assertRaisesRegexp(
+ ValueError, "Requested return tensor 'B:0' not found in graph def"):
importer.import_graph_def(
self._MakeGraphDef("""
node { name: 'A' op: 'IntOutput' }
"""),
return_elements=["B:0"])
- if ops._USE_C_API:
- error_msg = "Cannot convert 'A:B:0' to a tensor name."
- else:
- error_msg = "return_element 'A:B:0' not found in graph_def."
-
- with self.assertRaisesRegexp(ValueError, error_msg):
+ with self.assertRaisesRegexp(ValueError,
+ "Cannot convert 'A:B:0' to a tensor name."):
importer.import_graph_def(
self._MakeGraphDef("""
node { name: 'A' op: 'IntOutput' }
input_map={"A:2": constant_op.constant(5.0)})
def testInputMapTypeMismatch(self):
- if ops._USE_C_API:
- error_msg = ("Input 0 of node import/B was passed float from Const:0 "
- "incompatible with expected int32.")
- else:
- error_msg = ("Cannot convert a tensor of type float32 to an input of "
- "type int32.")
with ops.Graph().as_default():
- with self.assertRaisesRegexp(ValueError, error_msg):
+ with self.assertRaisesRegexp(
+ ValueError, "Input 0 of node import/B was passed float from Const:0 "
+ "incompatible with expected int32."):
importer.import_graph_def(
self._MakeGraphDef("""
node { name: 'A' op: 'IntOutput' }
value { list { s: 'loc:@A' } }
} }""")
- if ops._USE_C_API:
- error_msg = "Node 'B' expects to be colocated with unknown node 'A'"
- else:
- error_msg = "does not exist during import"
-
with ops.Graph().as_default():
- with self.assertRaisesRegexp(ValueError, error_msg):
+ with self.assertRaisesRegexp(
+ ValueError, "Node 'B' expects to be colocated with unknown node 'A'"):
importer.import_graph_def(
original_graph_def, return_elements=["B"], name="imported_graph")
TypeError, "return_elements must be a list of strings."):
importer.import_graph_def(self._MakeGraphDef(""), return_elements=[7])
- if ops._USE_C_API:
- error_msg = "Cannot convert 'a:b:c' to a tensor name."
- else:
- error_msg = "Requested return_element 'a:b:c' not found in graph_def."
- with self.assertRaisesRegexp(ValueError, error_msg):
- importer.import_graph_def(self._MakeGraphDef(""),
- return_elements=["a:b:c"])
+ with self.assertRaisesRegexp(ValueError,
+ "Cannot convert 'a:b:c' to a tensor name."):
+ importer.import_graph_def(
+ self._MakeGraphDef(""), return_elements=["a:b:c"])
def testDuplicateOperationNames(self):
- if ops._USE_C_API:
- error_msg = "Node 'A' is not unique"
- else:
- error_msg = "Duplicate name 'A' in GraphDef."
-
- with ops.Graph().as_default():
- with self.assertRaisesRegexp(ValueError, error_msg):
- importer.import_graph_def(
- self._MakeGraphDef("""
- node { name: 'A' op: 'IntOutput' }
- node { name: 'B' op: 'IntOutput' }
- node { name: 'A' op: 'IntOutput' }
- """))
+ with self.assertRaisesRegexp(ValueError, "Node 'A' is not unique"):
+ importer.import_graph_def(
+ self._MakeGraphDef("""
+ node { name: 'A' op: 'IntOutput' }
+ node { name: 'B' op: 'IntOutput' }
+ node { name: 'A' op: 'IntOutput' }
+ """))
def testWithExtensionAndAttr(self):
with ops.Graph().as_default() as g:
min_consumer)
def testVersionLow(self):
- with ops.Graph().as_default() as g:
- pat = (r"GraphDef producer version -1 below min producer %d supported "
- r"by TensorFlow \S+\. Please regenerate your graph.$" %
- versions.GRAPH_DEF_VERSION_MIN_PRODUCER)
- # C API throws error during import, Python-only throws error during run
- if ops._USE_C_API:
- with self.assertRaisesRegexp(Exception, pat):
- importer.import_graph_def(self._MakeGraphDef("", producer=-1))
- else:
+ with ops.Graph().as_default():
+ with self.assertRaisesRegexp(
+ Exception,
+ r"GraphDef producer version -1 below min producer %d supported "
+ r"by TensorFlow \S+\. Please regenerate your graph.$" %
+ versions.GRAPH_DEF_VERSION_MIN_PRODUCER):
importer.import_graph_def(self._MakeGraphDef("", producer=-1))
- x = constant_op.constant(
- 7) # Need at least one op to get a C++ graph generated
- with self.test_session(graph=g) as sess:
- with self.assertRaisesRegexp(Exception, pat):
- sess.run(x)
def testVersionHigh(self):
- with ops.Graph().as_default() as g:
- pat = (r"GraphDef min consumer version %d above current version %d "
- r"for TensorFlow \S+\. Please upgrade TensorFlow\.$" %
- (1 << 30, versions.GRAPH_DEF_VERSION))
-
- if ops._USE_C_API:
- with self.assertRaisesRegexp(ValueError, pat):
- importer.import_graph_def(self._MakeGraphDef("",
- min_consumer=1 << 30))
- else:
- # Python API only throws when graph is run
+ with ops.Graph().as_default():
+ with self.assertRaisesRegexp(
+ ValueError,
+ r"GraphDef min consumer version %d above current version %d "
+ r"for TensorFlow \S+\. Please upgrade TensorFlow\.$" %
+ (1 << 30, versions.GRAPH_DEF_VERSION)):
importer.import_graph_def(self._MakeGraphDef("", min_consumer=1 << 30))
- x = constant_op.constant(
- 7) # Need at least one op to get a C++ graph generated
- with self.test_session(graph=g) as sess:
- with self.assertRaisesRegexp(Exception, pat):
- sess.run(x)
def testVersionAppliesToOpConstruction(self):
"""These tests rely on shape fns in test_ops.cc."""
"""),
return_elements=["A"],
producer_op_list=producer_op_list)
- if ops._USE_C_API:
- error_msg = "Operation 'import/A' has no attr named 'default_int'."
- else:
- error_msg = "No attr named 'default_int'"
- with self.assertRaisesRegexp(ValueError, error_msg):
+ with self.assertRaisesRegexp(
+ ValueError, "Operation 'import/A' has no attr named 'default_int'."):
a[0].get_attr("default_int")
- # Unknown attrs cannot be imported using C API. This test will eventually be
- # deleted.
- if not ops._USE_C_API:
- # Attr only in producer_op_list with non-default value is preserved.
- with ops.Graph().as_default():
- a = importer.import_graph_def(
- self._MakeGraphDef("""
- node { name: 'A' op: 'OpWithFutureDefaultAttr'
- attr { key: 'default_int' value { i: 987 } } }
- """),
- return_elements=["A"],
- producer_op_list=producer_op_list)
- self.assertEqual(987, a[0].get_attr("default_int"))
-
def testFunctions(self):
dtype = dtypes.float32
+
@function.Defun(dtype, dtype, dtype, dtype)
def Grad(x, y, dout1, dout2): # pylint: disable=unused-argument
# Return the inputs for simplicity of testing. The correct return value
def testImportInsideDefun(self):
g = ops.Graph()
with g.as_default():
+
@function.Defun()
def Add2(x, y):
return math_ops.add(x, y)
def testImportGraphWithFunctionTwice(self):
g = ops.Graph()
with g.as_default():
+
@function.Defun()
def Add2(x, y):
return math_ops.add(x, y)
# pylint: enable=invalid-name
-@test_util.with_c_api
class SimpleMetaGraphTest(test.TestCase):
def testNoVariables(self):
self.assertIs(global_vars[0], trainable_vars[0])
-@test_util.with_c_api
class ScopedMetaGraphTest(test.TestCase):
def _testScopedExport(self, test_dir, exported_filenames):
self.assertEqual("", str(graph2.as_graph_element("matmul").device))
-@test_util.with_c_api
class MetaGraphWithVariableScopeTest(test.TestCase):
def testMetricsCollection(self):
initializer = variables.local_variables_initializer()
-@test_util.with_c_api
class ExportImportAcrossScopesTest(test.TestCase):
def testPartionedVariables(self):
return [tensor_shape.unknown_shape() for _ in op.outputs]
-@test_util.with_c_api
class OpDefLibraryTest(test_util.TensorFlowTestCase):
def setUp(self):
self.assertEqual(t_c, [x.dtype for x in c])
-@test_util.with_c_api
class OpDefLibraryGraphTest(test_util.TensorFlowTestCase):
def setUp(self):
import weakref
from tensorflow.core.framework import attr_value_pb2
-from tensorflow.core.framework import types_pb2
from tensorflow.core.protobuf import config_pb2
from tensorflow.python.client import session
from tensorflow.python.eager import context
from tensorflow.python.framework import versions
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import control_flow_ops
-from tensorflow.python.ops import gen_array_ops
from tensorflow.python.ops import math_ops
from tensorflow.python.ops import resource_variable_ops
from tensorflow.python.ops import resources
ops._set_call_cpp_shape_fn(common_shapes.call_cpp_shape_fn)
-@test_util.with_c_api
class ResourceTest(test_util.TensorFlowTestCase):
def testBuildGraph(self):
resources.shared_resources()).eval()), 0)
-@test_util.with_c_api
class TensorAndShapeTest(test_util.TensorFlowTestCase):
def testShape(self):
_ = a + b
-@test_util.with_c_api
class IndexedSlicesTest(test_util.TensorFlowTestCase):
def testToTensor(self):
self.assertAllEqual(x.indices.eval(), [0, 2])
-@test_util.with_c_api
class NodeDefConstructorTest(test_util.TensorFlowTestCase):
def testNoArgs(self):
return op.outputs
-@test_util.with_c_api
class OperationTest(test_util.TensorFlowTestCase):
def testNoInputs(self):
attr_value_pb2.NameAttrList(name="MyFunc"))
# Try fetching missing attr
- if ops._USE_C_API:
- error_msg = "Operation 'FuncAttr' has no attr named 'FakeAttr'."
- else:
- error_msg = "No attr named 'FakeAttr' in name: \"FuncAttr\""
-
- with self.assertRaisesRegexp(ValueError, error_msg):
+ with self.assertRaisesRegexp(
+ ValueError, "Operation 'FuncAttr' has no attr named 'FakeAttr'."):
op.get_attr("FakeAttr")
# TODO(b/65162920): remove this test when users who are directly mutating the
# TODO(nolivia): test all error cases
def testAddControlInput(self):
- # The C API dedups redundant control edges, pure Python does not
- if ops._USE_C_API: return
- with ops.Graph().as_default():
- x = constant_op.constant(1).op
- y = constant_op.constant(2).op
- z = constant_op.constant(3).op
- z._add_control_input(x) # pylint: disable=protected-access
- self.assertEqual(z.control_inputs, [x])
- z._add_control_input(x) # pylint: disable=protected-access
- self.assertEqual(z.control_inputs, [x, x])
- z._add_control_inputs([x, y, y]) # pylint: disable=protected-access
- self.assertEqual(z.control_inputs, [x, x, x, y, y])
- self.assertEqual(x._control_outputs, [z])
-
- def testAddControlInputC(self):
- # The C API dedups redundant control edges, pure Python does not
- if not ops._USE_C_API: return
with ops.Graph().as_default():
x = constant_op.constant(1).op
y = constant_op.constant(2).op
self.assertEqual(list(f.op.inputs), [d, e])
def testControlInputCycle(self):
- # Non-C API path has a different error message
- if not ops._USE_C_API: return
graph = ops.Graph()
with graph.as_default():
z = constant_op.constant(0)
sess.run(z)
def testUpdateInputShapeError(self):
- # C-API throws the error differently.
- if ops._USE_C_API:
- return
- g = ops.Graph()
- with g.as_default():
- w = constant_op.constant(2, shape=[3, 1])
- x = constant_op.constant(0, shape=[3, 1])
- y = constant_op.constant(1, shape=[2, 2])
- z = w + x
- z.op._update_input(0, y) # pylint: disable=protected-access
-
- with session.Session(graph=g) as sess:
- with self.assertRaisesRegexp(errors.InvalidArgumentError,
- r"Incompatible shapes: \[2,2\] vs. \[3,1\]"):
- sess.run(z)
-
- def testUpdateInputShapeErrorC(self):
- if not ops._USE_C_API:
- return
g = ops.Graph()
with g.as_default():
w = constant_op.constant(2, shape=[3, 1])
z.op._update_input(0, y) # pylint: disable=protected-access
def testUpdateInputOutOfRange(self):
- # C-API throws the error differently.
- if ops._USE_C_API: return
- g = ops.Graph()
- with g.as_default():
- x = constant_op.constant(1)
- with self.assertRaisesRegexp(IndexError, "list index out of range"):
- x.op._update_input(1, x) # pylint: disable=protected-access
-
- def testUpdateInputOutOfRangeC(self):
- # C-API throws the error differently.
- if not ops._USE_C_API: return
g = ops.Graph()
with g.as_default():
x = constant_op.constant(1)
y = constant_op.constant(1)
z = x + y
- # Pure Python mode doesn't create OpDefs for constants
- if ops._USE_C_API:
- self.assertEqual(x.op.op_def.name, "Const")
- self.assertEqual(len(x.op.op_def.input_arg), 0)
- self.assertEqual(len(x.op.op_def.output_arg), 1)
+ self.assertEqual(x.op.op_def.name, "Const")
+ self.assertEqual(len(x.op.op_def.input_arg), 0)
+ self.assertEqual(len(x.op.op_def.output_arg), 1)
self.assertEqual(z.op.op_def.name, "Add")
self.assertEqual(len(z.op.op_def.input_arg), 2)
op.inputs.append(None)
-@test_util.with_c_api
class CreateOpTest(test_util.TensorFlowTestCase):
def testNodeDefArgs(self):
# the control flow context isn't set properly, but a more complicated use case
# that might not be obvious to test will fail). Thus we instead explicitly test
# the low-level behavior.
-@test_util.with_c_api
class CreateOpFromTFOperationTest(test_util.TensorFlowTestCase):
def testBasic(self):
g = ops.Graph()
with g.as_default():
x = test_ops.int_output()
- if ops._USE_C_API:
- c_op = ops._create_c_op(
- g, ops._NodeDef("IntInputIntOutput", "myop"), [x], [])
- op = g._create_op_from_tf_operation(c_op)
- else:
- # Test pure-Python version to make sure C API has same behavior.
- op = test_ops.int_input_int_output(x, name="myop").op
+ c_op = ops._create_c_op(
+ g, ops._NodeDef("IntInputIntOutput", "myop"), [x], [])
+ op = g._create_op_from_tf_operation(c_op)
self.assertEqual(op.name, "myop")
self.assertEqual(op.type, "IntInputIntOutput")
g = ops.Graph()
with g.as_default():
x = constant_op.constant([[1, 2, 3], [4, 5, 6]])
- if ops._USE_C_API:
- c_op = ops._create_c_op(g, ops._NodeDef("Identity", "myop"), [x], [])
- op = g._create_op_from_tf_operation(c_op)
- else:
- # Test pure-Python version to make sure C API has same behavior.
- op = array_ops.identity(x, name="myop").op
+ c_op = ops._create_c_op(g, ops._NodeDef("Identity", "myop"), [x], [])
+ op = g._create_op_from_tf_operation(c_op)
self.assertEqual(op.name, "myop")
self.assertEqual(op.type, "Identity")
def testUniqueName(self):
g = ops.Graph()
with g.as_default():
- if ops._USE_C_API:
- c_op = ops._create_c_op(g, ops._NodeDef("IntOutput", "myop"), [], [])
- c_op2 = ops._create_c_op(g, ops._NodeDef("IntOutput", "myop_1"), [], [])
- op = g._create_op_from_tf_operation(c_op)
- op2 = g._create_op_from_tf_operation(c_op2)
- else:
- # Test pure-Python version to make sure C API has same behavior.
- op = test_ops.int_output(name="myop").op
- op2 = test_ops.int_output(name="myop_1").op
+ c_op = ops._create_c_op(g, ops._NodeDef("IntOutput", "myop"), [], [])
+ c_op2 = ops._create_c_op(g, ops._NodeDef("IntOutput", "myop_1"), [], [])
+ op = g._create_op_from_tf_operation(c_op)
+ op2 = g._create_op_from_tf_operation(c_op2)
# Create ops with same names as op1 and op2. We expect the new names to be
# uniquified.
x = test_ops.int_output()
def true_fn():
- if ops._USE_C_API:
- ops._create_c_op(ops.get_default_graph(),
- ops._NodeDef("IntInput", "cond/myop"), [x], [])
- new_ops = g._add_new_tf_operations()
- self.assertEqual(len(new_ops), 1)
- else:
- # Test pure-Python version to make sure C API has same behavior.
- test_ops.int_input(x, name="myop")
+ ops._create_c_op(ops.get_default_graph(),
+ ops._NodeDef("IntInput", "cond/myop"), [x], [])
+ new_ops = g._add_new_tf_operations()
+ self.assertEqual(len(new_ops), 1)
return x
control_flow_ops.cond(x < 10, true_fn, lambda: x)
x = test_ops.int_output()
def body(i):
- if ops._USE_C_API:
- ops._create_c_op(ops.get_default_graph(),
- ops._NodeDef("IntInput", "myloop/myop"), [x], [])
- new_ops = g._add_new_tf_operations()
- self.assertEqual(len(new_ops), 1)
- else:
- # Test pure-Python version to make sure C API has same behavior.
- test_ops.int_input(x, name="myop")
+ ops._create_c_op(ops.get_default_graph(),
+ ops._NodeDef("IntInput", "myloop/myop"), [x], [])
+ new_ops = g._add_new_tf_operations()
+ self.assertEqual(len(new_ops), 1)
return i
control_flow_ops.while_loop(lambda i: i < 10, body, [0], name="myloop")
def body(i):
c = constant_op.constant(1.0, name="c")
- if ops._USE_C_API:
- ops._create_c_op(ops.get_default_graph(),
- ops._NodeDef("IntInput", "myloop/myop"), [x], [])
- with ops.control_dependencies([c]):
- new_ops = g._add_new_tf_operations()
- self.assertEqual(len(new_ops), 1)
- else:
- with ops.control_dependencies([c]):
- test_ops.int_input(x, name="myop")
+ ops._create_c_op(ops.get_default_graph(),
+ ops._NodeDef("IntInput", "myloop/myop"), [x], [])
+ with ops.control_dependencies([c]):
+ new_ops = g._add_new_tf_operations()
+ self.assertEqual(len(new_ops), 1)
return i
control_flow_ops.while_loop(lambda i: i < 10, body, [0], name="myloop")
c = constant_op.constant(1.0)
def body(i):
- if ops._USE_C_API:
- ops._create_c_op(ops.get_default_graph(),
- ops._NodeDef("IntInput", "myloop/myop"), [x], [])
- with ops.control_dependencies([c]):
- new_ops = g._add_new_tf_operations()
- self.assertEqual(len(new_ops), 1)
- else:
- with ops.control_dependencies([c]):
- test_ops.int_input(x, name="myop")
+ ops._create_c_op(ops.get_default_graph(),
+ ops._NodeDef("IntInput", "myloop/myop"), [x], [])
+ with ops.control_dependencies([c]):
+ new_ops = g._add_new_tf_operations()
+ self.assertEqual(len(new_ops), 1)
return i
control_flow_ops.while_loop(lambda i: i < 10, body, [0], name="myloop")
self.assertIsNotNone(op.control_inputs[0]._get_control_flow_context())
-@test_util.with_c_api
class ApplyOpTest(test_util.TensorFlowTestCase):
def testNodeDefArgs(self):
out_3.op.node_def)
-@test_util.with_c_api
class NameStackTest(test_util.TensorFlowTestCase):
def testBasics(self):
pass
-@test_util.with_c_api
class NameTest(test_util.TensorFlowTestCase):
def testGenerateName(self):
g.create_op("FloatOutput", [], [dtypes.float32]).name)
-@test_util.with_c_api
class DeviceTest(test_util.TensorFlowTestCase):
def testNoDevice(self):
""", gd)
-@test_util.with_c_api
class MultithreadedGraphStateTest(test_util.TensorFlowTestCase):
class TestThread(threading.Thread):
self.assertEquals("foo" + s + "/FloatOutput_1", t.result[1].name)
-@test_util.with_c_api
class ObjectWithName(object):
def __init__(self, name):
return self._name
-@test_util.with_c_api
class CollectionTest(test_util.TensorFlowTestCase):
def test_get_collections(self):
return x_grad
-@test_util.with_c_api
class RegistrationTest(test_util.TensorFlowTestCase):
def testRegisterGradients(self):
ops.get_gradient_function(y.op)
-@test_util.with_c_api
class ComparisonTest(test_util.TensorFlowTestCase):
def testMembershipAllowed(self):
self.assertTrue(t1 not in [t2])
-@test_util.with_c_api
class ControlDependenciesTest(test_util.TensorFlowTestCase):
- @test_util.enable_c_api
def testBasic(self):
g = ops.Graph()
with g.as_default():
self.assertEqual(b.op.control_inputs, [])
-@test_util.with_c_api
class OpScopeTest(test_util.TensorFlowTestCase):
@test_util.run_in_graph_and_eager_modes()
self.assertEqual(ops.get_name_scope(), "")
-@test_util.with_c_api
class GraphTest(test_util.TensorFlowTestCase):
def setUp(self):
sess.run(a)
-@test_util.with_c_api
class AttrScopeTest(test_util.TensorFlowTestCase):
def _get_test_attrs(self):
ops.RegisterShape("KernelLabel")(common_shapes.scalar_shape)
-@test_util.with_c_api
class KernelLabelTest(test_util.TensorFlowTestCase):
- @test_util.enable_c_api
def testNoLabel(self):
with self.test_session():
self.assertAllEqual(b"My label is: default",
self.assertAllEqual(b"My label is: overload_2", overload_2.eval())
-@test_util.with_c_api
class AsGraphDefTest(test_util.TensorFlowTestCase):
def testGraphDefVersion(self):
return ops.OpStats("flops", 20)
-@test_util.with_c_api
class StatisticsTest(test_util.TensorFlowTestCase):
def testRegisteredNode(self):
self.assertEqual(3, flops_total.value)
-@test_util.with_c_api
class ColocationGroupTest(test_util.TensorFlowTestCase):
def testBasic(self):
self.assertEqual("/device:CPU:0", b.device)
-@test_util.with_c_api
class DeprecatedTest(test_util.TensorFlowTestCase):
def testSuccess(self):
- # TODO(skyewm): make g.graph_def_versions work with the C API enabled
- if ops._USE_C_API: return
-
with ops.Graph().as_default() as g:
- g.graph_def_versions.producer = 7
+ test_util.set_producer_version(g, 7)
old = test_ops.old()
with self.test_session(graph=g):
old.run()
with self.assertRaisesRegexp(NotImplementedError, self._error()):
test_ops.old()
- def testGraphExecutionFail(self):
- # TODO(skyewm): make g.graph_def_versions work with the C API enabled
- if ops._USE_C_API: return
-
- with ops.Graph().as_default() as g:
- g.graph_def_versions.producer = 7
- old = test_ops.old()
- g.graph_def_versions.producer = versions.GRAPH_DEF_VERSION
- with self.test_session(graph=g):
- with self.assertRaisesRegexp(errors.UnimplementedError, self._error()):
- old.run()
-
-@test_util.with_c_api
class DenseTensorLikeTypeTest(test_util.TensorFlowTestCase):
def testSuccess(self):
DenseTensorLikeTypeTest.BadClassBadDtype)
-@test_util.with_c_api
class NameScopeTest(test_util.TensorFlowTestCase):
def testStripAndPrependScope(self):
self.assertRaisesRegexp(ValueError, "'_' is not a valid scope name", f)
-@test_util.with_c_api
class TracebackTest(test_util.TensorFlowTestCase):
def testTracebackWithStartLines(self):
self.assertEquals(frame, frame_with_start_line[:-1])
-@test_util.with_c_api
-class OutputTypesTest(test_util.TensorFlowTestCase):
- """Tests Operation._output_types property.
-
- This test should not exist as _output_types is a private property.
- This property is used by util.copy_elements and its tests would normally
- cover Operation._output_types. However, we can't yet run these tests in C
- API mode because their use _set_device method. This test will be deleted
- once we port _set_device and run the copy tests with C API on.
- """
- # TODO(iga): Remove this test
-
- def setUp(self):
- self.prev_use_c_api = ops._USE_C_API # pylint: disable=protected-access
- ops._USE_C_API = True # pylint: disable=protected-access
-
- def tearDown(self):
- ops._USE_C_API = self.prev_use_c_api # pylint: disable=protected-access
-
- def testOneOutput(self):
- g = ops.Graph()
- with g.as_default():
- # Using a constant because creating unregistered ops
- # doesn't work with the C API.
- op = constant_op.constant(12, dtype=dtypes.uint16).op
- # pylint: disable=protected-access
- self.assertEqual([types_pb2.DT_UINT16], op._output_types)
- # pylint: enable=protected-access
-
- def testTwoDifferentOutputs(self):
- g = ops.Graph()
- with g.as_default():
- x = constant_op.constant([1, 1, 2, 4, 4, 4, 7, 8, 8],
- dtype=dtypes.double)
- y, _ = gen_array_ops.unique(x)
- self.assertEqual([types_pb2.DT_DOUBLE, types_pb2.DT_INT32],
- y.op._output_types) # pylint: disable=protected-access
-
- def testThreeOutputs(self):
- g = ops.Graph()
- with g.as_default():
- # Using a split operationt because creating unregistered ops
- # doesn't work with the C API.
- a = constant_op.constant("abc", dtype=dtypes.string, shape=[5, 30])
- split0, _, _ = array_ops.split(a, [4, 15, 11], 1)
- # pylint: disable=protected-access
- self.assertEqual([types_pb2.DT_STRING] * 3, split0.op._output_types)
- # pylint: enable=protected-access
-
-
-@test_util.with_c_api
class EnableEagerExecutionTest(test_util.TensorFlowTestCase):
def testBadArgumentsToEnableEagerExecution(self):
raise RuntimeError("did not expect to be called")
-@test_util.with_c_api
class SmartCondTest(test_util.TensorFlowTestCase):
def testTrue(self):
self.assertEqual(y.eval(feed_dict={x: -1}), 2)
def testEval(self):
- # Constant expression evaluation only works with the C API enabled.
- if not ops._USE_C_API: return
-
with ops.Graph().as_default():
with session.Session():
x = constant_op.constant(1)
smart_cond.smart_cond(True, lambda: x)
-@test_util.with_c_api
class SmartCaseTest(test_util.TensorFlowTestCase):
def testTrue(self):
self.assertEqual(sess.run(z), 1)
def testMix(self):
- # Constant expression evaluation only works with the C API enabled.
- if not ops._USE_C_API: return
-
x = array_ops.placeholder(dtype=dtypes.int32, shape=[])
y = constant_op.constant(10)
conditions = [(x > 1, lambda: constant_op.constant(1)),
self.assertEqual(sess.run(z, feed_dict={x: 0}), 3)
-@test_util.with_c_api
class SmartConstantValueTest(test_util.TensorFlowTestCase):
# TODO(skyewm): this is essentially a regression test for
from tensorflow.python.platform import googletest
-@test_util.with_c_api
class SubscribeTest(test_util.TensorFlowTestCase):
def _ExpectSubscribedIdentities(self, container):
from tensorflow.python.platform import googletest
-@test_util.with_c_api
class TestUtilTest(test_util.TensorFlowTestCase):
def test_assert_ops_in_graph(self):
self.assertIsNone(test_util.get_node_def_from_graph("bar", graph_def))
-@test_util.with_c_api
class GarbageCollectionTest(test_util.TensorFlowTestCase):
def test_no_reference_cycle_decorator(self):
self.evaluate(array_ops.size(tensor, out_type=dtypes.int64)).dtype)
-@test_util.with_c_api
class SequenceMaskTest(test_util.TensorFlowTestCase):
def testExceptions(self):
# test dtype and default maxlen:
res = array_ops.sequence_mask(constant_op.constant([0, 1, 4]),
dtype=dtypes.float32)
- if ops._USE_C_API:
- self.assertAllEqual(res.get_shape().as_list(), [3, 4])
- else:
- self.assertAllEqual(res.get_shape().as_list(), [3, None])
+ self.assertAllEqual(res.get_shape().as_list(), [3, 4])
self.assertAllEqual(
res.eval(),
[[0.0, 0.0, 0.0, 0.0], [1.0, 0.0, 0.0, 0.0], [1.0, 1.0, 1.0, 1.0]])
with self.test_session():
res = array_ops.sequence_mask(
constant_op.constant([0, 1, 4]))
- if ops._USE_C_API:
- self.assertAllEqual(res.get_shape().as_list(), [3, 4])
- else:
- self.assertAllEqual(res.get_shape().as_list(), [3, None])
+ self.assertAllEqual(res.get_shape().as_list(), [3, 4])
self.assertAllEqual(
res.eval(),
[[False, False, False, False],
# test dtype and default maxlen:
res = array_ops.sequence_mask(
constant_op.constant([[0, 1, 4], [1, 2, 3]]), dtype=dtypes.float32)
- if ops._USE_C_API:
- self.assertAllEqual(res.get_shape().as_list(), [2, 3, 4])
- else:
- self.assertAllEqual(res.get_shape().as_list(), [2, 3, None])
+ self.assertAllEqual(res.get_shape().as_list(), [2, 3, 4])
self.assertAllEqual(
res.eval(),
[[[0.0, 0.0, 0.0, 0.0], [1.0, 0.0, 0.0, 0.0], [1.0, 1.0, 1.0, 1.0]],
from tensorflow.python.framework import ops
from tensorflow.python.framework import sparse_tensor
from tensorflow.python.framework import tensor_shape
-from tensorflow.python.framework import test_util
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import control_flow_ops
from tensorflow.python.ops import data_flow_ops
return r_s
-@test_util.with_c_api
class ControlFlowTest(test.TestCase):
def testRefIdentity(self):
1)
-@test_util.with_c_api
class ControlFlowContextCheckTest(test.TestCase):
def _getWhileTensor(self):
math_ops.less(1, 2), true_fn, lambda: constant_op.constant(0))
-@test_util.with_c_api
class TupleTest(test.TestCase):
def testTensors(self):
self.assertEquals(1, var.eval())
-@test_util.with_c_api
class AssertTest(test.TestCase):
def testGuardedAssertDoesNotCopyWhenTrue(self):
self.assertEqual([], guarded_memcpy_nodestat_names)
-@test_util.with_c_api
class WhileOpBenchmark(test.Benchmark):
"""Evaluate the performance of while_loop op."""
name="unroll_same_device", iters=iters, wall_time=duration)
-@test_util.with_c_api
class EagerTest(test.TestCase):
def testCond(self):
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import ops
-from tensorflow.python.framework import test_util
from tensorflow.python.ops import array_ops
from tensorflow.python.platform import test
-@test_util.with_c_api
class LargeConcatOpTest(test.TestCase):
"""Tests that belong in concat_op_test.py, but run over large tensors."""
from tensorflow.python.util import compat
-@test_util.with_c_api
class ResourceVariableOpsTest(test_util.TensorFlowTestCase):
def tearDown(self):
from tensorflow.python.platform import test
-@test_util.with_c_api
class ScalarTest(test.TestCase):
def check(self, op, args, error, correct=None):
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import errors_impl
-from tensorflow.python.framework import test_util
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import nn_ops
from tensorflow.python.platform import test
from tensorflow.python.platform import tf_logging as logging
-@test_util.with_c_api
class SoftmaxTest(test.TestCase):
def _npSoftmax(self, features, dim=-1, log=False):