summary history branches tags files
commit:ab0d628b9d012e3db2eebd13f09d54d2798f3fe2
author:Trevor Bentley
committer:Trevor Bentley
date:Thu Dec 20 22:47:18 2018 +0100
parents:a1eb942b74d288cc4f2948167924f041ae674dab
FFI examples working
diff --git a/Cargo.toml b/Cargo.toml
line changes: +14/-0
index 69eac8f..561f15b
--- a/Cargo.toml
+++ b/Cargo.toml
@@ -4,6 +4,20 @@ version = "0.1.0"
 authors = ["Trevor Bentley"]
 edition = "2018"
 
+[lib]
+crate_type = ["lib", "cdylib", "staticlib"]
+
+[profile.release]
+opt-level = "z"
+debug = false
+rpath = false
+lto = true
+debug-assertsion = false
+codegen-units = 1
+panic = 'abort' #'unwind'
+incremental = true
+overflow-checks = false
+
 [dependencies]
 x25519-dalek = "0.3.0"
 chacha20-poly1305-aead = "0.1.2"

diff --git a/examples/CMakeLists.txt b/examples/CMakeLists.txt
line changes: +12/-0
index 0000000..78bd865
--- /dev/null
+++ b/examples/CMakeLists.txt
@@ -0,0 +1,12 @@
+cmake_minimum_required(VERSION 3.12)
+project(ffi C)
+
+set(CMAKE_C_FLAGS "${CMAKE_C_FLAGS} -flto -O3")
+
+add_executable (ffi ffi.c)
+target_include_directories (ffi PUBLIC ${CMAKE_CURRENT_SOURCE_DIR}/../ffi)
+target_link_libraries (ffi LINK_PUBLIC ${CMAKE_CURRENT_SOURCE_DIR}/../target/release/libossuary.dylib)
+
+#target_link_libraries (ffi LINK_PUBLIC ${CMAKE_CURRENT_SOURCE_DIR}/../target/release/libossuary.a)
+#find_library(SECURITY_FRAMEWORK Security)
+#target_link_libraries(ffi LINK_PUBLIC ${SECURITY_FRAMEWORK})

diff --git a/examples/ffi.c b/examples/ffi.c
line changes: +69/-0
index 0000000..b000fb6
--- /dev/null
+++ b/examples/ffi.c
@@ -0,0 +1,69 @@
+#include <stdio.h>
+#include <stdint.h>
+#include <string.h>
+#include <unistd.h>
+#include "ossuary.h"
+
+uint8_t client_buf[256];
+uint8_t server_buf[256];
+
+int main(int argc, char **argv) {
+  int client_done, server_done;
+  uint16_t client_bytes, server_bytes, bytes;
+  ConnectionContext *client_conn = NULL;
+  ConnectionContext *server_conn = NULL;
+  client_conn = ossuary_create_connection(0);
+  server_conn = ossuary_create_connection(1);
+
+  memset(client_buf, 0, sizeof(client_buf));
+  memset(server_buf, 0, sizeof(server_buf));
+
+  // Client and server send handshakes
+  int count = 0;
+  do {
+    client_done = ossuary_handshake_done(client_conn);
+    server_done = ossuary_handshake_done(server_conn);
+    printf("done: %d %d\n", client_done, server_done);
+
+    if (!client_done) {
+      client_bytes = sizeof(client_buf);
+      ossuary_send_handshake(client_conn, client_buf, &client_bytes);
+      printf("client send handshake bytes: %d\n", client_bytes);
+
+      if (client_bytes) {
+        ossuary_recv_handshake(server_conn, client_buf, &client_bytes);
+        printf("server recv handshake bytes: %d\n", client_bytes);
+      }
+    }
+
+    if (!server_done) {
+      server_bytes = sizeof(server_buf);
+      ossuary_send_handshake(server_conn, server_buf, &server_bytes);
+      printf("server send handshake bytes: %d\n", server_bytes);
+
+      if (server_bytes) {
+        ossuary_recv_handshake(client_conn, server_buf, &server_bytes);
+        printf("client recv handshake bytes: %d\n", server_bytes);
+      }
+    }
+
+    //if (++count == 8) break;
+    usleep(100000);
+  } while (!client_done || !server_done);
+
+  memset(client_buf, 0, sizeof(client_buf));
+  memset(server_buf, 0, sizeof(server_buf));
+
+  // Server sends encrypted data
+  bytes = snprintf((char*)server_buf, sizeof(server_buf), "hello world");
+  bytes = ossuary_send_data(server_conn, server_buf, bytes, client_buf, sizeof(client_buf));
+  printf("server send data bytes: %d\n", bytes);
+
+  // Client receives decrypted data
+  bytes = ossuary_recv_data(client_conn, client_buf, bytes, client_buf, sizeof(client_buf));
+  printf("client recv data bytes: %d\n", bytes);
+  printf("decrypted: %s\n", client_buf);
+
+  ossuary_destroy_connection(&client_conn);
+  ossuary_destroy_connection(&server_conn);
+}

diff --git a/ffi/ossuary.h b/ffi/ossuary.h
line changes: +22/-0
index 0000000..27ce7b7
--- /dev/null
+++ b/ffi/ossuary.h
@@ -0,0 +1,22 @@
+#ifndef _OSSUARY_H
+
+#include <stdint.h>
+
+typedef struct ConnectionContext ConnectionContext;
+
+ConnectionContext *ossuary_create_connection(uint8_t is_server);
+int32_t ossuary_destroy_connection(ConnectionContext **conn);
+int32_t ossuary_recv_handshake(ConnectionContext *conn,
+                               uint8_t *in_buf, uint16_t *in_buf_len);
+int32_t ossuary_send_handshake(ConnectionContext *conn,
+                               uint8_t *out_buf, uint16_t *out_buf_len);
+uint8_t ossuary_handshake_done(ConnectionContext *conn);
+int32_t ossuary_send_data(ConnectionContext *conn,
+                          uint8_t *in_buf, uint16_t in_buf_len,
+                          uint8_t *out_buf, uint16_t out_buf_len);
+int32_t ossuary_recv_data(ConnectionContext *conn,
+                  uint8_t *in_buf, uint16_t in_buf_len,
+                  uint8_t *out_buf, uint16_t out_buf_len);
+
+#define _OSSUARY_H
+#endif

diff --git a/src/clib.rs b/src/clib.rs
line changes: +161/-26
index ad03083..2318816
--- a/src/clib.rs
+++ b/src/clib.rs
@@ -1,38 +1,112 @@
-use crate::{crypto_send_handshake, crypto_recv_handshake, ConnectionContext};
+use crate::{crypto_send_data, crypto_recv_data,
+            crypto_send_handshake, crypto_recv_handshake, crypto_handshake_done,
+            ConnectionContext};
 
 #[no_mangle]
-pub extern "C" fn ossuary_create_connection() -> *mut ConnectionContext {
-    let mut conn = Box::new(ConnectionContext::new());
+pub extern "C" fn ossuary_create_connection(is_server: u8) -> *mut ConnectionContext {
+    let is_server: bool = match is_server {
+        0 => false,
+        _ => true,
+    };
+    let mut conn = Box::new(ConnectionContext::new(is_server)); // todo
     let ptr: *mut _ = &mut *conn;
     ::std::mem::forget(conn);
     ptr
 }
 
 #[no_mangle]
-pub extern "C" fn ossuary_recv_handshake(conn: *mut ConnectionContext, in_buf: *const u8, in_buf_len: u16) -> i32 {
-    if conn.is_null() || in_buf.is_null() {
+pub extern "C" fn ossuary_destroy_connection(conn: &mut *mut ConnectionContext) {
+    if conn.is_null() {
+        return;
+    }
+    let obj: Box<ConnectionContext> = unsafe { ::std::mem::transmute(*conn) };
+    ::std::mem::drop(obj);
+    *conn = ::std::ptr::null_mut();
+}
+
+#[no_mangle]
+pub extern "C" fn ossuary_recv_handshake(conn: *mut ConnectionContext,
+                                         in_buf: *const u8, in_buf_len: *mut u16) -> i32 {
+    if conn.is_null() || in_buf.is_null() || in_buf_len.is_null() {
         return -1i32;
     }
     let mut conn = unsafe { &mut *conn };
-    let r_in_buf: &[u8] = unsafe { std::slice::from_raw_parts(in_buf, in_buf_len as usize) };
+    let inlen = unsafe { *in_buf_len as usize };
+    let r_in_buf: &[u8] = unsafe { std::slice::from_raw_parts(in_buf, inlen) };
     let mut slice = r_in_buf;
     crypto_recv_handshake(&mut conn, &mut slice);
+    ::std::mem::forget(conn);
+    unsafe { *in_buf_len = (inlen - slice.len()) as u16 };
+    0i32
+}
 
+#[no_mangle]
+pub extern "C" fn ossuary_send_handshake(conn: *mut ConnectionContext,
+                                         out_buf: *mut u8, out_buf_len: *mut u16) -> i32 {
+    if conn.is_null() || out_buf.is_null() || out_buf_len.is_null() {
+        return -1i32;
+    }
+    let mut conn = unsafe { &mut *conn };
+    let outlen = unsafe { *out_buf_len as usize };
+    let r_out_buf: &mut [u8] = unsafe { std::slice::from_raw_parts_mut(out_buf, outlen) };
+    let mut slice = r_out_buf;
+    let more = crypto_send_handshake(&mut conn, &mut slice);
     ::std::mem::forget(conn);
-    (in_buf_len - slice.len() as u16) as i32
+    unsafe { *out_buf_len = (outlen - slice.len()) as u16 };
+    more as i32
 }
 
 #[no_mangle]
-pub extern "C" fn ossuary_send_handshake(conn: *mut ConnectionContext, in_buf: *mut u8, in_buf_len: u16) -> i32 {
-    if conn.is_null() || in_buf.is_null() {
+pub extern "C" fn ossuary_handshake_done(conn: *const ConnectionContext) -> u8 {
+    if conn.is_null() {
+        return 0u8;
+    }
+    let conn = unsafe { &*conn };
+    let done = crypto_handshake_done(&conn);
+    ::std::mem::forget(conn);
+    done as u8
+}
+
+#[no_mangle]
+pub extern "C" fn ossuary_send_data(conn: *mut ConnectionContext,
+                                    in_buf: *mut u8, in_buf_len: u16,
+                                    out_buf: *mut u8, out_buf_len: u16) -> i32 {
+    if conn.is_null() || in_buf.is_null() || out_buf.is_null() {
         return -1i32;
     }
     let mut conn = unsafe { &mut *conn };
-    let r_in_buf: &mut [u8] = unsafe { std::slice::from_raw_parts_mut(in_buf, in_buf_len as usize) };
-    let mut slice = r_in_buf;
-    crypto_send_handshake(&mut conn, &mut slice);
+    let r_out_buf: &mut [u8] = unsafe { std::slice::from_raw_parts_mut(out_buf, out_buf_len as usize) };
+    let r_in_buf: &[u8] = unsafe { std::slice::from_raw_parts(in_buf, in_buf_len as usize) };
+    let mut out_slice = r_out_buf;
+    let in_slice = r_in_buf;
+    match crypto_send_data(&mut conn, &in_slice, &mut out_slice) {
+        Err(_) => { return -1; },
+        _ => {},
+    }
     ::std::mem::forget(conn);
-    (in_buf_len - slice.len() as u16) as i32
+    (out_buf_len - out_slice.len() as u16) as i32
+}
+
+#[no_mangle]
+pub extern "C" fn ossuary_recv_data(conn: *mut ConnectionContext,
+                                    in_buf: *mut u8, in_buf_len: u16,
+                                    out_buf: *mut u8, out_buf_len: u16) -> i32 {
+    if conn.is_null() || in_buf.is_null() || out_buf.is_null() {
+        return -1i32;
+    }
+    let mut conn = unsafe { &mut *conn };
+    let r_out_buf: &mut [u8] = unsafe { std::slice::from_raw_parts_mut(out_buf, out_buf_len as usize) };
+    let r_in_buf: &[u8] = unsafe { std::slice::from_raw_parts(in_buf, in_buf_len as usize) };
+    let mut out_slice = r_out_buf;
+    let mut in_slice = r_in_buf;
+    match crypto_recv_data(&mut conn, &mut in_slice, &mut out_slice) {
+        Ok(_) => {},
+        Err(e) => {
+            println!("recv_data failed: {} {}", e, conn.is_server);
+            return -1; },
+    }
+    ::std::mem::forget(conn);
+    (out_buf_len - out_slice.len() as u16) as i32
 }
 
 #[cfg(test)]
@@ -42,38 +116,99 @@ mod tests {
     use std::net::{TcpListener, TcpStream};
     use crate::clib::*;
     pub fn server() -> Result<(), std::io::Error> {
-        println!("server start");
         let listener = TcpListener::bind("127.0.0.1:9989").unwrap();
         for stream in listener.incoming() {
             let mut stream: TcpStream = stream.unwrap();
-            let conn = ossuary_create_connection();
+            let mut conn = ossuary_create_connection(1);
+
+            let out_buf: [u8; 256] = [0; 256];
             let mut in_buf: [u8; 256] = [0; 256];
-            ossuary_send_handshake(conn, (&in_buf) as *const u8 as *mut u8, in_buf.len() as u16);
-            let _ = stream.write(&in_buf);
-            let _ = stream.read(&mut in_buf);
-            ossuary_recv_handshake(conn, (&in_buf) as *const u8, in_buf.len() as u16);
+
+            while ossuary_handshake_done(conn) == 0 {
+                let mut out_len = out_buf.len() as u16;
+                let more = ossuary_send_handshake(conn, (&out_buf) as *const u8 as *mut u8, &mut out_len);
+                let _ = stream.write(&out_buf[0..out_len as usize]);
+
+                if more != 0 {
+                    let _ = stream.read(&mut in_buf);
+                    let mut in_len = in_buf.len() as u16;
+                    ossuary_recv_handshake(conn, (&in_buf) as *const u8, &mut in_len);
+                }
+            }
+
+            let out_buf: [u8; 256] = [0; 256];
+            let mut plaintext: [u8; 256] = [0; 256];
+            plaintext[0..11].copy_from_slice("hello world".as_bytes());
+            ossuary_send_data(
+                conn,
+                (&plaintext) as *const u8 as *mut u8, 11 as u16,
+                (&out_buf) as *const u8 as *mut u8, out_buf.len() as u16);
+            let _ = stream.write(&out_buf);
+
+            let out_buf: [u8; 256] = [0; 256];
+            let mut plaintext: [u8; 256] = [0; 256];
+            plaintext[0..13].copy_from_slice("goodbye world".as_bytes());
+            ossuary_send_data(
+                conn,
+                (&plaintext) as *const u8 as *mut u8, 13 as u16,
+                (&out_buf) as *const u8 as *mut u8, out_buf.len() as u16);
+            let _ = stream.write(&out_buf);
+
+            ossuary_destroy_connection(&mut conn);
             break;
         }
-        println!("server done");
         Ok(())
     }
 
     pub fn client() -> Result<(), std::io::Error> {
-        println!("client start");
         let mut stream = TcpStream::connect("127.0.0.1:9989").unwrap();
-        let conn = ossuary_create_connection();
+        let mut conn = ossuary_create_connection(0);
+
+        let out_buf: [u8; 256] = [0; 256];
         let mut in_buf: [u8; 256] = [0; 256];
-        ossuary_send_handshake(conn, (&in_buf) as *const u8 as *mut u8, in_buf.len() as u16);
-        let _ = stream.write(&in_buf);
+
+        while ossuary_handshake_done(conn) == 0 {
+            let mut out_len = out_buf.len() as u16;
+            let more = ossuary_send_handshake(conn, (&out_buf) as *const u8 as *mut u8, &mut out_len);
+            let _ = stream.write(&out_buf[0.. out_len as usize]);
+
+            if more != 0 {
+                let _ = stream.read(&mut in_buf);
+                let mut in_len = in_buf.len() as u16;
+                ossuary_recv_handshake(conn, (&in_buf) as *const u8, &mut in_len);
+            }
+        }
+
+
         let _ = stream.read(&mut in_buf);
-        ossuary_recv_handshake(conn, (&in_buf) as *const u8, in_buf.len() as u16);
-        println!("client done");
+        let out_buf: [u8; 256] = [0; 256];
+        let len = ossuary_recv_data(
+            conn,
+            (&in_buf) as *const u8 as *mut u8, in_buf.len() as u16,
+            (&out_buf) as *const u8 as *mut u8, out_buf.len() as u16);
+        if len != -1 {
+            println!("FFI RECEIVED: {:?}", std::str::from_utf8(&out_buf[0..len as usize]));
+        }
+
+        let _ = stream.read(&mut in_buf);
+        let out_buf: [u8; 256] = [0; 256];
+        let len = ossuary_recv_data(
+            conn,
+            (&in_buf) as *const u8 as *mut u8, in_buf.len() as u16,
+            (&out_buf) as *const u8 as *mut u8, out_buf.len() as u16);
+        if len != -1 {
+            println!("FFI RECEIVED: {:?}", std::str::from_utf8(&out_buf[0..len as usize]));
+        }
+
+        ossuary_destroy_connection(&mut conn);
         Ok(())
     }
     pub fn test() {
+        println!("FFI START");
         thread::spawn(move || { let _ = server(); });
         let child = thread::spawn(move || { let _ = client(); });
         let _ = child.join();
+        println!("FFI DONE");
     }
     #[test]
     fn it_works() {

diff --git a/src/lib.rs b/src/lib.rs
line changes: +271/-105
index c2cc991..6ddbf09
--- a/src/lib.rs
+++ b/src/lib.rs
@@ -9,13 +9,15 @@ use x25519_dalek::generate_secret;
 use x25519_dalek::generate_public;
 use x25519_dalek::diffie_hellman;
 
-use rand::thread_rng;
+//use rand::thread_rng;
 use rand::RngCore;
+use rand::rngs::OsRng;
 
 use std::convert::TryInto;
 
 pub mod clib;
 
+const MAX_PUB_KEY_ACK_TIME: u64 = 3u64;
 //
 // API:
 //  * sock -- TCP data socket
@@ -68,8 +70,28 @@ fn slice_as_struct<T>(p: &[u8]) -> Result<&T, &'static str> {
         Ok(&*(&p[..::std::mem::size_of::<T>()] as *const [u8] as *const T))
     }
 }
-#[repr(packed)]
-#[allow(dead_code)]
+
+pub enum OssuaryError {
+    Io(std::io::Error),
+    Unpack(core::array::TryFromSliceError),
+}
+impl std::fmt::Debug for OssuaryError {
+    fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
+        write!(f, "OssuaryError")
+    }
+}
+impl From<std::io::Error> for OssuaryError {
+    fn from(error: std::io::Error) -> Self {
+        OssuaryError::Io(error)
+    }
+}
+impl From<core::array::TryFromSliceError> for OssuaryError {
+    fn from(error: core::array::TryFromSliceError) -> Self {
+        OssuaryError::Unpack(error)
+    }
+}
+
+#[repr(C,packed)]
 struct HandshakePacket {
     len: u16,
     _reserved: u16,
@@ -89,35 +111,36 @@ impl Default for HandshakePacket {
 
 #[repr(u16)]
 #[derive(Clone, Copy)]
-#[allow(dead_code)]
 enum PacketType {
-    Unknown = 0,
+    Unknown = 0x00,
     PublicKeyNonce = 0x01,
-    AuthRequest = 0x02,
-    EncryptedData = 0x03,
-    Disconnect = 0x04,
+    PubKeyAck = 0x02,
+    AuthRequest = 0x03,
+    Reset = 0x04,
+    Disconnect = 0x05,
+    EncryptedData = 0x10,
 }
 impl PacketType {
     pub fn from_u16(i: u16) -> PacketType {
         match i {
             0x01 => PacketType::PublicKeyNonce,
-            0x02 => PacketType::AuthRequest,
-            0x03 => PacketType::EncryptedData,
-            0x04 => PacketType::Disconnect,
+            0x02 => PacketType::PubKeyAck,
+            0x03 => PacketType::AuthRequest,
+            0x04 => PacketType::Reset,
+            0x05 => PacketType::Disconnect,
+            0x10 => PacketType::EncryptedData,
             _ => PacketType::Unknown,
         }
     }
 }
 
-#[repr(packed)]
-#[allow(dead_code)]
+#[repr(C,packed)]
 struct EncryptedPacket {
     data_len: u16,
     tag_len: u16,
 }
 
-#[repr(packed)]
-#[allow(dead_code)]
+#[repr(C,packed)]
 struct PacketHeader {
     len: u16,
     msg_id: u16,
@@ -125,11 +148,26 @@ struct PacketHeader {
     _reserved: u16,
 }
 
+struct NetworkPacket {
+    header: PacketHeader,
+    data: Box<[u8]>,
+}
+impl NetworkPacket {
+    fn kind(&self) -> PacketType {
+        self.header.packet_type
+    }
+}
+
 enum ConnectionState {
-    New,
-    PubKeySent,
+    ServerNew,
+    ServerSendPubKey,
+    ServerWaitAck(std::time::SystemTime),
+
+    ClientNew,
+    ClientWaitKey(std::time::SystemTime),
+    ClientSendAck,
+
     Encrypted,
-    _Authenticated,
 }
 struct KeyMaterial {
     secret: Option<[u8; 32]>,
@@ -139,12 +177,16 @@ struct KeyMaterial {
 }
 pub struct ConnectionContext {
     state: ConnectionState,
+    is_server: bool,
     local_key: KeyMaterial,
     remote_key: Option<KeyMaterial>,
+    local_msg_id: u16,
+    remote_msg_id: u16,
 }
 impl ConnectionContext {
-    fn new() -> ConnectionContext {
-        let mut rng = thread_rng();
+    fn new(server: bool) -> ConnectionContext {
+        //let mut rng = thread_rng();
+        let mut rng = OsRng::new().unwrap();
         let sec_key = generate_secret(&mut rng);
         let pub_key = generate_public(&sec_key);
         let mut nonce: [u8; 12] = [0; 12];
@@ -156,11 +198,23 @@ impl ConnectionContext {
             session: None,
         };
         ConnectionContext {
-            state: ConnectionState::New,
+            state: match server {
+                true => ConnectionState::ServerNew,
+                false => ConnectionState::ClientNew,
+            },
+            is_server: server,
             local_key: key,
             remote_key: None,
+            local_msg_id: 0u16,
+            remote_msg_id: 0u16,
         }
     }
+    fn reset_state(&mut self) {
+        self.state = match self.is_server {
+            true => ConnectionState::ServerNew,
+            false => ConnectionState::ClientNew,
+        };
+    }
     fn add_remote_key(&mut self, public: &[u8; 32], nonce: &[u8; 12]) {
         let key = KeyMaterial {
             secret: None,
@@ -170,17 +224,6 @@ impl ConnectionContext {
         };
         self.remote_key = Some(key);
         self.local_key.session = Some(diffie_hellman(self.local_key.secret.as_ref().unwrap(), public));
-        self.state = ConnectionState::Encrypted;
-    }
-}
-
-struct NetworkPacket {
-    header: PacketHeader,
-    data: Box<[u8]>,
-}
-impl NetworkPacket {
-    fn kind(&self) -> PacketType {
-        self.header.packet_type
     }
 }
 
@@ -194,21 +237,6 @@ fn interpret_packet_extra<'a, T>(pkt: &'a NetworkPacket) -> Result<(&'a T, &[u8]
     Ok((s, &pkt.data[::std::mem::size_of::<T>()..]))
 }
 
-pub enum OssuaryError {
-    Io(std::io::Error),
-    Unpack(core::array::TryFromSliceError),
-}
-impl From<std::io::Error> for OssuaryError {
-    fn from(error: std::io::Error) -> Self {
-        OssuaryError::Io(error)
-    }
-}
-impl From<core::array::TryFromSliceError> for OssuaryError {
-    fn from(error: core::array::TryFromSliceError) -> Self {
-        OssuaryError::Unpack(error)
-    }
-}
-
 fn read_packet<T,U>(mut stream: T) -> Result<NetworkPacket, OssuaryError>
 where T: std::ops::DerefMut<Target = U>,
       U: std::io::Read {
@@ -228,74 +256,183 @@ where T: std::ops::DerefMut<Target = U>,
     })
 }
 
-//fn write_packet<T>(stream: &mut T, data: &[u8], msg_id: u16, kind: PacketType) -> Result<(), std::io::Error>
-//where T: std::io::Write {
-fn write_packet<T,U>(mut stream: T, data: &[u8], msg_id: u16, kind: PacketType) -> Result<(), std::io::Error>
+fn write_packet<T,U>(stream: &mut T, data: &[u8], msg_id: &mut u16, kind: PacketType) -> Result<(), std::io::Error>
 where T: std::ops::DerefMut<Target = U>,
       U: std::io::Write {
     let mut buf: Vec<u8> = Vec::with_capacity(::std::mem::size_of::<PacketHeader>());
     buf.extend_from_slice(&(data.len() as u16).to_be_bytes());
-    buf.extend_from_slice(&(msg_id as u16).to_be_bytes());
+    buf.extend_from_slice(&(*msg_id as u16).to_be_bytes());
     buf.extend_from_slice(&(kind as u16).to_be_bytes());
     buf.extend_from_slice(&(0u16).to_be_bytes());
-    stream.write(&buf)?;
-    stream.write(data)?;
+    let _ = stream.write(&buf)?;
+    let _ = stream.write(data)?;
+    *msg_id = *msg_id + 1;
     Ok(())
 }
 
-pub fn crypto_send_handshake<T,U>(conn: &mut ConnectionContext, buf: T) -> bool
+pub fn crypto_send_handshake<T,U>(conn: &mut ConnectionContext, mut buf: T) -> bool
 where T: std::ops::DerefMut<Target = U>,
       U: std::io::Write {
-    match conn.state {
-        ConnectionState::New => {
+    let mut next_msg_id = conn.local_msg_id;
+    let more = match conn.state {
+        ConnectionState::ServerNew => {
+            // wait for client
+            true
+        },
+        ConnectionState::ServerWaitAck(t) => {
+            // TIMEOUT NACK
+            if let Ok(dur) = t.elapsed() {
+                if dur.as_secs() > MAX_PUB_KEY_ACK_TIME {
+                    let pkt: HandshakePacket = Default::default();
+                    let _ = write_packet(&mut buf, struct_as_slice(&pkt),
+                                         &mut next_msg_id, PacketType::Reset);
+                    conn.state = ConnectionState::ServerNew;
+                }
+            }
+            true
+        },
+        ConnectionState::ServerSendPubKey => {
+            // Send pubkey
+            let mut pkt: HandshakePacket = Default::default();
+            pkt.public_key.copy_from_slice(&conn.local_key.public);
+            pkt.nonce.copy_from_slice(&conn.local_key.nonce);
+            let _ = write_packet(&mut buf, struct_as_slice(&pkt),
+                                 &mut next_msg_id, PacketType::PublicKeyNonce);
+            conn.state = ConnectionState::ServerWaitAck(std::time::SystemTime::now());
+            true
+        },
+        ConnectionState::ClientNew => {
+            // Send pubkey
             let mut pkt: HandshakePacket = Default::default();
             pkt.public_key.copy_from_slice(&conn.local_key.public);
             pkt.nonce.copy_from_slice(&conn.local_key.nonce);
-            let _ = write_packet(buf, struct_as_slice(&pkt), 0, PacketType::PublicKeyNonce);
-            conn.state = ConnectionState::PubKeySent;
+            let _ = write_packet(&mut buf, struct_as_slice(&pkt),
+                                 &mut next_msg_id, PacketType::PublicKeyNonce);
+            conn.state = ConnectionState::ClientWaitKey(std::time::SystemTime::now());
             true
         },
-        _ => {
+        ConnectionState::ClientWaitKey(t) => {
+            if let Ok(dur) = t.elapsed() {
+                if dur.as_secs() > MAX_PUB_KEY_ACK_TIME {
+                    conn.state = ConnectionState::ClientNew;
+                }
+            }
+            true
+        },
+        ConnectionState::ClientSendAck => {
+            let pkt: HandshakePacket = Default::default();
+            let _ = write_packet(&mut buf, struct_as_slice(&pkt),
+                                 &mut next_msg_id, PacketType::PubKeyAck);
+            conn.state = ConnectionState::Encrypted;
             false
-        }
-    }
+        },
+        ConnectionState::Encrypted => {
+            false
+        },
+    };
+    conn.local_msg_id = next_msg_id;
+    more
 }
 
-pub fn crypto_recv_handshake<T,U>(conn: &mut ConnectionContext, buf: T) -> bool
+pub fn crypto_recv_handshake<T,U>(conn: &mut ConnectionContext, buf: T)
 where T: std::ops::DerefMut<Target = U>,
       U: std::io::Read {
+    // TODO: read_exact won't work.
+    let pkt = read_packet(buf);
+    if pkt.is_err() {
+        return;
+    }
+    let pkt: NetworkPacket = pkt.unwrap();
+
+    if pkt.header.msg_id != conn.remote_msg_id {
+        println!("Message gap detected.  Restarting connection.");
+        conn.reset_state();
+        return; // TODO: return error
+    }
+    conn.remote_msg_id = pkt.header.msg_id + 1;
+
+    let mut error = false;
+    match pkt.kind() {
+        PacketType::Reset => {
+            conn.state = match conn.is_server {
+                true => ConnectionState::ServerNew,
+                _ => ConnectionState::ClientNew,
+            };
+            return;
+        },
+        _ => {},
+    }
+
     match conn.state {
-        ConnectionState::New => { return true; },
-        ConnectionState::PubKeySent => {},
-        _ => { return false; }
+        ConnectionState::ServerNew => {
+            match pkt.kind() {
+                PacketType::PublicKeyNonce => {
+                    let data_pkt: &HandshakePacket = interpret_packet(&pkt).as_ref().unwrap();
+                    conn.add_remote_key(&data_pkt.public_key, &data_pkt.nonce);
+                    conn.state = ConnectionState::ServerSendPubKey;
+                },
+                _ => { error = true; }
+            }
+        },
+        ConnectionState::ServerWaitAck(_t) => {
+            match pkt.kind() {
+                PacketType::PubKeyAck => {
+                    conn.state = ConnectionState::Encrypted;
+                },
+                _ => { error = true; }
+            }
+        },
+        ConnectionState::ServerSendPubKey => {
+            error = true;
+        }, // nop
+        ConnectionState::ClientNew => {
+            error = true;
+        }, // nop
+        ConnectionState::ClientWaitKey(_t) => {
+            match pkt.kind() {
+                PacketType::PublicKeyNonce => {
+                    let data_pkt: &HandshakePacket = interpret_packet(&pkt).as_ref().unwrap();
+                    conn.add_remote_key(&data_pkt.public_key, &data_pkt.nonce);
+                    conn.state = ConnectionState::ClientSendAck;
+                },
+                _ => { }
+            }
+        },
+        ConnectionState::ClientSendAck => {
+            error = true;
+        }, // nop
+        ConnectionState::Encrypted => {
+            error = true;
+        }, // nop
     }
-    // TODO: read_exact won't work.
-    if let Ok(pkt) = read_packet(buf) {
-        println!("Packet type: {}", pkt.kind() as u16);
-        match pkt.kind() {
-            PacketType::PublicKeyNonce => {
-                let data_pkt: &HandshakePacket = interpret_packet(&pkt).as_ref().unwrap();
-                conn.add_remote_key(&data_pkt.public_key, &data_pkt.nonce);
-                println!("Session key: {:?}", conn.local_key.session.as_ref().unwrap());
-                conn.state = ConnectionState::Encrypted;
-            },
-            _ => {},
-        }
+    if error {
+        conn.state = match conn.is_server {
+            true => ConnectionState::ServerNew,
+            _ => ConnectionState::ClientNew,
+        };
     }
-    true
 }
 
-pub fn crypto_send_data<T,U>(conn: &mut ConnectionContext, in_buf: &[u8], out_buf: T) -> Result<(), &'static str>
+pub fn crypto_handshake_done(conn: &ConnectionContext) -> bool {
+    match conn.state {
+        ConnectionState::Encrypted => true,
+        _ => false,
+    }
+}
+
+pub fn crypto_send_data<T,U>(conn: &mut ConnectionContext, in_buf: &[u8], mut out_buf: T) -> Result<u16, &'static str>
 where T: std::ops::DerefMut<Target = U>,
       U: std::io::Write {
     match conn.state {
         ConnectionState::Encrypted => {},
         _ => { return Err("Encrypted channel not established."); }
     }
+    let mut next_msg_id = conn.local_msg_id;
+    let bytes;
     let aad = [];
     let mut ciphertext = Vec::with_capacity(in_buf.len());
-    let tag = encrypt(conn.local_key.session.as_ref().unwrap(), &conn.local_key.nonce, &aad, in_buf, &mut ciphertext).unwrap();
-    println!("encrypted: {:?} {:?}", ciphertext, tag);
+    let tag = encrypt(conn.local_key.session.as_ref().unwrap(),
+                      &conn.local_key.nonce, &aad, in_buf, &mut ciphertext).unwrap();
 
     let pkt: EncryptedPacket = EncryptedPacket {
         tag_len: tag.len() as u16,
@@ -305,59 +442,88 @@ where T: std::ops::DerefMut<Target = U>,
     buf.extend(struct_as_slice(&pkt));
     buf.extend(&ciphertext);
     buf.extend(&tag);
-    let _ = write_packet(out_buf, &buf, 0, PacketType::EncryptedData);
-    Ok(())
+    let _ = write_packet(&mut out_buf, &buf,
+                         &mut next_msg_id, PacketType::EncryptedData);
+    bytes = buf.len() as u16;
+    conn.local_msg_id = next_msg_id;
+    Ok(bytes)
 }
 
-pub fn crypto_recv_data<T,U,R,V>(conn: &mut ConnectionContext, in_buf: T, mut out_buf: R) -> Result<(), &'static str>
+pub fn crypto_recv_data<T,U,R,V>(conn: &mut ConnectionContext, in_buf: T, mut out_buf: R) -> Result<u16, &'static str>
 where T: std::ops::DerefMut<Target = U>,
       U: std::io::Read,
       R: std::ops::DerefMut<Target = V>,
       V: std::io::Write {
+    let mut bytes: u16 = 0u16;
     match conn.state {
         ConnectionState::Encrypted => {},
         _ => { return Err("Encrypted channel not established."); }
     }
-    if let Ok(pkt) = read_packet(in_buf) {
-        println!("Packet type: {}", pkt.kind() as u16);
-        match pkt.kind() {
-            PacketType::EncryptedData => {
-                let (data_pkt, rest) = interpret_packet_extra::<EncryptedPacket>(&pkt).unwrap();
-                let ciphertext = &rest[..data_pkt.data_len as usize];
-                let tag = &rest[data_pkt.data_len as usize..];
-                let aad = [];
-                let mut plaintext = Vec::with_capacity(ciphertext.len());
-                let _ = decrypt(conn.local_key.session.as_ref().unwrap(), &conn.remote_key.as_ref().unwrap().nonce, &aad, &ciphertext, &tag, &mut plaintext);
-                let _ = out_buf.write(&plaintext);
-            },
-            _ => {
-                return Err("Received non-encrypted data on encrypted channel.");
-            },
-        }
+    //if let Ok(pkt) = read_packet(in_buf) {
+    match read_packet(in_buf) {
+        Ok(pkt) => {
+            if pkt.header.msg_id != conn.remote_msg_id {
+                println!("Message gap detected.  Restarting connection.");
+                conn.reset_state();
+                return Ok(0u16); // TODO: return error
+            }
+            conn.remote_msg_id = pkt.header.msg_id + 1;
+
+            match pkt.kind() {
+                PacketType::EncryptedData => {
+                    let (data_pkt, rest) = interpret_packet_extra::<EncryptedPacket>(&pkt).unwrap();
+                    let ciphertext = &rest[..data_pkt.data_len as usize];
+                    let tag = &rest[data_pkt.data_len as usize..];
+                    let aad = [];
+                    let mut plaintext = Vec::with_capacity(ciphertext.len());
+                    let _ = decrypt(conn.local_key.session.as_ref().unwrap(),
+                                    &conn.remote_key.as_ref().unwrap().nonce,
+                                    &aad, &ciphertext, &tag, &mut plaintext);
+                    let _ = out_buf.write(&plaintext);
+                    bytes = ciphertext.len() as u16;
+                },
+                _ => {
+                    println!("bad packet: {:x}", pkt.kind() as u16);
+                    return Err("Received non-encrypted data on encrypted channel.");
+                },
+            }
+        },
+        Err(_e) => {
+            // TODO
+        },
     }
-    Ok(())
+    Ok(bytes)
 }
 
 #[cfg(test)]
 mod tests {
     use std::thread;
+    use std::time;
     use std::net::{TcpListener, TcpStream};
     use crate::*;
 
     fn event_loop<T>(mut conn: ConnectionContext, mut stream: T, is_server: bool) -> Result<(), std::io::Error>
     where T: std::io::Read + std::io::Write {
-        while crypto_send_handshake(&mut conn, &mut stream) == true {}
-        while crypto_recv_handshake(&mut conn, &mut stream) == true {}
+        while crypto_handshake_done(&conn) == false {
+            if crypto_send_handshake(&mut conn, &mut stream) {
+                crypto_recv_handshake(&mut conn, &mut stream);
+            }
+        }
 
         if is_server {
             let mut plaintext = "hello, world".as_bytes();
             let _ = crypto_send_data(&mut conn, &mut plaintext, &mut stream);
+            let _ = stream.flush();
+            loop {
+                std::thread::sleep(time::Duration::from_millis(50));
+            }
         }
 
         loop {
             let mut plaintext = vec!();
             let _ = crypto_recv_data(&mut conn, &mut stream, &mut plaintext);
             println!("decrypted: {:?}", String::from_utf8(plaintext));
+            std::thread::sleep(time::Duration::from_millis(50));
         }
     }
 
@@ -366,7 +532,7 @@ mod tests {
         let listener = TcpListener::bind("127.0.0.1:9988").unwrap();
         for stream in listener.incoming() {
             let stream: TcpStream = stream.unwrap();
-            let conn = ConnectionContext::new();
+            let conn = ConnectionContext::new(true);
             let _ = event_loop(conn, stream, true);
         }
         Ok(())
@@ -374,7 +540,7 @@ mod tests {
 
     pub fn client() -> Result<(), std::io::Error> {
         let stream = TcpStream::connect("127.0.0.1:9988").unwrap();
-        let conn = ConnectionContext::new();
+        let conn = ConnectionContext::new(false);
         let _ = event_loop(conn, stream, false);
         Ok(())
     }