summary history branches tags files
commit:17872a54c1f7ba55b1bed8d1df5a2a6bc5766a4f
author:Trevor Bentley
committer:Trevor Bentley
date:Thu Dec 13 20:11:44 2018 +0100
parents:a966b7a288ad0139ac11a2cce34f1fb8e7159127
starting to implement API
diff --git a/src/lib.rs b/src/lib.rs
line changes: +126/-233
index af974cf..0badb3f
--- a/src/lib.rs
+++ b/src/lib.rs
@@ -5,11 +5,15 @@ extern crate chacha20_poly1305_aead;
 use chacha20_poly1305_aead::{encrypt,decrypt};
 use x25519_dalek::generate_secret;
 use x25519_dalek::generate_public;
+use x25519_dalek::diffie_hellman;
+
 use rand::thread_rng;
+use rand::RngCore;
 
-use std::collections::VecDeque;
 use std::thread;
-//use std::io::prelude::*;
+use std::rc::Rc;
+use std::cell::RefCell;
+use std::collections::VecDeque;
 use std::net::{TcpListener, TcpStream};
 
 //
@@ -48,8 +52,6 @@ use std::net::{TcpListener, TcpStream};
 //   * data (encrypted)
 //
 
-use x25519_dalek::diffie_hellman;
-
 fn struct_as_slice<T: Sized>(p: &T) -> &[u8] {
     unsafe {
         ::std::slice::from_raw_parts(
@@ -58,13 +60,6 @@ fn struct_as_slice<T: Sized>(p: &T) -> &[u8] {
         )
     }
 }
-//unsafe fn struct_as_mut_slice<T: Sized>(p: &mut T) -> &[u8] {
-//    ::std::slice::from_raw_parts_mut(
-//        (p as *mut T) as *mut u8,
-//        ::std::mem::size_of::<T>(),
-//    )
-//}
-//unsafe fn slice_as_struct<T: Sized>(p: &[u8]) -> Result<&T, &'static str> {
 fn slice_as_struct<T>(p: &[u8]) -> Result<&T, &'static str> {
     unsafe {
         if p.len() < ::std::mem::size_of::<T>() {
@@ -73,29 +68,6 @@ fn slice_as_struct<T>(p: &[u8]) -> Result<&T, &'static str> {
         Ok(&*(&p[..::std::mem::size_of::<T>()] as *const [u8] as *const T))
     }
 }
-//unsafe fn slice_as_mut_struct<T: Sized>(p: &mut [u8]) -> Result<&mut T, &'static str> {
-//    if p.len() < ::std::mem::size_of::<T>() {
-//        return Err("Cannot cast bytes to struct: size mismatch");
-//    }
-//    Ok(&mut *(&mut p[..::std::mem::size_of::<T>()] as *mut [u8] as *mut T))
-//}
-//unsafe fn slice_as_owned_struct<T: Sized>(mut p: Box<[u8]>) -> Result<T, &'static str> {
-//    if p.len() < ::std::mem::size_of::<T>() {
-//        return Err("Cannot cast bytes to struct: size mismatch");
-//    }
-//    println!("box size: {} / T size: {}", ::std::mem::size_of_val(&*p), ::std::mem::size_of::<T>());
-//    println!("p: {}", p[0]);
-//    let p = Box::into_raw(p);
-//    Ok(*Box::from_raw(p as *mut T))
-//}
-//unsafe fn vec_as_struct<T: Sized>(p: &Vec<u8>) -> Result<&T, &'static str> {
-//    if p.len() < ::std::mem::size_of::<T>() {
-//        return Err("Cannot cast bytes to struct: buffer too small");
-//    }
-//    let s = &p.as_slice()[..::std::mem::size_of::<T>()];
-//    Ok(&*(s as *const [u8] as *const T))
-//}
-
 #[repr(packed)]
 #[allow(dead_code)]
 struct HandshakePacket {
@@ -141,52 +113,55 @@ struct PacketHeader {
     _reserved: u16,
 }
 
-//#[repr(packed)]
-//struct Packet<'a> {
-//    header: PacketHeader,
-//    data: &'a [u8],
-//}
-//impl <'a> Packet<'a> {
-//    fn new(msg_id: u16, kind: PacketType, data: &'a [u8]) -> Packet {
-//        Packet {
-//            header: PacketHeader {
-//                len: data.len() as u16,
-//                msg_id: msg_id,
-//                packet_type: kind,
-//                _reserved: 0u16,
-//            },
-//            data: data
-//        }
-//    }
-//}
-
-
-//enum ServerConnectionState {
-//    New,
-//    Encrypted,
-//    Authenticated,
-//}
-//
-//struct ServerConnectionContext {
-//    state: ServerConnectionState,
-//    remote_public: [u8; 32],
-//    nonce: [u8; 12],
-//}
-
-//pub fn read_struct<'a,T>(stream: &mut TcpStream, mut buf: &'a mut [u8]) -> Result<&'a T, &'static str> {
-//    let sz = ::std::mem::size_of::<T>();
-//    let buf_len = stream.read_exact(&mut buf[..sz]).unwrap();
-//    let pkt: &T = unsafe { slice_as_struct(buf)? };
-//    Ok(pkt)
-//}
-//
-//pub fn read_owned_struct<'a,T>(stream: &mut TcpStream) -> Result<T, &'static str> {
-//    let sz = ::std::mem::size_of::<T>();
-//    let mut buf: Box<[u8]> = vec![0; sz].into_boxed_slice();
-//    let buf_len = stream.read_exact(&mut buf[..sz]).unwrap();
-//    let pkt: T = unsafe { slice_as_owned_struct(buf)? };
-//    Ok(pkt)
-//}
+enum ConnectionState {
+    New,
+    Encrypted,
+    _Authenticated,
+}
+struct KeyMaterial {
+    secret: Option<[u8; 32]>,
+    public: [u8; 32],
+    session: Option<[u8; 32]>,
+    nonce: [u8; 12],
+}
+struct ConnectionContext {
+    stream: Rc<RefCell<TcpStream>>,
+    state: ConnectionState,
+    local_key: KeyMaterial,
+    remote_key: Option<KeyMaterial>,
+}
+impl ConnectionContext {
+    fn new(stream: TcpStream) -> ConnectionContext {
+        let mut rng = thread_rng();
+        let sec_key = generate_secret(&mut rng);
+        let pub_key = generate_public(&sec_key);
+        let mut nonce: [u8; 12] = [0; 12];
+        rng.fill_bytes(&mut nonce);
+        let key = KeyMaterial {
+            secret: Some(sec_key),
+            public: pub_key.to_bytes(),
+            nonce: nonce,
+            session: None,
+        };
+        ConnectionContext {
+            stream: Rc::new(RefCell::new(stream)),
+            state: ConnectionState::New,
+            local_key: key,
+            remote_key: None,
+        }
+    }
+    fn add_remote_key(&mut self, public: &[u8; 32], nonce: &[u8; 12]) {
+        let key = KeyMaterial {
+            secret: None,
+            public: public.to_owned(),
+            nonce: nonce.to_owned(),
+            session: None,
+        };
+        self.remote_key = Some(key);
+        self.local_key.session = Some(diffie_hellman(self.local_key.secret.as_ref().unwrap(), public));
+        self.state = ConnectionState::Encrypted;
+    }
+}
 
 struct NetworkPacket {
     header: PacketHeader,
@@ -208,8 +183,11 @@ fn interpret_packet_extra<'a, T>(pkt: &'a NetworkPacket) -> Result<(&'a T, &[u8]
     Ok((s, &pkt.data[::std::mem::size_of::<T>()..]))
 }
 
-fn read_packet<T>(stream: &mut T) -> Result<NetworkPacket, std::io::Error>
-where T: std::io::Read {
+//fn read_packet<T>(stream: &mut T) -> Result<NetworkPacket, std::io::Error>
+//where T: std::io::Read {
+fn read_packet<T,U>(mut stream: T) -> Result<NetworkPacket, std::io::Error>
+where T: std::ops::DerefMut<Target = U>,
+      U: std::io::Read {
     let mut buf: Box<[u8]> = Box::new([0u8; ::std::mem::size_of::<PacketHeader>()]);
     println!("read header");
     let _ = stream.read_exact(&mut buf)?;
@@ -235,8 +213,11 @@ where T: std::io::Read {
     })
 }
 
-fn write_packet<T>(stream: &mut T, data: &[u8], msg_id: u16, kind: PacketType) -> Result<(), std::io::Error>
-where T: std::io::Write {
+//fn write_packet<T>(stream: &mut T, data: &[u8], msg_id: u16, kind: PacketType) -> Result<(), std::io::Error>
+//where T: std::io::Write {
+fn write_packet<T,U>(mut stream: T, data: &[u8], msg_id: u16, kind: PacketType) -> Result<(), std::io::Error>
+where T: std::ops::DerefMut<Target = U>,
+      U: std::io::Write {
     let buf: Box<[u8]> = Box::new([0u8; ::std::mem::size_of::<PacketHeader>()]);
     unsafe {
         std::slice::from_raw_parts_mut(buf.as_ptr() as *mut u16, buf.len())[0] = u16::to_be(data.len() as u16);
@@ -255,41 +236,70 @@ where T: std::io::Write {
     Ok(())
 }
 
-pub fn event_loop(stream: &mut TcpStream, is_server: bool) -> Result<(), std::io::Error> {
-    let mut seed = thread_rng();
-    let sec_key = generate_secret(&mut seed);
-    let pub_key = generate_public(&sec_key);
-    println!("conn: {} {}", stream.peer_addr().unwrap(), is_server);
+//fn crypto_send_handshake<T>(conn: &ConnectionContext, buf: &mut T)
+//where T: std::io::Write {
+fn crypto_send_handshake<T,U>(conn: &ConnectionContext, buf: T)
+where T: std::ops::DerefMut<Target = U>,
+      U: std::io::Write {
+    match conn.state {
+        ConnectionState::New => {
+            let mut pkt: HandshakePacket = Default::default();
+            pkt.public_key.copy_from_slice(&conn.local_key.public);
+            pkt.nonce.copy_from_slice(&conn.local_key.nonce);
+            let _ = write_packet(buf, struct_as_slice(&pkt), 0, PacketType::PublicKeyNonce);
+        },
+        _ => {}
+    }
+}
+
+fn crypto_recv_handshake<T,U>(conn: &mut ConnectionContext, buf: T) -> bool
+where T: std::ops::DerefMut<Target = U>,
+      U: std::io::Read {
+    match conn.state {
+        ConnectionState::New => {},
+        _ => { return false; }
+    }
+    // TODO: read_exact won't work.
+    if let Ok(pkt) = read_packet(buf) {
+        println!("Packet type: {}", pkt.kind() as u16);
+        match pkt.kind() {
+            _ => {},
+        }
+    }
+    true
+}
 
-    //let _ = stream.write(pub_key.as_bytes());
+//fn crypto_send_data(conn: &ConnectionContext, buf: &[u8]) {
+//}
+//fn crypto_recv_data(conn: &ConnectionContext, buf: &mut [u8]) {
+//}
 
-    let mut remote_public: [u8; 32] = [0; 32];
-    let mut nonce: [u8; 12] = [0; 12];
-    let mut sess_key: [u8; 32] = [0; 32];
+//pub fn event_loop(stream: &mut TcpStream, is_server: bool) -> Result<(), std::io::Error> {
+fn event_loop(mut conn: ConnectionContext, is_server: bool) -> Result<(), std::io::Error> {
+    println!("conn: {} {}", conn.stream.borrow().peer_addr().unwrap(), is_server);
 
-    let mut pkt: HandshakePacket = Default::default();
-    pkt.public_key.copy_from_slice(pub_key.as_bytes());
-    pkt.nonce.copy_from_slice(&[1,0,0,0,0,1,0,0,0,0,1,0]);
-    let _ = write_packet(stream, struct_as_slice(&pkt), 0, PacketType::PublicKeyNonce);
+    crypto_send_handshake(&conn, conn.stream.borrow_mut());
+    //let mut pkt: HandshakePacket = Default::default();
+    //pkt.public_key.copy_from_slice(&conn.local_key.public);
+    //pkt.nonce.copy_from_slice(&conn.local_key.nonce);
+    //let _ = write_packet(&mut conn.stream, struct_as_slice(&pkt), 0, PacketType::PublicKeyNonce);
 
     let mut pkt_queue: VecDeque<NetworkPacket> = VecDeque::with_capacity(20);
     loop {
-        pkt_queue.push_back(read_packet(stream)?);
+        pkt_queue.push_back(read_packet(conn.stream.borrow_mut())?);
         if let Some(pkt) = pkt_queue.pop_front() {
             println!("Packet type: {}", pkt.kind() as u16);
             match pkt.kind() {
                 PacketType::PublicKeyNonce => {
-                    let data_pkt: &HandshakePacket = interpret_packet(&pkt).unwrap();
-                    remote_public.copy_from_slice(&data_pkt.public_key);
-                    nonce.copy_from_slice(&data_pkt.nonce);
-                    sess_key = diffie_hellman(&sec_key, &remote_public);
-                    println!("Server key: {:?}", sess_key);
+                    let data_pkt: &HandshakePacket = interpret_packet(&pkt).as_ref().unwrap();
+                    conn.add_remote_key(&data_pkt.public_key, &data_pkt.nonce);
+                    println!("Server key: {:?}", conn.local_key.session.as_ref().unwrap());
 
                     if is_server {
                         let aad = [1, 2, 3, 4];
                         let plaintext = b"hello, world";
                         let mut ciphertext = Vec::with_capacity(plaintext.len());
-                        let tag = encrypt(&sess_key, &nonce, &aad, plaintext, &mut ciphertext).unwrap();
+                        let tag = encrypt(conn.local_key.session.as_ref().unwrap(), &conn.local_key.nonce, &aad, plaintext, &mut ciphertext).unwrap();
                         println!("encrypted: {:?} {:?}", ciphertext, tag);
 
                         let pkt: EncryptedPacket = EncryptedPacket {
@@ -300,7 +310,7 @@ pub fn event_loop(stream: &mut TcpStream, is_server: bool) -> Result<(), std::io
                         buf.extend(struct_as_slice(&pkt));
                         buf.extend(&ciphertext);
                         buf.extend(&tag);
-                        let _ = write_packet(stream, &buf, 0, PacketType::EncryptedData);
+                        let _ = write_packet(conn.stream.borrow_mut(), &buf, 0, PacketType::EncryptedData);
                     }
                 },
                 PacketType::AuthRequest => {},
@@ -311,7 +321,7 @@ pub fn event_loop(stream: &mut TcpStream, is_server: bool) -> Result<(), std::io
                     let aad = [1, 2, 3, 4];
                     let mut plaintext = Vec::with_capacity(ciphertext.len());
                     println!("decrypting: {:?} {:?}", ciphertext, tag);
-                    let _ = decrypt(&sess_key, &nonce, &aad, &ciphertext, &tag, &mut plaintext);
+                    let _ = decrypt(conn.local_key.session.as_ref().unwrap(), &conn.remote_key.as_ref().unwrap().nonce, &aad, &ciphertext, &tag, &mut plaintext);
                     println!("decrypted: {:?}", String::from_utf8(plaintext));
                 },
                 PacketType::Disconnect => {},
@@ -321,144 +331,27 @@ pub fn event_loop(stream: &mut TcpStream, is_server: bool) -> Result<(), std::io
 }
 
 
-pub fn server2() -> Result<(), std::io::Error> {
+pub fn server() -> Result<(), std::io::Error> {
     let listener = TcpListener::bind("127.0.0.1:9988").unwrap();
     for stream in listener.incoming() {
-        let mut stream: TcpStream = stream.unwrap();
-        let _ = event_loop(&mut stream, true);
+        let stream: TcpStream = stream.unwrap();
+        let conn = ConnectionContext::new(stream);
+        let _ = event_loop(conn, true);
     }
     Ok(())
 }
 
-pub fn client2() -> Result<(), std::io::Error> {
-    let mut stream = TcpStream::connect("127.0.0.1:9988").unwrap();
-    let _ = event_loop(&mut stream, false);
+pub fn client() -> Result<(), std::io::Error> {
+    let stream = TcpStream::connect("127.0.0.1:9988").unwrap();
+    let conn = ConnectionContext::new(stream);
+    let _ = event_loop(conn, false);
     Ok(())
 }
 
-//pub fn server() {
-//    let listener = TcpListener::bind("127.0.0.1:9988").unwrap();
-//    for stream in listener.incoming() {
-//        let mut stream: TcpStream = stream.unwrap();
-//        stream.set_nonblocking(false);
-//
-//        let mut seed = thread_rng();
-//        let sec_key = generate_secret(&mut seed);
-//        let pub_key = generate_public(&sec_key);
-//        println!("Server conn: {}", stream.peer_addr().unwrap());
-//        let _ = stream.write(pub_key.as_bytes());
-//
-//        let mut buf: Box<[u8]> = Box::new([0u8; 65535]);
-//        {
-//            let hdr: &PacketHeader = read_struct(&mut stream, &mut buf).unwrap();
-//            println!("hdr len: {}", hdr.len);
-//        }
-//        let pkt: &HandshakePacket = read_struct(&mut stream, &mut buf).unwrap();
-//        println!("pkt len: {}", pkt.len);
-//        println!("Server write");
-//
-//        let mut remote_public: [u8; 32] = [0; 32];
-//        let mut nonce: [u8; 12] = [0; 12];
-//        remote_public.copy_from_slice(&pkt.public_key);
-//        nonce.copy_from_slice(&pkt.nonce);
-//
-//        let sess_key = diffie_hellman(&sec_key, &remote_public);
-//        println!("Server key: {:?}", sess_key);
-//
-//        let aad = [1, 2, 3, 4];
-//        let plaintext = b"hello, world";
-//        let mut ciphertext = Vec::with_capacity(plaintext.len());
-//        let tag = encrypt(&sess_key, &nonce, &aad, plaintext, &mut ciphertext).unwrap();
-//        println!("encrypted: {:?} {:?}", ciphertext, tag);
-//
-//        unsafe {
-//            let pkt: EncryptedPacket = EncryptedPacket {
-//                tag_len: tag.len() as u16,
-//                data_len: ciphertext.len() as u16,
-//            };
-//            let pkt = Packet::new(0, PacketType::EncryptedData, struct_as_slice(&pkt));
-//            //let _ = stream.write(struct_as_slice(&pkt));
-//            let _ = stream.write(struct_as_slice(&pkt.header));
-//            let _ = stream.write(pkt.data);
-//            let _ = stream.write(&ciphertext);
-//            let _ = stream.write(&tag);
-//        }
-//    }
-//}
-//
-//pub fn client() {
-//    let mut stream = TcpStream::connect("127.0.0.1:9988").unwrap();
-//    stream.set_nonblocking(false);
-//    let mut seed = thread_rng();
-//    let sec_key = generate_secret(&mut seed);
-//    let pub_key = generate_public(&sec_key);
-//
-//    let mut pkt: HandshakePacket = Default::default();
-//    pkt.public_key.copy_from_slice(pub_key.as_bytes());
-//    pkt.nonce.copy_from_slice(&[1,0,0,0,0,1,0,0,0,0,1,0]);
-//
-//    unsafe {
-//        let pkt = Packet::new(0, PacketType::PublicKeyNonce, struct_as_slice(&pkt));
-//        //let _ = stream.write(struct_as_slice(&pkt));
-//        let _ = stream.write(struct_as_slice(&pkt.header));
-//        let _ = stream.write(pkt.data);
-//    }
-//    //let _ = stream.write(pub_key.as_bytes());
-//    println!("Client write");
-//    let mut buf = vec![0u8; 4096];
-//    let buf_len = stream.read(&mut buf).unwrap();
-//    println!("Client read: {}", buf_len);
-//    let mut remote_public: [u8; 32] = [0; 32];
-//    remote_public.copy_from_slice(&buf[..32]);
-//    let sess_key = diffie_hellman(&sec_key, &remote_public);
-//    println!("Client key: {:?}", sess_key);
-//
-//    let hdr_len = {
-//        let hdr: &PacketHeader = read_struct(&mut stream, &mut buf).unwrap();
-//        println!("hdr len: {}", hdr.len);
-//        hdr.len
-//    } as usize;
-//    let (tag_len, data_len) = {
-//        let pkt: &EncryptedPacket = read_struct(&mut stream, &mut buf).unwrap();
-//        (pkt.tag_len as usize, pkt.data_len as usize)
-//    };
-//    let mut tag: Vec<u8> = vec![0; tag_len];
-//    let mut ciphertext: Vec<u8> = vec![0; data_len];
-//    let _ = stream.read_exact(&mut ciphertext).unwrap();
-//    let _ = stream.read_exact(&mut tag).unwrap();
-//    let aad = [1, 2, 3, 4];
-//    let mut plaintext = Vec::with_capacity(ciphertext.len());
-//    println!("decrypting: {:?} {:?}", ciphertext, tag);
-//    decrypt(&sess_key, &pkt.nonce, &aad, &ciphertext, &tag, &mut plaintext);
-//    println!("decrypted: {:?}", String::from_utf8(plaintext));
-//
-//    println!("Client done");
-//}
-
 pub fn test() {
-    thread::spawn(move || { let _ = server2(); });
-    let child = thread::spawn(move || { let _ = client2(); });
+    thread::spawn(move || { let _ = server(); });
+    let child = thread::spawn(move || { let _ = client(); });
     let _ = child.join();
-    //let mut alice_csprng = thread_rng();
-    //let     alice_secret = generate_secret(&mut alice_csprng);
-    //let     alice_public = generate_public(&alice_secret);
-    //let mut bob_csprng = thread_rng();
-    //let     bob_secret = generate_secret(&mut bob_csprng);
-    //let     bob_public = generate_public(&bob_secret);
-    //use x25519_dalek::diffie_hellman;
-    //let shared_secret_a = diffie_hellman(&alice_secret, &bob_public.as_bytes());
-    //let shared_secret_b = diffie_hellman(&bob_secret, &alice_public.as_bytes());
-    //let key = [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16,
-    //           17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31];
-    //let nonce = [1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0];
-    //let aad = [1, 2, 3, 4];
-    //let plaintext = b"hello, world";
-    //let mut ciphertext = Vec::with_capacity(plaintext.len());
-    //let tag = encrypt(&key, &nonce, &aad, plaintext, &mut ciphertext).unwrap();
-    //println!("encrypted: {:?}", ciphertext);
-    //let mut plaintext = Vec::with_capacity(ciphertext.len());
-    //decrypt(&key, &nonce, &aad, &ciphertext, &tag, &mut plaintext);
-    //println!("decrypted: {:?}", String::from_utf8(plaintext));
 }
 
 #[cfg(test)]