diff --git a/src/main.rs b/src/main.rs index 9a0562d..27abf93 100644 --- a/src/main.rs +++ b/src/main.rs @@ -1,7 +1,6 @@ #![feature(let_chains)] use std::path::Path; -use std::sync::mpsc; use std::sync::{ Arc, Mutex, RwLock, atomic::{AtomicUsize, Ordering}, @@ -11,7 +10,7 @@ use clap::Parser; use log::{debug, trace}; use serde::{Deserialize, Serialize}; use tokio::process::Command; -use tokio::sync::Semaphore; +use tokio::sync::{Semaphore, mpsc}; const UPSTREAM_CACHES: &[&str] = &["https://cache.nixos.org"]; @@ -138,7 +137,7 @@ async fn main() { } println!("got {} store paths", store_paths.read().unwrap().len()); - let (cacheable_tx, cacheable_rx) = mpsc::channel(); + let (cacheable_tx, mut cacheable_rx) = mpsc::channel(cli.uploader_concurrency.into()); println!("spawning check_upstream"); handles = Vec::new(); @@ -154,7 +153,7 @@ async fn main() { println!("spawning uploader"); handles.push(tokio::spawn(async move { - uploader(cacheable_rx, binary_cache, cli.uploader_concurrency).await; + uploader(&mut cacheable_rx, binary_cache, cli.uploader_concurrency).await; })); // make sure all threads are done @@ -210,20 +209,24 @@ async fn check_upstream( } if !hit { trace!("sending {}", store_path); - tx.send(store_path).unwrap(); + tx.send(store_path).await.unwrap(); } }); } } -async fn uploader(cacheable_rx: mpsc::Receiver, binary_cache: String, concurrency: u8) { +async fn uploader( + cacheable_rx: &mut mpsc::Receiver, + binary_cache: String, + concurrency: u8, +) { let upload_count = Arc::new(AtomicUsize::new(0)); let failures: Arc>> = Arc::new(Mutex::new(Vec::new())); let concurrency = Arc::new(Semaphore::new(concurrency.into())); let mut handles = Vec::new(); loop { - if let Ok(path_to_upload) = cacheable_rx.recv() { + if let Some(path_to_upload) = cacheable_rx.recv().await { let concurrency = Arc::clone(&concurrency); let failures = Arc::clone(&failures); let binary_cache = binary_cache.clone();