Allow experimental string attrs for functions.
authorPatrick Nguyen <drpng@google.com>
Thu, 29 Mar 2018 17:41:36 +0000 (10:41 -0700)
committerTensorFlower Gardener <gardener@tensorflow.org>
Thu, 29 Mar 2018 17:46:31 +0000 (10:46 -0700)
PiperOrigin-RevId: 190951605

tensorflow/python/framework/function.py
tensorflow/python/framework/function_test.py

index 14d72d8..82dd2a3 100644 (file)
@@ -934,6 +934,12 @@ def _parse_kwargs_as_attrs(func_name, **kwargs):
           s=("function_%s" % func_name).encode())
     # pylint: enable=protected-access
 
+  kwargs_keys = list(kwargs.keys())
+  for key in kwargs_keys:
+    if key.startswith("experimental_"):
+      attrs[key] = attr_value_pb2.AttrValue(s=compat.as_bytes(kwargs[key]))
+      del kwargs[key]
+
   if kwargs:
     raise ValueError("Unknown keyword arguments: %s" % kwargs.keys())
   return attrs
index 65ca801..83d256f 100644 (file)
@@ -1227,6 +1227,15 @@ class FunctionsFromProtos(test.TestCase):
         ValueError, "FunctionDefLibrary contains cyclic gradient functions!"):
       function._from_library(library)
 
+  def testExperimentalAttrs(self):
+
+    @function.Defun(dtypes.int32, experimental_tag="tag_value")
+    def FunctionWithAttr(i):
+      return array_ops.identity(i)
+    self.assertTrue("experimental_tag" in FunctionWithAttr.definition.attr)
+    self.assertEqual(
+        FunctionWithAttr.definition.attr["experimental_tag"].s, b"tag_value")
+
 
 @test_util.with_c_api
 class FunctionOverloadTest(test.TestCase):