1use crate::support::{
45 CastFrom, CastInto, DInt, Float, FpResult, HInt, Int, IntTy, MinInt, Round, Status, cold_path,
46};
47
48#[inline]
49pub fn sqrt<F>(x: F) -> F
50where
51 F: Float + SqrtHelper,
52 F::Int: HInt,
53 F::Int: From<u8>,
54 F::Int: From<F::ISet2>,
55 F::Int: CastInto<F::ISet1>,
56 F::Int: CastInto<F::ISet2>,
57 u32: CastInto<F::Int>,
58{
59 sqrt_round(x, Round::Nearest).val
60}
61
62#[inline]
63pub fn sqrt_round<F>(x: F, _round: Round) -> FpResult<F>
64where
65 F: Float + SqrtHelper,
66 F::Int: HInt,
67 F::Int: From<u8>,
68 F::Int: From<F::ISet2>,
69 F::Int: CastInto<F::ISet1>,
70 F::Int: CastInto<F::ISet2>,
71 u32: CastInto<F::Int>,
72{
73 let zero = IntTy::<F>::ZERO;
74 let one = IntTy::<F>::ONE;
75
76 let mut ix = x.to_bits();
77
78 let noshift = F::BITS <= u32::BITS;
81 let (mut top, special_case) = if noshift {
82 let exp_lsb = one << F::SIG_BITS;
83 let special_case = ix.wrapping_sub(exp_lsb) >= F::EXP_MASK - exp_lsb;
84 (Exp::NoShift(()), special_case)
85 } else {
86 let top = u32::cast_from(ix >> F::SIG_BITS);
87 let special_case = top.wrapping_sub(1) >= F::EXP_SAT - 1;
88 (Exp::Shifted(top), special_case)
89 };
90
91 if special_case {
93 cold_path();
94
95 if ix << 1 == zero {
97 return FpResult::ok(x);
98 }
99
100 if ix == F::EXP_MASK {
102 return FpResult::ok(x);
103 }
104
105 if ix > F::EXP_MASK {
107 return FpResult::new(F::NAN, Status::INVALID);
108 }
109
110 let scaled = x * F::from_parts(false, F::SIG_BITS + F::EXP_BIAS, zero);
112 ix = scaled.to_bits();
113 match top {
114 Exp::Shifted(ref mut v) => {
115 *v = scaled.ex();
116 *v = (*v).wrapping_sub(F::SIG_BITS);
117 }
118 Exp::NoShift(()) => {
119 ix = ix.wrapping_sub((F::SIG_BITS << F::SIG_BITS).cast());
120 }
121 }
122 }
123
124 let (m_u2, exp) = match top {
129 Exp::Shifted(top) => {
130 let mut e = top;
132 let mut m_u2 = (ix | F::IMPLICIT_BIT) << F::EXP_BITS;
134 let even = (e & 1) != 0;
135 if even {
136 m_u2 >>= 1;
137 }
138 e = (e.wrapping_add(F::EXP_SAT >> 1)) >> 1;
139 (m_u2, Exp::Shifted(e))
140 }
141 Exp::NoShift(()) => {
142 let even = ix & (one << F::SIG_BITS) != zero;
143
144 let mut e_noshift = ix >> 1;
146 e_noshift += (F::EXP_MASK ^ (F::SIGN_MASK >> 1)) >> 1;
148 e_noshift &= F::EXP_MASK;
149
150 let m1 = (ix << F::EXP_BITS) | F::SIGN_MASK;
151 let m0 = (ix << (F::EXP_BITS - 1)) & !F::SIGN_MASK;
152 let m_u2 = if even { m0 } else { m1 };
153
154 (m_u2, Exp::NoShift(e_noshift))
155 }
156 };
157
158 let i = usize::cast_from(ix >> (F::SIG_BITS - 6)) & 0b1111111;
160
161 let r1_u0: F::ISet1 = F::ISet1::cast_from(RSQRT_TAB[i]) << (F::ISet1::BITS - 16);
164 let s1_u2: F::ISet1 = ((m_u2) >> (F::BITS - F::ISet1::BITS)).cast();
165
166 let (r1_u0, _s1_u2) = goldschmidt::<F, F::ISet1>(r1_u0, s1_u2, F::SET1_ROUNDS, false);
168
169 let r2_u0: F::ISet2 = F::ISet2::from(r1_u0) << (F::ISet2::BITS - F::ISet1::BITS);
171 let s2_u2: F::ISet2 = ((m_u2) >> (F::BITS - F::ISet2::BITS)).cast();
172 let (r2_u0, _s2_u2) = goldschmidt::<F, F::ISet2>(r2_u0, s2_u2, F::SET2_ROUNDS, false);
173
174 let r_u0: F::Int = F::Int::from(r2_u0) << (F::BITS - F::ISet2::BITS);
176 let s_u2: F::Int = m_u2;
177 let (_r_u0, s_u2) = goldschmidt::<F, F::Int>(r_u0, s_u2, F::FINAL_ROUNDS, true);
178
179 let mut m = s_u2 >> (F::EXP_BITS - 2);
181
182 let shift = 2 * F::SIG_BITS - (F::BITS - 2);
202
203 let d0 = (m_u2 << shift).wrapping_sub(m.wrapping_mul(m));
205 let d1 = m.wrapping_sub(d0);
207 m += d1 >> (F::BITS - 1);
208 m &= F::SIG_MASK;
209
210 match exp {
211 Exp::Shifted(e) => m |= IntTy::<F>::cast_from(e) << F::SIG_BITS,
212 Exp::NoShift(e) => m |= e,
213 };
214
215 let mut y = F::from_bits(m);
216
217 if F::BITS > 16 {
219 let d2 = d1.wrapping_add(m).wrapping_add(one);
222 let mut tiny = if d2 == zero {
223 cold_path();
224 zero
225 } else {
226 F::IMPLICIT_BIT
227 };
228
229 tiny |= (d1 ^ d2) & F::SIGN_MASK;
230 let t = F::from_bits(tiny);
231 y = y + t;
232 }
233
234 FpResult::ok(y)
235}
236
237fn wmulh<I: HInt>(a: I, b: I) -> I {
239 a.widen_mul(b).hi()
240}
241
242#[inline]
253fn goldschmidt<F, I>(mut r_u0: I, mut s_u2: I, count: u32, final_set: bool) -> (I, I)
254where
255 F: SqrtHelper,
256 I: HInt + From<u8>,
257{
258 let three_u2 = I::from(0b11u8) << (I::BITS - 2);
259 let mut u_u0 = r_u0;
260
261 for i in 0..count {
262 s_u2 = wmulh(s_u2, u_u0);
265
266 if i > 0 && (!final_set || i + 1 < count) {
275 s_u2 <<= 1;
276 }
277
278 let d_u2 = wmulh(s_u2, r_u0);
280 u_u0 = three_u2.wrapping_sub(d_u2);
281
282 r_u0 = wmulh(r_u0, u_u0) << 1;
284 }
285
286 (r_u0, s_u2)
287}
288
289enum Exp<T> {
292 Shifted(u32),
294 NoShift(T),
296}
297
298pub trait SqrtHelper: Float {
300 type ISet1: HInt + Into<Self::ISet2> + CastFrom<Self::Int> + From<u8>;
302 type ISet2: HInt + From<Self::ISet1> + From<u8>;
304
305 const SET1_ROUNDS: u32 = 0;
307 const SET2_ROUNDS: u32 = 0;
309 const FINAL_ROUNDS: u32;
311}
312
313#[cfg(f16_enabled)]
314impl SqrtHelper for f16 {
315 type ISet1 = u16; type ISet2 = u16; const FINAL_ROUNDS: u32 = 2;
319}
320
321impl SqrtHelper for f32 {
322 type ISet1 = u32; type ISet2 = u32; const FINAL_ROUNDS: u32 = 3;
326}
327
328impl SqrtHelper for f64 {
329 type ISet1 = u32; type ISet2 = u32;
331
332 const SET2_ROUNDS: u32 = 2;
333 const FINAL_ROUNDS: u32 = 2;
334}
335
336#[cfg(f128_enabled)]
337impl SqrtHelper for f128 {
338 type ISet1 = u32;
339 type ISet2 = u64;
340
341 const SET1_ROUNDS: u32 = 1;
342 const SET2_ROUNDS: u32 = 2;
343 const FINAL_ROUNDS: u32 = 2;
344}
345
346#[rustfmt::skip]
350static RSQRT_TAB: [u16; 128] = [
351 0xb451, 0xb2f0, 0xb196, 0xb044, 0xaef9, 0xadb6, 0xac79, 0xab43,
352 0xaa14, 0xa8eb, 0xa7c8, 0xa6aa, 0xa592, 0xa480, 0xa373, 0xa26b,
353 0xa168, 0xa06a, 0x9f70, 0x9e7b, 0x9d8a, 0x9c9d, 0x9bb5, 0x9ad1,
354 0x99f0, 0x9913, 0x983a, 0x9765, 0x9693, 0x95c4, 0x94f8, 0x9430,
355 0x936b, 0x92a9, 0x91ea, 0x912e, 0x9075, 0x8fbe, 0x8f0a, 0x8e59,
356 0x8daa, 0x8cfe, 0x8c54, 0x8bac, 0x8b07, 0x8a64, 0x89c4, 0x8925,
357 0x8889, 0x87ee, 0x8756, 0x86c0, 0x862b, 0x8599, 0x8508, 0x8479,
358 0x83ec, 0x8361, 0x82d8, 0x8250, 0x81c9, 0x8145, 0x80c2, 0x8040,
359 0xff02, 0xfd0e, 0xfb25, 0xf947, 0xf773, 0xf5aa, 0xf3ea, 0xf234,
360 0xf087, 0xeee3, 0xed47, 0xebb3, 0xea27, 0xe8a3, 0xe727, 0xe5b2,
361 0xe443, 0xe2dc, 0xe17a, 0xe020, 0xdecb, 0xdd7d, 0xdc34, 0xdaf1,
362 0xd9b3, 0xd87b, 0xd748, 0xd61a, 0xd4f1, 0xd3cd, 0xd2ad, 0xd192,
363 0xd07b, 0xcf69, 0xce5b, 0xcd51, 0xcc4a, 0xcb48, 0xca4a, 0xc94f,
364 0xc858, 0xc764, 0xc674, 0xc587, 0xc49d, 0xc3b7, 0xc2d4, 0xc1f4,
365 0xc116, 0xc03c, 0xbf65, 0xbe90, 0xbdbe, 0xbcef, 0xbc23, 0xbb59,
366 0xba91, 0xb9cc, 0xb90a, 0xb84a, 0xb78c, 0xb6d0, 0xb617, 0xb560,
367];
368
369#[cfg(test)]
370mod tests {
371 use super::*;
372
373 fn spec_test<F>()
375 where
376 F: Float + SqrtHelper,
377 F::Int: HInt,
378 F::Int: From<u8>,
379 F::Int: From<F::ISet2>,
380 F::Int: CastInto<F::ISet1>,
381 F::Int: CastInto<F::ISet2>,
382 u32: CastInto<F::Int>,
383 {
384 let nan = [F::NEG_INFINITY, F::NEG_ONE, F::NAN, F::MIN];
386
387 let roundtrip = [F::ZERO, F::NEG_ZERO, F::INFINITY];
389
390 for x in nan {
391 let FpResult { val, status } = sqrt_round(x, Round::Nearest);
392 assert!(val.is_nan());
393 assert!(status == Status::INVALID);
394 }
395
396 for x in roundtrip {
397 let FpResult { val, status } = sqrt_round(x, Round::Nearest);
398 assert_biteq!(val, x);
399 assert!(status == Status::OK);
400 }
401 }
402
403 #[test]
404 #[cfg(f16_enabled)]
405 fn sanity_check_f16() {
406 assert_biteq!(sqrt(100.0f16), 10.0);
407 assert_biteq!(sqrt(4.0f16), 2.0);
408 }
409
410 #[test]
411 #[cfg(f16_enabled)]
412 fn spec_tests_f16() {
413 spec_test::<f16>();
414 }
415
416 #[test]
417 #[cfg(f16_enabled)]
418 #[allow(clippy::approx_constant)]
419 fn conformance_tests_f16() {
420 let cases = [
421 (f16::PI, 0x3f17_u16),
422 (f16::from_bits(0x70e2), 0x5640_u16),
425 (f16::from_bits(0x0000000f), 0x13bf_u16),
426 (f16::INFINITY, f16::INFINITY.to_bits()),
427 ];
428
429 for (input, output) in cases {
430 assert_biteq!(
431 sqrt(input),
432 f16::from_bits(output),
433 "input: {input:?} ({:#018x})",
434 input.to_bits()
435 );
436 }
437 }
438
439 #[test]
440 fn sanity_check_f32() {
441 assert_biteq!(sqrt(100.0f32), 10.0);
442 assert_biteq!(sqrt(4.0f32), 2.0);
443 }
444
445 #[test]
446 fn spec_tests_f32() {
447 spec_test::<f32>();
448 }
449
450 #[test]
451 #[allow(clippy::approx_constant)]
452 fn conformance_tests_f32() {
453 let cases = [
454 (f32::PI, 0x3fe2dfc5_u32),
455 (10000.0f32, 0x42c80000_u32),
456 (f32::from_bits(0x0000000f), 0x1b2f456f_u32),
457 (f32::INFINITY, f32::INFINITY.to_bits()),
458 ];
459
460 for (input, output) in cases {
461 assert_biteq!(
462 sqrt(input),
463 f32::from_bits(output),
464 "input: {input:?} ({:#018x})",
465 input.to_bits()
466 );
467 }
468 }
469
470 #[test]
471 fn sanity_check_f64() {
472 assert_biteq!(sqrt(100.0f64), 10.0);
473 assert_biteq!(sqrt(4.0f64), 2.0);
474 }
475
476 #[test]
477 fn spec_tests_f64() {
478 spec_test::<f64>();
479 }
480
481 #[test]
482 #[allow(clippy::approx_constant)]
483 fn conformance_tests_f64() {
484 let cases = [
485 (f64::PI, 0x3ffc5bf891b4ef6a_u64),
486 (10000.0, 0x4059000000000000_u64),
487 (f64::from_bits(0x0000000f), 0x1e7efbdeb14f4eda_u64),
488 (f64::INFINITY, f64::INFINITY.to_bits()),
489 ];
490
491 for (input, output) in cases {
492 assert_biteq!(
493 sqrt(input),
494 f64::from_bits(output),
495 "input: {input:?} ({:#018x})",
496 input.to_bits()
497 );
498 }
499 }
500
501 #[test]
502 #[cfg(f128_enabled)]
503 fn sanity_check_f128() {
504 assert_biteq!(sqrt(100.0f128), 10.0);
505 assert_biteq!(sqrt(4.0f128), 2.0);
506 }
507
508 #[test]
509 #[cfg(f128_enabled)]
510 fn spec_tests_f128() {
511 spec_test::<f128>();
512 }
513
514 #[test]
515 #[cfg(f128_enabled)]
516 #[allow(clippy::approx_constant)]
517 fn conformance_tests_f128() {
518 let cases = [
519 (f128::PI, 0x3fffc5bf891b4ef6aa79c3b0520d5db9_u128),
520 (
522 f128::from_bits(0x400c3880000000000000000000000000),
523 0x40059000000000000000000000000000_u128,
524 ),
525 (
526 f128::from_bits(0x0000000f),
527 0x1fc9efbdeb14f4ed9b17ae807907e1e9_u128,
528 ),
529 (f128::INFINITY, f128::INFINITY.to_bits()),
530 ];
531
532 for (input, output) in cases {
533 assert_biteq!(
534 sqrt(input),
535 f128::from_bits(output),
536 "input: {input:?} ({:#018x})",
537 input.to_bits()
538 );
539 }
540 }
541}