Code

SSL utils: Add sdb_ssl_free_options().
[sysdb.git] / src / client / sock.c
1 /*
2  * SysDB - src/client/sock.c
3  * Copyright (C) 2013 Sebastian 'tokkee' Harl <sh@tokkee.org>
4  * All rights reserved.
5  *
6  * Redistribution and use in source and binary forms, with or without
7  * modification, are permitted provided that the following conditions
8  * are met:
9  * 1. Redistributions of source code must retain the above copyright
10  *    notice, this list of conditions and the following disclaimer.
11  * 2. Redistributions in binary form must reproduce the above copyright
12  *    notice, this list of conditions and the following disclaimer in the
13  *    documentation and/or other materials provided with the distribution.
14  *
15  * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
16  * ``AS IS'' AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED
17  * TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR
18  * PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDERS OR
19  * CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL,
20  * EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO,
21  * PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS;
22  * OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY,
23  * WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR
24  * OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF
25  * ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
26  */
28 #if HAVE_CONFIG_H
29 #       include "config.h"
30 #endif /* HAVE_CONFIG_H */
32 #include "client/sock.h"
33 #include "utils/error.h"
34 #include "utils/strbuf.h"
35 #include "utils/proto.h"
36 #include "utils/os.h"
37 #include "utils/ssl.h"
39 #include <arpa/inet.h>
41 #include <assert.h>
42 #include <errno.h>
43 #include <limits.h>
45 #include <stdlib.h>
47 #include <string.h>
48 #include <strings.h>
50 #include <unistd.h>
52 #include <sys/socket.h>
53 #include <sys/un.h>
55 #include <netdb.h>
57 /*
58  * private data types
59  */
61 struct sdb_client {
62         char *address;
63         int   fd;
64         bool  eof;
66         /* optional SSL settings */
67         sdb_ssl_options_t ssl_opts;
68         sdb_ssl_client_t *ssl;
69         sdb_ssl_session_t *ssl_session;
71         ssize_t (*read)(sdb_client_t *, sdb_strbuf_t *, size_t);
72         ssize_t (*write)(sdb_client_t *, const void *, size_t);
73 };
75 /*
76  * private helper functions
77  */
79 static ssize_t
80 ssl_read(sdb_client_t *client, sdb_strbuf_t *buf, size_t n)
81 {
82         char tmp[n];
83         ssize_t ret;
85         ret = sdb_ssl_session_read(client->ssl_session, tmp, n);
86         if (ret <= 0)
87                 return ret;
89         sdb_strbuf_memappend(buf, tmp, ret);
90         return ret;
91 } /* ssl_read */
93 static ssize_t
94 ssl_write(sdb_client_t *client, const void *buf, size_t n)
95 {
96         return sdb_ssl_session_write(client->ssl_session, buf, n);
97 } /* ssl_write */
99 static ssize_t
100 client_read(sdb_client_t *client, sdb_strbuf_t *buf, size_t n)
102         return sdb_strbuf_read(buf, client->fd, n);
103 } /* client_read */
105 static ssize_t
106 client_write(sdb_client_t *client, const void *buf, size_t n)
108         return sdb_write(client->fd, n, buf);
109 } /* client_write */
111 static int
112 connect_unixsock(sdb_client_t *client, const char *address)
114         struct sockaddr_un sa;
116         client->fd = socket(AF_UNIX, SOCK_STREAM, /* protocol = */ 0);
117         if (client->fd < 0) {
118                 char errbuf[1024];
119                 sdb_log(SDB_LOG_ERR, "Failed to open socket: %s",
120                                 sdb_strerror(errno, errbuf, sizeof(errbuf)));
121                 return -1;
122         }
124         sa.sun_family = AF_UNIX;
125         strncpy(sa.sun_path, address, sizeof(sa.sun_path));
126         sa.sun_path[sizeof(sa.sun_path) - 1] = '\0';
128         if (connect(client->fd, (struct sockaddr *)&sa, sizeof(sa))) {
129                 char errbuf[1024];
130                 sdb_client_close(client);
131                 sdb_log(SDB_LOG_ERR, "Failed to connect to '%s': %s",
132                                 sa.sun_path, sdb_strerror(errno, errbuf, sizeof(errbuf)));
133                 return -1;
134         }
135         return client->fd;
136 } /* connect_unixsock */
138 static int
139 connect_tcp(sdb_client_t *client, const char *address)
141         struct addrinfo *ai, *ai_list = NULL;
142         int status;
144         if ((status = sdb_resolve(SDB_NET_TCP, address, &ai_list))) {
145                 sdb_log(SDB_LOG_ERR, "Failed to resolve '%s': %s",
146                                 address, gai_strerror(status));
147                 return -1;
148         }
150         for (ai = ai_list; ai != NULL; ai = ai->ai_next) {
151                 client->fd = socket(ai->ai_family, ai->ai_socktype, ai->ai_protocol);
152                 if (client->fd < 0) {
153                         char errbuf[1024];
154                         sdb_log(SDB_LOG_ERR, "Failed to open socket: %s",
155                                         sdb_strerror(errno, errbuf, sizeof(errbuf)));
156                         continue;
157                 }
159                 if (connect(client->fd, ai->ai_addr, ai->ai_addrlen)) {
160                         char host[1024], port[32], errbuf[1024];
161                         sdb_client_close(client);
162                         getnameinfo(ai->ai_addr, ai->ai_addrlen, host, sizeof(host),
163                                         port, sizeof(port), NI_NUMERICHOST | NI_NUMERICSERV);
164                         sdb_log(SDB_LOG_ERR, "Failed to connect to '%s:%s': %s",
165                                         host, port, sdb_strerror(errno, errbuf, sizeof(errbuf)));
166                         continue;
167                 }
168                 break;
169         }
170         freeaddrinfo(ai_list);
172         if (client->fd < 0)
173                 return -1;
175         client->ssl = sdb_ssl_client_create(&client->ssl_opts);
176         if (! client->ssl) {
177                 sdb_client_close(client);
178                 return -1;
179         }
180         client->ssl_session = sdb_ssl_client_connect(client->ssl, client->fd);
181         if (! client->ssl_session) {
182                 sdb_client_close(client);
183                 return -1;
184         }
186         client->read = ssl_read;
187         client->write = ssl_write;
188         return client->fd;
189 } /* connect_tcp */
191 /*
192  * public API
193  */
195 sdb_client_t *
196 sdb_client_create(const char *address)
198         sdb_client_t *client;
200         if (! address)
201                 return NULL;
203         client = malloc(sizeof(*client));
204         if (! client) {
205                 sdb_log(SDB_LOG_ERR, "Out of memory");
206                 return NULL;
207         }
208         memset(client, 0, sizeof(*client));
209         client->fd = -1;
210         client->eof = 1;
212         client->ssl = NULL;
213         client->read = client_read;
214         client->write = client_write;
216         client->address = strdup(address);
217         if (! client->address) {
218                 sdb_client_destroy(client);
219                 sdb_log(SDB_LOG_ERR, "Out of memory");
220                 return NULL;
221         }
223         return client;
224 } /* sdb_client_create */
226 void
227 sdb_client_destroy(sdb_client_t *client)
229         if (! client)
230                 return;
232         sdb_client_close(client);
234         if (client->address)
235                 free(client->address);
236         client->address = NULL;
238         sdb_ssl_free_options(&client->ssl_opts);
240         free(client);
241 } /* sdb_client_destroy */
243 int
244 sdb_client_set_ssl_options(sdb_client_t *client, const sdb_ssl_options_t *opts)
246         int ret = 0;
248         if ((! client) || (! opts))
249                 return -1;
251         sdb_ssl_free_options(&client->ssl_opts);
253         if (opts->ca_file) {
254                 client->ssl_opts.ca_file = strdup(opts->ca_file);
255                 if (! client->ssl_opts.ca_file)
256                         ret = -1;
257         }
258         if (opts->key_file) {
259                 client->ssl_opts.key_file = strdup(opts->key_file);
260                 if (! client->ssl_opts.key_file)
261                         ret = -1;
262         }
263         if (opts->cert_file) {
264                 client->ssl_opts.cert_file = strdup(opts->cert_file);
265                 if (! client->ssl_opts.cert_file)
266                         ret = -1;
267         }
268         if (opts->crl_file) {
269                 client->ssl_opts.crl_file = strdup(opts->crl_file);
270                 if (! client->ssl_opts.crl_file)
271                         ret = -1;
272         }
274         if (ret)
275                 sdb_ssl_free_options(&client->ssl_opts);
276         return ret;
277 } /* sdb_client_set_ssl_options */
279 int
280 sdb_client_connect(sdb_client_t *client, const char *username)
282         sdb_strbuf_t *buf;
283         ssize_t status;
284         uint32_t rstatus;
286         if ((! client) || (! client->address))
287                 return -1;
289         if (client->fd >= 0)
290                 return -1;
292         if (*client->address == '/')
293                 connect_unixsock(client, client->address);
294         else if (!strncasecmp(client->address, "unix:", strlen("unix:")))
295                 connect_unixsock(client, client->address + strlen("unix:"));
296         else if (!strncasecmp(client->address, "tcp:", strlen("tcp:")))
297                 connect_tcp(client, client->address + strlen("tcp:"));
298         else
299                 connect_tcp(client, client->address);
301         if (client->fd < 0)
302                 return -1;
303         client->eof = 0;
305         /* XXX */
306         if (! username)
307                 username = "";
309         buf = sdb_strbuf_create(64);
310         rstatus = 0;
311         status = sdb_client_rpc(client, SDB_CONNECTION_STARTUP,
312                         (uint32_t)strlen(username), username, &rstatus, buf);
313         if ((status >= 0) && (rstatus == SDB_CONNECTION_OK)) {
314                 sdb_strbuf_destroy(buf);
315                 return 0;
316         }
318         if (status < 0) {
319                 sdb_log(SDB_LOG_ERR, "%s", sdb_strbuf_string(buf));
320                 sdb_client_close(client);
321                 sdb_strbuf_destroy(buf);
322                 return (int)status;
323         }
324         if (client->eof)
325                 sdb_log(SDB_LOG_ERR, "Encountered end-of-file while waiting "
326                                 "for server response");
328         if (rstatus == SDB_CONNECTION_ERROR) {
329                 sdb_log(SDB_LOG_ERR, "Access denied for user '%s': %s",
330                                 username, sdb_strbuf_string(buf));
331                 status = -((int)rstatus);
332         }
333         else if (rstatus != SDB_CONNECTION_OK) {
334                 sdb_log(SDB_LOG_ERR, "Received unsupported authentication request "
335                                 "(status %d) during startup", (int)rstatus);
336                 status = -((int)rstatus);
337         }
339         sdb_client_close(client);
340         sdb_strbuf_destroy(buf);
341         return (int)status;
342 } /* sdb_client_connect */
344 int
345 sdb_client_sockfd(sdb_client_t *client)
347         if (! client)
348                 return -1;
349         return client->fd;
350 } /* sdb_client_sockfd */
352 int
353 sdb_client_shutdown(sdb_client_t *client, int how)
355         if (! client) {
356                 errno = ENOTSOCK;
357                 return -1;
358         }
360         if (client->fd < 0) {
361                 errno = EBADF;
362                 return -1;
363         }
365         return shutdown(client->fd, how);
366 } /* sdb_client_shutdown */
368 void
369 sdb_client_close(sdb_client_t *client)
371         if (! client)
372                 return;
374         if (client->ssl_session) {
375                 sdb_ssl_session_destroy(client->ssl_session);
376                 client->ssl_session = NULL;
377         }
378         if (client->ssl) {
379                 sdb_ssl_client_destroy(client->ssl);
380                 client->ssl = NULL;
381         }
383         close(client->fd);
384         client->fd = -1;
385         client->eof = 1;
386 } /* sdb_client_close */
388 ssize_t
389 sdb_client_rpc(sdb_client_t *client,
390                 uint32_t cmd, uint32_t msg_len, const char *msg,
391                 uint32_t *code, sdb_strbuf_t *buf)
393         uint32_t rcode = 0;
394         ssize_t status;
396         if (! buf)
397                 return -1;
399         if (sdb_client_send(client, cmd, msg_len, msg) < 0) {
400                 char errbuf[1024];
401                 sdb_strbuf_sprintf(buf, "Failed to send %s message to server: %s",
402                                 SDB_CONN_MSGTYPE_TO_STRING(cmd),
403                                 sdb_strerror(errno, errbuf, sizeof(errbuf)));
404                 if (code)
405                         *code = SDB_CONNECTION_ERROR;
406                 return -1;
407         }
409         while (42) {
410                 size_t offset = sdb_strbuf_len(buf);
412                 status = sdb_client_recv(client, &rcode, buf);
413                 if (status < 0) {
414                         char errbuf[1024];
415                         sdb_strbuf_sprintf(buf, "Failed to receive server response: %s",
416                                         sdb_strerror(errno, errbuf, sizeof(errbuf)));
417                         if (code)
418                                 *code = SDB_CONNECTION_ERROR;
419                         return status;
420                 }
422                 if (rcode == SDB_CONNECTION_LOG) {
423                         uint32_t prio = 0;
424                         if (sdb_proto_unmarshal_int32(SDB_STRBUF_STR(buf), &prio) < 0) {
425                                 sdb_log(SDB_LOG_WARNING, "Received a LOG message "
426                                                 "with invalid or missing priority");
427                                 prio = (uint32_t)SDB_LOG_ERR;
428                         }
429                         sdb_log((int)prio, "%s", sdb_strbuf_string(buf) + offset);
430                         sdb_strbuf_skip(buf, offset, sdb_strbuf_len(buf) - offset);
431                         continue;
432                 }
433                 break;
434         }
436         if (code)
437                 *code = rcode;
438         return status;
439 } /* sdb_client_rpc */
441 ssize_t
442 sdb_client_send(sdb_client_t *client,
443                 uint32_t cmd, uint32_t msg_len, const char *msg)
445         char buf[2 * sizeof(uint32_t) + msg_len];
447         if ((! client) || (! client->fd))
448                 return -1;
449         if (sdb_proto_marshal(buf, sizeof(buf), cmd, msg_len, msg) < 0)
450                 return -1;
452         return client->write(client, buf, sizeof(buf));
453 } /* sdb_client_send */
455 ssize_t
456 sdb_client_recv(sdb_client_t *client,
457                 uint32_t *code, sdb_strbuf_t *buf)
459         uint32_t rstatus = UINT32_MAX;
460         uint32_t rlen = UINT32_MAX;
462         size_t total = 0;
463         size_t req = 2 * sizeof(uint32_t);
465         size_t data_offset = sdb_strbuf_len(buf);
467         if (code)
468                 *code = UINT32_MAX;
470         if ((! client) || (! client->fd) || (! buf)) {
471                 errno = EBADF;
472                 return -1;
473         }
475         while (42) {
476                 ssize_t status;
478                 errno = 0;
479                 status = client->read(client, buf, req);
480                 if (status < 0) {
481                         if ((errno == EAGAIN) || (errno == EWOULDBLOCK))
482                                 continue;
483                         return status;
484                 }
485                 else if (! status) {
486                         client->eof = 1;
487                         break;
488                 }
490                 total += (size_t)status;
492                 if (total != req)
493                         continue;
495                 if (rstatus == UINT32_MAX) {
496                         const char *str = sdb_strbuf_string(buf) + data_offset;
497                         size_t len = sdb_strbuf_len(buf) - data_offset;
498                         ssize_t n;
500                         /* retrieve status and data len */
501                         assert(len >= 2 * sizeof(uint32_t));
502                         n = sdb_proto_unmarshal_int32(str, len, &rstatus);
503                         str += n; len -= (size_t)n;
504                         sdb_proto_unmarshal_int32(str, len, &rlen);
506                         if (! rlen)
507                                 break;
509                         req = (size_t)rlen;
510                         total = 0;
511                 }
512                 else /* finished reading data */
513                         break;
514         }
516         if (total != req) {
517                 /* unexpected EOF; clear partially read data */
518                 sdb_strbuf_skip(buf, data_offset, sdb_strbuf_len(buf));
519                 return 0;
520         }
522         if (rstatus != UINT32_MAX)
523                 /* remove status,len */
524                 sdb_strbuf_skip(buf, data_offset, 2 * sizeof(rstatus));
526         if (code)
527                 *code = rstatus;
529         return (ssize_t)total;
530 } /* sdb_client_recv */
532 bool
533 sdb_client_eof(sdb_client_t *client)
535         if ((! client) || (client->fd < 0))
536                 return 1;
537         return client->eof;
538 } /* sdb_client_eof */
540 /* vim: set tw=78 sw=4 ts=4 noexpandtab : */