Skip to content

Commit

Permalink
[Android] Improve SslStream PAL buffer resizing (dotnet#104726)
Browse files Browse the repository at this point in the history
* Update PAL sslstream on Android to adjust buffer sizes based on Wrap results

* Add test

* Remove extra test

* Use direct buffers
  • Loading branch information
simonrozsival authored Jul 19, 2024
1 parent 823cd67 commit a3fd095
Show file tree
Hide file tree
Showing 3 changed files with 71 additions and 45 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -451,8 +451,6 @@ jmethodID g_ByteBufferGet;
jmethodID g_ByteBufferLimit;
jmethodID g_ByteBufferPosition;
jmethodID g_ByteBufferPutBuffer;
jmethodID g_ByteBufferPutByteArray;
jmethodID g_ByteBufferPutByteArrayWithLength;
jmethodID g_ByteBufferRemaining;

// javax/net/ssl/SSLContext
Expand All @@ -477,6 +475,7 @@ jclass g_SSLEngineResult;
jmethodID g_SSLEngineResultGetStatus;
jmethodID g_SSLEngineResultGetHandshakeStatus;
bool g_SSLEngineResultStatusLegacyOrder;
jmethodID g_SSLEngineResultBytesConsumed;

// javax/crypto/KeyAgreement
jclass g_KeyAgreementClass;
Expand Down Expand Up @@ -1073,8 +1072,6 @@ jint AndroidCryptoNative_InitLibraryOnLoad (JavaVM *vm, void *reserved)
g_ByteBufferLimit = GetMethod(env, false, g_ByteBuffer, "limit", "()I");
g_ByteBufferPosition = GetMethod(env, false, g_ByteBuffer, "position", "()I");
g_ByteBufferPutBuffer = GetMethod(env, false, g_ByteBuffer, "put", "(Ljava/nio/ByteBuffer;)Ljava/nio/ByteBuffer;");
g_ByteBufferPutByteArray = GetMethod(env, false, g_ByteBuffer, "put", "([B)Ljava/nio/ByteBuffer;");
g_ByteBufferPutByteArrayWithLength = GetMethod(env, false, g_ByteBuffer, "put", "([BII)Ljava/nio/ByteBuffer;");
g_ByteBufferRemaining = GetMethod(env, false, g_ByteBuffer, "remaining", "()I");

g_SSLContext = GetClassGRef(env, "javax/net/ssl/SSLContext");
Expand All @@ -1095,6 +1092,7 @@ jint AndroidCryptoNative_InitLibraryOnLoad (JavaVM *vm, void *reserved)
g_SSLEngineResult = GetClassGRef(env, "javax/net/ssl/SSLEngineResult");
g_SSLEngineResultGetStatus = GetMethod(env, false, g_SSLEngineResult, "getStatus", "()Ljavax/net/ssl/SSLEngineResult$Status;");
g_SSLEngineResultGetHandshakeStatus = GetMethod(env, false, g_SSLEngineResult, "getHandshakeStatus", "()Ljavax/net/ssl/SSLEngineResult$HandshakeStatus;");
g_SSLEngineResultBytesConsumed = GetMethod(env, false, g_SSLEngineResult, "bytesConsumed", "()I");
g_SSLEngineResultStatusLegacyOrder = android_get_device_api_level() < 24;

g_KeyAgreementClass = GetClassGRef(env, "javax/crypto/KeyAgreement");
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -465,8 +465,6 @@ extern jmethodID g_ByteBufferGet;
extern jmethodID g_ByteBufferLimit;
extern jmethodID g_ByteBufferPosition;
extern jmethodID g_ByteBufferPutBuffer;
extern jmethodID g_ByteBufferPutByteArray;
extern jmethodID g_ByteBufferPutByteArrayWithLength;
extern jmethodID g_ByteBufferRemaining;

// javax/net/ssl/SSLContext
Expand All @@ -491,6 +489,7 @@ extern jclass g_SSLEngineResult;
extern jmethodID g_SSLEngineResultGetStatus;
extern jmethodID g_SSLEngineResultGetHandshakeStatus;
extern bool g_SSLEngineResultStatusLegacyOrder;
extern jmethodID g_SSLEngineResultBytesConsumed;

// javax/crypto/KeyAgreement
extern jclass g_KeyAgreementClass;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@ struct ApplicationProtocolData_t
ARGS_NON_NULL(1) static uint16_t* AllocateString(JNIEnv* env, jstring source);

ARGS_NON_NULL_ALL static PAL_SSLStreamStatus DoHandshake(JNIEnv* env, SSLStream* sslStream);
ARGS_NON_NULL_ALL static PAL_SSLStreamStatus DoWrap(JNIEnv* env, SSLStream* sslStream, int* handshakeStatus);
ARGS_NON_NULL_ALL static PAL_SSLStreamStatus DoWrap(JNIEnv* env, SSLStream* sslStream, int* handshakeStatus, int* bytesConsumed);
ARGS_NON_NULL_ALL static PAL_SSLStreamStatus DoUnwrap(JNIEnv* env, SSLStream* sslStream, int* handshakeStatus);

ARGS_NON_NULL_ALL static int GetHandshakeStatus(JNIEnv* env, SSLStream* sslStream)
Expand Down Expand Up @@ -112,15 +112,15 @@ ARGS_NON_NULL_ALL static PAL_SSLStreamStatus Close(JNIEnv* env, SSLStream* sslSt
{
// Call wrap to clear any remaining data before closing
int unused;
PAL_SSLStreamStatus ret = DoWrap(env, sslStream, &unused);
PAL_SSLStreamStatus ret = DoWrap(env, sslStream, &unused, &unused);

// sslEngine.closeOutbound();
(*env)->CallVoidMethod(env, sslStream->sslEngine, g_SSLEngineCloseOutbound);
if (ret != SSLStreamStatus_OK)
return ret;

// Flush any remaining data (e.g. sending close notification)
return DoWrap(env, sslStream, &unused);
return DoWrap(env, sslStream, &unused, &unused);
}

ARGS_NON_NULL_ALL static PAL_SSLStreamStatus Flush(JNIEnv* env, SSLStream* sslStream)
Expand Down Expand Up @@ -172,10 +172,14 @@ ARGS_NON_NULL_ALL static jobject ExpandBuffer(JNIEnv* env, jobject oldBuffer, in

ARGS_NON_NULL_ALL static jobject EnsureRemaining(JNIEnv* env, jobject oldBuffer, int32_t newRemaining)
{
IGNORE_RETURN((*env)->CallObjectMethod(env, oldBuffer, g_ByteBufferCompact));
int32_t oldPosition = (*env)->CallIntMethod(env, oldBuffer, g_ByteBufferPosition);
int32_t oldRemaining = (*env)->CallIntMethod(env, oldBuffer, g_ByteBufferRemaining);
if (oldRemaining < newRemaining)
{
return ExpandBuffer(env, oldBuffer, oldRemaining + newRemaining);
// After compacting the oldBuffer, the oldPosition is equal to the number of bytes in the buffer at the moment
// we need to change the capacity to the oldPosition + newRemaining
return ExpandBuffer(env, oldBuffer, oldPosition + newRemaining);
}
else
{
Expand Down Expand Up @@ -204,22 +208,19 @@ static int MapLegacySSLEngineResultStatus(int legacyStatus)
}
}

ARGS_NON_NULL_ALL static PAL_SSLStreamStatus DoWrap(JNIEnv* env, SSLStream* sslStream, int* handshakeStatus)
ARGS_NON_NULL_ALL static PAL_SSLStreamStatus WrapAndProcessResult(JNIEnv* env, SSLStream* sslStream, int* handshakeStatus, int* bytesConsumed, bool* repeat)
{
// appOutBuffer.flip();
// SSLEngineResult result = sslEngine.wrap(appOutBuffer, netOutBuffer);
IGNORE_RETURN((*env)->CallObjectMethod(env, sslStream->appOutBuffer, g_ByteBufferFlip));
jobject result = (*env)->CallObjectMethod(
env, sslStream->sslEngine, g_SSLEngineWrap, sslStream->appOutBuffer, sslStream->netOutBuffer);
if (CheckJNIExceptions(env))
return SSLStreamStatus_Error;

// appOutBuffer.compact();
IGNORE_RETURN((*env)->CallObjectMethod(env, sslStream->appOutBuffer, g_ByteBufferCompact));

// handshakeStatus = result.getHandshakeStatus();
// bytesConsumed = result.bytesConsumed();
// SSLEngineResult.Status status = result.getStatus();
*handshakeStatus = GetEnumAsInt(env, (*env)->CallObjectMethod(env, result, g_SSLEngineResultGetHandshakeStatus));
*bytesConsumed = (*env)->CallIntMethod(env, result, g_SSLEngineResultBytesConsumed);
int status = GetEnumAsInt(env, (*env)->CallObjectMethod(env, result, g_SSLEngineResultGetStatus));
(*env)->DeleteLocalRef(env, result);

Expand All @@ -242,11 +243,10 @@ ARGS_NON_NULL_ALL static PAL_SSLStreamStatus DoWrap(JNIEnv* env, SSLStream* sslS
}
case STATUS__BUFFER_OVERFLOW:
{
// Expand buffer
// int newCapacity = sslSession.getPacketBufferSize() + netOutBuffer.remaining();
int32_t newCapacity = (*env)->CallIntMethod(env, sslStream->sslSession, g_SSLSessionGetPacketBufferSize) +
(*env)->CallIntMethod(env, sslStream->netOutBuffer, g_ByteBufferRemaining);
sslStream->netOutBuffer = ExpandBuffer(env, sslStream->netOutBuffer, newCapacity);
// Expand buffer and repeat the wrap
int32_t packetBufferSize = (*env)->CallIntMethod(env, sslStream->sslSession, g_SSLSessionGetPacketBufferSize);
sslStream->netOutBuffer = ExpandBuffer(env, sslStream->netOutBuffer, packetBufferSize);
*repeat = true;
return SSLStreamStatus_OK;
}
default:
Expand All @@ -257,32 +257,60 @@ ARGS_NON_NULL_ALL static PAL_SSLStreamStatus DoWrap(JNIEnv* env, SSLStream* sslS
}
}

ARGS_NON_NULL_ALL static PAL_SSLStreamStatus DoWrap(JNIEnv* env, SSLStream* sslStream, int* handshakeStatus, int* bytesConsumed)
{
// appOutBuffer.flip();
IGNORE_RETURN((*env)->CallObjectMethod(env, sslStream->appOutBuffer, g_ByteBufferFlip));

bool repeat = false;
PAL_SSLStreamStatus status = WrapAndProcessResult(env, sslStream, handshakeStatus, bytesConsumed, &repeat);

if (repeat)
{
repeat = false;
status = WrapAndProcessResult(env, sslStream, handshakeStatus, bytesConsumed, &repeat);

if (repeat)
{
LOG_ERROR("Unexpected repeat in DoWrap");
return SSLStreamStatus_Error;
}
}

// appOutBuffer.compact();
IGNORE_RETURN((*env)->CallObjectMethod(env, sslStream->appOutBuffer, g_ByteBufferCompact));

return status;
}

ARGS_NON_NULL_ALL static PAL_SSLStreamStatus DoUnwrap(JNIEnv* env, SSLStream* sslStream, int* handshakeStatus)
{
// if (netInBuffer.position() == 0)
// {
// byte[] tmp = new byte[netInBuffer.limit()];
// int count = streamReader(tmp, 0, tmp.length);
// netInBuffer.put(tmp, 0, count);
// int netInBufferLimit = netInBuffer.limit();
// ByteBuffer tmp = ByteBuffer.allocateDirect(netInBufferLimit);
// int count = streamReader(tmp, 0, netInBufferLimit);
// netInBuffer.put(tmp);
// }
if ((*env)->CallIntMethod(env, sslStream->netInBuffer, g_ByteBufferPosition) == 0)
{
int netInBufferLimit = (*env)->CallIntMethod(env, sslStream->netInBuffer, g_ByteBufferLimit);
jbyteArray tmp = make_java_byte_array(env, netInBufferLimit);
uint8_t* tmpNative = (uint8_t*)xmalloc((size_t)netInBufferLimit);
int count = netInBufferLimit;
// todo assert streamReader != 0 ?
PAL_SSLStreamStatus status = sslStream->streamReader(sslStream->managedContextHandle, tmpNative, &count);
if (status != SSLStreamStatus_OK)
{
free(tmpNative);
(*env)->DeleteLocalRef(env, tmp);
return status;
}

(*env)->SetByteArrayRegion(env, tmp, 0, count, (jbyte*)(tmpNative));
jobject tmp = (*env)->NewDirectByteBuffer(env, tmpNative, count);
ON_EXCEPTION_PRINT_AND_GOTO(cleanup);

IGNORE_RETURN(
(*env)->CallObjectMethod(env, sslStream->netInBuffer, g_ByteBufferPutByteArrayWithLength, tmp, 0, count));
(*env)->CallObjectMethod(env, sslStream->netInBuffer, g_ByteBufferPutBuffer, tmp));
cleanup:
free(tmpNative);
(*env)->DeleteLocalRef(env, tmp);
}
Expand Down Expand Up @@ -350,13 +378,14 @@ ARGS_NON_NULL_ALL static PAL_SSLStreamStatus DoHandshake(JNIEnv* env, SSLStream*
PAL_SSLStreamStatus status = SSLStreamStatus_OK;
int handshakeStatus = GetHandshakeStatus(env, sslStream);
assert(handshakeStatus >= 0);
int bytesConsumed;

while (IsHandshaking(handshakeStatus) && status == SSLStreamStatus_OK)
{
switch (handshakeStatus)
{
case HANDSHAKE_STATUS__NEED_WRAP:
status = DoWrap(env, sslStream, &handshakeStatus);
status = DoWrap(env, sslStream, &handshakeStatus, &bytesConsumed);
break;
case HANDSHAKE_STATUS__NEED_UNWRAP:
status = DoUnwrap(env, sslStream, &handshakeStatus);
Expand Down Expand Up @@ -858,26 +887,24 @@ PAL_SSLStreamStatus AndroidCryptoNative_SSLStreamWrite(SSLStream* sslStream, uin
JNIEnv* env = GetJNIEnv();
PAL_SSLStreamStatus ret = SSLStreamStatus_Error;

// int remaining = appOutBuffer.remaining();
// int arraySize = length > remaining ? remaining : length;
// byte[] data = new byte[arraySize];
int32_t remaining = (*env)->CallIntMethod(env, sslStream->appOutBuffer, g_ByteBufferRemaining);
int32_t arraySize = length > remaining ? remaining : length;
jbyteArray data = make_java_byte_array(env, arraySize);
// ByteBuffer bufferByteBuffer = ...;
jobject bufferByteBuffer = (*env)->NewDirectByteBuffer(env, buffer, length);
ON_EXCEPTION_PRINT_AND_GOTO(cleanup);

// appOutBuffer.compact();
// appOutBuffer = EnsureRemaining(appOutBuffer, length);
// appOutBuffer.put(bufferByteBuffer);
IGNORE_RETURN((*env)->CallObjectMethod(env, sslStream->appOutBuffer, g_ByteBufferCompact));
sslStream->appOutBuffer = EnsureRemaining(env, sslStream->appOutBuffer, length);
IGNORE_RETURN((*env)->CallObjectMethod(env, sslStream->appOutBuffer, g_ByteBufferPutBuffer, bufferByteBuffer));
ON_EXCEPTION_PRINT_AND_GOTO(cleanup);

int32_t written = 0;
while (written < length)
{
int32_t toWrite = length - written > arraySize ? arraySize : length - written;
(*env)->SetByteArrayRegion(env, data, 0, toWrite, (jbyte*)(buffer + written));

// appOutBuffer.put(data, 0, toWrite);
IGNORE_RETURN((*env)->CallObjectMethod(env, sslStream->appOutBuffer, g_ByteBufferPutByteArrayWithLength, data, 0, toWrite));
ON_EXCEPTION_PRINT_AND_GOTO(cleanup);
written += toWrite;

int handshakeStatus;
ret = DoWrap(env, sslStream, &handshakeStatus);
int bytesConsumed;
ret = DoWrap(env, sslStream, &handshakeStatus, &bytesConsumed);
if (ret != SSLStreamStatus_OK)
{
goto cleanup;
Expand All @@ -887,10 +914,12 @@ PAL_SSLStreamStatus AndroidCryptoNative_SSLStreamWrite(SSLStream* sslStream, uin
ret = SSLStreamStatus_Renegotiate;
goto cleanup;
}

written += bytesConsumed;
}

cleanup:
(*env)->DeleteLocalRef(env, data);
(*env)->DeleteLocalRef(env, bufferByteBuffer);
return ret;
}

Expand Down

0 comments on commit a3fd095

Please sign in to comment.