Speedup (x4) lookup in large rbl database.
[apps/pfixtools.git] / postlicyd / rbl.c
index 9c76ed6..d03d06b 100644 (file)
@@ -61,8 +61,7 @@ enum {
 };
 
 struct rbldb_t {
-    A(uint32_t) ips;
-    bool        locked;
+    A(uint16_t) ips[1 << 16];
 };
 ARRAY(rbldb_t)
 
@@ -123,6 +122,7 @@ rbldb_t *rbldb_create(const char *file, bool lock)
     rbldb_t *db;
     file_map_t map;
     const char *p, *end;
+    uint32_t ips = 0;
 
     if (!file_map_open(&map, file, false)) {
         return NULL;
@@ -134,8 +134,8 @@ rbldb_t *rbldb_create(const char *file, bool lock)
         --end;
     }
     if (end != map.end) {
-        syslog(LOG_WARNING, "file %s miss a final \\n, ignoring last line",
-               file);
+        warn("file %s miss a final \\n, ignoring last line",
+             file);
     }
 
     db = p_new(rbldb_t, 1);
@@ -148,37 +148,37 @@ rbldb_t *rbldb_create(const char *file, bool lock)
         if (parse_ipv4(p, &p, &ip) < 0) {
             p = (char *)memchr(p, '\n', end - p) + 1;
         } else {
-            array_add(db->ips, ip);
+            array_add(db->ips[ip >> 16], ip & 0xffff);
+            ++ips;
         }
     }
     file_map_close(&map);
 
     /* Lookup may perform serveral I/O, so avoid swap.
      */
-    array_adjust(db->ips);
-    db->locked = lock && array_lock(db->ips);
-    if (lock && !db->locked) {
-        UNIXERR("mlock");
-    }
-
-    if (db->ips.len) {
-#       define QSORT_TYPE uint32_t
-#       define QSORT_BASE db->ips.data
-#       define QSORT_NELT db->ips.len
+    for (int i = 0 ; i < 1 << 16 ; ++i) {
+        array_adjust(db->ips[i]);
+        if (lock && !array_lock(db->ips[i])) {
+            UNIXERR("mlock");
+        }
+        if (db->ips[i].len) {
+#       define QSORT_TYPE uint16_t
+#       define QSORT_BASE db->ips[i].data
+#       define QSORT_NELT db->ips[i].len
 #       define QSORT_LT(a,b) *a < *b
 #       include "qsort.c"
+        }
     }
 
-    syslog(LOG_INFO, "rbl %s loaded, %d IPs", file, db->ips.len);
+    info("rbl %s loaded, %d IPs", file, ips);
     return db;
 }
 
 static void rbldb_wipe(rbldb_t *db)
 {
-    if (db->locked) {
-      array_unlock(db->ips);
+    for (int i = 0 ; i < 1 << 16 ; ++i) {
+        array_wipe(db->ips[i]);
     }
-    array_wipe(db->ips);
 }
 
 void rbldb_delete(rbldb_t **db)
@@ -191,20 +191,27 @@ void rbldb_delete(rbldb_t **db)
 
 uint32_t rbldb_stats(const rbldb_t *rbl)
 {
-    return rbl->ips.len;
+    uint32_t ips = 0;
+    for (int i = 0 ; i < 1 << 16 ; ++i) {
+        ips += array_len(rbl->ips[i]);
+    }
+    printf("memory overhead of rbldb: %u\n", sizeof(rbldb_t));
+    return ips;
 }
 
 bool rbldb_ipv4_lookup(const rbldb_t *db, uint32_t ip)
 {
-    int l = 0, r = db->ips.len;
+    const uint16_t hip = ip >> 16;
+    const uint16_t lip = ip & 0xffff;
+    int l = 0, r = db->ips[hip].len;
 
     while (l < r) {
         int i = (r + l) / 2;
 
-        if (array_elt(db->ips, i) == ip)
+        if (array_elt(db->ips[hip], i) == lip)
             return true;
 
-        if (ip < array_elt(db->ips, i)) {
+        if (lip < array_elt(db->ips[hip], i)) {
             r = i;
         } else {
             l = i + 1;
@@ -247,7 +254,7 @@ static bool rbl_filter_constructor(filter_t *filter)
 
 #define PARSE_CHECK(Expr, Str, ...)                                            \
     if (!(Expr)) {                                                             \
-        syslog(LOG_ERR, Str, ##__VA_ARGS__);                                   \
+        err(Str, ##__VA_ARGS__);                                               \
         rbl_filter_delete(&data);                                              \
         return false;                                                          \
     }
@@ -305,8 +312,10 @@ static bool rbl_filter_constructor(filter_t *filter)
                     array_add(data->weights, weight);
                     break;
                 }
-                current = p + 1;
-                p = m_strchrnul(current, ':');
+                if (i != 2) {
+                    current = p + 1;
+                    p = m_strchrnul(current, ':');
+                }
             }
           } break;
 
@@ -351,11 +360,11 @@ static filter_result_t rbl_filter(const filter_t *filter, const query_t *query)
     const rbl_filter_t *data = filter->data;
 
     if (parse_ipv4(query->client_address, &end, &ip) != 0) {
-        syslog(LOG_WARNING, "invalid client address: %s, expected ipv4",
-               query->client_address);
+        warn("invalid client address: %s, expected ipv4",
+             query->client_address);
         return HTK_ERROR;
     }
-    for (int i = 0 ; i < data->rbls.len ; ++i) {
+    for (uint32_t i = 0 ; i < data->rbls.len ; ++i) {
         const rbldb_t *rbl = array_elt(data->rbls, i);
         int weight   = array_elt(data->weights, i);
         if (rbldb_ipv4_lookup(rbl, ip)) {
@@ -373,7 +382,7 @@ static filter_result_t rbl_filter(const filter_t *filter, const query_t *query)
 
 static int rbl_init(void)
 {
-    filter_type_t type =  filter_register("rbl", rbl_filter_constructor,
+    filter_type_t type =  filter_register("iplist", rbl_filter_constructor,
                                           rbl_filter_destructor, rbl_filter);
     /* Hooks.
      */