Skip to content

Commit

Permalink
Fix for later numpy version
Browse files Browse the repository at this point in the history
  • Loading branch information
Maslyaev committed Nov 13, 2023
1 parent 2704b80 commit aef227a
Show file tree
Hide file tree
Showing 5 changed files with 26 additions and 12 deletions.
5 changes: 3 additions & 2 deletions epde/globals.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,11 +86,12 @@ class VerboseManager:
show_iter_idx : bool
iter_fitness : bool
iter_stats : bool
show_ann_loss : bool
show_warnings : bool

def init_verbose(plot_DE_solutions : bool = False, show_iter_idx : bool = True,
show_iter_fitness : bool = False, show_iter_stats : bool = False,
show_warnings : bool = False):
show_ann_loss : bool = False, show_warnings : bool = False):
"""
Method for initialized of manager for output in text form
Expand All @@ -110,4 +111,4 @@ def init_verbose(plot_DE_solutions : bool = False, show_iter_idx : bool = True,
if not show_warnings:
warnings.filterwarnings("ignore")
verbose = VerboseManager(plot_DE_solutions, show_iter_idx, show_iter_fitness,
show_iter_stats, show_warnings)
show_iter_stats, show_ann_loss, show_warnings)
10 changes: 8 additions & 2 deletions epde/optimizers/moeadd/moeadd.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,13 @@

import numpy as np
import warnings
from itertools import chain

from typing import Union

def flatten_chain(matrix):
return list(chain.from_iterable(matrix))

# from copy import deepcopy
# from functools import reduce

Expand Down Expand Up @@ -154,7 +159,8 @@ def delete_point(self, point):
self.population = population_cleared

def get_stats(self):
return np.array([[element.obj_fun for element in level] for level in self.levels])
return np.array(flatten_chain([[element.obj_fun for element in level]
for level in self.levels]))

def fit_convex_hull(self):
"""
Expand Down Expand Up @@ -448,4 +454,4 @@ def plot_pareto(self, dimensions:list, **visualizer_kwargs):
assert len(dimensions) == 2, 'Current approach supports only two dimensional plots'
visualizer = ParetoVisualizer(self.pareto_levels)
visualizer.plot_pareto(dimensions = dimensions, **visualizer_kwargs)


4 changes: 3 additions & 1 deletion epde/preprocessing/smoothers.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
import torch
device = torch.device('cpu')

import epde.globals as global_var

class AbstractSmoother(ABC):
def __init__(self, *args, **kwargs):
Expand Down Expand Up @@ -89,7 +90,8 @@ def __call__(self, data, grid, epochs_max=1e3, loss_mean=1000, batch_frac=0.5,
if loss_mean < min_loss:
best_model = model
min_loss = loss_mean
print('Surface training t={}, loss={}'.format(t, loss_mean))
if global_var.verbose.show_ann_loss:
print('Surface training t={}, loss={}'.format(t, loss_mean))
t += 1

data_approx = best_model(grid_flattened).detach().numpy().reshape(original_shape)
Expand Down
4 changes: 3 additions & 1 deletion epde/supplementary.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
device = torch.device('cpu')

import matplotlib.pyplot as plt
import epde.globals as global_var

def exp_form(a, sign_num: int = 4):
if np.isclose(a, 0):
Expand Down Expand Up @@ -95,7 +96,8 @@ def train_ann(grids: list, data: np.ndarray, epochs_max: int = 500):
best_model = model
min_loss = loss_mean
losses.append(loss_mean)
print('Surface training t={}, loss={}'.format(t, loss_mean))
if global_var.verbose.show_ann_loss:
print('Surface training t={}, loss={}'.format(t, loss_mean))
t += 1
print_loss = True
if print_loss:
Expand Down
15 changes: 9 additions & 6 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,10 +11,11 @@
from pathlib import Path
import pathlib

here = pathlib.Path(__file__).parent.resolve()
HERE = pathlib.Path(__file__).parent.resolve()
README = Path(HERE, 'README.rst').read_text(encoding='utf-8')
SHORT_DESCRIPTION = 'Data-driven dynamical system and differential equations discovery framework'

# Get the long description from the README file
long_description = (here / 'README.rst').read_text(encoding='utf-8')


def read(*names, **kwargs):
Expand All @@ -35,8 +36,10 @@ def get_requirements():

setup(
name = 'epde',
version = '1.2.9',
description = 'EPDE package',
version = '1.2.14',
description = SHORT_DESCRIPTION,
long_description="PLACEHOLDER",
# long_description_content_type='text/x-rst',
author = 'Mikhail Maslyaev',
author_email = '[email protected]',
classifiers = [
Expand All @@ -45,7 +48,7 @@ def get_requirements():
'License :: OSI Approved :: MIT License',
'Operating System :: OS Independent',
],
packages = find_packages(include = ['epde', 'epde.cache', 'epde.interface', 'epde.moeadd',
packages = find_packages(include = ['epde', 'epde.cache', 'epde.interface',
'epde.optimizers', 'epde.optimizers.moeadd',
'epde.optimizers.single_criterion', 'epde.operators.common',
'epde.operators', 'epde.operators.utils',
Expand All @@ -54,5 +57,5 @@ def get_requirements():
'epde.operators.singleobjective', 'epde.preprocessing',
'epde.parametric', 'epde.structure', 'epde.solver']),
include_package_data = True,
python_requires =' >=3.8'
python_requires =' >=3.8',
)

0 comments on commit aef227a

Please sign in to comment.