Source code for root5

"""
Root-finding solver based on Brent's method

See `<https://en.wikipedia.org/wiki/Brent%27s_method>`_
"""
from __future__ import annotations

from dataclasses import dataclass
from math import cos
from typing import Callable, Literal, Optional, Tuple, Union


[docs]@dataclass(frozen=True) class Settings: """ Settings for the root-finding algorithm """ x_rel_tol: float = 1e-12 #: X relative tolerance x_abs_tol: float = 1e-12 #: X absolute tolerance y_tol: float = 1e-12 #: Y tolerance verbose: bool = False #: Whether to display progress def __post_init__(self) -> None: """ Verifies sanity of parameters """ assert self.x_rel_tol >= 0 assert self.x_abs_tol >= 0 assert self.y_tol >= 0 assert ( self.x_rel_tol > 0 or self.x_abs_tol > 0 or self.y_tol > 0 ), "At least one convergence criteria must be set"
[docs] def converged(self, a: float, b: float, fb: float) -> Tuple[bool, Optional[str]]: """ Checks convergence of a root finding method Args: a: Contrapoint b: Best guess fb: f(b) Returns: Whether the root finding converges to the specified tolerance and why """ if fb == 0: return (True, "Exact root found") x_delta = abs(a - b) if x_delta <= self.x_abs_tol: return (True, "Met x_abs_tol criterion") if x_delta / max(a, b) <= self.x_rel_tol: return (True, "Met x_rel_tol criterion") y_delta = abs(fb) if y_delta <= self.y_tol: return (True, "Met y_tol criterion") return (False, None)
[docs] @staticmethod def default_verbose() -> Settings: """ Returns default verbose settings """ return Settings(verbose=True)
def inverse_quadratic_interpolation_step( a: float, b: float, c: float, fa: float, fb: float, fc: float ) -> float: """ Computes an approximation for a zero of a 1D function from three function values Note: The values ``fa``, ``fb``, ``fc`` need all to be distinct. See `<https://en.wikipedia.org/wiki/Inverse_quadratic_interpolation>`_ Args: a: First x coordinate b: Second x coordinate c: Third x coordinate fa: f(a) fb: f(b) fc: f(c) Returns: An approximation of the zero """ L0 = (a * fb * fc) / ((fa - fb) * (fa - fc)) L1 = (b * fa * fc) / ((fb - fa) * (fb - fc)) L2 = (c * fb * fa) / ((fc - fa) * (fc - fb)) return L0 + L1 + L2 def secant_step(a: float, b: float, fa: float, fb: float) -> float: """ Computes an approximation for a zero of a 1D function from two function values Note: The values ``fa`` and ``fb`` need to have a different sign. Args: a: First x coordinate b: Second x coordinate fa: f(a) fb: f(b) Returns: An approximation of the zero """ return b - fb * (b - a) / (fb - fa) def bisection_step(a: float, b: float) -> float: """ Computes an approximation for a zero of a 1D function from two function values Note: The values ``f(a)`` and ``f(b)`` (not needed in the code) need to have a different sign. Args: a: First x coordinate b: Second x coordinate Returns: An approximation of the zero """ return min(a, b) + abs(b - a) / 2
[docs]@dataclass(frozen=True) class BrentState: """ State for Brent's method Note: We have the following invariants. - ``fa = f(a)`` and ``fb = f(b)`` have opposite signs - ``abs(fb) <= abs(fa)`` so that ``b`` is the best guess """ a: float #: Contrapoint b: float #: Current iterate, best root approximation known so far c: float #: Previous iterate d: float #: Iterate before the previous iterate fa: float #: f(a) fb: float #: f(b) fc: float #: f(c) last_step: Optional[ Union[Literal["quadratic"], Literal["secant"], Literal["bisection"]] ] #: Type of previous step iter: int #: Current iteration number (1-based)
[docs] @staticmethod def make(f: Callable[[float], float], a: float, b: float) -> BrentState: """ Initializes a state from an interval that brackets a root of the given function Args: f: Function to find the root of a: First x coordinate b: Second x coordinate Returns: The initial state """ fa = f(a) fb = f(b) # check the first invariant assert fa * fb <= 0, "Root not bracketed" if abs(fa) < abs(fb): # force the second invariant b, a = a, b fb, fa = fa, fb c, fc = a, fa d, fd = a, fa return BrentState(a=a, b=b, c=c, d=d, fa=fa, fb=fb, fc=fc, last_step=None, iter=1)
def brent_step(f: Callable[[float], float], state: BrentState, delta: float) -> BrentState: """ Performs a step of Brent's method Args: f: Function to find the root of state: Previous state delta: x absolute tolerance Returns: New state """ a, b, c, d = state.a, state.b, state.c, state.d fa, fb, fc = state.fa, state.fb, state.fc last_step = state.last_step iter = state.iter step: Optional[Union[Literal["quadratic"], Literal["secant"], Literal["bisection"]]] = None if fa != fc and fb != fc: s = inverse_quadratic_interpolation_step(a, b, c, fa, fb, fc) step = "quadratic" else: s = secant_step(a, b, fa, fb) step = "secant" perform_bisection = False if a <= b and not ((3 * a + b) / 4 <= s <= b): perform_bisection = True elif b <= a and not (b <= s <= (3 * a + b) / 4): perform_bisection = True elif last_step == "bisection" and abs(s - b) >= abs(b - c) / 2: perform_bisection = True elif last_step != "bisection" and abs(a - b) >= abs(c - d) / 2: perform_bisection = True elif last_step == "bisection" and abs(b - c) < delta: perform_bisection = True elif last_step != "bisection" and abs(c - d) < delta: perform_bisection = True if perform_bisection: s = bisection_step(a, b) step = "bisection" fs = f(s) d = c c = b fc = fb # check which point to replace to maintain (a,b) have different signs if f(a) * f(s) < 0: b = s fb = fs else: a = s fa = fs # keep b as the best guess if abs(fa) < abs(fb): b, a = a, b fb, fa = fa, fb return BrentState(a=a, b=b, c=c, d=d, fa=fa, fb=fb, fc=fc, last_step=step, iter=iter + 1) def brent( f: Callable[[float], float], a: float, b: float, settings: Settings = Settings() ) -> float: """ Finds the root of a function using Brent's method, starting from an interval enclosing the zero Args: f: Function to find the root of a: First x coordinate enclosing the root b: Second x coordinate enclosing the root settings: Algorithm settings Returns: The approximate root """ state = BrentState.make(f, a, b) converged = None reason = None def print_state(s: BrentState): """ Prints information about an iteration """ dx = abs(s.a - s.b) dy = abs(s.fa - s.fb) ls: Optional[str] = s.last_step if ls is None: ls = "" print(f"{s.iter}\t{s.b:.3e}\t{s.fb:.3e}\t{dx:.3e}\t{dy:.3e}\t{ls}") if settings.verbose: print("Iter\tx\t\tf(x)\t\tdelta(x)\tdelta(f(x))\tstep") print_state(state) while not converged: state = brent_step(f, state, settings.x_abs_tol) converged, reason = settings.converged(state.a, state.b, state.fb) if settings.verbose: print_state(state) if settings.verbose: assert reason is not None print(reason) return state.b if __name__ == "__main__": print(brent(cos, 0.0, 3.0, Settings.default_verbose()))