Skip to content

Commit

Permalink
Deal with linter feedback
Browse files Browse the repository at this point in the history
  • Loading branch information
joerivanruth committed Mar 6, 2024
1 parent f210d24 commit c1b36d2
Show file tree
Hide file tree
Showing 8 changed files with 40 additions and 35 deletions.
4 changes: 2 additions & 2 deletions pymonetdb/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
#
# Copyright 1997 - July 2008 CWI, August 2008 - 2016 MonetDB B.V.

from typing import Optional, Union
from typing import Optional
from pymonetdb import sql
from pymonetdb import mapi
from pymonetdb import exceptions
Expand Down Expand Up @@ -41,7 +41,7 @@
'TimeTzFromTicks', 'TimestampTzFromTicks']


def connect(
def connect( # noqa C901
database: str,
hostname: Optional[str] = None,
port: Optional[int] = None,
Expand Down
2 changes: 0 additions & 2 deletions pymonetdb/control.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,6 @@
import logging
from pymonetdb import mapi
from pymonetdb.exceptions import OperationalError, InterfaceError
from pymonetdb.target import Target, looks_like_url


logger = logging.getLogger(__name__)
Expand Down Expand Up @@ -73,7 +72,6 @@ class Control:
stop, lock, unlock, destroy your databases and request status information.
"""


def __init__(self, hostname=None, port=None, passphrase=None, **kwargs):

# override some settings
Expand Down
16 changes: 8 additions & 8 deletions pymonetdb/mapi.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,8 +15,7 @@
import hashlib
import ssl
import typing
from typing import Callable, Dict, List, Optional, Tuple, Union
from urllib.parse import parse_qsl, urlparse
from typing import Callable, List, Optional, Tuple, Union

from pymonetdb.exceptions import OperationalError, DatabaseError, \
ProgrammingError, NotSupportedError, IntegrityError
Expand Down Expand Up @@ -274,8 +273,8 @@ def prime_or_wrap_connection(self): # noqa: C901
ssl_context.minimum_version = ssl.TLSVersion.TLSv1_3
ssl_context.set_alpn_protocols(["mapi/9"])
if target.clientkey:
certfile=target.clientcert if target.clientcert else target.clientkey
keyfile=target.clientkey
certfile = target.clientcert if target.clientcert else target.clientkey
keyfile = target.clientkey
ssl_context.load_cert_chain(certfile, keyfile)
if 'host' in disabled_checks:
ssl_context.check_hostname = False
Expand Down Expand Up @@ -714,6 +713,7 @@ class HandshakeOption:
value (not converted to an integer) as a parameter.
Field `sent` can be used to keep track of whether the option has been sent.
"""

def __init__(self, level, name, fallback, value):
self.level = level
self.name = name
Expand All @@ -722,10 +722,10 @@ def __init__(self, level, name, fallback, value):
self.sent = False


def construct_target_from_args(database: str, username: str, password: str, language: str, # noqa: C901
hostname: Optional[str] = None, port: Optional[int] = None, unix_socket: Optional[str]=None,
connect_timeout: Optional[Union[float,int]]=None,
**kwargs):
def construct_target_from_args(database: Optional[str], username: str, password: str, language: str, # noqa: C901
hostname: Optional[str] = None, port: Optional[int] = None, unix_socket: Optional[str] = None,
connect_timeout: Optional[Union[float, int]] = None,
**kwargs):
"""Construct a Target from the other args"""

target = Target()
Expand Down
1 change: 0 additions & 1 deletion pymonetdb/sql/connections.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,6 @@

from datetime import datetime, timedelta, timezone
import logging
import platform
from typing import List

from pymonetdb.exceptions import DatabaseError
Expand Down
16 changes: 9 additions & 7 deletions pymonetdb/target.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@


import re
from typing import Callable, Union
from typing import Any, Callable, Union
from urllib.parse import parse_qsl, urlparse, quote as urlquote


Expand Down Expand Up @@ -89,7 +89,7 @@ class urlparam:
"""Decorator to create getter/setter for url parameter on a Target instance"""

field: str
parser: Callable[[Union[str, any]], any]
parser: Callable[[Union[str, Any]], Any]

def __init__(self, name, typ, doc):
self.field = name
Expand Down Expand Up @@ -171,7 +171,9 @@ def clone(self):
'rows beyond this limit are retrieved on demand, <1 means unlimited')
maxprefetch = urlparam('maxprefetch', 'integer', 'specific to pymonetdb')
connect_timeout = urlparam('connect_timeout', 'integer', 'abort if connect takes longer than this')
dangerous_tls_nocheck = urlparam('dangerous_tls_nocheck', 'bool', 'comma separated certificate checks to skip, host: do not verify host, cert: do not verify certificate chain')
dangerous_tls_nocheck = urlparam(
'dangerous_tls_nocheck', 'bool',
'comma separated certificate checks to skip, host: do not verify host, cert: do not verify certificate chain')

# alias
fetchsize = replysize
Expand Down Expand Up @@ -228,7 +230,7 @@ def _set_core_defaults(self):
self.port = _DEFAULTS['port']
self.database = ''

def _parse_monetdb_url(self, url):
def _parse_monetdb_url(self, url): # noqa C901
parsed = urlparse(url, allow_fragments=True)

if parsed.scheme == 'monetdb':
Expand Down Expand Up @@ -278,7 +280,7 @@ def _parse_monetdb_url(self, url):
"key {key!r} is not allowed in the query parameters")
self.set(key, value)

def _parse_mapi_monetdb_url(self, url):
def _parse_mapi_monetdb_url(self, url): # noqa C901
# mapi urls have no percent encoding at all
parsed = urlparse(url[5:])
if parsed.scheme != 'monetdb':
Expand Down Expand Up @@ -326,7 +328,7 @@ def _parse_mapi_monetdb_url(self, url):
# unknown parameters are ignored
pass

def _parse_mapi_merovingian_url(self, url):
def _parse_mapi_merovingian_url(self, url): # noqa C901
# mapi urls have no percent encoding at all
parsed = urlparse(url[5:])
if parsed.scheme != 'merovingian':
Expand Down Expand Up @@ -373,7 +375,7 @@ def _parse_mapi_merovingian_url(self, url):
# unknown parameters are ignored
pass

def validate(self):
def validate(self): # noqa C901
# 1. The parameters have the types listed in the table in [Section
# Parameters](#parameters).
#
Expand Down
2 changes: 1 addition & 1 deletion tests/test_resultset.py
Original file line number Diff line number Diff line change
Expand Up @@ -151,7 +151,7 @@ def connect_with_args(self, **kw_args) -> pymonetdb.Connection:
args = dict()
args.update(test_args)
args.update(kw_args)
conn = pymonetdb.connect(**args)
conn = pymonetdb.connect(**args) # type: ignore
except AttributeError:
self.fail("No connect method found in pymonetdb module")
self.to_close.append(conn)
Expand Down
33 changes: 19 additions & 14 deletions tests/test_target.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,7 @@

import os
import re
import sys
from typing import List, Tuple
from typing import List, Optional, Tuple
from unittest import TestCase
import unittest

Expand Down Expand Up @@ -41,8 +40,8 @@ def read_lines(f, filename: str, start_line=0) -> List[Line]:


def split_tests(lines: List[Line]) -> List[Tuple[str, List[Line]]]:
tests = []
cur = None
tests: List[Tuple[str, List[Line]]] = []
cur: Optional[List[Line]] = None
header = None
count = 0
location = None
Expand Down Expand Up @@ -102,7 +101,7 @@ def run_test(self, test):
e.add_note(f"At {line.location}")
raise

def apply_line(self, target: Target, line: Line):
def apply_line(self, target: Target, line: Line): # noqa C901
if not line:
return

Expand Down Expand Up @@ -155,15 +154,7 @@ def apply_set(self, target: Target, key, value):

def apply_expect(self, target: Target, key, expected_value):
if key == 'valid':
should_succeed = parse_bool(expected_value)
try:
target.validate()
if not should_succeed:
self.fail("Expected valid=false")
except ValueError as e:
if should_succeed:
self.fail(f"Expected valid=true, got error {e}")
return
return self.apply_expect_valid(target, key, expected_value)

if key in VIRTUAL:
target.validate()
Expand All @@ -173,6 +164,9 @@ def apply_expect(self, target: Target, key, expected_value):
else:
actual_value = target.get(key)

self.verify_expected_value(key, expected_value, actual_value)

def verify_expected_value(self, key, expected_value, actual_value):
if isinstance(actual_value, bool):
expected_value = parse_bool(expected_value)
elif isinstance(actual_value, int):
Expand All @@ -184,6 +178,17 @@ def apply_expect(self, target: Target, key, expected_value):
if actual_value != expected_value:
self.fail(f"Expected {key}={expected_value!r}, found {actual_value!r}")

def apply_expect_valid(self, target, key, expected_value):
should_succeed = parse_bool(expected_value)
try:
target.validate()
if not should_succeed:
self.fail("Expected valid=false")
except ValueError as e:
if should_succeed:
self.fail(f"Expected valid=true, got error {e}")
return


# Magic alert!
# Read tests.md and generate test cases programmatically!
Expand Down
1 change: 1 addition & 0 deletions tests/test_tls.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
import logging
logging.basicConfig(level=logging.DEBUG)


class TestTLS(TestCase):
_name: Optional[str]
_cache: Dict[str, str]
Expand Down

0 comments on commit c1b36d2

Please sign in to comment.