Asynchronous DNS queries on iplist.
[apps/pfixtools.git] / postlicyd / iplist.c
index 867f88b..c2b471b 100644 (file)
@@ -225,7 +225,7 @@ bool rbldb_ipv4_lookup(const rbldb_t *db, uint32_t ip)
 
 #include "filter.h"
 
-typedef struct rbl_filter_t {
+typedef struct iplist_filter_t {
     PA(rbldb_t) rbls;
     A(int)      weights;
     A(char)     hosts;
@@ -234,14 +234,23 @@ typedef struct rbl_filter_t {
 
     int32_t     hard_threshold;
     int32_t     soft_threshold;
-} rbl_filter_t;
+} iplist_filter_t;
 
-static rbl_filter_t *rbl_filter_new(void)
+typedef struct iplist_async_data_t {
+    A(rbl_result_t) results;
+    int awaited;
+    uint32_t sum;
+    bool error;
+} iplist_async_data_t;
+
+static filter_type_t filter_type = FTK_UNKNOWN;
+
+static iplist_filter_t *iplist_filter_new(void)
 {
-    return p_new(rbl_filter_t, 1);
+    return p_new(iplist_filter_t, 1);
 }
 
-static void rbl_filter_delete(rbl_filter_t **rbl)
+static void iplist_filter_delete(iplist_filter_t **rbl)
 {
     if (*rbl) {
         array_deep_wipe((*rbl)->rbls, rbldb_delete);
@@ -254,14 +263,14 @@ static void rbl_filter_delete(rbl_filter_t **rbl)
 }
 
 
-static bool rbl_filter_constructor(filter_t *filter)
+static bool iplist_filter_constructor(filter_t *filter)
 {
-    rbl_filter_t *data = rbl_filter_new();
+    iplist_filter_t *data = iplist_filter_new();
 
 #define PARSE_CHECK(Expr, Str, ...)                                            \
     if (!(Expr)) {                                                             \
         err(Str, ##__VA_ARGS__);                                               \
-        rbl_filter_delete(&data);                                              \
+        iplist_filter_delete(&data);                                              \
         return false;                                                          \
     }
 
@@ -384,20 +393,66 @@ static bool rbl_filter_constructor(filter_t *filter)
     return true;
 }
 
-static void rbl_filter_destructor(filter_t *filter)
+static void iplist_filter_destructor(filter_t *filter)
 {
-    rbl_filter_t *data = filter->data;
-    rbl_filter_delete(&data);
+    iplist_filter_t *data = filter->data;
+    iplist_filter_delete(&data);
     filter->data = data;
 }
 
-static filter_result_t rbl_filter(const filter_t *filter, const query_t *query,
-                                  filter_context_t *context)
+static void iplist_filter_async(rbl_result_t *result, void *arg)
+{
+    filter_context_t   *context = arg;
+    const filter_t      *filter = context->current_filter;
+    const iplist_filter_t *data = filter->data;
+    iplist_async_data_t  *async = context->contexts[filter_type];
+
+
+    if (*result != RBL_ERROR) {
+        async->error = false;
+    }
+    --async->awaited;
+
+    debug("got asynchronous request result for filter %s, rbl %d, still awaiting %d answers",
+          filter->name, result - array_ptr(async->results, 0), async->awaited);
+
+    if (async->awaited == 0) {
+        filter_result_t res = HTK_FAIL;
+        if (async->error) {
+            res = HTK_ERROR;
+        } else {
+            for (uint32_t i = 0 ; i < array_len(data->host_offsets) ; ++i) {
+                int weight = array_elt(data->host_weights, i);
+
+                switch (array_elt(async->results, i)) {
+                  case RBL_ASYNC:
+                    crit("no more awaited answer but result is ASYNC");
+                    abort();
+                  case RBL_FOUND:
+                    async->sum += weight;
+                    break;
+                  default:
+                    break;
+                }
+            }
+            if (async->sum >= (uint32_t)data->hard_threshold) {
+                res = HTK_HARD_MATCH;
+            } else if (async->sum >= (uint32_t)data->soft_threshold) {
+                res = HTK_SOFT_MATCH;
+            }
+        }
+        debug("answering to filter %s", filter->name);
+        filter_post_async_result(context, res);
+    }
+}
+
+static filter_result_t iplist_filter(const filter_t *filter, const query_t *query,
+                                     filter_context_t *context)
 {
     uint32_t ip;
     int32_t sum = 0;
     const char *end = NULL;
-    const rbl_filter_t *data = filter->data;
+    const iplist_filter_t *data = filter->data;
     bool  error = true;
 
     if (parse_ipv4(query->client_address, &end, &ip) != 0) {
@@ -416,24 +471,22 @@ static filter_result_t rbl_filter(const filter_t *filter, const query_t *query,
         }
         error = false;
     }
-    for (uint32_t i = 0 ; i < data->host_offsets.len ; ++i) {
-        const char *rbl = array_ptr(data->hosts, array_elt(data->host_offsets, i));
-        int weight      = array_elt(data->host_weights, i);
-        switch (rbl_check(rbl, ip)) {
-          case RBL_FOUND:
-            error = false;
-            sum += weight;
-            if (sum >= data->hard_threshold) {
-                return HTK_HARD_MATCH;
+    if (array_len(data->host_offsets) > 0) {
+        iplist_async_data_t* async = context->contexts[filter_type];
+        array_ensure_exact_capacity(async->results, array_len(data->host_offsets));
+        async->sum = sum;
+        async->awaited = 0;
+        for (uint32_t i = 0 ; i < data->host_offsets.len ; ++i) {
+            const char *rbl = array_ptr(data->hosts, array_elt(data->host_offsets, i));
+            if (rbl_check(rbl, ip, array_ptr(async->results, i),
+                          iplist_filter_async, context)) {
+                error = false;
+                ++async->awaited;
             }
-            break;
-          case RBL_NOTFOUND:
-            error = false;
-            break;
-          case RBL_ERROR:
-            warn("rbl %s unavailable", rbl);
-            break;
         }
+        debug("filter %s awaiting %d asynchronous queries", filter->name, async->awaited);
+        async->error = error;
+        return HTK_ASYNC;
     }
     if (error) {
         err("filter %s: all the rbl returned an error", filter->name);
@@ -448,27 +501,39 @@ static filter_result_t rbl_filter(const filter_t *filter, const query_t *query,
     }
 }
 
-static int rbl_init(void)
+static void *iplist_context_constructor(void)
+{
+    return p_new(iplist_async_data_t, 1);
+}
+
+static void iplist_context_destructor(void *data)
+{
+    iplist_async_data_t *ctx = data;
+    p_delete(&ctx);
+}
+
+static int iplist_init(void)
 {
-    filter_type_t type =  filter_register("iplist", rbl_filter_constructor,
-                                          rbl_filter_destructor, rbl_filter,
-                                          NULL, NULL);
+    filter_type =  filter_register("iplist", iplist_filter_constructor,
+                                   iplist_filter_destructor, iplist_filter,
+                                   iplist_context_constructor,
+                                   iplist_context_destructor);
     /* Hooks.
      */
-    (void)filter_hook_register(type, "abort");
-    (void)filter_hook_register(type, "error");
-    (void)filter_hook_register(type, "fail");
-    (void)filter_hook_register(type, "hard_match");
-    (void)filter_hook_register(type, "soft_match");
-    (void)filter_hook_register(type, "async");
+    (void)filter_hook_register(filter_type, "abort");
+    (void)filter_hook_register(filter_type, "error");
+    (void)filter_hook_register(filter_type, "fail");
+    (void)filter_hook_register(filter_type, "hard_match");
+    (void)filter_hook_register(filter_type, "soft_match");
+    (void)filter_hook_register(filter_type, "async");
 
     /* Parameters.
      */
-    (void)filter_param_register(type, "file");
-    (void)filter_param_register(type, "rbldns");
-    (void)filter_param_register(type, "dns");
-    (void)filter_param_register(type, "hard_threshold");
-    (void)filter_param_register(type, "soft_threshold");
+    (void)filter_param_register(filter_type, "file");
+    (void)filter_param_register(filter_type, "rbldns");
+    (void)filter_param_register(filter_type, "dns");
+    (void)filter_param_register(filter_type, "hard_threshold");
+    (void)filter_param_register(filter_type, "soft_threshold");
     return 0;
 }
-module_init(rbl_init);
+module_init(iplist_init);