-
Notifications
You must be signed in to change notification settings - Fork 0
/
aes.py
125 lines (88 loc) · 3.1 KB
/
aes.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
from Crypto.Cipher import AES
from binascii import hexlify
def xor_bytes(a, b):
return bytes([x ^ y for x, y in zip(a, b)])
class BadPaddingException(Exception):
pass
def aes_encrypt_block(key, block):
assert(len(block) == 16)
suite = AES.new(key, AES.MODE_ECB)
return suite.encrypt(block)
def aes_decrypt_block(key, block):
assert(len(block) == 16)
suite = AES.new(key, AES.MODE_ECB)
return suite.decrypt(block)
def ecb_encrypt(key, plaintext):
blocks = get_blocks(pad(plaintext))
ciphertext = bytearray()
for block in blocks:
ciphertext.extend(aes_encrypt_block(key, block))
return bytes(ciphertext)
def ecb_decrypt(key, ciphertext):
blocks = get_blocks(ciphertext)
plaintext = bytearray()
for block in blocks:
plaintext.extend(aes_decrypt_block(key, block))
return unpad(bytes(plaintext))
def get_blocks(bytes_, blocksize=16):
return [bytes_[i:i+blocksize] for i in range(0, len(bytes_), blocksize)]
def cbc_encrypt(key, iv, plaintext):
assert(len(iv) == 16)
blocks = get_blocks(pad(plaintext))
ciphertext = bytearray()
for block in blocks:
encrypted = aes_encrypt_block(key, xor_bytes(iv, block))
ciphertext.extend(encrypted)
iv = encrypted
return bytes(ciphertext)
def cbc_decrypt(key, iv, ciphertext):
assert(len(iv) == 16)
blocks = get_blocks(ciphertext)
plaintext = bytearray()
for block in blocks:
decrypted = aes_decrypt_block(key, block)
plaintext.extend(xor_bytes(iv, decrypted))
iv = block
return unpad(bytes(plaintext)) if unpad else plaintext
def ctr_encrypt(key, nonce, plaintext):
nonce = nonce.to_bytes(8, byteorder='little')
keystream = bytes()
block_count = 0
while len(keystream) < len(plaintext):
block = nonce + block_count.to_bytes(8, byteorder='little')
keystream += aes_encrypt_block(key, block)
block_count += 1
return xor_bytes(plaintext, keystream)
def ctr_decrypt(key, nonce, ciphertext):
return ctr_encrypt(key, nonce, ciphertext)
def pad(bytes_, block_size=16):
padding = block_size - len(bytes_) % block_size
return bytes_ + bytes([padding] * padding)
def unpad(bytes_, block_size=16):
pad_val = bytes_[-1]
if pad_val > block_size:
raise BadPaddingException()
pad = bytes_[-pad_val:]
if pad != bytes([pad_val] * pad_val):
raise BadPaddingException()
return bytes_[:-pad_val]
if __name__ == "__main__":
assert(pad(bytes()) == bytes([16] * 16))
key = "YELLOW SUBMARINE".encode('ascii')
msg = "this is a test please work".encode('ascii')
iv = b"\x00" * 16
nonce = 0
print(key)
print(iv)
print("ECB mode")
ecb_encoded = ecb_encrypt(key, msg)
print(hexlify(ecb_encoded))
print(ecb_decrypt(key, ecb_encoded))
print("CBC mode")
cbc_encoded = cbc_encrypt(key, iv, msg)
print(hexlify(cbc_encoded))
print(cbc_decrypt(key, iv, cbc_encoded))
print("CTR mode")
ctr_encoded = ctr_encrypt(key, nonce, msg)
print(hexlify(ctr_encoded))
print(ctr_decrypt(key, nonce, ctr_encoded))