-
Notifications
You must be signed in to change notification settings - Fork 0
/
populate_database.py
128 lines (103 loc) · 3.67 KB
/
populate_database.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
import argparse
import os
import shutil
from typing import List
from langchain_community.document_loaders import PyPDFDirectoryLoader
from langchain.text_splitter import RecursiveCharacterTextSplitter
from langchain.schema.document import Document
from langchain_community.vectorstores import Chroma
from get_embedding_function import get_embedding_function
CHROMA_PATH = "chroma_db"
DATA_PATH = "data"
def load_documents() -> List[Document]:
"""
Load PDF documents from the DATA_PATH directory.
Returns:
List[Document]: A list of loaded documents.
"""
document_loader = PyPDFDirectoryLoader(DATA_PATH)
return document_loader.load()
def split_documents(documents: List[Document]) -> List[Document]:
"""
Split the documents into smaller chunks.
Args:
documents (List[Document]): The list of documents to split.
Returns:
List[Document]: A list of document chunks.
"""
text_splitter = RecursiveCharacterTextSplitter(
chunk_size=2000,
chunk_overlap=200,
length_function=len,
is_separator_regex=False,
)
return text_splitter.split_documents(documents)
def calculate_chunk_ids(chunks: List[Document]) -> List[Document]:
"""
Calculate and add unique IDs to document chunks.
Args:
chunks (List[Document]): The list of document chunks.
Returns:
List[Document]: The list of document chunks with added IDs.
"""
last_page_id = None
current_chunk_index = 0
for chunk in chunks:
source = chunk.metadata.get("source")
page = chunk.metadata.get("page")
current_page_id = f"{source}:{page}"
if current_page_id == last_page_id:
current_chunk_index += 1
else:
current_chunk_index = 0
chunk_id = f"{current_page_id}:{current_chunk_index}"
last_page_id = current_page_id
chunk.metadata["id"] = chunk_id
return chunks
def add_to_chroma(chunks: List[Document]):
"""
Add new document chunks to the Chroma vector store.
Args:
chunks (List[Document]): The list of document chunks to add.
"""
db = Chroma(
persist_directory=CHROMA_PATH,
embedding_function=get_embedding_function()
)
chunks_with_ids = calculate_chunk_ids(chunks)
existing_items = db.get()
existing_ids = set(existing_items["ids"])
print(f"Number of existing documents in DB: {len(existing_ids)}")
new_chunks = [chunk for chunk in chunks_with_ids if chunk.metadata["id"] not in existing_ids]
if new_chunks:
print(f"Adding new documents: {len(new_chunks)}")
new_chunk_ids = [chunk.metadata["id"] for chunk in new_chunks]
db.add_documents(new_chunks, ids=new_chunk_ids)
else:
print("No new documents to add.")
def clear_database():
"""
Clear the existing Chroma database.
"""
if os.path.exists(CHROMA_PATH):
shutil.rmtree(CHROMA_PATH)
print("Database cleared.")
def main():
"""
Main function to handle document processing and database population.
"""
parser = argparse.ArgumentParser(description="Populate or reset the Chroma database with Harry Potter PDF documents.")
parser.add_argument("--reset", action="store_true", help="Reset the database before populating")
args = parser.parse_args()
if args.reset:
print("Resetting the database...")
clear_database()
print("Loading documents...")
documents = load_documents()
print("Splitting documents...")
chunks = split_documents(documents)
print("Adding documents to Chroma...")
add_to_chroma(chunks)
print("Database population complete.")
if __name__ == "__main__":
main()