Introduce an "indexed array" analysis
authorSanjoy Das <sanjoy@google.com>
Thu, 17 May 2018 22:47:30 +0000 (15:47 -0700)
committerTensorFlower Gardener <gardener@tensorflow.org>
Thu, 17 May 2018 22:52:10 +0000 (15:52 -0700)
commitb669510b115b5c726fd5e69b5062a1072c034a57
treeda4cbc996725ecc1dbc6d61477a8b43c933d1b61
parent317f3e09109dcb6f4fc70718d1ad2be70e4d2bf8
Introduce an "indexed array" analysis

Context: we want to optimize computations hanging off of a embedding lookup from
a constant array.  For instance, consider:

  embedding = gather from a constant array using non-constant indices
  embedding_reshaped = reshape embedding
  embedding_reshaped_transposed = transpose embedding_reshaped
  result = dot(embedding_reshaped_transposed, constant)

In the graph above, depending on how the details work out, we may be able to
fold `result` into a gather from a precomputed constant array.  However, it is
inconvenient to get there by incremental rewrites -- it is probably not
profitable to rewrite embedding_reshaped or embedding_reshaped_transposed [0] as
embedding lookups but we get to "see" that the dot can be rewritten only after
rewriting the reshape and the transpose.

This analysis aims to make the optimization above more straightforward by
allowing a transformation pass (that uses this analysis) to query the analysis
to see if if `result` _can_ be represented as an embedding lookup.  If yes it
can then apply some profitability heuristics to decide if it is worth it to
rewrite it as one.  This suggested workflow gives us separation of concerns (the
legality of the rewrite is computed separately from its profitability) and, more
importantly, lets us "look ahead" and analyze the dot without rewriting its
operands.

The implementation is far from complete (most of the interesting bits are TODO)
but I wanted to get an early design review before I spent too much time on this.

[0] Under the assumption that transposing or reshaping are not expensive enough
to pay the price of keeping around a new potentially large constant (in
particular, some of these may have been equivalent to free bitcasts).

PiperOrigin-RevId: 197064648
tensorflow/compiler/xla/service/BUILD
tensorflow/compiler/xla/service/cpu/BUILD
tensorflow/compiler/xla/service/cpu/cpu_compiler.cc
tensorflow/compiler/xla/service/indexed_array_analysis.cc [new file with mode: 0644]
tensorflow/compiler/xla/service/indexed_array_analysis.h [new file with mode: 0644]
tensorflow/compiler/xla/service/indexed_array_analysis_test.cc [new file with mode: 0644]