[TF-XLA] Disable Tensorflow's CSE in xla compiler
authorYunxing Dai <yunxing@google.com>
Tue, 13 Feb 2018 22:51:25 +0000 (14:51 -0800)
committerTensorFlower Gardener <gardener@tensorflow.org>
Tue, 13 Feb 2018 22:55:12 +0000 (14:55 -0800)
No need to do CSE in TF-XLA bridge, as XLA already has its own CSE pass later in the compilation pipeline. This removes one source of nondeterminism.

RELNOTES: CSE pass from Tensorflow is now disabled in XLA.
PiperOrigin-RevId: 185592383

tensorflow/compiler/tf2xla/xla_compiler.cc

index c5b4ec5..59e8830 100644 (file)
@@ -153,7 +153,8 @@ std::unique_ptr<Graph> XlaCompiler::GetGraph(const FunctionBody* fbody) {
   std::unique_ptr<Graph> graph(new Graph(options_.flib_def));
   CopyGraph(*fbody->graph, graph.get());
   OptimizerOptions opts;
-  opts.set_do_common_subexpression_elimination(true);
+  opts.set_opt_level(OptimizerOptions::L0);
+  opts.set_do_common_subexpression_elimination(false);
   opts.set_do_function_inlining(true);
   opts.set_do_constant_folding(true);
   GraphOptimizer optimizer(opts);
@@ -184,8 +185,7 @@ Status XlaCompiler::CompileFunction(const XlaCompiler::CompileOptions& options,
       CheckSignature(fbody->arg_types, args),
       "Signature check failure while compiling: ", function.name());
 
-  std::unique_ptr<Graph> graph(new Graph(options_.flib_def));
-  CopyGraph(*fbody->graph, graph.get());
+  std::unique_ptr<Graph> graph = GetGraph(fbody);
 
   // _Arg and _Retval nodes don't exist in the stored subgraph for the function;
   // they are added by the function body looked up.  Therefore, they don't have
@@ -213,15 +213,6 @@ Status XlaCompiler::CompileFunction(const XlaCompiler::CompileOptions& options,
                    *graph);
   }
 
-  // Optimize the graph before running the compiler.
-  OptimizerOptions opts;
-  opts.set_do_common_subexpression_elimination(true);
-  opts.set_do_function_inlining(true);
-  opts.set_do_constant_folding(true);
-  GraphOptimizer optimizer(opts);
-  optimizer.Optimize(flib_runtime_, flib_runtime_->env(),
-                     /*device=*/nullptr, &graph, /*shape_map=*/nullptr);
-
   VLOG(1) << "====================================================";
   TF_RETURN_IF_ERROR(
       CompileGraph(options, function_id, std::move(graph), args, result));