Improve fusion logic of (a dot b) * alpha
authorA. Unique TensorFlower <gardener@tensorflow.org>
Mon, 7 May 2018 10:00:34 +0000 (03:00 -0700)
committerTensorFlower Gardener <gardener@tensorflow.org>
Mon, 7 May 2018 23:27:20 +0000 (16:27 -0700)
commit6304da208116ed00ad4ee776787dfa6fe8256f4f
tree6f0a979a054724604118f54bc0ac2b1b6f988755
parente0be7b7c6f70aa0c3fc1b97de049cb4ccf1e9c0f
Improve fusion logic of (a dot b) * alpha

The previous approach didn't work because a multiplication by a scalar value
will be changed into an explicit broadcast.
Another issue that is fixed in this CL is retrieving the constant value from
the literal. This depends on the PrimitiveType, before we always assumed it to be double.
Also when checking ImplementedAsGemm() we should not call it recursively, but instead just the check related to kDot.
Finally add an execution test and adjust the fusion logic test.

PiperOrigin-RevId: 195638795
tensorflow/compiler/xla/service/gpu/BUILD
tensorflow/compiler/xla/service/gpu/instruction_fusion.cc
tensorflow/compiler/xla/service/gpu/instruction_fusion_test.cc
tensorflow/compiler/xla/service/gpu/ir_emission_utils.cc
tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.cc