diff --git a/src/client/sock.c b/src/client/sock.c
index 8602532da9eea9d4575d85d52fdc82cc34976c91..b1efe58e189700fd46c7b77b46c55278c39a150d 100644 (file)
--- a/src/client/sock.c
+++ b/src/client/sock.c
# include "config.h"
#endif /* HAVE_CONFIG_H */
+#include "sysdb.h"
#include "client/sock.h"
#include "utils/error.h"
#include "utils/strbuf.h"
#include "utils/proto.h"
+#include "utils/os.h"
+#include "utils/ssl.h"
#include <arpa/inet.h>
+#include <assert.h>
#include <errno.h>
#include <limits.h>
#include <sys/socket.h>
#include <sys/un.h>
+#include <netdb.h>
+
/*
* private data types
*/
struct sdb_client {
char *address;
int fd;
- _Bool eof;
+ bool eof;
+
+ /* optional SSL settings */
+ sdb_ssl_options_t ssl_opts;
+ sdb_ssl_client_t *ssl;
+ sdb_ssl_session_t *ssl_session;
+
+ ssize_t (*read)(sdb_client_t *, sdb_strbuf_t *, size_t);
+ ssize_t (*write)(sdb_client_t *, const void *, size_t);
};
/*
* private helper functions
*/
+static ssize_t
+ssl_read(sdb_client_t *client, sdb_strbuf_t *buf, size_t n)
+{
+ char tmp[n];
+ ssize_t ret;
+
+ ret = sdb_ssl_session_read(client->ssl_session, tmp, n);
+ if (ret <= 0)
+ return ret;
+
+ sdb_strbuf_memappend(buf, tmp, ret);
+ return ret;
+} /* ssl_read */
+
+static ssize_t
+ssl_write(sdb_client_t *client, const void *buf, size_t n)
+{
+ return sdb_ssl_session_write(client->ssl_session, buf, n);
+} /* ssl_write */
+
+static ssize_t
+client_read(sdb_client_t *client, sdb_strbuf_t *buf, size_t n)
+{
+ return sdb_strbuf_read(buf, client->fd, n);
+} /* client_read */
+
+static ssize_t
+client_write(sdb_client_t *client, const void *buf, size_t n)
+{
+ return sdb_write(client->fd, n, buf);
+} /* client_write */
+
static int
connect_unixsock(sdb_client_t *client, const char *address)
{
client->fd = socket(AF_UNIX, SOCK_STREAM, /* protocol = */ 0);
if (client->fd < 0) {
char errbuf[1024];
- sdb_log(SDB_LOG_ERR, "Failed to open socket: %s",
+ sdb_log(SDB_LOG_ERR, "client: Failed to open socket: %s",
sdb_strerror(errno, errbuf, sizeof(errbuf)));
return -1;
}
if (connect(client->fd, (struct sockaddr *)&sa, sizeof(sa))) {
char errbuf[1024];
sdb_client_close(client);
- sdb_log(SDB_LOG_ERR, "Failed to connect to '%s': %s",
+ sdb_log(SDB_LOG_ERR, "client: Failed to connect to '%s': %s",
sa.sun_path, sdb_strerror(errno, errbuf, sizeof(errbuf)));
return -1;
}
return client->fd;
} /* connect_unixsock */
+static int
+connect_tcp(sdb_client_t *client, const char *address)
+{
+ char host[SDB_MAX(strlen("localhost"), (address ? strlen(address) : 0)) + 1];
+ struct addrinfo *ai, *ai_list = NULL;
+ char *peer, *tmp;
+ int status;
+
+ if ((status = sdb_resolve(SDB_NET_TCP, address, &ai_list))) {
+ sdb_log(SDB_LOG_ERR, "client: Failed to resolve '%s': %s",
+ address, gai_strerror(status));
+ return -1;
+ }
+
+ for (ai = ai_list; ai != NULL; ai = ai->ai_next) {
+ client->fd = socket(ai->ai_family, ai->ai_socktype, ai->ai_protocol);
+ if (client->fd < 0) {
+ char errbuf[1024];
+ sdb_log(SDB_LOG_ERR, "client: Failed to open socket: %s",
+ sdb_strerror(errno, errbuf, sizeof(errbuf)));
+ continue;
+ }
+
+ if (connect(client->fd, ai->ai_addr, ai->ai_addrlen)) {
+ char h[1024], p[32], errbuf[1024];
+ sdb_client_close(client);
+ getnameinfo(ai->ai_addr, ai->ai_addrlen, h, sizeof(h),
+ p, sizeof(p), NI_NUMERICHOST | NI_NUMERICSERV);
+ sdb_log(SDB_LOG_ERR, "client: Failed to connect to '%s:%s': %s",
+ h, p, sdb_strerror(errno, errbuf, sizeof(errbuf)));
+ continue;
+ }
+ break;
+ }
+ freeaddrinfo(ai_list);
+
+ if (client->fd < 0)
+ return -1;
+
+ client->ssl = sdb_ssl_client_create(&client->ssl_opts);
+ if (! client->ssl) {
+ sdb_client_close(client);
+ return -1;
+ }
+ client->ssl_session = sdb_ssl_client_connect(client->ssl, client->fd);
+ if (! client->ssl_session) {
+ sdb_client_close(client);
+ return -1;
+ }
+
+ strncpy(host, address, sizeof(host));
+ if ((tmp = strrchr(host, (int)':')))
+ *tmp = '\0';
+ if (! host[0])
+ strncpy(host, "localhost", sizeof(host));
+ peer = sdb_ssl_session_peer(client->ssl_session);
+ if ((! peer) || strcasecmp(peer, host)) {
+ /* TODO: also check alt-name */
+ sdb_log(SDB_LOG_ERR, "client: Failed to connect to '%s': "
+ "peer name '%s' does not match host address",
+ address, peer);
+ sdb_client_close(client);
+ if (peer)
+ free(peer);
+ return -1;
+ }
+ free(peer);
+
+ client->read = ssl_read;
+ client->write = ssl_write;
+ return client->fd;
+} /* connect_tcp */
+
/*
* public API
*/
client = malloc(sizeof(*client));
if (! client) {
- sdb_log(SDB_LOG_ERR, "Out of memory");
+ sdb_log(SDB_LOG_ERR, "client: Out of memory");
return NULL;
}
memset(client, 0, sizeof(*client));
client->fd = -1;
client->eof = 1;
+ client->ssl = NULL;
+ client->read = client_read;
+ client->write = client_write;
+
client->address = strdup(address);
if (! client->address) {
sdb_client_destroy(client);
- sdb_log(SDB_LOG_ERR, "Out of memory");
+ sdb_log(SDB_LOG_ERR, "client: Out of memory");
return NULL;
}
free(client->address);
client->address = NULL;
+ sdb_ssl_free_options(&client->ssl_opts);
+
free(client);
} /* sdb_client_destroy */
+int
+sdb_client_set_ssl_options(sdb_client_t *client, const sdb_ssl_options_t *opts)
+{
+ int ret = 0;
+
+ if ((! client) || (! opts))
+ return -1;
+
+ sdb_ssl_free_options(&client->ssl_opts);
+
+ if (opts->ca_file) {
+ client->ssl_opts.ca_file = strdup(opts->ca_file);
+ if (! client->ssl_opts.ca_file)
+ ret = -1;
+ }
+ if (opts->key_file) {
+ client->ssl_opts.key_file = strdup(opts->key_file);
+ if (! client->ssl_opts.key_file)
+ ret = -1;
+ }
+ if (opts->cert_file) {
+ client->ssl_opts.cert_file = strdup(opts->cert_file);
+ if (! client->ssl_opts.cert_file)
+ ret = -1;
+ }
+ if (opts->crl_file) {
+ client->ssl_opts.crl_file = strdup(opts->crl_file);
+ if (! client->ssl_opts.crl_file)
+ ret = -1;
+ }
+
+ if (ret)
+ sdb_ssl_free_options(&client->ssl_opts);
+ return ret;
+} /* sdb_client_set_ssl_options */
+
int
sdb_client_connect(sdb_client_t *client, const char *username)
{
if (client->fd >= 0)
return -1;
- if (!strncasecmp(client->address, "unix:", strlen("unix:")))
- connect_unixsock(client, client->address + strlen("unix:"));
- else if (*client->address == '/')
+ if (*client->address == '/')
connect_unixsock(client, client->address);
- else {
- sdb_log(SDB_LOG_ERR, "Unknown address type: %s", client->address);
- return -1;
- }
+ else if (!strncasecmp(client->address, "unix:", strlen("unix:")))
+ connect_unixsock(client, client->address + strlen("unix:"));
+ else if (!strncasecmp(client->address, "tcp:", strlen("tcp:")))
+ connect_tcp(client, client->address + strlen("tcp:"));
+ else
+ connect_tcp(client, client->address);
if (client->fd < 0)
return -1;
if (! username)
username = "";
- status = sdb_client_send(client, SDB_CONNECTION_STARTUP,
- (uint32_t)strlen(username), username);
- if (status < 0) {
- char errbuf[1024];
- sdb_client_close(client);
- sdb_log(SDB_LOG_ERR, "Failed to send STARTUP message to server: %s",
- sdb_strerror(errno, errbuf, sizeof(errbuf)));
- return (int)status;
- }
-
buf = sdb_strbuf_create(64);
rstatus = 0;
- status = sdb_client_recv(client, &rstatus, buf);
- if ((status > 0) && (rstatus == SDB_CONNECTION_OK)) {
+ status = sdb_client_rpc(client, SDB_CONNECTION_STARTUP,
+ (uint32_t)strlen(username), username, &rstatus, buf);
+ if ((status >= 0) && (rstatus == SDB_CONNECTION_OK)) {
sdb_strbuf_destroy(buf);
return 0;
}
if (status < 0) {
- char errbuf[1024];
- sdb_log(SDB_LOG_ERR, "Failed to receive server response: %s",
- sdb_strerror(errno, errbuf, sizeof(errbuf)));
+ sdb_log(SDB_LOG_ERR, "client: %s", sdb_strbuf_string(buf));
+ sdb_client_close(client);
+ sdb_strbuf_destroy(buf);
+ return (int)status;
}
- else if (client->eof)
- sdb_log(SDB_LOG_ERR, "Encountered end-of-file while waiting "
+ if (client->eof)
+ sdb_log(SDB_LOG_ERR, "client: Encountered end-of-file while waiting "
"for server response");
if (rstatus == SDB_CONNECTION_ERROR) {
- sdb_log(SDB_LOG_ERR, "Access denied for user '%s'", username);
+ sdb_log(SDB_LOG_ERR, "client: Access denied for user '%s': %s",
+ username, sdb_strbuf_string(buf));
status = -((int)rstatus);
}
else if (rstatus != SDB_CONNECTION_OK) {
- sdb_log(SDB_LOG_ERR, "Received unsupported authentication request "
- "(status %d) during startup", (int)rstatus);
+ sdb_log(SDB_LOG_ERR, "client: Received unsupported authentication "
+ "request (status %d) during startup", (int)rstatus);
status = -((int)rstatus);
}
if (! client)
return;
+ if (client->ssl_session) {
+ sdb_ssl_session_destroy(client->ssl_session);
+ client->ssl_session = NULL;
+ }
+ if (client->ssl) {
+ sdb_ssl_client_destroy(client->ssl);
+ client->ssl = NULL;
+ }
+
close(client->fd);
client->fd = -1;
client->eof = 1;
} /* sdb_client_close */
+ssize_t
+sdb_client_rpc(sdb_client_t *client,
+ uint32_t cmd, uint32_t msg_len, const char *msg,
+ uint32_t *code, sdb_strbuf_t *buf)
+{
+ uint32_t rcode = 0;
+ ssize_t status;
+
+ if (! buf)
+ return -1;
+
+ if (sdb_client_send(client, cmd, msg_len, msg) < 0) {
+ char errbuf[1024];
+ sdb_strbuf_sprintf(buf, "Failed to send %s message to server: %s",
+ SDB_CONN_MSGTYPE_TO_STRING(cmd),
+ sdb_strerror(errno, errbuf, sizeof(errbuf)));
+ if (code)
+ *code = SDB_CONNECTION_ERROR;
+ return -1;
+ }
+
+ while (42) {
+ size_t offset = sdb_strbuf_len(buf);
+
+ status = sdb_client_recv(client, &rcode, buf);
+ if (status < 0) {
+ char errbuf[1024];
+ sdb_strbuf_sprintf(buf, "Failed to receive server response: %s",
+ sdb_strerror(errno, errbuf, sizeof(errbuf)));
+ if (code)
+ *code = SDB_CONNECTION_ERROR;
+ return status;
+ }
+
+ if (rcode == SDB_CONNECTION_LOG) {
+ uint32_t prio = 0;
+ if (sdb_proto_unmarshal_int32(SDB_STRBUF_STR(buf), &prio) < 0) {
+ sdb_log(SDB_LOG_WARNING, "client: Received a LOG message "
+ "with invalid or missing priority");
+ prio = (uint32_t)SDB_LOG_ERR;
+ }
+ sdb_log((int)prio, "client: %s", sdb_strbuf_string(buf) + offset);
+ sdb_strbuf_skip(buf, offset, sdb_strbuf_len(buf) - offset);
+ continue;
+ }
+ break;
+ }
+
+ if (code)
+ *code = rcode;
+ return status;
+} /* sdb_client_rpc */
+
ssize_t
sdb_client_send(sdb_client_t *client,
uint32_t cmd, uint32_t msg_len, const char *msg)
{
+ char buf[2 * sizeof(uint32_t) + msg_len];
+
if ((! client) || (! client->fd))
return -1;
+ if (sdb_proto_marshal(buf, sizeof(buf), cmd, msg_len, msg) < 0)
+ return -1;
- return sdb_proto_send_msg(client->fd, cmd, msg_len, msg);
+ return client->write(client, buf, sizeof(buf));
} /* sdb_client_send */
ssize_t
while (42) {
ssize_t status;
- if (sdb_proto_select(client->fd, SDB_PROTO_SELECTIN))
- return -1;
-
errno = 0;
- status = sdb_strbuf_read(buf, client->fd, req);
+ status = client->read(client, buf, req);
if (status < 0) {
if ((errno == EAGAIN) || (errno == EWOULDBLOCK))
continue;
continue;
if (rstatus == UINT32_MAX) {
+ const char *str = sdb_strbuf_string(buf) + data_offset;
+ size_t len = sdb_strbuf_len(buf) - data_offset;
+ ssize_t n;
+
/* retrieve status and data len */
- rstatus = sdb_proto_get_int(buf, data_offset);
- rlen = sdb_proto_get_int(buf, data_offset + sizeof(rstatus));
+ assert(len >= 2 * sizeof(uint32_t));
+ n = sdb_proto_unmarshal_int32(str, len, &rstatus);
+ str += n; len -= (size_t)n;
+ sdb_proto_unmarshal_int32(str, len, &rlen);
if (! rlen)
break;
return (ssize_t)total;
} /* sdb_client_recv */
-_Bool
+bool
sdb_client_eof(sdb_client_t *client)
{
if ((! client) || (client->fd < 0))