Skip to content

Commit

Permalink
Merge pull request #114 from AlfredChester:master
Browse files Browse the repository at this point in the history
New feature: 为 IO.output_gen 方法增加 time_limit 参数
  • Loading branch information
Mr-Python-in-China authored Oct 2, 2024
2 parents 8027e5a + 40e7247 commit aa1b6e7
Show file tree
Hide file tree
Showing 4 changed files with 179 additions and 54 deletions.
4 changes: 3 additions & 1 deletion .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -134,5 +134,7 @@ target/
# Pycharm
venv

*.DS_Store

# VS Code
.vscode
.vscode
132 changes: 103 additions & 29 deletions cyaron/compare.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ def __init__(self, name, mismatch):
self.mismatch = mismatch

def __str__(self):
return 'In program: \'{}\'. {}'.format(self.name,self.mismatch)
return "In program: '{}'. {}".format(self.name, self.mismatch)


class Compare:
Expand All @@ -38,7 +38,7 @@ def __process_file(file):
file.output_file.seek(0)
return file.output_filename, file.output_file.read()
else:
with open(file, "r", newline='\n') as f:
with open(file, "r", newline="\n") as f:
return file, f.read()

@staticmethod
Expand All @@ -51,26 +51,43 @@ def __normal_max_workers(workers):

@classmethod
def output(cls, *files, **kwargs):
kwargs = unpack_kwargs('output', kwargs, ('std', ('grader', DEFAULT_GRADER), ('max_workers', -1),
('job_pool', None), ('stop_on_incorrect', None)))
std = kwargs['std']
grader = kwargs['grader']
max_workers = kwargs['max_workers']
job_pool = kwargs['job_pool']
if kwargs['stop_on_incorrect'] is not None:
kwargs = unpack_kwargs(
"output",
kwargs,
(
"std",
("grader", DEFAULT_GRADER),
("max_workers", -1),
("job_pool", None),
("stop_on_incorrect", None),
),
)
std = kwargs["std"]
grader = kwargs["grader"]
max_workers = kwargs["max_workers"]
job_pool = kwargs["job_pool"]
if kwargs["stop_on_incorrect"] is not None:
log.warn("parameter stop_on_incorrect is deprecated and has no effect.")

if (max_workers is None or max_workers >= 0) and job_pool is None:
max_workers = cls.__normal_max_workers(max_workers)
try:
from concurrent.futures import ThreadPoolExecutor

with ThreadPoolExecutor(max_workers=max_workers) as job_pool:
return cls.output(*files, std=std, grader=grader, max_workers=max_workers, job_pool=job_pool)
return cls.output(
*files,
std=std,
grader=grader,
max_workers=max_workers,
job_pool=job_pool
)
except ImportError:
pass

def get_std():
return cls.__process_file(std)[1]

if job_pool is not None:
std = job_pool.submit(get_std).result()
else:
Expand All @@ -87,61 +104,118 @@ def do(file):

@classmethod
def program(cls, *programs, **kwargs):
kwargs = unpack_kwargs('program', kwargs, ('input', ('std', None), ('std_program', None),
('grader', DEFAULT_GRADER), ('max_workers', -1),
('job_pool', None), ('stop_on_incorrect', None)))
input = kwargs['input']
std = kwargs['std']
std_program = kwargs['std_program']
grader = kwargs['grader']
max_workers = kwargs['max_workers']
job_pool = kwargs['job_pool']
if kwargs['stop_on_incorrect'] is not None:
kwargs = unpack_kwargs(
"program",
kwargs,
(
"input",
("std", None),
("std_program", None),
("grader", DEFAULT_GRADER),
("max_workers", -1),
("job_pool", None),
("stop_on_incorrect", None),
),
)
input = kwargs["input"]
std = kwargs["std"]
std_program = kwargs["std_program"]
grader = kwargs["grader"]
max_workers = kwargs["max_workers"]
job_pool = kwargs["job_pool"]
if kwargs["stop_on_incorrect"] is not None:
log.warn("parameter stop_on_incorrect is deprecated and has no effect.")

if (max_workers is None or max_workers >= 0) and job_pool is None:
max_workers = cls.__normal_max_workers(max_workers)
try:
from concurrent.futures import ThreadPoolExecutor

with ThreadPoolExecutor(max_workers=max_workers) as job_pool:
return cls.program(*programs, input=input, std=std, std_program=std_program, grader=grader, max_workers=max_workers, job_pool=job_pool)
return cls.program(
*programs,
input=input,
std=std,
std_program=std_program,
grader=grader,
max_workers=max_workers,
job_pool=job_pool
)
except ImportError:
pass

if not isinstance(input, IO):
raise TypeError("expect {}, got {}".format(type(IO).__name__, type(input).__name__))
raise TypeError(
"expect {}, got {}".format(type(IO).__name__, type(input).__name__)
)
input.flush_buffer()
input.input_file.seek(0)

if std_program is not None:

def get_std():
with open(os.dup(input.input_file.fileno()), 'r', newline='\n') as input_file:
content = make_unicode(subprocess.check_output(std_program, shell=(not list_like(std_program)), stdin=input.input_file, universal_newlines=True))
with open(
os.dup(input.input_file.fileno()), "r", newline="\n"
) as input_file:
content = make_unicode(
subprocess.check_output(
std_program,
shell=(not list_like(std_program)),
stdin=input.input_file,
universal_newlines=True,
)
)
input_file.seek(0)
return content

if job_pool is not None:
std = job_pool.submit(get_std).result()
else:
std = get_std()
elif std is not None:

def get_std():
return cls.__process_file(std)[1]

if job_pool is not None:
std = job_pool.submit(get_std).result()
else:
std = get_std()
else:
raise TypeError('program() missing 1 required non-None keyword-only argument: \'std\' or \'std_program\'')
raise TypeError(
"program() missing 1 required non-None keyword-only argument: 'std' or 'std_program'"
)

def do(program_name):
timeout = None
if list_like(program_name) and len(program_name) == 2 and int_like(program_name[-1]):
if (
list_like(program_name)
and len(program_name) == 2
and int_like(program_name[-1])
):
program_name, timeout = program_name
with open(os.dup(input.input_file.fileno()), 'r', newline='\n') as input_file:
with open(
os.dup(input.input_file.fileno()), "r", newline="\n"
) as input_file:
if timeout is None:
content = make_unicode(subprocess.check_output(program_name, shell=(not list_like(program_name)), stdin=input_file, universal_newlines=True))
content = make_unicode(
subprocess.check_output(
program_name,
shell=(not list_like(program_name)),
stdin=input_file,
universal_newlines=True,
)
)
else:
content = make_unicode(subprocess.check_output(program_name, shell=(not list_like(program_name)), stdin=input_file, universal_newlines=True, timeout=timeout))
content = make_unicode(
subprocess.check_output(
program_name,
shell=(not list_like(program_name)),
stdin=input_file,
universal_newlines=True,
timeout=timeout,
)
)
input_file.seek(0)
cls.__compare_two(program_name, content, std, grader)

Expand Down
40 changes: 27 additions & 13 deletions cyaron/io.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,9 +84,9 @@ def __init__(self,
self.__escape_format(output_suffix))
self.input_filename, self.output_filename = None, None
self.__input_temp, self.__output_temp = False, False
self.__init_file(input_file, data_id, 'i')
self.__init_file(input_file, data_id, "i")
if not disable_output:
self.__init_file(output_file, data_id, 'o')
self.__init_file(output_file, data_id, "o")
else:
self.output_file = None
self.__closed = False
Expand All @@ -96,7 +96,7 @@ def __init_file(self, f: Union[IOBase, str, int, None],
data_id: Union[int, None], file_type: str):
if isinstance(f, IOBase):
# consider ``f`` as a file object
if file_type == 'i':
if file_type == "i":
self.input_file = f
else:
self.output_file = f
Expand All @@ -108,14 +108,14 @@ def __init_file(self, f: Union[IOBase, str, int, None],
# consider wanna temp file
fd, self.input_filename = tempfile.mkstemp()
self.__init_file(fd, data_id, file_type)
if file_type == 'i':
if file_type == "i":
self.__input_temp = True
else:
self.__output_temp = True
else:
# consider ``f`` as filename template
filename = f.format(data_id or '')
if file_type == 'i':
filename = f.format(data_id or "")
if file_type == "i":
self.input_filename = filename
else:
self.output_filename = filename
Expand All @@ -125,7 +125,7 @@ def __init_file(self, f: Union[IOBase, str, int, None],

def __escape_format(self, st: str):
"""replace "{}" to "{{}}" """
return re.sub(r'\{', '{{', re.sub(r'\}', '}}', st))
return re.sub(r"\{", "{{", re.sub(r"\}", "}}", st))

def __del_files(self):
"""delete files"""
Expand Down Expand Up @@ -207,21 +207,35 @@ def input_writeln(self, *args, **kwargs):
args.append("\n")
self.input_write(*args, **kwargs)

def output_gen(self, shell_cmd):
def output_gen(self, shell_cmd, time_limit=None):
"""
Run the command `shell_cmd` (usually the std program) and send it the input file as stdin.
Write its output to the output file.
Args:
shell_cmd: the command to run, usually the std program.
time_limit: the time limit (seconds) of the command to run.
None means infinity. Defaults to None.
"""
self.flush_buffer()
origin_pos = self.input_file.tell()
self.input_file.seek(0)
subprocess.check_call(shell_cmd,
shell=True,
stdin=self.input_file,
stdout=self.output_file,
universal_newlines=True)
if time_limit is not None:
subprocess.check_call(
shell_cmd,
shell=True,
timeout=time_limit,
stdin=self.input_file,
stdout=self.output_file,
universal_newlines=True,
)
else:
subprocess.check_call(
shell_cmd,
shell=True,
stdin=self.input_file,
stdout=self.output_file,
universal_newlines=True,
)
self.input_file.seek(origin_pos)

log.debug(self.output_filename, " done")
Expand Down
57 changes: 46 additions & 11 deletions cyaron/tests/io_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
import os
import shutil
import tempfile
import subprocess
from cyaron import IO
from cyaron.output_capture import captured_output

Expand All @@ -26,7 +27,12 @@ def test_create_files_simple(self):

def test_create_files_prefix_id(self):
with captured_output() as (out, err):
IO(file_prefix="test_prefix", data_id=233, input_suffix=".inp", output_suffix=".ans")
IO(
file_prefix="test_prefix",
data_id=233,
input_suffix=".inp",
output_suffix=".ans",
)
self.assertTrue(os.path.exists("test_prefix233.inp"))
self.assertTrue(os.path.exists("test_prefix233.ans"))

Expand All @@ -50,8 +56,8 @@ def test_write_stuff(self):
input = f.read()
with open("test_write.out") as f:
output = f.read()
self.assertEqual(input.split(), ['1', '2', '3', '4', '5', '6', '7', '8', '9'])
self.assertEqual(output.split(), ['9', '8', '7', '6', '5', '4', '3', '2', '1'])
self.assertEqual(input.split(), ["1", "2", "3", "4", "5", "6", "7", "8", "9"])
self.assertEqual(output.split(), ["9", "8", "7", "6", "5", "4", "3", "2", "1"])
self.assertEqual(input.count("\n"), 2)
self.assertEqual(output.count("\n"), 2)

Expand All @@ -64,15 +70,44 @@ def test_output_gen(self):
output = f.read()
self.assertEqual(output.strip("\n"), "233")

def test_output_gen_time_limit_exceeded(self):
time_limit_exceeded = False
with captured_output() as (out, err):
with open("long_time.py", "w") as f:
f.write("import time\ntime.sleep(10)\nprint(1)")

try:
with IO("test_gen.in", "test_gen.out") as test:
test.output_gen("python long_time.py", time_limit=1)
except subprocess.TimeoutExpired:
time_limit_exceeded = True
self.assertEqual(time_limit_exceeded, True)

def test_output_gen_time_limit_not_exceeded(self):
time_limit_exceeded = False
with captured_output() as (out, err):
with open("short_time.py", "w") as f:
f.write("import time\ntime.sleep(0.2)\nprint(1)")

try:
with IO("test_gen.in", "test_gen.out") as test:
test.output_gen("python short_time.py", time_limit=1)
except subprocess.TimeoutExpired:
time_limit_exceeded = True
with open("test_gen.out") as f:
output = f.read()
self.assertEqual(output.strip("\n"), "1")
self.assertEqual(time_limit_exceeded, False)

def test_init_overload(self):
with IO(file_prefix='data{', data_id=5) as test:
self.assertEqual(test.input_filename, 'data{5.in')
self.assertEqual(test.output_filename, 'data{5.out')
with IO('data{}.in', 'data{}.out', 5) as test:
self.assertEqual(test.input_filename, 'data5.in')
self.assertEqual(test.output_filename, 'data5.out')
with open('data5.in', 'w+') as fin:
with open('data5.out', 'w+') as fout:
with IO(file_prefix="data{", data_id=5) as test:
self.assertEqual(test.input_filename, "data{5.in")
self.assertEqual(test.output_filename, "data{5.out")
with IO("data{}.in", "data{}.out", 5) as test:
self.assertEqual(test.input_filename, "data5.in")
self.assertEqual(test.output_filename, "data5.out")
with open("data5.in", "w+") as fin:
with open("data5.out", "w+") as fout:
with IO(fin, fout) as test:
self.assertEqual(test.input_file, fin)
self.assertEqual(test.output_file, fout)

0 comments on commit aa1b6e7

Please sign in to comment.