static Node* AddIdentity(Graph* g, Endpoint input) {
DCHECK_LT(0, input.dtype());
- DCHECK_LT(input.dtype(), DT_FLOAT_REF);
NodeDef ndef;
ndef.set_name(g->NewName(kNodeLabel));
ndef.set_op("Identity");
ndef.add_input(input.name());
- AddNodeAttr("T", input.dtype(), &ndef);
+ AddNodeAttr("T", BaseType(input.dtype()), &ndef);
Status s;
Node* ret = g->AddNode(ndef, &s);
TF_CHECK_OK(s);
from __future__ import print_function
import re
-import time
import sys
+import time
import numpy as np
with session.Session() as sess:
self.assertAllEqual([18.0], sess.run(call))
+ def testIdentityImplicitDeref(self):
+
+ @function.Defun(dtypes.float32, func_name="MyIdentity")
+ def MyIdentityFunc(a):
+ return a
+
+ with ops.Graph().as_default():
+ var = variables.Variable([18.0])
+ call = MyIdentityFunc(var._ref()) # pylint: disable=protected-access
+ self.assertEqual("MyIdentity", call.op.name)
+ for cfg in _OptimizerOptions():
+ with session.Session(config=cfg) as sess:
+ sess.run(var.initializer)
+ self.assertAllEqual([18.0], sess.run(call))
+
def testIdentityOutputName(self):
@function.Defun(
# We added more randomness to function names in C API.
# TODO(iga): Remove this if statement when we switch to C API.
if ops._USE_C_API: # pylint: disable=protected-access
- if sys.byteorder == 'big':
+ if sys.byteorder == "big":
self.assertEqual("Foo_kEdkAG8SJvg",
Foo.instantiate([dtypes.float32] * 3).name)
else: