Skip to content

Commit

Permalink
Further modernization effort
Browse files Browse the repository at this point in the history
- Add support for List, ListPaginated and ListAll
- Switch to ReadOnlyMemory<float> to avoid copying embeddings in SemanticKernel
- Use UnsafeAccessor for all cases of zero-copy RepeatedField<T> construction and stop referencing reflection on .NET8+
- Switch from archived repo .proto to up-to-date upstream one from pinecone-dotnet-client
  • Loading branch information
neon-sunset committed Sep 19, 2024
1 parent d462352 commit d125d26
Show file tree
Hide file tree
Showing 14 changed files with 357 additions and 132 deletions.
5 changes: 4 additions & 1 deletion .github/dependabot.yml
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,10 @@
version: 2
updates:
- package-ecosystem: "nuget" # See documentation for possible values
directory: "/" # Location of package manifests
directory: |
src
example
test
schedule:
interval: "weekly"
groups:
Expand Down
6 changes: 3 additions & 3 deletions .gitmodules
Original file line number Diff line number Diff line change
@@ -1,3 +1,3 @@
[submodule "deps/pinecone-client"]
path = deps/pinecone-client
url = https://github.com/pinecone-io/pinecone-client
[submodule "deps/pinecone-dotnet-client"]
path = deps/pinecone-dotnet-client
url = https://github.com/pinecone-io/pinecone-dotnet-client
1 change: 0 additions & 1 deletion deps/pinecone-client
Submodule pinecone-client deleted from bf1ffa
1 change: 1 addition & 0 deletions deps/pinecone-dotnet-client
Submodule pinecone-dotnet-client added at 664894
190 changes: 117 additions & 73 deletions src/Grpc/Converters.cs
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
using System.Reflection;
using System.Runtime.CompilerServices;
using System.Runtime.InteropServices;
using Google.Protobuf.Collections;
using Google.Protobuf.WellKnownTypes;

Expand All @@ -19,24 +20,21 @@ public static Struct ToProtoStruct(this MetadataMap source)
return protoStruct;
}

public static Value ToProtoValue(this MetadataValue source)
public static Value ToProtoValue(this MetadataValue source) => source.Inner switch
{
return source.Inner switch
{
// This is terrible but such is life
null => Value.ForNull(),
double num => Value.ForNumber(num),
string str => Value.ForString(str),
bool boolean => Value.ForBool(boolean),
MetadataMap nested => Value.ForStruct(nested.ToProtoStruct()),
IEnumerable<MetadataValue> list => Value.ForList(list.Select(v => v.ToProtoValue()).ToArray()),
_ => ThrowHelpers.ArgumentException<Value>($"Unsupported metadata type: {source.Inner!.GetType()}")
};
}
// This is terrible but such is life
null => Value.ForNull(),
double num => Value.ForNumber(num),
string str => Value.ForString(str),
bool boolean => Value.ForBool(boolean),
MetadataMap nested => Value.ForStruct(nested.ToProtoStruct()),
IEnumerable<MetadataValue> list => Value.ForList(list.Select(v => v.ToProtoValue()).ToArray()),
_ => ThrowHelpers.ArgumentException<Value>($"Unsupported metadata type: {source.Inner!.GetType()}")
};

public static global::Vector ToProtoVector(this Vector source)
public static Vector ToProtoVector(this Pinecone.Vector source)
{
var protoVector = new global::Vector
var protoVector = new Vector
{
Id = source.Id,
SparseValues = source.SparseValues?.ToProtoSparseValues(),
Expand Down Expand Up @@ -69,32 +67,29 @@ public static SparseValues ToProtoSparseValues(this SparseVector source)
TotalVectorCount = source.TotalVectorCount
};

public static Vector ToPublicType(this global::Vector source)
public static Pinecone.Vector ToPublicType(this Vector source) => new()
{
return new Vector
{
Id = source.Id,
Values = source.Values.AsArray(),
SparseValues = source.SparseValues?.Indices.Count > 0
? new SparseVector
{
Indices = source.SparseValues.Indices.AsArray(),
Values = source.SparseValues.Values.AsArray()
}
: null,
Metadata = source.Metadata?.Fields.ToPublicType()
};
}
Id = source.Id,
Values = source.Values.AsMemory(),
SparseValues = source.SparseValues?.Indices.Count > 0
? new SparseVector
{
Indices = source.SparseValues.Indices.AsMemory(),
Values = source.SparseValues.Values.AsMemory()
}
: null,
Metadata = source.Metadata?.Fields.ToPublicType()
};

public static ScoredVector ToPublicType(this global::ScoredVector source) => new()
public static Pinecone.ScoredVector ToPublicType(this ScoredVector source) => new()
{
Id = source.Id,
Score = source.Score,
Values = source.Values.AsArray(),
Values = source.Values.AsMemory(),
SparseValues = source.SparseValues?.Indices.Count > 0 ? new()
{
Indices = source.SparseValues.Indices.AsArray(),
Values = source.SparseValues.Values.AsArray()
Indices = source.SparseValues.Indices.AsMemory(),
Values = source.SparseValues.Values.AsMemory()
} : null,
Metadata = source.Metadata?.Fields.ToPublicType()
};
Expand Down Expand Up @@ -123,63 +118,119 @@ Value.KindOneofCase.None or
_ => ThrowHelpers.ArgumentException<MetadataValue>($"Unsupported metadata type: {source.KindCase}")
};
}

#if NET8_0_OR_GREATER
// These have to be duplicated because unsafe accessor does not support generics in .NET 8.
// This approach is, however, very useful as we completely bypass referencing reflection for NAOT.
public static ReadOnlyMemory<float> AsMemory(this RepeatedField<float> source)
{
return ArrayRef(source).AsMemory(0, source.Count);
}

public static T[] AsArray<T>(this RepeatedField<T> source) where T : unmanaged
public static void OverwriteWith(this RepeatedField<float> target, ReadOnlyMemory<float>? source)
{
var buffer = FieldAccessors<T>.GetArray(source);
if (buffer.Length != source.Count)
if (source is null or { IsEmpty: true }) return;

float[] array;
int count;
if (MemoryMarshal.TryGetArray(source.Value, out var segment)
&& segment.Offset is 0)
{
buffer = buffer.AsSpan(0, source.Count).ToArray();
array = segment.Array!;
count = segment.Count;
}
else
{
array = source.Value.ToArray();
count = array.Length;
}

return buffer;
ArrayRef(target) = array;
CountRef(target) = count;
}

public static void OverwriteWith<T>(this RepeatedField<T> target, T[]? source) where T : unmanaged
[UnsafeAccessor(UnsafeAccessorKind.Field, Name = "array")]
static extern ref float[] ArrayRef(RepeatedField<float> instance);

[UnsafeAccessor(UnsafeAccessorKind.Field, Name = "count")]
static extern ref int CountRef(RepeatedField<float> instance);

public static ReadOnlyMemory<uint> AsMemory(this RepeatedField<uint> source)
{
if (source is null) return;
return ArrayRef(source).AsMemory(0, source.Count);
}

FieldAccessors<T>.SetArray(target, source);
FieldAccessors<T>.SetCount(target, source.Length);
public static void OverwriteWith(this RepeatedField<uint> target, ReadOnlyMemory<uint>? source)
{
if (source is null or { IsEmpty: true }) return;

uint[] array;
int count;
if (MemoryMarshal.TryGetArray(source.Value, out var segment)
&& segment.Offset is 0)
{
array = segment.Array!;
count = segment.Count;
}
else
{
array = source.Value.ToArray();
count = array.Length;
}

ArrayRef(target) = array;
CountRef(target) = count;
}

[UnsafeAccessor(UnsafeAccessorKind.Field, Name = "array")]
static extern ref uint[] ArrayRef(RepeatedField<uint> instance);

[UnsafeAccessor(UnsafeAccessorKind.Field, Name = "count")]
static extern ref int CountRef(RepeatedField<uint> instance);
#else
public static ReadOnlyMemory<T> AsMemory<T>(this RepeatedField<T> source)
where T : unmanaged
{
return FieldAccessors<T>.GetArray(source).AsMemory(0, source.Count);
}

public static void OverwriteWith<T>(this RepeatedField<T> target, ReadOnlyMemory<T>? source)
where T : unmanaged
{
if (source is null or { IsEmpty: true }) return;

T[] array;
int count;
if (MemoryMarshal.TryGetArray(source.Value, out var segment)
&& segment.Offset is 0)
{
array = segment.Array!;
count = segment.Count;
}
else
{
array = source.Value.ToArray();
count = array.Length;
}

FieldAccessors<T>.SetArray(target, array);
FieldAccessors<T>.SetCount(target, count);
}

private static class FieldAccessors<T> where T : unmanaged
{
public static T[] GetArray(RepeatedField<T> instance)
{
#if NET8_0_OR_GREATER
if (instance is RepeatedField<float> floatSeq)
{
return (T[])(object)ArrayRef(floatSeq);
}
#endif

return (T[])ArrayField.GetValue(instance)!;
}

public static void SetArray(RepeatedField<T> instance, T[] value)
{
#if NET8_0_OR_GREATER
if (instance is RepeatedField<float> floatSeq)
{
ArrayRef(floatSeq) = (float[])(object)value;
return;
}
#endif

ArrayField.SetValue(instance, value);
}

public static void SetCount(RepeatedField<T> instance, int value)
{
#if NET8_0_OR_GREATER
if (instance is RepeatedField<float> floatSeq)
{
CountRef(floatSeq) = value;
return;
}
#endif

CountField.SetValue(instance, value);
}

Expand All @@ -189,12 +240,5 @@ public static void SetCount(RepeatedField<T> instance, int value)
static readonly FieldInfo CountField = typeof(RepeatedField<T>)
.GetField("count", BindingFlags.NonPublic | BindingFlags.Instance) ?? throw new NullReferenceException();
}

#if NET8_0_OR_GREATER
[UnsafeAccessor(UnsafeAccessorKind.Field, Name = "array")]
static extern ref float[] ArrayRef(RepeatedField<float> instance);

[UnsafeAccessor(UnsafeAccessorKind.Field, Name = "count")]
static extern ref int CountRef(RepeatedField<float> instance);
#endif
}
40 changes: 32 additions & 8 deletions src/Grpc/GrpcTransport.cs
Original file line number Diff line number Diff line change
Expand Up @@ -45,9 +45,9 @@ public async Task<IndexStats> DescribeStats(MetadataMap? filter = null, Cancella
return (await call.ConfigureAwait(false)).ToPublicType();
}

public async Task<ScoredVector[]> Query(
public async Task<Pinecone.ScoredVector[]> Query(
string? id,
float[]? values,
ReadOnlyMemory<float>? values,
SparseVector? sparseValues,
uint topK,
MetadataMap? filter,
Expand Down Expand Up @@ -84,7 +84,7 @@ public async Task<ScoredVector[]> Query(
var response = await call.ConfigureAwait(false);

var matches = response.Matches;
var vectors = new ScoredVector[response.Matches.Count];
var vectors = new Pinecone.ScoredVector[response.Matches.Count];
for (var i = 0; i < matches.Count; i++)
{
vectors[i] = matches[i].ToPublicType();
Expand All @@ -93,7 +93,7 @@ public async Task<ScoredVector[]> Query(
return vectors;
}

public async Task<uint> Upsert(IEnumerable<Vector> vectors, string? indexNamespace = null, CancellationToken ct = default)
public async Task<uint> Upsert(IEnumerable<Pinecone.Vector> vectors, string? indexNamespace = null, CancellationToken ct = default)
{
var request = new UpsertRequest { Namespace = indexNamespace ?? "" };
request.Vectors.AddRange(vectors.Select(v => v.ToProtoVector()));
Expand All @@ -103,7 +103,7 @@ public async Task<uint> Upsert(IEnumerable<Vector> vectors, string? indexNamespa
return (await call.ConfigureAwait(false)).UpsertedCount;
}

public Task Update(Vector vector, string? indexNamespace = null, CancellationToken ct = default) => Update(
public Task Update(Pinecone.Vector vector, string? indexNamespace = null, CancellationToken ct = default) => Update(
vector.Id,
vector.Values,
vector.SparseValues,
Expand All @@ -113,12 +113,12 @@ public Task Update(Vector vector, string? indexNamespace = null, CancellationTok

public async Task Update(
string id,
float[]? values = null,
ReadOnlyMemory<float>? values = null,
SparseVector? sparseValues = null,
MetadataMap? metadata = null,
string? indexNamespace = null,
CancellationToken ct = default)
{
{
if (values is null && sparseValues is null && metadata is null)
{
ThrowHelpers.ArgumentException(
Expand All @@ -138,7 +138,31 @@ public async Task Update(
_ = await call.ConfigureAwait(false);
}

public async Task<Dictionary<string, Vector>> Fetch(
public async Task<(string[] VectorIds, string? PaginationToken, uint ReadUnits)> List(
string? prefix,
uint? limit,
string? paginationToken,
string? indexNamespace = null,
CancellationToken ct = default)
{
var request = new ListRequest
{
Prefix = prefix ?? "",
Limit = limit ?? 0,
PaginationToken = paginationToken ?? "",
Namespace = indexNamespace ?? ""
};

using var call = Grpc.ListAsync(request, Metadata, cancellationToken: ct);
var response = await call.ConfigureAwait(false);

return (
response.Vectors.Select(v => v.Id).ToArray(),
response.Pagination.Next,
response.Usage.ReadUnits);
}

public async Task<Dictionary<string, Pinecone.Vector>> Fetch(
IEnumerable<string> ids, string? indexNamespace = null, CancellationToken ct = default)
{
var request = new FetchRequest
Expand Down
Loading

0 comments on commit d125d26

Please sign in to comment.