diff --git a/arc/__init__.py b/arc/__init__.py index 125400e..d975dd6 100644 --- a/arc/__init__.py +++ b/arc/__init__.py @@ -11,18 +11,18 @@ from alluka import Client as Injector from alluka import inject -from .client import Client, GatewayClient, RESTClient +from arc import abc, command + +from .abc import Option +from .client import Client, GatewayClient, GatewayContext, GatewayPlugin, RESTClient, RESTContext, RESTPlugin from .command import ( AttachmentParams, BoolParams, - CallableCommandBase, - CallableCommandProto, ChannelParams, FloatParams, IntParams, MentionableParams, MessageCommand, - Option, RoleParams, SlashCommand, SlashGroup, @@ -36,12 +36,12 @@ slash_subcommand, user_command, ) -from .context import AutocompleteData, AutodeferMode, Context +from .context import AutocompleteData, AutodeferMode, Context, InteractionResponse from .errors import ArcError, AutocompleteError, CommandInvokeError from .events import ArcEvent, CommandErrorEvent from .extension import loader, unloader from .internal.about import __author__, __author_email__, __license__, __maintainer__, __url__, __version__ -from .plugin import GatewayPlugin, Plugin, RESTPlugin +from .plugin import GatewayPluginBase, PluginBase, RESTPluginBase __all__ = ( "__version__", @@ -54,8 +54,6 @@ "inject", "Injector", "AutocompleteData", - "CallableCommandProto", - "CallableCommandBase", "Option", "Context", "Context", @@ -84,13 +82,20 @@ "ArcError", "AutocompleteError", "CommandInvokeError", - "Plugin", - "RESTPlugin", - "GatewayPlugin", + "PluginBase", + "RESTPluginBase", + "GatewayPluginBase", "loader", "unloader", "ArcEvent", "CommandErrorEvent", + "InteractionResponse", + "GatewayContext", + "RESTContext", + "RESTPlugin", + "GatewayPlugin", + "abc", + "command", ) # MIT License diff --git a/arc/abc/__init__.py b/arc/abc/__init__.py index 2ccf6c7..cb053f6 100644 --- a/arc/abc/__init__.py +++ b/arc/abc/__init__.py @@ -1,3 +1,19 @@ +from .client import Client +from .command import CallableCommandBase, CallableCommandProto, CommandBase, CommandProto from .error_handler import HasErrorHandler +from .option import CommandOptionBase, Option, OptionBase, OptionParams, OptionWithChoices, OptionWithChoicesParams -__all__ = ("HasErrorHandler",) +__all__ = ( + "HasErrorHandler", + "CommandBase", + "CommandProto", + "CallableCommandProto", + "CallableCommandBase", + "Option", + "OptionBase", + "CommandOptionBase", + "OptionParams", + "OptionWithChoices", + "OptionWithChoicesParams", + "Client", +) diff --git a/arc/abc/client.py b/arc/abc/client.py new file mode 100644 index 0000000..b888878 --- /dev/null +++ b/arc/abc/client.py @@ -0,0 +1,706 @@ +from __future__ import annotations + +import abc +import asyncio +import functools +import importlib +import inspect +import logging +import pathlib +import sys +import traceback +import typing as t +from contextlib import suppress + +import alluka +import hikari + +from arc.command.message import MessageCommand +from arc.command.slash import SlashCommand, SlashGroup, SlashSubCommand, SlashSubGroup +from arc.command.user import UserCommand +from arc.context import AutodeferMode, Context +from arc.errors import ExtensionLoadError, ExtensionUnloadError +from arc.internal.sync import _sync_commands +from arc.internal.types import AppT, BuilderT, ResponseBuilderT +from arc.plugin import PluginBase + +if t.TYPE_CHECKING: + import typing_extensions as te + + from arc.abc.command import CommandBase + from arc.command import SlashCommandLike + +__all__ = ("Client",) + + +T = t.TypeVar("T") +P = t.ParamSpec("P") + +logger = logging.getLogger(__name__) + + +class Client(t.Generic[AppT], abc.ABC): + """The abstract base class for an `arc` client. + See [`GatewayClient`][arc.client.GatewayClient] and [`RESTClient`][arc.client.RESTClient] for implementations. + + Parameters + ---------- + app : AppT + The application this client is for. + default_enabled_guilds : t.Sequence[hikari.Snowflake] | None, optional + The guilds that slash commands will be registered in by default, by default None + autosync : bool, optional + Whether to automatically sync commands on startup, by default True + """ + + __slots__: t.Sequence[str] = ( + "_app", + "_default_enabled_guilds", + "_application", + "_slash_commands", + "_message_commands", + "_user_commands", + "_injector", + "_autosync", + "_plugins", + "_loaded_extensions", + ) + + def __init__( + self, app: AppT, *, default_enabled_guilds: t.Sequence[hikari.Snowflake] | None = None, autosync: bool = True + ) -> None: + self._app = app + self._default_enabled_guilds = default_enabled_guilds + self._application: hikari.Application | None = None + self._slash_commands: dict[str, SlashCommandLike[te.Self]] = {} + self._message_commands: dict[str, MessageCommand[te.Self]] = {} + self._user_commands: dict[str, UserCommand[te.Self]] = {} + self._injector: alluka.Client = alluka.Client() + self._plugins: dict[str, PluginBase[te.Self]] = {} + self._loaded_extensions: list[str] = [] + self._autosync = autosync + + @property + @abc.abstractmethod + def is_rest(self) -> bool: + """Whether the app is a rest client or a gateway client. + + This controls the client response flow, if True, `Client.handle_command_interaction` and `Client.handle_autocomplete_interaction` + will return interaction response builders to be sent back to Discord, otherwise they will return None. + """ + + @property + def app(self) -> AppT: + """The application this client is for.""" + return self._app + + @property + def rest(self) -> hikari.api.RESTClient: + """The REST client of the underyling app.""" + return self.app.rest + + @property + def application(self) -> hikari.Application | None: + """The application object for this client. This is fetched on startup.""" + return self._application + + @property + def injector(self) -> alluka.Client: + """The injector for this client.""" + return self._injector + + @property + def default_enabled_guilds(self) -> t.Sequence[hikari.Snowflake] | None: + """The guilds that slash commands will be registered in by default.""" + return self._default_enabled_guilds + + @property + def commands(self) -> t.Mapping[hikari.CommandType, t.Mapping[str, CommandBase[te.Self, t.Any]]]: + """All commands added to this client, categorized by command type.""" + return { + hikari.CommandType.SLASH: self.slash_commands, + hikari.CommandType.MESSAGE: self._message_commands, + hikari.CommandType.USER: self._user_commands, + } + + @property + def slash_commands(self) -> t.Mapping[str, SlashCommandLike[te.Self]]: + """The slash commands added to this client. This only includes top-level commands and groups.""" + return self._slash_commands + + @property + def message_commands(self) -> t.Mapping[str, MessageCommand[te.Self]]: + """The message commands added to this client.""" + return self._message_commands + + @property + def user_commands(self) -> t.Mapping[str, UserCommand[te.Self]]: + """The user commands added to this client.""" + return self._user_commands + + @property + def plugins(self) -> t.Mapping[str, PluginBase[te.Self]]: + """The plugins added to this client.""" + return self._plugins + + def _add_command(self, command: CommandBase[te.Self, t.Any]) -> None: + """Add a command to this client. Called by include hooks.""" + if isinstance(command, (SlashCommand, SlashGroup)): + self._add_slash_command(command) + elif isinstance(command, MessageCommand): + self._add_message_command(command) + elif isinstance(command, UserCommand): + self._add_user_command(command) + + def _remove_command(self, command: CommandBase[te.Self, t.Any]) -> None: + """Remove a command from this client. Called by remove hooks.""" + if isinstance(command, (SlashCommand, SlashGroup)): + self._slash_commands.pop(command.name, None) + elif isinstance(command, MessageCommand): + self._message_commands.pop(command.name, None) + elif isinstance(command, UserCommand): + self._user_commands.pop(command.name, None) + + def _add_slash_command(self, command: SlashCommandLike[te.Self]) -> None: + """Add a slash command to this client.""" + if self.slash_commands.get(command.name) is not None: + logger.warning( + f"Shadowing already registered slash command '{command.name}'. Did you define multiple commands/groups with the same name?" + ) + + self._slash_commands[command.name] = command + + def _add_message_command(self, command: MessageCommand[te.Self]) -> None: + """Add a message command to this client.""" + if self._message_commands.get(command.name) is not None: + logger.warning( + f"Shadowing already registered message command '{command.name}'. Did you define multiple commands with the same name?" + ) + + self._message_commands[command.name] = command + + def _add_user_command(self, command: UserCommand[te.Self]) -> None: + """Add a user command to this client.""" + if self._user_commands.get(command.name) is not None: + logger.warning( + f"Shadowing already registered user command '{command.name}'. Did you define multiple commands with the same name?" + ) + + self._user_commands[command.name] = command + + async def _on_startup(self) -> None: + """Called when the client is starting up. + Fetches application, syncs commands, calls user-defined startup. + """ + self._application = await self.app.rest.fetch_application() + logger.debug(f"Fetched application: '{self.application}'") + if self._autosync: + await _sync_commands(self) + await self.on_startup() + + async def on_startup(self) -> None: + """Called when the client is starting up. + Override for custom startup logic. + """ + + async def _on_shutdown(self) -> None: + """Called when the client is shutting down. + Reserved for internal shutdown logic. + """ + await self.on_shutdown() + + async def on_shutdown(self) -> None: + """Called when the client is shutting down. + Override for custom shutdown logic. + """ + + async def _on_error(self, ctx: Context[te.Self], exception: Exception) -> None: + await self.on_error(ctx, exception) + + async def on_error(self, context: Context[te.Self], exception: Exception) -> None: + """Called when an error occurs in a command callback and all other error handlers have failed. + + Parameters + ---------- + context : Context[te.Self] + The context of the command. + exception : Exception + The exception that was raised. + """ + print(f"Unhandled error in command '{context.command.name}' callback: {exception}", file=sys.stderr) + traceback.print_exception(type(exception), exception, exception.__traceback__, file=sys.stderr) + with suppress(Exception): + await context.respond("❌ Something went wrong. Please contact the bot developer.") + + async def on_command_interaction(self, interaction: hikari.CommandInteraction) -> ResponseBuilderT | None: + """Should be called when a command interaction is sent by Discord. + + Parameters + ---------- + interaction : hikari.CommandInteraction + The interaction that was created. + + Returns + ------- + ResponseBuilderT | None + The response builder to send back to Discord, if using a REST client. + """ + command = None + + if interaction.command_type is hikari.CommandType.SLASH: + command = self.slash_commands.get(interaction.command_name) + elif interaction.command_type is hikari.CommandType.MESSAGE: + command = self.message_commands.get(interaction.command_name) + elif interaction.command_type is hikari.CommandType.USER: + command = self.user_commands.get(interaction.command_name) + + if command is None: + logger.warning(f"Received interaction for unknown command '{interaction.command_name}'.") + return + + fut = await command.invoke(interaction) + + if fut is None: + return + + try: + return await asyncio.wait_for(fut, timeout=3.0) + except asyncio.TimeoutError: + logger.warning( + f"Timed out waiting for response from command: '{interaction.command_name} ({interaction.command_type})'" + f" Did you forget to respond?" + ) + + async def on_autocomplete_interaction( + self, interaction: hikari.AutocompleteInteraction + ) -> hikari.api.InteractionAutocompleteBuilder | None: + """Should be called when an autocomplete interaction is sent by Discord. + + Parameters + ---------- + interaction : hikari.AutocompleteInteraction + The interaction that was created. + + Returns + ------- + hikari.api.InteractionAutocompleteBuilder | None + The autocomplete builder to send back to Discord, if using a REST client. + """ + command = self.slash_commands.get(interaction.command_name) + + if command is None: + logger.warning(f"Received autocomplete interaction for unknown command '{interaction.command_name}'.") + return + + return await command._on_autocomplete(interaction) + + def include(self, command: CommandBase[te.Self, BuilderT]) -> CommandBase[te.Self, BuilderT]: + """First-order decorator to add a command to this client. + + Parameters + ---------- + command : CommandBase[te.Self, BuilderT] + The command to add. + """ + if command.plugin is not None: + raise ValueError(f"Command '{command.name}' is already registered with plugin '{command.plugin.name}'.") + + if command.name in self.commands[command.command_type]: + raise ValueError(f"Command '{command.name}' is already registered with this client.") + + command._client_include_hook(self) + return command + + def include_slash_group( + self, + name: str, + description: str = "No description provided.", + *, + guilds: hikari.UndefinedOr[t.Sequence[hikari.Snowflake]] = hikari.UNDEFINED, + autodefer: bool | AutodeferMode = True, + is_dm_enabled: bool = True, + default_permissions: hikari.UndefinedOr[hikari.Permissions] = hikari.UNDEFINED, + name_localizations: dict[hikari.Locale, str] | None = None, + description_localizations: dict[hikari.Locale, str] | None = None, + is_nsfw: bool = False, + ) -> SlashGroup[te.Self]: + """Add a new slash command group to this client. + + Parameters + ---------- + name : str + The name of the slash command group. + description : str + The description of the slash command group. + guilds : hikari.UndefinedOr[t.Sequence[hikari.Snowflake]], optional + The guilds to register the slash command group in, by default hikari.UNDEFINED + autodefer : bool | AutodeferMode, optional + If True, all commands in this group will automatically defer if it is taking longer than 2 seconds to respond. + This can be overridden on a per-subcommand basis. + is_dm_enabled : bool, optional + Whether the slash command group is enabled in DMs, by default True + default_permissions : hikari.UndefinedOr[hikari.Permissions], optional + The default permissions for the slash command group, by default hikari.UNDEFINED + name_localizations : dict[hikari.Locale, str], optional + The name of the slash command group in different locales, by default None + description_localizations : dict[hikari.Locale, str], optional + The description of the slash command group in different locales, by default None + is_nsfw : bool, optional + Whether the slash command group is only usable in NSFW channels, by default False + + Returns + ------- + SlashGroup[te.Self] + The slash command group that was created. + + Usage + ----- + ```py + group = client.include_slash_group("Group", "A group of commands.") + + @group.include + @arc.slash_subcommand(name="Command", description="A command.") + async def cmd(ctx: arc.Context[arc.GatewayClient]) -> None: + await ctx.respond("Hello!") + ``` + """ + children: dict[str, SlashSubCommand[te.Self] | SlashSubGroup[te.Self]] = {} + + group = SlashGroup( + name=name, + description=description, + children=children, + guilds=guilds, + autodefer=AutodeferMode(autodefer), + is_dm_enabled=is_dm_enabled, + default_permissions=default_permissions, + name_localizations=name_localizations or {}, + description_localizations=description_localizations or {}, + is_nsfw=is_nsfw, + ) + group._client_include_hook(self) + return group + + def add_plugin(self, plugin: PluginBase[te.Self]) -> None: + """Add a plugin to this client. + + Parameters + ---------- + plugin : Plugin[te.Self] + The plugin to add. + """ + plugin._client_include_hook(self) + + def remove_plugin(self, plugin: str | PluginBase[te.Self]) -> None: + """Remove a plugin from this client. + + Parameters + ---------- + plugin : str | Plugin[te.Self] + The plugin or name of the plugin to remove. + + Raises + ------ + ValueError + If there is no plugin with the given name. + """ + if isinstance(plugin, PluginBase): + if plugin not in self.plugins.values(): + raise ValueError(f"Plugin '{plugin.name}' is not registered with this client.") + return plugin._client_remove_hook() + + pg = self.plugins.get(plugin) + + if pg is None: + raise ValueError(f"Plugin '{plugin}' is not registered with this client.") + + pg._client_remove_hook() + + def load_extension(self, path: str) -> te.Self: + """Load a python module with path `path` as an extension. + This will import the module, and call it's loader function. + + Parameters + ---------- + path : str + The path to the module to load. + + Returns + ------- + te.Self + The client for chaining calls. + + Raises + ------ + ValueError + If the module does not have a loader. + + Usage + ----- + ```py + client = arc.GatewayClient(...) + client.load_extension("extension") + + # In extension.py + + plugin = arc.GatewayPlugin[arc.GatewayClient]("test_plugin") + + @arc.loader + def loader(client: arc.GatewayClient) -> None: + client.add_plugin(plugin) + ``` + + See Also + -------- + - [`@arc.loader`][arc.extension.loader] + - [`Client.load_extensions_from`][arc.client.Client.load_extensions_from] + - [`Client.unload_extension`][arc.client.Client.unload_extension] + """ + parents = path.split(".") + name = parents.pop() + + pkg = ".".join(parents) + + if pkg: + name = "." + name + + module = importlib.import_module(path, package=pkg) + + loader = getattr(module, "__arc_extension_loader__", None) + + if loader is None: + raise ValueError(f"Module '{path}' does not have a loader.") + + self._loaded_extensions.append(path) + loader(self) + logger.info(f"Loaded extension: '{path}'") + + return self + + def load_extensions_from(self, dir_path: str | pathlib.Path, recursive: bool = False) -> te.Self: + """Load all python modules in a directory as extensions. + This will import the modules, and call their loader functions. + + Parameters + ---------- + dir_path : str + The path to the directory to load extensions from. + recursive : bool, optional + Whether to load extensions from subdirectories, by default False + + Returns + ------- + te.Self + The client for chaining calls. + + Raises + ------ + ExtensionLoadError + If `dir_path` does not exist or is not a directory. + ExtensionLoadError + If a module does not have a loader defined. + """ + if isinstance(dir_path, str): + dir_path = pathlib.Path(dir_path) + + try: + dir_path.resolve().relative_to(pathlib.Path.cwd()) + except ValueError: + raise ExtensionLoadError("dir_path must be relative to the current working directory.") + + if not dir_path.is_dir(): + raise ExtensionLoadError("dir_path must exist and be a directory.") + + globfunc = dir_path.rglob if recursive else dir_path.glob + loaded = 0 + + for file in globfunc(r"**/[!_]*.py"): + module_path = ".".join(file.as_posix()[:-3].split("/")) + self.load_extension(module_path) + loaded += 1 + + if loaded == 0: + logger.warning(f"No extensions were found at '{dir_path}'.") + + return self + + def unload_extension(self, path: str) -> te.Self: + """Unload a python module with path `path` as an extension. + + Parameters + ---------- + path : str + The path to the module to unload. + + Returns + ------- + te.Self + The client for chaining calls. + + Raises + ------ + ExtensionUnloadError + If the module does not have an unloader or is not loaded. + """ + parents = path.split(".") + name = parents.pop() + + pkg = ".".join(parents) + + if pkg: + name = "." + name + + if path not in self._loaded_extensions: + raise ExtensionUnloadError(f"Extension '{path}' is not loaded.") + + module = importlib.import_module(path, package=pkg) + + unloader = getattr(module, "__arc_extension_unloader__", None) + + if unloader is None: + raise ExtensionUnloadError(f"Module '{path}' does not have an unloader.") + + unloader(self) + self._loaded_extensions.remove(path) + + if module.__name__ in sys.modules: + del sys.modules[module.__name__] + + return self + + def set_type_dependency(self, type_: t.Type[T], instance: T) -> None: + """Set a type dependency for this client. This can then be injected into all arc callbacks. + + Parameters + ---------- + type_ : t.Type[T] + The type of the dependency. + instance : T + The instance of the dependency. + + Usage + ----- + + ```py + class MyDependency: + def __init__(self, value: str): + self.value = value + + client.set_type_dependency(MyDependency, MyDependency("Hello!")) + + @client.include + @arc.slash_command("cmd", "A command.") + async def cmd(ctx: arc.Context[arc.GatewayClient], dep: MyDependency = arc.inject()) -> None: + await ctx.respond(dep.value) + ``` + + See Also + -------- + - [`Client.get_type_dependency`][arc.client.Client.get_type_dependency] + A method to get dependencies for the client. + + - [`Client.inject_dependencies`][arc.client.Client.inject_dependencies] + A decorator to inject dependencies into arbitrary functions. + """ + self._injector.set_type_dependency(type_, instance) + + def get_type_dependency(self, type_: t.Type[T]) -> hikari.UndefinedOr[T]: + """Get a type dependency for this client. + + Parameters + ---------- + type_ : t.Type[T] + The type of the dependency. + + Returns + ------- + hikari.UndefinedOr[T] + The instance of the dependency, if it exists. + """ + return self._injector.get_type_dependency(type_, default=hikari.UNDEFINED) + + def inject_dependencies(self, func: t.Callable[P, T]) -> t.Callable[P, T]: + """First order decorator to inject dependencies into the decorated function. + + !!! note + Command callbacks are automatically injected with dependencies, + thus this decorator is not needed for them. + + Usage + ----- + ```py + class MyDependency: + def __init__(self, value: str): + self.value = value + + client.set_type_dependency(MyDependency, MyDependency("Hello!")) + + @client.inject_dependencies + def my_func(dep: MyDependency = arc.inject()) -> None: + print(dep.value) + + my_func() # Prints "Hello!" + ``` + + See Also + -------- + - [`Client.set_type_dependency`][arc.client.Client.set_type_dependency] + A method to set dependencies for the client. + """ + if inspect.iscoroutinefunction(func): + + @functools.wraps(func) + async def decorator_async(*args: P.args, **kwargs: P.kwargs) -> T: + return await self.injector.call_with_async_di(func, *args, **kwargs) + + return decorator_async # pyright: ignore reportGeneralTypeIssues + else: + + @functools.wraps(func) + def decorator(*args: P.args, **kwargs: P.kwargs) -> T: + return self.injector.call_with_di(func, *args, **kwargs) + + return decorator + + async def resync_commands(self) -> None: + """Synchronize the commands registered in this client with Discord. + + !!! warning + Calling this is expensive, and should only be done when absolutely necessary. + The client automatically syncs commands on startup, unless the `autosync` parameter + is set to `False` when creating the client. + + Raises + ------ + RuntimeError + If `Client.application` is `None`. + This usually only happens if `Client.resync_commands` is called before `Client.on_startup`. + """ + await _sync_commands(self) + + async def purge_all_commands(self, guild: hikari.SnowflakeishOr[hikari.PartialGuild] | None = None) -> None: + """Purge all commands registered on Discord. This can be used to clean up commands. + + Parameters + ---------- + guild : hikari.SnowflakeishOr[hikari.PartialGuild] | None, optional + The guild to purge commands from, by default None + If a `guild` is not provided, this will purge global commands. + + !!! warning + This will remove all commands registered on Discord, **including commands not registered by this client**. + + Raises + ------ + RuntimeError + If `Client.application` is `None`. + This usually only happens if `Client.purge_all_commands` is called before `Client.on_startup`. + """ + if self.application is None: + raise RuntimeError(f"Cannot purge commands before '{type(self).__name__}.application' is fetched.") + + if guild is not None: + guild_id = hikari.Snowflake(guild) + await self.rest.set_application_commands(self.application, [], guild_id) + else: + await self.rest.set_application_commands(self.application, []) diff --git a/arc/command/base.py b/arc/abc/command.py similarity index 87% rename from arc/command/base.py rename to arc/abc/command.py index 3e0d9a2..bd0d635 100644 --- a/arc/command/base.py +++ b/arc/abc/command.py @@ -7,15 +7,13 @@ import attr import hikari -from ..abc import HasErrorHandler -from ..context import AutodeferMode -from ..internal.types import BuilderT, ClientT, CommandCallbackT, ResponseBuilderT +from arc.abc.error_handler import HasErrorHandler +from arc.context import AutodeferMode +from arc.internal.types import BuilderT, ClientT, CommandCallbackT, ResponseBuilderT if t.TYPE_CHECKING: from ..context import Context - from ..plugin import Plugin - -__all__ = ("CommandProto", "CallableCommandProto", "CommandBase", "CallableCommandBase") + from ..plugin import PluginBase class CommandProto(t.Protocol): @@ -111,7 +109,7 @@ class CommandBase(HasErrorHandler[ClientT], t.Generic[ClientT, BuilderT]): _client: ClientT | None = attr.field(init=False, default=None) """The client that is handling this command.""" - _plugin: Plugin[ClientT] | None = attr.field(init=False, default=None) + _plugin: PluginBase[ClientT] | None = attr.field(init=False, default=None) """The plugin that this command belongs to, if any.""" guilds: hikari.UndefinedOr[t.Sequence[hikari.Snowflake]] = hikari.UNDEFINED @@ -156,7 +154,7 @@ def client(self) -> ClientT: return self._client @property - def plugin(self) -> Plugin[ClientT] | None: + def plugin(self) -> PluginBase[ClientT] | None: """The plugin that this command belongs to, if any.""" return self._plugin @@ -267,7 +265,7 @@ def _client_include_hook(self, client: ClientT) -> None: self._client = client self.client._add_command(self) - def _plugin_include_hook(self, plugin: Plugin[ClientT]) -> None: + def _plugin_include_hook(self, plugin: PluginBase[ClientT]) -> None: """Called when the plugin requests the command be added to it.""" self._plugin = plugin self._plugin._add_command(self) @@ -304,26 +302,3 @@ async def invoke( self._invoke_task = asyncio.create_task(self._handle_callback(self, ctx, *args, **kwargs)) if self.client.is_rest: return ctx._resp_builder - - -# MIT License -# -# Copyright (c) 2023-present hypergonial -# -# Permission is hereby granted, free of charge, to any person obtaining a copy -# of this software and associated documentation files (the "Software"), to deal -# in the Software without restriction, including without limitation the rights -# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell -# copies of the Software, and to permit persons to whom the Software is -# furnished to do so, subject to the following conditions: -# -# The above copyright notice and this permission notice shall be included in all -# copies or substantial portions of the Software. -# -# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR -# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, -# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE -# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER -# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, -# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE -# SOFTWARE. diff --git a/arc/abc/error_handler.py b/arc/abc/error_handler.py index 4008e37..17e6ed1 100644 --- a/arc/abc/error_handler.py +++ b/arc/abc/error_handler.py @@ -5,7 +5,7 @@ import attr -from ..internal.types import ClientT, ErrorHandlerCallbackT +from arc.internal.types import ClientT, ErrorHandlerCallbackT if t.TYPE_CHECKING: from ..context import Context diff --git a/arc/command/option/base.py b/arc/abc/option.py similarity index 97% rename from arc/command/option/base.py rename to arc/abc/option.py index 45524e2..a04af13 100644 --- a/arc/command/option/base.py +++ b/arc/abc/option.py @@ -7,7 +7,7 @@ import attr import hikari -from ...internal.types import AutocompleteCallbackT, ChoiceT, ClientT, ParamsT +from arc.internal.types import AutocompleteCallbackT, ChoiceT, ClientT, ParamsT if t.TYPE_CHECKING: import typing_extensions as te @@ -24,6 +24,12 @@ ```py arc.Option[type, arc.TypeParams(...)] ``` + +So for example, to create an `int` option, you would do: + +```py +arc.Option[int, arc.IntParams(...)] +``` """ diff --git a/arc/client.py b/arc/client.py index d0cc147..2390472 100644 --- a/arc/client.py +++ b/arc/client.py @@ -1,34 +1,22 @@ from __future__ import annotations -import abc -import asyncio -import functools -import importlib -import inspect import logging -import pathlib -import sys -import traceback import typing as t -from contextlib import suppress -import alluka import hikari -from .command import MessageCommand, SlashCommand, SlashGroup, SlashSubCommand, SlashSubGroup, UserCommand -from .context import AutodeferMode, Context -from .errors import ExtensionLoadError, ExtensionUnloadError, NoResponseIssuedError -from .events import CommandErrorEvent -from .internal.sync import _sync_commands -from .internal.types import AppT, BuilderT, EventCallbackT, EventT, ResponseBuilderT -from .plugin import Plugin +from arc.abc.client import Client +from arc.context import Context +from arc.errors import NoResponseIssuedError +from arc.events import CommandErrorEvent +from arc.plugin import GatewayPluginBase, RESTPluginBase if t.TYPE_CHECKING: import typing_extensions as te - from .command import CommandBase, SlashCommandLike + from .internal.types import EventCallbackT, EventT, ResponseBuilderT -__all__ = ("Client", "GatewayClient", "RESTClient") +__all__ = ("GatewayClient", "RESTClient") T = t.TypeVar("T") @@ -37,680 +25,13 @@ logger = logging.getLogger(__name__) -class Client(t.Generic[AppT], abc.ABC): - """A base class for an `arc` client. - See [`GatewayClient`][arc.client.GatewayClient] and [`RESTClient`][arc.client.RESTClient] for implementations. - - Parameters - ---------- - app : AppT - The application this client is for. - default_enabled_guilds : t.Sequence[hikari.Snowflake] | None, optional - The guilds that slash commands will be registered in by default, by default None - autosync : bool, optional - Whether to automatically sync commands on startup, by default True - """ - - __slots__: t.Sequence[str] = ( - "_app", - "_default_enabled_guilds", - "_application", - "_slash_commands", - "_message_commands", - "_user_commands", - "_injector", - "_autosync", - "_plugins", - "_loaded_extensions", - ) - - def __init__( - self, app: AppT, *, default_enabled_guilds: t.Sequence[hikari.Snowflake] | None = None, autosync: bool = True - ) -> None: - self._app = app - self._default_enabled_guilds = default_enabled_guilds - self._application: hikari.Application | None = None - self._slash_commands: dict[str, SlashCommandLike[te.Self]] = {} - self._message_commands: dict[str, MessageCommand[te.Self]] = {} - self._user_commands: dict[str, UserCommand[te.Self]] = {} - self._injector: alluka.Client = alluka.Client() - self._plugins: dict[str, Plugin[te.Self]] = {} - self._loaded_extensions: list[str] = [] - self._autosync = autosync - - @property - @abc.abstractmethod - def is_rest(self) -> bool: - """Whether the app is a rest client or a gateway client. - - This controls the client response flow, if True, `Client.handle_command_interaction` and `Client.handle_autocomplete_interaction` - will return interaction response builders to be sent back to Discord, otherwise they will return None. - """ - - @property - def app(self) -> AppT: - """The application this client is for.""" - return self._app - - @property - def rest(self) -> hikari.api.RESTClient: - """The REST client of the underyling app.""" - return self.app.rest - - @property - def application(self) -> hikari.Application | None: - """The application object for this client. This is fetched on startup.""" - return self._application - - @property - def injector(self) -> alluka.Client: - """The injector for this client.""" - return self._injector - - @property - def default_enabled_guilds(self) -> t.Sequence[hikari.Snowflake] | None: - """The guilds that slash commands will be registered in by default.""" - return self._default_enabled_guilds - - @property - def commands(self) -> t.Mapping[hikari.CommandType, t.Mapping[str, CommandBase[te.Self, t.Any]]]: - """All commands added to this client, categorized by command type.""" - return { - hikari.CommandType.SLASH: self.slash_commands, - hikari.CommandType.MESSAGE: self._message_commands, - hikari.CommandType.USER: self._user_commands, - } - - @property - def slash_commands(self) -> t.Mapping[str, SlashCommandLike[te.Self]]: - """The slash commands added to this client. This only includes top-level commands and groups.""" - return self._slash_commands - - @property - def message_commands(self) -> t.Mapping[str, MessageCommand[te.Self]]: - """The message commands added to this client.""" - return self._message_commands - - @property - def user_commands(self) -> t.Mapping[str, UserCommand[te.Self]]: - """The user commands added to this client.""" - return self._user_commands - - @property - def plugins(self) -> t.Mapping[str, Plugin[te.Self]]: - """The plugins added to this client.""" - return self._plugins - - def _add_command(self, command: CommandBase[te.Self, t.Any]) -> None: - """Add a command to this client. Called by include hooks.""" - if isinstance(command, (SlashCommand, SlashGroup)): - self._add_slash_command(command) - elif isinstance(command, MessageCommand): - self._add_message_command(command) - elif isinstance(command, UserCommand): - self._add_user_command(command) - - def _remove_command(self, command: CommandBase[te.Self, t.Any]) -> None: - """Remove a command from this client. Called by remove hooks.""" - if isinstance(command, (SlashCommand, SlashGroup)): - self._slash_commands.pop(command.name, None) - elif isinstance(command, MessageCommand): - self._message_commands.pop(command.name, None) - elif isinstance(command, UserCommand): - self._user_commands.pop(command.name, None) - - def _add_slash_command(self, command: SlashCommandLike[te.Self]) -> None: - """Add a slash command to this client.""" - if self.slash_commands.get(command.name) is not None: - logger.warning( - f"Shadowing already registered slash command '{command.name}'. Did you define multiple commands/groups with the same name?" - ) - - self._slash_commands[command.name] = command - - def _add_message_command(self, command: MessageCommand[te.Self]) -> None: - """Add a message command to this client.""" - if self._message_commands.get(command.name) is not None: - logger.warning( - f"Shadowing already registered message command '{command.name}'. Did you define multiple commands with the same name?" - ) - - self._message_commands[command.name] = command - - def _add_user_command(self, command: UserCommand[te.Self]) -> None: - """Add a user command to this client.""" - if self._user_commands.get(command.name) is not None: - logger.warning( - f"Shadowing already registered user command '{command.name}'. Did you define multiple commands with the same name?" - ) - - self._user_commands[command.name] = command - - async def _on_startup(self) -> None: - """Called when the client is starting up. - Fetches application, syncs commands, calls user-defined startup. - """ - self._application = await self.app.rest.fetch_application() - logger.debug(f"Fetched application: '{self.application}'") - if self._autosync: - await _sync_commands(self) - await self.on_startup() - - async def on_startup(self) -> None: - """Called when the client is starting up. - Override for custom startup logic. - """ - - async def _on_shutdown(self) -> None: - """Called when the client is shutting down. - Reserved for internal shutdown logic. - """ - await self.on_shutdown() - - async def on_shutdown(self) -> None: - """Called when the client is shutting down. - Override for custom shutdown logic. - """ - - async def _on_error(self, ctx: Context[te.Self], exception: Exception) -> None: - await self.on_error(ctx, exception) - - async def on_error(self, context: Context[te.Self], exception: Exception) -> None: - """Called when an error occurs in a command callback and all other error handlers have failed. - - Parameters - ---------- - context : Context[te.Self] - The context of the command. - exception : Exception - The exception that was raised. - """ - print(f"Unhandled error in command '{context.command.name}' callback: {exception}", file=sys.stderr) - traceback.print_exception(type(exception), exception, exception.__traceback__, file=sys.stderr) - with suppress(Exception): - await context.respond("❌ Something went wrong. Please contact the bot developer.") - - async def on_command_interaction(self, interaction: hikari.CommandInteraction) -> ResponseBuilderT | None: - """Should be called when a command interaction is sent by Discord. - - Parameters - ---------- - interaction : hikari.CommandInteraction - The interaction that was created. - - Returns - ------- - ResponseBuilderT | None - The response builder to send back to Discord, if using a REST client. - """ - command = None - - if interaction.command_type is hikari.CommandType.SLASH: - command = self.slash_commands.get(interaction.command_name) - elif interaction.command_type is hikari.CommandType.MESSAGE: - command = self.message_commands.get(interaction.command_name) - elif interaction.command_type is hikari.CommandType.USER: - command = self.user_commands.get(interaction.command_name) - - if command is None: - logger.warning(f"Received interaction for unknown command '{interaction.command_name}'.") - return - - fut = await command.invoke(interaction) - - if fut is None: - return - - try: - return await asyncio.wait_for(fut, timeout=3.0) - except asyncio.TimeoutError: - logger.warning( - f"Timed out waiting for response from command: '{interaction.command_name} ({interaction.command_type})'" - f" Did you forget to respond?" - ) - - async def on_autocomplete_interaction( - self, interaction: hikari.AutocompleteInteraction - ) -> hikari.api.InteractionAutocompleteBuilder | None: - """Should be called when an autocomplete interaction is sent by Discord. - - Parameters - ---------- - interaction : hikari.AutocompleteInteraction - The interaction that was created. - - Returns - ------- - hikari.api.InteractionAutocompleteBuilder | None - The autocomplete builder to send back to Discord, if using a REST client. - """ - command = self.slash_commands.get(interaction.command_name) - - if command is None: - logger.warning(f"Received autocomplete interaction for unknown command '{interaction.command_name}'.") - return - - return await command._on_autocomplete(interaction) - - def include(self, command: CommandBase[te.Self, BuilderT]) -> CommandBase[te.Self, BuilderT]: - """First-order decorator to add a command to this client. - - Parameters - ---------- - command : CommandBase[te.Self, BuilderT] - The command to add. - """ - if command.plugin is not None: - raise ValueError(f"Command '{command.name}' is already registered with plugin '{command.plugin.name}'.") - - if command.name in self.commands[command.command_type]: - raise ValueError(f"Command '{command.name}' is already registered with this client.") - - command._client_include_hook(self) - return command - - def include_slash_group( - self, - name: str, - description: str = "No description provided.", - *, - guilds: hikari.UndefinedOr[t.Sequence[hikari.Snowflake]] = hikari.UNDEFINED, - autodefer: bool | AutodeferMode = True, - is_dm_enabled: bool = True, - default_permissions: hikari.UndefinedOr[hikari.Permissions] = hikari.UNDEFINED, - name_localizations: dict[hikari.Locale, str] | None = None, - description_localizations: dict[hikari.Locale, str] | None = None, - is_nsfw: bool = False, - ) -> SlashGroup[te.Self]: - """Add a new slash command group to this client. - - Parameters - ---------- - name : str - The name of the slash command group. - description : str - The description of the slash command group. - guilds : hikari.UndefinedOr[t.Sequence[hikari.Snowflake]], optional - The guilds to register the slash command group in, by default hikari.UNDEFINED - autodefer : bool | AutodeferMode, optional - If True, all commands in this group will automatically defer if it is taking longer than 2 seconds to respond. - This can be overridden on a per-subcommand basis. - is_dm_enabled : bool, optional - Whether the slash command group is enabled in DMs, by default True - default_permissions : hikari.UndefinedOr[hikari.Permissions], optional - The default permissions for the slash command group, by default hikari.UNDEFINED - name_localizations : dict[hikari.Locale, str], optional - The name of the slash command group in different locales, by default None - description_localizations : dict[hikari.Locale, str], optional - The description of the slash command group in different locales, by default None - is_nsfw : bool, optional - Whether the slash command group is only usable in NSFW channels, by default False - - Returns - ------- - SlashGroup[te.Self] - The slash command group that was created. - - Usage - ----- - ```py - group = client.include_slash_group("Group", "A group of commands.") - - @group.include - @arc.slash_subcommand(name="Command", description="A command.") - async def cmd(ctx: arc.Context[arc.GatewayClient]) -> None: - await ctx.respond("Hello!") - ``` - """ - children: dict[str, SlashSubCommand[te.Self] | SlashSubGroup[te.Self]] = {} - - group = SlashGroup( - name=name, - description=description, - children=children, - guilds=guilds, - autodefer=AutodeferMode(autodefer), - is_dm_enabled=is_dm_enabled, - default_permissions=default_permissions, - name_localizations=name_localizations or {}, - description_localizations=description_localizations or {}, - is_nsfw=is_nsfw, - ) - group._client_include_hook(self) - return group - - def add_plugin(self, plugin: Plugin[te.Self]) -> None: - """Add a plugin to this client. - - Parameters - ---------- - plugin : Plugin[te.Self] - The plugin to add. - """ - plugin._client_include_hook(self) - - def remove_plugin(self, plugin: str | Plugin[te.Self]) -> None: - """Remove a plugin from this client. - - Parameters - ---------- - plugin : str | Plugin[te.Self] - The plugin or name of the plugin to remove. - - Raises - ------ - ValueError - If there is no plugin with the given name. - """ - if isinstance(plugin, Plugin): - if plugin not in self.plugins.values(): - raise ValueError(f"Plugin '{plugin.name}' is not registered with this client.") - return plugin._client_remove_hook() - - pg = self.plugins.get(plugin) - - if pg is None: - raise ValueError(f"Plugin '{plugin}' is not registered with this client.") - - pg._client_remove_hook() - - def load_extension(self, path: str) -> te.Self: - """Load a python module with path `path` as an extension. - This will import the module, and call it's loader function. - - Parameters - ---------- - path : str - The path to the module to load. - - Returns - ------- - te.Self - The client for chaining calls. - - Raises - ------ - ValueError - If the module does not have a loader. - - Usage - ----- - ```py - client = arc.GatewayClient(...) - client.load_extension("extension") - - # In extension.py - - plugin = arc.GatewayPlugin[arc.GatewayClient]("test_plugin") - - @arc.loader - def loader(client: arc.GatewayClient) -> None: - client.add_plugin(plugin) - ``` - - See Also - -------- - - [`@arc.loader`][arc.extension.loader] - - [`Client.load_extensions_from`][arc.client.Client.load_extensions_from] - - [`Client.unload_extension`][arc.client.Client.unload_extension] - """ - parents = path.split(".") - name = parents.pop() - - pkg = ".".join(parents) - - if pkg: - name = "." + name - - module = importlib.import_module(path, package=pkg) - - loader = getattr(module, "__arc_extension_loader__", None) - - if loader is None: - raise ValueError(f"Module '{path}' does not have a loader.") - - self._loaded_extensions.append(path) - loader(self) - logger.info(f"Loaded extension: '{path}'") - - return self - - def load_extensions_from(self, dir_path: str | pathlib.Path, recursive: bool = False) -> te.Self: - """Load all python modules in a directory as extensions. - This will import the modules, and call their loader functions. - - Parameters - ---------- - dir_path : str - The path to the directory to load extensions from. - recursive : bool, optional - Whether to load extensions from subdirectories, by default False - - Returns - ------- - te.Self - The client for chaining calls. - - Raises - ------ - ExtensionLoadError - If `dir_path` does not exist or is not a directory. - ExtensionLoadError - If a module does not have a loader defined. - """ - if isinstance(dir_path, str): - dir_path = pathlib.Path(dir_path) - - try: - dir_path.resolve().relative_to(pathlib.Path.cwd()) - except ValueError: - raise ExtensionLoadError("dir_path must be relative to the current working directory.") - - if not dir_path.is_dir(): - raise ExtensionLoadError("dir_path must exist and be a directory.") - - globfunc = dir_path.rglob if recursive else dir_path.glob - loaded = 0 - - for file in globfunc(r"**/[!_]*.py"): - module_path = ".".join(file.as_posix()[:-3].split("/")) - self.load_extension(module_path) - loaded += 1 - - if loaded == 0: - logger.warning(f"No extensions were found at '{dir_path}'.") - - return self - - def unload_extension(self, path: str) -> te.Self: - """Unload a python module with path `path` as an extension. - - Parameters - ---------- - path : str - The path to the module to unload. - - Returns - ------- - te.Self - The client for chaining calls. - - Raises - ------ - ExtensionUnloadError - If the module does not have an unloader or is not loaded. - """ - parents = path.split(".") - name = parents.pop() - - pkg = ".".join(parents) - - if pkg: - name = "." + name - - if path not in self._loaded_extensions: - raise ExtensionUnloadError(f"Extension '{path}' is not loaded.") - - module = importlib.import_module(path, package=pkg) - - unloader = getattr(module, "__arc_extension_unloader__", None) - - if unloader is None: - raise ExtensionUnloadError(f"Module '{path}' does not have an unloader.") - - unloader(self) - self._loaded_extensions.remove(path) - - if module.__name__ in sys.modules: - del sys.modules[module.__name__] - - return self - - def set_type_dependency(self, type_: t.Type[T], instance: T) -> None: - """Set a type dependency for this client. This can then be injected into all arc callbacks. - - Parameters - ---------- - type_ : t.Type[T] - The type of the dependency. - instance : T - The instance of the dependency. - - Usage - ----- - - ```py - class MyDependency: - def __init__(self, value: str): - self.value = value - - client.set_type_dependency(MyDependency, MyDependency("Hello!")) - - @client.include - @arc.slash_command("cmd", "A command.") - async def cmd(ctx: arc.Context[arc.GatewayClient], dep: MyDependency = arc.inject()) -> None: - await ctx.respond(dep.value) - ``` - - See Also - -------- - - [`Client.get_type_dependency`][arc.client.Client.get_type_dependency] - A method to get dependencies for the client. - - - [`Client.inject_dependencies`][arc.client.Client.inject_dependencies] - A decorator to inject dependencies into arbitrary functions. - """ - self._injector.set_type_dependency(type_, instance) - - def get_type_dependency(self, type_: t.Type[T]) -> hikari.UndefinedOr[T]: - """Get a type dependency for this client. - - Parameters - ---------- - type_ : t.Type[T] - The type of the dependency. - - Returns - ------- - hikari.UndefinedOr[T] - The instance of the dependency, if it exists. - """ - return self._injector.get_type_dependency(type_, default=hikari.UNDEFINED) - - def inject_dependencies(self, func: t.Callable[P, T]) -> t.Callable[P, T]: - """First order decorator to inject dependencies into the decorated function. - - !!! note - Command callbacks are automatically injected with dependencies, - thus this decorator is not needed for them. - - Usage - ----- - ```py - class MyDependency: - def __init__(self, value: str): - self.value = value - - client.set_type_dependency(MyDependency, MyDependency("Hello!")) - - @client.inject_dependencies - def my_func(dep: MyDependency = arc.inject()) -> None: - print(dep.value) - - my_func() # Prints "Hello!" - ``` - - See Also - -------- - - [`Client.set_type_dependency`][arc.client.Client.set_type_dependency] - A method to set dependencies for the client. - """ - if inspect.iscoroutinefunction(func): - - @functools.wraps(func) - async def decorator_async(*args: P.args, **kwargs: P.kwargs) -> T: - return await self.injector.call_with_async_di(func, *args, **kwargs) - - return decorator_async # pyright: ignore reportGeneralTypeIssues - else: - - @functools.wraps(func) - def decorator(*args: P.args, **kwargs: P.kwargs) -> T: - return self.injector.call_with_di(func, *args, **kwargs) - - return decorator - - async def resync_commands(self) -> None: - """Synchronize the commands registered in this client with Discord. - - !!! warning - Calling this is expensive, and should only be done when absolutely necessary. - The client automatically syncs commands on startup, unless the `autosync` parameter - is set to `False` when creating the client. - - Raises - ------ - RuntimeError - If `Client.application` is `None`. - This usually only happens if `Client.resync_commands` is called before `Client.on_startup`. - """ - await _sync_commands(self) - - async def purge_all_commands(self, guild: hikari.SnowflakeishOr[hikari.PartialGuild] | None = None) -> None: - """Purge all commands registered on Discord. This can be used to clean up commands. - - Parameters - ---------- - guild : hikari.SnowflakeishOr[hikari.PartialGuild] | None, optional - The guild to purge commands from, by default None - If a `guild` is not provided, this will purge global commands. - - !!! warning - This will remove all commands registered on Discord, **including commands not registered by this client**. - - Raises - ------ - RuntimeError - If `Client.application` is `None`. - This usually only happens if `Client.purge_all_commands` is called before `Client.on_startup`. - """ - if self.application is None: - raise RuntimeError(f"Cannot purge commands before '{type(self).__name__}.application' is fetched.") - - if guild is not None: - guild_id = hikari.Snowflake(guild) - await self.rest.set_application_commands(self.application, [], guild_id) - else: - await self.rest.set_application_commands(self.application, []) - - class GatewayClient(Client[hikari.GatewayBotAware]): - """A base class for an arc client with `hikari.GatewayBotAware` support. + """The default implementation for an arc client with `hikari.GatewayBotAware` support. If you want to use a `hikari.RESTBotAware`, use `RESTClient` instead. Parameters ---------- - app : hikari.GatewayBot + app : hikari.GatewayBotAware The application this client is for. default_enabled_guilds : t.Sequence[hikari.Snowflake] | None, optional The guilds that slash commands will be registered in by default, by default None @@ -797,12 +118,12 @@ def listen(self, *event_types: t.Type[EventT]) -> t.Callable[[EventCallbackT[Eve class RESTClient(Client[hikari.RESTBotAware]): - """A base class for an arc client with `hikari.RESTBotAware` support. + """The default implementation for an arc client with `hikari.RESTBotAware` support. If you want to use `hikari.GatewayBotAware`, use `GatewayClient` instead. Parameters ---------- - app : hikari.RESTBot + app : hikari.RESTBotAware The application this client is for. default_enabled_guilds : t.Sequence[hikari.Snowflake] | None, optional The guilds that slash commands will be registered in by default, by default None @@ -895,6 +216,18 @@ async def _on_restbot_autocomplete_interaction_create( return builder +GatewayContext = Context[GatewayClient] +"""A context using the default gateway client implementation. An alias for [`arc.Context[arc.GatewayClient]`][arc.context.base.Context].""" + +RESTContext = Context[RESTClient] +"""A context using the default REST client implementation. An alias for [`arc.Context[arc.RESTClient]`][arc.context.base.Context].""" + +RESTPlugin = RESTPluginBase[RESTClient] +"""A plugin using the default REST client implementation. An alias for [`arc.RESTPluginBase[arc.RESTClient]`][arc.plugin.RESTPluginBase].""" + +GatewayPlugin = GatewayPluginBase[GatewayClient] +"""An alias for [`arc.GatewayPluginBase[arc.GatewayClient]`][arc.plugin.GatewayPluginBase].""" + # MIT License # # Copyright (c) 2023-present hypergonial diff --git a/arc/command/__init__.py b/arc/command/__init__.py index bf3541f..44811c8 100644 --- a/arc/command/__init__.py +++ b/arc/command/__init__.py @@ -1,4 +1,3 @@ -from .base import CallableCommandBase, CallableCommandProto, CommandBase, CommandProto from .message import MessageCommand, message_command from .option import ( AttachmentOption, @@ -7,17 +6,12 @@ BoolParams, ChannelOption, ChannelParams, - CommandOptionBase, FloatOption, FloatParams, IntOption, IntParams, MentionableOption, MentionableParams, - Option, - OptionParams, - OptionWithChoices, - OptionWithChoicesParams, RoleOption, RoleParams, StrOption, @@ -37,10 +31,6 @@ from .user import UserCommand, user_command __all__ = ( - "CommandBase", - "CommandProto", - "CallableCommandBase", - "CallableCommandProto", "SlashCommand", "SlashCommandLike", "SlashGroup", @@ -48,15 +38,10 @@ "SlashSubGroup", "slash_command", "slash_subcommand", - "CommandOptionBase", - "Option", "BoolOption", "BoolParams", "IntOption", "StrOption", - "OptionParams", - "OptionWithChoices", - "OptionWithChoicesParams", "IntParams", "StrParams", "FloatOption", diff --git a/arc/command/message.py b/arc/command/message.py index 821a59a..e6e9e41 100644 --- a/arc/command/message.py +++ b/arc/command/message.py @@ -5,15 +5,15 @@ import attr import hikari -from ..context import Context -from ..errors import CommandInvokeError -from ..internal.types import ClientT, MessageContextCallbackT, ResponseBuilderT -from .base import AutodeferMode, CallableCommandBase +from arc.abc.command import CallableCommandBase +from arc.context import AutodeferMode, Context +from arc.errors import CommandInvokeError +from arc.internal.types import ClientT, MessageContextCallbackT, ResponseBuilderT if t.TYPE_CHECKING: import asyncio - from .base import CallableCommandProto + from ..abc import CallableCommandProto __all__ = ("MessageCommand", "message_command") diff --git a/arc/command/option/__init__.py b/arc/command/option/__init__.py index 6f347c5..a5e07c9 100644 --- a/arc/command/option/__init__.py +++ b/arc/command/option/__init__.py @@ -1,13 +1,4 @@ from .attachment import AttachmentOption, AttachmentParams -from .base import ( - AutocompleteCallbackT, - CommandOptionBase, - Option, - OptionBase, - OptionParams, - OptionWithChoices, - OptionWithChoicesParams, -) from .bool import BoolOption, BoolParams from .channel import ChannelOption, ChannelParams from .float import FloatOption, FloatParams @@ -18,15 +9,8 @@ from .user import UserOption, UserParams __all__ = ( - "Option", - "CommandOptionBase", - "OptionBase", - "OptionWithChoices", - "AutocompleteCallbackT", "BoolOption", "BoolParams", - "OptionParams", - "OptionWithChoicesParams", "IntOption", "IntParams", "StrOption", diff --git a/arc/command/option/attachment.py b/arc/command/option/attachment.py index 6744396..aba37e3 100644 --- a/arc/command/option/attachment.py +++ b/arc/command/option/attachment.py @@ -5,12 +5,14 @@ import attr import hikari -from ...internal.types import ClientT -from .base import CommandOptionBase, OptionParams +from arc.abc.option import CommandOptionBase, OptionParams +from arc.internal.types import ClientT if t.TYPE_CHECKING: import typing_extensions as te +__all__ = ("AttachmentOption", "AttachmentParams") + class AttachmentParams(OptionParams[hikari.Attachment]): """The parameters for an attachment option. diff --git a/arc/command/option/bool.py b/arc/command/option/bool.py index 5540ba6..2d2c09c 100644 --- a/arc/command/option/bool.py +++ b/arc/command/option/bool.py @@ -5,12 +5,14 @@ import attr import hikari -from ...internal.types import ClientT -from .base import CommandOptionBase, OptionParams +from arc.abc.option import CommandOptionBase, OptionParams +from arc.internal.types import ClientT if t.TYPE_CHECKING: import typing_extensions as te +__all__ = ("BoolOption", "BoolParams") + class BoolParams(OptionParams[bool]): """The parameters for a bool option. diff --git a/arc/command/option/channel.py b/arc/command/option/channel.py index 06fdea6..3ff809f 100644 --- a/arc/command/option/channel.py +++ b/arc/command/option/channel.py @@ -5,12 +5,14 @@ import attr import hikari -from ...internal.types import ClientT -from .base import CommandOptionBase, OptionParams +from arc.abc.option import CommandOptionBase, OptionParams +from arc.internal.types import ClientT if t.TYPE_CHECKING: import typing_extensions as te +__all__ = ("ChannelOption", "ChannelParams") + class ChannelParams(OptionParams[hikari.PartialChannel]): """The parameters for a channel option. diff --git a/arc/command/option/float.py b/arc/command/option/float.py index 3e6e980..80663e4 100644 --- a/arc/command/option/float.py +++ b/arc/command/option/float.py @@ -5,14 +5,16 @@ import attr import hikari -from ...internal.types import ClientT -from .base import OptionWithChoices, OptionWithChoicesParams +from arc.abc.option import OptionWithChoices, OptionWithChoicesParams +from arc.internal.types import ClientT if t.TYPE_CHECKING: import typing_extensions as te from ...internal.types import AutocompleteCallbackT +__all__ = ("FloatOption", "FloatParams") + class FloatParams(OptionWithChoicesParams[float, ClientT]): """The parameters for a float option. diff --git a/arc/command/option/int.py b/arc/command/option/int.py index d885e75..4913eef 100644 --- a/arc/command/option/int.py +++ b/arc/command/option/int.py @@ -5,14 +5,16 @@ import attr import hikari -from ...internal.types import ClientT -from .base import OptionWithChoices, OptionWithChoicesParams +from arc.abc.option import OptionWithChoices, OptionWithChoicesParams +from arc.internal.types import ClientT if t.TYPE_CHECKING: import typing_extensions as te from ...internal.types import AutocompleteCallbackT +__all__ = ("IntOption", "IntParams") + class IntParams(OptionWithChoicesParams[int, ClientT]): """The parameters for an int option. diff --git a/arc/command/option/mentionable.py b/arc/command/option/mentionable.py index a38412f..fd8fb33 100644 --- a/arc/command/option/mentionable.py +++ b/arc/command/option/mentionable.py @@ -5,12 +5,14 @@ import attr import hikari -from ...internal.types import ClientT -from .base import CommandOptionBase, OptionParams +from arc.abc.option import CommandOptionBase, OptionParams +from arc.internal.types import ClientT if t.TYPE_CHECKING: import typing_extensions as te +__all__ = ("MentionableOption", "MentionableParams") + class MentionableParams(OptionParams[hikari.Role | hikari.User]): """The parameters for a mentionable option. diff --git a/arc/command/option/role.py b/arc/command/option/role.py index d367c57..dbbaba7 100644 --- a/arc/command/option/role.py +++ b/arc/command/option/role.py @@ -5,12 +5,14 @@ import attr import hikari -from ...internal.types import ClientT -from .base import CommandOptionBase, OptionParams +from arc.abc.option import CommandOptionBase, OptionParams +from arc.internal.types import ClientT if t.TYPE_CHECKING: import typing_extensions as te +__all__ = ("RoleOption", "RoleParams") + class RoleParams(OptionParams[hikari.Role]): """The parameters for a user option. diff --git a/arc/command/option/str.py b/arc/command/option/str.py index 056c148..170cf01 100644 --- a/arc/command/option/str.py +++ b/arc/command/option/str.py @@ -5,13 +5,16 @@ import attr import hikari -from ...internal.types import ClientT -from .base import OptionWithChoices, OptionWithChoicesParams +from arc.abc.option import OptionWithChoices, OptionWithChoicesParams +from arc.internal.types import ClientT if t.TYPE_CHECKING: import typing_extensions as te - from ...internal.types import AutocompleteCallbackT + from arc.internal.types import AutocompleteCallbackT + + +__all__ = ("StrOption", "StrParams") class StrParams(OptionWithChoicesParams[str, ClientT]): diff --git a/arc/command/option/user.py b/arc/command/option/user.py index c6e38e9..42001e7 100644 --- a/arc/command/option/user.py +++ b/arc/command/option/user.py @@ -5,13 +5,16 @@ import attr import hikari -from ...internal.types import ClientT -from .base import CommandOptionBase, OptionParams +from arc.abc.option import CommandOptionBase, OptionParams +from arc.internal.types import ClientT if t.TYPE_CHECKING: import typing_extensions as te +__all__ = ("UserOption", "UserParams") + + class UserParams(OptionParams[hikari.User]): """The parameters for a user option. diff --git a/arc/command/slash.py b/arc/command/slash.py index c2a65aa..148bad8 100644 --- a/arc/command/slash.py +++ b/arc/command/slash.py @@ -6,20 +6,20 @@ import attr import hikari -from ..abc import HasErrorHandler -from ..context import AutocompleteData, Context -from ..errors import AutocompleteError, CommandInvokeError -from ..internal.sigparse import parse_function_signature -from ..internal.types import ClientT, CommandCallbackT, ResponseBuilderT, SlashCommandLike -from .base import AutodeferMode, CallableCommandBase, CommandBase -from .option import OptionBase, OptionWithChoices +from arc.abc.command import CallableCommandBase, CommandBase +from arc.abc.error_handler import HasErrorHandler +from arc.abc.option import OptionBase, OptionWithChoices +from arc.context import AutocompleteData, AutodeferMode, Context +from arc.errors import AutocompleteError, CommandInvokeError +from arc.internal.sigparse import parse_function_signature +from arc.internal.types import ClientT, CommandCallbackT, ResponseBuilderT, SlashCommandLike if t.TYPE_CHECKING: from asyncio.futures import Future - from ..plugin import Plugin - from .base import CallableCommandProto - from .option import CommandOptionBase + from arc.abc.command import CallableCommandProto + from arc.abc.option import CommandOptionBase + from arc.plugin import PluginBase __all__ = ( "SlashCommandLike", @@ -418,7 +418,7 @@ class SlashSubGroup(OptionBase[ClientT], HasErrorHandler[ClientT]): If undefined, then it will be inherited from the parent. """ - _plugin: Plugin[ClientT] | None = attr.field(default=None, init=False) + _plugin: PluginBase[ClientT] | None = attr.field(default=None, init=False) @property def option_type(self) -> hikari.OptionType: @@ -444,7 +444,7 @@ def client(self) -> ClientT: return self.parent.client @property - def plugin(self) -> Plugin[ClientT] | None: + def plugin(self) -> PluginBase[ClientT] | None: """The plugin that includes this subgroup.""" return self._plugin @@ -542,7 +542,7 @@ def client(self) -> ClientT: return self.root.client @property - def plugin(self) -> Plugin[ClientT] | None: + def plugin(self) -> PluginBase[ClientT] | None: """The plugin that includes this subcommand.""" return self.root.plugin diff --git a/arc/command/user.py b/arc/command/user.py index 80ca79d..8f80e6b 100644 --- a/arc/command/user.py +++ b/arc/command/user.py @@ -5,16 +5,16 @@ import attr import hikari -from ..context import Context -from ..errors import CommandInvokeError -from ..internal.types import ClientT, ResponseBuilderT -from .base import AutodeferMode, CallableCommandBase +from arc.abc.command import CallableCommandBase +from arc.context import AutodeferMode, Context +from arc.errors import CommandInvokeError +from arc.internal.types import ClientT, ResponseBuilderT if t.TYPE_CHECKING: import asyncio - from ..internal.types import UserContextCallbackT - from .base import CallableCommandProto + from arc.abc.command import CallableCommandProto + from arc.internal.types import UserContextCallbackT __all__ = ("UserCommand", "user_command") diff --git a/arc/context/autocomplete.py b/arc/context/autocomplete.py index e725d97..04dddf0 100644 --- a/arc/context/autocomplete.py +++ b/arc/context/autocomplete.py @@ -5,12 +5,12 @@ import attr -from ..internal.types import ChoiceT, ClientT +from arc.internal.types import ChoiceT, ClientT if t.TYPE_CHECKING: import hikari - from ..command import CommandProto + from arc.abc.command import CommandProto __all__ = ("AutocompleteData",) diff --git a/arc/context/base.py b/arc/context/base.py index 2032c57..017dc4d 100644 --- a/arc/context/base.py +++ b/arc/context/base.py @@ -10,11 +10,11 @@ import attr import hikari -from ..errors import NoResponseIssuedError, ResponseAlreadyIssuedError -from ..internal.types import ClientT, ResponseBuilderT +from arc.errors import NoResponseIssuedError, ResponseAlreadyIssuedError +from arc.internal.types import ClientT, ResponseBuilderT if t.TYPE_CHECKING: - from ..command import CallableCommandProto + from arc.abc.command import CallableCommandProto __all__ = ("Context", "InteractionResponse", "AutodeferMode") diff --git a/arc/events.py b/arc/events.py index 280d35c..8840223 100644 --- a/arc/events.py +++ b/arc/events.py @@ -4,10 +4,10 @@ import hikari -from .internal.types import GatewayClientT +from arc.internal.types import GatewayClientT if t.TYPE_CHECKING: - from .context import Context + from arc.context import Context __all__ = ("ArcEvent", "CommandErrorEvent") diff --git a/arc/extension.py b/arc/extension.py index 3d56171..fb65e10 100644 --- a/arc/extension.py +++ b/arc/extension.py @@ -4,7 +4,7 @@ import typing as t if t.TYPE_CHECKING: - from .internal.types import ClientT + from arc.internal.types import ClientT def loader(callback: t.Callable[[ClientT], None]) -> t.Callable[[ClientT], None]: diff --git a/arc/internal/deprecation.py b/arc/internal/deprecation.py index 1ccd5b6..7b6662d 100644 --- a/arc/internal/deprecation.py +++ b/arc/internal/deprecation.py @@ -3,7 +3,7 @@ import typing as t from warnings import warn -from .version import CURRENT_VERSION, Version +from arc.internal.version import CURRENT_VERSION, Version __all__ = ("warn_deprecate",) diff --git a/arc/internal/sigparse.py b/arc/internal/sigparse.py index 0e27dc0..82807d6 100644 --- a/arc/internal/sigparse.py +++ b/arc/internal/sigparse.py @@ -6,7 +6,8 @@ import hikari -from ..command.option import ( +from arc.abc.option import OptionParams +from arc.command.option import ( AttachmentOption, AttachmentParams, BoolOption, @@ -19,7 +20,6 @@ IntParams, MentionableOption, MentionableParams, - OptionParams, RoleOption, RoleParams, StrOption, @@ -29,9 +29,9 @@ ) if t.TYPE_CHECKING: - from ..command.option import CommandOptionBase - from ..context import Context - from .types import ClientT + from arc.abc.option import CommandOptionBase + from arc.context import Context + from arc.internal.types import ClientT __all__ = ("parse_function_signature",) diff --git a/arc/internal/sync.py b/arc/internal/sync.py index ca728ee..ea21b5d 100644 --- a/arc/internal/sync.py +++ b/arc/internal/sync.py @@ -9,9 +9,9 @@ import hikari if t.TYPE_CHECKING: - from ..client import Client - from ..command import CommandBase - from .types import AppT + from arc.abc.client import Client + from arc.abc.command import CommandBase + from arc.internal.types import AppT __all__ = ("_sync_commands",) diff --git a/arc/internal/types.py b/arc/internal/types.py index 5c160bf..f3394fc 100644 --- a/arc/internal/types.py +++ b/arc/internal/types.py @@ -5,9 +5,10 @@ if t.TYPE_CHECKING: import hikari - from ..client import Client, GatewayClient, RESTClient - from ..command import OptionParams, SlashCommand, SlashGroup - from ..context import AutocompleteData, Context + from arc.abc import Client, OptionParams + from arc.client import GatewayClient, RESTClient + from arc.command import SlashCommand, SlashGroup + from arc.context import AutocompleteData, Context # Generics diff --git a/arc/internal/version.py b/arc/internal/version.py index 93a85b6..27425c6 100644 --- a/arc/internal/version.py +++ b/arc/internal/version.py @@ -3,7 +3,7 @@ import typing as t from functools import total_ordering -from .about import __version__ +from arc.internal.about import __version__ if t.TYPE_CHECKING: import typing_extensions as te diff --git a/arc/plugin.py b/arc/plugin.py index 7a2aa87..1c1921e 100644 --- a/arc/plugin.py +++ b/arc/plugin.py @@ -8,22 +8,23 @@ import hikari -from .abc import HasErrorHandler -from .command import MessageCommand, SlashCommand, SlashGroup, UserCommand -from .context import AutodeferMode, Context -from .internal.types import BuilderT, ClientT, EventCallbackT, GatewayClientT, RESTClientT, SlashCommandLike +from arc.abc.error_handler import HasErrorHandler +from arc.command import MessageCommand, SlashCommand, SlashGroup, UserCommand +from arc.context import AutodeferMode, Context +from arc.internal.types import BuilderT, ClientT, EventCallbackT, GatewayClientT, RESTClientT, SlashCommandLike if t.TYPE_CHECKING: - from .command import CommandBase, SlashSubCommand, SlashSubGroup + from arc.abc.command import CommandBase + from arc.command import SlashSubCommand, SlashSubGroup -__all__ = ("Plugin", "RESTPlugin", "GatewayPlugin") +__all__ = ("PluginBase", "RESTPluginBase", "GatewayPluginBase") P = t.ParamSpec("P") T = t.TypeVar("T") -class Plugin(HasErrorHandler[ClientT], t.Generic[ClientT]): - """A base class for plugins. +class PluginBase(HasErrorHandler[ClientT], t.Generic[ClientT]): + """An abstract base class for plugins. Parameters ---------- @@ -261,7 +262,7 @@ def decorator(*args: P.args, **kwargs: P.kwargs) -> T: return decorator -class RESTPlugin(Plugin[RESTClientT]): +class RESTPluginBase(PluginBase[RESTClientT]): """The default implementation of a REST plugin. Parameters @@ -275,7 +276,7 @@ def is_rest(self) -> bool: return True -class GatewayPlugin(Plugin[GatewayClientT]): +class GatewayPluginBase(PluginBase[GatewayClientT]): """The default implementation of a gateway plugin. Parameters diff --git a/docs/api_reference/abc/client.md b/docs/api_reference/abc/client.md new file mode 100644 index 0000000..17cd2e2 --- /dev/null +++ b/docs/api_reference/abc/client.md @@ -0,0 +1,8 @@ +--- +title: Client ABC +description: Abstract Base Classes API reference +--- + +# Client ABC + +::: arc.abc.client diff --git a/docs/api_reference/abc/command.md b/docs/api_reference/abc/command.md new file mode 100644 index 0000000..6c723df --- /dev/null +++ b/docs/api_reference/abc/command.md @@ -0,0 +1,8 @@ +--- +title: Command ABCs +description: Abstract Base Classes API reference +--- + +# Command ABCs + +::: arc.abc.command diff --git a/docs/api_reference/abc/error_handler.md b/docs/api_reference/abc/error_handler.md new file mode 100644 index 0000000..dd07dbb --- /dev/null +++ b/docs/api_reference/abc/error_handler.md @@ -0,0 +1,8 @@ +--- +title: Error Handling ABCs +description: Abstract Base Classes API reference +--- + +# Error Handling ABCs + +::: arc.abc.error_handler diff --git a/docs/api_reference/abc/index.md b/docs/api_reference/abc/index.md index 9f4ffcb..e6c9edf 100644 --- a/docs/api_reference/abc/index.md +++ b/docs/api_reference/abc/index.md @@ -5,6 +5,4 @@ description: Abstract Base Classes API reference # ABC -This is where you can find all the Abstract Base Classes defined by arc. You generally shouldn't need to use these classes directly, however most objects in arc derive from these, so they may be useful to you in some way. - -::: arc.abc +This is where you can find all the Abstract Base Classes defined by arc. You generally shouldn't need to use/inherit from these classes directly, however most objects in arc derive from these, so they may be useful to you in some way. diff --git a/docs/api_reference/abc/option.md b/docs/api_reference/abc/option.md new file mode 100644 index 0000000..2d2cf92 --- /dev/null +++ b/docs/api_reference/abc/option.md @@ -0,0 +1,8 @@ +--- +title: Option ABCs +description: Abstract Base Classes API reference +--- + +# Option ABCs + +::: arc.abc.option diff --git a/docs/api_reference/command/index.md b/docs/api_reference/command/index.md index cbc61af..a9be8e5 100644 --- a/docs/api_reference/command/index.md +++ b/docs/api_reference/command/index.md @@ -5,4 +5,4 @@ description: Command API reference # Command -::: arc.command.base +Here you can find all the concrete implementations of commands inside arc. For the abstract base classes see [here](../abc/index.md). diff --git a/docs/changelog.md b/docs/changelog.md index 6d543f9..a001f35 100644 --- a/docs/changelog.md +++ b/docs/changelog.md @@ -13,7 +13,13 @@ Here you can find all the changelogs for `hikari-arc`. - **Breaking:** Rename `Context.edit_response()` to `Context.edit_initial_response()`. This is to make the purpose of the function clearer. - **Breaking:** Remove `arc.Injected[T]` typehint alias. Use `arc.inject()` instead. This is to avoid confusion between the two. +- **Breaking:** Rename `GatewayPlugin` to `GatewayPluginBase` and `RESTPlugin` to `RESTPluginBase`. +- Add `GatewayContext` aliasing `Context[GatewayClient]` +- Add `RESTContext` aliasing `Context[RESTClient]` +- Add `GatewayPlugin` aliasing `GatewayPluginBase[GatewayClient]` +- Add `RESTPlugin` aliasing `RESTPlugin[RESTClient]` - Add support for passing mappings to `choices=` when specifying option params. +- Move `ABC`s used internally under `arc.abc`. - Improve handling missing responses via REST by adding `NoResponseIssuedError`. - Fix `@plugin.inject_dependencies` failing when located outside of the main module. diff --git a/docs/getting_started.md b/docs/getting_started.md index 38d967f..c495b93 100644 --- a/docs/getting_started.md +++ b/docs/getting_started.md @@ -49,7 +49,7 @@ If successful, it should output basic information about the library. @client.include @arc.slash_command(name="hi", description="Say hi to someone!") async def hi_slash( - ctx: arc.Context[arc.GatewayClient], + ctx: arc.GatewayContext, user: arc.Option[hikari.User, arc.UserParams(description="The user to say hi to.")] ) -> None: await ctx.respond(f"Hey {user.mention}!") @@ -71,7 +71,7 @@ If successful, it should output basic information about the library. @client.include @arc.slash_command(name="hi", description="Say hi to someone!") async def hi_slash( - ctx: arc.Context[arc.RESTClient], + ctx: arc.RESTContext, user: arc.Option[hikari.User, arc.UserParams(description="The user to say hi to.")] ) -> None: await ctx.respond(f"Hey {user.mention}!") diff --git a/docs/guides/dependency_injection.md b/docs/guides/dependency_injection.md index 7ff7d63..899c526 100644 --- a/docs/guides/dependency_injection.md +++ b/docs/guides/dependency_injection.md @@ -68,8 +68,7 @@ In the above example, we asked `arc` that every time we ask for a dependency of @arc.slash_command("increment", "Increment a counter!") # We inject a dependency of type 'MyDatabase' here. async def increment( - ctx: arc.Context[arc.GatewayClient], - db: MyDatabase = arc.inject() + ctx: arc.GatewayContext, db: MyDatabase = arc.inject() ) -> None: db.value += 1 await ctx.respond(f"Counter is at: `{db.value}`") @@ -83,8 +82,7 @@ In the above example, we asked `arc` that every time we ask for a dependency of @arc.slash_command("increment", "Increment a counter!") # We inject a dependency of type 'MyDatabase' here. async def increment( - ctx: arc.Context[arc.RESTClient], - db: MyDatabase = arc.inject() + ctx: arc.RESTContext, db: MyDatabase = arc.inject() ) -> None: db.value += 1 await ctx.respond(f"Counter is at: `{db.value}`") @@ -160,8 +158,7 @@ Let's say our app has two configurations, a "testing mode" where we want our "da # We inject 'Database' here, the caller doesn't know which # implementation it will get! async def fetch_data( - ctx: arc.Context[arc.GatewayClient], - db: Database = arc.inject() + ctx: arc.GatewayContext, db: Database = arc.inject() ) -> None: data = await db.fetch_data() await ctx.respond(f"Data is: `{data}`") @@ -184,8 +181,7 @@ Let's say our app has two configurations, a "testing mode" where we want our "da # We inject 'Database' here, the caller doesn't know which # implementation it will get! async def fetch_data( - ctx: arc.Context[arc.RESTClient], - db: Database = arc.inject() + ctx: arc.RESTContext, db: Database = arc.inject() ) -> None: data = await db.fetch_data() await ctx.respond(f"Data is: `{data}`") diff --git a/docs/guides/interactions.md b/docs/guides/interactions.md index fc5f0b9..6ec6c5a 100644 --- a/docs/guides/interactions.md +++ b/docs/guides/interactions.md @@ -87,7 +87,7 @@ await context.respond("I'm secret!", flags=hikari.MessageFlag.EPHEMERAL) @client.include @arc.slash_command("name", "description") - async def takes_time(context: arc.Context[arc.GatewayClient]) -> None: + async def takes_time(context: arc.GatewayContext) -> None: await asyncio.sleep(10) # Simulate something taking a long time await context.respond("Finished!") ``` @@ -102,7 +102,7 @@ await context.respond("I'm secret!", flags=hikari.MessageFlag.EPHEMERAL) @client.include @arc.slash_command("name", "description") - async def takes_time(context: arc.Context[arc.RESTClient]) -> None: + async def takes_time(context: arc.RESTContext) -> None: await asyncio.sleep(10) # Simulate something taking a long time await context.respond("Finished!") ``` @@ -125,7 +125,7 @@ This can be passed to the command decorator's `autodefer=` keyword argument: @client.include @arc.slash_command("name", "description", autodefer=AutodeferMode.EPHEMERAL) - async def takes_time(context: arc.Context[arc.GatewayClient]) -> None: + async def takes_time(context: arc.GatewayContext) -> None: await asyncio.sleep(10) await context.respond("Finished!") # This will now be an ephemeral response! ``` @@ -140,7 +140,7 @@ This can be passed to the command decorator's `autodefer=` keyword argument: @client.include @arc.slash_command("name", "description", autodefer=AutodeferMode.EPHEMERAL) - async def takes_time(context: arc.Context[arc.RESTClient]) -> None: + async def takes_time(context: arc.RESTContext) -> None: await asyncio.sleep(10) await context.respond("Finished!") # This will now be an ephemeral response! ``` diff --git a/mkdocs.yml b/mkdocs.yml index f13be4f..1da88b7 100644 --- a/mkdocs.yml +++ b/mkdocs.yml @@ -77,7 +77,12 @@ nav: - api_reference/errors.md - api_reference/events.md - api_reference/plugin.md - - api_reference/abc/index.md + - ABC: + - api_reference/abc/index.md + - api_reference/abc/client.md + - api_reference/abc/command.md + - api_reference/abc/option.md + - api_reference/abc/error_handler.md - Changelog: changelog.md @@ -121,7 +126,7 @@ plugins: annotations_path: source docstring_section_style: spacy docstring_style: numpy - inherited_members: false + inherited_members: true merge_init_into_class: true separate_signature: true show_signature_annotations: true diff --git a/tests/test_context_command.py b/tests/test_context_command.py index f98ed18..71727ac 100644 --- a/tests/test_context_command.py +++ b/tests/test_context_command.py @@ -8,13 +8,13 @@ @client.include @arc.user_command(name="Say Hi") -async def ping_user(ctx: arc.Context[arc.GatewayClient], user: hikari.User) -> None: +async def ping_user(ctx: arc.GatewayContext, user: hikari.User) -> None: await ctx.respond(f"Hi {user}!") @client.include @arc.message_command(name="Say Hi") -async def ping_message(ctx: arc.Context[arc.GatewayClient], message: hikari.Message) -> None: +async def ping_message(ctx: arc.GatewayContext, message: hikari.Message) -> None: await ctx.respond(f"Hi {message.author}!") diff --git a/tests/test_extension/extension.py b/tests/test_extension/extension.py index a100e6b..eb6c2ee 100644 --- a/tests/test_extension/extension.py +++ b/tests/test_extension/extension.py @@ -2,13 +2,13 @@ import arc -plugin = arc.GatewayPlugin[arc.GatewayClient]("test_plugin") +plugin = arc.GatewayPlugin("test_plugin") @plugin.include @arc.slash_command("foo", "Test description") async def foo_command( - ctx: arc.Context[arc.GatewayClient], + ctx: arc.GatewayContext, a: arc.Option[float | None, arc.FloatParams(description="foo", max=50.0)], b: arc.Option[hikari.GuildChannel | None, arc.ChannelParams(description="bar")], ) -> None: @@ -18,7 +18,7 @@ async def foo_command( @plugin.include @arc.slash_command("bar", "Test description") async def bar_command( - ctx: arc.Context[arc.GatewayClient], + ctx: arc.GatewayContext, a: arc.Option[hikari.Role | hikari.User, arc.MentionableParams(description="foo")], b: arc.Option[hikari.Attachment | None, arc.AttachmentParams(description="bar")], ) -> None: @@ -28,7 +28,7 @@ async def bar_command( @plugin.include @arc.slash_command("baz", "Test description") async def baz_command( - ctx: arc.Context[arc.GatewayClient], + ctx: arc.GatewayContext, a: arc.Option[int, arc.IntParams(description="foo", min=10)], b: arc.Option[str, arc.StrParams(description="bar", min_length=100)], ) -> None: @@ -43,7 +43,7 @@ async def baz_command( @group.include @arc.slash_subcommand("test_subcommand", "My subcommand description") async def my_subcommand( - ctx: arc.Context[arc.GatewayClient], + ctx: arc.GatewayContext, a: arc.Option[int, arc.IntParams(description="foo", min=10)], b: arc.Option[str, arc.StrParams(description="bar", min_length=100)], ) -> None: diff --git a/tests/test_sigparse.py b/tests/test_sigparse.py index f1855bd..05bd088 100644 --- a/tests/test_sigparse.py +++ b/tests/test_sigparse.py @@ -1,5 +1,3 @@ -import typing as t - import hikari import pytest @@ -8,7 +6,7 @@ async def correct_command( - ctx: arc.Context[t.Any], + ctx: arc.GatewayContext, a: arc.Option[int, arc.IntParams(description="foo", min=10)], b: arc.Option[str, arc.StrParams(description="bar", min_length=100)], c: arc.Option[float | None, arc.FloatParams(description="baz", max=50.0)], @@ -93,7 +91,7 @@ def test_correct_command() -> None: async def wrong_params_type( - ctx: arc.Context[t.Any], + ctx: arc.GatewayContext, a: arc.Option[int, arc.IntParams(description="foo", min=10)], b: arc.Option[str, arc.IntParams(description="bar", max=50)], ) -> None: @@ -109,7 +107,7 @@ class WrongType: pass -async def wrong_opt_type(ctx: arc.Context[t.Any], a: arc.Option[WrongType, arc.IntParams(description="foo")]) -> None: +async def wrong_opt_type(ctx: arc.GatewayContext, a: arc.Option[WrongType, arc.IntParams(description="foo")]) -> None: pass diff --git a/tests/test_slash.py b/tests/test_slash.py index 6760bb2..4ad7d40 100644 --- a/tests/test_slash.py +++ b/tests/test_slash.py @@ -9,7 +9,7 @@ @client.include @arc.slash_command("test", default_permissions=hikari.Permissions.ADMINISTRATOR) async def my_command( - ctx: arc.Context[arc.GatewayClient], + ctx: arc.GatewayContext, a: arc.Option[int, arc.IntParams(description="foo", min=10)], b: arc.Option[str, arc.StrParams(description="bar", min_length=100)], c: arc.Option[float | None, arc.FloatParams(description="baz", max=50.0)], @@ -31,7 +31,7 @@ async def my_command( @group.include @arc.slash_subcommand("test_subcommand", "My subcommand description") async def my_subcommand( - ctx: arc.Context[arc.GatewayClient], + ctx: arc.GatewayContext, a: arc.Option[int, arc.IntParams(description="foo", min=10)], b: arc.Option[str, arc.StrParams(description="bar", min_length=100)], ) -> None: @@ -41,7 +41,7 @@ async def my_subcommand( @subgroup.include @arc.slash_subcommand("test_subsubcommand", "My subsubcommand description") async def my_subsubcommand( - ctx: arc.Context[arc.GatewayClient], + ctx: arc.GatewayContext, a: arc.Option[int, arc.IntParams(description="foo", min=10)], b: arc.Option[str, arc.StrParams(description="bar", min_length=100)], ) -> None: