diff --git a/src/beignet/_bisect.py b/src/beignet/_bisect.py index 1c9eec1507..8a5a8d16e5 100644 --- a/src/beignet/_bisect.py +++ b/src/beignet/_bisect.py @@ -19,6 +19,50 @@ def bisect( device=None, **_, ) -> Tensor | tuple[Tensor, RootSolutionInfo]: + """Find the root of a scalar (elementwise) function using bisection. + + This method is slow but guarenteed to converge. + + Parameters + ---------- + func: Callable + Function to find a root of. Called as `f(x, *args)`. + The function must operate element wise, i.e. `f(x[i]) == f(x)[i]`. + Handling *args via broadcasting is acceptable. + + *args + Extra arguments to be passed to `func`. + + lower: float | Tensor + Lower bracket for root + + upper: float | Tensor + Upper bracket for root + + rtol: float | None = None + Relative tolerance + + atol: float | None = None + Absolute tolerance + + maxiter: int = 100 + Maximum number of iterations + + return_solution_info: bool = False + Whether to return a `RootSolutionInfo` object + + dtype = None + if upper/lower are passed as floats instead of tensors + use this dtype when constructing the tensor. + + device = None + if upper/lower are passed as floats instead of tensors + use this device when constructing the tensor. + + Returns + ------- + Tensor | tuple[Tensor, RootSolutionInfo] + """ a = torch.as_tensor(lower, dtype=dtype, device=device) b = torch.as_tensor(upper, dtype=dtype, device=device) a, b, *args = torch.broadcast_tensors(a, b, *args) diff --git a/src/beignet/_chandrupatla.py b/src/beignet/_chandrupatla.py index a6aa9145a1..9f06d1e2e7 100644 --- a/src/beignet/_chandrupatla.py +++ b/src/beignet/_chandrupatla.py @@ -19,6 +19,58 @@ def chandrupatla( device=None, **_, ) -> Tensor | tuple[Tensor, RootSolutionInfo]: + """Find the root of a scalar (elementwise) function using chandrupatla method. + + This method is slow but guarenteed to converge. + + Parameters + ---------- + func: Callable + Function to find a root of. Called as `f(x, *args)`. + The function must operate element wise, i.e. `f(x[i]) == f(x)[i]`. + Handling *args via broadcasting is acceptable. + + *args + Extra arguments to be passed to `func`. + + lower: float | Tensor + Lower bracket for root + + upper: float | Tensor + Upper bracket for root + + rtol: float | None = None + Relative tolerance + + atol: float | None = None + Absolute tolerance + + maxiter: int = 100 + Maximum number of iterations + + return_solution_info: bool = False + Whether to return a `RootSolutionInfo` object + + dtype = None + if upper/lower are passed as floats instead of tensors + use this dtype when constructing the tensor. + + device = None + if upper/lower are passed as floats instead of tensors + use this device when constructing the tensor. + + Returns + ------- + Tensor | tuple[Tensor, RootSolutionInfo] + + + References + ---------- + + [1] Tirupathi R. Chandrupatla. A new hybrid quadratic/bisection algorithm for + finding the zero of a nonlinear function without using derivatives. + Advances in Engineering Software, 28.3:145-149, 1997. + """ # maintain three points a,b,c for inverse quadratic interpolation # we will keep (a,b) as the bracketing interval a = torch.as_tensor(lower, dtype=dtype, device=device) diff --git a/src/beignet/_root_scalar.py b/src/beignet/_root_scalar.py index 8f57a16420..e54b4aa3ef 100644 --- a/src/beignet/_root_scalar.py +++ b/src/beignet/_root_scalar.py @@ -19,7 +19,7 @@ def root_scalar( method: Literal["bisect", "chandrupatla"] = "chandrupatla", implicit_diff: bool = True, options: dict | None = None, -): +) -> Tensor | tuple[Tensor, RootSolutionInfo]: """ Find the root of a scalar (elementwise) function. @@ -35,6 +35,9 @@ def root_scalar( method: Literal["bisect", "chandrupatla"] = "chandrupatla" Solver method to use. + * bisect: `beignet.bisect` + * chandrupatla: `beignet.chandrupatla` + See docstring of underlying solvers for description of options dict. implicit_diff: bool = True If true, the solver is wrapped in `beignet.func.custom_scalar_root` which @@ -42,6 +45,11 @@ def root_scalar( options: dict | None = None A dictionary of options that are passed through to the solver as keyword args. + + + Returns + ------- + Tensor | tuple[Tensor, RootSolutionInfo] """ if options is None: options = {}