arc on downloaded_data

This commit is contained in:
2025-01-05 16:53:16 +08:00
parent 07ee03b321
commit c64e6573d9

View File

@@ -19,10 +19,10 @@ struct Downloader {
#[derive(Debug)] #[derive(Debug)]
struct DownloadState { struct DownloadState {
notify: Notify, notify: Notify,
result: Mutex<Option<Result<DownloadData, (StatusCode, String)>>>, result: Mutex<Option<Result<Arc<DownloadData>, (StatusCode, String)>>>,
} }
#[derive(Debug, Clone)] #[derive(Debug)]
struct DownloadData { struct DownloadData {
body: Bytes, body: Bytes,
headers: reqwest::header::HeaderMap, headers: reqwest::header::HeaderMap,
@@ -35,7 +35,7 @@ impl Downloader {
} }
} }
async fn download(&self, target_url: &str) -> Result<DownloadData, (StatusCode, String)> { async fn download(&self, target_url: &str) -> Result<Arc<DownloadData>, (StatusCode, String)> {
// check if the url is already downloading // check if the url is already downloading
{ {
let states = self.states.lock().await; let states = self.states.lock().await;
@@ -82,18 +82,17 @@ impl Downloader {
.await .await
.map_err(|e| (StatusCode::INTERNAL_SERVER_ERROR, e.to_string()))?; .map_err(|e| (StatusCode::INTERNAL_SERVER_ERROR, e.to_string()))?;
let download_data = Arc::new(DownloadData { body, headers });
// notify all waiters // notify all waiters
{ {
let mut states = self.states.lock().await; let mut states = self.states.lock().await;
let state = states.remove(target_url).unwrap(); let state = states.remove(target_url).unwrap();
state.result.lock().await.replace(Ok(DownloadData { state.result.lock().await.replace(Ok(download_data.clone()));
body: body.clone(),
headers: headers.clone(),
}));
state.notify.notify_waiters(); state.notify.notify_waiters();
} }
Ok(DownloadData { body, headers }) Ok(download_data.into())
} }
} }
#[tokio::main] #[tokio::main]
@@ -147,7 +146,7 @@ async fn api_proxy_image(
} }
let downlaod_begin = std::time::Instant::now(); let downlaod_begin = std::time::Instant::now();
let DownloadData { body, headers } = downloader.download(target_url).await?; let DownloadData { body, headers } = &*downloader.download(target_url).await?;
let ori_content_type = headers.get("Content-Type").unwrap().to_str().unwrap(); let ori_content_type = headers.get("Content-Type").unwrap().to_str().unwrap();
if !ori_content_type.starts_with("image/") { if !ori_content_type.starts_with("image/") {
return Err(( return Err((