diff --git a/examples/3_Advanced/strain_simple.py b/examples/3_Advanced/strain_simple.py index 51b60b84b..f11fac807 100644 --- a/examples/3_Advanced/strain_simple.py +++ b/examples/3_Advanced/strain_simple.py @@ -17,7 +17,7 @@ from simsopt.objectives import SquaredFlux from simsopt.objectives import QuadraticPenalty from simsopt.geo.strain_optimization_classes import StrainOpt -from simsopt.geo.finitebuild import create_multifilament_grid +from simsopt.geo import create_multifilament_grid, ZeroRotation, FramedCurveFrenet, FrameRotation # from exportcoils import export_coils, import_coils, import_coils_fb, export_coils_fb from simsopt.configs import get_ncsx_data import matplotlib.pyplot as plt @@ -26,20 +26,14 @@ curve = curves[0] # Set up the winding pack -numfilaments_n = 1 # number of filaments in normal direction -numfilaments_b = 1 # number of filaments in bi-normal direction -gapsize_n = 0.02 # gap between filaments in normal direction -gapsize_b = 0.03 # gap between filaments in bi-normal direction rot_order = 5 # order of the Fourier expression for the rotation of the filament pack -scale = 1 width = 12 -# use sum here to concatenate lists -filaments = create_multifilament_grid(curve, numfilaments_n, numfilaments_b, gapsize_n, - gapsize_b, rotation_order=rot_order, frame='frenet') +rotation = FrameRotation(curve.quadpoints,rot_order) +framedcurve = FramedCurveFrenet(curve, rotation) -strain = StrainOpt(filaments[0], width=width) +strain = StrainOpt(framedcurve, width=width) tor = strain.binormal_curvature_strain() @@ -48,8 +42,9 @@ plt.show() # tor = strain.torsional_strain() -tor_frame = filaments[0].frame_torsion() +tor_frame = framedcurve.frame_torsion() tor_curve = curve.torsion() +dl = curve.incremental_arclength() plt.figure() plt.plot(tor_frame) diff --git a/src/simsopt/field/magneticfieldclasses.py b/src/simsopt/field/magneticfieldclasses.py index 78928a162..2de2f0343 100644 --- a/src/simsopt/field/magneticfieldclasses.py +++ b/src/simsopt/field/magneticfieldclasses.py @@ -257,8 +257,8 @@ def __init__(self, phi_str): self.phi_str = phi_str self.phi_parsed = parse_expr(phi_str) R, Z, Phi = sp.symbols('R Z phi') - self.Blambdify = sp.lambdify((R, Z, Phi), [self.phi_parsed.diff(R)+1e-30*Phi*R*Z,\ - self.phi_parsed.diff(Phi)/R+1e-30*Phi*R*Z,\ + self.Blambdify = sp.lambdify((R, Z, Phi), [self.phi_parsed.diff(R)+1e-30*Phi*R*Z, \ + self.phi_parsed.diff(Phi)/R+1e-30*Phi*R*Z, \ self.phi_parsed.diff(Z)+1e-30*Phi*R*Z]) self.dBlambdify_by_dX = sp.lambdify( (R, Z, Phi), diff --git a/src/simsopt/geo/__init__.py b/src/simsopt/geo/__init__.py index bbc47fc74..12eca158b 100644 --- a/src/simsopt/geo/__init__.py +++ b/src/simsopt/geo/__init__.py @@ -8,7 +8,7 @@ from .curvexyzfourier import * from .curveperturbed import * from .curveobjectives import * - +from .framedcurve import * from .finitebuild import * from .plotting import * diff --git a/src/simsopt/geo/finitebuild.py b/src/simsopt/geo/finitebuild.py index 442980a36..8e45f4ac1 100644 --- a/src/simsopt/geo/finitebuild.py +++ b/src/simsopt/geo/finitebuild.py @@ -7,32 +7,29 @@ from .._core.derivative import Derivative from .curve import Curve from .jit import jit +from .framedcurve import FramedCurve """ The functions and classes in this model are used to deal with multifilament approximation of finite build coils. """ -__all__ = ['create_multifilament_grid', - 'CurveFilament', 'FilamentRotation', 'ZeroRotation'] +__all__ = ['create_multifilament_grid','CurveFilament'] -class CurveFilament(sopp.Curve, Curve): - def __init__(self, curve, dn, db, rotation=None): +class CurveFilament(FramedCurve): + + def __init__(self, framedcurve, dn, db): """ - Implementation of the centroid frame introduced in - Singh et al, "Optimization of finite-build stellarator coils", - Journal of Plasma Physics 86 (2020), - doi:10.1017/S0022377820000756. Given a curve, one defines a normal and - binormal vector and then creates a grid of curves by shifting along the - normal and binormal vector. In addition, we specify an angle along the - curve that allows us to optimise for the rotation of the winding pack. + Given a FramedCurve, defining a normal and + binormal vector, create a grid of curves by shifting + along the normal and binormal vector. - The idea is explained well in Figure 1 in the reference above. + The idea is explained well in Figure 1 in the reference: - Note that "normal" and "binormal" in the function arguments here - refer not to the Frenet frame but rather to the "coil centroid - frame" defined by Singh et al., before rotation. + Singh et al, "Optimization of finite-build stellarator coils", + Journal of Plasma Physics 86 (2020), + doi:10.1017/S0022377820000756. Args: curve: the underlying curve @@ -40,19 +37,13 @@ def __init__(self, curve, dn, db, rotation=None): db: how far to move in binormal direction rotation: angle along the curve to rotate the frame. """ - self.curve = curve - sopp.Curve.__init__(self, curve.quadpoints) - deps = [curve] - if rotation is not None: - deps.append(rotation) - Curve.__init__(self, depends_on=deps) - self.curve = curve + self.curve = framedcurve.curve self.dn = dn self.db = db - if rotation is None: - rotation = ZeroRotation(curve.quadpoints) - self.rotation = rotation - + self.rotation = framedcurve.rotation + self.framedcurve = framedcurve + FramedCurve.__init__(self, self.curve, self.rotation) + def recompute_bell(self, parent=None): self.invalidate_cache() @@ -67,62 +58,19 @@ def gammadash_impl(self, gammadash): td, nd, bd = self.rotated_frame_dash() gammadash[:] = self.curve.gammadash() + self.dn * nd + self.db * bd - -# class FramedCurve(Optimizable): - -# def __init__(self, curve, rotation=None): -# self.curve = curve -# self.rotation = rotation -# super().__init__(depends_on=[curve, rotation]) - -class CurveFilamentFrenet(CurveFilament): - - def __init__(self, curve, dn, db, rotation=None): - CurveFilament.__init__(self, curve, dn, db, rotation=None) - self.torsion = jit(lambda gamma, gammadash, gammadashdash, gammadashdashdash, alpha, alphadash: torsion_pure_frenet( - gamma, gammadash, gammadashdash, gammadashdashdash, alpha, alphadash)) - self.rotated_frame = rotated_frenet_frame - self.binormal_curvature = jit(lambda gamma, gammadash, gammadashdash, gammadashdashdash, alpha, alphadash: binormal_curvature_pure_frenet( - gamma, gammadash, gammadashdash, gammadashdashdash, alpha, alphadash)) - - def rotated_frame_dash(self): - return rotated_frenet_frame_dash( - self.curve.gamma(), self.curve.gammadash(), self.curve.gammadashdash(), self.curve.gammadashdashdash(), - self.rotation.alpha(self.curve.quadpoints), self.rotation.alphadash(self.curve.quadpoints) - ) - - def frame_torsion(self): - """Exports frame torsion along a curve""" - gamma = self.curve.gamma() - d1gamma = self.curve.gammadash() - d2gamma = self.curve.gammadashdash() - d3gamma = self.curve.gammadashdashdash() - alpha = self.rotation.alpha(self.curve.quadpoints) - alphadash = self.rotation.alphadash(self.curve.quadpoints) - return self.torsion(gamma, d1gamma, d2gamma, d3gamma, alpha, alphadash) - - def frame_binormal_curvature(self): - gamma = self.curve.gamma() - d1gamma = self.curve.gammadash() - d2gamma = self.curve.gammadashdash() - d3gamma = self.curve.gammadashdashdash() - alpha = self.rotation.alpha(self.curve.quadpoints) - alphadash = self.rotation.alphadash(self.curve.quadpoints) - return self.binormal_curvature(gamma, d1gamma, d2gamma, d3gamma, alpha, alphadash) - def dgamma_by_dcoeff_vjp(self, v): g = self.curve.gamma() gd = self.curve.gammadash() gdd = self.curve.gammadashdash() a = self.rotation.alpha(self.curve.quadpoints) zero = np.zeros_like(v) - vg = rotated_frenet_frame_dcoeff_vjp0( + vg = self.framedcurve.rotated_frame_dcoeff_vjp0( g, gd, gdd, a, (zero, self.dn*v, self.db*v)) - vgd = rotated_frenet_frame_dcoeff_vjp1( + vgd = self.framedcurve.rotated_frame_dcoeff_vjp1( g, gd, gdd, a, (zero, self.dn*v, self.db*v)) - vgdd = rotated_frenet_frame_dcoeff_vjp2( + vgdd = self.framedcurve.rotated_frame_dcoeff_vjp2( g, gd, gdd, a, (zero, self.dn*v, self.db*v)) - va = rotated_frenet_frame_dcoeff_vjp3( + va = self.framedcurve.rotated_frame_dcoeff_vjp3( g, gd, gdd, a, (zero, self.dn*v, self.db*v)) return self.curve.dgamma_by_dcoeff_vjp(v + vg) \ + self.curve.dgammadash_by_dcoeff_vjp(vgd) \ @@ -138,17 +86,17 @@ def dgammadash_by_dcoeff_vjp(self, v): ad = self.rotation.alphadash(self.curve.quadpoints) zero = np.zeros_like(v) - vg = rotated_frenet_frame_dash_dcoeff_vjp0( + vg = self.framedcurve.rotated_frame_dash_dcoeff_vjp0( g, gd, gdd, gddd, a, ad, (zero, self.dn*v, self.db*v)) - vgd = rotated_frenet_frame_dash_dcoeff_vjp1( + vgd = self.framedcurve.rotated_frame_dash_dcoeff_vjp1( g, gd, gdd, gddd, a, ad, (zero, self.dn*v, self.db*v)) - vgdd = rotated_frenet_frame_dash_dcoeff_vjp2( + vgdd = self.framedcurve.rotated_frame_dash_dcoeff_vjp2( g, gd, gdd, gddd, a, ad, (zero, self.dn*v, self.db*v)) - vgddd = rotated_frenet_frame_dash_dcoeff_vjp3( + vgddd = self.framedcurve.rotated_frame_dash_dcoeff_vjp3( g, gd, gdd, gddd, a, ad, (zero, self.dn*v, self.db*v)) - va = rotated_frenet_frame_dash_dcoeff_vjp4( + va = self.framedcurve.rotated_frame_dash_dcoeff_vjp4( g, gd, gdd, gddd, a, ad, (zero, self.dn*v, self.db*v)) - vad = rotated_frenet_frame_dash_dcoeff_vjp5( + vad = self.framedcurve.rotated_frame_dash_dcoeff_vjp5( g, gd, gdd, gddd, a, ad, (zero, self.dn*v, self.db*v)) return self.curve.dgamma_by_dcoeff_vjp(vg) \ + self.curve.dgammadash_by_dcoeff_vjp(v+vgd) \ @@ -158,83 +106,14 @@ def dgammadash_by_dcoeff_vjp(self, v): + self.rotation.dalphadash_by_dcoeff_vjp(self.curve.quadpoints, vad) -class CurveFilamentCentroid(CurveFilament): - - def __init__(self, curve, dn, db, rotation=None): - CurveFilament.__init__(self, curve, dn, db, rotation=None) - self.rotated_frame = rotated_centroid_frame - self.dgamma_by_dcoeff_vjp = dgamma_by_dcoeff_vjp_centroid - self.dgammadash_by_dcoeff_vjp = dgammadash_by_dcoeff_vjp_centroid - self.torsion = jit(lambda gamma, gammadash, gammadashdash, alpha, alphadash: torsion_pure_centroid( - gamma, gammadash, gammadashdash, alpha, alphadash)) - self.binormal_curvature = jit(lambda gamma, gammadash, gammadashdash, alpha, alphadash: binormal_curvature_pure_centroid( - gamma, gammadash, gammadashdash, alpha, alphadash)) - - def frame_torsion(self): - """Exports frame torsion along a curve""" - gamma = self.curve.curve.gamma() - d1gamma = self.curve.curve.gammadash() - d2gamma = self.curve.curve.gammadashdash() - d3gamma = self.curve.curve.gammadashdashdash() - alpha = self.curve.rotation.alpha(self.curve.quadpoints) - alphadash = self.curve.rotation.alphadash(self.curve.quadpoints) - return self.torsion(gamma, d1gamma, d2gamma, alpha, alphadash) - - def frame_binormal_curvature(self): - gamma = self.curve.curve.gamma() - d1gamma = self.curve.curve.gammadash() - d2gamma = self.curve.curve.gammadashdash() - d3gamma = self.curve.curve.gammadashdashdash() - alpha = self.curve.rotation.alpha(self.curve.quadpoints) - alphadash = self.curve.rotation.alphadash(self.curve.quadpoints) - return self.binormal_curvature(gamma, d1gamma, d2gamma, alpha, alphadash) - - - def rotated_frame_dash(self): - return rotated_centroid_frame_dash( - self.curve.gamma(), self.curve.gammadash() , self.curve.gammadashdashdash(), - self.rotation.alpha(self.curve.quadpoints), self.rotation.alphadash(self.curve.quadpoints) - ) - - def dgammadash_by_dcoeff_vjp(self, v): - g = self.curve.gamma() - gd = self.curve.gammadash() - gdd = self.curve.gammadashdash() - a = self.rotation.alpha(self.curve.quadpoints) - ad = self.rotation.alphadash(self.curve.quadpoints) - zero = np.zeros_like(v) - - vg = rotated_centroid_frame_dash_dcoeff_vjp0(g, gd, gdd, a, ad, (zero, self.dn*v, self.db*v)) - vgd = rotated_centroid_frame_dash_dcoeff_vjp1(g, gd, gdd, a, ad, (zero, self.dn*v, self.db*v)) - vgdd = rotated_centroid_frame_dash_dcoeff_vjp2(g, gd, gdd, a, ad, (zero, self.dn*v, self.db*v)) - va = rotated_centroid_frame_dash_dcoeff_vjp3(g, gd, gdd, a, ad, (zero, self.dn*v, self.db*v)) - vad = rotated_centroid_frame_dash_dcoeff_vjp4(g, gd, gdd, a, ad, (zero, self.dn*v, self.db*v)) - return self.curve.dgamma_by_dcoeff_vjp(vg) \ - + self.curve.dgammadash_by_dcoeff_vjp(v+vgd) \ - + self.curve.dgammadashdash_by_dcoeff_vjp(vgdd) \ - + self.rotation.dalpha_by_dcoeff_vjp(self.curve.quadpoints, va) \ - + self.rotation.dalphadash_by_dcoeff_vjp(self.curve.quadpoints, vad) - - def dgamma_by_dcoeff_vjp(self, v): - g = self.curve.gamma() - gd = self.curve.gammadash() - a = self.rotation.alpha(self.curve.quadpoints) - zero = np.zeros_like(v) - vg = rotated_centroid_frame_dcoeff_vjp0(g, gd, a, (zero, self.dn*v, self.db*v)) - vgd = rotated_centroid_frame_dcoeff_vjp1(g, gd, a, (zero, self.dn*v, self.db*v)) - va = rotated_centroid_frame_dcoeff_vjp2(g, gd, a, (zero, self.dn*v, self.db*v)) - return self.curve.dgamma_by_dcoeff_vjp(v + vg) \ - + self.curve.dgammadash_by_dcoeff_vjp(vgd) \ - + self.rotation.dalpha_by_dcoeff_vjp(self.curve.quadpoints, va) - def create_multifilament_grid(curve, numfilaments_n, numfilaments_b, gapsize_n, gapsize_b, - rotation_order=None, rotation_scaling=None, frame='centroid'): + rotation_order=None, rotation_scaling=None, frame='centroid'): """ Create a regular grid of ``numfilaments_n * numfilaments_b`` many filaments to approximate a finite-build coil. Note that "normal" and "binormal" in the function arguments here - refer not to the Frenet frame but rather to the "coil centroid + refer to either the Frenet frame or the "coil centroid frame" defined by Singh et al., before rotation. Args: @@ -249,8 +128,9 @@ def create_multifilament_grid(curve, numfilaments_n, numfilaments_b, gapsize_n, scaling improves the convergence of first order optimization algorithms. If ``None``, then the default of ``1 / max(gapsize_n, gapsize_b)`` is used. + frame: orthonormal frame to define normal and binormal before rotation (either 'centroid' or 'frenet') """ - assert frame in ['centroid','frenet'] + assert frame in ['centroid', 'frenet'] if numfilaments_n % 2 == 1: shifts_n = np.arange(numfilaments_n) - numfilaments_n//2 else: @@ -267,289 +147,15 @@ def create_multifilament_grid(curve, numfilaments_n, numfilaments_b, gapsize_n, if rotation_order is None: rotation = ZeroRotation(curve.quadpoints) else: - rotation = FilamentRotation(curve.quadpoints, rotation_order, scale=rotation_scaling) + rotation = FrameRotation(curve.quadpoints, rotation_order, scale=rotation_scaling) + if frame == 'frenet': + framedcurve = FramedCurveFrenet(curve, rotation) + else: + framedcurve = FramedCurveCentroid(curve, rotation) + filaments = [] for i in range(numfilaments_n): for j in range(numfilaments_b): - if frame=='frenet': - filaments.append(CurveFilamentFrenet(curve, shifts_n[i], shifts_b[j], rotation)) - else: - filaments.append(CurveFilamentCentroid(curve, shifts_n[i], shifts_b[j], rotation)) + filaments.append(CurveFilament(framedcurve, shifts_n[i], shifts_b[j])) return filaments - -class FilamentRotation(Optimizable): - - def __init__(self, quadpoints, order, scale=1., dofs=None): - """ - The rotation of the multifilament pack; alpha in Figure 1 of - Singh et al, "Optimization of finite-build stellarator coils", - Journal of Plasma Physics 86 (2020), - doi:10.1017/S0022377820000756 - """ - self.order = order - if dofs is None: - super().__init__(x0=np.zeros((2*order+1, ))) - else: - super().__init__(dofs=dofs) - self.quadpoints = quadpoints - self.scale = scale - self.jac = rotation_dcoeff(quadpoints, order) - self.jacdash = rotationdash_dcoeff(quadpoints, order) - self.jax_alpha = jit(lambda dofs, points: jaxrotation_pure(dofs, points, self.order)) - self.jax_alphadash = jit(lambda dofs, points: jaxrotationdash_pure(dofs, points, self.order)) - - def alpha(self, quadpoints): - return self.scale * self.jax_alpha(self._dofs.full_x, quadpoints) - - def alphadash(self, quadpoints): - return self.scale * self.jax_alphadash(self._dofs.full_x, quadpoints) - - def dalpha_by_dcoeff_vjp(self, quadpoints, v): - return Derivative({self: self.scale * sopp.vjp(v, self.jac)}) - - def dalphadash_by_dcoeff_vjp(self, quadpoints, v): - return Derivative({self: self.scale * sopp.vjp(v, self.jacdash)}) - - -class ZeroRotation(Optimizable): - - def __init__(self, quadpoints): - """ - Dummy class that just returns zero for the rotation angle. Equivalent to using - - .. code-block:: python - - rot = FilamentRotation(...) - rot.fix_all() - - """ - super().__init__() - self.zero = np.zeros((quadpoints.size, )) - - def alpha(self, quadpoints): - return self.zero - - def alphadash(self, quadpoints): - return self.zero - - def dalpha_by_dcoeff_vjp(self, quadpoints, v): - return Derivative({}) - - def dalphadash_by_dcoeff_vjp(self, quadpoints, v): - return Derivative({}) - - -@jit -def rotated_centroid_frame(gamma, gammadash, alpha): - t = gammadash - t *= 1./jnp.linalg.norm(gammadash, axis=1)[:, None] - R = jnp.mean(gamma, axis=0) # centroid - delta = gamma - R[None, :] - n = delta - jnp.sum(delta * t, axis=1)[:, None] * t - n *= 1./jnp.linalg.norm(n, axis=1)[:, None] - b = jnp.cross(t, n, axis=1) - - # now rotate the frame by alpha - nn = jnp.cos(alpha)[:, None] * n - jnp.sin(alpha)[:, None] * b - bb = jnp.sin(alpha)[:, None] * n + jnp.cos(alpha)[:, None] * b - return t, nn, bb - - -rotated_centroid_frame_dash = jit( - lambda gamma, gammadash, gammadashdash, gammadashdashdash, alpha, alphadash: jvp(rotated_centroid_frame, - (gamma, gammadash, alpha), - (gammadash, gammadashdash, alphadash))[1]) - -rotated_centroid_frame_dcoeff_vjp0 = jit( - lambda gamma, gammadash, alpha, v: vjp( - lambda g: rotated_centroid_frame(g, gammadash, alpha), gamma)[1](v)[0]) - -rotated_centroid_frame_dcoeff_vjp1 = jit( - lambda gamma, gammadash, alpha, v: vjp( - lambda gd: rotated_centroid_frame(gamma, gd, alpha), gammadash)[1](v)[0]) - -rotated_centroid_frame_dcoeff_vjp2 = jit( - lambda gamma, gammadash, alpha, v: vjp( - lambda a: rotated_centroid_frame(gamma, gammadash, a), alpha)[1](v)[0]) - -rotated_centroid_frame_dash_dcoeff_vjp0 = jit( - lambda gamma, gammadash, gammadashdash, alpha, alphadash, v: vjp( - lambda g: rotated_centroid_frame_dash(g, gammadash, gammadashdash, alpha, alphadash), gamma)[1](v)[0]) - -rotated_centroid_frame_dash_dcoeff_vjp1 = jit( - lambda gamma, gammadash, gammadashdash, alpha, alphadash, v: vjp( - lambda gd: rotated_centroid_frame_dash(gamma, gd, gammadashdash, alpha, alphadash), gammadash)[1](v)[0]) - -rotated_centroid_frame_dash_dcoeff_vjp2 = jit( - lambda gamma, gammadash, gammadashdash, alpha, alphadash, v: vjp( - lambda gdd: rotated_centroid_frame_dash(gamma, gammadash, gdd, alpha, alphadash), gammadashdash)[1](v)[0]) - -rotated_centroid_frame_dash_dcoeff_vjp3 = jit( - lambda gamma, gammadash, gammadashdash, alpha, alphadash, v: vjp( - lambda a: rotated_centroid_frame_dash(gamma, gammadash, gammadashdash, a, alphadash), alpha)[1](v)[0]) - -rotated_centroid_frame_dash_dcoeff_vjp4 = jit( - lambda gamma, gammadash, gammadashdash, alpha, alphadash, v: vjp( - lambda ad: rotated_centroid_frame_dash(gamma, gammadash, gammadashdash, alpha, ad), alphadash)[1](v)[0]) - -@jit -def rotated_frenet_frame(gamma, gammadash, gammadashdash, alpha): - """Frenet frame of a curve rotated by a angle that varies along the coil path""" - - N = gamma.shape[0] - t, n, b = (np.zeros((N, 3)), np.zeros((N, 3)), np.zeros((N, 3))) - t = gammadash - t *= 1./jnp.linalg.norm(gammadash, axis=1)[:, None] - - tdash = (1./jnp.linalg.norm(gammadash, axis=1)[:, None])**2 * (jnp.linalg.norm(gammadash, axis=1)[:, None] * gammadashdash - - (inner(gammadash, gammadashdash)/jnp.linalg.norm(gammadash, axis=1))[:, None] * gammadash) - - n = tdash - n *= 1/jnp.linalg.norm(tdash, axis=1)[:, None] - b = jnp.cross(t, n, axis=1) - # now rotate the frame by alpha - nn = jnp.cos(alpha)[:, None] * n - jnp.sin(alpha)[:, None] * b - bb = jnp.sin(alpha)[:, None] * n + jnp.cos(alpha)[:, None] * b - - return t, nn, bb - -rotated_frenet_frame_dash = jit( - lambda gamma, gammadash, gammadashdash, gammadashdashdash, alpha, alphadash: jvp(rotated_frenet_frame, - (gamma, gammadash, - gammadashdash, alpha), - (gammadash, gammadashdash, gammadashdashdash, alphadash))[1]) - -rotated_frenet_frame_dcoeff_vjp0 = jit( - lambda gamma, gammadash, gammadashdash, alpha, v: vjp( - lambda g: rotated_frenet_frame(g, gammadash, gammadashdash, alpha), gamma)[1](v)[0]) - -rotated_frenet_frame_dcoeff_vjp1 = jit( - lambda gamma, gammadash, gammadashdash, alpha, v: vjp( - lambda gd: rotated_frenet_frame(gamma, gd, gammadashdash, alpha), gammadash)[1](v)[0]) - -rotated_frenet_frame_dcoeff_vjp2 = jit( - lambda gamma, gammadash, gammadashdash, alpha, v: vjp( - lambda gdd: rotated_frenet_frame(gamma, gammadash, gdd, alpha), gammadashdash)[1](v)[0]) - -rotated_frenet_frame_dcoeff_vjp3 = jit( - lambda gamma, gammadash, gammadashdash, alpha, v: vjp( - lambda a: rotated_frenet_frame(gamma, gammadash, gammadashdash, a), alpha)[1](v)[0]) - -rotated_frenet_frame_dash_dcoeff_vjp0 = jit( - lambda gamma, gammadash, gammadashdash, gammadashdashdash, alpha, alphadash, v: vjp( - lambda g: rotated_frenet_frame_dash(g, gammadash, gammadashdash, gammadashdashdash, alpha, alphadash), gamma)[1](v)[0]) - -rotated_frenet_frame_dash_dcoeff_vjp1 = jit( - lambda gamma, gammadash, gammadashdash, gammadashdashdash, alpha, alphadash, v: vjp( - lambda gd: rotated_frenet_frame_dash(gamma, gd, gammadashdash, gammadashdashdash, alpha, alphadash), gammadash)[1](v)[0]) - -rotated_frenet_frame_dash_dcoeff_vjp2 = jit( - lambda gamma, gammadash, gammadashdash, gammadashdashdash, alpha, alphadash, v: vjp( - lambda gdd: rotated_frenet_frame_dash(gamma, gammadash, gdd, gammadashdashdash, alpha, alphadash), gammadashdash)[1](v)[0]) - -rotated_frenet_frame_dash_dcoeff_vjp3 = jit( - lambda gamma, gammadash, gammadashdash, gammadashdashdash, alpha, alphadash, v: vjp( - lambda gddd: rotated_frenet_frame_dash(gamma, gammadash, gammadashdash, gddd, alpha, alphadash), gammadashdashdash)[1](v)[0]) - -rotated_frenet_frame_dash_dcoeff_vjp4 = jit( - lambda gamma, gammadash, gammadashdash, gammadashdashdash, alpha, alphadash, v: vjp( - lambda a: rotated_frenet_frame_dash(gamma, gammadash, gammadashdash, gammadashdashdash, a, alphadash), alpha)[1](v)[0]) - -rotated_frenet_frame_dash_dcoeff_vjp5 = jit( - lambda gamma, gammadash, gammadashdash, gammadashdashdash, alpha, alphadash, v: vjp( - lambda ad: rotated_frenet_frame_dash(gamma, gammadash, gammadashdash, gammadashdashdash, alpha, ad), alphadash)[1](v)[0]) - - -def jaxrotation_pure(dofs, points, order): - rotation = jnp.zeros((len(points), )) - rotation += dofs[0] - for j in range(1, order+1): - rotation += dofs[2*j-1] * jnp.sin(2*np.pi*j*points) - rotation += dofs[2*j] * jnp.cos(2*np.pi*j*points) - return rotation - - -def jaxrotationdash_pure(dofs, points, order): - rotation = jnp.zeros((len(points), )) - for j in range(1, order+1): - rotation += dofs[2*j-1] * 2*np.pi*j*jnp.cos(2*np.pi*j*points) - rotation -= dofs[2*j] * 2*np.pi*j*jnp.sin(2*np.pi*j*points) - return rotation - - -def rotation_dcoeff(points, order): - jac = np.zeros((len(points), 2*order+1)) - jac[:, 0] = 1 - for j in range(1, order+1): - jac[:, 2*j-1] = np.sin(2*np.pi*j*points) - jac[:, 2*j+0] = np.cos(2*np.pi*j*points) - return jac - - -def rotationdash_dcoeff(points, order): - jac = np.zeros((len(points), 2*order+1)) - for j in range(1, order+1): - jac[:, 2*j-1] = +2*np.pi*j*np.cos(2*np.pi*j*points) - jac[:, 2*j+0] = -2*np.pi*j*np.sin(2*np.pi*j*points) - return jac - -def inner(a, b): - """Inner product for arrays of shape (N, 3)""" - return np.sum(a*b, axis=1) - -torsion2vjp0 = jit(lambda ndash, b, v: vjp( - lambda nd: torsion_pure(nd, b), ndash)[1](v)[0]) -torsion2vjp1 = jit(lambda ndash, b, v: vjp( - lambda bi: torsion_pure(ndash, bi), b)[1](v)[0]) - -def binormal_curvature_pure(tdash, b): - """Implements binormal currvature for optimization""" - binormal_curvature = inner(tdash, b) - return binormal_curvature - -def torsion_pure_frenet(gamma, gammadash, gammadashdash, gammadashdashdash, - alpha, alphadash): - """Torsion function for export/evaulate coil sets""" - - _, _, b = rotated_frenet_frame(gamma, gammadash, gammadashdash, alpha) - _, ndash, _ = rotated_frenet_frame_dash( - gamma, gammadash, gammadashdash, gammadashdashdash, alpha, alphadash) - - ndash *= 1/jnp.linalg.norm(gammadash, axis=1)[:, None] - return inner(ndash, b) - -def binormal_curvature_pure_frenet(gamma, gammadash, gammadashdash, gammadashdashdash, - alpha, alphadash): - - """Binormal curvature function for export/evaulate coil sets.""" - - _, _, b = rotated_frenet_frame(gamma, gammadash, gammadashdash, alpha) - tdash, _, _ = rotated_frenet_frame_dash( - gamma, gammadash, gammadashdash, gammadashdashdash, alpha, alphadash) - - tdash *= 1/jnp.linalg.norm(gammadash, axis=1)[:, None] - return inner(tdash, b) - -def torsion_pure_centroid(gamma, gammadash, gammadashdash, - alpha, alphadash): - """Torsion function for export/evaulate coil sets""" - - _, _, b = rotated_centroid_frame(gamma, gammadash, alpha) - _, ndash, _ = rotated_centroid_frame_dash( - gamma, gammadash, gammadashdash, alpha, alphadash) - - ndash *= 1/jnp.linalg.norm(gammadash, axis=1)[:, None] - return inner(ndash, b) - -def binormal_curvature_pure_centroid(gamma, gammadash, gammadashdash, - alpha, alphadash): - - """Binormal curvature function for export/evaulate coil sets.""" - - _, _, b = rotated_centroid_frame(gamma, gammadash, alpha) - tdash, _, _ = rotated_centroid_frame_dash( - gamma, gammadash, gammadashdash, alpha, alphadash) - - tdash *= 1/jnp.linalg.norm(gammadash, axis=1)[:, None] - return inner(tdash, b) diff --git a/src/simsopt/geo/strain_optimization_classes.py b/src/simsopt/geo/strain_optimization_classes.py index 22f5df66c..7ec0dff66 100644 --- a/src/simsopt/geo/strain_optimization_classes.py +++ b/src/simsopt/geo/strain_optimization_classes.py @@ -7,7 +7,7 @@ from jax import vjp, jvp, grad import simsoptpp as sopp from simsopt.geo.jit import jit -from simsopt.geo import ZeroRotation, FilamentRotation, Curve +from simsopt.geo import ZeroRotation, Curve from simsopt._core import Optimizable from simsopt._core.derivative import derivative_dec from simsopt.geo.curveobjectives import Lp_curvature_pure @@ -31,8 +31,8 @@ class StrainOpt(Optimizable): """Class for strain optimization""" - def __init__(self, curvefilament, width=3): - self.curvefilament = curvefilament + def __init__(self, framedcurve, width=3): + self.framedcurve = framedcurve self.width = width self.J_jax = jit(lambda gamma, gammadash, gammadashdash, gammadashdashdash, alpha, alphadash, width: strain_opt_pure( gamma, gammadash, gammadashdash, gammadashdashdash, alpha, alphadash, width)) @@ -48,16 +48,16 @@ def __init__(self, curvefilament, width=3): self.J_jax, argnums=4)(gamma, gammadash, gammadashdash, gammadashdashdash, alpha, alphadash, width)) self.thisgrad5 = jit(lambda gamma, gammadash, gammadashdash, gammadashdashdash, alpha, alphadash, width: grad( self.J_jax, argnums=5)(gamma, gammadash, gammadashdash, gammadashdashdash, alpha, alphadash, width)) - super().__init__(depends_on=[curvefilament]) + super().__init__(depends_on=[framedcurve]) def torsional_strain(self): """Exports torsion along a coil for a StrainOpt object""" - torsion = self.curvefilament.frame_torsion() - return torsion**2 * self.width**2 / 12 # From 2020 Paz-Soldan + torsion = self.framedcurve.frame_torsion() + return torsion**2 * self.width**2 / 12 # From 2020 Paz-Soldan def binormal_curvature_strain(self): - binormal_curvature = self.curvefilament.frame_binormal_curvature() - return (self.width/2)*binormal_curvature # From 2020 Paz-Soldan + binormal_curvature = self.framedcurve.frame_binormal_curvature() + return (self.width/2)*binormal_curvature # From 2020 Paz-Soldan # def J(self): # """