1use std::{
4 collections::HashMap, env, ffi::OsStr, fs, io::Error as IoError, path::Path, str::FromStr,
5};
6
7use basic_toml::Error as TomlError;
8use pico_args::Arguments;
9use serde::{Deserialize, Serialize};
10use serde_json::Error as JsonError;
11use serde_yaml::Error as YamlError;
12use strum::{Display as EnumDisplay, EnumString};
13use thiserror::Error;
14use tracing::{instrument, warn};
15
16use crate::{
17 config::{global::Hsts, CertificateSource, DefaultCertificateSource, ListenAddress, LogLevel},
18 stats::StatisticCategories,
19 store::BackendType,
20};
21
22#[derive(Debug, Error)]
24pub enum IntoPartialError {
25 #[error("failed to parse from toml")]
27 Toml(#[from] TomlError),
28 #[error("failed to parse from yaml")]
30 Yaml(#[from] YamlError),
31 #[error("failed to parse from json")]
33 Json(#[from] JsonError),
34 #[error("failed to read config file")]
36 Io(#[from] IoError),
37 #[error("file extension unknown, could not determine format")]
39 UnknownExtension,
40}
41
42fn deserialize_arg<T: for<'a> Deserialize<'a>>(
46 args: &mut Arguments,
47 key: &'static str,
48) -> Option<T> {
49 args.opt_value_from_fn(key, |s| serde_json::from_str(s))
50 .map_err(|err| {
51 warn!(
52 %err,
53 "Error parsing configuration from command-line argument '{key}'"
54 );
55 })
56 .ok()
57 .flatten()
58}
59
60fn parse_env_var<T: FromStr>(key: &str) -> Option<T> {
63 env::var(key).map_or(None, |s| {
64 s.parse()
65 .map_err(|_| {
66 warn!("Error parsing configuration from environment variable '{key}'");
67 })
68 .ok()
69 })
70}
71
72fn deserialize_env_var<T: for<'a> Deserialize<'a>>(key: &str) -> Option<T> {
76 env::var(key)
77 .map_or(None, |s| {
78 serde_json::from_str(&s)
79 .map_err(|err| {
80 warn!(
81 %err,
82 "Error parsing configuration from environment variable '{key}'"
83 );
84 })
85 .ok()
86 })
87 .flatten()
88}
89
90#[derive(Clone, Debug, Default, PartialEq, Eq, Serialize, Deserialize)]
97pub struct Partial {
98 pub log_level: Option<LogLevel>,
102 pub token: Option<String>,
104 pub listeners: Option<Vec<ListenAddress>>,
106 pub statistics: Option<StatisticCategories>,
108 pub default_certificate: Option<DefaultCertificateSource>,
110 pub certificates: Option<Vec<CertificateSource>>,
112 pub hsts: Option<PartialHsts>,
114 pub hsts_max_age: Option<u32>,
117 pub https_redirect: Option<bool>,
119 pub send_alt_svc: Option<bool>,
122 pub send_server: Option<bool>,
124 pub send_csp: Option<bool>,
126 pub store: Option<BackendType>,
128 pub store_config: Option<HashMap<String, String>>,
134}
135
136impl Partial {
137 pub fn from_toml(toml: &str) -> Result<Self, IntoPartialError> {
142 Ok(basic_toml::from_str(toml)?)
143 }
144
145 pub fn from_yaml(yaml: &str) -> Result<Self, IntoPartialError> {
150 Ok(serde_yaml::from_str(yaml)?)
151 }
152
153 pub fn from_json(json: &str) -> Result<Self, IntoPartialError> {
158 Ok(serde_json::from_str(json)?)
159 }
160
161 #[instrument(level = "debug", ret, err)]
174 pub fn from_file(path: &Path) -> Result<Self, IntoPartialError> {
175 let parse = match path.extension().map(OsStr::to_str) {
176 Some(Some("toml")) => Self::from_toml,
177 Some(Some("yaml" | "yml")) => Self::from_yaml,
178 Some(Some("json")) => Self::from_json,
179 _ => return Err(IntoPartialError::UnknownExtension),
180 };
181
182 parse(&fs::read_to_string(path)?)
183 }
184
185 #[must_use]
188 #[instrument(level = "debug", ret)]
189 pub fn from_args() -> Self {
190 let mut args = Arguments::from_env();
191 Self {
192 log_level: args.opt_value_from_str("--log-level").unwrap_or(None),
193 token: args.opt_value_from_str("--token").unwrap_or(None),
194 listeners: deserialize_arg(&mut args, "--listeners"),
195 statistics: deserialize_arg(&mut args, "--statistics"),
196 default_certificate: deserialize_arg(&mut args, "--default-certificate"),
197 certificates: deserialize_arg(&mut args, "--certificates"),
198 hsts: args.opt_value_from_str("--hsts").unwrap_or(None),
199 hsts_max_age: args.opt_value_from_str("--hsts-max-age").unwrap_or(None),
200 https_redirect: args.opt_value_from_str("--https-redirect").unwrap_or(None),
201 send_alt_svc: args.opt_value_from_str("--send-alt-svc").unwrap_or(None),
202 send_server: args.opt_value_from_str("--send-server").unwrap_or(None),
203 send_csp: args.opt_value_from_str("--send-csp").unwrap_or(None),
204 store: args.opt_value_from_str("--store").unwrap_or(None),
205 store_config: deserialize_arg(&mut args, "--store-config"),
206 }
207 }
208
209 #[must_use]
212 #[instrument(level = "debug", ret)]
213 pub fn from_env_vars() -> Self {
214 Self {
215 log_level: parse_env_var("LINKS_LOG_LEVEL"),
216 token: parse_env_var("LINKS_TOKEN"),
217 listeners: deserialize_env_var("LINKS_LISTENERS"),
218 statistics: deserialize_env_var("LINKS_STATISTICS"),
219 default_certificate: deserialize_env_var("LINKS_DEFAULT_CERTIFICATE"),
220 certificates: deserialize_env_var("LINKS_CERTIFICATES"),
221 hsts: parse_env_var("LINKS_HSTS"),
222 hsts_max_age: parse_env_var("LINKS_HSTS_MAX_AGE"),
223 https_redirect: parse_env_var("LINKS_HTTPS_REDIRECT"),
224 send_alt_svc: parse_env_var("LINKS_SEND_ALT_SVC"),
225 send_server: parse_env_var("LINKS_SEND_SERVER"),
226 send_csp: parse_env_var("LINKS_SEND_CSP"),
227 store: parse_env_var("LINKS_STORE"),
228 store_config: deserialize_env_var("LINKS_STORE_CONFIG"),
229 }
230 }
231
232 #[must_use]
234 pub fn hsts(&self) -> Option<Hsts> {
235 match self.hsts? {
236 PartialHsts::Disable => Some(Hsts::Disable),
237 PartialHsts::Enable => Some(Hsts::Enable(self.hsts_max_age?)),
238 PartialHsts::IncludeSubDomains => Some(Hsts::IncludeSubDomains(self.hsts_max_age?)),
239 PartialHsts::Preload => Some(Hsts::Preload(self.hsts_max_age?)),
240 }
241 }
242}
243
244#[derive(
263 Copy, Clone, Debug, Default, PartialEq, Eq, Serialize, Deserialize, EnumString, EnumDisplay,
264)]
265#[serde(rename_all = "snake_case")]
266#[strum(serialize_all = "snake_case", ascii_case_insensitive)]
267pub enum PartialHsts {
268 #[strum(serialize = "disable", serialize = "off")]
270 Disable,
271 #[default]
274 #[strum(serialize = "enable", serialize = "on")]
275 Enable,
276 #[strum(
287 serialize = "includeSubDomains",
288 serialize = "include",
289 to_string = "include"
290 )]
291 IncludeSubDomains,
292 Preload,
305}
306
307#[cfg(test)]
308mod tests {
309 use super::*;
310
311 #[test]
312 fn test_deserialize_arg() {
313 let mut args = Arguments::from_vec(vec![
314 "--certificates".into(),
315 r#"[{"source": "files", "domains": ["example.com"], "cert": "./cert.pem", "key": "./key.pem"}]"#.into(),
316 "--listeners".into(),
317 "yes, please".into(),
318 ]);
319
320 assert!(deserialize_arg::<Vec<ListenAddress>>(&mut args, "--listeners").is_none());
321 assert!(deserialize_arg::<Vec<CertificateSource>>(&mut args, "--certificates").is_some());
322 }
323
324 #[test]
325 fn test_parse_env_var() {
326 env::set_var("LINKS_LOG_LEVEL", "no logging, thanks");
327 env::set_var("LINKS_HSTS", "include");
328
329 assert!(parse_env_var::<LogLevel>("LINKS_LOG_LEVEL").is_none());
330 assert!(parse_env_var::<PartialHsts>("LINKS_HSTS").is_some());
331 }
332
333 #[test]
334 fn test_deserialize_env_var() {
335 env::set_var(
336 "LINKS_CERTIFICATES",
337 r#"[{"source": "files", "domains": ["example.com"], "cert": "./cert.pem", "key": "./key.pem"}]"#,
338 );
339 env::set_var("LINKS_LISTENERS", "yes, please");
340
341 assert!(deserialize_env_var::<Vec<ListenAddress>>("LINKS_LISTENERS").is_none());
342 assert!(deserialize_env_var::<Vec<CertificateSource>>("LINKS_CERTIFICATES").is_some());
343 }
344}