Skip to content

Commit

Permalink
Make RFC compatible behaviour with session ticket
Browse files Browse the repository at this point in the history
Prevent the server from echoing the session_ticket extension when either
ticket_count or ticketKeys are not set in settings.
#524
  • Loading branch information
GeorgePantelakis committed Aug 12, 2024
1 parent 4d2c6b8 commit 667da27
Showing 1 changed file with 31 additions and 30 deletions.
61 changes: 31 additions & 30 deletions tlslite/tlsconnection.py
Original file line number Diff line number Diff line change
Expand Up @@ -825,7 +825,7 @@ def _clientSendClientHello(self, settings, session, srpUsername,
clientHello = ClientHello()
clientHello.create(sent_version, getRandomBytes(32),
session.sessionID, wireCipherSuites,
certificateTypes,
certificateTypes,
session.srpUsername,
reqTack, nextProtos is not None,
session.serverName,
Expand All @@ -836,9 +836,9 @@ def _clientSendClientHello(self, settings, session, srpUsername,
clientHello = ClientHello()
clientHello.create(sent_version, getRandomBytes(32),
session_id, wireCipherSuites,
certificateTypes,
certificateTypes,
srpUsername,
reqTack, nextProtos is not None,
reqTack, nextProtos is not None,
serverName,
extensions=extensions)

Expand Down Expand Up @@ -1083,7 +1083,7 @@ def _clientGetServerHello(self, settings, session, clientHello):
AlertDescription.illegal_parameter,
"Server responded with incorrect compression method"):
yield result
if serverHello.tackExt:
if serverHello.tackExt:
if not clientHello.tack:
for result in self._sendError(\
AlertDescription.illegal_parameter,
Expand Down Expand Up @@ -1605,7 +1605,7 @@ def _clientSelectNextProto(self, nextProtos, serverHello):
#
# !!! We assume the client may have specified nextProtos as a list of
# strings so we convert them to bytearrays (it's awkward to require
# the user to specify a list of bytearrays or "bytes", and in
# the user to specify a list of bytearrays or "bytes", and in
# Python 2.6 bytes() is just an alias for str() anyways...
if nextProtos is not None and serverHello.next_protos is not None:
for p in nextProtos:
Expand All @@ -1617,7 +1617,7 @@ def _clientSelectNextProto(self, nextProtos, serverHello):
# the client SHOULD select the first protocol it supports.
return bytearray(nextProtos[0])
return None

def _clientResume(self, session, serverHello, clientRandom,
nextProto, settings):

Expand Down Expand Up @@ -1855,8 +1855,8 @@ def _clientFinished(self, premasterSecret, clientRandom, serverRandom,
cipherSuite,
clientRandom,
serverRandom)
self._calcPendingStates(cipherSuite, masterSecret,
clientRandom, serverRandom,
self._calcPendingStates(cipherSuite, masterSecret,
clientRandom, serverRandom,
cipherImplementations)

#Exchange ChangeCipherSpec and Finished messages
Expand Down Expand Up @@ -1967,13 +1967,13 @@ def _clientGetKeyFromChain(self, certificate, settings, tack_ext=None):
if tackpyLoaded:
if not tack_ext:
tack_ext = cert_chain.getTackExt()

# If there's a TACK (whether via TLS or TACK Cert), check that it
# matches the cert chain
# matches the cert chain
if tack_ext and tack_ext.tacks:
for tack in tack_ext.tacks:
if not cert_chain.checkTack(tack):
for result in self._sendError(
for result in self._sendError(
AlertDescription.illegal_parameter,
"Other party's TACK doesn't match their public key"):
yield result
Expand All @@ -1989,7 +1989,7 @@ def _clientGetKeyFromChain(self, certificate, settings, tack_ext=None):
def handshakeServer(self, verifierDB=None,
certChain=None, privateKey=None, reqCert=False,
sessionCache=None, settings=None, checker=None,
reqCAs = None,
reqCAs=None,
tacks=None, activationFlags=0,
nextProtos=None, anon=False, alpn=None, sni=None):
"""Perform a handshake in the role of server.
Expand Down Expand Up @@ -2090,7 +2090,7 @@ def handshakeServer(self, verifierDB=None,
def handshakeServerAsync(self, verifierDB=None,
certChain=None, privateKey=None, reqCert=False,
sessionCache=None, settings=None, checker=None,
reqCAs=None,
reqCAs=None,
tacks=None, activationFlags=0,
nextProtos=None, anon=False, alpn=None, sni=None
):
Expand All @@ -2108,19 +2108,19 @@ def handshakeServerAsync(self, verifierDB=None,
handshaker = self._handshakeServerAsyncHelper(\
verifierDB=verifierDB, cert_chain=certChain,
privateKey=privateKey, reqCert=reqCert,
sessionCache=sessionCache, settings=settings,
reqCAs=reqCAs,
tacks=tacks, activationFlags=activationFlags,
sessionCache=sessionCache, settings=settings,
reqCAs=reqCAs,
tacks=tacks, activationFlags=activationFlags,
nextProtos=nextProtos, anon=anon, alpn=alpn, sni=sni)
for result in self._handshakeWrapperAsync(handshaker, checker):
yield result


def _handshakeServerAsyncHelper(self, verifierDB,
cert_chain, privateKey, reqCert, sessionCache,
settings, reqCAs,
tacks, activationFlags,
nextProtos, anon, alpn, sni):
cert_chain, privateKey, reqCert,
sessionCache, settings, reqCAs, tacks,
activationFlags, nextProtos, anon, alpn,
sni):

self._handshakeStart(client=False)

Expand All @@ -2136,7 +2136,7 @@ def _handshakeServerAsyncHelper(self, verifierDB,
if privateKey and not cert_chain:
raise ValueError("Caller passed a privateKey but no cert_chain")
if reqCAs and not reqCert:
raise ValueError("Caller passed reqCAs but not reqCert")
raise ValueError("Caller passed reqCAs but not reqCert")
if cert_chain and not isinstance(cert_chain, X509CertChain):
raise ValueError("Unrecognized certificate type")
if activationFlags and not tacks:
Expand All @@ -2153,16 +2153,16 @@ def _handshakeServerAsyncHelper(self, verifierDB,

# OK Start exchanging messages
# ******************************

# Handle ClientHello and resumption
for result in self._serverGetClientHello(settings, privateKey,
cert_chain,
verifierDB, sessionCache,
anon, alpn, sni):
if result in (0,1): yield result
elif result == None:
self._handshakeDone(resumed=True)
return # Handshake was resumed, we're done
self._handshakeDone(resumed=True)
return # Handshake was resumed, we're done
else: break
(clientHello, version, cipherSuite, sig_scheme, privateKey,
cert_chain) = result
Expand Down Expand Up @@ -2191,7 +2191,7 @@ def _handshakeServerAsyncHelper(self, verifierDB,
sessionID = getRandomBytes(32)
else:
sessionID = bytearray(0)

if not clientHello.supports_npn:
nextProtos = None

Expand Down Expand Up @@ -2292,7 +2292,8 @@ def _handshakeServerAsyncHelper(self, verifierDB,
# send a new ticket in a NewSessionTicket message
send_session_ticket = False
session_ticket = clientHello.getExtension(ExtensionType.session_ticket)
if session_ticket and len(session_ticket.ticket) == 0:
enable_ticket = settings.ticket_count > 0 and settings.ticketKeys
if session_ticket and session_ticket.ticket and enable_ticket:
send_session_ticket = True
extensions.append(SessionTicketExtension().create(
bytearray(0)))
Expand Down Expand Up @@ -3695,9 +3696,9 @@ def _serverGetClientHello(self, settings, private_key, cert_chain,
yield result

#Calculate pending connection states
self._calcPendingStates(session.cipherSuite,
self._calcPendingStates(session.cipherSuite,
session.masterSecret,
clientHello.random,
clientHello.random,
serverHello.random,
settings.cipherImplementations)

Expand Down Expand Up @@ -4422,7 +4423,7 @@ def _serverFinished(self, premasterSecret, clientRandom, serverRandom,
self.session.masterSecret = masterSecret

#Calculate pending connection states
self._calcPendingStates(cipherSuite, masterSecret,
self._calcPendingStates(cipherSuite, masterSecret,
clientRandom, serverRandom,
cipherImplementations)

Expand Down Expand Up @@ -4534,7 +4535,7 @@ def _getFinished(self, masterSecret, cipherSuite=None,
#Switch to pending read state
self._changeReadState()

#Server Finish - Are we waiting for a next protocol echo?
#Server Finish - Are we waiting for a next protocol echo?
if expect_next_protocol:
for result in self._getMsg(ContentType.handshake, HandshakeType.next_protocol):
if result in (0,1):
Expand Down

0 comments on commit 667da27

Please sign in to comment.