[hwasan] support hwasan-match-all-tag flag for hwasan meminstrinsic calls
authorEnna1 <xumingjie.enna1@bytedance.com>
Sat, 27 May 2023 02:35:18 +0000 (10:35 +0800)
committerEnna1 <xumingjie.enna1@bytedance.com>
Sat, 27 May 2023 02:35:52 +0000 (10:35 +0800)
This patch implements `__hwasan_memset_match_all`, `__hwasan_memcpy_match_all` and `__hwasan_memmove_match_all`, making hwasan-match-all-tag flag working for hwasan versions of memset, memcpy and memmove.

Reviewed By: vitalybuka

Differential Revision: https://reviews.llvm.org/D149943

compiler-rt/lib/hwasan/hwasan_interface_internal.h
compiler-rt/lib/hwasan/hwasan_memintrinsics.cpp
llvm/lib/Transforms/Instrumentation/HWAddressSanitizer.cpp
llvm/test/Instrumentation/HWAddressSanitizer/mem-intrinsics.ll

index 48ff3d5..e7804cc 100644 (file)
@@ -236,6 +236,13 @@ SANITIZER_INTERFACE_ATTRIBUTE
 void *__hwasan_memmove(void *dest, const void *src, uptr n);
 
 SANITIZER_INTERFACE_ATTRIBUTE
+void *__hwasan_memcpy_match_all(void *dst, const void *src, uptr size, u8);
+SANITIZER_INTERFACE_ATTRIBUTE
+void *__hwasan_memset_match_all(void *s, int c, uptr n, u8);
+SANITIZER_INTERFACE_ATTRIBUTE
+void *__hwasan_memmove_match_all(void *dest, const void *src, uptr n, u8);
+
+SANITIZER_INTERFACE_ATTRIBUTE
 void __hwasan_set_error_report_callback(void (*callback)(const char *));
 }  // extern "C"
 
index ea7f5ce..16d6f90 100644 (file)
@@ -42,3 +42,33 @@ void *__hwasan_memmove(void *to, const void *from, uptr size) {
       reinterpret_cast<uptr>(from), size);
   return memmove(to, from, size);
 }
+
+void *__hwasan_memset_match_all(void *block, int c, uptr size,
+                                u8 match_all_tag) {
+  if (GetTagFromPointer(reinterpret_cast<uptr>(block)) != match_all_tag)
+    CheckAddressSized<ErrorAction::Recover, AccessType::Store>(
+        reinterpret_cast<uptr>(block), size);
+  return memset(block, c, size);
+}
+
+void *__hwasan_memcpy_match_all(void *to, const void *from, uptr size,
+                                u8 match_all_tag) {
+  if (GetTagFromPointer(reinterpret_cast<uptr>(to)) != match_all_tag)
+    CheckAddressSized<ErrorAction::Recover, AccessType::Store>(
+        reinterpret_cast<uptr>(to), size);
+  if (GetTagFromPointer(reinterpret_cast<uptr>(from)) != match_all_tag)
+    CheckAddressSized<ErrorAction::Recover, AccessType::Load>(
+        reinterpret_cast<uptr>(from), size);
+  return memcpy(to, from, size);
+}
+
+void *__hwasan_memmove_match_all(void *to, const void *from, uptr size,
+                                 u8 match_all_tag) {
+  if (GetTagFromPointer(reinterpret_cast<uptr>(to)) != match_all_tag)
+    CheckAddressSized<ErrorAction::Recover, AccessType::Store>(
+        reinterpret_cast<uptr>(to), size);
+  if (GetTagFromPointer(reinterpret_cast<uptr>(from)) != match_all_tag)
+    CheckAddressSized<ErrorAction::Recover, AccessType::Load>(
+        reinterpret_cast<uptr>(from), size);
+  return memmove(to, from, size);
+}
index 954de89..2ad090b 100644 (file)
@@ -623,24 +623,33 @@ void HWAddressSanitizer::initializeModule() {
 
 void HWAddressSanitizer::initializeCallbacks(Module &M) {
   IRBuilder<> IRB(*C);
+  const std::string MatchAllStr = UseMatchAllCallback ? "_match_all" : "";
+  FunctionType *HwasanMemoryAccessCallbackSizedFnTy,
+      *HwasanMemoryAccessCallbackFnTy, *HWAsanMemTransferFnTy,
+      *HWAsanMemsetFnTy;
+  if (UseMatchAllCallback) {
+    HwasanMemoryAccessCallbackSizedFnTy =
+        FunctionType::get(VoidTy, {IntptrTy, IntptrTy, Int8Ty}, false);
+    HwasanMemoryAccessCallbackFnTy =
+        FunctionType::get(VoidTy, {IntptrTy, Int8Ty}, false);
+    HWAsanMemTransferFnTy = FunctionType::get(
+        Int8PtrTy, {Int8PtrTy, Int8PtrTy, IntptrTy, Int8Ty}, false);
+    HWAsanMemsetFnTy = FunctionType::get(
+        Int8PtrTy, {Int8PtrTy, Int32Ty, IntptrTy, Int8Ty}, false);
+  } else {
+    HwasanMemoryAccessCallbackSizedFnTy =
+        FunctionType::get(VoidTy, {IntptrTy, IntptrTy}, false);
+    HwasanMemoryAccessCallbackFnTy =
+        FunctionType::get(VoidTy, {IntptrTy}, false);
+    HWAsanMemTransferFnTy =
+        FunctionType::get(Int8PtrTy, {Int8PtrTy, Int8PtrTy, IntptrTy}, false);
+    HWAsanMemsetFnTy =
+        FunctionType::get(Int8PtrTy, {Int8PtrTy, Int32Ty, IntptrTy}, false);
+  }
+
   for (size_t AccessIsWrite = 0; AccessIsWrite <= 1; AccessIsWrite++) {
     const std::string TypeStr = AccessIsWrite ? "store" : "load";
     const std::string EndingStr = Recover ? "_noabort" : "";
-    const std::string MatchAllStr = UseMatchAllCallback ? "_match_all" : "";
-
-    FunctionType *HwasanMemoryAccessCallbackSizedFnTy,
-        *HwasanMemoryAccessCallbackFnTy;
-    if (UseMatchAllCallback) {
-      HwasanMemoryAccessCallbackSizedFnTy =
-          FunctionType::get(VoidTy, {IntptrTy, IntptrTy, Int8Ty}, false);
-      HwasanMemoryAccessCallbackFnTy =
-          FunctionType::get(VoidTy, {IntptrTy, Int8Ty}, false);
-    } else {
-      HwasanMemoryAccessCallbackSizedFnTy =
-          FunctionType::get(VoidTy, {IntptrTy, IntptrTy}, false);
-      HwasanMemoryAccessCallbackFnTy =
-          FunctionType::get(VoidTy, {IntptrTy}, false);
-    }
 
     HwasanMemoryAccessCallbackSized[AccessIsWrite] = M.getOrInsertFunction(
         ClMemoryAccessCallbackPrefix + TypeStr + "N" + MatchAllStr + EndingStr,
@@ -656,6 +665,18 @@ void HWAddressSanitizer::initializeCallbacks(Module &M) {
     }
   }
 
+  const std::string MemIntrinCallbackPrefix =
+      (CompileKernel && !ClKasanMemIntrinCallbackPrefix)
+          ? std::string("")
+          : ClMemoryAccessCallbackPrefix;
+
+  HWAsanMemmove = M.getOrInsertFunction(
+      MemIntrinCallbackPrefix + "memmove" + MatchAllStr, HWAsanMemTransferFnTy);
+  HWAsanMemcpy = M.getOrInsertFunction(
+      MemIntrinCallbackPrefix + "memcpy" + MatchAllStr, HWAsanMemTransferFnTy);
+  HWAsanMemset = M.getOrInsertFunction(
+      MemIntrinCallbackPrefix + "memset" + MatchAllStr, HWAsanMemsetFnTy);
+
   HwasanTagMemoryFunc = M.getOrInsertFunction("__hwasan_tag_memory", VoidTy,
                                               Int8PtrTy, Int8Ty, IntptrTy);
   HwasanGenerateTagFunc =
@@ -667,19 +688,6 @@ void HWAddressSanitizer::initializeCallbacks(Module &M) {
   ShadowGlobal =
       M.getOrInsertGlobal("__hwasan_shadow", ArrayType::get(Int8Ty, 0));
 
-  const std::string MemIntrinCallbackPrefix =
-      (CompileKernel && !ClKasanMemIntrinCallbackPrefix)
-          ? std::string("")
-          : ClMemoryAccessCallbackPrefix;
-  HWAsanMemmove =
-      M.getOrInsertFunction(MemIntrinCallbackPrefix + "memmove", Int8PtrTy,
-                            Int8PtrTy, Int8PtrTy, IntptrTy);
-  HWAsanMemcpy =
-      M.getOrInsertFunction(MemIntrinCallbackPrefix + "memcpy", Int8PtrTy,
-                            Int8PtrTy, Int8PtrTy, IntptrTy);
-  HWAsanMemset = M.getOrInsertFunction(MemIntrinCallbackPrefix + "memset",
-                                       Int8PtrTy, Int8PtrTy, Int32Ty, IntptrTy);
-
   HWAsanHandleVfork =
       M.getOrInsertFunction("__hwasan_handle_vfork", VoidTy, IntptrTy);
 }
@@ -949,15 +957,35 @@ bool HWAddressSanitizer::ignoreMemIntrinsic(MemIntrinsic *MI) {
 void HWAddressSanitizer::instrumentMemIntrinsic(MemIntrinsic *MI) {
   IRBuilder<> IRB(MI);
   if (isa<MemTransferInst>(MI)) {
-    IRB.CreateCall(isa<MemMoveInst>(MI) ? HWAsanMemmove : HWAsanMemcpy,
-                   {IRB.CreatePointerCast(MI->getOperand(0), Int8PtrTy),
-                    IRB.CreatePointerCast(MI->getOperand(1), Int8PtrTy),
-                    IRB.CreateIntCast(MI->getOperand(2), IntptrTy, false)});
+    if (UseMatchAllCallback) {
+      IRB.CreateCall(
+          isa<MemMoveInst>(MI) ? HWAsanMemmove : HWAsanMemcpy,
+          {IRB.CreatePointerCast(MI->getOperand(0), IRB.getInt8PtrTy()),
+           IRB.CreatePointerCast(MI->getOperand(1), IRB.getInt8PtrTy()),
+           IRB.CreateIntCast(MI->getOperand(2), IntptrTy, false),
+           ConstantInt::get(Int8Ty, *MatchAllTag)});
+    } else {
+      IRB.CreateCall(
+          isa<MemMoveInst>(MI) ? HWAsanMemmove : HWAsanMemcpy,
+          {IRB.CreatePointerCast(MI->getOperand(0), IRB.getInt8PtrTy()),
+           IRB.CreatePointerCast(MI->getOperand(1), IRB.getInt8PtrTy()),
+           IRB.CreateIntCast(MI->getOperand(2), IntptrTy, false)});
+    }
   } else if (isa<MemSetInst>(MI)) {
-    IRB.CreateCall(HWAsanMemset,
-                   {IRB.CreatePointerCast(MI->getOperand(0), Int8PtrTy),
-                    IRB.CreateIntCast(MI->getOperand(1), Int32Ty, false),
-                    IRB.CreateIntCast(MI->getOperand(2), IntptrTy, false)});
+    if (UseMatchAllCallback) {
+      IRB.CreateCall(
+          HWAsanMemset,
+          {IRB.CreatePointerCast(MI->getOperand(0), IRB.getInt8PtrTy()),
+           IRB.CreateIntCast(MI->getOperand(1), IRB.getInt32Ty(), false),
+           IRB.CreateIntCast(MI->getOperand(2), IntptrTy, false),
+           ConstantInt::get(Int8Ty, *MatchAllTag)});
+    } else {
+      IRB.CreateCall(
+          HWAsanMemset,
+          {IRB.CreatePointerCast(MI->getOperand(0), IRB.getInt8PtrTy()),
+           IRB.CreateIntCast(MI->getOperand(1), IRB.getInt32Ty(), false),
+           IRB.CreateIntCast(MI->getOperand(2), IntptrTy, false)});
+    }
   }
   MI->eraseFromParent();
 }
index aea3a34..c13ca6c 100644 (file)
@@ -1,6 +1,7 @@
 ; RUN: opt -S -passes=hwasan -hwasan-use-stack-safety=0 %s | FileCheck --check-prefixes=CHECK,CHECK-PREFIX %s
 ; RUN: opt -S -passes=hwasan -hwasan-kernel -hwasan-use-stack-safety=0 %s | FileCheck --check-prefixes=CHECK,CHECK-NOPREFIX %s
 ; RUN: opt -S -passes=hwasan -hwasan-kernel -hwasan-kernel-mem-intrinsic-prefix -hwasan-use-stack-safety=0 %s | FileCheck --check-prefixes=CHECK,CHECK-PREFIX %s
+; RUN: opt -S -passes=hwasan -hwasan-use-stack-safety=0 -hwasan-match-all-tag=0 %s | FileCheck --check-prefixes=CHECK,CHECK-MATCH-ALL-TAG %s
 
 target datalayout = "e-m:e-i64:64-f80:128-n8:16:32:64-S128"
 target triple = "x86_64-unknown-linux-gnu"
@@ -15,19 +16,22 @@ entry:
   store i32 0, ptr %retval, align 4
 
   call void @llvm.memset.p0.i64(ptr align 1 %Q, i8 0, i64 10, i1 false)
-; CHECK-PREFIX: call ptr @__hwasan_memset
-; CHECK-NOPREFIX: call ptr @memset
+; CHECK-PREFIX: call ptr @__hwasan_memset(
+; CHECK-NOPREFIX: call ptr @memset(
+; CHECK-MATCH-ALL-TAG: call ptr @__hwasan_memset_match_all(ptr %Q.hwasan, i32 0, i64 10, i8 0)
 
   %add.ptr = getelementptr inbounds i8, ptr %Q, i64 5
 
   call void @llvm.memmove.p0.p0.i64(ptr align 1 %Q, ptr align 1 %add.ptr, i64 5, i1 false)
-; CHECK-PREFIX: call ptr @__hwasan_memmove
-; CHECK-NOPREFIX: call ptr @memmove
+; CHECK-PREFIX: call ptr @__hwasan_memmove(
+; CHECK-NOPREFIX: call ptr @memmove(
+; CHECK-MATCH-ALL-TAG: call ptr @__hwasan_memmove_match_all(ptr %Q.hwasan, ptr %add.ptr, i64 5, i8 0)
 
 
   call void @llvm.memcpy.p0.p0.i64(ptr align 1 %P, ptr align 1 %Q, i64 10, i1 false)
-; CHECK-PREFIX: call ptr @__hwasan_memcpy
-; CHECK-NOPREFIX: call ptr @memcpy
+; CHECK-PREFIX: call ptr @__hwasan_memcpy(
+; CHECK-NOPREFIX: call ptr @memcpy(
+; CHECK-MATCH-ALL-TAG: call ptr @__hwasan_memcpy_match_all(ptr %P.hwasan, ptr %Q.hwasan, i64 10, i8 0)
   ret i32 0
 }