blob: 62e5042f08b9d1b0b26d714933ce5b2681933810 [file] [log] [blame]
extern crate schannel;
use self::schannel::cert_context::{CertContext, HashAlgorithm, KeySpec};
use self::schannel::cert_store::{CertAdd, CertStore, Memory, PfxImportOptions};
use self::schannel::crypt_prov::{AcquireOptions, ProviderType};
use self::schannel::schannel_cred::{Direction, Protocol, SchannelCred};
use self::schannel::tls_stream;
use std::error;
use std::fmt;
use std::io;
use std::str;
use {TlsAcceptorBuilder, TlsConnectorBuilder};
const SEC_E_NO_CREDENTIALS: u32 = 0x8009030E;
static PROTOCOLS: &'static [Protocol] = &[
Protocol::Ssl3,
Protocol::Tls10,
Protocol::Tls11,
Protocol::Tls12,
];
fn convert_protocols(min: Option<::Protocol>, max: Option<::Protocol>) -> &'static [Protocol] {
let mut protocols = PROTOCOLS;
if let Some(p) = max.and_then(|max| protocols.get(..=max as usize)) {
protocols = p;
}
if let Some(p) = min.and_then(|min| protocols.get(min as usize..)) {
protocols = p;
}
protocols
}
pub struct Error(io::Error);
impl error::Error for Error {
fn source(&self) -> Option<&(dyn error::Error + 'static)> {
error::Error::source(&self.0)
}
}
impl fmt::Display for Error {
fn fmt(&self, fmt: &mut fmt::Formatter) -> fmt::Result {
fmt::Display::fmt(&self.0, fmt)
}
}
impl fmt::Debug for Error {
fn fmt(&self, fmt: &mut fmt::Formatter) -> fmt::Result {
fmt::Debug::fmt(&self.0, fmt)
}
}
impl From<io::Error> for Error {
fn from(error: io::Error) -> Error {
Error(error)
}
}
#[derive(Clone)]
pub struct Identity {
cert: CertContext,
}
impl Identity {
pub fn from_pkcs12(buf: &[u8], pass: &str) -> Result<Identity, Error> {
let store = PfxImportOptions::new().password(pass).import(buf)?;
let mut identity = None;
for cert in store.certs() {
if cert
.private_key()
.silent(true)
.compare_key(true)
.acquire()
.is_ok()
{
identity = Some(cert);
break;
}
}
let identity = match identity {
Some(identity) => identity,
None => {
return Err(io::Error::new(
io::ErrorKind::InvalidInput,
"No identity found in PKCS #12 archive",
)
.into());
}
};
Ok(Identity { cert: identity })
}
pub fn from_pkcs8(pem: &[u8], key: &[u8]) -> Result<Identity, Error> {
if !key.starts_with(b"-----BEGIN PRIVATE KEY-----") {
return Err(io::Error::new(io::ErrorKind::InvalidInput, "not a PKCS#8 key").into());
}
let mut store = Memory::new()?.into_store();
let mut cert_iter = pem::PemBlock::new(pem).into_iter();
let leaf = cert_iter.next().ok_or_else(|| {
io::Error::new(
io::ErrorKind::InvalidInput,
"at least one certificate must be provided to create an identity",
)
})?;
let cert = CertContext::from_pem(std::str::from_utf8(leaf).map_err(|_| {
io::Error::new(
io::ErrorKind::InvalidInput,
"leaf cert contains invalid utf8",
)
})?)?;
let name = gen_container_name();
let mut options = AcquireOptions::new();
options.container(&name);
let type_ = ProviderType::rsa_full();
let mut container = match options.acquire(type_) {
Ok(container) => container,
Err(_) => options.new_keyset(true).acquire(type_)?,
};
container.import().import_pkcs8_pem(&key)?;
cert.set_key_prov_info()
.container(&name)
.type_(type_)
.keep_open(true)
.key_spec(KeySpec::key_exchange())
.set()?;
let mut context = store.add_cert(&cert, CertAdd::Always)?;
for int_cert in cert_iter {
let certificate = Certificate::from_pem(int_cert)?;
context = store.add_cert(&certificate.0, CertAdd::Always)?;
}
Ok(Identity { cert: context })
}
}
// The name of the container must be unique to have multiple active keys.
fn gen_container_name() -> String {
use std::sync::atomic::{AtomicUsize, Ordering};
static COUNTER: AtomicUsize = AtomicUsize::new(0);
format!("native-tls-{}", COUNTER.fetch_add(1, Ordering::Relaxed))
}
#[derive(Clone)]
pub struct Certificate(CertContext);
impl Certificate {
pub fn from_der(buf: &[u8]) -> Result<Certificate, Error> {
let cert = CertContext::new(buf)?;
Ok(Certificate(cert))
}
pub fn from_pem(buf: &[u8]) -> Result<Certificate, Error> {
match str::from_utf8(buf) {
Ok(s) => {
let cert = CertContext::from_pem(s)?;
Ok(Certificate(cert))
}
Err(_) => Err(io::Error::new(
io::ErrorKind::InvalidInput,
"PEM representation contains non-UTF-8 bytes",
)
.into()),
}
}
pub fn to_der(&self) -> Result<Vec<u8>, Error> {
Ok(self.0.to_der().to_vec())
}
}
pub struct MidHandshakeTlsStream<S>(tls_stream::MidHandshakeTlsStream<S>);
impl<S> fmt::Debug for MidHandshakeTlsStream<S>
where
S: fmt::Debug,
{
fn fmt(&self, fmt: &mut fmt::Formatter) -> fmt::Result {
fmt::Debug::fmt(&self.0, fmt)
}
}
impl<S> MidHandshakeTlsStream<S> {
pub fn get_ref(&self) -> &S {
self.0.get_ref()
}
pub fn get_mut(&mut self) -> &mut S {
self.0.get_mut()
}
}
impl<S> MidHandshakeTlsStream<S>
where
S: io::Read + io::Write,
{
pub fn handshake(self) -> Result<TlsStream<S>, HandshakeError<S>> {
match self.0.handshake() {
Ok(s) => Ok(TlsStream(s)),
Err(e) => Err(e.into()),
}
}
}
pub enum HandshakeError<S> {
Failure(Error),
WouldBlock(MidHandshakeTlsStream<S>),
}
impl<S> From<tls_stream::HandshakeError<S>> for HandshakeError<S> {
fn from(e: tls_stream::HandshakeError<S>) -> HandshakeError<S> {
match e {
tls_stream::HandshakeError::Failure(e) => HandshakeError::Failure(e.into()),
tls_stream::HandshakeError::Interrupted(s) => {
HandshakeError::WouldBlock(MidHandshakeTlsStream(s))
}
}
}
}
impl<S> From<io::Error> for HandshakeError<S> {
fn from(e: io::Error) -> HandshakeError<S> {
HandshakeError::Failure(e.into())
}
}
#[derive(Clone, Debug)]
pub struct TlsConnector {
cert: Option<CertContext>,
roots: CertStore,
min_protocol: Option<::Protocol>,
max_protocol: Option<::Protocol>,
use_sni: bool,
accept_invalid_hostnames: bool,
accept_invalid_certs: bool,
disable_built_in_roots: bool,
#[cfg(feature = "alpn")]
alpn: Vec<String>,
}
impl TlsConnector {
pub fn new(builder: &TlsConnectorBuilder) -> Result<TlsConnector, Error> {
let cert = builder.identity.as_ref().map(|i| i.0.cert.clone());
let mut roots = Memory::new()?.into_store();
for cert in &builder.root_certificates {
roots.add_cert(&(cert.0).0, CertAdd::ReplaceExisting)?;
}
Ok(TlsConnector {
cert,
roots,
min_protocol: builder.min_protocol,
max_protocol: builder.max_protocol,
use_sni: builder.use_sni,
accept_invalid_hostnames: builder.accept_invalid_hostnames,
accept_invalid_certs: builder.accept_invalid_certs,
disable_built_in_roots: builder.disable_built_in_roots,
#[cfg(feature = "alpn")]
alpn: builder.alpn.clone(),
})
}
pub fn connect<S>(&self, domain: &str, stream: S) -> Result<TlsStream<S>, HandshakeError<S>>
where
S: io::Read + io::Write,
{
let mut builder = SchannelCred::builder();
builder.enabled_protocols(convert_protocols(self.min_protocol, self.max_protocol));
if let Some(cert) = self.cert.as_ref() {
builder.cert(cert.clone());
}
let cred = builder.acquire(Direction::Outbound)?;
let mut builder = tls_stream::Builder::new();
builder
.cert_store(self.roots.clone())
.domain(domain)
.use_sni(self.use_sni)
.accept_invalid_hostnames(self.accept_invalid_hostnames);
if self.accept_invalid_certs {
builder.verify_callback(|_| Ok(()));
} else if self.disable_built_in_roots {
let roots_copy = self.roots.clone();
builder.verify_callback(move |res| {
if let Err(err) = res.result() {
// Propagate previous error encountered during normal cert validation.
return Err(err);
}
if let Some(chain) = res.chain() {
if chain
.certificates()
.any(|cert| roots_copy.certs().any(|root_cert| root_cert == cert))
{
return Ok(());
}
}
Err(io::Error::new(
io::ErrorKind::Other,
"unable to find any user-specified roots in the final cert chain",
))
});
}
#[cfg(feature = "alpn")]
{
if !self.alpn.is_empty() {
builder.request_application_protocols(
&self.alpn.iter().map(|s| s.as_bytes()).collect::<Vec<_>>(),
);
}
}
match builder.connect(cred, stream) {
Ok(s) => Ok(TlsStream(s)),
Err(e) => Err(e.into()),
}
}
}
#[derive(Clone)]
pub struct TlsAcceptor {
cert: CertContext,
min_protocol: Option<::Protocol>,
max_protocol: Option<::Protocol>,
}
impl TlsAcceptor {
pub fn new(builder: &TlsAcceptorBuilder) -> Result<TlsAcceptor, Error> {
Ok(TlsAcceptor {
cert: builder.identity.0.cert.clone(),
min_protocol: builder.min_protocol,
max_protocol: builder.max_protocol,
})
}
pub fn accept<S>(&self, stream: S) -> Result<TlsStream<S>, HandshakeError<S>>
where
S: io::Read + io::Write,
{
let mut builder = SchannelCred::builder();
builder.enabled_protocols(convert_protocols(self.min_protocol, self.max_protocol));
builder.cert(self.cert.clone());
// FIXME we're probably missing the certificate chain?
let cred = builder.acquire(Direction::Inbound)?;
match tls_stream::Builder::new().accept(cred, stream) {
Ok(s) => Ok(TlsStream(s)),
Err(e) => Err(e.into()),
}
}
}
pub struct TlsStream<S>(tls_stream::TlsStream<S>);
impl<S: fmt::Debug> fmt::Debug for TlsStream<S> {
fn fmt(&self, fmt: &mut fmt::Formatter) -> fmt::Result {
fmt::Debug::fmt(&self.0, fmt)
}
}
impl<S> TlsStream<S> {
pub fn get_ref(&self) -> &S {
self.0.get_ref()
}
pub fn get_mut(&mut self) -> &mut S {
self.0.get_mut()
}
}
impl<S: io::Read + io::Write> TlsStream<S> {
pub fn buffered_read_size(&self) -> Result<usize, Error> {
Ok(self.0.get_buf().len())
}
pub fn peer_certificate(&self) -> Result<Option<Certificate>, Error> {
match self.0.peer_certificate() {
Ok(cert) => Ok(Some(Certificate(cert))),
Err(ref e) if e.raw_os_error() == Some(SEC_E_NO_CREDENTIALS as i32) => Ok(None),
Err(e) => Err(Error(e)),
}
}
#[cfg(feature = "alpn")]
pub fn negotiated_alpn(&self) -> Result<Option<Vec<u8>>, Error> {
Ok(self.0.negotiated_application_protocol()?)
}
pub fn tls_server_end_point(&self) -> Result<Option<Vec<u8>>, Error> {
let cert = if self.0.is_server() {
self.0.certificate()
} else {
self.0.peer_certificate()
};
let cert = match cert {
Ok(cert) => cert,
Err(ref e) if e.raw_os_error() == Some(SEC_E_NO_CREDENTIALS as i32) => return Ok(None),
Err(e) => return Err(Error(e)),
};
let signature_algorithms = cert.sign_hash_algorithms()?;
let hash = match signature_algorithms.rsplit('/').next().unwrap() {
"MD5" | "SHA1" | "SHA256" => HashAlgorithm::sha256(),
"SHA384" => HashAlgorithm::sha384(),
"SHA512" => HashAlgorithm::sha512(),
_ => return Ok(None),
};
let digest = cert.fingerprint(hash)?;
Ok(Some(digest))
}
pub fn shutdown(&mut self) -> io::Result<()> {
self.0.shutdown()?;
Ok(())
}
}
impl<S: io::Read + io::Write> io::Read for TlsStream<S> {
fn read(&mut self, buf: &mut [u8]) -> io::Result<usize> {
self.0.read(buf)
}
}
impl<S: io::Read + io::Write> io::Write for TlsStream<S> {
fn write(&mut self, buf: &[u8]) -> io::Result<usize> {
self.0.write(buf)
}
fn flush(&mut self) -> io::Result<()> {
self.0.flush()
}
}
mod pem {
/// Split data by PEM guard lines
pub struct PemBlock<'a> {
pem_block: &'a str,
cur_end: usize,
}
impl<'a> PemBlock<'a> {
pub fn new(data: &'a [u8]) -> PemBlock<'a> {
let s = ::std::str::from_utf8(data).unwrap();
PemBlock {
pem_block: s,
cur_end: s.find("-----BEGIN").unwrap_or(s.len()),
}
}
}
impl<'a> Iterator for PemBlock<'a> {
type Item = &'a [u8];
fn next(&mut self) -> Option<Self::Item> {
let last = self.pem_block.len();
if self.cur_end >= last {
return None;
}
let begin = self.cur_end;
let pos = self.pem_block[begin + 1..].find("-----BEGIN");
self.cur_end = match pos {
Some(end) => end + begin + 1,
None => last,
};
return Some(&self.pem_block[begin..self.cur_end].as_bytes());
}
}
#[test]
fn test_split() {
// Split three certs, CRLF line terminators.
assert_eq!(
PemBlock::new(
b"-----BEGIN FIRST-----\r\n-----END FIRST-----\r\n\
-----BEGIN SECOND-----\r\n-----END SECOND\r\n\
-----BEGIN THIRD-----\r\n-----END THIRD\r\n"
)
.collect::<Vec<&[u8]>>(),
vec![
b"-----BEGIN FIRST-----\r\n-----END FIRST-----\r\n" as &[u8],
b"-----BEGIN SECOND-----\r\n-----END SECOND\r\n",
b"-----BEGIN THIRD-----\r\n-----END THIRD\r\n"
]
);
// Split three certs, CRLF line terminators except at EOF.
assert_eq!(
PemBlock::new(
b"-----BEGIN FIRST-----\r\n-----END FIRST-----\r\n\
-----BEGIN SECOND-----\r\n-----END SECOND-----\r\n\
-----BEGIN THIRD-----\r\n-----END THIRD-----"
)
.collect::<Vec<&[u8]>>(),
vec![
b"-----BEGIN FIRST-----\r\n-----END FIRST-----\r\n" as &[u8],
b"-----BEGIN SECOND-----\r\n-----END SECOND-----\r\n",
b"-----BEGIN THIRD-----\r\n-----END THIRD-----"
]
);
// Split two certs, LF line terminators.
assert_eq!(
PemBlock::new(
b"-----BEGIN FIRST-----\n-----END FIRST-----\n\
-----BEGIN SECOND-----\n-----END SECOND\n"
)
.collect::<Vec<&[u8]>>(),
vec![
b"-----BEGIN FIRST-----\n-----END FIRST-----\n" as &[u8],
b"-----BEGIN SECOND-----\n-----END SECOND\n"
]
);
// Split two certs, CR line terminators.
assert_eq!(
PemBlock::new(
b"-----BEGIN FIRST-----\r-----END FIRST-----\r\
-----BEGIN SECOND-----\r-----END SECOND\r"
)
.collect::<Vec<&[u8]>>(),
vec![
b"-----BEGIN FIRST-----\r-----END FIRST-----\r" as &[u8],
b"-----BEGIN SECOND-----\r-----END SECOND\r"
]
);
// Split two certs, LF line terminators except at EOF.
assert_eq!(
PemBlock::new(
b"-----BEGIN FIRST-----\n-----END FIRST-----\n\
-----BEGIN SECOND-----\n-----END SECOND"
)
.collect::<Vec<&[u8]>>(),
vec![
b"-----BEGIN FIRST-----\n-----END FIRST-----\n" as &[u8],
b"-----BEGIN SECOND-----\n-----END SECOND"
]
);
// Split a single cert, LF line terminators.
assert_eq!(
PemBlock::new(b"-----BEGIN FIRST-----\n-----END FIRST-----\n").collect::<Vec<&[u8]>>(),
vec![b"-----BEGIN FIRST-----\n-----END FIRST-----\n" as &[u8]]
);
// Split a single cert, LF line terminators except at EOF.
assert_eq!(
PemBlock::new(b"-----BEGIN FIRST-----\n-----END FIRST-----").collect::<Vec<&[u8]>>(),
vec![b"-----BEGIN FIRST-----\n-----END FIRST-----" as &[u8]]
);
// (Don't) split garbage.
assert_eq!(
PemBlock::new(b"junk").collect::<Vec<&[u8]>>(),
Vec::<&[u8]>::new()
);
assert_eq!(
PemBlock::new(b"junk-----BEGIN garbage").collect::<Vec<&[u8]>>(),
vec![b"-----BEGIN garbage" as &[u8]]
);
}
}