Code

store, plugin: Let the plugin module determine an objects backends.
[sysdb.git] / src / client / sock.c
index 805dc1d1ce5350da8d1aaced43d9bce53df1d9cd..b1efe58e189700fd46c7b77b46c55278c39a150d 100644 (file)
 #      include "config.h"
 #endif /* HAVE_CONFIG_H */
 
+#include "sysdb.h"
 #include "client/sock.h"
 #include "utils/error.h"
 #include "utils/strbuf.h"
 #include "utils/proto.h"
+#include "utils/os.h"
+#include "utils/ssl.h"
 
 #include <arpa/inet.h>
 
+#include <assert.h>
 #include <errno.h>
 #include <limits.h>
 
@@ -49,6 +53,8 @@
 #include <sys/socket.h>
 #include <sys/un.h>
 
+#include <netdb.h>
+
 /*
  * private data types
  */
 struct sdb_client {
        char *address;
        int   fd;
-       _Bool eof;
+       bool  eof;
+
+       /* optional SSL settings */
+       sdb_ssl_options_t ssl_opts;
+       sdb_ssl_client_t *ssl;
+       sdb_ssl_session_t *ssl_session;
+
+       ssize_t (*read)(sdb_client_t *, sdb_strbuf_t *, size_t);
+       ssize_t (*write)(sdb_client_t *, const void *, size_t);
 };
 
 /*
  * private helper functions
  */
 
+static ssize_t
+ssl_read(sdb_client_t *client, sdb_strbuf_t *buf, size_t n)
+{
+       char tmp[n];
+       ssize_t ret;
+
+       ret = sdb_ssl_session_read(client->ssl_session, tmp, n);
+       if (ret <= 0)
+               return ret;
+
+       sdb_strbuf_memappend(buf, tmp, ret);
+       return ret;
+} /* ssl_read */
+
+static ssize_t
+ssl_write(sdb_client_t *client, const void *buf, size_t n)
+{
+       return sdb_ssl_session_write(client->ssl_session, buf, n);
+} /* ssl_write */
+
+static ssize_t
+client_read(sdb_client_t *client, sdb_strbuf_t *buf, size_t n)
+{
+       return sdb_strbuf_read(buf, client->fd, n);
+} /* client_read */
+
+static ssize_t
+client_write(sdb_client_t *client, const void *buf, size_t n)
+{
+       return sdb_write(client->fd, n, buf);
+} /* client_write */
+
 static int
 connect_unixsock(sdb_client_t *client, const char *address)
 {
@@ -71,7 +117,7 @@ connect_unixsock(sdb_client_t *client, const char *address)
        client->fd = socket(AF_UNIX, SOCK_STREAM, /* protocol = */ 0);
        if (client->fd < 0) {
                char errbuf[1024];
-               sdb_log(SDB_LOG_ERR, "Failed to open socket: %s",
+               sdb_log(SDB_LOG_ERR, "client: Failed to open socket: %s",
                                sdb_strerror(errno, errbuf, sizeof(errbuf)));
                return -1;
        }
@@ -83,13 +129,86 @@ connect_unixsock(sdb_client_t *client, const char *address)
        if (connect(client->fd, (struct sockaddr *)&sa, sizeof(sa))) {
                char errbuf[1024];
                sdb_client_close(client);
-               sdb_log(SDB_LOG_ERR, "Failed to connect to '%s': %s",
+               sdb_log(SDB_LOG_ERR, "client: Failed to connect to '%s': %s",
                                sa.sun_path, sdb_strerror(errno, errbuf, sizeof(errbuf)));
                return -1;
        }
        return client->fd;
 } /* connect_unixsock */
 
+static int
+connect_tcp(sdb_client_t *client, const char *address)
+{
+       char host[SDB_MAX(strlen("localhost"), (address ? strlen(address) : 0)) + 1];
+       struct addrinfo *ai, *ai_list = NULL;
+       char *peer, *tmp;
+       int status;
+
+       if ((status = sdb_resolve(SDB_NET_TCP, address, &ai_list))) {
+               sdb_log(SDB_LOG_ERR, "client: Failed to resolve '%s': %s",
+                               address, gai_strerror(status));
+               return -1;
+       }
+
+       for (ai = ai_list; ai != NULL; ai = ai->ai_next) {
+               client->fd = socket(ai->ai_family, ai->ai_socktype, ai->ai_protocol);
+               if (client->fd < 0) {
+                       char errbuf[1024];
+                       sdb_log(SDB_LOG_ERR, "client: Failed to open socket: %s",
+                                       sdb_strerror(errno, errbuf, sizeof(errbuf)));
+                       continue;
+               }
+
+               if (connect(client->fd, ai->ai_addr, ai->ai_addrlen)) {
+                       char h[1024], p[32], errbuf[1024];
+                       sdb_client_close(client);
+                       getnameinfo(ai->ai_addr, ai->ai_addrlen, h, sizeof(h),
+                                       p, sizeof(p), NI_NUMERICHOST | NI_NUMERICSERV);
+                       sdb_log(SDB_LOG_ERR, "client: Failed to connect to '%s:%s': %s",
+                                       h, p, sdb_strerror(errno, errbuf, sizeof(errbuf)));
+                       continue;
+               }
+               break;
+       }
+       freeaddrinfo(ai_list);
+
+       if (client->fd < 0)
+               return -1;
+
+       client->ssl = sdb_ssl_client_create(&client->ssl_opts);
+       if (! client->ssl) {
+               sdb_client_close(client);
+               return -1;
+       }
+       client->ssl_session = sdb_ssl_client_connect(client->ssl, client->fd);
+       if (! client->ssl_session) {
+               sdb_client_close(client);
+               return -1;
+       }
+
+       strncpy(host, address, sizeof(host));
+       if ((tmp = strrchr(host, (int)':')))
+               *tmp = '\0';
+       if (! host[0])
+               strncpy(host, "localhost", sizeof(host));
+       peer = sdb_ssl_session_peer(client->ssl_session);
+       if ((! peer) || strcasecmp(peer, host)) {
+               /* TODO: also check alt-name */
+               sdb_log(SDB_LOG_ERR, "client: Failed to connect to '%s': "
+                               "peer name '%s' does not match host address",
+                               address, peer);
+               sdb_client_close(client);
+               if (peer)
+                       free(peer);
+               return -1;
+       }
+       free(peer);
+
+       client->read = ssl_read;
+       client->write = ssl_write;
+       return client->fd;
+} /* connect_tcp */
+
 /*
  * public API
  */
@@ -104,17 +223,21 @@ sdb_client_create(const char *address)
 
        client = malloc(sizeof(*client));
        if (! client) {
-               sdb_log(SDB_LOG_ERR, "Out of memory");
+               sdb_log(SDB_LOG_ERR, "client: Out of memory");
                return NULL;
        }
        memset(client, 0, sizeof(*client));
        client->fd = -1;
        client->eof = 1;
 
+       client->ssl = NULL;
+       client->read = client_read;
+       client->write = client_write;
+
        client->address = strdup(address);
        if (! client->address) {
                sdb_client_destroy(client);
-               sdb_log(SDB_LOG_ERR, "Out of memory");
+               sdb_log(SDB_LOG_ERR, "client: Out of memory");
                return NULL;
        }
 
@@ -133,9 +256,47 @@ sdb_client_destroy(sdb_client_t *client)
                free(client->address);
        client->address = NULL;
 
+       sdb_ssl_free_options(&client->ssl_opts);
+
        free(client);
 } /* sdb_client_destroy */
 
+int
+sdb_client_set_ssl_options(sdb_client_t *client, const sdb_ssl_options_t *opts)
+{
+       int ret = 0;
+
+       if ((! client) || (! opts))
+               return -1;
+
+       sdb_ssl_free_options(&client->ssl_opts);
+
+       if (opts->ca_file) {
+               client->ssl_opts.ca_file = strdup(opts->ca_file);
+               if (! client->ssl_opts.ca_file)
+                       ret = -1;
+       }
+       if (opts->key_file) {
+               client->ssl_opts.key_file = strdup(opts->key_file);
+               if (! client->ssl_opts.key_file)
+                       ret = -1;
+       }
+       if (opts->cert_file) {
+               client->ssl_opts.cert_file = strdup(opts->cert_file);
+               if (! client->ssl_opts.cert_file)
+                       ret = -1;
+       }
+       if (opts->crl_file) {
+               client->ssl_opts.crl_file = strdup(opts->crl_file);
+               if (! client->ssl_opts.crl_file)
+                       ret = -1;
+       }
+
+       if (ret)
+               sdb_ssl_free_options(&client->ssl_opts);
+       return ret;
+} /* sdb_client_set_ssl_options */
+
 int
 sdb_client_connect(sdb_client_t *client, const char *username)
 {
@@ -149,14 +310,14 @@ sdb_client_connect(sdb_client_t *client, const char *username)
        if (client->fd >= 0)
                return -1;
 
-       if (!strncasecmp(client->address, "unix:", strlen("unix:")))
-               connect_unixsock(client, client->address + strlen("unix:"));
-       else if (*client->address == '/')
+       if (*client->address == '/')
                connect_unixsock(client, client->address);
-       else {
-               sdb_log(SDB_LOG_ERR, "Unknown address type: %s", client->address);
-               return -1;
-       }
+       else if (!strncasecmp(client->address, "unix:", strlen("unix:")))
+               connect_unixsock(client, client->address + strlen("unix:"));
+       else if (!strncasecmp(client->address, "tcp:", strlen("tcp:")))
+               connect_tcp(client, client->address + strlen("tcp:"));
+       else
+               connect_tcp(client, client->address);
 
        if (client->fd < 0)
                return -1;
@@ -166,35 +327,33 @@ sdb_client_connect(sdb_client_t *client, const char *username)
        if (! username)
                username = "";
 
-       status = sdb_client_send(client, CONNECTION_STARTUP,
-                       (uint32_t)strlen(username), username);
-       if (status < 0) {
-               char errbuf[1024];
-               sdb_client_close(client);
-               sdb_log(SDB_LOG_ERR, "Failed to send STARTUP message to server: %s",
-                               sdb_strerror(errno, errbuf, sizeof(errbuf)));
-               return (int)status;
-       }
-
        buf = sdb_strbuf_create(64);
        rstatus = 0;
-       status = sdb_client_recv(client, &rstatus, buf);
-       if ((status > 0) && (rstatus == CONNECTION_OK)) {
+       status = sdb_client_rpc(client, SDB_CONNECTION_STARTUP,
+                       (uint32_t)strlen(username), username, &rstatus, buf);
+       if ((status >= 0) && (rstatus == SDB_CONNECTION_OK)) {
                sdb_strbuf_destroy(buf);
                return 0;
        }
 
        if (status < 0) {
-               char errbuf[1024];
-               sdb_log(SDB_LOG_ERR, "Failed to receive server response: %s",
-                               sdb_strerror(errno, errbuf, sizeof(errbuf)));
+               sdb_log(SDB_LOG_ERR, "client: %s", sdb_strbuf_string(buf));
+               sdb_client_close(client);
+               sdb_strbuf_destroy(buf);
+               return (int)status;
        }
-       else if (client->eof)
-               sdb_log(SDB_LOG_ERR, "Encountered end-of-file while waiting "
+       if (client->eof)
+               sdb_log(SDB_LOG_ERR, "client: Encountered end-of-file while waiting "
                                "for server response");
 
-       if (rstatus != CONNECTION_OK) {
-               sdb_log(SDB_LOG_ERR, "Access denied for user '%s'", username);
+       if (rstatus == SDB_CONNECTION_ERROR) {
+               sdb_log(SDB_LOG_ERR, "client: Access denied for user '%s': %s",
+                               username, sdb_strbuf_string(buf));
+               status = -((int)rstatus);
+       }
+       else if (rstatus != SDB_CONNECTION_OK) {
+               sdb_log(SDB_LOG_ERR, "client: Received unsupported authentication "
+                               "request (status %d) during startup", (int)rstatus);
                status = -((int)rstatus);
        }
 
@@ -211,25 +370,107 @@ sdb_client_sockfd(sdb_client_t *client)
        return client->fd;
 } /* sdb_client_sockfd */
 
+int
+sdb_client_shutdown(sdb_client_t *client, int how)
+{
+       if (! client) {
+               errno = ENOTSOCK;
+               return -1;
+       }
+
+       if (client->fd < 0) {
+               errno = EBADF;
+               return -1;
+       }
+
+       return shutdown(client->fd, how);
+} /* sdb_client_shutdown */
+
 void
 sdb_client_close(sdb_client_t *client)
 {
        if (! client)
                return;
 
+       if (client->ssl_session) {
+               sdb_ssl_session_destroy(client->ssl_session);
+               client->ssl_session = NULL;
+       }
+       if (client->ssl) {
+               sdb_ssl_client_destroy(client->ssl);
+               client->ssl = NULL;
+       }
+
        close(client->fd);
        client->fd = -1;
        client->eof = 1;
 } /* sdb_client_close */
 
+ssize_t
+sdb_client_rpc(sdb_client_t *client,
+               uint32_t cmd, uint32_t msg_len, const char *msg,
+               uint32_t *code, sdb_strbuf_t *buf)
+{
+       uint32_t rcode = 0;
+       ssize_t status;
+
+       if (! buf)
+               return -1;
+
+       if (sdb_client_send(client, cmd, msg_len, msg) < 0) {
+               char errbuf[1024];
+               sdb_strbuf_sprintf(buf, "Failed to send %s message to server: %s",
+                               SDB_CONN_MSGTYPE_TO_STRING(cmd),
+                               sdb_strerror(errno, errbuf, sizeof(errbuf)));
+               if (code)
+                       *code = SDB_CONNECTION_ERROR;
+               return -1;
+       }
+
+       while (42) {
+               size_t offset = sdb_strbuf_len(buf);
+
+               status = sdb_client_recv(client, &rcode, buf);
+               if (status < 0) {
+                       char errbuf[1024];
+                       sdb_strbuf_sprintf(buf, "Failed to receive server response: %s",
+                                       sdb_strerror(errno, errbuf, sizeof(errbuf)));
+                       if (code)
+                               *code = SDB_CONNECTION_ERROR;
+                       return status;
+               }
+
+               if (rcode == SDB_CONNECTION_LOG) {
+                       uint32_t prio = 0;
+                       if (sdb_proto_unmarshal_int32(SDB_STRBUF_STR(buf), &prio) < 0) {
+                               sdb_log(SDB_LOG_WARNING, "client: Received a LOG message "
+                                               "with invalid or missing priority");
+                               prio = (uint32_t)SDB_LOG_ERR;
+                       }
+                       sdb_log((int)prio, "client: %s", sdb_strbuf_string(buf) + offset);
+                       sdb_strbuf_skip(buf, offset, sdb_strbuf_len(buf) - offset);
+                       continue;
+               }
+               break;
+       }
+
+       if (code)
+               *code = rcode;
+       return status;
+} /* sdb_client_rpc */
+
 ssize_t
 sdb_client_send(sdb_client_t *client,
                uint32_t cmd, uint32_t msg_len, const char *msg)
 {
+       char buf[2 * sizeof(uint32_t) + msg_len];
+
        if ((! client) || (! client->fd))
                return -1;
+       if (sdb_proto_marshal(buf, sizeof(buf), cmd, msg_len, msg) < 0)
+               return -1;
 
-       return sdb_proto_send_msg(client->fd, cmd, msg_len, msg);
+       return client->write(client, buf, sizeof(buf));
 } /* sdb_client_send */
 
 ssize_t
@@ -255,11 +496,8 @@ sdb_client_recv(sdb_client_t *client,
        while (42) {
                ssize_t status;
 
-               if (sdb_proto_select(client->fd, SDB_PROTO_SELECTIN))
-                       return -1;
-
                errno = 0;
-               status = sdb_strbuf_read(buf, client->fd, req);
+               status = client->read(client, buf, req);
                if (status < 0) {
                        if ((errno == EAGAIN) || (errno == EWOULDBLOCK))
                                continue;
@@ -276,9 +514,15 @@ sdb_client_recv(sdb_client_t *client,
                        continue;
 
                if (rstatus == UINT32_MAX) {
+                       const char *str = sdb_strbuf_string(buf) + data_offset;
+                       size_t len = sdb_strbuf_len(buf) - data_offset;
+                       ssize_t n;
+
                        /* retrieve status and data len */
-                       rstatus = sdb_proto_get_int(buf, data_offset);
-                       rlen = sdb_proto_get_int(buf, data_offset + sizeof(rstatus));
+                       assert(len >= 2 * sizeof(uint32_t));
+                       n = sdb_proto_unmarshal_int32(str, len, &rstatus);
+                       str += n; len -= (size_t)n;
+                       sdb_proto_unmarshal_int32(str, len, &rlen);
 
                        if (! rlen)
                                break;
@@ -306,7 +550,7 @@ sdb_client_recv(sdb_client_t *client,
        return (ssize_t)total;
 } /* sdb_client_recv */
 
-_Bool
+bool
 sdb_client_eof(sdb_client_t *client)
 {
        if ((! client) || (client->fd < 0))