Fifth step: dataclasses
Contents
Fifth step: dataclasses¶
Python dataclasses are amazing alternatives to dicts when passing data around. Actually, I use them by default for the classes in my code, except if I have a reason not to (I cannot think of an example at the moment actually).
In the code below, we moved the two dicts to proper dataclasses, and demonstrate the following advantages we gained.
Mutable vs immutable data¶
We can declare dataclasses to be frozen. This means that their values cannot be modified after they are constructed, i.e. their instances are immutable objects.
Making a distinction between mutable and immutable pieces of data is great when we need to understand how data flows through a program.
In our code, the settings are immutable, as they are not modified during code execution (rather,
different Settings instances can be constructed and used at different times).
The BrentState class is immutable as well. The Brent step method returns a new updated instance
at every step.
However, this is feasible in this example as the data concerned is quite small (a few floating-point numbers). When dealing with larger datasets, often we will modify the data in place (a standard linear algebra example: in-place LU decomposition of matrices).
To have a mutable dataclass, one simply omits the (frozen=True) qualifier.
Methods¶
Dataclasses can have methods. This enables us to group functions related to a piece of data.
When some dataclass fields are derived from the values of other fields, as is the case for the
initial state of Brent’s algorithm, we can use
static methods
instead of fiddling with the __init__ method.
This works better with frozen dataclasses, inheritance and default parameters etc.
Also, we can provide additional static methods such as read_from_json to construct instances
from various sources.
We use the __post_init__ method to verify the soundness of the constructed object,
again using assertions.
Documenting invariants¶
A OOP principle is to have class invariants, i.e. constraints that are satisfied when an instance is constructed, and are preserved during code execution.
For example, in our BrentState class, we preserve the fact that b is the best-known
approximation so far and that f(a) and f(b) have opposite signs.
Example Sphinx autodoc output¶
- class root5.Settings(x_rel_tol=1e-12, x_abs_tol=1e-12, y_tol=1e-12, verbose=False)[source]¶
 Settings for the root-finding algorithm
- converged(a, b, fb)[source]¶
 Checks convergence of a root finding method
- Parameters
 a (
float) – Contrapointb (
float) – Best guessfb (
float) – f(b)
- Return type
 Tuple[bool,Optional[str]]- Returns
 Whether the root finding converges to the specified tolerance and why
- verbose: bool = False¶
 Whether to display progress
- x_abs_tol: float = 1e-12¶
 X absolute tolerance
- x_rel_tol: float = 1e-12¶
 X relative tolerance
- y_tol: float = 1e-12¶
 Y tolerance
- class root5.BrentState(a, b, c, d, fa, fb, fc, last_step, iter)[source]¶
 State for Brent’s method
Note
We have the following invariants.
fa = f(a)andfb = f(b)have opposite signsabs(fb) <= abs(fa)so thatbis the best guess
- a: float¶
 Contrapoint
- b: float¶
 Current iterate, best root approximation known so far
- c: float¶
 Previous iterate
- d: float¶
 Iterate before the previous iterate
- fa: float¶
 f(a)
- fb: float¶
 f(b)
- fc: float¶
 f(c)
- iter: int¶
 Current iteration number (1-based)
- last_step: Optional[Union[Literal['quadratic'], Literal['secant'], Literal['bisection']]]¶
 Type of previous step
Python code¶
"""
Root-finding solver based on Brent's method
See `<https://en.wikipedia.org/wiki/Brent%27s_method>`_
"""
from __future__ import annotations
from dataclasses import dataclass
from math import cos
from typing import Callable, Literal, Optional, Tuple, Union
@dataclass(frozen=True)
class Settings:
    """
    Settings for the root-finding algorithm
    """
    x_rel_tol: float = 1e-12  #: X relative tolerance
    x_abs_tol: float = 1e-12  #: X absolute tolerance
    y_tol: float = 1e-12  #: Y tolerance
    verbose: bool = False  #: Whether to display progress
    def __post_init__(self) -> None:
        """
        Verifies sanity of parameters
        """
        assert self.x_rel_tol >= 0
        assert self.x_abs_tol >= 0
        assert self.y_tol >= 0
        assert (
            self.x_rel_tol > 0 or self.x_abs_tol > 0 or self.y_tol > 0
        ), "At least one convergence criteria must be set"
    def converged(self, a: float, b: float, fb: float) -> Tuple[bool, Optional[str]]:
        """
        Checks convergence of a root finding method
        Args:
            a: Contrapoint
            b: Best guess
            fb: f(b)
        Returns:
            Whether the root finding converges to the specified tolerance and why
        """
        if fb == 0:
            return (True, "Exact root found")
        x_delta = abs(a - b)
        if x_delta <= self.x_abs_tol:
            return (True, "Met x_abs_tol criterion")
        if x_delta / max(a, b) <= self.x_rel_tol:
            return (True, "Met x_rel_tol criterion")
        y_delta = abs(fb)
        if y_delta <= self.y_tol:
            return (True, "Met y_tol criterion")
        return (False, None)
    @staticmethod
    def default_verbose() -> Settings:
        """
        Returns default verbose settings
        """
        return Settings(verbose=True)
def inverse_quadratic_interpolation_step(
    a: float, b: float, c: float, fa: float, fb: float, fc: float
) -> float:
    """
    Computes an approximation for a zero of a 1D function from three function values
    Note:
        The values ``fa``, ``fb``, ``fc`` need all to be distinct.
    See `<https://en.wikipedia.org/wiki/Inverse_quadratic_interpolation>`_
    Args:
        a: First x coordinate
        b: Second x coordinate
        c: Third x coordinate
        fa: f(a)
        fb: f(b)
        fc: f(c)
    Returns:
        An approximation of the zero
    """
    L0 = (a * fb * fc) / ((fa - fb) * (fa - fc))
    L1 = (b * fa * fc) / ((fb - fa) * (fb - fc))
    L2 = (c * fb * fa) / ((fc - fa) * (fc - fb))
    return L0 + L1 + L2
def secant_step(a: float, b: float, fa: float, fb: float) -> float:
    """
    Computes an approximation for a zero of a 1D function from two function values
    Note:
        The values ``fa`` and ``fb`` need to have a different sign.
    Args:
        a: First x coordinate
        b: Second x coordinate
        fa: f(a)
        fb: f(b)
    Returns:
        An approximation of the zero
    """
    return b - fb * (b - a) / (fb - fa)
def bisection_step(a: float, b: float) -> float:
    """
    Computes an approximation for a zero of a 1D function from two function values
    Note:
        The values ``f(a)`` and ``f(b)`` (not needed in the code) need to have a different sign.
    Args:
        a: First x coordinate
        b: Second x coordinate
    Returns:
        An approximation of the zero
    """
    return min(a, b) + abs(b - a) / 2
@dataclass(frozen=True)
class BrentState:
    """
    State for Brent's method
    Note:
        We have the following invariants.
        - ``fa = f(a)`` and ``fb = f(b)`` have opposite signs
        - ``abs(fb) <= abs(fa)`` so that ``b`` is the best guess
    """
    a: float  #: Contrapoint
    b: float  #: Current iterate, best root approximation known so far
    c: float  #: Previous iterate
    d: float  #: Iterate before the previous iterate
    fa: float  #: f(a)
    fb: float  #: f(b)
    fc: float  #: f(c)
    last_step: Optional[
        Union[Literal["quadratic"], Literal["secant"], Literal["bisection"]]
    ]  #: Type of previous step
    iter: int  #: Current iteration number (1-based)
    @staticmethod
    def make(f: Callable[[float], float], a: float, b: float) -> BrentState:
        """
        Initializes a state from an interval that brackets a root of the given function
        Args:
            f: Function to find the root of
            a: First x coordinate
            b: Second x coordinate
        Returns:
            The initial state
        """
        fa = f(a)
        fb = f(b)
        # check the first invariant
        assert fa * fb <= 0, "Root not bracketed"
        if abs(fa) < abs(fb):
            # force the second invariant
            b, a = a, b
            fb, fa = fa, fb
        c, fc = a, fa
        d, fd = a, fa
        return BrentState(a=a, b=b, c=c, d=d, fa=fa, fb=fb, fc=fc, last_step=None, iter=1)
def brent_step(f: Callable[[float], float], state: BrentState, delta: float) -> BrentState:
    """
    Performs a step of Brent's method
    Args:
        f: Function to find the root of
        state: Previous state
        delta: x absolute tolerance
    Returns:
        New state
    """
    a, b, c, d = state.a, state.b, state.c, state.d
    fa, fb, fc = state.fa, state.fb, state.fc
    last_step = state.last_step
    iter = state.iter
    step: Optional[Union[Literal["quadratic"], Literal["secant"], Literal["bisection"]]] = None
    if fa != fc and fb != fc:
        s = inverse_quadratic_interpolation_step(a, b, c, fa, fb, fc)
        step = "quadratic"
    else:
        s = secant_step(a, b, fa, fb)
        step = "secant"
    perform_bisection = False
    if a <= b and not ((3 * a + b) / 4 <= s <= b):
        perform_bisection = True
    elif b <= a and not (b <= s <= (3 * a + b) / 4):
        perform_bisection = True
    elif last_step == "bisection" and abs(s - b) >= abs(b - c) / 2:
        perform_bisection = True
    elif last_step != "bisection" and abs(a - b) >= abs(c - d) / 2:
        perform_bisection = True
    elif last_step == "bisection" and abs(b - c) < delta:
        perform_bisection = True
    elif last_step != "bisection" and abs(c - d) < delta:
        perform_bisection = True
    if perform_bisection:
        s = bisection_step(a, b)
        step = "bisection"
    fs = f(s)
    d = c
    c = b
    fc = fb
    # check which point to replace to maintain (a,b) have different signs
    if f(a) * f(s) < 0:
        b = s
        fb = fs
    else:
        a = s
        fa = fs
    # keep b as the best guess
    if abs(fa) < abs(fb):
        b, a = a, b
        fb, fa = fa, fb
    return BrentState(a=a, b=b, c=c, d=d, fa=fa, fb=fb, fc=fc, last_step=step, iter=iter + 1)
def brent(
    f: Callable[[float], float], a: float, b: float, settings: Settings = Settings()
) -> float:
    """
    Finds the root of a function using Brent's method, starting from an interval enclosing the zero
    Args:
        f: Function to find the root of
        a: First x coordinate enclosing the root
        b: Second x coordinate enclosing the root
        settings: Algorithm settings
    Returns:
        The approximate root
    """
    state = BrentState.make(f, a, b)
    converged = None
    reason = None
    def print_state(s: BrentState):
        """
        Prints information about an iteration
        """
        dx = abs(s.a - s.b)
        dy = abs(s.fa - s.fb)
        ls: Optional[str] = s.last_step
        if ls is None:
            ls = ""
        print(f"{s.iter}\t{s.b:.3e}\t{s.fb:.3e}\t{dx:.3e}\t{dy:.3e}\t{ls}")
    if settings.verbose:
        print("Iter\tx\t\tf(x)\t\tdelta(x)\tdelta(f(x))\tstep")
        print_state(state)
    while not converged:
        state = brent_step(f, state, settings.x_abs_tol)
        converged, reason = settings.converged(state.a, state.b, state.fb)
        if settings.verbose:
            print_state(state)
    if settings.verbose:
        assert reason is not None
        print(reason)
    return state.b
if __name__ == "__main__":
    print(brent(cos, 0.0, 3.0, Settings.default_verbose()))
Sample execution¶
We display the execution below.
$ python typing/root5.py
Iter	x		f(x)		delta(x)	delta(f(x))	step
1	3.000e+00	-9.900e-01	3.000e+00	1.990e+00	
2	1.500e+00	7.074e-02	1.500e+00	1.061e+00	bisection
3	1.600e+00	-2.923e-02	1.000e-01	9.997e-02	secant
4	1.571e+00	1.434e-05	2.925e-02	2.924e-02	secant
5	1.571e+00	-2.042e-09	1.434e-05	1.434e-05	secant
6	1.571e+00	6.123e-17	2.042e-09	2.042e-09	secant
Met y_tol criterion
1.5707963267948966