Asynchronous DNS queries on iplist.
[apps/pfixtools.git] / postlicyd / iplist.c
index 256c360..c2b471b 100644 (file)
@@ -62,7 +62,7 @@ enum {
 };
 
 struct rbldb_t {
-    A(uint32_t) ips;
+    A(uint16_t) ips[1 << 16];
 };
 ARRAY(rbldb_t)
 
@@ -123,6 +123,7 @@ rbldb_t *rbldb_create(const char *file, bool lock)
     rbldb_t *db;
     file_map_t map;
     const char *p, *end;
+    uint32_t ips = 0;
 
     if (!file_map_open(&map, file, false)) {
         return NULL;
@@ -148,33 +149,37 @@ rbldb_t *rbldb_create(const char *file, bool lock)
         if (parse_ipv4(p, &p, &ip) < 0) {
             p = (char *)memchr(p, '\n', end - p) + 1;
         } else {
-            array_add(db->ips, ip);
+            array_add(db->ips[ip >> 16], ip & 0xffff);
+            ++ips;
         }
     }
     file_map_close(&map);
 
     /* Lookup may perform serveral I/O, so avoid swap.
      */
-    array_adjust(db->ips);
-    if (lock && !array_lock(db->ips)) {
-        UNIXERR("mlock");
-    }
-
-    if (db->ips.len) {
-#       define QSORT_TYPE uint32_t
-#       define QSORT_BASE db->ips.data
-#       define QSORT_NELT db->ips.len
+    for (int i = 0 ; i < 1 << 16 ; ++i) {
+        array_adjust(db->ips[i]);
+        if (lock && !array_lock(db->ips[i])) {
+            UNIXERR("mlock");
+        }
+        if (db->ips[i].len) {
+#       define QSORT_TYPE uint16_t
+#       define QSORT_BASE db->ips[i].data
+#       define QSORT_NELT db->ips[i].len
 #       define QSORT_LT(a,b) *a < *b
 #       include "qsort.c"
+        }
     }
 
-    info("rbl %s loaded, %d IPs", file, db->ips.len);
+    info("rbl %s loaded, %d IPs", file, ips);
     return db;
 }
 
 static void rbldb_wipe(rbldb_t *db)
 {
-    array_wipe(db->ips);
+    for (int i = 0 ; i < 1 << 16 ; ++i) {
+        array_wipe(db->ips[i]);
+    }
 }
 
 void rbldb_delete(rbldb_t **db)
@@ -187,20 +192,26 @@ void rbldb_delete(rbldb_t **db)
 
 uint32_t rbldb_stats(const rbldb_t *rbl)
 {
-    return rbl->ips.len;
+    uint32_t ips = 0;
+    for (int i = 0 ; i < 1 << 16 ; ++i) {
+        ips += array_len(rbl->ips[i]);
+    }
+    return ips;
 }
 
 bool rbldb_ipv4_lookup(const rbldb_t *db, uint32_t ip)
 {
-    int l = 0, r = db->ips.len;
+    const uint16_t hip = ip >> 16;
+    const uint16_t lip = ip & 0xffff;
+    int l = 0, r = db->ips[hip].len;
 
     while (l < r) {
         int i = (r + l) / 2;
 
-        if (array_elt(db->ips, i) == ip)
+        if (array_elt(db->ips[hip], i) == lip)
             return true;
 
-        if (ip < array_elt(db->ips, i)) {
+        if (lip < array_elt(db->ips[hip], i)) {
             r = i;
         } else {
             l = i + 1;
@@ -214,7 +225,7 @@ bool rbldb_ipv4_lookup(const rbldb_t *db, uint32_t ip)
 
 #include "filter.h"
 
-typedef struct rbl_filter_t {
+typedef struct iplist_filter_t {
     PA(rbldb_t) rbls;
     A(int)      weights;
     A(char)     hosts;
@@ -223,14 +234,23 @@ typedef struct rbl_filter_t {
 
     int32_t     hard_threshold;
     int32_t     soft_threshold;
-} rbl_filter_t;
+} iplist_filter_t;
+
+typedef struct iplist_async_data_t {
+    A(rbl_result_t) results;
+    int awaited;
+    uint32_t sum;
+    bool error;
+} iplist_async_data_t;
 
-static rbl_filter_t *rbl_filter_new(void)
+static filter_type_t filter_type = FTK_UNKNOWN;
+
+static iplist_filter_t *iplist_filter_new(void)
 {
-    return p_new(rbl_filter_t, 1);
+    return p_new(iplist_filter_t, 1);
 }
 
-static void rbl_filter_delete(rbl_filter_t **rbl)
+static void iplist_filter_delete(iplist_filter_t **rbl)
 {
     if (*rbl) {
         array_deep_wipe((*rbl)->rbls, rbldb_delete);
@@ -243,14 +263,14 @@ static void rbl_filter_delete(rbl_filter_t **rbl)
 }
 
 
-static bool rbl_filter_constructor(filter_t *filter)
+static bool iplist_filter_constructor(filter_t *filter)
 {
-    rbl_filter_t *data = rbl_filter_new();
+    iplist_filter_t *data = iplist_filter_new();
 
 #define PARSE_CHECK(Expr, Str, ...)                                            \
     if (!(Expr)) {                                                             \
         err(Str, ##__VA_ARGS__);                                               \
-        rbl_filter_delete(&data);                                              \
+        iplist_filter_delete(&data);                                              \
         return false;                                                          \
     }
 
@@ -268,7 +288,7 @@ static bool rbl_filter_constructor(filter_t *filter)
            *  the file pointed by filename MUST be a valid ip list issued from
            *  the rsync (or equivalent) service of a (r)bl.
            */
-          case ATK_FILE: {
+          case ATK_FILE: case ATK_RBLDNS: {
             bool lock = false;
             int  weight = 0;
             rbldb_t *rbl = NULL;
@@ -288,7 +308,7 @@ static bool rbl_filter_constructor(filter_t *filter)
                         lock = false;
                     } else {
                         PARSE_CHECK(false, "illegal locking state %.*s",
-                                    p - current, current);
+                                    (int)(p - current), current);
                     }
                     break;
 
@@ -296,7 +316,7 @@ static bool rbl_filter_constructor(filter_t *filter)
                     weight = strtol(current, &next, 10);
                     PARSE_CHECK(next == p && weight >= 0 && weight <= 1024,
                                 "illegal weight value %.*s",
-                                (p - current), current);
+                                (int)(p - current), current);
                     break;
 
                   case 2:
@@ -314,11 +334,11 @@ static bool rbl_filter_constructor(filter_t *filter)
             }
           } break;
 
-          /* host parameter.
+          /* dns parameter.
            *  weight:hostname.
            * define a RBL to use through DNS resolution.
            */
-          case ATK_HOST: {
+          case ATK_DNS: {
             int  weight = 0;
             const char *current = param->value;
             const char *p = m_strchrnul(param->value, ':');
@@ -331,7 +351,7 @@ static bool rbl_filter_constructor(filter_t *filter)
                     weight = strtol(current, &next, 10);
                     PARSE_CHECK(next == p && weight >= 0 && weight <= 1024,
                                 "illegal weight value %.*s",
-                                (p - current), current);
+                                (int)(p - current), current);
                     break;
 
                   case 1:
@@ -367,25 +387,72 @@ static bool rbl_filter_constructor(filter_t *filter)
         }
     }}
 
-    PARSE_CHECK(data->rbls.len
+    PARSE_CHECK(data->rbls.len || data->host_offsets.len,
                 "no file parameter in the filter %s", filter->name);
     filter->data = data;
     return true;
 }
 
-static void rbl_filter_destructor(filter_t *filter)
+static void iplist_filter_destructor(filter_t *filter)
 {
-    rbl_filter_t *data = filter->data;
-    rbl_filter_delete(&data);
+    iplist_filter_t *data = filter->data;
+    iplist_filter_delete(&data);
     filter->data = data;
 }
 
-static filter_result_t rbl_filter(const filter_t *filter, const query_t *query)
+static void iplist_filter_async(rbl_result_t *result, void *arg)
+{
+    filter_context_t   *context = arg;
+    const filter_t      *filter = context->current_filter;
+    const iplist_filter_t *data = filter->data;
+    iplist_async_data_t  *async = context->contexts[filter_type];
+
+
+    if (*result != RBL_ERROR) {
+        async->error = false;
+    }
+    --async->awaited;
+
+    debug("got asynchronous request result for filter %s, rbl %d, still awaiting %d answers",
+          filter->name, result - array_ptr(async->results, 0), async->awaited);
+
+    if (async->awaited == 0) {
+        filter_result_t res = HTK_FAIL;
+        if (async->error) {
+            res = HTK_ERROR;
+        } else {
+            for (uint32_t i = 0 ; i < array_len(data->host_offsets) ; ++i) {
+                int weight = array_elt(data->host_weights, i);
+
+                switch (array_elt(async->results, i)) {
+                  case RBL_ASYNC:
+                    crit("no more awaited answer but result is ASYNC");
+                    abort();
+                  case RBL_FOUND:
+                    async->sum += weight;
+                    break;
+                  default:
+                    break;
+                }
+            }
+            if (async->sum >= (uint32_t)data->hard_threshold) {
+                res = HTK_HARD_MATCH;
+            } else if (async->sum >= (uint32_t)data->soft_threshold) {
+                res = HTK_SOFT_MATCH;
+            }
+        }
+        debug("answering to filter %s", filter->name);
+        filter_post_async_result(context, res);
+    }
+}
+
+static filter_result_t iplist_filter(const filter_t *filter, const query_t *query,
+                                     filter_context_t *context)
 {
     uint32_t ip;
     int32_t sum = 0;
     const char *end = NULL;
-    const rbl_filter_t *data = filter->data;
+    const iplist_filter_t *data = filter->data;
     bool  error = true;
 
     if (parse_ipv4(query->client_address, &end, &ip) != 0) {
@@ -404,24 +471,22 @@ static filter_result_t rbl_filter(const filter_t *filter, const query_t *query)
         }
         error = false;
     }
-    for (uint32_t i = 0 ; i < data->host_offsets.len ; ++i) {
-        const char *rbl = array_ptr(data->hosts, array_elt(data->host_offsets, i));
-        int weight      = array_elt(data->host_weights, i);
-        switch (rbl_check(rbl, ip)) {
-          case RBL_FOUND:
-            error = false;
-            sum += weight;
-            if (sum >= data->hard_threshold) {
-                return HTK_HARD_MATCH;
+    if (array_len(data->host_offsets) > 0) {
+        iplist_async_data_t* async = context->contexts[filter_type];
+        array_ensure_exact_capacity(async->results, array_len(data->host_offsets));
+        async->sum = sum;
+        async->awaited = 0;
+        for (uint32_t i = 0 ; i < data->host_offsets.len ; ++i) {
+            const char *rbl = array_ptr(data->hosts, array_elt(data->host_offsets, i));
+            if (rbl_check(rbl, ip, array_ptr(async->results, i),
+                          iplist_filter_async, context)) {
+                error = false;
+                ++async->awaited;
             }
-            break;
-          case RBL_NOTFOUND:
-            error = false;
-            break;
-          case RBL_ERROR:
-            warn("rbl %s unavailable", rbl);
-            break;
         }
+        debug("filter %s awaiting %d asynchronous queries", filter->name, async->awaited);
+        async->error = error;
+        return HTK_ASYNC;
     }
     if (error) {
         err("filter %s: all the rbl returned an error", filter->name);
@@ -436,24 +501,39 @@ static filter_result_t rbl_filter(const filter_t *filter, const query_t *query)
     }
 }
 
-static int rbl_init(void)
+static void *iplist_context_constructor(void)
+{
+    return p_new(iplist_async_data_t, 1);
+}
+
+static void iplist_context_destructor(void *data)
+{
+    iplist_async_data_t *ctx = data;
+    p_delete(&ctx);
+}
+
+static int iplist_init(void)
 {
-    filter_type_t type =  filter_register("iplist", rbl_filter_constructor,
-                                          rbl_filter_destructor, rbl_filter);
+    filter_type =  filter_register("iplist", iplist_filter_constructor,
+                                   iplist_filter_destructor, iplist_filter,
+                                   iplist_context_constructor,
+                                   iplist_context_destructor);
     /* Hooks.
      */
-    (void)filter_hook_register(type, "abort");
-    (void)filter_hook_register(type, "error");
-    (void)filter_hook_register(type, "fail");
-    (void)filter_hook_register(type, "hard_match");
-    (void)filter_hook_register(type, "soft_match");
+    (void)filter_hook_register(filter_type, "abort");
+    (void)filter_hook_register(filter_type, "error");
+    (void)filter_hook_register(filter_type, "fail");
+    (void)filter_hook_register(filter_type, "hard_match");
+    (void)filter_hook_register(filter_type, "soft_match");
+    (void)filter_hook_register(filter_type, "async");
 
     /* Parameters.
      */
-    (void)filter_param_register(type, "file");
-    (void)filter_param_register(type, "host");
-    (void)filter_param_register(type, "hard_threshold");
-    (void)filter_param_register(type, "soft_threshold");
+    (void)filter_param_register(filter_type, "file");
+    (void)filter_param_register(filter_type, "rbldns");
+    (void)filter_param_register(filter_type, "dns");
+    (void)filter_param_register(filter_type, "hard_threshold");
+    (void)filter_param_register(filter_type, "soft_threshold");
     return 0;
 }
-module_init(rbl_init);
+module_init(iplist_init);