Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Fix] Release resources after use #133

Merged
merged 11 commits into from
Dec 22, 2023
9 changes: 4 additions & 5 deletions erniebot/src/erniebot/backends/aistudio.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,15 +63,15 @@ def request(
method: str,
path: str,
stream: bool,
*,
params: Optional[ParamsType] = None,
headers: Optional[HeadersType] = None,
files: Optional[FilesType] = None,
request_timeout: Optional[float] = None,
) -> Union[EBResponse, Iterator[EBResponse]]:
url = self._get_url(path)
url, headers, data = self._client.prepare_request(
method,
url,
path,
supplied_headers=headers,
params=params,
files=files,
Expand All @@ -85,23 +85,22 @@ def request(
headers=headers,
files=files,
request_timeout=request_timeout,
base_url=self.base_url,
)

async def arequest(
self,
method: str,
path: str,
stream: bool,
*,
params: Optional[ParamsType] = None,
headers: Optional[HeadersType] = None,
files: Optional[FilesType] = None,
request_timeout: Optional[float] = None,
) -> Union[EBResponse, AsyncIterator[EBResponse]]:
url = self._get_url(path)
url, headers, data = self._client.prepare_request(
method,
url,
path,
supplied_headers=headers,
params=params,
files=files,
Expand Down
15 changes: 9 additions & 6 deletions erniebot/src/erniebot/backends/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,12 +26,16 @@ class EBBackend(object):

def __init__(self, config_dict: ConfigDictType) -> None:
super().__init__()

self.api_type = self.API_TYPE
self.base_url = config_dict.get("api_base_url", None) or self.BASE_URL

self._cfg = config_dict
self._client = EBClient(self.handle_response, proxy=self._cfg.get("proxy", None))
self._client = EBClient(
self.base_url,
session=self._cfg.get("requests_session", None),
asession=self._cfg.get("aiohttp_session", None),
response_handler=self.handle_response,
proxy=self._cfg.get("proxy", None),
)

def handle_response(self, resp: EBResponse) -> EBResponse:
raise NotImplementedError
Expand All @@ -41,6 +45,7 @@ def request(
method: str,
path: str,
stream: bool,
*,
params: Optional[ParamsType] = None,
headers: Optional[HeadersType] = None,
files: Optional[FilesType] = None,
Expand All @@ -53,12 +58,10 @@ async def arequest(
method: str,
path: str,
stream: bool,
*,
params: Optional[ParamsType] = None,
headers: Optional[HeadersType] = None,
files: Optional[FilesType] = None,
request_timeout: Optional[float] = None,
) -> Union[EBResponse, AsyncIterator[EBResponse]]:
raise NotImplementedError

def _get_url(self, path: str) -> str:
return f"{self.base_url}{path}"
18 changes: 8 additions & 10 deletions erniebot/src/erniebot/backends/bce.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,15 +46,15 @@ def request(
method: str,
path: str,
stream: bool,
*,
params: Optional[ParamsType] = None,
headers: Optional[HeadersType] = None,
files: Optional[FilesType] = None,
request_timeout: Optional[float] = None,
) -> Union[EBResponse, Iterator[EBResponse]]:
url = self._get_url(path)
url, headers, data = self._client.prepare_request(
method,
url,
path,
supplied_headers=headers,
params=params,
files=files,
Expand All @@ -71,7 +71,6 @@ def request(
headers=headers,
files=files,
request_timeout=request_timeout,
base_url=self.base_url,
)
except (errors.TokenExpiredError, errors.InvalidTokenError):
logging.warning(
Expand All @@ -88,23 +87,22 @@ def request(
headers=headers,
files=files,
request_timeout=request_timeout,
base_url=self.base_url,
)

async def arequest(
self,
method: str,
path: str,
stream: bool,
*,
params: Optional[ParamsType] = None,
headers: Optional[HeadersType] = None,
files: Optional[FilesType] = None,
request_timeout: Optional[float] = None,
) -> Union[EBResponse, AsyncIterator[EBResponse]]:
url = self._get_url(path)
url, headers, data = self._client.prepare_request(
method,
url,
path,
supplied_headers=headers,
params=params,
files=files,
Expand Down Expand Up @@ -160,15 +158,15 @@ def request(
method: str,
path: str,
stream: bool,
*,
params: Optional[ParamsType] = None,
headers: Optional[HeadersType] = None,
files: Optional[FilesType] = None,
request_timeout: Optional[float] = None,
) -> Union[EBResponse, Iterator[EBResponse]]:
url = self._get_url(path)
url, headers, data = self._client.prepare_request(
method,
url,
path,
supplied_headers=headers,
params=params,
files=files,
Expand All @@ -189,15 +187,15 @@ async def arequest(
method: str,
path: str,
stream: bool,
*,
params: Optional[ParamsType] = None,
headers: Optional[HeadersType] = None,
files: Optional[FilesType] = None,
request_timeout: Optional[float] = None,
) -> Union[EBResponse, AsyncIterator[EBResponse]]:
url = self._get_url(path)
url, headers, data = self._client.prepare_request(
method,
url,
path,
supplied_headers=headers,
params=params,
files=files,
Expand Down
9 changes: 4 additions & 5 deletions erniebot/src/erniebot/backends/custom.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,15 +37,15 @@ def request(
method: str,
path: str,
stream: bool,
*,
params: Optional[ParamsType] = None,
headers: Optional[HeadersType] = None,
files: Optional[FilesType] = None,
request_timeout: Optional[float] = None,
) -> Union[EBResponse, Iterator[EBResponse]]:
url = self._get_url(path)
url, headers, data = self._client.prepare_request(
method,
url,
path,
supplied_headers=headers,
params=params,
files=files,
Expand All @@ -59,23 +59,22 @@ def request(
headers=headers,
files=files,
request_timeout=request_timeout,
base_url=self.base_url,
)

async def arequest(
self,
method: str,
path: str,
stream: bool,
*,
params: Optional[ParamsType] = None,
headers: Optional[HeadersType] = None,
files: Optional[FilesType] = None,
request_timeout: Optional[float] = None,
) -> Union[EBResponse, AsyncIterator[EBResponse]]:
url = self._get_url(path)
url, headers, data = self._client.prepare_request(
method,
url,
path,
supplied_headers=headers,
params=params,
files=files,
Expand Down
23 changes: 20 additions & 3 deletions erniebot/src/erniebot/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,9 +54,13 @@ def init_global_config() -> None:
# Miscellaneous settings
# Proxy to use
cfg.add_item(URLItem(key="proxy", env_key="EB_PROXY"))
# requests session
cfg.add_item(AnyObjectItem(key="requests_session"))
# aiohttp session
cfg.add_item(AnyObjectItem(key="aiohttp_session"))


class _BaseConfig(object):
class _Config(object):
def __init__(self, cfg_dict: Optional[Dict[str, "_ConfigItem"]] = None) -> None:
super().__init__()
self._cfg_dict: Dict[str, "_ConfigItem"] = cfg_dict if cfg_dict is not None else dict()
Expand All @@ -75,7 +79,7 @@ def set_value(self, key: str, value: Any) -> None:
cfg.value = value


class GlobalConfig(_BaseConfig, metaclass=Singleton):
class GlobalConfig(_Config, metaclass=Singleton):
def create_dict(self, **overrides: Any) -> ConfigDictType:
dict_: ConfigDictType = {}
for key, cfg in self._cfg_dict.items():
Expand Down Expand Up @@ -153,7 +157,8 @@ def __init__(
def factory(self, env_val: str) -> Any:
if self.ensure_integer:
return int(env_val)
return float(env_val)
else:
return float(env_val)

def _validate(self, val: Any) -> None:
if not isinstance(val, int if self.ensure_integer else numbers.Real):
Expand Down Expand Up @@ -191,3 +196,15 @@ def _validate(self, val: Any) -> None:
res = re.match(pat, val)
if res is None:
raise ValueError(f"Invalid URL: {val}")


class AnyObjectItem(_ConfigItem):
def __init__(self, key: str, default: Any = None) -> None:
super().__init__(key=key, env_key=None, default=default)

def factory(self, env_val: str) -> Any:
raise AssertionError

def _validate(self, val: Any) -> None:
# Any object is valid
pass
15 changes: 0 additions & 15 deletions erniebot/src/erniebot/functions/__init__.py

This file was deleted.

Loading