Skip to content

Commit

Permalink
chore: multithreading for I/O requests
Browse files Browse the repository at this point in the history
  • Loading branch information
Ananya2001-an committed Jun 9, 2023
1 parent bff9884 commit a3c6417
Show file tree
Hide file tree
Showing 6 changed files with 81 additions and 25 deletions.
2 changes: 1 addition & 1 deletion .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@ gtfs/__version__.py

# Generated / downloaded files
*.zip
*.p
*.pkl

# virtualenv
.venv/
Expand Down
Binary file added feeds/AlbanyNy.pkl
Binary file not shown.
Binary file added feeds/Berlin.pkl
Binary file not shown.
82 changes: 59 additions & 23 deletions gtfs/__main__.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
#!/usr/bin/env python
"""Command line interface for fetching GTFS."""
import os
import threading
from typing import Optional

import typer
Expand Down Expand Up @@ -37,6 +38,18 @@ def check_bbox(bbox: str) -> Optional[Bbox]:
return Bbox(min_x, min_y, max_x, max_y)


def check_sources(sources: str) -> Optional[str]:
"""Check if the sources are valid."""
if sources is None:
return None
sources = sources.split(",")
for source in sources:
if not any(src.__name__.lower() == source.lower() for src in feed_sources):
raise typer.BadParameter(f"{source} is not a valid feed source!")

return ",".join(sources)


@app.command()
def list_feeds(
bbox: Annotated[
Expand Down Expand Up @@ -86,10 +99,8 @@ def list_feeds(
["Feed Source", "Transit URL", "Bounding Box"], theme=Themes.OCEAN, hrules=1
)


filtered_srcs = ""


for src in feed_sources:
feed_bbox: Bbox = src.bbox
if bbox is not None and predicate == "contains":
Expand All @@ -101,7 +112,6 @@ def list_feeds(
):
continue


filtered_srcs += src.__name__ + ", "

if pretty is True:
Expand All @@ -116,23 +126,22 @@ def list_feeds(

print(src.url)


if pretty is True:
print("\n" + pretty_output.get_string())

if typer.confirm("Do you want to fetch feeds from these sources?"):
fetch_feeds(sources=filtered_srcs[:-1])



@app.command()
def fetch_feeds(
sources: Annotated[
Optional[str],
typer.Option(
"--sources",
"-src",
help="pass value as a string separated by commas like this: src1,src2,src3",
help="pass value as a string separated by commas like this: Berlin,AlbanyNy,...",
callback=check_sources,
),
] = None,
search: Annotated[
Expand All @@ -148,19 +157,27 @@ def fetch_feeds(
typer.Option(
"--output-dir",
"-o",
help="the directory where the downloaded feeds will be saved",
help="the directory where the downloaded feeds will be saved, default is feeds",
),
] = "feeds",
concurrency: Annotated[
Optional[int],
typer.Option(
"--concurrency",
"-c",
help="the number of concurrent downloads, default is 4",
),
] = 4,
) -> None:
"""Fetch feeds from sources.
:param sources: List of :FeedSource: modules to fetch; if not set, will fetch all available.
:param search: Search for feeds based on a string.
:param output_dir: The directory where the downloaded feeds will be saved; default is feeds.
:param concurrency: The number of concurrent downloads; default is 4.
"""
# statuses = {} # collect the statuses for all the files


if not sources:
if not search:
# fetch all feeds
Expand All @@ -173,28 +190,47 @@ def fetch_feeds(
if search.lower() in src.__name__.lower() or search.lower() in src.url.lower()
]
else:
# fetch feeds based on sources
sources = [src for src in feed_sources if src.__name__.lower() in sources.lower()]
if search:
raise typer.BadParameter("Please pass either sources or search, not both at the same time!")
else:
sources = [src for src in feed_sources if src.__name__.lower() in sources.lower()]

output_dir_path = os.path.join(os.getcwd(), output_dir)
if not os.path.exists(output_dir_path):
os.makedirs(output_dir_path)

LOG.info(f"Going to fetch feeds from sources: {sources}")

for src in sources:
LOG.debug(f"Going to start fetch for {src}...")
try:
if issubclass(src, FeedSource):
inst = src()
inst.ddir = output_dir_path
inst.status_file = os.path.join(inst.ddir, src.__name__ + ".pkl")
inst.fetch()
# statuses.update(inst.status)
else:
LOG.warning(f"Skipping class {src.__name__}, which does not subclass FeedSource.")
except AttributeError:
LOG.error(f"Skipping feed {src}, which could not be found.")
threads = []

def thread_worker():
while True:
try:
src = sources.pop(0)
except IndexError:
break

LOG.debug(f"Going to start fetch for {src}...")
try:
if issubclass(src, FeedSource):
inst = src()
inst.ddir = output_dir_path
inst.status_file = os.path.join(inst.ddir, src.__name__ + ".pkl")
inst.fetch()
# statuses.update(inst.status)
else:
LOG.warning(f"Skipping class {src.__name__}, which does not subclass FeedSource.")
except AttributeError:
LOG.error(f"Skipping feed {src}, which could not be found.")

for _ in range(concurrency):
thread = threading.Thread(target=thread_worker)
thread.start()
threads.append(thread)

# Wait for all threads to complete
for thread in threads:
thread.join()

# ptable = ColorTable(
# [
Expand Down
2 changes: 1 addition & 1 deletion gtfs/feed_source.py
Original file line number Diff line number Diff line change
Expand Up @@ -144,7 +144,7 @@ def download_feed(self, feed_file: str, url: str, do_stream: bool = True) -> boo
LOG.debug("No last-modified header set")
posted_date = datetime.utcnow().strftime(TIMECHECK_FMT)
self.set_posted_date(feed_file, posted_date)
LOG.info("Download completed successfully.")
LOG.info(f"Download completed successfully for {feed_file}.")
return True
else:
self.set_error(feed_file, "Download failed")
Expand Down
20 changes: 20 additions & 0 deletions tests/test_cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,3 +49,23 @@ def test_contains_predicate(self, runner):
def test_pretty(self, runner):
result = runner.invoke(app, ["list-feeds", "-pt"], input="N\n")
assert result.exit_code == 0


class TestFetchFeedsCommand:
def test_help(self, runner):
result = runner.invoke(app, ["fetch-feeds", "--help"])
assert result.exit_code == 0
assert "Fetch feeds from sources." in result.stdout

def test_bad_args(self, runner):
result = runner.invoke(app, ["fetch-feeds", "-src", "berlin", "-s", "cdta"])
assert result.exit_code == 2
assert "Please pass either sources or search" in result.stdout

def test_fetch_with_sources(self, runner):
result = runner.invoke(app, ["fetch-feeds", "-src", "berlin"])
assert result.exit_code == 0

def test_fetch_with_search(self, runner):
result = runner.invoke(app, ["fetch-feeds", "-s", "cdta"])
assert result.exit_code == 0

0 comments on commit a3c6417

Please sign in to comment.