Fifth step: dataclasses

Python dataclasses are amazing alternatives to dicts when passing data around. Actually, I use them by default for the classes in my code, except if I have a reason not to (I cannot think of an example at the moment actually).

In the code below, we moved the two dicts to proper dataclasses, and demonstrate the following advantages we gained.

Mutable vs immutable data

We can declare dataclasses to be frozen. This means that their values cannot be modified after they are constructed, i.e. their instances are immutable objects.

Making a distinction between mutable and immutable pieces of data is great when we need to understand how data flows through a program.

In our code, the settings are immutable, as they are not modified during code execution (rather, different Settings instances can be constructed and used at different times).

The BrentState class is immutable as well. The Brent step method returns a new updated instance at every step.

However, this is feasible in this example as the data concerned is quite small (a few floating-point numbers). When dealing with larger datasets, often we will modify the data in place (a standard linear algebra example: in-place LU decomposition of matrices).

To have a mutable dataclass, one simply omits the (frozen=True) qualifier.

Methods

Dataclasses can have methods. This enables us to group functions related to a piece of data.

When some dataclass fields are derived from the values of other fields, as is the case for the initial state of Brent’s algorithm, we can use static methods instead of fiddling with the __init__ method. This works better with frozen dataclasses, inheritance and default parameters etc.

Also, we can provide additional static methods such as read_from_json to construct instances from various sources.

We use the __post_init__ method to verify the soundness of the constructed object, again using assertions.

Documenting invariants

A OOP principle is to have class invariants, i.e. constraints that are satisfied when an instance is constructed, and are preserved during code execution.

For example, in our BrentState class, we preserve the fact that b is the best-known approximation so far and that f(a) and f(b) have opposite signs.

Example Sphinx autodoc output

class root5.Settings(x_rel_tol=1e-12, x_abs_tol=1e-12, y_tol=1e-12, verbose=False)[source]

Settings for the root-finding algorithm

converged(a, b, fb)[source]

Checks convergence of a root finding method

Parameters
  • a (float) – Contrapoint

  • b (float) – Best guess

  • fb (float) – f(b)

Return type

Tuple[bool, Optional[str]]

Returns

Whether the root finding converges to the specified tolerance and why

static default_verbose()[source]

Returns default verbose settings

Return type

Settings

verbose: bool = False

Whether to display progress

x_abs_tol: float = 1e-12

X absolute tolerance

x_rel_tol: float = 1e-12

X relative tolerance

y_tol: float = 1e-12

Y tolerance

class root5.BrentState(a, b, c, d, fa, fb, fc, last_step, iter)[source]

State for Brent’s method

Note

We have the following invariants.

  • fa = f(a) and fb = f(b) have opposite signs

  • abs(fb) <= abs(fa) so that b is the best guess

a: float

Contrapoint

b: float

Current iterate, best root approximation known so far

c: float

Previous iterate

d: float

Iterate before the previous iterate

fa: float

f(a)

fb: float

f(b)

fc: float

f(c)

iter: int

Current iteration number (1-based)

last_step: Optional[Union[Literal['quadratic'], Literal['secant'], Literal['bisection']]]

Type of previous step

static make(f, a, b)[source]

Initializes a state from an interval that brackets a root of the given function

Parameters
  • f (Callable[[float], float]) – Function to find the root of

  • a (float) – First x coordinate

  • b (float) – Second x coordinate

Return type

BrentState

Returns

The initial state

Python code

"""
Root-finding solver based on Brent's method

See `<https://en.wikipedia.org/wiki/Brent%27s_method>`_
"""
from __future__ import annotations

from dataclasses import dataclass
from math import cos
from typing import Callable, Literal, Optional, Tuple, Union


@dataclass(frozen=True)
class Settings:
    """
    Settings for the root-finding algorithm
    """

    x_rel_tol: float = 1e-12  #: X relative tolerance
    x_abs_tol: float = 1e-12  #: X absolute tolerance
    y_tol: float = 1e-12  #: Y tolerance
    verbose: bool = False  #: Whether to display progress

    def __post_init__(self) -> None:
        """
        Verifies sanity of parameters
        """
        assert self.x_rel_tol >= 0
        assert self.x_abs_tol >= 0
        assert self.y_tol >= 0
        assert (
            self.x_rel_tol > 0 or self.x_abs_tol > 0 or self.y_tol > 0
        ), "At least one convergence criteria must be set"

    def converged(self, a: float, b: float, fb: float) -> Tuple[bool, Optional[str]]:
        """
        Checks convergence of a root finding method

        Args:
            a: Contrapoint
            b: Best guess
            fb: f(b)

        Returns:
            Whether the root finding converges to the specified tolerance and why
        """
        if fb == 0:
            return (True, "Exact root found")
        x_delta = abs(a - b)
        if x_delta <= self.x_abs_tol:
            return (True, "Met x_abs_tol criterion")
        if x_delta / max(a, b) <= self.x_rel_tol:
            return (True, "Met x_rel_tol criterion")
        y_delta = abs(fb)
        if y_delta <= self.y_tol:
            return (True, "Met y_tol criterion")
        return (False, None)

    @staticmethod
    def default_verbose() -> Settings:
        """
        Returns default verbose settings
        """
        return Settings(verbose=True)


def inverse_quadratic_interpolation_step(
    a: float, b: float, c: float, fa: float, fb: float, fc: float
) -> float:
    """
    Computes an approximation for a zero of a 1D function from three function values

    Note:
        The values ``fa``, ``fb``, ``fc`` need all to be distinct.

    See `<https://en.wikipedia.org/wiki/Inverse_quadratic_interpolation>`_

    Args:
        a: First x coordinate
        b: Second x coordinate
        c: Third x coordinate
        fa: f(a)
        fb: f(b)
        fc: f(c)

    Returns:
        An approximation of the zero
    """
    L0 = (a * fb * fc) / ((fa - fb) * (fa - fc))
    L1 = (b * fa * fc) / ((fb - fa) * (fb - fc))
    L2 = (c * fb * fa) / ((fc - fa) * (fc - fb))
    return L0 + L1 + L2


def secant_step(a: float, b: float, fa: float, fb: float) -> float:
    """
    Computes an approximation for a zero of a 1D function from two function values

    Note:
        The values ``fa`` and ``fb`` need to have a different sign.

    Args:
        a: First x coordinate
        b: Second x coordinate
        fa: f(a)
        fb: f(b)

    Returns:
        An approximation of the zero
    """
    return b - fb * (b - a) / (fb - fa)


def bisection_step(a: float, b: float) -> float:
    """
    Computes an approximation for a zero of a 1D function from two function values

    Note:
        The values ``f(a)`` and ``f(b)`` (not needed in the code) need to have a different sign.

    Args:
        a: First x coordinate
        b: Second x coordinate

    Returns:
        An approximation of the zero
    """
    return min(a, b) + abs(b - a) / 2


@dataclass(frozen=True)
class BrentState:
    """
    State for Brent's method

    Note:
        We have the following invariants.

        - ``fa = f(a)`` and ``fb = f(b)`` have opposite signs

        - ``abs(fb) <= abs(fa)`` so that ``b`` is the best guess
    """

    a: float  #: Contrapoint
    b: float  #: Current iterate, best root approximation known so far
    c: float  #: Previous iterate
    d: float  #: Iterate before the previous iterate
    fa: float  #: f(a)
    fb: float  #: f(b)
    fc: float  #: f(c)
    last_step: Optional[
        Union[Literal["quadratic"], Literal["secant"], Literal["bisection"]]
    ]  #: Type of previous step
    iter: int  #: Current iteration number (1-based)

    @staticmethod
    def make(f: Callable[[float], float], a: float, b: float) -> BrentState:
        """
        Initializes a state from an interval that brackets a root of the given function

        Args:
            f: Function to find the root of
            a: First x coordinate
            b: Second x coordinate

        Returns:
            The initial state
        """
        fa = f(a)
        fb = f(b)
        # check the first invariant
        assert fa * fb <= 0, "Root not bracketed"
        if abs(fa) < abs(fb):
            # force the second invariant
            b, a = a, b
            fb, fa = fa, fb
        c, fc = a, fa
        d, fd = a, fa
        return BrentState(a=a, b=b, c=c, d=d, fa=fa, fb=fb, fc=fc, last_step=None, iter=1)


def brent_step(f: Callable[[float], float], state: BrentState, delta: float) -> BrentState:
    """
    Performs a step of Brent's method

    Args:
        f: Function to find the root of
        state: Previous state
        delta: x absolute tolerance

    Returns:
        New state
    """
    a, b, c, d = state.a, state.b, state.c, state.d
    fa, fb, fc = state.fa, state.fb, state.fc
    last_step = state.last_step
    iter = state.iter
    step: Optional[Union[Literal["quadratic"], Literal["secant"], Literal["bisection"]]] = None
    if fa != fc and fb != fc:
        s = inverse_quadratic_interpolation_step(a, b, c, fa, fb, fc)
        step = "quadratic"
    else:
        s = secant_step(a, b, fa, fb)
        step = "secant"
    perform_bisection = False
    if a <= b and not ((3 * a + b) / 4 <= s <= b):
        perform_bisection = True
    elif b <= a and not (b <= s <= (3 * a + b) / 4):
        perform_bisection = True
    elif last_step == "bisection" and abs(s - b) >= abs(b - c) / 2:
        perform_bisection = True
    elif last_step != "bisection" and abs(a - b) >= abs(c - d) / 2:
        perform_bisection = True
    elif last_step == "bisection" and abs(b - c) < delta:
        perform_bisection = True
    elif last_step != "bisection" and abs(c - d) < delta:
        perform_bisection = True
    if perform_bisection:
        s = bisection_step(a, b)
        step = "bisection"
    fs = f(s)
    d = c
    c = b
    fc = fb
    # check which point to replace to maintain (a,b) have different signs
    if f(a) * f(s) < 0:
        b = s
        fb = fs
    else:
        a = s
        fa = fs
    # keep b as the best guess
    if abs(fa) < abs(fb):
        b, a = a, b
        fb, fa = fa, fb
    return BrentState(a=a, b=b, c=c, d=d, fa=fa, fb=fb, fc=fc, last_step=step, iter=iter + 1)


def brent(
    f: Callable[[float], float], a: float, b: float, settings: Settings = Settings()
) -> float:
    """
    Finds the root of a function using Brent's method, starting from an interval enclosing the zero

    Args:
        f: Function to find the root of
        a: First x coordinate enclosing the root
        b: Second x coordinate enclosing the root
        settings: Algorithm settings

    Returns:
        The approximate root
    """
    state = BrentState.make(f, a, b)
    converged = None
    reason = None

    def print_state(s: BrentState):
        """
        Prints information about an iteration
        """
        dx = abs(s.a - s.b)
        dy = abs(s.fa - s.fb)
        ls: Optional[str] = s.last_step
        if ls is None:
            ls = ""
        print(f"{s.iter}\t{s.b:.3e}\t{s.fb:.3e}\t{dx:.3e}\t{dy:.3e}\t{ls}")

    if settings.verbose:
        print("Iter\tx\t\tf(x)\t\tdelta(x)\tdelta(f(x))\tstep")
        print_state(state)

    while not converged:
        state = brent_step(f, state, settings.x_abs_tol)
        converged, reason = settings.converged(state.a, state.b, state.fb)
        if settings.verbose:
            print_state(state)

    if settings.verbose:
        assert reason is not None
        print(reason)

    return state.b


if __name__ == "__main__":
    print(brent(cos, 0.0, 3.0, Settings.default_verbose()))

Sample execution

We display the execution below.

$ python typing/root5.py
Iter	x		f(x)		delta(x)	delta(f(x))	step
1	3.000e+00	-9.900e-01	3.000e+00	1.990e+00	
2	1.500e+00	7.074e-02	1.500e+00	1.061e+00	bisection
3	1.600e+00	-2.923e-02	1.000e-01	9.997e-02	secant
4	1.571e+00	1.434e-05	2.925e-02	2.924e-02	secant
5	1.571e+00	-2.042e-09	1.434e-05	1.434e-05	secant
6	1.571e+00	6.123e-17	2.042e-09	2.042e-09	secant
Met y_tol criterion
1.5707963267948966