+#![feature(try_from)]
+
extern crate x25519_dalek;
extern crate rand;
extern crate chacha20_poly1305_aead;
use rand::thread_rng;
use rand::RngCore;
+use std::convert::TryInto;
+
use std::thread;
-use std::rc::Rc;
-use std::cell::RefCell;
-use std::collections::VecDeque;
use std::net::{TcpListener, TcpStream};
//
#[derive(Clone, Copy)]
#[allow(dead_code)]
enum PacketType {
- PublicKeyNonce,
- AuthRequest,
- EncryptedData,
- Disconnect,
+ Unknown = 0,
+ PublicKeyNonce = 0x01,
+ AuthRequest = 0x02,
+ EncryptedData = 0x03,
+ Disconnect = 0x04,
+}
+impl PacketType {
+ pub fn from_u16(i: u16) -> PacketType {
+ match i {
+ 0x01 => PacketType::PublicKeyNonce,
+ 0x02 => PacketType::AuthRequest,
+ 0x03 => PacketType::EncryptedData,
+ 0x04 => PacketType::Disconnect,
+ _ => PacketType::Unknown,
+ }
+ }
}
#[repr(packed)]
enum ConnectionState {
New,
+ PubKeySent,
Encrypted,
_Authenticated,
}
nonce: [u8; 12],
}
struct ConnectionContext {
- stream: Rc<RefCell<TcpStream>>,
state: ConnectionState,
local_key: KeyMaterial,
remote_key: Option<KeyMaterial>,
}
impl ConnectionContext {
- fn new(stream: TcpStream) -> ConnectionContext {
+ fn new() -> ConnectionContext {
let mut rng = thread_rng();
let sec_key = generate_secret(&mut rng);
let pub_key = generate_public(&sec_key);
session: None,
};
ConnectionContext {
- stream: Rc::new(RefCell::new(stream)),
state: ConnectionState::New,
local_key: key,
remote_key: None,
Ok((s, &pkt.data[::std::mem::size_of::<T>()..]))
}
-//fn read_packet<T>(stream: &mut T) -> Result<NetworkPacket, std::io::Error>
-//where T: std::io::Read {
-fn read_packet<T,U>(mut stream: T) -> Result<NetworkPacket, std::io::Error>
+pub enum OssuaryError {
+ Io(std::io::Error),
+ Unpack(core::array::TryFromSliceError),
+}
+impl From<std::io::Error> for OssuaryError {
+ fn from(error: std::io::Error) -> Self {
+ OssuaryError::Io(error)
+ }
+}
+impl From<core::array::TryFromSliceError> for OssuaryError {
+ fn from(error: core::array::TryFromSliceError) -> Self {
+ OssuaryError::Unpack(error)
+ }
+}
+
+fn read_packet<T,U>(mut stream: T) -> Result<NetworkPacket, OssuaryError>
where T: std::ops::DerefMut<Target = U>,
U: std::io::Read {
let mut buf: Box<[u8]> = Box::new([0u8; ::std::mem::size_of::<PacketHeader>()]);
- println!("read header");
let _ = stream.read_exact(&mut buf)?;
let hdr = PacketHeader {
- len: unsafe {
- u16::from_be(std::slice::from_raw_parts(buf.as_ptr() as *const u16, buf.len())[0])
- },
- msg_id: unsafe {
- u16::from_be(std::slice::from_raw_parts(buf.as_ptr() as *const u16, buf.len())[1])
- },
- packet_type: unsafe {
- std::mem::transmute(u16::from_be(std::slice::from_raw_parts(buf.as_ptr() as *const u16, buf.len())[2]))
- },
- _reserved: unsafe {
- u16::from_be(std::slice::from_raw_parts(buf.as_ptr() as *const u16, buf.len())[3])
- },
+ len: u16::from_be_bytes(buf[0..2].try_into()?),
+ msg_id: u16::from_be_bytes(buf[2..4].try_into()?),
+ packet_type: PacketType::from_u16(u16::from_be_bytes(buf[4..6].try_into()?)),
+ _reserved: u16::from_be_bytes(buf[6..8].try_into()?),
};
let mut buf: Box<[u8]> = vec![0u8; hdr.len as usize].into_boxed_slice();
let _ = stream.read_exact(&mut buf)?;
fn write_packet<T,U>(mut stream: T, data: &[u8], msg_id: u16, kind: PacketType) -> Result<(), std::io::Error>
where T: std::ops::DerefMut<Target = U>,
U: std::io::Write {
- let buf: Box<[u8]> = Box::new([0u8; ::std::mem::size_of::<PacketHeader>()]);
- unsafe {
- std::slice::from_raw_parts_mut(buf.as_ptr() as *mut u16, buf.len())[0] = u16::to_be(data.len() as u16);
- }
- unsafe {
- std::slice::from_raw_parts_mut(buf.as_ptr() as *mut u16, buf.len())[1] = u16::to_be(msg_id);
- }
- unsafe {
- std::slice::from_raw_parts_mut(buf.as_ptr() as *mut u16, buf.len())[2] = u16::to_be(kind as u16);
- }
- unsafe {
- std::slice::from_raw_parts_mut(buf.as_ptr() as *mut u16, buf.len())[3] = u16::to_be(0u16);
- }
+ let mut buf: Vec<u8> = Vec::with_capacity(::std::mem::size_of::<PacketHeader>());
+ buf.extend_from_slice(&(data.len() as u16).to_be_bytes());
+ buf.extend_from_slice(&(msg_id as u16).to_be_bytes());
+ buf.extend_from_slice(&(kind as u16).to_be_bytes());
+ buf.extend_from_slice(&(0u16).to_be_bytes());
stream.write(&buf)?;
stream.write(data)?;
Ok(())
}
-//fn crypto_send_handshake<T>(conn: &ConnectionContext, buf: &mut T)
-//where T: std::io::Write {
-fn crypto_send_handshake<T,U>(conn: &ConnectionContext, buf: T)
+fn crypto_send_handshake<T,U>(conn: &mut ConnectionContext, buf: T) -> bool
where T: std::ops::DerefMut<Target = U>,
U: std::io::Write {
match conn.state {
pkt.public_key.copy_from_slice(&conn.local_key.public);
pkt.nonce.copy_from_slice(&conn.local_key.nonce);
let _ = write_packet(buf, struct_as_slice(&pkt), 0, PacketType::PublicKeyNonce);
+ conn.state = ConnectionState::PubKeySent;
+ true
},
- _ => {}
+ _ => {
+ false
+ }
}
}
where T: std::ops::DerefMut<Target = U>,
U: std::io::Read {
match conn.state {
- ConnectionState::New => {},
+ ConnectionState::New => { return true; },
+ ConnectionState::PubKeySent => {},
_ => { return false; }
}
// TODO: read_exact won't work.
if let Ok(pkt) = read_packet(buf) {
println!("Packet type: {}", pkt.kind() as u16);
match pkt.kind() {
+ PacketType::PublicKeyNonce => {
+ let data_pkt: &HandshakePacket = interpret_packet(&pkt).as_ref().unwrap();
+ conn.add_remote_key(&data_pkt.public_key, &data_pkt.nonce);
+ println!("Session key: {:?}", conn.local_key.session.as_ref().unwrap());
+ conn.state = ConnectionState::Encrypted;
+ },
_ => {},
}
}
true
}
-//fn crypto_send_data(conn: &ConnectionContext, buf: &[u8]) {
-//}
-//fn crypto_recv_data(conn: &ConnectionContext, buf: &mut [u8]) {
-//}
+fn crypto_send_data<T,U>(conn: &mut ConnectionContext, in_buf: &[u8], out_buf: T) -> Result<(), &'static str>
+where T: std::ops::DerefMut<Target = U>,
+ U: std::io::Write {
+ match conn.state {
+ ConnectionState::Encrypted => {},
+ _ => { return Err("Encrypted channel not established."); }
+ }
+ let aad = [];
+ let mut ciphertext = Vec::with_capacity(in_buf.len());
+ let tag = encrypt(conn.local_key.session.as_ref().unwrap(), &conn.local_key.nonce, &aad, in_buf, &mut ciphertext).unwrap();
+ println!("encrypted: {:?} {:?}", ciphertext, tag);
-//pub fn event_loop(stream: &mut TcpStream, is_server: bool) -> Result<(), std::io::Error> {
-fn event_loop(mut conn: ConnectionContext, is_server: bool) -> Result<(), std::io::Error> {
- println!("conn: {} {}", conn.stream.borrow().peer_addr().unwrap(), is_server);
+ let pkt: EncryptedPacket = EncryptedPacket {
+ tag_len: tag.len() as u16,
+ data_len: ciphertext.len() as u16,
+ };
+ let mut buf: Vec<u8>= vec![];
+ buf.extend(struct_as_slice(&pkt));
+ buf.extend(&ciphertext);
+ buf.extend(&tag);
+ let _ = write_packet(out_buf, &buf, 0, PacketType::EncryptedData);
+ Ok(())
+}
- crypto_send_handshake(&conn, conn.stream.borrow_mut());
- //let mut pkt: HandshakePacket = Default::default();
- //pkt.public_key.copy_from_slice(&conn.local_key.public);
- //pkt.nonce.copy_from_slice(&conn.local_key.nonce);
- //let _ = write_packet(&mut conn.stream, struct_as_slice(&pkt), 0, PacketType::PublicKeyNonce);
+fn crypto_recv_data<T,U,R,V>(conn: &mut ConnectionContext, in_buf: T, mut out_buf: R) -> Result<(), &'static str>
+where T: std::ops::DerefMut<Target = U>,
+ U: std::io::Read,
+ R: std::ops::DerefMut<Target = V>,
+ V: std::io::Write {
+ match conn.state {
+ ConnectionState::Encrypted => {},
+ _ => { return Err("Encrypted channel not established."); }
+ }
+ if let Ok(pkt) = read_packet(in_buf) {
+ println!("Packet type: {}", pkt.kind() as u16);
+ match pkt.kind() {
+ PacketType::EncryptedData => {
+ let (data_pkt, rest) = interpret_packet_extra::<EncryptedPacket>(&pkt).unwrap();
+ let ciphertext = &rest[..data_pkt.data_len as usize];
+ let tag = &rest[data_pkt.data_len as usize..];
+ let aad = [];
+ let mut plaintext = Vec::with_capacity(ciphertext.len());
+ let _ = decrypt(conn.local_key.session.as_ref().unwrap(), &conn.remote_key.as_ref().unwrap().nonce, &aad, &ciphertext, &tag, &mut plaintext);
+ let _ = out_buf.write(&plaintext);
+ },
+ _ => {
+ return Err("Received non-encrypted data on encrypted channel.");
+ },
+ }
+ }
+ Ok(())
+}
- let mut pkt_queue: VecDeque<NetworkPacket> = VecDeque::with_capacity(20);
- loop {
- pkt_queue.push_back(read_packet(conn.stream.borrow_mut())?);
- if let Some(pkt) = pkt_queue.pop_front() {
- println!("Packet type: {}", pkt.kind() as u16);
- match pkt.kind() {
- PacketType::PublicKeyNonce => {
- let data_pkt: &HandshakePacket = interpret_packet(&pkt).as_ref().unwrap();
- conn.add_remote_key(&data_pkt.public_key, &data_pkt.nonce);
- println!("Server key: {:?}", conn.local_key.session.as_ref().unwrap());
+fn event_loop<T>(mut conn: ConnectionContext, mut stream: T, is_server: bool) -> Result<(), std::io::Error>
+where T: std::io::Read + std::io::Write {
+ while crypto_send_handshake(&mut conn, &mut stream) == true {}
+ while crypto_recv_handshake(&mut conn, &mut stream) == true {}
- if is_server {
- let aad = [1, 2, 3, 4];
- let plaintext = b"hello, world";
- let mut ciphertext = Vec::with_capacity(plaintext.len());
- let tag = encrypt(conn.local_key.session.as_ref().unwrap(), &conn.local_key.nonce, &aad, plaintext, &mut ciphertext).unwrap();
- println!("encrypted: {:?} {:?}", ciphertext, tag);
+ if is_server {
+ let mut plaintext = "hello, world".as_bytes();
+ let _ = crypto_send_data(&mut conn, &mut plaintext, &mut stream);
+ }
- let pkt: EncryptedPacket = EncryptedPacket {
- tag_len: tag.len() as u16,
- data_len: ciphertext.len() as u16,
- };
- let mut buf: Vec<u8>= vec![];
- buf.extend(struct_as_slice(&pkt));
- buf.extend(&ciphertext);
- buf.extend(&tag);
- let _ = write_packet(conn.stream.borrow_mut(), &buf, 0, PacketType::EncryptedData);
- }
- },
- PacketType::AuthRequest => {},
- PacketType::EncryptedData => {
- let (data_pkt, rest) = interpret_packet_extra::<EncryptedPacket>(&pkt).unwrap();
- let ciphertext = &rest[..data_pkt.data_len as usize];
- let tag = &rest[data_pkt.data_len as usize..];
- let aad = [1, 2, 3, 4];
- let mut plaintext = Vec::with_capacity(ciphertext.len());
- println!("decrypting: {:?} {:?}", ciphertext, tag);
- let _ = decrypt(conn.local_key.session.as_ref().unwrap(), &conn.remote_key.as_ref().unwrap().nonce, &aad, &ciphertext, &tag, &mut plaintext);
- println!("decrypted: {:?}", String::from_utf8(plaintext));
- },
- PacketType::Disconnect => {},
- }
- }
+ loop {
+ let mut plaintext = vec!();
+ let _ = crypto_recv_data(&mut conn, &mut stream, &mut plaintext);
+ println!("decrypted: {:?}", String::from_utf8(plaintext));
}
}
let listener = TcpListener::bind("127.0.0.1:9988").unwrap();
for stream in listener.incoming() {
let stream: TcpStream = stream.unwrap();
- let conn = ConnectionContext::new(stream);
- let _ = event_loop(conn, true);
+ let conn = ConnectionContext::new();
+ let _ = event_loop(conn, stream, true);
}
Ok(())
}
pub fn client() -> Result<(), std::io::Error> {
let stream = TcpStream::connect("127.0.0.1:9988").unwrap();
- let conn = ConnectionContext::new(stream);
- let _ = event_loop(conn, false);
+ let conn = ConnectionContext::new();
+ let _ = event_loop(conn, stream, false);
Ok(())
}