b67910efa4b5eb843504b825587e52a14367f465
1 /*
2 * SysDB - src/client/sock.c
3 * Copyright (C) 2013 Sebastian 'tokkee' Harl <sh@tokkee.org>
4 * All rights reserved.
5 *
6 * Redistribution and use in source and binary forms, with or without
7 * modification, are permitted provided that the following conditions
8 * are met:
9 * 1. Redistributions of source code must retain the above copyright
10 * notice, this list of conditions and the following disclaimer.
11 * 2. Redistributions in binary form must reproduce the above copyright
12 * notice, this list of conditions and the following disclaimer in the
13 * documentation and/or other materials provided with the distribution.
14 *
15 * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
16 * ``AS IS'' AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED
17 * TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR
18 * PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDERS OR
19 * CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL,
20 * EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO,
21 * PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS;
22 * OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY,
23 * WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR
24 * OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF
25 * ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
26 */
28 #if HAVE_CONFIG_H
29 # include "config.h"
30 #endif /* HAVE_CONFIG_H */
32 #include "client/sock.h"
33 #include "utils/error.h"
34 #include "utils/strbuf.h"
35 #include "utils/proto.h"
36 #include "utils/os.h"
37 #include "utils/ssl.h"
39 #include <arpa/inet.h>
41 #include <assert.h>
42 #include <errno.h>
43 #include <limits.h>
45 #include <stdlib.h>
47 #include <string.h>
48 #include <strings.h>
50 #include <unistd.h>
52 #include <sys/socket.h>
53 #include <sys/un.h>
55 #include <netdb.h>
57 /*
58 * private data types
59 */
61 struct sdb_client {
62 char *address;
63 int fd;
64 bool eof;
66 /* optional SSL settings */
67 sdb_ssl_options_t ssl_opts;
68 sdb_ssl_client_t *ssl;
69 sdb_ssl_session_t *ssl_session;
71 ssize_t (*read)(sdb_client_t *, sdb_strbuf_t *, size_t);
72 ssize_t (*write)(sdb_client_t *, const void *, size_t);
73 };
75 /*
76 * private helper functions
77 */
79 static ssize_t
80 ssl_read(sdb_client_t *client, sdb_strbuf_t *buf, size_t n)
81 {
82 char tmp[n];
83 ssize_t ret;
85 ret = sdb_ssl_session_read(client->ssl_session, tmp, n);
86 if (ret <= 0)
87 return ret;
89 sdb_strbuf_memappend(buf, tmp, ret);
90 return ret;
91 } /* ssl_read */
93 static ssize_t
94 ssl_write(sdb_client_t *client, const void *buf, size_t n)
95 {
96 return sdb_ssl_session_write(client->ssl_session, buf, n);
97 } /* ssl_write */
99 static ssize_t
100 client_read(sdb_client_t *client, sdb_strbuf_t *buf, size_t n)
101 {
102 return sdb_strbuf_read(buf, client->fd, n);
103 } /* client_read */
105 static ssize_t
106 client_write(sdb_client_t *client, const void *buf, size_t n)
107 {
108 return sdb_write(client->fd, n, buf);
109 } /* client_write */
111 static int
112 connect_unixsock(sdb_client_t *client, const char *address)
113 {
114 struct sockaddr_un sa;
116 client->fd = socket(AF_UNIX, SOCK_STREAM, /* protocol = */ 0);
117 if (client->fd < 0) {
118 char errbuf[1024];
119 sdb_log(SDB_LOG_ERR, "Failed to open socket: %s",
120 sdb_strerror(errno, errbuf, sizeof(errbuf)));
121 return -1;
122 }
124 sa.sun_family = AF_UNIX;
125 strncpy(sa.sun_path, address, sizeof(sa.sun_path));
126 sa.sun_path[sizeof(sa.sun_path) - 1] = '\0';
128 if (connect(client->fd, (struct sockaddr *)&sa, sizeof(sa))) {
129 char errbuf[1024];
130 sdb_client_close(client);
131 sdb_log(SDB_LOG_ERR, "Failed to connect to '%s': %s",
132 sa.sun_path, sdb_strerror(errno, errbuf, sizeof(errbuf)));
133 return -1;
134 }
135 return client->fd;
136 } /* connect_unixsock */
138 static int
139 connect_tcp(sdb_client_t *client, const char *address)
140 {
141 struct addrinfo *ai, *ai_list = NULL;
142 int status;
144 if ((status = sdb_resolve(SDB_NET_TCP, address, &ai_list))) {
145 sdb_log(SDB_LOG_ERR, "Failed to resolve '%s': %s",
146 address, gai_strerror(status));
147 return -1;
148 }
150 for (ai = ai_list; ai != NULL; ai = ai->ai_next) {
151 client->fd = socket(ai->ai_family, ai->ai_socktype, ai->ai_protocol);
152 if (client->fd < 0) {
153 char errbuf[1024];
154 sdb_log(SDB_LOG_ERR, "Failed to open socket: %s",
155 sdb_strerror(errno, errbuf, sizeof(errbuf)));
156 continue;
157 }
159 if (connect(client->fd, ai->ai_addr, ai->ai_addrlen)) {
160 char host[1024], port[32], errbuf[1024];
161 sdb_client_close(client);
162 getnameinfo(ai->ai_addr, ai->ai_addrlen, host, sizeof(host),
163 port, sizeof(port), NI_NUMERICHOST | NI_NUMERICSERV);
164 sdb_log(SDB_LOG_ERR, "Failed to connect to '%s:%s': %s",
165 host, port, sdb_strerror(errno, errbuf, sizeof(errbuf)));
166 continue;
167 }
168 break;
169 }
170 freeaddrinfo(ai_list);
172 if (client->fd < 0)
173 return -1;
175 client->ssl = sdb_ssl_client_create(&client->ssl_opts);
176 if (! client->ssl) {
177 sdb_client_close(client);
178 return -1;
179 }
180 client->ssl_session = sdb_ssl_client_connect(client->ssl, client->fd);
181 if (! client->ssl_session) {
182 sdb_client_close(client);
183 return -1;
184 }
186 client->read = ssl_read;
187 client->write = ssl_write;
188 return client->fd;
189 } /* connect_tcp */
191 static void
192 free_ssl_options(sdb_ssl_options_t *opts)
193 {
194 if (opts->ca_file)
195 free(opts->ca_file);
196 if (opts->key_file)
197 free(opts->key_file);
198 if (opts->cert_file)
199 free(opts->cert_file);
200 if (opts->crl_file)
201 free(opts->crl_file);
202 opts->ca_file = opts->key_file = opts->cert_file = opts->crl_file = NULL;
203 } /* free_ssl_options */
205 /*
206 * public API
207 */
209 sdb_client_t *
210 sdb_client_create(const char *address)
211 {
212 sdb_client_t *client;
214 if (! address)
215 return NULL;
217 client = malloc(sizeof(*client));
218 if (! client) {
219 sdb_log(SDB_LOG_ERR, "Out of memory");
220 return NULL;
221 }
222 memset(client, 0, sizeof(*client));
223 client->fd = -1;
224 client->eof = 1;
226 client->ssl = NULL;
227 client->read = client_read;
228 client->write = client_write;
230 client->address = strdup(address);
231 if (! client->address) {
232 sdb_client_destroy(client);
233 sdb_log(SDB_LOG_ERR, "Out of memory");
234 return NULL;
235 }
237 return client;
238 } /* sdb_client_create */
240 void
241 sdb_client_destroy(sdb_client_t *client)
242 {
243 if (! client)
244 return;
246 sdb_client_close(client);
248 if (client->address)
249 free(client->address);
250 client->address = NULL;
252 free_ssl_options(&client->ssl_opts);
254 free(client);
255 } /* sdb_client_destroy */
257 int
258 sdb_client_set_ssl_options(sdb_client_t *client, const sdb_ssl_options_t *opts)
259 {
260 int ret = 0;
262 if ((! client) || (! opts))
263 return -1;
265 free_ssl_options(&client->ssl_opts);
267 if (opts->ca_file) {
268 client->ssl_opts.ca_file = strdup(opts->ca_file);
269 if (! client->ssl_opts.ca_file)
270 ret = -1;
271 }
272 if (opts->key_file) {
273 client->ssl_opts.key_file = strdup(opts->key_file);
274 if (! client->ssl_opts.key_file)
275 ret = -1;
276 }
277 if (opts->cert_file) {
278 client->ssl_opts.cert_file = strdup(opts->cert_file);
279 if (! client->ssl_opts.cert_file)
280 ret = -1;
281 }
282 if (opts->crl_file) {
283 client->ssl_opts.crl_file = strdup(opts->crl_file);
284 if (! client->ssl_opts.crl_file)
285 ret = -1;
286 }
288 if (ret)
289 free_ssl_options(&client->ssl_opts);
290 return ret;
291 } /* sdb_client_set_ssl_options */
293 int
294 sdb_client_connect(sdb_client_t *client, const char *username)
295 {
296 sdb_strbuf_t *buf;
297 ssize_t status;
298 uint32_t rstatus;
300 if ((! client) || (! client->address))
301 return -1;
303 if (client->fd >= 0)
304 return -1;
306 if (*client->address == '/')
307 connect_unixsock(client, client->address);
308 else if (!strncasecmp(client->address, "unix:", strlen("unix:")))
309 connect_unixsock(client, client->address + strlen("unix:"));
310 else if (!strncasecmp(client->address, "tcp:", strlen("tcp:")))
311 connect_tcp(client, client->address + strlen("tcp:"));
312 else
313 connect_tcp(client, client->address);
315 if (client->fd < 0)
316 return -1;
317 client->eof = 0;
319 /* XXX */
320 if (! username)
321 username = "";
323 buf = sdb_strbuf_create(64);
324 rstatus = 0;
325 status = sdb_client_rpc(client, SDB_CONNECTION_STARTUP,
326 (uint32_t)strlen(username), username, &rstatus, buf);
327 if ((status >= 0) && (rstatus == SDB_CONNECTION_OK)) {
328 sdb_strbuf_destroy(buf);
329 return 0;
330 }
332 if (status < 0) {
333 sdb_log(SDB_LOG_ERR, "%s", sdb_strbuf_string(buf));
334 sdb_client_close(client);
335 sdb_strbuf_destroy(buf);
336 return (int)status;
337 }
338 if (client->eof)
339 sdb_log(SDB_LOG_ERR, "Encountered end-of-file while waiting "
340 "for server response");
342 if (rstatus == SDB_CONNECTION_ERROR) {
343 sdb_log(SDB_LOG_ERR, "Access denied for user '%s': %s",
344 username, sdb_strbuf_string(buf));
345 status = -((int)rstatus);
346 }
347 else if (rstatus != SDB_CONNECTION_OK) {
348 sdb_log(SDB_LOG_ERR, "Received unsupported authentication request "
349 "(status %d) during startup", (int)rstatus);
350 status = -((int)rstatus);
351 }
353 sdb_client_close(client);
354 sdb_strbuf_destroy(buf);
355 return (int)status;
356 } /* sdb_client_connect */
358 int
359 sdb_client_sockfd(sdb_client_t *client)
360 {
361 if (! client)
362 return -1;
363 return client->fd;
364 } /* sdb_client_sockfd */
366 int
367 sdb_client_shutdown(sdb_client_t *client, int how)
368 {
369 if (! client) {
370 errno = ENOTSOCK;
371 return -1;
372 }
374 if (client->fd < 0) {
375 errno = EBADF;
376 return -1;
377 }
379 return shutdown(client->fd, how);
380 } /* sdb_client_shutdown */
382 void
383 sdb_client_close(sdb_client_t *client)
384 {
385 if (! client)
386 return;
388 if (client->ssl_session) {
389 sdb_ssl_session_destroy(client->ssl_session);
390 client->ssl_session = NULL;
391 }
392 if (client->ssl) {
393 sdb_ssl_client_destroy(client->ssl);
394 client->ssl = NULL;
395 }
397 close(client->fd);
398 client->fd = -1;
399 client->eof = 1;
400 } /* sdb_client_close */
402 ssize_t
403 sdb_client_rpc(sdb_client_t *client,
404 uint32_t cmd, uint32_t msg_len, const char *msg,
405 uint32_t *code, sdb_strbuf_t *buf)
406 {
407 uint32_t rcode = 0;
408 ssize_t status;
410 if (! buf)
411 return -1;
413 if (sdb_client_send(client, cmd, msg_len, msg) < 0) {
414 char errbuf[1024];
415 sdb_strbuf_sprintf(buf, "Failed to send %s message to server: %s",
416 SDB_CONN_MSGTYPE_TO_STRING(cmd),
417 sdb_strerror(errno, errbuf, sizeof(errbuf)));
418 if (code)
419 *code = SDB_CONNECTION_ERROR;
420 return -1;
421 }
423 while (42) {
424 size_t offset = sdb_strbuf_len(buf);
426 status = sdb_client_recv(client, &rcode, buf);
427 if (status < 0) {
428 char errbuf[1024];
429 sdb_strbuf_sprintf(buf, "Failed to receive server response: %s",
430 sdb_strerror(errno, errbuf, sizeof(errbuf)));
431 if (code)
432 *code = SDB_CONNECTION_ERROR;
433 return status;
434 }
436 if (rcode == SDB_CONNECTION_LOG) {
437 uint32_t prio = 0;
438 if (sdb_proto_unmarshal_int32(SDB_STRBUF_STR(buf), &prio) < 0) {
439 sdb_log(SDB_LOG_WARNING, "Received a LOG message "
440 "with invalid or missing priority");
441 prio = (uint32_t)SDB_LOG_ERR;
442 }
443 sdb_log((int)prio, "%s", sdb_strbuf_string(buf) + offset);
444 sdb_strbuf_skip(buf, offset, sdb_strbuf_len(buf) - offset);
445 continue;
446 }
447 break;
448 }
450 if (code)
451 *code = rcode;
452 return status;
453 } /* sdb_client_rpc */
455 ssize_t
456 sdb_client_send(sdb_client_t *client,
457 uint32_t cmd, uint32_t msg_len, const char *msg)
458 {
459 char buf[2 * sizeof(uint32_t) + msg_len];
461 if ((! client) || (! client->fd))
462 return -1;
463 if (sdb_proto_marshal(buf, sizeof(buf), cmd, msg_len, msg) < 0)
464 return -1;
466 return client->write(client, buf, sizeof(buf));
467 } /* sdb_client_send */
469 ssize_t
470 sdb_client_recv(sdb_client_t *client,
471 uint32_t *code, sdb_strbuf_t *buf)
472 {
473 uint32_t rstatus = UINT32_MAX;
474 uint32_t rlen = UINT32_MAX;
476 size_t total = 0;
477 size_t req = 2 * sizeof(uint32_t);
479 size_t data_offset = sdb_strbuf_len(buf);
481 if (code)
482 *code = UINT32_MAX;
484 if ((! client) || (! client->fd) || (! buf)) {
485 errno = EBADF;
486 return -1;
487 }
489 while (42) {
490 ssize_t status;
492 errno = 0;
493 status = client->read(client, buf, req);
494 if (status < 0) {
495 if ((errno == EAGAIN) || (errno == EWOULDBLOCK))
496 continue;
497 return status;
498 }
499 else if (! status) {
500 client->eof = 1;
501 break;
502 }
504 total += (size_t)status;
506 if (total != req)
507 continue;
509 if (rstatus == UINT32_MAX) {
510 const char *str = sdb_strbuf_string(buf) + data_offset;
511 size_t len = sdb_strbuf_len(buf) - data_offset;
512 ssize_t n;
514 /* retrieve status and data len */
515 assert(len >= 2 * sizeof(uint32_t));
516 n = sdb_proto_unmarshal_int32(str, len, &rstatus);
517 str += n; len -= (size_t)n;
518 sdb_proto_unmarshal_int32(str, len, &rlen);
520 if (! rlen)
521 break;
523 req = (size_t)rlen;
524 total = 0;
525 }
526 else /* finished reading data */
527 break;
528 }
530 if (total != req) {
531 /* unexpected EOF; clear partially read data */
532 sdb_strbuf_skip(buf, data_offset, sdb_strbuf_len(buf));
533 return 0;
534 }
536 if (rstatus != UINT32_MAX)
537 /* remove status,len */
538 sdb_strbuf_skip(buf, data_offset, 2 * sizeof(rstatus));
540 if (code)
541 *code = rstatus;
543 return (ssize_t)total;
544 } /* sdb_client_recv */
546 bool
547 sdb_client_eof(sdb_client_t *client)
548 {
549 if ((! client) || (client->fd < 0))
550 return 1;
551 return client->eof;
552 } /* sdb_client_eof */
554 /* vim: set tw=78 sw=4 ts=4 noexpandtab : */