diff --git a/Driver.cpp b/Driver.cpp index 4c3ceec..4b4f0bc 100644 --- a/Driver.cpp +++ b/Driver.cpp @@ -269,6 +269,12 @@ OvpnEvtIoDeviceControl(WDFQUEUE queue, WDFREQUEST request, size_t outputBufferLe ExReleaseSpinLockExclusive(&device->SpinLock, kirql); break; + case OVPN_IOCTL_NEW_KEY_V2: + kirql = ExAcquireSpinLockExclusive(&device->SpinLock); + status = OvpnPeerNewKeyV2(device, request); + ExReleaseSpinLockExclusive(&device->SpinLock, kirql); + break; + case OVPN_IOCTL_SWAP_KEYS: kirql = ExAcquireSpinLockExclusive(&device->SpinLock); status = OvpnPeerSwapKeys(device); diff --git a/Driver.h b/Driver.h index 01df6c9..bf221ce 100644 --- a/Driver.h +++ b/Driver.h @@ -94,7 +94,6 @@ struct OVPN_DEVICE { _Guarded_by_(SpinLock) RTL_GENERIC_TABLE Peers; - SIZE_T CryptoOverhead; }; typedef OVPN_DEVICE * POVPN_DEVICE; diff --git a/PropertySheet.props b/PropertySheet.props index 61bdfe9..6056166 100644 --- a/PropertySheet.props +++ b/PropertySheet.props @@ -3,7 +3,7 @@ 2 - 3 + 4 0 diff --git a/bufferpool.h b/bufferpool.h index fcde323..1ed0721 100644 --- a/bufferpool.h +++ b/bufferpool.h @@ -74,7 +74,6 @@ struct OVPN_RX_BUFFER UCHAR Data[OVPN_SOCKET_RX_PACKET_BUFFER_SIZE]; }; -_Must_inspect_result_ UCHAR* OvpnTxBufferPut(_In_ OVPN_TX_BUFFER* work, SIZE_T len); diff --git a/crypto.cpp b/crypto.cpp index 75005f4..0f915fa 100644 --- a/crypto.cpp +++ b/crypto.cpp @@ -48,11 +48,14 @@ OvpnProtoOp32Compose(UINT opcode, UINT keyId, UINT opPeerId) OVPN_CRYPTO_DECRYPT OvpnCryptoDecryptNone; _Use_decl_annotations_ -NTSTATUS OvpnCryptoDecryptNone(OvpnCryptoKeySlot* keySlot, UCHAR* bufIn, SIZE_T len, UCHAR* bufOut) +NTSTATUS OvpnCryptoDecryptNone(OvpnCryptoKeySlot* keySlot, UCHAR* bufIn, SIZE_T len, UCHAR* bufOut, INT32 cryptoOptions) { UNREFERENCED_PARAMETER(keySlot); - if (len < NONE_CRYPTO_OVERHEAD) { + BOOLEAN pktId64bit = cryptoOptions & CRYPTO_OPTIONS_64BIT_PKTID; + BOOLEAN cryptoOverhead = OVPN_DATA_V2_LEN + pktId64bit ? 8 : 4; + + if (len < cryptoOverhead) { LOG_WARN("Packet too short", TraceLoggingValue(len, "len")); return STATUS_DATA_ERROR; } @@ -66,10 +69,11 @@ OVPN_CRYPTO_ENCRYPT OvpnCryptoEncryptNone; _Use_decl_annotations_ NTSTATUS -OvpnCryptoEncryptNone(OvpnCryptoKeySlot* keySlot, UCHAR* buf, SIZE_T len) +OvpnCryptoEncryptNone(OvpnCryptoKeySlot* keySlot, UCHAR* buf, SIZE_T len, INT32 cryptoOptions) { UNREFERENCED_PARAMETER(keySlot); UNREFERENCED_PARAMETER(len); + UNREFERENCED_PARAMETER(cryptoOptions); // prepend with opcode, key-id and peer-id UINT32 op = OvpnProtoOp32Compose(OVPN_OP_DATA_V2, 0, 0); @@ -121,74 +125,116 @@ OvpnCryptoUninitAlgHandles(_In_ BCRYPT_ALG_HANDLE aesAlgHandle, BCRYPT_ALG_HANDL static NTSTATUS -OvpnCryptoAEADDoWork(BOOLEAN encrypt, OvpnCryptoKeySlot* keySlot, UCHAR *bufIn, SIZE_T len, UCHAR* bufOut) +OvpnCryptoAEADDoWork(BOOLEAN encrypt, OvpnCryptoKeySlot* keySlot, UCHAR *bufIn, SIZE_T len, UCHAR* bufOut, INT32 cryptoOptions) { /* AEAD Nonce : [Packet ID] [HMAC keying material] - [4 bytes ] [8 bytes ] + [4/8 bytes] [8/4 bytes ] [AEAD nonce total : 12 bytes ] TLS wire protocol : + Packet ID is 8 bytes long with CRYPTO_OPTIONS_64BIT_PKTID. + [DATA_V2 opcode] [Packet ID] [AEAD Auth tag] [ciphertext] - [4 bytes ] [4 bytes ] [16 bytes ] + [4 bytes ] [4/8 bytes] [16 bytes ] + [AEAD additional data(AD) ] + + With CRYPTO_OPTIONS_AEAD_TAG_END AEAD Auth tag is placed after ciphertext: + + [DATA_V2 opcode] [Packet ID] [ciphertext] [AEAD Auth tag] + [4 bytes ] [4/8 bytes] [16 bytes ] [AEAD additional data(AD) ] */ NTSTATUS status = STATUS_SUCCESS; - if (len < AEAD_CRYPTO_OVERHEAD) { + BOOLEAN pktId64bit = cryptoOptions & CRYPTO_OPTIONS_64BIT_PKTID; + + SIZE_T cryptoOverhead = OVPN_DATA_V2_LEN + AEAD_AUTH_TAG_LEN + (pktId64bit ? 8 : 4); + + if (len < cryptoOverhead) { LOG_WARN("Packet too short", TraceLoggingValue(len, "len")); return STATUS_DATA_ERROR; } - UCHAR nonce[OVPN_PKTID_LEN + OVPN_NONCE_TAIL_LEN]; + UCHAR nonce[12]; if (encrypt) { // prepend with opcode, key-id and peer-id UINT32 op = OvpnProtoOp32Compose(OVPN_OP_DATA_V2, keySlot->KeyId, keySlot->PeerId); op = RtlUlongByteSwap(op); - *(UINT32*)(bufOut) = op; + *reinterpret_cast(bufOut) = op; - // calculate pktid - UINT32 pktid; - GOTO_IF_NOT_NT_SUCCESS(done, status, OvpnPktidXmitNext(&keySlot->PktidXmit, &pktid)); - ULONG pktidNetwork = RtlUlongByteSwap(pktid); + if (pktId64bit) + { + // calculate pktid + UINT64 pktid; + GOTO_IF_NOT_NT_SUCCESS(done, status, OvpnPktidXmitNext(&keySlot->PktidXmit, &pktid, true)); + ULONG64 pktidNetwork = RtlUlonglongByteSwap(pktid); + + // calculate nonce, which is pktid + nonce_tail + RtlCopyMemory(nonce, &pktidNetwork, 8); + RtlCopyMemory(nonce + 8, keySlot->EncNonceTail, 4); + + // prepend with pktid + *reinterpret_cast(bufOut + OVPN_DATA_V2_LEN) = pktidNetwork; + } + else + { + // calculate pktid + UINT32 pktid; + GOTO_IF_NOT_NT_SUCCESS(done, status, OvpnPktidXmitNext(&keySlot->PktidXmit, &pktid, false)); + ULONG pktidNetwork = RtlUlongByteSwap(pktid); - // calculate nonce, which is pktid + nonce_tail - RtlCopyMemory(nonce, &pktidNetwork, OVPN_PKTID_LEN); - RtlCopyMemory(nonce + OVPN_PKTID_LEN, keySlot->EncNonceTail, OVPN_NONCE_TAIL_LEN); + // calculate nonce, which is pktid + nonce_tail + RtlCopyMemory(nonce, &pktidNetwork, 4); + RtlCopyMemory(nonce + 4, keySlot->EncNonceTail, 8); - // prepend with pktid - *(UINT32*)(bufOut + OVPN_DATA_V2_LEN) = pktidNetwork; + // prepend with pktid + *reinterpret_cast(bufOut + OVPN_DATA_V2_LEN) = pktidNetwork; + } } else { - RtlCopyMemory(nonce, bufIn + OVPN_DATA_V2_LEN, OVPN_PKTID_LEN); - RtlCopyMemory(nonce + OVPN_PKTID_LEN, &keySlot->DecNonceTail, sizeof(keySlot->DecNonceTail)); + ULONG64 pktId; + + RtlCopyMemory(nonce, bufIn + OVPN_DATA_V2_LEN, pktId64bit ? 8 : 4); + RtlCopyMemory(nonce + (pktId64bit ? 8 : 4), &keySlot->DecNonceTail, pktId64bit ? 4 : 8); + if (pktId64bit) + { + pktId = RtlUlonglongByteSwap(*reinterpret_cast(nonce)); + } + else + { + pktId = static_cast(RtlUlongByteSwap(*reinterpret_cast(nonce))); + } - UINT32 pktId = RtlUlongByteSwap(*(UINT32*)nonce); status = OvpnPktidRecvVerify(&keySlot->PktidRecv, pktId); if (!NT_SUCCESS(status)) { - LOG_ERROR("Invalid pktId", TraceLoggingUInt32(pktId, "pktId")); + LOG_ERROR("Invalid pktId", TraceLoggingUInt64(pktId, "pktId")); return STATUS_DATA_ERROR; } } + // we prepended buf with crypto overhead + len -= cryptoOverhead; + + BOOLEAN aeadTagEnd = cryptoOptions & CRYPTO_OPTIONS_AEAD_TAG_END; + BCRYPT_AUTHENTICATED_CIPHER_MODE_INFO authInfo; BCRYPT_INIT_AUTH_MODE_INFO(authInfo); authInfo.pbNonce = nonce; authInfo.cbNonce = sizeof(nonce); - authInfo.pbTag = (encrypt ? bufOut : bufIn) + OVPN_DATA_V2_LEN + OVPN_PKTID_LEN; + authInfo.pbTag = (encrypt ? bufOut : bufIn) + OVPN_DATA_V2_LEN + (pktId64bit ? 8 : 4) + (aeadTagEnd ? len : 0); authInfo.cbTag = AEAD_AUTH_TAG_LEN; authInfo.pbAuthData = (encrypt ? bufOut : bufIn); - authInfo.cbAuthData = OVPN_DATA_V2_LEN + OVPN_PKTID_LEN; - - bufOut += AEAD_CRYPTO_OVERHEAD; - bufIn += AEAD_CRYPTO_OVERHEAD; + authInfo.cbAuthData = OVPN_DATA_V2_LEN + (pktId64bit ? 8 : 4); - len -= AEAD_CRYPTO_OVERHEAD; + auto payloadOffset = OVPN_DATA_V2_LEN + (pktId64bit ? 8 : 4) + (aeadTagEnd ? 0 : AEAD_AUTH_TAG_LEN); + bufOut += payloadOffset; + bufIn += payloadOffset; // non-chaining mode ULONG bytesDone = 0; @@ -205,27 +251,29 @@ OVPN_CRYPTO_DECRYPT OvpnCryptoDecryptAEAD; _Use_decl_annotations_ NTSTATUS -OvpnCryptoDecryptAEAD(OvpnCryptoKeySlot* keySlot, UCHAR* bufIn, SIZE_T len, UCHAR* bufOut) +OvpnCryptoDecryptAEAD(OvpnCryptoKeySlot* keySlot, UCHAR* bufIn, SIZE_T len, UCHAR* bufOut, INT32 cryptoOptions) { - return OvpnCryptoAEADDoWork(FALSE, keySlot, bufIn, len, bufOut); + return OvpnCryptoAEADDoWork(FALSE, keySlot, bufIn, len, bufOut, cryptoOptions); } OVPN_CRYPTO_ENCRYPT OvpnCryptoEncryptAEAD; _Use_decl_annotations_ NTSTATUS -OvpnCryptoEncryptAEAD(OvpnCryptoKeySlot* keySlot, UCHAR* buf, SIZE_T len) +OvpnCryptoEncryptAEAD(OvpnCryptoKeySlot* keySlot, UCHAR* buf, SIZE_T len, INT32 cryptoOptions) { - return OvpnCryptoAEADDoWork(TRUE, keySlot, buf, len, buf); + return OvpnCryptoAEADDoWork(TRUE, keySlot, buf, len, buf, cryptoOptions); } _Use_decl_annotations_ NTSTATUS -OvpnCryptoNewKey(OvpnCryptoContext* cryptoContext, POVPN_CRYPTO_DATA cryptoData, BCRYPT_ALG_HANDLE algHandle) +OvpnCryptoNewKey(OvpnCryptoContext* cryptoContext, POVPN_CRYPTO_DATA_V2 cryptoDataV2, BCRYPT_ALG_HANDLE algHandle) { OvpnCryptoKeySlot* keySlot = NULL; NTSTATUS status = STATUS_SUCCESS; + POVPN_CRYPTO_DATA cryptoData = &cryptoDataV2->V1; + if (cryptoData->KeySlot == OVPN_KEY_SLOT::OVPN_KEY_SLOT_PRIMARY) { keySlot = &cryptoContext->Primary; } @@ -237,6 +285,15 @@ OvpnCryptoNewKey(OvpnCryptoContext* cryptoContext, POVPN_CRYPTO_DATA cryptoData, return STATUS_INVALID_DEVICE_REQUEST; } + if (cryptoDataV2->CryptoOptions & CRYPTO_OPTIONS_64BIT_PKTID) + { + cryptoContext->CryptoOptions |= CRYPTO_OPTIONS_64BIT_PKTID; + } + if (cryptoDataV2->CryptoOptions & CRYPTO_OPTIONS_AEAD_TAG_END) + { + cryptoContext->CryptoOptions |= CRYPTO_OPTIONS_AEAD_TAG_END; + } + if ((cryptoData->CipherAlg == OVPN_CIPHER_ALG_AES_GCM) || (cryptoData->CipherAlg == OVPN_CIPHER_ALG_CHACHA20_POLY1305)) { // destroy previous keys if (keySlot->EncKey) { diff --git a/crypto.h b/crypto.h index ee35cba..3782d42 100644 --- a/crypto.h +++ b/crypto.h @@ -29,14 +29,8 @@ #include "uapi\ovpn-dco.h" #include "socket.h" -#define AEAD_CRYPTO_OVERHEAD 24 // 4 + 4 + 16 data_v2 + pktid + auth_tag -#define NONE_CRYPTO_OVERHEAD 8 // 4 + 4 data_v2 + pktid -#define OVPN_PKTID_LEN 4 -#define OVPN_NONCE_TAIL_LEN 8 #define OVPN_DATA_V2_LEN 4 #define AEAD_AUTH_TAG_LEN 16 -#define AES_BLOCK_SIZE 16 -#define AES_GCM_NONCE_LEN 12 // packet opcode (high 5 bits) and key-id (low 3 bits) are combined in one byte #define OVPN_OP_DATA_V2 9 @@ -63,7 +57,7 @@ _IRQL_requires_max_(DISPATCH_LEVEL) _Must_inspect_result_ typedef NTSTATUS -OVPN_CRYPTO_ENCRYPT(_In_ OvpnCryptoKeySlot* keySlot, _In_ UCHAR* buf, _In_ SIZE_T len); +OVPN_CRYPTO_ENCRYPT(_In_ OvpnCryptoKeySlot* keySlot, _In_ UCHAR* buf, _In_ SIZE_T len, _In_ INT32 CryptoOptions); typedef OVPN_CRYPTO_ENCRYPT* POVPN_CRYPTO_ENCRYPT; _Function_class_(OVPN_CRYPTO_DECRYPT) @@ -71,7 +65,7 @@ _IRQL_requires_max_(DISPATCH_LEVEL) _Must_inspect_result_ typedef NTSTATUS -OVPN_CRYPTO_DECRYPT(_In_ OvpnCryptoKeySlot* keySlot, _In_ UCHAR* bufIn, _In_ SIZE_T len, _In_ UCHAR* bufOut); +OVPN_CRYPTO_DECRYPT(_In_ OvpnCryptoKeySlot* keySlot, _In_ UCHAR* bufIn, _In_ SIZE_T len, _In_ UCHAR* bufOut, _In_ INT32 CryptoOptions); typedef OVPN_CRYPTO_DECRYPT* POVPN_CRYPTO_DECRYPT; struct OvpnCryptoContext @@ -82,7 +76,7 @@ struct OvpnCryptoContext POVPN_CRYPTO_ENCRYPT Encrypt; POVPN_CRYPTO_DECRYPT Decrypt; - SIZE_T CryptoOverhead; + INT32 CryptoOptions; }; _Must_inspect_result_ @@ -99,7 +93,7 @@ OvpnCryptoUninit(_In_ OvpnCryptoContext* cryptoContext); _Must_inspect_result_ NTSTATUS -OvpnCryptoNewKey(_In_ OvpnCryptoContext* cryptoContext, _In_ POVPN_CRYPTO_DATA cryptoData, _In_opt_ BCRYPT_ALG_HANDLE algHandle); +OvpnCryptoNewKey(_In_ OvpnCryptoContext* cryptoContext, _In_ POVPN_CRYPTO_DATA_V2 cryptoData, _In_opt_ BCRYPT_ALG_HANDLE algHandle); _Must_inspect_result_ OvpnCryptoKeySlot* @@ -119,4 +113,4 @@ static inline UCHAR OvpnCryptoOpcodeExtract(UCHAR op) { return op >> OVPN_OPCODE_SHIFT; -} \ No newline at end of file +} diff --git a/peer.cpp b/peer.cpp index 3314fee..a2fa688 100644 --- a/peer.cpp +++ b/peer.cpp @@ -303,6 +303,31 @@ OvpnPeerStartVPN(POVPN_DEVICE device) return status; } +static NTSTATUS +OvpnPeerGetAlgHandle(POVPN_DEVICE device, OVPN_CIPHER_ALG cipherAlg, BCRYPT_ALG_HANDLE& algHandle) +{ + NTSTATUS status = STATUS_SUCCESS; + + switch (cipherAlg) { + case OVPN_CIPHER_ALG_AES_GCM: + algHandle = device->AesAlgHandle; + break; + + case OVPN_CIPHER_ALG_CHACHA20_POLY1305: + algHandle = device->ChachaAlgHandle; + if (algHandle == NULL) { + LOG_ERROR("CHACHA20-POLY1305 is not available"); + status = STATUS_INVALID_DEVICE_REQUEST; + } + break; + + default: + break; + } + + return status; +} + _Use_decl_annotations_ NTSTATUS OvpnPeerNewKey(POVPN_DEVICE device, WDFREQUEST request) @@ -311,44 +336,63 @@ OvpnPeerNewKey(POVPN_DEVICE device, WDFREQUEST request) NTSTATUS status = STATUS_SUCCESS; + POVPN_CRYPTO_DATA cryptoData = NULL; + OVPN_CRYPTO_DATA_V2 cryptoDataV2{}; + if (!OvpnHasPeers(device)) { LOG_ERROR("Peer not added"); status = STATUS_INVALID_DEVICE_REQUEST; goto done; } - POVPN_CRYPTO_DATA cryptoData = NULL; - GOTO_IF_NOT_NT_SUCCESS(done, status, WdfRequestRetrieveInputBuffer(request, sizeof(OVPN_CRYPTO_DATA), (PVOID*)&cryptoData, nullptr)); BCRYPT_ALG_HANDLE algHandle = NULL; - switch (cryptoData->CipherAlg) { - case OVPN_CIPHER_ALG_AES_GCM: - algHandle = device->AesAlgHandle; - device->CryptoOverhead = AEAD_CRYPTO_OVERHEAD; - break; + GOTO_IF_NOT_NT_SUCCESS(done, status, OvpnPeerGetAlgHandle(device, cryptoData->CipherAlg, algHandle)); - case OVPN_CIPHER_ALG_CHACHA20_POLY1305: - algHandle = device->ChachaAlgHandle; - if (algHandle == NULL) { - LOG_ERROR("CHACHA20-POLY1305 is not available"); - status = STATUS_INVALID_DEVICE_REQUEST; - goto done; - } - device->CryptoOverhead = AEAD_CRYPTO_OVERHEAD; + OvpnPeerContext* peer = OvpnGetFirstPeer(&device->Peers); + if (peer == NULL) { + status = STATUS_OBJECTID_NOT_FOUND; + goto done; + } - default: - device->CryptoOverhead = NONE_CRYPTO_OVERHEAD; - break; + RtlCopyMemory(&cryptoDataV2.V1, cryptoData, sizeof(OVPN_CRYPTO_DATA)); + GOTO_IF_NOT_NT_SUCCESS(done, status, OvpnCryptoNewKey(&peer->CryptoContext, &cryptoDataV2, algHandle)); + +done: + LOG_EXIT(); + + return status; +} + +_Use_decl_annotations_ +NTSTATUS +OvpnPeerNewKeyV2(POVPN_DEVICE device, WDFREQUEST request) +{ + LOG_ENTER(); + + NTSTATUS status = STATUS_SUCCESS; + + POVPN_CRYPTO_DATA_V2 cryptoDataV2 = NULL; + + if (!OvpnHasPeers(device)) { + LOG_ERROR("Peer not added"); + status = STATUS_INVALID_DEVICE_REQUEST; + goto done; } + GOTO_IF_NOT_NT_SUCCESS(done, status, WdfRequestRetrieveInputBuffer(request, sizeof(OVPN_CRYPTO_DATA_V2), (PVOID*)&cryptoDataV2, nullptr)); + + BCRYPT_ALG_HANDLE algHandle = NULL; + GOTO_IF_NOT_NT_SUCCESS(done, status, OvpnPeerGetAlgHandle(device, cryptoDataV2->V1.CipherAlg, algHandle)); + OvpnPeerContext* peer = OvpnGetFirstPeer(&device->Peers); if (peer == NULL) { status = STATUS_OBJECTID_NOT_FOUND; goto done; } - GOTO_IF_NOT_NT_SUCCESS(done, status, OvpnCryptoNewKey(&peer->CryptoContext, cryptoData, algHandle)); + GOTO_IF_NOT_NT_SUCCESS(done, status, OvpnCryptoNewKey(&peer->CryptoContext, cryptoDataV2, algHandle)); done: LOG_EXIT(); diff --git a/peer.h b/peer.h index e1eb5d4..fc780f8 100644 --- a/peer.h +++ b/peer.h @@ -85,6 +85,11 @@ _Requires_exclusive_lock_held_(device->SpinLock) NTSTATUS OvpnPeerNewKey(_In_ POVPN_DEVICE device, WDFREQUEST request); +_Must_inspect_result_ +_Requires_exclusive_lock_held_(device->SpinLock) +NTSTATUS +OvpnPeerNewKeyV2(_In_ POVPN_DEVICE device, WDFREQUEST request); + _Must_inspect_result_ _Requires_exclusive_lock_held_(device->SpinLock) NTSTATUS diff --git a/pktid.cpp b/pktid.cpp index b7f365d..4d94ee3 100644 --- a/pktid.cpp +++ b/pktid.cpp @@ -28,24 +28,29 @@ #define PKTID_WRAP_WARN 0xf0000000ULL _Use_decl_annotations_ -NTSTATUS OvpnPktidXmitNext(OvpnPktidXmit* px, UINT32* pktId) +NTSTATUS OvpnPktidXmitNext(OvpnPktidXmit* px, VOID* pktId, BOOLEAN pktId64bit) { ULONG64 seqNum = InterlockedIncrementNoFence64(&px->SeqNum); - *pktId = (UINT32)seqNum; - if (seqNum < PKTID_WRAP_WARN) { - return STATUS_SUCCESS; - } - else { - LOG_ERROR("Pktid wrapped"); - return STATUS_INTEGER_OVERFLOW; - } + if (pktId64bit) { + *static_cast(pktId) = seqNum; + } + else + { + *static_cast(pktId) = static_cast(seqNum); + if (seqNum >= PKTID_WRAP_WARN) { + LOG_ERROR("Pktid wrapped"); + return STATUS_INTEGER_OVERFLOW; + } + } + + return STATUS_SUCCESS; } #define PKTID_RECV_EXPIRE ((30 * WDF_TIMEOUT_TO_SEC) / KeQueryTimeIncrement()) _Use_decl_annotations_ -NTSTATUS OvpnPktidRecvVerify(OvpnPktidRecv* pr, UINT32 pktId) +NTSTATUS OvpnPktidRecvVerify(OvpnPktidRecv* pr, UINT64 pktId) { LARGE_INTEGER now; KeQueryTickCount(&now); @@ -69,16 +74,16 @@ NTSTATUS OvpnPktidRecvVerify(OvpnPktidRecv* pr, UINT32 pktId) } else if (pktId > pr->Id) { /* ID jumped forward by more than one */ - UINT32 delta = pktId - pr->Id; + const auto delta = pktId - pr->Id; if (delta < REPLAY_WINDOW_SIZE) { pr->Base = REPLAY_INDEX(pr->Base, -(INT32)delta); pr->History[pr->Base / 8] |= (1 << (pr->Base % 8)); - pr->Extent += delta; + pr->Extent += static_cast(delta); if (pr->Extent > REPLAY_WINDOW_SIZE) pr->Extent = REPLAY_WINDOW_SIZE; - for (UINT32 i = 1; i < delta; ++i) { - unsigned int newb = REPLAY_INDEX(pr->Base, i); + for (auto i = 1; i < delta; ++i) { + const auto newb = REPLAY_INDEX(pr->Base, i); pr->History[newb / 8] &= ~BIT(newb % 8); } @@ -93,10 +98,8 @@ NTSTATUS OvpnPktidRecvVerify(OvpnPktidRecv* pr, UINT32 pktId) } else { /* ID backtrack */ - UINT32 delta = pr->Id - pktId; + const auto delta = pr->Id - pktId; - if (delta > pr->MaxBacktrack) - pr->MaxBacktrack = delta; if (delta < pr->Extent) { if (pktId > pr->IdFloor) { UINT32 ri = REPLAY_INDEX(pr->Base, delta); diff --git a/pktid.h b/pktid.h index b0d2325..dcc4be8 100644 --- a/pktid.h +++ b/pktid.h @@ -50,17 +50,17 @@ struct OvpnPktidRecv LARGE_INTEGER Expire; /* highest sequence number received */ - UINT32 Id; + UINT64 Id; /* we will only accept backtrack IDs > id_floor */ - UINT32 IdFloor; - UINT32 MaxBacktrack; + UINT64 IdFloor; }; /* Get the next packet ID for xmit */ -NTSTATUS OvpnPktidXmitNext(_In_ OvpnPktidXmit* px, _Out_ UINT32* pktId); +NTSTATUS OvpnPktidXmitNext(_In_ OvpnPktidXmit* px, _Out_ VOID* pktId, BOOLEAN pktId64bit); + /* Packet replay detection. * Allows ID backtrack of up to REPLAY_WINDOW_SIZE - 1. */ -NTSTATUS OvpnPktidRecvVerify(_In_ OvpnPktidRecv* pid, UINT32 pktId); +NTSTATUS OvpnPktidRecvVerify(_In_ OvpnPktidRecv* pid, UINT64 pktId); diff --git a/rxqueue.cpp b/rxqueue.cpp index ce7d71b..d7687c3 100644 --- a/rxqueue.cpp +++ b/rxqueue.cpp @@ -27,6 +27,7 @@ #include "driver.h" #include "bufferpool.h" +#include "peer.h" #include "rxqueue.h" #include "netringiterator.h" #include "trace.h" @@ -101,6 +102,16 @@ OvpnEvtRxQueueAdvance(NETPACKETQUEUE netPacketQueue) POVPN_RXQUEUE queue = OvpnGetRxQueueContext(netPacketQueue); OVPN_DEVICE* device = OvpnGetDeviceContext(queue->Adapter->WdfDevice); + OvpnPeerContext* peer = OvpnGetFirstPeer(&device->Peers); + if (peer == NULL) { + LOG_WARN("No peer"); + return; + } + + BOOLEAN pktId64bit = peer->CryptoContext.CryptoOptions & CRYPTO_OPTIONS_64BIT_PKTID; + BOOLEAN aeadTagEnd = peer->CryptoContext.CryptoOptions & CRYPTO_OPTIONS_AEAD_TAG_END; + auto payloadOffset = OVPN_DATA_V2_LEN + (pktId64bit ? 8 : 4) + (aeadTagEnd ? 0 : AEAD_AUTH_TAG_LEN); + NET_RING_FRAGMENT_ITERATOR fi = NetRingGetAllFragments(queue->Rings); NET_RING_PACKET_ITERATOR pi = NetRingGetAllPackets(queue->Rings); while (NetFragmentIteratorHasAny(&fi)) { @@ -115,7 +126,7 @@ OvpnEvtRxQueueAdvance(NETPACKETQUEUE netPacketQueue) fragment->ValidLength = buffer->Len; fragment->Offset = 0; NET_FRAGMENT_VIRTUAL_ADDRESS* virtualAddr = NetExtensionGetFragmentVirtualAddress(&queue->VirtualAddressExtension, NetFragmentIteratorGetIndex(&fi)); - RtlCopyMemory(virtualAddr->VirtualAddress, buffer->Data + device->CryptoOverhead, buffer->Len); + RtlCopyMemory(virtualAddr->VirtualAddress, buffer->Data + payloadOffset, buffer->Len); InterlockedExchangeAddNoFence64(&device->Stats.TunBytesReceived, buffer->Len); diff --git a/socket.cpp b/socket.cpp index 1b3ca42..2fcf2d7 100644 --- a/socket.cpp +++ b/socket.cpp @@ -178,9 +178,11 @@ VOID OvpnSocketDataPacketReceived(_In_ POVPN_DEVICE device, UCHAR op, _In_reads_ return; } - if (peer->CryptoContext.Decrypt) { + OvpnCryptoContext* cryptoContext = &peer->CryptoContext; + + if (cryptoContext->Decrypt) { UCHAR keyId = OvpnCryptoKeyIdExtract(op); - OvpnCryptoKeySlot* keySlot = OvpnCryptoKeySlotFromKeyId(&peer->CryptoContext, keyId); + OvpnCryptoKeySlot* keySlot = OvpnCryptoKeySlotFromKeyId(cryptoContext, keyId); if (!keySlot) { status = STATUS_INVALID_DEVICE_STATE; @@ -188,8 +190,11 @@ VOID OvpnSocketDataPacketReceived(_In_ POVPN_DEVICE device, UCHAR op, _In_reads_ } else { // decrypt into plaintext buffer - status = peer->CryptoContext.Decrypt(keySlot, cipherTextBuf, len, buffer->Data); - buffer->Len = len - device->CryptoOverhead; + status = cryptoContext->Decrypt(keySlot, cipherTextBuf, len, buffer->Data, cryptoContext->CryptoOptions); + + auto pktId64bit = cryptoContext->CryptoOptions & CRYPTO_OPTIONS_64BIT_PKTID; + auto cryptoOverhead = OVPN_DATA_V2_LEN + AEAD_AUTH_TAG_LEN + (pktId64bit ? 8 : 4); + buffer->Len = len - cryptoOverhead; } } else { @@ -206,20 +211,23 @@ VOID OvpnSocketDataPacketReceived(_In_ POVPN_DEVICE device, UCHAR op, _In_reads_ OvpnTimerResetRecv(peer->Timer); // points to the beginning of plaintext - UCHAR* buf = buffer->Data + device->CryptoOverhead; + BOOLEAN pktId64bit = peer->CryptoContext.CryptoOptions & CRYPTO_OPTIONS_64BIT_PKTID; + BOOLEAN aeadTagEnd = peer->CryptoContext.CryptoOptions & CRYPTO_OPTIONS_AEAD_TAG_END; + auto payloadOffset = OVPN_DATA_V2_LEN + (pktId64bit ? 8 : 4) + (aeadTagEnd ? 0 : AEAD_AUTH_TAG_LEN); + UCHAR* plaintext = buffer->Data + payloadOffset; // ping packet? - if (OvpnTimerIsKeepaliveMessage(buf, buffer->Len)) { + if (OvpnTimerIsKeepaliveMessage(plaintext, buffer->Len)) { LOG_INFO("Ping received"); // no need to inject ping packet into OS, return buffer to the pool OvpnRxBufferPoolPut(buffer); } else { - if (OvpnMssIsIPv4(buf, buffer->Len)) { - OvpnMssDoIPv4(buf, buffer->Len, device->MSS); - } else if (OvpnMssIsIPv6(buf, buffer->Len)) { - OvpnMssDoIPv6(buf, buffer->Len, device->MSS); + if (OvpnMssIsIPv4(plaintext, buffer->Len)) { + OvpnMssDoIPv4(plaintext, buffer->Len, device->MSS); + } else if (OvpnMssIsIPv6(plaintext, buffer->Len)) { + OvpnMssDoIPv6(plaintext, buffer->Len, device->MSS); } // enqueue plaintext buffer, it will be dequeued by NetAdapter RX datapath diff --git a/timer.cpp b/timer.cpp index 6ce4c95..95eeaa1 100644 --- a/timer.cpp +++ b/timer.cpp @@ -73,12 +73,16 @@ static VOID OvpnTimerXmit(WDFTIMER timer) OvpnPeerContext* peer = timerCtx->Peer; KIRQL kiqrl = ExAcquireSpinLockShared(&device->SpinLock); - if (peer->CryptoContext.Encrypt) { + OvpnCryptoContext* cryptoContext = &peer->CryptoContext; + if (cryptoContext->Encrypt) { // make space to crypto overhead - OvpnTxBufferPush(buffer, device->CryptoOverhead); + BOOLEAN pktId64bit = cryptoContext->CryptoOptions & CRYPTO_OPTIONS_64BIT_PKTID; + BOOLEAN aeadTagEnd = cryptoContext->CryptoOptions & CRYPTO_OPTIONS_AEAD_TAG_END; + + OvpnTxBufferPush(buffer, OVPN_DATA_V2_LEN + (pktId64bit ? 8 : 4) + (aeadTagEnd ? 0 : AEAD_AUTH_TAG_LEN)); // in-place encrypt, always with primary key - status = peer->CryptoContext.Encrypt(&peer->CryptoContext.Primary, buffer->Data, buffer->Len); + status = cryptoContext->Encrypt(&cryptoContext->Primary, buffer->Data, buffer->Len, cryptoContext->CryptoOptions); } else { status = STATUS_INVALID_DEVICE_STATE; diff --git a/txqueue.cpp b/txqueue.cpp index 7f49038..feec4c5 100644 --- a/txqueue.cpp +++ b/txqueue.cpp @@ -93,12 +93,21 @@ OvpnTxProcessPacket(_In_ POVPN_DEVICE device, _In_ POVPN_TXQUEUE queue, _In_ NET InterlockedExchangeAddNoFence64(&device->Stats.TunBytesSent, buffer->Len); - if (peer->CryptoContext.Encrypt) { + OvpnCryptoContext* cryptoContext = &peer->CryptoContext; + + if (cryptoContext->Encrypt) { + auto aeadTagEnd = cryptoContext->CryptoOptions & CRYPTO_OPTIONS_AEAD_TAG_END; + auto pktId64bit = cryptoContext->CryptoOptions & CRYPTO_OPTIONS_64BIT_PKTID; + // make space to crypto overhead - OvpnTxBufferPush(buffer, device->CryptoOverhead); + OvpnTxBufferPush(buffer, OVPN_DATA_V2_LEN + (pktId64bit ? 8 : 4) + (aeadTagEnd ? 0 : AEAD_AUTH_TAG_LEN)); + if (aeadTagEnd) + { + OvpnTxBufferPut(buffer, AEAD_AUTH_TAG_LEN); + } // in-place encrypt, always with primary key - status = peer->CryptoContext.Encrypt(&peer->CryptoContext.Primary, buffer->Data, buffer->Len); + status = cryptoContext->Encrypt(&cryptoContext->Primary, buffer->Data, buffer->Len, cryptoContext->CryptoOptions); } else { status = STATUS_INVALID_DEVICE_STATE; @@ -152,7 +161,7 @@ OvpnEvtTxQueueAdvance(NETPACKETQUEUE netPacketQueue) POVPN_TXQUEUE queue = OvpnGetTxQueueContext(netPacketQueue); NET_RING_PACKET_ITERATOR pi = NetRingGetAllPackets(queue->Rings); POVPN_DEVICE device = OvpnGetDeviceContext(queue->Adapter->WdfDevice); - bool packetSent = false; + BOOLEAN packetSent = false; KIRQL kirql = ExAcquireSpinLockShared(&device->SpinLock); diff --git a/uapi/ovpn-dco.h b/uapi/ovpn-dco.h index ea2a733..9e437f0 100644 --- a/uapi/ovpn-dco.h +++ b/uapi/ovpn-dco.h @@ -94,6 +94,14 @@ typedef struct _OVPN_CRYPTO_DATA { int PeerId; } OVPN_CRYPTO_DATA, * POVPN_CRYPTO_DATA; +#define CRYPTO_OPTIONS_AEAD_TAG_END (1<<1) +#define CRYPTO_OPTIONS_64BIT_PKTID (1<<2) + +typedef struct _OVPN_CRYPTO_DATA_V2 { + OVPN_CRYPTO_DATA V1; + UINT32 CryptoOptions; +} OVPN_CRYPTO_DATA_V2, * POVPN_CRYPTO_DATA_V2; + typedef struct _OVPN_SET_PEER { LONG KeepaliveInterval; LONG KeepaliveTimeout; @@ -114,3 +122,4 @@ typedef struct _OVPN_VERSION { #define OVPN_IOCTL_START_VPN CTL_CODE(FILE_DEVICE_UNKNOWN, 6, METHOD_BUFFERED, FILE_ANY_ACCESS) #define OVPN_IOCTL_DEL_PEER CTL_CODE(FILE_DEVICE_UNKNOWN, 7, METHOD_BUFFERED, FILE_ANY_ACCESS) #define OVPN_IOCTL_GET_VERSION CTL_CODE(FILE_DEVICE_UNKNOWN, 8, METHOD_BUFFERED, FILE_ANY_ACCESS) +#define OVPN_IOCTL_NEW_KEY_V2 CTL_CODE(FILE_DEVICE_UNKNOWN, 9, METHOD_BUFFERED, FILE_ANY_ACCESS)