Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Try to optimize proxy_func #507

Merged
merged 4 commits into from
May 28, 2024
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
52 changes: 27 additions & 25 deletions wgpu/backends/wgpu_native/_helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@

import sys
import ctypes
from queue import deque

from ._ffi import ffi, lib
from ..._diagnostics import DiagnosticsBase
Expand Down Expand Up @@ -237,29 +238,38 @@ class ErrorHandler:

def __init__(self, logger):
self._logger = logger
self._proxy_stack = []
self._proxy_stack = deque()
self._proxy_messages = {}
self._error_message_counts = {}

def capture(self, func):
"""Send incoming error messages to the given func instead of logging them."""
self._proxy_stack.append(func)
def capture(self, name):
"""Capture incoming error messages instead of logging them directly."""
self._proxy_stack.append(name)

def release(self, func):
"""Release the given func."""
f = self._proxy_stack.pop(-1)
if f is not func:
def release(self, name):
"""Release the given name, returning the last captured error."""
n = self._proxy_stack.pop()
if n is not name:
messages = [m for _, m in self._proxy_message.values()]
self._proxy_messages.clear()
self._proxy_stack.clear()
self._logger.warning("ErrorHandler capture/release out of sync")
self._logger.error("ErrorHandler capture/release out of sync")
for message in messages:
self.log_error(message)
return self._proxy_messages.pop(name, None)

def handle_error(self, error_type: str, message: str):
"""Handle an error message."""
if self._proxy_stack:
self._proxy_stack[-1](error_type, message)
proxy_name = self._proxy_stack[-1]
if proxy_name in self._proxy_messages:
self.log_error(self._proxy_messages[proxy_name][1])
self._proxy_messages[proxy_name] = error_type, message
else:
self.log_error(message)

def log_error(self, message):
"""Hanle an error message by logging it, bypassing any capturing."""
"""Handle an error message by logging it, bypassing any capturing."""
# Get count for this message. Use a hash that does not use the
# digits in the message, because of id's getting renewed on
# each draw.
Expand All @@ -283,7 +293,6 @@ class SafeLibCalls:

def __init__(self, lib, error_handler):
self._error_handler = error_handler
self._error_message = None
self._make_function_copies(lib)

def _make_function_copies(self, lib):
Expand All @@ -293,27 +302,20 @@ def _make_function_copies(self, lib):
if callable(ob):
setattr(self, name, self._make_proxy_func(name, ob))

def _handle_error(self, error_type, message):
# If we already had an error, we log the earlier one now
if self._error_message:
self._error_handler.log_error(self._error_message[1])
# Store new error
self._error_message = (error_type, message)

def _make_proxy_func(self, name, ob):
error_handler = self._error_handler

def proxy_func(*args):
# Make the call, with error capturing on
handle_error = self._handle_error
self._error_handler.capture(handle_error)
error_handler.capture(name)
try:
result = ob(*args)
finally:
self._error_handler.release(handle_error)
error_type_msg = error_handler.release(name)

# Handle the error.
if self._error_message:
error_type, message = self._error_message
self._error_message = None
if error_type_msg is not None:
error_type, message = error_type_msg
cls = ERROR_TYPES.get(error_type, GPUError)
wgpu_error = cls(message)
# The line below will be the bottom line in the traceback,
Expand Down
Loading