(#16825)
authorJie <jiej@nvidia.com>
Fri, 15 Feb 2019 14:44:49 +0000 (06:44 -0800)
committerFacebook Github Bot <facebook-github-bot@users.noreply.github.com>
Fri, 15 Feb 2019 15:02:30 +0000 (07:02 -0800)
commita771a6ba6722f51f5781971843f8babc042b7bef
tree7ee7ce4a1f1b880c7dd1152526443b7c7593234c
parentacf5ec07af5c17d52bcf020df2f8055f798219b3
(#16825)

Summary:
setting the correct math type for cudnn rnn, which is enforced starting from cudnn 7.5+

1. Updating persistent rnn check with input data type instead of rnn math type;
2. Updating rnn type promotion to set correct math type for accumulation;
3. Replace datatype check for filter descriptor from rnn.datatype to input.datatype;
Pull Request resolved: https://github.com/pytorch/pytorch/pull/16825

Differential Revision: D14071190

Pulled By: ezyang

fbshipit-source-id: 1c9a1531ccf510cb0619e830be444c20c5e72f3f
aten/src/ATen/cudnn/Descriptors.h
aten/src/ATen/native/cudnn/RNN.cpp