#![warn(missing_docs)]
use parking_lot::{Condvar, Mutex};
use std::cmp::{Eq, Ord, Ordering, PartialEq, PartialOrd};
use std::collections::BinaryHeap;
use std::panic::{self, AssertUnwindSafe};
use std::sync::atomic::{self, AtomicBool};
use std::sync::Arc;
use std::thread;
use std::time::{Duration, Instant};
use crate::thunk::Thunk;
mod thunk;
pub struct JobHandle(Arc<AtomicBool>);
impl JobHandle {
pub fn cancel(&self) {
self.0.store(true, atomic::Ordering::SeqCst);
}
}
enum JobType {
Once(Thunk<'static>),
FixedRate {
f: Box<FnMut() + Send + 'static>,
rate: Duration,
},
FixedDelay {
f: Box<FnMut() + Send + 'static>,
delay: Duration,
},
}
struct Job {
type_: JobType,
time: Instant,
canceled: Arc<AtomicBool>,
}
impl PartialOrd for Job {
fn partial_cmp(&self, other: &Job) -> Option<Ordering> {
Some(self.cmp(other))
}
}
impl Ord for Job {
fn cmp(&self, other: &Job) -> Ordering {
self.time.cmp(&other.time).reverse()
}
}
impl PartialEq for Job {
fn eq(&self, other: &Job) -> bool {
self.time == other.time
}
}
impl Eq for Job {}
struct InnerPool {
queue: BinaryHeap<Job>,
shutdown: bool,
}
struct SharedPool {
inner: Mutex<InnerPool>,
cvar: Condvar,
}
impl SharedPool {
fn run(&self, job: Job) {
let mut inner = self.inner.lock();
if inner.shutdown {
return;
}
match inner.queue.peek() {
None => self.cvar.notify_all(),
Some(e) if e.time > job.time => self.cvar.notify_all(),
_ => 0usize,
};
inner.queue.push(job);
}
}
pub struct ScheduledThreadPool {
shared: Arc<SharedPool>,
}
impl Drop for ScheduledThreadPool {
fn drop(&mut self) {
self.shared.inner.lock().shutdown = true;
self.shared.cvar.notify_all();
}
}
impl ScheduledThreadPool {
pub fn new(num_threads: usize) -> ScheduledThreadPool {
ScheduledThreadPool::new_inner(None, num_threads)
}
pub fn with_name(thread_name: &str, num_threads: usize) -> ScheduledThreadPool {
ScheduledThreadPool::new_inner(Some(thread_name), num_threads)
}
fn new_inner(thread_name: Option<&str>, num_threads: usize) -> ScheduledThreadPool {
assert!(num_threads > 0, "num_threads must be positive");
let inner = InnerPool {
queue: BinaryHeap::new(),
shutdown: false,
};
let shared = SharedPool {
inner: Mutex::new(inner),
cvar: Condvar::new(),
};
let pool = ScheduledThreadPool {
shared: Arc::new(shared),
};
for i in 0..num_threads {
Worker::start(
thread_name.map(|n| n.replace("{}", &i.to_string())),
pool.shared.clone(),
);
}
pool
}
pub fn execute<F>(&self, job: F) -> JobHandle
where
F: FnOnce() + Send + 'static,
{
self.execute_after(Duration::from_secs(0), job)
}
pub fn execute_after<F>(&self, delay: Duration, job: F) -> JobHandle
where
F: FnOnce() + Send + 'static,
{
let canceled = Arc::new(AtomicBool::new(false));
let job = Job {
type_: JobType::Once(Thunk::new(job)),
time: Instant::now() + delay,
canceled: canceled.clone(),
};
self.shared.run(job);
JobHandle(canceled)
}
pub fn execute_at_fixed_rate<F>(
&self,
initial_delay: Duration,
rate: Duration,
f: F,
) -> JobHandle
where
F: FnMut() + Send + 'static,
{
let canceled = Arc::new(AtomicBool::new(false));
let job = Job {
type_: JobType::FixedRate {
f: Box::new(f),
rate: rate,
},
time: Instant::now() + initial_delay,
canceled: canceled.clone(),
};
self.shared.run(job);
JobHandle(canceled)
}
pub fn execute_with_fixed_delay<F>(
&self,
initial_delay: Duration,
delay: Duration,
f: F,
) -> JobHandle
where
F: FnMut() + Send + 'static,
{
let canceled = Arc::new(AtomicBool::new(false));
let job = Job {
type_: JobType::FixedDelay {
f: Box::new(f),
delay: delay,
},
time: Instant::now() + initial_delay,
canceled: canceled.clone(),
};
self.shared.run(job);
JobHandle(canceled)
}
}
struct Worker {
shared: Arc<SharedPool>,
}
impl Worker {
fn start(name: Option<String>, shared: Arc<SharedPool>) {
let mut worker = Worker { shared: shared };
let mut thread = thread::Builder::new();
if let Some(name) = name {
thread = thread.name(name);
}
thread.spawn(move || worker.run()).unwrap();
}
fn run(&mut self) {
while let Some(job) = self.get_job() {
let _ = panic::catch_unwind(AssertUnwindSafe(|| self.run_job(job)));
}
}
fn get_job(&self) -> Option<Job> {
enum Need {
Wait,
WaitTimeout(Duration),
}
let mut inner = self.shared.inner.lock();
loop {
let now = Instant::now();
let need = match inner.queue.peek() {
None if inner.shutdown => return None,
None => Need::Wait,
Some(e) if e.time <= now => break,
Some(e) => Need::WaitTimeout(e.time - now),
};
match need {
Need::Wait => self.shared.cvar.wait(&mut inner),
Need::WaitTimeout(t) => {
self.shared.cvar.wait_until(&mut inner, now + t);
}
};
}
Some(inner.queue.pop().unwrap())
}
fn run_job(&self, job: Job) {
if job.canceled.load(atomic::Ordering::SeqCst) {
return;
}
match job.type_ {
JobType::Once(f) => f.invoke(()),
JobType::FixedRate { mut f, rate } => {
f();
let new_job = Job {
type_: JobType::FixedRate { f: f, rate: rate },
time: job.time + rate,
canceled: job.canceled,
};
self.shared.run(new_job)
}
JobType::FixedDelay { mut f, delay } => {
f();
let new_job = Job {
type_: JobType::FixedDelay { f: f, delay: delay },
time: Instant::now() + delay,
canceled: job.canceled,
};
self.shared.run(new_job)
}
}
}
}
#[cfg(test)]
mod test {
use std::sync::mpsc::channel;
use std::sync::{Arc, Barrier};
use std::time::Duration;
use super::ScheduledThreadPool;
const TEST_TASKS: usize = 4;
#[test]
fn test_works() {
let pool = ScheduledThreadPool::new(TEST_TASKS);
let (tx, rx) = channel();
for _ in 0..TEST_TASKS {
let tx = tx.clone();
pool.execute(move || {
tx.send(1usize).unwrap();
});
}
assert_eq!(rx.iter().take(TEST_TASKS).fold(0, |a, b| a + b), TEST_TASKS);
}
#[test]
#[should_panic(expected = "num_threads must be positive")]
fn test_zero_tasks_panic() {
ScheduledThreadPool::new(0);
}
#[test]
fn test_recovery_from_subtask_panic() {
let pool = ScheduledThreadPool::new(TEST_TASKS);
let waiter = Arc::new(Barrier::new(TEST_TASKS as usize));
for _ in 0..TEST_TASKS {
let waiter = waiter.clone();
pool.execute(move || -> () {
waiter.wait();
panic!();
});
}
let (tx, rx) = channel();
let waiter = Arc::new(Barrier::new(TEST_TASKS as usize));
for _ in 0..TEST_TASKS {
let tx = tx.clone();
let waiter = waiter.clone();
pool.execute(move || {
waiter.wait();
tx.send(1usize).unwrap();
});
}
assert_eq!(rx.iter().take(TEST_TASKS).fold(0, |a, b| a + b), TEST_TASKS);
}
#[test]
fn test_execute_after() {
let pool = ScheduledThreadPool::new(TEST_TASKS);
let (tx, rx) = channel();
let tx1 = tx.clone();
pool.execute_after(Duration::from_secs(1), move || tx1.send(1usize).unwrap());
pool.execute_after(Duration::from_millis(500), move || tx.send(2usize).unwrap());
assert_eq!(2, rx.recv().unwrap());
assert_eq!(1, rx.recv().unwrap());
}
#[test]
fn test_jobs_complete_after_drop() {
let pool = ScheduledThreadPool::new(TEST_TASKS);
let (tx, rx) = channel();
let tx1 = tx.clone();
pool.execute_after(Duration::from_secs(1), move || tx1.send(1usize).unwrap());
pool.execute_after(Duration::from_millis(500), move || tx.send(2usize).unwrap());
drop(pool);
assert_eq!(2, rx.recv().unwrap());
assert_eq!(1, rx.recv().unwrap());
}
#[test]
fn test_fixed_delay_jobs_stop_after_drop() {
let pool = Arc::new(ScheduledThreadPool::new(TEST_TASKS));
let (tx, rx) = channel();
let (tx2, rx2) = channel();
let mut pool2 = Some(pool.clone());
let mut i = 0i32;
pool.execute_at_fixed_rate(
Duration::from_millis(500),
Duration::from_millis(500),
move || {
i += 1;
tx.send(i).unwrap();
rx2.recv().unwrap();
if i == 2 {
drop(pool2.take().unwrap());
}
},
);
drop(pool);
assert_eq!(Ok(1), rx.recv());
tx2.send(()).unwrap();
assert_eq!(Ok(2), rx.recv());
tx2.send(()).unwrap();
assert!(rx.recv().is_err());
}
#[test]
fn cancellation() {
let pool = ScheduledThreadPool::new(TEST_TASKS);
let (tx, rx) = channel();
let handle = pool.execute_at_fixed_rate(
Duration::from_millis(500),
Duration::from_millis(500),
move || {
tx.send(()).unwrap();
},
);
rx.recv().unwrap();
handle.cancel();
assert!(rx.recv().is_err());
}
}