diff --git a/.clusterfuzzlite/Dockerfile b/.clusterfuzzlite/Dockerfile index b57da8caa..d9fedb9a5 100644 --- a/.clusterfuzzlite/Dockerfile +++ b/.clusterfuzzlite/Dockerfile @@ -3,7 +3,7 @@ FROM gcr.io/oss-fuzz-base/base-builder-go:v1 ARG TARGETPLATFORM RUN echo "TARGETPLATFORM: ${TARGETPLATFORM}" -ENV GOVERSION=1.20.7 +ENV GOVERSION=1.22.0 RUN platform=$(echo ${TARGETPLATFORM} | tr '/' '-') && \ filename="go${GOVERSION}.${platform}.tar.gz" && \ diff --git a/.github/dependabot.yml b/.github/dependabot.yml new file mode 100644 index 000000000..5ace4600a --- /dev/null +++ b/.github/dependabot.yml @@ -0,0 +1,6 @@ +version: 2 +updates: + - package-ecosystem: "github-actions" + directory: "/" + schedule: + interval: "weekly" diff --git a/.golangci.yml b/.golangci.yml index 1315759bc..469d54cfb 100644 --- a/.golangci.yml +++ b/.golangci.yml @@ -2,16 +2,6 @@ run: skip-files: - internal/handshake/cipher_suite.go linters-settings: - depguard: - type: blacklist - packages: - - github.com/marten-seemann/qtls - - github.com/quic-go/qtls-go1-19 - - github.com/quic-go/qtls-go1-20 - packages-with-error-message: - - github.com/marten-seemann/qtls: "importing qtls only allowed in internal/qtls" - - github.com/quic-go/qtls-go1-19: "importing qtls only allowed in internal/qtls" - - github.com/quic-go/qtls-go1-20: "importing qtls only allowed in internal/qtls" misspell: ignore-words: - ect @@ -20,7 +10,6 @@ linters: disable-all: true enable: - asciicheck - - depguard - exhaustive - exportloopref - goimports diff --git a/client.go b/client.go index c70937f56..70f7143bc 100644 --- a/client.go +++ b/client.go @@ -29,7 +29,7 @@ type client struct { initialPacketNumber protocol.PacketNumber hasNegotiatedVersion bool - version protocol.VersionNumber + version protocol.Version handshakeChan chan struct{} @@ -237,7 +237,7 @@ func (c *client) dial(ctx context.Context) error { select { case <-ctx.Done(): - c.conn.shutdown() + c.conn.destroy(nil) return context.Cause(ctx) case err := <-errorChan: return err diff --git a/client_test.go b/client_test.go index c55953f38..3b1da3b02 100644 --- a/client_test.go +++ b/client_test.go @@ -47,7 +47,7 @@ var _ = Describe("Client", func() { tracer *logging.ConnectionTracer, tracingID uint64, logger utils.Logger, - v protocol.VersionNumber, + v protocol.Version, ) quicConn ) @@ -61,7 +61,7 @@ var _ = Describe("Client", func() { Tracer: func(ctx context.Context, perspective logging.Perspective, id ConnectionID) *logging.ConnectionTracer { return tr }, - Versions: []protocol.VersionNumber{protocol.Version1}, + Versions: []protocol.Version{protocol.Version1}, } Eventually(areConnsRunning).Should(BeFalse()) packetConn = NewMockSendConn(mockCtrl) @@ -88,7 +88,7 @@ var _ = Describe("Client", func() { AfterEach(func() { if s, ok := cl.conn.(*connection); ok { - s.shutdown() + s.destroy(nil) } Eventually(areConnsRunning).Should(BeFalse()) }) @@ -126,11 +126,11 @@ var _ = Describe("Client", func() { _ *logging.ConnectionTracer, _ uint64, _ utils.Logger, - _ protocol.VersionNumber, + _ protocol.Version, ) quicConn { Expect(enable0RTT).To(BeFalse()) conn := NewMockQUICConn(mockCtrl) - conn.EXPECT().run().Do(func() { close(run) }) + conn.EXPECT().run().Do(func() error { close(run); return nil }) c := make(chan struct{}) close(c) conn.EXPECT().HandshakeComplete().Return(c) @@ -163,11 +163,11 @@ var _ = Describe("Client", func() { _ *logging.ConnectionTracer, _ uint64, _ utils.Logger, - _ protocol.VersionNumber, + _ protocol.Version, ) quicConn { Expect(enable0RTT).To(BeTrue()) conn := NewMockQUICConn(mockCtrl) - conn.EXPECT().run().Do(func() { close(done) }) + conn.EXPECT().run().Do(func() error { close(done); return nil }) conn.EXPECT().HandshakeComplete().Return(make(chan struct{})) conn.EXPECT().earlyConnReady().Return(readyChan) return conn @@ -200,7 +200,7 @@ var _ = Describe("Client", func() { _ *logging.ConnectionTracer, _ uint64, _ utils.Logger, - _ protocol.VersionNumber, + _ protocol.Version, ) quicConn { conn := NewMockQUICConn(mockCtrl) conn.EXPECT().run().Return(testErr) @@ -266,9 +266,9 @@ var _ = Describe("Client", func() { }) It("creates new connections with the right parameters", func() { - config := &Config{Versions: []protocol.VersionNumber{protocol.Version1}} + config := &Config{Versions: []protocol.Version{protocol.Version1}} c := make(chan struct{}) - var version protocol.VersionNumber + var version protocol.Version var conf *Config done := make(chan struct{}) newClientConnection = func( @@ -285,7 +285,7 @@ var _ = Describe("Client", func() { _ *logging.ConnectionTracer, _ uint64, _ utils.Logger, - versionP protocol.VersionNumber, + versionP protocol.Version, ) quicConn { version = versionP conf = configP @@ -328,7 +328,7 @@ var _ = Describe("Client", func() { _ *logging.ConnectionTracer, _ uint64, _ utils.Logger, - versionP protocol.VersionNumber, + versionP protocol.Version, ) quicConn { conn := NewMockQUICConn(mockCtrl) conn.EXPECT().HandshakeComplete().Return(make(chan struct{})) @@ -352,7 +352,7 @@ var _ = Describe("Client", func() { return conn } - config := &Config{Tracer: config.Tracer, Versions: []protocol.VersionNumber{protocol.Version1}} + config := &Config{Tracer: config.Tracer, Versions: []protocol.Version{protocol.Version1}} tracer.EXPECT().StartedConnection(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()) _, err := DialAddr(context.Background(), "localhost:7890", tlsConf, config) Expect(err).ToNot(HaveOccurred()) diff --git a/closed_conn.go b/closed_conn.go index 071071c7b..1dc144258 100644 --- a/closed_conn.go +++ b/closed_conn.go @@ -4,7 +4,6 @@ import ( "math/bits" "net" - "github.com/refraction-networking/uquic/internal/protocol" "github.com/refraction-networking/uquic/internal/utils" ) @@ -12,9 +11,8 @@ import ( // When receiving packets for such a connection, we need to retransmit the packet containing the CONNECTION_CLOSE frame, // with an exponential backoff. type closedLocalConn struct { - counter uint32 - perspective protocol.Perspective - logger utils.Logger + counter uint32 + logger utils.Logger sendPacket func(net.Addr, packetInfo) } @@ -22,11 +20,10 @@ type closedLocalConn struct { var _ packetHandler = &closedLocalConn{} // newClosedLocalConn creates a new closedLocalConn and runs it. -func newClosedLocalConn(sendPacket func(net.Addr, packetInfo), pers protocol.Perspective, logger utils.Logger) packetHandler { +func newClosedLocalConn(sendPacket func(net.Addr, packetInfo), logger utils.Logger) packetHandler { return &closedLocalConn{ - sendPacket: sendPacket, - perspective: pers, - logger: logger, + sendPacket: sendPacket, + logger: logger, } } @@ -41,24 +38,20 @@ func (c *closedLocalConn) handlePacket(p receivedPacket) { c.sendPacket(p.remoteAddr, p.info) } -func (c *closedLocalConn) shutdown() {} -func (c *closedLocalConn) destroy(error) {} -func (c *closedLocalConn) getPerspective() protocol.Perspective { return c.perspective } +func (c *closedLocalConn) destroy(error) {} +func (c *closedLocalConn) closeWithTransportError(TransportErrorCode) {} // A closedRemoteConn is a connection that was closed remotely. // For such a connection, we might receive reordered packets that were sent before the CONNECTION_CLOSE. // We can just ignore those packets. -type closedRemoteConn struct { - perspective protocol.Perspective -} +type closedRemoteConn struct{} var _ packetHandler = &closedRemoteConn{} -func newClosedRemoteConn(pers protocol.Perspective) packetHandler { - return &closedRemoteConn{perspective: pers} +func newClosedRemoteConn() packetHandler { + return &closedRemoteConn{} } -func (s *closedRemoteConn) handlePacket(receivedPacket) {} -func (s *closedRemoteConn) shutdown() {} -func (s *closedRemoteConn) destroy(error) {} -func (s *closedRemoteConn) getPerspective() protocol.Perspective { return s.perspective } +func (c *closedRemoteConn) handlePacket(receivedPacket) {} +func (c *closedRemoteConn) destroy(error) {} +func (c *closedRemoteConn) closeWithTransportError(TransportErrorCode) {} diff --git a/closed_conn_test.go b/closed_conn_test.go index 40337998c..d36d664b6 100644 --- a/closed_conn_test.go +++ b/closed_conn_test.go @@ -3,7 +3,6 @@ package quic import ( "net" - "github.com/refraction-networking/uquic/internal/protocol" "github.com/refraction-networking/uquic/internal/utils" . "github.com/onsi/ginkgo/v2" @@ -11,20 +10,9 @@ import ( ) var _ = Describe("Closed local connection", func() { - It("tells its perspective", func() { - conn := newClosedLocalConn(nil, protocol.PerspectiveClient, utils.DefaultLogger) - Expect(conn.getPerspective()).To(Equal(protocol.PerspectiveClient)) - // stop the connection - conn.shutdown() - }) - It("repeats the packet containing the CONNECTION_CLOSE frame", func() { written := make(chan net.Addr, 1) - conn := newClosedLocalConn( - func(addr net.Addr, _ packetInfo) { written <- addr }, - protocol.PerspectiveClient, - utils.DefaultLogger, - ) + conn := newClosedLocalConn(func(addr net.Addr, _ packetInfo) { written <- addr }, utils.DefaultLogger) addr := &net.UDPAddr{IP: net.IPv4(127, 1, 2, 3), Port: 1337} for i := 1; i <= 20; i++ { conn.handlePacket(receivedPacket{remoteAddr: addr}) diff --git a/codecov.yml b/codecov.yml index a24c7a15e..59e4b58f6 100644 --- a/codecov.yml +++ b/codecov.yml @@ -5,6 +5,8 @@ coverage: - interop/ - internal/handshake/cipher_suite.go - internal/utils/linkedlist/linkedlist.go + - internal/testdata + - testutils/ - fuzzing/ - metrics/ status: diff --git a/config.go b/config.go index 501ed1a07..414404883 100644 --- a/config.go +++ b/config.go @@ -2,7 +2,6 @@ package quic import ( "fmt" - "net" "time" "github.com/refraction-networking/uquic/internal/protocol" @@ -49,16 +48,6 @@ func validateConfig(config *Config) error { return nil } -// populateServerConfig populates fields in the quic.Config with their default values, if none are set -// it may be called with nil -func populateServerConfig(config *Config) *Config { - config = populateConfig(config) - if config.RequireAddressValidation == nil { - config.RequireAddressValidation = func(net.Addr) bool { return false } - } - return config -} - // populateConfig populates fields in the quic.Config with their default values, if none are set // it may be called with nil func populateConfig(config *Config) *Config { @@ -111,7 +100,6 @@ func populateConfig(config *Config) *Config { Versions: versions, HandshakeIdleTimeout: handshakeIdleTimeout, MaxIdleTimeout: idleTimeout, - RequireAddressValidation: config.RequireAddressValidation, KeepAlivePeriod: config.KeepAlivePeriod, InitialStreamReceiveWindow: initialStreamReceiveWindow, MaxStreamReceiveWindow: maxStreamReceiveWindow, diff --git a/config_test.go b/config_test.go index e0eef4304..b34575c0b 100644 --- a/config_test.go +++ b/config_test.go @@ -4,7 +4,6 @@ import ( "context" "errors" "fmt" - "net" "reflect" "time" @@ -23,7 +22,7 @@ var _ = Describe("Config", func() { }) It("validates a config with normal values", func() { - conf := populateServerConfig(&Config{ + conf := populateConfig(&Config{ MaxIncomingStreams: 5, MaxStreamReceiveWindow: 10, }) @@ -69,7 +68,7 @@ var _ = Describe("Config", func() { case "GetConfigForClient", "RequireAddressValidation", "GetLogWriter", "AllowConnectionWindowIncrease", "Tracer": // Can't compare functions. case "Versions": - f.Set(reflect.ValueOf([]VersionNumber{1, 2, 3})) + f.Set(reflect.ValueOf([]Version{1, 2, 3})) case "ConnectionIDLength": f.Set(reflect.ValueOf(8)) case "ConnectionIDGenerator": @@ -118,19 +117,16 @@ var _ = Describe("Config", func() { Context("cloning", func() { It("clones function fields", func() { - var calledAddrValidation, calledAllowConnectionWindowIncrease, calledTracer bool + var calledAllowConnectionWindowIncrease, calledTracer bool c1 := &Config{ GetConfigForClient: func(info *ClientHelloInfo) (*Config, error) { return nil, errors.New("nope") }, AllowConnectionWindowIncrease: func(Connection, uint64) bool { calledAllowConnectionWindowIncrease = true; return true }, - RequireAddressValidation: func(net.Addr) bool { calledAddrValidation = true; return true }, Tracer: func(context.Context, logging.Perspective, ConnectionID) *logging.ConnectionTracer { calledTracer = true return nil }, } c2 := c1.Clone() - c2.RequireAddressValidation(&net.UDPAddr{}) - Expect(calledAddrValidation).To(BeTrue()) c2.AllowConnectionWindowIncrease(nil, 1234) Expect(calledAllowConnectionWindowIncrease).To(BeTrue()) _, err := c2.GetConfigForClient(&ClientHelloInfo{}) @@ -145,29 +141,15 @@ var _ = Describe("Config", func() { }) It("returns a copy", func() { - c1 := &Config{ - MaxIncomingStreams: 100, - RequireAddressValidation: func(net.Addr) bool { return true }, - } + c1 := &Config{MaxIncomingStreams: 100} c2 := c1.Clone() c2.MaxIncomingStreams = 200 - c2.RequireAddressValidation = func(net.Addr) bool { return false } Expect(c1.MaxIncomingStreams).To(BeEquivalentTo(100)) - Expect(c1.RequireAddressValidation(&net.UDPAddr{})).To(BeTrue()) }) }) Context("populating", func() { - It("populates function fields", func() { - var calledAddrValidation bool - c1 := &Config{} - c1.RequireAddressValidation = func(net.Addr) bool { calledAddrValidation = true; return true } - c2 := populateConfig(c1) - c2.RequireAddressValidation(&net.UDPAddr{}) - Expect(calledAddrValidation).To(BeTrue()) - }) - It("copies non-function fields", func() { c := configWithNonZeroNonFunctionFields() Expect(populateConfig(c)).To(Equal(c)) @@ -186,10 +168,5 @@ var _ = Describe("Config", func() { Expect(c.DisablePathMTUDiscovery).To(BeFalse()) Expect(c.GetConfigForClient).To(BeNil()) }) - - It("populates empty fields with default values, for the server", func() { - c := populateServerConfig(&Config{}) - Expect(c.RequireAddressValidation).ToNot(BeNil()) - }) }) }) diff --git a/conn_id_generator.go b/conn_id_generator.go index aab6bb422..5f64b8a5b 100644 --- a/conn_id_generator.go +++ b/conn_id_generator.go @@ -5,7 +5,6 @@ import ( "github.com/refraction-networking/uquic/internal/protocol" "github.com/refraction-networking/uquic/internal/qerr" - "github.com/refraction-networking/uquic/internal/utils" "github.com/refraction-networking/uquic/internal/wire" ) @@ -20,7 +19,7 @@ type connIDGenerator struct { getStatelessResetToken func(protocol.ConnectionID) protocol.StatelessResetToken removeConnectionID func(protocol.ConnectionID) retireConnectionID func(protocol.ConnectionID) - replaceWithClosed func([]protocol.ConnectionID, protocol.Perspective, []byte) + replaceWithClosed func([]protocol.ConnectionID, []byte) queueControlFrame func(wire.Frame) } @@ -31,7 +30,7 @@ func newConnIDGenerator( getStatelessResetToken func(protocol.ConnectionID) protocol.StatelessResetToken, removeConnectionID func(protocol.ConnectionID), retireConnectionID func(protocol.ConnectionID), - replaceWithClosed func([]protocol.ConnectionID, protocol.Perspective, []byte), + replaceWithClosed func([]protocol.ConnectionID, []byte), queueControlFrame func(wire.Frame), generator ConnectionIDGenerator, ) *connIDGenerator { @@ -60,7 +59,7 @@ func (m *connIDGenerator) SetMaxActiveConnIDs(limit uint64) error { // transport parameter. // We currently don't send the preferred_address transport parameter, // so we can issue (limit - 1) connection IDs. - for i := uint64(len(m.activeSrcConnIDs)); i < utils.Min(limit, protocol.MaxIssuedConnectionIDs); i++ { + for i := uint64(len(m.activeSrcConnIDs)); i < min(limit, protocol.MaxIssuedConnectionIDs); i++ { if err := m.issueNewConnID(); err != nil { return err } @@ -127,7 +126,7 @@ func (m *connIDGenerator) RemoveAll() { } } -func (m *connIDGenerator) ReplaceWithClosed(pers protocol.Perspective, connClose []byte) { +func (m *connIDGenerator) ReplaceWithClosed(connClose []byte) { connIDs := make([]protocol.ConnectionID, 0, len(m.activeSrcConnIDs)+1) if m.initialClientDestConnID != nil { connIDs = append(connIDs, *m.initialClientDestConnID) @@ -135,5 +134,5 @@ func (m *connIDGenerator) ReplaceWithClosed(pers protocol.Perspective, connClose for _, connID := range m.activeSrcConnIDs { connIDs = append(connIDs, connID) } - m.replaceWithClosed(connIDs, pers, connClose) + m.replaceWithClosed(connIDs, connClose) } diff --git a/conn_id_generator_test.go b/conn_id_generator_test.go index 36e5db655..2e8a27856 100644 --- a/conn_id_generator_test.go +++ b/conn_id_generator_test.go @@ -41,9 +41,7 @@ var _ = Describe("Connection ID Generator", func() { connIDToToken, func(c protocol.ConnectionID) { removedConnIDs = append(removedConnIDs, c) }, func(c protocol.ConnectionID) { retiredConnIDs = append(retiredConnIDs, c) }, - func(cs []protocol.ConnectionID, _ protocol.Perspective, _ []byte) { - replacedWithClosed = append(replacedWithClosed, cs...) - }, + func(cs []protocol.ConnectionID, _ []byte) { replacedWithClosed = append(replacedWithClosed, cs...) }, func(f wire.Frame) { queuedFrames = append(queuedFrames, f) }, &protocol.DefaultConnectionIDGenerator{ConnLen: initialConnID.Len()}, ) @@ -177,7 +175,7 @@ var _ = Describe("Connection ID Generator", func() { It("replaces with a closed connection for all connection IDs", func() { Expect(g.SetMaxActiveConnIDs(5)).To(Succeed()) Expect(queuedFrames).To(HaveLen(4)) - g.ReplaceWithClosed(protocol.PerspectiveClient, []byte("foobar")) + g.ReplaceWithClosed([]byte("foobar")) Expect(replacedWithClosed).To(HaveLen(6)) // initial conn ID, initial client dest conn id, and newly issued ones Expect(replacedWithClosed).To(ContainElement(initialClientDestConnID)) Expect(replacedWithClosed).To(ContainElement(initialConnID)) diff --git a/conn_id_manager.go b/conn_id_manager.go index 49ca79ec8..86f013c9a 100644 --- a/conn_id_manager.go +++ b/conn_id_manager.go @@ -155,7 +155,7 @@ func (h *connIDManager) updateConnectionID() { h.queueControlFrame(&wire.RetireConnectionIDFrame{ SequenceNumber: h.activeSequenceNumber, }) - h.highestRetired = utils.Max(h.highestRetired, h.activeSequenceNumber) + h.highestRetired = max(h.highestRetired, h.activeSequenceNumber) if h.activeStatelessResetToken != nil { h.removeStatelessResetToken(*h.activeStatelessResetToken) } diff --git a/connection.go b/connection.go index c7e0d4a35..cdf53169b 100644 --- a/connection.go +++ b/connection.go @@ -26,7 +26,7 @@ import ( ) type unpacker interface { - UnpackLongHeader(hdr *wire.Header, rcvTime time.Time, data []byte, v protocol.VersionNumber) (*unpackedPacket, error) + UnpackLongHeader(hdr *wire.Header, rcvTime time.Time, data []byte, v protocol.Version) (*unpackedPacket, error) UnpackShortHeader(rcvTime time.Time, data []byte) (protocol.PacketNumber, protocol.PacketNumberLen, protocol.KeyPhaseBit, []byte, error) } @@ -94,7 +94,7 @@ type connRunner interface { GetStatelessResetToken(protocol.ConnectionID) protocol.StatelessResetToken Retire(protocol.ConnectionID) Remove(protocol.ConnectionID) - ReplaceWithClosed([]protocol.ConnectionID, protocol.Perspective, []byte) + ReplaceWithClosed([]protocol.ConnectionID, []byte) AddResetToken(protocol.StatelessResetToken, packetHandler) RemoveResetToken(protocol.StatelessResetToken) } @@ -107,7 +107,7 @@ type closeError struct { type errCloseForRecreating struct { nextPacketNumber protocol.PacketNumber - nextVersion protocol.VersionNumber + nextVersion protocol.Version } func (e *errCloseForRecreating) Error() string { @@ -129,7 +129,7 @@ type connection struct { srcConnIDLen int perspective protocol.Perspective - version protocol.VersionNumber + version protocol.Version config *Config conn sendConn @@ -178,6 +178,7 @@ type connection struct { earlyConnReadyChan chan struct{} sentFirstPacket bool + droppedInitialKeys bool handshakeComplete bool handshakeConfirmed bool @@ -236,7 +237,7 @@ var newConnection = func( tracer *logging.ConnectionTracer, tracingID uint64, logger utils.Logger, - v protocol.VersionNumber, + v protocol.Version, ) quicConn { s := &connection{ conn: conn, @@ -308,7 +309,7 @@ var newConnection = func( RetrySourceConnectionID: retrySrcConnID, } if s.config.EnableDatagrams { - params.MaxDatagramFrameSize = protocol.MaxDatagramFrameSize + params.MaxDatagramFrameSize = wire.MaxDatagramSize } else { params.MaxDatagramFrameSize = protocol.InvalidByteCount } @@ -349,7 +350,7 @@ var newClientConnection = func( tracer *logging.ConnectionTracer, tracingID uint64, logger utils.Logger, - v protocol.VersionNumber, + v protocol.Version, ) quicConn { s := &connection{ conn: conn, @@ -417,7 +418,7 @@ var newClientConnection = func( InitialSourceConnectionID: srcConnID, } if s.config.EnableDatagrams { - params.MaxDatagramFrameSize = protocol.MaxDatagramFrameSize + params.MaxDatagramFrameSize = wire.MaxDatagramSize } else { params.MaxDatagramFrameSize = protocol.InvalidByteCount } @@ -457,7 +458,7 @@ func (s *connection) preSetup() { s.handshakeStream = newCryptoStream() s.sendQueue = newSendQueue(s.conn) s.retransmissionQueue = newRetransmissionQueue() - s.frameParser = wire.NewFrameParser(s.config.EnableDatagrams) + s.frameParser = *wire.NewFrameParser(s.config.EnableDatagrams) s.rttStats = &utils.RTTStats{} s.connFlowController = flowcontrol.NewConnectionFlowController( protocol.ByteCount(s.config.InitialConnectionReceiveWindow), @@ -524,6 +525,9 @@ func (s *connection) run() error { runLoop: for { + if s.framer.QueuedTooManyControlFrames() { + s.closeLocal(&qerr.TransportError{ErrorCode: InternalError}) + } // Close immediately if requested select { case closeErr = <-s.closeChan: @@ -633,7 +637,7 @@ runLoop: sendQueueAvailable = s.sendQueue.Available() continue } - if err := s.triggerSending(); err != nil { + if err := s.triggerSending(now); err != nil { s.closeLocal(err) } if s.sendQueue.WouldBlock() { @@ -685,7 +689,7 @@ func (s *connection) ConnectionState() ConnectionState { // Time when the connection should time out func (s *connection) nextIdleTimeoutTime() time.Time { - idleTimeout := utils.Max(s.idleTimeout, s.rttStats.PTO(true)*3) + idleTimeout := max(s.idleTimeout, s.rttStats.PTO(true)*3) return s.idleTimeoutStartTime().Add(idleTimeout) } @@ -695,7 +699,7 @@ func (s *connection) nextKeepAliveTime() time.Time { if s.config.KeepAlivePeriod == 0 || s.keepAlivePingSent || !s.firstAckElicitingPacketAfterIdleSentTime.IsZero() { return time.Time{} } - keepAliveInterval := utils.Max(s.keepAliveInterval, s.rttStats.PTO(true)*3/2) + keepAliveInterval := max(s.keepAliveInterval, s.rttStats.PTO(true)*3/2) return s.lastPacketReceivedTime.Add(keepAliveInterval) } @@ -735,6 +739,10 @@ func (s *connection) handleHandshakeComplete() error { s.connIDManager.SetHandshakeComplete() s.connIDGenerator.SetHandshakeComplete() + if s.tracer != nil && s.tracer.ChoseALPN != nil { + s.tracer.ChoseALPN(s.cryptoStreamHandler.ConnectionState().NegotiatedProtocol) + } + // The server applies transport parameters right away, but the client side has to wait for handshake completion. // During a 0-RTT connection, the client is only allowed to use the new transport parameters for 1-RTT packets. if s.perspective == protocol.PerspectiveClient { @@ -780,7 +788,7 @@ func (s *connection) handleHandshakeConfirmed() error { if maxPacketSize == 0 { maxPacketSize = protocol.MaxByteCount } - s.mtuDiscoverer.Start(utils.Min(maxPacketSize, protocol.MaxPacketBufferSize)) + s.mtuDiscoverer.Start(min(maxPacketSize, protocol.MaxPacketBufferSize)) } return nil } @@ -808,14 +816,14 @@ func (s *connection) handlePacketImpl(rp receivedPacket) bool { destConnID, err = wire.ParseConnectionID(p.data, s.srcConnIDLen) if err != nil { if s.tracer != nil && s.tracer.DroppedPacket != nil { - s.tracer.DroppedPacket(logging.PacketTypeNotDetermined, protocol.ByteCount(len(data)), logging.PacketDropHeaderParseError) + s.tracer.DroppedPacket(logging.PacketTypeNotDetermined, protocol.InvalidPacketNumber, protocol.ByteCount(len(data)), logging.PacketDropHeaderParseError) } s.logger.Debugf("error parsing packet, couldn't parse connection ID: %s", err) break } if destConnID != lastConnID { if s.tracer != nil && s.tracer.DroppedPacket != nil { - s.tracer.DroppedPacket(logging.PacketTypeNotDetermined, protocol.ByteCount(len(data)), logging.PacketDropUnknownConnectionID) + s.tracer.DroppedPacket(logging.PacketTypeNotDetermined, protocol.InvalidPacketNumber, protocol.ByteCount(len(data)), logging.PacketDropUnknownConnectionID) } s.logger.Debugf("coalesced packet has different destination connection ID: %s, expected %s", destConnID, lastConnID) break @@ -830,7 +838,7 @@ func (s *connection) handlePacketImpl(rp receivedPacket) bool { if err == wire.ErrUnsupportedVersion { dropReason = logging.PacketDropUnsupportedVersion } - s.tracer.DroppedPacket(logging.PacketTypeNotDetermined, protocol.ByteCount(len(data)), dropReason) + s.tracer.DroppedPacket(logging.PacketTypeNotDetermined, protocol.InvalidPacketNumber, protocol.ByteCount(len(data)), dropReason) } s.logger.Debugf("error parsing packet: %s", err) break @@ -839,7 +847,7 @@ func (s *connection) handlePacketImpl(rp receivedPacket) bool { if hdr.Version != s.version { if s.tracer != nil && s.tracer.DroppedPacket != nil { - s.tracer.DroppedPacket(logging.PacketTypeFromHeader(hdr), protocol.ByteCount(len(data)), logging.PacketDropUnexpectedVersion) + s.tracer.DroppedPacket(logging.PacketTypeFromHeader(hdr), protocol.InvalidPacketNumber, protocol.ByteCount(len(data)), logging.PacketDropUnexpectedVersion) } s.logger.Debugf("Dropping packet with version %x. Expected %x.", hdr.Version, s.version) break @@ -898,7 +906,7 @@ func (s *connection) handleShortHeaderPacket(p receivedPacket, destConnID protoc if s.receivedPacketHandler.IsPotentiallyDuplicate(pn, protocol.Encryption1RTT) { s.logger.Debugf("Dropping (potentially) duplicate packet.") if s.tracer != nil && s.tracer.DroppedPacket != nil { - s.tracer.DroppedPacket(logging.PacketType1RTT, p.Size(), logging.PacketDropDuplicate) + s.tracer.DroppedPacket(logging.PacketType1RTT, pn, p.Size(), logging.PacketDropDuplicate) } return false } @@ -944,7 +952,7 @@ func (s *connection) handleLongHeaderPacket(p receivedPacket, hdr *wire.Header) // After this, all packets with a different source connection have to be ignored. if s.receivedFirstPacket && hdr.Type == protocol.PacketTypeInitial && hdr.SrcConnectionID != s.handshakeDestConnID { if s.tracer != nil && s.tracer.DroppedPacket != nil { - s.tracer.DroppedPacket(logging.PacketTypeInitial, p.Size(), logging.PacketDropUnknownConnectionID) + s.tracer.DroppedPacket(logging.PacketTypeInitial, protocol.InvalidPacketNumber, p.Size(), logging.PacketDropUnknownConnectionID) } s.logger.Debugf("Dropping Initial packet (%d bytes) with unexpected source connection ID: %s (expected %s)", p.Size(), hdr.SrcConnectionID, s.handshakeDestConnID) return false @@ -952,7 +960,7 @@ func (s *connection) handleLongHeaderPacket(p receivedPacket, hdr *wire.Header) // drop 0-RTT packets, if we are a client if s.perspective == protocol.PerspectiveClient && hdr.Type == protocol.PacketType0RTT { if s.tracer != nil && s.tracer.DroppedPacket != nil { - s.tracer.DroppedPacket(logging.PacketType0RTT, p.Size(), logging.PacketDropKeyUnavailable) + s.tracer.DroppedPacket(logging.PacketType0RTT, protocol.InvalidPacketNumber, p.Size(), logging.PacketDropKeyUnavailable) } return false } @@ -968,10 +976,10 @@ func (s *connection) handleLongHeaderPacket(p receivedPacket, hdr *wire.Header) packet.hdr.Log(s.logger) } - if s.receivedPacketHandler.IsPotentiallyDuplicate(packet.hdr.PacketNumber, packet.encryptionLevel) { + if pn := packet.hdr.PacketNumber; s.receivedPacketHandler.IsPotentiallyDuplicate(pn, packet.encryptionLevel) { s.logger.Debugf("Dropping (potentially) duplicate packet.") if s.tracer != nil && s.tracer.DroppedPacket != nil { - s.tracer.DroppedPacket(logging.PacketTypeFromHeader(hdr), p.Size(), logging.PacketDropDuplicate) + s.tracer.DroppedPacket(logging.PacketTypeFromHeader(hdr), pn, p.Size(), logging.PacketDropDuplicate) } return false } @@ -987,7 +995,7 @@ func (s *connection) handleUnpackError(err error, p receivedPacket, pt logging.P switch err { case handshake.ErrKeysDropped: if s.tracer != nil && s.tracer.DroppedPacket != nil { - s.tracer.DroppedPacket(pt, p.Size(), logging.PacketDropKeyUnavailable) + s.tracer.DroppedPacket(pt, protocol.InvalidPacketNumber, p.Size(), logging.PacketDropKeyUnavailable) } s.logger.Debugf("Dropping %s packet (%d bytes) because we already dropped the keys.", pt, p.Size()) case handshake.ErrKeysNotYetAvailable: @@ -1003,7 +1011,7 @@ func (s *connection) handleUnpackError(err error, p receivedPacket, pt logging.P case handshake.ErrDecryptionFailed: // This might be a packet injected by an attacker. Drop it. if s.tracer != nil && s.tracer.DroppedPacket != nil { - s.tracer.DroppedPacket(pt, p.Size(), logging.PacketDropPayloadDecryptError) + s.tracer.DroppedPacket(pt, protocol.InvalidPacketNumber, p.Size(), logging.PacketDropPayloadDecryptError) } s.logger.Debugf("Dropping %s packet (%d bytes) that could not be unpacked. Error: %s", pt, p.Size(), err) default: @@ -1011,7 +1019,7 @@ func (s *connection) handleUnpackError(err error, p receivedPacket, pt logging.P if errors.As(err, &headerErr) { // This might be a packet injected by an attacker. Drop it. if s.tracer != nil && s.tracer.DroppedPacket != nil { - s.tracer.DroppedPacket(pt, p.Size(), logging.PacketDropHeaderParseError) + s.tracer.DroppedPacket(pt, protocol.InvalidPacketNumber, p.Size(), logging.PacketDropHeaderParseError) } s.logger.Debugf("Dropping %s packet (%d bytes) for which we couldn't unpack the header. Error: %s", pt, p.Size(), err) } else { @@ -1026,14 +1034,14 @@ func (s *connection) handleUnpackError(err error, p receivedPacket, pt logging.P func (s *connection) handleRetryPacket(hdr *wire.Header, data []byte, rcvTime time.Time) bool /* was this a valid Retry */ { if s.perspective == protocol.PerspectiveServer { if s.tracer != nil && s.tracer.DroppedPacket != nil { - s.tracer.DroppedPacket(logging.PacketTypeRetry, protocol.ByteCount(len(data)), logging.PacketDropUnexpectedPacket) + s.tracer.DroppedPacket(logging.PacketTypeRetry, protocol.InvalidPacketNumber, protocol.ByteCount(len(data)), logging.PacketDropUnexpectedPacket) } s.logger.Debugf("Ignoring Retry.") return false } if s.receivedFirstPacket { if s.tracer != nil && s.tracer.DroppedPacket != nil { - s.tracer.DroppedPacket(logging.PacketTypeRetry, protocol.ByteCount(len(data)), logging.PacketDropUnexpectedPacket) + s.tracer.DroppedPacket(logging.PacketTypeRetry, protocol.InvalidPacketNumber, protocol.ByteCount(len(data)), logging.PacketDropUnexpectedPacket) } s.logger.Debugf("Ignoring Retry, since we already received a packet.") return false @@ -1041,7 +1049,7 @@ func (s *connection) handleRetryPacket(hdr *wire.Header, data []byte, rcvTime ti destConnID := s.connIDManager.Get() if hdr.SrcConnectionID == destConnID { if s.tracer != nil && s.tracer.DroppedPacket != nil { - s.tracer.DroppedPacket(logging.PacketTypeRetry, protocol.ByteCount(len(data)), logging.PacketDropUnexpectedPacket) + s.tracer.DroppedPacket(logging.PacketTypeRetry, protocol.InvalidPacketNumber, protocol.ByteCount(len(data)), logging.PacketDropUnexpectedPacket) } s.logger.Debugf("Ignoring Retry, since the server didn't change the Source Connection ID.") return false @@ -1056,7 +1064,7 @@ func (s *connection) handleRetryPacket(hdr *wire.Header, data []byte, rcvTime ti tag := handshake.GetRetryIntegrityTag(data[:len(data)-16], destConnID, hdr.Version) if !bytes.Equal(data[len(data)-16:], tag[:]) { if s.tracer != nil && s.tracer.DroppedPacket != nil { - s.tracer.DroppedPacket(logging.PacketTypeRetry, protocol.ByteCount(len(data)), logging.PacketDropPayloadDecryptError) + s.tracer.DroppedPacket(logging.PacketTypeRetry, protocol.InvalidPacketNumber, protocol.ByteCount(len(data)), logging.PacketDropPayloadDecryptError) } s.logger.Debugf("Ignoring spoofed Retry. Integrity Tag doesn't match.") return false @@ -1089,7 +1097,7 @@ func (s *connection) handleVersionNegotiationPacket(p receivedPacket) { if s.perspective == protocol.PerspectiveServer || // servers never receive version negotiation packets s.receivedFirstPacket || s.versionNegotiated { // ignore delayed / duplicated version negotiation packets if s.tracer != nil && s.tracer.DroppedPacket != nil { - s.tracer.DroppedPacket(logging.PacketTypeVersionNegotiation, p.Size(), logging.PacketDropUnexpectedPacket) + s.tracer.DroppedPacket(logging.PacketTypeVersionNegotiation, protocol.InvalidPacketNumber, p.Size(), logging.PacketDropUnexpectedPacket) } return } @@ -1097,7 +1105,7 @@ func (s *connection) handleVersionNegotiationPacket(p receivedPacket) { src, dest, supportedVersions, err := wire.ParseVersionNegotiationPacket(p.data) if err != nil { if s.tracer != nil && s.tracer.DroppedPacket != nil { - s.tracer.DroppedPacket(logging.PacketTypeVersionNegotiation, p.Size(), logging.PacketDropHeaderParseError) + s.tracer.DroppedPacket(logging.PacketTypeVersionNegotiation, protocol.InvalidPacketNumber, p.Size(), logging.PacketDropHeaderParseError) } s.logger.Debugf("Error parsing Version Negotiation packet: %s", err) return @@ -1106,7 +1114,7 @@ func (s *connection) handleVersionNegotiationPacket(p receivedPacket) { for _, v := range supportedVersions { if v == s.version { if s.tracer != nil && s.tracer.DroppedPacket != nil { - s.tracer.DroppedPacket(logging.PacketTypeVersionNegotiation, p.Size(), logging.PacketDropUnexpectedVersion) + s.tracer.DroppedPacket(logging.PacketTypeVersionNegotiation, protocol.InvalidPacketNumber, p.Size(), logging.PacketDropUnexpectedVersion) } // The Version Negotiation packet contains the version that we offered. // This might be a packet sent by an attacker, or it was corrupted. @@ -1148,7 +1156,7 @@ func (s *connection) handleUnpackedLongHeaderPacket( if !s.receivedFirstPacket { s.receivedFirstPacket = true if !s.versionNegotiated && s.tracer != nil && s.tracer.NegotiatedVersion != nil { - var clientVersions, serverVersions []protocol.VersionNumber + var clientVersions, serverVersions []protocol.Version switch s.perspective { case protocol.PerspectiveClient: clientVersions = s.config.Versions @@ -1185,7 +1193,8 @@ func (s *connection) handleUnpackedLongHeaderPacket( } } - if s.perspective == protocol.PerspectiveServer && packet.encryptionLevel == protocol.EncryptionHandshake { + if s.perspective == protocol.PerspectiveServer && packet.encryptionLevel == protocol.EncryptionHandshake && + !s.droppedInitialKeys { // On the server side, Initial keys are dropped as soon as the first Handshake packet is received. // See Section 4.9.1 of RFC 9001. if err := s.dropEncryptionLevel(protocol.EncryptionInitial); err != nil { @@ -1347,7 +1356,7 @@ func (s *connection) handlePacket(p receivedPacket) { case s.receivedPackets <- p: default: if s.tracer != nil && s.tracer.DroppedPacket != nil { - s.tracer.DroppedPacket(logging.PacketTypeNotDetermined, p.Size(), logging.PacketDropDOSPrevention) + s.tracer.DroppedPacket(logging.PacketTypeNotDetermined, protocol.InvalidPacketNumber, p.Size(), logging.PacketDropDOSPrevention) } } } @@ -1526,7 +1535,7 @@ func (s *connection) handleAckFrame(frame *wire.AckFrame, encLevel protocol.Encr } func (s *connection) handleDatagramFrame(f *wire.DatagramFrame) error { - if f.Length(s.version) > protocol.MaxDatagramFrameSize { + if f.Length(s.version) > wire.MaxDatagramSize { return &qerr.TransportError{ ErrorCode: qerr.ProtocolViolation, ErrorMessage: "DATAGRAM frame too large", @@ -1572,13 +1581,6 @@ func (s *connection) closeRemote(e error) { }) } -// Close the connection. It sends a NO_ERROR application error. -// It waits until the run loop has stopped before returning -func (s *connection) shutdown() { - s.closeLocal(nil) - <-s.ctx.Done() -} - func (s *connection) CloseWithError(code ApplicationErrorCode, desc string) error { s.closeLocal(&qerr.ApplicationError{ ErrorCode: code, @@ -1588,6 +1590,11 @@ func (s *connection) CloseWithError(code ApplicationErrorCode, desc string) erro return nil } +func (s *connection) closeWithTransportError(code TransportErrorCode) { + s.closeLocal(&qerr.TransportError{ErrorCode: code}) + <-s.ctx.Done() +} + func (s *connection) handleCloseError(closeErr *closeError) { e := closeErr.err if e == nil { @@ -1632,7 +1639,7 @@ func (s *connection) handleCloseError(closeErr *closeError) { // If this is a remote close we're done here if closeErr.remote { - s.connIDGenerator.ReplaceWithClosed(s.perspective, nil) + s.connIDGenerator.ReplaceWithClosed(nil) return } if closeErr.immediate { @@ -1649,7 +1656,7 @@ func (s *connection) handleCloseError(closeErr *closeError) { if err != nil { s.logger.Debugf("Error sending CONNECTION_CLOSE: %s", err) } - s.connIDGenerator.ReplaceWithClosed(s.perspective, connClosePacket) + s.connIDGenerator.ReplaceWithClosed(connClosePacket) } func (s *connection) dropEncryptionLevel(encLevel protocol.EncryptionLevel) error { @@ -1661,6 +1668,7 @@ func (s *connection) dropEncryptionLevel(encLevel protocol.EncryptionLevel) erro //nolint:exhaustive // only Initial and 0-RTT need special treatment switch encLevel { case protocol.EncryptionInitial: + s.droppedInitialKeys = true s.cryptoStreamHandler.DiscardInitialKeys() case protocol.Encryption0RTT: s.streamsMap.ResetFor0RTT() @@ -1755,7 +1763,7 @@ func (s *connection) applyTransportParameters() { params := s.peerParams // Our local idle timeout will always be > 0. s.idleTimeout = utils.MinNonZeroDuration(s.config.MaxIdleTimeout, params.MaxIdleTimeout) - s.keepAliveInterval = utils.Min(s.config.KeepAlivePeriod, utils.Min(s.idleTimeout/2, protocol.MaxKeepAliveInterval)) + s.keepAliveInterval = min(s.config.KeepAlivePeriod, min(s.idleTimeout/2, protocol.MaxKeepAliveInterval)) s.streamsMap.UpdateLimits(params) s.frameParser.SetAckDelayExponent(params.AckDelayExponent) s.connFlowController.UpdateSendWindow(params.InitialMaxData) @@ -1771,9 +1779,8 @@ func (s *connection) applyTransportParameters() { } } -func (s *connection) triggerSending() error { +func (s *connection) triggerSending(now time.Time) error { s.pacingDeadline = time.Time{} - now := time.Now() sendMode := s.sentPacketHandler.SendMode(now) //nolint:exhaustive // No need to handle pacing limited here. @@ -1805,7 +1812,7 @@ func (s *connection) triggerSending() error { s.scheduleSending() return nil } - return s.triggerSending() + return s.triggerSending(now) case ackhandler.SendPTOHandshake: if err := s.sendProbePacket(protocol.EncryptionHandshake, now); err != nil { return err @@ -1814,7 +1821,7 @@ func (s *connection) triggerSending() error { s.scheduleSending() return nil } - return s.triggerSending() + return s.triggerSending(now) case ackhandler.SendPTOAppData: if err := s.sendProbePacket(protocol.Encryption1RTT, now); err != nil { return err @@ -1823,7 +1830,7 @@ func (s *connection) triggerSending() error { s.scheduleSending() return nil } - return s.triggerSending() + return s.triggerSending(now) default: return fmt.Errorf("BUG: invalid send mode %d", sendMode) } @@ -1992,7 +1999,7 @@ func (s *connection) maybeSendAckOnlyPacket(now time.Time) error { if packet == nil { return nil } - return s.sendPackedCoalescedPacket(packet, ecn, time.Now()) + return s.sendPackedCoalescedPacket(packet, ecn, now) } ecn := s.sentPacketHandler.ECNMode(true) @@ -2078,7 +2085,8 @@ func (s *connection) sendPackedCoalescedPacket(packet *coalescedPacket, ecn prot largestAcked = p.ack.LargestAcked() } s.sentPacketHandler.SentPacket(now, p.header.PacketNumber, largestAcked, p.streamFrames, p.frames, p.EncryptionLevel(), ecn, p.length, false) - if s.perspective == protocol.PerspectiveClient && p.EncryptionLevel() == protocol.EncryptionHandshake { + if s.perspective == protocol.PerspectiveClient && p.EncryptionLevel() == protocol.EncryptionHandshake && + !s.droppedInitialKeys { // On the client side, Initial keys are dropped as soon as the first Handshake packet is sent. // See Section 4.9.1 of RFC 9001. if err := s.dropEncryptionLevel(protocol.EncryptionInitial); err != nil { @@ -2309,7 +2317,7 @@ func (s *connection) tryQueueingUndecryptablePacket(p receivedPacket, pt logging } if len(s.undecryptablePackets)+1 > protocol.MaxUndecryptablePackets { if s.tracer != nil && s.tracer.DroppedPacket != nil { - s.tracer.DroppedPacket(pt, p.Size(), logging.PacketDropDOSPrevention) + s.tracer.DroppedPacket(pt, protocol.InvalidPacketNumber, p.Size(), logging.PacketDropDOSPrevention) } s.logger.Infof("Dropping undecryptable packet (%d bytes). Undecryptable packet queue full.", p.Size()) return @@ -2347,21 +2355,23 @@ func (s *connection) onStreamCompleted(id protocol.StreamID) { } } -func (s *connection) SendMessage(p []byte) error { +func (s *connection) SendDatagram(p []byte) error { if !s.supportsDatagrams() { return errors.New("datagram support disabled") } f := &wire.DatagramFrame{DataLenPresent: true} if protocol.ByteCount(len(p)) > f.MaxDataLen(s.peerParams.MaxDatagramFrameSize, s.version) { - return errors.New("message too large") + return &DatagramTooLargeError{ + PeerMaxDatagramFrameSize: int64(s.peerParams.MaxDatagramFrameSize), + } } f.Data = make([]byte, len(p)) copy(f.Data, p) - return s.datagramQueue.AddAndWait(f) + return s.datagramQueue.Add(f) } -func (s *connection) ReceiveMessage(ctx context.Context) ([]byte, error) { +func (s *connection) ReceiveDatagram(ctx context.Context) ([]byte, error) { if !s.config.EnableDatagrams { return nil, errors.New("datagram support disabled") } @@ -2376,11 +2386,7 @@ func (s *connection) RemoteAddr() net.Addr { return s.conn.RemoteAddr() } -func (s *connection) getPerspective() protocol.Perspective { - return s.perspective -} - -func (s *connection) GetVersion() protocol.VersionNumber { +func (s *connection) GetVersion() protocol.Version { return s.version } diff --git a/connection_test.go b/connection_test.go index 87cf05c95..9947e5860 100644 --- a/connection_test.go +++ b/connection_test.go @@ -8,6 +8,7 @@ import ( "fmt" "io" "net" + "net/netip" "runtime/pprof" "strings" "time" @@ -21,10 +22,10 @@ import ( mocklogging "github.com/refraction-networking/uquic/internal/mocks/logging" "github.com/refraction-networking/uquic/internal/protocol" "github.com/refraction-networking/uquic/internal/qerr" - "github.com/refraction-networking/uquic/internal/testutils" "github.com/refraction-networking/uquic/internal/utils" "github.com/refraction-networking/uquic/internal/wire" "github.com/refraction-networking/uquic/logging" + "github.com/refraction-networking/uquic/testutils" . "github.com/onsi/ginkgo/v2" . "github.com/onsi/gomega" @@ -76,7 +77,7 @@ var _ = Describe("Connection", func() { } expectReplaceWithClosed := func() { - connRunner.EXPECT().ReplaceWithClosed(gomock.Any(), gomock.Any(), gomock.Any()).Do(func(connIDs []protocol.ConnectionID, _ protocol.Perspective, _ []byte) { + connRunner.EXPECT().ReplaceWithClosed(gomock.Any(), gomock.Any()).Do(func(connIDs []protocol.ConnectionID, _ []byte) { Expect(connIDs).To(ContainElement(srcConnID)) if len(connIDs) > 1 { Expect(connIDs).To(ContainElement(clientDestConnID)) @@ -84,8 +85,8 @@ var _ = Describe("Connection", func() { }) } - expectAppendPacket := func(packer *MockPacker, p shortHeaderPacket, b []byte) *gomock.Call { - return packer.EXPECT().AppendPacket(gomock.Any(), gomock.Any(), Version1).DoAndReturn(func(buf *packetBuffer, _ protocol.ByteCount, _ protocol.VersionNumber) (shortHeaderPacket, error) { + expectAppendPacket := func(packer *MockPacker, p shortHeaderPacket, b []byte) *MockPackerAppendPacketCall { + return packer.EXPECT().AppendPacket(gomock.Any(), gomock.Any(), Version1).DoAndReturn(func(buf *packetBuffer, _ protocol.ByteCount, _ protocol.Version) (shortHeaderPacket, error) { buf.Data = append(buf.Data, b...) return p, nil }) @@ -118,7 +119,7 @@ var _ = Describe("Connection", func() { srcConnID, &protocol.DefaultConnectionIDGenerator{}, protocol.StatelessResetToken{}, - populateServerConfig(&Config{DisablePathMTUDiscovery: true}), + populateConfig(&Config{DisablePathMTUDiscovery: true}), &tls.Config{}, tokenGenerator, false, @@ -346,7 +347,7 @@ var _ = Describe("Connection", func() { ErrorMessage: "foobar", } streamManager.EXPECT().CloseWithError(expectedErr) - connRunner.EXPECT().ReplaceWithClosed(gomock.Any(), gomock.Any(), gomock.Any()).Do(func(connIDs []protocol.ConnectionID, _ protocol.Perspective, _ []byte) { + connRunner.EXPECT().ReplaceWithClosed(gomock.Any(), gomock.Any()).Do(func(connIDs []protocol.ConnectionID, _ []byte) { Expect(connIDs).To(ConsistOf(clientDestConnID, srcConnID)) }) cryptoSetup.EXPECT().Close() @@ -375,7 +376,7 @@ var _ = Describe("Connection", func() { ErrorMessage: "foobar", } streamManager.EXPECT().CloseWithError(testErr) - connRunner.EXPECT().ReplaceWithClosed(gomock.Any(), gomock.Any(), gomock.Any()).Do(func(connIDs []protocol.ConnectionID, _ protocol.Perspective, _ []byte) { + connRunner.EXPECT().ReplaceWithClosed(gomock.Any(), gomock.Any()).Do(func(connIDs []protocol.ConnectionID, _ []byte) { Expect(connIDs).To(ConsistOf(clientDestConnID, srcConnID)) }) cryptoSetup.EXPECT().Close() @@ -410,7 +411,7 @@ var _ = Describe("Connection", func() { It("tells its versions", func() { conn.version = 4242 - Expect(conn.GetVersion()).To(Equal(protocol.VersionNumber(4242))) + Expect(conn.GetVersion()).To(Equal(protocol.Version(4242))) }) Context("closing", func() { @@ -450,7 +451,7 @@ var _ = Describe("Connection", func() { cryptoSetup.EXPECT().Close() buffer := getPacketBuffer() buffer.Data = append(buffer.Data, []byte("connection close")...) - packer.EXPECT().PackApplicationClose(gomock.Any(), gomock.Any(), conn.version).DoAndReturn(func(e *qerr.ApplicationError, _ protocol.ByteCount, _ protocol.VersionNumber) (*coalescedPacket, error) { + packer.EXPECT().PackApplicationClose(gomock.Any(), gomock.Any(), conn.version).DoAndReturn(func(e *qerr.ApplicationError, _ protocol.ByteCount, _ protocol.Version) (*coalescedPacket, error) { Expect(e.ErrorCode).To(BeEquivalentTo(qerr.NoError)) Expect(e.ErrorMessage).To(BeEmpty()) return &coalescedPacket{buffer: buffer}, nil @@ -465,7 +466,7 @@ var _ = Describe("Connection", func() { }), tracer.EXPECT().Close(), ) - conn.shutdown() + conn.CloseWithError(0, "") Eventually(areConnsRunning).Should(BeFalse()) Expect(conn.Context().Done()).To(BeClosed()) }) @@ -479,8 +480,8 @@ var _ = Describe("Connection", func() { mconn.EXPECT().Write(gomock.Any(), gomock.Any(), gomock.Any()) tracer.EXPECT().ClosedConnection(gomock.Any()) tracer.EXPECT().Close() - conn.shutdown() - conn.shutdown() + conn.CloseWithError(0, "") + conn.CloseWithError(0, "") Eventually(areConnsRunning).Should(BeFalse()) Expect(conn.Context().Done()).To(BeClosed()) }) @@ -551,29 +552,6 @@ var _ = Describe("Connection", func() { } }) - It("cancels the context when the run loop exists", func() { - runConn() - streamManager.EXPECT().CloseWithError(gomock.Any()) - expectReplaceWithClosed() - cryptoSetup.EXPECT().Close() - packer.EXPECT().PackApplicationClose(gomock.Any(), gomock.Any(), conn.version).Return(&coalescedPacket{buffer: getPacketBuffer()}, nil) - returned := make(chan struct{}) - go func() { - defer GinkgoRecover() - ctx := conn.Context() - <-ctx.Done() - Expect(ctx.Err()).To(MatchError(context.Canceled)) - close(returned) - }() - Consistently(returned).ShouldNot(BeClosed()) - mconn.EXPECT().Write(gomock.Any(), gomock.Any(), gomock.Any()) - tracer.EXPECT().ClosedConnection(gomock.Any()) - tracer.EXPECT().Close() - conn.shutdown() - Eventually(returned).Should(BeClosed()) - Expect(context.Cause(conn.Context())).To(MatchError(context.Canceled)) - }) - It("doesn't send any more packets after receiving a CONNECTION_CLOSE", func() { unpacker := NewMockUnpacker(mockCtrl) conn.handshakeConfirmed = true @@ -581,7 +559,7 @@ var _ = Describe("Connection", func() { runConn() cryptoSetup.EXPECT().Close() streamManager.EXPECT().CloseWithError(gomock.Any()) - connRunner.EXPECT().ReplaceWithClosed(gomock.Any(), gomock.Any(), gomock.Any()).AnyTimes() + connRunner.EXPECT().ReplaceWithClosed(gomock.Any(), gomock.Any()).AnyTimes() b, err := wire.AppendShortHeader(nil, srcConnID, 42, protocol.PacketNumberLen2, protocol.KeyPhaseOne) Expect(err).ToNot(HaveOccurred()) @@ -687,7 +665,7 @@ var _ = Describe("Connection", func() { Version: conn.version, Token: []byte("foobar"), }}, make([]byte, 16) /* Retry integrity tag */) - tracer.EXPECT().DroppedPacket(logging.PacketTypeRetry, p.Size(), logging.PacketDropUnexpectedPacket) + tracer.EXPECT().DroppedPacket(logging.PacketTypeRetry, protocol.InvalidPacketNumber, p.Size(), logging.PacketDropUnexpectedPacket) Expect(conn.handlePacketImpl(p)).To(BeFalse()) }) @@ -697,7 +675,7 @@ var _ = Describe("Connection", func() { protocol.ArbitraryLenConnectionID(destConnID.Bytes()), conn.config.Versions, ) - tracer.EXPECT().DroppedPacket(logging.PacketTypeVersionNegotiation, protocol.ByteCount(len(b)), logging.PacketDropUnexpectedPacket) + tracer.EXPECT().DroppedPacket(logging.PacketTypeVersionNegotiation, protocol.InvalidPacketNumber, protocol.ByteCount(len(b)), logging.PacketDropUnexpectedPacket) Expect(conn.handlePacketImpl(receivedPacket{ data: b, buffer: getPacketBuffer(), @@ -713,7 +691,7 @@ var _ = Describe("Connection", func() { PacketNumberLen: protocol.PacketNumberLen2, }, nil) p.data[0] ^= 0x40 // unset the QUIC bit - tracer.EXPECT().DroppedPacket(logging.PacketTypeNotDetermined, p.Size(), logging.PacketDropHeaderParseError) + tracer.EXPECT().DroppedPacket(logging.PacketTypeNotDetermined, protocol.InvalidPacketNumber, p.Size(), logging.PacketDropHeaderParseError) Expect(conn.handlePacketImpl(p)).To(BeFalse()) }) @@ -725,12 +703,12 @@ var _ = Describe("Connection", func() { }, PacketNumberLen: protocol.PacketNumberLen2, }, nil) - tracer.EXPECT().DroppedPacket(logging.PacketTypeNotDetermined, p.Size(), logging.PacketDropUnsupportedVersion) + tracer.EXPECT().DroppedPacket(logging.PacketTypeNotDetermined, protocol.InvalidPacketNumber, p.Size(), logging.PacketDropUnsupportedVersion) Expect(conn.handlePacketImpl(p)).To(BeFalse()) }) It("drops packets with an unsupported version", func() { - origSupportedVersions := make([]protocol.VersionNumber, len(protocol.SupportedVersions)) + origSupportedVersions := make([]protocol.Version, len(protocol.SupportedVersions)) copy(origSupportedVersions, protocol.SupportedVersions) defer func() { protocol.SupportedVersions = origSupportedVersions @@ -746,7 +724,7 @@ var _ = Describe("Connection", func() { }, PacketNumberLen: protocol.PacketNumberLen2, }, nil) - tracer.EXPECT().DroppedPacket(logging.PacketTypeHandshake, p.Size(), logging.PacketDropUnexpectedVersion) + tracer.EXPECT().DroppedPacket(logging.PacketTypeHandshake, protocol.InvalidPacketNumber, p.Size(), logging.PacketDropUnexpectedVersion) Expect(conn.handlePacketImpl(p)).To(BeFalse()) }) @@ -812,7 +790,7 @@ var _ = Describe("Connection", func() { rph := mockackhandler.NewMockReceivedPacketHandler(mockCtrl) rph.EXPECT().IsPotentiallyDuplicate(protocol.PacketNumber(0x1337), protocol.Encryption1RTT).Return(true) conn.receivedPacketHandler = rph - tracer.EXPECT().DroppedPacket(logging.PacketType1RTT, protocol.ByteCount(len(packet.data)), logging.PacketDropDuplicate) + tracer.EXPECT().DroppedPacket(logging.PacketType1RTT, protocol.PacketNumber(0x1337), protocol.ByteCount(len(packet.data)), logging.PacketDropDuplicate) Expect(conn.handlePacketImpl(packet)).To(BeFalse()) }) @@ -838,7 +816,7 @@ var _ = Describe("Connection", func() { PacketNumber: 0x1337, PacketNumberLen: protocol.PacketNumberLen2, }, []byte("foobar")) - tracer.EXPECT().DroppedPacket(logging.PacketTypeHandshake, p.Size(), logging.PacketDropPayloadDecryptError) + tracer.EXPECT().DroppedPacket(logging.PacketTypeHandshake, protocol.InvalidPacketNumber, p.Size(), logging.PacketDropPayloadDecryptError) conn.handlePacket(p) Consistently(conn.Context().Done()).ShouldNot(BeClosed()) // make the go routine return @@ -956,7 +934,7 @@ var _ = Describe("Connection", func() { runErr <- conn.run() }() expectReplaceWithClosed() - tracer.EXPECT().DroppedPacket(logging.PacketType1RTT, gomock.Any(), logging.PacketDropHeaderParseError) + tracer.EXPECT().DroppedPacket(logging.PacketType1RTT, protocol.InvalidPacketNumber, gomock.Any(), logging.PacketDropHeaderParseError) conn.handlePacket(getShortHeaderPacket(srcConnID, 0x42, nil)) Consistently(runErr).ShouldNot(Receive()) // make the go routine return @@ -964,7 +942,7 @@ var _ = Describe("Connection", func() { tracer.EXPECT().ClosedConnection(gomock.Any()) tracer.EXPECT().Close() mconn.EXPECT().Write(gomock.Any(), gomock.Any(), gomock.Any()) - conn.shutdown() + conn.CloseWithError(0, "") Eventually(conn.Context().Done()).Should(BeClosed()) }) @@ -1029,7 +1007,7 @@ var _ = Describe("Connection", func() { Expect(conn.handlePacketImpl(p1)).To(BeTrue()) // The next packet has to be ignored, since the source connection ID doesn't match. p2 := getLongHeaderPacket(hdr2, nil) - tracer.EXPECT().DroppedPacket(logging.PacketTypeInitial, protocol.ByteCount(len(p2.data)), logging.PacketDropUnknownConnectionID) + tracer.EXPECT().DroppedPacket(logging.PacketTypeInitial, protocol.InvalidPacketNumber, protocol.ByteCount(len(p2.data)), logging.PacketDropUnknownConnectionID) Expect(conn.handlePacketImpl(p2)).To(BeFalse()) }) @@ -1088,7 +1066,7 @@ var _ = Describe("Connection", func() { It("cuts packets to the right length", func() { hdrLen, packet := getPacketWithLength(srcConnID, 456) - unpacker.EXPECT().UnpackLongHeader(gomock.Any(), gomock.Any(), gomock.Any(), conn.version).DoAndReturn(func(_ *wire.Header, _ time.Time, data []byte, _ protocol.VersionNumber) (*unpackedPacket, error) { + unpacker.EXPECT().UnpackLongHeader(gomock.Any(), gomock.Any(), gomock.Any(), conn.version).DoAndReturn(func(_ *wire.Header, _ time.Time, data []byte, _ protocol.Version) (*unpackedPacket, error) { Expect(data).To(HaveLen(hdrLen + 456 - 3)) return &unpackedPacket{ encryptionLevel: protocol.EncryptionHandshake, @@ -1105,7 +1083,7 @@ var _ = Describe("Connection", func() { It("handles coalesced packets", func() { hdrLen1, packet1 := getPacketWithLength(srcConnID, 456) packet1.ecn = protocol.ECT1 - unpacker.EXPECT().UnpackLongHeader(gomock.Any(), gomock.Any(), gomock.Any(), conn.version).DoAndReturn(func(_ *wire.Header, _ time.Time, data []byte, _ protocol.VersionNumber) (*unpackedPacket, error) { + unpacker.EXPECT().UnpackLongHeader(gomock.Any(), gomock.Any(), gomock.Any(), conn.version).DoAndReturn(func(_ *wire.Header, _ time.Time, data []byte, _ protocol.Version) (*unpackedPacket, error) { Expect(data).To(HaveLen(hdrLen1 + 456 - 3)) return &unpackedPacket{ encryptionLevel: protocol.EncryptionHandshake, @@ -1117,7 +1095,7 @@ var _ = Describe("Connection", func() { }, nil }) hdrLen2, packet2 := getPacketWithLength(srcConnID, 123) - unpacker.EXPECT().UnpackLongHeader(gomock.Any(), gomock.Any(), gomock.Any(), conn.version).DoAndReturn(func(_ *wire.Header, _ time.Time, data []byte, _ protocol.VersionNumber) (*unpackedPacket, error) { + unpacker.EXPECT().UnpackLongHeader(gomock.Any(), gomock.Any(), gomock.Any(), conn.version).DoAndReturn(func(_ *wire.Header, _ time.Time, data []byte, _ protocol.Version) (*unpackedPacket, error) { Expect(data).To(HaveLen(hdrLen2 + 123 - 3)) return &unpackedPacket{ encryptionLevel: protocol.EncryptionHandshake, @@ -1144,7 +1122,7 @@ var _ = Describe("Connection", func() { hdrLen2, packet2 := getPacketWithLength(srcConnID, 123) gomock.InOrder( unpacker.EXPECT().UnpackLongHeader(gomock.Any(), gomock.Any(), gomock.Any(), conn.version).Return(nil, handshake.ErrKeysNotYetAvailable), - unpacker.EXPECT().UnpackLongHeader(gomock.Any(), gomock.Any(), gomock.Any(), conn.version).DoAndReturn(func(_ *wire.Header, _ time.Time, data []byte, _ protocol.VersionNumber) (*unpackedPacket, error) { + unpacker.EXPECT().UnpackLongHeader(gomock.Any(), gomock.Any(), gomock.Any(), conn.version).DoAndReturn(func(_ *wire.Header, _ time.Time, data []byte, _ protocol.Version) (*unpackedPacket, error) { Expect(data).To(HaveLen(hdrLen2 + 123 - 3)) return &unpackedPacket{ encryptionLevel: protocol.EncryptionHandshake, @@ -1170,7 +1148,7 @@ var _ = Describe("Connection", func() { wrongConnID := protocol.ParseConnectionID([]byte{0xde, 0xad, 0xbe, 0xef}) Expect(srcConnID).ToNot(Equal(wrongConnID)) hdrLen1, packet1 := getPacketWithLength(srcConnID, 456) - unpacker.EXPECT().UnpackLongHeader(gomock.Any(), gomock.Any(), gomock.Any(), conn.version).DoAndReturn(func(_ *wire.Header, _ time.Time, data []byte, _ protocol.VersionNumber) (*unpackedPacket, error) { + unpacker.EXPECT().UnpackLongHeader(gomock.Any(), gomock.Any(), gomock.Any(), conn.version).DoAndReturn(func(_ *wire.Header, _ time.Time, data []byte, _ protocol.Version) (*unpackedPacket, error) { Expect(data).To(HaveLen(hdrLen1 + 456 - 3)) return &unpackedPacket{ encryptionLevel: protocol.EncryptionHandshake, @@ -1184,7 +1162,7 @@ var _ = Describe("Connection", func() { // don't EXPECT any more calls to unpacker.UnpackLongHeader() gomock.InOrder( tracer.EXPECT().ReceivedLongHeaderPacket(gomock.Any(), protocol.ByteCount(len(packet1.data)), gomock.Any(), gomock.Any()), - tracer.EXPECT().DroppedPacket(gomock.Any(), protocol.ByteCount(len(packet2.data)), logging.PacketDropUnknownConnectionID), + tracer.EXPECT().DroppedPacket(gomock.Any(), protocol.InvalidPacketNumber, protocol.ByteCount(len(packet2.data)), logging.PacketDropUnknownConnectionID), ) packet1.data = append(packet1.data, packet2.data...) Expect(conn.handlePacketImpl(packet1)).To(BeTrue()) @@ -1219,7 +1197,7 @@ var _ = Describe("Connection", func() { tracer.EXPECT().ClosedConnection(gomock.Any()) tracer.EXPECT().Close() sender.EXPECT().Close() - conn.shutdown() + conn.CloseWithError(0, "") Eventually(conn.Context().Done()).Should(BeClosed()) Eventually(connDone).Should(BeClosed()) }) @@ -1270,7 +1248,6 @@ var _ = Describe("Connection", func() { sph.EXPECT().ECNMode(true).AnyTimes() runConn() packer.EXPECT().AppendPacket(gomock.Any(), gomock.Any(), conn.version).Return(shortHeaderPacket{}, errNothingToPack).AnyTimes() - conn.receivedPacketHandler.ReceivedPacket(0x035e, protocol.ECNNon, protocol.Encryption1RTT, time.Now(), true) conn.scheduleSending() time.Sleep(50 * time.Millisecond) // make sure there are no calls to mconn.Write() }) @@ -1281,7 +1258,10 @@ var _ = Describe("Connection", func() { sph.EXPECT().SendMode(gomock.Any()).Return(ackhandler.SendAck) sph.EXPECT().ECNMode(gomock.Any()).Return(protocol.ECT1).AnyTimes() done := make(chan struct{}) - packer.EXPECT().PackCoalescedPacket(true, gomock.Any(), conn.version).Do(func(bool, protocol.ByteCount, protocol.VersionNumber) { close(done) }) + packer.EXPECT().PackCoalescedPacket(true, gomock.Any(), conn.version).Do(func(bool, protocol.ByteCount, protocol.Version) (*coalescedPacket, error) { + close(done) + return nil, nil + }) runConn() conn.scheduleSending() Eventually(done).Should(BeClosed()) @@ -1323,7 +1303,7 @@ var _ = Describe("Connection", func() { Context(fmt.Sprintf("sending %s probe packets", encLevel), func() { var sendMode ackhandler.SendMode - var getFrame func(protocol.ByteCount, protocol.VersionNumber) wire.Frame + var getFrame func(protocol.ByteCount, protocol.Version) wire.Frame BeforeEach(func() { //nolint:exhaustive @@ -1420,7 +1400,7 @@ var _ = Describe("Connection", func() { tracer.EXPECT().ClosedConnection(gomock.Any()) tracer.EXPECT().Close() sender.EXPECT().Close() - conn.shutdown() + conn.CloseWithError(0, "") Eventually(conn.Context().Done()).Should(BeClosed()) }) @@ -1809,7 +1789,7 @@ var _ = Describe("Connection", func() { sender.EXPECT().Close() tracer.EXPECT().ClosedConnection(gomock.Any()) tracer.EXPECT().Close() - conn.shutdown() + conn.CloseWithError(0, "") Eventually(conn.Context().Done()).Should(BeClosed()) }) @@ -1915,7 +1895,7 @@ var _ = Describe("Connection", func() { ) sent := make(chan struct{}) - mconn.EXPECT().Write([]byte("foobar"), uint16(0), protocol.ECT1).Do(func([]byte, uint16, protocol.ECN) { close(sent) }) + mconn.EXPECT().Write([]byte("foobar"), uint16(0), protocol.ECT1).Do(func([]byte, uint16, protocol.ECN) error { close(sent); return nil }) go func() { defer GinkgoRecover() @@ -1935,7 +1915,7 @@ var _ = Describe("Connection", func() { mconn.EXPECT().Write(gomock.Any(), gomock.Any(), gomock.Any()) tracer.EXPECT().ClosedConnection(gomock.Any()) tracer.EXPECT().Close() - conn.shutdown() + conn.CloseWithError(0, "") Eventually(conn.Context().Done()).Should(BeClosed()) }) @@ -1944,6 +1924,7 @@ var _ = Describe("Connection", func() { sph := mockackhandler.NewMockSentPacketHandler(mockCtrl) conn.sentPacketHandler = sph tracer.EXPECT().DroppedEncryptionLevel(protocol.EncryptionHandshake) + tracer.EXPECT().ChoseALPN(gomock.Any()) sph.EXPECT().GetLossDetectionTimeout().AnyTimes() sph.EXPECT().TimeUntilSend().AnyTimes() sph.EXPECT().SendMode(gomock.Any()).AnyTimes() @@ -1952,6 +1933,7 @@ var _ = Describe("Connection", func() { connRunner.EXPECT().Retire(clientDestConnID) cryptoSetup.EXPECT().SetHandshakeConfirmed() cryptoSetup.EXPECT().GetSessionTicket() + cryptoSetup.EXPECT().ConnectionState() handshakeCtx := conn.HandshakeComplete() Consistently(handshakeCtx).ShouldNot(BeClosed()) Expect(conn.handleHandshakeComplete()).To(Succeed()) @@ -1964,8 +1946,10 @@ var _ = Describe("Connection", func() { connRunner.EXPECT().Retire(clientDestConnID) conn.sentPacketHandler.DropPackets(protocol.EncryptionInitial) tracer.EXPECT().DroppedEncryptionLevel(protocol.EncryptionHandshake) + tracer.EXPECT().ChoseALPN(gomock.Any()) cryptoSetup.EXPECT().SetHandshakeConfirmed() cryptoSetup.EXPECT().GetSessionTicket().Return(make([]byte, size), nil) + cryptoSetup.EXPECT().ConnectionState() handshakeCtx := conn.HandshakeComplete() Consistently(handshakeCtx).ShouldNot(BeClosed()) @@ -2019,10 +2003,11 @@ var _ = Describe("Connection", func() { sph.EXPECT().SentPacket(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()) mconn.EXPECT().Write(gomock.Any(), gomock.Any(), gomock.Any()) tracer.EXPECT().SentShortHeaderPacket(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()) + tracer.EXPECT().ChoseALPN(gomock.Any()) conn.sentPacketHandler = sph done := make(chan struct{}) connRunner.EXPECT().Retire(clientDestConnID) - packer.EXPECT().AppendPacket(gomock.Any(), gomock.Any(), conn.version).DoAndReturn(func(_ *packetBuffer, _ protocol.ByteCount, v protocol.VersionNumber) (shortHeaderPacket, error) { + packer.EXPECT().AppendPacket(gomock.Any(), gomock.Any(), conn.version).DoAndReturn(func(_ *packetBuffer, _ protocol.ByteCount, v protocol.Version) (shortHeaderPacket, error) { frames, _ := conn.framer.AppendControlFrames(nil, protocol.MaxByteCount, v) Expect(frames).ToNot(BeEmpty()) Expect(frames[0].Frame).To(BeEquivalentTo(&wire.HandshakeDoneFrame{})) @@ -2039,6 +2024,7 @@ var _ = Describe("Connection", func() { cryptoSetup.EXPECT().NextEvent().Return(handshake.Event{Kind: handshake.EventNoEvent}) cryptoSetup.EXPECT().SetHandshakeConfirmed() cryptoSetup.EXPECT().GetSessionTicket() + cryptoSetup.EXPECT().ConnectionState() mconn.EXPECT().Write(gomock.Any(), gomock.Any(), gomock.Any()) Expect(conn.handleHandshakeComplete()).To(Succeed()) conn.run() @@ -2051,7 +2037,7 @@ var _ = Describe("Connection", func() { cryptoSetup.EXPECT().Close() tracer.EXPECT().ClosedConnection(gomock.Any()) tracer.EXPECT().Close() - conn.shutdown() + conn.CloseWithError(0, "") Eventually(conn.Context().Done()).Should(BeClosed()) }) @@ -2065,13 +2051,11 @@ var _ = Describe("Connection", func() { close(done) }() streamManager.EXPECT().CloseWithError(gomock.Any()) - expectReplaceWithClosed() - packer.EXPECT().PackApplicationClose(gomock.Any(), gomock.Any(), conn.version).Return(&coalescedPacket{buffer: getPacketBuffer()}, nil) cryptoSetup.EXPECT().Close() - mconn.EXPECT().Write(gomock.Any(), gomock.Any(), gomock.Any()) + connRunner.EXPECT().Remove(gomock.Any()).AnyTimes() tracer.EXPECT().ClosedConnection(gomock.Any()) tracer.EXPECT().Close() - conn.shutdown() + conn.destroy(nil) Eventually(done).Should(BeClosed()) Expect(context.Cause(conn.Context())).To(MatchError(context.Canceled)) }) @@ -2157,7 +2141,7 @@ var _ = Describe("Connection", func() { mconn.EXPECT().Write(gomock.Any(), gomock.Any(), gomock.Any()) tracer.EXPECT().ClosedConnection(gomock.Any()) tracer.EXPECT().Close() - conn.shutdown() + conn.CloseWithError(0, "") Eventually(conn.Context().Done()).Should(BeClosed()) }) @@ -2165,7 +2149,7 @@ var _ = Describe("Connection", func() { setRemoteIdleTimeout(5 * time.Second) conn.lastPacketReceivedTime = time.Now().Add(-5 * time.Second / 2) sent := make(chan struct{}) - packer.EXPECT().PackCoalescedPacket(false, gomock.Any(), conn.version).Do(func(bool, protocol.ByteCount, protocol.VersionNumber) (*coalescedPacket, error) { + packer.EXPECT().PackCoalescedPacket(false, gomock.Any(), conn.version).Do(func(bool, protocol.ByteCount, protocol.Version) (*coalescedPacket, error) { close(sent) return nil, nil }) @@ -2178,7 +2162,7 @@ var _ = Describe("Connection", func() { setRemoteIdleTimeout(time.Hour) conn.lastPacketReceivedTime = time.Now().Add(-protocol.MaxKeepAliveInterval).Add(-time.Millisecond) sent := make(chan struct{}) - packer.EXPECT().PackCoalescedPacket(false, gomock.Any(), conn.version).Do(func(bool, protocol.ByteCount, protocol.VersionNumber) (*coalescedPacket, error) { + packer.EXPECT().PackCoalescedPacket(false, gomock.Any(), conn.version).Do(func(bool, protocol.ByteCount, protocol.Version) (*coalescedPacket, error) { close(sent) return nil, nil }) @@ -2211,7 +2195,7 @@ var _ = Describe("Connection", func() { pto := conn.rttStats.PTO(true) conn.lastPacketReceivedTime = time.Now() sentPingTimeChan := make(chan time.Time) - packer.EXPECT().PackCoalescedPacket(false, gomock.Any(), conn.version).Do(func(bool, protocol.ByteCount, protocol.VersionNumber) (*coalescedPacket, error) { + packer.EXPECT().PackCoalescedPacket(false, gomock.Any(), conn.version).Do(func(bool, protocol.ByteCount, protocol.Version) (*coalescedPacket, error) { sentPingTimeChan <- time.Now() return nil, nil }) @@ -2282,7 +2266,7 @@ var _ = Describe("Connection", func() { conn.config.HandshakeIdleTimeout = 9999 * time.Second conn.config.MaxIdleTimeout = 9999 * time.Second conn.lastPacketReceivedTime = time.Now().Add(-time.Minute) - packer.EXPECT().PackApplicationClose(gomock.Any(), gomock.Any(), conn.version).DoAndReturn(func(e *qerr.ApplicationError, _ protocol.ByteCount, _ protocol.VersionNumber) (*coalescedPacket, error) { + packer.EXPECT().PackApplicationClose(gomock.Any(), gomock.Any(), conn.version).DoAndReturn(func(e *qerr.ApplicationError, _ protocol.ByteCount, _ protocol.Version) (*coalescedPacket, error) { Expect(e.ErrorCode).To(BeZero()) return &coalescedPacket{buffer: getPacketBuffer()}, nil }) @@ -2308,7 +2292,7 @@ var _ = Describe("Connection", func() { expectReplaceWithClosed() cryptoSetup.EXPECT().Close() mconn.EXPECT().Write(gomock.Any(), gomock.Any(), gomock.Any()) - conn.shutdown() + conn.CloseWithError(0, "") Eventually(conn.Context().Done()).Should(BeClosed()) }) @@ -2349,6 +2333,7 @@ var _ = Describe("Connection", func() { ) cryptoSetup.EXPECT().Close() gomock.InOrder( + tracer.EXPECT().ChoseALPN(gomock.Any()), tracer.EXPECT().DroppedEncryptionLevel(protocol.EncryptionHandshake), tracer.EXPECT().ClosedConnection(gomock.Any()).Do(func(e error) { Expect(e).To(MatchError(&IdleTimeoutError{})) @@ -2364,6 +2349,7 @@ var _ = Describe("Connection", func() { cryptoSetup.EXPECT().NextEvent().Return(handshake.Event{Kind: handshake.EventNoEvent}) cryptoSetup.EXPECT().GetSessionTicket().MaxTimes(1) cryptoSetup.EXPECT().SetHandshakeConfirmed().MaxTimes(1) + cryptoSetup.EXPECT().ConnectionState() Expect(conn.handleHandshakeComplete()).To(Succeed()) err := conn.run() nerr, ok := err.(net.Error) @@ -2393,7 +2379,7 @@ var _ = Describe("Connection", func() { mconn.EXPECT().Write(gomock.Any(), gomock.Any(), gomock.Any()) tracer.EXPECT().ClosedConnection(gomock.Any()) tracer.EXPECT().Close() - conn.shutdown() + conn.CloseWithError(0, "") Eventually(conn.Context().Done()).Should(BeClosed()) }) @@ -2428,7 +2414,7 @@ var _ = Describe("Connection", func() { It("stores up to MaxConnUnprocessedPackets packets", func() { done := make(chan struct{}) - tracer.EXPECT().DroppedPacket(logging.PacketTypeNotDetermined, logging.ByteCount(6), logging.PacketDropDOSPrevention).Do(func(logging.PacketType, logging.ByteCount, logging.PacketDropReason) { + tracer.EXPECT().DroppedPacket(logging.PacketTypeNotDetermined, protocol.InvalidPacketNumber, logging.ByteCount(6), logging.PacketDropDOSPrevention).Do(func(logging.PacketType, logging.PacketNumber, logging.ByteCount, logging.PacketDropReason) { close(done) }) // Nothing here should block @@ -2573,7 +2559,7 @@ var _ = Describe("Client Connection", func() { It("changes the connection ID when receiving the first packet from the server", func() { unpacker := NewMockUnpacker(mockCtrl) - unpacker.EXPECT().UnpackLongHeader(gomock.Any(), gomock.Any(), gomock.Any(), conn.version).DoAndReturn(func(hdr *wire.Header, _ time.Time, data []byte, _ protocol.VersionNumber) (*unpackedPacket, error) { + unpacker.EXPECT().UnpackLongHeader(gomock.Any(), gomock.Any(), gomock.Any(), conn.version).DoAndReturn(func(hdr *wire.Header, _ time.Time, data []byte, _ protocol.Version) (*unpackedPacket, error) { return &unpackedPacket{ encryptionLevel: protocol.Encryption1RTT, hdr: &wire.ExtendedHeader{Header: *hdr}, @@ -2582,7 +2568,10 @@ var _ = Describe("Client Connection", func() { }) conn.unpacker = unpacker done := make(chan struct{}) - packer.EXPECT().PackCoalescedPacket(gomock.Any(), gomock.Any(), gomock.Any()).Do(func(onlyAck bool, maxPacketSize protocol.ByteCount, v protocol.VersionNumber) { close(done) }) + packer.EXPECT().PackCoalescedPacket(gomock.Any(), gomock.Any(), gomock.Any()).Do(func(onlyAck bool, maxPacketSize protocol.ByteCount, v protocol.Version) (*coalescedPacket, error) { + close(done) + return nil, nil + }) newConnID := protocol.ParseConnectionID([]byte{1, 3, 3, 7, 1, 3, 3, 7}) p := getPacket(&wire.ExtendedHeader{ Header: wire.Header{ @@ -2606,11 +2595,11 @@ var _ = Describe("Client Connection", func() { // make sure the go routine returns packer.EXPECT().PackApplicationClose(gomock.Any(), gomock.Any(), conn.version).Return(&coalescedPacket{buffer: getPacketBuffer()}, nil) cryptoSetup.EXPECT().Close() - connRunner.EXPECT().ReplaceWithClosed([]protocol.ConnectionID{srcConnID}, gomock.Any(), gomock.Any()) + connRunner.EXPECT().ReplaceWithClosed([]protocol.ConnectionID{srcConnID}, gomock.Any()) mconn.EXPECT().Write(gomock.Any(), gomock.Any(), gomock.Any()).MaxTimes(1) tracer.EXPECT().ClosedConnection(gomock.Any()) tracer.EXPECT().Close() - conn.shutdown() + conn.CloseWithError(0, "") Eventually(conn.Context().Done()).Should(BeClosed()) time.Sleep(200 * time.Millisecond) }) @@ -2626,7 +2615,7 @@ var _ = Describe("Client Connection", func() { }) Expect(conn.connIDManager.Get()).To(Equal(protocol.ParseConnectionID([]byte{1, 2, 3, 4, 5}))) // now receive a packet with the original source connection ID - unpacker.EXPECT().UnpackLongHeader(gomock.Any(), gomock.Any(), gomock.Any(), conn.version).DoAndReturn(func(hdr *wire.Header, _ time.Time, _ []byte, _ protocol.VersionNumber) (*unpackedPacket, error) { + unpacker.EXPECT().UnpackLongHeader(gomock.Any(), gomock.Any(), gomock.Any(), conn.version).DoAndReturn(func(hdr *wire.Header, _ time.Time, _ []byte, _ protocol.Version) (*unpackedPacket, error) { return &unpackedPacket{ hdr: &wire.ExtendedHeader{Header: *hdr}, data: []byte{0}, @@ -2672,9 +2661,10 @@ var _ = Describe("Client Connection", func() { tracer.EXPECT().ClosedConnection(gomock.Any()) tracer.EXPECT().Close() running := make(chan struct{}) - cryptoSetup.EXPECT().StartHandshake().Do(func() { + cryptoSetup.EXPECT().StartHandshake().Do(func() error { close(running) conn.closeLocal(errors.New("early error")) + return nil }) cryptoSetup.EXPECT().NextEvent().Return(handshake.Event{Kind: handshake.EventNoEvent}) cryptoSetup.EXPECT().Close() @@ -2705,7 +2695,7 @@ var _ = Describe("Client Connection", func() { }) Context("handling Version Negotiation", func() { - getVNP := func(versions ...protocol.VersionNumber) receivedPacket { + getVNP := func(versions ...protocol.Version) receivedPacket { b := wire.ComposeVersionNegotiation( protocol.ArbitraryLenConnectionID(srcConnID.Bytes()), protocol.ArbitraryLenConnectionID(destConnID.Bytes()), @@ -2723,7 +2713,7 @@ var _ = Describe("Client Connection", func() { conn.sentPacketHandler = sph sph.EXPECT().ReceivedBytes(gomock.Any()) sph.EXPECT().PeekPacketNumber(protocol.EncryptionInitial).Return(protocol.PacketNumber(128), protocol.PacketNumberLen4) - conn.config.Versions = []protocol.VersionNumber{1234, 4321} + conn.config.Versions = []protocol.Version{1234, 4321} errChan := make(chan error, 1) start := make(chan struct{}) go func() { @@ -2738,8 +2728,8 @@ var _ = Describe("Client Connection", func() { connRunner.EXPECT().Remove(srcConnID) tracer.EXPECT().ReceivedVersionNegotiationPacket(gomock.Any(), gomock.Any(), gomock.Any()).Do(func(_, _ protocol.ArbitraryLenConnectionID, versions []logging.VersionNumber) { Expect(versions).To(And( - ContainElement(protocol.VersionNumber(4321)), - ContainElement(protocol.VersionNumber(1337)), + ContainElement(protocol.Version(4321)), + ContainElement(protocol.Version(1337)), )) }) cryptoSetup.EXPECT().Close() @@ -2750,7 +2740,7 @@ var _ = Describe("Client Connection", func() { Expect(err).To(HaveOccurred()) Expect(err).To(BeAssignableToTypeOf(&errCloseForRecreating{})) recreateErr := err.(*errCloseForRecreating) - Expect(recreateErr.nextVersion).To(Equal(protocol.VersionNumber(4321))) + Expect(recreateErr.nextVersion).To(Equal(protocol.Version(4321))) Expect(recreateErr.nextPacketNumber).To(Equal(protocol.PacketNumber(128))) }) @@ -2784,14 +2774,14 @@ var _ = Describe("Client Connection", func() { It("ignores Version Negotiation packets that offer the current version", func() { p := getVNP(conn.version) - tracer.EXPECT().DroppedPacket(logging.PacketTypeVersionNegotiation, p.Size(), logging.PacketDropUnexpectedVersion) + tracer.EXPECT().DroppedPacket(logging.PacketTypeVersionNegotiation, protocol.InvalidPacketNumber, p.Size(), logging.PacketDropUnexpectedVersion) Expect(conn.handlePacketImpl(p)).To(BeFalse()) }) It("ignores unparseable Version Negotiation packets", func() { p := getVNP(conn.version) p.data = p.data[:len(p.data)-2] - tracer.EXPECT().DroppedPacket(logging.PacketTypeVersionNegotiation, p.Size(), logging.PacketDropHeaderParseError) + tracer.EXPECT().DroppedPacket(logging.PacketTypeVersionNegotiation, protocol.InvalidPacketNumber, p.Size(), logging.PacketDropHeaderParseError) Expect(conn.handlePacketImpl(p)).To(BeFalse()) }) }) @@ -2840,14 +2830,14 @@ var _ = Describe("Client Connection", func() { It("ignores Retry packets after receiving a regular packet", func() { conn.receivedFirstPacket = true p := getPacket(retryHdr, getRetryTag(retryHdr)) - tracer.EXPECT().DroppedPacket(logging.PacketTypeRetry, p.Size(), logging.PacketDropUnexpectedPacket) + tracer.EXPECT().DroppedPacket(logging.PacketTypeRetry, protocol.InvalidPacketNumber, p.Size(), logging.PacketDropUnexpectedPacket) Expect(conn.handlePacketImpl(p)).To(BeFalse()) }) It("ignores Retry packets if the server didn't change the connection ID", func() { retryHdr.SrcConnectionID = destConnID p := getPacket(retryHdr, getRetryTag(retryHdr)) - tracer.EXPECT().DroppedPacket(logging.PacketTypeRetry, p.Size(), logging.PacketDropUnexpectedPacket) + tracer.EXPECT().DroppedPacket(logging.PacketTypeRetry, protocol.InvalidPacketNumber, p.Size(), logging.PacketDropUnexpectedPacket) Expect(conn.handlePacketImpl(p)).To(BeFalse()) }) @@ -2855,7 +2845,7 @@ var _ = Describe("Client Connection", func() { tag := getRetryTag(retryHdr) tag[0]++ p := getPacket(retryHdr, tag) - tracer.EXPECT().DroppedPacket(logging.PacketTypeRetry, p.Size(), logging.PacketDropPayloadDecryptError) + tracer.EXPECT().DroppedPacket(logging.PacketTypeRetry, protocol.InvalidPacketNumber, p.Size(), logging.PacketDropPayloadDecryptError) Expect(conn.handlePacketImpl(p)).To(BeFalse()) }) }) @@ -2890,6 +2880,8 @@ var _ = Describe("Client Connection", func() { defer GinkgoRecover() Expect(conn.handleHandshakeComplete()).To(Succeed()) }) + tracer.EXPECT().ChoseALPN(gomock.Any()).MaxTimes(1) + cryptoSetup.EXPECT().ConnectionState().MaxTimes(1) errChan <- conn.run() close(errChan) }() @@ -2897,7 +2889,7 @@ var _ = Describe("Client Connection", func() { expectClose := func(applicationClose, errored bool) { if !closed && !errored { - connRunner.EXPECT().ReplaceWithClosed(gomock.Any(), gomock.Any(), gomock.Any()) + connRunner.EXPECT().ReplaceWithClosed(gomock.Any(), gomock.Any()) if applicationClose { packer.EXPECT().PackApplicationClose(gomock.Any(), gomock.Any(), conn.version).Return(&coalescedPacket{buffer: getPacketBuffer()}, nil).MaxTimes(1) } else { @@ -2914,7 +2906,7 @@ var _ = Describe("Client Connection", func() { } AfterEach(func() { - conn.shutdown() + conn.CloseWithError(0, "") Eventually(conn.Context().Done()).Should(BeClosed()) Eventually(errChan).Should(BeClosed()) }) @@ -2924,8 +2916,8 @@ var _ = Describe("Client Connection", func() { OriginalDestinationConnectionID: destConnID, InitialSourceConnectionID: destConnID, PreferredAddress: &wire.PreferredAddress{ - IPv4: net.IPv4(127, 0, 0, 1), - IPv6: net.IP{1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16}, + IPv4: netip.AddrPortFrom(netip.AddrFrom4([4]byte{127, 0, 0, 1}), 42), + IPv6: netip.AddrPortFrom(netip.AddrFrom16([16]byte{1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16}), 13), ConnectionID: protocol.ParseConnectionID([]byte{1, 2, 3, 4}), StatelessResetToken: protocol.StatelessResetToken{16, 15, 14, 13, 12, 11, 10, 9, 8, 7, 6, 5, 4, 3, 2, 1}, }, @@ -2958,7 +2950,7 @@ var _ = Describe("Client Connection", func() { Eventually(processed).Should(BeClosed()) // close first expectClose(true, false) - conn.shutdown() + conn.CloseWithError(0, "") // then check. Avoids race condition when accessing idleTimeout Expect(conn.idleTimeout).To(Equal(18 * time.Second)) }) @@ -3153,7 +3145,7 @@ var _ = Describe("Client Connection", func() { tracer.EXPECT().ReceivedLongHeaderPacket(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()) Expect(conn.handlePacketImpl(getPacket(hdr1, nil))).To(BeTrue()) // The next packet has to be ignored, since the source connection ID doesn't match. - tracer.EXPECT().DroppedPacket(gomock.Any(), gomock.Any(), gomock.Any()) + tracer.EXPECT().DroppedPacket(gomock.Any(), protocol.InvalidPacketNumber, gomock.Any(), gomock.Any()) Expect(conn.handlePacketImpl(getPacket(hdr2, nil))).To(BeFalse()) }) @@ -3168,7 +3160,7 @@ var _ = Describe("Client Connection", func() { PacketNumber: 0x42, PacketNumberLen: protocol.PacketNumberLen2, }, []byte("foobar")) - tracer.EXPECT().DroppedPacket(logging.PacketType0RTT, p.Size(), gomock.Any()) + tracer.EXPECT().DroppedPacket(logging.PacketType0RTT, protocol.InvalidPacketNumber, p.Size(), gomock.Any()) Expect(conn.handlePacketImpl(p)).To(BeFalse()) }) @@ -3176,7 +3168,7 @@ var _ = Describe("Client Connection", func() { // the connection to immediately break down It("fails on Initial-level ACK for unsent packet", func() { ack := &wire.AckFrame{AckRanges: []wire.AckRange{{Smallest: 2, Largest: 2}}} - initialPacket := testutils.ComposeInitialPacket(destConnID, srcConnID, conn.version, destConnID, []wire.Frame{ack}) + initialPacket := testutils.ComposeInitialPacket(destConnID, srcConnID, destConnID, []wire.Frame{ack}, protocol.PerspectiveServer, conn.version) tracer.EXPECT().ReceivedLongHeaderPacket(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()) Expect(conn.handlePacketImpl(wrapPacket(initialPacket))).To(BeFalse()) }) @@ -3188,7 +3180,7 @@ var _ = Describe("Client Connection", func() { IsApplicationError: true, ReasonPhrase: "mitm attacker", } - initialPacket := testutils.ComposeInitialPacket(destConnID, srcConnID, conn.version, destConnID, []wire.Frame{connCloseFrame}) + initialPacket := testutils.ComposeInitialPacket(destConnID, srcConnID, destConnID, []wire.Frame{connCloseFrame}, protocol.PerspectiveServer, conn.version) tracer.EXPECT().ReceivedLongHeaderPacket(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()) Expect(conn.handlePacketImpl(wrapPacket(initialPacket))).To(BeTrue()) }) @@ -3206,8 +3198,8 @@ var _ = Describe("Client Connection", func() { tracer.EXPECT().ReceivedRetry(gomock.Any()) conn.handlePacketImpl(wrapPacket(testutils.ComposeRetryPacket(newSrcConnID, destConnID, destConnID, []byte("foobar"), conn.version))) - initialPacket := testutils.ComposeInitialPacket(conn.connIDManager.Get(), srcConnID, conn.version, conn.connIDManager.Get(), nil) - tracer.EXPECT().DroppedPacket(gomock.Any(), gomock.Any(), gomock.Any()) + initialPacket := testutils.ComposeInitialPacket(conn.connIDManager.Get(), srcConnID, conn.connIDManager.Get(), nil, protocol.PerspectiveServer, conn.version) + tracer.EXPECT().DroppedPacket(gomock.Any(), protocol.InvalidPacketNumber, gomock.Any(), gomock.Any()) Expect(conn.handlePacketImpl(wrapPacket(initialPacket))).To(BeFalse()) }) }) diff --git a/crypto_stream.go b/crypto_stream.go index 4ad097ce5..c0f26d435 100644 --- a/crypto_stream.go +++ b/crypto_stream.go @@ -6,7 +6,6 @@ import ( "github.com/refraction-networking/uquic/internal/protocol" "github.com/refraction-networking/uquic/internal/qerr" - "github.com/refraction-networking/uquic/internal/utils" "github.com/refraction-networking/uquic/internal/wire" ) @@ -56,7 +55,7 @@ func (s *cryptoStreamImpl) HandleCryptoFrame(f *wire.CryptoFrame) error { // could e.g. be a retransmission return nil } - s.highestOffset = utils.Max(s.highestOffset, highestOffset) + s.highestOffset = max(s.highestOffset, highestOffset) if err := s.queue.Push(f.Data, f.Offset, nil); err != nil { return err } @@ -99,7 +98,7 @@ func (s *cryptoStreamImpl) HasData() bool { func (s *cryptoStreamImpl) PopCryptoFrame(maxLen protocol.ByteCount) *wire.CryptoFrame { f := &wire.CryptoFrame{Offset: s.writeOffset} - n := utils.Min(f.MaxDataLen(maxLen), protocol.ByteCount(len(s.writeBuf))) + n := min(f.MaxDataLen(maxLen), protocol.ByteCount(len(s.writeBuf))) f.Data = s.writeBuf[:n] s.writeBuf = s.writeBuf[n:] s.writeOffset += n diff --git a/datagram_queue.go b/datagram_queue.go index fbd6a251e..979f3b866 100644 --- a/datagram_queue.go +++ b/datagram_queue.go @@ -4,14 +4,20 @@ import ( "context" "sync" - "github.com/refraction-networking/uquic/internal/protocol" "github.com/refraction-networking/uquic/internal/utils" + "github.com/refraction-networking/uquic/internal/utils/ringbuffer" "github.com/refraction-networking/uquic/internal/wire" ) +const ( + maxDatagramSendQueueLen = 32 + maxDatagramRcvQueueLen = 128 +) + type datagramQueue struct { - sendQueue chan *wire.DatagramFrame - nextFrame *wire.DatagramFrame + sendMx sync.Mutex + sendQueue ringbuffer.RingBuffer[*wire.DatagramFrame] + sent chan struct{} // used to notify Add that a datagram was dequeued rcvMx sync.Mutex rcvQueue [][]byte @@ -22,60 +28,65 @@ type datagramQueue struct { hasData func() - dequeued chan struct{} - logger utils.Logger } func newDatagramQueue(hasData func(), logger utils.Logger) *datagramQueue { return &datagramQueue{ - hasData: hasData, - sendQueue: make(chan *wire.DatagramFrame, 1), - rcvd: make(chan struct{}, 1), - dequeued: make(chan struct{}), - closed: make(chan struct{}), - logger: logger, + hasData: hasData, + rcvd: make(chan struct{}, 1), + sent: make(chan struct{}, 1), + closed: make(chan struct{}), + logger: logger, } } -// AddAndWait queues a new DATAGRAM frame for sending. -// It blocks until the frame has been dequeued. -func (h *datagramQueue) AddAndWait(f *wire.DatagramFrame) error { - select { - case h.sendQueue <- f: - h.hasData() - case <-h.closed: - return h.closeErr - } +// Add queues a new DATAGRAM frame for sending. +// Up to 32 DATAGRAM frames will be queued. +// Once that limit is reached, Add blocks until the queue size has reduced. +func (h *datagramQueue) Add(f *wire.DatagramFrame) error { + h.sendMx.Lock() - select { - case <-h.dequeued: - return nil - case <-h.closed: - return h.closeErr + for { + if h.sendQueue.Len() < maxDatagramSendQueueLen { + h.sendQueue.PushBack(f) + h.sendMx.Unlock() + h.hasData() + return nil + } + select { + case <-h.sent: // drain the queue so we don't loop immediately + default: + } + h.sendMx.Unlock() + select { + case <-h.closed: + return h.closeErr + case <-h.sent: + } + h.sendMx.Lock() } } // Peek gets the next DATAGRAM frame for sending. // If actually sent out, Pop needs to be called before the next call to Peek. func (h *datagramQueue) Peek() *wire.DatagramFrame { - if h.nextFrame != nil { - return h.nextFrame - } - select { - case h.nextFrame = <-h.sendQueue: - h.dequeued <- struct{}{} - default: + h.sendMx.Lock() + defer h.sendMx.Unlock() + if h.sendQueue.Empty() { return nil } - return h.nextFrame + return h.sendQueue.PeekFront() } func (h *datagramQueue) Pop() { - if h.nextFrame == nil { - panic("datagramQueue BUG: Pop called for nil frame") + h.sendMx.Lock() + defer h.sendMx.Unlock() + _ = h.sendQueue.PopFront() + select { + case h.sent <- struct{}{}: + default: } - h.nextFrame = nil } // HandleDatagramFrame handles a received DATAGRAM frame. @@ -84,7 +95,7 @@ func (h *datagramQueue) HandleDatagramFrame(f *wire.DatagramFrame) { copy(data, f.Data) var queued bool h.rcvMx.Lock() - if len(h.rcvQueue) < protocol.DatagramRcvQueueLen { + if len(h.rcvQueue) < maxDatagramRcvQueueLen { h.rcvQueue = append(h.rcvQueue, data) queued = true select { @@ -94,7 +105,7 @@ func (h *datagramQueue) HandleDatagramFrame(f *wire.DatagramFrame) { } h.rcvMx.Unlock() if !queued && h.logger.Debug() { - h.logger.Debugf("Discarding DATAGRAM frame (%d bytes payload)", len(f.Data)) + h.logger.Debugf("Discarding received DATAGRAM frame (%d bytes payload)", len(f.Data)) } } diff --git a/datagram_queue_test.go b/datagram_queue_test.go index f581f4294..3d8efb4e1 100644 --- a/datagram_queue_test.go +++ b/datagram_queue_test.go @@ -3,6 +3,7 @@ package quic import ( "context" "errors" + "time" "github.com/refraction-networking/uquic/internal/utils" "github.com/refraction-networking/uquic/internal/wire" @@ -26,55 +27,65 @@ var _ = Describe("Datagram Queue", func() { }) It("queues a datagram", func() { - done := make(chan struct{}) frame := &wire.DatagramFrame{Data: []byte("foobar")} - go func() { - defer GinkgoRecover() - defer close(done) - Expect(queue.AddAndWait(frame)).To(Succeed()) - }() - - Eventually(queued).Should(HaveLen(1)) - Consistently(done).ShouldNot(BeClosed()) + Expect(queue.Add(frame)).To(Succeed()) + Expect(queued).To(HaveLen(1)) f := queue.Peek() Expect(f.Data).To(Equal([]byte("foobar"))) - Eventually(done).Should(BeClosed()) queue.Pop() Expect(queue.Peek()).To(BeNil()) }) - It("returns the same datagram multiple times, when Pop isn't called", func() { - sent := make(chan struct{}, 1) + It("blocks when the maximum number of datagrams have been queued", func() { + for i := 0; i < maxDatagramSendQueueLen; i++ { + Expect(queue.Add(&wire.DatagramFrame{Data: []byte{0}})).To(Succeed()) + } + errChan := make(chan error, 1) go func() { defer GinkgoRecover() - Expect(queue.AddAndWait(&wire.DatagramFrame{Data: []byte("foo")})).To(Succeed()) - sent <- struct{}{} - Expect(queue.AddAndWait(&wire.DatagramFrame{Data: []byte("bar")})).To(Succeed()) - sent <- struct{}{} + errChan <- queue.Add(&wire.DatagramFrame{Data: []byte("foobar")}) }() + Consistently(errChan, 50*time.Millisecond).ShouldNot(Receive()) + Expect(queue.Peek()).ToNot(BeNil()) + Consistently(errChan, 50*time.Millisecond).ShouldNot(Receive()) + queue.Pop() + Eventually(errChan).Should(Receive(BeNil())) + for i := 1; i < maxDatagramSendQueueLen; i++ { + queue.Pop() + } + f := queue.Peek() + Expect(f).ToNot(BeNil()) + Expect(f.Data).To(Equal([]byte("foobar"))) + }) + + It("returns the same datagram multiple times, when Pop isn't called", func() { + Expect(queue.Add(&wire.DatagramFrame{Data: []byte("foo")})).To(Succeed()) + Expect(queue.Add(&wire.DatagramFrame{Data: []byte("bar")})).To(Succeed()) - Eventually(queued).Should(HaveLen(1)) + Eventually(queued).Should(HaveLen(2)) f := queue.Peek() Expect(f.Data).To(Equal([]byte("foo"))) - Eventually(sent).Should(Receive()) Expect(queue.Peek()).To(Equal(f)) Expect(queue.Peek()).To(Equal(f)) queue.Pop() - Eventually(func() *wire.DatagramFrame { f = queue.Peek(); return f }).ShouldNot(BeNil()) f = queue.Peek() + Expect(f).ToNot(BeNil()) Expect(f.Data).To(Equal([]byte("bar"))) }) It("closes", func() { + for i := 0; i < maxDatagramSendQueueLen; i++ { + Expect(queue.Add(&wire.DatagramFrame{Data: []byte("foo")})).To(Succeed()) + } errChan := make(chan error, 1) go func() { defer GinkgoRecover() - errChan <- queue.AddAndWait(&wire.DatagramFrame{Data: []byte("foobar")}) + errChan <- queue.Add(&wire.DatagramFrame{Data: []byte("foo")}) }() - - Consistently(errChan).ShouldNot(Receive()) - queue.CloseWithError(errors.New("test error")) - Eventually(errChan).Should(Receive(MatchError("test error"))) + Consistently(errChan, 25*time.Millisecond).ShouldNot(Receive()) + testErr := errors.New("test error") + queue.CloseWithError(testErr) + Eventually(errChan).Should(Receive(MatchError(testErr))) }) }) diff --git a/errors.go b/errors.go index fba74e864..c4ffcf873 100644 --- a/errors.go +++ b/errors.go @@ -61,3 +61,15 @@ func (e *StreamError) Error() string { } return fmt.Sprintf("stream %d canceled by %s with error code %d", e.StreamID, pers, e.ErrorCode) } + +// DatagramTooLargeError is returned from Connection.SendDatagram if the payload is too large to be sent. +type DatagramTooLargeError struct { + PeerMaxDatagramFrameSize int64 +} + +func (e *DatagramTooLargeError) Is(target error) bool { + _, ok := target.(*DatagramTooLargeError) + return ok +} + +func (e *DatagramTooLargeError) Error() string { return "DATAGRAM frame too large" } diff --git a/example/client/main.go b/example/client/main.go index 18a09a191..50a100f83 100644 --- a/example/client/main.go +++ b/example/client/main.go @@ -1,12 +1,9 @@ package main import ( - "bufio" "bytes" - "context" "crypto/x509" "flag" - "fmt" "io" "log" "net/http" @@ -18,29 +15,16 @@ import ( quic "github.com/refraction-networking/uquic" "github.com/refraction-networking/uquic/http3" "github.com/refraction-networking/uquic/internal/testdata" - "github.com/refraction-networking/uquic/internal/utils" - "github.com/refraction-networking/uquic/logging" "github.com/refraction-networking/uquic/qlog" ) func main() { - verbose := flag.Bool("v", false, "verbose") quiet := flag.Bool("q", false, "don't print the data") keyLogFile := flag.String("keylog", "", "key log file") insecure := flag.Bool("insecure", false, "skip certificate verification") - enableQlog := flag.Bool("qlog", false, "output a qlog (in the same directory)") flag.Parse() urls := flag.Args() - logger := utils.DefaultLogger - - if *verbose { - logger.SetLogLevel(utils.LogLevelDebug) - } else { - logger.SetLogLevel(utils.LogLevelInfo) - } - logger.SetLogTimeFormat("") - var keyLog io.Writer if len(*keyLogFile) > 0 { f, err := os.Create(*keyLogFile) @@ -57,25 +41,15 @@ func main() { } testdata.AddRootCA(pool) - var qconf quic.Config - if *enableQlog { - qconf.Tracer = func(ctx context.Context, p logging.Perspective, connID quic.ConnectionID) *logging.ConnectionTracer { - filename := fmt.Sprintf("client_%x.qlog", connID) - f, err := os.Create(filename) - if err != nil { - log.Fatal(err) - } - log.Printf("Creating qlog file %s.\n", filename) - return qlog.NewConnectionTracer(utils.NewBufferedWriteCloser(bufio.NewWriter(f), f), p, connID) - } - } roundTripper := &http3.RoundTripper{ TLSClientConfig: &tls.Config{ RootCAs: pool, InsecureSkipVerify: *insecure, KeyLogWriter: keyLog, }, - QuicConfig: &qconf, + QuicConfig: &quic.Config{ + Tracer: qlog.DefaultTracer, + }, } defer roundTripper.Close() hclient := &http.Client{ @@ -85,13 +59,13 @@ func main() { var wg sync.WaitGroup wg.Add(len(urls)) for _, addr := range urls { - logger.Infof("GET %s", addr) + log.Printf("GET %s", addr) go func(addr string) { rsp, err := hclient.Get(addr) if err != nil { log.Fatal(err) } - logger.Infof("Got response for %s: %#v", addr, rsp) + log.Printf("Got response for %s: %#v", addr, rsp) body := &bytes.Buffer{} _, err = io.Copy(body, rsp.Body) @@ -99,10 +73,9 @@ func main() { log.Fatal(err) } if *quiet { - logger.Infof("Response Body: %d bytes", body.Len()) + log.Printf("Response Body: %d bytes", body.Len()) } else { - logger.Infof("Response Body:") - logger.Infof("%s", body.Bytes()) + log.Printf("Response Body (%d bytes):\n%s", body.Len(), body.Bytes()) } wg.Done() }(addr) diff --git a/example/echo/echo.go b/example/echo/echo.go index e8962a777..77603a0cf 100644 --- a/example/echo/echo.go +++ b/example/echo/echo.go @@ -37,14 +37,19 @@ func echoServer() error { if err != nil { return err } + defer listener.Close() + conn, err := listener.Accept(context.Background()) if err != nil { return err } + stream, err := conn.AcceptStream(context.Background()) if err != nil { panic(err) } + defer stream.Close() + // Echo through the loggingWriter _, err = io.Copy(loggingWriter{stream}, stream) return err @@ -59,11 +64,13 @@ func clientMain() error { if err != nil { return err } + defer conn.CloseWithError(0, "") stream, err := conn.OpenStreamSync(context.Background()) if err != nil { return err } + defer stream.Close() fmt.Printf("Client: Sending '%s'\n", message) _, err = stream.Write([]byte(message)) diff --git a/example/main.go b/example/main.go index a77af3398..ce3ef51f9 100644 --- a/example/main.go +++ b/example/main.go @@ -1,8 +1,6 @@ package main import ( - "bufio" - "context" "crypto/md5" "errors" "flag" @@ -11,7 +9,6 @@ import ( "log" "mime/multipart" "net/http" - "os" "strconv" "strings" "sync" @@ -21,8 +18,6 @@ import ( quic "github.com/refraction-networking/uquic" "github.com/refraction-networking/uquic/http3" "github.com/refraction-networking/uquic/internal/testdata" - "github.com/refraction-networking/uquic/internal/utils" - "github.com/refraction-networking/uquic/logging" "github.com/refraction-networking/uquic/qlog" ) @@ -121,7 +116,7 @@ func setupHandler(www string) http.Handler { err = errors.New("couldn't get uploaded file size") } } - utils.DefaultLogger.Infof("Error receiving upload: %#v", err) + log.Printf("Error receiving upload: %#v", err) } io.WriteString(w, `

@@ -139,57 +134,45 @@ func main() { }() // runtime.SetBlockProfileRate(1) - verbose := flag.Bool("v", false, "verbose") bs := binds{} flag.Var(&bs, "bind", "bind to") www := flag.String("www", "", "www data") tcp := flag.Bool("tcp", false, "also listen on TCP") - enableQlog := flag.Bool("qlog", false, "output a qlog (in the same directory)") + key := flag.String("key", "", "TLS key (requires -cert option)") + cert := flag.String("cert", "", "TLS certificate (requires -key option)") flag.Parse() - logger := utils.DefaultLogger - - if *verbose { - logger.SetLogLevel(utils.LogLevelDebug) - } else { - logger.SetLogLevel(utils.LogLevelInfo) - } - logger.SetLogTimeFormat("") - if len(bs) == 0 { bs = binds{"localhost:6121"} } handler := setupHandler(*www) - quicConf := &quic.Config{} - if *enableQlog { - quicConf.Tracer = func(ctx context.Context, p logging.Perspective, connID quic.ConnectionID) *logging.ConnectionTracer { - filename := fmt.Sprintf("server_%x.qlog", connID) - f, err := os.Create(filename) - if err != nil { - log.Fatal(err) - } - log.Printf("Creating qlog file %s.\n", filename) - return qlog.NewConnectionTracer(utils.NewBufferedWriteCloser(bufio.NewWriter(f), f), p, connID) - } - } var wg sync.WaitGroup wg.Add(len(bs)) + var certFile, keyFile string + if *key != "" && *cert != "" { + keyFile = *key + certFile = *cert + } else { + certFile, keyFile = testdata.GetCertificatePaths() + } for _, b := range bs { + fmt.Println("listening on", b) bCap := b go func() { var err error if *tcp { - certFile, keyFile := testdata.GetCertificatePaths() err = http3.ListenAndServe(bCap, certFile, keyFile, handler) } else { server := http3.Server{ - Handler: handler, - Addr: bCap, - QuicConfig: quicConf, + Handler: handler, + Addr: bCap, + QuicConfig: &quic.Config{ + Tracer: qlog.DefaultTracer, + }, } - err = server.ListenAndServeTLS(testdata.GetCertificatePaths()) + err = server.ListenAndServeTLS(certFile, keyFile) } if err != nil { fmt.Println(err) diff --git a/framer.go b/framer.go index 3a3c415ba..e64b43062 100644 --- a/framer.go +++ b/framer.go @@ -15,14 +15,26 @@ type framer interface { HasData() bool QueueControlFrame(wire.Frame) - AppendControlFrames([]ackhandler.Frame, protocol.ByteCount, protocol.VersionNumber) ([]ackhandler.Frame, protocol.ByteCount) + AppendControlFrames([]ackhandler.Frame, protocol.ByteCount, protocol.Version) ([]ackhandler.Frame, protocol.ByteCount) AddActiveStream(protocol.StreamID) - AppendStreamFrames([]ackhandler.StreamFrame, protocol.ByteCount, protocol.VersionNumber) ([]ackhandler.StreamFrame, protocol.ByteCount) + AppendStreamFrames([]ackhandler.StreamFrame, protocol.ByteCount, protocol.Version) ([]ackhandler.StreamFrame, protocol.ByteCount) Handle0RTTRejection() error + + // QueuedTooManyControlFrames says if the control frame queue exceeded its maximum queue length. + // This is a hack. + // It is easier to implement than propagating an error return value in QueueControlFrame. + // The correct solution would be to queue frames with their respective structs. + // See https://github.com/quic-go/quic-go/issues/4271 for the queueing of stream-related control frames. + QueuedTooManyControlFrames() bool } +const ( + maxPathResponses = 256 + maxControlFrames = 16 << 10 +) + type framerI struct { mutex sync.Mutex @@ -31,8 +43,10 @@ type framerI struct { activeStreams map[protocol.StreamID]struct{} streamQueue ringbuffer.RingBuffer[protocol.StreamID] - controlFrameMutex sync.Mutex - controlFrames []wire.Frame + controlFrameMutex sync.Mutex + controlFrames []wire.Frame + pathResponses []*wire.PathResponseFrame + queuedTooManyControlFrames bool } var _ framer = &framerI{} @@ -52,20 +66,48 @@ func (f *framerI) HasData() bool { return true } f.controlFrameMutex.Lock() - hasData = len(f.controlFrames) > 0 - f.controlFrameMutex.Unlock() - return hasData + defer f.controlFrameMutex.Unlock() + return len(f.controlFrames) > 0 || len(f.pathResponses) > 0 } func (f *framerI) QueueControlFrame(frame wire.Frame) { f.controlFrameMutex.Lock() + defer f.controlFrameMutex.Unlock() + + if pr, ok := frame.(*wire.PathResponseFrame); ok { + // Only queue up to maxPathResponses PATH_RESPONSE frames. + // This limit should be high enough to never be hit in practice, + // unless the peer is doing something malicious. + if len(f.pathResponses) >= maxPathResponses { + return + } + f.pathResponses = append(f.pathResponses, pr) + return + } + // This is a hack. + if len(f.controlFrames) >= maxControlFrames { + f.queuedTooManyControlFrames = true + return + } f.controlFrames = append(f.controlFrames, frame) - f.controlFrameMutex.Unlock() } -func (f *framerI) AppendControlFrames(frames []ackhandler.Frame, maxLen protocol.ByteCount, v protocol.VersionNumber) ([]ackhandler.Frame, protocol.ByteCount) { - var length protocol.ByteCount +func (f *framerI) AppendControlFrames(frames []ackhandler.Frame, maxLen protocol.ByteCount, v protocol.Version) ([]ackhandler.Frame, protocol.ByteCount) { f.controlFrameMutex.Lock() + defer f.controlFrameMutex.Unlock() + + var length protocol.ByteCount + // add a PATH_RESPONSE first, but only pack a single PATH_RESPONSE per packet + if len(f.pathResponses) > 0 { + frame := f.pathResponses[0] + frameLen := frame.Length(v) + if frameLen <= maxLen { + frames = append(frames, ackhandler.Frame{Frame: frame}) + length += frameLen + f.pathResponses = f.pathResponses[1:] + } + } + for len(f.controlFrames) > 0 { frame := f.controlFrames[len(f.controlFrames)-1] frameLen := frame.Length(v) @@ -76,10 +118,13 @@ func (f *framerI) AppendControlFrames(frames []ackhandler.Frame, maxLen protocol length += frameLen f.controlFrames = f.controlFrames[:len(f.controlFrames)-1] } - f.controlFrameMutex.Unlock() return frames, length } +func (f *framerI) QueuedTooManyControlFrames() bool { + return f.queuedTooManyControlFrames +} + func (f *framerI) AddActiveStream(id protocol.StreamID) { f.mutex.Lock() if _, ok := f.activeStreams[id]; !ok { @@ -89,7 +134,7 @@ func (f *framerI) AddActiveStream(id protocol.StreamID) { f.mutex.Unlock() } -func (f *framerI) AppendStreamFrames(frames []ackhandler.StreamFrame, maxLen protocol.ByteCount, v protocol.VersionNumber) ([]ackhandler.StreamFrame, protocol.ByteCount) { +func (f *framerI) AppendStreamFrames(frames []ackhandler.StreamFrame, maxLen protocol.ByteCount, v protocol.Version) ([]ackhandler.StreamFrame, protocol.ByteCount) { startLen := len(frames) var length protocol.ByteCount f.mutex.Lock() diff --git a/framer_test.go b/framer_test.go index 1c4944dad..44e860060 100644 --- a/framer_test.go +++ b/framer_test.go @@ -2,7 +2,8 @@ package quic import ( "bytes" - "math/rand" + + "golang.org/x/exp/rand" "github.com/refraction-networking/uquic/internal/ackhandler" "github.com/refraction-networking/uquic/internal/protocol" @@ -23,7 +24,7 @@ var _ = Describe("Framer", func() { framer framer stream1, stream2 *MockSendStreamI streamGetter *MockStreamGetter - version protocol.VersionNumber + version protocol.Version ) BeforeEach(func() { @@ -108,6 +109,72 @@ var _ = Describe("Framer", func() { Expect(fs).To(HaveLen(2)) Expect(length).To(Equal(ping.Length(version) + ncid.Length(version))) }) + + It("detects when too many frames are queued", func() { + for i := 0; i < maxControlFrames-1; i++ { + framer.QueueControlFrame(&wire.PingFrame{}) + framer.QueueControlFrame(&wire.PingFrame{}) + Expect(framer.QueuedTooManyControlFrames()).To(BeFalse()) + frames, _ := framer.AppendControlFrames([]ackhandler.Frame{}, 1, protocol.Version1) + Expect(frames).To(HaveLen(1)) + Expect(framer.(*framerI).controlFrames).To(HaveLen(i + 1)) + } + framer.QueueControlFrame(&wire.PingFrame{}) + Expect(framer.QueuedTooManyControlFrames()).To(BeFalse()) + Expect(framer.(*framerI).controlFrames).To(HaveLen(maxControlFrames)) + framer.QueueControlFrame(&wire.PingFrame{}) + Expect(framer.QueuedTooManyControlFrames()).To(BeTrue()) + Expect(framer.(*framerI).controlFrames).To(HaveLen(maxControlFrames)) + }) + }) + + Context("handling PATH_RESPONSE frames", func() { + It("packs a single PATH_RESPONSE per packet", func() { + f1 := &wire.PathResponseFrame{Data: [8]byte{1, 2, 3, 4, 5, 6, 7, 8}} + f2 := &wire.PathResponseFrame{Data: [8]byte{2, 3, 4, 5, 6, 7, 8, 9}} + cf1 := &wire.DataBlockedFrame{MaximumData: 1337} + cf2 := &wire.HandshakeDoneFrame{} + framer.QueueControlFrame(f1) + framer.QueueControlFrame(f2) + framer.QueueControlFrame(cf1) + framer.QueueControlFrame(cf2) + // the first packet should contain a single PATH_RESPONSE frame, but all the other control frames + Expect(framer.HasData()).To(BeTrue()) + frames, length := framer.AppendControlFrames(nil, protocol.MaxByteCount, protocol.Version1) + Expect(frames).To(HaveLen(3)) + Expect(frames[0].Frame).To(Equal(f1)) + Expect([]wire.Frame{frames[1].Frame, frames[2].Frame}).To(ContainElement(cf1)) + Expect([]wire.Frame{frames[1].Frame, frames[2].Frame}).To(ContainElement(cf2)) + Expect(length).To(Equal(f1.Length(protocol.Version1) + cf1.Length(protocol.Version1) + cf2.Length(protocol.Version1))) + // the second packet should contain the other PATH_RESPONSE frame + Expect(framer.HasData()).To(BeTrue()) + frames, length = framer.AppendControlFrames(nil, protocol.MaxByteCount, protocol.Version1) + Expect(frames).To(HaveLen(1)) + Expect(frames[0].Frame).To(Equal(f2)) + Expect(length).To(Equal(f2.Length(protocol.Version1))) + Expect(framer.HasData()).To(BeFalse()) + }) + + It("limits the number of queued PATH_RESPONSE frames", func() { + var pathResponses []*wire.PathResponseFrame + for i := 0; i < 2*maxPathResponses; i++ { + var f wire.PathResponseFrame + rand.Read(f.Data[:]) + pathResponses = append(pathResponses, &f) + framer.QueueControlFrame(&f) + } + for i := 0; i < maxPathResponses; i++ { + Expect(framer.HasData()).To(BeTrue()) + frames, length := framer.AppendControlFrames(nil, protocol.MaxByteCount, protocol.Version1) + Expect(frames).To(HaveLen(1)) + Expect(frames[0].Frame).To(Equal(pathResponses[i])) + Expect(length).To(Equal(pathResponses[i].Length(protocol.Version1))) + } + Expect(framer.HasData()).To(BeFalse()) + frames, length := framer.AppendControlFrames(nil, protocol.MaxByteCount, protocol.Version1) + Expect(frames).To(BeEmpty()) + Expect(length).To(BeZero()) + }) }) Context("popping STREAM frames", func() { @@ -296,7 +363,7 @@ var _ = Describe("Framer", func() { It("pops maximum size STREAM frames", func() { for i := protocol.MinStreamFrameSize; i < 2000; i++ { streamGetter.EXPECT().GetOrOpenSendStream(id1).Return(stream1, nil) - stream1.EXPECT().popStreamFrame(gomock.Any(), protocol.Version1).DoAndReturn(func(size protocol.ByteCount, v protocol.VersionNumber) (ackhandler.StreamFrame, bool, bool) { + stream1.EXPECT().popStreamFrame(gomock.Any(), protocol.Version1).DoAndReturn(func(size protocol.ByteCount, v protocol.Version) (ackhandler.StreamFrame, bool, bool) { f := &wire.StreamFrame{ StreamID: id1, DataLenPresent: true, @@ -318,7 +385,7 @@ var _ = Describe("Framer", func() { for i := 2 * protocol.MinStreamFrameSize; i < 2000; i++ { streamGetter.EXPECT().GetOrOpenSendStream(id1).Return(stream1, nil) streamGetter.EXPECT().GetOrOpenSendStream(id2).Return(stream2, nil) - stream1.EXPECT().popStreamFrame(gomock.Any(), protocol.Version1).DoAndReturn(func(size protocol.ByteCount, v protocol.VersionNumber) (ackhandler.StreamFrame, bool, bool) { + stream1.EXPECT().popStreamFrame(gomock.Any(), protocol.Version1).DoAndReturn(func(size protocol.ByteCount, v protocol.Version) (ackhandler.StreamFrame, bool, bool) { f := &wire.StreamFrame{ StreamID: id2, DataLenPresent: true, @@ -326,7 +393,7 @@ var _ = Describe("Framer", func() { f.Data = make([]byte, f.MaxDataLen(protocol.MinStreamFrameSize, v)) return ackhandler.StreamFrame{Frame: f}, true, false }) - stream2.EXPECT().popStreamFrame(gomock.Any(), protocol.Version1).DoAndReturn(func(size protocol.ByteCount, v protocol.VersionNumber) (ackhandler.StreamFrame, bool, bool) { + stream2.EXPECT().popStreamFrame(gomock.Any(), protocol.Version1).DoAndReturn(func(size protocol.ByteCount, v protocol.Version) (ackhandler.StreamFrame, bool, bool) { f := &wire.StreamFrame{ StreamID: id2, DataLenPresent: true, diff --git a/fuzzing/frames/fuzz.go b/fuzzing/frames/fuzz.go index d22d363b6..1d0d47e71 100644 --- a/fuzzing/frames/fuzz.go +++ b/fuzzing/frames/fuzz.go @@ -56,22 +56,23 @@ func Fuzz(data []byte) int { continue } } + validateFrame(f) startLen := len(b) parsedLen := initialLen - len(data) b, err = f.Append(b, version) if err != nil { - panic(fmt.Sprintf("Error writing frame %#v: %s", f, err)) + panic(fmt.Sprintf("error writing frame %#v: %s", f, err)) } frameLen := protocol.ByteCount(len(b) - startLen) if f.Length(version) != frameLen { - panic(fmt.Sprintf("Inconsistent frame length for %#v: expected %d, got %d", f, frameLen, f.Length(version))) + panic(fmt.Sprintf("inconsistent frame length for %#v: expected %d, got %d", f, frameLen, f.Length(version))) } if sf, ok := f.(*wire.StreamFrame); ok { sf.PutBack() } if frameLen > protocol.ByteCount(parsedLen) { - panic(fmt.Sprintf("Serialized length (%d) is longer than parsed length (%d)", len(b), parsedLen)) + panic(fmt.Sprintf("serialized length (%d) is longer than parsed length (%d)", len(b), parsedLen)) } } @@ -80,3 +81,52 @@ func Fuzz(data []byte) int { } return 1 } + +func validateFrame(frame wire.Frame) { + switch f := frame.(type) { + case *wire.StreamFrame: + if protocol.ByteCount(len(f.Data)) != f.DataLen() { + panic("STREAM frame: inconsistent data length") + } + case *wire.AckFrame: + if f.DelayTime < 0 { + panic(fmt.Sprintf("invalid ACK delay_time: %s", f.DelayTime)) + } + if f.LargestAcked() < f.LowestAcked() { + panic("ACK: largest acknowledged is smaller than lowest acknowledged") + } + for _, r := range f.AckRanges { + if r.Largest < 0 || r.Smallest < 0 { + panic("ACK range contains a negative packet number") + } + } + if !f.AcksPacket(f.LargestAcked()) { + panic("ACK frame claims that largest acknowledged is not acknowledged") + } + if !f.AcksPacket(f.LowestAcked()) { + panic("ACK frame claims that lowest acknowledged is not acknowledged") + } + _ = f.AcksPacket(100) + _ = f.AcksPacket((f.LargestAcked() + f.LowestAcked()) / 2) + case *wire.NewConnectionIDFrame: + if f.ConnectionID.Len() < 1 || f.ConnectionID.Len() > 20 { + panic(fmt.Sprintf("invalid NEW_CONNECTION_ID frame length: %s", f.ConnectionID)) + } + case *wire.NewTokenFrame: + if len(f.Token) == 0 { + panic("NEW_TOKEN frame with an empty token") + } + case *wire.MaxStreamsFrame: + if f.MaxStreamNum > protocol.MaxStreamCount { + panic("MAX_STREAMS frame with an invalid Maximum Streams value") + } + case *wire.StreamsBlockedFrame: + if f.StreamLimit > protocol.MaxStreamCount { + panic("STREAMS_BLOCKED frame with an invalid Maximum Streams value") + } + case *wire.ConnectionCloseFrame: + if f.IsApplicationError && f.FrameType != 0 { + panic("CONNECTION_CLOSE for an application error containing a frame type") + } + } +} diff --git a/fuzzing/header/cmd/corpus.go b/fuzzing/header/cmd/corpus.go index 799d2a6c0..e67eaf14d 100644 --- a/fuzzing/header/cmd/corpus.go +++ b/fuzzing/header/cmd/corpus.go @@ -20,9 +20,9 @@ func getRandomData(l int) []byte { } func getVNP(src, dest protocol.ArbitraryLenConnectionID, numVersions int) []byte { - versions := make([]protocol.VersionNumber, numVersions) + versions := make([]protocol.Version, numVersions) for i := 0; i < numVersions; i++ { - versions[i] = protocol.VersionNumber(rand.Uint32()) + versions[i] = protocol.Version(rand.Uint32()) } return wire.ComposeVersionNegotiation(src, dest, versions) } diff --git a/fuzzing/transportparameters/cmd/corpus.go b/fuzzing/transportparameters/cmd/corpus.go index fdbb3d300..b85aa17d4 100644 --- a/fuzzing/transportparameters/cmd/corpus.go +++ b/fuzzing/transportparameters/cmd/corpus.go @@ -3,7 +3,7 @@ package main import ( "log" "math" - "net" + "net/netip" "time" "golang.org/x/exp/rand" @@ -59,11 +59,13 @@ func main() { if rand.Int()%2 == 0 { var token protocol.StatelessResetToken rand.Read(token[:]) + var ip4 [4]byte + rand.Read(ip4[:]) + var ip6 [16]byte + rand.Read(ip6[:]) tp.PreferredAddress = &wire.PreferredAddress{ - IPv4: net.IPv4(uint8(rand.Int()), uint8(rand.Int()), uint8(rand.Int()), uint8(rand.Int())), - IPv4Port: uint16(rand.Int()), - IPv6: net.IP(getRandomData(16)), - IPv6Port: uint16(rand.Int()), + IPv4: netip.AddrPortFrom(netip.AddrFrom4(ip4), uint16(rand.Int())), + IPv6: netip.AddrPortFrom(netip.AddrFrom16(ip6), uint16(rand.Int())), ConnectionID: protocol.ParseConnectionID(getRandomData(rand.Intn(21))), StatelessResetToken: token, } diff --git a/fuzzing/transportparameters/fuzz.go b/fuzzing/transportparameters/fuzz.go index 6e645f587..dfca93ad4 100644 --- a/fuzzing/transportparameters/fuzz.go +++ b/fuzzing/transportparameters/fuzz.go @@ -2,6 +2,7 @@ package transportparameters import ( "bytes" + "errors" "fmt" "github.com/refraction-networking/uquic/fuzzing/internal/helper" @@ -26,23 +27,29 @@ func Fuzz(data []byte) int { return fuzzTransportParameters(data[PrefixLen:], helper.NthBit(data[0], 1)) } -func fuzzTransportParameters(data []byte, isServer bool) int { - perspective := protocol.PerspectiveClient - if isServer { - perspective = protocol.PerspectiveServer +func fuzzTransportParameters(data []byte, sentByServer bool) int { + sentBy := protocol.PerspectiveClient + if sentByServer { + sentBy = protocol.PerspectiveServer } tp := &wire.TransportParameters{} - if err := tp.Unmarshal(data, perspective); err != nil { + if err := tp.Unmarshal(data, sentBy); err != nil { return 0 } _ = tp.String() + if err := validateTransportParameters(tp, sentBy); err != nil { + panic(err) + } tp2 := &wire.TransportParameters{} - if err := tp2.Unmarshal(tp.Marshal(perspective), perspective); err != nil { + if err := tp2.Unmarshal(tp.Marshal(sentBy), sentBy); err != nil { fmt.Printf("%#v\n", tp) panic(err) } + if err := validateTransportParameters(tp2, sentBy); err != nil { + panic(err) + } return 1 } @@ -58,3 +65,34 @@ func fuzzTransportParametersForSessionTicket(data []byte) int { } return 1 } + +func validateTransportParameters(tp *wire.TransportParameters, sentBy protocol.Perspective) error { + if sentBy == protocol.PerspectiveClient && tp.StatelessResetToken != nil { + return errors.New("client's transport parameters contained stateless reset token") + } + if tp.MaxIdleTimeout < 0 { + return fmt.Errorf("negative max_idle_timeout: %s", tp.MaxIdleTimeout) + } + if tp.AckDelayExponent > 20 { + return fmt.Errorf("invalid ack_delay_exponent: %d", tp.AckDelayExponent) + } + if tp.MaxUDPPayloadSize < 1200 { + return fmt.Errorf("invalid max_udp_payload_size: %d", tp.MaxUDPPayloadSize) + } + if tp.ActiveConnectionIDLimit < 2 { + return fmt.Errorf("invalid active_connection_id_limit: %d", tp.ActiveConnectionIDLimit) + } + if tp.OriginalDestinationConnectionID.Len() > 20 { + return fmt.Errorf("invalid original_destination_connection_id length: %s", tp.InitialSourceConnectionID) + } + if tp.InitialSourceConnectionID.Len() > 20 { + return fmt.Errorf("invalid initial_source_connection_id length: %s", tp.InitialSourceConnectionID) + } + if tp.RetrySourceConnectionID != nil && tp.RetrySourceConnectionID.Len() > 20 { + return fmt.Errorf("invalid retry_source_connection_id length: %s", tp.RetrySourceConnectionID) + } + if tp.PreferredAddress != nil && tp.PreferredAddress.ConnectionID.Len() > 20 { + return fmt.Errorf("invalid preferred_address connection ID length: %s", tp.PreferredAddress.ConnectionID) + } + return nil +} diff --git a/go.mod b/go.mod index 510caebaf..3c5f73f30 100644 --- a/go.mod +++ b/go.mod @@ -8,13 +8,14 @@ require ( github.com/onsi/ginkgo/v2 v2.13.0 github.com/onsi/gomega v1.29.0 github.com/quic-go/qpack v0.4.0 - github.com/refraction-networking/utls v1.5.4 + github.com/refraction-networking/utls v1.6.4 go.uber.org/mock v0.4.0 - golang.org/x/crypto v0.21.0 - golang.org/x/exp v0.0.0-20231006140011-7918f672742d - golang.org/x/net v0.23.0 - golang.org/x/sync v0.4.0 - golang.org/x/sys v0.18.0 + golang.org/x/crypto v0.22.0 + golang.org/x/exp v0.0.0-20240416160154-fe59bbe5cc7f + golang.org/x/net v0.24.0 + golang.org/x/sync v0.7.0 + golang.org/x/sys v0.19.0 + golang.org/x/time v0.5.0 ) require ( @@ -26,10 +27,10 @@ require ( github.com/google/go-cmp v0.6.0 // indirect github.com/google/gopacket v1.1.19 // indirect github.com/google/pprof v0.0.0-20231023181126-ff6d637d2a7b // indirect - github.com/klauspost/compress v1.17.2 // indirect - github.com/quic-go/quic-go v0.42.0 // indirect - golang.org/x/mod v0.13.0 // indirect + github.com/klauspost/compress v1.17.4 // indirect + golang.org/x/mod v0.17.0 // indirect golang.org/x/text v0.14.0 // indirect - golang.org/x/tools v0.14.0 // indirect + golang.org/x/tools v0.20.0 // indirect + gopkg.in/check.v1 v1.0.0-20180628173108-788fd7840127 // indirect gopkg.in/yaml.v3 v3.0.1 // indirect ) diff --git a/go.sum b/go.sum index 4e1040290..8f6fb2e2a 100644 --- a/go.sum +++ b/go.sum @@ -67,8 +67,8 @@ github.com/jellevandenhooff/dkim v0.0.0-20150330215556-f50fe3d243e1/go.mod h1:E0 github.com/json-iterator/go v1.1.6/go.mod h1:+SdeFBvtyEkXs7REEP0seUULqWtbJapLOCVDaaPEHmU= github.com/jstemmer/go-junit-report v0.0.0-20190106144839-af01ea7f8024/go.mod h1:6v2b51hI/fHJwM22ozAgKL4VKDeJcHhJFhtBdhmNjmU= github.com/kisielk/gotool v1.0.0/go.mod h1:XhKaO+MFFWcvkIS/tQcRk01m1F5IRFswLeQ+oQHNcck= -github.com/klauspost/compress v1.17.2 h1:RlWWUY/Dr4fL8qk9YG7DTZ7PDgME2V4csBXA8L/ixi4= -github.com/klauspost/compress v1.17.2/go.mod h1:ntbaceVETuRiXiv4DpjP66DpAtAGkEQskQzEyD//IeE= +github.com/klauspost/compress v1.17.4 h1:Ej5ixsIri7BrIjBkRZLTo6ghwrEtHFk7ijlczPW4fZ4= +github.com/klauspost/compress v1.17.4/go.mod h1:/dCuZOvVtNoHsyb+cuJD3itjs3NbnF6KH9zAO4BDxPM= github.com/kr/pretty v0.1.0 h1:L/CwN0zerZDmRFUapSPitk6f+Q3+0za1rQkzVuMiMFI= github.com/kr/pretty v0.1.0/go.mod h1:dAy3ld7l9f0ibDNOQOHHMYYIIbhfbHSm3C4ZsoJORNo= github.com/kr/pty v1.1.1/go.mod h1:pFQYn66WHrOpPYNljwOMqo10TkYh1fy3cYio2l3bCsQ= @@ -97,10 +97,8 @@ github.com/prometheus/common v0.0.0-20180801064454-c7de2306084e/go.mod h1:daVV7q github.com/prometheus/procfs v0.0.0-20180725123919-05ee40e3a273/go.mod h1:c3At6R/oaqEKCNdg8wHV1ftS6bRYblBhIjjI8uT2IGk= github.com/quic-go/qpack v0.4.0 h1:Cr9BXA1sQS2SmDUWjSofMPNKmvF6IiIfDRmgU0w1ZCo= github.com/quic-go/qpack v0.4.0/go.mod h1:UZVnYIfi5GRk+zI9UMaCPsmZ2xKJP7XBUvVyT1Knj9A= -github.com/quic-go/quic-go v0.42.0 h1:uSfdap0eveIl8KXnipv9K7nlwZ5IqLlYOpJ58u5utpM= -github.com/quic-go/quic-go v0.42.0/go.mod h1:132kz4kL3F9vxhW3CtQJLDVwcFe5wdWeJXXijhsO57M= -github.com/refraction-networking/utls v1.5.4 h1:9k6EO2b8TaOGsQ7Pl7p9w6PUhx18/ZCeT0WNTZ7Uw4o= -github.com/refraction-networking/utls v1.5.4/go.mod h1:SPuDbBmgLGp8s+HLNc83FuavwZCFoMmExj+ltUHiHUw= +github.com/refraction-networking/utls v1.6.4 h1:aeynTroaYn7y+mFtqv8D0bQ4bw0y9nJHneGxJ7lvRDM= +github.com/refraction-networking/utls v1.6.4/go.mod h1:2VL2xfiqgFAZtJKeUTlf+PSYFs3Eu7km0gCtXJ3m8zs= github.com/russross/blackfriday v1.5.2/go.mod h1:JO/DiYxRf+HjHt06OyowR9PTA263kcR/rfWxYHBV53g= github.com/sergi/go-diff v1.0.0/go.mod h1:0CfEIISq7TuYL3j771MWULgwwjU+GofnZX9QAmXWZgo= github.com/shurcooL/component v0.0.0-20170202220835-f88ec8f54cc4/go.mod h1:XhFIlyj5a1fBNx5aJTbKoIq0mNaPvOagO+HjB3EtxrY= @@ -143,18 +141,18 @@ golang.org/x/crypto v0.0.0-20181030102418-4d3f4d9ffa16/go.mod h1:6SG95UA2DQfeDnf golang.org/x/crypto v0.0.0-20190308221718-c2843e01d9a2/go.mod h1:djNgcEr1/C05ACkg1iLfiJU5Ep61QUkGW8qpdssI0+w= golang.org/x/crypto v0.0.0-20190313024323-a1f597ede03a/go.mod h1:djNgcEr1/C05ACkg1iLfiJU5Ep61QUkGW8qpdssI0+w= golang.org/x/crypto v0.0.0-20191011191535-87dc89f01550/go.mod h1:yigFU9vqHzYiE8UmvKecakEJjdnWj3jj499lnFckfCI= -golang.org/x/crypto v0.21.0 h1:X31++rzVUdKhX5sWmSOFZxx8UW/ldWx55cbf08iNAMA= -golang.org/x/crypto v0.21.0/go.mod h1:0BP7YvVV9gBbVKyeTG0Gyn+gZm94bibOW5BjDEYAOMs= +golang.org/x/crypto v0.22.0 h1:g1v0xeRhjcugydODzvb3mEM9SQ0HGp9s/nh3COQ/C30= +golang.org/x/crypto v0.22.0/go.mod h1:vr6Su+7cTlO45qkww3VDJlzDn0ctJvRgYbC2NvXHt+M= golang.org/x/exp v0.0.0-20190121172915-509febef88a4/go.mod h1:CJ0aWSM057203Lf6IL+f9T1iT9GByDxfZKAQTCR3kQA= -golang.org/x/exp v0.0.0-20231006140011-7918f672742d h1:jtJma62tbqLibJ5sFQz8bKtEM8rJBtfilJ2qTU199MI= -golang.org/x/exp v0.0.0-20231006140011-7918f672742d/go.mod h1:ldy0pHrwJyGW56pPQzzkH36rKxoZW1tw7ZJpeKx+hdo= +golang.org/x/exp v0.0.0-20240416160154-fe59bbe5cc7f h1:99ci1mjWVBWwJiEKYY6jWa4d2nTQVIEhZIptnrVb1XY= +golang.org/x/exp v0.0.0-20240416160154-fe59bbe5cc7f/go.mod h1:/lliqkxwWAhPjf5oSOIJup2XcqJaw8RGS6k3TGEc7GI= golang.org/x/lint v0.0.0-20180702182130-06c8688daad7/go.mod h1:UVdnD1Gm6xHRNCYTkRU2/jEulfH38KcIWyp/GAMgvoE= golang.org/x/lint v0.0.0-20181026193005-c67002cb31c3/go.mod h1:UVdnD1Gm6xHRNCYTkRU2/jEulfH38KcIWyp/GAMgvoE= golang.org/x/lint v0.0.0-20190227174305-5b3e6a55c961/go.mod h1:wehouNa3lNwaWXcvxsM5YxQ5yQlVC4a0KAMCusXpPoU= golang.org/x/lint v0.0.0-20200302205851-738671d3881b/go.mod h1:3xt1FjdF8hUf6vQPIChWIBhFzV8gjjsPE/fR3IyQdNY= golang.org/x/mod v0.1.1-0.20191105210325-c90efee705ee/go.mod h1:QqPTAvyqsEbceGzBzNggFXnrqF1CaUcvgkdR5Ot7KZg= -golang.org/x/mod v0.13.0 h1:I/DsJXRlw/8l/0c24sM9yb0T4z9liZTduXvdAWYiysY= -golang.org/x/mod v0.13.0/go.mod h1:hTbmBsO62+eylJbnUtE2MGJUyE7QWk4xUqPFrRgJ+7c= +golang.org/x/mod v0.17.0 h1:zY54UmvipHiNd+pm+m0x9KhZ9hl1/7QNMyxXbc6ICqA= +golang.org/x/mod v0.17.0/go.mod h1:hTbmBsO62+eylJbnUtE2MGJUyE7QWk4xUqPFrRgJ+7c= golang.org/x/net v0.0.0-20180724234803-3673e40ba225/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4= golang.org/x/net v0.0.0-20180826012351-8a410e7b638d/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4= golang.org/x/net v0.0.0-20180906233101-161cd47e91fd/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4= @@ -165,8 +163,8 @@ golang.org/x/net v0.0.0-20190213061140-3a22650c66bd/go.mod h1:mL1N/T3taQHkDXs73r golang.org/x/net v0.0.0-20190313220215-9f648a60d977/go.mod h1:t9HGtf8HONx5eT2rtn7q6eTqICYqUVnKs3thJo3Qplg= golang.org/x/net v0.0.0-20190404232315-eb5bcb51f2a3/go.mod h1:t9HGtf8HONx5eT2rtn7q6eTqICYqUVnKs3thJo3Qplg= golang.org/x/net v0.0.0-20190620200207-3b0461eec859/go.mod h1:z5CRVTTTmAJ677TzLLGU+0bjPO0LkuOLi4/5GtJWs/s= -golang.org/x/net v0.23.0 h1:7EYJ93RZ9vYSZAIb2x3lnuvqO5zneoD6IvWjuhfxjTs= -golang.org/x/net v0.23.0/go.mod h1:JKghWKKOSdJwpW2GEx0Ja7fmaKnMsbu+MWVZTokSYmg= +golang.org/x/net v0.24.0 h1:1PcaxkF854Fu3+lvBIx5SYn9wRlBzzcnHZSiaFFAb0w= +golang.org/x/net v0.24.0/go.mod h1:2Q7sJY5mzlzWjKtYUEXSlBWCdyaioyXzRB2RtU8KVE8= golang.org/x/oauth2 v0.0.0-20180821212333-d2e6202438be/go.mod h1:N/0e6XlmueqKjAGxoOufVs8QHGRruUQn6yWY3a++T0U= golang.org/x/oauth2 v0.0.0-20181017192945-9dcd33a902f4/go.mod h1:N/0e6XlmueqKjAGxoOufVs8QHGRruUQn6yWY3a++T0U= golang.org/x/oauth2 v0.0.0-20181203162652-d668ce993890/go.mod h1:N/0e6XlmueqKjAGxoOufVs8QHGRruUQn6yWY3a++T0U= @@ -177,29 +175,31 @@ golang.org/x/sync v0.0.0-20181108010431-42b317875d0f/go.mod h1:RxMgew5VJxzue5/jJ golang.org/x/sync v0.0.0-20181221193216-37e7f081c4d4/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= golang.org/x/sync v0.0.0-20190227155943-e225da77a7e6/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= golang.org/x/sync v0.0.0-20190423024810-112230192c58/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= -golang.org/x/sync v0.4.0 h1:zxkM55ReGkDlKSM+Fu41A+zmbZuaPVbGMzvvdUPznYQ= -golang.org/x/sync v0.4.0/go.mod h1:FU7BRWz2tNW+3quACPkgCx/L+uEAv1htQ0V83Z9Rj+Y= +golang.org/x/sync v0.7.0 h1:YsImfSBoP9QPYL0xyKJPq0gcaJdG3rInoqxTWbfQu9M= +golang.org/x/sync v0.7.0/go.mod h1:Czt+wKu1gCyEFDUtn0jG5QVvpJ6rzVqr5aXyt9drQfk= golang.org/x/sys v0.0.0-20180830151530-49385e6e1522/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= golang.org/x/sys v0.0.0-20180909124046-d0be0721c37e/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= golang.org/x/sys v0.0.0-20181029174526-d69651ed3497/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= golang.org/x/sys v0.0.0-20190215142949-d0b11bdaac8a/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= golang.org/x/sys v0.0.0-20190316082340-a2f829d7f35f/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20190412213103-97732733099d/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= -golang.org/x/sys v0.18.0 h1:DBdB3niSjOA/O0blCZBqDefyWNYveAYMNF1Wum0DYQ4= -golang.org/x/sys v0.18.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA= +golang.org/x/sys v0.19.0 h1:q5f1RH2jigJ1MoAWp2KTp3gm5zAGFUTarQZ5U386+4o= +golang.org/x/sys v0.19.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA= golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ= golang.org/x/text v0.3.1-0.20180807135948-17ff2d5776d2/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ= golang.org/x/text v0.14.0 h1:ScX5w1eTa3QqT8oi6+ziP7dTV1S2+ALU0bI+0zXKWiQ= golang.org/x/text v0.14.0/go.mod h1:18ZOQIKpY8NJVqYksKHtTdi31H5itFRjB5/qKTNYzSU= golang.org/x/time v0.0.0-20180412165947-fbb02b2291d2/go.mod h1:tRJNPiyCQ0inRvYxbN9jk5I+vvW/OXSQhTDSoE431IQ= golang.org/x/time v0.0.0-20181108054448-85acf8d2951c/go.mod h1:tRJNPiyCQ0inRvYxbN9jk5I+vvW/OXSQhTDSoE431IQ= +golang.org/x/time v0.5.0 h1:o7cqy6amK/52YcAKIPlM3a+Fpj35zvRj2TP+e1xFSfk= +golang.org/x/time v0.5.0/go.mod h1:3BpzKBy/shNhVucY/MWOyx10tF3SFh9QdLuxbVysPQM= golang.org/x/tools v0.0.0-20180828015842-6cd1fcedba52/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ= golang.org/x/tools v0.0.0-20181030000716-a0a13e073c7b/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ= golang.org/x/tools v0.0.0-20190114222345-bf090417da8b/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ= golang.org/x/tools v0.0.0-20190226205152-f727befe758c/go.mod h1:9Yl7xja0Znq3iFh3HoIrodX9oNMXvdceNzlUR8zjMvY= golang.org/x/tools v0.0.0-20200130002326-2f3ba24bd6e7/go.mod h1:TB2adYChydJhpapKDTa4BR/hXlZSLoq2Wpct/0txZ28= -golang.org/x/tools v0.14.0 h1:jvNa2pY0M4r62jkRQ6RwEZZyPcymeL9XZMLBbV7U2nc= -golang.org/x/tools v0.14.0/go.mod h1:uYBEerGOWcJyEORxN+Ek8+TT266gXkNlHdJBwexUsBg= +golang.org/x/tools v0.20.0 h1:hz/CVckiOxybQvFw6h7b/q80NTr9IUQb4s1IIzW7KNY= +golang.org/x/tools v0.20.0/go.mod h1:WvitBU7JJf6A4jOdg4S1tviW9bhUxkgeCui/0JHctQg= golang.org/x/xerrors v0.0.0-20191011141410-1b5146add898/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= google.golang.org/api v0.0.0-20180910000450-7ca32eb868bf/go.mod h1:4mhQ8q/RsB7i+udVvVy5NUi08OU8ZlA0gRVgrF7VFY0= google.golang.org/api v0.0.0-20181030000543-1d582fd0359e/go.mod h1:4mhQ8q/RsB7i+udVvVy5NUi08OU8ZlA0gRVgrF7VFY0= diff --git a/http3/README.md b/http3/README.md new file mode 100644 index 000000000..a6128516f --- /dev/null +++ b/http3/README.md @@ -0,0 +1,104 @@ +# HTTP/3 + +[![PkgGoDev](https://pkg.go.dev/badge/github.com/quic-go/quic-go/http3)](https://pkg.go.dev/github.com/quic-go/quic-go/http3) + +This package implements HTTP/3 ([RFC 9114](https://datatracker.ietf.org/doc/html/rfc9114)), including QPACK ([RFC 9204](https://datatracker.ietf.org/doc/html/rfc9204)). +It aims to provide feature parity with the standard library's HTTP/1.1 and HTTP/2 implementation. + +## Serving HTTP/3 + +The easiest way to start an HTTP/3 server is using +```go +mux := http.NewServeMux() +// ... add HTTP handlers to mux ... +// If mux is nil, the http.DefaultServeMux is used. +http3.ListenAndServeQUIC("0.0.0.0:443", "/path/to/cert", "/path/to/key", mux) +``` + +`ListenAndServeQUIC` is a convenience function. For more configurability, set up an `http3.Server` explicitly: +```go +server := http3.Server{ + Handler: mux, + Addr: "0.0.0.0:443", + TLSConfig: http3.ConfigureTLSConfig(&tls.Config{}), // use your tls.Config here + QuicConfig: &quic.Config{}, +} +err := server.ListenAndServe() +``` + +The `http3.Server` provides a number of configuration options, please refer to the [documentation](https://pkg.go.dev/github.com/quic-go/quic-go/http3#Server) for a complete list. The `QuicConfig` is used to configure the underlying QUIC connection. More details can be found in the documentation of the QUIC package. + +It is also possible to manually set up a `quic.Transport`, and then pass the listener to the server. This is useful when you want to set configuration options on the `quic.Transport`. +```go +tr := quic.Transport{Conn: conn} +tlsConf := http3.ConfigureTLSConfig(&tls.Config{}) // use your tls.Config here +quicConf := &quic.Config{} // QUIC connection options +server := http3.Server{} +ln, _ := tr.ListenEarly(tlsConf, quicConf) +server.ServeListener(ln) +``` + +Alternatively, it is also possible to pass fully established QUIC connections to the HTTP/3 server. This is useful if the QUIC server offers multiple ALPNs (via `NextProtos` in the `tls.Config`). +```go +tr := quic.Transport{Conn: conn} +tlsConf := http3.ConfigureTLSConfig(&tls.Config{}) // use your tls.Config here +quicConf := &quic.Config{} // QUIC connection options +server := http3.Server{} +// alternatively, use tr.ListenEarly to accept 0-RTT connections +ln, _ := tr.Listen(tlsConf, quicConf) +for { + c, _ := ln.Accept() + switch c.ConnectionState().TLS.NegotiatedProtocol { + case http3.NextProtoH3: + go server.ServeQUICConn(c) + // ... handle other protocols ... + } +} +``` + +## Dialing HTTP/3 + +This package provides a `http.RoundTripper` implementation that can be used on the `http.Client`: + +```go +&http3.RoundTripper{ + TLSClientConfig: &tls.Config{}, // set a TLS client config, if desired + QuicConfig: &quic.Config{}, // QUIC connection options +} +defer roundTripper.Close() +client := &http.Client{ + Transport: roundTripper, +} +``` + +The `http3.RoundTripper` provides a number of configuration options, please refer to the [documentation](https://pkg.go.dev/github.com/quic-go/quic-go/http3#RoundTripper) for a complete list. + +To use a custom `quic.Transport`, the function used to dial new QUIC connections can be configured: +```go +tr := quic.Transport{} +roundTripper := &http3.RoundTripper{ + TLSClientConfig: &tls.Config{}, // set a TLS client config, if desired + QuicConfig: &quic.Config{}, // QUIC connection options + Dial: func(ctx context.Context, addr string, tlsConf *tls.Config, quicConf *quic.Config) (quic.EarlyConnection, error) { + a, err := net.ResolveUDPAddr("udp", addr) + if err != nil { + return nil, err + } + return tr.DialEarly(ctx, a, tlsConf, quicConf) + }, +} +``` + +## Using the same UDP Socket for Server and Roundtripper + +Since QUIC demultiplexes packets based on their connection IDs, it is possible allows running a QUIC server and client on the same UDP socket. This also works when using HTTP/3: HTTP requests can be sent from the same socket that a server is listening on. + +To achieve this using this package, first initialize a single `quic.Transport`, and pass a `quic.EarlyListner` obtained from that transport to `http3.Server.ServeListener`, and use the `DialEarly` function of the transport as the `Dial` function for the `http3.RoundTripper`. + +## QPACK + +HTTP/3 utilizes QPACK ([RFC 9204](https://datatracker.ietf.org/doc/html/rfc9204)) for efficient HTTP header field compression. Our implementation, available at[quic-go/qpack](https://github.com/quic-go/qpack), provides a minimal implementation of the protocol. + +While the current implementation is a fully interoperable implementation of the QPACK protocol, it only uses the static compression table. The dynamic table would allow for more effective compression of frequently transmitted header fields. This can be particularly beneficial in scenarios where headers have considerable redundancy or in high-throughput environments. + +If you think that your application would benefit from higher compression efficiency, or if you're interested in contributing improvements here, please let us know in [#2424](https://github.com/quic-go/quic-go/issues/2424). diff --git a/http3/client.go b/http3/client.go index bac73f325..ececd95f0 100644 --- a/http3/client.go +++ b/http3/client.go @@ -61,6 +61,9 @@ type client struct { dialer dialFunc handshakeErr error + receivedSettings chan struct{} // closed once the server's SETTINGS frame was processed + settings *Settings // set once receivedSettings is closed + requestWriter *requestWriter decoder *qpack.Decoder @@ -76,10 +79,14 @@ var _ roundTripCloser = &client{} func newClient(hostname string, tlsConf *tls.Config, opts *roundTripperOpts, conf *quic.Config, dialer dialFunc) (roundTripCloser, error) { if conf == nil { conf = defaultQuicConfig.Clone() + conf.EnableDatagrams = opts.EnableDatagram + } + if opts.EnableDatagram && !conf.EnableDatagrams { + return nil, errors.New("HTTP Datagrams enabled, but QUIC Datagrams disabled") } if len(conf.Versions) == 0 { conf = conf.Clone() - conf.Versions = []quic.VersionNumber{protocol.SupportedVersions[0]} + conf.Versions = []quic.Version{protocol.SupportedVersions[0]} } if len(conf.Versions) != 1 { return nil, errors.New("can only use a single QUIC version for dialing a HTTP/3 connection") @@ -87,7 +94,6 @@ func newClient(hostname string, tlsConf *tls.Config, opts *roundTripperOpts, con if conf.MaxIncomingStreams == 0 { conf.MaxIncomingStreams = -1 // don't allow any bidirectional streams } - conf.EnableDatagrams = opts.EnableDatagram logger := utils.DefaultLogger.WithPrefix("h3 client") if tlsConf == nil { @@ -107,14 +113,15 @@ func newClient(hostname string, tlsConf *tls.Config, opts *roundTripperOpts, con tlsConf.NextProtos = []string{versionToALPN(conf.Versions[0])} return &client{ - hostname: authorityAddr("https", hostname), - tlsConf: tlsConf, - requestWriter: newRequestWriter(logger), - decoder: qpack.NewDecoder(func(hf qpack.HeaderField) {}), - config: conf, - opts: opts, - dialer: dialer, - logger: logger, + hostname: authorityAddr("https", hostname), + tlsConf: tlsConf, + requestWriter: newRequestWriter(logger), + receivedSettings: make(chan struct{}), + decoder: qpack.NewDecoder(func(hf qpack.HeaderField) {}), + config: conf, + opts: opts, + dialer: dialer, + logger: logger, }, nil } @@ -183,6 +190,8 @@ func (c *client) handleBidirectionalStreams(conn quic.EarlyConnection) { } func (c *client) handleUnidirectionalStreams(conn quic.EarlyConnection) { + var rcvdControlStream atomic.Bool + for { str, err := conn.AcceptUniStream(context.Background()) if err != nil { @@ -217,6 +226,11 @@ func (c *client) handleUnidirectionalStreams(conn quic.EarlyConnection) { str.CancelRead(quic.StreamErrorCode(ErrCodeStreamCreationError)) return } + // Only a single control stream is allowed. + if isFirstControlStr := rcvdControlStream.CompareAndSwap(false, true); !isFirstControlStr { + conn.CloseWithError(quic.ApplicationErrorCode(ErrCodeStreamCreationError), "duplicate control stream") + return + } f, err := parseNextFrame(str, nil) if err != nil { conn.CloseWithError(quic.ApplicationErrorCode(ErrCodeFrameError), "") @@ -227,6 +241,12 @@ func (c *client) handleUnidirectionalStreams(conn quic.EarlyConnection) { conn.CloseWithError(quic.ApplicationErrorCode(ErrCodeMissingSettings), "") return } + c.settings = &Settings{ + EnableDatagram: sf.Datagram, + EnableExtendedConnect: sf.ExtendedConnect, + Other: sf.Other, + } + close(c.receivedSettings) if !sf.Datagram { return } @@ -257,6 +277,15 @@ func (c *client) maxHeaderBytes() uint64 { // RoundTripOpt executes a request and returns a response func (c *client) RoundTripOpt(req *http.Request, opt RoundTripOpt) (*http.Response, error) { + rsp, err := c.roundTripOpt(req, opt) + if err != nil && req.Context().Err() != nil { + // if the context was canceled, return the context cancellation error + err = req.Context().Err() + } + return rsp, err +} + +func (c *client) roundTripOpt(req *http.Request, opt RoundTripOpt) (*http.Response, error) { if authorityAddr("https", hostnameFromRequest(req)) != c.hostname { return nil, fmt.Errorf("http3 client BUG: RoundTripOpt called for the wrong client (expected %s, got %s)", c.hostname, req.Host) } @@ -283,6 +312,18 @@ func (c *client) RoundTripOpt(req *http.Request, opt RoundTripOpt) (*http.Respon } } + if opt.CheckSettings != nil { + // wait for the server's SETTINGS frame to arrive + select { + case <-c.receivedSettings: + case <-conn.Context().Done(): + return nil, context.Cause(conn.Context()) + } + if err := opt.CheckSettings(*c.settings); err != nil { + return nil, err + } + } + str, err := conn.OpenStreamSync(req.Context()) if err != nil { return nil, err diff --git a/http3/client_test.go b/http3/client_test.go index 28504f30c..80e98b92c 100644 --- a/http3/client_test.go +++ b/http3/client_test.go @@ -11,11 +11,12 @@ import ( "sync" "time" - quic "github.com/refraction-networking/uquic" tls "github.com/refraction-networking/utls" + quic "github.com/refraction-networking/uquic" mockquic "github.com/refraction-networking/uquic/internal/mocks/quic" "github.com/refraction-networking/uquic/internal/protocol" + "github.com/refraction-networking/uquic/internal/qerr" "github.com/refraction-networking/uquic/internal/utils" "github.com/refraction-networking/uquic/quicvarint" @@ -56,7 +57,7 @@ var _ = Describe("Client", func() { It("rejects quic.Configs that allow multiple QUIC versions", func() { qconf := &quic.Config{ - Versions: []quic.VersionNumber{protocol.Version2, protocol.Version1}, + Versions: []quic.Version{protocol.Version2, protocol.Version1}, } _, err := newClient("localhost:1337", nil, &roundTripperOpts{}, qconf, nil) Expect(err).To(MatchError("can only use a single QUIC version for dialing a HTTP/3 connection")) @@ -69,7 +70,7 @@ var _ = Describe("Client", func() { dialAddr = func(_ context.Context, _ string, tlsConf *tls.Config, quicConf *quic.Config) (quic.EarlyConnection, error) { Expect(quicConf.MaxIncomingStreams).To(Equal(defaultQuicConfig.MaxIncomingStreams)) Expect(tlsConf.NextProtos).To(Equal([]string{NextProtoH3})) - Expect(quicConf.Versions).To(Equal([]protocol.VersionNumber{protocol.Version1})) + Expect(quicConf.Versions).To(Equal([]protocol.Version{protocol.Version1})) dialAddrCalled = true return nil, errors.New("test done") } @@ -214,9 +215,10 @@ var _ = Describe("Client", func() { testDone = make(chan struct{}) settingsFrameWritten = make(chan struct{}) controlStr := mockquic.NewMockStream(mockCtrl) - controlStr.EXPECT().Write(gomock.Any()).Do(func(b []byte) { + controlStr.EXPECT().Write(gomock.Any()).Do(func(b []byte) (int, error) { defer GinkgoRecover() close(settingsFrameWritten) + return len(b), nil }) conn = mockquic.NewMockEarlyConnection(mockCtrl) conn.EXPECT().OpenUniStream().Return(controlStr, nil) @@ -340,9 +342,10 @@ var _ = Describe("Client", func() { testDone = make(chan struct{}) settingsFrameWritten = make(chan struct{}) controlStr := mockquic.NewMockStream(mockCtrl) - controlStr.EXPECT().Write(gomock.Any()).Do(func(b []byte) { + controlStr.EXPECT().Write(gomock.Any()).Do(func(b []byte) (int, error) { defer GinkgoRecover() close(settingsFrameWritten) + return len(b), nil }) conn = mockquic.NewMockEarlyConnection(mockCtrl) conn.EXPECT().OpenUniStream().Return(controlStr, nil) @@ -441,19 +444,19 @@ var _ = Describe("Client", func() { conn *mockquic.MockEarlyConnection settingsFrameWritten chan struct{} ) - testDone := make(chan struct{}) + testDone := make(chan struct{}, 1) BeforeEach(func() { settingsFrameWritten = make(chan struct{}) controlStr := mockquic.NewMockStream(mockCtrl) - controlStr.EXPECT().Write(gomock.Any()).Do(func(b []byte) { + controlStr.EXPECT().Write(gomock.Any()).Do(func(b []byte) (int, error) { defer GinkgoRecover() close(settingsFrameWritten) + return len(b), nil }) conn = mockquic.NewMockEarlyConnection(mockCtrl) conn.EXPECT().OpenUniStream().Return(controlStr, nil) conn.EXPECT().HandshakeComplete().Return(handshakeChan) - conn.EXPECT().OpenStreamSync(gomock.Any()).Return(nil, errors.New("done")) dialAddr = func(context.Context, string, *tls.Config, *quic.Config) (quic.EarlyConnection, error) { return conn, nil } @@ -469,10 +472,15 @@ var _ = Describe("Client", func() { It("parses the SETTINGS frame", func() { b := quicvarint.Append(nil, streamTypeControlStream) - b = (&settingsFrame{}).Append(b) + b = (&settingsFrame{ + Datagram: true, + ExtendedConnect: true, + Other: map[uint64]uint64{1337: 42}, + }).Append(b) r := bytes.NewReader(b) controlStr := mockquic.NewMockStream(mockCtrl) controlStr.EXPECT().Read(gomock.Any()).DoAndReturn(r.Read).AnyTimes() + conn.EXPECT().OpenStreamSync(gomock.Any()).Return(nil, errors.New("done")) conn.EXPECT().AcceptUniStream(gomock.Any()).DoAndReturn(func(context.Context) (quic.ReceiveStream, error) { return controlStr, nil }) @@ -480,11 +488,72 @@ var _ = Describe("Client", func() { <-testDone return nil, errors.New("test done") }) - _, err := cl.RoundTripOpt(req, RoundTripOpt{}) + conn.EXPECT().Context().Return(context.Background()) + _, err := cl.RoundTripOpt(req, RoundTripOpt{CheckSettings: func(settings Settings) error { + defer GinkgoRecover() + Expect(settings.EnableDatagram).To(BeTrue()) + Expect(settings.EnableExtendedConnect).To(BeTrue()) + Expect(settings.Other).To(HaveLen(1)) + Expect(settings.Other).To(HaveKeyWithValue(uint64(1337), uint64(42))) + return nil + }}) Expect(err).To(MatchError("done")) time.Sleep(scaleDuration(20 * time.Millisecond)) // don't EXPECT any calls to conn.CloseWithError }) + It("allows the client to reject the SETTINGS using the CheckSettings RoundTripOpt", func() { + b := quicvarint.Append(nil, streamTypeControlStream) + b = (&settingsFrame{}).Append(b) + r := bytes.NewReader(b) + controlStr := mockquic.NewMockStream(mockCtrl) + controlStr.EXPECT().Read(gomock.Any()).DoAndReturn(r.Read).AnyTimes() + // Don't EXPECT any call to OpenStreamSync. + // When the SETTINGS are rejected, we don't even open the request stream. + conn.EXPECT().AcceptUniStream(gomock.Any()).DoAndReturn(func(context.Context) (quic.ReceiveStream, error) { + return controlStr, nil + }) + conn.EXPECT().AcceptUniStream(gomock.Any()).DoAndReturn(func(context.Context) (quic.ReceiveStream, error) { + <-testDone + return nil, errors.New("test done") + }) + conn.EXPECT().Context().Return(context.Background()) + _, err := cl.RoundTripOpt(req, RoundTripOpt{CheckSettings: func(settings Settings) error { + return errors.New("wrong settings") + }}) + Expect(err).To(MatchError("wrong settings")) + time.Sleep(scaleDuration(20 * time.Millisecond)) // don't EXPECT any calls to conn.CloseWithError + }) + + It("rejects duplicate control streams", func() { + b := quicvarint.Append(nil, streamTypeControlStream) + b = (&settingsFrame{}).Append(b) + r1 := bytes.NewReader(b) + controlStr1 := mockquic.NewMockStream(mockCtrl) + controlStr1.EXPECT().Read(gomock.Any()).DoAndReturn(r1.Read).AnyTimes() + r2 := bytes.NewReader(b) + controlStr2 := mockquic.NewMockStream(mockCtrl) + controlStr2.EXPECT().Read(gomock.Any()).DoAndReturn(r2.Read).AnyTimes() + done := make(chan struct{}) + conn.EXPECT().OpenStreamSync(gomock.Any()).Return(nil, errors.New("done")) + conn.EXPECT().CloseWithError(qerr.ApplicationErrorCode(ErrCodeStreamCreationError), "duplicate control stream").Do(func(qerr.ApplicationErrorCode, string) error { + close(done) + return nil + }) + conn.EXPECT().AcceptUniStream(gomock.Any()).DoAndReturn(func(context.Context) (quic.ReceiveStream, error) { + return controlStr1, nil + }) + conn.EXPECT().AcceptUniStream(gomock.Any()).DoAndReturn(func(context.Context) (quic.ReceiveStream, error) { + return controlStr2, nil + }) + conn.EXPECT().AcceptUniStream(gomock.Any()).DoAndReturn(func(context.Context) (quic.ReceiveStream, error) { + <-done + return nil, errors.New("test done") + }) + _, err := cl.RoundTripOpt(req, RoundTripOpt{}) + Expect(err).To(HaveOccurred()) + Eventually(done).Should(BeClosed()) + }) + for _, t := range []uint64{streamTypeQPACKEncoderStream, streamTypeQPACKDecoderStream} { streamType := t name := "encoder" @@ -497,6 +566,7 @@ var _ = Describe("Client", func() { str := mockquic.NewMockStream(mockCtrl) str.EXPECT().Read(gomock.Any()).DoAndReturn(buf.Read).AnyTimes() + conn.EXPECT().OpenStreamSync(gomock.Any()).Return(nil, errors.New("done")) conn.EXPECT().AcceptUniStream(gomock.Any()).DoAndReturn(func(context.Context) (quic.ReceiveStream, error) { return str, nil }) @@ -510,15 +580,14 @@ var _ = Describe("Client", func() { }) } - It("resets streams Other than the control stream and the QPACK streams", func() { + It("resets streams other than the control stream and the QPACK streams", func() { buf := bytes.NewBuffer(quicvarint.Append(nil, 0x1337)) str := mockquic.NewMockStream(mockCtrl) str.EXPECT().Read(gomock.Any()).DoAndReturn(buf.Read).AnyTimes() done := make(chan struct{}) - str.EXPECT().CancelRead(quic.StreamErrorCode(ErrCodeStreamCreationError)).Do(func(code quic.StreamErrorCode) { - close(done) - }) + str.EXPECT().CancelRead(quic.StreamErrorCode(ErrCodeStreamCreationError)).Do(func(quic.StreamErrorCode) { close(done) }) + conn.EXPECT().OpenStreamSync(gomock.Any()).Return(nil, errors.New("done")) conn.EXPECT().AcceptUniStream(gomock.Any()).DoAndReturn(func(context.Context) (quic.ReceiveStream, error) { return str, nil }) @@ -537,6 +606,8 @@ var _ = Describe("Client", func() { r := bytes.NewReader(b) controlStr := mockquic.NewMockStream(mockCtrl) controlStr.EXPECT().Read(gomock.Any()).DoAndReturn(r.Read).AnyTimes() + + conn.EXPECT().OpenStreamSync(gomock.Any()).Return(nil, errors.New("done")) conn.EXPECT().AcceptUniStream(gomock.Any()).DoAndReturn(func(context.Context) (quic.ReceiveStream, error) { return controlStr, nil }) @@ -545,22 +616,53 @@ var _ = Describe("Client", func() { return nil, errors.New("test done") }) done := make(chan struct{}) - conn.EXPECT().CloseWithError(gomock.Any(), gomock.Any()).Do(func(code quic.ApplicationErrorCode, _ string) { - defer GinkgoRecover() - Expect(code).To(BeEquivalentTo(ErrCodeMissingSettings)) + conn.EXPECT().CloseWithError(quic.ApplicationErrorCode(ErrCodeMissingSettings), gomock.Any()).Do(func(quic.ApplicationErrorCode, string) error { close(done) + return nil }) _, err := cl.RoundTripOpt(req, RoundTripOpt{}) Expect(err).To(MatchError("done")) Eventually(done).Should(BeClosed()) }) + It("errors when the first frame on the control stream is not a SETTINGS frame, when checking SETTINGS", func() { + b := quicvarint.Append(nil, streamTypeControlStream) + b = (&dataFrame{}).Append(b) + r := bytes.NewReader(b) + controlStr := mockquic.NewMockStream(mockCtrl) + controlStr.EXPECT().Read(gomock.Any()).DoAndReturn(r.Read).AnyTimes() + + // Don't EXPECT any calls to OpenStreamSync. + // We fail before we even get the chance to open the request stream. + conn.EXPECT().AcceptUniStream(gomock.Any()).DoAndReturn(func(context.Context) (quic.ReceiveStream, error) { + return controlStr, nil + }) + conn.EXPECT().AcceptUniStream(gomock.Any()).DoAndReturn(func(context.Context) (quic.ReceiveStream, error) { + <-testDone + return nil, errors.New("test done") + }) + doneCtx, doneCancel := context.WithCancelCause(context.Background()) + conn.EXPECT().CloseWithError(quic.ApplicationErrorCode(ErrCodeMissingSettings), gomock.Any()).Do(func(quic.ApplicationErrorCode, string) error { + doneCancel(errors.New("done")) + return nil + }) + conn.EXPECT().Context().Return(doneCtx).Times(2) + var checked bool + _, err := cl.RoundTripOpt(req, RoundTripOpt{ + CheckSettings: func(Settings) error { checked = true; return nil }, + }) + Expect(checked).To(BeFalse()) + Expect(err).To(MatchError("done")) + Eventually(doneCtx.Done()).Should(BeClosed()) + }) + It("errors when parsing the frame on the control stream fails", func() { b := quicvarint.Append(nil, streamTypeControlStream) b = (&settingsFrame{}).Append(b) r := bytes.NewReader(b[:len(b)-1]) controlStr := mockquic.NewMockStream(mockCtrl) controlStr.EXPECT().Read(gomock.Any()).DoAndReturn(r.Read).AnyTimes() + conn.EXPECT().OpenStreamSync(gomock.Any()).Return(nil, errors.New("done")) conn.EXPECT().AcceptUniStream(gomock.Any()).DoAndReturn(func(context.Context) (quic.ReceiveStream, error) { return controlStr, nil }) @@ -569,10 +671,9 @@ var _ = Describe("Client", func() { return nil, errors.New("test done") }) done := make(chan struct{}) - conn.EXPECT().CloseWithError(gomock.Any(), gomock.Any()).Do(func(code quic.ApplicationErrorCode, _ string) { - defer GinkgoRecover() - Expect(code).To(BeEquivalentTo(ErrCodeFrameError)) + conn.EXPECT().CloseWithError(quic.ApplicationErrorCode(ErrCodeFrameError), gomock.Any()).Do(func(code quic.ApplicationErrorCode, _ string) error { close(done) + return nil }) _, err := cl.RoundTripOpt(req, RoundTripOpt{}) Expect(err).To(MatchError("done")) @@ -583,6 +684,7 @@ var _ = Describe("Client", func() { buf := bytes.NewBuffer(quicvarint.Append(nil, streamTypePushStream)) controlStr := mockquic.NewMockStream(mockCtrl) controlStr.EXPECT().Read(gomock.Any()).DoAndReturn(buf.Read).AnyTimes() + conn.EXPECT().OpenStreamSync(gomock.Any()).Return(nil, errors.New("done")) conn.EXPECT().AcceptUniStream(gomock.Any()).DoAndReturn(func(context.Context) (quic.ReceiveStream, error) { return controlStr, nil }) @@ -591,10 +693,9 @@ var _ = Describe("Client", func() { return nil, errors.New("test done") }) done := make(chan struct{}) - conn.EXPECT().CloseWithError(gomock.Any(), gomock.Any()).Do(func(code quic.ApplicationErrorCode, _ string) { - defer GinkgoRecover() - Expect(code).To(BeEquivalentTo(ErrCodeIDError)) + conn.EXPECT().CloseWithError(quic.ApplicationErrorCode(ErrCodeIDError), gomock.Any()).Do(func(quic.ApplicationErrorCode, string) error { close(done) + return nil }) _, err := cl.RoundTripOpt(req, RoundTripOpt{}) Expect(err).To(MatchError("done")) @@ -608,6 +709,7 @@ var _ = Describe("Client", func() { r := bytes.NewReader(b) controlStr := mockquic.NewMockStream(mockCtrl) controlStr.EXPECT().Read(gomock.Any()).DoAndReturn(r.Read).AnyTimes() + conn.EXPECT().OpenStreamSync(gomock.Any()).Return(nil, errors.New("done")) conn.EXPECT().AcceptUniStream(gomock.Any()).DoAndReturn(func(context.Context) (quic.ReceiveStream, error) { return controlStr, nil }) @@ -617,11 +719,9 @@ var _ = Describe("Client", func() { }) conn.EXPECT().ConnectionState().Return(quic.ConnectionState{SupportsDatagrams: false}) done := make(chan struct{}) - conn.EXPECT().CloseWithError(gomock.Any(), gomock.Any()).Do(func(code quic.ApplicationErrorCode, reason string) { - defer GinkgoRecover() - Expect(code).To(BeEquivalentTo(ErrCodeSettingsError)) - Expect(reason).To(Equal("missing QUIC Datagram support")) + conn.EXPECT().CloseWithError(quic.ApplicationErrorCode(ErrCodeSettingsError), "missing QUIC Datagram support").Do(func(quic.ApplicationErrorCode, string) error { close(done) + return nil }) _, err := cl.RoundTripOpt(req, RoundTripOpt{}) Expect(err).To(MatchError("done")) @@ -670,13 +770,14 @@ var _ = Describe("Client", func() { BeforeEach(func() { settingsFrameWritten = make(chan struct{}) controlStr := mockquic.NewMockStream(mockCtrl) - controlStr.EXPECT().Write(gomock.Any()).Do(func(b []byte) { + controlStr.EXPECT().Write(gomock.Any()).Do(func(b []byte) (int, error) { defer GinkgoRecover() r := bytes.NewReader(b) streamType, err := quicvarint.Read(r) Expect(err).ToNot(HaveOccurred()) Expect(streamType).To(BeEquivalentTo(streamTypeControlStream)) close(settingsFrameWritten) + return len(b), nil }) // SETTINGS frame str = mockquic.NewMockStream(mockCtrl) conn = mockquic.NewMockEarlyConnection(mockCtrl) @@ -778,7 +879,7 @@ var _ = Describe("Client", func() { It("sends a request", func() { done := make(chan struct{}) gomock.InOrder( - str.EXPECT().Close().Do(func() { close(done) }), + str.EXPECT().Close().Do(func() error { close(done); return nil }), str.EXPECT().CancelWrite(gomock.Any()).MaxTimes(1), // when reading the response errors ) // the response body is sent asynchronously, while already reading the response @@ -832,7 +933,7 @@ var _ = Describe("Client", func() { return 0, errors.New("test done") }) closed := make(chan struct{}) - str.EXPECT().Close().Do(func() { close(closed) }) + str.EXPECT().Close().Do(func() error { close(closed); return nil }) _, err := cl.RoundTripOpt(req, RoundTripOpt{}) Expect(err).To(MatchError("test done")) Eventually(closed).Should(BeClosed()) @@ -843,7 +944,7 @@ var _ = Describe("Client", func() { conn.EXPECT().CloseWithError(quic.ApplicationErrorCode(ErrCodeFrameUnexpected), gomock.Any()) closed := make(chan struct{}) r := bytes.NewReader(b) - str.EXPECT().Close().Do(func() { close(closed) }) + str.EXPECT().Close().Do(func() error { close(closed); return nil }) str.EXPECT().Read(gomock.Any()).DoAndReturn(r.Read).AnyTimes() _, err := cl.RoundTripOpt(req, RoundTripOpt{}) Expect(err).To(MatchError("expected first frame to be a HEADERS frame")) @@ -861,7 +962,7 @@ var _ = Describe("Client", func() { r := bytes.NewReader(b) str.EXPECT().CancelWrite(quic.StreamErrorCode(ErrCodeMessageError)) closed := make(chan struct{}) - str.EXPECT().Close().Do(func() { close(closed) }) + str.EXPECT().Close().Do(func() error { close(closed); return nil }) str.EXPECT().Read(gomock.Any()).DoAndReturn(r.Read).AnyTimes() _, err := cl.RoundTripOpt(req, RoundTripOpt{}) Expect(err).To(HaveOccurred()) @@ -873,7 +974,7 @@ var _ = Describe("Client", func() { r := bytes.NewReader(b) str.EXPECT().CancelWrite(quic.StreamErrorCode(ErrCodeFrameError)) closed := make(chan struct{}) - str.EXPECT().Close().Do(func() { close(closed) }) + str.EXPECT().Close().Do(func() error { close(closed); return nil }) str.EXPECT().Read(gomock.Any()).DoAndReturn(r.Read).AnyTimes() _, err := cl.RoundTripOpt(req, RoundTripOpt{}) Expect(err).To(MatchError("HEADERS frame too large: 1338 bytes (max: 1337)")) @@ -926,7 +1027,7 @@ var _ = Describe("Client", func() { return 0, errors.New("test done") }) _, err := cl.RoundTripOpt(req, roundTripOpt) - Expect(err).To(MatchError("test done")) + Expect(err).To(MatchError(context.Canceled)) Eventually(done).Should(BeClosed()) }) }) diff --git a/http3/frames.go b/http3/frames.go index 7f6d0fe8d..961694af5 100644 --- a/http3/frames.go +++ b/http3/frames.go @@ -88,11 +88,18 @@ func (f *headersFrame) Append(b []byte) []byte { return quicvarint.Append(b, f.Length) } -const settingDatagram = 0x33 +const ( + // Extended CONNECT, RFC 9220 + settingExtendedConnect = 0x8 + // HTTP Datagrams, RFC 9297 + settingDatagram = 0x33 +) type settingsFrame struct { - Datagram bool - Other map[uint64]uint64 // all settings that we don't explicitly recognize + Datagram bool // HTTP Datagrams, RFC 9297 + ExtendedConnect bool // Extended CONNECT, RFC 9220 + + Other map[uint64]uint64 // all settings that we don't explicitly recognize } func parseSettingsFrame(r io.Reader, l uint64) (*settingsFrame, error) { @@ -108,7 +115,7 @@ func parseSettingsFrame(r io.Reader, l uint64) (*settingsFrame, error) { } frame := &settingsFrame{} b := bytes.NewReader(buf) - var readDatagram bool + var readDatagram, readExtendedConnect bool for b.Len() > 0 { id, err := quicvarint.Read(b) if err != nil { // should not happen. We allocated the whole frame already. @@ -120,13 +127,22 @@ func parseSettingsFrame(r io.Reader, l uint64) (*settingsFrame, error) { } switch id { + case settingExtendedConnect: + if readExtendedConnect { + return nil, fmt.Errorf("duplicate setting: %d", id) + } + readExtendedConnect = true + if val != 0 && val != 1 { + return nil, fmt.Errorf("invalid value for SETTINGS_ENABLE_CONNECT_PROTOCOL: %d", val) + } + frame.ExtendedConnect = val == 1 case settingDatagram: if readDatagram { return nil, fmt.Errorf("duplicate setting: %d", id) } readDatagram = true if val != 0 && val != 1 { - return nil, fmt.Errorf("invalid value for H3_DATAGRAM: %d", val) + return nil, fmt.Errorf("invalid value for SETTINGS_H3_DATAGRAM: %d", val) } frame.Datagram = val == 1 default: @@ -151,11 +167,18 @@ func (f *settingsFrame) Append(b []byte) []byte { if f.Datagram { l += quicvarint.Len(settingDatagram) + quicvarint.Len(1) } + if f.ExtendedConnect { + l += quicvarint.Len(settingExtendedConnect) + quicvarint.Len(1) + } b = quicvarint.Append(b, uint64(l)) if f.Datagram { b = quicvarint.Append(b, settingDatagram) b = quicvarint.Append(b, 1) } + if f.ExtendedConnect { + b = quicvarint.Append(b, settingExtendedConnect) + b = quicvarint.Append(b, 1) + } for id, val := range f.Other { b = quicvarint.Append(b, id) b = quicvarint.Append(b, val) diff --git a/http3/frames_test.go b/http3/frames_test.go index ed56c74af..cf9fa27c8 100644 --- a/http3/frames_test.go +++ b/http3/frames_test.go @@ -127,8 +127,8 @@ var _ = Describe("Frames", func() { } }) - Context("H3_DATAGRAM", func() { - It("reads the H3_DATAGRAM value", func() { + Context("HTTP Datagrams", func() { + It("reads the SETTINGS_H3_DATAGRAM value", func() { settings := quicvarint.Append(nil, settingDatagram) settings = quicvarint.Append(settings, 1) data := quicvarint.Append(nil, 4) // type byte @@ -141,7 +141,7 @@ var _ = Describe("Frames", func() { Expect(sf.Datagram).To(BeTrue()) }) - It("rejects duplicate H3_DATAGRAM entries", func() { + It("rejects duplicate SETTINGS_H3_DATAGRAM entries", func() { settings := quicvarint.Append(nil, settingDatagram) settings = quicvarint.Append(settings, 1) settings = quicvarint.Append(settings, settingDatagram) @@ -153,23 +153,67 @@ var _ = Describe("Frames", func() { Expect(err).To(MatchError(fmt.Sprintf("duplicate setting: %d", settingDatagram))) }) - It("rejects invalid values for the H3_DATAGRAM entry", func() { + It("rejects invalid values for the SETTINGS_H3_DATAGRAM entry", func() { settings := quicvarint.Append(nil, settingDatagram) settings = quicvarint.Append(settings, 1337) data := quicvarint.Append(nil, 4) // type byte data = quicvarint.Append(data, uint64(len(settings))) data = append(data, settings...) _, err := parseNextFrame(bytes.NewReader(data), nil) - Expect(err).To(MatchError("invalid value for H3_DATAGRAM: 1337")) + Expect(err).To(MatchError("invalid value for SETTINGS_H3_DATAGRAM: 1337")) }) - It("writes the H3_DATAGRAM setting", func() { + It("writes the SETTINGS_H3_DATAGRAM setting", func() { sf := &settingsFrame{Datagram: true} frame, err := parseNextFrame(bytes.NewReader(sf.Append(nil)), nil) Expect(err).ToNot(HaveOccurred()) Expect(frame).To(Equal(sf)) }) }) + + Context("Extended Connect", func() { + It("reads the SETTINGS_ENABLE_CONNECT_PROTOCOL value", func() { + settings := quicvarint.Append(nil, settingExtendedConnect) + settings = quicvarint.Append(settings, 1) + data := quicvarint.Append(nil, 4) // type byte + data = quicvarint.Append(data, uint64(len(settings))) + data = append(data, settings...) + f, err := parseNextFrame(bytes.NewReader(data), nil) + Expect(err).ToNot(HaveOccurred()) + Expect(f).To(BeAssignableToTypeOf(&settingsFrame{})) + sf := f.(*settingsFrame) + Expect(sf.ExtendedConnect).To(BeTrue()) + }) + + It("rejects duplicate SETTINGS_ENABLE_CONNECT_PROTOCOL entries", func() { + settings := quicvarint.Append(nil, settingExtendedConnect) + settings = quicvarint.Append(settings, 1) + settings = quicvarint.Append(settings, settingExtendedConnect) + settings = quicvarint.Append(settings, 1) + data := quicvarint.Append(nil, 4) // type byte + data = quicvarint.Append(data, uint64(len(settings))) + data = append(data, settings...) + _, err := parseNextFrame(bytes.NewReader(data), nil) + Expect(err).To(MatchError(fmt.Sprintf("duplicate setting: %d", settingExtendedConnect))) + }) + + It("rejects invalid values for the SETTINGS_ENABLE_CONNECT_PROTOCOL entry", func() { + settings := quicvarint.Append(nil, settingExtendedConnect) + settings = quicvarint.Append(settings, 1337) + data := quicvarint.Append(nil, 4) // type byte + data = quicvarint.Append(data, uint64(len(settings))) + data = append(data, settings...) + _, err := parseNextFrame(bytes.NewReader(data), nil) + Expect(err).To(MatchError("invalid value for SETTINGS_ENABLE_CONNECT_PROTOCOL: 1337")) + }) + + It("writes the SETTINGS_ENABLE_CONNECT_PROTOCOL setting", func() { + sf := &settingsFrame{ExtendedConnect: true} + frame, err := parseNextFrame(bytes.NewReader(sf.Append(nil)), nil) + Expect(err).ToNot(HaveOccurred()) + Expect(frame).To(Equal(sf)) + }) + }) }) Context("hijacking", func() { diff --git a/http3/headers.go b/http3/headers.go index 79c070b55..040667178 100644 --- a/http3/headers.go +++ b/http3/headers.go @@ -126,9 +126,14 @@ func requestFromHeaders(headerFields []qpack.HeaderField) (*http.Request, error) return nil, errors.New(":path, :authority and :method must not be empty") } + if !isExtendedConnected && len(hdr.Protocol) > 0 { + return nil, errors.New(":protocol must be empty") + } + var u *url.URL var requestURI string - var protocol string + + protocol := "HTTP/3.0" if isConnect { u = &url.URL{} @@ -137,15 +142,14 @@ func requestFromHeaders(headerFields []qpack.HeaderField) (*http.Request, error) if err != nil { return nil, err } + protocol = hdr.Protocol } else { u.Path = hdr.Path } u.Scheme = hdr.Scheme u.Host = hdr.Authority requestURI = hdr.Authority - protocol = hdr.Protocol } else { - protocol = "HTTP/3.0" u, err = url.ParseRequestURI(hdr.Path) if err != nil { return nil, fmt.Errorf("invalid content length: %w", err) diff --git a/http3/headers_test.go b/http3/headers_test.go index 70fac9974..1fb6ccc58 100644 --- a/http3/headers_test.go +++ b/http3/headers_test.go @@ -212,6 +212,17 @@ var _ = Describe("Request", func() { Expect(err).To(MatchError(":path, :authority and :method must not be empty")) }) + It("errors with invalid protocol", func() { + headers := []qpack.HeaderField{ + {Name: ":path", Value: "/foo"}, + {Name: ":authority", Value: "quic.clemente.io"}, + {Name: ":method", Value: "GET"}, + {Name: ":protocol", Value: "connect-udp"}, + } + _, err := requestFromHeaders(headers) + Expect(err).To(MatchError(":protocol must be empty")) + }) + Context("regular HTTP CONNECT", func() { It("handles CONNECT method", func() { headers := []qpack.HeaderField{ @@ -221,6 +232,7 @@ var _ = Describe("Request", func() { req, err := requestFromHeaders(headers) Expect(err).NotTo(HaveOccurred()) Expect(req.Method).To(Equal(http.MethodConnect)) + Expect(req.Proto).To(Equal("HTTP/3.0")) Expect(req.RequestURI).To(Equal("quic.clemente.io")) }) diff --git a/http3/http_stream.go b/http3/http_stream.go index df62254e7..d01d61986 100644 --- a/http3/http_stream.go +++ b/http3/http_stream.go @@ -5,7 +5,6 @@ import ( "fmt" quic "github.com/refraction-networking/uquic" - "github.com/refraction-networking/uquic/internal/utils" ) // A Stream is a HTTP/3 stream. @@ -115,7 +114,7 @@ func (s *lengthLimitedStream) Read(b []byte) (int, error) { if err := s.checkContentLengthViolation(); err != nil { return 0, err } - n, err := s.stream.Read(b[:utils.Min(int64(len(b)), s.contentLength-s.read)]) + n, err := s.stream.Read(b[:min(int64(len(b)), s.contentLength-s.read)]) s.read += int64(n) if err := s.checkContentLengthViolation(); err != nil { return n, err diff --git a/http3/mock_quic_early_listener_test.go b/http3/mock_quic_early_listener_test.go index 3995c5302..81cd2e2e3 100644 --- a/http3/mock_quic_early_listener_test.go +++ b/http3/mock_quic_early_listener_test.go @@ -3,8 +3,9 @@ // // Generated by this command: // -// mockgen -package http3 -destination mock_quic_early_listener_test.go github.com/refraction-networking/uquic/http3 QUICEarlyListener +// mockgen -typed -package http3 -destination mock_quic_early_listener_test.go github.com/quic-go/quic-go/http3 QUICEarlyListener // + // Package http3 is a generated GoMock package. package http3 @@ -50,9 +51,33 @@ func (m *MockQUICEarlyListener) Accept(arg0 context.Context) (quic.EarlyConnecti } // Accept indicates an expected call of Accept. -func (mr *MockQUICEarlyListenerMockRecorder) Accept(arg0 any) *gomock.Call { +func (mr *MockQUICEarlyListenerMockRecorder) Accept(arg0 any) *MockQUICEarlyListenerAcceptCall { mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Accept", reflect.TypeOf((*MockQUICEarlyListener)(nil).Accept), arg0) + call := mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Accept", reflect.TypeOf((*MockQUICEarlyListener)(nil).Accept), arg0) + return &MockQUICEarlyListenerAcceptCall{Call: call} +} + +// MockQUICEarlyListenerAcceptCall wrap *gomock.Call +type MockQUICEarlyListenerAcceptCall struct { + *gomock.Call +} + +// Return rewrite *gomock.Call.Return +func (c *MockQUICEarlyListenerAcceptCall) Return(arg0 quic.EarlyConnection, arg1 error) *MockQUICEarlyListenerAcceptCall { + c.Call = c.Call.Return(arg0, arg1) + return c +} + +// Do rewrite *gomock.Call.Do +func (c *MockQUICEarlyListenerAcceptCall) Do(f func(context.Context) (quic.EarlyConnection, error)) *MockQUICEarlyListenerAcceptCall { + c.Call = c.Call.Do(f) + return c +} + +// DoAndReturn rewrite *gomock.Call.DoAndReturn +func (c *MockQUICEarlyListenerAcceptCall) DoAndReturn(f func(context.Context) (quic.EarlyConnection, error)) *MockQUICEarlyListenerAcceptCall { + c.Call = c.Call.DoAndReturn(f) + return c } // Addr mocks base method. @@ -64,9 +89,33 @@ func (m *MockQUICEarlyListener) Addr() net.Addr { } // Addr indicates an expected call of Addr. -func (mr *MockQUICEarlyListenerMockRecorder) Addr() *gomock.Call { +func (mr *MockQUICEarlyListenerMockRecorder) Addr() *MockQUICEarlyListenerAddrCall { mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Addr", reflect.TypeOf((*MockQUICEarlyListener)(nil).Addr)) + call := mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Addr", reflect.TypeOf((*MockQUICEarlyListener)(nil).Addr)) + return &MockQUICEarlyListenerAddrCall{Call: call} +} + +// MockQUICEarlyListenerAddrCall wrap *gomock.Call +type MockQUICEarlyListenerAddrCall struct { + *gomock.Call +} + +// Return rewrite *gomock.Call.Return +func (c *MockQUICEarlyListenerAddrCall) Return(arg0 net.Addr) *MockQUICEarlyListenerAddrCall { + c.Call = c.Call.Return(arg0) + return c +} + +// Do rewrite *gomock.Call.Do +func (c *MockQUICEarlyListenerAddrCall) Do(f func() net.Addr) *MockQUICEarlyListenerAddrCall { + c.Call = c.Call.Do(f) + return c +} + +// DoAndReturn rewrite *gomock.Call.DoAndReturn +func (c *MockQUICEarlyListenerAddrCall) DoAndReturn(f func() net.Addr) *MockQUICEarlyListenerAddrCall { + c.Call = c.Call.DoAndReturn(f) + return c } // Close mocks base method. @@ -78,7 +127,31 @@ func (m *MockQUICEarlyListener) Close() error { } // Close indicates an expected call of Close. -func (mr *MockQUICEarlyListenerMockRecorder) Close() *gomock.Call { +func (mr *MockQUICEarlyListenerMockRecorder) Close() *MockQUICEarlyListenerCloseCall { mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Close", reflect.TypeOf((*MockQUICEarlyListener)(nil).Close)) + call := mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Close", reflect.TypeOf((*MockQUICEarlyListener)(nil).Close)) + return &MockQUICEarlyListenerCloseCall{Call: call} +} + +// MockQUICEarlyListenerCloseCall wrap *gomock.Call +type MockQUICEarlyListenerCloseCall struct { + *gomock.Call +} + +// Return rewrite *gomock.Call.Return +func (c *MockQUICEarlyListenerCloseCall) Return(arg0 error) *MockQUICEarlyListenerCloseCall { + c.Call = c.Call.Return(arg0) + return c +} + +// Do rewrite *gomock.Call.Do +func (c *MockQUICEarlyListenerCloseCall) Do(f func() error) *MockQUICEarlyListenerCloseCall { + c.Call = c.Call.Do(f) + return c +} + +// DoAndReturn rewrite *gomock.Call.DoAndReturn +func (c *MockQUICEarlyListenerCloseCall) DoAndReturn(f func() error) *MockQUICEarlyListenerCloseCall { + c.Call = c.Call.DoAndReturn(f) + return c } diff --git a/http3/mock_roundtripcloser_test.go b/http3/mock_roundtripcloser_test.go index 9f580b881..b76245ebe 100644 --- a/http3/mock_roundtripcloser_test.go +++ b/http3/mock_roundtripcloser_test.go @@ -3,8 +3,9 @@ // // Generated by this command: // -// mockgen -build_flags=-tags=gomock -package http3 -destination mock_roundtripcloser_test.go github.com/refraction-networking/uquic/http3 RoundTripCloser +// mockgen -typed -build_flags=-tags=gomock -package http3 -destination mock_roundtripcloser_test.go github.com/quic-go/quic-go/http3 RoundTripCloser // + // Package http3 is a generated GoMock package. package http3 @@ -47,9 +48,33 @@ func (m *MockRoundTripCloser) Close() error { } // Close indicates an expected call of Close. -func (mr *MockRoundTripCloserMockRecorder) Close() *gomock.Call { +func (mr *MockRoundTripCloserMockRecorder) Close() *MockRoundTripCloserCloseCall { mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Close", reflect.TypeOf((*MockRoundTripCloser)(nil).Close)) + call := mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Close", reflect.TypeOf((*MockRoundTripCloser)(nil).Close)) + return &MockRoundTripCloserCloseCall{Call: call} +} + +// MockRoundTripCloserCloseCall wrap *gomock.Call +type MockRoundTripCloserCloseCall struct { + *gomock.Call +} + +// Return rewrite *gomock.Call.Return +func (c *MockRoundTripCloserCloseCall) Return(arg0 error) *MockRoundTripCloserCloseCall { + c.Call = c.Call.Return(arg0) + return c +} + +// Do rewrite *gomock.Call.Do +func (c *MockRoundTripCloserCloseCall) Do(f func() error) *MockRoundTripCloserCloseCall { + c.Call = c.Call.Do(f) + return c +} + +// DoAndReturn rewrite *gomock.Call.DoAndReturn +func (c *MockRoundTripCloserCloseCall) DoAndReturn(f func() error) *MockRoundTripCloserCloseCall { + c.Call = c.Call.DoAndReturn(f) + return c } // HandshakeComplete mocks base method. @@ -61,9 +86,33 @@ func (m *MockRoundTripCloser) HandshakeComplete() bool { } // HandshakeComplete indicates an expected call of HandshakeComplete. -func (mr *MockRoundTripCloserMockRecorder) HandshakeComplete() *gomock.Call { +func (mr *MockRoundTripCloserMockRecorder) HandshakeComplete() *MockRoundTripCloserHandshakeCompleteCall { mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "HandshakeComplete", reflect.TypeOf((*MockRoundTripCloser)(nil).HandshakeComplete)) + call := mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "HandshakeComplete", reflect.TypeOf((*MockRoundTripCloser)(nil).HandshakeComplete)) + return &MockRoundTripCloserHandshakeCompleteCall{Call: call} +} + +// MockRoundTripCloserHandshakeCompleteCall wrap *gomock.Call +type MockRoundTripCloserHandshakeCompleteCall struct { + *gomock.Call +} + +// Return rewrite *gomock.Call.Return +func (c *MockRoundTripCloserHandshakeCompleteCall) Return(arg0 bool) *MockRoundTripCloserHandshakeCompleteCall { + c.Call = c.Call.Return(arg0) + return c +} + +// Do rewrite *gomock.Call.Do +func (c *MockRoundTripCloserHandshakeCompleteCall) Do(f func() bool) *MockRoundTripCloserHandshakeCompleteCall { + c.Call = c.Call.Do(f) + return c +} + +// DoAndReturn rewrite *gomock.Call.DoAndReturn +func (c *MockRoundTripCloserHandshakeCompleteCall) DoAndReturn(f func() bool) *MockRoundTripCloserHandshakeCompleteCall { + c.Call = c.Call.DoAndReturn(f) + return c } // RoundTripOpt mocks base method. @@ -76,7 +125,31 @@ func (m *MockRoundTripCloser) RoundTripOpt(arg0 *http.Request, arg1 RoundTripOpt } // RoundTripOpt indicates an expected call of RoundTripOpt. -func (mr *MockRoundTripCloserMockRecorder) RoundTripOpt(arg0, arg1 any) *gomock.Call { +func (mr *MockRoundTripCloserMockRecorder) RoundTripOpt(arg0, arg1 any) *MockRoundTripCloserRoundTripOptCall { mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "RoundTripOpt", reflect.TypeOf((*MockRoundTripCloser)(nil).RoundTripOpt), arg0, arg1) + call := mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "RoundTripOpt", reflect.TypeOf((*MockRoundTripCloser)(nil).RoundTripOpt), arg0, arg1) + return &MockRoundTripCloserRoundTripOptCall{Call: call} +} + +// MockRoundTripCloserRoundTripOptCall wrap *gomock.Call +type MockRoundTripCloserRoundTripOptCall struct { + *gomock.Call +} + +// Return rewrite *gomock.Call.Return +func (c *MockRoundTripCloserRoundTripOptCall) Return(arg0 *http.Response, arg1 error) *MockRoundTripCloserRoundTripOptCall { + c.Call = c.Call.Return(arg0, arg1) + return c +} + +// Do rewrite *gomock.Call.Do +func (c *MockRoundTripCloserRoundTripOptCall) Do(f func(*http.Request, RoundTripOpt) (*http.Response, error)) *MockRoundTripCloserRoundTripOptCall { + c.Call = c.Call.Do(f) + return c +} + +// DoAndReturn rewrite *gomock.Call.DoAndReturn +func (c *MockRoundTripCloserRoundTripOptCall) DoAndReturn(f func(*http.Request, RoundTripOpt) (*http.Response, error)) *MockRoundTripCloserRoundTripOptCall { + c.Call = c.Call.DoAndReturn(f) + return c } diff --git a/http3/mockgen.go b/http3/mockgen.go index cf4b917a6..57af7972c 100644 --- a/http3/mockgen.go +++ b/http3/mockgen.go @@ -2,7 +2,7 @@ package http3 -//go:generate sh -c "go run go.uber.org/mock/mockgen -build_flags=\"-tags=gomock\" -package http3 -destination mock_roundtripcloser_test.go github.com/refraction-networking/uquic/http3 RoundTripCloser" +//go:generate sh -c "go run go.uber.org/mock/mockgen -typed -build_flags=\"-tags=gomock\" -package http3 -destination mock_roundtripcloser_test.go github.com/quic-go/quic-go/http3 RoundTripCloser" type RoundTripCloser = roundTripCloser -//go:generate sh -c "go run go.uber.org/mock/mockgen -package http3 -destination mock_quic_early_listener_test.go github.com/refraction-networking/uquic/http3 QUICEarlyListener" +//go:generate sh -c "go run go.uber.org/mock/mockgen -typed -package http3 -destination mock_quic_early_listener_test.go github.com/quic-go/quic-go/http3 QUICEarlyListener" diff --git a/http3/response_writer.go b/http3/response_writer.go index ed58e9d1b..c86c37b4d 100644 --- a/http3/response_writer.go +++ b/http3/response_writer.go @@ -67,9 +67,10 @@ type responseWriter struct { bufferedStr *bufio.Writer buf []byte - headerWritten bool contentLen int64 // if handler set valid Content-Length header numWritten int64 // bytes written + headerWritten bool + isHead bool } var ( @@ -162,6 +163,10 @@ func (w *responseWriter) Write(p []byte) (int, error) { return 0, http.ErrContentLength } + if w.isHead { + return len(p), nil + } + df := &dataFrame{Length: uint64(len(p))} w.buf = w.buf[:0] w.buf = df.Append(w.buf) diff --git a/http3/roundtrip.go b/http3/roundtrip.go index 3ef5ddd2a..3d5885b71 100644 --- a/http3/roundtrip.go +++ b/http3/roundtrip.go @@ -17,6 +17,30 @@ import ( "golang.org/x/net/http/httpguts" ) +// Settings are HTTP/3 settings that apply to the underlying connection. +type Settings struct { + // Support for HTTP/3 datagrams (RFC 9297) + EnableDatagram bool + // Extended CONNECT, RFC 9220 + EnableExtendedConnect bool + // Other settings, defined by the application + Other map[uint64]uint64 +} + +// RoundTripOpt are options for the Transport.RoundTripOpt method. +type RoundTripOpt struct { + // OnlyCachedConn controls whether the RoundTripper may create a new QUIC connection. + // If set true and no cached connection is available, RoundTripOpt will return ErrNoCachedConn. + OnlyCachedConn bool + // DontCloseRequestStream controls whether the request stream is closed after sending the request. + // If set, context cancellations have no effect after the response headers are received. + DontCloseRequestStream bool + // CheckSettings is run before the request is sent to the server. + // If not yet received, it blocks until the server's SETTINGS frame is received. + // If an error is returned, the request won't be sent to the server, and the error is returned. + CheckSettings func(Settings) error +} + type roundTripCloser interface { RoundTripOpt(*http.Request, RoundTripOpt) (*http.Response, error) HandshakeComplete() bool @@ -50,9 +74,8 @@ type RoundTripper struct { // If nil, reasonable default values will be used. QuicConfig *quic.Config - // Enable support for HTTP/3 datagrams. - // If set to true, QuicConfig.EnableDatagram will be set. - // See https://datatracker.ietf.org/doc/html/rfc9297. + // Enable support for HTTP/3 datagrams (RFC 9297). + // If a QuicConfig is set, datagram support also needs to be enabled on the QUIC layer by setting EnableDatagrams. EnableDatagrams bool // Additional HTTP/3 settings. @@ -89,16 +112,6 @@ type RoundTripper struct { transport *quic.Transport } -// RoundTripOpt are options for the Transport.RoundTripOpt method. -type RoundTripOpt struct { - // OnlyCachedConn controls whether the RoundTripper may create a new QUIC connection. - // If set true and no cached connection is available, RoundTripOpt will return ErrNoCachedConn. - OnlyCachedConn bool - // DontCloseRequestStream controls whether the request stream is closed after sending the request. - // If set, context cancellations have no effect after the response headers are received. - DontCloseRequestStream bool -} - var ( _ http.RoundTripper = &RoundTripper{} _ io.Closer = &RoundTripper{} @@ -204,6 +217,7 @@ func (r *RoundTripper) getClient(hostname string, onlyCached bool) (rtc *roundTr MaxHeaderBytes: r.MaxResponseHeaderBytes, StreamHijacker: r.StreamHijacker, UniStreamHijacker: r.UniStreamHijacker, + AdditionalSettings: r.AdditionalSettings, }, r.QuicConfig, dial, diff --git a/http3/roundtrip_test.go b/http3/roundtrip_test.go index e685aa618..ca3c65355 100644 --- a/http3/roundtrip_test.go +++ b/http3/roundtrip_test.go @@ -72,6 +72,21 @@ var _ = Describe("RoundTripper", func() { Expect(err).To(MatchError(testErr)) }) + It("creates new clients with additional settings", func() { + testErr := errors.New("test err") + req, err := http.NewRequest("GET", "https://quic.clemente.io/foobar.html", nil) + Expect(err).ToNot(HaveOccurred()) + rt.AdditionalSettings = map[uint64]uint64{1337: 42} + rt.newClient = func(_ string, _ *tls.Config, opts *roundTripperOpts, conf *quic.Config, _ dialFunc) (roundTripCloser, error) { + cl := NewMockRoundTripCloser(mockCtrl) + cl.EXPECT().RoundTripOpt(gomock.Any(), gomock.Any()).Return(nil, testErr) + Expect(opts.AdditionalSettings).To(HaveKeyWithValue(uint64(1337), uint64(42))) + return cl, nil + } + _, err = rt.RoundTrip(req) + Expect(err).To(MatchError(testErr)) + }) + It("uses the quic.Config, if provided", func() { config := &quic.Config{HandshakeIdleTimeout: time.Millisecond} var receivedConfig *quic.Config @@ -85,6 +100,19 @@ var _ = Describe("RoundTripper", func() { Expect(receivedConfig.HandshakeIdleTimeout).To(Equal(config.HandshakeIdleTimeout)) }) + It("requires quic.Config.EnableDatagram if HTTP datagrams are enabled", func() { + rt.QuicConfig = &quic.Config{EnableDatagrams: false} + rt.Dial = func(_ context.Context, _ string, _ *tls.Config, config *quic.Config) (quic.EarlyConnection, error) { + return nil, errors.New("handshake error") + } + rt.EnableDatagrams = true + _, err := rt.RoundTrip(req) + Expect(err).To(MatchError("HTTP Datagrams enabled, but QUIC Datagrams disabled")) + rt.QuicConfig.EnableDatagrams = true + _, err = rt.RoundTrip(req) + Expect(err).To(MatchError("handshake error")) + }) + It("uses the custom dialer, if provided", func() { var dialed bool dialer := func(_ context.Context, _ string, tlsCfgP *tls.Config, cfg *quic.Config) (quic.EarlyConnection, error) { diff --git a/http3/server.go b/http3/server.go index 20a0323a9..3f19b7603 100644 --- a/http3/server.go +++ b/http3/server.go @@ -12,6 +12,7 @@ import ( "strconv" "strings" "sync" + "sync/atomic" "time" tls "github.com/refraction-networking/utls" @@ -32,6 +33,7 @@ var ( quicListenAddr = func(addr string, tlsConf *tls.Config, config *quic.Config) (QUICEarlyListener, error) { return quic.ListenAddrEarly(addr, tlsConf, config) } + errPanicked = errors.New("panicked") ) // NextProtoH3 is the ALPN protocol negotiated during the TLS handshake, for QUIC v1 and v2. @@ -56,7 +58,7 @@ type QUICEarlyListener interface { var _ QUICEarlyListener = &quic.EarlyListener{} -func versionToALPN(v protocol.VersionNumber) string { +func versionToALPN(v protocol.Version) string { //nolint:exhaustive // These are all the versions we care about. switch v { case protocol.Version1, protocol.Version2: @@ -78,7 +80,7 @@ func ConfigureTLSConfig(tlsConf *tls.Config) *tls.Config { // determine the ALPN from the QUIC version used proto := NextProtoH3 val := ch.Context().Value(quic.QUICVersionContextKey) - if v, ok := val.(quic.VersionNumber); ok { + if v, ok := val.(quic.Version); ok { proto = versionToALPN(v) } config := tlsConf @@ -117,6 +119,16 @@ func (k *contextKey) String() string { return "quic-go/http3 context value " + k // type *http3.Server. var ServerContextKey = &contextKey{"http3-server"} +// RemoteAddrContextKey is a context key. It can be used in +// HTTP handlers with Context.Value to access the remote +// address of the connection. The associated value will be of +// type net.Addr. +// +// Use this value instead of [http.Request.RemoteAddr] if you +// require access to the remote address of the connection rather +// than its string representation. +var RemoteAddrContextKey = &contextKey{"remote-addr"} + type requestError struct { err error streamErr ErrCode @@ -202,6 +214,11 @@ type Server struct { // In that case, the stream type will not be set. UniStreamHijacker func(StreamType, quic.Connection, quic.ReceiveStream, error) (hijacked bool) + // ConnContext optionally specifies a function that modifies + // the context used for a new connection c. The provided ctx + // has a ServerContextKey value. + ConnContext func(ctx context.Context, c quic.Connection) context.Context + mutex sync.RWMutex listeners map[*QUICEarlyListener]listenerInfo @@ -275,7 +292,7 @@ func (s *Server) ServeListener(ln QUICEarlyListener) error { } go func() { if err := s.handleConn(conn); err != nil { - s.logger.Debugf(err.Error()) + s.logger.Debugf("handling connection failed: %s", err) } }() } @@ -409,10 +426,11 @@ func (s *Server) addListener(l *QUICEarlyListener) error { s.listeners = make(map[*QUICEarlyListener]listenerInfo) } - if port, err := extractPort((*l).Addr().String()); err == nil { + laddr := (*l).Addr() + if port, err := extractPort(laddr.String()); err == nil { s.listeners[l] = listenerInfo{port} } else { - s.logger.Errorf("Unable to extract port from listener %+v, will not be announced using SetQuicHeaders: %s", err) + s.logger.Errorf("Unable to extract port from listener %s, will not be announced using SetQuicHeaders: %s", laddr, err) s.listeners[l] = listenerInfo{} } s.generateAltSvcHeader() @@ -436,7 +454,11 @@ func (s *Server) handleConn(conn quic.Connection) error { } b := make([]byte, 0, 64) b = quicvarint.Append(b, streamTypeControlStream) // stream type - b = (&settingsFrame{Datagram: s.EnableDatagrams, Other: s.AdditionalSettings}).Append(b) + b = (&settingsFrame{ + Datagram: s.EnableDatagrams, + ExtendedConnect: true, + Other: s.AdditionalSettings, + }).Append(b) str.Write(b) go s.handleUnidirectionalStreams(conn) @@ -479,6 +501,8 @@ func (s *Server) handleConn(conn quic.Connection) error { } func (s *Server) handleUnidirectionalStreams(conn quic.Connection) { + var rcvdControlStream atomic.Bool + for { str, err := conn.AcceptUniStream(context.Background()) if err != nil { @@ -512,6 +536,11 @@ func (s *Server) handleUnidirectionalStreams(conn quic.Connection) { str.CancelRead(quic.StreamErrorCode(ErrCodeStreamCreationError)) return } + // Only a single control stream is allowed. + if isFirstControlStr := rcvdControlStream.CompareAndSwap(false, true); !isFirstControlStr { + conn.CloseWithError(quic.ApplicationErrorCode(ErrCodeStreamCreationError), "duplicate control stream") + return + } f, err := parseNextFrame(str, nil) if err != nil { conn.CloseWithError(quic.ApplicationErrorCode(ErrCodeFrameError), "") @@ -617,8 +646,18 @@ func (s *Server) handleRequest(conn quic.Connection, str quic.Stream, decoder *q ctx := str.Context() ctx = context.WithValue(ctx, ServerContextKey, s) ctx = context.WithValue(ctx, http.LocalAddrContextKey, conn.LocalAddr()) + ctx = context.WithValue(ctx, RemoteAddrContextKey, conn.RemoteAddr()) + if s.ConnContext != nil { + ctx = s.ConnContext(ctx, conn) + if ctx == nil { + panic("http3: ConnContext returned nil") + } + } req = req.WithContext(ctx) r := newResponseWriter(str, conn, s.logger) + if req.Method == http.MethodHead { + r.isHead = true + } handler := s.Handler if handler == nil { handler = http.DefaultServeMux @@ -658,6 +697,11 @@ func (s *Server) handleRequest(conn quic.Connection, str quic.Stream, decoder *q } // If the EOF was read by the handler, CancelRead() is a no-op. str.CancelRead(quic.StreamErrorCode(ErrCodeNoError)) + + // abort the stream when there is a panic + if panicked { + return newStreamError(ErrCodeInternalError, errPanicked) + } return requestError{} } @@ -722,7 +766,7 @@ func ListenAndServeQUIC(addr, certFile, keyFile string, handler http.Handler) er return server.ListenAndServeTLS(certFile, keyFile) } -// ListenAndServe listens on the given network address for both, TLS and QUIC +// ListenAndServe listens on the given network address for both TLS/TCP and QUIC // connections in parallel. It returns if one of the two returns an error. // http.DefaultServeMux is used when handler is nil. // The correct Alt-Svc headers for QUIC are set. @@ -764,8 +808,8 @@ func ListenAndServe(addr, certFile, keyFile string, handler http.Handler) error Handler: handler, } - hErr := make(chan error) - qErr := make(chan error) + hErr := make(chan error, 1) + qErr := make(chan error, 1) go func() { hErr <- http.ListenAndServeTLS(addr, certFile, keyFile, http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { quicServer.SetQuicHeaders(w.Header()) diff --git a/http3/server_test.go b/http3/server_test.go index 5ab313dea..ce284a311 100644 --- a/http3/server_test.go +++ b/http3/server_test.go @@ -17,6 +17,7 @@ import ( quic "github.com/refraction-networking/uquic" mockquic "github.com/refraction-networking/uquic/internal/mocks/quic" "github.com/refraction-networking/uquic/internal/protocol" + "github.com/refraction-networking/uquic/internal/qerr" "github.com/refraction-networking/uquic/internal/testdata" "github.com/refraction-networking/uquic/internal/utils" "github.com/refraction-networking/uquic/quicvarint" @@ -68,11 +69,15 @@ var _ = Describe("Server", func() { s *Server origQuicListenAddr = quicListenAddr ) + type testConnContextKey string BeforeEach(func() { s = &Server{ TLSConfig: testdata.GetTLSConfig(), logger: utils.DefaultLogger, + ConnContext: func(ctx context.Context, c quic.Connection) context.Context { + return context.WithValue(ctx, testConnContextKey("test"), c) + }, } origQuicListenAddr = quicListenAddr }) @@ -164,6 +169,7 @@ var _ = Describe("Server", func() { Expect(req.Host).To(Equal("www.example.com")) Expect(req.RemoteAddr).To(Equal("127.0.0.1:1337")) Expect(req.Context().Value(ServerContextKey)).To(Equal(s)) + Expect(req.Context().Value(testConnContextKey("test"))).ToNot(Equal(nil)) }) It("returns 200 with an empty handler", func() { @@ -222,6 +228,45 @@ var _ = Describe("Server", func() { Expect(hfs).To(HaveLen(3)) }) + It("response to HEAD request should not have body", func() { + s.Handler = http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.Write([]byte("foobar")) + }) + + headRequest, err := http.NewRequest("HEAD", "https://www.example.com", nil) + Expect(err).ToNot(HaveOccurred()) + responseBuf := &bytes.Buffer{} + setRequest(encodeRequest(headRequest)) + str.EXPECT().Context().Return(reqContext) + str.EXPECT().Write(gomock.Any()).DoAndReturn(responseBuf.Write).AnyTimes() + str.EXPECT().CancelRead(gomock.Any()) + serr := s.handleRequest(conn, str, qpackDecoder, nil) + Expect(serr.err).ToNot(HaveOccurred()) + hfs := decodeHeader(responseBuf) + Expect(hfs).To(HaveKeyWithValue(":status", []string{"200"})) + Expect(responseBuf.Bytes()).To(HaveLen(0)) + }) + + It("response to HEAD request should also do content sniffing", func() { + s.Handler = http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.Write([]byte("")) + }) + + headRequest, err := http.NewRequest("HEAD", "https://www.example.com", nil) + Expect(err).ToNot(HaveOccurred()) + responseBuf := &bytes.Buffer{} + setRequest(encodeRequest(headRequest)) + str.EXPECT().Context().Return(reqContext) + str.EXPECT().Write(gomock.Any()).DoAndReturn(responseBuf.Write).AnyTimes() + str.EXPECT().CancelRead(gomock.Any()) + serr := s.handleRequest(conn, str, qpackDecoder, nil) + Expect(serr.err).ToNot(HaveOccurred()) + hfs := decodeHeader(responseBuf) + Expect(hfs).To(HaveKeyWithValue(":status", []string{"200"})) + Expect(hfs).To(HaveKeyWithValue("content-length", []string{"13"})) + Expect(hfs).To(HaveKeyWithValue("content-type", []string{"text/html; charset=utf-8"})) + }) + It("handles a aborting handler", func() { s.Handler = http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { panic(http.ErrAbortHandler) @@ -234,7 +279,7 @@ var _ = Describe("Server", func() { str.EXPECT().CancelRead(gomock.Any()) serr := s.handleRequest(conn, str, qpackDecoder, nil) - Expect(serr.err).ToNot(HaveOccurred()) + Expect(serr.err).To(MatchError(errPanicked)) Expect(responseBuf.Bytes()).To(HaveLen(0)) }) @@ -250,7 +295,7 @@ var _ = Describe("Server", func() { str.EXPECT().CancelRead(gomock.Any()) serr := s.handleRequest(conn, str, qpackDecoder, nil) - Expect(serr.err).ToNot(HaveOccurred()) + Expect(serr.err).To(MatchError(errPanicked)) Expect(responseBuf.Bytes()).To(HaveLen(0)) }) @@ -454,7 +499,7 @@ var _ = Describe("Server", func() { Context("control stream handling", func() { var conn *mockquic.MockEarlyConnection - testDone := make(chan struct{}) + testDone := make(chan struct{}, 1) BeforeEach(func() { conn = mockquic.NewMockEarlyConnection(mockCtrl) @@ -485,6 +530,34 @@ var _ = Describe("Server", func() { time.Sleep(scaleDuration(20 * time.Millisecond)) // don't EXPECT any calls to conn.CloseWithError }) + It("rejects duplicate control streams", func() { + b := quicvarint.Append(nil, streamTypeControlStream) + b = (&settingsFrame{}).Append(b) + r1 := bytes.NewReader(b) + controlStr1 := mockquic.NewMockStream(mockCtrl) + controlStr1.EXPECT().Read(gomock.Any()).DoAndReturn(r1.Read).AnyTimes() + r2 := bytes.NewReader(b) + controlStr2 := mockquic.NewMockStream(mockCtrl) + controlStr2.EXPECT().Read(gomock.Any()).DoAndReturn(r2.Read).AnyTimes() + done := make(chan struct{}) + conn.EXPECT().CloseWithError(qerr.ApplicationErrorCode(ErrCodeStreamCreationError), "duplicate control stream").Do(func(qerr.ApplicationErrorCode, string) error { + close(done) + return nil + }) + conn.EXPECT().AcceptUniStream(gomock.Any()).DoAndReturn(func(context.Context) (quic.ReceiveStream, error) { + return controlStr1, nil + }) + conn.EXPECT().AcceptUniStream(gomock.Any()).DoAndReturn(func(context.Context) (quic.ReceiveStream, error) { + return controlStr2, nil + }) + conn.EXPECT().AcceptUniStream(gomock.Any()).DoAndReturn(func(context.Context) (quic.ReceiveStream, error) { + <-done + return nil, errors.New("test done") + }) + s.handleConn(conn) + Eventually(done).Should(BeClosed()) + }) + for _, t := range []uint64{streamTypeQPACKEncoderStream, streamTypeQPACKDecoderStream} { streamType := t name := "encoder" @@ -514,9 +587,7 @@ var _ = Describe("Server", func() { str := mockquic.NewMockStream(mockCtrl) str.EXPECT().Read(gomock.Any()).DoAndReturn(buf.Read).AnyTimes() done := make(chan struct{}) - str.EXPECT().CancelRead(quic.StreamErrorCode(ErrCodeStreamCreationError)).Do(func(code quic.StreamErrorCode) { - close(done) - }) + str.EXPECT().CancelRead(quic.StreamErrorCode(ErrCodeStreamCreationError)).Do(func(quic.StreamErrorCode) { close(done) }) conn.EXPECT().AcceptUniStream(gomock.Any()).DoAndReturn(func(context.Context) (quic.ReceiveStream, error) { return str, nil @@ -543,10 +614,9 @@ var _ = Describe("Server", func() { return nil, errors.New("test done") }) done := make(chan struct{}) - conn.EXPECT().CloseWithError(gomock.Any(), gomock.Any()).Do(func(code quic.ApplicationErrorCode, _ string) { - defer GinkgoRecover() - Expect(code).To(BeEquivalentTo(ErrCodeMissingSettings)) + conn.EXPECT().CloseWithError(quic.ApplicationErrorCode(ErrCodeMissingSettings), gomock.Any()).Do(func(quic.ApplicationErrorCode, string) error { close(done) + return nil }) s.handleConn(conn) Eventually(done).Should(BeClosed()) @@ -566,10 +636,9 @@ var _ = Describe("Server", func() { return nil, errors.New("test done") }) done := make(chan struct{}) - conn.EXPECT().CloseWithError(gomock.Any(), gomock.Any()).Do(func(code quic.ApplicationErrorCode, _ string) { - defer GinkgoRecover() - Expect(code).To(BeEquivalentTo(ErrCodeFrameError)) + conn.EXPECT().CloseWithError(quic.ApplicationErrorCode(ErrCodeFrameError), gomock.Any()).Do(func(quic.ApplicationErrorCode, string) error { close(done) + return nil }) s.handleConn(conn) Eventually(done).Should(BeClosed()) @@ -589,10 +658,9 @@ var _ = Describe("Server", func() { return nil, errors.New("test done") }) done := make(chan struct{}) - conn.EXPECT().CloseWithError(gomock.Any(), gomock.Any()).Do(func(code quic.ApplicationErrorCode, _ string) { - defer GinkgoRecover() - Expect(code).To(BeEquivalentTo(ErrCodeStreamCreationError)) + conn.EXPECT().CloseWithError(quic.ApplicationErrorCode(ErrCodeStreamCreationError), gomock.Any()).Do(func(quic.ApplicationErrorCode, string) error { close(done) + return nil }) s.handleConn(conn) Eventually(done).Should(BeClosed()) @@ -614,11 +682,9 @@ var _ = Describe("Server", func() { }) conn.EXPECT().ConnectionState().Return(quic.ConnectionState{SupportsDatagrams: false}) done := make(chan struct{}) - conn.EXPECT().CloseWithError(gomock.Any(), gomock.Any()).Do(func(code quic.ApplicationErrorCode, reason string) { - defer GinkgoRecover() - Expect(code).To(BeEquivalentTo(ErrCodeSettingsError)) - Expect(reason).To(Equal("missing QUIC Datagram support")) + conn.EXPECT().CloseWithError(quic.ApplicationErrorCode(ErrCodeSettingsError), "missing QUIC Datagram support").Do(func(quic.ApplicationErrorCode, string) error { close(done) + return nil }) s.handleConn(conn) Eventually(done).Should(BeClosed()) @@ -664,7 +730,7 @@ var _ = Describe("Server", func() { str.EXPECT().Context().Return(reqContext) str.EXPECT().Write(gomock.Any()).DoAndReturn(responseBuf.Write).AnyTimes() str.EXPECT().CancelRead(quic.StreamErrorCode(ErrCodeNoError)) - str.EXPECT().Close().Do(func() { close(done) }) + str.EXPECT().Close().Do(func() error { close(done); return nil }) s.handleConn(conn) Eventually(done).Should(BeClosed()) @@ -739,9 +805,9 @@ var _ = Describe("Server", func() { }).AnyTimes() done := make(chan struct{}) - conn.EXPECT().CloseWithError(gomock.Any(), gomock.Any()).Do(func(code quic.ApplicationErrorCode, _ string) { - Expect(code).To(Equal(quic.ApplicationErrorCode(ErrCodeFrameUnexpected))) + conn.EXPECT().CloseWithError(quic.ApplicationErrorCode(ErrCodeFrameUnexpected), gomock.Any()).Do(func(quic.ApplicationErrorCode, string) error { close(done) + return nil }) s.handleConn(conn) Eventually(done).Should(BeClosed()) @@ -819,7 +885,7 @@ var _ = Describe("Server", func() { Context("setting http headers", func() { BeforeEach(func() { - s.QuicConfig = &quic.Config{Versions: []protocol.VersionNumber{protocol.Version1}} + s.QuicConfig = &quic.Config{Versions: []protocol.Version{protocol.Version1}} }) var ln1 QUICEarlyListener @@ -880,7 +946,7 @@ var _ = Describe("Server", func() { }) It("works if the quic.Config sets QUIC versions", func() { - s.QuicConfig.Versions = []quic.VersionNumber{quic.Version1, quic.Version2} + s.QuicConfig.Versions = []quic.Version{quic.Version1, quic.Version2} addListener(":443", &ln1) checkSetHeaders(Equal(http.Header{"Alt-Svc": {`h3=":443"; ma=2592000`}})) removeListener(&ln1) @@ -919,7 +985,7 @@ var _ = Describe("Server", func() { }) It("doesn't duplicate Alt-Svc values", func() { - s.QuicConfig.Versions = []quic.VersionNumber{quic.Version1, quic.Version1} + s.QuicConfig.Versions = []quic.Version{quic.Version1, quic.Version1} addListener(":443", &ln1) checkSetHeaders(Equal(http.Header{"Alt-Svc": {`h3=":443"; ma=2592000`}})) removeListener(&ln1) @@ -960,7 +1026,7 @@ var _ = Describe("Server", func() { Context("ConfigureTLSConfig", func() { It("advertises v1 by default", func() { conf := ConfigureTLSConfig(testdata.GetTLSConfig()) - ln, err := quic.ListenAddr("localhost:0", conf, &quic.Config{Versions: []quic.VersionNumber{quic.Version1}}) + ln, err := quic.ListenAddr("localhost:0", conf, &quic.Config{Versions: []quic.Version{quic.Version1}}) Expect(err).ToNot(HaveOccurred()) defer ln.Close() c, err := quic.DialAddr(context.Background(), ln.Addr().String(), &tls.Config{InsecureSkipVerify: true, NextProtos: []string{NextProtoH3}}, nil) @@ -988,7 +1054,7 @@ var _ = Describe("Server", func() { }, } - ln, err := quic.ListenAddr("localhost:0", ConfigureTLSConfig(tlsConf), &quic.Config{Versions: []quic.VersionNumber{quic.Version1}}) + ln, err := quic.ListenAddr("localhost:0", ConfigureTLSConfig(tlsConf), &quic.Config{Versions: []quic.Version{quic.Version1}}) Expect(err).ToNot(HaveOccurred()) defer ln.Close() c, err := quic.DialAddr(context.Background(), ln.Addr().String(), &tls.Config{InsecureSkipVerify: true, NextProtos: []string{NextProtoH3}}, nil) @@ -1001,7 +1067,7 @@ var _ = Describe("Server", func() { tlsConf := testdata.GetTLSConfig() tlsConf.GetConfigForClient = func(*tls.ClientHelloInfo) (*tls.Config, error) { return nil, nil } - ln, err := quic.ListenAddr("localhost:0", ConfigureTLSConfig(tlsConf), &quic.Config{Versions: []quic.VersionNumber{quic.Version1}}) + ln, err := quic.ListenAddr("localhost:0", ConfigureTLSConfig(tlsConf), &quic.Config{Versions: []quic.Version{quic.Version1}}) Expect(err).ToNot(HaveOccurred()) defer ln.Close() c, err := quic.DialAddr(context.Background(), ln.Addr().String(), &tls.Config{InsecureSkipVerify: true, NextProtos: []string{NextProtoH3}}, nil) @@ -1019,7 +1085,7 @@ var _ = Describe("Server", func() { }, } - ln, err := quic.ListenAddr("localhost:0", ConfigureTLSConfig(tlsConf), &quic.Config{Versions: []quic.VersionNumber{quic.Version1}}) + ln, err := quic.ListenAddr("localhost:0", ConfigureTLSConfig(tlsConf), &quic.Config{Versions: []quic.Version{quic.Version1}}) Expect(err).ToNot(HaveOccurred()) defer ln.Close() c, err := quic.DialAddr(context.Background(), ln.Addr().String(), &tls.Config{InsecureSkipVerify: true, NextProtos: []string{NextProtoH3}}, nil) @@ -1051,7 +1117,7 @@ var _ = Describe("Server", func() { } stopAccept := make(chan struct{}) - ln.EXPECT().Accept(gomock.Any()).DoAndReturn(func(context.Context) (quic.Connection, error) { + ln.EXPECT().Accept(gomock.Any()).DoAndReturn(func(context.Context) (quic.EarlyConnection, error) { <-stopAccept return nil, errors.New("closed") }) @@ -1064,7 +1130,7 @@ var _ = Describe("Server", func() { }() Consistently(done).ShouldNot(BeClosed()) - ln.EXPECT().Close().Do(func() { close(stopAccept) }) + ln.EXPECT().Close().Do(func() error { close(stopAccept); return nil }) Expect(s.Close()).To(Succeed()) Eventually(done).Should(BeClosed()) }) @@ -1086,13 +1152,13 @@ var _ = Describe("Server", func() { } stopAccept1 := make(chan struct{}) - ln1.EXPECT().Accept(gomock.Any()).DoAndReturn(func(context.Context) (quic.Connection, error) { + ln1.EXPECT().Accept(gomock.Any()).DoAndReturn(func(context.Context) (quic.EarlyConnection, error) { <-stopAccept1 return nil, errors.New("closed") }) ln1.EXPECT().Addr() // generate alt-svc headers stopAccept2 := make(chan struct{}) - ln2.EXPECT().Accept(gomock.Any()).DoAndReturn(func(context.Context) (quic.Connection, error) { + ln2.EXPECT().Accept(gomock.Any()).DoAndReturn(func(context.Context) (quic.EarlyConnection, error) { <-stopAccept2 return nil, errors.New("closed") }) @@ -1113,8 +1179,8 @@ var _ = Describe("Server", func() { Consistently(done1).ShouldNot(BeClosed()) Expect(done2).ToNot(BeClosed()) - ln1.EXPECT().Close().Do(func() { close(stopAccept1) }) - ln2.EXPECT().Close().Do(func() { close(stopAccept2) }) + ln1.EXPECT().Close().Do(func() error { close(stopAccept1); return nil }) + ln2.EXPECT().Close().Do(func() error { close(stopAccept2); return nil }) Expect(s.Close()).To(Succeed()) Eventually(done1).Should(BeClosed()) Eventually(done2).Should(BeClosed()) @@ -1139,7 +1205,7 @@ var _ = Describe("Server", func() { s := &Server{} stopAccept := make(chan struct{}) - ln.EXPECT().Accept(gomock.Any()).DoAndReturn(func(context.Context) (quic.Connection, error) { + ln.EXPECT().Accept(gomock.Any()).DoAndReturn(func(context.Context) (quic.EarlyConnection, error) { <-stopAccept return nil, errors.New("closed") }) @@ -1153,7 +1219,7 @@ var _ = Describe("Server", func() { Consistently(func() int32 { return atomic.LoadInt32(&called) }).Should(Equal(int32(0))) Consistently(done).ShouldNot(BeClosed()) - ln.EXPECT().Close().Do(func() { close(stopAccept) }) + ln.EXPECT().Close().Do(func() error { close(stopAccept); return nil }) Expect(s.Close()).To(Succeed()) Eventually(done).Should(BeClosed()) }) @@ -1173,13 +1239,13 @@ var _ = Describe("Server", func() { s := &Server{} stopAccept1 := make(chan struct{}) - ln1.EXPECT().Accept(gomock.Any()).DoAndReturn(func(context.Context) (quic.Connection, error) { + ln1.EXPECT().Accept(gomock.Any()).DoAndReturn(func(context.Context) (quic.EarlyConnection, error) { <-stopAccept1 return nil, errors.New("closed") }) ln1.EXPECT().Addr() // generate alt-svc headers stopAccept2 := make(chan struct{}) - ln2.EXPECT().Accept(gomock.Any()).DoAndReturn(func(context.Context) (quic.Connection, error) { + ln2.EXPECT().Accept(gomock.Any()).DoAndReturn(func(context.Context) (quic.EarlyConnection, error) { <-stopAccept2 return nil, errors.New("closed") }) @@ -1201,8 +1267,8 @@ var _ = Describe("Server", func() { Consistently(func() int32 { return atomic.LoadInt32(&called) }).Should(Equal(int32(0))) Consistently(done1).ShouldNot(BeClosed()) Expect(done2).ToNot(BeClosed()) - ln1.EXPECT().Close().Do(func() { close(stopAccept1) }) - ln2.EXPECT().Close().Do(func() { close(stopAccept2) }) + ln1.EXPECT().Close().Do(func() error { close(stopAccept1); return nil }) + ln2.EXPECT().Close().Do(func() error { close(stopAccept2); return nil }) Expect(s.Close()).To(Succeed()) Eventually(done1).Should(BeClosed()) Eventually(done2).Should(BeClosed()) diff --git a/integrationtests/gomodvendor/go.mod b/integrationtests/gomodvendor/go.mod index a47ba6439..106f783ed 100644 --- a/integrationtests/gomodvendor/go.mod +++ b/integrationtests/gomodvendor/go.mod @@ -1,8 +1,23 @@ module test -go 1.16 +go 1.21 // The version doesn't matter here, as we're replacing it with the currently checked out code anyway. -require github.com/refraction-networking/uquic v0.21.0 +require github.com/quic-go/quic-go v0.21.0 -replace github.com/refraction-networking/uquic => ../../ +require ( + github.com/go-task/slim-sprig v0.0.0-20230315185526-52ccab3ef572 // indirect + github.com/google/pprof v0.0.0-20210407192527-94a9f03dee38 // indirect + github.com/onsi/ginkgo/v2 v2.9.5 // indirect + github.com/quic-go/qpack v0.4.0 // indirect + go.uber.org/mock v0.4.0 // indirect + golang.org/x/crypto v0.4.0 // indirect + golang.org/x/exp v0.0.0-20221205204356-47842c84f3db // indirect + golang.org/x/mod v0.11.0 // indirect + golang.org/x/net v0.10.0 // indirect + golang.org/x/sys v0.8.0 // indirect + golang.org/x/text v0.9.0 // indirect + golang.org/x/tools v0.9.1 // indirect +) + +replace github.com/quic-go/quic-go => ../../ diff --git a/integrationtests/gomodvendor/go.sum b/integrationtests/gomodvendor/go.sum index 29427a8ce..a006faf08 100644 --- a/integrationtests/gomodvendor/go.sum +++ b/integrationtests/gomodvendor/go.sum @@ -1,364 +1,53 @@ -cloud.google.com/go v0.26.0/go.mod h1:aQUYkXzVsufM+DwF1aE+0xfcU+56JwCaLick0ClmMTw= -cloud.google.com/go v0.31.0/go.mod h1:aQUYkXzVsufM+DwF1aE+0xfcU+56JwCaLick0ClmMTw= -cloud.google.com/go v0.34.0/go.mod h1:aQUYkXzVsufM+DwF1aE+0xfcU+56JwCaLick0ClmMTw= -cloud.google.com/go v0.37.0/go.mod h1:TS1dMSSfndXH133OKGwekG838Om/cQT0BUHV3HcBgoo= -dmitri.shuralyov.com/app/changes v0.0.0-20180602232624-0a106ad413e3/go.mod h1:Yl+fi1br7+Rr3LqpNJf1/uxUdtRUV+Tnj0o93V2B9MU= -dmitri.shuralyov.com/html/belt v0.0.0-20180602232347-f7d459c86be0/go.mod h1:JLBrvjyP0v+ecvNYvCpyZgu5/xkfAUhi6wJj28eUfSU= -dmitri.shuralyov.com/service/change v0.0.0-20181023043359-a85b471d5412/go.mod h1:a1inKt/atXimZ4Mv927x+r7UpyzRUf4emIoiiSC2TN4= -dmitri.shuralyov.com/state v0.0.0-20180228185332-28bcc343414c/go.mod h1:0PRwlb0D6DFvNNtx+9ybjezNCa8XF0xaYcETyp6rHWU= -git.apache.org/thrift.git v0.0.0-20180902110319-2566ecd5d999/go.mod h1:fPE2ZNJGynbRyZ4dJvy6G277gSllfV2HJqblrnkyeyg= -github.com/BurntSushi/toml v0.3.1/go.mod h1:xHWCNGjB5oqiDr8zfno3MHue2Ht5sIBksp03qcyfWMU= -github.com/anmitsu/go-shlex v0.0.0-20161002113705-648efa622239/go.mod h1:2FmKhYUyUczH0OGQWaF5ceTx0UBShxjsH6f8oGKYe2c= -github.com/beorn7/perks v0.0.0-20180321164747-3a771d992973/go.mod h1:Dwedo/Wpr24TaqPxmxbtue+5NUziq4I4S80YR8gNf3Q= -github.com/bradfitz/go-smtpd v0.0.0-20170404230938-deb6d6237625/go.mod h1:HYsPBTaaSFSlLx/70C2HPIMNZpVV8+vt/A+FMnYP11g= -github.com/buger/jsonparser v0.0.0-20181115193947-bf1c66bbce23/go.mod h1:bbYlZJ7hK1yFx9hf58LP0zeX7UjIGs20ufpu3evjr+s= github.com/chzyer/logex v1.1.10/go.mod h1:+Ywpsq7O8HXn0nuIou7OrIPyXbp3wmkHB+jjWRnGsAI= github.com/chzyer/readline v0.0.0-20180603132655-2972be24d48e/go.mod h1:nSuG5e5PlCu98SY8svDHJxuZscDgtXS6KTTbou5AhLI= github.com/chzyer/test v0.0.0-20180213035817-a1ea475d72b1/go.mod h1:Q3SI9o4m/ZMnBNeIyt5eFwwo7qiLfzFZmjNmxjkiQlU= -github.com/client9/misspell v0.3.4/go.mod h1:qj6jICC3Q7zFZvVWo7KLAzC3yx5G7kyvSDkc90ppPyw= -github.com/coreos/go-systemd v0.0.0-20181012123002-c6f51f82210d/go.mod h1:F5haX7vjVVG0kc13fIWeqUViNPyEJxv/OmvnBo0Yme4= github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c= github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= -github.com/dustin/go-humanize v1.0.0/go.mod h1:HtrtbFcZ19U5GC7JDqmcUSB87Iq5E25KnS6fMYU6eOk= -github.com/flynn/go-shlex v0.0.0-20150515145356-3f9db97f8568/go.mod h1:xEzjJPgXI435gkrCt3MPfRiAkVrwSbHsst4LCFVfpJc= -github.com/francoispqt/gojay v1.2.13/go.mod h1:ehT5mTG4ua4581f1++1WLG0vPdaA9HaiDsoyrBGkyDY= -github.com/fsnotify/fsnotify v1.4.7/go.mod h1:jwhsz4b93w/PPRr/qN1Yymfu8t87LnFCMoQvtojpjFo= -github.com/fsnotify/fsnotify v1.4.9/go.mod h1:znqG4EE+3YCdAaPaxE2ZRY/06pZUdp0tY4IgpuI1SZQ= -github.com/ghodss/yaml v1.0.0/go.mod h1:4dBDuWmgqj2HViK6kFavaiC9ZROes6MMH2rRYeMEF04= -github.com/gliderlabs/ssh v0.1.1/go.mod h1:U7qILu1NlMHj9FlMhZLlkCdDnU1DBEAqr0aevW3Awn0= -github.com/go-errors/errors v1.0.1/go.mod h1:f4zRHt4oKfwPJE5k8C9vpYG+aDHdBFUsgrm6/TyX73Q= -github.com/go-logr/logr v1.2.3/go.mod h1:jdQByPbusPIv2/zmleS9BjJVeZ6kBagPoEUsqbVz/1A= github.com/go-logr/logr v1.2.4 h1:g01GSCwiDw2xSZfjJ2/T9M+S6pFdcNtFYsp+Y43HYDQ= github.com/go-logr/logr v1.2.4/go.mod h1:jdQByPbusPIv2/zmleS9BjJVeZ6kBagPoEUsqbVz/1A= -github.com/go-task/slim-sprig v0.0.0-20210107165309-348f09dbbbc0/go.mod h1:fyg7847qk6SyHyPtNmDHnmrv/HOrqktSC+C9fM+CJOE= github.com/go-task/slim-sprig v0.0.0-20230315185526-52ccab3ef572 h1:tfuBGBXKqDEevZMzYi5KSi8KkcZtzBcTgAUUtapy0OI= github.com/go-task/slim-sprig v0.0.0-20230315185526-52ccab3ef572/go.mod h1:9Pwr4B2jHnOSGXyyzV8ROjYa2ojvAY6HCGYYfMoC3Ls= -github.com/gogo/protobuf v1.1.1/go.mod h1:r8qH/GZQm5c6nD/R0oafs1akxWv10x8SbQlK7atdtwQ= -github.com/golang/glog v0.0.0-20160126235308-23def4e6c14b/go.mod h1:SBH7ygxi8pfUlaOkMMuAQtPIUF8ecWP5IEl/CR7VP2Q= -github.com/golang/lint v0.0.0-20180702182130-06c8688daad7/go.mod h1:tluoj9z5200jBnyusfRPU2LqT6J+DAorxEvtC7LHB+E= -github.com/golang/mock v1.1.1/go.mod h1:oTYuIxOrZwtPieC+H1uAHpcLFnEyAGVDL/k47Jfbm0A= -github.com/golang/mock v1.2.0/go.mod h1:oTYuIxOrZwtPieC+H1uAHpcLFnEyAGVDL/k47Jfbm0A= -github.com/golang/protobuf v1.2.0/go.mod h1:6lQm79b+lXiMfvg/cZm0SGofjICqVBUtrP5yJMmIC1U= -github.com/golang/protobuf v1.3.1/go.mod h1:6lQm79b+lXiMfvg/cZm0SGofjICqVBUtrP5yJMmIC1U= -github.com/golang/protobuf v1.4.0-rc.1/go.mod h1:ceaxUfeHdC40wWswd/P6IGgMaK3YpKi5j83Wpe3EHw8= -github.com/golang/protobuf v1.4.0-rc.1.0.20200221234624-67d41d38c208/go.mod h1:xKAWHe0F5eneWXFV3EuXVDTCmh+JuBKY0li0aMyXATA= -github.com/golang/protobuf v1.4.0-rc.2/go.mod h1:LlEzMj4AhA7rCAGe4KMBDvJI+AwstrUpVNzEA03Pprs= -github.com/golang/protobuf v1.4.0-rc.4.0.20200313231945-b860323f09d0/go.mod h1:WU3c8KckQ9AFe+yFwt9sWVRKCVIyN9cPHBJSNnbL67w= -github.com/golang/protobuf v1.4.0/go.mod h1:jodUvKwWbYaEsadDk5Fwe5c77LiNKVO9IDvqG2KuDX0= -github.com/golang/protobuf v1.4.2/go.mod h1:oDoupMAO8OvCJWAcko0GGGIgR6R6ocIYbsSw735rRwI= -github.com/golang/protobuf v1.5.0/go.mod h1:FsONVRAS9T7sI+LIUmWTfcYkHO4aIWwzhcaSAoJOfIk= -github.com/golang/protobuf v1.5.2/go.mod h1:XVQd3VNwM+JqD3oG2Ue2ip4fOMUkwXdXDdiuN0vRsmY= github.com/golang/protobuf v1.5.3 h1:KhyjKVUg7Usr/dYsdSqoFveMYd5ko72D+zANwlG1mmg= github.com/golang/protobuf v1.5.3/go.mod h1:XVQd3VNwM+JqD3oG2Ue2ip4fOMUkwXdXDdiuN0vRsmY= -github.com/google/btree v0.0.0-20180813153112-4030bb1f1f0c/go.mod h1:lNA+9X1NB3Zf8V7Ke586lFgjr2dZNuvo3lPJSGZ5JPQ= -github.com/google/go-cmp v0.2.0/go.mod h1:oXzfMopK8JAjlY9xF4vHSVASa0yLyX7SntLO5aqRK0M= -github.com/google/go-cmp v0.3.0/go.mod h1:8QqcDgzrUqlUb/G2PQTWiueGozuR1884gddMywk6iLU= -github.com/google/go-cmp v0.3.1/go.mod h1:8QqcDgzrUqlUb/G2PQTWiueGozuR1884gddMywk6iLU= -github.com/google/go-cmp v0.4.0/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/gNBxE= -github.com/google/go-cmp v0.5.5/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/gNBxE= -github.com/google/go-cmp v0.5.8/go.mod h1:17dUlkBOakJ0+DkrSSNjCkIjxS6bF9zb3elmeNGIjoY= github.com/google/go-cmp v0.5.9 h1:O2Tfq5qg4qc4AmwVlvv0oLiVAGB7enBSJ2x2DqQFi38= github.com/google/go-cmp v0.5.9/go.mod h1:17dUlkBOakJ0+DkrSSNjCkIjxS6bF9zb3elmeNGIjoY= -github.com/google/go-github v17.0.0+incompatible/go.mod h1:zLgOLi98H3fifZn+44m+umXrS52loVEgC2AApnigrVQ= -github.com/google/go-querystring v1.0.0/go.mod h1:odCYkC5MyYFN7vkCjXpyrEuKhc/BUO6wN/zVPAxq5ck= -github.com/google/martian v2.1.0+incompatible/go.mod h1:9I4somxYTbIHy5NJKHRl3wXiIaQGbYVAs8BPL6v8lEs= -github.com/google/pprof v0.0.0-20181206194817-3ea8567a2e57/go.mod h1:zfwlbNMJ+OItoe0UupaVj+oy1omPYYDuagoSzA8v9mc= github.com/google/pprof v0.0.0-20210407192527-94a9f03dee38 h1:yAJXTCF9TqKcTiHJAE8dj7HMvPfh66eeA2JYW7eFpSE= github.com/google/pprof v0.0.0-20210407192527-94a9f03dee38/go.mod h1:kpwsk12EmLew5upagYY7GY0pfYCcupk39gWOCRROcvE= -github.com/googleapis/gax-go v2.0.0+incompatible/go.mod h1:SFVmujtThgffbyetf+mdk2eWhX2bMyUtNHzFKcPA9HY= -github.com/googleapis/gax-go/v2 v2.0.3/go.mod h1:LLvjysVCY1JZeum8Z6l8qUty8fiNwE08qbEPm1M08qg= -github.com/gopherjs/gopherjs v0.0.0-20181017120253-0766667cb4d1/go.mod h1:wJfORRmW1u3UXTncJ5qlYoELFm8eSnnEO6hX4iZ3EWY= -github.com/gregjones/httpcache v0.0.0-20180305231024-9cad4c3443a7/go.mod h1:FecbI9+v66THATjSRHfNgh1IVFe/9kFxbXtjV0ctIMA= -github.com/grpc-ecosystem/grpc-gateway v1.5.0/go.mod h1:RSKVYQBd5MCa4OVpNdGskqpgL2+G+NZTnrVHpWWfpdw= -github.com/hpcloud/tail v1.0.0/go.mod h1:ab1qPbhIpdTxEkNHXyeSf5vhxWSCs/tWer42PpOxQnU= github.com/ianlancetaylor/demangle v0.0.0-20200824232613-28f6c0f3b639/go.mod h1:aSSvb/t6k1mPoxDqO4vJh6VOCGPwU4O0C2/Eqndh1Sc= -github.com/jellevandenhooff/dkim v0.0.0-20150330215556-f50fe3d243e1/go.mod h1:E0B/fFc00Y+Rasa88328GlI/XbtyysCtTHZS8h7IrBU= -github.com/json-iterator/go v1.1.6/go.mod h1:+SdeFBvtyEkXs7REEP0seUULqWtbJapLOCVDaaPEHmU= -github.com/jstemmer/go-junit-report v0.0.0-20190106144839-af01ea7f8024/go.mod h1:6v2b51hI/fHJwM22ozAgKL4VKDeJcHhJFhtBdhmNjmU= -github.com/kisielk/gotool v1.0.0/go.mod h1:XhKaO+MFFWcvkIS/tQcRk01m1F5IRFswLeQ+oQHNcck= -github.com/kr/pretty v0.1.0/go.mod h1:dAy3ld7l9f0ibDNOQOHHMYYIIbhfbHSm3C4ZsoJORNo= -github.com/kr/pty v1.1.1/go.mod h1:pFQYn66WHrOpPYNljwOMqo10TkYh1fy3cYio2l3bCsQ= -github.com/kr/pty v1.1.3/go.mod h1:pFQYn66WHrOpPYNljwOMqo10TkYh1fy3cYio2l3bCsQ= -github.com/kr/text v0.1.0/go.mod h1:4Jbv+DJW3UT/LiOwJeYQe1efqtUx/iVham/4vfdArNI= -github.com/lunixbochs/vtclean v1.0.0/go.mod h1:pHhQNgMf3btfWnGBVipUOjRYhoOsdGqdm/+2c2E2WMI= -github.com/mailru/easyjson v0.0.0-20190312143242-1de009706dbe/go.mod h1:C1wdFJiN94OJF2b5HbByQZoLdCWB1Yqtg26g4irojpc= -github.com/matttproud/golang_protobuf_extensions v1.0.1/go.mod h1:D8He9yQNgCq6Z5Ld7szi9bcBfOoFv/3dc6xSMkL2PC0= -github.com/microcosm-cc/bluemonday v1.0.1/go.mod h1:hsXNsILzKxV+sX77C5b8FSuKF00vh2OMYv+xgHpAMF4= -github.com/modern-go/concurrent v0.0.0-20180306012644-bacd9c7ef1dd/go.mod h1:6dJC0mAP4ikYIbvyc7fijjWJddQyLn8Ig3JB5CqoB9Q= -github.com/modern-go/reflect2 v1.0.1/go.mod h1:bx2lNnkwVCuqBIxFjflWJWanXIb3RllmbCylyMrvgv0= -github.com/neelance/astrewrite v0.0.0-20160511093645-99348263ae86/go.mod h1:kHJEU3ofeGjhHklVoIGuVj85JJwZ6kWPaJwCIxgnFmo= -github.com/neelance/sourcemap v0.0.0-20151028013722-8c68805598ab/go.mod h1:Qr6/a/Q4r9LP1IltGz7tA7iOK1WonHEYhu1HRBA7ZiM= -github.com/nxadm/tail v1.4.4/go.mod h1:kenIhsEOeOJmVchQTgglprH7qJGnHDVpk1VPCcaMI8A= -github.com/nxadm/tail v1.4.8/go.mod h1:+ncqLTQzXmGhMZNUePPaPqPvBxHAIsmXswZKocGu+AU= -github.com/onsi/ginkgo v1.6.0/go.mod h1:lLunBs/Ym6LB5Z9jYTR76FiuTmxDTDusOGeTQH+WWjE= -github.com/onsi/ginkgo v1.12.1/go.mod h1:zj2OWP4+oCPe1qIXoGWkgMRwljMUYCdkwsT2108oapk= -github.com/onsi/ginkgo v1.16.4 h1:29JGrr5oVBm5ulCWet69zQkzWipVXIol6ygQUe/EzNc= -github.com/onsi/ginkgo v1.16.4/go.mod h1:dX+/inL/fNMqNlz0e9LfyB9TswhZpCVdJM/Z6Vvnwo0= -github.com/onsi/ginkgo/v2 v2.1.3/go.mod h1:vw5CSIxN1JObi/U8gcbwft7ZxR2dgaR70JSE3/PpL4c= -github.com/onsi/ginkgo/v2 v2.1.4/go.mod h1:um6tUpWM/cxCK3/FK8BXqEiUMUwRgSM4JXG47RKZmLU= -github.com/onsi/ginkgo/v2 v2.1.6/go.mod h1:MEH45j8TBi6u9BMogfbp0stKC5cdGjumZj5Y7AG4VIk= -github.com/onsi/ginkgo/v2 v2.2.0/go.mod h1:MEH45j8TBi6u9BMogfbp0stKC5cdGjumZj5Y7AG4VIk= -github.com/onsi/ginkgo/v2 v2.3.0/go.mod h1:Eew0uilEqZmIEZr8JrvYlvOM7Rr6xzTmMV8AyFNU9d0= -github.com/onsi/ginkgo/v2 v2.4.0/go.mod h1:iHkDK1fKGcBoEHT5W7YBq4RFWaQulw+caOMkAt4OrFo= -github.com/onsi/ginkgo/v2 v2.5.0/go.mod h1:Luc4sArBICYCS8THh8v3i3i5CuSZO+RaQRaJoeNwomw= -github.com/onsi/ginkgo/v2 v2.7.0/go.mod h1:yjiuMwPokqY1XauOgju45q3sJt6VzQ/Fict1LFVcsAo= -github.com/onsi/ginkgo/v2 v2.8.1/go.mod h1:N1/NbDngAFcSLdyZ+/aYTYGSlq9qMCS/cNKGJjy+csc= -github.com/onsi/ginkgo/v2 v2.9.0/go.mod h1:4xkjoL/tZv4SMWeww56BU5kAt19mVB47gTWxmrTcxyk= -github.com/onsi/ginkgo/v2 v2.9.1/go.mod h1:FEcmzVcCHl+4o9bQZVab+4dC9+j+91t2FHSzmGAPfuo= -github.com/onsi/ginkgo/v2 v2.9.2/go.mod h1:WHcJJG2dIlcCqVfBAwUCrJxSPFb6v4azBwgxeMeDuts= github.com/onsi/ginkgo/v2 v2.9.5 h1:+6Hr4uxzP4XIUyAkg61dWBw8lb/gc4/X5luuxN/EC+Q= github.com/onsi/ginkgo/v2 v2.9.5/go.mod h1:tvAoo1QUJwNEU2ITftXTpR7R1RbCzoZUOs3RonqW57k= -github.com/onsi/gomega v1.7.1/go.mod h1:XdKZgCCFLUoM/7CFJVPcG8C1xQ1AJ0vpAezJrB7JYyY= -github.com/onsi/gomega v1.10.1/go.mod h1:iN09h71vgCQne3DLsj+A5owkum+a2tYe+TOCB1ybHNo= -github.com/onsi/gomega v1.17.0/go.mod h1:HnhC7FXeEQY45zxNK3PPoIUhzk/80Xly9PcubAlGdZY= -github.com/onsi/gomega v1.19.0/go.mod h1:LY+I3pBVzYsTBU1AnDwOSxaYi9WoWiqgwooUqq9yPro= -github.com/onsi/gomega v1.20.1/go.mod h1:DtrZpjmvpn2mPm4YWQa0/ALMDj9v4YxLgojwPeREyVo= -github.com/onsi/gomega v1.21.1/go.mod h1:iYAIXgPSaDHak0LCMA+AWBpIKBr8WZicMxnE8luStNc= -github.com/onsi/gomega v1.22.1/go.mod h1:x6n7VNe4hw0vkyYUM4mjIXx3JbLiPaBPNgB7PRQ1tuM= -github.com/onsi/gomega v1.24.0/go.mod h1:Z/NWtiqwBrwUt4/2loMmHL63EDLnYHmVbuBpDr2vQAg= -github.com/onsi/gomega v1.24.1/go.mod h1:3AOiACssS3/MajrniINInwbfOOtfZvplPzuRSmvt1jM= -github.com/onsi/gomega v1.26.0/go.mod h1:r+zV744Re+DiYCIPRlYOTxn0YkOLcAnW8k1xXdMPGhM= -github.com/onsi/gomega v1.27.1/go.mod h1:aHX5xOykVYzWOV4WqQy0sy8BQptgukenXpCXfadcIAw= -github.com/onsi/gomega v1.27.3/go.mod h1:5vG284IBtfDAmDyrK+eGyZmUgUlmi+Wngqo557cZ6Gw= -github.com/onsi/gomega v1.27.4/go.mod h1:riYq/GJKh8hhoM01HN6Vmuy93AarCXCBGpvFDK3q3fQ= github.com/onsi/gomega v1.27.6 h1:ENqfyGeS5AX/rlXDd/ETokDz93u0YufY1Pgxuy/PvWE= github.com/onsi/gomega v1.27.6/go.mod h1:PIQNjfQwkP3aQAH7lf7j87O/5FiNr+ZR8+ipb+qQlhg= -github.com/openzipkin/zipkin-go v0.1.1/go.mod h1:NtoC/o8u3JlF1lSlyPNswIbeQH9bJTmOf0Erfk+hxe8= -github.com/pkg/errors v0.8.1/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0= github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM= github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= -github.com/prometheus/client_golang v0.8.0/go.mod h1:7SWBe2y4D6OKWSNQJUaRYU/AaXPKyh/dDVn+NZz0KFw= -github.com/prometheus/client_model v0.0.0-20180712105110-5c3871d89910/go.mod h1:MbSGuTsp3dbXC40dX6PRTWyKYBIrTGTE9sqQNg2J8bo= -github.com/prometheus/common v0.0.0-20180801064454-c7de2306084e/go.mod h1:daVV7qP5qjZbuso7PdcryaAu0sAZbrN9i7WWcTMWvro= -github.com/prometheus/procfs v0.0.0-20180725123919-05ee40e3a273/go.mod h1:c3At6R/oaqEKCNdg8wHV1ftS6bRYblBhIjjI8uT2IGk= github.com/quic-go/qpack v0.4.0 h1:Cr9BXA1sQS2SmDUWjSofMPNKmvF6IiIfDRmgU0w1ZCo= github.com/quic-go/qpack v0.4.0/go.mod h1:UZVnYIfi5GRk+zI9UMaCPsmZ2xKJP7XBUvVyT1Knj9A= -github.com/quic-go/qtls-go1-20 v0.3.4 h1:MfFAPULvst4yoMgY9QmtpYmfij/em7O8UUi+bNVm7Cg= -github.com/quic-go/qtls-go1-20 v0.3.4/go.mod h1:X9Nh97ZL80Z+bX/gUXMbipO6OxdiDi58b/fMC9mAL+k= -github.com/russross/blackfriday v1.5.2/go.mod h1:JO/DiYxRf+HjHt06OyowR9PTA263kcR/rfWxYHBV53g= -github.com/sergi/go-diff v1.0.0/go.mod h1:0CfEIISq7TuYL3j771MWULgwwjU+GofnZX9QAmXWZgo= -github.com/shurcooL/component v0.0.0-20170202220835-f88ec8f54cc4/go.mod h1:XhFIlyj5a1fBNx5aJTbKoIq0mNaPvOagO+HjB3EtxrY= -github.com/shurcooL/events v0.0.0-20181021180414-410e4ca65f48/go.mod h1:5u70Mqkb5O5cxEA8nxTsgrgLehJeAw6Oc4Ab1c/P1HM= -github.com/shurcooL/github_flavored_markdown v0.0.0-20181002035957-2122de532470/go.mod h1:2dOwnU2uBioM+SGy2aZoq1f/Sd1l9OkAeAUvjSyvgU0= -github.com/shurcooL/go v0.0.0-20180423040247-9e1955d9fb6e/go.mod h1:TDJrrUr11Vxrven61rcy3hJMUqaf/CLWYhHNPmT14Lk= -github.com/shurcooL/go-goon v0.0.0-20170922171312-37c2f522c041/go.mod h1:N5mDOmsrJOB+vfqUK+7DmDyjhSLIIBnXo9lvZJj3MWQ= -github.com/shurcooL/gofontwoff v0.0.0-20180329035133-29b52fc0a18d/go.mod h1:05UtEgK5zq39gLST6uB0cf3NEHjETfB4Fgr3Gx5R9Vw= -github.com/shurcooL/gopherjslib v0.0.0-20160914041154-feb6d3990c2c/go.mod h1:8d3azKNyqcHP1GaQE/c6dDgjkgSx2BZ4IoEi4F1reUI= -github.com/shurcooL/highlight_diff v0.0.0-20170515013008-09bb4053de1b/go.mod h1:ZpfEhSmds4ytuByIcDnOLkTHGUI6KNqRNPDLHDk+mUU= -github.com/shurcooL/highlight_go v0.0.0-20181028180052-98c3abbbae20/go.mod h1:UDKB5a1T23gOMUJrI+uSuH0VRDStOiUVSjBTRDVBVag= -github.com/shurcooL/home v0.0.0-20181020052607-80b7ffcb30f9/go.mod h1:+rgNQw2P9ARFAs37qieuu7ohDNQ3gds9msbT2yn85sg= -github.com/shurcooL/htmlg v0.0.0-20170918183704-d01228ac9e50/go.mod h1:zPn1wHpTIePGnXSHpsVPWEktKXHr6+SS6x/IKRb7cpw= -github.com/shurcooL/httperror v0.0.0-20170206035902-86b7830d14cc/go.mod h1:aYMfkZ6DWSJPJ6c4Wwz3QtW22G7mf/PEgaB9k/ik5+Y= -github.com/shurcooL/httpfs v0.0.0-20171119174359-809beceb2371/go.mod h1:ZY1cvUeJuFPAdZ/B6v7RHavJWZn2YPVFQ1OSXhCGOkg= -github.com/shurcooL/httpgzip v0.0.0-20180522190206-b1c53ac65af9/go.mod h1:919LwcH0M7/W4fcZ0/jy0qGght1GIhqyS/EgWGH2j5Q= -github.com/shurcooL/issues v0.0.0-20181008053335-6292fdc1e191/go.mod h1:e2qWDig5bLteJ4fwvDAc2NHzqFEthkqn7aOZAOpj+PQ= -github.com/shurcooL/issuesapp v0.0.0-20180602232740-048589ce2241/go.mod h1:NPpHK2TI7iSaM0buivtFUc9offApnI0Alt/K8hcHy0I= -github.com/shurcooL/notifications v0.0.0-20181007000457-627ab5aea122/go.mod h1:b5uSkrEVM1jQUspwbixRBhaIjIzL2xazXp6kntxYle0= -github.com/shurcooL/octicon v0.0.0-20181028054416-fa4f57f9efb2/go.mod h1:eWdoE5JD4R5UVWDucdOPg1g2fqQRq78IQa9zlOV1vpQ= -github.com/shurcooL/reactions v0.0.0-20181006231557-f2e0b4ca5b82/go.mod h1:TCR1lToEk4d2s07G3XGfz2QrgHXg4RJBvjrOozvoWfk= -github.com/shurcooL/sanitized_anchor_name v0.0.0-20170918181015-86672fcb3f95/go.mod h1:1NzhyTcUVG4SuEtjjoZeVRXNmyL/1OwPU0+IJeTBvfc= -github.com/shurcooL/users v0.0.0-20180125191416-49c67e49c537/go.mod h1:QJTqeLYEDaXHZDBsXlPCDqdhQuJkuw4NOtaxYe3xii4= -github.com/shurcooL/webdavfs v0.0.0-20170829043945-18c3829fa133/go.mod h1:hKmq5kWdCj2z2KEozexVbfEZIWiTjhE0+UjmZgPqehw= -github.com/sourcegraph/annotate v0.0.0-20160123013949-f4cad6c6324d/go.mod h1:UdhH50NIW0fCiwBSr0co2m7BnFLdv4fQTgdqdJTHFeE= -github.com/sourcegraph/syntaxhighlight v0.0.0-20170531221838-bd320f5d308e/go.mod h1:HuIsMU8RRBOtsCgI77wP899iHVBQpCmg4ErYMZB+2IA= github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME= -github.com/stretchr/testify v1.2.2/go.mod h1:a8OnRcib4nhh0OaRAV+Yts87kKdq0PP7pXfy6kDkUVs= -github.com/stretchr/testify v1.5.1/go.mod h1:5W2xD1RspED5o8YsWQXVCued0rvSQ+mT+I5cxcmMvtA= github.com/stretchr/testify v1.6.1 h1:hDPOHmpOpP40lSULcqw7IrRb/u7w6RpDC9399XyoNd0= github.com/stretchr/testify v1.6.1/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg= -github.com/tarm/serial v0.0.0-20180830185346-98f6abe2eb07/go.mod h1:kDXzergiv9cbyO7IOYJZWg1U88JhDg3PB6klq9Hg2pA= -github.com/viant/assertly v0.4.8/go.mod h1:aGifi++jvCrUaklKEKT0BU95igDNaqkvz+49uaYMPRU= -github.com/viant/toolbox v0.24.0/go.mod h1:OxMCG57V0PXuIP2HNQrtJf2CjqdmbrOx5EkMILuUhzM= -github.com/yuin/goldmark v1.2.1/go.mod h1:3hX8gzYuyVAZsxl0MRgGTJEmQBFcNTphYh9decYSb74= -github.com/yuin/goldmark v1.4.1/go.mod h1:mwnBkeHKe2W/ZEtQ+71ViKU8L12m81fl3OWwC1Zlc8k= -github.com/yuin/goldmark v1.4.13/go.mod h1:6yULJ656Px+3vBD8DxQVa3kxgyrAnzto9xy5taEt/CY= -go.opencensus.io v0.18.0/go.mod h1:vKdFvxhtzZ9onBp9VKHK8z/sRpBMnKAsufL7wlDrCOA= -go.uber.org/mock v0.3.0 h1:3mUxI1No2/60yUYax92Pt8eNOEecx2D3lcXZh2NEZJo= -go.uber.org/mock v0.3.0/go.mod h1:a6FSlNadKUHUa9IP5Vyt1zh4fC7uAwxMutEAscFbkZc= -go4.org v0.0.0-20180809161055-417644f6feb5/go.mod h1:MkTOUMDaeVYJUOUsaDXIhWPZYa1yOyC1qaOBpL57BhE= -golang.org/x/build v0.0.0-20190111050920-041ab4dc3f9d/go.mod h1:OWs+y06UdEOHN4y+MfF/py+xQ/tYqIWW03b70/CG9Rw= -golang.org/x/crypto v0.0.0-20181030102418-4d3f4d9ffa16/go.mod h1:6SG95UA2DQfeDnfUPMdvaQW0Q7yPrPDi9nlGo2tz2b4= -golang.org/x/crypto v0.0.0-20190308221718-c2843e01d9a2/go.mod h1:djNgcEr1/C05ACkg1iLfiJU5Ep61QUkGW8qpdssI0+w= -golang.org/x/crypto v0.0.0-20190313024323-a1f597ede03a/go.mod h1:djNgcEr1/C05ACkg1iLfiJU5Ep61QUkGW8qpdssI0+w= -golang.org/x/crypto v0.0.0-20191011191535-87dc89f01550/go.mod h1:yigFU9vqHzYiE8UmvKecakEJjdnWj3jj499lnFckfCI= -golang.org/x/crypto v0.0.0-20200622213623-75b288015ac9/go.mod h1:LzIPMQfyMNhhGPhUkYOs5KpL4U8rLKemX1yGLhDgUto= -golang.org/x/crypto v0.0.0-20210921155107-089bfa567519/go.mod h1:GvvjBRRGRdwPK5ydBHafDWAxML/pGHZbMvKqRZ5+Abc= -golang.org/x/crypto v0.1.0/go.mod h1:RecgLatLF4+eUMCP1PoPZQb+cVrJcOPbHkTkbkB9sbw= +go.uber.org/mock v0.4.0 h1:VcM4ZOtdbR4f6VXfiOpwpVJDL6lCReaZ6mw31wqh7KU= +go.uber.org/mock v0.4.0/go.mod h1:a6FSlNadKUHUa9IP5Vyt1zh4fC7uAwxMutEAscFbkZc= golang.org/x/crypto v0.4.0 h1:UVQgzMY87xqpKNgb+kDsll2Igd33HszWHFLmpaRMq/8= golang.org/x/crypto v0.4.0/go.mod h1:3quD/ATkf6oY+rnes5c3ExXTbLc8mueNue5/DoinL80= -golang.org/x/exp v0.0.0-20190121172915-509febef88a4/go.mod h1:CJ0aWSM057203Lf6IL+f9T1iT9GByDxfZKAQTCR3kQA= golang.org/x/exp v0.0.0-20221205204356-47842c84f3db h1:D/cFflL63o2KSLJIwjlcIt8PR064j/xsmdEJL/YvY/o= golang.org/x/exp v0.0.0-20221205204356-47842c84f3db/go.mod h1:CxIveKay+FTh1D0yPZemJVgC/95VzuuOLq5Qi4xnoYc= -golang.org/x/lint v0.0.0-20180702182130-06c8688daad7/go.mod h1:UVdnD1Gm6xHRNCYTkRU2/jEulfH38KcIWyp/GAMgvoE= -golang.org/x/lint v0.0.0-20181026193005-c67002cb31c3/go.mod h1:UVdnD1Gm6xHRNCYTkRU2/jEulfH38KcIWyp/GAMgvoE= -golang.org/x/lint v0.0.0-20190227174305-5b3e6a55c961/go.mod h1:wehouNa3lNwaWXcvxsM5YxQ5yQlVC4a0KAMCusXpPoU= -golang.org/x/mod v0.3.0/go.mod h1:s0Qsj1ACt9ePp/hMypM3fl4fZqREWJwdYDEqhRiZZUA= -golang.org/x/mod v0.6.0-dev.0.20220106191415-9b9b3d81d5e3/go.mod h1:3p9vT2HGsQu2K1YbXdKPJLVgG5VJdoTa1poYQBtP1AY= -golang.org/x/mod v0.6.0-dev.0.20220419223038-86c51ed26bb4/go.mod h1:jJ57K6gSWd91VN4djpZkiMVwK6gcyfeH4XE8wZrZaV4= -golang.org/x/mod v0.6.0/go.mod h1:4mET923SAdbXp2ki8ey+zGs1SLqsuM2Y0uvdZR/fUNI= -golang.org/x/mod v0.7.0/go.mod h1:iBbtSCu2XBx23ZKBPSOrRkjjQPZFPuis4dIYUhu/chs= -golang.org/x/mod v0.8.0/go.mod h1:iBbtSCu2XBx23ZKBPSOrRkjjQPZFPuis4dIYUhu/chs= -golang.org/x/mod v0.9.0/go.mod h1:iBbtSCu2XBx23ZKBPSOrRkjjQPZFPuis4dIYUhu/chs= -golang.org/x/mod v0.10.0/go.mod h1:iBbtSCu2XBx23ZKBPSOrRkjjQPZFPuis4dIYUhu/chs= golang.org/x/mod v0.11.0 h1:bUO06HqtnRcc/7l71XBe4WcqTZ+3AH1J59zWDDwLKgU= golang.org/x/mod v0.11.0/go.mod h1:iBbtSCu2XBx23ZKBPSOrRkjjQPZFPuis4dIYUhu/chs= -golang.org/x/net v0.0.0-20180724234803-3673e40ba225/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4= -golang.org/x/net v0.0.0-20180826012351-8a410e7b638d/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4= -golang.org/x/net v0.0.0-20180906233101-161cd47e91fd/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4= -golang.org/x/net v0.0.0-20181029044818-c44066c5c816/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4= -golang.org/x/net v0.0.0-20181106065722-10aee1819953/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4= -golang.org/x/net v0.0.0-20190108225652-1e06a53dbb7e/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4= -golang.org/x/net v0.0.0-20190213061140-3a22650c66bd/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4= -golang.org/x/net v0.0.0-20190313220215-9f648a60d977/go.mod h1:t9HGtf8HONx5eT2rtn7q6eTqICYqUVnKs3thJo3Qplg= -golang.org/x/net v0.0.0-20190404232315-eb5bcb51f2a3/go.mod h1:t9HGtf8HONx5eT2rtn7q6eTqICYqUVnKs3thJo3Qplg= -golang.org/x/net v0.0.0-20190620200207-3b0461eec859/go.mod h1:z5CRVTTTmAJ677TzLLGU+0bjPO0LkuOLi4/5GtJWs/s= -golang.org/x/net v0.0.0-20200520004742-59133d7f0dd7/go.mod h1:qpuaurCH72eLCgpAm/N6yyVIVM9cpaDIP3A8BGJEC5A= -golang.org/x/net v0.0.0-20201021035429-f5854403a974/go.mod h1:sp8m0HH+o8qH0wwXwYZr8TS3Oi6o0r6Gce1SSxlDquU= -golang.org/x/net v0.0.0-20210226172049-e18ecbb05110/go.mod h1:m0MpNAwzfU5UDzcl9v0D8zg8gWTRqZa9RBIspLL5mdg= -golang.org/x/net v0.0.0-20210428140749-89ef3d95e781/go.mod h1:OJAsFXCWl8Ukc7SiCT/9KSuxbyM7479/AVlXFRxuMCk= -golang.org/x/net v0.0.0-20211015210444-4f30a5c0130f/go.mod h1:9nx3DQGgdP8bBQD5qxJ1jj9UTztislL4KSBs9R2vV5Y= -golang.org/x/net v0.0.0-20220225172249-27dd8689420f/go.mod h1:CfG3xpIq0wQ8r1q4Su4UZFWDARRcnwPjda9FqA0JpMk= -golang.org/x/net v0.0.0-20220425223048-2871e0cb64e4/go.mod h1:CfG3xpIq0wQ8r1q4Su4UZFWDARRcnwPjda9FqA0JpMk= -golang.org/x/net v0.0.0-20220722155237-a158d28d115b/go.mod h1:XRhObCWvk6IyKnWLug+ECip1KBveYUHfp+8e9klMJ9c= -golang.org/x/net v0.1.0/go.mod h1:Cx3nUiGt4eDBEyega/BKRp+/AlGL8hYe7U9odMt2Cco= -golang.org/x/net v0.2.0/go.mod h1:KqCZLdyyvdV855qA2rE3GC2aiw5xGR5TEjj8smXukLY= -golang.org/x/net v0.3.0/go.mod h1:MBQ8lrhLObU/6UmLb4fmbmk5OcyYmqtbGd/9yIeKjEE= -golang.org/x/net v0.5.0/go.mod h1:DivGGAXEgPSlEBzxGzZI+ZLohi+xUj054jfeKui00ws= -golang.org/x/net v0.6.0/go.mod h1:2Tu9+aMcznHK/AK1HMvgo6xiTLG5rD5rZLDS+rp2Bjs= -golang.org/x/net v0.7.0/go.mod h1:2Tu9+aMcznHK/AK1HMvgo6xiTLG5rD5rZLDS+rp2Bjs= -golang.org/x/net v0.8.0/go.mod h1:QVkue5JL9kW//ek3r6jTKnTFis1tRmNAW2P1shuFdJc= golang.org/x/net v0.10.0 h1:X2//UzNDwYmtCLn7To6G58Wr6f5ahEAQgKNzv9Y951M= golang.org/x/net v0.10.0/go.mod h1:0qNGK6F8kojg2nk9dLZ2mShWaEBan6FAoqfSigmmuDg= -golang.org/x/oauth2 v0.0.0-20180821212333-d2e6202438be/go.mod h1:N/0e6XlmueqKjAGxoOufVs8QHGRruUQn6yWY3a++T0U= -golang.org/x/oauth2 v0.0.0-20181017192945-9dcd33a902f4/go.mod h1:N/0e6XlmueqKjAGxoOufVs8QHGRruUQn6yWY3a++T0U= -golang.org/x/oauth2 v0.0.0-20181203162652-d668ce993890/go.mod h1:N/0e6XlmueqKjAGxoOufVs8QHGRruUQn6yWY3a++T0U= -golang.org/x/oauth2 v0.0.0-20190226205417-e64efc72b421/go.mod h1:gOpvHmFTYa4IltrdGE7lF6nIHvwfUNPOp7c8zoXwtLw= -golang.org/x/perf v0.0.0-20180704124530-6e6d33e29852/go.mod h1:JLpeXjPJfIyPr5TlbXLkXWLhP8nz10XfvxElABhCtcw= -golang.org/x/sync v0.0.0-20180314180146-1d60e4601c6f/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= -golang.org/x/sync v0.0.0-20181108010431-42b317875d0f/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= -golang.org/x/sync v0.0.0-20181221193216-37e7f081c4d4/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= -golang.org/x/sync v0.0.0-20190227155943-e225da77a7e6/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= -golang.org/x/sync v0.0.0-20190423024810-112230192c58/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= -golang.org/x/sync v0.0.0-20201020160332-67f06af15bc9/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= -golang.org/x/sync v0.0.0-20210220032951-036812b2e83c/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= -golang.org/x/sync v0.0.0-20220722155255-886fb9371eb4/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= -golang.org/x/sync v0.1.0/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= -golang.org/x/sync v0.2.0/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= -golang.org/x/sys v0.0.0-20180830151530-49385e6e1522/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= -golang.org/x/sys v0.0.0-20180909124046-d0be0721c37e/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= -golang.org/x/sys v0.0.0-20181029174526-d69651ed3497/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= -golang.org/x/sys v0.0.0-20190215142949-d0b11bdaac8a/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= -golang.org/x/sys v0.0.0-20190316082340-a2f829d7f35f/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= -golang.org/x/sys v0.0.0-20190412213103-97732733099d/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= -golang.org/x/sys v0.0.0-20190904154756-749cb33beabd/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= -golang.org/x/sys v0.0.0-20191005200804-aed5e4c7ecf9/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= -golang.org/x/sys v0.0.0-20191120155948-bd437916bb0e/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20191204072324-ce4227a45e2e/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= -golang.org/x/sys v0.0.0-20200323222414-85ca7c5b95cd/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= -golang.org/x/sys v0.0.0-20200930185726-fdedc70b468f/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= -golang.org/x/sys v0.0.0-20201119102817-f84b799fce68/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= -golang.org/x/sys v0.0.0-20210112080510-489259a85091/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= -golang.org/x/sys v0.0.0-20210423082822-04245dca01da/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= -golang.org/x/sys v0.0.0-20210615035016-665e8c7367d1/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= -golang.org/x/sys v0.0.0-20211019181941-9d821ace8654/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= -golang.org/x/sys v0.0.0-20211216021012-1d35b9e2eb4e/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= -golang.org/x/sys v0.0.0-20220319134239-a9b59b0215f8/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= -golang.org/x/sys v0.0.0-20220422013727-9388b58f7150/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= -golang.org/x/sys v0.0.0-20220520151302-bc2c85ada10a/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= -golang.org/x/sys v0.0.0-20220722155257-8c9f86f7a55f/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= -golang.org/x/sys v0.1.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= -golang.org/x/sys v0.2.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= -golang.org/x/sys v0.3.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= -golang.org/x/sys v0.4.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= -golang.org/x/sys v0.5.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= -golang.org/x/sys v0.6.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.8.0 h1:EBmGv8NaZBZTWvrbjNoL6HVt+IVy3QDQpJs7VRIw3tU= golang.org/x/sys v0.8.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= -golang.org/x/term v0.0.0-20201126162022-7de9c90e9dd1/go.mod h1:bj7SfCRtBDWHUb9snDiAeCFNEtKQo2Wmx5Cou7ajbmo= -golang.org/x/term v0.0.0-20210927222741-03fcf44c2211/go.mod h1:jbD1KX2456YbFQfuXm/mYQcufACuNUgVhRMnK/tPxf8= -golang.org/x/term v0.1.0/go.mod h1:jbD1KX2456YbFQfuXm/mYQcufACuNUgVhRMnK/tPxf8= -golang.org/x/term v0.2.0/go.mod h1:TVmDHMZPmdnySmBfhjOoOdhjzdE1h4u1VwSiw2l1Nuc= -golang.org/x/term v0.3.0/go.mod h1:q750SLmJuPmVoN1blW3UFBPREJfb1KmY3vwxfr+nFDA= -golang.org/x/term v0.4.0/go.mod h1:9P2UbLfCdcvo3p/nzKvsmas4TnlujnuoV9hGgYzW1lQ= -golang.org/x/term v0.5.0/go.mod h1:jMB1sMXY+tzblOD4FWmEbocvup2/aLOaQEp7JmGp78k= -golang.org/x/term v0.6.0/go.mod h1:m6U89DPEgQRMq3DNkDClhWw02AUbt2daBVO4cn4Hv9U= -golang.org/x/term v0.8.0/go.mod h1:xPskH00ivmX89bAKVGSKKtLOWNx2+17Eiy94tnKShWo= -golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ= -golang.org/x/text v0.3.1-0.20180807135948-17ff2d5776d2/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ= -golang.org/x/text v0.3.3/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ= -golang.org/x/text v0.3.6/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ= -golang.org/x/text v0.3.7/go.mod h1:u+2+/6zg+i71rQMx5EYifcz6MCKuco9NR6JIITiCfzQ= -golang.org/x/text v0.4.0/go.mod h1:mrYo+phRRbMaCq/xk9113O4dZlRixOauAjOtrjsXDZ8= -golang.org/x/text v0.5.0/go.mod h1:mrYo+phRRbMaCq/xk9113O4dZlRixOauAjOtrjsXDZ8= -golang.org/x/text v0.6.0/go.mod h1:mrYo+phRRbMaCq/xk9113O4dZlRixOauAjOtrjsXDZ8= -golang.org/x/text v0.7.0/go.mod h1:mrYo+phRRbMaCq/xk9113O4dZlRixOauAjOtrjsXDZ8= -golang.org/x/text v0.8.0/go.mod h1:e1OnstbJyHTd6l/uOt8jFFHp6TRDWZR/bV3emEE/zU8= golang.org/x/text v0.9.0 h1:2sjJmO8cDvYveuX97RDLsxlyUxLl+GHoLxBiRdHllBE= golang.org/x/text v0.9.0/go.mod h1:e1OnstbJyHTd6l/uOt8jFFHp6TRDWZR/bV3emEE/zU8= -golang.org/x/time v0.0.0-20180412165947-fbb02b2291d2/go.mod h1:tRJNPiyCQ0inRvYxbN9jk5I+vvW/OXSQhTDSoE431IQ= -golang.org/x/time v0.0.0-20181108054448-85acf8d2951c/go.mod h1:tRJNPiyCQ0inRvYxbN9jk5I+vvW/OXSQhTDSoE431IQ= -golang.org/x/tools v0.0.0-20180828015842-6cd1fcedba52/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ= -golang.org/x/tools v0.0.0-20180917221912-90fa682c2a6e/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ= -golang.org/x/tools v0.0.0-20181030000716-a0a13e073c7b/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ= -golang.org/x/tools v0.0.0-20190114222345-bf090417da8b/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ= -golang.org/x/tools v0.0.0-20190226205152-f727befe758c/go.mod h1:9Yl7xja0Znq3iFh3HoIrodX9oNMXvdceNzlUR8zjMvY= -golang.org/x/tools v0.0.0-20191119224855-298f0cb1881e/go.mod h1:b+2E5dAYhXwXZwtnZ6UAqBI28+e2cm9otk0dWdXHAEo= -golang.org/x/tools v0.0.0-20201224043029-2b0845dc783e/go.mod h1:emZCQorbCU4vsT4fOWvOPXz4eW1wZW4PmDk9uLelYpA= -golang.org/x/tools v0.1.10/go.mod h1:Uh6Zz+xoGYZom868N8YTex3t7RhtHDBrE8Gzo9bV56E= -golang.org/x/tools v0.1.12/go.mod h1:hNGJHUnrk76NpqgfD5Aqm5Crs+Hm0VOH/i9J2+nxYbc= -golang.org/x/tools v0.2.0/go.mod h1:y4OqIKeOV/fWJetJ8bXPU1sEVniLMIyDAZWeHdV+NTA= -golang.org/x/tools v0.4.0/go.mod h1:UE5sM2OK9E/d67R0ANs2xJizIymRP5gJU295PvKXxjQ= -golang.org/x/tools v0.6.0/go.mod h1:Xwgl3UAJ/d3gWutnCtw505GrjyAbvKui8lOU390QaIU= -golang.org/x/tools v0.7.0/go.mod h1:4pg6aUX35JBAogB10C9AtvVL+qowtN4pT3CGSQex14s= +golang.org/x/time v0.5.0 h1:o7cqy6amK/52YcAKIPlM3a+Fpj35zvRj2TP+e1xFSfk= +golang.org/x/time v0.5.0/go.mod h1:3BpzKBy/shNhVucY/MWOyx10tF3SFh9QdLuxbVysPQM= golang.org/x/tools v0.9.1 h1:8WMNJAz3zrtPmnYC7ISf5dEn3MT0gY7jBJfw27yrrLo= golang.org/x/tools v0.9.1/go.mod h1:owI94Op576fPu3cIGQeHs3joujW/2Oc6MtlxbF5dfNc= -golang.org/x/xerrors v0.0.0-20190717185122-a985d3407aa7/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= -golang.org/x/xerrors v0.0.0-20191011141410-1b5146add898/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= -golang.org/x/xerrors v0.0.0-20191204190536-9bdfabe68543/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= -golang.org/x/xerrors v0.0.0-20200804184101-5ec99f83aff1/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= -google.golang.org/api v0.0.0-20180910000450-7ca32eb868bf/go.mod h1:4mhQ8q/RsB7i+udVvVy5NUi08OU8ZlA0gRVgrF7VFY0= -google.golang.org/api v0.0.0-20181030000543-1d582fd0359e/go.mod h1:4mhQ8q/RsB7i+udVvVy5NUi08OU8ZlA0gRVgrF7VFY0= -google.golang.org/api v0.1.0/go.mod h1:UGEZY7KEX120AnNLIHFMKIo4obdJhkp2tPbaPlQx13Y= -google.golang.org/appengine v1.1.0/go.mod h1:EbEs0AVv82hx2wNQdGPgUI5lhzA/G0D9YwlJXL52JkM= -google.golang.org/appengine v1.2.0/go.mod h1:xpcJRLb0r/rnEns0DIKYYv+WjYCduHsrkT7/EB5XEv4= -google.golang.org/appengine v1.3.0/go.mod h1:xpcJRLb0r/rnEns0DIKYYv+WjYCduHsrkT7/EB5XEv4= -google.golang.org/appengine v1.4.0/go.mod h1:xpcJRLb0r/rnEns0DIKYYv+WjYCduHsrkT7/EB5XEv4= -google.golang.org/genproto v0.0.0-20180817151627-c66870c02cf8/go.mod h1:JiN7NxoALGmiZfu7CAH4rXhgtRTLTxftemlI0sWmxmc= -google.golang.org/genproto v0.0.0-20180831171423-11092d34479b/go.mod h1:JiN7NxoALGmiZfu7CAH4rXhgtRTLTxftemlI0sWmxmc= -google.golang.org/genproto v0.0.0-20181029155118-b69ba1387ce2/go.mod h1:JiN7NxoALGmiZfu7CAH4rXhgtRTLTxftemlI0sWmxmc= -google.golang.org/genproto v0.0.0-20181202183823-bd91e49a0898/go.mod h1:7Ep/1NZk928CDR8SjdVbjWNpdIf6nzjE3BTgJDr2Atg= -google.golang.org/genproto v0.0.0-20190306203927-b5d61aea6440/go.mod h1:VzzqZJRnGkLBvHegQrXjBqPurQTc5/KpmUdxsrq26oE= -google.golang.org/grpc v1.14.0/go.mod h1:yo6s7OP7yaDglbqo1J04qKzAhqBH6lvTonzMVmEdcZw= -google.golang.org/grpc v1.16.0/go.mod h1:0JHn/cJsOMiMfNA9+DeHDlAU7KAAB5GDlYFpa9MZMio= -google.golang.org/grpc v1.17.0/go.mod h1:6QZJwpn2B+Zp71q/5VxRsJ6NXXVCE5NRUHRo+f3cWCs= -google.golang.org/grpc v1.19.0/go.mod h1:mqu4LbDTu4XGKhr4mRzUsmM4RtVoemTSY81AxZiDr8c= -google.golang.org/protobuf v0.0.0-20200109180630-ec00e32a8dfd/go.mod h1:DFci5gLYBciE7Vtevhsrf46CRTquxDuWsQurQQe4oz8= -google.golang.org/protobuf v0.0.0-20200221191635-4d8936d0db64/go.mod h1:kwYJMbMJ01Woi6D6+Kah6886xMZcty6N08ah7+eCXa0= -google.golang.org/protobuf v0.0.0-20200228230310-ab0ca4ff8a60/go.mod h1:cfTl7dwQJ+fmap5saPgwCLgHXTUD7jkjRqWcaiX5VyM= -google.golang.org/protobuf v1.20.1-0.20200309200217-e05f789c0967/go.mod h1:A+miEFZTKqfCUM6K7xSMQL9OKL/b6hQv+e19PK+JZNE= -google.golang.org/protobuf v1.21.0/go.mod h1:47Nbq4nVaFHyn7ilMalzfO3qCViNmqZ2kzikPIcrTAo= -google.golang.org/protobuf v1.23.0/go.mod h1:EGpADcykh3NcUnDUJcl1+ZksZNG86OlYog2l/sGQquU= -google.golang.org/protobuf v1.26.0-rc.1/go.mod h1:jlhhOSvTdKEhbULTjvd4ARK9grFBp09yW+WbY/TyQbw= -google.golang.org/protobuf v1.26.0/go.mod h1:9q0QmTI4eRPtz6boOQmLYwt+qCgq0jsYwAQnmE0givc= google.golang.org/protobuf v1.28.0 h1:w43yiav+6bVFTBQFZX0r7ipe9JQ1QsbMgHwbBziscLw= google.golang.org/protobuf v1.28.0/go.mod h1:HV8QOd/L58Z+nl8r43ehVNZIU/HEI6OcFqwMG9pJV4I= gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= -gopkg.in/check.v1 v1.0.0-20180628173108-788fd7840127/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= -gopkg.in/fsnotify.v1 v1.4.7/go.mod h1:Tz8NjZHkW78fSQdbUxIjBTcgA1z1m8ZHf0WmKUhAMys= -gopkg.in/inf.v0 v0.9.1/go.mod h1:cWUDdTG/fYaXco+Dcufb5Vnc6Gp2YChqWtbxRZE0mXw= -gopkg.in/tomb.v1 v1.0.0-20141024135613-dd632973f1e7/go.mod h1:dt/ZhP58zS4L8KSrWDmTeBkI65Dw0HsyUHuEVlX15mw= -gopkg.in/yaml.v2 v2.2.1/go.mod h1:hI93XBmqTisBFMUTm0b8Fm+jr3Dg1NNxqwp+5A1VGuI= -gopkg.in/yaml.v2 v2.2.2/go.mod h1:hI93XBmqTisBFMUTm0b8Fm+jr3Dg1NNxqwp+5A1VGuI= -gopkg.in/yaml.v2 v2.2.4/go.mod h1:hI93XBmqTisBFMUTm0b8Fm+jr3Dg1NNxqwp+5A1VGuI= -gopkg.in/yaml.v2 v2.3.0/go.mod h1:hI93XBmqTisBFMUTm0b8Fm+jr3Dg1NNxqwp+5A1VGuI= -gopkg.in/yaml.v2 v2.4.0/go.mod h1:RDklbk79AGWmwhnvt/jBztapEOGDOx6ZbXqjP6csGnQ= gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA= gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= -grpc.go4.org v0.0.0-20170609214715-11d0a25b4919/go.mod h1:77eQGdRu53HpSqPFJFmuJdjuHRquDANNeA4x7B8WQ9o= -honnef.co/go/tools v0.0.0-20180728063816-88497007e858/go.mod h1:rf3lG4BRIbNafJWhAfAdb/ePZxsR/4RtNHQocxwk9r4= -honnef.co/go/tools v0.0.0-20190102054323-c2f93a96b099/go.mod h1:rf3lG4BRIbNafJWhAfAdb/ePZxsR/4RtNHQocxwk9r4= -honnef.co/go/tools v0.0.0-20190106161140-3f1c8253044a/go.mod h1:rf3lG4BRIbNafJWhAfAdb/ePZxsR/4RtNHQocxwk9r4= -sourcegraph.com/sourcegraph/go-diff v0.5.0/go.mod h1:kuch7UrkMzY0X+p9CRK03kfuPQ2zzQcaEFbx8wA8rck= -sourcegraph.com/sqs/pbtypes v0.0.0-20180604144634-d3ebe8f20ae4/go.mod h1:ketZ/q3QxT9HOBeFhu6RdvsftgpsbFHBF5Cas6cDKZ0= diff --git a/integrationtests/self/cancelation_test.go b/integrationtests/self/cancelation_test.go index 54c78abd0..8fac63dca 100644 --- a/integrationtests/self/cancelation_test.go +++ b/integrationtests/self/cancelation_test.go @@ -31,7 +31,7 @@ var _ = Describe("Stream Cancellations", func() { server, err = quic.ListenAddr("localhost:0", getTLSConfig(), getQuicConfig(nil)) Expect(err).ToNot(HaveOccurred()) - var canceledCounter int32 + var canceledCounter atomic.Int32 go func() { defer GinkgoRecover() var wg sync.WaitGroup @@ -50,18 +50,18 @@ var _ = Describe("Stream Cancellations", func() { ErrorCode: quic.StreamErrorCode(str.StreamID()), Remote: true, })) - atomic.AddInt32(&canceledCounter, 1) + canceledCounter.Add(1) return } if err := str.Close(); err != nil { Expect(err).To(MatchError(fmt.Sprintf("close called for canceled stream %d", str.StreamID()))) - atomic.AddInt32(&canceledCounter, 1) + canceledCounter.Add(1) return } }() } wg.Wait() - numCanceledStreamsChan <- atomic.LoadInt32(&canceledCounter) + numCanceledStreamsChan <- canceledCounter.Load() }() return numCanceledStreamsChan } @@ -80,7 +80,7 @@ var _ = Describe("Stream Cancellations", func() { ) Expect(err).ToNot(HaveOccurred()) - var canceledCounter int32 + var canceledCounter atomic.Int32 var wg sync.WaitGroup wg.Add(numStreams) for i := 0; i < numStreams; i++ { @@ -91,7 +91,7 @@ var _ = Describe("Stream Cancellations", func() { Expect(err).ToNot(HaveOccurred()) // cancel around 2/3 of the streams if rand.Int31()%3 != 0 { - atomic.AddInt32(&canceledCounter, 1) + canceledCounter.Add(1) resetErr := quic.StreamErrorCode(str.StreamID()) str.CancelRead(resetErr) _, err := str.Read([]byte{0}) @@ -113,7 +113,7 @@ var _ = Describe("Stream Cancellations", func() { Eventually(serverCanceledCounterChan).Should(Receive(&serverCanceledCounter)) Expect(conn.CloseWithError(0, "")).To(Succeed()) - clientCanceledCounter := atomic.LoadInt32(&canceledCounter) + clientCanceledCounter := canceledCounter.Load() // The server will only count a stream as being reset if learns about the cancelation before it finished writing all data. Expect(clientCanceledCounter).To(BeNumerically(">=", serverCanceledCounter)) fmt.Fprintf(GinkgoWriter, "Canceled reading on %d of %d streams.\n", clientCanceledCounter, numStreams) @@ -132,7 +132,7 @@ var _ = Describe("Stream Cancellations", func() { ) Expect(err).ToNot(HaveOccurred()) - var canceledCounter int32 + var canceledCounter atomic.Int32 var wg sync.WaitGroup wg.Add(numStreams) for i := 0; i < numStreams; i++ { @@ -148,7 +148,7 @@ var _ = Describe("Stream Cancellations", func() { Expect(err).ToNot(HaveOccurred()) str.CancelRead(quic.StreamErrorCode(str.StreamID())) Expect(data).To(Equal(PRData[:length])) - atomic.AddInt32(&canceledCounter, 1) + canceledCounter.Add(1) return } data, err := io.ReadAll(str) @@ -162,7 +162,7 @@ var _ = Describe("Stream Cancellations", func() { Eventually(serverCanceledCounterChan).Should(Receive(&serverCanceledCounter)) Expect(conn.CloseWithError(0, "")).To(Succeed()) - clientCanceledCounter := atomic.LoadInt32(&canceledCounter) + clientCanceledCounter := canceledCounter.Load() // The server will only count a stream as being reset if learns about the cancelation before it finished writing all data. Expect(clientCanceledCounter).To(BeNumerically(">=", serverCanceledCounter)) fmt.Fprintf(GinkgoWriter, "Canceled reading on %d of %d streams.\n", clientCanceledCounter, numStreams) @@ -185,7 +185,7 @@ var _ = Describe("Stream Cancellations", func() { var wg sync.WaitGroup wg.Add(numStreams) - var counter int32 + var counter atomic.Int32 for i := 0; i < numStreams; i++ { go func() { defer GinkgoRecover() @@ -199,7 +199,7 @@ var _ = Describe("Stream Cancellations", func() { defer close(done) b := make([]byte, 32) if _, err := str.Read(b); err != nil { - atomic.AddInt32(&counter, 1) + counter.Add(1) Expect(err).To(Equal(&quic.StreamError{ StreamID: str.StreamID(), ErrorCode: 1234, @@ -214,7 +214,7 @@ var _ = Describe("Stream Cancellations", func() { } wg.Wait() Expect(conn.CloseWithError(0, "")).To(Succeed()) - numCanceled := atomic.LoadInt32(&counter) + numCanceled := counter.Load() fmt.Fprintf(GinkgoWriter, "canceled %d out of %d streams", numCanceled, numStreams) Expect(numCanceled).ToNot(BeZero()) Eventually(serverCanceledCounterChan).Should(Receive()) @@ -232,7 +232,7 @@ var _ = Describe("Stream Cancellations", func() { Expect(err).ToNot(HaveOccurred()) var wg sync.WaitGroup - var counter int32 + var counter atomic.Int32 wg.Add(numStreams) for i := 0; i < numStreams; i++ { go func() { @@ -242,7 +242,7 @@ var _ = Describe("Stream Cancellations", func() { Expect(err).ToNot(HaveOccurred()) data, err := io.ReadAll(str) if err != nil { - atomic.AddInt32(&counter, 1) + counter.Add(1) Expect(err).To(MatchError(&quic.StreamError{ StreamID: str.StreamID(), ErrorCode: quic.StreamErrorCode(str.StreamID()), @@ -254,7 +254,7 @@ var _ = Describe("Stream Cancellations", func() { } wg.Wait() - streamCount := atomic.LoadInt32(&counter) + streamCount := counter.Load() fmt.Fprintf(GinkgoWriter, "Canceled writing on %d of %d streams\n", streamCount, numStreams) Expect(streamCount).To(BeNumerically(">", numStreams/10)) Expect(numStreams - streamCount).To(BeNumerically(">", numStreams/10)) @@ -267,7 +267,7 @@ var _ = Describe("Stream Cancellations", func() { server, err := quic.ListenAddr("localhost:0", getTLSConfig(), nil) Expect(err).ToNot(HaveOccurred()) - var canceledCounter int32 + var canceledCounter atomic.Int32 go func() { defer GinkgoRecover() conn, err := server.Accept(context.Background()) @@ -280,7 +280,7 @@ var _ = Describe("Stream Cancellations", func() { // cancel about 2/3 of the streams if rand.Int31()%3 != 0 { str.CancelWrite(quic.StreamErrorCode(str.StreamID())) - atomic.AddInt32(&canceledCounter, 1) + canceledCounter.Add(1) return } _, err = str.Write(PRData) @@ -291,14 +291,14 @@ var _ = Describe("Stream Cancellations", func() { }() clientCanceledStreams := runClient(server) - Expect(clientCanceledStreams).To(Equal(atomic.LoadInt32(&canceledCounter))) + Expect(clientCanceledStreams).To(Equal(canceledCounter.Load())) }) It("downloads when the server cancels some streams after sending some data", func() { server, err := quic.ListenAddr("localhost:0", getTLSConfig(), nil) Expect(err).ToNot(HaveOccurred()) - var canceledCounter int32 + var canceledCounter atomic.Int32 go func() { defer GinkgoRecover() conn, err := server.Accept(context.Background()) @@ -314,7 +314,7 @@ var _ = Describe("Stream Cancellations", func() { _, err = str.Write(PRData[:length]) Expect(err).ToNot(HaveOccurred()) str.CancelWrite(quic.StreamErrorCode(str.StreamID())) - atomic.AddInt32(&canceledCounter, 1) + canceledCounter.Add(1) return } _, err = str.Write(PRData) @@ -325,7 +325,7 @@ var _ = Describe("Stream Cancellations", func() { }() clientCanceledStreams := runClient(server) - Expect(clientCanceledStreams).To(Equal(atomic.LoadInt32(&canceledCounter))) + Expect(clientCanceledStreams).To(Equal(canceledCounter.Load())) }) }) @@ -378,7 +378,7 @@ var _ = Describe("Stream Cancellations", func() { Expect(err).ToNot(HaveOccurred()) var wg sync.WaitGroup - var counter int32 + var counter atomic.Int32 wg.Add(numStreams) for i := 0; i < numStreams; i++ { go func() { @@ -399,13 +399,13 @@ var _ = Describe("Stream Cancellations", func() { })) return } - atomic.AddInt32(&counter, 1) + counter.Add(1) Expect(data).To(Equal(PRData)) }() } wg.Wait() - count := atomic.LoadInt32(&counter) + count := counter.Load() Expect(count).To(BeNumerically(">", numStreams/15)) fmt.Fprintf(GinkgoWriter, "Successfully read from %d of %d streams.\n", count, numStreams) @@ -464,7 +464,7 @@ var _ = Describe("Stream Cancellations", func() { Expect(err).ToNot(HaveOccurred()) var wg sync.WaitGroup - var counter int32 + var counter atomic.Int32 wg.Add(numStreams) for i := 0; i < numStreams; i++ { go func() { @@ -495,14 +495,14 @@ var _ = Describe("Stream Cancellations", func() { return } - atomic.AddInt32(&counter, 1) + counter.Add(1) Expect(data).To(Equal(PRData)) }() } wg.Wait() Eventually(done).Should(BeClosed()) - count := atomic.LoadInt32(&counter) + count := counter.Load() Expect(count).To(BeNumerically(">", numStreams/15)) fmt.Fprintf(GinkgoWriter, "Successfully read from %d of %d streams.\n", count, numStreams) @@ -543,7 +543,7 @@ var _ = Describe("Stream Cancellations", func() { Expect(err).ToNot(HaveOccurred()) var numToAccept int - var counter int32 + var counter atomic.Int32 var wg sync.WaitGroup wg.Add(numStreams) for numToAccept < numStreams { @@ -561,7 +561,7 @@ var _ = Describe("Stream Cancellations", func() { str, err := conn.AcceptUniStream(ctx) if err != nil { if err.Error() == "context canceled" { - atomic.AddInt32(&counter, 1) + counter.Add(1) } return } @@ -573,7 +573,7 @@ var _ = Describe("Stream Cancellations", func() { } wg.Wait() - count := atomic.LoadInt32(&counter) + count := counter.Load() fmt.Fprintf(GinkgoWriter, "Canceled AcceptStream %d times\n", count) Expect(count).To(BeNumerically(">", numStreams/2)) Expect(conn.CloseWithError(0, "")).To(Succeed()) @@ -589,7 +589,7 @@ var _ = Describe("Stream Cancellations", func() { Expect(err).ToNot(HaveOccurred()) msg := make(chan struct{}, 1) - var numCanceled int32 + var numCanceled atomic.Int32 go func() { defer GinkgoRecover() defer close(msg) @@ -603,7 +603,7 @@ var _ = Describe("Stream Cancellations", func() { str, err := conn.OpenUniStreamSync(ctx) if err != nil { Expect(err).To(MatchError(context.DeadlineExceeded)) - atomic.AddInt32(&numCanceled, 1) + numCanceled.Add(1) select { case msg <- struct{}{}: default: @@ -644,7 +644,7 @@ var _ = Describe("Stream Cancellations", func() { } wg.Wait() - count := atomic.LoadInt32(&numCanceled) + count := numCanceled.Load() fmt.Fprintf(GinkgoWriter, "Canceled OpenStreamSync %d times\n", count) Expect(count).To(BeNumerically(">=", numStreams-maxIncomingStreams)) Expect(conn.CloseWithError(0, "")).To(Succeed()) diff --git a/integrationtests/self/conn_id_test.go b/integrationtests/self/conn_id_test.go index fd73451d5..3b388b8fa 100644 --- a/integrationtests/self/conn_id_test.go +++ b/integrationtests/self/conn_id_test.go @@ -50,6 +50,7 @@ var _ = Describe("Connection ID lengths tests", func() { ConnectionIDLength: connIDLen, ConnectionIDGenerator: connIDGenerator, } + addTracer(tr) ln, err := tr.Listen(getTLSConfig(), getQuicConfig(nil)) Expect(err).ToNot(HaveOccurred()) go func() { @@ -92,6 +93,7 @@ var _ = Describe("Connection ID lengths tests", func() { ConnectionIDLength: connIDLen, ConnectionIDGenerator: connIDGenerator, } + addTracer(tr) defer tr.Close() cl, err := tr.Dial( context.Background(), diff --git a/integrationtests/self/datagram_test.go b/integrationtests/self/datagram_test.go index 8703e4b4f..72ca00c65 100644 --- a/integrationtests/self/datagram_test.go +++ b/integrationtests/self/datagram_test.go @@ -19,11 +19,12 @@ import ( ) var _ = Describe("Datagram test", func() { - const num = 100 + const concurrentSends = 100 + const maxDatagramSize = 250 var ( serverConn, clientConn *net.UDPConn - dropped, total int32 + dropped, total atomic.Int32 ) startServerAndProxy := func(enableDatagram, expectDatagramSupport bool) (port int, closeFn func()) { @@ -47,19 +48,24 @@ var _ = Describe("Datagram test", func() { if expectDatagramSupport { Expect(conn.ConnectionState().SupportsDatagrams).To(BeTrue()) - if enableDatagram { + f := &wire.DatagramFrame{DataLenPresent: true} var wg sync.WaitGroup - wg.Add(num) - for i := 0; i < num; i++ { + wg.Add(concurrentSends) + for i := 0; i < concurrentSends; i++ { go func(i int) { defer GinkgoRecover() defer wg.Done() b := make([]byte, 8) binary.BigEndian.PutUint64(b, uint64(i)) - Expect(conn.SendMessage(b)).To(Succeed()) + Expect(conn.SendDatagram(b)).To(Succeed()) }(i) } + maxDatagramMessageSize := f.MaxDataLen(maxDatagramSize, conn.ConnectionState().Version) + b := make([]byte, maxDatagramMessageSize+1) + Expect(conn.SendDatagram(b)).To(MatchError(&quic.DatagramTooLargeError{ + PeerMaxDatagramFrameSize: int64(maxDatagramMessageSize), + })) wg.Wait() } } else { @@ -81,9 +87,9 @@ var _ = Describe("Datagram test", func() { } drop := mrand.Int()%10 == 0 if drop { - atomic.AddInt32(&dropped, 1) + dropped.Add(1) } - atomic.AddInt32(&total, 1) + total.Add(1) return drop }, }) @@ -103,6 +109,8 @@ var _ = Describe("Datagram test", func() { }) It("sends datagrams", func() { + oldMaxDatagramSize := wire.MaxDatagramSize + wire.MaxDatagramSize = maxDatagramSize proxyPort, close := startServerAndProxy(true, true) defer close() raddr, err := net.ResolveUDPAddr("udp", fmt.Sprintf("localhost:%d", proxyPort)) @@ -120,22 +128,23 @@ var _ = Describe("Datagram test", func() { for { // Close the connection if no message is received for 100 ms. timer := time.AfterFunc(scaleDuration(100*time.Millisecond), func() { conn.CloseWithError(0, "") }) - if _, err := conn.ReceiveMessage(context.Background()); err != nil { + if _, err := conn.ReceiveDatagram(context.Background()); err != nil { break } timer.Stop() counter++ } - numDropped := int(atomic.LoadInt32(&dropped)) - expVal := num - numDropped - fmt.Fprintf(GinkgoWriter, "Dropped %d out of %d packets.\n", numDropped, atomic.LoadInt32(&total)) - fmt.Fprintf(GinkgoWriter, "Received %d out of %d sent datagrams.\n", counter, num) + numDropped := int(dropped.Load()) + expVal := concurrentSends - numDropped + fmt.Fprintf(GinkgoWriter, "Dropped %d out of %d packets.\n", numDropped, total.Load()) + fmt.Fprintf(GinkgoWriter, "Received %d out of %d sent datagrams.\n", counter, concurrentSends) Expect(counter).To(And( BeNumerically(">", expVal*9/10), - BeNumerically("<", num), + BeNumerically("<", concurrentSends), )) Eventually(conn.Context().Done).Should(BeClosed()) + wire.MaxDatagramSize = oldMaxDatagramSize }) It("server can disable datagram", func() { @@ -170,7 +179,7 @@ var _ = Describe("Datagram test", func() { Expect(err).ToNot(HaveOccurred()) Expect(conn.ConnectionState().SupportsDatagrams).To(BeFalse()) - Expect(conn.SendMessage([]byte{0})).To(HaveOccurred()) + Expect(conn.SendDatagram([]byte{0})).To(HaveOccurred()) close() conn.CloseWithError(0, "") diff --git a/integrationtests/self/deadline_test.go b/integrationtests/self/deadline_test.go index 95914ac5c..336cd2083 100644 --- a/integrationtests/self/deadline_test.go +++ b/integrationtests/self/deadline_test.go @@ -13,7 +13,7 @@ import ( ) var _ = Describe("Stream deadline tests", func() { - setup := func() (*quic.Listener, quic.Stream, quic.Stream) { + setup := func() (serverStr, clientStr quic.Stream, close func()) { server, err := quic.ListenAddr("localhost:0", getTLSConfig(), getQuicConfig(nil)) Expect(err).ToNot(HaveOccurred()) strChan := make(chan quic.SendStream) @@ -35,19 +35,21 @@ var _ = Describe("Stream deadline tests", func() { getQuicConfig(nil), ) Expect(err).ToNot(HaveOccurred()) - clientStr, err := conn.OpenStream() + clientStr, err = conn.OpenStream() Expect(err).ToNot(HaveOccurred()) _, err = clientStr.Write([]byte{0}) // need to write one byte so the server learns about the stream Expect(err).ToNot(HaveOccurred()) - var serverStr quic.Stream Eventually(strChan).Should(Receive(&serverStr)) - return server, serverStr, clientStr + return serverStr, clientStr, func() { + Expect(server.Close()).To(Succeed()) + Expect(conn.CloseWithError(0, "")).To(Succeed()) + } } Context("read deadlines", func() { It("completes a transfer when the deadline is set", func() { - server, serverStr, clientStr := setup() - defer server.Close() + serverStr, clientStr, closeFn := setup() + defer closeFn() const timeout = time.Millisecond done := make(chan struct{}) @@ -81,8 +83,8 @@ var _ = Describe("Stream deadline tests", func() { }) It("completes a transfer when the deadline is set concurrently", func() { - server, serverStr, clientStr := setup() - defer server.Close() + serverStr, clientStr, closeFn := setup() + defer closeFn() const timeout = time.Millisecond go func() { @@ -131,8 +133,8 @@ var _ = Describe("Stream deadline tests", func() { Context("write deadlines", func() { It("completes a transfer when the deadline is set", func() { - server, serverStr, clientStr := setup() - defer server.Close() + serverStr, clientStr, closeFn := setup() + defer closeFn() const timeout = time.Millisecond done := make(chan struct{}) @@ -164,8 +166,8 @@ var _ = Describe("Stream deadline tests", func() { }) It("completes a transfer when the deadline is set concurrently", func() { - server, serverStr, clientStr := setup() - defer server.Close() + serverStr, clientStr, closeFn := setup() + defer closeFn() const timeout = time.Millisecond readDone := make(chan struct{}) diff --git a/integrationtests/self/drop_test.go b/integrationtests/self/drop_test.go index 14f0a6181..874e5293b 100644 --- a/integrationtests/self/drop_test.go +++ b/integrationtests/self/drop_test.go @@ -67,14 +67,14 @@ var _ = Describe("Drop Tests", func() { fmt.Fprintf(GinkgoWriter, "Dropping packets for %s, after a delay of %s\n", dropDuration, dropDelay) startTime := time.Now() - var numDroppedPackets int32 + var numDroppedPackets atomic.Int32 startListenerAndProxy(func(d quicproxy.Direction, _ []byte) bool { if !d.Is(direction) { return false } drop := time.Now().After(startTime.Add(dropDelay)) && time.Now().Before(startTime.Add(dropDelay).Add(dropDuration)) if drop { - atomic.AddInt32(&numDroppedPackets, 1) + numDroppedPackets.Add(1) } return drop }) @@ -114,7 +114,7 @@ var _ = Describe("Drop Tests", func() { Expect(b[0]).To(Equal(i)) } close(done) - numDropped := atomic.LoadInt32(&numDroppedPackets) + numDropped := numDroppedPackets.Load() fmt.Fprintf(GinkgoWriter, "Dropped %d packets.\n", numDropped) Expect(numDropped).To(BeNumerically(">", 0)) }) diff --git a/integrationtests/self/go119_test.go b/integrationtests/self/go119_test.go deleted file mode 100644 index c676693da..000000000 --- a/integrationtests/self/go119_test.go +++ /dev/null @@ -1,21 +0,0 @@ -//go:build go1.19 && !go1.20 - -package self_test - -import ( - "errors" - "net/http" - "time" -) - -const go120 = false - -var errNotSupported = errors.New("not supported") - -func setReadDeadline(w http.ResponseWriter, deadline time.Time) error { - return errNotSupported -} - -func setWriteDeadline(w http.ResponseWriter, deadline time.Time) error { - return errNotSupported -} diff --git a/integrationtests/self/go120_test.go b/integrationtests/self/go120_test.go deleted file mode 100644 index 88eb4a7ed..000000000 --- a/integrationtests/self/go120_test.go +++ /dev/null @@ -1,22 +0,0 @@ -//go:build go1.20 - -package self_test - -import ( - "net/http" - "time" -) - -const go120 = true - -func setReadDeadline(w http.ResponseWriter, deadline time.Time) error { - rc := http.NewResponseController(w) - - return rc.SetReadDeadline(deadline) -} - -func setWriteDeadline(w http.ResponseWriter, deadline time.Time) error { - rc := http.NewResponseController(w) - - return rc.SetWriteDeadline(deadline) -} diff --git a/integrationtests/self/handshake_drop_test.go b/integrationtests/self/handshake_drop_test.go index c099d9fa9..4048ab382 100644 --- a/integrationtests/self/handshake_drop_test.go +++ b/integrationtests/self/handshake_drop_test.go @@ -10,13 +10,12 @@ import ( "sync/atomic" "time" - quic "github.com/refraction-networking/uquic" tls "github.com/refraction-networking/utls" - "github.com/refraction-networking/uquic/quicvarint" - + quic "github.com/refraction-networking/uquic" quicproxy "github.com/refraction-networking/uquic/integrationtests/tools/proxy" "github.com/refraction-networking/uquic/internal/wire" + "github.com/refraction-networking/uquic/quicvarint" . "github.com/onsi/ginkgo/v2" . "github.com/onsi/gomega" @@ -27,23 +26,17 @@ var directions = []quicproxy.Direction{quicproxy.DirectionIncoming, quicproxy.Di type applicationProtocol struct { name string - run func() + run func(ln *quic.Listener, port int) } var _ = Describe("Handshake drop tests", func() { - var ( - proxy *quicproxy.QuicProxy - ln *quic.Listener - ) - data := GeneratePRData(5000) const timeout = 2 * time.Minute - startListenerAndProxy := func(dropCallback quicproxy.DropCallback, doRetry bool, longCertChain bool) { + startListenerAndProxy := func(dropCallback quicproxy.DropCallback, doRetry bool, longCertChain bool) (ln *quic.Listener, proxyPort int, closeFn func()) { conf := getQuicConfig(&quic.Config{ - MaxIdleTimeout: timeout, - HandshakeIdleTimeout: timeout, - RequireAddressValidation: func(net.Addr) bool { return doRetry }, + MaxIdleTimeout: timeout, + HandshakeIdleTimeout: timeout, }) var tlsConf *tls.Config if longCertChain { @@ -51,11 +44,18 @@ var _ = Describe("Handshake drop tests", func() { } else { tlsConf = getTLSConfig() } - var err error - ln, err = quic.ListenAddr("localhost:0", tlsConf, conf) + laddr, err := net.ResolveUDPAddr("udp", "localhost:0") + Expect(err).ToNot(HaveOccurred()) + conn, err := net.ListenUDP("udp", laddr) + Expect(err).ToNot(HaveOccurred()) + tr := &quic.Transport{Conn: conn} + if doRetry { + tr.VerifySourceAddress = func(net.Addr) bool { return true } + } + ln, err = tr.Listen(tlsConf, conf) Expect(err).ToNot(HaveOccurred()) serverPort := ln.Addr().(*net.UDPAddr).Port - proxy, err = quicproxy.NewQuicProxy("localhost:0", &quicproxy.Opts{ + proxy, err := quicproxy.NewQuicProxy("localhost:0", &quicproxy.Opts{ RemoteAddr: fmt.Sprintf("localhost:%d", serverPort), DropPacket: dropCallback, DelayPacket: func(dir quicproxy.Direction, packet []byte) time.Duration { @@ -63,11 +63,18 @@ var _ = Describe("Handshake drop tests", func() { }, }) Expect(err).ToNot(HaveOccurred()) + + return ln, proxy.LocalPort(), func() { + ln.Close() + tr.Close() + conn.Close() + proxy.Close() + } } clientSpeaksFirst := &applicationProtocol{ name: "client speaks first", - run: func() { + run: func(ln *quic.Listener, port int) { serverConnChan := make(chan quic.Connection) go func() { defer GinkgoRecover() @@ -83,7 +90,7 @@ var _ = Describe("Handshake drop tests", func() { }() conn, err := quic.DialAddr( context.Background(), - fmt.Sprintf("localhost:%d", proxy.LocalPort()), + fmt.Sprintf("localhost:%d", port), getTLSClientConfig(), getQuicConfig(&quic.Config{ MaxIdleTimeout: timeout, @@ -106,7 +113,7 @@ var _ = Describe("Handshake drop tests", func() { serverSpeaksFirst := &applicationProtocol{ name: "server speaks first", - run: func() { + run: func(ln *quic.Listener, port int) { serverConnChan := make(chan quic.Connection) go func() { defer GinkgoRecover() @@ -121,7 +128,7 @@ var _ = Describe("Handshake drop tests", func() { }() conn, err := quic.DialAddr( context.Background(), - fmt.Sprintf("localhost:%d", proxy.LocalPort()), + fmt.Sprintf("localhost:%d", port), getTLSClientConfig(), getQuicConfig(&quic.Config{ MaxIdleTimeout: timeout, @@ -144,7 +151,7 @@ var _ = Describe("Handshake drop tests", func() { nobodySpeaks := &applicationProtocol{ name: "nobody speaks", - run: func() { + run: func(ln *quic.Listener, port int) { serverConnChan := make(chan quic.Connection) go func() { defer GinkgoRecover() @@ -154,7 +161,7 @@ var _ = Describe("Handshake drop tests", func() { }() conn, err := quic.DialAddr( context.Background(), - fmt.Sprintf("localhost:%d", proxy.LocalPort()), + fmt.Sprintf("localhost:%d", port), getTLSClientConfig(), getQuicConfig(&quic.Config{ MaxIdleTimeout: timeout, @@ -170,11 +177,6 @@ var _ = Describe("Handshake drop tests", func() { }, } - AfterEach(func() { - Expect(ln.Close()).To(Succeed()) - Expect(proxy.Close()).To(Succeed()) - }) - for _, d := range directions { direction := d @@ -195,35 +197,37 @@ var _ = Describe("Handshake drop tests", func() { Context(app.name, func() { It(fmt.Sprintf("establishes a connection when the first packet is lost in %s direction", direction), func() { - var incoming, outgoing int32 - startListenerAndProxy(func(d quicproxy.Direction, _ []byte) bool { + var incoming, outgoing atomic.Int32 + ln, proxyPort, closeFn := startListenerAndProxy(func(d quicproxy.Direction, _ []byte) bool { var p int32 //nolint:exhaustive switch d { case quicproxy.DirectionIncoming: - p = atomic.AddInt32(&incoming, 1) + p = incoming.Add(1) case quicproxy.DirectionOutgoing: - p = atomic.AddInt32(&outgoing, 1) + p = outgoing.Add(1) } return p == 1 && d.Is(direction) }, doRetry, longCertChain) - app.run() + defer closeFn() + app.run(ln, proxyPort) }) It(fmt.Sprintf("establishes a connection when the second packet is lost in %s direction", direction), func() { - var incoming, outgoing int32 - startListenerAndProxy(func(d quicproxy.Direction, _ []byte) bool { + var incoming, outgoing atomic.Int32 + ln, proxyPort, closeFn := startListenerAndProxy(func(d quicproxy.Direction, _ []byte) bool { var p int32 //nolint:exhaustive switch d { case quicproxy.DirectionIncoming: - p = atomic.AddInt32(&incoming, 1) + p = incoming.Add(1) case quicproxy.DirectionOutgoing: - p = atomic.AddInt32(&outgoing, 1) + p = outgoing.Add(1) } return p == 2 && d.Is(direction) }, doRetry, longCertChain) - app.run() + defer closeFn() + app.run(ln, proxyPort) }) It(fmt.Sprintf("establishes a connection when 1/3 of the packets are lost in %s direction", direction), func() { @@ -231,7 +235,7 @@ var _ = Describe("Handshake drop tests", func() { var mx sync.Mutex var incoming, outgoing int - startListenerAndProxy(func(d quicproxy.Direction, _ []byte) bool { + ln, proxyPort, closeFn := startListenerAndProxy(func(d quicproxy.Direction, _ []byte) bool { drop := mrand.Int63n(int64(3)) == 0 mx.Lock() @@ -261,7 +265,8 @@ var _ = Describe("Handshake drop tests", func() { } return drop }, doRetry, longCertChain) - app.run() + defer closeFn() + app.run(ln, proxyPort) }) }) } @@ -282,13 +287,14 @@ var _ = Describe("Handshake drop tests", func() { uint64(27+31*(1000+mrand.Int63()/31)) % quicvarint.Max: b, } - startListenerAndProxy(func(d quicproxy.Direction, _ []byte) bool { + ln, proxyPort, closeFn := startListenerAndProxy(func(d quicproxy.Direction, _ []byte) bool { if d == quicproxy.DirectionOutgoing { return false } return mrand.Intn(3) == 0 }, false, false) - clientSpeaksFirst.run() + defer closeFn() + clientSpeaksFirst.run(ln, proxyPort) }) } }) diff --git a/integrationtests/self/handshake_rtt_test.go b/integrationtests/self/handshake_rtt_test.go index ef453a984..ea2f09503 100644 --- a/integrationtests/self/handshake_rtt_test.go +++ b/integrationtests/self/handshake_rtt_test.go @@ -55,9 +55,19 @@ var _ = Describe("Handshake RTT tests", func() { // 1 RTT for verifying the source address // 1 RTT for the TLS handshake - It("is forward-secure after 2 RTTs", func() { - serverConfig.RequireAddressValidation = func(net.Addr) bool { return true } - ln, err := quic.ListenAddr("localhost:0", serverTLSConfig, serverConfig) + It("is forward-secure after 2 RTTs with Retry", func() { + laddr, err := net.ResolveUDPAddr("udp", "localhost:0") + Expect(err).ToNot(HaveOccurred()) + udpConn, err := net.ListenUDP("udp", laddr) + Expect(err).ToNot(HaveOccurred()) + defer udpConn.Close() + tr := &quic.Transport{ + Conn: udpConn, + VerifySourceAddress: func(net.Addr) bool { return true }, + } + addTracer(tr) + defer tr.Close() + ln, err := tr.Listen(serverTLSConfig, serverConfig) Expect(err).ToNot(HaveOccurred()) defer ln.Close() @@ -67,7 +77,10 @@ var _ = Describe("Handshake RTT tests", func() { context.Background(), fmt.Sprintf("localhost:%d", proxy.LocalAddr().(*net.UDPAddr).Port), getTLSClientConfig(), - getQuicConfig(nil), + getQuicConfig(&quic.Config{GetConfigForClient: func(info *quic.ClientHelloInfo) (*quic.Config, error) { + Expect(info.AddrVerified).To(BeTrue()) + return nil, nil + }}), ) Expect(err).ToNot(HaveOccurred()) defer conn.CloseWithError(0, "") @@ -85,7 +98,10 @@ var _ = Describe("Handshake RTT tests", func() { context.Background(), fmt.Sprintf("localhost:%d", proxy.LocalAddr().(*net.UDPAddr).Port), getTLSClientConfig(), - getQuicConfig(nil), + getQuicConfig(&quic.Config{GetConfigForClient: func(info *quic.ClientHelloInfo) (*quic.Config, error) { + Expect(info.AddrVerified).To(BeFalse()) + return nil, nil + }}), ) Expect(err).ToNot(HaveOccurred()) defer conn.CloseWithError(0, "") diff --git a/integrationtests/self/handshake_test.go b/integrationtests/self/handshake_test.go index c7e4cee7c..6f1fbfde1 100644 --- a/integrationtests/self/handshake_test.go +++ b/integrationtests/self/handshake_test.go @@ -80,6 +80,26 @@ var _ = Describe("Handshake tests", func() { }() } + It("returns the context cancellation error on timeouts", func() { + ctx, cancel := context.WithTimeout(context.Background(), scaleDuration(20*time.Millisecond)) + defer cancel() + errChan := make(chan error, 1) + go func() { + _, err := quic.DialAddr( + ctx, + "localhost:1234", // nobody is listening on this port, but we're going to cancel this dial anyway + getTLSClientConfig(), + getQuicConfig(nil), + ) + errChan <- err + }() + + var err error + Eventually(errChan).Should(Receive(&err)) + Expect(err).To(HaveOccurred()) + Expect(err).To(MatchError(context.DeadlineExceeded)) + }) + It("returns the cancellation reason when a dial is canceled", func() { ctx, cancel := context.WithCancelCause(context.Background()) errChan := make(chan error, 1) @@ -150,13 +170,14 @@ var _ = Describe("Handshake tests", func() { Context("Certificate validation", func() { It("accepts the certificate", func() { runServer(getTLSConfig()) - _, err := quic.DialAddr( + conn, err := quic.DialAddr( context.Background(), fmt.Sprintf("localhost:%d", server.Addr().(*net.UDPAddr).Port), getTLSClientConfig(), getQuicConfig(nil), ) Expect(err).ToNot(HaveOccurred()) + conn.CloseWithError(0, "") }) It("has the right local and remote address on the tls.Config.GetConfigForClient ClientHelloInfo.Conn", func() { @@ -185,6 +206,7 @@ var _ = Describe("Handshake tests", func() { getQuicConfig(nil), ) Expect(err).ToNot(HaveOccurred()) + defer conn.CloseWithError(0, "") Eventually(done).Should(BeClosed()) Expect(server.Addr()).To(Equal(local)) Expect(conn.LocalAddr().(*net.UDPAddr).Port).To(Equal(remote.(*net.UDPAddr).Port)) @@ -194,13 +216,14 @@ var _ = Describe("Handshake tests", func() { It("works with a long certificate chain", func() { runServer(getTLSConfigWithLongCertChain()) - _, err := quic.DialAddr( + conn, err := quic.DialAddr( context.Background(), fmt.Sprintf("localhost:%d", server.Addr().(*net.UDPAddr).Port), getTLSClientConfig(), getQuicConfig(nil), ) Expect(err).ToNot(HaveOccurred()) + conn.CloseWithError(0, "") }) It("errors if the server name doesn't match", func() { @@ -276,7 +299,7 @@ var _ = Describe("Handshake tests", func() { }) }) - Context("rate limiting", func() { + Context("queuening and accepting connections", func() { var ( server *quic.Listener pconn net.PacketConn @@ -301,7 +324,10 @@ var _ = Describe("Handshake tests", func() { Expect(err).ToNot(HaveOccurred()) pconn, err = net.ListenUDP("udp", laddr) Expect(err).ToNot(HaveOccurred()) - dialer = &quic.Transport{Conn: pconn, ConnectionIDLength: 4} + dialer = &quic.Transport{ + Conn: pconn, + ConnectionIDLength: 4, + } }) AfterEach(func() { @@ -318,8 +344,11 @@ var _ = Describe("Handshake tests", func() { } time.Sleep(25 * time.Millisecond) // wait a bit for the connection to be queued - _, err := dial() - Expect(err).To(HaveOccurred()) + conn, err := dial() + Expect(err).ToNot(HaveOccurred()) + ctx, cancel := context.WithTimeout(context.Background(), 500*time.Millisecond) + defer cancel() + _, err = conn.AcceptStream(ctx) var transportErr *quic.TransportError Expect(errors.As(err, &transportErr)).To(BeTrue()) Expect(transportErr.ErrorCode).To(Equal(quic.ConnectionRefused)) @@ -328,18 +357,21 @@ var _ = Describe("Handshake tests", func() { _, err = server.Accept(context.Background()) Expect(err).ToNot(HaveOccurred()) // dial again, and expect that this dial succeeds - conn, err := dial() + conn2, err := dial() Expect(err).ToNot(HaveOccurred()) - defer conn.CloseWithError(0, "") + defer conn2.CloseWithError(0, "") time.Sleep(25 * time.Millisecond) // wait a bit for the connection to be queued - _, err = dial() - Expect(err).To(HaveOccurred()) + conn3, err := dial() + Expect(err).ToNot(HaveOccurred()) + ctx, cancel = context.WithTimeout(context.Background(), 500*time.Millisecond) + defer cancel() + _, err = conn3.AcceptStream(ctx) Expect(errors.As(err, &transportErr)).To(BeTrue()) Expect(transportErr.ErrorCode).To(Equal(quic.ConnectionRefused)) }) - It("removes closed connections from the accept queue", func() { + It("also returns closed connections from the accept queue", func() { firstConn, err := dial() Expect(err).ToNot(HaveOccurred()) @@ -350,25 +382,79 @@ var _ = Describe("Handshake tests", func() { } time.Sleep(scaleDuration(20 * time.Millisecond)) // wait a bit for the connection to be queued - _, err = dial() - Expect(err).To(HaveOccurred()) + conn, err := dial() + Expect(err).ToNot(HaveOccurred()) + ctx, cancel := context.WithTimeout(context.Background(), 500*time.Millisecond) + defer cancel() + _, err = conn.AcceptStream(ctx) var transportErr *quic.TransportError Expect(errors.As(err, &transportErr)).To(BeTrue()) Expect(transportErr.ErrorCode).To(Equal(quic.ConnectionRefused)) // Now close the one of the connection that are waiting to be accepted. - // This should free one spot in the queue. - Expect(firstConn.CloseWithError(0, "")) + const appErrCode quic.ApplicationErrorCode = 12345 + Expect(firstConn.CloseWithError(appErrCode, "")) Eventually(firstConn.Context().Done()).Should(BeClosed()) time.Sleep(scaleDuration(200 * time.Millisecond)) - // dial again, and expect that this dial succeeds - _, err = dial() + // dial again, and expect that this fails again + conn2, err := dial() Expect(err).ToNot(HaveOccurred()) - time.Sleep(scaleDuration(20 * time.Millisecond)) // wait a bit for the connection to be queued + ctx, cancel = context.WithTimeout(context.Background(), 500*time.Millisecond) + defer cancel() + _, err = conn2.AcceptStream(ctx) + Expect(errors.As(err, &transportErr)).To(BeTrue()) + Expect(transportErr.ErrorCode).To(Equal(quic.ConnectionRefused)) - _, err = dial() - Expect(err).To(HaveOccurred()) + // now accept all connections + var closedConn quic.Connection + for i := 0; i < protocol.MaxAcceptQueueSize; i++ { + conn, err := server.Accept(context.Background()) + Expect(err).ToNot(HaveOccurred()) + if conn.Context().Err() != nil { + if closedConn != nil { + Fail("only expected a single closed connection") + } + closedConn = conn + } + } + Expect(closedConn).ToNot(BeNil()) // there should be exactly one closed connection + _, err = closedConn.AcceptStream(context.Background()) + var appErr *quic.ApplicationError + Expect(errors.As(err, &appErr)).To(BeTrue()) + Expect(appErr.ErrorCode).To(Equal(appErrCode)) + }) + + It("closes handshaking connections when the server is closed", func() { + laddr, err := net.ResolveUDPAddr("udp", "localhost:0") + Expect(err).ToNot(HaveOccurred()) + udpConn, err := net.ListenUDP("udp", laddr) + Expect(err).ToNot(HaveOccurred()) + tr := &quic.Transport{Conn: udpConn} + addTracer(tr) + defer tr.Close() + tlsConf := &tls.Config{} + done := make(chan struct{}) + tlsConf.GetConfigForClient = func(info *tls.ClientHelloInfo) (*tls.Config, error) { + <-done + return nil, errors.New("closed") + } + ln, err := tr.Listen(tlsConf, getQuicConfig(nil)) + Expect(err).ToNot(HaveOccurred()) + + errChan := make(chan error, 1) + ctx, cancel := context.WithTimeout(context.Background(), 2*time.Second) + defer cancel() + go func() { + defer GinkgoRecover() + _, err := quic.DialAddr(ctx, ln.Addr().String(), getTLSClientConfig(), getQuicConfig(nil)) + errChan <- err + }() + time.Sleep(scaleDuration(20 * time.Millisecond)) // wait a bit for the connection to be queued + Expect(ln.Close()).To(Succeed()) + close(done) + err = <-errChan + var transportErr *quic.TransportError Expect(errors.As(err, &transportErr)).To(BeTrue()) Expect(transportErr.ErrorCode).To(Equal(quic.ConnectionRefused)) }) @@ -474,14 +560,25 @@ var _ = Describe("Handshake tests", func() { It("rejects invalid Retry token with the INVALID_TOKEN error", func() { const rtt = 10 * time.Millisecond - serverConfig.RequireAddressValidation = func(net.Addr) bool { return true } + // The validity period of the retry token is the handshake timeout, // which is twice the handshake idle timeout. // By setting the handshake timeout shorter than the RTT, the token will have expired by the time // it reaches the server. serverConfig.HandshakeIdleTimeout = rtt / 5 - server, err := quic.ListenAddr("localhost:0", getTLSConfig(), serverConfig) + laddr, err := net.ResolveUDPAddr("udp", "localhost:0") + Expect(err).ToNot(HaveOccurred()) + udpConn, err := net.ListenUDP("udp", laddr) + Expect(err).ToNot(HaveOccurred()) + defer udpConn.Close() + tr := &quic.Transport{ + Conn: udpConn, + VerifySourceAddress: func(net.Addr) bool { return true }, + } + addTracer(tr) + defer tr.Close() + server, err := tr.Listen(getTLSConfig(), serverConfig) Expect(err).ToNot(HaveOccurred()) defer server.Close() diff --git a/integrationtests/self/hotswap_test.go b/integrationtests/self/hotswap_test.go index e7dca2b19..ec4df8145 100644 --- a/integrationtests/self/hotswap_test.go +++ b/integrationtests/self/hotswap_test.go @@ -20,7 +20,7 @@ import ( type listenerWrapper struct { http3.QUICEarlyListener listenerClosed bool - count int32 + count atomic.Int32 } func (ln *listenerWrapper) Close() error { @@ -29,14 +29,18 @@ func (ln *listenerWrapper) Close() error { } func (ln *listenerWrapper) Faker() *fakeClosingListener { - atomic.AddInt32(&ln.count, 1) + ln.count.Add(1) ctx, cancel := context.WithCancel(context.Background()) - return &fakeClosingListener{ln, 0, ctx, cancel} + return &fakeClosingListener{ + listenerWrapper: ln, + ctx: ctx, + cancel: cancel, + } } type fakeClosingListener struct { *listenerWrapper - closed int32 + closed atomic.Bool ctx context.Context cancel context.CancelFunc } @@ -47,9 +51,9 @@ func (ln *fakeClosingListener) Accept(ctx context.Context) (quic.EarlyConnection } func (ln *fakeClosingListener) Close() error { - if atomic.CompareAndSwapInt32(&ln.closed, 0, 1) { + if ln.closed.CompareAndSwap(false, true) { ln.cancel() - if atomic.AddInt32(&ln.listenerWrapper.count, -1) == 0 { + if ln.listenerWrapper.count.Add(-1) == 0 { ln.listenerWrapper.Close() } } @@ -145,8 +149,8 @@ var _ = Describe("HTTP3 Server hotswap test", func() { // and only the fake listener should be closed Expect(server1.Close()).NotTo(HaveOccurred()) Eventually(stoppedServing1).Should(BeClosed()) - Expect(fake1.closed).To(Equal(int32(1))) - Expect(fake2.closed).To(Equal(int32(0))) + Expect(fake1.closed.Load()).To(BeTrue()) + Expect(fake2.closed.Load()).To(BeFalse()) Expect(ln.listenerClosed).ToNot(BeTrue()) Expect(client.Transport.(*http3.RoundTripper).Close()).NotTo(HaveOccurred()) @@ -161,7 +165,7 @@ var _ = Describe("HTTP3 Server hotswap test", func() { // close the other server - both the fake and the actual listeners must close now Expect(server2.Close()).NotTo(HaveOccurred()) Eventually(stoppedServing2).Should(BeClosed()) - Expect(fake2.closed).To(Equal(int32(1))) + Expect(fake2.closed.Load()).To(BeTrue()) Expect(ln.listenerClosed).To(BeTrue()) }) }) diff --git a/integrationtests/self/http_test.go b/integrationtests/self/http_test.go index c0d4f22ec..db9787fbb 100644 --- a/integrationtests/self/http_test.go +++ b/integrationtests/self/http_test.go @@ -140,6 +140,26 @@ var _ = Describe("HTTP tests", func() { Expect(resp.Header.Get("Content-Length")).To(Equal(strconv.Itoa(len("foobar")))) }) + It("detects stream errors when server panics when writing response", func() { + respChan := make(chan struct{}) + mux.HandleFunc("/writing_and_panicking", func(w http.ResponseWriter, r *http.Request) { + // no recover here as it will interfere with the handler + w.Write([]byte("foobar")) + w.(http.Flusher).Flush() + // wait for the client to receive the response + <-respChan + panic(http.ErrAbortHandler) + }) + + resp, err := client.Get(fmt.Sprintf("https://localhost:%d/writing_and_panicking", port)) + close(respChan) + Expect(err).ToNot(HaveOccurred()) + body, err := io.ReadAll(resp.Body) + Expect(err).To(HaveOccurred()) + // the body will be a prefix of what's written + Expect(bytes.HasPrefix([]byte("foobar"), body)).To(BeTrue()) + }) + It("requests to different servers with the same udpconn", func() { resp, err := client.Get(fmt.Sprintf("https://localhost:%d/remoteAddr", port)) Expect(err).ToNot(HaveOccurred()) @@ -299,6 +319,21 @@ var _ = Describe("HTTP tests", func() { Expect(string(body)).To(Equal("Hello, World!\n")) }) + It("handles context cancellations", func() { + mux.HandleFunc("/cancel", func(w http.ResponseWriter, r *http.Request) { + <-r.Context().Done() + }) + + ctx, cancel := context.WithCancel(context.Background()) + req, err := http.NewRequestWithContext(ctx, http.MethodGet, fmt.Sprintf("https://localhost:%d/cancel", port), nil) + Expect(err).ToNot(HaveOccurred()) + time.AfterFunc(50*time.Millisecond, cancel) + + _, err = client.Do(req) + Expect(err).To(HaveOccurred()) + Expect(err).To(MatchError(context.Canceled)) + }) + It("cancels requests", func() { handlerCalled := make(chan struct{}) mux.HandleFunc("/cancel", func(w http.ResponseWriter, r *http.Request) { @@ -432,55 +467,109 @@ var _ = Describe("HTTP tests", func() { Eventually(done).Should(BeClosed()) }) - if go120 { - It("supports read deadlines", func() { - mux.HandleFunc("/read-deadline", func(w http.ResponseWriter, r *http.Request) { - defer GinkgoRecover() - err := setReadDeadline(w, time.Now().Add(deadlineDelay)) - Expect(err).ToNot(HaveOccurred()) + It("supports read deadlines", func() { + mux.HandleFunc("/read-deadline", func(w http.ResponseWriter, r *http.Request) { + defer GinkgoRecover() + rc := http.NewResponseController(w) + Expect(rc.SetReadDeadline(time.Now().Add(deadlineDelay))).To(Succeed()) - body, err := io.ReadAll(r.Body) - Expect(err).To(MatchError(os.ErrDeadlineExceeded)) - Expect(body).To(ContainSubstring("aa")) + body, err := io.ReadAll(r.Body) + Expect(err).To(MatchError(os.ErrDeadlineExceeded)) + Expect(body).To(ContainSubstring("aa")) - w.Write([]byte("ok")) - }) + w.Write([]byte("ok")) + }) - expectedEnd := time.Now().Add(deadlineDelay) - resp, err := client.Post( - fmt.Sprintf("https://localhost:%d/read-deadline", port), - "text/plain", - neverEnding('a'), - ) - Expect(err).ToNot(HaveOccurred()) - Expect(resp.StatusCode).To(Equal(200)) + expectedEnd := time.Now().Add(deadlineDelay) + resp, err := client.Post( + fmt.Sprintf("https://localhost:%d/read-deadline", port), + "text/plain", + neverEnding('a'), + ) + Expect(err).ToNot(HaveOccurred()) + Expect(resp.StatusCode).To(Equal(200)) - body, err := io.ReadAll(gbytes.TimeoutReader(resp.Body, 2*deadlineDelay)) - Expect(err).ToNot(HaveOccurred()) - Expect(time.Now().After(expectedEnd)).To(BeTrue()) - Expect(string(body)).To(Equal("ok")) + body, err := io.ReadAll(gbytes.TimeoutReader(resp.Body, 2*deadlineDelay)) + Expect(err).ToNot(HaveOccurred()) + Expect(time.Now().After(expectedEnd)).To(BeTrue()) + Expect(string(body)).To(Equal("ok")) + }) + + It("supports write deadlines", func() { + mux.HandleFunc("/write-deadline", func(w http.ResponseWriter, r *http.Request) { + defer GinkgoRecover() + rc := http.NewResponseController(w) + Expect(rc.SetWriteDeadline(time.Now().Add(deadlineDelay))).To(Succeed()) + + _, err := io.Copy(w, neverEnding('a')) + Expect(err).To(MatchError(os.ErrDeadlineExceeded)) }) - It("supports write deadlines", func() { - mux.HandleFunc("/write-deadline", func(w http.ResponseWriter, r *http.Request) { - defer GinkgoRecover() - err := setWriteDeadline(w, time.Now().Add(deadlineDelay)) - Expect(err).ToNot(HaveOccurred()) + expectedEnd := time.Now().Add(deadlineDelay) - _, err = io.Copy(w, neverEnding('a')) - Expect(err).To(MatchError(os.ErrDeadlineExceeded)) - }) + resp, err := client.Get(fmt.Sprintf("https://localhost:%d/write-deadline", port)) + Expect(err).ToNot(HaveOccurred()) + Expect(resp.StatusCode).To(Equal(200)) - expectedEnd := time.Now().Add(deadlineDelay) + body, err := io.ReadAll(gbytes.TimeoutReader(resp.Body, 2*deadlineDelay)) + Expect(err).ToNot(HaveOccurred()) + Expect(time.Now().After(expectedEnd)).To(BeTrue()) + Expect(string(body)).To(ContainSubstring("aa")) + }) - resp, err := client.Get(fmt.Sprintf("https://localhost:%d/write-deadline", port)) - Expect(err).ToNot(HaveOccurred()) - Expect(resp.StatusCode).To(Equal(200)) + It("sets remote address", func() { + mux.HandleFunc("/remote-addr", func(w http.ResponseWriter, r *http.Request) { + defer GinkgoRecover() + _, ok := r.Context().Value(http3.RemoteAddrContextKey).(net.Addr) + Expect(ok).To(BeTrue()) + }) - body, err := io.ReadAll(gbytes.TimeoutReader(resp.Body, 2*deadlineDelay)) - Expect(err).ToNot(HaveOccurred()) - Expect(time.Now().After(expectedEnd)).To(BeTrue()) - Expect(string(body)).To(ContainSubstring("aa")) + resp, err := client.Get(fmt.Sprintf("https://localhost:%d/remote-addr", port)) + Expect(err).ToNot(HaveOccurred()) + Expect(resp.StatusCode).To(Equal(200)) + }) + + It("sets conn context", func() { + type ctxKey int + server.ConnContext = func(ctx context.Context, c quic.Connection) context.Context { + serv, ok := ctx.Value(http3.ServerContextKey).(*http3.Server) + Expect(ok).To(BeTrue()) + Expect(serv).To(Equal(server)) + + ctx = context.WithValue(ctx, ctxKey(0), "Hello") + ctx = context.WithValue(ctx, ctxKey(1), c) + return ctx + } + mux.HandleFunc("/conn-context", func(w http.ResponseWriter, r *http.Request) { + defer GinkgoRecover() + v, ok := r.Context().Value(ctxKey(0)).(string) + Expect(ok).To(BeTrue()) + Expect(v).To(Equal("Hello")) + + c, ok := r.Context().Value(ctxKey(1)).(quic.Connection) + Expect(ok).To(BeTrue()) + Expect(c).ToNot(BeNil()) + + serv, ok := r.Context().Value(http3.ServerContextKey).(*http3.Server) + Expect(ok).To(BeTrue()) + Expect(serv).To(Equal(server)) }) - } + + resp, err := client.Get(fmt.Sprintf("https://localhost:%d/conn-context", port)) + Expect(err).ToNot(HaveOccurred()) + Expect(resp.StatusCode).To(Equal(200)) + }) + + It("checks the server's settings", func() { + req, err := http.NewRequest(http.MethodGet, fmt.Sprintf("https://localhost:%d/hello", port), nil) + Expect(err).ToNot(HaveOccurred()) + testErr := errors.New("test error") + _, err = rt.RoundTripOpt(req, http3.RoundTripOpt{CheckSettings: func(settings http3.Settings) error { + Expect(settings.EnableExtendedConnect).To(BeTrue()) + Expect(settings.EnableDatagram).To(BeFalse()) + Expect(settings.Other).To(BeEmpty()) + return testErr + }}) + Expect(err).To(MatchError(err)) + }) }) diff --git a/integrationtests/self/mitm_test.go b/integrationtests/self/mitm_test.go index b9491b253..e18dced12 100644 --- a/integrationtests/self/mitm_test.go +++ b/integrationtests/self/mitm_test.go @@ -15,8 +15,8 @@ import ( quic "github.com/refraction-networking/uquic" quicproxy "github.com/refraction-networking/uquic/integrationtests/tools/proxy" "github.com/refraction-networking/uquic/internal/protocol" - "github.com/refraction-networking/uquic/internal/testutils" "github.com/refraction-networking/uquic/internal/wire" + "github.com/refraction-networking/uquic/testutils" . "github.com/onsi/ginkgo/v2" . "github.com/onsi/gomega" @@ -32,7 +32,7 @@ var _ = Describe("MITM test", func() { serverConfig *quic.Config ) - startServerAndProxy := func(delayCb quicproxy.DelayCallback, dropCb quicproxy.DropCallback) (proxyPort int, closeFn func()) { + startServerAndProxy := func(delayCb quicproxy.DelayCallback, dropCb quicproxy.DropCallback, forceAddressValidation bool) (proxyPort int, closeFn func()) { addr, err := net.ResolveUDPAddr("udp", "localhost:0") Expect(err).ToNot(HaveOccurred()) c, err := net.ListenUDP("udp", addr) @@ -41,6 +41,10 @@ var _ = Describe("MITM test", func() { Conn: c, ConnectionIDLength: connIDLen, } + addTracer(serverTransport) + if forceAddressValidation { + serverTransport.VerifySourceAddress = func(net.Addr) bool { return true } + } ln, err := serverTransport.Listen(getTLSConfig(), serverConfig) Expect(err).ToNot(HaveOccurred()) done := make(chan struct{}) @@ -83,6 +87,7 @@ var _ = Describe("MITM test", func() { Conn: clientUDPConn, ConnectionIDLength: connIDLen, } + addTracer(clientTransport) }) Context("unsuccessful attacks", func() { @@ -153,7 +158,7 @@ var _ = Describe("MITM test", func() { } runTest := func(delayCb quicproxy.DelayCallback) { - proxyPort, closeFn := startServerAndProxy(delayCb, nil) + proxyPort, closeFn := startServerAndProxy(delayCb, nil, false) defer closeFn() raddr, err := net.ResolveUDPAddr("udp", fmt.Sprintf("localhost:%d", proxyPort)) Expect(err).ToNot(HaveOccurred()) @@ -196,7 +201,7 @@ var _ = Describe("MITM test", func() { }) runTest := func(dropCb quicproxy.DropCallback) { - proxyPort, closeFn := startServerAndProxy(nil, dropCb) + proxyPort, closeFn := startServerAndProxy(nil, dropCb, false) defer closeFn() raddr, err := net.ResolveUDPAddr("udp", fmt.Sprintf("localhost:%d", proxyPort)) Expect(err).ToNot(HaveOccurred()) @@ -244,17 +249,17 @@ var _ = Describe("MITM test", func() { Context("corrupting packets", func() { const idleTimeout = time.Second - var numCorrupted, numPackets int32 + var numCorrupted, numPackets atomic.Int32 BeforeEach(func() { - numCorrupted = 0 - numPackets = 0 + numCorrupted.Store(0) + numPackets.Store(0) serverConfig.MaxIdleTimeout = idleTimeout }) AfterEach(func() { - num := atomic.LoadInt32(&numCorrupted) - fmt.Fprintf(GinkgoWriter, "Corrupted %d of %d packets.", num, atomic.LoadInt32(&numPackets)) + num := numCorrupted.Load() + fmt.Fprintf(GinkgoWriter, "Corrupted %d of %d packets.", num, numPackets.Load()) Expect(num).To(BeNumerically(">=", 1)) // If the packet containing the CONNECTION_CLOSE is corrupted, // we have to wait for the connection to time out. @@ -266,13 +271,13 @@ var _ = Describe("MITM test", func() { dropCb := func(dir quicproxy.Direction, raw []byte) bool { defer GinkgoRecover() if dir == quicproxy.DirectionIncoming { - atomic.AddInt32(&numPackets, 1) + numPackets.Add(1) if rand.Intn(interval) == 0 { pos := rand.Intn(len(raw)) raw[pos] = byte(rand.Intn(256)) _, err := clientTransport.WriteTo(raw, serverTransport.Conn.LocalAddr()) Expect(err).ToNot(HaveOccurred()) - atomic.AddInt32(&numCorrupted, 1) + numCorrupted.Add(1) return true } } @@ -286,13 +291,13 @@ var _ = Describe("MITM test", func() { dropCb := func(dir quicproxy.Direction, raw []byte) bool { defer GinkgoRecover() if dir == quicproxy.DirectionOutgoing { - atomic.AddInt32(&numPackets, 1) + numPackets.Add(1) if rand.Intn(interval) == 0 { pos := rand.Intn(len(raw)) raw[pos] = byte(rand.Intn(256)) _, err := serverTransport.WriteTo(raw, clientTransport.Conn.LocalAddr()) Expect(err).ToNot(HaveOccurred()) - atomic.AddInt32(&numCorrupted, 1) + numCorrupted.Add(1) return true } } @@ -310,17 +315,16 @@ var _ = Describe("MITM test", func() { const rtt = 20 * time.Millisecond - runTest := func(delayCb quicproxy.DelayCallback) (closeFn func(), err error) { - proxyPort, serverCloseFn := startServerAndProxy(delayCb, nil) + runTest := func(proxyPort int) (closeFn func(), err error) { raddr, err := net.ResolveUDPAddr("udp", fmt.Sprintf("localhost:%d", proxyPort)) Expect(err).ToNot(HaveOccurred()) _, err = clientTransport.Dial( context.Background(), raddr, getTLSClientConfig(), - getQuicConfig(&quic.Config{HandshakeIdleTimeout: 2 * time.Second}), + getQuicConfig(&quic.Config{HandshakeIdleTimeout: scaleDuration(200 * time.Millisecond)}), ) - return func() { clientTransport.Close(); serverCloseFn() }, err + return func() { clientTransport.Close() }, err } // fails immediately because client connection closes when it can't find compatible version @@ -338,7 +342,7 @@ var _ = Describe("MITM test", func() { } // Create fake version negotiation packet with no supported versions - versions := []protocol.VersionNumber{} + versions := []protocol.Version{} packet := wire.ComposeVersionNegotiation( protocol.ArbitraryLenConnectionID(hdr.SrcConnectionID.Bytes()), protocol.ArbitraryLenConnectionID(hdr.DestConnectionID.Bytes()), @@ -352,7 +356,9 @@ var _ = Describe("MITM test", func() { } return rtt / 2 } - closeFn, err := runTest(delayCb) + proxyPort, serverCloseFn := startServerAndProxy(delayCb, nil, false) + defer serverCloseFn() + closeFn, err := runTest(proxyPort) defer closeFn() Expect(err).To(HaveOccurred()) vnErr := &quic.VersionNegotiationError{} @@ -363,8 +369,7 @@ var _ = Describe("MITM test", func() { // times out, because client doesn't accept subsequent real retry packets from server // as it has already accepted a retry. // TODO: determine behavior when server does not send Retry packets - It("fails when a forged Retry packet with modified srcConnID is sent to client", func() { - serverConfig.RequireAddressValidation = func(net.Addr) bool { return true } + It("fails when a forged Retry packet with modified Source Connection ID is sent to client", func() { var initialPacketIntercepted bool done := make(chan struct{}) delayCb := func(dir quicproxy.Direction, raw []byte) time.Duration { @@ -388,7 +393,9 @@ var _ = Describe("MITM test", func() { } return rtt / 2 } - closeFn, err := runTest(delayCb) + proxyPort, serverCloseFn := startServerAndProxy(delayCb, nil, true) + defer serverCloseFn() + closeFn, err := runTest(proxyPort) defer closeFn() Expect(err).To(HaveOccurred()) Expect(err.(net.Error).Timeout()).To(BeTrue()) @@ -412,13 +419,15 @@ var _ = Describe("MITM test", func() { } defer close(done) injected = true - initialPacket := testutils.ComposeInitialPacket(hdr.DestConnectionID, hdr.SrcConnectionID, hdr.Version, hdr.DestConnectionID, nil) + initialPacket := testutils.ComposeInitialPacket(hdr.DestConnectionID, hdr.SrcConnectionID, hdr.DestConnectionID, nil, protocol.PerspectiveServer, hdr.Version) _, err = serverTransport.WriteTo(initialPacket, clientTransport.Conn.LocalAddr()) Expect(err).ToNot(HaveOccurred()) } return rtt } - closeFn, err := runTest(delayCb) + proxyPort, serverCloseFn := startServerAndProxy(delayCb, nil, false) + defer serverCloseFn() + closeFn, err := runTest(proxyPort) defer closeFn() Expect(err).To(HaveOccurred()) Expect(err.(net.Error).Timeout()).To(BeTrue()) @@ -442,13 +451,15 @@ var _ = Describe("MITM test", func() { injected = true // Fake Initial with ACK for packet 2 (unsent) ack := &wire.AckFrame{AckRanges: []wire.AckRange{{Smallest: 2, Largest: 2}}} - initialPacket := testutils.ComposeInitialPacket(hdr.DestConnectionID, hdr.SrcConnectionID, hdr.Version, hdr.DestConnectionID, []wire.Frame{ack}) + initialPacket := testutils.ComposeInitialPacket(hdr.DestConnectionID, hdr.SrcConnectionID, hdr.DestConnectionID, []wire.Frame{ack}, protocol.PerspectiveServer, hdr.Version) _, err = serverTransport.WriteTo(initialPacket, clientTransport.Conn.LocalAddr()) Expect(err).ToNot(HaveOccurred()) } return rtt } - closeFn, err := runTest(delayCb) + proxyPort, serverCloseFn := startServerAndProxy(delayCb, nil, false) + defer serverCloseFn() + closeFn, err := runTest(proxyPort) defer closeFn() Expect(err).To(HaveOccurred()) var transportErr *quic.TransportError diff --git a/integrationtests/self/multiplex_test.go b/integrationtests/self/multiplex_test.go index b508f9aa7..438f0e254 100644 --- a/integrationtests/self/multiplex_test.go +++ b/integrationtests/self/multiplex_test.go @@ -73,6 +73,7 @@ var _ = Describe("Multiplexing", func() { Expect(err).ToNot(HaveOccurred()) defer conn.Close() tr := &quic.Transport{Conn: conn} + addTracer(tr) done1 := make(chan struct{}) done2 := make(chan struct{}) @@ -108,6 +109,7 @@ var _ = Describe("Multiplexing", func() { Expect(err).ToNot(HaveOccurred()) defer conn.Close() tr := &quic.Transport{Conn: conn} + addTracer(tr) done1 := make(chan struct{}) done2 := make(chan struct{}) @@ -138,6 +140,7 @@ var _ = Describe("Multiplexing", func() { Expect(err).ToNot(HaveOccurred()) defer conn.Close() tr := &quic.Transport{Conn: conn} + addTracer(tr) server, err := tr.Listen( getTLSConfig(), getQuicConfig(nil), @@ -166,6 +169,7 @@ var _ = Describe("Multiplexing", func() { Expect(err).ToNot(HaveOccurred()) defer conn1.Close() tr1 := &quic.Transport{Conn: conn1} + addTracer(tr1) addr2, err := net.ResolveUDPAddr("udp", "localhost:0") Expect(err).ToNot(HaveOccurred()) @@ -173,6 +177,7 @@ var _ = Describe("Multiplexing", func() { Expect(err).ToNot(HaveOccurred()) defer conn2.Close() tr2 := &quic.Transport{Conn: conn2} + addTracer(tr2) server1, err := tr1.Listen( getTLSConfig(), @@ -219,6 +224,7 @@ var _ = Describe("Multiplexing", func() { Expect(err).ToNot(HaveOccurred()) defer conn1.Close() tr1 := &quic.Transport{Conn: conn1} + addTracer(tr1) addr2, err := net.ResolveUDPAddr("udp", "localhost:0") Expect(err).ToNot(HaveOccurred()) @@ -226,6 +232,7 @@ var _ = Describe("Multiplexing", func() { Expect(err).ToNot(HaveOccurred()) defer conn2.Close() tr2 := &quic.Transport{Conn: conn2} + addTracer(tr2) server, err := tr1.Listen(getTLSConfig(), getQuicConfig(nil)) Expect(err).ToNot(HaveOccurred()) @@ -250,6 +257,9 @@ var _ = Describe("Multiplexing", func() { b := make([]byte, packetLen) rand.Read(b[1:]) // keep the first byte set to 0, so it's not classified as a QUIC packet _, err := tr1.WriteTo(b, tr2.Conn.LocalAddr()) + if ctx.Err() != nil { // ctx canceled while Read was executing + return + } Expect(err).ToNot(HaveOccurred()) sentPackets.Add(1) } diff --git a/integrationtests/self/qlog_dir_test.go b/integrationtests/self/qlog_dir_test.go new file mode 100644 index 000000000..529409bca --- /dev/null +++ b/integrationtests/self/qlog_dir_test.go @@ -0,0 +1,90 @@ +package self_test + +import ( + "context" + "os" + "path" + "regexp" + + . "github.com/onsi/ginkgo/v2" + . "github.com/onsi/gomega" + quic "github.com/refraction-networking/uquic" + "github.com/refraction-networking/uquic/qlog" +) + +var _ = Describe("qlog dir tests", Serial, func() { + var originalQlogDirValue string + var tempTestDirPath string + + BeforeEach(func() { + originalQlogDirValue = os.Getenv("QLOGDIR") + var err error + tempTestDirPath, err = os.MkdirTemp("", "temp_test_dir") + Expect(err).ToNot(HaveOccurred()) + }) + + AfterEach(func() { + err := os.Setenv("QLOGDIR", originalQlogDirValue) + Expect(err).ToNot(HaveOccurred()) + err = os.RemoveAll(tempTestDirPath) + Expect(err).ToNot(HaveOccurred()) + }) + + handshake := func() { + serverStopped := make(chan struct{}) + server, err := quic.ListenAddr( + "localhost:0", + getTLSConfig(), + &quic.Config{ + Tracer: qlog.DefaultTracer, + }, + ) + Expect(err).ToNot(HaveOccurred()) + + go func() { + defer GinkgoRecover() + defer close(serverStopped) + for { + if _, err := server.Accept(context.Background()); err != nil { + return + } + } + }() + + conn, err := quic.DialAddr( + context.Background(), + server.Addr().String(), + getTLSClientConfig(), + &quic.Config{ + Tracer: qlog.DefaultTracer, + }, + ) + Expect(err).ToNot(HaveOccurred()) + conn.CloseWithError(0, "") + server.Close() + <-serverStopped + } + + It("environment variable is set", func() { + qlogDir := path.Join(tempTestDirPath, "qlogs") + err := os.Setenv("QLOGDIR", qlogDir) + Expect(err).ToNot(HaveOccurred()) + handshake() + _, err = os.Stat(tempTestDirPath) + qlogDirCreated := !os.IsNotExist(err) + Expect(qlogDirCreated).To(BeTrue()) + childs, err := os.ReadDir(qlogDir) + Expect(err).ToNot(HaveOccurred()) + Expect(len(childs)).To(Equal(2)) + odcids := make([]string, 0) + vantagePoints := make([]string, 0) + qlogFileNameRegexp := regexp.MustCompile(`^([0-f]+)_(client|server).qlog$`) + for _, child := range childs { + matches := qlogFileNameRegexp.FindStringSubmatch(child.Name()) + odcids = append(odcids, matches[1]) + vantagePoints = append(vantagePoints, matches[2]) + } + Expect(odcids[0]).To(Equal(odcids[1])) + Expect(vantagePoints).To(ContainElements("client", "server")) + }) +}) diff --git a/integrationtests/self/resumption_test.go b/integrationtests/self/resumption_test.go index d38ae9a74..d731a250f 100644 --- a/integrationtests/self/resumption_test.go +++ b/integrationtests/self/resumption_test.go @@ -52,22 +52,23 @@ var _ = Describe("TLS session resumption", func() { cache := newClientSessionCache(tls.NewLRUClientSessionCache(10), gets, puts) tlsConf := getTLSClientConfig() tlsConf.ClientSessionCache = cache - conn, err := quic.DialAddr( + conn1, err := quic.DialAddr( context.Background(), fmt.Sprintf("localhost:%d", server.Addr().(*net.UDPAddr).Port), tlsConf, getQuicConfig(nil), ) Expect(err).ToNot(HaveOccurred()) + defer conn1.CloseWithError(0, "") var sessionKey string Eventually(puts).Should(Receive(&sessionKey)) - Expect(conn.ConnectionState().TLS.DidResume).To(BeFalse()) + Expect(conn1.ConnectionState().TLS.DidResume).To(BeFalse()) serverConn, err := server.Accept(context.Background()) Expect(err).ToNot(HaveOccurred()) Expect(serverConn.ConnectionState().TLS.DidResume).To(BeFalse()) - conn, err = quic.DialAddr( + conn2, err := quic.DialAddr( context.Background(), fmt.Sprintf("localhost:%d", server.Addr().(*net.UDPAddr).Port), tlsConf, @@ -75,11 +76,12 @@ var _ = Describe("TLS session resumption", func() { ) Expect(err).ToNot(HaveOccurred()) Expect(gets).To(Receive(Equal(sessionKey))) - Expect(conn.ConnectionState().TLS.DidResume).To(BeTrue()) + Expect(conn2.ConnectionState().TLS.DidResume).To(BeTrue()) serverConn, err = server.Accept(context.Background()) Expect(err).ToNot(HaveOccurred()) Expect(serverConn.ConnectionState().TLS.DidResume).To(BeTrue()) + conn2.CloseWithError(0, "") }) It("doesn't use session resumption, if the config disables it", func() { @@ -94,15 +96,16 @@ var _ = Describe("TLS session resumption", func() { cache := newClientSessionCache(tls.NewLRUClientSessionCache(10), gets, puts) tlsConf := getTLSClientConfig() tlsConf.ClientSessionCache = cache - conn, err := quic.DialAddr( + conn1, err := quic.DialAddr( context.Background(), fmt.Sprintf("localhost:%d", server.Addr().(*net.UDPAddr).Port), tlsConf, getQuicConfig(nil), ) Expect(err).ToNot(HaveOccurred()) + defer conn1.CloseWithError(0, "") Consistently(puts).ShouldNot(Receive()) - Expect(conn.ConnectionState().TLS.DidResume).To(BeFalse()) + Expect(conn1.ConnectionState().TLS.DidResume).To(BeFalse()) ctx, cancel := context.WithTimeout(context.Background(), 2*time.Second) defer cancel() @@ -110,14 +113,15 @@ var _ = Describe("TLS session resumption", func() { Expect(err).ToNot(HaveOccurred()) Expect(serverConn.ConnectionState().TLS.DidResume).To(BeFalse()) - conn, err = quic.DialAddr( + conn2, err := quic.DialAddr( context.Background(), fmt.Sprintf("localhost:%d", server.Addr().(*net.UDPAddr).Port), tlsConf, getQuicConfig(nil), ) Expect(err).ToNot(HaveOccurred()) - Expect(conn.ConnectionState().TLS.DidResume).To(BeFalse()) + Expect(conn2.ConnectionState().TLS.DidResume).To(BeFalse()) + defer conn2.CloseWithError(0, "") serverConn, err = server.Accept(context.Background()) Expect(err).ToNot(HaveOccurred()) @@ -142,7 +146,7 @@ var _ = Describe("TLS session resumption", func() { cache := newClientSessionCache(tls.NewLRUClientSessionCache(10), gets, puts) tlsConf := getTLSClientConfig() tlsConf.ClientSessionCache = cache - conn, err := quic.DialAddr( + conn1, err := quic.DialAddr( context.Background(), fmt.Sprintf("localhost:%d", server.Addr().(*net.UDPAddr).Port), tlsConf, @@ -150,7 +154,8 @@ var _ = Describe("TLS session resumption", func() { ) Expect(err).ToNot(HaveOccurred()) Consistently(puts).ShouldNot(Receive()) - Expect(conn.ConnectionState().TLS.DidResume).To(BeFalse()) + Expect(conn1.ConnectionState().TLS.DidResume).To(BeFalse()) + defer conn1.CloseWithError(0, "") ctx, cancel := context.WithTimeout(context.Background(), 2*time.Second) defer cancel() @@ -158,14 +163,15 @@ var _ = Describe("TLS session resumption", func() { Expect(err).ToNot(HaveOccurred()) Expect(serverConn.ConnectionState().TLS.DidResume).To(BeFalse()) - conn, err = quic.DialAddr( + conn2, err := quic.DialAddr( context.Background(), fmt.Sprintf("localhost:%d", server.Addr().(*net.UDPAddr).Port), tlsConf, getQuicConfig(nil), ) Expect(err).ToNot(HaveOccurred()) - Expect(conn.ConnectionState().TLS.DidResume).To(BeFalse()) + Expect(conn2.ConnectionState().TLS.DidResume).To(BeFalse()) + defer conn2.CloseWithError(0, "") serverConn, err = server.Accept(context.Background()) Expect(err).ToNot(HaveOccurred()) diff --git a/integrationtests/self/self_suite_test.go b/integrationtests/self/self_suite_test.go index d63bedcec..d1ee2c1e4 100644 --- a/integrationtests/self/self_suite_test.go +++ b/integrationtests/self/self_suite_test.go @@ -87,10 +87,9 @@ var ( logBuf *syncedBuffer versionParam string - qlogTracer func(context.Context, logging.Perspective, quic.ConnectionID) *logging.ConnectionTracer enableQlog bool - version quic.VersionNumber + version quic.Version tlsConfig *tls.Config tlsConfigLongChain *tls.Config tlsClientConfig *tls.Config @@ -139,9 +138,6 @@ func init() { } var _ = BeforeSuite(func() { - if enableQlog { - qlogTracer = tools.NewQlogger(GinkgoWriter) - } switch versionParam { case "1": version = quic.Version1 @@ -151,7 +147,7 @@ var _ = BeforeSuite(func() { Fail(fmt.Sprintf("unknown QUIC version: %s", versionParam)) } fmt.Printf("Using QUIC version: %s\n", version) - protocol.SupportedVersions = []quic.VersionNumber{version} + protocol.SupportedVersions = []quic.Version{version} }) func getTLSConfig() *tls.Config { @@ -176,28 +172,48 @@ func getQuicConfig(conf *quic.Config) *quic.Config { } else { conf = conf.Clone() } - if enableQlog { - if conf.Tracer == nil { - conf.Tracer = func(ctx context.Context, p logging.Perspective, connID quic.ConnectionID) *logging.ConnectionTracer { - return logging.NewMultiplexedConnectionTracer( - qlogTracer(ctx, p, connID), - // multiplex it with an empty tracer to check that we're correctly ignoring unset callbacks everywhere - &logging.ConnectionTracer{}, - ) - } - } else if qlogTracer != nil { - origTracer := conf.Tracer - conf.Tracer = func(ctx context.Context, p logging.Perspective, connID quic.ConnectionID) *logging.ConnectionTracer { - return logging.NewMultiplexedConnectionTracer( - qlogTracer(ctx, p, connID), - origTracer(ctx, p, connID), - ) - } + if !enableQlog { + return conf + } + if conf.Tracer == nil { + conf.Tracer = func(ctx context.Context, p logging.Perspective, connID quic.ConnectionID) *logging.ConnectionTracer { + return logging.NewMultiplexedConnectionTracer( + tools.NewQlogConnectionTracer(GinkgoWriter)(ctx, p, connID), + // multiplex it with an empty tracer to check that we're correctly ignoring unset callbacks everywhere + &logging.ConnectionTracer{}, + ) } + return conf + } + origTracer := conf.Tracer + conf.Tracer = func(ctx context.Context, p logging.Perspective, connID quic.ConnectionID) *logging.ConnectionTracer { + return logging.NewMultiplexedConnectionTracer( + tools.NewQlogConnectionTracer(GinkgoWriter)(ctx, p, connID), + origTracer(ctx, p, connID), + ) } return conf } +func addTracer(tr *quic.Transport) { + if !enableQlog { + return + } + if tr.Tracer == nil { + tr.Tracer = logging.NewMultiplexedTracer( + tools.QlogTracer(GinkgoWriter), + // multiplex it with an empty tracer to check that we're correctly ignoring unset callbacks everywhere + &logging.Tracer{}, + ) + return + } + origTracer := tr.Tracer + tr.Tracer = logging.NewMultiplexedTracer( + tools.QlogTracer(GinkgoWriter), + origTracer, + ) +} + var _ = BeforeEach(func() { log.SetFlags(log.Ldate | log.Ltime | log.Lmicroseconds) diff --git a/integrationtests/self/timeout_test.go b/integrationtests/self/timeout_test.go index fc5db68ae..45b6ed079 100644 --- a/integrationtests/self/timeout_test.go +++ b/integrationtests/self/timeout_test.go @@ -22,12 +22,12 @@ type faultyConn struct { net.PacketConn MaxPackets int32 - counter int32 + counter atomic.Int32 } func (c *faultyConn) ReadFrom(p []byte) (int, net.Addr, error) { n, addr, err := c.PacketConn.ReadFrom(p) - counter := atomic.AddInt32(&c.counter, 1) + counter := c.counter.Add(1) if counter <= c.MaxPackets { return n, addr, err } @@ -35,7 +35,7 @@ func (c *faultyConn) ReadFrom(p []byte) (int, net.Addr, error) { } func (c *faultyConn) WriteTo(p []byte, addr net.Addr) (int, error) { - counter := atomic.AddInt32(&c.counter, 1) + counter := c.counter.Add(1) if counter <= c.MaxPackets { return c.PacketConn.WriteTo(p, addr) } @@ -185,11 +185,13 @@ var _ = Describe("Timeout tests", func() { Expect(err).ToNot(HaveOccurred()) defer server.Close() + serverConnChan := make(chan quic.Connection, 1) serverConnClosed := make(chan struct{}) go func() { defer GinkgoRecover() conn, err := server.Accept(context.Background()) Expect(err).ToNot(HaveOccurred()) + serverConnChan <- conn conn.AcceptStream(context.Background()) // blocks until the connection is closed close(serverConnClosed) }() @@ -240,7 +242,7 @@ var _ = Describe("Timeout tests", func() { Consistently(serverConnClosed).ShouldNot(BeClosed()) // make the go routine return - Expect(server.Close()).To(Succeed()) + (<-serverConnChan).CloseWithError(0, "") Eventually(serverConnClosed).Should(BeClosed()) }) @@ -266,11 +268,13 @@ var _ = Describe("Timeout tests", func() { Expect(err).ToNot(HaveOccurred()) defer proxy.Close() + serverConnChan := make(chan quic.Connection, 1) serverConnClosed := make(chan struct{}) go func() { defer GinkgoRecover() conn, err := server.Accept(context.Background()) Expect(err).ToNot(HaveOccurred()) + serverConnChan <- conn <-conn.Context().Done() // block until the connection is closed close(serverConnClosed) }() @@ -309,7 +313,7 @@ var _ = Describe("Timeout tests", func() { Consistently(serverConnClosed).ShouldNot(BeClosed()) // make the go routine return - Expect(server.Close()).To(Succeed()) + (<-serverConnChan).CloseWithError(0, "") Eventually(serverConnClosed).Should(BeClosed()) }) }) @@ -325,11 +329,13 @@ var _ = Describe("Timeout tests", func() { Expect(err).ToNot(HaveOccurred()) defer server.Close() + serverConnChan := make(chan quic.Connection, 1) serverConnClosed := make(chan struct{}) go func() { defer GinkgoRecover() conn, err := server.Accept(context.Background()) Expect(err).ToNot(HaveOccurred()) + serverConnChan <- conn conn.AcceptStream(context.Background()) // blocks until the connection is closed close(serverConnClosed) }() @@ -370,7 +376,7 @@ var _ = Describe("Timeout tests", func() { _, err = str.Write([]byte("foobar")) checkTimeoutError(err) - Expect(server.Close()).To(Succeed()) + (<-serverConnChan).CloseWithError(0, "") Eventually(serverConnClosed).Should(BeClosed()) }) diff --git a/integrationtests/self/tracer_test.go b/integrationtests/self/tracer_test.go index cdd5c0d8c..27f418a0c 100644 --- a/integrationtests/self/tracer_test.go +++ b/integrationtests/self/tracer_test.go @@ -19,7 +19,7 @@ import ( . "github.com/onsi/gomega" ) -var _ = Describe("Handshake tests", func() { +var _ = Describe("Tracer tests", func() { addTracers := func(pers protocol.Perspective, conf *quic.Config) *quic.Config { enableQlog := mrand.Int()%3 != 0 enableCustomTracer := mrand.Int()%3 != 0 @@ -30,10 +30,10 @@ var _ = Describe("Handshake tests", func() { if enableQlog { tracerConstructors = append(tracerConstructors, func(_ context.Context, p logging.Perspective, connID quic.ConnectionID) *logging.ConnectionTracer { if mrand.Int()%2 == 0 { // simulate that a qlog collector might only want to log some connections - fmt.Fprintf(GinkgoWriter, "%s qlog tracer deciding to not trace connection %x\n", p, connID) + fmt.Fprintf(GinkgoWriter, "%s qlog tracer deciding to not trace connection %s\n", p, connID) return nil } - fmt.Fprintf(GinkgoWriter, "%s qlog tracing connection %x\n", p, connID) + fmt.Fprintf(GinkgoWriter, "%s qlog tracing connection %s\n", p, connID) return qlog.NewConnectionTracer(utils.NewBufferedWriteCloser(bufio.NewWriter(&bytes.Buffer{}), io.NopCloser(nil)), p, connID) }) } diff --git a/integrationtests/self/uni_stream_test.go b/integrationtests/self/uni_stream_test.go index ef0f08424..2c177c13d 100644 --- a/integrationtests/self/uni_stream_test.go +++ b/integrationtests/self/uni_stream_test.go @@ -142,5 +142,6 @@ var _ = Describe("Unidirectional Streams", func() { runReceivingPeer(client) <-done1 <-done2 + client.CloseWithError(0, "") }) }) diff --git a/integrationtests/self/zero_rtt_oldgo_test.go b/integrationtests/self/zero_rtt_oldgo_test.go deleted file mode 100644 index b59726529..000000000 --- a/integrationtests/self/zero_rtt_oldgo_test.go +++ /dev/null @@ -1,916 +0,0 @@ -//go:build !go1.21 - -package self_test - -import ( - "context" - "fmt" - "io" - mrand "math/rand" - "net" - "sync" - "sync/atomic" - "time" - - quic "github.com/refraction-networking/uquic" - tls "github.com/refraction-networking/utls" - - quicproxy "github.com/refraction-networking/uquic/integrationtests/tools/proxy" - "github.com/refraction-networking/uquic/internal/protocol" - "github.com/refraction-networking/uquic/internal/wire" - "github.com/refraction-networking/uquic/logging" - - . "github.com/onsi/ginkgo/v2" - . "github.com/onsi/gomega" -) - -var _ = Describe("0-RTT", func() { - rtt := scaleDuration(5 * time.Millisecond) - - runCountingProxy := func(serverPort int) (*quicproxy.QuicProxy, *uint32) { - var num0RTTPackets uint32 // to be used as an atomic - proxy, err := quicproxy.NewQuicProxy("localhost:0", &quicproxy.Opts{ - RemoteAddr: fmt.Sprintf("localhost:%d", serverPort), - DelayPacket: func(_ quicproxy.Direction, data []byte) time.Duration { - for len(data) > 0 { - if !wire.IsLongHeaderPacket(data[0]) { - break - } - hdr, _, rest, err := wire.ParsePacket(data) - Expect(err).ToNot(HaveOccurred()) - if hdr.Type == protocol.PacketType0RTT { - atomic.AddUint32(&num0RTTPackets, 1) - break - } - data = rest - } - return rtt / 2 - }, - }) - Expect(err).ToNot(HaveOccurred()) - - return proxy, &num0RTTPackets - } - - dialAndReceiveSessionTicket := func(serverConf *quic.Config) (*tls.Config, *tls.Config) { - tlsConf := getTLSConfig() - if serverConf == nil { - serverConf = getQuicConfig(nil) - } - serverConf.Allow0RTT = true - ln, err := quic.ListenAddrEarly( - "localhost:0", - tlsConf, - serverConf, - ) - Expect(err).ToNot(HaveOccurred()) - defer ln.Close() - - proxy, err := quicproxy.NewQuicProxy("localhost:0", &quicproxy.Opts{ - RemoteAddr: fmt.Sprintf("localhost:%d", ln.Addr().(*net.UDPAddr).Port), - DelayPacket: func(_ quicproxy.Direction, data []byte) time.Duration { return rtt / 2 }, - }) - Expect(err).ToNot(HaveOccurred()) - defer proxy.Close() - - // dial the first connection in order to receive a session ticket - done := make(chan struct{}) - go func() { - defer GinkgoRecover() - defer close(done) - conn, err := ln.Accept(context.Background()) - Expect(err).ToNot(HaveOccurred()) - <-conn.Context().Done() - }() - - clientConf := getTLSClientConfig() - gets := make(chan string, 100) - puts := make(chan string, 100) - clientConf.ClientSessionCache = newClientSessionCache(tls.NewLRUClientSessionCache(100), gets, puts) - conn, err := quic.DialAddr( - context.Background(), - fmt.Sprintf("localhost:%d", proxy.LocalPort()), - clientConf, - getQuicConfig(nil), - ) - Expect(err).ToNot(HaveOccurred()) - Eventually(puts).Should(Receive()) - // received the session ticket. We're done here. - Expect(conn.CloseWithError(0, "")).To(Succeed()) - Eventually(done).Should(BeClosed()) - return tlsConf, clientConf - } - - transfer0RTTData := func( - ln *quic.EarlyListener, - proxyPort int, - connIDLen int, - clientTLSConf *tls.Config, - clientConf *quic.Config, - testdata []byte, // data to transfer - ) { - // accept the second connection, and receive the data sent in 0-RTT - done := make(chan struct{}) - go func() { - defer GinkgoRecover() - conn, err := ln.Accept(context.Background()) - Expect(err).ToNot(HaveOccurred()) - str, err := conn.AcceptStream(context.Background()) - Expect(err).ToNot(HaveOccurred()) - data, err := io.ReadAll(str) - Expect(err).ToNot(HaveOccurred()) - Expect(data).To(Equal(testdata)) - Expect(str.Close()).To(Succeed()) - Expect(conn.ConnectionState().Used0RTT).To(BeTrue()) - <-conn.Context().Done() - close(done) - }() - - if clientConf == nil { - clientConf = getQuicConfig(nil) - } - var conn quic.EarlyConnection - if connIDLen == 0 { - var err error - conn, err = quic.DialAddrEarly( - context.Background(), - fmt.Sprintf("localhost:%d", proxyPort), - clientTLSConf, - clientConf, - ) - Expect(err).ToNot(HaveOccurred()) - } else { - addr, err := net.ResolveUDPAddr("udp", "localhost:0") - Expect(err).ToNot(HaveOccurred()) - udpConn, err := net.ListenUDP("udp", addr) - Expect(err).ToNot(HaveOccurred()) - defer udpConn.Close() - tr := &quic.Transport{ - Conn: udpConn, - ConnectionIDLength: connIDLen, - } - defer tr.Close() - conn, err = tr.DialEarly( - context.Background(), - &net.UDPAddr{IP: net.IPv4(127, 0, 0, 1), Port: proxyPort}, - clientTLSConf, - clientConf, - ) - Expect(err).ToNot(HaveOccurred()) - } - defer conn.CloseWithError(0, "") - str, err := conn.OpenStream() - Expect(err).ToNot(HaveOccurred()) - _, err = str.Write(testdata) - Expect(err).ToNot(HaveOccurred()) - Expect(str.Close()).To(Succeed()) - <-conn.HandshakeComplete() - Expect(conn.ConnectionState().Used0RTT).To(BeTrue()) - io.ReadAll(str) // wait for the EOF from the server to arrive before closing the conn - conn.CloseWithError(0, "") - Eventually(done).Should(BeClosed()) - Eventually(conn.Context().Done()).Should(BeClosed()) - } - - check0RTTRejected := func( - ln *quic.EarlyListener, - proxyPort int, - clientConf *tls.Config, - ) { - conn, err := quic.DialAddrEarly( - context.Background(), - fmt.Sprintf("localhost:%d", proxyPort), - clientConf, - getQuicConfig(nil), - ) - Expect(err).ToNot(HaveOccurred()) - str, err := conn.OpenUniStream() - Expect(err).ToNot(HaveOccurred()) - _, err = str.Write(make([]byte, 3000)) - Expect(err).ToNot(HaveOccurred()) - Expect(str.Close()).To(Succeed()) - Expect(conn.ConnectionState().Used0RTT).To(BeFalse()) - - // make sure the server doesn't process the data - ctx, cancel := context.WithTimeout(context.Background(), scaleDuration(50*time.Millisecond)) - defer cancel() - serverConn, err := ln.Accept(ctx) - Expect(err).ToNot(HaveOccurred()) - Expect(serverConn.ConnectionState().Used0RTT).To(BeFalse()) - _, err = serverConn.AcceptUniStream(ctx) - Expect(err).To(Equal(context.DeadlineExceeded)) - Expect(serverConn.CloseWithError(0, "")).To(Succeed()) - Eventually(conn.Context().Done()).Should(BeClosed()) - } - - // can be used to extract 0-RTT from a packetCounter - get0RTTPackets := func(packets []packet) []protocol.PacketNumber { - var zeroRTTPackets []protocol.PacketNumber - for _, p := range packets { - if p.hdr.Type == protocol.PacketType0RTT { - zeroRTTPackets = append(zeroRTTPackets, p.hdr.PacketNumber) - } - } - return zeroRTTPackets - } - - for _, l := range []int{0, 15} { - connIDLen := l - - It(fmt.Sprintf("transfers 0-RTT data, with %d byte connection IDs", connIDLen), func() { - tlsConf, clientTLSConf := dialAndReceiveSessionTicket(nil) - - counter, tracer := newPacketTracer() - ln, err := quic.ListenAddrEarly( - "localhost:0", - tlsConf, - getQuicConfig(&quic.Config{ - Allow0RTT: true, - Tracer: newTracer(tracer), - }), - ) - Expect(err).ToNot(HaveOccurred()) - defer ln.Close() - - proxy, num0RTTPackets := runCountingProxy(ln.Addr().(*net.UDPAddr).Port) - defer proxy.Close() - - transfer0RTTData( - ln, - proxy.LocalPort(), - connIDLen, - clientTLSConf, - getQuicConfig(nil), - PRData, - ) - - var numNewConnIDs int - for _, p := range counter.getRcvdLongHeaderPackets() { - for _, f := range p.frames { - if _, ok := f.(*logging.NewConnectionIDFrame); ok { - numNewConnIDs++ - } - } - } - if connIDLen == 0 { - Expect(numNewConnIDs).To(BeZero()) - } else { - Expect(numNewConnIDs).ToNot(BeZero()) - } - - num0RTT := atomic.LoadUint32(num0RTTPackets) - fmt.Fprintf(GinkgoWriter, "Sent %d 0-RTT packets.", num0RTT) - Expect(num0RTT).ToNot(BeZero()) - zeroRTTPackets := get0RTTPackets(counter.getRcvdLongHeaderPackets()) - Expect(len(zeroRTTPackets)).To(BeNumerically(">", 10)) - Expect(zeroRTTPackets).To(ContainElement(protocol.PacketNumber(0))) - }) - } - - // Test that data intended to be sent with 1-RTT protection is not sent in 0-RTT packets. - It("waits for a connection until the handshake is done", func() { - tlsConf, clientConf := dialAndReceiveSessionTicket(nil) - - zeroRTTData := GeneratePRData(5 << 10) - oneRTTData := PRData - - counter, tracer := newPacketTracer() - ln, err := quic.ListenAddrEarly( - "localhost:0", - tlsConf, - getQuicConfig(&quic.Config{ - Allow0RTT: true, - Tracer: newTracer(tracer), - }), - ) - Expect(err).ToNot(HaveOccurred()) - defer ln.Close() - - // now accept the second connection, and receive the 0-RTT data - go func() { - defer GinkgoRecover() - conn, err := ln.Accept(context.Background()) - Expect(err).ToNot(HaveOccurred()) - str, err := conn.AcceptUniStream(context.Background()) - Expect(err).ToNot(HaveOccurred()) - data, err := io.ReadAll(str) - Expect(err).ToNot(HaveOccurred()) - Expect(data).To(Equal(zeroRTTData)) - str, err = conn.AcceptUniStream(context.Background()) - Expect(err).ToNot(HaveOccurred()) - data, err = io.ReadAll(str) - Expect(err).ToNot(HaveOccurred()) - Expect(data).To(Equal(oneRTTData)) - Expect(conn.CloseWithError(0, "")).To(Succeed()) - }() - - proxy, _ := runCountingProxy(ln.Addr().(*net.UDPAddr).Port) - defer proxy.Close() - - conn, err := quic.DialAddrEarly( - context.Background(), - fmt.Sprintf("localhost:%d", proxy.LocalPort()), - clientConf, - getQuicConfig(nil), - ) - Expect(err).ToNot(HaveOccurred()) - firstStr, err := conn.OpenUniStream() - Expect(err).ToNot(HaveOccurred()) - _, err = firstStr.Write(zeroRTTData) - Expect(err).ToNot(HaveOccurred()) - Expect(firstStr.Close()).To(Succeed()) - - // wait for the handshake to complete - Eventually(conn.HandshakeComplete()).Should(BeClosed()) - str, err := conn.OpenUniStream() - Expect(err).ToNot(HaveOccurred()) - _, err = str.Write(PRData) - Expect(err).ToNot(HaveOccurred()) - Expect(str.Close()).To(Succeed()) - <-conn.Context().Done() - - // check that 0-RTT packets only contain STREAM frames for the first stream - var num0RTT int - for _, p := range counter.getRcvdLongHeaderPackets() { - if p.hdr.Header.Type != protocol.PacketType0RTT { - continue - } - for _, f := range p.frames { - sf, ok := f.(*logging.StreamFrame) - if !ok { - continue - } - num0RTT++ - Expect(sf.StreamID).To(Equal(firstStr.StreamID())) - } - } - fmt.Fprintf(GinkgoWriter, "received %d STREAM frames in 0-RTT packets\n", num0RTT) - Expect(num0RTT).ToNot(BeZero()) - }) - - It("transfers 0-RTT data, when 0-RTT packets are lost", func() { - var ( - num0RTTPackets uint32 // to be used as an atomic - num0RTTDropped uint32 - ) - - tlsConf, clientConf := dialAndReceiveSessionTicket(nil) - - counter, tracer := newPacketTracer() - ln, err := quic.ListenAddrEarly( - "localhost:0", - tlsConf, - getQuicConfig(&quic.Config{ - Allow0RTT: true, - Tracer: newTracer(tracer), - }), - ) - Expect(err).ToNot(HaveOccurred()) - defer ln.Close() - - proxy, err := quicproxy.NewQuicProxy("localhost:0", &quicproxy.Opts{ - RemoteAddr: fmt.Sprintf("localhost:%d", ln.Addr().(*net.UDPAddr).Port), - DelayPacket: func(_ quicproxy.Direction, data []byte) time.Duration { - if wire.IsLongHeaderPacket(data[0]) { - hdr, _, _, err := wire.ParsePacket(data) - Expect(err).ToNot(HaveOccurred()) - if hdr.Type == protocol.PacketType0RTT { - atomic.AddUint32(&num0RTTPackets, 1) - } - } - return rtt / 2 - }, - DropPacket: func(_ quicproxy.Direction, data []byte) bool { - if !wire.IsLongHeaderPacket(data[0]) { - return false - } - hdr, _, _, err := wire.ParsePacket(data) - Expect(err).ToNot(HaveOccurred()) - if hdr.Type == protocol.PacketType0RTT { - // drop 25% of the 0-RTT packets - drop := mrand.Intn(4) == 0 - if drop { - atomic.AddUint32(&num0RTTDropped, 1) - } - return drop - } - return false - }, - }) - Expect(err).ToNot(HaveOccurred()) - defer proxy.Close() - - transfer0RTTData(ln, proxy.LocalPort(), protocol.DefaultConnectionIDLength, clientConf, nil, PRData) - - num0RTT := atomic.LoadUint32(&num0RTTPackets) - numDropped := atomic.LoadUint32(&num0RTTDropped) - fmt.Fprintf(GinkgoWriter, "Sent %d 0-RTT packets. Dropped %d of those.", num0RTT, numDropped) - Expect(numDropped).ToNot(BeZero()) - Expect(num0RTT).ToNot(BeZero()) - Expect(get0RTTPackets(counter.getRcvdLongHeaderPackets())).ToNot(BeEmpty()) - }) - - It("retransmits all 0-RTT data when the server performs a Retry", func() { - var mutex sync.Mutex - var firstConnID, secondConnID *protocol.ConnectionID - var firstCounter, secondCounter protocol.ByteCount - - tlsConf, clientConf := dialAndReceiveSessionTicket(nil) - - countZeroRTTBytes := func(data []byte) (n protocol.ByteCount) { - for len(data) > 0 { - hdr, _, rest, err := wire.ParsePacket(data) - if err != nil { - return - } - data = rest - if hdr.Type == protocol.PacketType0RTT { - n += hdr.Length - 16 /* AEAD tag */ - } - } - return - } - - counter, tracer := newPacketTracer() - ln, err := quic.ListenAddrEarly( - "localhost:0", - tlsConf, - getQuicConfig(&quic.Config{ - RequireAddressValidation: func(net.Addr) bool { return true }, - Allow0RTT: true, - Tracer: newTracer(tracer), - }), - ) - Expect(err).ToNot(HaveOccurred()) - defer ln.Close() - - proxy, err := quicproxy.NewQuicProxy("localhost:0", &quicproxy.Opts{ - RemoteAddr: fmt.Sprintf("localhost:%d", ln.Addr().(*net.UDPAddr).Port), - DelayPacket: func(dir quicproxy.Direction, data []byte) time.Duration { - connID, err := wire.ParseConnectionID(data, 0) - Expect(err).ToNot(HaveOccurred()) - - mutex.Lock() - defer mutex.Unlock() - - if zeroRTTBytes := countZeroRTTBytes(data); zeroRTTBytes > 0 { - if firstConnID == nil { - firstConnID = &connID - firstCounter += zeroRTTBytes - } else if firstConnID != nil && *firstConnID == connID { - Expect(secondConnID).To(BeNil()) - firstCounter += zeroRTTBytes - } else if secondConnID == nil { - secondConnID = &connID - secondCounter += zeroRTTBytes - } else if secondConnID != nil && *secondConnID == connID { - secondCounter += zeroRTTBytes - } else { - Fail("received 3 connection IDs on 0-RTT packets") - } - } - return rtt / 2 - }, - }) - Expect(err).ToNot(HaveOccurred()) - defer proxy.Close() - - transfer0RTTData(ln, proxy.LocalPort(), protocol.DefaultConnectionIDLength, clientConf, nil, GeneratePRData(5000)) // ~5 packets - - mutex.Lock() - defer mutex.Unlock() - Expect(firstCounter).To(BeNumerically("~", 5000+100 /* framing overhead */, 100)) // the FIN bit might be sent extra - Expect(secondCounter).To(BeNumerically("~", firstCounter, 20)) - zeroRTTPackets := get0RTTPackets(counter.getRcvdLongHeaderPackets()) - Expect(len(zeroRTTPackets)).To(BeNumerically(">=", 5)) - Expect(zeroRTTPackets[0]).To(BeNumerically(">=", protocol.PacketNumber(5))) - }) - - It("doesn't reject 0-RTT when the server's transport stream limit increased", func() { - const maxStreams = 1 - tlsConf, clientConf := dialAndReceiveSessionTicket(getQuicConfig(&quic.Config{ - MaxIncomingUniStreams: maxStreams, - })) - - ln, err := quic.ListenAddrEarly( - "localhost:0", - tlsConf, - getQuicConfig(&quic.Config{ - MaxIncomingUniStreams: maxStreams + 1, - Allow0RTT: true, - }), - ) - Expect(err).ToNot(HaveOccurred()) - defer ln.Close() - proxy, _ := runCountingProxy(ln.Addr().(*net.UDPAddr).Port) - defer proxy.Close() - - conn, err := quic.DialAddrEarly( - context.Background(), - fmt.Sprintf("localhost:%d", proxy.LocalPort()), - clientConf, - getQuicConfig(nil), - ) - Expect(err).ToNot(HaveOccurred()) - str, err := conn.OpenUniStream() - Expect(err).ToNot(HaveOccurred()) - _, err = str.Write([]byte("foobar")) - Expect(err).ToNot(HaveOccurred()) - Expect(str.Close()).To(Succeed()) - // The client remembers the old limit and refuses to open a new stream. - _, err = conn.OpenUniStream() - Expect(err).To(HaveOccurred()) - Expect(err.Error()).To(ContainSubstring("too many open streams")) - ctx, cancel := context.WithTimeout(context.Background(), time.Second) - defer cancel() - _, err = conn.OpenUniStreamSync(ctx) - Expect(err).ToNot(HaveOccurred()) - Expect(conn.ConnectionState().Used0RTT).To(BeTrue()) - Expect(conn.CloseWithError(0, "")).To(Succeed()) - }) - - It("rejects 0-RTT when the server's stream limit decreased", func() { - const maxStreams = 42 - tlsConf, clientConf := dialAndReceiveSessionTicket(getQuicConfig(&quic.Config{ - MaxIncomingStreams: maxStreams, - })) - - counter, tracer := newPacketTracer() - ln, err := quic.ListenAddrEarly( - "localhost:0", - tlsConf, - getQuicConfig(&quic.Config{ - MaxIncomingStreams: maxStreams - 1, - Allow0RTT: true, - Tracer: newTracer(tracer), - }), - ) - Expect(err).ToNot(HaveOccurred()) - defer ln.Close() - proxy, num0RTTPackets := runCountingProxy(ln.Addr().(*net.UDPAddr).Port) - defer proxy.Close() - check0RTTRejected(ln, proxy.LocalPort(), clientConf) - - // The client should send 0-RTT packets, but the server doesn't process them. - num0RTT := atomic.LoadUint32(num0RTTPackets) - fmt.Fprintf(GinkgoWriter, "Sent %d 0-RTT packets.", num0RTT) - Expect(num0RTT).ToNot(BeZero()) - Expect(get0RTTPackets(counter.getRcvdLongHeaderPackets())).To(BeEmpty()) - }) - - It("rejects 0-RTT when the ALPN changed", func() { - tlsConf, clientConf := dialAndReceiveSessionTicket(nil) - - // now close the listener and dial new connection with a different ALPN - // clientConf.NextProtos = []string{"new-alpn"} - clientConf.NextProtos = append(clientConf.NextProtos, "new-alpn") - tlsConf.NextProtos = []string{"new-alpn"} - counter, tracer := newPacketTracer() - ln, err := quic.ListenAddrEarly( - "localhost:0", - tlsConf, - getQuicConfig(&quic.Config{ - Allow0RTT: true, - Tracer: newTracer(tracer), - }), - ) - Expect(err).ToNot(HaveOccurred()) - defer ln.Close() - proxy, num0RTTPackets := runCountingProxy(ln.Addr().(*net.UDPAddr).Port) - defer proxy.Close() - - check0RTTRejected(ln, proxy.LocalPort(), clientConf) - - // The client should send 0-RTT packets, but the server doesn't process them. - num0RTT := atomic.LoadUint32(num0RTTPackets) - fmt.Fprintf(GinkgoWriter, "Sent %d 0-RTT packets.", num0RTT) - Expect(num0RTT).ToNot(BeZero()) - Expect(get0RTTPackets(counter.getRcvdLongHeaderPackets())).To(BeEmpty()) - }) - - It("rejects 0-RTT when the application doesn't allow it", func() { - tlsConf, clientConf := dialAndReceiveSessionTicket(nil) - - // now close the listener and dial new connection with a different ALPN - counter, tracer := newPacketTracer() - ln, err := quic.ListenAddrEarly( - "localhost:0", - tlsConf, - getQuicConfig(&quic.Config{ - Allow0RTT: false, // application rejects 0-RTT - Tracer: newTracer(tracer), - }), - ) - Expect(err).ToNot(HaveOccurred()) - defer ln.Close() - proxy, num0RTTPackets := runCountingProxy(ln.Addr().(*net.UDPAddr).Port) - defer proxy.Close() - - check0RTTRejected(ln, proxy.LocalPort(), clientConf) - - // The client should send 0-RTT packets, but the server doesn't process them. - num0RTT := atomic.LoadUint32(num0RTTPackets) - fmt.Fprintf(GinkgoWriter, "Sent %d 0-RTT packets.", num0RTT) - Expect(num0RTT).ToNot(BeZero()) - Expect(get0RTTPackets(counter.getRcvdLongHeaderPackets())).To(BeEmpty()) - }) - - DescribeTable("flow control limits", - func(addFlowControlLimit func(*quic.Config, uint64)) { - counter, tracer := newPacketTracer() - firstConf := getQuicConfig(&quic.Config{Allow0RTT: true}) - addFlowControlLimit(firstConf, 3) - tlsConf, clientConf := dialAndReceiveSessionTicket(firstConf) - - secondConf := getQuicConfig(&quic.Config{ - Allow0RTT: true, - Tracer: newTracer(tracer), - }) - addFlowControlLimit(secondConf, 100) - ln, err := quic.ListenAddrEarly( - "localhost:0", - tlsConf, - secondConf, - ) - Expect(err).ToNot(HaveOccurred()) - defer ln.Close() - proxy, _ := runCountingProxy(ln.Addr().(*net.UDPAddr).Port) - defer proxy.Close() - - conn, err := quic.DialAddrEarly( - context.Background(), - fmt.Sprintf("localhost:%d", proxy.LocalPort()), - clientConf, - getQuicConfig(nil), - ) - Expect(err).ToNot(HaveOccurred()) - str, err := conn.OpenUniStream() - Expect(err).ToNot(HaveOccurred()) - written := make(chan struct{}) - go func() { - defer GinkgoRecover() - defer close(written) - _, err := str.Write([]byte("foobar")) - Expect(err).ToNot(HaveOccurred()) - Expect(str.Close()).To(Succeed()) - }() - - Eventually(written).Should(BeClosed()) - - serverConn, err := ln.Accept(context.Background()) - Expect(err).ToNot(HaveOccurred()) - rstr, err := serverConn.AcceptUniStream(context.Background()) - Expect(err).ToNot(HaveOccurred()) - data, err := io.ReadAll(rstr) - Expect(err).ToNot(HaveOccurred()) - Expect(data).To(Equal([]byte("foobar"))) - Expect(serverConn.ConnectionState().Used0RTT).To(BeTrue()) - Expect(serverConn.CloseWithError(0, "")).To(Succeed()) - Eventually(conn.Context().Done()).Should(BeClosed()) - - var processedFirst bool - for _, p := range counter.getRcvdLongHeaderPackets() { - for _, f := range p.frames { - if sf, ok := f.(*logging.StreamFrame); ok { - if !processedFirst { - // The first STREAM should have been sent in a 0-RTT packet. - // Due to the flow control limit, the STREAM frame was limit to the first 3 bytes. - Expect(p.hdr.Type).To(Equal(protocol.PacketType0RTT)) - Expect(sf.Length).To(BeEquivalentTo(3)) - processedFirst = true - } else { - Fail("STREAM was shouldn't have been sent in 0-RTT") - } - } - } - } - }, - Entry("doesn't reject 0-RTT when the server's transport stream flow control limit increased", func(c *quic.Config, limit uint64) { c.InitialStreamReceiveWindow = limit }), - Entry("doesn't reject 0-RTT when the server's transport connection flow control limit increased", func(c *quic.Config, limit uint64) { c.InitialConnectionReceiveWindow = limit }), - ) - - for _, l := range []int{0, 15} { - connIDLen := l - - It(fmt.Sprintf("correctly deals with 0-RTT rejections, for %d byte connection IDs", connIDLen), func() { - tlsConf, clientConf := dialAndReceiveSessionTicket(nil) - // now dial new connection with different transport parameters - counter, tracer := newPacketTracer() - ln, err := quic.ListenAddrEarly( - "localhost:0", - tlsConf, - getQuicConfig(&quic.Config{ - MaxIncomingUniStreams: 1, - Tracer: newTracer(tracer), - }), - ) - Expect(err).ToNot(HaveOccurred()) - defer ln.Close() - proxy, num0RTTPackets := runCountingProxy(ln.Addr().(*net.UDPAddr).Port) - defer proxy.Close() - - conn, err := quic.DialAddrEarly( - context.Background(), - fmt.Sprintf("localhost:%d", proxy.LocalPort()), - clientConf, - getQuicConfig(nil), - ) - Expect(err).ToNot(HaveOccurred()) - // The client remembers that it was allowed to open 2 uni-directional streams. - firstStr, err := conn.OpenUniStream() - Expect(err).ToNot(HaveOccurred()) - written := make(chan struct{}, 2) - go func() { - defer GinkgoRecover() - defer func() { written <- struct{}{} }() - _, err := firstStr.Write([]byte("first flight")) - Expect(err).ToNot(HaveOccurred()) - }() - secondStr, err := conn.OpenUniStream() - Expect(err).ToNot(HaveOccurred()) - go func() { - defer GinkgoRecover() - defer func() { written <- struct{}{} }() - _, err := secondStr.Write([]byte("first flight")) - Expect(err).ToNot(HaveOccurred()) - }() - - ctx, cancel := context.WithTimeout(context.Background(), time.Second) - defer cancel() - _, err = conn.AcceptStream(ctx) - Expect(err).To(MatchError(quic.Err0RTTRejected)) - Eventually(written).Should(Receive()) - Eventually(written).Should(Receive()) - _, err = firstStr.Write([]byte("foobar")) - Expect(err).To(MatchError(quic.Err0RTTRejected)) - _, err = conn.OpenUniStream() - Expect(err).To(MatchError(quic.Err0RTTRejected)) - - _, err = conn.AcceptStream(ctx) - Expect(err).To(Equal(quic.Err0RTTRejected)) - - newConn := conn.NextConnection() - str, err := newConn.OpenUniStream() - Expect(err).ToNot(HaveOccurred()) - _, err = newConn.OpenUniStream() - Expect(err).To(HaveOccurred()) - Expect(err.Error()).To(ContainSubstring("too many open streams")) - _, err = str.Write([]byte("second flight")) - Expect(err).ToNot(HaveOccurred()) - Expect(str.Close()).To(Succeed()) - Expect(conn.CloseWithError(0, "")).To(Succeed()) - - // The client should send 0-RTT packets, but the server doesn't process them. - num0RTT := atomic.LoadUint32(num0RTTPackets) - fmt.Fprintf(GinkgoWriter, "Sent %d 0-RTT packets.", num0RTT) - Expect(num0RTT).ToNot(BeZero()) - Expect(get0RTTPackets(counter.getRcvdLongHeaderPackets())).To(BeEmpty()) - }) - } - - It("queues 0-RTT packets, if the Initial is delayed", func() { - tlsConf, clientConf := dialAndReceiveSessionTicket(nil) - - counter, tracer := newPacketTracer() - ln, err := quic.ListenAddrEarly( - "localhost:0", - tlsConf, - getQuicConfig(&quic.Config{ - Allow0RTT: true, - Tracer: newTracer(tracer), - }), - ) - Expect(err).ToNot(HaveOccurred()) - defer ln.Close() - proxy, err := quicproxy.NewQuicProxy("localhost:0", &quicproxy.Opts{ - RemoteAddr: ln.Addr().String(), - DelayPacket: func(dir quicproxy.Direction, data []byte) time.Duration { - if dir == quicproxy.DirectionIncoming && wire.IsLongHeaderPacket(data[0]) && data[0]&0x30>>4 == 0 { // Initial packet from client - return rtt/2 + rtt - } - return rtt / 2 - }, - }) - Expect(err).ToNot(HaveOccurred()) - defer proxy.Close() - - transfer0RTTData(ln, proxy.LocalPort(), protocol.DefaultConnectionIDLength, clientConf, nil, PRData) - - Expect(counter.getRcvdLongHeaderPackets()[0].hdr.Type).To(Equal(protocol.PacketTypeInitial)) - zeroRTTPackets := get0RTTPackets(counter.getRcvdLongHeaderPackets()) - Expect(len(zeroRTTPackets)).To(BeNumerically(">", 10)) - Expect(zeroRTTPackets[0]).To(Equal(protocol.PacketNumber(0))) - }) - - It("sends 0-RTT datagrams", func() { - tlsConf, clientTLSConf := dialAndReceiveSessionTicket(getQuicConfig(&quic.Config{ - EnableDatagrams: true, - })) - - counter, tracer := newPacketTracer() - ln, err := quic.ListenAddrEarly( - "localhost:0", - tlsConf, - getQuicConfig(&quic.Config{ - Allow0RTT: true, - EnableDatagrams: true, - Tracer: newTracer(tracer), - }), - ) - Expect(err).ToNot(HaveOccurred()) - defer ln.Close() - - proxy, num0RTTPackets := runCountingProxy(ln.Addr().(*net.UDPAddr).Port) - defer proxy.Close() - - // second connection - sentMessage := GeneratePRData(100) - var receivedMessage []byte - received := make(chan struct{}) - go func() { - defer GinkgoRecover() - defer close(received) - conn, err := ln.Accept(context.Background()) - Expect(err).ToNot(HaveOccurred()) - receivedMessage, err = conn.ReceiveMessage(context.Background()) - Expect(err).ToNot(HaveOccurred()) - Expect(conn.ConnectionState().Used0RTT).To(BeTrue()) - }() - conn, err := quic.DialAddrEarly( - context.Background(), - fmt.Sprintf("localhost:%d", proxy.LocalPort()), - clientTLSConf, - getQuicConfig(&quic.Config{ - EnableDatagrams: true, - }), - ) - Expect(err).ToNot(HaveOccurred()) - Expect(conn.ConnectionState().SupportsDatagrams).To(BeTrue()) - Expect(conn.SendMessage(sentMessage)).To(Succeed()) - <-conn.HandshakeComplete() - <-received - - Expect(conn.ConnectionState().Used0RTT).To(BeTrue()) - Expect(receivedMessage).To(Equal(sentMessage)) - Expect(conn.CloseWithError(0, "")).To(Succeed()) - num0RTT := atomic.LoadUint32(num0RTTPackets) - fmt.Fprintf(GinkgoWriter, "Sent %d 0-RTT packets.", num0RTT) - Expect(num0RTT).ToNot(BeZero()) - zeroRTTPackets := get0RTTPackets(counter.getRcvdLongHeaderPackets()) - Expect(zeroRTTPackets).To(HaveLen(1)) - }) - - It("rejects 0-RTT datagrams when the server doesn't support datagrams anymore", func() { - tlsConf, clientTLSConf := dialAndReceiveSessionTicket(getQuicConfig(&quic.Config{ - EnableDatagrams: true, - })) - - counter, tracer := newPacketTracer() - ln, err := quic.ListenAddrEarly( - "localhost:0", - tlsConf, - getQuicConfig(&quic.Config{ - Allow0RTT: true, - EnableDatagrams: false, - Tracer: newTracer(tracer), - }), - ) - Expect(err).ToNot(HaveOccurred()) - defer ln.Close() - - proxy, num0RTTPackets := runCountingProxy(ln.Addr().(*net.UDPAddr).Port) - defer proxy.Close() - - // second connection - go func() { - defer GinkgoRecover() - conn, err := ln.Accept(context.Background()) - Expect(err).ToNot(HaveOccurred()) - _, err = conn.ReceiveMessage(context.Background()) - Expect(err.Error()).To(Equal("datagram support disabled")) - <-conn.HandshakeComplete() - Expect(conn.ConnectionState().Used0RTT).To(BeFalse()) - }() - conn, err := quic.DialAddrEarly( - context.Background(), - fmt.Sprintf("localhost:%d", proxy.LocalPort()), - clientTLSConf, - getQuicConfig(&quic.Config{ - EnableDatagrams: true, - }), - ) - Expect(err).ToNot(HaveOccurred()) - // the client can temporarily send datagrams but the server doesn't process them. - Expect(conn.ConnectionState().SupportsDatagrams).To(BeTrue()) - Expect(conn.SendMessage(make([]byte, 100))).To(Succeed()) - <-conn.HandshakeComplete() - - Expect(conn.ConnectionState().SupportsDatagrams).To(BeFalse()) - Expect(conn.ConnectionState().Used0RTT).To(BeFalse()) - Expect(conn.CloseWithError(0, "")).To(Succeed()) - num0RTT := atomic.LoadUint32(num0RTTPackets) - fmt.Fprintf(GinkgoWriter, "Sent %d 0-RTT packets.", num0RTT) - Expect(num0RTT).ToNot(BeZero()) - Expect(get0RTTPackets(counter.getRcvdLongHeaderPackets())).To(BeEmpty()) - }) -}) diff --git a/integrationtests/self/zero_rtt_test.go b/integrationtests/self/zero_rtt_test.go index e5ad5d4d3..0d1ebee48 100644 --- a/integrationtests/self/zero_rtt_test.go +++ b/integrationtests/self/zero_rtt_test.go @@ -1,12 +1,9 @@ -//go:build go1.21 - package self_test import ( "context" "fmt" "io" - mrand "math/rand" "net" "sync" "sync/atomic" @@ -57,8 +54,8 @@ func (m metadataClientSessionCache) Put(key string, session *tls.ClientSessionSt var _ = Describe("0-RTT", func() { rtt := scaleDuration(5 * time.Millisecond) - runCountingProxy := func(serverPort int) (*quicproxy.QuicProxy, *uint32) { - var num0RTTPackets uint32 // to be used as an atomic + runCountingProxy := func(serverPort int) (*quicproxy.QuicProxy, *atomic.Uint32) { + var num0RTTPackets atomic.Uint32 proxy, err := quicproxy.NewQuicProxy("localhost:0", &quicproxy.Opts{ RemoteAddr: fmt.Sprintf("localhost:%d", serverPort), DelayPacket: func(_ quicproxy.Direction, data []byte) time.Duration { @@ -69,7 +66,7 @@ var _ = Describe("0-RTT", func() { hdr, _, rest, err := wire.ParsePacket(data) Expect(err).ToNot(HaveOccurred()) if hdr.Type == protocol.PacketType0RTT { - atomic.AddUint32(&num0RTTPackets, 1) + num0RTTPackets.Add(1) break } data = rest @@ -179,6 +176,7 @@ var _ = Describe("0-RTT", func() { Conn: udpConn, ConnectionIDLength: connIDLen, } + addTracer(tr) defer tr.Close() conn, err = tr.DialEarly( context.Background(), @@ -290,7 +288,7 @@ var _ = Describe("0-RTT", func() { Expect(numNewConnIDs).ToNot(BeZero()) } - num0RTT := atomic.LoadUint32(num0RTTPackets) + num0RTT := num0RTTPackets.Load() fmt.Fprintf(GinkgoWriter, "Sent %d 0-RTT packets.", num0RTT) Expect(num0RTT).ToNot(BeZero()) zeroRTTPackets := get0RTTPackets(counter.getRcvdLongHeaderPackets()) @@ -383,10 +381,7 @@ var _ = Describe("0-RTT", func() { }) It("transfers 0-RTT data, when 0-RTT packets are lost", func() { - var ( - num0RTTPackets uint32 // to be used as an atomic - num0RTTDropped uint32 - ) + var num0RTTPackets, numDropped atomic.Uint32 tlsConf := getTLSConfig() clientConf := getTLSClientConfig() @@ -405,17 +400,8 @@ var _ = Describe("0-RTT", func() { defer ln.Close() proxy, err := quicproxy.NewQuicProxy("localhost:0", &quicproxy.Opts{ - RemoteAddr: fmt.Sprintf("localhost:%d", ln.Addr().(*net.UDPAddr).Port), - DelayPacket: func(_ quicproxy.Direction, data []byte) time.Duration { - if wire.IsLongHeaderPacket(data[0]) { - hdr, _, _, err := wire.ParsePacket(data) - Expect(err).ToNot(HaveOccurred()) - if hdr.Type == protocol.PacketType0RTT { - atomic.AddUint32(&num0RTTPackets, 1) - } - } - return rtt / 2 - }, + RemoteAddr: fmt.Sprintf("localhost:%d", ln.Addr().(*net.UDPAddr).Port), + DelayPacket: func(_ quicproxy.Direction, data []byte) time.Duration { return rtt / 2 }, DropPacket: func(_ quicproxy.Direction, data []byte) bool { if !wire.IsLongHeaderPacket(data[0]) { return false @@ -423,10 +409,11 @@ var _ = Describe("0-RTT", func() { hdr, _, _, err := wire.ParsePacket(data) Expect(err).ToNot(HaveOccurred()) if hdr.Type == protocol.PacketType0RTT { + count := num0RTTPackets.Add(1) // drop 25% of the 0-RTT packets - drop := mrand.Intn(4) == 0 + drop := count%4 == 0 if drop { - atomic.AddUint32(&num0RTTDropped, 1) + numDropped.Add(1) } return drop } @@ -438,10 +425,9 @@ var _ = Describe("0-RTT", func() { transfer0RTTData(ln, proxy.LocalPort(), protocol.DefaultConnectionIDLength, clientConf, nil, PRData) - num0RTT := atomic.LoadUint32(&num0RTTPackets) - numDropped := atomic.LoadUint32(&num0RTTDropped) - fmt.Fprintf(GinkgoWriter, "Sent %d 0-RTT packets. Dropped %d of those.", num0RTT, numDropped) - Expect(numDropped).ToNot(BeZero()) + num0RTT := num0RTTPackets.Load() + fmt.Fprintf(GinkgoWriter, "Sent %d 0-RTT packets. Dropped %d of those.", num0RTT, numDropped.Load()) + Expect(numDropped.Load()).ToNot(BeZero()) Expect(num0RTT).ToNot(BeZero()) Expect(get0RTTPackets(counter.getRcvdLongHeaderPackets())).ToNot(BeEmpty()) }) @@ -470,14 +456,20 @@ var _ = Describe("0-RTT", func() { } counter, tracer := newPacketTracer() - ln, err := quic.ListenAddrEarly( - "localhost:0", + laddr, err := net.ResolveUDPAddr("udp", "localhost:0") + Expect(err).ToNot(HaveOccurred()) + udpConn, err := net.ListenUDP("udp", laddr) + Expect(err).ToNot(HaveOccurred()) + defer udpConn.Close() + tr := &quic.Transport{ + Conn: udpConn, + VerifySourceAddress: func(net.Addr) bool { return true }, + } + addTracer(tr) + defer tr.Close() + ln, err := tr.ListenEarly( tlsConf, - getQuicConfig(&quic.Config{ - RequireAddressValidation: func(net.Addr) bool { return true }, - Allow0RTT: true, - Tracer: newTracer(tracer), - }), + getQuicConfig(&quic.Config{Allow0RTT: true, Tracer: newTracer(tracer)}), ) Expect(err).ToNot(HaveOccurred()) defer ln.Close() @@ -524,6 +516,39 @@ var _ = Describe("0-RTT", func() { Expect(zeroRTTPackets[0]).To(BeNumerically(">=", protocol.PacketNumber(5))) }) + It("doesn't use 0-RTT when Dial is used for the resumed connection", func() { + tlsConf := getTLSConfig() + clientConf := getTLSClientConfig() + dialAndReceiveSessionTicket(tlsConf, getQuicConfig(nil), clientConf) + + ln, err := quic.ListenAddrEarly( + "localhost:0", + tlsConf, + getQuicConfig(&quic.Config{Allow0RTT: true}), + ) + Expect(err).ToNot(HaveOccurred()) + defer ln.Close() + proxy, num0RTTPackets := runCountingProxy(ln.Addr().(*net.UDPAddr).Port) + defer proxy.Close() + + conn, err := quic.DialAddr( + context.Background(), + fmt.Sprintf("localhost:%d", proxy.LocalPort()), + clientConf, + getQuicConfig(nil), + ) + Expect(err).ToNot(HaveOccurred()) + defer conn.CloseWithError(0, "") + Expect(conn.ConnectionState().TLS.DidResume).To(BeTrue()) + Expect(conn.ConnectionState().Used0RTT).To(BeFalse()) + Expect(num0RTTPackets.Load()).To(BeZero()) + + serverConn, err := ln.Accept(context.Background()) + Expect(err).ToNot(HaveOccurred()) + Expect(serverConn.ConnectionState().TLS.DidResume).To(BeTrue()) + Expect(serverConn.ConnectionState().Used0RTT).To(BeFalse()) + }) + It("doesn't reject 0-RTT when the server's transport stream limit increased", func() { const maxStreams = 1 tlsConf := getTLSConfig() @@ -595,7 +620,7 @@ var _ = Describe("0-RTT", func() { check0RTTRejected(ln, proxy.LocalPort(), clientConf) // The client should send 0-RTT packets, but the server doesn't process them. - num0RTT := atomic.LoadUint32(num0RTTPackets) + num0RTT := num0RTTPackets.Load() fmt.Fprintf(GinkgoWriter, "Sent %d 0-RTT packets.", num0RTT) Expect(num0RTT).ToNot(BeZero()) Expect(get0RTTPackets(counter.getRcvdLongHeaderPackets())).To(BeEmpty()) @@ -628,7 +653,7 @@ var _ = Describe("0-RTT", func() { check0RTTRejected(ln, proxy.LocalPort(), clientConf) // The client should send 0-RTT packets, but the server doesn't process them. - num0RTT := atomic.LoadUint32(num0RTTPackets) + num0RTT := num0RTTPackets.Load() fmt.Fprintf(GinkgoWriter, "Sent %d 0-RTT packets.", num0RTT) Expect(num0RTT).ToNot(BeZero()) Expect(get0RTTPackets(counter.getRcvdLongHeaderPackets())).To(BeEmpty()) @@ -657,12 +682,55 @@ var _ = Describe("0-RTT", func() { check0RTTRejected(ln, proxy.LocalPort(), clientConf) // The client should send 0-RTT packets, but the server doesn't process them. - num0RTT := atomic.LoadUint32(num0RTTPackets) + num0RTT := num0RTTPackets.Load() fmt.Fprintf(GinkgoWriter, "Sent %d 0-RTT packets.", num0RTT) Expect(num0RTT).ToNot(BeZero()) Expect(get0RTTPackets(counter.getRcvdLongHeaderPackets())).To(BeEmpty()) }) + It("doesn't use 0-RTT, if the server didn't enable it", func() { + server, err := quic.ListenAddr("localhost:0", getTLSConfig(), getQuicConfig(nil)) + Expect(err).ToNot(HaveOccurred()) + defer server.Close() + + gets := make(chan string, 100) + puts := make(chan string, 100) + cache := newClientSessionCache(tls.NewLRUClientSessionCache(10), gets, puts) + tlsConf := getTLSClientConfig() + tlsConf.ClientSessionCache = cache + conn1, err := quic.DialAddr( + context.Background(), + fmt.Sprintf("localhost:%d", server.Addr().(*net.UDPAddr).Port), + tlsConf, + getQuicConfig(nil), + ) + Expect(err).ToNot(HaveOccurred()) + defer conn1.CloseWithError(0, "") + var sessionKey string + Eventually(puts).Should(Receive(&sessionKey)) + Expect(conn1.ConnectionState().TLS.DidResume).To(BeFalse()) + + serverConn, err := server.Accept(context.Background()) + Expect(err).ToNot(HaveOccurred()) + Expect(serverConn.ConnectionState().TLS.DidResume).To(BeFalse()) + + conn2, err := quic.DialAddrEarly( + context.Background(), + fmt.Sprintf("localhost:%d", server.Addr().(*net.UDPAddr).Port), + tlsConf, + getQuicConfig(nil), + ) + Expect(err).ToNot(HaveOccurred()) + Expect(gets).To(Receive(Equal(sessionKey))) + Expect(conn2.ConnectionState().TLS.DidResume).To(BeTrue()) + + serverConn, err = server.Accept(context.Background()) + Expect(err).ToNot(HaveOccurred()) + Expect(serverConn.ConnectionState().TLS.DidResume).To(BeTrue()) + Expect(serverConn.ConnectionState().Used0RTT).To(BeFalse()) + conn2.CloseWithError(0, "") + }) + DescribeTable("flow control limits", func(addFlowControlLimit func(*quic.Config, uint64)) { counter, tracer := newPacketTracer() @@ -813,7 +881,7 @@ var _ = Describe("0-RTT", func() { Expect(conn.CloseWithError(0, "")).To(Succeed()) // The client should send 0-RTT packets, but the server doesn't process them. - num0RTT := atomic.LoadUint32(num0RTTPackets) + num0RTT := num0RTTPackets.Load() fmt.Fprintf(GinkgoWriter, "Sent %d 0-RTT packets.", num0RTT) Expect(num0RTT).ToNot(BeZero()) Expect(get0RTTPackets(counter.getRcvdLongHeaderPackets())).To(BeEmpty()) @@ -961,7 +1029,7 @@ var _ = Describe("0-RTT", func() { defer close(received) conn, err := ln.Accept(context.Background()) Expect(err).ToNot(HaveOccurred()) - receivedMessage, err = conn.ReceiveMessage(context.Background()) + receivedMessage, err = conn.ReceiveDatagram(context.Background()) Expect(err).ToNot(HaveOccurred()) Expect(conn.ConnectionState().Used0RTT).To(BeTrue()) }() @@ -975,14 +1043,14 @@ var _ = Describe("0-RTT", func() { ) Expect(err).ToNot(HaveOccurred()) Expect(conn.ConnectionState().SupportsDatagrams).To(BeTrue()) - Expect(conn.SendMessage(sentMessage)).To(Succeed()) + Expect(conn.SendDatagram(sentMessage)).To(Succeed()) <-conn.HandshakeComplete() <-received Expect(conn.ConnectionState().Used0RTT).To(BeTrue()) Expect(conn.CloseWithError(0, "")).To(Succeed()) Expect(receivedMessage).To(Equal(sentMessage)) - num0RTT := atomic.LoadUint32(num0RTTPackets) + num0RTT := num0RTTPackets.Load() fmt.Fprintf(GinkgoWriter, "Sent %d 0-RTT packets.", num0RTT) Expect(num0RTT).ToNot(BeZero()) zeroRTTPackets := get0RTTPackets(counter.getRcvdLongHeaderPackets()) @@ -1017,7 +1085,7 @@ var _ = Describe("0-RTT", func() { defer GinkgoRecover() conn, err := ln.Accept(context.Background()) Expect(err).ToNot(HaveOccurred()) - _, err = conn.ReceiveMessage(context.Background()) + _, err = conn.ReceiveDatagram(context.Background()) Expect(err.Error()).To(Equal("datagram support disabled")) <-conn.HandshakeComplete() Expect(conn.ConnectionState().Used0RTT).To(BeFalse()) @@ -1033,13 +1101,13 @@ var _ = Describe("0-RTT", func() { Expect(err).ToNot(HaveOccurred()) // the client can temporarily send datagrams but the server doesn't process them. Expect(conn.ConnectionState().SupportsDatagrams).To(BeTrue()) - Expect(conn.SendMessage(make([]byte, 100))).To(Succeed()) + Expect(conn.SendDatagram(make([]byte, 100))).To(Succeed()) <-conn.HandshakeComplete() Expect(conn.ConnectionState().SupportsDatagrams).To(BeFalse()) Expect(conn.ConnectionState().Used0RTT).To(BeFalse()) Expect(conn.CloseWithError(0, "")).To(Succeed()) - num0RTT := atomic.LoadUint32(num0RTTPackets) + num0RTT := num0RTTPackets.Load() fmt.Fprintf(GinkgoWriter, "Sent %d 0-RTT packets.", num0RTT) Expect(num0RTT).ToNot(BeZero()) Expect(get0RTTPackets(counter.getRcvdLongHeaderPackets())).To(BeEmpty()) diff --git a/integrationtests/tools/proxy/proxy_test.go b/integrationtests/tools/proxy/proxy_test.go index bef42f289..d43e20f69 100644 --- a/integrationtests/tools/proxy/proxy_test.go +++ b/integrationtests/tools/proxy/proxy_test.go @@ -141,7 +141,7 @@ var _ = Describe("QUIC Proxy", func() { Context("Proxy tests", func() { var ( serverConn *net.UDPConn - serverNumPacketsSent int32 + serverNumPacketsSent atomic.Int32 serverReceivedPackets chan packetData clientConn *net.UDPConn proxy *QuicProxy @@ -159,9 +159,9 @@ var _ = Describe("QUIC Proxy", func() { BeforeEach(func() { stoppedReading = make(chan struct{}) serverReceivedPackets = make(chan packetData, 100) - atomic.StoreInt32(&serverNumPacketsSent, 0) + serverNumPacketsSent.Store(0) - // setup a dump UDP server + // set up a dump UDP server // in production this would be a QUIC server raddr, err := net.ResolveUDPAddr("udp", "127.0.0.1:0") Expect(err).ToNot(HaveOccurred()) @@ -181,7 +181,7 @@ var _ = Describe("QUIC Proxy", func() { data := buf[0:n] serverReceivedPackets <- packetData(data) // echo the packet - atomic.AddInt32(&serverNumPacketsSent, 1) + serverNumPacketsSent.Add(1) serverConn.WriteToUDP(data, addr) } }() @@ -236,7 +236,7 @@ var _ = Describe("QUIC Proxy", func() { }() Eventually(serverReceivedPackets).Should(HaveLen(2)) - Expect(atomic.LoadInt32(&serverNumPacketsSent)).To(BeEquivalentTo(2)) + Expect(serverNumPacketsSent.Load()).To(BeEquivalentTo(2)) Eventually(clientReceivedPackets).Should(HaveLen(2)) Expect(string(<-clientReceivedPackets)).To(ContainSubstring("foobar")) Expect(string(<-clientReceivedPackets)).To(ContainSubstring("decafbad")) @@ -245,14 +245,14 @@ var _ = Describe("QUIC Proxy", func() { Context("Drop Callbacks", func() { It("drops incoming packets", func() { - var counter int32 + var counter atomic.Int32 opts := &Opts{ RemoteAddr: serverConn.LocalAddr().String(), DropPacket: func(d Direction, _ []byte) bool { if d != DirectionIncoming { return false } - return atomic.AddInt32(&counter, 1)%2 == 1 + return counter.Add(1)%2 == 1 }, } startProxy(opts) @@ -267,14 +267,14 @@ var _ = Describe("QUIC Proxy", func() { It("drops outgoing packets", func() { const numPackets = 6 - var counter int32 + var counter atomic.Int32 opts := &Opts{ RemoteAddr: serverConn.LocalAddr().String(), DropPacket: func(d Direction, _ []byte) bool { if d != DirectionOutgoing { return false } - return atomic.AddInt32(&counter, 1)%2 == 1 + return counter.Add(1)%2 == 1 }, } startProxy(opts) @@ -315,7 +315,7 @@ var _ = Describe("QUIC Proxy", func() { } It("delays incoming packets", func() { - var counter int32 + var counter atomic.Int32 opts := &Opts{ RemoteAddr: serverConn.LocalAddr().String(), // delay packet 1 by 200 ms @@ -325,7 +325,7 @@ var _ = Describe("QUIC Proxy", func() { if d == DirectionOutgoing { return 0 } - p := atomic.AddInt32(&counter, 1) + p := counter.Add(1) return time.Duration(p) * delay }, } @@ -349,7 +349,7 @@ var _ = Describe("QUIC Proxy", func() { }) It("handles reordered packets", func() { - var counter int32 + var counter atomic.Int32 opts := &Opts{ RemoteAddr: serverConn.LocalAddr().String(), // delay packet 1 by 600 ms @@ -359,7 +359,7 @@ var _ = Describe("QUIC Proxy", func() { if d == DirectionOutgoing { return 0 } - p := atomic.AddInt32(&counter, 1) + p := counter.Add(1) return 600*time.Millisecond - time.Duration(p-1)*delay }, } @@ -407,7 +407,7 @@ var _ = Describe("QUIC Proxy", func() { It("delays outgoing packets", func() { const numPackets = 3 - var counter int32 + var counter atomic.Int32 opts := &Opts{ RemoteAddr: serverConn.LocalAddr().String(), // delay packet 1 by 200 ms @@ -417,7 +417,7 @@ var _ = Describe("QUIC Proxy", func() { if d == DirectionIncoming { return 0 } - p := atomic.AddInt32(&counter, 1) + p := counter.Add(1) return time.Duration(p) * delay }, } diff --git a/integrationtests/tools/qlog.go b/integrationtests/tools/qlog.go index 80d6476a0..8e0e86feb 100644 --- a/integrationtests/tools/qlog.go +++ b/integrationtests/tools/qlog.go @@ -7,6 +7,7 @@ import ( "io" "log" "os" + "time" quic "github.com/refraction-networking/uquic" "github.com/refraction-networking/uquic/internal/utils" @@ -14,13 +15,21 @@ import ( "github.com/refraction-networking/uquic/qlog" ) -func NewQlogger(logger io.Writer) func(context.Context, logging.Perspective, quic.ConnectionID) *logging.ConnectionTracer { +func QlogTracer(logger io.Writer) *logging.Tracer { + filename := fmt.Sprintf("log_%s_transport.qlog", time.Now().Format("2006-01-02T15:04:05")) + fmt.Fprintf(logger, "Creating %s.\n", filename) + f, err := os.Create(filename) + if err != nil { + log.Fatalf("failed to create qlog file: %s", err) + return nil + } + bw := bufio.NewWriter(f) + return qlog.NewTracer(utils.NewBufferedWriteCloser(bw, f)) +} + +func NewQlogConnectionTracer(logger io.Writer) func(context.Context, logging.Perspective, quic.ConnectionID) *logging.ConnectionTracer { return func(_ context.Context, p logging.Perspective, connID quic.ConnectionID) *logging.ConnectionTracer { - role := "server" - if p == logging.PerspectiveClient { - role = "client" - } - filename := fmt.Sprintf("log_%x_%s.qlog", connID.Bytes(), role) + filename := fmt.Sprintf("log_%s_%s.qlog", connID, p.String()) fmt.Fprintf(logger, "Creating %s.\n", filename) f, err := os.Create(filename) if err != nil { diff --git a/integrationtests/versionnegotiation/handshake_test.go b/integrationtests/versionnegotiation/handshake_test.go index eeef9ea35..e9fe8da24 100644 --- a/integrationtests/versionnegotiation/handshake_test.go +++ b/integrationtests/versionnegotiation/handshake_test.go @@ -19,7 +19,7 @@ import ( ) type versioner interface { - GetVersion() protocol.VersionNumber + GetVersion() protocol.Version } type result struct { @@ -69,11 +69,11 @@ var _ = Describe("Handshake tests", func() { } } - var supportedVersions []protocol.VersionNumber + var supportedVersions []protocol.Version BeforeEach(func() { - supportedVersions = append([]quic.VersionNumber{}, protocol.SupportedVersions...) - protocol.SupportedVersions = append(protocol.SupportedVersions, []protocol.VersionNumber{7, 8, 9, 10}...) + supportedVersions = append([]quic.Version{}, protocol.SupportedVersions...) + protocol.SupportedVersions = append(protocol.SupportedVersions, []protocol.Version{7, 8, 9, 10}...) }) AfterEach(func() { @@ -86,7 +86,7 @@ var _ = Describe("Handshake tests", func() { // the server doesn't support the highest supported version, which is the first one the client will try // but it supports a bunch of versions that the client doesn't speak serverConfig := &quic.Config{} - serverConfig.Versions = []protocol.VersionNumber{7, 8, protocol.SupportedVersions[0], 9} + serverConfig.Versions = []protocol.Version{7, 8, protocol.SupportedVersions[0], 9} serverResult, serverTracer := newVersionNegotiationTracer() serverConfig.Tracer = func(context.Context, logging.Perspective, quic.ConnectionID) *logging.ConnectionTracer { return serverTracer @@ -126,7 +126,7 @@ var _ = Describe("Handshake tests", func() { } server, cl := startServer(getTLSConfig(), serverConfig) defer cl() - clientVersions := []protocol.VersionNumber{7, 8, 9, protocol.SupportedVersions[0], 10} + clientVersions := []protocol.Version{7, 8, 9, protocol.SupportedVersions[0], 10} clientResult, clientTracer := newVersionNegotiationTracer() conn, err := quic.DialAddr( context.Background(), @@ -170,7 +170,7 @@ var _ = Describe("Handshake tests", func() { Expect(err).ToNot(HaveOccurred()) defer ln.Close() - clientVersions := []protocol.VersionNumber{7, 8, 9, protocol.SupportedVersions[0], 10} + clientVersions := []protocol.Version{7, 8, 9, protocol.SupportedVersions[0], 10} clientResult, clientTracer := newVersionNegotiationTracer() _, err = quic.DialAddr( context.Background(), diff --git a/integrationtests/versionnegotiation/versionnegotiation_suite_test.go b/integrationtests/versionnegotiation/versionnegotiation_suite_test.go index e1c7ef70d..20cfa7728 100644 --- a/integrationtests/versionnegotiation/versionnegotiation_suite_test.go +++ b/integrationtests/versionnegotiation/versionnegotiation_suite_test.go @@ -65,7 +65,7 @@ func maybeAddQLOGTracer(c *quic.Config) *quic.Config { if !enableQlog { return c } - qlogger := tools.NewQlogger(GinkgoWriter) + qlogger := tools.NewQlogConnectionTracer(GinkgoWriter) if c.Tracer == nil { c.Tracer = qlogger } else if qlogger != nil { diff --git a/interface.go b/interface.go index 68fccd680..a918b15ec 100644 --- a/interface.go +++ b/interface.go @@ -17,8 +17,12 @@ import ( // The StreamID is the ID of a QUIC stream. type StreamID = protocol.StreamID +// A Version is a QUIC version number. +type Version = protocol.Version + // A VersionNumber is a QUIC version number. -type VersionNumber = protocol.VersionNumber +// Deprecated: VersionNumber was renamed to Version. +type VersionNumber = Version const ( // Version1 is RFC 9000 @@ -160,6 +164,9 @@ type Connection interface { OpenStream() (Stream, error) // OpenStreamSync opens a new bidirectional QUIC stream. // It blocks until a new stream can be opened. + // There is no signaling to the peer about new streams: + // The peer can only accept the stream after data has been sent on the stream, + // or the stream has been reset or closed. // If the error is non-nil, it satisfies the net.Error interface. // If the connection was closed due to a timeout, Timeout() will be true. OpenStreamSync(context.Context) (Stream, error) @@ -188,10 +195,14 @@ type Connection interface { // Warning: This API should not be considered stable and might change soon. ConnectionState() ConnectionState - // SendMessage sends a message as a datagram, as specified in RFC 9221. - SendMessage([]byte) error - // ReceiveMessage gets a message received in a datagram, as specified in RFC 9221. - ReceiveMessage(context.Context) ([]byte, error) + // SendDatagram sends a message using a QUIC datagram, as specified in RFC 9221. + // There is no delivery guarantee for DATAGRAM frames, they are not retransmitted if lost. + // The payload of the datagram needs to fit into a single QUIC packet. + // In addition, a datagram may be dropped before being sent out if the available packet size suddenly decreases. + // If the payload is too large to be sent at the current time, a DatagramTooLargeError is returned. + SendDatagram(payload []byte) error + // ReceiveDatagram gets a message received in a datagram, as specified in RFC 9221. + ReceiveDatagram(context.Context) ([]byte, error) } // An EarlyConnection is a connection that is handshaking. @@ -252,7 +263,7 @@ type Config struct { GetConfigForClient func(info *ClientHelloInfo) (*Config, error) // The QUIC versions that can be negotiated. // If not set, it uses all versions available. - Versions []VersionNumber + Versions []Version // HandshakeIdleTimeout is the idle timeout before completion of the handshake. // If we don't receive any packet from the peer within this time, the connection attempt is aborted. // Additionally, if the handshake doesn't complete in twice this time, the connection attempt is also aborted. @@ -264,11 +275,6 @@ type Config struct { // If the timeout is exceeded, the connection is closed. // If this value is zero, the timeout is set to 30 seconds. MaxIdleTimeout time.Duration - // RequireAddressValidation determines if a QUIC Retry packet is sent. - // This allows the server to verify the client's address, at the cost of increasing the handshake latency by 1 RTT. - // See https://datatracker.ietf.org/doc/html/rfc9000#section-8 for details. - // If not set, every client is forced to prove its remote address. - RequireAddressValidation func(net.Addr) bool // The TokenStore stores tokens received from the server. // Tokens are used to skip address validation on future connection attempts. // The key used to store tokens is the ServerName from the tls.Config, if set @@ -328,8 +334,15 @@ type Config struct { Tracer func(context.Context, logging.Perspective, ConnectionID) *logging.ConnectionTracer } +// ClientHelloInfo contains information about an incoming connection attempt. type ClientHelloInfo struct { + // RemoteAddr is the remote address on the Initial packet. + // Unless AddrVerified is set, the address is not yet verified, and could be a spoofed IP address. RemoteAddr net.Addr + // AddrVerified says if the remote address was verified using QUIC's Retry mechanism. + // Note that the Retry mechanism costs one network roundtrip, + // and is not performed unless Transport.MaxUnvalidatedHandshakes is surpassed. + AddrVerified bool } // ConnectionState records basic details about a QUIC connection @@ -339,12 +352,12 @@ type ConnectionState struct { // SupportsDatagrams says if support for QUIC datagrams (RFC 9221) was negotiated. // This requires both nodes to support and enable the datagram extensions (via Config.EnableDatagrams). // If datagram support was negotiated, datagrams can be sent and received using the - // SendMessage and ReceiveMessage methods on the Connection. + // SendDatagram and ReceiveDatagram methods on the Connection. SupportsDatagrams bool // Used0RTT says if 0-RTT resumption was used. Used0RTT bool // Version is the QUIC version of the QUIC connection. - Version VersionNumber + Version Version // GSO says if generic segmentation offload is used GSO bool } diff --git a/internal/ackhandler/ackhandler.go b/internal/ackhandler/ackhandler.go index 5f9071a6e..cabf558e4 100644 --- a/internal/ackhandler/ackhandler.go +++ b/internal/ackhandler/ackhandler.go @@ -20,5 +20,5 @@ func NewAckHandler( logger utils.Logger, ) (SentPacketHandler, ReceivedPacketHandler) { sph := newSentPacketHandler(initialPacketNumber, initialMaxDatagramSize, rttStats, clientAddressValidated, enableECN, pers, tracer, logger) - return sph, newReceivedPacketHandler(sph, rttStats, logger) + return sph, newReceivedPacketHandler(sph, logger) } diff --git a/internal/ackhandler/mock_ecn_handler_test.go b/internal/ackhandler/mock_ecn_handler_test.go index 949268f5e..9f42b74ca 100644 --- a/internal/ackhandler/mock_ecn_handler_test.go +++ b/internal/ackhandler/mock_ecn_handler_test.go @@ -5,6 +5,7 @@ // // mockgen -build_flags=-tags=gomock -package ackhandler -destination mock_ecn_handler_test.go github.com/refraction-networking/uquic/internal/ackhandler ECNHandler // + // Package ackhandler is a generated GoMock package. package ackhandler diff --git a/internal/ackhandler/mock_sent_packet_tracker_test.go b/internal/ackhandler/mock_sent_packet_tracker_test.go index 2e755658d..d4bf0c5af 100644 --- a/internal/ackhandler/mock_sent_packet_tracker_test.go +++ b/internal/ackhandler/mock_sent_packet_tracker_test.go @@ -3,8 +3,9 @@ // // Generated by this command: // -// mockgen -build_flags=-tags=gomock -package ackhandler -destination mock_sent_packet_tracker_test.go github.com/refraction-networking/uquic/internal/ackhandler SentPacketTracker +// mockgen -typed -build_flags=-tags=gomock -package ackhandler -destination mock_sent_packet_tracker_test.go github.com/quic-go/quic-go/internal/ackhandler SentPacketTracker // + // Package ackhandler is a generated GoMock package. package ackhandler @@ -47,9 +48,33 @@ func (m *MockSentPacketTracker) GetLowestPacketNotConfirmedAcked() protocol.Pack } // GetLowestPacketNotConfirmedAcked indicates an expected call of GetLowestPacketNotConfirmedAcked. -func (mr *MockSentPacketTrackerMockRecorder) GetLowestPacketNotConfirmedAcked() *gomock.Call { +func (mr *MockSentPacketTrackerMockRecorder) GetLowestPacketNotConfirmedAcked() *MockSentPacketTrackerGetLowestPacketNotConfirmedAckedCall { mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetLowestPacketNotConfirmedAcked", reflect.TypeOf((*MockSentPacketTracker)(nil).GetLowestPacketNotConfirmedAcked)) + call := mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetLowestPacketNotConfirmedAcked", reflect.TypeOf((*MockSentPacketTracker)(nil).GetLowestPacketNotConfirmedAcked)) + return &MockSentPacketTrackerGetLowestPacketNotConfirmedAckedCall{Call: call} +} + +// MockSentPacketTrackerGetLowestPacketNotConfirmedAckedCall wrap *gomock.Call +type MockSentPacketTrackerGetLowestPacketNotConfirmedAckedCall struct { + *gomock.Call +} + +// Return rewrite *gomock.Call.Return +func (c *MockSentPacketTrackerGetLowestPacketNotConfirmedAckedCall) Return(arg0 protocol.PacketNumber) *MockSentPacketTrackerGetLowestPacketNotConfirmedAckedCall { + c.Call = c.Call.Return(arg0) + return c +} + +// Do rewrite *gomock.Call.Do +func (c *MockSentPacketTrackerGetLowestPacketNotConfirmedAckedCall) Do(f func() protocol.PacketNumber) *MockSentPacketTrackerGetLowestPacketNotConfirmedAckedCall { + c.Call = c.Call.Do(f) + return c +} + +// DoAndReturn rewrite *gomock.Call.DoAndReturn +func (c *MockSentPacketTrackerGetLowestPacketNotConfirmedAckedCall) DoAndReturn(f func() protocol.PacketNumber) *MockSentPacketTrackerGetLowestPacketNotConfirmedAckedCall { + c.Call = c.Call.DoAndReturn(f) + return c } // ReceivedPacket mocks base method. @@ -59,7 +84,31 @@ func (m *MockSentPacketTracker) ReceivedPacket(arg0 protocol.EncryptionLevel) { } // ReceivedPacket indicates an expected call of ReceivedPacket. -func (mr *MockSentPacketTrackerMockRecorder) ReceivedPacket(arg0 any) *gomock.Call { +func (mr *MockSentPacketTrackerMockRecorder) ReceivedPacket(arg0 any) *MockSentPacketTrackerReceivedPacketCall { mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "ReceivedPacket", reflect.TypeOf((*MockSentPacketTracker)(nil).ReceivedPacket), arg0) + call := mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "ReceivedPacket", reflect.TypeOf((*MockSentPacketTracker)(nil).ReceivedPacket), arg0) + return &MockSentPacketTrackerReceivedPacketCall{Call: call} +} + +// MockSentPacketTrackerReceivedPacketCall wrap *gomock.Call +type MockSentPacketTrackerReceivedPacketCall struct { + *gomock.Call +} + +// Return rewrite *gomock.Call.Return +func (c *MockSentPacketTrackerReceivedPacketCall) Return() *MockSentPacketTrackerReceivedPacketCall { + c.Call = c.Call.Return() + return c +} + +// Do rewrite *gomock.Call.Do +func (c *MockSentPacketTrackerReceivedPacketCall) Do(f func(protocol.EncryptionLevel)) *MockSentPacketTrackerReceivedPacketCall { + c.Call = c.Call.Do(f) + return c +} + +// DoAndReturn rewrite *gomock.Call.DoAndReturn +func (c *MockSentPacketTrackerReceivedPacketCall) DoAndReturn(f func(protocol.EncryptionLevel)) *MockSentPacketTrackerReceivedPacketCall { + c.Call = c.Call.DoAndReturn(f) + return c } diff --git a/internal/ackhandler/mockgen.go b/internal/ackhandler/mockgen.go index b36c0de1e..e58e26c36 100644 --- a/internal/ackhandler/mockgen.go +++ b/internal/ackhandler/mockgen.go @@ -2,7 +2,7 @@ package ackhandler -//go:generate sh -c "go run go.uber.org/mock/mockgen -build_flags=\"-tags=gomock\" -package ackhandler -destination mock_sent_packet_tracker_test.go github.com/refraction-networking/uquic/internal/ackhandler SentPacketTracker" +//go:generate sh -c "go run go.uber.org/mock/mockgen -typed -build_flags=\"-tags=gomock\" -package ackhandler -destination mock_sent_packet_tracker_test.go github.com/quic-go/quic-go/internal/ackhandler SentPacketTracker" type SentPacketTracker = sentPacketTracker //go:generate sh -c "go run go.uber.org/mock/mockgen -build_flags=\"-tags=gomock\" -package ackhandler -destination mock_ecn_handler_test.go github.com/refraction-networking/uquic/internal/ackhandler ECNHandler" diff --git a/internal/ackhandler/packet_number_generator.go b/internal/ackhandler/packet_number_generator.go index b29db772e..7675c1257 100644 --- a/internal/ackhandler/packet_number_generator.go +++ b/internal/ackhandler/packet_number_generator.go @@ -80,5 +80,5 @@ func (p *skippingPacketNumberGenerator) Pop() (bool, protocol.PacketNumber) { func (p *skippingPacketNumberGenerator) generateNewSkip() { // make sure that there are never two consecutive packet numbers that are skipped p.nextToSkip = p.next + 3 + protocol.PacketNumber(p.rng.Int31n(int32(2*p.period))) - p.period = utils.Min(2*p.period, p.maxPeriod) + p.period = min(2*p.period, p.maxPeriod) } diff --git a/internal/ackhandler/received_packet_handler.go b/internal/ackhandler/received_packet_handler.go index 98c5e6132..3b2b745ca 100644 --- a/internal/ackhandler/received_packet_handler.go +++ b/internal/ackhandler/received_packet_handler.go @@ -14,23 +14,19 @@ type receivedPacketHandler struct { initialPackets *receivedPacketTracker handshakePackets *receivedPacketTracker - appDataPackets *receivedPacketTracker + appDataPackets appDataReceivedPacketTracker lowest1RTTPacket protocol.PacketNumber } var _ ReceivedPacketHandler = &receivedPacketHandler{} -func newReceivedPacketHandler( - sentPackets sentPacketTracker, - rttStats *utils.RTTStats, - logger utils.Logger, -) ReceivedPacketHandler { +func newReceivedPacketHandler(sentPackets sentPacketTracker, logger utils.Logger) ReceivedPacketHandler { return &receivedPacketHandler{ sentPackets: sentPackets, - initialPackets: newReceivedPacketTracker(rttStats, logger), - handshakePackets: newReceivedPacketTracker(rttStats, logger), - appDataPackets: newReceivedPacketTracker(rttStats, logger), + initialPackets: newReceivedPacketTracker(), + handshakePackets: newReceivedPacketTracker(), + appDataPackets: *newAppDataReceivedPacketTracker(logger), lowest1RTTPacket: protocol.InvalidPacketNumber, } } @@ -88,41 +84,28 @@ func (h *receivedPacketHandler) DropPackets(encLevel protocol.EncryptionLevel) { } func (h *receivedPacketHandler) GetAlarmTimeout() time.Time { - var initialAlarm, handshakeAlarm time.Time - if h.initialPackets != nil { - initialAlarm = h.initialPackets.GetAlarmTimeout() - } - if h.handshakePackets != nil { - handshakeAlarm = h.handshakePackets.GetAlarmTimeout() - } - oneRTTAlarm := h.appDataPackets.GetAlarmTimeout() - return utils.MinNonZeroTime(utils.MinNonZeroTime(initialAlarm, handshakeAlarm), oneRTTAlarm) + return h.appDataPackets.GetAlarmTimeout() } func (h *receivedPacketHandler) GetAckFrame(encLevel protocol.EncryptionLevel, onlyIfQueued bool) *wire.AckFrame { - var ack *wire.AckFrame //nolint:exhaustive // 0-RTT packets can't contain ACK frames. switch encLevel { case protocol.EncryptionInitial: if h.initialPackets != nil { - ack = h.initialPackets.GetAckFrame(onlyIfQueued) + return h.initialPackets.GetAckFrame() } + return nil case protocol.EncryptionHandshake: if h.handshakePackets != nil { - ack = h.handshakePackets.GetAckFrame(onlyIfQueued) + return h.handshakePackets.GetAckFrame() } + return nil case protocol.Encryption1RTT: - // 0-RTT packets can't contain ACK frames return h.appDataPackets.GetAckFrame(onlyIfQueued) default: + // 0-RTT packets can't contain ACK frames return nil } - // For Initial and Handshake ACKs, the delay time is ignored by the receiver. - // Set it to 0 in order to save bytes. - if ack != nil { - ack.DelayTime = 0 - } - return ack } func (h *receivedPacketHandler) IsPotentiallyDuplicate(pn protocol.PacketNumber, encLevel protocol.EncryptionLevel) bool { diff --git a/internal/ackhandler/received_packet_handler_test.go b/internal/ackhandler/received_packet_handler_test.go index a537399b4..a7769c569 100644 --- a/internal/ackhandler/received_packet_handler_test.go +++ b/internal/ackhandler/received_packet_handler_test.go @@ -18,11 +18,7 @@ var _ = Describe("Received Packet Handler", func() { BeforeEach(func() { sentPackets = NewMockSentPacketTracker(mockCtrl) - handler = newReceivedPacketHandler( - sentPackets, - &utils.RTTStats{}, - utils.DefaultLogger, - ) + handler = newReceivedPacketHandler(sentPackets, utils.DefaultLogger) }) It("generates ACKs for different packet number spaces", func() { diff --git a/internal/ackhandler/received_packet_tracker.go b/internal/ackhandler/received_packet_tracker.go index 750d13f93..1a82b2af6 100644 --- a/internal/ackhandler/received_packet_tracker.go +++ b/internal/ackhandler/received_packet_tracker.go @@ -9,74 +9,128 @@ import ( "github.com/refraction-networking/uquic/internal/wire" ) -// number of ack-eliciting packets received before sending an ack. +// The receivedPacketTracker tracks packets for the Initial and Handshake packet number space. +// Every received packet is acknowledged immediately. +type receivedPacketTracker struct { + ect0, ect1, ecnce uint64 + + packetHistory receivedPacketHistory + + lastAck *wire.AckFrame + hasNewAck bool // true as soon as we received an ack-eliciting new packet +} + +func newReceivedPacketTracker() *receivedPacketTracker { + return &receivedPacketTracker{packetHistory: *newReceivedPacketHistory()} +} + +func (h *receivedPacketTracker) ReceivedPacket(pn protocol.PacketNumber, ecn protocol.ECN, rcvTime time.Time, ackEliciting bool) error { + if isNew := h.packetHistory.ReceivedPacket(pn); !isNew { + return fmt.Errorf("recevedPacketTracker BUG: ReceivedPacket called for old / duplicate packet %d", pn) + } + + //nolint:exhaustive // Only need to count ECT(0), ECT(1) and ECN-CE. + switch ecn { + case protocol.ECT0: + h.ect0++ + case protocol.ECT1: + h.ect1++ + case protocol.ECNCE: + h.ecnce++ + } + if !ackEliciting { + return nil + } + h.hasNewAck = true + return nil +} + +func (h *receivedPacketTracker) GetAckFrame() *wire.AckFrame { + if !h.hasNewAck { + return nil + } + + // This function always returns the same ACK frame struct, filled with the most recent values. + ack := h.lastAck + if ack == nil { + ack = &wire.AckFrame{} + } + ack.Reset() + ack.ECT0 = h.ect0 + ack.ECT1 = h.ect1 + ack.ECNCE = h.ecnce + ack.AckRanges = h.packetHistory.AppendAckRanges(ack.AckRanges) + + h.lastAck = ack + h.hasNewAck = false + return ack +} + +func (h *receivedPacketTracker) IsPotentiallyDuplicate(pn protocol.PacketNumber) bool { + return h.packetHistory.IsPotentiallyDuplicate(pn) +} + +// number of ack-eliciting packets received before sending an ACK const packetsBeforeAck = 2 -type receivedPacketTracker struct { - largestObserved protocol.PacketNumber - ignoreBelow protocol.PacketNumber +// The appDataReceivedPacketTracker tracks packets received in the Application Data packet number space. +// It waits until at least 2 packets were received before queueing an ACK, or until the max_ack_delay was reached. +type appDataReceivedPacketTracker struct { + receivedPacketTracker + largestObservedRcvdTime time.Time - ect0, ect1, ecnce uint64 - packetHistory *receivedPacketHistory + largestObserved protocol.PacketNumber + ignoreBelow protocol.PacketNumber maxAckDelay time.Duration - rttStats *utils.RTTStats - - hasNewAck bool // true as soon as we received an ack-eliciting new packet - ackQueued bool // true once we received more than 2 (or later in the connection 10) ack-eliciting packets + ackQueued bool // true if we need send a new ACK ackElicitingPacketsReceivedSinceLastAck int ackAlarm time.Time - lastAck *wire.AckFrame logger utils.Logger } -func newReceivedPacketTracker( - rttStats *utils.RTTStats, - logger utils.Logger, -) *receivedPacketTracker { - return &receivedPacketTracker{ - packetHistory: newReceivedPacketHistory(), - maxAckDelay: protocol.MaxAckDelay, - rttStats: rttStats, - logger: logger, +func newAppDataReceivedPacketTracker(logger utils.Logger) *appDataReceivedPacketTracker { + h := &appDataReceivedPacketTracker{ + receivedPacketTracker: *newReceivedPacketTracker(), + maxAckDelay: protocol.MaxAckDelay, + logger: logger, } + return h } -func (h *receivedPacketTracker) ReceivedPacket(pn protocol.PacketNumber, ecn protocol.ECN, rcvTime time.Time, ackEliciting bool) error { - if isNew := h.packetHistory.ReceivedPacket(pn); !isNew { - return fmt.Errorf("recevedPacketTracker BUG: ReceivedPacket called for old / duplicate packet %d", pn) +func (h *appDataReceivedPacketTracker) ReceivedPacket(pn protocol.PacketNumber, ecn protocol.ECN, rcvTime time.Time, ackEliciting bool) error { + if err := h.receivedPacketTracker.ReceivedPacket(pn, ecn, rcvTime, ackEliciting); err != nil { + return err } - - isMissing := h.isMissing(pn) if pn >= h.largestObserved { h.largestObserved = pn h.largestObservedRcvdTime = rcvTime } - - if ackEliciting { - h.hasNewAck = true + if !ackEliciting { + return nil } - if ackEliciting { - h.maybeQueueACK(pn, rcvTime, isMissing) + h.ackElicitingPacketsReceivedSinceLastAck++ + isMissing := h.isMissing(pn) + if !h.ackQueued && h.shouldQueueACK(pn, ecn, isMissing) { + h.ackQueued = true + h.ackAlarm = time.Time{} // cancel the ack alarm } - //nolint:exhaustive // Only need to count ECT(0), ECT(1) and ECNCE. - switch ecn { - case protocol.ECT0: - h.ect0++ - case protocol.ECT1: - h.ect1++ - case protocol.ECNCE: - h.ecnce++ + if !h.ackQueued { + // No ACK queued, but we'll need to acknowledge the packet after max_ack_delay. + h.ackAlarm = rcvTime.Add(h.maxAckDelay) + if h.logger.Debug() { + h.logger.Debugf("\tSetting ACK timer to max ack delay: %s", h.maxAckDelay) + } } return nil } // IgnoreBelow sets a lower limit for acknowledging packets. // Packets with packet numbers smaller than p will not be acked. -func (h *receivedPacketTracker) IgnoreBelow(pn protocol.PacketNumber) { +func (h *appDataReceivedPacketTracker) IgnoreBelow(pn protocol.PacketNumber) { if pn <= h.ignoreBelow { return } @@ -88,14 +142,14 @@ func (h *receivedPacketTracker) IgnoreBelow(pn protocol.PacketNumber) { } // isMissing says if a packet was reported missing in the last ACK. -func (h *receivedPacketTracker) isMissing(p protocol.PacketNumber) bool { +func (h *appDataReceivedPacketTracker) isMissing(p protocol.PacketNumber) bool { if h.lastAck == nil || p < h.ignoreBelow { return false } return p < h.lastAck.LargestAcked() && !h.lastAck.AcksPacket(p) } -func (h *receivedPacketTracker) hasNewMissingPackets() bool { +func (h *appDataReceivedPacketTracker) hasNewMissingPackets() bool { if h.lastAck == nil { return false } @@ -103,31 +157,21 @@ func (h *receivedPacketTracker) hasNewMissingPackets() bool { return highestRange.Smallest > h.lastAck.LargestAcked()+1 && highestRange.Len() == 1 } -// maybeQueueACK queues an ACK, if necessary. -func (h *receivedPacketTracker) maybeQueueACK(pn protocol.PacketNumber, rcvTime time.Time, wasMissing bool) { +func (h *appDataReceivedPacketTracker) shouldQueueACK(pn protocol.PacketNumber, ecn protocol.ECN, wasMissing bool) bool { // always acknowledge the first packet if h.lastAck == nil { - if !h.ackQueued { - h.logger.Debugf("\tQueueing ACK because the first packet should be acknowledged.") - } - h.ackQueued = true - return - } - - if h.ackQueued { - return + h.logger.Debugf("\tQueueing ACK because the first packet should be acknowledged.") + return true } - h.ackElicitingPacketsReceivedSinceLastAck++ - // Send an ACK if this packet was reported missing in an ACK sent before. // Ack decimation with reordering relies on the timer to send an ACK, but if - // missing packets we reported in the previous ack, send an ACK immediately. + // missing packets we reported in the previous ACK, send an ACK immediately. if wasMissing { if h.logger.Debug() { h.logger.Debugf("\tQueueing ACK because packet %d was missing before.", pn) } - h.ackQueued = true + return true } // send an ACK every 2 ack-eliciting packets @@ -135,62 +179,42 @@ func (h *receivedPacketTracker) maybeQueueACK(pn protocol.PacketNumber, rcvTime if h.logger.Debug() { h.logger.Debugf("\tQueueing ACK because packet %d packets were received after the last ACK (using initial threshold: %d).", h.ackElicitingPacketsReceivedSinceLastAck, packetsBeforeAck) } - h.ackQueued = true - } else if h.ackAlarm.IsZero() { - if h.logger.Debug() { - h.logger.Debugf("\tSetting ACK timer to max ack delay: %s", h.maxAckDelay) - } - h.ackAlarm = rcvTime.Add(h.maxAckDelay) + return true } - // Queue an ACK if there are new missing packets to report. + // queue an ACK if there are new missing packets to report if h.hasNewMissingPackets() { h.logger.Debugf("\tQueuing ACK because there's a new missing packet to report.") - h.ackQueued = true + return true } - if h.ackQueued { - // cancel the ack alarm - h.ackAlarm = time.Time{} + // queue an ACK if the packet was ECN-CE marked + if ecn == protocol.ECNCE { + h.logger.Debugf("\tQueuing ACK because the packet was ECN-CE marked.") + return true } + return false } -func (h *receivedPacketTracker) GetAckFrame(onlyIfQueued bool) *wire.AckFrame { - if !h.hasNewAck { - return nil - } +func (h *appDataReceivedPacketTracker) GetAckFrame(onlyIfQueued bool) *wire.AckFrame { now := time.Now() - if onlyIfQueued { - if !h.ackQueued && (h.ackAlarm.IsZero() || h.ackAlarm.After(now)) { + if onlyIfQueued && !h.ackQueued { + if h.ackAlarm.IsZero() || h.ackAlarm.After(now) { return nil } - if h.logger.Debug() && !h.ackQueued && !h.ackAlarm.IsZero() { + if h.logger.Debug() && !h.ackAlarm.IsZero() { h.logger.Debugf("Sending ACK because the ACK timer expired.") } } - - // This function always returns the same ACK frame struct, filled with the most recent values. - ack := h.lastAck + ack := h.receivedPacketTracker.GetAckFrame() if ack == nil { - ack = &wire.AckFrame{} + return nil } - ack.Reset() - ack.DelayTime = utils.Max(0, now.Sub(h.largestObservedRcvdTime)) - ack.ECT0 = h.ect0 - ack.ECT1 = h.ect1 - ack.ECNCE = h.ecnce - ack.AckRanges = h.packetHistory.AppendAckRanges(ack.AckRanges) - - h.lastAck = ack - h.ackAlarm = time.Time{} + ack.DelayTime = max(0, now.Sub(h.largestObservedRcvdTime)) h.ackQueued = false - h.hasNewAck = false + h.ackAlarm = time.Time{} h.ackElicitingPacketsReceivedSinceLastAck = 0 return ack } -func (h *receivedPacketTracker) GetAlarmTimeout() time.Time { return h.ackAlarm } - -func (h *receivedPacketTracker) IsPotentiallyDuplicate(pn protocol.PacketNumber) bool { - return h.packetHistory.IsPotentiallyDuplicate(pn) -} +func (h *appDataReceivedPacketTracker) GetAlarmTimeout() time.Time { return h.ackAlarm } diff --git a/internal/ackhandler/received_packet_tracker_test.go b/internal/ackhandler/received_packet_tracker_test.go index a95e05684..0c7194646 100644 --- a/internal/ackhandler/received_packet_tracker_test.go +++ b/internal/ackhandler/received_packet_tracker_test.go @@ -12,20 +12,70 @@ import ( ) var _ = Describe("Received Packet Tracker", func() { - var ( - tracker *receivedPacketTracker - rttStats *utils.RTTStats - ) + var tracker *receivedPacketTracker BeforeEach(func() { - rttStats = &utils.RTTStats{} - tracker = newReceivedPacketTracker(rttStats, utils.DefaultLogger) + tracker = newReceivedPacketTracker() + }) + + It("acknowledges packets", func() { + t := time.Now().Add(-10 * time.Second) + Expect(tracker.ReceivedPacket(protocol.PacketNumber(3), protocol.ECNNon, t, true)).To(Succeed()) + ack := tracker.GetAckFrame() + Expect(ack).ToNot(BeNil()) + Expect(ack.AckRanges).To(Equal([]wire.AckRange{{Smallest: 3, Largest: 3}})) + Expect(ack.DelayTime).To(BeZero()) + // now receive another packet + Expect(tracker.ReceivedPacket(protocol.PacketNumber(4), protocol.ECNNon, t.Add(time.Second), true)).To(Succeed()) + ack = tracker.GetAckFrame() + Expect(ack).ToNot(BeNil()) + Expect(ack.AckRanges).To(Equal([]wire.AckRange{{Smallest: 3, Largest: 4}})) + Expect(ack.DelayTime).To(BeZero()) + }) + + It("also acknowledges delayed packets", func() { + t := time.Now().Add(-10 * time.Second) + Expect(tracker.ReceivedPacket(protocol.PacketNumber(3), protocol.ECNNon, t, true)).To(Succeed()) + ack := tracker.GetAckFrame() + Expect(ack).ToNot(BeNil()) + Expect(ack.LargestAcked()).To(Equal(protocol.PacketNumber(3))) + Expect(ack.LowestAcked()).To(Equal(protocol.PacketNumber(3))) + Expect(ack.DelayTime).To(BeZero()) + // now receive another packet + Expect(tracker.ReceivedPacket(protocol.PacketNumber(1), protocol.ECNNon, t.Add(time.Second), true)).To(Succeed()) + ack = tracker.GetAckFrame() + Expect(ack).ToNot(BeNil()) + Expect(ack.AckRanges).To(HaveLen(2)) + Expect(ack.AckRanges).To(ContainElement(wire.AckRange{Smallest: 1, Largest: 1})) + Expect(ack.AckRanges).To(ContainElement(wire.AckRange{Smallest: 3, Largest: 3})) + Expect(ack.DelayTime).To(BeZero()) + }) + + It("doesn't trigger ACKs for non-ack-eliciting packets", func() { + t := time.Now().Add(-10 * time.Second) + Expect(tracker.ReceivedPacket(protocol.PacketNumber(3), protocol.ECNNon, t, false)).To(Succeed()) + Expect(tracker.GetAckFrame()).To(BeNil()) + Expect(tracker.ReceivedPacket(protocol.PacketNumber(4), protocol.ECNNon, t.Add(5*time.Second), false)).To(Succeed()) + Expect(tracker.GetAckFrame()).To(BeNil()) + Expect(tracker.ReceivedPacket(protocol.PacketNumber(5), protocol.ECNNon, t.Add(10*time.Second), true)).To(Succeed()) + ack := tracker.GetAckFrame() + Expect(ack).ToNot(BeNil()) + Expect(ack.AckRanges).To(Equal([]wire.AckRange{{Smallest: 3, Largest: 5}})) + }) +}) + +var _ = Describe("Application Data Received Packet Tracker", func() { + var tracker *appDataReceivedPacketTracker + + BeforeEach(func() { + tracker = newAppDataReceivedPacketTracker(utils.DefaultLogger) }) Context("accepting packets", func() { It("saves the time when each packet arrived", func() { - Expect(tracker.ReceivedPacket(protocol.PacketNumber(3), protocol.ECNNon, time.Now(), true)).To(Succeed()) - Expect(tracker.largestObservedRcvdTime).To(BeTemporally("~", time.Now(), 10*time.Millisecond)) + t := time.Now() + Expect(tracker.ReceivedPacket(protocol.PacketNumber(3), protocol.ECNNon, t, true)).To(Succeed()) + Expect(tracker.largestObservedRcvdTime).To(Equal(t)) }) It("updates the largestObserved and the largestObservedRcvdTime", func() { @@ -50,6 +100,7 @@ var _ = Describe("Received Packet Tracker", func() { Context("ACKs", func() { Context("queueing ACKs", func() { + // receives and gets ACKs for packet numbers 1 to 10 (including) receiveAndAck10Packets := func() { for i := 1; i <= 10; i++ { Expect(tracker.ReceivedPacket(protocol.PacketNumber(i), protocol.ECNNon, time.Time{}, true)).To(Succeed()) @@ -126,6 +177,16 @@ var _ = Describe("Received Packet Tracker", func() { Expect(tracker.GetAlarmTimeout()).To(Equal(rcvTime.Add(protocol.MaxAckDelay))) }) + It("queues an ACK if the packet was ECN-CE marked", func() { + receiveAndAck10Packets() + Expect(tracker.ReceivedPacket(11, protocol.ECNCE, time.Now(), true)).To(Succeed()) + ack := tracker.GetAckFrame(true) + Expect(ack).ToNot(BeNil()) + Expect(ack.AckRanges).To(HaveLen(1)) + Expect(ack.AckRanges[0].Largest).To(Equal(protocol.PacketNumber(11))) + Expect(ack.ECNCE).To(BeEquivalentTo(1)) + }) + It("queues an ACK if it was reported missing before", func() { receiveAndAck10Packets() Expect(tracker.ReceivedPacket(11, protocol.ECNNon, time.Now(), true)).To(Succeed()) diff --git a/internal/ackhandler/sent_packet_handler.go b/internal/ackhandler/sent_packet_handler.go index 716c406cc..2790c6080 100644 --- a/internal/ackhandler/sent_packet_handler.go +++ b/internal/ackhandler/sent_packet_handler.go @@ -245,7 +245,7 @@ func (h *sentPacketHandler) SentPacket( pnSpace := h.getPacketNumberSpace(encLevel) if h.logger.Debug() && pnSpace.history.HasOutstandingPackets() { - for p := utils.Max(0, pnSpace.largestSent+1); p < pn; p++ { + for p := max(0, pnSpace.largestSent+1); p < pn; p++ { h.logger.Debugf("Skipping packet number %d", p) } } @@ -336,7 +336,7 @@ func (h *sentPacketHandler) ReceivedAck(ack *wire.AckFrame, encLevel protocol.En // don't use the ack delay for Initial and Handshake packets var ackDelay time.Duration if encLevel == protocol.Encryption1RTT { - ackDelay = utils.Min(ack.DelayTime, h.rttStats.MaxAckDelay()) + ackDelay = min(ack.DelayTime, h.rttStats.MaxAckDelay()) } h.rttStats.UpdateRTT(rcvTime.Sub(p.SendTime), ackDelay, rcvTime) if h.logger.Debug() { @@ -354,7 +354,7 @@ func (h *sentPacketHandler) ReceivedAck(ack *wire.AckFrame, encLevel protocol.En } } - pnSpace.largestAcked = utils.Max(pnSpace.largestAcked, largestAcked) + pnSpace.largestAcked = max(pnSpace.largestAcked, largestAcked) if err := h.detectLostPackets(rcvTime, encLevel); err != nil { return false, err @@ -446,7 +446,7 @@ func (h *sentPacketHandler) detectAndRemoveAckedPackets(ack *wire.AckFrame, encL for _, p := range h.ackedPackets { if p.LargestAcked != protocol.InvalidPacketNumber && encLevel == protocol.Encryption1RTT { - h.lowestNotConfirmedAcked = utils.Max(h.lowestNotConfirmedAcked, p.LargestAcked+1) + h.lowestNotConfirmedAcked = max(h.lowestNotConfirmedAcked, p.LargestAcked+1) } for _, f := range p.Frames { @@ -607,11 +607,11 @@ func (h *sentPacketHandler) detectLostPackets(now time.Time, encLevel protocol.E pnSpace := h.getPacketNumberSpace(encLevel) pnSpace.lossTime = time.Time{} - maxRTT := float64(utils.Max(h.rttStats.LatestRTT(), h.rttStats.SmoothedRTT())) + maxRTT := float64(max(h.rttStats.LatestRTT(), h.rttStats.SmoothedRTT())) lossDelay := time.Duration(timeThreshold * maxRTT) // Minimum time of granularity before packets are deemed lost. - lossDelay = utils.Max(lossDelay, protocol.TimerGranularity) + lossDelay = max(lossDelay, protocol.TimerGranularity) // Packets sent before this time are deemed lost. lostSendTime := now.Add(-lossDelay) @@ -891,7 +891,7 @@ func (h *sentPacketHandler) ResetForRetry(now time.Time) error { // Otherwise, we don't know which Initial the Retry was sent in response to. if h.ptoCount == 0 { // Don't set the RTT to a value lower than 5ms here. - h.rttStats.UpdateRTT(utils.Max(minRTTAfterRetry, now.Sub(firstPacketSendTime)), 0, now) + h.rttStats.UpdateRTT(max(minRTTAfterRetry, now.Sub(firstPacketSendTime)), 0, now) if h.logger.Debug() { h.logger.Debugf("\tupdated RTT: %s (σ: %s)", h.rttStats.SmoothedRTT(), h.rttStats.MeanDeviation()) } diff --git a/internal/ackhandler/u_ackhandler.go b/internal/ackhandler/u_ackhandler.go index 00a2a8c0a..8a4bdeb65 100644 --- a/internal/ackhandler/u_ackhandler.go +++ b/internal/ackhandler/u_ackhandler.go @@ -20,5 +20,5 @@ func NewUAckHandler( sph := newSentPacketHandler(initialPacketNumber, initialMaxDatagramSize, rttStats, clientAddressValidated, enableECN, pers, tracer, logger) return &uSentPacketHandler{ sentPacketHandler: sph, - }, newReceivedPacketHandler(sph, rttStats, logger) + }, newReceivedPacketHandler(sph, logger) } diff --git a/internal/congestion/cubic.go b/internal/congestion/cubic.go index df18b7be9..4ee24107d 100644 --- a/internal/congestion/cubic.go +++ b/internal/congestion/cubic.go @@ -5,7 +5,6 @@ import ( "time" "github.com/refraction-networking/uquic/internal/protocol" - "github.com/refraction-networking/uquic/internal/utils" ) // This cubic implementation is based on the one found in Chromiums's QUIC @@ -187,7 +186,7 @@ func (c *Cubic) CongestionWindowAfterAck( targetCongestionWindow = c.originPointCongestionWindow - deltaCongestionWindow } // Limit the CWND increase to half the acked bytes. - targetCongestionWindow = utils.Min(targetCongestionWindow, currentCongestionWindow+c.ackedBytesCount/2) + targetCongestionWindow = min(targetCongestionWindow, currentCongestionWindow+c.ackedBytesCount/2) // Increase the window by approximately Alpha * 1 MSS of bytes every // time we ack an estimated tcp window of bytes. For small diff --git a/internal/congestion/cubic_sender.go b/internal/congestion/cubic_sender.go index e5d297644..7649a6984 100644 --- a/internal/congestion/cubic_sender.go +++ b/internal/congestion/cubic_sender.go @@ -178,7 +178,7 @@ func (c *cubicSender) OnPacketAcked( priorInFlight protocol.ByteCount, eventTime time.Time, ) { - c.largestAckedPacketNumber = utils.Max(ackedPacketNumber, c.largestAckedPacketNumber) + c.largestAckedPacketNumber = max(ackedPacketNumber, c.largestAckedPacketNumber) if c.InRecovery() { return } @@ -246,7 +246,7 @@ func (c *cubicSender) maybeIncreaseCwnd( c.numAckedPackets = 0 } } else { - c.congestionWindow = utils.Min(c.maxCongestionWindow(), c.cubic.CongestionWindowAfterAck(ackedBytes, c.congestionWindow, c.rttStats.MinRTT(), eventTime)) + c.congestionWindow = min(c.maxCongestionWindow(), c.cubic.CongestionWindowAfterAck(ackedBytes, c.congestionWindow, c.rttStats.MinRTT(), eventTime)) } } diff --git a/internal/congestion/hybrid_slow_start.go b/internal/congestion/hybrid_slow_start.go index ae1909716..af3d76977 100644 --- a/internal/congestion/hybrid_slow_start.go +++ b/internal/congestion/hybrid_slow_start.go @@ -4,7 +4,6 @@ import ( "time" "github.com/refraction-networking/uquic/internal/protocol" - "github.com/refraction-networking/uquic/internal/utils" ) // Note(pwestin): the magic clamping numbers come from the original code in @@ -75,8 +74,8 @@ func (s *HybridSlowStart) ShouldExitSlowStart(latestRTT time.Duration, minRTT ti // Divide minRTT by 8 to get a rtt increase threshold for exiting. minRTTincreaseThresholdUs := int64(minRTT / time.Microsecond >> hybridStartDelayFactorExp) // Ensure the rtt threshold is never less than 2ms or more than 16ms. - minRTTincreaseThresholdUs = utils.Min(minRTTincreaseThresholdUs, hybridStartDelayMaxThresholdUs) - minRTTincreaseThreshold := time.Duration(utils.Max(minRTTincreaseThresholdUs, hybridStartDelayMinThresholdUs)) * time.Microsecond + minRTTincreaseThresholdUs = min(minRTTincreaseThresholdUs, hybridStartDelayMaxThresholdUs) + minRTTincreaseThreshold := time.Duration(max(minRTTincreaseThresholdUs, hybridStartDelayMinThresholdUs)) * time.Microsecond if s.currentMinRTT > (minRTT + minRTTincreaseThreshold) { s.hystartFound = true diff --git a/internal/congestion/pacer.go b/internal/congestion/pacer.go index 3b4334045..79761c41b 100644 --- a/internal/congestion/pacer.go +++ b/internal/congestion/pacer.go @@ -1,11 +1,9 @@ package congestion import ( - "math" "time" "github.com/refraction-networking/uquic/internal/protocol" - "github.com/refraction-networking/uquic/internal/utils" ) const maxBurstSizePackets = 10 @@ -26,7 +24,7 @@ func newPacer(getBandwidth func() Bandwidth) *pacer { bw := uint64(getBandwidth() / BytesPerSecond) // Use a slightly higher value than the actual measured bandwidth. // RTT variations then won't result in under-utilization of the congestion window. - // Ultimately, this will result in sending packets as acknowledgments are received rather than when timers fire, + // Ultimately, this will result in sending packets as acknowledgments are received rather than when timers fire, // provided the congestion window is fully utilized and acknowledgments arrive at regular intervals. return bw * 5 / 4 }, @@ -37,7 +35,7 @@ func newPacer(getBandwidth func() Bandwidth) *pacer { func (p *pacer) SentPacket(sendTime time.Time, size protocol.ByteCount) { budget := p.Budget(sendTime) - if size > budget { + if size >= budget { p.budgetAtLastSent = 0 } else { p.budgetAtLastSent = budget - size @@ -53,11 +51,11 @@ func (p *pacer) Budget(now time.Time) protocol.ByteCount { if budget < 0 { // protect against overflows budget = protocol.MaxByteCount } - return utils.Min(p.maxBurstSize(), budget) + return min(p.maxBurstSize(), budget) } func (p *pacer) maxBurstSize() protocol.ByteCount { - return utils.Max( + return max( protocol.ByteCount(uint64((protocol.MinPacingDelay+protocol.TimerGranularity).Nanoseconds())*p.adjustedBandwidth())/1e9, maxBurstSizePackets*p.maxDatagramSize, ) @@ -69,10 +67,16 @@ func (p *pacer) TimeUntilSend() time.Time { if p.budgetAtLastSent >= p.maxDatagramSize { return time.Time{} } - return p.lastSentTime.Add(utils.Max( - protocol.MinPacingDelay, - time.Duration(math.Ceil(float64(p.maxDatagramSize-p.budgetAtLastSent)*1e9/float64(p.adjustedBandwidth())))*time.Nanosecond, - )) + diff := 1e9 * uint64(p.maxDatagramSize-p.budgetAtLastSent) + bw := p.adjustedBandwidth() + // We might need to round up this value. + // Otherwise, we might have a budget (slightly) smaller than the datagram size when the timer expires. + d := diff / bw + // this is effectively a math.Ceil, but using only integer math + if diff%bw > 0 { + d++ + } + return p.lastSentTime.Add(max(protocol.MinPacingDelay, time.Duration(d)*time.Nanosecond)) } func (p *pacer) SetMaxDatagramSize(s protocol.ByteCount) { diff --git a/internal/congestion/pacer_test.go b/internal/congestion/pacer_test.go index 9843dab18..8d5e8a609 100644 --- a/internal/congestion/pacer_test.go +++ b/internal/congestion/pacer_test.go @@ -101,6 +101,17 @@ var _ = Describe("Pacer", func() { Expect(p.Budget(t.Add(5 * t2.Sub(t)))).To(BeEquivalentTo(5 * packetSize)) }) + It("has enough budget for at least one packet when the timer expires", func() { + t := time.Now() + sendBurst(t) + for bw := uint64(100); bw < uint64(5*initialMaxDatagramSize); bw++ { + bandwidth = bw // reduce the bandwidth to 5 packet per second + t2 := p.TimeUntilSend() + Expect(t2).To(BeTemporally(">", t)) + Expect(p.Budget(t2)).To(BeNumerically(">=", initialMaxDatagramSize)) + } + }) + It("never allows bursts larger than the maximum burst size", func() { t := time.Now() sendBurst(t) diff --git a/internal/flowcontrol/base_flow_controller.go b/internal/flowcontrol/base_flow_controller.go index 537c8dae5..0a75ff6bb 100644 --- a/internal/flowcontrol/base_flow_controller.go +++ b/internal/flowcontrol/base_flow_controller.go @@ -48,10 +48,12 @@ func (c *baseFlowController) AddBytesSent(n protocol.ByteCount) { } // UpdateSendWindow is called after receiving a MAX_{STREAM_}DATA frame. -func (c *baseFlowController) UpdateSendWindow(offset protocol.ByteCount) { +func (c *baseFlowController) UpdateSendWindow(offset protocol.ByteCount) (updated bool) { if offset > c.sendWindow { c.sendWindow = offset + return true } + return false } func (c *baseFlowController) sendWindowSize() protocol.ByteCount { @@ -107,7 +109,7 @@ func (c *baseFlowController) maybeAdjustWindowSize() { now := time.Now() if now.Sub(c.epochStartTime) < time.Duration(4*fraction*float64(rtt)) { // window is consumed too fast, try to increase the window size - newSize := utils.Min(2*c.receiveWindowSize, c.maxReceiveWindowSize) + newSize := min(2*c.receiveWindowSize, c.maxReceiveWindowSize) if newSize > c.receiveWindowSize && (c.allowWindowIncrease == nil || c.allowWindowIncrease(newSize-c.receiveWindowSize)) { c.receiveWindowSize = newSize } diff --git a/internal/flowcontrol/base_flow_controller_test.go b/internal/flowcontrol/base_flow_controller_test.go index db22b88bd..619d4d0dc 100644 --- a/internal/flowcontrol/base_flow_controller_test.go +++ b/internal/flowcontrol/base_flow_controller_test.go @@ -59,9 +59,9 @@ var _ = Describe("Base Flow controller", func() { }) It("does not decrease the flow control window", func() { - controller.UpdateSendWindow(20) + Expect(controller.UpdateSendWindow(20)).To(BeTrue()) Expect(controller.sendWindowSize()).To(Equal(protocol.ByteCount(20))) - controller.UpdateSendWindow(10) + Expect(controller.UpdateSendWindow(10)).To(BeFalse()) Expect(controller.sendWindowSize()).To(Equal(protocol.ByteCount(20))) }) diff --git a/internal/flowcontrol/connection_flow_controller.go b/internal/flowcontrol/connection_flow_controller.go index 980cbc59d..868ab70ee 100644 --- a/internal/flowcontrol/connection_flow_controller.go +++ b/internal/flowcontrol/connection_flow_controller.go @@ -87,7 +87,7 @@ func (c *connectionFlowController) EnsureMinimumWindowSize(inc protocol.ByteCoun c.mutex.Lock() if inc > c.receiveWindowSize { c.logger.Debugf("Increasing receive flow control window for the connection to %d kB, in response to stream flow control window increase", c.receiveWindowSize/(1<<10)) - newSize := utils.Min(inc, c.maxReceiveWindowSize) + newSize := min(inc, c.maxReceiveWindowSize) if delta := newSize - c.receiveWindowSize; delta > 0 && c.allowWindowIncrease(delta) { c.receiveWindowSize = newSize } diff --git a/internal/flowcontrol/interface.go b/internal/flowcontrol/interface.go index b1a5d59ec..76d890c7e 100644 --- a/internal/flowcontrol/interface.go +++ b/internal/flowcontrol/interface.go @@ -5,7 +5,7 @@ import "github.com/refraction-networking/uquic/internal/protocol" type flowController interface { // for sending SendWindowSize() protocol.ByteCount - UpdateSendWindow(protocol.ByteCount) + UpdateSendWindow(protocol.ByteCount) (updated bool) AddBytesSent(protocol.ByteCount) // for receiving AddBytesRead(protocol.ByteCount) @@ -16,12 +16,11 @@ type flowController interface { // A StreamFlowController is a flow controller for a QUIC stream. type StreamFlowController interface { flowController - // for receiving - // UpdateHighestReceived should be called when a new highest offset is received + // UpdateHighestReceived is called when a new highest offset is received // final has to be to true if this is the final offset of the stream, // as contained in a STREAM frame with FIN bit, and the RESET_STREAM frame UpdateHighestReceived(offset protocol.ByteCount, final bool) error - // Abandon should be called when reading from the stream is aborted early, + // Abandon is called when reading from the stream is aborted early, // and there won't be any further calls to AddBytesRead. Abandon() } diff --git a/internal/flowcontrol/stream_flow_controller.go b/internal/flowcontrol/stream_flow_controller.go index 7db487406..a3c400097 100644 --- a/internal/flowcontrol/stream_flow_controller.go +++ b/internal/flowcontrol/stream_flow_controller.go @@ -123,7 +123,7 @@ func (c *streamFlowController) AddBytesSent(n protocol.ByteCount) { } func (c *streamFlowController) SendWindowSize() protocol.ByteCount { - return utils.Min(c.baseFlowController.sendWindowSize(), c.connection.SendWindowSize()) + return min(c.baseFlowController.sendWindowSize(), c.connection.SendWindowSize()) } func (c *streamFlowController) shouldQueueWindowUpdate() bool { diff --git a/internal/handshake/aead.go b/internal/handshake/aead.go index 6d88bfdaa..41384e030 100644 --- a/internal/handshake/aead.go +++ b/internal/handshake/aead.go @@ -1,14 +1,12 @@ package handshake import ( - "crypto/cipher" "encoding/binary" "github.com/refraction-networking/uquic/internal/protocol" - "github.com/refraction-networking/uquic/internal/utils" ) -func createAEAD(suite *cipherSuite, trafficSecret []byte, v protocol.VersionNumber) cipher.AEAD { +func createAEAD(suite *cipherSuite, trafficSecret []byte, v protocol.Version) *xorNonceAEAD { keyLabel := hkdfLabelKeyV1 ivLabel := hkdfLabelIVV1 if v == protocol.Version2 { @@ -21,28 +19,26 @@ func createAEAD(suite *cipherSuite, trafficSecret []byte, v protocol.VersionNumb } type longHeaderSealer struct { - aead cipher.AEAD + aead *xorNonceAEAD headerProtector headerProtector - - // use a single slice to avoid allocations - nonceBuf []byte + nonceBuf [8]byte } var _ LongHeaderSealer = &longHeaderSealer{} -func newLongHeaderSealer(aead cipher.AEAD, headerProtector headerProtector) LongHeaderSealer { +func newLongHeaderSealer(aead *xorNonceAEAD, headerProtector headerProtector) LongHeaderSealer { + if aead.NonceSize() != 8 { + panic("unexpected nonce size") + } return &longHeaderSealer{ aead: aead, headerProtector: headerProtector, - nonceBuf: make([]byte, aead.NonceSize()), } } func (s *longHeaderSealer) Seal(dst, src []byte, pn protocol.PacketNumber, ad []byte) []byte { - binary.BigEndian.PutUint64(s.nonceBuf[len(s.nonceBuf)-8:], uint64(pn)) - // The AEAD we're using here will be the qtls.aeadAESGCM13. - // It uses the nonce provided here and XOR it with the IV. - return s.aead.Seal(dst, s.nonceBuf, src, ad) + binary.BigEndian.PutUint64(s.nonceBuf[:], uint64(pn)) + return s.aead.Seal(dst, s.nonceBuf[:], src, ad) } func (s *longHeaderSealer) EncryptHeader(sample []byte, firstByte *byte, pnBytes []byte) { @@ -54,21 +50,23 @@ func (s *longHeaderSealer) Overhead() int { } type longHeaderOpener struct { - aead cipher.AEAD + aead *xorNonceAEAD headerProtector headerProtector highestRcvdPN protocol.PacketNumber // highest packet number received (which could be successfully unprotected) - // use a single slice to avoid allocations - nonceBuf []byte + // use a single array to avoid allocations + nonceBuf [8]byte } var _ LongHeaderOpener = &longHeaderOpener{} -func newLongHeaderOpener(aead cipher.AEAD, headerProtector headerProtector) LongHeaderOpener { +func newLongHeaderOpener(aead *xorNonceAEAD, headerProtector headerProtector) LongHeaderOpener { + if aead.NonceSize() != 8 { + panic("unexpected nonce size") + } return &longHeaderOpener{ aead: aead, headerProtector: headerProtector, - nonceBuf: make([]byte, aead.NonceSize()), } } @@ -77,12 +75,10 @@ func (o *longHeaderOpener) DecodePacketNumber(wirePN protocol.PacketNumber, wire } func (o *longHeaderOpener) Open(dst, src []byte, pn protocol.PacketNumber, ad []byte) ([]byte, error) { - binary.BigEndian.PutUint64(o.nonceBuf[len(o.nonceBuf)-8:], uint64(pn)) - // The AEAD we're using here will be the qtls.aeadAESGCM13. - // It uses the nonce provided here and XOR it with the IV. - dec, err := o.aead.Open(dst, o.nonceBuf, src, ad) + binary.BigEndian.PutUint64(o.nonceBuf[:], uint64(pn)) + dec, err := o.aead.Open(dst, o.nonceBuf[:], src, ad) if err == nil { - o.highestRcvdPN = utils.Max(o.highestRcvdPN, pn) + o.highestRcvdPN = max(o.highestRcvdPN, pn) } else { err = ErrDecryptionFailed } diff --git a/internal/handshake/aead_test.go b/internal/handshake/aead_test.go index f9cf64970..8e37b0424 100644 --- a/internal/handshake/aead_test.go +++ b/internal/handshake/aead_test.go @@ -16,7 +16,7 @@ import ( ) var _ = Describe("Long Header AEAD", func() { - for _, ver := range []protocol.VersionNumber{protocol.Version1, protocol.Version2} { + for _, ver := range []protocol.Version{protocol.Version1, protocol.Version2} { v := ver Context(fmt.Sprintf("using version %s", v), func() { @@ -34,8 +34,8 @@ var _ = Describe("Long Header AEAD", func() { aead, err := cipher.NewGCM(block) Expect(err).ToNot(HaveOccurred()) - return newLongHeaderSealer(aead, newHeaderProtector(cs, hpKey, true, v)), - newLongHeaderOpener(aead, newHeaderProtector(cs, hpKey, true, v)) + return newLongHeaderSealer(&xorNonceAEAD{aead: aead}, newHeaderProtector(cs, hpKey, true, v)), + newLongHeaderOpener(&xorNonceAEAD{aead: aead}, newHeaderProtector(cs, hpKey, true, v)) } Context("message encryption", func() { diff --git a/internal/handshake/cipher_suite.go b/internal/handshake/cipher_suite.go index 368a95d48..5f5e6c6ff 100644 --- a/internal/handshake/cipher_suite.go +++ b/internal/handshake/cipher_suite.go @@ -19,7 +19,7 @@ type cipherSuite struct { ID uint16 Hash crypto.Hash KeyLen int - AEAD func(key, nonceMask []byte) cipher.AEAD + AEAD func(key, nonceMask []byte) *xorNonceAEAD } func (s cipherSuite) IVLen() int { return aeadNonceLength } @@ -37,7 +37,7 @@ func getCipherSuite(id uint16) *cipherSuite { } } -func aeadAESGCMTLS13(key, nonceMask []byte) cipher.AEAD { +func aeadAESGCMTLS13(key, nonceMask []byte) *xorNonceAEAD { if len(nonceMask) != aeadNonceLength { panic("tls: internal error: wrong nonce length") } @@ -55,7 +55,7 @@ func aeadAESGCMTLS13(key, nonceMask []byte) cipher.AEAD { return ret } -func aeadChaCha20Poly1305(key, nonceMask []byte) cipher.AEAD { +func aeadChaCha20Poly1305(key, nonceMask []byte) *xorNonceAEAD { if len(nonceMask) != aeadNonceLength { panic("tls: internal error: wrong nonce length") } diff --git a/internal/handshake/crypto_setup.go b/internal/handshake/crypto_setup.go index 07ade6aec..a7cb5f18b 100644 --- a/internal/handshake/crypto_setup.go +++ b/internal/handshake/crypto_setup.go @@ -7,7 +7,6 @@ import ( "fmt" "net" "strings" - "sync" "sync/atomic" "time" @@ -26,15 +25,15 @@ type quicVersionContextKey struct{} var QUICVersionContextKey = &quicVersionContextKey{} -const clientSessionStateRevision = 3 +const clientSessionStateRevision = 4 type cryptoSetup struct { tlsConf *tls.Config - conn *qtls.QUICConn + conn *tls.QUICConn events []Event - version protocol.VersionNumber + version protocol.Version ourParams *wire.TransportParameters peerParams *wire.TransportParameters @@ -49,8 +48,6 @@ type cryptoSetup struct { perspective protocol.Perspective - mutex sync.Mutex // protects all members below - handshakeCompleteTime time.Time zeroRTTOpener LongHeaderOpener // only set for the server @@ -80,7 +77,7 @@ func NewCryptoSetupClient( rttStats *utils.RTTStats, tracer *logging.ConnectionTracer, logger utils.Logger, - version protocol.VersionNumber, + version protocol.Version, ) CryptoSetup { cs := newCryptoSetup( connID, @@ -94,11 +91,12 @@ func NewCryptoSetupClient( tlsConf = tlsConf.Clone() tlsConf.MinVersion = tls.VersionTLS13 - quicConf := &qtls.QUICConfig{TLSConfig: tlsConf} + quicConf := &tls.QUICConfig{TLSConfig: tlsConf} qtls.SetupConfigForClient(quicConf, cs.marshalDataForSessionState, cs.handleDataFromSessionState) cs.tlsConf = tlsConf + cs.allow0RTT = enable0RTT - cs.conn = qtls.QUICClient(quicConf) + cs.conn = tls.QUICClient(quicConf) cs.conn.SetTransportParameters(cs.ourParams.Marshal(protocol.PerspectiveClient)) return cs @@ -114,7 +112,7 @@ func NewCryptoSetupServer( rttStats *utils.RTTStats, tracer *logging.ConnectionTracer, logger utils.Logger, - version protocol.VersionNumber, + version protocol.Version, ) CryptoSetup { cs := newCryptoSetup( connID, @@ -127,12 +125,12 @@ func NewCryptoSetupServer( ) cs.allow0RTT = allow0RTT - quicConf := &qtls.QUICConfig{TLSConfig: tlsConf} + quicConf := &tls.QUICConfig{TLSConfig: tlsConf} qtls.SetupConfigForServer(quicConf, cs.allow0RTT, cs.getDataForSessionTicket, cs.handleSessionTicket) addConnToClientHelloInfo(quicConf.TLSConfig, localAddr, remoteAddr) cs.tlsConf = quicConf.TLSConfig - cs.conn = qtls.QUICServer(quicConf) + cs.conn = tls.QUICServer(quicConf) return cs } @@ -147,6 +145,9 @@ func addConnToClientHelloInfo(conf *tls.Config, localAddr, remoteAddr net.Addr) info.Conn = &conn{localAddr: localAddr, remoteAddr: remoteAddr} c, err := gcfc(info) if c != nil { + c = c.Clone() + // This won't be necessary anymore once https://github.com/golang/go/issues/63722 is accepted. + c.MinVersion = tls.VersionTLS13 // We're returning a tls.Config here, so we need to apply this recursively. addConnToClientHelloInfo(c, localAddr, remoteAddr) } @@ -169,7 +170,7 @@ func newCryptoSetup( tracer *logging.ConnectionTracer, logger utils.Logger, perspective protocol.Perspective, - version protocol.VersionNumber, + version protocol.Version, ) *cryptoSetup { initialSealer, initialOpener := NewInitialAEAD(connID, perspective, version) if tracer != nil && tracer.UpdatedKeyFromTLS != nil { @@ -261,29 +262,28 @@ func (h *cryptoSetup) handleMessage(data []byte, encLevel protocol.EncryptionLev } } -func (h *cryptoSetup) handleEvent(ev qtls.QUICEvent) (done bool, err error) { +func (h *cryptoSetup) handleEvent(ev tls.QUICEvent) (done bool, err error) { switch ev.Kind { - case qtls.QUICNoEvent: + case tls.QUICNoEvent: return true, nil - case qtls.QUICSetReadSecret: - h.SetReadKey(ev.Level, ev.Suite, ev.Data) + case tls.QUICSetReadSecret: + h.setReadKey(ev.Level, ev.Suite, ev.Data) return false, nil - case qtls.QUICSetWriteSecret: - h.SetWriteKey(ev.Level, ev.Suite, ev.Data) + case tls.QUICSetWriteSecret: + h.setWriteKey(ev.Level, ev.Suite, ev.Data) return false, nil - case qtls.QUICTransportParameters: + case tls.QUICTransportParameters: return false, h.handleTransportParameters(ev.Data) - case qtls.QUICTransportParametersRequired: + case tls.QUICTransportParametersRequired: h.conn.SetTransportParameters(h.ourParams.Marshal(h.perspective)) - // [UQUIC] doesn't expect this and may fail return false, nil - case qtls.QUICRejectedEarlyData: + case tls.QUICRejectedEarlyData: h.rejected0RTT() return false, nil - case qtls.QUICWriteData: - h.WriteRecord(ev.Level, ev.Data) + case tls.QUICWriteData: + h.writeRecord(ev.Level, ev.Data) return false, nil - case qtls.QUICHandshakeDone: + case tls.QUICHandshakeDone: h.handshakeComplete() return false, nil default: @@ -311,41 +311,56 @@ func (h *cryptoSetup) handleTransportParameters(data []byte) error { } // must be called after receiving the transport parameters -func (h *cryptoSetup) marshalDataForSessionState() []byte { +func (h *cryptoSetup) marshalDataForSessionState(earlyData bool) []byte { b := make([]byte, 0, 256) b = quicvarint.Append(b, clientSessionStateRevision) b = quicvarint.Append(b, uint64(h.rttStats.SmoothedRTT().Microseconds())) - return h.peerParams.MarshalForSessionTicket(b) + if earlyData { + // only save the transport parameters for 0-RTT enabled session tickets + return h.peerParams.MarshalForSessionTicket(b) + } + return b } -func (h *cryptoSetup) handleDataFromSessionState(data []byte) { - tp, err := h.handleDataFromSessionStateImpl(data) +func (h *cryptoSetup) handleDataFromSessionState(data []byte, earlyData bool) (allowEarlyData bool) { + rtt, tp, err := decodeDataFromSessionState(data, earlyData) if err != nil { h.logger.Debugf("Restoring of transport parameters from session ticket failed: %s", err.Error()) return } - h.zeroRTTParameters = tp + h.rttStats.SetInitialRTT(rtt) + // The session ticket might have been saved from a connection that allowed 0-RTT, + // and therefore contain transport parameters. + // Only use them if 0-RTT is actually used on the new connection. + if tp != nil && h.allow0RTT { + h.zeroRTTParameters = tp + return true + } + return false } -func (h *cryptoSetup) handleDataFromSessionStateImpl(data []byte) (*wire.TransportParameters, error) { +func decodeDataFromSessionState(data []byte, earlyData bool) (time.Duration, *wire.TransportParameters, error) { r := bytes.NewReader(data) ver, err := quicvarint.Read(r) if err != nil { - return nil, err + return 0, nil, err } if ver != clientSessionStateRevision { - return nil, fmt.Errorf("mismatching version. Got %d, expected %d", ver, clientSessionStateRevision) + return 0, nil, fmt.Errorf("mismatching version. Got %d, expected %d", ver, clientSessionStateRevision) } - rtt, err := quicvarint.Read(r) + rttEncoded, err := quicvarint.Read(r) if err != nil { - return nil, err + return 0, nil, err + } + rtt := time.Duration(rttEncoded) * time.Microsecond + if !earlyData { + return rtt, nil, nil } - h.rttStats.SetInitialRTT(time.Duration(rtt) * time.Microsecond) var tp wire.TransportParameters if err := tp.UnmarshalFromSessionTicket(r); err != nil { - return nil, err + return 0, nil, err } - return &tp, nil + return rtt, &tp, nil } func (h *cryptoSetup) getDataForSessionTicket() []byte { @@ -362,7 +377,9 @@ func (h *cryptoSetup) getDataForSessionTicket() []byte { // Due to limitations in crypto/tls, it's only possible to generate a single session ticket per connection. // It is only valid for the server. func (h *cryptoSetup) GetSessionTicket() ([]byte, error) { - if err := qtls.SendSessionTicket(h.conn, h.allow0RTT); err != nil { + if err := h.conn.SendSessionTicket(tls.QUICSessionTicketOptions{ + EarlyData: h.allow0RTT, + }); err != nil { // Session tickets might be disabled by tls.Config.SessionTicketsDisabled. // We can't check h.tlsConfig here, since the actual config might have been obtained from // the GetConfigForClient callback. @@ -374,18 +391,20 @@ func (h *cryptoSetup) GetSessionTicket() ([]byte, error) { return nil, err } ev := h.conn.NextEvent() - if ev.Kind != qtls.QUICWriteData || ev.Level != qtls.QUICEncryptionLevelApplication { + if ev.Kind != tls.QUICWriteData || ev.Level != tls.QUICEncryptionLevelApplication { panic("crypto/tls bug: where's my session ticket?") } ticket := ev.Data - if ev := h.conn.NextEvent(); ev.Kind != qtls.QUICNoEvent { + if ev := h.conn.NextEvent(); ev.Kind != tls.QUICNoEvent { panic("crypto/tls bug: why more than one ticket?") } return ticket, nil } // handleSessionTicket is called for the server when receiving the client's session ticket. -// It reads parameters from the session ticket and decides whether to accept 0-RTT when the session ticket is used for 0-RTT. +// It reads parameters from the session ticket and checks whether to accept 0-RTT if the session ticket enabled 0-RTT. +// Note that the fact that the session ticket allows 0-RTT doesn't mean that the actual TLS handshake enables 0-RTT: +// A client may use a 0-RTT enabled session to resume a TLS session without using 0-RTT. func (h *cryptoSetup) handleSessionTicket(sessionTicketData []byte, using0RTT bool) bool { var t sessionTicket if err := t.Unmarshal(sessionTicketData, using0RTT); err != nil { @@ -413,22 +432,19 @@ func (h *cryptoSetup) handleSessionTicket(sessionTicketData []byte, using0RTT bo func (h *cryptoSetup) rejected0RTT() { h.logger.Debugf("0-RTT was rejected. Dropping 0-RTT keys.") - h.mutex.Lock() had0RTTKeys := h.zeroRTTSealer != nil h.zeroRTTSealer = nil - h.mutex.Unlock() if had0RTTKeys { h.events = append(h.events, Event{Kind: EventDiscard0RTTKeys}) } } -func (h *cryptoSetup) SetReadKey(el qtls.QUICEncryptionLevel, suiteID uint16, trafficSecret []byte) { +func (h *cryptoSetup) setReadKey(el tls.QUICEncryptionLevel, suiteID uint16, trafficSecret []byte) { suite := getCipherSuite(suiteID) - h.mutex.Lock() //nolint:exhaustive // The TLS stack doesn't export Initial keys. switch el { - case qtls.QUICEncryptionLevelEarly: + case tls.QUICEncryptionLevelEarly: if h.perspective == protocol.PerspectiveClient { panic("Received 0-RTT read key for the client") } @@ -440,7 +456,7 @@ func (h *cryptoSetup) SetReadKey(el qtls.QUICEncryptionLevel, suiteID uint16, tr if h.logger.Debug() { h.logger.Debugf("Installed 0-RTT Read keys (using %s)", tls.CipherSuiteName(suite.ID)) } - case qtls.QUICEncryptionLevelHandshake: + case tls.QUICEncryptionLevelHandshake: h.handshakeOpener = newLongHeaderOpener( createAEAD(suite, trafficSecret, h.version), newHeaderProtector(suite, trafficSecret, true, h.version), @@ -448,7 +464,7 @@ func (h *cryptoSetup) SetReadKey(el qtls.QUICEncryptionLevel, suiteID uint16, tr if h.logger.Debug() { h.logger.Debugf("Installed Handshake Read keys (using %s)", tls.CipherSuiteName(suite.ID)) } - case qtls.QUICEncryptionLevelApplication: + case tls.QUICEncryptionLevelApplication: h.aead.SetReadKey(suite, trafficSecret) h.has1RTTOpener = true if h.logger.Debug() { @@ -457,19 +473,17 @@ func (h *cryptoSetup) SetReadKey(el qtls.QUICEncryptionLevel, suiteID uint16, tr default: panic("unexpected read encryption level") } - h.mutex.Unlock() h.events = append(h.events, Event{Kind: EventReceivedReadKeys}) if h.tracer != nil && h.tracer.UpdatedKeyFromTLS != nil { h.tracer.UpdatedKeyFromTLS(qtls.FromTLSEncryptionLevel(el), h.perspective.Opposite()) } } -func (h *cryptoSetup) SetWriteKey(el qtls.QUICEncryptionLevel, suiteID uint16, trafficSecret []byte) { +func (h *cryptoSetup) setWriteKey(el tls.QUICEncryptionLevel, suiteID uint16, trafficSecret []byte) { suite := getCipherSuite(suiteID) - h.mutex.Lock() //nolint:exhaustive // The TLS stack doesn't export Initial keys. switch el { - case qtls.QUICEncryptionLevelEarly: + case tls.QUICEncryptionLevelEarly: if h.perspective == protocol.PerspectiveServer { panic("Received 0-RTT write key for the server") } @@ -477,7 +491,6 @@ func (h *cryptoSetup) SetWriteKey(el qtls.QUICEncryptionLevel, suiteID uint16, t createAEAD(suite, trafficSecret, h.version), newHeaderProtector(suite, trafficSecret, true, h.version), ) - h.mutex.Unlock() if h.logger.Debug() { h.logger.Debugf("Installed 0-RTT Write keys (using %s)", tls.CipherSuiteName(suite.ID)) } @@ -486,7 +499,7 @@ func (h *cryptoSetup) SetWriteKey(el qtls.QUICEncryptionLevel, suiteID uint16, t } // don't set used0RTT here. 0-RTT might still get rejected. return - case qtls.QUICEncryptionLevelHandshake: + case tls.QUICEncryptionLevelHandshake: h.handshakeSealer = newLongHeaderSealer( createAEAD(suite, trafficSecret, h.version), newHeaderProtector(suite, trafficSecret, true, h.version), @@ -494,7 +507,7 @@ func (h *cryptoSetup) SetWriteKey(el qtls.QUICEncryptionLevel, suiteID uint16, t if h.logger.Debug() { h.logger.Debugf("Installed Handshake Write keys (using %s)", tls.CipherSuiteName(suite.ID)) } - case qtls.QUICEncryptionLevelApplication: + case tls.QUICEncryptionLevelApplication: h.aead.SetWriteKey(suite, trafficSecret) h.has1RTTSealer = true if h.logger.Debug() { @@ -512,21 +525,20 @@ func (h *cryptoSetup) SetWriteKey(el qtls.QUICEncryptionLevel, suiteID uint16, t default: panic("unexpected write encryption level") } - h.mutex.Unlock() if h.tracer != nil && h.tracer.UpdatedKeyFromTLS != nil { h.tracer.UpdatedKeyFromTLS(qtls.FromTLSEncryptionLevel(el), h.perspective) } } -// WriteRecord is called when TLS writes data -func (h *cryptoSetup) WriteRecord(encLevel qtls.QUICEncryptionLevel, p []byte) { +// writeRecord is called when TLS writes data +func (h *cryptoSetup) writeRecord(encLevel tls.QUICEncryptionLevel, p []byte) { //nolint:exhaustive // handshake records can only be written for Initial and Handshake. switch encLevel { - case qtls.QUICEncryptionLevelInitial: + case tls.QUICEncryptionLevelInitial: h.events = append(h.events, Event{Kind: EventWriteInitialData, Data: p}) - case qtls.QUICEncryptionLevelHandshake: + case tls.QUICEncryptionLevelHandshake: h.events = append(h.events, Event{Kind: EventWriteHandshakeData, Data: p}) - case qtls.QUICEncryptionLevelApplication: + case tls.QUICEncryptionLevelApplication: panic("unexpected write") default: panic(fmt.Sprintf("unexpected write encryption level: %s", encLevel)) @@ -534,11 +546,9 @@ func (h *cryptoSetup) WriteRecord(encLevel qtls.QUICEncryptionLevel, p []byte) { } func (h *cryptoSetup) DiscardInitialKeys() { - h.mutex.Lock() dropped := h.initialOpener != nil h.initialOpener = nil h.initialSealer = nil - h.mutex.Unlock() if dropped { h.logger.Debugf("Dropping Initial keys.") } @@ -553,22 +563,17 @@ func (h *cryptoSetup) SetHandshakeConfirmed() { h.aead.SetHandshakeConfirmed() // drop Handshake keys var dropped bool - h.mutex.Lock() if h.handshakeOpener != nil { h.handshakeOpener = nil h.handshakeSealer = nil dropped = true } - h.mutex.Unlock() if dropped { h.logger.Debugf("Dropping Handshake keys.") } } func (h *cryptoSetup) GetInitialSealer() (LongHeaderSealer, error) { - h.mutex.Lock() - defer h.mutex.Unlock() - if h.initialSealer == nil { return nil, ErrKeysDropped } @@ -576,9 +581,6 @@ func (h *cryptoSetup) GetInitialSealer() (LongHeaderSealer, error) { } func (h *cryptoSetup) Get0RTTSealer() (LongHeaderSealer, error) { - h.mutex.Lock() - defer h.mutex.Unlock() - if h.zeroRTTSealer == nil { return nil, ErrKeysDropped } @@ -586,9 +588,6 @@ func (h *cryptoSetup) Get0RTTSealer() (LongHeaderSealer, error) { } func (h *cryptoSetup) GetHandshakeSealer() (LongHeaderSealer, error) { - h.mutex.Lock() - defer h.mutex.Unlock() - if h.handshakeSealer == nil { if h.initialSealer == nil { return nil, ErrKeysDropped @@ -599,9 +598,6 @@ func (h *cryptoSetup) GetHandshakeSealer() (LongHeaderSealer, error) { } func (h *cryptoSetup) Get1RTTSealer() (ShortHeaderSealer, error) { - h.mutex.Lock() - defer h.mutex.Unlock() - if !h.has1RTTSealer { return nil, ErrKeysNotYetAvailable } @@ -609,9 +605,6 @@ func (h *cryptoSetup) Get1RTTSealer() (ShortHeaderSealer, error) { } func (h *cryptoSetup) GetInitialOpener() (LongHeaderOpener, error) { - h.mutex.Lock() - defer h.mutex.Unlock() - if h.initialOpener == nil { return nil, ErrKeysDropped } @@ -619,9 +612,6 @@ func (h *cryptoSetup) GetInitialOpener() (LongHeaderOpener, error) { } func (h *cryptoSetup) Get0RTTOpener() (LongHeaderOpener, error) { - h.mutex.Lock() - defer h.mutex.Unlock() - if h.zeroRTTOpener == nil { if h.initialOpener != nil { return nil, ErrKeysNotYetAvailable @@ -633,9 +623,6 @@ func (h *cryptoSetup) Get0RTTOpener() (LongHeaderOpener, error) { } func (h *cryptoSetup) GetHandshakeOpener() (LongHeaderOpener, error) { - h.mutex.Lock() - defer h.mutex.Unlock() - if h.handshakeOpener == nil { if h.initialOpener != nil { return nil, ErrKeysNotYetAvailable @@ -647,9 +634,6 @@ func (h *cryptoSetup) GetHandshakeOpener() (LongHeaderOpener, error) { } func (h *cryptoSetup) Get1RTTOpener() (ShortHeaderOpener, error) { - h.mutex.Lock() - defer h.mutex.Unlock() - if h.zeroRTTOpener != nil && time.Since(h.handshakeCompleteTime) > 3*h.rttStats.PTO(true) { h.zeroRTTOpener = nil h.logger.Debugf("Dropping 0-RTT keys.") @@ -673,7 +657,7 @@ func (h *cryptoSetup) ConnectionState() ConnectionState { func wrapError(err error) error { // alert 80 is an internal error - if alertErr := qtls.AlertError(0); errors.As(err, &alertErr) && alertErr != 80 { + if alertErr := tls.AlertError(0); errors.As(err, &alertErr) && alertErr != 80 { return qerr.NewLocalCryptoError(uint8(alertErr), err) } return &qerr.TransportError{ErrorCode: qerr.InternalError, ErrorMessage: err.Error()} diff --git a/internal/handshake/crypto_setup_test.go b/internal/handshake/crypto_setup_test.go index 2a57ff91c..94ce5b119 100644 --- a/internal/handshake/crypto_setup_test.go +++ b/internal/handshake/crypto_setup_test.go @@ -7,8 +7,7 @@ import ( "crypto/x509/pkix" "math/big" "net" - "runtime" - "strings" + "reflect" "time" tls "github.com/refraction-networking/utls" @@ -140,32 +139,43 @@ var _ = Describe("Crypto Setup TLS", func() { }, } addConnToClientHelloInfo(tlsConf, local, remote) - _, err := tlsConf.GetConfigForClient(&tls.ClientHelloInfo{}) + conf, err := tlsConf.GetConfigForClient(&tls.ClientHelloInfo{}) Expect(err).ToNot(HaveOccurred()) Expect(localAddr).To(Equal(local)) Expect(remoteAddr).To(Equal(remote)) + Expect(conf).ToNot(BeNil()) + Expect(conf.MinVersion).To(BeEquivalentTo(tls.VersionTLS13)) }) It("wraps GetConfigForClient, recursively", func() { var localAddr, remoteAddr net.Addr tlsConf := &tls.Config{} + var innerConf *tls.Config + getCert := func(info *tls.ClientHelloInfo) (*tls.Certificate, error) { //nolint:unparam + localAddr = info.Conn.LocalAddr() + remoteAddr = info.Conn.RemoteAddr() + cert := generateCert() + return &cert, nil + } tlsConf.GetConfigForClient = func(info *tls.ClientHelloInfo) (*tls.Config, error) { - conf := tlsConf.Clone() - conf.GetCertificate = func(info *tls.ClientHelloInfo) (*tls.Certificate, error) { - localAddr = info.Conn.LocalAddr() - remoteAddr = info.Conn.RemoteAddr() - cert := generateCert() - return &cert, nil - } - return conf, nil + innerConf = tlsConf.Clone() + // set the MaxVersion, so we can check that quic-go doesn't overwrite the user's config + innerConf.MaxVersion = tls.VersionTLS12 + innerConf.GetCertificate = getCert + return innerConf, nil } addConnToClientHelloInfo(tlsConf, local, remote) conf, err := tlsConf.GetConfigForClient(&tls.ClientHelloInfo{}) Expect(err).ToNot(HaveOccurred()) + Expect(conf).ToNot(BeNil()) + Expect(conf.MinVersion).To(BeEquivalentTo(tls.VersionTLS13)) _, err = conf.GetCertificate(&tls.ClientHelloInfo{}) Expect(err).ToNot(HaveOccurred()) Expect(localAddr).To(Equal(local)) Expect(remoteAddr).To(Equal(remote)) + // make sure that the tls.Config returned by GetConfigForClient isn't modified + Expect(reflect.ValueOf(innerConf.GetCertificate).Pointer() == reflect.ValueOf(getCert).Pointer()).To(BeTrue()) + Expect(innerConf.MaxVersion).To(BeEquivalentTo(tls.VersionTLS12)) }) }) @@ -452,9 +462,7 @@ var _ = Describe("Crypto Setup TLS", func() { Expect(server.ConnectionState().DidResume).To(BeTrue()) Expect(client.ConnectionState().DidResume).To(BeTrue()) Expect(clientRTTStats.SmoothedRTT()).To(Equal(clientRTT)) - if !strings.Contains(runtime.Version(), "go1.20") { - Expect(serverRTTStats.SmoothedRTT()).To(Equal(serverRTT)) - } + Expect(serverRTTStats.SmoothedRTT()).To(Equal(serverRTT)) }) It("doesn't use session resumption if the server disabled it", func() { diff --git a/internal/handshake/header_protector.go b/internal/handshake/header_protector.go index 21def7130..eb586cbbe 100644 --- a/internal/handshake/header_protector.go +++ b/internal/handshake/header_protector.go @@ -18,14 +18,14 @@ type headerProtector interface { DecryptHeader(sample []byte, firstByte *byte, hdrBytes []byte) } -func hkdfHeaderProtectionLabel(v protocol.VersionNumber) string { +func hkdfHeaderProtectionLabel(v protocol.Version) string { if v == protocol.Version2 { return "quicv2 hp" } return "quic hp" } -func newHeaderProtector(suite *cipherSuite, trafficSecret []byte, isLongHeader bool, v protocol.VersionNumber) headerProtector { +func newHeaderProtector(suite *cipherSuite, trafficSecret []byte, isLongHeader bool, v protocol.Version) headerProtector { hkdfLabel := hkdfHeaderProtectionLabel(v) switch suite.ID { case tls.TLS_AES_128_GCM_SHA256, tls.TLS_AES_256_GCM_SHA384: @@ -38,7 +38,7 @@ func newHeaderProtector(suite *cipherSuite, trafficSecret []byte, isLongHeader b } type aesHeaderProtector struct { - mask []byte + mask [16]byte // AES always has a 16 byte block size block cipher.Block isLongHeader bool } @@ -53,7 +53,6 @@ func newAESHeaderProtector(suite *cipherSuite, trafficSecret []byte, isLongHeade } return &aesHeaderProtector{ block: block, - mask: make([]byte, block.BlockSize()), isLongHeader: isLongHeader, } } @@ -70,7 +69,7 @@ func (p *aesHeaderProtector) apply(sample []byte, firstByte *byte, hdrBytes []by if len(sample) != len(p.mask) { panic("invalid sample size") } - p.block.Encrypt(p.mask, sample) + p.block.Encrypt(p.mask[:], sample) if p.isLongHeader { *firstByte ^= p.mask[0] & 0xf } else { diff --git a/internal/handshake/hkdf.go b/internal/handshake/hkdf.go index c4fd86c57..0caf1c8e5 100644 --- a/internal/handshake/hkdf.go +++ b/internal/handshake/hkdf.go @@ -7,7 +7,7 @@ import ( "golang.org/x/crypto/hkdf" ) -// hkdfExpandLabel HKDF expands a label. +// hkdfExpandLabel HKDF expands a label as defined in RFC 8446, section 7.1. // Since this implementation avoids using a cryptobyte.Builder, it is about 15% faster than the // hkdfExpandLabel in the standard library. func hkdfExpandLabel(hash crypto.Hash, secret, context []byte, label string, length int) []byte { diff --git a/internal/handshake/hkdf_test.go b/internal/handshake/hkdf_test.go index e79a6f1e9..1cc161299 100644 --- a/internal/handshake/hkdf_test.go +++ b/internal/handshake/hkdf_test.go @@ -2,16 +2,67 @@ package handshake import ( "crypto" + "crypto/cipher" + "crypto/tls" + "testing" + _ "unsafe" + + "golang.org/x/exp/rand" . "github.com/onsi/ginkgo/v2" . "github.com/onsi/gomega" ) -var _ = Describe("Initial AEAD using AES-GCM", func() { - // Result generated by running in qtls: - // cipherSuiteTLS13ByID(TLS_AES_128_GCM_SHA256).expandLabel([]byte("secret"), []byte("context"), "label", 42) - It("gets the same results as qtls", func() { - expanded := hkdfExpandLabel(crypto.SHA256, []byte("secret"), []byte("context"), "label", 42) - Expect(expanded).To(Equal([]byte{0x78, 0x87, 0x6a, 0xb5, 0x84, 0xa2, 0x26, 0xb7, 0x8, 0x5a, 0x7b, 0x3a, 0x4c, 0xbb, 0x1e, 0xbc, 0x2f, 0x9b, 0x67, 0xd0, 0x6a, 0xa2, 0x24, 0xb4, 0x7d, 0x29, 0x3c, 0x7a, 0xce, 0xc7, 0xc3, 0x74, 0xcd, 0x59, 0x7a, 0xa8, 0x21, 0x5e, 0xe7, 0xca, 0x1, 0xda})) - }) +type cipherSuiteTLS13 struct { + ID uint16 + KeyLen int + AEAD func(key, fixedNonce []byte) cipher.AEAD + Hash crypto.Hash +} + +//go:linkname cipherSuiteTLS13ByID crypto/tls.cipherSuiteTLS13ByID +func cipherSuiteTLS13ByID(id uint16) *cipherSuiteTLS13 + +//go:linkname expandLabel crypto/tls.(*cipherSuiteTLS13).expandLabel +func expandLabel(cs *cipherSuiteTLS13, secret []byte, label string, context []byte, length int) []byte + +var _ = Describe("HKDF", func() { + DescribeTable("gets the same results as crypto/tls", + func(cipherSuite uint16, secret, context []byte, label string, length int) { + cs := cipherSuiteTLS13ByID(cipherSuite) + expected := expandLabel(cs, secret, label, context, length) + expanded := hkdfExpandLabel(cs.Hash, secret, context, label, length) + Expect(expanded).To(Equal(expected)) + }, + Entry("TLS_AES_128_GCM_SHA256", tls.TLS_AES_128_GCM_SHA256, []byte("secret"), []byte("context"), "label", 42), + Entry("TLS_AES_256_GCM_SHA384", tls.TLS_AES_256_GCM_SHA384, []byte("secret"), []byte("context"), "label", 100), + Entry("TLS_CHACHA20_POLY1305_SHA256", tls.TLS_CHACHA20_POLY1305_SHA256, []byte("secret"), []byte("context"), "label", 77), + ) }) + +func BenchmarkHKDFExpandLabelStandardLibrary(b *testing.B) { + b.Run("TLS_AES_128_GCM_SHA256", func(b *testing.B) { benchmarkHKDFExpandLabel(b, tls.TLS_AES_128_GCM_SHA256, true) }) + b.Run("TLS_AES_256_GCM_SHA384", func(b *testing.B) { benchmarkHKDFExpandLabel(b, tls.TLS_AES_256_GCM_SHA384, true) }) + b.Run("TLS_CHACHA20_POLY1305_SHA256", func(b *testing.B) { benchmarkHKDFExpandLabel(b, tls.TLS_CHACHA20_POLY1305_SHA256, true) }) +} + +func BenchmarkHKDFExpandLabelOptimized(b *testing.B) { + b.Run("TLS_AES_128_GCM_SHA256", func(b *testing.B) { benchmarkHKDFExpandLabel(b, tls.TLS_AES_128_GCM_SHA256, false) }) + b.Run("TLS_AES_256_GCM_SHA384", func(b *testing.B) { benchmarkHKDFExpandLabel(b, tls.TLS_AES_256_GCM_SHA384, false) }) + b.Run("TLS_CHACHA20_POLY1305_SHA256", func(b *testing.B) { benchmarkHKDFExpandLabel(b, tls.TLS_CHACHA20_POLY1305_SHA256, false) }) +} + +func benchmarkHKDFExpandLabel(b *testing.B, cipherSuite uint16, useStdLib bool) { + b.ReportAllocs() + cs := cipherSuiteTLS13ByID(cipherSuite) + secret := make([]byte, 32) + rand.Read(secret) + b.ResetTimer() + for i := 0; i < b.N; i++ { + if useStdLib { + expandLabel(cs, secret, "label", []byte("context"), 42) + } else { + hkdfExpandLabel(cs.Hash, secret, []byte("context"), "label", 42) + } + } +} diff --git a/internal/handshake/initial_aead.go b/internal/handshake/initial_aead.go index dbeef89fd..a1a61f248 100644 --- a/internal/handshake/initial_aead.go +++ b/internal/handshake/initial_aead.go @@ -22,7 +22,7 @@ const ( hkdfLabelIVV2 = "quicv2 iv" ) -func getSalt(v protocol.VersionNumber) []byte { +func getSalt(v protocol.Version) []byte { if v == protocol.Version2 { return quicSaltV2 } @@ -32,7 +32,7 @@ func getSalt(v protocol.VersionNumber) []byte { var initialSuite = getCipherSuite(tls.TLS_AES_128_GCM_SHA256) // NewInitialAEAD creates a new AEAD for Initial encryption / decryption. -func NewInitialAEAD(connID protocol.ConnectionID, pers protocol.Perspective, v protocol.VersionNumber) (LongHeaderSealer, LongHeaderOpener) { +func NewInitialAEAD(connID protocol.ConnectionID, pers protocol.Perspective, v protocol.Version) (LongHeaderSealer, LongHeaderOpener) { clientSecret, serverSecret := computeSecrets(connID, v) var mySecret, otherSecret []byte if pers == protocol.PerspectiveClient { @@ -52,14 +52,14 @@ func NewInitialAEAD(connID protocol.ConnectionID, pers protocol.Perspective, v p newLongHeaderOpener(decrypter, newAESHeaderProtector(initialSuite, otherSecret, true, hkdfHeaderProtectionLabel(v))) } -func computeSecrets(connID protocol.ConnectionID, v protocol.VersionNumber) (clientSecret, serverSecret []byte) { +func computeSecrets(connID protocol.ConnectionID, v protocol.Version) (clientSecret, serverSecret []byte) { initialSecret := hkdf.Extract(crypto.SHA256.New, connID.Bytes(), getSalt(v)) clientSecret = hkdfExpandLabel(crypto.SHA256, initialSecret, []byte{}, "client in", crypto.SHA256.Size()) serverSecret = hkdfExpandLabel(crypto.SHA256, initialSecret, []byte{}, "server in", crypto.SHA256.Size()) return } -func computeInitialKeyAndIV(secret []byte, v protocol.VersionNumber) (key, iv []byte) { +func computeInitialKeyAndIV(secret []byte, v protocol.Version) (key, iv []byte) { keyLabel := hkdfLabelKeyV1 ivLabel := hkdfLabelIVV1 if v == protocol.Version2 { diff --git a/internal/handshake/initial_aead_test.go b/internal/handshake/initial_aead_test.go index 5715dd4d3..0c10a8bcf 100644 --- a/internal/handshake/initial_aead_test.go +++ b/internal/handshake/initial_aead_test.go @@ -1,8 +1,11 @@ package handshake import ( - "crypto/rand" + "bytes" "fmt" + "testing" + + "golang.org/x/exp/rand" "github.com/refraction-networking/uquic/internal/protocol" @@ -20,7 +23,7 @@ var _ = Describe("Initial AEAD using AES-GCM", func() { connID := protocol.ParseConnectionID(splitHexString("0x8394c8f03e515708")) DescribeTable("computes the client key and IV", - func(v protocol.VersionNumber, expectedClientSecret, expectedKey, expectedIV []byte) { + func(v protocol.Version, expectedClientSecret, expectedKey, expectedIV []byte) { clientSecret, _ := computeSecrets(connID, v) Expect(clientSecret).To(Equal(expectedClientSecret)) key, iv := computeInitialKeyAndIV(clientSecret, v) @@ -42,7 +45,7 @@ var _ = Describe("Initial AEAD using AES-GCM", func() { ) DescribeTable("computes the server key and IV", - func(v protocol.VersionNumber, expectedServerSecret, expectedKey, expectedIV []byte) { + func(v protocol.Version, expectedServerSecret, expectedKey, expectedIV []byte) { _, serverSecret := computeSecrets(connID, v) Expect(serverSecret).To(Equal(expectedServerSecret)) key, iv := computeInitialKeyAndIV(serverSecret, v) @@ -64,7 +67,7 @@ var _ = Describe("Initial AEAD using AES-GCM", func() { ) DescribeTable("encrypts the client's Initial", - func(v protocol.VersionNumber, header, data, expectedSample []byte, expectedHdrFirstByte byte, expectedHdr, expectedPacket []byte) { + func(v protocol.Version, header, data, expectedSample []byte, expectedHdrFirstByte byte, expectedHdr, expectedPacket []byte) { sealer, _ := NewInitialAEAD(connID, protocol.PerspectiveClient, v) data = append(data, make([]byte, 1162-len(data))...) // add PADDING sealed := sealer.Seal(nil, data, 2, header) @@ -97,7 +100,7 @@ var _ = Describe("Initial AEAD using AES-GCM", func() { ) DescribeTable("encrypts the server's Initial", - func(v protocol.VersionNumber, header, data, expectedSample, expectedHdr, expectedPacket []byte) { + func(v protocol.Version, header, data, expectedSample, expectedHdr, expectedPacket []byte) { sealer, _ := NewInitialAEAD(connID, protocol.PerspectiveServer, v) sealed := sealer.Seal(nil, data, 1, header) sample := sealed[2 : 2+16] @@ -125,7 +128,7 @@ var _ = Describe("Initial AEAD using AES-GCM", func() { ), ) - for _, ver := range []protocol.VersionNumber{protocol.Version1, protocol.Version2} { + for _, ver := range []protocol.Version{protocol.Version1, protocol.Version2} { v := ver Context(fmt.Sprintf("using version %s", v), func() { @@ -187,3 +190,59 @@ var _ = Describe("Initial AEAD using AES-GCM", func() { }) } }) + +func BenchmarkInitialAEADCreate(b *testing.B) { + b.ReportAllocs() + connID := protocol.ParseConnectionID([]byte{0x12, 0x34, 0x56, 0x78, 0x90, 0xab, 0xcd, 0xef}) + for i := 0; i < b.N; i++ { + NewInitialAEAD(connID, protocol.PerspectiveServer, protocol.Version1) + } +} + +func BenchmarkInitialAEAD(b *testing.B) { + connectionID := protocol.ParseConnectionID([]byte{1, 2, 3, 4, 5, 6, 7, 8, 9, 0xa, 0xb, 0xc, 0xd}) + clientSealer, _ := NewInitialAEAD(connectionID, protocol.PerspectiveClient, protocol.Version1) + _, serverOpener := NewInitialAEAD(connectionID, protocol.PerspectiveServer, protocol.Version1) + + r := rand.New(rand.NewSource(1)) + packetData := make([]byte, 1200) + r.Read(packetData) + hdr := make([]byte, 50) + r.Read(hdr) + msg := clientSealer.Seal(nil, packetData, 42, hdr) + m, err := serverOpener.Open(nil, msg, 42, hdr) + if err != nil { + b.Fatalf("opening failed: %s", err) + } + if !bytes.Equal(m, packetData) { + b.Fatal("decrypted data doesn't match") + } + + b.Run("opening 100 bytes", func(b *testing.B) { + benchmarkOpen(b, serverOpener, clientSealer.Seal(nil, packetData[:100], 42, hdr), hdr) + }) + b.Run("opening 1200 bytes", func(b *testing.B) { benchmarkOpen(b, serverOpener, msg, hdr) }) + + b.Run("sealing 100 bytes", func(b *testing.B) { benchmarkSeal(b, clientSealer, packetData[:100], hdr) }) + b.Run("sealing 1200 bytes", func(b *testing.B) { benchmarkSeal(b, clientSealer, packetData, hdr) }) +} + +func benchmarkOpen(b *testing.B, aead LongHeaderOpener, msg, hdr []byte) { + b.ReportAllocs() + dst := make([]byte, 0, 1500) + for i := 0; i < b.N; i++ { + dst = dst[:0] + if _, err := aead.Open(dst, msg, 42, hdr); err != nil { + b.Fatalf("opening failed: %s", err) + } + } +} + +func benchmarkSeal(b *testing.B, aead LongHeaderSealer, msg, hdr []byte) { + b.ReportAllocs() + dst := make([]byte, 0, 1500) + for i := 0; i < b.N; i++ { + dst = dst[:0] + aead.Seal(dst, msg, protocol.PacketNumber(i), hdr) + } +} diff --git a/internal/handshake/retry.go b/internal/handshake/retry.go index aa8f92d83..2dd8490f8 100644 --- a/internal/handshake/retry.go +++ b/internal/handshake/retry.go @@ -40,7 +40,7 @@ var ( ) // GetRetryIntegrityTag calculates the integrity tag on a Retry packet -func GetRetryIntegrityTag(retry []byte, origDestConnID protocol.ConnectionID, version protocol.VersionNumber) *[16]byte { +func GetRetryIntegrityTag(retry []byte, origDestConnID protocol.ConnectionID, version protocol.Version) *[16]byte { retryMutex.Lock() defer retryMutex.Unlock() diff --git a/internal/handshake/retry_test.go b/internal/handshake/retry_test.go index 98d80b01a..0d70f1c66 100644 --- a/internal/handshake/retry_test.go +++ b/internal/handshake/retry_test.go @@ -28,9 +28,9 @@ var _ = Describe("Retry Integrity Check", func() { }) DescribeTable("using the test vectors", - func(version protocol.VersionNumber, data []byte) { + func(version protocol.Version, data []byte) { v := binary.BigEndian.Uint32(data[1:5]) - Expect(protocol.VersionNumber(v)).To(Equal(version)) + Expect(protocol.Version(v)).To(Equal(version)) connID := protocol.ParseConnectionID(splitHexString("0x8394c8f03e515708")) Expect(GetRetryIntegrityTag(data[:len(data)-16], connID, version)[:]).To(Equal(data[len(data)-16:])) }, diff --git a/internal/handshake/u_crypto_setup.go b/internal/handshake/u_crypto_setup.go index d6dc16b17..8c7cd5b98 100644 --- a/internal/handshake/u_crypto_setup.go +++ b/internal/handshake/u_crypto_setup.go @@ -1,11 +1,9 @@ package handshake import ( - "bytes" "context" "fmt" "strings" - "sync" "sync/atomic" "time" @@ -20,11 +18,11 @@ import ( type uCryptoSetup struct { tlsConf *tls.Config - conn *qtls.UQUICConn + conn *tls.UQUICConn events []Event - version protocol.VersionNumber + version protocol.Version ourParams *wire.TransportParameters peerParams *wire.TransportParameters @@ -39,8 +37,6 @@ type uCryptoSetup struct { perspective protocol.Perspective - mutex sync.Mutex // protects all members below - handshakeCompleteTime time.Time zeroRTTOpener LongHeaderOpener // only set for the server @@ -71,7 +67,7 @@ func NewUCryptoSetupClient( rttStats *utils.RTTStats, tracer *logging.ConnectionTracer, logger utils.Logger, - version protocol.VersionNumber, + version protocol.Version, chs *tls.ClientHelloSpec, ) CryptoSetup { cs := newUCryptoSetup( @@ -86,11 +82,16 @@ func NewUCryptoSetupClient( tlsConf = tlsConf.Clone() tlsConf.MinVersion = tls.VersionTLS13 - quicConf := &qtls.QUICConfig{TLSConfig: tlsConf} + quicConf := &tls.QUICConfig{TLSConfig: tlsConf} qtls.SetupConfigForClient(quicConf, cs.marshalDataForSessionState, cs.handleDataFromSessionState) cs.tlsConf = tlsConf - cs.conn = qtls.UQUICClient(quicConf, chs) + // [UQUIC] + cs.conn = tls.UQUICClient(quicConf, tls.HelloCustom) + if err := cs.conn.ApplyPreset(chs); err != nil { + panic(err) + } + // cs.conn.SetTransportParameters(cs.ourParams.Marshal(protocol.PerspectiveClient)) // [UQUIC] doesn't require this return cs @@ -103,7 +104,7 @@ func newUCryptoSetup( tracer *logging.ConnectionTracer, logger utils.Logger, perspective protocol.Perspective, - version protocol.VersionNumber, + version protocol.Version, ) *uCryptoSetup { initialSealer, initialOpener := NewInitialAEAD(connID, perspective, version) if tracer != nil { @@ -195,29 +196,29 @@ func (h *uCryptoSetup) handleMessage(data []byte, encLevel protocol.EncryptionLe } } -func (h *uCryptoSetup) handleEvent(ev qtls.QUICEvent) (done bool, err error) { +func (h *uCryptoSetup) handleEvent(ev tls.QUICEvent) (done bool, err error) { switch ev.Kind { - case qtls.QUICNoEvent: + case tls.QUICNoEvent: return true, nil - case qtls.QUICSetReadSecret: - h.SetReadKey(ev.Level, ev.Suite, ev.Data) + case tls.QUICSetReadSecret: + h.setReadKey(ev.Level, ev.Suite, ev.Data) return false, nil - case qtls.QUICSetWriteSecret: - h.SetWriteKey(ev.Level, ev.Suite, ev.Data) + case tls.QUICSetWriteSecret: + h.setWriteKey(ev.Level, ev.Suite, ev.Data) return false, nil - case qtls.QUICTransportParameters: + case tls.QUICTransportParameters: return false, h.handleTransportParameters(ev.Data) - case qtls.QUICTransportParametersRequired: + case tls.QUICTransportParametersRequired: h.conn.SetTransportParameters(h.ourParams.Marshal(h.perspective)) // [UQUIC] doesn't expect this and may fail return false, nil - case qtls.QUICRejectedEarlyData: + case tls.QUICRejectedEarlyData: h.rejected0RTT() return false, nil - case qtls.QUICWriteData: - h.WriteRecord(ev.Level, ev.Data) + case tls.QUICWriteData: + h.writeRecord(ev.Level, ev.Data) return false, nil - case qtls.QUICHandshakeDone: + case tls.QUICHandshakeDone: h.handshakeComplete() return false, nil default: @@ -245,48 +246,41 @@ func (h *uCryptoSetup) handleTransportParameters(data []byte) error { } // must be called after receiving the transport parameters -func (h *uCryptoSetup) marshalDataForSessionState() []byte { +func (h *uCryptoSetup) marshalDataForSessionState(earlyData bool) []byte { b := make([]byte, 0, 256) b = quicvarint.Append(b, clientSessionStateRevision) b = quicvarint.Append(b, uint64(h.rttStats.SmoothedRTT().Microseconds())) - return h.peerParams.MarshalForSessionTicket(b) + if earlyData { + // only save the transport parameters for 0-RTT enabled session tickets + return h.peerParams.MarshalForSessionTicket(b) + } + return b } -func (h *uCryptoSetup) handleDataFromSessionState(data []byte) { - tp, err := h.handleDataFromSessionStateImpl(data) +func (h *uCryptoSetup) handleDataFromSessionState(data []byte, earlyData bool) (allowEarlyData bool) { + rtt, tp, err := decodeDataFromSessionState(data, earlyData) if err != nil { h.logger.Debugf("Restoring of transport parameters from session ticket failed: %s", err.Error()) return } - h.zeroRTTParameters = tp -} - -func (h *uCryptoSetup) handleDataFromSessionStateImpl(data []byte) (*wire.TransportParameters, error) { - r := bytes.NewReader(data) - ver, err := quicvarint.Read(r) - if err != nil { - return nil, err - } - if ver != clientSessionStateRevision { - return nil, fmt.Errorf("mismatching version. Got %d, expected %d", ver, clientSessionStateRevision) - } - rtt, err := quicvarint.Read(r) - if err != nil { - return nil, err - } - h.rttStats.SetInitialRTT(time.Duration(rtt) * time.Microsecond) - var tp wire.TransportParameters - if err := tp.UnmarshalFromSessionTicket(r); err != nil { - return nil, err + h.rttStats.SetInitialRTT(rtt) + // The session ticket might have been saved from a connection that allowed 0-RTT, + // and therefore contain transport parameters. + // Only use them if 0-RTT is actually used on the new connection. + if tp != nil && h.allow0RTT { + h.zeroRTTParameters = tp + return true } - return &tp, nil + return false } // GetSessionTicket generates a new session ticket. // Due to limitations in crypto/tls, it's only possible to generate a single session ticket per connection. // It is only valid for the server. func (h *uCryptoSetup) GetSessionTicket() ([]byte, error) { - if err := qtls.SendSessionTicket(h.conn, h.allow0RTT); err != nil { + if err := h.conn.SendSessionTicket(tls.QUICSessionTicketOptions{ + EarlyData: h.allow0RTT, + }); err != nil { // Session tickets might be disabled by tls.Config.SessionTicketsDisabled. // We can't check h.tlsConfig here, since the actual config might have been obtained from // the GetConfigForClient callback. @@ -298,11 +292,11 @@ func (h *uCryptoSetup) GetSessionTicket() ([]byte, error) { return nil, err } ev := h.conn.NextEvent() - if ev.Kind != qtls.QUICWriteData || ev.Level != qtls.QUICEncryptionLevelApplication { + if ev.Kind != tls.QUICWriteData || ev.Level != tls.QUICEncryptionLevelApplication { panic("crypto/tls bug: where's my session ticket?") } ticket := ev.Data - if ev := h.conn.NextEvent(); ev.Kind != qtls.QUICNoEvent { + if ev := h.conn.NextEvent(); ev.Kind != tls.QUICNoEvent { panic("crypto/tls bug: why more than one ticket?") } return ticket, nil @@ -312,22 +306,19 @@ func (h *uCryptoSetup) GetSessionTicket() ([]byte, error) { func (h *uCryptoSetup) rejected0RTT() { h.logger.Debugf("0-RTT was rejected. Dropping 0-RTT keys.") - h.mutex.Lock() had0RTTKeys := h.zeroRTTSealer != nil h.zeroRTTSealer = nil - h.mutex.Unlock() if had0RTTKeys { h.events = append(h.events, Event{Kind: EventDiscard0RTTKeys}) } } -func (h *uCryptoSetup) SetReadKey(el qtls.QUICEncryptionLevel, suiteID uint16, trafficSecret []byte) { +func (h *uCryptoSetup) setReadKey(el tls.QUICEncryptionLevel, suiteID uint16, trafficSecret []byte) { suite := getCipherSuite(suiteID) - h.mutex.Lock() //nolint:exhaustive // The TLS stack doesn't export Initial keys. switch el { - case qtls.QUICEncryptionLevelEarly: + case tls.QUICEncryptionLevelEarly: if h.perspective == protocol.PerspectiveClient { panic("Received 0-RTT read key for the client") } @@ -339,7 +330,7 @@ func (h *uCryptoSetup) SetReadKey(el qtls.QUICEncryptionLevel, suiteID uint16, t if h.logger.Debug() { h.logger.Debugf("Installed 0-RTT Read keys (using %s)", tls.CipherSuiteName(suite.ID)) } - case qtls.QUICEncryptionLevelHandshake: + case tls.QUICEncryptionLevelHandshake: h.handshakeOpener = newLongHeaderOpener( createAEAD(suite, trafficSecret, h.version), newHeaderProtector(suite, trafficSecret, true, h.version), @@ -347,7 +338,7 @@ func (h *uCryptoSetup) SetReadKey(el qtls.QUICEncryptionLevel, suiteID uint16, t if h.logger.Debug() { h.logger.Debugf("Installed Handshake Read keys (using %s)", tls.CipherSuiteName(suite.ID)) } - case qtls.QUICEncryptionLevelApplication: + case tls.QUICEncryptionLevelApplication: h.aead.SetReadKey(suite, trafficSecret) h.has1RTTOpener = true if h.logger.Debug() { @@ -356,19 +347,17 @@ func (h *uCryptoSetup) SetReadKey(el qtls.QUICEncryptionLevel, suiteID uint16, t default: panic("unexpected read encryption level") } - h.mutex.Unlock() h.events = append(h.events, Event{Kind: EventReceivedReadKeys}) - if h.tracer != nil { + if h.tracer != nil && h.tracer.UpdatedKeyFromTLS != nil { h.tracer.UpdatedKeyFromTLS(qtls.FromTLSEncryptionLevel(el), h.perspective.Opposite()) } } -func (h *uCryptoSetup) SetWriteKey(el qtls.QUICEncryptionLevel, suiteID uint16, trafficSecret []byte) { +func (h *uCryptoSetup) setWriteKey(el tls.QUICEncryptionLevel, suiteID uint16, trafficSecret []byte) { suite := getCipherSuite(suiteID) - h.mutex.Lock() //nolint:exhaustive // The TLS stack doesn't export Initial keys. switch el { - case qtls.QUICEncryptionLevelEarly: + case tls.QUICEncryptionLevelEarly: if h.perspective == protocol.PerspectiveServer { panic("Received 0-RTT write key for the server") } @@ -376,16 +365,15 @@ func (h *uCryptoSetup) SetWriteKey(el qtls.QUICEncryptionLevel, suiteID uint16, createAEAD(suite, trafficSecret, h.version), newHeaderProtector(suite, trafficSecret, true, h.version), ) - h.mutex.Unlock() if h.logger.Debug() { h.logger.Debugf("Installed 0-RTT Write keys (using %s)", tls.CipherSuiteName(suite.ID)) } - if h.tracer != nil { + if h.tracer != nil && h.tracer.UpdatedKeyFromTLS != nil { h.tracer.UpdatedKeyFromTLS(protocol.Encryption0RTT, h.perspective) } // don't set used0RTT here. 0-RTT might still get rejected. return - case qtls.QUICEncryptionLevelHandshake: + case tls.QUICEncryptionLevelHandshake: h.handshakeSealer = newLongHeaderSealer( createAEAD(suite, trafficSecret, h.version), newHeaderProtector(suite, trafficSecret, true, h.version), @@ -393,7 +381,7 @@ func (h *uCryptoSetup) SetWriteKey(el qtls.QUICEncryptionLevel, suiteID uint16, if h.logger.Debug() { h.logger.Debugf("Installed Handshake Write keys (using %s)", tls.CipherSuiteName(suite.ID)) } - case qtls.QUICEncryptionLevelApplication: + case tls.QUICEncryptionLevelApplication: h.aead.SetWriteKey(suite, trafficSecret) h.has1RTTSealer = true if h.logger.Debug() { @@ -404,28 +392,27 @@ func (h *uCryptoSetup) SetWriteKey(el qtls.QUICEncryptionLevel, suiteID uint16, h.used0RTT.Store(true) h.zeroRTTSealer = nil h.logger.Debugf("Dropping 0-RTT keys.") - if h.tracer != nil { + if h.tracer != nil && h.tracer.DroppedEncryptionLevel != nil { h.tracer.DroppedEncryptionLevel(protocol.Encryption0RTT) } } default: panic("unexpected write encryption level") } - h.mutex.Unlock() - if h.tracer != nil { + if h.tracer != nil && h.tracer.UpdatedKeyFromTLS != nil { h.tracer.UpdatedKeyFromTLS(qtls.FromTLSEncryptionLevel(el), h.perspective) } } -// WriteRecord is called when TLS writes data -func (h *uCryptoSetup) WriteRecord(encLevel qtls.QUICEncryptionLevel, p []byte) { +// writeRecord is called when TLS writes data +func (h *uCryptoSetup) writeRecord(encLevel tls.QUICEncryptionLevel, p []byte) { //nolint:exhaustive // handshake records can only be written for Initial and Handshake. switch encLevel { - case qtls.QUICEncryptionLevelInitial: + case tls.QUICEncryptionLevelInitial: h.events = append(h.events, Event{Kind: EventWriteInitialData, Data: p}) - case qtls.QUICEncryptionLevelHandshake: + case tls.QUICEncryptionLevelHandshake: h.events = append(h.events, Event{Kind: EventWriteHandshakeData, Data: p}) - case qtls.QUICEncryptionLevelApplication: + case tls.QUICEncryptionLevelApplication: panic("unexpected write") default: panic(fmt.Sprintf("unexpected write encryption level: %s", encLevel)) @@ -433,11 +420,9 @@ func (h *uCryptoSetup) WriteRecord(encLevel qtls.QUICEncryptionLevel, p []byte) } func (h *uCryptoSetup) DiscardInitialKeys() { - h.mutex.Lock() dropped := h.initialOpener != nil h.initialOpener = nil h.initialSealer = nil - h.mutex.Unlock() if dropped { h.logger.Debugf("Dropping Initial keys.") } @@ -452,22 +437,17 @@ func (h *uCryptoSetup) SetHandshakeConfirmed() { h.aead.SetHandshakeConfirmed() // drop Handshake keys var dropped bool - h.mutex.Lock() if h.handshakeOpener != nil { h.handshakeOpener = nil h.handshakeSealer = nil dropped = true } - h.mutex.Unlock() if dropped { h.logger.Debugf("Dropping Handshake keys.") } } func (h *uCryptoSetup) GetInitialSealer() (LongHeaderSealer, error) { - h.mutex.Lock() - defer h.mutex.Unlock() - if h.initialSealer == nil { return nil, ErrKeysDropped } @@ -475,9 +455,6 @@ func (h *uCryptoSetup) GetInitialSealer() (LongHeaderSealer, error) { } func (h *uCryptoSetup) Get0RTTSealer() (LongHeaderSealer, error) { - h.mutex.Lock() - defer h.mutex.Unlock() - if h.zeroRTTSealer == nil { return nil, ErrKeysDropped } @@ -485,9 +462,6 @@ func (h *uCryptoSetup) Get0RTTSealer() (LongHeaderSealer, error) { } func (h *uCryptoSetup) GetHandshakeSealer() (LongHeaderSealer, error) { - h.mutex.Lock() - defer h.mutex.Unlock() - if h.handshakeSealer == nil { if h.initialSealer == nil { return nil, ErrKeysDropped @@ -498,9 +472,6 @@ func (h *uCryptoSetup) GetHandshakeSealer() (LongHeaderSealer, error) { } func (h *uCryptoSetup) Get1RTTSealer() (ShortHeaderSealer, error) { - h.mutex.Lock() - defer h.mutex.Unlock() - if !h.has1RTTSealer { return nil, ErrKeysNotYetAvailable } @@ -508,9 +479,6 @@ func (h *uCryptoSetup) Get1RTTSealer() (ShortHeaderSealer, error) { } func (h *uCryptoSetup) GetInitialOpener() (LongHeaderOpener, error) { - h.mutex.Lock() - defer h.mutex.Unlock() - if h.initialOpener == nil { return nil, ErrKeysDropped } @@ -518,9 +486,6 @@ func (h *uCryptoSetup) GetInitialOpener() (LongHeaderOpener, error) { } func (h *uCryptoSetup) Get0RTTOpener() (LongHeaderOpener, error) { - h.mutex.Lock() - defer h.mutex.Unlock() - if h.zeroRTTOpener == nil { if h.initialOpener != nil { return nil, ErrKeysNotYetAvailable @@ -532,9 +497,6 @@ func (h *uCryptoSetup) Get0RTTOpener() (LongHeaderOpener, error) { } func (h *uCryptoSetup) GetHandshakeOpener() (LongHeaderOpener, error) { - h.mutex.Lock() - defer h.mutex.Unlock() - if h.handshakeOpener == nil { if h.initialOpener != nil { return nil, ErrKeysNotYetAvailable @@ -546,9 +508,6 @@ func (h *uCryptoSetup) GetHandshakeOpener() (LongHeaderOpener, error) { } func (h *uCryptoSetup) Get1RTTOpener() (ShortHeaderOpener, error) { - h.mutex.Lock() - defer h.mutex.Unlock() - if h.zeroRTTOpener != nil && time.Since(h.handshakeCompleteTime) > 3*h.rttStats.PTO(true) { h.zeroRTTOpener = nil h.logger.Debugf("Dropping 0-RTT keys.") diff --git a/internal/handshake/updatable_aead.go b/internal/handshake/updatable_aead.go index 0d9664a0d..771142195 100644 --- a/internal/handshake/updatable_aead.go +++ b/internal/handshake/updatable_aead.go @@ -60,7 +60,7 @@ type updatableAEAD struct { tracer *logging.ConnectionTracer logger utils.Logger - version protocol.VersionNumber + version protocol.Version // use a single slice to avoid allocations nonceBuf []byte @@ -71,7 +71,7 @@ var ( _ ShortHeaderSealer = &updatableAEAD{} ) -func newUpdatableAEAD(rttStats *utils.RTTStats, tracer *logging.ConnectionTracer, logger utils.Logger, version protocol.VersionNumber) *updatableAEAD { +func newUpdatableAEAD(rttStats *utils.RTTStats, tracer *logging.ConnectionTracer, logger utils.Logger, version protocol.Version) *updatableAEAD { return &updatableAEAD{ firstPacketNumber: protocol.InvalidPacketNumber, largestAcked: protocol.InvalidPacketNumber, @@ -134,7 +134,7 @@ func (a *updatableAEAD) SetReadKey(suite *cipherSuite, trafficSecret []byte) { // SetWriteKey sets the write key. // For the client, this function is called after SetReadKey. -// For the server, this function is called before SetWriteKey. +// For the server, this function is called before SetReadKey. func (a *updatableAEAD) SetWriteKey(suite *cipherSuite, trafficSecret []byte) { a.sendAEAD = createAEAD(suite, trafficSecret, a.version) a.headerEncrypter = newHeaderProtector(suite, trafficSecret, false, a.version) @@ -173,7 +173,7 @@ func (a *updatableAEAD) Open(dst, src []byte, rcvTime time.Time, pn protocol.Pac } } if err == nil { - a.highestRcvdPN = utils.Max(a.highestRcvdPN, pn) + a.highestRcvdPN = max(a.highestRcvdPN, pn) } return dec, err } diff --git a/internal/handshake/updatable_aead_test.go b/internal/handshake/updatable_aead_test.go index af8d393f0..0de4542bd 100644 --- a/internal/handshake/updatable_aead_test.go +++ b/internal/handshake/updatable_aead_test.go @@ -21,7 +21,7 @@ import ( var _ = Describe("Updatable AEAD", func() { DescribeTable("ChaCha test vector", - func(v protocol.VersionNumber, expectedPayload, expectedPacket []byte) { + func(v protocol.Version, expectedPayload, expectedPacket []byte) { secret := splitHexString("9ac312a7f877468ebe69422748ad00a1 5443f18203a07d6060f688f30f21632b") aead := newUpdatableAEAD(&utils.RTTStats{}, nil, nil, v) chacha := cipherSuites[2] @@ -49,7 +49,7 @@ var _ = Describe("Updatable AEAD", func() { ), ) - for _, ver := range []protocol.VersionNumber{protocol.Version1, protocol.Version2} { + for _, ver := range []protocol.Version{protocol.Version1, protocol.Version2} { v := ver Context(fmt.Sprintf("using version %s", v), func() { diff --git a/internal/mocks/ackhandler/received_packet_handler.go b/internal/mocks/ackhandler/received_packet_handler.go index 8be3e071d..955bc7ae1 100644 --- a/internal/mocks/ackhandler/received_packet_handler.go +++ b/internal/mocks/ackhandler/received_packet_handler.go @@ -3,8 +3,9 @@ // // Generated by this command: // -// mockgen -build_flags=-tags=gomock -package mockackhandler -destination ackhandler/received_packet_handler.go github.com/refraction-networking/uquic/internal/ackhandler ReceivedPacketHandler +// mockgen -typed -build_flags=-tags=gomock -package mockackhandler -destination ackhandler/received_packet_handler.go github.com/quic-go/quic-go/internal/ackhandler ReceivedPacketHandler // + // Package mockackhandler is a generated GoMock package. package mockackhandler @@ -47,9 +48,33 @@ func (m *MockReceivedPacketHandler) DropPackets(arg0 protocol.EncryptionLevel) { } // DropPackets indicates an expected call of DropPackets. -func (mr *MockReceivedPacketHandlerMockRecorder) DropPackets(arg0 any) *gomock.Call { +func (mr *MockReceivedPacketHandlerMockRecorder) DropPackets(arg0 any) *MockReceivedPacketHandlerDropPacketsCall { mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "DropPackets", reflect.TypeOf((*MockReceivedPacketHandler)(nil).DropPackets), arg0) + call := mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "DropPackets", reflect.TypeOf((*MockReceivedPacketHandler)(nil).DropPackets), arg0) + return &MockReceivedPacketHandlerDropPacketsCall{Call: call} +} + +// MockReceivedPacketHandlerDropPacketsCall wrap *gomock.Call +type MockReceivedPacketHandlerDropPacketsCall struct { + *gomock.Call +} + +// Return rewrite *gomock.Call.Return +func (c *MockReceivedPacketHandlerDropPacketsCall) Return() *MockReceivedPacketHandlerDropPacketsCall { + c.Call = c.Call.Return() + return c +} + +// Do rewrite *gomock.Call.Do +func (c *MockReceivedPacketHandlerDropPacketsCall) Do(f func(protocol.EncryptionLevel)) *MockReceivedPacketHandlerDropPacketsCall { + c.Call = c.Call.Do(f) + return c +} + +// DoAndReturn rewrite *gomock.Call.DoAndReturn +func (c *MockReceivedPacketHandlerDropPacketsCall) DoAndReturn(f func(protocol.EncryptionLevel)) *MockReceivedPacketHandlerDropPacketsCall { + c.Call = c.Call.DoAndReturn(f) + return c } // GetAckFrame mocks base method. @@ -61,9 +86,33 @@ func (m *MockReceivedPacketHandler) GetAckFrame(arg0 protocol.EncryptionLevel, a } // GetAckFrame indicates an expected call of GetAckFrame. -func (mr *MockReceivedPacketHandlerMockRecorder) GetAckFrame(arg0, arg1 any) *gomock.Call { +func (mr *MockReceivedPacketHandlerMockRecorder) GetAckFrame(arg0, arg1 any) *MockReceivedPacketHandlerGetAckFrameCall { mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetAckFrame", reflect.TypeOf((*MockReceivedPacketHandler)(nil).GetAckFrame), arg0, arg1) + call := mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetAckFrame", reflect.TypeOf((*MockReceivedPacketHandler)(nil).GetAckFrame), arg0, arg1) + return &MockReceivedPacketHandlerGetAckFrameCall{Call: call} +} + +// MockReceivedPacketHandlerGetAckFrameCall wrap *gomock.Call +type MockReceivedPacketHandlerGetAckFrameCall struct { + *gomock.Call +} + +// Return rewrite *gomock.Call.Return +func (c *MockReceivedPacketHandlerGetAckFrameCall) Return(arg0 *wire.AckFrame) *MockReceivedPacketHandlerGetAckFrameCall { + c.Call = c.Call.Return(arg0) + return c +} + +// Do rewrite *gomock.Call.Do +func (c *MockReceivedPacketHandlerGetAckFrameCall) Do(f func(protocol.EncryptionLevel, bool) *wire.AckFrame) *MockReceivedPacketHandlerGetAckFrameCall { + c.Call = c.Call.Do(f) + return c +} + +// DoAndReturn rewrite *gomock.Call.DoAndReturn +func (c *MockReceivedPacketHandlerGetAckFrameCall) DoAndReturn(f func(protocol.EncryptionLevel, bool) *wire.AckFrame) *MockReceivedPacketHandlerGetAckFrameCall { + c.Call = c.Call.DoAndReturn(f) + return c } // GetAlarmTimeout mocks base method. @@ -75,9 +124,33 @@ func (m *MockReceivedPacketHandler) GetAlarmTimeout() time.Time { } // GetAlarmTimeout indicates an expected call of GetAlarmTimeout. -func (mr *MockReceivedPacketHandlerMockRecorder) GetAlarmTimeout() *gomock.Call { +func (mr *MockReceivedPacketHandlerMockRecorder) GetAlarmTimeout() *MockReceivedPacketHandlerGetAlarmTimeoutCall { mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetAlarmTimeout", reflect.TypeOf((*MockReceivedPacketHandler)(nil).GetAlarmTimeout)) + call := mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetAlarmTimeout", reflect.TypeOf((*MockReceivedPacketHandler)(nil).GetAlarmTimeout)) + return &MockReceivedPacketHandlerGetAlarmTimeoutCall{Call: call} +} + +// MockReceivedPacketHandlerGetAlarmTimeoutCall wrap *gomock.Call +type MockReceivedPacketHandlerGetAlarmTimeoutCall struct { + *gomock.Call +} + +// Return rewrite *gomock.Call.Return +func (c *MockReceivedPacketHandlerGetAlarmTimeoutCall) Return(arg0 time.Time) *MockReceivedPacketHandlerGetAlarmTimeoutCall { + c.Call = c.Call.Return(arg0) + return c +} + +// Do rewrite *gomock.Call.Do +func (c *MockReceivedPacketHandlerGetAlarmTimeoutCall) Do(f func() time.Time) *MockReceivedPacketHandlerGetAlarmTimeoutCall { + c.Call = c.Call.Do(f) + return c +} + +// DoAndReturn rewrite *gomock.Call.DoAndReturn +func (c *MockReceivedPacketHandlerGetAlarmTimeoutCall) DoAndReturn(f func() time.Time) *MockReceivedPacketHandlerGetAlarmTimeoutCall { + c.Call = c.Call.DoAndReturn(f) + return c } // IsPotentiallyDuplicate mocks base method. @@ -89,9 +162,33 @@ func (m *MockReceivedPacketHandler) IsPotentiallyDuplicate(arg0 protocol.PacketN } // IsPotentiallyDuplicate indicates an expected call of IsPotentiallyDuplicate. -func (mr *MockReceivedPacketHandlerMockRecorder) IsPotentiallyDuplicate(arg0, arg1 any) *gomock.Call { +func (mr *MockReceivedPacketHandlerMockRecorder) IsPotentiallyDuplicate(arg0, arg1 any) *MockReceivedPacketHandlerIsPotentiallyDuplicateCall { mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "IsPotentiallyDuplicate", reflect.TypeOf((*MockReceivedPacketHandler)(nil).IsPotentiallyDuplicate), arg0, arg1) + call := mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "IsPotentiallyDuplicate", reflect.TypeOf((*MockReceivedPacketHandler)(nil).IsPotentiallyDuplicate), arg0, arg1) + return &MockReceivedPacketHandlerIsPotentiallyDuplicateCall{Call: call} +} + +// MockReceivedPacketHandlerIsPotentiallyDuplicateCall wrap *gomock.Call +type MockReceivedPacketHandlerIsPotentiallyDuplicateCall struct { + *gomock.Call +} + +// Return rewrite *gomock.Call.Return +func (c *MockReceivedPacketHandlerIsPotentiallyDuplicateCall) Return(arg0 bool) *MockReceivedPacketHandlerIsPotentiallyDuplicateCall { + c.Call = c.Call.Return(arg0) + return c +} + +// Do rewrite *gomock.Call.Do +func (c *MockReceivedPacketHandlerIsPotentiallyDuplicateCall) Do(f func(protocol.PacketNumber, protocol.EncryptionLevel) bool) *MockReceivedPacketHandlerIsPotentiallyDuplicateCall { + c.Call = c.Call.Do(f) + return c +} + +// DoAndReturn rewrite *gomock.Call.DoAndReturn +func (c *MockReceivedPacketHandlerIsPotentiallyDuplicateCall) DoAndReturn(f func(protocol.PacketNumber, protocol.EncryptionLevel) bool) *MockReceivedPacketHandlerIsPotentiallyDuplicateCall { + c.Call = c.Call.DoAndReturn(f) + return c } // ReceivedPacket mocks base method. @@ -103,7 +200,31 @@ func (m *MockReceivedPacketHandler) ReceivedPacket(arg0 protocol.PacketNumber, a } // ReceivedPacket indicates an expected call of ReceivedPacket. -func (mr *MockReceivedPacketHandlerMockRecorder) ReceivedPacket(arg0, arg1, arg2, arg3, arg4 any) *gomock.Call { +func (mr *MockReceivedPacketHandlerMockRecorder) ReceivedPacket(arg0, arg1, arg2, arg3, arg4 any) *MockReceivedPacketHandlerReceivedPacketCall { mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "ReceivedPacket", reflect.TypeOf((*MockReceivedPacketHandler)(nil).ReceivedPacket), arg0, arg1, arg2, arg3, arg4) + call := mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "ReceivedPacket", reflect.TypeOf((*MockReceivedPacketHandler)(nil).ReceivedPacket), arg0, arg1, arg2, arg3, arg4) + return &MockReceivedPacketHandlerReceivedPacketCall{Call: call} +} + +// MockReceivedPacketHandlerReceivedPacketCall wrap *gomock.Call +type MockReceivedPacketHandlerReceivedPacketCall struct { + *gomock.Call +} + +// Return rewrite *gomock.Call.Return +func (c *MockReceivedPacketHandlerReceivedPacketCall) Return(arg0 error) *MockReceivedPacketHandlerReceivedPacketCall { + c.Call = c.Call.Return(arg0) + return c +} + +// Do rewrite *gomock.Call.Do +func (c *MockReceivedPacketHandlerReceivedPacketCall) Do(f func(protocol.PacketNumber, protocol.ECN, protocol.EncryptionLevel, time.Time, bool) error) *MockReceivedPacketHandlerReceivedPacketCall { + c.Call = c.Call.Do(f) + return c +} + +// DoAndReturn rewrite *gomock.Call.DoAndReturn +func (c *MockReceivedPacketHandlerReceivedPacketCall) DoAndReturn(f func(protocol.PacketNumber, protocol.ECN, protocol.EncryptionLevel, time.Time, bool) error) *MockReceivedPacketHandlerReceivedPacketCall { + c.Call = c.Call.DoAndReturn(f) + return c } diff --git a/internal/mocks/ackhandler/sent_packet_handler.go b/internal/mocks/ackhandler/sent_packet_handler.go index a0234086b..89e2b9a3b 100644 --- a/internal/mocks/ackhandler/sent_packet_handler.go +++ b/internal/mocks/ackhandler/sent_packet_handler.go @@ -3,8 +3,9 @@ // // Generated by this command: // -// mockgen -build_flags=-tags=gomock -package mockackhandler -destination ackhandler/sent_packet_handler.go github.com/refraction-networking/uquic/internal/ackhandler SentPacketHandler +// mockgen -typed -build_flags=-tags=gomock -package mockackhandler -destination ackhandler/sent_packet_handler.go github.com/quic-go/quic-go/internal/ackhandler SentPacketHandler // + // Package mockackhandler is a generated GoMock package. package mockackhandler @@ -48,9 +49,33 @@ func (m *MockSentPacketHandler) DropPackets(arg0 protocol.EncryptionLevel) { } // DropPackets indicates an expected call of DropPackets. -func (mr *MockSentPacketHandlerMockRecorder) DropPackets(arg0 any) *gomock.Call { +func (mr *MockSentPacketHandlerMockRecorder) DropPackets(arg0 any) *MockSentPacketHandlerDropPacketsCall { mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "DropPackets", reflect.TypeOf((*MockSentPacketHandler)(nil).DropPackets), arg0) + call := mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "DropPackets", reflect.TypeOf((*MockSentPacketHandler)(nil).DropPackets), arg0) + return &MockSentPacketHandlerDropPacketsCall{Call: call} +} + +// MockSentPacketHandlerDropPacketsCall wrap *gomock.Call +type MockSentPacketHandlerDropPacketsCall struct { + *gomock.Call +} + +// Return rewrite *gomock.Call.Return +func (c *MockSentPacketHandlerDropPacketsCall) Return() *MockSentPacketHandlerDropPacketsCall { + c.Call = c.Call.Return() + return c +} + +// Do rewrite *gomock.Call.Do +func (c *MockSentPacketHandlerDropPacketsCall) Do(f func(protocol.EncryptionLevel)) *MockSentPacketHandlerDropPacketsCall { + c.Call = c.Call.Do(f) + return c +} + +// DoAndReturn rewrite *gomock.Call.DoAndReturn +func (c *MockSentPacketHandlerDropPacketsCall) DoAndReturn(f func(protocol.EncryptionLevel)) *MockSentPacketHandlerDropPacketsCall { + c.Call = c.Call.DoAndReturn(f) + return c } // ECNMode mocks base method. @@ -62,9 +87,33 @@ func (m *MockSentPacketHandler) ECNMode(arg0 bool) protocol.ECN { } // ECNMode indicates an expected call of ECNMode. -func (mr *MockSentPacketHandlerMockRecorder) ECNMode(arg0 any) *gomock.Call { +func (mr *MockSentPacketHandlerMockRecorder) ECNMode(arg0 any) *MockSentPacketHandlerECNModeCall { mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "ECNMode", reflect.TypeOf((*MockSentPacketHandler)(nil).ECNMode), arg0) + call := mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "ECNMode", reflect.TypeOf((*MockSentPacketHandler)(nil).ECNMode), arg0) + return &MockSentPacketHandlerECNModeCall{Call: call} +} + +// MockSentPacketHandlerECNModeCall wrap *gomock.Call +type MockSentPacketHandlerECNModeCall struct { + *gomock.Call +} + +// Return rewrite *gomock.Call.Return +func (c *MockSentPacketHandlerECNModeCall) Return(arg0 protocol.ECN) *MockSentPacketHandlerECNModeCall { + c.Call = c.Call.Return(arg0) + return c +} + +// Do rewrite *gomock.Call.Do +func (c *MockSentPacketHandlerECNModeCall) Do(f func(bool) protocol.ECN) *MockSentPacketHandlerECNModeCall { + c.Call = c.Call.Do(f) + return c +} + +// DoAndReturn rewrite *gomock.Call.DoAndReturn +func (c *MockSentPacketHandlerECNModeCall) DoAndReturn(f func(bool) protocol.ECN) *MockSentPacketHandlerECNModeCall { + c.Call = c.Call.DoAndReturn(f) + return c } // GetLossDetectionTimeout mocks base method. @@ -76,9 +125,33 @@ func (m *MockSentPacketHandler) GetLossDetectionTimeout() time.Time { } // GetLossDetectionTimeout indicates an expected call of GetLossDetectionTimeout. -func (mr *MockSentPacketHandlerMockRecorder) GetLossDetectionTimeout() *gomock.Call { +func (mr *MockSentPacketHandlerMockRecorder) GetLossDetectionTimeout() *MockSentPacketHandlerGetLossDetectionTimeoutCall { mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetLossDetectionTimeout", reflect.TypeOf((*MockSentPacketHandler)(nil).GetLossDetectionTimeout)) + call := mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetLossDetectionTimeout", reflect.TypeOf((*MockSentPacketHandler)(nil).GetLossDetectionTimeout)) + return &MockSentPacketHandlerGetLossDetectionTimeoutCall{Call: call} +} + +// MockSentPacketHandlerGetLossDetectionTimeoutCall wrap *gomock.Call +type MockSentPacketHandlerGetLossDetectionTimeoutCall struct { + *gomock.Call +} + +// Return rewrite *gomock.Call.Return +func (c *MockSentPacketHandlerGetLossDetectionTimeoutCall) Return(arg0 time.Time) *MockSentPacketHandlerGetLossDetectionTimeoutCall { + c.Call = c.Call.Return(arg0) + return c +} + +// Do rewrite *gomock.Call.Do +func (c *MockSentPacketHandlerGetLossDetectionTimeoutCall) Do(f func() time.Time) *MockSentPacketHandlerGetLossDetectionTimeoutCall { + c.Call = c.Call.Do(f) + return c +} + +// DoAndReturn rewrite *gomock.Call.DoAndReturn +func (c *MockSentPacketHandlerGetLossDetectionTimeoutCall) DoAndReturn(f func() time.Time) *MockSentPacketHandlerGetLossDetectionTimeoutCall { + c.Call = c.Call.DoAndReturn(f) + return c } // OnLossDetectionTimeout mocks base method. @@ -90,9 +163,33 @@ func (m *MockSentPacketHandler) OnLossDetectionTimeout() error { } // OnLossDetectionTimeout indicates an expected call of OnLossDetectionTimeout. -func (mr *MockSentPacketHandlerMockRecorder) OnLossDetectionTimeout() *gomock.Call { +func (mr *MockSentPacketHandlerMockRecorder) OnLossDetectionTimeout() *MockSentPacketHandlerOnLossDetectionTimeoutCall { mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "OnLossDetectionTimeout", reflect.TypeOf((*MockSentPacketHandler)(nil).OnLossDetectionTimeout)) + call := mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "OnLossDetectionTimeout", reflect.TypeOf((*MockSentPacketHandler)(nil).OnLossDetectionTimeout)) + return &MockSentPacketHandlerOnLossDetectionTimeoutCall{Call: call} +} + +// MockSentPacketHandlerOnLossDetectionTimeoutCall wrap *gomock.Call +type MockSentPacketHandlerOnLossDetectionTimeoutCall struct { + *gomock.Call +} + +// Return rewrite *gomock.Call.Return +func (c *MockSentPacketHandlerOnLossDetectionTimeoutCall) Return(arg0 error) *MockSentPacketHandlerOnLossDetectionTimeoutCall { + c.Call = c.Call.Return(arg0) + return c +} + +// Do rewrite *gomock.Call.Do +func (c *MockSentPacketHandlerOnLossDetectionTimeoutCall) Do(f func() error) *MockSentPacketHandlerOnLossDetectionTimeoutCall { + c.Call = c.Call.Do(f) + return c +} + +// DoAndReturn rewrite *gomock.Call.DoAndReturn +func (c *MockSentPacketHandlerOnLossDetectionTimeoutCall) DoAndReturn(f func() error) *MockSentPacketHandlerOnLossDetectionTimeoutCall { + c.Call = c.Call.DoAndReturn(f) + return c } // PeekPacketNumber mocks base method. @@ -105,9 +202,33 @@ func (m *MockSentPacketHandler) PeekPacketNumber(arg0 protocol.EncryptionLevel) } // PeekPacketNumber indicates an expected call of PeekPacketNumber. -func (mr *MockSentPacketHandlerMockRecorder) PeekPacketNumber(arg0 any) *gomock.Call { +func (mr *MockSentPacketHandlerMockRecorder) PeekPacketNumber(arg0 any) *MockSentPacketHandlerPeekPacketNumberCall { mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "PeekPacketNumber", reflect.TypeOf((*MockSentPacketHandler)(nil).PeekPacketNumber), arg0) + call := mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "PeekPacketNumber", reflect.TypeOf((*MockSentPacketHandler)(nil).PeekPacketNumber), arg0) + return &MockSentPacketHandlerPeekPacketNumberCall{Call: call} +} + +// MockSentPacketHandlerPeekPacketNumberCall wrap *gomock.Call +type MockSentPacketHandlerPeekPacketNumberCall struct { + *gomock.Call +} + +// Return rewrite *gomock.Call.Return +func (c *MockSentPacketHandlerPeekPacketNumberCall) Return(arg0 protocol.PacketNumber, arg1 protocol.PacketNumberLen) *MockSentPacketHandlerPeekPacketNumberCall { + c.Call = c.Call.Return(arg0, arg1) + return c +} + +// Do rewrite *gomock.Call.Do +func (c *MockSentPacketHandlerPeekPacketNumberCall) Do(f func(protocol.EncryptionLevel) (protocol.PacketNumber, protocol.PacketNumberLen)) *MockSentPacketHandlerPeekPacketNumberCall { + c.Call = c.Call.Do(f) + return c +} + +// DoAndReturn rewrite *gomock.Call.DoAndReturn +func (c *MockSentPacketHandlerPeekPacketNumberCall) DoAndReturn(f func(protocol.EncryptionLevel) (protocol.PacketNumber, protocol.PacketNumberLen)) *MockSentPacketHandlerPeekPacketNumberCall { + c.Call = c.Call.DoAndReturn(f) + return c } // PopPacketNumber mocks base method. @@ -119,9 +240,33 @@ func (m *MockSentPacketHandler) PopPacketNumber(arg0 protocol.EncryptionLevel) p } // PopPacketNumber indicates an expected call of PopPacketNumber. -func (mr *MockSentPacketHandlerMockRecorder) PopPacketNumber(arg0 any) *gomock.Call { +func (mr *MockSentPacketHandlerMockRecorder) PopPacketNumber(arg0 any) *MockSentPacketHandlerPopPacketNumberCall { mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "PopPacketNumber", reflect.TypeOf((*MockSentPacketHandler)(nil).PopPacketNumber), arg0) + call := mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "PopPacketNumber", reflect.TypeOf((*MockSentPacketHandler)(nil).PopPacketNumber), arg0) + return &MockSentPacketHandlerPopPacketNumberCall{Call: call} +} + +// MockSentPacketHandlerPopPacketNumberCall wrap *gomock.Call +type MockSentPacketHandlerPopPacketNumberCall struct { + *gomock.Call +} + +// Return rewrite *gomock.Call.Return +func (c *MockSentPacketHandlerPopPacketNumberCall) Return(arg0 protocol.PacketNumber) *MockSentPacketHandlerPopPacketNumberCall { + c.Call = c.Call.Return(arg0) + return c +} + +// Do rewrite *gomock.Call.Do +func (c *MockSentPacketHandlerPopPacketNumberCall) Do(f func(protocol.EncryptionLevel) protocol.PacketNumber) *MockSentPacketHandlerPopPacketNumberCall { + c.Call = c.Call.Do(f) + return c +} + +// DoAndReturn rewrite *gomock.Call.DoAndReturn +func (c *MockSentPacketHandlerPopPacketNumberCall) DoAndReturn(f func(protocol.EncryptionLevel) protocol.PacketNumber) *MockSentPacketHandlerPopPacketNumberCall { + c.Call = c.Call.DoAndReturn(f) + return c } // QueueProbePacket mocks base method. @@ -133,9 +278,33 @@ func (m *MockSentPacketHandler) QueueProbePacket(arg0 protocol.EncryptionLevel) } // QueueProbePacket indicates an expected call of QueueProbePacket. -func (mr *MockSentPacketHandlerMockRecorder) QueueProbePacket(arg0 any) *gomock.Call { +func (mr *MockSentPacketHandlerMockRecorder) QueueProbePacket(arg0 any) *MockSentPacketHandlerQueueProbePacketCall { mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "QueueProbePacket", reflect.TypeOf((*MockSentPacketHandler)(nil).QueueProbePacket), arg0) + call := mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "QueueProbePacket", reflect.TypeOf((*MockSentPacketHandler)(nil).QueueProbePacket), arg0) + return &MockSentPacketHandlerQueueProbePacketCall{Call: call} +} + +// MockSentPacketHandlerQueueProbePacketCall wrap *gomock.Call +type MockSentPacketHandlerQueueProbePacketCall struct { + *gomock.Call +} + +// Return rewrite *gomock.Call.Return +func (c *MockSentPacketHandlerQueueProbePacketCall) Return(arg0 bool) *MockSentPacketHandlerQueueProbePacketCall { + c.Call = c.Call.Return(arg0) + return c +} + +// Do rewrite *gomock.Call.Do +func (c *MockSentPacketHandlerQueueProbePacketCall) Do(f func(protocol.EncryptionLevel) bool) *MockSentPacketHandlerQueueProbePacketCall { + c.Call = c.Call.Do(f) + return c +} + +// DoAndReturn rewrite *gomock.Call.DoAndReturn +func (c *MockSentPacketHandlerQueueProbePacketCall) DoAndReturn(f func(protocol.EncryptionLevel) bool) *MockSentPacketHandlerQueueProbePacketCall { + c.Call = c.Call.DoAndReturn(f) + return c } // ReceivedAck mocks base method. @@ -148,9 +317,33 @@ func (m *MockSentPacketHandler) ReceivedAck(arg0 *wire.AckFrame, arg1 protocol.E } // ReceivedAck indicates an expected call of ReceivedAck. -func (mr *MockSentPacketHandlerMockRecorder) ReceivedAck(arg0, arg1, arg2 any) *gomock.Call { +func (mr *MockSentPacketHandlerMockRecorder) ReceivedAck(arg0, arg1, arg2 any) *MockSentPacketHandlerReceivedAckCall { mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "ReceivedAck", reflect.TypeOf((*MockSentPacketHandler)(nil).ReceivedAck), arg0, arg1, arg2) + call := mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "ReceivedAck", reflect.TypeOf((*MockSentPacketHandler)(nil).ReceivedAck), arg0, arg1, arg2) + return &MockSentPacketHandlerReceivedAckCall{Call: call} +} + +// MockSentPacketHandlerReceivedAckCall wrap *gomock.Call +type MockSentPacketHandlerReceivedAckCall struct { + *gomock.Call +} + +// Return rewrite *gomock.Call.Return +func (c *MockSentPacketHandlerReceivedAckCall) Return(arg0 bool, arg1 error) *MockSentPacketHandlerReceivedAckCall { + c.Call = c.Call.Return(arg0, arg1) + return c +} + +// Do rewrite *gomock.Call.Do +func (c *MockSentPacketHandlerReceivedAckCall) Do(f func(*wire.AckFrame, protocol.EncryptionLevel, time.Time) (bool, error)) *MockSentPacketHandlerReceivedAckCall { + c.Call = c.Call.Do(f) + return c +} + +// DoAndReturn rewrite *gomock.Call.DoAndReturn +func (c *MockSentPacketHandlerReceivedAckCall) DoAndReturn(f func(*wire.AckFrame, protocol.EncryptionLevel, time.Time) (bool, error)) *MockSentPacketHandlerReceivedAckCall { + c.Call = c.Call.DoAndReturn(f) + return c } // ReceivedBytes mocks base method. @@ -160,9 +353,33 @@ func (m *MockSentPacketHandler) ReceivedBytes(arg0 protocol.ByteCount) { } // ReceivedBytes indicates an expected call of ReceivedBytes. -func (mr *MockSentPacketHandlerMockRecorder) ReceivedBytes(arg0 any) *gomock.Call { +func (mr *MockSentPacketHandlerMockRecorder) ReceivedBytes(arg0 any) *MockSentPacketHandlerReceivedBytesCall { mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "ReceivedBytes", reflect.TypeOf((*MockSentPacketHandler)(nil).ReceivedBytes), arg0) + call := mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "ReceivedBytes", reflect.TypeOf((*MockSentPacketHandler)(nil).ReceivedBytes), arg0) + return &MockSentPacketHandlerReceivedBytesCall{Call: call} +} + +// MockSentPacketHandlerReceivedBytesCall wrap *gomock.Call +type MockSentPacketHandlerReceivedBytesCall struct { + *gomock.Call +} + +// Return rewrite *gomock.Call.Return +func (c *MockSentPacketHandlerReceivedBytesCall) Return() *MockSentPacketHandlerReceivedBytesCall { + c.Call = c.Call.Return() + return c +} + +// Do rewrite *gomock.Call.Do +func (c *MockSentPacketHandlerReceivedBytesCall) Do(f func(protocol.ByteCount)) *MockSentPacketHandlerReceivedBytesCall { + c.Call = c.Call.Do(f) + return c +} + +// DoAndReturn rewrite *gomock.Call.DoAndReturn +func (c *MockSentPacketHandlerReceivedBytesCall) DoAndReturn(f func(protocol.ByteCount)) *MockSentPacketHandlerReceivedBytesCall { + c.Call = c.Call.DoAndReturn(f) + return c } // ResetForRetry mocks base method. @@ -174,9 +391,33 @@ func (m *MockSentPacketHandler) ResetForRetry(arg0 time.Time) error { } // ResetForRetry indicates an expected call of ResetForRetry. -func (mr *MockSentPacketHandlerMockRecorder) ResetForRetry(arg0 any) *gomock.Call { +func (mr *MockSentPacketHandlerMockRecorder) ResetForRetry(arg0 any) *MockSentPacketHandlerResetForRetryCall { mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "ResetForRetry", reflect.TypeOf((*MockSentPacketHandler)(nil).ResetForRetry), arg0) + call := mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "ResetForRetry", reflect.TypeOf((*MockSentPacketHandler)(nil).ResetForRetry), arg0) + return &MockSentPacketHandlerResetForRetryCall{Call: call} +} + +// MockSentPacketHandlerResetForRetryCall wrap *gomock.Call +type MockSentPacketHandlerResetForRetryCall struct { + *gomock.Call +} + +// Return rewrite *gomock.Call.Return +func (c *MockSentPacketHandlerResetForRetryCall) Return(arg0 error) *MockSentPacketHandlerResetForRetryCall { + c.Call = c.Call.Return(arg0) + return c +} + +// Do rewrite *gomock.Call.Do +func (c *MockSentPacketHandlerResetForRetryCall) Do(f func(time.Time) error) *MockSentPacketHandlerResetForRetryCall { + c.Call = c.Call.Do(f) + return c +} + +// DoAndReturn rewrite *gomock.Call.DoAndReturn +func (c *MockSentPacketHandlerResetForRetryCall) DoAndReturn(f func(time.Time) error) *MockSentPacketHandlerResetForRetryCall { + c.Call = c.Call.DoAndReturn(f) + return c } // SendMode mocks base method. @@ -188,9 +429,33 @@ func (m *MockSentPacketHandler) SendMode(arg0 time.Time) ackhandler.SendMode { } // SendMode indicates an expected call of SendMode. -func (mr *MockSentPacketHandlerMockRecorder) SendMode(arg0 any) *gomock.Call { +func (mr *MockSentPacketHandlerMockRecorder) SendMode(arg0 any) *MockSentPacketHandlerSendModeCall { mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "SendMode", reflect.TypeOf((*MockSentPacketHandler)(nil).SendMode), arg0) + call := mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "SendMode", reflect.TypeOf((*MockSentPacketHandler)(nil).SendMode), arg0) + return &MockSentPacketHandlerSendModeCall{Call: call} +} + +// MockSentPacketHandlerSendModeCall wrap *gomock.Call +type MockSentPacketHandlerSendModeCall struct { + *gomock.Call +} + +// Return rewrite *gomock.Call.Return +func (c *MockSentPacketHandlerSendModeCall) Return(arg0 ackhandler.SendMode) *MockSentPacketHandlerSendModeCall { + c.Call = c.Call.Return(arg0) + return c +} + +// Do rewrite *gomock.Call.Do +func (c *MockSentPacketHandlerSendModeCall) Do(f func(time.Time) ackhandler.SendMode) *MockSentPacketHandlerSendModeCall { + c.Call = c.Call.Do(f) + return c +} + +// DoAndReturn rewrite *gomock.Call.DoAndReturn +func (c *MockSentPacketHandlerSendModeCall) DoAndReturn(f func(time.Time) ackhandler.SendMode) *MockSentPacketHandlerSendModeCall { + c.Call = c.Call.DoAndReturn(f) + return c } // SentPacket mocks base method. @@ -200,9 +465,33 @@ func (m *MockSentPacketHandler) SentPacket(arg0 time.Time, arg1, arg2 protocol.P } // SentPacket indicates an expected call of SentPacket. -func (mr *MockSentPacketHandlerMockRecorder) SentPacket(arg0, arg1, arg2, arg3, arg4, arg5, arg6, arg7, arg8 any) *gomock.Call { +func (mr *MockSentPacketHandlerMockRecorder) SentPacket(arg0, arg1, arg2, arg3, arg4, arg5, arg6, arg7, arg8 any) *MockSentPacketHandlerSentPacketCall { mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "SentPacket", reflect.TypeOf((*MockSentPacketHandler)(nil).SentPacket), arg0, arg1, arg2, arg3, arg4, arg5, arg6, arg7, arg8) + call := mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "SentPacket", reflect.TypeOf((*MockSentPacketHandler)(nil).SentPacket), arg0, arg1, arg2, arg3, arg4, arg5, arg6, arg7, arg8) + return &MockSentPacketHandlerSentPacketCall{Call: call} +} + +// MockSentPacketHandlerSentPacketCall wrap *gomock.Call +type MockSentPacketHandlerSentPacketCall struct { + *gomock.Call +} + +// Return rewrite *gomock.Call.Return +func (c *MockSentPacketHandlerSentPacketCall) Return() *MockSentPacketHandlerSentPacketCall { + c.Call = c.Call.Return() + return c +} + +// Do rewrite *gomock.Call.Do +func (c *MockSentPacketHandlerSentPacketCall) Do(f func(time.Time, protocol.PacketNumber, protocol.PacketNumber, []ackhandler.StreamFrame, []ackhandler.Frame, protocol.EncryptionLevel, protocol.ECN, protocol.ByteCount, bool)) *MockSentPacketHandlerSentPacketCall { + c.Call = c.Call.Do(f) + return c +} + +// DoAndReturn rewrite *gomock.Call.DoAndReturn +func (c *MockSentPacketHandlerSentPacketCall) DoAndReturn(f func(time.Time, protocol.PacketNumber, protocol.PacketNumber, []ackhandler.StreamFrame, []ackhandler.Frame, protocol.EncryptionLevel, protocol.ECN, protocol.ByteCount, bool)) *MockSentPacketHandlerSentPacketCall { + c.Call = c.Call.DoAndReturn(f) + return c } // SetHandshakeConfirmed mocks base method. @@ -212,9 +501,33 @@ func (m *MockSentPacketHandler) SetHandshakeConfirmed() { } // SetHandshakeConfirmed indicates an expected call of SetHandshakeConfirmed. -func (mr *MockSentPacketHandlerMockRecorder) SetHandshakeConfirmed() *gomock.Call { +func (mr *MockSentPacketHandlerMockRecorder) SetHandshakeConfirmed() *MockSentPacketHandlerSetHandshakeConfirmedCall { mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "SetHandshakeConfirmed", reflect.TypeOf((*MockSentPacketHandler)(nil).SetHandshakeConfirmed)) + call := mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "SetHandshakeConfirmed", reflect.TypeOf((*MockSentPacketHandler)(nil).SetHandshakeConfirmed)) + return &MockSentPacketHandlerSetHandshakeConfirmedCall{Call: call} +} + +// MockSentPacketHandlerSetHandshakeConfirmedCall wrap *gomock.Call +type MockSentPacketHandlerSetHandshakeConfirmedCall struct { + *gomock.Call +} + +// Return rewrite *gomock.Call.Return +func (c *MockSentPacketHandlerSetHandshakeConfirmedCall) Return() *MockSentPacketHandlerSetHandshakeConfirmedCall { + c.Call = c.Call.Return() + return c +} + +// Do rewrite *gomock.Call.Do +func (c *MockSentPacketHandlerSetHandshakeConfirmedCall) Do(f func()) *MockSentPacketHandlerSetHandshakeConfirmedCall { + c.Call = c.Call.Do(f) + return c +} + +// DoAndReturn rewrite *gomock.Call.DoAndReturn +func (c *MockSentPacketHandlerSetHandshakeConfirmedCall) DoAndReturn(f func()) *MockSentPacketHandlerSetHandshakeConfirmedCall { + c.Call = c.Call.DoAndReturn(f) + return c } // SetMaxDatagramSize mocks base method. @@ -224,9 +537,33 @@ func (m *MockSentPacketHandler) SetMaxDatagramSize(arg0 protocol.ByteCount) { } // SetMaxDatagramSize indicates an expected call of SetMaxDatagramSize. -func (mr *MockSentPacketHandlerMockRecorder) SetMaxDatagramSize(arg0 any) *gomock.Call { +func (mr *MockSentPacketHandlerMockRecorder) SetMaxDatagramSize(arg0 any) *MockSentPacketHandlerSetMaxDatagramSizeCall { mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "SetMaxDatagramSize", reflect.TypeOf((*MockSentPacketHandler)(nil).SetMaxDatagramSize), arg0) + call := mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "SetMaxDatagramSize", reflect.TypeOf((*MockSentPacketHandler)(nil).SetMaxDatagramSize), arg0) + return &MockSentPacketHandlerSetMaxDatagramSizeCall{Call: call} +} + +// MockSentPacketHandlerSetMaxDatagramSizeCall wrap *gomock.Call +type MockSentPacketHandlerSetMaxDatagramSizeCall struct { + *gomock.Call +} + +// Return rewrite *gomock.Call.Return +func (c *MockSentPacketHandlerSetMaxDatagramSizeCall) Return() *MockSentPacketHandlerSetMaxDatagramSizeCall { + c.Call = c.Call.Return() + return c +} + +// Do rewrite *gomock.Call.Do +func (c *MockSentPacketHandlerSetMaxDatagramSizeCall) Do(f func(protocol.ByteCount)) *MockSentPacketHandlerSetMaxDatagramSizeCall { + c.Call = c.Call.Do(f) + return c +} + +// DoAndReturn rewrite *gomock.Call.DoAndReturn +func (c *MockSentPacketHandlerSetMaxDatagramSizeCall) DoAndReturn(f func(protocol.ByteCount)) *MockSentPacketHandlerSetMaxDatagramSizeCall { + c.Call = c.Call.DoAndReturn(f) + return c } // TimeUntilSend mocks base method. @@ -238,7 +575,31 @@ func (m *MockSentPacketHandler) TimeUntilSend() time.Time { } // TimeUntilSend indicates an expected call of TimeUntilSend. -func (mr *MockSentPacketHandlerMockRecorder) TimeUntilSend() *gomock.Call { +func (mr *MockSentPacketHandlerMockRecorder) TimeUntilSend() *MockSentPacketHandlerTimeUntilSendCall { mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "TimeUntilSend", reflect.TypeOf((*MockSentPacketHandler)(nil).TimeUntilSend)) + call := mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "TimeUntilSend", reflect.TypeOf((*MockSentPacketHandler)(nil).TimeUntilSend)) + return &MockSentPacketHandlerTimeUntilSendCall{Call: call} +} + +// MockSentPacketHandlerTimeUntilSendCall wrap *gomock.Call +type MockSentPacketHandlerTimeUntilSendCall struct { + *gomock.Call +} + +// Return rewrite *gomock.Call.Return +func (c *MockSentPacketHandlerTimeUntilSendCall) Return(arg0 time.Time) *MockSentPacketHandlerTimeUntilSendCall { + c.Call = c.Call.Return(arg0) + return c +} + +// Do rewrite *gomock.Call.Do +func (c *MockSentPacketHandlerTimeUntilSendCall) Do(f func() time.Time) *MockSentPacketHandlerTimeUntilSendCall { + c.Call = c.Call.Do(f) + return c +} + +// DoAndReturn rewrite *gomock.Call.DoAndReturn +func (c *MockSentPacketHandlerTimeUntilSendCall) DoAndReturn(f func() time.Time) *MockSentPacketHandlerTimeUntilSendCall { + c.Call = c.Call.DoAndReturn(f) + return c } diff --git a/internal/mocks/congestion.go b/internal/mocks/congestion.go index 99ffd8169..d1fe76aa6 100644 --- a/internal/mocks/congestion.go +++ b/internal/mocks/congestion.go @@ -3,8 +3,9 @@ // // Generated by this command: // -// mockgen -build_flags=-tags=gomock -package mocks -destination congestion.go github.com/refraction-networking/uquic/internal/congestion SendAlgorithmWithDebugInfos +// mockgen -typed -build_flags=-tags=gomock -package mocks -destination congestion.go github.com/quic-go/quic-go/internal/congestion SendAlgorithmWithDebugInfos // + // Package mocks is a generated GoMock package. package mocks @@ -48,9 +49,33 @@ func (m *MockSendAlgorithmWithDebugInfos) CanSend(arg0 protocol.ByteCount) bool } // CanSend indicates an expected call of CanSend. -func (mr *MockSendAlgorithmWithDebugInfosMockRecorder) CanSend(arg0 any) *gomock.Call { +func (mr *MockSendAlgorithmWithDebugInfosMockRecorder) CanSend(arg0 any) *MockSendAlgorithmWithDebugInfosCanSendCall { mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "CanSend", reflect.TypeOf((*MockSendAlgorithmWithDebugInfos)(nil).CanSend), arg0) + call := mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "CanSend", reflect.TypeOf((*MockSendAlgorithmWithDebugInfos)(nil).CanSend), arg0) + return &MockSendAlgorithmWithDebugInfosCanSendCall{Call: call} +} + +// MockSendAlgorithmWithDebugInfosCanSendCall wrap *gomock.Call +type MockSendAlgorithmWithDebugInfosCanSendCall struct { + *gomock.Call +} + +// Return rewrite *gomock.Call.Return +func (c *MockSendAlgorithmWithDebugInfosCanSendCall) Return(arg0 bool) *MockSendAlgorithmWithDebugInfosCanSendCall { + c.Call = c.Call.Return(arg0) + return c +} + +// Do rewrite *gomock.Call.Do +func (c *MockSendAlgorithmWithDebugInfosCanSendCall) Do(f func(protocol.ByteCount) bool) *MockSendAlgorithmWithDebugInfosCanSendCall { + c.Call = c.Call.Do(f) + return c +} + +// DoAndReturn rewrite *gomock.Call.DoAndReturn +func (c *MockSendAlgorithmWithDebugInfosCanSendCall) DoAndReturn(f func(protocol.ByteCount) bool) *MockSendAlgorithmWithDebugInfosCanSendCall { + c.Call = c.Call.DoAndReturn(f) + return c } // GetCongestionWindow mocks base method. @@ -62,9 +87,33 @@ func (m *MockSendAlgorithmWithDebugInfos) GetCongestionWindow() protocol.ByteCou } // GetCongestionWindow indicates an expected call of GetCongestionWindow. -func (mr *MockSendAlgorithmWithDebugInfosMockRecorder) GetCongestionWindow() *gomock.Call { +func (mr *MockSendAlgorithmWithDebugInfosMockRecorder) GetCongestionWindow() *MockSendAlgorithmWithDebugInfosGetCongestionWindowCall { mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetCongestionWindow", reflect.TypeOf((*MockSendAlgorithmWithDebugInfos)(nil).GetCongestionWindow)) + call := mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetCongestionWindow", reflect.TypeOf((*MockSendAlgorithmWithDebugInfos)(nil).GetCongestionWindow)) + return &MockSendAlgorithmWithDebugInfosGetCongestionWindowCall{Call: call} +} + +// MockSendAlgorithmWithDebugInfosGetCongestionWindowCall wrap *gomock.Call +type MockSendAlgorithmWithDebugInfosGetCongestionWindowCall struct { + *gomock.Call +} + +// Return rewrite *gomock.Call.Return +func (c *MockSendAlgorithmWithDebugInfosGetCongestionWindowCall) Return(arg0 protocol.ByteCount) *MockSendAlgorithmWithDebugInfosGetCongestionWindowCall { + c.Call = c.Call.Return(arg0) + return c +} + +// Do rewrite *gomock.Call.Do +func (c *MockSendAlgorithmWithDebugInfosGetCongestionWindowCall) Do(f func() protocol.ByteCount) *MockSendAlgorithmWithDebugInfosGetCongestionWindowCall { + c.Call = c.Call.Do(f) + return c +} + +// DoAndReturn rewrite *gomock.Call.DoAndReturn +func (c *MockSendAlgorithmWithDebugInfosGetCongestionWindowCall) DoAndReturn(f func() protocol.ByteCount) *MockSendAlgorithmWithDebugInfosGetCongestionWindowCall { + c.Call = c.Call.DoAndReturn(f) + return c } // HasPacingBudget mocks base method. @@ -76,9 +125,33 @@ func (m *MockSendAlgorithmWithDebugInfos) HasPacingBudget(arg0 time.Time) bool { } // HasPacingBudget indicates an expected call of HasPacingBudget. -func (mr *MockSendAlgorithmWithDebugInfosMockRecorder) HasPacingBudget(arg0 any) *gomock.Call { +func (mr *MockSendAlgorithmWithDebugInfosMockRecorder) HasPacingBudget(arg0 any) *MockSendAlgorithmWithDebugInfosHasPacingBudgetCall { mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "HasPacingBudget", reflect.TypeOf((*MockSendAlgorithmWithDebugInfos)(nil).HasPacingBudget), arg0) + call := mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "HasPacingBudget", reflect.TypeOf((*MockSendAlgorithmWithDebugInfos)(nil).HasPacingBudget), arg0) + return &MockSendAlgorithmWithDebugInfosHasPacingBudgetCall{Call: call} +} + +// MockSendAlgorithmWithDebugInfosHasPacingBudgetCall wrap *gomock.Call +type MockSendAlgorithmWithDebugInfosHasPacingBudgetCall struct { + *gomock.Call +} + +// Return rewrite *gomock.Call.Return +func (c *MockSendAlgorithmWithDebugInfosHasPacingBudgetCall) Return(arg0 bool) *MockSendAlgorithmWithDebugInfosHasPacingBudgetCall { + c.Call = c.Call.Return(arg0) + return c +} + +// Do rewrite *gomock.Call.Do +func (c *MockSendAlgorithmWithDebugInfosHasPacingBudgetCall) Do(f func(time.Time) bool) *MockSendAlgorithmWithDebugInfosHasPacingBudgetCall { + c.Call = c.Call.Do(f) + return c +} + +// DoAndReturn rewrite *gomock.Call.DoAndReturn +func (c *MockSendAlgorithmWithDebugInfosHasPacingBudgetCall) DoAndReturn(f func(time.Time) bool) *MockSendAlgorithmWithDebugInfosHasPacingBudgetCall { + c.Call = c.Call.DoAndReturn(f) + return c } // InRecovery mocks base method. @@ -90,9 +163,33 @@ func (m *MockSendAlgorithmWithDebugInfos) InRecovery() bool { } // InRecovery indicates an expected call of InRecovery. -func (mr *MockSendAlgorithmWithDebugInfosMockRecorder) InRecovery() *gomock.Call { +func (mr *MockSendAlgorithmWithDebugInfosMockRecorder) InRecovery() *MockSendAlgorithmWithDebugInfosInRecoveryCall { mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "InRecovery", reflect.TypeOf((*MockSendAlgorithmWithDebugInfos)(nil).InRecovery)) + call := mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "InRecovery", reflect.TypeOf((*MockSendAlgorithmWithDebugInfos)(nil).InRecovery)) + return &MockSendAlgorithmWithDebugInfosInRecoveryCall{Call: call} +} + +// MockSendAlgorithmWithDebugInfosInRecoveryCall wrap *gomock.Call +type MockSendAlgorithmWithDebugInfosInRecoveryCall struct { + *gomock.Call +} + +// Return rewrite *gomock.Call.Return +func (c *MockSendAlgorithmWithDebugInfosInRecoveryCall) Return(arg0 bool) *MockSendAlgorithmWithDebugInfosInRecoveryCall { + c.Call = c.Call.Return(arg0) + return c +} + +// Do rewrite *gomock.Call.Do +func (c *MockSendAlgorithmWithDebugInfosInRecoveryCall) Do(f func() bool) *MockSendAlgorithmWithDebugInfosInRecoveryCall { + c.Call = c.Call.Do(f) + return c +} + +// DoAndReturn rewrite *gomock.Call.DoAndReturn +func (c *MockSendAlgorithmWithDebugInfosInRecoveryCall) DoAndReturn(f func() bool) *MockSendAlgorithmWithDebugInfosInRecoveryCall { + c.Call = c.Call.DoAndReturn(f) + return c } // InSlowStart mocks base method. @@ -104,9 +201,33 @@ func (m *MockSendAlgorithmWithDebugInfos) InSlowStart() bool { } // InSlowStart indicates an expected call of InSlowStart. -func (mr *MockSendAlgorithmWithDebugInfosMockRecorder) InSlowStart() *gomock.Call { +func (mr *MockSendAlgorithmWithDebugInfosMockRecorder) InSlowStart() *MockSendAlgorithmWithDebugInfosInSlowStartCall { mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "InSlowStart", reflect.TypeOf((*MockSendAlgorithmWithDebugInfos)(nil).InSlowStart)) + call := mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "InSlowStart", reflect.TypeOf((*MockSendAlgorithmWithDebugInfos)(nil).InSlowStart)) + return &MockSendAlgorithmWithDebugInfosInSlowStartCall{Call: call} +} + +// MockSendAlgorithmWithDebugInfosInSlowStartCall wrap *gomock.Call +type MockSendAlgorithmWithDebugInfosInSlowStartCall struct { + *gomock.Call +} + +// Return rewrite *gomock.Call.Return +func (c *MockSendAlgorithmWithDebugInfosInSlowStartCall) Return(arg0 bool) *MockSendAlgorithmWithDebugInfosInSlowStartCall { + c.Call = c.Call.Return(arg0) + return c +} + +// Do rewrite *gomock.Call.Do +func (c *MockSendAlgorithmWithDebugInfosInSlowStartCall) Do(f func() bool) *MockSendAlgorithmWithDebugInfosInSlowStartCall { + c.Call = c.Call.Do(f) + return c +} + +// DoAndReturn rewrite *gomock.Call.DoAndReturn +func (c *MockSendAlgorithmWithDebugInfosInSlowStartCall) DoAndReturn(f func() bool) *MockSendAlgorithmWithDebugInfosInSlowStartCall { + c.Call = c.Call.DoAndReturn(f) + return c } // MaybeExitSlowStart mocks base method. @@ -116,9 +237,33 @@ func (m *MockSendAlgorithmWithDebugInfos) MaybeExitSlowStart() { } // MaybeExitSlowStart indicates an expected call of MaybeExitSlowStart. -func (mr *MockSendAlgorithmWithDebugInfosMockRecorder) MaybeExitSlowStart() *gomock.Call { +func (mr *MockSendAlgorithmWithDebugInfosMockRecorder) MaybeExitSlowStart() *MockSendAlgorithmWithDebugInfosMaybeExitSlowStartCall { mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "MaybeExitSlowStart", reflect.TypeOf((*MockSendAlgorithmWithDebugInfos)(nil).MaybeExitSlowStart)) + call := mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "MaybeExitSlowStart", reflect.TypeOf((*MockSendAlgorithmWithDebugInfos)(nil).MaybeExitSlowStart)) + return &MockSendAlgorithmWithDebugInfosMaybeExitSlowStartCall{Call: call} +} + +// MockSendAlgorithmWithDebugInfosMaybeExitSlowStartCall wrap *gomock.Call +type MockSendAlgorithmWithDebugInfosMaybeExitSlowStartCall struct { + *gomock.Call +} + +// Return rewrite *gomock.Call.Return +func (c *MockSendAlgorithmWithDebugInfosMaybeExitSlowStartCall) Return() *MockSendAlgorithmWithDebugInfosMaybeExitSlowStartCall { + c.Call = c.Call.Return() + return c +} + +// Do rewrite *gomock.Call.Do +func (c *MockSendAlgorithmWithDebugInfosMaybeExitSlowStartCall) Do(f func()) *MockSendAlgorithmWithDebugInfosMaybeExitSlowStartCall { + c.Call = c.Call.Do(f) + return c +} + +// DoAndReturn rewrite *gomock.Call.DoAndReturn +func (c *MockSendAlgorithmWithDebugInfosMaybeExitSlowStartCall) DoAndReturn(f func()) *MockSendAlgorithmWithDebugInfosMaybeExitSlowStartCall { + c.Call = c.Call.DoAndReturn(f) + return c } // OnCongestionEvent mocks base method. @@ -128,9 +273,33 @@ func (m *MockSendAlgorithmWithDebugInfos) OnCongestionEvent(arg0 protocol.Packet } // OnCongestionEvent indicates an expected call of OnCongestionEvent. -func (mr *MockSendAlgorithmWithDebugInfosMockRecorder) OnCongestionEvent(arg0, arg1, arg2 any) *gomock.Call { +func (mr *MockSendAlgorithmWithDebugInfosMockRecorder) OnCongestionEvent(arg0, arg1, arg2 any) *MockSendAlgorithmWithDebugInfosOnCongestionEventCall { mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "OnCongestionEvent", reflect.TypeOf((*MockSendAlgorithmWithDebugInfos)(nil).OnCongestionEvent), arg0, arg1, arg2) + call := mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "OnCongestionEvent", reflect.TypeOf((*MockSendAlgorithmWithDebugInfos)(nil).OnCongestionEvent), arg0, arg1, arg2) + return &MockSendAlgorithmWithDebugInfosOnCongestionEventCall{Call: call} +} + +// MockSendAlgorithmWithDebugInfosOnCongestionEventCall wrap *gomock.Call +type MockSendAlgorithmWithDebugInfosOnCongestionEventCall struct { + *gomock.Call +} + +// Return rewrite *gomock.Call.Return +func (c *MockSendAlgorithmWithDebugInfosOnCongestionEventCall) Return() *MockSendAlgorithmWithDebugInfosOnCongestionEventCall { + c.Call = c.Call.Return() + return c +} + +// Do rewrite *gomock.Call.Do +func (c *MockSendAlgorithmWithDebugInfosOnCongestionEventCall) Do(f func(protocol.PacketNumber, protocol.ByteCount, protocol.ByteCount)) *MockSendAlgorithmWithDebugInfosOnCongestionEventCall { + c.Call = c.Call.Do(f) + return c +} + +// DoAndReturn rewrite *gomock.Call.DoAndReturn +func (c *MockSendAlgorithmWithDebugInfosOnCongestionEventCall) DoAndReturn(f func(protocol.PacketNumber, protocol.ByteCount, protocol.ByteCount)) *MockSendAlgorithmWithDebugInfosOnCongestionEventCall { + c.Call = c.Call.DoAndReturn(f) + return c } // OnPacketAcked mocks base method. @@ -140,9 +309,33 @@ func (m *MockSendAlgorithmWithDebugInfos) OnPacketAcked(arg0 protocol.PacketNumb } // OnPacketAcked indicates an expected call of OnPacketAcked. -func (mr *MockSendAlgorithmWithDebugInfosMockRecorder) OnPacketAcked(arg0, arg1, arg2, arg3 any) *gomock.Call { +func (mr *MockSendAlgorithmWithDebugInfosMockRecorder) OnPacketAcked(arg0, arg1, arg2, arg3 any) *MockSendAlgorithmWithDebugInfosOnPacketAckedCall { mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "OnPacketAcked", reflect.TypeOf((*MockSendAlgorithmWithDebugInfos)(nil).OnPacketAcked), arg0, arg1, arg2, arg3) + call := mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "OnPacketAcked", reflect.TypeOf((*MockSendAlgorithmWithDebugInfos)(nil).OnPacketAcked), arg0, arg1, arg2, arg3) + return &MockSendAlgorithmWithDebugInfosOnPacketAckedCall{Call: call} +} + +// MockSendAlgorithmWithDebugInfosOnPacketAckedCall wrap *gomock.Call +type MockSendAlgorithmWithDebugInfosOnPacketAckedCall struct { + *gomock.Call +} + +// Return rewrite *gomock.Call.Return +func (c *MockSendAlgorithmWithDebugInfosOnPacketAckedCall) Return() *MockSendAlgorithmWithDebugInfosOnPacketAckedCall { + c.Call = c.Call.Return() + return c +} + +// Do rewrite *gomock.Call.Do +func (c *MockSendAlgorithmWithDebugInfosOnPacketAckedCall) Do(f func(protocol.PacketNumber, protocol.ByteCount, protocol.ByteCount, time.Time)) *MockSendAlgorithmWithDebugInfosOnPacketAckedCall { + c.Call = c.Call.Do(f) + return c +} + +// DoAndReturn rewrite *gomock.Call.DoAndReturn +func (c *MockSendAlgorithmWithDebugInfosOnPacketAckedCall) DoAndReturn(f func(protocol.PacketNumber, protocol.ByteCount, protocol.ByteCount, time.Time)) *MockSendAlgorithmWithDebugInfosOnPacketAckedCall { + c.Call = c.Call.DoAndReturn(f) + return c } // OnPacketSent mocks base method. @@ -152,9 +345,33 @@ func (m *MockSendAlgorithmWithDebugInfos) OnPacketSent(arg0 time.Time, arg1 prot } // OnPacketSent indicates an expected call of OnPacketSent. -func (mr *MockSendAlgorithmWithDebugInfosMockRecorder) OnPacketSent(arg0, arg1, arg2, arg3, arg4 any) *gomock.Call { +func (mr *MockSendAlgorithmWithDebugInfosMockRecorder) OnPacketSent(arg0, arg1, arg2, arg3, arg4 any) *MockSendAlgorithmWithDebugInfosOnPacketSentCall { mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "OnPacketSent", reflect.TypeOf((*MockSendAlgorithmWithDebugInfos)(nil).OnPacketSent), arg0, arg1, arg2, arg3, arg4) + call := mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "OnPacketSent", reflect.TypeOf((*MockSendAlgorithmWithDebugInfos)(nil).OnPacketSent), arg0, arg1, arg2, arg3, arg4) + return &MockSendAlgorithmWithDebugInfosOnPacketSentCall{Call: call} +} + +// MockSendAlgorithmWithDebugInfosOnPacketSentCall wrap *gomock.Call +type MockSendAlgorithmWithDebugInfosOnPacketSentCall struct { + *gomock.Call +} + +// Return rewrite *gomock.Call.Return +func (c *MockSendAlgorithmWithDebugInfosOnPacketSentCall) Return() *MockSendAlgorithmWithDebugInfosOnPacketSentCall { + c.Call = c.Call.Return() + return c +} + +// Do rewrite *gomock.Call.Do +func (c *MockSendAlgorithmWithDebugInfosOnPacketSentCall) Do(f func(time.Time, protocol.ByteCount, protocol.PacketNumber, protocol.ByteCount, bool)) *MockSendAlgorithmWithDebugInfosOnPacketSentCall { + c.Call = c.Call.Do(f) + return c +} + +// DoAndReturn rewrite *gomock.Call.DoAndReturn +func (c *MockSendAlgorithmWithDebugInfosOnPacketSentCall) DoAndReturn(f func(time.Time, protocol.ByteCount, protocol.PacketNumber, protocol.ByteCount, bool)) *MockSendAlgorithmWithDebugInfosOnPacketSentCall { + c.Call = c.Call.DoAndReturn(f) + return c } // OnRetransmissionTimeout mocks base method. @@ -164,9 +381,33 @@ func (m *MockSendAlgorithmWithDebugInfos) OnRetransmissionTimeout(arg0 bool) { } // OnRetransmissionTimeout indicates an expected call of OnRetransmissionTimeout. -func (mr *MockSendAlgorithmWithDebugInfosMockRecorder) OnRetransmissionTimeout(arg0 any) *gomock.Call { +func (mr *MockSendAlgorithmWithDebugInfosMockRecorder) OnRetransmissionTimeout(arg0 any) *MockSendAlgorithmWithDebugInfosOnRetransmissionTimeoutCall { mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "OnRetransmissionTimeout", reflect.TypeOf((*MockSendAlgorithmWithDebugInfos)(nil).OnRetransmissionTimeout), arg0) + call := mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "OnRetransmissionTimeout", reflect.TypeOf((*MockSendAlgorithmWithDebugInfos)(nil).OnRetransmissionTimeout), arg0) + return &MockSendAlgorithmWithDebugInfosOnRetransmissionTimeoutCall{Call: call} +} + +// MockSendAlgorithmWithDebugInfosOnRetransmissionTimeoutCall wrap *gomock.Call +type MockSendAlgorithmWithDebugInfosOnRetransmissionTimeoutCall struct { + *gomock.Call +} + +// Return rewrite *gomock.Call.Return +func (c *MockSendAlgorithmWithDebugInfosOnRetransmissionTimeoutCall) Return() *MockSendAlgorithmWithDebugInfosOnRetransmissionTimeoutCall { + c.Call = c.Call.Return() + return c +} + +// Do rewrite *gomock.Call.Do +func (c *MockSendAlgorithmWithDebugInfosOnRetransmissionTimeoutCall) Do(f func(bool)) *MockSendAlgorithmWithDebugInfosOnRetransmissionTimeoutCall { + c.Call = c.Call.Do(f) + return c +} + +// DoAndReturn rewrite *gomock.Call.DoAndReturn +func (c *MockSendAlgorithmWithDebugInfosOnRetransmissionTimeoutCall) DoAndReturn(f func(bool)) *MockSendAlgorithmWithDebugInfosOnRetransmissionTimeoutCall { + c.Call = c.Call.DoAndReturn(f) + return c } // SetMaxDatagramSize mocks base method. @@ -176,9 +417,33 @@ func (m *MockSendAlgorithmWithDebugInfos) SetMaxDatagramSize(arg0 protocol.ByteC } // SetMaxDatagramSize indicates an expected call of SetMaxDatagramSize. -func (mr *MockSendAlgorithmWithDebugInfosMockRecorder) SetMaxDatagramSize(arg0 any) *gomock.Call { +func (mr *MockSendAlgorithmWithDebugInfosMockRecorder) SetMaxDatagramSize(arg0 any) *MockSendAlgorithmWithDebugInfosSetMaxDatagramSizeCall { mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "SetMaxDatagramSize", reflect.TypeOf((*MockSendAlgorithmWithDebugInfos)(nil).SetMaxDatagramSize), arg0) + call := mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "SetMaxDatagramSize", reflect.TypeOf((*MockSendAlgorithmWithDebugInfos)(nil).SetMaxDatagramSize), arg0) + return &MockSendAlgorithmWithDebugInfosSetMaxDatagramSizeCall{Call: call} +} + +// MockSendAlgorithmWithDebugInfosSetMaxDatagramSizeCall wrap *gomock.Call +type MockSendAlgorithmWithDebugInfosSetMaxDatagramSizeCall struct { + *gomock.Call +} + +// Return rewrite *gomock.Call.Return +func (c *MockSendAlgorithmWithDebugInfosSetMaxDatagramSizeCall) Return() *MockSendAlgorithmWithDebugInfosSetMaxDatagramSizeCall { + c.Call = c.Call.Return() + return c +} + +// Do rewrite *gomock.Call.Do +func (c *MockSendAlgorithmWithDebugInfosSetMaxDatagramSizeCall) Do(f func(protocol.ByteCount)) *MockSendAlgorithmWithDebugInfosSetMaxDatagramSizeCall { + c.Call = c.Call.Do(f) + return c +} + +// DoAndReturn rewrite *gomock.Call.DoAndReturn +func (c *MockSendAlgorithmWithDebugInfosSetMaxDatagramSizeCall) DoAndReturn(f func(protocol.ByteCount)) *MockSendAlgorithmWithDebugInfosSetMaxDatagramSizeCall { + c.Call = c.Call.DoAndReturn(f) + return c } // TimeUntilSend mocks base method. @@ -190,7 +455,31 @@ func (m *MockSendAlgorithmWithDebugInfos) TimeUntilSend(arg0 protocol.ByteCount) } // TimeUntilSend indicates an expected call of TimeUntilSend. -func (mr *MockSendAlgorithmWithDebugInfosMockRecorder) TimeUntilSend(arg0 any) *gomock.Call { +func (mr *MockSendAlgorithmWithDebugInfosMockRecorder) TimeUntilSend(arg0 any) *MockSendAlgorithmWithDebugInfosTimeUntilSendCall { mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "TimeUntilSend", reflect.TypeOf((*MockSendAlgorithmWithDebugInfos)(nil).TimeUntilSend), arg0) + call := mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "TimeUntilSend", reflect.TypeOf((*MockSendAlgorithmWithDebugInfos)(nil).TimeUntilSend), arg0) + return &MockSendAlgorithmWithDebugInfosTimeUntilSendCall{Call: call} +} + +// MockSendAlgorithmWithDebugInfosTimeUntilSendCall wrap *gomock.Call +type MockSendAlgorithmWithDebugInfosTimeUntilSendCall struct { + *gomock.Call +} + +// Return rewrite *gomock.Call.Return +func (c *MockSendAlgorithmWithDebugInfosTimeUntilSendCall) Return(arg0 time.Time) *MockSendAlgorithmWithDebugInfosTimeUntilSendCall { + c.Call = c.Call.Return(arg0) + return c +} + +// Do rewrite *gomock.Call.Do +func (c *MockSendAlgorithmWithDebugInfosTimeUntilSendCall) Do(f func(protocol.ByteCount) time.Time) *MockSendAlgorithmWithDebugInfosTimeUntilSendCall { + c.Call = c.Call.Do(f) + return c +} + +// DoAndReturn rewrite *gomock.Call.DoAndReturn +func (c *MockSendAlgorithmWithDebugInfosTimeUntilSendCall) DoAndReturn(f func(protocol.ByteCount) time.Time) *MockSendAlgorithmWithDebugInfosTimeUntilSendCall { + c.Call = c.Call.DoAndReturn(f) + return c } diff --git a/internal/mocks/connection_flow_controller.go b/internal/mocks/connection_flow_controller.go index a2edfc229..288016a8e 100644 --- a/internal/mocks/connection_flow_controller.go +++ b/internal/mocks/connection_flow_controller.go @@ -3,8 +3,9 @@ // // Generated by this command: // -// mockgen -build_flags=-tags=gomock -package mocks -destination connection_flow_controller.go github.com/refraction-networking/uquic/internal/flowcontrol ConnectionFlowController +// mockgen -typed -build_flags=-tags=gomock -package mocks -destination connection_flow_controller.go github.com/quic-go/quic-go/internal/flowcontrol ConnectionFlowController // + // Package mocks is a generated GoMock package. package mocks @@ -45,9 +46,33 @@ func (m *MockConnectionFlowController) AddBytesRead(arg0 protocol.ByteCount) { } // AddBytesRead indicates an expected call of AddBytesRead. -func (mr *MockConnectionFlowControllerMockRecorder) AddBytesRead(arg0 any) *gomock.Call { +func (mr *MockConnectionFlowControllerMockRecorder) AddBytesRead(arg0 any) *MockConnectionFlowControllerAddBytesReadCall { mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "AddBytesRead", reflect.TypeOf((*MockConnectionFlowController)(nil).AddBytesRead), arg0) + call := mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "AddBytesRead", reflect.TypeOf((*MockConnectionFlowController)(nil).AddBytesRead), arg0) + return &MockConnectionFlowControllerAddBytesReadCall{Call: call} +} + +// MockConnectionFlowControllerAddBytesReadCall wrap *gomock.Call +type MockConnectionFlowControllerAddBytesReadCall struct { + *gomock.Call +} + +// Return rewrite *gomock.Call.Return +func (c *MockConnectionFlowControllerAddBytesReadCall) Return() *MockConnectionFlowControllerAddBytesReadCall { + c.Call = c.Call.Return() + return c +} + +// Do rewrite *gomock.Call.Do +func (c *MockConnectionFlowControllerAddBytesReadCall) Do(f func(protocol.ByteCount)) *MockConnectionFlowControllerAddBytesReadCall { + c.Call = c.Call.Do(f) + return c +} + +// DoAndReturn rewrite *gomock.Call.DoAndReturn +func (c *MockConnectionFlowControllerAddBytesReadCall) DoAndReturn(f func(protocol.ByteCount)) *MockConnectionFlowControllerAddBytesReadCall { + c.Call = c.Call.DoAndReturn(f) + return c } // AddBytesSent mocks base method. @@ -57,9 +82,33 @@ func (m *MockConnectionFlowController) AddBytesSent(arg0 protocol.ByteCount) { } // AddBytesSent indicates an expected call of AddBytesSent. -func (mr *MockConnectionFlowControllerMockRecorder) AddBytesSent(arg0 any) *gomock.Call { +func (mr *MockConnectionFlowControllerMockRecorder) AddBytesSent(arg0 any) *MockConnectionFlowControllerAddBytesSentCall { mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "AddBytesSent", reflect.TypeOf((*MockConnectionFlowController)(nil).AddBytesSent), arg0) + call := mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "AddBytesSent", reflect.TypeOf((*MockConnectionFlowController)(nil).AddBytesSent), arg0) + return &MockConnectionFlowControllerAddBytesSentCall{Call: call} +} + +// MockConnectionFlowControllerAddBytesSentCall wrap *gomock.Call +type MockConnectionFlowControllerAddBytesSentCall struct { + *gomock.Call +} + +// Return rewrite *gomock.Call.Return +func (c *MockConnectionFlowControllerAddBytesSentCall) Return() *MockConnectionFlowControllerAddBytesSentCall { + c.Call = c.Call.Return() + return c +} + +// Do rewrite *gomock.Call.Do +func (c *MockConnectionFlowControllerAddBytesSentCall) Do(f func(protocol.ByteCount)) *MockConnectionFlowControllerAddBytesSentCall { + c.Call = c.Call.Do(f) + return c +} + +// DoAndReturn rewrite *gomock.Call.DoAndReturn +func (c *MockConnectionFlowControllerAddBytesSentCall) DoAndReturn(f func(protocol.ByteCount)) *MockConnectionFlowControllerAddBytesSentCall { + c.Call = c.Call.DoAndReturn(f) + return c } // GetWindowUpdate mocks base method. @@ -71,9 +120,33 @@ func (m *MockConnectionFlowController) GetWindowUpdate() protocol.ByteCount { } // GetWindowUpdate indicates an expected call of GetWindowUpdate. -func (mr *MockConnectionFlowControllerMockRecorder) GetWindowUpdate() *gomock.Call { +func (mr *MockConnectionFlowControllerMockRecorder) GetWindowUpdate() *MockConnectionFlowControllerGetWindowUpdateCall { mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetWindowUpdate", reflect.TypeOf((*MockConnectionFlowController)(nil).GetWindowUpdate)) + call := mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetWindowUpdate", reflect.TypeOf((*MockConnectionFlowController)(nil).GetWindowUpdate)) + return &MockConnectionFlowControllerGetWindowUpdateCall{Call: call} +} + +// MockConnectionFlowControllerGetWindowUpdateCall wrap *gomock.Call +type MockConnectionFlowControllerGetWindowUpdateCall struct { + *gomock.Call +} + +// Return rewrite *gomock.Call.Return +func (c *MockConnectionFlowControllerGetWindowUpdateCall) Return(arg0 protocol.ByteCount) *MockConnectionFlowControllerGetWindowUpdateCall { + c.Call = c.Call.Return(arg0) + return c +} + +// Do rewrite *gomock.Call.Do +func (c *MockConnectionFlowControllerGetWindowUpdateCall) Do(f func() protocol.ByteCount) *MockConnectionFlowControllerGetWindowUpdateCall { + c.Call = c.Call.Do(f) + return c +} + +// DoAndReturn rewrite *gomock.Call.DoAndReturn +func (c *MockConnectionFlowControllerGetWindowUpdateCall) DoAndReturn(f func() protocol.ByteCount) *MockConnectionFlowControllerGetWindowUpdateCall { + c.Call = c.Call.DoAndReturn(f) + return c } // IsNewlyBlocked mocks base method. @@ -86,9 +159,33 @@ func (m *MockConnectionFlowController) IsNewlyBlocked() (bool, protocol.ByteCoun } // IsNewlyBlocked indicates an expected call of IsNewlyBlocked. -func (mr *MockConnectionFlowControllerMockRecorder) IsNewlyBlocked() *gomock.Call { +func (mr *MockConnectionFlowControllerMockRecorder) IsNewlyBlocked() *MockConnectionFlowControllerIsNewlyBlockedCall { mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "IsNewlyBlocked", reflect.TypeOf((*MockConnectionFlowController)(nil).IsNewlyBlocked)) + call := mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "IsNewlyBlocked", reflect.TypeOf((*MockConnectionFlowController)(nil).IsNewlyBlocked)) + return &MockConnectionFlowControllerIsNewlyBlockedCall{Call: call} +} + +// MockConnectionFlowControllerIsNewlyBlockedCall wrap *gomock.Call +type MockConnectionFlowControllerIsNewlyBlockedCall struct { + *gomock.Call +} + +// Return rewrite *gomock.Call.Return +func (c *MockConnectionFlowControllerIsNewlyBlockedCall) Return(arg0 bool, arg1 protocol.ByteCount) *MockConnectionFlowControllerIsNewlyBlockedCall { + c.Call = c.Call.Return(arg0, arg1) + return c +} + +// Do rewrite *gomock.Call.Do +func (c *MockConnectionFlowControllerIsNewlyBlockedCall) Do(f func() (bool, protocol.ByteCount)) *MockConnectionFlowControllerIsNewlyBlockedCall { + c.Call = c.Call.Do(f) + return c +} + +// DoAndReturn rewrite *gomock.Call.DoAndReturn +func (c *MockConnectionFlowControllerIsNewlyBlockedCall) DoAndReturn(f func() (bool, protocol.ByteCount)) *MockConnectionFlowControllerIsNewlyBlockedCall { + c.Call = c.Call.DoAndReturn(f) + return c } // Reset mocks base method. @@ -100,9 +197,33 @@ func (m *MockConnectionFlowController) Reset() error { } // Reset indicates an expected call of Reset. -func (mr *MockConnectionFlowControllerMockRecorder) Reset() *gomock.Call { +func (mr *MockConnectionFlowControllerMockRecorder) Reset() *MockConnectionFlowControllerResetCall { mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Reset", reflect.TypeOf((*MockConnectionFlowController)(nil).Reset)) + call := mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Reset", reflect.TypeOf((*MockConnectionFlowController)(nil).Reset)) + return &MockConnectionFlowControllerResetCall{Call: call} +} + +// MockConnectionFlowControllerResetCall wrap *gomock.Call +type MockConnectionFlowControllerResetCall struct { + *gomock.Call +} + +// Return rewrite *gomock.Call.Return +func (c *MockConnectionFlowControllerResetCall) Return(arg0 error) *MockConnectionFlowControllerResetCall { + c.Call = c.Call.Return(arg0) + return c +} + +// Do rewrite *gomock.Call.Do +func (c *MockConnectionFlowControllerResetCall) Do(f func() error) *MockConnectionFlowControllerResetCall { + c.Call = c.Call.Do(f) + return c +} + +// DoAndReturn rewrite *gomock.Call.DoAndReturn +func (c *MockConnectionFlowControllerResetCall) DoAndReturn(f func() error) *MockConnectionFlowControllerResetCall { + c.Call = c.Call.DoAndReturn(f) + return c } // SendWindowSize mocks base method. @@ -114,19 +235,69 @@ func (m *MockConnectionFlowController) SendWindowSize() protocol.ByteCount { } // SendWindowSize indicates an expected call of SendWindowSize. -func (mr *MockConnectionFlowControllerMockRecorder) SendWindowSize() *gomock.Call { +func (mr *MockConnectionFlowControllerMockRecorder) SendWindowSize() *MockConnectionFlowControllerSendWindowSizeCall { mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "SendWindowSize", reflect.TypeOf((*MockConnectionFlowController)(nil).SendWindowSize)) + call := mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "SendWindowSize", reflect.TypeOf((*MockConnectionFlowController)(nil).SendWindowSize)) + return &MockConnectionFlowControllerSendWindowSizeCall{Call: call} +} + +// MockConnectionFlowControllerSendWindowSizeCall wrap *gomock.Call +type MockConnectionFlowControllerSendWindowSizeCall struct { + *gomock.Call +} + +// Return rewrite *gomock.Call.Return +func (c *MockConnectionFlowControllerSendWindowSizeCall) Return(arg0 protocol.ByteCount) *MockConnectionFlowControllerSendWindowSizeCall { + c.Call = c.Call.Return(arg0) + return c +} + +// Do rewrite *gomock.Call.Do +func (c *MockConnectionFlowControllerSendWindowSizeCall) Do(f func() protocol.ByteCount) *MockConnectionFlowControllerSendWindowSizeCall { + c.Call = c.Call.Do(f) + return c +} + +// DoAndReturn rewrite *gomock.Call.DoAndReturn +func (c *MockConnectionFlowControllerSendWindowSizeCall) DoAndReturn(f func() protocol.ByteCount) *MockConnectionFlowControllerSendWindowSizeCall { + c.Call = c.Call.DoAndReturn(f) + return c } // UpdateSendWindow mocks base method. -func (m *MockConnectionFlowController) UpdateSendWindow(arg0 protocol.ByteCount) { +func (m *MockConnectionFlowController) UpdateSendWindow(arg0 protocol.ByteCount) bool { m.ctrl.T.Helper() - m.ctrl.Call(m, "UpdateSendWindow", arg0) + ret := m.ctrl.Call(m, "UpdateSendWindow", arg0) + ret0, _ := ret[0].(bool) + return ret0 } // UpdateSendWindow indicates an expected call of UpdateSendWindow. -func (mr *MockConnectionFlowControllerMockRecorder) UpdateSendWindow(arg0 any) *gomock.Call { +func (mr *MockConnectionFlowControllerMockRecorder) UpdateSendWindow(arg0 any) *MockConnectionFlowControllerUpdateSendWindowCall { mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "UpdateSendWindow", reflect.TypeOf((*MockConnectionFlowController)(nil).UpdateSendWindow), arg0) + call := mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "UpdateSendWindow", reflect.TypeOf((*MockConnectionFlowController)(nil).UpdateSendWindow), arg0) + return &MockConnectionFlowControllerUpdateSendWindowCall{Call: call} +} + +// MockConnectionFlowControllerUpdateSendWindowCall wrap *gomock.Call +type MockConnectionFlowControllerUpdateSendWindowCall struct { + *gomock.Call +} + +// Return rewrite *gomock.Call.Return +func (c *MockConnectionFlowControllerUpdateSendWindowCall) Return(arg0 bool) *MockConnectionFlowControllerUpdateSendWindowCall { + c.Call = c.Call.Return(arg0) + return c +} + +// Do rewrite *gomock.Call.Do +func (c *MockConnectionFlowControllerUpdateSendWindowCall) Do(f func(protocol.ByteCount) bool) *MockConnectionFlowControllerUpdateSendWindowCall { + c.Call = c.Call.Do(f) + return c +} + +// DoAndReturn rewrite *gomock.Call.DoAndReturn +func (c *MockConnectionFlowControllerUpdateSendWindowCall) DoAndReturn(f func(protocol.ByteCount) bool) *MockConnectionFlowControllerUpdateSendWindowCall { + c.Call = c.Call.DoAndReturn(f) + return c } diff --git a/internal/mocks/crypto_setup.go b/internal/mocks/crypto_setup.go index e4151c46c..2df19355b 100644 --- a/internal/mocks/crypto_setup.go +++ b/internal/mocks/crypto_setup.go @@ -3,8 +3,9 @@ // // Generated by this command: // -// mockgen -build_flags=-tags=gomock -package mocks -destination crypto_setup_tmp.go github.com/refraction-networking/uquic/internal/handshake CryptoSetup +// mockgen -typed -build_flags=-tags=gomock -package mocks -destination crypto_setup_tmp.go github.com/quic-go/quic-go/internal/handshake CryptoSetup // + // Package mocks is a generated GoMock package. package mocks @@ -46,9 +47,33 @@ func (m *MockCryptoSetup) ChangeConnectionID(arg0 protocol.ConnectionID) { } // ChangeConnectionID indicates an expected call of ChangeConnectionID. -func (mr *MockCryptoSetupMockRecorder) ChangeConnectionID(arg0 any) *gomock.Call { +func (mr *MockCryptoSetupMockRecorder) ChangeConnectionID(arg0 any) *MockCryptoSetupChangeConnectionIDCall { mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "ChangeConnectionID", reflect.TypeOf((*MockCryptoSetup)(nil).ChangeConnectionID), arg0) + call := mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "ChangeConnectionID", reflect.TypeOf((*MockCryptoSetup)(nil).ChangeConnectionID), arg0) + return &MockCryptoSetupChangeConnectionIDCall{Call: call} +} + +// MockCryptoSetupChangeConnectionIDCall wrap *gomock.Call +type MockCryptoSetupChangeConnectionIDCall struct { + *gomock.Call +} + +// Return rewrite *gomock.Call.Return +func (c *MockCryptoSetupChangeConnectionIDCall) Return() *MockCryptoSetupChangeConnectionIDCall { + c.Call = c.Call.Return() + return c +} + +// Do rewrite *gomock.Call.Do +func (c *MockCryptoSetupChangeConnectionIDCall) Do(f func(protocol.ConnectionID)) *MockCryptoSetupChangeConnectionIDCall { + c.Call = c.Call.Do(f) + return c +} + +// DoAndReturn rewrite *gomock.Call.DoAndReturn +func (c *MockCryptoSetupChangeConnectionIDCall) DoAndReturn(f func(protocol.ConnectionID)) *MockCryptoSetupChangeConnectionIDCall { + c.Call = c.Call.DoAndReturn(f) + return c } // Close mocks base method. @@ -60,9 +85,33 @@ func (m *MockCryptoSetup) Close() error { } // Close indicates an expected call of Close. -func (mr *MockCryptoSetupMockRecorder) Close() *gomock.Call { +func (mr *MockCryptoSetupMockRecorder) Close() *MockCryptoSetupCloseCall { mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Close", reflect.TypeOf((*MockCryptoSetup)(nil).Close)) + call := mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Close", reflect.TypeOf((*MockCryptoSetup)(nil).Close)) + return &MockCryptoSetupCloseCall{Call: call} +} + +// MockCryptoSetupCloseCall wrap *gomock.Call +type MockCryptoSetupCloseCall struct { + *gomock.Call +} + +// Return rewrite *gomock.Call.Return +func (c *MockCryptoSetupCloseCall) Return(arg0 error) *MockCryptoSetupCloseCall { + c.Call = c.Call.Return(arg0) + return c +} + +// Do rewrite *gomock.Call.Do +func (c *MockCryptoSetupCloseCall) Do(f func() error) *MockCryptoSetupCloseCall { + c.Call = c.Call.Do(f) + return c +} + +// DoAndReturn rewrite *gomock.Call.DoAndReturn +func (c *MockCryptoSetupCloseCall) DoAndReturn(f func() error) *MockCryptoSetupCloseCall { + c.Call = c.Call.DoAndReturn(f) + return c } // ConnectionState mocks base method. @@ -74,9 +123,33 @@ func (m *MockCryptoSetup) ConnectionState() handshake.ConnectionState { } // ConnectionState indicates an expected call of ConnectionState. -func (mr *MockCryptoSetupMockRecorder) ConnectionState() *gomock.Call { +func (mr *MockCryptoSetupMockRecorder) ConnectionState() *MockCryptoSetupConnectionStateCall { mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "ConnectionState", reflect.TypeOf((*MockCryptoSetup)(nil).ConnectionState)) + call := mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "ConnectionState", reflect.TypeOf((*MockCryptoSetup)(nil).ConnectionState)) + return &MockCryptoSetupConnectionStateCall{Call: call} +} + +// MockCryptoSetupConnectionStateCall wrap *gomock.Call +type MockCryptoSetupConnectionStateCall struct { + *gomock.Call +} + +// Return rewrite *gomock.Call.Return +func (c *MockCryptoSetupConnectionStateCall) Return(arg0 handshake.ConnectionState) *MockCryptoSetupConnectionStateCall { + c.Call = c.Call.Return(arg0) + return c +} + +// Do rewrite *gomock.Call.Do +func (c *MockCryptoSetupConnectionStateCall) Do(f func() handshake.ConnectionState) *MockCryptoSetupConnectionStateCall { + c.Call = c.Call.Do(f) + return c +} + +// DoAndReturn rewrite *gomock.Call.DoAndReturn +func (c *MockCryptoSetupConnectionStateCall) DoAndReturn(f func() handshake.ConnectionState) *MockCryptoSetupConnectionStateCall { + c.Call = c.Call.DoAndReturn(f) + return c } // DiscardInitialKeys mocks base method. @@ -86,9 +159,33 @@ func (m *MockCryptoSetup) DiscardInitialKeys() { } // DiscardInitialKeys indicates an expected call of DiscardInitialKeys. -func (mr *MockCryptoSetupMockRecorder) DiscardInitialKeys() *gomock.Call { +func (mr *MockCryptoSetupMockRecorder) DiscardInitialKeys() *MockCryptoSetupDiscardInitialKeysCall { mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "DiscardInitialKeys", reflect.TypeOf((*MockCryptoSetup)(nil).DiscardInitialKeys)) + call := mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "DiscardInitialKeys", reflect.TypeOf((*MockCryptoSetup)(nil).DiscardInitialKeys)) + return &MockCryptoSetupDiscardInitialKeysCall{Call: call} +} + +// MockCryptoSetupDiscardInitialKeysCall wrap *gomock.Call +type MockCryptoSetupDiscardInitialKeysCall struct { + *gomock.Call +} + +// Return rewrite *gomock.Call.Return +func (c *MockCryptoSetupDiscardInitialKeysCall) Return() *MockCryptoSetupDiscardInitialKeysCall { + c.Call = c.Call.Return() + return c +} + +// Do rewrite *gomock.Call.Do +func (c *MockCryptoSetupDiscardInitialKeysCall) Do(f func()) *MockCryptoSetupDiscardInitialKeysCall { + c.Call = c.Call.Do(f) + return c +} + +// DoAndReturn rewrite *gomock.Call.DoAndReturn +func (c *MockCryptoSetupDiscardInitialKeysCall) DoAndReturn(f func()) *MockCryptoSetupDiscardInitialKeysCall { + c.Call = c.Call.DoAndReturn(f) + return c } // Get0RTTOpener mocks base method. @@ -101,9 +198,33 @@ func (m *MockCryptoSetup) Get0RTTOpener() (handshake.LongHeaderOpener, error) { } // Get0RTTOpener indicates an expected call of Get0RTTOpener. -func (mr *MockCryptoSetupMockRecorder) Get0RTTOpener() *gomock.Call { +func (mr *MockCryptoSetupMockRecorder) Get0RTTOpener() *MockCryptoSetupGet0RTTOpenerCall { mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Get0RTTOpener", reflect.TypeOf((*MockCryptoSetup)(nil).Get0RTTOpener)) + call := mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Get0RTTOpener", reflect.TypeOf((*MockCryptoSetup)(nil).Get0RTTOpener)) + return &MockCryptoSetupGet0RTTOpenerCall{Call: call} +} + +// MockCryptoSetupGet0RTTOpenerCall wrap *gomock.Call +type MockCryptoSetupGet0RTTOpenerCall struct { + *gomock.Call +} + +// Return rewrite *gomock.Call.Return +func (c *MockCryptoSetupGet0RTTOpenerCall) Return(arg0 handshake.LongHeaderOpener, arg1 error) *MockCryptoSetupGet0RTTOpenerCall { + c.Call = c.Call.Return(arg0, arg1) + return c +} + +// Do rewrite *gomock.Call.Do +func (c *MockCryptoSetupGet0RTTOpenerCall) Do(f func() (handshake.LongHeaderOpener, error)) *MockCryptoSetupGet0RTTOpenerCall { + c.Call = c.Call.Do(f) + return c +} + +// DoAndReturn rewrite *gomock.Call.DoAndReturn +func (c *MockCryptoSetupGet0RTTOpenerCall) DoAndReturn(f func() (handshake.LongHeaderOpener, error)) *MockCryptoSetupGet0RTTOpenerCall { + c.Call = c.Call.DoAndReturn(f) + return c } // Get0RTTSealer mocks base method. @@ -116,9 +237,33 @@ func (m *MockCryptoSetup) Get0RTTSealer() (handshake.LongHeaderSealer, error) { } // Get0RTTSealer indicates an expected call of Get0RTTSealer. -func (mr *MockCryptoSetupMockRecorder) Get0RTTSealer() *gomock.Call { +func (mr *MockCryptoSetupMockRecorder) Get0RTTSealer() *MockCryptoSetupGet0RTTSealerCall { mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Get0RTTSealer", reflect.TypeOf((*MockCryptoSetup)(nil).Get0RTTSealer)) + call := mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Get0RTTSealer", reflect.TypeOf((*MockCryptoSetup)(nil).Get0RTTSealer)) + return &MockCryptoSetupGet0RTTSealerCall{Call: call} +} + +// MockCryptoSetupGet0RTTSealerCall wrap *gomock.Call +type MockCryptoSetupGet0RTTSealerCall struct { + *gomock.Call +} + +// Return rewrite *gomock.Call.Return +func (c *MockCryptoSetupGet0RTTSealerCall) Return(arg0 handshake.LongHeaderSealer, arg1 error) *MockCryptoSetupGet0RTTSealerCall { + c.Call = c.Call.Return(arg0, arg1) + return c +} + +// Do rewrite *gomock.Call.Do +func (c *MockCryptoSetupGet0RTTSealerCall) Do(f func() (handshake.LongHeaderSealer, error)) *MockCryptoSetupGet0RTTSealerCall { + c.Call = c.Call.Do(f) + return c +} + +// DoAndReturn rewrite *gomock.Call.DoAndReturn +func (c *MockCryptoSetupGet0RTTSealerCall) DoAndReturn(f func() (handshake.LongHeaderSealer, error)) *MockCryptoSetupGet0RTTSealerCall { + c.Call = c.Call.DoAndReturn(f) + return c } // Get1RTTOpener mocks base method. @@ -131,9 +276,33 @@ func (m *MockCryptoSetup) Get1RTTOpener() (handshake.ShortHeaderOpener, error) { } // Get1RTTOpener indicates an expected call of Get1RTTOpener. -func (mr *MockCryptoSetupMockRecorder) Get1RTTOpener() *gomock.Call { +func (mr *MockCryptoSetupMockRecorder) Get1RTTOpener() *MockCryptoSetupGet1RTTOpenerCall { mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Get1RTTOpener", reflect.TypeOf((*MockCryptoSetup)(nil).Get1RTTOpener)) + call := mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Get1RTTOpener", reflect.TypeOf((*MockCryptoSetup)(nil).Get1RTTOpener)) + return &MockCryptoSetupGet1RTTOpenerCall{Call: call} +} + +// MockCryptoSetupGet1RTTOpenerCall wrap *gomock.Call +type MockCryptoSetupGet1RTTOpenerCall struct { + *gomock.Call +} + +// Return rewrite *gomock.Call.Return +func (c *MockCryptoSetupGet1RTTOpenerCall) Return(arg0 handshake.ShortHeaderOpener, arg1 error) *MockCryptoSetupGet1RTTOpenerCall { + c.Call = c.Call.Return(arg0, arg1) + return c +} + +// Do rewrite *gomock.Call.Do +func (c *MockCryptoSetupGet1RTTOpenerCall) Do(f func() (handshake.ShortHeaderOpener, error)) *MockCryptoSetupGet1RTTOpenerCall { + c.Call = c.Call.Do(f) + return c +} + +// DoAndReturn rewrite *gomock.Call.DoAndReturn +func (c *MockCryptoSetupGet1RTTOpenerCall) DoAndReturn(f func() (handshake.ShortHeaderOpener, error)) *MockCryptoSetupGet1RTTOpenerCall { + c.Call = c.Call.DoAndReturn(f) + return c } // Get1RTTSealer mocks base method. @@ -146,9 +315,33 @@ func (m *MockCryptoSetup) Get1RTTSealer() (handshake.ShortHeaderSealer, error) { } // Get1RTTSealer indicates an expected call of Get1RTTSealer. -func (mr *MockCryptoSetupMockRecorder) Get1RTTSealer() *gomock.Call { +func (mr *MockCryptoSetupMockRecorder) Get1RTTSealer() *MockCryptoSetupGet1RTTSealerCall { mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Get1RTTSealer", reflect.TypeOf((*MockCryptoSetup)(nil).Get1RTTSealer)) + call := mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Get1RTTSealer", reflect.TypeOf((*MockCryptoSetup)(nil).Get1RTTSealer)) + return &MockCryptoSetupGet1RTTSealerCall{Call: call} +} + +// MockCryptoSetupGet1RTTSealerCall wrap *gomock.Call +type MockCryptoSetupGet1RTTSealerCall struct { + *gomock.Call +} + +// Return rewrite *gomock.Call.Return +func (c *MockCryptoSetupGet1RTTSealerCall) Return(arg0 handshake.ShortHeaderSealer, arg1 error) *MockCryptoSetupGet1RTTSealerCall { + c.Call = c.Call.Return(arg0, arg1) + return c +} + +// Do rewrite *gomock.Call.Do +func (c *MockCryptoSetupGet1RTTSealerCall) Do(f func() (handshake.ShortHeaderSealer, error)) *MockCryptoSetupGet1RTTSealerCall { + c.Call = c.Call.Do(f) + return c +} + +// DoAndReturn rewrite *gomock.Call.DoAndReturn +func (c *MockCryptoSetupGet1RTTSealerCall) DoAndReturn(f func() (handshake.ShortHeaderSealer, error)) *MockCryptoSetupGet1RTTSealerCall { + c.Call = c.Call.DoAndReturn(f) + return c } // GetHandshakeOpener mocks base method. @@ -161,9 +354,33 @@ func (m *MockCryptoSetup) GetHandshakeOpener() (handshake.LongHeaderOpener, erro } // GetHandshakeOpener indicates an expected call of GetHandshakeOpener. -func (mr *MockCryptoSetupMockRecorder) GetHandshakeOpener() *gomock.Call { +func (mr *MockCryptoSetupMockRecorder) GetHandshakeOpener() *MockCryptoSetupGetHandshakeOpenerCall { mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetHandshakeOpener", reflect.TypeOf((*MockCryptoSetup)(nil).GetHandshakeOpener)) + call := mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetHandshakeOpener", reflect.TypeOf((*MockCryptoSetup)(nil).GetHandshakeOpener)) + return &MockCryptoSetupGetHandshakeOpenerCall{Call: call} +} + +// MockCryptoSetupGetHandshakeOpenerCall wrap *gomock.Call +type MockCryptoSetupGetHandshakeOpenerCall struct { + *gomock.Call +} + +// Return rewrite *gomock.Call.Return +func (c *MockCryptoSetupGetHandshakeOpenerCall) Return(arg0 handshake.LongHeaderOpener, arg1 error) *MockCryptoSetupGetHandshakeOpenerCall { + c.Call = c.Call.Return(arg0, arg1) + return c +} + +// Do rewrite *gomock.Call.Do +func (c *MockCryptoSetupGetHandshakeOpenerCall) Do(f func() (handshake.LongHeaderOpener, error)) *MockCryptoSetupGetHandshakeOpenerCall { + c.Call = c.Call.Do(f) + return c +} + +// DoAndReturn rewrite *gomock.Call.DoAndReturn +func (c *MockCryptoSetupGetHandshakeOpenerCall) DoAndReturn(f func() (handshake.LongHeaderOpener, error)) *MockCryptoSetupGetHandshakeOpenerCall { + c.Call = c.Call.DoAndReturn(f) + return c } // GetHandshakeSealer mocks base method. @@ -176,9 +393,33 @@ func (m *MockCryptoSetup) GetHandshakeSealer() (handshake.LongHeaderSealer, erro } // GetHandshakeSealer indicates an expected call of GetHandshakeSealer. -func (mr *MockCryptoSetupMockRecorder) GetHandshakeSealer() *gomock.Call { +func (mr *MockCryptoSetupMockRecorder) GetHandshakeSealer() *MockCryptoSetupGetHandshakeSealerCall { mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetHandshakeSealer", reflect.TypeOf((*MockCryptoSetup)(nil).GetHandshakeSealer)) + call := mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetHandshakeSealer", reflect.TypeOf((*MockCryptoSetup)(nil).GetHandshakeSealer)) + return &MockCryptoSetupGetHandshakeSealerCall{Call: call} +} + +// MockCryptoSetupGetHandshakeSealerCall wrap *gomock.Call +type MockCryptoSetupGetHandshakeSealerCall struct { + *gomock.Call +} + +// Return rewrite *gomock.Call.Return +func (c *MockCryptoSetupGetHandshakeSealerCall) Return(arg0 handshake.LongHeaderSealer, arg1 error) *MockCryptoSetupGetHandshakeSealerCall { + c.Call = c.Call.Return(arg0, arg1) + return c +} + +// Do rewrite *gomock.Call.Do +func (c *MockCryptoSetupGetHandshakeSealerCall) Do(f func() (handshake.LongHeaderSealer, error)) *MockCryptoSetupGetHandshakeSealerCall { + c.Call = c.Call.Do(f) + return c +} + +// DoAndReturn rewrite *gomock.Call.DoAndReturn +func (c *MockCryptoSetupGetHandshakeSealerCall) DoAndReturn(f func() (handshake.LongHeaderSealer, error)) *MockCryptoSetupGetHandshakeSealerCall { + c.Call = c.Call.DoAndReturn(f) + return c } // GetInitialOpener mocks base method. @@ -191,9 +432,33 @@ func (m *MockCryptoSetup) GetInitialOpener() (handshake.LongHeaderOpener, error) } // GetInitialOpener indicates an expected call of GetInitialOpener. -func (mr *MockCryptoSetupMockRecorder) GetInitialOpener() *gomock.Call { +func (mr *MockCryptoSetupMockRecorder) GetInitialOpener() *MockCryptoSetupGetInitialOpenerCall { mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetInitialOpener", reflect.TypeOf((*MockCryptoSetup)(nil).GetInitialOpener)) + call := mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetInitialOpener", reflect.TypeOf((*MockCryptoSetup)(nil).GetInitialOpener)) + return &MockCryptoSetupGetInitialOpenerCall{Call: call} +} + +// MockCryptoSetupGetInitialOpenerCall wrap *gomock.Call +type MockCryptoSetupGetInitialOpenerCall struct { + *gomock.Call +} + +// Return rewrite *gomock.Call.Return +func (c *MockCryptoSetupGetInitialOpenerCall) Return(arg0 handshake.LongHeaderOpener, arg1 error) *MockCryptoSetupGetInitialOpenerCall { + c.Call = c.Call.Return(arg0, arg1) + return c +} + +// Do rewrite *gomock.Call.Do +func (c *MockCryptoSetupGetInitialOpenerCall) Do(f func() (handshake.LongHeaderOpener, error)) *MockCryptoSetupGetInitialOpenerCall { + c.Call = c.Call.Do(f) + return c +} + +// DoAndReturn rewrite *gomock.Call.DoAndReturn +func (c *MockCryptoSetupGetInitialOpenerCall) DoAndReturn(f func() (handshake.LongHeaderOpener, error)) *MockCryptoSetupGetInitialOpenerCall { + c.Call = c.Call.DoAndReturn(f) + return c } // GetInitialSealer mocks base method. @@ -206,9 +471,33 @@ func (m *MockCryptoSetup) GetInitialSealer() (handshake.LongHeaderSealer, error) } // GetInitialSealer indicates an expected call of GetInitialSealer. -func (mr *MockCryptoSetupMockRecorder) GetInitialSealer() *gomock.Call { +func (mr *MockCryptoSetupMockRecorder) GetInitialSealer() *MockCryptoSetupGetInitialSealerCall { mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetInitialSealer", reflect.TypeOf((*MockCryptoSetup)(nil).GetInitialSealer)) + call := mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetInitialSealer", reflect.TypeOf((*MockCryptoSetup)(nil).GetInitialSealer)) + return &MockCryptoSetupGetInitialSealerCall{Call: call} +} + +// MockCryptoSetupGetInitialSealerCall wrap *gomock.Call +type MockCryptoSetupGetInitialSealerCall struct { + *gomock.Call +} + +// Return rewrite *gomock.Call.Return +func (c *MockCryptoSetupGetInitialSealerCall) Return(arg0 handshake.LongHeaderSealer, arg1 error) *MockCryptoSetupGetInitialSealerCall { + c.Call = c.Call.Return(arg0, arg1) + return c +} + +// Do rewrite *gomock.Call.Do +func (c *MockCryptoSetupGetInitialSealerCall) Do(f func() (handshake.LongHeaderSealer, error)) *MockCryptoSetupGetInitialSealerCall { + c.Call = c.Call.Do(f) + return c +} + +// DoAndReturn rewrite *gomock.Call.DoAndReturn +func (c *MockCryptoSetupGetInitialSealerCall) DoAndReturn(f func() (handshake.LongHeaderSealer, error)) *MockCryptoSetupGetInitialSealerCall { + c.Call = c.Call.DoAndReturn(f) + return c } // GetSessionTicket mocks base method. @@ -221,9 +510,33 @@ func (m *MockCryptoSetup) GetSessionTicket() ([]byte, error) { } // GetSessionTicket indicates an expected call of GetSessionTicket. -func (mr *MockCryptoSetupMockRecorder) GetSessionTicket() *gomock.Call { +func (mr *MockCryptoSetupMockRecorder) GetSessionTicket() *MockCryptoSetupGetSessionTicketCall { mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetSessionTicket", reflect.TypeOf((*MockCryptoSetup)(nil).GetSessionTicket)) + call := mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetSessionTicket", reflect.TypeOf((*MockCryptoSetup)(nil).GetSessionTicket)) + return &MockCryptoSetupGetSessionTicketCall{Call: call} +} + +// MockCryptoSetupGetSessionTicketCall wrap *gomock.Call +type MockCryptoSetupGetSessionTicketCall struct { + *gomock.Call +} + +// Return rewrite *gomock.Call.Return +func (c *MockCryptoSetupGetSessionTicketCall) Return(arg0 []byte, arg1 error) *MockCryptoSetupGetSessionTicketCall { + c.Call = c.Call.Return(arg0, arg1) + return c +} + +// Do rewrite *gomock.Call.Do +func (c *MockCryptoSetupGetSessionTicketCall) Do(f func() ([]byte, error)) *MockCryptoSetupGetSessionTicketCall { + c.Call = c.Call.Do(f) + return c +} + +// DoAndReturn rewrite *gomock.Call.DoAndReturn +func (c *MockCryptoSetupGetSessionTicketCall) DoAndReturn(f func() ([]byte, error)) *MockCryptoSetupGetSessionTicketCall { + c.Call = c.Call.DoAndReturn(f) + return c } // HandleMessage mocks base method. @@ -235,9 +548,33 @@ func (m *MockCryptoSetup) HandleMessage(arg0 []byte, arg1 protocol.EncryptionLev } // HandleMessage indicates an expected call of HandleMessage. -func (mr *MockCryptoSetupMockRecorder) HandleMessage(arg0, arg1 any) *gomock.Call { +func (mr *MockCryptoSetupMockRecorder) HandleMessage(arg0, arg1 any) *MockCryptoSetupHandleMessageCall { mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "HandleMessage", reflect.TypeOf((*MockCryptoSetup)(nil).HandleMessage), arg0, arg1) + call := mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "HandleMessage", reflect.TypeOf((*MockCryptoSetup)(nil).HandleMessage), arg0, arg1) + return &MockCryptoSetupHandleMessageCall{Call: call} +} + +// MockCryptoSetupHandleMessageCall wrap *gomock.Call +type MockCryptoSetupHandleMessageCall struct { + *gomock.Call +} + +// Return rewrite *gomock.Call.Return +func (c *MockCryptoSetupHandleMessageCall) Return(arg0 error) *MockCryptoSetupHandleMessageCall { + c.Call = c.Call.Return(arg0) + return c +} + +// Do rewrite *gomock.Call.Do +func (c *MockCryptoSetupHandleMessageCall) Do(f func([]byte, protocol.EncryptionLevel) error) *MockCryptoSetupHandleMessageCall { + c.Call = c.Call.Do(f) + return c +} + +// DoAndReturn rewrite *gomock.Call.DoAndReturn +func (c *MockCryptoSetupHandleMessageCall) DoAndReturn(f func([]byte, protocol.EncryptionLevel) error) *MockCryptoSetupHandleMessageCall { + c.Call = c.Call.DoAndReturn(f) + return c } // NextEvent mocks base method. @@ -249,9 +586,33 @@ func (m *MockCryptoSetup) NextEvent() handshake.Event { } // NextEvent indicates an expected call of NextEvent. -func (mr *MockCryptoSetupMockRecorder) NextEvent() *gomock.Call { +func (mr *MockCryptoSetupMockRecorder) NextEvent() *MockCryptoSetupNextEventCall { mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "NextEvent", reflect.TypeOf((*MockCryptoSetup)(nil).NextEvent)) + call := mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "NextEvent", reflect.TypeOf((*MockCryptoSetup)(nil).NextEvent)) + return &MockCryptoSetupNextEventCall{Call: call} +} + +// MockCryptoSetupNextEventCall wrap *gomock.Call +type MockCryptoSetupNextEventCall struct { + *gomock.Call +} + +// Return rewrite *gomock.Call.Return +func (c *MockCryptoSetupNextEventCall) Return(arg0 handshake.Event) *MockCryptoSetupNextEventCall { + c.Call = c.Call.Return(arg0) + return c +} + +// Do rewrite *gomock.Call.Do +func (c *MockCryptoSetupNextEventCall) Do(f func() handshake.Event) *MockCryptoSetupNextEventCall { + c.Call = c.Call.Do(f) + return c +} + +// DoAndReturn rewrite *gomock.Call.DoAndReturn +func (c *MockCryptoSetupNextEventCall) DoAndReturn(f func() handshake.Event) *MockCryptoSetupNextEventCall { + c.Call = c.Call.DoAndReturn(f) + return c } // SetHandshakeConfirmed mocks base method. @@ -261,9 +622,33 @@ func (m *MockCryptoSetup) SetHandshakeConfirmed() { } // SetHandshakeConfirmed indicates an expected call of SetHandshakeConfirmed. -func (mr *MockCryptoSetupMockRecorder) SetHandshakeConfirmed() *gomock.Call { +func (mr *MockCryptoSetupMockRecorder) SetHandshakeConfirmed() *MockCryptoSetupSetHandshakeConfirmedCall { mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "SetHandshakeConfirmed", reflect.TypeOf((*MockCryptoSetup)(nil).SetHandshakeConfirmed)) + call := mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "SetHandshakeConfirmed", reflect.TypeOf((*MockCryptoSetup)(nil).SetHandshakeConfirmed)) + return &MockCryptoSetupSetHandshakeConfirmedCall{Call: call} +} + +// MockCryptoSetupSetHandshakeConfirmedCall wrap *gomock.Call +type MockCryptoSetupSetHandshakeConfirmedCall struct { + *gomock.Call +} + +// Return rewrite *gomock.Call.Return +func (c *MockCryptoSetupSetHandshakeConfirmedCall) Return() *MockCryptoSetupSetHandshakeConfirmedCall { + c.Call = c.Call.Return() + return c +} + +// Do rewrite *gomock.Call.Do +func (c *MockCryptoSetupSetHandshakeConfirmedCall) Do(f func()) *MockCryptoSetupSetHandshakeConfirmedCall { + c.Call = c.Call.Do(f) + return c +} + +// DoAndReturn rewrite *gomock.Call.DoAndReturn +func (c *MockCryptoSetupSetHandshakeConfirmedCall) DoAndReturn(f func()) *MockCryptoSetupSetHandshakeConfirmedCall { + c.Call = c.Call.DoAndReturn(f) + return c } // SetLargest1RTTAcked mocks base method. @@ -275,9 +660,33 @@ func (m *MockCryptoSetup) SetLargest1RTTAcked(arg0 protocol.PacketNumber) error } // SetLargest1RTTAcked indicates an expected call of SetLargest1RTTAcked. -func (mr *MockCryptoSetupMockRecorder) SetLargest1RTTAcked(arg0 any) *gomock.Call { +func (mr *MockCryptoSetupMockRecorder) SetLargest1RTTAcked(arg0 any) *MockCryptoSetupSetLargest1RTTAckedCall { mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "SetLargest1RTTAcked", reflect.TypeOf((*MockCryptoSetup)(nil).SetLargest1RTTAcked), arg0) + call := mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "SetLargest1RTTAcked", reflect.TypeOf((*MockCryptoSetup)(nil).SetLargest1RTTAcked), arg0) + return &MockCryptoSetupSetLargest1RTTAckedCall{Call: call} +} + +// MockCryptoSetupSetLargest1RTTAckedCall wrap *gomock.Call +type MockCryptoSetupSetLargest1RTTAckedCall struct { + *gomock.Call +} + +// Return rewrite *gomock.Call.Return +func (c *MockCryptoSetupSetLargest1RTTAckedCall) Return(arg0 error) *MockCryptoSetupSetLargest1RTTAckedCall { + c.Call = c.Call.Return(arg0) + return c +} + +// Do rewrite *gomock.Call.Do +func (c *MockCryptoSetupSetLargest1RTTAckedCall) Do(f func(protocol.PacketNumber) error) *MockCryptoSetupSetLargest1RTTAckedCall { + c.Call = c.Call.Do(f) + return c +} + +// DoAndReturn rewrite *gomock.Call.DoAndReturn +func (c *MockCryptoSetupSetLargest1RTTAckedCall) DoAndReturn(f func(protocol.PacketNumber) error) *MockCryptoSetupSetLargest1RTTAckedCall { + c.Call = c.Call.DoAndReturn(f) + return c } // StartHandshake mocks base method. @@ -289,7 +698,31 @@ func (m *MockCryptoSetup) StartHandshake() error { } // StartHandshake indicates an expected call of StartHandshake. -func (mr *MockCryptoSetupMockRecorder) StartHandshake() *gomock.Call { +func (mr *MockCryptoSetupMockRecorder) StartHandshake() *MockCryptoSetupStartHandshakeCall { mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "StartHandshake", reflect.TypeOf((*MockCryptoSetup)(nil).StartHandshake)) + call := mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "StartHandshake", reflect.TypeOf((*MockCryptoSetup)(nil).StartHandshake)) + return &MockCryptoSetupStartHandshakeCall{Call: call} +} + +// MockCryptoSetupStartHandshakeCall wrap *gomock.Call +type MockCryptoSetupStartHandshakeCall struct { + *gomock.Call +} + +// Return rewrite *gomock.Call.Return +func (c *MockCryptoSetupStartHandshakeCall) Return(arg0 error) *MockCryptoSetupStartHandshakeCall { + c.Call = c.Call.Return(arg0) + return c +} + +// Do rewrite *gomock.Call.Do +func (c *MockCryptoSetupStartHandshakeCall) Do(f func() error) *MockCryptoSetupStartHandshakeCall { + c.Call = c.Call.Do(f) + return c +} + +// DoAndReturn rewrite *gomock.Call.DoAndReturn +func (c *MockCryptoSetupStartHandshakeCall) DoAndReturn(f func() error) *MockCryptoSetupStartHandshakeCall { + c.Call = c.Call.DoAndReturn(f) + return c } diff --git a/internal/mocks/logging/connection_tracer.go b/internal/mocks/logging/connection_tracer.go index c9728eece..e8ab98ce6 100644 --- a/internal/mocks/logging/connection_tracer.go +++ b/internal/mocks/logging/connection_tracer.go @@ -56,8 +56,8 @@ func NewMockConnectionTracer(ctrl *gomock.Controller) (*logging.ConnectionTracer BufferedPacket: func(typ logging.PacketType, size logging.ByteCount) { t.BufferedPacket(typ, size) }, - DroppedPacket: func(typ logging.PacketType, size logging.ByteCount, reason logging.PacketDropReason) { - t.DroppedPacket(typ, size, reason) + DroppedPacket: func(typ logging.PacketType, pn logging.PacketNumber, size logging.ByteCount, reason logging.PacketDropReason) { + t.DroppedPacket(typ, pn, size, reason) }, UpdatedMetrics: func(rttStats *logging.RTTStats, cwnd, bytesInFlight logging.ByteCount, packetsInFlight int) { t.UpdatedMetrics(rttStats, cwnd, bytesInFlight, packetsInFlight) @@ -98,6 +98,9 @@ func NewMockConnectionTracer(ctrl *gomock.Controller) (*logging.ConnectionTracer ECNStateUpdated: func(state logging.ECNState, trigger logging.ECNStateTrigger) { t.ECNStateUpdated(state, trigger) }, + ChoseALPN: func(protocol string) { + t.ChoseALPN(protocol) + }, Close: func() { t.Close() }, diff --git a/internal/mocks/logging/internal/connection_tracer.go b/internal/mocks/logging/internal/connection_tracer.go index b703f980c..9dd6cb0d6 100644 --- a/internal/mocks/logging/internal/connection_tracer.go +++ b/internal/mocks/logging/internal/connection_tracer.go @@ -3,8 +3,9 @@ // // Generated by this command: // -// mockgen -build_flags=-tags=gomock -package internal -destination internal/connection_tracer.go github.com/refraction-networking/uquic/internal/mocks/logging ConnectionTracer +// mockgen -typed -build_flags=-tags=gomock -package internal -destination internal/connection_tracer.go github.com/quic-go/quic-go/internal/mocks/logging ConnectionTracer // + // Package internal is a generated GoMock package. package internal @@ -50,9 +51,33 @@ func (m *MockConnectionTracer) AcknowledgedPacket(arg0 protocol.EncryptionLevel, } // AcknowledgedPacket indicates an expected call of AcknowledgedPacket. -func (mr *MockConnectionTracerMockRecorder) AcknowledgedPacket(arg0, arg1 any) *gomock.Call { +func (mr *MockConnectionTracerMockRecorder) AcknowledgedPacket(arg0, arg1 any) *MockConnectionTracerAcknowledgedPacketCall { mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "AcknowledgedPacket", reflect.TypeOf((*MockConnectionTracer)(nil).AcknowledgedPacket), arg0, arg1) + call := mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "AcknowledgedPacket", reflect.TypeOf((*MockConnectionTracer)(nil).AcknowledgedPacket), arg0, arg1) + return &MockConnectionTracerAcknowledgedPacketCall{Call: call} +} + +// MockConnectionTracerAcknowledgedPacketCall wrap *gomock.Call +type MockConnectionTracerAcknowledgedPacketCall struct { + *gomock.Call +} + +// Return rewrite *gomock.Call.Return +func (c *MockConnectionTracerAcknowledgedPacketCall) Return() *MockConnectionTracerAcknowledgedPacketCall { + c.Call = c.Call.Return() + return c +} + +// Do rewrite *gomock.Call.Do +func (c *MockConnectionTracerAcknowledgedPacketCall) Do(f func(protocol.EncryptionLevel, protocol.PacketNumber)) *MockConnectionTracerAcknowledgedPacketCall { + c.Call = c.Call.Do(f) + return c +} + +// DoAndReturn rewrite *gomock.Call.DoAndReturn +func (c *MockConnectionTracerAcknowledgedPacketCall) DoAndReturn(f func(protocol.EncryptionLevel, protocol.PacketNumber)) *MockConnectionTracerAcknowledgedPacketCall { + c.Call = c.Call.DoAndReturn(f) + return c } // BufferedPacket mocks base method. @@ -62,9 +87,69 @@ func (m *MockConnectionTracer) BufferedPacket(arg0 logging.PacketType, arg1 prot } // BufferedPacket indicates an expected call of BufferedPacket. -func (mr *MockConnectionTracerMockRecorder) BufferedPacket(arg0, arg1 any) *gomock.Call { +func (mr *MockConnectionTracerMockRecorder) BufferedPacket(arg0, arg1 any) *MockConnectionTracerBufferedPacketCall { mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "BufferedPacket", reflect.TypeOf((*MockConnectionTracer)(nil).BufferedPacket), arg0, arg1) + call := mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "BufferedPacket", reflect.TypeOf((*MockConnectionTracer)(nil).BufferedPacket), arg0, arg1) + return &MockConnectionTracerBufferedPacketCall{Call: call} +} + +// MockConnectionTracerBufferedPacketCall wrap *gomock.Call +type MockConnectionTracerBufferedPacketCall struct { + *gomock.Call +} + +// Return rewrite *gomock.Call.Return +func (c *MockConnectionTracerBufferedPacketCall) Return() *MockConnectionTracerBufferedPacketCall { + c.Call = c.Call.Return() + return c +} + +// Do rewrite *gomock.Call.Do +func (c *MockConnectionTracerBufferedPacketCall) Do(f func(logging.PacketType, protocol.ByteCount)) *MockConnectionTracerBufferedPacketCall { + c.Call = c.Call.Do(f) + return c +} + +// DoAndReturn rewrite *gomock.Call.DoAndReturn +func (c *MockConnectionTracerBufferedPacketCall) DoAndReturn(f func(logging.PacketType, protocol.ByteCount)) *MockConnectionTracerBufferedPacketCall { + c.Call = c.Call.DoAndReturn(f) + return c +} + +// ChoseALPN mocks base method. +func (m *MockConnectionTracer) ChoseALPN(arg0 string) { + m.ctrl.T.Helper() + m.ctrl.Call(m, "ChoseALPN", arg0) +} + +// ChoseALPN indicates an expected call of ChoseALPN. +func (mr *MockConnectionTracerMockRecorder) ChoseALPN(arg0 any) *MockConnectionTracerChoseALPNCall { + mr.mock.ctrl.T.Helper() + call := mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "ChoseALPN", reflect.TypeOf((*MockConnectionTracer)(nil).ChoseALPN), arg0) + return &MockConnectionTracerChoseALPNCall{Call: call} +} + +// MockConnectionTracerChoseALPNCall wrap *gomock.Call +type MockConnectionTracerChoseALPNCall struct { + *gomock.Call +} + +// Return rewrite *gomock.Call.Return +func (c *MockConnectionTracerChoseALPNCall) Return() *MockConnectionTracerChoseALPNCall { + c.Call = c.Call.Return() + return c +} + +// Do rewrite *gomock.Call.Do +func (c *MockConnectionTracerChoseALPNCall) Do(f func(string)) *MockConnectionTracerChoseALPNCall { + c.Call = c.Call.Do(f) + return c +} + +// DoAndReturn rewrite *gomock.Call.DoAndReturn +func (c *MockConnectionTracerChoseALPNCall) DoAndReturn(f func(string)) *MockConnectionTracerChoseALPNCall { + c.Call = c.Call.DoAndReturn(f) + return c } // Close mocks base method. @@ -74,9 +159,33 @@ func (m *MockConnectionTracer) Close() { } // Close indicates an expected call of Close. -func (mr *MockConnectionTracerMockRecorder) Close() *gomock.Call { +func (mr *MockConnectionTracerMockRecorder) Close() *MockConnectionTracerCloseCall { mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Close", reflect.TypeOf((*MockConnectionTracer)(nil).Close)) + call := mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Close", reflect.TypeOf((*MockConnectionTracer)(nil).Close)) + return &MockConnectionTracerCloseCall{Call: call} +} + +// MockConnectionTracerCloseCall wrap *gomock.Call +type MockConnectionTracerCloseCall struct { + *gomock.Call +} + +// Return rewrite *gomock.Call.Return +func (c *MockConnectionTracerCloseCall) Return() *MockConnectionTracerCloseCall { + c.Call = c.Call.Return() + return c +} + +// Do rewrite *gomock.Call.Do +func (c *MockConnectionTracerCloseCall) Do(f func()) *MockConnectionTracerCloseCall { + c.Call = c.Call.Do(f) + return c +} + +// DoAndReturn rewrite *gomock.Call.DoAndReturn +func (c *MockConnectionTracerCloseCall) DoAndReturn(f func()) *MockConnectionTracerCloseCall { + c.Call = c.Call.DoAndReturn(f) + return c } // ClosedConnection mocks base method. @@ -86,9 +195,33 @@ func (m *MockConnectionTracer) ClosedConnection(arg0 error) { } // ClosedConnection indicates an expected call of ClosedConnection. -func (mr *MockConnectionTracerMockRecorder) ClosedConnection(arg0 any) *gomock.Call { +func (mr *MockConnectionTracerMockRecorder) ClosedConnection(arg0 any) *MockConnectionTracerClosedConnectionCall { mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "ClosedConnection", reflect.TypeOf((*MockConnectionTracer)(nil).ClosedConnection), arg0) + call := mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "ClosedConnection", reflect.TypeOf((*MockConnectionTracer)(nil).ClosedConnection), arg0) + return &MockConnectionTracerClosedConnectionCall{Call: call} +} + +// MockConnectionTracerClosedConnectionCall wrap *gomock.Call +type MockConnectionTracerClosedConnectionCall struct { + *gomock.Call +} + +// Return rewrite *gomock.Call.Return +func (c *MockConnectionTracerClosedConnectionCall) Return() *MockConnectionTracerClosedConnectionCall { + c.Call = c.Call.Return() + return c +} + +// Do rewrite *gomock.Call.Do +func (c *MockConnectionTracerClosedConnectionCall) Do(f func(error)) *MockConnectionTracerClosedConnectionCall { + c.Call = c.Call.Do(f) + return c +} + +// DoAndReturn rewrite *gomock.Call.DoAndReturn +func (c *MockConnectionTracerClosedConnectionCall) DoAndReturn(f func(error)) *MockConnectionTracerClosedConnectionCall { + c.Call = c.Call.DoAndReturn(f) + return c } // Debug mocks base method. @@ -98,9 +231,33 @@ func (m *MockConnectionTracer) Debug(arg0, arg1 string) { } // Debug indicates an expected call of Debug. -func (mr *MockConnectionTracerMockRecorder) Debug(arg0, arg1 any) *gomock.Call { +func (mr *MockConnectionTracerMockRecorder) Debug(arg0, arg1 any) *MockConnectionTracerDebugCall { mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Debug", reflect.TypeOf((*MockConnectionTracer)(nil).Debug), arg0, arg1) + call := mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Debug", reflect.TypeOf((*MockConnectionTracer)(nil).Debug), arg0, arg1) + return &MockConnectionTracerDebugCall{Call: call} +} + +// MockConnectionTracerDebugCall wrap *gomock.Call +type MockConnectionTracerDebugCall struct { + *gomock.Call +} + +// Return rewrite *gomock.Call.Return +func (c *MockConnectionTracerDebugCall) Return() *MockConnectionTracerDebugCall { + c.Call = c.Call.Return() + return c +} + +// Do rewrite *gomock.Call.Do +func (c *MockConnectionTracerDebugCall) Do(f func(string, string)) *MockConnectionTracerDebugCall { + c.Call = c.Call.Do(f) + return c +} + +// DoAndReturn rewrite *gomock.Call.DoAndReturn +func (c *MockConnectionTracerDebugCall) DoAndReturn(f func(string, string)) *MockConnectionTracerDebugCall { + c.Call = c.Call.DoAndReturn(f) + return c } // DroppedEncryptionLevel mocks base method. @@ -110,9 +267,33 @@ func (m *MockConnectionTracer) DroppedEncryptionLevel(arg0 protocol.EncryptionLe } // DroppedEncryptionLevel indicates an expected call of DroppedEncryptionLevel. -func (mr *MockConnectionTracerMockRecorder) DroppedEncryptionLevel(arg0 any) *gomock.Call { +func (mr *MockConnectionTracerMockRecorder) DroppedEncryptionLevel(arg0 any) *MockConnectionTracerDroppedEncryptionLevelCall { mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "DroppedEncryptionLevel", reflect.TypeOf((*MockConnectionTracer)(nil).DroppedEncryptionLevel), arg0) + call := mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "DroppedEncryptionLevel", reflect.TypeOf((*MockConnectionTracer)(nil).DroppedEncryptionLevel), arg0) + return &MockConnectionTracerDroppedEncryptionLevelCall{Call: call} +} + +// MockConnectionTracerDroppedEncryptionLevelCall wrap *gomock.Call +type MockConnectionTracerDroppedEncryptionLevelCall struct { + *gomock.Call +} + +// Return rewrite *gomock.Call.Return +func (c *MockConnectionTracerDroppedEncryptionLevelCall) Return() *MockConnectionTracerDroppedEncryptionLevelCall { + c.Call = c.Call.Return() + return c +} + +// Do rewrite *gomock.Call.Do +func (c *MockConnectionTracerDroppedEncryptionLevelCall) Do(f func(protocol.EncryptionLevel)) *MockConnectionTracerDroppedEncryptionLevelCall { + c.Call = c.Call.Do(f) + return c +} + +// DoAndReturn rewrite *gomock.Call.DoAndReturn +func (c *MockConnectionTracerDroppedEncryptionLevelCall) DoAndReturn(f func(protocol.EncryptionLevel)) *MockConnectionTracerDroppedEncryptionLevelCall { + c.Call = c.Call.DoAndReturn(f) + return c } // DroppedKey mocks base method. @@ -122,21 +303,69 @@ func (m *MockConnectionTracer) DroppedKey(arg0 protocol.KeyPhase) { } // DroppedKey indicates an expected call of DroppedKey. -func (mr *MockConnectionTracerMockRecorder) DroppedKey(arg0 any) *gomock.Call { +func (mr *MockConnectionTracerMockRecorder) DroppedKey(arg0 any) *MockConnectionTracerDroppedKeyCall { mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "DroppedKey", reflect.TypeOf((*MockConnectionTracer)(nil).DroppedKey), arg0) + call := mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "DroppedKey", reflect.TypeOf((*MockConnectionTracer)(nil).DroppedKey), arg0) + return &MockConnectionTracerDroppedKeyCall{Call: call} +} + +// MockConnectionTracerDroppedKeyCall wrap *gomock.Call +type MockConnectionTracerDroppedKeyCall struct { + *gomock.Call +} + +// Return rewrite *gomock.Call.Return +func (c *MockConnectionTracerDroppedKeyCall) Return() *MockConnectionTracerDroppedKeyCall { + c.Call = c.Call.Return() + return c +} + +// Do rewrite *gomock.Call.Do +func (c *MockConnectionTracerDroppedKeyCall) Do(f func(protocol.KeyPhase)) *MockConnectionTracerDroppedKeyCall { + c.Call = c.Call.Do(f) + return c +} + +// DoAndReturn rewrite *gomock.Call.DoAndReturn +func (c *MockConnectionTracerDroppedKeyCall) DoAndReturn(f func(protocol.KeyPhase)) *MockConnectionTracerDroppedKeyCall { + c.Call = c.Call.DoAndReturn(f) + return c } // DroppedPacket mocks base method. -func (m *MockConnectionTracer) DroppedPacket(arg0 logging.PacketType, arg1 protocol.ByteCount, arg2 logging.PacketDropReason) { +func (m *MockConnectionTracer) DroppedPacket(arg0 logging.PacketType, arg1 protocol.PacketNumber, arg2 protocol.ByteCount, arg3 logging.PacketDropReason) { m.ctrl.T.Helper() - m.ctrl.Call(m, "DroppedPacket", arg0, arg1, arg2) + m.ctrl.Call(m, "DroppedPacket", arg0, arg1, arg2, arg3) } // DroppedPacket indicates an expected call of DroppedPacket. -func (mr *MockConnectionTracerMockRecorder) DroppedPacket(arg0, arg1, arg2 any) *gomock.Call { +func (mr *MockConnectionTracerMockRecorder) DroppedPacket(arg0, arg1, arg2, arg3 any) *MockConnectionTracerDroppedPacketCall { mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "DroppedPacket", reflect.TypeOf((*MockConnectionTracer)(nil).DroppedPacket), arg0, arg1, arg2) + call := mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "DroppedPacket", reflect.TypeOf((*MockConnectionTracer)(nil).DroppedPacket), arg0, arg1, arg2, arg3) + return &MockConnectionTracerDroppedPacketCall{Call: call} +} + +// MockConnectionTracerDroppedPacketCall wrap *gomock.Call +type MockConnectionTracerDroppedPacketCall struct { + *gomock.Call +} + +// Return rewrite *gomock.Call.Return +func (c *MockConnectionTracerDroppedPacketCall) Return() *MockConnectionTracerDroppedPacketCall { + c.Call = c.Call.Return() + return c +} + +// Do rewrite *gomock.Call.Do +func (c *MockConnectionTracerDroppedPacketCall) Do(f func(logging.PacketType, protocol.PacketNumber, protocol.ByteCount, logging.PacketDropReason)) *MockConnectionTracerDroppedPacketCall { + c.Call = c.Call.Do(f) + return c +} + +// DoAndReturn rewrite *gomock.Call.DoAndReturn +func (c *MockConnectionTracerDroppedPacketCall) DoAndReturn(f func(logging.PacketType, protocol.PacketNumber, protocol.ByteCount, logging.PacketDropReason)) *MockConnectionTracerDroppedPacketCall { + c.Call = c.Call.DoAndReturn(f) + return c } // ECNStateUpdated mocks base method. @@ -146,9 +375,33 @@ func (m *MockConnectionTracer) ECNStateUpdated(arg0 logging.ECNState, arg1 loggi } // ECNStateUpdated indicates an expected call of ECNStateUpdated. -func (mr *MockConnectionTracerMockRecorder) ECNStateUpdated(arg0, arg1 any) *gomock.Call { +func (mr *MockConnectionTracerMockRecorder) ECNStateUpdated(arg0, arg1 any) *MockConnectionTracerECNStateUpdatedCall { mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "ECNStateUpdated", reflect.TypeOf((*MockConnectionTracer)(nil).ECNStateUpdated), arg0, arg1) + call := mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "ECNStateUpdated", reflect.TypeOf((*MockConnectionTracer)(nil).ECNStateUpdated), arg0, arg1) + return &MockConnectionTracerECNStateUpdatedCall{Call: call} +} + +// MockConnectionTracerECNStateUpdatedCall wrap *gomock.Call +type MockConnectionTracerECNStateUpdatedCall struct { + *gomock.Call +} + +// Return rewrite *gomock.Call.Return +func (c *MockConnectionTracerECNStateUpdatedCall) Return() *MockConnectionTracerECNStateUpdatedCall { + c.Call = c.Call.Return() + return c +} + +// Do rewrite *gomock.Call.Do +func (c *MockConnectionTracerECNStateUpdatedCall) Do(f func(logging.ECNState, logging.ECNStateTrigger)) *MockConnectionTracerECNStateUpdatedCall { + c.Call = c.Call.Do(f) + return c +} + +// DoAndReturn rewrite *gomock.Call.DoAndReturn +func (c *MockConnectionTracerECNStateUpdatedCall) DoAndReturn(f func(logging.ECNState, logging.ECNStateTrigger)) *MockConnectionTracerECNStateUpdatedCall { + c.Call = c.Call.DoAndReturn(f) + return c } // LossTimerCanceled mocks base method. @@ -158,9 +411,33 @@ func (m *MockConnectionTracer) LossTimerCanceled() { } // LossTimerCanceled indicates an expected call of LossTimerCanceled. -func (mr *MockConnectionTracerMockRecorder) LossTimerCanceled() *gomock.Call { +func (mr *MockConnectionTracerMockRecorder) LossTimerCanceled() *MockConnectionTracerLossTimerCanceledCall { mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "LossTimerCanceled", reflect.TypeOf((*MockConnectionTracer)(nil).LossTimerCanceled)) + call := mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "LossTimerCanceled", reflect.TypeOf((*MockConnectionTracer)(nil).LossTimerCanceled)) + return &MockConnectionTracerLossTimerCanceledCall{Call: call} +} + +// MockConnectionTracerLossTimerCanceledCall wrap *gomock.Call +type MockConnectionTracerLossTimerCanceledCall struct { + *gomock.Call +} + +// Return rewrite *gomock.Call.Return +func (c *MockConnectionTracerLossTimerCanceledCall) Return() *MockConnectionTracerLossTimerCanceledCall { + c.Call = c.Call.Return() + return c +} + +// Do rewrite *gomock.Call.Do +func (c *MockConnectionTracerLossTimerCanceledCall) Do(f func()) *MockConnectionTracerLossTimerCanceledCall { + c.Call = c.Call.Do(f) + return c +} + +// DoAndReturn rewrite *gomock.Call.DoAndReturn +func (c *MockConnectionTracerLossTimerCanceledCall) DoAndReturn(f func()) *MockConnectionTracerLossTimerCanceledCall { + c.Call = c.Call.DoAndReturn(f) + return c } // LossTimerExpired mocks base method. @@ -170,9 +447,33 @@ func (m *MockConnectionTracer) LossTimerExpired(arg0 logging.TimerType, arg1 pro } // LossTimerExpired indicates an expected call of LossTimerExpired. -func (mr *MockConnectionTracerMockRecorder) LossTimerExpired(arg0, arg1 any) *gomock.Call { +func (mr *MockConnectionTracerMockRecorder) LossTimerExpired(arg0, arg1 any) *MockConnectionTracerLossTimerExpiredCall { mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "LossTimerExpired", reflect.TypeOf((*MockConnectionTracer)(nil).LossTimerExpired), arg0, arg1) + call := mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "LossTimerExpired", reflect.TypeOf((*MockConnectionTracer)(nil).LossTimerExpired), arg0, arg1) + return &MockConnectionTracerLossTimerExpiredCall{Call: call} +} + +// MockConnectionTracerLossTimerExpiredCall wrap *gomock.Call +type MockConnectionTracerLossTimerExpiredCall struct { + *gomock.Call +} + +// Return rewrite *gomock.Call.Return +func (c *MockConnectionTracerLossTimerExpiredCall) Return() *MockConnectionTracerLossTimerExpiredCall { + c.Call = c.Call.Return() + return c +} + +// Do rewrite *gomock.Call.Do +func (c *MockConnectionTracerLossTimerExpiredCall) Do(f func(logging.TimerType, protocol.EncryptionLevel)) *MockConnectionTracerLossTimerExpiredCall { + c.Call = c.Call.Do(f) + return c +} + +// DoAndReturn rewrite *gomock.Call.DoAndReturn +func (c *MockConnectionTracerLossTimerExpiredCall) DoAndReturn(f func(logging.TimerType, protocol.EncryptionLevel)) *MockConnectionTracerLossTimerExpiredCall { + c.Call = c.Call.DoAndReturn(f) + return c } // LostPacket mocks base method. @@ -182,21 +483,69 @@ func (m *MockConnectionTracer) LostPacket(arg0 protocol.EncryptionLevel, arg1 pr } // LostPacket indicates an expected call of LostPacket. -func (mr *MockConnectionTracerMockRecorder) LostPacket(arg0, arg1, arg2 any) *gomock.Call { +func (mr *MockConnectionTracerMockRecorder) LostPacket(arg0, arg1, arg2 any) *MockConnectionTracerLostPacketCall { mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "LostPacket", reflect.TypeOf((*MockConnectionTracer)(nil).LostPacket), arg0, arg1, arg2) + call := mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "LostPacket", reflect.TypeOf((*MockConnectionTracer)(nil).LostPacket), arg0, arg1, arg2) + return &MockConnectionTracerLostPacketCall{Call: call} +} + +// MockConnectionTracerLostPacketCall wrap *gomock.Call +type MockConnectionTracerLostPacketCall struct { + *gomock.Call +} + +// Return rewrite *gomock.Call.Return +func (c *MockConnectionTracerLostPacketCall) Return() *MockConnectionTracerLostPacketCall { + c.Call = c.Call.Return() + return c +} + +// Do rewrite *gomock.Call.Do +func (c *MockConnectionTracerLostPacketCall) Do(f func(protocol.EncryptionLevel, protocol.PacketNumber, logging.PacketLossReason)) *MockConnectionTracerLostPacketCall { + c.Call = c.Call.Do(f) + return c +} + +// DoAndReturn rewrite *gomock.Call.DoAndReturn +func (c *MockConnectionTracerLostPacketCall) DoAndReturn(f func(protocol.EncryptionLevel, protocol.PacketNumber, logging.PacketLossReason)) *MockConnectionTracerLostPacketCall { + c.Call = c.Call.DoAndReturn(f) + return c } // NegotiatedVersion mocks base method. -func (m *MockConnectionTracer) NegotiatedVersion(arg0 protocol.VersionNumber, arg1, arg2 []protocol.VersionNumber) { +func (m *MockConnectionTracer) NegotiatedVersion(arg0 protocol.Version, arg1, arg2 []protocol.Version) { m.ctrl.T.Helper() m.ctrl.Call(m, "NegotiatedVersion", arg0, arg1, arg2) } // NegotiatedVersion indicates an expected call of NegotiatedVersion. -func (mr *MockConnectionTracerMockRecorder) NegotiatedVersion(arg0, arg1, arg2 any) *gomock.Call { +func (mr *MockConnectionTracerMockRecorder) NegotiatedVersion(arg0, arg1, arg2 any) *MockConnectionTracerNegotiatedVersionCall { mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "NegotiatedVersion", reflect.TypeOf((*MockConnectionTracer)(nil).NegotiatedVersion), arg0, arg1, arg2) + call := mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "NegotiatedVersion", reflect.TypeOf((*MockConnectionTracer)(nil).NegotiatedVersion), arg0, arg1, arg2) + return &MockConnectionTracerNegotiatedVersionCall{Call: call} +} + +// MockConnectionTracerNegotiatedVersionCall wrap *gomock.Call +type MockConnectionTracerNegotiatedVersionCall struct { + *gomock.Call +} + +// Return rewrite *gomock.Call.Return +func (c *MockConnectionTracerNegotiatedVersionCall) Return() *MockConnectionTracerNegotiatedVersionCall { + c.Call = c.Call.Return() + return c +} + +// Do rewrite *gomock.Call.Do +func (c *MockConnectionTracerNegotiatedVersionCall) Do(f func(protocol.Version, []protocol.Version, []protocol.Version)) *MockConnectionTracerNegotiatedVersionCall { + c.Call = c.Call.Do(f) + return c +} + +// DoAndReturn rewrite *gomock.Call.DoAndReturn +func (c *MockConnectionTracerNegotiatedVersionCall) DoAndReturn(f func(protocol.Version, []protocol.Version, []protocol.Version)) *MockConnectionTracerNegotiatedVersionCall { + c.Call = c.Call.DoAndReturn(f) + return c } // ReceivedLongHeaderPacket mocks base method. @@ -206,9 +555,33 @@ func (m *MockConnectionTracer) ReceivedLongHeaderPacket(arg0 *wire.ExtendedHeade } // ReceivedLongHeaderPacket indicates an expected call of ReceivedLongHeaderPacket. -func (mr *MockConnectionTracerMockRecorder) ReceivedLongHeaderPacket(arg0, arg1, arg2, arg3 any) *gomock.Call { +func (mr *MockConnectionTracerMockRecorder) ReceivedLongHeaderPacket(arg0, arg1, arg2, arg3 any) *MockConnectionTracerReceivedLongHeaderPacketCall { mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "ReceivedLongHeaderPacket", reflect.TypeOf((*MockConnectionTracer)(nil).ReceivedLongHeaderPacket), arg0, arg1, arg2, arg3) + call := mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "ReceivedLongHeaderPacket", reflect.TypeOf((*MockConnectionTracer)(nil).ReceivedLongHeaderPacket), arg0, arg1, arg2, arg3) + return &MockConnectionTracerReceivedLongHeaderPacketCall{Call: call} +} + +// MockConnectionTracerReceivedLongHeaderPacketCall wrap *gomock.Call +type MockConnectionTracerReceivedLongHeaderPacketCall struct { + *gomock.Call +} + +// Return rewrite *gomock.Call.Return +func (c *MockConnectionTracerReceivedLongHeaderPacketCall) Return() *MockConnectionTracerReceivedLongHeaderPacketCall { + c.Call = c.Call.Return() + return c +} + +// Do rewrite *gomock.Call.Do +func (c *MockConnectionTracerReceivedLongHeaderPacketCall) Do(f func(*wire.ExtendedHeader, protocol.ByteCount, protocol.ECN, []logging.Frame)) *MockConnectionTracerReceivedLongHeaderPacketCall { + c.Call = c.Call.Do(f) + return c +} + +// DoAndReturn rewrite *gomock.Call.DoAndReturn +func (c *MockConnectionTracerReceivedLongHeaderPacketCall) DoAndReturn(f func(*wire.ExtendedHeader, protocol.ByteCount, protocol.ECN, []logging.Frame)) *MockConnectionTracerReceivedLongHeaderPacketCall { + c.Call = c.Call.DoAndReturn(f) + return c } // ReceivedRetry mocks base method. @@ -218,9 +591,33 @@ func (m *MockConnectionTracer) ReceivedRetry(arg0 *wire.Header) { } // ReceivedRetry indicates an expected call of ReceivedRetry. -func (mr *MockConnectionTracerMockRecorder) ReceivedRetry(arg0 any) *gomock.Call { +func (mr *MockConnectionTracerMockRecorder) ReceivedRetry(arg0 any) *MockConnectionTracerReceivedRetryCall { mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "ReceivedRetry", reflect.TypeOf((*MockConnectionTracer)(nil).ReceivedRetry), arg0) + call := mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "ReceivedRetry", reflect.TypeOf((*MockConnectionTracer)(nil).ReceivedRetry), arg0) + return &MockConnectionTracerReceivedRetryCall{Call: call} +} + +// MockConnectionTracerReceivedRetryCall wrap *gomock.Call +type MockConnectionTracerReceivedRetryCall struct { + *gomock.Call +} + +// Return rewrite *gomock.Call.Return +func (c *MockConnectionTracerReceivedRetryCall) Return() *MockConnectionTracerReceivedRetryCall { + c.Call = c.Call.Return() + return c +} + +// Do rewrite *gomock.Call.Do +func (c *MockConnectionTracerReceivedRetryCall) Do(f func(*wire.Header)) *MockConnectionTracerReceivedRetryCall { + c.Call = c.Call.Do(f) + return c +} + +// DoAndReturn rewrite *gomock.Call.DoAndReturn +func (c *MockConnectionTracerReceivedRetryCall) DoAndReturn(f func(*wire.Header)) *MockConnectionTracerReceivedRetryCall { + c.Call = c.Call.DoAndReturn(f) + return c } // ReceivedShortHeaderPacket mocks base method. @@ -230,9 +627,33 @@ func (m *MockConnectionTracer) ReceivedShortHeaderPacket(arg0 *logging.ShortHead } // ReceivedShortHeaderPacket indicates an expected call of ReceivedShortHeaderPacket. -func (mr *MockConnectionTracerMockRecorder) ReceivedShortHeaderPacket(arg0, arg1, arg2, arg3 any) *gomock.Call { +func (mr *MockConnectionTracerMockRecorder) ReceivedShortHeaderPacket(arg0, arg1, arg2, arg3 any) *MockConnectionTracerReceivedShortHeaderPacketCall { mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "ReceivedShortHeaderPacket", reflect.TypeOf((*MockConnectionTracer)(nil).ReceivedShortHeaderPacket), arg0, arg1, arg2, arg3) + call := mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "ReceivedShortHeaderPacket", reflect.TypeOf((*MockConnectionTracer)(nil).ReceivedShortHeaderPacket), arg0, arg1, arg2, arg3) + return &MockConnectionTracerReceivedShortHeaderPacketCall{Call: call} +} + +// MockConnectionTracerReceivedShortHeaderPacketCall wrap *gomock.Call +type MockConnectionTracerReceivedShortHeaderPacketCall struct { + *gomock.Call +} + +// Return rewrite *gomock.Call.Return +func (c *MockConnectionTracerReceivedShortHeaderPacketCall) Return() *MockConnectionTracerReceivedShortHeaderPacketCall { + c.Call = c.Call.Return() + return c +} + +// Do rewrite *gomock.Call.Do +func (c *MockConnectionTracerReceivedShortHeaderPacketCall) Do(f func(*logging.ShortHeader, protocol.ByteCount, protocol.ECN, []logging.Frame)) *MockConnectionTracerReceivedShortHeaderPacketCall { + c.Call = c.Call.Do(f) + return c +} + +// DoAndReturn rewrite *gomock.Call.DoAndReturn +func (c *MockConnectionTracerReceivedShortHeaderPacketCall) DoAndReturn(f func(*logging.ShortHeader, protocol.ByteCount, protocol.ECN, []logging.Frame)) *MockConnectionTracerReceivedShortHeaderPacketCall { + c.Call = c.Call.DoAndReturn(f) + return c } // ReceivedTransportParameters mocks base method. @@ -242,21 +663,69 @@ func (m *MockConnectionTracer) ReceivedTransportParameters(arg0 *wire.TransportP } // ReceivedTransportParameters indicates an expected call of ReceivedTransportParameters. -func (mr *MockConnectionTracerMockRecorder) ReceivedTransportParameters(arg0 any) *gomock.Call { +func (mr *MockConnectionTracerMockRecorder) ReceivedTransportParameters(arg0 any) *MockConnectionTracerReceivedTransportParametersCall { mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "ReceivedTransportParameters", reflect.TypeOf((*MockConnectionTracer)(nil).ReceivedTransportParameters), arg0) + call := mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "ReceivedTransportParameters", reflect.TypeOf((*MockConnectionTracer)(nil).ReceivedTransportParameters), arg0) + return &MockConnectionTracerReceivedTransportParametersCall{Call: call} +} + +// MockConnectionTracerReceivedTransportParametersCall wrap *gomock.Call +type MockConnectionTracerReceivedTransportParametersCall struct { + *gomock.Call +} + +// Return rewrite *gomock.Call.Return +func (c *MockConnectionTracerReceivedTransportParametersCall) Return() *MockConnectionTracerReceivedTransportParametersCall { + c.Call = c.Call.Return() + return c +} + +// Do rewrite *gomock.Call.Do +func (c *MockConnectionTracerReceivedTransportParametersCall) Do(f func(*wire.TransportParameters)) *MockConnectionTracerReceivedTransportParametersCall { + c.Call = c.Call.Do(f) + return c +} + +// DoAndReturn rewrite *gomock.Call.DoAndReturn +func (c *MockConnectionTracerReceivedTransportParametersCall) DoAndReturn(f func(*wire.TransportParameters)) *MockConnectionTracerReceivedTransportParametersCall { + c.Call = c.Call.DoAndReturn(f) + return c } // ReceivedVersionNegotiationPacket mocks base method. -func (m *MockConnectionTracer) ReceivedVersionNegotiationPacket(arg0, arg1 protocol.ArbitraryLenConnectionID, arg2 []protocol.VersionNumber) { +func (m *MockConnectionTracer) ReceivedVersionNegotiationPacket(arg0, arg1 protocol.ArbitraryLenConnectionID, arg2 []protocol.Version) { m.ctrl.T.Helper() m.ctrl.Call(m, "ReceivedVersionNegotiationPacket", arg0, arg1, arg2) } // ReceivedVersionNegotiationPacket indicates an expected call of ReceivedVersionNegotiationPacket. -func (mr *MockConnectionTracerMockRecorder) ReceivedVersionNegotiationPacket(arg0, arg1, arg2 any) *gomock.Call { +func (mr *MockConnectionTracerMockRecorder) ReceivedVersionNegotiationPacket(arg0, arg1, arg2 any) *MockConnectionTracerReceivedVersionNegotiationPacketCall { mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "ReceivedVersionNegotiationPacket", reflect.TypeOf((*MockConnectionTracer)(nil).ReceivedVersionNegotiationPacket), arg0, arg1, arg2) + call := mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "ReceivedVersionNegotiationPacket", reflect.TypeOf((*MockConnectionTracer)(nil).ReceivedVersionNegotiationPacket), arg0, arg1, arg2) + return &MockConnectionTracerReceivedVersionNegotiationPacketCall{Call: call} +} + +// MockConnectionTracerReceivedVersionNegotiationPacketCall wrap *gomock.Call +type MockConnectionTracerReceivedVersionNegotiationPacketCall struct { + *gomock.Call +} + +// Return rewrite *gomock.Call.Return +func (c *MockConnectionTracerReceivedVersionNegotiationPacketCall) Return() *MockConnectionTracerReceivedVersionNegotiationPacketCall { + c.Call = c.Call.Return() + return c +} + +// Do rewrite *gomock.Call.Do +func (c *MockConnectionTracerReceivedVersionNegotiationPacketCall) Do(f func(protocol.ArbitraryLenConnectionID, protocol.ArbitraryLenConnectionID, []protocol.Version)) *MockConnectionTracerReceivedVersionNegotiationPacketCall { + c.Call = c.Call.Do(f) + return c +} + +// DoAndReturn rewrite *gomock.Call.DoAndReturn +func (c *MockConnectionTracerReceivedVersionNegotiationPacketCall) DoAndReturn(f func(protocol.ArbitraryLenConnectionID, protocol.ArbitraryLenConnectionID, []protocol.Version)) *MockConnectionTracerReceivedVersionNegotiationPacketCall { + c.Call = c.Call.DoAndReturn(f) + return c } // RestoredTransportParameters mocks base method. @@ -266,9 +735,33 @@ func (m *MockConnectionTracer) RestoredTransportParameters(arg0 *wire.TransportP } // RestoredTransportParameters indicates an expected call of RestoredTransportParameters. -func (mr *MockConnectionTracerMockRecorder) RestoredTransportParameters(arg0 any) *gomock.Call { +func (mr *MockConnectionTracerMockRecorder) RestoredTransportParameters(arg0 any) *MockConnectionTracerRestoredTransportParametersCall { mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "RestoredTransportParameters", reflect.TypeOf((*MockConnectionTracer)(nil).RestoredTransportParameters), arg0) + call := mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "RestoredTransportParameters", reflect.TypeOf((*MockConnectionTracer)(nil).RestoredTransportParameters), arg0) + return &MockConnectionTracerRestoredTransportParametersCall{Call: call} +} + +// MockConnectionTracerRestoredTransportParametersCall wrap *gomock.Call +type MockConnectionTracerRestoredTransportParametersCall struct { + *gomock.Call +} + +// Return rewrite *gomock.Call.Return +func (c *MockConnectionTracerRestoredTransportParametersCall) Return() *MockConnectionTracerRestoredTransportParametersCall { + c.Call = c.Call.Return() + return c +} + +// Do rewrite *gomock.Call.Do +func (c *MockConnectionTracerRestoredTransportParametersCall) Do(f func(*wire.TransportParameters)) *MockConnectionTracerRestoredTransportParametersCall { + c.Call = c.Call.Do(f) + return c +} + +// DoAndReturn rewrite *gomock.Call.DoAndReturn +func (c *MockConnectionTracerRestoredTransportParametersCall) DoAndReturn(f func(*wire.TransportParameters)) *MockConnectionTracerRestoredTransportParametersCall { + c.Call = c.Call.DoAndReturn(f) + return c } // SentLongHeaderPacket mocks base method. @@ -278,9 +771,33 @@ func (m *MockConnectionTracer) SentLongHeaderPacket(arg0 *wire.ExtendedHeader, a } // SentLongHeaderPacket indicates an expected call of SentLongHeaderPacket. -func (mr *MockConnectionTracerMockRecorder) SentLongHeaderPacket(arg0, arg1, arg2, arg3, arg4 any) *gomock.Call { +func (mr *MockConnectionTracerMockRecorder) SentLongHeaderPacket(arg0, arg1, arg2, arg3, arg4 any) *MockConnectionTracerSentLongHeaderPacketCall { mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "SentLongHeaderPacket", reflect.TypeOf((*MockConnectionTracer)(nil).SentLongHeaderPacket), arg0, arg1, arg2, arg3, arg4) + call := mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "SentLongHeaderPacket", reflect.TypeOf((*MockConnectionTracer)(nil).SentLongHeaderPacket), arg0, arg1, arg2, arg3, arg4) + return &MockConnectionTracerSentLongHeaderPacketCall{Call: call} +} + +// MockConnectionTracerSentLongHeaderPacketCall wrap *gomock.Call +type MockConnectionTracerSentLongHeaderPacketCall struct { + *gomock.Call +} + +// Return rewrite *gomock.Call.Return +func (c *MockConnectionTracerSentLongHeaderPacketCall) Return() *MockConnectionTracerSentLongHeaderPacketCall { + c.Call = c.Call.Return() + return c +} + +// Do rewrite *gomock.Call.Do +func (c *MockConnectionTracerSentLongHeaderPacketCall) Do(f func(*wire.ExtendedHeader, protocol.ByteCount, protocol.ECN, *wire.AckFrame, []logging.Frame)) *MockConnectionTracerSentLongHeaderPacketCall { + c.Call = c.Call.Do(f) + return c +} + +// DoAndReturn rewrite *gomock.Call.DoAndReturn +func (c *MockConnectionTracerSentLongHeaderPacketCall) DoAndReturn(f func(*wire.ExtendedHeader, protocol.ByteCount, protocol.ECN, *wire.AckFrame, []logging.Frame)) *MockConnectionTracerSentLongHeaderPacketCall { + c.Call = c.Call.DoAndReturn(f) + return c } // SentShortHeaderPacket mocks base method. @@ -290,9 +807,33 @@ func (m *MockConnectionTracer) SentShortHeaderPacket(arg0 *logging.ShortHeader, } // SentShortHeaderPacket indicates an expected call of SentShortHeaderPacket. -func (mr *MockConnectionTracerMockRecorder) SentShortHeaderPacket(arg0, arg1, arg2, arg3, arg4 any) *gomock.Call { +func (mr *MockConnectionTracerMockRecorder) SentShortHeaderPacket(arg0, arg1, arg2, arg3, arg4 any) *MockConnectionTracerSentShortHeaderPacketCall { mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "SentShortHeaderPacket", reflect.TypeOf((*MockConnectionTracer)(nil).SentShortHeaderPacket), arg0, arg1, arg2, arg3, arg4) + call := mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "SentShortHeaderPacket", reflect.TypeOf((*MockConnectionTracer)(nil).SentShortHeaderPacket), arg0, arg1, arg2, arg3, arg4) + return &MockConnectionTracerSentShortHeaderPacketCall{Call: call} +} + +// MockConnectionTracerSentShortHeaderPacketCall wrap *gomock.Call +type MockConnectionTracerSentShortHeaderPacketCall struct { + *gomock.Call +} + +// Return rewrite *gomock.Call.Return +func (c *MockConnectionTracerSentShortHeaderPacketCall) Return() *MockConnectionTracerSentShortHeaderPacketCall { + c.Call = c.Call.Return() + return c +} + +// Do rewrite *gomock.Call.Do +func (c *MockConnectionTracerSentShortHeaderPacketCall) Do(f func(*logging.ShortHeader, protocol.ByteCount, protocol.ECN, *wire.AckFrame, []logging.Frame)) *MockConnectionTracerSentShortHeaderPacketCall { + c.Call = c.Call.Do(f) + return c +} + +// DoAndReturn rewrite *gomock.Call.DoAndReturn +func (c *MockConnectionTracerSentShortHeaderPacketCall) DoAndReturn(f func(*logging.ShortHeader, protocol.ByteCount, protocol.ECN, *wire.AckFrame, []logging.Frame)) *MockConnectionTracerSentShortHeaderPacketCall { + c.Call = c.Call.DoAndReturn(f) + return c } // SentTransportParameters mocks base method. @@ -302,9 +843,33 @@ func (m *MockConnectionTracer) SentTransportParameters(arg0 *wire.TransportParam } // SentTransportParameters indicates an expected call of SentTransportParameters. -func (mr *MockConnectionTracerMockRecorder) SentTransportParameters(arg0 any) *gomock.Call { +func (mr *MockConnectionTracerMockRecorder) SentTransportParameters(arg0 any) *MockConnectionTracerSentTransportParametersCall { mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "SentTransportParameters", reflect.TypeOf((*MockConnectionTracer)(nil).SentTransportParameters), arg0) + call := mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "SentTransportParameters", reflect.TypeOf((*MockConnectionTracer)(nil).SentTransportParameters), arg0) + return &MockConnectionTracerSentTransportParametersCall{Call: call} +} + +// MockConnectionTracerSentTransportParametersCall wrap *gomock.Call +type MockConnectionTracerSentTransportParametersCall struct { + *gomock.Call +} + +// Return rewrite *gomock.Call.Return +func (c *MockConnectionTracerSentTransportParametersCall) Return() *MockConnectionTracerSentTransportParametersCall { + c.Call = c.Call.Return() + return c +} + +// Do rewrite *gomock.Call.Do +func (c *MockConnectionTracerSentTransportParametersCall) Do(f func(*wire.TransportParameters)) *MockConnectionTracerSentTransportParametersCall { + c.Call = c.Call.Do(f) + return c +} + +// DoAndReturn rewrite *gomock.Call.DoAndReturn +func (c *MockConnectionTracerSentTransportParametersCall) DoAndReturn(f func(*wire.TransportParameters)) *MockConnectionTracerSentTransportParametersCall { + c.Call = c.Call.DoAndReturn(f) + return c } // SetLossTimer mocks base method. @@ -314,9 +879,33 @@ func (m *MockConnectionTracer) SetLossTimer(arg0 logging.TimerType, arg1 protoco } // SetLossTimer indicates an expected call of SetLossTimer. -func (mr *MockConnectionTracerMockRecorder) SetLossTimer(arg0, arg1, arg2 any) *gomock.Call { +func (mr *MockConnectionTracerMockRecorder) SetLossTimer(arg0, arg1, arg2 any) *MockConnectionTracerSetLossTimerCall { mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "SetLossTimer", reflect.TypeOf((*MockConnectionTracer)(nil).SetLossTimer), arg0, arg1, arg2) + call := mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "SetLossTimer", reflect.TypeOf((*MockConnectionTracer)(nil).SetLossTimer), arg0, arg1, arg2) + return &MockConnectionTracerSetLossTimerCall{Call: call} +} + +// MockConnectionTracerSetLossTimerCall wrap *gomock.Call +type MockConnectionTracerSetLossTimerCall struct { + *gomock.Call +} + +// Return rewrite *gomock.Call.Return +func (c *MockConnectionTracerSetLossTimerCall) Return() *MockConnectionTracerSetLossTimerCall { + c.Call = c.Call.Return() + return c +} + +// Do rewrite *gomock.Call.Do +func (c *MockConnectionTracerSetLossTimerCall) Do(f func(logging.TimerType, protocol.EncryptionLevel, time.Time)) *MockConnectionTracerSetLossTimerCall { + c.Call = c.Call.Do(f) + return c +} + +// DoAndReturn rewrite *gomock.Call.DoAndReturn +func (c *MockConnectionTracerSetLossTimerCall) DoAndReturn(f func(logging.TimerType, protocol.EncryptionLevel, time.Time)) *MockConnectionTracerSetLossTimerCall { + c.Call = c.Call.DoAndReturn(f) + return c } // StartedConnection mocks base method. @@ -326,9 +915,33 @@ func (m *MockConnectionTracer) StartedConnection(arg0, arg1 net.Addr, arg2, arg3 } // StartedConnection indicates an expected call of StartedConnection. -func (mr *MockConnectionTracerMockRecorder) StartedConnection(arg0, arg1, arg2, arg3 any) *gomock.Call { +func (mr *MockConnectionTracerMockRecorder) StartedConnection(arg0, arg1, arg2, arg3 any) *MockConnectionTracerStartedConnectionCall { mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "StartedConnection", reflect.TypeOf((*MockConnectionTracer)(nil).StartedConnection), arg0, arg1, arg2, arg3) + call := mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "StartedConnection", reflect.TypeOf((*MockConnectionTracer)(nil).StartedConnection), arg0, arg1, arg2, arg3) + return &MockConnectionTracerStartedConnectionCall{Call: call} +} + +// MockConnectionTracerStartedConnectionCall wrap *gomock.Call +type MockConnectionTracerStartedConnectionCall struct { + *gomock.Call +} + +// Return rewrite *gomock.Call.Return +func (c *MockConnectionTracerStartedConnectionCall) Return() *MockConnectionTracerStartedConnectionCall { + c.Call = c.Call.Return() + return c +} + +// Do rewrite *gomock.Call.Do +func (c *MockConnectionTracerStartedConnectionCall) Do(f func(net.Addr, net.Addr, protocol.ConnectionID, protocol.ConnectionID)) *MockConnectionTracerStartedConnectionCall { + c.Call = c.Call.Do(f) + return c +} + +// DoAndReturn rewrite *gomock.Call.DoAndReturn +func (c *MockConnectionTracerStartedConnectionCall) DoAndReturn(f func(net.Addr, net.Addr, protocol.ConnectionID, protocol.ConnectionID)) *MockConnectionTracerStartedConnectionCall { + c.Call = c.Call.DoAndReturn(f) + return c } // UpdatedCongestionState mocks base method. @@ -338,9 +951,33 @@ func (m *MockConnectionTracer) UpdatedCongestionState(arg0 logging.CongestionSta } // UpdatedCongestionState indicates an expected call of UpdatedCongestionState. -func (mr *MockConnectionTracerMockRecorder) UpdatedCongestionState(arg0 any) *gomock.Call { +func (mr *MockConnectionTracerMockRecorder) UpdatedCongestionState(arg0 any) *MockConnectionTracerUpdatedCongestionStateCall { mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "UpdatedCongestionState", reflect.TypeOf((*MockConnectionTracer)(nil).UpdatedCongestionState), arg0) + call := mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "UpdatedCongestionState", reflect.TypeOf((*MockConnectionTracer)(nil).UpdatedCongestionState), arg0) + return &MockConnectionTracerUpdatedCongestionStateCall{Call: call} +} + +// MockConnectionTracerUpdatedCongestionStateCall wrap *gomock.Call +type MockConnectionTracerUpdatedCongestionStateCall struct { + *gomock.Call +} + +// Return rewrite *gomock.Call.Return +func (c *MockConnectionTracerUpdatedCongestionStateCall) Return() *MockConnectionTracerUpdatedCongestionStateCall { + c.Call = c.Call.Return() + return c +} + +// Do rewrite *gomock.Call.Do +func (c *MockConnectionTracerUpdatedCongestionStateCall) Do(f func(logging.CongestionState)) *MockConnectionTracerUpdatedCongestionStateCall { + c.Call = c.Call.Do(f) + return c +} + +// DoAndReturn rewrite *gomock.Call.DoAndReturn +func (c *MockConnectionTracerUpdatedCongestionStateCall) DoAndReturn(f func(logging.CongestionState)) *MockConnectionTracerUpdatedCongestionStateCall { + c.Call = c.Call.DoAndReturn(f) + return c } // UpdatedKey mocks base method. @@ -350,9 +987,33 @@ func (m *MockConnectionTracer) UpdatedKey(arg0 protocol.KeyPhase, arg1 bool) { } // UpdatedKey indicates an expected call of UpdatedKey. -func (mr *MockConnectionTracerMockRecorder) UpdatedKey(arg0, arg1 any) *gomock.Call { +func (mr *MockConnectionTracerMockRecorder) UpdatedKey(arg0, arg1 any) *MockConnectionTracerUpdatedKeyCall { mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "UpdatedKey", reflect.TypeOf((*MockConnectionTracer)(nil).UpdatedKey), arg0, arg1) + call := mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "UpdatedKey", reflect.TypeOf((*MockConnectionTracer)(nil).UpdatedKey), arg0, arg1) + return &MockConnectionTracerUpdatedKeyCall{Call: call} +} + +// MockConnectionTracerUpdatedKeyCall wrap *gomock.Call +type MockConnectionTracerUpdatedKeyCall struct { + *gomock.Call +} + +// Return rewrite *gomock.Call.Return +func (c *MockConnectionTracerUpdatedKeyCall) Return() *MockConnectionTracerUpdatedKeyCall { + c.Call = c.Call.Return() + return c +} + +// Do rewrite *gomock.Call.Do +func (c *MockConnectionTracerUpdatedKeyCall) Do(f func(protocol.KeyPhase, bool)) *MockConnectionTracerUpdatedKeyCall { + c.Call = c.Call.Do(f) + return c +} + +// DoAndReturn rewrite *gomock.Call.DoAndReturn +func (c *MockConnectionTracerUpdatedKeyCall) DoAndReturn(f func(protocol.KeyPhase, bool)) *MockConnectionTracerUpdatedKeyCall { + c.Call = c.Call.DoAndReturn(f) + return c } // UpdatedKeyFromTLS mocks base method. @@ -362,9 +1023,33 @@ func (m *MockConnectionTracer) UpdatedKeyFromTLS(arg0 protocol.EncryptionLevel, } // UpdatedKeyFromTLS indicates an expected call of UpdatedKeyFromTLS. -func (mr *MockConnectionTracerMockRecorder) UpdatedKeyFromTLS(arg0, arg1 any) *gomock.Call { +func (mr *MockConnectionTracerMockRecorder) UpdatedKeyFromTLS(arg0, arg1 any) *MockConnectionTracerUpdatedKeyFromTLSCall { mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "UpdatedKeyFromTLS", reflect.TypeOf((*MockConnectionTracer)(nil).UpdatedKeyFromTLS), arg0, arg1) + call := mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "UpdatedKeyFromTLS", reflect.TypeOf((*MockConnectionTracer)(nil).UpdatedKeyFromTLS), arg0, arg1) + return &MockConnectionTracerUpdatedKeyFromTLSCall{Call: call} +} + +// MockConnectionTracerUpdatedKeyFromTLSCall wrap *gomock.Call +type MockConnectionTracerUpdatedKeyFromTLSCall struct { + *gomock.Call +} + +// Return rewrite *gomock.Call.Return +func (c *MockConnectionTracerUpdatedKeyFromTLSCall) Return() *MockConnectionTracerUpdatedKeyFromTLSCall { + c.Call = c.Call.Return() + return c +} + +// Do rewrite *gomock.Call.Do +func (c *MockConnectionTracerUpdatedKeyFromTLSCall) Do(f func(protocol.EncryptionLevel, protocol.Perspective)) *MockConnectionTracerUpdatedKeyFromTLSCall { + c.Call = c.Call.Do(f) + return c +} + +// DoAndReturn rewrite *gomock.Call.DoAndReturn +func (c *MockConnectionTracerUpdatedKeyFromTLSCall) DoAndReturn(f func(protocol.EncryptionLevel, protocol.Perspective)) *MockConnectionTracerUpdatedKeyFromTLSCall { + c.Call = c.Call.DoAndReturn(f) + return c } // UpdatedMetrics mocks base method. @@ -374,9 +1059,33 @@ func (m *MockConnectionTracer) UpdatedMetrics(arg0 *utils.RTTStats, arg1, arg2 p } // UpdatedMetrics indicates an expected call of UpdatedMetrics. -func (mr *MockConnectionTracerMockRecorder) UpdatedMetrics(arg0, arg1, arg2, arg3 any) *gomock.Call { +func (mr *MockConnectionTracerMockRecorder) UpdatedMetrics(arg0, arg1, arg2, arg3 any) *MockConnectionTracerUpdatedMetricsCall { mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "UpdatedMetrics", reflect.TypeOf((*MockConnectionTracer)(nil).UpdatedMetrics), arg0, arg1, arg2, arg3) + call := mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "UpdatedMetrics", reflect.TypeOf((*MockConnectionTracer)(nil).UpdatedMetrics), arg0, arg1, arg2, arg3) + return &MockConnectionTracerUpdatedMetricsCall{Call: call} +} + +// MockConnectionTracerUpdatedMetricsCall wrap *gomock.Call +type MockConnectionTracerUpdatedMetricsCall struct { + *gomock.Call +} + +// Return rewrite *gomock.Call.Return +func (c *MockConnectionTracerUpdatedMetricsCall) Return() *MockConnectionTracerUpdatedMetricsCall { + c.Call = c.Call.Return() + return c +} + +// Do rewrite *gomock.Call.Do +func (c *MockConnectionTracerUpdatedMetricsCall) Do(f func(*utils.RTTStats, protocol.ByteCount, protocol.ByteCount, int)) *MockConnectionTracerUpdatedMetricsCall { + c.Call = c.Call.Do(f) + return c +} + +// DoAndReturn rewrite *gomock.Call.DoAndReturn +func (c *MockConnectionTracerUpdatedMetricsCall) DoAndReturn(f func(*utils.RTTStats, protocol.ByteCount, protocol.ByteCount, int)) *MockConnectionTracerUpdatedMetricsCall { + c.Call = c.Call.DoAndReturn(f) + return c } // UpdatedPTOCount mocks base method. @@ -386,7 +1095,31 @@ func (m *MockConnectionTracer) UpdatedPTOCount(arg0 uint32) { } // UpdatedPTOCount indicates an expected call of UpdatedPTOCount. -func (mr *MockConnectionTracerMockRecorder) UpdatedPTOCount(arg0 any) *gomock.Call { +func (mr *MockConnectionTracerMockRecorder) UpdatedPTOCount(arg0 any) *MockConnectionTracerUpdatedPTOCountCall { mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "UpdatedPTOCount", reflect.TypeOf((*MockConnectionTracer)(nil).UpdatedPTOCount), arg0) + call := mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "UpdatedPTOCount", reflect.TypeOf((*MockConnectionTracer)(nil).UpdatedPTOCount), arg0) + return &MockConnectionTracerUpdatedPTOCountCall{Call: call} +} + +// MockConnectionTracerUpdatedPTOCountCall wrap *gomock.Call +type MockConnectionTracerUpdatedPTOCountCall struct { + *gomock.Call +} + +// Return rewrite *gomock.Call.Return +func (c *MockConnectionTracerUpdatedPTOCountCall) Return() *MockConnectionTracerUpdatedPTOCountCall { + c.Call = c.Call.Return() + return c +} + +// Do rewrite *gomock.Call.Do +func (c *MockConnectionTracerUpdatedPTOCountCall) Do(f func(uint32)) *MockConnectionTracerUpdatedPTOCountCall { + c.Call = c.Call.Do(f) + return c +} + +// DoAndReturn rewrite *gomock.Call.DoAndReturn +func (c *MockConnectionTracerUpdatedPTOCountCall) DoAndReturn(f func(uint32)) *MockConnectionTracerUpdatedPTOCountCall { + c.Call = c.Call.DoAndReturn(f) + return c } diff --git a/internal/mocks/logging/internal/tracer.go b/internal/mocks/logging/internal/tracer.go index 707abb431..305c74922 100644 --- a/internal/mocks/logging/internal/tracer.go +++ b/internal/mocks/logging/internal/tracer.go @@ -3,8 +3,9 @@ // // Generated by this command: // -// mockgen -build_flags=-tags=gomock -package internal -destination internal/tracer.go github.com/refraction-networking/uquic/internal/mocks/logging Tracer +// mockgen -typed -build_flags=-tags=gomock -package internal -destination internal/tracer.go github.com/quic-go/quic-go/internal/mocks/logging Tracer // + // Package internal is a generated GoMock package. package internal @@ -41,6 +42,78 @@ func (m *MockTracer) EXPECT() *MockTracerMockRecorder { return m.recorder } +// Close mocks base method. +func (m *MockTracer) Close() { + m.ctrl.T.Helper() + m.ctrl.Call(m, "Close") +} + +// Close indicates an expected call of Close. +func (mr *MockTracerMockRecorder) Close() *MockTracerCloseCall { + mr.mock.ctrl.T.Helper() + call := mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Close", reflect.TypeOf((*MockTracer)(nil).Close)) + return &MockTracerCloseCall{Call: call} +} + +// MockTracerCloseCall wrap *gomock.Call +type MockTracerCloseCall struct { + *gomock.Call +} + +// Return rewrite *gomock.Call.Return +func (c *MockTracerCloseCall) Return() *MockTracerCloseCall { + c.Call = c.Call.Return() + return c +} + +// Do rewrite *gomock.Call.Do +func (c *MockTracerCloseCall) Do(f func()) *MockTracerCloseCall { + c.Call = c.Call.Do(f) + return c +} + +// DoAndReturn rewrite *gomock.Call.DoAndReturn +func (c *MockTracerCloseCall) DoAndReturn(f func()) *MockTracerCloseCall { + c.Call = c.Call.DoAndReturn(f) + return c +} + +// Debug mocks base method. +func (m *MockTracer) Debug(arg0, arg1 string) { + m.ctrl.T.Helper() + m.ctrl.Call(m, "Debug", arg0, arg1) +} + +// Debug indicates an expected call of Debug. +func (mr *MockTracerMockRecorder) Debug(arg0, arg1 any) *MockTracerDebugCall { + mr.mock.ctrl.T.Helper() + call := mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Debug", reflect.TypeOf((*MockTracer)(nil).Debug), arg0, arg1) + return &MockTracerDebugCall{Call: call} +} + +// MockTracerDebugCall wrap *gomock.Call +type MockTracerDebugCall struct { + *gomock.Call +} + +// Return rewrite *gomock.Call.Return +func (c *MockTracerDebugCall) Return() *MockTracerDebugCall { + c.Call = c.Call.Return() + return c +} + +// Do rewrite *gomock.Call.Do +func (c *MockTracerDebugCall) Do(f func(string, string)) *MockTracerDebugCall { + c.Call = c.Call.Do(f) + return c +} + +// DoAndReturn rewrite *gomock.Call.DoAndReturn +func (c *MockTracerDebugCall) DoAndReturn(f func(string, string)) *MockTracerDebugCall { + c.Call = c.Call.DoAndReturn(f) + return c +} + // DroppedPacket mocks base method. func (m *MockTracer) DroppedPacket(arg0 net.Addr, arg1 logging.PacketType, arg2 protocol.ByteCount, arg3 logging.PacketDropReason) { m.ctrl.T.Helper() @@ -48,9 +121,33 @@ func (m *MockTracer) DroppedPacket(arg0 net.Addr, arg1 logging.PacketType, arg2 } // DroppedPacket indicates an expected call of DroppedPacket. -func (mr *MockTracerMockRecorder) DroppedPacket(arg0, arg1, arg2, arg3 any) *gomock.Call { +func (mr *MockTracerMockRecorder) DroppedPacket(arg0, arg1, arg2, arg3 any) *MockTracerDroppedPacketCall { mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "DroppedPacket", reflect.TypeOf((*MockTracer)(nil).DroppedPacket), arg0, arg1, arg2, arg3) + call := mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "DroppedPacket", reflect.TypeOf((*MockTracer)(nil).DroppedPacket), arg0, arg1, arg2, arg3) + return &MockTracerDroppedPacketCall{Call: call} +} + +// MockTracerDroppedPacketCall wrap *gomock.Call +type MockTracerDroppedPacketCall struct { + *gomock.Call +} + +// Return rewrite *gomock.Call.Return +func (c *MockTracerDroppedPacketCall) Return() *MockTracerDroppedPacketCall { + c.Call = c.Call.Return() + return c +} + +// Do rewrite *gomock.Call.Do +func (c *MockTracerDroppedPacketCall) Do(f func(net.Addr, logging.PacketType, protocol.ByteCount, logging.PacketDropReason)) *MockTracerDroppedPacketCall { + c.Call = c.Call.Do(f) + return c +} + +// DoAndReturn rewrite *gomock.Call.DoAndReturn +func (c *MockTracerDroppedPacketCall) DoAndReturn(f func(net.Addr, logging.PacketType, protocol.ByteCount, logging.PacketDropReason)) *MockTracerDroppedPacketCall { + c.Call = c.Call.DoAndReturn(f) + return c } // SentPacket mocks base method. @@ -60,19 +157,67 @@ func (m *MockTracer) SentPacket(arg0 net.Addr, arg1 *wire.Header, arg2 protocol. } // SentPacket indicates an expected call of SentPacket. -func (mr *MockTracerMockRecorder) SentPacket(arg0, arg1, arg2, arg3 any) *gomock.Call { +func (mr *MockTracerMockRecorder) SentPacket(arg0, arg1, arg2, arg3 any) *MockTracerSentPacketCall { mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "SentPacket", reflect.TypeOf((*MockTracer)(nil).SentPacket), arg0, arg1, arg2, arg3) + call := mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "SentPacket", reflect.TypeOf((*MockTracer)(nil).SentPacket), arg0, arg1, arg2, arg3) + return &MockTracerSentPacketCall{Call: call} +} + +// MockTracerSentPacketCall wrap *gomock.Call +type MockTracerSentPacketCall struct { + *gomock.Call +} + +// Return rewrite *gomock.Call.Return +func (c *MockTracerSentPacketCall) Return() *MockTracerSentPacketCall { + c.Call = c.Call.Return() + return c +} + +// Do rewrite *gomock.Call.Do +func (c *MockTracerSentPacketCall) Do(f func(net.Addr, *wire.Header, protocol.ByteCount, []logging.Frame)) *MockTracerSentPacketCall { + c.Call = c.Call.Do(f) + return c +} + +// DoAndReturn rewrite *gomock.Call.DoAndReturn +func (c *MockTracerSentPacketCall) DoAndReturn(f func(net.Addr, *wire.Header, protocol.ByteCount, []logging.Frame)) *MockTracerSentPacketCall { + c.Call = c.Call.DoAndReturn(f) + return c } // SentVersionNegotiationPacket mocks base method. -func (m *MockTracer) SentVersionNegotiationPacket(arg0 net.Addr, arg1, arg2 protocol.ArbitraryLenConnectionID, arg3 []protocol.VersionNumber) { +func (m *MockTracer) SentVersionNegotiationPacket(arg0 net.Addr, arg1, arg2 protocol.ArbitraryLenConnectionID, arg3 []protocol.Version) { m.ctrl.T.Helper() m.ctrl.Call(m, "SentVersionNegotiationPacket", arg0, arg1, arg2, arg3) } // SentVersionNegotiationPacket indicates an expected call of SentVersionNegotiationPacket. -func (mr *MockTracerMockRecorder) SentVersionNegotiationPacket(arg0, arg1, arg2, arg3 any) *gomock.Call { +func (mr *MockTracerMockRecorder) SentVersionNegotiationPacket(arg0, arg1, arg2, arg3 any) *MockTracerSentVersionNegotiationPacketCall { mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "SentVersionNegotiationPacket", reflect.TypeOf((*MockTracer)(nil).SentVersionNegotiationPacket), arg0, arg1, arg2, arg3) + call := mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "SentVersionNegotiationPacket", reflect.TypeOf((*MockTracer)(nil).SentVersionNegotiationPacket), arg0, arg1, arg2, arg3) + return &MockTracerSentVersionNegotiationPacketCall{Call: call} +} + +// MockTracerSentVersionNegotiationPacketCall wrap *gomock.Call +type MockTracerSentVersionNegotiationPacketCall struct { + *gomock.Call +} + +// Return rewrite *gomock.Call.Return +func (c *MockTracerSentVersionNegotiationPacketCall) Return() *MockTracerSentVersionNegotiationPacketCall { + c.Call = c.Call.Return() + return c +} + +// Do rewrite *gomock.Call.Do +func (c *MockTracerSentVersionNegotiationPacketCall) Do(f func(net.Addr, protocol.ArbitraryLenConnectionID, protocol.ArbitraryLenConnectionID, []protocol.Version)) *MockTracerSentVersionNegotiationPacketCall { + c.Call = c.Call.Do(f) + return c +} + +// DoAndReturn rewrite *gomock.Call.DoAndReturn +func (c *MockTracerSentVersionNegotiationPacketCall) DoAndReturn(f func(net.Addr, protocol.ArbitraryLenConnectionID, protocol.ArbitraryLenConnectionID, []protocol.Version)) *MockTracerSentVersionNegotiationPacketCall { + c.Call = c.Call.DoAndReturn(f) + return c } diff --git a/internal/mocks/logging/mockgen.go b/internal/mocks/logging/mockgen.go index a4a7f4f05..160e3cb7d 100644 --- a/internal/mocks/logging/mockgen.go +++ b/internal/mocks/logging/mockgen.go @@ -9,14 +9,16 @@ import ( "github.com/refraction-networking/uquic/logging" ) -//go:generate sh -c "go run go.uber.org/mock/mockgen -build_flags=\"-tags=gomock\" -package internal -destination internal/tracer.go github.com/refraction-networking/uquic/internal/mocks/logging Tracer" +//go:generate sh -c "go run go.uber.org/mock/mockgen -typed -build_flags=\"-tags=gomock\" -package internal -destination internal/tracer.go github.com/quic-go/quic-go/internal/mocks/logging Tracer" type Tracer interface { SentPacket(net.Addr, *logging.Header, logging.ByteCount, []logging.Frame) SentVersionNegotiationPacket(_ net.Addr, dest, src logging.ArbitraryLenConnectionID, _ []logging.VersionNumber) DroppedPacket(net.Addr, logging.PacketType, logging.ByteCount, logging.PacketDropReason) + Debug(name, msg string) + Close() } -//go:generate sh -c "go run go.uber.org/mock/mockgen -build_flags=\"-tags=gomock\" -package internal -destination internal/connection_tracer.go github.com/refraction-networking/uquic/internal/mocks/logging ConnectionTracer" +//go:generate sh -c "go run go.uber.org/mock/mockgen -typed -build_flags=\"-tags=gomock\" -package internal -destination internal/connection_tracer.go github.com/quic-go/quic-go/internal/mocks/logging ConnectionTracer" type ConnectionTracer interface { StartedConnection(local, remote net.Addr, srcConnID, destConnID logging.ConnectionID) NegotiatedVersion(chosen logging.VersionNumber, clientVersions, serverVersions []logging.VersionNumber) @@ -31,7 +33,7 @@ type ConnectionTracer interface { ReceivedLongHeaderPacket(*logging.ExtendedHeader, logging.ByteCount, logging.ECN, []logging.Frame) ReceivedShortHeaderPacket(*logging.ShortHeader, logging.ByteCount, logging.ECN, []logging.Frame) BufferedPacket(logging.PacketType, logging.ByteCount) - DroppedPacket(logging.PacketType, logging.ByteCount, logging.PacketDropReason) + DroppedPacket(logging.PacketType, logging.PacketNumber, logging.ByteCount, logging.PacketDropReason) UpdatedMetrics(rttStats *logging.RTTStats, cwnd, bytesInFlight logging.ByteCount, packetsInFlight int) AcknowledgedPacket(logging.EncryptionLevel, logging.PacketNumber) LostPacket(logging.EncryptionLevel, logging.PacketNumber, logging.PacketLossReason) @@ -45,6 +47,7 @@ type ConnectionTracer interface { LossTimerExpired(logging.TimerType, logging.EncryptionLevel) LossTimerCanceled() ECNStateUpdated(state logging.ECNState, trigger logging.ECNStateTrigger) + ChoseALPN(protocol string) // Close is called when the connection is closed. Close() Debug(name, msg string) diff --git a/internal/mocks/logging/tracer.go b/internal/mocks/logging/tracer.go index 9fc620996..ed638ba2e 100644 --- a/internal/mocks/logging/tracer.go +++ b/internal/mocks/logging/tracer.go @@ -25,5 +25,11 @@ func NewMockTracer(ctrl *gomock.Controller) (*logging.Tracer, *MockTracer) { DroppedPacket: func(remote net.Addr, typ logging.PacketType, size logging.ByteCount, reason logging.PacketDropReason) { t.DroppedPacket(remote, typ, size, reason) }, + Debug: func(name, msg string) { + t.Debug(name, msg) + }, + Close: func() { + t.Close() + }, }, t } diff --git a/internal/mocks/long_header_opener.go b/internal/mocks/long_header_opener.go index adfcd1f92..c409a0896 100644 --- a/internal/mocks/long_header_opener.go +++ b/internal/mocks/long_header_opener.go @@ -3,8 +3,9 @@ // // Generated by this command: // -// mockgen -build_flags=-tags=gomock -package mocks -destination long_header_opener.go github.com/refraction-networking/uquic/internal/handshake LongHeaderOpener +// mockgen -typed -build_flags=-tags=gomock -package mocks -destination long_header_opener.go github.com/quic-go/quic-go/internal/handshake LongHeaderOpener // + // Package mocks is a generated GoMock package. package mocks @@ -47,9 +48,33 @@ func (m *MockLongHeaderOpener) DecodePacketNumber(arg0 protocol.PacketNumber, ar } // DecodePacketNumber indicates an expected call of DecodePacketNumber. -func (mr *MockLongHeaderOpenerMockRecorder) DecodePacketNumber(arg0, arg1 any) *gomock.Call { +func (mr *MockLongHeaderOpenerMockRecorder) DecodePacketNumber(arg0, arg1 any) *MockLongHeaderOpenerDecodePacketNumberCall { mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "DecodePacketNumber", reflect.TypeOf((*MockLongHeaderOpener)(nil).DecodePacketNumber), arg0, arg1) + call := mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "DecodePacketNumber", reflect.TypeOf((*MockLongHeaderOpener)(nil).DecodePacketNumber), arg0, arg1) + return &MockLongHeaderOpenerDecodePacketNumberCall{Call: call} +} + +// MockLongHeaderOpenerDecodePacketNumberCall wrap *gomock.Call +type MockLongHeaderOpenerDecodePacketNumberCall struct { + *gomock.Call +} + +// Return rewrite *gomock.Call.Return +func (c *MockLongHeaderOpenerDecodePacketNumberCall) Return(arg0 protocol.PacketNumber) *MockLongHeaderOpenerDecodePacketNumberCall { + c.Call = c.Call.Return(arg0) + return c +} + +// Do rewrite *gomock.Call.Do +func (c *MockLongHeaderOpenerDecodePacketNumberCall) Do(f func(protocol.PacketNumber, protocol.PacketNumberLen) protocol.PacketNumber) *MockLongHeaderOpenerDecodePacketNumberCall { + c.Call = c.Call.Do(f) + return c +} + +// DoAndReturn rewrite *gomock.Call.DoAndReturn +func (c *MockLongHeaderOpenerDecodePacketNumberCall) DoAndReturn(f func(protocol.PacketNumber, protocol.PacketNumberLen) protocol.PacketNumber) *MockLongHeaderOpenerDecodePacketNumberCall { + c.Call = c.Call.DoAndReturn(f) + return c } // DecryptHeader mocks base method. @@ -59,9 +84,33 @@ func (m *MockLongHeaderOpener) DecryptHeader(arg0 []byte, arg1 *byte, arg2 []byt } // DecryptHeader indicates an expected call of DecryptHeader. -func (mr *MockLongHeaderOpenerMockRecorder) DecryptHeader(arg0, arg1, arg2 any) *gomock.Call { +func (mr *MockLongHeaderOpenerMockRecorder) DecryptHeader(arg0, arg1, arg2 any) *MockLongHeaderOpenerDecryptHeaderCall { mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "DecryptHeader", reflect.TypeOf((*MockLongHeaderOpener)(nil).DecryptHeader), arg0, arg1, arg2) + call := mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "DecryptHeader", reflect.TypeOf((*MockLongHeaderOpener)(nil).DecryptHeader), arg0, arg1, arg2) + return &MockLongHeaderOpenerDecryptHeaderCall{Call: call} +} + +// MockLongHeaderOpenerDecryptHeaderCall wrap *gomock.Call +type MockLongHeaderOpenerDecryptHeaderCall struct { + *gomock.Call +} + +// Return rewrite *gomock.Call.Return +func (c *MockLongHeaderOpenerDecryptHeaderCall) Return() *MockLongHeaderOpenerDecryptHeaderCall { + c.Call = c.Call.Return() + return c +} + +// Do rewrite *gomock.Call.Do +func (c *MockLongHeaderOpenerDecryptHeaderCall) Do(f func([]byte, *byte, []byte)) *MockLongHeaderOpenerDecryptHeaderCall { + c.Call = c.Call.Do(f) + return c +} + +// DoAndReturn rewrite *gomock.Call.DoAndReturn +func (c *MockLongHeaderOpenerDecryptHeaderCall) DoAndReturn(f func([]byte, *byte, []byte)) *MockLongHeaderOpenerDecryptHeaderCall { + c.Call = c.Call.DoAndReturn(f) + return c } // Open mocks base method. @@ -74,7 +123,31 @@ func (m *MockLongHeaderOpener) Open(arg0, arg1 []byte, arg2 protocol.PacketNumbe } // Open indicates an expected call of Open. -func (mr *MockLongHeaderOpenerMockRecorder) Open(arg0, arg1, arg2, arg3 any) *gomock.Call { +func (mr *MockLongHeaderOpenerMockRecorder) Open(arg0, arg1, arg2, arg3 any) *MockLongHeaderOpenerOpenCall { mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Open", reflect.TypeOf((*MockLongHeaderOpener)(nil).Open), arg0, arg1, arg2, arg3) + call := mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Open", reflect.TypeOf((*MockLongHeaderOpener)(nil).Open), arg0, arg1, arg2, arg3) + return &MockLongHeaderOpenerOpenCall{Call: call} +} + +// MockLongHeaderOpenerOpenCall wrap *gomock.Call +type MockLongHeaderOpenerOpenCall struct { + *gomock.Call +} + +// Return rewrite *gomock.Call.Return +func (c *MockLongHeaderOpenerOpenCall) Return(arg0 []byte, arg1 error) *MockLongHeaderOpenerOpenCall { + c.Call = c.Call.Return(arg0, arg1) + return c +} + +// Do rewrite *gomock.Call.Do +func (c *MockLongHeaderOpenerOpenCall) Do(f func([]byte, []byte, protocol.PacketNumber, []byte) ([]byte, error)) *MockLongHeaderOpenerOpenCall { + c.Call = c.Call.Do(f) + return c +} + +// DoAndReturn rewrite *gomock.Call.DoAndReturn +func (c *MockLongHeaderOpenerOpenCall) DoAndReturn(f func([]byte, []byte, protocol.PacketNumber, []byte) ([]byte, error)) *MockLongHeaderOpenerOpenCall { + c.Call = c.Call.DoAndReturn(f) + return c } diff --git a/internal/mocks/mockgen.go b/internal/mocks/mockgen.go index 30174b48a..b736631d5 100644 --- a/internal/mocks/mockgen.go +++ b/internal/mocks/mockgen.go @@ -2,18 +2,18 @@ package mocks -//go:generate sh -c "go run go.uber.org/mock/mockgen -build_flags=\"-tags=gomock\" -package mockquic -destination quic/stream.go github.com/refraction-networking/uquic Stream" -//go:generate sh -c "go run go.uber.org/mock/mockgen -build_flags=\"-tags=gomock\" -package mockquic -destination quic/early_conn_tmp.go github.com/refraction-networking/uquic EarlyConnection && sed 's/qtls.ConnectionState/quic.ConnectionState/g' quic/early_conn_tmp.go > quic/early_conn.go && rm quic/early_conn_tmp.go && go run golang.org/x/tools/cmd/goimports -w quic/early_conn.go" -//go:generate sh -c "go run go.uber.org/mock/mockgen -build_flags=\"-tags=gomock\" -package mocks -destination short_header_sealer.go github.com/refraction-networking/uquic/internal/handshake ShortHeaderSealer" -//go:generate sh -c "go run go.uber.org/mock/mockgen -build_flags=\"-tags=gomock\" -package mocks -destination short_header_opener.go github.com/refraction-networking/uquic/internal/handshake ShortHeaderOpener" -//go:generate sh -c "go run go.uber.org/mock/mockgen -build_flags=\"-tags=gomock\" -package mocks -destination long_header_opener.go github.com/refraction-networking/uquic/internal/handshake LongHeaderOpener" -//go:generate sh -c "go run go.uber.org/mock/mockgen -build_flags=\"-tags=gomock\" -package mocks -destination crypto_setup_tmp.go github.com/refraction-networking/uquic/internal/handshake CryptoSetup && sed -E 's~github.com/quic-go/qtls[[:alnum:]_-]*~github.com/refraction-networking/uquic/internal/qtls~g; s~qtls.ConnectionStateWith0RTT~qtls.ConnectionState~g' crypto_setup_tmp.go > crypto_setup.go && rm crypto_setup_tmp.go && go run golang.org/x/tools/cmd/goimports -w crypto_setup.go" -//go:generate sh -c "go run go.uber.org/mock/mockgen -build_flags=\"-tags=gomock\" -package mocks -destination stream_flow_controller.go github.com/refraction-networking/uquic/internal/flowcontrol StreamFlowController" -//go:generate sh -c "go run go.uber.org/mock/mockgen -build_flags=\"-tags=gomock\" -package mocks -destination congestion.go github.com/refraction-networking/uquic/internal/congestion SendAlgorithmWithDebugInfos" -//go:generate sh -c "go run go.uber.org/mock/mockgen -build_flags=\"-tags=gomock\" -package mocks -destination connection_flow_controller.go github.com/refraction-networking/uquic/internal/flowcontrol ConnectionFlowController" -//go:generate sh -c "go run go.uber.org/mock/mockgen -build_flags=\"-tags=gomock\" -package mockackhandler -destination ackhandler/sent_packet_handler.go github.com/refraction-networking/uquic/internal/ackhandler SentPacketHandler" -//go:generate sh -c "go run go.uber.org/mock/mockgen -build_flags=\"-tags=gomock\" -package mockackhandler -destination ackhandler/received_packet_handler.go github.com/refraction-networking/uquic/internal/ackhandler ReceivedPacketHandler" +//go:generate sh -c "go run go.uber.org/mock/mockgen -typed -build_flags=\"-tags=gomock\" -package mockquic -destination quic/stream.go github.com/quic-go/quic-go Stream" +//go:generate sh -c "go run go.uber.org/mock/mockgen -typed -build_flags=\"-tags=gomock\" -package mockquic -destination quic/early_conn_tmp.go github.com/quic-go/quic-go EarlyConnection && sed 's/qtls.ConnectionState/quic.ConnectionState/g' quic/early_conn_tmp.go > quic/early_conn.go && rm quic/early_conn_tmp.go && go run golang.org/x/tools/cmd/goimports -w quic/early_conn.go" +//go:generate sh -c "go run go.uber.org/mock/mockgen -typed -build_flags=\"-tags=gomock\" -package mocks -destination short_header_sealer.go github.com/quic-go/quic-go/internal/handshake ShortHeaderSealer" +//go:generate sh -c "go run go.uber.org/mock/mockgen -typed -build_flags=\"-tags=gomock\" -package mocks -destination short_header_opener.go github.com/quic-go/quic-go/internal/handshake ShortHeaderOpener" +//go:generate sh -c "go run go.uber.org/mock/mockgen -typed -build_flags=\"-tags=gomock\" -package mocks -destination long_header_opener.go github.com/quic-go/quic-go/internal/handshake LongHeaderOpener" +//go:generate sh -c "go run go.uber.org/mock/mockgen -typed -build_flags=\"-tags=gomock\" -package mocks -destination crypto_setup_tmp.go github.com/quic-go/quic-go/internal/handshake CryptoSetup && sed -E 's~github.com/quic-go/qtls[[:alnum:]_-]*~github.com/quic-go/quic-go/internal/qtls~g; s~qtls.ConnectionStateWith0RTT~qtls.ConnectionState~g' crypto_setup_tmp.go > crypto_setup.go && rm crypto_setup_tmp.go && go run golang.org/x/tools/cmd/goimports -w crypto_setup.go" +//go:generate sh -c "go run go.uber.org/mock/mockgen -typed -build_flags=\"-tags=gomock\" -package mocks -destination stream_flow_controller.go github.com/quic-go/quic-go/internal/flowcontrol StreamFlowController" +//go:generate sh -c "go run go.uber.org/mock/mockgen -typed -build_flags=\"-tags=gomock\" -package mocks -destination congestion.go github.com/quic-go/quic-go/internal/congestion SendAlgorithmWithDebugInfos" +//go:generate sh -c "go run go.uber.org/mock/mockgen -typed -build_flags=\"-tags=gomock\" -package mocks -destination connection_flow_controller.go github.com/quic-go/quic-go/internal/flowcontrol ConnectionFlowController" +//go:generate sh -c "go run go.uber.org/mock/mockgen -typed -build_flags=\"-tags=gomock\" -package mockackhandler -destination ackhandler/sent_packet_handler.go github.com/quic-go/quic-go/internal/ackhandler SentPacketHandler" +//go:generate sh -c "go run go.uber.org/mock/mockgen -typed -build_flags=\"-tags=gomock\" -package mockackhandler -destination ackhandler/received_packet_handler.go github.com/quic-go/quic-go/internal/ackhandler ReceivedPacketHandler" // The following command produces a warning message on OSX, however, it still generates the correct mock file. // See https://github.com/golang/mock/issues/339 for details. -//go:generate sh -c "go run go.uber.org/mock/mockgen -build_flags=\"-tags=gomock\" -package mocktls -destination tls/client_session_cache.go crypto/tls ClientSessionCache" +//go:generate sh -c "go run go.uber.org/mock/mockgen -typed -build_flags=\"-tags=gomock\" -package mocktls -destination tls/client_session_cache.go crypto/tls ClientSessionCache" diff --git a/internal/mocks/quic/early_conn.go b/internal/mocks/quic/early_conn.go index cfcb27835..cc05b845b 100644 --- a/internal/mocks/quic/early_conn.go +++ b/internal/mocks/quic/early_conn.go @@ -3,8 +3,9 @@ // // Generated by this command: // -// mockgen -build_flags=-tags=gomock -package mockquic -destination quic/early_conn_tmp.go github.com/refraction-networking/uquic EarlyConnection +// mockgen -typed -build_flags=-tags=gomock -package mockquic -destination quic/early_conn_tmp.go github.com/quic-go/quic-go EarlyConnection // + // Package mockquic is a generated GoMock package. package mockquic @@ -51,9 +52,33 @@ func (m *MockEarlyConnection) AcceptStream(arg0 context.Context) (quic.Stream, e } // AcceptStream indicates an expected call of AcceptStream. -func (mr *MockEarlyConnectionMockRecorder) AcceptStream(arg0 any) *gomock.Call { +func (mr *MockEarlyConnectionMockRecorder) AcceptStream(arg0 any) *MockEarlyConnectionAcceptStreamCall { mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "AcceptStream", reflect.TypeOf((*MockEarlyConnection)(nil).AcceptStream), arg0) + call := mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "AcceptStream", reflect.TypeOf((*MockEarlyConnection)(nil).AcceptStream), arg0) + return &MockEarlyConnectionAcceptStreamCall{Call: call} +} + +// MockEarlyConnectionAcceptStreamCall wrap *gomock.Call +type MockEarlyConnectionAcceptStreamCall struct { + *gomock.Call +} + +// Return rewrite *gomock.Call.Return +func (c *MockEarlyConnectionAcceptStreamCall) Return(arg0 quic.Stream, arg1 error) *MockEarlyConnectionAcceptStreamCall { + c.Call = c.Call.Return(arg0, arg1) + return c +} + +// Do rewrite *gomock.Call.Do +func (c *MockEarlyConnectionAcceptStreamCall) Do(f func(context.Context) (quic.Stream, error)) *MockEarlyConnectionAcceptStreamCall { + c.Call = c.Call.Do(f) + return c +} + +// DoAndReturn rewrite *gomock.Call.DoAndReturn +func (c *MockEarlyConnectionAcceptStreamCall) DoAndReturn(f func(context.Context) (quic.Stream, error)) *MockEarlyConnectionAcceptStreamCall { + c.Call = c.Call.DoAndReturn(f) + return c } // AcceptUniStream mocks base method. @@ -66,9 +91,33 @@ func (m *MockEarlyConnection) AcceptUniStream(arg0 context.Context) (quic.Receiv } // AcceptUniStream indicates an expected call of AcceptUniStream. -func (mr *MockEarlyConnectionMockRecorder) AcceptUniStream(arg0 any) *gomock.Call { +func (mr *MockEarlyConnectionMockRecorder) AcceptUniStream(arg0 any) *MockEarlyConnectionAcceptUniStreamCall { mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "AcceptUniStream", reflect.TypeOf((*MockEarlyConnection)(nil).AcceptUniStream), arg0) + call := mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "AcceptUniStream", reflect.TypeOf((*MockEarlyConnection)(nil).AcceptUniStream), arg0) + return &MockEarlyConnectionAcceptUniStreamCall{Call: call} +} + +// MockEarlyConnectionAcceptUniStreamCall wrap *gomock.Call +type MockEarlyConnectionAcceptUniStreamCall struct { + *gomock.Call +} + +// Return rewrite *gomock.Call.Return +func (c *MockEarlyConnectionAcceptUniStreamCall) Return(arg0 quic.ReceiveStream, arg1 error) *MockEarlyConnectionAcceptUniStreamCall { + c.Call = c.Call.Return(arg0, arg1) + return c +} + +// Do rewrite *gomock.Call.Do +func (c *MockEarlyConnectionAcceptUniStreamCall) Do(f func(context.Context) (quic.ReceiveStream, error)) *MockEarlyConnectionAcceptUniStreamCall { + c.Call = c.Call.Do(f) + return c +} + +// DoAndReturn rewrite *gomock.Call.DoAndReturn +func (c *MockEarlyConnectionAcceptUniStreamCall) DoAndReturn(f func(context.Context) (quic.ReceiveStream, error)) *MockEarlyConnectionAcceptUniStreamCall { + c.Call = c.Call.DoAndReturn(f) + return c } // CloseWithError mocks base method. @@ -80,9 +129,33 @@ func (m *MockEarlyConnection) CloseWithError(arg0 qerr.ApplicationErrorCode, arg } // CloseWithError indicates an expected call of CloseWithError. -func (mr *MockEarlyConnectionMockRecorder) CloseWithError(arg0, arg1 any) *gomock.Call { +func (mr *MockEarlyConnectionMockRecorder) CloseWithError(arg0, arg1 any) *MockEarlyConnectionCloseWithErrorCall { mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "CloseWithError", reflect.TypeOf((*MockEarlyConnection)(nil).CloseWithError), arg0, arg1) + call := mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "CloseWithError", reflect.TypeOf((*MockEarlyConnection)(nil).CloseWithError), arg0, arg1) + return &MockEarlyConnectionCloseWithErrorCall{Call: call} +} + +// MockEarlyConnectionCloseWithErrorCall wrap *gomock.Call +type MockEarlyConnectionCloseWithErrorCall struct { + *gomock.Call +} + +// Return rewrite *gomock.Call.Return +func (c *MockEarlyConnectionCloseWithErrorCall) Return(arg0 error) *MockEarlyConnectionCloseWithErrorCall { + c.Call = c.Call.Return(arg0) + return c +} + +// Do rewrite *gomock.Call.Do +func (c *MockEarlyConnectionCloseWithErrorCall) Do(f func(qerr.ApplicationErrorCode, string) error) *MockEarlyConnectionCloseWithErrorCall { + c.Call = c.Call.Do(f) + return c +} + +// DoAndReturn rewrite *gomock.Call.DoAndReturn +func (c *MockEarlyConnectionCloseWithErrorCall) DoAndReturn(f func(qerr.ApplicationErrorCode, string) error) *MockEarlyConnectionCloseWithErrorCall { + c.Call = c.Call.DoAndReturn(f) + return c } // ConnectionState mocks base method. @@ -94,9 +167,33 @@ func (m *MockEarlyConnection) ConnectionState() quic.ConnectionState { } // ConnectionState indicates an expected call of ConnectionState. -func (mr *MockEarlyConnectionMockRecorder) ConnectionState() *gomock.Call { +func (mr *MockEarlyConnectionMockRecorder) ConnectionState() *MockEarlyConnectionConnectionStateCall { mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "ConnectionState", reflect.TypeOf((*MockEarlyConnection)(nil).ConnectionState)) + call := mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "ConnectionState", reflect.TypeOf((*MockEarlyConnection)(nil).ConnectionState)) + return &MockEarlyConnectionConnectionStateCall{Call: call} +} + +// MockEarlyConnectionConnectionStateCall wrap *gomock.Call +type MockEarlyConnectionConnectionStateCall struct { + *gomock.Call +} + +// Return rewrite *gomock.Call.Return +func (c *MockEarlyConnectionConnectionStateCall) Return(arg0 quic.ConnectionState) *MockEarlyConnectionConnectionStateCall { + c.Call = c.Call.Return(arg0) + return c +} + +// Do rewrite *gomock.Call.Do +func (c *MockEarlyConnectionConnectionStateCall) Do(f func() quic.ConnectionState) *MockEarlyConnectionConnectionStateCall { + c.Call = c.Call.Do(f) + return c +} + +// DoAndReturn rewrite *gomock.Call.DoAndReturn +func (c *MockEarlyConnectionConnectionStateCall) DoAndReturn(f func() quic.ConnectionState) *MockEarlyConnectionConnectionStateCall { + c.Call = c.Call.DoAndReturn(f) + return c } // Context mocks base method. @@ -108,9 +205,33 @@ func (m *MockEarlyConnection) Context() context.Context { } // Context indicates an expected call of Context. -func (mr *MockEarlyConnectionMockRecorder) Context() *gomock.Call { +func (mr *MockEarlyConnectionMockRecorder) Context() *MockEarlyConnectionContextCall { mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Context", reflect.TypeOf((*MockEarlyConnection)(nil).Context)) + call := mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Context", reflect.TypeOf((*MockEarlyConnection)(nil).Context)) + return &MockEarlyConnectionContextCall{Call: call} +} + +// MockEarlyConnectionContextCall wrap *gomock.Call +type MockEarlyConnectionContextCall struct { + *gomock.Call +} + +// Return rewrite *gomock.Call.Return +func (c *MockEarlyConnectionContextCall) Return(arg0 context.Context) *MockEarlyConnectionContextCall { + c.Call = c.Call.Return(arg0) + return c +} + +// Do rewrite *gomock.Call.Do +func (c *MockEarlyConnectionContextCall) Do(f func() context.Context) *MockEarlyConnectionContextCall { + c.Call = c.Call.Do(f) + return c +} + +// DoAndReturn rewrite *gomock.Call.DoAndReturn +func (c *MockEarlyConnectionContextCall) DoAndReturn(f func() context.Context) *MockEarlyConnectionContextCall { + c.Call = c.Call.DoAndReturn(f) + return c } // HandshakeComplete mocks base method. @@ -122,9 +243,33 @@ func (m *MockEarlyConnection) HandshakeComplete() <-chan struct{} { } // HandshakeComplete indicates an expected call of HandshakeComplete. -func (mr *MockEarlyConnectionMockRecorder) HandshakeComplete() *gomock.Call { +func (mr *MockEarlyConnectionMockRecorder) HandshakeComplete() *MockEarlyConnectionHandshakeCompleteCall { mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "HandshakeComplete", reflect.TypeOf((*MockEarlyConnection)(nil).HandshakeComplete)) + call := mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "HandshakeComplete", reflect.TypeOf((*MockEarlyConnection)(nil).HandshakeComplete)) + return &MockEarlyConnectionHandshakeCompleteCall{Call: call} +} + +// MockEarlyConnectionHandshakeCompleteCall wrap *gomock.Call +type MockEarlyConnectionHandshakeCompleteCall struct { + *gomock.Call +} + +// Return rewrite *gomock.Call.Return +func (c *MockEarlyConnectionHandshakeCompleteCall) Return(arg0 <-chan struct{}) *MockEarlyConnectionHandshakeCompleteCall { + c.Call = c.Call.Return(arg0) + return c +} + +// Do rewrite *gomock.Call.Do +func (c *MockEarlyConnectionHandshakeCompleteCall) Do(f func() <-chan struct{}) *MockEarlyConnectionHandshakeCompleteCall { + c.Call = c.Call.Do(f) + return c +} + +// DoAndReturn rewrite *gomock.Call.DoAndReturn +func (c *MockEarlyConnectionHandshakeCompleteCall) DoAndReturn(f func() <-chan struct{}) *MockEarlyConnectionHandshakeCompleteCall { + c.Call = c.Call.DoAndReturn(f) + return c } // LocalAddr mocks base method. @@ -136,9 +281,33 @@ func (m *MockEarlyConnection) LocalAddr() net.Addr { } // LocalAddr indicates an expected call of LocalAddr. -func (mr *MockEarlyConnectionMockRecorder) LocalAddr() *gomock.Call { +func (mr *MockEarlyConnectionMockRecorder) LocalAddr() *MockEarlyConnectionLocalAddrCall { mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "LocalAddr", reflect.TypeOf((*MockEarlyConnection)(nil).LocalAddr)) + call := mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "LocalAddr", reflect.TypeOf((*MockEarlyConnection)(nil).LocalAddr)) + return &MockEarlyConnectionLocalAddrCall{Call: call} +} + +// MockEarlyConnectionLocalAddrCall wrap *gomock.Call +type MockEarlyConnectionLocalAddrCall struct { + *gomock.Call +} + +// Return rewrite *gomock.Call.Return +func (c *MockEarlyConnectionLocalAddrCall) Return(arg0 net.Addr) *MockEarlyConnectionLocalAddrCall { + c.Call = c.Call.Return(arg0) + return c +} + +// Do rewrite *gomock.Call.Do +func (c *MockEarlyConnectionLocalAddrCall) Do(f func() net.Addr) *MockEarlyConnectionLocalAddrCall { + c.Call = c.Call.Do(f) + return c +} + +// DoAndReturn rewrite *gomock.Call.DoAndReturn +func (c *MockEarlyConnectionLocalAddrCall) DoAndReturn(f func() net.Addr) *MockEarlyConnectionLocalAddrCall { + c.Call = c.Call.DoAndReturn(f) + return c } // NextConnection mocks base method. @@ -150,9 +319,33 @@ func (m *MockEarlyConnection) NextConnection() quic.Connection { } // NextConnection indicates an expected call of NextConnection. -func (mr *MockEarlyConnectionMockRecorder) NextConnection() *gomock.Call { +func (mr *MockEarlyConnectionMockRecorder) NextConnection() *MockEarlyConnectionNextConnectionCall { mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "NextConnection", reflect.TypeOf((*MockEarlyConnection)(nil).NextConnection)) + call := mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "NextConnection", reflect.TypeOf((*MockEarlyConnection)(nil).NextConnection)) + return &MockEarlyConnectionNextConnectionCall{Call: call} +} + +// MockEarlyConnectionNextConnectionCall wrap *gomock.Call +type MockEarlyConnectionNextConnectionCall struct { + *gomock.Call +} + +// Return rewrite *gomock.Call.Return +func (c *MockEarlyConnectionNextConnectionCall) Return(arg0 quic.Connection) *MockEarlyConnectionNextConnectionCall { + c.Call = c.Call.Return(arg0) + return c +} + +// Do rewrite *gomock.Call.Do +func (c *MockEarlyConnectionNextConnectionCall) Do(f func() quic.Connection) *MockEarlyConnectionNextConnectionCall { + c.Call = c.Call.Do(f) + return c +} + +// DoAndReturn rewrite *gomock.Call.DoAndReturn +func (c *MockEarlyConnectionNextConnectionCall) DoAndReturn(f func() quic.Connection) *MockEarlyConnectionNextConnectionCall { + c.Call = c.Call.DoAndReturn(f) + return c } // OpenStream mocks base method. @@ -165,9 +358,33 @@ func (m *MockEarlyConnection) OpenStream() (quic.Stream, error) { } // OpenStream indicates an expected call of OpenStream. -func (mr *MockEarlyConnectionMockRecorder) OpenStream() *gomock.Call { +func (mr *MockEarlyConnectionMockRecorder) OpenStream() *MockEarlyConnectionOpenStreamCall { mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "OpenStream", reflect.TypeOf((*MockEarlyConnection)(nil).OpenStream)) + call := mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "OpenStream", reflect.TypeOf((*MockEarlyConnection)(nil).OpenStream)) + return &MockEarlyConnectionOpenStreamCall{Call: call} +} + +// MockEarlyConnectionOpenStreamCall wrap *gomock.Call +type MockEarlyConnectionOpenStreamCall struct { + *gomock.Call +} + +// Return rewrite *gomock.Call.Return +func (c *MockEarlyConnectionOpenStreamCall) Return(arg0 quic.Stream, arg1 error) *MockEarlyConnectionOpenStreamCall { + c.Call = c.Call.Return(arg0, arg1) + return c +} + +// Do rewrite *gomock.Call.Do +func (c *MockEarlyConnectionOpenStreamCall) Do(f func() (quic.Stream, error)) *MockEarlyConnectionOpenStreamCall { + c.Call = c.Call.Do(f) + return c +} + +// DoAndReturn rewrite *gomock.Call.DoAndReturn +func (c *MockEarlyConnectionOpenStreamCall) DoAndReturn(f func() (quic.Stream, error)) *MockEarlyConnectionOpenStreamCall { + c.Call = c.Call.DoAndReturn(f) + return c } // OpenStreamSync mocks base method. @@ -180,9 +397,33 @@ func (m *MockEarlyConnection) OpenStreamSync(arg0 context.Context) (quic.Stream, } // OpenStreamSync indicates an expected call of OpenStreamSync. -func (mr *MockEarlyConnectionMockRecorder) OpenStreamSync(arg0 any) *gomock.Call { +func (mr *MockEarlyConnectionMockRecorder) OpenStreamSync(arg0 any) *MockEarlyConnectionOpenStreamSyncCall { mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "OpenStreamSync", reflect.TypeOf((*MockEarlyConnection)(nil).OpenStreamSync), arg0) + call := mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "OpenStreamSync", reflect.TypeOf((*MockEarlyConnection)(nil).OpenStreamSync), arg0) + return &MockEarlyConnectionOpenStreamSyncCall{Call: call} +} + +// MockEarlyConnectionOpenStreamSyncCall wrap *gomock.Call +type MockEarlyConnectionOpenStreamSyncCall struct { + *gomock.Call +} + +// Return rewrite *gomock.Call.Return +func (c *MockEarlyConnectionOpenStreamSyncCall) Return(arg0 quic.Stream, arg1 error) *MockEarlyConnectionOpenStreamSyncCall { + c.Call = c.Call.Return(arg0, arg1) + return c +} + +// Do rewrite *gomock.Call.Do +func (c *MockEarlyConnectionOpenStreamSyncCall) Do(f func(context.Context) (quic.Stream, error)) *MockEarlyConnectionOpenStreamSyncCall { + c.Call = c.Call.Do(f) + return c +} + +// DoAndReturn rewrite *gomock.Call.DoAndReturn +func (c *MockEarlyConnectionOpenStreamSyncCall) DoAndReturn(f func(context.Context) (quic.Stream, error)) *MockEarlyConnectionOpenStreamSyncCall { + c.Call = c.Call.DoAndReturn(f) + return c } // OpenUniStream mocks base method. @@ -195,9 +436,33 @@ func (m *MockEarlyConnection) OpenUniStream() (quic.SendStream, error) { } // OpenUniStream indicates an expected call of OpenUniStream. -func (mr *MockEarlyConnectionMockRecorder) OpenUniStream() *gomock.Call { +func (mr *MockEarlyConnectionMockRecorder) OpenUniStream() *MockEarlyConnectionOpenUniStreamCall { mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "OpenUniStream", reflect.TypeOf((*MockEarlyConnection)(nil).OpenUniStream)) + call := mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "OpenUniStream", reflect.TypeOf((*MockEarlyConnection)(nil).OpenUniStream)) + return &MockEarlyConnectionOpenUniStreamCall{Call: call} +} + +// MockEarlyConnectionOpenUniStreamCall wrap *gomock.Call +type MockEarlyConnectionOpenUniStreamCall struct { + *gomock.Call +} + +// Return rewrite *gomock.Call.Return +func (c *MockEarlyConnectionOpenUniStreamCall) Return(arg0 quic.SendStream, arg1 error) *MockEarlyConnectionOpenUniStreamCall { + c.Call = c.Call.Return(arg0, arg1) + return c +} + +// Do rewrite *gomock.Call.Do +func (c *MockEarlyConnectionOpenUniStreamCall) Do(f func() (quic.SendStream, error)) *MockEarlyConnectionOpenUniStreamCall { + c.Call = c.Call.Do(f) + return c +} + +// DoAndReturn rewrite *gomock.Call.DoAndReturn +func (c *MockEarlyConnectionOpenUniStreamCall) DoAndReturn(f func() (quic.SendStream, error)) *MockEarlyConnectionOpenUniStreamCall { + c.Call = c.Call.DoAndReturn(f) + return c } // OpenUniStreamSync mocks base method. @@ -210,24 +475,72 @@ func (m *MockEarlyConnection) OpenUniStreamSync(arg0 context.Context) (quic.Send } // OpenUniStreamSync indicates an expected call of OpenUniStreamSync. -func (mr *MockEarlyConnectionMockRecorder) OpenUniStreamSync(arg0 any) *gomock.Call { +func (mr *MockEarlyConnectionMockRecorder) OpenUniStreamSync(arg0 any) *MockEarlyConnectionOpenUniStreamSyncCall { mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "OpenUniStreamSync", reflect.TypeOf((*MockEarlyConnection)(nil).OpenUniStreamSync), arg0) + call := mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "OpenUniStreamSync", reflect.TypeOf((*MockEarlyConnection)(nil).OpenUniStreamSync), arg0) + return &MockEarlyConnectionOpenUniStreamSyncCall{Call: call} +} + +// MockEarlyConnectionOpenUniStreamSyncCall wrap *gomock.Call +type MockEarlyConnectionOpenUniStreamSyncCall struct { + *gomock.Call } -// ReceiveMessage mocks base method. -func (m *MockEarlyConnection) ReceiveMessage(arg0 context.Context) ([]byte, error) { +// Return rewrite *gomock.Call.Return +func (c *MockEarlyConnectionOpenUniStreamSyncCall) Return(arg0 quic.SendStream, arg1 error) *MockEarlyConnectionOpenUniStreamSyncCall { + c.Call = c.Call.Return(arg0, arg1) + return c +} + +// Do rewrite *gomock.Call.Do +func (c *MockEarlyConnectionOpenUniStreamSyncCall) Do(f func(context.Context) (quic.SendStream, error)) *MockEarlyConnectionOpenUniStreamSyncCall { + c.Call = c.Call.Do(f) + return c +} + +// DoAndReturn rewrite *gomock.Call.DoAndReturn +func (c *MockEarlyConnectionOpenUniStreamSyncCall) DoAndReturn(f func(context.Context) (quic.SendStream, error)) *MockEarlyConnectionOpenUniStreamSyncCall { + c.Call = c.Call.DoAndReturn(f) + return c +} + +// ReceiveDatagram mocks base method. +func (m *MockEarlyConnection) ReceiveDatagram(arg0 context.Context) ([]byte, error) { m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "ReceiveMessage", arg0) + ret := m.ctrl.Call(m, "ReceiveDatagram", arg0) ret0, _ := ret[0].([]byte) ret1, _ := ret[1].(error) return ret0, ret1 } -// ReceiveMessage indicates an expected call of ReceiveMessage. -func (mr *MockEarlyConnectionMockRecorder) ReceiveMessage(arg0 any) *gomock.Call { +// ReceiveDatagram indicates an expected call of ReceiveDatagram. +func (mr *MockEarlyConnectionMockRecorder) ReceiveDatagram(arg0 any) *MockEarlyConnectionReceiveDatagramCall { mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "ReceiveMessage", reflect.TypeOf((*MockEarlyConnection)(nil).ReceiveMessage), arg0) + call := mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "ReceiveDatagram", reflect.TypeOf((*MockEarlyConnection)(nil).ReceiveDatagram), arg0) + return &MockEarlyConnectionReceiveDatagramCall{Call: call} +} + +// MockEarlyConnectionReceiveDatagramCall wrap *gomock.Call +type MockEarlyConnectionReceiveDatagramCall struct { + *gomock.Call +} + +// Return rewrite *gomock.Call.Return +func (c *MockEarlyConnectionReceiveDatagramCall) Return(arg0 []byte, arg1 error) *MockEarlyConnectionReceiveDatagramCall { + c.Call = c.Call.Return(arg0, arg1) + return c +} + +// Do rewrite *gomock.Call.Do +func (c *MockEarlyConnectionReceiveDatagramCall) Do(f func(context.Context) ([]byte, error)) *MockEarlyConnectionReceiveDatagramCall { + c.Call = c.Call.Do(f) + return c +} + +// DoAndReturn rewrite *gomock.Call.DoAndReturn +func (c *MockEarlyConnectionReceiveDatagramCall) DoAndReturn(f func(context.Context) ([]byte, error)) *MockEarlyConnectionReceiveDatagramCall { + c.Call = c.Call.DoAndReturn(f) + return c } // RemoteAddr mocks base method. @@ -239,21 +552,69 @@ func (m *MockEarlyConnection) RemoteAddr() net.Addr { } // RemoteAddr indicates an expected call of RemoteAddr. -func (mr *MockEarlyConnectionMockRecorder) RemoteAddr() *gomock.Call { +func (mr *MockEarlyConnectionMockRecorder) RemoteAddr() *MockEarlyConnectionRemoteAddrCall { mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "RemoteAddr", reflect.TypeOf((*MockEarlyConnection)(nil).RemoteAddr)) + call := mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "RemoteAddr", reflect.TypeOf((*MockEarlyConnection)(nil).RemoteAddr)) + return &MockEarlyConnectionRemoteAddrCall{Call: call} +} + +// MockEarlyConnectionRemoteAddrCall wrap *gomock.Call +type MockEarlyConnectionRemoteAddrCall struct { + *gomock.Call +} + +// Return rewrite *gomock.Call.Return +func (c *MockEarlyConnectionRemoteAddrCall) Return(arg0 net.Addr) *MockEarlyConnectionRemoteAddrCall { + c.Call = c.Call.Return(arg0) + return c +} + +// Do rewrite *gomock.Call.Do +func (c *MockEarlyConnectionRemoteAddrCall) Do(f func() net.Addr) *MockEarlyConnectionRemoteAddrCall { + c.Call = c.Call.Do(f) + return c +} + +// DoAndReturn rewrite *gomock.Call.DoAndReturn +func (c *MockEarlyConnectionRemoteAddrCall) DoAndReturn(f func() net.Addr) *MockEarlyConnectionRemoteAddrCall { + c.Call = c.Call.DoAndReturn(f) + return c } -// SendMessage mocks base method. -func (m *MockEarlyConnection) SendMessage(arg0 []byte) error { +// SendDatagram mocks base method. +func (m *MockEarlyConnection) SendDatagram(arg0 []byte) error { m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "SendMessage", arg0) + ret := m.ctrl.Call(m, "SendDatagram", arg0) ret0, _ := ret[0].(error) return ret0 } -// SendMessage indicates an expected call of SendMessage. -func (mr *MockEarlyConnectionMockRecorder) SendMessage(arg0 any) *gomock.Call { +// SendDatagram indicates an expected call of SendDatagram. +func (mr *MockEarlyConnectionMockRecorder) SendDatagram(arg0 any) *MockEarlyConnectionSendDatagramCall { mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "SendMessage", reflect.TypeOf((*MockEarlyConnection)(nil).SendMessage), arg0) + call := mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "SendDatagram", reflect.TypeOf((*MockEarlyConnection)(nil).SendDatagram), arg0) + return &MockEarlyConnectionSendDatagramCall{Call: call} +} + +// MockEarlyConnectionSendDatagramCall wrap *gomock.Call +type MockEarlyConnectionSendDatagramCall struct { + *gomock.Call +} + +// Return rewrite *gomock.Call.Return +func (c *MockEarlyConnectionSendDatagramCall) Return(arg0 error) *MockEarlyConnectionSendDatagramCall { + c.Call = c.Call.Return(arg0) + return c +} + +// Do rewrite *gomock.Call.Do +func (c *MockEarlyConnectionSendDatagramCall) Do(f func([]byte) error) *MockEarlyConnectionSendDatagramCall { + c.Call = c.Call.Do(f) + return c +} + +// DoAndReturn rewrite *gomock.Call.DoAndReturn +func (c *MockEarlyConnectionSendDatagramCall) DoAndReturn(f func([]byte) error) *MockEarlyConnectionSendDatagramCall { + c.Call = c.Call.DoAndReturn(f) + return c } diff --git a/internal/mocks/quic/stream.go b/internal/mocks/quic/stream.go index 8ee289520..5c5e7824e 100644 --- a/internal/mocks/quic/stream.go +++ b/internal/mocks/quic/stream.go @@ -3,8 +3,9 @@ // // Generated by this command: // -// mockgen -build_flags=-tags=gomock -package mockquic -destination quic/stream.go github.com/refraction-networking/uquic Stream +// mockgen -typed -build_flags=-tags=gomock -package mockquic -destination quic/stream.go github.com/quic-go/quic-go Stream // + // Package mockquic is a generated GoMock package. package mockquic @@ -48,9 +49,33 @@ func (m *MockStream) CancelRead(arg0 qerr.StreamErrorCode) { } // CancelRead indicates an expected call of CancelRead. -func (mr *MockStreamMockRecorder) CancelRead(arg0 any) *gomock.Call { +func (mr *MockStreamMockRecorder) CancelRead(arg0 any) *MockStreamCancelReadCall { mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "CancelRead", reflect.TypeOf((*MockStream)(nil).CancelRead), arg0) + call := mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "CancelRead", reflect.TypeOf((*MockStream)(nil).CancelRead), arg0) + return &MockStreamCancelReadCall{Call: call} +} + +// MockStreamCancelReadCall wrap *gomock.Call +type MockStreamCancelReadCall struct { + *gomock.Call +} + +// Return rewrite *gomock.Call.Return +func (c *MockStreamCancelReadCall) Return() *MockStreamCancelReadCall { + c.Call = c.Call.Return() + return c +} + +// Do rewrite *gomock.Call.Do +func (c *MockStreamCancelReadCall) Do(f func(qerr.StreamErrorCode)) *MockStreamCancelReadCall { + c.Call = c.Call.Do(f) + return c +} + +// DoAndReturn rewrite *gomock.Call.DoAndReturn +func (c *MockStreamCancelReadCall) DoAndReturn(f func(qerr.StreamErrorCode)) *MockStreamCancelReadCall { + c.Call = c.Call.DoAndReturn(f) + return c } // CancelWrite mocks base method. @@ -60,9 +85,33 @@ func (m *MockStream) CancelWrite(arg0 qerr.StreamErrorCode) { } // CancelWrite indicates an expected call of CancelWrite. -func (mr *MockStreamMockRecorder) CancelWrite(arg0 any) *gomock.Call { +func (mr *MockStreamMockRecorder) CancelWrite(arg0 any) *MockStreamCancelWriteCall { mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "CancelWrite", reflect.TypeOf((*MockStream)(nil).CancelWrite), arg0) + call := mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "CancelWrite", reflect.TypeOf((*MockStream)(nil).CancelWrite), arg0) + return &MockStreamCancelWriteCall{Call: call} +} + +// MockStreamCancelWriteCall wrap *gomock.Call +type MockStreamCancelWriteCall struct { + *gomock.Call +} + +// Return rewrite *gomock.Call.Return +func (c *MockStreamCancelWriteCall) Return() *MockStreamCancelWriteCall { + c.Call = c.Call.Return() + return c +} + +// Do rewrite *gomock.Call.Do +func (c *MockStreamCancelWriteCall) Do(f func(qerr.StreamErrorCode)) *MockStreamCancelWriteCall { + c.Call = c.Call.Do(f) + return c +} + +// DoAndReturn rewrite *gomock.Call.DoAndReturn +func (c *MockStreamCancelWriteCall) DoAndReturn(f func(qerr.StreamErrorCode)) *MockStreamCancelWriteCall { + c.Call = c.Call.DoAndReturn(f) + return c } // Close mocks base method. @@ -74,9 +123,33 @@ func (m *MockStream) Close() error { } // Close indicates an expected call of Close. -func (mr *MockStreamMockRecorder) Close() *gomock.Call { +func (mr *MockStreamMockRecorder) Close() *MockStreamCloseCall { mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Close", reflect.TypeOf((*MockStream)(nil).Close)) + call := mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Close", reflect.TypeOf((*MockStream)(nil).Close)) + return &MockStreamCloseCall{Call: call} +} + +// MockStreamCloseCall wrap *gomock.Call +type MockStreamCloseCall struct { + *gomock.Call +} + +// Return rewrite *gomock.Call.Return +func (c *MockStreamCloseCall) Return(arg0 error) *MockStreamCloseCall { + c.Call = c.Call.Return(arg0) + return c +} + +// Do rewrite *gomock.Call.Do +func (c *MockStreamCloseCall) Do(f func() error) *MockStreamCloseCall { + c.Call = c.Call.Do(f) + return c +} + +// DoAndReturn rewrite *gomock.Call.DoAndReturn +func (c *MockStreamCloseCall) DoAndReturn(f func() error) *MockStreamCloseCall { + c.Call = c.Call.DoAndReturn(f) + return c } // Context mocks base method. @@ -88,9 +161,33 @@ func (m *MockStream) Context() context.Context { } // Context indicates an expected call of Context. -func (mr *MockStreamMockRecorder) Context() *gomock.Call { +func (mr *MockStreamMockRecorder) Context() *MockStreamContextCall { mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Context", reflect.TypeOf((*MockStream)(nil).Context)) + call := mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Context", reflect.TypeOf((*MockStream)(nil).Context)) + return &MockStreamContextCall{Call: call} +} + +// MockStreamContextCall wrap *gomock.Call +type MockStreamContextCall struct { + *gomock.Call +} + +// Return rewrite *gomock.Call.Return +func (c *MockStreamContextCall) Return(arg0 context.Context) *MockStreamContextCall { + c.Call = c.Call.Return(arg0) + return c +} + +// Do rewrite *gomock.Call.Do +func (c *MockStreamContextCall) Do(f func() context.Context) *MockStreamContextCall { + c.Call = c.Call.Do(f) + return c +} + +// DoAndReturn rewrite *gomock.Call.DoAndReturn +func (c *MockStreamContextCall) DoAndReturn(f func() context.Context) *MockStreamContextCall { + c.Call = c.Call.DoAndReturn(f) + return c } // Read mocks base method. @@ -103,9 +200,33 @@ func (m *MockStream) Read(arg0 []byte) (int, error) { } // Read indicates an expected call of Read. -func (mr *MockStreamMockRecorder) Read(arg0 any) *gomock.Call { +func (mr *MockStreamMockRecorder) Read(arg0 any) *MockStreamReadCall { mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Read", reflect.TypeOf((*MockStream)(nil).Read), arg0) + call := mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Read", reflect.TypeOf((*MockStream)(nil).Read), arg0) + return &MockStreamReadCall{Call: call} +} + +// MockStreamReadCall wrap *gomock.Call +type MockStreamReadCall struct { + *gomock.Call +} + +// Return rewrite *gomock.Call.Return +func (c *MockStreamReadCall) Return(arg0 int, arg1 error) *MockStreamReadCall { + c.Call = c.Call.Return(arg0, arg1) + return c +} + +// Do rewrite *gomock.Call.Do +func (c *MockStreamReadCall) Do(f func([]byte) (int, error)) *MockStreamReadCall { + c.Call = c.Call.Do(f) + return c +} + +// DoAndReturn rewrite *gomock.Call.DoAndReturn +func (c *MockStreamReadCall) DoAndReturn(f func([]byte) (int, error)) *MockStreamReadCall { + c.Call = c.Call.DoAndReturn(f) + return c } // SetDeadline mocks base method. @@ -117,9 +238,33 @@ func (m *MockStream) SetDeadline(arg0 time.Time) error { } // SetDeadline indicates an expected call of SetDeadline. -func (mr *MockStreamMockRecorder) SetDeadline(arg0 any) *gomock.Call { +func (mr *MockStreamMockRecorder) SetDeadline(arg0 any) *MockStreamSetDeadlineCall { mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "SetDeadline", reflect.TypeOf((*MockStream)(nil).SetDeadline), arg0) + call := mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "SetDeadline", reflect.TypeOf((*MockStream)(nil).SetDeadline), arg0) + return &MockStreamSetDeadlineCall{Call: call} +} + +// MockStreamSetDeadlineCall wrap *gomock.Call +type MockStreamSetDeadlineCall struct { + *gomock.Call +} + +// Return rewrite *gomock.Call.Return +func (c *MockStreamSetDeadlineCall) Return(arg0 error) *MockStreamSetDeadlineCall { + c.Call = c.Call.Return(arg0) + return c +} + +// Do rewrite *gomock.Call.Do +func (c *MockStreamSetDeadlineCall) Do(f func(time.Time) error) *MockStreamSetDeadlineCall { + c.Call = c.Call.Do(f) + return c +} + +// DoAndReturn rewrite *gomock.Call.DoAndReturn +func (c *MockStreamSetDeadlineCall) DoAndReturn(f func(time.Time) error) *MockStreamSetDeadlineCall { + c.Call = c.Call.DoAndReturn(f) + return c } // SetReadDeadline mocks base method. @@ -131,9 +276,33 @@ func (m *MockStream) SetReadDeadline(arg0 time.Time) error { } // SetReadDeadline indicates an expected call of SetReadDeadline. -func (mr *MockStreamMockRecorder) SetReadDeadline(arg0 any) *gomock.Call { +func (mr *MockStreamMockRecorder) SetReadDeadline(arg0 any) *MockStreamSetReadDeadlineCall { mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "SetReadDeadline", reflect.TypeOf((*MockStream)(nil).SetReadDeadline), arg0) + call := mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "SetReadDeadline", reflect.TypeOf((*MockStream)(nil).SetReadDeadline), arg0) + return &MockStreamSetReadDeadlineCall{Call: call} +} + +// MockStreamSetReadDeadlineCall wrap *gomock.Call +type MockStreamSetReadDeadlineCall struct { + *gomock.Call +} + +// Return rewrite *gomock.Call.Return +func (c *MockStreamSetReadDeadlineCall) Return(arg0 error) *MockStreamSetReadDeadlineCall { + c.Call = c.Call.Return(arg0) + return c +} + +// Do rewrite *gomock.Call.Do +func (c *MockStreamSetReadDeadlineCall) Do(f func(time.Time) error) *MockStreamSetReadDeadlineCall { + c.Call = c.Call.Do(f) + return c +} + +// DoAndReturn rewrite *gomock.Call.DoAndReturn +func (c *MockStreamSetReadDeadlineCall) DoAndReturn(f func(time.Time) error) *MockStreamSetReadDeadlineCall { + c.Call = c.Call.DoAndReturn(f) + return c } // SetWriteDeadline mocks base method. @@ -145,9 +314,33 @@ func (m *MockStream) SetWriteDeadline(arg0 time.Time) error { } // SetWriteDeadline indicates an expected call of SetWriteDeadline. -func (mr *MockStreamMockRecorder) SetWriteDeadline(arg0 any) *gomock.Call { +func (mr *MockStreamMockRecorder) SetWriteDeadline(arg0 any) *MockStreamSetWriteDeadlineCall { mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "SetWriteDeadline", reflect.TypeOf((*MockStream)(nil).SetWriteDeadline), arg0) + call := mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "SetWriteDeadline", reflect.TypeOf((*MockStream)(nil).SetWriteDeadline), arg0) + return &MockStreamSetWriteDeadlineCall{Call: call} +} + +// MockStreamSetWriteDeadlineCall wrap *gomock.Call +type MockStreamSetWriteDeadlineCall struct { + *gomock.Call +} + +// Return rewrite *gomock.Call.Return +func (c *MockStreamSetWriteDeadlineCall) Return(arg0 error) *MockStreamSetWriteDeadlineCall { + c.Call = c.Call.Return(arg0) + return c +} + +// Do rewrite *gomock.Call.Do +func (c *MockStreamSetWriteDeadlineCall) Do(f func(time.Time) error) *MockStreamSetWriteDeadlineCall { + c.Call = c.Call.Do(f) + return c +} + +// DoAndReturn rewrite *gomock.Call.DoAndReturn +func (c *MockStreamSetWriteDeadlineCall) DoAndReturn(f func(time.Time) error) *MockStreamSetWriteDeadlineCall { + c.Call = c.Call.DoAndReturn(f) + return c } // StreamID mocks base method. @@ -159,9 +352,33 @@ func (m *MockStream) StreamID() protocol.StreamID { } // StreamID indicates an expected call of StreamID. -func (mr *MockStreamMockRecorder) StreamID() *gomock.Call { +func (mr *MockStreamMockRecorder) StreamID() *MockStreamStreamIDCall { mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "StreamID", reflect.TypeOf((*MockStream)(nil).StreamID)) + call := mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "StreamID", reflect.TypeOf((*MockStream)(nil).StreamID)) + return &MockStreamStreamIDCall{Call: call} +} + +// MockStreamStreamIDCall wrap *gomock.Call +type MockStreamStreamIDCall struct { + *gomock.Call +} + +// Return rewrite *gomock.Call.Return +func (c *MockStreamStreamIDCall) Return(arg0 protocol.StreamID) *MockStreamStreamIDCall { + c.Call = c.Call.Return(arg0) + return c +} + +// Do rewrite *gomock.Call.Do +func (c *MockStreamStreamIDCall) Do(f func() protocol.StreamID) *MockStreamStreamIDCall { + c.Call = c.Call.Do(f) + return c +} + +// DoAndReturn rewrite *gomock.Call.DoAndReturn +func (c *MockStreamStreamIDCall) DoAndReturn(f func() protocol.StreamID) *MockStreamStreamIDCall { + c.Call = c.Call.DoAndReturn(f) + return c } // Write mocks base method. @@ -174,7 +391,31 @@ func (m *MockStream) Write(arg0 []byte) (int, error) { } // Write indicates an expected call of Write. -func (mr *MockStreamMockRecorder) Write(arg0 any) *gomock.Call { +func (mr *MockStreamMockRecorder) Write(arg0 any) *MockStreamWriteCall { mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Write", reflect.TypeOf((*MockStream)(nil).Write), arg0) + call := mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Write", reflect.TypeOf((*MockStream)(nil).Write), arg0) + return &MockStreamWriteCall{Call: call} +} + +// MockStreamWriteCall wrap *gomock.Call +type MockStreamWriteCall struct { + *gomock.Call +} + +// Return rewrite *gomock.Call.Return +func (c *MockStreamWriteCall) Return(arg0 int, arg1 error) *MockStreamWriteCall { + c.Call = c.Call.Return(arg0, arg1) + return c +} + +// Do rewrite *gomock.Call.Do +func (c *MockStreamWriteCall) Do(f func([]byte) (int, error)) *MockStreamWriteCall { + c.Call = c.Call.Do(f) + return c +} + +// DoAndReturn rewrite *gomock.Call.DoAndReturn +func (c *MockStreamWriteCall) DoAndReturn(f func([]byte) (int, error)) *MockStreamWriteCall { + c.Call = c.Call.DoAndReturn(f) + return c } diff --git a/internal/mocks/short_header_opener.go b/internal/mocks/short_header_opener.go index 08aaf8f92..092ec618d 100644 --- a/internal/mocks/short_header_opener.go +++ b/internal/mocks/short_header_opener.go @@ -3,8 +3,9 @@ // // Generated by this command: // -// mockgen -build_flags=-tags=gomock -package mocks -destination short_header_opener.go github.com/refraction-networking/uquic/internal/handshake ShortHeaderOpener +// mockgen -typed -build_flags=-tags=gomock -package mocks -destination short_header_opener.go github.com/quic-go/quic-go/internal/handshake ShortHeaderOpener // + // Package mocks is a generated GoMock package. package mocks @@ -48,9 +49,33 @@ func (m *MockShortHeaderOpener) DecodePacketNumber(arg0 protocol.PacketNumber, a } // DecodePacketNumber indicates an expected call of DecodePacketNumber. -func (mr *MockShortHeaderOpenerMockRecorder) DecodePacketNumber(arg0, arg1 any) *gomock.Call { +func (mr *MockShortHeaderOpenerMockRecorder) DecodePacketNumber(arg0, arg1 any) *MockShortHeaderOpenerDecodePacketNumberCall { mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "DecodePacketNumber", reflect.TypeOf((*MockShortHeaderOpener)(nil).DecodePacketNumber), arg0, arg1) + call := mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "DecodePacketNumber", reflect.TypeOf((*MockShortHeaderOpener)(nil).DecodePacketNumber), arg0, arg1) + return &MockShortHeaderOpenerDecodePacketNumberCall{Call: call} +} + +// MockShortHeaderOpenerDecodePacketNumberCall wrap *gomock.Call +type MockShortHeaderOpenerDecodePacketNumberCall struct { + *gomock.Call +} + +// Return rewrite *gomock.Call.Return +func (c *MockShortHeaderOpenerDecodePacketNumberCall) Return(arg0 protocol.PacketNumber) *MockShortHeaderOpenerDecodePacketNumberCall { + c.Call = c.Call.Return(arg0) + return c +} + +// Do rewrite *gomock.Call.Do +func (c *MockShortHeaderOpenerDecodePacketNumberCall) Do(f func(protocol.PacketNumber, protocol.PacketNumberLen) protocol.PacketNumber) *MockShortHeaderOpenerDecodePacketNumberCall { + c.Call = c.Call.Do(f) + return c +} + +// DoAndReturn rewrite *gomock.Call.DoAndReturn +func (c *MockShortHeaderOpenerDecodePacketNumberCall) DoAndReturn(f func(protocol.PacketNumber, protocol.PacketNumberLen) protocol.PacketNumber) *MockShortHeaderOpenerDecodePacketNumberCall { + c.Call = c.Call.DoAndReturn(f) + return c } // DecryptHeader mocks base method. @@ -60,9 +85,33 @@ func (m *MockShortHeaderOpener) DecryptHeader(arg0 []byte, arg1 *byte, arg2 []by } // DecryptHeader indicates an expected call of DecryptHeader. -func (mr *MockShortHeaderOpenerMockRecorder) DecryptHeader(arg0, arg1, arg2 any) *gomock.Call { +func (mr *MockShortHeaderOpenerMockRecorder) DecryptHeader(arg0, arg1, arg2 any) *MockShortHeaderOpenerDecryptHeaderCall { mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "DecryptHeader", reflect.TypeOf((*MockShortHeaderOpener)(nil).DecryptHeader), arg0, arg1, arg2) + call := mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "DecryptHeader", reflect.TypeOf((*MockShortHeaderOpener)(nil).DecryptHeader), arg0, arg1, arg2) + return &MockShortHeaderOpenerDecryptHeaderCall{Call: call} +} + +// MockShortHeaderOpenerDecryptHeaderCall wrap *gomock.Call +type MockShortHeaderOpenerDecryptHeaderCall struct { + *gomock.Call +} + +// Return rewrite *gomock.Call.Return +func (c *MockShortHeaderOpenerDecryptHeaderCall) Return() *MockShortHeaderOpenerDecryptHeaderCall { + c.Call = c.Call.Return() + return c +} + +// Do rewrite *gomock.Call.Do +func (c *MockShortHeaderOpenerDecryptHeaderCall) Do(f func([]byte, *byte, []byte)) *MockShortHeaderOpenerDecryptHeaderCall { + c.Call = c.Call.Do(f) + return c +} + +// DoAndReturn rewrite *gomock.Call.DoAndReturn +func (c *MockShortHeaderOpenerDecryptHeaderCall) DoAndReturn(f func([]byte, *byte, []byte)) *MockShortHeaderOpenerDecryptHeaderCall { + c.Call = c.Call.DoAndReturn(f) + return c } // Open mocks base method. @@ -75,7 +124,31 @@ func (m *MockShortHeaderOpener) Open(arg0, arg1 []byte, arg2 time.Time, arg3 pro } // Open indicates an expected call of Open. -func (mr *MockShortHeaderOpenerMockRecorder) Open(arg0, arg1, arg2, arg3, arg4, arg5 any) *gomock.Call { +func (mr *MockShortHeaderOpenerMockRecorder) Open(arg0, arg1, arg2, arg3, arg4, arg5 any) *MockShortHeaderOpenerOpenCall { mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Open", reflect.TypeOf((*MockShortHeaderOpener)(nil).Open), arg0, arg1, arg2, arg3, arg4, arg5) + call := mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Open", reflect.TypeOf((*MockShortHeaderOpener)(nil).Open), arg0, arg1, arg2, arg3, arg4, arg5) + return &MockShortHeaderOpenerOpenCall{Call: call} +} + +// MockShortHeaderOpenerOpenCall wrap *gomock.Call +type MockShortHeaderOpenerOpenCall struct { + *gomock.Call +} + +// Return rewrite *gomock.Call.Return +func (c *MockShortHeaderOpenerOpenCall) Return(arg0 []byte, arg1 error) *MockShortHeaderOpenerOpenCall { + c.Call = c.Call.Return(arg0, arg1) + return c +} + +// Do rewrite *gomock.Call.Do +func (c *MockShortHeaderOpenerOpenCall) Do(f func([]byte, []byte, time.Time, protocol.PacketNumber, protocol.KeyPhaseBit, []byte) ([]byte, error)) *MockShortHeaderOpenerOpenCall { + c.Call = c.Call.Do(f) + return c +} + +// DoAndReturn rewrite *gomock.Call.DoAndReturn +func (c *MockShortHeaderOpenerOpenCall) DoAndReturn(f func([]byte, []byte, time.Time, protocol.PacketNumber, protocol.KeyPhaseBit, []byte) ([]byte, error)) *MockShortHeaderOpenerOpenCall { + c.Call = c.Call.DoAndReturn(f) + return c } diff --git a/internal/mocks/short_header_sealer.go b/internal/mocks/short_header_sealer.go index 768543d60..ad49501b7 100644 --- a/internal/mocks/short_header_sealer.go +++ b/internal/mocks/short_header_sealer.go @@ -3,8 +3,9 @@ // // Generated by this command: // -// mockgen -build_flags=-tags=gomock -package mocks -destination short_header_sealer.go github.com/refraction-networking/uquic/internal/handshake ShortHeaderSealer +// mockgen -typed -build_flags=-tags=gomock -package mocks -destination short_header_sealer.go github.com/quic-go/quic-go/internal/handshake ShortHeaderSealer // + // Package mocks is a generated GoMock package. package mocks @@ -45,9 +46,33 @@ func (m *MockShortHeaderSealer) EncryptHeader(arg0 []byte, arg1 *byte, arg2 []by } // EncryptHeader indicates an expected call of EncryptHeader. -func (mr *MockShortHeaderSealerMockRecorder) EncryptHeader(arg0, arg1, arg2 any) *gomock.Call { +func (mr *MockShortHeaderSealerMockRecorder) EncryptHeader(arg0, arg1, arg2 any) *MockShortHeaderSealerEncryptHeaderCall { mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "EncryptHeader", reflect.TypeOf((*MockShortHeaderSealer)(nil).EncryptHeader), arg0, arg1, arg2) + call := mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "EncryptHeader", reflect.TypeOf((*MockShortHeaderSealer)(nil).EncryptHeader), arg0, arg1, arg2) + return &MockShortHeaderSealerEncryptHeaderCall{Call: call} +} + +// MockShortHeaderSealerEncryptHeaderCall wrap *gomock.Call +type MockShortHeaderSealerEncryptHeaderCall struct { + *gomock.Call +} + +// Return rewrite *gomock.Call.Return +func (c *MockShortHeaderSealerEncryptHeaderCall) Return() *MockShortHeaderSealerEncryptHeaderCall { + c.Call = c.Call.Return() + return c +} + +// Do rewrite *gomock.Call.Do +func (c *MockShortHeaderSealerEncryptHeaderCall) Do(f func([]byte, *byte, []byte)) *MockShortHeaderSealerEncryptHeaderCall { + c.Call = c.Call.Do(f) + return c +} + +// DoAndReturn rewrite *gomock.Call.DoAndReturn +func (c *MockShortHeaderSealerEncryptHeaderCall) DoAndReturn(f func([]byte, *byte, []byte)) *MockShortHeaderSealerEncryptHeaderCall { + c.Call = c.Call.DoAndReturn(f) + return c } // KeyPhase mocks base method. @@ -59,9 +84,33 @@ func (m *MockShortHeaderSealer) KeyPhase() protocol.KeyPhaseBit { } // KeyPhase indicates an expected call of KeyPhase. -func (mr *MockShortHeaderSealerMockRecorder) KeyPhase() *gomock.Call { +func (mr *MockShortHeaderSealerMockRecorder) KeyPhase() *MockShortHeaderSealerKeyPhaseCall { mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "KeyPhase", reflect.TypeOf((*MockShortHeaderSealer)(nil).KeyPhase)) + call := mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "KeyPhase", reflect.TypeOf((*MockShortHeaderSealer)(nil).KeyPhase)) + return &MockShortHeaderSealerKeyPhaseCall{Call: call} +} + +// MockShortHeaderSealerKeyPhaseCall wrap *gomock.Call +type MockShortHeaderSealerKeyPhaseCall struct { + *gomock.Call +} + +// Return rewrite *gomock.Call.Return +func (c *MockShortHeaderSealerKeyPhaseCall) Return(arg0 protocol.KeyPhaseBit) *MockShortHeaderSealerKeyPhaseCall { + c.Call = c.Call.Return(arg0) + return c +} + +// Do rewrite *gomock.Call.Do +func (c *MockShortHeaderSealerKeyPhaseCall) Do(f func() protocol.KeyPhaseBit) *MockShortHeaderSealerKeyPhaseCall { + c.Call = c.Call.Do(f) + return c +} + +// DoAndReturn rewrite *gomock.Call.DoAndReturn +func (c *MockShortHeaderSealerKeyPhaseCall) DoAndReturn(f func() protocol.KeyPhaseBit) *MockShortHeaderSealerKeyPhaseCall { + c.Call = c.Call.DoAndReturn(f) + return c } // Overhead mocks base method. @@ -73,9 +122,33 @@ func (m *MockShortHeaderSealer) Overhead() int { } // Overhead indicates an expected call of Overhead. -func (mr *MockShortHeaderSealerMockRecorder) Overhead() *gomock.Call { +func (mr *MockShortHeaderSealerMockRecorder) Overhead() *MockShortHeaderSealerOverheadCall { mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Overhead", reflect.TypeOf((*MockShortHeaderSealer)(nil).Overhead)) + call := mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Overhead", reflect.TypeOf((*MockShortHeaderSealer)(nil).Overhead)) + return &MockShortHeaderSealerOverheadCall{Call: call} +} + +// MockShortHeaderSealerOverheadCall wrap *gomock.Call +type MockShortHeaderSealerOverheadCall struct { + *gomock.Call +} + +// Return rewrite *gomock.Call.Return +func (c *MockShortHeaderSealerOverheadCall) Return(arg0 int) *MockShortHeaderSealerOverheadCall { + c.Call = c.Call.Return(arg0) + return c +} + +// Do rewrite *gomock.Call.Do +func (c *MockShortHeaderSealerOverheadCall) Do(f func() int) *MockShortHeaderSealerOverheadCall { + c.Call = c.Call.Do(f) + return c +} + +// DoAndReturn rewrite *gomock.Call.DoAndReturn +func (c *MockShortHeaderSealerOverheadCall) DoAndReturn(f func() int) *MockShortHeaderSealerOverheadCall { + c.Call = c.Call.DoAndReturn(f) + return c } // Seal mocks base method. @@ -87,7 +160,31 @@ func (m *MockShortHeaderSealer) Seal(arg0, arg1 []byte, arg2 protocol.PacketNumb } // Seal indicates an expected call of Seal. -func (mr *MockShortHeaderSealerMockRecorder) Seal(arg0, arg1, arg2, arg3 any) *gomock.Call { +func (mr *MockShortHeaderSealerMockRecorder) Seal(arg0, arg1, arg2, arg3 any) *MockShortHeaderSealerSealCall { mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Seal", reflect.TypeOf((*MockShortHeaderSealer)(nil).Seal), arg0, arg1, arg2, arg3) + call := mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Seal", reflect.TypeOf((*MockShortHeaderSealer)(nil).Seal), arg0, arg1, arg2, arg3) + return &MockShortHeaderSealerSealCall{Call: call} +} + +// MockShortHeaderSealerSealCall wrap *gomock.Call +type MockShortHeaderSealerSealCall struct { + *gomock.Call +} + +// Return rewrite *gomock.Call.Return +func (c *MockShortHeaderSealerSealCall) Return(arg0 []byte) *MockShortHeaderSealerSealCall { + c.Call = c.Call.Return(arg0) + return c +} + +// Do rewrite *gomock.Call.Do +func (c *MockShortHeaderSealerSealCall) Do(f func([]byte, []byte, protocol.PacketNumber, []byte) []byte) *MockShortHeaderSealerSealCall { + c.Call = c.Call.Do(f) + return c +} + +// DoAndReturn rewrite *gomock.Call.DoAndReturn +func (c *MockShortHeaderSealerSealCall) DoAndReturn(f func([]byte, []byte, protocol.PacketNumber, []byte) []byte) *MockShortHeaderSealerSealCall { + c.Call = c.Call.DoAndReturn(f) + return c } diff --git a/internal/mocks/stream_flow_controller.go b/internal/mocks/stream_flow_controller.go index 3bacec394..d5c337d55 100644 --- a/internal/mocks/stream_flow_controller.go +++ b/internal/mocks/stream_flow_controller.go @@ -3,8 +3,9 @@ // // Generated by this command: // -// mockgen -build_flags=-tags=gomock -package mocks -destination stream_flow_controller.go github.com/refraction-networking/uquic/internal/flowcontrol StreamFlowController +// mockgen -typed -build_flags=-tags=gomock -package mocks -destination stream_flow_controller.go github.com/quic-go/quic-go/internal/flowcontrol StreamFlowController // + // Package mocks is a generated GoMock package. package mocks @@ -45,9 +46,33 @@ func (m *MockStreamFlowController) Abandon() { } // Abandon indicates an expected call of Abandon. -func (mr *MockStreamFlowControllerMockRecorder) Abandon() *gomock.Call { +func (mr *MockStreamFlowControllerMockRecorder) Abandon() *MockStreamFlowControllerAbandonCall { mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Abandon", reflect.TypeOf((*MockStreamFlowController)(nil).Abandon)) + call := mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Abandon", reflect.TypeOf((*MockStreamFlowController)(nil).Abandon)) + return &MockStreamFlowControllerAbandonCall{Call: call} +} + +// MockStreamFlowControllerAbandonCall wrap *gomock.Call +type MockStreamFlowControllerAbandonCall struct { + *gomock.Call +} + +// Return rewrite *gomock.Call.Return +func (c *MockStreamFlowControllerAbandonCall) Return() *MockStreamFlowControllerAbandonCall { + c.Call = c.Call.Return() + return c +} + +// Do rewrite *gomock.Call.Do +func (c *MockStreamFlowControllerAbandonCall) Do(f func()) *MockStreamFlowControllerAbandonCall { + c.Call = c.Call.Do(f) + return c +} + +// DoAndReturn rewrite *gomock.Call.DoAndReturn +func (c *MockStreamFlowControllerAbandonCall) DoAndReturn(f func()) *MockStreamFlowControllerAbandonCall { + c.Call = c.Call.DoAndReturn(f) + return c } // AddBytesRead mocks base method. @@ -57,9 +82,33 @@ func (m *MockStreamFlowController) AddBytesRead(arg0 protocol.ByteCount) { } // AddBytesRead indicates an expected call of AddBytesRead. -func (mr *MockStreamFlowControllerMockRecorder) AddBytesRead(arg0 any) *gomock.Call { +func (mr *MockStreamFlowControllerMockRecorder) AddBytesRead(arg0 any) *MockStreamFlowControllerAddBytesReadCall { mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "AddBytesRead", reflect.TypeOf((*MockStreamFlowController)(nil).AddBytesRead), arg0) + call := mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "AddBytesRead", reflect.TypeOf((*MockStreamFlowController)(nil).AddBytesRead), arg0) + return &MockStreamFlowControllerAddBytesReadCall{Call: call} +} + +// MockStreamFlowControllerAddBytesReadCall wrap *gomock.Call +type MockStreamFlowControllerAddBytesReadCall struct { + *gomock.Call +} + +// Return rewrite *gomock.Call.Return +func (c *MockStreamFlowControllerAddBytesReadCall) Return() *MockStreamFlowControllerAddBytesReadCall { + c.Call = c.Call.Return() + return c +} + +// Do rewrite *gomock.Call.Do +func (c *MockStreamFlowControllerAddBytesReadCall) Do(f func(protocol.ByteCount)) *MockStreamFlowControllerAddBytesReadCall { + c.Call = c.Call.Do(f) + return c +} + +// DoAndReturn rewrite *gomock.Call.DoAndReturn +func (c *MockStreamFlowControllerAddBytesReadCall) DoAndReturn(f func(protocol.ByteCount)) *MockStreamFlowControllerAddBytesReadCall { + c.Call = c.Call.DoAndReturn(f) + return c } // AddBytesSent mocks base method. @@ -69,9 +118,33 @@ func (m *MockStreamFlowController) AddBytesSent(arg0 protocol.ByteCount) { } // AddBytesSent indicates an expected call of AddBytesSent. -func (mr *MockStreamFlowControllerMockRecorder) AddBytesSent(arg0 any) *gomock.Call { +func (mr *MockStreamFlowControllerMockRecorder) AddBytesSent(arg0 any) *MockStreamFlowControllerAddBytesSentCall { mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "AddBytesSent", reflect.TypeOf((*MockStreamFlowController)(nil).AddBytesSent), arg0) + call := mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "AddBytesSent", reflect.TypeOf((*MockStreamFlowController)(nil).AddBytesSent), arg0) + return &MockStreamFlowControllerAddBytesSentCall{Call: call} +} + +// MockStreamFlowControllerAddBytesSentCall wrap *gomock.Call +type MockStreamFlowControllerAddBytesSentCall struct { + *gomock.Call +} + +// Return rewrite *gomock.Call.Return +func (c *MockStreamFlowControllerAddBytesSentCall) Return() *MockStreamFlowControllerAddBytesSentCall { + c.Call = c.Call.Return() + return c +} + +// Do rewrite *gomock.Call.Do +func (c *MockStreamFlowControllerAddBytesSentCall) Do(f func(protocol.ByteCount)) *MockStreamFlowControllerAddBytesSentCall { + c.Call = c.Call.Do(f) + return c +} + +// DoAndReturn rewrite *gomock.Call.DoAndReturn +func (c *MockStreamFlowControllerAddBytesSentCall) DoAndReturn(f func(protocol.ByteCount)) *MockStreamFlowControllerAddBytesSentCall { + c.Call = c.Call.DoAndReturn(f) + return c } // GetWindowUpdate mocks base method. @@ -83,9 +156,33 @@ func (m *MockStreamFlowController) GetWindowUpdate() protocol.ByteCount { } // GetWindowUpdate indicates an expected call of GetWindowUpdate. -func (mr *MockStreamFlowControllerMockRecorder) GetWindowUpdate() *gomock.Call { +func (mr *MockStreamFlowControllerMockRecorder) GetWindowUpdate() *MockStreamFlowControllerGetWindowUpdateCall { mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetWindowUpdate", reflect.TypeOf((*MockStreamFlowController)(nil).GetWindowUpdate)) + call := mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetWindowUpdate", reflect.TypeOf((*MockStreamFlowController)(nil).GetWindowUpdate)) + return &MockStreamFlowControllerGetWindowUpdateCall{Call: call} +} + +// MockStreamFlowControllerGetWindowUpdateCall wrap *gomock.Call +type MockStreamFlowControllerGetWindowUpdateCall struct { + *gomock.Call +} + +// Return rewrite *gomock.Call.Return +func (c *MockStreamFlowControllerGetWindowUpdateCall) Return(arg0 protocol.ByteCount) *MockStreamFlowControllerGetWindowUpdateCall { + c.Call = c.Call.Return(arg0) + return c +} + +// Do rewrite *gomock.Call.Do +func (c *MockStreamFlowControllerGetWindowUpdateCall) Do(f func() protocol.ByteCount) *MockStreamFlowControllerGetWindowUpdateCall { + c.Call = c.Call.Do(f) + return c +} + +// DoAndReturn rewrite *gomock.Call.DoAndReturn +func (c *MockStreamFlowControllerGetWindowUpdateCall) DoAndReturn(f func() protocol.ByteCount) *MockStreamFlowControllerGetWindowUpdateCall { + c.Call = c.Call.DoAndReturn(f) + return c } // IsNewlyBlocked mocks base method. @@ -98,9 +195,33 @@ func (m *MockStreamFlowController) IsNewlyBlocked() (bool, protocol.ByteCount) { } // IsNewlyBlocked indicates an expected call of IsNewlyBlocked. -func (mr *MockStreamFlowControllerMockRecorder) IsNewlyBlocked() *gomock.Call { +func (mr *MockStreamFlowControllerMockRecorder) IsNewlyBlocked() *MockStreamFlowControllerIsNewlyBlockedCall { mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "IsNewlyBlocked", reflect.TypeOf((*MockStreamFlowController)(nil).IsNewlyBlocked)) + call := mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "IsNewlyBlocked", reflect.TypeOf((*MockStreamFlowController)(nil).IsNewlyBlocked)) + return &MockStreamFlowControllerIsNewlyBlockedCall{Call: call} +} + +// MockStreamFlowControllerIsNewlyBlockedCall wrap *gomock.Call +type MockStreamFlowControllerIsNewlyBlockedCall struct { + *gomock.Call +} + +// Return rewrite *gomock.Call.Return +func (c *MockStreamFlowControllerIsNewlyBlockedCall) Return(arg0 bool, arg1 protocol.ByteCount) *MockStreamFlowControllerIsNewlyBlockedCall { + c.Call = c.Call.Return(arg0, arg1) + return c +} + +// Do rewrite *gomock.Call.Do +func (c *MockStreamFlowControllerIsNewlyBlockedCall) Do(f func() (bool, protocol.ByteCount)) *MockStreamFlowControllerIsNewlyBlockedCall { + c.Call = c.Call.Do(f) + return c +} + +// DoAndReturn rewrite *gomock.Call.DoAndReturn +func (c *MockStreamFlowControllerIsNewlyBlockedCall) DoAndReturn(f func() (bool, protocol.ByteCount)) *MockStreamFlowControllerIsNewlyBlockedCall { + c.Call = c.Call.DoAndReturn(f) + return c } // SendWindowSize mocks base method. @@ -112,9 +233,33 @@ func (m *MockStreamFlowController) SendWindowSize() protocol.ByteCount { } // SendWindowSize indicates an expected call of SendWindowSize. -func (mr *MockStreamFlowControllerMockRecorder) SendWindowSize() *gomock.Call { +func (mr *MockStreamFlowControllerMockRecorder) SendWindowSize() *MockStreamFlowControllerSendWindowSizeCall { mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "SendWindowSize", reflect.TypeOf((*MockStreamFlowController)(nil).SendWindowSize)) + call := mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "SendWindowSize", reflect.TypeOf((*MockStreamFlowController)(nil).SendWindowSize)) + return &MockStreamFlowControllerSendWindowSizeCall{Call: call} +} + +// MockStreamFlowControllerSendWindowSizeCall wrap *gomock.Call +type MockStreamFlowControllerSendWindowSizeCall struct { + *gomock.Call +} + +// Return rewrite *gomock.Call.Return +func (c *MockStreamFlowControllerSendWindowSizeCall) Return(arg0 protocol.ByteCount) *MockStreamFlowControllerSendWindowSizeCall { + c.Call = c.Call.Return(arg0) + return c +} + +// Do rewrite *gomock.Call.Do +func (c *MockStreamFlowControllerSendWindowSizeCall) Do(f func() protocol.ByteCount) *MockStreamFlowControllerSendWindowSizeCall { + c.Call = c.Call.Do(f) + return c +} + +// DoAndReturn rewrite *gomock.Call.DoAndReturn +func (c *MockStreamFlowControllerSendWindowSizeCall) DoAndReturn(f func() protocol.ByteCount) *MockStreamFlowControllerSendWindowSizeCall { + c.Call = c.Call.DoAndReturn(f) + return c } // UpdateHighestReceived mocks base method. @@ -126,19 +271,69 @@ func (m *MockStreamFlowController) UpdateHighestReceived(arg0 protocol.ByteCount } // UpdateHighestReceived indicates an expected call of UpdateHighestReceived. -func (mr *MockStreamFlowControllerMockRecorder) UpdateHighestReceived(arg0, arg1 any) *gomock.Call { +func (mr *MockStreamFlowControllerMockRecorder) UpdateHighestReceived(arg0, arg1 any) *MockStreamFlowControllerUpdateHighestReceivedCall { mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "UpdateHighestReceived", reflect.TypeOf((*MockStreamFlowController)(nil).UpdateHighestReceived), arg0, arg1) + call := mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "UpdateHighestReceived", reflect.TypeOf((*MockStreamFlowController)(nil).UpdateHighestReceived), arg0, arg1) + return &MockStreamFlowControllerUpdateHighestReceivedCall{Call: call} +} + +// MockStreamFlowControllerUpdateHighestReceivedCall wrap *gomock.Call +type MockStreamFlowControllerUpdateHighestReceivedCall struct { + *gomock.Call +} + +// Return rewrite *gomock.Call.Return +func (c *MockStreamFlowControllerUpdateHighestReceivedCall) Return(arg0 error) *MockStreamFlowControllerUpdateHighestReceivedCall { + c.Call = c.Call.Return(arg0) + return c +} + +// Do rewrite *gomock.Call.Do +func (c *MockStreamFlowControllerUpdateHighestReceivedCall) Do(f func(protocol.ByteCount, bool) error) *MockStreamFlowControllerUpdateHighestReceivedCall { + c.Call = c.Call.Do(f) + return c +} + +// DoAndReturn rewrite *gomock.Call.DoAndReturn +func (c *MockStreamFlowControllerUpdateHighestReceivedCall) DoAndReturn(f func(protocol.ByteCount, bool) error) *MockStreamFlowControllerUpdateHighestReceivedCall { + c.Call = c.Call.DoAndReturn(f) + return c } // UpdateSendWindow mocks base method. -func (m *MockStreamFlowController) UpdateSendWindow(arg0 protocol.ByteCount) { +func (m *MockStreamFlowController) UpdateSendWindow(arg0 protocol.ByteCount) bool { m.ctrl.T.Helper() - m.ctrl.Call(m, "UpdateSendWindow", arg0) + ret := m.ctrl.Call(m, "UpdateSendWindow", arg0) + ret0, _ := ret[0].(bool) + return ret0 } // UpdateSendWindow indicates an expected call of UpdateSendWindow. -func (mr *MockStreamFlowControllerMockRecorder) UpdateSendWindow(arg0 any) *gomock.Call { +func (mr *MockStreamFlowControllerMockRecorder) UpdateSendWindow(arg0 any) *MockStreamFlowControllerUpdateSendWindowCall { mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "UpdateSendWindow", reflect.TypeOf((*MockStreamFlowController)(nil).UpdateSendWindow), arg0) + call := mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "UpdateSendWindow", reflect.TypeOf((*MockStreamFlowController)(nil).UpdateSendWindow), arg0) + return &MockStreamFlowControllerUpdateSendWindowCall{Call: call} +} + +// MockStreamFlowControllerUpdateSendWindowCall wrap *gomock.Call +type MockStreamFlowControllerUpdateSendWindowCall struct { + *gomock.Call +} + +// Return rewrite *gomock.Call.Return +func (c *MockStreamFlowControllerUpdateSendWindowCall) Return(arg0 bool) *MockStreamFlowControllerUpdateSendWindowCall { + c.Call = c.Call.Return(arg0) + return c +} + +// Do rewrite *gomock.Call.Do +func (c *MockStreamFlowControllerUpdateSendWindowCall) Do(f func(protocol.ByteCount) bool) *MockStreamFlowControllerUpdateSendWindowCall { + c.Call = c.Call.Do(f) + return c +} + +// DoAndReturn rewrite *gomock.Call.DoAndReturn +func (c *MockStreamFlowControllerUpdateSendWindowCall) DoAndReturn(f func(protocol.ByteCount) bool) *MockStreamFlowControllerUpdateSendWindowCall { + c.Call = c.Call.DoAndReturn(f) + return c } diff --git a/internal/mocks/tls/client_session_cache.go b/internal/mocks/tls/client_session_cache.go index 30483c14f..fe657a033 100644 --- a/internal/mocks/tls/client_session_cache.go +++ b/internal/mocks/tls/client_session_cache.go @@ -3,8 +3,9 @@ // // Generated by this command: // -// mockgen -build_flags=-tags=gomock -package mocktls -destination tls/client_session_cache.go crypto/tls ClientSessionCache +// mockgen -typed -build_flags=-tags=gomock -package mocktls -destination tls/client_session_cache.go crypto/tls ClientSessionCache // + // Package mocktls is a generated GoMock package. package mocktls @@ -48,9 +49,33 @@ func (m *MockClientSessionCache) Get(arg0 string) (*tls.ClientSessionState, bool } // Get indicates an expected call of Get. -func (mr *MockClientSessionCacheMockRecorder) Get(arg0 any) *gomock.Call { +func (mr *MockClientSessionCacheMockRecorder) Get(arg0 any) *MockClientSessionCacheGetCall { mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Get", reflect.TypeOf((*MockClientSessionCache)(nil).Get), arg0) + call := mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Get", reflect.TypeOf((*MockClientSessionCache)(nil).Get), arg0) + return &MockClientSessionCacheGetCall{Call: call} +} + +// MockClientSessionCacheGetCall wrap *gomock.Call +type MockClientSessionCacheGetCall struct { + *gomock.Call +} + +// Return rewrite *gomock.Call.Return +func (c *MockClientSessionCacheGetCall) Return(arg0 *tls.ClientSessionState, arg1 bool) *MockClientSessionCacheGetCall { + c.Call = c.Call.Return(arg0, arg1) + return c +} + +// Do rewrite *gomock.Call.Do +func (c *MockClientSessionCacheGetCall) Do(f func(string) (*tls.ClientSessionState, bool)) *MockClientSessionCacheGetCall { + c.Call = c.Call.Do(f) + return c +} + +// DoAndReturn rewrite *gomock.Call.DoAndReturn +func (c *MockClientSessionCacheGetCall) DoAndReturn(f func(string) (*tls.ClientSessionState, bool)) *MockClientSessionCacheGetCall { + c.Call = c.Call.DoAndReturn(f) + return c } // Put mocks base method. @@ -60,7 +85,31 @@ func (m *MockClientSessionCache) Put(arg0 string, arg1 *tls.ClientSessionState) } // Put indicates an expected call of Put. -func (mr *MockClientSessionCacheMockRecorder) Put(arg0, arg1 any) *gomock.Call { +func (mr *MockClientSessionCacheMockRecorder) Put(arg0, arg1 any) *MockClientSessionCachePutCall { mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Put", reflect.TypeOf((*MockClientSessionCache)(nil).Put), arg0, arg1) + call := mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Put", reflect.TypeOf((*MockClientSessionCache)(nil).Put), arg0, arg1) + return &MockClientSessionCachePutCall{Call: call} +} + +// MockClientSessionCachePutCall wrap *gomock.Call +type MockClientSessionCachePutCall struct { + *gomock.Call +} + +// Return rewrite *gomock.Call.Return +func (c *MockClientSessionCachePutCall) Return() *MockClientSessionCachePutCall { + c.Call = c.Call.Return() + return c +} + +// Do rewrite *gomock.Call.Do +func (c *MockClientSessionCachePutCall) Do(f func(string, *tls.ClientSessionState)) *MockClientSessionCachePutCall { + c.Call = c.Call.Do(f) + return c +} + +// DoAndReturn rewrite *gomock.Call.DoAndReturn +func (c *MockClientSessionCachePutCall) DoAndReturn(f func(string, *tls.ClientSessionState)) *MockClientSessionCachePutCall { + c.Call = c.Call.DoAndReturn(f) + return c } diff --git a/internal/protocol/params.go b/internal/protocol/params.go index 3ca68bf83..487cbc06b 100644 --- a/internal/protocol/params.go +++ b/internal/protocol/params.go @@ -129,13 +129,6 @@ const MaxPostHandshakeCryptoFrameSize = 1000 // but must ensure that a maximum size ACK frame fits into one packet. const MaxAckFrameSize ByteCount = 1000 -// MaxDatagramFrameSize is the maximum size of a DATAGRAM frame (RFC 9221). -// The size is chosen such that a DATAGRAM frame fits into a QUIC packet. -const MaxDatagramFrameSize ByteCount = 1200 - -// DatagramRcvQueueLen is the length of the receive queue for DATAGRAM frames (RFC 9221) -const DatagramRcvQueueLen = 128 - // MaxNumAckRanges is the maximum number of ACK ranges that we send in an ACK frame. // It also serves as a limit for the packet history. // If at any point we keep track of more ranges, old ranges are discarded. diff --git a/internal/protocol/perspective.go b/internal/protocol/perspective.go index 43358fecb..5a29d3ce2 100644 --- a/internal/protocol/perspective.go +++ b/internal/protocol/perspective.go @@ -17,9 +17,9 @@ func (p Perspective) Opposite() Perspective { func (p Perspective) String() string { switch p { case PerspectiveServer: - return "Server" + return "server" case PerspectiveClient: - return "Client" + return "client" default: return "invalid perspective" } diff --git a/internal/protocol/perspective_test.go b/internal/protocol/perspective_test.go index 11e29cc96..83a904e03 100644 --- a/internal/protocol/perspective_test.go +++ b/internal/protocol/perspective_test.go @@ -7,8 +7,8 @@ import ( var _ = Describe("Perspective", func() { It("has a string representation", func() { - Expect(PerspectiveClient.String()).To(Equal("Client")) - Expect(PerspectiveServer.String()).To(Equal("Server")) + Expect(PerspectiveClient.String()).To(Equal("client")) + Expect(PerspectiveServer.String()).To(Equal("server")) Expect(Perspective(0).String()).To(Equal("invalid perspective")) }) diff --git a/internal/protocol/version.go b/internal/protocol/version.go index 5c2decbdc..025ade9b4 100644 --- a/internal/protocol/version.go +++ b/internal/protocol/version.go @@ -1,14 +1,17 @@ package protocol import ( - "crypto/rand" "encoding/binary" "fmt" "math" + "sync" + "time" + + "golang.org/x/exp/rand" ) -// VersionNumber is a version number as int -type VersionNumber uint32 +// Version is a version number as int +type Version uint32 // gQUIC version range as defined in the wiki: https://github.com/quicwg/base-drafts/wiki/QUIC-Versions const ( @@ -18,22 +21,22 @@ const ( // The version numbers, making grepping easier const ( - VersionUnknown VersionNumber = math.MaxUint32 - versionDraft29 VersionNumber = 0xff00001d // draft-29 used to be a widely deployed version - Version1 VersionNumber = 0x1 - Version2 VersionNumber = 0x6b3343cf + VersionUnknown Version = math.MaxUint32 + versionDraft29 Version = 0xff00001d // draft-29 used to be a widely deployed version + Version1 Version = 0x1 + Version2 Version = 0x6b3343cf ) // SupportedVersions lists the versions that the server supports // must be in sorted descending order -var SupportedVersions = []VersionNumber{Version1, Version2} +var SupportedVersions = []Version{Version1, Version2} // IsValidVersion says if the version is known to quic-go -func IsValidVersion(v VersionNumber) bool { +func IsValidVersion(v Version) bool { return v == Version1 || IsSupportedVersion(SupportedVersions, v) } -func (vn VersionNumber) String() string { +func (vn Version) String() string { //nolint:exhaustive switch vn { case VersionUnknown: @@ -52,16 +55,16 @@ func (vn VersionNumber) String() string { } } -func (vn VersionNumber) isGQUIC() bool { +func (vn Version) isGQUIC() bool { return vn > gquicVersion0 && vn <= maxGquicVersion } -func (vn VersionNumber) toGQUICVersion() int { +func (vn Version) toGQUICVersion() int { return int(10*(vn-gquicVersion0)/0x100) + int(vn%0x10) } // IsSupportedVersion returns true if the server supports this version -func IsSupportedVersion(supported []VersionNumber, v VersionNumber) bool { +func IsSupportedVersion(supported []Version, v Version) bool { for _, t := range supported { if t == v { return true @@ -74,7 +77,7 @@ func IsSupportedVersion(supported []VersionNumber, v VersionNumber) bool { // ours is a slice of versions that we support, sorted by our preference (descending) // theirs is a slice of versions offered by the peer. The order does not matter. // The bool returned indicates if a matching version was found. -func ChooseSupportedVersion(ours, theirs []VersionNumber) (VersionNumber, bool) { +func ChooseSupportedVersion(ours, theirs []Version) (Version, bool) { for _, ourVer := range ours { for _, theirVer := range theirs { if ourVer == theirVer { @@ -85,19 +88,25 @@ func ChooseSupportedVersion(ours, theirs []VersionNumber) (VersionNumber, bool) return 0, false } -// generateReservedVersion generates a reserved version number (v & 0x0f0f0f0f == 0x0a0a0a0a) -func generateReservedVersion() VersionNumber { - b := make([]byte, 4) - _, _ = rand.Read(b) // ignore the error here. Failure to read random data doesn't break anything - return VersionNumber((binary.BigEndian.Uint32(b) | 0x0a0a0a0a) & 0xfafafafa) +var ( + versionNegotiationMx sync.Mutex + versionNegotiationRand = rand.New(rand.NewSource(uint64(time.Now().UnixNano()))) +) + +// generateReservedVersion generates a reserved version (v & 0x0f0f0f0f == 0x0a0a0a0a) +func generateReservedVersion() Version { + var b [4]byte + _, _ = versionNegotiationRand.Read(b[:]) // ignore the error here. Failure to read random data doesn't break anything + return Version((binary.BigEndian.Uint32(b[:]) | 0x0a0a0a0a) & 0xfafafafa) } -// GetGreasedVersions adds one reserved version number to a slice of version numbers, at a random position -func GetGreasedVersions(supported []VersionNumber) []VersionNumber { - b := make([]byte, 1) - _, _ = rand.Read(b) // ignore the error here. Failure to read random data doesn't break anything - randPos := int(b[0]) % (len(supported) + 1) - greased := make([]VersionNumber, len(supported)+1) +// GetGreasedVersions adds one reserved version number to a slice of version numbers, at a random position. +// It doesn't modify the supported slice. +func GetGreasedVersions(supported []Version) []Version { + versionNegotiationMx.Lock() + defer versionNegotiationMx.Unlock() + randPos := rand.Intn(len(supported) + 1) + greased := make([]Version, len(supported)+1) copy(greased, supported[:randPos]) greased[randPos] = generateReservedVersion() copy(greased[randPos+1:], supported[randPos:]) diff --git a/internal/protocol/version_test.go b/internal/protocol/version_test.go index f1bceefcb..2f8c0b2df 100644 --- a/internal/protocol/version_test.go +++ b/internal/protocol/version_test.go @@ -6,7 +6,7 @@ import ( ) var _ = Describe("Version", func() { - isReservedVersion := func(v VersionNumber) bool { + isReservedVersion := func(v Version) bool { return v&0x0f0f0f0f == 0x0a0a0a0a } @@ -24,11 +24,11 @@ var _ = Describe("Version", func() { Expect(Version1.String()).To(Equal("v1")) Expect(Version2.String()).To(Equal("v2")) // check with unsupported version numbers from the wiki - Expect(VersionNumber(0x51303039).String()).To(Equal("gQUIC 9")) - Expect(VersionNumber(0x51303133).String()).To(Equal("gQUIC 13")) - Expect(VersionNumber(0x51303235).String()).To(Equal("gQUIC 25")) - Expect(VersionNumber(0x51303438).String()).To(Equal("gQUIC 48")) - Expect(VersionNumber(0x01234567).String()).To(Equal("0x1234567")) + Expect(Version(0x51303039).String()).To(Equal("gQUIC 9")) + Expect(Version(0x51303133).String()).To(Equal("gQUIC 13")) + Expect(Version(0x51303235).String()).To(Equal("gQUIC 25")) + Expect(Version(0x51303438).String()).To(Equal("gQUIC 48")) + Expect(Version(0x01234567).String()).To(Equal("0x1234567")) }) It("recognizes supported versions", func() { @@ -39,45 +39,45 @@ var _ = Describe("Version", func() { Context("highest supported version", func() { It("finds the supported version", func() { - supportedVersions := []VersionNumber{1, 2, 3} - other := []VersionNumber{6, 5, 4, 3} + supportedVersions := []Version{1, 2, 3} + other := []Version{6, 5, 4, 3} ver, ok := ChooseSupportedVersion(supportedVersions, other) Expect(ok).To(BeTrue()) - Expect(ver).To(Equal(VersionNumber(3))) + Expect(ver).To(Equal(Version(3))) }) It("picks the preferred version", func() { - supportedVersions := []VersionNumber{2, 1, 3} - other := []VersionNumber{3, 6, 1, 8, 2, 10} + supportedVersions := []Version{2, 1, 3} + other := []Version{3, 6, 1, 8, 2, 10} ver, ok := ChooseSupportedVersion(supportedVersions, other) Expect(ok).To(BeTrue()) - Expect(ver).To(Equal(VersionNumber(2))) + Expect(ver).To(Equal(Version(2))) }) It("says when no matching version was found", func() { - _, ok := ChooseSupportedVersion([]VersionNumber{1}, []VersionNumber{2}) + _, ok := ChooseSupportedVersion([]Version{1}, []Version{2}) Expect(ok).To(BeFalse()) }) It("handles empty inputs", func() { - _, ok := ChooseSupportedVersion([]VersionNumber{102, 101}, []VersionNumber{}) + _, ok := ChooseSupportedVersion([]Version{102, 101}, []Version{}) Expect(ok).To(BeFalse()) - _, ok = ChooseSupportedVersion([]VersionNumber{}, []VersionNumber{1, 2}) + _, ok = ChooseSupportedVersion([]Version{}, []Version{1, 2}) Expect(ok).To(BeFalse()) - _, ok = ChooseSupportedVersion([]VersionNumber{}, []VersionNumber{}) + _, ok = ChooseSupportedVersion([]Version{}, []Version{}) Expect(ok).To(BeFalse()) }) }) Context("reserved versions", func() { It("adds a greased version if passed an empty slice", func() { - greased := GetGreasedVersions([]VersionNumber{}) + greased := GetGreasedVersions([]Version{}) Expect(greased).To(HaveLen(1)) Expect(isReservedVersion(greased[0])).To(BeTrue()) }) It("creates greased lists of version numbers", func() { - supported := []VersionNumber{10, 18, 29} + supported := []Version{10, 18, 29} for _, v := range supported { Expect(isReservedVersion(v)).To(BeFalse()) } diff --git a/internal/qerr/error_codes.go b/internal/qerr/error_codes.go index b8e84fc03..00361308e 100644 --- a/internal/qerr/error_codes.go +++ b/internal/qerr/error_codes.go @@ -1,9 +1,8 @@ package qerr import ( + "crypto/tls" "fmt" - - "github.com/refraction-networking/uquic/internal/qtls" ) // TransportErrorCode is a QUIC transport error. @@ -40,7 +39,7 @@ func (e TransportErrorCode) Message() string { if !e.IsCryptoError() { return "" } - return qtls.AlertError(e - 0x100).Error() + return tls.AlertError(e - 0x100).Error() } func (e TransportErrorCode) String() string { diff --git a/internal/qerr/errors.go b/internal/qerr/errors.go index c3d8465b9..e88350590 100644 --- a/internal/qerr/errors.go +++ b/internal/qerr/errors.go @@ -101,8 +101,8 @@ func (e *HandshakeTimeoutError) Is(target error) bool { return target == net.Err // A VersionNegotiationError occurs when the client and the server can't agree on a QUIC version. type VersionNegotiationError struct { - Ours []protocol.VersionNumber - Theirs []protocol.VersionNumber + Ours []protocol.Version + Theirs []protocol.Version } func (e *VersionNegotiationError) Error() string { diff --git a/internal/qerr/errors_test.go b/internal/qerr/errors_test.go index 7bfbc7345..086d8ee8f 100644 --- a/internal/qerr/errors_test.go +++ b/internal/qerr/errors_test.go @@ -108,8 +108,8 @@ var _ = Describe("QUIC Errors", func() { Context("Version Negotiation errors", func() { It("has a string representation", func() { Expect((&VersionNegotiationError{ - Ours: []protocol.VersionNumber{2, 3}, - Theirs: []protocol.VersionNumber{4, 5, 6}, + Ours: []protocol.Version{2, 3}, + Theirs: []protocol.Version{4, 5, 6}, }).Error()).To(Equal("no compatible QUIC version found (we support [0x2 0x3], server offered [0x4 0x5 0x6])")) }) }) diff --git a/internal/qtls/cipher_suite.go b/internal/qtls/cipher_suite.go index 07891c141..e13cdcfd9 100644 --- a/internal/qtls/cipher_suite.go +++ b/internal/qtls/cipher_suite.go @@ -1,24 +1,12 @@ package qtls import ( - "crypto" - "crypto/cipher" "fmt" "unsafe" tls "github.com/refraction-networking/utls" ) -type cipherSuiteTLS13 struct { - ID uint16 - KeyLen int - AEAD func(key, fixedNonce []byte) cipher.AEAD - Hash crypto.Hash -} - -//go:linkname cipherSuiteTLS13ByID crypto/tls.cipherSuiteTLS13ByID -func cipherSuiteTLS13ByID(id uint16) *cipherSuiteTLS13 - //go:linkname cipherSuitesTLS13 crypto/tls.cipherSuitesTLS13 var cipherSuitesTLS13 []unsafe.Pointer diff --git a/internal/qtls/cipher_suite_test.go b/internal/qtls/cipher_suite_test.go new file mode 100644 index 000000000..c76763f59 --- /dev/null +++ b/internal/qtls/cipher_suite_test.go @@ -0,0 +1,50 @@ +package qtls + +import ( + "fmt" + "net" + + "github.com/refraction-networking/uquic/internal/testdata" + tls "github.com/refraction-networking/utls" + + . "github.com/onsi/ginkgo/v2" + . "github.com/onsi/gomega" +) + +var _ = Describe("Setting the Cipher Suite", func() { + for _, cs := range []uint16{tls.TLS_AES_128_GCM_SHA256, tls.TLS_CHACHA20_POLY1305_SHA256, tls.TLS_AES_256_GCM_SHA384} { + cs := cs + + It(fmt.Sprintf("selects %s", tls.CipherSuiteName(cs)), func() { + reset := SetCipherSuite(cs) + defer reset() + + ln, err := tls.Listen("tcp4", "localhost:0", testdata.GetTLSConfig()) + Expect(err).ToNot(HaveOccurred()) + defer ln.Close() + + done := make(chan struct{}) + go func() { + defer GinkgoRecover() + defer close(done) + conn, err := ln.Accept() + Expect(err).ToNot(HaveOccurred()) + _, err = conn.Read(make([]byte, 10)) + Expect(err).ToNot(HaveOccurred()) + Expect(conn.(*tls.Conn).ConnectionState().CipherSuite).To(Equal(cs)) + }() + + conn, err := tls.Dial( + "tcp4", + fmt.Sprintf("localhost:%d", ln.Addr().(*net.TCPAddr).Port), + &tls.Config{RootCAs: testdata.GetRootCA()}, + ) + Expect(err).ToNot(HaveOccurred()) + _, err = conn.Write([]byte("foobar")) + Expect(err).ToNot(HaveOccurred()) + Expect(conn.ConnectionState().CipherSuite).To(Equal(cs)) + Expect(conn.Close()).To(Succeed()) + Eventually(done).Should(BeClosed()) + }) + } +}) diff --git a/internal/qtls/client_session_cache.go b/internal/qtls/client_session_cache.go index ca8c885ce..52889f82a 100644 --- a/internal/qtls/client_session_cache.go +++ b/internal/qtls/client_session_cache.go @@ -1,18 +1,24 @@ package qtls import ( + "sync" + tls "github.com/refraction-networking/utls" ) type clientSessionCache struct { - getData func() []byte - setData func([]byte) + mx sync.Mutex + getData func(earlyData bool) []byte + setData func(data []byte, earlyData bool) (allowEarlyData bool) wrapped tls.ClientSessionCache } var _ tls.ClientSessionCache = &clientSessionCache{} -func (c clientSessionCache) Put(key string, cs *tls.ClientSessionState) { +func (c *clientSessionCache) Put(key string, cs *tls.ClientSessionState) { + c.mx.Lock() + defer c.mx.Unlock() + if cs == nil { c.wrapped.Put(key, nil) return @@ -22,7 +28,7 @@ func (c clientSessionCache) Put(key string, cs *tls.ClientSessionState) { c.wrapped.Put(key, cs) return } - state.Extra = append(state.Extra, addExtraPrefix(c.getData())) + state.Extra = append(state.Extra, addExtraPrefix(c.getData(state.EarlyData))) newCS, err := tls.NewResumptionState(ticket, state) if err != nil { // It's not clear why this would error. Just save the original state. @@ -32,7 +38,10 @@ func (c clientSessionCache) Put(key string, cs *tls.ClientSessionState) { c.wrapped.Put(key, newCS) } -func (c clientSessionCache) Get(key string) (*tls.ClientSessionState, bool) { +func (c *clientSessionCache) Get(key string) (*tls.ClientSessionState, bool) { + c.mx.Lock() + defer c.mx.Unlock() + cs, ok := c.wrapped.Get(key) if !ok || cs == nil { return cs, ok @@ -46,7 +55,10 @@ func (c clientSessionCache) Get(key string) (*tls.ClientSessionState, bool) { } // restore QUIC transport parameters and RTT stored in state.Extra if extra := findExtraData(state.Extra); extra != nil { - c.setData(extra) + earlyData := c.setData(extra, state.EarlyData) + if state.EarlyData { + state.EarlyData = earlyData + } } session, err := tls.NewResumptionState(ticket, state) if err != nil { diff --git a/internal/qtls/client_session_cache_test.go b/internal/qtls/client_session_cache_test.go index 7bacebe29..37f123b65 100644 --- a/internal/qtls/client_session_cache_test.go +++ b/internal/qtls/client_session_cache_test.go @@ -39,8 +39,12 @@ var _ = Describe("Client Session Cache", func() { RootCAs: testdata.GetRootCA(), ClientSessionCache: &clientSessionCache{ wrapped: tls.NewLRUClientSessionCache(10), - getData: func() []byte { return []byte("session") }, - setData: func(data []byte) { restored <- data }, + getData: func(bool) []byte { return []byte("session") }, + setData: func(data []byte, earlyData bool) bool { + Expect(earlyData).To(BeFalse()) // running on top of TCP, we can only test non-0-RTT here + restored <- data + return true + }, }, } conn, err := tls.Dial( diff --git a/internal/qtls/go_oldversion.go b/internal/qtls/go_oldversion.go deleted file mode 100644 index 2903587e9..000000000 --- a/internal/qtls/go_oldversion.go +++ /dev/null @@ -1,5 +0,0 @@ -//go:build !go1.20 - -package qtls - -var _ int = "The version of quic-go you're using can't be built using outdated Go versions. For more details, please see https://github.com/refraction-networking/uquic/wiki/quic-go-and-Go-versions." diff --git a/internal/qtls/utls.go b/internal/qtls/qtls.go similarity index 61% rename from internal/qtls/utls.go rename to internal/qtls/qtls.go index 45f19e11e..0425b8744 100644 --- a/internal/qtls/utls.go +++ b/internal/qtls/qtls.go @@ -9,53 +9,7 @@ import ( "github.com/refraction-networking/uquic/internal/protocol" ) -type ( - QUICConn = tls.QUICConn - UQUICConn = tls.UQUICConn // [UQUIC] - QUICConfig = tls.QUICConfig - QUICEvent = tls.QUICEvent - QUICEventKind = tls.QUICEventKind - QUICEncryptionLevel = tls.QUICEncryptionLevel - QUICSessionTicketOptions = tls.QUICSessionTicketOptions - AlertError = tls.AlertError -) - -const ( - QUICEncryptionLevelInitial = tls.QUICEncryptionLevelInitial - QUICEncryptionLevelEarly = tls.QUICEncryptionLevelEarly - QUICEncryptionLevelHandshake = tls.QUICEncryptionLevelHandshake - QUICEncryptionLevelApplication = tls.QUICEncryptionLevelApplication -) - -const ( - QUICNoEvent = tls.QUICNoEvent - QUICSetReadSecret = tls.QUICSetReadSecret - QUICSetWriteSecret = tls.QUICSetWriteSecret - QUICWriteData = tls.QUICWriteData - QUICTransportParameters = tls.QUICTransportParameters - QUICTransportParametersRequired = tls.QUICTransportParametersRequired - QUICRejectedEarlyData = tls.QUICRejectedEarlyData - QUICHandshakeDone = tls.QUICHandshakeDone -) - -func QUICServer(config *QUICConfig) *QUICConn { - return tls.QUICServer(config) -} - -func QUICClient(config *QUICConfig) *QUICConn { - return tls.QUICClient(config) -} - -// [UQUIC] -func UQUICClient(config *QUICConfig, clientHelloSpec *tls.ClientHelloSpec) *UQUICConn { - uqc := tls.UQUICClient(config, tls.HelloCustom) - if err := uqc.ApplyPreset(clientHelloSpec); err != nil { - panic(err) - } - return uqc -} - -func SetupConfigForServer(qconf *QUICConfig, _ bool, getData func() []byte, handleSessionTicket func([]byte, bool) bool) { +func SetupConfigForServer(qconf *tls.QUICConfig, _ bool, getData func() []byte, handleSessionTicket func([]byte, bool) bool) { conf := qconf.TLSConfig // Workaround for https://github.com/golang/go/issues/60506. @@ -107,7 +61,11 @@ func SetupConfigForServer(qconf *QUICConfig, _ bool, getData func() []byte, hand } } -func SetupConfigForClient(qconf *QUICConfig, getData func() []byte, setData func([]byte)) { +func SetupConfigForClient( + qconf *tls.QUICConfig, + getData func(earlyData bool) []byte, + setData func(data []byte, earlyData bool) (allowEarlyData bool), +) { conf := qconf.TLSConfig if conf.ClientSessionCache != nil { origCache := conf.ClientSessionCache @@ -165,14 +123,3 @@ func findExtraData(extras [][]byte) []byte { } return nil } - -type QUICConnOrUQUICConn interface { - *QUICConn | *UQUICConn - SendSessionTicket(opts QUICSessionTicketOptions) error -} - -func SendSessionTicket[C QUICConnOrUQUICConn](c C, allow0RTT bool) error { - return c.SendSessionTicket(tls.QUICSessionTicketOptions{ - EarlyData: allow0RTT, - }) -} diff --git a/internal/qtls/utls_test.go b/internal/qtls/qtls_test.go similarity index 90% rename from internal/qtls/utls_test.go rename to internal/qtls/qtls_test.go index 31dda063d..46994562d 100644 --- a/internal/qtls/utls_test.go +++ b/internal/qtls/qtls_test.go @@ -9,7 +9,7 @@ import ( . "github.com/onsi/gomega" ) -var _ = Describe("Go 1.21", func() { +var _ = Describe("interface go crypto/tls", func() { It("converts to tls.EncryptionLevel", func() { Expect(ToTLSEncryptionLevel(protocol.EncryptionInitial)).To(Equal(tls.QUICEncryptionLevelInitial)) Expect(ToTLSEncryptionLevel(protocol.EncryptionHandshake)).To(Equal(tls.QUICEncryptionLevelHandshake)) @@ -27,14 +27,14 @@ var _ = Describe("Go 1.21", func() { Context("setting up a tls.Config for the client", func() { It("sets up a session cache if there's one present on the config", func() { csc := tls.NewLRUClientSessionCache(1) - conf := &QUICConfig{TLSConfig: &tls.Config{ClientSessionCache: csc}} + conf := &tls.QUICConfig{TLSConfig: &tls.Config{ClientSessionCache: csc}} SetupConfigForClient(conf, nil, nil) Expect(conf.TLSConfig.ClientSessionCache).ToNot(BeNil()) Expect(conf.TLSConfig.ClientSessionCache).ToNot(Equal(csc)) }) It("doesn't set up a session cache if there's none present on the config", func() { - conf := &QUICConfig{TLSConfig: &tls.Config{}} + conf := &tls.QUICConfig{TLSConfig: &tls.Config{}} SetupConfigForClient(conf, nil, nil) Expect(conf.TLSConfig.ClientSessionCache).To(BeNil()) }) @@ -43,7 +43,7 @@ var _ = Describe("Go 1.21", func() { Context("setting up a tls.Config for the server", func() { It("sets the minimum TLS version to TLS 1.3", func() { orig := &tls.Config{MinVersion: tls.VersionTLS12} - conf := &QUICConfig{TLSConfig: orig} + conf := &tls.QUICConfig{TLSConfig: orig} SetupConfigForServer(conf, false, nil, nil) Expect(conf.TLSConfig.MinVersion).To(BeEquivalentTo(tls.VersionTLS13)) // check that the original config wasn't modified diff --git a/internal/utils/minmax.go b/internal/utils/minmax.go index d191f7515..03a9c9a87 100644 --- a/internal/utils/minmax.go +++ b/internal/utils/minmax.go @@ -3,27 +3,11 @@ package utils import ( "math" "time" - - "golang.org/x/exp/constraints" ) // InfDuration is a duration of infinite length const InfDuration = time.Duration(math.MaxInt64) -func Max[T constraints.Ordered](a, b T) T { - if a < b { - return b - } - return a -} - -func Min[T constraints.Ordered](a, b T) T { - if a < b { - return a - } - return b -} - // MinNonZeroDuration return the minimum duration that's not zero. func MinNonZeroDuration(a, b time.Duration) time.Duration { if a == 0 { @@ -32,15 +16,7 @@ func MinNonZeroDuration(a, b time.Duration) time.Duration { if b == 0 { return a } - return Min(a, b) -} - -// AbsDuration returns the absolute value of a time duration -func AbsDuration(d time.Duration) time.Duration { - if d >= 0 { - return d - } - return -d + return min(a, b) } // MinTime returns the earlier time @@ -51,18 +27,6 @@ func MinTime(a, b time.Time) time.Time { return a } -// MinNonZeroTime returns the earlist time that is not time.Time{} -// If both a and b are time.Time{}, it returns time.Time{} -func MinNonZeroTime(a, b time.Time) time.Time { - if a.IsZero() { - return b - } - if b.IsZero() { - return a - } - return MinTime(a, b) -} - // MaxTime returns the later time func MaxTime(a, b time.Time) time.Time { if a.After(b) { diff --git a/internal/utils/minmax_test.go b/internal/utils/minmax_test.go index 163134ece..3d6485025 100644 --- a/internal/utils/minmax_test.go +++ b/internal/utils/minmax_test.go @@ -8,16 +8,6 @@ import ( ) var _ = Describe("Min / Max", func() { - It("returns the maximum", func() { - Expect(Max(5, 7)).To(Equal(7)) - Expect(Max(5.5, 5.7)).To(Equal(5.7)) - }) - - It("returns the minimum", func() { - Expect(Min(5, 7)).To(Equal(5)) - Expect(Min(5.5, 5.7)).To(Equal(5.5)) - }) - It("returns the maximum time", func() { a := time.Now() b := a.Add(time.Second) @@ -40,19 +30,4 @@ var _ = Describe("Min / Max", func() { Expect(MinNonZeroDuration(b, a)).To(Equal(b)) Expect(MinNonZeroDuration(time.Minute, time.Hour)).To(Equal(time.Minute)) }) - - It("returns the minium non-zero time", func() { - a := time.Time{} - b := time.Now() - Expect(MinNonZeroTime(time.Time{}, time.Time{})).To(Equal(time.Time{})) - Expect(MinNonZeroTime(a, b)).To(Equal(b)) - Expect(MinNonZeroTime(b, a)).To(Equal(b)) - Expect(MinNonZeroTime(b, b.Add(time.Second))).To(Equal(b)) - Expect(MinNonZeroTime(b.Add(time.Second), b)).To(Equal(b)) - }) - - It("returns the abs time", func() { - Expect(AbsDuration(time.Microsecond)).To(Equal(time.Microsecond)) - Expect(AbsDuration(-time.Microsecond)).To(Equal(time.Microsecond)) - }) }) diff --git a/internal/utils/ringbuffer/ringbuffer.go b/internal/utils/ringbuffer/ringbuffer.go index eae261f1c..5aa7fb376 100644 --- a/internal/utils/ringbuffer/ringbuffer.go +++ b/internal/utils/ringbuffer/ringbuffer.go @@ -8,7 +8,7 @@ type RingBuffer[T any] struct { full bool } -// Init preallocs a buffer with a certain size. +// Init preallocates a buffer with a certain size. func (r *RingBuffer[T]) Init(size int) { r.ring = make([]T, size) } @@ -62,6 +62,16 @@ func (r *RingBuffer[T]) PopFront() T { return t } +// PeekFront returns the next element. +// It must not be called when the buffer is empty, that means that +// callers might need to check if there are elements in the buffer first. +func (r *RingBuffer[T]) PeekFront() T { + if r.Empty() { + panic("github.com/refraction-networking/uquic/internal/utils/ringbuffer: peek from an empty queue") + } + return r.ring[r.headPos] +} + // Grow the maximum size of the queue. // This method assume the queue is full. func (r *RingBuffer[T]) grow() { diff --git a/internal/utils/ringbuffer/ringbuffer_test.go b/internal/utils/ringbuffer/ringbuffer_test.go index 13241a308..68f1c7cc7 100644 --- a/internal/utils/ringbuffer/ringbuffer_test.go +++ b/internal/utils/ringbuffer/ringbuffer_test.go @@ -6,14 +6,17 @@ import ( ) var _ = Describe("RingBuffer", func() { - It("push and pop", func() { + It("push, peek and pop", func() { r := RingBuffer[int]{} Expect(len(r.ring)).To(Equal(0)) Expect(func() { r.PopFront() }).To(Panic()) r.PushBack(1) r.PushBack(2) r.PushBack(3) + Expect(r.PeekFront()).To(Equal(1)) + Expect(r.PeekFront()).To(Equal(1)) Expect(r.PopFront()).To(Equal(1)) + Expect(r.PeekFront()).To(Equal(2)) Expect(r.PopFront()).To(Equal(2)) r.PushBack(4) r.PushBack(5) @@ -25,7 +28,16 @@ var _ = Describe("RingBuffer", func() { Expect(r.PopFront()).To(Equal(5)) Expect(r.PopFront()).To(Equal(6)) }) - It("clear", func() { + + It("panics when Peek or Pop are called on an empty buffer", func() { + r := RingBuffer[string]{} + Expect(r.Empty()).To(BeTrue()) + Expect(r.Len()).To(BeZero()) + Expect(func() { r.PeekFront() }).To(Panic()) + Expect(func() { r.PopFront() }).To(Panic()) + }) + + It("clearing", func() { r := RingBuffer[int]{} r.Init(2) r.PushBack(1) diff --git a/internal/utils/rtt_stats.go b/internal/utils/rtt_stats.go index 4e867ca76..9b7571bc5 100644 --- a/internal/utils/rtt_stats.go +++ b/internal/utils/rtt_stats.go @@ -55,7 +55,7 @@ func (r *RTTStats) PTO(includeMaxAckDelay bool) time.Duration { if r.SmoothedRTT() == 0 { return 2 * defaultInitialRTT } - pto := r.SmoothedRTT() + Max(4*r.MeanDeviation(), protocol.TimerGranularity) + pto := r.SmoothedRTT() + max(4*r.MeanDeviation(), protocol.TimerGranularity) if includeMaxAckDelay { pto += r.MaxAckDelay() } @@ -90,7 +90,7 @@ func (r *RTTStats) UpdateRTT(sendDelta, ackDelay time.Duration, now time.Time) { r.smoothedRTT = sample r.meanDeviation = sample / 2 } else { - r.meanDeviation = time.Duration(oneMinusBeta*float32(r.meanDeviation/time.Microsecond)+rttBeta*float32(AbsDuration(r.smoothedRTT-sample)/time.Microsecond)) * time.Microsecond + r.meanDeviation = time.Duration(oneMinusBeta*float32(r.meanDeviation/time.Microsecond)+rttBeta*float32((r.smoothedRTT-sample).Abs()/time.Microsecond)) * time.Microsecond r.smoothedRTT = time.Duration((float32(r.smoothedRTT/time.Microsecond)*oneMinusAlpha)+(float32(sample/time.Microsecond)*rttAlpha)) * time.Microsecond } } @@ -126,6 +126,6 @@ func (r *RTTStats) OnConnectionMigration() { // is larger. The mean deviation is increased to the most recent deviation if // it's larger. func (r *RTTStats) ExpireSmoothedMetrics() { - r.meanDeviation = Max(r.meanDeviation, AbsDuration(r.smoothedRTT-r.latestRTT)) - r.smoothedRTT = Max(r.smoothedRTT, r.latestRTT) + r.meanDeviation = max(r.meanDeviation, (r.smoothedRTT - r.latestRTT).Abs()) + r.smoothedRTT = max(r.smoothedRTT, r.latestRTT) } diff --git a/internal/wire/ack_frame.go b/internal/wire/ack_frame.go index b6831916b..ba47fe53a 100644 --- a/internal/wire/ack_frame.go +++ b/internal/wire/ack_frame.go @@ -22,7 +22,7 @@ type AckFrame struct { } // parseAckFrame reads an ACK frame -func parseAckFrame(frame *AckFrame, r *bytes.Reader, typ uint64, ackDelayExponent uint8, _ protocol.VersionNumber) error { +func parseAckFrame(frame *AckFrame, r *bytes.Reader, typ uint64, ackDelayExponent uint8, _ protocol.Version) error { ecn := typ == ackECNFrameType la, err := quicvarint.Read(r) @@ -37,7 +37,7 @@ func parseAckFrame(frame *AckFrame, r *bytes.Reader, typ uint64, ackDelayExponen delayTime := time.Duration(delay*1< 0 || f.ECT1 > 0 || f.ECNCE > 0 if hasECN { b = append(b, ackECNFrameType) @@ -143,7 +143,7 @@ func (f *AckFrame) Append(b []byte, _ protocol.VersionNumber) ([]byte, error) { } // Length of a written frame -func (f *AckFrame) Length(_ protocol.VersionNumber) protocol.ByteCount { +func (f *AckFrame) Length(_ protocol.Version) protocol.ByteCount { largestAcked := f.AckRanges[0].Largest numRanges := f.numEncodableAckRanges() diff --git a/internal/wire/connection_close_frame.go b/internal/wire/connection_close_frame.go index 7a0b4db5f..02b221411 100644 --- a/internal/wire/connection_close_frame.go +++ b/internal/wire/connection_close_frame.go @@ -16,7 +16,7 @@ type ConnectionCloseFrame struct { ReasonPhrase string } -func parseConnectionCloseFrame(r *bytes.Reader, typ uint64, _ protocol.VersionNumber) (*ConnectionCloseFrame, error) { +func parseConnectionCloseFrame(r *bytes.Reader, typ uint64, _ protocol.Version) (*ConnectionCloseFrame, error) { f := &ConnectionCloseFrame{IsApplicationError: typ == applicationCloseFrameType} ec, err := quicvarint.Read(r) if err != nil { @@ -53,7 +53,7 @@ func parseConnectionCloseFrame(r *bytes.Reader, typ uint64, _ protocol.VersionNu } // Length of a written frame -func (f *ConnectionCloseFrame) Length(protocol.VersionNumber) protocol.ByteCount { +func (f *ConnectionCloseFrame) Length(protocol.Version) protocol.ByteCount { length := 1 + quicvarint.Len(f.ErrorCode) + quicvarint.Len(uint64(len(f.ReasonPhrase))) + protocol.ByteCount(len(f.ReasonPhrase)) if !f.IsApplicationError { length += quicvarint.Len(f.FrameType) // for the frame type @@ -61,7 +61,7 @@ func (f *ConnectionCloseFrame) Length(protocol.VersionNumber) protocol.ByteCount return length } -func (f *ConnectionCloseFrame) Append(b []byte, _ protocol.VersionNumber) ([]byte, error) { +func (f *ConnectionCloseFrame) Append(b []byte, _ protocol.Version) ([]byte, error) { if f.IsApplicationError { b = append(b, applicationCloseFrameType) } else { diff --git a/internal/wire/crypto_frame.go b/internal/wire/crypto_frame.go index c6e082347..316a175c9 100644 --- a/internal/wire/crypto_frame.go +++ b/internal/wire/crypto_frame.go @@ -14,7 +14,7 @@ type CryptoFrame struct { Data []byte } -func parseCryptoFrame(r *bytes.Reader, _ protocol.VersionNumber) (*CryptoFrame, error) { +func parseCryptoFrame(r *bytes.Reader, _ protocol.Version) (*CryptoFrame, error) { frame := &CryptoFrame{} offset, err := quicvarint.Read(r) if err != nil { @@ -38,7 +38,7 @@ func parseCryptoFrame(r *bytes.Reader, _ protocol.VersionNumber) (*CryptoFrame, return frame, nil } -func (f *CryptoFrame) Append(b []byte, _ protocol.VersionNumber) ([]byte, error) { +func (f *CryptoFrame) Append(b []byte, _ protocol.Version) ([]byte, error) { b = append(b, cryptoFrameType) b = quicvarint.Append(b, uint64(f.Offset)) b = quicvarint.Append(b, uint64(len(f.Data))) @@ -47,7 +47,7 @@ func (f *CryptoFrame) Append(b []byte, _ protocol.VersionNumber) ([]byte, error) } // Length of a written frame -func (f *CryptoFrame) Length(_ protocol.VersionNumber) protocol.ByteCount { +func (f *CryptoFrame) Length(_ protocol.Version) protocol.ByteCount { return 1 + quicvarint.Len(uint64(f.Offset)) + quicvarint.Len(uint64(len(f.Data))) + protocol.ByteCount(len(f.Data)) } @@ -71,7 +71,7 @@ func (f *CryptoFrame) MaxDataLen(maxSize protocol.ByteCount) protocol.ByteCount // The frame might not be split if: // * the size is large enough to fit the whole frame // * the size is too small to fit even a 1-byte frame. In that case, the frame returned is nil. -func (f *CryptoFrame) MaybeSplitOffFrame(maxSize protocol.ByteCount, version protocol.VersionNumber) (*CryptoFrame, bool /* was splitting required */) { +func (f *CryptoFrame) MaybeSplitOffFrame(maxSize protocol.ByteCount, version protocol.Version) (*CryptoFrame, bool /* was splitting required */) { if f.Length(version) <= maxSize { return nil, false } diff --git a/internal/wire/data_blocked_frame.go b/internal/wire/data_blocked_frame.go index d3eb0fa70..36fc77044 100644 --- a/internal/wire/data_blocked_frame.go +++ b/internal/wire/data_blocked_frame.go @@ -12,7 +12,7 @@ type DataBlockedFrame struct { MaximumData protocol.ByteCount } -func parseDataBlockedFrame(r *bytes.Reader, _ protocol.VersionNumber) (*DataBlockedFrame, error) { +func parseDataBlockedFrame(r *bytes.Reader, _ protocol.Version) (*DataBlockedFrame, error) { offset, err := quicvarint.Read(r) if err != nil { return nil, err @@ -20,12 +20,12 @@ func parseDataBlockedFrame(r *bytes.Reader, _ protocol.VersionNumber) (*DataBloc return &DataBlockedFrame{MaximumData: protocol.ByteCount(offset)}, nil } -func (f *DataBlockedFrame) Append(b []byte, version protocol.VersionNumber) ([]byte, error) { +func (f *DataBlockedFrame) Append(b []byte, version protocol.Version) ([]byte, error) { b = append(b, dataBlockedFrameType) return quicvarint.Append(b, uint64(f.MaximumData)), nil } // Length of a written frame -func (f *DataBlockedFrame) Length(version protocol.VersionNumber) protocol.ByteCount { +func (f *DataBlockedFrame) Length(version protocol.Version) protocol.ByteCount { return 1 + quicvarint.Len(uint64(f.MaximumData)) } diff --git a/internal/wire/datagram_frame.go b/internal/wire/datagram_frame.go index be9bd72d7..3970ce74f 100644 --- a/internal/wire/datagram_frame.go +++ b/internal/wire/datagram_frame.go @@ -8,13 +8,19 @@ import ( "github.com/refraction-networking/uquic/quicvarint" ) +// MaxDatagramSize is the maximum size of a DATAGRAM frame (RFC 9221). +// By setting it to a large value, we allow all datagrams that fit into a QUIC packet. +// The value is chosen such that it can still be encoded as a 2 byte varint. +// This is a var and not a const so it can be set in tests. +var MaxDatagramSize protocol.ByteCount = 16383 + // A DatagramFrame is a DATAGRAM frame type DatagramFrame struct { DataLenPresent bool Data []byte } -func parseDatagramFrame(r *bytes.Reader, typ uint64, _ protocol.VersionNumber) (*DatagramFrame, error) { +func parseDatagramFrame(r *bytes.Reader, typ uint64, _ protocol.Version) (*DatagramFrame, error) { f := &DatagramFrame{} f.DataLenPresent = typ&0x1 > 0 @@ -39,7 +45,7 @@ func parseDatagramFrame(r *bytes.Reader, typ uint64, _ protocol.VersionNumber) ( return f, nil } -func (f *DatagramFrame) Append(b []byte, _ protocol.VersionNumber) ([]byte, error) { +func (f *DatagramFrame) Append(b []byte, _ protocol.Version) ([]byte, error) { typ := uint8(0x30) if f.DataLenPresent { typ ^= 0b1 @@ -53,7 +59,7 @@ func (f *DatagramFrame) Append(b []byte, _ protocol.VersionNumber) ([]byte, erro } // MaxDataLen returns the maximum data length -func (f *DatagramFrame) MaxDataLen(maxSize protocol.ByteCount, version protocol.VersionNumber) protocol.ByteCount { +func (f *DatagramFrame) MaxDataLen(maxSize protocol.ByteCount, version protocol.Version) protocol.ByteCount { headerLen := protocol.ByteCount(1) if f.DataLenPresent { // pretend that the data size will be 1 bytes @@ -71,7 +77,7 @@ func (f *DatagramFrame) MaxDataLen(maxSize protocol.ByteCount, version protocol. } // Length of a written frame -func (f *DatagramFrame) Length(_ protocol.VersionNumber) protocol.ByteCount { +func (f *DatagramFrame) Length(_ protocol.Version) protocol.ByteCount { length := 1 + protocol.ByteCount(len(f.Data)) if f.DataLenPresent { length += quicvarint.Len(uint64(len(f.Data))) diff --git a/internal/wire/extended_header.go b/internal/wire/extended_header.go index 31ef69893..20990c316 100644 --- a/internal/wire/extended_header.go +++ b/internal/wire/extended_header.go @@ -32,7 +32,7 @@ type ExtendedHeader struct { parsedLen protocol.ByteCount } -func (h *ExtendedHeader) parse(b *bytes.Reader, v protocol.VersionNumber) (bool /* reserved bits valid */, error) { +func (h *ExtendedHeader) parse(b *bytes.Reader, v protocol.Version) (bool /* reserved bits valid */, error) { startLen := b.Len() // read the (now unencrypted) first byte var err error @@ -51,7 +51,7 @@ func (h *ExtendedHeader) parse(b *bytes.Reader, v protocol.VersionNumber) (bool return reservedBitsValid, err } -func (h *ExtendedHeader) parseLongHeader(b *bytes.Reader, _ protocol.VersionNumber) (bool /* reserved bits valid */, error) { +func (h *ExtendedHeader) parseLongHeader(b *bytes.Reader, _ protocol.Version) (bool /* reserved bits valid */, error) { if err := h.readPacketNumber(b); err != nil { return false, err } @@ -95,7 +95,7 @@ func (h *ExtendedHeader) readPacketNumber(b *bytes.Reader) error { } // Append appends the Header. -func (h *ExtendedHeader) Append(b []byte, v protocol.VersionNumber) ([]byte, error) { +func (h *ExtendedHeader) Append(b []byte, v protocol.Version) ([]byte, error) { if h.DestConnectionID.Len() > protocol.MaxConnIDLen { return nil, fmt.Errorf("invalid connection ID length: %d bytes", h.DestConnectionID.Len()) } @@ -162,7 +162,7 @@ func (h *ExtendedHeader) ParsedLen() protocol.ByteCount { } // GetLength determines the length of the Header. -func (h *ExtendedHeader) GetLength(_ protocol.VersionNumber) protocol.ByteCount { +func (h *ExtendedHeader) GetLength(_ protocol.Version) protocol.ByteCount { length := 1 /* type byte */ + 4 /* version */ + 1 /* dest conn ID len */ + protocol.ByteCount(h.DestConnectionID.Len()) + 1 /* src conn ID len */ + protocol.ByteCount(h.SrcConnectionID.Len()) + protocol.ByteCount(h.PacketNumberLen) + 2 /* length */ if h.Type == protocol.PacketTypeInitial { length += quicvarint.Len(uint64(len(h.Token))) + protocol.ByteCount(len(h.Token)) diff --git a/internal/wire/frame_parser.go b/internal/wire/frame_parser.go index c0ad6ca5a..eff001992 100644 --- a/internal/wire/frame_parser.go +++ b/internal/wire/frame_parser.go @@ -36,7 +36,8 @@ const ( handshakeDoneFrameType = 0x1e ) -type frameParser struct { +// The FrameParser parses QUIC frames, one by one. +type FrameParser struct { r bytes.Reader // cached bytes.Reader, so we don't have to repeatedly allocate them ackDelayExponent uint8 @@ -47,11 +48,9 @@ type frameParser struct { ackFrame *AckFrame } -var _ FrameParser = &frameParser{} - // NewFrameParser creates a new frame parser. -func NewFrameParser(supportsDatagrams bool) *frameParser { - return &frameParser{ +func NewFrameParser(supportsDatagrams bool) *FrameParser { + return &FrameParser{ r: *bytes.NewReader(nil), supportsDatagrams: supportsDatagrams, ackFrame: &AckFrame{}, @@ -60,7 +59,7 @@ func NewFrameParser(supportsDatagrams bool) *frameParser { // ParseNext parses the next frame. // It skips PADDING frames. -func (p *frameParser) ParseNext(data []byte, encLevel protocol.EncryptionLevel, v protocol.VersionNumber) (int, Frame, error) { +func (p *FrameParser) ParseNext(data []byte, encLevel protocol.EncryptionLevel, v protocol.Version) (int, Frame, error) { startLen := len(data) p.r.Reset(data) frame, err := p.parseNext(&p.r, encLevel, v) @@ -69,7 +68,7 @@ func (p *frameParser) ParseNext(data []byte, encLevel protocol.EncryptionLevel, return n, frame, err } -func (p *frameParser) parseNext(r *bytes.Reader, encLevel protocol.EncryptionLevel, v protocol.VersionNumber) (Frame, error) { +func (p *FrameParser) parseNext(r *bytes.Reader, encLevel protocol.EncryptionLevel, v protocol.Version) (Frame, error) { for r.Len() != 0 { typ, err := quicvarint.Read(r) if err != nil { @@ -95,7 +94,7 @@ func (p *frameParser) parseNext(r *bytes.Reader, encLevel protocol.EncryptionLev return nil, nil } -func (p *frameParser) parseFrame(r *bytes.Reader, typ uint64, encLevel protocol.EncryptionLevel, v protocol.VersionNumber) (Frame, error) { +func (p *FrameParser) parseFrame(r *bytes.Reader, typ uint64, encLevel protocol.EncryptionLevel, v protocol.Version) (Frame, error) { var frame Frame var err error if typ&0xf8 == 0x8 { @@ -163,7 +162,7 @@ func (p *frameParser) parseFrame(r *bytes.Reader, typ uint64, encLevel protocol. return frame, nil } -func (p *frameParser) isAllowedAtEncLevel(f Frame, encLevel protocol.EncryptionLevel) bool { +func (p *FrameParser) isAllowedAtEncLevel(f Frame, encLevel protocol.EncryptionLevel) bool { switch encLevel { case protocol.EncryptionInitial, protocol.EncryptionHandshake: switch f.(type) { @@ -186,6 +185,8 @@ func (p *frameParser) isAllowedAtEncLevel(f Frame, encLevel protocol.EncryptionL } } -func (p *frameParser) SetAckDelayExponent(exp uint8) { +// SetAckDelayExponent sets the acknowledgment delay exponent (sent in the transport parameters). +// This value is used to scale the ACK Delay field in the ACK frame. +func (p *FrameParser) SetAckDelayExponent(exp uint8) { p.ackDelayExponent = exp } diff --git a/internal/wire/frame_parser_test.go b/internal/wire/frame_parser_test.go index 06f3cf0ca..ea390204d 100644 --- a/internal/wire/frame_parser_test.go +++ b/internal/wire/frame_parser_test.go @@ -14,7 +14,7 @@ var _ = Describe("Frame parsing", func() { var parser FrameParser BeforeEach(func() { - parser = NewFrameParser(true) + parser = *NewFrameParser(true) }) It("returns nil if there's nothing more to read", func() { @@ -315,7 +315,7 @@ var _ = Describe("Frame parsing", func() { }) It("errors when DATAGRAM frames are not supported", func() { - parser = NewFrameParser(false) + parser = *NewFrameParser(false) f := &DatagramFrame{Data: []byte("foobar")} b, err := f.Append(nil, protocol.Version1) Expect(err).ToNot(HaveOccurred()) diff --git a/internal/wire/handshake_done_frame.go b/internal/wire/handshake_done_frame.go index 19431808f..6cb27d124 100644 --- a/internal/wire/handshake_done_frame.go +++ b/internal/wire/handshake_done_frame.go @@ -7,11 +7,11 @@ import ( // A HandshakeDoneFrame is a HANDSHAKE_DONE frame type HandshakeDoneFrame struct{} -func (f *HandshakeDoneFrame) Append(b []byte, _ protocol.VersionNumber) ([]byte, error) { +func (f *HandshakeDoneFrame) Append(b []byte, _ protocol.Version) ([]byte, error) { return append(b, handshakeDoneFrameType), nil } // Length of a written frame -func (f *HandshakeDoneFrame) Length(_ protocol.VersionNumber) protocol.ByteCount { +func (f *HandshakeDoneFrame) Length(_ protocol.Version) protocol.ByteCount { return 1 } diff --git a/internal/wire/header.go b/internal/wire/header.go index f6025019f..fb4bf3be9 100644 --- a/internal/wire/header.go +++ b/internal/wire/header.go @@ -85,11 +85,11 @@ func IsLongHeaderPacket(firstByte byte) bool { // ParseVersion parses the QUIC version. // It should only be called for Long Header packets (Short Header packets don't contain a version number). -func ParseVersion(data []byte) (protocol.VersionNumber, error) { +func ParseVersion(data []byte) (protocol.Version, error) { if len(data) < 5 { return 0, io.EOF } - return protocol.VersionNumber(binary.BigEndian.Uint32(data[1:5])), nil + return protocol.Version(binary.BigEndian.Uint32(data[1:5])), nil } // IsVersionNegotiationPacket says if this is a version negotiation packet @@ -109,7 +109,7 @@ func Is0RTTPacket(b []byte) bool { if !IsLongHeaderPacket(b[0]) { return false } - version := protocol.VersionNumber(binary.BigEndian.Uint32(b[1:5])) + version := protocol.Version(binary.BigEndian.Uint32(b[1:5])) //nolint:exhaustive // We only need to test QUIC versions that we support. switch version { case protocol.Version1: @@ -128,7 +128,7 @@ type Header struct { typeByte byte Type protocol.PacketType - Version protocol.VersionNumber + Version protocol.Version SrcConnectionID protocol.ConnectionID DestConnectionID protocol.ConnectionID @@ -184,7 +184,7 @@ func (h *Header) parseLongHeader(b *bytes.Reader) error { if err != nil { return err } - h.Version = protocol.VersionNumber(v) + h.Version = protocol.Version(v) if h.Version != 0 && h.typeByte&0x40 == 0 { return errors.New("not a QUIC packet") } @@ -278,7 +278,7 @@ func (h *Header) ParsedLen() protocol.ByteCount { // ParseExtended parses the version dependent part of the header. // The Reader has to be set such that it points to the first byte of the header. -func (h *Header) ParseExtended(b *bytes.Reader, ver protocol.VersionNumber) (*ExtendedHeader, error) { +func (h *Header) ParseExtended(b *bytes.Reader, ver protocol.Version) (*ExtendedHeader, error) { extHdr := h.toExtendedHeader() reservedBitsValid, err := extHdr.parse(b, ver) if err != nil { diff --git a/internal/wire/header_test.go b/internal/wire/header_test.go index f5b28d54e..2d3283c5e 100644 --- a/internal/wire/header_test.go +++ b/internal/wire/header_test.go @@ -95,7 +95,7 @@ var _ = Describe("Header Parsing", func() { b := []byte{0x80, 0xde, 0xad, 0xbe, 0xef} v, err := ParseVersion(b) Expect(err).ToNot(HaveOccurred()) - Expect(v).To(Equal(protocol.VersionNumber(0xdeadbeef))) + Expect(v).To(Equal(protocol.Version(0xdeadbeef))) }) It("errors with EOF", func() { @@ -230,7 +230,7 @@ var _ = Describe("Header Parsing", func() { } hdr, _, rest, err := ParsePacket(data) Expect(err).To(MatchError(ErrUnsupportedVersion)) - Expect(hdr.Version).To(Equal(protocol.VersionNumber(0xdeadbeef))) + Expect(hdr.Version).To(Equal(protocol.Version(0xdeadbeef))) Expect(hdr.DestConnectionID).To(Equal(protocol.ParseConnectionID([]byte{0x1, 0x2, 0x3, 0x4, 0x5, 0x6, 0x7, 0x8}))) Expect(hdr.SrcConnectionID).To(Equal(protocol.ParseConnectionID([]byte{0x8, 0x7, 0x6, 0x5, 0x4, 0x3, 0x2, 0x1}))) Expect(rest).To(BeEmpty()) diff --git a/internal/wire/interface.go b/internal/wire/interface.go index ca84d4fa6..5d1ca61a6 100644 --- a/internal/wire/interface.go +++ b/internal/wire/interface.go @@ -6,12 +6,6 @@ import ( // A Frame in QUIC type Frame interface { - Append(b []byte, version protocol.VersionNumber) ([]byte, error) - Length(version protocol.VersionNumber) protocol.ByteCount -} - -// A FrameParser parses QUIC frames, one by one. -type FrameParser interface { - ParseNext([]byte, protocol.EncryptionLevel, protocol.VersionNumber) (int, Frame, error) - SetAckDelayExponent(uint8) + Append(b []byte, version protocol.Version) ([]byte, error) + Length(version protocol.Version) protocol.ByteCount } diff --git a/internal/wire/log.go b/internal/wire/log.go index bf1a8d3bf..7690849e0 100644 --- a/internal/wire/log.go +++ b/internal/wire/log.go @@ -63,7 +63,9 @@ func LogFrame(logger utils.Logger, frame Frame, sent bool) { logger.Debugf("\t%s &wire.StreamsBlockedFrame{Type: bidi, MaxStreams: %d}", dir, f.StreamLimit) } case *NewConnectionIDFrame: - logger.Debugf("\t%s &wire.NewConnectionIDFrame{SequenceNumber: %d, ConnectionID: %s, StatelessResetToken: %#x}", dir, f.SequenceNumber, f.ConnectionID, f.StatelessResetToken) + logger.Debugf("\t%s &wire.NewConnectionIDFrame{SequenceNumber: %d, RetirePriorTo: %d, ConnectionID: %s, StatelessResetToken: %#x}", dir, f.SequenceNumber, f.RetirePriorTo, f.ConnectionID, f.StatelessResetToken) + case *RetireConnectionIDFrame: + logger.Debugf("\t%s &wire.RetireConnectionIDFrame{SequenceNumber: %d}", dir, f.SequenceNumber) case *NewTokenFrame: logger.Debugf("\t%s &wire.NewTokenFrame{Token: %#x}", dir, f.Token) default: diff --git a/internal/wire/log_test.go b/internal/wire/log_test.go index ddf9461ee..74a1ed6c3 100644 --- a/internal/wire/log_test.go +++ b/internal/wire/log_test.go @@ -153,10 +153,16 @@ var _ = Describe("Frame logging", func() { It("logs NEW_CONNECTION_ID frames", func() { LogFrame(logger, &NewConnectionIDFrame{ SequenceNumber: 42, + RetirePriorTo: 24, ConnectionID: protocol.ParseConnectionID([]byte{0xde, 0xad, 0xbe, 0xef}), StatelessResetToken: protocol.StatelessResetToken{0x1, 0x2, 0x3, 0x4, 0x5, 0x6, 0x7, 0x8, 0x9, 0xa, 0xb, 0xc, 0xd, 0xe, 0xf, 0x10}, }, false) - Expect(buf.String()).To(ContainSubstring("\t<- &wire.NewConnectionIDFrame{SequenceNumber: 42, ConnectionID: deadbeef, StatelessResetToken: 0x0102030405060708090a0b0c0d0e0f10}")) + Expect(buf.String()).To(ContainSubstring("\t<- &wire.NewConnectionIDFrame{SequenceNumber: 42, RetirePriorTo: 24, ConnectionID: deadbeef, StatelessResetToken: 0x0102030405060708090a0b0c0d0e0f10}")) + }) + + It("logs RETIRE_CONNECTION_ID frames", func() { + LogFrame(logger, &RetireConnectionIDFrame{SequenceNumber: 42}, false) + Expect(buf.String()).To(ContainSubstring("\t<- &wire.RetireConnectionIDFrame{SequenceNumber: 42}")) }) It("logs NEW_TOKEN frames", func() { diff --git a/internal/wire/max_data_frame.go b/internal/wire/max_data_frame.go index 7ddb40a9c..15ff7b0d4 100644 --- a/internal/wire/max_data_frame.go +++ b/internal/wire/max_data_frame.go @@ -13,7 +13,7 @@ type MaxDataFrame struct { } // parseMaxDataFrame parses a MAX_DATA frame -func parseMaxDataFrame(r *bytes.Reader, _ protocol.VersionNumber) (*MaxDataFrame, error) { +func parseMaxDataFrame(r *bytes.Reader, _ protocol.Version) (*MaxDataFrame, error) { frame := &MaxDataFrame{} byteOffset, err := quicvarint.Read(r) if err != nil { @@ -23,13 +23,13 @@ func parseMaxDataFrame(r *bytes.Reader, _ protocol.VersionNumber) (*MaxDataFrame return frame, nil } -func (f *MaxDataFrame) Append(b []byte, _ protocol.VersionNumber) ([]byte, error) { +func (f *MaxDataFrame) Append(b []byte, _ protocol.Version) ([]byte, error) { b = append(b, maxDataFrameType) b = quicvarint.Append(b, uint64(f.MaximumData)) return b, nil } // Length of a written frame -func (f *MaxDataFrame) Length(_ protocol.VersionNumber) protocol.ByteCount { +func (f *MaxDataFrame) Length(_ protocol.Version) protocol.ByteCount { return 1 + quicvarint.Len(uint64(f.MaximumData)) } diff --git a/internal/wire/max_stream_data_frame.go b/internal/wire/max_stream_data_frame.go index 33a156b70..9f3fbdd41 100644 --- a/internal/wire/max_stream_data_frame.go +++ b/internal/wire/max_stream_data_frame.go @@ -13,7 +13,7 @@ type MaxStreamDataFrame struct { MaximumStreamData protocol.ByteCount } -func parseMaxStreamDataFrame(r *bytes.Reader, _ protocol.VersionNumber) (*MaxStreamDataFrame, error) { +func parseMaxStreamDataFrame(r *bytes.Reader, _ protocol.Version) (*MaxStreamDataFrame, error) { sid, err := quicvarint.Read(r) if err != nil { return nil, err @@ -29,7 +29,7 @@ func parseMaxStreamDataFrame(r *bytes.Reader, _ protocol.VersionNumber) (*MaxStr }, nil } -func (f *MaxStreamDataFrame) Append(b []byte, version protocol.VersionNumber) ([]byte, error) { +func (f *MaxStreamDataFrame) Append(b []byte, version protocol.Version) ([]byte, error) { b = append(b, maxStreamDataFrameType) b = quicvarint.Append(b, uint64(f.StreamID)) b = quicvarint.Append(b, uint64(f.MaximumStreamData)) @@ -37,6 +37,6 @@ func (f *MaxStreamDataFrame) Append(b []byte, version protocol.VersionNumber) ([ } // Length of a written frame -func (f *MaxStreamDataFrame) Length(version protocol.VersionNumber) protocol.ByteCount { +func (f *MaxStreamDataFrame) Length(version protocol.Version) protocol.ByteCount { return 1 + quicvarint.Len(uint64(f.StreamID)) + quicvarint.Len(uint64(f.MaximumStreamData)) } diff --git a/internal/wire/max_streams_frame.go b/internal/wire/max_streams_frame.go index b80591bb9..f7d436292 100644 --- a/internal/wire/max_streams_frame.go +++ b/internal/wire/max_streams_frame.go @@ -14,7 +14,7 @@ type MaxStreamsFrame struct { MaxStreamNum protocol.StreamNum } -func parseMaxStreamsFrame(r *bytes.Reader, typ uint64, _ protocol.VersionNumber) (*MaxStreamsFrame, error) { +func parseMaxStreamsFrame(r *bytes.Reader, typ uint64, _ protocol.Version) (*MaxStreamsFrame, error) { f := &MaxStreamsFrame{} switch typ { case bidiMaxStreamsFrameType: @@ -33,7 +33,7 @@ func parseMaxStreamsFrame(r *bytes.Reader, typ uint64, _ protocol.VersionNumber) return f, nil } -func (f *MaxStreamsFrame) Append(b []byte, _ protocol.VersionNumber) ([]byte, error) { +func (f *MaxStreamsFrame) Append(b []byte, _ protocol.Version) ([]byte, error) { switch f.Type { case protocol.StreamTypeBidi: b = append(b, bidiMaxStreamsFrameType) @@ -45,6 +45,6 @@ func (f *MaxStreamsFrame) Append(b []byte, _ protocol.VersionNumber) ([]byte, er } // Length of a written frame -func (f *MaxStreamsFrame) Length(protocol.VersionNumber) protocol.ByteCount { +func (f *MaxStreamsFrame) Length(protocol.Version) protocol.ByteCount { return 1 + quicvarint.Len(uint64(f.MaxStreamNum)) } diff --git a/internal/wire/new_connection_id_frame.go b/internal/wire/new_connection_id_frame.go index 37c9577a3..9de45bbc9 100644 --- a/internal/wire/new_connection_id_frame.go +++ b/internal/wire/new_connection_id_frame.go @@ -2,6 +2,7 @@ package wire import ( "bytes" + "errors" "fmt" "io" @@ -17,7 +18,7 @@ type NewConnectionIDFrame struct { StatelessResetToken protocol.StatelessResetToken } -func parseNewConnectionIDFrame(r *bytes.Reader, _ protocol.VersionNumber) (*NewConnectionIDFrame, error) { +func parseNewConnectionIDFrame(r *bytes.Reader, _ protocol.Version) (*NewConnectionIDFrame, error) { seq, err := quicvarint.Read(r) if err != nil { return nil, err @@ -34,6 +35,9 @@ func parseNewConnectionIDFrame(r *bytes.Reader, _ protocol.VersionNumber) (*NewC if err != nil { return nil, err } + if connIDLen == 0 { + return nil, errors.New("invalid zero-length connection ID") + } connID, err := protocol.ReadConnectionID(r, int(connIDLen)) if err != nil { return nil, err @@ -53,7 +57,7 @@ func parseNewConnectionIDFrame(r *bytes.Reader, _ protocol.VersionNumber) (*NewC return frame, nil } -func (f *NewConnectionIDFrame) Append(b []byte, _ protocol.VersionNumber) ([]byte, error) { +func (f *NewConnectionIDFrame) Append(b []byte, _ protocol.Version) ([]byte, error) { b = append(b, newConnectionIDFrameType) b = quicvarint.Append(b, f.SequenceNumber) b = quicvarint.Append(b, f.RetirePriorTo) @@ -68,6 +72,6 @@ func (f *NewConnectionIDFrame) Append(b []byte, _ protocol.VersionNumber) ([]byt } // Length of a written frame -func (f *NewConnectionIDFrame) Length(protocol.VersionNumber) protocol.ByteCount { +func (f *NewConnectionIDFrame) Length(protocol.Version) protocol.ByteCount { return 1 + quicvarint.Len(f.SequenceNumber) + quicvarint.Len(f.RetirePriorTo) + 1 /* connection ID length */ + protocol.ByteCount(f.ConnectionID.Len()) + 16 } diff --git a/internal/wire/new_connection_id_frame_test.go b/internal/wire/new_connection_id_frame_test.go index f4da83143..86b7df184 100644 --- a/internal/wire/new_connection_id_frame_test.go +++ b/internal/wire/new_connection_id_frame_test.go @@ -38,7 +38,15 @@ var _ = Describe("NEW_CONNECTION_ID frame", func() { Expect(err).To(MatchError("Retire Prior To value (1001) larger than Sequence Number (1000)")) }) - It("errors when the connection ID has an invalid length", func() { + It("errors when the connection ID has a zero-length connection ID", func() { + data := encodeVarInt(42) // sequence number + data = append(data, encodeVarInt(12)...) // retire prior to + data = append(data, 0) // connection ID length + _, err := parseNewConnectionIDFrame(bytes.NewReader(data), protocol.Version1) + Expect(err).To(MatchError("invalid zero-length connection ID")) + }) + + It("errors when the connection ID has an invalid length (too long)", func() { data := encodeVarInt(0xdeadbeef) // sequence number data = append(data, encodeVarInt(0xcafe)...) // retire prior to data = append(data, 21) // connection ID length diff --git a/internal/wire/new_token_frame.go b/internal/wire/new_token_frame.go index 34c62b98e..d15d15d01 100644 --- a/internal/wire/new_token_frame.go +++ b/internal/wire/new_token_frame.go @@ -14,7 +14,7 @@ type NewTokenFrame struct { Token []byte } -func parseNewTokenFrame(r *bytes.Reader, _ protocol.VersionNumber) (*NewTokenFrame, error) { +func parseNewTokenFrame(r *bytes.Reader, _ protocol.Version) (*NewTokenFrame, error) { tokenLen, err := quicvarint.Read(r) if err != nil { return nil, err @@ -32,7 +32,7 @@ func parseNewTokenFrame(r *bytes.Reader, _ protocol.VersionNumber) (*NewTokenFra return &NewTokenFrame{Token: token}, nil } -func (f *NewTokenFrame) Append(b []byte, _ protocol.VersionNumber) ([]byte, error) { +func (f *NewTokenFrame) Append(b []byte, _ protocol.Version) ([]byte, error) { b = append(b, newTokenFrameType) b = quicvarint.Append(b, uint64(len(f.Token))) b = append(b, f.Token...) @@ -40,6 +40,6 @@ func (f *NewTokenFrame) Append(b []byte, _ protocol.VersionNumber) ([]byte, erro } // Length of a written frame -func (f *NewTokenFrame) Length(protocol.VersionNumber) protocol.ByteCount { +func (f *NewTokenFrame) Length(protocol.Version) protocol.ByteCount { return 1 + quicvarint.Len(uint64(len(f.Token))) + protocol.ByteCount(len(f.Token)) } diff --git a/internal/wire/path_challenge_frame.go b/internal/wire/path_challenge_frame.go index de2e7ad96..3882087de 100644 --- a/internal/wire/path_challenge_frame.go +++ b/internal/wire/path_challenge_frame.go @@ -12,7 +12,7 @@ type PathChallengeFrame struct { Data [8]byte } -func parsePathChallengeFrame(r *bytes.Reader, _ protocol.VersionNumber) (*PathChallengeFrame, error) { +func parsePathChallengeFrame(r *bytes.Reader, _ protocol.Version) (*PathChallengeFrame, error) { frame := &PathChallengeFrame{} if _, err := io.ReadFull(r, frame.Data[:]); err != nil { if err == io.ErrUnexpectedEOF { @@ -23,13 +23,13 @@ func parsePathChallengeFrame(r *bytes.Reader, _ protocol.VersionNumber) (*PathCh return frame, nil } -func (f *PathChallengeFrame) Append(b []byte, _ protocol.VersionNumber) ([]byte, error) { +func (f *PathChallengeFrame) Append(b []byte, _ protocol.Version) ([]byte, error) { b = append(b, pathChallengeFrameType) b = append(b, f.Data[:]...) return b, nil } // Length of a written frame -func (f *PathChallengeFrame) Length(_ protocol.VersionNumber) protocol.ByteCount { +func (f *PathChallengeFrame) Length(_ protocol.Version) protocol.ByteCount { return 1 + 8 } diff --git a/internal/wire/path_response_frame.go b/internal/wire/path_response_frame.go index 9dd69c25f..0e8185633 100644 --- a/internal/wire/path_response_frame.go +++ b/internal/wire/path_response_frame.go @@ -12,7 +12,7 @@ type PathResponseFrame struct { Data [8]byte } -func parsePathResponseFrame(r *bytes.Reader, _ protocol.VersionNumber) (*PathResponseFrame, error) { +func parsePathResponseFrame(r *bytes.Reader, _ protocol.Version) (*PathResponseFrame, error) { frame := &PathResponseFrame{} if _, err := io.ReadFull(r, frame.Data[:]); err != nil { if err == io.ErrUnexpectedEOF { @@ -23,13 +23,13 @@ func parsePathResponseFrame(r *bytes.Reader, _ protocol.VersionNumber) (*PathRes return frame, nil } -func (f *PathResponseFrame) Append(b []byte, _ protocol.VersionNumber) ([]byte, error) { +func (f *PathResponseFrame) Append(b []byte, _ protocol.Version) ([]byte, error) { b = append(b, pathResponseFrameType) b = append(b, f.Data[:]...) return b, nil } // Length of a written frame -func (f *PathResponseFrame) Length(_ protocol.VersionNumber) protocol.ByteCount { +func (f *PathResponseFrame) Length(_ protocol.Version) protocol.ByteCount { return 1 + 8 } diff --git a/internal/wire/ping_frame.go b/internal/wire/ping_frame.go index a616f7d0a..6af2567d2 100644 --- a/internal/wire/ping_frame.go +++ b/internal/wire/ping_frame.go @@ -7,11 +7,11 @@ import ( // A PingFrame is a PING frame type PingFrame struct{} -func (f *PingFrame) Append(b []byte, _ protocol.VersionNumber) ([]byte, error) { +func (f *PingFrame) Append(b []byte, _ protocol.Version) ([]byte, error) { return append(b, pingFrameType), nil } // Length of a written frame -func (f *PingFrame) Length(_ protocol.VersionNumber) protocol.ByteCount { +func (f *PingFrame) Length(_ protocol.Version) protocol.ByteCount { return 1 } diff --git a/internal/wire/reset_stream_frame.go b/internal/wire/reset_stream_frame.go index bfa272394..87fd09e2c 100644 --- a/internal/wire/reset_stream_frame.go +++ b/internal/wire/reset_stream_frame.go @@ -15,7 +15,7 @@ type ResetStreamFrame struct { FinalSize protocol.ByteCount } -func parseResetStreamFrame(r *bytes.Reader, _ protocol.VersionNumber) (*ResetStreamFrame, error) { +func parseResetStreamFrame(r *bytes.Reader, _ protocol.Version) (*ResetStreamFrame, error) { var streamID protocol.StreamID var byteOffset protocol.ByteCount sid, err := quicvarint.Read(r) @@ -40,7 +40,7 @@ func parseResetStreamFrame(r *bytes.Reader, _ protocol.VersionNumber) (*ResetStr }, nil } -func (f *ResetStreamFrame) Append(b []byte, _ protocol.VersionNumber) ([]byte, error) { +func (f *ResetStreamFrame) Append(b []byte, _ protocol.Version) ([]byte, error) { b = append(b, resetStreamFrameType) b = quicvarint.Append(b, uint64(f.StreamID)) b = quicvarint.Append(b, uint64(f.ErrorCode)) @@ -49,6 +49,6 @@ func (f *ResetStreamFrame) Append(b []byte, _ protocol.VersionNumber) ([]byte, e } // Length of a written frame -func (f *ResetStreamFrame) Length(version protocol.VersionNumber) protocol.ByteCount { +func (f *ResetStreamFrame) Length(version protocol.Version) protocol.ByteCount { return 1 + quicvarint.Len(uint64(f.StreamID)) + quicvarint.Len(uint64(f.ErrorCode)) + quicvarint.Len(uint64(f.FinalSize)) } diff --git a/internal/wire/retire_connection_id_frame.go b/internal/wire/retire_connection_id_frame.go index 519729378..d36f81e3b 100644 --- a/internal/wire/retire_connection_id_frame.go +++ b/internal/wire/retire_connection_id_frame.go @@ -12,7 +12,7 @@ type RetireConnectionIDFrame struct { SequenceNumber uint64 } -func parseRetireConnectionIDFrame(r *bytes.Reader, _ protocol.VersionNumber) (*RetireConnectionIDFrame, error) { +func parseRetireConnectionIDFrame(r *bytes.Reader, _ protocol.Version) (*RetireConnectionIDFrame, error) { seq, err := quicvarint.Read(r) if err != nil { return nil, err @@ -20,13 +20,13 @@ func parseRetireConnectionIDFrame(r *bytes.Reader, _ protocol.VersionNumber) (*R return &RetireConnectionIDFrame{SequenceNumber: seq}, nil } -func (f *RetireConnectionIDFrame) Append(b []byte, _ protocol.VersionNumber) ([]byte, error) { +func (f *RetireConnectionIDFrame) Append(b []byte, _ protocol.Version) ([]byte, error) { b = append(b, retireConnectionIDFrameType) b = quicvarint.Append(b, f.SequenceNumber) return b, nil } // Length of a written frame -func (f *RetireConnectionIDFrame) Length(protocol.VersionNumber) protocol.ByteCount { +func (f *RetireConnectionIDFrame) Length(protocol.Version) protocol.ByteCount { return 1 + quicvarint.Len(f.SequenceNumber) } diff --git a/internal/wire/stop_sending_frame.go b/internal/wire/stop_sending_frame.go index 0467c4d23..d4bd78045 100644 --- a/internal/wire/stop_sending_frame.go +++ b/internal/wire/stop_sending_frame.go @@ -15,7 +15,7 @@ type StopSendingFrame struct { } // parseStopSendingFrame parses a STOP_SENDING frame -func parseStopSendingFrame(r *bytes.Reader, _ protocol.VersionNumber) (*StopSendingFrame, error) { +func parseStopSendingFrame(r *bytes.Reader, _ protocol.Version) (*StopSendingFrame, error) { streamID, err := quicvarint.Read(r) if err != nil { return nil, err @@ -32,11 +32,11 @@ func parseStopSendingFrame(r *bytes.Reader, _ protocol.VersionNumber) (*StopSend } // Length of a written frame -func (f *StopSendingFrame) Length(_ protocol.VersionNumber) protocol.ByteCount { +func (f *StopSendingFrame) Length(_ protocol.Version) protocol.ByteCount { return 1 + quicvarint.Len(uint64(f.StreamID)) + quicvarint.Len(uint64(f.ErrorCode)) } -func (f *StopSendingFrame) Append(b []byte, _ protocol.VersionNumber) ([]byte, error) { +func (f *StopSendingFrame) Append(b []byte, _ protocol.Version) ([]byte, error) { b = append(b, stopSendingFrameType) b = quicvarint.Append(b, uint64(f.StreamID)) b = quicvarint.Append(b, uint64(f.ErrorCode)) diff --git a/internal/wire/stream_data_blocked_frame.go b/internal/wire/stream_data_blocked_frame.go index 26f428c04..5bc6324a1 100644 --- a/internal/wire/stream_data_blocked_frame.go +++ b/internal/wire/stream_data_blocked_frame.go @@ -13,7 +13,7 @@ type StreamDataBlockedFrame struct { MaximumStreamData protocol.ByteCount } -func parseStreamDataBlockedFrame(r *bytes.Reader, _ protocol.VersionNumber) (*StreamDataBlockedFrame, error) { +func parseStreamDataBlockedFrame(r *bytes.Reader, _ protocol.Version) (*StreamDataBlockedFrame, error) { sid, err := quicvarint.Read(r) if err != nil { return nil, err @@ -29,7 +29,7 @@ func parseStreamDataBlockedFrame(r *bytes.Reader, _ protocol.VersionNumber) (*St }, nil } -func (f *StreamDataBlockedFrame) Append(b []byte, _ protocol.VersionNumber) ([]byte, error) { +func (f *StreamDataBlockedFrame) Append(b []byte, _ protocol.Version) ([]byte, error) { b = append(b, 0x15) b = quicvarint.Append(b, uint64(f.StreamID)) b = quicvarint.Append(b, uint64(f.MaximumStreamData)) @@ -37,6 +37,6 @@ func (f *StreamDataBlockedFrame) Append(b []byte, _ protocol.VersionNumber) ([]b } // Length of a written frame -func (f *StreamDataBlockedFrame) Length(version protocol.VersionNumber) protocol.ByteCount { +func (f *StreamDataBlockedFrame) Length(version protocol.Version) protocol.ByteCount { return 1 + quicvarint.Len(uint64(f.StreamID)) + quicvarint.Len(uint64(f.MaximumStreamData)) } diff --git a/internal/wire/stream_frame.go b/internal/wire/stream_frame.go index bbe3dece5..4176c84ef 100644 --- a/internal/wire/stream_frame.go +++ b/internal/wire/stream_frame.go @@ -20,7 +20,7 @@ type StreamFrame struct { fromPool bool } -func parseStreamFrame(r *bytes.Reader, typ uint64, _ protocol.VersionNumber) (*StreamFrame, error) { +func parseStreamFrame(r *bytes.Reader, typ uint64, _ protocol.Version) (*StreamFrame, error) { hasOffset := typ&0b100 > 0 fin := typ&0b1 > 0 hasDataLen := typ&0b10 > 0 @@ -79,7 +79,7 @@ func parseStreamFrame(r *bytes.Reader, typ uint64, _ protocol.VersionNumber) (*S } // Write writes a STREAM frame -func (f *StreamFrame) Append(b []byte, _ protocol.VersionNumber) ([]byte, error) { +func (f *StreamFrame) Append(b []byte, _ protocol.Version) ([]byte, error) { if len(f.Data) == 0 && !f.Fin { return nil, errors.New("StreamFrame: attempting to write empty frame without FIN") } @@ -108,7 +108,7 @@ func (f *StreamFrame) Append(b []byte, _ protocol.VersionNumber) ([]byte, error) } // Length returns the total length of the STREAM frame -func (f *StreamFrame) Length(version protocol.VersionNumber) protocol.ByteCount { +func (f *StreamFrame) Length(version protocol.Version) protocol.ByteCount { length := 1 + quicvarint.Len(uint64(f.StreamID)) if f.Offset != 0 { length += quicvarint.Len(uint64(f.Offset)) @@ -126,7 +126,7 @@ func (f *StreamFrame) DataLen() protocol.ByteCount { // MaxDataLen returns the maximum data length // If 0 is returned, writing will fail (a STREAM frame must contain at least 1 byte of data). -func (f *StreamFrame) MaxDataLen(maxSize protocol.ByteCount, version protocol.VersionNumber) protocol.ByteCount { +func (f *StreamFrame) MaxDataLen(maxSize protocol.ByteCount, version protocol.Version) protocol.ByteCount { headerLen := 1 + quicvarint.Len(uint64(f.StreamID)) if f.Offset != 0 { headerLen += quicvarint.Len(uint64(f.Offset)) @@ -151,7 +151,7 @@ func (f *StreamFrame) MaxDataLen(maxSize protocol.ByteCount, version protocol.Ve // The frame might not be split if: // * the size is large enough to fit the whole frame // * the size is too small to fit even a 1-byte frame. In that case, the frame returned is nil. -func (f *StreamFrame) MaybeSplitOffFrame(maxSize protocol.ByteCount, version protocol.VersionNumber) (*StreamFrame, bool /* was splitting required */) { +func (f *StreamFrame) MaybeSplitOffFrame(maxSize protocol.ByteCount, version protocol.Version) (*StreamFrame, bool /* was splitting required */) { if maxSize >= f.Length(version) { return nil, false } diff --git a/internal/wire/streams_blocked_frame.go b/internal/wire/streams_blocked_frame.go index 450cb03c5..b64efbe3c 100644 --- a/internal/wire/streams_blocked_frame.go +++ b/internal/wire/streams_blocked_frame.go @@ -14,7 +14,7 @@ type StreamsBlockedFrame struct { StreamLimit protocol.StreamNum } -func parseStreamsBlockedFrame(r *bytes.Reader, typ uint64, _ protocol.VersionNumber) (*StreamsBlockedFrame, error) { +func parseStreamsBlockedFrame(r *bytes.Reader, typ uint64, _ protocol.Version) (*StreamsBlockedFrame, error) { f := &StreamsBlockedFrame{} switch typ { case bidiStreamBlockedFrameType: @@ -33,7 +33,7 @@ func parseStreamsBlockedFrame(r *bytes.Reader, typ uint64, _ protocol.VersionNum return f, nil } -func (f *StreamsBlockedFrame) Append(b []byte, _ protocol.VersionNumber) ([]byte, error) { +func (f *StreamsBlockedFrame) Append(b []byte, _ protocol.Version) ([]byte, error) { switch f.Type { case protocol.StreamTypeBidi: b = append(b, bidiStreamBlockedFrameType) @@ -45,6 +45,6 @@ func (f *StreamsBlockedFrame) Append(b []byte, _ protocol.VersionNumber) ([]byte } // Length of a written frame -func (f *StreamsBlockedFrame) Length(_ protocol.VersionNumber) protocol.ByteCount { +func (f *StreamsBlockedFrame) Length(_ protocol.Version) protocol.ByteCount { return 1 + quicvarint.Len(uint64(f.StreamLimit)) } diff --git a/internal/wire/transport_parameter_test.go b/internal/wire/transport_parameter_test.go index faa6eaf37..0af8d93f9 100644 --- a/internal/wire/transport_parameter_test.go +++ b/internal/wire/transport_parameter_test.go @@ -4,7 +4,7 @@ import ( "bytes" "fmt" "math" - "net" + "net/netip" "time" "golang.org/x/exp/rand" @@ -425,10 +425,8 @@ var _ = Describe("Transport Parameters", func() { BeforeEach(func() { pa = &PreferredAddress{ - IPv4: net.IPv4(127, 0, 0, 1), - IPv4Port: 42, - IPv6: net.IP{1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16}, - IPv6Port: 13, + IPv4: netip.AddrPortFrom(netip.AddrFrom4([4]byte{127, 0, 0, 1}), 42), + IPv6: netip.AddrPortFrom(netip.AddrFrom16([16]byte{1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16}), 13), ConnectionID: protocol.ParseConnectionID([]byte{0xde, 0xad, 0xbe, 0xef}), StatelessResetToken: protocol.StatelessResetToken{16, 15, 14, 13, 12, 11, 10, 9, 8, 7, 6, 5, 4, 3, 2, 1}, } @@ -442,10 +440,8 @@ var _ = Describe("Transport Parameters", func() { }).Marshal(protocol.PerspectiveServer) p := &TransportParameters{} Expect(p.Unmarshal(data, protocol.PerspectiveServer)).To(Succeed()) - Expect(p.PreferredAddress.IPv4.String()).To(Equal(pa.IPv4.String())) - Expect(p.PreferredAddress.IPv4Port).To(Equal(pa.IPv4Port)) - Expect(p.PreferredAddress.IPv6.String()).To(Equal(pa.IPv6.String())) - Expect(p.PreferredAddress.IPv6Port).To(Equal(pa.IPv6Port)) + Expect(p.PreferredAddress.IPv4).To(Equal(pa.IPv4)) + Expect(p.PreferredAddress.IPv6).To(Equal(pa.IPv6)) Expect(p.PreferredAddress.ConnectionID).To(Equal(pa.ConnectionID)) Expect(p.PreferredAddress.StatelessResetToken).To(Equal(pa.StatelessResetToken)) }) @@ -503,7 +499,7 @@ var _ = Describe("Transport Parameters", func() { MaxBidiStreamNum: protocol.StreamNum(getRandomValueUpTo(int64(protocol.MaxStreamCount))), MaxUniStreamNum: protocol.StreamNum(getRandomValueUpTo(int64(protocol.MaxStreamCount))), ActiveConnectionIDLimit: 2 + getRandomValueUpTo(math.MaxInt64-2), - MaxDatagramFrameSize: protocol.ByteCount(getRandomValueUpTo(int64(protocol.MaxDatagramFrameSize))), + MaxDatagramFrameSize: protocol.ByteCount(getRandomValueUpTo(int64(MaxDatagramSize))), } Expect(params.ValidFor0RTT(params)).To(BeTrue()) b := params.MarshalForSessionTicket(nil) diff --git a/internal/wire/transport_parameters.go b/internal/wire/transport_parameters.go index f1562b36d..9e821245f 100644 --- a/internal/wire/transport_parameters.go +++ b/internal/wire/transport_parameters.go @@ -7,7 +7,7 @@ import ( "errors" "fmt" "io" - "net" + "net/netip" "sort" "time" @@ -52,10 +52,7 @@ const ( // PreferredAddress is the value encoding in the preferred_address transport parameter type PreferredAddress struct { - IPv4 net.IP - IPv4Port uint16 - IPv6 net.IP - IPv6Port uint16 + IPv4, IPv6 netip.AddrPort ConnectionID protocol.ConnectionID StatelessResetToken protocol.StatelessResetToken } @@ -222,26 +219,24 @@ func (p *TransportParameters) unmarshal(r *bytes.Reader, sentBy protocol.Perspec func (p *TransportParameters) readPreferredAddress(r *bytes.Reader, expectedLen int) error { remainingLen := r.Len() pa := &PreferredAddress{} - ipv4 := make([]byte, 4) - if _, err := io.ReadFull(r, ipv4); err != nil { + var ipv4 [4]byte + if _, err := io.ReadFull(r, ipv4[:]); err != nil { return err } - pa.IPv4 = net.IP(ipv4) port, err := utils.BigEndian.ReadUint16(r) if err != nil { return err } - pa.IPv4Port = port - ipv6 := make([]byte, 16) - if _, err := io.ReadFull(r, ipv6); err != nil { + pa.IPv4 = netip.AddrPortFrom(netip.AddrFrom4(ipv4), port) + var ipv6 [16]byte + if _, err := io.ReadFull(r, ipv6[:]); err != nil { return err } - pa.IPv6 = net.IP(ipv6) port, err = utils.BigEndian.ReadUint16(r) if err != nil { return err } - pa.IPv6Port = port + pa.IPv6 = netip.AddrPortFrom(netip.AddrFrom16(ipv6), port) connIDLen, err := r.ReadByte() if err != nil { return err @@ -298,7 +293,7 @@ func (p *TransportParameters) readNumericTransportParameter( return fmt.Errorf("initial_max_streams_uni too large: %d (maximum %d)", p.MaxUniStreamNum, protocol.MaxStreamCount) } case maxIdleTimeoutParameterID: - p.MaxIdleTimeout = utils.Max(protocol.MinRemoteIdleTimeout, time.Duration(val)*time.Millisecond) + p.MaxIdleTimeout = max(protocol.MinRemoteIdleTimeout, time.Duration(val)*time.Millisecond) case maxUDPPayloadSizeParameterID: if val < 1200 { return fmt.Errorf("invalid value for max_packet_size: %d (minimum 1200)", val) @@ -394,13 +389,12 @@ func (p *TransportParameters) Marshal(pers protocol.Perspective) []byte { if p.PreferredAddress != nil { b = quicvarint.Append(b, uint64(preferredAddressParameterID)) b = quicvarint.Append(b, 4+2+16+2+1+uint64(p.PreferredAddress.ConnectionID.Len())+16) - ipv4 := p.PreferredAddress.IPv4 - b = append(b, ipv4[len(ipv4)-4:]...) - b = append(b, []byte{0, 0}...) - binary.BigEndian.PutUint16(b[len(b)-2:], p.PreferredAddress.IPv4Port) - b = append(b, p.PreferredAddress.IPv6...) - b = append(b, []byte{0, 0}...) - binary.BigEndian.PutUint16(b[len(b)-2:], p.PreferredAddress.IPv6Port) + ip4 := p.PreferredAddress.IPv4.Addr().As4() + b = append(b, ip4[:]...) + b = binary.BigEndian.AppendUint16(b, p.PreferredAddress.IPv4.Port()) + ip6 := p.PreferredAddress.IPv6.Addr().As16() + b = append(b, ip6[:]...) + b = binary.BigEndian.AppendUint16(b, p.PreferredAddress.IPv6.Port()) b = append(b, uint8(p.PreferredAddress.ConnectionID.Len())) b = append(b, p.PreferredAddress.ConnectionID.Bytes()...) b = append(b, p.PreferredAddress.StatelessResetToken[:]...) diff --git a/internal/wire/version_negotiation.go b/internal/wire/version_negotiation.go index 3659d1031..94d266cfd 100644 --- a/internal/wire/version_negotiation.go +++ b/internal/wire/version_negotiation.go @@ -1,17 +1,15 @@ package wire import ( - "bytes" "crypto/rand" "encoding/binary" "errors" "github.com/refraction-networking/uquic/internal/protocol" - "github.com/refraction-networking/uquic/internal/utils" ) // ParseVersionNegotiationPacket parses a Version Negotiation packet. -func ParseVersionNegotiationPacket(b []byte) (dest, src protocol.ArbitraryLenConnectionID, _ []protocol.VersionNumber, _ error) { +func ParseVersionNegotiationPacket(b []byte) (dest, src protocol.ArbitraryLenConnectionID, _ []protocol.Version, _ error) { n, dest, src, err := ParseArbitraryLenConnectionIDs(b) if err != nil { return nil, nil, nil, err @@ -25,32 +23,31 @@ func ParseVersionNegotiationPacket(b []byte) (dest, src protocol.ArbitraryLenCon //nolint:stylecheck return nil, nil, nil, errors.New("Version Negotiation packet has a version list with an invalid length") } - versions := make([]protocol.VersionNumber, len(b)/4) + versions := make([]protocol.Version, len(b)/4) for i := 0; len(b) > 0; i++ { - versions[i] = protocol.VersionNumber(binary.BigEndian.Uint32(b[:4])) + versions[i] = protocol.Version(binary.BigEndian.Uint32(b[:4])) b = b[4:] } return dest, src, versions, nil } // ComposeVersionNegotiation composes a Version Negotiation -func ComposeVersionNegotiation(destConnID, srcConnID protocol.ArbitraryLenConnectionID, versions []protocol.VersionNumber) []byte { +func ComposeVersionNegotiation(destConnID, srcConnID protocol.ArbitraryLenConnectionID, versions []protocol.Version) []byte { greasedVersions := protocol.GetGreasedVersions(versions) expectedLen := 1 /* type byte */ + 4 /* version field */ + 1 /* dest connection ID length field */ + destConnID.Len() + 1 /* src connection ID length field */ + srcConnID.Len() + len(greasedVersions)*4 - buf := bytes.NewBuffer(make([]byte, 0, expectedLen)) - r := make([]byte, 1) - _, _ = rand.Read(r) // ignore the error here. It is not critical to have perfect random here. + buf := make([]byte, 1+4 /* type byte and version field */, expectedLen) + _, _ = rand.Read(buf[:1]) // ignore the error here. It is not critical to have perfect random here. // Setting the "QUIC bit" (0x40) is not required by the RFC, // but it allows clients to demultiplex QUIC with a long list of other protocols. // See RFC 9443 and https://mailarchive.ietf.org/arch/msg/quic/oR4kxGKY6mjtPC1CZegY1ED4beg/ for details. - buf.WriteByte(r[0] | 0xc0) - utils.BigEndian.WriteUint32(buf, 0) // version 0 - buf.WriteByte(uint8(destConnID.Len())) - buf.Write(destConnID.Bytes()) - buf.WriteByte(uint8(srcConnID.Len())) - buf.Write(srcConnID.Bytes()) + buf[0] |= 0xc0 + // The next 4 bytes are left at 0 (version number). + buf = append(buf, uint8(destConnID.Len())) + buf = append(buf, destConnID.Bytes()...) + buf = append(buf, uint8(srcConnID.Len())) + buf = append(buf, srcConnID.Bytes()...) for _, v := range greasedVersions { - utils.BigEndian.WriteUint32(buf, uint32(v)) + buf = binary.BigEndian.AppendUint32(buf, uint32(v)) } - return buf.Bytes() + return buf } diff --git a/internal/wire/version_negotiation_test.go b/internal/wire/version_negotiation_test.go index e47d2c83f..8dface533 100644 --- a/internal/wire/version_negotiation_test.go +++ b/internal/wire/version_negotiation_test.go @@ -2,6 +2,7 @@ package wire import ( "encoding/binary" + "testing" "golang.org/x/exp/rand" @@ -22,7 +23,7 @@ var _ = Describe("Version Negotiation Packets", func() { It("parses a Version Negotiation packet", func() { srcConnID := randConnID(rand.Intn(255) + 1) destConnID := randConnID(rand.Intn(255) + 1) - versions := []protocol.VersionNumber{0x22334455, 0x33445566} + versions := []protocol.Version{0x22334455, 0x33445566} data := []byte{0x80, 0, 0, 0, 0} data = append(data, uint8(len(destConnID))) data = append(data, destConnID...) @@ -42,7 +43,7 @@ var _ = Describe("Version Negotiation Packets", func() { It("errors if it contains versions of the wrong length", func() { connID := protocol.ArbitraryLenConnectionID{1, 2, 3, 4, 5, 6, 7, 8} - versions := []protocol.VersionNumber{0x22334455, 0x33445566} + versions := []protocol.Version{0x22334455, 0x33445566} data := ComposeVersionNegotiation(connID, connID, versions) _, _, _, err := ParseVersionNegotiationPacket(data[:len(data)-2]) Expect(err).To(MatchError("Version Negotiation packet has a version list with an invalid length")) @@ -50,7 +51,7 @@ var _ = Describe("Version Negotiation Packets", func() { It("errors if the version list is empty", func() { connID := protocol.ArbitraryLenConnectionID{1, 2, 3, 4, 5, 6, 7, 8} - versions := []protocol.VersionNumber{0x22334455} + versions := []protocol.Version{0x22334455} data := ComposeVersionNegotiation(connID, connID, versions) // remove 8 bytes (two versions), since ComposeVersionNegotiation also added a reserved version number data = data[:len(data)-8] @@ -61,7 +62,7 @@ var _ = Describe("Version Negotiation Packets", func() { It("adds a reserved version", func() { srcConnID := protocol.ArbitraryLenConnectionID{0xde, 0xad, 0xbe, 0xef, 0xca, 0xfe, 0x13, 0x37} destConnID := protocol.ArbitraryLenConnectionID{1, 2, 3, 4, 5, 6, 7, 8} - versions := []protocol.VersionNumber{1001, 1003} + versions := []protocol.Version{1001, 1003} data := ComposeVersionNegotiation(destConnID, srcConnID, versions) Expect(IsLongHeaderPacket(data[0])).To(BeTrue()) Expect(data[0] & 0x40).ToNot(BeZero()) @@ -77,7 +78,7 @@ var _ = Describe("Version Negotiation Packets", func() { for _, v := range versions { Expect(supportedVersions).To(ContainElement(v)) } - var reservedVersion protocol.VersionNumber + var reservedVersion protocol.Version versionLoop: for _, ver := range supportedVersions { for _, v := range versions { @@ -91,3 +92,13 @@ var _ = Describe("Version Negotiation Packets", func() { Expect(reservedVersion&0x0f0f0f0f == 0x0a0a0a0a).To(BeTrue()) // check that it's a greased version number }) }) + +func BenchmarkComposeVersionNegotiationPacket(b *testing.B) { + b.ReportAllocs() + supportedVersions := []protocol.Version{protocol.Version2, protocol.Version1, 0x1337} + destConnID := protocol.ArbitraryLenConnectionID{1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 0xa, 0xb, 0xc, 0xd} + srcConnID := protocol.ArbitraryLenConnectionID{10, 9, 8, 7, 6, 5, 4, 3, 2, 1} + for i := 0; i < b.N; i++ { + ComposeVersionNegotiation(destConnID, srcConnID, supportedVersions) + } +} diff --git a/internal/wire/wire_suite_test.go b/internal/wire/wire_suite_test.go index 161b6e7f7..73c1784ac 100644 --- a/internal/wire/wire_suite_test.go +++ b/internal/wire/wire_suite_test.go @@ -20,7 +20,7 @@ func encodeVarInt(i uint64) []byte { return quicvarint.Append(nil, i) } -func appendVersion(data []byte, v protocol.VersionNumber) []byte { +func appendVersion(data []byte, v protocol.Version) []byte { offset := len(data) data = append(data, []byte{0, 0, 0, 0}...) binary.BigEndian.PutUint32(data[offset:], uint32(v)) diff --git a/logging/connection_tracer.go b/logging/connection_tracer.go index e3f322d91..7f54d6cda 100644 --- a/logging/connection_tracer.go +++ b/logging/connection_tracer.go @@ -20,20 +20,21 @@ type ConnectionTracer struct { ReceivedLongHeaderPacket func(*ExtendedHeader, ByteCount, ECN, []Frame) ReceivedShortHeaderPacket func(*ShortHeader, ByteCount, ECN, []Frame) BufferedPacket func(PacketType, ByteCount) - DroppedPacket func(PacketType, ByteCount, PacketDropReason) + DroppedPacket func(PacketType, PacketNumber, ByteCount, PacketDropReason) UpdatedMetrics func(rttStats *RTTStats, cwnd, bytesInFlight ByteCount, packetsInFlight int) AcknowledgedPacket func(EncryptionLevel, PacketNumber) LostPacket func(EncryptionLevel, PacketNumber, PacketLossReason) UpdatedCongestionState func(CongestionState) UpdatedPTOCount func(value uint32) UpdatedKeyFromTLS func(EncryptionLevel, Perspective) - UpdatedKey func(generation KeyPhase, remote bool) + UpdatedKey func(keyPhase KeyPhase, remote bool) DroppedEncryptionLevel func(EncryptionLevel) - DroppedKey func(generation KeyPhase) + DroppedKey func(keyPhase KeyPhase) SetLossTimer func(TimerType, EncryptionLevel, time.Time) LossTimerExpired func(TimerType, EncryptionLevel) LossTimerCanceled func() ECNStateUpdated func(state ECNState, trigger ECNStateTrigger) + ChoseALPN func(protocol string) // Close is called when the connection is closed. Close func() Debug func(name, msg string) @@ -139,10 +140,10 @@ func NewMultiplexedConnectionTracer(tracers ...*ConnectionTracer) *ConnectionTra } } }, - DroppedPacket: func(typ PacketType, size ByteCount, reason PacketDropReason) { + DroppedPacket: func(typ PacketType, pn PacketNumber, size ByteCount, reason PacketDropReason) { for _, t := range tracers { if t.DroppedPacket != nil { - t.DroppedPacket(typ, size, reason) + t.DroppedPacket(typ, pn, size, reason) } } }, @@ -237,6 +238,13 @@ func NewMultiplexedConnectionTracer(tracers ...*ConnectionTracer) *ConnectionTra } } }, + ChoseALPN: func(protocol string) { + for _, t := range tracers { + if t.ChoseALPN != nil { + t.ChoseALPN(protocol) + } + } + }, Close: func() { for _, t := range tracers { if t.Close != nil { diff --git a/logging/interface.go b/logging/interface.go index 355bc09aa..928733fe6 100644 --- a/logging/interface.go +++ b/logging/interface.go @@ -37,7 +37,7 @@ type ( // The StreamType is the type of the stream (unidirectional or bidirectional). StreamType = protocol.StreamType // The VersionNumber is the QUIC version. - VersionNumber = protocol.VersionNumber + VersionNumber = protocol.Version // The Header is the QUIC packet header, before removing header protection. Header = wire.Header diff --git a/logging/multiplex_test.go b/logging/multiplex_test.go index cd87641d4..a1d9906e8 100644 --- a/logging/multiplex_test.go +++ b/logging/multiplex_test.go @@ -64,6 +64,18 @@ var _ = Describe("Tracing", func() { tr2.EXPECT().DroppedPacket(remote, PacketTypeRetry, ByteCount(1024), PacketDropDuplicate) tracer.DroppedPacket(remote, PacketTypeRetry, 1024, PacketDropDuplicate) }) + + It("traces the Debug event", func() { + tr1.EXPECT().Debug("foo", "bar") + tr2.EXPECT().Debug("foo", "bar") + tracer.Debug("foo", "bar") + }) + + It("traces the Close event", func() { + tr1.EXPECT().Close() + tr2.EXPECT().Close() + tracer.Close() + }) }) }) @@ -93,8 +105,8 @@ var _ = Describe("Tracing", func() { It("traces the NegotiatedVersion event", func() { chosen := protocol.Version2 - client := []protocol.VersionNumber{protocol.Version1} - server := []protocol.VersionNumber{13, 37} + client := []protocol.Version{protocol.Version1} + server := []protocol.Version{13, 37} tr1.EXPECT().NegotiatedVersion(chosen, client, server) tr2.EXPECT().NegotiatedVersion(chosen, client, server) tracer.NegotiatedVersion(chosen, client, server) @@ -184,9 +196,9 @@ var _ = Describe("Tracing", func() { }) It("traces the DroppedPacket event", func() { - tr1.EXPECT().DroppedPacket(PacketTypeInitial, ByteCount(1337), PacketDropHeaderParseError) - tr2.EXPECT().DroppedPacket(PacketTypeInitial, ByteCount(1337), PacketDropHeaderParseError) - tracer.DroppedPacket(PacketTypeInitial, 1337, PacketDropHeaderParseError) + tr1.EXPECT().DroppedPacket(PacketTypeInitial, PacketNumber(42), ByteCount(1337), PacketDropHeaderParseError) + tr2.EXPECT().DroppedPacket(PacketTypeInitial, PacketNumber(42), ByteCount(1337), PacketDropHeaderParseError) + tracer.DroppedPacket(PacketTypeInitial, 42, 1337, PacketDropHeaderParseError) }) It("traces the UpdatedCongestionState event", func() { diff --git a/logging/tracer.go b/logging/tracer.go index 5918f30f8..edd85dbaa 100644 --- a/logging/tracer.go +++ b/logging/tracer.go @@ -7,6 +7,8 @@ type Tracer struct { SentPacket func(net.Addr, *Header, ByteCount, []Frame) SentVersionNegotiationPacket func(_ net.Addr, dest, src ArbitraryLenConnectionID, _ []VersionNumber) DroppedPacket func(net.Addr, PacketType, ByteCount, PacketDropReason) + Debug func(name, msg string) + Close func() } // NewMultiplexedTracer creates a new tracer that multiplexes events to multiple tracers. @@ -39,5 +41,19 @@ func NewMultiplexedTracer(tracers ...*Tracer) *Tracer { } } }, + Debug: func(name, msg string) { + for _, t := range tracers { + if t.Debug != nil { + t.Debug(name, msg) + } + } + }, + Close: func() { + for _, t := range tracers { + if t.Close != nil { + t.Close() + } + } + }, } } diff --git a/mock_ack_frame_source_test.go b/mock_ack_frame_source_test.go index bc4969d47..95ed688c4 100644 --- a/mock_ack_frame_source_test.go +++ b/mock_ack_frame_source_test.go @@ -3,8 +3,9 @@ // // Generated by this command: // -// mockgen -build_flags=-tags=gomock -package quic -self_package github.com/refraction-networking/uquic -destination mock_ack_frame_source_test.go github.com/refraction-networking/uquic AckFrameSource +// mockgen -typed -build_flags=-tags=gomock -package quic -self_package github.com/quic-go/quic-go -destination mock_ack_frame_source_test.go github.com/quic-go/quic-go AckFrameSource // + // Package quic is a generated GoMock package. package quic @@ -48,7 +49,31 @@ func (m *MockAckFrameSource) GetAckFrame(arg0 protocol.EncryptionLevel, arg1 boo } // GetAckFrame indicates an expected call of GetAckFrame. -func (mr *MockAckFrameSourceMockRecorder) GetAckFrame(arg0, arg1 any) *gomock.Call { +func (mr *MockAckFrameSourceMockRecorder) GetAckFrame(arg0, arg1 any) *MockAckFrameSourceGetAckFrameCall { mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetAckFrame", reflect.TypeOf((*MockAckFrameSource)(nil).GetAckFrame), arg0, arg1) + call := mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetAckFrame", reflect.TypeOf((*MockAckFrameSource)(nil).GetAckFrame), arg0, arg1) + return &MockAckFrameSourceGetAckFrameCall{Call: call} +} + +// MockAckFrameSourceGetAckFrameCall wrap *gomock.Call +type MockAckFrameSourceGetAckFrameCall struct { + *gomock.Call +} + +// Return rewrite *gomock.Call.Return +func (c *MockAckFrameSourceGetAckFrameCall) Return(arg0 *wire.AckFrame) *MockAckFrameSourceGetAckFrameCall { + c.Call = c.Call.Return(arg0) + return c +} + +// Do rewrite *gomock.Call.Do +func (c *MockAckFrameSourceGetAckFrameCall) Do(f func(protocol.EncryptionLevel, bool) *wire.AckFrame) *MockAckFrameSourceGetAckFrameCall { + c.Call = c.Call.Do(f) + return c +} + +// DoAndReturn rewrite *gomock.Call.DoAndReturn +func (c *MockAckFrameSourceGetAckFrameCall) DoAndReturn(f func(protocol.EncryptionLevel, bool) *wire.AckFrame) *MockAckFrameSourceGetAckFrameCall { + c.Call = c.Call.DoAndReturn(f) + return c } diff --git a/mock_batch_conn_test.go b/mock_batch_conn_test.go index 7d3319574..ea74b4e83 100644 --- a/mock_batch_conn_test.go +++ b/mock_batch_conn_test.go @@ -3,8 +3,9 @@ // // Generated by this command: // -// mockgen -package quic -self_package github.com/refraction-networking/uquic -source sys_conn_oob.go -destination mock_batch_conn_test.go -mock_names batchConn=MockBatchConn +// mockgen -typed -package quic -self_package github.com/quic-go/quic-go -source sys_conn_oob.go -destination mock_batch_conn_test.go -mock_names batchConn=MockBatchConn // + // Package quic is a generated GoMock package. package quic @@ -48,7 +49,31 @@ func (m *MockBatchConn) ReadBatch(ms []ipv4.Message, flags int) (int, error) { } // ReadBatch indicates an expected call of ReadBatch. -func (mr *MockBatchConnMockRecorder) ReadBatch(ms, flags any) *gomock.Call { +func (mr *MockBatchConnMockRecorder) ReadBatch(ms, flags any) *MockBatchConnReadBatchCall { mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "ReadBatch", reflect.TypeOf((*MockBatchConn)(nil).ReadBatch), ms, flags) + call := mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "ReadBatch", reflect.TypeOf((*MockBatchConn)(nil).ReadBatch), ms, flags) + return &MockBatchConnReadBatchCall{Call: call} +} + +// MockBatchConnReadBatchCall wrap *gomock.Call +type MockBatchConnReadBatchCall struct { + *gomock.Call +} + +// Return rewrite *gomock.Call.Return +func (c *MockBatchConnReadBatchCall) Return(arg0 int, arg1 error) *MockBatchConnReadBatchCall { + c.Call = c.Call.Return(arg0, arg1) + return c +} + +// Do rewrite *gomock.Call.Do +func (c *MockBatchConnReadBatchCall) Do(f func([]ipv4.Message, int) (int, error)) *MockBatchConnReadBatchCall { + c.Call = c.Call.Do(f) + return c +} + +// DoAndReturn rewrite *gomock.Call.DoAndReturn +func (c *MockBatchConnReadBatchCall) DoAndReturn(f func([]ipv4.Message, int) (int, error)) *MockBatchConnReadBatchCall { + c.Call = c.Call.DoAndReturn(f) + return c } diff --git a/mock_conn_runner_test.go b/mock_conn_runner_test.go index 45056041b..fe6c0680e 100644 --- a/mock_conn_runner_test.go +++ b/mock_conn_runner_test.go @@ -3,8 +3,9 @@ // // Generated by this command: // -// mockgen -build_flags=-tags=gomock -package quic -self_package github.com/refraction-networking/uquic -destination mock_conn_runner_test.go github.com/refraction-networking/uquic ConnRunner +// mockgen -typed -build_flags=-tags=gomock -package quic -self_package github.com/quic-go/quic-go -destination mock_conn_runner_test.go github.com/quic-go/quic-go ConnRunner // + // Package quic is a generated GoMock package. package quic @@ -47,9 +48,33 @@ func (m *MockConnRunner) Add(arg0 protocol.ConnectionID, arg1 packetHandler) boo } // Add indicates an expected call of Add. -func (mr *MockConnRunnerMockRecorder) Add(arg0, arg1 any) *gomock.Call { +func (mr *MockConnRunnerMockRecorder) Add(arg0, arg1 any) *MockConnRunnerAddCall { mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Add", reflect.TypeOf((*MockConnRunner)(nil).Add), arg0, arg1) + call := mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Add", reflect.TypeOf((*MockConnRunner)(nil).Add), arg0, arg1) + return &MockConnRunnerAddCall{Call: call} +} + +// MockConnRunnerAddCall wrap *gomock.Call +type MockConnRunnerAddCall struct { + *gomock.Call +} + +// Return rewrite *gomock.Call.Return +func (c *MockConnRunnerAddCall) Return(arg0 bool) *MockConnRunnerAddCall { + c.Call = c.Call.Return(arg0) + return c +} + +// Do rewrite *gomock.Call.Do +func (c *MockConnRunnerAddCall) Do(f func(protocol.ConnectionID, packetHandler) bool) *MockConnRunnerAddCall { + c.Call = c.Call.Do(f) + return c +} + +// DoAndReturn rewrite *gomock.Call.DoAndReturn +func (c *MockConnRunnerAddCall) DoAndReturn(f func(protocol.ConnectionID, packetHandler) bool) *MockConnRunnerAddCall { + c.Call = c.Call.DoAndReturn(f) + return c } // AddResetToken mocks base method. @@ -59,9 +84,33 @@ func (m *MockConnRunner) AddResetToken(arg0 protocol.StatelessResetToken, arg1 p } // AddResetToken indicates an expected call of AddResetToken. -func (mr *MockConnRunnerMockRecorder) AddResetToken(arg0, arg1 any) *gomock.Call { +func (mr *MockConnRunnerMockRecorder) AddResetToken(arg0, arg1 any) *MockConnRunnerAddResetTokenCall { mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "AddResetToken", reflect.TypeOf((*MockConnRunner)(nil).AddResetToken), arg0, arg1) + call := mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "AddResetToken", reflect.TypeOf((*MockConnRunner)(nil).AddResetToken), arg0, arg1) + return &MockConnRunnerAddResetTokenCall{Call: call} +} + +// MockConnRunnerAddResetTokenCall wrap *gomock.Call +type MockConnRunnerAddResetTokenCall struct { + *gomock.Call +} + +// Return rewrite *gomock.Call.Return +func (c *MockConnRunnerAddResetTokenCall) Return() *MockConnRunnerAddResetTokenCall { + c.Call = c.Call.Return() + return c +} + +// Do rewrite *gomock.Call.Do +func (c *MockConnRunnerAddResetTokenCall) Do(f func(protocol.StatelessResetToken, packetHandler)) *MockConnRunnerAddResetTokenCall { + c.Call = c.Call.Do(f) + return c +} + +// DoAndReturn rewrite *gomock.Call.DoAndReturn +func (c *MockConnRunnerAddResetTokenCall) DoAndReturn(f func(protocol.StatelessResetToken, packetHandler)) *MockConnRunnerAddResetTokenCall { + c.Call = c.Call.DoAndReturn(f) + return c } // GetStatelessResetToken mocks base method. @@ -73,9 +122,33 @@ func (m *MockConnRunner) GetStatelessResetToken(arg0 protocol.ConnectionID) prot } // GetStatelessResetToken indicates an expected call of GetStatelessResetToken. -func (mr *MockConnRunnerMockRecorder) GetStatelessResetToken(arg0 any) *gomock.Call { +func (mr *MockConnRunnerMockRecorder) GetStatelessResetToken(arg0 any) *MockConnRunnerGetStatelessResetTokenCall { mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetStatelessResetToken", reflect.TypeOf((*MockConnRunner)(nil).GetStatelessResetToken), arg0) + call := mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetStatelessResetToken", reflect.TypeOf((*MockConnRunner)(nil).GetStatelessResetToken), arg0) + return &MockConnRunnerGetStatelessResetTokenCall{Call: call} +} + +// MockConnRunnerGetStatelessResetTokenCall wrap *gomock.Call +type MockConnRunnerGetStatelessResetTokenCall struct { + *gomock.Call +} + +// Return rewrite *gomock.Call.Return +func (c *MockConnRunnerGetStatelessResetTokenCall) Return(arg0 protocol.StatelessResetToken) *MockConnRunnerGetStatelessResetTokenCall { + c.Call = c.Call.Return(arg0) + return c +} + +// Do rewrite *gomock.Call.Do +func (c *MockConnRunnerGetStatelessResetTokenCall) Do(f func(protocol.ConnectionID) protocol.StatelessResetToken) *MockConnRunnerGetStatelessResetTokenCall { + c.Call = c.Call.Do(f) + return c +} + +// DoAndReturn rewrite *gomock.Call.DoAndReturn +func (c *MockConnRunnerGetStatelessResetTokenCall) DoAndReturn(f func(protocol.ConnectionID) protocol.StatelessResetToken) *MockConnRunnerGetStatelessResetTokenCall { + c.Call = c.Call.DoAndReturn(f) + return c } // Remove mocks base method. @@ -85,9 +158,33 @@ func (m *MockConnRunner) Remove(arg0 protocol.ConnectionID) { } // Remove indicates an expected call of Remove. -func (mr *MockConnRunnerMockRecorder) Remove(arg0 any) *gomock.Call { +func (mr *MockConnRunnerMockRecorder) Remove(arg0 any) *MockConnRunnerRemoveCall { mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Remove", reflect.TypeOf((*MockConnRunner)(nil).Remove), arg0) + call := mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Remove", reflect.TypeOf((*MockConnRunner)(nil).Remove), arg0) + return &MockConnRunnerRemoveCall{Call: call} +} + +// MockConnRunnerRemoveCall wrap *gomock.Call +type MockConnRunnerRemoveCall struct { + *gomock.Call +} + +// Return rewrite *gomock.Call.Return +func (c *MockConnRunnerRemoveCall) Return() *MockConnRunnerRemoveCall { + c.Call = c.Call.Return() + return c +} + +// Do rewrite *gomock.Call.Do +func (c *MockConnRunnerRemoveCall) Do(f func(protocol.ConnectionID)) *MockConnRunnerRemoveCall { + c.Call = c.Call.Do(f) + return c +} + +// DoAndReturn rewrite *gomock.Call.DoAndReturn +func (c *MockConnRunnerRemoveCall) DoAndReturn(f func(protocol.ConnectionID)) *MockConnRunnerRemoveCall { + c.Call = c.Call.DoAndReturn(f) + return c } // RemoveResetToken mocks base method. @@ -97,21 +194,69 @@ func (m *MockConnRunner) RemoveResetToken(arg0 protocol.StatelessResetToken) { } // RemoveResetToken indicates an expected call of RemoveResetToken. -func (mr *MockConnRunnerMockRecorder) RemoveResetToken(arg0 any) *gomock.Call { +func (mr *MockConnRunnerMockRecorder) RemoveResetToken(arg0 any) *MockConnRunnerRemoveResetTokenCall { mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "RemoveResetToken", reflect.TypeOf((*MockConnRunner)(nil).RemoveResetToken), arg0) + call := mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "RemoveResetToken", reflect.TypeOf((*MockConnRunner)(nil).RemoveResetToken), arg0) + return &MockConnRunnerRemoveResetTokenCall{Call: call} +} + +// MockConnRunnerRemoveResetTokenCall wrap *gomock.Call +type MockConnRunnerRemoveResetTokenCall struct { + *gomock.Call +} + +// Return rewrite *gomock.Call.Return +func (c *MockConnRunnerRemoveResetTokenCall) Return() *MockConnRunnerRemoveResetTokenCall { + c.Call = c.Call.Return() + return c +} + +// Do rewrite *gomock.Call.Do +func (c *MockConnRunnerRemoveResetTokenCall) Do(f func(protocol.StatelessResetToken)) *MockConnRunnerRemoveResetTokenCall { + c.Call = c.Call.Do(f) + return c +} + +// DoAndReturn rewrite *gomock.Call.DoAndReturn +func (c *MockConnRunnerRemoveResetTokenCall) DoAndReturn(f func(protocol.StatelessResetToken)) *MockConnRunnerRemoveResetTokenCall { + c.Call = c.Call.DoAndReturn(f) + return c } // ReplaceWithClosed mocks base method. -func (m *MockConnRunner) ReplaceWithClosed(arg0 []protocol.ConnectionID, arg1 protocol.Perspective, arg2 []byte) { +func (m *MockConnRunner) ReplaceWithClosed(arg0 []protocol.ConnectionID, arg1 []byte) { m.ctrl.T.Helper() - m.ctrl.Call(m, "ReplaceWithClosed", arg0, arg1, arg2) + m.ctrl.Call(m, "ReplaceWithClosed", arg0, arg1) } // ReplaceWithClosed indicates an expected call of ReplaceWithClosed. -func (mr *MockConnRunnerMockRecorder) ReplaceWithClosed(arg0, arg1, arg2 any) *gomock.Call { +func (mr *MockConnRunnerMockRecorder) ReplaceWithClosed(arg0, arg1 any) *MockConnRunnerReplaceWithClosedCall { mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "ReplaceWithClosed", reflect.TypeOf((*MockConnRunner)(nil).ReplaceWithClosed), arg0, arg1, arg2) + call := mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "ReplaceWithClosed", reflect.TypeOf((*MockConnRunner)(nil).ReplaceWithClosed), arg0, arg1) + return &MockConnRunnerReplaceWithClosedCall{Call: call} +} + +// MockConnRunnerReplaceWithClosedCall wrap *gomock.Call +type MockConnRunnerReplaceWithClosedCall struct { + *gomock.Call +} + +// Return rewrite *gomock.Call.Return +func (c *MockConnRunnerReplaceWithClosedCall) Return() *MockConnRunnerReplaceWithClosedCall { + c.Call = c.Call.Return() + return c +} + +// Do rewrite *gomock.Call.Do +func (c *MockConnRunnerReplaceWithClosedCall) Do(f func([]protocol.ConnectionID, []byte)) *MockConnRunnerReplaceWithClosedCall { + c.Call = c.Call.Do(f) + return c +} + +// DoAndReturn rewrite *gomock.Call.DoAndReturn +func (c *MockConnRunnerReplaceWithClosedCall) DoAndReturn(f func([]protocol.ConnectionID, []byte)) *MockConnRunnerReplaceWithClosedCall { + c.Call = c.Call.DoAndReturn(f) + return c } // Retire mocks base method. @@ -121,7 +266,31 @@ func (m *MockConnRunner) Retire(arg0 protocol.ConnectionID) { } // Retire indicates an expected call of Retire. -func (mr *MockConnRunnerMockRecorder) Retire(arg0 any) *gomock.Call { +func (mr *MockConnRunnerMockRecorder) Retire(arg0 any) *MockConnRunnerRetireCall { mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Retire", reflect.TypeOf((*MockConnRunner)(nil).Retire), arg0) + call := mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Retire", reflect.TypeOf((*MockConnRunner)(nil).Retire), arg0) + return &MockConnRunnerRetireCall{Call: call} +} + +// MockConnRunnerRetireCall wrap *gomock.Call +type MockConnRunnerRetireCall struct { + *gomock.Call +} + +// Return rewrite *gomock.Call.Return +func (c *MockConnRunnerRetireCall) Return() *MockConnRunnerRetireCall { + c.Call = c.Call.Return() + return c +} + +// Do rewrite *gomock.Call.Do +func (c *MockConnRunnerRetireCall) Do(f func(protocol.ConnectionID)) *MockConnRunnerRetireCall { + c.Call = c.Call.Do(f) + return c +} + +// DoAndReturn rewrite *gomock.Call.DoAndReturn +func (c *MockConnRunnerRetireCall) DoAndReturn(f func(protocol.ConnectionID)) *MockConnRunnerRetireCall { + c.Call = c.Call.DoAndReturn(f) + return c } diff --git a/mock_crypto_data_handler_test.go b/mock_crypto_data_handler_test.go index c8a601012..bf976b8d7 100644 --- a/mock_crypto_data_handler_test.go +++ b/mock_crypto_data_handler_test.go @@ -3,8 +3,9 @@ // // Generated by this command: // -// mockgen -build_flags=-tags=gomock -package quic -self_package github.com/refraction-networking/uquic -destination mock_crypto_data_handler_test.go github.com/refraction-networking/uquic CryptoDataHandler +// mockgen -typed -build_flags=-tags=gomock -package quic -self_package github.com/quic-go/quic-go -destination mock_crypto_data_handler_test.go github.com/quic-go/quic-go CryptoDataHandler // + // Package quic is a generated GoMock package. package quic @@ -48,9 +49,33 @@ func (m *MockCryptoDataHandler) HandleMessage(arg0 []byte, arg1 protocol.Encrypt } // HandleMessage indicates an expected call of HandleMessage. -func (mr *MockCryptoDataHandlerMockRecorder) HandleMessage(arg0, arg1 any) *gomock.Call { +func (mr *MockCryptoDataHandlerMockRecorder) HandleMessage(arg0, arg1 any) *MockCryptoDataHandlerHandleMessageCall { mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "HandleMessage", reflect.TypeOf((*MockCryptoDataHandler)(nil).HandleMessage), arg0, arg1) + call := mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "HandleMessage", reflect.TypeOf((*MockCryptoDataHandler)(nil).HandleMessage), arg0, arg1) + return &MockCryptoDataHandlerHandleMessageCall{Call: call} +} + +// MockCryptoDataHandlerHandleMessageCall wrap *gomock.Call +type MockCryptoDataHandlerHandleMessageCall struct { + *gomock.Call +} + +// Return rewrite *gomock.Call.Return +func (c *MockCryptoDataHandlerHandleMessageCall) Return(arg0 error) *MockCryptoDataHandlerHandleMessageCall { + c.Call = c.Call.Return(arg0) + return c +} + +// Do rewrite *gomock.Call.Do +func (c *MockCryptoDataHandlerHandleMessageCall) Do(f func([]byte, protocol.EncryptionLevel) error) *MockCryptoDataHandlerHandleMessageCall { + c.Call = c.Call.Do(f) + return c +} + +// DoAndReturn rewrite *gomock.Call.DoAndReturn +func (c *MockCryptoDataHandlerHandleMessageCall) DoAndReturn(f func([]byte, protocol.EncryptionLevel) error) *MockCryptoDataHandlerHandleMessageCall { + c.Call = c.Call.DoAndReturn(f) + return c } // NextEvent mocks base method. @@ -62,7 +87,31 @@ func (m *MockCryptoDataHandler) NextEvent() handshake.Event { } // NextEvent indicates an expected call of NextEvent. -func (mr *MockCryptoDataHandlerMockRecorder) NextEvent() *gomock.Call { +func (mr *MockCryptoDataHandlerMockRecorder) NextEvent() *MockCryptoDataHandlerNextEventCall { mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "NextEvent", reflect.TypeOf((*MockCryptoDataHandler)(nil).NextEvent)) + call := mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "NextEvent", reflect.TypeOf((*MockCryptoDataHandler)(nil).NextEvent)) + return &MockCryptoDataHandlerNextEventCall{Call: call} +} + +// MockCryptoDataHandlerNextEventCall wrap *gomock.Call +type MockCryptoDataHandlerNextEventCall struct { + *gomock.Call +} + +// Return rewrite *gomock.Call.Return +func (c *MockCryptoDataHandlerNextEventCall) Return(arg0 handshake.Event) *MockCryptoDataHandlerNextEventCall { + c.Call = c.Call.Return(arg0) + return c +} + +// Do rewrite *gomock.Call.Do +func (c *MockCryptoDataHandlerNextEventCall) Do(f func() handshake.Event) *MockCryptoDataHandlerNextEventCall { + c.Call = c.Call.Do(f) + return c +} + +// DoAndReturn rewrite *gomock.Call.DoAndReturn +func (c *MockCryptoDataHandlerNextEventCall) DoAndReturn(f func() handshake.Event) *MockCryptoDataHandlerNextEventCall { + c.Call = c.Call.DoAndReturn(f) + return c } diff --git a/mock_crypto_stream_test.go b/mock_crypto_stream_test.go index cc01d35a7..09fe6d8a6 100644 --- a/mock_crypto_stream_test.go +++ b/mock_crypto_stream_test.go @@ -3,8 +3,9 @@ // // Generated by this command: // -// mockgen -build_flags=-tags=gomock -package quic -self_package github.com/refraction-networking/uquic -destination mock_crypto_stream_test.go github.com/refraction-networking/uquic CryptoStream +// mockgen -typed -build_flags=-tags=gomock -package quic -self_package github.com/quic-go/quic-go -destination mock_crypto_stream_test.go github.com/quic-go/quic-go CryptoStream // + // Package quic is a generated GoMock package. package quic @@ -48,9 +49,33 @@ func (m *MockCryptoStream) Finish() error { } // Finish indicates an expected call of Finish. -func (mr *MockCryptoStreamMockRecorder) Finish() *gomock.Call { +func (mr *MockCryptoStreamMockRecorder) Finish() *MockCryptoStreamFinishCall { mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Finish", reflect.TypeOf((*MockCryptoStream)(nil).Finish)) + call := mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Finish", reflect.TypeOf((*MockCryptoStream)(nil).Finish)) + return &MockCryptoStreamFinishCall{Call: call} +} + +// MockCryptoStreamFinishCall wrap *gomock.Call +type MockCryptoStreamFinishCall struct { + *gomock.Call +} + +// Return rewrite *gomock.Call.Return +func (c *MockCryptoStreamFinishCall) Return(arg0 error) *MockCryptoStreamFinishCall { + c.Call = c.Call.Return(arg0) + return c +} + +// Do rewrite *gomock.Call.Do +func (c *MockCryptoStreamFinishCall) Do(f func() error) *MockCryptoStreamFinishCall { + c.Call = c.Call.Do(f) + return c +} + +// DoAndReturn rewrite *gomock.Call.DoAndReturn +func (c *MockCryptoStreamFinishCall) DoAndReturn(f func() error) *MockCryptoStreamFinishCall { + c.Call = c.Call.DoAndReturn(f) + return c } // GetCryptoData mocks base method. @@ -62,9 +87,33 @@ func (m *MockCryptoStream) GetCryptoData() []byte { } // GetCryptoData indicates an expected call of GetCryptoData. -func (mr *MockCryptoStreamMockRecorder) GetCryptoData() *gomock.Call { +func (mr *MockCryptoStreamMockRecorder) GetCryptoData() *MockCryptoStreamGetCryptoDataCall { mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetCryptoData", reflect.TypeOf((*MockCryptoStream)(nil).GetCryptoData)) + call := mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetCryptoData", reflect.TypeOf((*MockCryptoStream)(nil).GetCryptoData)) + return &MockCryptoStreamGetCryptoDataCall{Call: call} +} + +// MockCryptoStreamGetCryptoDataCall wrap *gomock.Call +type MockCryptoStreamGetCryptoDataCall struct { + *gomock.Call +} + +// Return rewrite *gomock.Call.Return +func (c *MockCryptoStreamGetCryptoDataCall) Return(arg0 []byte) *MockCryptoStreamGetCryptoDataCall { + c.Call = c.Call.Return(arg0) + return c +} + +// Do rewrite *gomock.Call.Do +func (c *MockCryptoStreamGetCryptoDataCall) Do(f func() []byte) *MockCryptoStreamGetCryptoDataCall { + c.Call = c.Call.Do(f) + return c +} + +// DoAndReturn rewrite *gomock.Call.DoAndReturn +func (c *MockCryptoStreamGetCryptoDataCall) DoAndReturn(f func() []byte) *MockCryptoStreamGetCryptoDataCall { + c.Call = c.Call.DoAndReturn(f) + return c } // HandleCryptoFrame mocks base method. @@ -76,9 +125,33 @@ func (m *MockCryptoStream) HandleCryptoFrame(arg0 *wire.CryptoFrame) error { } // HandleCryptoFrame indicates an expected call of HandleCryptoFrame. -func (mr *MockCryptoStreamMockRecorder) HandleCryptoFrame(arg0 any) *gomock.Call { +func (mr *MockCryptoStreamMockRecorder) HandleCryptoFrame(arg0 any) *MockCryptoStreamHandleCryptoFrameCall { mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "HandleCryptoFrame", reflect.TypeOf((*MockCryptoStream)(nil).HandleCryptoFrame), arg0) + call := mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "HandleCryptoFrame", reflect.TypeOf((*MockCryptoStream)(nil).HandleCryptoFrame), arg0) + return &MockCryptoStreamHandleCryptoFrameCall{Call: call} +} + +// MockCryptoStreamHandleCryptoFrameCall wrap *gomock.Call +type MockCryptoStreamHandleCryptoFrameCall struct { + *gomock.Call +} + +// Return rewrite *gomock.Call.Return +func (c *MockCryptoStreamHandleCryptoFrameCall) Return(arg0 error) *MockCryptoStreamHandleCryptoFrameCall { + c.Call = c.Call.Return(arg0) + return c +} + +// Do rewrite *gomock.Call.Do +func (c *MockCryptoStreamHandleCryptoFrameCall) Do(f func(*wire.CryptoFrame) error) *MockCryptoStreamHandleCryptoFrameCall { + c.Call = c.Call.Do(f) + return c +} + +// DoAndReturn rewrite *gomock.Call.DoAndReturn +func (c *MockCryptoStreamHandleCryptoFrameCall) DoAndReturn(f func(*wire.CryptoFrame) error) *MockCryptoStreamHandleCryptoFrameCall { + c.Call = c.Call.DoAndReturn(f) + return c } // HasData mocks base method. @@ -90,9 +163,33 @@ func (m *MockCryptoStream) HasData() bool { } // HasData indicates an expected call of HasData. -func (mr *MockCryptoStreamMockRecorder) HasData() *gomock.Call { +func (mr *MockCryptoStreamMockRecorder) HasData() *MockCryptoStreamHasDataCall { mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "HasData", reflect.TypeOf((*MockCryptoStream)(nil).HasData)) + call := mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "HasData", reflect.TypeOf((*MockCryptoStream)(nil).HasData)) + return &MockCryptoStreamHasDataCall{Call: call} +} + +// MockCryptoStreamHasDataCall wrap *gomock.Call +type MockCryptoStreamHasDataCall struct { + *gomock.Call +} + +// Return rewrite *gomock.Call.Return +func (c *MockCryptoStreamHasDataCall) Return(arg0 bool) *MockCryptoStreamHasDataCall { + c.Call = c.Call.Return(arg0) + return c +} + +// Do rewrite *gomock.Call.Do +func (c *MockCryptoStreamHasDataCall) Do(f func() bool) *MockCryptoStreamHasDataCall { + c.Call = c.Call.Do(f) + return c +} + +// DoAndReturn rewrite *gomock.Call.DoAndReturn +func (c *MockCryptoStreamHasDataCall) DoAndReturn(f func() bool) *MockCryptoStreamHasDataCall { + c.Call = c.Call.DoAndReturn(f) + return c } // PopCryptoFrame mocks base method. @@ -104,9 +201,33 @@ func (m *MockCryptoStream) PopCryptoFrame(arg0 protocol.ByteCount) *wire.CryptoF } // PopCryptoFrame indicates an expected call of PopCryptoFrame. -func (mr *MockCryptoStreamMockRecorder) PopCryptoFrame(arg0 any) *gomock.Call { +func (mr *MockCryptoStreamMockRecorder) PopCryptoFrame(arg0 any) *MockCryptoStreamPopCryptoFrameCall { mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "PopCryptoFrame", reflect.TypeOf((*MockCryptoStream)(nil).PopCryptoFrame), arg0) + call := mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "PopCryptoFrame", reflect.TypeOf((*MockCryptoStream)(nil).PopCryptoFrame), arg0) + return &MockCryptoStreamPopCryptoFrameCall{Call: call} +} + +// MockCryptoStreamPopCryptoFrameCall wrap *gomock.Call +type MockCryptoStreamPopCryptoFrameCall struct { + *gomock.Call +} + +// Return rewrite *gomock.Call.Return +func (c *MockCryptoStreamPopCryptoFrameCall) Return(arg0 *wire.CryptoFrame) *MockCryptoStreamPopCryptoFrameCall { + c.Call = c.Call.Return(arg0) + return c +} + +// Do rewrite *gomock.Call.Do +func (c *MockCryptoStreamPopCryptoFrameCall) Do(f func(protocol.ByteCount) *wire.CryptoFrame) *MockCryptoStreamPopCryptoFrameCall { + c.Call = c.Call.Do(f) + return c +} + +// DoAndReturn rewrite *gomock.Call.DoAndReturn +func (c *MockCryptoStreamPopCryptoFrameCall) DoAndReturn(f func(protocol.ByteCount) *wire.CryptoFrame) *MockCryptoStreamPopCryptoFrameCall { + c.Call = c.Call.DoAndReturn(f) + return c } // Write mocks base method. @@ -119,7 +240,31 @@ func (m *MockCryptoStream) Write(arg0 []byte) (int, error) { } // Write indicates an expected call of Write. -func (mr *MockCryptoStreamMockRecorder) Write(arg0 any) *gomock.Call { +func (mr *MockCryptoStreamMockRecorder) Write(arg0 any) *MockCryptoStreamWriteCall { mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Write", reflect.TypeOf((*MockCryptoStream)(nil).Write), arg0) + call := mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Write", reflect.TypeOf((*MockCryptoStream)(nil).Write), arg0) + return &MockCryptoStreamWriteCall{Call: call} +} + +// MockCryptoStreamWriteCall wrap *gomock.Call +type MockCryptoStreamWriteCall struct { + *gomock.Call +} + +// Return rewrite *gomock.Call.Return +func (c *MockCryptoStreamWriteCall) Return(arg0 int, arg1 error) *MockCryptoStreamWriteCall { + c.Call = c.Call.Return(arg0, arg1) + return c +} + +// Do rewrite *gomock.Call.Do +func (c *MockCryptoStreamWriteCall) Do(f func([]byte) (int, error)) *MockCryptoStreamWriteCall { + c.Call = c.Call.Do(f) + return c +} + +// DoAndReturn rewrite *gomock.Call.DoAndReturn +func (c *MockCryptoStreamWriteCall) DoAndReturn(f func([]byte) (int, error)) *MockCryptoStreamWriteCall { + c.Call = c.Call.DoAndReturn(f) + return c } diff --git a/mock_frame_source_test.go b/mock_frame_source_test.go index 5454890f5..ca868145d 100644 --- a/mock_frame_source_test.go +++ b/mock_frame_source_test.go @@ -3,8 +3,9 @@ // // Generated by this command: // -// mockgen -build_flags=-tags=gomock -package quic -self_package github.com/refraction-networking/uquic -destination mock_frame_source_test.go github.com/refraction-networking/uquic FrameSource +// mockgen -typed -build_flags=-tags=gomock -package quic -self_package github.com/quic-go/quic-go -destination mock_frame_source_test.go github.com/quic-go/quic-go FrameSource // + // Package quic is a generated GoMock package. package quic @@ -40,7 +41,7 @@ func (m *MockFrameSource) EXPECT() *MockFrameSourceMockRecorder { } // AppendControlFrames mocks base method. -func (m *MockFrameSource) AppendControlFrames(arg0 []ackhandler.Frame, arg1 protocol.ByteCount, arg2 protocol.VersionNumber) ([]ackhandler.Frame, protocol.ByteCount) { +func (m *MockFrameSource) AppendControlFrames(arg0 []ackhandler.Frame, arg1 protocol.ByteCount, arg2 protocol.Version) ([]ackhandler.Frame, protocol.ByteCount) { m.ctrl.T.Helper() ret := m.ctrl.Call(m, "AppendControlFrames", arg0, arg1, arg2) ret0, _ := ret[0].([]ackhandler.Frame) @@ -49,13 +50,37 @@ func (m *MockFrameSource) AppendControlFrames(arg0 []ackhandler.Frame, arg1 prot } // AppendControlFrames indicates an expected call of AppendControlFrames. -func (mr *MockFrameSourceMockRecorder) AppendControlFrames(arg0, arg1, arg2 any) *gomock.Call { +func (mr *MockFrameSourceMockRecorder) AppendControlFrames(arg0, arg1, arg2 any) *MockFrameSourceAppendControlFramesCall { mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "AppendControlFrames", reflect.TypeOf((*MockFrameSource)(nil).AppendControlFrames), arg0, arg1, arg2) + call := mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "AppendControlFrames", reflect.TypeOf((*MockFrameSource)(nil).AppendControlFrames), arg0, arg1, arg2) + return &MockFrameSourceAppendControlFramesCall{Call: call} +} + +// MockFrameSourceAppendControlFramesCall wrap *gomock.Call +type MockFrameSourceAppendControlFramesCall struct { + *gomock.Call +} + +// Return rewrite *gomock.Call.Return +func (c *MockFrameSourceAppendControlFramesCall) Return(arg0 []ackhandler.Frame, arg1 protocol.ByteCount) *MockFrameSourceAppendControlFramesCall { + c.Call = c.Call.Return(arg0, arg1) + return c +} + +// Do rewrite *gomock.Call.Do +func (c *MockFrameSourceAppendControlFramesCall) Do(f func([]ackhandler.Frame, protocol.ByteCount, protocol.Version) ([]ackhandler.Frame, protocol.ByteCount)) *MockFrameSourceAppendControlFramesCall { + c.Call = c.Call.Do(f) + return c +} + +// DoAndReturn rewrite *gomock.Call.DoAndReturn +func (c *MockFrameSourceAppendControlFramesCall) DoAndReturn(f func([]ackhandler.Frame, protocol.ByteCount, protocol.Version) ([]ackhandler.Frame, protocol.ByteCount)) *MockFrameSourceAppendControlFramesCall { + c.Call = c.Call.DoAndReturn(f) + return c } // AppendStreamFrames mocks base method. -func (m *MockFrameSource) AppendStreamFrames(arg0 []ackhandler.StreamFrame, arg1 protocol.ByteCount, arg2 protocol.VersionNumber) ([]ackhandler.StreamFrame, protocol.ByteCount) { +func (m *MockFrameSource) AppendStreamFrames(arg0 []ackhandler.StreamFrame, arg1 protocol.ByteCount, arg2 protocol.Version) ([]ackhandler.StreamFrame, protocol.ByteCount) { m.ctrl.T.Helper() ret := m.ctrl.Call(m, "AppendStreamFrames", arg0, arg1, arg2) ret0, _ := ret[0].([]ackhandler.StreamFrame) @@ -64,9 +89,33 @@ func (m *MockFrameSource) AppendStreamFrames(arg0 []ackhandler.StreamFrame, arg1 } // AppendStreamFrames indicates an expected call of AppendStreamFrames. -func (mr *MockFrameSourceMockRecorder) AppendStreamFrames(arg0, arg1, arg2 any) *gomock.Call { +func (mr *MockFrameSourceMockRecorder) AppendStreamFrames(arg0, arg1, arg2 any) *MockFrameSourceAppendStreamFramesCall { mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "AppendStreamFrames", reflect.TypeOf((*MockFrameSource)(nil).AppendStreamFrames), arg0, arg1, arg2) + call := mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "AppendStreamFrames", reflect.TypeOf((*MockFrameSource)(nil).AppendStreamFrames), arg0, arg1, arg2) + return &MockFrameSourceAppendStreamFramesCall{Call: call} +} + +// MockFrameSourceAppendStreamFramesCall wrap *gomock.Call +type MockFrameSourceAppendStreamFramesCall struct { + *gomock.Call +} + +// Return rewrite *gomock.Call.Return +func (c *MockFrameSourceAppendStreamFramesCall) Return(arg0 []ackhandler.StreamFrame, arg1 protocol.ByteCount) *MockFrameSourceAppendStreamFramesCall { + c.Call = c.Call.Return(arg0, arg1) + return c +} + +// Do rewrite *gomock.Call.Do +func (c *MockFrameSourceAppendStreamFramesCall) Do(f func([]ackhandler.StreamFrame, protocol.ByteCount, protocol.Version) ([]ackhandler.StreamFrame, protocol.ByteCount)) *MockFrameSourceAppendStreamFramesCall { + c.Call = c.Call.Do(f) + return c +} + +// DoAndReturn rewrite *gomock.Call.DoAndReturn +func (c *MockFrameSourceAppendStreamFramesCall) DoAndReturn(f func([]ackhandler.StreamFrame, protocol.ByteCount, protocol.Version) ([]ackhandler.StreamFrame, protocol.ByteCount)) *MockFrameSourceAppendStreamFramesCall { + c.Call = c.Call.DoAndReturn(f) + return c } // HasData mocks base method. @@ -78,7 +127,31 @@ func (m *MockFrameSource) HasData() bool { } // HasData indicates an expected call of HasData. -func (mr *MockFrameSourceMockRecorder) HasData() *gomock.Call { +func (mr *MockFrameSourceMockRecorder) HasData() *MockFrameSourceHasDataCall { mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "HasData", reflect.TypeOf((*MockFrameSource)(nil).HasData)) + call := mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "HasData", reflect.TypeOf((*MockFrameSource)(nil).HasData)) + return &MockFrameSourceHasDataCall{Call: call} +} + +// MockFrameSourceHasDataCall wrap *gomock.Call +type MockFrameSourceHasDataCall struct { + *gomock.Call +} + +// Return rewrite *gomock.Call.Return +func (c *MockFrameSourceHasDataCall) Return(arg0 bool) *MockFrameSourceHasDataCall { + c.Call = c.Call.Return(arg0) + return c +} + +// Do rewrite *gomock.Call.Do +func (c *MockFrameSourceHasDataCall) Do(f func() bool) *MockFrameSourceHasDataCall { + c.Call = c.Call.Do(f) + return c +} + +// DoAndReturn rewrite *gomock.Call.DoAndReturn +func (c *MockFrameSourceHasDataCall) DoAndReturn(f func() bool) *MockFrameSourceHasDataCall { + c.Call = c.Call.DoAndReturn(f) + return c } diff --git a/mock_mtu_discoverer_test.go b/mock_mtu_discoverer_test.go index 1af111f4f..821a551bd 100644 --- a/mock_mtu_discoverer_test.go +++ b/mock_mtu_discoverer_test.go @@ -3,8 +3,9 @@ // // Generated by this command: // -// mockgen -build_flags=-tags=gomock -package quic -self_package github.com/refraction-networking/uquic -destination mock_mtu_discoverer_test.go github.com/refraction-networking/uquic MTUDiscoverer +// mockgen -typed -build_flags=-tags=gomock -package quic -self_package github.com/quic-go/quic-go -destination mock_mtu_discoverer_test.go github.com/quic-go/quic-go MTUDiscoverer // + // Package quic is a generated GoMock package. package quic @@ -49,9 +50,33 @@ func (m *MockMTUDiscoverer) CurrentSize() protocol.ByteCount { } // CurrentSize indicates an expected call of CurrentSize. -func (mr *MockMTUDiscovererMockRecorder) CurrentSize() *gomock.Call { +func (mr *MockMTUDiscovererMockRecorder) CurrentSize() *MockMTUDiscovererCurrentSizeCall { mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "CurrentSize", reflect.TypeOf((*MockMTUDiscoverer)(nil).CurrentSize)) + call := mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "CurrentSize", reflect.TypeOf((*MockMTUDiscoverer)(nil).CurrentSize)) + return &MockMTUDiscovererCurrentSizeCall{Call: call} +} + +// MockMTUDiscovererCurrentSizeCall wrap *gomock.Call +type MockMTUDiscovererCurrentSizeCall struct { + *gomock.Call +} + +// Return rewrite *gomock.Call.Return +func (c *MockMTUDiscovererCurrentSizeCall) Return(arg0 protocol.ByteCount) *MockMTUDiscovererCurrentSizeCall { + c.Call = c.Call.Return(arg0) + return c +} + +// Do rewrite *gomock.Call.Do +func (c *MockMTUDiscovererCurrentSizeCall) Do(f func() protocol.ByteCount) *MockMTUDiscovererCurrentSizeCall { + c.Call = c.Call.Do(f) + return c +} + +// DoAndReturn rewrite *gomock.Call.DoAndReturn +func (c *MockMTUDiscovererCurrentSizeCall) DoAndReturn(f func() protocol.ByteCount) *MockMTUDiscovererCurrentSizeCall { + c.Call = c.Call.DoAndReturn(f) + return c } // GetPing mocks base method. @@ -64,9 +89,33 @@ func (m *MockMTUDiscoverer) GetPing() (ackhandler.Frame, protocol.ByteCount) { } // GetPing indicates an expected call of GetPing. -func (mr *MockMTUDiscovererMockRecorder) GetPing() *gomock.Call { +func (mr *MockMTUDiscovererMockRecorder) GetPing() *MockMTUDiscovererGetPingCall { mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetPing", reflect.TypeOf((*MockMTUDiscoverer)(nil).GetPing)) + call := mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetPing", reflect.TypeOf((*MockMTUDiscoverer)(nil).GetPing)) + return &MockMTUDiscovererGetPingCall{Call: call} +} + +// MockMTUDiscovererGetPingCall wrap *gomock.Call +type MockMTUDiscovererGetPingCall struct { + *gomock.Call +} + +// Return rewrite *gomock.Call.Return +func (c *MockMTUDiscovererGetPingCall) Return(arg0 ackhandler.Frame, arg1 protocol.ByteCount) *MockMTUDiscovererGetPingCall { + c.Call = c.Call.Return(arg0, arg1) + return c +} + +// Do rewrite *gomock.Call.Do +func (c *MockMTUDiscovererGetPingCall) Do(f func() (ackhandler.Frame, protocol.ByteCount)) *MockMTUDiscovererGetPingCall { + c.Call = c.Call.Do(f) + return c +} + +// DoAndReturn rewrite *gomock.Call.DoAndReturn +func (c *MockMTUDiscovererGetPingCall) DoAndReturn(f func() (ackhandler.Frame, protocol.ByteCount)) *MockMTUDiscovererGetPingCall { + c.Call = c.Call.DoAndReturn(f) + return c } // ShouldSendProbe mocks base method. @@ -78,9 +127,33 @@ func (m *MockMTUDiscoverer) ShouldSendProbe(arg0 time.Time) bool { } // ShouldSendProbe indicates an expected call of ShouldSendProbe. -func (mr *MockMTUDiscovererMockRecorder) ShouldSendProbe(arg0 any) *gomock.Call { +func (mr *MockMTUDiscovererMockRecorder) ShouldSendProbe(arg0 any) *MockMTUDiscovererShouldSendProbeCall { mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "ShouldSendProbe", reflect.TypeOf((*MockMTUDiscoverer)(nil).ShouldSendProbe), arg0) + call := mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "ShouldSendProbe", reflect.TypeOf((*MockMTUDiscoverer)(nil).ShouldSendProbe), arg0) + return &MockMTUDiscovererShouldSendProbeCall{Call: call} +} + +// MockMTUDiscovererShouldSendProbeCall wrap *gomock.Call +type MockMTUDiscovererShouldSendProbeCall struct { + *gomock.Call +} + +// Return rewrite *gomock.Call.Return +func (c *MockMTUDiscovererShouldSendProbeCall) Return(arg0 bool) *MockMTUDiscovererShouldSendProbeCall { + c.Call = c.Call.Return(arg0) + return c +} + +// Do rewrite *gomock.Call.Do +func (c *MockMTUDiscovererShouldSendProbeCall) Do(f func(time.Time) bool) *MockMTUDiscovererShouldSendProbeCall { + c.Call = c.Call.Do(f) + return c +} + +// DoAndReturn rewrite *gomock.Call.DoAndReturn +func (c *MockMTUDiscovererShouldSendProbeCall) DoAndReturn(f func(time.Time) bool) *MockMTUDiscovererShouldSendProbeCall { + c.Call = c.Call.DoAndReturn(f) + return c } // Start mocks base method. @@ -90,7 +163,31 @@ func (m *MockMTUDiscoverer) Start(arg0 protocol.ByteCount) { } // Start indicates an expected call of Start. -func (mr *MockMTUDiscovererMockRecorder) Start(arg0 any) *gomock.Call { +func (mr *MockMTUDiscovererMockRecorder) Start(arg0 any) *MockMTUDiscovererStartCall { mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Start", reflect.TypeOf((*MockMTUDiscoverer)(nil).Start), arg0) + call := mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Start", reflect.TypeOf((*MockMTUDiscoverer)(nil).Start), arg0) + return &MockMTUDiscovererStartCall{Call: call} +} + +// MockMTUDiscovererStartCall wrap *gomock.Call +type MockMTUDiscovererStartCall struct { + *gomock.Call +} + +// Return rewrite *gomock.Call.Return +func (c *MockMTUDiscovererStartCall) Return() *MockMTUDiscovererStartCall { + c.Call = c.Call.Return() + return c +} + +// Do rewrite *gomock.Call.Do +func (c *MockMTUDiscovererStartCall) Do(f func(protocol.ByteCount)) *MockMTUDiscovererStartCall { + c.Call = c.Call.Do(f) + return c +} + +// DoAndReturn rewrite *gomock.Call.DoAndReturn +func (c *MockMTUDiscovererStartCall) DoAndReturn(f func(protocol.ByteCount)) *MockMTUDiscovererStartCall { + c.Call = c.Call.DoAndReturn(f) + return c } diff --git a/mock_packer_test.go b/mock_packer_test.go index f73f6f5c7..6062f96a2 100644 --- a/mock_packer_test.go +++ b/mock_packer_test.go @@ -3,8 +3,9 @@ // // Generated by this command: // -// mockgen -build_flags=-tags=gomock -package quic -self_package github.com/refraction-networking/uquic -destination mock_packer_test.go github.com/refraction-networking/uquic Packer +// mockgen -typed -build_flags=-tags=gomock -package quic -self_package github.com/quic-go/quic-go -destination mock_packer_test.go github.com/quic-go/quic-go Packer // + // Package quic is a generated GoMock package. package quic @@ -41,7 +42,7 @@ func (m *MockPacker) EXPECT() *MockPackerMockRecorder { } // AppendPacket mocks base method. -func (m *MockPacker) AppendPacket(arg0 *packetBuffer, arg1 protocol.ByteCount, arg2 protocol.VersionNumber) (shortHeaderPacket, error) { +func (m *MockPacker) AppendPacket(arg0 *packetBuffer, arg1 protocol.ByteCount, arg2 protocol.Version) (shortHeaderPacket, error) { m.ctrl.T.Helper() ret := m.ctrl.Call(m, "AppendPacket", arg0, arg1, arg2) ret0, _ := ret[0].(shortHeaderPacket) @@ -50,13 +51,37 @@ func (m *MockPacker) AppendPacket(arg0 *packetBuffer, arg1 protocol.ByteCount, a } // AppendPacket indicates an expected call of AppendPacket. -func (mr *MockPackerMockRecorder) AppendPacket(arg0, arg1, arg2 any) *gomock.Call { +func (mr *MockPackerMockRecorder) AppendPacket(arg0, arg1, arg2 any) *MockPackerAppendPacketCall { mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "AppendPacket", reflect.TypeOf((*MockPacker)(nil).AppendPacket), arg0, arg1, arg2) + call := mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "AppendPacket", reflect.TypeOf((*MockPacker)(nil).AppendPacket), arg0, arg1, arg2) + return &MockPackerAppendPacketCall{Call: call} +} + +// MockPackerAppendPacketCall wrap *gomock.Call +type MockPackerAppendPacketCall struct { + *gomock.Call +} + +// Return rewrite *gomock.Call.Return +func (c *MockPackerAppendPacketCall) Return(arg0 shortHeaderPacket, arg1 error) *MockPackerAppendPacketCall { + c.Call = c.Call.Return(arg0, arg1) + return c +} + +// Do rewrite *gomock.Call.Do +func (c *MockPackerAppendPacketCall) Do(f func(*packetBuffer, protocol.ByteCount, protocol.Version) (shortHeaderPacket, error)) *MockPackerAppendPacketCall { + c.Call = c.Call.Do(f) + return c +} + +// DoAndReturn rewrite *gomock.Call.DoAndReturn +func (c *MockPackerAppendPacketCall) DoAndReturn(f func(*packetBuffer, protocol.ByteCount, protocol.Version) (shortHeaderPacket, error)) *MockPackerAppendPacketCall { + c.Call = c.Call.DoAndReturn(f) + return c } // MaybePackProbePacket mocks base method. -func (m *MockPacker) MaybePackProbePacket(arg0 protocol.EncryptionLevel, arg1 protocol.ByteCount, arg2 protocol.VersionNumber) (*coalescedPacket, error) { +func (m *MockPacker) MaybePackProbePacket(arg0 protocol.EncryptionLevel, arg1 protocol.ByteCount, arg2 protocol.Version) (*coalescedPacket, error) { m.ctrl.T.Helper() ret := m.ctrl.Call(m, "MaybePackProbePacket", arg0, arg1, arg2) ret0, _ := ret[0].(*coalescedPacket) @@ -65,13 +90,37 @@ func (m *MockPacker) MaybePackProbePacket(arg0 protocol.EncryptionLevel, arg1 pr } // MaybePackProbePacket indicates an expected call of MaybePackProbePacket. -func (mr *MockPackerMockRecorder) MaybePackProbePacket(arg0, arg1, arg2 any) *gomock.Call { +func (mr *MockPackerMockRecorder) MaybePackProbePacket(arg0, arg1, arg2 any) *MockPackerMaybePackProbePacketCall { mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "MaybePackProbePacket", reflect.TypeOf((*MockPacker)(nil).MaybePackProbePacket), arg0, arg1, arg2) + call := mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "MaybePackProbePacket", reflect.TypeOf((*MockPacker)(nil).MaybePackProbePacket), arg0, arg1, arg2) + return &MockPackerMaybePackProbePacketCall{Call: call} +} + +// MockPackerMaybePackProbePacketCall wrap *gomock.Call +type MockPackerMaybePackProbePacketCall struct { + *gomock.Call +} + +// Return rewrite *gomock.Call.Return +func (c *MockPackerMaybePackProbePacketCall) Return(arg0 *coalescedPacket, arg1 error) *MockPackerMaybePackProbePacketCall { + c.Call = c.Call.Return(arg0, arg1) + return c +} + +// Do rewrite *gomock.Call.Do +func (c *MockPackerMaybePackProbePacketCall) Do(f func(protocol.EncryptionLevel, protocol.ByteCount, protocol.Version) (*coalescedPacket, error)) *MockPackerMaybePackProbePacketCall { + c.Call = c.Call.Do(f) + return c +} + +// DoAndReturn rewrite *gomock.Call.DoAndReturn +func (c *MockPackerMaybePackProbePacketCall) DoAndReturn(f func(protocol.EncryptionLevel, protocol.ByteCount, protocol.Version) (*coalescedPacket, error)) *MockPackerMaybePackProbePacketCall { + c.Call = c.Call.DoAndReturn(f) + return c } // PackAckOnlyPacket mocks base method. -func (m *MockPacker) PackAckOnlyPacket(arg0 protocol.ByteCount, arg1 protocol.VersionNumber) (shortHeaderPacket, *packetBuffer, error) { +func (m *MockPacker) PackAckOnlyPacket(arg0 protocol.ByteCount, arg1 protocol.Version) (shortHeaderPacket, *packetBuffer, error) { m.ctrl.T.Helper() ret := m.ctrl.Call(m, "PackAckOnlyPacket", arg0, arg1) ret0, _ := ret[0].(shortHeaderPacket) @@ -81,13 +130,37 @@ func (m *MockPacker) PackAckOnlyPacket(arg0 protocol.ByteCount, arg1 protocol.Ve } // PackAckOnlyPacket indicates an expected call of PackAckOnlyPacket. -func (mr *MockPackerMockRecorder) PackAckOnlyPacket(arg0, arg1 any) *gomock.Call { +func (mr *MockPackerMockRecorder) PackAckOnlyPacket(arg0, arg1 any) *MockPackerPackAckOnlyPacketCall { mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "PackAckOnlyPacket", reflect.TypeOf((*MockPacker)(nil).PackAckOnlyPacket), arg0, arg1) + call := mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "PackAckOnlyPacket", reflect.TypeOf((*MockPacker)(nil).PackAckOnlyPacket), arg0, arg1) + return &MockPackerPackAckOnlyPacketCall{Call: call} +} + +// MockPackerPackAckOnlyPacketCall wrap *gomock.Call +type MockPackerPackAckOnlyPacketCall struct { + *gomock.Call +} + +// Return rewrite *gomock.Call.Return +func (c *MockPackerPackAckOnlyPacketCall) Return(arg0 shortHeaderPacket, arg1 *packetBuffer, arg2 error) *MockPackerPackAckOnlyPacketCall { + c.Call = c.Call.Return(arg0, arg1, arg2) + return c +} + +// Do rewrite *gomock.Call.Do +func (c *MockPackerPackAckOnlyPacketCall) Do(f func(protocol.ByteCount, protocol.Version) (shortHeaderPacket, *packetBuffer, error)) *MockPackerPackAckOnlyPacketCall { + c.Call = c.Call.Do(f) + return c +} + +// DoAndReturn rewrite *gomock.Call.DoAndReturn +func (c *MockPackerPackAckOnlyPacketCall) DoAndReturn(f func(protocol.ByteCount, protocol.Version) (shortHeaderPacket, *packetBuffer, error)) *MockPackerPackAckOnlyPacketCall { + c.Call = c.Call.DoAndReturn(f) + return c } // PackApplicationClose mocks base method. -func (m *MockPacker) PackApplicationClose(arg0 *qerr.ApplicationError, arg1 protocol.ByteCount, arg2 protocol.VersionNumber) (*coalescedPacket, error) { +func (m *MockPacker) PackApplicationClose(arg0 *qerr.ApplicationError, arg1 protocol.ByteCount, arg2 protocol.Version) (*coalescedPacket, error) { m.ctrl.T.Helper() ret := m.ctrl.Call(m, "PackApplicationClose", arg0, arg1, arg2) ret0, _ := ret[0].(*coalescedPacket) @@ -96,13 +169,37 @@ func (m *MockPacker) PackApplicationClose(arg0 *qerr.ApplicationError, arg1 prot } // PackApplicationClose indicates an expected call of PackApplicationClose. -func (mr *MockPackerMockRecorder) PackApplicationClose(arg0, arg1, arg2 any) *gomock.Call { +func (mr *MockPackerMockRecorder) PackApplicationClose(arg0, arg1, arg2 any) *MockPackerPackApplicationCloseCall { mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "PackApplicationClose", reflect.TypeOf((*MockPacker)(nil).PackApplicationClose), arg0, arg1, arg2) + call := mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "PackApplicationClose", reflect.TypeOf((*MockPacker)(nil).PackApplicationClose), arg0, arg1, arg2) + return &MockPackerPackApplicationCloseCall{Call: call} +} + +// MockPackerPackApplicationCloseCall wrap *gomock.Call +type MockPackerPackApplicationCloseCall struct { + *gomock.Call +} + +// Return rewrite *gomock.Call.Return +func (c *MockPackerPackApplicationCloseCall) Return(arg0 *coalescedPacket, arg1 error) *MockPackerPackApplicationCloseCall { + c.Call = c.Call.Return(arg0, arg1) + return c +} + +// Do rewrite *gomock.Call.Do +func (c *MockPackerPackApplicationCloseCall) Do(f func(*qerr.ApplicationError, protocol.ByteCount, protocol.Version) (*coalescedPacket, error)) *MockPackerPackApplicationCloseCall { + c.Call = c.Call.Do(f) + return c +} + +// DoAndReturn rewrite *gomock.Call.DoAndReturn +func (c *MockPackerPackApplicationCloseCall) DoAndReturn(f func(*qerr.ApplicationError, protocol.ByteCount, protocol.Version) (*coalescedPacket, error)) *MockPackerPackApplicationCloseCall { + c.Call = c.Call.DoAndReturn(f) + return c } // PackCoalescedPacket mocks base method. -func (m *MockPacker) PackCoalescedPacket(arg0 bool, arg1 protocol.ByteCount, arg2 protocol.VersionNumber) (*coalescedPacket, error) { +func (m *MockPacker) PackCoalescedPacket(arg0 bool, arg1 protocol.ByteCount, arg2 protocol.Version) (*coalescedPacket, error) { m.ctrl.T.Helper() ret := m.ctrl.Call(m, "PackCoalescedPacket", arg0, arg1, arg2) ret0, _ := ret[0].(*coalescedPacket) @@ -111,13 +208,37 @@ func (m *MockPacker) PackCoalescedPacket(arg0 bool, arg1 protocol.ByteCount, arg } // PackCoalescedPacket indicates an expected call of PackCoalescedPacket. -func (mr *MockPackerMockRecorder) PackCoalescedPacket(arg0, arg1, arg2 any) *gomock.Call { +func (mr *MockPackerMockRecorder) PackCoalescedPacket(arg0, arg1, arg2 any) *MockPackerPackCoalescedPacketCall { mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "PackCoalescedPacket", reflect.TypeOf((*MockPacker)(nil).PackCoalescedPacket), arg0, arg1, arg2) + call := mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "PackCoalescedPacket", reflect.TypeOf((*MockPacker)(nil).PackCoalescedPacket), arg0, arg1, arg2) + return &MockPackerPackCoalescedPacketCall{Call: call} +} + +// MockPackerPackCoalescedPacketCall wrap *gomock.Call +type MockPackerPackCoalescedPacketCall struct { + *gomock.Call +} + +// Return rewrite *gomock.Call.Return +func (c *MockPackerPackCoalescedPacketCall) Return(arg0 *coalescedPacket, arg1 error) *MockPackerPackCoalescedPacketCall { + c.Call = c.Call.Return(arg0, arg1) + return c +} + +// Do rewrite *gomock.Call.Do +func (c *MockPackerPackCoalescedPacketCall) Do(f func(bool, protocol.ByteCount, protocol.Version) (*coalescedPacket, error)) *MockPackerPackCoalescedPacketCall { + c.Call = c.Call.Do(f) + return c +} + +// DoAndReturn rewrite *gomock.Call.DoAndReturn +func (c *MockPackerPackCoalescedPacketCall) DoAndReturn(f func(bool, protocol.ByteCount, protocol.Version) (*coalescedPacket, error)) *MockPackerPackCoalescedPacketCall { + c.Call = c.Call.DoAndReturn(f) + return c } // PackConnectionClose mocks base method. -func (m *MockPacker) PackConnectionClose(arg0 *qerr.TransportError, arg1 protocol.ByteCount, arg2 protocol.VersionNumber) (*coalescedPacket, error) { +func (m *MockPacker) PackConnectionClose(arg0 *qerr.TransportError, arg1 protocol.ByteCount, arg2 protocol.Version) (*coalescedPacket, error) { m.ctrl.T.Helper() ret := m.ctrl.Call(m, "PackConnectionClose", arg0, arg1, arg2) ret0, _ := ret[0].(*coalescedPacket) @@ -126,13 +247,37 @@ func (m *MockPacker) PackConnectionClose(arg0 *qerr.TransportError, arg1 protoco } // PackConnectionClose indicates an expected call of PackConnectionClose. -func (mr *MockPackerMockRecorder) PackConnectionClose(arg0, arg1, arg2 any) *gomock.Call { +func (mr *MockPackerMockRecorder) PackConnectionClose(arg0, arg1, arg2 any) *MockPackerPackConnectionCloseCall { mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "PackConnectionClose", reflect.TypeOf((*MockPacker)(nil).PackConnectionClose), arg0, arg1, arg2) + call := mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "PackConnectionClose", reflect.TypeOf((*MockPacker)(nil).PackConnectionClose), arg0, arg1, arg2) + return &MockPackerPackConnectionCloseCall{Call: call} +} + +// MockPackerPackConnectionCloseCall wrap *gomock.Call +type MockPackerPackConnectionCloseCall struct { + *gomock.Call +} + +// Return rewrite *gomock.Call.Return +func (c *MockPackerPackConnectionCloseCall) Return(arg0 *coalescedPacket, arg1 error) *MockPackerPackConnectionCloseCall { + c.Call = c.Call.Return(arg0, arg1) + return c +} + +// Do rewrite *gomock.Call.Do +func (c *MockPackerPackConnectionCloseCall) Do(f func(*qerr.TransportError, protocol.ByteCount, protocol.Version) (*coalescedPacket, error)) *MockPackerPackConnectionCloseCall { + c.Call = c.Call.Do(f) + return c +} + +// DoAndReturn rewrite *gomock.Call.DoAndReturn +func (c *MockPackerPackConnectionCloseCall) DoAndReturn(f func(*qerr.TransportError, protocol.ByteCount, protocol.Version) (*coalescedPacket, error)) *MockPackerPackConnectionCloseCall { + c.Call = c.Call.DoAndReturn(f) + return c } // PackMTUProbePacket mocks base method. -func (m *MockPacker) PackMTUProbePacket(arg0 ackhandler.Frame, arg1 protocol.ByteCount, arg2 protocol.VersionNumber) (shortHeaderPacket, *packetBuffer, error) { +func (m *MockPacker) PackMTUProbePacket(arg0 ackhandler.Frame, arg1 protocol.ByteCount, arg2 protocol.Version) (shortHeaderPacket, *packetBuffer, error) { m.ctrl.T.Helper() ret := m.ctrl.Call(m, "PackMTUProbePacket", arg0, arg1, arg2) ret0, _ := ret[0].(shortHeaderPacket) @@ -142,9 +287,33 @@ func (m *MockPacker) PackMTUProbePacket(arg0 ackhandler.Frame, arg1 protocol.Byt } // PackMTUProbePacket indicates an expected call of PackMTUProbePacket. -func (mr *MockPackerMockRecorder) PackMTUProbePacket(arg0, arg1, arg2 any) *gomock.Call { +func (mr *MockPackerMockRecorder) PackMTUProbePacket(arg0, arg1, arg2 any) *MockPackerPackMTUProbePacketCall { mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "PackMTUProbePacket", reflect.TypeOf((*MockPacker)(nil).PackMTUProbePacket), arg0, arg1, arg2) + call := mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "PackMTUProbePacket", reflect.TypeOf((*MockPacker)(nil).PackMTUProbePacket), arg0, arg1, arg2) + return &MockPackerPackMTUProbePacketCall{Call: call} +} + +// MockPackerPackMTUProbePacketCall wrap *gomock.Call +type MockPackerPackMTUProbePacketCall struct { + *gomock.Call +} + +// Return rewrite *gomock.Call.Return +func (c *MockPackerPackMTUProbePacketCall) Return(arg0 shortHeaderPacket, arg1 *packetBuffer, arg2 error) *MockPackerPackMTUProbePacketCall { + c.Call = c.Call.Return(arg0, arg1, arg2) + return c +} + +// Do rewrite *gomock.Call.Do +func (c *MockPackerPackMTUProbePacketCall) Do(f func(ackhandler.Frame, protocol.ByteCount, protocol.Version) (shortHeaderPacket, *packetBuffer, error)) *MockPackerPackMTUProbePacketCall { + c.Call = c.Call.Do(f) + return c +} + +// DoAndReturn rewrite *gomock.Call.DoAndReturn +func (c *MockPackerPackMTUProbePacketCall) DoAndReturn(f func(ackhandler.Frame, protocol.ByteCount, protocol.Version) (shortHeaderPacket, *packetBuffer, error)) *MockPackerPackMTUProbePacketCall { + c.Call = c.Call.DoAndReturn(f) + return c } // SetToken mocks base method. @@ -154,7 +323,31 @@ func (m *MockPacker) SetToken(arg0 []byte) { } // SetToken indicates an expected call of SetToken. -func (mr *MockPackerMockRecorder) SetToken(arg0 any) *gomock.Call { +func (mr *MockPackerMockRecorder) SetToken(arg0 any) *MockPackerSetTokenCall { mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "SetToken", reflect.TypeOf((*MockPacker)(nil).SetToken), arg0) + call := mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "SetToken", reflect.TypeOf((*MockPacker)(nil).SetToken), arg0) + return &MockPackerSetTokenCall{Call: call} +} + +// MockPackerSetTokenCall wrap *gomock.Call +type MockPackerSetTokenCall struct { + *gomock.Call +} + +// Return rewrite *gomock.Call.Return +func (c *MockPackerSetTokenCall) Return() *MockPackerSetTokenCall { + c.Call = c.Call.Return() + return c +} + +// Do rewrite *gomock.Call.Do +func (c *MockPackerSetTokenCall) Do(f func([]byte)) *MockPackerSetTokenCall { + c.Call = c.Call.Do(f) + return c +} + +// DoAndReturn rewrite *gomock.Call.DoAndReturn +func (c *MockPackerSetTokenCall) DoAndReturn(f func([]byte)) *MockPackerSetTokenCall { + c.Call = c.Call.DoAndReturn(f) + return c } diff --git a/mock_packet_handler_manager_test.go b/mock_packet_handler_manager_test.go index ea49f66bc..bad6cb434 100644 --- a/mock_packet_handler_manager_test.go +++ b/mock_packet_handler_manager_test.go @@ -3,8 +3,9 @@ // // Generated by this command: // -// mockgen -build_flags=-tags=gomock -package quic -self_package github.com/refraction-networking/uquic -destination mock_packet_handler_manager_test.go github.com/refraction-networking/uquic PacketHandlerManager +// mockgen -typed -build_flags=-tags=gomock -package quic -self_package github.com/quic-go/quic-go -destination mock_packet_handler_manager_test.go github.com/quic-go/quic-go PacketHandlerManager // + // Package quic is a generated GoMock package. package quic @@ -47,9 +48,33 @@ func (m *MockPacketHandlerManager) Add(arg0 protocol.ConnectionID, arg1 packetHa } // Add indicates an expected call of Add. -func (mr *MockPacketHandlerManagerMockRecorder) Add(arg0, arg1 any) *gomock.Call { +func (mr *MockPacketHandlerManagerMockRecorder) Add(arg0, arg1 any) *MockPacketHandlerManagerAddCall { mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Add", reflect.TypeOf((*MockPacketHandlerManager)(nil).Add), arg0, arg1) + call := mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Add", reflect.TypeOf((*MockPacketHandlerManager)(nil).Add), arg0, arg1) + return &MockPacketHandlerManagerAddCall{Call: call} +} + +// MockPacketHandlerManagerAddCall wrap *gomock.Call +type MockPacketHandlerManagerAddCall struct { + *gomock.Call +} + +// Return rewrite *gomock.Call.Return +func (c *MockPacketHandlerManagerAddCall) Return(arg0 bool) *MockPacketHandlerManagerAddCall { + c.Call = c.Call.Return(arg0) + return c +} + +// Do rewrite *gomock.Call.Do +func (c *MockPacketHandlerManagerAddCall) Do(f func(protocol.ConnectionID, packetHandler) bool) *MockPacketHandlerManagerAddCall { + c.Call = c.Call.Do(f) + return c +} + +// DoAndReturn rewrite *gomock.Call.DoAndReturn +func (c *MockPacketHandlerManagerAddCall) DoAndReturn(f func(protocol.ConnectionID, packetHandler) bool) *MockPacketHandlerManagerAddCall { + c.Call = c.Call.DoAndReturn(f) + return c } // AddResetToken mocks base method. @@ -59,13 +84,37 @@ func (m *MockPacketHandlerManager) AddResetToken(arg0 protocol.StatelessResetTok } // AddResetToken indicates an expected call of AddResetToken. -func (mr *MockPacketHandlerManagerMockRecorder) AddResetToken(arg0, arg1 any) *gomock.Call { +func (mr *MockPacketHandlerManagerMockRecorder) AddResetToken(arg0, arg1 any) *MockPacketHandlerManagerAddResetTokenCall { mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "AddResetToken", reflect.TypeOf((*MockPacketHandlerManager)(nil).AddResetToken), arg0, arg1) + call := mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "AddResetToken", reflect.TypeOf((*MockPacketHandlerManager)(nil).AddResetToken), arg0, arg1) + return &MockPacketHandlerManagerAddResetTokenCall{Call: call} +} + +// MockPacketHandlerManagerAddResetTokenCall wrap *gomock.Call +type MockPacketHandlerManagerAddResetTokenCall struct { + *gomock.Call +} + +// Return rewrite *gomock.Call.Return +func (c *MockPacketHandlerManagerAddResetTokenCall) Return() *MockPacketHandlerManagerAddResetTokenCall { + c.Call = c.Call.Return() + return c +} + +// Do rewrite *gomock.Call.Do +func (c *MockPacketHandlerManagerAddResetTokenCall) Do(f func(protocol.StatelessResetToken, packetHandler)) *MockPacketHandlerManagerAddResetTokenCall { + c.Call = c.Call.Do(f) + return c +} + +// DoAndReturn rewrite *gomock.Call.DoAndReturn +func (c *MockPacketHandlerManagerAddResetTokenCall) DoAndReturn(f func(protocol.StatelessResetToken, packetHandler)) *MockPacketHandlerManagerAddResetTokenCall { + c.Call = c.Call.DoAndReturn(f) + return c } // AddWithConnID mocks base method. -func (m *MockPacketHandlerManager) AddWithConnID(arg0, arg1 protocol.ConnectionID, arg2 func() (packetHandler, bool)) bool { +func (m *MockPacketHandlerManager) AddWithConnID(arg0, arg1 protocol.ConnectionID, arg2 packetHandler) bool { m.ctrl.T.Helper() ret := m.ctrl.Call(m, "AddWithConnID", arg0, arg1, arg2) ret0, _ := ret[0].(bool) @@ -73,9 +122,33 @@ func (m *MockPacketHandlerManager) AddWithConnID(arg0, arg1 protocol.ConnectionI } // AddWithConnID indicates an expected call of AddWithConnID. -func (mr *MockPacketHandlerManagerMockRecorder) AddWithConnID(arg0, arg1, arg2 any) *gomock.Call { +func (mr *MockPacketHandlerManagerMockRecorder) AddWithConnID(arg0, arg1, arg2 any) *MockPacketHandlerManagerAddWithConnIDCall { mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "AddWithConnID", reflect.TypeOf((*MockPacketHandlerManager)(nil).AddWithConnID), arg0, arg1, arg2) + call := mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "AddWithConnID", reflect.TypeOf((*MockPacketHandlerManager)(nil).AddWithConnID), arg0, arg1, arg2) + return &MockPacketHandlerManagerAddWithConnIDCall{Call: call} +} + +// MockPacketHandlerManagerAddWithConnIDCall wrap *gomock.Call +type MockPacketHandlerManagerAddWithConnIDCall struct { + *gomock.Call +} + +// Return rewrite *gomock.Call.Return +func (c *MockPacketHandlerManagerAddWithConnIDCall) Return(arg0 bool) *MockPacketHandlerManagerAddWithConnIDCall { + c.Call = c.Call.Return(arg0) + return c +} + +// Do rewrite *gomock.Call.Do +func (c *MockPacketHandlerManagerAddWithConnIDCall) Do(f func(protocol.ConnectionID, protocol.ConnectionID, packetHandler) bool) *MockPacketHandlerManagerAddWithConnIDCall { + c.Call = c.Call.Do(f) + return c +} + +// DoAndReturn rewrite *gomock.Call.DoAndReturn +func (c *MockPacketHandlerManagerAddWithConnIDCall) DoAndReturn(f func(protocol.ConnectionID, protocol.ConnectionID, packetHandler) bool) *MockPacketHandlerManagerAddWithConnIDCall { + c.Call = c.Call.DoAndReturn(f) + return c } // Close mocks base method. @@ -85,21 +158,33 @@ func (m *MockPacketHandlerManager) Close(arg0 error) { } // Close indicates an expected call of Close. -func (mr *MockPacketHandlerManagerMockRecorder) Close(arg0 any) *gomock.Call { +func (mr *MockPacketHandlerManagerMockRecorder) Close(arg0 any) *MockPacketHandlerManagerCloseCall { mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Close", reflect.TypeOf((*MockPacketHandlerManager)(nil).Close), arg0) + call := mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Close", reflect.TypeOf((*MockPacketHandlerManager)(nil).Close), arg0) + return &MockPacketHandlerManagerCloseCall{Call: call} } -// CloseServer mocks base method. -func (m *MockPacketHandlerManager) CloseServer() { - m.ctrl.T.Helper() - m.ctrl.Call(m, "CloseServer") +// MockPacketHandlerManagerCloseCall wrap *gomock.Call +type MockPacketHandlerManagerCloseCall struct { + *gomock.Call } -// CloseServer indicates an expected call of CloseServer. -func (mr *MockPacketHandlerManagerMockRecorder) CloseServer() *gomock.Call { - mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "CloseServer", reflect.TypeOf((*MockPacketHandlerManager)(nil).CloseServer)) +// Return rewrite *gomock.Call.Return +func (c *MockPacketHandlerManagerCloseCall) Return() *MockPacketHandlerManagerCloseCall { + c.Call = c.Call.Return() + return c +} + +// Do rewrite *gomock.Call.Do +func (c *MockPacketHandlerManagerCloseCall) Do(f func(error)) *MockPacketHandlerManagerCloseCall { + c.Call = c.Call.Do(f) + return c +} + +// DoAndReturn rewrite *gomock.Call.DoAndReturn +func (c *MockPacketHandlerManagerCloseCall) DoAndReturn(f func(error)) *MockPacketHandlerManagerCloseCall { + c.Call = c.Call.DoAndReturn(f) + return c } // Get mocks base method. @@ -112,9 +197,33 @@ func (m *MockPacketHandlerManager) Get(arg0 protocol.ConnectionID) (packetHandle } // Get indicates an expected call of Get. -func (mr *MockPacketHandlerManagerMockRecorder) Get(arg0 any) *gomock.Call { +func (mr *MockPacketHandlerManagerMockRecorder) Get(arg0 any) *MockPacketHandlerManagerGetCall { mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Get", reflect.TypeOf((*MockPacketHandlerManager)(nil).Get), arg0) + call := mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Get", reflect.TypeOf((*MockPacketHandlerManager)(nil).Get), arg0) + return &MockPacketHandlerManagerGetCall{Call: call} +} + +// MockPacketHandlerManagerGetCall wrap *gomock.Call +type MockPacketHandlerManagerGetCall struct { + *gomock.Call +} + +// Return rewrite *gomock.Call.Return +func (c *MockPacketHandlerManagerGetCall) Return(arg0 packetHandler, arg1 bool) *MockPacketHandlerManagerGetCall { + c.Call = c.Call.Return(arg0, arg1) + return c +} + +// Do rewrite *gomock.Call.Do +func (c *MockPacketHandlerManagerGetCall) Do(f func(protocol.ConnectionID) (packetHandler, bool)) *MockPacketHandlerManagerGetCall { + c.Call = c.Call.Do(f) + return c +} + +// DoAndReturn rewrite *gomock.Call.DoAndReturn +func (c *MockPacketHandlerManagerGetCall) DoAndReturn(f func(protocol.ConnectionID) (packetHandler, bool)) *MockPacketHandlerManagerGetCall { + c.Call = c.Call.DoAndReturn(f) + return c } // GetByResetToken mocks base method. @@ -127,9 +236,33 @@ func (m *MockPacketHandlerManager) GetByResetToken(arg0 protocol.StatelessResetT } // GetByResetToken indicates an expected call of GetByResetToken. -func (mr *MockPacketHandlerManagerMockRecorder) GetByResetToken(arg0 any) *gomock.Call { +func (mr *MockPacketHandlerManagerMockRecorder) GetByResetToken(arg0 any) *MockPacketHandlerManagerGetByResetTokenCall { mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetByResetToken", reflect.TypeOf((*MockPacketHandlerManager)(nil).GetByResetToken), arg0) + call := mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetByResetToken", reflect.TypeOf((*MockPacketHandlerManager)(nil).GetByResetToken), arg0) + return &MockPacketHandlerManagerGetByResetTokenCall{Call: call} +} + +// MockPacketHandlerManagerGetByResetTokenCall wrap *gomock.Call +type MockPacketHandlerManagerGetByResetTokenCall struct { + *gomock.Call +} + +// Return rewrite *gomock.Call.Return +func (c *MockPacketHandlerManagerGetByResetTokenCall) Return(arg0 packetHandler, arg1 bool) *MockPacketHandlerManagerGetByResetTokenCall { + c.Call = c.Call.Return(arg0, arg1) + return c +} + +// Do rewrite *gomock.Call.Do +func (c *MockPacketHandlerManagerGetByResetTokenCall) Do(f func(protocol.StatelessResetToken) (packetHandler, bool)) *MockPacketHandlerManagerGetByResetTokenCall { + c.Call = c.Call.Do(f) + return c +} + +// DoAndReturn rewrite *gomock.Call.DoAndReturn +func (c *MockPacketHandlerManagerGetByResetTokenCall) DoAndReturn(f func(protocol.StatelessResetToken) (packetHandler, bool)) *MockPacketHandlerManagerGetByResetTokenCall { + c.Call = c.Call.DoAndReturn(f) + return c } // GetStatelessResetToken mocks base method. @@ -141,9 +274,33 @@ func (m *MockPacketHandlerManager) GetStatelessResetToken(arg0 protocol.Connecti } // GetStatelessResetToken indicates an expected call of GetStatelessResetToken. -func (mr *MockPacketHandlerManagerMockRecorder) GetStatelessResetToken(arg0 any) *gomock.Call { +func (mr *MockPacketHandlerManagerMockRecorder) GetStatelessResetToken(arg0 any) *MockPacketHandlerManagerGetStatelessResetTokenCall { mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetStatelessResetToken", reflect.TypeOf((*MockPacketHandlerManager)(nil).GetStatelessResetToken), arg0) + call := mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetStatelessResetToken", reflect.TypeOf((*MockPacketHandlerManager)(nil).GetStatelessResetToken), arg0) + return &MockPacketHandlerManagerGetStatelessResetTokenCall{Call: call} +} + +// MockPacketHandlerManagerGetStatelessResetTokenCall wrap *gomock.Call +type MockPacketHandlerManagerGetStatelessResetTokenCall struct { + *gomock.Call +} + +// Return rewrite *gomock.Call.Return +func (c *MockPacketHandlerManagerGetStatelessResetTokenCall) Return(arg0 protocol.StatelessResetToken) *MockPacketHandlerManagerGetStatelessResetTokenCall { + c.Call = c.Call.Return(arg0) + return c +} + +// Do rewrite *gomock.Call.Do +func (c *MockPacketHandlerManagerGetStatelessResetTokenCall) Do(f func(protocol.ConnectionID) protocol.StatelessResetToken) *MockPacketHandlerManagerGetStatelessResetTokenCall { + c.Call = c.Call.Do(f) + return c +} + +// DoAndReturn rewrite *gomock.Call.DoAndReturn +func (c *MockPacketHandlerManagerGetStatelessResetTokenCall) DoAndReturn(f func(protocol.ConnectionID) protocol.StatelessResetToken) *MockPacketHandlerManagerGetStatelessResetTokenCall { + c.Call = c.Call.DoAndReturn(f) + return c } // Remove mocks base method. @@ -153,9 +310,33 @@ func (m *MockPacketHandlerManager) Remove(arg0 protocol.ConnectionID) { } // Remove indicates an expected call of Remove. -func (mr *MockPacketHandlerManagerMockRecorder) Remove(arg0 any) *gomock.Call { +func (mr *MockPacketHandlerManagerMockRecorder) Remove(arg0 any) *MockPacketHandlerManagerRemoveCall { mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Remove", reflect.TypeOf((*MockPacketHandlerManager)(nil).Remove), arg0) + call := mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Remove", reflect.TypeOf((*MockPacketHandlerManager)(nil).Remove), arg0) + return &MockPacketHandlerManagerRemoveCall{Call: call} +} + +// MockPacketHandlerManagerRemoveCall wrap *gomock.Call +type MockPacketHandlerManagerRemoveCall struct { + *gomock.Call +} + +// Return rewrite *gomock.Call.Return +func (c *MockPacketHandlerManagerRemoveCall) Return() *MockPacketHandlerManagerRemoveCall { + c.Call = c.Call.Return() + return c +} + +// Do rewrite *gomock.Call.Do +func (c *MockPacketHandlerManagerRemoveCall) Do(f func(protocol.ConnectionID)) *MockPacketHandlerManagerRemoveCall { + c.Call = c.Call.Do(f) + return c +} + +// DoAndReturn rewrite *gomock.Call.DoAndReturn +func (c *MockPacketHandlerManagerRemoveCall) DoAndReturn(f func(protocol.ConnectionID)) *MockPacketHandlerManagerRemoveCall { + c.Call = c.Call.DoAndReturn(f) + return c } // RemoveResetToken mocks base method. @@ -165,21 +346,69 @@ func (m *MockPacketHandlerManager) RemoveResetToken(arg0 protocol.StatelessReset } // RemoveResetToken indicates an expected call of RemoveResetToken. -func (mr *MockPacketHandlerManagerMockRecorder) RemoveResetToken(arg0 any) *gomock.Call { +func (mr *MockPacketHandlerManagerMockRecorder) RemoveResetToken(arg0 any) *MockPacketHandlerManagerRemoveResetTokenCall { mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "RemoveResetToken", reflect.TypeOf((*MockPacketHandlerManager)(nil).RemoveResetToken), arg0) + call := mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "RemoveResetToken", reflect.TypeOf((*MockPacketHandlerManager)(nil).RemoveResetToken), arg0) + return &MockPacketHandlerManagerRemoveResetTokenCall{Call: call} +} + +// MockPacketHandlerManagerRemoveResetTokenCall wrap *gomock.Call +type MockPacketHandlerManagerRemoveResetTokenCall struct { + *gomock.Call +} + +// Return rewrite *gomock.Call.Return +func (c *MockPacketHandlerManagerRemoveResetTokenCall) Return() *MockPacketHandlerManagerRemoveResetTokenCall { + c.Call = c.Call.Return() + return c +} + +// Do rewrite *gomock.Call.Do +func (c *MockPacketHandlerManagerRemoveResetTokenCall) Do(f func(protocol.StatelessResetToken)) *MockPacketHandlerManagerRemoveResetTokenCall { + c.Call = c.Call.Do(f) + return c +} + +// DoAndReturn rewrite *gomock.Call.DoAndReturn +func (c *MockPacketHandlerManagerRemoveResetTokenCall) DoAndReturn(f func(protocol.StatelessResetToken)) *MockPacketHandlerManagerRemoveResetTokenCall { + c.Call = c.Call.DoAndReturn(f) + return c } // ReplaceWithClosed mocks base method. -func (m *MockPacketHandlerManager) ReplaceWithClosed(arg0 []protocol.ConnectionID, arg1 protocol.Perspective, arg2 []byte) { +func (m *MockPacketHandlerManager) ReplaceWithClosed(arg0 []protocol.ConnectionID, arg1 []byte) { m.ctrl.T.Helper() - m.ctrl.Call(m, "ReplaceWithClosed", arg0, arg1, arg2) + m.ctrl.Call(m, "ReplaceWithClosed", arg0, arg1) } // ReplaceWithClosed indicates an expected call of ReplaceWithClosed. -func (mr *MockPacketHandlerManagerMockRecorder) ReplaceWithClosed(arg0, arg1, arg2 any) *gomock.Call { +func (mr *MockPacketHandlerManagerMockRecorder) ReplaceWithClosed(arg0, arg1 any) *MockPacketHandlerManagerReplaceWithClosedCall { mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "ReplaceWithClosed", reflect.TypeOf((*MockPacketHandlerManager)(nil).ReplaceWithClosed), arg0, arg1, arg2) + call := mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "ReplaceWithClosed", reflect.TypeOf((*MockPacketHandlerManager)(nil).ReplaceWithClosed), arg0, arg1) + return &MockPacketHandlerManagerReplaceWithClosedCall{Call: call} +} + +// MockPacketHandlerManagerReplaceWithClosedCall wrap *gomock.Call +type MockPacketHandlerManagerReplaceWithClosedCall struct { + *gomock.Call +} + +// Return rewrite *gomock.Call.Return +func (c *MockPacketHandlerManagerReplaceWithClosedCall) Return() *MockPacketHandlerManagerReplaceWithClosedCall { + c.Call = c.Call.Return() + return c +} + +// Do rewrite *gomock.Call.Do +func (c *MockPacketHandlerManagerReplaceWithClosedCall) Do(f func([]protocol.ConnectionID, []byte)) *MockPacketHandlerManagerReplaceWithClosedCall { + c.Call = c.Call.Do(f) + return c +} + +// DoAndReturn rewrite *gomock.Call.DoAndReturn +func (c *MockPacketHandlerManagerReplaceWithClosedCall) DoAndReturn(f func([]protocol.ConnectionID, []byte)) *MockPacketHandlerManagerReplaceWithClosedCall { + c.Call = c.Call.DoAndReturn(f) + return c } // Retire mocks base method. @@ -189,7 +418,31 @@ func (m *MockPacketHandlerManager) Retire(arg0 protocol.ConnectionID) { } // Retire indicates an expected call of Retire. -func (mr *MockPacketHandlerManagerMockRecorder) Retire(arg0 any) *gomock.Call { +func (mr *MockPacketHandlerManagerMockRecorder) Retire(arg0 any) *MockPacketHandlerManagerRetireCall { mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Retire", reflect.TypeOf((*MockPacketHandlerManager)(nil).Retire), arg0) + call := mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Retire", reflect.TypeOf((*MockPacketHandlerManager)(nil).Retire), arg0) + return &MockPacketHandlerManagerRetireCall{Call: call} +} + +// MockPacketHandlerManagerRetireCall wrap *gomock.Call +type MockPacketHandlerManagerRetireCall struct { + *gomock.Call +} + +// Return rewrite *gomock.Call.Return +func (c *MockPacketHandlerManagerRetireCall) Return() *MockPacketHandlerManagerRetireCall { + c.Call = c.Call.Return() + return c +} + +// Do rewrite *gomock.Call.Do +func (c *MockPacketHandlerManagerRetireCall) Do(f func(protocol.ConnectionID)) *MockPacketHandlerManagerRetireCall { + c.Call = c.Call.Do(f) + return c +} + +// DoAndReturn rewrite *gomock.Call.DoAndReturn +func (c *MockPacketHandlerManagerRetireCall) DoAndReturn(f func(protocol.ConnectionID)) *MockPacketHandlerManagerRetireCall { + c.Call = c.Call.DoAndReturn(f) + return c } diff --git a/mock_packet_handler_test.go b/mock_packet_handler_test.go index 852d1f004..14af1fe0b 100644 --- a/mock_packet_handler_test.go +++ b/mock_packet_handler_test.go @@ -3,16 +3,17 @@ // // Generated by this command: // -// mockgen -build_flags=-tags=gomock -package quic -self_package github.com/refraction-networking/uquic -destination mock_packet_handler_test.go github.com/refraction-networking/uquic PacketHandler +// mockgen -typed -build_flags=-tags=gomock -package quic -self_package github.com/quic-go/quic-go -destination mock_packet_handler_test.go github.com/quic-go/quic-go PacketHandler // + // Package quic is a generated GoMock package. package quic import ( reflect "reflect" + qerr "github.com/refraction-networking/uquic/internal/qerr" gomock "go.uber.org/mock/gomock" - protocol "github.com/refraction-networking/uquic/internal/protocol" ) // MockPacketHandler is a mock of PacketHandler interface. @@ -38,6 +39,42 @@ func (m *MockPacketHandler) EXPECT() *MockPacketHandlerMockRecorder { return m.recorder } +// closeWithTransportError mocks base method. +func (m *MockPacketHandler) closeWithTransportError(arg0 qerr.TransportErrorCode) { + m.ctrl.T.Helper() + m.ctrl.Call(m, "closeWithTransportError", arg0) +} + +// closeWithTransportError indicates an expected call of closeWithTransportError. +func (mr *MockPacketHandlerMockRecorder) closeWithTransportError(arg0 any) *MockPacketHandlercloseWithTransportErrorCall { + mr.mock.ctrl.T.Helper() + call := mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "closeWithTransportError", reflect.TypeOf((*MockPacketHandler)(nil).closeWithTransportError), arg0) + return &MockPacketHandlercloseWithTransportErrorCall{Call: call} +} + +// MockPacketHandlercloseWithTransportErrorCall wrap *gomock.Call +type MockPacketHandlercloseWithTransportErrorCall struct { + *gomock.Call +} + +// Return rewrite *gomock.Call.Return +func (c *MockPacketHandlercloseWithTransportErrorCall) Return() *MockPacketHandlercloseWithTransportErrorCall { + c.Call = c.Call.Return() + return c +} + +// Do rewrite *gomock.Call.Do +func (c *MockPacketHandlercloseWithTransportErrorCall) Do(f func(qerr.TransportErrorCode)) *MockPacketHandlercloseWithTransportErrorCall { + c.Call = c.Call.Do(f) + return c +} + +// DoAndReturn rewrite *gomock.Call.DoAndReturn +func (c *MockPacketHandlercloseWithTransportErrorCall) DoAndReturn(f func(qerr.TransportErrorCode)) *MockPacketHandlercloseWithTransportErrorCall { + c.Call = c.Call.DoAndReturn(f) + return c +} + // destroy mocks base method. func (m *MockPacketHandler) destroy(arg0 error) { m.ctrl.T.Helper() @@ -45,23 +82,33 @@ func (m *MockPacketHandler) destroy(arg0 error) { } // destroy indicates an expected call of destroy. -func (mr *MockPacketHandlerMockRecorder) destroy(arg0 any) *gomock.Call { +func (mr *MockPacketHandlerMockRecorder) destroy(arg0 any) *MockPacketHandlerdestroyCall { mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "destroy", reflect.TypeOf((*MockPacketHandler)(nil).destroy), arg0) + call := mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "destroy", reflect.TypeOf((*MockPacketHandler)(nil).destroy), arg0) + return &MockPacketHandlerdestroyCall{Call: call} } -// getPerspective mocks base method. -func (m *MockPacketHandler) getPerspective() protocol.Perspective { - m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "getPerspective") - ret0, _ := ret[0].(protocol.Perspective) - return ret0 +// MockPacketHandlerdestroyCall wrap *gomock.Call +type MockPacketHandlerdestroyCall struct { + *gomock.Call } -// getPerspective indicates an expected call of getPerspective. -func (mr *MockPacketHandlerMockRecorder) getPerspective() *gomock.Call { - mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "getPerspective", reflect.TypeOf((*MockPacketHandler)(nil).getPerspective)) +// Return rewrite *gomock.Call.Return +func (c *MockPacketHandlerdestroyCall) Return() *MockPacketHandlerdestroyCall { + c.Call = c.Call.Return() + return c +} + +// Do rewrite *gomock.Call.Do +func (c *MockPacketHandlerdestroyCall) Do(f func(error)) *MockPacketHandlerdestroyCall { + c.Call = c.Call.Do(f) + return c +} + +// DoAndReturn rewrite *gomock.Call.DoAndReturn +func (c *MockPacketHandlerdestroyCall) DoAndReturn(f func(error)) *MockPacketHandlerdestroyCall { + c.Call = c.Call.DoAndReturn(f) + return c } // handlePacket mocks base method. @@ -71,19 +118,31 @@ func (m *MockPacketHandler) handlePacket(arg0 receivedPacket) { } // handlePacket indicates an expected call of handlePacket. -func (mr *MockPacketHandlerMockRecorder) handlePacket(arg0 any) *gomock.Call { +func (mr *MockPacketHandlerMockRecorder) handlePacket(arg0 any) *MockPacketHandlerhandlePacketCall { mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "handlePacket", reflect.TypeOf((*MockPacketHandler)(nil).handlePacket), arg0) + call := mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "handlePacket", reflect.TypeOf((*MockPacketHandler)(nil).handlePacket), arg0) + return &MockPacketHandlerhandlePacketCall{Call: call} } -// shutdown mocks base method. -func (m *MockPacketHandler) shutdown() { - m.ctrl.T.Helper() - m.ctrl.Call(m, "shutdown") +// MockPacketHandlerhandlePacketCall wrap *gomock.Call +type MockPacketHandlerhandlePacketCall struct { + *gomock.Call } -// shutdown indicates an expected call of shutdown. -func (mr *MockPacketHandlerMockRecorder) shutdown() *gomock.Call { - mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "shutdown", reflect.TypeOf((*MockPacketHandler)(nil).shutdown)) +// Return rewrite *gomock.Call.Return +func (c *MockPacketHandlerhandlePacketCall) Return() *MockPacketHandlerhandlePacketCall { + c.Call = c.Call.Return() + return c +} + +// Do rewrite *gomock.Call.Do +func (c *MockPacketHandlerhandlePacketCall) Do(f func(receivedPacket)) *MockPacketHandlerhandlePacketCall { + c.Call = c.Call.Do(f) + return c +} + +// DoAndReturn rewrite *gomock.Call.DoAndReturn +func (c *MockPacketHandlerhandlePacketCall) DoAndReturn(f func(receivedPacket)) *MockPacketHandlerhandlePacketCall { + c.Call = c.Call.DoAndReturn(f) + return c } diff --git a/mock_packetconn_test.go b/mock_packetconn_test.go index 6e317e3ed..45d36cbac 100644 --- a/mock_packetconn_test.go +++ b/mock_packetconn_test.go @@ -3,8 +3,9 @@ // // Generated by this command: // -// mockgen -package quic -self_package github.com/refraction-networking/uquic -self_package github.com/refraction-networking/uquic -destination mock_packetconn_test.go net PacketConn +// mockgen -typed -package quic -self_package github.com/quic-go/quic-go -self_package github.com/quic-go/quic-go -destination mock_packetconn_test.go net PacketConn // + // Package quic is a generated GoMock package. package quic @@ -48,9 +49,33 @@ func (m *MockPacketConn) Close() error { } // Close indicates an expected call of Close. -func (mr *MockPacketConnMockRecorder) Close() *gomock.Call { +func (mr *MockPacketConnMockRecorder) Close() *MockPacketConnCloseCall { mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Close", reflect.TypeOf((*MockPacketConn)(nil).Close)) + call := mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Close", reflect.TypeOf((*MockPacketConn)(nil).Close)) + return &MockPacketConnCloseCall{Call: call} +} + +// MockPacketConnCloseCall wrap *gomock.Call +type MockPacketConnCloseCall struct { + *gomock.Call +} + +// Return rewrite *gomock.Call.Return +func (c *MockPacketConnCloseCall) Return(arg0 error) *MockPacketConnCloseCall { + c.Call = c.Call.Return(arg0) + return c +} + +// Do rewrite *gomock.Call.Do +func (c *MockPacketConnCloseCall) Do(f func() error) *MockPacketConnCloseCall { + c.Call = c.Call.Do(f) + return c +} + +// DoAndReturn rewrite *gomock.Call.DoAndReturn +func (c *MockPacketConnCloseCall) DoAndReturn(f func() error) *MockPacketConnCloseCall { + c.Call = c.Call.DoAndReturn(f) + return c } // LocalAddr mocks base method. @@ -62,9 +87,33 @@ func (m *MockPacketConn) LocalAddr() net.Addr { } // LocalAddr indicates an expected call of LocalAddr. -func (mr *MockPacketConnMockRecorder) LocalAddr() *gomock.Call { +func (mr *MockPacketConnMockRecorder) LocalAddr() *MockPacketConnLocalAddrCall { mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "LocalAddr", reflect.TypeOf((*MockPacketConn)(nil).LocalAddr)) + call := mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "LocalAddr", reflect.TypeOf((*MockPacketConn)(nil).LocalAddr)) + return &MockPacketConnLocalAddrCall{Call: call} +} + +// MockPacketConnLocalAddrCall wrap *gomock.Call +type MockPacketConnLocalAddrCall struct { + *gomock.Call +} + +// Return rewrite *gomock.Call.Return +func (c *MockPacketConnLocalAddrCall) Return(arg0 net.Addr) *MockPacketConnLocalAddrCall { + c.Call = c.Call.Return(arg0) + return c +} + +// Do rewrite *gomock.Call.Do +func (c *MockPacketConnLocalAddrCall) Do(f func() net.Addr) *MockPacketConnLocalAddrCall { + c.Call = c.Call.Do(f) + return c +} + +// DoAndReturn rewrite *gomock.Call.DoAndReturn +func (c *MockPacketConnLocalAddrCall) DoAndReturn(f func() net.Addr) *MockPacketConnLocalAddrCall { + c.Call = c.Call.DoAndReturn(f) + return c } // ReadFrom mocks base method. @@ -78,9 +127,33 @@ func (m *MockPacketConn) ReadFrom(arg0 []byte) (int, net.Addr, error) { } // ReadFrom indicates an expected call of ReadFrom. -func (mr *MockPacketConnMockRecorder) ReadFrom(arg0 any) *gomock.Call { +func (mr *MockPacketConnMockRecorder) ReadFrom(arg0 any) *MockPacketConnReadFromCall { mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "ReadFrom", reflect.TypeOf((*MockPacketConn)(nil).ReadFrom), arg0) + call := mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "ReadFrom", reflect.TypeOf((*MockPacketConn)(nil).ReadFrom), arg0) + return &MockPacketConnReadFromCall{Call: call} +} + +// MockPacketConnReadFromCall wrap *gomock.Call +type MockPacketConnReadFromCall struct { + *gomock.Call +} + +// Return rewrite *gomock.Call.Return +func (c *MockPacketConnReadFromCall) Return(arg0 int, arg1 net.Addr, arg2 error) *MockPacketConnReadFromCall { + c.Call = c.Call.Return(arg0, arg1, arg2) + return c +} + +// Do rewrite *gomock.Call.Do +func (c *MockPacketConnReadFromCall) Do(f func([]byte) (int, net.Addr, error)) *MockPacketConnReadFromCall { + c.Call = c.Call.Do(f) + return c +} + +// DoAndReturn rewrite *gomock.Call.DoAndReturn +func (c *MockPacketConnReadFromCall) DoAndReturn(f func([]byte) (int, net.Addr, error)) *MockPacketConnReadFromCall { + c.Call = c.Call.DoAndReturn(f) + return c } // SetDeadline mocks base method. @@ -92,9 +165,33 @@ func (m *MockPacketConn) SetDeadline(arg0 time.Time) error { } // SetDeadline indicates an expected call of SetDeadline. -func (mr *MockPacketConnMockRecorder) SetDeadline(arg0 any) *gomock.Call { +func (mr *MockPacketConnMockRecorder) SetDeadline(arg0 any) *MockPacketConnSetDeadlineCall { mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "SetDeadline", reflect.TypeOf((*MockPacketConn)(nil).SetDeadline), arg0) + call := mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "SetDeadline", reflect.TypeOf((*MockPacketConn)(nil).SetDeadline), arg0) + return &MockPacketConnSetDeadlineCall{Call: call} +} + +// MockPacketConnSetDeadlineCall wrap *gomock.Call +type MockPacketConnSetDeadlineCall struct { + *gomock.Call +} + +// Return rewrite *gomock.Call.Return +func (c *MockPacketConnSetDeadlineCall) Return(arg0 error) *MockPacketConnSetDeadlineCall { + c.Call = c.Call.Return(arg0) + return c +} + +// Do rewrite *gomock.Call.Do +func (c *MockPacketConnSetDeadlineCall) Do(f func(time.Time) error) *MockPacketConnSetDeadlineCall { + c.Call = c.Call.Do(f) + return c +} + +// DoAndReturn rewrite *gomock.Call.DoAndReturn +func (c *MockPacketConnSetDeadlineCall) DoAndReturn(f func(time.Time) error) *MockPacketConnSetDeadlineCall { + c.Call = c.Call.DoAndReturn(f) + return c } // SetReadDeadline mocks base method. @@ -106,9 +203,33 @@ func (m *MockPacketConn) SetReadDeadline(arg0 time.Time) error { } // SetReadDeadline indicates an expected call of SetReadDeadline. -func (mr *MockPacketConnMockRecorder) SetReadDeadline(arg0 any) *gomock.Call { +func (mr *MockPacketConnMockRecorder) SetReadDeadline(arg0 any) *MockPacketConnSetReadDeadlineCall { mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "SetReadDeadline", reflect.TypeOf((*MockPacketConn)(nil).SetReadDeadline), arg0) + call := mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "SetReadDeadline", reflect.TypeOf((*MockPacketConn)(nil).SetReadDeadline), arg0) + return &MockPacketConnSetReadDeadlineCall{Call: call} +} + +// MockPacketConnSetReadDeadlineCall wrap *gomock.Call +type MockPacketConnSetReadDeadlineCall struct { + *gomock.Call +} + +// Return rewrite *gomock.Call.Return +func (c *MockPacketConnSetReadDeadlineCall) Return(arg0 error) *MockPacketConnSetReadDeadlineCall { + c.Call = c.Call.Return(arg0) + return c +} + +// Do rewrite *gomock.Call.Do +func (c *MockPacketConnSetReadDeadlineCall) Do(f func(time.Time) error) *MockPacketConnSetReadDeadlineCall { + c.Call = c.Call.Do(f) + return c +} + +// DoAndReturn rewrite *gomock.Call.DoAndReturn +func (c *MockPacketConnSetReadDeadlineCall) DoAndReturn(f func(time.Time) error) *MockPacketConnSetReadDeadlineCall { + c.Call = c.Call.DoAndReturn(f) + return c } // SetWriteDeadline mocks base method. @@ -120,9 +241,33 @@ func (m *MockPacketConn) SetWriteDeadline(arg0 time.Time) error { } // SetWriteDeadline indicates an expected call of SetWriteDeadline. -func (mr *MockPacketConnMockRecorder) SetWriteDeadline(arg0 any) *gomock.Call { +func (mr *MockPacketConnMockRecorder) SetWriteDeadline(arg0 any) *MockPacketConnSetWriteDeadlineCall { mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "SetWriteDeadline", reflect.TypeOf((*MockPacketConn)(nil).SetWriteDeadline), arg0) + call := mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "SetWriteDeadline", reflect.TypeOf((*MockPacketConn)(nil).SetWriteDeadline), arg0) + return &MockPacketConnSetWriteDeadlineCall{Call: call} +} + +// MockPacketConnSetWriteDeadlineCall wrap *gomock.Call +type MockPacketConnSetWriteDeadlineCall struct { + *gomock.Call +} + +// Return rewrite *gomock.Call.Return +func (c *MockPacketConnSetWriteDeadlineCall) Return(arg0 error) *MockPacketConnSetWriteDeadlineCall { + c.Call = c.Call.Return(arg0) + return c +} + +// Do rewrite *gomock.Call.Do +func (c *MockPacketConnSetWriteDeadlineCall) Do(f func(time.Time) error) *MockPacketConnSetWriteDeadlineCall { + c.Call = c.Call.Do(f) + return c +} + +// DoAndReturn rewrite *gomock.Call.DoAndReturn +func (c *MockPacketConnSetWriteDeadlineCall) DoAndReturn(f func(time.Time) error) *MockPacketConnSetWriteDeadlineCall { + c.Call = c.Call.DoAndReturn(f) + return c } // WriteTo mocks base method. @@ -135,7 +280,31 @@ func (m *MockPacketConn) WriteTo(arg0 []byte, arg1 net.Addr) (int, error) { } // WriteTo indicates an expected call of WriteTo. -func (mr *MockPacketConnMockRecorder) WriteTo(arg0, arg1 any) *gomock.Call { +func (mr *MockPacketConnMockRecorder) WriteTo(arg0, arg1 any) *MockPacketConnWriteToCall { mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "WriteTo", reflect.TypeOf((*MockPacketConn)(nil).WriteTo), arg0, arg1) + call := mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "WriteTo", reflect.TypeOf((*MockPacketConn)(nil).WriteTo), arg0, arg1) + return &MockPacketConnWriteToCall{Call: call} +} + +// MockPacketConnWriteToCall wrap *gomock.Call +type MockPacketConnWriteToCall struct { + *gomock.Call +} + +// Return rewrite *gomock.Call.Return +func (c *MockPacketConnWriteToCall) Return(arg0 int, arg1 error) *MockPacketConnWriteToCall { + c.Call = c.Call.Return(arg0, arg1) + return c +} + +// Do rewrite *gomock.Call.Do +func (c *MockPacketConnWriteToCall) Do(f func([]byte, net.Addr) (int, error)) *MockPacketConnWriteToCall { + c.Call = c.Call.Do(f) + return c +} + +// DoAndReturn rewrite *gomock.Call.DoAndReturn +func (c *MockPacketConnWriteToCall) DoAndReturn(f func([]byte, net.Addr) (int, error)) *MockPacketConnWriteToCall { + c.Call = c.Call.DoAndReturn(f) + return c } diff --git a/mock_quic_conn_test.go b/mock_quic_conn_test.go index fc7f1efaa..6b4ba0b0d 100644 --- a/mock_quic_conn_test.go +++ b/mock_quic_conn_test.go @@ -3,8 +3,9 @@ // // Generated by this command: // -// mockgen -build_flags=-tags=gomock -package quic -self_package github.com/refraction-networking/uquic -destination mock_quic_conn_test.go github.com/refraction-networking/uquic QUICConn +// mockgen -typed -build_flags=-tags=gomock -package quic -self_package github.com/quic-go/quic-go -destination mock_quic_conn_test.go github.com/quic-go/quic-go QUICConn // + // Package quic is a generated GoMock package. package quic @@ -13,9 +14,8 @@ import ( net "net" reflect "reflect" - gomock "go.uber.org/mock/gomock" - protocol "github.com/refraction-networking/uquic/internal/protocol" qerr "github.com/refraction-networking/uquic/internal/qerr" + gomock "go.uber.org/mock/gomock" ) // MockQUICConn is a mock of QUICConn interface. @@ -51,9 +51,33 @@ func (m *MockQUICConn) AcceptStream(arg0 context.Context) (Stream, error) { } // AcceptStream indicates an expected call of AcceptStream. -func (mr *MockQUICConnMockRecorder) AcceptStream(arg0 any) *gomock.Call { +func (mr *MockQUICConnMockRecorder) AcceptStream(arg0 any) *MockQUICConnAcceptStreamCall { mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "AcceptStream", reflect.TypeOf((*MockQUICConn)(nil).AcceptStream), arg0) + call := mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "AcceptStream", reflect.TypeOf((*MockQUICConn)(nil).AcceptStream), arg0) + return &MockQUICConnAcceptStreamCall{Call: call} +} + +// MockQUICConnAcceptStreamCall wrap *gomock.Call +type MockQUICConnAcceptStreamCall struct { + *gomock.Call +} + +// Return rewrite *gomock.Call.Return +func (c *MockQUICConnAcceptStreamCall) Return(arg0 Stream, arg1 error) *MockQUICConnAcceptStreamCall { + c.Call = c.Call.Return(arg0, arg1) + return c +} + +// Do rewrite *gomock.Call.Do +func (c *MockQUICConnAcceptStreamCall) Do(f func(context.Context) (Stream, error)) *MockQUICConnAcceptStreamCall { + c.Call = c.Call.Do(f) + return c +} + +// DoAndReturn rewrite *gomock.Call.DoAndReturn +func (c *MockQUICConnAcceptStreamCall) DoAndReturn(f func(context.Context) (Stream, error)) *MockQUICConnAcceptStreamCall { + c.Call = c.Call.DoAndReturn(f) + return c } // AcceptUniStream mocks base method. @@ -66,9 +90,33 @@ func (m *MockQUICConn) AcceptUniStream(arg0 context.Context) (ReceiveStream, err } // AcceptUniStream indicates an expected call of AcceptUniStream. -func (mr *MockQUICConnMockRecorder) AcceptUniStream(arg0 any) *gomock.Call { +func (mr *MockQUICConnMockRecorder) AcceptUniStream(arg0 any) *MockQUICConnAcceptUniStreamCall { mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "AcceptUniStream", reflect.TypeOf((*MockQUICConn)(nil).AcceptUniStream), arg0) + call := mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "AcceptUniStream", reflect.TypeOf((*MockQUICConn)(nil).AcceptUniStream), arg0) + return &MockQUICConnAcceptUniStreamCall{Call: call} +} + +// MockQUICConnAcceptUniStreamCall wrap *gomock.Call +type MockQUICConnAcceptUniStreamCall struct { + *gomock.Call +} + +// Return rewrite *gomock.Call.Return +func (c *MockQUICConnAcceptUniStreamCall) Return(arg0 ReceiveStream, arg1 error) *MockQUICConnAcceptUniStreamCall { + c.Call = c.Call.Return(arg0, arg1) + return c +} + +// Do rewrite *gomock.Call.Do +func (c *MockQUICConnAcceptUniStreamCall) Do(f func(context.Context) (ReceiveStream, error)) *MockQUICConnAcceptUniStreamCall { + c.Call = c.Call.Do(f) + return c +} + +// DoAndReturn rewrite *gomock.Call.DoAndReturn +func (c *MockQUICConnAcceptUniStreamCall) DoAndReturn(f func(context.Context) (ReceiveStream, error)) *MockQUICConnAcceptUniStreamCall { + c.Call = c.Call.DoAndReturn(f) + return c } // CloseWithError mocks base method. @@ -80,9 +128,33 @@ func (m *MockQUICConn) CloseWithError(arg0 qerr.ApplicationErrorCode, arg1 strin } // CloseWithError indicates an expected call of CloseWithError. -func (mr *MockQUICConnMockRecorder) CloseWithError(arg0, arg1 any) *gomock.Call { +func (mr *MockQUICConnMockRecorder) CloseWithError(arg0, arg1 any) *MockQUICConnCloseWithErrorCall { mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "CloseWithError", reflect.TypeOf((*MockQUICConn)(nil).CloseWithError), arg0, arg1) + call := mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "CloseWithError", reflect.TypeOf((*MockQUICConn)(nil).CloseWithError), arg0, arg1) + return &MockQUICConnCloseWithErrorCall{Call: call} +} + +// MockQUICConnCloseWithErrorCall wrap *gomock.Call +type MockQUICConnCloseWithErrorCall struct { + *gomock.Call +} + +// Return rewrite *gomock.Call.Return +func (c *MockQUICConnCloseWithErrorCall) Return(arg0 error) *MockQUICConnCloseWithErrorCall { + c.Call = c.Call.Return(arg0) + return c +} + +// Do rewrite *gomock.Call.Do +func (c *MockQUICConnCloseWithErrorCall) Do(f func(qerr.ApplicationErrorCode, string) error) *MockQUICConnCloseWithErrorCall { + c.Call = c.Call.Do(f) + return c +} + +// DoAndReturn rewrite *gomock.Call.DoAndReturn +func (c *MockQUICConnCloseWithErrorCall) DoAndReturn(f func(qerr.ApplicationErrorCode, string) error) *MockQUICConnCloseWithErrorCall { + c.Call = c.Call.DoAndReturn(f) + return c } // ConnectionState mocks base method. @@ -94,9 +166,33 @@ func (m *MockQUICConn) ConnectionState() ConnectionState { } // ConnectionState indicates an expected call of ConnectionState. -func (mr *MockQUICConnMockRecorder) ConnectionState() *gomock.Call { +func (mr *MockQUICConnMockRecorder) ConnectionState() *MockQUICConnConnectionStateCall { mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "ConnectionState", reflect.TypeOf((*MockQUICConn)(nil).ConnectionState)) + call := mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "ConnectionState", reflect.TypeOf((*MockQUICConn)(nil).ConnectionState)) + return &MockQUICConnConnectionStateCall{Call: call} +} + +// MockQUICConnConnectionStateCall wrap *gomock.Call +type MockQUICConnConnectionStateCall struct { + *gomock.Call +} + +// Return rewrite *gomock.Call.Return +func (c *MockQUICConnConnectionStateCall) Return(arg0 ConnectionState) *MockQUICConnConnectionStateCall { + c.Call = c.Call.Return(arg0) + return c +} + +// Do rewrite *gomock.Call.Do +func (c *MockQUICConnConnectionStateCall) Do(f func() ConnectionState) *MockQUICConnConnectionStateCall { + c.Call = c.Call.Do(f) + return c +} + +// DoAndReturn rewrite *gomock.Call.DoAndReturn +func (c *MockQUICConnConnectionStateCall) DoAndReturn(f func() ConnectionState) *MockQUICConnConnectionStateCall { + c.Call = c.Call.DoAndReturn(f) + return c } // Context mocks base method. @@ -108,23 +204,33 @@ func (m *MockQUICConn) Context() context.Context { } // Context indicates an expected call of Context. -func (mr *MockQUICConnMockRecorder) Context() *gomock.Call { +func (mr *MockQUICConnMockRecorder) Context() *MockQUICConnContextCall { mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Context", reflect.TypeOf((*MockQUICConn)(nil).Context)) + call := mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Context", reflect.TypeOf((*MockQUICConn)(nil).Context)) + return &MockQUICConnContextCall{Call: call} } -// GetVersion mocks base method. -func (m *MockQUICConn) GetVersion() protocol.VersionNumber { - m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "GetVersion") - ret0, _ := ret[0].(protocol.VersionNumber) - return ret0 +// MockQUICConnContextCall wrap *gomock.Call +type MockQUICConnContextCall struct { + *gomock.Call } -// GetVersion indicates an expected call of GetVersion. -func (mr *MockQUICConnMockRecorder) GetVersion() *gomock.Call { - mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetVersion", reflect.TypeOf((*MockQUICConn)(nil).GetVersion)) +// Return rewrite *gomock.Call.Return +func (c *MockQUICConnContextCall) Return(arg0 context.Context) *MockQUICConnContextCall { + c.Call = c.Call.Return(arg0) + return c +} + +// Do rewrite *gomock.Call.Do +func (c *MockQUICConnContextCall) Do(f func() context.Context) *MockQUICConnContextCall { + c.Call = c.Call.Do(f) + return c +} + +// DoAndReturn rewrite *gomock.Call.DoAndReturn +func (c *MockQUICConnContextCall) DoAndReturn(f func() context.Context) *MockQUICConnContextCall { + c.Call = c.Call.DoAndReturn(f) + return c } // HandshakeComplete mocks base method. @@ -136,9 +242,33 @@ func (m *MockQUICConn) HandshakeComplete() <-chan struct{} { } // HandshakeComplete indicates an expected call of HandshakeComplete. -func (mr *MockQUICConnMockRecorder) HandshakeComplete() *gomock.Call { +func (mr *MockQUICConnMockRecorder) HandshakeComplete() *MockQUICConnHandshakeCompleteCall { mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "HandshakeComplete", reflect.TypeOf((*MockQUICConn)(nil).HandshakeComplete)) + call := mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "HandshakeComplete", reflect.TypeOf((*MockQUICConn)(nil).HandshakeComplete)) + return &MockQUICConnHandshakeCompleteCall{Call: call} +} + +// MockQUICConnHandshakeCompleteCall wrap *gomock.Call +type MockQUICConnHandshakeCompleteCall struct { + *gomock.Call +} + +// Return rewrite *gomock.Call.Return +func (c *MockQUICConnHandshakeCompleteCall) Return(arg0 <-chan struct{}) *MockQUICConnHandshakeCompleteCall { + c.Call = c.Call.Return(arg0) + return c +} + +// Do rewrite *gomock.Call.Do +func (c *MockQUICConnHandshakeCompleteCall) Do(f func() <-chan struct{}) *MockQUICConnHandshakeCompleteCall { + c.Call = c.Call.Do(f) + return c +} + +// DoAndReturn rewrite *gomock.Call.DoAndReturn +func (c *MockQUICConnHandshakeCompleteCall) DoAndReturn(f func() <-chan struct{}) *MockQUICConnHandshakeCompleteCall { + c.Call = c.Call.DoAndReturn(f) + return c } // LocalAddr mocks base method. @@ -150,9 +280,33 @@ func (m *MockQUICConn) LocalAddr() net.Addr { } // LocalAddr indicates an expected call of LocalAddr. -func (mr *MockQUICConnMockRecorder) LocalAddr() *gomock.Call { +func (mr *MockQUICConnMockRecorder) LocalAddr() *MockQUICConnLocalAddrCall { mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "LocalAddr", reflect.TypeOf((*MockQUICConn)(nil).LocalAddr)) + call := mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "LocalAddr", reflect.TypeOf((*MockQUICConn)(nil).LocalAddr)) + return &MockQUICConnLocalAddrCall{Call: call} +} + +// MockQUICConnLocalAddrCall wrap *gomock.Call +type MockQUICConnLocalAddrCall struct { + *gomock.Call +} + +// Return rewrite *gomock.Call.Return +func (c *MockQUICConnLocalAddrCall) Return(arg0 net.Addr) *MockQUICConnLocalAddrCall { + c.Call = c.Call.Return(arg0) + return c +} + +// Do rewrite *gomock.Call.Do +func (c *MockQUICConnLocalAddrCall) Do(f func() net.Addr) *MockQUICConnLocalAddrCall { + c.Call = c.Call.Do(f) + return c +} + +// DoAndReturn rewrite *gomock.Call.DoAndReturn +func (c *MockQUICConnLocalAddrCall) DoAndReturn(f func() net.Addr) *MockQUICConnLocalAddrCall { + c.Call = c.Call.DoAndReturn(f) + return c } // NextConnection mocks base method. @@ -164,9 +318,33 @@ func (m *MockQUICConn) NextConnection() Connection { } // NextConnection indicates an expected call of NextConnection. -func (mr *MockQUICConnMockRecorder) NextConnection() *gomock.Call { +func (mr *MockQUICConnMockRecorder) NextConnection() *MockQUICConnNextConnectionCall { mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "NextConnection", reflect.TypeOf((*MockQUICConn)(nil).NextConnection)) + call := mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "NextConnection", reflect.TypeOf((*MockQUICConn)(nil).NextConnection)) + return &MockQUICConnNextConnectionCall{Call: call} +} + +// MockQUICConnNextConnectionCall wrap *gomock.Call +type MockQUICConnNextConnectionCall struct { + *gomock.Call +} + +// Return rewrite *gomock.Call.Return +func (c *MockQUICConnNextConnectionCall) Return(arg0 Connection) *MockQUICConnNextConnectionCall { + c.Call = c.Call.Return(arg0) + return c +} + +// Do rewrite *gomock.Call.Do +func (c *MockQUICConnNextConnectionCall) Do(f func() Connection) *MockQUICConnNextConnectionCall { + c.Call = c.Call.Do(f) + return c +} + +// DoAndReturn rewrite *gomock.Call.DoAndReturn +func (c *MockQUICConnNextConnectionCall) DoAndReturn(f func() Connection) *MockQUICConnNextConnectionCall { + c.Call = c.Call.DoAndReturn(f) + return c } // OpenStream mocks base method. @@ -179,9 +357,33 @@ func (m *MockQUICConn) OpenStream() (Stream, error) { } // OpenStream indicates an expected call of OpenStream. -func (mr *MockQUICConnMockRecorder) OpenStream() *gomock.Call { +func (mr *MockQUICConnMockRecorder) OpenStream() *MockQUICConnOpenStreamCall { mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "OpenStream", reflect.TypeOf((*MockQUICConn)(nil).OpenStream)) + call := mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "OpenStream", reflect.TypeOf((*MockQUICConn)(nil).OpenStream)) + return &MockQUICConnOpenStreamCall{Call: call} +} + +// MockQUICConnOpenStreamCall wrap *gomock.Call +type MockQUICConnOpenStreamCall struct { + *gomock.Call +} + +// Return rewrite *gomock.Call.Return +func (c *MockQUICConnOpenStreamCall) Return(arg0 Stream, arg1 error) *MockQUICConnOpenStreamCall { + c.Call = c.Call.Return(arg0, arg1) + return c +} + +// Do rewrite *gomock.Call.Do +func (c *MockQUICConnOpenStreamCall) Do(f func() (Stream, error)) *MockQUICConnOpenStreamCall { + c.Call = c.Call.Do(f) + return c +} + +// DoAndReturn rewrite *gomock.Call.DoAndReturn +func (c *MockQUICConnOpenStreamCall) DoAndReturn(f func() (Stream, error)) *MockQUICConnOpenStreamCall { + c.Call = c.Call.DoAndReturn(f) + return c } // OpenStreamSync mocks base method. @@ -194,9 +396,33 @@ func (m *MockQUICConn) OpenStreamSync(arg0 context.Context) (Stream, error) { } // OpenStreamSync indicates an expected call of OpenStreamSync. -func (mr *MockQUICConnMockRecorder) OpenStreamSync(arg0 any) *gomock.Call { +func (mr *MockQUICConnMockRecorder) OpenStreamSync(arg0 any) *MockQUICConnOpenStreamSyncCall { mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "OpenStreamSync", reflect.TypeOf((*MockQUICConn)(nil).OpenStreamSync), arg0) + call := mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "OpenStreamSync", reflect.TypeOf((*MockQUICConn)(nil).OpenStreamSync), arg0) + return &MockQUICConnOpenStreamSyncCall{Call: call} +} + +// MockQUICConnOpenStreamSyncCall wrap *gomock.Call +type MockQUICConnOpenStreamSyncCall struct { + *gomock.Call +} + +// Return rewrite *gomock.Call.Return +func (c *MockQUICConnOpenStreamSyncCall) Return(arg0 Stream, arg1 error) *MockQUICConnOpenStreamSyncCall { + c.Call = c.Call.Return(arg0, arg1) + return c +} + +// Do rewrite *gomock.Call.Do +func (c *MockQUICConnOpenStreamSyncCall) Do(f func(context.Context) (Stream, error)) *MockQUICConnOpenStreamSyncCall { + c.Call = c.Call.Do(f) + return c +} + +// DoAndReturn rewrite *gomock.Call.DoAndReturn +func (c *MockQUICConnOpenStreamSyncCall) DoAndReturn(f func(context.Context) (Stream, error)) *MockQUICConnOpenStreamSyncCall { + c.Call = c.Call.DoAndReturn(f) + return c } // OpenUniStream mocks base method. @@ -209,9 +435,33 @@ func (m *MockQUICConn) OpenUniStream() (SendStream, error) { } // OpenUniStream indicates an expected call of OpenUniStream. -func (mr *MockQUICConnMockRecorder) OpenUniStream() *gomock.Call { +func (mr *MockQUICConnMockRecorder) OpenUniStream() *MockQUICConnOpenUniStreamCall { mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "OpenUniStream", reflect.TypeOf((*MockQUICConn)(nil).OpenUniStream)) + call := mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "OpenUniStream", reflect.TypeOf((*MockQUICConn)(nil).OpenUniStream)) + return &MockQUICConnOpenUniStreamCall{Call: call} +} + +// MockQUICConnOpenUniStreamCall wrap *gomock.Call +type MockQUICConnOpenUniStreamCall struct { + *gomock.Call +} + +// Return rewrite *gomock.Call.Return +func (c *MockQUICConnOpenUniStreamCall) Return(arg0 SendStream, arg1 error) *MockQUICConnOpenUniStreamCall { + c.Call = c.Call.Return(arg0, arg1) + return c +} + +// Do rewrite *gomock.Call.Do +func (c *MockQUICConnOpenUniStreamCall) Do(f func() (SendStream, error)) *MockQUICConnOpenUniStreamCall { + c.Call = c.Call.Do(f) + return c +} + +// DoAndReturn rewrite *gomock.Call.DoAndReturn +func (c *MockQUICConnOpenUniStreamCall) DoAndReturn(f func() (SendStream, error)) *MockQUICConnOpenUniStreamCall { + c.Call = c.Call.DoAndReturn(f) + return c } // OpenUniStreamSync mocks base method. @@ -224,24 +474,72 @@ func (m *MockQUICConn) OpenUniStreamSync(arg0 context.Context) (SendStream, erro } // OpenUniStreamSync indicates an expected call of OpenUniStreamSync. -func (mr *MockQUICConnMockRecorder) OpenUniStreamSync(arg0 any) *gomock.Call { +func (mr *MockQUICConnMockRecorder) OpenUniStreamSync(arg0 any) *MockQUICConnOpenUniStreamSyncCall { mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "OpenUniStreamSync", reflect.TypeOf((*MockQUICConn)(nil).OpenUniStreamSync), arg0) + call := mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "OpenUniStreamSync", reflect.TypeOf((*MockQUICConn)(nil).OpenUniStreamSync), arg0) + return &MockQUICConnOpenUniStreamSyncCall{Call: call} +} + +// MockQUICConnOpenUniStreamSyncCall wrap *gomock.Call +type MockQUICConnOpenUniStreamSyncCall struct { + *gomock.Call +} + +// Return rewrite *gomock.Call.Return +func (c *MockQUICConnOpenUniStreamSyncCall) Return(arg0 SendStream, arg1 error) *MockQUICConnOpenUniStreamSyncCall { + c.Call = c.Call.Return(arg0, arg1) + return c } -// ReceiveMessage mocks base method. -func (m *MockQUICConn) ReceiveMessage(arg0 context.Context) ([]byte, error) { +// Do rewrite *gomock.Call.Do +func (c *MockQUICConnOpenUniStreamSyncCall) Do(f func(context.Context) (SendStream, error)) *MockQUICConnOpenUniStreamSyncCall { + c.Call = c.Call.Do(f) + return c +} + +// DoAndReturn rewrite *gomock.Call.DoAndReturn +func (c *MockQUICConnOpenUniStreamSyncCall) DoAndReturn(f func(context.Context) (SendStream, error)) *MockQUICConnOpenUniStreamSyncCall { + c.Call = c.Call.DoAndReturn(f) + return c +} + +// ReceiveDatagram mocks base method. +func (m *MockQUICConn) ReceiveDatagram(arg0 context.Context) ([]byte, error) { m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "ReceiveMessage", arg0) + ret := m.ctrl.Call(m, "ReceiveDatagram", arg0) ret0, _ := ret[0].([]byte) ret1, _ := ret[1].(error) return ret0, ret1 } -// ReceiveMessage indicates an expected call of ReceiveMessage. -func (mr *MockQUICConnMockRecorder) ReceiveMessage(arg0 any) *gomock.Call { +// ReceiveDatagram indicates an expected call of ReceiveDatagram. +func (mr *MockQUICConnMockRecorder) ReceiveDatagram(arg0 any) *MockQUICConnReceiveDatagramCall { mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "ReceiveMessage", reflect.TypeOf((*MockQUICConn)(nil).ReceiveMessage), arg0) + call := mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "ReceiveDatagram", reflect.TypeOf((*MockQUICConn)(nil).ReceiveDatagram), arg0) + return &MockQUICConnReceiveDatagramCall{Call: call} +} + +// MockQUICConnReceiveDatagramCall wrap *gomock.Call +type MockQUICConnReceiveDatagramCall struct { + *gomock.Call +} + +// Return rewrite *gomock.Call.Return +func (c *MockQUICConnReceiveDatagramCall) Return(arg0 []byte, arg1 error) *MockQUICConnReceiveDatagramCall { + c.Call = c.Call.Return(arg0, arg1) + return c +} + +// Do rewrite *gomock.Call.Do +func (c *MockQUICConnReceiveDatagramCall) Do(f func(context.Context) ([]byte, error)) *MockQUICConnReceiveDatagramCall { + c.Call = c.Call.Do(f) + return c +} + +// DoAndReturn rewrite *gomock.Call.DoAndReturn +func (c *MockQUICConnReceiveDatagramCall) DoAndReturn(f func(context.Context) ([]byte, error)) *MockQUICConnReceiveDatagramCall { + c.Call = c.Call.DoAndReturn(f) + return c } // RemoteAddr mocks base method. @@ -253,23 +551,107 @@ func (m *MockQUICConn) RemoteAddr() net.Addr { } // RemoteAddr indicates an expected call of RemoteAddr. -func (mr *MockQUICConnMockRecorder) RemoteAddr() *gomock.Call { +func (mr *MockQUICConnMockRecorder) RemoteAddr() *MockQUICConnRemoteAddrCall { mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "RemoteAddr", reflect.TypeOf((*MockQUICConn)(nil).RemoteAddr)) + call := mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "RemoteAddr", reflect.TypeOf((*MockQUICConn)(nil).RemoteAddr)) + return &MockQUICConnRemoteAddrCall{Call: call} +} + +// MockQUICConnRemoteAddrCall wrap *gomock.Call +type MockQUICConnRemoteAddrCall struct { + *gomock.Call +} + +// Return rewrite *gomock.Call.Return +func (c *MockQUICConnRemoteAddrCall) Return(arg0 net.Addr) *MockQUICConnRemoteAddrCall { + c.Call = c.Call.Return(arg0) + return c +} + +// Do rewrite *gomock.Call.Do +func (c *MockQUICConnRemoteAddrCall) Do(f func() net.Addr) *MockQUICConnRemoteAddrCall { + c.Call = c.Call.Do(f) + return c +} + +// DoAndReturn rewrite *gomock.Call.DoAndReturn +func (c *MockQUICConnRemoteAddrCall) DoAndReturn(f func() net.Addr) *MockQUICConnRemoteAddrCall { + c.Call = c.Call.DoAndReturn(f) + return c } -// SendMessage mocks base method. -func (m *MockQUICConn) SendMessage(arg0 []byte) error { +// SendDatagram mocks base method. +func (m *MockQUICConn) SendDatagram(arg0 []byte) error { m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "SendMessage", arg0) + ret := m.ctrl.Call(m, "SendDatagram", arg0) ret0, _ := ret[0].(error) return ret0 } -// SendMessage indicates an expected call of SendMessage. -func (mr *MockQUICConnMockRecorder) SendMessage(arg0 any) *gomock.Call { +// SendDatagram indicates an expected call of SendDatagram. +func (mr *MockQUICConnMockRecorder) SendDatagram(arg0 any) *MockQUICConnSendDatagramCall { + mr.mock.ctrl.T.Helper() + call := mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "SendDatagram", reflect.TypeOf((*MockQUICConn)(nil).SendDatagram), arg0) + return &MockQUICConnSendDatagramCall{Call: call} +} + +// MockQUICConnSendDatagramCall wrap *gomock.Call +type MockQUICConnSendDatagramCall struct { + *gomock.Call +} + +// Return rewrite *gomock.Call.Return +func (c *MockQUICConnSendDatagramCall) Return(arg0 error) *MockQUICConnSendDatagramCall { + c.Call = c.Call.Return(arg0) + return c +} + +// Do rewrite *gomock.Call.Do +func (c *MockQUICConnSendDatagramCall) Do(f func([]byte) error) *MockQUICConnSendDatagramCall { + c.Call = c.Call.Do(f) + return c +} + +// DoAndReturn rewrite *gomock.Call.DoAndReturn +func (c *MockQUICConnSendDatagramCall) DoAndReturn(f func([]byte) error) *MockQUICConnSendDatagramCall { + c.Call = c.Call.DoAndReturn(f) + return c +} + +// closeWithTransportError mocks base method. +func (m *MockQUICConn) closeWithTransportError(arg0 qerr.TransportErrorCode) { + m.ctrl.T.Helper() + m.ctrl.Call(m, "closeWithTransportError", arg0) +} + +// closeWithTransportError indicates an expected call of closeWithTransportError. +func (mr *MockQUICConnMockRecorder) closeWithTransportError(arg0 any) *MockQUICConncloseWithTransportErrorCall { mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "SendMessage", reflect.TypeOf((*MockQUICConn)(nil).SendMessage), arg0) + call := mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "closeWithTransportError", reflect.TypeOf((*MockQUICConn)(nil).closeWithTransportError), arg0) + return &MockQUICConncloseWithTransportErrorCall{Call: call} +} + +// MockQUICConncloseWithTransportErrorCall wrap *gomock.Call +type MockQUICConncloseWithTransportErrorCall struct { + *gomock.Call +} + +// Return rewrite *gomock.Call.Return +func (c *MockQUICConncloseWithTransportErrorCall) Return() *MockQUICConncloseWithTransportErrorCall { + c.Call = c.Call.Return() + return c +} + +// Do rewrite *gomock.Call.Do +func (c *MockQUICConncloseWithTransportErrorCall) Do(f func(qerr.TransportErrorCode)) *MockQUICConncloseWithTransportErrorCall { + c.Call = c.Call.Do(f) + return c +} + +// DoAndReturn rewrite *gomock.Call.DoAndReturn +func (c *MockQUICConncloseWithTransportErrorCall) DoAndReturn(f func(qerr.TransportErrorCode)) *MockQUICConncloseWithTransportErrorCall { + c.Call = c.Call.DoAndReturn(f) + return c } // destroy mocks base method. @@ -279,9 +661,33 @@ func (m *MockQUICConn) destroy(arg0 error) { } // destroy indicates an expected call of destroy. -func (mr *MockQUICConnMockRecorder) destroy(arg0 any) *gomock.Call { +func (mr *MockQUICConnMockRecorder) destroy(arg0 any) *MockQUICConndestroyCall { mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "destroy", reflect.TypeOf((*MockQUICConn)(nil).destroy), arg0) + call := mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "destroy", reflect.TypeOf((*MockQUICConn)(nil).destroy), arg0) + return &MockQUICConndestroyCall{Call: call} +} + +// MockQUICConndestroyCall wrap *gomock.Call +type MockQUICConndestroyCall struct { + *gomock.Call +} + +// Return rewrite *gomock.Call.Return +func (c *MockQUICConndestroyCall) Return() *MockQUICConndestroyCall { + c.Call = c.Call.Return() + return c +} + +// Do rewrite *gomock.Call.Do +func (c *MockQUICConndestroyCall) Do(f func(error)) *MockQUICConndestroyCall { + c.Call = c.Call.Do(f) + return c +} + +// DoAndReturn rewrite *gomock.Call.DoAndReturn +func (c *MockQUICConndestroyCall) DoAndReturn(f func(error)) *MockQUICConndestroyCall { + c.Call = c.Call.DoAndReturn(f) + return c } // earlyConnReady mocks base method. @@ -293,23 +699,33 @@ func (m *MockQUICConn) earlyConnReady() <-chan struct{} { } // earlyConnReady indicates an expected call of earlyConnReady. -func (mr *MockQUICConnMockRecorder) earlyConnReady() *gomock.Call { +func (mr *MockQUICConnMockRecorder) earlyConnReady() *MockQUICConnearlyConnReadyCall { mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "earlyConnReady", reflect.TypeOf((*MockQUICConn)(nil).earlyConnReady)) + call := mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "earlyConnReady", reflect.TypeOf((*MockQUICConn)(nil).earlyConnReady)) + return &MockQUICConnearlyConnReadyCall{Call: call} } -// getPerspective mocks base method. -func (m *MockQUICConn) getPerspective() protocol.Perspective { - m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "getPerspective") - ret0, _ := ret[0].(protocol.Perspective) - return ret0 +// MockQUICConnearlyConnReadyCall wrap *gomock.Call +type MockQUICConnearlyConnReadyCall struct { + *gomock.Call } -// getPerspective indicates an expected call of getPerspective. -func (mr *MockQUICConnMockRecorder) getPerspective() *gomock.Call { - mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "getPerspective", reflect.TypeOf((*MockQUICConn)(nil).getPerspective)) +// Return rewrite *gomock.Call.Return +func (c *MockQUICConnearlyConnReadyCall) Return(arg0 <-chan struct{}) *MockQUICConnearlyConnReadyCall { + c.Call = c.Call.Return(arg0) + return c +} + +// Do rewrite *gomock.Call.Do +func (c *MockQUICConnearlyConnReadyCall) Do(f func() <-chan struct{}) *MockQUICConnearlyConnReadyCall { + c.Call = c.Call.Do(f) + return c +} + +// DoAndReturn rewrite *gomock.Call.DoAndReturn +func (c *MockQUICConnearlyConnReadyCall) DoAndReturn(f func() <-chan struct{}) *MockQUICConnearlyConnReadyCall { + c.Call = c.Call.DoAndReturn(f) + return c } // handlePacket mocks base method. @@ -319,9 +735,33 @@ func (m *MockQUICConn) handlePacket(arg0 receivedPacket) { } // handlePacket indicates an expected call of handlePacket. -func (mr *MockQUICConnMockRecorder) handlePacket(arg0 any) *gomock.Call { +func (mr *MockQUICConnMockRecorder) handlePacket(arg0 any) *MockQUICConnhandlePacketCall { mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "handlePacket", reflect.TypeOf((*MockQUICConn)(nil).handlePacket), arg0) + call := mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "handlePacket", reflect.TypeOf((*MockQUICConn)(nil).handlePacket), arg0) + return &MockQUICConnhandlePacketCall{Call: call} +} + +// MockQUICConnhandlePacketCall wrap *gomock.Call +type MockQUICConnhandlePacketCall struct { + *gomock.Call +} + +// Return rewrite *gomock.Call.Return +func (c *MockQUICConnhandlePacketCall) Return() *MockQUICConnhandlePacketCall { + c.Call = c.Call.Return() + return c +} + +// Do rewrite *gomock.Call.Do +func (c *MockQUICConnhandlePacketCall) Do(f func(receivedPacket)) *MockQUICConnhandlePacketCall { + c.Call = c.Call.Do(f) + return c +} + +// DoAndReturn rewrite *gomock.Call.DoAndReturn +func (c *MockQUICConnhandlePacketCall) DoAndReturn(f func(receivedPacket)) *MockQUICConnhandlePacketCall { + c.Call = c.Call.DoAndReturn(f) + return c } // run mocks base method. @@ -333,19 +773,31 @@ func (m *MockQUICConn) run() error { } // run indicates an expected call of run. -func (mr *MockQUICConnMockRecorder) run() *gomock.Call { +func (mr *MockQUICConnMockRecorder) run() *MockQUICConnrunCall { mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "run", reflect.TypeOf((*MockQUICConn)(nil).run)) + call := mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "run", reflect.TypeOf((*MockQUICConn)(nil).run)) + return &MockQUICConnrunCall{Call: call} } -// shutdown mocks base method. -func (m *MockQUICConn) shutdown() { - m.ctrl.T.Helper() - m.ctrl.Call(m, "shutdown") +// MockQUICConnrunCall wrap *gomock.Call +type MockQUICConnrunCall struct { + *gomock.Call } -// shutdown indicates an expected call of shutdown. -func (mr *MockQUICConnMockRecorder) shutdown() *gomock.Call { - mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "shutdown", reflect.TypeOf((*MockQUICConn)(nil).shutdown)) +// Return rewrite *gomock.Call.Return +func (c *MockQUICConnrunCall) Return(arg0 error) *MockQUICConnrunCall { + c.Call = c.Call.Return(arg0) + return c +} + +// Do rewrite *gomock.Call.Do +func (c *MockQUICConnrunCall) Do(f func() error) *MockQUICConnrunCall { + c.Call = c.Call.Do(f) + return c +} + +// DoAndReturn rewrite *gomock.Call.DoAndReturn +func (c *MockQUICConnrunCall) DoAndReturn(f func() error) *MockQUICConnrunCall { + c.Call = c.Call.DoAndReturn(f) + return c } diff --git a/mock_raw_conn_test.go b/mock_raw_conn_test.go index f7862c62c..1c18491ff 100644 --- a/mock_raw_conn_test.go +++ b/mock_raw_conn_test.go @@ -3,8 +3,9 @@ // // Generated by this command: // -// mockgen -build_flags=-tags=gomock -package quic -self_package github.com/refraction-networking/uquic -destination mock_raw_conn_test.go github.com/refraction-networking/uquic RawConn +// mockgen -typed -build_flags=-tags=gomock -package quic -self_package github.com/quic-go/quic-go -destination mock_raw_conn_test.go github.com/quic-go/quic-go RawConn // + // Package quic is a generated GoMock package. package quic @@ -49,9 +50,33 @@ func (m *MockRawConn) Close() error { } // Close indicates an expected call of Close. -func (mr *MockRawConnMockRecorder) Close() *gomock.Call { +func (mr *MockRawConnMockRecorder) Close() *MockRawConnCloseCall { mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Close", reflect.TypeOf((*MockRawConn)(nil).Close)) + call := mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Close", reflect.TypeOf((*MockRawConn)(nil).Close)) + return &MockRawConnCloseCall{Call: call} +} + +// MockRawConnCloseCall wrap *gomock.Call +type MockRawConnCloseCall struct { + *gomock.Call +} + +// Return rewrite *gomock.Call.Return +func (c *MockRawConnCloseCall) Return(arg0 error) *MockRawConnCloseCall { + c.Call = c.Call.Return(arg0) + return c +} + +// Do rewrite *gomock.Call.Do +func (c *MockRawConnCloseCall) Do(f func() error) *MockRawConnCloseCall { + c.Call = c.Call.Do(f) + return c +} + +// DoAndReturn rewrite *gomock.Call.DoAndReturn +func (c *MockRawConnCloseCall) DoAndReturn(f func() error) *MockRawConnCloseCall { + c.Call = c.Call.DoAndReturn(f) + return c } // LocalAddr mocks base method. @@ -63,9 +88,33 @@ func (m *MockRawConn) LocalAddr() net.Addr { } // LocalAddr indicates an expected call of LocalAddr. -func (mr *MockRawConnMockRecorder) LocalAddr() *gomock.Call { +func (mr *MockRawConnMockRecorder) LocalAddr() *MockRawConnLocalAddrCall { mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "LocalAddr", reflect.TypeOf((*MockRawConn)(nil).LocalAddr)) + call := mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "LocalAddr", reflect.TypeOf((*MockRawConn)(nil).LocalAddr)) + return &MockRawConnLocalAddrCall{Call: call} +} + +// MockRawConnLocalAddrCall wrap *gomock.Call +type MockRawConnLocalAddrCall struct { + *gomock.Call +} + +// Return rewrite *gomock.Call.Return +func (c *MockRawConnLocalAddrCall) Return(arg0 net.Addr) *MockRawConnLocalAddrCall { + c.Call = c.Call.Return(arg0) + return c +} + +// Do rewrite *gomock.Call.Do +func (c *MockRawConnLocalAddrCall) Do(f func() net.Addr) *MockRawConnLocalAddrCall { + c.Call = c.Call.Do(f) + return c +} + +// DoAndReturn rewrite *gomock.Call.DoAndReturn +func (c *MockRawConnLocalAddrCall) DoAndReturn(f func() net.Addr) *MockRawConnLocalAddrCall { + c.Call = c.Call.DoAndReturn(f) + return c } // ReadPacket mocks base method. @@ -78,9 +127,33 @@ func (m *MockRawConn) ReadPacket() (receivedPacket, error) { } // ReadPacket indicates an expected call of ReadPacket. -func (mr *MockRawConnMockRecorder) ReadPacket() *gomock.Call { +func (mr *MockRawConnMockRecorder) ReadPacket() *MockRawConnReadPacketCall { mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "ReadPacket", reflect.TypeOf((*MockRawConn)(nil).ReadPacket)) + call := mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "ReadPacket", reflect.TypeOf((*MockRawConn)(nil).ReadPacket)) + return &MockRawConnReadPacketCall{Call: call} +} + +// MockRawConnReadPacketCall wrap *gomock.Call +type MockRawConnReadPacketCall struct { + *gomock.Call +} + +// Return rewrite *gomock.Call.Return +func (c *MockRawConnReadPacketCall) Return(arg0 receivedPacket, arg1 error) *MockRawConnReadPacketCall { + c.Call = c.Call.Return(arg0, arg1) + return c +} + +// Do rewrite *gomock.Call.Do +func (c *MockRawConnReadPacketCall) Do(f func() (receivedPacket, error)) *MockRawConnReadPacketCall { + c.Call = c.Call.Do(f) + return c +} + +// DoAndReturn rewrite *gomock.Call.DoAndReturn +func (c *MockRawConnReadPacketCall) DoAndReturn(f func() (receivedPacket, error)) *MockRawConnReadPacketCall { + c.Call = c.Call.DoAndReturn(f) + return c } // SetReadDeadline mocks base method. @@ -92,9 +165,33 @@ func (m *MockRawConn) SetReadDeadline(arg0 time.Time) error { } // SetReadDeadline indicates an expected call of SetReadDeadline. -func (mr *MockRawConnMockRecorder) SetReadDeadline(arg0 any) *gomock.Call { +func (mr *MockRawConnMockRecorder) SetReadDeadline(arg0 any) *MockRawConnSetReadDeadlineCall { mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "SetReadDeadline", reflect.TypeOf((*MockRawConn)(nil).SetReadDeadline), arg0) + call := mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "SetReadDeadline", reflect.TypeOf((*MockRawConn)(nil).SetReadDeadline), arg0) + return &MockRawConnSetReadDeadlineCall{Call: call} +} + +// MockRawConnSetReadDeadlineCall wrap *gomock.Call +type MockRawConnSetReadDeadlineCall struct { + *gomock.Call +} + +// Return rewrite *gomock.Call.Return +func (c *MockRawConnSetReadDeadlineCall) Return(arg0 error) *MockRawConnSetReadDeadlineCall { + c.Call = c.Call.Return(arg0) + return c +} + +// Do rewrite *gomock.Call.Do +func (c *MockRawConnSetReadDeadlineCall) Do(f func(time.Time) error) *MockRawConnSetReadDeadlineCall { + c.Call = c.Call.Do(f) + return c +} + +// DoAndReturn rewrite *gomock.Call.DoAndReturn +func (c *MockRawConnSetReadDeadlineCall) DoAndReturn(f func(time.Time) error) *MockRawConnSetReadDeadlineCall { + c.Call = c.Call.DoAndReturn(f) + return c } // WritePacket mocks base method. @@ -107,9 +204,33 @@ func (m *MockRawConn) WritePacket(arg0 []byte, arg1 net.Addr, arg2 []byte, arg3 } // WritePacket indicates an expected call of WritePacket. -func (mr *MockRawConnMockRecorder) WritePacket(arg0, arg1, arg2, arg3, arg4 any) *gomock.Call { +func (mr *MockRawConnMockRecorder) WritePacket(arg0, arg1, arg2, arg3, arg4 any) *MockRawConnWritePacketCall { mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "WritePacket", reflect.TypeOf((*MockRawConn)(nil).WritePacket), arg0, arg1, arg2, arg3, arg4) + call := mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "WritePacket", reflect.TypeOf((*MockRawConn)(nil).WritePacket), arg0, arg1, arg2, arg3, arg4) + return &MockRawConnWritePacketCall{Call: call} +} + +// MockRawConnWritePacketCall wrap *gomock.Call +type MockRawConnWritePacketCall struct { + *gomock.Call +} + +// Return rewrite *gomock.Call.Return +func (c *MockRawConnWritePacketCall) Return(arg0 int, arg1 error) *MockRawConnWritePacketCall { + c.Call = c.Call.Return(arg0, arg1) + return c +} + +// Do rewrite *gomock.Call.Do +func (c *MockRawConnWritePacketCall) Do(f func([]byte, net.Addr, []byte, uint16, protocol.ECN) (int, error)) *MockRawConnWritePacketCall { + c.Call = c.Call.Do(f) + return c +} + +// DoAndReturn rewrite *gomock.Call.DoAndReturn +func (c *MockRawConnWritePacketCall) DoAndReturn(f func([]byte, net.Addr, []byte, uint16, protocol.ECN) (int, error)) *MockRawConnWritePacketCall { + c.Call = c.Call.DoAndReturn(f) + return c } // capabilities mocks base method. @@ -121,7 +242,31 @@ func (m *MockRawConn) capabilities() connCapabilities { } // capabilities indicates an expected call of capabilities. -func (mr *MockRawConnMockRecorder) capabilities() *gomock.Call { +func (mr *MockRawConnMockRecorder) capabilities() *MockRawConncapabilitiesCall { mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "capabilities", reflect.TypeOf((*MockRawConn)(nil).capabilities)) + call := mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "capabilities", reflect.TypeOf((*MockRawConn)(nil).capabilities)) + return &MockRawConncapabilitiesCall{Call: call} +} + +// MockRawConncapabilitiesCall wrap *gomock.Call +type MockRawConncapabilitiesCall struct { + *gomock.Call +} + +// Return rewrite *gomock.Call.Return +func (c *MockRawConncapabilitiesCall) Return(arg0 connCapabilities) *MockRawConncapabilitiesCall { + c.Call = c.Call.Return(arg0) + return c +} + +// Do rewrite *gomock.Call.Do +func (c *MockRawConncapabilitiesCall) Do(f func() connCapabilities) *MockRawConncapabilitiesCall { + c.Call = c.Call.Do(f) + return c +} + +// DoAndReturn rewrite *gomock.Call.DoAndReturn +func (c *MockRawConncapabilitiesCall) DoAndReturn(f func() connCapabilities) *MockRawConncapabilitiesCall { + c.Call = c.Call.DoAndReturn(f) + return c } diff --git a/mock_receive_stream_internal_test.go b/mock_receive_stream_internal_test.go index 48b9063c8..3644251e2 100644 --- a/mock_receive_stream_internal_test.go +++ b/mock_receive_stream_internal_test.go @@ -3,8 +3,9 @@ // // Generated by this command: // -// mockgen -build_flags=-tags=gomock -package quic -self_package github.com/refraction-networking/uquic -destination mock_receive_stream_internal_test.go github.com/refraction-networking/uquic ReceiveStreamI +// mockgen -typed -build_flags=-tags=gomock -package quic -self_package github.com/quic-go/quic-go -destination mock_receive_stream_internal_test.go github.com/quic-go/quic-go ReceiveStreamI // + // Package quic is a generated GoMock package. package quic @@ -48,9 +49,33 @@ func (m *MockReceiveStreamI) CancelRead(arg0 qerr.StreamErrorCode) { } // CancelRead indicates an expected call of CancelRead. -func (mr *MockReceiveStreamIMockRecorder) CancelRead(arg0 any) *gomock.Call { +func (mr *MockReceiveStreamIMockRecorder) CancelRead(arg0 any) *MockReceiveStreamICancelReadCall { mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "CancelRead", reflect.TypeOf((*MockReceiveStreamI)(nil).CancelRead), arg0) + call := mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "CancelRead", reflect.TypeOf((*MockReceiveStreamI)(nil).CancelRead), arg0) + return &MockReceiveStreamICancelReadCall{Call: call} +} + +// MockReceiveStreamICancelReadCall wrap *gomock.Call +type MockReceiveStreamICancelReadCall struct { + *gomock.Call +} + +// Return rewrite *gomock.Call.Return +func (c *MockReceiveStreamICancelReadCall) Return() *MockReceiveStreamICancelReadCall { + c.Call = c.Call.Return() + return c +} + +// Do rewrite *gomock.Call.Do +func (c *MockReceiveStreamICancelReadCall) Do(f func(qerr.StreamErrorCode)) *MockReceiveStreamICancelReadCall { + c.Call = c.Call.Do(f) + return c +} + +// DoAndReturn rewrite *gomock.Call.DoAndReturn +func (c *MockReceiveStreamICancelReadCall) DoAndReturn(f func(qerr.StreamErrorCode)) *MockReceiveStreamICancelReadCall { + c.Call = c.Call.DoAndReturn(f) + return c } // Read mocks base method. @@ -63,9 +88,33 @@ func (m *MockReceiveStreamI) Read(arg0 []byte) (int, error) { } // Read indicates an expected call of Read. -func (mr *MockReceiveStreamIMockRecorder) Read(arg0 any) *gomock.Call { +func (mr *MockReceiveStreamIMockRecorder) Read(arg0 any) *MockReceiveStreamIReadCall { mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Read", reflect.TypeOf((*MockReceiveStreamI)(nil).Read), arg0) + call := mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Read", reflect.TypeOf((*MockReceiveStreamI)(nil).Read), arg0) + return &MockReceiveStreamIReadCall{Call: call} +} + +// MockReceiveStreamIReadCall wrap *gomock.Call +type MockReceiveStreamIReadCall struct { + *gomock.Call +} + +// Return rewrite *gomock.Call.Return +func (c *MockReceiveStreamIReadCall) Return(arg0 int, arg1 error) *MockReceiveStreamIReadCall { + c.Call = c.Call.Return(arg0, arg1) + return c +} + +// Do rewrite *gomock.Call.Do +func (c *MockReceiveStreamIReadCall) Do(f func([]byte) (int, error)) *MockReceiveStreamIReadCall { + c.Call = c.Call.Do(f) + return c +} + +// DoAndReturn rewrite *gomock.Call.DoAndReturn +func (c *MockReceiveStreamIReadCall) DoAndReturn(f func([]byte) (int, error)) *MockReceiveStreamIReadCall { + c.Call = c.Call.DoAndReturn(f) + return c } // SetReadDeadline mocks base method. @@ -77,9 +126,33 @@ func (m *MockReceiveStreamI) SetReadDeadline(arg0 time.Time) error { } // SetReadDeadline indicates an expected call of SetReadDeadline. -func (mr *MockReceiveStreamIMockRecorder) SetReadDeadline(arg0 any) *gomock.Call { +func (mr *MockReceiveStreamIMockRecorder) SetReadDeadline(arg0 any) *MockReceiveStreamISetReadDeadlineCall { mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "SetReadDeadline", reflect.TypeOf((*MockReceiveStreamI)(nil).SetReadDeadline), arg0) + call := mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "SetReadDeadline", reflect.TypeOf((*MockReceiveStreamI)(nil).SetReadDeadline), arg0) + return &MockReceiveStreamISetReadDeadlineCall{Call: call} +} + +// MockReceiveStreamISetReadDeadlineCall wrap *gomock.Call +type MockReceiveStreamISetReadDeadlineCall struct { + *gomock.Call +} + +// Return rewrite *gomock.Call.Return +func (c *MockReceiveStreamISetReadDeadlineCall) Return(arg0 error) *MockReceiveStreamISetReadDeadlineCall { + c.Call = c.Call.Return(arg0) + return c +} + +// Do rewrite *gomock.Call.Do +func (c *MockReceiveStreamISetReadDeadlineCall) Do(f func(time.Time) error) *MockReceiveStreamISetReadDeadlineCall { + c.Call = c.Call.Do(f) + return c +} + +// DoAndReturn rewrite *gomock.Call.DoAndReturn +func (c *MockReceiveStreamISetReadDeadlineCall) DoAndReturn(f func(time.Time) error) *MockReceiveStreamISetReadDeadlineCall { + c.Call = c.Call.DoAndReturn(f) + return c } // StreamID mocks base method. @@ -91,9 +164,33 @@ func (m *MockReceiveStreamI) StreamID() protocol.StreamID { } // StreamID indicates an expected call of StreamID. -func (mr *MockReceiveStreamIMockRecorder) StreamID() *gomock.Call { +func (mr *MockReceiveStreamIMockRecorder) StreamID() *MockReceiveStreamIStreamIDCall { mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "StreamID", reflect.TypeOf((*MockReceiveStreamI)(nil).StreamID)) + call := mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "StreamID", reflect.TypeOf((*MockReceiveStreamI)(nil).StreamID)) + return &MockReceiveStreamIStreamIDCall{Call: call} +} + +// MockReceiveStreamIStreamIDCall wrap *gomock.Call +type MockReceiveStreamIStreamIDCall struct { + *gomock.Call +} + +// Return rewrite *gomock.Call.Return +func (c *MockReceiveStreamIStreamIDCall) Return(arg0 protocol.StreamID) *MockReceiveStreamIStreamIDCall { + c.Call = c.Call.Return(arg0) + return c +} + +// Do rewrite *gomock.Call.Do +func (c *MockReceiveStreamIStreamIDCall) Do(f func() protocol.StreamID) *MockReceiveStreamIStreamIDCall { + c.Call = c.Call.Do(f) + return c +} + +// DoAndReturn rewrite *gomock.Call.DoAndReturn +func (c *MockReceiveStreamIStreamIDCall) DoAndReturn(f func() protocol.StreamID) *MockReceiveStreamIStreamIDCall { + c.Call = c.Call.DoAndReturn(f) + return c } // closeForShutdown mocks base method. @@ -103,9 +200,33 @@ func (m *MockReceiveStreamI) closeForShutdown(arg0 error) { } // closeForShutdown indicates an expected call of closeForShutdown. -func (mr *MockReceiveStreamIMockRecorder) closeForShutdown(arg0 any) *gomock.Call { +func (mr *MockReceiveStreamIMockRecorder) closeForShutdown(arg0 any) *MockReceiveStreamIcloseForShutdownCall { mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "closeForShutdown", reflect.TypeOf((*MockReceiveStreamI)(nil).closeForShutdown), arg0) + call := mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "closeForShutdown", reflect.TypeOf((*MockReceiveStreamI)(nil).closeForShutdown), arg0) + return &MockReceiveStreamIcloseForShutdownCall{Call: call} +} + +// MockReceiveStreamIcloseForShutdownCall wrap *gomock.Call +type MockReceiveStreamIcloseForShutdownCall struct { + *gomock.Call +} + +// Return rewrite *gomock.Call.Return +func (c *MockReceiveStreamIcloseForShutdownCall) Return() *MockReceiveStreamIcloseForShutdownCall { + c.Call = c.Call.Return() + return c +} + +// Do rewrite *gomock.Call.Do +func (c *MockReceiveStreamIcloseForShutdownCall) Do(f func(error)) *MockReceiveStreamIcloseForShutdownCall { + c.Call = c.Call.Do(f) + return c +} + +// DoAndReturn rewrite *gomock.Call.DoAndReturn +func (c *MockReceiveStreamIcloseForShutdownCall) DoAndReturn(f func(error)) *MockReceiveStreamIcloseForShutdownCall { + c.Call = c.Call.DoAndReturn(f) + return c } // getWindowUpdate mocks base method. @@ -117,9 +238,33 @@ func (m *MockReceiveStreamI) getWindowUpdate() protocol.ByteCount { } // getWindowUpdate indicates an expected call of getWindowUpdate. -func (mr *MockReceiveStreamIMockRecorder) getWindowUpdate() *gomock.Call { +func (mr *MockReceiveStreamIMockRecorder) getWindowUpdate() *MockReceiveStreamIgetWindowUpdateCall { mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "getWindowUpdate", reflect.TypeOf((*MockReceiveStreamI)(nil).getWindowUpdate)) + call := mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "getWindowUpdate", reflect.TypeOf((*MockReceiveStreamI)(nil).getWindowUpdate)) + return &MockReceiveStreamIgetWindowUpdateCall{Call: call} +} + +// MockReceiveStreamIgetWindowUpdateCall wrap *gomock.Call +type MockReceiveStreamIgetWindowUpdateCall struct { + *gomock.Call +} + +// Return rewrite *gomock.Call.Return +func (c *MockReceiveStreamIgetWindowUpdateCall) Return(arg0 protocol.ByteCount) *MockReceiveStreamIgetWindowUpdateCall { + c.Call = c.Call.Return(arg0) + return c +} + +// Do rewrite *gomock.Call.Do +func (c *MockReceiveStreamIgetWindowUpdateCall) Do(f func() protocol.ByteCount) *MockReceiveStreamIgetWindowUpdateCall { + c.Call = c.Call.Do(f) + return c +} + +// DoAndReturn rewrite *gomock.Call.DoAndReturn +func (c *MockReceiveStreamIgetWindowUpdateCall) DoAndReturn(f func() protocol.ByteCount) *MockReceiveStreamIgetWindowUpdateCall { + c.Call = c.Call.DoAndReturn(f) + return c } // handleResetStreamFrame mocks base method. @@ -131,9 +276,33 @@ func (m *MockReceiveStreamI) handleResetStreamFrame(arg0 *wire.ResetStreamFrame) } // handleResetStreamFrame indicates an expected call of handleResetStreamFrame. -func (mr *MockReceiveStreamIMockRecorder) handleResetStreamFrame(arg0 any) *gomock.Call { +func (mr *MockReceiveStreamIMockRecorder) handleResetStreamFrame(arg0 any) *MockReceiveStreamIhandleResetStreamFrameCall { mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "handleResetStreamFrame", reflect.TypeOf((*MockReceiveStreamI)(nil).handleResetStreamFrame), arg0) + call := mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "handleResetStreamFrame", reflect.TypeOf((*MockReceiveStreamI)(nil).handleResetStreamFrame), arg0) + return &MockReceiveStreamIhandleResetStreamFrameCall{Call: call} +} + +// MockReceiveStreamIhandleResetStreamFrameCall wrap *gomock.Call +type MockReceiveStreamIhandleResetStreamFrameCall struct { + *gomock.Call +} + +// Return rewrite *gomock.Call.Return +func (c *MockReceiveStreamIhandleResetStreamFrameCall) Return(arg0 error) *MockReceiveStreamIhandleResetStreamFrameCall { + c.Call = c.Call.Return(arg0) + return c +} + +// Do rewrite *gomock.Call.Do +func (c *MockReceiveStreamIhandleResetStreamFrameCall) Do(f func(*wire.ResetStreamFrame) error) *MockReceiveStreamIhandleResetStreamFrameCall { + c.Call = c.Call.Do(f) + return c +} + +// DoAndReturn rewrite *gomock.Call.DoAndReturn +func (c *MockReceiveStreamIhandleResetStreamFrameCall) DoAndReturn(f func(*wire.ResetStreamFrame) error) *MockReceiveStreamIhandleResetStreamFrameCall { + c.Call = c.Call.DoAndReturn(f) + return c } // handleStreamFrame mocks base method. @@ -145,7 +314,31 @@ func (m *MockReceiveStreamI) handleStreamFrame(arg0 *wire.StreamFrame) error { } // handleStreamFrame indicates an expected call of handleStreamFrame. -func (mr *MockReceiveStreamIMockRecorder) handleStreamFrame(arg0 any) *gomock.Call { +func (mr *MockReceiveStreamIMockRecorder) handleStreamFrame(arg0 any) *MockReceiveStreamIhandleStreamFrameCall { mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "handleStreamFrame", reflect.TypeOf((*MockReceiveStreamI)(nil).handleStreamFrame), arg0) + call := mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "handleStreamFrame", reflect.TypeOf((*MockReceiveStreamI)(nil).handleStreamFrame), arg0) + return &MockReceiveStreamIhandleStreamFrameCall{Call: call} +} + +// MockReceiveStreamIhandleStreamFrameCall wrap *gomock.Call +type MockReceiveStreamIhandleStreamFrameCall struct { + *gomock.Call +} + +// Return rewrite *gomock.Call.Return +func (c *MockReceiveStreamIhandleStreamFrameCall) Return(arg0 error) *MockReceiveStreamIhandleStreamFrameCall { + c.Call = c.Call.Return(arg0) + return c +} + +// Do rewrite *gomock.Call.Do +func (c *MockReceiveStreamIhandleStreamFrameCall) Do(f func(*wire.StreamFrame) error) *MockReceiveStreamIhandleStreamFrameCall { + c.Call = c.Call.Do(f) + return c +} + +// DoAndReturn rewrite *gomock.Call.DoAndReturn +func (c *MockReceiveStreamIhandleStreamFrameCall) DoAndReturn(f func(*wire.StreamFrame) error) *MockReceiveStreamIhandleStreamFrameCall { + c.Call = c.Call.DoAndReturn(f) + return c } diff --git a/mock_sealing_manager_test.go b/mock_sealing_manager_test.go index 6a691b3be..59c466848 100644 --- a/mock_sealing_manager_test.go +++ b/mock_sealing_manager_test.go @@ -3,8 +3,9 @@ // // Generated by this command: // -// mockgen -build_flags=-tags=gomock -package quic -self_package github.com/refraction-networking/uquic -destination mock_sealing_manager_test.go github.com/refraction-networking/uquic SealingManager +// mockgen -typed -build_flags=-tags=gomock -package quic -self_package github.com/quic-go/quic-go -destination mock_sealing_manager_test.go github.com/quic-go/quic-go SealingManager // + // Package quic is a generated GoMock package. package quic @@ -48,9 +49,33 @@ func (m *MockSealingManager) Get0RTTSealer() (handshake.LongHeaderSealer, error) } // Get0RTTSealer indicates an expected call of Get0RTTSealer. -func (mr *MockSealingManagerMockRecorder) Get0RTTSealer() *gomock.Call { +func (mr *MockSealingManagerMockRecorder) Get0RTTSealer() *MockSealingManagerGet0RTTSealerCall { mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Get0RTTSealer", reflect.TypeOf((*MockSealingManager)(nil).Get0RTTSealer)) + call := mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Get0RTTSealer", reflect.TypeOf((*MockSealingManager)(nil).Get0RTTSealer)) + return &MockSealingManagerGet0RTTSealerCall{Call: call} +} + +// MockSealingManagerGet0RTTSealerCall wrap *gomock.Call +type MockSealingManagerGet0RTTSealerCall struct { + *gomock.Call +} + +// Return rewrite *gomock.Call.Return +func (c *MockSealingManagerGet0RTTSealerCall) Return(arg0 handshake.LongHeaderSealer, arg1 error) *MockSealingManagerGet0RTTSealerCall { + c.Call = c.Call.Return(arg0, arg1) + return c +} + +// Do rewrite *gomock.Call.Do +func (c *MockSealingManagerGet0RTTSealerCall) Do(f func() (handshake.LongHeaderSealer, error)) *MockSealingManagerGet0RTTSealerCall { + c.Call = c.Call.Do(f) + return c +} + +// DoAndReturn rewrite *gomock.Call.DoAndReturn +func (c *MockSealingManagerGet0RTTSealerCall) DoAndReturn(f func() (handshake.LongHeaderSealer, error)) *MockSealingManagerGet0RTTSealerCall { + c.Call = c.Call.DoAndReturn(f) + return c } // Get1RTTSealer mocks base method. @@ -63,9 +88,33 @@ func (m *MockSealingManager) Get1RTTSealer() (handshake.ShortHeaderSealer, error } // Get1RTTSealer indicates an expected call of Get1RTTSealer. -func (mr *MockSealingManagerMockRecorder) Get1RTTSealer() *gomock.Call { +func (mr *MockSealingManagerMockRecorder) Get1RTTSealer() *MockSealingManagerGet1RTTSealerCall { mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Get1RTTSealer", reflect.TypeOf((*MockSealingManager)(nil).Get1RTTSealer)) + call := mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Get1RTTSealer", reflect.TypeOf((*MockSealingManager)(nil).Get1RTTSealer)) + return &MockSealingManagerGet1RTTSealerCall{Call: call} +} + +// MockSealingManagerGet1RTTSealerCall wrap *gomock.Call +type MockSealingManagerGet1RTTSealerCall struct { + *gomock.Call +} + +// Return rewrite *gomock.Call.Return +func (c *MockSealingManagerGet1RTTSealerCall) Return(arg0 handshake.ShortHeaderSealer, arg1 error) *MockSealingManagerGet1RTTSealerCall { + c.Call = c.Call.Return(arg0, arg1) + return c +} + +// Do rewrite *gomock.Call.Do +func (c *MockSealingManagerGet1RTTSealerCall) Do(f func() (handshake.ShortHeaderSealer, error)) *MockSealingManagerGet1RTTSealerCall { + c.Call = c.Call.Do(f) + return c +} + +// DoAndReturn rewrite *gomock.Call.DoAndReturn +func (c *MockSealingManagerGet1RTTSealerCall) DoAndReturn(f func() (handshake.ShortHeaderSealer, error)) *MockSealingManagerGet1RTTSealerCall { + c.Call = c.Call.DoAndReturn(f) + return c } // GetHandshakeSealer mocks base method. @@ -78,9 +127,33 @@ func (m *MockSealingManager) GetHandshakeSealer() (handshake.LongHeaderSealer, e } // GetHandshakeSealer indicates an expected call of GetHandshakeSealer. -func (mr *MockSealingManagerMockRecorder) GetHandshakeSealer() *gomock.Call { +func (mr *MockSealingManagerMockRecorder) GetHandshakeSealer() *MockSealingManagerGetHandshakeSealerCall { mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetHandshakeSealer", reflect.TypeOf((*MockSealingManager)(nil).GetHandshakeSealer)) + call := mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetHandshakeSealer", reflect.TypeOf((*MockSealingManager)(nil).GetHandshakeSealer)) + return &MockSealingManagerGetHandshakeSealerCall{Call: call} +} + +// MockSealingManagerGetHandshakeSealerCall wrap *gomock.Call +type MockSealingManagerGetHandshakeSealerCall struct { + *gomock.Call +} + +// Return rewrite *gomock.Call.Return +func (c *MockSealingManagerGetHandshakeSealerCall) Return(arg0 handshake.LongHeaderSealer, arg1 error) *MockSealingManagerGetHandshakeSealerCall { + c.Call = c.Call.Return(arg0, arg1) + return c +} + +// Do rewrite *gomock.Call.Do +func (c *MockSealingManagerGetHandshakeSealerCall) Do(f func() (handshake.LongHeaderSealer, error)) *MockSealingManagerGetHandshakeSealerCall { + c.Call = c.Call.Do(f) + return c +} + +// DoAndReturn rewrite *gomock.Call.DoAndReturn +func (c *MockSealingManagerGetHandshakeSealerCall) DoAndReturn(f func() (handshake.LongHeaderSealer, error)) *MockSealingManagerGetHandshakeSealerCall { + c.Call = c.Call.DoAndReturn(f) + return c } // GetInitialSealer mocks base method. @@ -93,7 +166,31 @@ func (m *MockSealingManager) GetInitialSealer() (handshake.LongHeaderSealer, err } // GetInitialSealer indicates an expected call of GetInitialSealer. -func (mr *MockSealingManagerMockRecorder) GetInitialSealer() *gomock.Call { +func (mr *MockSealingManagerMockRecorder) GetInitialSealer() *MockSealingManagerGetInitialSealerCall { mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetInitialSealer", reflect.TypeOf((*MockSealingManager)(nil).GetInitialSealer)) + call := mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetInitialSealer", reflect.TypeOf((*MockSealingManager)(nil).GetInitialSealer)) + return &MockSealingManagerGetInitialSealerCall{Call: call} +} + +// MockSealingManagerGetInitialSealerCall wrap *gomock.Call +type MockSealingManagerGetInitialSealerCall struct { + *gomock.Call +} + +// Return rewrite *gomock.Call.Return +func (c *MockSealingManagerGetInitialSealerCall) Return(arg0 handshake.LongHeaderSealer, arg1 error) *MockSealingManagerGetInitialSealerCall { + c.Call = c.Call.Return(arg0, arg1) + return c +} + +// Do rewrite *gomock.Call.Do +func (c *MockSealingManagerGetInitialSealerCall) Do(f func() (handshake.LongHeaderSealer, error)) *MockSealingManagerGetInitialSealerCall { + c.Call = c.Call.Do(f) + return c +} + +// DoAndReturn rewrite *gomock.Call.DoAndReturn +func (c *MockSealingManagerGetInitialSealerCall) DoAndReturn(f func() (handshake.LongHeaderSealer, error)) *MockSealingManagerGetInitialSealerCall { + c.Call = c.Call.DoAndReturn(f) + return c } diff --git a/mock_send_conn_test.go b/mock_send_conn_test.go index 92fe43a96..f903e5151 100644 --- a/mock_send_conn_test.go +++ b/mock_send_conn_test.go @@ -3,8 +3,9 @@ // // Generated by this command: // -// mockgen -build_flags=-tags=gomock -package quic -self_package github.com/refraction-networking/uquic -destination mock_send_conn_test.go github.com/refraction-networking/uquic SendConn +// mockgen -typed -build_flags=-tags=gomock -package quic -self_package github.com/quic-go/quic-go -destination mock_send_conn_test.go github.com/quic-go/quic-go SendConn // + // Package quic is a generated GoMock package. package quic @@ -48,9 +49,33 @@ func (m *MockSendConn) Close() error { } // Close indicates an expected call of Close. -func (mr *MockSendConnMockRecorder) Close() *gomock.Call { +func (mr *MockSendConnMockRecorder) Close() *MockSendConnCloseCall { mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Close", reflect.TypeOf((*MockSendConn)(nil).Close)) + call := mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Close", reflect.TypeOf((*MockSendConn)(nil).Close)) + return &MockSendConnCloseCall{Call: call} +} + +// MockSendConnCloseCall wrap *gomock.Call +type MockSendConnCloseCall struct { + *gomock.Call +} + +// Return rewrite *gomock.Call.Return +func (c *MockSendConnCloseCall) Return(arg0 error) *MockSendConnCloseCall { + c.Call = c.Call.Return(arg0) + return c +} + +// Do rewrite *gomock.Call.Do +func (c *MockSendConnCloseCall) Do(f func() error) *MockSendConnCloseCall { + c.Call = c.Call.Do(f) + return c +} + +// DoAndReturn rewrite *gomock.Call.DoAndReturn +func (c *MockSendConnCloseCall) DoAndReturn(f func() error) *MockSendConnCloseCall { + c.Call = c.Call.DoAndReturn(f) + return c } // LocalAddr mocks base method. @@ -62,9 +87,33 @@ func (m *MockSendConn) LocalAddr() net.Addr { } // LocalAddr indicates an expected call of LocalAddr. -func (mr *MockSendConnMockRecorder) LocalAddr() *gomock.Call { +func (mr *MockSendConnMockRecorder) LocalAddr() *MockSendConnLocalAddrCall { mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "LocalAddr", reflect.TypeOf((*MockSendConn)(nil).LocalAddr)) + call := mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "LocalAddr", reflect.TypeOf((*MockSendConn)(nil).LocalAddr)) + return &MockSendConnLocalAddrCall{Call: call} +} + +// MockSendConnLocalAddrCall wrap *gomock.Call +type MockSendConnLocalAddrCall struct { + *gomock.Call +} + +// Return rewrite *gomock.Call.Return +func (c *MockSendConnLocalAddrCall) Return(arg0 net.Addr) *MockSendConnLocalAddrCall { + c.Call = c.Call.Return(arg0) + return c +} + +// Do rewrite *gomock.Call.Do +func (c *MockSendConnLocalAddrCall) Do(f func() net.Addr) *MockSendConnLocalAddrCall { + c.Call = c.Call.Do(f) + return c +} + +// DoAndReturn rewrite *gomock.Call.DoAndReturn +func (c *MockSendConnLocalAddrCall) DoAndReturn(f func() net.Addr) *MockSendConnLocalAddrCall { + c.Call = c.Call.DoAndReturn(f) + return c } // RemoteAddr mocks base method. @@ -76,9 +125,33 @@ func (m *MockSendConn) RemoteAddr() net.Addr { } // RemoteAddr indicates an expected call of RemoteAddr. -func (mr *MockSendConnMockRecorder) RemoteAddr() *gomock.Call { +func (mr *MockSendConnMockRecorder) RemoteAddr() *MockSendConnRemoteAddrCall { mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "RemoteAddr", reflect.TypeOf((*MockSendConn)(nil).RemoteAddr)) + call := mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "RemoteAddr", reflect.TypeOf((*MockSendConn)(nil).RemoteAddr)) + return &MockSendConnRemoteAddrCall{Call: call} +} + +// MockSendConnRemoteAddrCall wrap *gomock.Call +type MockSendConnRemoteAddrCall struct { + *gomock.Call +} + +// Return rewrite *gomock.Call.Return +func (c *MockSendConnRemoteAddrCall) Return(arg0 net.Addr) *MockSendConnRemoteAddrCall { + c.Call = c.Call.Return(arg0) + return c +} + +// Do rewrite *gomock.Call.Do +func (c *MockSendConnRemoteAddrCall) Do(f func() net.Addr) *MockSendConnRemoteAddrCall { + c.Call = c.Call.Do(f) + return c +} + +// DoAndReturn rewrite *gomock.Call.DoAndReturn +func (c *MockSendConnRemoteAddrCall) DoAndReturn(f func() net.Addr) *MockSendConnRemoteAddrCall { + c.Call = c.Call.DoAndReturn(f) + return c } // Write mocks base method. @@ -90,9 +163,33 @@ func (m *MockSendConn) Write(arg0 []byte, arg1 uint16, arg2 protocol.ECN) error } // Write indicates an expected call of Write. -func (mr *MockSendConnMockRecorder) Write(arg0, arg1, arg2 any) *gomock.Call { +func (mr *MockSendConnMockRecorder) Write(arg0, arg1, arg2 any) *MockSendConnWriteCall { mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Write", reflect.TypeOf((*MockSendConn)(nil).Write), arg0, arg1, arg2) + call := mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Write", reflect.TypeOf((*MockSendConn)(nil).Write), arg0, arg1, arg2) + return &MockSendConnWriteCall{Call: call} +} + +// MockSendConnWriteCall wrap *gomock.Call +type MockSendConnWriteCall struct { + *gomock.Call +} + +// Return rewrite *gomock.Call.Return +func (c *MockSendConnWriteCall) Return(arg0 error) *MockSendConnWriteCall { + c.Call = c.Call.Return(arg0) + return c +} + +// Do rewrite *gomock.Call.Do +func (c *MockSendConnWriteCall) Do(f func([]byte, uint16, protocol.ECN) error) *MockSendConnWriteCall { + c.Call = c.Call.Do(f) + return c +} + +// DoAndReturn rewrite *gomock.Call.DoAndReturn +func (c *MockSendConnWriteCall) DoAndReturn(f func([]byte, uint16, protocol.ECN) error) *MockSendConnWriteCall { + c.Call = c.Call.DoAndReturn(f) + return c } // capabilities mocks base method. @@ -104,7 +201,31 @@ func (m *MockSendConn) capabilities() connCapabilities { } // capabilities indicates an expected call of capabilities. -func (mr *MockSendConnMockRecorder) capabilities() *gomock.Call { +func (mr *MockSendConnMockRecorder) capabilities() *MockSendConncapabilitiesCall { mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "capabilities", reflect.TypeOf((*MockSendConn)(nil).capabilities)) + call := mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "capabilities", reflect.TypeOf((*MockSendConn)(nil).capabilities)) + return &MockSendConncapabilitiesCall{Call: call} +} + +// MockSendConncapabilitiesCall wrap *gomock.Call +type MockSendConncapabilitiesCall struct { + *gomock.Call +} + +// Return rewrite *gomock.Call.Return +func (c *MockSendConncapabilitiesCall) Return(arg0 connCapabilities) *MockSendConncapabilitiesCall { + c.Call = c.Call.Return(arg0) + return c +} + +// Do rewrite *gomock.Call.Do +func (c *MockSendConncapabilitiesCall) Do(f func() connCapabilities) *MockSendConncapabilitiesCall { + c.Call = c.Call.Do(f) + return c +} + +// DoAndReturn rewrite *gomock.Call.DoAndReturn +func (c *MockSendConncapabilitiesCall) DoAndReturn(f func() connCapabilities) *MockSendConncapabilitiesCall { + c.Call = c.Call.DoAndReturn(f) + return c } diff --git a/mock_send_stream_internal_test.go b/mock_send_stream_internal_test.go index 6b878c943..4815a9d46 100644 --- a/mock_send_stream_internal_test.go +++ b/mock_send_stream_internal_test.go @@ -3,8 +3,9 @@ // // Generated by this command: // -// mockgen -build_flags=-tags=gomock -package quic -self_package github.com/refraction-networking/uquic -destination mock_send_stream_internal_test.go github.com/refraction-networking/uquic SendStreamI +// mockgen -typed -build_flags=-tags=gomock -package quic -self_package github.com/quic-go/quic-go -destination mock_send_stream_internal_test.go github.com/quic-go/quic-go SendStreamI // + // Package quic is a generated GoMock package. package quic @@ -50,9 +51,33 @@ func (m *MockSendStreamI) CancelWrite(arg0 qerr.StreamErrorCode) { } // CancelWrite indicates an expected call of CancelWrite. -func (mr *MockSendStreamIMockRecorder) CancelWrite(arg0 any) *gomock.Call { +func (mr *MockSendStreamIMockRecorder) CancelWrite(arg0 any) *MockSendStreamICancelWriteCall { mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "CancelWrite", reflect.TypeOf((*MockSendStreamI)(nil).CancelWrite), arg0) + call := mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "CancelWrite", reflect.TypeOf((*MockSendStreamI)(nil).CancelWrite), arg0) + return &MockSendStreamICancelWriteCall{Call: call} +} + +// MockSendStreamICancelWriteCall wrap *gomock.Call +type MockSendStreamICancelWriteCall struct { + *gomock.Call +} + +// Return rewrite *gomock.Call.Return +func (c *MockSendStreamICancelWriteCall) Return() *MockSendStreamICancelWriteCall { + c.Call = c.Call.Return() + return c +} + +// Do rewrite *gomock.Call.Do +func (c *MockSendStreamICancelWriteCall) Do(f func(qerr.StreamErrorCode)) *MockSendStreamICancelWriteCall { + c.Call = c.Call.Do(f) + return c +} + +// DoAndReturn rewrite *gomock.Call.DoAndReturn +func (c *MockSendStreamICancelWriteCall) DoAndReturn(f func(qerr.StreamErrorCode)) *MockSendStreamICancelWriteCall { + c.Call = c.Call.DoAndReturn(f) + return c } // Close mocks base method. @@ -64,9 +89,33 @@ func (m *MockSendStreamI) Close() error { } // Close indicates an expected call of Close. -func (mr *MockSendStreamIMockRecorder) Close() *gomock.Call { +func (mr *MockSendStreamIMockRecorder) Close() *MockSendStreamICloseCall { mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Close", reflect.TypeOf((*MockSendStreamI)(nil).Close)) + call := mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Close", reflect.TypeOf((*MockSendStreamI)(nil).Close)) + return &MockSendStreamICloseCall{Call: call} +} + +// MockSendStreamICloseCall wrap *gomock.Call +type MockSendStreamICloseCall struct { + *gomock.Call +} + +// Return rewrite *gomock.Call.Return +func (c *MockSendStreamICloseCall) Return(arg0 error) *MockSendStreamICloseCall { + c.Call = c.Call.Return(arg0) + return c +} + +// Do rewrite *gomock.Call.Do +func (c *MockSendStreamICloseCall) Do(f func() error) *MockSendStreamICloseCall { + c.Call = c.Call.Do(f) + return c +} + +// DoAndReturn rewrite *gomock.Call.DoAndReturn +func (c *MockSendStreamICloseCall) DoAndReturn(f func() error) *MockSendStreamICloseCall { + c.Call = c.Call.DoAndReturn(f) + return c } // Context mocks base method. @@ -78,9 +127,33 @@ func (m *MockSendStreamI) Context() context.Context { } // Context indicates an expected call of Context. -func (mr *MockSendStreamIMockRecorder) Context() *gomock.Call { +func (mr *MockSendStreamIMockRecorder) Context() *MockSendStreamIContextCall { mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Context", reflect.TypeOf((*MockSendStreamI)(nil).Context)) + call := mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Context", reflect.TypeOf((*MockSendStreamI)(nil).Context)) + return &MockSendStreamIContextCall{Call: call} +} + +// MockSendStreamIContextCall wrap *gomock.Call +type MockSendStreamIContextCall struct { + *gomock.Call +} + +// Return rewrite *gomock.Call.Return +func (c *MockSendStreamIContextCall) Return(arg0 context.Context) *MockSendStreamIContextCall { + c.Call = c.Call.Return(arg0) + return c +} + +// Do rewrite *gomock.Call.Do +func (c *MockSendStreamIContextCall) Do(f func() context.Context) *MockSendStreamIContextCall { + c.Call = c.Call.Do(f) + return c +} + +// DoAndReturn rewrite *gomock.Call.DoAndReturn +func (c *MockSendStreamIContextCall) DoAndReturn(f func() context.Context) *MockSendStreamIContextCall { + c.Call = c.Call.DoAndReturn(f) + return c } // SetWriteDeadline mocks base method. @@ -92,9 +165,33 @@ func (m *MockSendStreamI) SetWriteDeadline(arg0 time.Time) error { } // SetWriteDeadline indicates an expected call of SetWriteDeadline. -func (mr *MockSendStreamIMockRecorder) SetWriteDeadline(arg0 any) *gomock.Call { +func (mr *MockSendStreamIMockRecorder) SetWriteDeadline(arg0 any) *MockSendStreamISetWriteDeadlineCall { mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "SetWriteDeadline", reflect.TypeOf((*MockSendStreamI)(nil).SetWriteDeadline), arg0) + call := mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "SetWriteDeadline", reflect.TypeOf((*MockSendStreamI)(nil).SetWriteDeadline), arg0) + return &MockSendStreamISetWriteDeadlineCall{Call: call} +} + +// MockSendStreamISetWriteDeadlineCall wrap *gomock.Call +type MockSendStreamISetWriteDeadlineCall struct { + *gomock.Call +} + +// Return rewrite *gomock.Call.Return +func (c *MockSendStreamISetWriteDeadlineCall) Return(arg0 error) *MockSendStreamISetWriteDeadlineCall { + c.Call = c.Call.Return(arg0) + return c +} + +// Do rewrite *gomock.Call.Do +func (c *MockSendStreamISetWriteDeadlineCall) Do(f func(time.Time) error) *MockSendStreamISetWriteDeadlineCall { + c.Call = c.Call.Do(f) + return c +} + +// DoAndReturn rewrite *gomock.Call.DoAndReturn +func (c *MockSendStreamISetWriteDeadlineCall) DoAndReturn(f func(time.Time) error) *MockSendStreamISetWriteDeadlineCall { + c.Call = c.Call.DoAndReturn(f) + return c } // StreamID mocks base method. @@ -106,9 +203,33 @@ func (m *MockSendStreamI) StreamID() protocol.StreamID { } // StreamID indicates an expected call of StreamID. -func (mr *MockSendStreamIMockRecorder) StreamID() *gomock.Call { +func (mr *MockSendStreamIMockRecorder) StreamID() *MockSendStreamIStreamIDCall { mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "StreamID", reflect.TypeOf((*MockSendStreamI)(nil).StreamID)) + call := mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "StreamID", reflect.TypeOf((*MockSendStreamI)(nil).StreamID)) + return &MockSendStreamIStreamIDCall{Call: call} +} + +// MockSendStreamIStreamIDCall wrap *gomock.Call +type MockSendStreamIStreamIDCall struct { + *gomock.Call +} + +// Return rewrite *gomock.Call.Return +func (c *MockSendStreamIStreamIDCall) Return(arg0 protocol.StreamID) *MockSendStreamIStreamIDCall { + c.Call = c.Call.Return(arg0) + return c +} + +// Do rewrite *gomock.Call.Do +func (c *MockSendStreamIStreamIDCall) Do(f func() protocol.StreamID) *MockSendStreamIStreamIDCall { + c.Call = c.Call.Do(f) + return c +} + +// DoAndReturn rewrite *gomock.Call.DoAndReturn +func (c *MockSendStreamIStreamIDCall) DoAndReturn(f func() protocol.StreamID) *MockSendStreamIStreamIDCall { + c.Call = c.Call.DoAndReturn(f) + return c } // Write mocks base method. @@ -121,9 +242,33 @@ func (m *MockSendStreamI) Write(arg0 []byte) (int, error) { } // Write indicates an expected call of Write. -func (mr *MockSendStreamIMockRecorder) Write(arg0 any) *gomock.Call { +func (mr *MockSendStreamIMockRecorder) Write(arg0 any) *MockSendStreamIWriteCall { mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Write", reflect.TypeOf((*MockSendStreamI)(nil).Write), arg0) + call := mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Write", reflect.TypeOf((*MockSendStreamI)(nil).Write), arg0) + return &MockSendStreamIWriteCall{Call: call} +} + +// MockSendStreamIWriteCall wrap *gomock.Call +type MockSendStreamIWriteCall struct { + *gomock.Call +} + +// Return rewrite *gomock.Call.Return +func (c *MockSendStreamIWriteCall) Return(arg0 int, arg1 error) *MockSendStreamIWriteCall { + c.Call = c.Call.Return(arg0, arg1) + return c +} + +// Do rewrite *gomock.Call.Do +func (c *MockSendStreamIWriteCall) Do(f func([]byte) (int, error)) *MockSendStreamIWriteCall { + c.Call = c.Call.Do(f) + return c +} + +// DoAndReturn rewrite *gomock.Call.DoAndReturn +func (c *MockSendStreamIWriteCall) DoAndReturn(f func([]byte) (int, error)) *MockSendStreamIWriteCall { + c.Call = c.Call.DoAndReturn(f) + return c } // closeForShutdown mocks base method. @@ -133,9 +278,33 @@ func (m *MockSendStreamI) closeForShutdown(arg0 error) { } // closeForShutdown indicates an expected call of closeForShutdown. -func (mr *MockSendStreamIMockRecorder) closeForShutdown(arg0 any) *gomock.Call { +func (mr *MockSendStreamIMockRecorder) closeForShutdown(arg0 any) *MockSendStreamIcloseForShutdownCall { mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "closeForShutdown", reflect.TypeOf((*MockSendStreamI)(nil).closeForShutdown), arg0) + call := mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "closeForShutdown", reflect.TypeOf((*MockSendStreamI)(nil).closeForShutdown), arg0) + return &MockSendStreamIcloseForShutdownCall{Call: call} +} + +// MockSendStreamIcloseForShutdownCall wrap *gomock.Call +type MockSendStreamIcloseForShutdownCall struct { + *gomock.Call +} + +// Return rewrite *gomock.Call.Return +func (c *MockSendStreamIcloseForShutdownCall) Return() *MockSendStreamIcloseForShutdownCall { + c.Call = c.Call.Return() + return c +} + +// Do rewrite *gomock.Call.Do +func (c *MockSendStreamIcloseForShutdownCall) Do(f func(error)) *MockSendStreamIcloseForShutdownCall { + c.Call = c.Call.Do(f) + return c +} + +// DoAndReturn rewrite *gomock.Call.DoAndReturn +func (c *MockSendStreamIcloseForShutdownCall) DoAndReturn(f func(error)) *MockSendStreamIcloseForShutdownCall { + c.Call = c.Call.DoAndReturn(f) + return c } // handleStopSendingFrame mocks base method. @@ -145,9 +314,33 @@ func (m *MockSendStreamI) handleStopSendingFrame(arg0 *wire.StopSendingFrame) { } // handleStopSendingFrame indicates an expected call of handleStopSendingFrame. -func (mr *MockSendStreamIMockRecorder) handleStopSendingFrame(arg0 any) *gomock.Call { +func (mr *MockSendStreamIMockRecorder) handleStopSendingFrame(arg0 any) *MockSendStreamIhandleStopSendingFrameCall { mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "handleStopSendingFrame", reflect.TypeOf((*MockSendStreamI)(nil).handleStopSendingFrame), arg0) + call := mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "handleStopSendingFrame", reflect.TypeOf((*MockSendStreamI)(nil).handleStopSendingFrame), arg0) + return &MockSendStreamIhandleStopSendingFrameCall{Call: call} +} + +// MockSendStreamIhandleStopSendingFrameCall wrap *gomock.Call +type MockSendStreamIhandleStopSendingFrameCall struct { + *gomock.Call +} + +// Return rewrite *gomock.Call.Return +func (c *MockSendStreamIhandleStopSendingFrameCall) Return() *MockSendStreamIhandleStopSendingFrameCall { + c.Call = c.Call.Return() + return c +} + +// Do rewrite *gomock.Call.Do +func (c *MockSendStreamIhandleStopSendingFrameCall) Do(f func(*wire.StopSendingFrame)) *MockSendStreamIhandleStopSendingFrameCall { + c.Call = c.Call.Do(f) + return c +} + +// DoAndReturn rewrite *gomock.Call.DoAndReturn +func (c *MockSendStreamIhandleStopSendingFrameCall) DoAndReturn(f func(*wire.StopSendingFrame)) *MockSendStreamIhandleStopSendingFrameCall { + c.Call = c.Call.DoAndReturn(f) + return c } // hasData mocks base method. @@ -159,13 +352,37 @@ func (m *MockSendStreamI) hasData() bool { } // hasData indicates an expected call of hasData. -func (mr *MockSendStreamIMockRecorder) hasData() *gomock.Call { +func (mr *MockSendStreamIMockRecorder) hasData() *MockSendStreamIhasDataCall { mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "hasData", reflect.TypeOf((*MockSendStreamI)(nil).hasData)) + call := mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "hasData", reflect.TypeOf((*MockSendStreamI)(nil).hasData)) + return &MockSendStreamIhasDataCall{Call: call} +} + +// MockSendStreamIhasDataCall wrap *gomock.Call +type MockSendStreamIhasDataCall struct { + *gomock.Call +} + +// Return rewrite *gomock.Call.Return +func (c *MockSendStreamIhasDataCall) Return(arg0 bool) *MockSendStreamIhasDataCall { + c.Call = c.Call.Return(arg0) + return c +} + +// Do rewrite *gomock.Call.Do +func (c *MockSendStreamIhasDataCall) Do(f func() bool) *MockSendStreamIhasDataCall { + c.Call = c.Call.Do(f) + return c +} + +// DoAndReturn rewrite *gomock.Call.DoAndReturn +func (c *MockSendStreamIhasDataCall) DoAndReturn(f func() bool) *MockSendStreamIhasDataCall { + c.Call = c.Call.DoAndReturn(f) + return c } // popStreamFrame mocks base method. -func (m *MockSendStreamI) popStreamFrame(arg0 protocol.ByteCount, arg1 protocol.VersionNumber) (ackhandler.StreamFrame, bool, bool) { +func (m *MockSendStreamI) popStreamFrame(arg0 protocol.ByteCount, arg1 protocol.Version) (ackhandler.StreamFrame, bool, bool) { m.ctrl.T.Helper() ret := m.ctrl.Call(m, "popStreamFrame", arg0, arg1) ret0, _ := ret[0].(ackhandler.StreamFrame) @@ -175,9 +392,33 @@ func (m *MockSendStreamI) popStreamFrame(arg0 protocol.ByteCount, arg1 protocol. } // popStreamFrame indicates an expected call of popStreamFrame. -func (mr *MockSendStreamIMockRecorder) popStreamFrame(arg0, arg1 any) *gomock.Call { +func (mr *MockSendStreamIMockRecorder) popStreamFrame(arg0, arg1 any) *MockSendStreamIpopStreamFrameCall { mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "popStreamFrame", reflect.TypeOf((*MockSendStreamI)(nil).popStreamFrame), arg0, arg1) + call := mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "popStreamFrame", reflect.TypeOf((*MockSendStreamI)(nil).popStreamFrame), arg0, arg1) + return &MockSendStreamIpopStreamFrameCall{Call: call} +} + +// MockSendStreamIpopStreamFrameCall wrap *gomock.Call +type MockSendStreamIpopStreamFrameCall struct { + *gomock.Call +} + +// Return rewrite *gomock.Call.Return +func (c *MockSendStreamIpopStreamFrameCall) Return(arg0 ackhandler.StreamFrame, arg1, arg2 bool) *MockSendStreamIpopStreamFrameCall { + c.Call = c.Call.Return(arg0, arg1, arg2) + return c +} + +// Do rewrite *gomock.Call.Do +func (c *MockSendStreamIpopStreamFrameCall) Do(f func(protocol.ByteCount, protocol.Version) (ackhandler.StreamFrame, bool, bool)) *MockSendStreamIpopStreamFrameCall { + c.Call = c.Call.Do(f) + return c +} + +// DoAndReturn rewrite *gomock.Call.DoAndReturn +func (c *MockSendStreamIpopStreamFrameCall) DoAndReturn(f func(protocol.ByteCount, protocol.Version) (ackhandler.StreamFrame, bool, bool)) *MockSendStreamIpopStreamFrameCall { + c.Call = c.Call.DoAndReturn(f) + return c } // updateSendWindow mocks base method. @@ -187,7 +428,31 @@ func (m *MockSendStreamI) updateSendWindow(arg0 protocol.ByteCount) { } // updateSendWindow indicates an expected call of updateSendWindow. -func (mr *MockSendStreamIMockRecorder) updateSendWindow(arg0 any) *gomock.Call { +func (mr *MockSendStreamIMockRecorder) updateSendWindow(arg0 any) *MockSendStreamIupdateSendWindowCall { mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "updateSendWindow", reflect.TypeOf((*MockSendStreamI)(nil).updateSendWindow), arg0) + call := mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "updateSendWindow", reflect.TypeOf((*MockSendStreamI)(nil).updateSendWindow), arg0) + return &MockSendStreamIupdateSendWindowCall{Call: call} +} + +// MockSendStreamIupdateSendWindowCall wrap *gomock.Call +type MockSendStreamIupdateSendWindowCall struct { + *gomock.Call +} + +// Return rewrite *gomock.Call.Return +func (c *MockSendStreamIupdateSendWindowCall) Return() *MockSendStreamIupdateSendWindowCall { + c.Call = c.Call.Return() + return c +} + +// Do rewrite *gomock.Call.Do +func (c *MockSendStreamIupdateSendWindowCall) Do(f func(protocol.ByteCount)) *MockSendStreamIupdateSendWindowCall { + c.Call = c.Call.Do(f) + return c +} + +// DoAndReturn rewrite *gomock.Call.DoAndReturn +func (c *MockSendStreamIupdateSendWindowCall) DoAndReturn(f func(protocol.ByteCount)) *MockSendStreamIupdateSendWindowCall { + c.Call = c.Call.DoAndReturn(f) + return c } diff --git a/mock_sender_test.go b/mock_sender_test.go index 8e9291b56..bea8b8213 100644 --- a/mock_sender_test.go +++ b/mock_sender_test.go @@ -3,8 +3,9 @@ // // Generated by this command: // -// mockgen -build_flags=-tags=gomock -package quic -self_package github.com/refraction-networking/uquic -destination mock_sender_test.go github.com/refraction-networking/uquic Sender +// mockgen -typed -build_flags=-tags=gomock -package quic -self_package github.com/quic-go/quic-go -destination mock_sender_test.go github.com/quic-go/quic-go Sender // + // Package quic is a generated GoMock package. package quic @@ -47,9 +48,33 @@ func (m *MockSender) Available() <-chan struct{} { } // Available indicates an expected call of Available. -func (mr *MockSenderMockRecorder) Available() *gomock.Call { +func (mr *MockSenderMockRecorder) Available() *MockSenderAvailableCall { mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Available", reflect.TypeOf((*MockSender)(nil).Available)) + call := mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Available", reflect.TypeOf((*MockSender)(nil).Available)) + return &MockSenderAvailableCall{Call: call} +} + +// MockSenderAvailableCall wrap *gomock.Call +type MockSenderAvailableCall struct { + *gomock.Call +} + +// Return rewrite *gomock.Call.Return +func (c *MockSenderAvailableCall) Return(arg0 <-chan struct{}) *MockSenderAvailableCall { + c.Call = c.Call.Return(arg0) + return c +} + +// Do rewrite *gomock.Call.Do +func (c *MockSenderAvailableCall) Do(f func() <-chan struct{}) *MockSenderAvailableCall { + c.Call = c.Call.Do(f) + return c +} + +// DoAndReturn rewrite *gomock.Call.DoAndReturn +func (c *MockSenderAvailableCall) DoAndReturn(f func() <-chan struct{}) *MockSenderAvailableCall { + c.Call = c.Call.DoAndReturn(f) + return c } // Close mocks base method. @@ -59,9 +84,33 @@ func (m *MockSender) Close() { } // Close indicates an expected call of Close. -func (mr *MockSenderMockRecorder) Close() *gomock.Call { +func (mr *MockSenderMockRecorder) Close() *MockSenderCloseCall { mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Close", reflect.TypeOf((*MockSender)(nil).Close)) + call := mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Close", reflect.TypeOf((*MockSender)(nil).Close)) + return &MockSenderCloseCall{Call: call} +} + +// MockSenderCloseCall wrap *gomock.Call +type MockSenderCloseCall struct { + *gomock.Call +} + +// Return rewrite *gomock.Call.Return +func (c *MockSenderCloseCall) Return() *MockSenderCloseCall { + c.Call = c.Call.Return() + return c +} + +// Do rewrite *gomock.Call.Do +func (c *MockSenderCloseCall) Do(f func()) *MockSenderCloseCall { + c.Call = c.Call.Do(f) + return c +} + +// DoAndReturn rewrite *gomock.Call.DoAndReturn +func (c *MockSenderCloseCall) DoAndReturn(f func()) *MockSenderCloseCall { + c.Call = c.Call.DoAndReturn(f) + return c } // Run mocks base method. @@ -73,9 +122,33 @@ func (m *MockSender) Run() error { } // Run indicates an expected call of Run. -func (mr *MockSenderMockRecorder) Run() *gomock.Call { +func (mr *MockSenderMockRecorder) Run() *MockSenderRunCall { mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Run", reflect.TypeOf((*MockSender)(nil).Run)) + call := mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Run", reflect.TypeOf((*MockSender)(nil).Run)) + return &MockSenderRunCall{Call: call} +} + +// MockSenderRunCall wrap *gomock.Call +type MockSenderRunCall struct { + *gomock.Call +} + +// Return rewrite *gomock.Call.Return +func (c *MockSenderRunCall) Return(arg0 error) *MockSenderRunCall { + c.Call = c.Call.Return(arg0) + return c +} + +// Do rewrite *gomock.Call.Do +func (c *MockSenderRunCall) Do(f func() error) *MockSenderRunCall { + c.Call = c.Call.Do(f) + return c +} + +// DoAndReturn rewrite *gomock.Call.DoAndReturn +func (c *MockSenderRunCall) DoAndReturn(f func() error) *MockSenderRunCall { + c.Call = c.Call.DoAndReturn(f) + return c } // Send mocks base method. @@ -85,9 +158,33 @@ func (m *MockSender) Send(arg0 *packetBuffer, arg1 uint16, arg2 protocol.ECN) { } // Send indicates an expected call of Send. -func (mr *MockSenderMockRecorder) Send(arg0, arg1, arg2 any) *gomock.Call { +func (mr *MockSenderMockRecorder) Send(arg0, arg1, arg2 any) *MockSenderSendCall { mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Send", reflect.TypeOf((*MockSender)(nil).Send), arg0, arg1, arg2) + call := mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Send", reflect.TypeOf((*MockSender)(nil).Send), arg0, arg1, arg2) + return &MockSenderSendCall{Call: call} +} + +// MockSenderSendCall wrap *gomock.Call +type MockSenderSendCall struct { + *gomock.Call +} + +// Return rewrite *gomock.Call.Return +func (c *MockSenderSendCall) Return() *MockSenderSendCall { + c.Call = c.Call.Return() + return c +} + +// Do rewrite *gomock.Call.Do +func (c *MockSenderSendCall) Do(f func(*packetBuffer, uint16, protocol.ECN)) *MockSenderSendCall { + c.Call = c.Call.Do(f) + return c +} + +// DoAndReturn rewrite *gomock.Call.DoAndReturn +func (c *MockSenderSendCall) DoAndReturn(f func(*packetBuffer, uint16, protocol.ECN)) *MockSenderSendCall { + c.Call = c.Call.DoAndReturn(f) + return c } // WouldBlock mocks base method. @@ -99,7 +196,31 @@ func (m *MockSender) WouldBlock() bool { } // WouldBlock indicates an expected call of WouldBlock. -func (mr *MockSenderMockRecorder) WouldBlock() *gomock.Call { +func (mr *MockSenderMockRecorder) WouldBlock() *MockSenderWouldBlockCall { mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "WouldBlock", reflect.TypeOf((*MockSender)(nil).WouldBlock)) + call := mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "WouldBlock", reflect.TypeOf((*MockSender)(nil).WouldBlock)) + return &MockSenderWouldBlockCall{Call: call} +} + +// MockSenderWouldBlockCall wrap *gomock.Call +type MockSenderWouldBlockCall struct { + *gomock.Call +} + +// Return rewrite *gomock.Call.Return +func (c *MockSenderWouldBlockCall) Return(arg0 bool) *MockSenderWouldBlockCall { + c.Call = c.Call.Return(arg0) + return c +} + +// Do rewrite *gomock.Call.Do +func (c *MockSenderWouldBlockCall) Do(f func() bool) *MockSenderWouldBlockCall { + c.Call = c.Call.Do(f) + return c +} + +// DoAndReturn rewrite *gomock.Call.DoAndReturn +func (c *MockSenderWouldBlockCall) DoAndReturn(f func() bool) *MockSenderWouldBlockCall { + c.Call = c.Call.DoAndReturn(f) + return c } diff --git a/mock_stream_getter_test.go b/mock_stream_getter_test.go index 21a873c98..06785a940 100644 --- a/mock_stream_getter_test.go +++ b/mock_stream_getter_test.go @@ -3,8 +3,9 @@ // // Generated by this command: // -// mockgen -build_flags=-tags=gomock -package quic -self_package github.com/refraction-networking/uquic -destination mock_stream_getter_test.go github.com/refraction-networking/uquic StreamGetter +// mockgen -typed -build_flags=-tags=gomock -package quic -self_package github.com/quic-go/quic-go -destination mock_stream_getter_test.go github.com/quic-go/quic-go StreamGetter // + // Package quic is a generated GoMock package. package quic @@ -48,9 +49,33 @@ func (m *MockStreamGetter) GetOrOpenReceiveStream(arg0 protocol.StreamID) (recei } // GetOrOpenReceiveStream indicates an expected call of GetOrOpenReceiveStream. -func (mr *MockStreamGetterMockRecorder) GetOrOpenReceiveStream(arg0 any) *gomock.Call { +func (mr *MockStreamGetterMockRecorder) GetOrOpenReceiveStream(arg0 any) *MockStreamGetterGetOrOpenReceiveStreamCall { mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetOrOpenReceiveStream", reflect.TypeOf((*MockStreamGetter)(nil).GetOrOpenReceiveStream), arg0) + call := mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetOrOpenReceiveStream", reflect.TypeOf((*MockStreamGetter)(nil).GetOrOpenReceiveStream), arg0) + return &MockStreamGetterGetOrOpenReceiveStreamCall{Call: call} +} + +// MockStreamGetterGetOrOpenReceiveStreamCall wrap *gomock.Call +type MockStreamGetterGetOrOpenReceiveStreamCall struct { + *gomock.Call +} + +// Return rewrite *gomock.Call.Return +func (c *MockStreamGetterGetOrOpenReceiveStreamCall) Return(arg0 receiveStreamI, arg1 error) *MockStreamGetterGetOrOpenReceiveStreamCall { + c.Call = c.Call.Return(arg0, arg1) + return c +} + +// Do rewrite *gomock.Call.Do +func (c *MockStreamGetterGetOrOpenReceiveStreamCall) Do(f func(protocol.StreamID) (receiveStreamI, error)) *MockStreamGetterGetOrOpenReceiveStreamCall { + c.Call = c.Call.Do(f) + return c +} + +// DoAndReturn rewrite *gomock.Call.DoAndReturn +func (c *MockStreamGetterGetOrOpenReceiveStreamCall) DoAndReturn(f func(protocol.StreamID) (receiveStreamI, error)) *MockStreamGetterGetOrOpenReceiveStreamCall { + c.Call = c.Call.DoAndReturn(f) + return c } // GetOrOpenSendStream mocks base method. @@ -63,7 +88,31 @@ func (m *MockStreamGetter) GetOrOpenSendStream(arg0 protocol.StreamID) (sendStre } // GetOrOpenSendStream indicates an expected call of GetOrOpenSendStream. -func (mr *MockStreamGetterMockRecorder) GetOrOpenSendStream(arg0 any) *gomock.Call { +func (mr *MockStreamGetterMockRecorder) GetOrOpenSendStream(arg0 any) *MockStreamGetterGetOrOpenSendStreamCall { mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetOrOpenSendStream", reflect.TypeOf((*MockStreamGetter)(nil).GetOrOpenSendStream), arg0) + call := mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetOrOpenSendStream", reflect.TypeOf((*MockStreamGetter)(nil).GetOrOpenSendStream), arg0) + return &MockStreamGetterGetOrOpenSendStreamCall{Call: call} +} + +// MockStreamGetterGetOrOpenSendStreamCall wrap *gomock.Call +type MockStreamGetterGetOrOpenSendStreamCall struct { + *gomock.Call +} + +// Return rewrite *gomock.Call.Return +func (c *MockStreamGetterGetOrOpenSendStreamCall) Return(arg0 sendStreamI, arg1 error) *MockStreamGetterGetOrOpenSendStreamCall { + c.Call = c.Call.Return(arg0, arg1) + return c +} + +// Do rewrite *gomock.Call.Do +func (c *MockStreamGetterGetOrOpenSendStreamCall) Do(f func(protocol.StreamID) (sendStreamI, error)) *MockStreamGetterGetOrOpenSendStreamCall { + c.Call = c.Call.Do(f) + return c +} + +// DoAndReturn rewrite *gomock.Call.DoAndReturn +func (c *MockStreamGetterGetOrOpenSendStreamCall) DoAndReturn(f func(protocol.StreamID) (sendStreamI, error)) *MockStreamGetterGetOrOpenSendStreamCall { + c.Call = c.Call.DoAndReturn(f) + return c } diff --git a/mock_stream_internal_test.go b/mock_stream_internal_test.go index 3c482fb23..5ed36ec93 100644 --- a/mock_stream_internal_test.go +++ b/mock_stream_internal_test.go @@ -3,8 +3,9 @@ // // Generated by this command: // -// mockgen -build_flags=-tags=gomock -package quic -self_package github.com/refraction-networking/uquic -destination mock_stream_internal_test.go github.com/refraction-networking/uquic StreamI +// mockgen -typed -build_flags=-tags=gomock -package quic -self_package github.com/quic-go/quic-go -destination mock_stream_internal_test.go github.com/quic-go/quic-go StreamI // + // Package quic is a generated GoMock package. package quic @@ -50,9 +51,33 @@ func (m *MockStreamI) CancelRead(arg0 qerr.StreamErrorCode) { } // CancelRead indicates an expected call of CancelRead. -func (mr *MockStreamIMockRecorder) CancelRead(arg0 any) *gomock.Call { +func (mr *MockStreamIMockRecorder) CancelRead(arg0 any) *MockStreamICancelReadCall { mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "CancelRead", reflect.TypeOf((*MockStreamI)(nil).CancelRead), arg0) + call := mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "CancelRead", reflect.TypeOf((*MockStreamI)(nil).CancelRead), arg0) + return &MockStreamICancelReadCall{Call: call} +} + +// MockStreamICancelReadCall wrap *gomock.Call +type MockStreamICancelReadCall struct { + *gomock.Call +} + +// Return rewrite *gomock.Call.Return +func (c *MockStreamICancelReadCall) Return() *MockStreamICancelReadCall { + c.Call = c.Call.Return() + return c +} + +// Do rewrite *gomock.Call.Do +func (c *MockStreamICancelReadCall) Do(f func(qerr.StreamErrorCode)) *MockStreamICancelReadCall { + c.Call = c.Call.Do(f) + return c +} + +// DoAndReturn rewrite *gomock.Call.DoAndReturn +func (c *MockStreamICancelReadCall) DoAndReturn(f func(qerr.StreamErrorCode)) *MockStreamICancelReadCall { + c.Call = c.Call.DoAndReturn(f) + return c } // CancelWrite mocks base method. @@ -62,9 +87,33 @@ func (m *MockStreamI) CancelWrite(arg0 qerr.StreamErrorCode) { } // CancelWrite indicates an expected call of CancelWrite. -func (mr *MockStreamIMockRecorder) CancelWrite(arg0 any) *gomock.Call { +func (mr *MockStreamIMockRecorder) CancelWrite(arg0 any) *MockStreamICancelWriteCall { mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "CancelWrite", reflect.TypeOf((*MockStreamI)(nil).CancelWrite), arg0) + call := mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "CancelWrite", reflect.TypeOf((*MockStreamI)(nil).CancelWrite), arg0) + return &MockStreamICancelWriteCall{Call: call} +} + +// MockStreamICancelWriteCall wrap *gomock.Call +type MockStreamICancelWriteCall struct { + *gomock.Call +} + +// Return rewrite *gomock.Call.Return +func (c *MockStreamICancelWriteCall) Return() *MockStreamICancelWriteCall { + c.Call = c.Call.Return() + return c +} + +// Do rewrite *gomock.Call.Do +func (c *MockStreamICancelWriteCall) Do(f func(qerr.StreamErrorCode)) *MockStreamICancelWriteCall { + c.Call = c.Call.Do(f) + return c +} + +// DoAndReturn rewrite *gomock.Call.DoAndReturn +func (c *MockStreamICancelWriteCall) DoAndReturn(f func(qerr.StreamErrorCode)) *MockStreamICancelWriteCall { + c.Call = c.Call.DoAndReturn(f) + return c } // Close mocks base method. @@ -76,9 +125,33 @@ func (m *MockStreamI) Close() error { } // Close indicates an expected call of Close. -func (mr *MockStreamIMockRecorder) Close() *gomock.Call { +func (mr *MockStreamIMockRecorder) Close() *MockStreamICloseCall { mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Close", reflect.TypeOf((*MockStreamI)(nil).Close)) + call := mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Close", reflect.TypeOf((*MockStreamI)(nil).Close)) + return &MockStreamICloseCall{Call: call} +} + +// MockStreamICloseCall wrap *gomock.Call +type MockStreamICloseCall struct { + *gomock.Call +} + +// Return rewrite *gomock.Call.Return +func (c *MockStreamICloseCall) Return(arg0 error) *MockStreamICloseCall { + c.Call = c.Call.Return(arg0) + return c +} + +// Do rewrite *gomock.Call.Do +func (c *MockStreamICloseCall) Do(f func() error) *MockStreamICloseCall { + c.Call = c.Call.Do(f) + return c +} + +// DoAndReturn rewrite *gomock.Call.DoAndReturn +func (c *MockStreamICloseCall) DoAndReturn(f func() error) *MockStreamICloseCall { + c.Call = c.Call.DoAndReturn(f) + return c } // Context mocks base method. @@ -90,9 +163,33 @@ func (m *MockStreamI) Context() context.Context { } // Context indicates an expected call of Context. -func (mr *MockStreamIMockRecorder) Context() *gomock.Call { +func (mr *MockStreamIMockRecorder) Context() *MockStreamIContextCall { mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Context", reflect.TypeOf((*MockStreamI)(nil).Context)) + call := mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Context", reflect.TypeOf((*MockStreamI)(nil).Context)) + return &MockStreamIContextCall{Call: call} +} + +// MockStreamIContextCall wrap *gomock.Call +type MockStreamIContextCall struct { + *gomock.Call +} + +// Return rewrite *gomock.Call.Return +func (c *MockStreamIContextCall) Return(arg0 context.Context) *MockStreamIContextCall { + c.Call = c.Call.Return(arg0) + return c +} + +// Do rewrite *gomock.Call.Do +func (c *MockStreamIContextCall) Do(f func() context.Context) *MockStreamIContextCall { + c.Call = c.Call.Do(f) + return c +} + +// DoAndReturn rewrite *gomock.Call.DoAndReturn +func (c *MockStreamIContextCall) DoAndReturn(f func() context.Context) *MockStreamIContextCall { + c.Call = c.Call.DoAndReturn(f) + return c } // Read mocks base method. @@ -105,9 +202,33 @@ func (m *MockStreamI) Read(arg0 []byte) (int, error) { } // Read indicates an expected call of Read. -func (mr *MockStreamIMockRecorder) Read(arg0 any) *gomock.Call { +func (mr *MockStreamIMockRecorder) Read(arg0 any) *MockStreamIReadCall { mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Read", reflect.TypeOf((*MockStreamI)(nil).Read), arg0) + call := mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Read", reflect.TypeOf((*MockStreamI)(nil).Read), arg0) + return &MockStreamIReadCall{Call: call} +} + +// MockStreamIReadCall wrap *gomock.Call +type MockStreamIReadCall struct { + *gomock.Call +} + +// Return rewrite *gomock.Call.Return +func (c *MockStreamIReadCall) Return(arg0 int, arg1 error) *MockStreamIReadCall { + c.Call = c.Call.Return(arg0, arg1) + return c +} + +// Do rewrite *gomock.Call.Do +func (c *MockStreamIReadCall) Do(f func([]byte) (int, error)) *MockStreamIReadCall { + c.Call = c.Call.Do(f) + return c +} + +// DoAndReturn rewrite *gomock.Call.DoAndReturn +func (c *MockStreamIReadCall) DoAndReturn(f func([]byte) (int, error)) *MockStreamIReadCall { + c.Call = c.Call.DoAndReturn(f) + return c } // SetDeadline mocks base method. @@ -119,9 +240,33 @@ func (m *MockStreamI) SetDeadline(arg0 time.Time) error { } // SetDeadline indicates an expected call of SetDeadline. -func (mr *MockStreamIMockRecorder) SetDeadline(arg0 any) *gomock.Call { +func (mr *MockStreamIMockRecorder) SetDeadline(arg0 any) *MockStreamISetDeadlineCall { mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "SetDeadline", reflect.TypeOf((*MockStreamI)(nil).SetDeadline), arg0) + call := mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "SetDeadline", reflect.TypeOf((*MockStreamI)(nil).SetDeadline), arg0) + return &MockStreamISetDeadlineCall{Call: call} +} + +// MockStreamISetDeadlineCall wrap *gomock.Call +type MockStreamISetDeadlineCall struct { + *gomock.Call +} + +// Return rewrite *gomock.Call.Return +func (c *MockStreamISetDeadlineCall) Return(arg0 error) *MockStreamISetDeadlineCall { + c.Call = c.Call.Return(arg0) + return c +} + +// Do rewrite *gomock.Call.Do +func (c *MockStreamISetDeadlineCall) Do(f func(time.Time) error) *MockStreamISetDeadlineCall { + c.Call = c.Call.Do(f) + return c +} + +// DoAndReturn rewrite *gomock.Call.DoAndReturn +func (c *MockStreamISetDeadlineCall) DoAndReturn(f func(time.Time) error) *MockStreamISetDeadlineCall { + c.Call = c.Call.DoAndReturn(f) + return c } // SetReadDeadline mocks base method. @@ -133,9 +278,33 @@ func (m *MockStreamI) SetReadDeadline(arg0 time.Time) error { } // SetReadDeadline indicates an expected call of SetReadDeadline. -func (mr *MockStreamIMockRecorder) SetReadDeadline(arg0 any) *gomock.Call { +func (mr *MockStreamIMockRecorder) SetReadDeadline(arg0 any) *MockStreamISetReadDeadlineCall { mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "SetReadDeadline", reflect.TypeOf((*MockStreamI)(nil).SetReadDeadline), arg0) + call := mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "SetReadDeadline", reflect.TypeOf((*MockStreamI)(nil).SetReadDeadline), arg0) + return &MockStreamISetReadDeadlineCall{Call: call} +} + +// MockStreamISetReadDeadlineCall wrap *gomock.Call +type MockStreamISetReadDeadlineCall struct { + *gomock.Call +} + +// Return rewrite *gomock.Call.Return +func (c *MockStreamISetReadDeadlineCall) Return(arg0 error) *MockStreamISetReadDeadlineCall { + c.Call = c.Call.Return(arg0) + return c +} + +// Do rewrite *gomock.Call.Do +func (c *MockStreamISetReadDeadlineCall) Do(f func(time.Time) error) *MockStreamISetReadDeadlineCall { + c.Call = c.Call.Do(f) + return c +} + +// DoAndReturn rewrite *gomock.Call.DoAndReturn +func (c *MockStreamISetReadDeadlineCall) DoAndReturn(f func(time.Time) error) *MockStreamISetReadDeadlineCall { + c.Call = c.Call.DoAndReturn(f) + return c } // SetWriteDeadline mocks base method. @@ -147,9 +316,33 @@ func (m *MockStreamI) SetWriteDeadline(arg0 time.Time) error { } // SetWriteDeadline indicates an expected call of SetWriteDeadline. -func (mr *MockStreamIMockRecorder) SetWriteDeadline(arg0 any) *gomock.Call { +func (mr *MockStreamIMockRecorder) SetWriteDeadline(arg0 any) *MockStreamISetWriteDeadlineCall { mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "SetWriteDeadline", reflect.TypeOf((*MockStreamI)(nil).SetWriteDeadline), arg0) + call := mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "SetWriteDeadline", reflect.TypeOf((*MockStreamI)(nil).SetWriteDeadline), arg0) + return &MockStreamISetWriteDeadlineCall{Call: call} +} + +// MockStreamISetWriteDeadlineCall wrap *gomock.Call +type MockStreamISetWriteDeadlineCall struct { + *gomock.Call +} + +// Return rewrite *gomock.Call.Return +func (c *MockStreamISetWriteDeadlineCall) Return(arg0 error) *MockStreamISetWriteDeadlineCall { + c.Call = c.Call.Return(arg0) + return c +} + +// Do rewrite *gomock.Call.Do +func (c *MockStreamISetWriteDeadlineCall) Do(f func(time.Time) error) *MockStreamISetWriteDeadlineCall { + c.Call = c.Call.Do(f) + return c +} + +// DoAndReturn rewrite *gomock.Call.DoAndReturn +func (c *MockStreamISetWriteDeadlineCall) DoAndReturn(f func(time.Time) error) *MockStreamISetWriteDeadlineCall { + c.Call = c.Call.DoAndReturn(f) + return c } // StreamID mocks base method. @@ -161,9 +354,33 @@ func (m *MockStreamI) StreamID() protocol.StreamID { } // StreamID indicates an expected call of StreamID. -func (mr *MockStreamIMockRecorder) StreamID() *gomock.Call { +func (mr *MockStreamIMockRecorder) StreamID() *MockStreamIStreamIDCall { mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "StreamID", reflect.TypeOf((*MockStreamI)(nil).StreamID)) + call := mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "StreamID", reflect.TypeOf((*MockStreamI)(nil).StreamID)) + return &MockStreamIStreamIDCall{Call: call} +} + +// MockStreamIStreamIDCall wrap *gomock.Call +type MockStreamIStreamIDCall struct { + *gomock.Call +} + +// Return rewrite *gomock.Call.Return +func (c *MockStreamIStreamIDCall) Return(arg0 protocol.StreamID) *MockStreamIStreamIDCall { + c.Call = c.Call.Return(arg0) + return c +} + +// Do rewrite *gomock.Call.Do +func (c *MockStreamIStreamIDCall) Do(f func() protocol.StreamID) *MockStreamIStreamIDCall { + c.Call = c.Call.Do(f) + return c +} + +// DoAndReturn rewrite *gomock.Call.DoAndReturn +func (c *MockStreamIStreamIDCall) DoAndReturn(f func() protocol.StreamID) *MockStreamIStreamIDCall { + c.Call = c.Call.DoAndReturn(f) + return c } // Write mocks base method. @@ -176,9 +393,33 @@ func (m *MockStreamI) Write(arg0 []byte) (int, error) { } // Write indicates an expected call of Write. -func (mr *MockStreamIMockRecorder) Write(arg0 any) *gomock.Call { +func (mr *MockStreamIMockRecorder) Write(arg0 any) *MockStreamIWriteCall { mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Write", reflect.TypeOf((*MockStreamI)(nil).Write), arg0) + call := mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Write", reflect.TypeOf((*MockStreamI)(nil).Write), arg0) + return &MockStreamIWriteCall{Call: call} +} + +// MockStreamIWriteCall wrap *gomock.Call +type MockStreamIWriteCall struct { + *gomock.Call +} + +// Return rewrite *gomock.Call.Return +func (c *MockStreamIWriteCall) Return(arg0 int, arg1 error) *MockStreamIWriteCall { + c.Call = c.Call.Return(arg0, arg1) + return c +} + +// Do rewrite *gomock.Call.Do +func (c *MockStreamIWriteCall) Do(f func([]byte) (int, error)) *MockStreamIWriteCall { + c.Call = c.Call.Do(f) + return c +} + +// DoAndReturn rewrite *gomock.Call.DoAndReturn +func (c *MockStreamIWriteCall) DoAndReturn(f func([]byte) (int, error)) *MockStreamIWriteCall { + c.Call = c.Call.DoAndReturn(f) + return c } // closeForShutdown mocks base method. @@ -188,9 +429,33 @@ func (m *MockStreamI) closeForShutdown(arg0 error) { } // closeForShutdown indicates an expected call of closeForShutdown. -func (mr *MockStreamIMockRecorder) closeForShutdown(arg0 any) *gomock.Call { +func (mr *MockStreamIMockRecorder) closeForShutdown(arg0 any) *MockStreamIcloseForShutdownCall { mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "closeForShutdown", reflect.TypeOf((*MockStreamI)(nil).closeForShutdown), arg0) + call := mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "closeForShutdown", reflect.TypeOf((*MockStreamI)(nil).closeForShutdown), arg0) + return &MockStreamIcloseForShutdownCall{Call: call} +} + +// MockStreamIcloseForShutdownCall wrap *gomock.Call +type MockStreamIcloseForShutdownCall struct { + *gomock.Call +} + +// Return rewrite *gomock.Call.Return +func (c *MockStreamIcloseForShutdownCall) Return() *MockStreamIcloseForShutdownCall { + c.Call = c.Call.Return() + return c +} + +// Do rewrite *gomock.Call.Do +func (c *MockStreamIcloseForShutdownCall) Do(f func(error)) *MockStreamIcloseForShutdownCall { + c.Call = c.Call.Do(f) + return c +} + +// DoAndReturn rewrite *gomock.Call.DoAndReturn +func (c *MockStreamIcloseForShutdownCall) DoAndReturn(f func(error)) *MockStreamIcloseForShutdownCall { + c.Call = c.Call.DoAndReturn(f) + return c } // getWindowUpdate mocks base method. @@ -202,9 +467,33 @@ func (m *MockStreamI) getWindowUpdate() protocol.ByteCount { } // getWindowUpdate indicates an expected call of getWindowUpdate. -func (mr *MockStreamIMockRecorder) getWindowUpdate() *gomock.Call { +func (mr *MockStreamIMockRecorder) getWindowUpdate() *MockStreamIgetWindowUpdateCall { mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "getWindowUpdate", reflect.TypeOf((*MockStreamI)(nil).getWindowUpdate)) + call := mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "getWindowUpdate", reflect.TypeOf((*MockStreamI)(nil).getWindowUpdate)) + return &MockStreamIgetWindowUpdateCall{Call: call} +} + +// MockStreamIgetWindowUpdateCall wrap *gomock.Call +type MockStreamIgetWindowUpdateCall struct { + *gomock.Call +} + +// Return rewrite *gomock.Call.Return +func (c *MockStreamIgetWindowUpdateCall) Return(arg0 protocol.ByteCount) *MockStreamIgetWindowUpdateCall { + c.Call = c.Call.Return(arg0) + return c +} + +// Do rewrite *gomock.Call.Do +func (c *MockStreamIgetWindowUpdateCall) Do(f func() protocol.ByteCount) *MockStreamIgetWindowUpdateCall { + c.Call = c.Call.Do(f) + return c +} + +// DoAndReturn rewrite *gomock.Call.DoAndReturn +func (c *MockStreamIgetWindowUpdateCall) DoAndReturn(f func() protocol.ByteCount) *MockStreamIgetWindowUpdateCall { + c.Call = c.Call.DoAndReturn(f) + return c } // handleResetStreamFrame mocks base method. @@ -216,9 +505,33 @@ func (m *MockStreamI) handleResetStreamFrame(arg0 *wire.ResetStreamFrame) error } // handleResetStreamFrame indicates an expected call of handleResetStreamFrame. -func (mr *MockStreamIMockRecorder) handleResetStreamFrame(arg0 any) *gomock.Call { +func (mr *MockStreamIMockRecorder) handleResetStreamFrame(arg0 any) *MockStreamIhandleResetStreamFrameCall { mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "handleResetStreamFrame", reflect.TypeOf((*MockStreamI)(nil).handleResetStreamFrame), arg0) + call := mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "handleResetStreamFrame", reflect.TypeOf((*MockStreamI)(nil).handleResetStreamFrame), arg0) + return &MockStreamIhandleResetStreamFrameCall{Call: call} +} + +// MockStreamIhandleResetStreamFrameCall wrap *gomock.Call +type MockStreamIhandleResetStreamFrameCall struct { + *gomock.Call +} + +// Return rewrite *gomock.Call.Return +func (c *MockStreamIhandleResetStreamFrameCall) Return(arg0 error) *MockStreamIhandleResetStreamFrameCall { + c.Call = c.Call.Return(arg0) + return c +} + +// Do rewrite *gomock.Call.Do +func (c *MockStreamIhandleResetStreamFrameCall) Do(f func(*wire.ResetStreamFrame) error) *MockStreamIhandleResetStreamFrameCall { + c.Call = c.Call.Do(f) + return c +} + +// DoAndReturn rewrite *gomock.Call.DoAndReturn +func (c *MockStreamIhandleResetStreamFrameCall) DoAndReturn(f func(*wire.ResetStreamFrame) error) *MockStreamIhandleResetStreamFrameCall { + c.Call = c.Call.DoAndReturn(f) + return c } // handleStopSendingFrame mocks base method. @@ -228,9 +541,33 @@ func (m *MockStreamI) handleStopSendingFrame(arg0 *wire.StopSendingFrame) { } // handleStopSendingFrame indicates an expected call of handleStopSendingFrame. -func (mr *MockStreamIMockRecorder) handleStopSendingFrame(arg0 any) *gomock.Call { +func (mr *MockStreamIMockRecorder) handleStopSendingFrame(arg0 any) *MockStreamIhandleStopSendingFrameCall { mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "handleStopSendingFrame", reflect.TypeOf((*MockStreamI)(nil).handleStopSendingFrame), arg0) + call := mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "handleStopSendingFrame", reflect.TypeOf((*MockStreamI)(nil).handleStopSendingFrame), arg0) + return &MockStreamIhandleStopSendingFrameCall{Call: call} +} + +// MockStreamIhandleStopSendingFrameCall wrap *gomock.Call +type MockStreamIhandleStopSendingFrameCall struct { + *gomock.Call +} + +// Return rewrite *gomock.Call.Return +func (c *MockStreamIhandleStopSendingFrameCall) Return() *MockStreamIhandleStopSendingFrameCall { + c.Call = c.Call.Return() + return c +} + +// Do rewrite *gomock.Call.Do +func (c *MockStreamIhandleStopSendingFrameCall) Do(f func(*wire.StopSendingFrame)) *MockStreamIhandleStopSendingFrameCall { + c.Call = c.Call.Do(f) + return c +} + +// DoAndReturn rewrite *gomock.Call.DoAndReturn +func (c *MockStreamIhandleStopSendingFrameCall) DoAndReturn(f func(*wire.StopSendingFrame)) *MockStreamIhandleStopSendingFrameCall { + c.Call = c.Call.DoAndReturn(f) + return c } // handleStreamFrame mocks base method. @@ -242,9 +579,33 @@ func (m *MockStreamI) handleStreamFrame(arg0 *wire.StreamFrame) error { } // handleStreamFrame indicates an expected call of handleStreamFrame. -func (mr *MockStreamIMockRecorder) handleStreamFrame(arg0 any) *gomock.Call { +func (mr *MockStreamIMockRecorder) handleStreamFrame(arg0 any) *MockStreamIhandleStreamFrameCall { mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "handleStreamFrame", reflect.TypeOf((*MockStreamI)(nil).handleStreamFrame), arg0) + call := mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "handleStreamFrame", reflect.TypeOf((*MockStreamI)(nil).handleStreamFrame), arg0) + return &MockStreamIhandleStreamFrameCall{Call: call} +} + +// MockStreamIhandleStreamFrameCall wrap *gomock.Call +type MockStreamIhandleStreamFrameCall struct { + *gomock.Call +} + +// Return rewrite *gomock.Call.Return +func (c *MockStreamIhandleStreamFrameCall) Return(arg0 error) *MockStreamIhandleStreamFrameCall { + c.Call = c.Call.Return(arg0) + return c +} + +// Do rewrite *gomock.Call.Do +func (c *MockStreamIhandleStreamFrameCall) Do(f func(*wire.StreamFrame) error) *MockStreamIhandleStreamFrameCall { + c.Call = c.Call.Do(f) + return c +} + +// DoAndReturn rewrite *gomock.Call.DoAndReturn +func (c *MockStreamIhandleStreamFrameCall) DoAndReturn(f func(*wire.StreamFrame) error) *MockStreamIhandleStreamFrameCall { + c.Call = c.Call.DoAndReturn(f) + return c } // hasData mocks base method. @@ -256,13 +617,37 @@ func (m *MockStreamI) hasData() bool { } // hasData indicates an expected call of hasData. -func (mr *MockStreamIMockRecorder) hasData() *gomock.Call { +func (mr *MockStreamIMockRecorder) hasData() *MockStreamIhasDataCall { mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "hasData", reflect.TypeOf((*MockStreamI)(nil).hasData)) + call := mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "hasData", reflect.TypeOf((*MockStreamI)(nil).hasData)) + return &MockStreamIhasDataCall{Call: call} +} + +// MockStreamIhasDataCall wrap *gomock.Call +type MockStreamIhasDataCall struct { + *gomock.Call +} + +// Return rewrite *gomock.Call.Return +func (c *MockStreamIhasDataCall) Return(arg0 bool) *MockStreamIhasDataCall { + c.Call = c.Call.Return(arg0) + return c +} + +// Do rewrite *gomock.Call.Do +func (c *MockStreamIhasDataCall) Do(f func() bool) *MockStreamIhasDataCall { + c.Call = c.Call.Do(f) + return c +} + +// DoAndReturn rewrite *gomock.Call.DoAndReturn +func (c *MockStreamIhasDataCall) DoAndReturn(f func() bool) *MockStreamIhasDataCall { + c.Call = c.Call.DoAndReturn(f) + return c } // popStreamFrame mocks base method. -func (m *MockStreamI) popStreamFrame(arg0 protocol.ByteCount, arg1 protocol.VersionNumber) (ackhandler.StreamFrame, bool, bool) { +func (m *MockStreamI) popStreamFrame(arg0 protocol.ByteCount, arg1 protocol.Version) (ackhandler.StreamFrame, bool, bool) { m.ctrl.T.Helper() ret := m.ctrl.Call(m, "popStreamFrame", arg0, arg1) ret0, _ := ret[0].(ackhandler.StreamFrame) @@ -272,9 +657,33 @@ func (m *MockStreamI) popStreamFrame(arg0 protocol.ByteCount, arg1 protocol.Vers } // popStreamFrame indicates an expected call of popStreamFrame. -func (mr *MockStreamIMockRecorder) popStreamFrame(arg0, arg1 any) *gomock.Call { +func (mr *MockStreamIMockRecorder) popStreamFrame(arg0, arg1 any) *MockStreamIpopStreamFrameCall { mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "popStreamFrame", reflect.TypeOf((*MockStreamI)(nil).popStreamFrame), arg0, arg1) + call := mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "popStreamFrame", reflect.TypeOf((*MockStreamI)(nil).popStreamFrame), arg0, arg1) + return &MockStreamIpopStreamFrameCall{Call: call} +} + +// MockStreamIpopStreamFrameCall wrap *gomock.Call +type MockStreamIpopStreamFrameCall struct { + *gomock.Call +} + +// Return rewrite *gomock.Call.Return +func (c *MockStreamIpopStreamFrameCall) Return(arg0 ackhandler.StreamFrame, arg1, arg2 bool) *MockStreamIpopStreamFrameCall { + c.Call = c.Call.Return(arg0, arg1, arg2) + return c +} + +// Do rewrite *gomock.Call.Do +func (c *MockStreamIpopStreamFrameCall) Do(f func(protocol.ByteCount, protocol.Version) (ackhandler.StreamFrame, bool, bool)) *MockStreamIpopStreamFrameCall { + c.Call = c.Call.Do(f) + return c +} + +// DoAndReturn rewrite *gomock.Call.DoAndReturn +func (c *MockStreamIpopStreamFrameCall) DoAndReturn(f func(protocol.ByteCount, protocol.Version) (ackhandler.StreamFrame, bool, bool)) *MockStreamIpopStreamFrameCall { + c.Call = c.Call.DoAndReturn(f) + return c } // updateSendWindow mocks base method. @@ -284,7 +693,31 @@ func (m *MockStreamI) updateSendWindow(arg0 protocol.ByteCount) { } // updateSendWindow indicates an expected call of updateSendWindow. -func (mr *MockStreamIMockRecorder) updateSendWindow(arg0 any) *gomock.Call { +func (mr *MockStreamIMockRecorder) updateSendWindow(arg0 any) *MockStreamIupdateSendWindowCall { mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "updateSendWindow", reflect.TypeOf((*MockStreamI)(nil).updateSendWindow), arg0) + call := mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "updateSendWindow", reflect.TypeOf((*MockStreamI)(nil).updateSendWindow), arg0) + return &MockStreamIupdateSendWindowCall{Call: call} +} + +// MockStreamIupdateSendWindowCall wrap *gomock.Call +type MockStreamIupdateSendWindowCall struct { + *gomock.Call +} + +// Return rewrite *gomock.Call.Return +func (c *MockStreamIupdateSendWindowCall) Return() *MockStreamIupdateSendWindowCall { + c.Call = c.Call.Return() + return c +} + +// Do rewrite *gomock.Call.Do +func (c *MockStreamIupdateSendWindowCall) Do(f func(protocol.ByteCount)) *MockStreamIupdateSendWindowCall { + c.Call = c.Call.Do(f) + return c +} + +// DoAndReturn rewrite *gomock.Call.DoAndReturn +func (c *MockStreamIupdateSendWindowCall) DoAndReturn(f func(protocol.ByteCount)) *MockStreamIupdateSendWindowCall { + c.Call = c.Call.DoAndReturn(f) + return c } diff --git a/mock_stream_manager_test.go b/mock_stream_manager_test.go index 3d5e18de2..e707fe8a6 100644 --- a/mock_stream_manager_test.go +++ b/mock_stream_manager_test.go @@ -3,8 +3,9 @@ // // Generated by this command: // -// mockgen -build_flags=-tags=gomock -package quic -self_package github.com/refraction-networking/uquic -destination mock_stream_manager_test.go github.com/refraction-networking/uquic StreamManager +// mockgen -typed -build_flags=-tags=gomock -package quic -self_package github.com/quic-go/quic-go -destination mock_stream_manager_test.go github.com/quic-go/quic-go StreamManager // + // Package quic is a generated GoMock package. package quic @@ -50,9 +51,33 @@ func (m *MockStreamManager) AcceptStream(arg0 context.Context) (Stream, error) { } // AcceptStream indicates an expected call of AcceptStream. -func (mr *MockStreamManagerMockRecorder) AcceptStream(arg0 any) *gomock.Call { +func (mr *MockStreamManagerMockRecorder) AcceptStream(arg0 any) *MockStreamManagerAcceptStreamCall { mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "AcceptStream", reflect.TypeOf((*MockStreamManager)(nil).AcceptStream), arg0) + call := mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "AcceptStream", reflect.TypeOf((*MockStreamManager)(nil).AcceptStream), arg0) + return &MockStreamManagerAcceptStreamCall{Call: call} +} + +// MockStreamManagerAcceptStreamCall wrap *gomock.Call +type MockStreamManagerAcceptStreamCall struct { + *gomock.Call +} + +// Return rewrite *gomock.Call.Return +func (c *MockStreamManagerAcceptStreamCall) Return(arg0 Stream, arg1 error) *MockStreamManagerAcceptStreamCall { + c.Call = c.Call.Return(arg0, arg1) + return c +} + +// Do rewrite *gomock.Call.Do +func (c *MockStreamManagerAcceptStreamCall) Do(f func(context.Context) (Stream, error)) *MockStreamManagerAcceptStreamCall { + c.Call = c.Call.Do(f) + return c +} + +// DoAndReturn rewrite *gomock.Call.DoAndReturn +func (c *MockStreamManagerAcceptStreamCall) DoAndReturn(f func(context.Context) (Stream, error)) *MockStreamManagerAcceptStreamCall { + c.Call = c.Call.DoAndReturn(f) + return c } // AcceptUniStream mocks base method. @@ -65,9 +90,33 @@ func (m *MockStreamManager) AcceptUniStream(arg0 context.Context) (ReceiveStream } // AcceptUniStream indicates an expected call of AcceptUniStream. -func (mr *MockStreamManagerMockRecorder) AcceptUniStream(arg0 any) *gomock.Call { +func (mr *MockStreamManagerMockRecorder) AcceptUniStream(arg0 any) *MockStreamManagerAcceptUniStreamCall { mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "AcceptUniStream", reflect.TypeOf((*MockStreamManager)(nil).AcceptUniStream), arg0) + call := mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "AcceptUniStream", reflect.TypeOf((*MockStreamManager)(nil).AcceptUniStream), arg0) + return &MockStreamManagerAcceptUniStreamCall{Call: call} +} + +// MockStreamManagerAcceptUniStreamCall wrap *gomock.Call +type MockStreamManagerAcceptUniStreamCall struct { + *gomock.Call +} + +// Return rewrite *gomock.Call.Return +func (c *MockStreamManagerAcceptUniStreamCall) Return(arg0 ReceiveStream, arg1 error) *MockStreamManagerAcceptUniStreamCall { + c.Call = c.Call.Return(arg0, arg1) + return c +} + +// Do rewrite *gomock.Call.Do +func (c *MockStreamManagerAcceptUniStreamCall) Do(f func(context.Context) (ReceiveStream, error)) *MockStreamManagerAcceptUniStreamCall { + c.Call = c.Call.Do(f) + return c +} + +// DoAndReturn rewrite *gomock.Call.DoAndReturn +func (c *MockStreamManagerAcceptUniStreamCall) DoAndReturn(f func(context.Context) (ReceiveStream, error)) *MockStreamManagerAcceptUniStreamCall { + c.Call = c.Call.DoAndReturn(f) + return c } // CloseWithError mocks base method. @@ -77,9 +126,33 @@ func (m *MockStreamManager) CloseWithError(arg0 error) { } // CloseWithError indicates an expected call of CloseWithError. -func (mr *MockStreamManagerMockRecorder) CloseWithError(arg0 any) *gomock.Call { +func (mr *MockStreamManagerMockRecorder) CloseWithError(arg0 any) *MockStreamManagerCloseWithErrorCall { mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "CloseWithError", reflect.TypeOf((*MockStreamManager)(nil).CloseWithError), arg0) + call := mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "CloseWithError", reflect.TypeOf((*MockStreamManager)(nil).CloseWithError), arg0) + return &MockStreamManagerCloseWithErrorCall{Call: call} +} + +// MockStreamManagerCloseWithErrorCall wrap *gomock.Call +type MockStreamManagerCloseWithErrorCall struct { + *gomock.Call +} + +// Return rewrite *gomock.Call.Return +func (c *MockStreamManagerCloseWithErrorCall) Return() *MockStreamManagerCloseWithErrorCall { + c.Call = c.Call.Return() + return c +} + +// Do rewrite *gomock.Call.Do +func (c *MockStreamManagerCloseWithErrorCall) Do(f func(error)) *MockStreamManagerCloseWithErrorCall { + c.Call = c.Call.Do(f) + return c +} + +// DoAndReturn rewrite *gomock.Call.DoAndReturn +func (c *MockStreamManagerCloseWithErrorCall) DoAndReturn(f func(error)) *MockStreamManagerCloseWithErrorCall { + c.Call = c.Call.DoAndReturn(f) + return c } // DeleteStream mocks base method. @@ -91,9 +164,33 @@ func (m *MockStreamManager) DeleteStream(arg0 protocol.StreamID) error { } // DeleteStream indicates an expected call of DeleteStream. -func (mr *MockStreamManagerMockRecorder) DeleteStream(arg0 any) *gomock.Call { +func (mr *MockStreamManagerMockRecorder) DeleteStream(arg0 any) *MockStreamManagerDeleteStreamCall { mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "DeleteStream", reflect.TypeOf((*MockStreamManager)(nil).DeleteStream), arg0) + call := mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "DeleteStream", reflect.TypeOf((*MockStreamManager)(nil).DeleteStream), arg0) + return &MockStreamManagerDeleteStreamCall{Call: call} +} + +// MockStreamManagerDeleteStreamCall wrap *gomock.Call +type MockStreamManagerDeleteStreamCall struct { + *gomock.Call +} + +// Return rewrite *gomock.Call.Return +func (c *MockStreamManagerDeleteStreamCall) Return(arg0 error) *MockStreamManagerDeleteStreamCall { + c.Call = c.Call.Return(arg0) + return c +} + +// Do rewrite *gomock.Call.Do +func (c *MockStreamManagerDeleteStreamCall) Do(f func(protocol.StreamID) error) *MockStreamManagerDeleteStreamCall { + c.Call = c.Call.Do(f) + return c +} + +// DoAndReturn rewrite *gomock.Call.DoAndReturn +func (c *MockStreamManagerDeleteStreamCall) DoAndReturn(f func(protocol.StreamID) error) *MockStreamManagerDeleteStreamCall { + c.Call = c.Call.DoAndReturn(f) + return c } // GetOrOpenReceiveStream mocks base method. @@ -106,9 +203,33 @@ func (m *MockStreamManager) GetOrOpenReceiveStream(arg0 protocol.StreamID) (rece } // GetOrOpenReceiveStream indicates an expected call of GetOrOpenReceiveStream. -func (mr *MockStreamManagerMockRecorder) GetOrOpenReceiveStream(arg0 any) *gomock.Call { +func (mr *MockStreamManagerMockRecorder) GetOrOpenReceiveStream(arg0 any) *MockStreamManagerGetOrOpenReceiveStreamCall { mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetOrOpenReceiveStream", reflect.TypeOf((*MockStreamManager)(nil).GetOrOpenReceiveStream), arg0) + call := mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetOrOpenReceiveStream", reflect.TypeOf((*MockStreamManager)(nil).GetOrOpenReceiveStream), arg0) + return &MockStreamManagerGetOrOpenReceiveStreamCall{Call: call} +} + +// MockStreamManagerGetOrOpenReceiveStreamCall wrap *gomock.Call +type MockStreamManagerGetOrOpenReceiveStreamCall struct { + *gomock.Call +} + +// Return rewrite *gomock.Call.Return +func (c *MockStreamManagerGetOrOpenReceiveStreamCall) Return(arg0 receiveStreamI, arg1 error) *MockStreamManagerGetOrOpenReceiveStreamCall { + c.Call = c.Call.Return(arg0, arg1) + return c +} + +// Do rewrite *gomock.Call.Do +func (c *MockStreamManagerGetOrOpenReceiveStreamCall) Do(f func(protocol.StreamID) (receiveStreamI, error)) *MockStreamManagerGetOrOpenReceiveStreamCall { + c.Call = c.Call.Do(f) + return c +} + +// DoAndReturn rewrite *gomock.Call.DoAndReturn +func (c *MockStreamManagerGetOrOpenReceiveStreamCall) DoAndReturn(f func(protocol.StreamID) (receiveStreamI, error)) *MockStreamManagerGetOrOpenReceiveStreamCall { + c.Call = c.Call.DoAndReturn(f) + return c } // GetOrOpenSendStream mocks base method. @@ -121,9 +242,33 @@ func (m *MockStreamManager) GetOrOpenSendStream(arg0 protocol.StreamID) (sendStr } // GetOrOpenSendStream indicates an expected call of GetOrOpenSendStream. -func (mr *MockStreamManagerMockRecorder) GetOrOpenSendStream(arg0 any) *gomock.Call { +func (mr *MockStreamManagerMockRecorder) GetOrOpenSendStream(arg0 any) *MockStreamManagerGetOrOpenSendStreamCall { mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetOrOpenSendStream", reflect.TypeOf((*MockStreamManager)(nil).GetOrOpenSendStream), arg0) + call := mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetOrOpenSendStream", reflect.TypeOf((*MockStreamManager)(nil).GetOrOpenSendStream), arg0) + return &MockStreamManagerGetOrOpenSendStreamCall{Call: call} +} + +// MockStreamManagerGetOrOpenSendStreamCall wrap *gomock.Call +type MockStreamManagerGetOrOpenSendStreamCall struct { + *gomock.Call +} + +// Return rewrite *gomock.Call.Return +func (c *MockStreamManagerGetOrOpenSendStreamCall) Return(arg0 sendStreamI, arg1 error) *MockStreamManagerGetOrOpenSendStreamCall { + c.Call = c.Call.Return(arg0, arg1) + return c +} + +// Do rewrite *gomock.Call.Do +func (c *MockStreamManagerGetOrOpenSendStreamCall) Do(f func(protocol.StreamID) (sendStreamI, error)) *MockStreamManagerGetOrOpenSendStreamCall { + c.Call = c.Call.Do(f) + return c +} + +// DoAndReturn rewrite *gomock.Call.DoAndReturn +func (c *MockStreamManagerGetOrOpenSendStreamCall) DoAndReturn(f func(protocol.StreamID) (sendStreamI, error)) *MockStreamManagerGetOrOpenSendStreamCall { + c.Call = c.Call.DoAndReturn(f) + return c } // HandleMaxStreamsFrame mocks base method. @@ -133,9 +278,33 @@ func (m *MockStreamManager) HandleMaxStreamsFrame(arg0 *wire.MaxStreamsFrame) { } // HandleMaxStreamsFrame indicates an expected call of HandleMaxStreamsFrame. -func (mr *MockStreamManagerMockRecorder) HandleMaxStreamsFrame(arg0 any) *gomock.Call { +func (mr *MockStreamManagerMockRecorder) HandleMaxStreamsFrame(arg0 any) *MockStreamManagerHandleMaxStreamsFrameCall { mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "HandleMaxStreamsFrame", reflect.TypeOf((*MockStreamManager)(nil).HandleMaxStreamsFrame), arg0) + call := mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "HandleMaxStreamsFrame", reflect.TypeOf((*MockStreamManager)(nil).HandleMaxStreamsFrame), arg0) + return &MockStreamManagerHandleMaxStreamsFrameCall{Call: call} +} + +// MockStreamManagerHandleMaxStreamsFrameCall wrap *gomock.Call +type MockStreamManagerHandleMaxStreamsFrameCall struct { + *gomock.Call +} + +// Return rewrite *gomock.Call.Return +func (c *MockStreamManagerHandleMaxStreamsFrameCall) Return() *MockStreamManagerHandleMaxStreamsFrameCall { + c.Call = c.Call.Return() + return c +} + +// Do rewrite *gomock.Call.Do +func (c *MockStreamManagerHandleMaxStreamsFrameCall) Do(f func(*wire.MaxStreamsFrame)) *MockStreamManagerHandleMaxStreamsFrameCall { + c.Call = c.Call.Do(f) + return c +} + +// DoAndReturn rewrite *gomock.Call.DoAndReturn +func (c *MockStreamManagerHandleMaxStreamsFrameCall) DoAndReturn(f func(*wire.MaxStreamsFrame)) *MockStreamManagerHandleMaxStreamsFrameCall { + c.Call = c.Call.DoAndReturn(f) + return c } // OpenStream mocks base method. @@ -148,9 +317,33 @@ func (m *MockStreamManager) OpenStream() (Stream, error) { } // OpenStream indicates an expected call of OpenStream. -func (mr *MockStreamManagerMockRecorder) OpenStream() *gomock.Call { +func (mr *MockStreamManagerMockRecorder) OpenStream() *MockStreamManagerOpenStreamCall { mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "OpenStream", reflect.TypeOf((*MockStreamManager)(nil).OpenStream)) + call := mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "OpenStream", reflect.TypeOf((*MockStreamManager)(nil).OpenStream)) + return &MockStreamManagerOpenStreamCall{Call: call} +} + +// MockStreamManagerOpenStreamCall wrap *gomock.Call +type MockStreamManagerOpenStreamCall struct { + *gomock.Call +} + +// Return rewrite *gomock.Call.Return +func (c *MockStreamManagerOpenStreamCall) Return(arg0 Stream, arg1 error) *MockStreamManagerOpenStreamCall { + c.Call = c.Call.Return(arg0, arg1) + return c +} + +// Do rewrite *gomock.Call.Do +func (c *MockStreamManagerOpenStreamCall) Do(f func() (Stream, error)) *MockStreamManagerOpenStreamCall { + c.Call = c.Call.Do(f) + return c +} + +// DoAndReturn rewrite *gomock.Call.DoAndReturn +func (c *MockStreamManagerOpenStreamCall) DoAndReturn(f func() (Stream, error)) *MockStreamManagerOpenStreamCall { + c.Call = c.Call.DoAndReturn(f) + return c } // OpenStreamSync mocks base method. @@ -163,9 +356,33 @@ func (m *MockStreamManager) OpenStreamSync(arg0 context.Context) (Stream, error) } // OpenStreamSync indicates an expected call of OpenStreamSync. -func (mr *MockStreamManagerMockRecorder) OpenStreamSync(arg0 any) *gomock.Call { +func (mr *MockStreamManagerMockRecorder) OpenStreamSync(arg0 any) *MockStreamManagerOpenStreamSyncCall { mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "OpenStreamSync", reflect.TypeOf((*MockStreamManager)(nil).OpenStreamSync), arg0) + call := mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "OpenStreamSync", reflect.TypeOf((*MockStreamManager)(nil).OpenStreamSync), arg0) + return &MockStreamManagerOpenStreamSyncCall{Call: call} +} + +// MockStreamManagerOpenStreamSyncCall wrap *gomock.Call +type MockStreamManagerOpenStreamSyncCall struct { + *gomock.Call +} + +// Return rewrite *gomock.Call.Return +func (c *MockStreamManagerOpenStreamSyncCall) Return(arg0 Stream, arg1 error) *MockStreamManagerOpenStreamSyncCall { + c.Call = c.Call.Return(arg0, arg1) + return c +} + +// Do rewrite *gomock.Call.Do +func (c *MockStreamManagerOpenStreamSyncCall) Do(f func(context.Context) (Stream, error)) *MockStreamManagerOpenStreamSyncCall { + c.Call = c.Call.Do(f) + return c +} + +// DoAndReturn rewrite *gomock.Call.DoAndReturn +func (c *MockStreamManagerOpenStreamSyncCall) DoAndReturn(f func(context.Context) (Stream, error)) *MockStreamManagerOpenStreamSyncCall { + c.Call = c.Call.DoAndReturn(f) + return c } // OpenUniStream mocks base method. @@ -178,9 +395,33 @@ func (m *MockStreamManager) OpenUniStream() (SendStream, error) { } // OpenUniStream indicates an expected call of OpenUniStream. -func (mr *MockStreamManagerMockRecorder) OpenUniStream() *gomock.Call { +func (mr *MockStreamManagerMockRecorder) OpenUniStream() *MockStreamManagerOpenUniStreamCall { mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "OpenUniStream", reflect.TypeOf((*MockStreamManager)(nil).OpenUniStream)) + call := mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "OpenUniStream", reflect.TypeOf((*MockStreamManager)(nil).OpenUniStream)) + return &MockStreamManagerOpenUniStreamCall{Call: call} +} + +// MockStreamManagerOpenUniStreamCall wrap *gomock.Call +type MockStreamManagerOpenUniStreamCall struct { + *gomock.Call +} + +// Return rewrite *gomock.Call.Return +func (c *MockStreamManagerOpenUniStreamCall) Return(arg0 SendStream, arg1 error) *MockStreamManagerOpenUniStreamCall { + c.Call = c.Call.Return(arg0, arg1) + return c +} + +// Do rewrite *gomock.Call.Do +func (c *MockStreamManagerOpenUniStreamCall) Do(f func() (SendStream, error)) *MockStreamManagerOpenUniStreamCall { + c.Call = c.Call.Do(f) + return c +} + +// DoAndReturn rewrite *gomock.Call.DoAndReturn +func (c *MockStreamManagerOpenUniStreamCall) DoAndReturn(f func() (SendStream, error)) *MockStreamManagerOpenUniStreamCall { + c.Call = c.Call.DoAndReturn(f) + return c } // OpenUniStreamSync mocks base method. @@ -193,9 +434,33 @@ func (m *MockStreamManager) OpenUniStreamSync(arg0 context.Context) (SendStream, } // OpenUniStreamSync indicates an expected call of OpenUniStreamSync. -func (mr *MockStreamManagerMockRecorder) OpenUniStreamSync(arg0 any) *gomock.Call { +func (mr *MockStreamManagerMockRecorder) OpenUniStreamSync(arg0 any) *MockStreamManagerOpenUniStreamSyncCall { mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "OpenUniStreamSync", reflect.TypeOf((*MockStreamManager)(nil).OpenUniStreamSync), arg0) + call := mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "OpenUniStreamSync", reflect.TypeOf((*MockStreamManager)(nil).OpenUniStreamSync), arg0) + return &MockStreamManagerOpenUniStreamSyncCall{Call: call} +} + +// MockStreamManagerOpenUniStreamSyncCall wrap *gomock.Call +type MockStreamManagerOpenUniStreamSyncCall struct { + *gomock.Call +} + +// Return rewrite *gomock.Call.Return +func (c *MockStreamManagerOpenUniStreamSyncCall) Return(arg0 SendStream, arg1 error) *MockStreamManagerOpenUniStreamSyncCall { + c.Call = c.Call.Return(arg0, arg1) + return c +} + +// Do rewrite *gomock.Call.Do +func (c *MockStreamManagerOpenUniStreamSyncCall) Do(f func(context.Context) (SendStream, error)) *MockStreamManagerOpenUniStreamSyncCall { + c.Call = c.Call.Do(f) + return c +} + +// DoAndReturn rewrite *gomock.Call.DoAndReturn +func (c *MockStreamManagerOpenUniStreamSyncCall) DoAndReturn(f func(context.Context) (SendStream, error)) *MockStreamManagerOpenUniStreamSyncCall { + c.Call = c.Call.DoAndReturn(f) + return c } // ResetFor0RTT mocks base method. @@ -205,9 +470,33 @@ func (m *MockStreamManager) ResetFor0RTT() { } // ResetFor0RTT indicates an expected call of ResetFor0RTT. -func (mr *MockStreamManagerMockRecorder) ResetFor0RTT() *gomock.Call { +func (mr *MockStreamManagerMockRecorder) ResetFor0RTT() *MockStreamManagerResetFor0RTTCall { mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "ResetFor0RTT", reflect.TypeOf((*MockStreamManager)(nil).ResetFor0RTT)) + call := mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "ResetFor0RTT", reflect.TypeOf((*MockStreamManager)(nil).ResetFor0RTT)) + return &MockStreamManagerResetFor0RTTCall{Call: call} +} + +// MockStreamManagerResetFor0RTTCall wrap *gomock.Call +type MockStreamManagerResetFor0RTTCall struct { + *gomock.Call +} + +// Return rewrite *gomock.Call.Return +func (c *MockStreamManagerResetFor0RTTCall) Return() *MockStreamManagerResetFor0RTTCall { + c.Call = c.Call.Return() + return c +} + +// Do rewrite *gomock.Call.Do +func (c *MockStreamManagerResetFor0RTTCall) Do(f func()) *MockStreamManagerResetFor0RTTCall { + c.Call = c.Call.Do(f) + return c +} + +// DoAndReturn rewrite *gomock.Call.DoAndReturn +func (c *MockStreamManagerResetFor0RTTCall) DoAndReturn(f func()) *MockStreamManagerResetFor0RTTCall { + c.Call = c.Call.DoAndReturn(f) + return c } // UpdateLimits mocks base method. @@ -217,9 +506,33 @@ func (m *MockStreamManager) UpdateLimits(arg0 *wire.TransportParameters) { } // UpdateLimits indicates an expected call of UpdateLimits. -func (mr *MockStreamManagerMockRecorder) UpdateLimits(arg0 any) *gomock.Call { +func (mr *MockStreamManagerMockRecorder) UpdateLimits(arg0 any) *MockStreamManagerUpdateLimitsCall { mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "UpdateLimits", reflect.TypeOf((*MockStreamManager)(nil).UpdateLimits), arg0) + call := mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "UpdateLimits", reflect.TypeOf((*MockStreamManager)(nil).UpdateLimits), arg0) + return &MockStreamManagerUpdateLimitsCall{Call: call} +} + +// MockStreamManagerUpdateLimitsCall wrap *gomock.Call +type MockStreamManagerUpdateLimitsCall struct { + *gomock.Call +} + +// Return rewrite *gomock.Call.Return +func (c *MockStreamManagerUpdateLimitsCall) Return() *MockStreamManagerUpdateLimitsCall { + c.Call = c.Call.Return() + return c +} + +// Do rewrite *gomock.Call.Do +func (c *MockStreamManagerUpdateLimitsCall) Do(f func(*wire.TransportParameters)) *MockStreamManagerUpdateLimitsCall { + c.Call = c.Call.Do(f) + return c +} + +// DoAndReturn rewrite *gomock.Call.DoAndReturn +func (c *MockStreamManagerUpdateLimitsCall) DoAndReturn(f func(*wire.TransportParameters)) *MockStreamManagerUpdateLimitsCall { + c.Call = c.Call.DoAndReturn(f) + return c } // UseResetMaps mocks base method. @@ -229,7 +542,31 @@ func (m *MockStreamManager) UseResetMaps() { } // UseResetMaps indicates an expected call of UseResetMaps. -func (mr *MockStreamManagerMockRecorder) UseResetMaps() *gomock.Call { +func (mr *MockStreamManagerMockRecorder) UseResetMaps() *MockStreamManagerUseResetMapsCall { mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "UseResetMaps", reflect.TypeOf((*MockStreamManager)(nil).UseResetMaps)) + call := mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "UseResetMaps", reflect.TypeOf((*MockStreamManager)(nil).UseResetMaps)) + return &MockStreamManagerUseResetMapsCall{Call: call} +} + +// MockStreamManagerUseResetMapsCall wrap *gomock.Call +type MockStreamManagerUseResetMapsCall struct { + *gomock.Call +} + +// Return rewrite *gomock.Call.Return +func (c *MockStreamManagerUseResetMapsCall) Return() *MockStreamManagerUseResetMapsCall { + c.Call = c.Call.Return() + return c +} + +// Do rewrite *gomock.Call.Do +func (c *MockStreamManagerUseResetMapsCall) Do(f func()) *MockStreamManagerUseResetMapsCall { + c.Call = c.Call.Do(f) + return c +} + +// DoAndReturn rewrite *gomock.Call.DoAndReturn +func (c *MockStreamManagerUseResetMapsCall) DoAndReturn(f func()) *MockStreamManagerUseResetMapsCall { + c.Call = c.Call.DoAndReturn(f) + return c } diff --git a/mock_stream_sender_test.go b/mock_stream_sender_test.go index f3165076a..8c89cf8b5 100644 --- a/mock_stream_sender_test.go +++ b/mock_stream_sender_test.go @@ -3,8 +3,9 @@ // // Generated by this command: // -// mockgen -build_flags=-tags=gomock -package quic -self_package github.com/refraction-networking/uquic -destination mock_stream_sender_test.go github.com/refraction-networking/uquic StreamSender +// mockgen -typed -build_flags=-tags=gomock -package quic -self_package github.com/quic-go/quic-go -destination mock_stream_sender_test.go github.com/quic-go/quic-go StreamSender // + // Package quic is a generated GoMock package. package quic @@ -46,9 +47,33 @@ func (m *MockStreamSender) onHasStreamData(arg0 protocol.StreamID) { } // onHasStreamData indicates an expected call of onHasStreamData. -func (mr *MockStreamSenderMockRecorder) onHasStreamData(arg0 any) *gomock.Call { +func (mr *MockStreamSenderMockRecorder) onHasStreamData(arg0 any) *MockStreamSenderonHasStreamDataCall { mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "onHasStreamData", reflect.TypeOf((*MockStreamSender)(nil).onHasStreamData), arg0) + call := mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "onHasStreamData", reflect.TypeOf((*MockStreamSender)(nil).onHasStreamData), arg0) + return &MockStreamSenderonHasStreamDataCall{Call: call} +} + +// MockStreamSenderonHasStreamDataCall wrap *gomock.Call +type MockStreamSenderonHasStreamDataCall struct { + *gomock.Call +} + +// Return rewrite *gomock.Call.Return +func (c *MockStreamSenderonHasStreamDataCall) Return() *MockStreamSenderonHasStreamDataCall { + c.Call = c.Call.Return() + return c +} + +// Do rewrite *gomock.Call.Do +func (c *MockStreamSenderonHasStreamDataCall) Do(f func(protocol.StreamID)) *MockStreamSenderonHasStreamDataCall { + c.Call = c.Call.Do(f) + return c +} + +// DoAndReturn rewrite *gomock.Call.DoAndReturn +func (c *MockStreamSenderonHasStreamDataCall) DoAndReturn(f func(protocol.StreamID)) *MockStreamSenderonHasStreamDataCall { + c.Call = c.Call.DoAndReturn(f) + return c } // onStreamCompleted mocks base method. @@ -58,9 +83,33 @@ func (m *MockStreamSender) onStreamCompleted(arg0 protocol.StreamID) { } // onStreamCompleted indicates an expected call of onStreamCompleted. -func (mr *MockStreamSenderMockRecorder) onStreamCompleted(arg0 any) *gomock.Call { +func (mr *MockStreamSenderMockRecorder) onStreamCompleted(arg0 any) *MockStreamSenderonStreamCompletedCall { mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "onStreamCompleted", reflect.TypeOf((*MockStreamSender)(nil).onStreamCompleted), arg0) + call := mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "onStreamCompleted", reflect.TypeOf((*MockStreamSender)(nil).onStreamCompleted), arg0) + return &MockStreamSenderonStreamCompletedCall{Call: call} +} + +// MockStreamSenderonStreamCompletedCall wrap *gomock.Call +type MockStreamSenderonStreamCompletedCall struct { + *gomock.Call +} + +// Return rewrite *gomock.Call.Return +func (c *MockStreamSenderonStreamCompletedCall) Return() *MockStreamSenderonStreamCompletedCall { + c.Call = c.Call.Return() + return c +} + +// Do rewrite *gomock.Call.Do +func (c *MockStreamSenderonStreamCompletedCall) Do(f func(protocol.StreamID)) *MockStreamSenderonStreamCompletedCall { + c.Call = c.Call.Do(f) + return c +} + +// DoAndReturn rewrite *gomock.Call.DoAndReturn +func (c *MockStreamSenderonStreamCompletedCall) DoAndReturn(f func(protocol.StreamID)) *MockStreamSenderonStreamCompletedCall { + c.Call = c.Call.DoAndReturn(f) + return c } // queueControlFrame mocks base method. @@ -70,7 +119,31 @@ func (m *MockStreamSender) queueControlFrame(arg0 wire.Frame) { } // queueControlFrame indicates an expected call of queueControlFrame. -func (mr *MockStreamSenderMockRecorder) queueControlFrame(arg0 any) *gomock.Call { +func (mr *MockStreamSenderMockRecorder) queueControlFrame(arg0 any) *MockStreamSenderqueueControlFrameCall { mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "queueControlFrame", reflect.TypeOf((*MockStreamSender)(nil).queueControlFrame), arg0) + call := mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "queueControlFrame", reflect.TypeOf((*MockStreamSender)(nil).queueControlFrame), arg0) + return &MockStreamSenderqueueControlFrameCall{Call: call} +} + +// MockStreamSenderqueueControlFrameCall wrap *gomock.Call +type MockStreamSenderqueueControlFrameCall struct { + *gomock.Call +} + +// Return rewrite *gomock.Call.Return +func (c *MockStreamSenderqueueControlFrameCall) Return() *MockStreamSenderqueueControlFrameCall { + c.Call = c.Call.Return() + return c +} + +// Do rewrite *gomock.Call.Do +func (c *MockStreamSenderqueueControlFrameCall) Do(f func(wire.Frame)) *MockStreamSenderqueueControlFrameCall { + c.Call = c.Call.Do(f) + return c +} + +// DoAndReturn rewrite *gomock.Call.DoAndReturn +func (c *MockStreamSenderqueueControlFrameCall) DoAndReturn(f func(wire.Frame)) *MockStreamSenderqueueControlFrameCall { + c.Call = c.Call.DoAndReturn(f) + return c } diff --git a/mock_token_store_test.go b/mock_token_store_test.go index 3d921b2ea..47c3afb4e 100644 --- a/mock_token_store_test.go +++ b/mock_token_store_test.go @@ -3,8 +3,9 @@ // // Generated by this command: // -// mockgen -package quic -self_package github.com/refraction-networking/uquic -self_package github.com/refraction-networking/uquic -destination mock_token_store_test.go github.com/refraction-networking/uquic TokenStore +// mockgen -typed -package quic -self_package github.com/quic-go/quic-go -self_package github.com/quic-go/quic-go -destination mock_token_store_test.go github.com/quic-go/quic-go TokenStore // + // Package quic is a generated GoMock package. package quic @@ -46,9 +47,33 @@ func (m *MockTokenStore) Pop(arg0 string) *ClientToken { } // Pop indicates an expected call of Pop. -func (mr *MockTokenStoreMockRecorder) Pop(arg0 any) *gomock.Call { +func (mr *MockTokenStoreMockRecorder) Pop(arg0 any) *MockTokenStorePopCall { mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Pop", reflect.TypeOf((*MockTokenStore)(nil).Pop), arg0) + call := mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Pop", reflect.TypeOf((*MockTokenStore)(nil).Pop), arg0) + return &MockTokenStorePopCall{Call: call} +} + +// MockTokenStorePopCall wrap *gomock.Call +type MockTokenStorePopCall struct { + *gomock.Call +} + +// Return rewrite *gomock.Call.Return +func (c *MockTokenStorePopCall) Return(arg0 *ClientToken) *MockTokenStorePopCall { + c.Call = c.Call.Return(arg0) + return c +} + +// Do rewrite *gomock.Call.Do +func (c *MockTokenStorePopCall) Do(f func(string) *ClientToken) *MockTokenStorePopCall { + c.Call = c.Call.Do(f) + return c +} + +// DoAndReturn rewrite *gomock.Call.DoAndReturn +func (c *MockTokenStorePopCall) DoAndReturn(f func(string) *ClientToken) *MockTokenStorePopCall { + c.Call = c.Call.DoAndReturn(f) + return c } // Put mocks base method. @@ -58,7 +83,31 @@ func (m *MockTokenStore) Put(arg0 string, arg1 *ClientToken) { } // Put indicates an expected call of Put. -func (mr *MockTokenStoreMockRecorder) Put(arg0, arg1 any) *gomock.Call { +func (mr *MockTokenStoreMockRecorder) Put(arg0, arg1 any) *MockTokenStorePutCall { mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Put", reflect.TypeOf((*MockTokenStore)(nil).Put), arg0, arg1) + call := mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Put", reflect.TypeOf((*MockTokenStore)(nil).Put), arg0, arg1) + return &MockTokenStorePutCall{Call: call} +} + +// MockTokenStorePutCall wrap *gomock.Call +type MockTokenStorePutCall struct { + *gomock.Call +} + +// Return rewrite *gomock.Call.Return +func (c *MockTokenStorePutCall) Return() *MockTokenStorePutCall { + c.Call = c.Call.Return() + return c +} + +// Do rewrite *gomock.Call.Do +func (c *MockTokenStorePutCall) Do(f func(string, *ClientToken)) *MockTokenStorePutCall { + c.Call = c.Call.Do(f) + return c +} + +// DoAndReturn rewrite *gomock.Call.DoAndReturn +func (c *MockTokenStorePutCall) DoAndReturn(f func(string, *ClientToken)) *MockTokenStorePutCall { + c.Call = c.Call.DoAndReturn(f) + return c } diff --git a/mock_unpacker_test.go b/mock_unpacker_test.go index 56cbe69cc..e9c0586cb 100644 --- a/mock_unpacker_test.go +++ b/mock_unpacker_test.go @@ -3,8 +3,9 @@ // // Generated by this command: // -// mockgen -build_flags=-tags=gomock -package quic -self_package github.com/refraction-networking/uquic -destination mock_unpacker_test.go github.com/refraction-networking/uquic Unpacker +// mockgen -typed -build_flags=-tags=gomock -package quic -self_package github.com/quic-go/quic-go -destination mock_unpacker_test.go github.com/quic-go/quic-go Unpacker // + // Package quic is a generated GoMock package. package quic @@ -41,7 +42,7 @@ func (m *MockUnpacker) EXPECT() *MockUnpackerMockRecorder { } // UnpackLongHeader mocks base method. -func (m *MockUnpacker) UnpackLongHeader(arg0 *wire.Header, arg1 time.Time, arg2 []byte, arg3 protocol.VersionNumber) (*unpackedPacket, error) { +func (m *MockUnpacker) UnpackLongHeader(arg0 *wire.Header, arg1 time.Time, arg2 []byte, arg3 protocol.Version) (*unpackedPacket, error) { m.ctrl.T.Helper() ret := m.ctrl.Call(m, "UnpackLongHeader", arg0, arg1, arg2, arg3) ret0, _ := ret[0].(*unpackedPacket) @@ -50,9 +51,33 @@ func (m *MockUnpacker) UnpackLongHeader(arg0 *wire.Header, arg1 time.Time, arg2 } // UnpackLongHeader indicates an expected call of UnpackLongHeader. -func (mr *MockUnpackerMockRecorder) UnpackLongHeader(arg0, arg1, arg2, arg3 any) *gomock.Call { +func (mr *MockUnpackerMockRecorder) UnpackLongHeader(arg0, arg1, arg2, arg3 any) *MockUnpackerUnpackLongHeaderCall { mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "UnpackLongHeader", reflect.TypeOf((*MockUnpacker)(nil).UnpackLongHeader), arg0, arg1, arg2, arg3) + call := mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "UnpackLongHeader", reflect.TypeOf((*MockUnpacker)(nil).UnpackLongHeader), arg0, arg1, arg2, arg3) + return &MockUnpackerUnpackLongHeaderCall{Call: call} +} + +// MockUnpackerUnpackLongHeaderCall wrap *gomock.Call +type MockUnpackerUnpackLongHeaderCall struct { + *gomock.Call +} + +// Return rewrite *gomock.Call.Return +func (c *MockUnpackerUnpackLongHeaderCall) Return(arg0 *unpackedPacket, arg1 error) *MockUnpackerUnpackLongHeaderCall { + c.Call = c.Call.Return(arg0, arg1) + return c +} + +// Do rewrite *gomock.Call.Do +func (c *MockUnpackerUnpackLongHeaderCall) Do(f func(*wire.Header, time.Time, []byte, protocol.Version) (*unpackedPacket, error)) *MockUnpackerUnpackLongHeaderCall { + c.Call = c.Call.Do(f) + return c +} + +// DoAndReturn rewrite *gomock.Call.DoAndReturn +func (c *MockUnpackerUnpackLongHeaderCall) DoAndReturn(f func(*wire.Header, time.Time, []byte, protocol.Version) (*unpackedPacket, error)) *MockUnpackerUnpackLongHeaderCall { + c.Call = c.Call.DoAndReturn(f) + return c } // UnpackShortHeader mocks base method. @@ -68,7 +93,31 @@ func (m *MockUnpacker) UnpackShortHeader(arg0 time.Time, arg1 []byte) (protocol. } // UnpackShortHeader indicates an expected call of UnpackShortHeader. -func (mr *MockUnpackerMockRecorder) UnpackShortHeader(arg0, arg1 any) *gomock.Call { +func (mr *MockUnpackerMockRecorder) UnpackShortHeader(arg0, arg1 any) *MockUnpackerUnpackShortHeaderCall { mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "UnpackShortHeader", reflect.TypeOf((*MockUnpacker)(nil).UnpackShortHeader), arg0, arg1) + call := mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "UnpackShortHeader", reflect.TypeOf((*MockUnpacker)(nil).UnpackShortHeader), arg0, arg1) + return &MockUnpackerUnpackShortHeaderCall{Call: call} +} + +// MockUnpackerUnpackShortHeaderCall wrap *gomock.Call +type MockUnpackerUnpackShortHeaderCall struct { + *gomock.Call +} + +// Return rewrite *gomock.Call.Return +func (c *MockUnpackerUnpackShortHeaderCall) Return(arg0 protocol.PacketNumber, arg1 protocol.PacketNumberLen, arg2 protocol.KeyPhaseBit, arg3 []byte, arg4 error) *MockUnpackerUnpackShortHeaderCall { + c.Call = c.Call.Return(arg0, arg1, arg2, arg3, arg4) + return c +} + +// Do rewrite *gomock.Call.Do +func (c *MockUnpackerUnpackShortHeaderCall) Do(f func(time.Time, []byte) (protocol.PacketNumber, protocol.PacketNumberLen, protocol.KeyPhaseBit, []byte, error)) *MockUnpackerUnpackShortHeaderCall { + c.Call = c.Call.Do(f) + return c +} + +// DoAndReturn rewrite *gomock.Call.DoAndReturn +func (c *MockUnpackerUnpackShortHeaderCall) DoAndReturn(f func(time.Time, []byte) (protocol.PacketNumber, protocol.PacketNumberLen, protocol.KeyPhaseBit, []byte, error)) *MockUnpackerUnpackShortHeaderCall { + c.Call = c.Call.DoAndReturn(f) + return c } diff --git a/mockgen.go b/mockgen.go index b6267e5c3..81cc4a5ef 100644 --- a/mockgen.go +++ b/mockgen.go @@ -2,73 +2,75 @@ package quic -//go:generate sh -c "go run go.uber.org/mock/mockgen -build_flags=\"-tags=gomock\" -package quic -self_package github.com/refraction-networking/uquic -destination mock_send_conn_test.go github.com/refraction-networking/uquic SendConn" +//go:generate sh -c "go run go.uber.org/mock/mockgen -typed -build_flags=\"-tags=gomock\" -package quic -self_package github.com/quic-go/quic-go -destination mock_send_conn_test.go github.com/quic-go/quic-go SendConn" type SendConn = sendConn -//go:generate sh -c "go run go.uber.org/mock/mockgen -build_flags=\"-tags=gomock\" -package quic -self_package github.com/refraction-networking/uquic -destination mock_raw_conn_test.go github.com/refraction-networking/uquic RawConn" +//go:generate sh -c "go run go.uber.org/mock/mockgen -typed -build_flags=\"-tags=gomock\" -package quic -self_package github.com/quic-go/quic-go -destination mock_raw_conn_test.go github.com/quic-go/quic-go RawConn" type RawConn = rawConn -//go:generate sh -c "go run go.uber.org/mock/mockgen -build_flags=\"-tags=gomock\" -package quic -self_package github.com/refraction-networking/uquic -destination mock_sender_test.go github.com/refraction-networking/uquic Sender" +//go:generate sh -c "go run go.uber.org/mock/mockgen -typed -build_flags=\"-tags=gomock\" -package quic -self_package github.com/quic-go/quic-go -destination mock_sender_test.go github.com/quic-go/quic-go Sender" type Sender = sender -//go:generate sh -c "go run go.uber.org/mock/mockgen -build_flags=\"-tags=gomock\" -package quic -self_package github.com/refraction-networking/uquic -destination mock_stream_internal_test.go github.com/refraction-networking/uquic StreamI" +//go:generate sh -c "go run go.uber.org/mock/mockgen -typed -build_flags=\"-tags=gomock\" -package quic -self_package github.com/quic-go/quic-go -destination mock_stream_internal_test.go github.com/quic-go/quic-go StreamI" type StreamI = streamI -//go:generate sh -c "go run go.uber.org/mock/mockgen -build_flags=\"-tags=gomock\" -package quic -self_package github.com/refraction-networking/uquic -destination mock_crypto_stream_test.go github.com/refraction-networking/uquic CryptoStream" +//go:generate sh -c "go run go.uber.org/mock/mockgen -typed -build_flags=\"-tags=gomock\" -package quic -self_package github.com/quic-go/quic-go -destination mock_crypto_stream_test.go github.com/quic-go/quic-go CryptoStream" type CryptoStream = cryptoStream -//go:generate sh -c "go run go.uber.org/mock/mockgen -build_flags=\"-tags=gomock\" -package quic -self_package github.com/refraction-networking/uquic -destination mock_receive_stream_internal_test.go github.com/refraction-networking/uquic ReceiveStreamI" +//go:generate sh -c "go run go.uber.org/mock/mockgen -typed -build_flags=\"-tags=gomock\" -package quic -self_package github.com/quic-go/quic-go -destination mock_receive_stream_internal_test.go github.com/quic-go/quic-go ReceiveStreamI" type ReceiveStreamI = receiveStreamI -//go:generate sh -c "go run go.uber.org/mock/mockgen -build_flags=\"-tags=gomock\" -package quic -self_package github.com/refraction-networking/uquic -destination mock_send_stream_internal_test.go github.com/refraction-networking/uquic SendStreamI" +//go:generate sh -c "go run go.uber.org/mock/mockgen -typed -build_flags=\"-tags=gomock\" -package quic -self_package github.com/quic-go/quic-go -destination mock_send_stream_internal_test.go github.com/quic-go/quic-go SendStreamI" type SendStreamI = sendStreamI -//go:generate sh -c "go run go.uber.org/mock/mockgen -build_flags=\"-tags=gomock\" -package quic -self_package github.com/refraction-networking/uquic -destination mock_stream_getter_test.go github.com/refraction-networking/uquic StreamGetter" +//go:generate sh -c "go run go.uber.org/mock/mockgen -typed -build_flags=\"-tags=gomock\" -package quic -self_package github.com/quic-go/quic-go -destination mock_stream_getter_test.go github.com/quic-go/quic-go StreamGetter" type StreamGetter = streamGetter -//go:generate sh -c "go run go.uber.org/mock/mockgen -build_flags=\"-tags=gomock\" -package quic -self_package github.com/refraction-networking/uquic -destination mock_stream_sender_test.go github.com/refraction-networking/uquic StreamSender" +//go:generate sh -c "go run go.uber.org/mock/mockgen -typed -build_flags=\"-tags=gomock\" -package quic -self_package github.com/quic-go/quic-go -destination mock_stream_sender_test.go github.com/quic-go/quic-go StreamSender" type StreamSender = streamSender -//go:generate sh -c "go run go.uber.org/mock/mockgen -build_flags=\"-tags=gomock\" -package quic -self_package github.com/refraction-networking/uquic -destination mock_crypto_data_handler_test.go github.com/refraction-networking/uquic CryptoDataHandler" +//go:generate sh -c "go run go.uber.org/mock/mockgen -typed -build_flags=\"-tags=gomock\" -package quic -self_package github.com/quic-go/quic-go -destination mock_crypto_data_handler_test.go github.com/quic-go/quic-go CryptoDataHandler" type CryptoDataHandler = cryptoDataHandler -//go:generate sh -c "go run go.uber.org/mock/mockgen -build_flags=\"-tags=gomock\" -package quic -self_package github.com/refraction-networking/uquic -destination mock_frame_source_test.go github.com/refraction-networking/uquic FrameSource" +//go:generate sh -c "go run go.uber.org/mock/mockgen -typed -build_flags=\"-tags=gomock\" -package quic -self_package github.com/quic-go/quic-go -destination mock_frame_source_test.go github.com/quic-go/quic-go FrameSource" type FrameSource = frameSource -//go:generate sh -c "go run go.uber.org/mock/mockgen -build_flags=\"-tags=gomock\" -package quic -self_package github.com/refraction-networking/uquic -destination mock_ack_frame_source_test.go github.com/refraction-networking/uquic AckFrameSource" +//go:generate sh -c "go run go.uber.org/mock/mockgen -typed -build_flags=\"-tags=gomock\" -package quic -self_package github.com/quic-go/quic-go -destination mock_ack_frame_source_test.go github.com/quic-go/quic-go AckFrameSource" type AckFrameSource = ackFrameSource -//go:generate sh -c "go run go.uber.org/mock/mockgen -build_flags=\"-tags=gomock\" -package quic -self_package github.com/refraction-networking/uquic -destination mock_stream_manager_test.go github.com/refraction-networking/uquic StreamManager" +//go:generate sh -c "go run go.uber.org/mock/mockgen -typed -build_flags=\"-tags=gomock\" -package quic -self_package github.com/quic-go/quic-go -destination mock_stream_manager_test.go github.com/quic-go/quic-go StreamManager" type StreamManager = streamManager -//go:generate sh -c "go run go.uber.org/mock/mockgen -build_flags=\"-tags=gomock\" -package quic -self_package github.com/refraction-networking/uquic -destination mock_sealing_manager_test.go github.com/refraction-networking/uquic SealingManager" +//go:generate sh -c "go run go.uber.org/mock/mockgen -typed -build_flags=\"-tags=gomock\" -package quic -self_package github.com/quic-go/quic-go -destination mock_sealing_manager_test.go github.com/quic-go/quic-go SealingManager" type SealingManager = sealingManager -//go:generate sh -c "go run go.uber.org/mock/mockgen -build_flags=\"-tags=gomock\" -package quic -self_package github.com/refraction-networking/uquic -destination mock_unpacker_test.go github.com/refraction-networking/uquic Unpacker" +//go:generate sh -c "go run go.uber.org/mock/mockgen -typed -build_flags=\"-tags=gomock\" -package quic -self_package github.com/quic-go/quic-go -destination mock_unpacker_test.go github.com/quic-go/quic-go Unpacker" type Unpacker = unpacker -//go:generate sh -c "go run go.uber.org/mock/mockgen -build_flags=\"-tags=gomock\" -package quic -self_package github.com/refraction-networking/uquic -destination mock_packer_test.go github.com/refraction-networking/uquic Packer" +//go:generate sh -c "go run go.uber.org/mock/mockgen -typed -build_flags=\"-tags=gomock\" -package quic -self_package github.com/quic-go/quic-go -destination mock_packer_test.go github.com/quic-go/quic-go Packer" type Packer = packer -//go:generate sh -c "go run go.uber.org/mock/mockgen -build_flags=\"-tags=gomock\" -package quic -self_package github.com/refraction-networking/uquic -destination mock_mtu_discoverer_test.go github.com/refraction-networking/uquic MTUDiscoverer" +//go:generate sh -c "go run go.uber.org/mock/mockgen -typed -build_flags=\"-tags=gomock\" -package quic -self_package github.com/quic-go/quic-go -destination mock_mtu_discoverer_test.go github.com/quic-go/quic-go MTUDiscoverer" type MTUDiscoverer = mtuDiscoverer -//go:generate sh -c "go run go.uber.org/mock/mockgen -build_flags=\"-tags=gomock\" -package quic -self_package github.com/refraction-networking/uquic -destination mock_conn_runner_test.go github.com/refraction-networking/uquic ConnRunner" +//go:generate sh -c "go run go.uber.org/mock/mockgen -typed -build_flags=\"-tags=gomock\" -package quic -self_package github.com/quic-go/quic-go -destination mock_conn_runner_test.go github.com/quic-go/quic-go ConnRunner" type ConnRunner = connRunner -//go:generate sh -c "go run go.uber.org/mock/mockgen -build_flags=\"-tags=gomock\" -package quic -self_package github.com/refraction-networking/uquic -destination mock_quic_conn_test.go github.com/refraction-networking/uquic QUICConn" +//go:generate sh -c "go run go.uber.org/mock/mockgen -typed -build_flags=\"-tags=gomock\" -package quic -self_package github.com/quic-go/quic-go -destination mock_quic_conn_test.go github.com/quic-go/quic-go QUICConn" type QUICConn = quicConn -//go:generate sh -c "go run go.uber.org/mock/mockgen -build_flags=\"-tags=gomock\" -package quic -self_package github.com/refraction-networking/uquic -destination mock_packet_handler_test.go github.com/refraction-networking/uquic PacketHandler" +//go:generate sh -c "go run go.uber.org/mock/mockgen -typed -build_flags=\"-tags=gomock\" -package quic -self_package github.com/quic-go/quic-go -destination mock_packet_handler_test.go github.com/quic-go/quic-go PacketHandler" type PacketHandler = packetHandler -//go:generate sh -c "go run go.uber.org/mock/mockgen -build_flags=\"-tags=gomock\" -package quic -self_package github.com/refraction-networking/uquic -destination mock_packet_handler_manager_test.go github.com/refraction-networking/uquic PacketHandlerManager" +//go:generate sh -c "go run go.uber.org/mock/mockgen -build_flags=\"-tags=gomock\" -package quic -self_package github.com/quic-go/quic-go -destination mock_packet_handler_manager_test.go github.com/quic-go/quic-go PacketHandlerManager" + +//go:generate sh -c "go run go.uber.org/mock/mockgen -typed -build_flags=\"-tags=gomock\" -package quic -self_package github.com/quic-go/quic-go -destination mock_packet_handler_manager_test.go github.com/quic-go/quic-go PacketHandlerManager" type PacketHandlerManager = packetHandlerManager // Need to use source mode for the batchConn, since reflect mode follows type aliases. // See https://github.com/golang/mock/issues/244 for details. // -//go:generate sh -c "go run go.uber.org/mock/mockgen -package quic -self_package github.com/refraction-networking/uquic -source sys_conn_oob.go -destination mock_batch_conn_test.go -mock_names batchConn=MockBatchConn" +//go:generate sh -c "go run go.uber.org/mock/mockgen -typed -package quic -self_package github.com/quic-go/quic-go -source sys_conn_oob.go -destination mock_batch_conn_test.go -mock_names batchConn=MockBatchConn" -//go:generate sh -c "go run go.uber.org/mock/mockgen -package quic -self_package github.com/refraction-networking/uquic -self_package github.com/refraction-networking/uquic -destination mock_token_store_test.go github.com/refraction-networking/uquic TokenStore" -//go:generate sh -c "go run go.uber.org/mock/mockgen -package quic -self_package github.com/refraction-networking/uquic -self_package github.com/refraction-networking/uquic -destination mock_packetconn_test.go net PacketConn" +//go:generate sh -c "go run go.uber.org/mock/mockgen -typed -package quic -self_package github.com/quic-go/quic-go -self_package github.com/quic-go/quic-go -destination mock_token_store_test.go github.com/quic-go/quic-go TokenStore" +//go:generate sh -c "go run go.uber.org/mock/mockgen -typed -package quic -self_package github.com/quic-go/quic-go -self_package github.com/quic-go/quic-go -destination mock_packetconn_test.go net PacketConn" diff --git a/mtu_discoverer_test.go b/mtu_discoverer_test.go index d0819ef10..59f9d5fb6 100644 --- a/mtu_discoverer_test.go +++ b/mtu_discoverer_test.go @@ -88,12 +88,12 @@ var _ = Describe("MTU Discoverer", func() { const rep = 3000 var maxDiff protocol.ByteCount for i := 0; i < rep; i++ { - max := protocol.ByteCount(rand.Intn(int(3000-startMTU))) + startMTU + 1 + maxMTU := protocol.ByteCount(rand.Intn(int(3000-startMTU))) + startMTU + 1 currentMTU := startMTU d := newMTUDiscoverer(rttStats, startMTU, func(s protocol.ByteCount) { currentMTU = s }) - d.Start(max) + d.Start(maxMTU) now := time.Now() - realMTU := protocol.ByteCount(rand.Intn(int(max-startMTU))) + startMTU + realMTU := protocol.ByteCount(rand.Intn(int(maxMTU-startMTU))) + startMTU t := now.Add(mtuProbeDelay * rtt) var count int for d.ShouldSendProbe(t) { @@ -112,7 +112,7 @@ var _ = Describe("MTU Discoverer", func() { } diff := realMTU - currentMTU Expect(diff).To(BeNumerically(">=", 0)) - maxDiff = utils.Max(maxDiff, diff) + maxDiff = max(maxDiff, diff) } Expect(maxDiff).To(BeEquivalentTo(maxMTUDiff)) }) diff --git a/oss-fuzz.sh b/oss-fuzz.sh index e711d94ab..75f875b58 100644 --- a/oss-fuzz.sh +++ b/oss-fuzz.sh @@ -3,12 +3,12 @@ # Install Go manually, since oss-fuzz ships with an outdated Go version. # See https://github.com/google/oss-fuzz/pull/10643. export CXX="${CXX} -lresolv" # required by Go 1.20 -wget https://go.dev/dl/go1.20.5.linux-amd64.tar.gz \ +wget https://go.dev/dl/go1.22.0.linux-amd64.tar.gz \ && mkdir temp-go \ && rm -rf /root/.go/* \ - && tar -C temp-go/ -xzf go1.20.5.linux-amd64.tar.gz \ + && tar -C temp-go/ -xzf go1.22.0.linux-amd64.tar.gz \ && mv temp-go/go/* /root/.go/ \ - && rm -rf temp-go go1.20.5.linux-amd64.tar.gz + && rm -rf temp-go go1.22.0.linux-amd64.tar.gz ( # fuzz qpack diff --git a/packet_handler_map.go b/packet_handler_map.go index 9c4ebf876..4bbc119ce 100644 --- a/packet_handler_map.go +++ b/packet_handler_map.go @@ -129,7 +129,7 @@ func (h *packetHandlerMap) Add(id protocol.ConnectionID, handler packetHandler) return true } -func (h *packetHandlerMap) AddWithConnID(clientDestConnID, newConnID protocol.ConnectionID, fn func() (packetHandler, bool)) bool { +func (h *packetHandlerMap) AddWithConnID(clientDestConnID, newConnID protocol.ConnectionID, handler packetHandler) bool { h.mutex.Lock() defer h.mutex.Unlock() @@ -137,12 +137,8 @@ func (h *packetHandlerMap) AddWithConnID(clientDestConnID, newConnID protocol.Co h.logger.Debugf("Not adding connection ID %s for a new connection, as it already exists.", clientDestConnID) return false } - conn, ok := fn() - if !ok { - return false - } - h.handlers[clientDestConnID] = conn - h.handlers[newConnID] = conn + h.handlers[clientDestConnID] = handler + h.handlers[newConnID] = handler h.logger.Debugf("Adding connection IDs %s and %s for a new connection.", clientDestConnID, newConnID) return true } @@ -168,18 +164,17 @@ func (h *packetHandlerMap) Retire(id protocol.ConnectionID) { // Depending on which side closed the connection, we need to: // * remote close: absorb delayed packets // * local close: retransmit the CONNECTION_CLOSE packet, in case it was lost -func (h *packetHandlerMap) ReplaceWithClosed(ids []protocol.ConnectionID, pers protocol.Perspective, connClosePacket []byte) { +func (h *packetHandlerMap) ReplaceWithClosed(ids []protocol.ConnectionID, connClosePacket []byte) { var handler packetHandler if connClosePacket != nil { handler = newClosedLocalConn( func(addr net.Addr, info packetInfo) { h.enqueueClosePacket(closePacket{payload: connClosePacket, addr: addr, info: info}) }, - pers, h.logger, ) } else { - handler = newClosedRemoteConn(pers) + handler = newClosedRemoteConn() } h.mutex.Lock() @@ -191,7 +186,6 @@ func (h *packetHandlerMap) ReplaceWithClosed(ids []protocol.ConnectionID, pers p time.AfterFunc(h.deleteRetiredConnsAfter, func() { h.mutex.Lock() - handler.shutdown() for _, id := range ids { delete(h.handlers, id) } @@ -220,23 +214,6 @@ func (h *packetHandlerMap) GetByResetToken(token protocol.StatelessResetToken) ( return handler, ok } -func (h *packetHandlerMap) CloseServer() { - h.mutex.Lock() - var wg sync.WaitGroup - for _, handler := range h.handlers { - if handler.getPerspective() == protocol.PerspectiveServer { - wg.Add(1) - go func(handler packetHandler) { - // blocks until the CONNECTION_CLOSE has been sent and the run-loop has stopped - handler.shutdown() - wg.Done() - }(handler) - } - } - h.mutex.Unlock() - wg.Wait() -} - func (h *packetHandlerMap) Close(e error) { h.mutex.Lock() diff --git a/packet_handler_map_test.go b/packet_handler_map_test.go index 78b15d9c9..51eca7546 100644 --- a/packet_handler_map_test.go +++ b/packet_handler_map_test.go @@ -59,18 +59,12 @@ var _ = Describe("Packet Handler Map", func() { It("adds newly to-be-constructed handlers", func() { m := newPacketHandlerMap(nil, nil, utils.DefaultLogger) - var called bool connID1 := protocol.ParseConnectionID([]byte{1, 2, 3, 4}) connID2 := protocol.ParseConnectionID([]byte{4, 3, 2, 1}) - Expect(m.AddWithConnID(connID1, connID2, func() (packetHandler, bool) { - called = true - return NewMockPacketHandler(mockCtrl), true - })).To(BeTrue()) - Expect(called).To(BeTrue()) - Expect(m.AddWithConnID(connID1, protocol.ParseConnectionID([]byte{1, 2, 3}), func() (packetHandler, bool) { - Fail("didn't expect the constructor to be executed") - return nil, false - })).To(BeFalse()) + h := NewMockPacketHandler(mockCtrl) + Expect(m.AddWithConnID(connID1, connID2, h)).To(BeTrue()) + // collision of the destination connection ID, this handler should not be added + Expect(m.AddWithConnID(connID1, protocol.ParseConnectionID([]byte{1, 2, 3}), nil)).To(BeFalse()) }) It("adds, gets and removes reset tokens", func() { @@ -124,7 +118,7 @@ var _ = Describe("Packet Handler Map", func() { handler := NewMockPacketHandler(mockCtrl) connID := protocol.ParseConnectionID([]byte{4, 3, 2, 1}) Expect(m.Add(connID, handler)).To(BeTrue()) - m.ReplaceWithClosed([]protocol.ConnectionID{connID}, protocol.PerspectiveClient, []byte("foobar")) + m.ReplaceWithClosed([]protocol.ConnectionID{connID}, []byte("foobar")) h, ok := m.Get(connID) Expect(ok).To(BeTrue()) Expect(h).ToNot(Equal(handler)) @@ -147,7 +141,7 @@ var _ = Describe("Packet Handler Map", func() { handler := NewMockPacketHandler(mockCtrl) connID := protocol.ParseConnectionID([]byte{4, 3, 2, 1}) Expect(m.Add(connID, handler)).To(BeTrue()) - m.ReplaceWithClosed([]protocol.ConnectionID{connID}, protocol.PerspectiveClient, nil) + m.ReplaceWithClosed([]protocol.ConnectionID{connID}, nil) h, ok := m.Get(connID) Expect(ok).To(BeTrue()) Expect(h).ToNot(Equal(handler)) @@ -159,23 +153,6 @@ var _ = Describe("Packet Handler Map", func() { Eventually(func() bool { _, ok := m.Get(connID); return ok }).Should(BeFalse()) }) - It("closes the server", func() { - m := newPacketHandlerMap(nil, nil, utils.DefaultLogger) - for i := 0; i < 10; i++ { - conn := NewMockPacketHandler(mockCtrl) - if i%2 == 0 { - conn.EXPECT().getPerspective().Return(protocol.PerspectiveClient) - } else { - conn.EXPECT().getPerspective().Return(protocol.PerspectiveServer) - conn.EXPECT().shutdown() - } - b := make([]byte, 12) - rand.Read(b) - m.Add(protocol.ParseConnectionID(b), conn) - } - m.CloseServer() - }) - It("closes", func() { m := newPacketHandlerMap(nil, nil, utils.DefaultLogger) testErr := errors.New("shutdown") diff --git a/packet_packer.go b/packet_packer.go index d1b548898..64d5ff3bd 100644 --- a/packet_packer.go +++ b/packet_packer.go @@ -17,13 +17,13 @@ import ( var errNothingToPack = errors.New("nothing to pack") type packer interface { - PackCoalescedPacket(onlyAck bool, maxPacketSize protocol.ByteCount, v protocol.VersionNumber) (*coalescedPacket, error) - PackAckOnlyPacket(maxPacketSize protocol.ByteCount, v protocol.VersionNumber) (shortHeaderPacket, *packetBuffer, error) - AppendPacket(buf *packetBuffer, maxPacketSize protocol.ByteCount, v protocol.VersionNumber) (shortHeaderPacket, error) - MaybePackProbePacket(protocol.EncryptionLevel, protocol.ByteCount, protocol.VersionNumber) (*coalescedPacket, error) - PackConnectionClose(*qerr.TransportError, protocol.ByteCount, protocol.VersionNumber) (*coalescedPacket, error) - PackApplicationClose(*qerr.ApplicationError, protocol.ByteCount, protocol.VersionNumber) (*coalescedPacket, error) - PackMTUProbePacket(ping ackhandler.Frame, size protocol.ByteCount, v protocol.VersionNumber) (shortHeaderPacket, *packetBuffer, error) + PackCoalescedPacket(onlyAck bool, maxPacketSize protocol.ByteCount, v protocol.Version) (*coalescedPacket, error) + PackAckOnlyPacket(maxPacketSize protocol.ByteCount, v protocol.Version) (shortHeaderPacket, *packetBuffer, error) + AppendPacket(buf *packetBuffer, maxPacketSize protocol.ByteCount, v protocol.Version) (shortHeaderPacket, error) + MaybePackProbePacket(protocol.EncryptionLevel, protocol.ByteCount, protocol.Version) (*coalescedPacket, error) + PackConnectionClose(*qerr.TransportError, protocol.ByteCount, protocol.Version) (*coalescedPacket, error) + PackApplicationClose(*qerr.ApplicationError, protocol.ByteCount, protocol.Version) (*coalescedPacket, error) + PackMTUProbePacket(ping ackhandler.Frame, size protocol.ByteCount, v protocol.Version) (shortHeaderPacket, *packetBuffer, error) SetToken([]byte) } @@ -105,8 +105,8 @@ type sealingManager interface { type frameSource interface { HasData() bool - AppendStreamFrames([]ackhandler.StreamFrame, protocol.ByteCount, protocol.VersionNumber) ([]ackhandler.StreamFrame, protocol.ByteCount) - AppendControlFrames([]ackhandler.Frame, protocol.ByteCount, protocol.VersionNumber) ([]ackhandler.Frame, protocol.ByteCount) + AppendStreamFrames([]ackhandler.StreamFrame, protocol.ByteCount, protocol.Version) ([]ackhandler.StreamFrame, protocol.ByteCount) + AppendControlFrames([]ackhandler.Frame, protocol.ByteCount, protocol.Version) ([]ackhandler.Frame, protocol.ByteCount) } type ackFrameSource interface { @@ -169,7 +169,7 @@ func newPacketPacker( } // PackConnectionClose packs a packet that closes the connection with a transport error. -func (p *packetPacker) PackConnectionClose(e *qerr.TransportError, maxPacketSize protocol.ByteCount, v protocol.VersionNumber) (*coalescedPacket, error) { +func (p *packetPacker) PackConnectionClose(e *qerr.TransportError, maxPacketSize protocol.ByteCount, v protocol.Version) (*coalescedPacket, error) { var reason string // don't send details of crypto errors if !e.ErrorCode.IsCryptoError() { @@ -179,7 +179,7 @@ func (p *packetPacker) PackConnectionClose(e *qerr.TransportError, maxPacketSize } // PackApplicationClose packs a packet that closes the connection with an application error. -func (p *packetPacker) PackApplicationClose(e *qerr.ApplicationError, maxPacketSize protocol.ByteCount, v protocol.VersionNumber) (*coalescedPacket, error) { +func (p *packetPacker) PackApplicationClose(e *qerr.ApplicationError, maxPacketSize protocol.ByteCount, v protocol.Version) (*coalescedPacket, error) { return p.packConnectionClose(true, uint64(e.ErrorCode), 0, e.ErrorMessage, maxPacketSize, v) } @@ -189,7 +189,7 @@ func (p *packetPacker) packConnectionClose( frameType uint64, reason string, maxPacketSize protocol.ByteCount, - v protocol.VersionNumber, + v protocol.Version, ) (*coalescedPacket, error) { var sealers [4]sealer var hdrs [3]*wire.ExtendedHeader @@ -292,7 +292,7 @@ func (p *packetPacker) packConnectionClose( // longHeaderPacketLength calculates the length of a serialized long header packet. // It takes into account that packets that have a tiny payload need to be padded, // such that len(payload) + packet number len >= 4 + AEAD overhead -func (p *packetPacker) longHeaderPacketLength(hdr *wire.ExtendedHeader, pl payload, v protocol.VersionNumber) protocol.ByteCount { +func (p *packetPacker) longHeaderPacketLength(hdr *wire.ExtendedHeader, pl payload, v protocol.Version) protocol.ByteCount { var paddingLen protocol.ByteCount pnLen := protocol.ByteCount(hdr.PacketNumberLen) if pl.length < 4-pnLen { @@ -327,7 +327,7 @@ func (p *packetPacker) initialPaddingLen(frames []ackhandler.Frame, currentSize, // PackCoalescedPacket packs a new packet. // It packs an Initial / Handshake if there is data to send in these packet number spaces. // It should only be called before the handshake is confirmed. -func (p *packetPacker) PackCoalescedPacket(onlyAck bool, maxPacketSize protocol.ByteCount, v protocol.VersionNumber) (*coalescedPacket, error) { +func (p *packetPacker) PackCoalescedPacket(onlyAck bool, maxPacketSize protocol.ByteCount, v protocol.Version) (*coalescedPacket, error) { var ( initialHdr, handshakeHdr, zeroRTTHdr *wire.ExtendedHeader initialPayload, handshakePayload, zeroRTTPayload, oneRTTPayload payload @@ -441,7 +441,7 @@ func (p *packetPacker) PackCoalescedPacket(onlyAck bool, maxPacketSize protocol. // PackAckOnlyPacket packs a packet containing only an ACK in the application data packet number space. // It should be called after the handshake is confirmed. -func (p *packetPacker) PackAckOnlyPacket(maxPacketSize protocol.ByteCount, v protocol.VersionNumber) (shortHeaderPacket, *packetBuffer, error) { +func (p *packetPacker) PackAckOnlyPacket(maxPacketSize protocol.ByteCount, v protocol.Version) (shortHeaderPacket, *packetBuffer, error) { buf := getPacketBuffer() packet, err := p.appendPacket(buf, true, maxPacketSize, v) return packet, buf, err @@ -449,11 +449,11 @@ func (p *packetPacker) PackAckOnlyPacket(maxPacketSize protocol.ByteCount, v pro // AppendPacket packs a packet in the application data packet number space. // It should be called after the handshake is confirmed. -func (p *packetPacker) AppendPacket(buf *packetBuffer, maxPacketSize protocol.ByteCount, v protocol.VersionNumber) (shortHeaderPacket, error) { +func (p *packetPacker) AppendPacket(buf *packetBuffer, maxPacketSize protocol.ByteCount, v protocol.Version) (shortHeaderPacket, error) { return p.appendPacket(buf, false, maxPacketSize, v) } -func (p *packetPacker) appendPacket(buf *packetBuffer, onlyAck bool, maxPacketSize protocol.ByteCount, v protocol.VersionNumber) (shortHeaderPacket, error) { +func (p *packetPacker) appendPacket(buf *packetBuffer, onlyAck bool, maxPacketSize protocol.ByteCount, v protocol.Version) (shortHeaderPacket, error) { sealer, err := p.cryptoSetup.Get1RTTSealer() if err != nil { return shortHeaderPacket{}, err @@ -470,7 +470,7 @@ func (p *packetPacker) appendPacket(buf *packetBuffer, onlyAck bool, maxPacketSi return p.appendShortHeaderPacket(buf, connID, pn, pnLen, kp, pl, 0, maxPacketSize, sealer, false, v) } -func (p *packetPacker) maybeGetCryptoPacket(maxPacketSize protocol.ByteCount, encLevel protocol.EncryptionLevel, onlyAck, ackAllowed bool, v protocol.VersionNumber) (*wire.ExtendedHeader, payload) { +func (p *packetPacker) maybeGetCryptoPacket(maxPacketSize protocol.ByteCount, encLevel protocol.EncryptionLevel, onlyAck, ackAllowed bool, v protocol.Version) (*wire.ExtendedHeader, payload) { if onlyAck { if ack := p.acks.GetAckFrame(encLevel, true); ack != nil { return p.getLongHeader(encLevel, v), payload{ @@ -542,7 +542,7 @@ func (p *packetPacker) maybeGetCryptoPacket(maxPacketSize protocol.ByteCount, en return hdr, pl } -func (p *packetPacker) maybeGetAppDataPacketFor0RTT(sealer sealer, maxPacketSize protocol.ByteCount, v protocol.VersionNumber) (*wire.ExtendedHeader, payload) { +func (p *packetPacker) maybeGetAppDataPacketFor0RTT(sealer sealer, maxPacketSize protocol.ByteCount, v protocol.Version) (*wire.ExtendedHeader, payload) { if p.perspective != protocol.PerspectiveClient { return nil, payload{} } @@ -552,12 +552,12 @@ func (p *packetPacker) maybeGetAppDataPacketFor0RTT(sealer sealer, maxPacketSize return hdr, p.maybeGetAppDataPacket(maxPayloadSize, false, false, v) } -func (p *packetPacker) maybeGetShortHeaderPacket(sealer handshake.ShortHeaderSealer, hdrLen protocol.ByteCount, maxPacketSize protocol.ByteCount, onlyAck, ackAllowed bool, v protocol.VersionNumber) payload { +func (p *packetPacker) maybeGetShortHeaderPacket(sealer handshake.ShortHeaderSealer, hdrLen protocol.ByteCount, maxPacketSize protocol.ByteCount, onlyAck, ackAllowed bool, v protocol.Version) payload { maxPayloadSize := maxPacketSize - hdrLen - protocol.ByteCount(sealer.Overhead()) return p.maybeGetAppDataPacket(maxPayloadSize, onlyAck, ackAllowed, v) } -func (p *packetPacker) maybeGetAppDataPacket(maxPayloadSize protocol.ByteCount, onlyAck, ackAllowed bool, v protocol.VersionNumber) payload { +func (p *packetPacker) maybeGetAppDataPacket(maxPayloadSize protocol.ByteCount, onlyAck, ackAllowed bool, v protocol.Version) payload { pl := p.composeNextPacket(maxPayloadSize, onlyAck, ackAllowed, v) // check if we have anything to send @@ -580,7 +580,7 @@ func (p *packetPacker) maybeGetAppDataPacket(maxPayloadSize protocol.ByteCount, return pl } -func (p *packetPacker) composeNextPacket(maxFrameSize protocol.ByteCount, onlyAck, ackAllowed bool, v protocol.VersionNumber) payload { +func (p *packetPacker) composeNextPacket(maxFrameSize protocol.ByteCount, onlyAck, ackAllowed bool, v protocol.Version) payload { if onlyAck { if ack := p.acks.GetAckFrame(protocol.Encryption1RTT, true); ack != nil { return payload{ack: ack, length: ack.Length(v)} @@ -588,12 +588,11 @@ func (p *packetPacker) composeNextPacket(maxFrameSize protocol.ByteCount, onlyAc return payload{} } - pl := payload{streamFrames: make([]ackhandler.StreamFrame, 0, 1)} - hasData := p.framer.HasData() hasRetransmission := p.retransmissionQueue.HasAppData() var hasAck bool + var pl payload if ackAllowed { if ack := p.acks.GetAckFrame(protocol.Encryption1RTT, !hasRetransmission && !hasData); ack != nil { pl.ack = ack @@ -605,11 +604,17 @@ func (p *packetPacker) composeNextPacket(maxFrameSize protocol.ByteCount, onlyAc if p.datagramQueue != nil { if f := p.datagramQueue.Peek(); f != nil { size := f.Length(v) - if size <= maxFrameSize-pl.length { + if size <= maxFrameSize-pl.length { // DATAGRAM frame fits pl.frames = append(pl.frames, ackhandler.Frame{Frame: f}) pl.length += size p.datagramQueue.Pop() + } else if !hasAck { + // The DATAGRAM frame doesn't fit, and the packet doesn't contain an ACK. + // Discard this frame. There's no point in retrying this in the next packet, + // as it's unlikely that the available packet size will increase. + p.datagramQueue.Pop() } + // If the DATAGRAM frame was too large and the packet contained an ACK, we'll try to send it out later. } } @@ -639,7 +644,13 @@ func (p *packetPacker) composeNextPacket(maxFrameSize protocol.ByteCount, onlyAc pl.length += lengthAdded // add handlers for the control frames that were added for i := startLen; i < len(pl.frames); i++ { - pl.frames[i].Handler = p.retransmissionQueue.AppDataAckHandler() + switch pl.frames[i].Frame.(type) { + case *wire.PathChallengeFrame, *wire.PathResponseFrame: + // Path probing is currently not supported, therefore we don't need to set the OnAcked callback yet. + // PATH_CHALLENGE and PATH_RESPONSE are never retransmitted. + default: + pl.frames[i].Handler = p.retransmissionQueue.AppDataAckHandler() + } } pl.streamFrames, lengthAdded = p.framer.AppendStreamFrames(pl.streamFrames, maxFrameSize-pl.length, v) @@ -648,7 +659,7 @@ func (p *packetPacker) composeNextPacket(maxFrameSize protocol.ByteCount, onlyAc return pl } -func (p *packetPacker) MaybePackProbePacket(encLevel protocol.EncryptionLevel, maxPacketSize protocol.ByteCount, v protocol.VersionNumber) (*coalescedPacket, error) { +func (p *packetPacker) MaybePackProbePacket(encLevel protocol.EncryptionLevel, maxPacketSize protocol.ByteCount, v protocol.Version) (*coalescedPacket, error) { if encLevel == protocol.Encryption1RTT { s, err := p.cryptoSetup.Get1RTTSealer() if err != nil { @@ -714,7 +725,7 @@ func (p *packetPacker) MaybePackProbePacket(encLevel protocol.EncryptionLevel, m return packet, nil } -func (p *packetPacker) PackMTUProbePacket(ping ackhandler.Frame, size protocol.ByteCount, v protocol.VersionNumber) (shortHeaderPacket, *packetBuffer, error) { +func (p *packetPacker) PackMTUProbePacket(ping ackhandler.Frame, size protocol.ByteCount, v protocol.Version) (shortHeaderPacket, *packetBuffer, error) { pl := payload{ frames: []ackhandler.Frame{ping}, length: ping.Frame.Length(v), @@ -732,7 +743,7 @@ func (p *packetPacker) PackMTUProbePacket(ping ackhandler.Frame, size protocol.B return packet, buffer, err } -func (p *packetPacker) getLongHeader(encLevel protocol.EncryptionLevel, v protocol.VersionNumber) *wire.ExtendedHeader { +func (p *packetPacker) getLongHeader(encLevel protocol.EncryptionLevel, v protocol.Version) *wire.ExtendedHeader { pn, pnLen := p.pnManager.PeekPacketNumber(encLevel) hdr := &wire.ExtendedHeader{ PacketNumber: pn, @@ -755,7 +766,7 @@ func (p *packetPacker) getLongHeader(encLevel protocol.EncryptionLevel, v protoc return hdr } -func (p *packetPacker) appendLongHeaderPacket(buffer *packetBuffer, header *wire.ExtendedHeader, pl payload, padding protocol.ByteCount, encLevel protocol.EncryptionLevel, sealer sealer, v protocol.VersionNumber) (*longHeaderPacket, error) { +func (p *packetPacker) appendLongHeaderPacket(buffer *packetBuffer, header *wire.ExtendedHeader, pl payload, padding protocol.ByteCount, encLevel protocol.EncryptionLevel, sealer sealer, v protocol.Version) (*longHeaderPacket, error) { var paddingLen protocol.ByteCount pnLen := protocol.ByteCount(header.PacketNumberLen) if pl.length < 4-pnLen { @@ -801,7 +812,7 @@ func (p *packetPacker) appendShortHeaderPacket( padding, maxPacketSize protocol.ByteCount, sealer sealer, isMTUProbePacket bool, - v protocol.VersionNumber, + v protocol.Version, ) (shortHeaderPacket, error) { var paddingLen protocol.ByteCount if pl.length < 4-protocol.ByteCount(pnLen) { @@ -847,7 +858,7 @@ func (p *packetPacker) appendShortHeaderPacket( // appendPacketPayload serializes the payload of a packet into the raw byte slice. // It modifies the order of payload.frames. -func (p *packetPacker) appendPacketPayload(raw []byte, pl payload, paddingLen protocol.ByteCount, v protocol.VersionNumber) ([]byte, error) { +func (p *packetPacker) appendPacketPayload(raw []byte, pl payload, paddingLen protocol.ByteCount, v protocol.Version) ([]byte, error) { payloadOffset := len(raw) if pl.ack != nil { var err error diff --git a/packet_packer_test.go b/packet_packer_test.go index 13b20d63a..4a2c3d0e5 100644 --- a/packet_packer_test.go +++ b/packet_packer_test.go @@ -65,7 +65,7 @@ var _ = Describe("Packet packer", func() { } expectAppendStreamFrames := func(frames ...ackhandler.StreamFrame) { - framer.EXPECT().AppendStreamFrames(gomock.Any(), gomock.Any(), gomock.Any()).DoAndReturn(func(fs []ackhandler.StreamFrame, _ protocol.ByteCount, v protocol.VersionNumber) ([]ackhandler.StreamFrame, protocol.ByteCount) { + framer.EXPECT().AppendStreamFrames(gomock.Any(), gomock.Any(), gomock.Any()).DoAndReturn(func(fs []ackhandler.StreamFrame, _ protocol.ByteCount, v protocol.Version) ([]ackhandler.StreamFrame, protocol.ByteCount) { var length protocol.ByteCount for _, f := range frames { length += f.Frame.Length(v) @@ -75,7 +75,7 @@ var _ = Describe("Packet packer", func() { } expectAppendControlFrames := func(frames ...ackhandler.Frame) { - framer.EXPECT().AppendControlFrames(gomock.Any(), gomock.Any(), gomock.Any()).DoAndReturn(func(fs []ackhandler.Frame, _ protocol.ByteCount, v protocol.VersionNumber) ([]ackhandler.Frame, protocol.ByteCount) { + framer.EXPECT().AppendControlFrames(gomock.Any(), gomock.Any(), gomock.Any()).DoAndReturn(func(fs []ackhandler.Frame, _ protocol.ByteCount, v protocol.Version) ([]ackhandler.Frame, protocol.ByteCount) { var length protocol.ByteCount for _, f := range frames { length += f.Frame.Length(v) @@ -301,12 +301,12 @@ var _ = Describe("Packet packer", func() { pnManager.EXPECT().PopPacketNumber(protocol.Encryption0RTT).Return(protocol.PacketNumber(0x42)) cf := ackhandler.Frame{Frame: &wire.MaxDataFrame{MaximumData: 0x1337}} framer.EXPECT().HasData().Return(true) - framer.EXPECT().AppendControlFrames(gomock.Any(), gomock.Any(), protocol.Version1).DoAndReturn(func(frames []ackhandler.Frame, _ protocol.ByteCount, v protocol.VersionNumber) ([]ackhandler.Frame, protocol.ByteCount) { + framer.EXPECT().AppendControlFrames(gomock.Any(), gomock.Any(), protocol.Version1).DoAndReturn(func(frames []ackhandler.Frame, _ protocol.ByteCount, v protocol.Version) ([]ackhandler.Frame, protocol.ByteCount) { Expect(frames).To(BeEmpty()) return append(frames, cf), cf.Frame.Length(v) }) // TODO: check sizes - framer.EXPECT().AppendStreamFrames(gomock.Any(), gomock.Any(), protocol.Version1).DoAndReturn(func(frames []ackhandler.StreamFrame, _ protocol.ByteCount, _ protocol.VersionNumber) ([]ackhandler.StreamFrame, protocol.ByteCount) { + framer.EXPECT().AppendStreamFrames(gomock.Any(), gomock.Any(), protocol.Version1).DoAndReturn(func(frames []ackhandler.StreamFrame, _ protocol.ByteCount, _ protocol.Version) ([]ackhandler.StreamFrame, protocol.ByteCount) { return frames, 0 }) p, err := packer.PackCoalescedPacket(false, maxPacketSize, protocol.Version1) @@ -516,7 +516,6 @@ var _ = Describe("Packet packer", func() { buffer.Data = append(buffer.Data, []byte("foobar")...) p, err := packer.AppendPacket(buffer, maxPacketSize, protocol.Version1) Expect(err).ToNot(HaveOccurred()) - Expect(p).ToNot(BeNil()) b, err := f.Append(nil, protocol.Version1) Expect(err).ToNot(HaveOccurred()) Expect(p.Frames).To(BeEmpty()) @@ -535,7 +534,6 @@ var _ = Describe("Packet packer", func() { sealingManager.EXPECT().Get1RTTSealer().Return(getSealer(), nil) p, err := packer.AppendPacket(getPacketBuffer(), maxPacketSize, protocol.Version1) Expect(err).NotTo(HaveOccurred()) - Expect(p).ToNot(BeNil()) Expect(p.Ack).To(Equal(ack)) }) @@ -553,7 +551,6 @@ var _ = Describe("Packet packer", func() { expectAppendStreamFrames() buffer := getPacketBuffer() p, err := packer.AppendPacket(buffer, maxPacketSize, protocol.Version1) - Expect(p).ToNot(BeNil()) Expect(err).ToNot(HaveOccurred()) Expect(p.Frames).To(HaveLen(2)) for i, f := range p.Frames { @@ -562,6 +559,36 @@ var _ = Describe("Packet packer", func() { Expect(buffer.Len()).ToNot(BeZero()) }) + It("packs PATH_CHALLENGE and PATH_RESPONSE frames", func() { + pnManager.EXPECT().PeekPacketNumber(protocol.Encryption1RTT).Return(protocol.PacketNumber(0x42), protocol.PacketNumberLen2) + pnManager.EXPECT().PopPacketNumber(protocol.Encryption1RTT).Return(protocol.PacketNumber(0x42)) + sealingManager.EXPECT().Get1RTTSealer().Return(getSealer(), nil) + framer.EXPECT().HasData().Return(true) + ackFramer.EXPECT().GetAckFrame(protocol.Encryption1RTT, false) + frames := []ackhandler.Frame{ + {Frame: &wire.PathChallengeFrame{}}, + {Frame: &wire.PathResponseFrame{}}, + {Frame: &wire.DataBlockedFrame{}}, + } + expectAppendControlFrames(frames...) + expectAppendStreamFrames() + buffer := getPacketBuffer() + p, err := packer.AppendPacket(buffer, maxPacketSize, protocol.Version1) + Expect(err).ToNot(HaveOccurred()) + Expect(p.Frames).To(HaveLen(3)) + for i, f := range p.Frames { + Expect(f).To(BeAssignableToTypeOf(frames[i])) + switch f.Frame.(type) { + case *wire.PathChallengeFrame, *wire.PathResponseFrame: + // This means that the frame won't be retransmitted. + Expect(f.Handler).To(BeNil()) + default: + Expect(f.Handler).ToNot(BeNil()) + } + } + Expect(buffer.Len()).ToNot(BeZero()) + }) + It("packs DATAGRAM frames", func() { ackFramer.EXPECT().GetAckFrame(protocol.Encryption1RTT, true) pnManager.EXPECT().PeekPacketNumber(protocol.Encryption1RTT).Return(protocol.PacketNumber(0x42), protocol.PacketNumberLen2) @@ -575,7 +602,7 @@ var _ = Describe("Packet packer", func() { go func() { defer GinkgoRecover() defer close(done) - datagramQueue.AddAndWait(f) + datagramQueue.Add(f) }() // make sure the DATAGRAM has actually been queued time.Sleep(scaleDuration(20 * time.Millisecond)) @@ -583,7 +610,6 @@ var _ = Describe("Packet packer", func() { framer.EXPECT().HasData() buffer := getPacketBuffer() p, err := packer.AppendPacket(buffer, maxPacketSize, protocol.Version1) - Expect(p).ToNot(BeNil()) Expect(err).ToNot(HaveOccurred()) Expect(p.Frames).To(HaveLen(1)) Expect(p.Frames[0].Frame).To(Equal(f)) @@ -604,7 +630,7 @@ var _ = Describe("Packet packer", func() { go func() { defer GinkgoRecover() defer close(done) - datagramQueue.AddAndWait(f) + datagramQueue.Add(f) }() // make sure the DATAGRAM has actually been queued time.Sleep(scaleDuration(20 * time.Millisecond)) @@ -612,15 +638,42 @@ var _ = Describe("Packet packer", func() { framer.EXPECT().HasData() buffer := getPacketBuffer() p, err := packer.AppendPacket(buffer, maxPacketSize, protocol.Version1) - Expect(p).ToNot(BeNil()) Expect(err).ToNot(HaveOccurred()) Expect(p.Ack).ToNot(BeNil()) Expect(p.Frames).To(BeEmpty()) Expect(buffer.Data).ToNot(BeEmpty()) + Expect(datagramQueue.Peek()).To(Equal(f)) // make sure the frame is still there datagramQueue.CloseWithError(nil) Eventually(done).Should(BeClosed()) }) + It("discards a DATAGRAM frame if it doesn't fit into a packet that doesn't contain an ACK", func() { + ackFramer.EXPECT().GetAckFrame(protocol.Encryption1RTT, true) + pnManager.EXPECT().PeekPacketNumber(protocol.Encryption1RTT).Return(protocol.PacketNumber(0x42), protocol.PacketNumberLen2) + sealingManager.EXPECT().Get1RTTSealer().Return(getSealer(), nil) + f := &wire.DatagramFrame{ + DataLenPresent: true, + Data: make([]byte, maxPacketSize+10), // won't fit + } + done := make(chan struct{}) + go func() { + defer GinkgoRecover() + defer close(done) + datagramQueue.Add(f) + }() + // make sure the DATAGRAM has actually been queued + time.Sleep(scaleDuration(20 * time.Millisecond)) + + framer.EXPECT().HasData() + buffer := getPacketBuffer() + p, err := packer.AppendPacket(buffer, maxPacketSize, protocol.Version1) + Expect(err).To(MatchError(errNothingToPack)) + Expect(p.Frames).To(BeEmpty()) + Expect(p.Ack).To(BeNil()) + Expect(datagramQueue.Peek()).To(BeNil()) + Eventually(done).Should(BeClosed()) + }) + It("accounts for the space consumed by control frames", func() { pnManager.EXPECT().PeekPacketNumber(protocol.Encryption1RTT).Return(protocol.PacketNumber(0x42), protocol.PacketNumberLen2) sealingManager.EXPECT().Get1RTTSealer().Return(getSealer(), nil) @@ -628,11 +681,11 @@ var _ = Describe("Packet packer", func() { ackFramer.EXPECT().GetAckFrame(protocol.Encryption1RTT, false) var maxSize protocol.ByteCount gomock.InOrder( - framer.EXPECT().AppendControlFrames(gomock.Any(), gomock.Any(), protocol.Version1).DoAndReturn(func(fs []ackhandler.Frame, maxLen protocol.ByteCount, _ protocol.VersionNumber) ([]ackhandler.Frame, protocol.ByteCount) { + framer.EXPECT().AppendControlFrames(gomock.Any(), gomock.Any(), protocol.Version1).DoAndReturn(func(fs []ackhandler.Frame, maxLen protocol.ByteCount, _ protocol.Version) ([]ackhandler.Frame, protocol.ByteCount) { maxSize = maxLen return fs, 444 }), - framer.EXPECT().AppendStreamFrames(gomock.Any(), gomock.Any(), protocol.Version1).Do(func(fs []ackhandler.StreamFrame, maxLen protocol.ByteCount, _ protocol.VersionNumber) ([]ackhandler.StreamFrame, protocol.ByteCount) { + framer.EXPECT().AppendStreamFrames(gomock.Any(), gomock.Any(), protocol.Version1).Do(func(fs []ackhandler.StreamFrame, maxLen protocol.ByteCount, _ protocol.Version) ([]ackhandler.StreamFrame, protocol.ByteCount) { Expect(maxLen).To(Equal(maxSize - 444)) return fs, 0 }), @@ -746,7 +799,6 @@ var _ = Describe("Packet packer", func() { expectAppendControlFrames() expectAppendStreamFrames(ackhandler.StreamFrame{Frame: f1}, ackhandler.StreamFrame{Frame: f2}, ackhandler.StreamFrame{Frame: f3}) p, err := packer.AppendPacket(getPacketBuffer(), maxPacketSize, protocol.Version1) - Expect(p).ToNot(BeNil()) Expect(err).ToNot(HaveOccurred()) Expect(p.Frames).To(BeEmpty()) Expect(p.StreamFrames).To(HaveLen(3)) @@ -766,7 +818,6 @@ var _ = Describe("Packet packer", func() { expectAppendControlFrames() expectAppendStreamFrames() p, err := packer.AppendPacket(getPacketBuffer(), maxPacketSize, protocol.Version1) - Expect(p).ToNot(BeNil()) Expect(err).ToNot(HaveOccurred()) Expect(p.Ack).ToNot(BeNil()) Expect(p.Frames).To(BeEmpty()) @@ -783,7 +834,6 @@ var _ = Describe("Packet packer", func() { expectAppendControlFrames() expectAppendStreamFrames() p, err := packer.AppendPacket(getPacketBuffer(), maxPacketSize, protocol.Version1) - Expect(p).ToNot(BeNil()) Expect(err).ToNot(HaveOccurred()) var hasPing bool for _, f := range p.Frames { @@ -802,7 +852,6 @@ var _ = Describe("Packet packer", func() { expectAppendControlFrames() expectAppendStreamFrames() p, err = packer.AppendPacket(getPacketBuffer(), maxPacketSize, protocol.Version1) - Expect(p).ToNot(BeNil()) Expect(err).ToNot(HaveOccurred()) Expect(p.Ack).ToNot(BeNil()) Expect(p.Frames).To(BeEmpty()) @@ -852,7 +901,6 @@ var _ = Describe("Packet packer", func() { expectAppendControlFrames(ackhandler.Frame{Frame: &wire.MaxDataFrame{}}) p, err := packer.AppendPacket(getPacketBuffer(), maxPacketSize, protocol.Version1) Expect(err).ToNot(HaveOccurred()) - Expect(p).ToNot(BeNil()) Expect(p.Frames).ToNot(ContainElement(&wire.PingFrame{})) }) }) @@ -1443,7 +1491,7 @@ var _ = Describe("Packet packer", func() { pnManager.EXPECT().PopPacketNumber(protocol.Encryption1RTT).Return(protocol.PacketNumber(0x42)) framer.EXPECT().HasData().Return(true) expectAppendControlFrames() - framer.EXPECT().AppendStreamFrames(gomock.Any(), gomock.Any(), protocol.Version1).DoAndReturn(func(fs []ackhandler.StreamFrame, maxSize protocol.ByteCount, v protocol.VersionNumber) ([]ackhandler.StreamFrame, protocol.ByteCount) { + framer.EXPECT().AppendStreamFrames(gomock.Any(), gomock.Any(), protocol.Version1).DoAndReturn(func(fs []ackhandler.StreamFrame, maxSize protocol.ByteCount, v protocol.Version) ([]ackhandler.StreamFrame, protocol.ByteCount) { sf, split := f.MaybeSplitOffFrame(maxSize, v) Expect(split).To(BeTrue()) return append(fs, ackhandler.StreamFrame{Frame: sf}), sf.Length(v) diff --git a/packet_unpacker.go b/packet_unpacker.go index c9d1c0eb3..3d92c618c 100644 --- a/packet_unpacker.go +++ b/packet_unpacker.go @@ -53,7 +53,7 @@ func newPacketUnpacker(cs handshake.CryptoSetup, shortHdrConnIDLen int) *packetU // If the reserved bits are invalid, the error is wire.ErrInvalidReservedBits. // If any other error occurred when parsing the header, the error is of type headerParseError. // If decrypting the payload fails for any reason, the error is the error returned by the AEAD. -func (u *packetUnpacker) UnpackLongHeader(hdr *wire.Header, rcvTime time.Time, data []byte, v protocol.VersionNumber) (*unpackedPacket, error) { +func (u *packetUnpacker) UnpackLongHeader(hdr *wire.Header, rcvTime time.Time, data []byte, v protocol.Version) (*unpackedPacket, error) { var encLevel protocol.EncryptionLevel var extHdr *wire.ExtendedHeader var decrypted []byte @@ -125,7 +125,7 @@ func (u *packetUnpacker) UnpackShortHeader(rcvTime time.Time, data []byte) (prot return pn, pnLen, kp, decrypted, nil } -func (u *packetUnpacker) unpackLongHeaderPacket(opener handshake.LongHeaderOpener, hdr *wire.Header, data []byte, v protocol.VersionNumber) (*wire.ExtendedHeader, []byte, error) { +func (u *packetUnpacker) unpackLongHeaderPacket(opener handshake.LongHeaderOpener, hdr *wire.Header, data []byte, v protocol.Version) (*wire.ExtendedHeader, []byte, error) { extHdr, parseErr := u.unpackLongHeader(opener, hdr, data, v) // If the reserved bits are set incorrectly, we still need to continue unpacking. // This avoids a timing side-channel, which otherwise might allow an attacker @@ -187,7 +187,7 @@ func (u *packetUnpacker) unpackShortHeader(hd headerDecryptor, data []byte) (int } // The error is either nil, a wire.ErrInvalidReservedBits or of type headerParseError. -func (u *packetUnpacker) unpackLongHeader(hd headerDecryptor, hdr *wire.Header, data []byte, v protocol.VersionNumber) (*wire.ExtendedHeader, error) { +func (u *packetUnpacker) unpackLongHeader(hd headerDecryptor, hdr *wire.Header, data []byte, v protocol.Version) (*wire.ExtendedHeader, error) { extHdr, err := unpackLongHeader(hd, hdr, data, v) if err != nil && err != wire.ErrInvalidReservedBits { return nil, &headerParseError{err: err} @@ -195,7 +195,7 @@ func (u *packetUnpacker) unpackLongHeader(hd headerDecryptor, hdr *wire.Header, return extHdr, err } -func unpackLongHeader(hd headerDecryptor, hdr *wire.Header, data []byte, v protocol.VersionNumber) (*wire.ExtendedHeader, error) { +func unpackLongHeader(hd headerDecryptor, hdr *wire.Header, data []byte, v protocol.Version) (*wire.ExtendedHeader, error) { r := bytes.NewReader(data) hdrLen := hdr.ParsedLen() diff --git a/qlog/qlog.go b/qlog/connection_tracer.go similarity index 75% rename from qlog/qlog.go rename to qlog/connection_tracer.go index e94da3e20..6b7076bab 100644 --- a/qlog/qlog.go +++ b/qlog/connection_tracer.go @@ -1,13 +1,8 @@ package qlog import ( - "bytes" - "fmt" "io" - "log" "net" - "runtime/debug" - "sync" "time" "github.com/refraction-networking/uquic/internal/protocol" @@ -18,62 +13,28 @@ import ( "github.com/francoispqt/gojay" ) -// Setting of this only works when quic-go is used as a library. -// When building a binary from this repository, the version can be set using the following go build flag: -// -ldflags="-X github.com/refraction-networking/uquic/qlog.quicGoVersion=foobar" -var quicGoVersion = "(devel)" - -func init() { - if quicGoVersion != "(devel)" { // variable set by ldflags - return - } - info, ok := debug.ReadBuildInfo() - if !ok { // no build info available. This happens when quic-go is not used as a library. - return - } - for _, d := range info.Deps { - if d.Path == "github.com/refraction-networking/uquic" { - quicGoVersion = d.Version - if d.Replace != nil { - if len(d.Replace.Version) > 0 { - quicGoVersion = d.Version - } else { - quicGoVersion += " (replaced)" - } - } - break - } - } -} - -const eventChanSize = 50 - type connectionTracer struct { - mutex sync.Mutex - - w io.WriteCloser - odcid protocol.ConnectionID - perspective protocol.Perspective - referenceTime time.Time - - events chan event - encodeErr error - runStopped chan struct{} - + w writer lastMetrics *metrics + + perspective logging.Perspective } // NewConnectionTracer creates a new tracer to record a qlog for a connection. -func NewConnectionTracer(w io.WriteCloser, p protocol.Perspective, odcid protocol.ConnectionID) *logging.ConnectionTracer { +func NewConnectionTracer(w io.WriteCloser, p logging.Perspective, odcid protocol.ConnectionID) *logging.ConnectionTracer { + tr := &trace{ + VantagePoint: vantagePoint{Type: p.String()}, + CommonFields: commonFields{ + ODCID: &odcid, + GroupID: &odcid, + ReferenceTime: time.Now(), + }, + } t := connectionTracer{ - w: w, - perspective: p, - odcid: odcid, - runStopped: make(chan struct{}), - events: make(chan event, eventChanSize), - referenceTime: time.Now(), + w: *newWriter(w, tr), + perspective: p, } - go t.run() + go t.w.Run() return &logging.ConnectionTracer{ StartedConnection: func(local, remote net.Addr, srcConnID, destConnID logging.ConnectionID) { t.StartedConnection(local, remote, srcConnID, destConnID) @@ -106,8 +67,8 @@ func NewConnectionTracer(w io.WriteCloser, p protocol.Perspective, odcid protoco BufferedPacket: func(pt logging.PacketType, size protocol.ByteCount) { t.BufferedPacket(pt, size) }, - DroppedPacket: func(pt logging.PacketType, size protocol.ByteCount, reason logging.PacketDropReason) { - t.DroppedPacket(pt, size, reason) + DroppedPacket: func(pt logging.PacketType, pn logging.PacketNumber, size logging.ByteCount, reason logging.PacketDropReason) { + t.DroppedPacket(pt, pn, size, reason) }, UpdatedMetrics: func(rttStats *utils.RTTStats, cwnd, bytesInFlight protocol.ByteCount, packetsInFlight int) { t.UpdatedMetrics(rttStats, cwnd, bytesInFlight, packetsInFlight) @@ -124,14 +85,14 @@ func NewConnectionTracer(w io.WriteCloser, p protocol.Perspective, odcid protoco UpdatedKeyFromTLS: func(encLevel protocol.EncryptionLevel, pers protocol.Perspective) { t.UpdatedKeyFromTLS(encLevel, pers) }, - UpdatedKey: func(generation protocol.KeyPhase, remote bool) { - t.UpdatedKey(generation, remote) + UpdatedKey: func(keyPhase protocol.KeyPhase, remote bool) { + t.UpdatedKey(keyPhase, remote) }, DroppedEncryptionLevel: func(encLevel protocol.EncryptionLevel) { t.DroppedEncryptionLevel(encLevel) }, - DroppedKey: func(generation protocol.KeyPhase) { - t.DroppedKey(generation) + DroppedKey: func(keyPhase protocol.KeyPhase) { + t.DroppedKey(keyPhase) }, SetLossTimer: func(tt logging.TimerType, encLevel protocol.EncryptionLevel, timeout time.Time) { t.SetLossTimer(tt, encLevel, timeout) @@ -145,6 +106,9 @@ func NewConnectionTracer(w io.WriteCloser, p protocol.Perspective, odcid protoco ECNStateUpdated: func(state logging.ECNState, trigger logging.ECNStateTrigger) { t.ECNStateUpdated(state, trigger) }, + ChoseALPN: func(protocol string) { + t.recordEvent(time.Now(), eventALPNInformation{chosenALPN: protocol}) + }, Debug: func(name, msg string) { t.Debug(name, msg) }, @@ -154,65 +118,12 @@ func NewConnectionTracer(w io.WriteCloser, p protocol.Perspective, odcid protoco } } -func (t *connectionTracer) run() { - defer close(t.runStopped) - buf := &bytes.Buffer{} - enc := gojay.NewEncoder(buf) - tl := &topLevel{ - trace: trace{ - VantagePoint: vantagePoint{Type: t.perspective}, - CommonFields: commonFields{ - ODCID: t.odcid, - GroupID: t.odcid, - ReferenceTime: t.referenceTime, - }, - }, - } - if err := enc.Encode(tl); err != nil { - panic(fmt.Sprintf("qlog encoding into a bytes.Buffer failed: %s", err)) - } - if err := buf.WriteByte('\n'); err != nil { - panic(fmt.Sprintf("qlog encoding into a bytes.Buffer failed: %s", err)) - } - if _, err := t.w.Write(buf.Bytes()); err != nil { - t.encodeErr = err - } - enc = gojay.NewEncoder(t.w) - for ev := range t.events { - if t.encodeErr != nil { // if encoding failed, just continue draining the event channel - continue - } - if err := enc.Encode(ev); err != nil { - t.encodeErr = err - continue - } - if _, err := t.w.Write([]byte{'\n'}); err != nil { - t.encodeErr = err - } - } +func (t *connectionTracer) recordEvent(eventTime time.Time, details eventDetails) { + t.w.RecordEvent(eventTime, details) } func (t *connectionTracer) Close() { - if err := t.export(); err != nil { - log.Printf("exporting qlog failed: %s\n", err) - } -} - -// export writes a qlog. -func (t *connectionTracer) export() error { - close(t.events) - <-t.runStopped - if t.encodeErr != nil { - return t.encodeErr - } - return t.w.Close() -} - -func (t *connectionTracer) recordEvent(eventTime time.Time, details eventDetails) { - t.events <- event{ - RelativeTime: eventTime.Sub(t.referenceTime), - eventDetails: details, - } + t.w.Close() } func (t *connectionTracer) StartedConnection(local, remote net.Addr, srcConnID, destConnID protocol.ConnectionID) { @@ -225,14 +136,12 @@ func (t *connectionTracer) StartedConnection(local, remote net.Addr, srcConnID, if !ok { return } - t.mutex.Lock() t.recordEvent(time.Now(), &eventConnectionStarted{ SrcAddr: localAddr, DestAddr: remoteAddr, SrcConnectionID: srcConnID, DestConnectionID: destConnID, }) - t.mutex.Unlock() } func (t *connectionTracer) NegotiatedVersion(chosen logging.VersionNumber, client, server []logging.VersionNumber) { @@ -249,19 +158,15 @@ func (t *connectionTracer) NegotiatedVersion(chosen logging.VersionNumber, clien serverVersions[i] = versionNumber(v) } } - t.mutex.Lock() t.recordEvent(time.Now(), &eventVersionNegotiated{ clientVersions: clientVersions, serverVersions: serverVersions, chosenVersion: versionNumber(chosen), }) - t.mutex.Unlock() } func (t *connectionTracer) ClosedConnection(e error) { - t.mutex.Lock() t.recordEvent(time.Now(), &eventConnectionClosed{e: e}) - t.mutex.Unlock() } func (t *connectionTracer) SentTransportParameters(tp *wire.TransportParameters) { @@ -276,9 +181,7 @@ func (t *connectionTracer) RestoredTransportParameters(tp *wire.TransportParamet ev := t.toTransportParameters(tp) ev.Restore = true - t.mutex.Lock() t.recordEvent(time.Now(), ev) - t.mutex.Unlock() } func (t *connectionTracer) recordTransportParameters(sentBy protocol.Perspective, tp *wire.TransportParameters) { @@ -289,9 +192,7 @@ func (t *connectionTracer) recordTransportParameters(sentBy protocol.Perspective } ev.SentBy = sentBy - t.mutex.Lock() t.recordEvent(time.Now(), ev) - t.mutex.Unlock() } func (t *connectionTracer) toTransportParameters(tp *wire.TransportParameters) *eventTransportParameters { @@ -299,9 +200,7 @@ func (t *connectionTracer) toTransportParameters(tp *wire.TransportParameters) * if tp.PreferredAddress != nil { pa = &preferredAddress{ IPv4: tp.PreferredAddress.IPv4, - PortV4: tp.PreferredAddress.IPv4Port, IPv6: tp.PreferredAddress.IPv6, - PortV6: tp.PreferredAddress.IPv6Port, ConnectionID: tp.PreferredAddress.ConnectionID, StatelessResetToken: tp.PreferredAddress.StatelessResetToken, } @@ -366,7 +265,6 @@ func (t *connectionTracer) sentPacket( for _, f := range frames { fs = append(fs, frame{Frame: f}) } - t.mutex.Lock() t.recordEvent(time.Now(), &eventPacketSent{ Header: hdr, Length: size, @@ -374,7 +272,6 @@ func (t *connectionTracer) sentPacket( ECN: ecn, Frames: fs, }) - t.mutex.Unlock() } func (t *connectionTracer) ReceivedLongHeaderPacket(hdr *logging.ExtendedHeader, size logging.ByteCount, ecn logging.ECN, frames []logging.Frame) { @@ -383,7 +280,6 @@ func (t *connectionTracer) ReceivedLongHeaderPacket(hdr *logging.ExtendedHeader, fs[i] = frame{Frame: f} } header := *transformLongHeader(hdr) - t.mutex.Lock() t.recordEvent(time.Now(), &eventPacketReceived{ Header: header, Length: size, @@ -391,7 +287,6 @@ func (t *connectionTracer) ReceivedLongHeaderPacket(hdr *logging.ExtendedHeader, ECN: ecn, Frames: fs, }) - t.mutex.Unlock() } func (t *connectionTracer) ReceivedShortHeaderPacket(hdr *logging.ShortHeader, size logging.ByteCount, ecn logging.ECN, frames []logging.Frame) { @@ -400,7 +295,6 @@ func (t *connectionTracer) ReceivedShortHeaderPacket(hdr *logging.ShortHeader, s fs[i] = frame{Frame: f} } header := *transformShortHeader(hdr) - t.mutex.Lock() t.recordEvent(time.Now(), &eventPacketReceived{ Header: header, Length: size, @@ -408,15 +302,12 @@ func (t *connectionTracer) ReceivedShortHeaderPacket(hdr *logging.ShortHeader, s ECN: ecn, Frames: fs, }) - t.mutex.Unlock() } func (t *connectionTracer) ReceivedRetry(hdr *wire.Header) { - t.mutex.Lock() t.recordEvent(time.Now(), &eventRetryReceived{ Header: *transformHeader(hdr), }) - t.mutex.Unlock() } func (t *connectionTracer) ReceivedVersionNegotiationPacket(dest, src logging.ArbitraryLenConnectionID, versions []logging.VersionNumber) { @@ -424,7 +315,6 @@ func (t *connectionTracer) ReceivedVersionNegotiationPacket(dest, src logging.Ar for i, v := range versions { ver[i] = versionNumber(v) } - t.mutex.Lock() t.recordEvent(time.Now(), &eventVersionNegotiationReceived{ Header: packetHeaderVersionNegotiation{ SrcConnectionID: src, @@ -432,26 +322,22 @@ func (t *connectionTracer) ReceivedVersionNegotiationPacket(dest, src logging.Ar }, SupportedVersions: ver, }) - t.mutex.Unlock() } func (t *connectionTracer) BufferedPacket(pt logging.PacketType, size protocol.ByteCount) { - t.mutex.Lock() t.recordEvent(time.Now(), &eventPacketBuffered{ PacketType: pt, PacketSize: size, }) - t.mutex.Unlock() } -func (t *connectionTracer) DroppedPacket(pt logging.PacketType, size protocol.ByteCount, reason logging.PacketDropReason) { - t.mutex.Lock() +func (t *connectionTracer) DroppedPacket(pt logging.PacketType, pn logging.PacketNumber, size protocol.ByteCount, reason logging.PacketDropReason) { t.recordEvent(time.Now(), &eventPacketDropped{ - PacketType: pt, - PacketSize: size, - Trigger: packetDropReason(reason), + PacketType: pt, + PacketNumber: pn, + PacketSize: size, + Trigger: packetDropReason(reason), }) - t.mutex.Unlock() } func (t *connectionTracer) UpdatedMetrics(rttStats *utils.RTTStats, cwnd, bytesInFlight protocol.ByteCount, packetsInFlight int) { @@ -464,46 +350,36 @@ func (t *connectionTracer) UpdatedMetrics(rttStats *utils.RTTStats, cwnd, bytesI BytesInFlight: bytesInFlight, PacketsInFlight: packetsInFlight, } - t.mutex.Lock() t.recordEvent(time.Now(), &eventMetricsUpdated{ Last: t.lastMetrics, Current: m, }) t.lastMetrics = m - t.mutex.Unlock() } func (t *connectionTracer) AcknowledgedPacket(protocol.EncryptionLevel, protocol.PacketNumber) {} func (t *connectionTracer) LostPacket(encLevel protocol.EncryptionLevel, pn protocol.PacketNumber, lossReason logging.PacketLossReason) { - t.mutex.Lock() t.recordEvent(time.Now(), &eventPacketLost{ PacketType: getPacketTypeFromEncryptionLevel(encLevel), PacketNumber: pn, Trigger: packetLossReason(lossReason), }) - t.mutex.Unlock() } func (t *connectionTracer) UpdatedCongestionState(state logging.CongestionState) { - t.mutex.Lock() t.recordEvent(time.Now(), &eventCongestionStateUpdated{state: congestionState(state)}) - t.mutex.Unlock() } func (t *connectionTracer) UpdatedPTOCount(value uint32) { - t.mutex.Lock() t.recordEvent(time.Now(), &eventUpdatedPTO{Value: value}) - t.mutex.Unlock() } func (t *connectionTracer) UpdatedKeyFromTLS(encLevel protocol.EncryptionLevel, pers protocol.Perspective) { - t.mutex.Lock() t.recordEvent(time.Now(), &eventKeyUpdated{ Trigger: keyUpdateTLS, KeyType: encLevelToKeyType(encLevel, pers), }) - t.mutex.Unlock() } func (t *connectionTracer) UpdatedKey(generation protocol.KeyPhase, remote bool) { @@ -511,23 +387,20 @@ func (t *connectionTracer) UpdatedKey(generation protocol.KeyPhase, remote bool) if remote { trigger = keyUpdateRemote } - t.mutex.Lock() now := time.Now() t.recordEvent(now, &eventKeyUpdated{ - Trigger: trigger, - KeyType: keyTypeClient1RTT, - Generation: generation, + Trigger: trigger, + KeyType: keyTypeClient1RTT, + KeyPhase: generation, }) t.recordEvent(now, &eventKeyUpdated{ - Trigger: trigger, - KeyType: keyTypeServer1RTT, - Generation: generation, + Trigger: trigger, + KeyType: keyTypeServer1RTT, + KeyPhase: generation, }) - t.mutex.Unlock() } func (t *connectionTracer) DroppedEncryptionLevel(encLevel protocol.EncryptionLevel) { - t.mutex.Lock() now := time.Now() if encLevel == protocol.Encryption0RTT { t.recordEvent(now, &eventKeyDiscarded{KeyType: encLevelToKeyType(encLevel, t.perspective)}) @@ -535,60 +408,47 @@ func (t *connectionTracer) DroppedEncryptionLevel(encLevel protocol.EncryptionLe t.recordEvent(now, &eventKeyDiscarded{KeyType: encLevelToKeyType(encLevel, protocol.PerspectiveServer)}) t.recordEvent(now, &eventKeyDiscarded{KeyType: encLevelToKeyType(encLevel, protocol.PerspectiveClient)}) } - t.mutex.Unlock() } func (t *connectionTracer) DroppedKey(generation protocol.KeyPhase) { - t.mutex.Lock() now := time.Now() t.recordEvent(now, &eventKeyDiscarded{ - KeyType: encLevelToKeyType(protocol.Encryption1RTT, protocol.PerspectiveServer), - Generation: generation, + KeyType: encLevelToKeyType(protocol.Encryption1RTT, protocol.PerspectiveServer), + KeyPhase: generation, }) t.recordEvent(now, &eventKeyDiscarded{ - KeyType: encLevelToKeyType(protocol.Encryption1RTT, protocol.PerspectiveClient), - Generation: generation, + KeyType: encLevelToKeyType(protocol.Encryption1RTT, protocol.PerspectiveClient), + KeyPhase: generation, }) - t.mutex.Unlock() } func (t *connectionTracer) SetLossTimer(tt logging.TimerType, encLevel protocol.EncryptionLevel, timeout time.Time) { - t.mutex.Lock() now := time.Now() t.recordEvent(now, &eventLossTimerSet{ TimerType: timerType(tt), EncLevel: encLevel, Delta: timeout.Sub(now), }) - t.mutex.Unlock() } func (t *connectionTracer) LossTimerExpired(tt logging.TimerType, encLevel protocol.EncryptionLevel) { - t.mutex.Lock() t.recordEvent(time.Now(), &eventLossTimerExpired{ TimerType: timerType(tt), EncLevel: encLevel, }) - t.mutex.Unlock() } func (t *connectionTracer) LossTimerCanceled() { - t.mutex.Lock() t.recordEvent(time.Now(), &eventLossTimerCanceled{}) - t.mutex.Unlock() } func (t *connectionTracer) ECNStateUpdated(state logging.ECNState, trigger logging.ECNStateTrigger) { - t.mutex.Lock() t.recordEvent(time.Now(), &eventECNStateUpdated{state: state, trigger: trigger}) - t.mutex.Unlock() } func (t *connectionTracer) Debug(name, msg string) { - t.mutex.Lock() t.recordEvent(time.Now(), &eventGeneric{ name: name, msg: msg, }) - t.mutex.Unlock() } diff --git a/qlog/connection_tracer_test.go b/qlog/connection_tracer_test.go new file mode 100644 index 000000000..bddeb9237 --- /dev/null +++ b/qlog/connection_tracer_test.go @@ -0,0 +1,917 @@ +package qlog + +import ( + "bytes" + "encoding/json" + "io" + "net" + "net/netip" + "time" + + quic "github.com/refraction-networking/uquic" + "github.com/refraction-networking/uquic/internal/protocol" + "github.com/refraction-networking/uquic/internal/qerr" + "github.com/refraction-networking/uquic/internal/utils" + "github.com/refraction-networking/uquic/logging" + + . "github.com/onsi/ginkgo/v2" + . "github.com/onsi/gomega" +) + +type nopWriteCloserImpl struct{ io.Writer } + +func (nopWriteCloserImpl) Close() error { return nil } + +func nopWriteCloser(w io.Writer) io.WriteCloser { + return &nopWriteCloserImpl{Writer: w} +} + +type entry struct { + Time time.Time + Name string + Event map[string]interface{} +} + +func exportAndParse(buf *bytes.Buffer) []entry { + m := make(map[string]interface{}) + line, err := buf.ReadBytes('\n') + Expect(err).ToNot(HaveOccurred()) + Expect(json.Unmarshal(line, &m)).To(Succeed()) + Expect(m).To(HaveKey("trace")) + var entries []entry + trace := m["trace"].(map[string]interface{}) + Expect(trace).To(HaveKey("common_fields")) + commonFields := trace["common_fields"].(map[string]interface{}) + Expect(commonFields).To(HaveKey("reference_time")) + referenceTime := time.Unix(0, int64(commonFields["reference_time"].(float64)*1e6)) + Expect(trace).ToNot(HaveKey("events")) + + for buf.Len() > 0 { + line, err := buf.ReadBytes('\n') + Expect(err).ToNot(HaveOccurred()) + ev := make(map[string]interface{}) + Expect(json.Unmarshal(line, &ev)).To(Succeed()) + Expect(ev).To(HaveLen(3)) + Expect(ev).To(HaveKey("time")) + Expect(ev).To(HaveKey("name")) + Expect(ev).To(HaveKey("data")) + entries = append(entries, entry{ + Time: referenceTime.Add(time.Duration(ev["time"].(float64)*1e6) * time.Nanosecond), + Name: ev["name"].(string), + Event: ev["data"].(map[string]interface{}), + }) + } + return entries +} + +func exportAndParseSingle(buf *bytes.Buffer) entry { + entries := exportAndParse(buf) + Expect(entries).To(HaveLen(1)) + return entries[0] +} + +var _ = Describe("Tracing", func() { + var ( + tracer *logging.ConnectionTracer + buf *bytes.Buffer + ) + + BeforeEach(func() { + buf = &bytes.Buffer{} + tracer = NewConnectionTracer( + nopWriteCloser(buf), + logging.PerspectiveServer, + protocol.ParseConnectionID([]byte{0xde, 0xad, 0xbe, 0xef}), + ) + }) + + It("exports a trace that has the right metadata", func() { + tracer.Close() + + m := make(map[string]interface{}) + Expect(json.Unmarshal(buf.Bytes(), &m)).To(Succeed()) + Expect(m).To(HaveKeyWithValue("qlog_version", "draft-02")) + Expect(m).To(HaveKey("title")) + Expect(m).To(HaveKey("trace")) + trace := m["trace"].(map[string]interface{}) + Expect(trace).To(HaveKey(("common_fields"))) + commonFields := trace["common_fields"].(map[string]interface{}) + Expect(commonFields).To(HaveKeyWithValue("ODCID", "deadbeef")) + Expect(commonFields).To(HaveKeyWithValue("group_id", "deadbeef")) + Expect(commonFields).To(HaveKey("reference_time")) + referenceTime := time.Unix(0, int64(commonFields["reference_time"].(float64)*1e6)) + Expect(referenceTime).To(BeTemporally("~", time.Now(), scaleDuration(10*time.Millisecond))) + Expect(commonFields).To(HaveKeyWithValue("time_format", "relative")) + Expect(trace).To(HaveKey("vantage_point")) + vantagePoint := trace["vantage_point"].(map[string]interface{}) + Expect(vantagePoint).To(HaveKeyWithValue("type", "server")) + }) + + Context("Events", func() { + It("records connection starts", func() { + tracer.StartedConnection( + &net.UDPAddr{IP: net.IPv4(192, 168, 13, 37), Port: 42}, + &net.UDPAddr{IP: net.IPv4(192, 168, 12, 34), Port: 24}, + protocol.ParseConnectionID([]byte{1, 2, 3, 4}), + protocol.ParseConnectionID([]byte{5, 6, 7, 8}), + ) + tracer.Close() + entry := exportAndParseSingle(buf) + Expect(entry.Time).To(BeTemporally("~", time.Now(), scaleDuration(10*time.Millisecond))) + Expect(entry.Name).To(Equal("transport:connection_started")) + ev := entry.Event + Expect(ev).To(HaveKeyWithValue("ip_version", "ipv4")) + Expect(ev).To(HaveKeyWithValue("src_ip", "192.168.13.37")) + Expect(ev).To(HaveKeyWithValue("src_port", float64(42))) + Expect(ev).To(HaveKeyWithValue("dst_ip", "192.168.12.34")) + Expect(ev).To(HaveKeyWithValue("dst_port", float64(24))) + Expect(ev).To(HaveKeyWithValue("src_cid", "01020304")) + Expect(ev).To(HaveKeyWithValue("dst_cid", "05060708")) + }) + + It("records the version, if no version negotiation happened", func() { + tracer.NegotiatedVersion(0x1337, nil, nil) + tracer.Close() + entry := exportAndParseSingle(buf) + Expect(entry.Time).To(BeTemporally("~", time.Now(), scaleDuration(10*time.Millisecond))) + Expect(entry.Name).To(Equal("transport:version_information")) + ev := entry.Event + Expect(ev).To(HaveLen(1)) + Expect(ev).To(HaveKeyWithValue("chosen_version", "1337")) + }) + + It("records the version, if version negotiation happened", func() { + tracer.NegotiatedVersion(0x1337, []logging.VersionNumber{1, 2, 3}, []logging.VersionNumber{4, 5, 6}) + tracer.Close() + entry := exportAndParseSingle(buf) + Expect(entry.Time).To(BeTemporally("~", time.Now(), scaleDuration(10*time.Millisecond))) + Expect(entry.Name).To(Equal("transport:version_information")) + ev := entry.Event + Expect(ev).To(HaveLen(3)) + Expect(ev).To(HaveKeyWithValue("chosen_version", "1337")) + Expect(ev).To(HaveKey("client_versions")) + Expect(ev["client_versions"].([]interface{})).To(Equal([]interface{}{"1", "2", "3"})) + Expect(ev).To(HaveKey("server_versions")) + Expect(ev["server_versions"].([]interface{})).To(Equal([]interface{}{"4", "5", "6"})) + }) + + It("records idle timeouts", func() { + tracer.ClosedConnection(&quic.IdleTimeoutError{}) + tracer.Close() + entry := exportAndParseSingle(buf) + Expect(entry.Time).To(BeTemporally("~", time.Now(), scaleDuration(10*time.Millisecond))) + Expect(entry.Name).To(Equal("transport:connection_closed")) + ev := entry.Event + Expect(ev).To(HaveLen(2)) + Expect(ev).To(HaveKeyWithValue("owner", "local")) + Expect(ev).To(HaveKeyWithValue("trigger", "idle_timeout")) + }) + + It("records handshake timeouts", func() { + tracer.ClosedConnection(&quic.HandshakeTimeoutError{}) + tracer.Close() + entry := exportAndParseSingle(buf) + Expect(entry.Time).To(BeTemporally("~", time.Now(), scaleDuration(10*time.Millisecond))) + Expect(entry.Name).To(Equal("transport:connection_closed")) + ev := entry.Event + Expect(ev).To(HaveLen(2)) + Expect(ev).To(HaveKeyWithValue("owner", "local")) + Expect(ev).To(HaveKeyWithValue("trigger", "handshake_timeout")) + }) + + It("records a received stateless reset packet", func() { + tracer.ClosedConnection(&quic.StatelessResetError{ + Token: protocol.StatelessResetToken{0x00, 0x11, 0x22, 0x33, 0x44, 0x55, 0x66, 0x77, 0x88, 0x99, 0xaa, 0xbb, 0xcc, 0xdd, 0xee, 0xff}, + }) + tracer.Close() + entry := exportAndParseSingle(buf) + Expect(entry.Time).To(BeTemporally("~", time.Now(), scaleDuration(10*time.Millisecond))) + Expect(entry.Name).To(Equal("transport:connection_closed")) + ev := entry.Event + Expect(ev).To(HaveLen(3)) + Expect(ev).To(HaveKeyWithValue("owner", "remote")) + Expect(ev).To(HaveKeyWithValue("trigger", "stateless_reset")) + Expect(ev).To(HaveKeyWithValue("stateless_reset_token", "00112233445566778899aabbccddeeff")) + }) + + It("records connection closing due to version negotiation failure", func() { + tracer.ClosedConnection(&quic.VersionNegotiationError{}) + tracer.Close() + entry := exportAndParseSingle(buf) + Expect(entry.Time).To(BeTemporally("~", time.Now(), scaleDuration(10*time.Millisecond))) + Expect(entry.Name).To(Equal("transport:connection_closed")) + ev := entry.Event + Expect(ev).To(HaveLen(1)) + Expect(ev).To(HaveKeyWithValue("trigger", "version_mismatch")) + }) + + It("records application errors", func() { + tracer.ClosedConnection(&quic.ApplicationError{ + Remote: true, + ErrorCode: 1337, + ErrorMessage: "foobar", + }) + tracer.Close() + entry := exportAndParseSingle(buf) + Expect(entry.Time).To(BeTemporally("~", time.Now(), scaleDuration(10*time.Millisecond))) + Expect(entry.Name).To(Equal("transport:connection_closed")) + ev := entry.Event + Expect(ev).To(HaveLen(3)) + Expect(ev).To(HaveKeyWithValue("owner", "remote")) + Expect(ev).To(HaveKeyWithValue("application_code", float64(1337))) + Expect(ev).To(HaveKeyWithValue("reason", "foobar")) + }) + + It("records transport errors", func() { + tracer.ClosedConnection(&quic.TransportError{ + ErrorCode: qerr.AEADLimitReached, + ErrorMessage: "foobar", + }) + tracer.Close() + entry := exportAndParseSingle(buf) + Expect(entry.Time).To(BeTemporally("~", time.Now(), scaleDuration(10*time.Millisecond))) + Expect(entry.Name).To(Equal("transport:connection_closed")) + ev := entry.Event + Expect(ev).To(HaveLen(3)) + Expect(ev).To(HaveKeyWithValue("owner", "local")) + Expect(ev).To(HaveKeyWithValue("connection_code", "aead_limit_reached")) + Expect(ev).To(HaveKeyWithValue("reason", "foobar")) + }) + + It("records sent transport parameters", func() { + rcid := protocol.ParseConnectionID([]byte{0xde, 0xca, 0xfb, 0xad}) + tracer.SentTransportParameters(&logging.TransportParameters{ + InitialMaxStreamDataBidiLocal: 1000, + InitialMaxStreamDataBidiRemote: 2000, + InitialMaxStreamDataUni: 3000, + InitialMaxData: 4000, + MaxBidiStreamNum: 10, + MaxUniStreamNum: 20, + MaxAckDelay: 123 * time.Millisecond, + AckDelayExponent: 12, + DisableActiveMigration: true, + MaxUDPPayloadSize: 1234, + MaxIdleTimeout: 321 * time.Millisecond, + StatelessResetToken: &protocol.StatelessResetToken{0x11, 0x22, 0x33, 0x44, 0x55, 0x66, 0x77, 0x88, 0x99, 0xaa, 0xbb, 0xcc, 0xdd, 0xee, 0xff, 0x00}, + OriginalDestinationConnectionID: protocol.ParseConnectionID([]byte{0xde, 0xad, 0xc0, 0xde}), + InitialSourceConnectionID: protocol.ParseConnectionID([]byte{0xde, 0xad, 0xbe, 0xef}), + RetrySourceConnectionID: &rcid, + ActiveConnectionIDLimit: 7, + MaxDatagramFrameSize: protocol.InvalidByteCount, + }) + tracer.Close() + entry := exportAndParseSingle(buf) + Expect(entry.Time).To(BeTemporally("~", time.Now(), scaleDuration(10*time.Millisecond))) + Expect(entry.Name).To(Equal("transport:parameters_set")) + ev := entry.Event + Expect(ev).To(HaveKeyWithValue("owner", "local")) + Expect(ev).To(HaveKeyWithValue("original_destination_connection_id", "deadc0de")) + Expect(ev).To(HaveKeyWithValue("initial_source_connection_id", "deadbeef")) + Expect(ev).To(HaveKeyWithValue("retry_source_connection_id", "decafbad")) + Expect(ev).To(HaveKeyWithValue("stateless_reset_token", "112233445566778899aabbccddeeff00")) + Expect(ev).To(HaveKeyWithValue("max_idle_timeout", float64(321))) + Expect(ev).To(HaveKeyWithValue("max_udp_payload_size", float64(1234))) + Expect(ev).To(HaveKeyWithValue("ack_delay_exponent", float64(12))) + Expect(ev).To(HaveKeyWithValue("active_connection_id_limit", float64(7))) + Expect(ev).To(HaveKeyWithValue("initial_max_data", float64(4000))) + Expect(ev).To(HaveKeyWithValue("initial_max_stream_data_bidi_local", float64(1000))) + Expect(ev).To(HaveKeyWithValue("initial_max_stream_data_bidi_remote", float64(2000))) + Expect(ev).To(HaveKeyWithValue("initial_max_stream_data_uni", float64(3000))) + Expect(ev).To(HaveKeyWithValue("initial_max_streams_bidi", float64(10))) + Expect(ev).To(HaveKeyWithValue("initial_max_streams_uni", float64(20))) + Expect(ev).ToNot(HaveKey("preferred_address")) + Expect(ev).ToNot(HaveKey("max_datagram_frame_size")) + }) + + It("records the server's transport parameters, without a stateless reset token", func() { + tracer.SentTransportParameters(&logging.TransportParameters{ + OriginalDestinationConnectionID: protocol.ParseConnectionID([]byte{0xde, 0xad, 0xc0, 0xde}), + ActiveConnectionIDLimit: 7, + }) + tracer.Close() + entry := exportAndParseSingle(buf) + Expect(entry.Time).To(BeTemporally("~", time.Now(), scaleDuration(10*time.Millisecond))) + Expect(entry.Name).To(Equal("transport:parameters_set")) + ev := entry.Event + Expect(ev).ToNot(HaveKey("stateless_reset_token")) + }) + + It("records transport parameters without retry_source_connection_id", func() { + tracer.SentTransportParameters(&logging.TransportParameters{ + StatelessResetToken: &protocol.StatelessResetToken{0x11, 0x22, 0x33, 0x44, 0x55, 0x66, 0x77, 0x88, 0x99, 0xaa, 0xbb, 0xcc, 0xdd, 0xee, 0xff, 0x00}, + }) + tracer.Close() + entry := exportAndParseSingle(buf) + Expect(entry.Time).To(BeTemporally("~", time.Now(), scaleDuration(10*time.Millisecond))) + Expect(entry.Name).To(Equal("transport:parameters_set")) + ev := entry.Event + Expect(ev).To(HaveKeyWithValue("owner", "local")) + Expect(ev).ToNot(HaveKey("retry_source_connection_id")) + }) + + It("records transport parameters with a preferred address", func() { + tracer.SentTransportParameters(&logging.TransportParameters{ + PreferredAddress: &logging.PreferredAddress{ + IPv4: netip.AddrPortFrom(netip.AddrFrom4([4]byte{12, 34, 56, 78}), 123), + IPv6: netip.AddrPortFrom(netip.AddrFrom16([16]byte{1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16}), 456), + ConnectionID: protocol.ParseConnectionID([]byte{8, 7, 6, 5, 4, 3, 2, 1}), + StatelessResetToken: protocol.StatelessResetToken{15, 14, 13, 12, 11, 10, 9, 8, 7, 6, 5, 4, 3, 2, 1, 0}, + }, + }) + tracer.Close() + entry := exportAndParseSingle(buf) + Expect(entry.Time).To(BeTemporally("~", time.Now(), scaleDuration(10*time.Millisecond))) + Expect(entry.Name).To(Equal("transport:parameters_set")) + ev := entry.Event + Expect(ev).To(HaveKeyWithValue("owner", "local")) + Expect(ev).To(HaveKey("preferred_address")) + pa := ev["preferred_address"].(map[string]interface{}) + Expect(pa).To(HaveKeyWithValue("ip_v4", "12.34.56.78")) + Expect(pa).To(HaveKeyWithValue("port_v4", float64(123))) + Expect(pa).To(HaveKeyWithValue("ip_v6", "102:304:506:708:90a:b0c:d0e:f10")) + Expect(pa).To(HaveKeyWithValue("port_v6", float64(456))) + Expect(pa).To(HaveKeyWithValue("connection_id", "0807060504030201")) + Expect(pa).To(HaveKeyWithValue("stateless_reset_token", "0f0e0d0c0b0a09080706050403020100")) + }) + + It("records transport parameters that enable the datagram extension", func() { + tracer.SentTransportParameters(&logging.TransportParameters{ + MaxDatagramFrameSize: 1337, + }) + tracer.Close() + entry := exportAndParseSingle(buf) + Expect(entry.Time).To(BeTemporally("~", time.Now(), scaleDuration(10*time.Millisecond))) + Expect(entry.Name).To(Equal("transport:parameters_set")) + ev := entry.Event + Expect(ev).To(HaveKeyWithValue("max_datagram_frame_size", float64(1337))) + }) + + It("records received transport parameters", func() { + tracer.ReceivedTransportParameters(&logging.TransportParameters{}) + tracer.Close() + entry := exportAndParseSingle(buf) + Expect(entry.Time).To(BeTemporally("~", time.Now(), scaleDuration(10*time.Millisecond))) + Expect(entry.Name).To(Equal("transport:parameters_set")) + ev := entry.Event + Expect(ev).To(HaveKeyWithValue("owner", "remote")) + Expect(ev).ToNot(HaveKey("original_destination_connection_id")) + }) + + It("records restored transport parameters", func() { + tracer.RestoredTransportParameters(&logging.TransportParameters{ + InitialMaxStreamDataBidiLocal: 100, + InitialMaxStreamDataBidiRemote: 200, + InitialMaxStreamDataUni: 300, + InitialMaxData: 400, + MaxIdleTimeout: 123 * time.Millisecond, + }) + tracer.Close() + entry := exportAndParseSingle(buf) + Expect(entry.Time).To(BeTemporally("~", time.Now(), scaleDuration(10*time.Millisecond))) + Expect(entry.Name).To(Equal("transport:parameters_restored")) + ev := entry.Event + Expect(ev).ToNot(HaveKey("owner")) + Expect(ev).ToNot(HaveKey("original_destination_connection_id")) + Expect(ev).ToNot(HaveKey("stateless_reset_token")) + Expect(ev).ToNot(HaveKey("retry_source_connection_id")) + Expect(ev).ToNot(HaveKey("initial_source_connection_id")) + Expect(ev).To(HaveKeyWithValue("max_idle_timeout", float64(123))) + Expect(ev).To(HaveKeyWithValue("initial_max_data", float64(400))) + Expect(ev).To(HaveKeyWithValue("initial_max_stream_data_bidi_local", float64(100))) + Expect(ev).To(HaveKeyWithValue("initial_max_stream_data_bidi_remote", float64(200))) + Expect(ev).To(HaveKeyWithValue("initial_max_stream_data_uni", float64(300))) + }) + + It("records a sent long header packet, without an ACK", func() { + tracer.SentLongHeaderPacket( + &logging.ExtendedHeader{ + Header: logging.Header{ + Type: protocol.PacketTypeHandshake, + DestConnectionID: protocol.ParseConnectionID([]byte{1, 2, 3, 4, 5, 6, 7, 8}), + SrcConnectionID: protocol.ParseConnectionID([]byte{4, 3, 2, 1}), + Length: 1337, + Version: protocol.Version1, + }, + PacketNumber: 1337, + }, + 987, + logging.ECNCE, + nil, + []logging.Frame{ + &logging.MaxStreamDataFrame{StreamID: 42, MaximumStreamData: 987}, + &logging.StreamFrame{StreamID: 123, Offset: 1234, Length: 6, Fin: true}, + }, + ) + tracer.Close() + entry := exportAndParseSingle(buf) + Expect(entry.Time).To(BeTemporally("~", time.Now(), scaleDuration(10*time.Millisecond))) + Expect(entry.Name).To(Equal("transport:packet_sent")) + ev := entry.Event + Expect(ev).To(HaveKey("raw")) + raw := ev["raw"].(map[string]interface{}) + Expect(raw).To(HaveKeyWithValue("length", float64(987))) + Expect(raw).To(HaveKeyWithValue("payload_length", float64(1337))) + Expect(ev).To(HaveKey("header")) + hdr := ev["header"].(map[string]interface{}) + Expect(hdr).To(HaveKeyWithValue("packet_type", "handshake")) + Expect(hdr).To(HaveKeyWithValue("packet_number", float64(1337))) + Expect(hdr).To(HaveKeyWithValue("scid", "04030201")) + Expect(ev).To(HaveKey("frames")) + Expect(ev).To(HaveKeyWithValue("ecn", "CE")) + frames := ev["frames"].([]interface{}) + Expect(frames).To(HaveLen(2)) + Expect(frames[0].(map[string]interface{})).To(HaveKeyWithValue("frame_type", "max_stream_data")) + Expect(frames[1].(map[string]interface{})).To(HaveKeyWithValue("frame_type", "stream")) + }) + + It("records a sent short header packet, without an ACK", func() { + tracer.SentShortHeaderPacket( + &logging.ShortHeader{ + DestConnectionID: protocol.ParseConnectionID([]byte{1, 2, 3, 4}), + PacketNumber: 1337, + }, + 123, + logging.ECNUnsupported, + &logging.AckFrame{AckRanges: []logging.AckRange{{Smallest: 1, Largest: 10}}}, + []logging.Frame{&logging.MaxDataFrame{MaximumData: 987}}, + ) + tracer.Close() + entry := exportAndParseSingle(buf) + ev := entry.Event + raw := ev["raw"].(map[string]interface{}) + Expect(raw).To(HaveKeyWithValue("length", float64(123))) + Expect(raw).ToNot(HaveKey("payload_length")) + Expect(ev).To(HaveKey("header")) + Expect(ev).ToNot(HaveKey("ecn")) + hdr := ev["header"].(map[string]interface{}) + Expect(hdr).To(HaveKeyWithValue("packet_type", "1RTT")) + Expect(hdr).To(HaveKeyWithValue("packet_number", float64(1337))) + Expect(ev).To(HaveKey("frames")) + frames := ev["frames"].([]interface{}) + Expect(frames).To(HaveLen(2)) + Expect(frames[0].(map[string]interface{})).To(HaveKeyWithValue("frame_type", "ack")) + Expect(frames[1].(map[string]interface{})).To(HaveKeyWithValue("frame_type", "max_data")) + }) + + It("records a received Long Header packet", func() { + tracer.ReceivedLongHeaderPacket( + &logging.ExtendedHeader{ + Header: logging.Header{ + Type: protocol.PacketTypeInitial, + DestConnectionID: protocol.ParseConnectionID([]byte{1, 2, 3, 4, 5, 6, 7, 8}), + SrcConnectionID: protocol.ParseConnectionID([]byte{4, 3, 2, 1}), + Token: []byte{0xde, 0xad, 0xbe, 0xef}, + Length: 1234, + Version: protocol.Version1, + }, + PacketNumber: 1337, + }, + 789, + logging.ECT0, + []logging.Frame{ + &logging.MaxStreamDataFrame{StreamID: 42, MaximumStreamData: 987}, + &logging.StreamFrame{StreamID: 123, Offset: 1234, Length: 6, Fin: true}, + }, + ) + tracer.Close() + entry := exportAndParseSingle(buf) + Expect(entry.Time).To(BeTemporally("~", time.Now(), scaleDuration(10*time.Millisecond))) + Expect(entry.Name).To(Equal("transport:packet_received")) + ev := entry.Event + Expect(ev).To(HaveKey("raw")) + raw := ev["raw"].(map[string]interface{}) + Expect(raw).To(HaveKeyWithValue("length", float64(789))) + Expect(raw).To(HaveKeyWithValue("payload_length", float64(1234))) + Expect(ev).To(HaveKeyWithValue("ecn", "ECT(0)")) + Expect(ev).To(HaveKey("header")) + hdr := ev["header"].(map[string]interface{}) + Expect(hdr).To(HaveKeyWithValue("packet_type", "initial")) + Expect(hdr).To(HaveKeyWithValue("packet_number", float64(1337))) + Expect(hdr).To(HaveKeyWithValue("scid", "04030201")) + Expect(hdr).To(HaveKey("token")) + token := hdr["token"].(map[string]interface{}) + Expect(token).To(HaveKeyWithValue("data", "deadbeef")) + Expect(ev).To(HaveKey("frames")) + Expect(ev["frames"].([]interface{})).To(HaveLen(2)) + }) + + It("records a received Short Header packet", func() { + shdr := &logging.ShortHeader{ + DestConnectionID: protocol.ParseConnectionID([]byte{1, 2, 3, 4, 5, 6, 7, 8}), + PacketNumber: 1337, + PacketNumberLen: protocol.PacketNumberLen3, + KeyPhase: protocol.KeyPhaseZero, + } + tracer.ReceivedShortHeaderPacket( + shdr, + 789, + logging.ECT1, + []logging.Frame{ + &logging.MaxStreamDataFrame{StreamID: 42, MaximumStreamData: 987}, + &logging.StreamFrame{StreamID: 123, Offset: 1234, Length: 6, Fin: true}, + }, + ) + tracer.Close() + entry := exportAndParseSingle(buf) + Expect(entry.Time).To(BeTemporally("~", time.Now(), scaleDuration(10*time.Millisecond))) + Expect(entry.Name).To(Equal("transport:packet_received")) + ev := entry.Event + Expect(ev).To(HaveKey("raw")) + raw := ev["raw"].(map[string]interface{}) + Expect(raw).To(HaveKeyWithValue("length", float64(789))) + Expect(raw).To(HaveKeyWithValue("payload_length", float64(789-(1+8+3)))) + Expect(ev).To(HaveKeyWithValue("ecn", "ECT(1)")) + Expect(ev).To(HaveKey("header")) + hdr := ev["header"].(map[string]interface{}) + Expect(hdr).To(HaveKeyWithValue("packet_type", "1RTT")) + Expect(hdr).To(HaveKeyWithValue("packet_number", float64(1337))) + Expect(hdr).To(HaveKeyWithValue("key_phase_bit", "0")) + Expect(ev).To(HaveKey("frames")) + Expect(ev["frames"].([]interface{})).To(HaveLen(2)) + }) + + It("records a received Retry packet", func() { + tracer.ReceivedRetry( + &logging.Header{ + Type: protocol.PacketTypeRetry, + DestConnectionID: protocol.ParseConnectionID([]byte{1, 2, 3, 4, 5, 6, 7, 8}), + SrcConnectionID: protocol.ParseConnectionID([]byte{4, 3, 2, 1}), + Token: []byte{0xde, 0xad, 0xbe, 0xef}, + Version: protocol.Version1, + }, + ) + tracer.Close() + entry := exportAndParseSingle(buf) + Expect(entry.Time).To(BeTemporally("~", time.Now(), scaleDuration(10*time.Millisecond))) + Expect(entry.Name).To(Equal("transport:packet_received")) + ev := entry.Event + Expect(ev).ToNot(HaveKey("raw")) + Expect(ev).To(HaveKey("header")) + header := ev["header"].(map[string]interface{}) + Expect(header).To(HaveKeyWithValue("packet_type", "retry")) + Expect(header).ToNot(HaveKey("packet_number")) + Expect(header).To(HaveKey("version")) + Expect(header).To(HaveKey("dcid")) + Expect(header).To(HaveKey("scid")) + Expect(header).To(HaveKey("token")) + token := header["token"].(map[string]interface{}) + Expect(token).To(HaveKeyWithValue("data", "deadbeef")) + Expect(ev).ToNot(HaveKey("frames")) + }) + + It("records a received Version Negotiation packet", func() { + tracer.ReceivedVersionNegotiationPacket( + protocol.ArbitraryLenConnectionID{1, 2, 3, 4, 5, 6, 7, 8}, + protocol.ArbitraryLenConnectionID{4, 3, 2, 1}, + []protocol.Version{0xdeadbeef, 0xdecafbad}, + ) + tracer.Close() + entry := exportAndParseSingle(buf) + Expect(entry.Time).To(BeTemporally("~", time.Now(), scaleDuration(10*time.Millisecond))) + Expect(entry.Name).To(Equal("transport:packet_received")) + ev := entry.Event + Expect(ev).To(HaveKey("header")) + Expect(ev).ToNot(HaveKey("frames")) + Expect(ev).To(HaveKey("supported_versions")) + Expect(ev["supported_versions"].([]interface{})).To(Equal([]interface{}{"deadbeef", "decafbad"})) + header := ev["header"] + Expect(header).To(HaveKeyWithValue("packet_type", "version_negotiation")) + Expect(header).ToNot(HaveKey("packet_number")) + Expect(header).ToNot(HaveKey("version")) + Expect(header).To(HaveKeyWithValue("dcid", "0102030405060708")) + Expect(header).To(HaveKeyWithValue("scid", "04030201")) + }) + + It("records buffered packets", func() { + tracer.BufferedPacket(logging.PacketTypeHandshake, 1337) + tracer.Close() + entry := exportAndParseSingle(buf) + Expect(entry.Time).To(BeTemporally("~", time.Now(), scaleDuration(10*time.Millisecond))) + Expect(entry.Name).To(Equal("transport:packet_buffered")) + ev := entry.Event + Expect(ev).To(HaveKey("header")) + hdr := ev["header"].(map[string]interface{}) + Expect(hdr).To(HaveLen(1)) + Expect(hdr).To(HaveKeyWithValue("packet_type", "handshake")) + Expect(ev).To(HaveKey("raw")) + Expect(ev["raw"].(map[string]interface{})).To(HaveKeyWithValue("length", float64(1337))) + Expect(ev).To(HaveKeyWithValue("trigger", "keys_unavailable")) + }) + + It("records dropped packets", func() { + tracer.DroppedPacket(logging.PacketTypeRetry, protocol.InvalidPacketNumber, 1337, logging.PacketDropPayloadDecryptError) + tracer.Close() + entry := exportAndParseSingle(buf) + Expect(entry.Time).To(BeTemporally("~", time.Now(), scaleDuration(10*time.Millisecond))) + Expect(entry.Name).To(Equal("transport:packet_dropped")) + ev := entry.Event + Expect(ev).To(HaveKey("raw")) + Expect(ev["raw"].(map[string]interface{})).To(HaveKeyWithValue("length", float64(1337))) + Expect(ev).To(HaveKey("header")) + hdr := ev["header"].(map[string]interface{}) + Expect(hdr).To(HaveLen(1)) + Expect(hdr).To(HaveKeyWithValue("packet_type", "retry")) + Expect(ev).To(HaveKeyWithValue("trigger", "payload_decrypt_error")) + }) + + It("records dropped packets with a packet number", func() { + tracer.DroppedPacket(logging.PacketTypeHandshake, 42, 1337, logging.PacketDropDuplicate) + tracer.Close() + entry := exportAndParseSingle(buf) + Expect(entry.Time).To(BeTemporally("~", time.Now(), scaleDuration(10*time.Millisecond))) + Expect(entry.Name).To(Equal("transport:packet_dropped")) + ev := entry.Event + Expect(ev).To(HaveKey("raw")) + Expect(ev["raw"].(map[string]interface{})).To(HaveKeyWithValue("length", float64(1337))) + Expect(ev).To(HaveKey("header")) + hdr := ev["header"].(map[string]interface{}) + Expect(hdr).To(HaveLen(2)) + Expect(hdr).To(HaveKeyWithValue("packet_type", "handshake")) + Expect(hdr).To(HaveKeyWithValue("packet_number", float64(42))) + Expect(ev).To(HaveKeyWithValue("trigger", "duplicate")) + }) + + It("records metrics updates", func() { + now := time.Now() + rttStats := utils.NewRTTStats() + rttStats.UpdateRTT(15*time.Millisecond, 0, now) + rttStats.UpdateRTT(20*time.Millisecond, 0, now) + rttStats.UpdateRTT(25*time.Millisecond, 0, now) + Expect(rttStats.MinRTT()).To(Equal(15 * time.Millisecond)) + Expect(rttStats.SmoothedRTT()).To(And( + BeNumerically(">", 15*time.Millisecond), + BeNumerically("<", 25*time.Millisecond), + )) + Expect(rttStats.LatestRTT()).To(Equal(25 * time.Millisecond)) + tracer.UpdatedMetrics( + rttStats, + 4321, + 1234, + 42, + ) + tracer.Close() + entry := exportAndParseSingle(buf) + Expect(entry.Time).To(BeTemporally("~", time.Now(), scaleDuration(10*time.Millisecond))) + Expect(entry.Name).To(Equal("recovery:metrics_updated")) + ev := entry.Event + Expect(ev).To(HaveKeyWithValue("min_rtt", float64(15))) + Expect(ev).To(HaveKeyWithValue("latest_rtt", float64(25))) + Expect(ev).To(HaveKey("smoothed_rtt")) + Expect(time.Duration(ev["smoothed_rtt"].(float64)) * time.Millisecond).To(BeNumerically("~", rttStats.SmoothedRTT(), time.Millisecond)) + Expect(ev).To(HaveKey("rtt_variance")) + Expect(time.Duration(ev["rtt_variance"].(float64)) * time.Millisecond).To(BeNumerically("~", rttStats.MeanDeviation(), time.Millisecond)) + Expect(ev).To(HaveKeyWithValue("congestion_window", float64(4321))) + Expect(ev).To(HaveKeyWithValue("bytes_in_flight", float64(1234))) + Expect(ev).To(HaveKeyWithValue("packets_in_flight", float64(42))) + }) + + It("only logs the diff between two metrics updates", func() { + now := time.Now() + rttStats := utils.NewRTTStats() + rttStats.UpdateRTT(15*time.Millisecond, 0, now) + rttStats.UpdateRTT(20*time.Millisecond, 0, now) + rttStats.UpdateRTT(25*time.Millisecond, 0, now) + Expect(rttStats.MinRTT()).To(Equal(15 * time.Millisecond)) + + rttStats2 := utils.NewRTTStats() + rttStats2.UpdateRTT(15*time.Millisecond, 0, now) + rttStats2.UpdateRTT(15*time.Millisecond, 0, now) + rttStats2.UpdateRTT(15*time.Millisecond, 0, now) + Expect(rttStats2.MinRTT()).To(Equal(15 * time.Millisecond)) + + Expect(rttStats.LatestRTT()).To(Equal(25 * time.Millisecond)) + tracer.UpdatedMetrics( + rttStats, + 4321, + 1234, + 42, + ) + tracer.UpdatedMetrics( + rttStats2, + 4321, + 12345, // changed + 42, + ) + tracer.Close() + entries := exportAndParse(buf) + Expect(entries).To(HaveLen(2)) + Expect(entries[0].Time).To(BeTemporally("~", time.Now(), scaleDuration(10*time.Millisecond))) + Expect(entries[0].Name).To(Equal("recovery:metrics_updated")) + Expect(entries[0].Event).To(HaveLen(7)) + Expect(entries[1].Time).To(BeTemporally("~", time.Now(), scaleDuration(10*time.Millisecond))) + Expect(entries[1].Name).To(Equal("recovery:metrics_updated")) + ev := entries[1].Event + Expect(ev).ToNot(HaveKey("min_rtt")) + Expect(ev).ToNot(HaveKey("congestion_window")) + Expect(ev).ToNot(HaveKey("packets_in_flight")) + Expect(ev).To(HaveKeyWithValue("bytes_in_flight", float64(12345))) + Expect(ev).To(HaveKeyWithValue("smoothed_rtt", float64(15))) + }) + + It("records lost packets", func() { + tracer.LostPacket(protocol.EncryptionHandshake, 42, logging.PacketLossReorderingThreshold) + tracer.Close() + entry := exportAndParseSingle(buf) + Expect(entry.Time).To(BeTemporally("~", time.Now(), scaleDuration(10*time.Millisecond))) + Expect(entry.Name).To(Equal("recovery:packet_lost")) + ev := entry.Event + Expect(ev).To(HaveKey("header")) + hdr := ev["header"].(map[string]interface{}) + Expect(hdr).To(HaveLen(2)) + Expect(hdr).To(HaveKeyWithValue("packet_type", "handshake")) + Expect(hdr).To(HaveKeyWithValue("packet_number", float64(42))) + Expect(ev).To(HaveKeyWithValue("trigger", "reordering_threshold")) + }) + + It("records congestion state updates", func() { + tracer.UpdatedCongestionState(logging.CongestionStateCongestionAvoidance) + tracer.Close() + entry := exportAndParseSingle(buf) + Expect(entry.Time).To(BeTemporally("~", time.Now(), scaleDuration(10*time.Millisecond))) + Expect(entry.Name).To(Equal("recovery:congestion_state_updated")) + ev := entry.Event + Expect(ev).To(HaveKeyWithValue("new", "congestion_avoidance")) + }) + + It("records PTO changes", func() { + tracer.UpdatedPTOCount(42) + tracer.Close() + entry := exportAndParseSingle(buf) + Expect(entry.Time).To(BeTemporally("~", time.Now(), scaleDuration(10*time.Millisecond))) + Expect(entry.Name).To(Equal("recovery:metrics_updated")) + Expect(entry.Event).To(HaveKeyWithValue("pto_count", float64(42))) + }) + + It("records TLS key updates", func() { + tracer.UpdatedKeyFromTLS(protocol.EncryptionHandshake, protocol.PerspectiveClient) + tracer.Close() + entry := exportAndParseSingle(buf) + Expect(entry.Time).To(BeTemporally("~", time.Now(), scaleDuration(10*time.Millisecond))) + Expect(entry.Name).To(Equal("security:key_updated")) + ev := entry.Event + Expect(ev).To(HaveKeyWithValue("key_type", "client_handshake_secret")) + Expect(ev).To(HaveKeyWithValue("trigger", "tls")) + Expect(ev).ToNot(HaveKey("key_phase")) + Expect(ev).ToNot(HaveKey("old")) + Expect(ev).ToNot(HaveKey("new")) + }) + + It("records TLS key updates, for 1-RTT keys", func() { + tracer.UpdatedKeyFromTLS(protocol.Encryption1RTT, protocol.PerspectiveServer) + tracer.Close() + entry := exportAndParseSingle(buf) + Expect(entry.Time).To(BeTemporally("~", time.Now(), scaleDuration(10*time.Millisecond))) + Expect(entry.Name).To(Equal("security:key_updated")) + ev := entry.Event + Expect(ev).To(HaveKeyWithValue("key_type", "server_1rtt_secret")) + Expect(ev).To(HaveKeyWithValue("trigger", "tls")) + Expect(ev).To(HaveKeyWithValue("key_phase", float64(0))) + Expect(ev).ToNot(HaveKey("old")) + Expect(ev).ToNot(HaveKey("new")) + }) + + It("records QUIC key updates", func() { + tracer.UpdatedKey(1337, true) + tracer.Close() + entries := exportAndParse(buf) + Expect(entries).To(HaveLen(2)) + var keyTypes []string + for _, entry := range entries { + Expect(entry.Time).To(BeTemporally("~", time.Now(), scaleDuration(10*time.Millisecond))) + Expect(entry.Name).To(Equal("security:key_updated")) + ev := entry.Event + Expect(ev).To(HaveKeyWithValue("key_phase", float64(1337))) + Expect(ev).To(HaveKeyWithValue("trigger", "remote_update")) + Expect(ev).To(HaveKey("key_type")) + keyTypes = append(keyTypes, ev["key_type"].(string)) + } + Expect(keyTypes).To(ContainElement("server_1rtt_secret")) + Expect(keyTypes).To(ContainElement("client_1rtt_secret")) + }) + + It("records dropped encryption levels", func() { + tracer.DroppedEncryptionLevel(protocol.EncryptionInitial) + tracer.Close() + entries := exportAndParse(buf) + Expect(entries).To(HaveLen(2)) + var keyTypes []string + for _, entry := range entries { + Expect(entry.Time).To(BeTemporally("~", time.Now(), scaleDuration(10*time.Millisecond))) + Expect(entry.Name).To(Equal("security:key_discarded")) + ev := entry.Event + Expect(ev).To(HaveKeyWithValue("trigger", "tls")) + Expect(ev).To(HaveKey("key_type")) + keyTypes = append(keyTypes, ev["key_type"].(string)) + } + Expect(keyTypes).To(ContainElement("server_initial_secret")) + Expect(keyTypes).To(ContainElement("client_initial_secret")) + }) + + It("records dropped 0-RTT keys", func() { + tracer.DroppedEncryptionLevel(protocol.Encryption0RTT) + tracer.Close() + entries := exportAndParse(buf) + Expect(entries).To(HaveLen(1)) + entry := entries[0] + Expect(entry.Time).To(BeTemporally("~", time.Now(), scaleDuration(10*time.Millisecond))) + Expect(entry.Name).To(Equal("security:key_discarded")) + ev := entry.Event + Expect(ev).To(HaveKeyWithValue("trigger", "tls")) + Expect(ev).To(HaveKeyWithValue("key_type", "server_0rtt_secret")) + }) + + It("records dropped keys", func() { + tracer.DroppedKey(42) + tracer.Close() + entries := exportAndParse(buf) + Expect(entries).To(HaveLen(2)) + var keyTypes []string + for _, entry := range entries { + Expect(entry.Time).To(BeTemporally("~", time.Now(), scaleDuration(10*time.Millisecond))) + Expect(entry.Name).To(Equal("security:key_discarded")) + ev := entry.Event + Expect(ev).To(HaveKeyWithValue("key_phase", float64(42))) + Expect(ev).ToNot(HaveKey("trigger")) + Expect(ev).To(HaveKey("key_type")) + keyTypes = append(keyTypes, ev["key_type"].(string)) + } + Expect(keyTypes).To(ContainElement("server_1rtt_secret")) + Expect(keyTypes).To(ContainElement("client_1rtt_secret")) + }) + + It("records when the timer is set", func() { + timeout := time.Now().Add(137 * time.Millisecond) + tracer.SetLossTimer(logging.TimerTypePTO, protocol.EncryptionHandshake, timeout) + tracer.Close() + entry := exportAndParseSingle(buf) + Expect(entry.Time).To(BeTemporally("~", time.Now(), scaleDuration(10*time.Millisecond))) + Expect(entry.Name).To(Equal("recovery:loss_timer_updated")) + ev := entry.Event + Expect(ev).To(HaveLen(4)) + Expect(ev).To(HaveKeyWithValue("event_type", "set")) + Expect(ev).To(HaveKeyWithValue("timer_type", "pto")) + Expect(ev).To(HaveKeyWithValue("packet_number_space", "handshake")) + Expect(ev).To(HaveKey("delta")) + delta := time.Duration(ev["delta"].(float64)*1e6) * time.Nanosecond + Expect(entry.Time.Add(delta)).To(BeTemporally("~", timeout, scaleDuration(10*time.Microsecond))) + }) + + It("records when the loss timer expires", func() { + tracer.LossTimerExpired(logging.TimerTypeACK, protocol.Encryption1RTT) + tracer.Close() + entry := exportAndParseSingle(buf) + Expect(entry.Time).To(BeTemporally("~", time.Now(), scaleDuration(10*time.Millisecond))) + Expect(entry.Name).To(Equal("recovery:loss_timer_updated")) + ev := entry.Event + Expect(ev).To(HaveLen(3)) + Expect(ev).To(HaveKeyWithValue("event_type", "expired")) + Expect(ev).To(HaveKeyWithValue("timer_type", "ack")) + Expect(ev).To(HaveKeyWithValue("packet_number_space", "application_data")) + }) + + It("records when the timer is canceled", func() { + tracer.LossTimerCanceled() + tracer.Close() + entry := exportAndParseSingle(buf) + Expect(entry.Time).To(BeTemporally("~", time.Now(), scaleDuration(10*time.Millisecond))) + Expect(entry.Name).To(Equal("recovery:loss_timer_updated")) + ev := entry.Event + Expect(ev).To(HaveLen(1)) + Expect(ev).To(HaveKeyWithValue("event_type", "cancelled")) + }) + + It("records an ECN state transition, without a trigger", func() { + tracer.ECNStateUpdated(logging.ECNStateUnknown, logging.ECNTriggerNoTrigger) + tracer.Close() + entry := exportAndParseSingle(buf) + Expect(entry.Time).To(BeTemporally("~", time.Now(), scaleDuration(10*time.Millisecond))) + Expect(entry.Name).To(Equal("recovery:ecn_state_updated")) + ev := entry.Event + Expect(ev).To(HaveLen(1)) + Expect(ev).To(HaveKeyWithValue("new", "unknown")) + }) + + It("records an ECN state transition, with a trigger", func() { + tracer.ECNStateUpdated(logging.ECNStateFailed, logging.ECNFailedNoECNCounts) + tracer.Close() + entry := exportAndParseSingle(buf) + Expect(entry.Time).To(BeTemporally("~", time.Now(), scaleDuration(10*time.Millisecond))) + Expect(entry.Name).To(Equal("recovery:ecn_state_updated")) + ev := entry.Event + Expect(ev).To(HaveLen(2)) + Expect(ev).To(HaveKeyWithValue("new", "failed")) + Expect(ev).To(HaveKeyWithValue("trigger", "ACK doesn't contain ECN marks")) + }) + + It("records a generic event", func() { + tracer.Debug("foo", "bar") + tracer.Close() + entry := exportAndParseSingle(buf) + Expect(entry.Time).To(BeTemporally("~", time.Now(), scaleDuration(10*time.Millisecond))) + Expect(entry.Name).To(Equal("transport:foo")) + ev := entry.Event + Expect(ev).To(HaveLen(1)) + Expect(ev).To(HaveKeyWithValue("details", "bar")) + }) + }) +}) diff --git a/qlog/event.go b/qlog/event.go index ad90824dc..5a0e26086 100644 --- a/qlog/event.go +++ b/qlog/event.go @@ -4,6 +4,7 @@ import ( "errors" "fmt" "net" + "net/netip" "time" quic "github.com/refraction-networking/uquic" @@ -232,6 +233,20 @@ func (e eventVersionNegotiationReceived) MarshalJSONObject(enc *gojay.Encoder) { enc.ArrayKey("supported_versions", versions(e.SupportedVersions)) } +type eventVersionNegotiationSent struct { + Header packetHeaderVersionNegotiation + SupportedVersions []versionNumber +} + +func (e eventVersionNegotiationSent) Category() category { return categoryTransport } +func (e eventVersionNegotiationSent) Name() string { return "packet_sent" } +func (e eventVersionNegotiationSent) IsNil() bool { return false } + +func (e eventVersionNegotiationSent) MarshalJSONObject(enc *gojay.Encoder) { + enc.ObjectKey("header", e.Header) + enc.ArrayKey("supported_versions", versions(e.SupportedVersions)) +} + type eventPacketBuffered struct { PacketType logging.PacketType PacketSize protocol.ByteCount @@ -243,15 +258,16 @@ func (e eventPacketBuffered) IsNil() bool { return false } func (e eventPacketBuffered) MarshalJSONObject(enc *gojay.Encoder) { //nolint:gosimple - enc.ObjectKey("header", packetHeaderWithType{PacketType: e.PacketType}) + enc.ObjectKey("header", packetHeaderWithType{PacketType: e.PacketType, PacketNumber: protocol.InvalidPacketNumber}) enc.ObjectKey("raw", rawInfo{Length: e.PacketSize}) enc.StringKey("trigger", "keys_unavailable") } type eventPacketDropped struct { - PacketType logging.PacketType - PacketSize protocol.ByteCount - Trigger packetDropReason + PacketType logging.PacketType + PacketSize protocol.ByteCount + PacketNumber logging.PacketNumber + Trigger packetDropReason } func (e eventPacketDropped) Category() category { return categoryTransport } @@ -259,7 +275,10 @@ func (e eventPacketDropped) Name() string { return "packet_dropped" } func (e eventPacketDropped) IsNil() bool { return false } func (e eventPacketDropped) MarshalJSONObject(enc *gojay.Encoder) { - enc.ObjectKey("header", packetHeaderWithType{PacketType: e.PacketType}) + enc.ObjectKey("header", packetHeaderWithType{ + PacketType: e.PacketType, + PacketNumber: e.PacketNumber, + }) enc.ObjectKey("raw", rawInfo{Length: e.PacketSize}) enc.StringKey("trigger", e.Trigger.String()) } @@ -340,9 +359,9 @@ func (e eventPacketLost) MarshalJSONObject(enc *gojay.Encoder) { } type eventKeyUpdated struct { - Trigger keyUpdateTrigger - KeyType keyType - Generation protocol.KeyPhase + Trigger keyUpdateTrigger + KeyType keyType + KeyPhase protocol.KeyPhase // we don't log the keys here, so we don't need `old` and `new`. } @@ -354,13 +373,13 @@ func (e eventKeyUpdated) MarshalJSONObject(enc *gojay.Encoder) { enc.StringKey("trigger", e.Trigger.String()) enc.StringKey("key_type", e.KeyType.String()) if e.KeyType == keyTypeClient1RTT || e.KeyType == keyTypeServer1RTT { - enc.Uint64Key("generation", uint64(e.Generation)) + enc.Uint64Key("key_phase", uint64(e.KeyPhase)) } } type eventKeyDiscarded struct { - KeyType keyType - Generation protocol.KeyPhase + KeyType keyType + KeyPhase protocol.KeyPhase } func (e eventKeyDiscarded) Category() category { return categorySecurity } @@ -373,7 +392,7 @@ func (e eventKeyDiscarded) MarshalJSONObject(enc *gojay.Encoder) { } enc.StringKey("key_type", e.KeyType.String()) if e.KeyType == keyTypeClient1RTT || e.KeyType == keyTypeServer1RTT { - enc.Uint64Key("generation", uint64(e.Generation)) + enc.Uint64Key("key_phase", uint64(e.KeyPhase)) } } @@ -452,8 +471,7 @@ func (e eventTransportParameters) MarshalJSONObject(enc *gojay.Encoder) { } type preferredAddress struct { - IPv4, IPv6 net.IP - PortV4, PortV6 uint16 + IPv4, IPv6 netip.AddrPort ConnectionID protocol.ConnectionID StatelessResetToken protocol.StatelessResetToken } @@ -462,10 +480,10 @@ var _ gojay.MarshalerJSONObject = &preferredAddress{} func (a preferredAddress) IsNil() bool { return false } func (a preferredAddress) MarshalJSONObject(enc *gojay.Encoder) { - enc.StringKey("ip_v4", a.IPv4.String()) - enc.Uint16Key("port_v4", a.PortV4) - enc.StringKey("ip_v6", a.IPv6.String()) - enc.Uint16Key("port_v6", a.PortV6) + enc.StringKey("ip_v4", a.IPv4.Addr().String()) + enc.Uint16Key("port_v4", a.IPv4.Port()) + enc.StringKey("ip_v6", a.IPv6.Addr().String()) + enc.Uint16Key("port_v6", a.IPv6.Port()) enc.StringKey("connection_id", a.ConnectionID.String()) enc.StringKey("stateless_reset_token", fmt.Sprintf("%x", a.StatelessResetToken)) } @@ -550,3 +568,15 @@ func (e eventGeneric) IsNil() bool { return false } func (e eventGeneric) MarshalJSONObject(enc *gojay.Encoder) { enc.StringKey("details", e.msg) } + +type eventALPNInformation struct { + chosenALPN string +} + +func (e eventALPNInformation) Category() category { return categoryTransport } +func (e eventALPNInformation) Name() string { return "alpn_information" } +func (e eventALPNInformation) IsNil() bool { return false } + +func (e eventALPNInformation) MarshalJSONObject(enc *gojay.Encoder) { + enc.StringKey("chosen_alpn", e.chosenALPN) +} diff --git a/qlog/packet_header.go b/qlog/packet_header.go index 0fbdb6e37..f2e0e31a1 100644 --- a/qlog/packet_header.go +++ b/qlog/packet_header.go @@ -110,14 +110,18 @@ func (h packetHeaderVersionNegotiation) MarshalJSONObject(enc *gojay.Encoder) { enc.StringKey("dcid", h.DestConnectionID.String()) } -// a minimal header that only outputs the packet type +// a minimal header that only outputs the packet type, and potentially a packet number type packetHeaderWithType struct { - PacketType logging.PacketType + PacketType logging.PacketType + PacketNumber logging.PacketNumber } func (h packetHeaderWithType) IsNil() bool { return false } func (h packetHeaderWithType) MarshalJSONObject(enc *gojay.Encoder) { enc.StringKey("packet_type", packetType(h.PacketType).String()) + if h.PacketNumber != protocol.InvalidPacketNumber { + enc.Int64Key("packet_number", int64(h.PacketNumber)) + } } // a minimal header that only outputs the packet type diff --git a/qlog/packet_header_test.go b/qlog/packet_header_test.go index 1c4cdb2aa..e51e3229e 100644 --- a/qlog/packet_header_test.go +++ b/qlog/packet_header_test.go @@ -39,7 +39,7 @@ var _ = Describe("Packet Header", func() { Header: wire.Header{ Type: protocol.PacketTypeInitial, Length: 123, - Version: protocol.VersionNumber(0xdecafbad), + Version: protocol.Version(0xdecafbad), }, }, map[string]interface{}{ @@ -59,7 +59,7 @@ var _ = Describe("Packet Header", func() { Header: wire.Header{ Type: protocol.PacketTypeInitial, Length: 123, - Version: protocol.VersionNumber(0xdecafbad), + Version: protocol.Version(0xdecafbad), Token: []byte{0xde, 0xad, 0xbe, 0xef}, }, }, @@ -80,7 +80,7 @@ var _ = Describe("Packet Header", func() { Header: wire.Header{ Type: protocol.PacketTypeRetry, SrcConnectionID: protocol.ParseConnectionID([]byte{0x11, 0x22, 0x33, 0x44}), - Version: protocol.VersionNumber(0xdecafbad), + Version: protocol.Version(0xdecafbad), Token: []byte{0xde, 0xad, 0xbe, 0xef}, }, }, @@ -101,7 +101,7 @@ var _ = Describe("Packet Header", func() { PacketNumber: 0, Header: wire.Header{ Type: protocol.PacketTypeHandshake, - Version: protocol.VersionNumber(0xdecafbad), + Version: protocol.Version(0xdecafbad), }, }, map[string]interface{}{ @@ -121,7 +121,7 @@ var _ = Describe("Packet Header", func() { Header: wire.Header{ Type: protocol.PacketTypeHandshake, SrcConnectionID: protocol.ParseConnectionID([]byte{0x00, 0x11, 0x22, 0x33, 0x44, 0x55, 0x66, 0x77, 0x88, 0x99, 0xaa, 0xbb, 0xcc, 0xdd, 0xee, 0xff}), - Version: protocol.VersionNumber(0xdecafbad), + Version: protocol.Version(0xdecafbad), }, }, map[string]interface{}{ diff --git a/qlog/qlog_dir.go b/qlog/qlog_dir.go new file mode 100644 index 000000000..39a0938a4 --- /dev/null +++ b/qlog/qlog_dir.go @@ -0,0 +1,49 @@ +package qlog + +import ( + "bufio" + "context" + "fmt" + "log" + "os" + "strings" + + "github.com/refraction-networking/uquic/internal/utils" + "github.com/refraction-networking/uquic/logging" +) + +// DefaultTracer creates a qlog file in the qlog directory specified by the QLOGDIR environment variable. +// File names are _.qlog. +// Returns nil if QLOGDIR is not set. +func DefaultTracer(_ context.Context, p logging.Perspective, connID logging.ConnectionID) *logging.ConnectionTracer { + var label string + switch p { + case logging.PerspectiveClient: + label = "client" + case logging.PerspectiveServer: + label = "server" + } + return qlogDirTracer(p, connID, label) +} + +// qlogDirTracer creates a qlog file in the qlog directory specified by the QLOGDIR environment variable. +// File names are _