From 5860fa5dcf993ddb53080edcbfbc42cfed676a25 Mon Sep 17 00:00:00 2001 From: Xiang Gao Date: Sat, 23 Mar 2019 10:01:28 -0700 Subject: [PATCH] Fix deprecated scalar type in ATen/native/Distributions.cpp Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/18265 Differential Revision: D14577543 Pulled By: ezyang fbshipit-source-id: 36674530b32366c51835e4073d7ba23d455d2fda --- aten/src/ATen/native/Distributions.cpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/aten/src/ATen/native/Distributions.cpp b/aten/src/ATen/native/Distributions.cpp index 1dcac33..24392ee 100644 --- a/aten/src/ATen/native/Distributions.cpp +++ b/aten/src/ATen/native/Distributions.cpp @@ -230,7 +230,7 @@ Tensor _s_gamma_cpu(const Tensor& alpha, Generator *gen) { Tensor _s_dirichlet_cpu(const Tensor& alpha, Generator *gen) { Tensor ret = at::zeros(alpha.sizes(), alpha.options()); - AT_DISPATCH_FLOATING_TYPES(ret.type(), "dirichlet", [&] { + AT_DISPATCH_FLOATING_TYPES(ret.scalar_type(), "dirichlet", [&] { Tensor gamma = at::zeros(alpha.sizes(), alpha.options().dtype(ScalarType::Double)); THGenerator* generator = get_generator(gen); std::lock_guard lock(generator->mutex); -- 2.7.4