diff --git a/common.go b/common.go index 43286691..f7c287da 100644 --- a/common.go +++ b/common.go @@ -18,9 +18,11 @@ import ( "context" "crypto/tls" "os" + "strconv" "sync" "time" + "github.com/henrylee2cn/goutil" "github.com/henrylee2cn/goutil/pool" "github.com/henrylee2cn/teleport/socket" "github.com/henrylee2cn/teleport/utils" @@ -52,21 +54,21 @@ func TypeText(typ byte) string { // Internal Framework Rerror code. // Note: Recommended custom code is greater than 1000. const ( - CodeUnknownError = -1 - CodeConnClosed = 102 - CodeWriteFailed = 104 - CodeDialFailed = 105 - CodeBadPacket = 400 - CodeUnauthorized = 401 - CodeNotFound = 404 - CodePtypeNotAllowed = 405 - CodeHandleTimeout = 408 - CodeBadGateway = 502 + CodeUnknownError = -1 + CodeConnClosed = 102 + CodeWriteFailed = 104 + CodeDialFailed = 105 + CodeBadPacket = 400 + CodeUnauthorized = 401 + CodeNotFound = 404 + CodePtypeNotAllowed = 405 + CodeHandleTimeout = 408 + CodeInternalServerError = 500 + CodeBadGateway = 502 // CodeConflict = 409 // CodeUnsupportedTx = 410 // CodeUnsupportedCodecType = 415 - // CodeInternalServerError = 500 // CodeServiceUnavailable = 503 // CodeGatewayTimeout = 504 // CodeVariantAlsoNegotiates = 506 @@ -96,6 +98,8 @@ func CodeText(rerrCode int32) string { return "Handle Timeout" case CodePtypeNotAllowed: return "Packet Type Not Allowed" + case CodeInternalServerError: + return "Internal Server Error" case CodeBadGateway: return "Bad Gateway" case CodeUnknownError: @@ -115,6 +119,7 @@ var ( rerrNotFound = NewRerror(CodeNotFound, CodeText(CodeNotFound), "") rerrCodePtypeNotAllowed = NewRerror(CodePtypeNotAllowed, CodeText(CodePtypeNotAllowed), "") rerrHandleTimeout = NewRerror(CodeHandleTimeout, CodeText(CodeHandleTimeout), "") + rerrInternalServerError = NewRerror(CodeInternalServerError, CodeText(CodeInternalServerError), "") ) // IsConnRerror determines whether the error is a connection error @@ -135,6 +140,8 @@ const ( MetaRealId = "X-Real-ID" // MetaRealIp real IP metadata key MetaRealIp = "X-Real-IP" + // MetaAcceptBodyCodec the key of body codec that the sender wishes to accept + MetaAcceptBodyCodec = "X-Accept-Body-Codec" ) // WithRealId sets the real ID to metadata. @@ -147,6 +154,24 @@ func WithRealIp(ip string) socket.PacketSetting { return socket.WithAddMeta(MetaRealIp, ip) } +// WithAcceptBodyCodec sets the body codec that the sender wishes to accept. +func WithAcceptBodyCodec(bodyCodec byte) socket.PacketSetting { + return socket.WithAddMeta(MetaAcceptBodyCodec, strconv.FormatUint(uint64(bodyCodec), 10)) +} + +// GetAcceptBodyCodec gets the body codec that the sender wishes to accept. +func GetAcceptBodyCodec(meta *utils.Args) (byte, bool) { + s := meta.Peek(MetaAcceptBodyCodec) + if len(s) == 0 || len(s) > 3 { + return 0, false + } + b, err := strconv.ParseUint(goutil.BytesToString(s), 10, 8) + if err != nil { + return 0, false + } + return byte(b), true +} + // WithContext sets the packet handling context. // func WithContext(ctx context.Context) socket.PacketSetting var WithContext = socket.WithContext diff --git a/context.go b/context.go index 478442b1..3fee15f7 100644 --- a/context.go +++ b/context.go @@ -557,23 +557,38 @@ func (c *readHandleCtx) handlePull() { } // reply pull - rerr := c.pluginContainer.PreWriteReply(c) - if rerr == nil { - c.handleErr = rerr - } - _, rerr = c.sess.write(c.output) + c.pluginContainer.PreWriteReply(c) + _, rerr := c.sess.write(c.output) if rerr != nil { if c.handleErr == nil { c.handleErr = rerr } - // rerr.SetToMeta(c.output.Meta()) + if rerr != rerrConnClosed { + c.output.SetBody(nil) + rerr2 := rerrInternalServerError.Copy() + rerr2.Detail = rerr.Detail + rerr2.SetToMeta(c.output.Meta()) + c.sess.write(c.output) + } return } - rerr = c.pluginContainer.PostWriteReply(c) - if c.handleErr == nil { - c.handleErr = rerr + c.pluginContainer.PostWriteReply(c) +} + +func (c *readHandleCtx) setReplyBody(body interface{}) { + c.output.SetBody(body) + if c.output.BodyCodec() != codec.NilCodecId { + return + } + acceptBodyCodec, ok := GetAcceptBodyCodec(c.input.Meta()) + if ok { + if _, err := codec.Get(acceptBodyCodec); err == nil { + c.output.SetBodyCodec(acceptBodyCodec) + return + } } + c.output.SetBodyCodec(c.input.BodyCodec()) } func (c *readHandleCtx) bindReply(header socket.Header) interface{} { diff --git a/plugin.go b/plugin.go index 46a4ccc5..5efee8a6 100644 --- a/plugin.go +++ b/plugin.go @@ -354,31 +354,29 @@ func (p *PluginContainer) PostWritePull(ctx WriteCtx) *Rerror { } // PreWriteReply executes the defined plugins before writing REPLY packet. -func (p *PluginContainer) PreWriteReply(ctx WriteCtx) *Rerror { +func (p *PluginContainer) PreWriteReply(ctx WriteCtx) { var rerr *Rerror for _, plugin := range p.plugins { if _plugin, ok := plugin.(PreWriteReplyPlugin); ok { if rerr = _plugin.PreWriteReply(ctx); rerr != nil { Errorf("%s-PreWriteReplyPlugin(%s)", plugin.Name(), rerr.String()) - return rerr + return } } } - return nil } // PostWriteReply executes the defined plugins after successful writing REPLY packet. -func (p *PluginContainer) PostWriteReply(ctx WriteCtx) *Rerror { +func (p *PluginContainer) PostWriteReply(ctx WriteCtx) { var rerr *Rerror for _, plugin := range p.plugins { if _plugin, ok := plugin.(PostWriteReplyPlugin); ok { if rerr = _plugin.PostWriteReply(ctx); rerr != nil { Errorf("%s-PostWriteReplyPlugin(%s)", plugin.Name(), rerr.String()) - return rerr + return } } } - return nil } // PreWritePush executes the defined plugins before writing PUSH packet. diff --git a/router.go b/router.go index 491a5ff0..8138ab44 100644 --- a/router.go +++ b/router.go @@ -24,7 +24,6 @@ import ( "github.com/henrylee2cn/goutil" "github.com/henrylee2cn/goutil/errors" - "github.com/henrylee2cn/teleport/codec" ) /** @@ -285,10 +284,7 @@ func (r *Router) SetUnknownPull(fn func(UnknownPullCtx) (interface{}, *Rerror), ctx.handleErr = rerr rerr.SetToMeta(ctx.output.Meta()) } else { - ctx.output.SetBody(body) - if ctx.output.BodyCodec() == codec.NilCodecId { - ctx.output.SetBodyCodec(ctx.input.BodyCodec()) - } + ctx.setReplyBody(body) } }, } @@ -444,10 +440,7 @@ func makePullHandlersFromStruct(pathPrefix string, pullCtrlStruct interface{}, p ctx.handleErr = rerr rerr.SetToMeta(ctx.output.Meta()) } else { - ctx.output.SetBody(rets[0].Interface()) - if ctx.output.BodyCodec() == codec.NilCodecId { - ctx.output.SetBodyCodec(ctx.input.BodyCodec()) - } + ctx.setReplyBody(rets[0].Interface()) } pool.Put(obj) } @@ -529,10 +522,7 @@ func makePullHandlersFromFunc(pathPrefix string, pullHandleFunc interface{}, plu ctx.handleErr = rerr rerr.SetToMeta(ctx.output.Meta()) } else { - ctx.output.SetBody(rets[0].Interface()) - if ctx.output.BodyCodec() == codec.NilCodecId { - ctx.output.SetBodyCodec(ctx.input.BodyCodec()) - } + ctx.setReplyBody(rets[0].Interface()) } } @@ -573,10 +563,7 @@ func makePullHandlersFromFunc(pathPrefix string, pullHandleFunc interface{}, plu ctx.handleErr = rerr rerr.SetToMeta(ctx.output.Meta()) } else { - ctx.output.SetBody(rets[0].Interface()) - if ctx.output.BodyCodec() == codec.NilCodecId { - ctx.output.SetBodyCodec(ctx.input.BodyCodec()) - } + ctx.setReplyBody(rets[0].Interface()) } pool.Put(obj) } diff --git a/samples/simple/client.go b/samples/simple/client.go index 13e0d342..813f05f2 100644 --- a/samples/simple/client.go +++ b/samples/simple/client.go @@ -18,6 +18,7 @@ func main() { rerr := sess.Pull("/math/add?push_status=yes", []int{1, 2, 3, 4, 5}, &reply, + // tp.WithAcceptBodyCodec('j'), ).Rerror() if rerr != nil { diff --git a/session.go b/session.go index e3cddbd1..27051b7b 100644 --- a/session.go +++ b/session.go @@ -702,6 +702,8 @@ func (s *session) write(packet *socket.Packet) (net.Conn, *Rerror) { return conn, rerrConnClosed } + Debugf("write error: %s", err.Error()) + ERR: rerr = rerrWriteFailed.Copy() rerr.Detail = err.Error()