Skip to content

Commit

Permalink
Fix comparison
Browse files Browse the repository at this point in the history
  • Loading branch information
mariosasko committed Apr 1, 2021
1 parent ba3a9f8 commit 04e1c25
Showing 1 changed file with 6 additions and 6 deletions.
12 changes: 6 additions & 6 deletions docs/source/scripts/check_notebooks.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
import argparse
import copy
import multiprocess
import os
import shutil
import string
Expand All @@ -9,6 +8,7 @@
from functools import partial
from pathlib import Path

import multiprocess
import nbformat
from nbconvert.preprocessors import ExecutePreprocessor

Expand All @@ -20,7 +20,7 @@
TRANS_TABLE = str.maketrans(dict.fromkeys(string.whitespace))


def inject_sst():
def inject_shared_download():
delim = "&" if os.name == "nt" else ";"
subprocess.call(
delim.join([INSTALL_SOURCE_VERSION_COMMAND[4:], INSTALL_SST_COMMAND]),
Expand Down Expand Up @@ -108,7 +108,7 @@ def check_notebook_output(notebook_path, env="python3", ignore_whitespace=False)
new_cell_stdout_ = new_cell_stdout

if ignore_whitespace:
original_cell = original_cell_stdout.translate(TRANS_TABLE)
original_cell_stdout = original_cell_stdout.translate(TRANS_TABLE)
new_cell_stdout = new_cell_stdout.translate(TRANS_TABLE)
else:
if new_cell_stdout[-1] == "\n" and original_cell_stdout[-1] != "\n":
Expand Down Expand Up @@ -150,8 +150,8 @@ def check_notebook_output(notebook_path, env="python3", ignore_whitespace=False)
report = check_notebook_output(notebook_path, env=args.env, ignore_whitespace=args.ignore_whitespace)
reports.append(report)
else:
# inject the SST dataset to prevent parallel download
inject_sst()
# predownload datasets/vectorizers to prevent parallel download
inject_shared_download()
with multiprocess.Pool(num_proc) as pool:
reports = pool.map(partial(check_notebook_output, env=args.env, ignore_whitespace=args.ignore_whitespace), notebook_paths)

Expand All @@ -168,7 +168,7 @@ def check_notebook_output(notebook_path, env="python3", ignore_whitespace=False)
for i, original_output, new_output in report),
" " * 4,
)
for notebook, report in reports
for notebook, report in reports if len(report) > 0
])
raise Exception(
"❌❌ Mismatches found in the outputs of the notebooks:\n\n" + reports_str
Expand Down

0 comments on commit 04e1c25

Please sign in to comment.