Enable fusion of element-wise ops using XLA (Off by default, can be enabled by settin...
authorA. Unique TensorFlower <gardener@tensorflow.org>
Fri, 23 Mar 2018 06:25:37 +0000 (23:25 -0700)
committerTensorFlower Gardener <gardener@tensorflow.org>
Fri, 23 Mar 2018 06:28:35 +0000 (23:28 -0700)
PiperOrigin-RevId: 190179459

tensorflow/compiler/jit/legacy_flags/mark_for_compilation_pass_flags.cc
tensorflow/compiler/jit/legacy_flags/mark_for_compilation_pass_flags.h
tensorflow/compiler/jit/mark_for_compilation_pass.cc
tensorflow/compiler/tests/jit_test.py

index 51384ac..7277a1d 100644 (file)
@@ -41,6 +41,7 @@ static void AllocateFlags() {
   flags->tf_xla_clustering_debug = false;
   flags->tf_xla_cpu_global_jit = false;
   flags->tf_xla_clustering_fuel = std::numeric_limits<int64>::max();
+  flags->tf_xla_fusion_only = false;
   flag_list = new std::vector<Flag>(
       {Flag("tf_xla_auto_jit", &flags->tf_xla_auto_jit,
             "Control compilation of operators into XLA computations on CPU and "
@@ -59,7 +60,10 @@ static void AllocateFlags() {
             "Enables global JIT compilation for CPU via SessionOptions."),
        Flag("tf_xla_clustering_fuel", &flags->tf_xla_clustering_fuel,
             "Places an artificial limit on the number of ops marked as "
-            "eligible for clustering.")});
+            "eligible for clustering."),
+       Flag("tf_xla_fusion_only", &flags->tf_xla_fusion_only,
+            "enable fusion of element-wise operations only using XLA when "
+            "global_jit_level is ON*.")});
   xla::legacy_flags::ParseFlagsFromEnv(*flag_list);
 }
 
index 170b89c..2affda6 100644 (file)
@@ -51,6 +51,10 @@ typedef struct {
   int64 tf_xla_clustering_fuel;   // "Compiler fuel" for clustering.  Only this
                                   // many ops will be marked as eligible for
                                   // clustering.
+  bool tf_xla_fusion_only;  // This flag is effective only when global_jit_level
+                            // is set to ON* and overrides its behavior. If
+                            // true, enable fusion of element-wise operations
+                            // only using XLA.
 } MarkForCompilationPassFlags;
 
 // Return a pointer to the MarkForCompilationPassFlags struct;
index 57fb8d2..f651768 100644 (file)
@@ -180,6 +180,158 @@ struct NodeCompare {
 };
 using OrderedNodeSet = std::set<Node*, NodeCompare>;
 
+// Returns true if the op can be decomposed into XLA ops for which
+// there are fusable elemental implementations.
+//
+// TODO(hpucha): Consider a black list instead of a white list as
+// implemented below.
+bool IsXlaFusable(const NodeDef& node) {
+  static const std::unordered_set<std::string>* elementwise_ops =
+      new std::unordered_set<std::string>(
+          {// tf2xla/kernels/aggregate_ops.cc
+           "AddN",
+           // tf2xla/kernels/batchtospace_op.cc
+           "BatchToSpace", "BatchToSpaceND",
+           // tf2xla/kernels/bcast_ops.cc
+           "BroadcastArgs", "BroadcastGradientArgs",
+           // tf2xla/kernels/bias_ops.cc
+           "BiasAdd", "BiasAddV1", "BiasAddGrad" /*(Reduce)*/,
+           // tf2xla/kernels/binary_ops.cc
+           "Add", "Sub", "Mul", "Div", "Atan2", "Complex", "FloorDiv",
+           "FloorMod", "BitwiseAnd", "BitwiseOr", "LeftShift", "RightShift",
+           "LogicalAnd", "LogicalOr", "Mod", "Maximum", "Minimum", "RealDiv",
+           "ReciprocalGrad", "RsqrtGrad", "SqrtGrad", "SquaredDifference",
+           "TruncateDiv", "TruncateMod", "Equal", "NotEqual", "Greater",
+           "GreaterEqual", "Less", "LessEqual", "SigmoidGrad", "SoftplusGrad",
+           "SoftsignGrad", "TanhGrad", "Pow", "ApproximateEqual",
+           // tf2xla/kernels/cast_op.cc
+           "Cast",
+           // tf2xla/kernels/categorical_op.cc
+           "Multinomial" /* (Rng ops are disabled on GPU backend currently)*/,
+           // tf2xla/kernels/concat_op.cc
+           "Concat", "ConcatV2", "ConcatOffset",
+           // tf2xla/kernels/const_op.cc
+           "Const",
+           // tf2xla/kernels/cross_op.cc
+           "Cross",
+           // tf2xla/kernels/depthtospace_op.cc
+           "DepthToSpace",
+           // tf2xla/kernels/diag_op.cc
+           "Diag", "DiagPart", "MatrixDiag", "MatrixDiagPart",
+           // tf2xla/kernels/dynamic_stitch_op.cc
+           "DynamicStitch", "ParallelDynamicStitch",
+           // tf2xla/kernels/elu_op.cc
+           "Elu", "EluGrad", "Selu", "SeluGrad",
+           // tf2xla/kernels/fake_quantize_ops.cc
+           "FakeQuantWithMinMaxArgs", "FakeQuantWithMinMaxArgsGradient",
+           "FakeQuantWithMinMaxVars",
+           "FakeQuantWithMinMaxVarsGradient" /*(Reduce)*/,
+           // tf2xla/kernels/fill_op.cc
+           "Fill",
+           // tf2xla/kernels/gather_op.cc
+           "Gather", "GatherV2", "GatherNd",
+           // tf2xla/kernels/identity_op.cc
+           "Identity", "IdentityN", "PreventGradient", "StopGradient",
+           "Snapshot",
+           // tf2xla/kernels/image_ops.cc
+           "RGBToHSV", "HSVToRGB", "AdjustContrastv2" /*(Reduce)*/,
+           "AdjustSaturation", "AdjustHue",
+           // tf2xla/kernels/index_ops.cc
+           "ArgMax", "ArgMin",
+           // tf2xla/kernels/l2loss_op.cc
+           "L2Loss" /*(Reduce)*/,
+           // tf2xla/kernels/lrn_ops.cc (ReduceWindow)
+           "LRN", "LRNGrad",
+           // tf2xla/kernels/matrix_band_part_op.cc
+           "MatrixBandPart",
+           // tf2xla/kernels/matrix_set_diag_op.cc
+           "MatrixSetDiag",
+           // tf2xla/kernels/mirror_pad_op.cc
+           "MirrorPad",
+           // tf2xla/kernels/no_op.cc
+           "NoOp", "ControlTrigger",
+           // tf2xla/kernels/one_hot_op.cc
+           "OneHot",
+           // tf2xla/kernels/pack_op.cc
+           "Pack",
+           // tf2xla/kernels/pad_op.cc
+           "Pad", "PadV2",
+           // tf2xla/kernels/pooling_ops.cc
+           "MaxPool", "MaxPoolV2", "MaxPool3D", "AvgPool",
+           "AvgPool3D", /*(all the pooling ops use ReduceWindow)*/
+           "MaxPoolGrad", "MaxPoolGradV2", "MaxPool3DGrad", "AvgPoolGrad",
+           "AvgPool3DGrad",
+           // tf2xla/kernels/quantize_and_dequantize_op.cc (Reduce)
+           "QuantizeAndDequantizeV2",
+           // tf2xla/kernels/random_ops.cc (Rng ops are disabled on GPU backend
+           // currently)
+           "RandomUniform", "RandomUniformInt", "RandomStandardNormal",
+           "TruncatedNormal",
+           // tf2xla/kernels/reduction_ops.cc (Reduce)
+           "Sum", "Prod", "Min", "Max", "Mean", "All", "Any",
+           // tf2xla/kernels/relu_op.cc
+           "Relu", "Relu6", "ReluGrad", "Relu6Grad",
+           // tf2xla/kernels/reshape_op.cc
+           "Reshape",
+           // tf2xla/kernels/reverse_op.cc
+           "Reverse", "ReverseV2",
+           // tf2xla/kernels/reverse_sequence_op.cc
+           "ReverseSequence",
+           // tf2xla/kernels/scan_ops.cc (ReduceWindow)
+           "Cumsum", "Cumprod",
+           // tf2xla/kernels/scatter_nd_op.cc (Reduce)
+           "ScatterNd",
+           // tf2xla/kernels/segment_reduction_ops.cc (Reduce)
+           "UnsortedSegmentSum",
+           // tf2xla/kernels/select_op.cc
+           "Select",
+           // tf2xla/kernels/sequence_ops.cc
+           "Range", "LinSpace",
+           // tf2xla/kernels/shape_op.cc
+           "Shape", "ShapeN", "Rank", "Size", "ExpandDims", "Squeeze",
+           "ZerosLike", "OnesLike",
+           // tf2xla/kernels/slice_op.cc
+           "Slice",
+           // tf2xla/kernels/softmax_op.cc (Reduce)
+           "Softmax", "LogSoftmax", "SoftmaxCrossEntropyWithLogits",
+           "SparseSoftmaxCrossEntropyWithLogits",
+           // tf2xla/kernels/spacetobatch_op.cc
+           "SpaceToBatchND", "SpaceToBatch",
+           // tf2xla/kernels/spacetodepth_op.cc
+           "SpaceToDepth",
+           // tf2xla/kernels/split_op.cc
+           "Split", "SplitV",
+           // tf2xla/kernels/stack_ops.cc
+           "StackV2", "StackPushV2", "StackPopV2", "StackCloseV2",
+           // tf2xla/kernels/stateless_random_ops.cc (Rng ops are disabled on
+           // GPU
+           // backend currently)
+           "StatelessRandomUniform",
+           "StatelessRandomNormal"
+           // tf2xla/kernels/strided_slice_op.cc
+           "StridedSlice",
+           "StridedSliceGrad", "ResourceStridedSliceAssign",
+           // tf2xla/kernels/tile_ops.cc
+           "Tile",
+           // tf2xla/kernels/training_ops.cc
+           "ResourceApplyGradientDescent", "ResourceApplyMomentum",
+           "ResourceApplyAdagrad", "ResourceApplyAdam", "ResourceApplyRMSProp",
+           "ResourceApplyFtrl", "ResourceApplyFtrlV2",
+           // tf2xla/kernels/transpose_op.cc
+           "Transpose", "InvertPermutation",
+           // tf2xla/kernels/unary_ops.cc
+           "ComplexAbs", "Angle", "Conj", "Abs", "Acos", "Acosh", "Asin",
+           "Asinh", "Atan", "Atanh", "Ceil", "Cos", "Cosh", "Sin", "Exp",
+           "Expm1", "Floor", "IsFinite", "IsInf", "IsNan", "Inv", "Reciprocal",
+           "Log", "Log1p", "Invert", "LogicalNot", "Neg", "Rint", "Round",
+           "Rsqrt", "Sigmoid", "Sign", "Sinh", "Softplus", "Softsign", "Sqrt",
+           "Square", "Tan", "Tanh", "Real", "Imag",
+           // tf2xla/kernels/unpack_op.cc
+           "Unpack"});
+
+  return elementwise_ops->count(node.op()) > 0;
+}
+
 Status FindCompilationCandidates(
     const Graph& graph, FunctionLibraryDefinition* flib_def, Env* env,
     const std::function<bool(const Node*, const DeviceType&)>& is_compilable_fn,
@@ -338,10 +490,13 @@ Status MarkForCompilationPass::Run(
         static_cast<OptimizerOptions::GlobalJitLevel>(flags->tf_xla_auto_jit);
   }
   bool cpu_global_jit = flags->tf_xla_cpu_global_jit;
+  bool fusion_only = flags->tf_xla_fusion_only;
+
   VLOG(1) << "flags->tf_xla_cpu_global_jit = " << flags->tf_xla_cpu_global_jit;
+  VLOG(1) << "flags->tf_xla_fusion_only = " << flags->tf_xla_fusion_only;
   const FunctionLibraryDefinition* fld = options.flib_def;
 
-  auto is_compilable = [global_jit_level, cpu_global_jit, fld](
+  auto is_compilable = [global_jit_level, cpu_global_jit, fusion_only, fld](
                            const Node* node, const DeviceType& device_type) {
     const XlaOpRegistry::DeviceRegistration* registration;
     if (!XlaOpRegistry::GetCompilationDevice(device_type.type(),
@@ -364,6 +519,11 @@ Status MarkForCompilationPass::Run(
     status = fld->GetAttr(*node, kXlaCompileAttr, &compile);
     if (status.ok()) return compile;
 
+    // Check for fusable ops only if requested.
+    if (global_jit_level > 0 && fusion_only && !IsXlaFusable(node->def())) {
+      return false;
+    }
+
     // Otherwise use the value of global_jit_level.
     // Ignore enable_jit_by_default if global jit compilation for CPU
     // is explicitly requested via tf_xla_cpu_global_jit flag
index 2d8236e..f9d87c2 100644 (file)
@@ -18,6 +18,7 @@ from __future__ import absolute_import
 from __future__ import division
 from __future__ import print_function
 
+import os
 import numpy as np
 
 from tensorflow.contrib.compiler import jit
@@ -436,5 +437,55 @@ class XlaCompilationTest(test.TestCase):
     self.assertTrue(InLabels(labels, "_XlaLaunch"))
 
 
+class ElementWiseFusionTest(test.TestCase):
+
+  # Runs a simple test with the input jit_level and fusion_only flag.
+  def simpleTest(self, arg0, arg1, global_jit_level):
+    config = config_pb2.ConfigProto()
+    config.graph_options.optimizer_options.global_jit_level = global_jit_level
+
+    with session_lib.Session(config=config) as sess:
+      a1 = array_ops.placeholder(dtypes.float32, [2, 2], name="a1")
+      a2 = array_ops.placeholder(dtypes.float32, [2, 2], name="a2")
+      # Two element-wise ops. We need at least two ops since single
+      # element clusters are not passed to XLA in fusion_only mode.
+      a3 = a1 * a2
+      a4 = a3 + a1
+      # A matmul to break XLA clustering.
+      a5 = math_ops.matmul(a4, a1)
+      # Two more element-wise ops.
+      a6 = a5 - a4
+      a7 = a6 + a2
+
+      run_metadata = config_pb2.RunMetadata()
+      output = sess.run(
+          a7, {
+              a1: arg0,
+              a2: arg1
+          },
+          run_metadata=run_metadata,
+          options=config_pb2.RunOptions(
+              trace_level=config_pb2.RunOptions.FULL_TRACE))
+
+      labels = RunMetadataLabels(run_metadata)
+      count = sum("_XlaLaunch(" in x for x in labels)
+
+      return output, count
+
+  def testElementWiseClustering(self):
+    arg0 = np.random.rand(2, 2).astype(np.float32)
+    arg1 = np.random.rand(2, 2).astype(np.float32)
+    os.environ["TF_XLA_FLAGS"] = "--tf_xla_fusion_only=true"
+    tf_op, tf_count = self.simpleTest(arg0, arg1,
+                                      config_pb2.OptimizerOptions.OFF)
+    self.assertEqual(0, tf_count)
+
+    tfef_op, tfef_count = self.simpleTest(arg0, arg1,
+                                          config_pb2.OptimizerOptions.ON_1)
+    self.assertEqual(2, tfef_count)
+
+    self.assertAllClose(tf_op, tfef_op, rtol=1e-1)
+
+
 if __name__ == "__main__":
   test.main()