summary history branches tags files
commit:02a7478e626c8d381e4ccecf8b2848220c5e336d
author:Trevor Bentley
committer:Trevor Bentley
date:Thu Jan 17 12:48:18 2019 +0100
parents:4af85b1a853304449cb84c86082336ee6db20baa
Improve error handling, remove all unwraps
diff --git a/ffi/ossuary.h b/ffi/ossuary.h
line changes: +2/-2
index f8fc02e..b786fba
--- a/ffi/ossuary.h
+++ b/ffi/ossuary.h
@@ -13,12 +13,12 @@ typedef enum {
 ConnectionContext *ossuary_create_connection(connection_type_t type);
 int32_t ossuary_destroy_connection(ConnectionContext **conn);
 int32_t ossuary_set_secret_key(ConnectionContext *conn, uint8_t key[32]);
-int32_t ossuary_set_authorized_keys(ConnectionContext *conn, uint8_t key[][32], uint8_t count);
+int32_t ossuary_set_authorized_keys(ConnectionContext *conn, uint8_t *key[], uint8_t count);
 int32_t ossuary_recv_handshake(ConnectionContext *conn,
                                uint8_t *in_buf, uint16_t *in_buf_len);
 int32_t ossuary_send_handshake(ConnectionContext *conn,
                                uint8_t *out_buf, uint16_t *out_buf_len);
-uint8_t ossuary_handshake_done(ConnectionContext *conn);
+int32_t ossuary_handshake_done(ConnectionContext *conn);
 int32_t ossuary_send_data(ConnectionContext *conn,
                           uint8_t *in_buf, uint16_t in_buf_len,
                           uint8_t *out_buf, uint16_t out_buf_len);

diff --git a/src/clib.rs b/src/clib.rs
line changes: +6/-3
index a3fb22b..d7624c8
--- a/src/clib.rs
+++ b/src/clib.rs
@@ -98,14 +98,17 @@ pub extern "C" fn ossuary_send_handshake(conn: *mut ConnectionContext,
 }
 
 #[no_mangle]
-pub extern "C" fn ossuary_handshake_done(conn: *const ConnectionContext) -> u8 {
+pub extern "C" fn ossuary_handshake_done(conn: *const ConnectionContext) -> i32 {
     if conn.is_null() {
-        return 0u8;
+        return -1i32;
     }
     let conn = unsafe { &*conn };
     let done = crypto_handshake_done(&conn);
     ::std::mem::forget(conn);
-    done as u8
+    match done {
+        Ok(done) => done as i32,
+        Err(_) => -1i32,
+    }
 }
 
 #[no_mangle]

diff --git a/src/lib.rs b/src/lib.rs
line changes: +303/-138
index 0bf9d82..8452917
--- a/src/lib.rs
+++ b/src/lib.rs
@@ -23,8 +23,12 @@ use std::convert::TryInto;
 
 pub mod clib;
 
-const MAX_PUB_KEY_ACK_TIME: u64 = 3u64;
+// Maximum time to wait (in seconds) for a handshake response
+const MAX_HANDSHAKE_WAIT_TIME: u64 = 3u64;
+
+// Size of the random data to be signed by client
 const CHALLENGE_LEN: usize = 256;
+
 //
 // API:
 //  * sock -- TCP data socket
@@ -65,6 +69,7 @@ const CHALLENGE_LEN: usize = 256;
 //  - non-blocking IO
 //  - remove all unwraps()
 //  - consider all unexpected packet types to be errors
+//  - limit connection retries
 
 fn struct_as_slice<T: Sized>(p: &T) -> &[u8] {
     unsafe {
@@ -90,6 +95,8 @@ pub enum OssuaryError {
     InvalidKey,
     InvalidPacket(String),
     InvalidStruct,
+    InvalidSignature,
+    ConnectionFailed,
 }
 impl std::fmt::Debug for OssuaryError {
     fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
@@ -194,7 +201,7 @@ enum ConnectionState {
     ClientWaitAck(std::time::SystemTime),
     ClientSendAuth,
 
-    Failed,
+    Failed(OssuaryError),
     Encrypted,
 }
 struct KeyMaterial {
@@ -225,7 +232,7 @@ pub struct ConnectionContext {
 impl ConnectionContext {
     pub fn new(conn_type: ConnectionType) -> ConnectionContext {
         //let mut rng = thread_rng();
-        let mut rng = OsRng::new().unwrap();
+        let mut rng = OsRng::new().expect("RNG not available.");
         let sec_key = generate_secret(&mut rng);
         let pub_key = generate_public(&sec_key);
         let mut nonce: [u8; 12] = [0; 12];
@@ -253,10 +260,17 @@ impl ConnectionContext {
             public_key: None,
         }
     }
-    fn reset_state(&mut self) {
-        self.state = match self.conn_type {
-            ConnectionType::Client => ConnectionState::ClientNew,
-            _ => ConnectionState::ServerNew,
+    fn reset_state(&mut self, permanent_err: Option<OssuaryError>) {
+        self.state = match permanent_err {
+            None => {
+                match self.conn_type {
+                    ConnectionType::Client => ConnectionState::ClientNew,
+                    _ => ConnectionState::ServerNew,
+                }
+            },
+            Some(e) => {
+                ConnectionState::Failed(e)
+            }
         };
         self.local_msg_id = 0;
         self.challenge = None;
@@ -277,7 +291,9 @@ impl ConnectionContext {
             session: None,
         };
         self.remote_key = Some(key);
-        self.local_key.session = Some(diffie_hellman(self.local_key.secret.as_ref().unwrap(), public));
+        if let Some(secret) = self.local_key.secret.as_ref() {
+            self.local_key.session = Some(diffie_hellman(secret, public));
+        }
     }
     pub fn set_authorized_keys<'a,T>(&mut self, keys: T) -> Result<usize, OssuaryError>
     where T: std::iter::IntoIterator<Item = &'a [u8]> {
@@ -362,24 +378,30 @@ where T: std::ops::DerefMut<Target = U>,
     let mut next_msg_id = conn.local_msg_id;
     let more = match conn.state {
         ConnectionState::ServerNew => {
-            // wait for client
+            // Wait for client to initiate connection
             true
         },
+        ConnectionState::Encrypted => {
+            // Handshake finished
+            false
+        },
         ConnectionState::ServerWaitAck(t) |
-        ConnectionState::ServerWaitAuth(t) => {
-            // TIMEOUT NACK
+        ConnectionState::ServerWaitAuth(t) |
+        ConnectionState::ClientWaitKey(t) |
+        ConnectionState::ClientWaitAck(t) => {
+            // Wait for response, with timeout
             if let Ok(dur) = t.elapsed() {
-                if dur.as_secs() > MAX_PUB_KEY_ACK_TIME {
+                if dur.as_secs() > MAX_HANDSHAKE_WAIT_TIME {
                     let pkt: HandshakePacket = Default::default();
                     let _ = write_packet(&mut buf, struct_as_slice(&pkt),
                                          &mut next_msg_id, PacketType::Reset);
-                    conn.state = ConnectionState::ServerNew;
+                    conn.reset_state(None);
                 }
             }
             true
         },
         ConnectionState::ServerSendPubKey => {
-            // Send pubkey
+            // Send session public key and nonce to the client
             let mut pkt: HandshakePacket = Default::default();
             pkt.public_key.copy_from_slice(&conn.local_key.public);
             pkt.nonce.copy_from_slice(&conn.local_key.nonce);
@@ -391,16 +413,37 @@ where T: std::ops::DerefMut<Target = U>,
         ConnectionState::ServerSendChallenge => {
             match conn.conn_type {
                 ConnectionType::AuthenticatedServer => {
+                    // Send a block of random data over the encrypted session to
+                    // the client.  The client must sign it with its key to prove
+                    // key possession.
+                    let mut rng = match OsRng::new() {
+                        Ok(rng) => rng,
+                        Err(_) => {
+                            conn.reset_state(None);
+                            return true;
+                        }
+                    };
                     let aad = [];
                     let mut challenge: [u8; CHALLENGE_LEN] = [0; CHALLENGE_LEN];
-                    let mut rng = OsRng::new().unwrap();
                     rng.fill_bytes(&mut challenge);
                     conn.challenge = Some(challenge.to_vec());
                     let mut ciphertext = Vec::with_capacity(CHALLENGE_LEN);
-                    let tag = encrypt(conn.local_key.session.as_ref().unwrap(),
-                                      &conn.local_key.nonce,
-                                      &aad, &challenge, &mut ciphertext).unwrap();
-
+                    let session_key = match conn.local_key.session {
+                        Some(ref s) => s,
+                        None => {
+                            conn.reset_state(None);
+                            return true;
+                        }
+                    };
+                    let tag = match encrypt(session_key,
+                                            &conn.local_key.nonce,
+                                            &aad, &challenge, &mut ciphertext) {
+                        Ok(tag) => tag,
+                        Err(_) => {
+                            conn.reset_state(None);
+                            return true;
+                        }
+                    };
                     let pkt: EncryptedPacket = EncryptedPacket {
                         tag_len: tag.len() as u16,
                         data_len: ciphertext.len() as u16,
@@ -415,17 +458,17 @@ where T: std::ops::DerefMut<Target = U>,
                     true
                 },
                 _ => {
-                    // Unauthenticated
+                    // For unauthenticated connections, we are done.  Already encrypted.
                     let pkt: HandshakePacket = Default::default();
                     let _ = write_packet(&mut buf, struct_as_slice(&pkt),
                                          &mut next_msg_id, PacketType::PubKeyAck);
                     conn.state = ConnectionState::Encrypted;
-                    false
+                    false // handshake is finished (success)
                 },
             }
         },
         ConnectionState::ClientNew => {
-            // Send pubkey
+            // Send session public key and nonce to initiate connection
             let mut pkt: HandshakePacket = Default::default();
             pkt.public_key.copy_from_slice(&conn.local_key.public);
             pkt.nonce.copy_from_slice(&conn.local_key.nonce);
@@ -434,42 +477,40 @@ where T: std::ops::DerefMut<Target = U>,
             conn.state = ConnectionState::ClientWaitKey(std::time::SystemTime::now());
             true
         },
-        ConnectionState::ClientWaitKey(t) => {
-            if let Ok(dur) = t.elapsed() {
-                if dur.as_secs() > MAX_PUB_KEY_ACK_TIME {
-                    conn.reset_state();
-                }
-            }
-            true
-        },
         ConnectionState::ClientSendAck => {
+            // Acknowledge reception of server's session public key and nonce
             let pkt: HandshakePacket = Default::default();
             let _ = write_packet(&mut buf, struct_as_slice(&pkt),
                                  &mut next_msg_id, PacketType::PubKeyAck);
             conn.state = ConnectionState::ClientWaitAck(std::time::SystemTime::now());
             true
         },
-        ConnectionState::ClientWaitAck(t) => {
-            if let Ok(dur) = t.elapsed() {
-                if dur.as_secs() > MAX_PUB_KEY_ACK_TIME {
-                    conn.reset_state();
-                }
-            }
-            true
-        },
         ConnectionState::ClientSendAuth => {
-            // TODO: import secret key
-            if conn.secret_key.is_none() {
-                conn.reset_state();
-                // TODO: raise error
-                return true;
-            }
-            let secret = conn.secret_key.as_ref()
-                .map(|sec| SecretKey::from_bytes(sec.as_bytes()).unwrap())
-                .unwrap();
+            // Send signature of the server's challenge back to the server,
+            // along with the public part of the authentication key.  This is
+            // done over the established encrypted channel.
+            let secret = match conn.secret_key {
+                Some(ref s) => match SecretKey::from_bytes(s.as_bytes()) {
+                    Ok(s) => s, // local copy of secret key
+                    Err(_) => {
+                        conn.reset_state(Some(OssuaryError::InvalidKey));
+                        return true;
+                    }
+                },
+                None => {
+                    conn.reset_state(Some(OssuaryError::InvalidKey));
+                    return true;
+                }
+            };
             let public = PublicKey::from_secret::<Sha512>(&secret);
             let keypair = Keypair { secret: secret, public: public };
-            let sig = keypair.sign::<Sha512>(&conn.challenge.as_ref().unwrap()).to_bytes();
+            let sig = match conn.challenge {
+                Some(ref c) => keypair.sign::<Sha512>(c).to_bytes(),
+                None => {
+                    conn.reset_state(None);
+                    return true;
+                }
+            };
             let mut pkt_data: Vec<u8> = Vec::with_capacity(CHALLENGE_LEN + 32);
             pkt_data.extend_from_slice(public.as_bytes());
             pkt_data.extend_from_slice(&sig);
@@ -477,9 +518,22 @@ where T: std::ops::DerefMut<Target = U>,
 
             let aad = [];
             let mut ciphertext = Vec::with_capacity(pkt_data.len());
-            let tag = encrypt(conn.local_key.session.as_ref().unwrap(),
-                              &conn.local_key.nonce,
-                              &aad, &pkt_data, &mut ciphertext).unwrap();
+            let session_key = match conn.local_key.session {
+                Some(ref s) => s,
+                None => {
+                    conn.reset_state(None);
+                    return true;
+                }
+            };
+            let tag = match encrypt(session_key,
+                                    &conn.local_key.nonce,
+                                    &aad, &pkt_data, &mut ciphertext) {
+                Ok(t) => t,
+                Err(_) => {
+                    conn.reset_state(None);
+                    return true;
+                }
+            };
 
             let pkt: EncryptedPacket = EncryptedPacket {
                 tag_len: tag.len() as u16,
@@ -492,17 +546,15 @@ where T: std::ops::DerefMut<Target = U>,
             let _ = write_packet(&mut buf, &pkt_buf,
                                  &mut next_msg_id, PacketType::AuthResponse);
             conn.state = ConnectionState::Encrypted;
-            false
+            false // handshake is finished (success)
         },
-        ConnectionState::Failed => {
+        ConnectionState::Failed(_) => {
+            // This is a permanent failure.
             let pkt: HandshakePacket = Default::default();
             let _ = write_packet(&mut buf, struct_as_slice(&pkt),
                                  &mut next_msg_id, PacketType::Disconnect);
-            conn.reset_state();
-            true
-        },
-        ConnectionState::Encrypted => {
-            false
+            conn.reset_state(Some(OssuaryError::ConnectionFailed));
+            false // handshake is finished (failed)
         },
     };
     conn.local_msg_id = next_msg_id;
@@ -514,16 +566,17 @@ pub fn crypto_recv_handshake<T,U>(conn: &mut ConnectionContext, buf: T)
 where T: std::ops::DerefMut<Target = U>,
       U: std::io::Read {
     // TODO: read_exact won't work.
-    let pkt = read_packet(buf);
-    if pkt.is_err() {
-        return;
-    }
-    let pkt: NetworkPacket = pkt.unwrap();
+    let pkt: NetworkPacket = match read_packet(buf) {
+        Ok(p) => p,
+        Err(_) => {
+            return;
+        }
+    };
 
     if pkt.header.msg_id != conn.remote_msg_id {
         println!("Message gap detected.  Restarting connection.");
         println!("Server: {}", conn.is_server());
-        conn.reset_state();
+        conn.reset_state(None);
         return; // TODO: return error
     }
     conn.remote_msg_id = pkt.header.msg_id + 1;
@@ -531,11 +584,12 @@ where T: std::ops::DerefMut<Target = U>,
     let mut error = false;
     match pkt.kind() {
         PacketType::Reset => {
-            conn.reset_state();
+            conn.reset_state(None);
             return;
         },
         PacketType::Disconnect => {
             // TODO: handle error
+            conn.reset_state(Some(OssuaryError::ConnectionFailed));
             panic!("Remote side terminated connection.");
         },
         _ => {},
@@ -545,9 +599,16 @@ where T: std::ops::DerefMut<Target = U>,
         ConnectionState::ServerNew => {
             match pkt.kind() {
                 PacketType::PublicKeyNonce => {
-                    let data_pkt: &HandshakePacket = interpret_packet(&pkt).as_ref().unwrap();
-                    conn.add_remote_key(&data_pkt.public_key, &data_pkt.nonce);
-                    conn.state = ConnectionState::ServerSendPubKey;
+                    let data_pkt: Result<&HandshakePacket, _> = interpret_packet(&pkt);
+                    match data_pkt {
+                        Ok(ref data_pkt) => {
+                            conn.add_remote_key(&data_pkt.public_key, &data_pkt.nonce);
+                            conn.state = ConnectionState::ServerSendPubKey;
+                        },
+                        Err(_) => {
+                            error = true;
+                        },
+                    }
                 },
                 _ => { error = true; }
             }
@@ -561,39 +622,76 @@ where T: std::ops::DerefMut<Target = U>,
             }
         },
         ConnectionState::ServerWaitAuth(_t) => {
-            // TODO (auth)
             match pkt.kind() {
                 PacketType::AuthResponse => {
-                    let (data_pkt, rest) = interpret_packet_extra::<EncryptedPacket>(&pkt).unwrap();
-                    let ciphertext = &rest[..data_pkt.data_len as usize];
-                    let tag = &rest[data_pkt.data_len as usize..];
-                    let aad = [];
-                    let mut plaintext = Vec::with_capacity(ciphertext.len());
-                    let _ = decrypt(conn.local_key.session.as_ref().unwrap(),
-                                    &conn.remote_key.as_ref().unwrap().nonce,
-                                    &aad, &ciphertext, &tag, &mut plaintext);
-                    let pubkey = &plaintext[0..32];
-                    let sig = &plaintext[32..];
-
-                    if conn.authorized_keys.iter().filter(|k| &pubkey == k).count() > 0 {
-                        let public = PublicKey::from_bytes(pubkey).unwrap();
-                        let sig = Signature::from_bytes(sig).unwrap();
-                        match public.verify::<Sha512>(conn.challenge.as_ref().unwrap(), &sig) {
-                            Ok(_) => {
-                                conn.state = ConnectionState::Encrypted;
-                            },
-                            Err(_) => {
-                                println!("Verify bad");
-                                // TODO: error
-                                conn.state = ConnectionState::Failed;
-                            },
-                        }
-                    }
-                    else {
-                        println!("Key not allowed");
-                        // TODO: error
-                        conn.state = ConnectionState::Failed;
-                    }
+                    match interpret_packet_extra::<EncryptedPacket>(&pkt) {
+                        Ok((data_pkt, rest)) => {
+                            let ciphertext = &rest[..data_pkt.data_len as usize];
+                            let tag = &rest[data_pkt.data_len as usize..];
+                            let aad = [];
+                            let mut plaintext = Vec::with_capacity(ciphertext.len());
+                            let session_key = match conn.local_key.session {
+                                Some(ref k) => k,
+                                None => {
+                                    conn.reset_state(None);
+                                    return;
+                                }
+                            };
+                            let remote_nonce = match conn.remote_key {
+                                Some(ref rem) => rem.nonce,
+                                None => {
+                                    conn.reset_state(None);
+                                    return;
+                                }
+                            };
+                            let _ = decrypt(session_key,
+                                            &remote_nonce,
+                                            &aad, &ciphertext, &tag, &mut plaintext);
+                            let pubkey = &plaintext[0..32];
+                            let sig = &plaintext[32..];
+                            if conn.authorized_keys.iter().filter(|k| &pubkey == k).count() > 0 {
+                                let public = match PublicKey::from_bytes(pubkey) {
+                                    Ok(p) => p,
+                                    Err(_) => {
+                                        conn.reset_state(None);
+                                        return;
+                                    }
+                                };
+                                let sig = match Signature::from_bytes(sig) {
+                                    Ok(s) => s,
+                                    Err(_) => {
+                                        conn.reset_state(None);
+                                        return;
+                                    }
+                                };
+                                let challenge = match conn.challenge {
+                                    Some(ref c) => c,
+                                    None => {
+                                        conn.reset_state(None);
+                                        return;
+                                    }
+                                };
+                                match public.verify::<Sha512>(challenge, &sig) {
+                                    Ok(_) => {
+                                        conn.state = ConnectionState::Encrypted;
+                                    },
+                                    Err(_) => {
+                                        println!("Verify bad");
+                                        conn.state = ConnectionState::Failed(
+                                            OssuaryError::InvalidSignature);
+                                    },
+                                }
+                            }
+                            else {
+                                println!("Key not allowed");
+                                conn.state = ConnectionState::Failed(OssuaryError::InvalidKey);
+                            }
+                        },
+                        Err(_) => {
+                            conn.reset_state(None);
+                            return;
+                        },
+                    };
                 },
                 _ => { error = true; }
             }
@@ -610,11 +708,20 @@ where T: std::ops::DerefMut<Target = U>,
         ConnectionState::ClientWaitKey(_t) => {
             match pkt.kind() {
                 PacketType::PublicKeyNonce => {
-                    let data_pkt: &HandshakePacket = interpret_packet(&pkt).as_ref().unwrap();
-                    conn.add_remote_key(&data_pkt.public_key, &data_pkt.nonce);
-                    conn.state = ConnectionState::ClientSendAck;
+                    let data_pkt: Result<&HandshakePacket, _> = interpret_packet(&pkt);
+                    match data_pkt {
+                        Ok(data_pkt) => {
+                            conn.add_remote_key(&data_pkt.public_key, &data_pkt.nonce);
+                            conn.state = ConnectionState::ClientSendAck;
+                        },
+                        Err(_) => {
+                            error = true;
+                        },
+                    }
                 },
-                _ => { }
+                _ => {
+                    error = true;
+                }
             }
         },
         ConnectionState::ClientSendAck => {
@@ -626,24 +733,47 @@ where T: std::ops::DerefMut<Target = U>,
                     conn.state = ConnectionState::Encrypted;
                 },
                 PacketType::AuthChallenge => {
-                    let (data_pkt, rest) = interpret_packet_extra::<EncryptedPacket>(&pkt).unwrap();
-                    let ciphertext = &rest[..data_pkt.data_len as usize];
-                    let tag = &rest[data_pkt.data_len as usize..];
-                    let aad = [];
-                    let mut plaintext = Vec::with_capacity(ciphertext.len());
-                    let _ = decrypt(conn.local_key.session.as_ref().unwrap(),
-                                    &conn.remote_key.as_ref().unwrap().nonce,
-                                    &aad, &ciphertext, &tag, &mut plaintext);
-                    conn.challenge = Some(plaintext);
-                    conn.state = ConnectionState::ClientSendAuth;
+                    match interpret_packet_extra::<EncryptedPacket>(&pkt) {
+                        Ok((data_pkt, rest)) => {
+                            let ciphertext = &rest[..data_pkt.data_len as usize];
+                            let tag = &rest[data_pkt.data_len as usize..];
+                            let aad = [];
+                            let mut plaintext = Vec::with_capacity(ciphertext.len());
+
+                            let session_key = match conn.local_key.session {
+                                Some(ref k) => k,
+                                None => {
+                                    conn.reset_state(None);
+                                    return;
+                                }
+                            };
+                            let remote_nonce = match conn.remote_key {
+                                Some(ref rem) => rem.nonce,
+                                None => {
+                                    conn.reset_state(None);
+                                    return;
+                                }
+                            };
+                            let _ = decrypt(session_key,
+                                            &remote_nonce,
+                                            &aad, &ciphertext, &tag, &mut plaintext);
+                            conn.challenge = Some(plaintext);
+                            conn.state = ConnectionState::ClientSendAuth;
+                        },
+                        Err(_) => {
+                            error = true;
+                        },
+                    }
+                },
+                _ => {
+                    error = true;
                 },
-                _ => {},
             }
         },
         ConnectionState::ClientSendAuth => {
             error = true;
         }, // nop
-        ConnectionState::Failed => {
+        ConnectionState::Failed(_) => {
             error = true;
         }, // nop
         ConnectionState::Encrypted => {
@@ -651,15 +781,16 @@ where T: std::ops::DerefMut<Target = U>,
         }, // nop
     }
     if error {
-        conn.reset_state();
+        conn.reset_state(None);
     }
 }
 
 // TODO: should return a Result with error on forced-disconnect or permanent failure
-pub fn crypto_handshake_done(conn: &ConnectionContext) -> bool {
+pub fn crypto_handshake_done(conn: &ConnectionContext) -> Result<bool, &OssuaryError> {
     match conn.state {
-        ConnectionState::Encrypted => true,
-        _ => false,
+        ConnectionState::Encrypted => Ok(true),
+        ConnectionState::Failed(ref e) => Err(e),
+        _ => Ok(false),
     }
 }
 
@@ -676,8 +807,21 @@ where T: std::ops::DerefMut<Target = U>,
     let bytes;
     let aad = [];
     let mut ciphertext = Vec::with_capacity(in_buf.len());
-    let tag = encrypt(conn.local_key.session.as_ref().unwrap(),
-                      &conn.local_key.nonce, &aad, in_buf, &mut ciphertext).unwrap();
+    let session_key = match conn.local_key.session {
+        Some(ref k) => k,
+        None => {
+            conn.reset_state(None);
+            return Err(OssuaryError::InvalidKey);;
+        }
+    };
+    let tag = match encrypt(session_key,
+                            &conn.local_key.nonce, &aad, in_buf, &mut ciphertext) {
+        Ok(t) => t,
+        Err(_) => {
+            conn.reset_state(None);
+            return Err(OssuaryError::InvalidKey);;
+        }
+    };
 
     let pkt: EncryptedPacket = EncryptedPacket {
         tag_len: tag.len() as u16,
@@ -713,27 +857,48 @@ where T: std::ops::DerefMut<Target = U>,
             if pkt.header.msg_id != conn.remote_msg_id {
                 println!("Message gap detected.  Restarting connection.");
                 println!("Server: {}", conn.is_server());
-                conn.reset_state();
+                conn.reset_state(None);
                 return Err(OssuaryError::InvalidPacket("Message ID mismatch".into()))
             }
             conn.remote_msg_id = pkt.header.msg_id + 1;
 
             match pkt.kind() {
                 PacketType::EncryptedData => {
-                    let (data_pkt, rest) = interpret_packet_extra::<EncryptedPacket>(&pkt).unwrap();
-                    let ciphertext = &rest[..data_pkt.data_len as usize];
-                    let tag = &rest[data_pkt.data_len as usize..];
-                    let aad = [];
-                    let mut plaintext = Vec::with_capacity(ciphertext.len());
-                    let _ = decrypt(conn.local_key.session.as_ref().unwrap(),
-                                    &conn.remote_key.as_ref().unwrap().nonce,
-                                    &aad, &ciphertext, &tag, &mut plaintext);
-                    let _ = out_buf.write(&plaintext);
-                    bytes_written = ciphertext.len() as u16;
-                    bytes_read = (ciphertext.len() +
-                                  ::std::mem::size_of::<PacketHeader>() +
-                                  ::std::mem::size_of::<EncryptedPacket>() +
-                                  tag.len()) as u16;
+                    match interpret_packet_extra::<EncryptedPacket>(&pkt) {
+                        Ok((data_pkt, rest)) => {
+                            let ciphertext = &rest[..data_pkt.data_len as usize];
+                            let tag = &rest[data_pkt.data_len as usize..];
+                            let aad = [];
+                            let mut plaintext = Vec::with_capacity(ciphertext.len());
+                            let session_key = match conn.local_key.session {
+                                Some(ref k) => k,
+                                None => {
+                                    conn.reset_state(None);
+                                    return Err(OssuaryError::InvalidKey);
+                                }
+                            };
+                            let remote_nonce = match conn.remote_key {
+                                Some(ref rem) => rem.nonce,
+                                None => {
+                                    conn.reset_state(None);
+                                    return Err(OssuaryError::InvalidKey);
+                                }
+                            };
+                            let _ = decrypt(session_key,
+                                            &remote_nonce,
+                                            &aad, &ciphertext, &tag, &mut plaintext);
+                            let _ = out_buf.write(&plaintext);
+                            bytes_written = ciphertext.len() as u16;
+                            bytes_read = (ciphertext.len() +
+                                          ::std::mem::size_of::<PacketHeader>() +
+                                          ::std::mem::size_of::<EncryptedPacket>() +
+                                          tag.len()) as u16;
+                        },
+                        Err(_) => {
+                            conn.reset_state(None);
+                            return Err(OssuaryError::InvalidKey);
+                        },
+                    }
                 },
                 _ => {
                     return Err(OssuaryError::InvalidPacket("Received non-encrypted data on encrypted channel.".into()));
@@ -793,7 +958,7 @@ mod tests {
             let listener = TcpListener::bind("127.0.0.1:9987").unwrap();
             let mut server_stream = listener.incoming().next().unwrap().unwrap();
             let mut server_conn = ConnectionContext::new(ConnectionType::UnauthenticatedServer);
-            while crypto_handshake_done(&server_conn) == false {
+            while crypto_handshake_done(&server_conn).unwrap() == false {
                 if crypto_send_handshake(&mut server_conn, &mut server_stream) {
                     crypto_recv_handshake(&mut server_conn, &mut server_stream);
                 }
@@ -821,7 +986,7 @@ mod tests {
         std::thread::sleep(std::time::Duration::from_millis(500));
         let mut client_stream = TcpStream::connect("127.0.0.1:9987").unwrap();
         let mut client_conn = ConnectionContext::new(ConnectionType::Client);
-        while crypto_handshake_done(&client_conn) == false {
+        while crypto_handshake_done(&client_conn).unwrap() == false {
             if crypto_send_handshake(&mut client_conn, &mut client_stream) {
                 crypto_recv_handshake(&mut client_conn, &mut client_stream);
             }

diff --git a/tests/basic.rs b/tests/basic.rs
line changes: +1/-1
index 18c23da..839483f
--- a/tests/basic.rs
+++ b/tests/basic.rs
@@ -10,7 +10,7 @@ fn event_loop<T>(mut conn: ConnectionContext,
                  is_server: bool) -> Result<(), std::io::Error>
 where T: std::io::Read + std::io::Write {
     // Run the opaque handshake until the connection is established
-    while crypto_handshake_done(&conn) == false {
+    while crypto_handshake_done(&conn).unwrap() == false {
         if crypto_send_handshake(&mut conn, &mut stream) {
             crypto_recv_handshake(&mut conn, &mut stream);
         }

diff --git a/tests/basic_auth.rs b/tests/basic_auth.rs
line changes: +1/-1
index 8139a71..8e7f961
--- a/tests/basic_auth.rs
+++ b/tests/basic_auth.rs
@@ -10,7 +10,7 @@ fn event_loop<T>(mut conn: ConnectionContext,
                  is_server: bool) -> Result<(), std::io::Error>
 where T: std::io::Read + std::io::Write {
     // Run the opaque handshake until the connection is established
-    while crypto_handshake_done(&conn) == false {
+    while crypto_handshake_done(&conn).unwrap() == false {
         if crypto_send_handshake(&mut conn, &mut stream) {
             crypto_recv_handshake(&mut conn, &mut stream);
         }