Adds a nodedef_fn parameter to copy_op_handler, allowing customization by mutating
authorA. Unique TensorFlower <gardener@tensorflow.org>
Wed, 11 Apr 2018 21:04:09 +0000 (14:04 -0700)
committerTensorFlower Gardener <gardener@tensorflow.org>
Wed, 11 Apr 2018 21:07:52 +0000 (14:07 -0700)
NodeDef before creating the copied operation.

PiperOrigin-RevId: 192505209

tensorflow/contrib/graph_editor/tests/transform_test.py
tensorflow/contrib/graph_editor/transform.py

index 2603de6..97f38c9 100644 (file)
@@ -18,9 +18,11 @@ from __future__ import division
 from __future__ import print_function
 
 import collections
+import functools
 import numpy as np
 from tensorflow.contrib import graph_editor as ge
 from tensorflow.contrib.graph_editor.tests import match
+from tensorflow.core.framework import attr_value_pb2
 from tensorflow.python.client import session
 from tensorflow.python.framework import constant_op
 from tensorflow.python.framework import dtypes
@@ -42,6 +44,7 @@ class TransformTest(test.TestCase):
     self.graph = ops.Graph()
     with self.graph.as_default():
       c0 = constant_op.constant(1.0, shape=[10], name="Const")
+      c0.op._set_attr("_foo", attr_value_pb2.AttrValue(s=b"foo"))
       c1 = constant_op.constant(1.0, shape=[10], name="Const")
       c2 = constant_op.constant(1.0, shape=[10], name="Const")
       i = constant_op.constant(1.0, shape=[10], name="Input")
@@ -112,6 +115,32 @@ class TransformTest(test.TestCase):
     top = ge.select_ops("^AddNoise_2$", graph=graph)[0]
     self.assertTrue(matcher2(top))
 
+  def test_transform_nodedef_fn(self):
+    transformer = ge.Transformer()
+
+    def nodedef_fn(node_def):
+      if "_foo" in node_def.attr:
+        del node_def.attr["_foo"]
+      node_def.attr["_bar"].s = b"bar"
+      return node_def
+
+    my_copy_op_handler = functools.partial(
+        ge.transform.copy_op_handler, nodedef_fn=nodedef_fn)
+    transformer.transform_op_handler = my_copy_op_handler
+
+    graph = ops.Graph()
+    transformer(self.graph, graph, "", "")
+
+    c0_before = self.graph.get_operation_by_name("Const")
+    c0_after = graph.get_operation_by_name("Const")
+    self.assertEquals(c0_before.get_attr("_foo"), b"foo")
+    with self.assertRaises(ValueError):
+      c0_after.get_attr("_foo")
+
+    all_ops = graph.get_operations()
+    for op in all_ops:
+      self.assertEquals(op.get_attr("_bar"), b"bar")
+
   def test_copy_with_input_replacements(self):
     with self.graph.as_default():
       ten = constant_op.constant(10.0, shape=[10], name="Input")
index d8a4838..a320a3f 100644 (file)
@@ -129,7 +129,7 @@ def transform_op_if_inside_handler(info, op, keep_if_possible=True):
       return None
 
 
-def copy_op_handler(info, op, new_inputs, copy_shape=True):
+def copy_op_handler(info, op, new_inputs, copy_shape=True, nodedef_fn=None):
   """Copy a `tf.Operation`.
 
   Args:
@@ -137,6 +137,11 @@ def copy_op_handler(info, op, new_inputs, copy_shape=True):
     op: the `tf.Operation` to be copied.
     new_inputs: The new inputs for this op.
     copy_shape: also copy the shape of the tensor
+    nodedef_fn: If provided, a function that will be run on the NodeDef
+      and should return a mutated NodeDef before a new Operation is created.
+      This is useful as certain features cannot be set on the Operation and
+      must be modified in NodeDef.
+
   Returns:
     A `(op, op_outputs)` tuple containing the transformed op and its outputs.
   """
@@ -155,6 +160,10 @@ def copy_op_handler(info, op, new_inputs, copy_shape=True):
   name_ = info.graph_.unique_name(name_)
   node_def_.name = name_
 
+  # Mutate NodeDef if requested:
+  if nodedef_fn is not None:
+    node_def_ = nodedef_fn(node_def_)
+
   # Copy the other inputs needed for initialization
   output_types_ = op._output_types[:]
   input_types_ = op._input_types[:]