From fc3fd0f0340c1784c47407244f80adbfe94cdd81 Mon Sep 17 00:00:00 2001 From: Joseph Kleinhenz Date: Wed, 28 Aug 2024 10:51:42 -0700 Subject: [PATCH] add docstring to root_scalar --- src/beignet/_root_scalar.py | 30 ++++++++++++++++++++++++++++-- 1 file changed, 28 insertions(+), 2 deletions(-) diff --git a/src/beignet/_root_scalar.py b/src/beignet/_root_scalar.py index f8885224df..8f57a16420 100644 --- a/src/beignet/_root_scalar.py +++ b/src/beignet/_root_scalar.py @@ -16,10 +16,36 @@ class RootSolutionInfo: def root_scalar( func: Callable, *args, - method: Literal["bisect"] | Literal["chandrupatla"] = "chandrupatla", + method: Literal["bisect", "chandrupatla"] = "chandrupatla", implicit_diff: bool = True, - options: dict, + options: dict | None = None, ): + """ + Find the root of a scalar (elementwise) function. + + Parameters + ---------- + func: Callable + Function to find a root of. Called as `f(x, *args)`. + The function must operate element wise, i.e. `f(x[i]) == f(x)[i]`. + Handling *args via broadcasting is acceptable. + + *args + Extra arguments to be passed to `func`. + + method: Literal["bisect", "chandrupatla"] = "chandrupatla" + Solver method to use. + + implicit_diff: bool = True + If true, the solver is wrapped in `beignet.func.custom_scalar_root` which + enables gradients with respect to *args using implicit differentiation. + + options: dict | None = None + A dictionary of options that are passed through to the solver as keyword args. + """ + if options is None: + options = {} + if method == "bisect": solver = beignet.bisect elif method == "chandrupatla":