Add support for static helper functions
authorBrenden Blanco <bblanco@plumgrid.com>
Wed, 16 Sep 2015 21:59:35 +0000 (14:59 -0700)
committerBrenden Blanco <bblanco@plumgrid.com>
Thu, 17 Sep 2015 20:08:49 +0000 (13:08 -0700)
This adds support for static helper functions that can be reused. It is
not necessary to include pt_regs in the helper functions, even though
external pointers may be dereferenced. Arguments in the helpers can also
be reordered.

Signed-off-by: Brenden Blanco <bblanco@plumgrid.com>
src/cc/frontends/clang/b_frontend_action.cc
src/cc/frontends/clang/b_frontend_action.h
tests/cc/test_brb.c
tests/cc/test_clang.py

index eb0406d..7d3ac2c 100644 (file)
@@ -21,6 +21,7 @@
 #include <clang/AST/ASTContext.h>
 #include <clang/AST/RecordLayout.h>
 #include <clang/Frontend/CompilerInstance.h>
+#include <clang/Frontend/MultiplexConsumer.h>
 #include <clang/Rewrite/Core/Rewriter.h>
 
 #include "b_frontend_action.h"
@@ -36,6 +37,7 @@ const char *calling_conv_regs_x86[] = {
 const char **calling_conv_regs = calling_conv_regs_x86;
 
 using std::map;
+using std::set;
 using std::string;
 using std::to_string;
 using std::unique_ptr;
@@ -90,27 +92,107 @@ bool BMapDeclVisitor::VisitBuiltinType(const BuiltinType *T) {
   return true;
 }
 
-class BProbeChecker : public clang::RecursiveASTVisitor<BProbeChecker> {
+class ProbeChecker : public clang::RecursiveASTVisitor<ProbeChecker> {
  public:
+  explicit ProbeChecker(Expr *arg, const set<Decl *> &ptregs)
+      : needs_probe_(false), ptregs_(ptregs) {
+    if (arg)
+      TraverseStmt(arg);
+  }
   bool VisitDeclRefExpr(clang::DeclRefExpr *E) {
-    if (E->getDecl()->hasAttr<UnavailableAttr>())
-      return false;
+    if (ptregs_.find(E->getDecl()) != ptregs_.end())
+      needs_probe_ = true;
     return true;
   }
+  bool needs_probe() const { return needs_probe_; }
+ private:
+  bool needs_probe_;
+  const set<Decl *> &ptregs_;
 };
 
 // Visit a piece of the AST and mark it as needing probe reads
-class BProbeSetter : public clang::RecursiveASTVisitor<BProbeSetter> {
+class ProbeSetter : public clang::RecursiveASTVisitor<ProbeSetter> {
  public:
-  explicit BProbeSetter(ASTContext &C) : C(C) {}
+  explicit ProbeSetter(set<Decl *> *ptregs) : ptregs_(ptregs) {}
   bool VisitDeclRefExpr(clang::DeclRefExpr *E) {
-    E->getDecl()->addAttr(UnavailableAttr::CreateImplicit(C, "ptregs"));
+    ptregs_->insert(E->getDecl());
     return true;
   }
  private:
-  ASTContext &C;
+  set<Decl *> *ptregs_;
 };
 
+ProbeVisitor::ProbeVisitor(Rewriter &rewriter) : rewriter_(rewriter) {}
+
+bool ProbeVisitor::VisitVarDecl(VarDecl *Decl) {
+  if (Expr *E = Decl->getInit()) {
+    if (ProbeChecker(E, ptregs_).needs_probe())
+      set_ptreg(Decl);
+  }
+  return true;
+}
+bool ProbeVisitor::VisitCallExpr(CallExpr *Call) {
+  if (FunctionDecl *F = dyn_cast<FunctionDecl>(Call->getCalleeDecl())) {
+    if (F->hasBody()) {
+      unsigned i = 0;
+      for (auto arg : Call->arguments()) {
+        if (ProbeChecker(arg, ptregs_).needs_probe())
+          ptregs_.insert(F->getParamDecl(i));
+        ++i;
+      }
+      if (fn_visited_.find(F) == fn_visited_.end()) {
+        fn_visited_.insert(F);
+        TraverseDecl(F);
+      }
+    }
+  }
+  return true;
+}
+bool ProbeVisitor::VisitBinaryOperator(BinaryOperator *E) {
+  if (!E->isAssignmentOp())
+    return true;
+  // copy probe attribute from RHS to LHS if present
+  if (ProbeChecker(E->getRHS(), ptregs_).needs_probe()) {
+    ProbeSetter setter(&ptregs_);
+    setter.TraverseStmt(E->getLHS());
+  }
+  return true;
+}
+bool ProbeVisitor::VisitMemberExpr(MemberExpr *E) {
+  if (memb_visited_.find(E) != memb_visited_.end()) return true;
+
+  // Checks to see if the expression references something that needs to be run
+  // through bpf_probe_read.
+  if (!ProbeChecker(E, ptregs_).needs_probe())
+    return true;
+
+  Expr *base;
+  SourceLocation rhs_start, op;
+  bool found = false;
+  for (MemberExpr *M = E; M; M = dyn_cast<MemberExpr>(M->getBase())) {
+    memb_visited_.insert(M);
+    rhs_start = M->getLocEnd();
+    base = M->getBase();
+    op = M->getOperatorLoc();
+    if (M->isArrow()) {
+      found = true;
+      break;
+    }
+  }
+  if (!found)
+    return true;
+  string rhs = rewriter_.getRewrittenText(SourceRange(rhs_start, E->getLocEnd()));
+  string base_type = base->getType()->getPointeeType().getAsString();
+  string pre, post;
+  pre = "({ typeof(" + E->getType().getAsString() + ") _val; memset(&_val, 0, sizeof(_val));";
+  pre += " bpf_probe_read(&_val, sizeof(_val), (u64)";
+  post = " + offsetof(" + base_type + ", " + rhs + ")";
+  post += "); _val; })";
+  rewriter_.InsertText(E->getLocStart(), pre);
+  rewriter_.ReplaceText(SourceRange(op, E->getLocEnd()), post);
+  return true;
+}
+
 BTypeVisitor::BTypeVisitor(ASTContext &C, Rewriter &rewriter, vector<TableDesc> &tables)
     : C(C), rewriter_(rewriter), out_(llvm::errs()), tables_(tables) {
 }
@@ -141,6 +223,11 @@ bool BTypeVisitor::VisitFunctionDecl(FunctionDecl *D) {
     // for each trace argument, convert the variable from ptregs to something on stack
     if (CompoundStmt *S = dyn_cast<CompoundStmt>(D->getBody()))
       rewriter_.ReplaceText(S->getLBracLoc(), 1, preamble);
+  } else if (D->hasBody() &&
+             rewriter_.getSourceMgr().getFileID(D->getLocStart())
+               == rewriter_.getSourceMgr().getMainFileID()) {
+    // rewritable functions that are static should be always treated as helper
+    rewriter_.InsertText(D->getLocStart(), "__attribute__((always_inline))\n");
   }
   return true;
 }
@@ -282,37 +369,6 @@ bool BTypeVisitor::VisitCallExpr(CallExpr *Call) {
   return true;
 }
 
-bool BTypeVisitor::VisitMemberExpr(MemberExpr *E) {
-  if (visited_.find(E) != visited_.end()) return true;
-
-  // Checks to see if the expression references something that needs to be run
-  // through bpf_probe_read.
-  BProbeChecker checker;
-  if (checker.TraverseStmt(E))
-    return true;
-
-  Expr *base;
-  SourceLocation rhs_start, op;
-  for (MemberExpr *M = E; M; M = dyn_cast<MemberExpr>(M->getBase())) {
-    visited_.insert(M);
-    rhs_start = M->getLocEnd();
-    base = M->getBase();
-    op = M->getOperatorLoc();
-    if (M->isArrow())
-      break;
-  }
-  string rhs = rewriter_.getRewrittenText(SourceRange(rhs_start, E->getLocEnd()));
-  string base_type = base->getType()->getPointeeType().getAsString();
-  string pre, post;
-  pre = "({ typeof(" + E->getType().getAsString() + ") _val; memset(&_val, 0, sizeof(_val));";
-  pre += " bpf_probe_read(&_val, sizeof(_val), (u64)";
-  post = " + offsetof(" + base_type + ", " + rhs + ")";
-  post += "); _val; })";
-  rewriter_.InsertText(E->getLocStart(), pre);
-  rewriter_.ReplaceText(SourceRange(op, E->getLocEnd()), post);
-  return true;
-}
-
 bool BTypeVisitor::VisitBinaryOperator(BinaryOperator *E) {
   if (!E->isAssignmentOp())
     return true;
@@ -340,12 +396,6 @@ bool BTypeVisitor::VisitBinaryOperator(BinaryOperator *E) {
       }
     }
   }
-  // copy probe attribute from RHS to LHS if present
-  BProbeChecker checker;
-  if (!checker.TraverseStmt(E->getRHS())) {
-    BProbeSetter setter(C);
-    setter.TraverseStmt(E->getLHS());
-  }
   return true;
 }
 bool BTypeVisitor::VisitImplicitCastExpr(ImplicitCastExpr *E) {
@@ -453,11 +503,6 @@ bool BTypeVisitor::VisitVarDecl(VarDecl *Decl) {
       }
     }
   }
-  if (Expr *E = Decl->getInit()) {
-    BProbeChecker checker;
-    if (!checker.TraverseStmt(E))
-      Decl->addAttr(UnavailableAttr::CreateImplicit(C, "ptregs"));
-  }
   return true;
 }
 
@@ -465,9 +510,27 @@ BTypeConsumer::BTypeConsumer(ASTContext &C, Rewriter &rewriter, vector<TableDesc
     : visitor_(C, rewriter, tables) {
 }
 
-bool BTypeConsumer::HandleTopLevelDecl(DeclGroupRef D) {
-  for (auto it : D)
-    visitor_.TraverseDecl(it);
+bool BTypeConsumer::HandleTopLevelDecl(DeclGroupRef Group) {
+  for (auto D : Group)
+    visitor_.TraverseDecl(D);
+  return true;
+}
+
+ProbeConsumer::ProbeConsumer(clang::ASTContext &C, Rewriter &rewriter)
+    : visitor_(rewriter) {}
+
+bool ProbeConsumer::HandleTopLevelDecl(clang::DeclGroupRef Group) {
+  for (auto D : Group) {
+    if (FunctionDecl *F = dyn_cast<FunctionDecl>(D)) {
+      if (F->isExternallyVisible() && F->hasBody()) {
+        for (auto arg : F->parameters()) {
+          if (arg != F->getParamDecl(0))
+            visitor_.set_ptreg(arg);
+        }
+        visitor_.TraverseDecl(D);
+      }
+    }
+  }
   return true;
 }
 
@@ -476,7 +539,6 @@ BFrontendAction::BFrontendAction(llvm::raw_ostream &os, unsigned flags)
 }
 
 void BFrontendAction::EndSourceFileAction() {
-  // uncomment to see rewritten source
   if (flags_ & 0x4)
     rewriter_->getEditBuffer(rewriter_->getSourceMgr().getMainFileID()).write(llvm::errs());
   rewriter_->getEditBuffer(rewriter_->getSourceMgr().getMainFileID()).write(os_);
@@ -485,7 +547,10 @@ void BFrontendAction::EndSourceFileAction() {
 
 unique_ptr<ASTConsumer> BFrontendAction::CreateASTConsumer(CompilerInstance &Compiler, llvm::StringRef InFile) {
   rewriter_->setSourceMgr(Compiler.getSourceManager(), Compiler.getLangOpts());
-  return unique_ptr<ASTConsumer>(new BTypeConsumer(Compiler.getASTContext(), *rewriter_, *tables_));
+  vector<unique_ptr<ASTConsumer>> consumers;
+  consumers.push_back(unique_ptr<ASTConsumer>(new ProbeConsumer(Compiler.getASTContext(), *rewriter_)));
+  consumers.push_back(unique_ptr<ASTConsumer>(new BTypeConsumer(Compiler.getASTContext(), *rewriter_, *tables_)));
+  return unique_ptr<ASTConsumer>(new MultiplexConsumer(move(consumers)));
 }
 
 }
index 92bdca9..1cfcc00 100644 (file)
@@ -66,7 +66,6 @@ class BTypeVisitor : public clang::RecursiveASTVisitor<BTypeVisitor> {
   bool VisitFunctionDecl(clang::FunctionDecl *D);
   bool VisitCallExpr(clang::CallExpr *Call);
   bool VisitVarDecl(clang::VarDecl *Decl);
-  bool VisitMemberExpr(clang::MemberExpr *E);
   bool VisitBinaryOperator(clang::BinaryOperator *E);
   bool VisitImplicitCastExpr(clang::ImplicitCastExpr *E);
 
@@ -79,16 +78,41 @@ class BTypeVisitor : public clang::RecursiveASTVisitor<BTypeVisitor> {
   std::set<clang::Expr *> visited_;
 };
 
+// Do a depth-first search to rewrite all pointers that need to be probed
+class ProbeVisitor : public clang::RecursiveASTVisitor<ProbeVisitor> {
+ public:
+  explicit ProbeVisitor(clang::Rewriter &rewriter);
+  bool VisitVarDecl(clang::VarDecl *Decl);
+  bool VisitCallExpr(clang::CallExpr *Call);
+  bool VisitBinaryOperator(clang::BinaryOperator *E);
+  bool VisitMemberExpr(clang::MemberExpr *E);
+  void set_ptreg(clang::Decl *D) { ptregs_.insert(D); }
+ private:
+  clang::Rewriter &rewriter_;
+  std::set<clang::Decl *> fn_visited_;
+  std::set<clang::Expr *> memb_visited_;
+  std::set<clang::Decl *> ptregs_;
+};
+
 // A helper class to the frontend action, walks the decls
 class BTypeConsumer : public clang::ASTConsumer {
  public:
   explicit BTypeConsumer(clang::ASTContext &C, clang::Rewriter &rewriter,
                          std::vector<TableDesc> &tables);
-  bool HandleTopLevelDecl(clang::DeclGroupRef D) override;
+  bool HandleTopLevelDecl(clang::DeclGroupRef Group) override;
  private:
   BTypeVisitor visitor_;
 };
 
+// A helper class to the frontend action, walks the decls
+class ProbeConsumer : public clang::ASTConsumer {
+ public:
+  ProbeConsumer(clang::ASTContext &C, clang::Rewriter &rewriter);
+  bool HandleTopLevelDecl(clang::DeclGroupRef Group) override;
+ private:
+  ProbeVisitor visitor_;
+};
+
 // Create a B program in 2 phases (everything else is normal C frontend):
 // 1. Catch the map declarations and open the fd's
 // 2. Capture the IR
index db4432a..0617854 100644 (file)
@@ -104,7 +104,6 @@ int pem(struct __sk_buff *skb) {
     return 1;
 }
 
-static int br_common(struct __sk_buff *skb, int which_br) __attribute__((always_inline));
 static int br_common(struct __sk_buff *skb, int which_br) {
     u8 *cursor = 0;
     u16 proto;
index a62f80d..e119104 100755 (executable)
@@ -172,5 +172,40 @@ int kprobe__blk_update_request(struct pt_regs *ctx, struct request *req) {
     return 0;
 }""")
 
+    def test_probe_read_helper(self):
+        b = BPF(text="""
+#include <linux/fs.h>
+static void print_file_name(struct file *file) {
+    if (!file) return;
+    const char *name = file->f_path.dentry->d_name.name;
+    bpf_trace_printk("%s\\n", name);
+}
+int trace_entry(struct pt_regs *ctx, struct file *file) {
+    print_file_name(file);
+    return 0;
+}
+""")
+        fn = b.load_func("trace_entry", BPF.KPROBE)
+
+    def test_probe_struct_assign(self):
+        b = BPF(text = """
+#include <uapi/linux/ptrace.h>
+struct args_t {
+    const char *filename;
+    int flags;
+    int mode;
+};
+int kprobe__sys_open(struct pt_regs *ctx, const char *filename,
+        int flags, int mode) {
+    struct args_t args = {};
+    args.filename = filename;
+    args.flags = flags;
+    args.mode = mode;
+    bpf_trace_printk("%s\\n", args.filename);
+    return 0;
+};
+""")
+
+
 if __name__ == "__main__":
     main()