-
Notifications
You must be signed in to change notification settings - Fork 2
/
state_update_db.py
142 lines (127 loc) · 4.5 KB
/
state_update_db.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
# state_update_db.py - updates (or generates) embeddings
# for reports and incidents
# and stores them in the database.
#
# Requires: - the Longformer submodule is downloaded,
# - environment variable MONGODB_CONNECTION_STRING
# set to read the database.
#
# Assumes: - being executed from project root directory
from os import ( environ, path )
from ast import literal_eval
from pandas import read_csv, DataFrame, concat, array
from pymongo import MongoClient
from torch import tensor
from transformers import LongformerTokenizer, LongformerModel
from hashlib import sha1
MONGODB_URI = environ['MONGODB_CONNECTION_STRING']
MODEL_PATH = path.join('inference', 'model')
# Get the Longformer tokenizer and model
tokenizer = LongformerTokenizer.from_pretrained(
MODEL_PATH,
local_files_only=True,
model_max_length=2000
)
model = LongformerModel.from_pretrained(
MODEL_PATH,
local_files_only=True
)
# Process the text of one report and return the CLS token
def cls_token(text):
inp = tokenizer(text,
padding='longest',
truncation='longest_first',
return_tensors='pt')
return model(**inp).last_hidden_state[0][0]
# Aggregate the incident_id and text
# of all reports from Mongo for each incident
client = MongoClient(MONGODB_URI)
db = client['aiidprod']
pipeline = [
{ '$project': {
'_id': False,
'incident_id': True,
'reports': True,
'embedding': True
}
},
{ '$lookup': {
'from': 'reports',
'localField': 'reports',
'foreignField': 'report_number',
'pipeline': [
{ '$project': {
'_id': False,
'text': True,
'report_number': True,
'embedding': True
}
}
],
'as': 'reports'
}
},
{
'$sort': { 'incident_id' : 1 }
}
]
# Update the state using the results
for i, incident in enumerate(db.incidents.aggregate(pipeline)):
try:
print('checking', incident['incident_id']) # DEBUG
new_report_embedding = False
# Update the embeddings of new reports
# and reports where the text has changed
for report in incident['reports']:
print('checking report', report['report_number'])
text_hash = sha1(report['text'].encode('utf-8')).hexdigest()
if not (
report.get('embedding') and
report['embedding']['from_text_hash'] == text_hash
):
new_report_embedding = True
print('Updating embedding')
token = cls_token(report['text'])
del report['text']
report['embedding'] = {
'vector': token.detach().tolist(),
'from_text_hash': text_hash
}
db.reports.update_one(
{ 'report_number' : report['report_number'] },
{ '$set': { 'embedding': report['embedding'] } }
)
# Store the new mean for each incident
mean = incident['reports'][0]['embedding']['vector']
if new_report_embedding:
count = 1
for report in incident['reports'][1:]:
mean = (
tensor(report['embedding']['vector'])
.add(tensor(mean), alpha=count)
.div(count + 1)
.detach()
.tolist()
)
count += 1
report_ids = [r['report_number'] for r in incident['reports']]
if (
(not 'embedding' in incident) or
report_ids != incident['embedding']['from_reports']
):
print('uploading embedding for incident', incident['incident_id']),
db.incidents.update_one(
{ 'incident_id' : incident['incident_id'] },
{ '$set': {
'embedding': {
'vector': mean,
'from_reports': [
report['report_number']
for report in incident['reports']
]
}
}
}
)
except Exception as e:
print(e)