Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

work in progress, fix github issues #238

Merged
merged 4 commits into from
Aug 20, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
The table of contents is too big for display.
Diff view
Diff view
  •  
  •  
  •  
2 changes: 1 addition & 1 deletion code/nnv/engine/nn/layers/MaxPooling2DLayer.m
Original file line number Diff line number Diff line change
Expand Up @@ -580,7 +580,7 @@ function set_padding(obj, padding)
if n > 0
for i=1:n
m1 = length(images);
images = obj.stepSplitMultipleInputs(images, pad_image, split_pos(i, :, :), max_index{split_pos(i, 1), split_pos(i, 2), split_pos(i, 3)}, []);
images = obj.stepSplitMultipleInputs(images, pad_image, split_pos(i, :, :), max_index{split_pos(i, 1), split_pos(i, 2), split_pos(i, 3)}, [], lp_solver);
m2 = length(images);
if strcmp(dis_opt, 'display')
fprintf('\nSplit %d images into %d images', m1, m2);
Expand Down
2 changes: 2 additions & 0 deletions code/nnv/engine/utils/lpsolver.m
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,8 @@
% Define solver parameters
params = struct; % for now, leave default options/params
params.OutputFlag = 0; % no display
params.OptimalityTol = 1e-09;
params.FeasibilityTol = 1e-09;
result = gurobi(model, params);
fval = result.objval; % get fval value from results
% get exitflag and match those of linprog for easier parsing
Expand Down
32 changes: 9 additions & 23 deletions code/nnv/examples/Submission/WiP_3d/functions/add_voxels.m
Original file line number Diff line number Diff line change
Expand Up @@ -5,32 +5,18 @@
% Return a VolumeStar of a brightening attack on a few pixels

% Initialize vars
ct = 0; % keep track of pixels modified
flag = 0; % determine when to stop modifying pixels
vol = single(vol);
at_vol = vol;

% Create brightening attack
for i=1:size(vol,1)
for j=1:size(vol,2)
for k=1:size(vol,3)
if vol(i,j,k) < threshold
at_vol(i,j,k) = 255;
ct = ct + 1;
if ct >= max_pixels
flag = 1;
break;
end
end
end
if flag == 1
break
end
end
if flag == 1
break;
end
end
% we can find the edge of the shape
shape = edge3(vol,'approxcanny',0.6); % this should be okay for this data, but let's test it

% select a random pixel
idxs = setdiff(find(shape), find(vol));
voxels = min(voxels,length(idxs));

% For now, we can select the first ones
at_vol(idxs(1:voxels)) = 255;

% Define input set as VolumeStar
dif_vol = -vol + at_vol;
Expand Down
31 changes: 31 additions & 0 deletions code/nnv/examples/Submission/WiP_3d/functions/remove_voxels.m
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
function I = remove_voxels(vol, voxels, noise_disturbance)
% noise_disturnamce can be kept fixed here, more interesting on number
% of voxels changed

% Return a VolumeStar of a brightening attack on a few pixels

% Initialize vars
vol = single(vol);
at_vol = vol;

% we can find the edge of the shape
shape = edge3(vol,'approxcanny',0.6); % this should be okay for this data, but let's test it

% select a random pixel
idxs = intersect(find(shape),find(vol));
voxels = min(voxels,length(idxs));

% For now, we can select the first ones
at_vol(idxs(1:voxels)) = 0;

% Define input set as VolumeStar
dif_vol = -vol + at_vol;
noise = dif_vol;
V(:,:,:,:,1) = vol; % center of set
V(:,:,:,:,2) = noise; % basis vectors
C = [1; -1]; % constraints
d = [1; noise_disturbance-1]; % constraints
I = VolumeStar(V, C, d, 1-noise_disturbance, 1); % input set


end

This file was deleted.

Original file line number Diff line number Diff line change
Expand Up @@ -6,17 +6,16 @@
% Check what type of attack to consider
if strcmp(attack.Name, 'add') || strcmp(attack.Name, 'remove')
max_pixels = attack.max_pixels;
threshold = attack.threshold;
noise_disturbance = attack.noise_de;
else
error("Adversarial attack not supported.");
end

% Choose attack
if strcmp(attack.Name, 'add')
I = add_voxels(vol, max_pixels, threshold, noise_disturbance);
I = add_voxels(vol, max_pixels, noise_disturbance);
elseif strcmp(attack.Name, 'remove')
I = remove_voxels(vol, max_pixels, threshold, noise_disturbance);
I = remove_voxels(vol, max_pixels, noise_disturbance);
end

% Begin analysis
Expand Down
39 changes: 39 additions & 0 deletions code/nnv/examples/Submission/WiP_3d/other/NAV/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,39 @@

# NAV Benchmark

## Property:
The control goal is to navigate a robot to a goal region while avoiding an obstacle.
Time horizon: `t = 6s`. Control period: `0.2s`.

Initial states:

x1 = [2.9, 3.1]
x2 = [2.9, 3.1]
x3 = [0, 0]
x4 = [0, 0]

Dynamic system: [dynamics.m](./dynamics.m)

Goal region ( t=6 ):

x1 = [-0.5, 0.5]
x2 = [-0.5, 0.5]
x3 = [-Inf, Inf]
x4 = [-Inf, Inf]

Obstacle ( always ):

x1 = [1, 2]
x2 = [1, 2]
x3 = [-Inf, Inf]
x4 = [-Inf, Inf]

## Networks:

We provide two networks:
- The first network is trained with standard (point-based) reinforcement learning: `nn-nav-point.onnx`
- The second network is trained set-based to improve its verifiable robustness by integrating reachability analysis into the training process: `nn-nav-set.onnx`

Reference set-based training: https://arxiv.org/abs/2401.14961


11 changes: 11 additions & 0 deletions code/nnv/examples/Submission/WiP_3d/other/NAV/dynamics.m
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
function dx = dynamics(x,u)

dx = [
x(3)*cos(x(4));
x(3)*sin(x(4));
u(1);
u(2)
];

end

Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file not shown.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file not shown.
Binary file not shown.
Binary file not shown.
135 changes: 135 additions & 0 deletions code/nnv/examples/Submission/WiP_3d/other/NAV/reach_point.m
Original file line number Diff line number Diff line change
@@ -0,0 +1,135 @@
function rT = reach_point()

%% Reachability analysis of NAV Benchmark

%% Load Components

% Load the controller
%netonnx = importNetworkFromONNX('networks/nn-nav-point.onnx', "InputDataFormats", "BC");
netonnx = importONNXNetwork('networks/nn-nav-point.onnx', "InputDataFormats", "BC");

% Load plant
reachStep = 0.02;
controlPeriod = 0.2;
% plant = NonLinearODE(4, 2, @dynamics, reachStep, controlPeriod, eye(4));
% plant.set_tensorOrder(2);
% plant.set_taylorTerms(3);
% plant.set_zonotopeOrder(100);
% plant.set_intermediateOrder(50);

%% Reachability analysis

% Initial set
lb = [2.9; 2.9; 0; 0];
ub = [3.1; 3.1; 0; 0];
init_set = Box(lb,ub);
init = init_set.partition([1 2],[50 50]);

% Reachability options
num_steps = 20;
reachOptions.reachMethod = 'approx-star';

N = length(init);
disp("Verifying "+string(N)+" samples...")

mkdir('tmp');
parpool("Processes"); % initialize parallel process

% Execute reachabilty analysis
t = tic;
parfor j = 1:length(init)
% Get NNV network
net = matlab2nnv(netonnx);
% Create plant
plant = NonLinearODE(4, 2, @dynamics, reachStep, controlPeriod, eye(4));
plant.set_tensorOrder(2);
plant.set_taylorTerms(3);
plant.set_zonotopeOrder(100);
plant.set_intermediateOrder(50);
% Get initial conditions
init_set = init(j).toStar;
%reachSub = init_set;
for i = 1:num_steps
% Compute controller output set
input_set = net.reach(init_set,reachOptions);

% Compute plant reachable set
init_set = plantReach(plant, init_set, input_set,'lin');
end
toc(t);
parsave("tmp/reachSet"+string(j)+".mat",plant);
end
rT = toc(t); % get reach time
disp("Finished reachability...")

% Shut Down Current Parallel Pool
poolobj = gcp('nocreate');
delete(poolobj);

% Save results
if is_codeocean
save('/results/logs/nav_point.mat', 'rT','-v7.3');
else
save('nav_point.mat', 'rT','-v7.3');
end


%% Visualize results
setFiles = dir('tmp/*.mat');

t = tic;

f = figure;
rectangle('Position',[-0.5,-0.5,1,1],'FaceColor',[0 0.5 0 0.5],'EdgeColor','y', 'LineWidth', 0.1); % goal region
hold on;
rectangle('Position',[1,1,1,1],'FaceColor',[0.7 0 0 0.8], 'EdgeColor','r', 'LineWidth', 0.1); % obstacle
grid;
for K = 1 : length(setFiles)
if ~mod(K,50)
disp("Plotting partition "+string(K)+" ...");
toc(t)
pause(0.01); % to ensure it prints
end
res = load("tmp/"+setFiles(K).name);
plant = res.plant;
for k=1:length(plant.cora_set)
plot(plant.cora_set{k}, [1,2], 'b', 'Unify', true);
end
end
hold on;
xlabel('x1');
ylabel('x2');

disp("Finished plotting all reach sets");

%% Save figure
if is_codeocean
saveas(f,'/results/logs/nav_point.png');
% exportgraphics(f,'/results/logs/nav-set.pdf', 'ContentType', 'vector');
else
saveas(f,'nav_point_21.png');
% exportgraphics(f,'nav-set.pdf','ContentType', 'vector');
end

% Save results
if is_codeocean
save('/results/logs/nav_point.mat','rT','-v7.3');
else
save('nav_point.mat', 'rT','-v7.3');
end

end

%% Helper function
function init_set = plantReach(plant,init_set,input_set,algoC)
nS = length(init_set); % based on approx-star, number of sets should be equal
ss = [];
for k=1:nS
ss =[ss plant.stepReachStar(init_set(k), input_set(k),algoC)];
end
init_set = ss;
end

function parsave(fname, plant) % trick to save while on parpool
save(fname, 'plant')
end
Loading
Loading