s=("function_%s" % func_name).encode())
# pylint: enable=protected-access
+ kwargs_keys = list(kwargs.keys())
+ for key in kwargs_keys:
+ if key.startswith("experimental_"):
+ attrs[key] = attr_value_pb2.AttrValue(s=compat.as_bytes(kwargs[key]))
+ del kwargs[key]
+
if kwargs:
raise ValueError("Unknown keyword arguments: %s" % kwargs.keys())
return attrs
ValueError, "FunctionDefLibrary contains cyclic gradient functions!"):
function._from_library(library)
+ def testExperimentalAttrs(self):
+
+ @function.Defun(dtypes.int32, experimental_tag="tag_value")
+ def FunctionWithAttr(i):
+ return array_ops.identity(i)
+ self.assertTrue("experimental_tag" in FunctionWithAttr.definition.attr)
+ self.assertEqual(
+ FunctionWithAttr.definition.attr["experimental_tag"].s, b"tag_value")
+
@test_util.with_c_api
class FunctionOverloadTest(test.TestCase):