[dfsan] Add utils that get/set origins
authorJianzhou Zhao <jianzhouzh@google.com>
Fri, 19 Feb 2021 21:32:37 +0000 (21:32 +0000)
committerJianzhou Zhao <jianzhouzh@google.com>
Sat, 20 Feb 2021 00:52:33 +0000 (00:52 +0000)
This is a part of https://reviews.llvm.org/D95835.

Reviewed-by: morehouse
Differential Revision: https://reviews.llvm.org/D97087

llvm/lib/Transforms/Instrumentation/DataFlowSanitizer.cpp

index 6a02212..ba66a71 100644 (file)
@@ -382,6 +382,7 @@ class DataFlowSanitizer {
   IntegerType *OriginTy;
   PointerType *OriginPtrTy;
   ConstantInt *OriginBase;
+  ConstantInt *ZeroOrigin;
   /// The shadow type for all primitive types and vector types.
   IntegerType *PrimitiveShadowTy;
   PointerType *PrimitiveShadowPtrTy;
@@ -503,6 +504,7 @@ struct DFSanFunction {
   bool IsNativeABI;
   AllocaInst *LabelReturnAlloca = nullptr;
   DenseMap<Value *, Value *> ValShadowMap;
+  DenseMap<Value *, Value *> ValOriginMap;
   DenseMap<AllocaInst *, AllocaInst *> AllocaShadowMap;
   std::vector<std::pair<PHINode *, PHINode *>> PHIFixups;
   DenseSet<Instruction *> SkipInsts;
@@ -535,9 +537,19 @@ struct DFSanFunction {
   /// Shadow = ArgTLS+ArgOffset.
   Value *getArgTLS(Type *T, unsigned ArgOffset, IRBuilder<> &IRB);
 
-  /// Computes the shadow address for a retval.
+  /// Computes the shadow address for a return value.
   Value *getRetvalTLS(Type *T, IRBuilder<> &IRB);
 
+  /// Computes the origin address for a given function argument.
+  ///
+  /// Origin = ArgOriginTLS[ArgNo].
+  Value *getArgOriginTLS(unsigned ArgNo, IRBuilder<> &IRB);
+
+  /// Computes the origin address for a return value.
+  Value *getRetvalOriginTLS();
+
+  Value *getOrigin(Value *V);
+  void setOrigin(Instruction *I, Value *Origin);
   Value *getShadow(Value *V);
   void setShadow(Instruction *I, Value *Shadow);
   /// Generates IR to compute the union of the two given shadows, inserting it
@@ -877,6 +889,7 @@ bool DataFlowSanitizer::init(Module &M) {
   ZeroPrimitiveShadow = ConstantInt::getSigned(PrimitiveShadowTy, 0);
   ShadowPtrMul = ConstantInt::getSigned(IntptrTy, ShadowWidthBytes);
   OriginBase = ConstantInt::get(IntptrTy, 0x200000000000LL);
+  ZeroOrigin = ConstantInt::getSigned(OriginTy, 0);
   if (IsX86_64)
     ShadowPtrMask = ConstantInt::getSigned(IntptrTy, ~0x700000000000LL);
   else if (IsMIPS64)
@@ -1453,6 +1466,55 @@ Value *DFSanFunction::getRetvalTLS(Type *T, IRBuilder<> &IRB) {
       DFS.RetvalTLS, PointerType::get(DFS.getShadowTy(T), 0), "_dfsret");
 }
 
+Value *DFSanFunction::getRetvalOriginTLS() { return DFS.RetvalOriginTLS; }
+
+Value *DFSanFunction::getArgOriginTLS(unsigned ArgNo, IRBuilder<> &IRB) {
+  return IRB.CreateConstGEP2_64(DFS.ArgOriginTLSTy, DFS.ArgOriginTLS, 0, ArgNo,
+                                "_dfsarg_o");
+}
+
+Value *DFSanFunction::getOrigin(Value *V) {
+  assert(DFS.shouldTrackOrigins());
+  if (!isa<Argument>(V) && !isa<Instruction>(V))
+    return DFS.ZeroOrigin;
+  Value *&Origin = ValOriginMap[V];
+  if (!Origin) {
+    if (Argument *A = dyn_cast<Argument>(V)) {
+      if (IsNativeABI)
+        return DFS.ZeroOrigin;
+      switch (IA) {
+      case DataFlowSanitizer::IA_TLS: {
+        if (A->getArgNo() < DFS.kNumOfElementsInArgOrgTLS) {
+          Instruction *ArgOriginTLSPos = &*F->getEntryBlock().begin();
+          IRBuilder<> IRB(ArgOriginTLSPos);
+          Value *ArgOriginPtr = getArgOriginTLS(A->getArgNo(), IRB);
+          Origin = IRB.CreateLoad(DFS.OriginTy, ArgOriginPtr);
+        } else {
+          // Overflow
+          Origin = DFS.ZeroOrigin;
+        }
+        break;
+      }
+      case DataFlowSanitizer::IA_Args: {
+        Origin = DFS.ZeroOrigin;
+        break;
+      }
+      }
+    } else {
+      Origin = DFS.ZeroOrigin;
+    }
+  }
+  return Origin;
+}
+
+void DFSanFunction::setOrigin(Instruction *I, Value *Origin) {
+  if (!DFS.shouldTrackOrigins())
+    return;
+  assert(!ValOriginMap.count(I));
+  assert(Origin->getType() == DFS.OriginTy);
+  ValOriginMap[I] = Origin;
+}
+
 Value *DFSanFunction::getShadowForTLSArgument(Argument *A) {
   unsigned ArgOffset = 0;
   const DataLayout &DL = F->getParent()->getDataLayout();