#include <memory>
#include "tensorflow/core/grappler/utils.h"
#include "tensorflow/core/lib/core/status.h"
+#include "tensorflow/core/protobuf/rewriter_config.pb.h"
#include "tensorflow/core/public/session.h"
namespace tensorflow {
namespace grappler {
+GrapplerTest::GrapplerTest() {
+ // Turn off all the automatic optimizations to ensure that we run the graph
+ // exactly as it is given to us. This ensures that we can compare the results
+ // before and after manual optimization, without any of the automatic
+ // optimizations interfering in the comparison.
+ RewriterConfig* cfg =
+ options_.config.mutable_graph_options()->mutable_rewrite_options();
+ cfg->set_constant_folding(RewriterConfig::OFF);
+ cfg->set_arithmetic_optimization(RewriterConfig::OFF);
+ cfg->set_dependency_optimization(RewriterConfig::OFF);
+ cfg->set_loop_optimization(RewriterConfig::OFF);
+ cfg->set_function_optimization(RewriterConfig::OFF);
+ cfg->set_layout_optimizer(RewriterConfig::OFF);
+}
+
std::vector<Tensor> GrapplerTest::EvaluateNodes(
const GraphDef& graph, const std::vector<string>& node_names) const {
- SessionOptions options;
- std::unique_ptr<tensorflow::Session> session(NewSession(options));
+ std::unique_ptr<tensorflow::Session> session(NewSession(options_));
TF_CHECK_OK(session->Create(graph));
RunOptions run_options;
std::vector<Tensor> output_tensors;
std::vector<Tensor> GrapplerTest::EvaluateFetchNodes(
const GrapplerItem& item) const {
- SessionOptions options;
- std::unique_ptr<tensorflow::Session> session(NewSession(options));
+ std::unique_ptr<tensorflow::Session> session(NewSession(options_));
TF_CHECK_OK(session->Create(item.graph));
RunOptions run_options;
if (!item.init_ops.empty()) {
#include "tensorflow/core/grappler/grappler_item.h"
#include "tensorflow/core/grappler/utils.h"
#include "tensorflow/core/platform/test.h"
+#include "tensorflow/core/public/session_options.h"
namespace tensorflow {
namespace grappler {
class GrapplerTest : public ::testing::Test {
+ public:
+ GrapplerTest();
+
protected:
std::vector<Tensor> EvaluateNodes(
const GraphDef& graph, const std::vector<string>& node_names) const;
// Count nodes of the given op-type in a graph.
int CountOpNodes(const GraphDef& graph, const string& op);
+
+ private:
+ SessionOptions options_;
};
} // end namespace grappler