Srs tcp_table(5) daemon.
[apps/pfixtools.git] / srsd.c
diff --git a/srsd.c b/srsd.c
index 7d25046..cd5c125 100644 (file)
--- a/srsd.c
+++ b/srsd.c
  */
 
 #include <fcntl.h>
+#include <netinet/in.h>
+#include <sys/epoll.h>
 #include <sys/stat.h>
 
 #include <srs2.h>
 
 #include "common.h"
+#include "daemon.h"
 #include "mem.h"
+#include "buffer.h"
 
-static srs_t *srs = NULL;
+#define DEFAULT_ENCODER_PORT    10000
+#define DEFAULT_DECODER_PORT    10001
+#define __tostr(x)  #x
+#define STR(x)      __tostr(x)
 
-static int read_sfile(const char *sfile)
+/* srs encoder/decoder/listener worker {{{ */
+
+typedef struct srsd_t {
+    unsigned listener : 1;
+    unsigned decoder  : 1;
+    unsigned watchwr  : 1;
+    int fd;
+    buffer_t ibuf;
+    buffer_t obuf;
+} srsd_t;
+
+static srsd_t *srsd_new(void)
 {
-    FILE *f;
+    srsd_t *srsd = p_new(srsd_t, 1);
+    srsd->fd = -1;
+    return srsd;
+}
+
+static void srsd_delete(srsd_t **srsd)
+{
+    if (*srsd) {
+        if ((*srsd)->fd >= 0)
+            close((*srsd)->fd);
+        buffer_wipe(&(*srsd)->ibuf);
+        buffer_wipe(&(*srsd)->obuf);
+        p_delete(srsd);
+    }
+}
+
+int process_srs(srs_t *srs, const char *domain, srsd_t *srsd)
+{
+    while (srsd->ibuf.len > 4) {
+        char buf[BUFSIZ], *p, *q, *nl;
+        int err;
+
+        if (strncmp("get ", srsd->ibuf.data, 4)) {
+            syslog(LOG_ERR, "bad request, not starting with \"get \"");
+            return -1;
+        }
+
+        nl = strchr(srsd->ibuf.data + 4, '\n');
+        if (!nl)
+            return 0;
+
+        for (p = srsd->ibuf.data + 4; p < nl && isspace(*p); p++);
+        for (q = nl++; q >= p && isspace(*q); *q-- = '\0');
+
+        if (p == q) {
+            syslog(LOG_WARNING, "empty request");
+            goto skip;
+        }
+
+        if (srsd->decoder) {
+            err = srs_reverse(srs, buf, ssizeof(buf), p);
+        } else {
+            err = srs_forward(srs, buf, ssizeof(buf), p, domain);
+        }
+
+        if (err == 0) {
+            buffer_addstr(&srsd->obuf, "200 ");
+            buffer_addstr(&srsd->obuf, buf);
+            buffer_addstr(&srsd->obuf, "\r\n");
+        } else {
+            switch (SRS_ERROR_TYPE(err)) {
+              case SRS_ERRTYPE_SRS:
+              case SRS_ERRTYPE_SYNTAX:
+                buffer_addstr(&srsd->obuf, "500 ");
+                break;
+              default:
+                buffer_addstr(&srsd->obuf, "400 ");
+                break;
+            }
+            buffer_addstr(&srsd->obuf, srs_strerror(err));
+            buffer_addstr(&srsd->obuf, "\r\n");
+        }
+
+      skip:
+        buffer_consume(&srsd->ibuf, nl - srsd->ibuf.data);
+    }
+
+    return 0;
+}
+
+int start_listener(int epollfd, int port, bool decoder)
+{
+    struct sockaddr_in addr = {
+        .sin_family = AF_INET,
+        .sin_addr   = { htonl(INADDR_LOOPBACK) },
+    };
+    struct epoll_event evt = { .events = EPOLLIN };
+    srsd_t *tmp;
+    int sock;
+
+    addr.sin_port = htons(port);
+    sock = tcp_listen((const struct sockaddr *)&addr, sizeof(addr));
+    if (sock < 0) {
+        return -1;
+    }
+
+    evt.data.ptr  = tmp = srsd_new();
+    tmp->fd       = sock;
+    tmp->decoder  = decoder;
+    tmp->listener = true;
+    if (epoll_ctl(epollfd, EPOLL_CTL_ADD, sock, &evt) < 0) {
+        UNIXERR("epoll_ctl");
+        return -1;
+    }
+    return 0;
+}
+
+/* }}} */
+/* administrivia {{{ */
+
+static int main_initialize(void)
+{
+    openlog("srsd", LOG_PID, LOG_MAIL);
+    signal(SIGPIPE, SIG_IGN);
+    signal(SIGINT,  &common_sighandler);
+    signal(SIGTERM, &common_sighandler);
+    signal(SIGHUP,  &common_sighandler);
+    syslog(LOG_INFO, "Starting...");
+    return 0;
+}
+
+static void main_shutdown(void)
+{
+    syslog(LOG_INFO, cleanexit ? "Stopping..." : "Unclean exit...");
+    closelog();
+}
+
+module_init(main_initialize);
+module_exit(main_shutdown);
+
+void usage(void)
+{
+    fputs("usage: srsd [ -e <port> ] [ -d <port> ] domain secrets\n"
+          "\n"
+          "    -e <port>    port to listen to for encoding requests\n"
+          "                 (default: "STR(DEFAULT_ENCODER_PORT)")\n"
+          "    -d <port>    port to listen to for decoding requests\n"
+          "                 (default: "STR(DEFAULT_DECODER_PORT)")\n"
+         , stderr);
+}
+
+/* }}} */
+
+int main_loop(srs_t *srs, const char *domain, int port_enc, int port_dec)
+{
+    int exitcode = EXIT_SUCCESS;
+    int epollfd = epoll_create(128);
+
+    if (epollfd < 0) {
+        UNIXERR("epoll_create");
+        exitcode = EXIT_FAILURE;
+        goto error;
+    }
+
+    if (start_listener(epollfd, port_enc, false) < 0)
+        return EXIT_FAILURE;
+    if (start_listener(epollfd, port_dec, true) < 0)
+        return EXIT_FAILURE;
+
+    while (!sigint) {
+        struct epoll_event evts[1024];
+        int n;
+
+        n = epoll_wait(epollfd, evts, countof(evts), -1);
+        if (n < 0) {
+            if (errno != EAGAIN && errno != EINTR) {
+                UNIXERR("epoll_wait");
+                exitcode = EXIT_FAILURE;
+                break;
+            }
+            continue;
+        }
+
+        while (--n >= 0) {
+            srsd_t *srsd = evts[n].data.ptr;
+
+            if (srsd->listener) {
+                struct epoll_event evt = { .events = EPOLLIN };
+                srsd_t *tmp;
+                int sock;
+
+                sock = accept(srsd->fd, NULL, NULL);
+                if (sock < 0) {
+                    UNIXERR("accept");
+                    continue;
+                }
+
+                evt.data.ptr = tmp = srsd_new();
+                tmp->decoder = srsd->decoder;
+                tmp->fd      = sock;
+                if (epoll_ctl(epollfd, EPOLL_CTL_ADD, sock, &evt) < 0) {
+                    UNIXERR("epoll_ctl");
+                    srsd_delete(&tmp);
+                    close(sock);
+                }
+                continue;
+            }
+
+            if (evts[n].events & EPOLLIN) {
+                int res = buffer_read(&srsd->ibuf, srsd->fd, -1);
+
+                if ((res < 0 && errno != EINTR && errno != EAGAIN)
+                ||  res == 0)
+                {
+                    srsd_delete(&srsd);
+                    continue;
+                }
+
+                if (process_srs(srs, domain, srsd) < 0) {
+                    srsd_delete(&srsd);
+                    continue;
+                }
+            }
+
+            if ((evts[n].events & EPOLLOUT) && srsd->obuf.len) {
+                int res = write(srsd->fd, srsd->obuf.data, srsd->obuf.len);
+
+                if (res < 0 && errno != EINTR && errno != EAGAIN) {
+                    srsd_delete(&srsd);
+                    continue;
+                }
+
+                if (res > 0) {
+                    buffer_consume(&srsd->obuf, res);
+                }
+            }
+
+            if (srsd->watchwr == !srsd->obuf.len) {
+                struct epoll_event evt = {
+                    .events   = EPOLLIN | (srsd->obuf.len ? EPOLLOUT : 0),
+                    .data.ptr = srsd,
+                };
+                if (epoll_ctl(epollfd, EPOLL_CTL_MOD, srsd->fd, &evt) < 0) {
+                    UNIXERR("epoll_ctl");
+                    srsd_delete(&srsd);
+                    continue;
+                }
+                srsd->watchwr = srsd->obuf.len != 0;
+            }
+        }
+    }
+
+    close(epollfd);
+
+  error:
+    cleanexit = true;
+    return exitcode;
+}
+
+static srs_t *srs_read_secrets(const char *sfile)
+{
+    srs_t *srs;
     char buf[BUFSIZ];
-    srs_t *newsrs;
+    FILE *f;
+    int lineno = 0;
 
     f = fopen(sfile, "r");
     if (!f) {
         UNIXERR("fopen");
-        return -1;
+        return NULL;
     }
 
-    newsrs = srs_new();
+    srs = srs_new();
 
     while (fgets(buf, sizeof(buf), f)) {
         int n = strlen(buf);
 
-        if (buf[n - 1] != '\n')
+        ++lineno;
+        if (buf[n - 1] != '\n') {
+            syslog(LOG_CRIT, "%s:%d: line too long", sfile, lineno);
             goto error;
-        while (n > 0 && isspace((unsigned char)buf[n - 1]))
-            buf[--n] = '\0';
-        if (n > 0)
-            srs_add_secret(newsrs, buf);
+        }
+
+        srs_add_secret(srs, buf);
     }
-    fclose(f);
 
-    if (srs) {
-        srs_free(srs);
+    if (!lineno) {
+        syslog(LOG_CRIT, "%s: empty file, no secrets", sfile);
+        goto error;
     }
-    srs = newsrs;
-    return 0;
 
-  error:
     fclose(f);
-    srs_free(newsrs);
-    return -1;
-}
+    return srs;
 
-static void help(void)
-{
-    puts(
-            "Usage: srs-c [ -r | -d domain ] -f sfile -e sender\n"
-            "Perform an SRS encoding / decoding\n"
-            "\n"
-            "    -r          perform an SRS decoding\n"
-            "    -d domain   use that domain (required for encoding)\n"
-            "\n"
-            "    -f sfile    secret file for decoding.  the first line is taken if -s omitted\n"
-            "\n"
-            "    -e sender   the sender address we want to encode/decode\n"
-          );
-    exit(1);
+  error:
+    fclose(f);
+    srs_free(srs);
+    return NULL;
 }
 
 int main(int argc, char *argv[])
 {
-    char *res    = NULL;
-    char *domain = NULL;
-    char *sender = NULL;
-    char *sfile  = NULL;
-
-    int    opt   = 0;
-    bool   rev   = false;
-    int    err   = 0;
-
-    while ((opt = getopt(argc, argv, "d:e:f:r")) != -1) {
-        switch (opt) {
-            case 'd': domain = optarg;  break;
-            case 'e': sender = optarg;  break;
-            case 'f': sfile  = optarg;  break;
-            case 'r': rev    = true;    break;
-        }
-    }
+    int port_enc = DEFAULT_ENCODER_PORT;
+    int port_dec = DEFAULT_DECODER_PORT;
+
+    srs_t *srs;
 
-    if (!sender || !sfile || !(rev||domain)) {
-        help();
+    if (atexit(common_shutdown)) {
+        fputs("Cannot hook my atexit function, quitting !\n", stderr);
+        return EXIT_FAILURE;
     }
+    common_initialize();
 
-    if (read_sfile(sfile) < 0)
-        return -1;
+    for (int c = 0; (c = getopt(argc, argv, "he:d:")) >= 0; ) {
+        switch (c) {
+          case 'e':
+            port_enc = atoi(optarg);
+            break;
+          case 'd':
+            port_dec = atoi(optarg);
+            break;
+          default:
+            usage();
+            return EXIT_FAILURE;
+        }
+    }
 
-    if (rev) {
-        err = srs_reverse_alloc(srs, &res, sender);
-    } else {
-        err = srs_forward_alloc(srs, &res, sender, domain);
+    if (argc - optind != 2) {
+        usage();
+        return EXIT_FAILURE;
     }
 
-    if (res == NULL) {
-        fprintf(stderr, "%s\n", srs_strerror(err));
-        return -1;
+    srs = srs_read_secrets(argv[optind + 1]);
+    if (!srs) {
+        return EXIT_FAILURE;
     }
-    puts(res);
-    return 0;
+
+    return main_loop(srs, argv[optind], port_enc, port_dec);
 }