1use std::{
30 fmt::{Debug, Formatter, Result as FmtResult},
31 net::{IpAddr, Ipv6Addr, SocketAddr},
32 os::raw::c_int,
33 sync::Arc,
34 thread,
35};
36
37use hyper::{rt, server::conn::http2, service::service_fn, Request};
38use hyper_util::{
39 rt::{TokioExecutor, TokioIo},
40 server::conn::auto::Builder,
41 service::TowerToHyperService,
42};
43use links_id::Id;
44use links_normalized::{Link, Normalized};
45use parking_lot::Mutex;
46use socket2::{Domain, Protocol as SocketProtocol, Socket, Type};
47use strum::{Display as EnumDisplay, EnumString};
48use tokio::{
49 io::{AsyncRead, AsyncWrite, Error as IoError},
50 net::{TcpListener, TcpStream},
51 spawn,
52 task::JoinHandle,
53};
54use tokio_rustls::{rustls::ServerConfig, TlsAcceptor};
55use tonic::{
56 codegen::{CompressionEncoding, InterceptedService},
57 service::Routes,
58 transport::Server as RpcServer,
59};
60use tower::util::ServiceExt;
61use tracing::{debug, error, trace, warn};
62
63use crate::{
64 api::{self, Api, LinksServer},
65 certs::CertificateResolver,
66 config::{Config, ListenAddress},
67 redirector::{https_redirector, redirector},
68 stats::ExtraStatisticInfo,
69 store::{Current, Store},
70};
71
72const LISTENER_TCP_BACKLOG_SIZE: c_int = 1024;
78
79pub async fn http_handler(
83 stream: impl rt::Read + rt::Write + Send + Unpin + 'static,
84 store: Store,
85 config: &'static Config,
86 stat_info: ExtraStatisticInfo,
87) {
88 let redirector_service = service_fn(move |req: Request<_>| {
89 redirector(req, store.clone(), config.redirector(), stat_info.clone())
90 });
91
92 if let Err(err) = Builder::new(TokioExecutor::new())
93 .serve_connection(stream, redirector_service)
94 .await
95 {
96 error!(?err, "Error while handling HTTP connection");
97 }
98}
99
100pub async fn http_to_https_handler(
107 stream: impl rt::Read + rt::Write + Send + Unpin + 'static,
108 config: &'static Config,
109) {
110 let redirector_service =
111 service_fn(move |req: Request<_>| https_redirector(req, config.redirector()));
112
113 if let Err(err) = Builder::new(TokioExecutor::new())
114 .serve_connection(stream, redirector_service)
115 .await
116 {
117 error!(?err, "Error while handling HTTP connection");
118 }
119}
120
121pub async fn rpc_handler(
123 stream: impl rt::Read + rt::Write + Send + Unpin + 'static,
124 service: Routes,
125) {
126 if let Err(rpc_err) = http2::Builder::new(TokioExecutor::new())
127 .serve_connection(
128 stream,
129 TowerToHyperService::new(
130 service.map_request(|req: Request<_>| req.map(tonic::body::boxed)),
131 ),
132 )
133 .await
134 {
135 error!(?rpc_err, "Error while handling gRPC connection");
136 }
137}
138
139#[async_trait::async_trait]
146pub trait Acceptor<S: AsyncRead + AsyncWrite + Send + Unpin + 'static>:
147 Send + Sync + 'static
148{
149 async fn accept(&self, stream: S, local_addr: SocketAddr, remote_addr: SocketAddr);
155
156 fn protocol(&self) -> Protocol;
158}
159
160#[derive(Debug)]
164pub struct PlainHttpAcceptor {
165 config: &'static Config,
166 current_store: &'static Current,
167}
168
169impl PlainHttpAcceptor {
170 pub fn new(config: &'static Config, current_store: &'static Current) -> &'static Self {
177 Box::leak(Box::new(Self {
178 config,
179 current_store,
180 }))
181 }
182}
183
184#[async_trait::async_trait]
185impl Acceptor<TcpStream> for PlainHttpAcceptor {
186 async fn accept(&self, stream: TcpStream, local_addr: SocketAddr, remote_addr: SocketAddr) {
187 let config = self.config;
188 let current_store = self.current_store;
189
190 spawn(async move {
191 trace!("New plain connection from {remote_addr} on {local_addr}");
192
193 if config.https_redirect() {
194 http_to_https_handler(TokioIo::new(stream), config).await;
195 } else {
196 http_handler(
197 TokioIo::new(stream),
198 current_store.get(),
199 config,
200 ExtraStatisticInfo::default(),
201 )
202 .await;
203 }
204 });
205 }
206
207 fn protocol(&self) -> Protocol {
208 Protocol::Http
209 }
210}
211
212pub struct TlsHttpAcceptor {
215 config: &'static Config,
216 current_store: &'static Current,
217 tls_acceptor: TlsAcceptor,
218}
219
220impl TlsHttpAcceptor {
221 pub fn new(
229 config: &'static Config,
230 current_store: &'static Current,
231 cert_resolver: Arc<CertificateResolver>,
232 ) -> &'static Self {
233 let mut server_config = ServerConfig::builder()
234 .with_no_client_auth()
235 .with_cert_resolver(cert_resolver);
236 server_config.alpn_protocols = vec![b"h2".to_vec(), b"http/1.1".to_vec()];
237
238 let server_config = Arc::new(server_config);
239 let tls_acceptor = TlsAcceptor::from(server_config);
240
241 Box::leak(Box::new(Self {
242 config,
243 current_store,
244 tls_acceptor,
245 }))
246 }
247}
248
249#[async_trait::async_trait]
250impl Acceptor<TcpStream> for TlsHttpAcceptor {
251 async fn accept(&self, stream: TcpStream, local_addr: SocketAddr, remote_addr: SocketAddr) {
252 let config = self.config;
253 let current_store = self.current_store;
254 let tls_acceptor = self.tls_acceptor.clone();
255
256 spawn(async move {
257 trace!("New TLS connection from {remote_addr} on {local_addr}");
258
259 match tls_acceptor.accept(stream).await {
260 Ok(stream) => {
261 let tls_conn = stream.get_ref().1;
262 let extra_info = ExtraStatisticInfo {
263 tls_sni: tls_conn.server_name().map(Arc::from),
264 tls_version: tls_conn.protocol_version(),
265 tls_cipher_suite: tls_conn.negotiated_cipher_suite(),
266 };
267
268 http_handler(
269 TokioIo::new(stream),
270 current_store.get(),
271 config,
272 extra_info,
273 )
274 .await;
275 }
276 Err(err) => warn!("Error accepting incoming TLS connection: {err:?}"),
277 }
278 });
279 }
280
281 fn protocol(&self) -> Protocol {
282 Protocol::Https
283 }
284}
285
286impl Debug for TlsHttpAcceptor {
287 fn fmt(&self, fmt: &mut Formatter<'_>) -> FmtResult {
288 #[derive(Debug)]
289 struct TlsAcceptor {}
290
291 fmt.debug_struct("TlsHttpAcceptor")
292 .field("config", self.config)
293 .field("current_store", self.current_store)
294 .field("tls_acceptor", &TlsAcceptor {})
295 .finish()
296 }
297}
298
299#[derive(Debug)]
302pub struct PlainRpcAcceptor {
303 service: Mutex<Routes>,
304}
305
306impl PlainRpcAcceptor {
307 pub fn new(config: &'static Config, current_store: &'static Current) -> &'static Self {
314 let service = RpcServer::builder()
315 .add_service(InterceptedService::new(
316 LinksServer::new(Api::new(current_store))
317 .send_compressed(CompressionEncoding::Gzip)
318 .accept_compressed(CompressionEncoding::Gzip),
319 api::get_auth_checker(config),
320 ))
321 .into_service()
322 .prepare();
323
324 Box::leak(Box::new(Self {
325 service: Mutex::new(service),
326 }))
327 }
328}
329
330#[async_trait::async_trait]
331impl Acceptor<TcpStream> for PlainRpcAcceptor {
332 async fn accept(&self, stream: TcpStream, local_addr: SocketAddr, remote_addr: SocketAddr) {
333 let service = self.service.lock().clone();
334
335 spawn(async move {
336 trace!("New plain connection from {remote_addr} on {local_addr}");
337
338 rpc_handler(TokioIo::new(stream), service).await;
339 });
340 }
341
342 fn protocol(&self) -> Protocol {
343 Protocol::Grpc
344 }
345}
346
347pub struct TlsRpcAcceptor {
350 service: Arc<Mutex<Routes>>,
351 tls_acceptor: TlsAcceptor,
352}
353
354impl TlsRpcAcceptor {
355 pub fn new(
363 config: &'static Config,
364 current_store: &'static Current,
365 cert_resolver: Arc<CertificateResolver>,
366 ) -> &'static Self {
367 let mut server_config = ServerConfig::builder()
368 .with_no_client_auth()
369 .with_cert_resolver(cert_resolver);
370 server_config.alpn_protocols = vec![b"h2".to_vec(), b"http/1.1".to_vec()];
371
372 let server_config = Arc::new(server_config);
373 let tls_acceptor = TlsAcceptor::from(server_config);
374
375 let service = RpcServer::builder()
376 .add_service(InterceptedService::new(
377 LinksServer::new(Api::new(current_store))
378 .send_compressed(CompressionEncoding::Gzip)
379 .accept_compressed(CompressionEncoding::Gzip),
380 api::get_auth_checker(config),
381 ))
382 .into_service()
383 .prepare();
384
385 Box::leak(Box::new(Self {
386 service: Arc::new(Mutex::new(service)),
387 tls_acceptor,
388 }))
389 }
390}
391
392#[async_trait::async_trait]
393impl Acceptor<TcpStream> for TlsRpcAcceptor {
394 async fn accept(&self, stream: TcpStream, local_addr: SocketAddr, remote_addr: SocketAddr) {
395 let tls_acceptor = self.tls_acceptor.clone();
396 let service = self.service.lock().clone();
397
398 spawn(async move {
399 trace!("New TLS connection from {remote_addr} on {local_addr}");
400
401 match tls_acceptor.accept(stream).await {
402 Ok(stream) => rpc_handler(TokioIo::new(stream), service).await,
403 Err(err) => warn!("Error accepting incoming TLS connection: {err:?}"),
404 }
405 });
406 }
407
408 fn protocol(&self) -> Protocol {
409 Protocol::Grpcs
410 }
411}
412
413impl Debug for TlsRpcAcceptor {
414 fn fmt(&self, fmt: &mut Formatter<'_>) -> FmtResult {
415 #[derive(Debug)]
416 struct TlsAcceptor {}
417
418 fmt.debug_struct("TlsRpcAcceptor")
419 .field("service", &self.service)
420 .field("tls_acceptor", &TlsAcceptor {})
421 .finish()
422 }
423}
424
425#[derive(Debug, Clone, Copy, PartialEq, Eq, EnumString, EnumDisplay)]
427#[strum(serialize_all = "snake_case", ascii_case_insensitive)]
428pub enum Protocol {
429 Http,
431 Https,
433 Grpc,
435 Grpcs,
437}
438
439impl Protocol {
440 pub const GRPCS_DEFAULT_PORT: u16 = 530;
442 pub const GRPC_DEFAULT_PORT: u16 = 50051;
444 pub const HTTPS_DEFAULT_PORT: u16 = 443;
446 pub const HTTP_DEFAULT_PORT: u16 = 80;
448
449 #[must_use]
451 pub const fn default_port(self) -> u16 {
452 match self {
453 Self::Http => Self::HTTP_DEFAULT_PORT,
454 Self::Https => Self::HTTPS_DEFAULT_PORT,
455 Self::Grpc => Self::GRPC_DEFAULT_PORT,
456 Self::Grpcs => Self::GRPCS_DEFAULT_PORT,
457 }
458 }
459}
460
461#[derive(Debug)]
467pub struct Listener {
468 pub addr: Option<IpAddr>,
473 pub port: u16,
476 pub proto: Protocol,
479 handle: JoinHandle<()>,
480}
481
482impl Listener {
483 #[allow(clippy::unused_async)] pub async fn new(
517 addr: Option<IpAddr>,
518 port: Option<u16>,
519 acceptor: &'static impl Acceptor<TcpStream>,
520 ) -> Result<Self, IoError> {
521 let proto = acceptor.protocol();
522 let port = port.unwrap_or_else(|| proto.default_port());
523 let socket_addr = (addr.unwrap_or(IpAddr::V6(Ipv6Addr::UNSPECIFIED)), port).into();
524
525 let socket = Socket::new(
526 Domain::for_address(socket_addr),
527 Type::STREAM,
528 Some(SocketProtocol::TCP),
529 )?;
530
531 socket.set_reuse_address(cfg!(unix))?;
536 if socket_addr.is_ipv6() {
540 socket.set_only_v6(addr.is_some())?;
541 }
542 socket.set_nonblocking(true)?;
544 socket.set_nodelay(true)?;
546
547 socket.bind(&socket_addr.into())?;
548 socket.listen(LISTENER_TCP_BACKLOG_SIZE)?;
549 let listener = TcpListener::from_std(socket.into())?;
550
551 let handle = spawn(async move {
552 loop {
553 match listener.accept().await {
554 Ok((stream, remote_addr)) => {
555 acceptor.accept(stream, socket_addr, remote_addr).await;
556 }
557 Err(err) => {
558 warn!("Error accepting TCP connection on {socket_addr}: {err:?}");
559 }
560 }
561 }
562 });
563
564 debug!("Opened new listener on {}", ListenAddress {
565 protocol: proto,
566 address: addr,
567 port: Some(port),
568 });
569
570 Ok(Self {
571 addr,
572 port,
573 proto,
574 handle,
575 })
576 }
577
578 #[must_use]
580 pub const fn listen_address(&self) -> ListenAddress {
581 ListenAddress {
582 protocol: self.proto,
583 address: self.addr,
584 port: Some(self.port),
585 }
586 }
587}
588
589impl Drop for Listener {
590 fn drop(&mut self) {
597 trace!("Closing listener on {}", self.listen_address());
598
599 self.handle.abort();
600
601 while !self.handle.is_finished() {
602 thread::yield_now();
603 }
604
605 debug!("Closed listener on {}", self.listen_address());
606 }
607}
608
609pub async fn store_setup(config: &Config, example_redirect: bool) -> Result<Store, anyhow::Error> {
617 let store = Store::new(config.store(), &config.store_config()).await?;
618
619 if example_redirect {
620 store
621 .set_redirect(Id::try_from(Id::MAX)?, Link::new("https://example.com/")?)
622 .await?;
623 store
624 .set_vanity(Normalized::new("example"), Id::try_from(Id::MAX)?)
625 .await?;
626 }
627
628 Ok(store)
629}
630
631#[cfg(test)]
632mod tests {
633 use std::time::{Duration, Instant};
634
635 use super::*;
636
637 #[derive(Debug, Copy, Clone)]
639 struct UnAcceptor;
640
641 #[async_trait::async_trait]
642 impl Acceptor<TcpStream> for UnAcceptor {
643 async fn accept(&self, _: TcpStream, _: SocketAddr, _: SocketAddr) {
644 spawn(async {});
645 }
646
647 fn protocol(&self) -> Protocol {
648 Protocol::Http
649 }
650 }
651
652 #[tokio::test(flavor = "multi_thread")]
653 async fn listener_new_drop() {
654 let addr = Some([127, 0, 0, 1].into());
655 let port = Some(8000);
656
657 let listener = Listener::new(addr, port, &UnAcceptor).await.unwrap();
658
659 let start = Instant::now();
660 drop(listener);
661 let duration = start.elapsed();
662
663 let _listener = Listener::new(addr, port, &UnAcceptor).await.unwrap();
664
665 assert!(
666 dbg!(duration) < Duration::from_millis(if cfg!(debug_assertions) { 100 } else { 1 })
667 );
668 }
669
670 #[tokio::test]
671 async fn fn_store_setup() {
672 let with_example = store_setup(&Config::new(None), true).await.unwrap();
673 let without_example = store_setup(&Config::new(None), false).await.unwrap();
674
675 assert_eq!(
676 with_example.get_vanity("example".into()).await.unwrap(),
677 Some(Id::MAX.try_into().unwrap())
678 );
679
680 assert_eq!(
681 without_example.get_vanity("example".into()).await.unwrap(),
682 None
683 );
684 }
685}