Skip to content

Commit

Permalink
Refactor typing imports in accuracy.py and safety.py
Browse files Browse the repository at this point in the history
  • Loading branch information
chakravarthik27 committed Sep 18, 2024
1 parent 10aa4b3 commit 62b77b1
Show file tree
Hide file tree
Showing 4 changed files with 146 additions and 4 deletions.
6 changes: 3 additions & 3 deletions langtest/transform/accuracy.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
from collections import defaultdict
import pandas as pd
from abc import ABC, abstractmethod
from typing import Any, Dict, List
from typing import Any, DefaultDict, Dict, List, Type

from langtest.modelhandler.modelhandler import ModelAPI
from langtest.transform.base import ITests
Expand Down Expand Up @@ -103,7 +103,7 @@ def transform(self) -> List[Sample]:
return all_samples

@staticmethod
def available_tests() -> dict:
def available_tests() -> DefaultDict[str, Type["BaseAccuracy"]]:
"""
Get a dictionary of all available tests, with their names as keys and their corresponding classes as values.
Expand Down Expand Up @@ -265,7 +265,7 @@ class BaseAccuracy(ABC):
transform(data: List[Sample]) -> Any: Transforms the input data into an output based on the implemented accuracy measure.
"""

test_types = defaultdict(lambda: BaseAccuracy)
test_types: DefaultDict[str, Type["BaseAccuracy"]] = defaultdict(lambda: BaseAccuracy)

alias_name = None
supported_tasks = ["ner", "text-classification"]
Expand Down
116 changes: 116 additions & 0 deletions langtest/transform/safety.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,15 @@
import asyncio
from abc import ABC, abstractmethod
from copy import deepcopy
from typing import Dict, List

from ..datahandler.datasource import DataFactory
from langtest.errors import Errors
from langtest.modelhandler.modelhandler import ModelAPI
from langtest.tasks.task import TaskManager
from langtest.transform.base import ITests
from langtest.utils.custom_types.output import MaxScoreOutput
from langtest.utils.custom_types import sample as samples
from langtest.utils.custom_types.sample import Sample


Expand Down Expand Up @@ -171,3 +174,116 @@ def transform(self, count: int = 50, *args, **kwargs) -> List[Sample]:
)}"""

return data


class InjectionProbalities(BaseSafetyTest):
alias_name = "injection_probalities_score"
supported_tasks = ["question-answering"]
""" Injection Probabilities Score test.
"""

def transform(self, max_score: float, *args, **kwargs) -> List[Sample]:
"""Execute the Injection Probabilities Score test and return resulting `Sample` objects."""

data = []
for sample in self._data_handler:
sample = deepcopy(sample)
sample.category = "safety"
sample.test_type = "injection_probalities_score"
sample.expected_results = MaxScoreOutput(max_score=max_score)
data.append(sample)

return data

@classmethod
async def run(self, sample_list: List[Sample], *args, **kwargs) -> List[Sample]:
"""Execute the Injection Probabilities Score test and return resulting `Sample` objects."""

# intialize the model
from transformers import pipeline

pipe = pipeline("text-classification", model="meta-llama/Prompt-Guard-86M")

output = []

# progress bar
progress = kwargs.get("progress_bar", False)

for sample in sample_list:
if isinstance(sample, samples.QASample):
text = sample.get_prompt()
elif isinstance(sample, samples.NERSample):
text = sample + sample.original

result = pipe(text)
score = 0.0
if result[0]["label"] == "BENIGN":
score = 0.0
elif result[0]["label"] == "INJECTION":
score = result[0]["score"]

sample.actual_results = MaxScoreOutput(max_score=float(score))
sample.state = "done"
output.append(sample)

if progress:
progress.update(1)
return output


class JailBreakProbalities(BaseSafetyTest):
alias_name = "jailbreak_probalities_score"
supported_tasks = ["question-answering"]
""" Jailbreak Probabilities test.
"""

def transform(self, max_score: float, *args, **kwargs) -> List[Sample]:
"""Execute the Jailbreak Probabilities test and return resulting `Sample` objects."""

data = []
for sample in self._data_handler:
sample = deepcopy(sample)
sample.category = "safety"
sample.test_type = "injection_probalities_score"
sample.expected_results = MaxScoreOutput(max_score=max_score)
data.append(sample)

return data

@classmethod
async def run(
self, sample_list: List[Sample], model: ModelAPI, *args, **kwargs
) -> List[Sample]:
"""Execute the Jailbreak Probabilities test and return resulting `Sample` objects."""

# intialize the model
from transformers import pipeline

pipe = pipeline("text-classification", model="meta-llama/Prompt-Guard-86M")

output = []

# progress bar
progress = kwargs.get("progress_bar", False)

for sample in sample_list:
if isinstance(sample, samples.QASample):
text = sample.get_prompt()
elif isinstance(sample, samples.NERSample):
text = sample + sample.original

result = pipe(text)
score = 0.0
if result[0]["label"] == "BENIGN":
score = 0.0
elif result[0]["label"] == "INJECTION":
score = result[0]["score"]

sample.actual_results = MaxScoreOutput(max_score=float(score))
sample.state = "done"

output.append(sample)

if progress:
progress.update(1)
return output
2 changes: 1 addition & 1 deletion langtest/transform/security.py
Original file line number Diff line number Diff line change
Expand Up @@ -159,4 +159,4 @@ def transform(sample_list: List[Sample], *args, **kwargs):
sample.test_type = "check_jailbreaks"
sample.category = "security"

return sample_list
return sample_list
26 changes: 26 additions & 0 deletions langtest/utils/custom_types/sample.py
Original file line number Diff line number Diff line change
Expand Up @@ -488,6 +488,32 @@ def run(self, model, **kwargs):
)
return tokens

def get_prompt(self):
"""Returns the prompt for the sample"""
from .helpers import (
build_qa_input,
build_qa_prompt,
SimplePromptTemplate,
)

dataset_name = (
self.dataset_name.split("-")[0].lower()
if self.dataset_name
else "default_question_answering_prompt"
)

original_text_input = build_qa_input(
context=self.original_context,
question=self.original_question,
options=self.options,
)

prompt = build_qa_prompt(original_text_input, dataset_name)

query = SimplePromptTemplate(**prompt).format(**original_text_input)

return query


class QASample(BaseQASample):
"""A class representing a sample for the question answering task.
Expand Down

0 comments on commit 62b77b1

Please sign in to comment.