use std::{alloc::Layout, cell::Cell, mem::MaybeUninit, ptr::NonNull};
-use crate::oom;
+use crate::{align_offset, oom};
#[derive(Debug)]
pub struct AllocError;
#[inline(always)]
fn new_stack(page: *mut PageFooter) -> PagePointer {
+ debug_assert!(page as usize & (std::mem::align_of::<PageFooter>() - 1) == 0);
PagePointer(((page as usize) | 0x1) as *mut PageFooter)
}
#[inline(always)]
fn new_heap(page: *mut PageFooter) -> PagePointer {
+ debug_assert!(page as usize & (std::mem::align_of::<PageFooter>() - 1) == 0);
PagePointer(page)
}
// Cannot wrap due to guard above.
let bump = bump.wrapping_sub(layout.size());
- let remainder = bump as usize & (layout.align() - 1);
- // Cannot have a remainder greater than the magnitude of the value, so this
- // cannot wrap.
- let bump = bump.wrapping_sub(remainder);
+ // Align down, mask so can't wrap.
+ let bump = (bump as usize & !(layout.align() - 1)) as *mut u8;
+
+ debug_assert!(bump as usize & (layout.align() - 1) == 0);
if bump >= base {
// Cannot be null because `base` cannot be null (derived from `NonNull<u8>`).
let new_page_size = new_page_size.max(PAGE_MIN_SIZE).min(PAGE_MAX_SIZE);
// Ensure that after all that, the given page is large enough to hold the thing
// we're trying to allocate.
- let new_page_size = new_page_size.max(layout.size() + layout.align() + PAGE_FOOTER_SIZE);
+ let new_page_size = new_page_size.max(layout.size() + (layout.align() - 1) + PAGE_FOOTER_SIZE);
+ // Round up to page footer alignment.
+ let new_page_size = align_offset(new_page_size, std::mem::align_of::<PageFooter>());
let size_without_footer = new_page_size - PAGE_FOOTER_SIZE;
debug_assert_ne!(size_without_footer, 0);
let layout = layout_from_size_align(new_page_size, std::mem::align_of::<PageFooter>());
let base_ptr = std::alloc::alloc(layout);
let base = NonNull::new(base_ptr)?;
- let bump = NonNull::new_unchecked(base_ptr.add(size_without_footer));
+ let bump = base_ptr.add(size_without_footer);
+ let bump = NonNull::new_unchecked(bump);
let footer = bump.as_ptr() as *mut PageFooter;
+
debug_assert_ne!(base, bump);
debug_assert!(base < bump);