Speed up gemm by reordering the for loops (#17730)
authorGuanheng Zhang <zhangguanheng@fb.com>
Wed, 13 Mar 2019 15:22:56 +0000 (08:22 -0700)
committerFacebook Github Bot <facebook-github-bot@users.noreply.github.com>
Wed, 13 Mar 2019 15:57:26 +0000 (08:57 -0700)
commit26a4c2ada6aadace7b0c62658eb9f567ff313951
tree7111f618866cc1a82d772312a122c61bcf513dfe
parentecc5e623a29a825937ab6e8e9c4b4156b5a48626
Speed up gemm by reordering the for loops (#17730)

Summary:
Optimize the order of the "for" loops.

Note: For "transa = true" cases, the order of the "for" loops has been optimzied in the original code. Therefore, no significant improvement is observed in those case (i.e. "transa && transb" and "transa && !transb")

mode/opt (i.e. static libary)
//////////////////////////////////////////////////////////////////////////////
transa && transb
after:
loops:  2229     x:     128      y:     128      z:     128      time:  2243ns      =>  acceleration multiplier:  0.90
loops:  124      x:     128      y:     1024     z:     128      time:  40381ns      =>  acceleration multiplier:  0.97
loops:  121      x:     1024     y:     128      z:     128      time:  41651ns      =>  acceleration multiplier:  0.96
loops:  15       x:     1024     y:     1024     z:     128      time:  333771ns       =>  acceleration multiplier:  0.98
loops:  4610     x:     128      y:     128      z:     64       time:  1084ns       =>  acceleration multiplier:  0.95
loops:  252      x:     128      y:     1024     z:     64       time:  19860ns      =>  acceleration multiplier:  0.98
loops:  248      x:     1024     y:     128      z:     64       time:  20232ns      =>  acceleration multiplier:  0.98
loops:  30       x:     1024     y:     1024     z:     64       time:  167338ns      =>  acceleration multiplier:  0.99

before:
loops:  2468     x:     128      y:     128      z:     128      time:  2026ns
loops:  128      x:     128      y:     1024     z:     128      time:  39338ns
loops:  126      x:     1024     y:     128      z:     128      time:  39930ns
loops:  16       x:     1024     y:     1024     z:     128      time:  327549ns
loops:  4840     x:     128      y:     128      z:     64       time:  1033ns
loops:  258      x:     128      y:     1024     z:     64       time:  19441ns
loops:  252      x:     1024     y:     128      z:     64       time:  19854ns
loops:  31       x:     1024     y:     1024     z:     64       time:  166254ns

//////////////////////////////////////////////////////////////////////////////
transa && !transb
after:
loops:  4880     x:     128      y:     128      z:     128      time:  1024ns      =>  acceleration multiplier:  0.98
loops:  638      x:     128      y:     1024     z:     128      time:  7839ns      =>  acceleration multiplier:  1.04
loops:  605      x:     1024     y:     128      z:     128      time:  8276ns      =>  acceleration multiplier:  1.01
loops:  77       x:     1024     y:     1024     z:     128      time:  65713ns      =>  acceleration multiplier:  1.00
loops:  9935     x:     128      y:     128      z:     64       time:  503ns      =>  acceleration multiplier:  1.00
loops:  1252     x:     128      y:     1024     z:     64       time:  3994ns      =>  acceleration multiplier:  1.00
loops:  1183     x:     1024     y:     128      z:     64       time:  4226ns      =>  acceleration multiplier:  0.98
loops:  153      x:     1024     y:     1024     z:     64       time:  32766ns      =>  acceleration multiplier:  0.99

before:
loops:  4985     x:     128      y:     128      z:     128      time:  1003ns
loops:  615      x:     128      y:     1024     z:     128      time:  8140ns
loops:  599      x:     1024     y:     128      z:     128      time:  8357ns
loops:  76       x:     1024     y:     1024     z:     128      time:  65934ns
loops:  9897     x:     128      y:     128      z:     64       time:  505ns
loops:  1248     x:     128      y:     1024     z:     64       time:  4008ns
loops:  1203     x:     1024     y:     128      z:     64       time:  4159ns
loops:  154      x:     1024     y:     1024     z:     64       time:  32499ns

//////////////////////////////////////////////////////////////////////////////
!transa && transb
after:
loops:  3919     x:     128      y:     128      z:     128      time:  1276ns      =>  acceleration multiplier:  2.97
loops:  497      x:     128      y:     1024     z:     128      time:  10069ns      =>  acceleration multiplier:  7.85
loops:  449      x:     1024     y:     128      z:     128      time:  11145ns      =>  acceleration multiplier:  4.77
loops:  57       x:     1024     y:     1024     z:     128      time:  88595ns      =>  acceleration multiplier:  7.12
loops:  7575     x:     128      y:     128      z:     64       time:  660ns      =>  acceleration multiplier:  3.00
loops:  967      x:     128      y:     1024     z:     64       time:  5173ns      =>  acceleration multiplier:  7.66
loops:  877      x:     1024     y:     128      z:     64       time:  5702ns      =>  acceleration multiplier:  4.76
loops:  111      x:     1024     y:     1024     z:     64       time:  45232ns      =>  acceleration multiplier:  7.03

before:
loops:  1320     x:     128      y:     128      z:     128      time:  3789ns
loops:  64       x:     128      y:     1024     z:     128      time:  79061ns
loops:  95       x:     1024     y:     128      z:     128      time:  53107ns
loops:  8        x:     1024     y:     1024     z:     128      time:  631161ns
loops:  2521     x:     128      y:     128      z:     64       time:  1983ns
loops:  127      x:     128      y:     1024     z:     64       time:  39604ns
loops:  185      x:     1024     y:     128      z:     64       time:  27128ns
loops:  16       x:     1024     y:     1024     z:     64       time:  318155ns

//////////////////////////////////////////////////////////////////////////////
!transa && !transb
after:
loops:  3895     x:     128      y:     128      z:     128      time:  1283ns      =>  acceleration multiplier:  1.73
loops:  393      x:     128      y:     1024     z:     128      time:  12746ns      =>  acceleration multiplier:  3.36
loops:  411      x:     1024     y:     128      z:     128      time:  12170ns      =>  acceleration multiplier:  1.93
loops:  46       x:     1024     y:     1024     z:     128      time:  110116ns      =>  acceleration multiplier:  3.17
loops:  7404     x:     128      y:     128      z:     64       time:  675ns      =>  acceleration multiplier:  1.58
loops:  636      x:     128      y:     1024     z:     64       time:  7872ns      =>  acceleration multiplier:  2.70
loops:  724      x:     1024     y:     128      z:     64       time:  6911ns      =>  acceleration multiplier:  1.32
loops:  73       x:     1024     y:     1024     z:     64       time:  68502ns      =>  acceleration multiplier:  2.49

before:
loops:  2253     x:     128      y:     128      z:     128      time:  2219ns
loops:  117      x:     128      y:     1024     z:     128      time:  42788ns
loops:  214      x:     1024     y:     128      z:     128      time:  23465ns
loops:  15       x:     1024     y:     1024     z:     128      time:  349076ns
loops:  4694     x:     128      y:     128      z:     64       time:  1065ns
loops:  236      x:     128      y:     1024     z:     64       time:  21251ns
loops:  549      x:     1024     y:     128      z:     64       time:  9108ns
loops:  30       x:     1024     y:     1024     z:     64       time:  170799ns
Pull Request resolved: https://github.com/pytorch/pytorch/pull/17730

Differential Revision: D14325149

Pulled By: zhangguanheng66

fbshipit-source-id: a7a5a83890fdf99fee6eb87a3a5060b7b6bd862f
aten/src/TH/generic/THBlas.cpp
test/test_torch.py