Add test that checks all core ops have shape functions.
authorSkye Wanderman-Milne <skyewm@google.com>
Wed, 21 Feb 2018 21:14:27 +0000 (13:14 -0800)
committerTensorFlower Gardener <gardener@tensorflow.org>
Wed, 21 Feb 2018 21:18:18 +0000 (13:18 -0800)
This is meant to be a replacement for the current Python code that
checks that core ops have shape functions registered. Some ops were
missing a shape function, so I added UnknownShape.

This also adds an OpRegistry::GetOpRegistrationData() method for
fetching all the shape functions.

PiperOrigin-RevId: 186508356

tensorflow/core/BUILD
tensorflow/core/common_runtime/function_testlib.cc
tensorflow/core/framework/op.cc
tensorflow/core/framework/op.h
tensorflow/core/graph/testlib.cc
tensorflow/core/ops/function_ops.cc
tensorflow/core/ops/shape_function_test.cc [new file with mode: 0644]
tensorflow/core/ops/spectral_ops.cc
tensorflow/core/ops/word2vec_ops.cc
tensorflow/core/user_ops/fact.cc

index 2a8aefa..04307db 100644 (file)
@@ -3515,6 +3515,7 @@ tf_cc_tests(
         "ops/parsing_ops_test.cc",
         "ops/random_ops_test.cc",
         "ops/set_ops_test.cc",
+        "ops/shape_function_test.cc",
         "ops/sparse_ops_test.cc",
         "ops/spectral_ops_test.cc",
         "ops/state_ops_test.cc",
index 87c2476..87733ed 100644 (file)
@@ -15,6 +15,7 @@ limitations under the License.
 #include "tensorflow/core/common_runtime/function_testlib.h"
 
 #include "tensorflow/core/common_runtime/device.h"
+#include "tensorflow/core/framework/common_shape_fns.h"
 #include "tensorflow/core/framework/node_def.pb.h"
 #include "tensorflow/core/framework/node_def_builder.h"
 #include "tensorflow/core/framework/op_kernel.h"
@@ -39,7 +40,9 @@ class FindDeviceOpKernel : public OpKernel {
 
 REGISTER_KERNEL_BUILDER(Name("FindDeviceOp").Device(tensorflow::DEVICE_CPU),
                         FindDeviceOpKernel);
-REGISTER_OP("FindDeviceOp").Output("device_name: string");
+REGISTER_OP("FindDeviceOp")
+    .Output("device_name: string")
+    .SetShapeFn(shape_inference::UnknownShape);
 
 FunctionDef FindDevice() {
   return FDH::Define(
index fadb60d..fc5467b 100644 (file)
@@ -110,6 +110,15 @@ void OpRegistry::GetRegisteredOps(std::vector<OpDef>* op_defs) {
   }
 }
 
+void OpRegistry::GetOpRegistrationData(
+    std::vector<OpRegistrationData>* op_data) {
+  mutex_lock lock(mu_);
+  MustCallDeferred();
+  for (const auto& p : registry_) {
+    op_data->push_back(*p.second);
+  }
+}
+
 Status OpRegistry::SetWatcher(const Watcher& watcher) {
   mutex_lock lock(mu_);
   if (watcher_ && watcher) {
index f7f1ed2..3ccca40 100644 (file)
@@ -89,6 +89,9 @@ class OpRegistry : public OpRegistryInterface {
   // Get all registered ops.
   void GetRegisteredOps(std::vector<OpDef>* op_defs);
 
+  // Get all `OpRegistrationData`s.
+  void GetOpRegistrationData(std::vector<OpRegistrationData>* op_data);
+
   // Watcher, a function object.
   // The watcher, if set by SetWatcher(), is called every time an op is
   // registered via the Register function. The watcher is passed the Status
index 0d88d1f..67b252c 100644 (file)
@@ -16,6 +16,7 @@ limitations under the License.
 #include "tensorflow/core/graph/testlib.h"
 
 #include <vector>
+#include "tensorflow/core/framework/common_shape_fns.h"
 #include "tensorflow/core/framework/graph.pb.h"
 #include "tensorflow/core/framework/node_def_builder.h"
 #include "tensorflow/core/framework/node_def_util.h"
@@ -50,7 +51,8 @@ REGISTER_KERNEL_BUILDER(
 REGISTER_OP("HostConst")
     .Output("output: dtype")
     .Attr("value: tensor")
-    .Attr("dtype: type");
+    .Attr("dtype: type")
+    .SetShapeFn(shape_inference::UnknownShape);
 
 namespace test {
 namespace graph {
index ada96fa..a6914d9 100644 (file)
@@ -13,6 +13,7 @@ See the License for the specific language governing permissions and
 limitations under the License.
 ==============================================================================*/
 
+#include "tensorflow/core/framework/common_shape_fns.h"
 #include "tensorflow/core/framework/op.h"
 #include "tensorflow/core/framework/shape_inference.h"
 
@@ -55,6 +56,7 @@ REGISTER_OP("_ListToArray")
     .Attr("Tin: list(type)")
     .Attr("T: type")
     .Attr("N: int >= 1")
+    .SetShapeFn(shape_inference::UnknownShape)
     .Doc(R"doc(
 Converts a list of tensors to an array of tensors.
 )doc");
@@ -65,6 +67,7 @@ REGISTER_OP("_ArrayToList")
     .Attr("T: type")
     .Attr("N: int >= 1")
     .Attr("out_types: list(type)")
+    .SetShapeFn(shape_inference::UnknownShape)
     .Doc(R"doc(
 Converts an array of tensors to a list of tensors.
 )doc");
diff --git a/tensorflow/core/ops/shape_function_test.cc b/tensorflow/core/ops/shape_function_test.cc
new file mode 100644 (file)
index 0000000..120995f
--- /dev/null
@@ -0,0 +1,34 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+
+You may obtain a copy of the License at
+
+    http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow/core/framework/op.h"
+#include "tensorflow/core/platform/test.h"
+
+// Test to ensure that all core ops have shape functions defined. This is done
+// by looking at all ops registered in the test binary.
+
+namespace tensorflow {
+
+TEST(ShapeFunctionTest, RegisteredOpsHaveShapeFns) {
+  OpRegistry* op_registry = OpRegistry::Global();
+  std::vector<OpRegistrationData> op_data;
+  op_registry->GetOpRegistrationData(&op_data);
+  for (const OpRegistrationData& op_reg_data : op_data) {
+    EXPECT_TRUE(op_reg_data.shape_inference_fn != nullptr)
+        << op_reg_data.op_def.name();
+  }
+}
+
+}  // namespace tensorflow
index 508cea3..2790aee 100644 (file)
@@ -142,26 +142,32 @@ REGISTER_OP("IRFFT3D")
 REGISTER_OP("BatchFFT")
     .Input("input: complex64")
     .Output("output: complex64")
+    .SetShapeFn(shape_inference::UnknownShape)
     .Deprecated(15, "Use FFT");
 REGISTER_OP("BatchIFFT")
     .Input("input: complex64")
     .Output("output: complex64")
+    .SetShapeFn(shape_inference::UnknownShape)
     .Deprecated(15, "Use IFFT");
 REGISTER_OP("BatchFFT2D")
     .Input("input: complex64")
     .Output("output: complex64")
+    .SetShapeFn(shape_inference::UnknownShape)
     .Deprecated(15, "Use FFT2D");
 REGISTER_OP("BatchIFFT2D")
     .Input("input: complex64")
     .Output("output: complex64")
+    .SetShapeFn(shape_inference::UnknownShape)
     .Deprecated(15, "Use IFFT2D");
 REGISTER_OP("BatchFFT3D")
     .Input("input: complex64")
     .Output("output: complex64")
+    .SetShapeFn(shape_inference::UnknownShape)
     .Deprecated(15, "Use FFT3D");
 REGISTER_OP("BatchIFFT3D")
     .Input("input: complex64")
     .Output("output: complex64")
+    .SetShapeFn(shape_inference::UnknownShape)
     .Deprecated(15, "Use IFFT3D");
 
 }  // namespace tensorflow
index ed685dc..e469771 100644 (file)
@@ -13,6 +13,7 @@ See the License for the specific language governing permissions and
 limitations under the License.
 ==============================================================================*/
 
+#include "tensorflow/core/framework/common_shape_fns.h"
 #include "tensorflow/core/framework/op.h"
 
 namespace tensorflow {
@@ -33,7 +34,8 @@ REGISTER_OP("Skipgram")
     .Attr("batch_size: int")
     .Attr("window_size: int = 5")
     .Attr("min_count: int = 5")
-    .Attr("subsample: float = 1e-3");
+    .Attr("subsample: float = 1e-3")
+    .SetShapeFn(shape_inference::UnknownShape);
 
 REGISTER_OP("NegTrain")
     .Deprecated(19,
@@ -46,6 +48,7 @@ REGISTER_OP("NegTrain")
     .Input("lr: float")
     .SetIsStateful()
     .Attr("vocab_count: list(int)")
-    .Attr("num_negative_samples: int");
+    .Attr("num_negative_samples: int")
+    .SetShapeFn(shape_inference::UnknownShape);
 
 }  // end namespace tensorflow
index 3a4fc81..2e8b22a 100644 (file)
@@ -15,10 +15,13 @@ limitations under the License.
 
 // An example Op.
 
+#include "tensorflow/core/framework/common_shape_fns.h"
 #include "tensorflow/core/framework/op.h"
 #include "tensorflow/core/framework/op_kernel.h"
 
-REGISTER_OP("Fact").Output("fact: string");
+REGISTER_OP("Fact")
+    .Output("fact: string")
+    .SetShapeFn(tensorflow::shape_inference::UnknownShape);
 
 class FactOp : public tensorflow::OpKernel {
  public: