Skip to content

Commit

Permalink
Formatting, and replace precondition with default value to prevent cr…
Browse files Browse the repository at this point in the history
…ashing
  • Loading branch information
ZachNagengast committed Jun 4, 2024
1 parent 6f02580 commit 9bec547
Show file tree
Hide file tree
Showing 3 changed files with 83 additions and 85 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,8 @@
// Created by Zach Nagengast on 4/7/23.
//

import Foundation
import Accelerate
import Foundation

/// A struct implementing the `DistanceMetricProtocol` using the dot product.
///
Expand All @@ -21,19 +21,20 @@ public struct DotProduct: DistanceMetricProtocol {
return sortedScores(scores: scores, topK: resultsCount)
}


public func distance(between firstEmbedding: [Float], and secondEmbedding: [Float]) -> Float {
// Ensure the embeddings have the same length
precondition(firstEmbedding.count == secondEmbedding.count, "Embeddings must have the same length")

var dotProduct: Float = 0

// Calculate dot product using Accelerate
vDSP_dotpr(firstEmbedding, 1, secondEmbedding, 1, &dotProduct, vDSP_Length(firstEmbedding.count))

return dotProduct
}

public func distance(between firstEmbedding: [Float], and secondEmbedding: [Float]) -> Float {
// Ensure the embeddings have the same length
if firstEmbedding.count != secondEmbedding.count {
print("Embeddings must have the same length")
return -Float.greatestFiniteMagnitude
}

var dotProduct: Float = 0

// Calculate dot product using Accelerate
vDSP_dotpr(firstEmbedding, 1, secondEmbedding, 1, &dotProduct, vDSP_Length(firstEmbedding.count))

return dotProduct
}
}

/// A struct implementing the `DistanceMetricProtocol` using cosine similarity.
Expand All @@ -49,28 +50,29 @@ public struct CosineSimilarity: DistanceMetricProtocol {
return sortedScores(scores: scores, topK: resultsCount)
}


public func distance(between firstEmbedding: [Float], and secondEmbedding: [Float]) -> Float {
// Ensure the embeddings have the same length
precondition(firstEmbedding.count == secondEmbedding.count, "Embeddings must have the same length")

var dotProduct: Float = 0
var firstMagnitude: Float = 0
var secondMagnitude: Float = 0

// Calculate dot product and magnitudes using Accelerate
vDSP_dotpr(firstEmbedding, 1, secondEmbedding, 1, &dotProduct, vDSP_Length(firstEmbedding.count))
vDSP_svesq(firstEmbedding, 1, &firstMagnitude, vDSP_Length(firstEmbedding.count))
vDSP_svesq(secondEmbedding, 1, &secondMagnitude, vDSP_Length(secondEmbedding.count))

// Take square root of magnitudes
firstMagnitude = sqrt(firstMagnitude)
secondMagnitude = sqrt(secondMagnitude)

// Return cosine similarity
return dotProduct / (firstMagnitude * secondMagnitude)
}

public func distance(between firstEmbedding: [Float], and secondEmbedding: [Float]) -> Float {
// Ensure the embeddings have the same length
if firstEmbedding.count != secondEmbedding.count {
print("Embeddings must have the same length")
return -1
}

var dotProduct: Float = 0
var firstMagnitude: Float = 0
var secondMagnitude: Float = 0

// Calculate dot product and magnitudes using Accelerate
vDSP_dotpr(firstEmbedding, 1, secondEmbedding, 1, &dotProduct, vDSP_Length(firstEmbedding.count))
vDSP_svesq(firstEmbedding, 1, &firstMagnitude, vDSP_Length(firstEmbedding.count))
vDSP_svesq(secondEmbedding, 1, &secondMagnitude, vDSP_Length(secondEmbedding.count))

// Take square root of magnitudes
firstMagnitude = sqrt(firstMagnitude)
secondMagnitude = sqrt(secondMagnitude)

// Return cosine similarity
return dotProduct / (firstMagnitude * secondMagnitude)
}
}

/// A struct implementing the `DistanceMetricProtocol` using Euclidean distance.
Expand All @@ -79,26 +81,28 @@ public struct CosineSimilarity: DistanceMetricProtocol {
///
/// - Note: Use this metric when the magnitudes of the embeddings are significant in your use case, and the embeddings are distributed in a Euclidean space.
public struct EuclideanDistance: DistanceMetricProtocol {

public init() {}

public func findNearest(for queryEmbedding: [Float], in neighborEmbeddings: [[Float]], resultsCount: Int) -> [(Float, Int)] {
let distances = neighborEmbeddings.map { distance(between: queryEmbedding, and: $0) }
return sortedDistances(distances: distances, topK: resultsCount)
}

public func distance(between firstEmbedding: [Float], and secondEmbedding: [Float]) -> Float {
// Ensure the embeddings have the same length
precondition(firstEmbedding.count == secondEmbedding.count, "Embeddings must have the same length")

var distance: Float = 0

// Calculate squared differences and sum them using Accelerate
vDSP_distancesq(firstEmbedding, 1, secondEmbedding, 1, &distance, vDSP_Length(firstEmbedding.count))

// Return the square root of the summed squared differences
return sqrt(distance)
}

public func distance(between firstEmbedding: [Float], and secondEmbedding: [Float]) -> Float {
// Ensure the embeddings have the same length
if firstEmbedding.count != secondEmbedding.count {
print("Embeddings must have the same length")
return Float.greatestFiniteMagnitude
}

var distance: Float = 0

// Calculate squared differences and sum them using Accelerate
vDSP_distancesq(firstEmbedding, 1, secondEmbedding, 1, &distance, vDSP_Length(firstEmbedding.count))

// Return the square root of the summed squared differences
return sqrt(distance)
}
}

// MARK: - Helpers
Expand Down Expand Up @@ -133,7 +137,7 @@ public func sortedScores(scores: [Float], topK: Int) -> [(Float, Int)] {
/// - Parameters:
/// - distances: An array of Float values representing distances.
/// - topK: The number of top distances to return.
///
///
/// - Returns: An array of tuples containing the top K distances and their corresponding indices.
public func sortedDistances(distances: [Float], topK: Int) -> [(Float, Int)] {
// Combine indices & distances
Expand Down
61 changes: 28 additions & 33 deletions Sources/SimilaritySearchKit/Core/Index/SimilarityIndex.swift
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,6 @@
//

import Foundation
import OSLog

// MARK: - Type Aliases

Expand All @@ -19,20 +18,18 @@ public typealias VectorStoreType = SimilarityIndex.VectorStoreType

@available(macOS 11.0, iOS 15.0, *)
public class SimilarityIndex: Identifiable, Hashable {

public static func == (lhs: SimilarityIndex, rhs: SimilarityIndex) -> Bool {
return lhs.id == rhs.id
}

public func hash(into hasher: inout Hasher) {
hasher.combine(id)
}

// MARK: - Properties

/// A unique identifier
public var id: UUID = UUID()

/// Unique identifier for this index instance
public var id: UUID = .init()
public static func == (lhs: SimilarityIndex, rhs: SimilarityIndex) -> Bool {
return lhs.id == rhs.id
}

public func hash(into hasher: inout Hasher) {
hasher.combine(id)
}

/// The items stored in the index.
public var indexItems: [IndexItem] = []

Expand Down Expand Up @@ -161,7 +158,7 @@ public class SimilarityIndex: Identifiable, Hashable {
var indexIds: [String] = []
var indexEmbeddings: [[Float]] = []

indexItems.forEach { item in
for item in indexItems {
indexIds.append(item.id)
indexEmbeddings.append(item.embedding)
}
Expand Down Expand Up @@ -220,18 +217,18 @@ public class SimilarityIndex: Identifiable, Hashable {
// MARK: - CRUD

@available(macOS 11.0, iOS 15.0, *)
extension SimilarityIndex {
public extension SimilarityIndex {
// MARK: Create

// Add an item with optional pre-computed embedding
public func addItem(id: String, text: String, metadata: [String: String], embedding: [Float]? = nil) async {
/// Add an item with optional pre-computed embedding
func addItem(id: String, text: String, metadata: [String: String], embedding: [Float]? = nil) async {
let embeddingResult = await getEmbedding(for: text, embedding: embedding)

let item = IndexItem(id: id, text: text, embedding: embeddingResult, metadata: metadata)
indexItems.append(item)
}

public func addItems(ids: [String], texts: [String], metadata: [[String: String]], embeddings: [[Float]?]? = nil, onProgress: ((String) -> Void)? = nil) async {
func addItems(ids: [String], texts: [String], metadata: [[String: String]], embeddings: [[Float]?]? = nil, onProgress: ((String) -> Void)? = nil) async {
// Check if all input arrays have the same length
guard ids.count == texts.count, texts.count == metadata.count else {
fatalError("Input arrays must have the same length.")
Expand All @@ -258,7 +255,7 @@ extension SimilarityIndex {
}
}

public func addItems(_ items: [IndexItem], completion: (() -> Void)? = nil) {
func addItems(_ items: [IndexItem], completion: (() -> Void)? = nil) {
Task {
for item in items {
await self.addItem(id: item.id, text: item.text, metadata: item.metadata, embedding: item.embedding)
Expand All @@ -269,17 +266,17 @@ extension SimilarityIndex {

// MARK: Read

public func getItem(id: String) -> IndexItem? {
func getItem(id: String) -> IndexItem? {
return indexItems.first { $0.id == id }
}

public func sample(_ count: Int) -> [IndexItem]? {
func sample(_ count: Int) -> [IndexItem]? {
return Array(indexItems.prefix(upTo: count))
}

// MARK: Update

public func updateItem(id: String, text: String? = nil, embedding: [Float]? = nil, metadata: [String: String]? = nil) {
func updateItem(id: String, text: String? = nil, embedding: [Float]? = nil, metadata: [String: String]? = nil) {
// Check if the provided embedding has the correct dimension
if let embedding = embedding, embedding.count != dimension {
print("Dimension mismatch, expected \(dimension), saw \(embedding.count)")
Expand All @@ -306,21 +303,20 @@ extension SimilarityIndex {

// MARK: Delete

public func removeItem(id: String) {
func removeItem(id: String) {
indexItems.removeAll { $0.id == id }
}

public func removeAll() {
func removeAll() {
indexItems.removeAll()
}
}

// MARK: - Persistence

@available(macOS 13.0, iOS 16.0, *)
extension SimilarityIndex {

public func saveIndex(toDirectory path: URL? = nil, name: String? = nil) throws -> URL {
public extension SimilarityIndex {
func saveIndex(toDirectory path: URL? = nil, name: String? = nil) throws -> URL {
let indexName = name ?? self.indexName
let basePath: URL

Expand All @@ -333,13 +329,12 @@ extension SimilarityIndex {

let savedVectorStore = try vectorStore.saveIndex(items: indexItems, to: basePath, as: indexName)

let bundleId: String = Bundle.main.bundleIdentifier ?? "com.similarity-search-kit.logger"
let logger: Logger = Logger(subsystem: bundleId, category: "similarityIndexSave")
print("Saved \(indexItems.count) index items to \(savedVectorStore.absoluteString)")

return savedVectorStore
}

public func loadIndex(fromDirectory path: URL? = nil, name: String? = nil) throws -> [IndexItem]? {
func loadIndex(fromDirectory path: URL? = nil, name: String? = nil) throws -> [IndexItem]? {
if let indexPath = try getIndexPath(fromDirectory: path, name: name) {
indexItems = try vectorStore.loadIndex(from: indexPath)
return indexItems
Expand All @@ -355,7 +350,7 @@ extension SimilarityIndex {
/// - name: optional name
///
/// - Returns: an optional URL
public func getIndexPath(fromDirectory path: URL? = nil, name: String? = nil) throws -> URL? {
func getIndexPath(fromDirectory path: URL? = nil, name: String? = nil) throws -> URL? {
let indexName = name ?? self.indexName
let basePath: URL

Expand All @@ -382,7 +377,7 @@ extension SimilarityIndex {
return appSpecificDirectory
}

public func estimatedSizeInBytes() -> Int {
func estimatedSizeInBytes() -> Int {
var totalSize = 0

for item in indexItems {
Expand All @@ -397,7 +392,7 @@ extension SimilarityIndex {
let embeddingSize = item.embedding.count * floatSize

// Calculate the size of 'metadata' property
let metadataSize = item.metadata.reduce(0) { (size, keyValue) -> Int in
let metadataSize = item.metadata.reduce(0) { size, keyValue -> Int in
let keySize = keyValue.key.utf8.count
let valueSize = keyValue.value.utf8.count
return size + keySize + valueSize
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,6 @@
import Foundation

public class JsonStore: VectorStoreProtocol {

public func saveIndex(items: [IndexItem], to url: URL, as name: String) throws -> URL {
let encoder = JSONEncoder()
let data = try encoder.encode(items)
Expand Down

0 comments on commit 9bec547

Please sign in to comment.