summary history branches tags files
commit:dfaa9714f0c54c240a7c9e10a285a6dceface342
author:Trevor Bentley
committer:Trevor Bentley
date:Thu Jan 17 18:43:31 2019 +0100
parents:02a7478e626c8d381e4ccecf8b2848220c5e336d
Support non-blocking I/O, expand error handling
diff --git a/src/clib.rs b/src/clib.rs
line changes: +14/-9
index d7624c8..d8ed02c
--- a/src/clib.rs
+++ b/src/clib.rs
@@ -73,10 +73,15 @@ pub extern "C" fn ossuary_recv_handshake(conn: *mut ConnectionContext,
     let inlen = unsafe { *in_buf_len as usize };
     let r_in_buf: &[u8] = unsafe { std::slice::from_raw_parts(in_buf, inlen) };
     let mut slice = r_in_buf;
-    crypto_recv_handshake(&mut conn, &mut slice);
+    let written = match crypto_recv_handshake(&mut conn, &mut slice) {
+        Ok(read) => {
+            read as u16
+        },
+        _ => {
+            0u16
+        }
+    };
     ::std::mem::forget(conn);
-    let written = (inlen - slice.len()) as u16;
-    unsafe { *in_buf_len = written };
     written as i32 // TODO
 }
 
@@ -131,7 +136,6 @@ pub extern "C" fn ossuary_send_data(conn: *mut ConnectionContext,
         Err(_) => { return -1; },
     }
     ::std::mem::forget(conn);
-    //(out_buf_len - out_slice.len() as u16) as i32
     bytes_written as i32
 }
 
@@ -164,7 +168,7 @@ pub extern "C" fn ossuary_recv_data(conn: *mut ConnectionContext,
 #[cfg(test)]
 mod tests {
     use std::thread;
-    use std::io::{Read,Write};
+    use std::io::{Write};
     use std::net::{TcpListener, TcpStream};
     use std::io::BufRead;
     use crate::clib::*;
@@ -244,17 +248,18 @@ mod tests {
         ossuary_set_secret_key(conn, key as *const u8);
 
         let out_buf: [u8; 512] = [0; 512];
-        let mut in_buf: [u8; 512] = [0; 512];
 
+        let mut reader = std::io::BufReader::new(stream.try_clone().unwrap());
         while ossuary_handshake_done(conn) == 0 {
             let mut out_len = out_buf.len() as u16;
             let more = ossuary_send_handshake(conn, (&out_buf) as *const u8 as *mut u8, &mut out_len);
             let _ = stream.write_all(&out_buf[0.. out_len as usize]).unwrap();
 
             if more != 0 {
-                let _ = stream.read(&mut in_buf);
+                let in_buf = reader.fill_buf().unwrap();
                 let mut in_len = in_buf.len() as u16;
-                ossuary_recv_handshake(conn, (&in_buf) as *const u8, &mut in_len);
+                let len = ossuary_recv_handshake(conn, in_buf as *const [u8] as *const u8, &mut in_len);
+                reader.consume(len as usize);
             }
         }
 
@@ -295,7 +300,7 @@ mod tests {
     }
 
     #[test]
-    fn test() {
+    fn test_clib() {
         let server = thread::spawn(move || { let _ = server(); });
         std::thread::sleep(std::time::Duration::from_millis(500));
         let child = thread::spawn(move || { let _ = client(); });

diff --git a/src/lib.rs b/src/lib.rs
line changes: +97/-37
index 8452917..c3d4b3b
--- a/src/lib.rs
+++ b/src/lib.rs
@@ -29,6 +29,12 @@ const MAX_HANDSHAKE_WAIT_TIME: u64 = 3u64;
 // Size of the random data to be signed by client
 const CHALLENGE_LEN: usize = 256;
 
+// Internal buffer for copy of network data
+const PACKET_BUF_SIZE: usize = 16384
+    + ::std::mem::size_of::<PacketHeader>()
+    + ::std::mem::size_of::<EncryptedPacket>()
+    + 16; // chacha20 tag
+
 //
 // API:
 //  * sock -- TCP data socket
@@ -90,6 +96,7 @@ fn slice_as_struct<T>(p: &[u8]) -> Result<&T, OssuaryError> {
 
 pub enum OssuaryError {
     Io(std::io::Error),
+    WouldBlock(usize), // bytes consumed
     Unpack(core::array::TryFromSliceError),
     KeySize(usize, usize), // (expected, actual)
     InvalidKey,
@@ -216,6 +223,7 @@ pub enum ConnectionType {
     AuthenticatedServer,
     UnauthenticatedServer,
 }
+
 pub struct ConnectionContext {
     state: ConnectionState,
     conn_type: ConnectionType,
@@ -228,6 +236,8 @@ pub struct ConnectionContext {
     authorized_keys: Vec<[u8; 32]>,
     secret_key: Option<SecretKey>,
     public_key: Option<PublicKey>,
+    packet_buf: [u8; PACKET_BUF_SIZE],
+    packet_buf_used: usize,
 }
 impl ConnectionContext {
     pub fn new(conn_type: ConnectionType) -> ConnectionContext {
@@ -258,6 +268,8 @@ impl ConnectionContext {
             authorized_keys: vec!(),
             secret_key: None,
             public_key: None,
+            packet_buf: [0u8; PACKET_BUF_SIZE],
+            packet_buf_used: 0,
         }
     }
     fn reset_state(&mut self, permanent_err: Option<OssuaryError>) {
@@ -334,28 +346,51 @@ fn interpret_packet<'a, T>(pkt: &'a NetworkPacket) -> Result<&'a T, OssuaryError
     Ok(s)
 }
 
-fn interpret_packet_extra<'a, T>(pkt: &'a NetworkPacket) -> Result<(&'a T, &[u8]), OssuaryError> {
+fn interpret_packet_extra<'a, T>(pkt: &'a NetworkPacket) -> Result<(&'a T, &'a [u8]), OssuaryError> {
     let s: &T = slice_as_struct(&pkt.data)?;
     Ok((s, &pkt.data[::std::mem::size_of::<T>()..]))
 }
 
-fn read_packet<T,U>(mut stream: T) -> Result<NetworkPacket, OssuaryError>
+fn read_packet<T,U>(conn: &mut ConnectionContext, mut stream: T) ->Result<(NetworkPacket, usize), OssuaryError>
 where T: std::ops::DerefMut<Target = U>,
       U: std::io::Read {
-    let mut buf: Box<[u8]> = Box::new([0u8; ::std::mem::size_of::<PacketHeader>()]);
-    let _ = stream.read_exact(&mut buf)?;
+    let header_size = ::std::mem::size_of::<PacketHeader>();
+    let bytes_read: usize;
+    match stream.read(&mut conn.packet_buf[conn.packet_buf_used..]) {
+        Ok(b) => bytes_read = b,
+        Err(ref e) if e.kind() == std::io::ErrorKind::WouldBlock => {
+            return Err(OssuaryError::WouldBlock(0))
+        },
+        Err(e) => return Err(e.into()),
+    }
+    conn.packet_buf_used += bytes_read;
+    let buf: &[u8] = &conn.packet_buf;
     let hdr = PacketHeader {
         len: u16::from_be_bytes(buf[0..2].try_into()?),
         msg_id: u16::from_be_bytes(buf[2..4].try_into()?),
         packet_type: PacketType::from_u16(u16::from_be_bytes(buf[4..6].try_into()?)),
         _reserved: u16::from_be_bytes(buf[6..8].try_into()?),
     };
-    let mut buf: Box<[u8]> = vec![0u8; hdr.len as usize].into_boxed_slice();
-    let _ = stream.read_exact(&mut buf)?;
-    Ok(NetworkPacket {
+    let packet_len = hdr.len as usize;
+    if conn.packet_buf_used < header_size + packet_len {
+        return Err(OssuaryError::WouldBlock(bytes_read));
+    }
+    let buf: Box<[u8]> = (&conn.packet_buf[header_size..header_size+packet_len])
+        .to_vec().into_boxed_slice();
+    let excess = conn.packet_buf_used - header_size - packet_len;
+    unsafe {
+        // no safe way to memmove() in Rust?
+        std::ptr::copy::<u8>(
+            conn.packet_buf.as_ptr().offset((header_size + packet_len) as isize),
+            conn.packet_buf.as_mut_ptr(),
+            excess);
+    }
+    conn.packet_buf_used = excess;
+    Ok((NetworkPacket {
         header: hdr,
         data: buf,
-    })
+    },
+    header_size + packet_len))
 }
 
 fn write_packet<T,U>(stream: &mut T, data: &[u8], msg_id: &mut u16, kind: PacketType) -> Result<(), std::io::Error>
@@ -562,14 +597,22 @@ where T: std::ops::DerefMut<Target = U>,
     more
 }
 
-pub fn crypto_recv_handshake<T,U>(conn: &mut ConnectionContext, buf: T)
+// TODO u16 should be usize
+pub fn crypto_recv_handshake<T,U>(conn: &mut ConnectionContext, buf: T) -> Result<usize, OssuaryError>
 where T: std::ops::DerefMut<Target = U>,
       U: std::io::Read {
     // TODO: read_exact won't work.
-    let pkt: NetworkPacket = match read_packet(buf) {
-        Ok(p) => p,
-        Err(_) => {
-            return;
+    let mut bytes_read: usize = 0;
+    let pkt: NetworkPacket = match read_packet(conn, buf) {
+        Ok((p, r)) => {
+            bytes_read += r;
+            p
+        },
+        Err(OssuaryError::WouldBlock(b)) => {
+            return Err(OssuaryError::WouldBlock(b));
+        }
+        Err(e) => {
+            return Err(e);
         }
     };
 
@@ -577,7 +620,7 @@ where T: std::ops::DerefMut<Target = U>,
         println!("Message gap detected.  Restarting connection.");
         println!("Server: {}", conn.is_server());
         conn.reset_state(None);
-        return; // TODO: return error
+        return Err(OssuaryError::InvalidPacket("Message ID does not match".into()));
     }
     conn.remote_msg_id = pkt.header.msg_id + 1;
 
@@ -585,7 +628,7 @@ where T: std::ops::DerefMut<Target = U>,
     match pkt.kind() {
         PacketType::Reset => {
             conn.reset_state(None);
-            return;
+            return Err(OssuaryError::ConnectionFailed);
         },
         PacketType::Disconnect => {
             // TODO: handle error
@@ -634,14 +677,14 @@ where T: std::ops::DerefMut<Target = U>,
                                 Some(ref k) => k,
                                 None => {
                                     conn.reset_state(None);
-                                    return;
+                                    return Err(OssuaryError::InvalidKey);
                                 }
                             };
                             let remote_nonce = match conn.remote_key {
                                 Some(ref rem) => rem.nonce,
                                 None => {
                                     conn.reset_state(None);
-                                    return;
+                                    return Err(OssuaryError::InvalidKey);
                                 }
                             };
                             let _ = decrypt(session_key,
@@ -654,21 +697,21 @@ where T: std::ops::DerefMut<Target = U>,
                                     Ok(p) => p,
                                     Err(_) => {
                                         conn.reset_state(None);
-                                        return;
+                                        return Err(OssuaryError::InvalidKey);
                                     }
                                 };
                                 let sig = match Signature::from_bytes(sig) {
                                     Ok(s) => s,
                                     Err(_) => {
                                         conn.reset_state(None);
-                                        return;
+                                        return Err(OssuaryError::InvalidKey);
                                     }
                                 };
                                 let challenge = match conn.challenge {
                                     Some(ref c) => c,
                                     None => {
                                         conn.reset_state(None);
-                                        return;
+                                        return Err(OssuaryError::InvalidKey);
                                     }
                                 };
                                 match public.verify::<Sha512>(challenge, &sig) {
@@ -689,7 +732,7 @@ where T: std::ops::DerefMut<Target = U>,
                         },
                         Err(_) => {
                             conn.reset_state(None);
-                            return;
+                            return Err(OssuaryError::InvalidPacket("Response invalid".into()));
                         },
                     };
                 },
@@ -744,14 +787,14 @@ where T: std::ops::DerefMut<Target = U>,
                                 Some(ref k) => k,
                                 None => {
                                     conn.reset_state(None);
-                                    return;
+                                    return Err(OssuaryError::InvalidKey);
                                 }
                             };
                             let remote_nonce = match conn.remote_key {
                                 Some(ref rem) => rem.nonce,
                                 None => {
                                     conn.reset_state(None);
-                                    return;
+                                    return Err(OssuaryError::InvalidKey);
                                 }
                             };
                             let _ = decrypt(session_key,
@@ -783,6 +826,7 @@ where T: std::ops::DerefMut<Target = U>,
     if error {
         conn.reset_state(None);
     }
+    Ok(bytes_read)
 }
 
 // TODO: should return a Result with error on forced-disconnect or permanent failure
@@ -844,7 +888,7 @@ where T: std::ops::DerefMut<Target = U>,
       R: std::ops::DerefMut<Target = V>,
       V: std::io::Write {
     let bytes_written: u16;
-    let bytes_read: u16;
+    let mut bytes_read: u16 = 0u16;
     match conn.state {
         ConnectionState::Encrypted => {},
         _ => {
@@ -852,8 +896,9 @@ where T: std::ops::DerefMut<Target = U>,
         }
     }
 
-    match read_packet(in_buf) {
-        Ok(pkt) => {
+    match read_packet(conn, in_buf) {
+        Ok((pkt, bytes)) => {
+            bytes_read += bytes as u16;
             if pkt.header.msg_id != conn.remote_msg_id {
                 println!("Message gap detected.  Restarting connection.");
                 println!("Server: {}", conn.is_server());
@@ -889,10 +934,6 @@ where T: std::ops::DerefMut<Target = U>,
                                             &aad, &ciphertext, &tag, &mut plaintext);
                             let _ = out_buf.write(&plaintext);
                             bytes_written = ciphertext.len() as u16;
-                            bytes_read = (ciphertext.len() +
-                                          ::std::mem::size_of::<PacketHeader>() +
-                                          ::std::mem::size_of::<EncryptedPacket>() +
-                                          tag.len()) as u16;
                         },
                         Err(_) => {
                             conn.reset_state(None);
@@ -905,6 +946,9 @@ where T: std::ops::DerefMut<Target = U>,
                 },
             }
         },
+        Err(OssuaryError::WouldBlock(b)) => {
+            return Err(OssuaryError::WouldBlock(b));
+        },
         Err(_e) => {
             return Err(OssuaryError::InvalidPacket("Packet header did not parse.".into()));
         },
@@ -960,16 +1004,26 @@ mod tests {
             let mut server_conn = ConnectionContext::new(ConnectionType::UnauthenticatedServer);
             while crypto_handshake_done(&server_conn).unwrap() == false {
                 if crypto_send_handshake(&mut server_conn, &mut server_stream) {
-                    crypto_recv_handshake(&mut server_conn, &mut server_stream);
+                    loop {
+                        match crypto_recv_handshake(&mut server_conn, &mut server_stream) {
+                            Ok(_) => break,
+                            Err(OssuaryError::WouldBlock(_)) => {},
+                            _ => panic!("Handshake failed"),
+                        }
+                    }
                 }
             }
             let mut plaintext = vec!();
             let mut bytes: u64 = 0;
             let start = std::time::SystemTime::now();
             loop {
-                bytes += crypto_recv_data(&mut server_conn,
-                                          &mut server_stream,
-                                          &mut plaintext).unwrap().0 as u64;
+                match crypto_recv_data(&mut server_conn,
+                                       &mut server_stream,
+                                       &mut plaintext) {
+                    Ok((read, _written)) => bytes += read as u64,
+                    Err(OssuaryError::WouldBlock(_)) => continue,
+                    _ => panic!("Recv failed"),
+                }
                 if plaintext == [0xde, 0xde, 0xbe, 0xbe] {
                     if let Ok(dur) = start.elapsed() {
                         let t = dur.as_secs() as f64
@@ -985,10 +1039,17 @@ mod tests {
 
         std::thread::sleep(std::time::Duration::from_millis(500));
         let mut client_stream = TcpStream::connect("127.0.0.1:9987").unwrap();
+        client_stream.set_nonblocking(true).unwrap();
         let mut client_conn = ConnectionContext::new(ConnectionType::Client);
         while crypto_handshake_done(&client_conn).unwrap() == false {
             if crypto_send_handshake(&mut client_conn, &mut client_stream) {
-                crypto_recv_handshake(&mut client_conn, &mut client_stream);
+                loop {
+                    match crypto_recv_handshake(&mut client_conn, &mut client_stream) {
+                        Ok(_) => break,
+                        Err(OssuaryError::WouldBlock(_)) => {},
+                        _ => panic!("Handshake failed"),
+                    }
+                }
             }
         }
         let mut client_stream = std::io::BufWriter::new(client_stream);
@@ -1008,8 +1069,7 @@ mod tests {
         }
         let mut plaintext: &[u8] = &[0xde, 0xde, 0xbe, 0xbe];
         let _ = crypto_send_data(&mut client_conn, &mut plaintext, &mut client_stream);
-        // Unwrap the BufWriter, flushing the buffer
-        let _ = client_stream.into_inner().unwrap();
+        drop(client_stream); // flush the buffer
         let _ = server_thread.join();
     }
 }

diff --git a/tests/basic.rs b/tests/basic.rs
line changes: +24/-7
index 839483f..2968818
--- a/tests/basic.rs
+++ b/tests/basic.rs
@@ -1,6 +1,7 @@
 use ossuary::{ConnectionContext, ConnectionType};
 use ossuary::{crypto_send_handshake,crypto_recv_handshake, crypto_handshake_done};
 use ossuary::{crypto_send_data,crypto_recv_data};
+use ossuary::OssuaryError;
 
 use std::thread;
 use std::net::{TcpListener, TcpStream};
@@ -12,21 +13,37 @@ where T: std::io::Read + std::io::Write {
     // Run the opaque handshake until the connection is established
     while crypto_handshake_done(&conn).unwrap() == false {
         if crypto_send_handshake(&mut conn, &mut stream) {
-            crypto_recv_handshake(&mut conn, &mut stream);
+            loop {
+                match crypto_recv_handshake(&mut conn, &mut stream) {
+                    Ok(_) => break,
+                    Err(OssuaryError::WouldBlock(_)) => {},
+                    _ => panic!("Handshake failed."),
+                }
+            }
         }
     }
 
     // Send a message to the other party
-    let mut plaintext = match is_server {
-        true => "message from server".as_bytes(),
-        false => "message from client".as_bytes(),
+    let strings = ("message_from_server", "message_from_client");
+    let (mut plaintext, response) = match is_server {
+        true => (strings.0.as_bytes(), strings.1.as_bytes()),
+        false => (strings.1.as_bytes(), strings.0.as_bytes()),
     };
     let _ = crypto_send_data(&mut conn, &mut plaintext, &mut stream);
 
     // Read a message from the other party
-    let mut plaintext = vec!();
-    let _ = crypto_recv_data(&mut conn, &mut stream, &mut plaintext);
-    println!("(basic) received: {:?}", String::from_utf8(plaintext).unwrap());
+    let mut recv_plaintext = vec!();
+    loop {
+        match crypto_recv_data(&mut conn, &mut stream, &mut recv_plaintext) {
+            Ok(_) => {
+                println!("(basic) received: {:?}",
+                         String::from_utf8(recv_plaintext.clone()).unwrap());
+                assert_eq!(recv_plaintext.as_slice(), response);
+                break;
+            },
+            _ => {},
+        }
+    }
 
     Ok(())
 }

diff --git a/tests/basic_auth.rs b/tests/basic_auth.rs
line changes: +25/-9
index 8e7f961..41264d1
--- a/tests/basic_auth.rs
+++ b/tests/basic_auth.rs
@@ -1,6 +1,7 @@
 use ossuary::{ConnectionContext, ConnectionType};
 use ossuary::{crypto_send_handshake,crypto_recv_handshake, crypto_handshake_done};
 use ossuary::{crypto_send_data,crypto_recv_data};
+use ossuary::OssuaryError;
 
 use std::thread;
 use std::net::{TcpListener, TcpStream};
@@ -12,22 +13,37 @@ where T: std::io::Read + std::io::Write {
     // Run the opaque handshake until the connection is established
     while crypto_handshake_done(&conn).unwrap() == false {
         if crypto_send_handshake(&mut conn, &mut stream) {
-            crypto_recv_handshake(&mut conn, &mut stream);
+            loop {
+                match crypto_recv_handshake(&mut conn, &mut stream) {
+                    Ok(_) => break,
+                    Err(OssuaryError::WouldBlock(_)) => {},
+                    _ => panic!("Handshake failed."),
+                }
+            }
         }
     }
 
     // Send a message to the other party
-    let mut plaintext = match is_server {
-        true => "message from server".as_bytes(),
-        false => "message from client".as_bytes(),
+    let strings = ("message_from_server", "message_from_client");
+    let (mut plaintext, response) = match is_server {
+        true => (strings.0.as_bytes(), strings.1.as_bytes()),
+        false => (strings.1.as_bytes(), strings.0.as_bytes()),
     };
     let _ = crypto_send_data(&mut conn, &mut plaintext, &mut stream);
 
     // Read a message from the other party
-    let mut plaintext = vec!();
-    let _ = crypto_recv_data(&mut conn, &mut stream, &mut plaintext);
-    println!("(basic) received: {:?}", String::from_utf8(plaintext).unwrap());
-
+    let mut recv_plaintext = vec!();
+    loop {
+        match crypto_recv_data(&mut conn, &mut stream, &mut recv_plaintext) {
+            Ok(_) => {
+                println!("(basic_auth) received: {:?}",
+                         String::from_utf8(recv_plaintext.clone()).unwrap());
+                assert_eq!(recv_plaintext.as_slice(), response);
+                break;
+            },
+            _ => {},
+        }
+    }
     Ok(())
 }
 
@@ -59,7 +75,7 @@ fn client() -> Result<(), std::io::Error> {
 }
 
 #[test]
-fn basic() {
+fn basic_auth() {
     let server = thread::spawn(move || { let _ = server(); });
     std::thread::sleep(std::time::Duration::from_millis(500));
     let child = thread::spawn(move || { let _ = client(); });