Skip to content

Commit

Permalink
improve docstrings
Browse files Browse the repository at this point in the history
  • Loading branch information
Joseph Kleinhenz committed Aug 28, 2024
1 parent fc3fd0f commit bdb201d
Show file tree
Hide file tree
Showing 3 changed files with 105 additions and 1 deletion.
44 changes: 44 additions & 0 deletions src/beignet/_bisect.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,50 @@ def bisect(
device=None,
**_,
) -> Tensor | tuple[Tensor, RootSolutionInfo]:
"""Find the root of a scalar (elementwise) function using bisection.
This method is slow but guarenteed to converge.
Parameters
----------
func: Callable
Function to find a root of. Called as `f(x, *args)`.
The function must operate element wise, i.e. `f(x[i]) == f(x)[i]`.
Handling *args via broadcasting is acceptable.
*args
Extra arguments to be passed to `func`.
lower: float | Tensor
Lower bracket for root
upper: float | Tensor
Upper bracket for root
rtol: float | None = None
Relative tolerance
atol: float | None = None
Absolute tolerance
maxiter: int = 100
Maximum number of iterations
return_solution_info: bool = False
Whether to return a `RootSolutionInfo` object
dtype = None
if upper/lower are passed as floats instead of tensors
use this dtype when constructing the tensor.
device = None
if upper/lower are passed as floats instead of tensors
use this device when constructing the tensor.
Returns
-------
Tensor | tuple[Tensor, RootSolutionInfo]
"""
a = torch.as_tensor(lower, dtype=dtype, device=device)
b = torch.as_tensor(upper, dtype=dtype, device=device)
a, b, *args = torch.broadcast_tensors(a, b, *args)
Expand Down
52 changes: 52 additions & 0 deletions src/beignet/_chandrupatla.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,58 @@ def chandrupatla(
device=None,
**_,
) -> Tensor | tuple[Tensor, RootSolutionInfo]:
"""Find the root of a scalar (elementwise) function using chandrupatla method.
This method is slow but guarenteed to converge.
Parameters
----------
func: Callable
Function to find a root of. Called as `f(x, *args)`.
The function must operate element wise, i.e. `f(x[i]) == f(x)[i]`.
Handling *args via broadcasting is acceptable.
*args
Extra arguments to be passed to `func`.
lower: float | Tensor
Lower bracket for root
upper: float | Tensor
Upper bracket for root
rtol: float | None = None
Relative tolerance
atol: float | None = None
Absolute tolerance
maxiter: int = 100
Maximum number of iterations
return_solution_info: bool = False
Whether to return a `RootSolutionInfo` object
dtype = None
if upper/lower are passed as floats instead of tensors
use this dtype when constructing the tensor.
device = None
if upper/lower are passed as floats instead of tensors
use this device when constructing the tensor.
Returns
-------
Tensor | tuple[Tensor, RootSolutionInfo]
References
----------
[1] Tirupathi R. Chandrupatla. A new hybrid quadratic/bisection algorithm for
finding the zero of a nonlinear function without using derivatives.
Advances in Engineering Software, 28.3:145-149, 1997.
"""
# maintain three points a,b,c for inverse quadratic interpolation
# we will keep (a,b) as the bracketing interval
a = torch.as_tensor(lower, dtype=dtype, device=device)
Expand Down
10 changes: 9 additions & 1 deletion src/beignet/_root_scalar.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ def root_scalar(
method: Literal["bisect", "chandrupatla"] = "chandrupatla",
implicit_diff: bool = True,
options: dict | None = None,
):
) -> Tensor | tuple[Tensor, RootSolutionInfo]:
"""
Find the root of a scalar (elementwise) function.
Expand All @@ -35,13 +35,21 @@ def root_scalar(
method: Literal["bisect", "chandrupatla"] = "chandrupatla"
Solver method to use.
* bisect: `beignet.bisect`
* chandrupatla: `beignet.chandrupatla`
See docstring of underlying solvers for description of options dict.
implicit_diff: bool = True
If true, the solver is wrapped in `beignet.func.custom_scalar_root` which
enables gradients with respect to *args using implicit differentiation.
options: dict | None = None
A dictionary of options that are passed through to the solver as keyword args.
Returns
-------
Tensor | tuple[Tensor, RootSolutionInfo]
"""
if options is None:
options = {}
Expand Down

0 comments on commit bdb201d

Please sign in to comment.