from __future__ import print_function
import collections
+import functools
import numpy as np
from tensorflow.contrib import graph_editor as ge
from tensorflow.contrib.graph_editor.tests import match
+from tensorflow.core.framework import attr_value_pb2
from tensorflow.python.client import session
from tensorflow.python.framework import constant_op
from tensorflow.python.framework import dtypes
self.graph = ops.Graph()
with self.graph.as_default():
c0 = constant_op.constant(1.0, shape=[10], name="Const")
+ c0.op._set_attr("_foo", attr_value_pb2.AttrValue(s=b"foo"))
c1 = constant_op.constant(1.0, shape=[10], name="Const")
c2 = constant_op.constant(1.0, shape=[10], name="Const")
i = constant_op.constant(1.0, shape=[10], name="Input")
top = ge.select_ops("^AddNoise_2$", graph=graph)[0]
self.assertTrue(matcher2(top))
+ def test_transform_nodedef_fn(self):
+ transformer = ge.Transformer()
+
+ def nodedef_fn(node_def):
+ if "_foo" in node_def.attr:
+ del node_def.attr["_foo"]
+ node_def.attr["_bar"].s = b"bar"
+ return node_def
+
+ my_copy_op_handler = functools.partial(
+ ge.transform.copy_op_handler, nodedef_fn=nodedef_fn)
+ transformer.transform_op_handler = my_copy_op_handler
+
+ graph = ops.Graph()
+ transformer(self.graph, graph, "", "")
+
+ c0_before = self.graph.get_operation_by_name("Const")
+ c0_after = graph.get_operation_by_name("Const")
+ self.assertEquals(c0_before.get_attr("_foo"), b"foo")
+ with self.assertRaises(ValueError):
+ c0_after.get_attr("_foo")
+
+ all_ops = graph.get_operations()
+ for op in all_ops:
+ self.assertEquals(op.get_attr("_bar"), b"bar")
+
def test_copy_with_input_replacements(self):
with self.graph.as_default():
ten = constant_op.constant(10.0, shape=[10], name="Input")
return None
-def copy_op_handler(info, op, new_inputs, copy_shape=True):
+def copy_op_handler(info, op, new_inputs, copy_shape=True, nodedef_fn=None):
"""Copy a `tf.Operation`.
Args:
op: the `tf.Operation` to be copied.
new_inputs: The new inputs for this op.
copy_shape: also copy the shape of the tensor
+ nodedef_fn: If provided, a function that will be run on the NodeDef
+ and should return a mutated NodeDef before a new Operation is created.
+ This is useful as certain features cannot be set on the Operation and
+ must be modified in NodeDef.
+
Returns:
A `(op, op_outputs)` tuple containing the transformed op and its outputs.
"""
name_ = info.graph_.unique_name(name_)
node_def_.name = name_
+ # Mutate NodeDef if requested:
+ if nodedef_fn is not None:
+ node_def_ = nodedef_fn(node_def_)
+
# Copy the other inputs needed for initialization
output_types_ = op._output_types[:]
input_types_ = op._input_types[:]