Skip to content

Commit

Permalink
[Fix] Release resources after use (#133)
Browse files Browse the repository at this point in the history
  • Loading branch information
Bobholamovic authored Dec 22, 2023
1 parent 5d7fc25 commit 7fba798
Show file tree
Hide file tree
Showing 8 changed files with 222 additions and 180 deletions.
9 changes: 4 additions & 5 deletions erniebot/src/erniebot/backends/aistudio.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,15 +63,15 @@ def request(
method: str,
path: str,
stream: bool,
*,
params: Optional[ParamsType] = None,
headers: Optional[HeadersType] = None,
files: Optional[FilesType] = None,
request_timeout: Optional[float] = None,
) -> Union[EBResponse, Iterator[EBResponse]]:
url = self._get_url(path)
url, headers, data = self._client.prepare_request(
method,
url,
path,
supplied_headers=headers,
params=params,
files=files,
Expand All @@ -85,23 +85,22 @@ def request(
headers=headers,
files=files,
request_timeout=request_timeout,
base_url=self.base_url,
)

async def arequest(
self,
method: str,
path: str,
stream: bool,
*,
params: Optional[ParamsType] = None,
headers: Optional[HeadersType] = None,
files: Optional[FilesType] = None,
request_timeout: Optional[float] = None,
) -> Union[EBResponse, AsyncIterator[EBResponse]]:
url = self._get_url(path)
url, headers, data = self._client.prepare_request(
method,
url,
path,
supplied_headers=headers,
params=params,
files=files,
Expand Down
15 changes: 9 additions & 6 deletions erniebot/src/erniebot/backends/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,12 +26,16 @@ class EBBackend(object):

def __init__(self, config_dict: ConfigDictType) -> None:
super().__init__()

self.api_type = self.API_TYPE
self.base_url = config_dict.get("api_base_url", None) or self.BASE_URL

self._cfg = config_dict
self._client = EBClient(self.handle_response, proxy=self._cfg.get("proxy", None))
self._client = EBClient(
self.base_url,
session=self._cfg.get("requests_session", None),
asession=self._cfg.get("aiohttp_session", None),
response_handler=self.handle_response,
proxy=self._cfg.get("proxy", None),
)

def handle_response(self, resp: EBResponse) -> EBResponse:
raise NotImplementedError
Expand All @@ -41,6 +45,7 @@ def request(
method: str,
path: str,
stream: bool,
*,
params: Optional[ParamsType] = None,
headers: Optional[HeadersType] = None,
files: Optional[FilesType] = None,
Expand All @@ -53,12 +58,10 @@ async def arequest(
method: str,
path: str,
stream: bool,
*,
params: Optional[ParamsType] = None,
headers: Optional[HeadersType] = None,
files: Optional[FilesType] = None,
request_timeout: Optional[float] = None,
) -> Union[EBResponse, AsyncIterator[EBResponse]]:
raise NotImplementedError

def _get_url(self, path: str) -> str:
return f"{self.base_url}{path}"
18 changes: 8 additions & 10 deletions erniebot/src/erniebot/backends/bce.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,15 +46,15 @@ def request(
method: str,
path: str,
stream: bool,
*,
params: Optional[ParamsType] = None,
headers: Optional[HeadersType] = None,
files: Optional[FilesType] = None,
request_timeout: Optional[float] = None,
) -> Union[EBResponse, Iterator[EBResponse]]:
url = self._get_url(path)
url, headers, data = self._client.prepare_request(
method,
url,
path,
supplied_headers=headers,
params=params,
files=files,
Expand All @@ -71,7 +71,6 @@ def request(
headers=headers,
files=files,
request_timeout=request_timeout,
base_url=self.base_url,
)
except (errors.TokenExpiredError, errors.InvalidTokenError):
logging.warning(
Expand All @@ -88,23 +87,22 @@ def request(
headers=headers,
files=files,
request_timeout=request_timeout,
base_url=self.base_url,
)

async def arequest(
self,
method: str,
path: str,
stream: bool,
*,
params: Optional[ParamsType] = None,
headers: Optional[HeadersType] = None,
files: Optional[FilesType] = None,
request_timeout: Optional[float] = None,
) -> Union[EBResponse, AsyncIterator[EBResponse]]:
url = self._get_url(path)
url, headers, data = self._client.prepare_request(
method,
url,
path,
supplied_headers=headers,
params=params,
files=files,
Expand Down Expand Up @@ -160,15 +158,15 @@ def request(
method: str,
path: str,
stream: bool,
*,
params: Optional[ParamsType] = None,
headers: Optional[HeadersType] = None,
files: Optional[FilesType] = None,
request_timeout: Optional[float] = None,
) -> Union[EBResponse, Iterator[EBResponse]]:
url = self._get_url(path)
url, headers, data = self._client.prepare_request(
method,
url,
path,
supplied_headers=headers,
params=params,
files=files,
Expand All @@ -189,15 +187,15 @@ async def arequest(
method: str,
path: str,
stream: bool,
*,
params: Optional[ParamsType] = None,
headers: Optional[HeadersType] = None,
files: Optional[FilesType] = None,
request_timeout: Optional[float] = None,
) -> Union[EBResponse, AsyncIterator[EBResponse]]:
url = self._get_url(path)
url, headers, data = self._client.prepare_request(
method,
url,
path,
supplied_headers=headers,
params=params,
files=files,
Expand Down
9 changes: 4 additions & 5 deletions erniebot/src/erniebot/backends/custom.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,15 +37,15 @@ def request(
method: str,
path: str,
stream: bool,
*,
params: Optional[ParamsType] = None,
headers: Optional[HeadersType] = None,
files: Optional[FilesType] = None,
request_timeout: Optional[float] = None,
) -> Union[EBResponse, Iterator[EBResponse]]:
url = self._get_url(path)
url, headers, data = self._client.prepare_request(
method,
url,
path,
supplied_headers=headers,
params=params,
files=files,
Expand All @@ -59,23 +59,22 @@ def request(
headers=headers,
files=files,
request_timeout=request_timeout,
base_url=self.base_url,
)

async def arequest(
self,
method: str,
path: str,
stream: bool,
*,
params: Optional[ParamsType] = None,
headers: Optional[HeadersType] = None,
files: Optional[FilesType] = None,
request_timeout: Optional[float] = None,
) -> Union[EBResponse, AsyncIterator[EBResponse]]:
url = self._get_url(path)
url, headers, data = self._client.prepare_request(
method,
url,
path,
supplied_headers=headers,
params=params,
files=files,
Expand Down
23 changes: 20 additions & 3 deletions erniebot/src/erniebot/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,9 +54,13 @@ def init_global_config() -> None:
# Miscellaneous settings
# Proxy to use
cfg.add_item(URLItem(key="proxy", env_key="EB_PROXY"))
# requests session
cfg.add_item(AnyObjectItem(key="requests_session"))
# aiohttp session
cfg.add_item(AnyObjectItem(key="aiohttp_session"))


class _BaseConfig(object):
class _Config(object):
def __init__(self, cfg_dict: Optional[Dict[str, "_ConfigItem"]] = None) -> None:
super().__init__()
self._cfg_dict: Dict[str, "_ConfigItem"] = cfg_dict if cfg_dict is not None else dict()
Expand All @@ -75,7 +79,7 @@ def set_value(self, key: str, value: Any) -> None:
cfg.value = value


class GlobalConfig(_BaseConfig, metaclass=Singleton):
class GlobalConfig(_Config, metaclass=Singleton):
def create_dict(self, **overrides: Any) -> ConfigDictType:
dict_: ConfigDictType = {}
for key, cfg in self._cfg_dict.items():
Expand Down Expand Up @@ -153,7 +157,8 @@ def __init__(
def factory(self, env_val: str) -> Any:
if self.ensure_integer:
return int(env_val)
return float(env_val)
else:
return float(env_val)

def _validate(self, val: Any) -> None:
if not isinstance(val, int if self.ensure_integer else numbers.Real):
Expand Down Expand Up @@ -191,3 +196,15 @@ def _validate(self, val: Any) -> None:
res = re.match(pat, val)
if res is None:
raise ValueError(f"Invalid URL: {val}")


class AnyObjectItem(_ConfigItem):
def __init__(self, key: str, default: Any = None) -> None:
super().__init__(key=key, env_key=None, default=default)

def factory(self, env_val: str) -> Any:
raise AssertionError

def _validate(self, val: Any) -> None:
# Any object is valid
pass
15 changes: 0 additions & 15 deletions erniebot/src/erniebot/functions/__init__.py

This file was deleted.

Loading

0 comments on commit 7fba798

Please sign in to comment.