Skip to content

Commit

Permalink
Chore: Make release 1.0.41
Browse files Browse the repository at this point in the history
  • Loading branch information
martinroberson authored and Vanden Bon, David V [GBM Public] committed Sep 29, 2023
1 parent d0437a4 commit e313d80
Show file tree
Hide file tree
Showing 4 changed files with 59 additions and 25 deletions.
47 changes: 47 additions & 0 deletions gs_quant/api/api_session.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,47 @@
"""
Copyright 2023 Goldman Sachs.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing,
software distributed under the License is distributed on an
"AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
KIND, either express or implied. See the License for the
specific language governing permissions and limitations
under the License.
"""
from typing import Callable, Optional

from gs_quant.session import GsSession


class ApiWithCustomSession:
__SESSION_SUPPLIER: Optional[Callable[[], GsSession]] = None

@classmethod
def set_session_provider(cls, session_supplier: Callable[[], GsSession]):
"""
To allow session context override specific to this API, set a factory/supplier.
Default is GsSession.current.
:param session_supplier: callable which returns a GsSession
"""
cls.__SESSION_SUPPLIER = session_supplier

@classmethod
def set_session(cls, session: GsSession):
"""
To allow session context override specific to this API, set a session directly.
Default is GsSession.current.
:param session: a GsSession
"""
cls.__SESSION_SUPPLIER = None if session is None else lambda: session

@classmethod
def get_session(cls) -> GsSession:
if cls.__SESSION_SUPPLIER:
return cls.__SESSION_SUPPLIER()
else:
return GsSession.current
14 changes: 7 additions & 7 deletions gs_quant/api/gs/risk.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,6 @@

from gs_quant.api.risk import RiskApi
from gs_quant.risk import RiskRequest
from gs_quant.session import GsSession
from gs_quant.target.risk import OptimizationRequest
from gs_quant.tracing import Tracer

Expand Down Expand Up @@ -63,7 +62,7 @@ def calc(cls, request: RiskRequest) -> Iterable:
def _exec(cls, request: Union[RiskRequest, Iterable[RiskRequest]]) -> Union[Iterable, dict]:
use_msgpack = cls.USE_MSGPACK and not isinstance(request, RiskRequest)
headers = {'Content-Type': 'application/x-msgpack'} if use_msgpack else {}
result, request_id = GsSession.current._post(cls.__url(request),
result, request_id = cls.get_session()._post(cls.__url(request),
request,
request_headers=headers,
timeout=181,
Expand Down Expand Up @@ -122,7 +121,7 @@ async def __get_results_poll(cls, responses: asyncio.Queue, results: asyncio.Que
# ... poll for completed requests ...

try:
calc_results = GsSession.current._post('/risk/calculate/results/bulk', list(pending_requests.keys()))
calc_results = cls.get_session()._post('/risk/calculate/results/bulk', list(pending_requests.keys()))

# ... enqueue the request and result for the listener to handle ...
for result in calc_results:
Expand Down Expand Up @@ -236,10 +235,11 @@ async def handle_websocket():

try:
ws_url = '/risk/calculate/results/subscribe'
async with GsSession.current._connect_websocket(ws_url) as ws:
async with cls.get_session()._connect_websocket(ws_url) as ws:
if span:
Tracer.get_instance().scope_manager.activate(span, finish_on_close=False)
with Tracer(f'wss:/{ws_url}'):
with Tracer(f'wss:/{ws_url}') as scope:
scope.span.set_tag('wss.host', ws.request_headers.get('host'))
error = await handle_websocket()
else:
error = await handle_websocket()
Expand All @@ -265,7 +265,7 @@ async def handle_websocket():
@classmethod
def create_pretrade_execution_optimization(cls, request: OptimizationRequest) -> str:
try:
response = GsSession.current._post(r'/risk/execution/pretrade', request)
response = cls.get_session()._post(r'/risk/execution/pretrade', request)
_logger.info('New optimization is created with id: {}'.format(response.get("optimizationId")))
return response
except Exception as e:
Expand All @@ -285,7 +285,7 @@ def get_pretrade_execution_optimization(cls, optimization_id: str, max_attempts:
time.sleep(math.pow(2, attempts))
_logger.error('Retrying (attempt {} of {})'.format(attempts, max_attempts))
try:
results = GsSession.current._get(url)
results = cls.get_session()._get(url)
if results.get('status') == 'Running':
attempts += 1
else:
Expand Down
20 changes: 4 additions & 16 deletions gs_quant/api/risk.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,11 +21,12 @@
from abc import ABCMeta, abstractmethod
from concurrent.futures import TimeoutError
from threading import Thread
from typing import Iterable, Optional, Union, Tuple, Callable
from typing import Iterable, Optional, Union, Tuple

from opentracing import Span
from tqdm import tqdm

from gs_quant.api.api_session import ApiWithCustomSession
from gs_quant.base import RiskKey, Sentinel
from gs_quant.risk import ErrorValue, RiskRequest
from gs_quant.risk.result_handlers import result_handlers
Expand All @@ -35,9 +36,8 @@
_logger = logging.getLogger(__name__)


class RiskApi(metaclass=ABCMeta):
class RiskApi(ApiWithCustomSession, metaclass=ABCMeta):
__SHUTDOWN_SENTINEL = Sentinel('QueueListenerShutdown')
__SESSION_SUPPLIER: Optional[Callable[[], GsSession]] = None

@classmethod
@abstractmethod
Expand Down Expand Up @@ -185,10 +185,7 @@ def num_risk_keys(request: RiskRequest):
results_handler = None

# determine session to use
if cls.__SESSION_SUPPLIER:
session = cls.__SESSION_SUPPLIER()
else:
session = GsSession.current
session = cls.get_session()

# The requests library (which we use for dispatching) is not async, so we need a thread for concurrency
Thread(daemon=True,
Expand Down Expand Up @@ -321,12 +318,3 @@ def _handle_results(cls, request: RiskRequest, results: Union[Iterable, Exceptio
formatted_results[(risk_key, position.instrument)] = result

return formatted_results

@classmethod
def set_session_provider(cls, session_supplier: Callable[[], GsSession]):
"""
To allow session context override specific to this API, set a factory/supplier.
Default is GsSession.current.
:param session_supplier: callable which returns a GsSession
"""
cls.__SESSION_SUPPLIER = session_supplier
3 changes: 1 addition & 2 deletions gs_quant/risk/result_handlers.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,6 @@
sort_values, MQVSValidatorDefnsWithInfo, MQVSValidatorDefn

_logger = logging.getLogger(__name__)
__scalar_risk_measures = ('EqDelta', 'EqGamma', 'EqVega')


def __dataframe_handler(result: Iterable, mappings: tuple, risk_key: RiskKey, request_id: Optional[str] = None) \
Expand Down Expand Up @@ -187,7 +186,7 @@ def risk_vector_handler(result: dict, risk_key: RiskKey, _instrument: Instrument
request_id: Optional[str] = None) -> DataFrameWithInfo:
assets = result['asset']
# Handle equity risk measures which are really scalars
if len(assets) == 1 and risk_key.risk_measure.name in __scalar_risk_measures:
if len(assets) == 1 and risk_key.risk_measure.name.startswith('Eq'):
return FloatWithInfo(risk_key, assets[0], request_id=request_id)

for points, value in zip(result['points'], assets):
Expand Down

0 comments on commit e313d80

Please sign in to comment.