srcs = ["grappler/model_analyzer.cc"],
hdrs = ["grappler/model_analyzer.h"],
deps = [
+ "//tensorflow/core:framework",
"//tensorflow/core:lib",
"//tensorflow/core:protos_all_cc",
"//tensorflow/core/grappler:grappler_item",
"//tensorflow/core/grappler/costs:graph_properties",
- "//tensorflow/core/grappler/costs:utils",
],
)
#include "tensorflow/python/grappler/model_analyzer.h"
#include <iomanip>
+#include "tensorflow/core/framework/op.h"
#include "tensorflow/core/framework/tensor_shape.pb.h"
#include "tensorflow/core/grappler/costs/graph_properties.h"
#include "tensorflow/core/grappler/grappler_item.h"
ModelAnalyzer::ModelAnalyzer(const GrapplerItem& item) : item_(item) {}
-Status ModelAnalyzer::GenerateReport(std::ostream& os) {
+Status ModelAnalyzer::GenerateReport(bool debug, std::ostream& os) {
GraphProperties properties(item_);
TF_RETURN_IF_ERROR(properties.InferStatically(false));
for (const auto& node : item_.MainOpsFanin()) {
- PrintNodeInfo(node, properties, os);
+ PrintNodeInfo(node, properties, debug, os);
}
for (const auto& node : item_.EnqueueOpsFanin()) {
- PrintNodeInfo(node, properties, os);
+ PrintNodeInfo(node, properties, debug, os);
}
return Status::OK();
}
void ModelAnalyzer::PrintNodeInfo(const NodeDef* node,
- const GraphProperties& properties,
+ const GraphProperties& properties, bool debug,
std::ostream& os) const {
os << node->name() << " [" << node->op() << "]" << std::endl;
if (properties.HasOutputProperties(node->name())) {
- std::vector<OpInfo::TensorProperties> props =
+ const std::vector<OpInfo::TensorProperties>& props =
properties.GetOutputProperties(node->name());
for (int i = 0; i < props.size(); ++i) {
const OpInfo::TensorProperties& prop = props[i];
os << std::endl;
}
}
+
+ if (debug) {
+ const OpRegistrationData* op_reg_data;
+ Status status = OpRegistry::Global()->LookUp(node->op(), &op_reg_data);
+ if (!status.ok()) {
+ os << "\tCouldn't find op registration for " << node->op() << std::endl;
+ } else if (!op_reg_data->shape_inference_fn) {
+ os << "\tCouldn't find shape function for op " << node->op() << std::endl;
+ } else if (properties.HasInputProperties(node->name())) {
+ const std::vector<OpInfo::TensorProperties>& props =
+ properties.GetInputProperties(node->name());
+ for (int i = 0; i < props.size(); ++i) {
+ const OpInfo::TensorProperties& prop = props[i];
+ if (prop.has_value()) {
+ os << "\t"
+ << "input " << i << " (" << DataTypeString(prop.dtype())
+ << ") has known value" << std::endl;
+ }
+ }
+ }
+ }
}
} // end namespace grappler
class ModelAnalyzer {
public:
explicit ModelAnalyzer(const GrapplerItem& item);
- Status GenerateReport(std::ostream& os);
+ Status GenerateReport(bool debug, std::ostream& os);
private:
void PrintNodeInfo(const NodeDef* node, const GraphProperties& properties,
- std::ostream& os) const;
+ bool debug, std::ostream& os) const;
const GrapplerItem& item_;
};
%}
%{
-string GenerateModelReport(const tensorflow::MetaGraphDef& metagraph) {
+string GenerateModelReport(const tensorflow::MetaGraphDef& metagraph, bool debug) {
tensorflow::grappler::ItemConfig cfg;
cfg.apply_optimizations = false;
std::unique_ptr<tensorflow::grappler::GrapplerItem> item =
tensorflow::grappler::ModelAnalyzer analyzer(*item);
std::stringstream os;
- analyzer.GenerateReport(os);
+ analyzer.GenerateReport(debug, os);
return os.str();
}
%}
-string GenerateModelReport(const tensorflow::MetaGraphDef& metagraph);
+string GenerateModelReport(const tensorflow::MetaGraphDef& metagraph, bool debug);
from tensorflow.python.framework import errors
-def GenerateModelReport(metagraph):
+def GenerateModelReport(metagraph, debug=False):
"""Report what's known statically about each node in the provided metagraph.
Args:
metagraph: A TensorFlow MetaGraphDef.
+ debug: Add some information useful for debugging.
Returns:
A string containing the report.
"""
with errors.raise_exception_on_not_ok_status():
- ret_from_swig = tf_wrap.GenerateModelReport(metagraph.SerializeToString())
+ ret_from_swig = tf_wrap.GenerateModelReport(metagraph.SerializeToString(),
+ debug)
return ret_from_swig
# Also print the report to make it easier to debug
print("{}".format(report))
+ def testDebugMode(self):
+ """Make sure arguments can be passed correctly."""
+ a = constant_op.constant([10, 11], name="a")
+ b = constant_op.constant([10], name="b")
+ c = math_ops.add(a, b, name="c")
+ train_op = ops.get_collection_ref(ops.GraphKeys.TRAIN_OP)
+ train_op.append(c)
+ mg = meta_graph.create_meta_graph_def(graph=ops.get_default_graph())
+
+ report = model_analyzer.GenerateModelReport(mg, debug=True)
+
+ # Check the report headers
+ self.assertTrue(b"input 0 (int32) has known value" in report)
+ self.assertTrue(b"input 1 (int32) has known value" in report)
+
+ # Also print the report to make it easier to debug
+ print("{}".format(report))
+
if __name__ == "__main__":
test.main()