use chacha20_poly1305_aead::{encrypt,decrypt};
use x25519_dalek::generate_secret;
use x25519_dalek::generate_public;
+use x25519_dalek::diffie_hellman;
+
use rand::thread_rng;
+use rand::RngCore;
-use std::collections::VecDeque;
use std::thread;
-//use std::io::prelude::*;
+use std::rc::Rc;
+use std::cell::RefCell;
+use std::collections::VecDeque;
use std::net::{TcpListener, TcpStream};
//
// * data (encrypted)
//
-use x25519_dalek::diffie_hellman;
-
fn struct_as_slice<T: Sized>(p: &T) -> &[u8] {
unsafe {
::std::slice::from_raw_parts(
)
}
}
-//unsafe fn struct_as_mut_slice<T: Sized>(p: &mut T) -> &[u8] {
-// ::std::slice::from_raw_parts_mut(
-// (p as *mut T) as *mut u8,
-// ::std::mem::size_of::<T>(),
-// )
-//}
-//unsafe fn slice_as_struct<T: Sized>(p: &[u8]) -> Result<&T, &'static str> {
fn slice_as_struct<T>(p: &[u8]) -> Result<&T, &'static str> {
unsafe {
if p.len() < ::std::mem::size_of::<T>() {
Ok(&*(&p[..::std::mem::size_of::<T>()] as *const [u8] as *const T))
}
}
-//unsafe fn slice_as_mut_struct<T: Sized>(p: &mut [u8]) -> Result<&mut T, &'static str> {
-// if p.len() < ::std::mem::size_of::<T>() {
-// return Err("Cannot cast bytes to struct: size mismatch");
-// }
-// Ok(&mut *(&mut p[..::std::mem::size_of::<T>()] as *mut [u8] as *mut T))
-//}
-//unsafe fn slice_as_owned_struct<T: Sized>(mut p: Box<[u8]>) -> Result<T, &'static str> {
-// if p.len() < ::std::mem::size_of::<T>() {
-// return Err("Cannot cast bytes to struct: size mismatch");
-// }
-// println!("box size: {} / T size: {}", ::std::mem::size_of_val(&*p), ::std::mem::size_of::<T>());
-// println!("p: {}", p[0]);
-// let p = Box::into_raw(p);
-// Ok(*Box::from_raw(p as *mut T))
-//}
-//unsafe fn vec_as_struct<T: Sized>(p: &Vec<u8>) -> Result<&T, &'static str> {
-// if p.len() < ::std::mem::size_of::<T>() {
-// return Err("Cannot cast bytes to struct: buffer too small");
-// }
-// let s = &p.as_slice()[..::std::mem::size_of::<T>()];
-// Ok(&*(s as *const [u8] as *const T))
-//}
-
#[repr(packed)]
#[allow(dead_code)]
struct HandshakePacket {
_reserved: u16,
}
-//#[repr(packed)]
-//struct Packet<'a> {
-// header: PacketHeader,
-// data: &'a [u8],
-//}
-//impl <'a> Packet<'a> {
-// fn new(msg_id: u16, kind: PacketType, data: &'a [u8]) -> Packet {
-// Packet {
-// header: PacketHeader {
-// len: data.len() as u16,
-// msg_id: msg_id,
-// packet_type: kind,
-// _reserved: 0u16,
-// },
-// data: data
-// }
-// }
-//}
-
-
-//enum ServerConnectionState {
-// New,
-// Encrypted,
-// Authenticated,
-//}
-//
-//struct ServerConnectionContext {
-// state: ServerConnectionState,
-// remote_public: [u8; 32],
-// nonce: [u8; 12],
-//}
-
-//pub fn read_struct<'a,T>(stream: &mut TcpStream, mut buf: &'a mut [u8]) -> Result<&'a T, &'static str> {
-// let sz = ::std::mem::size_of::<T>();
-// let buf_len = stream.read_exact(&mut buf[..sz]).unwrap();
-// let pkt: &T = unsafe { slice_as_struct(buf)? };
-// Ok(pkt)
-//}
-//
-//pub fn read_owned_struct<'a,T>(stream: &mut TcpStream) -> Result<T, &'static str> {
-// let sz = ::std::mem::size_of::<T>();
-// let mut buf: Box<[u8]> = vec![0; sz].into_boxed_slice();
-// let buf_len = stream.read_exact(&mut buf[..sz]).unwrap();
-// let pkt: T = unsafe { slice_as_owned_struct(buf)? };
-// Ok(pkt)
-//}
+enum ConnectionState {
+ New,
+ Encrypted,
+ _Authenticated,
+}
+struct KeyMaterial {
+ secret: Option<[u8; 32]>,
+ public: [u8; 32],
+ session: Option<[u8; 32]>,
+ nonce: [u8; 12],
+}
+struct ConnectionContext {
+ stream: Rc<RefCell<TcpStream>>,
+ state: ConnectionState,
+ local_key: KeyMaterial,
+ remote_key: Option<KeyMaterial>,
+}
+impl ConnectionContext {
+ fn new(stream: TcpStream) -> ConnectionContext {
+ let mut rng = thread_rng();
+ let sec_key = generate_secret(&mut rng);
+ let pub_key = generate_public(&sec_key);
+ let mut nonce: [u8; 12] = [0; 12];
+ rng.fill_bytes(&mut nonce);
+ let key = KeyMaterial {
+ secret: Some(sec_key),
+ public: pub_key.to_bytes(),
+ nonce: nonce,
+ session: None,
+ };
+ ConnectionContext {
+ stream: Rc::new(RefCell::new(stream)),
+ state: ConnectionState::New,
+ local_key: key,
+ remote_key: None,
+ }
+ }
+ fn add_remote_key(&mut self, public: &[u8; 32], nonce: &[u8; 12]) {
+ let key = KeyMaterial {
+ secret: None,
+ public: public.to_owned(),
+ nonce: nonce.to_owned(),
+ session: None,
+ };
+ self.remote_key = Some(key);
+ self.local_key.session = Some(diffie_hellman(self.local_key.secret.as_ref().unwrap(), public));
+ self.state = ConnectionState::Encrypted;
+ }
+}
struct NetworkPacket {
header: PacketHeader,
Ok((s, &pkt.data[::std::mem::size_of::<T>()..]))
}
-fn read_packet<T>(stream: &mut T) -> Result<NetworkPacket, std::io::Error>
-where T: std::io::Read {
+//fn read_packet<T>(stream: &mut T) -> Result<NetworkPacket, std::io::Error>
+//where T: std::io::Read {
+fn read_packet<T,U>(mut stream: T) -> Result<NetworkPacket, std::io::Error>
+where T: std::ops::DerefMut<Target = U>,
+ U: std::io::Read {
let mut buf: Box<[u8]> = Box::new([0u8; ::std::mem::size_of::<PacketHeader>()]);
println!("read header");
let _ = stream.read_exact(&mut buf)?;
})
}
-fn write_packet<T>(stream: &mut T, data: &[u8], msg_id: u16, kind: PacketType) -> Result<(), std::io::Error>
-where T: std::io::Write {
+//fn write_packet<T>(stream: &mut T, data: &[u8], msg_id: u16, kind: PacketType) -> Result<(), std::io::Error>
+//where T: std::io::Write {
+fn write_packet<T,U>(mut stream: T, data: &[u8], msg_id: u16, kind: PacketType) -> Result<(), std::io::Error>
+where T: std::ops::DerefMut<Target = U>,
+ U: std::io::Write {
let buf: Box<[u8]> = Box::new([0u8; ::std::mem::size_of::<PacketHeader>()]);
unsafe {
std::slice::from_raw_parts_mut(buf.as_ptr() as *mut u16, buf.len())[0] = u16::to_be(data.len() as u16);
Ok(())
}
-pub fn event_loop(stream: &mut TcpStream, is_server: bool) -> Result<(), std::io::Error> {
- let mut seed = thread_rng();
- let sec_key = generate_secret(&mut seed);
- let pub_key = generate_public(&sec_key);
- println!("conn: {} {}", stream.peer_addr().unwrap(), is_server);
+//fn crypto_send_handshake<T>(conn: &ConnectionContext, buf: &mut T)
+//where T: std::io::Write {
+fn crypto_send_handshake<T,U>(conn: &ConnectionContext, buf: T)
+where T: std::ops::DerefMut<Target = U>,
+ U: std::io::Write {
+ match conn.state {
+ ConnectionState::New => {
+ let mut pkt: HandshakePacket = Default::default();
+ pkt.public_key.copy_from_slice(&conn.local_key.public);
+ pkt.nonce.copy_from_slice(&conn.local_key.nonce);
+ let _ = write_packet(buf, struct_as_slice(&pkt), 0, PacketType::PublicKeyNonce);
+ },
+ _ => {}
+ }
+}
+
+fn crypto_recv_handshake<T,U>(conn: &mut ConnectionContext, buf: T) -> bool
+where T: std::ops::DerefMut<Target = U>,
+ U: std::io::Read {
+ match conn.state {
+ ConnectionState::New => {},
+ _ => { return false; }
+ }
+ // TODO: read_exact won't work.
+ if let Ok(pkt) = read_packet(buf) {
+ println!("Packet type: {}", pkt.kind() as u16);
+ match pkt.kind() {
+ _ => {},
+ }
+ }
+ true
+}
- //let _ = stream.write(pub_key.as_bytes());
+//fn crypto_send_data(conn: &ConnectionContext, buf: &[u8]) {
+//}
+//fn crypto_recv_data(conn: &ConnectionContext, buf: &mut [u8]) {
+//}
- let mut remote_public: [u8; 32] = [0; 32];
- let mut nonce: [u8; 12] = [0; 12];
- let mut sess_key: [u8; 32] = [0; 32];
+//pub fn event_loop(stream: &mut TcpStream, is_server: bool) -> Result<(), std::io::Error> {
+fn event_loop(mut conn: ConnectionContext, is_server: bool) -> Result<(), std::io::Error> {
+ println!("conn: {} {}", conn.stream.borrow().peer_addr().unwrap(), is_server);
- let mut pkt: HandshakePacket = Default::default();
- pkt.public_key.copy_from_slice(pub_key.as_bytes());
- pkt.nonce.copy_from_slice(&[1,0,0,0,0,1,0,0,0,0,1,0]);
- let _ = write_packet(stream, struct_as_slice(&pkt), 0, PacketType::PublicKeyNonce);
+ crypto_send_handshake(&conn, conn.stream.borrow_mut());
+ //let mut pkt: HandshakePacket = Default::default();
+ //pkt.public_key.copy_from_slice(&conn.local_key.public);
+ //pkt.nonce.copy_from_slice(&conn.local_key.nonce);
+ //let _ = write_packet(&mut conn.stream, struct_as_slice(&pkt), 0, PacketType::PublicKeyNonce);
let mut pkt_queue: VecDeque<NetworkPacket> = VecDeque::with_capacity(20);
loop {
- pkt_queue.push_back(read_packet(stream)?);
+ pkt_queue.push_back(read_packet(conn.stream.borrow_mut())?);
if let Some(pkt) = pkt_queue.pop_front() {
println!("Packet type: {}", pkt.kind() as u16);
match pkt.kind() {
PacketType::PublicKeyNonce => {
- let data_pkt: &HandshakePacket = interpret_packet(&pkt).unwrap();
- remote_public.copy_from_slice(&data_pkt.public_key);
- nonce.copy_from_slice(&data_pkt.nonce);
- sess_key = diffie_hellman(&sec_key, &remote_public);
- println!("Server key: {:?}", sess_key);
+ let data_pkt: &HandshakePacket = interpret_packet(&pkt).as_ref().unwrap();
+ conn.add_remote_key(&data_pkt.public_key, &data_pkt.nonce);
+ println!("Server key: {:?}", conn.local_key.session.as_ref().unwrap());
if is_server {
let aad = [1, 2, 3, 4];
let plaintext = b"hello, world";
let mut ciphertext = Vec::with_capacity(plaintext.len());
- let tag = encrypt(&sess_key, &nonce, &aad, plaintext, &mut ciphertext).unwrap();
+ let tag = encrypt(conn.local_key.session.as_ref().unwrap(), &conn.local_key.nonce, &aad, plaintext, &mut ciphertext).unwrap();
println!("encrypted: {:?} {:?}", ciphertext, tag);
let pkt: EncryptedPacket = EncryptedPacket {
buf.extend(struct_as_slice(&pkt));
buf.extend(&ciphertext);
buf.extend(&tag);
- let _ = write_packet(stream, &buf, 0, PacketType::EncryptedData);
+ let _ = write_packet(conn.stream.borrow_mut(), &buf, 0, PacketType::EncryptedData);
}
},
PacketType::AuthRequest => {},
let aad = [1, 2, 3, 4];
let mut plaintext = Vec::with_capacity(ciphertext.len());
println!("decrypting: {:?} {:?}", ciphertext, tag);
- let _ = decrypt(&sess_key, &nonce, &aad, &ciphertext, &tag, &mut plaintext);
+ let _ = decrypt(conn.local_key.session.as_ref().unwrap(), &conn.remote_key.as_ref().unwrap().nonce, &aad, &ciphertext, &tag, &mut plaintext);
println!("decrypted: {:?}", String::from_utf8(plaintext));
},
PacketType::Disconnect => {},
}
-pub fn server2() -> Result<(), std::io::Error> {
+pub fn server() -> Result<(), std::io::Error> {
let listener = TcpListener::bind("127.0.0.1:9988").unwrap();
for stream in listener.incoming() {
- let mut stream: TcpStream = stream.unwrap();
- let _ = event_loop(&mut stream, true);
+ let stream: TcpStream = stream.unwrap();
+ let conn = ConnectionContext::new(stream);
+ let _ = event_loop(conn, true);
}
Ok(())
}
-pub fn client2() -> Result<(), std::io::Error> {
- let mut stream = TcpStream::connect("127.0.0.1:9988").unwrap();
- let _ = event_loop(&mut stream, false);
+pub fn client() -> Result<(), std::io::Error> {
+ let stream = TcpStream::connect("127.0.0.1:9988").unwrap();
+ let conn = ConnectionContext::new(stream);
+ let _ = event_loop(conn, false);
Ok(())
}
-//pub fn server() {
-// let listener = TcpListener::bind("127.0.0.1:9988").unwrap();
-// for stream in listener.incoming() {
-// let mut stream: TcpStream = stream.unwrap();
-// stream.set_nonblocking(false);
-//
-// let mut seed = thread_rng();
-// let sec_key = generate_secret(&mut seed);
-// let pub_key = generate_public(&sec_key);
-// println!("Server conn: {}", stream.peer_addr().unwrap());
-// let _ = stream.write(pub_key.as_bytes());
-//
-// let mut buf: Box<[u8]> = Box::new([0u8; 65535]);
-// {
-// let hdr: &PacketHeader = read_struct(&mut stream, &mut buf).unwrap();
-// println!("hdr len: {}", hdr.len);
-// }
-// let pkt: &HandshakePacket = read_struct(&mut stream, &mut buf).unwrap();
-// println!("pkt len: {}", pkt.len);
-// println!("Server write");
-//
-// let mut remote_public: [u8; 32] = [0; 32];
-// let mut nonce: [u8; 12] = [0; 12];
-// remote_public.copy_from_slice(&pkt.public_key);
-// nonce.copy_from_slice(&pkt.nonce);
-//
-// let sess_key = diffie_hellman(&sec_key, &remote_public);
-// println!("Server key: {:?}", sess_key);
-//
-// let aad = [1, 2, 3, 4];
-// let plaintext = b"hello, world";
-// let mut ciphertext = Vec::with_capacity(plaintext.len());
-// let tag = encrypt(&sess_key, &nonce, &aad, plaintext, &mut ciphertext).unwrap();
-// println!("encrypted: {:?} {:?}", ciphertext, tag);
-//
-// unsafe {
-// let pkt: EncryptedPacket = EncryptedPacket {
-// tag_len: tag.len() as u16,
-// data_len: ciphertext.len() as u16,
-// };
-// let pkt = Packet::new(0, PacketType::EncryptedData, struct_as_slice(&pkt));
-// //let _ = stream.write(struct_as_slice(&pkt));
-// let _ = stream.write(struct_as_slice(&pkt.header));
-// let _ = stream.write(pkt.data);
-// let _ = stream.write(&ciphertext);
-// let _ = stream.write(&tag);
-// }
-// }
-//}
-//
-//pub fn client() {
-// let mut stream = TcpStream::connect("127.0.0.1:9988").unwrap();
-// stream.set_nonblocking(false);
-// let mut seed = thread_rng();
-// let sec_key = generate_secret(&mut seed);
-// let pub_key = generate_public(&sec_key);
-//
-// let mut pkt: HandshakePacket = Default::default();
-// pkt.public_key.copy_from_slice(pub_key.as_bytes());
-// pkt.nonce.copy_from_slice(&[1,0,0,0,0,1,0,0,0,0,1,0]);
-//
-// unsafe {
-// let pkt = Packet::new(0, PacketType::PublicKeyNonce, struct_as_slice(&pkt));
-// //let _ = stream.write(struct_as_slice(&pkt));
-// let _ = stream.write(struct_as_slice(&pkt.header));
-// let _ = stream.write(pkt.data);
-// }
-// //let _ = stream.write(pub_key.as_bytes());
-// println!("Client write");
-// let mut buf = vec![0u8; 4096];
-// let buf_len = stream.read(&mut buf).unwrap();
-// println!("Client read: {}", buf_len);
-// let mut remote_public: [u8; 32] = [0; 32];
-// remote_public.copy_from_slice(&buf[..32]);
-// let sess_key = diffie_hellman(&sec_key, &remote_public);
-// println!("Client key: {:?}", sess_key);
-//
-// let hdr_len = {
-// let hdr: &PacketHeader = read_struct(&mut stream, &mut buf).unwrap();
-// println!("hdr len: {}", hdr.len);
-// hdr.len
-// } as usize;
-// let (tag_len, data_len) = {
-// let pkt: &EncryptedPacket = read_struct(&mut stream, &mut buf).unwrap();
-// (pkt.tag_len as usize, pkt.data_len as usize)
-// };
-// let mut tag: Vec<u8> = vec![0; tag_len];
-// let mut ciphertext: Vec<u8> = vec![0; data_len];
-// let _ = stream.read_exact(&mut ciphertext).unwrap();
-// let _ = stream.read_exact(&mut tag).unwrap();
-// let aad = [1, 2, 3, 4];
-// let mut plaintext = Vec::with_capacity(ciphertext.len());
-// println!("decrypting: {:?} {:?}", ciphertext, tag);
-// decrypt(&sess_key, &pkt.nonce, &aad, &ciphertext, &tag, &mut plaintext);
-// println!("decrypted: {:?}", String::from_utf8(plaintext));
-//
-// println!("Client done");
-//}
-
pub fn test() {
- thread::spawn(move || { let _ = server2(); });
- let child = thread::spawn(move || { let _ = client2(); });
+ thread::spawn(move || { let _ = server(); });
+ let child = thread::spawn(move || { let _ = client(); });
let _ = child.join();
- //let mut alice_csprng = thread_rng();
- //let alice_secret = generate_secret(&mut alice_csprng);
- //let alice_public = generate_public(&alice_secret);
- //let mut bob_csprng = thread_rng();
- //let bob_secret = generate_secret(&mut bob_csprng);
- //let bob_public = generate_public(&bob_secret);
- //use x25519_dalek::diffie_hellman;
- //let shared_secret_a = diffie_hellman(&alice_secret, &bob_public.as_bytes());
- //let shared_secret_b = diffie_hellman(&bob_secret, &alice_public.as_bytes());
- //let key = [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16,
- // 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31];
- //let nonce = [1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0];
- //let aad = [1, 2, 3, 4];
- //let plaintext = b"hello, world";
- //let mut ciphertext = Vec::with_capacity(plaintext.len());
- //let tag = encrypt(&key, &nonce, &aad, plaintext, &mut ciphertext).unwrap();
- //println!("encrypted: {:?}", ciphertext);
- //let mut plaintext = Vec::with_capacity(ciphertext.len());
- //decrypt(&key, &nonce, &aad, &ciphertext, &tag, &mut plaintext);
- //println!("decrypted: {:?}", String::from_utf8(plaintext));
}
#[cfg(test)]