1use std::{
27 collections::HashMap,
28 fmt::{Debug, Formatter, Result as FmtResult},
29};
30
31use anyhow::{anyhow, Result};
32use async_trait::async_trait;
33use fred::{
34 bytes_utils::Str,
35 prelude::*,
36 types::{ClusterDiscoveryPolicy, RespVersion},
37};
38use links_id::Id;
39use links_normalized::{Link, Normalized};
40use tokio::try_join;
41use tracing::instrument;
42
43use super::BackendType;
44use crate::{
45 stats::{Statistic, StatisticDescription, StatisticValue},
46 store::StoreBackend,
47};
48
49pub struct Store {
79 pool: RedisPool,
80}
81
82impl Debug for Store {
83 fn fmt(&self, f: &mut Formatter<'_>) -> FmtResult {
84 f.debug_struct("Store").finish_non_exhaustive()
85 }
86}
87
88#[async_trait]
89impl StoreBackend for Store {
90 fn store_type() -> BackendType
91 where
92 Self: Sized,
93 {
94 BackendType::Redis
95 }
96
97 fn get_store_type(&self) -> BackendType {
98 BackendType::Redis
99 }
100
101 #[instrument(level = "trace", ret, err)]
102 async fn new(config: &HashMap<String, String>) -> Result<Self> {
103 let server_config = if config.get("cluster").map_or(Ok(false), |s| s.parse())? {
104 ServerConfig::Clustered {
105 hosts: config
106 .get("connect")
107 .ok_or_else(|| anyhow!("missing connect option"))?
108 .split(',')
109 .map(|s| {
110 s.trim()
111 .split_once(':')
112 .map(|v| {
113 let host = Str::from(v.0);
114
115 Ok(Server {
116 host: host.clone(),
117 port: v.1.parse::<u16>()?,
118 tls_server_name: Some(host),
119 })
120 })
121 .ok_or_else(|| anyhow!("couldn't parse connect value"))?
122 })
123 .collect::<Result<_, anyhow::Error>>()?,
124 policy: ClusterDiscoveryPolicy::ConfigEndpoint,
125 }
126 } else {
127 let (host, port) = config
128 .get("connect")
129 .map(|s| {
130 s.split_once(':')
131 .map::<Result<_, anyhow::Error>, _>(|v| {
132 Ok((Str::from(v.0), v.1.parse::<u16>()?))
133 })
134 .ok_or_else(|| anyhow!("couldn't parse connect value"))?
135 })
136 .ok_or_else(|| anyhow!("missing connect option"))??;
137
138 ServerConfig::Centralized {
139 server: Server {
140 host: host.clone(),
141 port,
142 tls_server_name: Some(host),
143 },
144 }
145 };
146
147 let pool_config = RedisConfig {
148 username: config.get("username").cloned(),
149 password: config.get("password").cloned(),
150 server: server_config,
151 version: RespVersion::RESP3,
152 database: config.get("database").map(|s| s.parse()).transpose()?,
153 tracing: TracingConfig {
154 enabled: true,
155 ..Default::default()
156 },
157 tls: if config.get("tls").map_or(Ok(false), |s| s.parse())? {
158 Some(TlsConnector::default_rustls()?.into())
159 } else {
160 None
161 },
162 ..RedisConfig::default()
163 };
164
165 let pool = RedisPool::new(
166 pool_config,
167 None,
168 None,
169 Some(ReconnectPolicy::new_constant(0, 100)),
170 config
171 .get("pool_size")
172 .map(|s| s.parse())
173 .transpose()?
174 .unwrap_or(8),
175 )?;
176
177 pool.connect();
178 pool.wait_for_connect().await?;
179
180 Ok(Self { pool })
181 }
182
183 #[instrument(level = "trace", ret, err)]
184 async fn get_redirect(&self, from: Id) -> Result<Option<Link>> {
185 Ok(self.pool.get(format!("links:redirect:{from}")).await?)
186 }
187
188 #[instrument(level = "trace", ret, err)]
189 async fn set_redirect(&self, from: Id, to: Link) -> Result<Option<Link>> {
190 Ok(self
191 .pool
192 .set(
193 format!("links:redirect:{from}"),
194 to.into_string(),
195 None,
196 None,
197 true,
198 )
199 .await?)
200 }
201
202 #[instrument(level = "trace", ret, err)]
203 async fn rem_redirect(&self, from: Id) -> Result<Option<Link>> {
204 Ok(self.pool.getdel(format!("links:redirect:{from}")).await?)
205 }
206
207 #[instrument(level = "trace", ret, err)]
208 async fn get_vanity(&self, from: Normalized) -> Result<Option<Id>> {
209 Ok(self.pool.get(format!("links:vanity:{from}")).await?)
210 }
211
212 #[instrument(level = "trace", ret, err)]
213 async fn set_vanity(&self, from: Normalized, to: Id) -> Result<Option<Id>> {
214 Ok(self
215 .pool
216 .set(
217 format!("links:vanity:{from}"),
218 to.to_string(),
219 None,
220 None,
221 true,
222 )
223 .await?)
224 }
225
226 #[instrument(level = "trace", ret, err)]
227 async fn rem_vanity(&self, from: Normalized) -> Result<Option<Id>> {
228 Ok(self.pool.getdel(format!("links:vanity:{from}")).await?)
229 }
230
231 #[instrument(level = "trace", ret, err)]
232 async fn get_statistics(
233 &self,
234 description: StatisticDescription,
235 ) -> Result<Vec<(Statistic, StatisticValue)>> {
236 let mut keys = Vec::with_capacity(5);
237
238 keys.push("links:stat-all".to_string());
239
240 if let Some(link) = description.link {
241 keys.push(format!("links:stat-link:{link}"));
242 }
243
244 if let Some(stat_type) = description.stat_type {
245 keys.push(format!("links:stat-type:{stat_type}"));
246 }
247
248 if let Some(data) = description.data {
249 keys.push(format!("links:stat-data:{data}"));
250 }
251
252 if let Some(time) = description.time {
253 keys.push(format!("links:stat-time:{time}"));
254 }
255
256 let stats: Vec<Statistic> = self
257 .pool
258 .sinter::<Vec<String>, _>(keys)
259 .await?
260 .into_iter()
261 .filter_map(|s| serde_json::from_str(&s).ok())
262 .collect();
263
264 let stat_keys = stats
265 .iter()
266 .map(
267 |Statistic {
268 link,
269 stat_type,
270 time,
271 data,
272 }| format!("links:stat:{link}:{stat_type}:{time}:{data}"),
273 )
274 .collect::<Vec<String>>();
275
276 let values: Vec<Option<u64>> = if stat_keys.is_empty() {
277 Vec::new()
278 } else {
279 self.pool.mget(stat_keys).await?
280 };
281
282 let res = stats
283 .into_iter()
284 .zip(values.into_iter())
285 .filter_map(|(s, v)| Some((s, StatisticValue::new(v?)?)))
286 .collect();
287
288 Ok(res)
289 }
290
291 #[instrument(level = "trace", ret, err)]
292 async fn incr_statistic(&self, statistic: Statistic) -> Result<Option<StatisticValue>> {
293 let stat_json = serde_json::to_string(&statistic)?;
294
295 let Statistic {
296 link,
297 stat_type,
298 data,
299 time,
300 } = statistic;
301
302 let values: Vec<RedisValue> = self
303 .pool
304 .incr(format!("links:stat:{link}:{stat_type}:{time}:{data}"))
305 .await?;
306
307 Box::pin(async {
308 try_join!(
309 self.pool
310 .sadd::<(), _, _>("links:stat-all".to_string(), &stat_json),
311 self.pool
312 .sadd::<(), _, _>(format!("links:stat-link:{link}"), &stat_json),
313 self.pool
314 .sadd::<(), _, _>(format!("links:stat-type:{stat_type}"), &stat_json),
315 self.pool
316 .sadd::<(), _, _>(format!("links:stat-data:{data}"), &stat_json),
317 self.pool
318 .sadd::<(), _, _>(format!("links:stat-time:{time}"), &stat_json),
319 )
320 })
321 .await?;
322
323 Ok(values
324 .first()
325 .and_then(RedisValue::as_u64)
326 .and_then(StatisticValue::new))
327 }
328
329 #[instrument(level = "trace", ret, err)]
330 async fn rem_statistics(
331 &self,
332 description: StatisticDescription,
333 ) -> Result<Vec<(Statistic, StatisticValue)>> {
334 let mut keys = Vec::with_capacity(5);
335
336 keys.push("links:stat-all".to_string());
337
338 if let Some(link) = description.link {
339 keys.push(format!("links:stat-link:{link}"));
340 }
341
342 if let Some(stat_type) = description.stat_type {
343 keys.push(format!("links:stat-type:{stat_type}"));
344 }
345
346 if let Some(data) = description.data {
347 keys.push(format!("links:stat-data:{data}"));
348 }
349
350 if let Some(time) = description.time {
351 keys.push(format!("links:stat-time:{time}"));
352 }
353
354 let stats_json: Vec<String> = self.pool.sinter(keys.clone()).await?;
355 let stats: Vec<Statistic> = stats_json
356 .iter()
357 .filter_map(|s| serde_json::from_str(s).ok())
358 .collect();
359
360 let stat_keys = stats
361 .iter()
362 .map(
363 |Statistic {
364 link,
365 stat_type,
366 time,
367 data,
368 }| format!("links:stat:{link}:{stat_type}:{time}:{data}"),
369 )
370 .collect::<Vec<String>>();
371
372 let values: Vec<Option<u64>> = if stat_keys.is_empty() {
373 Vec::new()
374 } else {
375 let values = self.pool.mget(stat_keys.clone()).await?;
376 let () = self.pool.del(stat_keys).await?;
377 for key in keys {
378 let () = self.pool.srem(key, stats_json.clone()).await?;
379 }
380 values
381 };
382
383 let res = stats
384 .into_iter()
385 .zip(values.into_iter())
386 .filter_map(|(s, v)| Some((s, StatisticValue::new(v?)?)))
387 .collect();
388
389 Ok(res)
390 }
391}
392
393#[cfg(all(test, feature = "test-redis"))]
401mod tests {
402 use std::collections::HashMap;
403
404 use super::Store;
405 use crate::store::{tests, StoreBackend as _};
406
407 async fn get_store() -> Store {
408 Store::new(&HashMap::from_iter([(
409 "connect".to_string(),
410 "localhost:6379".to_string(),
411 )]))
412 .await
413 .unwrap()
414 }
415
416 #[test]
417 fn store_type() {
418 tests::store_type::<Store>();
419 }
420
421 #[tokio::test]
422 async fn get_store_type() {
423 tests::get_store_type::<Store>(&get_store().await);
424 }
425
426 #[tokio::test]
427 async fn get_redirect() {
428 tests::get_redirect(&get_store().await).await;
429 }
430
431 #[tokio::test]
432 async fn set_redirect() {
433 tests::set_redirect(&get_store().await).await;
434 }
435
436 #[tokio::test]
437 async fn rem_redirect() {
438 tests::rem_redirect(&get_store().await).await;
439 }
440
441 #[tokio::test]
442 async fn get_vanity() {
443 tests::get_vanity(&get_store().await).await;
444 }
445
446 #[tokio::test]
447 async fn set_vanity() {
448 tests::set_vanity(&get_store().await).await;
449 }
450
451 #[tokio::test]
452 async fn rem_vanity() {
453 tests::rem_vanity(&get_store().await).await;
454 }
455
456 #[tokio::test]
457 async fn get_statistics() {
458 tests::get_statistics(&get_store().await).await;
459 }
460
461 #[tokio::test]
462 async fn incr_statistic() {
463 tests::incr_statistic(&get_store().await).await;
464 }
465
466 #[tokio::test]
467 async fn rem_statistics() {
468 tests::rem_statistics(&get_store().await).await;
469 }
470}