summary history branches tags files
commit:4af85b1a853304449cb84c86082336ee6db20baa
author:Trevor Bentley
committer:Trevor Bentley
date:Wed Jan 16 23:31:30 2019 +0100
parents:50e96fd9722b6feb4d076191de218be2645fcd81
Add authentication with ed25519 keys
diff --git a/Cargo.toml b/Cargo.toml
line changes: +2/-0
index a3ea4b9..a4fe0fa
--- a/Cargo.toml
+++ b/Cargo.toml
@@ -20,5 +20,7 @@ overflow-checks = false
 
 [dependencies]
 x25519-dalek = "0.3.0"
+ed25519-dalek = "0.8.1"
 chacha20-poly1305-aead = "0.1.2"
+sha2 = "0.7"
 rand = "0.6.1"

diff --git a/examples/ffi.c b/examples/ffi.c
line changes: +29/-6
index b000fb6..2329940
--- a/examples/ffi.c
+++ b/examples/ffi.c
@@ -4,16 +4,38 @@
 #include <unistd.h>
 #include "ossuary.h"
 
-uint8_t client_buf[256];
-uint8_t server_buf[256];
+uint8_t client_buf[512];
+uint8_t server_buf[512];
+
+uint8_t secret_key[] = {
+  0x10, 0x86, 0x6e, 0xc4, 0x8a, 0x11, 0xf3, 0xc5,
+  0x6d, 0x77, 0xa6, 0x4b, 0x2f, 0x54, 0xaa, 0x06,
+  0x6c, 0x0c, 0xb4, 0x75, 0xd8, 0xc8, 0x7d, 0x35,
+  0xb4, 0x91, 0xee, 0xd6, 0xac, 0x0b, 0xde, 0xbc
+};
+
+uint8_t public_key[32] = {
+  0xbe, 0x1c, 0xa0, 0x74, 0xf4, 0xa5, 0x8b, 0xbb,
+  0xd2, 0x62, 0xa7, 0xf9, 0x52, 0x3b, 0x6f, 0xb0,
+  0xbb, 0x9e, 0x86, 0x62, 0x28, 0x7c, 0x33, 0x89,
+  0xa2, 0xe1, 0x63, 0xdc, 0x55, 0xde, 0x28, 0x1f
+};
+
+uint8_t *authorized_keys[] = {
+  public_key,
+};
 
 int main(int argc, char **argv) {
   int client_done, server_done;
-  uint16_t client_bytes, server_bytes, bytes;
+  uint16_t client_bytes, server_bytes, bytes, out_len;
   ConnectionContext *client_conn = NULL;
   ConnectionContext *server_conn = NULL;
-  client_conn = ossuary_create_connection(0);
-  server_conn = ossuary_create_connection(1);
+
+  client_conn = ossuary_create_connection(CONN_TYPE_CLIENT);
+  ossuary_set_secret_key(client_conn, secret_key);
+
+  server_conn = ossuary_create_connection(CONN_TYPE_AUTHENTICATED_SERVER);
+  ossuary_set_authorized_keys(server_conn, authorized_keys, 1);
 
   memset(client_buf, 0, sizeof(client_buf));
   memset(server_buf, 0, sizeof(server_buf));
@@ -60,7 +82,8 @@ int main(int argc, char **argv) {
   printf("server send data bytes: %d\n", bytes);
 
   // Client receives decrypted data
-  bytes = ossuary_recv_data(client_conn, client_buf, bytes, client_buf, sizeof(client_buf));
+  out_len = sizeof(client_buf);
+  bytes = ossuary_recv_data(client_conn, client_buf, bytes, client_buf, &out_len);
   printf("client recv data bytes: %d\n", bytes);
   printf("decrypted: %s\n", client_buf);
 

diff --git a/ffi/ossuary.h b/ffi/ossuary.h
line changes: +10/-2
index 27ce7b7..f8fc02e
--- a/ffi/ossuary.h
+++ b/ffi/ossuary.h
@@ -4,8 +4,16 @@
 
 typedef struct ConnectionContext ConnectionContext;
 
-ConnectionContext *ossuary_create_connection(uint8_t is_server);
+typedef enum {
+  CONN_TYPE_CLIENT = 0x00,
+  CONN_TYPE_AUTHENTICATED_SERVER = 0x01,
+  CONN_TYPE_UNAUTHENTICATED_SERVER = 0x02,
+} connection_type_t;
+
+ConnectionContext *ossuary_create_connection(connection_type_t type);
 int32_t ossuary_destroy_connection(ConnectionContext **conn);
+int32_t ossuary_set_secret_key(ConnectionContext *conn, uint8_t key[32]);
+int32_t ossuary_set_authorized_keys(ConnectionContext *conn, uint8_t key[][32], uint8_t count);
 int32_t ossuary_recv_handshake(ConnectionContext *conn,
                                uint8_t *in_buf, uint16_t *in_buf_len);
 int32_t ossuary_send_handshake(ConnectionContext *conn,
@@ -16,7 +24,7 @@ int32_t ossuary_send_data(ConnectionContext *conn,
                           uint8_t *out_buf, uint16_t out_buf_len);
 int32_t ossuary_recv_data(ConnectionContext *conn,
                   uint8_t *in_buf, uint16_t in_buf_len,
-                  uint8_t *out_buf, uint16_t out_buf_len);
+                  uint8_t *out_buf, uint16_t *out_buf_len);
 
 #define _OSSUARY_H
 #endif

diff --git a/src/clib.rs b/src/clib.rs
line changes: +134/-65
index 1b4168c..a3fb22b
--- a/src/clib.rs
+++ b/src/clib.rs
@@ -1,14 +1,16 @@
 use crate::{crypto_send_data, crypto_recv_data,
             crypto_send_handshake, crypto_recv_handshake, crypto_handshake_done,
-            ConnectionContext};
+            ConnectionContext, ConnectionType};
 
 #[no_mangle]
-pub extern "C" fn ossuary_create_connection(is_server: u8) -> *mut ConnectionContext {
-    let is_server: bool = match is_server {
-        0 => false,
-        _ => true,
+pub extern "C" fn ossuary_create_connection(conn_type: u8) -> *mut ConnectionContext {
+    let conn_type: ConnectionType = match conn_type {
+        0 => ConnectionType::Client,
+        1 => ConnectionType::AuthenticatedServer,
+        2 => ConnectionType::UnauthenticatedServer,
+        _ => { return ::std::ptr::null_mut(); }
     };
-    let mut conn = Box::new(ConnectionContext::new(is_server)); // todo
+    let mut conn = Box::new(ConnectionContext::new(conn_type)); // todo
     let ptr: *mut _ = &mut *conn;
     ::std::mem::forget(conn);
     ptr
@@ -25,6 +27,43 @@ pub extern "C" fn ossuary_destroy_connection(conn: &mut *mut ConnectionContext) 
 }
 
 #[no_mangle]
+pub extern "C" fn ossuary_set_authorized_keys(conn: *mut ConnectionContext, keys: *const *const u8, key_count: u8) -> i32 {
+    if conn.is_null() || keys.is_null() {
+        return -1 as i32;
+    }
+    let conn = unsafe { &mut *conn };
+    let keys: &[*const u8] = unsafe { std::slice::from_raw_parts(keys, key_count as usize) };
+    let mut r_keys: Vec<&[u8]> = Vec::with_capacity(key_count as usize);
+    for key in keys {
+        if !key.is_null() {
+            let key: &[u8] = unsafe { std::slice::from_raw_parts(*key, 32) };
+            r_keys.push(key);
+        }
+    }
+    let written = match conn.set_authorized_keys(r_keys) {
+        Ok(c) => c as i32,
+        Err(_) => -1i32,
+    };
+    ::std::mem::forget(conn);
+    written
+}
+
+#[no_mangle]
+pub extern "C" fn ossuary_set_secret_key(conn: *mut ConnectionContext, key: *const u8) -> i32 {
+    if conn.is_null() || key.is_null() {
+        return -1 as i32;
+    }
+    let conn = unsafe { &mut *conn };
+    let key: &[u8] = unsafe { std::slice::from_raw_parts(key, 32) };
+    let success = match conn.set_secret_key(key) {
+        Ok(_) => 0i32,
+        Err(_) => -1i32,
+    };
+    ::std::mem::forget(conn);
+    success
+}
+
+#[no_mangle]
 pub extern "C" fn ossuary_recv_handshake(conn: *mut ConnectionContext,
                                          in_buf: *const u8, in_buf_len: *mut u16) -> i32 {
     if conn.is_null() || in_buf.is_null() || in_buf_len.is_null() {
@@ -36,8 +75,9 @@ pub extern "C" fn ossuary_recv_handshake(conn: *mut ConnectionContext,
     let mut slice = r_in_buf;
     crypto_recv_handshake(&mut conn, &mut slice);
     ::std::mem::forget(conn);
-    unsafe { *in_buf_len = (inlen - slice.len()) as u16 };
-    0i32
+    let written = (inlen - slice.len()) as u16;
+    unsafe { *in_buf_len = written };
+    written as i32 // TODO
 }
 
 #[no_mangle]
@@ -52,6 +92,7 @@ pub extern "C" fn ossuary_send_handshake(conn: *mut ConnectionContext,
     let mut slice = r_out_buf;
     let more = crypto_send_handshake(&mut conn, &mut slice);
     ::std::mem::forget(conn);
+    // TODO: error if data to send is larger than the given buffer
     unsafe { *out_buf_len = (outlen - slice.len()) as u16 };
     more as i32
 }
@@ -79,34 +120,42 @@ pub extern "C" fn ossuary_send_data(conn: *mut ConnectionContext,
     let r_in_buf: &[u8] = unsafe { std::slice::from_raw_parts(in_buf, in_buf_len as usize) };
     let mut out_slice = r_out_buf;
     let in_slice = r_in_buf;
+    let bytes_written: u16;
     match crypto_send_data(&mut conn, &in_slice, &mut out_slice) {
+        Ok(x) => {
+            bytes_written = x;
+        }
         Err(_) => { return -1; },
-        _ => {},
     }
     ::std::mem::forget(conn);
-    (out_buf_len - out_slice.len() as u16) as i32
+    //(out_buf_len - out_slice.len() as u16) as i32
+    bytes_written as i32
 }
 
 #[no_mangle]
 pub extern "C" fn ossuary_recv_data(conn: *mut ConnectionContext,
                                     in_buf: *mut u8, in_buf_len: u16,
-                                    out_buf: *mut u8, out_buf_len: u16) -> i32 {
-    if conn.is_null() || in_buf.is_null() || out_buf.is_null() {
+                                    out_buf: *mut u8, out_buf_len: *mut u16) -> i32 {
+    if conn.is_null() || in_buf.is_null() || out_buf.is_null() || out_buf_len.is_null() {
         return -1i32;
     }
     let mut conn = unsafe { &mut *conn };
-    let r_out_buf: &mut [u8] = unsafe { std::slice::from_raw_parts_mut(out_buf, out_buf_len as usize) };
+    let r_out_buf: &mut [u8] = unsafe { std::slice::from_raw_parts_mut(out_buf, *out_buf_len as usize) };
     let r_in_buf: &[u8] = unsafe { std::slice::from_raw_parts(in_buf, in_buf_len as usize) };
     let mut out_slice = r_out_buf;
     let mut in_slice = r_in_buf;
+    let bytes_read: u16;
     match crypto_recv_data(&mut conn, &mut in_slice, &mut out_slice) {
-        Ok(_) => {},
-        Err(e) => {
-            println!("recv_data failed: {} {}", e, conn.is_server);
-            return -1; },
+        Ok((read,written)) => {
+            unsafe { *out_buf_len = written };
+            bytes_read = read;
+        },
+        Err(_) => {
+            return -1;
+        },
     }
     ::std::mem::forget(conn);
-    (out_buf_len - out_slice.len() as u16) as i32
+    bytes_read as i32
 }
 
 #[cfg(test)]
@@ -114,55 +163,66 @@ mod tests {
     use std::thread;
     use std::io::{Read,Write};
     use std::net::{TcpListener, TcpStream};
+    use std::io::BufRead;
     use crate::clib::*;
+
     fn server() -> Result<(), std::io::Error> {
         let listener = TcpListener::bind("127.0.0.1:9989").unwrap();
         for stream in listener.incoming() {
             let mut stream: TcpStream = stream.unwrap();
+            let mut reader = std::io::BufReader::new(stream.try_clone().unwrap());
             let mut conn = ossuary_create_connection(1);
+            let key: &[u8; 32] = &[0xbe, 0x1c, 0xa0, 0x74, 0xf4, 0xa5, 0x8b, 0xbb,
+                                   0xd2, 0x62, 0xa7, 0xf9, 0x52, 0x3b, 0x6f, 0xb0,
+                                   0xbb, 0x9e, 0x86, 0x62, 0x28, 0x7c, 0x33, 0x89,
+                                   0xa2, 0xe1, 0x63, 0xdc, 0x55, 0xde, 0x28, 0x1f];
+            let keys: &[*const u8; 1] = &[key as *const u8];
+            ossuary_set_authorized_keys(conn, keys as *const *const u8, keys.len() as u8);
 
-            let out_buf: [u8; 256] = [0; 256];
-            let mut in_buf: [u8; 256] = [0; 256];
+            let out_buf: [u8; 512] = [0; 512];
 
             while ossuary_handshake_done(conn) == 0 {
                 let mut out_len = out_buf.len() as u16;
                 let more = ossuary_send_handshake(conn, (&out_buf) as *const u8 as *mut u8, &mut out_len);
-                let _ = stream.write(&out_buf[0..out_len as usize]);
+                let _ = stream.write_all(&out_buf[0..out_len as usize]).unwrap();
 
                 if more != 0 {
-                    let _ = stream.read(&mut in_buf);
+                    let in_buf = reader.fill_buf().unwrap();
                     let mut in_len = in_buf.len() as u16;
-                    ossuary_recv_handshake(conn, (&in_buf) as *const u8, &mut in_len);
+                    if in_len > 0 {
+                        let len = ossuary_recv_handshake(conn, in_buf as *const [u8] as *const u8, &mut in_len);
+                        reader.consume(len as usize);
+                    }
                 }
             }
 
-            let out_buf: [u8; 256] = [0; 256];
             let mut plaintext: [u8; 256] = [0; 256];
             plaintext[0..13].copy_from_slice("from server 1".as_bytes());
-            ossuary_send_data(
+            let sz = ossuary_send_data(
                 conn,
                 (&plaintext) as *const u8 as *mut u8, 13 as u16,
                 (&out_buf) as *const u8 as *mut u8, out_buf.len() as u16);
-            let _ = stream.write(&out_buf);
+            let _ = stream.write_all(&out_buf[0..sz as usize]).unwrap();
 
-            let out_buf: [u8; 256] = [0; 256];
-            let mut plaintext: [u8; 256] = [0; 256];
             plaintext[0..13].copy_from_slice("from server 2".as_bytes());
-            ossuary_send_data(
+            let sz = ossuary_send_data(
                 conn,
                 (&plaintext) as *const u8 as *mut u8, 13 as u16,
                 (&out_buf) as *const u8 as *mut u8, out_buf.len() as u16);
-            let _ = stream.write(&out_buf);
+            let _ = stream.write_all(&out_buf[0..sz as usize]).unwrap();
 
-            let _ = stream.read(&mut in_buf);
-            let out_buf: [u8; 256] = [0; 256];
-            let len = ossuary_recv_data(
-                conn,
-                (&in_buf) as *const u8 as *mut u8, in_buf.len() as u16,
-                (&out_buf) as *const u8 as *mut u8, out_buf.len() as u16);
-            if len != -1 {
-                println!("CLIB READ: {:?}",
-                         std::str::from_utf8(&out_buf[0..len as usize]).unwrap());
+            let in_buf = reader.fill_buf().unwrap();
+            if in_buf.len() > 0 {
+                let mut out_len = out_buf.len() as u16;
+                let len = ossuary_recv_data(
+                    conn,
+                    (in_buf) as *const [u8] as *mut u8, in_buf.len() as u16,
+                    (&out_buf) as *const u8 as *mut u8, &mut out_len);
+                if len != -1 {
+                    println!("CLIB READ: {:?}",
+                             std::str::from_utf8(&out_buf[0..out_len as usize]).unwrap());
+                reader.consume(len as usize);
+                }
             }
 
             ossuary_destroy_connection(&mut conn);
@@ -174,14 +234,19 @@ mod tests {
     fn client() -> Result<(), std::io::Error> {
         let mut stream = TcpStream::connect("127.0.0.1:9989").unwrap();
         let mut conn = ossuary_create_connection(0);
+        let key: &[u8; 32] = &[0x10, 0x86, 0x6e, 0xc4, 0x8a, 0x11, 0xf3, 0xc5,
+                               0x6d, 0x77, 0xa6, 0x4b, 0x2f, 0x54, 0xaa, 0x06,
+                               0x6c, 0x0c, 0xb4, 0x75, 0xd8, 0xc8, 0x7d, 0x35,
+                               0xb4, 0x91, 0xee, 0xd6, 0xac, 0x0b, 0xde, 0xbc];
+        ossuary_set_secret_key(conn, key as *const u8);
 
-        let out_buf: [u8; 256] = [0; 256];
-        let mut in_buf: [u8; 256] = [0; 256];
+        let out_buf: [u8; 512] = [0; 512];
+        let mut in_buf: [u8; 512] = [0; 512];
 
         while ossuary_handshake_done(conn) == 0 {
             let mut out_len = out_buf.len() as u16;
             let more = ossuary_send_handshake(conn, (&out_buf) as *const u8 as *mut u8, &mut out_len);
-            let _ = stream.write(&out_buf[0.. out_len as usize]);
+            let _ = stream.write_all(&out_buf[0.. out_len as usize]).unwrap();
 
             if more != 0 {
                 let _ = stream.read(&mut in_buf);
@@ -193,41 +258,45 @@ mod tests {
         let out_buf: [u8; 256] = [0; 256];
         let mut plaintext: [u8; 256] = [0; 256];
         plaintext[0..11].copy_from_slice("from client".as_bytes());
-        ossuary_send_data(
+        let sz = ossuary_send_data(
             conn,
             (&plaintext) as *const u8 as *mut u8, 11 as u16,
             (&out_buf) as *const u8 as *mut u8, out_buf.len() as u16);
-        let _ = stream.write(&out_buf);
+        let _ = stream.write_all(&out_buf[0..sz as usize]).unwrap();
 
-        let _ = stream.read(&mut in_buf);
-        let out_buf: [u8; 256] = [0; 256];
-        let len = ossuary_recv_data(
-            conn,
-            (&in_buf) as *const u8 as *mut u8, in_buf.len() as u16,
-            (&out_buf) as *const u8 as *mut u8, out_buf.len() as u16);
-        if len != -1 {
-            println!("CLIB READ: {:?}",
-                     std::str::from_utf8(&out_buf[0..len as usize]).unwrap());
-        }
-
-        let _ = stream.read(&mut in_buf);
-        let out_buf: [u8; 256] = [0; 256];
-        let len = ossuary_recv_data(
-            conn,
-            (&in_buf) as *const u8 as *mut u8, in_buf.len() as u16,
-            (&out_buf) as *const u8 as *mut u8, out_buf.len() as u16);
-        if len != -1 {
-            println!("CLIB READ: {:?}",
-                     std::str::from_utf8(&out_buf[0..len as usize]).unwrap());
+        let mut stream = std::io::BufReader::new(stream);
+        let mut count = 0;
+        loop {
+            let in_buf = stream.fill_buf().unwrap();
+            if in_buf.len() == 0 || count == 2 {
+                break;
+            }
+            let mut out_len = out_buf.len() as u16;
+            let len = ossuary_recv_data(
+                conn,
+                in_buf as *const [u8] as *mut u8, in_buf.len() as u16,
+                (&out_buf) as *const u8 as *mut u8, &mut out_len);
+            if len == -1 {
+                break;
+            }
+            if len > 0 {
+                println!("CLIB READ: {:?}",
+                         std::str::from_utf8(&out_buf[0..out_len as usize]).unwrap());
+                stream.consume(len as usize);
+                count += 1;
+            }
         }
 
         ossuary_destroy_connection(&mut conn);
         Ok(())
     }
+
     #[test]
     fn test() {
-        thread::spawn(move || { let _ = server(); });
+        let server = thread::spawn(move || { let _ = server(); });
+        std::thread::sleep(std::time::Duration::from_millis(500));
         let child = thread::spawn(move || { let _ = client(); });
         let _ = child.join();
+        let _ = server.join();
     }
 }

diff --git a/src/lib.rs b/src/lib.rs
line changes: +333/-99
index 94fc50c..0bf9d82
--- a/src/lib.rs
+++ b/src/lib.rs
@@ -2,14 +2,19 @@
 #![feature(try_from)]
 
 extern crate x25519_dalek;
+extern crate ed25519_dalek;
 extern crate rand;
 extern crate chacha20_poly1305_aead;
+extern crate sha2;
 
 use chacha20_poly1305_aead::{encrypt,decrypt};
 use x25519_dalek::generate_secret;
 use x25519_dalek::generate_public;
 use x25519_dalek::diffie_hellman;
 
+use ed25519_dalek::{Signature, Keypair, SecretKey, PublicKey};
+use sha2::Sha512;
+
 //use rand::thread_rng;
 use rand::RngCore;
 use rand::rngs::OsRng;
@@ -19,6 +24,7 @@ use std::convert::TryInto;
 pub mod clib;
 
 const MAX_PUB_KEY_ACK_TIME: u64 = 3u64;
+const CHALLENGE_LEN: usize = 256;
 //
 // API:
 //  * sock -- TCP data socket
@@ -55,6 +61,11 @@ const MAX_PUB_KEY_ACK_TIME: u64 = 3u64;
 //   * data (encrypted)
 //
 
+// TODO:
+//  - non-blocking IO
+//  - remove all unwraps()
+//  - consider all unexpected packet types to be errors
+
 fn struct_as_slice<T: Sized>(p: &T) -> &[u8] {
     unsafe {
         ::std::slice::from_raw_parts(
@@ -63,10 +74,10 @@ fn struct_as_slice<T: Sized>(p: &T) -> &[u8] {
         )
     }
 }
-fn slice_as_struct<T>(p: &[u8]) -> Result<&T, &'static str> {
+fn slice_as_struct<T>(p: &[u8]) -> Result<&T, OssuaryError> {
     unsafe {
         if p.len() < ::std::mem::size_of::<T>() {
-            return Err("Cannot cast bytes to struct: size mismatch");
+            return Err(OssuaryError::InvalidStruct);
         }
         Ok(&*(&p[..::std::mem::size_of::<T>()] as *const [u8] as *const T))
     }
@@ -75,6 +86,10 @@ fn slice_as_struct<T>(p: &[u8]) -> Result<&T, &'static str> {
 pub enum OssuaryError {
     Io(std::io::Error),
     Unpack(core::array::TryFromSliceError),
+    KeySize(usize, usize), // (expected, actual)
+    InvalidKey,
+    InvalidPacket(String),
+    InvalidStruct,
 }
 impl std::fmt::Debug for OssuaryError {
     fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
@@ -91,6 +106,11 @@ impl From<core::array::TryFromSliceError> for OssuaryError {
         OssuaryError::Unpack(error)
     }
 }
+impl From<ed25519_dalek::SignatureError> for OssuaryError {
+    fn from(_error: ed25519_dalek::SignatureError) -> Self {
+        OssuaryError::InvalidKey
+    }
+}
 
 #[repr(C,packed)]
 struct HandshakePacket {
@@ -114,22 +134,24 @@ impl Default for HandshakePacket {
 #[derive(Clone, Copy)]
 enum PacketType {
     Unknown = 0x00,
-    PublicKeyNonce = 0x01,
-    PubKeyAck = 0x02,
-    AuthRequest = 0x03,
-    Reset = 0x04,
-    Disconnect = 0x05,
-    EncryptedData = 0x10,
+    Disconnect = 0x01,
+    Reset = 0x02,
+    PublicKeyNonce = 0x10,
+    PubKeyAck = 0x11,
+    AuthChallenge = 0x12,
+    AuthResponse = 0x13,
+    EncryptedData = 0x20,
 }
 impl PacketType {
     pub fn from_u16(i: u16) -> PacketType {
         match i {
-            0x01 => PacketType::PublicKeyNonce,
-            0x02 => PacketType::PubKeyAck,
-            0x03 => PacketType::AuthRequest,
-            0x04 => PacketType::Reset,
-            0x05 => PacketType::Disconnect,
-            0x10 => PacketType::EncryptedData,
+            0x01 => PacketType::Disconnect,
+            0x02 => PacketType::Reset,
+            0x10 => PacketType::PublicKeyNonce,
+            0x11 => PacketType::PubKeyAck,
+            0x12 => PacketType::AuthChallenge,
+            0x13 => PacketType::AuthResponse,
+            0x20 => PacketType::EncryptedData,
             _ => PacketType::Unknown,
         }
     }
@@ -163,11 +185,16 @@ enum ConnectionState {
     ServerNew,
     ServerSendPubKey,
     ServerWaitAck(std::time::SystemTime),
+    ServerSendChallenge,
+    ServerWaitAuth(std::time::SystemTime),
 
     ClientNew,
     ClientWaitKey(std::time::SystemTime),
     ClientSendAck,
+    ClientWaitAck(std::time::SystemTime),
+    ClientSendAuth,
 
+    Failed,
     Encrypted,
 }
 struct KeyMaterial {
@@ -176,16 +203,27 @@ struct KeyMaterial {
     session: Option<[u8; 32]>,
     nonce: [u8; 12],
 }
+
+pub enum ConnectionType {
+    Client,
+    AuthenticatedServer,
+    UnauthenticatedServer,
+}
 pub struct ConnectionContext {
     state: ConnectionState,
-    is_server: bool,
+    conn_type: ConnectionType,
     local_key: KeyMaterial,
     remote_key: Option<KeyMaterial>,
     local_msg_id: u16,
     remote_msg_id: u16,
+    challenge: Option<Vec<u8>>,
+    challenge_sig: Option<Vec<u8>>,
+    authorized_keys: Vec<[u8; 32]>,
+    secret_key: Option<SecretKey>,
+    public_key: Option<PublicKey>,
 }
 impl ConnectionContext {
-    fn new(server: bool) -> ConnectionContext {
+    pub fn new(conn_type: ConnectionType) -> ConnectionContext {
         //let mut rng = thread_rng();
         let mut rng = OsRng::new().unwrap();
         let sec_key = generate_secret(&mut rng);
@@ -199,22 +237,37 @@ impl ConnectionContext {
             session: None,
         };
         ConnectionContext {
-            state: match server {
-                true => ConnectionState::ServerNew,
-                false => ConnectionState::ClientNew,
+            state: match conn_type {
+                ConnectionType::Client => ConnectionState::ClientNew,
+                _ => ConnectionState::ServerNew,
             },
-            is_server: server,
+            conn_type: conn_type,
             local_key: key,
             remote_key: None,
             local_msg_id: 0u16,
             remote_msg_id: 0u16,
+            challenge: None,
+            challenge_sig: None,
+            authorized_keys: vec!(),
+            secret_key: None,
+            public_key: None,
         }
     }
     fn reset_state(&mut self) {
-        self.state = match self.is_server {
-            true => ConnectionState::ServerNew,
-            false => ConnectionState::ClientNew,
+        self.state = match self.conn_type {
+            ConnectionType::Client => ConnectionState::ClientNew,
+            _ => ConnectionState::ServerNew,
         };
+        self.local_msg_id = 0;
+        self.challenge = None;
+        self.challenge_sig = None;
+        self.remote_key = None;
+    }
+    fn is_server(&self) -> bool {
+        match self.conn_type {
+            ConnectionType::Client => false,
+            _ => true,
+        }
     }
     fn add_remote_key(&mut self, public: &[u8; 32], nonce: &[u8; 12]) {
         let key = KeyMaterial {
@@ -226,14 +279,46 @@ impl ConnectionContext {
         self.remote_key = Some(key);
         self.local_key.session = Some(diffie_hellman(self.local_key.secret.as_ref().unwrap(), public));
     }
+    pub fn set_authorized_keys<'a,T>(&mut self, keys: T) -> Result<usize, OssuaryError>
+    where T: std::iter::IntoIterator<Item = &'a [u8]> {
+        let mut count: usize = 0;
+        for key in keys {
+            if key.len() != 32 {
+                return Err(OssuaryError::KeySize(32, key.len()));
+            }
+            let mut key_owned = [0u8; 32];
+            key_owned.copy_from_slice(key);
+            self.authorized_keys.push(key_owned);
+            count += 1;
+        }
+        Ok(count)
+    }
+    pub fn set_secret_key(&mut self, key: &[u8]) -> Result<(), OssuaryError> {
+        if key.len() != 32 {
+            return Err(OssuaryError::KeySize(32, key.len()));
+        }
+        let secret = SecretKey::from_bytes(key)?;
+        let public = PublicKey::from_secret::<Sha512>(&secret);
+        self.secret_key = Some(secret);
+        self.public_key = Some(public);
+        Ok(())
+    }
+    pub fn public_key(&self) -> Result<&[u8], OssuaryError> {
+        match self.public_key {
+            None => Err(OssuaryError::InvalidKey),
+            Some(ref p) => {
+                Ok(p.as_bytes())
+            }
+        }
+    }
 }
 
-fn interpret_packet<'a, T>(pkt: &'a NetworkPacket) -> Result<&'a T, &'static str> {
+fn interpret_packet<'a, T>(pkt: &'a NetworkPacket) -> Result<&'a T, OssuaryError> {
     let s: &T = slice_as_struct(&pkt.data)?;
     Ok(s)
 }
 
-fn interpret_packet_extra<'a, T>(pkt: &'a NetworkPacket) -> Result<(&'a T, &[u8]), &'static str> {
+fn interpret_packet_extra<'a, T>(pkt: &'a NetworkPacket) -> Result<(&'a T, &[u8]), OssuaryError> {
     let s: &T = slice_as_struct(&pkt.data)?;
     Ok((s, &pkt.data[::std::mem::size_of::<T>()..]))
 }
@@ -280,7 +365,8 @@ where T: std::ops::DerefMut<Target = U>,
             // wait for client
             true
         },
-        ConnectionState::ServerWaitAck(t) => {
+        ConnectionState::ServerWaitAck(t) |
+        ConnectionState::ServerWaitAuth(t) => {
             // TIMEOUT NACK
             if let Ok(dur) = t.elapsed() {
                 if dur.as_secs() > MAX_PUB_KEY_ACK_TIME {
@@ -302,6 +388,42 @@ where T: std::ops::DerefMut<Target = U>,
             conn.state = ConnectionState::ServerWaitAck(std::time::SystemTime::now());
             true
         },
+        ConnectionState::ServerSendChallenge => {
+            match conn.conn_type {
+                ConnectionType::AuthenticatedServer => {
+                    let aad = [];
+                    let mut challenge: [u8; CHALLENGE_LEN] = [0; CHALLENGE_LEN];
+                    let mut rng = OsRng::new().unwrap();
+                    rng.fill_bytes(&mut challenge);
+                    conn.challenge = Some(challenge.to_vec());
+                    let mut ciphertext = Vec::with_capacity(CHALLENGE_LEN);
+                    let tag = encrypt(conn.local_key.session.as_ref().unwrap(),
+                                      &conn.local_key.nonce,
+                                      &aad, &challenge, &mut ciphertext).unwrap();
+
+                    let pkt: EncryptedPacket = EncryptedPacket {
+                        tag_len: tag.len() as u16,
+                        data_len: ciphertext.len() as u16,
+                    };
+                    let mut pkt_buf: Vec<u8>= vec![];
+                    pkt_buf.extend(struct_as_slice(&pkt));
+                    pkt_buf.extend(&ciphertext);
+                    pkt_buf.extend(&tag);
+                    let _ = write_packet(&mut buf, &pkt_buf,
+                                         &mut next_msg_id, PacketType::AuthChallenge);
+                    conn.state = ConnectionState::ServerWaitAuth(std::time::SystemTime::now());
+                    true
+                },
+                _ => {
+                    // Unauthenticated
+                    let pkt: HandshakePacket = Default::default();
+                    let _ = write_packet(&mut buf, struct_as_slice(&pkt),
+                                         &mut next_msg_id, PacketType::PubKeyAck);
+                    conn.state = ConnectionState::Encrypted;
+                    false
+                },
+            }
+        },
         ConnectionState::ClientNew => {
             // Send pubkey
             let mut pkt: HandshakePacket = Default::default();
@@ -315,7 +437,7 @@ where T: std::ops::DerefMut<Target = U>,
         ConnectionState::ClientWaitKey(t) => {
             if let Ok(dur) = t.elapsed() {
                 if dur.as_secs() > MAX_PUB_KEY_ACK_TIME {
-                    conn.state = ConnectionState::ClientNew;
+                    conn.reset_state();
                 }
             }
             true
@@ -324,14 +446,67 @@ where T: std::ops::DerefMut<Target = U>,
             let pkt: HandshakePacket = Default::default();
             let _ = write_packet(&mut buf, struct_as_slice(&pkt),
                                  &mut next_msg_id, PacketType::PubKeyAck);
+            conn.state = ConnectionState::ClientWaitAck(std::time::SystemTime::now());
+            true
+        },
+        ConnectionState::ClientWaitAck(t) => {
+            if let Ok(dur) = t.elapsed() {
+                if dur.as_secs() > MAX_PUB_KEY_ACK_TIME {
+                    conn.reset_state();
+                }
+            }
+            true
+        },
+        ConnectionState::ClientSendAuth => {
+            // TODO: import secret key
+            if conn.secret_key.is_none() {
+                conn.reset_state();
+                // TODO: raise error
+                return true;
+            }
+            let secret = conn.secret_key.as_ref()
+                .map(|sec| SecretKey::from_bytes(sec.as_bytes()).unwrap())
+                .unwrap();
+            let public = PublicKey::from_secret::<Sha512>(&secret);
+            let keypair = Keypair { secret: secret, public: public };
+            let sig = keypair.sign::<Sha512>(&conn.challenge.as_ref().unwrap()).to_bytes();
+            let mut pkt_data: Vec<u8> = Vec::with_capacity(CHALLENGE_LEN + 32);
+            pkt_data.extend_from_slice(public.as_bytes());
+            pkt_data.extend_from_slice(&sig);
+            conn.challenge_sig = Some(sig.to_vec());
+
+            let aad = [];
+            let mut ciphertext = Vec::with_capacity(pkt_data.len());
+            let tag = encrypt(conn.local_key.session.as_ref().unwrap(),
+                              &conn.local_key.nonce,
+                              &aad, &pkt_data, &mut ciphertext).unwrap();
+
+            let pkt: EncryptedPacket = EncryptedPacket {
+                tag_len: tag.len() as u16,
+                data_len: ciphertext.len() as u16,
+            };
+            let mut pkt_buf: Vec<u8>= vec![];
+            pkt_buf.extend(struct_as_slice(&pkt));
+            pkt_buf.extend(&ciphertext);
+            pkt_buf.extend(&tag);
+            let _ = write_packet(&mut buf, &pkt_buf,
+                                 &mut next_msg_id, PacketType::AuthResponse);
             conn.state = ConnectionState::Encrypted;
             false
         },
+        ConnectionState::Failed => {
+            let pkt: HandshakePacket = Default::default();
+            let _ = write_packet(&mut buf, struct_as_slice(&pkt),
+                                 &mut next_msg_id, PacketType::Disconnect);
+            conn.reset_state();
+            true
+        },
         ConnectionState::Encrypted => {
             false
         },
     };
     conn.local_msg_id = next_msg_id;
+    // TODO: either this should return amount write, or send_data() should not
     more
 }
 
@@ -347,6 +522,7 @@ where T: std::ops::DerefMut<Target = U>,
 
     if pkt.header.msg_id != conn.remote_msg_id {
         println!("Message gap detected.  Restarting connection.");
+        println!("Server: {}", conn.is_server());
         conn.reset_state();
         return; // TODO: return error
     }
@@ -355,12 +531,13 @@ where T: std::ops::DerefMut<Target = U>,
     let mut error = false;
     match pkt.kind() {
         PacketType::Reset => {
-            conn.state = match conn.is_server {
-                true => ConnectionState::ServerNew,
-                _ => ConnectionState::ClientNew,
-            };
+            conn.reset_state();
             return;
         },
+        PacketType::Disconnect => {
+            // TODO: handle error
+            panic!("Remote side terminated connection.");
+        },
         _ => {},
     }
 
@@ -378,7 +555,45 @@ where T: std::ops::DerefMut<Target = U>,
         ConnectionState::ServerWaitAck(_t) => {
             match pkt.kind() {
                 PacketType::PubKeyAck => {
-                    conn.state = ConnectionState::Encrypted;
+                    conn.state = ConnectionState::ServerSendChallenge;
+                },
+                _ => { error = true; }
+            }
+        },
+        ConnectionState::ServerWaitAuth(_t) => {
+            // TODO (auth)
+            match pkt.kind() {
+                PacketType::AuthResponse => {
+                    let (data_pkt, rest) = interpret_packet_extra::<EncryptedPacket>(&pkt).unwrap();
+                    let ciphertext = &rest[..data_pkt.data_len as usize];
+                    let tag = &rest[data_pkt.data_len as usize..];
+                    let aad = [];
+                    let mut plaintext = Vec::with_capacity(ciphertext.len());
+                    let _ = decrypt(conn.local_key.session.as_ref().unwrap(),
+                                    &conn.remote_key.as_ref().unwrap().nonce,
+                                    &aad, &ciphertext, &tag, &mut plaintext);
+                    let pubkey = &plaintext[0..32];
+                    let sig = &plaintext[32..];
+
+                    if conn.authorized_keys.iter().filter(|k| &pubkey == k).count() > 0 {
+                        let public = PublicKey::from_bytes(pubkey).unwrap();
+                        let sig = Signature::from_bytes(sig).unwrap();
+                        match public.verify::<Sha512>(conn.challenge.as_ref().unwrap(), &sig) {
+                            Ok(_) => {
+                                conn.state = ConnectionState::Encrypted;
+                            },
+                            Err(_) => {
+                                println!("Verify bad");
+                                // TODO: error
+                                conn.state = ConnectionState::Failed;
+                            },
+                        }
+                    }
+                    else {
+                        println!("Key not allowed");
+                        // TODO: error
+                        conn.state = ConnectionState::Failed;
+                    }
                 },
                 _ => { error = true; }
             }
@@ -386,6 +601,9 @@ where T: std::ops::DerefMut<Target = U>,
         ConnectionState::ServerSendPubKey => {
             error = true;
         }, // nop
+        ConnectionState::ServerSendChallenge => {
+            error = true;
+        }, // nop
         ConnectionState::ClientNew => {
             error = true;
         }, // nop
@@ -402,18 +620,42 @@ where T: std::ops::DerefMut<Target = U>,
         ConnectionState::ClientSendAck => {
             error = true;
         }, // nop
+        ConnectionState::ClientWaitAck(_t) => {
+            match pkt.kind() {
+                PacketType::PubKeyAck => {
+                    conn.state = ConnectionState::Encrypted;
+                },
+                PacketType::AuthChallenge => {
+                    let (data_pkt, rest) = interpret_packet_extra::<EncryptedPacket>(&pkt).unwrap();
+                    let ciphertext = &rest[..data_pkt.data_len as usize];
+                    let tag = &rest[data_pkt.data_len as usize..];
+                    let aad = [];
+                    let mut plaintext = Vec::with_capacity(ciphertext.len());
+                    let _ = decrypt(conn.local_key.session.as_ref().unwrap(),
+                                    &conn.remote_key.as_ref().unwrap().nonce,
+                                    &aad, &ciphertext, &tag, &mut plaintext);
+                    conn.challenge = Some(plaintext);
+                    conn.state = ConnectionState::ClientSendAuth;
+                },
+                _ => {},
+            }
+        },
+        ConnectionState::ClientSendAuth => {
+            error = true;
+        }, // nop
+        ConnectionState::Failed => {
+            error = true;
+        }, // nop
         ConnectionState::Encrypted => {
             error = true;
         }, // nop
     }
     if error {
-        conn.state = match conn.is_server {
-            true => ConnectionState::ServerNew,
-            _ => ConnectionState::ClientNew,
-        };
+        conn.reset_state();
     }
 }
 
+// TODO: should return a Result with error on forced-disconnect or permanent failure
 pub fn crypto_handshake_done(conn: &ConnectionContext) -> bool {
     match conn.state {
         ConnectionState::Encrypted => true,
@@ -421,12 +663,14 @@ pub fn crypto_handshake_done(conn: &ConnectionContext) -> bool {
     }
 }
 
-pub fn crypto_send_data<T,U>(conn: &mut ConnectionContext, in_buf: &[u8], mut out_buf: T) -> Result<u16, &'static str>
+pub fn crypto_send_data<T,U>(conn: &mut ConnectionContext, in_buf: &[u8], mut out_buf: T) -> Result<u16, OssuaryError>
 where T: std::ops::DerefMut<Target = U>,
       U: std::io::Write {
     match conn.state {
         ConnectionState::Encrypted => {},
-        _ => { return Err("Encrypted channel not established."); }
+        _ => {
+            return Err(OssuaryError::InvalidPacket("Encrypted channel not established.".into()));
+        }
     }
     let mut next_msg_id = conn.local_msg_id;
     let bytes;
@@ -445,28 +689,32 @@ where T: std::ops::DerefMut<Target = U>,
     buf.extend(&tag);
     let _ = write_packet(&mut out_buf, &buf,
                          &mut next_msg_id, PacketType::EncryptedData);
-    bytes = buf.len() as u16;
+    bytes = (buf.len() + ::std::mem::size_of::<PacketHeader>()) as u16;
     conn.local_msg_id = next_msg_id;
     Ok(bytes)
 }
 
-pub fn crypto_recv_data<T,U,R,V>(conn: &mut ConnectionContext, in_buf: T, mut out_buf: R) -> Result<u16, &'static str>
+pub fn crypto_recv_data<T,U,R,V>(conn: &mut ConnectionContext, in_buf: T, mut out_buf: R) -> Result<(u16, u16), OssuaryError>
 where T: std::ops::DerefMut<Target = U>,
       U: std::io::Read,
       R: std::ops::DerefMut<Target = V>,
       V: std::io::Write {
-    let mut bytes: u16 = 0u16;
+    let bytes_written: u16;
+    let bytes_read: u16;
     match conn.state {
         ConnectionState::Encrypted => {},
-        _ => { return Err("Encrypted channel not established."); }
+        _ => {
+            return Err(OssuaryError::InvalidPacket("Encrypted channel not established.".into()));
+        }
     }
-    //if let Ok(pkt) = read_packet(in_buf) {
+
     match read_packet(in_buf) {
         Ok(pkt) => {
             if pkt.header.msg_id != conn.remote_msg_id {
                 println!("Message gap detected.  Restarting connection.");
+                println!("Server: {}", conn.is_server());
                 conn.reset_state();
-                return Ok(0u16); // TODO: return error
+                return Err(OssuaryError::InvalidPacket("Message ID mismatch".into()))
             }
             conn.remote_msg_id = pkt.header.msg_id + 1;
 
@@ -481,19 +729,22 @@ where T: std::ops::DerefMut<Target = U>,
                                     &conn.remote_key.as_ref().unwrap().nonce,
                                     &aad, &ciphertext, &tag, &mut plaintext);
                     let _ = out_buf.write(&plaintext);
-                    bytes = ciphertext.len() as u16;
+                    bytes_written = ciphertext.len() as u16;
+                    bytes_read = (ciphertext.len() +
+                                  ::std::mem::size_of::<PacketHeader>() +
+                                  ::std::mem::size_of::<EncryptedPacket>() +
+                                  tag.len()) as u16;
                 },
                 _ => {
-                    println!("bad packet: {:x}", pkt.kind() as u16);
-                    return Err("Received non-encrypted data on encrypted channel.");
+                    return Err(OssuaryError::InvalidPacket("Received non-encrypted data on encrypted channel.".into()));
                 },
             }
         },
         Err(_e) => {
-            // TODO
+            return Err(OssuaryError::InvalidPacket("Packet header did not parse.".into()));
         },
     }
-    Ok(bytes)
+    Ok((bytes_read, bytes_written))
 }
 
 #[cfg(test)]
@@ -504,54 +755,36 @@ mod tests {
     use std::net::{TcpListener, TcpStream};
     use crate::*;
 
-    fn event_loop<T>(mut conn: ConnectionContext,
-                     mut stream: T,
-                     is_server: bool) -> Result<(), std::io::Error>
-    where T: std::io::Read + std::io::Write {
-        while crypto_handshake_done(&conn) == false {
-            if crypto_send_handshake(&mut conn, &mut stream) {
-                crypto_recv_handshake(&mut conn, &mut stream);
-            }
-        }
-
-        if is_server {
-            let mut plaintext = "message from server".as_bytes();
-            let _ = crypto_send_data(&mut conn, &mut plaintext, &mut stream);
-        }
-        else {
-            let mut plaintext = "message from client".as_bytes();
-            let _ = crypto_send_data(&mut conn, &mut plaintext, &mut stream);
-        }
-
-        let mut plaintext = vec!();
-        let _ = crypto_recv_data(&mut conn, &mut stream, &mut plaintext);
-        println!("LIB READ: {:?}", String::from_utf8(plaintext).unwrap());
-        Ok(())
-    }
-
-
-    fn server() -> Result<(), std::io::Error> {
-        let listener = TcpListener::bind("127.0.0.1:9988").unwrap();
-        for stream in listener.incoming() {
-            let stream: TcpStream = stream.unwrap();
-            let conn = ConnectionContext::new(true);
-            let _ = event_loop(conn, stream, true);
-        }
-        Ok(())
-    }
-
-    fn client() -> Result<(), std::io::Error> {
-        let stream = TcpStream::connect("127.0.0.1:9988").unwrap();
-        let conn = ConnectionContext::new(false);
-        let _ = event_loop(conn, stream, false);
-        Ok(())
-    }
-
     #[test]
-    fn test() {
-        thread::spawn(move || { let _ = server(); });
-        let child = thread::spawn(move || { let _ = client(); });
-        let _ = child.join();
+    fn test_set_authorized_keys() {
+        let mut conn = ConnectionContext::new(ConnectionType::AuthenticatedServer);
+
+        // Vec of slices
+        let keys: Vec<&[u8]> = vec![
+            &[0xbe, 0x1c, 0xa0, 0x74, 0xf4, 0xa5, 0x8b, 0xbb,
+              0xd2, 0x62, 0xa7, 0xf9, 0x52, 0x3b, 0x6f, 0xb0,
+              0xbb, 0x9e, 0x86, 0x62, 0x28, 0x7c, 0x33, 0x89,
+              0xa2, 0xe1, 0x63, 0xdc, 0x55, 0xde, 0x28, 0x1f]
+        ];
+        let _ = conn.set_authorized_keys(keys).unwrap();
+
+        // Vec of owned arrays
+        let keys: Vec<[u8; 32]> = vec![
+            [0xbe, 0x1c, 0xa0, 0x74, 0xf4, 0xa5, 0x8b, 0xbb,
+             0xd2, 0x62, 0xa7, 0xf9, 0x52, 0x3b, 0x6f, 0xb0,
+             0xbb, 0x9e, 0x86, 0x62, 0x28, 0x7c, 0x33, 0x89,
+             0xa2, 0xe1, 0x63, 0xdc, 0x55, 0xde, 0x28, 0x1f]
+        ];
+        let _ = conn.set_authorized_keys(keys.iter().map(|x| x.as_ref()).collect::<Vec<&[u8]>>()).unwrap();
+
+        // Vec of vecs
+        let keys: Vec<Vec<u8>> = vec![
+            vec![0xbe, 0x1c, 0xa0, 0x74, 0xf4, 0xa5, 0x8b, 0xbb,
+                 0xd2, 0x62, 0xa7, 0xf9, 0x52, 0x3b, 0x6f, 0xb0,
+                 0xbb, 0x9e, 0x86, 0x62, 0x28, 0x7c, 0x33, 0x89,
+                 0xa2, 0xe1, 0x63, 0xdc, 0x55, 0xde, 0x28, 0x1f]
+        ];
+        let _ = conn.set_authorized_keys(keys.iter().map(|x| x.as_slice())).unwrap();
     }
 
     #[bench]
@@ -559,7 +792,7 @@ mod tests {
         let server_thread = thread::spawn(move || {
             let listener = TcpListener::bind("127.0.0.1:9987").unwrap();
             let mut server_stream = listener.incoming().next().unwrap().unwrap();
-            let mut server_conn = ConnectionContext::new(true);
+            let mut server_conn = ConnectionContext::new(ConnectionType::UnauthenticatedServer);
             while crypto_handshake_done(&server_conn) == false {
                 if crypto_send_handshake(&mut server_conn, &mut server_stream) {
                     crypto_recv_handshake(&mut server_conn, &mut server_stream);
@@ -571,7 +804,7 @@ mod tests {
             loop {
                 bytes += crypto_recv_data(&mut server_conn,
                                           &mut server_stream,
-                                          &mut plaintext).unwrap() as u64;
+                                          &mut plaintext).unwrap().0 as u64;
                 if plaintext == [0xde, 0xde, 0xbe, 0xbe] {
                     if let Ok(dur) = start.elapsed() {
                         let t = dur.as_secs() as f64
@@ -585,8 +818,9 @@ mod tests {
             }
         });
 
+        std::thread::sleep(std::time::Duration::from_millis(500));
         let mut client_stream = TcpStream::connect("127.0.0.1:9987").unwrap();
-        let mut client_conn = ConnectionContext::new(false);
+        let mut client_conn = ConnectionContext::new(ConnectionType::Client);
         while crypto_handshake_done(&client_conn) == false {
             if crypto_send_handshake(&mut client_conn, &mut client_stream) {
                 crypto_recv_handshake(&mut client_conn, &mut client_stream);

diff --git a/tests/basic.rs b/tests/basic.rs
line changes: +56/-0
index 0000000..18c23da
--- /dev/null
+++ b/tests/basic.rs
@@ -0,0 +1,56 @@
+use ossuary::{ConnectionContext, ConnectionType};
+use ossuary::{crypto_send_handshake,crypto_recv_handshake, crypto_handshake_done};
+use ossuary::{crypto_send_data,crypto_recv_data};
+
+use std::thread;
+use std::net::{TcpListener, TcpStream};
+
+fn event_loop<T>(mut conn: ConnectionContext,
+                 mut stream: T,
+                 is_server: bool) -> Result<(), std::io::Error>
+where T: std::io::Read + std::io::Write {
+    // Run the opaque handshake until the connection is established
+    while crypto_handshake_done(&conn) == false {
+        if crypto_send_handshake(&mut conn, &mut stream) {
+            crypto_recv_handshake(&mut conn, &mut stream);
+        }
+    }
+
+    // Send a message to the other party
+    let mut plaintext = match is_server {
+        true => "message from server".as_bytes(),
+        false => "message from client".as_bytes(),
+    };
+    let _ = crypto_send_data(&mut conn, &mut plaintext, &mut stream);
+
+    // Read a message from the other party
+    let mut plaintext = vec!();
+    let _ = crypto_recv_data(&mut conn, &mut stream, &mut plaintext);
+    println!("(basic) received: {:?}", String::from_utf8(plaintext).unwrap());
+
+    Ok(())
+}
+
+fn server() -> Result<(), std::io::Error> {
+    let listener = TcpListener::bind("127.0.0.1:9988").unwrap();
+    let stream: TcpStream = listener.incoming().next().unwrap().unwrap();
+    let conn = ConnectionContext::new(ConnectionType::UnauthenticatedServer);
+    let _ = event_loop(conn, stream, true);
+    Ok(())
+}
+
+fn client() -> Result<(), std::io::Error> {
+    let stream = TcpStream::connect("127.0.0.1:9988").unwrap();
+    let conn = ConnectionContext::new(ConnectionType::Client);
+    let _ = event_loop(conn, stream, false);
+    Ok(())
+}
+
+#[test]
+fn basic() {
+    let server = thread::spawn(move || { let _ = server(); });
+    std::thread::sleep(std::time::Duration::from_millis(500));
+    let child = thread::spawn(move || { let _ = client(); });
+    let _ = child.join();
+    let _ = server.join();
+}

diff --git a/tests/basic_auth.rs b/tests/basic_auth.rs
line changes: +68/-0
index 0000000..8139a71
--- /dev/null
+++ b/tests/basic_auth.rs
@@ -0,0 +1,68 @@
+use ossuary::{ConnectionContext, ConnectionType};
+use ossuary::{crypto_send_handshake,crypto_recv_handshake, crypto_handshake_done};
+use ossuary::{crypto_send_data,crypto_recv_data};
+
+use std::thread;
+use std::net::{TcpListener, TcpStream};
+
+fn event_loop<T>(mut conn: ConnectionContext,
+                 mut stream: T,
+                 is_server: bool) -> Result<(), std::io::Error>
+where T: std::io::Read + std::io::Write {
+    // Run the opaque handshake until the connection is established
+    while crypto_handshake_done(&conn) == false {
+        if crypto_send_handshake(&mut conn, &mut stream) {
+            crypto_recv_handshake(&mut conn, &mut stream);
+        }
+    }
+
+    // Send a message to the other party
+    let mut plaintext = match is_server {
+        true => "message from server".as_bytes(),
+        false => "message from client".as_bytes(),
+    };
+    let _ = crypto_send_data(&mut conn, &mut plaintext, &mut stream);
+
+    // Read a message from the other party
+    let mut plaintext = vec!();
+    let _ = crypto_recv_data(&mut conn, &mut stream, &mut plaintext);
+    println!("(basic) received: {:?}", String::from_utf8(plaintext).unwrap());
+
+    Ok(())
+}
+
+fn server() -> Result<(), std::io::Error> {
+    let listener = TcpListener::bind("127.0.0.1:9988").unwrap();
+    let stream: TcpStream = listener.incoming().next().unwrap().unwrap();
+    let mut conn = ConnectionContext::new(ConnectionType::AuthenticatedServer);
+    let keys: Vec<&[u8]> = vec![
+        &[0xbe, 0x1c, 0xa0, 0x74, 0xf4, 0xa5, 0x8b, 0xbb,
+          0xd2, 0x62, 0xa7, 0xf9, 0x52, 0x3b, 0x6f, 0xb0,
+          0xbb, 0x9e, 0x86, 0x62, 0x28, 0x7c, 0x33, 0x89,
+          0xa2, 0xe1, 0x63, 0xdc, 0x55, 0xde, 0x28, 0x1f]
+    ];
+    let _ = conn.set_authorized_keys(keys).unwrap();
+    let _ = event_loop(conn, stream, true);
+    Ok(())
+}
+
+fn client() -> Result<(), std::io::Error> {
+    let stream = TcpStream::connect("127.0.0.1:9988").unwrap();
+    let mut conn = ConnectionContext::new(ConnectionType::Client);
+    let _ = conn.set_secret_key(
+        &[0x10, 0x86, 0x6e, 0xc4, 0x8a, 0x11, 0xf3, 0xc5,
+          0x6d, 0x77, 0xa6, 0x4b, 0x2f, 0x54, 0xaa, 0x06,
+          0x6c, 0x0c, 0xb4, 0x75, 0xd8, 0xc8, 0x7d, 0x35,
+          0xb4, 0x91, 0xee, 0xd6, 0xac, 0x0b, 0xde, 0xbc]).unwrap();
+    let _ = event_loop(conn, stream, false);
+    Ok(())
+}
+
+#[test]
+fn basic() {
+    let server = thread::spawn(move || { let _ = server(); });
+    std::thread::sleep(std::time::Duration::from_millis(500));
+    let child = thread::spawn(move || { let _ = client(); });
+    let _ = child.join();
+    let _ = server.join();
+}