Fix a bug in function inlining when the argument is an implicitly dereferenced ref...
authorDerek Murray <mrry@google.com>
Fri, 2 Feb 2018 16:09:13 +0000 (08:09 -0800)
committerTensorFlower Gardener <gardener@tensorflow.org>
Fri, 2 Feb 2018 18:42:04 +0000 (10:42 -0800)
Previously the inliner would add an identity node with an invalid ref-type attr when the actual parameter had ref type. The changed version removes the reference.

PiperOrigin-RevId: 184285084

tensorflow/core/common_runtime/function.cc
tensorflow/python/framework/function_test.py

index 150fb85..e1b5404 100644 (file)
@@ -97,12 +97,11 @@ static Node* AddNoOp(Graph* g) {
 
 static Node* AddIdentity(Graph* g, Endpoint input) {
   DCHECK_LT(0, input.dtype());
-  DCHECK_LT(input.dtype(), DT_FLOAT_REF);
   NodeDef ndef;
   ndef.set_name(g->NewName(kNodeLabel));
   ndef.set_op("Identity");
   ndef.add_input(input.name());
-  AddNodeAttr("T", input.dtype(), &ndef);
+  AddNodeAttr("T", BaseType(input.dtype()), &ndef);
   Status s;
   Node* ret = g->AddNode(ndef, &s);
   TF_CHECK_OK(s);
index a4ca3f9..b35cee0 100644 (file)
@@ -19,8 +19,8 @@ from __future__ import division
 from __future__ import print_function
 
 import re
-import time
 import sys
+import time
 
 import numpy as np
 
@@ -86,6 +86,21 @@ class FunctionTest(test.TestCase):
       with session.Session() as sess:
         self.assertAllEqual([18.0], sess.run(call))
 
+  def testIdentityImplicitDeref(self):
+
+    @function.Defun(dtypes.float32, func_name="MyIdentity")
+    def MyIdentityFunc(a):
+      return a
+
+    with ops.Graph().as_default():
+      var = variables.Variable([18.0])
+      call = MyIdentityFunc(var._ref())  # pylint: disable=protected-access
+      self.assertEqual("MyIdentity", call.op.name)
+      for cfg in _OptimizerOptions():
+        with session.Session(config=cfg) as sess:
+          sess.run(var.initializer)
+          self.assertAllEqual([18.0], sess.run(call))
+
   def testIdentityOutputName(self):
 
     @function.Defun(
@@ -771,7 +786,7 @@ class FunctionTest(test.TestCase):
     # We added more randomness to function names in C API.
     # TODO(iga): Remove this if statement when we switch to C API.
     if ops._USE_C_API:  # pylint: disable=protected-access
-      if sys.byteorder == 'big':
+      if sys.byteorder == "big":
         self.assertEqual("Foo_kEdkAG8SJvg",
                          Foo.instantiate([dtypes.float32] * 3).name)
       else: