blob: 4471727db0c7eed056ce8e37006657b5caf60ef8 [file] [log] [blame]
//! Support for deserializing configuration via `serde`
use crate::util::config::value;
use crate::util::config::{Config, ConfigError, ConfigKey};
use crate::util::config::{ConfigValue as CV, Definition, Value};
use serde::{de, de::IntoDeserializer};
use std::collections::HashSet;
use std::path::PathBuf;
use std::vec;
/// Serde deserializer used to convert config values to a target type using
/// `Config::get`.
#[derive(Clone)]
pub(crate) struct Deserializer<'config> {
pub(crate) config: &'config Config,
pub(crate) key: ConfigKey,
}
macro_rules! deserialize_method {
($method:ident, $visit:ident, $getter:ident) => {
fn $method<V>(self, visitor: V) -> Result<V::Value, Self::Error>
where
V: de::Visitor<'de>,
{
let v = self.config.$getter(&self.key)?.ok_or_else(||
ConfigError::missing(&self.key))?;
let Value{val, definition} = v;
let res: Result<V::Value, ConfigError> = visitor.$visit(val);
res.map_err(|e| e.with_key_context(&self.key, definition))
}
}
}
impl<'de, 'config> de::Deserializer<'de> for Deserializer<'config> {
type Error = ConfigError;
fn deserialize_any<V>(self, visitor: V) -> Result<V::Value, Self::Error>
where
V: de::Visitor<'de>,
{
// Future note: If you ever need to deserialize a non-self describing
// map type, this should implement a starts_with check (similar to how
// ConfigMapAccess does).
if let Some(v) = self.config.env.get(self.key.as_env_key()) {
let res: Result<V::Value, ConfigError> = if v == "true" || v == "false" {
visitor.visit_bool(v.parse().unwrap())
} else if let Ok(v) = v.parse::<i64>() {
visitor.visit_i64(v)
} else if self.config.cli_unstable().advanced_env
&& v.starts_with('[')
&& v.ends_with(']')
{
visitor.visit_seq(ConfigSeqAccess::new(self.clone())?)
} else {
visitor.visit_string(v.clone())
};
return res.map_err(|e| {
e.with_key_context(
&self.key,
Definition::Environment(self.key.as_env_key().to_string()),
)
});
}
let o_cv = self.config.get_cv(self.key.as_config_key())?;
if let Some(cv) = o_cv {
let res: (Result<V::Value, ConfigError>, PathBuf) = match cv {
CV::Integer(i, path) => (visitor.visit_i64(i), path),
CV::String(s, path) => (visitor.visit_string(s), path),
CV::List(_, path) => (visitor.visit_seq(ConfigSeqAccess::new(self.clone())?), path),
CV::Table(_, path) => (
visitor.visit_map(ConfigMapAccess::new_map(self.clone())?),
path,
),
CV::Boolean(b, path) => (visitor.visit_bool(b), path),
};
let (res, path) = res;
return res.map_err(|e| e.with_key_context(&self.key, Definition::Path(path)));
}
Err(ConfigError::missing(&self.key))
}
deserialize_method!(deserialize_bool, visit_bool, get_bool_priv);
deserialize_method!(deserialize_i8, visit_i64, get_integer);
deserialize_method!(deserialize_i16, visit_i64, get_integer);
deserialize_method!(deserialize_i32, visit_i64, get_integer);
deserialize_method!(deserialize_i64, visit_i64, get_integer);
deserialize_method!(deserialize_u8, visit_i64, get_integer);
deserialize_method!(deserialize_u16, visit_i64, get_integer);
deserialize_method!(deserialize_u32, visit_i64, get_integer);
deserialize_method!(deserialize_u64, visit_i64, get_integer);
deserialize_method!(deserialize_string, visit_string, get_string_priv);
fn deserialize_option<V>(self, visitor: V) -> Result<V::Value, Self::Error>
where
V: de::Visitor<'de>,
{
if self.config.has_key(&self.key) {
visitor.visit_some(self)
} else {
// Treat missing values as `None`.
visitor.visit_none()
}
}
fn deserialize_struct<V>(
self,
name: &'static str,
fields: &'static [&'static str],
visitor: V,
) -> Result<V::Value, Self::Error>
where
V: de::Visitor<'de>,
{
// Match on the magical struct name/field names that are passed in to
// detect when we're deserializing `Value<T>`.
//
// See more comments in `value.rs` for the protocol used here.
if name == value::NAME && fields == value::FIELDS {
return visitor.visit_map(ValueDeserializer { hits: 0, de: self });
}
visitor.visit_map(ConfigMapAccess::new_struct(self, fields)?)
}
fn deserialize_map<V>(self, visitor: V) -> Result<V::Value, Self::Error>
where
V: de::Visitor<'de>,
{
visitor.visit_map(ConfigMapAccess::new_map(self)?)
}
fn deserialize_seq<V>(self, visitor: V) -> Result<V::Value, Self::Error>
where
V: de::Visitor<'de>,
{
visitor.visit_seq(ConfigSeqAccess::new(self)?)
}
fn deserialize_tuple<V>(self, _len: usize, visitor: V) -> Result<V::Value, Self::Error>
where
V: de::Visitor<'de>,
{
visitor.visit_seq(ConfigSeqAccess::new(self)?)
}
fn deserialize_tuple_struct<V>(
self,
_name: &'static str,
_len: usize,
visitor: V,
) -> Result<V::Value, Self::Error>
where
V: de::Visitor<'de>,
{
visitor.visit_seq(ConfigSeqAccess::new(self)?)
}
// These aren't really supported, yet.
serde::forward_to_deserialize_any! {
f32 f64 char str bytes
byte_buf unit unit_struct
enum identifier ignored_any newtype_struct
}
}
struct ConfigMapAccess<'config> {
de: Deserializer<'config>,
set_iter: <HashSet<KeyKind> as IntoIterator>::IntoIter,
next: Option<KeyKind>,
}
#[derive(PartialEq, Eq, Hash)]
enum KeyKind {
Normal(String),
CaseSensitive(String),
}
impl<'config> ConfigMapAccess<'config> {
fn new_map(de: Deserializer<'config>) -> Result<ConfigMapAccess<'config>, ConfigError> {
let mut set = HashSet::new();
if let Some(mut v) = de.config.get_table(de.key.as_config_key())? {
// `v: Value<HashMap<String, CV>>`
for (key, _value) in v.val.drain() {
set.insert(KeyKind::CaseSensitive(key));
}
}
if de.config.cli_unstable().advanced_env {
// `CARGO_PROFILE_DEV_OVERRIDES_`
let env_pattern = format!("{}_", de.key.as_env_key());
for env_key in de.config.env.keys() {
if env_key.starts_with(&env_pattern) {
// `CARGO_PROFILE_DEV_OVERRIDES_bar_OPT_LEVEL = 3`
let rest = &env_key[env_pattern.len()..];
// `rest = bar_OPT_LEVEL`
let part = rest.splitn(2, '_').next().unwrap();
// `part = "bar"`
set.insert(KeyKind::CaseSensitive(part.to_string()));
}
}
}
Ok(ConfigMapAccess {
de,
set_iter: set.into_iter(),
next: None,
})
}
fn new_struct(
de: Deserializer<'config>,
fields: &'static [&'static str],
) -> Result<ConfigMapAccess<'config>, ConfigError> {
let mut set = HashSet::new();
for field in fields {
set.insert(KeyKind::Normal(field.to_string()));
}
// Assume that if we're deserializing a struct it exhaustively lists all
// possible fields on this key that we're *supposed* to use, so take
// this opportunity to warn about any keys that aren't recognized as
// fields and warn about them.
if let Some(mut v) = de.config.get_table(de.key.as_config_key())? {
for (t_key, value) in v.val.drain() {
if set.contains(&KeyKind::Normal(t_key.to_string())) {
continue;
}
de.config.shell().warn(format!(
"unused key `{}.{}` in config file `{}`",
de.key.as_config_key(),
t_key,
value.definition_path().display()
))?;
}
}
Ok(ConfigMapAccess {
de,
set_iter: set.into_iter(),
next: None,
})
}
}
impl<'de, 'config> de::MapAccess<'de> for ConfigMapAccess<'config> {
type Error = ConfigError;
fn next_key_seed<K>(&mut self, seed: K) -> Result<Option<K::Value>, Self::Error>
where
K: de::DeserializeSeed<'de>,
{
match self.set_iter.next() {
Some(key) => {
let name = match &key {
KeyKind::Normal(s) | KeyKind::CaseSensitive(s) => s.as_str(),
};
let result = seed.deserialize(name.into_deserializer()).map(Some);
self.next = Some(key);
return result;
}
None => Ok(None),
}
}
fn next_value_seed<V>(&mut self, seed: V) -> Result<V::Value, Self::Error>
where
V: de::DeserializeSeed<'de>,
{
match self.next.take().expect("next field missing") {
KeyKind::Normal(key) => self.de.key.push(&key),
KeyKind::CaseSensitive(key) => self.de.key.push_sensitive(&key),
}
let result = seed.deserialize(Deserializer {
config: self.de.config,
key: self.de.key.clone(),
});
self.de.key.pop();
return result;
}
}
struct ConfigSeqAccess {
list_iter: vec::IntoIter<(String, Definition)>,
}
impl ConfigSeqAccess {
fn new(de: Deserializer<'_>) -> Result<ConfigSeqAccess, ConfigError> {
let mut res = Vec::new();
if let Some(v) = de.config.get_list(de.key.as_config_key())? {
for (s, path) in v.val {
res.push((s, Definition::Path(path)));
}
}
if de.config.cli_unstable().advanced_env {
// Parse an environment string as a TOML array.
if let Some(v) = de.config.env.get(de.key.as_env_key()) {
let def = Definition::Environment(de.key.as_env_key().to_string());
if !(v.starts_with('[') && v.ends_with(']')) {
return Err(ConfigError::new(
format!("should have TOML list syntax, found `{}`", v),
def,
));
}
let toml_s = format!("value={}", v);
let toml_v: toml::Value = toml::de::from_str(&toml_s).map_err(|e| {
ConfigError::new(format!("could not parse TOML list: {}", e), def.clone())
})?;
let values = toml_v
.as_table()
.unwrap()
.get("value")
.unwrap()
.as_array()
.expect("env var was not array");
for value in values {
// TODO: support other types.
let s = value.as_str().ok_or_else(|| {
ConfigError::new(
format!("expected string, found {}", value.type_str()),
def.clone(),
)
})?;
res.push((s.to_string(), def.clone()));
}
}
}
Ok(ConfigSeqAccess {
list_iter: res.into_iter(),
})
}
}
impl<'de> de::SeqAccess<'de> for ConfigSeqAccess {
type Error = ConfigError;
fn next_element_seed<T>(&mut self, seed: T) -> Result<Option<T::Value>, Self::Error>
where
T: de::DeserializeSeed<'de>,
{
match self.list_iter.next() {
// TODO: add `def` to error?
Some((value, _def)) => seed.deserialize(value.into_deserializer()).map(Some),
None => Ok(None),
}
}
}
/// This is a deserializer that deserializes into a `Value<T>` for
/// configuration.
///
/// This is a special deserializer because it deserializes one of its struct
/// fields into the location that this configuration value was defined in.
///
/// See more comments in `value.rs` for the protocol used here.
struct ValueDeserializer<'config> {
hits: u32,
de: Deserializer<'config>,
}
impl<'de, 'config> de::MapAccess<'de> for ValueDeserializer<'config> {
type Error = ConfigError;
fn next_key_seed<K>(&mut self, seed: K) -> Result<Option<K::Value>, Self::Error>
where
K: de::DeserializeSeed<'de>,
{
self.hits += 1;
match self.hits {
1 => seed
.deserialize(value::VALUE_FIELD.into_deserializer())
.map(Some),
2 => seed
.deserialize(value::DEFINITION_FIELD.into_deserializer())
.map(Some),
_ => Ok(None),
}
}
fn next_value_seed<V>(&mut self, seed: V) -> Result<V::Value, Self::Error>
where
V: de::DeserializeSeed<'de>,
{
macro_rules! bail {
($($t:tt)*) => (return Err(failure::format_err!($($t)*).into()))
}
// If this is the first time around we deserialize the `value` field
// which is the actual deserializer
if self.hits == 1 {
return seed.deserialize(self.de.clone());
}
// ... otherwise we're deserializing the `definition` field, so we need
// to figure out where the field we just deserialized was defined at.
let env = self.de.key.as_env_key();
if self.de.config.env.contains_key(env) {
return seed.deserialize(Tuple2Deserializer(1i32, env));
}
let val = match self.de.config.get_cv(self.de.key.as_config_key())? {
Some(val) => val,
None => bail!("failed to find definition of `{}`", self.de.key),
};
let path = match val.definition_path().to_str() {
Some(s) => s,
None => bail!("failed to convert {:?} to utf-8", val.definition_path()),
};
seed.deserialize(Tuple2Deserializer(0i32, path))
}
}
/// A deserializer which takes two values and deserializes into a tuple of those
/// two values. This is similar to types like `StrDeserializer` in upstream
/// serde itself.
struct Tuple2Deserializer<T, U>(T, U);
impl<'de, T, U> de::Deserializer<'de> for Tuple2Deserializer<T, U>
where
T: IntoDeserializer<'de, ConfigError>,
U: IntoDeserializer<'de, ConfigError>,
{
type Error = ConfigError;
fn deserialize_any<V>(self, visitor: V) -> Result<V::Value, ConfigError>
where
V: de::Visitor<'de>,
{
struct SeqVisitor<T, U> {
first: Option<T>,
second: Option<U>,
}
impl<'de, T, U> de::SeqAccess<'de> for SeqVisitor<T, U>
where
T: IntoDeserializer<'de, ConfigError>,
U: IntoDeserializer<'de, ConfigError>,
{
type Error = ConfigError;
fn next_element_seed<K>(&mut self, seed: K) -> Result<Option<K::Value>, Self::Error>
where
K: de::DeserializeSeed<'de>,
{
if let Some(first) = self.first.take() {
return seed.deserialize(first.into_deserializer()).map(Some);
}
if let Some(second) = self.second.take() {
return seed.deserialize(second.into_deserializer()).map(Some);
}
Ok(None)
}
}
visitor.visit_seq(SeqVisitor {
first: Some(self.0),
second: Some(self.1),
})
}
serde::forward_to_deserialize_any! {
bool u8 u16 u32 u64 i8 i16 i32 i64 f32 f64 char str string seq
bytes byte_buf map struct option unit newtype_struct
ignored_any unit_struct tuple_struct tuple enum identifier
}
}