Allow the use of custom keys in BPF_HASH_OF_MAPS (#3500)
authormasibw <masi19bw@gmail.com>
Tue, 22 Jun 2021 06:18:23 +0000 (15:18 +0900)
committerGitHub <noreply@github.com>
Tue, 22 Jun 2021 06:18:23 +0000 (23:18 -0700)
 - Allow the use of custom keys in BPF_HASH_OF_MAPS
 - Add both python and C++ tests

src/cc/api/BPF.cc
src/cc/api/BPF.h
src/cc/api/BPFTable.cc
src/cc/api/BPFTable.h
src/cc/export/helpers.h
tests/cc/test_map_in_map.cc
tests/python/CMakeLists.txt
tests/python/test_map_in_map.py

index 04ef3bf7e10ab8994c832121b9cfc51242eac47f..87d4a331da32e80b9fe84a41fb4155eb13b3ae39 100644 (file)
@@ -833,13 +833,6 @@ BPFStackBuildIdTable BPF::get_stackbuildid_table(const std::string &name, bool u
   return BPFStackBuildIdTable({}, use_debug_file, check_debug_file_crc, get_bsymcache());
 }
 
-BPFMapInMapTable BPF::get_map_in_map_table(const std::string& name) {
-  TableStorage::iterator it;
-  if (bpf_module_->table_storage().Find(Path({bpf_module_->id(), name}), it))
-    return BPFMapInMapTable(it->second);
-  return BPFMapInMapTable({});
-}
-
 BPFSockmapTable BPF::get_sockmap_table(const std::string& name) {
   TableStorage::iterator it;
   if (bpf_module_->table_storage().Find(Path({bpf_module_->id(), name}), it))
index d6f3b2a97447255a5a8599decfabeed8d7d4d84f..c266828e52040fe5cbe3e3b6723dee959b36ad7c 100644 (file)
@@ -211,8 +211,13 @@ class BPF {
   BPFStackBuildIdTable get_stackbuildid_table(const std::string &name,
                                               bool use_debug_file = true,
                                               bool check_debug_file_crc = true);
-
-  BPFMapInMapTable get_map_in_map_table(const std::string& name);
+  template <class KeyType>
+  BPFMapInMapTable<KeyType> get_map_in_map_table(const std::string& name){
+      TableStorage::iterator it;
+      if (bpf_module_->table_storage().Find(Path({bpf_module_->id(), name}), it))
+        return BPFMapInMapTable<KeyType>(it->second);
+      return BPFMapInMapTable<KeyType>({});
+  }
 
   bool add_module(std::string module);
 
index d1050512521f36aa4d1a36f3eb739f35873c4d5d..f27e712548e8c30e631dfc2c88417ac9e738f6fc 100644 (file)
@@ -689,27 +689,6 @@ StatusTuple BPFXskmapTable::remove_value(const int& index) {
     return StatusTuple::OK();
 }
 
-BPFMapInMapTable::BPFMapInMapTable(const TableDesc& desc)
-    : BPFTableBase<int, int>(desc) {
-    if(desc.type != BPF_MAP_TYPE_ARRAY_OF_MAPS &&
-       desc.type != BPF_MAP_TYPE_HASH_OF_MAPS)
-      throw std::invalid_argument("Table '" + desc.name +
-                                  "' is not a map-in-map table");
-}
-
-StatusTuple BPFMapInMapTable::update_value(const int& index,
-                                           const int& inner_map_fd) {
-    if (!this->update(const_cast<int*>(&index), const_cast<int*>(&inner_map_fd)))
-      return StatusTuple(-1, "Error updating value: %s", std::strerror(errno));
-    return StatusTuple::OK();
-}
-
-StatusTuple BPFMapInMapTable::remove_value(const int& index) {
-    if (!this->remove(const_cast<int*>(&index)))
-      return StatusTuple(-1, "Error removing value: %s", std::strerror(errno));
-    return StatusTuple::OK();
-}
-
 BPFSockmapTable::BPFSockmapTable(const TableDesc& desc)
     : BPFTableBase<int, int>(desc) {
     if(desc.type != BPF_MAP_TYPE_SOCKMAP)
index d63f3c5db993ef991ee249ce8a5ac9fbfb959f47..8786a3f9a80bdf5d0bc3eb0022fa149457224a93 100644 (file)
@@ -479,12 +479,26 @@ public:
   StatusTuple remove_value(const int& index);
 };
 
-class BPFMapInMapTable : public BPFTableBase<int, int> {
-public:
-  BPFMapInMapTable(const TableDesc& desc);
-
-  StatusTuple update_value(const int& index, const int& inner_map_fd);
-  StatusTuple remove_value(const int& index);
+template <class KeyType>
+class BPFMapInMapTable : public BPFTableBase<KeyType, int> {
+ public:
+  BPFMapInMapTable(const TableDesc& desc) : BPFTableBase<KeyType, int>(desc) {
+    if (desc.type != BPF_MAP_TYPE_ARRAY_OF_MAPS &&
+        desc.type != BPF_MAP_TYPE_HASH_OF_MAPS)
+      throw std::invalid_argument("Table '" + desc.name +
+                                  "' is not a map-in-map table");
+  }
+  virtual StatusTuple update_value(const KeyType& key, const int& inner_map_fd) {
+    if (!this->update(const_cast<KeyType*>(&key),
+                      const_cast<int*>(&inner_map_fd)))
+      return StatusTuple(-1, "Error updating value: %s", std::strerror(errno));
+    return StatusTuple::OK();
+  }
+  virtual StatusTuple remove_value(const KeyType& key) {
+    if (!this->remove(const_cast<KeyType*>(&key)))
+      return StatusTuple(-1, "Error removing value: %s", std::strerror(errno));
+    return StatusTuple::OK();
+  }
 };
 
 class BPFSockmapTable : public BPFTableBase<int, int> {
index 12072b06a975a8debe9a9b563e0192691055c477..a3283d2e2ede6a05e45c6e6cd941d4b2514e9350 100644 (file)
@@ -384,8 +384,17 @@ struct _name##_table_t _name = { .max_entries = (_max_entries) }
 #define BPF_ARRAY_OF_MAPS(_name, _inner_map_name, _max_entries) \
   BPF_TABLE("array_of_maps$" _inner_map_name, int, int, _name, _max_entries)
 
-#define BPF_HASH_OF_MAPS(_name, _inner_map_name, _max_entries) \
-  BPF_TABLE("hash_of_maps$" _inner_map_name, int, int, _name, _max_entries)
+#define BPF_HASH_OF_MAPS2(_name, _inner_map_name) \
+  BPF_TABLE("hash_of_maps$" _inner_map_name, int, int, _name, 10240)
+#define BPF_HASH_OF_MAPS3(_name, _key_type, _inner_map_name) \
+  BPF_TABLE("hash_of_maps$" _inner_map_name, _key_type, int, _name, 10240)
+#define BPF_HASH_OF_MAPS4(_name, _key_type, _inner_map_name, _max_entries) \
+  BPF_TABLE("hash_of_maps$" _inner_map_name, _key_type, int, _name, _max_entries)
+
+#define BPF_HASH_OF_MAPSX(_name, _2, _3, _4, NAME, ...) NAME
+
+#define BPF_HASH_OF_MAPS(...) \
+  BPF_HASH_OF_MAPSX(__VA_ARGS__, BPF_HASH_OF_MAPS4, BPF_HASH_OF_MAPS3, BPF_HASH_OF_MAPS2)(__VA_ARGS__)
 
 #define BPF_SK_STORAGE(_name, _leaf_type) \
 struct _name##_table_t { \
index f8c1a0b6683763dbc0f841557b289bff6ac2795f..7a383de92a916616179c2e0757ab46f53510908e 100644 (file)
@@ -30,7 +30,7 @@ TEST_CASE("test hash of maps", "[hash_of_maps]") {
       BPF_ARRAY(ex1, int, 1024);
       BPF_ARRAY(ex2, int, 1024);
       BPF_ARRAY(ex3, u64, 1024);
-      BPF_HASH_OF_MAPS(maps_hash, "ex1", 10);
+      BPF_HASH_OF_MAPS(maps_hash, int, "ex1", 10);
 
       int syscall__getuid(void *ctx) {
          int key = 0, data, *val, cntl_val;
@@ -63,7 +63,7 @@ TEST_CASE("test hash of maps", "[hash_of_maps]") {
     res = bpf.init(BPF_PROGRAM);
     REQUIRE(res.code() == 0);
 
-    auto t = bpf.get_map_in_map_table("maps_hash");
+    auto t = bpf.get_map_in_map_table<int>("maps_hash");
     auto ex1_table = bpf.get_array_table<int>("ex1");
     auto ex2_table = bpf.get_array_table<int>("ex2");
     auto ex3_table = bpf.get_array_table<unsigned long long>("ex3");
@@ -115,6 +115,119 @@ TEST_CASE("test hash of maps", "[hash_of_maps]") {
   }
 }
 
+TEST_CASE("test hash of maps using custom key", "[hash_of_maps_custom_key]") {
+  {
+    const std::string BPF_PROGRAM = R"(
+        struct custom_key {
+          int value_1;
+          int value_2;
+        };
+
+        BPF_ARRAY(cntl, int, 1);
+        BPF_TABLE("hash", int, int, ex1, 1024);
+        BPF_TABLE("hash", int, int, ex2, 1024);
+        BPF_HASH_OF_MAPS(maps_hash, struct custom_key, "ex1", 10);
+
+        int syscall__getuid(void *ctx) {
+          struct custom_key hash_key = {1, 0};
+          int key = 0, data, *val, cntl_val;
+          void *inner_map;
+
+          val = cntl.lookup(&key);
+          if (!val || *val == 0)
+            return 0;
+
+          hash_key.value_2 = *val;
+          inner_map = maps_hash.lookup(&hash_key);
+          if (!inner_map)
+            return 0;
+
+          val = bpf_map_lookup_elem(inner_map, &key);
+          if (!val) {
+            data = 1;
+            bpf_map_update_elem(inner_map, &key, &data, 0);
+          } else {
+            data = 1 + *val;
+            bpf_map_update_elem(inner_map, &key, &data, 0);
+          }
+
+          return 0;
+        }
+    )";
+
+    struct custom_key {
+      int value_1;
+      int value_2;
+    };
+
+    ebpf::BPF bpf;
+    ebpf::StatusTuple res(0);
+    res = bpf.init(BPF_PROGRAM);
+    REQUIRE(res.code() == 0);
+
+    auto t = bpf.get_map_in_map_table<struct custom_key>("maps_hash");
+    auto ex1_table = bpf.get_hash_table<int, int>("ex1");
+    auto ex2_table = bpf.get_hash_table<int, int>("ex2");
+    auto cntl_table = bpf.get_array_table<int>("cntl");
+    int ex1_fd = ex1_table.get_fd();
+    int ex2_fd = ex2_table.get_fd();
+
+    // test effectiveness of map-in-map
+    std::string getuid_fnname = bpf.get_syscall_fnname("getuid");
+    res = bpf.attach_kprobe(getuid_fnname, "syscall__getuid");
+    REQUIRE(res.code() == 0);
+
+    struct custom_key hash_key = {1, 1};
+
+    res = t.update_value(hash_key, ex1_fd);
+    REQUIRE(res.code() == 0);
+
+    struct custom_key hash_key2 = {1, 2};
+    res = t.update_value(hash_key2, ex2_fd);
+    REQUIRE(res.code() == 0);
+
+    int key = 0, value = 0, value2 = 0;
+
+    // Can't get value when value didn't set.
+    res = ex1_table.get_value(key, value);
+    REQUIRE(res.code() != 0);
+    REQUIRE(value == 0);
+
+    // Call syscall__getuid, then set value to ex1_table
+    res = cntl_table.update_value(key, 1);
+    REQUIRE(res.code() == 0);
+    REQUIRE(getuid() >= 0);
+
+    // Now we can get value from ex1_table
+    res = ex1_table.get_value(key, value);
+    REQUIRE(res.code() == 0);
+    REQUIRE(value >= 1);
+
+    // Can't get value when value didn't set.
+    res = ex2_table.get_value(key, value2);
+    REQUIRE(res.code() != 0);
+    REQUIRE(value2 == 0);
+
+    // Call syscall__getuid, then set value to ex2_table
+    res = cntl_table.update_value(key, 2);
+    REQUIRE(res.code() == 0);
+    REQUIRE(getuid() >= 0);
+
+    // Now we can get value from ex2_table
+    res = ex2_table.get_value(key, value2);
+    REQUIRE(res.code() == 0);
+    REQUIRE(value > 0);
+
+    res = bpf.detach_kprobe(getuid_fnname);
+    REQUIRE(res.code() == 0);
+
+    res = t.remove_value(hash_key);
+    REQUIRE(res.code() == 0);
+    res = t.remove_value(hash_key2);
+    REQUIRE(res.code() == 0);
+  }
+}
+
 TEST_CASE("test array of maps", "[array_of_maps]") {
   {
     const std::string BPF_PROGRAM = R"(
@@ -158,7 +271,7 @@ TEST_CASE("test array of maps", "[array_of_maps]") {
     res = bpf.init(BPF_PROGRAM);
     REQUIRE(res.code() == 0);
 
-    auto t = bpf.get_map_in_map_table("maps_array");
+    auto t = bpf.get_map_in_map_table<int>("maps_array");
     auto ex1_table = bpf.get_hash_table<int, int>("ex1");
     auto ex2_table = bpf.get_hash_table<int, int>("ex2");
     auto ex3_table =
index 7e60413ba9cbae7c2604d49c3457dedee615e2f9..3cd96031ceb1de7d763dc92882965b0742236519 100644 (file)
@@ -95,3 +95,5 @@ add_test(NAME py_queuestack WORKING_DIRECTORY ${CMAKE_CURRENT_SOURCE_DIR}
   COMMAND ${TEST_WRAPPER} py_queuestack sudo ${CMAKE_CURRENT_SOURCE_DIR}/test_queuestack.py)
 add_test(NAME py_test_map_batch_ops WORKING_DIRECTORY ${CMAKE_CURRENT_SOURCE_DIR}
   COMMAND ${TEST_WRAPPER} py_test_map_batch_ops sudo ${CMAKE_CURRENT_SOURCE_DIR}/test_map_batch_ops.py)
+add_test(NAME py_test_map_in_map WORKING_DIRECTORY ${CMAKE_CURRENT_SOURCE_DIR}
+  COMMAND ${TEST_WRAPPER} py_test_map_in_map sudo ${CMAKE_CURRENT_SOURCE_DIR}/test_map_in_map.py)
index 0f5960431edd3f21032f7a9aca9b5fbfb07298eb..bd909d844520093b145bf35f33b91f28eb1f02ee 100755 (executable)
@@ -12,6 +12,13 @@ from unittest import main, skipUnless, TestCase
 import ctypes as ct
 import os
 
+
+class CustomKey(ct.Structure):
+  _fields_ = [
+    ("value_1", ct.c_int),
+    ("value_2", ct.c_int)
+  ]
+
 def kernel_version_ge(major, minor):
     # True if running kernel is >= X.Y
     version = distutils.version.LooseVersion(os.uname()[2]).version
@@ -30,7 +37,7 @@ class TestUDST(TestCase):
       BPF_ARRAY(cntl, int, 1);
       BPF_TABLE("hash", int, int, ex1, 1024);
       BPF_TABLE("hash", int, int, ex2, 1024);
-      BPF_HASH_OF_MAPS(maps_hash, "ex1", 10);
+      BPF_HASH_OF_MAPS(maps_hash, int, "ex1", 10);
 
       int syscall__getuid(void *ctx) {
          int key = 0, data, *val, cntl_val;
@@ -77,7 +84,7 @@ class TestUDST(TestCase):
 
         cntl_map[0] = ct.c_int(1)
         os.getuid()
-        assert(ex1_map[ct.c_int(0)] >= 1)
+        assert(ex1_map[ct.c_int(0)].value >= 1)
 
         try:
           ex2_map[ct.c_int(0)]
@@ -87,12 +94,85 @@ class TestUDST(TestCase):
 
         cntl_map[0] = ct.c_int(2)
         os.getuid()
-        assert(ex2_map[ct.c_int(0)] >= 1)
+        assert(ex2_map[ct.c_int(0)].value >= 1)
 
         b.detach_kprobe(event=syscall_fnname)
         del hash_maps[ct.c_int(1)]
         del hash_maps[ct.c_int(2)]
 
+    def test_hash_table_custom_key(self):
+        bpf_text = """
+        struct custom_key {
+          int value_1;
+          int value_2;
+        };
+
+        BPF_ARRAY(cntl, int, 1);
+        BPF_TABLE("hash", int, int, ex1, 1024);
+        BPF_TABLE("hash", int, int, ex2, 1024);
+        BPF_HASH_OF_MAPS(maps_hash, struct custom_key, "ex1", 10);
+
+        int syscall__getuid(void *ctx) {
+          struct custom_key hash_key = {1, 0};
+          int key = 0, data, *val, cntl_val;
+          void *inner_map;
+
+          val = cntl.lookup(&key);
+          if (!val || *val == 0)
+            return 0;
+
+          hash_key.value_2 = *val;
+          inner_map = maps_hash.lookup(&hash_key);
+          if (!inner_map)
+            return 0;
+
+          val = bpf_map_lookup_elem(inner_map, &key);
+          if (!val) {
+            data = 1;
+            bpf_map_update_elem(inner_map, &key, &data, 0);
+          } else {
+            data = 1 + *val;
+            bpf_map_update_elem(inner_map, &key, &data, 0);
+          }
+
+          return 0;
+        }
+"""
+        b = BPF(text=bpf_text)
+        cntl_map = b.get_table("cntl")
+        ex1_map = b.get_table("ex1")
+        ex2_map = b.get_table("ex2")
+        hash_maps = b.get_table("maps_hash")
+
+        hash_maps[CustomKey(1, 1)] = ct.c_int(ex1_map.get_fd())
+        hash_maps[CustomKey(1, 2)] = ct.c_int(ex2_map.get_fd())
+        syscall_fnname = b.get_syscall_fnname("getuid")
+        b.attach_kprobe(event=syscall_fnname, fn_name="syscall__getuid")
+
+        try:
+          ex1_map[ct.c_int(0)]
+          raise Exception("Unexpected success for ex1_map[0]")
+        except KeyError:
+          pass
+
+        cntl_map[0] = ct.c_int(1)
+        os.getuid()
+        assert(ex1_map[ct.c_int(0)].value >= 1)
+
+        try:
+          ex2_map[ct.c_int(0)]
+          raise Exception("Unexpected success for ex2_map[0]")
+        except KeyError:
+          pass
+
+        cntl_map[0] = ct.c_int(2)
+        os.getuid()
+        assert(ex2_map[ct.c_int(0)].value >= 1)
+
+        b.detach_kprobe(event=syscall_fnname)
+        del hash_maps[CustomKey(1, 1)]
+        del hash_maps[CustomKey(1, 2)]
+
     def test_array_table(self):
         bpf_text = """
       BPF_ARRAY(cntl, int, 1);
@@ -136,11 +216,11 @@ class TestUDST(TestCase):
 
         cntl_map[0] = ct.c_int(1)
         os.getuid()
-        assert(ex1_map[ct.c_int(0)] >= 1)
+        assert(ex1_map[ct.c_int(0)].value >= 1)
 
         cntl_map[0] = ct.c_int(2)
         os.getuid()
-        assert(ex2_map[ct.c_int(0)] >= 1)
+        assert(ex2_map[ct.c_int(0)].value >= 1)
 
         b.detach_kprobe(event=syscall_fnname)