From: Derek Murray Date: Fri, 2 Feb 2018 16:09:13 +0000 (-0800) Subject: Fix a bug in function inlining when the argument is an implicitly dereferenced ref... X-Git-Tag: upstream/v1.7.0~31^2~1064 X-Git-Url: http://review.tizen.org/git/?a=commitdiff_plain;h=1804be16f41f5217d1ad53a8ba992d9a132d4d79;p=platform%2Fupstream%2Ftensorflow.git Fix a bug in function inlining when the argument is an implicitly dereferenced ref tensor. Previously the inliner would add an identity node with an invalid ref-type attr when the actual parameter had ref type. The changed version removes the reference. PiperOrigin-RevId: 184285084 --- diff --git a/tensorflow/core/common_runtime/function.cc b/tensorflow/core/common_runtime/function.cc index 150fb85..e1b5404 100644 --- a/tensorflow/core/common_runtime/function.cc +++ b/tensorflow/core/common_runtime/function.cc @@ -97,12 +97,11 @@ static Node* AddNoOp(Graph* g) { static Node* AddIdentity(Graph* g, Endpoint input) { DCHECK_LT(0, input.dtype()); - DCHECK_LT(input.dtype(), DT_FLOAT_REF); NodeDef ndef; ndef.set_name(g->NewName(kNodeLabel)); ndef.set_op("Identity"); ndef.add_input(input.name()); - AddNodeAttr("T", input.dtype(), &ndef); + AddNodeAttr("T", BaseType(input.dtype()), &ndef); Status s; Node* ret = g->AddNode(ndef, &s); TF_CHECK_OK(s); diff --git a/tensorflow/python/framework/function_test.py b/tensorflow/python/framework/function_test.py index a4ca3f9..b35cee0 100644 --- a/tensorflow/python/framework/function_test.py +++ b/tensorflow/python/framework/function_test.py @@ -19,8 +19,8 @@ from __future__ import division from __future__ import print_function import re -import time import sys +import time import numpy as np @@ -86,6 +86,21 @@ class FunctionTest(test.TestCase): with session.Session() as sess: self.assertAllEqual([18.0], sess.run(call)) + def testIdentityImplicitDeref(self): + + @function.Defun(dtypes.float32, func_name="MyIdentity") + def MyIdentityFunc(a): + return a + + with ops.Graph().as_default(): + var = variables.Variable([18.0]) + call = MyIdentityFunc(var._ref()) # pylint: disable=protected-access + self.assertEqual("MyIdentity", call.op.name) + for cfg in _OptimizerOptions(): + with session.Session(config=cfg) as sess: + sess.run(var.initializer) + self.assertAllEqual([18.0], sess.run(call)) + def testIdentityOutputName(self): @function.Defun( @@ -771,7 +786,7 @@ class FunctionTest(test.TestCase): # We added more randomness to function names in C API. # TODO(iga): Remove this if statement when we switch to C API. if ops._USE_C_API: # pylint: disable=protected-access - if sys.byteorder == 'big': + if sys.byteorder == "big": self.assertEqual("Foo_kEdkAG8SJvg", Foo.instantiate([dtypes.float32] * 3).name) else: