Fix deprecated scalar type in ATen/native/Distributions.cpp
authorXiang Gao <qasdfgtyuiop@gmail.com>
Sat, 23 Mar 2019 17:01:28 +0000 (10:01 -0700)
committerFacebook Github Bot <facebook-github-bot@users.noreply.github.com>
Sat, 23 Mar 2019 17:09:26 +0000 (10:09 -0700)
Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/18265

Differential Revision: D14577543

Pulled By: ezyang

fbshipit-source-id: 36674530b32366c51835e4073d7ba23d455d2fda

aten/src/ATen/native/Distributions.cpp

index 1dcac33..24392ee 100644 (file)
@@ -230,7 +230,7 @@ Tensor _s_gamma_cpu(const Tensor& alpha, Generator *gen) {
 
 Tensor _s_dirichlet_cpu(const Tensor& alpha, Generator *gen) {
   Tensor ret = at::zeros(alpha.sizes(), alpha.options());
-  AT_DISPATCH_FLOATING_TYPES(ret.type(), "dirichlet", [&] {
+  AT_DISPATCH_FLOATING_TYPES(ret.scalar_type(), "dirichlet", [&] {
     Tensor gamma = at::zeros(alpha.sizes(), alpha.options().dtype(ScalarType::Double));
     THGenerator* generator = get_generator(gen);
     std::lock_guard<std::mutex> lock(generator->mutex);