Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add multithreading support to OCR #24

Merged
merged 6 commits into from
Oct 17, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion Package.swift
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ import PackageDescription
let package = Package(
name: "macSubtitleOCR",
platforms: [
.macOS("13.0")
.macOS("14.0")
],
dependencies: [
.package(url: "https://github.com/apple/swift-argument-parser.git", from: "1.5.0")
Expand Down
47 changes: 47 additions & 0 deletions Sources/macSubtitleOCR/FileHandler.swift
Original file line number Diff line number Diff line change
@@ -0,0 +1,47 @@
//
// FileHandler.swift
// macSubtitleOCR
//
// Created by Ethan Dye on 10/16/24.
// Copyright © 2024 Ethan Dye. All rights reserved.
//

import Foundation

struct FileHandler {
let outputDirectory: String

init(outputDirectory: String) {
self.outputDirectory = outputDirectory
}

func saveSRTFile(for result: macSubtitleOCRResult) throws {
let srtFilePath = URL(fileURLWithPath: outputDirectory).appendingPathComponent("track_\(result.trackNumber).srt")
let srt = SRT(subtitles: result.srt.sorted { $0.index < $1.index })
srt.write(toFileAt: srtFilePath)
}

func saveJSONFile(for result: macSubtitleOCRResult) throws {
let jsonResults = result.json.sorted { $0.index < $1.index }.map { jsonResult in
[
"image": jsonResult.index,
"lines": jsonResult.lines.map { line in
[
"text": line.text,
"confidence": line.confidence,
"x": line.x,
"width": line.width,
"y": line.y,
"height": line.height
] as [String: Any]
},
"text": jsonResult.text
] as [String: Any]
}

let jsonData = try JSONSerialization.data(withJSONObject: jsonResults, options: [.prettyPrinted, .sortedKeys])
let jsonString = String(data: jsonData, encoding: .utf8) ?? "[]"
let jsonFilePath = URL(fileURLWithPath: outputDirectory).appendingPathComponent("track_\(result.trackNumber).json")
try jsonString.write(to: jsonFilePath, atomically: true, encoding: .utf8)
}
}
223 changes: 223 additions & 0 deletions Sources/macSubtitleOCR/SubtitleProcessor.swift
Original file line number Diff line number Diff line change
@@ -0,0 +1,223 @@
//
// SubtitleProcessor.swift
// macSubtitleOCR
//
// Created by Ethan Dye on 10/16/24.
// Copyright © 2024 Ethan Dye. All rights reserved.
//

import CoreGraphics
import Foundation
import os
import UniformTypeIdentifiers
import Vision

private let logger = Logger(subsystem: "github.ecdye.macSubtitleOCR", category: "SubtitleProcessor")

actor SubtitleAccumulator {
var subtitles: [Subtitle] = []
var json: [SubtitleJSONResult] = []

func appendSubtitle(_ subtitle: Subtitle) {
subtitles.append(subtitle)
}

func appendJSON(_ jsonOut: SubtitleJSONResult) {
json.append(jsonOut)
}
}

actor AsyncSemaphore {
private var permits: Int

init(limit: Int) {
permits = limit
}

func wait() async {
while permits <= 0 {
await Task.yield()
}
permits -= 1
}

func signal() {
permits += 1
}
}

struct SubtitleProcessor {
let subtitles: [Subtitle]
let trackNumber: Int
let invert: Bool
let saveImages: Bool
let language: String
let fastMode: Bool
let disableLanguageCorrection: Bool
let forceOldAPI: Bool
let outputDirectory: String
let maxConcurrentTasks: Int

func process() async throws -> macSubtitleOCRResult {
let accumulator = SubtitleAccumulator()
let semaphore = AsyncSemaphore(limit: maxConcurrentTasks) // Limit concurrent tasks

try await withThrowingDiscardingTaskGroup { group in
for subtitle in subtitles {
group.addTask {
// Wait for permission to start the task
await semaphore.wait()
let subIndex = subtitle.index

guard !shouldSkipSubtitle(subtitle, at: subIndex) else {
await semaphore.signal()
return
}

guard let subImage = subtitle.createImage(invert) else {
logger.warning("Could not create image for index \(subIndex)! Skipping...")
await semaphore.signal()
return
}

// Save subtitle image as PNG if requested
if saveImages {
do {
try saveImage(subImage, index: subIndex)
} catch {
logger.error("Error saving image \(trackNumber)-\(subIndex): \(error.localizedDescription)")
}
}

let (subtitleText, subtitleLines) = await recognizeText(from: subImage, at: subIndex)
subtitle.text = subtitleText
subtitle.imageData = nil // Clear the image data to save memory

let jsonOut = SubtitleJSONResult(index: subIndex, lines: subtitleLines, text: subtitleText)

// Safely append to the arrays using the actor
await accumulator.appendSubtitle(subtitle)
await accumulator.appendJSON(jsonOut)
await semaphore.signal()
}
}
}

return await macSubtitleOCRResult(trackNumber: trackNumber, srt: accumulator.subtitles, json: accumulator.json)
}

private func shouldSkipSubtitle(_ subtitle: Subtitle, at index: Int) -> Bool {
if subtitle.imageWidth == 0 || subtitle.imageHeight == 0 {
logger.warning("Skipping subtitle index \(index) with empty image data!")
return true
}
return false
}

private func recognizeText(from image: CGImage, at _: Int) async -> (String, [SubtitleLine]) {
var text = ""
var lines: [SubtitleLine] = []

if !forceOldAPI, #available(macOS 15.0, *) {
let request = createRecognizeTextRequest()
let observations = try? await request.perform(on: image) as [RecognizedTextObservation]
let size = CGSize(width: image.width, height: image.height)
processRecognizedText(observations, &text, &lines, size)
} else {
let request = VNRecognizeTextRequest()
request.recognitionLevel = getOCRMode()
request.usesLanguageCorrection = !disableLanguageCorrection
request.revision = VNRecognizeTextRequestRevision3
request.recognitionLanguages = language.split(separator: ",").map { String($0) }

try? VNImageRequestHandler(cgImage: image, options: [:]).perform([request])
let observations = request.results! as [VNRecognizedTextObservation]
processVNRecognizedText(observations, &text, &lines, image.width, image.height)
}

return (text, lines)
}

@available(macOS 15.0, *)
private func createRecognizeTextRequest() -> RecognizeTextRequest {
var request = RecognizeTextRequest()
request.recognitionLevel = getOCRMode()
request.usesLanguageCorrection = !disableLanguageCorrection
request.recognitionLanguages = language.split(separator: ",").map { Locale.Language(identifier: String($0)) }
return request
}

@available(macOS 15.0, *)
private func processRecognizedText(_ result: [RecognizedTextObservation]?, _ text: inout String,
_ lines: inout [SubtitleLine], _ size: CGSize) {
text = result?.compactMap { observation in
guard let candidate = observation.topCandidates(1).first else { return "" }

let string = candidate.string
let confidence = candidate.confidence
let stringRange = string.startIndex ..< string.endIndex
let boundingBox = candidate.boundingBox(for: stringRange)!.boundingBox
let rect = boundingBox.toImageCoordinates(size, origin: .upperLeft)
let line = SubtitleLine(
text: string,
confidence: confidence,
x: max(0, Int(rect.minX)),
width: Int(rect.size.width),
y: max(0, Int(size.height - rect.minY - rect.size.height)),
height: Int(rect.size.height))
lines.append(line)

return string
}.joined(separator: "\n") ?? ""
}

private func processVNRecognizedText(_ observations: [VNRecognizedTextObservation], _ text: inout String,
_ lines: inout [SubtitleLine], _ width: Int, _ height: Int) {
text = observations.compactMap { observation in
guard let candidate = observation.topCandidates(1).first else { return "" }

let string = candidate.string
let confidence = candidate.confidence
let stringRange = string.startIndex ..< string.endIndex
let boundingBox = try? candidate.boundingBox(for: stringRange)?.boundingBox ?? .zero
let rect = VNImageRectForNormalizedRect(boundingBox ?? .zero, width, height)

let line = SubtitleLine(
text: string,
confidence: confidence,
x: max(0, Int(rect.minX)),
width: Int(rect.size.width),
y: max(0, Int(CGFloat(height) - rect.minY - rect.size.height)),
height: Int(rect.size.height))
lines.append(line)

return string
}.joined(separator: "\n")
}

private func saveImage(_ image: CGImage, index: Int) throws {
let outputDirectory = URL(fileURLWithPath: outputDirectory)
let imageDirectory = outputDirectory.appendingPathComponent("images/" + "track_\(trackNumber)/")
let pngPath = imageDirectory.appendingPathComponent("subtitle_\(index).png")

try FileManager.default.createDirectory(at: imageDirectory, withIntermediateDirectories: true, attributes: nil)

let destination = CGImageDestinationCreateWithURL(pngPath as CFURL, UTType.png.identifier as CFString, 1, nil)
guard let destination else {
throw macSubtitleOCRError.fileCreationError
}
CGImageDestinationAddImage(destination, image, nil)
guard CGImageDestinationFinalize(destination) else {
throw macSubtitleOCRError.fileWriteError
}
}

@available(macOS 15.0, *)
private func getOCRMode() -> RecognizeTextRequest.RecognitionLevel {
fastMode ? .fast : .accurate
}

private func getOCRMode() -> VNRequestTextRecognitionLevel {
fastMode ? .fast : .accurate
}
}
6 changes: 3 additions & 3 deletions Sources/macSubtitleOCR/Subtitles/FFmpeg/FFmpeg.swift
Original file line number Diff line number Diff line change
Expand Up @@ -90,7 +90,7 @@ struct FFmpeg {
var trackSubtitles = subtitleTracks[streamNumber] ?? []
for i in 0 ..< Int(subtitle.num_rects) {
let rect = subtitle.rects[i]!
let sub = extractImageData(from: rect)
let sub = extractImageData(from: rect, index: trackSubtitles.count + 1)
let pts = convertToTimeInterval(packet!.pointee.pts, timeBase: stream.timeBase)
sub.startTimestamp = pts + convertToTimeInterval(subtitle.start_display_time, timeBase: timeBase)
sub.endTimestamp = pts + convertToTimeInterval(subtitle.end_display_time, timeBase: timeBase)
Expand All @@ -102,8 +102,8 @@ struct FFmpeg {
}
}

private func extractImageData(from rect: UnsafeMutablePointer<AVSubtitleRect>) -> Subtitle {
let subtitle = Subtitle(numberOfColors: Int(rect.pointee.nb_colors))
private func extractImageData(from rect: UnsafeMutablePointer<AVSubtitleRect>, index: Int) -> Subtitle {
let subtitle = Subtitle(index: index, numberOfColors: Int(rect.pointee.nb_colors))

// Check if the subtitle is an image (bitmap)
if rect.pointee.type == SUBTITLE_BITMAP {
Expand Down
3 changes: 2 additions & 1 deletion Sources/macSubtitleOCR/Subtitles/PGS/PGS.swift
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ struct PGS {
// MARK: - Properties

private(set) var subtitles = [Subtitle]()
private let logger: Logger = .init(subsystem: "github.ecdye.macSubtitleOCR", category: "PGS")
private let logger = Logger(subsystem: "github.ecdye.macSubtitleOCR", category: "PGS")
private var data: Data
private let pgsHeaderLength = 13

Expand Down Expand Up @@ -120,6 +120,7 @@ struct PGS {
guard let pds, let ods else { continue }
let startTimestamp = parseTimestamp(headerData)
return Subtitle(
index: subtitles.count + 1,
startTimestamp: startTimestamp,
imageWidth: ods.objectWidth,
imageHeight: ods.objectHeight,
Expand Down
18 changes: 16 additions & 2 deletions Sources/macSubtitleOCR/Subtitles/SRT/SRT.swift
Original file line number Diff line number Diff line change
Expand Up @@ -7,11 +7,13 @@
//

import Foundation
import os

struct SRT {
// MARK: - Properties

private var subtitles: [Subtitle] = []
private let logger = Logger(subsystem: "github.ecdye.macSubtitleOCR", category: "SRT")

// MARK: - Getters / Setters

Expand All @@ -38,10 +40,22 @@ struct SRT {
var srtContent = ""

for subtitle in subtitles {
var endTimestamp = subtitle.endTimestamp ?? 0
if subtitle.index + 1 < subtitles.count {
let nextSubtitle = subtitles[subtitle.index + 1]
if nextSubtitle.startTimestamp! <= subtitle.endTimestamp! {
logger.warning("Fixing subtitle index \(subtitle.index) end timestamp!")
if nextSubtitle.startTimestamp! - subtitle.startTimestamp! > 5 {
endTimestamp = subtitle.startTimestamp! + 5
} else {
endTimestamp = nextSubtitle.startTimestamp! - 0.1
}
}
}
let startTime = formatTime(subtitle.startTimestamp!)
let endTime = formatTime(subtitle.endTimestamp!)
let endTime = formatTime(endTimestamp)

srtContent += "\(subtitle.index!)\n"
srtContent += "\(subtitle.index)\n"
srtContent += "\(startTime) --> \(endTime)\n"
srtContent += "\(subtitle.text!)\n\n"
}
Expand Down
6 changes: 3 additions & 3 deletions Sources/macSubtitleOCR/Subtitles/Subtitle.swift
Original file line number Diff line number Diff line change
Expand Up @@ -9,8 +9,8 @@
import CoreGraphics
import Foundation

class Subtitle {
var index: Int?
class Subtitle: @unchecked Sendable {
var index: Int
var text: String?
var startTimestamp: TimeInterval?
var imageXOffset: Int?
Expand All @@ -25,7 +25,7 @@ class Subtitle {
var evenOffset: Int?
var oddOffset: Int?

init(index: Int? = nil, text: String? = nil, startTimestamp: TimeInterval? = nil, endTimestamp: TimeInterval? = nil,
init(index: Int, text: String? = nil, startTimestamp: TimeInterval? = nil, endTimestamp: TimeInterval? = nil,
imageXOffset: Int? = nil, imageYOffset: Int? = nil, imageWidth: Int? = nil, imageHeight: Int? = nil,
imageData: Data? = nil, imagePalette: [UInt8]? = nil, imageAlpha: [UInt8]? = nil, numberOfColors: Int? = nil,
evenOffset: Int? = nil, oddOffset: Int? = nil) {
Expand Down
1 change: 1 addition & 0 deletions Sources/macSubtitleOCR/Subtitles/VobSub/VobSub.swift
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@ struct VobSub {
subFile.seekToEndOfFile()
}
let subtitle = VobSubParser(
index: index + 1,
subFile: subFile,
timestamp: timestamp,
offset: offset,
Expand Down
Loading
Loading