};
using OrderedNodeSet = std::set<Node*, NodeCompare>;
+// Returns true if the op can be decomposed into XLA ops for which
+// there are fusable elemental implementations.
+//
+// TODO(hpucha): Consider a black list instead of a white list as
+// implemented below.
+bool IsXlaFusable(const NodeDef& node) {
+ static const std::unordered_set<std::string>* elementwise_ops =
+ new std::unordered_set<std::string>(
+ {// tf2xla/kernels/aggregate_ops.cc
+ "AddN",
+ // tf2xla/kernels/batchtospace_op.cc
+ "BatchToSpace", "BatchToSpaceND",
+ // tf2xla/kernels/bcast_ops.cc
+ "BroadcastArgs", "BroadcastGradientArgs",
+ // tf2xla/kernels/bias_ops.cc
+ "BiasAdd", "BiasAddV1", "BiasAddGrad" /*(Reduce)*/,
+ // tf2xla/kernels/binary_ops.cc
+ "Add", "Sub", "Mul", "Div", "Atan2", "Complex", "FloorDiv",
+ "FloorMod", "BitwiseAnd", "BitwiseOr", "LeftShift", "RightShift",
+ "LogicalAnd", "LogicalOr", "Mod", "Maximum", "Minimum", "RealDiv",
+ "ReciprocalGrad", "RsqrtGrad", "SqrtGrad", "SquaredDifference",
+ "TruncateDiv", "TruncateMod", "Equal", "NotEqual", "Greater",
+ "GreaterEqual", "Less", "LessEqual", "SigmoidGrad", "SoftplusGrad",
+ "SoftsignGrad", "TanhGrad", "Pow", "ApproximateEqual",
+ // tf2xla/kernels/cast_op.cc
+ "Cast",
+ // tf2xla/kernels/categorical_op.cc
+ "Multinomial" /* (Rng ops are disabled on GPU backend currently)*/,
+ // tf2xla/kernels/concat_op.cc
+ "Concat", "ConcatV2", "ConcatOffset",
+ // tf2xla/kernels/const_op.cc
+ "Const",
+ // tf2xla/kernels/cross_op.cc
+ "Cross",
+ // tf2xla/kernels/depthtospace_op.cc
+ "DepthToSpace",
+ // tf2xla/kernels/diag_op.cc
+ "Diag", "DiagPart", "MatrixDiag", "MatrixDiagPart",
+ // tf2xla/kernels/dynamic_stitch_op.cc
+ "DynamicStitch", "ParallelDynamicStitch",
+ // tf2xla/kernels/elu_op.cc
+ "Elu", "EluGrad", "Selu", "SeluGrad",
+ // tf2xla/kernels/fake_quantize_ops.cc
+ "FakeQuantWithMinMaxArgs", "FakeQuantWithMinMaxArgsGradient",
+ "FakeQuantWithMinMaxVars",
+ "FakeQuantWithMinMaxVarsGradient" /*(Reduce)*/,
+ // tf2xla/kernels/fill_op.cc
+ "Fill",
+ // tf2xla/kernels/gather_op.cc
+ "Gather", "GatherV2", "GatherNd",
+ // tf2xla/kernels/identity_op.cc
+ "Identity", "IdentityN", "PreventGradient", "StopGradient",
+ "Snapshot",
+ // tf2xla/kernels/image_ops.cc
+ "RGBToHSV", "HSVToRGB", "AdjustContrastv2" /*(Reduce)*/,
+ "AdjustSaturation", "AdjustHue",
+ // tf2xla/kernels/index_ops.cc
+ "ArgMax", "ArgMin",
+ // tf2xla/kernels/l2loss_op.cc
+ "L2Loss" /*(Reduce)*/,
+ // tf2xla/kernels/lrn_ops.cc (ReduceWindow)
+ "LRN", "LRNGrad",
+ // tf2xla/kernels/matrix_band_part_op.cc
+ "MatrixBandPart",
+ // tf2xla/kernels/matrix_set_diag_op.cc
+ "MatrixSetDiag",
+ // tf2xla/kernels/mirror_pad_op.cc
+ "MirrorPad",
+ // tf2xla/kernels/no_op.cc
+ "NoOp", "ControlTrigger",
+ // tf2xla/kernels/one_hot_op.cc
+ "OneHot",
+ // tf2xla/kernels/pack_op.cc
+ "Pack",
+ // tf2xla/kernels/pad_op.cc
+ "Pad", "PadV2",
+ // tf2xla/kernels/pooling_ops.cc
+ "MaxPool", "MaxPoolV2", "MaxPool3D", "AvgPool",
+ "AvgPool3D", /*(all the pooling ops use ReduceWindow)*/
+ "MaxPoolGrad", "MaxPoolGradV2", "MaxPool3DGrad", "AvgPoolGrad",
+ "AvgPool3DGrad",
+ // tf2xla/kernels/quantize_and_dequantize_op.cc (Reduce)
+ "QuantizeAndDequantizeV2",
+ // tf2xla/kernels/random_ops.cc (Rng ops are disabled on GPU backend
+ // currently)
+ "RandomUniform", "RandomUniformInt", "RandomStandardNormal",
+ "TruncatedNormal",
+ // tf2xla/kernels/reduction_ops.cc (Reduce)
+ "Sum", "Prod", "Min", "Max", "Mean", "All", "Any",
+ // tf2xla/kernels/relu_op.cc
+ "Relu", "Relu6", "ReluGrad", "Relu6Grad",
+ // tf2xla/kernels/reshape_op.cc
+ "Reshape",
+ // tf2xla/kernels/reverse_op.cc
+ "Reverse", "ReverseV2",
+ // tf2xla/kernels/reverse_sequence_op.cc
+ "ReverseSequence",
+ // tf2xla/kernels/scan_ops.cc (ReduceWindow)
+ "Cumsum", "Cumprod",
+ // tf2xla/kernels/scatter_nd_op.cc (Reduce)
+ "ScatterNd",
+ // tf2xla/kernels/segment_reduction_ops.cc (Reduce)
+ "UnsortedSegmentSum",
+ // tf2xla/kernels/select_op.cc
+ "Select",
+ // tf2xla/kernels/sequence_ops.cc
+ "Range", "LinSpace",
+ // tf2xla/kernels/shape_op.cc
+ "Shape", "ShapeN", "Rank", "Size", "ExpandDims", "Squeeze",
+ "ZerosLike", "OnesLike",
+ // tf2xla/kernels/slice_op.cc
+ "Slice",
+ // tf2xla/kernels/softmax_op.cc (Reduce)
+ "Softmax", "LogSoftmax", "SoftmaxCrossEntropyWithLogits",
+ "SparseSoftmaxCrossEntropyWithLogits",
+ // tf2xla/kernels/spacetobatch_op.cc
+ "SpaceToBatchND", "SpaceToBatch",
+ // tf2xla/kernels/spacetodepth_op.cc
+ "SpaceToDepth",
+ // tf2xla/kernels/split_op.cc
+ "Split", "SplitV",
+ // tf2xla/kernels/stack_ops.cc
+ "StackV2", "StackPushV2", "StackPopV2", "StackCloseV2",
+ // tf2xla/kernels/stateless_random_ops.cc (Rng ops are disabled on
+ // GPU
+ // backend currently)
+ "StatelessRandomUniform",
+ "StatelessRandomNormal"
+ // tf2xla/kernels/strided_slice_op.cc
+ "StridedSlice",
+ "StridedSliceGrad", "ResourceStridedSliceAssign",
+ // tf2xla/kernels/tile_ops.cc
+ "Tile",
+ // tf2xla/kernels/training_ops.cc
+ "ResourceApplyGradientDescent", "ResourceApplyMomentum",
+ "ResourceApplyAdagrad", "ResourceApplyAdam", "ResourceApplyRMSProp",
+ "ResourceApplyFtrl", "ResourceApplyFtrlV2",
+ // tf2xla/kernels/transpose_op.cc
+ "Transpose", "InvertPermutation",
+ // tf2xla/kernels/unary_ops.cc
+ "ComplexAbs", "Angle", "Conj", "Abs", "Acos", "Acosh", "Asin",
+ "Asinh", "Atan", "Atanh", "Ceil", "Cos", "Cosh", "Sin", "Exp",
+ "Expm1", "Floor", "IsFinite", "IsInf", "IsNan", "Inv", "Reciprocal",
+ "Log", "Log1p", "Invert", "LogicalNot", "Neg", "Rint", "Round",
+ "Rsqrt", "Sigmoid", "Sign", "Sinh", "Softplus", "Softsign", "Sqrt",
+ "Square", "Tan", "Tanh", "Real", "Imag",
+ // tf2xla/kernels/unpack_op.cc
+ "Unpack"});
+
+ return elementwise_ops->count(node.op()) > 0;
+}
+
Status FindCompilationCandidates(
const Graph& graph, FunctionLibraryDefinition* flib_def, Env* env,
const std::function<bool(const Node*, const DeviceType&)>& is_compilable_fn,
static_cast<OptimizerOptions::GlobalJitLevel>(flags->tf_xla_auto_jit);
}
bool cpu_global_jit = flags->tf_xla_cpu_global_jit;
+ bool fusion_only = flags->tf_xla_fusion_only;
+
VLOG(1) << "flags->tf_xla_cpu_global_jit = " << flags->tf_xla_cpu_global_jit;
+ VLOG(1) << "flags->tf_xla_fusion_only = " << flags->tf_xla_fusion_only;
const FunctionLibraryDefinition* fld = options.flib_def;
- auto is_compilable = [global_jit_level, cpu_global_jit, fld](
+ auto is_compilable = [global_jit_level, cpu_global_jit, fusion_only, fld](
const Node* node, const DeviceType& device_type) {
const XlaOpRegistry::DeviceRegistration* registration;
if (!XlaOpRegistry::GetCompilationDevice(device_type.type(),
status = fld->GetAttr(*node, kXlaCompileAttr, &compile);
if (status.ok()) return compile;
+ // Check for fusable ops only if requested.
+ if (global_jit_level > 0 && fusion_only && !IsXlaFusable(node->def())) {
+ return false;
+ }
+
// Otherwise use the value of global_jit_level.
// Ignore enable_jit_by_default if global jit compilation for CPU
// is explicitly requested via tf_xla_cpu_global_jit flag
from __future__ import division
from __future__ import print_function
+import os
import numpy as np
from tensorflow.contrib.compiler import jit
self.assertTrue(InLabels(labels, "_XlaLaunch"))
+class ElementWiseFusionTest(test.TestCase):
+
+ # Runs a simple test with the input jit_level and fusion_only flag.
+ def simpleTest(self, arg0, arg1, global_jit_level):
+ config = config_pb2.ConfigProto()
+ config.graph_options.optimizer_options.global_jit_level = global_jit_level
+
+ with session_lib.Session(config=config) as sess:
+ a1 = array_ops.placeholder(dtypes.float32, [2, 2], name="a1")
+ a2 = array_ops.placeholder(dtypes.float32, [2, 2], name="a2")
+ # Two element-wise ops. We need at least two ops since single
+ # element clusters are not passed to XLA in fusion_only mode.
+ a3 = a1 * a2
+ a4 = a3 + a1
+ # A matmul to break XLA clustering.
+ a5 = math_ops.matmul(a4, a1)
+ # Two more element-wise ops.
+ a6 = a5 - a4
+ a7 = a6 + a2
+
+ run_metadata = config_pb2.RunMetadata()
+ output = sess.run(
+ a7, {
+ a1: arg0,
+ a2: arg1
+ },
+ run_metadata=run_metadata,
+ options=config_pb2.RunOptions(
+ trace_level=config_pb2.RunOptions.FULL_TRACE))
+
+ labels = RunMetadataLabels(run_metadata)
+ count = sum("_XlaLaunch(" in x for x in labels)
+
+ return output, count
+
+ def testElementWiseClustering(self):
+ arg0 = np.random.rand(2, 2).astype(np.float32)
+ arg1 = np.random.rand(2, 2).astype(np.float32)
+ os.environ["TF_XLA_FLAGS"] = "--tf_xla_fusion_only=true"
+ tf_op, tf_count = self.simpleTest(arg0, arg1,
+ config_pb2.OptimizerOptions.OFF)
+ self.assertEqual(0, tf_count)
+
+ tfef_op, tfef_count = self.simpleTest(arg0, arg1,
+ config_pb2.OptimizerOptions.ON_1)
+ self.assertEqual(2, tfef_count)
+
+ self.assertAllClose(tf_op, tfef_op, rtol=1e-1)
+
+
if __name__ == "__main__":
test.main()