This commit is contained in:
2025-11-21 20:19:23 -05:00
parent 2c9fa9a14c
commit ee89f394bd
3 changed files with 2489 additions and 200 deletions

330
bot.py
View File

@@ -389,109 +389,35 @@ async def _get_active_voice_client(guild: Optional[discord.Guild]) -> Optional[d
return voice_client
async def connect_voice_with_retry(channel: discord.abc.Connectable, *, retries: int = 2, delay: float = 3.0) -> discord.VoiceClient:
async def connect_voice_with_retry(channel: discord.abc.Connectable) -> discord.VoiceClient:
"""
Robust voice connect helper that mitigates transient 4006 invalid session errors
by aggressively clearing stale session_id to force fresh handshakes.
Standard, simplified voice connection helper.
Uses standard Discord library methods without custom retry loops to avoid state conflicts.
"""
last_exc: Optional[Exception] = None
guild: Optional[discord.Guild] = getattr(channel, "guild", None)
if guild is None:
raise RuntimeError("Voice channel without guild cannot establish a connection.")
# Always start with a clean slate to avoid 4006 stale-session loops
await _force_cleanup_voice_client(guild)
await asyncio.sleep(0.5)
for attempt in range(1, retries + 1):
existing_vc: Optional[discord.VoiceClient] = await _get_active_voice_client(guild)
if existing_vc and existing_vc.channel == channel and existing_vc.is_connected():
logger.info("Re-using existing voice connection to %s (attempt %d/%d)", channel, attempt, retries)
return existing_vc
try:
logger.info("Attempting voice connect to %s (attempt %d/%d)", channel, attempt, retries)
# WORKAROUND for persistent 4006: Manually create VoiceClient and patch it to prevent session resume
from discord import VoiceClient as VoiceClientClass
state = guild._state
key_id = guild.id
# Remove any existing voice client to force fresh connection
state._remove_voice_client(key_id)
# Create new voice client
voice_client = VoiceClientClass(state._get_client(), channel)
# Force clear any stale session_id that might be cached
if hasattr(voice_client, 'session_id'):
voice_client.session_id = None
# Register it
state._add_voice_client(key_id, voice_client)
# Now connect with no reconnect to avoid resume attempts
try:
await voice_client.connect(timeout=25.0, reconnect=False)
except Exception:
# If connection fails, clean up the registered client
state._remove_voice_client(key_id)
raise
vc = voice_client
# Ensure we are self-undeafened so we can receive audio
try:
await channel.guild.change_voice_state(channel=channel, self_mute=False, self_deaf=False)
except Exception as change_err:
logger.debug("Unable to explicitly set voice state: %s", change_err)
if not vc.is_connected():
logger.warning("Voice connect returned but client not connected on attempt %d; cleaning up (guild %s)", attempt, guild.id)
await _force_cleanup_voice_client(guild)
raise ConnectionClosed(None, shard_id=None, code=4006)
logger.info("Voice connect succeeded on attempt %d (%s)", attempt, channel)
return vc
except ConnectionClosed as e:
last_exc = e
logger.warning(
"Voice connect failed with ConnectionClosed(code=%s, reason=%s) on attempt %d/%d",
getattr(e, "code", None),
getattr(e, "reason", None),
attempt,
retries,
)
await _force_cleanup_voice_client(guild)
except ClientException as e:
last_exc = e
logger.warning(
"Voice connect failed with ClientException('%s') on attempt %d/%d",
e,
attempt,
retries,
)
if "Already connected to a voice channel" in str(e):
existing_vc = await _get_active_voice_client(guild)
if existing_vc and existing_vc.channel == channel and existing_vc.is_connected():
logger.info("Detected active voice session despite ClientException; re-using existing client on guild %s", guild.id)
return existing_vc
await _force_cleanup_voice_client(guild)
except Exception as e:
last_exc = e
logger.exception("Voice connect raised %s on attempt %d/%d", e, attempt, retries)
await _force_cleanup_voice_client(guild)
# Reset voice state and wait before retrying
try:
await guild.change_voice_state(channel=None, self_mute=False, self_deaf=False)
# 1. Cleanup existing client if present
try:
old_vc = getattr(guild, "voice_client", None)
if old_vc:
if old_vc.channel == channel and old_vc.is_connected():
return old_vc
await old_vc.disconnect(force=True)
await asyncio.sleep(0.5)
except Exception as reset_err:
logger.debug("Failed to reset guild voice state: %s", reset_err)
if attempt < retries:
wait_time = delay * (1.5 ** (attempt - 1))
logger.info("Waiting %.1fs before retry %d (guild %s)", wait_time, attempt + 1, guild.id)
await asyncio.sleep(wait_time)
except Exception as e:
logger.debug("Error cleaning up old voice client: %s", e)
assert last_exc is not None
raise last_exc
# 2. Connect using standard library method
# Note: reconnect=True is the default and correct behavior for handling
# transient session errors (like 4006) internally by the library.
try:
voice_client = await channel.connect(timeout=20.0, reconnect=True)
return voice_client
except Exception as e:
logger.warning("Standard connect failed: %s", e)
raise
@dataclass
class QueueItem:
@@ -499,119 +425,123 @@ class QueueItem:
source_factory: Callable[[], discord.AudioSource]
announce: Optional[str] = None
if HAS_SINKS:
class HotwordStreamSink(sinks.Sink):
def __init__(
self,
state: "GuildAudioState",
text_channel: discord.abc.Messageable,
loop: asyncio.AbstractEventLoop,
min_chunk_seconds: float = 1.0,
window_seconds: float = 4.5,
inactivity_seconds: float = 1.0,
):
super().__init__()
self.state = state
self.text_channel = text_channel
self.loop = loop
self.closed = False
self.buffers: defaultdict[int, bytearray] = defaultdict(bytearray)
self.last_activity: defaultdict[int, float] = defaultdict(lambda: 0.0)
self.processing_users: set[int] = set()
self.pending_tasks: dict[int, concurrent.futures.Future] = {}
self.min_chunk_bytes = int(max(PCM_BYTES_PER_SECOND * min_chunk_seconds, PCM_BYTES_PER_SECOND * 0.5))
self.window_bytes = int(PCM_BYTES_PER_SECOND * window_seconds)
self.inactivity_seconds = inactivity_seconds
class HotwordStreamSink(sinks.Sink):
def __init__(
self,
state: "GuildAudioState",
text_channel: discord.abc.Messageable,
loop: asyncio.AbstractEventLoop,
min_chunk_seconds: float = 1.0,
window_seconds: float = 4.5,
inactivity_seconds: float = 1.0,
):
super().__init__()
self.state = state
self.text_channel = text_channel
self.loop = loop
self.closed = False
self.buffers: defaultdict[int, bytearray] = defaultdict(bytearray)
self.last_activity: defaultdict[int, float] = defaultdict(lambda: 0.0)
self.processing_users: set[int] = set()
self.pending_tasks: dict[int, concurrent.futures.Future] = {}
self.min_chunk_bytes = int(max(PCM_BYTES_PER_SECOND * min_chunk_seconds, PCM_BYTES_PER_SECOND * 0.5))
self.window_bytes = int(PCM_BYTES_PER_SECOND * window_seconds)
self.inactivity_seconds = inactivity_seconds
def close(self):
self.closed = True
for fut in list(self.pending_tasks.values()):
try:
fut.cancel()
except Exception:
pass
self.pending_tasks.clear()
self.buffers.clear()
self.processing_users.clear()
def update_text_channel(self, channel: discord.abc.Messageable):
self.text_channel = channel
def cleanup(self):
self.closed = True
for fut in list(self.pending_tasks.values()):
try:
fut.cancel()
except Exception:
pass
self.pending_tasks.clear()
return super().cleanup()
@sinks.Filters.container
def write(self, data, user):
if self.closed or user is None:
return
try:
user_id = int(user)
except Exception:
return
buffer = self.buffers[user_id]
buffer.extend(data)
if len(buffer) > self.window_bytes:
del buffer[: len(buffer) - int(self.window_bytes)]
now = time.perf_counter()
self.last_activity[user_id] = now
if len(buffer) < self.min_chunk_bytes:
return
existing = self.pending_tasks.get(user_id)
if existing and not existing.done():
existing.cancel()
self.pending_tasks.pop(user_id, None)
async def delayed_dispatch(uid: int, expected_time: float):
try:
await asyncio.sleep(self.inactivity_seconds)
if self.closed:
return
last = self.last_activity.get(uid, 0.0)
if abs(last - expected_time) > 1e-6:
return
buffer = self.buffers.get(uid)
if not buffer or len(buffer) < self.min_chunk_bytes:
return
if uid in self.processing_users:
return
self.processing_users.add(uid)
chunk = bytes(buffer)
buffer.clear()
def close(self):
self.closed = True
for fut in list(self.pending_tasks.values()):
try:
await self.state.handle_hotword_buffer(uid, chunk, self.text_channel)
finally:
self.processing_users.discard(uid)
except asyncio.CancelledError:
return
finally:
self.pending_tasks.pop(uid, None)
fut.cancel()
except Exception:
pass
self.pending_tasks.clear()
self.buffers.clear()
self.processing_users.clear()
future = asyncio.run_coroutine_threadsafe(delayed_dispatch(user_id, now), self.loop)
def update_text_channel(self, channel: discord.abc.Messageable):
self.text_channel = channel
def _done_callback(fut, uid=user_id):
if fut.cancelled():
def cleanup(self):
self.closed = True
for fut in list(self.pending_tasks.values()):
try:
fut.cancel()
except Exception:
pass
self.pending_tasks.clear()
return super().cleanup()
@sinks.Filters.container
def write(self, data, user):
if self.closed or user is None:
return
try:
fut.result()
except asyncio.CancelledError:
user_id = int(user)
except Exception:
return
except Exception as exc:
logger.exception("Hotword delayed dispatch failed for user %s: %s", uid, exc)
finally:
self.pending_tasks.pop(uid, None)
future.add_done_callback(_done_callback)
self.pending_tasks[user_id] = future
buffer = self.buffers[user_id]
buffer.extend(data)
if len(buffer) > self.window_bytes:
del buffer[: len(buffer) - int(self.window_bytes)]
now = time.perf_counter()
self.last_activity[user_id] = now
if len(buffer) < self.min_chunk_bytes:
return
existing = self.pending_tasks.get(user_id)
if existing and not existing.done():
existing.cancel()
self.pending_tasks.pop(user_id, None)
async def delayed_dispatch(uid: int, expected_time: float):
try:
await asyncio.sleep(self.inactivity_seconds)
if self.closed:
return
last = self.last_activity.get(uid, 0.0)
if abs(last - expected_time) > 1e-6:
return
buffer = self.buffers.get(uid)
if not buffer or len(buffer) < self.min_chunk_bytes:
return
if uid in self.processing_users:
return
self.processing_users.add(uid)
chunk = bytes(buffer)
buffer.clear()
try:
await self.state.handle_hotword_buffer(uid, chunk, self.text_channel)
finally:
self.processing_users.discard(uid)
except asyncio.CancelledError:
return
finally:
self.pending_tasks.pop(uid, None)
future = asyncio.run_coroutine_threadsafe(delayed_dispatch(user_id, now), self.loop)
def _done_callback(fut, uid=user_id):
if fut.cancelled():
return
try:
fut.result()
except asyncio.CancelledError:
return
except Exception as exc:
logger.exception("Hotword delayed dispatch failed for user %s: %s", uid, exc)
finally:
self.pending_tasks.pop(uid, None)
future.add_done_callback(_done_callback)
self.pending_tasks[user_id] = future
else:
class HotwordStreamSink: # type: ignore
def __init__(self, *args, **kwargs):
pass
@dataclass