From: Joshua Simmons Date: Sat, 8 Oct 2022 11:03:44 +0000 (+0200) Subject: Improve pool implementation X-Git-Url: https://git.nega.tv//gitweb.cgi?a=commitdiff_plain;h=3ad81a357afe3c3292e13a050755ac5d6d3ef2e3;p=josh%2Fnarcissus Improve pool implementation Move virtual memory failures into the `virtual_reserve` and `virtual_commit` signatures so we can give better assert messages on failure. Handle growth more accurately, allowing us to use the entire table capacity before asserting. Reserve a bit in the generation counter to keep track of whether a slot is full by incrementing the counter on both allocation and deallocation. This means exported handles can only ever have an odd generation counter. Assert when given a handle that has an invalid generation counter (where the counter implies it would be pointing to an empty slot). Both above changes together means it's no longer possible to create a reference to an uninitialized slot, even when manually messing with the handle or mixing handles between different pools. --- diff --git a/narcissus-core/src/pool.rs b/narcissus-core/src/pool.rs index f5450bb..b04381f 100644 --- a/narcissus-core/src/pool.rs +++ b/narcissus-core/src/pool.rs @@ -7,15 +7,19 @@ use crate::{ /// Each handle uses `GEN_BITS` bits of per-slot generation counter. Looking up a handle with the /// correct index but an incorrect generation will yield `None`. const GEN_BITS: u32 = 9; + /// Each handle uses `IDX_BITS` bits of index used to select a slot. This limits the maximum /// capacity of the table to `2 ^ IDX_BITS - 1`. const IDX_BITS: u32 = 23; -const MAX_CAPACITY: usize = 1 << IDX_BITS as usize; +const MAX_IDX: usize = 1 << IDX_BITS as usize; +/// Reserve the last slot for the null handle. +const MAX_CAP: usize = MAX_IDX - 1; + const PAGE_SIZE: usize = 4096; /// Keep at least `MIN_FREE_SLOTS` available at all times in order to ensure a minimum of -/// `MIN_FREE_SLOTS * 2 ^ GEN_BITS` create-delete cycles are required before a duplicate handle is +/// `MIN_FREE_SLOTS * 2 ^ (GEN_BITS - 1)` create-delete cycles are required before a duplicate handle is /// generated. const MIN_FREE_SLOTS: usize = 512; @@ -42,7 +46,13 @@ impl Default for Handle { impl Handle { /// Create a handle from the given encode_multiplier, generation counter and slot index. + /// + /// # Panics + /// + /// Panics if the generation counter is even, as that would reference an empty slot. fn encode(encode_multiplier: u32, generation: u32, slot_index: SlotIndex) -> Self { + assert!(generation & 1 == 1); + let value = (generation & GEN_MASK) << GEN_SHIFT | (slot_index.0 & IDX_MASK) << IDX_SHIFT; // Invert bits so that the all bits set, the null handle, becomes zero. let value = !value; @@ -53,6 +63,10 @@ impl Handle { } /// Return a tuple containing the generation counter and slot index from an encoded handle and decode multiplier. + /// + /// # Panics + /// + /// Panics if the generation counter is even, as that would reference an empty slot. fn decode(self, decode_multiplier: u32) -> (u32, SlotIndex) { let value = self.0; // Undo the bit mix from the encode step by multiplying by the multiplicative inverse of the encode_multiplier. @@ -60,7 +74,12 @@ impl Handle { // Invert bits so zero, the null handle, becomes all bits set. let value = !value; let generation = (value >> GEN_SHIFT) & GEN_MASK; - let slot_index = SlotIndex(value >> IDX_SHIFT & IDX_MASK); + let slot_index = SlotIndex((value >> IDX_SHIFT) & IDX_MASK); + + // An invalid generation counter here means either the handle itself has been corrupted, or that it's from + // another pool. + assert!(generation & 1 == 1, "invalid generation counter"); + (generation, slot_index) } @@ -91,7 +110,16 @@ struct SlotIndex(u32); #[derive(Clone, Copy, PartialEq, Eq)] struct ValueIndex(u32); +impl ValueIndex { + fn invalid() -> Self { + Self(!0) + } +} + /// Packed value storing the generation and value index for each slot in the indirection table. +/// +/// The least-significant bit of the generation counter serves to indicate whether the slot is occupied. If it's 1, +/// the slot contains a valid entry. If it's 0, the slot is invalid. struct Slot { value_index_and_gen: u32, } @@ -99,7 +127,8 @@ struct Slot { impl Slot { const fn new() -> Self { Self { - value_index_and_gen: 0xffff_ffff, + // Clear the generation counter, but leave the index bits set. + value_index_and_gen: IDX_MASK, } } @@ -113,16 +142,25 @@ impl Slot { ValueIndex((self.value_index_and_gen >> IDX_SHIFT) & IDX_MASK) } - /// Sets the slot's value index without modifying the generation. - fn set_value_index(&mut self, value_index: ValueIndex) { + /// Updates the slot's value index without modifying the generation. + fn update_value_index(&mut self, value_index: ValueIndex) { debug_assert!(value_index.0 & IDX_MASK == value_index.0); self.value_index_and_gen = self.generation() << GEN_SHIFT | (value_index.0 & IDX_MASK) << IDX_SHIFT; } - /// Clears the slot, resetting the value_mask to all bits set and incrementing the generation counter. - fn clear_value_index(&mut self) { + /// Sets the slot's value index, incrementing the generation counter. + fn set_value_index(&mut self, value_index: ValueIndex) { let new_generation = self.generation().wrapping_add(1); + self.value_index_and_gen = + (new_generation & GEN_MASK) << GEN_SHIFT | (value_index.0 & IDX_MASK) << IDX_SHIFT; + } + + /// Clears the slot's value index, incrementing the generation counter. + fn clear_value_index(&mut self) { + // Since we're clearing we need to reset the generation to one referencing an empty slot. But we still want to + // invalidate old handles. + let new_generation = (self.generation() | 1).wrapping_add(1); self.value_index_and_gen = (new_generation & GEN_MASK) << GEN_SHIFT | IDX_MASK << IDX_SHIFT; } } @@ -145,18 +183,27 @@ impl FreeSlots { } } + #[inline] fn head(&self) -> usize { self.head & (self.cap - 1) } + #[inline] fn tail(&self) -> usize { self.tail & (self.cap - 1) } + #[inline] fn len(&self) -> usize { self.head.wrapping_sub(self.tail) } + #[inline] + fn is_empty(&self) -> bool { + self.len() == 0 + } + + #[inline] fn is_full(&self) -> bool { self.len() == self.cap } @@ -172,8 +219,7 @@ impl FreeSlots { } fn pop(&mut self) -> Option { - // If we don't have enough free slots we need to add some more. - if self.len() < MIN_FREE_SLOTS { + if self.is_empty() { return None; } let tail = self.tail(); @@ -186,7 +232,7 @@ impl FreeSlots { // Free slots must always be a power of two so that the modular arithmetic for indexing // works out correctly. debug_assert!(self.cap == 0 || self.cap.is_power_of_two()); - assert!(self.cap < MAX_CAPACITY); + assert!(self.cap <= MAX_IDX, "freelist overflow"); let new_cap = if self.cap == 0 { 1024 } else { self.cap << 1 }; unsafe { @@ -208,6 +254,9 @@ impl FreeSlots { } } +// Make sure the slots array always grows by a single page. +const SLOT_GROWTH_AMOUNT: usize = PAGE_SIZE / std::mem::size_of::(); + /// Indirection table mapping slot indices stored in handles to values in the values array. /// /// Also contains the generation counter for each slot. @@ -239,11 +288,17 @@ impl Slots { } } + /// Attempts to grow the slots array. + /// + /// Returns a tuple containing the old len and new len, or None if the array was already at capacity. #[cold] - fn grow(&mut self) -> (u32, u32) { + fn try_grow(&mut self) -> Option<(u32, u32)> { let len = self.len; - let new_len = std::cmp::min(len + MIN_FREE_SLOTS * 2, MAX_CAPACITY); - assert!(new_len > len); + let new_len = std::cmp::min(len + SLOT_GROWTH_AMOUNT, MAX_CAP); + if new_len <= len { + return None; + } + unsafe { virtual_commit( self.ptr.as_ptr().add(len) as _, @@ -253,8 +308,9 @@ impl Slots { std::ptr::write(self.ptr.as_ptr().add(new_slot_index), Slot::new()); } } + self.len = new_len; - (len as u32, new_len as u32) + Some((len as u32, new_len as u32)) } } @@ -335,7 +391,7 @@ impl Values { slots .get_mut(last_slot_index) .unwrap() - .set_value_index(value_index); + .update_value_index(value_index); } let value_index = value_index.0 as usize; @@ -374,27 +430,34 @@ impl Values { unsafe { ptr.add(value_index).as_mut().unwrap() } } - /// Expands the values area by a fixed amount. Commiting the previously reserved virtual memory. + /// Expands the values region by a fixed amount. #[cold] fn grow(&mut self) { - let new_cap = std::cmp::min(self.cap + 1024, MAX_CAPACITY); - assert!(new_cap > self.cap); + let new_cap = std::cmp::min(self.cap + 1024, MAX_CAP); let grow_region = new_cap - self.cap; - unsafe { - virtual_commit( - self.values_ptr.as_ptr().add(self.len) as _, - grow_region * size_of::(), - ); - virtual_commit( - self.slots_ptr.as_ptr().add(self.len) as _, - grow_region * size_of::(), - ); + + if grow_region > 0 { + unsafe { + virtual_commit( + self.values_ptr.as_ptr().add(self.len) as _, + grow_region * size_of::(), + ); + virtual_commit( + self.slots_ptr.as_ptr().add(self.len) as _, + grow_region * size_of::(), + ); + } } + self.cap = new_cap; } } /// A pool for allocating objects of type T and associating them with a POD `Handle`. +/// +/// We do a basic attempt to ensure that mixing handles from different pools with either assert or return None. However +/// it's possible that by accident lookup using a handle from another pool will return a valid object. The pool will +/// not have memory unsafety in this case however, as it will only return valid objects from the pool. pub struct Pool { encode_multiplier: u32, decode_multiplier: u32, @@ -414,22 +477,22 @@ impl Pool { let mut mapping_size = 0; let free_slots_offset = mapping_size; - mapping_size += MAX_CAPACITY * size_of::(); + mapping_size += MAX_IDX * size_of::(); mapping_size = align_offset(mapping_size, PAGE_SIZE); let slots_offset = mapping_size; - mapping_size += MAX_CAPACITY * size_of::(); + mapping_size += MAX_IDX * size_of::(); mapping_size = align_offset(mapping_size, PAGE_SIZE); let value_slots_offset = mapping_size; - mapping_size += MAX_CAPACITY * size_of::(); + mapping_size += MAX_CAP * size_of::(); mapping_size = align_offset(mapping_size, PAGE_SIZE); let values_offset = mapping_size; - mapping_size += MAX_CAPACITY * size_of::(); + mapping_size += MAX_CAP * size_of::(); mapping_size = align_offset(mapping_size, PAGE_SIZE); - let mapping_base = virtual_reserve(mapping_size); + let mapping_base = virtual_reserve(mapping_size).expect("failed to map memory"); let free_slots = unsafe { mapping_base.add(free_slots_offset) } as _; let slots = unsafe { mapping_base.add(slots_offset) } as _; let value_slots = unsafe { mapping_base.add(value_slots_offset) } as _; @@ -476,43 +539,38 @@ impl Pool { } /// Inserts a value into the pool, returning a handle that represents it. + #[must_use] pub fn insert(&mut self, value: T) -> Handle { let value_index = self.values.push(value); - let slot_index = match self.free_slots.pop() { - Some(slot_index) => slot_index, - None => { - // We need to grow the slots array if there are insufficient free slots. - let (lo, hi) = self.slots.grow(); - for free_slot_index in (lo + 1)..hi { + if self.free_slots.len() < MIN_FREE_SLOTS { + // We need to grow the slots array if there are insufficient free slots. + // This is a no-op if we're already at the max capacity of the pool, which weakens the use-after-free + // detection. + if let Some((lo, hi)) = self.slots.try_grow() { + for free_slot_index in lo..hi { self.free_slots.push(SlotIndex(free_slot_index)); } - SlotIndex(lo) } - }; + } + let slot_index = self.free_slots.pop().expect("pool capacity exceeded"); self.values.set_slot(value_index, slot_index); let slot = self.slots.get_mut(slot_index).unwrap(); - let generation = slot.generation(); slot.set_value_index(value_index); - - Handle::encode(self.encode_multiplier, generation, slot_index) + Handle::encode(self.encode_multiplier, slot.generation(), slot_index) } /// Removes a value from the pool, returning the value associated with the handle if it was previously valid. pub fn remove(&mut self, handle: Handle) -> Option { - if handle.is_null() { - return None; - } - let (generation, slot_index) = handle.decode(self.decode_multiplier); if let Some(slot) = self.slots.get_mut(slot_index) { if slot.generation() == generation { self.free_slots.push(slot_index); let value_index = slot.value_index(); - slot.clear_value_index(); + slot.set_value_index(ValueIndex::invalid()); return Some(self.values.swap_remove(value_index, &mut self.slots)); } } @@ -522,10 +580,6 @@ impl Pool { /// Returns a mutable reference to the value corresponding to the handle. pub fn get_mut(&mut self, handle: Handle) -> Option<&mut T> { - if handle.is_null() { - return None; - } - let (generation, slot_index) = handle.decode(self.decode_multiplier); if let Some(slot) = self.slots.get(slot_index) { @@ -539,10 +593,6 @@ impl Pool { /// Returns a reference to the value corresponding to the handle. pub fn get(&self, handle: Handle) -> Option<&T> { - if handle.is_null() { - return None; - } - let (generation, slot_index) = handle.decode(self.decode_multiplier); if let Some(slot) = self.slots.get(slot_index) { @@ -586,7 +636,7 @@ impl Drop for Pool { self.values.len, ); std::ptr::drop_in_place(to_drop); - virtual_free(self.mapping_base, self.mapping_size); + virtual_free(self.mapping_base, self.mapping_size).expect("failed to unmap memory"); } } } @@ -601,92 +651,74 @@ impl Default for Pool { mod tests { use std::sync::atomic::{AtomicU32, Ordering}; - use super::{Handle, Pool, MAX_CAPACITY, MIN_FREE_SLOTS}; + use super::{Handle, Pool, MAX_CAP}; + + #[test] + fn lookup_null() { + let mut pool = Pool::new(); + assert!(pool.get(Handle::null()).is_none()); + let _ = pool.insert(0); + assert!(pool.get(Handle::null()).is_none()); + } #[test] - fn basics() { + fn insert_lookup_remove() { let mut pool = Pool::new(); assert_eq!(pool.get(Handle::null()), None); - let one = pool.insert(1); - let two = pool.insert(2); - let three = pool.insert(3); + + let handles: [Handle; 500] = std::array::from_fn(|i| pool.insert(i)); for _ in 0..20 { - let handles = (0..300_000).map(|_| pool.insert(9)).collect::>(); - for handle in &handles { - assert_eq!(pool.remove(*handle), Some(9)); + let handles = (0..300_000).map(|i| pool.insert(i)).collect::>(); + for (i, &handle) in handles.iter().enumerate() { + assert_eq!(pool.get(handle), Some(&i)); + assert_eq!(pool.remove(handle), Some(i)); + assert_eq!(pool.get(handle), None); + assert_eq!(pool.remove(handle), None); } } - assert_eq!(pool.get(one), Some(&1)); - assert_eq!(pool.get(two), Some(&2)); - assert_eq!(pool.get(three), Some(&3)); - assert_eq!(pool.remove(one), Some(1)); - assert_eq!(pool.remove(two), Some(2)); - assert_eq!(pool.remove(three), Some(3)); - assert_eq!(pool.remove(one), None); - assert_eq!(pool.remove(two), None); - assert_eq!(pool.remove(three), None); + for (i, &handle) in handles.iter().enumerate() { + assert_eq!(pool.get(handle), Some(&i)); + assert_eq!(pool.remove(handle), Some(i)); + assert_eq!(pool.get(handle), None); + assert_eq!(pool.remove(handle), None); + } + + assert_eq!(pool.get(Handle::null()), None); } // This test is based on randomness in the base address of the pool so disable it by default to // avoid flaky tests in CI. + // We do a basic attempt to ensure that mixing handles from different pools with either assert or return None. #[test] #[ignore] + #[should_panic] fn test_pool_randomiser() { let mut pool_1 = Pool::new(); let mut pool_2 = Pool::new(); - let handle_1 = pool_1.insert(1); let handle_2 = pool_2.insert(1); assert_ne!(handle_1, handle_2); + assert_eq!(pool_1.get(handle_2), None); + assert_eq!(pool_2.get(handle_1), None); } - // This test is based on randomness in the base address of the pool so disable it by default to - // avoid flaky tests in CI. - #[test] - #[ignore] - #[should_panic] - fn test_pool_randomiser_fail() { - let mut pool_1 = Pool::new(); - let mut pool_2 = Pool::new(); - let handle_1 = pool_1.insert(1); - let _handle_2 = pool_2.insert(1); - pool_2.get(handle_1).unwrap(); - } - - // Fills the entire pool which is slow in debug mode, so ignore this test. #[test] - #[ignore] fn capacity() { - #[derive(Clone, Copy)] - struct Chonk { - value: usize, - _pad: [u8; 4096 - std::mem::size_of::()], - } + let mut pool = Pool::new(); - impl Chonk { - fn new(value: usize) -> Self { - Self { - value, - _pad: [0; 4096 - std::mem::size_of::()], - } - } - } + let handles = (0..MAX_CAP).map(|i| pool.insert(i)).collect::>(); - impl PartialEq for Chonk { - fn eq(&self, rhs: &Self) -> bool { - self.value == rhs.value - } - } + assert_eq!(pool.len(), MAX_CAP); - let mut pool = Pool::new(); + for (i, &handle) in handles.iter().enumerate() { + assert_eq!(pool.get(handle), Some(&i)); + } - for i in 0..MAX_CAPACITY - MIN_FREE_SLOTS { - let chonk = Chonk::new(i); - let handle = pool.insert(chonk); - assert!(pool.get(handle) == Some(&chonk)); + for (i, &handle) in handles.iter().enumerate() { + assert_eq!(pool.remove(handle), Some(i)); } - assert_eq!(pool.len(), MAX_CAPACITY - MIN_FREE_SLOTS); + assert!(pool.is_empty()); } #[test] diff --git a/narcissus-core/src/virtual_mem.rs b/narcissus-core/src/virtual_mem.rs index 8a51535..3ad3b3a 100644 --- a/narcissus-core/src/virtual_mem.rs +++ b/narcissus-core/src/virtual_mem.rs @@ -5,13 +5,9 @@ use crate::libc; /// Size will be rounded up to align with the system's page size. /// /// The range is valid but inaccessible before calling `virtual_commit`. -/// -/// # Panics -/// -/// Panics if mapping fails. #[cold] #[inline(never)] -pub fn virtual_reserve(size: usize) -> *mut std::ffi::c_void { +pub fn virtual_reserve(size: usize) -> Result<*mut std::ffi::c_void, ()> { let ptr = unsafe { libc::mmap( std::ptr::null_mut(), @@ -23,9 +19,11 @@ pub fn virtual_reserve(size: usize) -> *mut std::ffi::c_void { ) }; - assert!(ptr != libc::MAP_FAILED && !ptr.is_null()); - - ptr + if ptr == libc::MAP_FAILED || ptr.is_null() { + Err(()) + } else { + Ok(ptr) + } } /// Commit (part of) a previously reserved memory range. @@ -55,13 +53,13 @@ pub unsafe fn virtual_commit(ptr: *mut std::ffi::c_void, size: usize) { /// /// - Must point to an existing assignment created by [`virtual_reserve`]. /// - `size` must be within range of that reservation. -/// -/// # Panics -/// -/// Panics if the range could not be unmapped. #[cold] #[inline(never)] -pub unsafe fn virtual_free(ptr: *mut std::ffi::c_void, size: usize) { +pub unsafe fn virtual_free(ptr: *mut std::ffi::c_void, size: usize) -> Result<(), ()> { let result = libc::munmap(ptr, size); - assert!(result == 0); + if result != 0 { + Err(()) + } else { + Ok(()) + } } diff --git a/narcissus-core/src/virtual_vec/raw_vec.rs b/narcissus-core/src/virtual_vec/raw_vec.rs index 688cf0a..88352d7 100644 --- a/narcissus-core/src/virtual_vec/raw_vec.rs +++ b/narcissus-core/src/virtual_vec/raw_vec.rs @@ -32,7 +32,8 @@ impl VirtualRawVec { // Check overflow of rounding operation. assert!(max_capacity_bytes <= (std::usize::MAX - (align - 1))); - let ptr = unsafe { NonNull::new_unchecked(virtual_reserve(max_capacity_bytes) as *mut T) }; + let ptr = virtual_reserve(max_capacity_bytes).expect("mapping failed"); + let ptr = unsafe { NonNull::new_unchecked(ptr as *mut T) }; Self { ptr, @@ -129,7 +130,8 @@ impl Drop for VirtualRawVec { virtual_free( self.ptr.as_ptr() as *mut std::ffi::c_void, self.max_cap * size_of::(), - ); + ) + .expect("failed to unmap memory"); } } }