diff --git a/.gitignore b/.gitignore index 68e07e7..a9555bc 100644 --- a/.gitignore +++ b/.gitignore @@ -134,5 +134,7 @@ target/ # Pycharm venv +*.DS_Store + # VS Code -.vscode \ No newline at end of file +.vscode diff --git a/cyaron/compare.py b/cyaron/compare.py index 72a94a3..7363a9b 100644 --- a/cyaron/compare.py +++ b/cyaron/compare.py @@ -18,7 +18,7 @@ def __init__(self, name, mismatch): self.mismatch = mismatch def __str__(self): - return 'In program: \'{}\'. {}'.format(self.name,self.mismatch) + return "In program: '{}'. {}".format(self.name, self.mismatch) class Compare: @@ -38,7 +38,7 @@ def __process_file(file): file.output_file.seek(0) return file.output_filename, file.output_file.read() else: - with open(file, "r", newline='\n') as f: + with open(file, "r", newline="\n") as f: return file, f.read() @staticmethod @@ -51,26 +51,43 @@ def __normal_max_workers(workers): @classmethod def output(cls, *files, **kwargs): - kwargs = unpack_kwargs('output', kwargs, ('std', ('grader', DEFAULT_GRADER), ('max_workers', -1), - ('job_pool', None), ('stop_on_incorrect', None))) - std = kwargs['std'] - grader = kwargs['grader'] - max_workers = kwargs['max_workers'] - job_pool = kwargs['job_pool'] - if kwargs['stop_on_incorrect'] is not None: + kwargs = unpack_kwargs( + "output", + kwargs, + ( + "std", + ("grader", DEFAULT_GRADER), + ("max_workers", -1), + ("job_pool", None), + ("stop_on_incorrect", None), + ), + ) + std = kwargs["std"] + grader = kwargs["grader"] + max_workers = kwargs["max_workers"] + job_pool = kwargs["job_pool"] + if kwargs["stop_on_incorrect"] is not None: log.warn("parameter stop_on_incorrect is deprecated and has no effect.") if (max_workers is None or max_workers >= 0) and job_pool is None: max_workers = cls.__normal_max_workers(max_workers) try: from concurrent.futures import ThreadPoolExecutor + with ThreadPoolExecutor(max_workers=max_workers) as job_pool: - return cls.output(*files, std=std, grader=grader, max_workers=max_workers, job_pool=job_pool) + return cls.output( + *files, + std=std, + grader=grader, + max_workers=max_workers, + job_pool=job_pool + ) except ImportError: pass def get_std(): return cls.__process_file(std)[1] + if job_pool is not None: std = job_pool.submit(get_std).result() else: @@ -87,61 +104,118 @@ def do(file): @classmethod def program(cls, *programs, **kwargs): - kwargs = unpack_kwargs('program', kwargs, ('input', ('std', None), ('std_program', None), - ('grader', DEFAULT_GRADER), ('max_workers', -1), - ('job_pool', None), ('stop_on_incorrect', None))) - input = kwargs['input'] - std = kwargs['std'] - std_program = kwargs['std_program'] - grader = kwargs['grader'] - max_workers = kwargs['max_workers'] - job_pool = kwargs['job_pool'] - if kwargs['stop_on_incorrect'] is not None: + kwargs = unpack_kwargs( + "program", + kwargs, + ( + "input", + ("std", None), + ("std_program", None), + ("grader", DEFAULT_GRADER), + ("max_workers", -1), + ("job_pool", None), + ("stop_on_incorrect", None), + ), + ) + input = kwargs["input"] + std = kwargs["std"] + std_program = kwargs["std_program"] + grader = kwargs["grader"] + max_workers = kwargs["max_workers"] + job_pool = kwargs["job_pool"] + if kwargs["stop_on_incorrect"] is not None: log.warn("parameter stop_on_incorrect is deprecated and has no effect.") if (max_workers is None or max_workers >= 0) and job_pool is None: max_workers = cls.__normal_max_workers(max_workers) try: from concurrent.futures import ThreadPoolExecutor + with ThreadPoolExecutor(max_workers=max_workers) as job_pool: - return cls.program(*programs, input=input, std=std, std_program=std_program, grader=grader, max_workers=max_workers, job_pool=job_pool) + return cls.program( + *programs, + input=input, + std=std, + std_program=std_program, + grader=grader, + max_workers=max_workers, + job_pool=job_pool + ) except ImportError: pass if not isinstance(input, IO): - raise TypeError("expect {}, got {}".format(type(IO).__name__, type(input).__name__)) + raise TypeError( + "expect {}, got {}".format(type(IO).__name__, type(input).__name__) + ) input.flush_buffer() input.input_file.seek(0) if std_program is not None: + def get_std(): - with open(os.dup(input.input_file.fileno()), 'r', newline='\n') as input_file: - content = make_unicode(subprocess.check_output(std_program, shell=(not list_like(std_program)), stdin=input.input_file, universal_newlines=True)) + with open( + os.dup(input.input_file.fileno()), "r", newline="\n" + ) as input_file: + content = make_unicode( + subprocess.check_output( + std_program, + shell=(not list_like(std_program)), + stdin=input.input_file, + universal_newlines=True, + ) + ) input_file.seek(0) return content + if job_pool is not None: std = job_pool.submit(get_std).result() else: std = get_std() elif std is not None: + def get_std(): return cls.__process_file(std)[1] + if job_pool is not None: std = job_pool.submit(get_std).result() else: std = get_std() else: - raise TypeError('program() missing 1 required non-None keyword-only argument: \'std\' or \'std_program\'') + raise TypeError( + "program() missing 1 required non-None keyword-only argument: 'std' or 'std_program'" + ) def do(program_name): timeout = None - if list_like(program_name) and len(program_name) == 2 and int_like(program_name[-1]): + if ( + list_like(program_name) + and len(program_name) == 2 + and int_like(program_name[-1]) + ): program_name, timeout = program_name - with open(os.dup(input.input_file.fileno()), 'r', newline='\n') as input_file: + with open( + os.dup(input.input_file.fileno()), "r", newline="\n" + ) as input_file: if timeout is None: - content = make_unicode(subprocess.check_output(program_name, shell=(not list_like(program_name)), stdin=input_file, universal_newlines=True)) + content = make_unicode( + subprocess.check_output( + program_name, + shell=(not list_like(program_name)), + stdin=input_file, + universal_newlines=True, + ) + ) else: - content = make_unicode(subprocess.check_output(program_name, shell=(not list_like(program_name)), stdin=input_file, universal_newlines=True, timeout=timeout)) + content = make_unicode( + subprocess.check_output( + program_name, + shell=(not list_like(program_name)), + stdin=input_file, + universal_newlines=True, + timeout=timeout, + ) + ) input_file.seek(0) cls.__compare_two(program_name, content, std, grader) diff --git a/cyaron/io.py b/cyaron/io.py index a21cada..126ce90 100644 --- a/cyaron/io.py +++ b/cyaron/io.py @@ -84,9 +84,9 @@ def __init__(self, self.__escape_format(output_suffix)) self.input_filename, self.output_filename = None, None self.__input_temp, self.__output_temp = False, False - self.__init_file(input_file, data_id, 'i') + self.__init_file(input_file, data_id, "i") if not disable_output: - self.__init_file(output_file, data_id, 'o') + self.__init_file(output_file, data_id, "o") else: self.output_file = None self.__closed = False @@ -96,7 +96,7 @@ def __init_file(self, f: Union[IOBase, str, int, None], data_id: Union[int, None], file_type: str): if isinstance(f, IOBase): # consider ``f`` as a file object - if file_type == 'i': + if file_type == "i": self.input_file = f else: self.output_file = f @@ -108,14 +108,14 @@ def __init_file(self, f: Union[IOBase, str, int, None], # consider wanna temp file fd, self.input_filename = tempfile.mkstemp() self.__init_file(fd, data_id, file_type) - if file_type == 'i': + if file_type == "i": self.__input_temp = True else: self.__output_temp = True else: # consider ``f`` as filename template - filename = f.format(data_id or '') - if file_type == 'i': + filename = f.format(data_id or "") + if file_type == "i": self.input_filename = filename else: self.output_filename = filename @@ -125,7 +125,7 @@ def __init_file(self, f: Union[IOBase, str, int, None], def __escape_format(self, st: str): """replace "{}" to "{{}}" """ - return re.sub(r'\{', '{{', re.sub(r'\}', '}}', st)) + return re.sub(r"\{", "{{", re.sub(r"\}", "}}", st)) def __del_files(self): """delete files""" @@ -207,21 +207,35 @@ def input_writeln(self, *args, **kwargs): args.append("\n") self.input_write(*args, **kwargs) - def output_gen(self, shell_cmd): + def output_gen(self, shell_cmd, time_limit=None): """ Run the command `shell_cmd` (usually the std program) and send it the input file as stdin. Write its output to the output file. Args: shell_cmd: the command to run, usually the std program. + time_limit: the time limit (seconds) of the command to run. + None means infinity. Defaults to None. """ self.flush_buffer() origin_pos = self.input_file.tell() self.input_file.seek(0) - subprocess.check_call(shell_cmd, - shell=True, - stdin=self.input_file, - stdout=self.output_file, - universal_newlines=True) + if time_limit is not None: + subprocess.check_call( + shell_cmd, + shell=True, + timeout=time_limit, + stdin=self.input_file, + stdout=self.output_file, + universal_newlines=True, + ) + else: + subprocess.check_call( + shell_cmd, + shell=True, + stdin=self.input_file, + stdout=self.output_file, + universal_newlines=True, + ) self.input_file.seek(origin_pos) log.debug(self.output_filename, " done") diff --git a/cyaron/tests/io_test.py b/cyaron/tests/io_test.py index 1da50e4..b83dc03 100644 --- a/cyaron/tests/io_test.py +++ b/cyaron/tests/io_test.py @@ -2,6 +2,7 @@ import os import shutil import tempfile +import subprocess from cyaron import IO from cyaron.output_capture import captured_output @@ -26,7 +27,12 @@ def test_create_files_simple(self): def test_create_files_prefix_id(self): with captured_output() as (out, err): - IO(file_prefix="test_prefix", data_id=233, input_suffix=".inp", output_suffix=".ans") + IO( + file_prefix="test_prefix", + data_id=233, + input_suffix=".inp", + output_suffix=".ans", + ) self.assertTrue(os.path.exists("test_prefix233.inp")) self.assertTrue(os.path.exists("test_prefix233.ans")) @@ -50,8 +56,8 @@ def test_write_stuff(self): input = f.read() with open("test_write.out") as f: output = f.read() - self.assertEqual(input.split(), ['1', '2', '3', '4', '5', '6', '7', '8', '9']) - self.assertEqual(output.split(), ['9', '8', '7', '6', '5', '4', '3', '2', '1']) + self.assertEqual(input.split(), ["1", "2", "3", "4", "5", "6", "7", "8", "9"]) + self.assertEqual(output.split(), ["9", "8", "7", "6", "5", "4", "3", "2", "1"]) self.assertEqual(input.count("\n"), 2) self.assertEqual(output.count("\n"), 2) @@ -64,15 +70,44 @@ def test_output_gen(self): output = f.read() self.assertEqual(output.strip("\n"), "233") + def test_output_gen_time_limit_exceeded(self): + time_limit_exceeded = False + with captured_output() as (out, err): + with open("long_time.py", "w") as f: + f.write("import time\ntime.sleep(10)\nprint(1)") + + try: + with IO("test_gen.in", "test_gen.out") as test: + test.output_gen("python long_time.py", time_limit=1) + except subprocess.TimeoutExpired: + time_limit_exceeded = True + self.assertEqual(time_limit_exceeded, True) + + def test_output_gen_time_limit_not_exceeded(self): + time_limit_exceeded = False + with captured_output() as (out, err): + with open("short_time.py", "w") as f: + f.write("import time\ntime.sleep(0.2)\nprint(1)") + + try: + with IO("test_gen.in", "test_gen.out") as test: + test.output_gen("python short_time.py", time_limit=1) + except subprocess.TimeoutExpired: + time_limit_exceeded = True + with open("test_gen.out") as f: + output = f.read() + self.assertEqual(output.strip("\n"), "1") + self.assertEqual(time_limit_exceeded, False) + def test_init_overload(self): - with IO(file_prefix='data{', data_id=5) as test: - self.assertEqual(test.input_filename, 'data{5.in') - self.assertEqual(test.output_filename, 'data{5.out') - with IO('data{}.in', 'data{}.out', 5) as test: - self.assertEqual(test.input_filename, 'data5.in') - self.assertEqual(test.output_filename, 'data5.out') - with open('data5.in', 'w+') as fin: - with open('data5.out', 'w+') as fout: + with IO(file_prefix="data{", data_id=5) as test: + self.assertEqual(test.input_filename, "data{5.in") + self.assertEqual(test.output_filename, "data{5.out") + with IO("data{}.in", "data{}.out", 5) as test: + self.assertEqual(test.input_filename, "data5.in") + self.assertEqual(test.output_filename, "data5.out") + with open("data5.in", "w+") as fin: + with open("data5.out", "w+") as fout: with IO(fin, fout) as test: self.assertEqual(test.input_file, fin) self.assertEqual(test.output_file, fout)