diff --git a/src/main.rs b/src/main.rs index a7eb618..add8884 100644 --- a/src/main.rs +++ b/src/main.rs @@ -1,3 +1,5 @@ +#![feature(let_chains)] + use std::process::{Command, Stdio}; use std::sync::mpsc; use std::{env, path::Path}; @@ -5,7 +7,7 @@ use std::{env, path::Path}; use log::{debug, trace}; use serde::{Deserialize, Serialize}; use serde_json; -use tokio; +use tokio::sync::Semaphore; const UPSTREAM_CACHES: &'static [&'static str] = &[ "https://cache.nixos.org", @@ -75,71 +77,91 @@ async fn main() { let store_paths = path_infos[0].get_store_paths(); let (cacheable_tx, cacheable_rx) = mpsc::channel(); + let mut handles = Vec::new(); + println!("spawning check_upstream"); - tokio::spawn(async move { + handles.push(tokio::spawn(async move { check_upstream(store_paths, cacheable_tx).await; - }); + })); println!("spawning uploader"); - tokio::spawn(async move { + handles.push(tokio::spawn(async move { uploader(cacheable_rx).await; - }).await.unwrap(); + })); + + // make sure all threads are done + for handle in handles { + handle.await.unwrap(); + } } // filter out store paths that exist in upstream caches async fn check_upstream(store_paths: Vec, cacheable_tx: mpsc::Sender) { + let concurrent = Semaphore::new(50); for store_path in store_paths { - let basename = Path::new(&store_path) - .file_name() - .unwrap() - .to_str() - .unwrap() - .to_string(); - let hash = basename.split("-").nth(0).unwrap(); - - let mut hit = false; - for upstream in UPSTREAM_CACHES { - let mut uri = String::from(*upstream); - uri.push_str(format!("/{}.narinfo", hash).as_str()); - - let res_status = reqwest::Client::new() - .head(uri) - .send() - .await + let _ = concurrent.acquire().await.unwrap(); + let tx = cacheable_tx.clone(); + tokio::spawn(async move { + let basename = Path::new(&store_path) + .file_name() .unwrap() - .status(); + .to_str() + .unwrap() + .to_string(); + let hash = basename.split("-").nth(0).unwrap(); - if res_status.is_success() { - debug!("{} was a hit upstream: {}", store_path, upstream); - hit = true; - break; + let mut hit = false; + for upstream in UPSTREAM_CACHES { + let mut uri = String::from(*upstream); + uri.push_str(format!("/{}.narinfo", hash).as_str()); + + let res_status = reqwest::Client::new() + .head(uri) + .send() + .await + .map(|x| x.status()); + + if let Ok(res_status) = res_status && res_status.is_success() { + println!("{} was a hit upstream: {}", store_path, upstream); + hit = true; + break; + } } - } - if !hit { - trace!("sending {}", store_path); - cacheable_tx.send(store_path).unwrap(); - } + if !hit { + trace!("sending {}", store_path); + tx.send(store_path).unwrap(); + } + }); } } async fn uploader(cacheable_rx: mpsc::Receiver) { let mut count = 0; + let concurrent = Semaphore::new(10); + let mut handles = Vec::new(); loop { if let Ok(path_to_upload) = cacheable_rx.recv() { - trace!("to upload: {}", path_to_upload); - if Command::new("nix") - .arg("copy") - .arg("--to") - .arg("s3://nixcache?endpoint=s3.cy7.sh&secret-key=/home/yt/cache-priv-key.pem") - .arg(&path_to_upload) - .output() - .is_err() - { - println!("WARN: upload failed: {}", path_to_upload); - } else { - count += 1; - } + let _ = concurrent.acquire().await.unwrap(); + handles.push(tokio::spawn(async move { + println!("uploading: {}", path_to_upload); + if Command::new("nix") + .arg("copy") + .arg("--to") + .arg("s3://nixcache?endpoint=s3.cy7.sh&secret-key=/home/yt/cache-priv-key.pem") + .arg(&path_to_upload) + .output() + .is_err() + { + println!("WARN: upload failed: {}", path_to_upload); + } else { + count += 1; + } + })); } else { + // make sure all threads are done + for handle in handles { + handle.await.unwrap(); + } println!("uploaded {} paths", count); break; }