Third step: breaking the code into functions
Contents
Third step: breaking the code into functions¶
In this third step, we break the big brent function into smaller pieces. The benefits in
this example are minimal: the resulting code is longer, and whether it is simpler to read
is debatable.
Nevertheless, this is an illustration of the process.
When breaking up code, it can become cumbersome to move data around: function signatures can become big. Note that we still have settings as top-level module declarations. Such practices make it difficult to understand where values come from.
We documented the functions using Google style docstrings
Example Sphinx autodoc output¶
Such documentation strings are automatically understood by Sphinx.
- root3.inverse_quadratic_interpolation_step(a, b, c, fa, fb, fc)[source]¶
 Computes an approximation for a zero of a 1D function from three function values
Note
The values
fa,fb,fcneed all to be distinct.See https://en.wikipedia.org/wiki/Inverse_quadratic_interpolation
- Parameters
 a (
float) – First x coordinateb (
float) – Second x coordinatec (
float) – Third x coordinatefa (
float) – f(a)fb (
float) – f(b)fc (
float) – f(c)
- Return type
 float- Returns
 An approximation of the zero
Python code¶
"""
Root-finding solver based on Brent's method
See `<https://en.wikipedia.org/wiki/Brent%27s_method>`_
"""
from __future__ import annotations
from math import cos
from typing import Callable, Optional, Tuple
x_rel_tol: float = 1e-12  #: X relative tolerance
x_abs_tol: float = 1e-12  #: X absolute tolerance
y_tol: float = 1e-12  #: Y tolerance
verbose: bool = True  #: Whether to display progress
def test_convergence(a: float, b: float, fb: float) -> Tuple[bool, Optional[str]]:
    """
    Checks convergence of a root finding method
    Args:
        a: Contrapoint
        b: Best guess
        fb: f(b)
    Returns:
        Whether the root finding converges to the specified tolerance and why
    """
    if fb == 0:
        return (True, "Exact root found")
    x_delta = abs(a - b)
    if x_delta <= x_abs_tol:
        return (True, "Met x_abs_tol criterion")
    if x_delta / max(a, b) <= x_rel_tol:
        return (True, "Met x_rel_tol criterion")
    y_delta = abs(fb)
    if y_delta <= y_tol:
        return (True, "Met y_tol criterion")
    return (False, None)
def inverse_quadratic_interpolation_step(
    a: float, b: float, c: float, fa: float, fb: float, fc: float
) -> float:
    """
    Computes an approximation for a zero of a 1D function from three function values
    Note:
        The values ``fa``, ``fb``, ``fc`` need all to be distinct.
    See `<https://en.wikipedia.org/wiki/Inverse_quadratic_interpolation>`_
    Args:
        a: First x coordinate
        b: Second x coordinate
        c: Third x coordinate
        fa: f(a)
        fb: f(b)
        fc: f(c)
    Returns:
        An approximation of the zero
    """
    L0 = (a * fb * fc) / ((fa - fb) * (fa - fc))
    L1 = (b * fa * fc) / ((fb - fa) * (fb - fc))
    L2 = (c * fb * fa) / ((fc - fa) * (fc - fb))
    return L0 + L1 + L2
def secant_step(a: float, b: float, fa: float, fb: float) -> float:
    """
    Computes an approximation for a zero of a 1D function from two function values
    Note:
        The values ``fa`` and ``fb`` need to have a different sign.
    Args:
        a: First x coordinate
        b: Second x coordinate
        fa: f(a)
        fb: f(b)
    Returns:
        An approximation of the zero
    """
    return b - fb * (b - a) / (fb - fa)
def bisection_step(a: float, b: float) -> float:
    """
    Computes an approximation for a zero of a 1D function from two function values
    Note:
        The values ``f(a)`` and ``f(b)`` (not needed in the code) need to have a different sign.
    Args:
        a: First x coordinate
        b: Second x coordinate
    Returns:
        An approximation of the zero
    """
    return min(a, b) + abs(b - a) / 2
def brent(f: Callable[[float], float], a: float, b: float) -> float:
    """
    Finds the root of a function using Brent's method, starting from an interval enclosing the zero
    Args:
        f: Function to find the root of
        a: First x coordinate enclosing the root
        b: Second x coordinate enclosing the root
    Returns:
        The approximate root
    """
    fa = f(a)  #: f(a)
    fb = f(b)  #: f(b)
    assert fa * fb <= 0, "Root not bracketed"
    if abs(fa) < abs(fb):
        # force abs(fa) >= abs(fb), make sure that b is the best root approximation known so far
        # and a is the contrapoint
        b, a = a, b
        fb, fa = fa, fb
    c = a  #: Previous iterate
    fc = fa  #: f(c)
    d = a  #: Iterate before the previous iterate
    fd = fa  #: f(d)
    last_step: Optional[str] = None
    step: Optional[str] = None
    iter = 1  #: Current iteration number (1-based)
    converged = None
    reason = None
    def print_state():
        """
        Prints information about an iteration
        """
        dx = abs(a - b)
        dy = abs(fa - fb)
        print(f"{iter}\t{b:.3e}\t{fb:.3e}\t{dx:.3e}\t{dy:.3e}\t{last_step}")
    if verbose:
        print("Iter\tx\t\tf(x)\t\tdelta(x)\tdelta(f(x))\tstep")
        print_state()
    while not converged:
        iter = iter + 1
        last_step = step
        if fa != fc and fb != fc:
            s = inverse_quadratic_interpolation_step(a, b, c, fa, fb, fc)
            step = "quadratic"
        else:
            s = secant_step(a, b, fa, fb)
            step = "secant"
        perform_bisection = False
        if a <= b and not ((3 * a + b) / 4 <= s <= b):
            perform_bisection = True
        elif b <= a and not (b <= s <= (3 * a + b) / 4):
            perform_bisection = True
        elif last_step == "bisection" and abs(s - b) >= abs(b - c) / 2:
            perform_bisection = True
        elif last_step != "bisection" and abs(a - b) >= abs(c - d) / 2:
            perform_bisection = True
        elif last_step == "bisection" and abs(b - c) < x_abs_tol:
            perform_bisection = True
        elif last_step != "bisection" and abs(c - d) < x_abs_tol:
            perform_bisection = True
        if perform_bisection:
            s = bisection_step(a, b)
            step = "bisection"
        fs = f(s)
        d = c
        c = b
        fc = fb
        # check which point to replace to maintain (a,b) have different signs
        if f(a) * f(s) < 0:
            b = s
            fb = fs
        else:
            a = s
            fa = fs
        # keep b as the best guess
        if abs(fa) < abs(fb):
            b, a = a, b
            fb, fa = fa, fb
        converged, reason = test_convergence(a, b, fb)
        if verbose:
            print_state()
    if verbose:
        assert reason is not None
        print(reason)
    return b
if __name__ == "__main__":
    print(brent(cos, 0.0, 3.0))
Sample execution¶
We display the execution below.
$ python typing/root3.py
Iter	x		f(x)		delta(x)	delta(f(x))	step
1	3.000e+00	-9.900e-01	3.000e+00	1.990e+00	None
2	1.500e+00	7.074e-02	1.500e+00	1.061e+00	None
3	1.600e+00	-2.923e-02	1.000e-01	9.997e-02	bisection
4	1.571e+00	1.434e-05	2.925e-02	2.924e-02	secant
5	1.571e+00	-2.042e-09	1.434e-05	1.434e-05	secant
6	1.571e+00	6.123e-17	2.042e-09	2.042e-09	secant
Met y_tol criterion
1.5707963267948966