// Size of the random data to be signed by client
const CHALLENGE_LEN: usize = 256;
+// Internal buffer for copy of network data
+const PACKET_BUF_SIZE: usize = 16384
+ + ::std::mem::size_of::<PacketHeader>()
+ + ::std::mem::size_of::<EncryptedPacket>()
+ + 16; // chacha20 tag
+
//
// API:
// * sock -- TCP data socket
pub enum OssuaryError {
Io(std::io::Error),
+ WouldBlock(usize), // bytes consumed
Unpack(core::array::TryFromSliceError),
KeySize(usize, usize), // (expected, actual)
InvalidKey,
AuthenticatedServer,
UnauthenticatedServer,
}
+
pub struct ConnectionContext {
state: ConnectionState,
conn_type: ConnectionType,
authorized_keys: Vec<[u8; 32]>,
secret_key: Option<SecretKey>,
public_key: Option<PublicKey>,
+ packet_buf: [u8; PACKET_BUF_SIZE],
+ packet_buf_used: usize,
}
impl ConnectionContext {
pub fn new(conn_type: ConnectionType) -> ConnectionContext {
authorized_keys: vec!(),
secret_key: None,
public_key: None,
+ packet_buf: [0u8; PACKET_BUF_SIZE],
+ packet_buf_used: 0,
}
}
fn reset_state(&mut self, permanent_err: Option<OssuaryError>) {
Ok(s)
}
-fn interpret_packet_extra<'a, T>(pkt: &'a NetworkPacket) -> Result<(&'a T, &[u8]), OssuaryError> {
+fn interpret_packet_extra<'a, T>(pkt: &'a NetworkPacket) -> Result<(&'a T, &'a [u8]), OssuaryError> {
let s: &T = slice_as_struct(&pkt.data)?;
Ok((s, &pkt.data[::std::mem::size_of::<T>()..]))
}
-fn read_packet<T,U>(mut stream: T) -> Result<NetworkPacket, OssuaryError>
+fn read_packet<T,U>(conn: &mut ConnectionContext, mut stream: T) ->Result<(NetworkPacket, usize), OssuaryError>
where T: std::ops::DerefMut<Target = U>,
U: std::io::Read {
- let mut buf: Box<[u8]> = Box::new([0u8; ::std::mem::size_of::<PacketHeader>()]);
- let _ = stream.read_exact(&mut buf)?;
+ let header_size = ::std::mem::size_of::<PacketHeader>();
+ let bytes_read: usize;
+ match stream.read(&mut conn.packet_buf[conn.packet_buf_used..]) {
+ Ok(b) => bytes_read = b,
+ Err(ref e) if e.kind() == std::io::ErrorKind::WouldBlock => {
+ return Err(OssuaryError::WouldBlock(0))
+ },
+ Err(e) => return Err(e.into()),
+ }
+ conn.packet_buf_used += bytes_read;
+ let buf: &[u8] = &conn.packet_buf;
let hdr = PacketHeader {
len: u16::from_be_bytes(buf[0..2].try_into()?),
msg_id: u16::from_be_bytes(buf[2..4].try_into()?),
packet_type: PacketType::from_u16(u16::from_be_bytes(buf[4..6].try_into()?)),
_reserved: u16::from_be_bytes(buf[6..8].try_into()?),
};
- let mut buf: Box<[u8]> = vec![0u8; hdr.len as usize].into_boxed_slice();
- let _ = stream.read_exact(&mut buf)?;
- Ok(NetworkPacket {
+ let packet_len = hdr.len as usize;
+ if conn.packet_buf_used < header_size + packet_len {
+ return Err(OssuaryError::WouldBlock(bytes_read));
+ }
+ let buf: Box<[u8]> = (&conn.packet_buf[header_size..header_size+packet_len])
+ .to_vec().into_boxed_slice();
+ let excess = conn.packet_buf_used - header_size - packet_len;
+ unsafe {
+ // no safe way to memmove() in Rust?
+ std::ptr::copy::<u8>(
+ conn.packet_buf.as_ptr().offset((header_size + packet_len) as isize),
+ conn.packet_buf.as_mut_ptr(),
+ excess);
+ }
+ conn.packet_buf_used = excess;
+ Ok((NetworkPacket {
header: hdr,
data: buf,
- })
+ },
+ header_size + packet_len))
}
fn write_packet<T,U>(stream: &mut T, data: &[u8], msg_id: &mut u16, kind: PacketType) -> Result<(), std::io::Error>
more
}
-pub fn crypto_recv_handshake<T,U>(conn: &mut ConnectionContext, buf: T)
+// TODO u16 should be usize
+pub fn crypto_recv_handshake<T,U>(conn: &mut ConnectionContext, buf: T) -> Result<usize, OssuaryError>
where T: std::ops::DerefMut<Target = U>,
U: std::io::Read {
// TODO: read_exact won't work.
- let pkt: NetworkPacket = match read_packet(buf) {
- Ok(p) => p,
- Err(_) => {
- return;
+ let mut bytes_read: usize = 0;
+ let pkt: NetworkPacket = match read_packet(conn, buf) {
+ Ok((p, r)) => {
+ bytes_read += r;
+ p
+ },
+ Err(OssuaryError::WouldBlock(b)) => {
+ return Err(OssuaryError::WouldBlock(b));
+ }
+ Err(e) => {
+ return Err(e);
}
};
println!("Message gap detected. Restarting connection.");
println!("Server: {}", conn.is_server());
conn.reset_state(None);
- return; // TODO: return error
+ return Err(OssuaryError::InvalidPacket("Message ID does not match".into()));
}
conn.remote_msg_id = pkt.header.msg_id + 1;
match pkt.kind() {
PacketType::Reset => {
conn.reset_state(None);
- return;
+ return Err(OssuaryError::ConnectionFailed);
},
PacketType::Disconnect => {
// TODO: handle error
Some(ref k) => k,
None => {
conn.reset_state(None);
- return;
+ return Err(OssuaryError::InvalidKey);
}
};
let remote_nonce = match conn.remote_key {
Some(ref rem) => rem.nonce,
None => {
conn.reset_state(None);
- return;
+ return Err(OssuaryError::InvalidKey);
}
};
let _ = decrypt(session_key,
Ok(p) => p,
Err(_) => {
conn.reset_state(None);
- return;
+ return Err(OssuaryError::InvalidKey);
}
};
let sig = match Signature::from_bytes(sig) {
Ok(s) => s,
Err(_) => {
conn.reset_state(None);
- return;
+ return Err(OssuaryError::InvalidKey);
}
};
let challenge = match conn.challenge {
Some(ref c) => c,
None => {
conn.reset_state(None);
- return;
+ return Err(OssuaryError::InvalidKey);
}
};
match public.verify::<Sha512>(challenge, &sig) {
},
Err(_) => {
conn.reset_state(None);
- return;
+ return Err(OssuaryError::InvalidPacket("Response invalid".into()));
},
};
},
Some(ref k) => k,
None => {
conn.reset_state(None);
- return;
+ return Err(OssuaryError::InvalidKey);
}
};
let remote_nonce = match conn.remote_key {
Some(ref rem) => rem.nonce,
None => {
conn.reset_state(None);
- return;
+ return Err(OssuaryError::InvalidKey);
}
};
let _ = decrypt(session_key,
if error {
conn.reset_state(None);
}
+ Ok(bytes_read)
}
// TODO: should return a Result with error on forced-disconnect or permanent failure
R: std::ops::DerefMut<Target = V>,
V: std::io::Write {
let bytes_written: u16;
- let bytes_read: u16;
+ let mut bytes_read: u16 = 0u16;
match conn.state {
ConnectionState::Encrypted => {},
_ => {
}
}
- match read_packet(in_buf) {
- Ok(pkt) => {
+ match read_packet(conn, in_buf) {
+ Ok((pkt, bytes)) => {
+ bytes_read += bytes as u16;
if pkt.header.msg_id != conn.remote_msg_id {
println!("Message gap detected. Restarting connection.");
println!("Server: {}", conn.is_server());
&aad, &ciphertext, &tag, &mut plaintext);
let _ = out_buf.write(&plaintext);
bytes_written = ciphertext.len() as u16;
- bytes_read = (ciphertext.len() +
- ::std::mem::size_of::<PacketHeader>() +
- ::std::mem::size_of::<EncryptedPacket>() +
- tag.len()) as u16;
},
Err(_) => {
conn.reset_state(None);
},
}
},
+ Err(OssuaryError::WouldBlock(b)) => {
+ return Err(OssuaryError::WouldBlock(b));
+ },
Err(_e) => {
return Err(OssuaryError::InvalidPacket("Packet header did not parse.".into()));
},
let mut server_conn = ConnectionContext::new(ConnectionType::UnauthenticatedServer);
while crypto_handshake_done(&server_conn).unwrap() == false {
if crypto_send_handshake(&mut server_conn, &mut server_stream) {
- crypto_recv_handshake(&mut server_conn, &mut server_stream);
+ loop {
+ match crypto_recv_handshake(&mut server_conn, &mut server_stream) {
+ Ok(_) => break,
+ Err(OssuaryError::WouldBlock(_)) => {},
+ _ => panic!("Handshake failed"),
+ }
+ }
}
}
let mut plaintext = vec!();
let mut bytes: u64 = 0;
let start = std::time::SystemTime::now();
loop {
- bytes += crypto_recv_data(&mut server_conn,
- &mut server_stream,
- &mut plaintext).unwrap().0 as u64;
+ match crypto_recv_data(&mut server_conn,
+ &mut server_stream,
+ &mut plaintext) {
+ Ok((read, _written)) => bytes += read as u64,
+ Err(OssuaryError::WouldBlock(_)) => continue,
+ _ => panic!("Recv failed"),
+ }
if plaintext == [0xde, 0xde, 0xbe, 0xbe] {
if let Ok(dur) = start.elapsed() {
let t = dur.as_secs() as f64
std::thread::sleep(std::time::Duration::from_millis(500));
let mut client_stream = TcpStream::connect("127.0.0.1:9987").unwrap();
+ client_stream.set_nonblocking(true).unwrap();
let mut client_conn = ConnectionContext::new(ConnectionType::Client);
while crypto_handshake_done(&client_conn).unwrap() == false {
if crypto_send_handshake(&mut client_conn, &mut client_stream) {
- crypto_recv_handshake(&mut client_conn, &mut client_stream);
+ loop {
+ match crypto_recv_handshake(&mut client_conn, &mut client_stream) {
+ Ok(_) => break,
+ Err(OssuaryError::WouldBlock(_)) => {},
+ _ => panic!("Handshake failed"),
+ }
+ }
}
}
let mut client_stream = std::io::BufWriter::new(client_stream);
}
let mut plaintext: &[u8] = &[0xde, 0xde, 0xbe, 0xbe];
let _ = crypto_send_data(&mut client_conn, &mut plaintext, &mut client_stream);
- // Unwrap the BufWriter, flushing the buffer
- let _ = client_stream.into_inner().unwrap();
+ drop(client_stream); // flush the buffer
let _ = server_thread.join();
}
}
use ossuary::{ConnectionContext, ConnectionType};
use ossuary::{crypto_send_handshake,crypto_recv_handshake, crypto_handshake_done};
use ossuary::{crypto_send_data,crypto_recv_data};
+use ossuary::OssuaryError;
use std::thread;
use std::net::{TcpListener, TcpStream};
// Run the opaque handshake until the connection is established
while crypto_handshake_done(&conn).unwrap() == false {
if crypto_send_handshake(&mut conn, &mut stream) {
- crypto_recv_handshake(&mut conn, &mut stream);
+ loop {
+ match crypto_recv_handshake(&mut conn, &mut stream) {
+ Ok(_) => break,
+ Err(OssuaryError::WouldBlock(_)) => {},
+ _ => panic!("Handshake failed."),
+ }
+ }
}
}
// Send a message to the other party
- let mut plaintext = match is_server {
- true => "message from server".as_bytes(),
- false => "message from client".as_bytes(),
+ let strings = ("message_from_server", "message_from_client");
+ let (mut plaintext, response) = match is_server {
+ true => (strings.0.as_bytes(), strings.1.as_bytes()),
+ false => (strings.1.as_bytes(), strings.0.as_bytes()),
};
let _ = crypto_send_data(&mut conn, &mut plaintext, &mut stream);
// Read a message from the other party
- let mut plaintext = vec!();
- let _ = crypto_recv_data(&mut conn, &mut stream, &mut plaintext);
- println!("(basic) received: {:?}", String::from_utf8(plaintext).unwrap());
+ let mut recv_plaintext = vec!();
+ loop {
+ match crypto_recv_data(&mut conn, &mut stream, &mut recv_plaintext) {
+ Ok(_) => {
+ println!("(basic) received: {:?}",
+ String::from_utf8(recv_plaintext.clone()).unwrap());
+ assert_eq!(recv_plaintext.as_slice(), response);
+ break;
+ },
+ _ => {},
+ }
+ }
Ok(())
}
use ossuary::{ConnectionContext, ConnectionType};
use ossuary::{crypto_send_handshake,crypto_recv_handshake, crypto_handshake_done};
use ossuary::{crypto_send_data,crypto_recv_data};
+use ossuary::OssuaryError;
use std::thread;
use std::net::{TcpListener, TcpStream};
// Run the opaque handshake until the connection is established
while crypto_handshake_done(&conn).unwrap() == false {
if crypto_send_handshake(&mut conn, &mut stream) {
- crypto_recv_handshake(&mut conn, &mut stream);
+ loop {
+ match crypto_recv_handshake(&mut conn, &mut stream) {
+ Ok(_) => break,
+ Err(OssuaryError::WouldBlock(_)) => {},
+ _ => panic!("Handshake failed."),
+ }
+ }
}
}
// Send a message to the other party
- let mut plaintext = match is_server {
- true => "message from server".as_bytes(),
- false => "message from client".as_bytes(),
+ let strings = ("message_from_server", "message_from_client");
+ let (mut plaintext, response) = match is_server {
+ true => (strings.0.as_bytes(), strings.1.as_bytes()),
+ false => (strings.1.as_bytes(), strings.0.as_bytes()),
};
let _ = crypto_send_data(&mut conn, &mut plaintext, &mut stream);
// Read a message from the other party
- let mut plaintext = vec!();
- let _ = crypto_recv_data(&mut conn, &mut stream, &mut plaintext);
- println!("(basic) received: {:?}", String::from_utf8(plaintext).unwrap());
-
+ let mut recv_plaintext = vec!();
+ loop {
+ match crypto_recv_data(&mut conn, &mut stream, &mut recv_plaintext) {
+ Ok(_) => {
+ println!("(basic_auth) received: {:?}",
+ String::from_utf8(recv_plaintext.clone()).unwrap());
+ assert_eq!(recv_plaintext.as_slice(), response);
+ break;
+ },
+ _ => {},
+ }
+ }
Ok(())
}
}
#[test]
-fn basic() {
+fn basic_auth() {
let server = thread::spawn(move || { let _ = server(); });
std::thread::sleep(std::time::Duration::from_millis(500));
let child = thread::spawn(move || { let _ = client(); });