Skip to content

Commit

Permalink
Improved handling of connections in Python
Browse files Browse the repository at this point in the history
  • Loading branch information
linusseelinger committed Feb 22, 2023
1 parent 6bccf42 commit 805b9ad
Showing 1 changed file with 16 additions and 6 deletions.
22 changes: 16 additions & 6 deletions umbridge/um.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
import requests
import asyncio
from concurrent.futures import ThreadPoolExecutor
import os

class Model(object):

Expand Down Expand Up @@ -44,6 +45,8 @@ def __init__(self, url, name):
if (name not in supported_models(url)):
raise Exception(f'Model {name} not supported by server! Supported models are: {supported_models(url)}')

self.__pid = -1

input = {}
input["name"] = name
response = requests.post(f"{self.url}/ModelInfo", json=input).json()
Expand All @@ -52,18 +55,25 @@ def __init__(self, url, name):
self.__supports_apply_jacobian = response["support"].get("ApplyJacobian", False)
self.__supports_apply_hessian = response["support"].get("ApplyHessian", False)

# This makes sure that the session is recreated if the process is forked (in particular when using multiprocessing)
def get_process_local_session(self):
if self.__pid != os.getpid():
self.__pid = os.getpid()
self.__session = requests.Session()
return self.__session

def get_input_sizes(self, config={}):
input = {}
input["name"] = self.name
input["config"] = config
response = requests.post(f"{self.url}/InputSizes", json=input).json()
response = self.get_process_local_session().post(f"{self.url}/InputSizes", json=input).json()
return response["inputSizes"]

def get_output_sizes(self, config={}):
input = {}
input["name"] = self.name
input["config"] = config
response = requests.post(f"{self.url}/OutputSizes", json=input).json()
response = self.get_process_local_session().post(f"{self.url}/OutputSizes", json=input).json()
return response["outputSizes"]

def supports_evaluate(self):
Expand Down Expand Up @@ -93,7 +103,7 @@ def __call__(self, parameters, config={}):
inputParams["name"] = self.name
inputParams["input"] = parameters
inputParams["config"] = config
response = requests.post(f"{self.url}/Evaluate", json=inputParams).json()
response = self.get_process_local_session().post(f"{self.url}/Evaluate", json=inputParams).json()

if "error" in response:
raise Exception(f'Model returned error of type {response["error"]["type"]}: {response["error"]["message"]}')
Expand All @@ -111,7 +121,7 @@ def gradient(self, out_wrt, in_wrt, parameters, sens, config={}):
inputParams["input"] = parameters
inputParams["sens"] = sens
inputParams["config"] = config
response = requests.post(f"{self.url}/Gradient", json=inputParams).json()
response = self.get_process_local_session().post(f"{self.url}/Gradient", json=inputParams).json()

if "error" in response:
raise Exception(f'Model returned error of type {response["error"]["type"]}: {response["error"]["message"]}')
Expand All @@ -129,7 +139,7 @@ def apply_jacobian(self, out_wrt, in_wrt, parameters, vec, config={}):
inputParams["input"] = parameters
inputParams["vec"] = vec
inputParams["config"] = config
response = requests.post(f"{self.url}/ApplyJacobian", json=inputParams).json()
response = self.get_process_local_session().post(f"{self.url}/ApplyJacobian", json=inputParams).json()

if "error" in response:
raise Exception(f'Model returned error of type {response["error"]["type"]}: {response["error"]["message"]}')
Expand All @@ -149,7 +159,7 @@ def apply_hessian(self, out_wrt, in_wrt1, in_wrt2, parameters, sens, vec, config
inputParams["sens"] = sens
inputParams["vec"] = vec
inputParams["config"] = config
response = requests.post(f"{self.url}/ApplyHessian", json=inputParams).json()
response = self.get_process_local_session().post(f"{self.url}/ApplyHessian", json=inputParams).json()

if "error" in response:
raise Exception(f'Model returned error of type {response["error"]["type"]}: {response["error"]["message"]}')
Expand Down

0 comments on commit 805b9ad

Please sign in to comment.