Skip to content

Commit

Permalink
stash function calls in gss()
Browse files Browse the repository at this point in the history
  • Loading branch information
goldingn committed Feb 26, 2020
1 parent 57047d8 commit 60ba1b2
Showing 1 changed file with 20 additions and 13 deletions.
33 changes: 20 additions & 13 deletions R/utils.R
Original file line number Diff line number Diff line change
Expand Up @@ -409,7 +409,7 @@ gss <- function(func,
lower = 0,
upper = 1,
max_iterations = 50,
tolerance = 1e-8) {
tolerance = .Machine$double.eps ^ 0.25) {

lower <- tf$constant(lower, tf_float(), shape(1))
upper <- tf$constant(upper, tf_float(), shape(1))
Expand All @@ -431,14 +431,15 @@ gss <- function(func,
d <- golden_ratio * width
x1 <- right - d
x2 <- left + d
f1 <- func(x1)
f2 <- func(x2)

values <- list(left, right, x1, x2, width, iter)
values <- list(left, right, x1, x2, f1, f2, width, iter)

# start loop
body <- function(left, right, x1, x2, width, iter) {
body <- function(left, right, x1, x2, f1, f2, width, iter) {

# prep lists of vectors for whether steps are above x1 or below x2

# order: lower, upper, x1, x2, width

# if the minimum is below x2, shift the bounds and reuse x1 as the new x2
Expand All @@ -465,22 +466,28 @@ gss <- function(func,
below <- tf$stack(below, axis = 1L)
above <- tf$stack(above, axis = 1L)

# can we stash and reuse these and reduce the number of function evaluations?
f1 <- func(x1)
f2 <- func(x2)

status <- tf$where(tf$greater(f2, f1), below, above)
is_below <- tf$greater(f2, f1)
status <- tf$where(is_below, below, above)

left <- status[, 0]
right <- status[, 1]
x1 <- status[, 2]
x2 <- status[, 3]
width <- status[, 4]

list(left, right, x1, x2, width, iter + 1L)
# either recompute f1 (f2) at new location, or use f2 (f1) in its place

# this is a bit convoluted, but ensures function calls don't need to be
# duplicated, whilst maintaining the vectorisation
x_to_evaluate <- tf$where(is_below, x1, x2)
new_f <- func(x_to_evaluate)
new_f1 <- tf$where(is_below, new_f, f2)
new_f2 <- tf$where(is_below, f1, new_f)

list(left, right, x1, x2, new_f1, new_f2, width, iter + 1L)
}

cond <- function(left, right, x1, x2, width, iter) {
cond <- function(left, right, x1, x2, f1, f2, width, iter) {
not_converged <- tf$less(tol, tf$abs(width))
not_all_converged <- tf$reduce_any(not_converged)
in_time <- tf$less(iter, maxiter)
Expand All @@ -490,12 +497,12 @@ gss <- function(func,
out <- tf$while_loop(cond, body, values)

# get minimum value
width <- out[[5]]
width <- out[[7]]
min <- out[[1]] + width / fl(2)

list(minimum = min,
width = width,
iterations = out[[6]])
iterations = out[[8]])

}

Expand Down

0 comments on commit 60ba1b2

Please sign in to comment.