1 /**
2 * Phoebe DOM Implementation.
3 *
4 * This is a C++ approximation of the W3C DOM model, which follows
5 * fairly closely the specifications in the various .idl files, copies of
6 * which are provided for reference. Most important is this one:
7 *
8 * http://www.w3.org/TR/2004/REC-DOM-Level-3-Core-20040407/idl-definitions.html
9 *
10 * Authors:
11 * Bob Jamison
12 *
13 * Copyright (C) 2005 Bob Jamison
14 *
15 * This library is free software; you can redistribute it and/or
16 * modify it under the terms of the GNU Lesser General Public
17 * License as published by the Free Software Foundation; either
18 * version 2.1 of the License, or (at your option) any later version.
19 *
20 * This library is distributed in the hope that it will be useful,
21 * but WITHOUT ANY WARRANTY; without even the implied warranty of
22 * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU
23 * Lesser General Public License for more details.
24 *
25 * You should have received a copy of the GNU Lesser General Public
26 * License along with this library; if not, write to the Free Software
27 * Foundation, Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA
28 */
30 #include "socket.h"
31 #include "dom/util/thread.h"
33 #ifdef __WIN32__
34 #include <windows.h>
35 #else /* unix */
36 #include <sys/types.h>
37 #include <sys/socket.h>
38 #include <netinet/in.h>
39 #include <netdb.h>
40 #include <unistd.h>
41 #include <sys/ioctl.h>
43 #endif
45 #ifdef HAVE_SSL
46 #include <openssl/ssl.h>
47 #include <openssl/err.h>
48 #endif
49 #endif
52 namespace org
53 {
54 namespace w3c
55 {
56 namespace dom
57 {
58 namespace io
59 {
61 static void mybzero(void *s, size_t n)
62 {
63 unsigned char *p = (unsigned char *)s;
64 while (n > 0)
65 {
66 *p++ = (unsigned char)0;
67 n--;
68 }
69 }
71 static void mybcopy(void *src, void *dest, size_t n)
72 {
73 unsigned char *p = (unsigned char *)dest;
74 unsigned char *q = (unsigned char *)src;
75 while (n > 0)
76 {
77 *p++ = *q++;
78 n--;
79 }
80 }
84 //#########################################################################
85 //# T C P C O N N E C T I O N
86 //#########################################################################
88 TcpSocket::TcpSocket()
89 {
90 init();
91 }
94 TcpSocket::TcpSocket(const DOMString &hostnameArg, int port)
95 {
96 init();
97 hostname = hostnameArg;
98 portno = port;
99 }
102 #ifdef HAVE_SSL
104 static void cryptoLockCallback(int mode, int type, const char *file, int line)
105 {
106 //printf("########### LOCK\n");
107 static int modes[CRYPTO_NUM_LOCKS]; /* = {0, 0, ... } */
108 const char *errstr = NULL;
110 int rw = mode & (CRYPTO_READ|CRYPTO_WRITE);
111 if (!((rw == CRYPTO_READ) || (rw == CRYPTO_WRITE)))
112 {
113 errstr = "invalid mode";
114 goto err;
115 }
117 if (type < 0 || type >= CRYPTO_NUM_LOCKS)
118 {
119 errstr = "type out of bounds";
120 goto err;
121 }
123 if (mode & CRYPTO_LOCK)
124 {
125 if (modes[type])
126 {
127 errstr = "already locked";
128 /* must not happen in a single-threaded program
129 * (would deadlock)
130 */
131 goto err;
132 }
134 modes[type] = rw;
135 }
136 else if (mode & CRYPTO_UNLOCK)
137 {
138 if (!modes[type])
139 {
140 errstr = "not locked";
141 goto err;
142 }
144 if (modes[type] != rw)
145 {
146 errstr = (rw == CRYPTO_READ) ?
147 "CRYPTO_r_unlock on write lock" :
148 "CRYPTO_w_unlock on read lock";
149 }
151 modes[type] = 0;
152 }
153 else
154 {
155 errstr = "invalid mode";
156 goto err;
157 }
159 err:
160 if (errstr)
161 {
162 /* we cannot use bio_err here */
163 fprintf(stderr, "openssl (lock_dbg_cb): %s (mode=%d, type=%d) at %s:%d\n",
164 errstr, mode, type, file, line);
165 }
166 }
168 static unsigned long cryptoIdCallback()
169 {
170 #ifdef __WIN32__
171 unsigned long ret = (unsigned long) GetCurrentThreadId();
172 #else
173 unsigned long ret = (unsigned long) pthread_self();
174 #endif
175 return ret;
176 }
178 #endif
181 TcpSocket::TcpSocket(const TcpSocket &other)
182 {
183 init();
184 sock = other.sock;
185 hostname = other.hostname;
186 portno = other.portno;
187 }
189 static bool tcp_socket_inited = false;
191 void TcpSocket::init()
192 {
193 if (!tcp_socket_inited)
194 {
195 #ifdef __WIN32__
196 WORD wVersionRequested = MAKEWORD( 2, 2 );
197 WSADATA wsaData;
198 WSAStartup( wVersionRequested, &wsaData );
199 #endif
200 #ifdef HAVE_SSL
201 sslStream = NULL;
202 sslContext = NULL;
203 CRYPTO_set_locking_callback(cryptoLockCallback);
204 CRYPTO_set_id_callback(cryptoIdCallback);
205 SSL_library_init();
206 SSL_load_error_strings();
207 #endif
208 tcp_socket_inited = true;
209 }
210 sock = -1;
211 connected = false;
212 hostname = "";
213 portno = -1;
214 sslEnabled = false;
215 receiveTimeout = 0;
216 }
218 TcpSocket::~TcpSocket()
219 {
220 disconnect();
221 }
223 bool TcpSocket::isConnected()
224 {
225 if (!connected || sock < 0)
226 return false;
227 return true;
228 }
230 void TcpSocket::enableSSL(bool val)
231 {
232 sslEnabled = val;
233 }
236 bool TcpSocket::connect(const DOMString &hostnameArg, int portnoArg)
237 {
238 hostname = hostnameArg;
239 portno = portnoArg;
240 return connect();
241 }
245 #ifdef HAVE_SSL
246 /*
247 static int password_cb(char *buf, int bufLen, int rwflag, void *userdata)
248 {
249 char *password = "password";
250 if (bufLen < (int)(strlen(password)+1))
251 return 0;
253 strcpy(buf,password);
254 int ret = strlen(password);
255 return ret;
256 }
258 static void infoCallback(const SSL *ssl, int where, int ret)
259 {
260 switch (where)
261 {
262 case SSL_CB_ALERT:
263 {
264 printf("## %d SSL ALERT: %s\n", where, SSL_alert_desc_string_long(ret));
265 break;
266 }
267 default:
268 {
269 printf("## %d SSL: %s\n", where, SSL_state_string_long(ssl));
270 break;
271 }
272 }
273 }
274 */
275 #endif
278 bool TcpSocket::startTls()
279 {
280 #ifdef HAVE_SSL
281 sslStream = NULL;
282 sslContext = NULL;
284 //SSL_METHOD *meth = SSLv23_method();
285 //SSL_METHOD *meth = SSLv3_client_method();
286 SSL_METHOD *meth = TLSv1_client_method();
287 sslContext = SSL_CTX_new(meth);
288 //SSL_CTX_set_info_callback(sslContext, infoCallback);
290 #if 0
291 char *keyFile = "client.pem";
292 char *caList = "root.pem";
293 /* Load our keys and certificates*/
294 if (!(SSL_CTX_use_certificate_chain_file(sslContext, keyFile)))
295 {
296 fprintf(stderr, "Can't read certificate file\n");
297 disconnect();
298 return false;
299 }
301 SSL_CTX_set_default_passwd_cb(sslContext, password_cb);
303 if (!(SSL_CTX_use_PrivateKey_file(sslContext, keyFile, SSL_FILETYPE_PEM)))
304 {
305 fprintf(stderr, "Can't read key file\n");
306 disconnect();
307 return false;
308 }
310 /* Load the CAs we trust*/
311 if (!(SSL_CTX_load_verify_locations(sslContext, caList, 0)))
312 {
313 fprintf(stderr, "Can't read CA list\n");
314 disconnect();
315 return false;
316 }
317 #endif
319 /* Connect the SSL socket */
320 sslStream = SSL_new(sslContext);
321 SSL_set_fd(sslStream, sock);
323 if (SSL_connect(sslStream)<=0)
324 {
325 fprintf(stderr, "SSL connect error\n");
326 disconnect();
327 return false;
328 }
330 sslEnabled = true;
331 #endif /*HAVE_SSL*/
332 return true;
333 }
336 bool TcpSocket::connect()
337 {
338 if (hostname.size()<1)
339 {
340 printf("open: null hostname\n");
341 return false;
342 }
344 if (portno<1)
345 {
346 printf("open: bad port number\n");
347 return false;
348 }
350 sock = socket(PF_INET, SOCK_STREAM, 0);
351 if (sock < 0)
352 {
353 printf("open: error creating socket\n");
354 return false;
355 }
357 char *c_hostname = (char *)hostname.c_str();
358 struct hostent *server = gethostbyname(c_hostname);
359 if (!server)
360 {
361 printf("open: could not locate host '%s'\n", c_hostname);
362 return false;
363 }
365 struct sockaddr_in serv_addr;
366 mybzero((char *) &serv_addr, sizeof(serv_addr));
367 serv_addr.sin_family = AF_INET;
368 mybcopy((char *)server->h_addr, (char *)&serv_addr.sin_addr.s_addr,
369 server->h_length);
370 serv_addr.sin_port = htons(portno);
372 int ret = ::connect(sock, (const sockaddr *)&serv_addr, sizeof(serv_addr));
373 if (ret < 0)
374 {
375 printf("open: could not connect to host '%s'\n", c_hostname);
376 return false;
377 }
379 if (sslEnabled)
380 {
381 if (!startTls())
382 return false;
383 }
384 connected = true;
385 return true;
386 }
388 bool TcpSocket::disconnect()
389 {
390 bool ret = true;
391 connected = false;
392 #ifdef HAVE_SSL
393 if (sslEnabled)
394 {
395 if (sslStream)
396 {
397 int r = SSL_shutdown(sslStream);
398 switch(r)
399 {
400 case 1:
401 break; /* Success */
402 case 0:
403 case -1:
404 default:
405 //printf("Shutdown failed");
406 ret = false;
407 }
408 SSL_free(sslStream);
409 }
410 if (sslContext)
411 SSL_CTX_free(sslContext);
412 }
413 sslStream = NULL;
414 sslContext = NULL;
415 #endif /*HAVE_SSL*/
417 #ifdef __WIN32__
418 closesocket(sock);
419 #else
420 ::close(sock);
421 #endif
422 sock = -1;
423 sslEnabled = false;
425 return ret;
426 }
430 bool TcpSocket::setReceiveTimeout(unsigned long millis)
431 {
432 receiveTimeout = millis;
433 return true;
434 }
436 /**
437 * For normal sockets, return the number of bytes waiting to be received.
438 * For SSL, just return >0 when something is ready to be read.
439 */
440 long TcpSocket::available()
441 {
442 if (!isConnected())
443 return -1;
445 long count = 0;
446 #ifdef __WIN32__
447 if (ioctlsocket(sock, FIONREAD, (unsigned long *)&count) != 0)
448 return -1;
449 #else
450 if (ioctl(sock, FIONREAD, &count) != 0)
451 return -1;
452 #endif
453 if (count<=0 && sslEnabled)
454 {
455 #ifdef HAVE_SSL
456 return SSL_pending(sslStream);
457 #endif
458 }
459 return count;
460 }
464 bool TcpSocket::write(int ch)
465 {
466 if (!isConnected())
467 {
468 printf("write: socket closed\n");
469 return false;
470 }
471 unsigned char c = (unsigned char)ch;
473 if (sslEnabled)
474 {
475 #ifdef HAVE_SSL
476 int r = SSL_write(sslStream, &c, 1);
477 if (r<=0)
478 {
479 switch(SSL_get_error(sslStream, r))
480 {
481 default:
482 printf("SSL write problem");
483 return -1;
484 }
485 }
486 #endif
487 }
488 else
489 {
490 if (send(sock, (const char *)&c, 1, 0) < 0)
491 //if (send(sock, &c, 1, 0) < 0)
492 {
493 printf("write: could not send data\n");
494 return false;
495 }
496 }
497 return true;
498 }
500 bool TcpSocket::write(const DOMString &strArg)
501 {
502 DOMString str = strArg;
504 if (!isConnected())
505 {
506 printf("write(str): socket closed\n");
507 return false;
508 }
509 int len = str.size();
511 if (sslEnabled)
512 {
513 #ifdef HAVE_SSL
514 int r = SSL_write(sslStream, (unsigned char *)str.c_str(), len);
515 if (r<=0)
516 {
517 switch(SSL_get_error(sslStream, r))
518 {
519 default:
520 printf("SSL write problem");
521 return -1;
522 }
523 }
524 #endif
525 }
526 else
527 {
528 if (send(sock, str.c_str(), len, 0) < 0)
529 //if (send(sock, &c, 1, 0) < 0)
530 {
531 printf("write: could not send data\n");
532 return false;
533 }
534 }
535 return true;
536 }
538 int TcpSocket::read()
539 {
540 if (!isConnected())
541 return -1;
543 //We'll use this loop for timeouts, so that SSL and plain sockets
544 //will behave the same way
545 if (receiveTimeout > 0)
546 {
547 unsigned long tim = 0;
548 while (true)
549 {
550 int avail = available();
551 if (avail > 0)
552 break;
553 if (tim >= receiveTimeout)
554 return -2;
555 org::w3c::dom::util::Thread::sleep(20);
556 tim += 20;
557 }
558 }
560 //check again
561 if (!isConnected())
562 return -1;
564 unsigned char ch;
565 if (sslEnabled)
566 {
567 #ifdef HAVE_SSL
568 if (!sslStream)
569 return -1;
570 int r = SSL_read(sslStream, &ch, 1);
571 unsigned long err = SSL_get_error(sslStream, r);
572 switch (err)
573 {
574 case SSL_ERROR_NONE:
575 break;
576 case SSL_ERROR_ZERO_RETURN:
577 return -1;
578 case SSL_ERROR_SYSCALL:
579 printf("SSL read problem(syscall) %s\n",
580 ERR_error_string(ERR_get_error(), NULL));
581 return -1;
582 default:
583 printf("SSL read problem %s\n",
584 ERR_error_string(ERR_get_error(), NULL));
585 return -1;
586 }
587 #endif
588 }
589 else
590 {
591 int ret = recv(sock, (char *)&ch, 1, 0);
592 if (ret <= 0)
593 {
594 if (ret<0)
595 printf("read: could not receive data\n");
596 disconnect();
597 return -1;
598 }
599 }
600 return (int)ch;
601 }
603 bool TcpSocket::readLine(DOMString &result)
604 {
605 result = "";
607 while (isConnected())
608 {
609 int ch = read();
610 if (ch<0)
611 return true;
612 else if (ch=='\r') //we want canonical Net '\r\n' , so skip this
613 {}
614 else if (ch=='\n')
615 return true;
616 else
617 result.push_back((char)ch);
618 }
620 return true;
621 }
623 } //namespace io
624 } //namespace dom
625 } //namespace w3c
626 } //namespace org
629 //#########################################################################
630 //# E N D O F F I L E
631 //#########################################################################