use x25519_dalek::generate_public;
use x25519_dalek::diffie_hellman;
-use rand::thread_rng;
+//use rand::thread_rng;
use rand::RngCore;
+use rand::rngs::OsRng;
use std::convert::TryInto;
pub mod clib;
+const MAX_PUB_KEY_ACK_TIME: u64 = 3u64;
//
// API:
// * sock -- TCP data socket
Ok(&*(&p[..::std::mem::size_of::<T>()] as *const [u8] as *const T))
}
}
-#[repr(packed)]
-#[allow(dead_code)]
+
+pub enum OssuaryError {
+ Io(std::io::Error),
+ Unpack(core::array::TryFromSliceError),
+}
+impl std::fmt::Debug for OssuaryError {
+ fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
+ write!(f, "OssuaryError")
+ }
+}
+impl From<std::io::Error> for OssuaryError {
+ fn from(error: std::io::Error) -> Self {
+ OssuaryError::Io(error)
+ }
+}
+impl From<core::array::TryFromSliceError> for OssuaryError {
+ fn from(error: core::array::TryFromSliceError) -> Self {
+ OssuaryError::Unpack(error)
+ }
+}
+
+#[repr(C,packed)]
struct HandshakePacket {
len: u16,
_reserved: u16,
#[repr(u16)]
#[derive(Clone, Copy)]
-#[allow(dead_code)]
enum PacketType {
- Unknown = 0,
+ Unknown = 0x00,
PublicKeyNonce = 0x01,
- AuthRequest = 0x02,
- EncryptedData = 0x03,
- Disconnect = 0x04,
+ PubKeyAck = 0x02,
+ AuthRequest = 0x03,
+ Reset = 0x04,
+ Disconnect = 0x05,
+ EncryptedData = 0x10,
}
impl PacketType {
pub fn from_u16(i: u16) -> PacketType {
match i {
0x01 => PacketType::PublicKeyNonce,
- 0x02 => PacketType::AuthRequest,
- 0x03 => PacketType::EncryptedData,
- 0x04 => PacketType::Disconnect,
+ 0x02 => PacketType::PubKeyAck,
+ 0x03 => PacketType::AuthRequest,
+ 0x04 => PacketType::Reset,
+ 0x05 => PacketType::Disconnect,
+ 0x10 => PacketType::EncryptedData,
_ => PacketType::Unknown,
}
}
}
-#[repr(packed)]
-#[allow(dead_code)]
+#[repr(C,packed)]
struct EncryptedPacket {
data_len: u16,
tag_len: u16,
}
-#[repr(packed)]
-#[allow(dead_code)]
+#[repr(C,packed)]
struct PacketHeader {
len: u16,
msg_id: u16,
_reserved: u16,
}
+struct NetworkPacket {
+ header: PacketHeader,
+ data: Box<[u8]>,
+}
+impl NetworkPacket {
+ fn kind(&self) -> PacketType {
+ self.header.packet_type
+ }
+}
+
enum ConnectionState {
- New,
- PubKeySent,
+ ServerNew,
+ ServerSendPubKey,
+ ServerWaitAck(std::time::SystemTime),
+
+ ClientNew,
+ ClientWaitKey(std::time::SystemTime),
+ ClientSendAck,
+
Encrypted,
- _Authenticated,
}
struct KeyMaterial {
secret: Option<[u8; 32]>,
}
pub struct ConnectionContext {
state: ConnectionState,
+ is_server: bool,
local_key: KeyMaterial,
remote_key: Option<KeyMaterial>,
+ local_msg_id: u16,
+ remote_msg_id: u16,
}
impl ConnectionContext {
- fn new() -> ConnectionContext {
- let mut rng = thread_rng();
+ fn new(server: bool) -> ConnectionContext {
+ //let mut rng = thread_rng();
+ let mut rng = OsRng::new().unwrap();
let sec_key = generate_secret(&mut rng);
let pub_key = generate_public(&sec_key);
let mut nonce: [u8; 12] = [0; 12];
session: None,
};
ConnectionContext {
- state: ConnectionState::New,
+ state: match server {
+ true => ConnectionState::ServerNew,
+ false => ConnectionState::ClientNew,
+ },
+ is_server: server,
local_key: key,
remote_key: None,
+ local_msg_id: 0u16,
+ remote_msg_id: 0u16,
}
}
+ fn reset_state(&mut self) {
+ self.state = match self.is_server {
+ true => ConnectionState::ServerNew,
+ false => ConnectionState::ClientNew,
+ };
+ }
fn add_remote_key(&mut self, public: &[u8; 32], nonce: &[u8; 12]) {
let key = KeyMaterial {
secret: None,
};
self.remote_key = Some(key);
self.local_key.session = Some(diffie_hellman(self.local_key.secret.as_ref().unwrap(), public));
- self.state = ConnectionState::Encrypted;
- }
-}
-
-struct NetworkPacket {
- header: PacketHeader,
- data: Box<[u8]>,
-}
-impl NetworkPacket {
- fn kind(&self) -> PacketType {
- self.header.packet_type
}
}
Ok((s, &pkt.data[::std::mem::size_of::<T>()..]))
}
-pub enum OssuaryError {
- Io(std::io::Error),
- Unpack(core::array::TryFromSliceError),
-}
-impl From<std::io::Error> for OssuaryError {
- fn from(error: std::io::Error) -> Self {
- OssuaryError::Io(error)
- }
-}
-impl From<core::array::TryFromSliceError> for OssuaryError {
- fn from(error: core::array::TryFromSliceError) -> Self {
- OssuaryError::Unpack(error)
- }
-}
-
fn read_packet<T,U>(mut stream: T) -> Result<NetworkPacket, OssuaryError>
where T: std::ops::DerefMut<Target = U>,
U: std::io::Read {
})
}
-//fn write_packet<T>(stream: &mut T, data: &[u8], msg_id: u16, kind: PacketType) -> Result<(), std::io::Error>
-//where T: std::io::Write {
-fn write_packet<T,U>(mut stream: T, data: &[u8], msg_id: u16, kind: PacketType) -> Result<(), std::io::Error>
+fn write_packet<T,U>(stream: &mut T, data: &[u8], msg_id: &mut u16, kind: PacketType) -> Result<(), std::io::Error>
where T: std::ops::DerefMut<Target = U>,
U: std::io::Write {
let mut buf: Vec<u8> = Vec::with_capacity(::std::mem::size_of::<PacketHeader>());
buf.extend_from_slice(&(data.len() as u16).to_be_bytes());
- buf.extend_from_slice(&(msg_id as u16).to_be_bytes());
+ buf.extend_from_slice(&(*msg_id as u16).to_be_bytes());
buf.extend_from_slice(&(kind as u16).to_be_bytes());
buf.extend_from_slice(&(0u16).to_be_bytes());
- stream.write(&buf)?;
- stream.write(data)?;
+ let _ = stream.write(&buf)?;
+ let _ = stream.write(data)?;
+ *msg_id = *msg_id + 1;
Ok(())
}
-pub fn crypto_send_handshake<T,U>(conn: &mut ConnectionContext, buf: T) -> bool
+pub fn crypto_send_handshake<T,U>(conn: &mut ConnectionContext, mut buf: T) -> bool
where T: std::ops::DerefMut<Target = U>,
U: std::io::Write {
- match conn.state {
- ConnectionState::New => {
+ let mut next_msg_id = conn.local_msg_id;
+ let more = match conn.state {
+ ConnectionState::ServerNew => {
+ // wait for client
+ true
+ },
+ ConnectionState::ServerWaitAck(t) => {
+ // TIMEOUT NACK
+ if let Ok(dur) = t.elapsed() {
+ if dur.as_secs() > MAX_PUB_KEY_ACK_TIME {
+ let pkt: HandshakePacket = Default::default();
+ let _ = write_packet(&mut buf, struct_as_slice(&pkt),
+ &mut next_msg_id, PacketType::Reset);
+ conn.state = ConnectionState::ServerNew;
+ }
+ }
+ true
+ },
+ ConnectionState::ServerSendPubKey => {
+ // Send pubkey
+ let mut pkt: HandshakePacket = Default::default();
+ pkt.public_key.copy_from_slice(&conn.local_key.public);
+ pkt.nonce.copy_from_slice(&conn.local_key.nonce);
+ let _ = write_packet(&mut buf, struct_as_slice(&pkt),
+ &mut next_msg_id, PacketType::PublicKeyNonce);
+ conn.state = ConnectionState::ServerWaitAck(std::time::SystemTime::now());
+ true
+ },
+ ConnectionState::ClientNew => {
+ // Send pubkey
let mut pkt: HandshakePacket = Default::default();
pkt.public_key.copy_from_slice(&conn.local_key.public);
pkt.nonce.copy_from_slice(&conn.local_key.nonce);
- let _ = write_packet(buf, struct_as_slice(&pkt), 0, PacketType::PublicKeyNonce);
- conn.state = ConnectionState::PubKeySent;
+ let _ = write_packet(&mut buf, struct_as_slice(&pkt),
+ &mut next_msg_id, PacketType::PublicKeyNonce);
+ conn.state = ConnectionState::ClientWaitKey(std::time::SystemTime::now());
true
},
- _ => {
+ ConnectionState::ClientWaitKey(t) => {
+ if let Ok(dur) = t.elapsed() {
+ if dur.as_secs() > MAX_PUB_KEY_ACK_TIME {
+ conn.state = ConnectionState::ClientNew;
+ }
+ }
+ true
+ },
+ ConnectionState::ClientSendAck => {
+ let pkt: HandshakePacket = Default::default();
+ let _ = write_packet(&mut buf, struct_as_slice(&pkt),
+ &mut next_msg_id, PacketType::PubKeyAck);
+ conn.state = ConnectionState::Encrypted;
false
- }
- }
+ },
+ ConnectionState::Encrypted => {
+ false
+ },
+ };
+ conn.local_msg_id = next_msg_id;
+ more
}
-pub fn crypto_recv_handshake<T,U>(conn: &mut ConnectionContext, buf: T) -> bool
+pub fn crypto_recv_handshake<T,U>(conn: &mut ConnectionContext, buf: T)
where T: std::ops::DerefMut<Target = U>,
U: std::io::Read {
+ // TODO: read_exact won't work.
+ let pkt = read_packet(buf);
+ if pkt.is_err() {
+ return;
+ }
+ let pkt: NetworkPacket = pkt.unwrap();
+
+ if pkt.header.msg_id != conn.remote_msg_id {
+ println!("Message gap detected. Restarting connection.");
+ conn.reset_state();
+ return; // TODO: return error
+ }
+ conn.remote_msg_id = pkt.header.msg_id + 1;
+
+ let mut error = false;
+ match pkt.kind() {
+ PacketType::Reset => {
+ conn.state = match conn.is_server {
+ true => ConnectionState::ServerNew,
+ _ => ConnectionState::ClientNew,
+ };
+ return;
+ },
+ _ => {},
+ }
+
match conn.state {
- ConnectionState::New => { return true; },
- ConnectionState::PubKeySent => {},
- _ => { return false; }
+ ConnectionState::ServerNew => {
+ match pkt.kind() {
+ PacketType::PublicKeyNonce => {
+ let data_pkt: &HandshakePacket = interpret_packet(&pkt).as_ref().unwrap();
+ conn.add_remote_key(&data_pkt.public_key, &data_pkt.nonce);
+ conn.state = ConnectionState::ServerSendPubKey;
+ },
+ _ => { error = true; }
+ }
+ },
+ ConnectionState::ServerWaitAck(_t) => {
+ match pkt.kind() {
+ PacketType::PubKeyAck => {
+ conn.state = ConnectionState::Encrypted;
+ },
+ _ => { error = true; }
+ }
+ },
+ ConnectionState::ServerSendPubKey => {
+ error = true;
+ }, // nop
+ ConnectionState::ClientNew => {
+ error = true;
+ }, // nop
+ ConnectionState::ClientWaitKey(_t) => {
+ match pkt.kind() {
+ PacketType::PublicKeyNonce => {
+ let data_pkt: &HandshakePacket = interpret_packet(&pkt).as_ref().unwrap();
+ conn.add_remote_key(&data_pkt.public_key, &data_pkt.nonce);
+ conn.state = ConnectionState::ClientSendAck;
+ },
+ _ => { }
+ }
+ },
+ ConnectionState::ClientSendAck => {
+ error = true;
+ }, // nop
+ ConnectionState::Encrypted => {
+ error = true;
+ }, // nop
}
- // TODO: read_exact won't work.
- if let Ok(pkt) = read_packet(buf) {
- println!("Packet type: {}", pkt.kind() as u16);
- match pkt.kind() {
- PacketType::PublicKeyNonce => {
- let data_pkt: &HandshakePacket = interpret_packet(&pkt).as_ref().unwrap();
- conn.add_remote_key(&data_pkt.public_key, &data_pkt.nonce);
- println!("Session key: {:?}", conn.local_key.session.as_ref().unwrap());
- conn.state = ConnectionState::Encrypted;
- },
- _ => {},
- }
+ if error {
+ conn.state = match conn.is_server {
+ true => ConnectionState::ServerNew,
+ _ => ConnectionState::ClientNew,
+ };
}
- true
}
-pub fn crypto_send_data<T,U>(conn: &mut ConnectionContext, in_buf: &[u8], out_buf: T) -> Result<(), &'static str>
+pub fn crypto_handshake_done(conn: &ConnectionContext) -> bool {
+ match conn.state {
+ ConnectionState::Encrypted => true,
+ _ => false,
+ }
+}
+
+pub fn crypto_send_data<T,U>(conn: &mut ConnectionContext, in_buf: &[u8], mut out_buf: T) -> Result<u16, &'static str>
where T: std::ops::DerefMut<Target = U>,
U: std::io::Write {
match conn.state {
ConnectionState::Encrypted => {},
_ => { return Err("Encrypted channel not established."); }
}
+ let mut next_msg_id = conn.local_msg_id;
+ let bytes;
let aad = [];
let mut ciphertext = Vec::with_capacity(in_buf.len());
- let tag = encrypt(conn.local_key.session.as_ref().unwrap(), &conn.local_key.nonce, &aad, in_buf, &mut ciphertext).unwrap();
- println!("encrypted: {:?} {:?}", ciphertext, tag);
+ let tag = encrypt(conn.local_key.session.as_ref().unwrap(),
+ &conn.local_key.nonce, &aad, in_buf, &mut ciphertext).unwrap();
let pkt: EncryptedPacket = EncryptedPacket {
tag_len: tag.len() as u16,
buf.extend(struct_as_slice(&pkt));
buf.extend(&ciphertext);
buf.extend(&tag);
- let _ = write_packet(out_buf, &buf, 0, PacketType::EncryptedData);
- Ok(())
+ let _ = write_packet(&mut out_buf, &buf,
+ &mut next_msg_id, PacketType::EncryptedData);
+ bytes = buf.len() as u16;
+ conn.local_msg_id = next_msg_id;
+ Ok(bytes)
}
-pub fn crypto_recv_data<T,U,R,V>(conn: &mut ConnectionContext, in_buf: T, mut out_buf: R) -> Result<(), &'static str>
+pub fn crypto_recv_data<T,U,R,V>(conn: &mut ConnectionContext, in_buf: T, mut out_buf: R) -> Result<u16, &'static str>
where T: std::ops::DerefMut<Target = U>,
U: std::io::Read,
R: std::ops::DerefMut<Target = V>,
V: std::io::Write {
+ let mut bytes: u16 = 0u16;
match conn.state {
ConnectionState::Encrypted => {},
_ => { return Err("Encrypted channel not established."); }
}
- if let Ok(pkt) = read_packet(in_buf) {
- println!("Packet type: {}", pkt.kind() as u16);
- match pkt.kind() {
- PacketType::EncryptedData => {
- let (data_pkt, rest) = interpret_packet_extra::<EncryptedPacket>(&pkt).unwrap();
- let ciphertext = &rest[..data_pkt.data_len as usize];
- let tag = &rest[data_pkt.data_len as usize..];
- let aad = [];
- let mut plaintext = Vec::with_capacity(ciphertext.len());
- let _ = decrypt(conn.local_key.session.as_ref().unwrap(), &conn.remote_key.as_ref().unwrap().nonce, &aad, &ciphertext, &tag, &mut plaintext);
- let _ = out_buf.write(&plaintext);
- },
- _ => {
- return Err("Received non-encrypted data on encrypted channel.");
- },
- }
+ //if let Ok(pkt) = read_packet(in_buf) {
+ match read_packet(in_buf) {
+ Ok(pkt) => {
+ if pkt.header.msg_id != conn.remote_msg_id {
+ println!("Message gap detected. Restarting connection.");
+ conn.reset_state();
+ return Ok(0u16); // TODO: return error
+ }
+ conn.remote_msg_id = pkt.header.msg_id + 1;
+
+ match pkt.kind() {
+ PacketType::EncryptedData => {
+ let (data_pkt, rest) = interpret_packet_extra::<EncryptedPacket>(&pkt).unwrap();
+ let ciphertext = &rest[..data_pkt.data_len as usize];
+ let tag = &rest[data_pkt.data_len as usize..];
+ let aad = [];
+ let mut plaintext = Vec::with_capacity(ciphertext.len());
+ let _ = decrypt(conn.local_key.session.as_ref().unwrap(),
+ &conn.remote_key.as_ref().unwrap().nonce,
+ &aad, &ciphertext, &tag, &mut plaintext);
+ let _ = out_buf.write(&plaintext);
+ bytes = ciphertext.len() as u16;
+ },
+ _ => {
+ println!("bad packet: {:x}", pkt.kind() as u16);
+ return Err("Received non-encrypted data on encrypted channel.");
+ },
+ }
+ },
+ Err(_e) => {
+ // TODO
+ },
}
- Ok(())
+ Ok(bytes)
}
#[cfg(test)]
mod tests {
use std::thread;
+ use std::time;
use std::net::{TcpListener, TcpStream};
use crate::*;
fn event_loop<T>(mut conn: ConnectionContext, mut stream: T, is_server: bool) -> Result<(), std::io::Error>
where T: std::io::Read + std::io::Write {
- while crypto_send_handshake(&mut conn, &mut stream) == true {}
- while crypto_recv_handshake(&mut conn, &mut stream) == true {}
+ while crypto_handshake_done(&conn) == false {
+ if crypto_send_handshake(&mut conn, &mut stream) {
+ crypto_recv_handshake(&mut conn, &mut stream);
+ }
+ }
if is_server {
let mut plaintext = "hello, world".as_bytes();
let _ = crypto_send_data(&mut conn, &mut plaintext, &mut stream);
+ let _ = stream.flush();
+ loop {
+ std::thread::sleep(time::Duration::from_millis(50));
+ }
}
loop {
let mut plaintext = vec!();
let _ = crypto_recv_data(&mut conn, &mut stream, &mut plaintext);
println!("decrypted: {:?}", String::from_utf8(plaintext));
+ std::thread::sleep(time::Duration::from_millis(50));
}
}
let listener = TcpListener::bind("127.0.0.1:9988").unwrap();
for stream in listener.incoming() {
let stream: TcpStream = stream.unwrap();
- let conn = ConnectionContext::new();
+ let conn = ConnectionContext::new(true);
let _ = event_loop(conn, stream, true);
}
Ok(())
pub fn client() -> Result<(), std::io::Error> {
let stream = TcpStream::connect("127.0.0.1:9988").unwrap();
- let conn = ConnectionContext::new();
+ let conn = ConnectionContext::new(false);
let _ = event_loop(conn, stream, false);
Ok(())
}