Skip to content

Commit

Permalink
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Fix invalid type hinting of initial_channels in WS Connection (Python…
Browse files Browse the repository at this point in the history
junah201 committed Jul 3, 2024

Verified

This commit was signed with the committer’s verified signature.
junah201 Jun-Ah 준아
1 parent 8e91cec commit fe807b3
Showing 3 changed files with 47 additions and 26 deletions.
2 changes: 1 addition & 1 deletion twitchio/client.py
Original file line number Diff line number Diff line change
@@ -76,7 +76,7 @@ def __init__(
token: str,
*,
client_secret: str = None,
initial_channels: Union[list, tuple, Callable] = None,
initial_channels: Union[List[str], Tuple[str, ...], Callable[[], Union[List[str], Tuple[str, ...]]], None] = None,
loop: asyncio.AbstractEventLoop = None,
heartbeat: Optional[float] = 30.0,
retain_cache: Optional[bool] = True,
4 changes: 2 additions & 2 deletions twitchio/ext/commands/bot.py
Original file line number Diff line number Diff line change
@@ -31,7 +31,7 @@
import types
import warnings
from functools import partial
from typing import Callable, Optional, Union, Coroutine, Dict, List, TYPE_CHECKING, Mapping, Awaitable
from typing import Callable, Optional, Union, Coroutine, Dict, List, Tuple, TYPE_CHECKING, Mapping, Awaitable

from twitchio.client import Client
from twitchio.http import TwitchHTTP
@@ -53,7 +53,7 @@ def __init__(
*,
prefix: Union[str, list, tuple, set, Callable, Coroutine],
client_secret: str = None,
initial_channels: Union[list, tuple, Callable] = None,
initial_channels: Union[List[str], Tuple[str, ...], Callable[[], Union[List[str], Tuple[str, ...]]], None] = None,
heartbeat: Optional[float] = 30.0,
retain_cache: Optional[bool] = True,
**kwargs,
67 changes: 44 additions & 23 deletions twitchio/websocket.py
Original file line number Diff line number Diff line change
@@ -31,7 +31,7 @@
import time
import traceback
from functools import partial
from typing import Union, Optional, List, TYPE_CHECKING
from typing import Union, Callable, Optional, List, Tuple, TYPE_CHECKING

import aiohttp

@@ -57,7 +57,7 @@ def __init__(
client: "Client",
token: str = None,
modes: tuple = None,
initial_channels: List[str] = None,
initial_channels: Union[List[str], Tuple[str, ...], Callable[[], Union[List[str], Tuple[str, ...]]], None] = None,
retain_cache: Optional[bool] = True,
):
self._loop = loop
@@ -115,7 +115,8 @@ def __init__(
async def _task_cleanup(self):
while True:
# keep all undone tasks
self._background_tasks = list(filter(lambda task: not task.done(), self._background_tasks))
self._background_tasks = list(
filter(lambda task: not task.done(), self._background_tasks))

# cleanup tasks every 30 seconds
await asyncio.sleep(30)
@@ -134,13 +135,15 @@ async def _connect(self):
if self._keeper:
self._keeper.cancel() # Stop our current keep alive.
if self.is_alive:
await self._websocket.close() # If for some reason we are in a weird state, close it before retrying.
# If for some reason we are in a weird state, close it before retrying.
await self._websocket.close()
if not self._client._http.nick:
try:
data = await self._client._http.validate(token=self._token)
except AuthenticationError:
await self._client._http.session.close()
self._client._closing.set() # clean up and error out (this is called to avoid calling Client.close in start()
# clean up and error out (this is called to avoid calling Client.close in start()
self._client._closing.set()
raise
self.nick = data["login"]
self.user_id = int(data["user_id"])
@@ -152,18 +155,21 @@ async def _connect(self):
self._websocket = await session.ws_connect(url=HOST, heartbeat=self._heartbeat)
except Exception as e:
retry = self._backoff.delay()
log.error(f"Websocket connection failure: {e}:: Attempting reconnect in {retry} seconds.")
log.error(
f"Websocket connection failure: {e}:: Attempting reconnect in {retry} seconds.")

await asyncio.sleep(retry)
return await self._connect()

await self.authenticate(self._initial_channels)

self._reconnect_requested = False
self._keeper = asyncio.create_task(self._keep_alive()) # Create our keep alive.
# Create our keep alive.
self._keeper = asyncio.create_task(self._keep_alive())

if not self._task_cleaner or self._task_cleaner.done():
self._task_cleaner = asyncio.create_task(self._task_cleanup()) # Create our task cleaner.
# Create our task cleaner.
self._task_cleaner = asyncio.create_task(self._task_cleanup())

self._ws_ready_event.set()

@@ -182,28 +188,33 @@ async def _keep_alive(self):
data = msg.data
if data:
log.debug(f" < {data}")
self.dispatch("raw_data", data) # Dispatch our event_raw_data event...
# Dispatch our event_raw_data event...
self.dispatch("raw_data", data)

events = data.split("\r\n")
for event in events:
if not event:
continue
task = asyncio.create_task(self._process_data(event))
task.add_done_callback(partial(self._task_callback, event)) # Process our raw data
# Process our raw data
task.add_done_callback(partial(self._task_callback, event))
self._background_tasks.append(task)

self._background_tasks.append(asyncio.create_task(self._connect()))

def _task_callback(self, data, task):
exc = task.exception()

if isinstance(exc, AuthenticationError): # Check if we failed to log in...
log.error("Authentication error. Please check your credentials and try again.")
# Check if we failed to log in...
if isinstance(exc, AuthenticationError):
log.error(
"Authentication error. Please check your credentials and try again.")
self._close()
elif exc:
# event_error task need to be shielded to avoid cancelling in self._close() function
# we need ensure, that the event will print its traceback
shielded_task = asyncio.shield(asyncio.create_task(self.event_error(exc, data)))
shielded_task = asyncio.shield(
asyncio.create_task(self.event_error(exc, data)))
self._background_tasks.append(shielded_task)

async def send(self, message: str):
@@ -218,7 +229,8 @@ async def send(self, message: str):
dummy = f"> :{self.nick}!{self.nick}@{self.nick}.tmi.twitch.tv PRIVMSG(ECHO) #{channel} {content}\r\n"

task = asyncio.create_task(self._process_data(dummy))
task.add_done_callback(partial(self._task_callback, dummy)) # Process our raw data
# Process our raw data
task.add_done_callback(partial(self._task_callback, dummy))
self._background_tasks.append(task)
await self._websocket.send_str(message + "\r\n")

@@ -233,7 +245,8 @@ async def reply(self, msg_id: str, message: str):

dummy = f"> @reply-parent-msg-id={msg_id} :{self.nick}!{self.nick}@{self.nick}.tmi.twitch.tv PRIVMSG(ECHO) #{channel} {content}\r\n"
task = asyncio.create_task(self._process_data(dummy))
task.add_done_callback(partial(self._task_callback, dummy)) # Process our raw data
# Process our raw data
task.add_done_callback(partial(self._task_callback, dummy))
self._background_tasks.append(task)
await self._websocket.send_str(f"@reply-parent-msg-id={msg_id} {message} \r\n")

@@ -258,7 +271,8 @@ async def authenticate(self, channels: Union[list, tuple]):
await self.send(f"NICK {self.nick}\r\n")

for cap in self.modes:
await self.send(f"CAP REQ :twitch.tv/{cap}") # Ideally no one should overwrite defaults...
# Ideally no one should overwrite defaults...
await self.send(f"CAP REQ :twitch.tv/{cap}")
if not channels and not self._initial_channels:
return
channels = channels or self._initial_channels
@@ -305,10 +319,12 @@ async def join_channels(self, *channels: str):
channel_count = len(channels)
if channel_count > 20:
timeout = self._assign_timeout(channel_count)
chunks = [channels[i : i + 20] for i in range(0, len(channels), 20)]
chunks = [channels[i: i + 20]
for i in range(0, len(channels), 20)]
for chunk in chunks:
for channel in chunk:
task = asyncio.create_task(self._join_channel(channel, timeout))
task = asyncio.create_task(
self._join_channel(channel, timeout))
self._background_tasks.append(task)

await asyncio.sleep(11)
@@ -322,7 +338,8 @@ async def _join_channel(self, entry: str, timeout: int):
await self.send(f"JOIN #{channel}\r\n")

self._join_pending[channel] = fut = self._loop.create_future()
self._background_tasks.append(asyncio.create_task(self._join_future_handle(fut, channel, timeout)))
self._background_tasks.append(asyncio.create_task(
self._join_future_handle(fut, channel, timeout)))

async def _join_future_handle(self, fut: asyncio.Future, channel: str, timeout: int):
try:
@@ -498,7 +515,8 @@ async def _usernotice(self, parsed):
channel = Channel(name=parsed["channel"], websocket=self)
rawData = parsed["groups"][0]
tags = dict(x.split("=", 1) for x in rawData.split(";"))
tags["user-type"] = tags["user-type"].split(":tmi.twitch.tv")[0].strip()
tags["user-type"] = tags["user-type"].split(":tmi.twitch.tv")[
0].strip()

self.dispatch("raw_usernotice", channel, tags)

@@ -552,7 +570,8 @@ def _cache_add(self, parsed: dict):

if parsed["batches"]:
for u in parsed["batches"]:
user = PartialChatter(name=u, bot=self._client, websocket=self, channel=channel_)
user = PartialChatter(
name=u, bot=self._client, websocket=self, channel=channel_)
self._cache[channel].add(user)
else:
name = parsed["user"] or parsed["nick"]
@@ -570,7 +589,8 @@ async def _mode(self, parsed): # TODO
pass

async def _reconnect(self, parsed):
log.debug("ACTION: RECONNECT:: Twitch has gracefully closed the connection and will reconnect.")
log.debug(
"ACTION: RECONNECT:: Twitch has gracefully closed the connection and will reconnect.")
self._reconnect_requested = True
self._keeper.cancel()
self._loop.create_task(self._connect())
@@ -582,7 +602,8 @@ def dispatch(self, event: str, *args, **kwargs):
self._client.run_event(event, *args, **kwargs)

async def event_error(self, error: Exception, data: str = None):
traceback.print_exception(type(error), error, error.__traceback__, file=sys.stderr)
traceback.print_exception(
type(error), error, error.__traceback__, file=sys.stderr)

def _fetch_futures(self):
return [

0 comments on commit fe807b3

Please sign in to comment.