Skip to content

Commit

Permalink
Merge pull request #434 from KShivendu/feat/openai-dbpedia-1M
Browse files Browse the repository at this point in the history
feat: Add new dataset with OpenAI  embeddings for 1M DBpedia entities
  • Loading branch information
erikbern authored Jul 5, 2023
2 parents f8236b8 + cf116aa commit 9bcf775
Show file tree
Hide file tree
Showing 2 changed files with 22 additions and 0 deletions.
21 changes: 21 additions & 0 deletions ann_benchmarks/datasets.py
Original file line number Diff line number Diff line change
Expand Up @@ -475,6 +475,22 @@ def movielens10m(out_fn):
def movielens20m(out_fn):
movielens("ml-20m.zip", "ml-20m/ratings.csv", out_fn, ",", True)

def dbpedia_entities_openai_1M(out_fn, n = None):
from sklearn.model_selection import train_test_split
from datasets import load_dataset
import numpy as np

data = load_dataset("KShivendu/dbpedia-entities-openai-1M", split="train")
if n is not None and n >= 100_000:
data = data.select(range(n))

embeddings = data.to_pandas()['openai'].to_numpy()
embeddings = np.vstack(embeddings).reshape((-1, 1536))

X_train, X_test = train_test_split(embeddings, test_size=10_000, random_state=42)

write_output(X_train, X_test, out_fn, "angular")


DATASETS = {
"deep-image-96-angular": deep_image,
Expand Down Expand Up @@ -505,3 +521,8 @@ def movielens20m(out_fn):
"movielens10m-jaccard": movielens10m,
"movielens20m-jaccard": movielens20m,
}

DATASETS.update({
f"dbpedia-openai-{n//1000}k-angular": lambda out_fn: dbpedia_entities_openai_1M(out_fn, n)
for n in range(100_000, 1_100_000, 100_000)
})
1 change: 1 addition & 0 deletions requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -8,3 +8,4 @@ psutil==5.9.4
scikit-learn==1.2.1
jinja2==3.1.2
pytest==7.2.2
datasets==2.12.0

0 comments on commit 9bcf775

Please sign in to comment.