Rename project -> pfixtools.
[apps/pfixtools.git] / srsd.c
diff --git a/srsd.c b/srsd.c
index aca3f79..8d88677 100644 (file)
--- a/srsd.c
+++ b/srsd.c
@@ -1,5 +1,5 @@
 /******************************************************************************/
-/*          postlicyd: a postfix policy daemon with a lot of features         */
+/*          pfixtools: a collection of postfix related tools                  */
 /*          ~~~~~~~~~                                                         */
 /*  ________________________________________________________________________  */
 /*                                                                            */
  */
 
 #include <fcntl.h>
+#include <netinet/in.h>
+#include <sys/epoll.h>
 #include <sys/stat.h>
 
 #include <srs2.h>
 
 #include "common.h"
+#include "daemon.h"
 #include "mem.h"
+#include "buffer.h"
+
+#define DAEMON_NAME             "srsd"
+#define DEFAULT_ENCODER_PORT    10000
+#define DEFAULT_DECODER_PORT    10001
+#define __tostr(x)  #x
+#define STR(x)      __tostr(x)
+
+/* srs encoder/decoder/listener worker {{{ */
+
+typedef struct srsd_t {
+    unsigned listener : 1;
+    unsigned decoder  : 1;
+    unsigned watchwr  : 1;
+    int fd;
+    buffer_t ibuf;
+    buffer_t obuf;
+} srsd_t;
+
+static srsd_t *srsd_new(void)
+{
+    srsd_t *srsd = p_new(srsd_t, 1);
+    srsd->fd = -1;
+    return srsd;
+}
 
-#define MAX_SIZE 0x10000
+static void srsd_delete(srsd_t **srsd)
+{
+    if (*srsd) {
+        if ((*srsd)->fd >= 0)
+            close((*srsd)->fd);
+        buffer_wipe(&(*srsd)->ibuf);
+        buffer_wipe(&(*srsd)->obuf);
+        p_delete(srsd);
+    }
+}
 
-static char **read_sfile(char *sfile)
+void urldecode(char *s, char *end)
 {
-    int fd  = -1;
-    int pos = 0;
-    int nb  = 0;
-    int len = 0;
+    char *p = s;
 
-    char  *buf = NULL;
-    char **res = NULL;
+    while (*p) {
+        if (*p == '%' && end - p >= 3) {
+            int h = (hexval(p[1]) << 4) | hexval(p[2]);
 
-    struct stat stat_buf;
+            if (h >= 0) {
+                *s++ = h;
+                p += 3;
+                continue;
+            }
+        }
 
-    if (stat(sfile, &stat_buf)) {
-        perror("stat");
-        exit(1);
+        *s++ = *p++;
     }
+    *s++ = '\0';
+}
 
-    if (stat_buf.st_size > MAX_SIZE) {
-        fprintf(stderr, "the secret file is too big\n");
-        exit(1);
-    }
+int process_srs(srs_t *srs, const char *domain, srsd_t *srsd)
+{
+    while (srsd->ibuf.len > 4) {
+        char buf[BUFSIZ], *p, *q, *nl;
+        int err;
+
+        nl = strchr(srsd->ibuf.data + 4, '\n');
+        if (!nl) {
+            if (srsd->ibuf.len > BUFSIZ) {
+                syslog(LOG_ERR, "unreasonnable amount of data without a \\n");
+                return -1;
+            }
+            return 0;
+        }
 
-    buf = (char *)malloc(stat_buf.st_size+1);
-    buf[stat_buf.st_size] = 0;
+        if (strncmp("get ", srsd->ibuf.data, 4)) {
+            syslog(LOG_ERR, "bad request, not starting with \"get \"");
+            return -1;
+        }
 
-    if ((fd = open(sfile, O_RDONLY)) < 0) {
-        perror("open");
-        exit (1);
-    }
+        for (p = srsd->ibuf.data + 4; p < nl && isspace(*p); p++);
+        for (q = nl++; q >= p && isspace(*q); *q-- = '\0');
 
-    for (;;) {
-        if ((nb = read(fd, &(buf[pos]), stat_buf.st_size)) < 0) {
-            if (errno == EINTR)
-                continue;
-            perror("read");
-            exit(1);
+        if (p == q) {
+            buffer_addstr(&srsd->obuf, "400 empty request ???\n");
+            syslog(LOG_WARNING, "empty request");
+            goto skip;
         }
-        pos += nb;
-        if (nb == 0 || pos == stat_buf.st_size) {
-            close(fd);
-            fd = -1;
-            break;
+
+        urldecode(p, q);
+
+        if (srsd->decoder) {
+            err = srs_reverse(srs, buf, ssizeof(buf), p);
+        } else {
+            err = srs_forward(srs, buf, ssizeof(buf), p, domain);
         }
-    }
 
-    for ( nb = pos = 0; pos < stat_buf.st_size ; pos++)
-    {
-        if ( buf[pos] == '\n' ) {
-            nb++;
-            buf[pos] = 0;
+        if (err == 0) {
+            buffer_addstr(&srsd->obuf, "200 ");
+            buffer_addstr(&srsd->obuf, buf);
+        } else {
+            switch (SRS_ERROR_TYPE(err)) {
+              case SRS_ERRTYPE_SRS:
+              case SRS_ERRTYPE_SYNTAX:
+                buffer_addstr(&srsd->obuf, "500 ");
+                break;
+              default:
+                buffer_addstr(&srsd->obuf, "400 ");
+                break;
+            }
+            buffer_addstr(&srsd->obuf, srs_strerror(err));
         }
+        buffer_addch(&srsd->obuf, '\n');
+
+      skip:
+        buffer_consume(&srsd->ibuf, nl - srsd->ibuf.data);
     }
 
-    res = p_new(char*, nb + 2);
+    return 0;
+}
 
-    nb = pos = 0;
-    while (pos < stat_buf.st_size)
-    {
-        len = strlen(&(buf[pos]));
-        if (len) {
-            res[nb++] = &(buf[pos]);
-        }
-        pos += len+1;
+int start_listener(int epollfd, int port, bool decoder)
+{
+    struct sockaddr_in addr = {
+        .sin_family = AF_INET,
+        .sin_addr   = { htonl(INADDR_LOOPBACK) },
+    };
+    struct epoll_event evt = { .events = EPOLLIN };
+    srsd_t *tmp;
+    int sock;
+
+    addr.sin_port = htons(port);
+    sock = tcp_listen((const struct sockaddr *)&addr, sizeof(addr));
+    if (sock < 0) {
+        return -1;
     }
 
-    return res;
+    evt.data.ptr  = tmp = srsd_new();
+    tmp->fd       = sock;
+    tmp->decoder  = decoder;
+    tmp->listener = true;
+    if (epoll_ctl(epollfd, EPOLL_CTL_ADD, sock, &evt) < 0) {
+        UNIXERR("epoll_ctl");
+        return -1;
+    }
+    return 0;
 }
 
+/* }}} */
+/* administrivia {{{ */
 
-static char *encode(char * secret, char * sender, char * alias)
+static int main_initialize(void)
 {
-    int    err = 0;
-    char  *res = NULL;
-    srs_t *srs = srs_new();
+    openlog(DAEMON_NAME, LOG_PID, LOG_MAIL);
+    signal(SIGPIPE, SIG_IGN);
+    signal(SIGINT,  &common_sighandler);
+    signal(SIGTERM, &common_sighandler);
+    signal(SIGHUP,  &common_sighandler);
+    syslog(LOG_INFO, "Starting...");
+    return 0;
+}
 
-    srs_add_secret(srs, secret);
-    err = srs_forward_alloc(srs, &res, sender, alias);
+static void main_shutdown(void)
+{
+    syslog(LOG_INFO, cleanexit ? "Stopping..." : "Unclean exit...");
+    closelog();
+}
 
-    if (res == NULL) {
-        fprintf(stderr, "%s\n", srs_strerror(err));
-        exit (1);
-    }
+module_init(main_initialize);
+module_exit(main_shutdown);
 
-    return res;
+void usage(void)
+{
+    fputs("usage: "DAEMON_NAME" [ -e <port> ] [ -d <port> ] domain secrets\n"
+          "\n"
+          "    -e <port>    port to listen to for encoding requests\n"
+          "                 (default: "STR(DEFAULT_ENCODER_PORT)")\n"
+          "    -d <port>    port to listen to for decoding requests\n"
+          "                 (default: "STR(DEFAULT_DECODER_PORT)")\n"
+         , stderr);
 }
 
-static char * decode(char * secret, char * secrets[], char * sender)
-{
-    int     err = 0;
-    char *  res = NULL;
-    srs_t * srs = srs_new();
+/* }}} */
 
-    if (secret) {
-        srs_add_secret(srs, secret);
-    }
+int main_loop(srs_t *srs, const char *domain, int port_enc, int port_dec)
+{
+    int exitcode = EXIT_SUCCESS;
+    int epollfd = epoll_create(128);
 
-    for (; secrets && secrets[err] != 0; err++) {
-        srs_add_secret(srs, secrets[err]);
+    if (epollfd < 0) {
+        UNIXERR("epoll_create");
+        exitcode = EXIT_FAILURE;
+        goto error;
     }
 
-    err = srs_reverse_alloc(srs, &res, sender);
+    if (start_listener(epollfd, port_enc, false) < 0)
+        return EXIT_FAILURE;
+    if (start_listener(epollfd, port_dec, true) < 0)
+        return EXIT_FAILURE;
+
+    while (!sigint) {
+        struct epoll_event evts[1024];
+        int n;
+
+        n = epoll_wait(epollfd, evts, countof(evts), -1);
+        if (n < 0) {
+            if (errno != EAGAIN && errno != EINTR) {
+                UNIXERR("epoll_wait");
+                exitcode = EXIT_FAILURE;
+                break;
+            }
+            continue;
+        }
 
-    if (res == NULL) {
-        fprintf(stderr, "%s\n", srs_strerror(err));
-        exit(1);
+        while (--n >= 0) {
+            srsd_t *srsd = evts[n].data.ptr;
+
+            if (srsd->listener) {
+                struct epoll_event evt = { .events = EPOLLIN };
+                srsd_t *tmp;
+                int sock;
+
+                sock = accept_nonblock(srsd->fd);
+                if (sock < 0) {
+                    UNIXERR("accept");
+                    continue;
+                }
+
+                evt.data.ptr = tmp = srsd_new();
+                tmp->decoder = srsd->decoder;
+                tmp->fd      = sock;
+                if (epoll_ctl(epollfd, EPOLL_CTL_ADD, sock, &evt) < 0) {
+                    UNIXERR("epoll_ctl");
+                    srsd_delete(&tmp);
+                    close(sock);
+                }
+                continue;
+            }
+
+            if (evts[n].events & EPOLLIN) {
+                int res = buffer_read(&srsd->ibuf, srsd->fd, -1);
+
+                if ((res < 0 && errno != EINTR && errno != EAGAIN)
+                ||  res == 0)
+                {
+                    srsd_delete(&srsd);
+                    continue;
+                }
+
+                if (process_srs(srs, domain, srsd) < 0) {
+                    srsd_delete(&srsd);
+                    continue;
+                }
+            }
+
+            if ((evts[n].events & EPOLLOUT) && srsd->obuf.len) {
+                int res = write(srsd->fd, srsd->obuf.data, srsd->obuf.len);
+
+                if (res < 0 && errno != EINTR && errno != EAGAIN) {
+                    srsd_delete(&srsd);
+                    continue;
+                }
+
+                if (res > 0) {
+                    buffer_consume(&srsd->obuf, res);
+                }
+            }
+
+            if (srsd->watchwr == !srsd->obuf.len) {
+                struct epoll_event evt = {
+                    .events   = EPOLLIN | (srsd->obuf.len ? EPOLLOUT : 0),
+                    .data.ptr = srsd,
+                };
+                if (epoll_ctl(epollfd, EPOLL_CTL_MOD, srsd->fd, &evt) < 0) {
+                    UNIXERR("epoll_ctl");
+                    srsd_delete(&srsd);
+                    continue;
+                }
+                srsd->watchwr = srsd->obuf.len != 0;
+            }
+        }
     }
 
-    return res;
-}
+    close(epollfd);
 
-static void help(void)
-{
-    puts(
-            "Usage: srs-c [ -r | -d domain ] [ -s secret | -f sfile ] -e sender\n"
-            "Perform an SRS encoding / decoding\n"
-            "\n"
-            "    -r          perform an SRS decoding\n"
-            "    -d domain   use that domain (required for encoding)\n"
-            "\n"
-            "    -s secret   secret used in the encoding (sfile required if omitted)\n"
-            "    -f sfile    secret file for decoding.  the first line is taken if -s omitted\n"
-            "\n"
-            "    -e sender   the sender address we want to encode/decode\n"
-          );
-    exit (1);
+  error:
+    cleanexit = true;
+    return exitcode;
 }
 
-int main(int argc, char * argv[])
+static srs_t *srs_read_secrets(const char *sfile)
 {
-    char *buf    = NULL;
-    char *domain = NULL;
-    char *sender = NULL;
-    char *secret = NULL;
-    char *sfile  = NULL;
-
-    int    opt   = 0;
-    bool   rev   = false;
-    char **secr  = NULL;
-
-    while ((opt = getopt(argc, argv, "d:e:s:f:r")) != -1)
-    {
-        switch (opt) {
-            case 'd': domain = optarg;  break;
-            case 'e': sender = optarg;  break;
-            case 'f': sfile  = optarg;  break;
-            case 'r': rev    = true;    break;
-            case 's': secret = optarg;  break;
+    srs_t *srs;
+    char buf[BUFSIZ];
+    FILE *f;
+    int lineno = 0;
+
+    f = fopen(sfile, "r");
+    if (!f) {
+        UNIXERR("fopen");
+        return NULL;
+    }
+
+    srs = srs_new();
+
+    while (fgets(buf, sizeof(buf), f)) {
+        int n = strlen(buf);
+
+        ++lineno;
+        if (buf[n - 1] != '\n') {
+            syslog(LOG_CRIT, "%s:%d: line too long", sfile, lineno);
+            goto error;
         }
+
+        srs_add_secret(srs, buf);
+    }
+
+    if (!lineno) {
+        syslog(LOG_CRIT, "%s: empty file, no secrets", sfile);
+        goto error;
     }
 
-    if ( !sender || !(secret||sfile) || !(rev||domain) ) {
-        help ();
+    fclose(f);
+    return srs;
+
+  error:
+    fclose(f);
+    srs_free(srs);
+    return NULL;
+}
+
+int main(int argc, char *argv[])
+{
+    int port_enc = DEFAULT_ENCODER_PORT;
+    int port_dec = DEFAULT_DECODER_PORT;
+
+    srs_t *srs;
+
+    if (atexit(common_shutdown)) {
+        fputs("Cannot hook my atexit function, quitting !\n", stderr);
+        return EXIT_FAILURE;
     }
+    common_initialize();
 
-    if (sfile) {
-        secr = read_sfile(sfile);
-        if (!secret && (!secr || !secr[0])) {
-            fprintf(stderr, "No secret given, and secret file is empty\n");
-            exit (1);
+    for (int c = 0; (c = getopt(argc, argv, "he:d:")) >= 0; ) {
+        switch (c) {
+          case 'e':
+            port_enc = atoi(optarg);
+            break;
+          case 'd':
+            port_dec = atoi(optarg);
+            break;
+          default:
+            usage();
+            return EXIT_FAILURE;
         }
     }
 
-    if (rev) {
-        buf = decode(secret, secr, sender);
-    } else {
-        buf = encode((secret ? secret : secr[0]), sender, domain);
+    if (argc - optind != 2) {
+        usage();
+        return EXIT_FAILURE;
     }
 
-    puts(buf);
-    return 0;
+    srs = srs_read_secrets(argv[optind + 1]);
+    if (!srs) {
+        return EXIT_FAILURE;
+    }
+
+    return main_loop(srs, argv[optind], port_enc, port_dec);
 }