Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Make RFC compatible behaviour with session ticket #525

Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
62 changes: 32 additions & 30 deletions tlslite/tlsconnection.py
Original file line number Diff line number Diff line change
Expand Up @@ -825,7 +825,7 @@ def _clientSendClientHello(self, settings, session, srpUsername,
clientHello = ClientHello()
clientHello.create(sent_version, getRandomBytes(32),
session.sessionID, wireCipherSuites,
certificateTypes,
certificateTypes,
session.srpUsername,
reqTack, nextProtos is not None,
session.serverName,
Expand All @@ -836,9 +836,9 @@ def _clientSendClientHello(self, settings, session, srpUsername,
clientHello = ClientHello()
clientHello.create(sent_version, getRandomBytes(32),
session_id, wireCipherSuites,
certificateTypes,
certificateTypes,
srpUsername,
reqTack, nextProtos is not None,
reqTack, nextProtos is not None,
serverName,
extensions=extensions)

Expand Down Expand Up @@ -1083,7 +1083,7 @@ def _clientGetServerHello(self, settings, session, clientHello):
AlertDescription.illegal_parameter,
"Server responded with incorrect compression method"):
yield result
if serverHello.tackExt:
if serverHello.tackExt:
if not clientHello.tack:
for result in self._sendError(\
AlertDescription.illegal_parameter,
Expand Down Expand Up @@ -1605,7 +1605,7 @@ def _clientSelectNextProto(self, nextProtos, serverHello):
#
# !!! We assume the client may have specified nextProtos as a list of
# strings so we convert them to bytearrays (it's awkward to require
# the user to specify a list of bytearrays or "bytes", and in
# the user to specify a list of bytearrays or "bytes", and in
# Python 2.6 bytes() is just an alias for str() anyways...
if nextProtos is not None and serverHello.next_protos is not None:
for p in nextProtos:
Expand All @@ -1617,7 +1617,7 @@ def _clientSelectNextProto(self, nextProtos, serverHello):
# the client SHOULD select the first protocol it supports.
return bytearray(nextProtos[0])
return None

def _clientResume(self, session, serverHello, clientRandom,
nextProto, settings):

Expand Down Expand Up @@ -1855,8 +1855,8 @@ def _clientFinished(self, premasterSecret, clientRandom, serverRandom,
cipherSuite,
clientRandom,
serverRandom)
self._calcPendingStates(cipherSuite, masterSecret,
clientRandom, serverRandom,
self._calcPendingStates(cipherSuite, masterSecret,
clientRandom, serverRandom,
cipherImplementations)

#Exchange ChangeCipherSpec and Finished messages
Expand Down Expand Up @@ -1967,13 +1967,13 @@ def _clientGetKeyFromChain(self, certificate, settings, tack_ext=None):
if tackpyLoaded:
if not tack_ext:
tack_ext = cert_chain.getTackExt()

# If there's a TACK (whether via TLS or TACK Cert), check that it
# matches the cert chain
# matches the cert chain
if tack_ext and tack_ext.tacks:
for tack in tack_ext.tacks:
if not cert_chain.checkTack(tack):
for result in self._sendError(
for result in self._sendError(
AlertDescription.illegal_parameter,
"Other party's TACK doesn't match their public key"):
yield result
Expand All @@ -1989,7 +1989,7 @@ def _clientGetKeyFromChain(self, certificate, settings, tack_ext=None):
def handshakeServer(self, verifierDB=None,
certChain=None, privateKey=None, reqCert=False,
sessionCache=None, settings=None, checker=None,
reqCAs = None,
reqCAs=None,
tacks=None, activationFlags=0,
nextProtos=None, anon=False, alpn=None, sni=None):
"""Perform a handshake in the role of server.
Expand Down Expand Up @@ -2090,7 +2090,7 @@ def handshakeServer(self, verifierDB=None,
def handshakeServerAsync(self, verifierDB=None,
certChain=None, privateKey=None, reqCert=False,
sessionCache=None, settings=None, checker=None,
reqCAs=None,
reqCAs=None,
tacks=None, activationFlags=0,
nextProtos=None, anon=False, alpn=None, sni=None
):
Expand All @@ -2108,19 +2108,19 @@ def handshakeServerAsync(self, verifierDB=None,
handshaker = self._handshakeServerAsyncHelper(\
verifierDB=verifierDB, cert_chain=certChain,
privateKey=privateKey, reqCert=reqCert,
sessionCache=sessionCache, settings=settings,
reqCAs=reqCAs,
tacks=tacks, activationFlags=activationFlags,
sessionCache=sessionCache, settings=settings,
reqCAs=reqCAs,
tacks=tacks, activationFlags=activationFlags,
nextProtos=nextProtos, anon=anon, alpn=alpn, sni=sni)
for result in self._handshakeWrapperAsync(handshaker, checker):
yield result


def _handshakeServerAsyncHelper(self, verifierDB,
cert_chain, privateKey, reqCert, sessionCache,
settings, reqCAs,
tacks, activationFlags,
nextProtos, anon, alpn, sni):
cert_chain, privateKey, reqCert,
sessionCache, settings, reqCAs, tacks,
activationFlags, nextProtos, anon, alpn,
sni):

self._handshakeStart(client=False)

Expand All @@ -2136,7 +2136,7 @@ def _handshakeServerAsyncHelper(self, verifierDB,
if privateKey and not cert_chain:
raise ValueError("Caller passed a privateKey but no cert_chain")
if reqCAs and not reqCert:
raise ValueError("Caller passed reqCAs but not reqCert")
raise ValueError("Caller passed reqCAs but not reqCert")
if cert_chain and not isinstance(cert_chain, X509CertChain):
raise ValueError("Unrecognized certificate type")
if activationFlags and not tacks:
Expand All @@ -2153,16 +2153,16 @@ def _handshakeServerAsyncHelper(self, verifierDB,

# OK Start exchanging messages
# ******************************

# Handle ClientHello and resumption
for result in self._serverGetClientHello(settings, privateKey,
cert_chain,
verifierDB, sessionCache,
anon, alpn, sni):
if result in (0,1): yield result
elif result == None:
self._handshakeDone(resumed=True)
return # Handshake was resumed, we're done
self._handshakeDone(resumed=True)
return # Handshake was resumed, we're done
else: break
(clientHello, version, cipherSuite, sig_scheme, privateKey,
cert_chain) = result
Expand Down Expand Up @@ -2191,7 +2191,7 @@ def _handshakeServerAsyncHelper(self, verifierDB,
sessionID = getRandomBytes(32)
else:
sessionID = bytearray(0)

if not clientHello.supports_npn:
nextProtos = None

Expand Down Expand Up @@ -2292,7 +2292,9 @@ def _handshakeServerAsyncHelper(self, verifierDB,
# send a new ticket in a NewSessionTicket message
send_session_ticket = False
session_ticket = clientHello.getExtension(ExtensionType.session_ticket)
if session_ticket and len(session_ticket.ticket) == 0:
enable_ticket = settings.ticket_count > 0 and settings.ticketKeys
if session_ticket and len(session_ticket.ticket) == 0 \
and enable_ticket:
send_session_ticket = True
extensions.append(SessionTicketExtension().create(
bytearray(0)))
Expand Down Expand Up @@ -3695,9 +3697,9 @@ def _serverGetClientHello(self, settings, private_key, cert_chain,
yield result

#Calculate pending connection states
self._calcPendingStates(session.cipherSuite,
self._calcPendingStates(session.cipherSuite,
session.masterSecret,
clientHello.random,
clientHello.random,
serverHello.random,
settings.cipherImplementations)

Expand Down Expand Up @@ -4422,7 +4424,7 @@ def _serverFinished(self, premasterSecret, clientRandom, serverRandom,
self.session.masterSecret = masterSecret

#Calculate pending connection states
self._calcPendingStates(cipherSuite, masterSecret,
self._calcPendingStates(cipherSuite, masterSecret,
clientRandom, serverRandom,
cipherImplementations)

Expand Down Expand Up @@ -4534,7 +4536,7 @@ def _getFinished(self, masterSecret, cipherSuite=None,
#Switch to pending read state
self._changeReadState()

#Server Finish - Are we waiting for a next protocol echo?
#Server Finish - Are we waiting for a next protocol echo?
if expect_next_protocol:
for result in self._getMsg(ContentType.handshake, HandshakeType.next_protocol):
if result in (0,1):
Expand Down
Loading