/*
 * TESLA: A Transparent, Extensible Session-Layer Architecture
 *
 * Jon Salz <jsalz@mit.edu>
 * Alex C. Snoeren <snoeren@lcs.mit.edu>
 *
 * Copyright (c) 2001-2 Massachusetts Institute of Technology.
 *
 * This software is being provided by the copyright holders under the GNU
 * General Public License, either version 2 or, at your discretion, any later
 * version. For more information, see the `COPYING' file in the source
 * distribution.
 *
 * $Id: teslamaster.cc,v 1.40 2002/10/10 19:18:53 snoeren Exp $
 *
 * The master process implementation.
 *
 */

#include "config.h"

#include <fcntl.h>
#include <sys/socket.h>
#include <sys/uio.h>
#include <sys/wait.h>
#include <signal.h>
#include <unistd.h>

#include <iostream>
#include <strstream>
#include <map>
#include <set>

#include "tesla/async.hh"
#include "tesla/client.hh"
#include "tesla/flow_handler.hh"
#include "tesla/tesla.h"
#include "tesla-internal.h"
#include "top_handler.hh"
#include "tesla/serial.hh"

using namespace std;

class top_handler;

set<top_handler*> close_set;
set<client*> clients;
map<int, top_handler*> conns; /* by conn_id */
map<int, top_handler*> conns_by_fd;

map<int, int> saved_fds; /* by conn_id */

int read_fully(int fd, void *buf, int count)
{
    int orig_count = count;
    while (count) {
        int bytes = read(fd, buf, count);
        if (bytes <= 0)
            return 0;
        assert(bytes <= count);
        buf = static_cast<char*>(buf) + bytes;
        count -= bytes;
    }
    return orig_count;
}

int write_fully(int fd, const void *buf, int count)
{
    int orig_count = count;
    while (count) {
        int written = write(fd, buf, count);
        if (written <= 0)
            return 0;
        assert(written <= count);
        buf = static_cast<const char*>(buf) + written;
        count -= written;
    }
    return orig_count;
}

void send_msg(int fd, master_msg_t *msg) {
    ts_debug_2("Sending message: id=%d, conn_id=%d, type=%s", msg->id, msg->conn_id,
	       MSG_TYPE(msg->type));

    int len = sizeof *msg - sizeof msg->id;
    if (write(fd, &len, sizeof len) != sizeof len)
	ts_fatal("Unable to write length");
    if (write(fd, &msg->id, sizeof *msg - sizeof msg->len) != sizeof *msg - sizeof msg->len)
	ts_fatal("Unable to write message body");
}

void send_msg_header(int fd, master_msg_t *msg, int length) {
    ts_debug_2("Sending message: id=%d, conn_id=%d, type=%s", msg->id, msg->conn_id,
	       MSG_TYPE(msg->type));

    msg->len = MASTER_MSG_HDRLEN + length - sizeof msg->len;
    if (!write_fully(fd, msg, MASTER_MSG_HDRLEN))
	ts_fatal("Unable to write length");
}

static struct cmsghdr	*cmptr = NULL;	/* buffer is malloc'ed first time */
#define	CONTROLLEN	(sizeof(struct cmsghdr) + sizeof(int))

int
send_fd(int clifd, int fd) {
    struct iovec	iov[1];
    struct msghdr	msg;
    char buf[2]; /* send_fd()/recv_fd() 2-byte protocol */

    iov[0].iov_base = buf;
    iov[0].iov_len  = 2;
    msg.msg_iov     = iov;
    msg.msg_iovlen  = 1;
    msg.msg_name    = NULL;
    msg.msg_namelen = 0;
    if (fd<0) {
	msg.msg_control    = NULL;
	msg.msg_controllen = 0;
	buf[1] = -fd;	/* nonzero status means error */
	if (buf[1] == 0) buf[1] = 1;	/* -256, etc. would screw up protocol */
    } else {
	if (cmptr == NULL && (cmptr = (cmsghdr*)malloc(CONTROLLEN)) == NULL) ts_fatal("Bollocks!");
	cmptr->cmsg_level  = SOL_SOCKET;
	cmptr->cmsg_type   = SCM_RIGHTS;
	cmptr->cmsg_len    = CONTROLLEN;
	msg.msg_control    = (caddr_t) cmptr;
	msg.msg_controllen = CONTROLLEN;
	*(int *)CMSG_DATA(cmptr) = fd;	/* the fd to pass */
	buf[1] = 0;		/* zero status means OK */
    }
    buf[0] = 0;			/* null byte flag to recv_fd() */
  
    if (sendmsg(clifd, &msg, 0) != 2) ts_fatal("Bollocks! %s", strerror(errno));
    return(0);
}


int next_conn_id = 1000;

void client::ravail() {
    master_msg_t msg;
    size_t bytes;

    bytes = read(fd, &msg.len, sizeof msg.len);
    if (bytes != 4) {
	ts_debug_2("Client %d has closed", fd);
	clients.erase(this);
	delete this;
	return;
    }
    
    if (!read_fully(fd, &msg.id, msg.len)) {
	ts_error("Got wrong number of bytes on %d", fd);
	clients.erase(this);
	delete this;
	return;
    }

    ts_debug_2("Received message: id=%d, conn_id=%d, pid=%d, type=%s",
	       msg.id, msg.conn_id, msg.pid, MSG_TYPE(msg.type));

    switch (msg.type) {
      case MSG_HELLO:
	{
	  msg.merrno = 0;
	  send_msg(fd, &msg);
	  break;
	}

      case MSG_NAK:
	{
	    ts_debug_2("Client %d died", fd);
	    clients.erase(this);
	    delete this;
	    break;
	}

      case MSG_SOCKET:
	{
	    ts_debug_2(" - domain=%d; type=%d", msg.body.socket.domain, msg.body.socket.type);

	    flow_handler *h = flow_handler::plumb(*(flow_handler*)0, msg.body.socket.domain, msg.body.socket.type);

	    if (h) {
		ts_debug_2("   - got a stack!");

		string out;
		h->dump(out);
		ts_debug_2("Handlers:\n%s", out.c_str());

		int pair[2];
		if (socketpair(PF_UNIX, SOCK_STREAM, 0, pair) != 0)
		    ts_fatal("Unable to socketpair");
		int conn_id = next_conn_id++;
		
		msg.conn_id = conn_id;
		msg.type = MSG_FH_FOLLOWS;
		msg.merrno = 0;
		send_msg(fd, &msg);

		send_fd(fd, pair[0]);
		close(pair[0]);

		if (fcntl(pair[1], F_SETFL, O_NONBLOCK))
		    ts_fatal("fcntl: %s", strerror(errno));

		top_handler *c = new top_handler(pair[1], h->get_info().get_type(), conn_id, h);
		conns[conn_id] = c;
		conns_by_fd[pair[1]] = c;
	    } else {
		ts_debug_2("   - no stack; declining");
		msg.type = MSG_NAK;
		send_msg(fd, &msg);
	    }
	    break;
	}

      case MSG_CONNECT:
	{
	    address a(msg.body.address.addr, msg.body.address.addrlen);
	    map<int, top_handler*>::const_iterator i = conns.find(msg.conn_id);
	    if (i == conns.end())
		ts_fatal("Reference to nonexistent conn_id");

	    int ret = i->second->get_handler()->connect(a);
	    ts_debug_2("Connect request: %d", ret);
	    msg.type = (ret == 0 ? MSG_ACK : MSG_NAK);
	    msg.merrno = -ret;
	    send_msg(fd, &msg);
	    break;
	}

      case MSG_BIND:
	{
	    address a(msg.body.address.addr, msg.body.address.addrlen);
	    map<int, top_handler*>::const_iterator i = conns.find(msg.conn_id);
	    if (i == conns.end())
		ts_fatal("Reference to nonexistent conn_id");

	    int ret = i->second->get_handler()->bind(a);
	    ts_debug_2("Bind request: %d", ret);
	    msg.type = (ret == 0 ? MSG_ACK : MSG_NAK);
	    msg.merrno = -ret;
	    send_msg(fd, &msg);
	    break;
	}

      case MSG_LISTEN:
	{
	    map<int, top_handler*>::const_iterator i = conns.find(msg.conn_id);
	    if (i == conns.end())
		ts_fatal("Reference to nonexistent conn_id");

	    int ret = i->second->get_handler()->listen(5);
	    ts_debug_2("Listen request: %d", ret);
	    msg.type = ret == 0 ? MSG_ACK : MSG_NAK;
	    msg.merrno = -ret;
	    //	    send_msg(fd, &msg);
	    break;
	}

      case MSG_ACCEPT:
	{
	    map<int, top_handler*>::const_iterator i = conns.find(msg.conn_id);
	    if (i == conns.end())
		ts_fatal("Reference to nonexistent conn_id");

	    acceptret h = i->second->get_handler()->accept();

	    if (h) {
		ts_debug_2("   - got a stack!");

		string out;
		h.h->dump(out);
		ts_debug_2("Handlers:\n%s", out.c_str());

		int pair[2];
		if (socketpair(PF_UNIX, h.h->get_type(), 0, pair) != 0)
		    ts_fatal("Unable to socketpair");
		int conn_id = next_conn_id++;

		msg.conn_id = conn_id;
		msg.merrno = 0;
		msg.type = MSG_ACK;
		msg.body.address.addrlen = h.addr.addrlen();
		memcpy(msg.body.address.addr, h.addr.addr(), h.addr.addrlen());
		send_msg(fd, &msg);

		msg.type = MSG_FH_FOLLOWS;
		send_msg(fd, &msg);

		send_fd(fd, pair[0]);
		close(pair[0]);
		
		// All TESLA sockets are non-blocking
		if(fcntl(pair[1], F_SETFL, O_NONBLOCK))
		  ts_fatal("fcntl: %s", strerror(errno));

		top_handler *c = new top_handler(pair[1], h.h->get_type(), conn_id, h.h);
		conns[conn_id] = c;
		conns_by_fd[pair[1]] = c;
	    } else {
		ts_debug_2("   - accept failed");
		msg.type = MSG_NAK;
		send_msg(fd, &msg);
	    }
	    break;
	}

      case MSG_GETSOCKOPT:
	{
	    ts_debug_2("Got getsockopt");

	    map<int, top_handler*>::const_iterator i = conns.find(msg.conn_id);
	    if (i == conns.end())
		ts_fatal("Reference to nonexistent conn_id");

	    string val = i->second->getsockopt(msg.body.sockopt.level,
					       msg.body.sockopt.optname,
					       msg.body.sockopt.optlen);

	    if (val.length()) {
		msg.type = MSG_ACK;
		msg.merrno = 0;
		send_msg_header(fd, &msg, sizeof msg.body.sockopt - sizeof msg.body.sockopt.optval + val.length());
		msg.body.sockopt.optlen = val.length();
		if (!write_fully(fd, &msg.body.sockopt, sizeof msg.body.sockopt - sizeof msg.body.sockopt.optval) ||
		    !write_fully(fd, val.data(), val.length()))
		{
		    ts_fatal("Unable to write sockopt data to %d", fd);
		}
	    } else {
		msg.type = MSG_NAK;
		send_msg(fd, &msg);
	    }

	    break;
	}

      case MSG_SETSOCKOPT:
	{
	    ts_debug_2("Got setsockopt");

	    map<int, top_handler*>::const_iterator i = conns.find(msg.conn_id);
	    if (i == conns.end())
		ts_fatal("Reference to nonexistent conn_id");

	    string opt(msg.body.sockopt.optval, msg.body.sockopt.optlen);

	    if (msg.body.sockopt.level == SOL_TESLA_IOCTL) {
		ts_debug_1("Got ioctl; target=%s; optname=%d; optlen=%d",
			   msg.body.sockopt.target, msg.body.sockopt.optname, opt.length());
		string out;
		int ret = i->second->ioctl(msg.body.sockopt.target, msg.body.sockopt.optname,
					   opt, out);
		ts_debug_1("ioctl returned %d", ret);
		if (ret == flow_handler::PASS_FD) {
		    ts_debug_1(" - handler wants to pass FD");
		    if (out.length() != sizeof(int)) {
			ts_error("Invalid PASS_FD from handler");
			msg.type = MSG_NAK;
			msg.merrno = EINVAL;
			send_msg(fd, &msg);
		    } else {
			int pass_fd = *(const int *)out.data();

			ts_debug_1(" - passing fd %d", pass_fd);
			msg.type = MSG_FH_FOLLOWS;
			send_msg(fd, &msg);
			send_fd(fd, pass_fd);
			::close(pass_fd);
		    }
		} else if (ret == 0) {
		    ts_debug_1(" - return value len is %d", out.length());
		    msg.type = MSG_ACK;
		    assert(out.length() <= sizeof msg.body.sockopt.optval);
		    msg.body.sockopt.optlen = out.length();
		    memcpy(msg.body.sockopt.optval, out.data(), out.length());
		    send_msg(fd, &msg);
		} else {
		    msg.type = MSG_NAK;
		    msg.merrno = -ret;
		    send_msg(fd, &msg);
		}
	    } else {
		int ret = i->second->setsockopt(msg.body.sockopt.level, msg.body.sockopt.optname, opt);
		msg.type = (ret ? MSG_NAK : MSG_ACK);
		msg.merrno = -ret;
		send_msg(fd, &msg);
	    }

	    break;
	}

      case MSG_GETSOCKNAME:
      case MSG_GETPEERNAME:
	{
	    ts_debug_2("Got get%sname", msg.type == MSG_GETSOCKNAME ? "sock" : "peer");

	    map<int, top_handler*>::const_iterator i = conns.find(msg.conn_id);
	    if (i == conns.end())
		ts_fatal("Reference to nonexistent conn_id");

	    address val = msg.type == MSG_GETSOCKNAME ? i->second->getsockname() :
		i->second->getpeername();

	    if (val.addrlen()) {
		msg.type = MSG_ACK;
		msg.merrno = 0;
		msg.body.address.addrlen = val.addrlen();
		memcpy(msg.body.address.addr, val.addr(), val.addrlen());

		send_msg(fd, &msg);
	    } else {
		msg.type = MSG_NAK;
		send_msg(fd, &msg);
	    }

	    break;
	}

      case MSG_ASYNC_MESSAGES:
	{
	    ts_debug_2("Told to %s async messages", msg.body.async_messages.enabled ? "enable" : "disable");

	    messages_enabled = msg.body.async_messages.enabled;
	    if (messages_enabled && msg_fd < 0) {
		int pair[2];
		if (socketpair(PF_UNIX, SOCK_STREAM, 0, pair) != 0)
		    ts_fatal("Unable to socketpair");
		msg_fd_remote = pair[0];
		msg_fd = pair[1];
	    }

	    if (messages_enabled && msg.body.async_messages.pass) {
		msg.type = MSG_FH_FOLLOWS;
		send_msg(fd, &msg);
		send_fd(fd, msg_fd_remote);
	    } else {
		msg.type = MSG_ACK;
		msg.merrno = 0;
		send_msg(fd, &msg);
	    }

	    ts_debug_2("Async messages enabled: master FD is %d (remote FD is %d)", msg_fd, msg_fd_remote);

	    break;
	}

      case MSG_FORK:
	{
	    ts_debug_1("Application is forking - passing new control channel");

	    int pair[2];
	    if (socketpair(PF_UNIX, SOCK_STREAM, 0, pair) != 0)
		ts_fatal("Unable to socketpair");
	    int remote = pair[0];
	    int local = pair[1];

	    msg.type = MSG_FH_FOLLOWS;
	    send_msg(fd, &msg);
	    send_fd(fd, remote);
	    close(remote);

	    clients.insert(new client(local));
	    break;
	}

      case MSG_INESSENTIAL:
	{
	    ts_debug_1("Client says inessential=%d", msg.body.inessential.val);
	    inessential = msg.body.inessential.val != 0;
	    break;
	}

      case MSG_SAVE_STATE:
	{
	    ts_debug_1("Client says to save state to %s", msg.body.save_state.filename);

	    {
		ofstream of(msg.body.save_state.filename, ofstream::binary | ofstream::app);
		if (!of)
		    ts_fatal("Unable to open state file for writing");

		oserialstream out(of);
		
		// Version number
		out << 65;

		ts_debug_1("There are %d handlers to serialize", conns.size());
		for (map<int, top_handler*>::const_iterator i = conns.begin(); i != conns.end(); ++i)
		    ts_debug_1(" - conn %d, top_handler %p", i->first, i->second);

		// Save state for each handler
		for (map<int, top_handler*>::const_iterator i = conns.begin(); i != conns.end(); ++i) {
		    oserialstring out1;

		    top_handler *top = i->second;
		    out1 << *top;
		    if (out1) {
			out << i->first;
			out.write(out1.str().data(), out1.str().length());
			ts_debug_1("Saved state on connid %d - %d bytes", i->first, out1.str().length());
		    } else {
			ts_debug_1("Unable to save state on connid %d: %s",
				   i->first,
				   out1.error().c_str());
		    }
		}

		// End
		out << -1;
	    }
	    
	    // Acknowledge
	    msg.type = MSG_ACK;
	    send_msg(fd, &msg);

	    ts_debug_1("Done saving state; exiting");
	    exit(1);
	}

      case MSG_GET_FD:
	{
	    ts_debug_1("Client wants FD for conn ID %d", msg.conn_id);
	    map<int, int>::iterator i = saved_fds.find(msg.conn_id);

	    if (i == saved_fds.end()) {
		ts_debug_1(" - couldn't find it");
		msg.type = MSG_NAK;
		send_msg(fd, &msg);
	    } else {
		ts_debug_1(" - found it");
		msg.type = MSG_FH_FOLLOWS;
		send_msg(fd, &msg);
		send_fd(fd, i->second);
		close(i->second);
		saved_fds.erase(i);
	    }
	    break;
	}

      default:
	ts_fatal("Unexpected message type %d", msg.type);
    }
}

pid_t pid;

void send_message(unsigned int conn_id, string handler, int name, string value)
{
    ts_async_message_s msg;

    assert(handler.length() < sizeof(msg.handler));

    ts_debug_1("Sending async message: conn_id %d, handler %s, name %d, value %s",
	       conn_id, handler.c_str(), name, value.c_str());

    strcpy(msg.handler, handler.c_str());
    msg.name = name;
    msg.conn_id = conn_id;
    msg.data_length = value.length();

    for (set<client*>::const_iterator i = clients.begin(); i != clients.end(); ++i) {    
	if (!(*i)->are_messages_enabled()) continue;

	int fd = (*i)->get_msg_fd();

	if (!write_fully(fd, &msg, sizeof msg))
	    ts_fatal("Unable to write async message to %d: %s", fd, strerror(errno));
	if (!write_fully(fd, value.data(), value.length()))
	    ts_fatal("Unable to write async message to %d: %s", fd, strerror(errno));
    }
}

void send_message_fd(unsigned int conn_id, string handler, int name, int the_fd)
{
    ts_async_message_s msg;

    assert(handler.length() < sizeof(msg.handler));

    ts_debug_1("Sending async FD message: conn_id %d, handler %s, name %d",
	       conn_id, handler.c_str(), name);

    strcpy(msg.handler, handler.c_str());
    msg.name = name;
    msg.conn_id = conn_id;
    msg.data_length = -1;

    for (set<client*>::const_iterator i = clients.begin(); i != clients.end(); ++i) {    
	if (!(*i)->are_messages_enabled()) continue;

	int fd = (*i)->get_msg_fd();

	if (!write_fully(fd, &msg, sizeof msg))
	    ts_fatal("Unable to write async FD message to %d: %s", fd, strerror(errno));
	if (send_fd(fd, the_fd))
	    ts_fatal("Unable to write async FD message to %d: %s", fd, strerror(errno));
    }
}

#if 0
extern "C"
void pipe_sigaction(int signal, siginfo_t *info, void *arg)
{
    assert(signal == SIGPIPE && info);

    ts_debug_1("Received SIGPIPE on %d", info->si_fd);
    return;

    map<int, top_handler*>::const_iterator i = conns.find(info->si_fd);
    if (i == conns.end())
	return;

    ts_debug_2(" - second SIGPIPE to top_handler");

    // SIGPIPE from writing on a connection to the client
    i->second->sigpipe();
}
#endif

int main(int argc, const char* argv[])
{
    if (argc != 4)
	ts_fatal("Three arguments are required");

    // Just in case
    __ts_debug_init();

    if (fork()) {
      exit(0);
    }

    // Detach from TTY if debugging disabled, or debugging to a file
    if (__ts_debug_level == 0 ||
	(__ts_debug_file != stdout && __ts_debug_file != stderr))
    {
	ts_debug_1("Detaching from TTY");

	// Grand Child
	setsid();
    }

    flow_handler::make_handlers();

#if 0
    struct sigaction act;
    memset(&act, 0, sizeof act);
    act.sa_sigaction = pipe_sigaction;
    act.sa_flags = SA_RESTART | SA_SIGINFO;
    sigaction(SIGPIPE, &act, 0);
#else
    signal(SIGPIPE, SIG_IGN);
#endif

    pid = atoi(argv[1]);

    int fd = atoi(argv[2]);
    assert(fd >= 0 && fd < FD_SETSIZE);
    ts_debug_2("Listening on %d", fd);
    clients.insert(new client(fd));

    int statefd = atoi(argv[3]);
    if (statefd != -1) {
	ts_debug_1("Reading state from FD %d", statefd);

	ifstream inf(statefd);
	iserialstream in(inf);

	int version;
	in >> version;

	ts_debug_1("Stream version %d", version);

	while (1) {
	    int conn_id;
	    in >> conn_id;

	    if (!in || conn_id == -1)
		break;

	    ts_debug_1("Reading conn id %d", conn_id);

	    flow_handler *flow;
	    in >> flow;

	    int pair[2];
	    if (socketpair(PF_UNIX, SOCK_STREAM, 0, pair) != 0)
		ts_fatal("Unable to socketpair");

	    top_handler *top = static_cast<top_handler *>(flow);
	    top->set_fd(pair[1]);
	    top->set_conn_id(conn_id);
	    top->rewant();

	    saved_fds[conn_id] = pair[0];

	    conns[conn_id] = top;
	    conns_by_fd[pair[1]] = top;

	    if (!in)
		ts_fatal("Error while deserializing: %s", in.error().c_str());
	}

	ts_debug_1("Done deserializing!");
    }

    while (clients.size()) {
	ts_debug_2("Main loop; %d clients remain", clients.size());

	bool all_inessential = true;
	for (set<client*>::const_iterator i = clients.begin(); i != clients.end(); ++i) {
	    if (!(*i)->is_inessential()) {
		all_inessential = false;
		break;
	    }
	}
	if (all_inessential) {
	    ts_debug_1("All clients are inessential; bailing");
	    break;
	}

	async::select();

	for (set<top_handler*>::const_iterator i = close_set.begin(); i != close_set.end(); ++i) {
	    (*i)->close();
	    conns.erase((*i)->get_conn_id());
	    delete *i;
	}

	close_set.clear();
    }

    ts_debug_1("All essential clients are closed; terminating handlers");

    // Close down handlers
    while (true) {
	bool all_may_exit = true;
	for (map<int, top_handler*>::const_iterator i = conns.begin(); i != conns.end(); ++i) {
	    if (!i->second->may_exit())
		all_may_exit = false;
	}
	if (all_may_exit)
	    break;

	ts_debug_1("Some handlers are unable to terminate; postponing exit");

	async::select();
    }

    ts_debug_1("Closing any inessential clients");
    for (set<client*>::const_iterator i = clients.begin(); i != clients.end(); ++i)
	delete *i;
    clients.clear();

    ts_debug_1("All handlers are done; exiting master");

    return 0;
}
