Skip to content

Latest commit

 

History

History
24 lines (15 loc) · 398 Bytes

README.md

File metadata and controls

24 lines (15 loc) · 398 Bytes

jax_newton_raphson

A simple Newton-Raphson optimizer in JAX.

Install

pip install git+https://github.com/thisiscam/jax_newton_raphson

Usage

import collections
import jax_newton_raphson as jnr

Params = collections.namedtuple("Params", "x y")


def f(params: Params):
  return (params.x**2 + params.y**2)


print(jnr.minimize(f, initial_guess=Params(-0.1, 0.1)))