diff --git a/PropertySheet.props b/PropertySheet.props index d0c1521..a1bb21c 100644 --- a/PropertySheet.props +++ b/PropertySheet.props @@ -4,7 +4,7 @@ 0 9 - 1 + 2 diff --git a/adapter.cpp b/adapter.cpp index e8ece3b..59ff874 100644 --- a/adapter.cpp +++ b/adapter.cpp @@ -87,7 +87,7 @@ OvpnAdapterSetLinkLayerCapabilities(_In_ POVPN_ADAPTER adapter) maxRcvLinkSpeed); NetAdapterSetLinkLayerCapabilities(adapter->NetAdapter, &linkLayerCapabilities); - NetAdapterSetLinkLayerMtuSize(adapter->NetAdapter, 0xFFFF); + NetAdapterSetLinkLayerMtuSize(adapter->NetAdapter, OVPN_DCO_MTU_MAX); } _Use_decl_annotations_ diff --git a/adapter.h b/adapter.h index adef4c8..dfaa0fc 100644 --- a/adapter.h +++ b/adapter.h @@ -25,6 +25,8 @@ #include #include +#define OVPN_DCO_MTU_MAX 1500 + // Context for NETADAPTER struct OVPN_ADAPTER { diff --git a/bufferpool.cpp b/bufferpool.cpp index 5b85661..e8b4357 100644 --- a/bufferpool.cpp +++ b/bufferpool.cpp @@ -23,10 +23,11 @@ #include +#include "adapter.h" #include "bufferpool.h" #include "trace.h" -#define OVPN_BUFFER_HEADROOM 256 +#define OVPN_BUFFER_HEADROOM 26 // we prepend TCP packet size (2 bytes) and crypto overhead (24 bytes) struct OVPN_BUFFER_POOL_IMPL { @@ -124,7 +125,7 @@ _Use_decl_annotations_ NTSTATUS OvpnTxBufferPoolCreate(OVPN_TX_BUFFER_POOL* handle, VOID* ctx) { - return OvpnBufferPoolCreate((OVPN_BUFFER_POOL*)handle, sizeof(OVPN_TX_BUFFER) + OVPN_SOCKET_PACKET_BUFFER_SIZE, "tx", ctx); + return OvpnBufferPoolCreate((OVPN_BUFFER_POOL*)handle, sizeof(OVPN_TX_BUFFER) + OVPN_DCO_MTU_MAX + OVPN_BUFFER_HEADROOM, "tx", ctx); } VOID* @@ -160,7 +161,7 @@ OvpnTxBufferPoolGet(OVPN_TX_BUFFER_POOL handle, OVPN_TX_BUFFER** buffer) if (*buffer == NULL) return STATUS_INSUFFICIENT_RESOURCES; - (*buffer)->Mdl = IoAllocateMdl(*buffer, sizeof(OVPN_TX_BUFFER) + OVPN_SOCKET_PACKET_BUFFER_SIZE, FALSE, FALSE, NULL); + (*buffer)->Mdl = IoAllocateMdl(*buffer, ((OVPN_BUFFER_POOL_IMPL*)handle)->ItemSize, FALSE, FALSE, NULL); MmBuildMdlForNonPagedPool((*buffer)->Mdl); (*buffer)->Pool = handle; diff --git a/bufferpool.h b/bufferpool.h index 14b749e..fcde323 100644 --- a/bufferpool.h +++ b/bufferpool.h @@ -26,7 +26,7 @@ #include #include -#define OVPN_SOCKET_PACKET_BUFFER_SIZE 2048 +#define OVPN_SOCKET_RX_PACKET_BUFFER_SIZE 2048 DECLARE_HANDLE(OVPN_BUFFER_POOL); DECLARE_HANDLE(OVPN_TX_BUFFER_POOL); @@ -71,7 +71,7 @@ struct OVPN_RX_BUFFER OVPN_RX_BUFFER_POOL Pool; - UCHAR Data[OVPN_SOCKET_PACKET_BUFFER_SIZE]; + UCHAR Data[OVPN_SOCKET_RX_PACKET_BUFFER_SIZE]; }; _Must_inspect_result_ diff --git a/socket.cpp b/socket.cpp index 3a58533..3b342f0 100644 --- a/socket.cpp +++ b/socket.cpp @@ -274,9 +274,9 @@ OvpnSocketUdpReceiveFromEvent(_In_ PVOID socketContext, ULONG flags, _In_opt_ PW SIZE_T bytesCopied = 0; SIZE_T bytesRemained = dataIndication->Buffer.Length; - if (bytesRemained > OVPN_SOCKET_PACKET_BUFFER_SIZE) { + if (bytesRemained > OVPN_SOCKET_RX_PACKET_BUFFER_SIZE) { LOG_ERROR("UDP datagram of size is larged than buffer size ", TraceLoggingValue(bytesRemained, "size"), - TraceLoggingValue(OVPN_SOCKET_PACKET_BUFFER_SIZE, "buf")); + TraceLoggingValue(OVPN_SOCKET_RX_PACKET_BUFFER_SIZE, "buf")); RtlZeroMemory(&device->Socket.UdpState, sizeof(OvpnSocketUdpState)); return STATUS_SUCCESS; } @@ -361,9 +361,9 @@ OvpnSocketTcpReceiveEvent(_In_opt_ PVOID socketContext, _In_ ULONG flags, _In_op // header fully read? if (tcpState->BytesRead == sizeof(tcpState->LenBuf)) { USHORT len = RtlUshortByteSwap(*(USHORT*)tcpState->LenBuf); - if ((len == 0) || (len > OVPN_SOCKET_PACKET_BUFFER_SIZE)) { + if ((len == 0) || (len > OVPN_SOCKET_RX_PACKET_BUFFER_SIZE)) { LOG_ERROR("TCP is 0 or larger than ", TraceLoggingValue(len, "payload size"), - TraceLoggingValue(OVPN_SOCKET_PACKET_BUFFER_SIZE, "buffer size")); + TraceLoggingValue(OVPN_SOCKET_RX_PACKET_BUFFER_SIZE, "buffer size")); RtlZeroMemory(tcpState, sizeof(OvpnSocketTcpState)); return STATUS_SUCCESS; } diff --git a/socket.h b/socket.h index 38e04a6..e94247c 100644 --- a/socket.h +++ b/socket.h @@ -37,14 +37,14 @@ struct OvpnSocketTcpState USHORT BytesRead; // packet buffer if packet is scattered across MDLs - UCHAR PacketBuf[OVPN_SOCKET_PACKET_BUFFER_SIZE]; + UCHAR PacketBuf[OVPN_SOCKET_RX_PACKET_BUFFER_SIZE]; }; struct OvpnSocketUdpState { // packet buffer if datagram scattered across MDLs // this seems to only happen in unlikely case when datagram is fragmented - UCHAR PacketBuf[OVPN_SOCKET_PACKET_BUFFER_SIZE]; + UCHAR PacketBuf[OVPN_SOCKET_RX_PACKET_BUFFER_SIZE]; }; struct OvpnSocket diff --git a/txqueue.cpp b/txqueue.cpp index e8c634d..711b555 100644 --- a/txqueue.cpp +++ b/txqueue.cpp @@ -55,6 +55,17 @@ OvpnTxProcessPacket(_In_ POVPN_DEVICE device, _In_ POVPN_TXQUEUE queue, _In_ NET while (NetFragmentIteratorHasAny(&fi)) { // get fragment payload NET_FRAGMENT* fragment = NetFragmentIteratorGetFragment(&fi); + + if ((buffer->Len + fragment->ValidLength) > OVPN_DCO_MTU_MAX) { + LOG_WARN("Packet max length exceeded, dropping", + TraceLoggingValue(buffer->Len, "currentLen"), + TraceLoggingValue(fragment->ValidLength, "lenToAdd"), + TraceLoggingValue(OVPN_DCO_MTU_MAX - buffer->Len, "spaceLeft")); + OvpnTxBufferPoolPut(buffer); + status = STATUS_INVALID_BUFFER_SIZE; + goto out; + } + NET_FRAGMENT_VIRTUAL_ADDRESS* virtualAddr = NetExtensionGetFragmentVirtualAddress( &queue->VirtualAddressExtension, NetFragmentIteratorGetIndex(&fi)); @@ -111,6 +122,7 @@ OvpnTxProcessPacket(_In_ POVPN_DEVICE device, _In_ POVPN_TXQUEUE queue, _In_ NET OvpnTxBufferPoolPut(buffer); } +out: // update fragment ring's BeginIndex to indicate that we've processes all fragments NET_PACKET* packet = NetPacketIteratorGetPacket(pi); NET_RING* const fragmentRing = NetRingCollectionGetFragmentRing(fi.Iterator.Rings);