blob: 70f89af0fa929ffb5070e3a91d75c26e25ab6c1c [file] [log] [blame]
use crate::{
errors::{IntoArrayError, NotEqualError},
InOut,
};
use core::{marker::PhantomData, slice};
use generic_array::{ArrayLength, GenericArray};
/// Custom slice type which references one immutable (input) slice and one
/// mutable (output) slice of equal length. Input and output slices are
/// either the same or do not overlap.
pub struct InOutBuf<'inp, 'out, T> {
pub(crate) in_ptr: *const T,
pub(crate) out_ptr: *mut T,
pub(crate) len: usize,
pub(crate) _pd: PhantomData<(&'inp T, &'out mut T)>,
}
impl<'a, T> From<&'a mut [T]> for InOutBuf<'a, 'a, T> {
#[inline(always)]
fn from(buf: &'a mut [T]) -> Self {
let p = buf.as_mut_ptr();
Self {
in_ptr: p,
out_ptr: p,
len: buf.len(),
_pd: PhantomData,
}
}
}
impl<'a, T> InOutBuf<'a, 'a, T> {
/// Create `InOutBuf` from a single mutable reference.
#[inline(always)]
pub fn from_mut(val: &'a mut T) -> InOutBuf<'a, 'a, T> {
let p = val as *mut T;
Self {
in_ptr: p,
out_ptr: p,
len: 1,
_pd: PhantomData,
}
}
}
impl<'inp, 'out, T> IntoIterator for InOutBuf<'inp, 'out, T> {
type Item = InOut<'inp, 'out, T>;
type IntoIter = InOutBufIter<'inp, 'out, T>;
#[inline(always)]
fn into_iter(self) -> Self::IntoIter {
InOutBufIter { buf: self, pos: 0 }
}
}
impl<'inp, 'out, T> InOutBuf<'inp, 'out, T> {
/// Create `InOutBuf` from a pair of immutable and mutable references.
#[inline(always)]
pub fn from_ref_mut(in_val: &'inp T, out_val: &'out mut T) -> Self {
Self {
in_ptr: in_val as *const T,
out_ptr: out_val as *mut T,
len: 1,
_pd: PhantomData,
}
}
/// Create `InOutBuf` from immutable and mutable slices.
///
/// Returns an error if length of slices is not equal to each other.
#[inline(always)]
pub fn new(in_buf: &'inp [T], out_buf: &'out mut [T]) -> Result<Self, NotEqualError> {
if in_buf.len() != out_buf.len() {
Err(NotEqualError)
} else {
Ok(Self {
in_ptr: in_buf.as_ptr(),
out_ptr: out_buf.as_mut_ptr(),
len: in_buf.len(),
_pd: Default::default(),
})
}
}
/// Get length of the inner buffers.
#[inline(always)]
pub fn len(&self) -> usize {
self.len
}
/// Returns `true` if the buffer has a length of 0.
#[inline(always)]
pub fn is_empty(&self) -> bool {
self.len == 0
}
/// Returns `InOut` for given position.
///
/// # Panics
/// If `pos` greater or equal to buffer length.
#[inline(always)]
pub fn get<'a>(&'a mut self, pos: usize) -> InOut<'a, 'a, T> {
assert!(pos < self.len);
unsafe {
InOut {
in_ptr: self.in_ptr.add(pos),
out_ptr: self.out_ptr.add(pos),
_pd: PhantomData,
}
}
}
/// Get input slice.
#[inline(always)]
pub fn get_in<'a>(&'a self) -> &'a [T] {
unsafe { slice::from_raw_parts(self.in_ptr, self.len) }
}
/// Get output slice.
#[inline(always)]
pub fn get_out<'a>(&'a mut self) -> &'a mut [T] {
unsafe { slice::from_raw_parts_mut(self.out_ptr, self.len) }
}
/// Consume self and return output slice with lifetime `'a`.
#[inline(always)]
pub fn into_out(self) -> &'out mut [T] {
unsafe { slice::from_raw_parts_mut(self.out_ptr, self.len) }
}
/// Get raw input and output pointers.
#[inline(always)]
pub fn into_raw(self) -> (*const T, *mut T) {
(self.in_ptr, self.out_ptr)
}
/// Reborrow `self`.
#[inline(always)]
pub fn reborrow<'a>(&'a mut self) -> InOutBuf<'a, 'a, T> {
Self {
in_ptr: self.in_ptr,
out_ptr: self.out_ptr,
len: self.len,
_pd: PhantomData,
}
}
/// Create [`InOutBuf`] from raw input and output pointers.
///
/// # Safety
/// Behavior is undefined if any of the following conditions are violated:
/// - `in_ptr` must point to a properly initialized value of type `T` and
/// must be valid for reads for `len * mem::size_of::<T>()` many bytes.
/// - `out_ptr` must point to a properly initialized value of type `T` and
/// must be valid for both reads and writes for `len * mem::size_of::<T>()`
/// many bytes.
/// - `in_ptr` and `out_ptr` must be either equal or non-overlapping.
/// - If `in_ptr` and `out_ptr` are equal, then the memory referenced by
/// them must not be accessed through any other pointer (not derived from
/// the return value) for the duration of lifetime 'a. Both read and write
/// accesses are forbidden.
/// - If `in_ptr` and `out_ptr` are not equal, then the memory referenced by
/// `out_ptr` must not be accessed through any other pointer (not derived from
/// the return value) for the duration of lifetime 'a. Both read and write
/// accesses are forbidden. The memory referenced by `in_ptr` must not be
/// mutated for the duration of lifetime `'a`, except inside an `UnsafeCell`.
/// - The total size `len * mem::size_of::<T>()` must be no larger than `isize::MAX`.
#[inline(always)]
pub unsafe fn from_raw(
in_ptr: *const T,
out_ptr: *mut T,
len: usize,
) -> InOutBuf<'inp, 'out, T> {
Self {
in_ptr,
out_ptr,
len,
_pd: PhantomData,
}
}
/// Divides one buffer into two at `mid` index.
///
/// The first will contain all indices from `[0, mid)` (excluding
/// the index `mid` itself) and the second will contain all
/// indices from `[mid, len)` (excluding the index `len` itself).
///
/// # Panics
///
/// Panics if `mid > len`.
#[inline(always)]
pub fn split_at(self, mid: usize) -> (InOutBuf<'inp, 'out, T>, InOutBuf<'inp, 'out, T>) {
assert!(mid <= self.len);
let (tail_in_ptr, tail_out_ptr) = unsafe { (self.in_ptr.add(mid), self.out_ptr.add(mid)) };
(
InOutBuf {
in_ptr: self.in_ptr,
out_ptr: self.out_ptr,
len: mid,
_pd: PhantomData,
},
InOutBuf {
in_ptr: tail_in_ptr,
out_ptr: tail_out_ptr,
len: self.len() - mid,
_pd: PhantomData,
},
)
}
/// Partition buffer into 2 parts: buffer of arrays and tail.
#[inline(always)]
pub fn into_chunks<N: ArrayLength<T>>(
self,
) -> (
InOutBuf<'inp, 'out, GenericArray<T, N>>,
InOutBuf<'inp, 'out, T>,
) {
let chunks = self.len() / N::USIZE;
let tail_pos = N::USIZE * chunks;
let tail_len = self.len() - tail_pos;
unsafe {
let chunks = InOutBuf {
in_ptr: self.in_ptr as *const GenericArray<T, N>,
out_ptr: self.out_ptr as *mut GenericArray<T, N>,
len: chunks,
_pd: PhantomData,
};
let tail = InOutBuf {
in_ptr: self.in_ptr.add(tail_pos),
out_ptr: self.out_ptr.add(tail_pos),
len: tail_len,
_pd: PhantomData,
};
(chunks, tail)
}
}
}
impl<'inp, 'out> InOutBuf<'inp, 'out, u8> {
/// XORs `data` with values behind the input slice and write
/// result to the output slice.
///
/// # Panics
/// If `data` length is not equal to the buffer length.
#[inline(always)]
#[allow(clippy::needless_range_loop)]
pub fn xor_in2out(&mut self, data: &[u8]) {
assert_eq!(self.len(), data.len());
unsafe {
for i in 0..data.len() {
let in_ptr = self.in_ptr.add(i);
let out_ptr = self.out_ptr.add(i);
*out_ptr = *in_ptr ^ data[i];
}
}
}
}
impl<'inp, 'out, T, N> TryInto<InOut<'inp, 'out, GenericArray<T, N>>> for InOutBuf<'inp, 'out, T>
where
N: ArrayLength<T>,
{
type Error = IntoArrayError;
#[inline(always)]
fn try_into(self) -> Result<InOut<'inp, 'out, GenericArray<T, N>>, Self::Error> {
if self.len() == N::USIZE {
Ok(InOut {
in_ptr: self.in_ptr as *const _,
out_ptr: self.out_ptr as *mut _,
_pd: PhantomData,
})
} else {
Err(IntoArrayError)
}
}
}
/// Iterator over [`InOutBuf`].
pub struct InOutBufIter<'inp, 'out, T> {
buf: InOutBuf<'inp, 'out, T>,
pos: usize,
}
impl<'inp, 'out, T> Iterator for InOutBufIter<'inp, 'out, T> {
type Item = InOut<'inp, 'out, T>;
#[inline(always)]
fn next(&mut self) -> Option<Self::Item> {
if self.buf.len() == self.pos {
return None;
}
let res = unsafe {
InOut {
in_ptr: self.buf.in_ptr.add(self.pos),
out_ptr: self.buf.out_ptr.add(self.pos),
_pd: PhantomData,
}
};
self.pos += 1;
Some(res)
}
}