Skip to content

Commit

Permalink
updated pool/unpool bug reports
Browse files Browse the repository at this point in the history
  • Loading branch information
denizyuret committed Dec 12, 2020
1 parent b01c548 commit ae149dd
Show file tree
Hide file tree
Showing 2 changed files with 21 additions and 18 deletions.
15 changes: 9 additions & 6 deletions src/ops20/conv.jl
Original file line number Diff line number Diff line change
Expand Up @@ -179,9 +179,6 @@ function checkpoolopts(x, window, padding, stride, mode, maxpoolingNanOpt, alpha
@warn "Pool maxpoolingNanOpt=0 not yet implemented in NNlib, using 1 instead. See https://github.com/FluxML/NNlib.jl/issues/218" maxlog=1
maxpoolingNanOpt = 1
end
if padding != 0 && x isa Array
@warn "Pool padding is buggy in NNlib, use with caution. See https://github.com/FluxML/NNlib.jl/issues/229" maxlog=1
end
return (mode, maxpoolingNanOpt)
end

Expand All @@ -191,11 +188,17 @@ end
Perform the reverse of pooling: `x == pool(unpool(x;o...); o...)`
"""
function unpool(x; window=2, alpha=1, o...) # padding=0, stride=window, mode=0, maxpoolingNanOpt=0
function unpool(x; window=2, padding=0, stride=window, mode=0, maxpoolingNanOpt=1, alpha=1)
if mode == 1 && x isa Array
@warn "unpool(mode=1), which uses poolx(mode=2) is not supported on the CPU; performing unpool(mode=2) instead, see https://github.com/FluxML/NNlib.jl/issues/218" maxlog=1
end
w = prod(psize(window,x))
y = similar(x,updims(x; window=window, o...))
y = similar(x,updims(x; window, padding, stride, mode, maxpoolingNanOpt, alpha))
# pool0=>unpool1, pool1=>unpool2, pool2=>unpool1
mode = (mode==0 ? 1 : mode==1 ? 2 : mode==2 ? 1 : mode==3 ? 1 : error("Unknown unpool mode $mode"))
alpha = 1/alpha
# Leave unpool as a non-primitive, it is just a poolx call
poolx(y,x,x.*w; o..., window=window, mode=1, alpha=1/alpha)
poolx(y,x,x.*w; window, padding, stride, mode, maxpoolingNanOpt, alpha)
end


Expand Down
24 changes: 12 additions & 12 deletions test/conv.jl
Original file line number Diff line number Diff line change
Expand Up @@ -93,15 +93,15 @@ struct M370; layer; end;
@test gradcheck(pool, ax; kw=[(:mode,1),(:padding,1)])
@test gradcheck(unpool, ax; kw=[(:mode,1),(:padding,1)])
@test isapprox(pool(unpool(ax;mode=1);mode=1),ax)
@test_broken isapprox(pool(unpool(ax;mode=1,padding=1);mode=1,padding=1),ax)
@test_broken isapprox(pool(unpool(ax;mode=1,padding=1);mode=1,padding=1),ax) # unpool(mode=1) uses pool(mode=2) NNlib#218
@test gradcheck(conv41, (aw,ax); rtol=TOL, kw=[(:mode,1),(:padding,1)])
@test gradcheck(deconv41, (ad,ax); rtol=TOL, kw=[(:mode,1),(:padding,1)])

### mode=2 (only for pool) -- is not supported in NNlib #218
# @test gradcheck(pool, ax; kw=[(:mode,2),(:padding,1)])
# @test gradcheck(unpool, ax; kw=[(:mode,2),(:padding,1)])
# @test isapprox(pool(unpool(ax;mode=2);mode=2),ax)
# @test isapprox(pool(unpool(ax;mode=2,padding=1);mode=2,padding=1),ax)
@test gradcheck(pool, ax; kw=[(:mode,2),(:padding,1)])
@test gradcheck(unpool, ax; kw=[(:mode,2),(:padding,1)])
@test isapprox(pool(unpool(ax;mode=2);mode=2),ax)
@test_broken isapprox(pool(unpool(ax;mode=2,padding=1);mode=2,padding=1),ax) # pool(mode=2) not supported NNlib#218

### alpha=2 (default=1)
@test gradcheck(pool, ax; kw=[(:alpha,2)])
Expand All @@ -110,7 +110,7 @@ struct M370; layer; end;
@test gradcheck(pool, ax; kw=[(:alpha,2),(:mode,1),(:padding,1)])
@test gradcheck(unpool, ax; kw=[(:alpha,2),(:mode,1),(:padding,1)])
@test isapprox(pool(unpool(ax;alpha=2,mode=1);alpha=2,mode=1),ax)
@test_broken isapprox(pool(unpool(ax;alpha=2,mode=1,padding=1);alpha=2,mode=1,padding=1),ax)
@test_broken isapprox(pool(unpool(ax;alpha=2,mode=1,padding=1);alpha=2,mode=1,padding=1),ax) # unpool(mode=1) uses pool(mode=2) unsupported by NNlib#218
@test gradcheck(conv41, (aw,ax); rtol=TOL, kw=[(:alpha,2)])
@test gradcheck(deconv41, (ad,ax); rtol=TOL, kw=[(:alpha,2)])
end
Expand Down Expand Up @@ -197,17 +197,17 @@ struct M370; layer; end;
### mode=1 (default=0)
@test isapprox(pool(kx;mode=1,padding=1), pool(ax;mode=1,padding=1))
@test gradcheck(pool, kx; kw=[(:mode,1),(:padding,1)])
@test isapprox(unpool(kx;mode=1,padding=1), unpool(ax;mode=1,padding=1))
@test_broken isapprox(unpool(kx;mode=1,padding=1), unpool(ax;mode=1,padding=1)) # unpool(mode=1) uses pool(mode=2) unsupported by NNlib#218
@test gradcheck(unpool, kx; kw=[(:mode,1),(:padding,1)])
@test isapprox(conv4(kw,kx;mode=1,padding=1), conv4(aw,ax;mode=1,padding=1))
@test gradcheck(conv41, (kw,kx); rtol=TOL, kw=[(:mode,1),(:padding,1)])
@test isapprox(deconv4(kd,kx;mode=1,padding=1), deconv4(ad,ax;mode=1,padding=1))
@test gradcheck(deconv41, (kd,kx); rtol=TOL, kw=[(:mode,1),(:padding,1)])

### mode=2 (only for pool)
# @test isapprox(pool(kx;mode=2,padding=1), pool(ax;mode=2,padding=1)) ## mode=2 is not supported in NNlib #218.
@test_broken isapprox(pool(kx;mode=2,padding=1), pool(ax;mode=2,padding=1)) ## mode=2 is not supported in NNlib #218.
@test gradcheck(pool, kx; kw=[(:mode,2),(:padding,1)])
# @test isapprox(unpool(kx;mode=2,padding=1), unpool(ax;mode=2,padding=1)) ## mode=2 is not supported in NNlib #218.
@test isapprox(unpool(kx;mode=2,padding=1), unpool(ax;mode=2,padding=1))
@test gradcheck(unpool, kx; kw=[(:mode,2),(:padding,1)])

### alpha=2 (default=1)
Expand All @@ -217,12 +217,12 @@ struct M370; layer; end;
@test gradcheck(unpool, kx; kw=[(:alpha,2)])
@test isapprox(pool(kx;alpha=2,mode=1,padding=1), pool(ax;alpha=2,mode=1,padding=1))
@test gradcheck(pool, kx; kw=[(:alpha,2),(:mode,1),(:padding,1)])
@test isapprox(unpool(kx;alpha=2,mode=1,padding=1), unpool(ax;alpha=2,mode=1,padding=1))
@test_broken isapprox(unpool(kx;alpha=2,mode=1,padding=1), unpool(ax;alpha=2,mode=1,padding=1)) # unpool(mode=1) uses pool(mode=2) unsupported by NNlib#218
@test gradcheck(unpool, kx; kw=[(:alpha,2),(:mode,1),(:padding,1)])

# @test isapprox(pool(kx;alpha=2,mode=2,padding=1), pool(ax;alpha=2,mode=2,padding=1)) ## mode=2 is not supported in NNlib #218.
@test_broken isapprox(pool(kx;alpha=2,mode=2,padding=1), pool(ax;alpha=2,mode=2,padding=1)) ## broken: mode=2 is not supported in NNlib #218.
@test gradcheck(pool, kx; kw=[(:alpha,2),(:mode,2),(:padding,1)])
# @test isapprox(unpool(kx;alpha=2,mode=2,padding=1), unpool(ax;alpha=2,mode=2,padding=1)) ## mode=2 is not supported in NNlib #218.
@test isapprox(unpool(kx;alpha=2,mode=2,padding=1), unpool(ax;alpha=2,mode=2,padding=1))
@test gradcheck(unpool, kx; kw=[(:alpha,2),(:mode,2),(:padding,1)])

@test isapprox(conv4(kw,kx;alpha=2), conv4(aw,ax;alpha=2))
Expand Down

0 comments on commit ae149dd

Please sign in to comment.