Code

store_json: Base the memstore emitter on the store-writer API.
[sysdb.git] / src / frontend / sock.c
index 5279cbae84fd44d1c05623537a1bbf468860653b..1f3ec5fd74b70ee903a68ce8678bd2fb69c4a87b 100644 (file)
@@ -38,6 +38,7 @@
 #include "utils/error.h"
 #include "utils/llist.h"
 #include "utils/os.h"
+#include "utils/ssl.h"
 #include "utils/strbuf.h"
 
 #include <assert.h>
@@ -78,6 +79,11 @@ typedef struct {
        char *address;
        int   type;
 
+       /* optional SSL settings */
+       sdb_ssl_options_t ssl_opts;
+       sdb_ssl_server_t *ssl;
+
+       /* listener configuration */
        int sock_fd;
        int (*setup)(sdb_conn_t *, void *);
 } listener_t;
@@ -108,6 +114,30 @@ struct sdb_fe_socket {
        sdb_channel_t *chan;
 };
 
+/*
+ * SSL helper functions
+ */
+
+static ssize_t
+ssl_read(sdb_conn_t *conn, size_t n)
+{
+       char buf[n];
+       ssize_t ret;
+
+       ret = sdb_ssl_session_read(conn->ssl_session, buf, n);
+       if (ret <= 0)
+               return ret;
+
+       sdb_strbuf_memappend(conn->buf, buf, ret);
+       return ret;
+} /* ssl_read */
+
+static ssize_t
+ssl_write(sdb_conn_t *conn, const void *buf, size_t n)
+{
+       return sdb_ssl_session_write(conn->ssl_session, buf, n);
+} /* ssl_write */
+
 /*
  * connection management functions
  */
@@ -225,6 +255,34 @@ close_unixsock(listener_t *listener)
        unlink(listener->address);
 } /* close_unixsock */
 
+static int
+finish_tcp(sdb_conn_t *conn)
+{
+       if (! conn->ssl_session)
+               return 0;
+
+       sdb_ssl_session_destroy(conn->ssl_session);
+       conn->ssl_session = NULL;
+       return 0;
+} /* finish_tcp */
+
+static int
+setup_tcp(sdb_conn_t *conn, void *user_data)
+{
+       listener_t *listener = user_data;
+
+       conn->ssl_session = sdb_ssl_server_accept(listener->ssl, conn->fd);
+       if (! conn->ssl_session)
+               return -1;
+
+       conn->username = sdb_ssl_session_peer(conn->ssl_session);
+
+       conn->finish = finish_tcp;
+       conn->read = ssl_read;
+       conn->write = ssl_write;
+       return 0;
+} /* setup_tcp */
+
 static int
 open_tcp(listener_t *listener)
 {
@@ -233,6 +291,10 @@ open_tcp(listener_t *listener)
 
        assert(listener);
 
+       listener->ssl = sdb_ssl_server_create(&listener->ssl_opts);
+       if (! listener->ssl)
+               return -1;
+
        if ((status = sdb_resolve(SDB_NET_TCP, listener->address, &ai_list))) {
                sdb_log(SDB_LOG_ERR, "frontend: Failed to resolve '%s': %s",
                                listener->address, gai_strerror(status));
@@ -277,6 +339,8 @@ open_tcp(listener_t *listener)
 
        if (listener->sock_fd < 0)
                return -1;
+
+       listener->setup = setup_tcp;
        return 0;
 } /* open_tcp */
 
@@ -285,6 +349,9 @@ close_tcp(listener_t *listener)
 {
        assert(listener);
 
+       sdb_ssl_server_destroy(listener->ssl);
+       listener->ssl = NULL;
+
        if (listener->sock_fd >= 0)
                close(listener->sock_fd);
        listener->sock_fd = -1;
@@ -326,6 +393,7 @@ listener_listen(listener_t *listener)
                                listener->address, sdb_strerror(errno, buf, sizeof(buf)));
                return -1;
        }
+       sdb_log(SDB_LOG_INFO, "frontend: Listening on %s", listener->address);
        return 0;
 } /* listener_listen */
 
@@ -366,7 +434,8 @@ get_type(const char *address)
                        return impl->type;
                }
        }
-       return -1;
+       /* don't report an error, this could be an IPv6 address */
+       return listener_impls[0].type;
 } /* get_type */
 
 static void
@@ -376,6 +445,7 @@ listener_destroy(listener_t *listener)
                return;
 
        listener_close(listener);
+       sdb_ssl_free_options(&listener->ssl_opts);
 
        if (listener->address)
                free(listener->address);
@@ -412,6 +482,7 @@ listener_create(sdb_fe_socket_t *sock, const char *address)
        if ((! strncmp(address, listener_impls[type].prefix, len))
                        && (address[len] == ':'))
                address += strlen(listener_impls[type].prefix) + 1;
+       memset(listener, 0, sizeof(*listener));
 
        listener->sock_fd = -1;
        listener->address = strdup(address);
@@ -424,12 +495,7 @@ listener_create(sdb_fe_socket_t *sock, const char *address)
        }
        listener->type = type;
        listener->setup = NULL;
-
-       if (listener_impls[type].open(listener)) {
-               /* prints error */
-               listener_destroy(listener);
-               return NULL;
-       }
+       listener->ssl = NULL;
 
        ++sock->listeners_num;
        return listener;
@@ -504,7 +570,11 @@ connection_handler(void *data)
                                        "connection %s to list of open connections",
                                        SDB_OBJ(conn)->name);
                }
-               write(sock->trigger[TRIGGER_WRITE], "", 1);
+               if (write(sock->trigger[TRIGGER_WRITE], "", 1) <= 0) {
+                       /* This shouldn't happen and it's not critical; in the worst cases
+                        * it slows us down. */
+                       sdb_log(SDB_LOG_WARNING, "frontend: Failed to trigger main loop");
+               }
 
                /* pass ownership back to list; or destroy in case of an error */
                sdb_object_deref(SDB_OBJ(conn));
@@ -519,7 +589,7 @@ connection_accept(sdb_fe_socket_t *sock, listener_t *listener)
        int status;
 
        obj = SDB_OBJ(sdb_connection_accept(listener->sock_fd,
-                               listener->setup, NULL));
+                               listener->setup, listener));
        if (! obj)
                return -1;
 
@@ -645,7 +715,8 @@ sdb_fe_sock_destroy(sdb_fe_socket_t *sock)
 } /* sdb_fe_sock_destroy */
 
 int
-sdb_fe_sock_add_listener(sdb_fe_socket_t *sock, const char *address)
+sdb_fe_sock_add_listener(sdb_fe_socket_t *sock, const char *address,
+               const sdb_ssl_options_t *opts)
 {
        listener_t *listener;
 
@@ -655,6 +726,44 @@ sdb_fe_sock_add_listener(sdb_fe_socket_t *sock, const char *address)
        listener = listener_create(sock, address);
        if (! listener)
                return -1;
+
+       if (opts) {
+               int ret = 0;
+
+               if (opts->ca_file) {
+                       listener->ssl_opts.ca_file = strdup(opts->ca_file);
+                       if (! listener->ssl_opts.ca_file)
+                               ret = -1;
+               }
+               if (opts->key_file) {
+                       listener->ssl_opts.key_file = strdup(opts->key_file);
+                       if (! listener->ssl_opts.key_file)
+                               ret = -1;
+               }
+               if (opts->cert_file) {
+                       listener->ssl_opts.cert_file = strdup(opts->cert_file);
+                       if (! listener->ssl_opts.cert_file)
+                               ret = -1;
+               }
+               if (opts->crl_file) {
+                       listener->ssl_opts.crl_file = strdup(opts->crl_file);
+                       if (! listener->ssl_opts.crl_file)
+                               ret = -1;
+               }
+
+               if (ret) {
+                       listener_destroy(listener);
+                       --sock->listeners_num;
+                       return ret;
+               }
+       }
+
+       if (listener_impls[listener->type].open(listener)) {
+               /* prints error */
+               listener_destroy(listener);
+               --sock->listeners_num;
+               return -1;
+       }
        return 0;
 } /* sdb_fe_sock_add_listener */