package ins.api;

import ins.inr.*;
import ins.namespace.*;

import java.io.*;
import java.util.*;
import java.security.*;
import java.net.DatagramPacket;
import java.math.BigInteger;

import xjava.security.*;

import cryptix.provider.mac.HMAC_SHA1;
import cryptix.util.core.Hex;
import cryptix.util.test.BaseTest;

import security.crypto.*;
import security.srp.Util;

public class SecureApp extends Application{

    public static final String unAuthStr = "UNAUTHORIZED OPERATION";

    protected static Attribute opAttr = new Attribute("op");
    protected static Attribute groupAttr = new Attribute("group");
    protected static Attribute clientIDAttr = new Attribute("clientID");

    private MessageDigest hmac;
    private ACLlist aclList;
    private Hashtable idToKey, idToGroup, groupToID, groupToPacket;
    private boolean isServer = false;

    private KeyPair grpKP;
    private String grpPKStr; // contains PK and modulus

    private Random rand;

    //idTokey is used by both clients and servers to map IDs to session keys
    //The client tracks which session key to use by remembering the session key
    //it sent. idToGroup is used by the server to map groups to clientIDs to
    //check permissions. The client uses it to map it ID's to server group keys
    // that it is waiting for a response from. groupToPacket stores the sent
    // packet to send after the hand-shake

    private void init(String keyFile) {
        idToKey = new Hashtable();
        idToGroup = new Hashtable();
	groupToID = new Hashtable();
	if(!isServer){
	    groupToPacket = new Hashtable();
	}
				// get the key pair for this app (this is the group key pair)
	grpKP = RSAKeyPairGenerator.readKeys(keyFile);
	grpPKStr = new String(((RSAPublicKey)grpKP.getPublic()).
			      getPublicKey() + ":" +
			      ((RSAPublicKey)grpKP.getPublic()).getModulo());

	
	try {
	    hmac = MessageDigest.getInstance("HMAC-SHA-1");
	} catch(NoSuchAlgorithmException e) {
	    System.err.println("ERROR: Algorithm not found!");
	}
	rand = new Random();
    }
    
    public SecureApp(String aclFile, boolean isServer, String keyFile) throws Exception{
	super();
	
	init(keyFile);
	this.isServer = isServer;
        if (isServer){
            aclList = new ACLlist(aclFile);
        }
    }

    //fix other constructors should only make ACLlist if it's a server

    public SecureApp(String aclFile, boolean isServer,String keyFile,
		     String dsr_name) throws Exception{
        super(dsr_name);
	init(keyFile);
				
	this.isServer = isServer;
        if (isServer){
            aclList = new ACLlist(aclFile);
        }
        
    }
    
    public SecureApp(String aclFile, boolean isServer, String keyFile,
		     String inrname, int inrport) 
        throws Exception{
        super(inrname,inrport);
	init(keyFile);
	this.isServer = isServer;
        if (isServer){
            aclList = new ACLlist(aclFile);
        }

    }
    
    // ************************************************************


    public void setServer(boolean b){
        isServer = b;
    }

    public boolean isServer(){
        return isServer;
    }


    public void respondToPacket(Packet p){
        // subclass must reimplement this method.
        // will get called after the entire authorization 
        // has been established, behind the scenes.
        System.out.println("Doing nothing with packet");
    }
    
    
    private byte[] generateSessionKey(int bits) {
	int bytes = bits/8;
	byte[] key = new byte[bytes];
	
	rand.nextBytes(key);
	return key;
    }


    private void sendUnAuth(Packet p){
	sendMessage(p.dNS,p.sNS,unAuthStr.getBytes());
    }
    

    private boolean verifyWithSessionKey(Packet p, byte[] sessionKey){
        // MAC the (packet bytes - the MAC itself) and then compare to MAC
        //get clientID from packet 
        //        String sessionkey = idToKey.get(clientID);

	System.err.print("verifying with session key: ");
	printBytes(sessionKey);

	byte[] mac1 = p.option;

	try {
	    hmac.reset();
	    ((Parameterized) hmac).setParameter("key", sessionKey);
	} catch (NoSuchParameterException e) {
	    System.err.println("ERROR: Cannot set session key for HMAC!");
	    e.printStackTrace();
	} catch (InvalidParameterTypeException e) {
	    System.err.println("ERROR: Error in session key type for HMAC!");
	    e.printStackTrace();
	}

	System.err.print("Verifying ===> ");
	printBytes(p.toMACBytes());
	
	byte[] mac2 = hmac.digest(p.toMACBytes());


	if (mac1.length != mac2.length)
	    return false;
	for (int i=0; i<mac1.length; i++)
	    if (mac1[i] != mac2[i])
		return false;

	return true;

    }

    private void insertMACintoPacket(Packet p, byte[] sessionKey){
        // modify packet to insert MAC into option field
        //get clientID from packet 
        //        String sessionkey = idToKey.get(clientID);

	System.err.print("performing MAC with session key: ");
	printBytes(sessionKey);

	
	try {
	    hmac.reset();
	    ((Parameterized) hmac).setParameter("key", sessionKey);
	} catch (NoSuchParameterException e) {
	    System.err.println("ERROR: Cannot set session key for HMAC!");
	    e.printStackTrace();
	} catch (InvalidParameterTypeException e) {
	    System.err.println("ERROR: Error in session key type for HMAC!");
	    e.printStackTrace();
	}


	System.err.print("Digesting ===> ");
	printBytes(p.toMACBytes());

	byte[] option = hmac.digest(p.toMACBytes());
	p.option = option;
	

    }

    // **************************************************************
    
    public void receivePacket(Packet p){
        //here we need to check the op and which group sent the packet
        //and check it against the ACL's
        if(isServer()){
            receiveServPacket(p);
        }
        else{
            receiveCliPacket(p);
        }
    }

    private void receiveServPacket(Packet p){

	System.err.println("Server received packet from " + p.sNS.toString());
	
        String op=p.sNS.getAVelement(opAttr).getValue().toString();
        String cliGroupKeyStr =
	    p.sNS.getAVelement(groupAttr).getValue().toString();

        String clientID = p.sNS.getAVelement(clientIDAttr).
	    getValue().toString();
	
        //this checks if the group that the Packet claims to be from 
        //can perform the op, saves time if fails
        if (!aclList.getPerms(op,cliGroupKeyStr)){
            sendUnAuth(p);
            return;
        } else
	    System.err.println("Operation APPROVED!");

        if (idToKey.containsKey(clientID)){
            //we've done the handshake already
	    if (verifyWithSessionKey(p, (byte[])idToKey.get(clientID))) {
		System.err.println("Packet VERIFIED.");
		respondToPacket(p);
	    } else {
		System.err.println("WARNING: Verification on packet failed!");
		respondToPacket(p); // DEBUG
	    }
        } else {
            // no handshake yet -- HANDSHAKE

	    System.err.println("Don't know about ClientID: " + clientID +
			       " => performing handshake");
	    
	    // *** decrypt with private key
	    // verify with client's public key

	    StringTokenizer st = new StringTokenizer(cliGroupKeyStr, ":");
	    PublicKey cliPK = 
		new RSAPublicKey(new BigInteger(st.nextToken()),
				 new BigInteger(st.nextToken()), "RSA");
	    PrivateKey servSK = grpKP.getPrivate();


	    // make byte array to bigIntArr
	    BigInteger[] sig = RSACrypto.byteArrToBigInt(p.option);
	    String compStr = RSACrypto.verify(sig, (RSAPublicKey)cliPK);

	    if (!compStr.equals(p.sNS.toString()))
		System.err.println("WARNING: Verify of client failed!");
	    else
		System.err.println("Client indentity VERIFIED!");

	    // send the session key as a response
	    // switch the source and dest for the return packet
	    Packet rp = new Packet(p.dNS, p.sNS, null);
	    sendMessage(rp);
	}
        
    }
    
    public void receiveCliPacket(Packet p){

        String clientID = p.sNS.getAVelement(clientIDAttr).
	    getValue().toString();
	String servGroupKeyStr = p.sNS.getAVelement(groupAttr).
	    getValue().toString();


	System.err.println("client receiving from " +
			   p.sNS.toString());
	
        //we are the client

	if (Arrays.equals(p.data, unAuthStr.getBytes())) {
	    System.err.println("OPERATION UNAUTHORIZED!");
	    respondToPacket(p);
	    return;
	}

	
        //is this a response to the handshake or not

	if(idToKey.containsKey(clientID)){
            // we've done the handshake
	    if (verifyWithSessionKey(p, (byte[])idToKey.get(clientID))) {
		System.err.println("Packet VERIFIED.");
		respondToPacket(p);  
	    } else {
		System.err.println("WARNING: Verification on packet failed!");
		respondToPacket(p); // DEBUG
	    }
        } else {
            // no handshake response, this must be it!
            /*to do:
             * - authenticate the signature on the packet using the server's 
             *   group key/decrypt using private key (signature in p.option)
             * - check the ID and server's group key in table to make sure that
             *   we are expecting a response from this guy. 
             * - get session key and place in table with returned clientID 
             *   from the server.
             */

	    System.err.println("Handshake RESPONSE (expecting session key)");

	    if (!groupToID.containsKey(servGroupKeyStr)) {
		// We don't know about this handshake, so discard the packet
		System.err.println("ERROR: Response to unknown handshake.");
		return;
	    }

	    StringTokenizer st = new StringTokenizer(servGroupKeyStr, ":");
						
	    // we need to check the servPK from our records in hash table to
	    // make sure we're not getting spoofed
	    String knownServPK = (String)(idToGroup.get(clientID));
	    if(!knownServPK.equals(servGroupKeyStr)){
		//the group key we sent to does not match the one we got back
		//die
		System.err.println("We may be getting spoofed!!!!");
		return;
	    }
	    PublicKey servPK =
		new RSAPublicKey(new BigInteger(st.nextToken()),
				 new BigInteger(st.nextToken()), "RSA");
	    PrivateKey cliSK = grpKP.getPrivate();

	    
	    // Decrypt and verify
	    BigInteger[] sig = RSACrypto.byteArrToBigInt(p.option);
	    String compStr = RSACrypto.verify(sig, (RSAPublicKey)servPK);

	    if (!compStr.equals(p.sNS.toString()))
		System.err.println("WARNING: Server response not verified!");
	    else
		System.err.println("Server response VERIFIED!");
	    

	    // XXX still must put in decryption
	    
	    byte[] encKey = p.data;
	    String keyString =
		RSACrypto.decrypt(RSACrypto.byteArrToBigInt(encKey),
				  (RSAPrivateKey)cliSK);
	    byte[] sessionKey = Util.fromb64(keyString);
	    System.err.println("KEYSTRING: " + keyString);

//  	    byte[] sessionKey = p.data;
	    System.err.print("session key: ");
	    printBytes(sessionKey);
    
	    
	    idToKey.put(clientID, sessionKey);
	    //we now send the originally intended packet to the server
	    Packet pa = (Packet) (groupToPacket.get(servGroupKeyStr));
	    sendMessage(pa);
	    groupToPacket.remove(servGroupKeyStr);
	}
        
    }

    public boolean sendMessage(Packet p) {

	String clientID;
	
	if(!isServer()){

	    // *** CLIENT ***
	    System.err.println("Client is sending message from " +
			       p.sNS.toString() +
			       " to " +
			       p.dNS.toString());
	    String servGroupKeyStr;
	    PublicKey servGroupPK;

	    //getting group key for server
	    AVelement gk = p.dNS.getAVelement(groupAttr);
	    if (gk == null){
		//no group key-- die now
		System.err.print("ERROR: No group attribute " +
				   "in name specifier!");
		System.err.println(" Ignoring secureApp options.");
		oldSendMessage(p);
		return false;
	    }

	    servGroupKeyStr = gk.getValue().toString();

	    //ok we have the key determine if we need it (i.e. have we set up a
	    //sesion key with this server)
	    if (groupToID.containsKey(servGroupKeyStr)){
		long start = System.currentTimeMillis();

		//we've seen this guy already 
		//perform encoding using the sesion key we have
		//and continue
		clientID = (String)groupToID.get(servGroupKeyStr);
		byte[] sessionKey = (byte[])(idToKey.get(clientID));

		if (sessionKey == null) {
		    System.err.println("** Handshake performed, but "+
				       "no session key received yet!");
		    return false;
		}
		
		// remeber to add the clientID to the sNS of the packet
		if (p.sNS.getAVelement(clientIDAttr) == null) {
		    AVelement cliAV = new AVelement(clientIDAttr,
						    new Value(clientID));
		    p.sNS.addAVelement(cliAV);
		}

		if (p.sNS.getAVelement(groupAttr) == null) {
		    AVelement grpAV = new AVelement(groupAttr,
						    new Value(grpPKStr));
		    p.sNS.addAVelement(grpAV);
		}
	
		// ok do the HMAC and mangle the packet;
		insertMACintoPacket(p, sessionKey);

		long delta = System.currentTimeMillis() - start;
		System.err.println("Digest HMAC in " +
				   delta + " milliseconds.");

	    } else {
		// HANDSHAKE 

		System.err.println("New Connection => Handshaking");
		
		// generate a clientID
		// hash the clientID inder the group server key
								
		Long cliID = new Long(rand.nextLong());
		clientID = cliID.toString();
		groupToID.put(servGroupKeyStr, clientID);
		idToGroup.put(clientID,servGroupKeyStr);
		// add the clientID and our public key to the sNS of packet
		// so server can complete the handshake on other side

		if (p.sNS.getAVelement(clientIDAttr) == null) {
		    AVelement cliAV = new AVelement(clientIDAttr,
						    new Value(clientID));
		    p.sNS.addAVelement(cliAV);
		}

		if (p.sNS.getAVelement(groupAttr) == null) {
		    AVelement grpAV = new AVelement(groupAttr,
						    new Value(grpPKStr));
		    p.sNS.addAVelement(grpAV);
		}

		//before we mangle the packet and its payload we want to save
		// it to transmit after the handshake is done
		groupToPacket.put(servGroupKeyStr,p.clone());
								
		// sign payload with private key
		// *** encrypt with server's group public key
								
		StringTokenizer st = new StringTokenizer(servGroupKeyStr, ":");
		PublicKey servPK =
		    new RSAPublicKey(new BigInteger(st.nextToken()),
				     new BigInteger(st.nextToken()), "RSA");

		PrivateKey cliSK = grpKP.getPrivate();
		BigInteger[] signed =
		    RSACrypto.sign(p.sNS.toString(), (RSAPrivateKey)cliSK);
		p.option = RSACrypto.bigIntToByteArr(signed);

	    }
	} else {

	    // ok we're a SERVER, woohoo!

	    System.err.println("Sever is sending reponse...");
	    
	    clientID = p.dNS.getAVelement(clientIDAttr).getValue().toString();

	    // have we heard from this client before?	    
	    if(idToKey.containsKey(clientID)){
		// we've spoken before get session key and do that HMAC
		// stuff
		
		byte[] sessionKey = (byte[])(idToKey.get(clientID));

		// put clientID into packet
		if (p.sNS.getAVelement(clientIDAttr) == null) {
		    AVelement cliAV = new AVelement(clientIDAttr,
						    new Value(clientID));
		    p.sNS.addAVelement(cliAV);
		}

		if (p.sNS.getAVelement(groupAttr) == null) {
		    AVelement grpAV = new AVelement(groupAttr,
						    new Value(grpPKStr));
		    p.sNS.addAVelement(grpAV);
		}

		insertMACintoPacket(p, sessionKey);

	    } else {

		// HANDSHAKE: first time sending back to the client

		System.err.println("Handshake (sending session key) => " +
				   p.dNS.toString());

		
		//getting group key for client
		AVelement gk = p.dNS.getAVelement(groupAttr);
		if (gk == null){
		    //no group key-- die now
		    System.err.println("ERROR: No group attribute " +
				       "in name specifier for client!");
		    return false;
		}
		String cliGroupKeyStr = gk.getValue().toString();

		StringTokenizer st = new StringTokenizer(cliGroupKeyStr, ":");
		PublicKey cliPK = 
		    new RSAPublicKey(new BigInteger(st.nextToken()),
				     new BigInteger(st.nextToken()), "RSA");
		PrivateKey servSK = grpKP.getPrivate();
		
		// generate a session key
		// hash that session key with the clientID
		byte[] key = generateSessionKey(128);

		System.err.print("session key: ");
		printBytes(key);
		
		//set the values in the hashtable so that we don't do 2
		//handshakes
		idToKey.put(clientID, key);
		idToGroup.put(clientID,cliGroupKeyStr);

		// include server group public key and ID in nameSpecifier

		if (p.sNS.getAVelement(clientIDAttr) == null) {
		    AVelement cliAV = new AVelement(clientIDAttr,
						    new Value(clientID));
		    p.sNS.addAVelement(cliAV);
		}

		if (p.sNS.getAVelement(groupAttr) == null) {
		    AVelement grpAV = new AVelement(groupAttr,
						    new Value(grpPKStr));
		    p.sNS.addAVelement(grpAV);
		}
		// sign with our private key 
		// XXX encrypt with client's public key

		BigInteger[] sig = RSACrypto.sign(p.sNS.toString(),
						  ((RSAPrivateKey)servSK));
		BigInteger[] keyData = RSACrypto.encrypt(Util.tob64(key),
							 (RSAPublicKey)cliPK);
		byte[] sigBytes = RSACrypto.bigIntToByteArr(sig);
		byte[] keyBytes = RSACrypto.bigIntToByteArr(keyData);

		p.option = sigBytes;

		System.err.println("KEYSTRING: " + Util.tob64(key));
		
		// send key back to client in packet
		p.data = keyBytes;
//  		p.data = key;
		
	    }
								
	}

	return oldSendMessage(p);
	
    }

    public boolean oldSendMessage(Packet p) {
	
	//*******************************************************
	//* Everything below here is the same as Application's
	//     sendMessage() method

	NameSpecifier dNS = p.dNS;
        NameSpecifier sNS = p.sNS;

	//System.out.println("sending: "+p.toString());

	String vspace = null;
	// First thing we need to do is find the vspace that 
	// we are to sent this packet to
	vspace = dNS.getVspace();
	if (vspace == null)
	    vspace=defaultVspace;
	
	
	InetAddressPort addr;
	
	if (singlePeering) {
	    if (singlePeeringAddress == null) {
		singlePeeringAddress = 
		    this.inrmanager.getResolverForVspace(vspace, this);
		if (singlePeeringAddress == null) return false;
	    }
	    addr = singlePeeringAddress;
	} else { 
	    addr = this.inrmanager.getResolverForVspace(vspace, this);
	    if (addr==null) return false;
	}
	
	// create the bytes to send
	byte[] bytestosend = p.toBytes();
	
	//now create the datagram packet to send
	DatagramPacket dpacket = 
	    new DatagramPacket(bytestosend, bytestosend.length, 
			       addr.addr, addr.port);
	
	//now send the packet
	try {
	    sendDatagramPacket(dpacket);
	    // printStatus 
	    //   ("Sent Packet to "+addr.addr+":"+addr.port);
	} catch (IOException e) {
	    if (e.getMessage().equals("Connection refused")) {
		printStatus("Connection refused from " + addr.addr+
			    " on port "+addr.port);
	    } 
	    else {
		printStatus("Error sending message: " + e.getMessage());
	    }
	    return false;
	}
	return true;
    }



    private void printBytes(byte[] key) {
	if (key==null)
	    return;
	System.err.println("LENGTH: " + key.length);
	for (int i=0; i< key.length; i++)
	    System.err.print((char)key[i]);
	System.err.println();
    }


    
}





