}
}
+/// Returns either `x`, if `t` is `false` or `y` if `t` is `true`, avoiding branches.
+#[must_use]
+#[inline(always)]
+fn select_f32(x: f32, y: f32, t: bool) -> f32 {
+ // With avx512 the compiler tends to emit masked moves anyway, so don't bother being clever.
+ #[cfg(any(target_feature = "avx512f", not(target_feature = "sse4.1")))]
+ {
+ if t {
+ y
+ } else {
+ x
+ }
+ }
+
+ #[cfg(all(target_feature = "sse4.1", not(target_feature = "avx512f")))]
+ unsafe {
+ let x = core::arch::x86_64::_mm_load_ss(&x);
+ let y = core::arch::x86_64::_mm_load_ss(&y);
+ let mask = std::mem::transmute(core::arch::x86_64::_mm_cvtsi32_si128(-(t as i32)));
+ let mut res = 0.0_f32;
+ core::arch::x86_64::_mm_store_ss(&mut res, core::arch::x86_64::_mm_blendv_ps(x, y, mask));
+ res
+ }
+}
+
#[macro_export]
macro_rules! impl_shared {
($name:ty, $t:ty, $n:expr) => {
#[cfg(test)]
mod tests {
- use crate::{dequantize_unorm_u8, quantize_unorm_u8};
+ use crate::{dequantize_unorm_u8, quantize_unorm_u8, select_f32};
#[test]
fn quantize_dequantize() {
assert_eq!(dequantize_unorm_u8(255), 1.0);
assert_eq!(dequantize_unorm_u8(0), 0.0);
}
+
+ #[test]
+ fn select() {
+ assert_eq!(select_f32(1.0, 2.0, true), 2.0);
+ assert_eq!(select_f32(1.0, 2.0, false), 1.0);
+ }
}
//
// Sollya code for generating these polynomials is in `doc/sincostan.sollya`
-use crate::f32_to_i32;
+use crate::{f32_to_i32, select_f32};
// constants for sin(pi x), cos(pi x) for x on [-1/4,1/4]
const F32_SIN_PI_7_K: [f32; 3] = unsafe {
])
};
-#[inline(always)]
-fn mulsign_f32(x: f32, s: u32) -> f32 {
- f32::from_bits(x.to_bits() ^ s)
-}
-
/// Simultaneously computes the sine and cosine of `a` expressed in multiples of
/// *pi* radians, or half-turns.
///
let i = f32_to_i32(r) as u32;
let r = r.mul_add(-0.5, a);
- let sx = (i >> 1) << 31;
- let sy = (i << 31) ^ sx;
-
- // Core approximation.
let r2 = r * r;
- let r = mulsign_f32(r, sy);
+ // Reconstruct signs early.
+ let sign_x = (i >> 1) << 31;
+ let sign_y = sign_x ^ i << 31;
+ let r_sign = r.copysign(f32::from_bits(r.to_bits() ^ sign_y));
+ let r2_sign = r2.copysign(f32::from_bits(r2.to_bits() ^ sign_x));
+ let one_sign = 1.0_f32.copysign(f32::from_bits(sign_x));
+
+ // Core approximation.
let c = C[3];
let c = c.mul_add(r2, C[2]);
let c = c.mul_add(r2, C[1]);
let c = c.mul_add(r2, C[0]);
- let c = c.mul_add(r2, 1.0);
- let c = mulsign_f32(c, sx);
+ let c = c.mul_add(r2_sign, one_sign);
let s = S[2];
let s = s.mul_add(r2, S[1]);
let s = s.mul_add(r2, S[0]);
- let s = r.mul_add(std::f32::consts::PI, r * r2.mul_add(s, -8.742278e-8));
+ let s = r_sign.mul_add(std::f32::consts::PI, r_sign * r2.mul_add(s, -8.742278e-8));
- let (s, c) = if i & 1 != 0 { (c, s) } else { (s, c) };
+ let t = s;
+ let s = select_f32(s, c, i & 1 != 0);
+ let c = select_f32(c, t, i & 1 != 0);
// IEEE-754: sin_pi(+n) is +0 and sin_pi(-n) is -0 for positive integers n
let s = if a == a.floor() { a * 0.0 } else { s };
#[cfg(test)]
mod tests {
- use crate::sin_cos_pi_f32;
+ use super::sin_cos_pi_f32;
#[test]
fn basics() {