Skip to content

Commit

Permalink
fix(transport): correctly release UDS locker file (#2305)
Browse files Browse the repository at this point in the history
* fix(transport): correctly release UDS locker file

* use callback function to do some jobs after create listener
  • Loading branch information
yin1999 authored and yuhan6665 committed Aug 26, 2023
1 parent 2d5475f commit 10d6b06
Show file tree
Hide file tree
Showing 5 changed files with 49 additions and 49 deletions.
5 changes: 0 additions & 5 deletions transport/internet/grpc/hub.go
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,6 @@ type Listener struct {
handler internet.ConnHandler
local net.Addr
config *Config
locker *internet.FileLocker // for unix domain socket

s *grpc.Server
}
Expand Down Expand Up @@ -110,10 +109,6 @@ func Listen(ctx context.Context, address net.Address, port net.Port, settings *i
newError("failed to listen on ", address).Base(err).AtError().WriteToLog(session.ExportIDToError(ctx))
return
}
locker := ctx.Value(address.Domain())
if locker != nil {
listener.locker = locker.(*internet.FileLocker)
}
} else { // tcp
streamListener, err = internet.ListenSystem(ctx, &net.TCPAddr{
IP: address.IP(),
Expand Down
8 changes: 0 additions & 8 deletions transport/internet/http/hub.go
Original file line number Diff line number Diff line change
Expand Up @@ -27,17 +27,13 @@ type Listener struct {
handler internet.ConnHandler
local net.Addr
config *Config
locker *internet.FileLocker // for unix domain socket
}

func (l *Listener) Addr() net.Addr {
return l.local
}

func (l *Listener) Close() error {
if l.locker != nil {
l.locker.Release()
}
return l.server.Close()
}

Expand Down Expand Up @@ -180,10 +176,6 @@ func Listen(ctx context.Context, address net.Address, port net.Port, streamSetti
newError("failed to listen on ", address).Base(err).AtError().WriteToLog(session.ExportIDToError(ctx))
return
}
locker := ctx.Value(address.Domain())
if locker != nil {
listener.locker = locker.(*internet.FileLocker)
}
} else { // tcp
streamListener, err = internet.ListenSystem(ctx, &net.TCPAddr{
IP: address.IP(),
Expand Down
69 changes: 49 additions & 20 deletions transport/internet/system_listener.go
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,19 @@ type DefaultListener struct {
controllers []control.Func
}

type combinedListener struct {
net.Listener
locker *FileLocker // for unix domain socket
}

func (cl *combinedListener) Close() error {
if cl.locker != nil {
cl.locker.Release()
cl.locker = nil
}
return cl.Listener.Close()
}

func getControlFunc(ctx context.Context, sockopt *SocketConfig, controllers []control.Func) func(network, address string, c syscall.RawConn) error {
return func(network, address string, c syscall.RawConn) error {
return c.Control(func(fd uintptr) {
Expand All @@ -44,6 +57,10 @@ func getControlFunc(ctx context.Context, sockopt *SocketConfig, controllers []co
func (dl *DefaultListener) Listen(ctx context.Context, addr net.Addr, sockopt *SocketConfig) (l net.Listener, err error) {
var lc net.ListenConfig
var network, address string
// callback is called after the Listen function returns
callback := func(l net.Listener, err error) (net.Listener, error) {
return l, err
}

switch addr := addr.(type) {
case *net.TCPAddr:
Expand All @@ -58,23 +75,6 @@ func (dl *DefaultListener) Listen(ctx context.Context, addr net.Addr, sockopt *S
network = addr.Network()
address = addr.Name

if s := strings.Split(address, ","); len(s) == 2 {
address = s[0]
perm, perr := strconv.ParseUint(s[1], 8, 32)
if perr != nil {
return nil, newError("failed to parse permission: " + s[1]).Base(perr)
}

defer func(file string, permission os.FileMode) {
if err == nil {
cerr := os.Chmod(address, permission)
if cerr != nil {
err = newError("failed to set permission for " + file).Base(cerr)
}
}
}(address, os.FileMode(perm))
}

if (runtime.GOOS == "linux" || runtime.GOOS == "android") && address[0] == '@' {
// linux abstract unix domain socket is lockfree
if len(address) > 1 && address[1] == '@' {
Expand All @@ -84,19 +84,48 @@ func (dl *DefaultListener) Listen(ctx context.Context, addr net.Addr, sockopt *S
address = string(fullAddr)
}
} else {
// split permission from address
var filePerm *os.FileMode
if s := strings.Split(address, ","); len(s) == 2 {
address = s[0]
perm, perr := strconv.ParseUint(s[1], 8, 32)
if perr != nil {
return nil, newError("failed to parse permission: " + s[1]).Base(perr)
}

mode := os.FileMode(perm)
filePerm = &mode
}
// normal unix domain socket needs lock
locker := &FileLocker{
path: address + ".lock",
}
err := locker.Acquire()
if err != nil {
if err := locker.Acquire(); err != nil {
return nil, err
}
ctx = context.WithValue(ctx, address, locker)

// set callback to combine listener and set permission
callback = func(l net.Listener, err error) (net.Listener, error) {
if err != nil {
locker.Release()
return l, err
}
l = &combinedListener{Listener: l, locker: locker}
if filePerm == nil {
return l, nil
}
err = os.Chmod(address, *filePerm)
if err != nil {
l.Close()
return nil, newError("failed to set permission for " + address).Base(err)
}
return l, nil
}
}
}

l, err = lc.Listen(ctx, network, address)
l, err = callback(l, err)
if sockopt != nil && sockopt.AcceptProxyProtocol {
policyFunc := func(upstream net.Addr) (proxyproto.Policy, error) { return proxyproto.REQUIRE, nil }
l = &proxyproto.Listener{Listener: l, Policy: policyFunc}
Expand Down
8 changes: 0 additions & 8 deletions transport/internet/tcp/hub.go
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,6 @@ type Listener struct {
authConfig internet.ConnectionAuthenticator
config *Config
addConn internet.ConnHandler
locker *internet.FileLocker // for unix domain socket
}

// ListenTCP creates a new Listener based on configurations.
Expand All @@ -51,10 +50,6 @@ func ListenTCP(ctx context.Context, address net.Address, port net.Port, streamSe
return nil, newError("failed to listen Unix Domain Socket on ", address).Base(err)
}
newError("listening Unix Domain Socket on ", address).WriteToLog(session.ExportIDToError(ctx))
locker := ctx.Value(address.Domain())
if locker != nil {
l.locker = locker.(*internet.FileLocker)
}
} else {
listener, err = internet.ListenSystem(ctx, &net.TCPAddr{
IP: address.IP(),
Expand Down Expand Up @@ -133,9 +128,6 @@ func (v *Listener) Addr() net.Addr {

// Close implements internet.Listener.Close.
func (v *Listener) Close() error {
if v.locker != nil {
v.locker.Release()
}
return v.listener.Close()
}

Expand Down
8 changes: 0 additions & 8 deletions transport/internet/websocket/hub.go
Original file line number Diff line number Diff line change
Expand Up @@ -75,7 +75,6 @@ type Listener struct {
listener net.Listener
config *Config
addConn internet.ConnHandler
locker *internet.FileLocker // for unix domain socket
}

func ListenWS(ctx context.Context, address net.Address, port net.Port, streamSettings *internet.MemoryStreamConfig, addConn internet.ConnHandler) (internet.Listener, error) {
Expand All @@ -101,10 +100,6 @@ func ListenWS(ctx context.Context, address net.Address, port net.Port, streamSet
return nil, newError("failed to listen unix domain socket(for WS) on ", address).Base(err)
}
newError("listening unix domain socket(for WS) on ", address).WriteToLog(session.ExportIDToError(ctx))
locker := ctx.Value(address.Domain())
if locker != nil {
l.locker = locker.(*internet.FileLocker)
}
} else { // tcp
listener, err = internet.ListenSystem(ctx, &net.TCPAddr{
IP: address.IP(),
Expand Down Expand Up @@ -153,9 +148,6 @@ func (ln *Listener) Addr() net.Addr {

// Close implements net.Listener.Close().
func (ln *Listener) Close() error {
if ln.locker != nil {
ln.locker.Release()
}
return ln.listener.Close()
}

Expand Down

0 comments on commit 10d6b06

Please sign in to comment.