From a3c6417c013131820522e3f2f19267ad4b1b833d Mon Sep 17 00:00:00 2001 From: Ananya Nayak Date: Fri, 9 Jun 2023 11:35:59 +0530 Subject: [PATCH] chore: multithreading for I/O requests --- .gitignore | 2 +- feeds/AlbanyNy.pkl | Bin 0 -> 75 bytes feeds/Berlin.pkl | Bin 0 -> 73 bytes gtfs/__main__.py | 82 +++++++++++++++++++++++++++++++------------- gtfs/feed_source.py | 2 +- tests/test_cli.py | 20 +++++++++++ 6 files changed, 81 insertions(+), 25 deletions(-) create mode 100644 feeds/AlbanyNy.pkl create mode 100644 feeds/Berlin.pkl diff --git a/.gitignore b/.gitignore index 675ca8c..b0441f3 100644 --- a/.gitignore +++ b/.gitignore @@ -3,7 +3,7 @@ gtfs/__version__.py # Generated / downloaded files *.zip -*.p +*.pkl # virtualenv .venv/ diff --git a/feeds/AlbanyNy.pkl b/feeds/AlbanyNy.pkl new file mode 100644 index 0000000000000000000000000000000000000000..d23c74e2495b9c1a9411f14955081550ea73a95d GIT binary patch literal 75 zcmZo*nd-m*0kuQ&V&lj13ih6Dt*r X42+Bw49%@ffXG0>-8W=Paj_l%VAU3z literal 0 HcmV?d00001 diff --git a/feeds/Berlin.pkl b/feeds/Berlin.pkl new file mode 100644 index 0000000000000000000000000000000000000000..9b1ea0b5cd464b2592138608f3e78f7245ae8296 GIT binary patch literal 73 zcmZo*nQF%X0ku$mtW1opj13jseM6=c7wZ84FT@q? literal 0 HcmV?d00001 diff --git a/gtfs/__main__.py b/gtfs/__main__.py index 8e4e23a..ad16df3 100755 --- a/gtfs/__main__.py +++ b/gtfs/__main__.py @@ -1,6 +1,7 @@ #!/usr/bin/env python """Command line interface for fetching GTFS.""" import os +import threading from typing import Optional import typer @@ -37,6 +38,18 @@ def check_bbox(bbox: str) -> Optional[Bbox]: return Bbox(min_x, min_y, max_x, max_y) +def check_sources(sources: str) -> Optional[str]: + """Check if the sources are valid.""" + if sources is None: + return None + sources = sources.split(",") + for source in sources: + if not any(src.__name__.lower() == source.lower() for src in feed_sources): + raise typer.BadParameter(f"{source} is not a valid feed source!") + + return ",".join(sources) + + @app.command() def list_feeds( bbox: Annotated[ @@ -86,10 +99,8 @@ def list_feeds( ["Feed Source", "Transit URL", "Bounding Box"], theme=Themes.OCEAN, hrules=1 ) - filtered_srcs = "" - for src in feed_sources: feed_bbox: Bbox = src.bbox if bbox is not None and predicate == "contains": @@ -101,7 +112,6 @@ def list_feeds( ): continue - filtered_srcs += src.__name__ + ", " if pretty is True: @@ -116,7 +126,6 @@ def list_feeds( print(src.url) - if pretty is True: print("\n" + pretty_output.get_string()) @@ -124,7 +133,6 @@ def list_feeds( fetch_feeds(sources=filtered_srcs[:-1]) - @app.command() def fetch_feeds( sources: Annotated[ @@ -132,7 +140,8 @@ def fetch_feeds( typer.Option( "--sources", "-src", - help="pass value as a string separated by commas like this: src1,src2,src3", + help="pass value as a string separated by commas like this: Berlin,AlbanyNy,...", + callback=check_sources, ), ] = None, search: Annotated[ @@ -148,19 +157,27 @@ def fetch_feeds( typer.Option( "--output-dir", "-o", - help="the directory where the downloaded feeds will be saved", + help="the directory where the downloaded feeds will be saved, default is feeds", ), ] = "feeds", + concurrency: Annotated[ + Optional[int], + typer.Option( + "--concurrency", + "-c", + help="the number of concurrent downloads, default is 4", + ), + ] = 4, ) -> None: """Fetch feeds from sources. :param sources: List of :FeedSource: modules to fetch; if not set, will fetch all available. :param search: Search for feeds based on a string. :param output_dir: The directory where the downloaded feeds will be saved; default is feeds. + :param concurrency: The number of concurrent downloads; default is 4. """ # statuses = {} # collect the statuses for all the files - if not sources: if not search: # fetch all feeds @@ -173,8 +190,10 @@ def fetch_feeds( if search.lower() in src.__name__.lower() or search.lower() in src.url.lower() ] else: - # fetch feeds based on sources - sources = [src for src in feed_sources if src.__name__.lower() in sources.lower()] + if search: + raise typer.BadParameter("Please pass either sources or search, not both at the same time!") + else: + sources = [src for src in feed_sources if src.__name__.lower() in sources.lower()] output_dir_path = os.path.join(os.getcwd(), output_dir) if not os.path.exists(output_dir_path): @@ -182,19 +201,36 @@ def fetch_feeds( LOG.info(f"Going to fetch feeds from sources: {sources}") - for src in sources: - LOG.debug(f"Going to start fetch for {src}...") - try: - if issubclass(src, FeedSource): - inst = src() - inst.ddir = output_dir_path - inst.status_file = os.path.join(inst.ddir, src.__name__ + ".pkl") - inst.fetch() - # statuses.update(inst.status) - else: - LOG.warning(f"Skipping class {src.__name__}, which does not subclass FeedSource.") - except AttributeError: - LOG.error(f"Skipping feed {src}, which could not be found.") + threads = [] + + def thread_worker(): + while True: + try: + src = sources.pop(0) + except IndexError: + break + + LOG.debug(f"Going to start fetch for {src}...") + try: + if issubclass(src, FeedSource): + inst = src() + inst.ddir = output_dir_path + inst.status_file = os.path.join(inst.ddir, src.__name__ + ".pkl") + inst.fetch() + # statuses.update(inst.status) + else: + LOG.warning(f"Skipping class {src.__name__}, which does not subclass FeedSource.") + except AttributeError: + LOG.error(f"Skipping feed {src}, which could not be found.") + + for _ in range(concurrency): + thread = threading.Thread(target=thread_worker) + thread.start() + threads.append(thread) + + # Wait for all threads to complete + for thread in threads: + thread.join() # ptable = ColorTable( # [ diff --git a/gtfs/feed_source.py b/gtfs/feed_source.py index 6c9840c..e6f1479 100644 --- a/gtfs/feed_source.py +++ b/gtfs/feed_source.py @@ -144,7 +144,7 @@ def download_feed(self, feed_file: str, url: str, do_stream: bool = True) -> boo LOG.debug("No last-modified header set") posted_date = datetime.utcnow().strftime(TIMECHECK_FMT) self.set_posted_date(feed_file, posted_date) - LOG.info("Download completed successfully.") + LOG.info(f"Download completed successfully for {feed_file}.") return True else: self.set_error(feed_file, "Download failed") diff --git a/tests/test_cli.py b/tests/test_cli.py index 3d02fb2..3428b1b 100644 --- a/tests/test_cli.py +++ b/tests/test_cli.py @@ -49,3 +49,23 @@ def test_contains_predicate(self, runner): def test_pretty(self, runner): result = runner.invoke(app, ["list-feeds", "-pt"], input="N\n") assert result.exit_code == 0 + + +class TestFetchFeedsCommand: + def test_help(self, runner): + result = runner.invoke(app, ["fetch-feeds", "--help"]) + assert result.exit_code == 0 + assert "Fetch feeds from sources." in result.stdout + + def test_bad_args(self, runner): + result = runner.invoke(app, ["fetch-feeds", "-src", "berlin", "-s", "cdta"]) + assert result.exit_code == 2 + assert "Please pass either sources or search" in result.stdout + + def test_fetch_with_sources(self, runner): + result = runner.invoke(app, ["fetch-feeds", "-src", "berlin"]) + assert result.exit_code == 0 + + def test_fetch_with_search(self, runner): + result = runner.invoke(app, ["fetch-feeds", "-s", "cdta"]) + assert result.exit_code == 0