* Copyright © 2008 Florent Bruneau
*/
+#include <unbound.h>
#include <netdb.h>
+#include "array.h"
+#include "epoll.h"
+#include "server.h"
#include "rbl.h"
-static inline rbl_result_t rbl_dns_check(const char *hostname)
+
+typedef struct rbl_context_t {
+ rbl_result_t *result;
+ rbl_result_callback_t call;
+ void *data;
+} rbl_context_t;
+ARRAY(rbl_context_t);
+
+static struct ub_ctx *ctx = NULL;
+static PA(rbl_context_t) ctx_pool = ARRAY_INIT;
+
+static rbl_context_t *rbl_context_new(void)
+{
+ return p_new(rbl_context_t, 1);
+}
+
+static void rbl_context_delete(rbl_context_t **context)
+{
+ if (*context) {
+ p_delete(context);
+ }
+}
+
+static void rbl_context_wipe(rbl_context_t *context)
+{
+ p_clear(context, 1);
+}
+
+static rbl_context_t *rbl_context_acquire(void)
+{
+ if (array_len(ctx_pool) > 0) {
+ return array_pop_last(ctx_pool);
+ } else {
+ return rbl_context_new();
+ }
+}
+
+static void rbl_context_release(rbl_context_t *context)
+{
+ rbl_context_wipe(context);
+ array_add(ctx_pool, context);
+}
+
+static void rbl_exit(void)
{
- debug("looking up for %s", hostname);
- struct hostent *host = gethostbyname(hostname);
- if (host != NULL) {
- debug("host found");
- return RBL_FOUND;
+ if (ctx != NULL) {
+ ub_ctx_delete(ctx);
+ ctx = NULL;
+ }
+ array_deep_wipe(ctx_pool, rbl_context_delete);
+}
+module_exit(rbl_exit);
+
+static void rbl_callback(void *arg, int err, struct ub_result *result)
+{
+ rbl_context_t *context = arg;
+ if (err != 0) {
+ debug("asynchronous request led to an error");
+ *context->result = RBL_ERROR;
+ } else if (result->nxdomain) {
+ debug("asynchronous request done, %s NOT FOUND", result->qname);
+ *context->result = RBL_NOTFOUND;
+ } else {
+ debug("asynchronous request done, %s FOUND", result->qname);
+ *context->result = RBL_FOUND;
+ }
+ if (context->call != NULL) {
+ debug("calling callback");
+ context->call(context->result, context->data);
} else {
- if (h_errno == HOST_NOT_FOUND) {
- debug("host not found: %s", hostname);
- return RBL_NOTFOUND;
+ debug("no callback defined");
+ }
+ ub_resolve_free(result);
+ rbl_context_release(context);
+}
+
+static int rbl_handler(server_t *event, void *config)
+{
+ int retval = 0;
+ debug("rbl_handler called: ub_fd triggered");
+ epoll_modify(event->fd, 0, event);
+ if ((retval = ub_process(ctx)) != 0) {
+ err("error in DNS resolution: %s", ub_strerror(retval));
+ }
+ epoll_modify(event->fd, EPOLLIN, event);
+ return 0;
+}
+
+static inline bool rbl_dns_check(const char *hostname, rbl_result_t *result,
+ rbl_result_callback_t callback, void *data)
+{
+ if (ctx == NULL) {
+ ctx = ub_ctx_create();
+ ub_ctx_async(ctx, true);
+ if (server_register(ub_fd(ctx), rbl_handler, NULL) == NULL) {
+ crit("cannot register asynchronous DNS event handler");
+ abort();
}
- debug("dns error: %m");
- return RBL_ERROR;
+ }
+ rbl_context_t *context = rbl_context_acquire();
+ context->result = result;
+ context->call = callback;
+ context->data = data;
+ if (ub_resolve_async(ctx, (char*)hostname, 1, 1, context, rbl_callback, NULL) == 0) {
+ *result = RBL_ASYNC;
+ return true;
+ } else {
+ *result = RBL_ERROR;
+ rbl_context_release(context);
+ return false;
}
}
-rbl_result_t rbl_check(const char *rbl, uint32_t ip)
+bool rbl_check(const char *rbl, uint32_t ip, rbl_result_t *result,
+ rbl_result_callback_t callback, void *data)
{
char host[257];
- snprintf(host, 257, "%d.%d.%d.%d.%s",
- ip & 0xff, (ip >> 8) & 0xff, (ip >> 16) & 0xff, (ip >> 24) & 0xff,
- rbl);
- return rbl_dns_check(host);
+ int len;
+
+ len = snprintf(host, 257, "%d.%d.%d.%d.%s.",
+ ip & 0xff, (ip >> 8) & 0xff, (ip >> 16) & 0xff, (ip >> 24) & 0xff,
+ rbl);
+ if (len >= (int)sizeof(host))
+ return RBL_ERROR;
+ if (host[len - 2] == '.')
+ host[len - 1] = '\0';
+ return rbl_dns_check(host, result, callback, data);
}
-rbl_result_t rhbl_check(const char *rhbl, const char *hostname)
+bool rhbl_check(const char *rhbl, const char *hostname, rbl_result_t *result,
+ rbl_result_callback_t callback, void *data)
{
char host[257];
- snprintf(host, 257, "%s.%s", hostname, rhbl);
- return rbl_dns_check(host);
+ int len;
+
+ len = snprintf(host, 257, "%s.%s.", hostname, rhbl);
+ if (len >= (int)sizeof(host))
+ return RBL_ERROR;
+ if (host[len - 2] == '.')
+ host[len - 1] = '\0';
+ return rbl_dns_check(host, result, callback, data);
}