Source code for root3

"""
Root-finding solver based on Brent's method

See `<https://en.wikipedia.org/wiki/Brent%27s_method>`_
"""
from __future__ import annotations

from math import cos
from typing import Callable, Optional, Tuple

x_rel_tol: float = 1e-12  #: X relative tolerance
x_abs_tol: float = 1e-12  #: X absolute tolerance
y_tol: float = 1e-12  #: Y tolerance
verbose: bool = True  #: Whether to display progress


def test_convergence(a: float, b: float, fb: float) -> Tuple[bool, Optional[str]]:
    """
    Checks convergence of a root finding method

    Args:
        a: Contrapoint
        b: Best guess
        fb: f(b)

    Returns:
        Whether the root finding converges to the specified tolerance and why
    """
    if fb == 0:
        return (True, "Exact root found")
    x_delta = abs(a - b)
    if x_delta <= x_abs_tol:
        return (True, "Met x_abs_tol criterion")
    if x_delta / max(a, b) <= x_rel_tol:
        return (True, "Met x_rel_tol criterion")
    y_delta = abs(fb)
    if y_delta <= y_tol:
        return (True, "Met y_tol criterion")
    return (False, None)


[docs]def inverse_quadratic_interpolation_step( a: float, b: float, c: float, fa: float, fb: float, fc: float ) -> float: """ Computes an approximation for a zero of a 1D function from three function values Note: The values ``fa``, ``fb``, ``fc`` need all to be distinct. See `<https://en.wikipedia.org/wiki/Inverse_quadratic_interpolation>`_ Args: a: First x coordinate b: Second x coordinate c: Third x coordinate fa: f(a) fb: f(b) fc: f(c) Returns: An approximation of the zero """ L0 = (a * fb * fc) / ((fa - fb) * (fa - fc)) L1 = (b * fa * fc) / ((fb - fa) * (fb - fc)) L2 = (c * fb * fa) / ((fc - fa) * (fc - fb)) return L0 + L1 + L2
def secant_step(a: float, b: float, fa: float, fb: float) -> float: """ Computes an approximation for a zero of a 1D function from two function values Note: The values ``fa`` and ``fb`` need to have a different sign. Args: a: First x coordinate b: Second x coordinate fa: f(a) fb: f(b) Returns: An approximation of the zero """ return b - fb * (b - a) / (fb - fa) def bisection_step(a: float, b: float) -> float: """ Computes an approximation for a zero of a 1D function from two function values Note: The values ``f(a)`` and ``f(b)`` (not needed in the code) need to have a different sign. Args: a: First x coordinate b: Second x coordinate Returns: An approximation of the zero """ return min(a, b) + abs(b - a) / 2 def brent(f: Callable[[float], float], a: float, b: float) -> float: """ Finds the root of a function using Brent's method, starting from an interval enclosing the zero Args: f: Function to find the root of a: First x coordinate enclosing the root b: Second x coordinate enclosing the root Returns: The approximate root """ fa = f(a) #: f(a) fb = f(b) #: f(b) assert fa * fb <= 0, "Root not bracketed" if abs(fa) < abs(fb): # force abs(fa) >= abs(fb), make sure that b is the best root approximation known so far # and a is the contrapoint b, a = a, b fb, fa = fa, fb c = a #: Previous iterate fc = fa #: f(c) d = a #: Iterate before the previous iterate fd = fa #: f(d) last_step: Optional[str] = None step: Optional[str] = None iter = 1 #: Current iteration number (1-based) converged = None reason = None def print_state(): """ Prints information about an iteration """ dx = abs(a - b) dy = abs(fa - fb) print(f"{iter}\t{b:.3e}\t{fb:.3e}\t{dx:.3e}\t{dy:.3e}\t{last_step}") if verbose: print("Iter\tx\t\tf(x)\t\tdelta(x)\tdelta(f(x))\tstep") print_state() while not converged: iter = iter + 1 last_step = step if fa != fc and fb != fc: s = inverse_quadratic_interpolation_step(a, b, c, fa, fb, fc) step = "quadratic" else: s = secant_step(a, b, fa, fb) step = "secant" perform_bisection = False if a <= b and not ((3 * a + b) / 4 <= s <= b): perform_bisection = True elif b <= a and not (b <= s <= (3 * a + b) / 4): perform_bisection = True elif last_step == "bisection" and abs(s - b) >= abs(b - c) / 2: perform_bisection = True elif last_step != "bisection" and abs(a - b) >= abs(c - d) / 2: perform_bisection = True elif last_step == "bisection" and abs(b - c) < x_abs_tol: perform_bisection = True elif last_step != "bisection" and abs(c - d) < x_abs_tol: perform_bisection = True if perform_bisection: s = bisection_step(a, b) step = "bisection" fs = f(s) d = c c = b fc = fb # check which point to replace to maintain (a,b) have different signs if f(a) * f(s) < 0: b = s fb = fs else: a = s fa = fs # keep b as the best guess if abs(fa) < abs(fb): b, a = a, b fb, fa = fa, fb converged, reason = test_convergence(a, b, fb) if verbose: print_state() if verbose: assert reason is not None print(reason) return b if __name__ == "__main__": print(brent(cos, 0.0, 3.0))