[go: up one dir, main page]

cordic/
lib.rs

1//! Implementations of special functions based on the CORDIC algorithm.
2
3#![no_std]
4
5// Used to make the tests build and run.
6#[cfg(test)]
7#[macro_use]
8extern crate std;
9
10mod cordic_number;
11
12pub use cordic_number::CordicNumber;
13use fixed::types::U0F64;
14use core::convert::TryInto;
15
16const ATAN_TABLE: &[u8] = include_bytes!("tables/cordic_atan.table");
17const EXP_MINUS_ONE_TABLE: &[u8] = include_bytes!("tables/cordic_exp_minus_one.table");
18
19fn lookup_table(table: &[u8], index: u8) -> U0F64 {
20    let i = index as usize * 8;
21    U0F64::from_bits(u64::from_le_bytes(table[i..(i + 8)].try_into().unwrap()))
22}
23
24// See cordit1 from http://www.voidware.com/cordic.htm
25fn cordic_circular<T: CordicNumber>(mut x: T, mut y: T, mut z: T, vecmode: T) -> (T, T, T) {
26    let _0 = T::zero();
27    let _2 = T::one() + T::one();
28
29    for i in 0..T::num_fract_bits() {
30        if vecmode >= _0 && y < vecmode || vecmode < _0 && z >= _0 {
31            let x1 = x - (y >> i);
32            y = y + (x >> i);
33            x = x1;
34            z = z - T::from_u0f64(lookup_table(ATAN_TABLE, i));
35        } else {
36            let x1 = x + (y >> i);
37            y = y - (x >> i);
38            x = x1;
39            z = z + T::from_u0f64(lookup_table(ATAN_TABLE, i));
40        }
41    }
42
43    (x, y, z)
44}
45
46fn gain_cordic<T: CordicNumber>() -> T {
47    cordic_circular(T::one(), T::zero(), T::zero(), -T::one()).0
48}
49
50/// Compute simultaneously the sinus and cosine of the given fixed-point number.
51pub fn sin_cos<T: CordicNumber>(mut angle: T) -> (T, T) {
52    let mut negative = false;
53
54    while angle > T::frac_pi_2() {
55        angle -= T::pi();
56        negative = !negative;
57    }
58
59    while angle < -T::frac_pi_2() {
60        angle += T::pi();
61        negative = !negative;
62    }
63
64    let inv_gain = T::one() / gain_cordic(); // FIXME: precompute this.
65    let res = cordic_circular(inv_gain, T::zero(), angle, -T::one());
66
67    if negative {
68        (-res.1, -res.0)
69    } else {
70        (res.1, res.0)
71    }
72}
73
74/// Compute the sinus of the given fixed-point number.
75pub fn sin<T: CordicNumber>(angle: T) -> T {
76    sin_cos(angle).0
77}
78
79/// Compute the cosinus of the given fixed-point number.
80pub fn cos<T: CordicNumber>(angle: T) -> T {
81    sin_cos(angle).1
82}
83
84/// Compute the tangent of the given fixed-point number.
85pub fn tan<T: CordicNumber>(angle: T) -> T {
86    let (sin, cos) = sin_cos(angle);
87    sin / cos
88}
89
90/// Compute the arc-sinus of the given fixed-point number.
91pub fn asin<T: CordicNumber>(mut val: T) -> T {
92    // For asin, we use a double-rotation approach to reduce errors.
93    // NOTE: see https://stackoverflow.com/questions/25976656/cordic-arcsine-implementation-fails
94    // for details about the innacuracy of CORDIC for asin.
95
96    let mut theta = T::zero();
97    let mut z = (T::one(), T::zero());
98    let niter = T::num_fract_bits();
99
100    for j in 0..niter {
101        let sign_x = if z.0 < T::zero() { -T::one() } else { T::one() };
102        let sigma = if z.1 <= val { sign_x } else { -sign_x };
103        let rotate = |(x, y)| (x - ((y >> j) * sigma), y + ((x >> j) * sigma));
104        z = rotate(rotate(z));
105
106        let angle = T::from_u0f64(lookup_table(ATAN_TABLE, j));
107        theta = theta + ((angle + angle) * sigma);
108        val = val + (val >> (j + j));
109    }
110
111    theta
112}
113
114/// Compute the arc-cosine of the given fixed-point number.
115pub fn acos<T: CordicNumber>(val: T) -> T {
116    T::frac_pi_2() - asin(val)
117}
118
119/// Compute the arc-tangent of the given fixed-point number.
120pub fn atan<T: CordicNumber>(val: T) -> T {
121    cordic_circular(T::one(), val, T::zero(), T::zero()).2
122}
123
124/// Compute the arc-tangent of `y/x` with quadrant correction.
125pub fn atan2<T: CordicNumber>(y: T, x: T) -> T {
126    if x == T::zero() {
127        if y < T::zero() {
128            return -T::frac_pi_2();
129        } else {
130            return T::frac_pi_2();
131        }
132    }
133
134    if y == T::zero() {
135        if x >= T::zero() {
136            return T::zero();
137        } else {
138            return T::pi();
139        }
140    }
141
142    match (x < T::zero(), y < T::zero()) {
143        (false, false) => atan(y / x),
144        (false, true) => -atan(-y / x),
145        (true, false) => T::pi() - atan(y / -x),
146        (true, true) => atan(y / x) - T::pi(),
147    }
148}
149
150/// Compute the exponential root of the given fixed-point number.
151pub fn exp<T: CordicNumber>(x: T) -> T {
152    assert!(
153        T::num_fract_bits() <= 128,
154        "Exp is not supported for more than 128 decimals."
155    );
156    let _0 = T::zero();
157    let _1 = T::one();
158    let _3 = T::one() + T::one() + T::one();
159    let mut int_part = x.floor();
160    let mut dec_part = x - int_part;
161    let mut poweroftwo = T::half();
162    let mut w = [false; 128];
163
164    for i in 0..T::num_fract_bits() {
165        if poweroftwo < dec_part {
166            w[i as usize] = true;
167            dec_part -= poweroftwo;
168        }
169
170        poweroftwo = poweroftwo >> 1;
171    }
172
173    let mut fx = _1;
174
175    for i in 0..T::num_fract_bits() {
176        if w[i as usize] {
177            let ai = T::from_u0f64(lookup_table(EXP_MINUS_ONE_TABLE, i)) + T::one();
178            fx = fx * ai;
179        }
180    }
181
182    let f4 = _1 + (dec_part >> 2);
183    let f3 = _1 + (dec_part / _3) * f4;
184    let f2 = _1 + (dec_part >> 1) * f3;
185    let f1 = _1 + dec_part * f2;
186    fx = fx * f1;
187
188    if int_part < _0 {
189        while int_part != _0 {
190            fx = fx / T::e();
191            int_part += _1;
192        }
193    } else {
194        while int_part != _0 {
195            fx = fx * T::e();
196            int_part -= _1;
197        }
198    }
199
200    fx
201}
202
203/// Compute the square root of the given fixed-point number.
204pub fn sqrt<T: CordicNumber>(x: T) -> T {
205    if x == T::zero() || x == T::one() {
206        return x;
207    }
208
209    let mut pow2 = T::one();
210    let mut result;
211
212    if x < T::one() {
213        while x <= pow2 * pow2 {
214            pow2 = pow2 >> 1;
215        }
216
217        result = pow2;
218    } else {
219        // x >= T::one()
220        while pow2 * pow2 <= x {
221            pow2 = pow2 << 1;
222        }
223
224        result = pow2 >> 1;
225    }
226
227    for _ in 0..T::num_bits() {
228        pow2 = pow2 >> 1;
229        let next_result = result + pow2;
230        if next_result * next_result <= x {
231            result = next_result;
232        }
233    }
234
235    result
236}
237
238#[cfg(test)]
239mod tests {
240    use super::*;
241    use fixed::types::I48F16;
242
243    fn assert_approx_eq<T: core::fmt::Display>(
244        input: T,
245        computed: f64,
246        expected: f64,
247        max_err: f64,
248    ) {
249        let err = (computed - expected).abs();
250        if err > max_err {
251            panic!(
252                "mismatch for input {}: computed {}, expected {}",
253                input, computed, expected
254            );
255        }
256    }
257
258    macro_rules! test_trig(
259        ($test: ident, $test_comprehensive: ident, $trigf: ident, $max_err: expr) => {
260            #[test]
261            fn $test() {
262                for i in -100..100 {
263                    let fx = f64::from(i) * 0.1_f64;
264                    let x: I48F16 = I48F16::from_num(fx);
265                    assert_approx_eq(x, $trigf(x).to_num(), fx.$trigf(), $max_err);
266                }
267            }
268
269            #[test]
270            fn $test_comprehensive() {
271                for i in 0..(1 << 20) {
272                    let x = I48F16::from_bits(i);
273                    let fx: f64 = x.to_num();
274                    assert_approx_eq(x, $trigf(x).to_num(), fx.$trigf(), $max_err);
275
276                    // Test negative numbers too.
277                    let x = -I48F16::from_bits(i);
278                    let fx: f64 = x.to_num();
279                    assert_approx_eq(x, $trigf(x).to_num(), fx.$trigf(), $max_err);
280                }
281            }
282        }
283    );
284
285    test_trig!(test_sin, test_sin_comprehensive, sin, 0.001);
286    test_trig!(test_cos, test_cos_comprehensive, cos, 0.001);
287    test_trig!(test_atan, test_atan_comprehensive, atan, 0.001);
288
289    #[test]
290    fn test_asin() {
291        for i in 0..(1 << 17) {
292            let x = I48F16::from_bits(i);
293            let fx: f64 = x.to_num();
294            assert_approx_eq(x, asin(x).to_num(), fx.asin(), 0.01);
295
296            // Test negative numbers too.
297            let x = I48F16::from_bits(i);
298            let fx: f64 = x.to_num();
299            assert_approx_eq(x, asin(x).to_num(), fx.asin(), 0.01);
300        }
301    }
302
303    #[test]
304    fn test_acos() {
305        for i in 0..(1 << 17) {
306            let x = I48F16::from_bits(i);
307            let fx: f64 = x.to_num();
308            assert_approx_eq(x, acos(x).to_num(), fx.acos(), 0.01);
309
310            // Test negative numbers too.
311            let x = I48F16::from_bits(i);
312            let fx: f64 = x.to_num();
313            assert_approx_eq(x, acos(x).to_num(), fx.acos(), 0.01);
314        }
315    }
316
317    #[test]
318    fn test_sqrt() {
319        for i in 0..(1 << 20) {
320            let x = I48F16::from_bits(i);
321            let fx: f64 = x.to_num();
322            assert_approx_eq(x, sqrt(x).to_num(), fx.sqrt(), 0.01);
323        }
324    }
325
326    #[test]
327    fn test_exp() {
328        for i in 0..(1 << 18) {
329            let x = I48F16::from_bits(i);
330            let fx: f64 = x.to_num();
331            assert_approx_eq(x, exp(x).to_num(), fx.exp(), 0.01);
332        }
333    }
334}