Code

client: Make SSL options configurable.
[sysdb.git] / src / client / sock.c
1 /*
2  * SysDB - src/client/sock.c
3  * Copyright (C) 2013 Sebastian 'tokkee' Harl <sh@tokkee.org>
4  * All rights reserved.
5  *
6  * Redistribution and use in source and binary forms, with or without
7  * modification, are permitted provided that the following conditions
8  * are met:
9  * 1. Redistributions of source code must retain the above copyright
10  *    notice, this list of conditions and the following disclaimer.
11  * 2. Redistributions in binary form must reproduce the above copyright
12  *    notice, this list of conditions and the following disclaimer in the
13  *    documentation and/or other materials provided with the distribution.
14  *
15  * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
16  * ``AS IS'' AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED
17  * TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR
18  * PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDERS OR
19  * CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL,
20  * EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO,
21  * PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS;
22  * OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY,
23  * WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR
24  * OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF
25  * ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
26  */
28 #if HAVE_CONFIG_H
29 #       include "config.h"
30 #endif /* HAVE_CONFIG_H */
32 #include "client/sock.h"
33 #include "utils/error.h"
34 #include "utils/strbuf.h"
35 #include "utils/proto.h"
36 #include "utils/os.h"
37 #include "utils/ssl.h"
39 #include <arpa/inet.h>
41 #include <assert.h>
42 #include <errno.h>
43 #include <limits.h>
45 #include <stdlib.h>
47 #include <string.h>
48 #include <strings.h>
50 #include <unistd.h>
52 #include <sys/socket.h>
53 #include <sys/un.h>
55 #include <netdb.h>
57 /*
58  * private data types
59  */
61 struct sdb_client {
62         char *address;
63         int   fd;
64         bool  eof;
66         /* optional SSL settings */
67         sdb_ssl_options_t ssl_opts;
68         sdb_ssl_client_t *ssl;
69         sdb_ssl_session_t *ssl_session;
71         ssize_t (*read)(sdb_client_t *, sdb_strbuf_t *, size_t);
72         ssize_t (*write)(sdb_client_t *, const void *, size_t);
73 };
75 /*
76  * private helper functions
77  */
79 static ssize_t
80 ssl_read(sdb_client_t *client, sdb_strbuf_t *buf, size_t n)
81 {
82         char tmp[n];
83         ssize_t ret;
85         ret = sdb_ssl_session_read(client->ssl_session, tmp, n);
86         if (ret <= 0)
87                 return ret;
89         sdb_strbuf_memappend(buf, tmp, ret);
90         return ret;
91 } /* ssl_read */
93 static ssize_t
94 ssl_write(sdb_client_t *client, const void *buf, size_t n)
95 {
96         return sdb_ssl_session_write(client->ssl_session, buf, n);
97 } /* ssl_write */
99 static ssize_t
100 client_read(sdb_client_t *client, sdb_strbuf_t *buf, size_t n)
102         return sdb_strbuf_read(buf, client->fd, n);
103 } /* client_read */
105 static ssize_t
106 client_write(sdb_client_t *client, const void *buf, size_t n)
108         return sdb_write(client->fd, n, buf);
109 } /* client_write */
111 static int
112 connect_unixsock(sdb_client_t *client, const char *address)
114         struct sockaddr_un sa;
116         client->fd = socket(AF_UNIX, SOCK_STREAM, /* protocol = */ 0);
117         if (client->fd < 0) {
118                 char errbuf[1024];
119                 sdb_log(SDB_LOG_ERR, "Failed to open socket: %s",
120                                 sdb_strerror(errno, errbuf, sizeof(errbuf)));
121                 return -1;
122         }
124         sa.sun_family = AF_UNIX;
125         strncpy(sa.sun_path, address, sizeof(sa.sun_path));
126         sa.sun_path[sizeof(sa.sun_path) - 1] = '\0';
128         if (connect(client->fd, (struct sockaddr *)&sa, sizeof(sa))) {
129                 char errbuf[1024];
130                 sdb_client_close(client);
131                 sdb_log(SDB_LOG_ERR, "Failed to connect to '%s': %s",
132                                 sa.sun_path, sdb_strerror(errno, errbuf, sizeof(errbuf)));
133                 return -1;
134         }
135         return client->fd;
136 } /* connect_unixsock */
138 static int
139 connect_tcp(sdb_client_t *client, const char *address)
141         struct addrinfo *ai, *ai_list = NULL;
142         int status;
144         if ((status = sdb_resolve(SDB_NET_TCP, address, &ai_list))) {
145                 sdb_log(SDB_LOG_ERR, "Failed to resolve '%s': %s",
146                                 address, gai_strerror(status));
147                 return -1;
148         }
150         for (ai = ai_list; ai != NULL; ai = ai->ai_next) {
151                 client->fd = socket(ai->ai_family, ai->ai_socktype, ai->ai_protocol);
152                 if (client->fd < 0) {
153                         char errbuf[1024];
154                         sdb_log(SDB_LOG_ERR, "Failed to open socket: %s",
155                                         sdb_strerror(errno, errbuf, sizeof(errbuf)));
156                         continue;
157                 }
159                 if (connect(client->fd, ai->ai_addr, ai->ai_addrlen)) {
160                         char host[1024], port[32], errbuf[1024];
161                         sdb_client_close(client);
162                         getnameinfo(ai->ai_addr, ai->ai_addrlen, host, sizeof(host),
163                                         port, sizeof(port), NI_NUMERICHOST | NI_NUMERICSERV);
164                         sdb_log(SDB_LOG_ERR, "Failed to connect to '%s:%s': %s",
165                                         host, port, sdb_strerror(errno, errbuf, sizeof(errbuf)));
166                         continue;
167                 }
168                 break;
169         }
170         freeaddrinfo(ai_list);
172         if (client->fd < 0)
173                 return -1;
175         client->ssl = sdb_ssl_client_create(&client->ssl_opts);
176         if (! client->ssl) {
177                 sdb_client_close(client);
178                 return -1;
179         }
180         client->ssl_session = sdb_ssl_client_connect(client->ssl, client->fd);
181         if (! client->ssl_session) {
182                 sdb_client_close(client);
183                 return -1;
184         }
186         client->read = ssl_read;
187         client->write = ssl_write;
188         return client->fd;
189 } /* connect_tcp */
191 static void
192 free_ssl_options(sdb_ssl_options_t *opts)
194         if (opts->ca_file)
195                 free(opts->ca_file);
196         if (opts->key_file)
197                 free(opts->key_file);
198         if (opts->cert_file)
199                 free(opts->cert_file);
200         if (opts->crl_file)
201                 free(opts->crl_file);
202         opts->ca_file = opts->key_file = opts->cert_file = opts->crl_file = NULL;
203 } /* free_ssl_options */
205 /*
206  * public API
207  */
209 sdb_client_t *
210 sdb_client_create(const char *address)
212         sdb_client_t *client;
214         if (! address)
215                 return NULL;
217         client = malloc(sizeof(*client));
218         if (! client) {
219                 sdb_log(SDB_LOG_ERR, "Out of memory");
220                 return NULL;
221         }
222         memset(client, 0, sizeof(*client));
223         client->fd = -1;
224         client->eof = 1;
226         client->ssl = NULL;
227         client->read = client_read;
228         client->write = client_write;
230         client->address = strdup(address);
231         if (! client->address) {
232                 sdb_client_destroy(client);
233                 sdb_log(SDB_LOG_ERR, "Out of memory");
234                 return NULL;
235         }
237         return client;
238 } /* sdb_client_create */
240 void
241 sdb_client_destroy(sdb_client_t *client)
243         if (! client)
244                 return;
246         sdb_client_close(client);
248         if (client->address)
249                 free(client->address);
250         client->address = NULL;
252         free_ssl_options(&client->ssl_opts);
254         free(client);
255 } /* sdb_client_destroy */
257 int
258 sdb_client_set_ssl_options(sdb_client_t *client, const sdb_ssl_options_t *opts)
260         int ret = 0;
262         if ((! client) || (! opts))
263                 return -1;
265         free_ssl_options(&client->ssl_opts);
267         if (opts->ca_file) {
268                 client->ssl_opts.ca_file = strdup(opts->ca_file);
269                 if (! client->ssl_opts.ca_file)
270                         ret = -1;
271         }
272         if (opts->key_file) {
273                 client->ssl_opts.key_file = strdup(opts->key_file);
274                 if (! client->ssl_opts.key_file)
275                         ret = -1;
276         }
277         if (opts->cert_file) {
278                 client->ssl_opts.cert_file = strdup(opts->cert_file);
279                 if (! client->ssl_opts.cert_file)
280                         ret = -1;
281         }
282         if (opts->crl_file) {
283                 client->ssl_opts.crl_file = strdup(opts->crl_file);
284                 if (! client->ssl_opts.crl_file)
285                         ret = -1;
286         }
288         if (ret)
289                 free_ssl_options(&client->ssl_opts);
290         return ret;
291 } /* sdb_client_set_ssl_options */
293 int
294 sdb_client_connect(sdb_client_t *client, const char *username)
296         sdb_strbuf_t *buf;
297         ssize_t status;
298         uint32_t rstatus;
300         if ((! client) || (! client->address))
301                 return -1;
303         if (client->fd >= 0)
304                 return -1;
306         if (*client->address == '/')
307                 connect_unixsock(client, client->address);
308         else if (!strncasecmp(client->address, "unix:", strlen("unix:")))
309                 connect_unixsock(client, client->address + strlen("unix:"));
310         else if (!strncasecmp(client->address, "tcp:", strlen("tcp:")))
311                 connect_tcp(client, client->address + strlen("tcp:"));
312         else
313                 connect_tcp(client, client->address);
315         if (client->fd < 0)
316                 return -1;
317         client->eof = 0;
319         /* XXX */
320         if (! username)
321                 username = "";
323         buf = sdb_strbuf_create(64);
324         rstatus = 0;
325         status = sdb_client_rpc(client, SDB_CONNECTION_STARTUP,
326                         (uint32_t)strlen(username), username, &rstatus, buf);
327         if ((status >= 0) && (rstatus == SDB_CONNECTION_OK)) {
328                 sdb_strbuf_destroy(buf);
329                 return 0;
330         }
332         if (status < 0) {
333                 sdb_log(SDB_LOG_ERR, "%s", sdb_strbuf_string(buf));
334                 sdb_client_close(client);
335                 sdb_strbuf_destroy(buf);
336                 return (int)status;
337         }
338         if (client->eof)
339                 sdb_log(SDB_LOG_ERR, "Encountered end-of-file while waiting "
340                                 "for server response");
342         if (rstatus == SDB_CONNECTION_ERROR) {
343                 sdb_log(SDB_LOG_ERR, "Access denied for user '%s': %s",
344                                 username, sdb_strbuf_string(buf));
345                 status = -((int)rstatus);
346         }
347         else if (rstatus != SDB_CONNECTION_OK) {
348                 sdb_log(SDB_LOG_ERR, "Received unsupported authentication request "
349                                 "(status %d) during startup", (int)rstatus);
350                 status = -((int)rstatus);
351         }
353         sdb_client_close(client);
354         sdb_strbuf_destroy(buf);
355         return (int)status;
356 } /* sdb_client_connect */
358 int
359 sdb_client_sockfd(sdb_client_t *client)
361         if (! client)
362                 return -1;
363         return client->fd;
364 } /* sdb_client_sockfd */
366 int
367 sdb_client_shutdown(sdb_client_t *client, int how)
369         if (! client) {
370                 errno = ENOTSOCK;
371                 return -1;
372         }
374         if (client->fd < 0) {
375                 errno = EBADF;
376                 return -1;
377         }
379         return shutdown(client->fd, how);
380 } /* sdb_client_shutdown */
382 void
383 sdb_client_close(sdb_client_t *client)
385         if (! client)
386                 return;
388         if (client->ssl_session) {
389                 sdb_ssl_session_destroy(client->ssl_session);
390                 client->ssl_session = NULL;
391         }
392         if (client->ssl) {
393                 sdb_ssl_client_destroy(client->ssl);
394                 client->ssl = NULL;
395         }
397         close(client->fd);
398         client->fd = -1;
399         client->eof = 1;
400 } /* sdb_client_close */
402 ssize_t
403 sdb_client_rpc(sdb_client_t *client,
404                 uint32_t cmd, uint32_t msg_len, const char *msg,
405                 uint32_t *code, sdb_strbuf_t *buf)
407         uint32_t rcode = 0;
408         ssize_t status;
410         if (! buf)
411                 return -1;
413         if (sdb_client_send(client, cmd, msg_len, msg) < 0) {
414                 char errbuf[1024];
415                 sdb_strbuf_sprintf(buf, "Failed to send %s message to server: %s",
416                                 SDB_CONN_MSGTYPE_TO_STRING(cmd),
417                                 sdb_strerror(errno, errbuf, sizeof(errbuf)));
418                 if (code)
419                         *code = SDB_CONNECTION_ERROR;
420                 return -1;
421         }
423         while (42) {
424                 size_t offset = sdb_strbuf_len(buf);
426                 status = sdb_client_recv(client, &rcode, buf);
427                 if (status < 0) {
428                         char errbuf[1024];
429                         sdb_strbuf_sprintf(buf, "Failed to receive server response: %s",
430                                         sdb_strerror(errno, errbuf, sizeof(errbuf)));
431                         if (code)
432                                 *code = SDB_CONNECTION_ERROR;
433                         return status;
434                 }
436                 if (rcode == SDB_CONNECTION_LOG) {
437                         uint32_t prio = 0;
438                         if (sdb_proto_unmarshal_int32(SDB_STRBUF_STR(buf), &prio) < 0) {
439                                 sdb_log(SDB_LOG_WARNING, "Received a LOG message "
440                                                 "with invalid or missing priority");
441                                 prio = (uint32_t)SDB_LOG_ERR;
442                         }
443                         sdb_log((int)prio, "%s", sdb_strbuf_string(buf) + offset);
444                         sdb_strbuf_skip(buf, offset, sdb_strbuf_len(buf) - offset);
445                         continue;
446                 }
447                 break;
448         }
450         if (code)
451                 *code = rcode;
452         return status;
453 } /* sdb_client_rpc */
455 ssize_t
456 sdb_client_send(sdb_client_t *client,
457                 uint32_t cmd, uint32_t msg_len, const char *msg)
459         char buf[2 * sizeof(uint32_t) + msg_len];
461         if ((! client) || (! client->fd))
462                 return -1;
463         if (sdb_proto_marshal(buf, sizeof(buf), cmd, msg_len, msg) < 0)
464                 return -1;
466         return client->write(client, buf, sizeof(buf));
467 } /* sdb_client_send */
469 ssize_t
470 sdb_client_recv(sdb_client_t *client,
471                 uint32_t *code, sdb_strbuf_t *buf)
473         uint32_t rstatus = UINT32_MAX;
474         uint32_t rlen = UINT32_MAX;
476         size_t total = 0;
477         size_t req = 2 * sizeof(uint32_t);
479         size_t data_offset = sdb_strbuf_len(buf);
481         if (code)
482                 *code = UINT32_MAX;
484         if ((! client) || (! client->fd) || (! buf)) {
485                 errno = EBADF;
486                 return -1;
487         }
489         while (42) {
490                 ssize_t status;
492                 errno = 0;
493                 status = client->read(client, buf, req);
494                 if (status < 0) {
495                         if ((errno == EAGAIN) || (errno == EWOULDBLOCK))
496                                 continue;
497                         return status;
498                 }
499                 else if (! status) {
500                         client->eof = 1;
501                         break;
502                 }
504                 total += (size_t)status;
506                 if (total != req)
507                         continue;
509                 if (rstatus == UINT32_MAX) {
510                         const char *str = sdb_strbuf_string(buf) + data_offset;
511                         size_t len = sdb_strbuf_len(buf) - data_offset;
512                         ssize_t n;
514                         /* retrieve status and data len */
515                         assert(len >= 2 * sizeof(uint32_t));
516                         n = sdb_proto_unmarshal_int32(str, len, &rstatus);
517                         str += n; len -= (size_t)n;
518                         sdb_proto_unmarshal_int32(str, len, &rlen);
520                         if (! rlen)
521                                 break;
523                         req = (size_t)rlen;
524                         total = 0;
525                 }
526                 else /* finished reading data */
527                         break;
528         }
530         if (total != req) {
531                 /* unexpected EOF; clear partially read data */
532                 sdb_strbuf_skip(buf, data_offset, sdb_strbuf_len(buf));
533                 return 0;
534         }
536         if (rstatus != UINT32_MAX)
537                 /* remove status,len */
538                 sdb_strbuf_skip(buf, data_offset, 2 * sizeof(rstatus));
540         if (code)
541                 *code = rstatus;
543         return (ssize_t)total;
544 } /* sdb_client_recv */
546 bool
547 sdb_client_eof(sdb_client_t *client)
549         if ((! client) || (client->fd < 0))
550                 return 1;
551         return client->eof;
552 } /* sdb_client_eof */
554 /* vim: set tw=78 sw=4 ts=4 noexpandtab : */