summary history branches tags files
src/handshake.rs
//
// Copyright 2019 Trevor Bentley
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
//     http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//
use crate::*;

use comm::{read_packet, write_packet, write_stored_packet};

use std::io::Write;

const CLIENT_HANDSHAKE_PACKET_LEN: usize = CHALLENGE_LEN + NONCE_LEN + KEY_LEN + 8;
const CLIENT_AUTH_PACKET_LEN: usize = CLIENT_AUTH_SUBPACKET_LEN + 8;
const SERVER_HANDSHAKE_PACKET_LEN: usize = NONCE_LEN + KEY_LEN + SERVER_HANDSHAKE_SUBPACKET_LEN + 8;
const SERVER_HANDSHAKE_SUBPACKET_LEN: usize = ::std::mem::size_of::<ServerEncryptedHandshakePacket>() +
    ::std::mem::size_of::<EncryptedPacket>() + TAG_LEN;
const CLIENT_AUTH_SUBPACKET_LEN: usize = ::std::mem::size_of::<ClientEncryptedAuthenticationPacket>() +
    ::std::mem::size_of::<EncryptedPacket>() + TAG_LEN;

#[repr(C,packed)]
pub(crate) struct ResetPacket {
    len: u16,
    pub(crate) error: bool,
    _reserved: u8,
}
impl Default for ResetPacket {
    fn default() -> ResetPacket {
        ResetPacket {
            len: ::std::mem::size_of::<ResetPacket> as u16,
            error: true,
            _reserved: 0u8,
        }
    }
}
impl ResetPacket {
    fn closed() -> ResetPacket {
        ResetPacket {
            len: ::std::mem::size_of::<ResetPacket> as u16,
            error: false,
            _reserved: 0u8,
        }
    }
}

#[repr(C,packed)]
struct ClientHandshakePacket {
    len: u16,
    version: u8,
    _reserved: [u8; 5],
    public_key: [u8; KEY_LEN],
    nonce: [u8; NONCE_LEN],
    challenge: [u8; CHALLENGE_LEN],
}
impl Default for ClientHandshakePacket {
    fn default() -> ClientHandshakePacket {
        ClientHandshakePacket {
            len: CLIENT_HANDSHAKE_PACKET_LEN as u16,
            version: PROTOCOL_VERSION,
            _reserved: [0u8; 5],
            public_key: [0u8; KEY_LEN],
            nonce: [0u8; NONCE_LEN],
            challenge: [0u8; CHALLENGE_LEN],
        }
    }
}
impl ClientHandshakePacket {
    fn new(pubkey: &[u8], nonce: &[u8], challenge: &[u8]) -> ClientHandshakePacket {
        let mut pkt: ClientHandshakePacket = Default::default();
        pkt.public_key.copy_from_slice(pubkey);
        pkt.nonce.copy_from_slice(nonce);
        pkt.challenge.copy_from_slice(challenge);
        pkt
    }
    fn from_packet(pkt: &NetworkPacket) -> Result<&ClientHandshakePacket, OssuaryError> {
        let hs_pkt = interpret_packet::<ClientHandshakePacket>(&pkt);
        match hs_pkt {
            Ok(pkt) => {
                if pkt.version != PROTOCOL_VERSION {
                    return Err(OssuaryError::WrongProtocolVersion(pkt.version, PROTOCOL_VERSION));
                }
                if pkt.len as usize != CLIENT_HANDSHAKE_PACKET_LEN {
                    return Err(OssuaryError::InvalidPacket("Unexpected packet size.".into()));
                }
            },
            _ => {},
        }
        hs_pkt
    }
}

#[repr(C,packed)]
struct ServerEncryptedHandshakePacket {
    public_key: [u8; KEY_LEN],
    challenge: [u8; CHALLENGE_LEN],
    signature: [u8; SIGNATURE_LEN],
}
impl Default for ServerEncryptedHandshakePacket {
    fn default() -> ServerEncryptedHandshakePacket {
        ServerEncryptedHandshakePacket {
            public_key: [0u8; KEY_LEN],
            challenge: [0u8; CHALLENGE_LEN],
            signature: [0u8; SIGNATURE_LEN],
        }
    }
}
impl ServerEncryptedHandshakePacket {
    fn from_bytes(data: &[u8]) -> Result<&ServerEncryptedHandshakePacket, OssuaryError> {
        let s: &ServerEncryptedHandshakePacket = slice_as_struct(&data)?;
        Ok(s)
    }
}

#[repr(C,packed)]
struct ServerHandshakePacket {
    len: u16,
    version: u8,
    _reserved: [u8; 5],
    public_key: [u8; KEY_LEN],
    nonce: [u8; NONCE_LEN],
    subpacket: [u8; SERVER_HANDSHAKE_SUBPACKET_LEN],
}
impl Default for ServerHandshakePacket {
    fn default() -> ServerHandshakePacket {
        ServerHandshakePacket {
            len: SERVER_HANDSHAKE_PACKET_LEN as u16,
            version: PROTOCOL_VERSION,
            _reserved: [0u8; 5],
            public_key: [0u8; KEY_LEN],
            nonce: [0u8; NONCE_LEN],
            subpacket: [0; SERVER_HANDSHAKE_SUBPACKET_LEN],
        }
    }
}
impl ServerHandshakePacket {
    fn new(session_pubkey: &[u8], nonce: &[u8], session_privkey: &[u8],
           server_pubkey: &[u8], challenge: &[u8], signature: &[u8]) -> Result<ServerHandshakePacket, OssuaryError> {
        let mut pkt: ServerHandshakePacket = Default::default();
        let mut enc_pkt: ServerEncryptedHandshakePacket = Default::default();
        pkt.public_key.copy_from_slice(session_pubkey);
        pkt.nonce.copy_from_slice(nonce);
        enc_pkt.public_key.copy_from_slice(server_pubkey);
        enc_pkt.challenge.copy_from_slice(challenge);
        enc_pkt.signature.copy_from_slice(signature);
        encrypt_to_bytes(session_privkey, nonce, struct_as_slice(&enc_pkt), &mut pkt.subpacket)?;
        Ok(pkt)
    }
    fn from_packet(pkt: &NetworkPacket) -> Result<&ServerHandshakePacket, OssuaryError> {
        let hs_pkt = interpret_packet::<ServerHandshakePacket>(&pkt);
        match hs_pkt {
            Ok(pkt) => {
                if pkt.version != PROTOCOL_VERSION {
                    return Err(OssuaryError::WrongProtocolVersion(pkt.version, PROTOCOL_VERSION));
                }
                if pkt.len as usize != SERVER_HANDSHAKE_PACKET_LEN {
                    return Err(OssuaryError::InvalidPacket("Unexpected packet size.".into()));
                }
            },
            _ => {},
        }
        hs_pkt
    }
}

#[repr(C,packed)]
struct ClientEncryptedAuthenticationPacket {
    public_key: [u8; KEY_LEN],
    signature: [u8; SIGNATURE_LEN],
}
impl Default for ClientEncryptedAuthenticationPacket {
    fn default() -> ClientEncryptedAuthenticationPacket {
        ClientEncryptedAuthenticationPacket {
            public_key: [0u8; KEY_LEN],
            signature: [0u8; SIGNATURE_LEN],
        }
    }
}
impl ClientEncryptedAuthenticationPacket {
    fn from_bytes(data: &[u8]) -> Result<&ClientEncryptedAuthenticationPacket, OssuaryError> {
        let s: &ClientEncryptedAuthenticationPacket = slice_as_struct(&data)?;
        Ok(s)
    }
}

#[repr(C,packed)]
struct ClientAuthenticationPacket {
    len: u16,
    version: u8,
    _reserved: [u8; 5],
    subpacket: [u8; CLIENT_AUTH_SUBPACKET_LEN],
}
impl Default for ClientAuthenticationPacket {
    fn default() -> ClientAuthenticationPacket {
        ClientAuthenticationPacket {
            len: CLIENT_AUTH_PACKET_LEN as u16,
            version: PROTOCOL_VERSION,
            _reserved: [0u8; 5],
            subpacket: [0u8; CLIENT_AUTH_SUBPACKET_LEN],
        }
    }
}
impl ClientAuthenticationPacket {
    fn new(nonce: &[u8], session_privkey: &[u8],
           client_pubkey: &[u8], signature: &[u8]) -> Result<ClientAuthenticationPacket, OssuaryError> {
        let mut pkt: ClientAuthenticationPacket = Default::default();
        let mut enc_pkt: ClientEncryptedAuthenticationPacket = Default::default();
        enc_pkt.public_key.copy_from_slice(client_pubkey);
        enc_pkt.signature.copy_from_slice(signature);
        encrypt_to_bytes(session_privkey, nonce, struct_as_slice(&enc_pkt), &mut pkt.subpacket)?;
        Ok(pkt)
    }
    fn from_packet(pkt: &NetworkPacket) -> Result<&ClientAuthenticationPacket, OssuaryError> {
        let hs_pkt = interpret_packet::<ClientAuthenticationPacket>(&pkt);
        match hs_pkt {
            Ok(pkt) => {
                if pkt.version != PROTOCOL_VERSION {
                    return Err(OssuaryError::WrongProtocolVersion(pkt.version, PROTOCOL_VERSION));
                }
                if pkt.len as usize != CLIENT_AUTH_PACKET_LEN {
                    return Err(OssuaryError::InvalidPacket("Unexpected packet size.".into()));
                }
            },
            _ => {},
        }
        hs_pkt
    }
}

impl OssuaryConnection {
    /// Write the next handshake packet into the given buffer
    ///
    /// If a handshake packet is ready to be sent, this function writes the
    /// encrypted packet into the provided buffer.
    ///
    /// This is a critical part of the handshaking stage, when a connection to
    /// a remote host is securely established.  Each side of the connection must
    /// call send_handshake() continuously, and any data that is written to the
    /// data buffer must be sent to the remote host.  This should be done until
    /// [`OssuaryConnection::handshake_done()`] returns true.
    ///
    /// Note that Ossuary does not perform network operations itself.  It is the
    /// caller's responsibility to put the written data on the wire.  However,
    /// you may pass a 'buf' that does this automatically, such as a TcpStream.
    ///
    /// Returns the number of bytes written into `buf`, or an error.  You must
    /// handle [`OssuaryError::WouldBlock`], which is a recoverable error, but
    /// indicates that some bytes were written to the buffer.
    pub fn send_handshake<T,U>(&mut self, mut buf: T) -> Result<usize, OssuaryError>
    where T: std::ops::DerefMut<Target = U>,
          U: std::io::Write {
        // Try to send any unsent buffered data
        match write_stored_packet(self, &mut buf) {
            Ok(w) if w == 0 => {},
            Ok(w) => return Err(OssuaryError::WouldBlock(w)),
            Err(e) => return Err(e),
        }
        let written = match self.state {
            // No-op states
            ConnectionState::Failed(_) |
            ConnectionState::ResetWait |
            ConnectionState::Encrypted |
            ConnectionState::ClientRaiseUntrustedServer |
            ConnectionState::ClientWaitServerApproval => {0},

            // Timeout wait states
            ConnectionState::ServerWaitHandshake(t) |
            ConnectionState::ServerWaitAuthentication(t) |
            ConnectionState::ClientWaitHandshake(t)  => {
                let mut w: usize = 0;
                // Wait for response, with timeout
                if let Ok(dur) = t.elapsed() {
                    if dur.as_secs() > MAX_HANDSHAKE_WAIT_TIME {
                        let pkt: ResetPacket = Default::default();
                        w = write_packet(self, &mut buf, struct_as_slice(&pkt),
                                         PacketType::Reset)?;
                        self.local_msg_id = 0;
                        self.reset_state(None);
                    }
                }
                w
            },

            ConnectionState::ClientSendHandshake => {
                // Send session public key and nonce to initiate connection
                let chal = self.local_auth.challenge.unwrap_or([0u8; CHALLENGE_LEN]);
                let pkt = ClientHandshakePacket::new(&self.local_key.public,
                                                     &self.local_key.nonce,
                                                     &chal);
                let w = write_packet(self, &mut buf, struct_as_slice(&pkt),
                                     PacketType::ClientHandshake)?;
                self.state = ConnectionState::ClientWaitHandshake(std::time::SystemTime::now());
                w
            },

            ConnectionState::ServerSendHandshake => {
                // Get a local copy of server's secret auth key, if it has one.
                // Default to 0s.
                let server_secret = match self.local_auth.secret_key {
                    Some(ref s) => match SecretKey::from_bytes(s.as_bytes()) {
                        Ok(s) => Some(s),
                        Err(_) => None,
                    },
                    _ => None,
                };
                // Sign the client's challenge if we have a key,
                // default to 0s.
                let sig: [u8; SIGNATURE_LEN] = match server_secret {
                    Some(s) => {
                        let mut sign_data = [0u8; KEY_LEN + NONCE_LEN + CHALLENGE_LEN];
                        sign_data[0..KEY_LEN].copy_from_slice(&self.local_key.public);
                        sign_data[KEY_LEN..KEY_LEN+NONCE_LEN].copy_from_slice(&self.local_key.nonce);
                        sign_data[KEY_LEN+NONCE_LEN..].copy_from_slice(&self.remote_auth.challenge.unwrap_or([0u8; CHALLENGE_LEN]));
                        let server_public = PublicKey::from(&s);
                        let keypair = Keypair { secret: s, public: server_public };
                        keypair.sign(&sign_data).to_bytes()
                    },
                    None => [0; SIGNATURE_LEN],
                };
                // Get server's public auth key, if it has one.
                // Default to 0s.
                let server_public = match self.local_auth.public_key {
                    Some(ref p) => p.as_bytes(),
                    None => &[0; KEY_LEN],
                };
                // Get session encryption key, which must be known by now.
                let session = match self.local_key.session {
                    Some(ref s) => s.as_bytes(),
                    None => {
                        self.reset_state(None);
                        return Err(OssuaryError::InvalidKey);
                    }
                };
                let chal = self.local_auth.challenge.unwrap_or([0u8; CHALLENGE_LEN]);
                let pkt = ServerHandshakePacket::new(&self.local_key.public,
                                                     &self.local_key.nonce,
                                                     session,
                                                     server_public,
                                                     &chal,
                                                     &sig)?;
                let w = write_packet(self, &mut buf, struct_as_slice(&pkt),
                                     PacketType::ServerHandshake)?;
                increment_nonce(&mut self.local_key.nonce);
                self.state = ConnectionState::ServerWaitAuthentication(std::time::SystemTime::now());
                w
            },

            ConnectionState::ClientSendAuthentication => {
                // Get a local copy of client's secret auth key, if it has one.
                // Default to 0s.
                let client_secret = match self.local_auth.secret_key {
                    Some(ref s) => match SecretKey::from_bytes(s.as_bytes()) {
                        Ok(s) => Some(s),
                        Err(_) => None,
                    },
                    _ => None,
                };
                // Sign the client's challenge if we have a key,
                // default to 0s.
                let sig: [u8; SIGNATURE_LEN] = match client_secret {
                    Some(s) => {
                        let client_public = PublicKey::from(&s);
                        let keypair = Keypair { secret: s, public: client_public };
                        let mut sign_data = [0u8; KEY_LEN + NONCE_LEN + CHALLENGE_LEN];
                        sign_data[0..KEY_LEN].copy_from_slice(&self.local_key.public);
                        sign_data[KEY_LEN..KEY_LEN+NONCE_LEN].copy_from_slice(&self.local_key.nonce);
                        sign_data[KEY_LEN+NONCE_LEN..].copy_from_slice(&self.remote_auth.challenge.unwrap_or([0u8; CHALLENGE_LEN]));
                        keypair.sign(&sign_data).to_bytes()
                    },
                    None => [0; SIGNATURE_LEN],
                };
                // Get server's public auth key, if it has one.
                // Default to 0s.
                let client_public = match self.local_auth.public_key {
                    Some(ref p) => p.as_bytes(),
                    None => &[0; KEY_LEN],
                };
                // Get session encryption key, which must be known by now.
                let session = match self.local_key.session {
                    Some(ref s) => s.as_bytes(),
                    None => {
                        self.reset_state(None);
                        return Err(OssuaryError::InvalidKey);
                    }
                };
                let pkt = ClientAuthenticationPacket::new(&self.local_key.nonce,
                                                          session,
                                                          client_public,
                                                          &sig)?;
                let w = write_packet(self, &mut buf, struct_as_slice(&pkt),
                                     PacketType::ClientAuthentication)?;
                increment_nonce(&mut self.local_key.nonce);
                self.state = ConnectionState::Encrypted;
                w
            },

            ConnectionState::Failing(ref e) => {
                // Tell remote host to disconnect
                let pkt: ResetPacket = match e {
                    OssuaryError::ConnectionClosed => ResetPacket::closed(),
                    _ => Default::default(),
                };
                let w = write_packet(self, &mut buf, struct_as_slice(&pkt),
                                     PacketType::Disconnect)?;
                w
            },

            ConnectionState::Resetting(initial) => {
                // Tell remote host to reset
                let pkt: ResetPacket = Default::default();
                let w = write_packet(self, &mut buf, struct_as_slice(&pkt),
                                     PacketType::Reset)?;
                self.local_msg_id = 0;
                self.state = match initial {
                    true => ConnectionState::ResetWait,
                    false => self.initial_state(),
                };
                w
            }
        };

        // Finalize failure state if failing
        match self.state {
            ConnectionState::Failing(ref e) => {
                self.state = ConnectionState::Failed(e.clone());
            },
            _ => {},
        }
        Ok(written)
    }

    /// Read the next handshake packet from the given buffer
    ///
    /// If a handshake packet has been received, this function reads and parses
    /// the encrypted packet from the provided buffer and updates its internal
    /// connection state.
    ///
    /// This is a critical part of the handshaking stage, when a connection to
    /// a remote host is securely established.  Each side of the connection must
    /// call recv_handshake() whenever data is received from the network until
    /// [`OssuaryConnection::handshake_done()`] returns true.
    ///
    /// Returns the number of bytes read from `buf`, or an error.  It is the
    /// caller's responsibility to ensure that the consumed bytes are removed
    /// from the data buffer before it is used again.  You must handle
    /// [`OssuaryError::WouldBlock`], which is a recoverable error, but
    /// indicates that some bytes were also read from the buffer.
    pub fn recv_handshake<T,U>(&mut self, buf: T) -> Result<usize, OssuaryError>
    where T: std::ops::DerefMut<Target = U>,
          U: std::io::Read {
        match self.state {
            ConnectionState::Failed(_) |
            ConnectionState::Encrypted => return Ok(0),
            // Timeout wait states
            ConnectionState::ServerWaitHandshake(t) |
            ConnectionState::ServerWaitAuthentication(t) |
            ConnectionState::ClientWaitHandshake(t)  => {
                // Wait for response, with timeout
                if let Ok(dur) = t.elapsed() {
                    if dur.as_secs() > MAX_HANDSHAKE_WAIT_TIME {
                        self.reset_state(None);
                        return Err(OssuaryError::ConnectionReset(0));
                    }
                }
            },
            _ => {},
        }

        let (pkt, bytes_read) = match read_packet(self, buf) {
            Ok(t) => { t },
            Err(OssuaryError::WouldBlock(b)) => {
                return Err(OssuaryError::WouldBlock(b));
            }
            Err(e) => {
                self.reset_state(Some(e.clone()));
                return Err(e);
            }
        };

        match pkt.kind() {
            PacketType::Reset => {
                match self.state {
                    ConnectionState::ResetWait => {},
                    _ => {
                        self.reset_state(None);
                        self.state = ConnectionState::Resetting(false);
                        return Err(OssuaryError::ConnectionReset(bytes_read));
                    },
                }
            },
            PacketType::Disconnect => {
                let rs_pkt = interpret_packet::<ResetPacket>(&pkt)?;
                match rs_pkt.error {
                    true => {
                        self.reset_state(Some(OssuaryError::ConnectionFailed));
                        return Err(OssuaryError::ConnectionFailed);
                    },
                    false => {
                        self.reset_state(Some(OssuaryError::ConnectionClosed));
                        return Err(OssuaryError::ConnectionClosed);
                    },
                }
            },
            _ => {},
        }

        if pkt.header.msg_id != self.remote_msg_id {
            match pkt.kind() {
                PacketType::Disconnect |
                PacketType::Reset => {},
                _ => {
                    match self.state {
                        ConnectionState::ResetWait => {},
                        _ => {
                            println!("Message gap detected.  Restarting connection.");
                            self.reset_state(None);
                            return Err(OssuaryError::InvalidPacket("Message ID does not match".into()));
                        },
                    }
                },
            }
        }
        self.remote_msg_id = pkt.header.msg_id + 1;

        match self.state {
            // no-op states
            ConnectionState::Failing(_) |
            ConnectionState::Failed(_) |
            ConnectionState::Resetting(_) |
            ConnectionState::ClientRaiseUntrustedServer |
            ConnectionState::ClientWaitServerApproval => {
                self.reset_state(None);
                return Err(OssuaryError::InvalidPacket("Received unexpected handshake packet.".into()));
            },

            // Non-receive states.  Receiving handshake data is an error.
            ConnectionState::ClientSendHandshake |
            ConnectionState::ClientSendAuthentication |
            ConnectionState::ServerSendHandshake |
            ConnectionState::Encrypted => {
                self.reset_state(None);
                return Err(OssuaryError::InvalidPacket("Received unexpected handshake packet.".into()));
            },

            ConnectionState::ServerWaitHandshake(_) => {
                match pkt.kind() {
                    PacketType::ClientHandshake => {
                        if let Ok(inner_pkt) = ClientHandshakePacket::from_packet(&pkt) {
                            let mut chal: [u8; CHALLENGE_LEN] = Default::default();
                            chal.copy_from_slice(&inner_pkt.challenge);
                            self.add_remote_key(&inner_pkt.public_key, &inner_pkt.nonce);
                            self.remote_auth = AuthKeyMaterial {
                                challenge: Some(chal),
                                public_key: None,
                                signature: None,
                                secret_key: None,
                            };
                            self.state = ConnectionState::ServerSendHandshake;
                        }
                    },
                    _ => {
                        self.reset_state(None);
                        return Err(OssuaryError::InvalidPacket("Received unexpected handshake packet.".into()));
                    },
                }
            },

            ConnectionState::ClientWaitHandshake(_t) => {
                match pkt.kind() {
                    PacketType::ServerHandshake => {
                        let packet = ServerHandshakePacket::from_packet(&pkt);
                        if packet.is_err() { // TODO: refactor
                            self.reset_state(None);
                            return Err(OssuaryError::InvalidPacket("Received unexpected handshake packet.".into()));
                        }
                        if let Ok(inner_pkt) = packet {
                            self.add_remote_key(&inner_pkt.public_key, &inner_pkt.nonce);
                            let mut plaintext: [u8; SERVER_HANDSHAKE_SUBPACKET_LEN] = [0u8; SERVER_HANDSHAKE_SUBPACKET_LEN];
                            let session = match self.local_key.session {
                                Some(ref s) => s.as_bytes(),
                                _ => {
                                    self.reset_state(None);
                                    return Err(OssuaryError::InvalidPacket("Received unexpected handshake packet.".into()));
                                }
                            };
                            let nonce = match self.remote_key {
                                Some(ref k) => k.nonce,
                                _ => {
                                    self.reset_state(None);
                                    return Err(OssuaryError::InvalidPacket("Received unexpected handshake packet.".into()));
                                }
                            };
                            // note: pt is consumed by decrypt_to_bytes
                            match decrypt_to_bytes(session, &nonce, &inner_pkt.subpacket, &mut plaintext) {
                                Ok(_) => {},
                                Err(e) => {
                                    self.reset_state(None);
                                    return Err(e);
                                }
                            }
                            if let Ok(enc_pkt) = ServerEncryptedHandshakePacket::from_bytes(&plaintext) {
                                let mut chal: [u8; CHALLENGE_LEN] = [0u8; CHALLENGE_LEN];
                                let mut sig: [u8; SIGNATURE_LEN] = [0u8; SIGNATURE_LEN];
                                chal.copy_from_slice(&enc_pkt.challenge);
                                sig.copy_from_slice(&enc_pkt.signature);
                                let pubkey = match PublicKey::from_bytes(&enc_pkt.public_key) {
                                    Ok(p) => p,
                                    _ => {
                                        self.reset_state(None);
                                        return Err(OssuaryError::InvalidPacket("Received unexpected handshake packet.".into()));
                                    }
                                };
                                let signature = match Signature::from_bytes(&sig) {
                                    Ok(s) => s,
                                    Err(_) => {
                                        self.reset_state(None);
                                        return Err(OssuaryError::InvalidSignature);
                                    }
                                };

                                // All servers should have an auth key set, so
                                // these parameters should be non-zero and the
                                // signature should verify.
                                if chal.iter().all(|x| *x == 0) ||
                                    sig.iter().all(|x| *x == 0) ||
                                    enc_pkt.public_key.iter().all(|x| *x == 0) {
                                        // Parameters must be non-zero
                                        self.reset_state(None);
                                        return Err(OssuaryError::InvalidSignature);
                                    }

                                // This is the first encrypted message, so the nonce has not changed yet
                                let mut sign_data = [0u8; KEY_LEN + NONCE_LEN + CHALLENGE_LEN];
                                sign_data[0..KEY_LEN].copy_from_slice(self.remote_key.as_ref().map(|k| &k.public).unwrap_or(&[0u8; KEY_LEN]));
                                sign_data[KEY_LEN..KEY_LEN+NONCE_LEN].copy_from_slice(self.remote_key.as_ref().map(|k| &k.nonce).unwrap_or(&[0u8; NONCE_LEN]));
                                sign_data[KEY_LEN+NONCE_LEN..].copy_from_slice(&self.local_auth.challenge.unwrap_or([0u8; CHALLENGE_LEN]));
                                match pubkey.verify(&sign_data, &signature) {
                                    Ok(_) => {},
                                    Err(_) => {
                                        self.reset_state(None);
                                        return Err(OssuaryError::InvalidSignature);
                                    },
                                }

                                self.remote_auth = AuthKeyMaterial {
                                    challenge: Some(chal),
                                    public_key: Some(pubkey),
                                    signature: Some(sig),
                                    secret_key: None,
                                };
                                let _ = self.remote_key.as_mut().map(|k| increment_nonce(&mut k.nonce));

                                match self.authorized_keys.contains(&enc_pkt.public_key) {
                                    true => self.state = ConnectionState::ClientSendAuthentication,
                                    false => self.state = ConnectionState::ClientRaiseUntrustedServer,
                                }
                            }
                        }
                    },
                    _ => {
                        self.reset_state(None);
                        return Err(OssuaryError::InvalidPacket("Received unexpected handshake packet.".into()));
                    },
                }
            },

            ConnectionState::ServerWaitAuthentication(_t) => {
                match pkt.kind() {
                    PacketType::ClientAuthentication => {
                        let packet = ClientAuthenticationPacket::from_packet(&pkt);
                        if packet.is_err() { // TODO: refactor
                            self.reset_state(None);
                            return Err(OssuaryError::InvalidPacket("Received unexpected handshake packet.".into()));
                        }
                        if let Ok(inner_pkt) = packet {
                            let mut plaintext: [u8; CLIENT_AUTH_SUBPACKET_LEN] = [0u8; CLIENT_AUTH_SUBPACKET_LEN];
                            let session = match self.local_key.session {
                                Some(ref s) => s.as_bytes(),
                                _ => {
                                    self.reset_state(None);
                                    return Err(OssuaryError::InvalidPacket("Received unexpected handshake packet.".into()));
                                }
                            };
                            let nonce = match self.remote_key {
                                Some(ref k) => k.nonce,
                                _ => {
                                    self.reset_state(None);
                                    return Err(OssuaryError::InvalidPacket("Received unexpected handshake packet.".into()));
                                }
                            };
                            // note: pt is consumed by decrypt_to_bytes
                            match decrypt_to_bytes(session, &nonce, &inner_pkt.subpacket, &mut plaintext) {
                                Ok(_) => {},
                                Err(e) => {
                                    self.reset_state(None);
                                    return Err(e);
                                }
                            }
                            if let Ok(enc_pkt) = ClientEncryptedAuthenticationPacket::from_bytes(&plaintext) {
                                let mut sig: [u8; SIGNATURE_LEN] = [0u8; SIGNATURE_LEN];
                                sig.copy_from_slice(&enc_pkt.signature);
                                let pubkey = match PublicKey::from_bytes(&enc_pkt.public_key) {
                                    Ok(p) => p,
                                    _ => {
                                        self.reset_state(None);
                                        return Err(OssuaryError::InvalidPacket("Received unexpected handshake packet.".into()));
                                    }
                                };
                                let challenge = self.local_auth.challenge.unwrap_or([0u8; CHALLENGE_LEN]);
                                let signature = match Signature::from_bytes(&sig) {
                                    Ok(s) => s,
                                    Err(_) => {
                                        self.reset_state(None);
                                        return Err(OssuaryError::InvalidSignature);
                                    }
                                };
                                match self.conn_type {
                                    ConnectionType::AuthenticatedServer => {
                                        if challenge.iter().all(|x| *x == 0) ||
                                            sig.iter().all(|x| *x == 0) ||
                                            enc_pkt.public_key.iter().all(|x| *x == 0) {
                                                // Parameters must be non-zero
                                                self.reset_state(None);
                                                return Err(OssuaryError::InvalidSignature);
                                        }

                                        // This is the first encrypted message, so the nonce has not changed yet
                                        let mut sign_data = [0u8; KEY_LEN + NONCE_LEN + CHALLENGE_LEN];
                                        sign_data[0..KEY_LEN].copy_from_slice(self.remote_key.as_ref().map(|k| &k.public).unwrap_or(&[0u8; KEY_LEN]));
                                        sign_data[KEY_LEN..KEY_LEN+NONCE_LEN].copy_from_slice(self.remote_key.as_ref().map(|k| &k.nonce).unwrap_or(&[0u8; NONCE_LEN]));
                                        sign_data[KEY_LEN+NONCE_LEN..].copy_from_slice(&self.local_auth.challenge.unwrap_or([0u8; CHALLENGE_LEN]));
                                        match pubkey.verify(&sign_data, &signature) {
                                            Ok(_) => {},
                                            Err(_) => {
                                                self.reset_state(None);
                                                return Err(OssuaryError::InvalidSignature);
                                            },
                                        }

                                        // Ensure this key is permitted to connect
                                        match self.authorized_keys.contains(&enc_pkt.public_key) {
                                            true => {},
                                            false => {
                                                self.reset_state(None);
                                                return Err(OssuaryError::InvalidKey);
                                            },
                                        }
                                    }
                                    _ => {},
                                }
                                self.remote_auth.signature = Some(sig);
                                self.remote_auth.public_key = Some(pubkey);
                                let _ = self.remote_key.as_mut().map(|k| increment_nonce(&mut k.nonce));
                                self.state = ConnectionState::Encrypted;
                            }
                        }
                    },
                    _ => {
                        self.reset_state(None);
                        return Err(OssuaryError::InvalidPacket("Received unexpected handshake packet.".into()));
                    },
                }
            },

            ConnectionState::ResetWait => {
                match pkt.kind() {
                    PacketType::Reset => {
                        self.remote_msg_id = 0;
                        self.state = match self.conn_type {
                            ConnectionType::Client => ConnectionState::ClientSendHandshake,
                            _ => ConnectionState::ServerWaitHandshake(std::time::SystemTime::now()),
                        }
                    },
                    _ => {},
                }
            }
        };
        Ok(bytes_read)
    }


    /// Returns whether the handshake process is complete.
    ///
    /// Returns an error if the connection has failed, and specifically raises
    /// [`OssuaryError::UntrustedServer`] if the handshake has stalled because
    /// the remote host sent an authentication key that is not trusted.
    ///
    /// In the event of an untrusted server, calling
    /// [`OssuaryConnection::add_authorized_key`] will mark the key as trusted
    /// and allow the handshake to continue.  This should only be done if the
    /// application is implementing a Trust-On-First-Use policy, and has
    /// verified that the remote host's key has never been seen before.  It is
    /// always best practice to prompt the user in this case before continuing.
    ///
    pub fn handshake_done(&mut self) -> Result<bool, OssuaryError> {
        match self.state {
            ConnectionState::Encrypted => Ok(true),
            ConnectionState::Failed(ref e) => Err(e.clone()),
            ConnectionState::ClientRaiseUntrustedServer => {
                self.state = ConnectionState::ClientWaitServerApproval;
                let mut key: Vec<u8> = Vec::new();
                match self.remote_auth.public_key {
                    Some(ref p) => key.extend_from_slice(p.as_bytes()),
                    None => key.extend_from_slice(&[0; KEY_LEN]),
                };
                Err(OssuaryError::UntrustedServer(key))
            },
            _ => Ok(false),
        }
    }
}

fn encrypt_to_bytes(session_key: &[u8], nonce: &[u8],
                    data: &[u8], mut out: &mut [u8]) -> Result<usize, OssuaryError> {
    let aad = [];
    let mut ciphertext = Vec::with_capacity(data.len());
    let tag = match encrypt(session_key,
                            nonce,
                            &aad,
                            data,
                            &mut ciphertext) {
        Ok(t) => t,
        Err(_) => {
            return Err(OssuaryError::InvalidKey);
        }
    };
    let pkt: EncryptedPacket = EncryptedPacket {
        tag_len: tag.len() as u16,
        data_len: ciphertext.len() as u16,
    };
    let mut size = 0;
    size += out.write(struct_as_slice(&pkt))?;
    size += out.write(&ciphertext)?;
    size += out.write(&tag)?;
    Ok(size)
}

fn decrypt_to_bytes(session_key: &[u8], nonce: &[u8],
                    data: &[u8], mut out: &mut [u8]) -> Result<usize, OssuaryError> {
    let s: &EncryptedPacket = slice_as_struct(data)?;
    if s.tag_len != 16 {
        return Err(OssuaryError::InvalidPacket("Invalid packet length".into()));
    }
    let data_pkt = s;
    let rest = &data[::std::mem::size_of::<EncryptedPacket>()..];
    let ciphertext = &rest[..data_pkt.data_len as usize];
    let tag = &rest[data_pkt.data_len as usize..];
    let aad = [];
    decrypt(session_key,
            &nonce,
            &aad, &ciphertext, &tag,
            &mut out)?;
    Ok(ciphertext.len())
}