Code

client: Add support for TCP connections.
[sysdb.git] / src / client / sock.c
1 /*
2  * SysDB - src/client/sock.c
3  * Copyright (C) 2013 Sebastian 'tokkee' Harl <sh@tokkee.org>
4  * All rights reserved.
5  *
6  * Redistribution and use in source and binary forms, with or without
7  * modification, are permitted provided that the following conditions
8  * are met:
9  * 1. Redistributions of source code must retain the above copyright
10  *    notice, this list of conditions and the following disclaimer.
11  * 2. Redistributions in binary form must reproduce the above copyright
12  *    notice, this list of conditions and the following disclaimer in the
13  *    documentation and/or other materials provided with the distribution.
14  *
15  * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
16  * ``AS IS'' AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED
17  * TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR
18  * PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDERS OR
19  * CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL,
20  * EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO,
21  * PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS;
22  * OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY,
23  * WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR
24  * OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF
25  * ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
26  */
28 #if HAVE_CONFIG_H
29 #       include "config.h"
30 #endif /* HAVE_CONFIG_H */
32 #include "client/sock.h"
33 #include "utils/error.h"
34 #include "utils/strbuf.h"
35 #include "utils/proto.h"
36 #include "utils/os.h"
38 #include <arpa/inet.h>
40 #include <assert.h>
41 #include <errno.h>
42 #include <limits.h>
44 #include <stdlib.h>
46 #include <string.h>
47 #include <strings.h>
49 #include <unistd.h>
51 #include <sys/socket.h>
52 #include <sys/un.h>
54 #include <netdb.h>
56 /*
57  * private data types
58  */
60 struct sdb_client {
61         char *address;
62         int   fd;
63         bool  eof;
65         ssize_t (*read)(sdb_client_t *, sdb_strbuf_t *, size_t);
66         ssize_t (*write)(sdb_client_t *, const void *, size_t);
67 };
69 /*
70  * private helper functions
71  */
73 static ssize_t
74 client_read(sdb_client_t *client, sdb_strbuf_t *buf, size_t n)
75 {
76         return sdb_strbuf_read(buf, client->fd, n);
77 } /* client_read */
79 static ssize_t
80 client_write(sdb_client_t *client, const void *buf, size_t n)
81 {
82         return sdb_write(client->fd, n, buf);
83 } /* client_write */
85 static int
86 connect_unixsock(sdb_client_t *client, const char *address)
87 {
88         struct sockaddr_un sa;
90         client->fd = socket(AF_UNIX, SOCK_STREAM, /* protocol = */ 0);
91         if (client->fd < 0) {
92                 char errbuf[1024];
93                 sdb_log(SDB_LOG_ERR, "Failed to open socket: %s",
94                                 sdb_strerror(errno, errbuf, sizeof(errbuf)));
95                 return -1;
96         }
98         sa.sun_family = AF_UNIX;
99         strncpy(sa.sun_path, address, sizeof(sa.sun_path));
100         sa.sun_path[sizeof(sa.sun_path) - 1] = '\0';
102         if (connect(client->fd, (struct sockaddr *)&sa, sizeof(sa))) {
103                 char errbuf[1024];
104                 sdb_client_close(client);
105                 sdb_log(SDB_LOG_ERR, "Failed to connect to '%s': %s",
106                                 sa.sun_path, sdb_strerror(errno, errbuf, sizeof(errbuf)));
107                 return -1;
108         }
109         return client->fd;
110 } /* connect_unixsock */
112 static int
113 connect_tcp(sdb_client_t *client, const char *address)
115         struct addrinfo *ai, *ai_list = NULL;
116         int status;
118         if ((status = sdb_resolve(SDB_NET_TCP, address, &ai_list))) {
119                 sdb_log(SDB_LOG_ERR, "Failed to resolve '%s': %s",
120                                 address, gai_strerror(status));
121                 return -1;
122         }
124         for (ai = ai_list; ai != NULL; ai = ai->ai_next) {
125                 client->fd = socket(ai->ai_family, ai->ai_socktype, ai->ai_protocol);
126                 if (client->fd < 0) {
127                         char errbuf[1024];
128                         sdb_log(SDB_LOG_ERR, "Failed to open socket: %s",
129                                         sdb_strerror(errno, errbuf, sizeof(errbuf)));
130                         continue;
131                 }
133                 if (connect(client->fd, ai->ai_addr, ai->ai_addrlen)) {
134                         char host[1024], port[32], errbuf[1024];
135                         sdb_client_close(client);
136                         getnameinfo(ai->ai_addr, ai->ai_addrlen, host, sizeof(host),
137                                         port, sizeof(port), NI_NUMERICHOST | NI_NUMERICSERV);
138                         sdb_log(SDB_LOG_ERR, "Failed to connect to '%s:%s': %s",
139                                         host, port, sdb_strerror(errno, errbuf, sizeof(errbuf)));
140                         continue;
141                 }
142                 break;
143         }
144         freeaddrinfo(ai_list);
145         return client->fd;
146 } /* connect_tcp */
148 /*
149  * public API
150  */
152 sdb_client_t *
153 sdb_client_create(const char *address)
155         sdb_client_t *client;
157         if (! address)
158                 return NULL;
160         client = malloc(sizeof(*client));
161         if (! client) {
162                 sdb_log(SDB_LOG_ERR, "Out of memory");
163                 return NULL;
164         }
165         memset(client, 0, sizeof(*client));
166         client->fd = -1;
167         client->eof = 1;
169         client->read = client_read;
170         client->write = client_write;
172         client->address = strdup(address);
173         if (! client->address) {
174                 sdb_client_destroy(client);
175                 sdb_log(SDB_LOG_ERR, "Out of memory");
176                 return NULL;
177         }
179         return client;
180 } /* sdb_client_create */
182 void
183 sdb_client_destroy(sdb_client_t *client)
185         if (! client)
186                 return;
188         sdb_client_close(client);
190         if (client->address)
191                 free(client->address);
192         client->address = NULL;
194         free(client);
195 } /* sdb_client_destroy */
197 int
198 sdb_client_connect(sdb_client_t *client, const char *username)
200         sdb_strbuf_t *buf;
201         ssize_t status;
202         uint32_t rstatus;
204         if ((! client) || (! client->address))
205                 return -1;
207         if (client->fd >= 0)
208                 return -1;
210         if (*client->address == '/')
211                 connect_unixsock(client, client->address);
212         else if (!strncasecmp(client->address, "unix:", strlen("unix:")))
213                 connect_unixsock(client, client->address + strlen("unix:"));
214         else if (!strncasecmp(client->address, "tcp:", strlen("tcp:")))
215                 connect_tcp(client, client->address + strlen("tcp:"));
216         else
217                 connect_tcp(client, client->address);
219         if (client->fd < 0)
220                 return -1;
221         client->eof = 0;
223         /* XXX */
224         if (! username)
225                 username = "";
227         status = sdb_client_send(client, SDB_CONNECTION_STARTUP,
228                         (uint32_t)strlen(username), username);
229         if (status < 0) {
230                 char errbuf[1024];
231                 sdb_client_close(client);
232                 sdb_log(SDB_LOG_ERR, "Failed to send STARTUP message to server: %s",
233                                 sdb_strerror(errno, errbuf, sizeof(errbuf)));
234                 return (int)status;
235         }
237         buf = sdb_strbuf_create(64);
238         rstatus = 0;
239         status = sdb_client_recv(client, &rstatus, buf);
240         if ((status > 0) && (rstatus == SDB_CONNECTION_OK)) {
241                 sdb_strbuf_destroy(buf);
242                 return 0;
243         }
245         if (status < 0) {
246                 char errbuf[1024];
247                 sdb_log(SDB_LOG_ERR, "Failed to receive server response: %s",
248                                 sdb_strerror(errno, errbuf, sizeof(errbuf)));
249         }
250         else if (client->eof)
251                 sdb_log(SDB_LOG_ERR, "Encountered end-of-file while waiting "
252                                 "for server response");
254         if (rstatus == SDB_CONNECTION_ERROR) {
255                 sdb_log(SDB_LOG_ERR, "Access denied for user '%s': %s",
256                                 username, sdb_strbuf_string(buf));
257                 status = -((int)rstatus);
258         }
259         else if (rstatus != SDB_CONNECTION_OK) {
260                 sdb_log(SDB_LOG_ERR, "Received unsupported authentication request "
261                                 "(status %d) during startup", (int)rstatus);
262                 status = -((int)rstatus);
263         }
265         sdb_client_close(client);
266         sdb_strbuf_destroy(buf);
267         return (int)status;
268 } /* sdb_client_connect */
270 int
271 sdb_client_sockfd(sdb_client_t *client)
273         if (! client)
274                 return -1;
275         return client->fd;
276 } /* sdb_client_sockfd */
278 int
279 sdb_client_shutdown(sdb_client_t *client, int how)
281         if (! client) {
282                 errno = ENOTSOCK;
283                 return -1;
284         }
286         if (client->fd < 0) {
287                 errno = EBADF;
288                 return -1;
289         }
291         return shutdown(client->fd, how);
292 } /* sdb_client_shutdown */
294 void
295 sdb_client_close(sdb_client_t *client)
297         if (! client)
298                 return;
300         close(client->fd);
301         client->fd = -1;
302         client->eof = 1;
303 } /* sdb_client_close */
305 ssize_t
306 sdb_client_send(sdb_client_t *client,
307                 uint32_t cmd, uint32_t msg_len, const char *msg)
309         char buf[2 * sizeof(uint32_t) + msg_len];
311         if ((! client) || (! client->fd))
312                 return -1;
313         if (sdb_proto_marshal(buf, sizeof(buf), cmd, msg_len, msg) < 0)
314                 return -1;
316         return client->write(client, buf, sizeof(buf));
317 } /* sdb_client_send */
319 ssize_t
320 sdb_client_recv(sdb_client_t *client,
321                 uint32_t *code, sdb_strbuf_t *buf)
323         uint32_t rstatus = UINT32_MAX;
324         uint32_t rlen = UINT32_MAX;
326         size_t total = 0;
327         size_t req = 2 * sizeof(uint32_t);
329         size_t data_offset = sdb_strbuf_len(buf);
331         if (code)
332                 *code = UINT32_MAX;
334         if ((! client) || (! client->fd) || (! buf)) {
335                 errno = EBADF;
336                 return -1;
337         }
339         while (42) {
340                 ssize_t status;
342                 if (sdb_select(client->fd, SDB_SELECTIN))
343                         return -1;
345                 errno = 0;
346                 status = client->read(client, buf, req);
347                 if (status < 0) {
348                         if ((errno == EAGAIN) || (errno == EWOULDBLOCK))
349                                 continue;
350                         return status;
351                 }
352                 else if (! status) {
353                         client->eof = 1;
354                         break;
355                 }
357                 total += (size_t)status;
359                 if (total != req)
360                         continue;
362                 if (rstatus == UINT32_MAX) {
363                         const char *str = sdb_strbuf_string(buf) + data_offset;
364                         size_t len = sdb_strbuf_len(buf) - data_offset;
365                         ssize_t n;
367                         /* retrieve status and data len */
368                         assert(len >= 2 * sizeof(uint32_t));
369                         n = sdb_proto_unmarshal_int32(str, len, &rstatus);
370                         str += n; len -= (size_t)n;
371                         sdb_proto_unmarshal_int32(str, len, &rlen);
373                         if (! rlen)
374                                 break;
376                         req = (size_t)rlen;
377                         total = 0;
378                 }
379                 else /* finished reading data */
380                         break;
381         }
383         if (total != req) {
384                 /* unexpected EOF; clear partially read data */
385                 sdb_strbuf_skip(buf, data_offset, sdb_strbuf_len(buf));
386                 return 0;
387         }
389         if (rstatus != UINT32_MAX)
390                 /* remove status,len */
391                 sdb_strbuf_skip(buf, data_offset, 2 * sizeof(rstatus));
393         if (code)
394                 *code = rstatus;
396         return (ssize_t)total;
397 } /* sdb_client_recv */
399 bool
400 sdb_client_eof(sdb_client_t *client)
402         if ((! client) || (client->fd < 0))
403                 return 1;
404         return client->eof;
405 } /* sdb_client_eof */
407 /* vim: set tw=78 sw=4 ts=4 noexpandtab : */