Skip to content

Commit

Permalink
Add multithreading support to OCR (#24)
Browse files Browse the repository at this point in the history
This adds support for running the OCR multithreaded to improve the
performance of the output.

Signed-off-by: Ethan Dye <[email protected]>
  • Loading branch information
ecdye authored Oct 17, 2024
1 parent 663d585 commit 7af4566
Show file tree
Hide file tree
Showing 12 changed files with 393 additions and 216 deletions.
2 changes: 1 addition & 1 deletion Package.swift
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ import PackageDescription
let package = Package(
name: "macSubtitleOCR",
platforms: [
.macOS("13.0")
.macOS("14.0")
],
dependencies: [
.package(url: "https://github.com/apple/swift-argument-parser.git", from: "1.5.0")
Expand Down
47 changes: 47 additions & 0 deletions Sources/macSubtitleOCR/FileHandler.swift
Original file line number Diff line number Diff line change
@@ -0,0 +1,47 @@
//
// FileHandler.swift
// macSubtitleOCR
//
// Created by Ethan Dye on 10/16/24.
// Copyright © 2024 Ethan Dye. All rights reserved.
//

import Foundation

struct FileHandler {
let outputDirectory: String

init(outputDirectory: String) {
self.outputDirectory = outputDirectory
}

func saveSRTFile(for result: macSubtitleOCRResult) throws {
let srtFilePath = URL(fileURLWithPath: outputDirectory).appendingPathComponent("track_\(result.trackNumber).srt")
let srt = SRT(subtitles: result.srt.sorted { $0.index < $1.index })
srt.write(toFileAt: srtFilePath)
}

func saveJSONFile(for result: macSubtitleOCRResult) throws {
let jsonResults = result.json.sorted { $0.index < $1.index }.map { jsonResult in
[
"image": jsonResult.index,
"lines": jsonResult.lines.map { line in
[
"text": line.text,
"confidence": line.confidence,
"x": line.x,
"width": line.width,
"y": line.y,
"height": line.height
] as [String: Any]
},
"text": jsonResult.text
] as [String: Any]
}

let jsonData = try JSONSerialization.data(withJSONObject: jsonResults, options: [.prettyPrinted, .sortedKeys])
let jsonString = String(data: jsonData, encoding: .utf8) ?? "[]"
let jsonFilePath = URL(fileURLWithPath: outputDirectory).appendingPathComponent("track_\(result.trackNumber).json")
try jsonString.write(to: jsonFilePath, atomically: true, encoding: .utf8)
}
}
223 changes: 223 additions & 0 deletions Sources/macSubtitleOCR/SubtitleProcessor.swift
Original file line number Diff line number Diff line change
@@ -0,0 +1,223 @@
//
// SubtitleProcessor.swift
// macSubtitleOCR
//
// Created by Ethan Dye on 10/16/24.
// Copyright © 2024 Ethan Dye. All rights reserved.
//

import CoreGraphics
import Foundation
import os
import UniformTypeIdentifiers
import Vision

private let logger = Logger(subsystem: "github.ecdye.macSubtitleOCR", category: "SubtitleProcessor")

actor SubtitleAccumulator {
var subtitles: [Subtitle] = []
var json: [SubtitleJSONResult] = []

func appendSubtitle(_ subtitle: Subtitle) {
subtitles.append(subtitle)
}

func appendJSON(_ jsonOut: SubtitleJSONResult) {
json.append(jsonOut)
}
}

actor AsyncSemaphore {
private var permits: Int

init(limit: Int) {
permits = limit
}

func wait() async {
while permits <= 0 {
await Task.yield()
}
permits -= 1
}

func signal() {
permits += 1
}
}

struct SubtitleProcessor {
let subtitles: [Subtitle]
let trackNumber: Int
let invert: Bool
let saveImages: Bool
let language: String
let fastMode: Bool
let disableLanguageCorrection: Bool
let forceOldAPI: Bool
let outputDirectory: String
let maxConcurrentTasks: Int

func process() async throws -> macSubtitleOCRResult {
let accumulator = SubtitleAccumulator()
let semaphore = AsyncSemaphore(limit: maxConcurrentTasks) // Limit concurrent tasks

try await withThrowingDiscardingTaskGroup { group in
for subtitle in subtitles {
group.addTask {
// Wait for permission to start the task
await semaphore.wait()
let subIndex = subtitle.index

guard !shouldSkipSubtitle(subtitle, at: subIndex) else {
await semaphore.signal()
return
}

guard let subImage = subtitle.createImage(invert) else {
logger.warning("Could not create image for index \(subIndex)! Skipping...")
await semaphore.signal()
return
}

// Save subtitle image as PNG if requested
if saveImages {
do {
try saveImage(subImage, index: subIndex)
} catch {
logger.error("Error saving image \(trackNumber)-\(subIndex): \(error.localizedDescription)")
}
}

let (subtitleText, subtitleLines) = await recognizeText(from: subImage, at: subIndex)
subtitle.text = subtitleText
subtitle.imageData = nil // Clear the image data to save memory

let jsonOut = SubtitleJSONResult(index: subIndex, lines: subtitleLines, text: subtitleText)

// Safely append to the arrays using the actor
await accumulator.appendSubtitle(subtitle)
await accumulator.appendJSON(jsonOut)
await semaphore.signal()
}
}
}

return await macSubtitleOCRResult(trackNumber: trackNumber, srt: accumulator.subtitles, json: accumulator.json)
}

private func shouldSkipSubtitle(_ subtitle: Subtitle, at index: Int) -> Bool {
if subtitle.imageWidth == 0 || subtitle.imageHeight == 0 {
logger.warning("Skipping subtitle index \(index) with empty image data!")
return true
}
return false
}

private func recognizeText(from image: CGImage, at _: Int) async -> (String, [SubtitleLine]) {
var text = ""
var lines: [SubtitleLine] = []

if !forceOldAPI, #available(macOS 15.0, *) {
let request = createRecognizeTextRequest()
let observations = try? await request.perform(on: image) as [RecognizedTextObservation]
let size = CGSize(width: image.width, height: image.height)
processRecognizedText(observations, &text, &lines, size)
} else {
let request = VNRecognizeTextRequest()
request.recognitionLevel = getOCRMode()
request.usesLanguageCorrection = !disableLanguageCorrection
request.revision = VNRecognizeTextRequestRevision3
request.recognitionLanguages = language.split(separator: ",").map { String($0) }

try? VNImageRequestHandler(cgImage: image, options: [:]).perform([request])
let observations = request.results! as [VNRecognizedTextObservation]
processVNRecognizedText(observations, &text, &lines, image.width, image.height)
}

return (text, lines)
}

@available(macOS 15.0, *)
private func createRecognizeTextRequest() -> RecognizeTextRequest {
var request = RecognizeTextRequest()
request.recognitionLevel = getOCRMode()
request.usesLanguageCorrection = !disableLanguageCorrection
request.recognitionLanguages = language.split(separator: ",").map { Locale.Language(identifier: String($0)) }
return request
}

@available(macOS 15.0, *)
private func processRecognizedText(_ result: [RecognizedTextObservation]?, _ text: inout String,
_ lines: inout [SubtitleLine], _ size: CGSize) {
text = result?.compactMap { observation in
guard let candidate = observation.topCandidates(1).first else { return "" }

let string = candidate.string
let confidence = candidate.confidence
let stringRange = string.startIndex ..< string.endIndex
let boundingBox = candidate.boundingBox(for: stringRange)!.boundingBox
let rect = boundingBox.toImageCoordinates(size, origin: .upperLeft)
let line = SubtitleLine(
text: string,
confidence: confidence,
x: max(0, Int(rect.minX)),
width: Int(rect.size.width),
y: max(0, Int(size.height - rect.minY - rect.size.height)),
height: Int(rect.size.height))
lines.append(line)

return string
}.joined(separator: "\n") ?? ""
}

private func processVNRecognizedText(_ observations: [VNRecognizedTextObservation], _ text: inout String,
_ lines: inout [SubtitleLine], _ width: Int, _ height: Int) {
text = observations.compactMap { observation in
guard let candidate = observation.topCandidates(1).first else { return "" }

let string = candidate.string
let confidence = candidate.confidence
let stringRange = string.startIndex ..< string.endIndex
let boundingBox = try? candidate.boundingBox(for: stringRange)?.boundingBox ?? .zero
let rect = VNImageRectForNormalizedRect(boundingBox ?? .zero, width, height)

let line = SubtitleLine(
text: string,
confidence: confidence,
x: max(0, Int(rect.minX)),
width: Int(rect.size.width),
y: max(0, Int(CGFloat(height) - rect.minY - rect.size.height)),
height: Int(rect.size.height))
lines.append(line)

return string
}.joined(separator: "\n")
}

private func saveImage(_ image: CGImage, index: Int) throws {
let outputDirectory = URL(fileURLWithPath: outputDirectory)
let imageDirectory = outputDirectory.appendingPathComponent("images/" + "track_\(trackNumber)/")
let pngPath = imageDirectory.appendingPathComponent("subtitle_\(index).png")

try FileManager.default.createDirectory(at: imageDirectory, withIntermediateDirectories: true, attributes: nil)

let destination = CGImageDestinationCreateWithURL(pngPath as CFURL, UTType.png.identifier as CFString, 1, nil)
guard let destination else {
throw macSubtitleOCRError.fileCreationError
}
CGImageDestinationAddImage(destination, image, nil)
guard CGImageDestinationFinalize(destination) else {
throw macSubtitleOCRError.fileWriteError
}
}

@available(macOS 15.0, *)
private func getOCRMode() -> RecognizeTextRequest.RecognitionLevel {
fastMode ? .fast : .accurate
}

private func getOCRMode() -> VNRequestTextRecognitionLevel {
fastMode ? .fast : .accurate
}
}
6 changes: 3 additions & 3 deletions Sources/macSubtitleOCR/Subtitles/FFmpeg/FFmpeg.swift
Original file line number Diff line number Diff line change
Expand Up @@ -90,7 +90,7 @@ struct FFmpeg {
var trackSubtitles = subtitleTracks[streamNumber] ?? []
for i in 0 ..< Int(subtitle.num_rects) {
let rect = subtitle.rects[i]!
let sub = extractImageData(from: rect)
let sub = extractImageData(from: rect, index: trackSubtitles.count + 1)
let pts = convertToTimeInterval(packet!.pointee.pts, timeBase: stream.timeBase)
sub.startTimestamp = pts + convertToTimeInterval(subtitle.start_display_time, timeBase: timeBase)
sub.endTimestamp = pts + convertToTimeInterval(subtitle.end_display_time, timeBase: timeBase)
Expand All @@ -102,8 +102,8 @@ struct FFmpeg {
}
}

private func extractImageData(from rect: UnsafeMutablePointer<AVSubtitleRect>) -> Subtitle {
let subtitle = Subtitle(numberOfColors: Int(rect.pointee.nb_colors))
private func extractImageData(from rect: UnsafeMutablePointer<AVSubtitleRect>, index: Int) -> Subtitle {
let subtitle = Subtitle(index: index, numberOfColors: Int(rect.pointee.nb_colors))

// Check if the subtitle is an image (bitmap)
if rect.pointee.type == SUBTITLE_BITMAP {
Expand Down
1 change: 1 addition & 0 deletions Sources/macSubtitleOCR/Subtitles/PGS/PGS.swift
Original file line number Diff line number Diff line change
Expand Up @@ -120,6 +120,7 @@ struct PGS {
guard let pds, let ods else { continue }
let startTimestamp = parseTimestamp(headerData)
return Subtitle(
index: subtitles.count + 1,
startTimestamp: startTimestamp,
imageWidth: ods.objectWidth,
imageHeight: ods.objectHeight,
Expand Down
18 changes: 16 additions & 2 deletions Sources/macSubtitleOCR/Subtitles/SRT/SRT.swift
Original file line number Diff line number Diff line change
Expand Up @@ -7,11 +7,13 @@
//

import Foundation
import os

struct SRT {
// MARK: - Properties

private var subtitles: [Subtitle] = []
private let logger = Logger(subsystem: "github.ecdye.macSubtitleOCR", category: "SRT")

// MARK: - Getters / Setters

Expand All @@ -38,10 +40,22 @@ struct SRT {
var srtContent = ""

for subtitle in subtitles {
var endTimestamp = subtitle.endTimestamp ?? 0
if subtitle.index + 1 < subtitles.count {
let nextSubtitle = subtitles[subtitle.index + 1]
if nextSubtitle.startTimestamp! <= subtitle.endTimestamp! {
logger.warning("Fixing subtitle index \(subtitle.index) end timestamp!")
if nextSubtitle.startTimestamp! - subtitle.startTimestamp! > 5 {
endTimestamp = subtitle.startTimestamp! + 5
} else {
endTimestamp = nextSubtitle.startTimestamp! - 0.1
}
}
}
let startTime = formatTime(subtitle.startTimestamp!)
let endTime = formatTime(subtitle.endTimestamp!)
let endTime = formatTime(endTimestamp)

srtContent += "\(subtitle.index!)\n"
srtContent += "\(subtitle.index)\n"
srtContent += "\(startTime) --> \(endTime)\n"
srtContent += "\(subtitle.text!)\n\n"
}
Expand Down
6 changes: 3 additions & 3 deletions Sources/macSubtitleOCR/Subtitles/Subtitle.swift
Original file line number Diff line number Diff line change
Expand Up @@ -9,8 +9,8 @@
import CoreGraphics
import Foundation

class Subtitle {
var index: Int?
class Subtitle: @unchecked Sendable {
var index: Int
var text: String?
var startTimestamp: TimeInterval?
var imageXOffset: Int?
Expand All @@ -25,7 +25,7 @@ class Subtitle {
var evenOffset: Int?
var oddOffset: Int?

init(index: Int? = nil, text: String? = nil, startTimestamp: TimeInterval? = nil, endTimestamp: TimeInterval? = nil,
init(index: Int, text: String? = nil, startTimestamp: TimeInterval? = nil, endTimestamp: TimeInterval? = nil,
imageXOffset: Int? = nil, imageYOffset: Int? = nil, imageWidth: Int? = nil, imageHeight: Int? = nil,
imageData: Data? = nil, imagePalette: [UInt8]? = nil, imageAlpha: [UInt8]? = nil, numberOfColors: Int? = nil,
evenOffset: Int? = nil, oddOffset: Int? = nil) {
Expand Down
1 change: 1 addition & 0 deletions Sources/macSubtitleOCR/Subtitles/VobSub/VobSub.swift
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@ struct VobSub {
subFile.seekToEndOfFile()
}
let subtitle = VobSubParser(
index: index + 1,
subFile: subFile,
timestamp: timestamp,
offset: offset,
Expand Down
Loading

0 comments on commit 7af4566

Please sign in to comment.