[libc] Add custom operator new to handle allocation failures gracefully.
authorSiva Chandra Reddy <sivachandra@google.com>
Wed, 7 Dec 2022 21:35:38 +0000 (21:35 +0000)
committerSiva Chandra Reddy <sivachandra@google.com>
Sun, 11 Dec 2022 00:29:04 +0000 (00:29 +0000)
This patch adds the implementation of the custom operator new functions.
The implementation of the internal strdup has been updated to use
operator new for allocation.

We will make it a policy and document that all allocations have to go
through the libc's own operator new. A future change will also add
operator delete replacements and make it a policy that deallocations in
libc internal code have to go through those replacements.

Reviewed By: lntue

Differential Revision: https://reviews.llvm.org/D139584

libc/src/__support/CPP/new.h [new file with mode: 0644]
libc/src/string/CMakeLists.txt
libc/src/string/allocating_string_utils.h
libc/src/string/strdup.cpp
libc/src/unistd/linux/getcwd.cpp
libc/test/src/string/CMakeLists.txt
libc/test/src/string/strdup_test.cpp

diff --git a/libc/src/__support/CPP/new.h b/libc/src/__support/CPP/new.h
new file mode 100644 (file)
index 0000000..3becb9f
--- /dev/null
@@ -0,0 +1,72 @@
+//===-- Libc specific custom operator new and delete ------------*- C++ -*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef LLVM_LIBC_SRC_SUPPORT_CPP_NEW_H
+#define LLVM_LIBC_SRC_SUPPORT_CPP_NEW_H
+
+#include <stddef.h> // For size_t
+#include <stdlib.h> // For malloc, free etc.
+
+// Defining members in the std namespace is not preferred. But, we do it here
+// so that we can use it to define the operator new which takes std::align_val_t
+// argument.
+namespace std {
+
+enum class align_val_t : size_t {};
+
+} // namespace std
+
+namespace __llvm_libc {
+
+class AllocChecker {
+  bool success = false;
+  AllocChecker &operator=(bool status) {
+    success = status;
+    return *this;
+  }
+
+public:
+  AllocChecker() = default;
+  operator bool() const { return success; }
+
+  static void *alloc(size_t s, AllocChecker &ac) {
+    void *mem = ::malloc(s);
+    ac = (mem != nullptr);
+    return mem;
+  }
+
+  static void *aligned_alloc(size_t s, std::align_val_t align,
+                             AllocChecker &ac) {
+    void *mem = ::aligned_alloc(static_cast<size_t>(align), s);
+    ac = (mem != nullptr);
+    return mem;
+  }
+};
+
+} // namespace __llvm_libc
+
+inline void *operator new(size_t size, __llvm_libc::AllocChecker &ac) noexcept {
+  return __llvm_libc::AllocChecker::alloc(size, ac);
+}
+
+inline void *operator new(size_t size, std::align_val_t align,
+                          __llvm_libc::AllocChecker &ac) noexcept {
+  return __llvm_libc::AllocChecker::aligned_alloc(size, align, ac);
+}
+
+inline void *operator new[](size_t size,
+                            __llvm_libc::AllocChecker &ac) noexcept {
+  return __llvm_libc::AllocChecker::alloc(size, ac);
+}
+
+inline void *operator new[](size_t size, std::align_val_t align,
+                            __llvm_libc::AllocChecker &ac) noexcept {
+  return __llvm_libc::AllocChecker::aligned_alloc(size, align, ac);
+}
+
+#endif // LLVM_LIBC_SRC_SUPPORT_CPP_NEW_H
index 422d24c..caeddc8 100644 (file)
@@ -16,8 +16,9 @@ add_header_library(
   HDRS
     allocating_string_utils.h
   DEPENDS
-    libc.include.stdlib
     .memory_utils.memcpy_implementation
+    libc.include.stdlib
+    libc.src.__support.CPP.optional
 )
 
 add_entrypoint_object(
@@ -150,7 +151,9 @@ add_entrypoint_object(
   DEPENDS
     .memory_utils.memcpy_implementation
     .string_utils
+    libc.include.errno
     libc.include.stdlib
+    libc.src.errno.errno
 )
 
 add_entrypoint_object(
index 3220400..93d1c7a 100644 (file)
@@ -9,24 +9,24 @@
 #ifndef LIBC_SRC_STRING_ALLOCATING_STRING_UTILS_H
 #define LIBC_SRC_STRING_ALLOCATING_STRING_UTILS_H
 
-#include "src/__support/CPP/bitset.h"
-#include "src/__support/common.h"
-#include "src/string/memory_utils/bzero_implementations.h"
-#include "src/string/memory_utils/memcpy_implementations.h"
+#include "src/__support/CPP/new.h"
+#include "src/__support/CPP/optional.h"
+#include "src/string/memory_utils/memcpy_implementations.h" // For string_length
 #include "src/string/string_utils.h"
+
 #include <stddef.h> // For size_t
-#include <stdlib.h> // For malloc
 
 namespace __llvm_libc {
 namespace internal {
 
-inline char *strdup(const char *src) {
+cpp::optional<char *> strdup(const char *src) {
   if (src == nullptr)
-    return nullptr;
+    return cpp::nullopt;
   size_t len = string_length(src) + 1;
-  char *newstr = reinterpret_cast<char *>(::malloc(len));
-  if (newstr == nullptr)
-    return nullptr;
+  AllocChecker ac;
+  char *newstr = new (ac) char[len];
+  if (!ac)
+    return cpp::nullopt;
   inline_memcpy(newstr, src, len);
   return newstr;
 }
index 9a52b29..9aa0d50 100644 (file)
 
 #include "src/__support/common.h"
 
+#include <errno.h>
 #include <stdlib.h>
 
 namespace __llvm_libc {
 
 LLVM_LIBC_FUNCTION(char *, strdup, (const char *src)) {
-  return internal::strdup(src);
+  auto dup = internal::strdup(src);
+  if (dup)
+    return *dup;
+  if (src != nullptr)
+    errno = ENOMEM;
+  return nullptr;
 }
 
 } // namespace __llvm_libc
index 84a16b3..67dea37 100644 (file)
@@ -44,12 +44,12 @@ LLVM_LIBC_FUNCTION(char *, getcwd, (char *buf, size_t size)) {
     char pathbuf[PATH_MAX];
     if (!getcwd_syscall(pathbuf, PATH_MAX))
       return nullptr;
-    char *cwd = internal::strdup(pathbuf);
-    if (cwd == nullptr) {
+    auto cwd = internal::strdup(pathbuf);
+    if (!cwd) {
       errno = ENOMEM;
       return nullptr;
     }
-    return cwd;
+    return *cwd;
   } else if (size == 0) {
     errno = EINVAL;
     return nullptr;
index 8b427cc..b5de08b 100644 (file)
@@ -141,8 +141,10 @@ add_libc_unittest(
   SRCS
     strdup_test.cpp
   DEPENDS
+    libc.include.errno
     libc.include.stdlib
     libc.src.string.strdup
+    libc.src.errno.errno
 )
 
 add_libc_unittest(
index 7820c1e..6576003 100644 (file)
@@ -8,12 +8,17 @@
 
 #include "src/string/strdup.h"
 #include "utils/UnitTest/Test.h"
+
+#include <errno.h>
 #include <stdlib.h>
 
 TEST(LlvmLibcStrDupTest, EmptyString) {
   const char *empty = "";
 
+  errno = 0;
   char *result = __llvm_libc::strdup(empty);
+  ASSERT_EQ(errno, 0);
+
   ASSERT_NE(result, static_cast<char *>(nullptr));
   ASSERT_NE(empty, const_cast<const char *>(result));
   ASSERT_STREQ(empty, result);
@@ -23,7 +28,9 @@ TEST(LlvmLibcStrDupTest, EmptyString) {
 TEST(LlvmLibcStrDupTest, AnyString) {
   const char *abc = "abc";
 
+  errno = 0;
   char *result = __llvm_libc::strdup(abc);
+  ASSERT_EQ(errno, 0);
 
   ASSERT_NE(result, static_cast<char *>(nullptr));
   ASSERT_NE(abc, const_cast<const char *>(result));
@@ -32,8 +39,9 @@ TEST(LlvmLibcStrDupTest, AnyString) {
 }
 
 TEST(LlvmLibcStrDupTest, NullPtr) {
-
+  errno = 0;
   char *result = __llvm_libc::strdup(nullptr);
+  ASSERT_EQ(errno, 0);
 
   ASSERT_EQ(result, static_cast<char *>(nullptr));
 }