-
-
Notifications
You must be signed in to change notification settings - Fork 0
/
data.py
134 lines (106 loc) · 3.62 KB
/
data.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
# A collection of convenience functions to interface with the database
import datetime
from interactions.api.models.misc import Snowflake
from loguru import logger
from typing import Optional, Tuple
from peewee import fn, JOIN
from db import Survey, Option, Vote
def create_survey(
message_id: str,
message_url: str,
author: str,
question: str,
options: list,
expires: Optional[datetime.datetime],
realm: str = "unknown",
votes_hidden: bool = False,
vote_limit: Optional[int] = None
) -> Survey:
logger.info(realm)
survey = Survey.create(
message_id=message_id,
message_url=message_url,
author=author,
question=question,
vote_limit=vote_limit,
expires=expires,
realm=realm,
votes_hidden=votes_hidden
)
for i, option in enumerate(options):
Option.create(
survey=survey,
idx=i,
text=option.text,
text_emoji_name=option.text_emoji_name,
button_text=option.button_text,
button_emoji_name=option.button_emoji_name,
button_emoji_id=option.button_emoji_id,
color=option.color
)
return survey
def get_option_by_index(survey: Survey, option_idx: int) -> Option:
return Option.select().where(
Option.survey == survey,
Option.idx == option_idx
)
def cast_vote(
voter_id: str,
survey: Survey,
option_idx: int
) -> None:
# If vote limit is 1
# If they casted a vote for a different option
# Remove the existing vote
# Cast the new vote
# If they casted a vote for the same option
# Remove the existing vote
# If the vote limit is > 1
# If they have cast a vote for the same option
# Remove the existing vote
# If they cast a vote for a different option
# If their vote count is below the vote_limit
# Cast the new vote
Vote.delete().where(
Vote.voter == voter_id,
Vote.survey == survey
).execute()
Vote.create(
voter=voter_id,
survey=survey,
option=get_option_by_index(survey, option_idx),
option_idx=option_idx
)
def get_survey_by_message_id(message_id: str) -> Survey:
logger.info(f"Retrieving message with id: {message_id}")
return Survey.select().where(Survey.message_id == message_id).first()
def get_option_counts_for_survey(survey: Survey) -> Tuple[dict[int, int], int]:
options = (Option
.select(Option.idx, fn.COUNT(Vote.id).alias('count'))
.join(Vote, JOIN.LEFT_OUTER)
.where(Option.survey == survey)
.group_by(Option.idx))
option_counts: dict[int, int] = {}
total = 0
for o in options:
option_counts[o.idx] = o.count
total = total + o.count
return option_counts, total
def update_survey_message_info(
survey: Survey,
message_id: str,
message_url: str):
res = (Survey
.update(message_id=message_id, message_url=message_url)
.where(Survey.id == survey.id)
.execute())
logger.info(res)
return res
def expire_survey(survey: Survey, member_id: Snowflake):
expired = datetime.datetime.now() - datetime.timedelta(seconds=1)
res = (Survey
.update(expires=expired, locked_by=str(member_id))
.where(Survey.id == survey.id)
.execute())
logger.info(res)
return res