]> git.nega.tv - josh/narcissus/commitdiff
Improve pool implementation
authorJoshua Simmons <josh@nega.tv>
Sat, 8 Oct 2022 11:03:44 +0000 (13:03 +0200)
committerJoshua Simmons <josh@nega.tv>
Sat, 8 Oct 2022 11:03:44 +0000 (13:03 +0200)
Move virtual memory failures into the `virtual_reserve` and
`virtual_commit` signatures so we can give better assert messages on
failure.

Handle growth more accurately, allowing us to use the entire table
capacity before asserting.

Reserve a bit in the generation counter to keep track of whether a slot
is full by incrementing the counter on both allocation and deallocation.
This means exported handles can only ever have an odd generation
counter.

Assert when given a handle that has an invalid generation counter
(where the counter implies it would be pointing to an empty slot).

Both above changes together means it's no longer possible to create a
reference to an uninitialized slot, even when manually messing with the
handle or mixing handles between different pools.

narcissus-core/src/pool.rs
narcissus-core/src/virtual_mem.rs
narcissus-core/src/virtual_vec/raw_vec.rs

index f5450bb4cead7284ef129168b3ee8dbef064e7b5..b04381f9a2a84c9796b047ee0680d7c9906d030d 100644 (file)
@@ -7,15 +7,19 @@ use crate::{
 /// Each handle uses `GEN_BITS` bits of per-slot generation counter. Looking up a handle with the
 /// correct index but an incorrect generation will yield `None`.
 const GEN_BITS: u32 = 9;
+
 /// Each handle uses `IDX_BITS` bits of index used to select a slot. This limits the maximum
 /// capacity of the table to `2 ^ IDX_BITS - 1`.
 const IDX_BITS: u32 = 23;
 
-const MAX_CAPACITY: usize = 1 << IDX_BITS as usize;
+const MAX_IDX: usize = 1 << IDX_BITS as usize;
+/// Reserve the last slot for the null handle.
+const MAX_CAP: usize = MAX_IDX - 1;
+
 const PAGE_SIZE: usize = 4096;
 
 /// Keep at least `MIN_FREE_SLOTS` available at all times in order to ensure a minimum of
-/// `MIN_FREE_SLOTS * 2 ^ GEN_BITS` create-delete cycles are required before a duplicate handle is
+/// `MIN_FREE_SLOTS * 2 ^ (GEN_BITS - 1)` create-delete cycles are required before a duplicate handle is
 /// generated.
 const MIN_FREE_SLOTS: usize = 512;
 
@@ -42,7 +46,13 @@ impl Default for Handle {
 
 impl Handle {
     /// Create a handle from the given encode_multiplier, generation counter and slot index.
+    ///
+    /// # Panics
+    ///
+    /// Panics if the generation counter is even, as that would reference an empty slot.
     fn encode(encode_multiplier: u32, generation: u32, slot_index: SlotIndex) -> Self {
+        assert!(generation & 1 == 1);
+
         let value = (generation & GEN_MASK) << GEN_SHIFT | (slot_index.0 & IDX_MASK) << IDX_SHIFT;
         // Invert bits so that the all bits set, the null handle, becomes zero.
         let value = !value;
@@ -53,6 +63,10 @@ impl Handle {
     }
 
     /// Return a tuple containing the generation counter and slot index from an encoded handle and decode multiplier.
+    ///
+    /// # Panics
+    ///
+    /// Panics if the generation counter is even, as that would reference an empty slot.
     fn decode(self, decode_multiplier: u32) -> (u32, SlotIndex) {
         let value = self.0;
         // Undo the bit mix from the encode step by multiplying by the multiplicative inverse of the encode_multiplier.
@@ -60,7 +74,12 @@ impl Handle {
         // Invert bits so zero, the null handle, becomes all bits set.
         let value = !value;
         let generation = (value >> GEN_SHIFT) & GEN_MASK;
-        let slot_index = SlotIndex(value >> IDX_SHIFT & IDX_MASK);
+        let slot_index = SlotIndex((value >> IDX_SHIFT) & IDX_MASK);
+
+        // An invalid generation counter here means either the handle itself has been corrupted, or that it's from
+        // another pool.
+        assert!(generation & 1 == 1, "invalid generation counter");
+
         (generation, slot_index)
     }
 
@@ -91,7 +110,16 @@ struct SlotIndex(u32);
 #[derive(Clone, Copy, PartialEq, Eq)]
 struct ValueIndex(u32);
 
+impl ValueIndex {
+    fn invalid() -> Self {
+        Self(!0)
+    }
+}
+
 /// Packed value storing the generation and value index for each slot in the indirection table.
+///
+/// The least-significant bit of the generation counter serves to indicate whether the slot is occupied. If it's 1,
+/// the slot contains a valid entry. If it's 0, the slot is invalid.
 struct Slot {
     value_index_and_gen: u32,
 }
@@ -99,7 +127,8 @@ struct Slot {
 impl Slot {
     const fn new() -> Self {
         Self {
-            value_index_and_gen: 0xffff_ffff,
+            // Clear the generation counter, but leave the index bits set.
+            value_index_and_gen: IDX_MASK,
         }
     }
 
@@ -113,16 +142,25 @@ impl Slot {
         ValueIndex((self.value_index_and_gen >> IDX_SHIFT) & IDX_MASK)
     }
 
-    /// Sets the slot's value index without modifying the generation.
-    fn set_value_index(&mut self, value_index: ValueIndex) {
+    /// Updates the slot's value index without modifying the generation.
+    fn update_value_index(&mut self, value_index: ValueIndex) {
         debug_assert!(value_index.0 & IDX_MASK == value_index.0);
         self.value_index_and_gen =
             self.generation() << GEN_SHIFT | (value_index.0 & IDX_MASK) << IDX_SHIFT;
     }
 
-    /// Clears the slot, resetting the value_mask to all bits set and incrementing the generation counter.
-    fn clear_value_index(&mut self) {
+    /// Sets the slot's value index, incrementing the generation counter.
+    fn set_value_index(&mut self, value_index: ValueIndex) {
         let new_generation = self.generation().wrapping_add(1);
+        self.value_index_and_gen =
+            (new_generation & GEN_MASK) << GEN_SHIFT | (value_index.0 & IDX_MASK) << IDX_SHIFT;
+    }
+
+    /// Clears the slot's value index, incrementing the generation counter.
+    fn clear_value_index(&mut self) {
+        // Since we're clearing we need to reset the generation to one referencing an empty slot. But we still want to
+        // invalidate old handles.
+        let new_generation = (self.generation() | 1).wrapping_add(1);
         self.value_index_and_gen = (new_generation & GEN_MASK) << GEN_SHIFT | IDX_MASK << IDX_SHIFT;
     }
 }
@@ -145,18 +183,27 @@ impl FreeSlots {
         }
     }
 
+    #[inline]
     fn head(&self) -> usize {
         self.head & (self.cap - 1)
     }
 
+    #[inline]
     fn tail(&self) -> usize {
         self.tail & (self.cap - 1)
     }
 
+    #[inline]
     fn len(&self) -> usize {
         self.head.wrapping_sub(self.tail)
     }
 
+    #[inline]
+    fn is_empty(&self) -> bool {
+        self.len() == 0
+    }
+
+    #[inline]
     fn is_full(&self) -> bool {
         self.len() == self.cap
     }
@@ -172,8 +219,7 @@ impl FreeSlots {
     }
 
     fn pop(&mut self) -> Option<SlotIndex> {
-        // If we don't have enough free slots we need to add some more.
-        if self.len() < MIN_FREE_SLOTS {
+        if self.is_empty() {
             return None;
         }
         let tail = self.tail();
@@ -186,7 +232,7 @@ impl FreeSlots {
         // Free slots must always be a power of two so that the modular arithmetic for indexing
         // works out correctly.
         debug_assert!(self.cap == 0 || self.cap.is_power_of_two());
-        assert!(self.cap < MAX_CAPACITY);
+        assert!(self.cap <= MAX_IDX, "freelist overflow");
 
         let new_cap = if self.cap == 0 { 1024 } else { self.cap << 1 };
         unsafe {
@@ -208,6 +254,9 @@ impl FreeSlots {
     }
 }
 
+// Make sure the slots array always grows by a single page.
+const SLOT_GROWTH_AMOUNT: usize = PAGE_SIZE / std::mem::size_of::<Slot>();
+
 /// Indirection table mapping slot indices stored in handles to values in the values array.
 ///
 /// Also contains the generation counter for each slot.
@@ -239,11 +288,17 @@ impl Slots {
         }
     }
 
+    /// Attempts to grow the slots array.
+    ///
+    /// Returns a tuple containing the old len and new len, or None if the array was already at capacity.
     #[cold]
-    fn grow(&mut self) -> (u32, u32) {
+    fn try_grow(&mut self) -> Option<(u32, u32)> {
         let len = self.len;
-        let new_len = std::cmp::min(len + MIN_FREE_SLOTS * 2, MAX_CAPACITY);
-        assert!(new_len > len);
+        let new_len = std::cmp::min(len + SLOT_GROWTH_AMOUNT, MAX_CAP);
+        if new_len <= len {
+            return None;
+        }
+
         unsafe {
             virtual_commit(
                 self.ptr.as_ptr().add(len) as _,
@@ -253,8 +308,9 @@ impl Slots {
                 std::ptr::write(self.ptr.as_ptr().add(new_slot_index), Slot::new());
             }
         }
+
         self.len = new_len;
-        (len as u32, new_len as u32)
+        Some((len as u32, new_len as u32))
     }
 }
 
@@ -335,7 +391,7 @@ impl<T> Values<T> {
             slots
                 .get_mut(last_slot_index)
                 .unwrap()
-                .set_value_index(value_index);
+                .update_value_index(value_index);
         }
 
         let value_index = value_index.0 as usize;
@@ -374,27 +430,34 @@ impl<T> Values<T> {
         unsafe { ptr.add(value_index).as_mut().unwrap() }
     }
 
-    /// Expands the values area by a fixed amount. Commiting the previously reserved virtual memory.
+    /// Expands the values region by a fixed amount.
     #[cold]
     fn grow(&mut self) {
-        let new_cap = std::cmp::min(self.cap + 1024, MAX_CAPACITY);
-        assert!(new_cap > self.cap);
+        let new_cap = std::cmp::min(self.cap + 1024, MAX_CAP);
         let grow_region = new_cap - self.cap;
-        unsafe {
-            virtual_commit(
-                self.values_ptr.as_ptr().add(self.len) as _,
-                grow_region * size_of::<T>(),
-            );
-            virtual_commit(
-                self.slots_ptr.as_ptr().add(self.len) as _,
-                grow_region * size_of::<SlotIndex>(),
-            );
+
+        if grow_region > 0 {
+            unsafe {
+                virtual_commit(
+                    self.values_ptr.as_ptr().add(self.len) as _,
+                    grow_region * size_of::<T>(),
+                );
+                virtual_commit(
+                    self.slots_ptr.as_ptr().add(self.len) as _,
+                    grow_region * size_of::<SlotIndex>(),
+                );
+            }
         }
+
         self.cap = new_cap;
     }
 }
 
 /// A pool for allocating objects of type T and associating them with a POD `Handle`.
+///
+/// We do a basic attempt to ensure that mixing handles from different pools with either assert or return None. However
+/// it's possible that by accident lookup using a handle from another pool will return a valid object. The pool will
+/// not have memory unsafety in this case however, as it will only return valid objects from the pool.
 pub struct Pool<T> {
     encode_multiplier: u32,
     decode_multiplier: u32,
@@ -414,22 +477,22 @@ impl<T> Pool<T> {
         let mut mapping_size = 0;
 
         let free_slots_offset = mapping_size;
-        mapping_size += MAX_CAPACITY * size_of::<u32>();
+        mapping_size += MAX_IDX * size_of::<u32>();
         mapping_size = align_offset(mapping_size, PAGE_SIZE);
 
         let slots_offset = mapping_size;
-        mapping_size += MAX_CAPACITY * size_of::<Slot>();
+        mapping_size += MAX_IDX * size_of::<Slot>();
         mapping_size = align_offset(mapping_size, PAGE_SIZE);
 
         let value_slots_offset = mapping_size;
-        mapping_size += MAX_CAPACITY * size_of::<u32>();
+        mapping_size += MAX_CAP * size_of::<u32>();
         mapping_size = align_offset(mapping_size, PAGE_SIZE);
 
         let values_offset = mapping_size;
-        mapping_size += MAX_CAPACITY * size_of::<T>();
+        mapping_size += MAX_CAP * size_of::<T>();
         mapping_size = align_offset(mapping_size, PAGE_SIZE);
 
-        let mapping_base = virtual_reserve(mapping_size);
+        let mapping_base = virtual_reserve(mapping_size).expect("failed to map memory");
         let free_slots = unsafe { mapping_base.add(free_slots_offset) } as _;
         let slots = unsafe { mapping_base.add(slots_offset) } as _;
         let value_slots = unsafe { mapping_base.add(value_slots_offset) } as _;
@@ -476,43 +539,38 @@ impl<T> Pool<T> {
     }
 
     /// Inserts a value into the pool, returning a handle that represents it.
+    #[must_use]
     pub fn insert(&mut self, value: T) -> Handle {
         let value_index = self.values.push(value);
 
-        let slot_index = match self.free_slots.pop() {
-            Some(slot_index) => slot_index,
-            None => {
-                // We need to grow the slots array if there are insufficient free slots.
-                let (lo, hi) = self.slots.grow();
-                for free_slot_index in (lo + 1)..hi {
+        if self.free_slots.len() < MIN_FREE_SLOTS {
+            // We need to grow the slots array if there are insufficient free slots.
+            // This is a no-op if we're already at the max capacity of the pool, which weakens the use-after-free
+            // detection.
+            if let Some((lo, hi)) = self.slots.try_grow() {
+                for free_slot_index in lo..hi {
                     self.free_slots.push(SlotIndex(free_slot_index));
                 }
-                SlotIndex(lo)
             }
-        };
+        }
 
+        let slot_index = self.free_slots.pop().expect("pool capacity exceeded");
         self.values.set_slot(value_index, slot_index);
 
         let slot = self.slots.get_mut(slot_index).unwrap();
-        let generation = slot.generation();
         slot.set_value_index(value_index);
-
-        Handle::encode(self.encode_multiplier, generation, slot_index)
+        Handle::encode(self.encode_multiplier, slot.generation(), slot_index)
     }
 
     /// Removes a value from the pool, returning the value associated with the handle if it was previously valid.
     pub fn remove(&mut self, handle: Handle) -> Option<T> {
-        if handle.is_null() {
-            return None;
-        }
-
         let (generation, slot_index) = handle.decode(self.decode_multiplier);
 
         if let Some(slot) = self.slots.get_mut(slot_index) {
             if slot.generation() == generation {
                 self.free_slots.push(slot_index);
                 let value_index = slot.value_index();
-                slot.clear_value_index();
+                slot.set_value_index(ValueIndex::invalid());
                 return Some(self.values.swap_remove(value_index, &mut self.slots));
             }
         }
@@ -522,10 +580,6 @@ impl<T> Pool<T> {
 
     /// Returns a mutable reference to the value corresponding to the handle.
     pub fn get_mut(&mut self, handle: Handle) -> Option<&mut T> {
-        if handle.is_null() {
-            return None;
-        }
-
         let (generation, slot_index) = handle.decode(self.decode_multiplier);
 
         if let Some(slot) = self.slots.get(slot_index) {
@@ -539,10 +593,6 @@ impl<T> Pool<T> {
 
     /// Returns a reference to the value corresponding to the handle.
     pub fn get(&self, handle: Handle) -> Option<&T> {
-        if handle.is_null() {
-            return None;
-        }
-
         let (generation, slot_index) = handle.decode(self.decode_multiplier);
 
         if let Some(slot) = self.slots.get(slot_index) {
@@ -586,7 +636,7 @@ impl<T> Drop for Pool<T> {
                 self.values.len,
             );
             std::ptr::drop_in_place(to_drop);
-            virtual_free(self.mapping_base, self.mapping_size);
+            virtual_free(self.mapping_base, self.mapping_size).expect("failed to unmap memory");
         }
     }
 }
@@ -601,92 +651,74 @@ impl<T> Default for Pool<T> {
 mod tests {
     use std::sync::atomic::{AtomicU32, Ordering};
 
-    use super::{Handle, Pool, MAX_CAPACITY, MIN_FREE_SLOTS};
+    use super::{Handle, Pool, MAX_CAP};
+
+    #[test]
+    fn lookup_null() {
+        let mut pool = Pool::new();
+        assert!(pool.get(Handle::null()).is_none());
+        let _ = pool.insert(0);
+        assert!(pool.get(Handle::null()).is_none());
+    }
 
     #[test]
-    fn basics() {
+    fn insert_lookup_remove() {
         let mut pool = Pool::new();
         assert_eq!(pool.get(Handle::null()), None);
-        let one = pool.insert(1);
-        let two = pool.insert(2);
-        let three = pool.insert(3);
+
+        let handles: [Handle; 500] = std::array::from_fn(|i| pool.insert(i));
         for _ in 0..20 {
-            let handles = (0..300_000).map(|_| pool.insert(9)).collect::<Vec<_>>();
-            for handle in &handles {
-                assert_eq!(pool.remove(*handle), Some(9));
+            let handles = (0..300_000).map(|i| pool.insert(i)).collect::<Vec<_>>();
+            for (i, &handle) in handles.iter().enumerate() {
+                assert_eq!(pool.get(handle), Some(&i));
+                assert_eq!(pool.remove(handle), Some(i));
+                assert_eq!(pool.get(handle), None);
+                assert_eq!(pool.remove(handle), None);
             }
         }
-        assert_eq!(pool.get(one), Some(&1));
-        assert_eq!(pool.get(two), Some(&2));
-        assert_eq!(pool.get(three), Some(&3));
-        assert_eq!(pool.remove(one), Some(1));
-        assert_eq!(pool.remove(two), Some(2));
-        assert_eq!(pool.remove(three), Some(3));
-        assert_eq!(pool.remove(one), None);
-        assert_eq!(pool.remove(two), None);
-        assert_eq!(pool.remove(three), None);
+        for (i, &handle) in handles.iter().enumerate() {
+            assert_eq!(pool.get(handle), Some(&i));
+            assert_eq!(pool.remove(handle), Some(i));
+            assert_eq!(pool.get(handle), None);
+            assert_eq!(pool.remove(handle), None);
+        }
+
+        assert_eq!(pool.get(Handle::null()), None);
     }
 
     // This test is based on randomness in the base address of the pool so disable it by default to
     // avoid flaky tests in CI.
+    // We do a basic attempt to ensure that mixing handles from different pools with either assert or return None.
     #[test]
     #[ignore]
+    #[should_panic]
     fn test_pool_randomiser() {
         let mut pool_1 = Pool::new();
         let mut pool_2 = Pool::new();
-
         let handle_1 = pool_1.insert(1);
         let handle_2 = pool_2.insert(1);
         assert_ne!(handle_1, handle_2);
+        assert_eq!(pool_1.get(handle_2), None);
+        assert_eq!(pool_2.get(handle_1), None);
     }
 
-    // This test is based on randomness in the base address of the pool so disable it by default to
-    // avoid flaky tests in CI.
-    #[test]
-    #[ignore]
-    #[should_panic]
-    fn test_pool_randomiser_fail() {
-        let mut pool_1 = Pool::new();
-        let mut pool_2 = Pool::new();
-        let handle_1 = pool_1.insert(1);
-        let _handle_2 = pool_2.insert(1);
-        pool_2.get(handle_1).unwrap();
-    }
-
-    // Fills the entire pool which is slow in debug mode, so ignore this test.
     #[test]
-    #[ignore]
     fn capacity() {
-        #[derive(Clone, Copy)]
-        struct Chonk {
-            value: usize,
-            _pad: [u8; 4096 - std::mem::size_of::<usize>()],
-        }
+        let mut pool = Pool::new();
 
-        impl Chonk {
-            fn new(value: usize) -> Self {
-                Self {
-                    value,
-                    _pad: [0; 4096 - std::mem::size_of::<usize>()],
-                }
-            }
-        }
+        let handles = (0..MAX_CAP).map(|i| pool.insert(i)).collect::<Vec<_>>();
 
-        impl PartialEq for Chonk {
-            fn eq(&self, rhs: &Self) -> bool {
-                self.value == rhs.value
-            }
-        }
+        assert_eq!(pool.len(), MAX_CAP);
 
-        let mut pool = Pool::new();
+        for (i, &handle) in handles.iter().enumerate() {
+            assert_eq!(pool.get(handle), Some(&i));
+        }
 
-        for i in 0..MAX_CAPACITY - MIN_FREE_SLOTS {
-            let chonk = Chonk::new(i);
-            let handle = pool.insert(chonk);
-            assert!(pool.get(handle) == Some(&chonk));
+        for (i, &handle) in handles.iter().enumerate() {
+            assert_eq!(pool.remove(handle), Some(i));
         }
 
-        assert_eq!(pool.len(), MAX_CAPACITY - MIN_FREE_SLOTS);
+        assert!(pool.is_empty());
     }
 
     #[test]
index 8a515355c63087de8824ced7032a65d94d2e7c60..3ad3b3af62a17b941aeab4672b56eebf8992e19f 100644 (file)
@@ -5,13 +5,9 @@ use crate::libc;
 /// Size will be rounded up to align with the system's page size.
 ///
 /// The range is valid but inaccessible before calling `virtual_commit`.
-///
-/// # Panics
-///
-/// Panics if mapping fails.
 #[cold]
 #[inline(never)]
-pub fn virtual_reserve(size: usize) -> *mut std::ffi::c_void {
+pub fn virtual_reserve(size: usize) -> Result<*mut std::ffi::c_void, ()> {
     let ptr = unsafe {
         libc::mmap(
             std::ptr::null_mut(),
@@ -23,9 +19,11 @@ pub fn virtual_reserve(size: usize) -> *mut std::ffi::c_void {
         )
     };
 
-    assert!(ptr != libc::MAP_FAILED && !ptr.is_null());
-
-    ptr
+    if ptr == libc::MAP_FAILED || ptr.is_null() {
+        Err(())
+    } else {
+        Ok(ptr)
+    }
 }
 
 /// Commit (part of) a previously reserved memory range.
@@ -55,13 +53,13 @@ pub unsafe fn virtual_commit(ptr: *mut std::ffi::c_void, size: usize) {
 ///
 /// - Must point to an existing assignment created by [`virtual_reserve`].
 /// - `size` must be within range of that reservation.
-///
-/// # Panics
-///
-/// Panics if the range could not be unmapped.
 #[cold]
 #[inline(never)]
-pub unsafe fn virtual_free(ptr: *mut std::ffi::c_void, size: usize) {
+pub unsafe fn virtual_free(ptr: *mut std::ffi::c_void, size: usize) -> Result<(), ()> {
     let result = libc::munmap(ptr, size);
-    assert!(result == 0);
+    if result != 0 {
+        Err(())
+    } else {
+        Ok(())
+    }
 }
index 688cf0a1af42831a7d28fd489c78d3d41239d8d8..88352d7325745648b3638143064aa2976d4c2edc 100644 (file)
@@ -32,7 +32,8 @@ impl<T> VirtualRawVec<T> {
         // Check overflow of rounding operation.
         assert!(max_capacity_bytes <= (std::usize::MAX - (align - 1)));
 
-        let ptr = unsafe { NonNull::new_unchecked(virtual_reserve(max_capacity_bytes) as *mut T) };
+        let ptr = virtual_reserve(max_capacity_bytes).expect("mapping failed");
+        let ptr = unsafe { NonNull::new_unchecked(ptr as *mut T) };
 
         Self {
             ptr,
@@ -129,7 +130,8 @@ impl<T> Drop for VirtualRawVec<T> {
             virtual_free(
                 self.ptr.as_ptr() as *mut std::ffi::c_void,
                 self.max_cap * size_of::<T>(),
-            );
+            )
+            .expect("failed to unmap memory");
         }
     }
 }