From c7b1fdb76723a2d0f1fa560b3ff25d088b3b349c Mon Sep 17 00:00:00 2001 From: Sebastian Messmer Date: Wed, 17 Apr 2019 23:41:22 -0700 Subject: [PATCH] Fixing function schema parser for Android (#19281) Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/19281 String<->Number conversions aren't available in the STL used in our Android environment. This diff adds workarounds for that so that the function schema parser can be compiled for android Reviewed By: dzhulgakov Differential Revision: D14931649 fbshipit-source-id: d5d386f2c474d3742ed89e52dff751513142efad --- c10/util/C++17.h | 2 + c10/util/string_utils.h | 19 +- tools/build_variables.py | 1 + torch/CMakeLists.txt | 2 + torch/csrc/jit/script/function_schema_parser.cpp | 7 +- torch/csrc/jit/script/lexer.cpp | 3 +- torch/csrc/jit/script/lexer.h | 21 +- torch/csrc/jit/script/schema_type_parser.cpp | 5 +- torch/csrc/jit/script/strtod.cpp | 252 +++++++++++++++++++++++ torch/csrc/jit/script/strtod.h | 14 ++ torch/csrc/jit/script/tree_views.h | 9 +- 11 files changed, 310 insertions(+), 25 deletions(-) create mode 100644 torch/csrc/jit/script/strtod.cpp create mode 100644 torch/csrc/jit/script/strtod.h diff --git a/c10/util/C++17.h b/c10/util/C++17.h index f115c81..f93eb00 100644 --- a/c10/util/C++17.h +++ b/c10/util/C++17.h @@ -7,6 +7,8 @@ #include #include #include +#include +#include /* * This header adds some polyfills with C++14 and C++17 functionality diff --git a/c10/util/string_utils.h b/c10/util/string_utils.h index 36a4eb3..60f1092 100644 --- a/c10/util/string_utils.h +++ b/c10/util/string_utils.h @@ -17,11 +17,18 @@ std::string to_string(T value) { return os.str(); } -inline int stoi(const std::string& str) { +inline int stoi(const std::string& str, std::size_t* pos = 0) { std::stringstream ss; int n = 0; ss << str; ss >> n; + if (pos) { + if (ss.tellg() == std::streampos(-1)) { + *pos = str.size(); + } else { + *pos = ss.tellg(); + } + } return n; } @@ -47,11 +54,21 @@ inline double stod(const std::string& str, std::size_t* pos = 0) { } return val; } + +inline long long stoll(const std::string& str) { + // std::stoll doesn't exist in our Android environment, we need to implement + // it ourselves. + std::istringstream s(str); + long long result = 0; + s >> result; + return result; +} #else #define CAFFE2_TESTONLY_WE_ARE_USING_CUSTOM_STRING_FUNCTIONS 0 using std::stod; using std::stoi; using std::stoull; +using std::stoll; using std::to_string; #endif // defined(__ANDROID__) || defined(CAFFE2_FORCE_STD_STRING_FALLBACK_TEST) diff --git a/tools/build_variables.py b/tools/build_variables.py index 3cc5f62..db521a0 100644 --- a/tools/build_variables.py +++ b/tools/build_variables.py @@ -113,6 +113,7 @@ libtorch_sources = [ "torch/csrc/jit/hooks_for_testing.cpp", "torch/csrc/jit/script/builtin_functions.cpp", "torch/csrc/jit/script/lexer.cpp", + "torch/csrc/jit/script/strtod.cpp", "torch/csrc/jit/script/module.cpp", "torch/csrc/jit/tracer.cpp", "torch/csrc/utils/tensor_flatten.cpp", diff --git a/torch/CMakeLists.txt b/torch/CMakeLists.txt index 2c906e9..b0390fc 100644 --- a/torch/CMakeLists.txt +++ b/torch/CMakeLists.txt @@ -190,6 +190,7 @@ set(TORCH_SRCS ${TORCH_SRC_DIR}/csrc/jit/script/builtin_functions.cpp ${TORCH_SRC_DIR}/csrc/jit/script/edit_distance.cpp ${TORCH_SRC_DIR}/csrc/jit/script/lexer.cpp + ${TORCH_SRC_DIR}/csrc/jit/script/strtod.cpp ${TORCH_SRC_DIR}/csrc/jit/script/logging.cpp ${TORCH_SRC_DIR}/csrc/jit/script/module.cpp ${TORCH_SRC_DIR}/csrc/jit/tracer.cpp @@ -531,6 +532,7 @@ if (BUILD_PYTHON) ${TORCH_SRC_DIR}/csrc/jit/python_tracer.cpp ${TORCH_SRC_DIR}/csrc/jit/script/init.cpp ${TORCH_SRC_DIR}/csrc/jit/script/lexer.cpp + ${TORCH_SRC_DIR}/csrc/jit/script/strtod.cpp ${TORCH_SRC_DIR}/csrc/jit/script/python_tree_views.cpp ${TORCH_SRC_DIR}/csrc/multiprocessing/init.cpp ${TORCH_SRC_DIR}/csrc/nn/THNN.cpp diff --git a/torch/csrc/jit/script/function_schema_parser.cpp b/torch/csrc/jit/script/function_schema_parser.cpp index dfc0080..f6b751c 100644 --- a/torch/csrc/jit/script/function_schema_parser.cpp +++ b/torch/csrc/jit/script/function_schema_parser.cpp @@ -2,6 +2,7 @@ #include #include #include +#include #include #include @@ -86,7 +87,7 @@ struct SchemaParser { if (L.nextIf('[')) { // note: an array with a size hint can only occur at the Argument level type = ListType::create(type); - N = std::stoll(L.expect(TK_NUMBER).text()); + N = c10::stoll(L.expect(TK_NUMBER).text()); L.expect(']'); auto container = type_parser.parseAliasAnnotation(); if (container && alias_info) { @@ -153,9 +154,9 @@ struct SchemaParser { n = L.expect(TK_NUMBER).text(); if (kind == TypeKind::FloatType || n.find('.') != std::string::npos || n.find('e') != std::string::npos) { - return std::stod(n); + return c10::stod(n); } else { - int64_t v = std::stoll(n); + int64_t v = c10::stoll(n); return v; } } diff --git a/torch/csrc/jit/script/lexer.cpp b/torch/csrc/jit/script/lexer.cpp index eaf94d6..653607b 100644 --- a/torch/csrc/jit/script/lexer.cpp +++ b/torch/csrc/jit/script/lexer.cpp @@ -89,7 +89,7 @@ std::string kindToString(int kind) { TC_FORALL_TOKEN_KINDS(DEFINE_CASE) #undef DEFINE_CASE default: - throw std::runtime_error("Unknown kind: " + std::to_string(kind)); + throw std::runtime_error("Unknown kind: " + c10::guts::to_string(kind)); } } @@ -97,6 +97,7 @@ SharedParserData& sharedParserData() { static SharedParserData data; // safely handles multi-threaded init return data; } + } // namespace script } // namespace jit } // namespace torch diff --git a/torch/csrc/jit/script/lexer.h b/torch/csrc/jit/script/lexer.h index c9b4851..e21a2c6 100644 --- a/torch/csrc/jit/script/lexer.h +++ b/torch/csrc/jit/script/lexer.h @@ -2,8 +2,11 @@ #include #include #include +#include #include #include +#include +#include #include #include #include @@ -160,19 +163,7 @@ struct SharedParserData { TC_FORALL_TOKEN_KINDS(ADD_CASE) #undef ADD_CASE } -#ifdef _WIN32 - static double strtod_c(const char* str, char** end) { - /// NOLINTNEXTLINE(hicpp-signed-bitwise) - static _locale_t loc = _create_locale(LC_ALL, "C"); - return _strtod_l(str, end, loc); - } -#else - static double strtod_c(const char* str, char** end) { - /// NOLINTNEXTLINE(hicpp-signed-bitwise) - static locale_t loc = newlocale(LC_ALL_MASK, "C", nullptr); - return strtod_l(str, end, loc); - } -#endif + // 1. skip whitespace // 2. handle comment or newline // @@ -186,7 +177,7 @@ struct SharedParserData { return false; const char* startptr = str.c_str() + start; char* endptr; - strtod_c(startptr, &endptr); + torch::jit::script::strtod_c(startptr, &endptr); *len = endptr - startptr; return *len > 0; } @@ -478,7 +469,7 @@ struct Lexer { indent_stack.pop_back(); next_tokens.emplace_back(TK_DEDENT, r.range); if (indent_stack.size() == 0) { - reportError("invalid indent level " + std::to_string(depth), r); + reportError("invalid indent level " + c10::guts::to_string(depth), r); } } return; // We've already queued the tokens diff --git a/torch/csrc/jit/script/schema_type_parser.cpp b/torch/csrc/jit/script/schema_type_parser.cpp index eb8a3e3..1897be8 100644 --- a/torch/csrc/jit/script/schema_type_parser.cpp +++ b/torch/csrc/jit/script/schema_type_parser.cpp @@ -4,6 +4,7 @@ #include #include #include +#include #include using c10::Symbol; @@ -105,7 +106,7 @@ c10::optional SchemaTypeParser::parseAliasAnnotation() { L.expect(')'); } else if (L.nextIf('!')) { alias_info.addBeforeSet( - Symbol::fromQualString("alias::$" + std::to_string(next_id++))); + Symbol::fromQualString("alias::$" + c10::guts::to_string(next_id++))); alias_info.setIsWrite(true); } else { return c10::nullopt; @@ -147,7 +148,7 @@ TypePtr SchemaTypeParser::parseRefinedTensor() { parseList(TK_NOTHING, ',', ')', [&] { const std::string& num = L.expect(TK_NUMBER).text(); std::string::size_type num_len; - size_t dim = std::stoi(num, &num_len); + size_t dim = c10::stoi(num, &num_len); AT_ASSERTM( num_len == num.size(), "Bad tensor dimension size. Strides not yet supported in parsing", diff --git a/torch/csrc/jit/script/strtod.cpp b/torch/csrc/jit/script/strtod.cpp new file mode 100644 index 0000000..4d40725 --- /dev/null +++ b/torch/csrc/jit/script/strtod.cpp @@ -0,0 +1,252 @@ +// Taken from https://github.com/JuliaLang/julia/blob/v1.1.0/src/support/strtod.c + +#include +#include +#include + +#if defined(__APPLE__) || defined(__FreeBSD__) +#include +#endif + +// The following code is derived from the Python function _PyOS_ascii_strtod +// see http://hg.python.org/cpython/file/default/Python/pystrtod.c +// +// Copyright © 2001-2014 Python Software Foundation; All Rights Reserved +// +// The following modifications have been made: +// - Leading spaces are ignored +// - Parsing of hex floats is supported in the derived version +// - Python functions for tolower, isdigit and malloc have been replaced by the respective +// C stdlib functions + +#include +#include +#include +#include +#include + +#define D_PNAN ((double)+NAN) +#define D_PINF ((double)+INFINITY) + +namespace { +int case_insensitive_match(const char *s, const char *t) +{ + while (*t && tolower(*s) == *t) { + s++; + t++; + } + return *t ? 0 : 1; +} + +double parse_inf_or_nan(const char *p, char **endptr) +{ + double retval; + const char *s; + int negate = 0; + + s = p; + if (*s == '-') { + negate = 1; + s++; + } + else if (*s == '+') { + s++; + } + if (case_insensitive_match(s, "inf")) { + s += 3; + if (case_insensitive_match(s, "inity")) + s += 5; + retval = negate ? -D_PINF : D_PINF; + } + else if (case_insensitive_match(s, "nan")) { + s += 3; + retval = negate ? -D_PNAN : D_PNAN; + } + else { + s = p; + retval = -1.0; + } + *endptr = (char *)s; + return retval; +} + +} + +namespace torch { +namespace jit { +namespace script { + +C10_EXPORT double strtod_c(const char *nptr, char **endptr) +{ + char *fail_pos; + double val; + const char *p, *decimal_point_pos; + const char *end = NULL; /* Silence gcc */ + const char *digits_pos = NULL; + int negate = 0; + + fail_pos = NULL; + + char decimal_point = std::use_facet>(std::locale()).decimal_point(); + + decimal_point_pos = NULL; + + /* Parse infinities and nans */ + val = parse_inf_or_nan(nptr, endptr); + if (*endptr != nptr) + return val; + + /* Set errno to zero, so that we can distinguish zero results + and underflows */ + errno = 0; + + /* We process the optional sign manually, then pass the remainder to + the system strtod. This ensures that the result of an underflow + has the correct sign. */ + p = nptr; + + /* parse leading spaces */ + while (isspace((unsigned char)*p)) { + p++; + } + + /* Process leading sign, if present */ + if (*p == '-') { + negate = 1; + p++; + } + else if (*p == '+') { + p++; + } + + /* This code path is used for hex floats */ + if (*p == '0' && (*(p+1) == 'x' || *(p+1) == 'X')) { + digits_pos = p; + p += 2; + /* Check that what's left begins with a digit or decimal point */ + if (!isxdigit(*p) && *p != '.') + goto invalid_string; + + + if (decimal_point != '.') { + /* Look for a '.' in the input; if present, it'll need to be + swapped for the current locale's decimal point before we + call strtod. On the other hand, if we find the current + locale's decimal point then the input is invalid. */ + while (isxdigit(*p)) + p++; + + if (*p == '.') { + decimal_point_pos = p++; + + /* locate end of number */ + while (isxdigit(*p)) + p++; + + if (*p == 'p' || *p == 'P') + p++; + if (*p == '+' || *p == '-') + p++; + while (isdigit(*p)) + p++; + end = p; + } + else if (*p == decimal_point) + goto invalid_string; + /* For the other cases, we need not convert the decimal point */ + } + } + else { + /* Check that what's left begins with a digit or decimal point */ + if (!isdigit(*p) && *p != '.') + goto invalid_string; + + digits_pos = p; + if (decimal_point != '.') { + /* Look for a '.' in the input; if present, it'll need to be + swapped for the current locale's decimal point before we + call strtod. On the other hand, if we find the current + locale's decimal point then the input is invalid. */ + while (isdigit(*p)) + p++; + + if (*p == '.') { + decimal_point_pos = p++; + + /* locate end of number */ + while (isdigit(*p)) + p++; + + if (*p == 'e' || *p == 'E') + p++; + if (*p == '+' || *p == '-') + p++; + while (isdigit(*p)) + p++; + end = p; + } + else if (*p == decimal_point) + goto invalid_string; + /* For the other cases, we need not convert the decimal point */ + } + } + + if (decimal_point_pos) { + char *copy, *c; + /* Create a copy of the input, with the '.' converted to the + locale-specific decimal point */ + copy = (char *)malloc(end - digits_pos + 2); + if (copy == NULL) { + *endptr = (char *)nptr; + errno = ENOMEM; + return val; + } + + c = copy; + memcpy(c, digits_pos, decimal_point_pos - digits_pos); + c += decimal_point_pos - digits_pos; + memcpy(c, &decimal_point, 1); + c += 1; + memcpy(c, decimal_point_pos + 1, + end - (decimal_point_pos + 1)); + c += end - (decimal_point_pos + 1); + *c = 0; + + val = strtod(copy, &fail_pos); + + if (fail_pos) + { + fail_pos = (char *)digits_pos + + (fail_pos - copy); + } + + free(copy); + } + else { + val = strtod(digits_pos, &fail_pos); + } + + if (fail_pos == digits_pos) + goto invalid_string; + + if (negate && fail_pos != nptr) + val = -val; + *endptr = fail_pos; + + return val; + +invalid_string: + *endptr = (char*)nptr; + errno = EINVAL; + return -1.0; +} + + +C10_EXPORT float strtof_c(const char *nptr, char **endptr) +{ + return (float) strtod_c(nptr, endptr); +} + +} +} +} diff --git a/torch/csrc/jit/script/strtod.h b/torch/csrc/jit/script/strtod.h new file mode 100644 index 0000000..c74a0ed --- /dev/null +++ b/torch/csrc/jit/script/strtod.h @@ -0,0 +1,14 @@ +#pragma once + +#include + +namespace torch { +namespace jit { +namespace script { + +CAFFE2_API double strtod_c(const char *nptr, char **endptr); +CAFFE2_API float strtof_c(const char *nptr, char **endptr); + +} +} +} diff --git a/torch/csrc/jit/script/tree_views.h b/torch/csrc/jit/script/tree_views.h index 67b73dd..22a31a3 100644 --- a/torch/csrc/jit/script/tree_views.h +++ b/torch/csrc/jit/script/tree_views.h @@ -1,6 +1,8 @@ #pragma once #include #include +#include +#include #include #include @@ -746,11 +748,12 @@ struct Const : public Expr { return !isFloatingPoint(); } int64_t asIntegral() const { - return std::stoll(subtree(0)->stringValue()); + return c10::stoll(subtree(0)->stringValue()); } double asFloatingPoint() const { - return SharedParserData::strtod_c( - subtree(0)->stringValue().c_str(), nullptr); + char* dummy; + return torch::jit::script::strtod_c( + subtree(0)->stringValue().c_str(), &dummy); } const std::string& text() const { return subtree(0)->stringValue(); -- 2.7.4