From 9b69f21a95fa626522ef371f8557e7286f9db318 Mon Sep 17 00:00:00 2001 From: James Reed Date: Sun, 7 Apr 2019 00:15:42 -0700 Subject: [PATCH] Improve precision of emitted code for prim::Constant (#18817) Summary: Stacked on https://github.com/pytorch/pytorch/pull/18815 and https://github.com/pytorch/pytorch/pull/18811. This makes it so that we emit a higher-precision literal for float values in the fusion kernel, as well as assign that to a `double` variable. This prevents us from losing precision for values such as `pi`, but with the previous fixes this will also get downcasted to `float` if downstream operations require it. Therefore, we should not lose performance because of implicit promotions Pull Request resolved: https://github.com/pytorch/pytorch/pull/18817 Differential Revision: D14820842 Pulled By: jamesr66a fbshipit-source-id: 519671c6ca5e7adac746a4c4c72760a6d91e332f --- test/test_jit.py | 17 ++++++++++++ torch/csrc/jit/fuser/codegen.cpp | 54 +++++++++++++++++++++++---------------- torch/csrc/jit/fuser/executor.cpp | 4 +-- 3 files changed, 51 insertions(+), 24 deletions(-) diff --git a/test/test_jit.py b/test/test_jit.py index 141e4c1..7a4105c 100644 --- a/test/test_jit.py +++ b/test/test_jit.py @@ -4551,6 +4551,23 @@ a") test_dispatch(fn, lookup_c_equivalent_fn(fn) + '(', torch.double, binary=True) test_dispatch(fn, lookup_c_equivalent_fn(fn) + 'f(', torch.float, binary=True) + @unittest.skipIf(RUN_CUDA, 'This tests the CPU fuser') + @unittest.skipIf(IS_WINDOWS or IS_SANDCASTLE, "NYI: fuser support for Windows or Sandcastle") + @enable_cpu_fuser + def test_fuser_double_literal_precision(self): + code = ''' + graph(%2 : Float(*, *)): + %4 : int = prim::Constant[value=1]() + %3 : float = prim::Constant[value=1.282549830161864]() + %5 : Float(*, *) = aten::add(%2, %3, %4) + %1 : Float(*, *) = aten::relu(%5) + return (%1) + ''' + + graph = parse_ir(code) + code = torch._C._jit_fuser_get_fused_kernel_code(graph, [torch.rand(3, 4)]) + FileCheck().check('1.282549830161864').run(code) + def test_fuser_multiple_blocks(self): cu = torch.jit.CompilationUnit(''' def test_fuser_multiple_blocks(this, that, theother, meme): diff --git a/torch/csrc/jit/fuser/codegen.cpp b/torch/csrc/jit/fuser/codegen.cpp index 5511a35..e7725bf 100644 --- a/torch/csrc/jit/fuser/codegen.cpp +++ b/torch/csrc/jit/fuser/codegen.cpp @@ -55,7 +55,7 @@ static std::string scalarValue(const double v) { out << "POS_INFINITY"; } } else { - out << std::scientific << v << "f"; + out << std::setprecision(16) << v; } return out.str(); } @@ -86,9 +86,9 @@ static const char* calcScalarTypeName(const at::ScalarType type) { static std::string variableType(const std::shared_ptr& t) { if (t->kind() == TypeKind::IntType) { - return "int"; + return "int64_t"; } else if (t->kind() == TypeKind::FloatType) { - return "float"; + return "double"; } else if (t->kind() == TypeKind::BoolType) { return "bool"; } else if (t->kind() == TypeKind::DimensionedTensorType) { @@ -110,10 +110,12 @@ static std::string typeCastedValueName( } return vn; } else if (t->kind() == TypeKind::FloatType) { - if (!isFloatingType(outtype)) { - return std::string("((") + calcScalarTypeName(outtype) + ") " + vn + ")"; - } - return vn; + // We don't guard this on anything because in our type system for scalars, + // there is not a distinction between `float` and `double`, however there + // *is* a distinction in tensor scalar types. We conservatively insert a + // cast here, which may end up being a no-op if the tensor's scalar type + // is `double`. + return std::string("((") + calcScalarTypeName(outtype) + ") " + vn + ")"; } else if (t->kind() == TypeKind::DimensionedTensorType) { auto const tt = t->cast(); if (tt->scalarType() != outtype) { @@ -253,18 +255,6 @@ static std::string encodeRHS(const Node* n) { {aten::_tanh_backward, "${0} * (1.f - ${1} * ${1})"}, }; - if (n->kind() == prim::Constant) { - const auto val = toIValue(n->output()).value(); - if (val.isDouble()) { - return scalarValue(val.toDouble()); - } else if (val.isBool()) { - return scalarValue(val.toBool()); - } else { - AT_ASSERT(val.isInt()); - return scalarValue(val.toInt()); - } - } - TemplateEnv env; if (simple_map_ops.find(n->kind()) == simple_map_ops.end()) { @@ -450,10 +440,30 @@ std::string generateKernel( AT_ASSERT(use_cuda); has_random = true; } + // Always emit double for prim::Constant. This will be narrowed later based + // on either: + // - Tensor-Scalar operator type rules + // - Math function rules + if (n->kind() == prim::Constant) { + const auto val = toIValue(n->output()).value(); + std::string rhs; + if (val.isDouble()) { + rhs = scalarValue(val.toDouble()); + } else if (val.isBool()) { + rhs = scalarValue(val.toBool()); + } else { + AT_ASSERT(val.isInt()); + rhs = scalarValue(val.toInt()); + } + env.s("node", valueName(n->output())); + env.s("rhs", rhs); + env.s("lhs_type", variableType(n->output()->type())); + } else { + env.s("node", valueName(n->output())); + env.s("rhs", encodeRHS(n)); + env.s("lhs_type", variableType(n->output()->type())); + } - env.s("node", valueName(n->output())); - env.s("rhs", encodeRHS(n)); - env.s("lhs_type", variableType(n->output()->type())); body << format("${lhs_type} ${node} = ${rhs};\n", env); } diff --git a/torch/csrc/jit/fuser/executor.cpp b/torch/csrc/jit/fuser/executor.cpp index 51b154f..fab9043 100644 --- a/torch/csrc/jit/fuser/executor.cpp +++ b/torch/csrc/jit/fuser/executor.cpp @@ -224,7 +224,7 @@ void launchFusion( } // compute number of scalar inputs and convert them to float - std::vector scalar_inputs; + std::vector scalar_inputs; scalar_inputs.reserve(all_inputs.size()); for (auto const &input: all_inputs){ if (input.isDouble()) scalar_inputs.push_back(input.to()); @@ -283,7 +283,7 @@ void launchFusion( } } // Adds scalar arguments - for (float &s: scalar_inputs){ + for (double &s: scalar_inputs){ arguments.push_back(&s); } -- 2.7.4