summary history branches tags files
commit:0f958c5f77e11c7ec9cfc888d6347df167d8cbfc
author:Trevor Bentley
committer:Trevor Bentley
date:Sun Jan 20 20:04:37 2019 +0100
parents:84309f8def0ce862cc9aae5d37426e9d8fbc81e5
Support non-blocking IO and related error handling
diff --git a/examples/ffi.c b/examples/ffi.c
line changes: +3/-2
index 2329940..af586cf
--- a/examples/ffi.c
+++ b/examples/ffi.c
@@ -77,13 +77,14 @@ int main(int argc, char **argv) {
   memset(server_buf, 0, sizeof(server_buf));
 
   // Server sends encrypted data
+  out_len = sizeof(client_buf);
   bytes = snprintf((char*)server_buf, sizeof(server_buf), "hello world");
-  bytes = ossuary_send_data(server_conn, server_buf, bytes, client_buf, sizeof(client_buf));
+  bytes = ossuary_send_data(server_conn, server_buf, bytes, client_buf, &out_len);
   printf("server send data bytes: %d\n", bytes);
 
   // Client receives decrypted data
   out_len = sizeof(client_buf);
-  bytes = ossuary_recv_data(client_conn, client_buf, bytes, client_buf, &out_len);
+  bytes = ossuary_recv_data(client_conn, client_buf, &bytes, client_buf, &out_len);
   printf("client recv data bytes: %d\n", bytes);
   printf("decrypted: %s\n", client_buf);
 

diff --git a/ffi/ossuary.h b/ffi/ossuary.h
line changes: +2/-2
index b786fba..e62a10b
--- a/ffi/ossuary.h
+++ b/ffi/ossuary.h
@@ -21,9 +21,9 @@ int32_t ossuary_send_handshake(ConnectionContext *conn,
 int32_t ossuary_handshake_done(ConnectionContext *conn);
 int32_t ossuary_send_data(ConnectionContext *conn,
                           uint8_t *in_buf, uint16_t in_buf_len,
-                          uint8_t *out_buf, uint16_t out_buf_len);
+                          uint8_t *out_buf, uint16_t *out_buf_len);
 int32_t ossuary_recv_data(ConnectionContext *conn,
-                  uint8_t *in_buf, uint16_t in_buf_len,
+                  uint8_t *in_buf, uint16_t *in_buf_len,
                   uint8_t *out_buf, uint16_t *out_buf_len);
 
 #define _OSSUARY_H

diff --git a/src/clib.rs b/src/clib.rs
line changes: +96/-41
index a7c467f..1d7f99d
--- a/src/clib.rs
+++ b/src/clib.rs
@@ -1,6 +1,8 @@
-use crate::{crypto_send_data, crypto_recv_data,
+use crate::{crypto_send_data, crypto_recv_data, crypto_flush,
             crypto_send_handshake, crypto_recv_handshake, crypto_handshake_done,
-            ConnectionContext, ConnectionType};
+            ConnectionContext, ConnectionType, OssuaryError};
+
+const ERROR_WOULD_BLOCK: i32 = -64;
 
 #[no_mangle]
 pub extern "C" fn ossuary_create_connection(conn_type: u8) -> *mut ConnectionContext {
@@ -27,7 +29,9 @@ pub extern "C" fn ossuary_destroy_connection(conn: &mut *mut ConnectionContext) 
 }
 
 #[no_mangle]
-pub extern "C" fn ossuary_set_authorized_keys(conn: *mut ConnectionContext, keys: *const *const u8, key_count: u8) -> i32 {
+pub extern "C" fn ossuary_set_authorized_keys(conn: *mut ConnectionContext,
+                                              keys: *const *const u8,
+                                              key_count: u8) -> i32 {
     if conn.is_null() || keys.is_null() {
         return -1 as i32;
     }
@@ -49,7 +53,8 @@ pub extern "C" fn ossuary_set_authorized_keys(conn: *mut ConnectionContext, keys
 }
 
 #[no_mangle]
-pub extern "C" fn ossuary_set_secret_key(conn: *mut ConnectionContext, key: *const u8) -> i32 {
+pub extern "C" fn ossuary_set_secret_key(conn: *mut ConnectionContext,
+                                         key: *const u8) -> i32 {
     if conn.is_null() || key.is_null() {
         return -1 as i32;
     }
@@ -73,16 +78,19 @@ pub extern "C" fn ossuary_recv_handshake(conn: *mut ConnectionContext,
     let inlen = unsafe { *in_buf_len as usize };
     let r_in_buf: &[u8] = unsafe { std::slice::from_raw_parts(in_buf, inlen) };
     let mut slice = r_in_buf;
-    let written = match crypto_recv_handshake(&mut conn, &mut slice) {
+    let read: i32 = match crypto_recv_handshake(&mut conn, &mut slice) {
         Ok(read) => {
-            read as u16
+            unsafe { *in_buf_len = read as u16; }
+            read as i32
         },
-        _ => {
-            0u16
-        }
+        Err(OssuaryError::WouldBlock(b)) => {
+            unsafe { *in_buf_len = b as u16; }
+            ERROR_WOULD_BLOCK
+        },
+        _ => -1i32,
     };
     ::std::mem::forget(conn);
-    written as i32 // TODO
+    read as i32 // TODO
 }
 
 #[no_mangle]
@@ -95,11 +103,19 @@ pub extern "C" fn ossuary_send_handshake(conn: *mut ConnectionContext,
     let outlen = unsafe { *out_buf_len as usize };
     let r_out_buf: &mut [u8] = unsafe { std::slice::from_raw_parts_mut(out_buf, outlen) };
     let mut slice = r_out_buf;
-    let more = crypto_send_handshake(&mut conn, &mut slice);
+    let wrote: i32 = match crypto_send_handshake(&mut conn, &mut slice) {
+        Ok(w) => {
+            unsafe { *out_buf_len = w as u16 };
+            w as i32
+        },
+        Err(OssuaryError::WouldBlock(w)) => {
+            unsafe { *out_buf_len = w as u16 };
+            ERROR_WOULD_BLOCK
+        },
+        Err(_) => -1,
+    };
     ::std::mem::forget(conn);
-    // TODO: error if data to send is larger than the given buffer
-    unsafe { *out_buf_len = (outlen - slice.len()) as u16 };
-    more as i32
+    wrote
 }
 
 #[no_mangle]
@@ -119,52 +135,83 @@ pub extern "C" fn ossuary_handshake_done(conn: *const ConnectionContext) -> i32 
 #[no_mangle]
 pub extern "C" fn ossuary_send_data(conn: *mut ConnectionContext,
                                     in_buf: *mut u8, in_buf_len: u16,
-                                    out_buf: *mut u8, out_buf_len: u16) -> i32 {
-    if conn.is_null() || in_buf.is_null() || out_buf.is_null() {
+                                    out_buf: *mut u8, out_buf_len: *mut u16) -> i32 {
+    if conn.is_null() || in_buf.is_null() || out_buf.is_null() || out_buf_len.is_null(){
         return -1i32;
     }
     let mut conn = unsafe { &mut *conn };
-    let r_out_buf: &mut [u8] = unsafe { std::slice::from_raw_parts_mut(out_buf, out_buf_len as usize) };
+    let r_out_buf: &mut [u8] = unsafe {
+        std::slice::from_raw_parts_mut(out_buf, *out_buf_len as usize)
+    };
     let r_in_buf: &[u8] = unsafe { std::slice::from_raw_parts(in_buf, in_buf_len as usize) };
     let mut out_slice = r_out_buf;
     let in_slice = r_in_buf;
-    let bytes_written: i32;
-    match crypto_send_data(&mut conn, &in_slice, &mut out_slice) {
-        Ok(x) => {
-            bytes_written = x as i32;
-        }
-        Err(_) => { return -1; },
-    }
+    let bytes_written = match crypto_send_data(&mut conn, &in_slice, &mut out_slice) {
+        Ok(w) => {
+            unsafe { *out_buf_len = w as u16; }
+            w as i32
+        },
+        Err(OssuaryError::WouldBlock(w)) => {
+            unsafe { *out_buf_len = w as u16; }
+            ERROR_WOULD_BLOCK
+        },
+        Err(_) => -1i32,
+    };
     ::std::mem::forget(conn);
     bytes_written
 }
 
 #[no_mangle]
 pub extern "C" fn ossuary_recv_data(conn: *mut ConnectionContext,
-                                    in_buf: *mut u8, in_buf_len: u16,
+                                    in_buf: *mut u8, in_buf_len: *mut u16,
                                     out_buf: *mut u8, out_buf_len: *mut u16) -> i32 {
-    if conn.is_null() || in_buf.is_null() || out_buf.is_null() || out_buf_len.is_null() {
+    if conn.is_null() || in_buf.is_null() || out_buf.is_null() ||
+        in_buf_len.is_null() || out_buf_len.is_null() {
         return -1i32;
     }
     let mut conn = unsafe { &mut *conn };
     let r_out_buf: &mut [u8] = unsafe { std::slice::from_raw_parts_mut(out_buf, *out_buf_len as usize) };
-    let r_in_buf: &[u8] = unsafe { std::slice::from_raw_parts(in_buf, in_buf_len as usize) };
+    let r_in_buf: &[u8] = unsafe { std::slice::from_raw_parts(in_buf, *in_buf_len as usize) };
     let mut out_slice = r_out_buf;
     let mut in_slice = r_in_buf;
-    let bytes_read: u16;
-    match crypto_recv_data(&mut conn, &mut in_slice, &mut out_slice) {
+    let bytes_read = match crypto_recv_data(&mut conn, &mut in_slice, &mut out_slice) {
         Ok((read,written)) => {
-            unsafe { *out_buf_len = written as u16 };
-            bytes_read = read as u16;
+            unsafe {
+                *in_buf_len = read as u16;
+                *out_buf_len = written as u16;
+            };
+            read as i32
         },
-        Err(_) => {
-            return -1;
+        Err(OssuaryError::WouldBlock(w)) => {
+            unsafe {
+                *out_buf_len = w as u16;
+            };
+            ERROR_WOULD_BLOCK
         },
-    }
+        Err(_) => -1i32,
+    };
     ::std::mem::forget(conn);
     bytes_read as i32
 }
 
+#[no_mangle]
+pub extern "C" fn ossuary_flush(conn: *mut ConnectionContext,
+                                out_buf: *mut u8, out_buf_len: u16) -> i32 {
+    if conn.is_null() || out_buf.is_null() {
+        return -1i32;
+    }
+    let mut conn = unsafe { &mut *conn };
+    let r_out_buf: &mut [u8] = unsafe { std::slice::from_raw_parts_mut(out_buf, out_buf_len as usize) };
+    let mut out_slice = r_out_buf;
+    let bytes_written = match crypto_flush(&mut conn, &mut out_slice) {
+        Ok(x) => x as i32,
+        Err(OssuaryError::WouldBlock(_)) => ERROR_WOULD_BLOCK,
+        Err(_) => -1i32,
+    };
+    ::std::mem::forget(conn);
+    bytes_written
+}
+
 #[cfg(test)]
 mod tests {
     use std::thread;
@@ -193,7 +240,7 @@ mod tests {
                 let more = ossuary_send_handshake(conn, (&out_buf) as *const u8 as *mut u8, &mut out_len);
                 let _ = stream.write_all(&out_buf[0..out_len as usize]).unwrap();
 
-                if more != 0 {
+                if more >= 0 {
                     let in_buf = reader.fill_buf().unwrap();
                     let mut in_len = in_buf.len() as u16;
                     if in_len > 0 {
@@ -205,25 +252,30 @@ mod tests {
 
             let mut plaintext: [u8; 256] = [0; 256];
             plaintext[0..13].copy_from_slice("from server 1".as_bytes());
+            let mut out_len: u16 = out_buf.len() as u16;
             let sz = ossuary_send_data(
                 conn,
                 (&plaintext) as *const u8 as *mut u8, 13 as u16,
-                (&out_buf) as *const u8 as *mut u8, out_buf.len() as u16);
+                (&out_buf) as *const u8 as *mut u8,
+                &mut out_len);
             let _ = stream.write_all(&out_buf[0..sz as usize]).unwrap();
 
             plaintext[0..13].copy_from_slice("from server 2".as_bytes());
+            let mut out_len: u16 = out_buf.len() as u16;
             let sz = ossuary_send_data(
                 conn,
                 (&plaintext) as *const u8 as *mut u8, 13 as u16,
-                (&out_buf) as *const u8 as *mut u8, out_buf.len() as u16);
+                (&out_buf) as *const u8 as *mut u8,
+                &mut out_len);
             let _ = stream.write_all(&out_buf[0..sz as usize]).unwrap();
 
             let in_buf = reader.fill_buf().unwrap();
             if in_buf.len() > 0 {
                 let mut out_len = out_buf.len() as u16;
+                let mut in_len = in_buf.len() as u16;
                 let len = ossuary_recv_data(
                     conn,
-                    (in_buf) as *const [u8] as *mut u8, in_buf.len() as u16,
+                    (in_buf) as *const [u8] as *mut u8, &mut in_len,
                     (&out_buf) as *const u8 as *mut u8, &mut out_len);
                 if len != -1 {
                     println!("CLIB READ: {:?}",
@@ -255,7 +307,7 @@ mod tests {
             let more = ossuary_send_handshake(conn, (&out_buf) as *const u8 as *mut u8, &mut out_len);
             let _ = stream.write_all(&out_buf[0.. out_len as usize]).unwrap();
 
-            if more != 0 {
+            if more >= 0 {
                 let in_buf = reader.fill_buf().unwrap();
                 let mut in_len = in_buf.len() as u16;
                 let len = ossuary_recv_handshake(conn, in_buf as *const [u8] as *const u8, &mut in_len);
@@ -266,10 +318,12 @@ mod tests {
         let out_buf: [u8; 256] = [0; 256];
         let mut plaintext: [u8; 256] = [0; 256];
         plaintext[0..11].copy_from_slice("from client".as_bytes());
+        let mut out_len: u16 = out_buf.len() as u16;
         let sz = ossuary_send_data(
             conn,
             (&plaintext) as *const u8 as *mut u8, 11 as u16,
-            (&out_buf) as *const u8 as *mut u8, out_buf.len() as u16);
+            (&out_buf) as *const u8 as *mut u8,
+            &mut out_len);
         let _ = stream.write_all(&out_buf[0..sz as usize]).unwrap();
 
         let mut stream = std::io::BufReader::new(stream);
@@ -280,9 +334,10 @@ mod tests {
                 break;
             }
             let mut out_len = out_buf.len() as u16;
+            let mut in_len = in_buf.len() as u16;
             let len = ossuary_recv_data(
                 conn,
-                in_buf as *const [u8] as *mut u8, in_buf.len() as u16,
+                in_buf as *const [u8] as *mut u8, &mut in_len,
                 (&out_buf) as *const u8 as *mut u8, &mut out_len);
             if len == -1 {
                 break;

diff --git a/src/lib.rs b/src/lib.rs
line changes: +243/-109
index cddf3d1..9d4db53
--- a/src/lib.rs
+++ b/src/lib.rs
@@ -72,10 +72,9 @@ const PACKET_BUF_SIZE: usize = 16384
 //
 
 // TODO:
-//  - non-blocking IO
-//  - remove all unwraps()
 //  - consider all unexpected packet types to be errors
 //  - limit connection retries
+//  - protocol version number
 
 fn struct_as_slice<T: Sized>(p: &T) -> &[u8] {
     unsafe {
@@ -103,16 +102,32 @@ pub enum OssuaryError {
     InvalidPacket(String),
     InvalidStruct,
     InvalidSignature,
+    ConnectionReset,
     ConnectionFailed,
 }
 impl std::fmt::Debug for OssuaryError {
     fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
-        write!(f, "OssuaryError")
+        match self {
+            OssuaryError::Io(e) => write!(f, "OssuaryError::Io {}", e),
+            OssuaryError::WouldBlock(_) => write!(f, "OssuaryError::WouldBlock"),
+            OssuaryError::Unpack(_) => write!(f, "OssuaryError::Unpack"),
+            OssuaryError::KeySize(_,_) => write!(f, "OssuaryError::KeySize"),
+            OssuaryError::InvalidKey => write!(f, "OssuaryError::InvalidKey"),
+            OssuaryError::InvalidPacket(_) => write!(f, "OssuaryError::InvalidPacket"),
+            OssuaryError::InvalidStruct => write!(f, "OssuaryError::InvalidStruct"),
+            OssuaryError::InvalidSignature => write!(f, "OssuaryError::InvalidSignature"),
+            OssuaryError::ConnectionReset => write!(f, "OssuaryError::ConnectionReset"),
+            OssuaryError::ConnectionFailed => write!(f, "OssuaryError::ConnectionFailed"),
+        }
+        //write!(f, "OssuaryError")
     }
 }
 impl From<std::io::Error> for OssuaryError {
     fn from(error: std::io::Error) -> Self {
-        OssuaryError::Io(error)
+        match error.kind() {
+            std::io::ErrorKind::WouldBlock => OssuaryError::WouldBlock(0),
+            _ => OssuaryError::Io(error),
+        }
     }
 }
 impl From<core::array::TryFromSliceError> for OssuaryError {
@@ -125,6 +140,11 @@ impl From<ed25519_dalek::SignatureError> for OssuaryError {
         OssuaryError::InvalidKey
     }
 }
+impl From<chacha20_poly1305_aead::DecryptError> for OssuaryError {
+    fn from(_error: chacha20_poly1305_aead::DecryptError) -> Self {
+        OssuaryError::InvalidKey
+    }
+}
 
 #[repr(C,packed)]
 struct HandshakePacket {
@@ -236,8 +256,10 @@ pub struct ConnectionContext {
     authorized_keys: Vec<[u8; 32]>,
     secret_key: Option<SecretKey>,
     public_key: Option<PublicKey>,
-    packet_buf: [u8; PACKET_BUF_SIZE],
-    packet_buf_used: usize,
+    read_buf: [u8; PACKET_BUF_SIZE],
+    read_buf_used: usize,
+    write_buf: [u8; PACKET_BUF_SIZE],
+    write_buf_used: usize,
 }
 impl ConnectionContext {
     pub fn new(conn_type: ConnectionType) -> ConnectionContext {
@@ -268,8 +290,10 @@ impl ConnectionContext {
             authorized_keys: vec!(),
             secret_key: None,
             public_key: None,
-            packet_buf: [0u8; PACKET_BUF_SIZE],
-            packet_buf_used: 0,
+            read_buf: [0u8; PACKET_BUF_SIZE],
+            read_buf_used: 0,
+            write_buf: [0u8; PACKET_BUF_SIZE],
+            write_buf_used: 0,
         }
     }
     fn reset_state(&mut self, permanent_err: Option<OssuaryError>) {
@@ -288,6 +312,10 @@ impl ConnectionContext {
         self.challenge = None;
         self.challenge_sig = None;
         self.remote_key = None;
+        self.read_buf_used = 0;
+        self.write_buf_used = 0;
+        self.local_msg_id = 0u16;
+        self.remote_msg_id = 0u16;
     }
     fn is_server(&self) -> bool {
         match self.conn_type {
@@ -356,15 +384,15 @@ where T: std::ops::DerefMut<Target = U>,
       U: std::io::Read {
     let header_size = ::std::mem::size_of::<PacketHeader>();
     let bytes_read: usize;
-    match stream.read(&mut conn.packet_buf[conn.packet_buf_used..]) {
+    match stream.read(&mut conn.read_buf[conn.read_buf_used..]) {
         Ok(b) => bytes_read = b,
         Err(ref e) if e.kind() == std::io::ErrorKind::WouldBlock => {
             return Err(OssuaryError::WouldBlock(0))
         },
         Err(e) => return Err(e.into()),
     }
-    conn.packet_buf_used += bytes_read;
-    let buf: &[u8] = &conn.packet_buf;
+    conn.read_buf_used += bytes_read;
+    let buf: &[u8] = &conn.read_buf;
     let hdr = PacketHeader {
         len: u16::from_be_bytes(buf[0..2].try_into()?),
         msg_id: u16::from_be_bytes(buf[2..4].try_into()?),
@@ -372,20 +400,23 @@ where T: std::ops::DerefMut<Target = U>,
         _reserved: u16::from_be_bytes(buf[6..8].try_into()?),
     };
     let packet_len = hdr.len as usize;
-    if conn.packet_buf_used < header_size + packet_len {
+    if conn.read_buf_used < header_size + packet_len {
+        if header_size + packet_len > PACKET_BUF_SIZE {
+            panic!("oversized packet");
+        }
         return Err(OssuaryError::WouldBlock(bytes_read));
     }
-    let buf: Box<[u8]> = (&conn.packet_buf[header_size..header_size+packet_len])
+    let buf: Box<[u8]> = (&conn.read_buf[header_size..header_size+packet_len])
         .to_vec().into_boxed_slice();
-    let excess = conn.packet_buf_used - header_size - packet_len;
+    let excess = conn.read_buf_used - header_size - packet_len;
     unsafe {
         // no safe way to memmove() in Rust?
         std::ptr::copy::<u8>(
-            conn.packet_buf.as_ptr().offset((header_size + packet_len) as isize),
-            conn.packet_buf.as_mut_ptr(),
+            conn.read_buf.as_ptr().offset((header_size + packet_len) as isize),
+            conn.read_buf.as_mut_ptr(),
             excess);
     }
-    conn.packet_buf_used = excess;
+    conn.read_buf_used = excess;
     Ok((NetworkPacket {
         header: hdr,
         data: buf,
@@ -393,57 +424,95 @@ where T: std::ops::DerefMut<Target = U>,
     header_size + packet_len))
 }
 
-fn write_packet<T,U>(stream: &mut T, data: &[u8], msg_id: &mut u16, kind: PacketType) -> Result<(), std::io::Error>
+fn write_stored_packet<T,U>(conn: &mut ConnectionContext,
+                            stream: &mut T) -> Result<usize, OssuaryError>
+where T: std::ops::DerefMut<Target = U>,
+      U: std::io::Write {
+    let mut written = 0;
+    while written < conn.write_buf_used {
+        match stream.write(&conn.write_buf[written..conn.write_buf_used]) {
+            Ok(w) => {
+                written += w;
+            },
+            Err(e) => {
+                if written > 0 && written < conn.write_buf_used {
+                    unsafe {
+                        // no safe way to memmove() in Rust?
+                        std::ptr::copy::<u8>(
+                            conn.write_buf.as_ptr().offset(written as isize),
+                            conn.write_buf.as_mut_ptr(),
+                            conn.write_buf_used - written);
+                    }
+                }
+                conn.write_buf_used -= written;
+                return Err(e.into());
+            },
+        }
+    }
+    conn.write_buf_used = 0;
+    Ok(written)
+}
+
+fn write_packet<T,U>(conn: &mut ConnectionContext,
+                     stream: &mut T, data: &[u8],
+                     kind: PacketType) -> Result<usize, OssuaryError>
 where T: std::ops::DerefMut<Target = U>,
       U: std::io::Write {
-    let mut buf: Vec<u8> = Vec::with_capacity(::std::mem::size_of::<PacketHeader>());
-    buf.extend_from_slice(&(data.len() as u16).to_be_bytes());
-    buf.extend_from_slice(&(*msg_id as u16).to_be_bytes());
-    buf.extend_from_slice(&(kind as u16).to_be_bytes());
-    buf.extend_from_slice(&(0u16).to_be_bytes());
-    let _ = stream.write(&buf)?;
-    let _ = stream.write(data)?;
-    *msg_id = *msg_id + 1;
-    Ok(())
+    let msg_id = conn.local_msg_id as u16;
+    conn.write_buf[0..2].copy_from_slice(&(data.len() as u16).to_be_bytes());
+    conn.write_buf[2..4].copy_from_slice(&msg_id.to_be_bytes());
+    conn.write_buf[4..6].copy_from_slice(&(kind as u16).to_be_bytes());
+    conn.write_buf[6..8].copy_from_slice(&(0u16).to_be_bytes());
+    conn.write_buf[8..8+data.len()].copy_from_slice(&data);
+    conn.write_buf_used = 8 + data.len();
+    conn.local_msg_id += 1;
+    let written = write_stored_packet(conn, stream)?;
+    Ok(written)
 }
 
-pub fn crypto_send_handshake<T,U>(conn: &mut ConnectionContext, mut buf: T) -> bool
+pub fn crypto_send_handshake<T,U>(conn: &mut ConnectionContext, mut buf: T) -> Result<usize, OssuaryError>
 where T: std::ops::DerefMut<Target = U>,
       U: std::io::Write {
-    let mut next_msg_id = conn.local_msg_id;
-    let more = match conn.state {
+    // Try to send any unsent buffered data
+    match write_stored_packet(conn, &mut buf) {
+        Ok(w) if w == 0 => {},
+        Ok(w) => return Err(OssuaryError::WouldBlock(w)),
+        Err(e) => return Err(e),
+    }
+    let written = match conn.state {
         ConnectionState::ServerNew => {
             // Wait for client to initiate connection
-            true
+            0
         },
         ConnectionState::Encrypted => {
             // Handshake finished
-            false
+            0
         },
         ConnectionState::ServerWaitAck(t) |
         ConnectionState::ServerWaitAuth(t) |
         ConnectionState::ClientWaitKey(t) |
         ConnectionState::ClientWaitAck(t) => {
+            let mut w: usize = 0;
             // Wait for response, with timeout
             if let Ok(dur) = t.elapsed() {
                 if dur.as_secs() > MAX_HANDSHAKE_WAIT_TIME {
                     let pkt: HandshakePacket = Default::default();
-                    let _ = write_packet(&mut buf, struct_as_slice(&pkt),
-                                         &mut next_msg_id, PacketType::Reset);
+                    w = write_packet(conn, &mut buf, struct_as_slice(&pkt),
+                                     PacketType::Reset)?;
                     conn.reset_state(None);
                 }
             }
-            true
+            w
         },
         ConnectionState::ServerSendPubKey => {
             // Send session public key and nonce to the client
             let mut pkt: HandshakePacket = Default::default();
             pkt.public_key.copy_from_slice(&conn.local_key.public);
             pkt.nonce.copy_from_slice(&conn.local_key.nonce);
-            let _ = write_packet(&mut buf, struct_as_slice(&pkt),
-                                 &mut next_msg_id, PacketType::PublicKeyNonce);
+            let w = write_packet(conn, &mut buf, struct_as_slice(&pkt),
+                                 PacketType::PublicKeyNonce)?;
             conn.state = ConnectionState::ServerWaitAck(std::time::SystemTime::now());
-            true
+            w
         },
         ConnectionState::ServerSendChallenge => {
             match conn.conn_type {
@@ -455,7 +524,7 @@ where T: std::ops::DerefMut<Target = U>,
                         Ok(rng) => rng,
                         Err(_) => {
                             conn.reset_state(None);
-                            return true;
+                            return Err(OssuaryError::InvalidKey);
                         }
                     };
                     let aad = [];
@@ -467,7 +536,7 @@ where T: std::ops::DerefMut<Target = U>,
                         Some(ref s) => s,
                         None => {
                             conn.reset_state(None);
-                            return true;
+                            return Err(OssuaryError::InvalidKey);
                         }
                     };
                     let tag = match encrypt(session_key,
@@ -476,7 +545,7 @@ where T: std::ops::DerefMut<Target = U>,
                         Ok(tag) => tag,
                         Err(_) => {
                             conn.reset_state(None);
-                            return true;
+                            return Err(OssuaryError::InvalidKey);
                         }
                     };
                     let pkt: EncryptedPacket = EncryptedPacket {
@@ -487,18 +556,18 @@ where T: std::ops::DerefMut<Target = U>,
                     pkt_buf.extend(struct_as_slice(&pkt));
                     pkt_buf.extend(&ciphertext);
                     pkt_buf.extend(&tag);
-                    let _ = write_packet(&mut buf, &pkt_buf,
-                                         &mut next_msg_id, PacketType::AuthChallenge);
+                    let w = write_packet(conn, &mut buf, &pkt_buf,
+                                         PacketType::AuthChallenge)?;
                     conn.state = ConnectionState::ServerWaitAuth(std::time::SystemTime::now());
-                    true
+                    w
                 },
                 _ => {
                     // For unauthenticated connections, we are done.  Already encrypted.
                     let pkt: HandshakePacket = Default::default();
-                    let _ = write_packet(&mut buf, struct_as_slice(&pkt),
-                                         &mut next_msg_id, PacketType::PubKeyAck);
+                    let w = write_packet(conn, &mut buf, struct_as_slice(&pkt),
+                                         PacketType::PubKeyAck)?;
                     conn.state = ConnectionState::Encrypted;
-                    false // handshake is finished (success)
+                    w // handshake is finished (success)
                 },
             }
         },
@@ -507,18 +576,18 @@ where T: std::ops::DerefMut<Target = U>,
             let mut pkt: HandshakePacket = Default::default();
             pkt.public_key.copy_from_slice(&conn.local_key.public);
             pkt.nonce.copy_from_slice(&conn.local_key.nonce);
-            let _ = write_packet(&mut buf, struct_as_slice(&pkt),
-                                 &mut next_msg_id, PacketType::PublicKeyNonce);
+            let w = write_packet(conn, &mut buf, struct_as_slice(&pkt),
+                                 PacketType::PublicKeyNonce)?;
             conn.state = ConnectionState::ClientWaitKey(std::time::SystemTime::now());
-            true
+            w
         },
         ConnectionState::ClientSendAck => {
             // Acknowledge reception of server's session public key and nonce
             let pkt: HandshakePacket = Default::default();
-            let _ = write_packet(&mut buf, struct_as_slice(&pkt),
-                                 &mut next_msg_id, PacketType::PubKeyAck);
+            let w = write_packet(conn, &mut buf, struct_as_slice(&pkt),
+                                 PacketType::PubKeyAck)?;
             conn.state = ConnectionState::ClientWaitAck(std::time::SystemTime::now());
-            true
+            w
         },
         ConnectionState::ClientSendAuth => {
             // Send signature of the server's challenge back to the server,
@@ -529,12 +598,12 @@ where T: std::ops::DerefMut<Target = U>,
                     Ok(s) => s, // local copy of secret key
                     Err(_) => {
                         conn.reset_state(Some(OssuaryError::InvalidKey));
-                        return true;
+                        return Err(OssuaryError::InvalidKey);
                     }
                 },
                 None => {
                     conn.reset_state(Some(OssuaryError::InvalidKey));
-                    return true;
+                    return Err(OssuaryError::InvalidKey);
                 }
             };
             let public = PublicKey::from_secret::<Sha512>(&secret);
@@ -543,7 +612,7 @@ where T: std::ops::DerefMut<Target = U>,
                 Some(ref c) => keypair.sign::<Sha512>(c).to_bytes(),
                 None => {
                     conn.reset_state(None);
-                    return true;
+                    return Err(OssuaryError::InvalidSignature);
                 }
             };
             let mut pkt_data: Vec<u8> = Vec::with_capacity(CHALLENGE_LEN + 32);
@@ -557,7 +626,7 @@ where T: std::ops::DerefMut<Target = U>,
                 Some(ref s) => s,
                 None => {
                     conn.reset_state(None);
-                    return true;
+                    return Err(OssuaryError::InvalidKey);
                 }
             };
             let tag = match encrypt(session_key,
@@ -566,7 +635,7 @@ where T: std::ops::DerefMut<Target = U>,
                 Ok(t) => t,
                 Err(_) => {
                     conn.reset_state(None);
-                    return true;
+                    return Err(OssuaryError::InvalidKey);
                 }
             };
 
@@ -578,31 +647,33 @@ where T: std::ops::DerefMut<Target = U>,
             pkt_buf.extend(struct_as_slice(&pkt));
             pkt_buf.extend(&ciphertext);
             pkt_buf.extend(&tag);
-            let _ = write_packet(&mut buf, &pkt_buf,
-                                 &mut next_msg_id, PacketType::AuthResponse);
+            let w = write_packet(conn, &mut buf, &pkt_buf,
+                                 PacketType::AuthResponse)?;
             conn.state = ConnectionState::Encrypted;
-            false // handshake is finished (success)
+            w // handshake is finished (success)
         },
         ConnectionState::Failed(_) => {
             // This is a permanent failure.
             let pkt: HandshakePacket = Default::default();
-            let _ = write_packet(&mut buf, struct_as_slice(&pkt),
-                                 &mut next_msg_id, PacketType::Disconnect);
+            let w = write_packet(conn, &mut buf, struct_as_slice(&pkt),
+                                 PacketType::Disconnect)?;
             conn.reset_state(Some(OssuaryError::ConnectionFailed));
-            false // handshake is finished (failed)
+            w // handshake is finished (failed)
         },
     };
-    conn.local_msg_id = next_msg_id;
-    // TODO: either this should return amount write, or send_data() should not
-    more
+    Ok(written)
 }
 
-// TODO u16 should be usize
 pub fn crypto_recv_handshake<T,U>(conn: &mut ConnectionContext, buf: T) -> Result<usize, OssuaryError>
 where T: std::ops::DerefMut<Target = U>,
       U: std::io::Read {
-    // TODO: read_exact won't work.
     let mut bytes_read: usize = 0;
+
+    match conn.state {
+        ConnectionState::Encrypted => return Ok(0),
+        _ => {},
+    }
+
     let pkt: NetworkPacket = match read_packet(conn, buf) {
         Ok((p, r)) => {
             bytes_read += r;
@@ -616,28 +687,27 @@ where T: std::ops::DerefMut<Target = U>,
         }
     };
 
-    if pkt.header.msg_id != conn.remote_msg_id {
-        println!("Message gap detected.  Restarting connection.");
-        println!("Server: {}", conn.is_server());
-        conn.reset_state(None);
-        return Err(OssuaryError::InvalidPacket("Message ID does not match".into()));
-    }
-    conn.remote_msg_id = pkt.header.msg_id + 1;
-
     let mut error = false;
     match pkt.kind() {
         PacketType::Reset => {
             conn.reset_state(None);
-            return Err(OssuaryError::ConnectionFailed);
+            return Err(OssuaryError::ConnectionReset);
         },
         PacketType::Disconnect => {
-            // TODO: handle error
             conn.reset_state(Some(OssuaryError::ConnectionFailed));
-            panic!("Remote side terminated connection.");
+            return Err(OssuaryError::ConnectionFailed);
         },
         _ => {},
     }
 
+    if pkt.header.msg_id != conn.remote_msg_id {
+        println!("Message gap detected.  Restarting connection.");
+        println!("Server: {}", conn.is_server());
+        conn.reset_state(None);
+        return Err(OssuaryError::InvalidPacket("Message ID does not match".into()));
+    }
+    conn.remote_msg_id = pkt.header.msg_id + 1;
+
     match conn.state {
         ConnectionState::ServerNew => {
             match pkt.kind() {
@@ -687,9 +757,9 @@ where T: std::ops::DerefMut<Target = U>,
                                     return Err(OssuaryError::InvalidKey);
                                 }
                             };
-                            let _ = decrypt(session_key,
-                                            &remote_nonce,
-                                            &aad, &ciphertext, &tag, &mut plaintext);
+                            decrypt(session_key,
+                                    &remote_nonce,
+                                    &aad, &ciphertext, &tag, &mut plaintext)?;
                             let pubkey = &plaintext[0..32];
                             let sig = &plaintext[32..];
                             if conn.authorized_keys.iter().filter(|k| &pubkey == k).count() > 0 {
@@ -797,9 +867,9 @@ where T: std::ops::DerefMut<Target = U>,
                                     return Err(OssuaryError::InvalidKey);
                                 }
                             };
-                            let _ = decrypt(session_key,
-                                            &remote_nonce,
-                                            &aad, &ciphertext, &tag, &mut plaintext);
+                            decrypt(session_key,
+                                    &remote_nonce,
+                                    &aad, &ciphertext, &tag, &mut plaintext)?;
                             conn.challenge = Some(plaintext);
                             conn.state = ConnectionState::ClientSendAuth;
                         },
@@ -829,7 +899,6 @@ where T: std::ops::DerefMut<Target = U>,
     Ok(bytes_read)
 }
 
-// TODO: should return a Result with error on forced-disconnect or permanent failure
 pub fn crypto_handshake_done(conn: &ConnectionContext) -> Result<bool, &OssuaryError> {
     match conn.state {
         ConnectionState::Encrypted => Ok(true),
@@ -841,14 +910,18 @@ pub fn crypto_handshake_done(conn: &ConnectionContext) -> Result<bool, &OssuaryE
 pub fn crypto_send_data<T,U>(conn: &mut ConnectionContext, in_buf: &[u8], mut out_buf: T) -> Result<usize, OssuaryError>
 where T: std::ops::DerefMut<Target = U>,
       U: std::io::Write {
+    // Try to send any unsent buffered data
+    match write_stored_packet(conn, &mut out_buf) {
+        Ok(w) if w == 0 => {},
+        Ok(w) => return Err(OssuaryError::WouldBlock(w)),
+        Err(e) => return Err(e),
+    }
     match conn.state {
         ConnectionState::Encrypted => {},
         _ => {
             return Err(OssuaryError::InvalidPacket("Encrypted channel not established.".into()));
         }
     }
-    let mut next_msg_id = conn.local_msg_id;
-    let bytes;
     let aad = [];
     let mut ciphertext = Vec::with_capacity(in_buf.len());
     let session_key = match conn.local_key.session {
@@ -875,11 +948,9 @@ where T: std::ops::DerefMut<Target = U>,
     buf.extend(struct_as_slice(&pkt));
     buf.extend(&ciphertext);
     buf.extend(&tag);
-    let _ = write_packet(&mut out_buf, &buf,
-                         &mut next_msg_id, PacketType::EncryptedData);
-    bytes = buf.len() + ::std::mem::size_of::<PacketHeader>();
-    conn.local_msg_id = next_msg_id;
-    Ok(bytes)
+    let written = write_packet(conn, &mut out_buf, &buf,
+                               PacketType::EncryptedData)?;
+    Ok(written)
 }
 
 pub fn crypto_recv_data<T,U,R,V>(conn: &mut ConnectionContext, in_buf: T, mut out_buf: R) -> Result<(usize, usize), OssuaryError>
@@ -900,7 +971,9 @@ where T: std::ops::DerefMut<Target = U>,
         Ok((pkt, bytes)) => {
             bytes_read += bytes;
             if pkt.header.msg_id != conn.remote_msg_id {
-                println!("Message gap detected.  Restarting connection.");
+                let msg_id = pkt.header.msg_id;
+                println!("Message gap detected.  Restarting connection. ({} != {})",
+                         msg_id, conn.remote_msg_id);
                 println!("Server: {}", conn.is_server());
                 conn.reset_state(None);
                 return Err(OssuaryError::InvalidPacket("Message ID mismatch".into()))
@@ -929,11 +1002,13 @@ where T: std::ops::DerefMut<Target = U>,
                                     return Err(OssuaryError::InvalidKey);
                                 }
                             };
-                            let _ = decrypt(session_key,
-                                            &remote_nonce,
-                                            &aad, &ciphertext, &tag, &mut plaintext);
-                            let _ = out_buf.write(&plaintext);
-                            bytes_written = ciphertext.len();
+                            decrypt(session_key,
+                                    &remote_nonce,
+                                    &aad, &ciphertext, &tag, &mut plaintext)?;
+                            bytes_written = match out_buf.write(&plaintext) {
+                                Ok(w) => w,
+                                Err(e) => return Err(e.into()),
+                            };
                         },
                         Err(_) => {
                             conn.reset_state(None);
@@ -956,6 +1031,12 @@ where T: std::ops::DerefMut<Target = U>,
     Ok((bytes_read, bytes_written))
 }
 
+pub fn crypto_flush<R,V>(conn: &mut ConnectionContext, mut out_buf: R) -> Result<usize, OssuaryError>
+where R: std::ops::DerefMut<Target = V>,
+      V: std::io::Write {
+    return write_stored_packet(conn, &mut out_buf);
+}
+
 #[cfg(test)]
 mod tests {
     extern crate test;
@@ -1003,7 +1084,7 @@ mod tests {
             let mut server_stream = listener.incoming().next().unwrap().unwrap();
             let mut server_conn = ConnectionContext::new(ConnectionType::UnauthenticatedServer);
             while crypto_handshake_done(&server_conn).unwrap() == false {
-                if crypto_send_handshake(&mut server_conn, &mut server_stream) {
+                if crypto_send_handshake(&mut server_conn, &mut server_stream).is_ok() {
                     loop {
                         match crypto_recv_handshake(&mut server_conn, &mut server_stream) {
                             Ok(_) => break,
@@ -1013,18 +1094,24 @@ mod tests {
                     }
                 }
             }
+            println!("server handshook");
             let mut plaintext = vec!();
             let mut bytes: u64 = 0;
             let start = std::time::SystemTime::now();
             loop {
+                //std::thread::sleep(std::time::Duration::from_millis(100));
                 match crypto_recv_data(&mut server_conn,
                                        &mut server_stream,
                                        &mut plaintext) {
                     Ok((read, _written)) => bytes += read as u64,
                     Err(OssuaryError::WouldBlock(_)) => continue,
-                    _ => panic!("Recv failed"),
+                    Err(e) => {
+                        println!("err: {:?}", e);
+                        panic!("Recv failed")
+                    },
                 }
                 if plaintext == [0xde, 0xde, 0xbe, 0xbe] {
+                    println!("finished");
                     if let Ok(dur) = start.elapsed() {
                         let t = dur.as_secs() as f64
                             + dur.subsec_nanos() as f64 * 1e-9;
@@ -1042,24 +1129,32 @@ mod tests {
         client_stream.set_nonblocking(true).unwrap();
         let mut client_conn = ConnectionContext::new(ConnectionType::Client);
         while crypto_handshake_done(&client_conn).unwrap() == false {
-            if crypto_send_handshake(&mut client_conn, &mut client_stream) {
+            if crypto_send_handshake(&mut client_conn, &mut client_stream).is_ok() {
                 loop {
                     match crypto_recv_handshake(&mut client_conn, &mut client_stream) {
                         Ok(_) => break,
                         Err(OssuaryError::WouldBlock(_)) => {},
-                        _ => panic!("Handshake failed"),
+                        Err(e) => {
+                            println!("err: {:?}", e);
+                            panic!("Handshake failed")
+                        },
                     }
                 }
             }
         }
+        println!("client handshook");
         let mut client_stream = std::io::BufWriter::new(client_stream);
         let mut bytes: u64 = 0;
         let start = std::time::SystemTime::now();
         let mut plaintext: &[u8] = &[0xaa; 16384];
         b.iter(|| {
-            bytes += crypto_send_data(&mut client_conn,
-                                      &mut plaintext,
-                                      &mut client_stream).unwrap() as u64;
+            match crypto_send_data(&mut client_conn,
+                                   &mut plaintext,
+                                   &mut client_stream) {
+                Ok(b) => bytes += b as u64,
+                Err(OssuaryError::WouldBlock(_)) => {},
+                _ => panic!("send error"),
+            }
         });
         if let Ok(dur) = start.elapsed() {
             let t = dur.as_secs() as f64
@@ -1068,8 +1163,47 @@ mod tests {
             println!("{:.2} MB/s", bytes as f64 / 1024.0 / 1024.0 / t);
         }
         let mut plaintext: &[u8] = &[0xde, 0xde, 0xbe, 0xbe];
-        let _ = crypto_send_data(&mut client_conn, &mut plaintext, &mut client_stream);
-        drop(client_stream); // flush the buffer
+        loop {
+            match crypto_send_data(&mut client_conn, &mut plaintext, &mut client_stream) {
+                Ok(w) => {
+                    println!("wrote finish: {}", w);
+                    break;
+                },
+                Err(OssuaryError::WouldBlock(_)) => {},
+                _ => panic!("Send failed"),
+            }
+        }
+        loop {
+            match crypto_flush(&mut client_conn, &mut client_stream) {
+                Ok(w) => {
+                    if w == 0 {
+                        break;
+                    }
+                    println!("flushed: {}", w);
+                },
+                _ => panic!("Flush failed"),
+            }
+        }
+
+        let mut client_stream: Option<std::io::BufWriter<_>> = Some(client_stream);
+        loop {
+            client_stream = match client_stream {
+                None => break,
+                Some(s) => match s.into_inner() {
+                    Ok(_) => None,
+                    Err(e) => {
+                        match e.error().kind() {
+                            std::io::ErrorKind::WouldBlock => {
+                                Some(e.into_inner())
+                            },
+                            _ => panic!("error: {:?}", e.error()),
+                        }
+                    },
+                }
+            };
+        }
+        println!("flushed");
+        //drop(client_stream); // flush the buffer
         let _ = server_thread.join();
     }
 }

diff --git a/tests/basic.rs b/tests/basic.rs
line changes: +1/-1
index 2968818..2bc5a12
--- a/tests/basic.rs
+++ b/tests/basic.rs
@@ -12,7 +12,7 @@ fn event_loop<T>(mut conn: ConnectionContext,
 where T: std::io::Read + std::io::Write {
     // Run the opaque handshake until the connection is established
     while crypto_handshake_done(&conn).unwrap() == false {
-        if crypto_send_handshake(&mut conn, &mut stream) {
+        if crypto_send_handshake(&mut conn, &mut stream).is_ok() {
             loop {
                 match crypto_recv_handshake(&mut conn, &mut stream) {
                     Ok(_) => break,

diff --git a/tests/basic_auth.rs b/tests/basic_auth.rs
line changes: +1/-1
index 41264d1..fad8807
--- a/tests/basic_auth.rs
+++ b/tests/basic_auth.rs
@@ -12,7 +12,7 @@ fn event_loop<T>(mut conn: ConnectionContext,
 where T: std::io::Read + std::io::Write {
     // Run the opaque handshake until the connection is established
     while crypto_handshake_done(&conn).unwrap() == false {
-        if crypto_send_handshake(&mut conn, &mut stream) {
+        if crypto_send_handshake(&mut conn, &mut stream).is_ok() {
             loop {
                 match crypto_recv_handshake(&mut conn, &mut stream) {
                     Ok(_) => break,