blob: 96a47d7c41a0f488c956c3dbbdccd1cdfb2eeea9 [file] [log] [blame]
#![allow(clippy::await_holding_lock)]
use crate::code_lock::{check_new_key, LOCKS};
#[doc(hidden)]
pub fn local_serial_core_with_return<E>(
name: &str,
function: fn() -> Result<(), E>,
) -> Result<(), E> {
check_new_key(name);
let unlock = LOCKS.get(name).expect("key to be set");
// _guard needs to be named to avoid being instant dropped
let _guard = unlock.lock();
function()
}
#[doc(hidden)]
pub fn local_serial_core(name: &str, function: fn()) {
check_new_key(name);
let unlock = LOCKS.get(name).expect("key to be set");
// _guard needs to be named to avoid being instant dropped
let _guard = unlock.lock();
function();
}
#[doc(hidden)]
#[cfg(feature = "async")]
pub async fn local_async_serial_core_with_return<E>(
name: &str,
fut: impl std::future::Future<Output = Result<(), E>>,
) -> Result<(), E> {
check_new_key(name);
let unlock = LOCKS.get(name).expect("key to be set");
// _guard needs to be named to avoid being instant dropped
let _guard = unlock.lock();
fut.await
}
#[doc(hidden)]
#[cfg(feature = "async")]
pub async fn local_async_serial_core(name: &str, fut: impl std::future::Future<Output = ()>) {
check_new_key(name);
let unlock = LOCKS.get(name).expect("key to be set");
// _guard needs to be named to avoid being instant dropped
let _guard = unlock.lock();
fut.await;
}
#[cfg(test)]
#[allow(clippy::print_stdout)]
mod tests {
use super::local_serial_core;
use crate::code_lock::{check_new_key, LOCKS};
use itertools::Itertools;
use parking_lot::RwLock;
use std::{
sync::{Arc, Barrier},
thread,
time::Duration,
};
#[test]
fn test_hammer_check_new_key() {
let ptrs = Arc::new(RwLock::new(Vec::new()));
let mut threads = Vec::new();
let count = 100;
let barrier = Arc::new(Barrier::new(count));
for _ in 0..count {
let local_locks = LOCKS.clone();
let local_ptrs = ptrs.clone();
let c = barrier.clone();
threads.push(thread::spawn(move || {
c.wait();
check_new_key("foo");
{
let unlock = local_locks.get("foo").expect("read didn't work");
let mutex = unlock.value();
let mut ptr_guard = local_ptrs
.try_write_for(Duration::from_secs(1))
.expect("write lock didn't work");
ptr_guard.push(mutex.id);
}
c.wait();
}));
}
for thread in threads {
thread.join().expect("thread join worked");
}
let ptrs_read_lock = ptrs
.try_read_recursive_for(Duration::from_secs(1))
.expect("ptrs read work");
assert_eq!(ptrs_read_lock.len(), count);
println!("{:?}", ptrs_read_lock);
assert_eq!(ptrs_read_lock.iter().unique().count(), 1);
}
#[test]
fn unlock_on_assert() {
let _ = std::panic::catch_unwind(|| {
local_serial_core("assert", || {
assert!(false);
})
});
assert!(!LOCKS.get("assert").unwrap().is_locked());
}
}