1 /*
2  * Copyright (c) Meta Platforms, Inc. and affiliates.
3  *
4  * This source code is licensed under the MIT license found in the
5  * LICENSE file in the root directory of this source tree.
6  */
11 #include "ocean/math/Math.h"
12 #include "ocean/math/Numeric.h"
14 #include "ocean/base/Utilities.h"
16 #include <complex>
18 namespace Ocean
19 {
21 // Forward declaration.
22 template <typename T> class EquationT;
24 /**
25  * Definition of the Equation object, depending on the OCEAN_MATH_USE_SINGLE_PRECISION either with single or double precision float data type.
26  * @see EquationT
27  * @ingroup math
28  */
31 /**
32  * Definition of the Equation class using double values
33  * @see EquationT
34  * @ingroup math
35  */
38 /**
39  * Definition of the Equation class using float values
40  * @see EquationT
41  * @ingroup math
42  */
45 /**
46  * This class provides several functions to solve equations with different degree using floating point values with the precission specified by type T.
47  * @tparam T Type of passed floating point values
48  * @see Equation, EquationF, EquationD.
49  * @ingroup math
50  */
51 template <typename T>
52 class EquationT
53 {
54  public:
56  /**
57  * Solves a linear eqation with the form:<br>
58  * ax + b = 0
59  * @param a A parameter, with range (-infinity, infinity) \ {0}, (must not be 0)
60  * @param b B parameter, with range (-infinity, infinity)
61  * @param x Resulting solution
62  * @return True, if succeeded
63  */
64  static bool solveLinear(const T a, const T b, T& x);
66  /**
67  * Solves an quadratic equation with the form:<br>
68  * ax^2 + bx + c = 0
69  * @param a A parameter, with range (-infinity, infinity) \ {0}, (must not be 0)
70  * @param b B parameter, with range (-infinity, infinity)
71  * @param c C parameter, with range (-infinity, infinity)
72  * @param x1 First resulting solution, with range (-infinity, infinity)
73  * @param x2 Second resulting solution, with range (-infinity, infinity)
74  * @return True, if succeeded
75  */
76  static bool solveQuadratic(const T a, const T b, const T c, T& x1, T& x2);
78  /**
79  * Solves a cubic equation with the from:<br>
80  * ax^3 + bx^2 + cx + d = 0
81  * @param a A parameter, with range (-infinity, infinity) \ {0}, (must not be 0)
82  * @param b B parameter, with range (-infinity, infinity)
83  * @param c C parameter, with range (-infinity, infinity)
84  * @param d D parameter, with range (-infinity, infinity)
85  * @param x1 First resulting solution, with range (-infinity, infinity)
86  * @param x2 Second resulting solution, with range (-infinity, infinity)
87  * @param x3 Third resulting solution, with range (-infinity, infinity)
88  * @return Number of solutions
89  */
90  static unsigned int solveCubic(const T a, const T b, const T c, const T d, T& x1, T& x2, T& x3);
92  /**
93  * Solves a quartic equation with the form:<br>
94  * ax^4 + bx^3 + cx^2 + dx + e = 0
95  * @param a A parameter, with range (-infinity, infinity) \ {0}, (must not be 0)
96  * @param b B parameter, with range (-infinity, infinity)
97  * @param c C parameter, with range (-infinity, infinity)
98  * @param d D parameter, with range (-infinity, infinity)
99  * @param e E parameter, with range (-infinity, infinity)
100  * @param x Array with at least four scalar values receiving the (at most) four solutions
101  * @return Number of solutions
102  */
103  static unsigned int solveQuartic(const T a, const T b, const T c, const T d, const T e, T* x);
104 };
106 template <typename T>
107 bool EquationT<T>::solveLinear(const T a, const T b, T& x)
108 {
109  ocean_assert(NumericT<T>::isNotEqualEps(a));
111  // ax + b = 0
114  {
115  return false;
116  }
118  x = -b / a;
119  return true;
120 }
122 template <typename T>
123 bool EquationT<T>::solveQuadratic(const T a, const T b, const T c, T& x1, T& x2)
124 {
125  ocean_assert(NumericT<T>::isNotEqualEps(a));
127  // ax^2 + bx + c = 0
128  // see Numerical Recipes in C++
131  {
132  return false;
133  }
135  const T value = b * b - 4 * a * c;
136  if (!NumericT<T>::isAbove(value, 0))
137  {
138  return false;
139  }
141  const T q = T(-0.5) * (b + ((value > T(0)) ? NumericT<T>::copySign(NumericT<T>::sqrt(value), b) : T(0)));
144  {
145  x1 = 0;
146  x2 = 0;
147  return true;
148  }
150  x1 = q / a;
151  x2 = c / q;
153 #ifdef OCEAN_DEBUG
154  if (std::is_same<T, float>::value)
155  {
156  // for 32 bit float values we have to weaken the zero accuracy
157  ocean_assert(NumericT<T>::isEqual(a * x1 * x1 + b * x1 + c, T(0), NumericT<T>::weakEps() * NumericT<T>::abs(x1)));
158  ocean_assert(NumericT<T>::isEqual(a * x2 * x2 + b * x2 + c, T(0), NumericT<T>::weakEps() * NumericT<T>::abs(x2)));
159  }
160  else
161  {
162  ocean_assert(NumericT<T>::isWeakEqualEps(a * x1 * x1 + b * x1 + c));
163  ocean_assert(NumericT<T>::isWeakEqualEps(a * x2 * x2 + b * x2 + c));
164  }
165 #endif
167  return true;
168 }
170 template <typename T>
171 unsigned int EquationT<T>::solveCubic(const T a, const T b, const T c, const T d, T& x1, T& x2, T& x3)
172 {
173  ocean_assert(NumericT<T>::isNotEqualEps(a));
175  // ax^3 + bx^2 + cx + d = 0
176  // see Numerical Recipes in C++
179  {
180  return 0u;
181  }
183  const T a1 = 1 / a;
184  const T alpha = b * a1;
185  const T beta = c * a1;
186  const T gamma = d * a1;
188  // x^3 + alpha x^2 + beta x + gamma = 0
190  // alpha2 = alpha^2
191  const T alpha2 = alpha * alpha;
193  // q = (alpha^2 - 3b) / 9
194  const T q = (alpha2 - 3 * beta) * T(0.11111111111111111111111111111111);
196  // r = (2 alpha^3 - 9 alpha beta + 27gamma) / 54
197  const T r = (2 * alpha2 * alpha - 9 * alpha * beta + 27 * gamma) * T(0.018518518518518518518518518518519);
199  // r2 = r^2
200  const T r2 = r * r;
202  // q3 = q^3
203  const T q3 = q * q * q;
205  if (r2 <= q3 + Numeric::eps() && q > NumericT<T>::eps())
206  {
207  const T sqrtQ = NumericT<T>::sqrt(q);
209  // angle = arccos(r / sqrt(q^3))
210  // angle_3 = angle / 3
211  const T angle_3 = NumericT<T>::acos(minmax<T>(-1, r / (q * sqrtQ), 1)) * T(0.33333333333333333333333333333333);
213  // alpha_3 = alpha / 3
214  const T alpha_3 = alpha * T(0.33333333333333333333333333333333);
216  const T factor = -2 * sqrtQ;
218  // x1 = -2 sqrt(q) * cos(angle / 3) - alpha / 3
219  x1 = factor * NumericT<T>::cos(angle_3) - alpha_3;
221  // x2 = -2 sqrt(q) * cos((angle + 2pi) / 3) - alpha / 3
222  x2 = factor * NumericT<T>::cos(angle_3 + T(2.0943951023931954923084289221863)) - alpha_3;
224  // x3 = -2 sqrt(q) * cos((angle - 2pi) / 3) - alpha / 3
225  x3 = factor * NumericT<T>::cos(angle_3 - T(2.0943951023931954923084289221863)) - alpha_3;
227 #ifdef OCEAN_DEBUG
229  const T value1 = a * x1 * x1 * x1 + b * x1 * x1 + c * x1 + d;
230  const T value2 = a * x2 * x2 * x2 + b * x2 * x2 + c * x2 + d;
231  const T value3 = a * x3 * x3 * x3 + b * x3 * x3 + c * x3 + d;
233  // the accuracy for 32 bit float values may be very poor so that we cannot define any assert
234  if (!std::is_same<T, float>::value)
235  {
236  ocean_assert(NumericT<T>::isEqual(value1, T(0), T(1e-3)));
237  ocean_assert(NumericT<T>::isEqual(value2, T(0), T(1e-3)));
238  ocean_assert(NumericT<T>::isEqual(value3, T(0), T(1e-3)));
239  }
240 #endif
242  return 3u;
243  }
245  ocean_assert(r2 - q3 >= -Numeric::eps());
247  // m = -sign(r) * [abs(r) + sqrt(r^2 - q^3)]^(1/3)
248  const T m = -NumericT<T>::copySign(pow(NumericT<T>::abs(r) + NumericT<T>::sqrt(std::max(T(0), r2 - q3)), T(0.33333333333333333333333333333333)), r);
250  // n = 0, if m == 0
251  // n = q / m, if m != 0
252  const T n = NumericT<T>::isEqualEps(m) ? 0 : q / m;
254  // x1 = (m + n) - alpha / 3
255  x1 = m + n - alpha * T(0.33333333333333333333333333333333);
257 #ifdef OCEAN_DEBUG
259  const T value1 = a * x1 * x1 * x1 + b * x1 * x1 + c * x1 + d;
261  // the accuracy for 32 bit float values may be very poor so that we cannot define any assert
262  if (!std::is_same<T, float>::value)
263  {
264  ocean_assert(NumericT<T>::isEqual(value1, T(0), T(1e-3)));
265  }
267 #endif
269  return 1u;
270 }
272 template <typename T>
273 unsigned int EquationT<T>::solveQuartic(const T a, const T b, const T c, const T d, const T e, T* x)
274 {
275  ocean_assert(NumericT<T>::isNotEqualEps(a));
276  ocean_assert(x != nullptr);
278  // ax^4 + bx^3 + cx^2 + dx + e = 0
281  {
282  return 0u;
283  }
285  // simplification using substitution:
286  /// y^4 + alpha * y^2 + beta * y + gamma = 0
288  // 1 / a
289  const T a1 = T(1.0) / a;
291  // b / a
292  const T b_a = b * a1;
294  // c / a
295  const T c_a = c * a1;
297  // (b * b) / (a * a)
298  const T b_a2 = b_a * b_a;
300  // (b * b * b) / (a * a * a)
301  const T b_a3 = b_a2 * b_a;
303  // d / a
304  const T d_a = d * a1;
306  //const T alpha = T(-0.375) * (b / a) * (b / a) + c / a;
307  const T alpha = T(-0.375) * b_a2 + c_a;
309  //const T beta = T(0.125) * (b / a) * (b / a) * (b / a) - T(0.5) * (b / a) * (c / a) + (d / a);
310  const T beta = T(0.125) * b_a3 - T(0.5) * b_a * c_a + d_a;
312  //const T gamma = T(-0.01171875) * (b / a) * (b / a) * (b / a) * (b / a) + T(0.0625) * (b / a) * (b / a) * (c / a) - T(0.25) * (b / a) * (d / a) + e / a;
313  const T gamma = T(-0.01171875) * b_a3 * b_a + T(0.0625) * b_a2 * c_a - T(0.25) * b_a * d_a + e * a1;
315  // y^4 + alpha y^2 + beta y + gamma = 0
317  if (NumericT<T>::isEqualEps(beta))
318  {
319  const std::complex<T> cx1 = std::complex<T>(T(-0.25)) * std::complex<T>(b / a) + NumericT<T>::sqrt(std::complex<T>(T(0.5)) * (std::complex<T>(-alpha) + NumericT<T>::sqrt(std::complex<T>(alpha * alpha - 4 * gamma))));
320  const std::complex<T> cx2 = std::complex<T>(T(-0.25)) * std::complex<T>(b / a) + NumericT<T>::sqrt(std::complex<T>(T(0.5)) * (std::complex<T>(-alpha) - NumericT<T>::sqrt(std::complex<T>(alpha * alpha - 4 * gamma))));
321  const std::complex<T> cx3 = std::complex<T>(T(-0.25)) * std::complex<T>(b / a) - NumericT<T>::sqrt(std::complex<T>(T(0.5)) * (std::complex<T>(-alpha) + NumericT<T>::sqrt(std::complex<T>(alpha * alpha - 4 * gamma))));
322  const std::complex<T> cx4 = std::complex<T>(T(-0.25)) * std::complex<T>(b / a) - NumericT<T>::sqrt(std::complex<T>(T(0.5)) * (std::complex<T>(-alpha) - NumericT<T>::sqrt(std::complex<T>(alpha * alpha - 4 * gamma))));
324  ocean_assert((std::is_same<T, float>::value) || NumericT<T>::isWeakEqualEps(cx1 * cx1 * cx1 * cx1 * a + cx1 * cx1 * cx1 * b + cx1 * cx1 * c + cx1 * d + e));
325  ocean_assert((std::is_same<T, float>::value) || NumericT<T>::isWeakEqualEps(cx2 * cx2 * cx2 * cx2 * a + cx2 * cx2 * cx2 * b + cx2 * cx2 * c + cx2 * d + e));
326  ocean_assert((std::is_same<T, float>::value) || NumericT<T>::isWeakEqualEps(cx3 * cx3 * cx3 * cx3 * a + cx3 * cx3 * cx3 * b + cx3 * cx3 * c + cx3 * d + e));
327  ocean_assert((std::is_same<T, float>::value) || NumericT<T>::isWeakEqualEps(cx4 * cx4 * cx4 * cx4 * a + cx4 * cx4 * cx4 * b + cx4 * cx4 * c + cx4 * d + e));
329  unsigned int solutions = 0u;
331  if (NumericT<T>::isEqualEps(cx1.imag()))
332  {
333  const T solution = cx1.real();
334  if (NumericT<T>::isWeakEqualEps(solution * solution * solution * solution * a + solution * solution * solution * b + solution * solution * c + solution * d + e))
335  {
336  x[solutions++] = solution;
337  }
338  }
340  if (NumericT<T>::isEqualEps(cx2.imag()))
341  {
342  const T solution = cx2.real();
343  if (NumericT<T>::isWeakEqualEps(solution * solution * solution * solution * a + solution * solution * solution * b + solution * solution * c + solution * d + e))
344  {
345  x[solutions++] = solution;
346  }
347  }
349  if (NumericT<T>::isEqualEps(cx3.imag()))
350  {
351  const T solution = cx3.real();
352  if (NumericT<T>::isWeakEqualEps(solution * solution * solution * solution * a + solution * solution * solution * b + solution * solution * c + solution * d + e))
353  {
354  x[solutions++] = solution;
355  }
356  }
358  if (NumericT<T>::isEqualEps(cx4.imag()))
359  {
360  const T solution = cx4.real();
361  if (NumericT<T>::isWeakEqualEps(solution * solution * solution * solution * a + solution * solution * solution * b + solution * solution * c + solution * d + e))
362  {
363  x[solutions++] = solution;
364  }
365  }
367  return solutions;
368  }
370  //const std::complex<T> p(-(alpha * alpha) / T(12.0) - gamma);
371  const std::complex<T> p(T(-0.08333333333333333333333333333333) * alpha * alpha - gamma);
373  //const std::complex<T> q(-(alpha * alpha * alpha) / T(108.0) + (alpha * gamma) / T(3.0) - (beta * beta) / T(8.0));
374  const std::complex<T> q(T(-0.00925925925925925925925925925926) * alpha * alpha * alpha + T(0.33333333333333333333333333333333) * alpha * gamma - T(0.125) * beta * beta);
376  const std::complex<T> qqSqr = NumericT<T>::sqrt(std::complex<T>(T(0.25)) * q * q + std::complex<T>(T(0.03703703703703703703703703703704)) * p * p * p);
378  const std::complex<T> r(std::complex<T>(T(-0.5)) * q + qqSqr);
380  {
381  return 0u;
382  }
384  const std::complex<T> u = NumericT<T>::pow(std::complex<T>(r), T(0.33333333333333333333333333333333));
387  {
388  return 0u;
389  }
391  std::complex<T> y;
392  if (NumericT<T>::isEqualEps(u.real()) && NumericT<T>::isEqualEps(u.imag()))
393  {
394  y = std::complex<T>(T(-0.83333333333333333333333333333333) * alpha) + u - NumericT<T>::pow(std::complex<T>(q), T(0.33333333333333333333333333333333));
395  }
396  else
397  {
398  y = std::complex<T>(T(-0.83333333333333333333333333333333) * alpha) + u - p / (std::complex<T>(3) * u);
399  }
401  //const std::complex<T> w_(NumericT<T>::sqrt(std::complex<T>(alpha) + std::complex<T>(2) * y));
402  const std::complex<T> w(NumericT<T>::sqrt(std::complex<T>(T(0.25) * alpha) + std::complex<T>(T(0.5)) * y));
404  //const std::complex<T> cx1 = std::complex<T>(T(-0.25)) * std::complex<T>(b / a) + std::complex<T>(T(0.5)) * (w + NumericT<T>::sqrt(std::complex<T>(-1) * (std::complex<T>(3) * std::complex<T>(alpha) + std::complex<T>(2) * y + std::complex<T>(2) * std::complex<T>(beta) / w)));
407  {
408  return 0u;
409  }
411  const std::complex<T> beta2_w(std::complex<T>(T(-0.25) * beta) / w);
412  const std::complex<T> alpha3y2 = std::complex<T>(T(-0.75) * alpha) - std::complex<T>(T(0.5)) * y;
413  const std::complex<T> b_a4(T(-0.25) * b_a);
415  const std::complex<T> sqrtPositive(NumericT<T>::sqrt(alpha3y2 + beta2_w));
416  const std::complex<T> sqrtNegative(NumericT<T>::sqrt(alpha3y2 - beta2_w));
418  const std::complex<T> cx1 = b_a4 + w + sqrtPositive;
419  const std::complex<T> cx2 = b_a4 + w - sqrtPositive;
420  const std::complex<T> cx3 = b_a4 - w + sqrtNegative;
421  const std::complex<T> cx4 = b_a4 - w - sqrtNegative;
423  unsigned int solutions = 0u;
425  if (NumericT<T>::isEqualEps(cx1.imag()))
426  {
427  const T solution = cx1.real();
428  if (NumericT<T>::isWeakEqualEps(solution * solution * solution * solution * a + solution * solution * solution * b + solution * solution * c + solution * d + e))
429  {
430  x[solutions++] = solution;
431  }
432  }
434  if (NumericT<T>::isEqualEps(cx2.imag()))
435  {
436  const T solution = cx2.real();
437  if (NumericT<T>::isWeakEqualEps(solution * solution * solution * solution * a + solution * solution * solution * b + solution * solution * c + solution * d + e))
438  {
439  x[solutions++] = solution;
440  }
441  }
443  if (NumericT<T>::isEqualEps(cx3.imag()))
444  {
445  const T solution = cx3.real();
446  if (NumericT<T>::isWeakEqualEps(solution * solution * solution * solution * a + solution * solution * solution * b + solution * solution * c + solution * d + e))
447  {
448  x[solutions++] = solution;
449  }
450  }
452  if (NumericT<T>::isEqualEps(cx4.imag()))
453  {
454  const T solution = cx4.real();
455  if (NumericT<T>::isWeakEqualEps(solution * solution * solution * solution * a + solution * solution * solution * b + solution * solution * c + solution * d + e))
456  {
457  x[solutions++] = solution;
458  }
459  }
461  return solutions;
462 }
464 }
