Skip to content

Commit

Permalink
contiguous chat completion
Browse files Browse the repository at this point in the history
  • Loading branch information
sbordt committed Apr 12, 2024
1 parent e9bc2bb commit 153ef75
Show file tree
Hide file tree
Showing 2 changed files with 110 additions and 14 deletions.
105 changes: 96 additions & 9 deletions tabmemcheck/chat_completion.py
Original file line number Diff line number Diff line change
Expand Up @@ -215,26 +215,113 @@ def row_completion(
####################################################################################


def build_contiguous_query(
text: str, prefix_length: int, suffix_length: int, few_shot: int, rng
):
query_length = (prefix_length + suffix_length) * (1 + few_shot)
# the length of the string must be at least (prefix_length + suffix_length) * (1 + few_shot)
assert (
len(text) >= query_length
), "The provided string is too short for the specified prefix and suffix lengths."
# choose a random sub-string of length query_length
idx = rng.integers(low=0, high=len(text) - query_length)
s_query = text[idx : idx + query_length]
# construct few-shot examples
few_shot_examples = []
for i_fs in range(few_shot):
offset = (prefix_length + suffix_length) * i_fs
few_shot_examples.append(
(
[s_query[offset : offset + prefix_length]],
[
s_query[
offset + prefix_length : offset + prefix_length + suffix_length
]
],
)
)
# prefix and suffix
prefix = s_query[
query_length - prefix_length - suffix_length : query_length - suffix_length
]
suffix = s_query[query_length - suffix_length :]
return few_shot_examples, prefix, suffix


def chat_completion(
llm: LLM_Interface,
strings: list[str],
system_prompt: str = "You are a helpful assistant that complets the user's input.",
few_shot=5,
strings: str | list[str],
system_prompt: str = "You are a helpful assistant.",
prefix_length: int = None,
suffix_length: int = None,
few_shot=5, # integer, or list [str, ..., str] or [[str,..,str], ..., [str,..,str]]
contiguous=False,
num_queries=10,
print_levenshtein=False,
out_file=None,
rng=None,
):
"""Basic completion with a chat model and a list of strings."""
# randomly split the strings into prefixes and suffixes, then use prefix_suffix_chat_completion
"""General-purpose chat completion."""
if rng is None:
rng = np.random.default_rng()
if isinstance(strings, str):
strings = [strings]

def prefix_suffix_split(s):
if prefix_length is not None:
return s[:prefix_length], s[prefix_length:]
else: # randomly split the string into prefix and suffix
idx = rng.integers(low=int(len(s) / 3), high=int(2 * len(s) / 3))
return s[:idx], s[idx:]

if contiguous:
# few-shot has to be an integer
assert isinstance(
few_shot, int
), "For contiguous chat completion, few_shot must be an integer."
# both prefix_length and suffix_length have to be specified
assert (
prefix_length is not None and suffix_length is not None
), "For contiguous chat completion, both prefix_length and suffix_length have to be specified."
prefixes, suffixes, responses = [], [], []
for _ in range(num_queries):
# select a random string and build the query
few_shot_examples, prefix, suffix = build_contiguous_query(
rng.choice(strings), prefix_length, suffix_length, few_shot, rng
)
# send query
prefix, suffix, response = prefix_suffix_chat_completion(
llm,
[prefix],
[suffix],
system_prompt,
few_shot=few_shot_examples,
num_queries=1,
print_levenshtein=print_levenshtein,
out_file=out_file,
rng=rng,
)
prefixes.append(prefix)
suffixes.append(suffix)
responses.append(response)
return prefixes, suffixes, responses

# non-contiguous
prefixes = []
suffixes = []
for s in strings:
idx = rng.integers(low=int(len(s) / 3), high=int(2 * len(s) / 3))
prefixes.append(s[:idx])
suffixes.append(s[idx:])
for s_query in strings: # fixed prefix length specified by the user
prefix, suffix = prefix_suffix_split(s_query)
prefixes.append(prefix)
suffixes.append(suffix)
# few shot list
if isinstance(few_shot, list):
if len(few_shot) > 0:
if isinstance(few_shot[0], list): # list of lists
few_shot = [[prefix_suffix_split(s) for s in fs] for fs in few_shot]
few_shot = [([x[0] for x in fs], [x[1] for x in fs]) for fs in few_shot]
else: # list of strings
few_shot = [prefix_suffix_split(s) for s in few_shot]
few_shot = [([fs[0]], [fs[1]]) for fs in few_shot]
return prefix_suffix_chat_completion(
llm,
prefixes,
Expand Down
19 changes: 14 additions & 5 deletions tabmemcheck/llm.py
Original file line number Diff line number Diff line change
Expand Up @@ -108,13 +108,15 @@ class OpenAILLM(LLM_Interface):
client: OpenAI = None
model: str = None

def __init__(self, client, model=None):
def __init__(self, client, model, chat_mode=None):
super().__init__()
self.client = client
self.model = model
# auto-detect chat models
if "gpt-3.5" in model or "gpt-4" in model:
self.chat_mode = True
if chat_mode is not None:
self.chat_mode = chat_mode

@retry(
retry=retry_if_not_exception_type(openai.BadRequestError),
Expand Down Expand Up @@ -153,17 +155,20 @@ def chat_completion(self, messages, temperature, max_tokens):
)
# we return the completion string or "" if there is an invalid response/query
try:
response = response.choices[0].message.content
response_content = response.choices[0].message.content
except:
print(f"Invalid response {response}")
response = ""
return response
response_content = ""
if response_content is None:
print(f"Invalid response {response}")
response_content = ""
return response_content

def __repr__(self) -> str:
return f"{self.model}"


def openai_setup(model: str, azure: bool = False):
def openai_setup(model: str, azure: bool = False, *args, **kwargs):
"""Setup an OpenAI language model.
:param model: The name of the model (e.g. "gpt-3.5-turbo-0613").
Expand Down Expand Up @@ -197,6 +202,8 @@ def openai_setup(model: str, azure: bool = False):
if "AZURE_OPENAI_VERSION" in os.environ
else None
),
*args,
**kwargs,
)
else: # openai api
client = OpenAI(
Expand All @@ -206,6 +213,8 @@ def openai_setup(model: str, azure: bool = False):
organization=(
os.environ["OPENAI_API_ORG"] if "OPENAI_API_ORG" in os.environ else None
),
*args,
**kwargs,
)

# the llm
Expand Down

0 comments on commit 153ef75

Please sign in to comment.