pub mod clib;
-const MAX_PUB_KEY_ACK_TIME: u64 = 3u64;
+// Maximum time to wait (in seconds) for a handshake response
+const MAX_HANDSHAKE_WAIT_TIME: u64 = 3u64;
+
+// Size of the random data to be signed by client
const CHALLENGE_LEN: usize = 256;
+
//
// API:
// * sock -- TCP data socket
// - non-blocking IO
// - remove all unwraps()
// - consider all unexpected packet types to be errors
+// - limit connection retries
fn struct_as_slice<T: Sized>(p: &T) -> &[u8] {
unsafe {
InvalidKey,
InvalidPacket(String),
InvalidStruct,
+ InvalidSignature,
+ ConnectionFailed,
}
impl std::fmt::Debug for OssuaryError {
fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
ClientWaitAck(std::time::SystemTime),
ClientSendAuth,
- Failed,
+ Failed(OssuaryError),
Encrypted,
}
struct KeyMaterial {
impl ConnectionContext {
pub fn new(conn_type: ConnectionType) -> ConnectionContext {
//let mut rng = thread_rng();
- let mut rng = OsRng::new().unwrap();
+ let mut rng = OsRng::new().expect("RNG not available.");
let sec_key = generate_secret(&mut rng);
let pub_key = generate_public(&sec_key);
let mut nonce: [u8; 12] = [0; 12];
public_key: None,
}
}
- fn reset_state(&mut self) {
- self.state = match self.conn_type {
- ConnectionType::Client => ConnectionState::ClientNew,
- _ => ConnectionState::ServerNew,
+ fn reset_state(&mut self, permanent_err: Option<OssuaryError>) {
+ self.state = match permanent_err {
+ None => {
+ match self.conn_type {
+ ConnectionType::Client => ConnectionState::ClientNew,
+ _ => ConnectionState::ServerNew,
+ }
+ },
+ Some(e) => {
+ ConnectionState::Failed(e)
+ }
};
self.local_msg_id = 0;
self.challenge = None;
session: None,
};
self.remote_key = Some(key);
- self.local_key.session = Some(diffie_hellman(self.local_key.secret.as_ref().unwrap(), public));
+ if let Some(secret) = self.local_key.secret.as_ref() {
+ self.local_key.session = Some(diffie_hellman(secret, public));
+ }
}
pub fn set_authorized_keys<'a,T>(&mut self, keys: T) -> Result<usize, OssuaryError>
where T: std::iter::IntoIterator<Item = &'a [u8]> {
let mut next_msg_id = conn.local_msg_id;
let more = match conn.state {
ConnectionState::ServerNew => {
- // wait for client
+ // Wait for client to initiate connection
true
},
+ ConnectionState::Encrypted => {
+ // Handshake finished
+ false
+ },
ConnectionState::ServerWaitAck(t) |
- ConnectionState::ServerWaitAuth(t) => {
- // TIMEOUT NACK
+ ConnectionState::ServerWaitAuth(t) |
+ ConnectionState::ClientWaitKey(t) |
+ ConnectionState::ClientWaitAck(t) => {
+ // Wait for response, with timeout
if let Ok(dur) = t.elapsed() {
- if dur.as_secs() > MAX_PUB_KEY_ACK_TIME {
+ if dur.as_secs() > MAX_HANDSHAKE_WAIT_TIME {
let pkt: HandshakePacket = Default::default();
let _ = write_packet(&mut buf, struct_as_slice(&pkt),
&mut next_msg_id, PacketType::Reset);
- conn.state = ConnectionState::ServerNew;
+ conn.reset_state(None);
}
}
true
},
ConnectionState::ServerSendPubKey => {
- // Send pubkey
+ // Send session public key and nonce to the client
let mut pkt: HandshakePacket = Default::default();
pkt.public_key.copy_from_slice(&conn.local_key.public);
pkt.nonce.copy_from_slice(&conn.local_key.nonce);
ConnectionState::ServerSendChallenge => {
match conn.conn_type {
ConnectionType::AuthenticatedServer => {
+ // Send a block of random data over the encrypted session to
+ // the client. The client must sign it with its key to prove
+ // key possession.
+ let mut rng = match OsRng::new() {
+ Ok(rng) => rng,
+ Err(_) => {
+ conn.reset_state(None);
+ return true;
+ }
+ };
let aad = [];
let mut challenge: [u8; CHALLENGE_LEN] = [0; CHALLENGE_LEN];
- let mut rng = OsRng::new().unwrap();
rng.fill_bytes(&mut challenge);
conn.challenge = Some(challenge.to_vec());
let mut ciphertext = Vec::with_capacity(CHALLENGE_LEN);
- let tag = encrypt(conn.local_key.session.as_ref().unwrap(),
- &conn.local_key.nonce,
- &aad, &challenge, &mut ciphertext).unwrap();
-
+ let session_key = match conn.local_key.session {
+ Some(ref s) => s,
+ None => {
+ conn.reset_state(None);
+ return true;
+ }
+ };
+ let tag = match encrypt(session_key,
+ &conn.local_key.nonce,
+ &aad, &challenge, &mut ciphertext) {
+ Ok(tag) => tag,
+ Err(_) => {
+ conn.reset_state(None);
+ return true;
+ }
+ };
let pkt: EncryptedPacket = EncryptedPacket {
tag_len: tag.len() as u16,
data_len: ciphertext.len() as u16,
true
},
_ => {
- // Unauthenticated
+ // For unauthenticated connections, we are done. Already encrypted.
let pkt: HandshakePacket = Default::default();
let _ = write_packet(&mut buf, struct_as_slice(&pkt),
&mut next_msg_id, PacketType::PubKeyAck);
conn.state = ConnectionState::Encrypted;
- false
+ false // handshake is finished (success)
},
}
},
ConnectionState::ClientNew => {
- // Send pubkey
+ // Send session public key and nonce to initiate connection
let mut pkt: HandshakePacket = Default::default();
pkt.public_key.copy_from_slice(&conn.local_key.public);
pkt.nonce.copy_from_slice(&conn.local_key.nonce);
conn.state = ConnectionState::ClientWaitKey(std::time::SystemTime::now());
true
},
- ConnectionState::ClientWaitKey(t) => {
- if let Ok(dur) = t.elapsed() {
- if dur.as_secs() > MAX_PUB_KEY_ACK_TIME {
- conn.reset_state();
- }
- }
- true
- },
ConnectionState::ClientSendAck => {
+ // Acknowledge reception of server's session public key and nonce
let pkt: HandshakePacket = Default::default();
let _ = write_packet(&mut buf, struct_as_slice(&pkt),
&mut next_msg_id, PacketType::PubKeyAck);
conn.state = ConnectionState::ClientWaitAck(std::time::SystemTime::now());
true
},
- ConnectionState::ClientWaitAck(t) => {
- if let Ok(dur) = t.elapsed() {
- if dur.as_secs() > MAX_PUB_KEY_ACK_TIME {
- conn.reset_state();
- }
- }
- true
- },
ConnectionState::ClientSendAuth => {
- // TODO: import secret key
- if conn.secret_key.is_none() {
- conn.reset_state();
- // TODO: raise error
- return true;
- }
- let secret = conn.secret_key.as_ref()
- .map(|sec| SecretKey::from_bytes(sec.as_bytes()).unwrap())
- .unwrap();
+ // Send signature of the server's challenge back to the server,
+ // along with the public part of the authentication key. This is
+ // done over the established encrypted channel.
+ let secret = match conn.secret_key {
+ Some(ref s) => match SecretKey::from_bytes(s.as_bytes()) {
+ Ok(s) => s, // local copy of secret key
+ Err(_) => {
+ conn.reset_state(Some(OssuaryError::InvalidKey));
+ return true;
+ }
+ },
+ None => {
+ conn.reset_state(Some(OssuaryError::InvalidKey));
+ return true;
+ }
+ };
let public = PublicKey::from_secret::<Sha512>(&secret);
let keypair = Keypair { secret: secret, public: public };
- let sig = keypair.sign::<Sha512>(&conn.challenge.as_ref().unwrap()).to_bytes();
+ let sig = match conn.challenge {
+ Some(ref c) => keypair.sign::<Sha512>(c).to_bytes(),
+ None => {
+ conn.reset_state(None);
+ return true;
+ }
+ };
let mut pkt_data: Vec<u8> = Vec::with_capacity(CHALLENGE_LEN + 32);
pkt_data.extend_from_slice(public.as_bytes());
pkt_data.extend_from_slice(&sig);
let aad = [];
let mut ciphertext = Vec::with_capacity(pkt_data.len());
- let tag = encrypt(conn.local_key.session.as_ref().unwrap(),
- &conn.local_key.nonce,
- &aad, &pkt_data, &mut ciphertext).unwrap();
+ let session_key = match conn.local_key.session {
+ Some(ref s) => s,
+ None => {
+ conn.reset_state(None);
+ return true;
+ }
+ };
+ let tag = match encrypt(session_key,
+ &conn.local_key.nonce,
+ &aad, &pkt_data, &mut ciphertext) {
+ Ok(t) => t,
+ Err(_) => {
+ conn.reset_state(None);
+ return true;
+ }
+ };
let pkt: EncryptedPacket = EncryptedPacket {
tag_len: tag.len() as u16,
let _ = write_packet(&mut buf, &pkt_buf,
&mut next_msg_id, PacketType::AuthResponse);
conn.state = ConnectionState::Encrypted;
- false
+ false // handshake is finished (success)
},
- ConnectionState::Failed => {
+ ConnectionState::Failed(_) => {
+ // This is a permanent failure.
let pkt: HandshakePacket = Default::default();
let _ = write_packet(&mut buf, struct_as_slice(&pkt),
&mut next_msg_id, PacketType::Disconnect);
- conn.reset_state();
- true
- },
- ConnectionState::Encrypted => {
- false
+ conn.reset_state(Some(OssuaryError::ConnectionFailed));
+ false // handshake is finished (failed)
},
};
conn.local_msg_id = next_msg_id;
where T: std::ops::DerefMut<Target = U>,
U: std::io::Read {
// TODO: read_exact won't work.
- let pkt = read_packet(buf);
- if pkt.is_err() {
- return;
- }
- let pkt: NetworkPacket = pkt.unwrap();
+ let pkt: NetworkPacket = match read_packet(buf) {
+ Ok(p) => p,
+ Err(_) => {
+ return;
+ }
+ };
if pkt.header.msg_id != conn.remote_msg_id {
println!("Message gap detected. Restarting connection.");
println!("Server: {}", conn.is_server());
- conn.reset_state();
+ conn.reset_state(None);
return; // TODO: return error
}
conn.remote_msg_id = pkt.header.msg_id + 1;
let mut error = false;
match pkt.kind() {
PacketType::Reset => {
- conn.reset_state();
+ conn.reset_state(None);
return;
},
PacketType::Disconnect => {
// TODO: handle error
+ conn.reset_state(Some(OssuaryError::ConnectionFailed));
panic!("Remote side terminated connection.");
},
_ => {},
ConnectionState::ServerNew => {
match pkt.kind() {
PacketType::PublicKeyNonce => {
- let data_pkt: &HandshakePacket = interpret_packet(&pkt).as_ref().unwrap();
- conn.add_remote_key(&data_pkt.public_key, &data_pkt.nonce);
- conn.state = ConnectionState::ServerSendPubKey;
+ let data_pkt: Result<&HandshakePacket, _> = interpret_packet(&pkt);
+ match data_pkt {
+ Ok(ref data_pkt) => {
+ conn.add_remote_key(&data_pkt.public_key, &data_pkt.nonce);
+ conn.state = ConnectionState::ServerSendPubKey;
+ },
+ Err(_) => {
+ error = true;
+ },
+ }
},
_ => { error = true; }
}
}
},
ConnectionState::ServerWaitAuth(_t) => {
- // TODO (auth)
match pkt.kind() {
PacketType::AuthResponse => {
- let (data_pkt, rest) = interpret_packet_extra::<EncryptedPacket>(&pkt).unwrap();
- let ciphertext = &rest[..data_pkt.data_len as usize];
- let tag = &rest[data_pkt.data_len as usize..];
- let aad = [];
- let mut plaintext = Vec::with_capacity(ciphertext.len());
- let _ = decrypt(conn.local_key.session.as_ref().unwrap(),
- &conn.remote_key.as_ref().unwrap().nonce,
- &aad, &ciphertext, &tag, &mut plaintext);
- let pubkey = &plaintext[0..32];
- let sig = &plaintext[32..];
-
- if conn.authorized_keys.iter().filter(|k| &pubkey == k).count() > 0 {
- let public = PublicKey::from_bytes(pubkey).unwrap();
- let sig = Signature::from_bytes(sig).unwrap();
- match public.verify::<Sha512>(conn.challenge.as_ref().unwrap(), &sig) {
- Ok(_) => {
- conn.state = ConnectionState::Encrypted;
- },
- Err(_) => {
- println!("Verify bad");
- // TODO: error
- conn.state = ConnectionState::Failed;
- },
- }
- }
- else {
- println!("Key not allowed");
- // TODO: error
- conn.state = ConnectionState::Failed;
- }
+ match interpret_packet_extra::<EncryptedPacket>(&pkt) {
+ Ok((data_pkt, rest)) => {
+ let ciphertext = &rest[..data_pkt.data_len as usize];
+ let tag = &rest[data_pkt.data_len as usize..];
+ let aad = [];
+ let mut plaintext = Vec::with_capacity(ciphertext.len());
+ let session_key = match conn.local_key.session {
+ Some(ref k) => k,
+ None => {
+ conn.reset_state(None);
+ return;
+ }
+ };
+ let remote_nonce = match conn.remote_key {
+ Some(ref rem) => rem.nonce,
+ None => {
+ conn.reset_state(None);
+ return;
+ }
+ };
+ let _ = decrypt(session_key,
+ &remote_nonce,
+ &aad, &ciphertext, &tag, &mut plaintext);
+ let pubkey = &plaintext[0..32];
+ let sig = &plaintext[32..];
+ if conn.authorized_keys.iter().filter(|k| &pubkey == k).count() > 0 {
+ let public = match PublicKey::from_bytes(pubkey) {
+ Ok(p) => p,
+ Err(_) => {
+ conn.reset_state(None);
+ return;
+ }
+ };
+ let sig = match Signature::from_bytes(sig) {
+ Ok(s) => s,
+ Err(_) => {
+ conn.reset_state(None);
+ return;
+ }
+ };
+ let challenge = match conn.challenge {
+ Some(ref c) => c,
+ None => {
+ conn.reset_state(None);
+ return;
+ }
+ };
+ match public.verify::<Sha512>(challenge, &sig) {
+ Ok(_) => {
+ conn.state = ConnectionState::Encrypted;
+ },
+ Err(_) => {
+ println!("Verify bad");
+ conn.state = ConnectionState::Failed(
+ OssuaryError::InvalidSignature);
+ },
+ }
+ }
+ else {
+ println!("Key not allowed");
+ conn.state = ConnectionState::Failed(OssuaryError::InvalidKey);
+ }
+ },
+ Err(_) => {
+ conn.reset_state(None);
+ return;
+ },
+ };
},
_ => { error = true; }
}
ConnectionState::ClientWaitKey(_t) => {
match pkt.kind() {
PacketType::PublicKeyNonce => {
- let data_pkt: &HandshakePacket = interpret_packet(&pkt).as_ref().unwrap();
- conn.add_remote_key(&data_pkt.public_key, &data_pkt.nonce);
- conn.state = ConnectionState::ClientSendAck;
+ let data_pkt: Result<&HandshakePacket, _> = interpret_packet(&pkt);
+ match data_pkt {
+ Ok(data_pkt) => {
+ conn.add_remote_key(&data_pkt.public_key, &data_pkt.nonce);
+ conn.state = ConnectionState::ClientSendAck;
+ },
+ Err(_) => {
+ error = true;
+ },
+ }
},
- _ => { }
+ _ => {
+ error = true;
+ }
}
},
ConnectionState::ClientSendAck => {
conn.state = ConnectionState::Encrypted;
},
PacketType::AuthChallenge => {
- let (data_pkt, rest) = interpret_packet_extra::<EncryptedPacket>(&pkt).unwrap();
- let ciphertext = &rest[..data_pkt.data_len as usize];
- let tag = &rest[data_pkt.data_len as usize..];
- let aad = [];
- let mut plaintext = Vec::with_capacity(ciphertext.len());
- let _ = decrypt(conn.local_key.session.as_ref().unwrap(),
- &conn.remote_key.as_ref().unwrap().nonce,
- &aad, &ciphertext, &tag, &mut plaintext);
- conn.challenge = Some(plaintext);
- conn.state = ConnectionState::ClientSendAuth;
+ match interpret_packet_extra::<EncryptedPacket>(&pkt) {
+ Ok((data_pkt, rest)) => {
+ let ciphertext = &rest[..data_pkt.data_len as usize];
+ let tag = &rest[data_pkt.data_len as usize..];
+ let aad = [];
+ let mut plaintext = Vec::with_capacity(ciphertext.len());
+
+ let session_key = match conn.local_key.session {
+ Some(ref k) => k,
+ None => {
+ conn.reset_state(None);
+ return;
+ }
+ };
+ let remote_nonce = match conn.remote_key {
+ Some(ref rem) => rem.nonce,
+ None => {
+ conn.reset_state(None);
+ return;
+ }
+ };
+ let _ = decrypt(session_key,
+ &remote_nonce,
+ &aad, &ciphertext, &tag, &mut plaintext);
+ conn.challenge = Some(plaintext);
+ conn.state = ConnectionState::ClientSendAuth;
+ },
+ Err(_) => {
+ error = true;
+ },
+ }
+ },
+ _ => {
+ error = true;
},
- _ => {},
}
},
ConnectionState::ClientSendAuth => {
error = true;
}, // nop
- ConnectionState::Failed => {
+ ConnectionState::Failed(_) => {
error = true;
}, // nop
ConnectionState::Encrypted => {
}, // nop
}
if error {
- conn.reset_state();
+ conn.reset_state(None);
}
}
// TODO: should return a Result with error on forced-disconnect or permanent failure
-pub fn crypto_handshake_done(conn: &ConnectionContext) -> bool {
+pub fn crypto_handshake_done(conn: &ConnectionContext) -> Result<bool, &OssuaryError> {
match conn.state {
- ConnectionState::Encrypted => true,
- _ => false,
+ ConnectionState::Encrypted => Ok(true),
+ ConnectionState::Failed(ref e) => Err(e),
+ _ => Ok(false),
}
}
let bytes;
let aad = [];
let mut ciphertext = Vec::with_capacity(in_buf.len());
- let tag = encrypt(conn.local_key.session.as_ref().unwrap(),
- &conn.local_key.nonce, &aad, in_buf, &mut ciphertext).unwrap();
+ let session_key = match conn.local_key.session {
+ Some(ref k) => k,
+ None => {
+ conn.reset_state(None);
+ return Err(OssuaryError::InvalidKey);;
+ }
+ };
+ let tag = match encrypt(session_key,
+ &conn.local_key.nonce, &aad, in_buf, &mut ciphertext) {
+ Ok(t) => t,
+ Err(_) => {
+ conn.reset_state(None);
+ return Err(OssuaryError::InvalidKey);;
+ }
+ };
let pkt: EncryptedPacket = EncryptedPacket {
tag_len: tag.len() as u16,
if pkt.header.msg_id != conn.remote_msg_id {
println!("Message gap detected. Restarting connection.");
println!("Server: {}", conn.is_server());
- conn.reset_state();
+ conn.reset_state(None);
return Err(OssuaryError::InvalidPacket("Message ID mismatch".into()))
}
conn.remote_msg_id = pkt.header.msg_id + 1;
match pkt.kind() {
PacketType::EncryptedData => {
- let (data_pkt, rest) = interpret_packet_extra::<EncryptedPacket>(&pkt).unwrap();
- let ciphertext = &rest[..data_pkt.data_len as usize];
- let tag = &rest[data_pkt.data_len as usize..];
- let aad = [];
- let mut plaintext = Vec::with_capacity(ciphertext.len());
- let _ = decrypt(conn.local_key.session.as_ref().unwrap(),
- &conn.remote_key.as_ref().unwrap().nonce,
- &aad, &ciphertext, &tag, &mut plaintext);
- let _ = out_buf.write(&plaintext);
- bytes_written = ciphertext.len() as u16;
- bytes_read = (ciphertext.len() +
- ::std::mem::size_of::<PacketHeader>() +
- ::std::mem::size_of::<EncryptedPacket>() +
- tag.len()) as u16;
+ match interpret_packet_extra::<EncryptedPacket>(&pkt) {
+ Ok((data_pkt, rest)) => {
+ let ciphertext = &rest[..data_pkt.data_len as usize];
+ let tag = &rest[data_pkt.data_len as usize..];
+ let aad = [];
+ let mut plaintext = Vec::with_capacity(ciphertext.len());
+ let session_key = match conn.local_key.session {
+ Some(ref k) => k,
+ None => {
+ conn.reset_state(None);
+ return Err(OssuaryError::InvalidKey);
+ }
+ };
+ let remote_nonce = match conn.remote_key {
+ Some(ref rem) => rem.nonce,
+ None => {
+ conn.reset_state(None);
+ return Err(OssuaryError::InvalidKey);
+ }
+ };
+ let _ = decrypt(session_key,
+ &remote_nonce,
+ &aad, &ciphertext, &tag, &mut plaintext);
+ let _ = out_buf.write(&plaintext);
+ bytes_written = ciphertext.len() as u16;
+ bytes_read = (ciphertext.len() +
+ ::std::mem::size_of::<PacketHeader>() +
+ ::std::mem::size_of::<EncryptedPacket>() +
+ tag.len()) as u16;
+ },
+ Err(_) => {
+ conn.reset_state(None);
+ return Err(OssuaryError::InvalidKey);
+ },
+ }
},
_ => {
return Err(OssuaryError::InvalidPacket("Received non-encrypted data on encrypted channel.".into()));
let listener = TcpListener::bind("127.0.0.1:9987").unwrap();
let mut server_stream = listener.incoming().next().unwrap().unwrap();
let mut server_conn = ConnectionContext::new(ConnectionType::UnauthenticatedServer);
- while crypto_handshake_done(&server_conn) == false {
+ while crypto_handshake_done(&server_conn).unwrap() == false {
if crypto_send_handshake(&mut server_conn, &mut server_stream) {
crypto_recv_handshake(&mut server_conn, &mut server_stream);
}
std::thread::sleep(std::time::Duration::from_millis(500));
let mut client_stream = TcpStream::connect("127.0.0.1:9987").unwrap();
let mut client_conn = ConnectionContext::new(ConnectionType::Client);
- while crypto_handshake_done(&client_conn) == false {
+ while crypto_handshake_done(&client_conn).unwrap() == false {
if crypto_send_handshake(&mut client_conn, &mut client_stream) {
crypto_recv_handshake(&mut client_conn, &mut client_stream);
}
is_server: bool) -> Result<(), std::io::Error>
where T: std::io::Read + std::io::Write {
// Run the opaque handshake until the connection is established
- while crypto_handshake_done(&conn) == false {
+ while crypto_handshake_done(&conn).unwrap() == false {
if crypto_send_handshake(&mut conn, &mut stream) {
crypto_recv_handshake(&mut conn, &mut stream);
}
is_server: bool) -> Result<(), std::io::Error>
where T: std::io::Read + std::io::Write {
// Run the opaque handshake until the connection is established
- while crypto_handshake_done(&conn) == false {
+ while crypto_handshake_done(&conn).unwrap() == false {
if crypto_send_handshake(&mut conn, &mut stream) {
crypto_recv_handshake(&mut conn, &mut stream);
}