Code

Split the memstore module from the store module.
[sysdb.git] / src / client / sock.c
1 /*
2  * SysDB - src/client/sock.c
3  * Copyright (C) 2013 Sebastian 'tokkee' Harl <sh@tokkee.org>
4  * All rights reserved.
5  *
6  * Redistribution and use in source and binary forms, with or without
7  * modification, are permitted provided that the following conditions
8  * are met:
9  * 1. Redistributions of source code must retain the above copyright
10  *    notice, this list of conditions and the following disclaimer.
11  * 2. Redistributions in binary form must reproduce the above copyright
12  *    notice, this list of conditions and the following disclaimer in the
13  *    documentation and/or other materials provided with the distribution.
14  *
15  * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
16  * ``AS IS'' AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED
17  * TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR
18  * PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDERS OR
19  * CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL,
20  * EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO,
21  * PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS;
22  * OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY,
23  * WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR
24  * OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF
25  * ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
26  */
28 #if HAVE_CONFIG_H
29 #       include "config.h"
30 #endif /* HAVE_CONFIG_H */
32 #include "sysdb.h"
33 #include "client/sock.h"
34 #include "utils/error.h"
35 #include "utils/strbuf.h"
36 #include "utils/proto.h"
37 #include "utils/os.h"
38 #include "utils/ssl.h"
40 #include <arpa/inet.h>
42 #include <assert.h>
43 #include <errno.h>
44 #include <limits.h>
46 #include <stdlib.h>
48 #include <string.h>
49 #include <strings.h>
51 #include <unistd.h>
53 #include <sys/socket.h>
54 #include <sys/un.h>
56 #include <netdb.h>
58 /*
59  * private data types
60  */
62 struct sdb_client {
63         char *address;
64         int   fd;
65         bool  eof;
67         /* optional SSL settings */
68         sdb_ssl_options_t ssl_opts;
69         sdb_ssl_client_t *ssl;
70         sdb_ssl_session_t *ssl_session;
72         ssize_t (*read)(sdb_client_t *, sdb_strbuf_t *, size_t);
73         ssize_t (*write)(sdb_client_t *, const void *, size_t);
74 };
76 /*
77  * private helper functions
78  */
80 static ssize_t
81 ssl_read(sdb_client_t *client, sdb_strbuf_t *buf, size_t n)
82 {
83         char tmp[n];
84         ssize_t ret;
86         ret = sdb_ssl_session_read(client->ssl_session, tmp, n);
87         if (ret <= 0)
88                 return ret;
90         sdb_strbuf_memappend(buf, tmp, ret);
91         return ret;
92 } /* ssl_read */
94 static ssize_t
95 ssl_write(sdb_client_t *client, const void *buf, size_t n)
96 {
97         return sdb_ssl_session_write(client->ssl_session, buf, n);
98 } /* ssl_write */
100 static ssize_t
101 client_read(sdb_client_t *client, sdb_strbuf_t *buf, size_t n)
103         return sdb_strbuf_read(buf, client->fd, n);
104 } /* client_read */
106 static ssize_t
107 client_write(sdb_client_t *client, const void *buf, size_t n)
109         return sdb_write(client->fd, n, buf);
110 } /* client_write */
112 static int
113 connect_unixsock(sdb_client_t *client, const char *address)
115         struct sockaddr_un sa;
117         client->fd = socket(AF_UNIX, SOCK_STREAM, /* protocol = */ 0);
118         if (client->fd < 0) {
119                 char errbuf[1024];
120                 sdb_log(SDB_LOG_ERR, "client: Failed to open socket: %s",
121                                 sdb_strerror(errno, errbuf, sizeof(errbuf)));
122                 return -1;
123         }
125         sa.sun_family = AF_UNIX;
126         strncpy(sa.sun_path, address, sizeof(sa.sun_path));
127         sa.sun_path[sizeof(sa.sun_path) - 1] = '\0';
129         if (connect(client->fd, (struct sockaddr *)&sa, sizeof(sa))) {
130                 char errbuf[1024];
131                 sdb_client_close(client);
132                 sdb_log(SDB_LOG_ERR, "client: Failed to connect to '%s': %s",
133                                 sa.sun_path, sdb_strerror(errno, errbuf, sizeof(errbuf)));
134                 return -1;
135         }
136         return client->fd;
137 } /* connect_unixsock */
139 static int
140 connect_tcp(sdb_client_t *client, const char *address)
142         char host[SDB_MAX(strlen("localhost"), (address ? strlen(address) : 0)) + 1];
143         struct addrinfo *ai, *ai_list = NULL;
144         char *peer, *tmp;
145         int status;
147         if ((status = sdb_resolve(SDB_NET_TCP, address, &ai_list))) {
148                 sdb_log(SDB_LOG_ERR, "client: Failed to resolve '%s': %s",
149                                 address, gai_strerror(status));
150                 return -1;
151         }
153         for (ai = ai_list; ai != NULL; ai = ai->ai_next) {
154                 client->fd = socket(ai->ai_family, ai->ai_socktype, ai->ai_protocol);
155                 if (client->fd < 0) {
156                         char errbuf[1024];
157                         sdb_log(SDB_LOG_ERR, "client: Failed to open socket: %s",
158                                         sdb_strerror(errno, errbuf, sizeof(errbuf)));
159                         continue;
160                 }
162                 if (connect(client->fd, ai->ai_addr, ai->ai_addrlen)) {
163                         char h[1024], p[32], errbuf[1024];
164                         sdb_client_close(client);
165                         getnameinfo(ai->ai_addr, ai->ai_addrlen, h, sizeof(h),
166                                         p, sizeof(p), NI_NUMERICHOST | NI_NUMERICSERV);
167                         sdb_log(SDB_LOG_ERR, "client: Failed to connect to '%s:%s': %s",
168                                         h, p, sdb_strerror(errno, errbuf, sizeof(errbuf)));
169                         continue;
170                 }
171                 break;
172         }
173         freeaddrinfo(ai_list);
175         if (client->fd < 0)
176                 return -1;
178         client->ssl = sdb_ssl_client_create(&client->ssl_opts);
179         if (! client->ssl) {
180                 sdb_client_close(client);
181                 return -1;
182         }
183         client->ssl_session = sdb_ssl_client_connect(client->ssl, client->fd);
184         if (! client->ssl_session) {
185                 sdb_client_close(client);
186                 return -1;
187         }
189         strncpy(host, address, sizeof(host));
190         if ((tmp = strrchr(host, (int)':')))
191                 *tmp = '\0';
192         if (! host[0])
193                 strncpy(host, "localhost", sizeof(host));
194         peer = sdb_ssl_session_peer(client->ssl_session);
195         if ((! peer) || strcasecmp(peer, host)) {
196                 /* TODO: also check alt-name */
197                 sdb_log(SDB_LOG_ERR, "client: Failed to connect to '%s': "
198                                 "peer name '%s' does not match host address",
199                                 address, peer);
200                 sdb_client_close(client);
201                 if (peer)
202                         free(peer);
203                 return -1;
204         }
205         free(peer);
207         client->read = ssl_read;
208         client->write = ssl_write;
209         return client->fd;
210 } /* connect_tcp */
212 /*
213  * public API
214  */
216 sdb_client_t *
217 sdb_client_create(const char *address)
219         sdb_client_t *client;
221         if (! address)
222                 return NULL;
224         client = malloc(sizeof(*client));
225         if (! client) {
226                 sdb_log(SDB_LOG_ERR, "client: Out of memory");
227                 return NULL;
228         }
229         memset(client, 0, sizeof(*client));
230         client->fd = -1;
231         client->eof = 1;
233         client->ssl = NULL;
234         client->read = client_read;
235         client->write = client_write;
237         client->address = strdup(address);
238         if (! client->address) {
239                 sdb_client_destroy(client);
240                 sdb_log(SDB_LOG_ERR, "client: Out of memory");
241                 return NULL;
242         }
244         return client;
245 } /* sdb_client_create */
247 void
248 sdb_client_destroy(sdb_client_t *client)
250         if (! client)
251                 return;
253         sdb_client_close(client);
255         if (client->address)
256                 free(client->address);
257         client->address = NULL;
259         sdb_ssl_free_options(&client->ssl_opts);
261         free(client);
262 } /* sdb_client_destroy */
264 int
265 sdb_client_set_ssl_options(sdb_client_t *client, const sdb_ssl_options_t *opts)
267         int ret = 0;
269         if ((! client) || (! opts))
270                 return -1;
272         sdb_ssl_free_options(&client->ssl_opts);
274         if (opts->ca_file) {
275                 client->ssl_opts.ca_file = strdup(opts->ca_file);
276                 if (! client->ssl_opts.ca_file)
277                         ret = -1;
278         }
279         if (opts->key_file) {
280                 client->ssl_opts.key_file = strdup(opts->key_file);
281                 if (! client->ssl_opts.key_file)
282                         ret = -1;
283         }
284         if (opts->cert_file) {
285                 client->ssl_opts.cert_file = strdup(opts->cert_file);
286                 if (! client->ssl_opts.cert_file)
287                         ret = -1;
288         }
289         if (opts->crl_file) {
290                 client->ssl_opts.crl_file = strdup(opts->crl_file);
291                 if (! client->ssl_opts.crl_file)
292                         ret = -1;
293         }
295         if (ret)
296                 sdb_ssl_free_options(&client->ssl_opts);
297         return ret;
298 } /* sdb_client_set_ssl_options */
300 int
301 sdb_client_connect(sdb_client_t *client, const char *username)
303         sdb_strbuf_t *buf;
304         ssize_t status;
305         uint32_t rstatus;
307         if ((! client) || (! client->address))
308                 return -1;
310         if (client->fd >= 0)
311                 return -1;
313         if (*client->address == '/')
314                 connect_unixsock(client, client->address);
315         else if (!strncasecmp(client->address, "unix:", strlen("unix:")))
316                 connect_unixsock(client, client->address + strlen("unix:"));
317         else if (!strncasecmp(client->address, "tcp:", strlen("tcp:")))
318                 connect_tcp(client, client->address + strlen("tcp:"));
319         else
320                 connect_tcp(client, client->address);
322         if (client->fd < 0)
323                 return -1;
324         client->eof = 0;
326         /* XXX */
327         if (! username)
328                 username = "";
330         buf = sdb_strbuf_create(64);
331         rstatus = 0;
332         status = sdb_client_rpc(client, SDB_CONNECTION_STARTUP,
333                         (uint32_t)strlen(username), username, &rstatus, buf);
334         if ((status >= 0) && (rstatus == SDB_CONNECTION_OK)) {
335                 sdb_strbuf_destroy(buf);
336                 return 0;
337         }
339         if (status < 0) {
340                 sdb_log(SDB_LOG_ERR, "client: %s", sdb_strbuf_string(buf));
341                 sdb_client_close(client);
342                 sdb_strbuf_destroy(buf);
343                 return (int)status;
344         }
345         if (client->eof)
346                 sdb_log(SDB_LOG_ERR, "client: Encountered end-of-file while waiting "
347                                 "for server response");
349         if (rstatus == SDB_CONNECTION_ERROR) {
350                 sdb_log(SDB_LOG_ERR, "client: Access denied for user '%s': %s",
351                                 username, sdb_strbuf_string(buf));
352                 status = -((int)rstatus);
353         }
354         else if (rstatus != SDB_CONNECTION_OK) {
355                 sdb_log(SDB_LOG_ERR, "client: Received unsupported authentication "
356                                 "request (status %d) during startup", (int)rstatus);
357                 status = -((int)rstatus);
358         }
360         sdb_client_close(client);
361         sdb_strbuf_destroy(buf);
362         return (int)status;
363 } /* sdb_client_connect */
365 int
366 sdb_client_sockfd(sdb_client_t *client)
368         if (! client)
369                 return -1;
370         return client->fd;
371 } /* sdb_client_sockfd */
373 int
374 sdb_client_shutdown(sdb_client_t *client, int how)
376         if (! client) {
377                 errno = ENOTSOCK;
378                 return -1;
379         }
381         if (client->fd < 0) {
382                 errno = EBADF;
383                 return -1;
384         }
386         return shutdown(client->fd, how);
387 } /* sdb_client_shutdown */
389 void
390 sdb_client_close(sdb_client_t *client)
392         if (! client)
393                 return;
395         if (client->ssl_session) {
396                 sdb_ssl_session_destroy(client->ssl_session);
397                 client->ssl_session = NULL;
398         }
399         if (client->ssl) {
400                 sdb_ssl_client_destroy(client->ssl);
401                 client->ssl = NULL;
402         }
404         close(client->fd);
405         client->fd = -1;
406         client->eof = 1;
407 } /* sdb_client_close */
409 ssize_t
410 sdb_client_rpc(sdb_client_t *client,
411                 uint32_t cmd, uint32_t msg_len, const char *msg,
412                 uint32_t *code, sdb_strbuf_t *buf)
414         uint32_t rcode = 0;
415         ssize_t status;
417         if (! buf)
418                 return -1;
420         if (sdb_client_send(client, cmd, msg_len, msg) < 0) {
421                 char errbuf[1024];
422                 sdb_strbuf_sprintf(buf, "Failed to send %s message to server: %s",
423                                 SDB_CONN_MSGTYPE_TO_STRING(cmd),
424                                 sdb_strerror(errno, errbuf, sizeof(errbuf)));
425                 if (code)
426                         *code = SDB_CONNECTION_ERROR;
427                 return -1;
428         }
430         while (42) {
431                 size_t offset = sdb_strbuf_len(buf);
433                 status = sdb_client_recv(client, &rcode, buf);
434                 if (status < 0) {
435                         char errbuf[1024];
436                         sdb_strbuf_sprintf(buf, "Failed to receive server response: %s",
437                                         sdb_strerror(errno, errbuf, sizeof(errbuf)));
438                         if (code)
439                                 *code = SDB_CONNECTION_ERROR;
440                         return status;
441                 }
443                 if (rcode == SDB_CONNECTION_LOG) {
444                         uint32_t prio = 0;
445                         if (sdb_proto_unmarshal_int32(SDB_STRBUF_STR(buf), &prio) < 0) {
446                                 sdb_log(SDB_LOG_WARNING, "client: Received a LOG message "
447                                                 "with invalid or missing priority");
448                                 prio = (uint32_t)SDB_LOG_ERR;
449                         }
450                         sdb_log((int)prio, "client: %s", sdb_strbuf_string(buf) + offset);
451                         sdb_strbuf_skip(buf, offset, sdb_strbuf_len(buf) - offset);
452                         continue;
453                 }
454                 break;
455         }
457         if (code)
458                 *code = rcode;
459         return status;
460 } /* sdb_client_rpc */
462 ssize_t
463 sdb_client_send(sdb_client_t *client,
464                 uint32_t cmd, uint32_t msg_len, const char *msg)
466         char buf[2 * sizeof(uint32_t) + msg_len];
468         if ((! client) || (! client->fd))
469                 return -1;
470         if (sdb_proto_marshal(buf, sizeof(buf), cmd, msg_len, msg) < 0)
471                 return -1;
473         return client->write(client, buf, sizeof(buf));
474 } /* sdb_client_send */
476 ssize_t
477 sdb_client_recv(sdb_client_t *client,
478                 uint32_t *code, sdb_strbuf_t *buf)
480         uint32_t rstatus = UINT32_MAX;
481         uint32_t rlen = UINT32_MAX;
483         size_t total = 0;
484         size_t req = 2 * sizeof(uint32_t);
486         size_t data_offset = sdb_strbuf_len(buf);
488         if (code)
489                 *code = UINT32_MAX;
491         if ((! client) || (! client->fd) || (! buf)) {
492                 errno = EBADF;
493                 return -1;
494         }
496         while (42) {
497                 ssize_t status;
499                 errno = 0;
500                 status = client->read(client, buf, req);
501                 if (status < 0) {
502                         if ((errno == EAGAIN) || (errno == EWOULDBLOCK))
503                                 continue;
504                         return status;
505                 }
506                 else if (! status) {
507                         client->eof = 1;
508                         break;
509                 }
511                 total += (size_t)status;
513                 if (total != req)
514                         continue;
516                 if (rstatus == UINT32_MAX) {
517                         const char *str = sdb_strbuf_string(buf) + data_offset;
518                         size_t len = sdb_strbuf_len(buf) - data_offset;
519                         ssize_t n;
521                         /* retrieve status and data len */
522                         assert(len >= 2 * sizeof(uint32_t));
523                         n = sdb_proto_unmarshal_int32(str, len, &rstatus);
524                         str += n; len -= (size_t)n;
525                         sdb_proto_unmarshal_int32(str, len, &rlen);
527                         if (! rlen)
528                                 break;
530                         req = (size_t)rlen;
531                         total = 0;
532                 }
533                 else /* finished reading data */
534                         break;
535         }
537         if (total != req) {
538                 /* unexpected EOF; clear partially read data */
539                 sdb_strbuf_skip(buf, data_offset, sdb_strbuf_len(buf));
540                 return 0;
541         }
543         if (rstatus != UINT32_MAX)
544                 /* remove status,len */
545                 sdb_strbuf_skip(buf, data_offset, 2 * sizeof(rstatus));
547         if (code)
548                 *code = rstatus;
550         return (ssize_t)total;
551 } /* sdb_client_recv */
553 bool
554 sdb_client_eof(sdb_client_t *client)
556         if ((! client) || (client->fd < 0))
557                 return 1;
558         return client->eof;
559 } /* sdb_client_eof */
561 /* vim: set tw=78 sw=4 ts=4 noexpandtab : */