diff --git a/Driver.cpp b/Driver.cpp index 803fa62..5a6b48f 100644 --- a/Driver.cpp +++ b/Driver.cpp @@ -268,6 +268,12 @@ OvpnEvtIoDeviceControl(WDFQUEUE queue, WDFREQUEST request, size_t outputBufferLe ExReleaseSpinLockExclusive(&device->SpinLock, kirql); break; + case OVPN_IOCTL_NEW_KEY_V2: + kirql = ExAcquireSpinLockExclusive(&device->SpinLock); + status = OvpnPeerNewKeyV2(device, request); + ExReleaseSpinLockExclusive(&device->SpinLock, kirql); + break; + case OVPN_IOCTL_SWAP_KEYS: kirql = ExAcquireSpinLockExclusive(&device->SpinLock); status = OvpnPeerSwapKeys(device); diff --git a/PropertySheet.props b/PropertySheet.props index f839915..76c4a62 100644 --- a/PropertySheet.props +++ b/PropertySheet.props @@ -3,7 +3,7 @@ 1 - 3 + 4 0 diff --git a/bufferpool.h b/bufferpool.h index fcde323..1ed0721 100644 --- a/bufferpool.h +++ b/bufferpool.h @@ -74,7 +74,6 @@ struct OVPN_RX_BUFFER UCHAR Data[OVPN_SOCKET_RX_PACKET_BUFFER_SIZE]; }; -_Must_inspect_result_ UCHAR* OvpnTxBufferPut(_In_ OVPN_TX_BUFFER* work, SIZE_T len); diff --git a/crypto.cpp b/crypto.cpp index b6a881a..18f837b 100644 --- a/crypto.cpp +++ b/crypto.cpp @@ -48,11 +48,14 @@ OvpnProtoOp32Compose(UINT opcode, UINT keyId, UINT opPeerId) OVPN_CRYPTO_DECRYPT OvpnCryptoDecryptNone; _Use_decl_annotations_ -NTSTATUS OvpnCryptoDecryptNone(OvpnCryptoKeySlot* keySlot, UCHAR* bufIn, SIZE_T len, UCHAR* bufOut) +NTSTATUS OvpnCryptoDecryptNone(OvpnCryptoKeySlot* keySlot, UCHAR* bufIn, SIZE_T len, UCHAR* bufOut, INT32 cryptoOptions) { UNREFERENCED_PARAMETER(keySlot); - if (len < NONE_CRYPTO_OVERHEAD) { + BOOLEAN pktId64bit = cryptoOptions & CRYPTO_OPTIONS_64BIT_PKTID; + BOOLEAN cryptoOverhead = OVPN_DATA_V2_LEN + pktId64bit ? 8 : 4; + + if (len < cryptoOverhead) { LOG_WARN("Packet too short", TraceLoggingValue(len, "len")); return STATUS_DATA_ERROR; } @@ -66,10 +69,11 @@ OVPN_CRYPTO_ENCRYPT OvpnCryptoEncryptNone; _Use_decl_annotations_ NTSTATUS -OvpnCryptoEncryptNone(OvpnCryptoKeySlot* keySlot, UCHAR* buf, SIZE_T len) +OvpnCryptoEncryptNone(OvpnCryptoKeySlot* keySlot, UCHAR* buf, SIZE_T len, INT32 cryptoOptions) { UNREFERENCED_PARAMETER(keySlot); UNREFERENCED_PARAMETER(len); + UNREFERENCED_PARAMETER(cryptoOptions); // prepend with opcode, key-id and peer-id UINT32 op = OvpnProtoOp32Compose(OVPN_OP_DATA_V2, 0, 0); @@ -121,74 +125,116 @@ OvpnCryptoUninitAlgHandles(_In_ BCRYPT_ALG_HANDLE aesAlgHandle, BCRYPT_ALG_HANDL static NTSTATUS -OvpnCryptoAEADDoWork(BOOLEAN encrypt, OvpnCryptoKeySlot* keySlot, UCHAR *bufIn, SIZE_T len, UCHAR* bufOut) +OvpnCryptoAEADDoWork(BOOLEAN encrypt, OvpnCryptoKeySlot* keySlot, UCHAR *bufIn, SIZE_T len, UCHAR* bufOut, INT32 cryptoOptions) { /* AEAD Nonce : [Packet ID] [HMAC keying material] - [4 bytes ] [8 bytes ] + [4/8 bytes] [8/4 bytes ] [AEAD nonce total : 12 bytes ] TLS wire protocol : + Packet ID is 8 bytes long with CRYPTO_OPTIONS_64BIT_PKTID. + [DATA_V2 opcode] [Packet ID] [AEAD Auth tag] [ciphertext] - [4 bytes ] [4 bytes ] [16 bytes ] + [4 bytes ] [4/8 bytes] [16 bytes ] + [AEAD additional data(AD) ] + + With CRYPTO_OPTIONS_AEAD_TAG_END AEAD Auth tag is placed after ciphertext: + + [DATA_V2 opcode] [Packet ID] [ciphertext] [AEAD Auth tag] + [4 bytes ] [4/8 bytes] [16 bytes ] [AEAD additional data(AD) ] */ NTSTATUS status = STATUS_SUCCESS; - if (len < AEAD_CRYPTO_OVERHEAD) { + BOOLEAN pktId64bit = cryptoOptions & CRYPTO_OPTIONS_64BIT_PKTID; + + auto cryptoOverhead = OVPN_DATA_V2_LEN + AEAD_AUTH_TAG_LEN + (pktId64bit ? 8 : 4); + + if (len < cryptoOverhead) { LOG_WARN("Packet too short", TraceLoggingValue(len, "len")); return STATUS_DATA_ERROR; } - UCHAR nonce[OVPN_PKTID_LEN + OVPN_NONCE_TAIL_LEN]; + UCHAR nonce[12]; if (encrypt) { // prepend with opcode, key-id and peer-id UINT32 op = OvpnProtoOp32Compose(OVPN_OP_DATA_V2, keySlot->KeyId, keySlot->PeerId); op = RtlUlongByteSwap(op); - *(UINT32*)(bufOut) = op; + *reinterpret_cast(bufOut) = op; - // calculate pktid - UINT32 pktid; - GOTO_IF_NOT_NT_SUCCESS(done, status, OvpnPktidXmitNext(&keySlot->PktidXmit, &pktid)); - ULONG pktidNetwork = RtlUlongByteSwap(pktid); + if (pktId64bit) + { + // calculate pktid + UINT64 pktid; + GOTO_IF_NOT_NT_SUCCESS(done, status, OvpnPktidXmitNext(&keySlot->PktidXmit, &pktid, true)); + ULONG64 pktidNetwork = RtlUlonglongByteSwap(pktid); + + // calculate nonce, which is pktid + nonce_tail + RtlCopyMemory(nonce, &pktidNetwork, 8); + RtlCopyMemory(nonce + 8, keySlot->EncNonceTail, 4); - // calculate nonce, which is pktid + nonce_tail - RtlCopyMemory(nonce, &pktidNetwork, OVPN_PKTID_LEN); - RtlCopyMemory(nonce + OVPN_PKTID_LEN, keySlot->EncNonceTail, OVPN_NONCE_TAIL_LEN); + // prepend with pktid + *reinterpret_cast(bufOut + OVPN_DATA_V2_LEN) = pktidNetwork; + } + else + { + // calculate pktid + UINT32 pktid; + GOTO_IF_NOT_NT_SUCCESS(done, status, OvpnPktidXmitNext(&keySlot->PktidXmit, &pktid, false)); + ULONG pktidNetwork = RtlUlongByteSwap(pktid); - // prepend with pktid - *(UINT32*)(bufOut + OVPN_DATA_V2_LEN) = pktidNetwork; + // calculate nonce, which is pktid + nonce_tail + RtlCopyMemory(nonce, &pktidNetwork, 4); + RtlCopyMemory(nonce + 4, keySlot->EncNonceTail, 8); + + // prepend with pktid + *reinterpret_cast(bufOut + OVPN_DATA_V2_LEN) = pktidNetwork; + } } else { - RtlCopyMemory(nonce, bufIn + OVPN_DATA_V2_LEN, OVPN_PKTID_LEN); - RtlCopyMemory(nonce + OVPN_PKTID_LEN, &keySlot->DecNonceTail, sizeof(keySlot->DecNonceTail)); + ULONG64 pktId; + + RtlCopyMemory(nonce, bufIn + OVPN_DATA_V2_LEN, pktId64bit ? 8 : 4); + RtlCopyMemory(nonce + (pktId64bit ? 8 : 4), &keySlot->DecNonceTail, pktId64bit ? 4 : 8); + if (pktId64bit) + { + pktId = RtlUlonglongByteSwap(*reinterpret_cast(nonce)); + } + else + { + pktId = static_cast(RtlUlongByteSwap(*reinterpret_cast(nonce))); + } - UINT32 pktId = RtlUlongByteSwap(*(UINT32*)nonce); status = OvpnPktidRecvVerify(&keySlot->PktidRecv, pktId); if (!NT_SUCCESS(status)) { - LOG_ERROR("Invalid pktId", TraceLoggingUInt32(pktId, "pktId")); + LOG_ERROR("Invalid pktId", TraceLoggingUInt64(pktId, "pktId")); return STATUS_DATA_ERROR; } } + // we prepended buf with crypto overhead + len -= cryptoOverhead; + + BOOLEAN aeadTagEnd = cryptoOptions & CRYPTO_OPTIONS_AEAD_TAG_END; + BCRYPT_AUTHENTICATED_CIPHER_MODE_INFO authInfo; BCRYPT_INIT_AUTH_MODE_INFO(authInfo); authInfo.pbNonce = nonce; authInfo.cbNonce = sizeof(nonce); - authInfo.pbTag = (encrypt ? bufOut : bufIn) + OVPN_DATA_V2_LEN + OVPN_PKTID_LEN; + authInfo.pbTag = (encrypt ? bufOut : bufIn) + OVPN_DATA_V2_LEN + (pktId64bit ? 8 : 4) + (aeadTagEnd ? len : 0); authInfo.cbTag = AEAD_AUTH_TAG_LEN; authInfo.pbAuthData = (encrypt ? bufOut : bufIn); - authInfo.cbAuthData = OVPN_DATA_V2_LEN + OVPN_PKTID_LEN; - - bufOut += AEAD_CRYPTO_OVERHEAD; - bufIn += AEAD_CRYPTO_OVERHEAD; + authInfo.cbAuthData = OVPN_DATA_V2_LEN + (pktId64bit ? 8 : 4); - len -= AEAD_CRYPTO_OVERHEAD; + auto payloadOffset = OVPN_DATA_V2_LEN + (pktId64bit ? 8 : 4) + (aeadTagEnd ? 0 : AEAD_AUTH_TAG_LEN); + bufOut += payloadOffset; + bufIn += payloadOffset; // non-chaining mode ULONG bytesDone = 0; @@ -205,27 +251,29 @@ OVPN_CRYPTO_DECRYPT OvpnCryptoDecryptAEAD; _Use_decl_annotations_ NTSTATUS -OvpnCryptoDecryptAEAD(OvpnCryptoKeySlot* keySlot, UCHAR* bufIn, SIZE_T len, UCHAR* bufOut) +OvpnCryptoDecryptAEAD(OvpnCryptoKeySlot* keySlot, UCHAR* bufIn, SIZE_T len, UCHAR* bufOut, INT32 cryptoOptions) { - return OvpnCryptoAEADDoWork(FALSE, keySlot, bufIn, len, bufOut); + return OvpnCryptoAEADDoWork(FALSE, keySlot, bufIn, len, bufOut, cryptoOptions); } OVPN_CRYPTO_ENCRYPT OvpnCryptoEncryptAEAD; _Use_decl_annotations_ NTSTATUS -OvpnCryptoEncryptAEAD(OvpnCryptoKeySlot* keySlot, UCHAR* buf, SIZE_T len) +OvpnCryptoEncryptAEAD(OvpnCryptoKeySlot* keySlot, UCHAR* buf, SIZE_T len, INT32 cryptoOptions) { - return OvpnCryptoAEADDoWork(TRUE, keySlot, buf, len, buf); + return OvpnCryptoAEADDoWork(TRUE, keySlot, buf, len, buf, cryptoOptions); } _Use_decl_annotations_ NTSTATUS -OvpnCryptoNewKey(OvpnCryptoContext* cryptoContext, POVPN_CRYPTO_DATA cryptoData) +OvpnCryptoNewKey(OvpnCryptoContext* cryptoContext, POVPN_CRYPTO_DATA_V2 cryptoDataV2) { OvpnCryptoKeySlot* keySlot = NULL; NTSTATUS status = STATUS_SUCCESS; + POVPN_CRYPTO_DATA cryptoData = &cryptoDataV2->V1; + if (cryptoData->KeySlot == OVPN_KEY_SLOT::OVPN_KEY_SLOT_PRIMARY) { keySlot = &cryptoContext->Primary; } @@ -237,6 +285,15 @@ OvpnCryptoNewKey(OvpnCryptoContext* cryptoContext, POVPN_CRYPTO_DATA cryptoData) return STATUS_INVALID_DEVICE_REQUEST; } + if (cryptoDataV2->CryptoOptions & CRYPTO_OPTIONS_64BIT_PKTID) + { + cryptoContext->CryptoOptions |= CRYPTO_OPTIONS_64BIT_PKTID; + } + if (cryptoDataV2->CryptoOptions & CRYPTO_OPTIONS_AEAD_TAG_END) + { + cryptoContext->CryptoOptions |= CRYPTO_OPTIONS_AEAD_TAG_END; + } + if ((cryptoData->CipherAlg == OVPN_CIPHER_ALG_AES_GCM) || (cryptoData->CipherAlg == OVPN_CIPHER_ALG_CHACHA20_POLY1305)) { // destroy previous keys if (keySlot->EncKey) { @@ -284,8 +341,6 @@ OvpnCryptoNewKey(OvpnCryptoContext* cryptoContext, POVPN_CRYPTO_DATA cryptoData) keySlot->KeyId = cryptoData->KeyId; keySlot->PeerId = cryptoData->PeerId; - cryptoContext->CryptoOverhead = AEAD_CRYPTO_OVERHEAD; - LOG_INFO("New key", TraceLoggingValue(cryptoData->CipherAlg == OVPN_CIPHER_ALG_AES_GCM ? "aes-gcm" : "chacha20-poly1305", "alg"), TraceLoggingValue(cryptoData->KeyId, "KeyId"), TraceLoggingValue(cryptoData->KeyId, "PeerId")); } @@ -293,8 +348,6 @@ OvpnCryptoNewKey(OvpnCryptoContext* cryptoContext, POVPN_CRYPTO_DATA cryptoData) cryptoContext->Encrypt = OvpnCryptoEncryptNone; cryptoContext->Decrypt = OvpnCryptoDecryptNone; - cryptoContext->CryptoOverhead = NONE_CRYPTO_OVERHEAD; - LOG_INFO("Using cipher none"); } else { diff --git a/crypto.h b/crypto.h index dfbc688..def22d7 100644 --- a/crypto.h +++ b/crypto.h @@ -28,14 +28,8 @@ #include "uapi\ovpn-dco.h" #include "socket.h" -#define AEAD_CRYPTO_OVERHEAD 24 // 4 + 4 + 16 data_v2 + pktid + auth_tag -#define NONE_CRYPTO_OVERHEAD 8 // 4 + 4 data_v2 + pktid -#define OVPN_PKTID_LEN 4 -#define OVPN_NONCE_TAIL_LEN 8 #define OVPN_DATA_V2_LEN 4 #define AEAD_AUTH_TAG_LEN 16 -#define AES_BLOCK_SIZE 16 -#define AES_GCM_NONCE_LEN 12 // packet opcode (high 5 bits) and key-id (low 3 bits) are combined in one byte #define OVPN_OP_DATA_V2 9 @@ -62,7 +56,7 @@ _IRQL_requires_max_(DISPATCH_LEVEL) _Must_inspect_result_ typedef NTSTATUS -OVPN_CRYPTO_ENCRYPT(_In_ OvpnCryptoKeySlot* keySlot, _In_ UCHAR* buf, _In_ SIZE_T len); +OVPN_CRYPTO_ENCRYPT(_In_ OvpnCryptoKeySlot* keySlot, _In_ UCHAR* buf, _In_ SIZE_T len, _In_ INT32 CryptoOptions); typedef OVPN_CRYPTO_ENCRYPT* POVPN_CRYPTO_ENCRYPT; _Function_class_(OVPN_CRYPTO_DECRYPT) @@ -70,7 +64,7 @@ _IRQL_requires_max_(DISPATCH_LEVEL) _Must_inspect_result_ typedef NTSTATUS -OVPN_CRYPTO_DECRYPT(_In_ OvpnCryptoKeySlot* keySlot, _In_ UCHAR* bufIn, _In_ SIZE_T len, _In_ UCHAR* bufOut); +OVPN_CRYPTO_DECRYPT(_In_ OvpnCryptoKeySlot* keySlot, _In_ UCHAR* bufIn, _In_ SIZE_T len, _In_ UCHAR* bufOut, _In_ INT32 CryptoOptions); typedef OVPN_CRYPTO_DECRYPT* POVPN_CRYPTO_DECRYPT; struct OvpnCryptoContext @@ -84,7 +78,7 @@ struct OvpnCryptoContext POVPN_CRYPTO_ENCRYPT Encrypt; POVPN_CRYPTO_DECRYPT Decrypt; - SIZE_T CryptoOverhead; + INT32 CryptoOptions; }; _Must_inspect_result_ @@ -101,7 +95,7 @@ OvpnCryptoUninit(_In_ OvpnCryptoContext* cryptoContext); _Must_inspect_result_ NTSTATUS -OvpnCryptoNewKey(_In_ OvpnCryptoContext* cryptoContext, _In_ POVPN_CRYPTO_DATA cryptoData); +OvpnCryptoNewKey(_In_ OvpnCryptoContext* cryptoContext, _In_ POVPN_CRYPTO_DATA_V2 cryptoData); _Must_inspect_result_ OvpnCryptoKeySlot* diff --git a/peer.cpp b/peer.cpp index e84612c..d8e0748 100644 --- a/peer.cpp +++ b/peer.cpp @@ -262,9 +262,35 @@ OvpnPeerNewKey(POVPN_DEVICE device, WDFREQUEST request) } POVPN_CRYPTO_DATA cryptoData = NULL; + OVPN_CRYPTO_DATA_V2 cryptoDataV2{}; NTSTATUS status; GOTO_IF_NOT_NT_SUCCESS(done, status, WdfRequestRetrieveInputBuffer(request, sizeof(OVPN_CRYPTO_DATA), (PVOID*)&cryptoData, nullptr)); + + RtlCopyMemory(&cryptoDataV2.V1, cryptoData, sizeof(OVPN_CRYPTO_DATA)); + GOTO_IF_NOT_NT_SUCCESS(done, status, OvpnCryptoNewKey(&device->CryptoContext, &cryptoDataV2)); + +done: + LOG_EXIT(); + + return status; +} + +_Use_decl_annotations_ +NTSTATUS +OvpnPeerNewKeyV2(POVPN_DEVICE device, WDFREQUEST request) +{ + LOG_ENTER(); + + if (InterlockedCompareExchange(&device->UserspacePid, 0, 0) == 0) { + LOG_ERROR("Peer not added"); + return STATUS_INVALID_DEVICE_REQUEST; + } + + POVPN_CRYPTO_DATA_V2 cryptoData = NULL; + NTSTATUS status; + + GOTO_IF_NOT_NT_SUCCESS(done, status, WdfRequestRetrieveInputBuffer(request, sizeof(OVPN_CRYPTO_DATA_V2), (PVOID*)&cryptoData, nullptr)); GOTO_IF_NOT_NT_SUCCESS(done, status, OvpnCryptoNewKey(&device->CryptoContext, cryptoData)); done: diff --git a/peer.h b/peer.h index 4eecca6..8906b0c 100644 --- a/peer.h +++ b/peer.h @@ -57,6 +57,11 @@ _Requires_exclusive_lock_held_(device->SpinLock) NTSTATUS OvpnPeerNewKey(_In_ POVPN_DEVICE device, WDFREQUEST request); +_Must_inspect_result_ +_Requires_exclusive_lock_held_(device->SpinLock) +NTSTATUS +OvpnPeerNewKeyV2(_In_ POVPN_DEVICE device, WDFREQUEST request); + _Must_inspect_result_ _Requires_exclusive_lock_held_(device->SpinLock) NTSTATUS diff --git a/pktid.cpp b/pktid.cpp index b7f365d..4d94ee3 100644 --- a/pktid.cpp +++ b/pktid.cpp @@ -28,24 +28,29 @@ #define PKTID_WRAP_WARN 0xf0000000ULL _Use_decl_annotations_ -NTSTATUS OvpnPktidXmitNext(OvpnPktidXmit* px, UINT32* pktId) +NTSTATUS OvpnPktidXmitNext(OvpnPktidXmit* px, VOID* pktId, BOOLEAN pktId64bit) { ULONG64 seqNum = InterlockedIncrementNoFence64(&px->SeqNum); - *pktId = (UINT32)seqNum; - if (seqNum < PKTID_WRAP_WARN) { - return STATUS_SUCCESS; - } - else { - LOG_ERROR("Pktid wrapped"); - return STATUS_INTEGER_OVERFLOW; - } + if (pktId64bit) { + *static_cast(pktId) = seqNum; + } + else + { + *static_cast(pktId) = static_cast(seqNum); + if (seqNum >= PKTID_WRAP_WARN) { + LOG_ERROR("Pktid wrapped"); + return STATUS_INTEGER_OVERFLOW; + } + } + + return STATUS_SUCCESS; } #define PKTID_RECV_EXPIRE ((30 * WDF_TIMEOUT_TO_SEC) / KeQueryTimeIncrement()) _Use_decl_annotations_ -NTSTATUS OvpnPktidRecvVerify(OvpnPktidRecv* pr, UINT32 pktId) +NTSTATUS OvpnPktidRecvVerify(OvpnPktidRecv* pr, UINT64 pktId) { LARGE_INTEGER now; KeQueryTickCount(&now); @@ -69,16 +74,16 @@ NTSTATUS OvpnPktidRecvVerify(OvpnPktidRecv* pr, UINT32 pktId) } else if (pktId > pr->Id) { /* ID jumped forward by more than one */ - UINT32 delta = pktId - pr->Id; + const auto delta = pktId - pr->Id; if (delta < REPLAY_WINDOW_SIZE) { pr->Base = REPLAY_INDEX(pr->Base, -(INT32)delta); pr->History[pr->Base / 8] |= (1 << (pr->Base % 8)); - pr->Extent += delta; + pr->Extent += static_cast(delta); if (pr->Extent > REPLAY_WINDOW_SIZE) pr->Extent = REPLAY_WINDOW_SIZE; - for (UINT32 i = 1; i < delta; ++i) { - unsigned int newb = REPLAY_INDEX(pr->Base, i); + for (auto i = 1; i < delta; ++i) { + const auto newb = REPLAY_INDEX(pr->Base, i); pr->History[newb / 8] &= ~BIT(newb % 8); } @@ -93,10 +98,8 @@ NTSTATUS OvpnPktidRecvVerify(OvpnPktidRecv* pr, UINT32 pktId) } else { /* ID backtrack */ - UINT32 delta = pr->Id - pktId; + const auto delta = pr->Id - pktId; - if (delta > pr->MaxBacktrack) - pr->MaxBacktrack = delta; if (delta < pr->Extent) { if (pktId > pr->IdFloor) { UINT32 ri = REPLAY_INDEX(pr->Base, delta); diff --git a/pktid.h b/pktid.h index b0d2325..dcc4be8 100644 --- a/pktid.h +++ b/pktid.h @@ -50,17 +50,17 @@ struct OvpnPktidRecv LARGE_INTEGER Expire; /* highest sequence number received */ - UINT32 Id; + UINT64 Id; /* we will only accept backtrack IDs > id_floor */ - UINT32 IdFloor; - UINT32 MaxBacktrack; + UINT64 IdFloor; }; /* Get the next packet ID for xmit */ -NTSTATUS OvpnPktidXmitNext(_In_ OvpnPktidXmit* px, _Out_ UINT32* pktId); +NTSTATUS OvpnPktidXmitNext(_In_ OvpnPktidXmit* px, _Out_ VOID* pktId, BOOLEAN pktId64bit); + /* Packet replay detection. * Allows ID backtrack of up to REPLAY_WINDOW_SIZE - 1. */ -NTSTATUS OvpnPktidRecvVerify(_In_ OvpnPktidRecv* pid, UINT32 pktId); +NTSTATUS OvpnPktidRecvVerify(_In_ OvpnPktidRecv* pid, UINT64 pktId); diff --git a/rxqueue.cpp b/rxqueue.cpp index bc0f8c0..794fa89 100644 --- a/rxqueue.cpp +++ b/rxqueue.cpp @@ -101,6 +101,10 @@ OvpnEvtRxQueueAdvance(NETPACKETQUEUE netPacketQueue) POVPN_RXQUEUE queue = OvpnGetRxQueueContext(netPacketQueue); OVPN_DEVICE* device = OvpnGetDeviceContext(queue->Adapter->WdfDevice); + BOOLEAN pktId64bit = device->CryptoContext.CryptoOptions & CRYPTO_OPTIONS_64BIT_PKTID; + BOOLEAN aeadTagEnd = device->CryptoContext.CryptoOptions & CRYPTO_OPTIONS_AEAD_TAG_END; + auto payloadOffset = OVPN_DATA_V2_LEN + (pktId64bit ? 8 : 4) + (aeadTagEnd ? 0 : AEAD_AUTH_TAG_LEN); + NET_RING_FRAGMENT_ITERATOR fi = NetRingGetAllFragments(queue->Rings); NET_RING_PACKET_ITERATOR pi = NetRingGetAllPackets(queue->Rings); while (NetFragmentIteratorHasAny(&fi)) { @@ -115,7 +119,7 @@ OvpnEvtRxQueueAdvance(NETPACKETQUEUE netPacketQueue) fragment->ValidLength = buffer->Len; fragment->Offset = 0; NET_FRAGMENT_VIRTUAL_ADDRESS* virtualAddr = NetExtensionGetFragmentVirtualAddress(&queue->VirtualAddressExtension, NetFragmentIteratorGetIndex(&fi)); - RtlCopyMemory(virtualAddr->VirtualAddress, buffer->Data + device->CryptoContext.CryptoOverhead, buffer->Len); + RtlCopyMemory(virtualAddr->VirtualAddress, buffer->Data + payloadOffset, buffer->Len); InterlockedExchangeAddNoFence64(&device->Stats.TunBytesReceived, buffer->Len); diff --git a/socket.cpp b/socket.cpp index 765152f..a41c624 100644 --- a/socket.cpp +++ b/socket.cpp @@ -169,9 +169,11 @@ VOID OvpnSocketDataPacketReceived(_In_ POVPN_DEVICE device, UCHAR op, _In_reads_ return; } - if (device->CryptoContext.Decrypt) { + OvpnCryptoContext* cryptoContext = &device->CryptoContext; + + if (cryptoContext->Decrypt) { UCHAR keyId = OvpnCryptoKeyIdExtract(op); - OvpnCryptoKeySlot* keySlot = OvpnCryptoKeySlotFromKeyId(&device->CryptoContext, keyId); + OvpnCryptoKeySlot* keySlot = OvpnCryptoKeySlotFromKeyId(cryptoContext, keyId); if (!keySlot) { status = STATUS_INVALID_DEVICE_STATE; @@ -179,8 +181,11 @@ VOID OvpnSocketDataPacketReceived(_In_ POVPN_DEVICE device, UCHAR op, _In_reads_ } else { // decrypt into plaintext buffer - status = device->CryptoContext.Decrypt(keySlot, cipherTextBuf, len, buffer->Data); - buffer->Len = len - device->CryptoContext.CryptoOverhead; + status = cryptoContext->Decrypt(keySlot, cipherTextBuf, len, buffer->Data, cryptoContext->CryptoOptions); + + auto pktId64bit = cryptoContext->CryptoOptions & CRYPTO_OPTIONS_64BIT_PKTID; + auto cryptoOverhead = OVPN_DATA_V2_LEN + AEAD_AUTH_TAG_LEN + (pktId64bit ? 8 : 4); + buffer->Len = len - cryptoOverhead; } } else { @@ -197,20 +202,23 @@ VOID OvpnSocketDataPacketReceived(_In_ POVPN_DEVICE device, UCHAR op, _In_reads_ OvpnTimerResetRecv(device->Timer); // points to the beginning of plaintext - UCHAR* buf = buffer->Data + device->CryptoContext.CryptoOverhead; + BOOLEAN pktId64bit = device->CryptoContext.CryptoOptions & CRYPTO_OPTIONS_64BIT_PKTID; + BOOLEAN aeadTagEnd = device->CryptoContext.CryptoOptions & CRYPTO_OPTIONS_AEAD_TAG_END; + auto payloadOffset = OVPN_DATA_V2_LEN + (pktId64bit ? 8 : 4) + (aeadTagEnd ? 0 : AEAD_AUTH_TAG_LEN); + UCHAR* plaintext = buffer->Data + payloadOffset; // ping packet? - if (OvpnTimerIsKeepaliveMessage(buf, buffer->Len)) { + if (OvpnTimerIsKeepaliveMessage(plaintext, buffer->Len)) { LOG_INFO("Ping received"); // no need to inject ping packet into OS, return buffer to the pool OvpnRxBufferPoolPut(buffer); } else { - if (OvpnMssIsIPv4(buf, buffer->Len)) { - OvpnMssDoIPv4(buf, buffer->Len, device->MSS); - } else if (OvpnMssIsIPv6(buf, buffer->Len)) { - OvpnMssDoIPv6(buf, buffer->Len, device->MSS); + if (OvpnMssIsIPv4(plaintext, buffer->Len)) { + OvpnMssDoIPv4(plaintext, buffer->Len, device->MSS); + } else if (OvpnMssIsIPv6(plaintext, buffer->Len)) { + OvpnMssDoIPv6(plaintext, buffer->Len, device->MSS); } // enqueue plaintext buffer, it will be dequeued by NetAdapter RX datapath diff --git a/timer.cpp b/timer.cpp index 215659a..2f5020a 100644 --- a/timer.cpp +++ b/timer.cpp @@ -63,12 +63,19 @@ static VOID OvpnTimerXmit(WDFTIMER timer) RtlCopyMemory(OvpnTxBufferPut(buffer, sizeof(OvpnKeepaliveMessage)), OvpnKeepaliveMessage, sizeof(OvpnKeepaliveMessage)); KIRQL kiqrl = ExAcquireSpinLockShared(&device->SpinLock); - if (device->CryptoContext.Encrypt) { + OvpnCryptoContext* cryptoContext = &device->CryptoContext; + if (cryptoContext->Encrypt) { // make space to crypto overhead - OvpnTxBufferPush(buffer, device->CryptoContext.CryptoOverhead); + BOOLEAN pktId64bit = cryptoContext->CryptoOptions & CRYPTO_OPTIONS_64BIT_PKTID; + BOOLEAN aeadTagEnd = cryptoContext->CryptoOptions & CRYPTO_OPTIONS_AEAD_TAG_END; + + OvpnTxBufferPush(buffer, OVPN_DATA_V2_LEN + (pktId64bit ? 8 : 4) + (aeadTagEnd ? 0 : AEAD_AUTH_TAG_LEN)); + if (aeadTagEnd) { + OvpnTxBufferPut(buffer, AEAD_AUTH_TAG_LEN); + } // in-place encrypt, always with primary key - status = device->CryptoContext.Encrypt(&device->CryptoContext.Primary, buffer->Data, buffer->Len); + status = cryptoContext->Encrypt(&cryptoContext->Primary, buffer->Data, buffer->Len, cryptoContext->CryptoOptions); } else { status = STATUS_INVALID_DEVICE_STATE; diff --git a/txqueue.cpp b/txqueue.cpp index d3887dd..d097cd7 100644 --- a/txqueue.cpp +++ b/txqueue.cpp @@ -83,12 +83,21 @@ OvpnTxProcessPacket(_In_ POVPN_DEVICE device, _In_ POVPN_TXQUEUE queue, _In_ NET InterlockedExchangeAddNoFence64(&device->Stats.TunBytesSent, buffer->Len); - if (device->CryptoContext.Encrypt) { + OvpnCryptoContext* cryptoContext = &device->CryptoContext; + + if (cryptoContext->Encrypt) { + auto aeadTagEnd = cryptoContext->CryptoOptions & CRYPTO_OPTIONS_AEAD_TAG_END; + auto pktId64bit = cryptoContext->CryptoOptions & CRYPTO_OPTIONS_64BIT_PKTID; + // make space to crypto overhead - OvpnTxBufferPush(buffer, device->CryptoContext.CryptoOverhead); + OvpnTxBufferPush(buffer, OVPN_DATA_V2_LEN + (pktId64bit ? 8 : 4) + (aeadTagEnd ? 0 : AEAD_AUTH_TAG_LEN)); + if (aeadTagEnd) + { + OvpnTxBufferPut(buffer, AEAD_AUTH_TAG_LEN); + } // in-place encrypt, always with primary key - status = device->CryptoContext.Encrypt(&device->CryptoContext.Primary, buffer->Data, buffer->Len); + status = cryptoContext->Encrypt(&cryptoContext->Primary, buffer->Data, buffer->Len, cryptoContext->CryptoOptions); } else { status = STATUS_INVALID_DEVICE_STATE; @@ -140,7 +149,7 @@ OvpnEvtTxQueueAdvance(NETPACKETQUEUE netPacketQueue) POVPN_TXQUEUE queue = OvpnGetTxQueueContext(netPacketQueue); NET_RING_PACKET_ITERATOR pi = NetRingGetAllPackets(queue->Rings); POVPN_DEVICE device = OvpnGetDeviceContext(queue->Adapter->WdfDevice); - bool packetSent = false; + BOOLEAN packetSent = false; KIRQL kirql = ExAcquireSpinLockShared(&device->SpinLock); diff --git a/uapi/ovpn-dco.h b/uapi/ovpn-dco.h index ea2a733..9e437f0 100644 --- a/uapi/ovpn-dco.h +++ b/uapi/ovpn-dco.h @@ -94,6 +94,14 @@ typedef struct _OVPN_CRYPTO_DATA { int PeerId; } OVPN_CRYPTO_DATA, * POVPN_CRYPTO_DATA; +#define CRYPTO_OPTIONS_AEAD_TAG_END (1<<1) +#define CRYPTO_OPTIONS_64BIT_PKTID (1<<2) + +typedef struct _OVPN_CRYPTO_DATA_V2 { + OVPN_CRYPTO_DATA V1; + UINT32 CryptoOptions; +} OVPN_CRYPTO_DATA_V2, * POVPN_CRYPTO_DATA_V2; + typedef struct _OVPN_SET_PEER { LONG KeepaliveInterval; LONG KeepaliveTimeout; @@ -114,3 +122,4 @@ typedef struct _OVPN_VERSION { #define OVPN_IOCTL_START_VPN CTL_CODE(FILE_DEVICE_UNKNOWN, 6, METHOD_BUFFERED, FILE_ANY_ACCESS) #define OVPN_IOCTL_DEL_PEER CTL_CODE(FILE_DEVICE_UNKNOWN, 7, METHOD_BUFFERED, FILE_ANY_ACCESS) #define OVPN_IOCTL_GET_VERSION CTL_CODE(FILE_DEVICE_UNKNOWN, 8, METHOD_BUFFERED, FILE_ANY_ACCESS) +#define OVPN_IOCTL_NEW_KEY_V2 CTL_CODE(FILE_DEVICE_UNKNOWN, 9, METHOD_BUFFERED, FILE_ANY_ACCESS)