-
Notifications
You must be signed in to change notification settings - Fork 0
/
main.py
77 lines (65 loc) · 2.33 KB
/
main.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
import box
import timeit
import yaml
import argparse
from dotenv import find_dotenv, load_dotenv
import atexit
import os
import readline
from src.utils import setup_dbqa
# Load environment variables from .env file
load_dotenv(find_dotenv())
# Import config vars
with open('config/config.yml', 'r', encoding='utf8') as ymlfile:
cfg = box.Box(yaml.safe_load(ymlfile))
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument('input',
nargs='?',
type=str,
default='',
help='Enter the query to pass into the LLM')
args = parser.parse_args()
# Setup DBQA
dbqa = setup_dbqa()
# query loop
query = args.input
if not query:
histfile = os.path.join(os.path.expanduser("~"), ".docqa_history")
try:
readline.read_history_file(histfile)
readline.set_history_length(1000)
except FileNotFoundError:
pass
atexit.register(readline.write_history_file, histfile)
query = input('\nEnter the question: ').strip()
while query != '\\q':
if query == '\\timing':
cfg.TIMING = not cfg.TIMING
print('Timing is {}'.format('on' if cfg.TIMING else 'off'))
query = input('\nEnter the query: ')
continue
start = timeit.default_timer()
response = dbqa({'query': query})
end = timeit.default_timer()
# Process source documents
source_docs = response['source_documents'] if 'source_documents' in response else []
for i, doc in enumerate(source_docs):
print('')
print('='* 50)
print(f'\nSource Document {i+1}\n')
print(f'Source Text: {doc.page_content}')
if 'source' in doc.metadata:
print(f'Document Name: {doc.metadata["source"]}')
if 'page' in doc.metadata:
print(f'Page Number: {doc.metadata["page"]}\n')
print('='*50)
print(f'\nQuestion: {query}\n')
print(f'\nAnswer: {response["result"]}\n')
if cfg.TIMING:
print('='*20)
print(f"Time to retrieve response: {end - start}")
if args.input:
break
print('='* 80)
query = input('\nEnter the query: ')