Skip to content

Commit

Permalink
corrections
Browse files Browse the repository at this point in the history
tests for EC point format

check version before check extension
  • Loading branch information
gstarovo committed May 10, 2024
1 parent 076956c commit a505851
Show file tree
Hide file tree
Showing 7 changed files with 140 additions and 52 deletions.
5 changes: 0 additions & 5 deletions scripts/tls.py
Original file line number Diff line number Diff line change
Expand Up @@ -416,8 +416,6 @@ def clientCmd(argv):
if cipherlist:
settings.cipherNames = [item for cipher in cipherlist
for item in cipher.split(',')]
# CHANGED
settings.ec_point_formats = []
try:
start = time_stamp()
if username and password:
Expand Down Expand Up @@ -570,9 +568,6 @@ def serverCmd(argv):
if cipherlist:
settings.cipherNames = [item for cipher in cipherlist
for item in cipher.split(',')]
# CHANGED

settings.ec_point_formats = [2, 0]

class MySimpleEchoHandler(BaseRequestHandler):
def handle(self):
Expand Down
103 changes: 90 additions & 13 deletions tests/tlstest.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
#!/usr/bin/env python

# Authors:
# Authors:
# Trevor Perrin
# Kees Bos - Added tests for XML-RPC
# Dimitris Moraitis - Anon ciphersuites
Expand Down Expand Up @@ -48,26 +48,26 @@

try:
from tack.structures.Tack import Tack

except ImportError:
pass

def printUsage(s=None):
if m2cryptoLoaded:
crypto = "M2Crypto/OpenSSL"
else:
crypto = "Python crypto"
crypto = "Python crypto"
if s:
print("ERROR: %s" % s)
print("""\ntls.py version %s (using %s)
print("""\ntls.py version %s (using %s)
Commands:
server HOST:PORT DIRECTORY
client HOST:PORT DIRECTORY
""" % (__version__, crypto))
sys.exit(-1)


def testConnClient(conn):
b1 = os.urandom(1)
Expand All @@ -92,9 +92,9 @@ def testConnClient(conn):
assert r1000 == b1000

def clientTestCmd(argv):

address = argv[0]
dir = argv[1]
dir = argv[1]

#Split address into hostname/port tuple
address = address.split(":")
Expand Down Expand Up @@ -235,7 +235,7 @@ def connect():
settings.minVersion = (3,0)
settings.maxVersion = (3,0)
connection.handshakeClientCert(settings=settings)
testConnClient(connection)
testConnClient(connection)
assert(isinstance(connection.session.serverCertChain, X509CertChain))
connection.close()

Expand Down Expand Up @@ -309,7 +309,45 @@ def connect():
settings.eccCurves = ["secp256r1", "secp384r1", "secp521r1", "x25519", "x448"]
connection.handshakeClientCert(settings=settings)
testConnClient(connection)
assert connection.session.ec_point_format == ECPointFormat.ansiX962_compressed_char2
assert connection.session.ec_point_format == ECPointFormat.ansiX962_compressed_prime
connection.close()

test_no += 1

print("Test {0} - client uncompressed - error, TLSv1.2".format(test_no))
synchro.recv(1)
connection = connect()
settings = HandshakeSettings()
settings.minVersion = (3, 3)
settings.maxVersion = (3, 3)
settings.ec_point_formats = [ECPointFormat.uncompressed]
settings.eccCurves = ["secp256r1", "secp384r1", "secp521r1", "x25519", "x448"]
try:
connection.handshakeClientCert(settings=settings)
assert False
except TLSIllegalParameterException as e:
assert "No common EC point format" in str(e)
except TLSAbruptCloseError as e:
pass
connection.close()

test_no += 1

print("Test {0} - client comppressed char2 - error, TLSv1.2".format(test_no))
synchro.recv(1)
connection = connect()
settings = HandshakeSettings()
settings.minVersion = (3, 3)
settings.maxVersion = (3, 3)
settings.ec_point_formats = [ECPointFormat.ansiX962_compressed_char2]
settings.eccCurves = ["secp256r1", "secp384r1", "secp521r1", "x25519", "x448"]
try:
connection.handshakeClientCert(settings=settings)
assert False
except ValueError as e:
assert "Unknown EC point format provided: [2]" in str(e)
except TLSAbruptCloseError as e:
pass
connection.close()

test_no += 1
Expand Down Expand Up @@ -2190,7 +2228,7 @@ def connect():

test_no += 1

print("Test {0} server uncompressed ec format - uncompressed, TLSv1.2".format(test_no))
print("Test {0} - server uncompressed ec format - uncompressed, TLSv1.2".format(test_no))
synchro.send(b'R')
connection = connect()
settings = HandshakeSettings()
Expand All @@ -2206,7 +2244,7 @@ def connect():

test_no += 1

print("Test {0} server compressed ec format - compressed, TLSv1.2".format(test_no))
print("Test {0} - server compressed ec format - compressed, TLSv1.2".format(test_no))
synchro.send(b'R')
connection = connect()
settings = HandshakeSettings()
Expand All @@ -2216,7 +2254,46 @@ def connect():
connection.handshakeServer(certChain=x509ecdsaChain,
privateKey=x509ecdsaKey, settings=settings)
testConnServer(connection)
assert connection.session.ec_point_format == ECPointFormat.ansiX962_compressed_char2
assert connection.session.ec_point_format == ECPointFormat.ansiX962_compressed_prime
connection.close()

test_no +=1

print("Test {0} - server compressed ec format - error, TLSv1.2".format(test_no))
synchro.send(b'R')
connection = connect()
settings = HandshakeSettings()
settings.minVersion = (3, 1)
settings.maxVersion = (3, 3)
settings.ec_point_formats = [ECPointFormat.ansiX962_compressed_prime]
settings.eccCurves = ["secp256r1", "secp384r1", "secp521r1", "x25519", "x448"]
try:
connection.handshakeServer(certChain=x509ecdsaChain,
privateKey=x509ecdsaKey, settings=settings)
assert False
except TLSIllegalParameterException as e:
assert "No common EC point format" in str(e)
except TLSAbruptCloseError as e:
pass
connection.close()

test_no +=1

print("Test {0} - client compressed char2 - error, TLSv1.2".format(test_no))
synchro.send(b'R')
connection = connect()
settings = HandshakeSettings()
settings.minVersion = (3, 1)
settings.maxVersion = (3, 3)
settings.eccCurves = ["secp256r1", "secp384r1", "secp521r1", "x25519", "x448"]
try:
connection.handshakeServer(certChain=x509ecdsaChain,
privateKey=x509ecdsaKey, settings=settings)
assert False
except ValueError as e:
assert "Unknown EC point format provided: [2]" in str(e)
except TLSAbruptCloseError as e:
pass
connection.close()

test_no +=1
Expand Down Expand Up @@ -2505,7 +2582,7 @@ def connect():
connection.handshakeServer(certChain=x509Chain, privateKey=x509Key,
tacks=[tackUnrelated], settings=settings)
assert False
except TLSRemoteAlert as alert:
except TLSLocalAlert as alert:
if alert.description != AlertDescription.illegal_parameter:
raise
else:
Expand Down
9 changes: 4 additions & 5 deletions tlslite/handshakesettings.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,8 +61,7 @@
TICKET_CIPHERS = ["chacha20-poly1305", "aes256gcm", "aes128gcm", "aes128ccm",
"aes128ccm_8", "aes256ccm", "aes256ccm_8"]
PSK_MODES = ["psk_dhe_ke", "psk_ke"]
EC_POINT_FORMATS = [ECPointFormat.ansiX962_compressed_char2,
ECPointFormat.ansiX962_compressed_prime,
EC_POINT_FORMATS = [ECPointFormat.ansiX962_compressed_prime,
ECPointFormat.uncompressed]


Expand Down Expand Up @@ -358,7 +357,7 @@ class HandshakeSettings(object):
influences selected cipher suites.
:vartype ec_point_formats: list
:ivat ec_point_formats: Enabeled point format extension for
:ivar ec_point_formats: Enabled point format extension for
elliptic curves.
"""

Expand Down Expand Up @@ -606,11 +605,11 @@ def _sanityCheckExtensions(other):
if other.record_size_limit is not None and \
not 64 <= other.record_size_limit <= 2**14 + 1:
raise ValueError("record_size_limit cannot exceed 2**14+1 bytes")

bad_ec_ext = [i for i in other.ec_point_formats if
i not in EC_POINT_FORMATS]
if bad_ec_ext:
raise ValueError("Unknown ec point format extension: "
raise ValueError("Unknown EC point format provided: "
"{0}".format(bad_ec_ext))

HandshakeSettings._sanityCheckEMSExtension(other)
Expand Down
18 changes: 12 additions & 6 deletions tlslite/keyexchange.py
Original file line number Diff line number Diff line change
Expand Up @@ -709,10 +709,11 @@ def makeServerKeyExchange(self, sigHash=None):
ext_c = self.clientHello.getExtension(ExtensionType.ec_point_formats)
ext_s = self.serverHello.getExtension(ExtensionType.ec_point_formats)
if ext_c and ext_s:
for ext in ext_c.formats:
if ext in ext_s.formats:
ext_negotiated = ext
break
try:
ext_negotiated = next((i for i in ext_c.formats \
if i in ext_s.formats))
except StopIteration:
raise TLSIllegalParameterException("No common EC point format")

ecdhYs = kex.calc_public_value(self.ecdhXs, ext_negotiated)

Expand All @@ -739,6 +740,8 @@ def processClientKeyExchange(self, clientKeyExchange):
ext_supported = [
ext for ext in ext_c.formats if ext in ext_s.formats
]
if not ext_supported:
raise TLSIllegalParameterException("No common EC point format")
return kex.calc_shared_key(self.ecdhXs, ecdhYc, ext_supported)

def processServerKeyExchange(self, srvPublicKey, serverKeyExchange):
Expand All @@ -762,8 +765,11 @@ def processServerKeyExchange(self, srvPublicKey, serverKeyExchange):
ext_c = self.clientHello.getExtension(ExtensionType.ec_point_formats)
ext_s = self.serverHello.getExtension(ExtensionType.ec_point_formats)
if ext_c and ext_s:
ext_supported = [i for i in ext_c.formats if i in ext_s.formats]
ext_negotiated = ext_supported[0]
try:
ext_supported = [i for i in ext_c.formats if i in ext_s.formats]
ext_negotiated = ext_supported[0]
except IndexError:
raise TLSIllegalParameterException("No common EC point format")

self.ecdhYc = kex.calc_public_value(ecdhXc, ext_negotiated)
return kex.calc_shared_key(ecdhXc, ecdh_Ys, ext_supported)
Expand Down
3 changes: 1 addition & 2 deletions tlslite/session.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,8 +74,7 @@ class Session(object):
from the server
:vartype ec_point_format: int
:ivar ec_point_format: used ec point extension format;
created for testing
:ivar ec_point_format: used EC point format for the ECDH key exchange;
"""

def __init__(self):
Expand Down
43 changes: 28 additions & 15 deletions tlslite/tlsconnection.py
Original file line number Diff line number Diff line change
Expand Up @@ -655,12 +655,21 @@ def _handshakeClientAsyncHelper(self, srpParams, certParams, anonParams,
alpnExt = serverHello.getExtension(ExtensionType.alpn)
if alpnExt:
alpnProto = alpnExt.protocol_names[0]
ext_c = clientHello.getExtension(ExtensionType.ec_point_formats)
ext_s = serverHello.getExtension(ExtensionType.ec_point_formats)

ext_ec_point = ECPointFormat.uncompressed
if ext_c and ext_s:
ext_ec_point = [i for i in ext_c.formats \
if i in ext_s.formats][0]
if self.version < (3, 4):
ext_c = clientHello.getExtension(ExtensionType.ec_point_formats)
ext_s = serverHello.getExtension(ExtensionType.ec_point_formats)
if ext_c and ext_s:
try:
ext_ec_point = next((i for i in ext_c.formats \
if i in ext_s.formats))

except StopIteration as alert:
for result in self._sendError(
AlertDescription.illegal_parameter,
str(alert)):
yield result

# Create the session object which is used for resumptions
self.session = Session()
Expand Down Expand Up @@ -771,9 +780,6 @@ def _clientSendClientHello(self, settings, session, srpUsername,
if settings.ec_point_formats:
extensions.append(ECPointFormatsExtension().\
create(settings.ec_point_formats))
else:
extensions.append(ECPointFormatsExtension().\
create(list([ECPointFormat.uncompressed])))
# Advertise FFDHE groups if we have DHE ciphers
if next((cipher for cipher in cipherSuites
if cipher in CipherSuite.dhAllSuites), None) is not None:
Expand Down Expand Up @@ -2282,9 +2288,6 @@ def _handshakeServerAsyncHelper(self, verifierDB,
if settings.ec_point_formats:
extensions.append(ECPointFormatsExtension().
create(settings.ec_point_formats))
else:
extensions.append(ECPointFormatsExtension().\
create(list([ECPointFormat.uncompressed])))

# if client sent Heartbeat extension
if clientHello.getExtension(ExtensionType.heartbeat):
Expand Down Expand Up @@ -2425,11 +2428,21 @@ def _handshakeServerAsyncHelper(self, verifierDB,
srpUsername = clientHello.srp_username.decode("utf-8")
if clientHello.server_name:
serverName = clientHello.server_name.decode("utf-8")
ext_c = clientHello.getExtension(ExtensionType.ec_point_formats)
ext_s = serverHello.getExtension(ExtensionType.ec_point_formats)

ext_ec_point = ECPointFormat.uncompressed
if ext_c and ext_s:
ext_ec_point = [i for i in ext_c.formats if i in ext_s.formats][0]
if version < (3, 4):
ext_c = clientHello.getExtension(ExtensionType.ec_point_formats)
ext_s = serverHello.getExtension(ExtensionType.ec_point_formats)
if ext_c and ext_s:
try:
ext_ec_point = next((i for i in ext_c.formats \
if i in ext_s.formats))

except StopIteration as alert:
for result in self._sendError(
AlertDescription.illegal_parameter,
str(alert)):
yield result

# We'll update the session master secret once it is calculated
# in _serverFinished
Expand Down
11 changes: 5 additions & 6 deletions unit_tests/test_tlslite_keyexchange.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,23 +18,22 @@
from tlslite.handshakesettings import HandshakeSettings
from tlslite.messages import ServerHello, ClientHello, ServerKeyExchange,\
CertificateRequest, ClientKeyExchange
from tlslite.constants import CipherSuite, CertificateType, AlertDescription, \
from tlslite.constants import CipherSuite, CertificateType, \
HashAlgorithm, SignatureAlgorithm, GroupName, ECCurveType, \
SignatureScheme, ECPointFormat
from tlslite.errors import TLSLocalAlert, TLSIllegalParameterException, \
SignatureScheme
from tlslite.errors import TLSIllegalParameterException, \
TLSDecryptionFailed, TLSInsufficientSecurity, TLSUnknownPSKIdentity, \
TLSInternalError, TLSDecodeError
from tlslite.x509 import X509
from tlslite.x509certchain import X509CertChain
from tlslite.utils.keyfactory import parsePEMKey
from tlslite.utils.codec import Parser, Writer
from tlslite.utils.codec import Parser
from tlslite.utils.cryptomath import bytesToNumber, getRandomBytes, powMod, \
numberToByteArray, isPrime, numBytes
from tlslite.mathtls import makeX, makeU, makeK, goodGroupParameters
from tlslite.handshakehashes import HandshakeHashes
from tlslite import VerifierDB
from tlslite.extensions import SupportedGroupsExtension, SNIExtension, \
ECPointFormatsExtension
from tlslite.extensions import SupportedGroupsExtension, SNIExtension
from tlslite.utils.ecc import getCurveByName, getPointByteSize
from tlslite.utils.compat import a2b_hex
import ecdsa
Expand Down

0 comments on commit a505851

Please sign in to comment.