From a28cd4f9adb20d0cb2feee53294d1f0e9d2ae62a Mon Sep 17 00:00:00 2001 From: Yaoda Liu Date: Fri, 15 Sep 2023 20:59:30 +0800 Subject: [PATCH] feat: replace empty hook funcs with server hooks (#201) --- server/server.go | 115 ++++++++++++++++++++++++++++++++++------------- 1 file changed, 85 insertions(+), 30 deletions(-) diff --git a/server/server.go b/server/server.go index 327b60a0..c55abc33 100644 --- a/server/server.go +++ b/server/server.go @@ -1213,8 +1213,11 @@ func (srv *server) initPluginHooks() error { } } if onAcceptWrappers != nil { - onAccept := func(ctx context.Context, conn net.Conn) bool { - return true + onAccept := srv.hooks.OnAccept + if onAccept == nil { + onAccept = func(ctx context.Context, conn net.Conn) bool { + return true + } } for i := len(onAcceptWrappers); i > 0; i-- { onAccept = onAcceptWrappers[i-1](onAccept) @@ -1222,8 +1225,11 @@ func (srv *server) initPluginHooks() error { srv.hooks.OnAccept = onAccept } if onBasicAuthWrappers != nil { - onBasicAuth := func(ctx context.Context, client Client, req *ConnectRequest) error { - return nil + onBasicAuth := srv.hooks.OnBasicAuth + if onBasicAuth == nil { + onBasicAuth = func(ctx context.Context, client Client, req *ConnectRequest) error { + return nil + } } for i := len(onBasicAuthWrappers); i > 0; i-- { onBasicAuth = onBasicAuthWrappers[i-1](onBasicAuth) @@ -1231,10 +1237,13 @@ func (srv *server) initPluginHooks() error { srv.hooks.OnBasicAuth = onBasicAuth } if onEnhancedAuthWrappers != nil { - onEnhancedAuth := func(ctx context.Context, client Client, req *ConnectRequest) (resp *EnhancedAuthResponse, err error) { - return &EnhancedAuthResponse{ - Continue: false, - }, nil + onEnhancedAuth := srv.hooks.OnEnhancedAuth + if onEnhancedAuth == nil { + onEnhancedAuth = func(ctx context.Context, client Client, req *ConnectRequest) (resp *EnhancedAuthResponse, err error) { + return &EnhancedAuthResponse{ + Continue: false, + }, nil + } } for i := len(onEnhancedAuthWrappers); i > 0; i-- { onEnhancedAuth = onEnhancedAuthWrappers[i-1](onEnhancedAuth) @@ -1243,36 +1252,51 @@ func (srv *server) initPluginHooks() error { } if onConnectedWrappers != nil { - onConnected := func(ctx context.Context, client Client) {} + onConnected := srv.hooks.OnConnected + if onConnected == nil { + onConnected = func(ctx context.Context, client Client) {} + } for i := len(onConnectedWrappers); i > 0; i-- { onConnected = onConnectedWrappers[i-1](onConnected) } srv.hooks.OnConnected = onConnected } if onSessionCreatedWrapper != nil { - onSessionCreated := func(ctx context.Context, client Client) {} + onSessionCreated := srv.hooks.OnSessionCreated + if onSessionCreated == nil { + onSessionCreated = func(ctx context.Context, client Client) {} + } for i := len(onSessionCreatedWrapper); i > 0; i-- { onSessionCreated = onSessionCreatedWrapper[i-1](onSessionCreated) } srv.hooks.OnSessionCreated = onSessionCreated } if onSessionResumedWrapper != nil { - onSessionResumed := func(ctx context.Context, client Client) {} + onSessionResumed := srv.hooks.OnSessionResumed + if onSessionResumed == nil { + onSessionResumed = func(ctx context.Context, client Client) {} + } for i := len(onSessionResumedWrapper); i > 0; i-- { onSessionResumed = onSessionResumedWrapper[i-1](onSessionResumed) } srv.hooks.OnSessionResumed = onSessionResumed } if onSessionTerminatedWrapper != nil { - onSessionTerminated := func(ctx context.Context, clientID string, reason SessionTerminatedReason) {} + onSessionTerminated := srv.hooks.OnSessionTerminated + if onSessionTerminated == nil { + onSessionTerminated = func(ctx context.Context, clientID string, reason SessionTerminatedReason) {} + } for i := len(onSessionTerminatedWrapper); i > 0; i-- { onSessionTerminated = onSessionTerminatedWrapper[i-1](onSessionTerminated) } srv.hooks.OnSessionTerminated = onSessionTerminated } if onSubscribeWrappers != nil { - onSubscribe := func(ctx context.Context, client Client, req *SubscribeRequest) error { - return nil + onSubscribe := srv.hooks.OnSubscribe + if onSubscribe == nil { + onSubscribe = func(ctx context.Context, client Client, req *SubscribeRequest) error { + return nil + } } for i := len(onSubscribeWrappers); i > 0; i-- { onSubscribe = onSubscribeWrappers[i-1](onSubscribe) @@ -1280,15 +1304,22 @@ func (srv *server) initPluginHooks() error { srv.hooks.OnSubscribe = onSubscribe } if onSubscribedWrappers != nil { - onSubscribed := func(ctx context.Context, client Client, subscription *gmqtt.Subscription) {} + onSubscribed := srv.hooks.OnSubscribed + if onSubscribed == nil { + onSubscribed = func(ctx context.Context, client Client, subscription *gmqtt.Subscription) {} + } + for i := len(onSubscribedWrappers); i > 0; i-- { onSubscribed = onSubscribedWrappers[i-1](onSubscribed) } srv.hooks.OnSubscribed = onSubscribed } if onUnsubscribeWrappers != nil { - onUnsubscribe := func(ctx context.Context, client Client, req *UnsubscribeRequest) error { - return nil + onUnsubscribe := srv.hooks.OnUnsubscribe + if onUnsubscribe == nil { + onUnsubscribe = func(ctx context.Context, client Client, req *UnsubscribeRequest) error { + return nil + } } for i := len(onUnsubscribeWrappers); i > 0; i-- { onUnsubscribe = onUnsubscribeWrappers[i-1](onUnsubscribe) @@ -1296,15 +1327,21 @@ func (srv *server) initPluginHooks() error { srv.hooks.OnUnsubscribe = onUnsubscribe } if onUnsubscribedWrappers != nil { - onUnsubscribed := func(ctx context.Context, client Client, topicName string) {} + onUnsubscribed := srv.hooks.OnUnsubscribed + if onUnsubscribed == nil { + onUnsubscribed = func(ctx context.Context, client Client, topicName string) {} + } for i := len(onUnsubscribedWrappers); i > 0; i-- { onUnsubscribed = onUnsubscribedWrappers[i-1](onUnsubscribed) } srv.hooks.OnUnsubscribed = onUnsubscribed } if onMsgArrivedWrappers != nil { - onMsgArrived := func(ctx context.Context, client Client, req *MsgArrivedRequest) error { - return nil + onMsgArrived := srv.hooks.OnMsgArrived + if onMsgArrived == nil { + onMsgArrived = func(ctx context.Context, client Client, req *MsgArrivedRequest) error { + return nil + } } for i := len(onMsgArrivedWrappers); i > 0; i-- { onMsgArrived = onMsgArrivedWrappers[i-1](onMsgArrived) @@ -1312,42 +1349,60 @@ func (srv *server) initPluginHooks() error { srv.hooks.OnMsgArrived = onMsgArrived } if OnDeliveredWrappers != nil { - OnDelivered := func(ctx context.Context, client Client, msg *gmqtt.Message) {} + onDelivered := srv.hooks.OnDelivered + if onDelivered == nil { + onDelivered = func(ctx context.Context, client Client, msg *gmqtt.Message) {} + } for i := len(OnDeliveredWrappers); i > 0; i-- { - OnDelivered = OnDeliveredWrappers[i-1](OnDelivered) + onDelivered = OnDeliveredWrappers[i-1](onDelivered) } - srv.hooks.OnDelivered = OnDelivered + srv.hooks.OnDelivered = onDelivered } if OnClosedWrappers != nil { - OnClosed := func(ctx context.Context, client Client, err error) {} + onClosed := srv.hooks.OnClosed + if onClosed == nil { + onClosed = func(ctx context.Context, client Client, err error) {} + } for i := len(OnClosedWrappers); i > 0; i-- { - OnClosed = OnClosedWrappers[i-1](OnClosed) + onClosed = OnClosedWrappers[i-1](onClosed) } - srv.hooks.OnClosed = OnClosed + srv.hooks.OnClosed = onClosed } if onStopWrappers != nil { - onStop := func(ctx context.Context) {} + onStop := srv.hooks.OnStop + if onStop == nil { + onStop = func(ctx context.Context) {} + } for i := len(onStopWrappers); i > 0; i-- { onStop = onStopWrappers[i-1](onStop) } srv.hooks.OnStop = onStop } if onMsgDroppedWrappers != nil { - onMsgDropped := func(ctx context.Context, clientID string, msg *gmqtt.Message, err error) {} + onMsgDropped := srv.hooks.OnMsgDropped + if onMsgDropped == nil { + onMsgDropped = func(ctx context.Context, clientID string, msg *gmqtt.Message, err error) {} + } for i := len(onMsgDroppedWrappers); i > 0; i-- { onMsgDropped = onMsgDroppedWrappers[i-1](onMsgDropped) } srv.hooks.OnMsgDropped = onMsgDropped } if onWillPublishWrappers != nil { - onWillPublish := func(ctx context.Context, clientID string, req *WillMsgRequest) {} + onWillPublish := srv.hooks.OnWillPublish + if onWillPublish == nil { + onWillPublish = func(ctx context.Context, clientID string, req *WillMsgRequest) {} + } for i := len(onWillPublishWrappers); i > 0; i-- { onWillPublish = onWillPublishWrappers[i-1](onWillPublish) } srv.hooks.OnWillPublish = onWillPublish } if onWillPublishedWrappers != nil { - onWillPublished := func(ctx context.Context, clientID string, msg *gmqtt.Message) {} + onWillPublished := srv.hooks.OnWillPublished + if onWillPublished == nil { + onWillPublished = func(ctx context.Context, clientID string, msg *gmqtt.Message) {} + } for i := len(onWillPublishedWrappers); i > 0; i-- { onWillPublished = onWillPublishedWrappers[i-1](onWillPublished) }