/*
 * TESLA: A Transparent, Extensible Session-Layer Architecture
 *
 * Jon Salz <jsalz@mit.edu>
 * Alex C. Snoeren <snoeren@lcs.mit.edu>
 *
 * Copyright (c) 2001-2 Massachusetts Institute of Technology.
 *
 * This software is being provided by the copyright holders under the GNU
 * General Public License, either version 2 or, at your discretion, any later
 * version. For more information, see the `COPYING' file in the source
 * distribution.
 *
 * $Id: libc.c,v 1.19 2002/10/08 19:03:33 snoeren Exp $
 *
 * libc overrides for the libtesla.so wrapper library.  #included by tesla.c.
 *
 */

int socket(int domain, int type, int protocol)
{
    /* Ask the master to give us a socket. */
    master_msg_t msg;
    unsigned int id;

    if (!tesla_enabled(1))
	return TS_CALL_LIBC(socket, domain, type, protocol);

    ts_debug_3("socket(domain=%d, type=%d, protocol=%d)", domain, type, protocol);

    msg.type = MSG_SOCKET;
    msg.body.socket.domain = domain;
    msg.body.socket.type = type;
    msg.body.socket.protocol = protocol;
    id = send_master(&msg);

    if (recv_master(&msg) != id)
	ts_fatal("Mis-sequenced response");
    if (msg.type == MSG_NAK) {
	ts_debug_3(" - declined by master");
	return TS_CALL_LIBC(socket, domain, type, protocol);
    } else if (msg.type == MSG_FH_FOLLOWS) {
	int fd = msg.body.fh_pass.fh;
	assert(fd >= 0 && fd < FD_SETSIZE);
	ts_fds[fd].wrapped = 1;
	ts_fds[fd].domain = domain;
	ts_fds[fd].type = type;
	ts_fds[fd].conn_id = msg.conn_id;
	ts_debug_3(" - got FD %d from master; ID %d", fd, msg.conn_id);
	return fd;
    } else {
	ts_fatal("Expected NAK or FH_FOLLOWS but got %d", msg.type);
    }
}

// Any async connections might be pending?
static int async_connect = 0;

int connect(int fd, const struct sockaddr *addr, socklen_t addrlen)
{
    master_msg_t msg;
    unsigned int id;
    char buf;
    int bytes;
    
    if (!tesla_enabled(0) || !ts_fds[fd].wrapped)
	return TS_CALL_LIBC(connect, fd, addr, addrlen);

    ts_fds[fd].connect_status = 0;

    msg.conn_id = ts_fds[fd].conn_id;
    msg.type = MSG_CONNECT;
    assert(addrlen <= sizeof msg.body.address.addr);
    msg.body.address.addrlen = addrlen;
    memcpy(msg.body.address.addr, addr, addrlen);
    id = send_master(&msg);

    /* Wait for connect by reading on socket. */
    ts_debug_3("Connect - waiting for response");

    if (recv_master(&msg) != id)
	ts_fatal("Mismatched message");

    if (msg.type != MSG_ACK) {
	ts_debug_3(" - connect itself failed!");
	ts_fds[fd].connect_status = errno = msg.merrno;
	return -1;
    }

    if (ts_fds[fd].type == SOCK_DGRAM)
	return 0;

    bytes = TS_CALL_LIBC(read, fd, &buf, 1);

    if (bytes < 0 && errno == EAGAIN) {
	async_connect = 1;
	ts_fds[fd].connecting = 1;
	errno = EINPROGRESS;
	return -1;
    }

    if (bytes != 1)
	ts_fatal("Didn't read one byte after connect");
	
    if (buf == 'A') {
	ts_debug_2("Connect failed");
	ts_fds[fd].connect_status = errno = ECONNREFUSED;
	return -1;
    } else if (buf == 'B') {
	ts_debug_2("Connect succeeded");
	ts_fds[fd].connect_status = 0;
	return 0;
    } else {
	ts_fatal("Expected A or B for connect; got %u", (unsigned int)buf);
    }
}

int listen(int fd, int backlog)
{
    master_msg_t msg;

    if (!tesla_enabled(0) || !ts_fds[fd].wrapped)
	return TS_CALL_LIBC(listen, fd, backlog);

    ts_debug_2("Listening!");

    msg.conn_id = ts_fds[fd].conn_id;
    msg.type = MSG_LISTEN;
    send_master(&msg);

    return 0;
}

int bind(int fd, const struct sockaddr *addr, socklen_t addrlen)
{
    master_msg_t msg;
    unsigned int id;

    if (!tesla_enabled(0) || !ts_fds[fd].wrapped)
	return TS_CALL_LIBC(bind, fd, addr, addrlen);
    
    msg.conn_id = ts_fds[fd].conn_id;
    msg.type = MSG_BIND;
    assert(addrlen <= sizeof msg.body.address.addr);
    msg.body.address.addrlen = addrlen;
    memcpy(msg.body.address.addr, addr, addrlen);
    id = send_master(&msg);

    if (recv_master(&msg) != id)
	ts_fatal("Mismatched message");

    if (msg.type == MSG_ACK) {
	ts_debug_2("Bind succeeded");
	return 0;
    } else if (msg.type == MSG_NAK) {
	ts_debug_2("Bind failed");
	errno = msg.merrno;
	return -1;
    }

    ts_fatal("Expected ACK or NAK but got %d", msg.type);
}

int accept(int fd, struct sockaddr *addr, socklen_t *addrlen)
{
    master_msg_t msg;
    unsigned int id;
    unsigned char buf;
    size_t bytes;

    if (!tesla_enabled(0) || !ts_fds[fd].wrapped)
	return TS_CALL_LIBC(accept, fd, addr, addrlen);

    ts_debug_2("In accept... reading from master");

    /* Read a byte from the FD. If it's blocking, it will wait for a byte (C) */
    bytes = TS_CALL_LIBC(read, fd, &buf, 1);

    ts_debug_2(" - got %d bytes", bytes);

    if (bytes == 1) {
	if (buf != 'C')
	    ts_fatal("Expected C but got %u", (unsigned int)buf);
	
	ts_debug_2("Got a 'C' for accept");

	msg.conn_id = ts_fds[fd].conn_id;
	msg.type = MSG_ACCEPT;
	id = send_master(&msg);    

	if (recv_master(&msg) != id)
	    ts_fatal("Mismatched message");

	if (msg.type == MSG_ACK) {
	  if(addr) {
	    *addrlen = msg.body.address.addrlen;
	    memcpy(addr, msg.body.address.addr, msg.body.address.addrlen);
	  }
	    if (recv_master(&msg) != id)
		ts_fatal("Mismatched message");
	    if (msg.type != MSG_FH_FOLLOWS)
		ts_fatal("Expected filehandle");

	    ts_fds[msg.body.fh_pass.fh].wrapped = 1;
	    ts_fds[msg.body.fh_pass.fh].conn_id = msg.conn_id;

	    ts_debug_2("Returning FD %d", msg.body.fh_pass.fh);

	    return msg.body.fh_pass.fh;
	}

	if (msg.type != MSG_NAK)
	    ts_fatal("Mismatched message");

	return -1;
    }
    else
	return -1;
}


int __close(int fd)
{
    int ret;

    if (!tesla_enabled(0) || !ts_fds[fd].wrapped)
	return TS_CALL_LIBC(__close, fd);

    if((ret = TS_CALL_LIBC(__close, fd)) != -1) {
	ts_debug_2("__Closing fd %d");
	ts_fds[fd].wrapped = 0;
    }
    return ret;
}



int close(int fd)
{
    int ret;

    if (!tesla_enabled(0) || !ts_fds[fd].wrapped)
	return TS_CALL_LIBC(close, fd);

    if((ret = TS_CALL_LIBC(close, fd)) != -1) {
	ts_debug_2("Closing fd %d");
	ts_fds[fd].wrapped = 0;
    }
    return ret;
}

/* We need to relay these to the master process */
int
setsockopt(int fd, int level, int optname, const void *optval,
	   socklen_t optlen)
{
    unsigned int id;
    master_msg_t msg;
    const ts_ioctl_s *io = 0;

    ts_debug_2("Beginning of setsockopt");

    if (level == SOL_TESLA_IOCTL) {
	ts_debug_2(" - it's an ioctl");

	io = optval;

	ts_debug_2("ioctl: target=%s, optname=%d, optlen=%d, retlen=%p",
		   io->target, io->optname, io->optlen, io->retlen);

	if (io == 0 || optlen != sizeof(ts_ioctl_s)) {
	    ts_error("Invalid arguments to IOCTL in setsockopt");
	    errno = EINVAL;
	    return -1;
	}

	if (io->target == 0) {
	    /* Handle internally */
	    if (io->optname == SO_TESLA_RESTORE_STATE) {
		return restore_state((char*)io->optval);
	    }

	    if (!tesla_enabled(1))
		return TS_CALL_LIBC(setsockopt, fd, level, optname, optval, optlen);

	    if (io->optname == SO_TESLA_HELLO)
		return 0;

	    if (io->optname == SO_TESLA_MESSAGES) {
		const ts_message_handler_s *h;

		h = io->optval;

		if (io->optlen != sizeof(ts_message_handler_s)) {
		    errno = EINVAL;
		    return -1;
		}
		return set_message_handler(h->handler, h->signal);
	    }

	    if (io->optname == SO_TESLA_INESSENTIAL) {
		if (io->optlen != sizeof(int)) {
		    errno = EINVAL;
		    return -1;
		}
		msg.type = MSG_INESSENTIAL;
		msg.body.inessential.val = *(int*)io->optval;
		send_master(&msg);
		return 0;
	    }

	    if (io->optname == SO_TESLA_SAVE_STATE) {
		return save_state((char*)io->optval);
	    }

	    if (io->optname == SO_TESLA_FD_USAGE) {
	        if (fd < 0 || fd >= FD_SETSIZE) {
		    errno = EINVAL;
		    return -1;
	        }

		if (fd == master_fd)
		    return TS_FD_IS_MASTER;
		if (fd == msg_fd)
		    return TS_FD_IS_MSG;
		if (ts_fds[fd].wrapped)
		    return TS_FD_IS_WRAPPED;

		return TS_FD_IS_UNWRAPPED;
	    }

	    errno = EINVAL;
	    return -1;
	}

	if (!ts_fds[fd].wrapped || fd < 0 || fd >= FD_SETSIZE) {
	    ts_error("ioctl on unwrapped FD %d", fd);
	    errno = EINVAL;
	    return -1;
	}

	assert(strlen(io->target) < sizeof msg.body.sockopt.target);
	strcpy(msg.body.sockopt.target, io->target);
	optname = io->optname;
	optval = io->optval;
	optlen = io->optlen;
    }

    if (!tesla_enabled(1) || !ts_fds[fd].wrapped)
	return TS_CALL_LIBC(setsockopt, fd, level, optname, optval, optlen);

    msg.conn_id = ts_fds[fd].conn_id;
    msg.type = MSG_SETSOCKOPT;
    msg.body.sockopt.level = level;
    msg.body.sockopt.optname = optname;
    msg.body.sockopt.optlen = optlen;

    ts_debug_3("In setsockopt: %d %d", level, optname);

    switch(level) {
      default:
	ts_debug_2("Sending header");
	id = send_master_header(&msg, optlen + sizeof msg.body.sockopt - sizeof msg.body.sockopt.optval);
	ts_debug_2("Sending option info");
	if (write_fully(master_fd, &msg.body.sockopt, sizeof msg.body.sockopt - sizeof msg.body.sockopt.optval) != (sizeof msg.body.sockopt - sizeof msg.body.sockopt.optval))
	    ts_fatal("unable to send setsockopt data");
	ts_debug_2("Sending option data");
	if (write_fully(master_fd, optval, optlen) != optlen)
	    ts_fatal("unable to send setsockopt data");
	ts_debug_2("Getting ack/nak");

	if (recv_master(&msg) != id)
	    ts_fatal("Mismatched message");

	if (io && msg.type == MSG_FH_FOLLOWS) {
	    if (*(io->retlen) > sizeof(int))
		*(io->retlen) = sizeof(int);
	    memcpy(io->retval, &msg.body.fh_pass.fh, *(io->retlen));
	    return 0;
	}

	errno = msg.merrno;
	if (msg.type != MSG_ACK)
	    return -1;

	if (io) {
	    if (msg.body.sockopt.optlen < *(io->retlen))
		*(io->retlen) = msg.body.sockopt.optlen;
	    memcpy(io->retval, msg.body.sockopt.optval, *(io->retlen));
	}

	return 0;
    }

    /* Not reached */
    return 0;
}

int
getsockopt(int fd, int level, int optname, void *optval, socklen_t *optlen)
{
    unsigned int id;
    master_msg_t msg;

    if (!tesla_enabled(0) || !ts_fds[fd].wrapped || fd < 0 || fd >= FD_SETSIZE)
	return TS_CALL_LIBC(getsockopt, fd, level, optname, optval, optlen);
  
    msg.conn_id = ts_fds[fd].conn_id;
    msg.type = MSG_GETSOCKOPT;
    msg.body.sockopt.level = level;
    msg.body.sockopt.optname = optname;
    msg.body.sockopt.optlen = *optlen;

    ts_debug_2("In getsockopt");

    if (level == SOL_SOCKET && optname == SO_ERROR) {
	ts_debug_1("- returning connect_status");
	if (*optlen > sizeof ts_fds[fd].connect_status)
	    *optlen = sizeof ts_fds[fd].connect_status;
	memcpy(optval, &ts_fds[fd].connect_status, *optlen);
	return 0;
    }

    id = send_master(&msg);
    if (recv_master(&msg) != id)
	ts_fatal("Mismatched message");

    if (msg.type == MSG_ACK) {
	memcpy(optval, msg.body.sockopt.optval, *optlen = msg.body.sockopt.optlen);
	return 0;
    }
    
    *optlen = 0;
    return -1;
}

static int
getxname(int fd, struct sockaddr *addr, socklen_t *addrlen, int which)
{
    unsigned int id;
    master_msg_t msg;
  
    msg.conn_id = ts_fds[fd].conn_id;
    msg.type = which;

    id = send_master(&msg);
    if (recv_master(&msg) != id)
	ts_fatal("Mismatched message");

    if (msg.type == MSG_ACK) {
	*addrlen = *addrlen < msg.body.address.addrlen ? *addrlen : msg.body.address.addrlen;
	memcpy(addr, msg.body.address.addr, *addrlen);
	return 0;
    }

    return -1;
}

int
getpeername(int fd, struct sockaddr *addr, socklen_t *addrlen)
{
    if (!tesla_enabled(0) || !ts_fds[fd].wrapped)
	return TS_CALL_LIBC(getpeername, fd, addr, addrlen);

    return getxname(fd, addr, addrlen, MSG_GETPEERNAME);
}

int
getsockname(int fd, struct sockaddr *addr, socklen_t *addrlen)
{
    if (!tesla_enabled(0) || !ts_fds[fd].wrapped)
	return TS_CALL_LIBC(getsockname, fd, addr, addrlen);

    return getxname(fd, addr, addrlen, MSG_GETSOCKNAME);
}

int
dup2(int fd, int newfd)
{
    int ret;

    if (!tesla_enabled(0))
	return TS_CALL_LIBC(dup2, fd, newfd);

    // Move the master FD if necessary
    if (newfd == master_fd)
	master_fd = TS_CALL_LIBC(dup, master_fd);

    // Close whatever it is we were dup'ing on top of (necessary in
    // case the destination FD is Tesla-wrapped)
    close(newfd);

    ret = TS_CALL_LIBC(dup2, fd, newfd);
    if (ret != -1) {
	if(ts_fds[fd].wrapped) {
	    ts_fds[newfd].wrapped = 1;
	    ts_fds[newfd].domain = ts_fds[fd].domain;
	    ts_fds[newfd].type = ts_fds[fd].type;
	    ts_fds[newfd].conn_id = ts_fds[fd].conn_id;
	}
    }
    return ret;
}

int
dup(int fd)
{
    int ret;

    if (!tesla_enabled(0))
	return TS_CALL_LIBC(dup, fd);
    ret = TS_CALL_LIBC(dup,fd);

    if(ret != -1) {
	if(ts_fds[fd].wrapped) {
	    ts_fds[ret].wrapped = 1;
	    ts_fds[ret].domain = ts_fds[fd].domain;
	    ts_fds[ret].type = ts_fds[fd].type;
	    ts_fds[ret].conn_id = ts_fds[fd].conn_id;
	}
    }
    return ret;
}

pid_t fork()
{
    master_msg_t msg;
    unsigned int id;
    pid_t pid;

    if (!tesla_enabled(0) || master_fd < 0)
	return TS_CALL_LIBC(fork);

    ts_debug_1("Forking");

    msg.type = MSG_FORK;
    id = send_master(&msg);

    if (recv_master(&msg) != id)
	ts_fatal("Mis-sequenced response");
    assert(msg.type == MSG_FH_FOLLOWS);

    pid = TS_CALL_LIBC(fork);
    if (pid < 0) {
	int e = errno;
	close(msg.body.fh_pass.fh);
	errno = e;
	return pid;
    }

    if (pid == 0) {
	// Child
	ts_debug_1("Child");
	close(master_fd);
	master_fd = msg.body.fh_pass.fh;
    } else {
	// Parent
	ts_debug_1("Parent");
	close(msg.body.fh_pass.fh);
	return pid;
    }

    return pid;
}

/* We do funky stuff for dup and exec, so if an application calls
 * vfork+exec, it would look like vfork+tesla internal stuff+exec,
 * which violates the vfork requirements (nothing happens between
 * vfork and exec).  We punt this by turning vfork into fork.
 */

pid_t vfork()
{
    if (!tesla_enabled(0))
	return TS_CALL_LIBC(fork);

    return fork();
}

// We have to handle only one special case for select - when
// the app does a connect() on a non-blocking socket and selects
// on write to wait for the connection to succeed or fail.
int select(int n, fd_set *_rfds, fd_set *_wfds, fd_set *efds,
	   struct timeval *timeout)
{
    int i;
    int have_any = 0;
    int ret;
    fd_set rfds;
    fd_set wfds;
    fd_set cwfds;
    fd_set crfds;

    if (!tesla_enabled(0) || !async_connect)
	return TS_CALL_LIBC(select, n, _rfds, _wfds, efds, timeout);

 again:

    // Start with specified fds
    if(_rfds)
      memcpy(&rfds, _rfds, sizeof(fd_set));
    else
      FD_ZERO(&rfds);
    if(_wfds)
      memcpy(&wfds, _wfds, sizeof(fd_set));
    else
      FD_ZERO(&wfds);

    FD_ZERO(&crfds);
    FD_ZERO(&cwfds);

    // fds might contain handles with a pending asynchronous
    // connect; if so handle them specially
    for (i = 0; i < n; ++i) {
	if (ts_fds[i].connecting && (FD_ISSET(i, &wfds) ||
				     FD_ISSET(i, &rfds))) {
	    have_any = 1;
	    ts_debug_2(" - specially handling async connect select on %d", i);
	    if(FD_ISSET(i, &rfds)) {
	      FD_SET(i, &crfds);
	    } else {
	      FD_SET(i, &rfds);
	    }
	    if(FD_ISSET(i, &wfds)) {
	      FD_SET(i, &cwfds);
	      FD_CLR(i, &wfds);
	    }
	}
    }

    ret = TS_CALL_LIBC(select, n, &rfds, &wfds, efds, timeout);
    if (!have_any || (ret <= 0)) return ret;

    for (i = 0; i < n; ++i) {
	if (ts_fds[i].connecting && FD_ISSET(i, &rfds)) {
	    int bytes;
	    char ch;

	    ts_debug_2(" - handling connecting select on %d", i);
	    ts_fds[i].connecting = 0;

	    // Read the byte with the connect result
	    bytes = TS_CALL_LIBC(read, i, &ch, 1);
	    if (bytes != 1) {
		ts_error("Unable to read connect result after async connect");
		continue;
	    }

	    if (ch == 'A') {
		ts_debug_2("   - async connect failed");
		ts_fds[i].connect_status = ECONNREFUSED;
	    } else if (ch == 'B') {
		ts_debug_2("   - async connect succeeded");
		ts_fds[i].connect_status = 0;
	    } else {
		ts_error("   - weird byte (neither A nor B) on async connect");
	    }

	    // Turn the read into a write if necessary
	    if(FD_ISSET(i, &crfds)) {
	      ts_debug_2("Didn't actually select on read");
	      FD_CLR(i, &rfds);
	      --ret;
	    }
	    if (!FD_ISSET(i, &wfds) && FD_ISSET(i, &cwfds)) {
	      FD_SET(i, &wfds);
	      ++ret;
	    }

	}
    }

    // If we just handled a connect and don't actually have any data, go again
    if (have_any && !ret) {
      ts_debug_2("Restarting select");
      goto again;
    }

    if(_rfds)
      memcpy(_rfds, &rfds, sizeof(fd_set));
    if(_wfds)
      memcpy(_wfds, &wfds, sizeof(fd_set));
    return ret;
}
