Skip to content

Commit

Permalink
Fix typing for Python 3.8 (#22)
Browse files Browse the repository at this point in the history
* Fix typing for Python 3.8

* Fix set typing

* Add changelog and bump version
  • Loading branch information
tadejsv authored Oct 2, 2022
1 parent 18330db commit 4188cf4
Show file tree
Hide file tree
Showing 8 changed files with 40 additions and 31 deletions.
2 changes: 1 addition & 1 deletion .github/workflows/ci.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ jobs:
strategy:
matrix:
operating-system: [ubuntu-latest, macos-latest]
python-version: ["3.9", "3.10"]
python-version: ["3.8", "3.9", "3.10"]
fail-fast: false
steps:
- name: Checkout
Expand Down
7 changes: 7 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,13 @@ All notable changes to this project will be documented in this file.

The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/).


## [0.1.7] - 2022-10-02

## Fixed

* Fix typing for compatibility with Python 3.8 ([#22](https://github.com/tadejsv/EvalDeT/pull/22))

## [0.1.6] - 2022-09-29

## Changed
Expand Down
1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@ keywords = ["evaluation", "tracking", "object detection", "computer vision"]
classifiers = [
"License :: OSI Approved :: Apache Software License",
"Programming Language :: Python :: 3",
"Programming Language :: Python :: 3.8",
"Programming Language :: Python :: 3.9",
"Programming Language :: Python :: 3.10",
"Development Status :: 4 - Beta",
Expand Down
2 changes: 1 addition & 1 deletion src/evaldet/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
__version__ = "0.1.6"
__version__ = "0.1.7"

from .metrics import MOTMetrics # noqa: F401
from .tracks import Tracks # noqa: F401
4 changes: 2 additions & 2 deletions src/evaldet/metrics.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,8 +29,8 @@ class MOTMetrics(CLEARMOTMetrics, IDMetrics, HOTAMetrics):
efficiently share pre-computed IoU distances.
"""

_ious: list[np.ndarray]
_ious_dict: dict[int, int]
_ious: t.List[np.ndarray]
_ious_dict: t.Dict[int, int]

def __init__(
self, clearmot_dist_threshold: float = 0.5, id_dist_threshold: float = 0.5
Expand Down
4 changes: 2 additions & 2 deletions src/evaldet/mot_metrics/clearmot.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,11 +31,11 @@ def _calculate_clearmot_metrics(
ground_truths = 0

matches_dist = []
matching: dict[int, int] = {}
matching: t.Dict[int, int] = {}

# This is the persistent matching dictionary, used to check for mismatches
# when a previously matched hypothesis is re-matched with a ground truth
matching_persist: dict[int, int] = {}
matching_persist: t.Dict[int, int] = {}

for frame in all_frames:
if frame not in ground_truth:
Expand Down
6 changes: 3 additions & 3 deletions src/evaldet/mot_metrics/hota.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ class HOTAResults(t.TypedDict):


def _create_coo_array(
vals_list: dict[tuple[int, int], int], shape: tuple[int, int]
vals_list: t.Dict[t.Tuple[int, int], int], shape: t.Tuple[int, int]
) -> sparse.coo_array:
row_inds = np.array(tuple(x[0] for x in vals_list.keys()))
col_inds = np.array(tuple(x[1] for x in vals_list.keys()))
Expand Down Expand Up @@ -54,7 +54,7 @@ def _calculate_hota_metrics(
TPA_max = np.zeros((len(alphas), n_gt, n_hyp), dtype=np.int32)
FPA_max = np.tile(np.tile(hyps_counts, (n_gt, 1)), (len(alphas), 1, 1))
FNA_max = np.tile(np.tile(gts_counts, (n_hyp, 1)).T, (len(alphas), 1, 1))
TPA_max_vals: list[dict[tuple[int, int], int]] = [
TPA_max_vals: t.List[t.Dict[t.Tuple[int, int], int]] = [
co.defaultdict(int) for _ in range(len(alphas))
]

Expand Down Expand Up @@ -85,7 +85,7 @@ def _calculate_hota_metrics(
A_max = TPA_max / (FNA_max + FPA_max - TPA_max)

# Do the actual matching
TPA_vals: list[dict[tuple[int, int], int]] = [
TPA_vals: t.List[t.Dict[t.Tuple[int, int], int]] = [
co.defaultdict(int) for _ in range(len(alphas))
]
for frame in sorted(set(ground_truth.frames).intersection(hypotheses.frames)):
Expand Down
45 changes: 23 additions & 22 deletions src/evaldet/tracks.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import csv
import typing as t
import xml.etree.ElementTree as ET
from pathlib import Path
from typing import Any, NamedTuple, Optional, Union
Expand Down Expand Up @@ -43,7 +44,7 @@ class Tracks:
def from_csv(
cls,
csv_file: Union[str, Path],
fieldnames: list[str],
fieldnames: t.List[str],
zero_indexed: bool = True,
) -> "Tracks":
"""Get detections from a CSV file.
Expand Down Expand Up @@ -72,7 +73,7 @@ def from_csv(
tracks = cls()
with open(csv_file, newline="") as file:
csv_reader = csv.DictReader(file, fieldnames=fieldnames, dialect="unix")
frames: dict[int, Any] = {}
frames: t.Dict[int, Any] = {}

for line_num, line in enumerate(csv_reader):
try:
Expand Down Expand Up @@ -212,7 +213,7 @@ def from_ua_detrac(
cls,
file_path: Union[Path, str],
classes_attr_name: Optional[str] = None,
classes_list: Optional[list[str]] = None,
classes_list: Optional[t.List[str]] = None,
) -> "Tracks":
"""Creates a Tracks object from detections file in the UA-DETRAC XML format.
Expand Down Expand Up @@ -282,9 +283,9 @@ def from_ua_detrac(
tracks_f = frame.find("target_list").findall("target") # type: ignore

current_frame = int(frame.attrib["num"])
detections: list[list[float]] = []
classes: list[int] = []
ids: list[int] = []
detections: t.List[t.List[float]] = []
classes: t.List[int] = []
ids: t.List[int] = []

for track in tracks_f:
# Get track attributes
Expand Down Expand Up @@ -315,7 +316,7 @@ def from_ua_detrac(
def from_cvat_video(
cls,
file_path: Union[Path, str],
classes_list: list[str],
classes_list: t.List[str],
) -> "Tracks":
"""Creates a Tracks object from detections file in the CVAT for Video XML
format.
Expand Down Expand Up @@ -366,7 +367,7 @@ def from_cvat_video(
root = xml_tree.getroot()
tracks = cls()

frames: dict[int, Any] = {}
frames: t.Dict[int, Any] = {}
tracks_cvat = root.findall("track")
for track_cvat in tracks_cvat:
track_id = int(track_cvat.attrib[_ID_KEY])
Expand Down Expand Up @@ -438,19 +439,19 @@ def _add_to_tracks_accumulator(frames: dict, new_obj: dict) -> None:

def __init__(self) -> None:

self._frame_nums: set[int] = set()
self._detections: dict[int, np.ndarray] = dict()
self._ids: dict[int, np.ndarray] = dict()
self._classes: dict[int, np.ndarray] = dict()
self._confs: dict[int, np.ndarray] = dict()
self._frame_nums: t.Set[int] = set()
self._detections: t.Dict[int, np.ndarray] = dict()
self._ids: t.Dict[int, np.ndarray] = dict()
self._classes: t.Dict[int, np.ndarray] = dict()
self._confs: t.Dict[int, np.ndarray] = dict()

def add_frame(
self,
frame_num: int,
ids: Union[list[int], np.ndarray],
ids: Union[t.List[int], np.ndarray],
detections: np.ndarray,
classes: Optional[Union[list[int], np.ndarray]] = None,
confs: Optional[Union[list[float], np.ndarray]] = None,
classes: Optional[Union[t.List[int], np.ndarray]] = None,
confs: Optional[Union[t.List[float], np.ndarray]] = None,
) -> None:
"""Add a frame to the collection. Can overwrite existing frame.
Expand Down Expand Up @@ -562,7 +563,7 @@ def filter_frame(self, frame_num: int, filter: np.ndarray) -> None:
if frame.classes is not None:
self._classes[frame_num] = frame.classes[filter]

def filter_by_class(self, classes: list[int]) -> None:
def filter_by_class(self, classes: t.List[int]) -> None:
"""Filter all frames by classes
This will keep the detections with class label corresponding to one of the
Expand Down Expand Up @@ -604,32 +605,32 @@ def filter_by_conf(self, lower_bound: float) -> None:
self.filter_frame(frame, filter_conf)

@property
def all_classes(self) -> set[int]:
def all_classes(self) -> t.Set[int]:
"""Get a set of all classes in the collection."""
classes: set[int] = set()
classes: t.Set[int] = set()
for frame in self._frame_nums:
if frame in self._classes:
classes.update(self._classes[frame])

return classes

@property
def ids_count(self) -> dict[int, int]:
def ids_count(self) -> t.Dict[int, int]:
"""Get the number of frames that each id is present in.
Returns:
A dictionary where keys are the track ids, and values
are the numbers of frames they appear in.
"""
ids_count: dict[int, int] = dict()
ids_count: t.Dict[int, int] = dict()
for frame in self._frame_nums:
for _id in self._ids[frame]:
ids_count[_id] = ids_count.get(_id, 0) + 1

return ids_count

@property
def frames(self) -> set[int]:
def frames(self) -> t.Set[int]:
"""Get an ordered list of all frame numbers in the collection."""
return self._frame_nums.copy()

Expand Down

0 comments on commit 4188cf4

Please sign in to comment.