Second step: script to module
Second step: script to module¶
In this second step, we put the algorithm in its own function. We keep the “cos” example,
inside a if __name__ == "__main__": code fence that executes only if the script is
run with a python root2.py invocation.
The settings are variables declared at the module level. It is not considered good programming practice (similar to the use of global variables), but is simpler.
The script root2.py is now a Python module that can be imported in other files:
from root2 import brent
print(brent(cos, 0.0, 3.0))
Here is the code, with minimal changes.
"""
Root-finding solver based on Brent's method
See `<https://en.wikipedia.org/wiki/Brent%27s_method>`_
"""
from math import cos
from typing import Callable, Optional
x_rel_tol = 1e-12  #: X relative tolerance
x_abs_tol = 1e-12  #: X absolute tolerance
y_tol = 1e-12  #: Y tolerance
verbose = True  #: Whether to display progress
def brent(f: Callable[[float], float], a: float, b: float) -> float:
    """
    Finds the root of a function using Brent's method, starting from an interval enclosing the zero
    Args:
        f: Function to find the root of
        a: First x coordinate enclosing the root
        b: Second x coordinate enclosing the root
    Returns:
        The approximate root
    """
    fa = f(a)  #: f(a)
    fb = f(b)  #: f(b)
    assert fa * fb <= 0, "Root not bracketed"
    if abs(fa) < abs(fb):
        # force abs(fa) >= abs(fb), make sure that b is the best root approximation known so far
        # and a is the contrapoint
        b, a = a, b
        fb, fa = fa, fb
    c = a  #: Previous iterate
    fc = fa  #: f(c)
    d = a  #: Iterate before the previous iterate
    fd = fa  #: f(d)
    last_step: Optional[str] = None
    step: Optional[str] = None
    iter = 1  #: Current iteration number (1-based)
    converged = False  #: Whether we have converged
    if verbose:
        print("Iter\tx\t\tf(x)\t\tdelta(x)\tdelta(f(x))\tstep")
        dx = abs(a - b)
        dy = abs(fa - fb)
        print(f"{iter}\t{b:.3e}\t{fb:.3e}\t{dx:.3e}\t{dy:.3e}\t{last_step}")
    while not converged:
        iter = iter + 1
        last_step = step
        if fa != fc and fb != fc:
            # perform a quadratic interpolation step
            # https://en.wikipedia.org/wiki/Inverse_quadratic_interpolation
            L0 = (a * fb * fc) / ((fa - fb) * (fa - fc))
            L1 = (b * fa * fc) / ((fb - fa) * (fb - fc))
            L2 = (c * fb * fa) / ((fc - fa) * (fc - fb))
            s = L0 + L1 + L2
            step = "quadratic"
        else:
            # perform a secant step
            s = b - fb * (b - a) / (fb - fa)
            step = "secant"
        perform_bisection = False
        if a <= b and not ((3 * a + b) / 4 <= s <= b):
            perform_bisection = True
        elif b <= a and not (b <= s <= (3 * a + b) / 4):
            perform_bisection = True
        elif last_step == "bisection" and abs(s - b) >= abs(b - c) / 2:
            perform_bisection = True
        elif last_step != "bisection" and abs(a - b) >= abs(c - d) / 2:
            perform_bisection = True
        elif last_step == "bisection" and abs(b - c) < x_abs_tol:
            perform_bisection = True
        elif last_step != "bisection" and abs(c - d) < x_abs_tol:
            perform_bisection = True
        if perform_bisection:
            # perform a bisection step
            s = min(a, b) + abs(b - a) / 2
            step = "bisection"
        fs = f(s)
        d = c
        c = b
        fc = fb
        # check which point to replace to maintain (a,b) have different signs
        if f(a) * f(s) < 0:
            b = s
            fb = fs
        else:
            a = s
            fa = fs
        # keep b as the best guess
        if abs(fa) < abs(fb):
            b, a = a, b
            fb, fa = fa, fb
        # Checks convergence
        if fb == 0:
            converged = True
            if verbose:
                print("Exact root found")
        x_delta = abs(a - b)
        if x_delta <= x_abs_tol:
            converged = True
            if verbose:
                print("Met x_abs_tol criterion")
        if x_delta / max(a, b) <= x_rel_tol:
            converged = True
            if verbose:
                print("Met x_rel_tol criterion")
        y_delta = abs(fb)
        if y_delta <= y_tol:
            converged = True
            if verbose:
                print("Met y_tol criterion")
        if verbose:
            dx = abs(a - b)
            dy = abs(fa - fb)
            print(f"{iter}\t{b:.3e}\t{fb:.3e}\t{dx:.3e}\t{dy:.3e}\t{last_step}")
    return b
if __name__ == "__main__":
    print(brent(cos, 0.0, 3.0))
and we display the execution below.
$ python typing/root2.py
Iter	x		f(x)		delta(x)	delta(f(x))	step
1	3.000e+00	-9.900e-01	3.000e+00	1.990e+00	None
2	1.500e+00	7.074e-02	1.500e+00	1.061e+00	None
3	1.600e+00	-2.923e-02	1.000e-01	9.997e-02	bisection
4	1.571e+00	1.434e-05	2.925e-02	2.924e-02	secant
5	1.571e+00	-2.042e-09	1.434e-05	1.434e-05	secant
Met y_tol criterion
6	1.571e+00	6.123e-17	2.042e-09	2.042e-09	secant
1.5707963267948966