Code

frontend: Fix invalid command handling when receiving data in chunks.
authorSebastian Harl <sh@tokkee.org>
Wed, 14 May 2014 18:38:17 +0000 (20:38 +0200)
committerSebastian Harl <sh@tokkee.org>
Wed, 14 May 2014 18:38:17 +0000 (20:38 +0200)
When skipping over invalid command data, make sure not to skip more data than
what's currently available. Rather, remember how much data needs to be ignored
and do so after actually receiving it.

Added a new test which catches these cases and also provides some more
low-level communication tests.

src/frontend/connection-private.h
src/frontend/connection.c
t/unit/frontend/connection_test.c

index 574049e3dde05b42d1dc1e0dd326242d45ff43bc..ae9cbbc91ba0cfa2dd532cb55795de95164a96fe 100644 (file)
@@ -64,6 +64,10 @@ struct sdb_conn {
        uint32_t cmd;
        uint32_t cmd_len;
 
+       /* amount of data to skip, e.g., after receiving invalid commands; if this
+        * is non-zero, the 'skip_len' first bytes of 'buf' are invalid */
+       size_t skip_len;
+
        sdb_strbuf_t *errbuf;
 
        /* user information */
index 923f000743db5f3694698ffbecdb9a0a033ec4b4..b91393997d6588cb631779230dc57b19cdf0ea22 100644 (file)
@@ -117,6 +117,7 @@ connection_init(sdb_object_t *obj, va_list ap)
 
        conn->cmd = CONNECTION_IDLE;
        conn->cmd_len = 0;
+       conn->skip_len = 0;
 
        /* update the object name */
        snprintf(obj->name + strlen(CONN_FD_PREFIX),
@@ -259,6 +260,7 @@ command_handle(sdb_conn_t *conn)
        int status = -1;
 
        assert(conn && (conn->cmd != CONNECTION_IDLE));
+       assert(! conn->skip_len);
 
        sdb_log(SDB_LOG_DEBUG, "frontend: Handling command %u (len: %u)",
                        conn->cmd, conn->cmd_len);
@@ -368,22 +370,33 @@ command_init(sdb_conn_t *conn)
 
        assert(conn && (conn->cmd == CONNECTION_IDLE) && (! conn->cmd_len));
 
+       if (conn->skip_len)
+               return -1;
+
        /* reset */
        sdb_strbuf_sprintf(conn->errbuf, "");
 
        conn->cmd = connection_get_int32(conn, 0);
        conn->cmd_len = connection_get_int32(conn, sizeof(uint32_t));
 
-       len = 2 * sizeof(uint32_t);
        if (conn->cmd == CONNECTION_IDLE) {
                const char *errmsg = "Invalid command 0";
                sdb_strbuf_sprintf(conn->errbuf, errmsg);
                sdb_connection_send(conn, CONNECTION_ERROR,
                                (uint32_t)strlen(errmsg), errmsg);
-               len += conn->cmd_len;
+               conn->skip_len += conn->cmd_len;
                conn->cmd_len = 0;
        }
-       sdb_strbuf_skip(conn->buf, 0, len);
+       sdb_strbuf_skip(conn->buf, 0, 2 * sizeof(uint32_t));
+
+       if (conn->skip_len) {
+               len = sdb_strbuf_len(conn->buf);
+               if (len > conn->skip_len)
+                       len = conn->skip_len;
+               sdb_strbuf_skip(conn->buf, 0, len);
+               conn->skip_len -= len;
+               /* connection_read will handle anything else */
+       }
        return 0;
 } /* command_init */
 
@@ -412,6 +425,13 @@ connection_read(sdb_conn_t *conn)
                else if (! status) /* EOF */
                        break;
 
+               if (conn->skip_len) {
+                       size_t len = (size_t)status < conn->skip_len
+                               ? (size_t)status : conn->skip_len;
+                       sdb_strbuf_skip(conn->buf, 0, len);
+                       conn->skip_len -= len;
+               }
+
                n += status;
        }
 
index da94f06ec3cfd1bf5b1391e7249b98dd6b9a0f20..43322e02adf00c7c852642e6a685263a44c8f3fd 100644 (file)
@@ -31,6 +31,7 @@
 
 #include "frontend/connection.h"
 #include "frontend/connection-private.h"
+#include "utils/proto.h"
 #include "libsysdb_test.h"
 
 #include "utils/strbuf.h"
@@ -167,6 +168,31 @@ mock_client(void *arg)
        return NULL;
 } /* mock_client */
 
+static void
+connection_setup(sdb_conn_t *conn)
+{
+       ssize_t check, expected;
+
+       expected = 2 * sizeof(uint32_t) + strlen("fakeuser");
+       check = sdb_connection_send(conn, CONNECTION_STARTUP,
+                       (uint32_t)strlen("fakeuser"), "fakeuser");
+       fail_unless(check == expected,
+                       "sdb_connection_send(STARTUP, fakeuser) = %zi; expected: %zi",
+                       check, expected);
+
+       mock_conn_rewind(conn);
+       check = sdb_connection_read(conn);
+       fail_unless(check == expected,
+                       "On startup: sdb_connection_read() = %zi; expected: %zi",
+                       check, expected);
+
+       fail_unless(sdb_strbuf_len(conn->errbuf) == 0,
+                       "sdb_connection_read() left %zu bytes in the error "
+                       "buffer; expected: 0", sdb_strbuf_len(conn->errbuf));
+
+       mock_conn_truncate(conn);
+} /* connection_setup */
+
 /*
  * tests
  */
@@ -200,6 +226,7 @@ START_TEST(test_conn_accept)
 }
 END_TEST
 
+/* test connection setup and very basic commands */
 START_TEST(test_conn_setup)
 {
        sdb_conn_t *conn = mock_conn_create();
@@ -209,6 +236,7 @@ START_TEST(test_conn_setup)
                const char *msg;
                const char *err;
        } golden_data[] = {
+               /* code == UINT32_MAX => no data will be sent */
                { UINT32_MAX,         NULL,       NULL },
                { CONNECTION_IDLE,    "fakedata", "Invalid command 0" },
                { CONNECTION_PING,    NULL,       "Authentication required" },
@@ -267,6 +295,109 @@ START_TEST(test_conn_setup)
 }
 END_TEST
 
+/* test simple I/O on open connections */
+START_TEST(test_conn_io)
+{
+       sdb_conn_t *conn = mock_conn_create();
+
+       struct {
+               uint32_t code;
+               uint32_t msg_len;
+               const char *msg;
+               size_t buf_len; /* number of bytes we expect in conn->buf */
+               const char *err;
+       } golden_data[] = {
+               /* code == UINT32_MAX => this is a follow-up package */
+               { CONNECTION_PING, 20, "9876543210", 10, NULL },
+               { UINT32_MAX,      -1, "9876543210",  0, NULL },
+               { CONNECTION_IDLE, 20, "9876543210",  0, "Invalid command 0" },
+               { UINT32_MAX,      -1, "9876543210",  0, "Invalid command 0" },
+               { CONNECTION_IDLE, 20, "9876543210",  0, "Invalid command 0" },
+               { UINT32_MAX,      -1, "9876543210",  0, "Invalid command 0" },
+               { CONNECTION_PING, 10, "9876543210",  0, NULL },
+               { CONNECTION_PING, 20, "9876543210", 10, NULL },
+               { UINT32_MAX,      -1, "9876543210",  0, NULL },
+       };
+
+       size_t i;
+
+       connection_setup(conn);
+
+       for (i = 0; i < SDB_STATIC_ARRAY_LEN(golden_data); ++i) {
+               size_t msg_len = strlen(golden_data[i].msg);
+               char buffer[2 * sizeof(uint32_t) + msg_len];
+               size_t offset = 0;
+
+               ssize_t check;
+
+               mock_conn_truncate(conn);
+
+               if (golden_data[i].code != UINT32_MAX) {
+                       uint32_t tmp;
+
+                       tmp = htonl(golden_data[i].code);
+                       memcpy(buffer, &tmp, sizeof(tmp));
+                       tmp = htonl(golden_data[i].msg_len);
+                       memcpy(buffer + sizeof(tmp), &tmp, sizeof(tmp));
+
+                       msg_len += 2 * sizeof(uint32_t);
+                       offset += 2 * sizeof(uint32_t);
+               }
+
+               memcpy(buffer + offset, golden_data[i].msg,
+                               strlen(golden_data[i].msg));
+
+               check = sdb_proto_send(conn->fd, msg_len, buffer);
+               fail_unless(check == (ssize_t)msg_len,
+                               "sdb_proto_send(%s) = %zi; expected: %zu",
+                               check, msg_len);
+
+               mock_conn_rewind(conn);
+               check = sdb_connection_read(conn);
+               fail_unless(check == (ssize_t)msg_len,
+                               "sdb_connection_read() = %zi; expected: %zu",
+                               check, msg_len);
+
+               if (golden_data[i].buf_len) {
+                       /* partial commands need to be stored in the object */
+                       fail_unless(conn->cmd == golden_data[i].code,
+                                       "sdb_connection_read() set partial command "
+                                       "to %u; expected: %u", conn->cmd, golden_data[i].code);
+                       fail_unless(conn->cmd_len > golden_data[i].buf_len,
+                                       "sdb_connection_read() set partial command length "
+                                       "to %u; expected: > %u", conn->cmd_len,
+                                       golden_data[i].buf_len);
+               }
+               else {
+                       fail_unless(conn->cmd == CONNECTION_IDLE,
+                                       "sdb_connection_read() did not reset command; "
+                                       "got %u; expected: %u", conn->cmd, CONNECTION_IDLE);
+                       fail_unless(conn->cmd_len == 0,
+                                       "sdb_connection_read() did not reset command length; "
+                                       "got %u; expected: 0", conn->cmd_len);
+               }
+
+               fail_unless(sdb_strbuf_len(conn->buf) == golden_data[i].buf_len,
+                               "sdb_connection_read() left %zu bytes in the buffer; "
+                               "expected: %zu", sdb_strbuf_len(conn->buf),
+                               golden_data[i].buf_len);
+
+               if (golden_data[i].err) {
+                       const char *err = sdb_strbuf_string(conn->errbuf);
+                       fail_unless(strcmp(err, golden_data[i].err) == 0,
+                                       "sdb_connection_read(): got error '%s'; "
+                                       "expected: '%s'", err, golden_data[i].err);
+               }
+               else
+                       fail_unless(sdb_strbuf_len(conn->errbuf) == 0,
+                                       "sdb_connection_read() left %zu bytes in the error "
+                                       "buffer; expected: 0", sdb_strbuf_len(conn->errbuf));
+       }
+
+       mock_conn_destroy(conn);
+}
+END_TEST
+
 Suite *
 fe_conn_suite(void)
 {
@@ -276,6 +407,7 @@ fe_conn_suite(void)
        tc = tcase_create("core");
        tcase_add_test(tc, test_conn_accept);
        tcase_add_test(tc, test_conn_setup);
+       tcase_add_test(tc, test_conn_io);
        suite_add_tcase(s, tc);
 
        return s;