diff --git a/src/main.rs b/src/main.rs index 819460a..51c3134 100644 --- a/src/main.rs +++ b/src/main.rs @@ -1,16 +1,16 @@ #![feature(let_chains)] use std::path::Path; -use std::process::Command; use std::sync::mpsc; use std::sync::{ - Arc, Mutex, + Arc, Mutex, RwLock, atomic::{AtomicUsize, Ordering}, }; use clap::Parser; use log::{debug, trace}; use serde::{Deserialize, Serialize}; +use tokio::process::Command; use tokio::sync::Semaphore; const UPSTREAM_CACHES: &[&str] = &["https://cache.nixos.org"]; @@ -30,13 +30,15 @@ struct PathInfo { impl PathInfo { // find derivations related to package - fn from_package(package: &str) -> Vec { + async fn from_package(package: &str) -> Vec { let path_infos = Command::new("nix") .arg("path-info") .arg("--derivation") + .arg("--recursive") .arg("--json") .arg(package) .output() + .await .expect("path-info failed"); let path_infos: Vec = serde_json::from_slice(&path_infos.stdout).unwrap(); @@ -45,7 +47,7 @@ impl PathInfo { } // find store paths related to derivation - fn get_store_paths(&self) -> Vec { + async fn get_store_paths(&self) -> Vec { let mut store_paths: Vec = Vec::new(); let nix_store_cmd = Command::new("nix-store") .arg("--query") @@ -53,6 +55,7 @@ impl PathInfo { .arg("--include-outputs") .arg(&self.path) .output() + .await .expect("nix-store cmd failed"); let nix_store_out = String::from_utf8(nix_store_cmd.stdout).unwrap(); @@ -85,6 +88,10 @@ struct Cli { /// Concurrent uploaders #[arg(long, default_value_t = 10)] uploader_concurrency: u8, + + /// Concurrent nix-store commands to run + #[arg(long, default_value_t = 50)] + nix_store_concurrency: u8, } #[tokio::main] @@ -102,15 +109,33 @@ async fn main() { debug!("upstream caches: {:#?}", upstream_caches); println!("querying nix path-info"); - let path_infos = PathInfo::from_package(package); + let derivations = PathInfo::from_package(package).await; + println!("got {} derivations", derivations.len()); println!("querying nix-store"); - let store_paths = path_infos[0].get_store_paths(); + let mut handles = Vec::new(); + let concurrency = Arc::new(Semaphore::new(cli.nix_store_concurrency.into())); + let store_paths = Arc::new(RwLock::new(Vec::new())); + + for derivation in derivations { + let store_paths = Arc::clone(&store_paths); + let permit = Arc::clone(&concurrency); + handles.push(tokio::spawn(async move { + let _permit = permit.acquire_owned().await.unwrap(); + let paths = derivation.get_store_paths().await; + store_paths.write().unwrap().extend(paths); + })); + } + // resolve store paths for all derivations before we move on + for handle in handles { + handle.await.unwrap(); + } + println!("got {} store paths", store_paths.read().unwrap().len()); + let (cacheable_tx, cacheable_rx) = mpsc::channel(); - let mut handles = Vec::new(); - println!("spawning check_upstream"); + handles = Vec::new(); handles.push(tokio::spawn(async move { check_upstream( store_paths, @@ -134,19 +159,22 @@ async fn main() { // filter out store paths that exist in upstream caches async fn check_upstream( - store_paths: Vec, + store_paths: Arc>>, cacheable_tx: mpsc::Sender, concurrency: u8, upstream_caches: Arc>, ) { - let concurrent = Semaphore::new(concurrency.into()); + let concurrency = Arc::new(Semaphore::new(concurrency.into())); + let c_store_paths = Arc::clone(&store_paths); + let store_paths = c_store_paths.read().unwrap().clone(); for store_path in store_paths { - let _ = concurrent.acquire().await.unwrap(); let tx = cacheable_tx.clone(); let upstream_caches = Arc::clone(&upstream_caches); + let concurrency = Arc::clone(&concurrency); tokio::spawn(async move { + let _permit = concurrency.acquire().await.unwrap(); let basename = Path::new(&store_path) .file_name() .unwrap() @@ -185,16 +213,18 @@ async fn check_upstream( async fn uploader(cacheable_rx: mpsc::Receiver, binary_cache: String, concurrency: u8) { let upload_count = Arc::new(AtomicUsize::new(0)); let failures: Arc>> = Arc::new(Mutex::new(Vec::new())); - let concurrent = Semaphore::new(concurrency.into()); + let concurrency = Arc::new(Semaphore::new(concurrency.into())); let mut handles = Vec::new(); + loop { if let Ok(path_to_upload) = cacheable_rx.recv() { - let _ = concurrent.acquire().await.unwrap(); + let concurrency = Arc::clone(&concurrency); let failures = Arc::clone(&failures); let binary_cache = binary_cache.clone(); let upload_count = Arc::clone(&upload_count); handles.push(tokio::spawn(async move { + let _permit = concurrency.acquire().await.unwrap(); println!("uploading: {}", path_to_upload); if Command::new("nix") .arg("copy") @@ -202,6 +232,7 @@ async fn uploader(cacheable_rx: mpsc::Receiver, binary_cache: String, co .arg(&binary_cache) .arg(&path_to_upload) .output() + .await .is_err() { println!("WARN: upload failed: {}", path_to_upload);