summary history branches tags files
commit:fab25d47f39213b74f7af01d9f6e5760e16fac6a
author:Trevor Bentley
committer:Trevor Bentley
date:Thu Dec 13 21:37:48 2018 +0100
parents:17872a54c1f7ba55b1bed8d1df5a2a6bc5766a4f
Implemented 4-function API
diff --git a/src/lib.rs b/src/lib.rs
line changes: +127/-101
index 0badb3f..9adf0c7
--- a/src/lib.rs
+++ b/src/lib.rs
@@ -1,3 +1,5 @@
+#![feature(try_from)]
+
 extern crate x25519_dalek;
 extern crate rand;
 extern crate chacha20_poly1305_aead;
@@ -10,10 +12,9 @@ use x25519_dalek::diffie_hellman;
 use rand::thread_rng;
 use rand::RngCore;
 
+use std::convert::TryInto;
+
 use std::thread;
-use std::rc::Rc;
-use std::cell::RefCell;
-use std::collections::VecDeque;
 use std::net::{TcpListener, TcpStream};
 
 //
@@ -91,10 +92,22 @@ impl Default for HandshakePacket {
 #[derive(Clone, Copy)]
 #[allow(dead_code)]
 enum PacketType {
-    PublicKeyNonce,
-    AuthRequest,
-    EncryptedData,
-    Disconnect,
+    Unknown = 0,
+    PublicKeyNonce = 0x01,
+    AuthRequest = 0x02,
+    EncryptedData = 0x03,
+    Disconnect = 0x04,
+}
+impl PacketType {
+    pub fn from_u16(i: u16) -> PacketType {
+        match i {
+            0x01 => PacketType::PublicKeyNonce,
+            0x02 => PacketType::AuthRequest,
+            0x03 => PacketType::EncryptedData,
+            0x04 => PacketType::Disconnect,
+            _ => PacketType::Unknown,
+        }
+    }
 }
 
 #[repr(packed)]
@@ -115,6 +128,7 @@ struct PacketHeader {
 
 enum ConnectionState {
     New,
+    PubKeySent,
     Encrypted,
     _Authenticated,
 }
@@ -125,13 +139,12 @@ struct KeyMaterial {
     nonce: [u8; 12],
 }
 struct ConnectionContext {
-    stream: Rc<RefCell<TcpStream>>,
     state: ConnectionState,
     local_key: KeyMaterial,
     remote_key: Option<KeyMaterial>,
 }
 impl ConnectionContext {
-    fn new(stream: TcpStream) -> ConnectionContext {
+    fn new() -> ConnectionContext {
         let mut rng = thread_rng();
         let sec_key = generate_secret(&mut rng);
         let pub_key = generate_public(&sec_key);
@@ -144,7 +157,6 @@ impl ConnectionContext {
             session: None,
         };
         ConnectionContext {
-            stream: Rc::new(RefCell::new(stream)),
             state: ConnectionState::New,
             local_key: key,
             remote_key: None,
@@ -183,27 +195,31 @@ fn interpret_packet_extra<'a, T>(pkt: &'a NetworkPacket) -> Result<(&'a T, &[u8]
     Ok((s, &pkt.data[::std::mem::size_of::<T>()..]))
 }
 
-//fn read_packet<T>(stream: &mut T) -> Result<NetworkPacket, std::io::Error>
-//where T: std::io::Read {
-fn read_packet<T,U>(mut stream: T) -> Result<NetworkPacket, std::io::Error>
+pub enum OssuaryError {
+    Io(std::io::Error),
+    Unpack(core::array::TryFromSliceError),
+}
+impl From<std::io::Error> for OssuaryError {
+    fn from(error: std::io::Error) -> Self {
+        OssuaryError::Io(error)
+    }
+}
+impl From<core::array::TryFromSliceError> for OssuaryError {
+    fn from(error: core::array::TryFromSliceError) -> Self {
+        OssuaryError::Unpack(error)
+    }
+}
+
+fn read_packet<T,U>(mut stream: T) -> Result<NetworkPacket, OssuaryError>
 where T: std::ops::DerefMut<Target = U>,
       U: std::io::Read {
     let mut buf: Box<[u8]> = Box::new([0u8; ::std::mem::size_of::<PacketHeader>()]);
-    println!("read header");
     let _ = stream.read_exact(&mut buf)?;
     let hdr = PacketHeader {
-        len: unsafe {
-            u16::from_be(std::slice::from_raw_parts(buf.as_ptr() as *const u16, buf.len())[0])
-        },
-        msg_id: unsafe {
-            u16::from_be(std::slice::from_raw_parts(buf.as_ptr() as *const u16, buf.len())[1])
-        },
-        packet_type: unsafe {
-            std::mem::transmute(u16::from_be(std::slice::from_raw_parts(buf.as_ptr() as *const u16, buf.len())[2]))
-        },
-        _reserved: unsafe {
-            u16::from_be(std::slice::from_raw_parts(buf.as_ptr() as *const u16, buf.len())[3])
-        },
+        len: u16::from_be_bytes(buf[0..2].try_into()?),
+        msg_id: u16::from_be_bytes(buf[2..4].try_into()?),
+        packet_type: PacketType::from_u16(u16::from_be_bytes(buf[4..6].try_into()?)),
+        _reserved: u16::from_be_bytes(buf[6..8].try_into()?),
     };
     let mut buf: Box<[u8]> = vec![0u8; hdr.len as usize].into_boxed_slice();
     let _ = stream.read_exact(&mut buf)?;
@@ -218,27 +234,17 @@ where T: std::ops::DerefMut<Target = U>,
 fn write_packet<T,U>(mut stream: T, data: &[u8], msg_id: u16, kind: PacketType) -> Result<(), std::io::Error>
 where T: std::ops::DerefMut<Target = U>,
       U: std::io::Write {
-    let buf: Box<[u8]> = Box::new([0u8; ::std::mem::size_of::<PacketHeader>()]);
-    unsafe {
-        std::slice::from_raw_parts_mut(buf.as_ptr() as *mut u16, buf.len())[0] = u16::to_be(data.len() as u16);
-    }
-    unsafe {
-        std::slice::from_raw_parts_mut(buf.as_ptr() as *mut u16, buf.len())[1] = u16::to_be(msg_id);
-    }
-    unsafe {
-        std::slice::from_raw_parts_mut(buf.as_ptr() as *mut u16, buf.len())[2] = u16::to_be(kind as u16);
-    }
-    unsafe {
-        std::slice::from_raw_parts_mut(buf.as_ptr() as *mut u16, buf.len())[3] = u16::to_be(0u16);
-    }
+    let mut buf: Vec<u8> = Vec::with_capacity(::std::mem::size_of::<PacketHeader>());
+    buf.extend_from_slice(&(data.len() as u16).to_be_bytes());
+    buf.extend_from_slice(&(msg_id as u16).to_be_bytes());
+    buf.extend_from_slice(&(kind as u16).to_be_bytes());
+    buf.extend_from_slice(&(0u16).to_be_bytes());
     stream.write(&buf)?;
     stream.write(data)?;
     Ok(())
 }
 
-//fn crypto_send_handshake<T>(conn: &ConnectionContext, buf: &mut T)
-//where T: std::io::Write {
-fn crypto_send_handshake<T,U>(conn: &ConnectionContext, buf: T)
+fn crypto_send_handshake<T,U>(conn: &mut ConnectionContext, buf: T) -> bool
 where T: std::ops::DerefMut<Target = U>,
       U: std::io::Write {
     match conn.state {
@@ -247,8 +253,12 @@ where T: std::ops::DerefMut<Target = U>,
             pkt.public_key.copy_from_slice(&conn.local_key.public);
             pkt.nonce.copy_from_slice(&conn.local_key.nonce);
             let _ = write_packet(buf, struct_as_slice(&pkt), 0, PacketType::PublicKeyNonce);
+            conn.state = ConnectionState::PubKeySent;
+            true
         },
-        _ => {}
+        _ => {
+            false
+        }
     }
 }
 
@@ -256,77 +266,93 @@ fn crypto_recv_handshake<T,U>(conn: &mut ConnectionContext, buf: T) -> bool
 where T: std::ops::DerefMut<Target = U>,
       U: std::io::Read {
     match conn.state {
-        ConnectionState::New => {},
+        ConnectionState::New => { return true; },
+        ConnectionState::PubKeySent => {},
         _ => { return false; }
     }
     // TODO: read_exact won't work.
     if let Ok(pkt) = read_packet(buf) {
         println!("Packet type: {}", pkt.kind() as u16);
         match pkt.kind() {
+            PacketType::PublicKeyNonce => {
+                let data_pkt: &HandshakePacket = interpret_packet(&pkt).as_ref().unwrap();
+                conn.add_remote_key(&data_pkt.public_key, &data_pkt.nonce);
+                println!("Session key: {:?}", conn.local_key.session.as_ref().unwrap());
+                conn.state = ConnectionState::Encrypted;
+            },
             _ => {},
         }
     }
     true
 }
 
-//fn crypto_send_data(conn: &ConnectionContext, buf: &[u8]) {
-//}
-//fn crypto_recv_data(conn: &ConnectionContext, buf: &mut [u8]) {
-//}
+fn crypto_send_data<T,U>(conn: &mut ConnectionContext, in_buf: &[u8], out_buf: T) -> Result<(), &'static str>
+where T: std::ops::DerefMut<Target = U>,
+      U: std::io::Write {
+    match conn.state {
+        ConnectionState::Encrypted => {},
+        _ => { return Err("Encrypted channel not established."); }
+    }
+    let aad = [];
+    let mut ciphertext = Vec::with_capacity(in_buf.len());
+    let tag = encrypt(conn.local_key.session.as_ref().unwrap(), &conn.local_key.nonce, &aad, in_buf, &mut ciphertext).unwrap();
+    println!("encrypted: {:?} {:?}", ciphertext, tag);
 
-//pub fn event_loop(stream: &mut TcpStream, is_server: bool) -> Result<(), std::io::Error> {
-fn event_loop(mut conn: ConnectionContext, is_server: bool) -> Result<(), std::io::Error> {
-    println!("conn: {} {}", conn.stream.borrow().peer_addr().unwrap(), is_server);
+    let pkt: EncryptedPacket = EncryptedPacket {
+        tag_len: tag.len() as u16,
+        data_len: ciphertext.len() as u16,
+    };
+    let mut buf: Vec<u8>= vec![];
+    buf.extend(struct_as_slice(&pkt));
+    buf.extend(&ciphertext);
+    buf.extend(&tag);
+    let _ = write_packet(out_buf, &buf, 0, PacketType::EncryptedData);
+    Ok(())
+}
 
-    crypto_send_handshake(&conn, conn.stream.borrow_mut());
-    //let mut pkt: HandshakePacket = Default::default();
-    //pkt.public_key.copy_from_slice(&conn.local_key.public);
-    //pkt.nonce.copy_from_slice(&conn.local_key.nonce);
-    //let _ = write_packet(&mut conn.stream, struct_as_slice(&pkt), 0, PacketType::PublicKeyNonce);
+fn crypto_recv_data<T,U,R,V>(conn: &mut ConnectionContext, in_buf: T, mut out_buf: R) -> Result<(), &'static str>
+where T: std::ops::DerefMut<Target = U>,
+      U: std::io::Read,
+      R: std::ops::DerefMut<Target = V>,
+      V: std::io::Write {
+    match conn.state {
+        ConnectionState::Encrypted => {},
+        _ => { return Err("Encrypted channel not established."); }
+    }
+    if let Ok(pkt) = read_packet(in_buf) {
+        println!("Packet type: {}", pkt.kind() as u16);
+        match pkt.kind() {
+            PacketType::EncryptedData => {
+                let (data_pkt, rest) = interpret_packet_extra::<EncryptedPacket>(&pkt).unwrap();
+                let ciphertext = &rest[..data_pkt.data_len as usize];
+                let tag = &rest[data_pkt.data_len as usize..];
+                let aad = [];
+                let mut plaintext = Vec::with_capacity(ciphertext.len());
+                let _ = decrypt(conn.local_key.session.as_ref().unwrap(), &conn.remote_key.as_ref().unwrap().nonce, &aad, &ciphertext, &tag, &mut plaintext);
+                let _ = out_buf.write(&plaintext);
+            },
+            _ => {
+                return Err("Received non-encrypted data on encrypted channel.");
+            },
+        }
+    }
+    Ok(())
+}
 
-    let mut pkt_queue: VecDeque<NetworkPacket> = VecDeque::with_capacity(20);
-    loop {
-        pkt_queue.push_back(read_packet(conn.stream.borrow_mut())?);
-        if let Some(pkt) = pkt_queue.pop_front() {
-            println!("Packet type: {}", pkt.kind() as u16);
-            match pkt.kind() {
-                PacketType::PublicKeyNonce => {
-                    let data_pkt: &HandshakePacket = interpret_packet(&pkt).as_ref().unwrap();
-                    conn.add_remote_key(&data_pkt.public_key, &data_pkt.nonce);
-                    println!("Server key: {:?}", conn.local_key.session.as_ref().unwrap());
+fn event_loop<T>(mut conn: ConnectionContext, mut stream: T, is_server: bool) -> Result<(), std::io::Error>
+where T: std::io::Read + std::io::Write {
+    while crypto_send_handshake(&mut conn, &mut stream) == true {}
+    while crypto_recv_handshake(&mut conn, &mut stream) == true {}
 
-                    if is_server {
-                        let aad = [1, 2, 3, 4];
-                        let plaintext = b"hello, world";
-                        let mut ciphertext = Vec::with_capacity(plaintext.len());
-                        let tag = encrypt(conn.local_key.session.as_ref().unwrap(), &conn.local_key.nonce, &aad, plaintext, &mut ciphertext).unwrap();
-                        println!("encrypted: {:?} {:?}", ciphertext, tag);
+    if is_server {
+        let mut plaintext = "hello, world".as_bytes();
+        let _ = crypto_send_data(&mut conn, &mut plaintext, &mut stream);
+    }
 
-                        let pkt: EncryptedPacket = EncryptedPacket {
-                            tag_len: tag.len() as u16,
-                            data_len: ciphertext.len() as u16,
-                        };
-                        let mut buf: Vec<u8>= vec![];
-                        buf.extend(struct_as_slice(&pkt));
-                        buf.extend(&ciphertext);
-                        buf.extend(&tag);
-                        let _ = write_packet(conn.stream.borrow_mut(), &buf, 0, PacketType::EncryptedData);
-                    }
-                },
-                PacketType::AuthRequest => {},
-                PacketType::EncryptedData => {
-                    let (data_pkt, rest) = interpret_packet_extra::<EncryptedPacket>(&pkt).unwrap();
-                    let ciphertext = &rest[..data_pkt.data_len as usize];
-                    let tag = &rest[data_pkt.data_len as usize..];
-                    let aad = [1, 2, 3, 4];
-                    let mut plaintext = Vec::with_capacity(ciphertext.len());
-                    println!("decrypting: {:?} {:?}", ciphertext, tag);
-                    let _ = decrypt(conn.local_key.session.as_ref().unwrap(), &conn.remote_key.as_ref().unwrap().nonce, &aad, &ciphertext, &tag, &mut plaintext);
-                    println!("decrypted: {:?}", String::from_utf8(plaintext));
-                },
-                PacketType::Disconnect => {},
-            }
-        }
+    loop {
+        let mut plaintext = vec!();
+        let _ = crypto_recv_data(&mut conn, &mut stream, &mut plaintext);
+        println!("decrypted: {:?}", String::from_utf8(plaintext));
     }
 }
 
@@ -335,16 +361,16 @@ pub fn server() -> Result<(), std::io::Error> {
     let listener = TcpListener::bind("127.0.0.1:9988").unwrap();
     for stream in listener.incoming() {
         let stream: TcpStream = stream.unwrap();
-        let conn = ConnectionContext::new(stream);
-        let _ = event_loop(conn, true);
+        let conn = ConnectionContext::new();
+        let _ = event_loop(conn, stream, true);
     }
     Ok(())
 }
 
 pub fn client() -> Result<(), std::io::Error> {
     let stream = TcpStream::connect("127.0.0.1:9988").unwrap();
-    let conn = ConnectionContext::new(stream);
-    let _ = event_loop(conn, false);
+    let conn = ConnectionContext::new();
+    let _ = event_loop(conn, stream, false);
     Ok(())
 }