Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

add high-priority stream flag using SetWriteDeadline magic value #1

Open
wants to merge 1 commit into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
47 changes: 35 additions & 12 deletions session.go
Original file line number Diff line number Diff line change
Expand Up @@ -96,6 +96,9 @@ type Session struct {
// sendCh is used to send messages
sendCh chan []byte

// highPrioSendCh is used to send messages for streams marked high priority.
highPrioSendCh chan []byte

// pingCh and pingCh are used to send pings and pongs
pongCh, pingCh chan uint32

Expand Down Expand Up @@ -144,6 +147,7 @@ func newSession(config *Config, conn net.Conn, client bool, readBuf int, newMemo
synCh: make(chan struct{}, config.AcceptBacklog),
acceptCh: make(chan *Stream, config.AcceptBacklog),
sendCh: make(chan []byte, 64),
highPrioSendCh: make(chan []byte, 64),
pongCh: make(chan uint32, config.PingBacklog),
pingCh: make(chan uint32),
recvDoneCh: make(chan struct{}),
Expand Down Expand Up @@ -326,7 +330,7 @@ func (s *Session) exitErr(err error) {
// GoAway can be used to prevent accepting further
// connections. It does not close the underlying conn.
func (s *Session) GoAway() error {
return s.sendMsg(s.goAway(goAwayNormal), nil, nil)
return s.sendMsg(s.goAway(goAwayNormal), nil, nil, false)
}

// goAway is used to send a goAway message
Expand Down Expand Up @@ -483,7 +487,7 @@ func (s *Session) extendKeepalive() {
}

// send sends the header and body.
func (s *Session) sendMsg(hdr header, body []byte, deadline <-chan struct{}) error {
func (s *Session) sendMsg(hdr header, body []byte, deadline <-chan struct{}, highPriority bool) error {
select {
case <-s.shutdownCh:
return s.shutdownErr
Expand All @@ -495,11 +499,16 @@ func (s *Session) sendMsg(hdr header, body []byte, deadline <-chan struct{}) err
copy(buf[:headerSize], hdr[:])
copy(buf[headerSize:], body)

sendCh := s.sendCh
if highPriority {
sendCh = s.highPrioSendCh
}

select {
case <-s.shutdownCh:
pool.Put(buf)
return s.shutdownErr
case s.sendCh <- buf:
case sendCh <- buf:
return nil
case <-deadline:
pool.Put(buf)
Expand Down Expand Up @@ -579,9 +588,9 @@ func (s *Session) sendLoop() (err error) {
hdr := encode(typePing, flagACK, 0, pingID)
copy(buf, hdr[:])
default:
// Then send normal data.
// Next, the highPrioSendCh gets to send before other streams, if data is available.
select {
case buf = <-s.sendCh:
case buf = <-s.highPrioSendCh:
case pingID := <-s.pingCh:
buf = pool.Get(headerSize)
hdr := encode(typePing, flagSYN, 0, pingID)
Expand All @@ -590,8 +599,22 @@ func (s *Session) sendLoop() (err error) {
buf = pool.Get(headerSize)
hdr := encode(typePing, flagACK, 0, pingID)
copy(buf, hdr[:])
case <-s.shutdownCh:
return nil
default:
// Then send normal data.
select {
case buf = <-s.highPrioSendCh:
case buf = <-s.sendCh:
case pingID := <-s.pingCh:
buf = pool.Get(headerSize)
hdr := encode(typePing, flagSYN, 0, pingID)
copy(buf, hdr[:])
case pingID := <-s.pongCh:
buf = pool.Get(headerSize)
hdr := encode(typePing, flagACK, 0, pingID)
copy(buf, hdr[:])
case <-s.shutdownCh:
return nil
}
// default:
// select {
// case buf = <-s.sendCh:
Expand Down Expand Up @@ -734,7 +757,7 @@ func (s *Session) handleStreamMessage(hdr header) error {

// Read the new data
if err := stream.readData(hdr, flags, s.reader); err != nil {
if sendErr := s.sendMsg(s.goAway(goAwayProtoErr), nil, nil); sendErr != nil {
if sendErr := s.sendMsg(s.goAway(goAwayProtoErr), nil, nil, false); sendErr != nil {
s.logger.Printf("[WARN] yamux: failed to send go away: %v", sendErr)
}
return err
Expand Down Expand Up @@ -802,7 +825,7 @@ func (s *Session) incomingStream(id uint32) error {
// Reject immediately if we are doing a go away
if atomic.LoadInt32(&s.localGoAway) == 1 {
hdr := encode(typeWindowUpdate, flagRST, id, 0)
return s.sendMsg(hdr, nil, nil)
return s.sendMsg(hdr, nil, nil, false)
}

// Allocate a new stream
Expand All @@ -821,7 +844,7 @@ func (s *Session) incomingStream(id uint32) error {
// Check if stream already exists
if _, ok := s.streams[id]; ok {
s.logger.Printf("[ERR] yamux: duplicate stream declared")
if sendErr := s.sendMsg(s.goAway(goAwayProtoErr), nil, nil); sendErr != nil {
if sendErr := s.sendMsg(s.goAway(goAwayProtoErr), nil, nil, false); sendErr != nil {
s.logger.Printf("[WARN] yamux: failed to send go away: %v", sendErr)
}
span.Done()
Expand All @@ -833,7 +856,7 @@ func (s *Session) incomingStream(id uint32) error {
s.logger.Printf("[WARN] yamux: MaxIncomingStreams exceeded, forcing stream reset")
defer span.Done()
hdr := encode(typeWindowUpdate, flagRST, id, 0)
return s.sendMsg(hdr, nil, nil)
return s.sendMsg(hdr, nil, nil, false)
}

s.numIncomingStreams++
Expand All @@ -850,7 +873,7 @@ func (s *Session) incomingStream(id uint32) error {
s.logger.Printf("[WARN] yamux: backlog exceeded, forcing stream reset")
s.deleteStream(id)
hdr := encode(typeWindowUpdate, flagRST, id, 0)
return s.sendMsg(hdr, nil, nil)
return s.sendMsg(hdr, nil, nil, false)
}
}

Expand Down
6 changes: 3 additions & 3 deletions session_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -1322,7 +1322,7 @@ func TestSession_sendMsg_Timeout(t *testing.T) {

hdr := encode(typePing, flagACK, 0, 0)
for {
err := client.sendMsg(hdr, nil, nil)
err := client.sendMsg(hdr, nil, nil, false)
if err == nil {
continue
} else if err == ErrConnectionWriteTimeout {
Expand All @@ -1345,14 +1345,14 @@ func TestWindowOverflow(t *testing.T) {
defer server.Close()

hdr1 := encode(typeData, flagSYN, i, 0)
_ = client.sendMsg(hdr1, nil, nil)
_ = client.sendMsg(hdr1, nil, nil, false)
s, err := server.AcceptStream()
if err != nil {
t.Fatal(err)
}
msg := make([]byte, client.config.MaxStreamWindowSize*2)
hdr2 := encode(typeData, 0, i, uint32(len(msg)))
_ = client.sendMsg(hdr2, msg, nil)
_ = client.sendMsg(hdr2, msg, nil, false)
_, err = io.ReadAll(s)
if err == nil {
t.Fatal("expected to read no data")
Expand Down
21 changes: 16 additions & 5 deletions stream.go
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,11 @@ const (
halfReset
)

// HighPriorityWriteDeadlineMagicValue is a special value that can be passed to
// SetWriteDeadline to indicate that this stream should get to send its data
// before other streams.
var HighPriorityWriteDeadlineMagicValue = time.Unix(1<<60, 0)

// Stream is used to represent a logical stream
// within a session.
type Stream struct {
Expand All @@ -49,6 +54,8 @@ type Stream struct {
sendNotifyCh chan struct{}

readDeadline, writeDeadline pipeDeadline

highPriority bool
}

// newStream is used to construct a new stream within a given session for an ID.
Expand All @@ -70,6 +77,7 @@ func newStream(session *Session, id uint32, state streamState, initialWindow uin
epochStart: time.Now(),
recvNotifyCh: make(chan struct{}, 1),
sendNotifyCh: make(chan struct{}, 1),
highPriority: false,
}
return s
}
Expand Down Expand Up @@ -179,7 +187,7 @@ START:

// Send the header
hdr = encode(typeData, flags, s.id, max)
if err = s.session.sendMsg(hdr, b[:max], s.writeDeadline.wait()); err != nil {
if err = s.session.sendMsg(hdr, b[:max], s.writeDeadline.wait(), s.highPriority); err != nil {
return 0, err
}

Expand Down Expand Up @@ -238,21 +246,21 @@ func (s *Stream) sendWindowUpdate(deadline <-chan struct{}) error {

s.epochStart = now
hdr := encode(typeWindowUpdate, flags, s.id, delta)
return s.session.sendMsg(hdr, nil, deadline)
return s.session.sendMsg(hdr, nil, deadline, s.highPriority)
}

// sendClose is used to send a FIN
func (s *Stream) sendClose() error {
flags := s.sendFlags()
flags |= flagFIN
hdr := encode(typeWindowUpdate, flags, s.id, 0)
return s.session.sendMsg(hdr, nil, nil)
return s.session.sendMsg(hdr, nil, nil, s.highPriority)
}

// sendReset is used to send a RST
func (s *Stream) sendReset() error {
hdr := encode(typeWindowUpdate, flagRST, s.id, 0)
return s.session.sendMsg(hdr, nil, nil)
return s.session.sendMsg(hdr, nil, nil, s.highPriority)
}

// Reset resets the stream (forcibly closes the stream)
Expand Down Expand Up @@ -490,7 +498,10 @@ func (s *Stream) SetReadDeadline(t time.Time) error {
func (s *Stream) SetWriteDeadline(t time.Time) error {
s.stateLock.Lock()
defer s.stateLock.Unlock()
if s.writeState == halfOpen {
// handle magic time.Time value to signal this is a high-priority stream.
if t.Equal(HighPriorityWriteDeadlineMagicValue) {

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why introduce this mechanism instead of addding an additional function to enable high priority?

s.highPriority = true
} else if s.writeState == halfOpen {
s.writeDeadline.set(t)
}
return nil
Expand Down