Remove degenerate batch dimensions form batch dot
authorSanjoy Das <sanjoy@google.com>
Fri, 11 May 2018 22:02:33 +0000 (15:02 -0700)
committerTensorFlower Gardener <gardener@tensorflow.org>
Fri, 11 May 2018 22:08:35 +0000 (15:08 -0700)
commit95f12f9bd5e8f73a67d534a608a384fe73729dad
tree64962ad401e70c573599bf235ef78e2c498c3651
parent81a162301830a02d72184a996c2abdde9b9b149a
Remove degenerate batch dimensions form batch dot

The way things are set up today this specific optimization isn't particularly
important, but I want to implement a follow-on optimization in
BatchDotSimplification to transform (non-degenerate) batch GEMV operations into
GEMM which I'm expecting to help us a bit.

This would normally be in the algebraic simplifier, but we want to fixpoint this
pass before we run DotDecomposer.  This will become more important when we
implement the (non-degenerate) batch GEMV operations -> GEMM transform.

PiperOrigin-RevId: 196314230
tensorflow/compiler/xla/service/BUILD
tensorflow/compiler/xla/service/batch_dot_simplification.cc [new file with mode: 0644]
tensorflow/compiler/xla/service/batch_dot_simplification.h [new file with mode: 0644]
tensorflow/compiler/xla/service/batch_dot_simplification_test.cc [new file with mode: 0644]
tensorflow/compiler/xla/service/cpu/BUILD
tensorflow/compiler/xla/service/cpu/cpu_compiler.cc
tensorflow/compiler/xla/service/hlo_creation_utils.cc
tensorflow/compiler/xla/service/hlo_creation_utils.h