1mod global;
39mod partial;
40
41use std::{
42 fmt::{Debug, Display, Formatter, Result as FmtResult},
43 fs,
44 io::Error as IoError,
45 net::{AddrParseError, IpAddr, Ipv4Addr, Ipv6Addr},
46 num::ParseIntError,
47 path::PathBuf,
48 str::FromStr,
49 sync::Mutex,
50 time::Duration,
51};
52
53use crossbeam_channel::{select, unbounded, Receiver, Sender};
54use links_domainmap::Domain;
55use notify::{Event, EventKind, RecommendedWatcher, RecursiveMode, Watcher};
56use serde::{Deserialize, Serialize};
57use strum::{Display as EnumDisplay, EnumString, ParseError};
58use tokio_rustls::rustls::{
59 crypto::ring::sign,
60 pki_types::{CertificateDer, PrivateKeyDer},
61 sign::CertifiedKey,
62 Error as RustlsError,
63};
64use tracing::{debug, error, Level};
65
66pub use self::{
67 global::{Config, Hsts, Redirector},
68 partial::{IntoPartialError, Partial, PartialHsts},
69};
70use crate::{server::Protocol, util::Unpoison};
71
72#[derive(Debug)]
74pub enum CertConfigUpdate {
75 DefaultUpdated(DefaultCertificateSource),
77 SourceAdded(CertificateSource),
79 SourceRemoved(CertificateSource),
81}
82
83#[derive(Debug)]
85pub struct CertificateWatcher {
86 sources: Vec<CertificateSource>,
88 default_source: DefaultCertificateSource,
90 files_watcher: RecommendedWatcher,
92 files_rx: Receiver<Event>,
94 config_rx: Receiver<CertConfigUpdate>,
96 config_tx: Sender<CertConfigUpdate>,
99}
100
101impl CertificateWatcher {
102 pub fn new() -> anyhow::Result<Self> {
108 let (files_tx, files_rx) = unbounded();
109 let (config_tx, config_rx) = unbounded();
110 let files_watcher = notify::recommended_watcher(move |res| match res {
111 Ok(ev) => {
112 let _ = files_tx.send(ev).inspect_err(|err| {
113 error!("the certificate file watching channel closed unexpectedly: {err}");
114 });
115 }
116 Err(err) => error!(%err, "certificate file watching error"),
117 })?;
118
119 Ok(Self {
120 sources: Vec::new(),
121 default_source: DefaultCertificateSource::None,
122 files_watcher,
123 files_rx,
124 config_rx,
125 config_tx,
126 })
127 }
128
129 #[allow(clippy::missing_panics_doc)] pub fn watch(
140 &mut self,
141 debounce_time: Duration,
142 ) -> (Vec<CertificateSource>, DefaultCertificateSource) {
143 let debounced = Mutex::new(Option::<(Vec<_>, _)>::None);
144
145 let handle_files = |this: &mut Self, event: Event| {
146 if let EventKind::Access(_) = event.kind {
147 debug!(?event, "Ignoring file event from watcher");
148 return;
149 }
150
151 debug!(?event, "Received file event from watcher");
152 let file_sources = this
153 .sources
154 .iter()
155 .filter(|s| match s.source {
156 CertificateSourceType::Files { .. } => true,
157 })
158 .cloned()
159 .collect();
160
161 let mut db = debounced.lock().unpoison();
162 if let Some(ref mut debounced) = *db {
163 for source in file_sources {
164 if !debounced.0.contains(&source) {
165 debounced.0.push(source);
166 }
167 }
168
169 if matches!(this.default_source, DefaultCertificateSource::Some {
170 source: CertificateSourceType::Files { .. },
171 ..
172 }) {
173 debounced.1 = this.default_source.clone();
174 }
175 } else {
176 *db = Some((
177 file_sources,
178 if matches!(this.default_source, DefaultCertificateSource::Some {
179 source: CertificateSourceType::Files { .. },
180 ..
181 }) {
182 this.default_source.clone()
183 } else {
184 DefaultCertificateSource::None
185 },
186 ));
187 }
188 };
189
190 let handle_config = |this: &mut Self, msg| match msg {
191 CertConfigUpdate::DefaultUpdated(default) => {
192 if let Some((Err(err), source)) = this
193 .default_source
194 .clone()
195 .into_cs()
196 .map(|s| (s.unwatch(this), s))
197 {
198 error!(%err, ?source, "default certificate source could not be unwatched");
199 }
200
201 this.default_source = default;
202
203 if let Some((Err(err), source)) = this
204 .default_source
205 .clone()
206 .into_cs()
207 .map(|s| (s.watch(this), s))
208 {
209 error!(%err, ?source, "default certificate source could not be watched");
210 }
211 }
212 CertConfigUpdate::SourceRemoved(source) => {
213 this.sources.retain(|s| s != &source);
214 if let Err(err) = source.unwatch(this) {
215 error!(%err, ?source, "certificate source could not be unwatched");
216 }
217 }
218 CertConfigUpdate::SourceAdded(source) => {
219 if let Err(err) = source.watch(this) {
220 error!(%err, ?source, "certificate source could not be watched");
221 }
222 this.sources.push(source);
223 }
224 };
225
226 loop {
227 select! {
228 recv(self.files_rx) -> msg => handle_files(self, msg.expect("certificate watcher channel closed")),
229 recv(self.config_rx) -> msg => handle_config(self, msg.expect("certificate watcher channel closed")),
230 default(debounce_time) => if debounced.lock().unpoison().is_some() {
231 break debounced.into_inner().unpoison().expect("the option was just checked to be some");
232 }
233 }
234 }
235 }
236
237 #[must_use]
239 pub fn get_config_sender(&self) -> Sender<CertConfigUpdate> {
240 self.config_tx.clone()
241 }
242
243 pub fn send_config_update(&self, update: CertConfigUpdate) {
245 if self.config_tx.send(update).is_err() {
246 unreachable!("the receiver is owned by this watcher, so this channel can not be closed")
247 }
248 }
249}
250
251#[derive(Clone, Debug, PartialEq, Eq, Serialize, Deserialize)]
253#[serde(rename_all = "kebab-case", untagged)]
254pub enum DefaultCertificateSource {
255 None,
257 Some {
259 #[serde(default)]
262 domains: Vec<Domain>,
263 #[serde(flatten)]
265 source: CertificateSourceType,
266 },
267}
268
269impl DefaultCertificateSource {
270 #[must_use]
274 pub fn into_cs(self) -> Option<CertificateSource> {
275 match self {
276 Self::None => None,
277 Self::Some { domains, source } => Some(CertificateSource { domains, source }),
278 }
279 }
280}
281
282#[derive(Clone, Debug, PartialEq, Eq, Serialize, Deserialize)]
284#[serde(rename_all = "kebab-case")]
285pub struct CertificateSource {
286 pub domains: Vec<Domain>,
288 #[serde(flatten)]
290 pub source: CertificateSourceType,
291}
292
293impl CertificateSource {
294 pub fn get_certkey(&self) -> Result<CertifiedKey, CertificateAcquisitionError> {
304 match &self.source {
305 CertificateSourceType::Files { cert, key } => {
306 let certs = fs::read(cert)?;
307 let key = fs::read(key)?;
308
309 let certs: Result<Vec<CertificateDer>, _> = rustls_pemfile::certs(&mut &certs[..])
310 .map(|res| res.map(|der| CertificateDer::from(der.to_vec())))
311 .collect();
312 let certs = certs?;
313 let key = rustls_pemfile::pkcs8_private_keys(&mut &key[..])
314 .map(|res| {
315 res.map(|der| {
316 PrivateKeyDer::Pkcs8(der.secret_pkcs8_der().to_owned().into())
317 })
318 })
319 .next()
320 .ok_or(CertificateAcquisitionError::MissingKey)??;
321
322 let cert_key = CertifiedKey::new(
323 certs,
324 sign::any_supported_type(&key)
325 .map_err(CertificateAcquisitionError::InvalidKey)?,
326 );
327
328 let () = cert_key
329 .keys_match()
330 .map_err(CertificateAcquisitionError::KeyMismatch)?;
331
332 Ok(cert_key)
333 }
334 }
335 }
336
337 pub fn watch(&self, watcher: &mut CertificateWatcher) -> anyhow::Result<()> {
343 match &self.source {
344 CertificateSourceType::Files { cert, key } => {
345 watcher
346 .files_watcher
347 .watch(cert, RecursiveMode::NonRecursive)?;
348 watcher
349 .files_watcher
350 .watch(key, RecursiveMode::NonRecursive)?;
351 }
352 }
353
354 Ok(())
355 }
356
357 pub fn unwatch(&self, watcher: &mut CertificateWatcher) -> anyhow::Result<()> {
363 match &self.source {
364 CertificateSourceType::Files { cert, key } => {
365 watcher.files_watcher.unwatch(cert)?;
366 watcher.files_watcher.unwatch(key)?;
367 }
368 }
369
370 Ok(())
371 }
372}
373
374#[derive(Debug, thiserror::Error)]
376#[non_exhaustive]
377pub enum CertificateAcquisitionError {
378 #[error("A filesystem error occurred")]
381 FileIo(#[from] IoError),
382 #[error("No certificate found")]
384 MissingCert,
385 #[error("No key found")]
387 MissingKey,
388 #[error("The private key is invalid or unsupported")]
390 InvalidKey(#[source] RustlsError),
391 #[error("The private key does not match the certificate")]
393 KeyMismatch(#[source] RustlsError),
394}
395
396#[derive(Clone, Debug, PartialEq, Eq, Serialize, Deserialize)]
404#[serde(tag = "source", rename_all = "kebab-case")]
405pub enum CertificateSourceType {
406 Files {
414 cert: PathBuf,
416 key: PathBuf,
418 },
419}
420
421#[derive(Debug, Clone, PartialEq, Eq, thiserror::Error)]
424pub enum IntoListenAddressError {
425 #[error("\"{0}\" is not a valid listen address")]
427 General(String),
428 #[error("{0}")]
430 Protocol(#[from] ParseError),
431 #[error("invalid IP address for listener: {0}")]
433 Address(#[from] AddrParseError),
434 #[error("invalid port number for listener: {0}")]
436 Port(#[from] ParseIntError),
437}
438
439#[derive(Copy, Clone, Eq, Serialize, Deserialize)]
467#[serde(try_from = "&str", into = "String")]
468pub struct ListenAddress {
469 pub protocol: Protocol,
472 pub address: Option<IpAddr>,
476 pub port: Option<u16>,
479}
480
481impl Debug for ListenAddress {
482 fn fmt(&self, fmt: &mut Formatter<'_>) -> FmtResult {
483 Display::fmt(self, fmt)
484 }
485}
486
487impl Display for ListenAddress {
488 fn fmt(&self, fmt: &mut Formatter<'_>) -> FmtResult {
489 fmt.write_fmt(format_args!(
490 "{}:{}:{}",
491 self.protocol,
492 self.address.map_or(String::new(), |a| match a {
493 IpAddr::V4(a) => a.to_string(),
494 IpAddr::V6(a) => format!("[{a}]"),
495 }),
496 self.port.map_or(String::new(), |n| n.to_string())
497 ))
498 }
499}
500
501impl PartialEq for ListenAddress {
502 #[expect(
503 clippy::suspicious_operation_groupings,
504 reason = "This is correct, a `None` address is distinct from all `Some(_)` addresses, but \
505 a `None` port is just the default port for that protocol"
506 )]
507 fn eq(&self, other: &Self) -> bool {
508 self.protocol == other.protocol
509 && self.address == other.address
510 && self.port.unwrap_or_else(|| self.protocol.default_port())
511 == other.port.unwrap_or_else(|| other.protocol.default_port())
512 }
513}
514
515impl FromStr for ListenAddress {
516 type Err = IntoListenAddressError;
517
518 fn from_str(s: &str) -> Result<Self, Self::Err> {
519 let (protocol, rest) = s
520 .split_once(':')
521 .ok_or_else(|| IntoListenAddressError::General(s.to_string()))?;
522 let (address, port) = rest
523 .rsplit_once(':')
524 .ok_or_else(|| IntoListenAddressError::General(s.to_string()))?;
525
526 let address = if address.starts_with('[') && address.ends_with(']') {
527 Some(Ipv6Addr::from_str(address.trim_start_matches('[').trim_end_matches(']'))?.into())
528 } else if address.is_empty() {
529 None
530 } else {
531 Some(Ipv4Addr::from_str(address)?.into())
532 };
533
534 Ok(Self {
535 protocol: protocol.parse()?,
536 address,
537 port: match port {
538 "" => None,
539 s => Some(s.parse()?),
540 },
541 })
542 }
543}
544
545impl TryFrom<&str> for ListenAddress {
546 type Error = IntoListenAddressError;
547
548 fn try_from(s: &str) -> Result<Self, Self::Error> {
549 s.parse()
550 }
551}
552
553impl From<ListenAddress> for String {
554 fn from(address: ListenAddress) -> Self {
555 address.to_string()
556 }
557}
558
559#[derive(
562 Copy, Clone, Debug, Default, PartialEq, Eq, Serialize, Deserialize, EnumString, EnumDisplay,
563)]
564#[serde(rename_all = "snake_case")]
565#[strum(serialize_all = "snake_case")]
566pub enum LogLevel {
567 Trace,
570 Debug,
573 Verbose,
577 #[default]
579 Info,
580 Warn,
583 Error,
586}
587
588impl From<LogLevel> for Level {
589 fn from(log_level: LogLevel) -> Self {
590 match log_level {
591 LogLevel::Trace => Level::TRACE,
592 LogLevel::Debug => Level::DEBUG,
593 LogLevel::Verbose | LogLevel::Info => Level::INFO,
594 LogLevel::Warn => Level::WARN,
595 LogLevel::Error => Level::ERROR,
596 }
597 }
598}
599
600impl From<Level> for LogLevel {
601 fn from(log_level: Level) -> Self {
602 match log_level {
603 Level::TRACE => LogLevel::Trace,
604 Level::DEBUG => LogLevel::Debug,
605 Level::INFO => LogLevel::Info,
606 Level::WARN => LogLevel::Warn,
607 Level::ERROR => LogLevel::Error,
608 }
609 }
610}
611
612#[cfg(test)]
613mod tests {
614 use super::*;
615
616 #[test]
617 fn listen_address_parse() {
618 assert_eq!(
619 "http:0.0.0.0:80".parse(),
620 Ok(ListenAddress {
621 protocol: Protocol::Http,
622 address: Some([0, 0, 0, 0].into()),
623 port: Some(80)
624 })
625 );
626
627 assert_eq!(
628 "http:[::]:80".parse(),
629 Ok(ListenAddress {
630 protocol: Protocol::Http,
631 address: Some([0, 0, 0, 0, 0, 0, 0, 0].into()),
632 port: Some(80)
633 })
634 );
635
636 assert_eq!(
637 "https::".parse(),
638 Ok(ListenAddress {
639 protocol: Protocol::Https,
640 address: None,
641 port: None
642 })
643 );
644
645 assert_eq!(
646 "grpc:127.0.0.1:".parse(),
647 Ok(ListenAddress {
648 protocol: Protocol::Grpc,
649 address: Some([127, 0, 0, 1].into()),
650 port: None
651 })
652 );
653
654 assert_eq!(
655 "grpc:[::1]:".parse(),
656 Ok(ListenAddress {
657 protocol: Protocol::Grpc,
658 address: Some([0, 0, 0, 0, 0, 0, 0, 1].into()),
659 port: None
660 })
661 );
662
663 assert_eq!(
664 "grpcs::530".parse(),
665 Ok(ListenAddress {
666 protocol: Protocol::Grpcs,
667 address: None,
668 port: Some(530)
669 })
670 );
671
672 assert_eq!(
673 "GrPcS::530".parse(),
674 Ok(ListenAddress {
675 protocol: Protocol::Grpcs,
676 address: None,
677 port: Some(530)
678 })
679 );
680
681 assert_eq!(
682 "GrPcS:127.0.5.4:00530".parse(),
683 Ok(ListenAddress {
684 protocol: Protocol::Grpcs,
685 address: Some([127, 0, 5, 4].into()),
686 port: Some(530)
687 })
688 );
689 }
690
691 #[test]
692 fn listen_address_parse_invalid() {
693 assert!(matches!(
694 "http:[::1]".parse::<ListenAddress>(),
695 Err(IntoListenAddressError::General(_)
696 | IntoListenAddressError::Address(_)
697 | IntoListenAddressError::Port(_))
698 ));
699
700 assert!(matches!(
701 "http:localhost:80".parse::<ListenAddress>(),
702 Err(IntoListenAddressError::Address(_))
703 ));
704
705 assert!(matches!(
706 "http:0.0.0.0".parse::<ListenAddress>(),
707 Err(IntoListenAddressError::General(_))
708 ));
709
710 assert!(matches!(
711 ":0.0.0.0:443".parse::<ListenAddress>(),
712 Err(IntoListenAddressError::Protocol(_))
713 ));
714
715 assert!(matches!(
716 "https:123.456.789.0:443".parse::<ListenAddress>(),
717 Err(IntoListenAddressError::Address(_))
718 ));
719
720 assert!(matches!(
721 "https:[::1]:123456789".parse::<ListenAddress>(),
722 Err(IntoListenAddressError::Port(_))
723 ));
724
725 assert!(matches!(
726 "https:[::1]:4a5".parse::<ListenAddress>(),
727 Err(IntoListenAddressError::Port(_))
728 ));
729
730 assert!(matches!(
731 "https:[::1]:0x1bb".parse::<ListenAddress>(),
732 Err(IntoListenAddressError::Port(_))
733 ));
734 }
735
736 #[test]
737 fn listen_address_to_from_string() {
738 assert_eq!(
739 "http::".parse::<ListenAddress>().unwrap().to_string(),
740 "http::"
741 );
742
743 assert_eq!(
744 "https:[::]:".parse::<ListenAddress>().unwrap().to_string(),
745 "https:[::]:"
746 );
747
748 assert_eq!(
749 "grpc::456".parse::<ListenAddress>().unwrap().to_string(),
750 "grpc::456"
751 );
752
753 assert_eq!(
754 "grpcs:0.0.0.0:789"
755 .parse::<ListenAddress>()
756 .unwrap()
757 .to_string(),
758 "grpcs:0.0.0.0:789"
759 );
760
761 assert_eq!(
762 "grpcs:0.0.0.0:0000789"
763 .parse::<ListenAddress>()
764 .unwrap()
765 .to_string(),
766 "grpcs:0.0.0.0:789"
767 );
768
769 assert_eq!(
770 "grpcs:[0000:0000:0000:0000:0000:0000:0000:0000]:789"
771 .parse::<ListenAddress>()
772 .unwrap()
773 .to_string(),
774 "grpcs:[::]:789"
775 );
776 }
777
778 #[test]
779 fn listen_address_eq() {
780 assert_eq!(
781 ListenAddress {
782 protocol: Protocol::Http,
783 address: None,
784 port: None
785 },
786 ListenAddress {
787 protocol: Protocol::Http,
788 address: None,
789 port: None
790 }
791 );
792
793 assert_eq!(
794 ListenAddress {
795 protocol: Protocol::Http,
796 address: None,
797 port: None
798 },
799 ListenAddress {
800 protocol: Protocol::Http,
801 address: None,
802 port: Some(Protocol::HTTP_DEFAULT_PORT)
803 }
804 );
805
806 assert_ne!(
807 ListenAddress {
808 protocol: Protocol::Http,
809 address: None,
810 port: None
811 },
812 ListenAddress {
813 protocol: Protocol::Https,
814 address: None,
815 port: None
816 }
817 );
818
819 assert_ne!(
820 ListenAddress {
821 protocol: Protocol::Http,
822 address: Some("::".parse().unwrap()),
823 port: None
824 },
825 ListenAddress {
826 protocol: Protocol::Http,
827 address: None,
828 port: None
829 }
830 );
831
832 assert_ne!(
833 ListenAddress {
834 protocol: Protocol::Https,
835 address: Some("::".parse().unwrap()),
836 port: None
837 },
838 ListenAddress {
839 protocol: Protocol::Http,
840 address: None,
841 port: Some(1000)
842 }
843 );
844 }
845
846 #[test]
847 fn log_level() {
848 assert_eq!("verbose".parse(), Ok(LogLevel::Verbose));
849 assert_eq!("info".parse(), Ok(LogLevel::Info));
850 assert_eq!("warn".parse(), Ok(LogLevel::Warn));
851
852 assert_eq!("info".parse::<LogLevel>().map(Into::into), Ok(Level::INFO));
853 assert_eq!(
854 "verbose".parse::<LogLevel>().map(Into::into),
855 Ok(Level::INFO)
856 );
857 assert_eq!(
858 "error".parse::<LogLevel>().map(Into::into),
859 Ok(Level::ERROR)
860 );
861 }
862}