Code

client, sysdb: Let TCP connection use SSL.
[sysdb.git] / src / client / sock.c
1 /*
2  * SysDB - src/client/sock.c
3  * Copyright (C) 2013 Sebastian 'tokkee' Harl <sh@tokkee.org>
4  * All rights reserved.
5  *
6  * Redistribution and use in source and binary forms, with or without
7  * modification, are permitted provided that the following conditions
8  * are met:
9  * 1. Redistributions of source code must retain the above copyright
10  *    notice, this list of conditions and the following disclaimer.
11  * 2. Redistributions in binary form must reproduce the above copyright
12  *    notice, this list of conditions and the following disclaimer in the
13  *    documentation and/or other materials provided with the distribution.
14  *
15  * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
16  * ``AS IS'' AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED
17  * TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR
18  * PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDERS OR
19  * CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL,
20  * EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO,
21  * PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS;
22  * OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY,
23  * WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR
24  * OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF
25  * ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
26  */
28 #if HAVE_CONFIG_H
29 #       include "config.h"
30 #endif /* HAVE_CONFIG_H */
32 #include "client/sock.h"
33 #include "utils/error.h"
34 #include "utils/strbuf.h"
35 #include "utils/proto.h"
36 #include "utils/os.h"
37 #include "utils/ssl.h"
39 #include <arpa/inet.h>
41 #include <assert.h>
42 #include <errno.h>
43 #include <limits.h>
45 #include <stdlib.h>
47 #include <string.h>
48 #include <strings.h>
50 #include <unistd.h>
52 #include <sys/socket.h>
53 #include <sys/un.h>
55 #include <netdb.h>
57 /*
58  * private data types
59  */
61 struct sdb_client {
62         char *address;
63         int   fd;
64         bool  eof;
66         /* optional SSL settings */
67         sdb_ssl_client_t *ssl;
68         sdb_ssl_session_t *ssl_session;
70         ssize_t (*read)(sdb_client_t *, sdb_strbuf_t *, size_t);
71         ssize_t (*write)(sdb_client_t *, const void *, size_t);
72 };
74 /*
75  * private helper functions
76  */
78 static ssize_t
79 ssl_read(sdb_client_t *client, sdb_strbuf_t *buf, size_t n)
80 {
81         char tmp[n];
82         ssize_t ret;
84         ret = sdb_ssl_session_read(client->ssl_session, tmp, n);
85         if (ret <= 0)
86                 return ret;
88         sdb_strbuf_memappend(buf, tmp, ret);
89         return ret;
90 } /* ssl_read */
92 static ssize_t
93 ssl_write(sdb_client_t *client, const void *buf, size_t n)
94 {
95         return sdb_ssl_session_write(client->ssl_session, buf, n);
96 } /* ssl_write */
98 static ssize_t
99 client_read(sdb_client_t *client, sdb_strbuf_t *buf, size_t n)
101         return sdb_strbuf_read(buf, client->fd, n);
102 } /* client_read */
104 static ssize_t
105 client_write(sdb_client_t *client, const void *buf, size_t n)
107         return sdb_write(client->fd, n, buf);
108 } /* client_write */
110 static int
111 connect_unixsock(sdb_client_t *client, const char *address)
113         struct sockaddr_un sa;
115         client->fd = socket(AF_UNIX, SOCK_STREAM, /* protocol = */ 0);
116         if (client->fd < 0) {
117                 char errbuf[1024];
118                 sdb_log(SDB_LOG_ERR, "Failed to open socket: %s",
119                                 sdb_strerror(errno, errbuf, sizeof(errbuf)));
120                 return -1;
121         }
123         sa.sun_family = AF_UNIX;
124         strncpy(sa.sun_path, address, sizeof(sa.sun_path));
125         sa.sun_path[sizeof(sa.sun_path) - 1] = '\0';
127         if (connect(client->fd, (struct sockaddr *)&sa, sizeof(sa))) {
128                 char errbuf[1024];
129                 sdb_client_close(client);
130                 sdb_log(SDB_LOG_ERR, "Failed to connect to '%s': %s",
131                                 sa.sun_path, sdb_strerror(errno, errbuf, sizeof(errbuf)));
132                 return -1;
133         }
134         return client->fd;
135 } /* connect_unixsock */
137 static int
138 connect_tcp(sdb_client_t *client, const char *address)
140         struct addrinfo *ai, *ai_list = NULL;
141         int status;
143         if ((status = sdb_resolve(SDB_NET_TCP, address, &ai_list))) {
144                 sdb_log(SDB_LOG_ERR, "Failed to resolve '%s': %s",
145                                 address, gai_strerror(status));
146                 return -1;
147         }
149         for (ai = ai_list; ai != NULL; ai = ai->ai_next) {
150                 client->fd = socket(ai->ai_family, ai->ai_socktype, ai->ai_protocol);
151                 if (client->fd < 0) {
152                         char errbuf[1024];
153                         sdb_log(SDB_LOG_ERR, "Failed to open socket: %s",
154                                         sdb_strerror(errno, errbuf, sizeof(errbuf)));
155                         continue;
156                 }
158                 if (connect(client->fd, ai->ai_addr, ai->ai_addrlen)) {
159                         char host[1024], port[32], errbuf[1024];
160                         sdb_client_close(client);
161                         getnameinfo(ai->ai_addr, ai->ai_addrlen, host, sizeof(host),
162                                         port, sizeof(port), NI_NUMERICHOST | NI_NUMERICSERV);
163                         sdb_log(SDB_LOG_ERR, "Failed to connect to '%s:%s': %s",
164                                         host, port, sdb_strerror(errno, errbuf, sizeof(errbuf)));
165                         continue;
166                 }
167                 break;
168         }
169         freeaddrinfo(ai_list);
171         if (client->fd < 0)
172                 return -1;
174         /* TODO: make options configurable */
175         client->ssl = sdb_ssl_client_create(NULL);
176         if (! client->ssl) {
177                 sdb_client_close(client);
178                 return -1;
179         }
180         client->ssl_session = sdb_ssl_client_connect(client->ssl, client->fd);
181         if (! client->ssl_session) {
182                 sdb_client_close(client);
183                 return -1;
184         }
186         client->read = ssl_read;
187         client->write = ssl_write;
188         return client->fd;
189 } /* connect_tcp */
191 /*
192  * public API
193  */
195 sdb_client_t *
196 sdb_client_create(const char *address)
198         sdb_client_t *client;
200         if (! address)
201                 return NULL;
203         client = malloc(sizeof(*client));
204         if (! client) {
205                 sdb_log(SDB_LOG_ERR, "Out of memory");
206                 return NULL;
207         }
208         memset(client, 0, sizeof(*client));
209         client->fd = -1;
210         client->eof = 1;
212         client->ssl = NULL;
213         client->read = client_read;
214         client->write = client_write;
216         client->address = strdup(address);
217         if (! client->address) {
218                 sdb_client_destroy(client);
219                 sdb_log(SDB_LOG_ERR, "Out of memory");
220                 return NULL;
221         }
223         return client;
224 } /* sdb_client_create */
226 void
227 sdb_client_destroy(sdb_client_t *client)
229         if (! client)
230                 return;
232         sdb_client_close(client);
234         if (client->address)
235                 free(client->address);
236         client->address = NULL;
238         free(client);
239 } /* sdb_client_destroy */
241 int
242 sdb_client_connect(sdb_client_t *client, const char *username)
244         sdb_strbuf_t *buf;
245         ssize_t status;
246         uint32_t rstatus;
248         if ((! client) || (! client->address))
249                 return -1;
251         if (client->fd >= 0)
252                 return -1;
254         if (*client->address == '/')
255                 connect_unixsock(client, client->address);
256         else if (!strncasecmp(client->address, "unix:", strlen("unix:")))
257                 connect_unixsock(client, client->address + strlen("unix:"));
258         else if (!strncasecmp(client->address, "tcp:", strlen("tcp:")))
259                 connect_tcp(client, client->address + strlen("tcp:"));
260         else
261                 connect_tcp(client, client->address);
263         if (client->fd < 0)
264                 return -1;
265         client->eof = 0;
267         /* XXX */
268         if (! username)
269                 username = "";
271         buf = sdb_strbuf_create(64);
272         rstatus = 0;
273         status = sdb_client_rpc(client, SDB_CONNECTION_STARTUP,
274                         (uint32_t)strlen(username), username, &rstatus, buf);
275         if ((status >= 0) && (rstatus == SDB_CONNECTION_OK)) {
276                 sdb_strbuf_destroy(buf);
277                 return 0;
278         }
280         if (status < 0) {
281                 sdb_log(SDB_LOG_ERR, "%s", sdb_strbuf_string(buf));
282                 sdb_client_close(client);
283                 sdb_strbuf_destroy(buf);
284                 return (int)status;
285         }
286         if (client->eof)
287                 sdb_log(SDB_LOG_ERR, "Encountered end-of-file while waiting "
288                                 "for server response");
290         if (rstatus == SDB_CONNECTION_ERROR) {
291                 sdb_log(SDB_LOG_ERR, "Access denied for user '%s': %s",
292                                 username, sdb_strbuf_string(buf));
293                 status = -((int)rstatus);
294         }
295         else if (rstatus != SDB_CONNECTION_OK) {
296                 sdb_log(SDB_LOG_ERR, "Received unsupported authentication request "
297                                 "(status %d) during startup", (int)rstatus);
298                 status = -((int)rstatus);
299         }
301         sdb_client_close(client);
302         sdb_strbuf_destroy(buf);
303         return (int)status;
304 } /* sdb_client_connect */
306 int
307 sdb_client_sockfd(sdb_client_t *client)
309         if (! client)
310                 return -1;
311         return client->fd;
312 } /* sdb_client_sockfd */
314 int
315 sdb_client_shutdown(sdb_client_t *client, int how)
317         if (! client) {
318                 errno = ENOTSOCK;
319                 return -1;
320         }
322         if (client->fd < 0) {
323                 errno = EBADF;
324                 return -1;
325         }
327         return shutdown(client->fd, how);
328 } /* sdb_client_shutdown */
330 void
331 sdb_client_close(sdb_client_t *client)
333         if (! client)
334                 return;
336         if (client->ssl_session) {
337                 sdb_ssl_session_destroy(client->ssl_session);
338                 client->ssl_session = NULL;
339         }
340         if (client->ssl) {
341                 sdb_ssl_client_destroy(client->ssl);
342                 client->ssl = NULL;
343         }
345         close(client->fd);
346         client->fd = -1;
347         client->eof = 1;
348 } /* sdb_client_close */
350 ssize_t
351 sdb_client_rpc(sdb_client_t *client,
352                 uint32_t cmd, uint32_t msg_len, const char *msg,
353                 uint32_t *code, sdb_strbuf_t *buf)
355         uint32_t rcode = 0;
356         ssize_t status;
358         if (! buf)
359                 return -1;
361         if (sdb_client_send(client, cmd, msg_len, msg) < 0) {
362                 char errbuf[1024];
363                 sdb_strbuf_sprintf(buf, "Failed to send %s message to server: %s",
364                                 SDB_CONN_MSGTYPE_TO_STRING(cmd),
365                                 sdb_strerror(errno, errbuf, sizeof(errbuf)));
366                 if (code)
367                         *code = SDB_CONNECTION_ERROR;
368                 return -1;
369         }
371         while (42) {
372                 size_t offset = sdb_strbuf_len(buf);
374                 status = sdb_client_recv(client, &rcode, buf);
375                 if (status < 0) {
376                         char errbuf[1024];
377                         sdb_strbuf_sprintf(buf, "Failed to receive server response: %s",
378                                         sdb_strerror(errno, errbuf, sizeof(errbuf)));
379                         if (code)
380                                 *code = SDB_CONNECTION_ERROR;
381                         return status;
382                 }
384                 if (rcode == SDB_CONNECTION_LOG) {
385                         uint32_t prio = 0;
386                         if (sdb_proto_unmarshal_int32(SDB_STRBUF_STR(buf), &prio) < 0) {
387                                 sdb_log(SDB_LOG_WARNING, "Received a LOG message "
388                                                 "with invalid or missing priority");
389                                 prio = (uint32_t)SDB_LOG_ERR;
390                         }
391                         sdb_log((int)prio, "%s", sdb_strbuf_string(buf) + offset);
392                         sdb_strbuf_skip(buf, offset, sdb_strbuf_len(buf) - offset);
393                         continue;
394                 }
395                 break;
396         }
398         if (code)
399                 *code = rcode;
400         return status;
401 } /* sdb_client_rpc */
403 ssize_t
404 sdb_client_send(sdb_client_t *client,
405                 uint32_t cmd, uint32_t msg_len, const char *msg)
407         char buf[2 * sizeof(uint32_t) + msg_len];
409         if ((! client) || (! client->fd))
410                 return -1;
411         if (sdb_proto_marshal(buf, sizeof(buf), cmd, msg_len, msg) < 0)
412                 return -1;
414         return client->write(client, buf, sizeof(buf));
415 } /* sdb_client_send */
417 ssize_t
418 sdb_client_recv(sdb_client_t *client,
419                 uint32_t *code, sdb_strbuf_t *buf)
421         uint32_t rstatus = UINT32_MAX;
422         uint32_t rlen = UINT32_MAX;
424         size_t total = 0;
425         size_t req = 2 * sizeof(uint32_t);
427         size_t data_offset = sdb_strbuf_len(buf);
429         if (code)
430                 *code = UINT32_MAX;
432         if ((! client) || (! client->fd) || (! buf)) {
433                 errno = EBADF;
434                 return -1;
435         }
437         while (42) {
438                 ssize_t status;
440                 errno = 0;
441                 status = client->read(client, buf, req);
442                 if (status < 0) {
443                         if ((errno == EAGAIN) || (errno == EWOULDBLOCK))
444                                 continue;
445                         return status;
446                 }
447                 else if (! status) {
448                         client->eof = 1;
449                         break;
450                 }
452                 total += (size_t)status;
454                 if (total != req)
455                         continue;
457                 if (rstatus == UINT32_MAX) {
458                         const char *str = sdb_strbuf_string(buf) + data_offset;
459                         size_t len = sdb_strbuf_len(buf) - data_offset;
460                         ssize_t n;
462                         /* retrieve status and data len */
463                         assert(len >= 2 * sizeof(uint32_t));
464                         n = sdb_proto_unmarshal_int32(str, len, &rstatus);
465                         str += n; len -= (size_t)n;
466                         sdb_proto_unmarshal_int32(str, len, &rlen);
468                         if (! rlen)
469                                 break;
471                         req = (size_t)rlen;
472                         total = 0;
473                 }
474                 else /* finished reading data */
475                         break;
476         }
478         if (total != req) {
479                 /* unexpected EOF; clear partially read data */
480                 sdb_strbuf_skip(buf, data_offset, sdb_strbuf_len(buf));
481                 return 0;
482         }
484         if (rstatus != UINT32_MAX)
485                 /* remove status,len */
486                 sdb_strbuf_skip(buf, data_offset, 2 * sizeof(rstatus));
488         if (code)
489                 *code = rstatus;
491         return (ssize_t)total;
492 } /* sdb_client_recv */
494 bool
495 sdb_client_eof(sdb_client_t *client)
497         if ((! client) || (client->fd < 0))
498                 return 1;
499         return client->eof;
500 } /* sdb_client_eof */
502 /* vim: set tw=78 sw=4 ts=4 noexpandtab : */