/* Unidirectional 16-bit pascal string message broadcast server. 
   (c) 2009 Tom Schouten (Licence: GPLv2) 

   - atomic transfer with minimal data inspection (Pascal strings)
   - optional blocking (minimal nb. clients)

*/

typedef unsigned short int message_size_t;

#define _GNU_SOURCE
#include <string.h>
#include <unistd.h>
#include <stdlib.h>
#include <stdio.h>
#include <errno.h>
#include <sys/types.h>
#include <sys/socket.h>
#include <sys/select.h>
#include <sys/un.h>
#include <netinet/in.h>
#include <signal.h>

// OPTIONS
int redundancy = 0;
char *path = "/tmp/broadcast";
int verbose = 0;


// INTERNAL

int listen_fd = 1;
struct sockaddr_un address;


fd_set conn_set;
int conn_max_fd = 0;
int conn_nb = 0;


int gotsig = 0;

void do_disconnect(fd) {
    FD_CLR(fd, &conn_set);
    close(fd);
    conn_nb--;
    if (verbose > 0) { fprintf(stderr, "%3d: Disconnect\n", fd); }

}

void safe_write(int fd, void *buf, int bytes) {
    int attempt = 0;
    int sent;
  again:
    sent = write(fd, buf, bytes);
    if (sent < 0) {
        // don't complain when client closes socket + ignore the signal.
        if (EPIPE == errno) {
            gotsig = 0;
        }
        else {
            fprintf(stderr, "fd %d: can't write %d bytes: %s\n", 
                    fd, bytes, strerror(errno));
        }
        do_disconnect(fd);
        return;
    }
    if (sent < bytes) {
        bytes -= sent;
        buf += sent;
        goto again;
    }
}

void safe_read(int fd, void *buf, int bytes) {
    int attempt = 0;
    int received;
    if (!bytes) return;

  again:
    received = read(fd, buf, bytes);
    if (received < 0) {
        fprintf(stderr, "can't read %d bytes: %s\n", bytes, strerror(errno));
        exit(1);
    }

    if (received == 0) {
        if (verbose > 0) { 
            fprintf(stderr, "Input closed.\n"); 
        }
        exit(0);
    }

    if (received < bytes) {
        bytes -= received;
        buf += received;
        goto again;
    }
}

void do_connect(void) {

    // accept
    int fd;
    struct sockaddr_un address; 
    socklen_t addrlen = sizeof(address);
    if (-1 == (fd = accept(listen_fd, 
                           (struct sockaddr *)&address,
                           &addrlen))) {
        perror("accept()");
    }

    // add client
    FD_SET(fd, &conn_set);
    if (fd > conn_max_fd) { conn_max_fd = fd; }
    if (verbose > 0) { fprintf(stderr, "%3d: Connect\n", fd); }
    conn_nb ++;

}

void broadcast_buffer(void *buf, size_t buffer_size) {
    // broadcast to all connections
    int fd;
    for (fd = 0; fd <= conn_max_fd; fd++) {
        if (FD_ISSET(fd, &conn_set)) {
            safe_write(fd, buf, buffer_size);
        }
    }
}

/* Note: for maximum throughput we should really read() and write() as
   much as possible at once.  Currently we choose ease of
   implementation and read() the size then the message. */

void do_broadcast(void) {

    // read input message size word
    message_size_t size;
    safe_read(0, &size, sizeof(size));

    if (verbose > 1) { fprintf(stderr, "%5d\n", size); }

    size_t payload_size = size;
    size_t buffer_size  = payload_size + sizeof(size);
    message_size_t buf[buffer_size / sizeof(size)];
    buf[0] = size;

    // read payload if any
    safe_read(0, &buf[1], payload_size);

    // pass it on
    broadcast_buffer(buf, buffer_size);
}

char *me = NULL;
void cleanup(void){
    if (verbose > 0) fprintf(stderr, "%s exiting.\n", me);
    remove(address.sun_path);
}

// wait for connections + input messages
int mainloop() {
    struct timeval tv;
    fd_set in;

  again:
    if (gotsig) {
        fprintf(stderr, "\nGot signal %d (%s).\n", gotsig, strsignal(gotsig));
        exit(0);
    }

    // fprintf(stderr, "_");
    FD_ZERO(&in);

    /* Only read from input when we have enough clients. */
    if (conn_nb >= redundancy) {
        FD_SET(0, &in);
    }

    FD_SET(listen_fd, &in);
    tv.tv_sec = 1;
    tv.tv_usec = 0;
    if (-1 == select (1 + listen_fd, &in, NULL, NULL, &tv)) {
        if (EBADF == errno) {
            printf("Server socket closed.\n");
            return -1;
        }
        if (EINTR == errno) {
            goto again;
        }
        perror("select()");
        return -1;
    }

    if (FD_ISSET(0, &in)) { do_broadcast(); }
    if (FD_ISSET(listen_fd, &in)) { do_connect(); }
    goto again;
}

int start(void) {

    socklen_t addrlen = sizeof(address);

    /* create socket */
    if ((listen_fd = socket(PF_UNIX, SOCK_STREAM, 0)) < 0) {
        fprintf(stderr, "create_socket()");
        return -errno;
    }

    address.sun_family = AF_UNIX;
    strcpy(address.sun_path, path);
    addrlen =  sizeof(address.sun_family) 
        + strlen(address.sun_path) + 1;

    if (0 == remove(address.sun_path)) {
        fprintf(stderr, "removing stale socket %s\n", address.sun_path);
    }

    /* bind + listen */
    if (-1 == bind(listen_fd, (struct sockaddr *) &address, addrlen)){
        perror("bind()");
        return -errno;
    }
    if (-1 == listen(listen_fd, 10)) {
        perror("listen()");
        return -errno;
    }



    if (verbose > 0) { fprintf(stderr, "Listening on %s\n", path); }
    return 0;
}

void handler(int signal) {
    gotsig = signal;
}

void usage(void) {
    fprintf(stderr, "Pascal string message broadcast to Unix socket.\n");
    fprintf(stderr, "usage:   %s <socket>\n", me);
    fprintf(stderr, "options: \n");
    fprintf(stderr, "  -r <n>       Minimum nb of clients (0).\n");
    fprintf(stderr, "  -s <socket>  Unix socket path (/tmp/broadcast).\n");
}
void usage_exit(void){
    usage();
    exit(1);
}

int main (int argc, char **argv) {
    int rv;
    me   = argv[0];

    int c;

    while ((c = getopt(argc, argv, "r:s:v")) != -1) {
        switch (c) {
        case 'v': verbose++; break;
        case 'r': redundancy = strtol(optarg, NULL, 10); break;
        case 's': path = optarg; break;
        case '?':
            if ((optopt == 'r') ||
                (optopt == 's')) {
                fprintf (stderr, "Option -%c requires an argument.\n", 
                         optopt);
            }
            else if (isprint (optopt)) {
                fprintf (stderr, "Unknown option `-%c'.\n", optopt);
            }
            else {
                fprintf (stderr,
                         "Unknown option character `\\x%x'.\n",
                         optopt);
            }
        default:
            usage_exit();
        }
    }
    FD_ZERO(&conn_set);
    atexit(cleanup);
    signal(SIGPIPE, handler);
    signal(SIGINT,  handler);
    signal(SIGTERM, handler);
    signal(SIGHUP,  handler);
    if ((rv = start())) return rv;
    return mainloop();
}

