#include "basenet.h" #include "backend.h" #include #include "sockets.h" #include #include #include #include #include "xalloc.h" static void _delete_endpoint (endpoint_t *endpoint); static void on_endpoint_error (endpoint_t *endpoint); static void write_to_buffer (basenet_buffer_t *buffer, void *data, size_t len); static void reset_endpoint_socket_hooks (endpoint_t *endpoint); static void eat_buffer (basenet_buffer_t *buffer, size_t len); static void resize_buffer (basenet_buffer_t *buffer, size_t size_at_least); static packet_t* try_recv (endpoint_t *endpoint); #ifdef BASENET_SSL static void create_ssl_context (endpoint_t *endpoint, int connect, int fd); static void handle_ssl_error (endpoint_t *endpoint, int err, const char *call); #endif #define BASENET_HANDSHAKE_MAGIC "BA10" basenet_id_t basenet_name_to_id (const char *name) { // adler variant for now unsigned int a= 1, b= 0; while (*name) { a= (a+ (unsigned char)*name++) % 65521; b= (a+b) % 65521; } return ((unsigned long long) a) <<32 | b; } basenet_t* basenet_new (const char *myname, basenet_recv_callback_t recv_callback, const char *pemfile) { basenet_t *basenet= XALLOC_STRUCT (basenet_t); if (pemfile) basenet->pemfile= xstrdup (pemfile); basenet->nodeid= basenet_name_to_id (myname); basenet->recv_callback= recv_callback; #ifdef BASENET_SSL static int once= 1; if (once) { once= 0; SSL_load_error_strings (); ERR_load_BIO_strings (); OpenSSL_add_all_algorithms (); SSL_library_init (); } #endif return basenet; } static void set_blocking (int fd, int blocking) { #ifdef _WIN32 blocking= !blocking; ioctlsocket (fd, FIONBIO, &blocking); #else if (blocking) fcntl (fd, F_SETFL, fcntl (fd, F_GETFL) & ~O_NONBLOCK); else fcntl (fd, F_SETFL, fcntl (fd, F_GETFL) | O_NONBLOCK); #endif } static int parse_address (endpoint_t *endpoint, const char *_addr) { char tmp [1024], *addr= tmp; char *ip, *port, *p; assert (strlen (_addr)+1 < sizeof (tmp)); strcpy (addr, _addr); if (!memcmp (addr, "tcp:", 4)) { addr+= 4; } #ifdef BASENET_SSL if (!memcmp (addr, "ssl:", 4)) { endpoint->ssl_socket= 1; addr+= 4; } #endif if ( (p= strchr (addr, ':')) ) { *p= 0; ip= addr; port= p+1; } else { ip= addr; port= "1000"; } endpoint->sin.sin_family= AF_INET; endpoint->sin.sin_port= htons (atoi (port)); if (!endpoint->sin.sin_port) { fprintf (stderr, "invalid port number: %s\n", port); return -1; } endpoint->sin.sin_addr.s_addr= inet_addr (addr); if (endpoint->sin.sin_addr.s_addr== INADDR_NONE) { fprintf (stderr, "invalid address: %s\n", addr); return -1; } return 0; } static void on_endpoint_read (int fd, void *_endpoint) { endpoint_t *endpoint= (endpoint_t*) _endpoint; basenet_t *basenet= endpoint->basenet; packet_t *first= NULL, **tail= &first; // used for callbacks packet_t *packet; //fprintf (stderr, "on_endpoint_read %p (%d)\n", endpoint, endpoint->fd); if (endpoint->is_server) { struct sockaddr_in sin; socklen_t socklen= sizeof (sin); int fd= accept (endpoint->fd, (struct sockaddr*) &sin, &socklen); if (fd < 0) fprintf (stderr, "accept (%d): %s\n", endpoint->fd, SOCKET_STRERROR ()); else { set_blocking (endpoint->fd, 0); #ifdef BASENET_SSL basenet_endpoint_from_fd (endpoint->basenet, fd, endpoint->ssl != NULL); #else basenet_endpoint_from_fd (endpoint->basenet, fd, 0); #endif } } if (endpoint->is_client || endpoint->is_accepted) { int space_left= endpoint->incoming.limit- endpoint->incoming.in_use; int amt; unsigned char *p; resize_buffer (&endpoint->incoming, endpoint->incoming.limit); p= endpoint->incoming.buffer+ endpoint->incoming.in_use; #ifdef BASENET_SSL if (endpoint->ssl) { amt= SSL_read (endpoint->ssl, p, space_left); handle_ssl_error (endpoint, amt, "SSL_read"); } else #endif { amt= recv (endpoint->fd, p, space_left, 0); if (amt < 0) { if (SOCKET_ERRNO() != EAGAIN && SOCKET_ERRNO() != EINTR && SOCKET_ERRNO()!= EWOULDBLOCK) { fprintf (stderr, "recv (%d, %p, %d, 0): %s\n", endpoint->fd, p, space_left, SOCKET_STRERROR ()); on_endpoint_error (endpoint); } goto done; } } if (!amt) on_endpoint_error (endpoint); else if (amt > 0) endpoint->incoming.in_use+= amt; if (!endpoint->handshake_done && endpoint->incoming.in_use >= 8+4) { uint32_t high, low; if (memcmp (p, BASENET_HANDSHAKE_MAGIC, 4)) { fprintf (stderr, "bad magic received: closing\n"); on_endpoint_error (endpoint); } else { p+= 4; memcpy (&high, p, sizeof (high)); p+= 4; memcpy (&low, p, sizeof (low)); high= ntohl (high); low= ntohl (low); endpoint->peer= ((unsigned long long) high)<<32 | low; endpoint->handshake_done= 1; eat_buffer (&endpoint->incoming, 12); fprintf (stderr, "handshake done %p (%d)\n", endpoint, endpoint->fd); } } while ( (packet= try_recv (endpoint)) ) { *tail= packet; tail= &(*tail)->next; } } done: reset_endpoint_socket_hooks (endpoint); while (first) { basenet->recv_callback (basenet, first); xfree (first->buffer); first= first->next; } } static void on_endpoint_write (int fd, void *_endpoint) { endpoint_t *endpoint= (endpoint_t*) _endpoint; int amt; //fprintf (stderr, "on_endpoint_write %p (%d)\n", endpoint, endpoint->fd); amt= send (endpoint->fd, endpoint->outgoing.buffer, endpoint->outgoing.in_use, 0); if (amt < 0) { if (SOCKET_ERRNO() != EAGAIN && SOCKET_ERRNO() != EINTR && SOCKET_ERRNO()!= EWOULDBLOCK) { fprintf (stderr, "send (%d, %p, %d): %s\n", endpoint->fd, endpoint->outgoing.buffer, endpoint->outgoing.in_use, SOCKET_STRERROR ()); on_endpoint_error (endpoint); } } else { eat_buffer (&endpoint->outgoing, amt); } reset_endpoint_socket_hooks (endpoint); } static void reset_endpoint_socket_hooks (endpoint_t *endpoint) { if (endpoint->dead) return; #ifdef BASENET_SSL if (endpoint->is_server== 0 && endpoint->ssl_socket) { if (endpoint->ssl_wants_to_read) backend_set_socket (endpoint->fd, BACKEND_READ, on_endpoint_read, endpoint); else backend_set_socket (endpoint->fd, BACKEND_READ, NULL, NULL); if (endpoint->ssl_wants_to_write) backend_set_socket (endpoint->fd, BACKEND_WRITE, on_endpoint_write, endpoint); else backend_set_socket (endpoint->fd, BACKEND_WRITE, NULL, NULL); } #endif if (endpoint->incoming.in_use < endpoint->incoming.limit && (endpoint->is_bound || endpoint->is_connected || endpoint->is_accepted)) backend_set_socket (endpoint->fd, BACKEND_READ, on_endpoint_read, endpoint); else backend_set_socket (endpoint->fd, BACKEND_READ, NULL, NULL); if (endpoint->outgoing.in_use && (endpoint->is_connected || endpoint->is_accepted)) backend_set_socket (endpoint->fd, BACKEND_WRITE, on_endpoint_write, endpoint); else backend_set_socket (endpoint->fd, BACKEND_WRITE, NULL, NULL); } static void queue_handshake (endpoint_t *endpoint) { uint32_t high, low; assert (endpoint->outgoing.in_use== 0); high= endpoint->basenet->nodeid>>32; low= endpoint->basenet->nodeid & 0xFFffFFff; high= htonl (high); low= htonl (low); write_to_buffer (&endpoint->outgoing, BASENET_HANDSHAKE_MAGIC, strlen (BASENET_HANDSHAKE_MAGIC)); write_to_buffer (&endpoint->outgoing, &high, sizeof (high)); write_to_buffer (&endpoint->outgoing, &low, sizeof (low)); } static void endpoint_periodic (void *_endpoint) { endpoint_t *endpoint= (endpoint_t*) _endpoint; if (endpoint->dead) { endpoint_t **nendpoint; backend_set_socket (endpoint->fd, BACKEND_READ | BACKEND_WRITE, NULL, NULL); SOCKET_CLOSE(endpoint->fd); fprintf (stderr, "CLOSE %d\n", endpoint->fd); xfree (endpoint->incoming.buffer); xfree (endpoint->outgoing.buffer); for (nendpoint= &endpoint->basenet->first_endpoint; *nendpoint; nendpoint= &(*nendpoint)->next) { if (*nendpoint== endpoint) { *nendpoint= (*nendpoint)->next; break; } } xfree (endpoint); return; } //fprintf (stderr, "endpoint periodic %p\n", endpoint); if (endpoint->fd < 0) { if ( (endpoint->fd= socket (AF_INET, SOCK_STREAM, 0)) >= 0) { set_blocking (endpoint->fd, 0); #ifdef BASENET_SSL create_ssl_context (endpoint, 1, endpoint->fd); #endif } else { SOCKET_PERROR ("socket (AF_INET, SOCK_STREAM)"); } } if (endpoint->is_server && endpoint->is_bound== 0 && endpoint->fd >= 0) { int true_val= 1; setsockopt (endpoint->fd, SOL_SOCKET, SO_REUSEADDR, (char*) &true_val, sizeof (int)); if (!bind (endpoint->fd, (struct sockaddr*) &endpoint->sin, sizeof (endpoint->sin))) { listen (endpoint->fd, 200); endpoint->is_bound= 1; } else { fprintf (stderr, "bind (%d, %s:%d): %s\n", endpoint->fd, inet_ntoa (endpoint->sin.sin_addr), ntohs (endpoint->sin.sin_port), SOCKET_STRERROR ()); } } if (endpoint->is_client && endpoint->is_connected== 0 && endpoint->fd>= 0) { fprintf (stderr, "conn\n"); if (connect (endpoint->fd, (struct sockaddr*) &endpoint->sin, sizeof (endpoint->sin)) < 0 && SOCKET_ERRNO ()!= EISCONN) { //int sockerrno= SOCKET_ERRNO (); switch (SOCKET_ERRNO ()) { case EAGAIN: case EINPROGRESS: case EISCONN: case EALREADY: #ifdef _WIN32 case EWOULDBLOCK: #endif break; default: fprintf (stderr, "connect (%d, %s:%d): %s %d\n", endpoint->fd, inet_ntoa (endpoint->sin.sin_addr), ntohs (endpoint->sin.sin_port), SOCKET_STRERROR (), SOCKET_ERRNO()); on_endpoint_error (endpoint); break; } } else { endpoint->is_connected= 1; queue_handshake (endpoint); #ifdef BASENET_SSL if (endpoint->ssl) { endpoint->ssl_wants_to_read= 1; endpoint->ssl_wants_to_write= 1; } #endif } } reset_endpoint_socket_hooks (endpoint); backend_add_timer (1000, endpoint_periodic, endpoint); } static endpoint_t* create_endpoint (basenet_t *basenet) { endpoint_t *endpoint= XALLOC_STRUCT (endpoint_t); endpoint->basenet= basenet; endpoint->fd= -1; endpoint->incoming.limit= 100*1024; endpoint->outgoing.limit= 100*1024; endpoint->next= basenet->first_endpoint; basenet->first_endpoint= endpoint; backend_add_timer (0, endpoint_periodic, endpoint); // endpoint_periodic (endpoint); return endpoint; } static void _delete_endpoint (endpoint_t *endpoint) { backend_set_socket (endpoint->fd, BACKEND_WRITE | BACKEND_READ, NULL, NULL); endpoint->dead= 1; } static void on_endpoint_error (endpoint_t *endpoint) { if (endpoint->is_accepted) _delete_endpoint (endpoint); else if (endpoint->is_client) { backend_set_socket (endpoint->fd, BACKEND_WRITE | BACKEND_READ, NULL, NULL); fprintf (stderr, "CLOSE %d\n", endpoint->fd); SOCKET_CLOSE (endpoint->fd); endpoint->fd= -1; endpoint->is_connected= 0; endpoint->handshake_done= 0; endpoint->outgoing.in_use= 0; endpoint->incoming.in_use= 0; } else assert (0 && "shouldnt happen"); } int basenet_bind (basenet_t *basenet, const char *addr) { endpoint_t *endpoint= create_endpoint (basenet); endpoint->is_server= 1; endpoint->fd= -1; if (parse_address (endpoint, addr) <0) goto fail; return 0; fail: _delete_endpoint (endpoint); return -1; } int basenet_connect (basenet_t *basenet, const char *addr) { endpoint_t *endpoint= create_endpoint (basenet); endpoint->is_client= 1; endpoint->fd= -1; if (parse_address (endpoint, addr) <0) goto fail; endpoint->peer_address= endpoint->sin; return 0; fail: _delete_endpoint (endpoint); return -1; } #ifdef BASENET_SSL static void handle_ssl_error (endpoint_t *endpoint, int err, const char *call) { const char *error_text= "unknown"; if (err > 0) return; endpoint->ssl_wants_to_read= 0; endpoint->ssl_wants_to_write= 0; err= SSL_get_error (endpoint->ssl, err); switch (err) { case SSL_ERROR_WANT_READ: endpoint->ssl_wants_to_read= 1; goto done; case SSL_ERROR_WANT_WRITE: endpoint->ssl_wants_to_write= 1; goto done; case SSL_ERROR_NONE: error_text= "none"; break; case SSL_ERROR_ZERO_RETURN: error_text= "connection closed"; break; case SSL_ERROR_WANT_CONNECT: error_text= "want connect"; break; case SSL_ERROR_WANT_ACCEPT: error_text= "wapt accept"; break; case SSL_ERROR_WANT_X509_LOOKUP: error_text="x509 lookup"; break; case SSL_ERROR_SYSCALL: error_text= "syscall"; break; case SSL_ERROR_SSL: error_text= "ssl"; break; } char buf [120]; int error = ERR_get_error(); ERR_error_string( error, buf ); fprintf (stderr, "%s (%d): %s: %d\n", call, endpoint->fd, buf, err); on_endpoint_error (endpoint); done: if (0) ; } static void create_ssl_context (endpoint_t *endpoint, int connect, int fd) { basenet_t *basenet= endpoint->basenet; if (endpoint->ssl== NULL) { int err; assert (fd != -1); if (connect) endpoint->ssl_ctx= SSL_CTX_new (SSLv23_client_method ()); else endpoint->ssl_ctx= SSL_CTX_new (SSLv23_server_method ()); assert (endpoint->ssl_ctx); if (!connect) { if (!basenet->pemfile) { fprintf (stderr, "no ssl certificate set\n"); } else { if (!SSL_CTX_use_RSAPrivateKey_file (endpoint->ssl_ctx, basenet->pemfile, SSL_FILETYPE_PEM)) { fprintf(stderr, "could not load RSA private key from %s\n", basenet->pemfile); } if (!SSL_CTX_use_certificate_file (endpoint->ssl_ctx, basenet->pemfile, SSL_FILETYPE_PEM)) { fprintf(stderr, "could not load certificate from %s\n", basenet->pemfile); } } } endpoint->ssl= SSL_new (endpoint->ssl_ctx); assert (endpoint->ssl); err= SSL_set_fd (endpoint->ssl, fd); assert (err== 1); if (connect) SSL_set_connect_state (endpoint->ssl); else SSL_set_accept_state (endpoint->ssl); SSL_set_verify (endpoint->ssl, SSL_VERIFY_NONE, NULL); endpoint->ssl_wants_to_read= 1; endpoint->ssl_wants_to_write= 1; } } #endif int basenet_endpoint_from_fd (basenet_t *basenet, int fd, int is_ssl) { endpoint_t *endpoint= create_endpoint (basenet); socklen_t sinlen= sizeof (endpoint->peer_address); getpeername (fd, (struct sockaddr*) &endpoint->peer_address, &sinlen); endpoint->fd= fd; endpoint->is_accepted= 1; queue_handshake (endpoint); #ifdef BASENET_SSL create_ssl_context (endpoint, 0, fd); #endif return 0; } static void resize_buffer (basenet_buffer_t *buffer, size_t size_at_least) { if (buffer->len < size_at_least) { buffer->len+= size_at_least; buffer->buffer= xrealloc (buffer->buffer, buffer->len); } } static void write_to_buffer (basenet_buffer_t *buffer, void *data, size_t len) { resize_buffer (buffer, buffer->in_use+ len); memcpy (buffer->buffer+ buffer->in_use, data, len); buffer->in_use+= len; } static void eat_buffer (basenet_buffer_t *buffer, size_t len) { if (buffer->in_use== len) buffer->in_use= 0; else { memmove (buffer->buffer, buffer->buffer+len, buffer->in_use-len); buffer->in_use-= len; } } int basenet_send (basenet_t *basenet, const packet_t *packet) { endpoint_t *endpoint; for (endpoint= basenet->first_endpoint; endpoint; endpoint= endpoint->next) if (endpoint->dead== 0 && endpoint->handshake_done && endpoint->peer== packet->peer) { // may queue? int need_space= 4+4+ 4+4+ packet->plen; uint32_t encoded_len, encoded_port; uint32_t encoded_xid, encoded_pid; if (endpoint->outgoing.in_use+ need_space <= endpoint->outgoing.limit) { encoded_len= htonl (need_space); encoded_port= htonl (packet->port); encoded_xid= ntohl (packet->xid); encoded_pid= ntohl (packet->pid); write_to_buffer (&endpoint->outgoing, &encoded_len, sizeof (encoded_len)); write_to_buffer (&endpoint->outgoing, &encoded_port, sizeof (encoded_port)); write_to_buffer (&endpoint->outgoing, &encoded_xid, sizeof (encoded_xid)); write_to_buffer (&endpoint->outgoing, &encoded_pid, sizeof (encoded_pid)); write_to_buffer (&endpoint->outgoing, packet->buffer, packet->plen); reset_endpoint_socket_hooks (endpoint); return 1; } } return -1; } static packet_t* try_recv (endpoint_t *endpoint) { packet_t *packet= NULL; uint32_t decoded_len, decoded_port; uint32_t decoded_xid, decoded_pid; int header_len; unsigned char *p; if (endpoint->dead) return NULL; if (endpoint->handshake_done== 0) return NULL; header_len= sizeof (decoded_len)+ sizeof (decoded_port)+ sizeof (decoded_xid)+ sizeof (decoded_pid); if (endpoint->incoming.in_use >= header_len) { p= endpoint->incoming.buffer; memcpy (&decoded_len, p, sizeof (decoded_len)); decoded_len= ntohl (decoded_len); p+= sizeof (decoded_len); if (decoded_len >= 0xFFffFF) { fprintf (stderr, "bad packet length %d: disconnecting\n", decoded_len); on_endpoint_error (endpoint); return 0; } if (endpoint->incoming.in_use >= decoded_len) { memcpy (&decoded_port, p, sizeof (decoded_port)); decoded_port= ntohl (decoded_port); p+= sizeof (decoded_port); memcpy (&decoded_xid, p, sizeof (decoded_xid)); decoded_xid= ntohl (decoded_xid); p+= sizeof (decoded_xid); memcpy (&decoded_pid, p, sizeof (decoded_pid)); decoded_pid= ntohl (decoded_pid); p+= sizeof (decoded_pid); packet= XALLOC_STRUCT (packet_t); packet->source_address= endpoint->peer_address; packet->peer= endpoint->peer; packet->port= decoded_port; packet->xid= decoded_xid; packet->pid= decoded_pid; packet->plen= decoded_len- header_len; packet->buffer= xalloc (packet->plen); memcpy (packet->buffer, p, packet->plen); eat_buffer (&endpoint->incoming, decoded_len); reset_endpoint_socket_hooks (endpoint); return packet; } } return NULL; } packet_t* basenet_recv (basenet_t *basenet) { endpoint_t *endpoint; packet_t *packet; for (endpoint= basenet->first_endpoint; endpoint; endpoint= endpoint->next) if ( (packet= try_recv (endpoint)) ) return packet; return NULL; } void basenet_poll (basenet_t *basenet, unsigned int timeout) { backend_poll (timeout); } void basenet_delete_packet (packet_t *packet) { xfree (packet->buffer); xfree (packet); } packet_t* basenet_clone_packet (const packet_t *packet) { packet_t *npacket= XALLOC_STRUCT (packet_t); *npacket= *packet; npacket->next= NULL; npacket->buffer= xalloc (packet->plen); memcpy (npacket->buffer, packet->buffer, packet->plen); return npacket; }