/*
 * Migrate Session Layer
 *
 * Alex C. Snoeren <snoeren@lcs.mit.edu>
 *
 * Copyright (c) 2001 Massachusetts Institute of Technology.
 *
 * This software is being provided by the copyright holders under the GNU
 * General Public License, either version 2 or, at your discretion, any later
 * version. For more information, see the `COPYING' file in the source
 * distribution.
 *
 * $Id: crypto.c,v 1.20 2002/08/28 17:46:24 snoeren Exp $
 *
 * This file requires OpenSSL, http://www.openssl.org
 */

#ifdef HAVE_CONFIG_H
# include <config.h>
#endif

#ifdef STDC_HEADERS
# include <string.h>
#endif
#ifdef HAVE_ERRNO_H
# include <errno.h>
#endif

#ifdef HAVE_LIBSSL
# include <openssl/ssl.h>
# include <openssl/dh.h>
# include <openssl/err.h>
# include <openssl/evp.h>
# include <openssl/rand.h>
#endif

#include "migrated.h"

#ifndef MODERN_SSL
/* Some older versions of OpenSSL headers don't include this function */
int
EVP_CIPHER_CTX_set_key_length(EVP_CIPHER_CTX *x, int keylen) {

  EVP_CIPHER *cp = (EVP_CIPHER *) EVP_CIPHER_CTX_cipher(x);
  EVP_CIPHER_key_length(cp) = keylen;
  return 0;
}
#endif

#define MAX_KEYSIZE      1024
#define CHALLENGE_LEN     128    /* Length of auth challenge cyphertext */
#define CIPHER (EVP_bf_ecb())    /* Cipher type, currently Blowfish */

/* Sample 1024-bit prime taken from OpenSSH */
#ifdef STRONG_CRYPTO
static const char *P =
            "FFFFFFFF" "FFFFFFFF" "C90FDAA2" "2168C234" "C4C6628B" "80DC1CD1"
            "29024E08" "8A67CC74" "020BBEA6" "3B139B22" "514A0879" "8E3404DD"
            "EF9519B3" "CD3A431B" "302B0A6D" "F25F1437" "4FE1356D" "6D51C245"
            "E485B576" "625E7EC6" "F44C42E9" "A637ED6B" "0BFF5CB6" "F406B7ED"
            "EE386BFB" "5A899FA5" "AE9F2411" "7C4B1FE6" "49286651" "ECE65381"
            "FFFFFFFF" "FFFFFFFF";
static const char *G = "02";
#else
static const char *P = "DC04EB6EB146437F17F6422B78DE6F7B"; /* 128-bit prime */
static const char *G = "02";
#endif

static DH *shared_parameters;

struct key_t {
  char data[MAX_KEYSIZE];
  int  len;
};


int
crypto_init(void)
{
  ERR_load_crypto_strings();
  if((shared_parameters = DH_new()) == NULL) {
    log_log(LOG_MOD_CRYPTO, LOG_ERR,
	    "DH_new(): %s", ERR_error_string(ERR_get_error(), NULL));
    return -1;
  }
  if(!BN_hex2bn(&shared_parameters->p, P)) {
    log_log(LOG_MOD_CRYPTO, LOG_ERR,
	    "DH_new(): %s", ERR_error_string(ERR_get_error(), NULL));
    return -1;
  }
  if(!BN_hex2bn(&shared_parameters->g, G)) {
    log_log(LOG_MOD_CRYPTO, LOG_ERR,
	    "DH_new(): %s", ERR_error_string(ERR_get_error(), NULL));
    return -1;
  }

  log_log(LOG_MOD_CRYPTO, LOG_INFO,
	  "Using prime %s (%s)", P, G);

  return 0;
}

int
session_crypto_auth(session_handle *handle, char *buf, int len)
{
  char challenge[BUFLEN];
  EVP_CIPHER_CTX c;
  struct key_t *key;
  int padlen;

  key = handle->key;

  /* Buf may be stored in the handle */
  memcpy(challenge, buf, len);

  log_log(LOG_MOD_CRYPTO, LOG_DEBUG,
	  "Decrypting challenge (%d, %d): %02hhx%02hhx%02hhx%02hhx...",
	  handle->session.id, len, buf[0], buf[1], buf[2], buf[3]);

  /* Decrypt challege */

  EVP_CIPHER_CTX_init(&c);
  EVP_DecryptInit(&c, CIPHER, key->data, NULL);
  EVP_CIPHER_CTX_set_key_length(&c, key->len);
  EVP_DecryptInit(&c, CIPHER, key->data, NULL);
  EVP_DecryptUpdate(&c, handle->buf, &handle->buflen, challenge, len);
  if(!EVP_DecryptFinal(&c, &handle->buf[handle->buflen], &padlen))
    log_log(LOG_MOD_CRYPTO, LOG_CRIT, "Unable to decrypt challenge");
  else {
    handle->buflen += padlen;
    log_log(LOG_MOD_CRYPTO, LOG_DEBUG,
	    "Decrypted challenge (%d): %02hhx%02hhx%02hhx%02hhx...",
	    handle->buflen, handle->buf[0], handle->buf[1], handle->buf[2],
	    handle->buf[3]);
  }

  /* Respond with plaintext */
  handle->msghdr.type = SESSION_RESUME_RESPONSE;
  handle->msghdr.len = handle->buflen;

  EVP_CIPHER_CTX_cleanup(&c);

  return 0;
}


int
session_crypto_resume(session_handle *handle)
{
  int lsessionid, psessionid;

  /* Send connection name */
  lsessionid = htonl(handle->session.id);
  psessionid = htonl(handle->session.pid);
  memcpy(handle->buf, &lsessionid, sizeof(lsessionid));
  memcpy(&handle->buf[sizeof(lsessionid)], &psessionid, sizeof(psessionid));
  handle->msghdr.type = SESSION_RESUME_REQUEST;
  handle->msghdr.len = 2*sizeof(lsessionid);
  return 0;
}


int
session_crypto_challenge(session_handle *handle)
{
  EVP_CIPHER_CTX c;
  int padlen, lsessionid, psessionid;
  struct key_t *key;
  session_handle *newhandle;

  assert(!handle->challenge);
  handle->challenge = (char *)malloc(CHALLENGE_LEN);
  assert(handle->challenge);

  /* Find session */
  assert(handle->msghdr.len == 2*sizeof(lsessionid));
  memcpy(&psessionid, handle->buf, sizeof(psessionid));
  memcpy(&lsessionid, &handle->buf[sizeof(psessionid)], sizeof(lsessionid));
  lsessionid = ntohl(lsessionid);
  psessionid = ntohl(psessionid);
  handle->session.id = lsessionid;
  handle->session.pid = psessionid;

  log_log(LOG_MOD_MIGRATED, LOG_DEBUG,
	  "Received a migrate request for session %u:%u",
	  lsessionid, psessionid);

  newhandle = find_session(lsessionid);
  if(!newhandle) {
    log_log(LOG_MOD_MIGRATED, LOG_ERR,
	    "Unknown session %d", lsessionid);
    memset(&handle->msghdr, 0, sizeof(handle->msghdr));
    return ENOENT;
  }
  handle->key = key = newhandle->key;

  EVP_CIPHER_CTX_init(&c);
  EVP_EncryptInit(&c, CIPHER, key->data, NULL);
  EVP_CIPHER_CTX_set_key_length(&c, key->len);
  EVP_EncryptInit(&c, CIPHER, key->data, NULL);

  /* Encrypt random data */
  if (!RAND_bytes(handle->challenge, CHALLENGE_LEN))
    log_log(LOG_MOD_CRYPTO, LOG_ERR, "Can't generate random data: %s",
	    ERR_error_string(ERR_get_error(), NULL));
  log_log(LOG_MOD_CRYPTO, LOG_DEBUG,
	  "Encrypting challenge (%d) : %02hhx%02hhx%02hhx%02hhx...",
	  CHALLENGE_LEN, handle->challenge[0], handle->challenge[1],
	  handle->challenge[2], handle->challenge[3]);

  EVP_EncryptUpdate(&c, handle->buf, &handle->buflen, handle->challenge,
		    CHALLENGE_LEN);
  EVP_EncryptFinal(&c, &handle->buf[handle->buflen], &padlen);
  handle->buflen += padlen;
  assert(handle->buflen < BUFLEN);

  log_log(LOG_MOD_CRYPTO, LOG_DEBUG,
	  "Encrypted challenge (%d, %d) : %02hhx%02hhx%02hhx%02hhx...",
	  lsessionid, handle->buflen, handle->buf[0], handle->buf[1],
	  handle->buf[2], handle->buf[3]);

  /* Send challenge */

  handle->msghdr.type = SESSION_RESUME_CHALLENGE;
  handle->msghdr.len = handle->buflen;

  EVP_CIPHER_CTX_cleanup(&c);
  return 0;
}

session_handle *
session_crypto_finalize(session_handle *handle)
{
  session_handle *newhandle;
  
  assert(handle->msghdr.len == CHALLENGE_LEN);
  assert(handle->challenge);

  if(memcmp(handle->challenge, handle->buf, handle->msghdr.len)) {
    log_log(LOG_MOD_CRYPTO, LOG_INFO,
	    "Challenge failed.");
    return NULL;
  }

  /* Find session */
  newhandle = find_session(handle->session.id);
  assert(newhandle);

  return newhandle;
}


#define read_peer(handle, bufptr, len) \
    memcpy(&sessionid, bufptr, sizeof(sessionid));                 \
    bufptr += sizeof(sessionid);                                   \
    len -= sizeof(sessionid);                                      \
    handle->session.pid = ntohl(sessionid);                        \
    memcpy(&bufsize, bufptr, sizeof(bufsize));                     \
    bufptr += sizeof(bufsize);                                     \
    len -= sizeof(bufsize);                                        \
    handle->session.pbufsize = ntohl(bufsize);                     \
    if(!handle->session.pname[0])                                  \
      memcpy(&handle->session.pname, bufptr, M_MAXNAMESIZE);       \
    bufptr += M_MAXNAMESIZE;                                       \
    len -= M_MAXNAMESIZE                                           

#define send_peer(handle, bufptr, len) \
    handle->msghdr.type = SESSION_EST;                             \
    bufptr = handle->buf;                                          \
    sessionid = htonl(handle->session.id);                         \
    memcpy(bufptr, &sessionid, sizeof(sessionid));                 \
    bufptr += sizeof(sessionid);                                   \
    bufsize = htonl(handle->lbufsize);                             \
    memcpy(bufptr, &bufsize, sizeof(bufsize));                     \
    bufptr += sizeof(bufsize);                                     \
    memcpy(bufptr, handle->session.dname, M_MAXNAMESIZE);          \
    bufptr += M_MAXNAMESIZE;                                       \
    handle->buflen = sizeof(sessionid) + sizeof(bufsize) + M_MAXNAMESIZE


int
session_crypto_init(session_handle *handle)
{  
  int  sessionid, len;
  size_t bufsize;
  char *bufptr = handle->buf;

  /* Generate new public/private key pair */
  assert(handle->key == NULL);
  shared_parameters->priv_key = NULL;
  if(!DH_generate_key(shared_parameters))
    return ERR_get_error();

  /* Save private key */
  len = BN_num_bytes(shared_parameters->priv_key);
  assert(len <= MAX_KEYSIZE);
  assert(handle->key = malloc(sizeof(struct key_t)));
  BN_bn2bin(shared_parameters->priv_key, handle->key);
  ((struct key_t *)(handle->key))->len = len;
  BN_clear_free(shared_parameters->priv_key);
  shared_parameters->priv_key = NULL;

  /* Send local session info */
  send_peer(handle, bufptr, len);

  /* Send public key */
  assert(BN_num_bytes(shared_parameters->pub_key) <= MAX_KEYSIZE);
  handle->buflen += BN_num_bytes(shared_parameters->pub_key);
  BN_bn2bin(shared_parameters->pub_key, bufptr);
  BN_clear_free(shared_parameters->pub_key);
  shared_parameters->pub_key = NULL;
  handle->msghdr.len = handle->buflen;

  return 0;
}


void *
finish_handshake(session_handle *handle)
{
  struct key_t *key = NULL;
  BIGNUM       *pubkey = NULL;
  BIGNUM       *privkey = NULL;
  char          buf[MAX_KEYSIZE];
  char         *errstr;
  int           sessionid, len = handle->msghdr.len;
  char         *bufptr = handle->buf;
  size_t        bufsize;

  /* Store peer's session info */
  read_peer(handle, bufptr, len);

  /* Read in peer's public key */
  assert(len <= MAX_KEYSIZE);
  if((pubkey = BN_bin2bn(bufptr, len, NULL))==NULL) {
    errstr = ERR_error_string(ERR_get_error(), NULL);
    goto err;
  }

  /* Read in private key */
  assert(((struct key_t *)handle->key)->len <= MAX_KEYSIZE);
  if(!(privkey = BN_bin2bn(handle->key, ((struct key_t *)handle->key)->len,
			   NULL))) {
    errstr = ERR_error_string(ERR_get_error(), NULL);
    goto err;
  }
  assert(shared_parameters->pub_key == NULL);
  assert(shared_parameters->priv_key == NULL);
  shared_parameters->priv_key = privkey;
  
  /* Compute shared key */
  if((len = DH_compute_key(buf, pubkey, shared_parameters)) == -1) {
    errstr = ERR_error_string(ERR_get_error(), NULL);
    goto err;
  }
  assert(len <= MAX_KEYSIZE);
  BN_clear_free(pubkey);
  BN_clear_free(privkey);
  privkey = NULL;
  pubkey = NULL;
  shared_parameters->priv_key = NULL;

  /* Store shared key */
  if(!(key = malloc(sizeof(struct key_t)))) {
    errstr = strerror(errno);
    goto err;
  }
  memcpy(key->data, buf, len);
  key->len = len;

  log_log(LOG_MOD_CRYPTO, LOG_DEBUG,
	  "negotiated %d-bit session key (%d:%d) %02hhx%02hhx%02hhx%02hhx...",
	  len*8, handle->session.id, handle->session.pid,
	  buf[0], buf[1], buf[2], buf[3]);

  return key;

 err:

  if(privkey) BN_clear_free(privkey);
  if(pubkey) BN_clear_free(pubkey);
  log_log(LOG_MOD_CRYPTO, LOG_ERR,
	  "DH Finish Handshake: %s", errstr);
  return NULL;
}


/*
 * Receive public key, generate shared secret, and respond with public key
 */

void *
session_handshake(session_handle *handle)
{
  char          buf[MAX_KEYSIZE];
  char         *bufptr = handle->buf;
  struct key_t *key = NULL;
  BIGNUM       *pkey = NULL;
  char         *errstr;
  int           sessionid, len = handle->msghdr.len;
  size_t        bufsize;
  
  /* Store peer's session info */
  read_peer(handle, bufptr, len);

  /* Copy in peer's key */
  assert(len <= MAX_KEYSIZE);
  if((pkey = BN_bin2bn(bufptr, len, NULL))==NULL) {
    errstr = ERR_error_string(ERR_get_error(), NULL);
    goto err;
  }

  /* Compute private key */
  assert(shared_parameters->priv_key == NULL);
  assert(shared_parameters->pub_key == NULL);
  if(!DH_generate_key(shared_parameters)) {
    errstr = ERR_error_string(ERR_get_error(), NULL);
    goto err;
  }

  /* Send local session info */
  send_peer(handle, bufptr, len);

  /* Send public key */
  assert(BN_num_bytes(shared_parameters->pub_key) <= MAX_KEYSIZE);
  handle->buflen += BN_num_bytes(shared_parameters->pub_key);
  BN_bn2bin(shared_parameters->pub_key, bufptr);
  BN_clear_free(shared_parameters->pub_key);
  shared_parameters->pub_key = NULL;
  handle->msghdr.len = handle->buflen;

  /* Store private key */
  if((len = DH_compute_key(buf, pkey, shared_parameters)) == -1) {
    errstr = ERR_error_string(ERR_get_error(), NULL);
    goto err;
  }
  assert(len <= MAX_KEYSIZE);
  BN_clear_free(shared_parameters->priv_key);
  shared_parameters->priv_key = NULL;

  if(!(key = malloc(sizeof(struct key_t)))) {
    errstr = strerror(errno);
    goto err;
  }
  key->len = len;
  memcpy(key->data, buf, key->len);

  log_log(LOG_MOD_CRYPTO, LOG_DEBUG,
	  "negotiated %d-bit session key (%d:%d) %02hhx%02hhx%02hhx%02hhx...",
	  key->len*8, handle->session.id, handle->session.pid,
	  buf[0], buf[1], buf[2], buf[3]);

  BN_clear_free(pkey);
  return key;

 err:

  if (pkey) BN_clear_free(pkey);
  log_log(LOG_MOD_CRYPTO, LOG_ERR, "DH Handshake: %s", errstr);
  return NULL;
}
