Added a debug mode to the model analyzer to make it easier to figure out why shapes...
authorBenoit Steiner <bsteiner@google.com>
Tue, 12 Dec 2017 22:15:09 +0000 (14:15 -0800)
committerTensorFlower Gardener <gardener@tensorflow.org>
Tue, 12 Dec 2017 22:19:09 +0000 (14:19 -0800)
PiperOrigin-RevId: 178813305

tensorflow/python/BUILD
tensorflow/python/grappler/model_analyzer.cc
tensorflow/python/grappler/model_analyzer.h
tensorflow/python/grappler/model_analyzer.i
tensorflow/python/grappler/model_analyzer.py
tensorflow/python/grappler/model_analyzer_test.py

index 20944d16789ecbc2b3a604f3f173d44e0f27b286..4012197bce4355265fb589d1f35366c60dd571ad 100644 (file)
@@ -204,11 +204,11 @@ cc_library(
     srcs = ["grappler/model_analyzer.cc"],
     hdrs = ["grappler/model_analyzer.h"],
     deps = [
+        "//tensorflow/core:framework",
         "//tensorflow/core:lib",
         "//tensorflow/core:protos_all_cc",
         "//tensorflow/core/grappler:grappler_item",
         "//tensorflow/core/grappler/costs:graph_properties",
-        "//tensorflow/core/grappler/costs:utils",
     ],
 )
 
index da5b03234e9bf806727f05c20ec6aa4270f843a7..d23eb811ac2b0a6a8802979b4d966b5617c8a8d9 100644 (file)
@@ -16,6 +16,7 @@ limitations under the License.
 #include "tensorflow/python/grappler/model_analyzer.h"
 
 #include <iomanip>
+#include "tensorflow/core/framework/op.h"
 #include "tensorflow/core/framework/tensor_shape.pb.h"
 #include "tensorflow/core/grappler/costs/graph_properties.h"
 #include "tensorflow/core/grappler/grappler_item.h"
@@ -25,26 +26,26 @@ namespace grappler {
 
 ModelAnalyzer::ModelAnalyzer(const GrapplerItem& item) : item_(item) {}
 
-Status ModelAnalyzer::GenerateReport(std::ostream& os) {
+Status ModelAnalyzer::GenerateReport(bool debug, std::ostream& os) {
   GraphProperties properties(item_);
   TF_RETURN_IF_ERROR(properties.InferStatically(false));
 
   for (const auto& node : item_.MainOpsFanin()) {
-    PrintNodeInfo(node, properties, os);
+    PrintNodeInfo(node, properties, debug, os);
   }
   for (const auto& node : item_.EnqueueOpsFanin()) {
-    PrintNodeInfo(node, properties, os);
+    PrintNodeInfo(node, properties, debug, os);
   }
 
   return Status::OK();
 }
 
 void ModelAnalyzer::PrintNodeInfo(const NodeDef* node,
-                                  const GraphProperties& properties,
+                                  const GraphProperties& properties, bool debug,
                                   std::ostream& os) const {
   os << node->name() << " [" << node->op() << "]" << std::endl;
   if (properties.HasOutputProperties(node->name())) {
-    std::vector<OpInfo::TensorProperties> props =
+    const std::vector<OpInfo::TensorProperties>& props =
         properties.GetOutputProperties(node->name());
     for (int i = 0; i < props.size(); ++i) {
       const OpInfo::TensorProperties& prop = props[i];
@@ -75,6 +76,27 @@ void ModelAnalyzer::PrintNodeInfo(const NodeDef* node,
       os << std::endl;
     }
   }
+
+  if (debug) {
+    const OpRegistrationData* op_reg_data;
+    Status status = OpRegistry::Global()->LookUp(node->op(), &op_reg_data);
+    if (!status.ok()) {
+      os << "\tCouldn't find op registration for " << node->op() << std::endl;
+    } else if (!op_reg_data->shape_inference_fn) {
+      os << "\tCouldn't find shape function for op " << node->op() << std::endl;
+    } else if (properties.HasInputProperties(node->name())) {
+      const std::vector<OpInfo::TensorProperties>& props =
+          properties.GetInputProperties(node->name());
+      for (int i = 0; i < props.size(); ++i) {
+        const OpInfo::TensorProperties& prop = props[i];
+        if (prop.has_value()) {
+          os << "\t"
+             << "input " << i << " (" << DataTypeString(prop.dtype())
+             << ") has known value" << std::endl;
+        }
+      }
+    }
+  }
 }
 
 }  // end namespace grappler
index a14034103ca70e59ac24d88318edc198e7d1c5f4..5bc551927d88db723e21b29903d6f5b941048139 100644 (file)
@@ -31,11 +31,11 @@ class GraphProperties;
 class ModelAnalyzer {
  public:
   explicit ModelAnalyzer(const GrapplerItem& item);
-  Status GenerateReport(std::ostream& os);
+  Status GenerateReport(bool debug, std::ostream& os);
 
  private:
   void PrintNodeInfo(const NodeDef* node, const GraphProperties& properties,
-                     std::ostream& os) const;
+                     bool debug, std::ostream& os) const;
 
   const GrapplerItem& item_;
 };
index 726143a0bb4db28538f4338eb3773d85332dc122..7c3a692d0efc501341ff1dff3cf24b8a4830ec84 100644 (file)
@@ -40,7 +40,7 @@ limitations under the License.
 %}
 
 %{
-string GenerateModelReport(const tensorflow::MetaGraphDef& metagraph) {
+string GenerateModelReport(const tensorflow::MetaGraphDef& metagraph, bool debug) {
   tensorflow::grappler::ItemConfig cfg;
   cfg.apply_optimizations = false;
   std::unique_ptr<tensorflow::grappler::GrapplerItem> item =
@@ -53,10 +53,10 @@ string GenerateModelReport(const tensorflow::MetaGraphDef& metagraph) {
   tensorflow::grappler::ModelAnalyzer analyzer(*item);
 
   std::stringstream os;
-  analyzer.GenerateReport(os);
+  analyzer.GenerateReport(debug, os);
   return os.str();
 }
 
 %}
 
-string GenerateModelReport(const tensorflow::MetaGraphDef& metagraph);
+string GenerateModelReport(const tensorflow::MetaGraphDef& metagraph, bool debug);
index c852d71ad8b047f5437ca62c49a5500bc29cec60..535889e1c4034952562a05e4d044fcafeddbc0ca 100644 (file)
@@ -22,16 +22,18 @@ from tensorflow.python import pywrap_tensorflow as tf_wrap
 from tensorflow.python.framework import errors
 
 
-def GenerateModelReport(metagraph):
+def GenerateModelReport(metagraph, debug=False):
   """Report what's known statically about each node in the provided metagraph.
 
   Args:
     metagraph: A TensorFlow MetaGraphDef.
+    debug: Add some information useful for debugging.
 
   Returns:
     A string containing the report.
   """
   with errors.raise_exception_on_not_ok_status():
-    ret_from_swig = tf_wrap.GenerateModelReport(metagraph.SerializeToString())
+    ret_from_swig = tf_wrap.GenerateModelReport(metagraph.SerializeToString(),
+                                                debug)
 
   return ret_from_swig
index b59d1650f4b5e4c7239c2275213e9a26c3aafafe..ec172755f1ae43fc7581e97c6a18471da45f9100 100644 (file)
@@ -49,6 +49,24 @@ class PyWrapOptimizeGraphTest(test.TestCase):
     # Also print the report to make it easier to debug
     print("{}".format(report))
 
+  def testDebugMode(self):
+    """Make sure arguments can be passed correctly."""
+    a = constant_op.constant([10, 11], name="a")
+    b = constant_op.constant([10], name="b")
+    c = math_ops.add(a, b, name="c")
+    train_op = ops.get_collection_ref(ops.GraphKeys.TRAIN_OP)
+    train_op.append(c)
+    mg = meta_graph.create_meta_graph_def(graph=ops.get_default_graph())
+
+    report = model_analyzer.GenerateModelReport(mg, debug=True)
+
+    # Check the report headers
+    self.assertTrue(b"input 0 (int32) has known value" in report)
+    self.assertTrue(b"input 1 (int32) has known value" in report)
+
+    # Also print the report to make it easier to debug
+    print("{}".format(report))
+
 
 if __name__ == "__main__":
   test.main()