Skip to content

Commit

Permalink
修正: API 関数 docstring を FastAPI 型に変更 (#1123)
Browse files Browse the repository at this point in the history
* refactor: API関数docstringをFastAPI型に変更

* fix: schema スナップショット

* refactor: 引数 docstring の FastAPI型化

* refactor: API 返り値 docstring の path op 引数化

* refactor: lint

* fix: FastAPI Query-Path 取り違え

* fix: OpenAPI schema 更新
  • Loading branch information
tarepan authored Mar 22, 2024
1 parent c75b657 commit d81ba0f
Show file tree
Hide file tree
Showing 2 changed files with 171 additions and 152 deletions.
220 changes: 92 additions & 128 deletions run.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,9 @@

import soundfile
import uvicorn
from fastapi import Depends, FastAPI, Form, HTTPException, Query, Request, Response
from fastapi import Body, Depends, FastAPI, Form, HTTPException
from fastapi import Path as FAPath
from fastapi import Query, Request, Response
from fastapi.middleware.cors import CORSMiddleware
from fastapi.openapi.utils import get_openapi
from fastapi.responses import JSONResponse
Expand Down Expand Up @@ -749,15 +751,15 @@ def connect_waves(waves: list[str]) -> FileResponse:
background=BackgroundTask(delete_file, f.name),
)

@app.get("/presets", response_model=list[Preset], tags=["その他"])
@app.get(
"/presets",
response_model=list[Preset],
response_description="プリセットのリスト",
tags=["その他"],
)
def get_presets() -> list[Preset]:
"""
エンジンが保持しているプリセットの設定を返します
Returns
-------
presets: list[Preset]
プリセットのリスト
"""
try:
presets = preset_manager.load_presets()
Expand All @@ -768,23 +770,20 @@ def get_presets() -> list[Preset]:
@app.post(
"/add_preset",
response_model=int,
response_description="追加したプリセットのプリセットID",
tags=["その他"],
dependencies=[Depends(check_disabled_mutable_api)],
)
def add_preset(preset: Preset) -> int:
def add_preset(
preset: Annotated[
Preset,
Body(
description="新しいプリセット。プリセットIDが既存のものと重複している場合は、新規のプリセットIDが採番されます。"
),
]
) -> int:
"""
新しいプリセットを追加します
Parameters
-------
preset: Preset
新しいプリセット。
プリセットIDが既存のものと重複している場合は、新規のプリセットIDが採番されます。
Returns
-------
id: int
追加したプリセットのプリセットID
"""
try:
id = preset_manager.add_preset(preset)
Expand All @@ -795,23 +794,20 @@ def add_preset(preset: Preset) -> int:
@app.post(
"/update_preset",
response_model=int,
response_description="更新したプリセットのプリセットID",
tags=["その他"],
dependencies=[Depends(check_disabled_mutable_api)],
)
def update_preset(preset: Preset) -> int:
def update_preset(
preset: Annotated[
Preset,
Body(
description="更新するプリセット。プリセットIDが更新対象と一致している必要があります。"
),
]
) -> int:
"""
既存のプリセットを更新します
Parameters
-------
preset: Preset
更新するプリセット。
プリセットIDが更新対象と一致している必要があります。
Returns
-------
id: int
更新したプリセットのプリセットID
"""
try:
id = preset_manager.update_preset(preset)
Expand All @@ -825,15 +821,11 @@ def update_preset(preset: Preset) -> int:
tags=["その他"],
dependencies=[Depends(check_disabled_mutable_api)],
)
def delete_preset(id: int) -> Response:
def delete_preset(
id: Annotated[int, Query(description="削除するプリセットのプリセットID")]
) -> Response:
"""
既存のプリセットを削除します
Parameters
-------
id: int
削除するプリセットのプリセットID
"""
try:
preset_manager.delete_preset(id)
Expand Down Expand Up @@ -996,15 +988,12 @@ def singer_info(
@app.get(
"/downloadable_libraries",
response_model=list[DownloadableLibraryInfo],
response_description="ダウンロード可能な音声ライブラリの情報リスト",
tags=["音声ライブラリ管理"],
)
def downloadable_libraries() -> list[DownloadableLibraryInfo]:
"""
ダウンロード可能な音声ライブラリの情報を返します。
Returns
-------
ret_data: list[DownloadableLibrary]
"""
if not engine_manifest_data.supported_features.manage_library:
raise HTTPException(
Expand All @@ -1015,15 +1004,12 @@ def downloadable_libraries() -> list[DownloadableLibraryInfo]:
@app.get(
"/installed_libraries",
response_model=dict[str, InstalledLibraryInfo],
response_description="インストールした音声ライブラリの情報",
tags=["音声ライブラリ管理"],
)
def installed_libraries() -> dict[str, InstalledLibraryInfo]:
"""
インストールした音声ライブラリの情報を返します。
Returns
-------
ret_data: dict[str, InstalledLibrary]
"""
if not engine_manifest_data.supported_features.manage_library:
raise HTTPException(
Expand All @@ -1038,17 +1024,12 @@ def installed_libraries() -> dict[str, InstalledLibraryInfo]:
dependencies=[Depends(check_disabled_mutable_api)],
)
async def install_library(
library_uuid: str,
library_uuid: Annotated[str, FAPath(description="音声ライブラリのID")],
request: Request,
) -> Response:
"""
音声ライブラリをインストールします。
音声ライブラリのZIPファイルをリクエストボディとして送信してください。
Parameters
----------
library_uuid: str
音声ライブラリのID
"""
if not engine_manifest_data.supported_features.manage_library:
raise HTTPException(
Expand All @@ -1067,14 +1048,11 @@ async def install_library(
tags=["音声ライブラリ管理"],
dependencies=[Depends(check_disabled_mutable_api)],
)
def uninstall_library(library_uuid: str) -> Response:
def uninstall_library(
library_uuid: Annotated[str, FAPath(description="音声ライブラリのID")]
) -> Response:
"""
音声ライブラリをアンインストールします。
Parameters
----------
library_uuid: str
音声ライブラリのID
"""
if not engine_manifest_data.supported_features.manage_library:
raise HTTPException(
Expand Down Expand Up @@ -1112,17 +1090,15 @@ def is_initialized_speaker(
return core.is_initialized_style_id_synthesis(style_id)

@app.get(
"/user_dict", response_model=dict[str, UserDictWord], tags=["ユーザー辞書"]
"/user_dict",
response_model=dict[str, UserDictWord],
response_description="単語のUUIDとその詳細",
tags=["ユーザー辞書"],
)
def get_user_dict_words() -> dict[str, UserDictWord]:
"""
ユーザー辞書に登録されている単語の一覧を返します。
単語の表層形(surface)は正規化済みの物を返します。
Returns
-------
dict[str, UserDictWord]
単語のUUIDとその詳細
"""
try:
return read_dict()
Expand All @@ -1139,29 +1115,28 @@ def get_user_dict_words() -> dict[str, UserDictWord]:
dependencies=[Depends(check_disabled_mutable_api)],
)
def add_user_dict_word(
surface: str,
pronunciation: str,
accent_type: int,
word_type: WordTypes | None = None,
priority: Annotated[int | None, Query(ge=MIN_PRIORITY, le=MAX_PRIORITY)] = None,
surface: Annotated[str, Query(description="言葉の表層形")],
pronunciation: Annotated[str, Query(description="言葉の発音(カタカナ)")],
accent_type: Annotated[
int, Query(description="アクセント型(音が下がる場所を指す)")
],
word_type: Annotated[
WordTypes | None,
Query(
description="PROPER_NOUN(固有名詞)、COMMON_NOUN(普通名詞)、VERB(動詞)、ADJECTIVE(形容詞)、SUFFIX(語尾)のいずれか"
),
] = None,
priority: Annotated[
int | None,
Query(
ge=MIN_PRIORITY,
le=MAX_PRIORITY,
description="単語の優先度(0から10までの整数)。数字が大きいほど優先度が高くなる。1から9までの値を指定することを推奨",
),
] = None,
) -> Response:
"""
ユーザー辞書に言葉を追加します。
Parameters
----------
surface : str
言葉の表層形
pronunciation: str
言葉の発音(カタカナ)
accent_type: int
アクセント型(音が下がる場所を指す)
word_type: WordTypes, optional
PROPER_NOUN(固有名詞)、COMMON_NOUN(普通名詞)、VERB(動詞)、ADJECTIVE(形容詞)、SUFFIX(語尾)のいずれか
priority: int, optional
単語の優先度(0から10までの整数)
数字が大きいほど優先度が高くなる
1から9までの値を指定することを推奨
"""
try:
word_uuid = apply_word(
Expand Down Expand Up @@ -1189,32 +1164,29 @@ def add_user_dict_word(
dependencies=[Depends(check_disabled_mutable_api)],
)
def rewrite_user_dict_word(
surface: str,
pronunciation: str,
accent_type: int,
word_uuid: str,
word_type: WordTypes | None = None,
priority: Annotated[int | None, Query(ge=MIN_PRIORITY, le=MAX_PRIORITY)] = None,
surface: Annotated[str, Query(description="言葉の表層形")],
pronunciation: Annotated[str, Query(description="言葉の発音(カタカナ)")],
accent_type: Annotated[
int, Query(description="アクセント型(音が下がる場所を指す)")
],
word_uuid: Annotated[str, FAPath(description="更新する言葉のUUID")],
word_type: Annotated[
WordTypes | None,
Query(
description="PROPER_NOUN(固有名詞)、COMMON_NOUN(普通名詞)、VERB(動詞)、ADJECTIVE(形容詞)、SUFFIX(語尾)のいずれか"
),
] = None,
priority: Annotated[
int | None,
Query(
ge=MIN_PRIORITY,
le=MAX_PRIORITY,
description="単語の優先度(0から10までの整数)。数字が大きいほど優先度が高くなる。1から9までの値を指定することを推奨。",
),
] = None,
) -> Response:
"""
ユーザー辞書に登録されている言葉を更新します。
Parameters
----------
surface : str
言葉の表層形
pronunciation: str
言葉の発音(カタカナ)
accent_type: int
アクセント型(音が下がる場所を指す)
word_uuid: str
更新する言葉のUUID
word_type: WordTypes, optional
PROPER_NOUN(固有名詞)、COMMON_NOUN(普通名詞)、VERB(動詞)、ADJECTIVE(形容詞)、SUFFIX(語尾)のいずれか
priority: int, optional
単語の優先度(0から10までの整数)
数字が大きいほど優先度が高くなる
1から9までの値を指定することを推奨
"""
try:
rewrite_word(
Expand Down Expand Up @@ -1244,14 +1216,11 @@ def rewrite_user_dict_word(
tags=["ユーザー辞書"],
dependencies=[Depends(check_disabled_mutable_api)],
)
def delete_user_dict_word(word_uuid: str) -> Response:
def delete_user_dict_word(
word_uuid: Annotated[str, FAPath(description="削除する言葉のUUID")]
) -> Response:
"""
ユーザー辞書に登録されている言葉を削除します。
Parameters
----------
word_uuid: str
削除する言葉のUUID
"""
try:
delete_word(word_uuid=word_uuid)
Expand All @@ -1271,18 +1240,16 @@ def delete_user_dict_word(word_uuid: str) -> Response:
dependencies=[Depends(check_disabled_mutable_api)],
)
def import_user_dict_words(
import_dict_data: dict[str, UserDictWord],
override: bool,
import_dict_data: Annotated[
dict[str, UserDictWord],
Body(description="インポートするユーザー辞書のデータ"),
],
override: Annotated[
bool, Query(description="重複したエントリがあった場合、上書きするかどうか")
],
) -> Response:
"""
他のユーザー辞書をインポートします。
Parameters
----------
import_dict_data: dict[str, UserDictWord]
インポートするユーザー辞書のデータ
override: bool
重複したエントリがあった場合、上書きするかどうか
"""
try:
import_user_dict(dict_data=import_dict_data, override=override)
Expand Down Expand Up @@ -1321,15 +1288,12 @@ def engine_manifest() -> EngineManifest:
}
},
)
def validate_kana(text: str) -> bool:
def validate_kana(
text: Annotated[str, Query(description="判定する対象の文字列")]
) -> bool:
"""
テキストがAquesTalk 風記法に従っているかどうかを判定します。
従っていない場合はエラーが返ります。
Parameters
----------
text: str
判定する対象の文字列
"""
try:
parse_kana(text)
Expand Down
Loading

0 comments on commit d81ba0f

Please sign in to comment.