Don't require shape functions when creating ops from Python using the C API.
authorSkye Wanderman-Milne <skyewm@google.com>
Wed, 21 Feb 2018 22:56:13 +0000 (14:56 -0800)
committerTensorFlower Gardener <gardener@tensorflow.org>
Wed, 21 Feb 2018 23:00:20 +0000 (15:00 -0800)
There are many ops out there without shape functions, and it's very
onerous to add UnknownShape to all of them.

PiperOrigin-RevId: 186524294

tensorflow/c/python_api.cc
tensorflow/c/python_api.h
tensorflow/python/client/tf_session.i
tensorflow/python/framework/ops.py

index 6e37cdb..f553142 100644 (file)
@@ -99,4 +99,9 @@ void RemoveAllControlInputs(TF_Graph* graph, TF_Operation* op) {
   }
 }
 
+void SetRequireShapeInferenceFns(TF_Graph* graph, bool require) {
+  mutex_lock l(graph->mu);
+  graph->refiner.set_require_shape_inference_fns(require);
+}
+
 }  // namespace tensorflow
index aa9d9e0..542d70f 100644 (file)
@@ -37,6 +37,10 @@ void UpdateEdge(TF_Graph* graph, TF_Output new_src, TF_Input dst,
 
 void RemoveAllControlInputs(TF_Graph* graph, TF_Operation* op);
 
+// Sets whether ops missing a shape inference function should trigger an
+// error. The default is true.
+void SetRequireShapeInferenceFns(TF_Graph* graph, bool require);
+
 }  // namespace tensorflow
 
 #endif  // TENSORFLOW_C_PYTHON_API_H_
index 1fd488e..f305cd2 100644 (file)
@@ -719,6 +719,8 @@ def TF_Reset(target, containers=None, config=None):
   $1 = &types_local;
 }
 
+%unignore SetRequireShapeInferenceFns;
+
 %include "tensorflow/python/client/tf_session_helper.h"
 
 %unignoreall
index afd553b..013a4df 100644 (file)
@@ -2770,6 +2770,10 @@ class Graph(object):
     # implementation
     if self._use_c_api_hack():
       self._scoped_c_graph = c_api_util.ScopedTFGraph()
+      # The C API requires all ops to have shape functions. Disable this
+      # requirement (many custom ops do not have shape functions, and we don't
+      # want to break these existing cases).
+      c_api.SetRequireShapeInferenceFns(self._c_graph, False)
     else:
       self._scoped_c_graph = None
     self._variable_creator_stack = []