Add FunctionTest.testLayerInDefun
authorIgor Ganichev <iga@google.com>
Thu, 12 Apr 2018 19:04:48 +0000 (12:04 -0700)
committerTensorFlower Gardener <gardener@tensorflow.org>
Thu, 12 Apr 2018 19:06:50 +0000 (12:06 -0700)
PiperOrigin-RevId: 192647818

tensorflow/python/eager/BUILD
tensorflow/python/eager/function_test.py

index 8c0d3fe..b3268c9 100644 (file)
@@ -142,6 +142,8 @@ cuda_py_test(
         ":tape",
         ":test",
         "//tensorflow/python:clip_ops",
+        "//tensorflow/python:init_ops",
+        "//tensorflow/python:layers",
         "//tensorflow/python:math_ops",
         "//tensorflow/python:resource_variable_ops",
     ],
index 9af1979..65dde75 100644 (file)
@@ -29,9 +29,11 @@ from tensorflow.python.framework import dtypes
 from tensorflow.python.framework import function as tf_function
 from tensorflow.python.framework import ops
 from tensorflow.python.framework import tensor_shape
+from tensorflow.python.layers import convolutional
 from tensorflow.python.ops import array_ops
 from tensorflow.python.ops import clip_ops
 from tensorflow.python.ops import control_flow_ops
+from tensorflow.python.ops import init_ops
 from tensorflow.python.ops import math_ops
 from tensorflow.python.ops import resource_variable_ops
 from tensorflow.python.ops import variable_scope
@@ -104,6 +106,7 @@ class FunctionTest(test.TestCase):
     matmul = function.defun(math_ops.matmul)
 
     pair = collections.namedtuple('pair', ['a', 'b'])
+
     def a_times_b(inputs):
       return matmul(inputs.a['a'], inputs.b['b'])
 
@@ -312,6 +315,7 @@ class FunctionTest(test.TestCase):
         x = variable_scope.get_variable(
             'v', initializer=constant_op.constant(1.0))
         return x * constant_op.constant(2.0)
+
       with self.assertRaisesRegexp(ValueError,
                                    'No trainable variables were accessed'):
         backprop.implicit_val_and_grad(f)()
@@ -581,6 +585,7 @@ class FunctionTest(test.TestCase):
       with ops.name_scope('foo'):
         v = resource_variable_ops.ResourceVariable(0.0, name='bar')
       self.assertEqual(v.name, 'foo/bar:0')
+
     create_variable()
 
   def testVariableNamesRespectNameScopesWithDefunInGraph(self):
@@ -590,9 +595,25 @@ class FunctionTest(test.TestCase):
         with ops.name_scope('foo'):
           v = resource_variable_ops.ResourceVariable([1.0, 2.0], name='bar')
         self.assertEqual(v.name, 'foo/bar:0')
+
       with ops.get_default_graph().as_default():
         create_variable()
 
+  def testLayerInDefun(self):
+    conv = convolutional.Conv2D(
+        filters=1,
+        kernel_size=2,
+        kernel_initializer=init_ops.ones_initializer(),
+        bias_initializer=init_ops.zeros_initializer())
+
+    @function.defun
+    def model(x):
+      return conv(x)
+
+    x = array_ops.ones([1, 2, 2, 1])
+    y = model(x)
+    self.assertAllEqual([[[[4.0]]]], y.numpy())
+
 
 class AutomaticControlDependenciesTest(test.TestCase):