summary history branches tags files
commit:686446f173e41625f27e7c77f57bcb2ff71c01d2
author:Trevor Bentley
committer:Trevor Bentley
date:Thu May 30 17:47:14 2019 +0200
parents:bbef62042e72079a05b0cb5b193d8a8fc2513b96
implement disconnect()
diff --git a/examples/example.rs b/examples/example.rs
line changes: +21/-10
index 6e8a12b..a0e4036
--- a/examples/example.rs
+++ b/examples/example.rs
@@ -12,19 +12,22 @@ use std::thread;
 use std::net::{TcpListener, TcpStream};
 
 fn event_loop(mut conn: OssuaryConnection,
-              mut stream: TcpStream,
-              is_server: bool) -> Result<(), std::io::Error> {
-    let mut strings = vec!("message1", "message2", "message3");
+              mut stream: TcpStream) -> Result<(), std::io::Error> {
+    let mut strings = vec!("message3", "message2", "message1");
     let mut plaintext = Vec::<u8>::new();
     let start = std::time::Instant::now();
+    let name = match conn.is_server() {
+        true => "server",
+        false => "client",
+    };
 
     // Simply run for 2 seconds
-    while start.elapsed().as_secs() < 2 {
+    while start.elapsed().as_secs() < 5 {
         match conn.handshake_done() {
             // Handshaking
             Ok(false) => {
-                conn.send_handshake(&mut stream).expect("handshake failed");
-                conn.recv_handshake(&mut stream).expect("handshake failed");
+                let _ = conn.send_handshake(&mut stream).unwrap(); // you should check errors
+                let _ = conn.recv_handshake(&mut stream);
             },
             // Transmitting on encrypted connection
             Ok(true) => {
@@ -32,18 +35,26 @@ fn event_loop(mut conn: OssuaryConnection,
                     let _ = conn.send_data(plaintext.as_bytes(), &mut stream);
                 }
                 if let Ok(_) =  conn.recv_data(&mut stream, &mut plaintext) {
-                    println!("({}) received: {:?}", is_server,
+                    println!("({}) received: {:?}", name,
                              String::from_utf8(plaintext.clone()).unwrap());
                     plaintext.clear();
                 }
+                // Client issues a disconnect when finished
+                if strings.is_empty() && !conn.is_server() {
+                    conn.disconnect(false);
+                }
             },
             // Trust-On-First-Use
             Err(OssuaryError::UntrustedServer(pubkey)) => {
                 let keys: Vec<&[u8]> = vec![&pubkey];
                 let _ = conn.add_authorized_keys(keys).unwrap();
             }
+            Err(OssuaryError::ConnectionClosed) => {
+                println!("({}) Finished succesfully", name);
+                break;
+            },
             // Uh-oh.
-            Err(e) => panic!("Handshake failed with error: {:?}", e),
+            Err(e) => panic!("({}) Handshake failed with error: {:?}", name, e),
         }
     }
     Ok(())
@@ -55,7 +66,7 @@ fn server() -> Result<(), std::io::Error> {
     let _ = stream.set_read_timeout(Some(std::time::Duration::from_millis(100u64)));
     // This server lets any client connect
     let conn = OssuaryConnection::new(ConnectionType::UnauthenticatedServer, None).unwrap();
-    let _ = event_loop(conn, stream, true);
+    let _ = event_loop(conn, stream);
     Ok(())
 }
 
@@ -64,7 +75,7 @@ fn client() -> Result<(), std::io::Error> {
     let _ = stream.set_read_timeout(Some(std::time::Duration::from_millis(100u64)));
     // This client doesn't know any servers, but will use Trust-On-First-Use
     let conn = OssuaryConnection::new(ConnectionType::Client, None).unwrap();
-    let _ = event_loop(conn, stream, false);
+    let _ = event_loop(conn, stream);
     Ok(())
 }
 

diff --git a/src/comm.rs b/src/comm.rs
line changes: +14/-2
index b0946d4..5c17181
--- a/src/comm.rs
+++ b/src/comm.rs
@@ -14,6 +14,7 @@
 // limitations under the License.
 //
 use crate::*;
+use crate::handshake::ResetPacket;
 
 use std::convert::TryInto;
 
@@ -187,6 +188,7 @@ impl OssuaryConnection {
         let bytes_written: usize;
         let mut bytes_read: usize = 0;
         match self.state {
+            ConnectionState::Failed(_) => { return Ok((0,0)); }
             ConnectionState::Encrypted => {},
             _ => {
                 return Err(OssuaryError::InvalidPacket(
@@ -199,6 +201,7 @@ impl OssuaryConnection {
                 bytes_read += bytes;
                 if pkt.header.msg_id != self.remote_msg_id {
                     match pkt.kind() {
+                        PacketType::Disconnect |
                         PacketType::Reset => {},
                         _ => {
                             let msg_id = pkt.header.msg_id;
@@ -218,8 +221,17 @@ impl OssuaryConnection {
                         return Err(OssuaryError::ConnectionReset(bytes_read));
                     },
                     PacketType::Disconnect => {
-                        self.reset_state(Some(OssuaryError::ConnectionFailed));
-                        return Err(OssuaryError::ConnectionFailed);
+                        let rs_pkt = interpret_packet::<ResetPacket>(&pkt)?;
+                        match rs_pkt.error {
+                            true => {
+                                self.reset_state(Some(OssuaryError::ConnectionFailed));
+                                return Err(OssuaryError::ConnectionFailed);
+                            },
+                            false => {
+                                self.reset_state(Some(OssuaryError::ConnectionClosed));
+                                return Err(OssuaryError::ConnectionClosed);
+                            },
+                        }
                     },
                     PacketType::EncryptedData => {
                         match interpret_packet_extra::<EncryptedPacket>(&pkt) {

diff --git a/src/connection.rs b/src/connection.rs
line changes: +25/-0
index fdd1446..14aef13
--- a/src/connection.rs
+++ b/src/connection.rs
@@ -94,6 +94,31 @@ impl OssuaryConnection {
         })
     }
 
+    /// Terminate a connection, or an on-going connection attempt.
+    ///
+    /// Calling this immediately closes the local end of Ossuary's connection,
+    /// and queues a disconnect packet to be sent to the remote host to inform
+    /// it to close its end.
+    ///
+    /// After calling disconnect(), the application should continue calling
+    /// Ossuary's functions (or at least its handshake functions) in a loop
+    /// until [`OssuaryConnection::handshake_done`] returns the matching error.
+    /// This allows Ossuary to generate the final disconnect packet.
+    ///
+    /// The handshake will return [`OssuaryError::ConnectionFailed`] if 'error'
+    /// is true, or [`OssuaryError::ConnectionClosed`] otherwise.
+    ///
+    /// 'error' - Indicates the reason for termination.  True means the channel
+    ///           is being closed because of some error, False means it is being
+    ///           closed due to completion or a clean shutdown.
+    ///
+    pub fn disconnect(&mut self, error: bool) {
+        match error {
+            true => self.reset_state(Some(OssuaryError::ConnectionFailed)),
+            false => self.reset_state(Some(OssuaryError::ConnectionClosed)),
+        }
+    }
+
     /// Get the initial state machine state of this connection
     pub(crate) fn initial_state(&self) -> ConnectionState {
         match self.conn_type {

diff --git a/src/error.rs b/src/error.rs
line changes: +10/-0
index e36618a..febf62a
--- a/src/error.rs
+++ b/src/error.rs
@@ -145,6 +145,15 @@ pub enum OssuaryError {
     /// When one side fails, it attempts to trigger a failure on the other side
     /// as well.
     ConnectionFailed,
+
+    /// The connection has been closed by request.
+    ///
+    /// This indicates that the connection has been permanently closed by
+    /// the local side's request, and not because of an error.  A call to
+    /// [`disconnect`] triggers this.
+    ///
+    /// [`disconnect`]: crate::OssuaryConnection::disconnect
+    ConnectionClosed,
 }
 impl std::fmt::Debug for OssuaryError {
     fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
@@ -160,6 +169,7 @@ impl std::fmt::Debug for OssuaryError {
             OssuaryError::InvalidSignature => write!(f, "OssuaryError::InvalidSignature"),
             OssuaryError::ConnectionReset(_) => write!(f, "OssuaryError::ConnectionReset"),
             OssuaryError::ConnectionFailed => write!(f, "OssuaryError::ConnectionFailed"),
+            OssuaryError::ConnectionClosed => write!(f, "OssuaryError::ConnectionClosed"),
             OssuaryError::UntrustedServer(_) => write!(f, "OssuaryError::UntrustedServer"),
             OssuaryError::DecryptionFailed => write!(f, "OssuaryError::DecryptionFailed"),
             OssuaryError::WrongProtocolVersion(r,l) => write!(f, "OssuaryError:WrongProtocolVersion {} != {}", r, l),

diff --git a/src/handshake.rs b/src/handshake.rs
line changes: +43/-8
index bda8294..5415f06
--- a/src/handshake.rs
+++ b/src/handshake.rs
@@ -28,15 +28,26 @@ const CLIENT_AUTH_SUBPACKET_LEN: usize = ::std::mem::size_of::<ClientEncryptedAu
     ::std::mem::size_of::<EncryptedPacket>() + TAG_LEN;
 
 #[repr(C,packed)]
-struct ResetPacket {
+pub(crate) struct ResetPacket {
     len: u16,
-    _reserved: u16,
+    pub(crate) error: bool,
+    _reserved: u8,
 }
 impl Default for ResetPacket {
     fn default() -> ResetPacket {
         ResetPacket {
             len: ::std::mem::size_of::<ResetPacket> as u16,
-            _reserved: 0u16,
+            error: true,
+            _reserved: 0u8,
+        }
+    }
+}
+impl ResetPacket {
+    fn closed() -> ResetPacket {
+        ResetPacket {
+            len: ::std::mem::size_of::<ResetPacket> as u16,
+            error: false,
+            _reserved: 0u8,
         }
     }
 }
@@ -373,11 +384,14 @@ impl OssuaryConnection {
                 w
             },
 
-            ConnectionState::Failing(_) => {
+            ConnectionState::Failing(ref e) => {
                 // Tell remote host to disconnect
-                let pkt: ResetPacket = Default::default();
+                let pkt: ResetPacket = match e {
+                    OssuaryError::ConnectionClosed => ResetPacket::closed(),
+                    _ => Default::default(),
+                };
                 let w = write_packet(self, &mut buf, struct_as_slice(&pkt),
-                                         PacketType::Disconnect)?;
+                                     PacketType::Disconnect)?;
                 w
             },
 
@@ -408,6 +422,7 @@ impl OssuaryConnection {
     where T: std::ops::DerefMut<Target = U>,
           U: std::io::Read {
         match self.state {
+            ConnectionState::Failed(_) |
             ConnectionState::Encrypted => return Ok(0),
             // Timeout wait states
             ConnectionState::ServerWaitHandshake(t) |
@@ -447,14 +462,24 @@ impl OssuaryConnection {
                 }
             },
             PacketType::Disconnect => {
-                self.reset_state(Some(OssuaryError::ConnectionFailed));
-                return Err(OssuaryError::ConnectionFailed);
+                let rs_pkt = interpret_packet::<ResetPacket>(&pkt)?;
+                match rs_pkt.error {
+                    true => {
+                        self.reset_state(Some(OssuaryError::ConnectionFailed));
+                        return Err(OssuaryError::ConnectionFailed);
+                    },
+                    false => {
+                        self.reset_state(Some(OssuaryError::ConnectionClosed));
+                        return Err(OssuaryError::ConnectionClosed);
+                    },
+                }
             },
             _ => {},
         }
 
         if pkt.header.msg_id != self.remote_msg_id {
             match pkt.kind() {
+                PacketType::Disconnect |
                 PacketType::Reset => {},
                 _ => {
                     match self.state {
@@ -729,6 +754,16 @@ impl OssuaryConnection {
 
     /// Returns whether the handshake process is complete.
     ///
+    /// Returns an error if the connection has failed, and specifically raises
+    /// [`OssuaryError::UntrustedServer`] if the handshake has stalled because
+    /// the remote host sent an authentication key that is not trusted.
+    ///
+    /// In the event of an untrusted server, calling
+    /// [`OssuaryConnection::add_authorized_key`] will mark the key as trusted
+    /// and allow the handshake to continue.  This should only be done if the
+    /// application is implementing a Trust-On-First-Use policy, and has
+    /// verified that the remote host's key has never been seen before.  It is
+    /// always best practice to prompt the user in this case before continuing.
     ///
     pub fn handshake_done(&mut self) -> Result<bool, OssuaryError> {
         match self.state {

diff --git a/tests/corruption.rs b/tests/corruption.rs
line changes: +34/-3
index 0048a0e..97916f2
--- a/tests/corruption.rs
+++ b/tests/corruption.rs
@@ -100,7 +100,17 @@ fn corruption() {
                         false => panic!("Unexpected connection failure."),
                     }
                 }
-                Err(e) => panic!("Handshake failed with error: {:?}", e),
+                Err(e) => match e {
+                    ref e if e == &corruption.4 => {}, // expected error
+                    OssuaryError::ConnectionFailed => {
+                        match corruption.5 {
+                            true => break,
+                            false => panic!("Unexpected connection failure."),
+                        }
+                    },
+                    OssuaryError::ConnectionReset(b) => { recv_buf.drain(0..b); },
+                    _ => panic!("Handshake failed: {:?}", e),
+                },
             }
             match recv_conn.handshake_done() {
                 Ok(true) => done += 1,
@@ -111,7 +121,17 @@ fn corruption() {
                         false => panic!("Unexpected connection failure."),
                     }
                 }
-                Err(e) => panic!("Handshake failed with error: {:?}", e),
+                Err(e) => match e {
+                    ref e if e == &corruption.4 => {}, // expected error
+                    OssuaryError::ConnectionFailed => {
+                        match corruption.5 {
+                            true => break,
+                            false => panic!("Unexpected connection failure."),
+                        }
+                    },
+                    OssuaryError::ConnectionReset(b) => { recv_buf.drain(0..b); },
+                    _ => panic!("Handshake failed: {:?}", e),
+                },
             }
             if done == 2 {
                 break;
@@ -160,7 +180,18 @@ fn corruption() {
                         },
                     }
                 },
-                Err(e) => panic!("ERROR: {:?}", e),
+                //Err(e) => panic!("ERROR: {:?}", e),
+                Err(e) => match e {
+                    ref e if e == &corruption.4 => {}, // expected error
+                    OssuaryError::ConnectionFailed => {
+                        match corruption.5 {
+                            true => break,
+                            false => panic!("Unexpected connection failure."),
+                        }
+                    },
+                    OssuaryError::ConnectionReset(b) => { recv_buf.drain(0..b); },
+                    _ => panic!("Handshake failed: {:?}", e),
+                },
                 _ => {},
             }
             loop_conn = match loop_conn {