Third step: breaking the code into functions

Third step: breaking the code into functions

In this third step, we break the big brent function into smaller pieces. The benefits in this example are minimal: the resulting code is longer, and whether it is simpler to read is debatable.

Nevertheless, this is an illustration of the process.

When breaking up code, it can become cumbersome to move data around: function signatures can become big. Note that we still have settings as top-level module declarations. Such practices make it difficult to understand where values come from.

We documented the functions using Google style docstrings

Example Sphinx autodoc output

Such documentation strings are automatically understood by Sphinx.

root3.inverse_quadratic_interpolation_step(a, b, c, fa, fb, fc)[source]

Computes an approximation for a zero of a 1D function from three function values

Note

The values fa, fb, fc need all to be distinct.

See https://en.wikipedia.org/wiki/Inverse_quadratic_interpolation

Parameters
  • a (float) – First x coordinate

  • b (float) – Second x coordinate

  • c (float) – Third x coordinate

  • fa (float) – f(a)

  • fb (float) – f(b)

  • fc (float) – f(c)

Return type

float

Returns

An approximation of the zero

Python code

"""
Root-finding solver based on Brent's method

See `<https://en.wikipedia.org/wiki/Brent%27s_method>`_
"""
from __future__ import annotations

from math import cos
from typing import Callable, Optional, Tuple

x_rel_tol: float = 1e-12  #: X relative tolerance
x_abs_tol: float = 1e-12  #: X absolute tolerance
y_tol: float = 1e-12  #: Y tolerance
verbose: bool = True  #: Whether to display progress


def test_convergence(a: float, b: float, fb: float) -> Tuple[bool, Optional[str]]:
    """
    Checks convergence of a root finding method

    Args:
        a: Contrapoint
        b: Best guess
        fb: f(b)

    Returns:
        Whether the root finding converges to the specified tolerance and why
    """
    if fb == 0:
        return (True, "Exact root found")
    x_delta = abs(a - b)
    if x_delta <= x_abs_tol:
        return (True, "Met x_abs_tol criterion")
    if x_delta / max(a, b) <= x_rel_tol:
        return (True, "Met x_rel_tol criterion")
    y_delta = abs(fb)
    if y_delta <= y_tol:
        return (True, "Met y_tol criterion")
    return (False, None)


def inverse_quadratic_interpolation_step(
    a: float, b: float, c: float, fa: float, fb: float, fc: float
) -> float:
    """
    Computes an approximation for a zero of a 1D function from three function values

    Note:
        The values ``fa``, ``fb``, ``fc`` need all to be distinct.

    See `<https://en.wikipedia.org/wiki/Inverse_quadratic_interpolation>`_

    Args:
        a: First x coordinate
        b: Second x coordinate
        c: Third x coordinate
        fa: f(a)
        fb: f(b)
        fc: f(c)

    Returns:
        An approximation of the zero
    """
    L0 = (a * fb * fc) / ((fa - fb) * (fa - fc))
    L1 = (b * fa * fc) / ((fb - fa) * (fb - fc))
    L2 = (c * fb * fa) / ((fc - fa) * (fc - fb))
    return L0 + L1 + L2


def secant_step(a: float, b: float, fa: float, fb: float) -> float:
    """
    Computes an approximation for a zero of a 1D function from two function values

    Note:
        The values ``fa`` and ``fb`` need to have a different sign.

    Args:
        a: First x coordinate
        b: Second x coordinate
        fa: f(a)
        fb: f(b)

    Returns:
        An approximation of the zero
    """
    return b - fb * (b - a) / (fb - fa)


def bisection_step(a: float, b: float) -> float:
    """
    Computes an approximation for a zero of a 1D function from two function values

    Note:
        The values ``f(a)`` and ``f(b)`` (not needed in the code) need to have a different sign.

    Args:
        a: First x coordinate
        b: Second x coordinate

    Returns:
        An approximation of the zero
    """
    return min(a, b) + abs(b - a) / 2


def brent(f: Callable[[float], float], a: float, b: float) -> float:
    """
    Finds the root of a function using Brent's method, starting from an interval enclosing the zero

    Args:
        f: Function to find the root of
        a: First x coordinate enclosing the root
        b: Second x coordinate enclosing the root

    Returns:
        The approximate root
    """

    fa = f(a)  #: f(a)
    fb = f(b)  #: f(b)

    assert fa * fb <= 0, "Root not bracketed"

    if abs(fa) < abs(fb):
        # force abs(fa) >= abs(fb), make sure that b is the best root approximation known so far
        # and a is the contrapoint
        b, a = a, b
        fb, fa = fa, fb

    c = a  #: Previous iterate
    fc = fa  #: f(c)
    d = a  #: Iterate before the previous iterate
    fd = fa  #: f(d)
    last_step: Optional[str] = None
    step: Optional[str] = None
    iter = 1  #: Current iteration number (1-based)
    converged = None
    reason = None

    def print_state():
        """
        Prints information about an iteration
        """
        dx = abs(a - b)
        dy = abs(fa - fb)
        print(f"{iter}\t{b:.3e}\t{fb:.3e}\t{dx:.3e}\t{dy:.3e}\t{last_step}")

    if verbose:
        print("Iter\tx\t\tf(x)\t\tdelta(x)\tdelta(f(x))\tstep")
        print_state()

    while not converged:
        iter = iter + 1
        last_step = step
        if fa != fc and fb != fc:
            s = inverse_quadratic_interpolation_step(a, b, c, fa, fb, fc)
            step = "quadratic"
        else:
            s = secant_step(a, b, fa, fb)
            step = "secant"
        perform_bisection = False
        if a <= b and not ((3 * a + b) / 4 <= s <= b):
            perform_bisection = True
        elif b <= a and not (b <= s <= (3 * a + b) / 4):
            perform_bisection = True
        elif last_step == "bisection" and abs(s - b) >= abs(b - c) / 2:
            perform_bisection = True
        elif last_step != "bisection" and abs(a - b) >= abs(c - d) / 2:
            perform_bisection = True
        elif last_step == "bisection" and abs(b - c) < x_abs_tol:
            perform_bisection = True
        elif last_step != "bisection" and abs(c - d) < x_abs_tol:
            perform_bisection = True
        if perform_bisection:
            s = bisection_step(a, b)
            step = "bisection"
        fs = f(s)
        d = c
        c = b
        fc = fb
        # check which point to replace to maintain (a,b) have different signs
        if f(a) * f(s) < 0:
            b = s
            fb = fs
        else:
            a = s
            fa = fs
        # keep b as the best guess
        if abs(fa) < abs(fb):
            b, a = a, b
            fb, fa = fa, fb
        converged, reason = test_convergence(a, b, fb)
        if verbose:
            print_state()

    if verbose:
        assert reason is not None
        print(reason)

    return b


if __name__ == "__main__":
    print(brent(cos, 0.0, 3.0))

Sample execution

We display the execution below.

$ python typing/root3.py
Iter	x		f(x)		delta(x)	delta(f(x))	step
1	3.000e+00	-9.900e-01	3.000e+00	1.990e+00	None
2	1.500e+00	7.074e-02	1.500e+00	1.061e+00	None
3	1.600e+00	-2.923e-02	1.000e-01	9.997e-02	bisection
4	1.571e+00	1.434e-05	2.925e-02	2.924e-02	secant
5	1.571e+00	-2.042e-09	1.434e-05	1.434e-05	secant
6	1.571e+00	6.123e-17	2.042e-09	2.042e-09	secant
Met y_tol criterion
1.5707963267948966