Skip to content

Commit

Permalink
declutter args for mlp generator
Browse files Browse the repository at this point in the history
  • Loading branch information
szcf-weiya committed Jul 31, 2024
1 parent 7039357 commit dff0396
Show file tree
Hide file tree
Showing 2 changed files with 57 additions and 37 deletions.
68 changes: 44 additions & 24 deletions src/boot.jl
Original file line number Diff line number Diff line change
Expand Up @@ -51,10 +51,15 @@ function check_CI(; n = 100, σ = 0.1, f = exp, seed = 1234,
nhidden = 1000, depth = 2,
gpu_id = 0,
model_file = nothing,
patience = 100,
patience = 5, patience0 = 5,
sort_in_nn = true, # only flux
check_acc = false, # reduce to check_acc
fig = true, figfolder = "~", kw...
cmdconvert = "convert", # on chpc-gpu node, we can use `ssh sandbox convert `
fig = true, figfolder = "~",
disable_early_stopping = true,
eval_sigma_adaptive = false,
add_cv = false, add_loocv = false, # used when compare error curves by generator and bootstrap; not recommended when n is large (e.g., the experiments of running time)
kw...
)
timestamp = replace(strip(read(`date -Iseconds`, String)), ":" => "_")
if check_acc
Expand All @@ -71,7 +76,9 @@ function check_CI(; n = 100, σ = 0.1, f = exp, seed = 1234,
Yhat0 = zeros(nλ, n)
Yhat = zeros(nλ, n)
## julia's progress bar has been overrideen by tqdm's progress bar
ident = "$f-n$n"
for i = 1:nrep
identi = "$ident-$i"
if nrep == 1
x = rand(MersenneTwister(seed), n) * 2 .- 1
y = f.(x) + randn(MersenneTwister(seed+1), n) * σ # to avoid the same random seed
Expand Down Expand Up @@ -110,7 +117,7 @@ function check_CI(; n = 100, σ = 0.1, f = exp, seed = 1234,
LOSS = vcat(loss0, loss)
else # backend = "pytorch"
if !(@isdefined PyCall)
error("PyCall/PyTorch is not properly loaded, please use Flux backend or re-install PyCall")
error("Did you run `__init_pytorch__()`? PyCall/PyTorch is not properly loaded, please use Flux backend or re-install PyCall")
end
M = K
model_file = "model-$f--n$n-J$J-nhidden$nhidden-$i-$seed-$timestamp.pt"
Expand All @@ -123,11 +130,13 @@ function check_CI(; n = 100, σ = 0.1, f = exp, seed = 1234,
nhidden = nhidden, depth = depth,
niter_per_epoch = niter_per_epoch,
model_file = model_file,
patience = patience,
patience = patience, patience0 = patience0, disable_early_stopping = disable_early_stopping,
eval_sigma_adaptive = eval_sigma_adaptive,
nepoch0 = nepoch0, λl = λs[1], λu = λs[end])
end
if fig
savefig(plot(log.(LOSS)), "$figfolder/loss-$f--$i.png")
serialize("$figfolder/loss-$identi.sil", LOSS)
savefig(plot(log.(LOSS)), "$figfolder/loss-$identi.png")
end
else
# gpu is much faster
Expand All @@ -141,10 +150,12 @@ function check_CI(; n = 100, σ = 0.1, f = exp, seed = 1234,
fit_err = zeros(nλ, n)
RES_YCI0 = Array{Any, 1}(undef, nλ)
for (j, λ) in enumerate(λs)
res_time[i, 3] += @elapsed begin
_, YCI = MonotoneSplines.ci_mono_ss(x, y, λ, prop_nknots=prop_nknots, B = nB)
if nepoch > 0
res_time[i, 3] += @elapsed begin
_, YCI = MonotoneSplines.ci_mono_ss(x, y, λ, prop_nknots=prop_nknots, B = nB)
end
RES_YCI0[j] = YCI
end
RES_YCI0[j] = YCI
res_time[i, 4] += @elapsed begin
yhat = Ghat(y, λ)
end
Expand All @@ -160,7 +171,10 @@ function check_CI(; n = 100, σ = 0.1, f = exp, seed = 1234,
end
end
if fig
savefig(fitfig, "$figfolder/fit-$f--$i.png")
savefig(fitfig, "$figfolder/fit-$identi.png")
end
if nepoch == 0 # do not run the CI part
continue
end
res_time[i, 5] = @elapsed begin
RES_YCI, cov_hat = sample_G_λ(Ghat, y, λs, nB = nB)
Expand All @@ -178,22 +192,26 @@ function check_CI(; n = 100, σ = 0.1, f = exp, seed = 1234,
Err_boot[i, :, 1] .= mean(fit_err, dims=2)[:]
Err_boot[i, :, 2] .= mean(cov_hat, dims=2)[:]
Err_boot[i, :, 3] .= Err_boot[i, :, 1] + 2 * Err_boot[i, :, 2]
errs, _, _, _ = cv_mono_ss(x, y, λs, nfold = 10)
errs2, _, _, _ = cv_mono_ss(x, y, λs, nfold = n)
if add_cv
errs, _, _, _ = cv_mono_ss(x, y, λs, nfold = 10)
end
if add_loocv
errs2, _, _, _ = cv_mono_ss(x, y, λs, nfold = n)
end
if fig
errfig = plot(log.(λs), Err_boot[i, :, 3], label = "err + 2cov")
plot!(errfig, log.(λs), Err_boot[i, :, 1], label = "err")
plot!(errfig, log.(λs), errs, label = "10 fold CV")
plot!(errfig, log.(λs), errs2, label = "LOOCV")
savefig(errfig, "$figfolder/err-$f--$i.png")
savefig(plot(log.(λs), cp), "$figfolder/cp-$f--$i.png")
end
if fig # TODO: cannot generalize
if strip(read(`hostname`, String)) == "chpc-gpu019"
run(`ssh sandbox convert $figfolder/cp-$f--$i.png $figfolder/err-$f--$i.png $figfolder/fit-$f--$i.png $figfolder/loss-$f--$i.png +append $figfolder/$f--$i.png`)
else
run(`convert $figfolder/cp-$f--$i.png $figfolder/err-$f--$i.png $figfolder/fit-$f--$i.png $figfolder/loss-$f--$i.png +append $figfolder/$f--$i.png`)
if add_cv
plot!(errfig, log.(λs), errs, label = "10 fold CV")
end
if add_loocv
plot!(errfig, log.(λs), errs2, label = "LOOCV")
end
savefig(errfig, "$figfolder/err-$identi.png")
savefig(plot(log.(λs), cp), "$figfolder/cp-$identi.png")
end
if fig
run(`$cmdconvert $figfolder/cp-$identi.png $figfolder/err-$identi.png $figfolder/fit-$identi.png $figfolder/loss-$identi.png +append $figfolder/$identi.png`)
end
end
serialize("$f-n$n-nrep$nrep-B$nB-K$K-nepoch$nepoch-$timestamp.sil", [res_covprob, res_overlap, res_err, res_time, Err_boot])
Expand Down Expand Up @@ -507,7 +525,7 @@ function py_train_G_lambda(y::AbstractVector, B::AbstractMatrix, L::AbstractMatr
η = 0.001, η0 = 0.001,
K0 = 10, K = 10,
nhidden = 1000, depth = 2,
patience = 100, patience0 = 100, disable_early_stopping = true, # deprecated
patience = 5, patience0 = 5, disable_early_stopping = true, # for nepoch
nepoch0 = 100, nepoch = 100,
λl = 1e-9, λu = 1e-4,
use_torchsort = false, sort_reg_strength = 0.1,
Expand All @@ -516,7 +534,8 @@ function py_train_G_lambda(y::AbstractVector, B::AbstractMatrix, L::AbstractMatr
niter_per_epoch = 100,
disable_tqdm = false,
λs_opt_train = nothing, λs_opt_val = nothing,
βs_opt_train = nothing, βs_opt_val = nothing
βs_opt_train = nothing, βs_opt_val = nothing,
eval_sigma_adaptive = false,
)
Ghat, LOSS, LOSS1 = _py_boot."train_G_lambda"(Float32.(y), Float32.(B), Float32.(L), eta = η, K = K,
K0 = K0,
Expand All @@ -532,7 +551,8 @@ function py_train_G_lambda(y::AbstractVector, B::AbstractMatrix, L::AbstractMatr
nhidden = nhidden, depth = depth,
disable_tqdm = disable_tqdm,
lams_opt_train = λs_opt_train, lams_opt_val = λs_opt_val,
betas_opt_train = βs_opt_train, betas_opt_val = βs_opt_val
betas_opt_train = βs_opt_train, betas_opt_val = βs_opt_val,
eval_sigma_adaptive = eval_sigma_adaptive
)#::Tuple{PyObject, PyArray}
#println(typeof(py_ret)) #Tuple{PyCall.PyObject, Matrix{Float32}}
# ....................... # Tuple{PyCall.PyObject, PyCall.PyArray{Float32, 2}}
Expand Down
26 changes: 13 additions & 13 deletions src/boot.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,8 +37,7 @@ def train_G_lambda(y, B, L, K = 10, K0 = 10,
nhidden = 1000, depth = 2,
nepoch = 100, nepoch0 = 100,
niter_per_epoch = 100,
gamma = 0.9, decay_step = 5, # TODO: how to proper set schedule? or just discard?
lam_lo = 1e-9, lam_up = 1e-4, use_torchsort = False,
lam_lo = 1e-9, lam_up = 1e-4, use_torchsort = False, sample_log_lam = True,
sort_reg_strength = 0.1, gpu_id = 0,
patience = 100, patience0 = 100, disable_early_stopping = True, # TODO: early stopping
eval_sigma_adaptive = False, # if False, use `model0` to evaluate sigma
Expand All @@ -47,7 +46,9 @@ def train_G_lambda(y, B, L, K = 10, K0 = 10,
lams_opt_train = None, lams_opt_val = None, # each lam corresponds to a beta (dim: N)
betas_opt_train = None, betas_opt_val = None, # evaluate the loss between the OPT solution and GpBS solution here (dim NxJ)
disable_tqdm = False):
#
# avoid boundary effect when evaluating
lam_lo = lam_lo * 0.9
lam_up = lam_up * 1.1
device = f"cuda:{gpu_id}" if torch.cuda.is_available() and gpu_id != -1 else "cpu"
y = torch.from_numpy(y[None, :]).to(device, non_blocking=True)
B = torch.from_numpy(B).to(device, non_blocking=True)
Expand All @@ -61,10 +62,6 @@ def train_G_lambda(y, B, L, K = 10, K0 = 10,
model = Model(n+dim_lam, J, nhidden, depth, use_torchsort, sort_reg_strength).to(device)
opt1 = torch.optim.Adam(model.parameters(), lr = eta0, amsgrad = amsgrad)
opt2 = torch.optim.Adam(model.parameters(), lr = eta, amsgrad = amsgrad)
#sch1 = torch.optim.lr_scheduler.StepLR(opt1, gamma = gamma, step_size = decay_step)
# sch1 = torch.optim.lr_scheduler.ReduceLROnPlateau(opt1, 'min', factor = gamma, patience = patience, cooldown = cooldown, min_lr = 1e-7, threshold = 1e-5)
# sch1 = torch.optim.lr_scheduler.CyclicLR(opt1, 1e-6, eta0, cycle_momentum=False, mode = "exp_range", gamma = gamma)
sch2 = torch.optim.lr_scheduler.StepLR(opt2, gamma = gamma, step_size = decay_step)
loss_fn = nn.functional.mse_loss
# just realized that pytorch also did not do sort in batch
LOSS = torch.zeros(nepoch, 4).to(device)
Expand All @@ -81,7 +78,10 @@ def aug(lam):
# else:
# lams = torch.rand((K, 1)).to(device) * (lam_up - lam_lo) + lam_lo
for ii in pbar0:
lams = torch.rand((K, 1), device = device) * (lam_up - lam_lo) + lam_lo
if sample_log_lam:
lams = torch.exp(torch.rand((K, 1), device = device) * (np.log(lam_up) - np.log(lam_lo)) + np.log(lam_lo))
else:
lams = torch.rand((K, 1), device = device) * (lam_up - lam_lo) + lam_lo
ys = torch.cat((y.repeat( (K, 1) ), lams, torch.pow(lams, 1/3), torch.exp(lams), torch.sqrt(lams),
torch.log(lams), 10*lams, torch.square(lams), torch.pow(lams, 3)), dim = 1) # repeat works regardless of y has been augmented via `y[None, :]`
betas = model(ys)
Expand Down Expand Up @@ -123,7 +123,6 @@ def aug(lam):
LOSS0[epoch, i+1] = loss_fn(ypred, y) + lam * torch.square(torch.matmul(beta, L)).mean() * J / n
print(f"epoch = {epoch}, L(lam) = {LOSS0[epoch, 0]:.6f}, L(lam_lo) = {LOSS0[epoch, 1]:.6f}, L(lam_up) = {LOSS0[epoch, 2]:.6f}")

# sch1.step()
if not disable_early_stopping:
early_stopping0(LOSS0[epoch, 1:].mean(), model)
if early_stopping0.early_stop:
Expand All @@ -143,7 +142,10 @@ def aug(lam):
for ii in pbar:
if step2_use_tensor:
# construct tensor
lam = torch.rand((K0, 1, 1)) * (lam_up - lam_lo) + lam_lo
if sample_log_lam:
lam = torch.exp(torch.rand((K0, 1, 1)) * (np.log(lam_up) - np.log(lam_lo)) + np.log(lam_lo))
else:
lam = torch.rand((K0, 1, 1)) * (lam_up - lam_lo) + lam_lo
aug_lam = torch.cat(aug(lam), dim=2).to(device, non_blocking=True) # K0 x 1 x 8
ylam = torch.cat((y.repeat(K0, 1, 1), aug_lam), dim=2) # K0 x 1 x (n+8)
# K0 x 1 x J
Expand Down Expand Up @@ -189,7 +191,6 @@ def aug(lam):
opt2.zero_grad()
loss2.backward()
opt2.step()
# sch2.step()
train_loss.append(loss2.item())
pbar.set_postfix(iter = ii, loss = loss2.item())
if ii == niter_per_epoch - 1:
Expand All @@ -202,8 +203,7 @@ def aug(lam):
ypred = torch.matmul(beta, B.t())
LOSS[epoch, i+1] = loss_fn(ypred, y) + lam * torch.square(torch.matmul(beta, L)).mean() * J / n

sch2.step()
print(f"epoch = {epoch}, L(lam) = {LOSS[epoch, 0]:.6f}, L(lam_lo) = {LOSS[epoch, 1]:.6f}, L(lam_up) = {LOSS[epoch, 2]:.6f}, lr = {sch2.get_last_lr()}")
print(f"epoch = {epoch}, L(lam) = {LOSS[epoch, 0]:.6f}, L(lam_lo) = {LOSS[epoch, 1]:.6f}, L(lam_up) = {LOSS[epoch, 2]:.6f}")
if not disable_early_stopping:
early_stopping(LOSS[epoch, 1:].mean(), model)
if early_stopping.early_stop:
Expand Down

0 comments on commit dff0396

Please sign in to comment.