| use super::Stream; |
| use futures_util::future::poll_fn; |
| use futures_util::task::noop_waker_ref; |
| use rustls::{ClientConnection, Connection, ServerConnection}; |
| use std::io::{self, Cursor, Read, Write}; |
| use std::pin::Pin; |
| use std::task::{Context, Poll}; |
| use tokio::io::{AsyncRead, AsyncReadExt, AsyncWrite, AsyncWriteExt, ReadBuf}; |
| |
| struct Good<'a>(&'a mut Connection); |
| |
| impl<'a> AsyncRead for Good<'a> { |
| fn poll_read( |
| mut self: Pin<&mut Self>, |
| _cx: &mut Context<'_>, |
| buf: &mut ReadBuf<'_>, |
| ) -> Poll<io::Result<()>> { |
| let mut buf2 = buf.initialize_unfilled(); |
| |
| Poll::Ready(match self.0.write_tls(buf2.by_ref()) { |
| Ok(n) => { |
| buf.advance(n); |
| Ok(()) |
| } |
| Err(err) => Err(err), |
| }) |
| } |
| } |
| |
| impl<'a> AsyncWrite for Good<'a> { |
| fn poll_write( |
| mut self: Pin<&mut Self>, |
| _cx: &mut Context<'_>, |
| mut buf: &[u8], |
| ) -> Poll<io::Result<usize>> { |
| let len = self.0.read_tls(buf.by_ref())?; |
| self.0 |
| .process_new_packets() |
| .map_err(|err| io::Error::new(io::ErrorKind::InvalidData, err))?; |
| Poll::Ready(Ok(len)) |
| } |
| |
| fn poll_flush(mut self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll<io::Result<()>> { |
| self.0 |
| .process_new_packets() |
| .map_err(|err| io::Error::new(io::ErrorKind::InvalidData, err))?; |
| Poll::Ready(Ok(())) |
| } |
| |
| fn poll_shutdown(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> { |
| self.0.send_close_notify(); |
| dbg!("sent close notify"); |
| self.poll_flush(cx) |
| } |
| } |
| |
| struct Pending; |
| |
| impl AsyncRead for Pending { |
| fn poll_read( |
| self: Pin<&mut Self>, |
| _cx: &mut Context<'_>, |
| _: &mut ReadBuf<'_>, |
| ) -> Poll<io::Result<()>> { |
| Poll::Pending |
| } |
| } |
| |
| impl AsyncWrite for Pending { |
| fn poll_write( |
| self: Pin<&mut Self>, |
| _cx: &mut Context<'_>, |
| _buf: &[u8], |
| ) -> Poll<io::Result<usize>> { |
| Poll::Pending |
| } |
| |
| fn poll_flush(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll<io::Result<()>> { |
| Poll::Ready(Ok(())) |
| } |
| |
| fn poll_shutdown(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll<io::Result<()>> { |
| Poll::Ready(Ok(())) |
| } |
| } |
| |
| struct Expected(Cursor<Vec<u8>>); |
| |
| impl AsyncRead for Expected { |
| fn poll_read( |
| self: Pin<&mut Self>, |
| _cx: &mut Context<'_>, |
| buf: &mut ReadBuf<'_>, |
| ) -> Poll<io::Result<()>> { |
| let this = self.get_mut(); |
| let n = std::io::Read::read(&mut this.0, buf.initialize_unfilled())?; |
| buf.advance(n); |
| |
| Poll::Ready(Ok(())) |
| } |
| } |
| |
| impl AsyncWrite for Expected { |
| fn poll_write( |
| self: Pin<&mut Self>, |
| _cx: &mut Context<'_>, |
| buf: &[u8], |
| ) -> Poll<io::Result<usize>> { |
| Poll::Ready(Ok(buf.len())) |
| } |
| |
| fn poll_flush(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll<io::Result<()>> { |
| Poll::Ready(Ok(())) |
| } |
| |
| fn poll_shutdown(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll<io::Result<()>> { |
| Poll::Ready(Ok(())) |
| } |
| } |
| |
| #[tokio::test] |
| async fn stream_good() -> io::Result<()> { |
| const FILE: &[u8] = include_bytes!("../../README.md"); |
| |
| let (server, mut client) = make_pair(); |
| let mut server = Connection::from(server); |
| poll_fn(|cx| do_handshake(&mut client, &mut server, cx)).await?; |
| |
| io::copy(&mut Cursor::new(FILE), &mut server.writer())?; |
| server.send_close_notify(); |
| |
| { |
| let mut good = Good(&mut server); |
| let mut stream = Stream::new(&mut good, &mut client); |
| |
| let mut buf = Vec::new(); |
| dbg!(stream.read_to_end(&mut buf).await)?; |
| assert_eq!(buf, FILE); |
| |
| dbg!(stream.write_all(b"Hello World!").await)?; |
| stream.session.send_close_notify(); |
| |
| dbg!(stream.shutdown().await)?; |
| } |
| |
| let mut buf = String::new(); |
| dbg!(server.process_new_packets()).map_err(|e| io::Error::new(io::ErrorKind::Other, e))?; |
| dbg!(server.reader().read_to_string(&mut buf))?; |
| assert_eq!(buf, "Hello World!"); |
| |
| Ok(()) as io::Result<()> |
| } |
| |
| #[tokio::test] |
| async fn stream_bad() -> io::Result<()> { |
| let (server, mut client) = make_pair(); |
| let mut server = Connection::from(server); |
| poll_fn(|cx| do_handshake(&mut client, &mut server, cx)).await?; |
| client.set_buffer_limit(Some(1024)); |
| |
| let mut bad = Pending; |
| let mut stream = Stream::new(&mut bad, &mut client); |
| assert_eq!( |
| poll_fn(|cx| stream.as_mut_pin().poll_write(cx, &[0x42; 8])).await?, |
| 8 |
| ); |
| assert_eq!( |
| poll_fn(|cx| stream.as_mut_pin().poll_write(cx, &[0x42; 8])).await?, |
| 8 |
| ); |
| let r = poll_fn(|cx| stream.as_mut_pin().poll_write(cx, &[0x00; 1024])).await?; // fill buffer |
| assert!(r < 1024); |
| |
| let mut cx = Context::from_waker(noop_waker_ref()); |
| let ret = stream.as_mut_pin().poll_write(&mut cx, &[0x01]); |
| assert!(ret.is_pending()); |
| |
| Ok(()) as io::Result<()> |
| } |
| |
| #[tokio::test] |
| async fn stream_handshake() -> io::Result<()> { |
| let (server, mut client) = make_pair(); |
| let mut server = Connection::from(server); |
| |
| { |
| let mut good = Good(&mut server); |
| let mut stream = Stream::new(&mut good, &mut client); |
| let (r, w) = poll_fn(|cx| stream.handshake(cx)).await?; |
| |
| assert!(r > 0); |
| assert!(w > 0); |
| |
| poll_fn(|cx| stream.handshake(cx)).await?; // finish server handshake |
| } |
| |
| assert!(!server.is_handshaking()); |
| assert!(!client.is_handshaking()); |
| |
| Ok(()) as io::Result<()> |
| } |
| |
| #[tokio::test] |
| async fn stream_buffered_handshake() -> io::Result<()> { |
| use tokio::io::BufWriter; |
| |
| let (server, mut client) = make_pair(); |
| let mut server = Connection::from(server); |
| |
| { |
| let mut good = BufWriter::new(Good(&mut server)); |
| let mut stream = Stream::new(&mut good, &mut client); |
| let (r, w) = poll_fn(|cx| stream.handshake(cx)).await?; |
| |
| assert!(r > 0); |
| assert!(w > 0); |
| |
| poll_fn(|cx| stream.handshake(cx)).await?; // finish server handshake |
| } |
| |
| assert!(!server.is_handshaking()); |
| assert!(!client.is_handshaking()); |
| |
| Ok(()) as io::Result<()> |
| } |
| |
| #[tokio::test] |
| async fn stream_handshake_eof() -> io::Result<()> { |
| let (_, mut client) = make_pair(); |
| |
| let mut bad = Expected(Cursor::new(Vec::new())); |
| let mut stream = Stream::new(&mut bad, &mut client); |
| |
| let mut cx = Context::from_waker(noop_waker_ref()); |
| let r = stream.handshake(&mut cx); |
| assert_eq!( |
| r.map_err(|err| err.kind()), |
| Poll::Ready(Err(io::ErrorKind::UnexpectedEof)) |
| ); |
| |
| Ok(()) as io::Result<()> |
| } |
| |
| // see https://github.com/tokio-rs/tls/issues/77 |
| #[tokio::test] |
| async fn stream_handshake_regression_issues_77() -> io::Result<()> { |
| let (_, mut client) = make_pair(); |
| |
| let mut bad = Expected(Cursor::new(b"\x15\x03\x01\x00\x02\x02\x00".to_vec())); |
| let mut stream = Stream::new(&mut bad, &mut client); |
| |
| let mut cx = Context::from_waker(noop_waker_ref()); |
| let r = stream.handshake(&mut cx); |
| assert_eq!( |
| r.map_err(|err| err.kind()), |
| Poll::Ready(Err(io::ErrorKind::UnexpectedEof)) |
| ); |
| |
| Ok(()) as io::Result<()> |
| } |
| |
| #[tokio::test] |
| async fn stream_eof() -> io::Result<()> { |
| let (server, mut client) = make_pair(); |
| let mut server = Connection::from(server); |
| poll_fn(|cx| do_handshake(&mut client, &mut server, cx)).await?; |
| |
| let mut bad = Expected(Cursor::new(Vec::new())); |
| let mut stream = Stream::new(&mut bad, &mut client); |
| |
| let mut buf = Vec::new(); |
| let result = stream.read_to_end(&mut buf).await; |
| assert_eq!( |
| result.err().map(|e| e.kind()), |
| Some(io::ErrorKind::UnexpectedEof) |
| ); |
| |
| Ok(()) as io::Result<()> |
| } |
| |
| fn make_pair() -> (ServerConnection, ClientConnection) { |
| use std::convert::TryFrom; |
| |
| let (sconfig, cconfig) = utils::make_configs(); |
| let server = ServerConnection::new(sconfig).unwrap(); |
| |
| let domain = rustls::ServerName::try_from("foobar.com").unwrap(); |
| let client = ClientConnection::new(cconfig, domain).unwrap(); |
| |
| (server, client) |
| } |
| |
| fn do_handshake( |
| client: &mut ClientConnection, |
| server: &mut Connection, |
| cx: &mut Context<'_>, |
| ) -> Poll<io::Result<()>> { |
| let mut good = Good(server); |
| let mut stream = Stream::new(&mut good, client); |
| |
| while stream.session.is_handshaking() { |
| ready!(stream.handshake(cx))?; |
| } |
| |
| while stream.session.wants_write() { |
| ready!(stream.write_io(cx))?; |
| } |
| |
| Poll::Ready(Ok(())) |
| } |
| |
| // Share `utils` module with integration tests |
| include!("../../tests/utils.rs"); |