Skip to content

Commit

Permalink
Add symnet package
Browse files Browse the repository at this point in the history
  • Loading branch information
LisIva committed Mar 15, 2024
1 parent 83b86ca commit 6b05cba
Show file tree
Hide file tree
Showing 8 changed files with 567 additions and 0 deletions.
187 changes: 187 additions & 0 deletions symnet/expr.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,187 @@
import numpy as np
import torch
import sympy
ISINSTALLMATLAB = True
try:
import matlab
except ModuleNotFoundError:
ISINSTALLMATLAB = False
matlab = None

__all__ = ['poly',]

class poly(torch.nn.Module):
def __init__(self, hidden_layers, channel_num, channel_names=None, normalization_weight=None):
super(poly, self).__init__()
self.hidden_layers = hidden_layers
self.channel_num = channel_num
if channel_names is None:
channel_names = list('u'+str(i) for i in range(self.channel_num))
self.channel_names = channel_names
layer = []
for k in range(hidden_layers):
module = torch.nn.Linear(channel_num+k,2).to(dtype=torch.float64)
module.weight.data.fill_(0)
module.bias.data.fill_(0)
self.add_module('layer'+str(k), module)
layer.append(self.__getattr__('layer'+str(k)))
module = torch.nn.Linear(channel_num+hidden_layers, 1).to(dtype=torch.float64)
module.weight.data.fill_(0)
module.bias.data.fill_(0)
self.add_module('layer_final', module)
layer.append(self.__getattr__('layer_final'))
self.layer = tuple(layer)
nw = torch.ones(channel_num).to(dtype=torch.float64)
if (not isinstance(normalization_weight, torch.Tensor)) and (not normalization_weight is None):
normalization_weight = np.array(normalization_weight)
normalization_weight = torch.from_numpy(normalization_weight).to(dtype=torch.float64)
normalization_weight = normalization_weight.view(channel_num)
nw = normalization_weight
self.register_buffer('_nw', nw)
@property
def channels(self):
channels = sympy.symbols(self.channel_names)
return channels
def renormalize(self, nw):
if (not isinstance(nw, torch.Tensor)) and (not nw is None):
nw = np.array(nw)
nw = torch.from_numpy(nw)
nw1 = nw.view(self.channel_num)
nw1 = nw1.to(self._nw)
nw0 = self._nw
scale = nw0/nw1
self._nw.data = nw1
for L in self.layer:
L.weight.data[:,:self.channel_num] *= scale
return None
def _cast2numpy(self, layer):
weight,bias = layer.weight.data.cpu().numpy(), \
layer.bias.data.cpu().numpy()
return weight,bias
def _cast2matsym(self, layer, eng):
weight,bias = self._cast2numpy(layer)
weight,bias = weight.tolist(),bias.tolist()
weight,bias = matlab.double(weight),matlab.double(bias)
eng.workspace['weight'],eng.workspace['bias'] = weight,bias
eng.workspace['weight'] = eng.eval("sym(weight,'d')")
eng.workspace['bias'] = eng.eval("sym(bias,'d')")
return None
def _cast2symbol(self, layer):
weight,bias = self._cast2numpy(layer)
weight,bias = sympy.Matrix(weight),sympy.Matrix(bias)
return weight,bias
def _sympychop(self, o, calprec):
cdict = o.expand().as_coefficients_dict()
o = 0
for k,v in cdict.items():
if abs(v)>0.1**calprec:
o = o+k*v
return o
def _matsymchop(self, o, calprec, eng):
eng.eval('[c,t] = coeffs('+o+');', nargout=0)
eng.eval('c = double(c);', nargout=0)
eng.eval('c(abs(c)<1e-'+calprec+') = 0;', nargout=0)
eng.eval(o+" = sum(sym(c, 'd').*t);", nargout=0)
return None
def expression(self, calprec=6, eng=None, isexpand=True):
if eng is None:
channels = sympy.symbols(self.channel_names)
for i in range(self.channel_num):
channels[i] = self._nw[i].item()*channels[i]
channels = sympy.Matrix([channels,])
for k in range(self.hidden_layers):
weight,bias = self._cast2symbol(self.layer[k])
o = weight*channels.transpose()+bias
if isexpand:
o[0] = self._sympychop(o[0], calprec)
o[1] = self._sympychop(o[1], calprec)
channels = list(channels)+[o[0]*o[1],]
channels = sympy.Matrix([channels,])
weight,bias = self._cast2symbol(self.layer[-1])
o = (weight*channels.transpose()+bias)[0]
if isexpand:
o = o.expand()
o = self._sympychop(o, calprec)
return o
else:
calprec = str(calprec)
eng.clear(nargout=0)
eng.syms(self.channel_names, nargout=0)
channels = ""
for c in self.channel_names:
channels = channels+" "+c
eng.eval('syms'+channels,nargout=0)
channels = "["+channels+"].'"
eng.workspace['channels'] = eng.eval(channels)
eng.workspace['nw'] = matlab.double(self._nw.data.cpu().numpy().tolist())
eng.eval("channels = channels.*nw.';", nargout=0)
for k in range(self.hidden_layers):
self._cast2matsym(self.layer[k], eng)
eng.eval("o = weight*channels+bias';", nargout=0)
eng.eval('o = o(1)*o(2);', nargout=0)
if isexpand:
eng.eval('o = expand(o);', nargout=0)
self._matsymchop('o', calprec, eng)
eng.eval('channels = [channels;o];', nargout=0)
self._cast2matsym(self.layer[-1],eng)
eng.eval("o = weight*channels+bias';", nargout=0)
if isexpand:
eng.eval("o = expand(o);", nargout=0)
self._matsymchop('o', calprec, eng)
return eng.workspace['o']
def coeffs(self, calprec=6, eng=None, o=None, iprint=0):
if eng is None:
if o is None:
o = self.expression(calprec, eng=None, isexpand=True)
cdict = o.as_coefficients_dict()
t = np.array(list(cdict.keys()))
c = np.array(list(cdict.values()), dtype=np.float64)
I = np.abs(c).argsort()[::-1]
t = list(t[I])
c = c[I]
if iprint > 0:
print(o)
else:
if o is None:
self.expression(calprec, eng=eng, isexpand=True)
else:
eng.workspace['o'] = eng.expand(o)
eng.eval('[c,t] = coeffs(o);', nargout=0)
eng.eval('c = double(c);', nargout=0)
eng.eval("[~,I] = sort(abs(c), 'descend'); c = c(I); t = t(I);", nargout=0)
eng.eval('m = cell(numel(t),1);', nargout=0)
eng.eval('for i=1:numel(t) m(i) = {char(t(i))}; end', nargout=0)
if iprint > 0:
eng.eval('disp(o)', nargout=0)
t = list(eng.workspace['m'])
c = np.array(eng.workspace['c']).flatten()
return t,c
def symboleval(self,inputs,eng=None,o=None):
if isinstance(inputs, torch.Tensor):
inputs = inputs.data.cpu().numpy()
if isinstance(inputs, np.ndarray):
inputs = list(inputs)
assert len(inputs) == len(self.channel_names)
if eng is None:
if o is None:
o = self.expression()
return o.subs(dict(zip(self.channels,inputs)))
else:
if o is None:
o = self.expression(eng=eng)
channels = "["
for c in self.channel_names:
channels = channels+" "+c
channels = channels+"].'"
eng.workspace['channels'] = eng.eval(channels)
eng.workspace['tmp'] = o
eng.workspace['tmpv'] = matlab.double(inputs)
eng.eval("tmpresults = double(subs(tmp,channels.',tmpv));",nargout=0)
return np.array(eng.workspace['tmpresults'])
def forward(self, inputs):
outputs = inputs*self._nw
for k in range(self.hidden_layers):
o = self.layer[k](outputs)
outputs = torch.cat([outputs,o[...,:1]*o[...,1:]], dim=-1)
outputs = self.layer[-1](outputs)
return outputs[...,0]
144 changes: 144 additions & 0 deletions symnet/initcoefficients.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,144 @@
import symnet.expr as expr
from symnet.preproc_input import prepare_batches
from symnet.prepare_left_side import init_left_term,get_left_pool
from symnet.initparams import initexpr
import torch
from symnet.loss import loss
from symnet.preproc_output import *

import seaborn as sns
import matplotlib.pyplot as plt
from sympy import Symbol, Pow, Mul


def clean_names(left_name, names: list):
new_names = names.copy()
idx = None
if left_name in new_names:
idx = new_names.index(left_name)
new_names.remove(left_name)

return new_names, idx


def train_model(input_names, x_train, y_train, sparsity):

def closure():
lbfgs.zero_grad()
tloss = loss(model, y_train, x_train, block=1, sparsity=sparsity)
tloss.backward()
return tloss

model = expr.poly(2, channel_num=len(input_names), channel_names=input_names)
initexpr(model)
lbfgs = torch.optim.LBFGS(params=model.parameters(), max_iter=2000, line_search_fn='strong_wolfe')
model.train()
lbfgs.step(closure)
last_step_loss = loss(model, y_train, x_train, block=1, sparsity=sparsity)

return model, last_step_loss.item()







def right_matrices_coef(matrices, names: list[str], csym, tsym):
token_matrix = {}
for i in range(len(names)):
token_matrix[Symbol(names[i])] = matrices[i]

right_side = []
for i in range(len(csym)):
total_mx = 1
if type(tsym[i]) == Mul:
if tsym[i] == Mul(Symbol("u"), Symbol("du/dx2")):
u_ux_ind = i
lbls = tsym[i].args
for lbl in lbls:
if type(lbl) == Symbol:
total_mx *= token_matrix.get(lbl)
else:
for j in range(lbl.args[1]):
total_mx *= token_matrix.get(lbl.args[0])
elif type(tsym[i]) == Symbol:
total_mx *= token_matrix.get(tsym[i])
elif type(tsym[i]) == Pow:
for j in range(tsym[i].args[1]):
total_mx *= token_matrix.get(tsym[i].args[0])
total_mx *= csym[i]
right_side.append(total_mx)

u_ux = 1
for lbl in (Symbol("u"), Symbol("du/dx2")):
u_ux *= token_matrix.get(lbl)
right_u_ux = csym[u_ux_ind] * u_ux
diff1 = np.fabs((np.abs(csym[u_ux_ind]) - 1) * u_ux)
return right_side, right_u_ux, u_ux


def select_model1(input_names, left_pool, u, derivs, shape, sparsity, additional_tokens):
models = []
losses = []
for left_side_name in left_pool:
m_input_names, idx = clean_names(left_side_name, input_names)
x_train, y_train = prepare_batches(u, derivs, shape, idx, additional_tokens=additional_tokens)
model, last_loss = train_model(m_input_names, x_train, y_train, sparsity)

tsym, csym = model.coeffs(calprec=16)
losses.append(last_loss)
models.append(model)

idx = losses.index(min(losses))
return models[idx], left_pool[idx]






def select_model(input_names, left_pool, u, derivs, shape, sparsity, additional_tokens):
models = []
losses = []
for left_side_name in left_pool:
m_input_names, idx = clean_names(left_side_name, input_names)
x_train, y_train = prepare_batches(u, derivs, shape, idx, additional_tokens=additional_tokens)
model, last_loss = train_model(m_input_names, x_train, y_train, sparsity)
losses.append(last_loss)
models.append(model)

idx = losses.index(min(losses))
return models[idx], left_pool[idx]


def save_fig(csym, add_left=True):
distr = np.fabs(csym.copy())
if add_left:
distr = np.append(distr, (distr[0] + distr[1]) / 2)
distr.sort()
distr = distr[::-1]

fig, ax = plt.subplots(figsize=(16, 8))
ax.set_ylim(0, np.max(distr) + 0.01)
sns.barplot(x=np.arange(len(distr)), y=distr, orient="v", ax=ax)
plt.grid()
# plt.show()
plt.yticks(fontsize=50)
plt.savefig(f'symnet_distr{len(distr)}.png', transparent=True)


def get_csym_tsym(u, derivs, shape, input_names, pool_names, sparsity=0.1, additional_tokens=None,
max_deriv_order=None):
"""
Can process only one variable! (u)
"""

left_pool = get_left_pool(max_deriv_order)
model, left_side_name = select_model(input_names, left_pool, u, derivs, shape, sparsity, additional_tokens)
tsym, csym = model.coeffs(calprec=16)
# save_fig(csym)
pool_sym_ls = cast_to_symbols(pool_names)
csym_pool_ls = get_csym_pool(tsym, csym, pool_sym_ls, left_side_name)
# save_fig(np.array(csym_pool_ls), add_left=False)
return dict(zip(pool_sym_ls, csym_pool_ls)), pool_sym_ls
6 changes: 6 additions & 0 deletions symnet/initparams.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
import torch


def initexpr(model):
for p in model.parameters():
p.data = torch.randn(*p.shape,dtype=p.dtype,device=p.device)*1e-1
31 changes: 31 additions & 0 deletions symnet/loss.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
import torch


def _sparse_loss(model):
"""
SymNet regularization
"""
loss = 0
s = 1e-3
for p in list(model.parameters()):
p = p.abs()
loss = loss+((p<s).to(p)*0.5/s*p**2).sum()+((p>=s).to(p)*(p-s/2)).sum()
return loss


def loss(model, u_left, u_right, block, sparsity):
stepnum = block if block >= 1 else 1

dataloss = 0
sparseloss = _sparse_loss(model)

u_der = u_left
for steps in range(1, stepnum + 1):
u_dertmp = model(u_right)

dataloss = dataloss + \
torch.mean((u_dertmp - u_der) ** 2)
# layerweight[steps-1]*torch.mean(((uttmp-u_obs[steps])/(steps*dt))**2)
# ut = u_right
loss = dataloss + stepnum * sparsity * sparseloss
return loss
Loading

0 comments on commit 6b05cba

Please sign in to comment.