Speed up gemm by reordering the for loops (#17730)
Summary:
Optimize the order of the "for" loops.
Note: For "transa = true" cases, the order of the "for" loops has been optimzied in the original code. Therefore, no significant improvement is observed in those case (i.e. "transa && transb" and "transa && !transb")
mode/opt (i.e. static libary)
//////////////////////////////////////////////////////////////////////////////
transa && transb
after:
loops: 2229 x: 128 y: 128 z: 128 time: 2243ns => acceleration multiplier: 0.90
loops: 124 x: 128 y: 1024 z: 128 time: 40381ns => acceleration multiplier: 0.97
loops: 121 x: 1024 y: 128 z: 128 time: 41651ns => acceleration multiplier: 0.96
loops: 15 x: 1024 y: 1024 z: 128 time: 333771ns => acceleration multiplier: 0.98
loops: 4610 x: 128 y: 128 z: 64 time: 1084ns => acceleration multiplier: 0.95
loops: 252 x: 128 y: 1024 z: 64 time: 19860ns => acceleration multiplier: 0.98
loops: 248 x: 1024 y: 128 z: 64 time: 20232ns => acceleration multiplier: 0.98
loops: 30 x: 1024 y: 1024 z: 64 time: 167338ns => acceleration multiplier: 0.99
before:
loops: 2468 x: 128 y: 128 z: 128 time: 2026ns
loops: 128 x: 128 y: 1024 z: 128 time: 39338ns
loops: 126 x: 1024 y: 128 z: 128 time: 39930ns
loops: 16 x: 1024 y: 1024 z: 128 time: 327549ns
loops: 4840 x: 128 y: 128 z: 64 time: 1033ns
loops: 258 x: 128 y: 1024 z: 64 time: 19441ns
loops: 252 x: 1024 y: 128 z: 64 time: 19854ns
loops: 31 x: 1024 y: 1024 z: 64 time: 166254ns
//////////////////////////////////////////////////////////////////////////////
transa && !transb
after:
loops: 4880 x: 128 y: 128 z: 128 time: 1024ns => acceleration multiplier: 0.98
loops: 638 x: 128 y: 1024 z: 128 time: 7839ns => acceleration multiplier: 1.04
loops: 605 x: 1024 y: 128 z: 128 time: 8276ns => acceleration multiplier: 1.01
loops: 77 x: 1024 y: 1024 z: 128 time: 65713ns => acceleration multiplier: 1.00
loops: 9935 x: 128 y: 128 z: 64 time: 503ns => acceleration multiplier: 1.00
loops: 1252 x: 128 y: 1024 z: 64 time: 3994ns => acceleration multiplier: 1.00
loops: 1183 x: 1024 y: 128 z: 64 time: 4226ns => acceleration multiplier: 0.98
loops: 153 x: 1024 y: 1024 z: 64 time: 32766ns => acceleration multiplier: 0.99
before:
loops: 4985 x: 128 y: 128 z: 128 time: 1003ns
loops: 615 x: 128 y: 1024 z: 128 time: 8140ns
loops: 599 x: 1024 y: 128 z: 128 time: 8357ns
loops: 76 x: 1024 y: 1024 z: 128 time: 65934ns
loops: 9897 x: 128 y: 128 z: 64 time: 505ns
loops: 1248 x: 128 y: 1024 z: 64 time: 4008ns
loops: 1203 x: 1024 y: 128 z: 64 time: 4159ns
loops: 154 x: 1024 y: 1024 z: 64 time: 32499ns
//////////////////////////////////////////////////////////////////////////////
!transa && transb
after:
loops: 3919 x: 128 y: 128 z: 128 time: 1276ns => acceleration multiplier: 2.97
loops: 497 x: 128 y: 1024 z: 128 time: 10069ns => acceleration multiplier: 7.85
loops: 449 x: 1024 y: 128 z: 128 time: 11145ns => acceleration multiplier: 4.77
loops: 57 x: 1024 y: 1024 z: 128 time: 88595ns => acceleration multiplier: 7.12
loops: 7575 x: 128 y: 128 z: 64 time: 660ns => acceleration multiplier: 3.00
loops: 967 x: 128 y: 1024 z: 64 time: 5173ns => acceleration multiplier: 7.66
loops: 877 x: 1024 y: 128 z: 64 time: 5702ns => acceleration multiplier: 4.76
loops: 111 x: 1024 y: 1024 z: 64 time: 45232ns => acceleration multiplier: 7.03
before:
loops: 1320 x: 128 y: 128 z: 128 time: 3789ns
loops: 64 x: 128 y: 1024 z: 128 time: 79061ns
loops: 95 x: 1024 y: 128 z: 128 time: 53107ns
loops: 8 x: 1024 y: 1024 z: 128 time: 631161ns
loops: 2521 x: 128 y: 128 z: 64 time: 1983ns
loops: 127 x: 128 y: 1024 z: 64 time: 39604ns
loops: 185 x: 1024 y: 128 z: 64 time: 27128ns
loops: 16 x: 1024 y: 1024 z: 64 time: 318155ns
//////////////////////////////////////////////////////////////////////////////
!transa && !transb
after:
loops: 3895 x: 128 y: 128 z: 128 time: 1283ns => acceleration multiplier: 1.73
loops: 393 x: 128 y: 1024 z: 128 time: 12746ns => acceleration multiplier: 3.36
loops: 411 x: 1024 y: 128 z: 128 time: 12170ns => acceleration multiplier: 1.93
loops: 46 x: 1024 y: 1024 z: 128 time: 110116ns => acceleration multiplier: 3.17
loops: 7404 x: 128 y: 128 z: 64 time: 675ns => acceleration multiplier: 1.58
loops: 636 x: 128 y: 1024 z: 64 time: 7872ns => acceleration multiplier: 2.70
loops: 724 x: 1024 y: 128 z: 64 time: 6911ns => acceleration multiplier: 1.32
loops: 73 x: 1024 y: 1024 z: 64 time: 68502ns => acceleration multiplier: 2.49
before:
loops: 2253 x: 128 y: 128 z: 128 time: 2219ns
loops: 117 x: 128 y: 1024 z: 128 time: 42788ns
loops: 214 x: 1024 y: 128 z: 128 time: 23465ns
loops: 15 x: 1024 y: 1024 z: 128 time: 349076ns
loops: 4694 x: 128 y: 128 z: 64 time: 1065ns
loops: 236 x: 128 y: 1024 z: 64 time: 21251ns
loops: 549 x: 1024 y: 128 z: 64 time: 9108ns
loops: 30 x: 1024 y: 1024 z: 64 time: 170799ns
Pull Request resolved: https://github.com/pytorch/pytorch/pull/17730
Differential Revision:
D14325149
Pulled By: zhangguanheng66
fbshipit-source-id:
a7a5a83890fdf99fee6eb87a3a5060b7b6bd862f