Disable all the automatic optimizations when testing, to ensure that we can
authorBenoit Steiner <bsteiner@google.com>
Thu, 22 Mar 2018 20:24:51 +0000 (13:24 -0700)
committerTensorFlower Gardener <gardener@tensorflow.org>
Thu, 22 Mar 2018 20:27:34 +0000 (13:27 -0700)
properly compare the results of the original graph against that of the hand
optimized graph.

PiperOrigin-RevId: 190115606

tensorflow/core/grappler/utils/grappler_test.cc
tensorflow/core/grappler/utils/grappler_test.h

index 6b6cece..1c15ea6 100644 (file)
@@ -17,15 +17,30 @@ limitations under the License.
 #include <memory>
 #include "tensorflow/core/grappler/utils.h"
 #include "tensorflow/core/lib/core/status.h"
+#include "tensorflow/core/protobuf/rewriter_config.pb.h"
 #include "tensorflow/core/public/session.h"
 
 namespace tensorflow {
 namespace grappler {
 
+GrapplerTest::GrapplerTest() {
+  // Turn off all the automatic optimizations to ensure that we run the graph
+  // exactly as it is given to us. This ensures that we can compare the results
+  // before and after manual optimization, without any of the automatic
+  // optimizations interfering in the comparison.
+  RewriterConfig* cfg =
+      options_.config.mutable_graph_options()->mutable_rewrite_options();
+  cfg->set_constant_folding(RewriterConfig::OFF);
+  cfg->set_arithmetic_optimization(RewriterConfig::OFF);
+  cfg->set_dependency_optimization(RewriterConfig::OFF);
+  cfg->set_loop_optimization(RewriterConfig::OFF);
+  cfg->set_function_optimization(RewriterConfig::OFF);
+  cfg->set_layout_optimizer(RewriterConfig::OFF);
+}
+
 std::vector<Tensor> GrapplerTest::EvaluateNodes(
     const GraphDef& graph, const std::vector<string>& node_names) const {
-  SessionOptions options;
-  std::unique_ptr<tensorflow::Session> session(NewSession(options));
+  std::unique_ptr<tensorflow::Session> session(NewSession(options_));
   TF_CHECK_OK(session->Create(graph));
   RunOptions run_options;
   std::vector<Tensor> output_tensors;
@@ -37,8 +52,7 @@ std::vector<Tensor> GrapplerTest::EvaluateNodes(
 
 std::vector<Tensor> GrapplerTest::EvaluateFetchNodes(
     const GrapplerItem& item) const {
-  SessionOptions options;
-  std::unique_ptr<tensorflow::Session> session(NewSession(options));
+  std::unique_ptr<tensorflow::Session> session(NewSession(options_));
   TF_CHECK_OK(session->Create(item.graph));
   RunOptions run_options;
   if (!item.init_ops.empty()) {
index c7f0655..e0c6738 100644 (file)
@@ -24,11 +24,15 @@ limitations under the License.
 #include "tensorflow/core/grappler/grappler_item.h"
 #include "tensorflow/core/grappler/utils.h"
 #include "tensorflow/core/platform/test.h"
+#include "tensorflow/core/public/session_options.h"
 
 namespace tensorflow {
 namespace grappler {
 
 class GrapplerTest : public ::testing::Test {
+ public:
+  GrapplerTest();
+
  protected:
   std::vector<Tensor> EvaluateNodes(
       const GraphDef& graph, const std::vector<string>& node_names) const;
@@ -48,6 +52,9 @@ class GrapplerTest : public ::testing::Test {
 
   // Count nodes of the given op-type in a graph.
   int CountOpNodes(const GraphDef& graph, const string& op);
+
+ private:
+  SessionOptions options_;
 };
 
 }  // end namespace grappler