diff --git a/symnet/expr.py b/symnet/expr.py new file mode 100644 index 0000000..d69a540 --- /dev/null +++ b/symnet/expr.py @@ -0,0 +1,187 @@ +import numpy as np +import torch +import sympy +ISINSTALLMATLAB = True +try: + import matlab +except ModuleNotFoundError: + ISINSTALLMATLAB = False + matlab = None + +__all__ = ['poly',] + +class poly(torch.nn.Module): + def __init__(self, hidden_layers, channel_num, channel_names=None, normalization_weight=None): + super(poly, self).__init__() + self.hidden_layers = hidden_layers + self.channel_num = channel_num + if channel_names is None: + channel_names = list('u'+str(i) for i in range(self.channel_num)) + self.channel_names = channel_names + layer = [] + for k in range(hidden_layers): + module = torch.nn.Linear(channel_num+k,2).to(dtype=torch.float64) + module.weight.data.fill_(0) + module.bias.data.fill_(0) + self.add_module('layer'+str(k), module) + layer.append(self.__getattr__('layer'+str(k))) + module = torch.nn.Linear(channel_num+hidden_layers, 1).to(dtype=torch.float64) + module.weight.data.fill_(0) + module.bias.data.fill_(0) + self.add_module('layer_final', module) + layer.append(self.__getattr__('layer_final')) + self.layer = tuple(layer) + nw = torch.ones(channel_num).to(dtype=torch.float64) + if (not isinstance(normalization_weight, torch.Tensor)) and (not normalization_weight is None): + normalization_weight = np.array(normalization_weight) + normalization_weight = torch.from_numpy(normalization_weight).to(dtype=torch.float64) + normalization_weight = normalization_weight.view(channel_num) + nw = normalization_weight + self.register_buffer('_nw', nw) + @property + def channels(self): + channels = sympy.symbols(self.channel_names) + return channels + def renormalize(self, nw): + if (not isinstance(nw, torch.Tensor)) and (not nw is None): + nw = np.array(nw) + nw = torch.from_numpy(nw) + nw1 = nw.view(self.channel_num) + nw1 = nw1.to(self._nw) + nw0 = self._nw + scale = nw0/nw1 + self._nw.data = nw1 + for L in self.layer: + L.weight.data[:,:self.channel_num] *= scale + return None + def _cast2numpy(self, layer): + weight,bias = layer.weight.data.cpu().numpy(), \ + layer.bias.data.cpu().numpy() + return weight,bias + def _cast2matsym(self, layer, eng): + weight,bias = self._cast2numpy(layer) + weight,bias = weight.tolist(),bias.tolist() + weight,bias = matlab.double(weight),matlab.double(bias) + eng.workspace['weight'],eng.workspace['bias'] = weight,bias + eng.workspace['weight'] = eng.eval("sym(weight,'d')") + eng.workspace['bias'] = eng.eval("sym(bias,'d')") + return None + def _cast2symbol(self, layer): + weight,bias = self._cast2numpy(layer) + weight,bias = sympy.Matrix(weight),sympy.Matrix(bias) + return weight,bias + def _sympychop(self, o, calprec): + cdict = o.expand().as_coefficients_dict() + o = 0 + for k,v in cdict.items(): + if abs(v)>0.1**calprec: + o = o+k*v + return o + def _matsymchop(self, o, calprec, eng): + eng.eval('[c,t] = coeffs('+o+');', nargout=0) + eng.eval('c = double(c);', nargout=0) + eng.eval('c(abs(c)<1e-'+calprec+') = 0;', nargout=0) + eng.eval(o+" = sum(sym(c, 'd').*t);", nargout=0) + return None + def expression(self, calprec=6, eng=None, isexpand=True): + if eng is None: + channels = sympy.symbols(self.channel_names) + for i in range(self.channel_num): + channels[i] = self._nw[i].item()*channels[i] + channels = sympy.Matrix([channels,]) + for k in range(self.hidden_layers): + weight,bias = self._cast2symbol(self.layer[k]) + o = weight*channels.transpose()+bias + if isexpand: + o[0] = self._sympychop(o[0], calprec) + o[1] = self._sympychop(o[1], calprec) + channels = list(channels)+[o[0]*o[1],] + channels = sympy.Matrix([channels,]) + weight,bias = self._cast2symbol(self.layer[-1]) + o = (weight*channels.transpose()+bias)[0] + if isexpand: + o = o.expand() + o = self._sympychop(o, calprec) + return o + else: + calprec = str(calprec) + eng.clear(nargout=0) + eng.syms(self.channel_names, nargout=0) + channels = "" + for c in self.channel_names: + channels = channels+" "+c + eng.eval('syms'+channels,nargout=0) + channels = "["+channels+"].'" + eng.workspace['channels'] = eng.eval(channels) + eng.workspace['nw'] = matlab.double(self._nw.data.cpu().numpy().tolist()) + eng.eval("channels = channels.*nw.';", nargout=0) + for k in range(self.hidden_layers): + self._cast2matsym(self.layer[k], eng) + eng.eval("o = weight*channels+bias';", nargout=0) + eng.eval('o = o(1)*o(2);', nargout=0) + if isexpand: + eng.eval('o = expand(o);', nargout=0) + self._matsymchop('o', calprec, eng) + eng.eval('channels = [channels;o];', nargout=0) + self._cast2matsym(self.layer[-1],eng) + eng.eval("o = weight*channels+bias';", nargout=0) + if isexpand: + eng.eval("o = expand(o);", nargout=0) + self._matsymchop('o', calprec, eng) + return eng.workspace['o'] + def coeffs(self, calprec=6, eng=None, o=None, iprint=0): + if eng is None: + if o is None: + o = self.expression(calprec, eng=None, isexpand=True) + cdict = o.as_coefficients_dict() + t = np.array(list(cdict.keys())) + c = np.array(list(cdict.values()), dtype=np.float64) + I = np.abs(c).argsort()[::-1] + t = list(t[I]) + c = c[I] + if iprint > 0: + print(o) + else: + if o is None: + self.expression(calprec, eng=eng, isexpand=True) + else: + eng.workspace['o'] = eng.expand(o) + eng.eval('[c,t] = coeffs(o);', nargout=0) + eng.eval('c = double(c);', nargout=0) + eng.eval("[~,I] = sort(abs(c), 'descend'); c = c(I); t = t(I);", nargout=0) + eng.eval('m = cell(numel(t),1);', nargout=0) + eng.eval('for i=1:numel(t) m(i) = {char(t(i))}; end', nargout=0) + if iprint > 0: + eng.eval('disp(o)', nargout=0) + t = list(eng.workspace['m']) + c = np.array(eng.workspace['c']).flatten() + return t,c + def symboleval(self,inputs,eng=None,o=None): + if isinstance(inputs, torch.Tensor): + inputs = inputs.data.cpu().numpy() + if isinstance(inputs, np.ndarray): + inputs = list(inputs) + assert len(inputs) == len(self.channel_names) + if eng is None: + if o is None: + o = self.expression() + return o.subs(dict(zip(self.channels,inputs))) + else: + if o is None: + o = self.expression(eng=eng) + channels = "[" + for c in self.channel_names: + channels = channels+" "+c + channels = channels+"].'" + eng.workspace['channels'] = eng.eval(channels) + eng.workspace['tmp'] = o + eng.workspace['tmpv'] = matlab.double(inputs) + eng.eval("tmpresults = double(subs(tmp,channels.',tmpv));",nargout=0) + return np.array(eng.workspace['tmpresults']) + def forward(self, inputs): + outputs = inputs*self._nw + for k in range(self.hidden_layers): + o = self.layer[k](outputs) + outputs = torch.cat([outputs,o[...,:1]*o[...,1:]], dim=-1) + outputs = self.layer[-1](outputs) + return outputs[...,0] diff --git a/symnet/initcoefficients.py b/symnet/initcoefficients.py new file mode 100644 index 0000000..680cf2a --- /dev/null +++ b/symnet/initcoefficients.py @@ -0,0 +1,144 @@ +import symnet.expr as expr +from symnet.preproc_input import prepare_batches +from symnet.prepare_left_side import init_left_term,get_left_pool +from symnet.initparams import initexpr +import torch +from symnet.loss import loss +from symnet.preproc_output import * + +import seaborn as sns +import matplotlib.pyplot as plt +from sympy import Symbol, Pow, Mul + + +def clean_names(left_name, names: list): + new_names = names.copy() + idx = None + if left_name in new_names: + idx = new_names.index(left_name) + new_names.remove(left_name) + + return new_names, idx + + +def train_model(input_names, x_train, y_train, sparsity): + + def closure(): + lbfgs.zero_grad() + tloss = loss(model, y_train, x_train, block=1, sparsity=sparsity) + tloss.backward() + return tloss + + model = expr.poly(2, channel_num=len(input_names), channel_names=input_names) + initexpr(model) + lbfgs = torch.optim.LBFGS(params=model.parameters(), max_iter=2000, line_search_fn='strong_wolfe') + model.train() + lbfgs.step(closure) + last_step_loss = loss(model, y_train, x_train, block=1, sparsity=sparsity) + + return model, last_step_loss.item() + + + + + + + +def right_matrices_coef(matrices, names: list[str], csym, tsym): + token_matrix = {} + for i in range(len(names)): + token_matrix[Symbol(names[i])] = matrices[i] + + right_side = [] + for i in range(len(csym)): + total_mx = 1 + if type(tsym[i]) == Mul: + if tsym[i] == Mul(Symbol("u"), Symbol("du/dx2")): + u_ux_ind = i + lbls = tsym[i].args + for lbl in lbls: + if type(lbl) == Symbol: + total_mx *= token_matrix.get(lbl) + else: + for j in range(lbl.args[1]): + total_mx *= token_matrix.get(lbl.args[0]) + elif type(tsym[i]) == Symbol: + total_mx *= token_matrix.get(tsym[i]) + elif type(tsym[i]) == Pow: + for j in range(tsym[i].args[1]): + total_mx *= token_matrix.get(tsym[i].args[0]) + total_mx *= csym[i] + right_side.append(total_mx) + + u_ux = 1 + for lbl in (Symbol("u"), Symbol("du/dx2")): + u_ux *= token_matrix.get(lbl) + right_u_ux = csym[u_ux_ind] * u_ux + diff1 = np.fabs((np.abs(csym[u_ux_ind]) - 1) * u_ux) + return right_side, right_u_ux, u_ux + + +def select_model1(input_names, left_pool, u, derivs, shape, sparsity, additional_tokens): + models = [] + losses = [] + for left_side_name in left_pool: + m_input_names, idx = clean_names(left_side_name, input_names) + x_train, y_train = prepare_batches(u, derivs, shape, idx, additional_tokens=additional_tokens) + model, last_loss = train_model(m_input_names, x_train, y_train, sparsity) + + tsym, csym = model.coeffs(calprec=16) + losses.append(last_loss) + models.append(model) + + idx = losses.index(min(losses)) + return models[idx], left_pool[idx] + + + + + + +def select_model(input_names, left_pool, u, derivs, shape, sparsity, additional_tokens): + models = [] + losses = [] + for left_side_name in left_pool: + m_input_names, idx = clean_names(left_side_name, input_names) + x_train, y_train = prepare_batches(u, derivs, shape, idx, additional_tokens=additional_tokens) + model, last_loss = train_model(m_input_names, x_train, y_train, sparsity) + losses.append(last_loss) + models.append(model) + + idx = losses.index(min(losses)) + return models[idx], left_pool[idx] + + +def save_fig(csym, add_left=True): + distr = np.fabs(csym.copy()) + if add_left: + distr = np.append(distr, (distr[0] + distr[1]) / 2) + distr.sort() + distr = distr[::-1] + + fig, ax = plt.subplots(figsize=(16, 8)) + ax.set_ylim(0, np.max(distr) + 0.01) + sns.barplot(x=np.arange(len(distr)), y=distr, orient="v", ax=ax) + plt.grid() + # plt.show() + plt.yticks(fontsize=50) + plt.savefig(f'symnet_distr{len(distr)}.png', transparent=True) + + +def get_csym_tsym(u, derivs, shape, input_names, pool_names, sparsity=0.1, additional_tokens=None, + max_deriv_order=None): + """ + Can process only one variable! (u) + """ + + left_pool = get_left_pool(max_deriv_order) + model, left_side_name = select_model(input_names, left_pool, u, derivs, shape, sparsity, additional_tokens) + tsym, csym = model.coeffs(calprec=16) + # save_fig(csym) + pool_sym_ls = cast_to_symbols(pool_names) + csym_pool_ls = get_csym_pool(tsym, csym, pool_sym_ls, left_side_name) + # save_fig(np.array(csym_pool_ls), add_left=False) + return dict(zip(pool_sym_ls, csym_pool_ls)), pool_sym_ls diff --git a/symnet/initparams.py b/symnet/initparams.py new file mode 100644 index 0000000..946daea --- /dev/null +++ b/symnet/initparams.py @@ -0,0 +1,6 @@ +import torch + + +def initexpr(model): + for p in model.parameters(): + p.data = torch.randn(*p.shape,dtype=p.dtype,device=p.device)*1e-1 \ No newline at end of file diff --git a/symnet/loss.py b/symnet/loss.py new file mode 100644 index 0000000..e0f9fb3 --- /dev/null +++ b/symnet/loss.py @@ -0,0 +1,31 @@ +import torch + + +def _sparse_loss(model): + """ + SymNet regularization + """ + loss = 0 + s = 1e-3 + for p in list(model.parameters()): + p = p.abs() + loss = loss+((p=s).to(p)*(p-s/2)).sum() + return loss + + +def loss(model, u_left, u_right, block, sparsity): + stepnum = block if block >= 1 else 1 + + dataloss = 0 + sparseloss = _sparse_loss(model) + + u_der = u_left + for steps in range(1, stepnum + 1): + u_dertmp = model(u_right) + + dataloss = dataloss + \ + torch.mean((u_dertmp - u_der) ** 2) + # layerweight[steps-1]*torch.mean(((uttmp-u_obs[steps])/(steps*dt))**2) + # ut = u_right + loss = dataloss + stepnum * sparsity * sparseloss + return loss \ No newline at end of file diff --git a/symnet/pool_terms.py b/symnet/pool_terms.py new file mode 100644 index 0000000..a9ab03d --- /dev/null +++ b/symnet/pool_terms.py @@ -0,0 +1,65 @@ +from symnet.initcoefficients import get_csym_tsym +from sympy import Symbol, Mul +# import seaborn as sns +# import matplotlib.pyplot as plt +import itertools + + +class PoolTerms: + def __init__(self, max_factors_in_term, families): + + self.pool_sym_dict = None + self.pool_sym_ls = None + self.pool_dict = None + + token_ls = [] + for family in families: + token_ls += family.tokens + self.term_ls = [] + for i in range(1, max_factors_in_term + 1): + self.term_ls += list(itertools.combinations(token_ls, i)) + + def set_initial_distr(self, u, derivs, shape, names, families, grids, max_deriv_order): + if len(families) == 1: + self.pool_dict, self.pool_sym_ls = \ + get_csym_tsym(u, derivs, shape, names, pool_names=self.term_ls, + max_deriv_order=max_deriv_order) + else: + additional_tokens = _prepare_additional_tokens(families, grids) + names = _prepare_names(names, families) + self.pool_dict, self.pool_sym_ls = \ + get_csym_tsym(u, derivs, shape, names, + pool_names=self.term_ls, additional_tokens=additional_tokens, + max_deriv_order=max_deriv_order) + self.pool_sym_dict = dict(zip(self.pool_sym_ls, self.term_ls)) + + +def _prepare_names(names, families): + names_c = names.copy() + for i in range(1, len(families)): + names_c += families[i].tokens + return names_c + + +# TODO: Обработать общий случай additional_tokens +def _prepare_additional_tokens(families, grids): + mx_ls = [] + for i in range(1, len(families)): + assert len(families[i]) == 1, "Can't process family consisting from more than one token" + + name = families[i].tokens[0] + fun = families[i]._evaluator._evaluator.evaluation_functions.get(name) + mx = fun(*grids, **{'power': families[i].token_params.get('power')[0]}) + mx_ls.append(mx) + return mx_ls + + +def to_symbolic(term): + if type(term.cache_label[0]) == tuple: + labels = [] + for label in term.cache_label: + labels.append(str(label[0])) + symlabels = list(map(lambda token: Symbol(token), labels)) + return Mul(*symlabels) + else: + return Symbol(str(term.cache_label[0])) diff --git a/symnet/prepare_left_side.py b/symnet/prepare_left_side.py new file mode 100644 index 0000000..12a26aa --- /dev/null +++ b/symnet/prepare_left_side.py @@ -0,0 +1,16 @@ +import torch +import numpy as np + + +def get_left_pool(max_deriv_order): + left_pool = ["du/dx1"] + if max_deriv_order[0] > 1: + for i in range(2, max_deriv_order[0]+1): + left_pool.append(f"d^{i}u/dx1^{i}") + return left_pool + + +def init_left_term(families): + labels = families[0].tokens.copy() + labels.remove('u') + return (np.random.choice(labels), ) diff --git a/symnet/preproc_input.py b/symnet/preproc_input.py new file mode 100644 index 0000000..6e760fe --- /dev/null +++ b/symnet/preproc_input.py @@ -0,0 +1,42 @@ +import torch +import numpy as np + + +# TODO: случай когда idx == None +def prepare_batches(u, derivs, shape, idx, additional_tokens=None): + u = np.reshape(u, (shape[0], shape[1], 1)) + if len(derivs.shape) != 3: + derivs = np.reshape(derivs, (shape[0], shape[1], derivs.shape[1])) + mxs = [u, derivs] + + if additional_tokens is not None: + add_mx = np.array(additional_tokens) + mxs.append(np.reshape(add_mx, (shape[0], shape[1], len(additional_tokens)))) + input_matrices = np.concatenate(mxs, axis=2) + + return _create_batch(input_matrices, 32, left_idx=idx) + + +def _create_batch(matrices, batch_size, left_idx): + n_batch_row = matrices[:, :, 0].shape[0] // batch_size + in_row_indent = matrices[:, :, 0].shape[0] % batch_size // 2 + n_batch_col = matrices[:, :, 0].shape[1] // batch_size + in_col_indent = matrices[:, :, 0].shape[1] % batch_size // 2 + + def pack_token(k): + elem_ls = [] + for i in range(n_batch_row): + for j in range(n_batch_col): + elem_ls.append(matrices[ + in_row_indent + i * batch_size:in_row_indent + (i + 1) * batch_size, + in_col_indent + j * batch_size:in_col_indent + (j + 1) * batch_size, k]) + return elem_ls + + all_tokens_ls = [] + for l in range(matrices.shape[2]): + if l == left_idx: + left_side = torch.from_numpy(np.asarray(pack_token(l))) + else: + all_tokens_ls.append(pack_token(l)) + right_side = torch.from_numpy(np.asarray(all_tokens_ls)) + return torch.permute(right_side, (1, 2, 3, 0)), left_side \ No newline at end of file diff --git a/symnet/preproc_output.py b/symnet/preproc_output.py new file mode 100644 index 0000000..8bf66cb --- /dev/null +++ b/symnet/preproc_output.py @@ -0,0 +1,76 @@ +from sympy import Mul, Symbol +import numpy as np + + +# def get_csym_pool(tsym: list, csym: list, pool_ls: list, left_side_name: tuple[str]): +# +# symnet_dict = dict(zip(tsym, csym)) +# csym_pool_ls = [] +# for tsym_pool_el in pool_ls: +# csym_pool_ls.append(symnet_dict.get(tsym_pool_el, 1e-6)) +# +# left_idx = pool_ls.index(left_to_sym(left_side_name)) +# csym_pool_ls[left_idx] = left_csym(csym) +# return csym_pool_ls + +def get_csym_pool(tsym: list, csym: list, pool_ls: list, left_side_name: str): + + symnet_dict = dict(zip(tsym, csym)) + csym_pool_ls = [] + for tsym_pool_el in pool_ls: + csym_pool_ls.append(symnet_dict.get(tsym_pool_el, 1e-6)) + + left_idx = pool_ls.index(Symbol(left_side_name)) + csym_pool_ls[left_idx] = left_csym(csym) + return csym_pool_ls + + +def left_csym(csym): + if len(csym) > 1: + return (np.fabs(csym[0]) + np.fabs(csym[1])) / 2 + else: + return csym[0] + + +def left_to_sym(left_term: tuple[str]): + term_symbolic = list(map(lambda u: Symbol(u), left_term)) + return Mul(*term_symbolic) + + +def cast_to_symbols(pool_names: list[tuple[str]]): + + pool_ls = [] + for name in pool_names: + term_symbolic = list(map(lambda u: Symbol(u), name)) + pool_ls.append(Mul(*term_symbolic)) + return pool_ls + + +def to_symbolic(term): + if type(term.cache_label[0]) == tuple: + labels = [] + for label in term.cache_label: + labels.append(str(label[0])) + symlabels = list(map(lambda token: Symbol(token), labels)) + return Mul(*symlabels) + else: + return Symbol(str(term.cache_label[0])) + + +def get_cross_distr(custom_cross_prob, start_idx, end_idx_exclude): + mmf = 2.4 + values = list(custom_cross_prob.values()) + csym_arr = np.fabs(np.array(values)) + + if np.max(csym_arr) / np.min(csym_arr) > 2.6: + min_max_coeff = mmf * np.min(csym_arr) - np.max(csym_arr) + smoothing_factor = min_max_coeff / (min_max_coeff - (mmf - 1) * np.average(csym_arr)) + uniform_csym = np.array([np.sum(csym_arr) / len(csym_arr)] * len(csym_arr)) + + smoothed_array = (1 - smoothing_factor) * csym_arr + smoothing_factor * uniform_csym + inv = 1 / smoothed_array + else: + inv = 1 / csym_arr + inv_norm = inv / np.sum(inv) + + return dict(zip([i for i in range(start_idx, end_idx_exclude)], inv_norm.tolist())) \ No newline at end of file