Enable recursive scanf support for char[] as string
authorBrenden Blanco <bblanco@gmail.com>
Tue, 9 May 2017 20:52:42 +0000 (13:52 -0700)
committerBrenden Blanco <bblanco@gmail.com>
Mon, 15 May 2017 17:26:03 +0000 (10:26 -0700)
When a bpf table contains i8[] in one of its keys/leaves, use "" to
enclose the value, rather than [ %i %i %i ... ] format. This simplifies
the code that is generated for cases such as #1154, and brings it back
under ~200ms code generation, instead of >30s. This change of format is
not particularly robust (it doesn't handle escaping the doublequote
character itself), but it should make more sense for the common case,
such as tracing files and pathnames.

The test case included tests both the functionality of the format string
handling as well as the compile time, since test_clang already has an
implicit 10second timeout limit.

Fixes: #1154
Signed-off-by: Brenden Blanco <bblanco@gmail.com>
src/cc/bpf_module.cc
tests/python/test_clang.py

index 4cc8191de792f69f889bbdb37a62cd8d07d3e9ea..69da9cde03d574e2f781ace1527b416cdbdeb8ab 100644 (file)
@@ -153,24 +153,90 @@ static void debug_printf(Module *mod, IRBuilder<> &B, const string &fmt, vector<
   B.CreateCall(fprintf_fn, args);
 }
 
+static void finish_sscanf(IRBuilder<> &B, vector<Value *> *args, string *fmt,
+                          const map<string, Value *> &locals, bool exact_args) {
+  // fmt += "%n";
+  // int nread = 0;
+  // int n = sscanf(s, fmt, args..., &nread);
+  // if (n < 0) return -1;
+  // s = &s[nread];
+  Value *sptr = locals.at("sptr");
+  Value *nread = locals.at("nread");
+  Function *cur_fn = B.GetInsertBlock()->getParent();
+  Function *sscanf_fn = B.GetInsertBlock()->getModule()->getFunction("sscanf");
+  *fmt += "%n";
+  B.CreateStore(B.getInt32(0), nread);
+  GlobalVariable *fmt_gvar = B.CreateGlobalString(*fmt, "fmt");
+  (*args)[1] = B.CreateInBoundsGEP(fmt_gvar, {B.getInt64(0), B.getInt64(0)});
+  (*args)[0] = B.CreateLoad(sptr);
+  args->push_back(nread);
+  CallInst *call = B.CreateCall(sscanf_fn, *args);
+  call->setTailCall(true);
+
+  BasicBlock *label_true = BasicBlock::Create(B.getContext(), "", cur_fn);
+  BasicBlock *label_false = BasicBlock::Create(B.getContext(), "", cur_fn);
+
+  // exact_args means fail if don't consume exact number of "%" inputs
+  // exact_args is disabled for string parsing (empty case)
+  Value *cond = exact_args ? B.CreateICmpNE(call, B.getInt32(args->size() - 3))
+                           : B.CreateICmpSLT(call, B.getInt32(0));
+  B.CreateCondBr(cond, label_true, label_false);
+
+  B.SetInsertPoint(label_true);
+  B.CreateRet(B.getInt32(-1));
+
+  B.SetInsertPoint(label_false);
+  // s = &s[nread];
+  B.CreateStore(
+      B.CreateInBoundsGEP(B.CreateLoad(sptr), B.CreateLoad(nread, true)), sptr);
+
+  args->resize(2);
+  fmt->clear();
+}
+
 // recursive helper to capture the arguments
 static void parse_type(IRBuilder<> &B, vector<Value *> *args, string *fmt,
-                       Type *type, Value *out, bool is_writer) {
+                       Type *type, Value *out,
+                       const map<string, Value *> &locals, bool is_writer) {
   if (StructType *st = dyn_cast<StructType>(type)) {
     *fmt += "{ ";
     unsigned idx = 0;
     for (auto field : st->elements()) {
-      parse_type(B, args, fmt, field, B.CreateStructGEP(type, out, idx++), is_writer);
+      parse_type(B, args, fmt, field, B.CreateStructGEP(type, out, idx++),
+                 locals, is_writer);
       *fmt += " ";
     }
     *fmt += "}";
   } else if (ArrayType *at = dyn_cast<ArrayType>(type)) {
-    *fmt += "[ ";
-    for (size_t i = 0; i < at->getNumElements(); ++i) {
-      parse_type(B, args, fmt, at->getElementType(), B.CreateStructGEP(type, out, i), is_writer);
-      *fmt += " ";
+    if (at->getElementType() == B.getInt8Ty()) {
+      // treat i8[] as a char string instead of as an array of u8's
+      if (is_writer) {
+        *fmt += "\"%s\"";
+        args->push_back(out);
+      } else {
+        // Scan a single "" enclosed string. Passing multiple %[^"] arguments
+        // doesn't work because scanf stops parsing the string when an empty
+        // string is encountered, so here we individually call scanf and mask
+        // the empty string case. A scan failure (e.g. no enclosing "") should
+        // still return an error.
+        *fmt += "\"";
+        finish_sscanf(B, args, fmt, locals, true);
+
+        *fmt = "%[^\"]";
+        args->push_back(out);
+        finish_sscanf(B, args, fmt, locals, false);
+
+        *fmt += "\"";
+      }
+    } else {
+      *fmt += "[ ";
+      for (size_t i = 0; i < at->getNumElements(); ++i) {
+        parse_type(B, args, fmt, at->getElementType(),
+                   B.CreateStructGEP(type, out, i), locals, is_writer);
+        *fmt += " ";
+      }
+      *fmt += "]";
     }
-    *fmt += "]";
   } else if (isa<PointerType>(type)) {
     *fmt += "0xl";
     if (is_writer)
@@ -209,6 +275,16 @@ string BPFModule::make_reader(Module *mod, Type *type) {
 
   IRBuilder<> B(*ctx_);
 
+  FunctionType *sscanf_fn_type = FunctionType::get(
+      B.getInt32Ty(), {B.getInt8PtrTy(), B.getInt8PtrTy()}, /*isVarArg=*/true);
+  Function *sscanf_fn = mod->getFunction("sscanf");
+  if (!sscanf_fn) {
+    sscanf_fn = Function::Create(sscanf_fn_type, GlobalValue::ExternalLinkage,
+                                 "sscanf", mod);
+    sscanf_fn->setCallingConv(CallingConv::C);
+    sscanf_fn->addFnAttr(Attribute::NoUnwind);
+  }
+
   string name = "reader" + std::to_string(readers_.size());
   vector<Type *> fn_args({B.getInt8PtrTy(), PointerType::getUnqual(type)});
   FunctionType *fn_type = FunctionType::get(B.getInt32Ty(), fn_args, /*isVarArg=*/false);
@@ -223,40 +299,21 @@ string BPFModule::make_reader(Module *mod, Type *type) {
   arg_out->setName("out");
 
   BasicBlock *label_entry = BasicBlock::Create(*ctx_, "entry", fn);
-  BasicBlock *label_exit = BasicBlock::Create(*ctx_, "exit", fn);
   B.SetInsertPoint(label_entry);
 
-  vector<Value *> args({arg_in, nullptr});
+  Value *nread = B.CreateAlloca(B.getInt32Ty());
+  Value *sptr = B.CreateAlloca(B.getInt8PtrTy());
+  map<string, Value *> locals{{"nread", nread}, {"sptr", sptr}};
+  B.CreateStore(arg_in, sptr);
+  vector<Value *> args({nullptr, nullptr});
   string fmt;
-  parse_type(B, &args, &fmt, type, arg_out, false);
-
-  GlobalVariable *fmt_gvar = B.CreateGlobalString(fmt, "fmt");
-
-  args[1] = B.CreateInBoundsGEP(fmt_gvar, vector<Value *>({B.getInt64(0), B.getInt64(0)}));
+  parse_type(B, &args, &fmt, type, arg_out, locals, false);
 
   if (0)
     debug_printf(mod, B, "%p %p\n", vector<Value *>({arg_in, arg_out}));
 
-  vector<Type *> sscanf_fn_args({B.getInt8PtrTy(), B.getInt8PtrTy()});
-  FunctionType *sscanf_fn_type = FunctionType::get(B.getInt32Ty(), sscanf_fn_args, /*isVarArg=*/true);
-  Function *sscanf_fn = mod->getFunction("sscanf");
-  if (!sscanf_fn)
-    sscanf_fn = Function::Create(sscanf_fn_type, GlobalValue::ExternalLinkage, "sscanf", mod);
-  sscanf_fn->setCallingConv(CallingConv::C);
-  sscanf_fn->addFnAttr(Attribute::NoUnwind);
-
-  CallInst *call = B.CreateCall(sscanf_fn, args);
-  call->setTailCall(true);
-
-  BasicBlock *label_then = BasicBlock::Create(*ctx_, "then", fn);
-
-  Value *is_neq = B.CreateICmpNE(call, B.getInt32(args.size() - 2));
-  B.CreateCondBr(is_neq, label_then, label_exit);
-
-  B.SetInsertPoint(label_then);
-  B.CreateRet(B.getInt32(-1));
+  finish_sscanf(B, &args, &fmt, locals, true);
 
-  B.SetInsertPoint(label_exit);
   B.CreateRet(B.getInt32(0));
 
   readers_[type] = name;
@@ -293,9 +350,12 @@ string BPFModule::make_writer(Module *mod, Type *type) {
   BasicBlock *label_entry = BasicBlock::Create(*ctx_, "entry", fn);
   B.SetInsertPoint(label_entry);
 
+  map<string, Value *> locals{
+      {"nread", B.CreateAlloca(B.getInt64Ty())},
+  };
   vector<Value *> args({arg_out, B.CreateZExt(arg_len, B.getInt64Ty()), nullptr});
   string fmt;
-  parse_type(B, &args, &fmt, type, arg_in, true);
+  parse_type(B, &args, &fmt, type, arg_in, locals, true);
 
   GlobalVariable *fmt_gvar = B.CreateGlobalString(fmt, "fmt");
 
index 1b9bcbfc4b9f84349a7849c92568da6b10e6f51f..5d9f036815df017f31bf02035a864ada7e1cc065 100755 (executable)
@@ -121,6 +121,38 @@ BPF_HASH(stats, int, struct { u32 a[3]; u32 b; }, 10);
         self.assertEqual(l.a[2], 3)
         self.assertEqual(l.b, 4)
 
+    def test_sscanf_string(self):
+        text = """
+struct Symbol {
+    char name[128];
+    char path[128];
+};
+struct Event {
+    uint32_t pid;
+    uint32_t tid;
+    struct Symbol stack[64];
+};
+BPF_TABLE("array", int, struct Event, comms, 1);
+"""
+        b = BPF(text=text)
+        t = b.get_table("comms")
+        s1 = t.leaf_sprintf(t[0])
+        fill = b' { "" "" }' * 63
+        self.assertEqual(s1, b'{ 0x0 0x0 [ { "" "" }%s ] }' % fill)
+        l = t.Leaf(1, 2)
+        name = b"libxyz"
+        path = b"/usr/lib/libxyz.so"
+        l.stack[0].name = name
+        l.stack[0].path = path
+        s2 = t.leaf_sprintf(l)
+        self.assertEqual(s2,
+                b'{ 0x1 0x2 [ { "%s" "%s" }%s ] }' % (name, path, fill))
+        l = t.leaf_scanf(s2)
+        self.assertEqual(l.pid, 1)
+        self.assertEqual(l.tid, 2)
+        self.assertEqual(l.stack[0].name, name)
+        self.assertEqual(l.stack[0].path, path)
+
     def test_iosnoop(self):
         text = """
 #include <linux/blkdev.h>