Add LLVM IR based sscanf routine, to be run by the JIT
authorBrenden Blanco <bblanco@plumgrid.com>
Sat, 8 Aug 2015 04:04:35 +0000 (21:04 -0700)
committerBrenden Blanco <bblanco@plumgrid.com>
Sat, 8 Aug 2015 04:04:35 +0000 (21:04 -0700)
After the modules have been created, create a helper function for each
table leaf/key type.

Signed-off-by: Brenden Blanco <bblanco@plumgrid.com>
CMakeLists.txt
src/cc/bpf_module.cc
src/cc/bpf_module.h
src/cc/frontends/b/codegen_llvm.cc
tests/cc/test_clang.py

index a8bb4f6..4813589 100644 (file)
@@ -3,7 +3,7 @@
 cmake_minimum_required(VERSION 2.8.7)
 
 project(bcc)
-set(CMAKE_BUILD_TYPE Release)
+set(CMAKE_BUILD_TYPE Debug)
 
 enable_testing()
 
index 95e6bbd..5f615b5 100644 (file)
@@ -29,6 +29,7 @@
 #include <llvm/ExecutionEngine/MCJIT.h>
 #include <llvm/ExecutionEngine/SectionMemoryManager.h>
 #include <llvm/IRReader/IRReader.h>
+#include <llvm/IR/IRBuilder.h>
 #include <llvm/IR/IRPrintingPasses.h>
 #include <llvm/IR/LegacyPassManager.h>
 #include <llvm/IR/LLVMContext.h>
@@ -110,22 +111,86 @@ BPFModule::~BPFModule() {
   ctx_.reset();
 }
 
-unique_ptr<ExecutionEngine> BPFModule::make_reader(LLVMContext &ctx) {
-  auto m = make_unique<Module>("scanf_reader", ctx);
-  Module *mod = &*m;
-  auto structs = mod->getIdentifiedStructTypes();
-  for (auto s : structs) {
-    fprintf(stderr, "struct %s\n", s->getName().str().c_str());
+// recursive helper to capture the arguments
+void parse_type(IRBuilder<> &B, vector<Value *> *args, string *fmt, Type *type, Value *out) {
+  if (StructType *st = dyn_cast<StructType>(type)) {
+    *fmt += "{ ";
+    unsigned idx = 0;
+    for (auto field : st->elements()) {
+      parse_type(B, args, fmt, field, B.CreateStructGEP(type, out, idx++));
+      *fmt += " ";
+    }
+    *fmt += "}";
+  } else if (dyn_cast<IntegerType>(type)) {
+    *fmt += "%lli";
+    args->push_back(out);
   }
+}
+
+int BPFModule::make_reader(Module *mod, Type *type) {
+  if (readers_.find(type) != readers_.end()) return 0;
+
+  // int read(const char *in, Type *out) {
+  //   int n = sscanf(in, "{ %i ... }", &out->field1, ...);
+  //   if (n != num_fields) return -1;
+  //   return 0;
+  // }
+
+  IRBuilder<> B(*ctx_);
+
+  vector<Type *> fn_args({B.getInt8PtrTy(), PointerType::getUnqual(type)});
+  FunctionType *fn_type = FunctionType::get(B.getInt32Ty(), fn_args, /*isVarArg=*/false);
+  Function *fn = Function::Create(fn_type, GlobalValue::ExternalLinkage,
+                                  "reader" + std::to_string(readers_.size()), mod);
+  auto arg_it = fn->arg_begin();
+  Argument *arg_in = arg_it++;
+  arg_in->setName("in");
+  Argument *arg_out = arg_it++;
+  arg_out->setName("out");
+
+  BasicBlock *label_entry = BasicBlock::Create(*ctx_, "entry", fn);
+  BasicBlock *label_exit = BasicBlock::Create(*ctx_, "exit", fn);
+  B.SetInsertPoint(label_entry);
+
+  vector<Value *> args;
+  string fmt;
+  parse_type(B, &args, &fmt, type, arg_out);
+
+  GlobalVariable *fmt_gvar = B.CreateGlobalString(fmt, "fmt");
+
+  args.insert(args.begin(), B.CreateInBoundsGEP(fmt_gvar, vector<Value *>({B.getInt64(0), B.getInt64(0)})));
+  args.insert(args.begin(), arg_in);
+
+  vector<Type *> sscanf_fn_args({B.getInt8PtrTy(), B.getInt8PtrTy()});
+  FunctionType *sscanf_fn_type = FunctionType::get(B.getInt32Ty(), sscanf_fn_args, /*isVarArg=*/true);
+  Function *sscanf_fn = mod->getFunction("__isoc99_sscanf");
+  if (!sscanf_fn)
+    sscanf_fn = Function::Create(sscanf_fn_type, GlobalValue::ExternalLinkage, "__isoc99_sscanf", mod);
+
+  CallInst *call = B.CreateCall(sscanf_fn, args);
+  BasicBlock *label_then = BasicBlock::Create(*ctx_, "then", fn);
+
+  Value *is_neq = B.CreateICmpNE(call, B.getInt32(args.size() - 2));
+  B.CreateCondBr(is_neq, label_then, label_exit);
+
+  B.SetInsertPoint(label_then);
+  B.CreateRet(B.getInt32(-1));
+
+  B.SetInsertPoint(label_exit);
+  B.CreateRet(B.getInt32(0));
+
+  readers_[type] = fn;
+  return 0;
+}
+
+unique_ptr<ExecutionEngine> BPFModule::finalize_reader(unique_ptr<Module> m) {
+  Module *mod = &*m;
 
-  dump_ir(*mod);
   run_pass_manager(*mod);
 
   string err;
-  map<string, tuple<uint8_t *, uintptr_t>> sections;
   EngineBuilder builder(move(m));
   builder.setErrorStr(&err);
-  builder.setMCJITMemoryManager(make_unique<MyMemoryManager>(&sections));
   builder.setUseOrcMCJITReplacement(true);
   auto engine = unique_ptr<ExecutionEngine>(builder.create());
   if (!engine)
@@ -157,34 +222,36 @@ int BPFModule::annotate() {
   for (auto fn = mod_->getFunctionList().begin(); fn != mod_->getFunctionList().end(); ++fn)
     fn->addFnAttr(Attribute::AlwaysInline);
 
+  // separate module to hold the reader functions
+  auto m = make_unique<Module>("sscanf", *ctx_);
+
   for (auto table : *tables_) {
     table_names_.push_back(table.first);
     GlobalValue *gvar = mod_->getNamedValue(table.first);
     if (!gvar) continue;
-    llvm::errs() << "table " << gvar->getName() << "\n";
-  }
-  //for (auto s : mod_->getIdentifiedStructTypes()) {
-  //  llvm::errs() << "struct " << s->getName() << "\n";
-  //  for (auto e : s->elements()) {
-  //    llvm::errs() << " ";
-  //    e->print(llvm::errs());
-  //    llvm::errs() << "\n";
-  //  }
-  //}
-
-  if (1) {
-    auto engine = make_reader(*ctx_);
-    if (engine)
-      engine->finalizeObject();
+    if (PointerType *pt = dyn_cast<PointerType>(gvar->getType())) {
+      if (StructType *st = dyn_cast<StructType>(pt->getElementType())) {
+        if (st->getNumElements() < 2) continue;
+        Type *key_type = st->elements()[0];
+        Type *leaf_type = st->elements()[1];
+        if (int rc = make_reader(&*m, key_type))
+          return rc;
+        if (int rc = make_reader(&*m, leaf_type))
+          return rc;
+      }
+    }
   }
 
+  auto engine = finalize_reader(move(m));
+  if (engine)
+    engine->finalizeObject();
 
   return 0;
 }
 
 void BPFModule::dump_ir(Module &mod) {
   legacy::PassManager PM;
-  PM.add(createPrintModulePass(outs()));
+  PM.add(createPrintModulePass(errs()));
   PM.run(mod);
 }
 
index 8f0c5e4..8f57dd0 100644 (file)
 
 namespace llvm {
 class ExecutionEngine;
+class Function;
 class LLVMContext;
 class Module;
+class Type;
 }
 
 namespace ebpf {
@@ -40,7 +42,8 @@ class BPFModule {
   int parse(llvm::Module *mod);
   int finalize();
   int annotate();
-  std::unique_ptr<llvm::ExecutionEngine> make_reader(llvm::LLVMContext &ctx);
+  std::unique_ptr<llvm::ExecutionEngine> finalize_reader(std::unique_ptr<llvm::Module> mod);
+  int make_reader(llvm::Module *mod, llvm::Type *type);
   void dump_ir(llvm::Module &mod);
   int load_file_module(std::unique_ptr<llvm::Module> *mod, const std::string &file, bool in_memory);
   int load_includes(const std::string &tmpfile);
@@ -86,6 +89,7 @@ class BPFModule {
   std::unique_ptr<std::map<std::string, TableDesc>> tables_;
   std::vector<std::string> table_names_;
   std::vector<std::string> function_names_;
+  std::map<llvm::Type *, llvm::Function *> readers_;
 };
 
 }  // namespace ebpf
index e6a7102..26cc063 100644 (file)
@@ -1087,20 +1087,17 @@ StatusTuple CodegenLLVM::visit_table_decl_stmt_node(TableDeclStmtNode *n) {
     else
       return mkstatus_(n, "Table type %s not implemented", n->type_id()->name_.c_str());
 
+    StructType *key_stype, *leaf_stype;
+    TRY2(lookup_struct_type(n->key_type_, &key_stype));
+    TRY2(lookup_struct_type(n->leaf_type_, &leaf_stype));
     StructType *decl_struct = mod_->getTypeByName("_struct." + n->id_->name_);
     if (!decl_struct)
       decl_struct = StructType::create(ctx(), "_struct." + n->id_->name_);
     if (decl_struct->isOpaque())
-      decl_struct->setBody(std::vector<Type *>({Type::getInt32Ty(ctx()), Type::getInt32Ty(ctx()),
-                                                Type::getInt32Ty(ctx()), Type::getInt32Ty(ctx())}),
-                           /*isPacked=*/false);
+      decl_struct->setBody(vector<Type *>({key_stype, leaf_stype}), /*isPacked=*/false);
     GlobalVariable *decl_gvar = new GlobalVariable(*mod_, decl_struct, false,
                                                    GlobalValue::ExternalLinkage, 0, n->id_->name_);
     decl_gvar->setSection("maps");
-    vector<Constant *> struct_init = { B.getInt32(map_type), B.getInt32(key->bit_width_ / 8),
-                                       B.getInt32(leaf->bit_width_ / 8), B.getInt32(n->size_)};
-    Constant *const_struct = ConstantStruct::get(decl_struct, struct_init);
-    decl_gvar->setInitializer(const_struct);
     tables_[n] = decl_gvar;
 
     int map_fd = bpf_create_map(map_type, key->bit_width_ / 8, leaf->bit_width_ / 8, n->size_);
@@ -1170,7 +1167,7 @@ StatusTuple CodegenLLVM::visit_func_decl_stmt_node(FuncDeclStmtNode *n) {
 
   BasicBlock *label_entry = BasicBlock::Create(ctx(), "entry", fn);
   B.SetInsertPoint(label_entry);
-  string scoped_entry_label = std::to_string((uintptr_t)fn) + "::entry";
+  string scoped_entry_label = to_string((uintptr_t)fn) + "::entry";
   labels_[scoped_entry_label] = label_entry;
   BasicBlock *label_return = resolve_label("DONE");
   retval_ = new AllocaInst(fn->getReturnType(), "ret", label_entry);
@@ -1274,7 +1271,7 @@ StatusTuple CodegenLLVM::print_header() {
   return mkstatus(0);
 }
 
-int CodegenLLVM::get_table_fd(const std::string &name) const {
+int CodegenLLVM::get_table_fd(const string &name) const {
   TableDeclStmtNode *table = scopes_->top_table()->lookup(name);
   if (!table)
     return -1;
@@ -1302,7 +1299,7 @@ Value * CodegenLLVM::pop_expr() {
 
 BasicBlock * CodegenLLVM::resolve_label(const string &label) {
   Function *parent = B.GetInsertBlock()->getParent();
-  string scoped_label = std::to_string((uintptr_t)parent) + "::" + label;
+  string scoped_label = to_string((uintptr_t)parent) + "::" + label;
   auto it = labels_.find(scoped_label);
   if (it != labels_.end()) return it->second;
   BasicBlock *label_new = BasicBlock::Create(ctx(), label, parent);
index 9c0bfbd..812c349 100755 (executable)
@@ -46,9 +46,9 @@ int count_foo(struct pt_regs *ctx, unsigned long a, unsigned long b) {
         b = BPF(text=text, debug=0)
         fn = b.load_func("count_foo", BPF.KPROBE)
 
-    def test_scanf(self):
+    def test_sscanf(self):
         text = """
-BPF_TABLE("hash", int, struct { int a; int b; }, stats, 10);
+BPF_TABLE("hash", int, struct { u64 a; u64 b; u64 c:31; u64 d:33; struct { u32 a; u32 b; } s; }, stats, 10);
 int foo(void *ctx) {
     return 0;
 }