Factorize code.
[apps/pfixtools.git] / main-srsd.c
index 891f7ee..89e21de 100644 (file)
 
 /*
  * Copyright © 2005-2007 Pierre Habouzit
+ * Copyright © 2008 Florent Bruneau
  */
 
-#include <fcntl.h>
-#include <netinet/in.h>
-#include <sys/epoll.h>
-#include <sys/stat.h>
+#include "common.h"
 
 #include <srs2.h>
 
-#include "common.h"
+#include "epoll.h"
 #include "mem.h"
 #include "buffer.h"
+#include "server.h"
 
 #define DAEMON_NAME             "pfix-srsd"
-#define DEFAULT_ENCODER_PORT    10000
-#define DEFAULT_DECODER_PORT    10001
+#define DEFAULT_ENCODER_PORT    10001
+#define DEFAULT_DECODER_PORT    10002
 #define RUNAS_USER              "nobody"
 #define RUNAS_GROUP             "nogroup"
 
-#define __tostr(x)  #x
-#define STR(x)      __tostr(x)
+DECLARE_MAIN
 
-/* srs encoder/decoder/listener worker {{{ */
+typedef struct srs_config_t {
+    srs_t* srs;
+    const char* domain;
+} srs_config_t;
 
-typedef struct srsd_t {
-    unsigned listener : 1;
-    unsigned decoder  : 1;
-    unsigned watchwr  : 1;
-    int fd;
-    buffer_t ibuf;
-    buffer_t obuf;
-} srsd_t;
+static const char* decoder_ptr = "decoder";
+static const char* encoder_ptr = "encoder";
 
-static srsd_t *srsd_new(void)
+static void *srsd_new_decoder(void)
 {
-    srsd_t *srsd = p_new(srsd_t, 1);
-    srsd->fd = -1;
-    return srsd;
+    return (void*)decoder_ptr;
 }
 
-static void srsd_delete(srsd_t **srsd)
+static void *srsd_new_encoder(void)
 {
-    if (*srsd) {
-        if ((*srsd)->fd >= 0)
-            close((*srsd)->fd);
-        buffer_wipe(&(*srsd)->ibuf);
-        buffer_wipe(&(*srsd)->obuf);
-        p_delete(srsd);
-    }
+    return (void*)encoder_ptr;
+}
+
+static void *srsd_stater(server_t *server)
+{
+    return server->data;
 }
 
 void urldecode(char *s, char *end)
@@ -102,8 +94,14 @@ void urldecode(char *s, char *end)
     *s++ = '\0';
 }
 
-int process_srs(srs_t *srs, const char *domain, srsd_t *srsd)
+int process_srs(server_t *srsd, void* vconfig)
 {
+    srs_config_t* config = vconfig;
+    int res = buffer_read(&srsd->ibuf, srsd->fd, -1);
+
+    if ((res < 0 && errno != EINTR && errno != EAGAIN) || res == 0)
+        return -1;
+
     while (srsd->ibuf.len > 4) {
         char buf[BUFSIZ], *p, *q, *nl;
         int err;
@@ -114,6 +112,9 @@ int process_srs(srs_t *srs, const char *domain, srsd_t *srsd)
                 syslog(LOG_ERR, "unreasonnable amount of data without a \\n");
                 return -1;
             }
+            if (srsd->obuf.len) {
+              epoll_modify(srsd->fd, EPOLLIN | EPOLLOUT, srsd);
+            }
             return 0;
         }
 
@@ -133,10 +134,10 @@ int process_srs(srs_t *srs, const char *domain, srsd_t *srsd)
 
         urldecode(p, q);
 
-        if (srsd->decoder) {
-            err = srs_reverse(srs, buf, ssizeof(buf), p);
+        if (srsd->data == (void*)decoder_ptr) {
+            err = srs_reverse(config->srs, buf, ssizeof(buf), p);
         } else {
-            err = srs_forward(srs, buf, ssizeof(buf), p, domain);
+            err = srs_forward(config->srs, buf, ssizeof(buf), p, config->domain);
         }
 
         if (err == 0) {
@@ -159,59 +160,20 @@ int process_srs(srs_t *srs, const char *domain, srsd_t *srsd)
       skip:
         buffer_consume(&srsd->ibuf, nl - srsd->ibuf.data);
     }
-
+    if (srsd->obuf.len) {
+      epoll_modify(srsd->fd, EPOLLIN | EPOLLOUT, srsd);
+    }
     return 0;
 }
 
-int start_listener(int epollfd, int port, bool decoder)
+int start_listener(int port, bool decoder)
 {
-    struct sockaddr_in addr = {
-        .sin_family = AF_INET,
-        .sin_addr   = { htonl(INADDR_LOOPBACK) },
-    };
-    struct epoll_event evt = { .events = EPOLLIN };
-    srsd_t *tmp;
-    int sock;
-
-    addr.sin_port = htons(port);
-    sock = tcp_listen_nonblock((const struct sockaddr *)&addr, sizeof(addr));
-    if (sock < 0) {
-        return -1;
-    }
-
-    evt.data.ptr  = tmp = srsd_new();
-    tmp->fd       = sock;
-    tmp->decoder  = decoder;
-    tmp->listener = true;
-    if (epoll_ctl(epollfd, EPOLL_CTL_ADD, sock, &evt) < 0) {
-        UNIXERR("epoll_ctl");
-        return -1;
-    }
-    return 0;
+    return start_server(port, decoder ? srsd_new_decoder : srsd_new_encoder, NULL);
 }
 
 /* }}} */
 /* administrivia {{{ */
 
-static int main_initialize(void)
-{
-    openlog(DAEMON_NAME, LOG_PID, LOG_MAIL);
-    signal(SIGPIPE, SIG_IGN);
-    signal(SIGINT,  &common_sighandler);
-    signal(SIGTERM, &common_sighandler);
-    signal(SIGHUP,  &common_sighandler);
-    signal(SIGSEGV, &common_sighandler);
-    syslog(LOG_INFO, "Starting...");
-    return 0;
-}
-
-static void main_shutdown(void)
-{
-    closelog();
-}
-
-module_init(main_initialize);
-module_exit(main_shutdown);
 
 void usage(void)
 {
@@ -224,116 +186,12 @@ void usage(void)
           "                 (default: "STR(DEFAULT_DECODER_PORT)")\n"
           "    -p <pidfile> file to write our pid to\n"
           "    -u           unsafe mode: don't drop privilegies\n"
+          "    -f           stay in foreground\n"
          , stderr);
 }
 
 /* }}} */
 
-int main_loop(srs_t *srs, const char *domain, int port_enc, int port_dec)
-{
-    int exitcode = EXIT_SUCCESS;
-    int epollfd = epoll_create(128);
-
-    if (epollfd < 0) {
-        UNIXERR("epoll_create");
-        exitcode = EXIT_FAILURE;
-        goto error;
-    }
-
-    if (start_listener(epollfd, port_enc, false) < 0)
-        return EXIT_FAILURE;
-    if (start_listener(epollfd, port_dec, true) < 0)
-        return EXIT_FAILURE;
-
-    while (!sigint) {
-        struct epoll_event evts[1024];
-        int n;
-
-        n = epoll_wait(epollfd, evts, countof(evts), -1);
-        if (n < 0) {
-            if (errno != EAGAIN && errno != EINTR) {
-                UNIXERR("epoll_wait");
-                exitcode = EXIT_FAILURE;
-                break;
-            }
-            continue;
-        }
-
-        while (--n >= 0) {
-            srsd_t *srsd = evts[n].data.ptr;
-
-            if (srsd->listener) {
-                struct epoll_event evt = { .events = EPOLLIN };
-                srsd_t *tmp;
-                int sock;
-
-                sock = accept_nonblock(srsd->fd);
-                if (sock < 0) {
-                    UNIXERR("accept");
-                    continue;
-                }
-
-                evt.data.ptr = tmp = srsd_new();
-                tmp->decoder = srsd->decoder;
-                tmp->fd      = sock;
-                if (epoll_ctl(epollfd, EPOLL_CTL_ADD, sock, &evt) < 0) {
-                    UNIXERR("epoll_ctl");
-                    srsd_delete(&tmp);
-                    close(sock);
-                }
-                continue;
-            }
-
-            if (evts[n].events & EPOLLIN) {
-                int res = buffer_read(&srsd->ibuf, srsd->fd, -1);
-
-                if ((res < 0 && errno != EINTR && errno != EAGAIN)
-                ||  res == 0)
-                {
-                    srsd_delete(&srsd);
-                    continue;
-                }
-
-                if (process_srs(srs, domain, srsd) < 0) {
-                    srsd_delete(&srsd);
-                    continue;
-                }
-            }
-
-            if ((evts[n].events & EPOLLOUT) && srsd->obuf.len) {
-                int res = write(srsd->fd, srsd->obuf.data, srsd->obuf.len);
-
-                if (res < 0 && errno != EINTR && errno != EAGAIN) {
-                    srsd_delete(&srsd);
-                    continue;
-                }
-
-                if (res > 0) {
-                    buffer_consume(&srsd->obuf, res);
-                }
-            }
-
-            if (srsd->watchwr == !srsd->obuf.len) {
-                struct epoll_event evt = {
-                    .events   = EPOLLIN | (srsd->obuf.len ? EPOLLOUT : 0),
-                    .data.ptr = srsd,
-                };
-                if (epoll_ctl(epollfd, EPOLL_CTL_MOD, srsd->fd, &evt) < 0) {
-                    UNIXERR("epoll_ctl");
-                    srsd_delete(&srsd);
-                    continue;
-                }
-                srsd->watchwr = srsd->obuf.len != 0;
-            }
-        }
-    }
-
-    close(epollfd);
-
-  error:
-    return exitcode;
-}
-
 static srs_t *srs_read_secrets(const char *sfile)
 {
     srs_t *srs;
@@ -378,25 +236,21 @@ static srs_t *srs_read_secrets(const char *sfile)
 int main(int argc, char *argv[])
 {
     bool unsafe  = false;
+    bool daemonize = true;
     int port_enc = DEFAULT_ENCODER_PORT;
     int port_dec = DEFAULT_DECODER_PORT;
     const char *pidfile = NULL;
 
-    FILE *f = NULL;
-    int res;
     srs_t *srs;
 
-    if (atexit(common_shutdown)) {
-        fputs("Cannot hook my atexit function, quitting !\n", stderr);
-        return EXIT_FAILURE;
-    }
-    common_initialize();
-
-    for (int c = 0; (c = getopt(argc, argv, "he:d:p:u")) >= 0; ) {
+    for (int c = 0; (c = getopt(argc, argv, "hfu" "e:d:p:")) >= 0; ) {
         switch (c) {
           case 'e':
             port_enc = atoi(optarg);
             break;
+          case 'f':
+            daemonize = false;
+            break;
           case 'd':
             port_dec = atoi(optarg);
             break;
@@ -422,38 +276,21 @@ int main(int argc, char *argv[])
         return EXIT_FAILURE;
     }
 
-    if (pidfile) {
-        f = fopen(pidfile, "w");
-        if (!f) {
-            syslog(LOG_CRIT, "unable to write pidfile %s", pidfile);
-        }
-        fprintf(f, "%d\n", getpid());
-        fflush(f);
-    }
-
-    if (!unsafe && drop_privileges(RUNAS_USER, RUNAS_GROUP) < 0) {
-        syslog(LOG_CRIT, "unable to drop privileges");
-        return EXIT_FAILURE;
-    }
-
-    if (daemon_detach() < 0) {
-        syslog(LOG_CRIT, "unable to fork");
+    if (common_setup(pidfile, unsafe, RUNAS_USER, RUNAS_GROUP, daemonize)
+          != EXIT_SUCCESS) {
         return EXIT_FAILURE;
     }
-
-    if (f) {
-        rewind(f);
-        ftruncate(fileno(f), 0);
-        fprintf(f, "%d\n", getpid());
-        fflush(f);
-    }
-    res = main_loop(srs, argv[optind], port_enc, port_dec);
-    if (f) {
-        rewind(f);
-        ftruncate(fileno(f), 0);
-        fclose(f);
-        f = NULL;
+    {
+      srs_config_t config = {
+        .srs    = srs,
+        .domain = argv[optind]
+      };
+
+      if (start_listener(port_enc, false) < 0)
+          return EXIT_FAILURE;
+      if (start_listener(port_dec, true) < 0)
+          return EXIT_FAILURE;
+
+      return server_loop(srsd_stater, NULL, process_srs, &config);
     }
-    syslog(LOG_INFO, "Stopping...");
-    return res;
 }