src/comm.rs
//
// Copyright 2019 Trevor Bentley
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//
use crate::*;
use crate::handshake::ResetPacket;
use std::convert::TryInto;
/// Read a complete network packet from the input stream.
///
/// On success, returns a NetworkPacket struct containing the header and data,
/// and a `usize` indicating how many bytes were consumed from the input buffer.
pub(crate) fn read_packet<T,U>(conn: &mut OssuaryConnection,
mut stream: T) ->Result<(NetworkPacket, usize), OssuaryError>
where T: std::ops::DerefMut<Target = U>,
U: std::io::Read {
let header_size = ::std::mem::size_of::<PacketHeader>();
let bytes_read: usize;
match stream.read(&mut conn.read_buf[conn.read_buf_used..]) {
Ok(b) => bytes_read = b,
Err(ref e) if e.kind() == std::io::ErrorKind::WouldBlock => {
return Err(OssuaryError::WouldBlock(0))
},
Err(e) => return Err(e.into()),
}
conn.read_buf_used += bytes_read;
let buf: &[u8] = &conn.read_buf;
let hdr = PacketHeader {
len: u16::from_be_bytes(buf[0..2].try_into()?),
msg_id: u16::from_be_bytes(buf[2..4].try_into()?),
packet_type: PacketType::from_u16(u16::from_be_bytes(buf[4..6].try_into()?)),
_reserved: u16::from_be_bytes(buf[6..8].try_into()?),
};
let packet_len = hdr.len as usize;
if conn.read_buf_used < header_size + packet_len {
if header_size + packet_len > PACKET_BUF_SIZE {
return Err(OssuaryError::InvalidPacket("Oversized packet".into()));
}
return Err(OssuaryError::WouldBlock(bytes_read));
}
let buf: Box<[u8]> = (&conn.read_buf[header_size..header_size+packet_len])
.to_vec().into_boxed_slice();
let excess = conn.read_buf_used - header_size - packet_len;
unsafe {
// no safe way to memmove() in Rust?
std::ptr::copy::<u8>(
conn.read_buf.as_ptr().offset((header_size + packet_len) as isize),
conn.read_buf.as_mut_ptr(),
excess);
}
conn.read_buf_used = excess;
Ok((NetworkPacket {
header: hdr,
data: buf,
},
header_size + packet_len))
}
/// Write a packet from OssuaryConnection's internal storage to the out buffer.
///
/// All packets are buffered to internal storage before writing, so this is
/// the function responsible for putting all packets "on the wire".
///
/// On success, returns the number of bytes written to the output buffer
pub(crate) fn write_stored_packet<T,U>(conn: &mut OssuaryConnection,
stream: &mut T) -> Result<usize, OssuaryError>
where T: std::ops::DerefMut<Target = U>,
U: std::io::Write {
let mut written = 0;
while written < conn.write_buf_used {
match stream.write(&conn.write_buf[written..conn.write_buf_used]) {
Ok(w) => {
written += w;
},
Err(e) => {
if written > 0 && written < conn.write_buf_used {
unsafe {
// no safe way to memmove() in Rust?
std::ptr::copy::<u8>(
conn.write_buf.as_ptr().offset(written as isize),
conn.write_buf.as_mut_ptr(),
conn.write_buf_used - written);
}
}
conn.write_buf_used -= written;
return Err(e.into());
},
}
}
conn.write_buf_used = 0;
Ok(written)
}
/// Write a packet to the OssuaryConnection's internal packet buffer
///
/// All packets are buffered internally because there is no guarantee that a
/// complete packet can be written without blocking, and Ossuary is a non-
/// blocking library.
///
/// On success, returns the number of bytes written to the output buffer.
pub(crate) fn write_packet<T,U>(conn: &mut OssuaryConnection,
stream: &mut T, data: &[u8],
kind: PacketType) -> Result<usize, OssuaryError>
where T: std::ops::DerefMut<Target = U>,
U: std::io::Write {
let msg_id = conn.local_msg_id as u16;
conn.write_buf[0..2].copy_from_slice(&(data.len() as u16).to_be_bytes());
conn.write_buf[2..4].copy_from_slice(&msg_id.to_be_bytes());
conn.write_buf[4..6].copy_from_slice(&(kind as u16).to_be_bytes());
conn.write_buf[6..8].copy_from_slice(&(0u16).to_be_bytes());
conn.write_buf[8..8+data.len()].copy_from_slice(&data);
conn.write_buf_used = 8 + data.len();
conn.local_msg_id += 1;
let written = write_stored_packet(conn, stream)?;
Ok(written)
}
impl OssuaryConnection {
/// Encrypts data into a packet suitable for sending over the network
///
/// The caller provides unencrypted plaintext data, in any format, in the
/// `in_buf` buffer. `send_data()` encrypts it and writes it in the proper
/// packet format into `out_buf`.
///
/// This is the core function for data transmission via ossuary. All data
/// to be sent over an Ossuary connection should pass through this function.
///
/// Note that Ossuary does not perform network operations itself. It is the
/// caller's responsibility to put the written data on the wire. However,
/// you may pass a 'buf' that does this automatically, such as a TcpStream.
///
/// Returns the number of bytes written to `out_buf`, or an error.
///
/// You must handle [`OssuaryError::WouldBlock`], which is a recoverable
/// error, but indicates that some bytes were written to the buffer. If any
/// bytes are written to `out_buf`, it can be assumed that all of `in_buf`
/// was consumed. In the event of a `WouldBlock` error, you can either
/// continue calling `send_data()` with the next data to be sent, or you can
/// use [`OssuaryConnection::flush()`] to explicitly finish writing the
/// packet.
pub fn send_data<T,U>(&mut self,
in_buf: &[u8],
mut out_buf: T) -> Result<usize, OssuaryError>
where T: std::ops::DerefMut<Target = U>,
U: std::io::Write {
// Try to send any unsent buffered data
match write_stored_packet(self, &mut out_buf) {
Ok(w) if w == 0 => {},
Ok(w) => return Err(OssuaryError::WouldBlock(w)),
Err(e) => return Err(e),
}
match self.state {
ConnectionState::Encrypted => {},
_ => {
return Err(OssuaryError::InvalidPacket(
"Encrypted channel not established.".into()));
}
}
let aad = [];
let mut ciphertext = Vec::with_capacity(in_buf.len());
let session_key = match self.local_key.session {
Some(ref k) => k,
None => {
self.reset_state(None);
return Err(OssuaryError::InvalidKey);;
}
};
let tag = match encrypt(session_key.as_bytes(),
&self.local_key.nonce, &aad, in_buf, &mut ciphertext) {
Ok(t) => t,
Err(_) => {
self.reset_state(None);
return Err(OssuaryError::InvalidKey);;
}
};
increment_nonce(&mut self.local_key.nonce);
let pkt: EncryptedPacket = EncryptedPacket {
tag_len: tag.len() as u16,
data_len: ciphertext.len() as u16,
};
let mut buf: Vec<u8>= vec![];
buf.extend(struct_as_slice(&pkt));
buf.extend(&ciphertext);
buf.extend(&tag);
let written = write_packet(self, &mut out_buf, &buf,
PacketType::EncryptedData)?;
Ok(written)
}
/// Decrypts data from a packet received from a remote host
///
/// The caller provides encrypted data from a remote host in the `in_buf`
/// buffer. `recv_data()` decrypts it and writes the plaintext result into
/// `out_buf`.
///
/// This is the core function for data transmission via ossuary. All data
/// received over an Ossuary connection should pass through this function.
///
/// Returns the number of bytes written to `out_buf`, or an error.
///
/// You must handle [`OssuaryError::WouldBlock`], which is a recoverable
/// error, but indicates that some bytes were read from `in_buf`. This
/// indicates that an incomplete packet was received.
pub fn recv_data<T,U,R,V>(&mut self,
in_buf: T,
mut out_buf: R) -> Result<(usize, usize), OssuaryError>
where T: std::ops::DerefMut<Target = U>,
U: std::io::Read,
R: std::ops::DerefMut<Target = V>,
V: std::io::Write {
let bytes_written: usize;
let mut bytes_read: usize = 0;
match self.state {
ConnectionState::Failed(_) => { return Ok((0,0)); }
ConnectionState::Encrypted => {},
_ => {
return Err(OssuaryError::InvalidPacket(
"Encrypted channel not established.".into()));
}
}
match read_packet(self, in_buf) {
Ok((pkt, bytes)) => {
bytes_read += bytes;
if pkt.header.msg_id != self.remote_msg_id {
match pkt.kind() {
PacketType::Disconnect |
PacketType::Reset => {},
_ => {
let msg_id = pkt.header.msg_id;
println!("Message gap detected. Restarting connection. ({} != {})",
msg_id, self.remote_msg_id);
self.reset_state(None);
return Err(OssuaryError::InvalidPacket("Message ID mismatch".into()))
},
}
}
self.remote_msg_id = pkt.header.msg_id + 1;
match pkt.kind() {
PacketType::Reset => {
self.reset_state(None);
self.state = ConnectionState::Resetting(false);
return Err(OssuaryError::ConnectionReset(bytes_read));
},
PacketType::Disconnect => {
let rs_pkt = interpret_packet::<ResetPacket>(&pkt)?;
match rs_pkt.error {
true => {
self.reset_state(Some(OssuaryError::ConnectionFailed));
return Err(OssuaryError::ConnectionFailed);
},
false => {
self.reset_state(Some(OssuaryError::ConnectionClosed));
return Err(OssuaryError::ConnectionClosed);
},
}
},
PacketType::EncryptedData => {
match interpret_packet_extra::<EncryptedPacket>(&pkt) {
Ok((data_pkt, rest)) => {
let ciphertext = &rest[..data_pkt.data_len as usize];
let tag = &rest[data_pkt.data_len as usize..];
let aad = [];
let mut plaintext = Vec::with_capacity(ciphertext.len());
let session_key = match self.local_key.session {
Some(ref k) => k,
None => {
self.reset_state(None);
return Err(OssuaryError::InvalidKey);
}
};
let remote_nonce = match self.remote_key {
Some(ref rem) => rem.nonce,
None => {
self.reset_state(None);
return Err(OssuaryError::InvalidKey);
}
};
decrypt(session_key.as_bytes(),
&remote_nonce,
&aad, &ciphertext, &tag, &mut plaintext)?;
bytes_written = match out_buf.write(&plaintext) {
Ok(w) => w,
Err(e) => return Err(e.into()),
};
let _ = self.remote_key.as_mut().map(|k| increment_nonce(&mut k.nonce));
},
Err(_) => {
self.reset_state(None);
return Err(OssuaryError::InvalidKey);
},
}
},
_ => {
return Err(OssuaryError::InvalidPacket(
"Received non-encrypted data on encrypted channel.".into()));
},
}
},
Err(OssuaryError::WouldBlock(b)) => {
return Err(OssuaryError::WouldBlock(b));
},
Err(_e) => {
self.reset_state(None);
return Err(OssuaryError::InvalidPacket("Packet header did not parse.".into()));
},
}
Ok((bytes_read, bytes_written))
}
/// Write any cached encrypted data waiting to be sent
///
/// If a previous call to [`OssuaryConnection::send_data`] was unable to
/// write out all of its data, the remaining data is cached internally. It
/// can be explicitly flushed by calling this function until it returns 0.
///
/// After each call, it is the caller's responsibility to put the written
/// data onto the network, unless `out_buf` is an object that handles that
/// implicitly, such as a TcpStream.
pub fn flush<R,V>(&mut self,
mut out_buf: R) -> Result<usize, OssuaryError>
where R: std::ops::DerefMut<Target = V>,
V: std::io::Write {
return write_stored_packet(self, &mut out_buf);
}
}