Make SavedModelTest.testStripDefaultAttrsInconsistentConsumerDefaults work with C...
authorSkye Wanderman-Milne <skyewm@google.com>
Fri, 9 Feb 2018 00:05:31 +0000 (16:05 -0800)
committerTensorFlower Gardener <gardener@tensorflow.org>
Fri, 9 Feb 2018 00:09:41 +0000 (16:09 -0800)
The test originally altered the Python version of the op registry,
which is not reflected in the C API. This changes the test to alter
the serialized node def instead of the op def, and renames the test to
testInconsistentConsumerDefaultAttrs.

PiperOrigin-RevId: 185067838

tensorflow/python/framework/test_ops.cc
tensorflow/python/saved_model/BUILD
tensorflow/python/saved_model/saved_model_test.py

index c6c6c22..070b5ac 100644 (file)
@@ -76,6 +76,11 @@ REGISTER_OP("TestStringOutput")
     .Output("output2: string")
     .SetShapeFn(shape_inference::UnknownShape);
 
+REGISTER_OP("TestAttr")
+    .Output("out: T")
+    .Attr("T: {float, double}")
+    .SetShapeFn(shape_inference::UnknownShape);
+
 namespace {
 enum KernelLabel { DEFAULT_LABEL, OVERLOAD_1_LABEL, OVERLOAD_2_LABEL };
 }  // namespace
@@ -188,6 +193,20 @@ class ResourceUsingOp : public OpKernel {
 REGISTER_KERNEL_BUILDER(Name("ResourceUsingOp").Device(DEVICE_CPU),
                         ResourceUsingOp);
 
+class TestAttrOp : public OpKernel {
+ public:
+  explicit TestAttrOp(OpKernelConstruction* ctx) : OpKernel(ctx) {}
+
+  void Compute(OpKernelContext* ctx) override {
+    Tensor* output;
+    OP_REQUIRES_OK(ctx, ctx->allocate_output(0, TensorShape({}), &output));
+    output->scalar<float>()() = 1.0;
+  }
+};
+
+REGISTER_KERNEL_BUILDER(
+    Name("TestAttr").Device(DEVICE_CPU).TypeConstraint<float>("T"), TestAttrOp);
+
 // Various test ops without kernels. These are used to test graph construction.
 
 REGISTER_OP("A")
index e34aa7c..30e0a09 100644 (file)
@@ -148,6 +148,7 @@ py_test(
         "//tensorflow/python:math_ops",
         "//tensorflow/python:saver_test_utils",
         "//tensorflow/python:state_ops",
+        "//tensorflow/python:test_ops",
         "//tensorflow/python:util",
         "//tensorflow/python:variables",
     ],
index f92247d..d1f6bc2 100644 (file)
@@ -20,7 +20,6 @@ from __future__ import print_function
 
 import os
 
-from tensorflow.core.framework import op_def_pb2
 from tensorflow.core.framework import types_pb2
 from tensorflow.core.protobuf import config_pb2
 from tensorflow.core.protobuf import meta_graph_pb2
@@ -28,8 +27,8 @@ from tensorflow.python.client import session
 from tensorflow.python.framework import constant_op
 from tensorflow.python.framework import dtypes
 from tensorflow.python.framework import errors
-from tensorflow.python.framework import op_def_registry
 from tensorflow.python.framework import ops
+from tensorflow.python.framework import test_ops
 from tensorflow.python.framework import test_util
 from tensorflow.python.lib.io import file_io
 from tensorflow.python.ops import control_flow_ops
@@ -945,61 +944,75 @@ class SavedModelTest(test.TestCase):
     self.assertIn("T", node_def.attr)
     self.assertIn("Tout", node_def.attr)
 
-  def testStripDefaultAttrsInconsistentConsumerDefaults(self):
-    if ops._USE_C_API: return  # TODO(skyewm): get this working
-
+  # Tests the behavior of loading SavedModels that having missing attrs or attrs
+  # with incorrect types.
+  def testInconsistentConsumerDefaultAttrs(self):
     export_dir = self._get_export_dir(
         "test_strip_default_attrs_no_consumer_defaults")
     builder = saved_model_builder.SavedModelBuilder(export_dir)
 
-    # Add a graph with two float32 variables and a Complex Op composing them
-    # with strip_default_attrs enabled. This must remove the following
-    # defaults for the "Complex" Op:
-    #   o "T"    : float32.   (input type)
-    #   o "Tout" : complex64. (output type)
+    # Add a graph with a single variable and a test op with a defaultless
+    # float32 attr, "test_attr".
     with session.Session(graph=ops.Graph()) as sess:
-      real_num = variables.Variable(1.0, dtype=dtypes.float32, name="real")
-      imag_num = variables.Variable(2.0, dtype=dtypes.float32, name="imag")
-      math_ops.complex(real_num, imag_num, name="complex")
+      variables.Variable(1.0, dtype=dtypes.float64, name="var")
+      test_ops.test_attr(T=dtypes.float32, name="test_attr")
       sess.run(variables.global_variables_initializer())
-      builder.add_meta_graph_and_variables(
-          sess, ["foo"], strip_default_attrs=True)
+      builder.add_meta_graph_and_variables(sess, ["foo"])
 
     # Save the SavedModel to disk in text format.
     builder.save(as_text=True)
 
-    # Update the Op registry to remove defaults for all attrs("T", "Tout") from
-    # the "Complex" OpDef.
-    complex_op_def = op_def_registry.get_registered_ops()["Complex"]
-    original_complex_op_def = op_def_pb2.OpDef()
-    original_complex_op_def.CopyFrom(complex_op_def)
-    for attr_def in complex_op_def.attr:
-      attr_def.ClearField("default_value")
+    # Rewrite the SavedModel to remove the T attr from "test_attr".
+    saved_model_file = os.path.join(
+        export_dir, constants.SAVED_MODEL_FILENAME_PBTXT)
+    with open(saved_model_file) as f:
+      original_saved_model = f.read()
+
+    no_attr_saved_model = original_saved_model.replace("""
+      attr {
+        key: "T"
+        value {
+          type: DT_FLOAT
+        }
+      }""", "")
+    with open(saved_model_file, "w") as f:
+      f.write(no_attr_saved_model)
 
     # Loading the SavedModel via the loader must fail because the SavedModel
-    # does not have any attr values for the "Complex" node and the current
-    # op registry does not have have any default values for the "Complex" op.
+    # does not have any attr values for the "TestAttr" node, and there is no
+    # default specified in the TestAttr OpDef.
     sess = session.Session(graph=ops.Graph())
-    with self.assertRaisesRegexp(
-        ValueError,
-        "Expected one attr with name .*T(out)?.* in name: \"complex\".*"):
+    if ops._USE_C_API:
+      error_message = "NodeDef missing attr 'T' from Op<name=TestAttr"
+    else:
+      error_message = ("Expected one attr with name .*T(out)?.* in name: "
+                       "\"test_attr\".*")
+    with self.assertRaisesRegexp(ValueError, error_message):
       loader.load(sess, ["foo"], export_dir)
 
-    # Update the Op registry to change the defaults for attr "Tout"
-    # (complex64 -> complex128).
-    complex_op_def.CopyFrom(original_complex_op_def)
-    for attr_def in complex_op_def.attr:
-      if attr_def.name == "Tout":
-        attr_def.default_value.type = types_pb2.DT_COMPLEX128
-
-    # Loading the SavedModel via the loader must set "Tout" attr_value for the
-    # "Complex" node according to the latest defaults (complex128). This is
-    # expected to fail the model import as there is no OpKernel registered to
-    # handle attrs "T" (float32) and "Tout" (complex128).
+    # Rewrite the SavedModel to change the type of the T attr in "test_attr"
+    bad_type_saved_model = original_saved_model.replace("""
+      attr {
+        key: "T"
+        value {
+          type: DT_FLOAT
+        }
+      }""", """
+      attr {
+        key: "T"
+        value {
+          type: DT_DOUBLE
+        }
+      }""")
+    with open(saved_model_file, "w") as f:
+      f.write(bad_type_saved_model)
+
+    # Loading the SavedModel via the loader must fail because there is no
+    # OpKernel registered to handle T = double.
     sess = session.Session(graph=ops.Graph())
     with self.assertRaisesRegexp(
         errors.InvalidArgumentError,
-        ".*No OpKernel was registered to support Op \'Complex\' with these "
+        ".*No OpKernel was registered to support Op \'TestAttr\' with these "
         "attrs..*"):
       loader.load(sess, ["foo"], export_dir)