blob: d770880ca3f8a397fadc7f51a43c276cafc5a5a5 [file] [log] [blame]
use crate::offset_from;
use std::cmp::Ordering::{Greater, Less};
use std::slice::{from_raw_parts, from_raw_parts_mut};
use std::{fmt, marker};
macro_rules! binary_group_by_key {
(struct $name:ident, $elem:ty, $mkslice:ident) => {
impl<'a, T: 'a, F> $name<'a, T, F> {
#[inline]
pub fn is_empty(&self) -> bool {
self.ptr == self.end
}
#[inline]
pub fn remainder_len(&self) -> usize {
unsafe { offset_from(self.end, self.ptr) }
}
}
impl<'a, T: 'a, F, K> std::iter::Iterator for $name<'a, T, F>
where
F: FnMut(&T) -> K,
K: PartialEq,
{
type Item = $elem;
#[inline]
fn next(&mut self) -> Option<Self::Item> {
if self.is_empty() {
return None;
}
let first = unsafe { &*self.ptr };
let len = self.remainder_len();
let tail = unsafe { $mkslice(self.ptr.add(1), len - 1) };
let predicate = |x: &T| {
if (self.func)(first) == (self.func)(x) {
Less
} else {
Greater
}
};
let index = tail.binary_search_by(predicate).unwrap_err();
let left = unsafe { $mkslice(self.ptr, index + 1) };
self.ptr = unsafe { self.ptr.add(index + 1) };
Some(left)
}
fn size_hint(&self) -> (usize, Option<usize>) {
if self.is_empty() {
return (0, Some(0));
}
let len = self.remainder_len();
(1, Some(len))
}
fn last(mut self) -> Option<Self::Item> {
self.next_back()
}
}
impl<'a, T: 'a, F, K> std::iter::DoubleEndedIterator for $name<'a, T, F>
where
F: FnMut(&T) -> K,
K: PartialEq,
{
#[inline]
fn next_back(&mut self) -> Option<Self::Item> {
if self.is_empty() {
return None;
}
let last = unsafe { &*self.end.sub(1) };
let len = self.remainder_len();
let head = unsafe { $mkslice(self.ptr, len - 1) };
let predicate = |x: &T| {
if (self.func)(last) == (self.func)(x) {
Greater
} else {
Less
}
};
let index = head.binary_search_by(predicate).unwrap_err();
let right = unsafe { $mkslice(self.ptr.add(index), len - index) };
self.end = unsafe { self.end.sub(len - index) };
Some(right)
}
}
impl<'a, T: 'a, F, K> std::iter::FusedIterator for $name<'a, T, F>
where
F: FnMut(&T) -> K,
K: PartialEq,
{
}
};
}
/// An iterator that will return non-overlapping groups in the slice using *binary search*.
///
/// It will give an element to the given function, producing a key and comparing
/// the keys to determine groups.
pub struct BinaryGroupByKey<'a, T, F> {
ptr: *const T,
end: *const T,
func: F,
_phantom: marker::PhantomData<&'a T>,
}
impl<'a, T: 'a, F> BinaryGroupByKey<'a, T, F> {
pub fn new(slice: &'a [T], func: F) -> Self {
BinaryGroupByKey {
ptr: slice.as_ptr(),
end: unsafe { slice.as_ptr().add(slice.len()) },
func,
_phantom: marker::PhantomData,
}
}
}
impl<'a, T, F> BinaryGroupByKey<'a, T, F> {
/// Returns the remainder of the original slice that is going to be
/// returned by the iterator.
pub fn remainder(&self) -> &[T] {
let len = self.remainder_len();
unsafe { from_raw_parts(self.ptr, len) }
}
}
impl<'a, T: 'a + fmt::Debug, F> fmt::Debug for BinaryGroupByKey<'a, T, F> {
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
f.debug_struct("BinaryGroupByKey")
.field("remainder", &self.remainder())
.finish()
}
}
binary_group_by_key! { struct BinaryGroupByKey, &'a [T], from_raw_parts }
/// An iterator that will return non-overlapping *mutable* groups
/// in the slice using *binary search*.
///
/// It will give an element to the given function, producing a key and comparing
/// the keys to determine groups.
pub struct BinaryGroupByKeyMut<'a, T, F> {
ptr: *mut T,
end: *mut T,
func: F,
_phantom: marker::PhantomData<&'a mut T>,
}
impl<'a, T: 'a, F> BinaryGroupByKeyMut<'a, T, F> {
pub fn new(slice: &'a mut [T], func: F) -> Self {
let ptr = slice.as_mut_ptr();
let end = unsafe { ptr.add(slice.len()) };
BinaryGroupByKeyMut {
ptr,
end,
func,
_phantom: marker::PhantomData,
}
}
}
impl<'a, T, F> BinaryGroupByKeyMut<'a, T, F> {
/// Returns the remainder of the original slice that is going to be
/// returned by the iterator.
pub fn remainder(&self) -> &[T] {
let len = self.remainder_len();
unsafe { from_raw_parts(self.ptr, len) }
}
}
impl<'a, T: 'a + fmt::Debug, F> fmt::Debug for BinaryGroupByKeyMut<'a, T, F> {
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
f.debug_struct("BinaryGroupByKeyMut")
.field("remainder", &self.remainder())
.finish()
}
}
binary_group_by_key! { struct BinaryGroupByKeyMut, &'a mut [T], from_raw_parts_mut }