This commit is contained in:
2025-12-05 23:27:43 -05:00
parent 009584f497
commit 44901a44b7
8 changed files with 4847 additions and 196 deletions

450
bot.py
View File

@@ -15,20 +15,18 @@ from typing import Callable, Deque, Optional, Tuple
import discord
import numpy as np
import pyttsx3
from TTS.api import TTS as CoquiTTS
import soundfile as sf
from concurrent.futures import ThreadPoolExecutor
from discord import Intents
from discord.errors import ClientException, ConnectionClosed
from discord.ext import voice_recv
from dotenv import load_dotenv
from yt_dlp import YoutubeDL
from stt import transcribe_file
try:
from discord import sinks # Available in discord.py >=2.0 and py-cord
HAS_SINKS = True
except Exception:
HAS_SINKS = False
HAS_VOICE_RECV = True
load_dotenv()
@@ -65,6 +63,7 @@ if not _have_file:
# Tweak library log levels
logging.getLogger("discord").setLevel(logging.INFO)
logging.getLogger("aiohttp").setLevel(logging.INFO)
logging.getLogger("discord.ext.voice_recv.opus").setLevel(logging.ERROR) # Suppress packet loss warnings
logger = logging.getLogger("basharbot")
@@ -115,6 +114,16 @@ COMMAND_ALIASES = {
"next": "skip",
}
# Verbal responses (Arabic-friendly)
VERBAL_RESPONSES = {
"join": "نعم، أنا هنا", # "Yes, I am here"
"leave": "مع السلامة", # "Goodbye"
"play": "حسناً", # "Okay"
"skip": "التالي", # "Next"
"stop": "توقف", # "Stop"
"unknown": "ماذا تريد؟", # "What do you want?"
}
PCM_SAMPLE_RATE = 48000
PCM_CHANNELS = 2
PCM_SAMPLE_WIDTH = 2 # bytes per sample
@@ -123,6 +132,8 @@ TRANSCRIPT_LOG_ENABLED = os.getenv("TRANSCRIPT_LOG_ENABLED", "true").lower() in
TRANSCRIPT_LOG_PATH = os.getenv("TRANSCRIPT_LOG_PATH", "transcript.log")
GOODBOY_USER_ID = int(os.getenv("GOODBOY_USER_ID", "94578724413902848"))
GOODBOY_AUDIO_PATH = os.path.join(os.getcwd(), "goodboy.ogg")
USE_ARABIC_TTS = os.getenv("USE_ARABIC_TTS", "true").lower() in {"1", "true", "yes", "on"}
ARABIC_TTS_MODEL = os.getenv("ARABIC_TTS_MODEL", "tts_models/ar/cv/vits")
def _display_name(user: object) -> str:
@@ -157,6 +168,39 @@ def is_probably_english_sentence(text: str) -> bool:
return bool(_ENGLISH_SENTENCE_RE.match(text))
async def speak_response(voice_client: Optional[discord.VoiceClient], response_key: str) -> None:
"""Speak a verbal response in Arabic if enabled and connected to voice."""
if not voice_client or not voice_client.is_connected():
return
if not USE_ARABIC_TTS:
return
response_text = VERBAL_RESPONSES.get(response_key)
if not response_text:
return
try:
with tempfile.TemporaryDirectory() as tmpdir:
tts_path = os.path.join(tmpdir, "response.wav")
await synthesize_tts_to_wav(response_text, tts_path, use_arabic=True)
if voice_client.is_playing():
# Wait a bit if already playing
await asyncio.sleep(0.5)
source = discord.FFmpegPCMAudio(tts_path, **FFMPEG_OPTIONS)
fut = asyncio.get_running_loop().create_future()
def after_playback(_):
if not fut.done():
fut.set_result(True)
voice_client.play(source, after=after_playback)
await fut
except Exception as e:
logger.debug("Failed to speak response: %s", e)
async def announce_listening_roster(channel, voice_channel: Optional[discord.VoiceChannel]):
if channel is None or voice_channel is None:
return
@@ -279,6 +323,7 @@ def make_tts_engine() -> pyttsx3.Engine:
_tts_engine_singleton: Optional[pyttsx3.Engine] = None
_arabic_tts_singleton: Optional[CoquiTTS] = None
_tts_executor: Optional[ThreadPoolExecutor] = None
@@ -289,6 +334,21 @@ def get_tts_engine() -> pyttsx3.Engine:
return _tts_engine_singleton
def get_arabic_tts() -> Optional[CoquiTTS]:
global _arabic_tts_singleton
if not USE_ARABIC_TTS:
return None
if _arabic_tts_singleton is None:
try:
logger.info("Loading Arabic TTS model: %s", ARABIC_TTS_MODEL)
_arabic_tts_singleton = CoquiTTS(model_name=ARABIC_TTS_MODEL, progress_bar=False, gpu=False)
logger.info("Arabic TTS model loaded successfully")
except Exception as e:
logger.error("Failed to load Arabic TTS model: %s", e)
return None
return _arabic_tts_singleton
def get_tts_executor() -> ThreadPoolExecutor:
global _tts_executor
if _tts_executor is None:
@@ -296,9 +356,22 @@ def get_tts_executor() -> ThreadPoolExecutor:
return _tts_executor
async def synthesize_tts_to_wav(text: str, wav_path: str) -> str:
"""Generate TTS to a WAV file using pyttsx3 in a background thread."""
async def synthesize_tts_to_wav(text: str, wav_path: str, use_arabic: bool = False) -> str:
"""Generate TTS to a WAV file using Coqui TTS (Arabic) or pyttsx3 (English)."""
loop = asyncio.get_running_loop()
if use_arabic and USE_ARABIC_TTS:
arabic_tts = get_arabic_tts()
if arabic_tts:
def _save_arabic():
logger.debug("Synthesizing Arabic TTS to %s: %s", wav_path, (text if len(text) < 120 else text[:117] + "..."))
arabic_tts.tts_to_file(text=text, file_path=wav_path)
await loop.run_in_executor(get_tts_executor(), _save_arabic)
logger.debug("Arabic TTS synthesis complete: %s", wav_path)
return wav_path
# Fallback to pyttsx3
engine = get_tts_engine()
def _save():
@@ -389,35 +462,31 @@ async def _get_active_voice_client(guild: Optional[discord.Guild]) -> Optional[d
return voice_client
async def connect_voice_with_retry(channel: discord.abc.Connectable) -> discord.VoiceClient:
async def connect_voice_with_retry(channel: discord.abc.Connectable) -> voice_recv.VoiceRecvClient:
"""
Standard, simplified voice connection helper.
Uses standard Discord library methods without custom retry loops to avoid state conflicts.
Connect using VoiceRecvClient to enable voice receiving.
"""
guild: Optional[discord.Guild] = getattr(channel, "guild", None)
if guild is None:
raise RuntimeError("Voice channel without guild cannot establish a connection.")
# 1. Cleanup existing client if present
try:
old_vc = getattr(guild, "voice_client", None)
if old_vc:
if old_vc.channel == channel and old_vc.is_connected():
return old_vc
# Cleanup existing client if present
old_vc = getattr(guild, "voice_client", None)
if old_vc:
if old_vc.channel == channel and old_vc.is_connected():
logger.debug("Already connected to target channel")
return old_vc
try:
await old_vc.disconnect(force=True)
await asyncio.sleep(0.5)
except Exception as e:
logger.debug("Error cleaning up old voice client: %s", e)
await asyncio.sleep(1.0) # Give Discord time to clean up
except Exception as e:
logger.debug("Error cleaning up old voice client: %s", e)
# 2. Connect using standard library method
# Note: reconnect=True is the default and correct behavior for handling
# transient session errors (like 4006) internally by the library.
try:
voice_client = await channel.connect(timeout=20.0, reconnect=True)
return voice_client
except Exception as e:
logger.warning("Standard connect failed: %s", e)
raise
# Connect with VoiceRecvClient to enable receiving
logger.info("Connecting to voice channel: %s", getattr(channel, "name", "?"))
voice_client = await channel.connect(cls=voice_recv.VoiceRecvClient, timeout=30.0, reconnect=True)
logger.info("Successfully connected to voice with VoiceRecvClient")
return voice_client
@dataclass
class QueueItem:
@@ -425,123 +494,116 @@ class QueueItem:
source_factory: Callable[[], discord.AudioSource]
announce: Optional[str] = None
if HAS_SINKS:
class HotwordStreamSink(sinks.Sink):
def __init__(
self,
state: "GuildAudioState",
text_channel: discord.abc.Messageable,
loop: asyncio.AbstractEventLoop,
min_chunk_seconds: float = 1.0,
window_seconds: float = 4.5,
inactivity_seconds: float = 1.0,
):
super().__init__()
self.state = state
self.text_channel = text_channel
self.loop = loop
self.closed = False
self.buffers: defaultdict[int, bytearray] = defaultdict(bytearray)
self.last_activity: defaultdict[int, float] = defaultdict(lambda: 0.0)
self.processing_users: set[int] = set()
self.pending_tasks: dict[int, concurrent.futures.Future] = {}
self.min_chunk_bytes = int(max(PCM_BYTES_PER_SECOND * min_chunk_seconds, PCM_BYTES_PER_SECOND * 0.5))
self.window_bytes = int(PCM_BYTES_PER_SECOND * window_seconds)
self.inactivity_seconds = inactivity_seconds
class HotwordStreamSink(voice_recv.AudioSink):
def __init__(
self,
state: "GuildAudioState",
text_channel: discord.abc.Messageable,
loop: asyncio.AbstractEventLoop,
min_chunk_seconds: float = 1.0,
window_seconds: float = 4.5,
inactivity_seconds: float = 1.0,
):
super().__init__()
self.state = state
self.text_channel = text_channel
self.loop = loop
self.closed = False
self.buffers: defaultdict[int, bytearray] = defaultdict(bytearray)
self.last_activity: defaultdict[int, float] = defaultdict(lambda: 0.0)
self.processing_users: set[int] = set()
self.pending_tasks: dict[int, concurrent.futures.Future] = {}
self.min_chunk_bytes = int(max(PCM_BYTES_PER_SECOND * min_chunk_seconds, PCM_BYTES_PER_SECOND * 0.5))
self.window_bytes = int(PCM_BYTES_PER_SECOND * window_seconds)
self.inactivity_seconds = inactivity_seconds
def close(self):
self.closed = True
for fut in list(self.pending_tasks.values()):
def wants_opus(self) -> bool:
# We want decoded PCM, not Opus packets
return False
def close(self):
self.closed = True
for fut in list(self.pending_tasks.values()):
try:
fut.cancel()
except Exception:
pass
self.pending_tasks.clear()
self.buffers.clear()
self.processing_users.clear()
def update_text_channel(self, channel: discord.abc.Messageable):
self.text_channel = channel
def cleanup(self):
self.close()
def write(self, user: discord.User, data: voice_recv.VoiceData):
if self.closed or user is None:
return
# Get PCM data from VoiceData
pcm_data = data.pcm
if not pcm_data:
return
user_id = user.id
buffer = self.buffers[user_id]
buffer.extend(pcm_data)
if len(buffer) > self.window_bytes:
del buffer[: len(buffer) - int(self.window_bytes)]
now = time.perf_counter()
self.last_activity[user_id] = now
if len(buffer) < self.min_chunk_bytes:
return
existing = self.pending_tasks.get(user_id)
if existing and not existing.done():
existing.cancel()
self.pending_tasks.pop(user_id, None)
async def delayed_dispatch(uid: int, expected_time: float):
try:
await asyncio.sleep(self.inactivity_seconds)
if self.closed:
return
last = self.last_activity.get(uid, 0.0)
if abs(last - expected_time) > 1e-6:
return
buffer = self.buffers.get(uid)
if not buffer or len(buffer) < self.min_chunk_bytes:
return
if uid in self.processing_users:
return
self.processing_users.add(uid)
chunk = bytes(buffer)
buffer.clear()
try:
fut.cancel()
except Exception:
pass
self.pending_tasks.clear()
self.buffers.clear()
self.processing_users.clear()
await self.state.handle_hotword_buffer(uid, chunk, self.text_channel)
finally:
self.processing_users.discard(uid)
except asyncio.CancelledError:
return
finally:
self.pending_tasks.pop(uid, None)
def update_text_channel(self, channel: discord.abc.Messageable):
self.text_channel = channel
future = asyncio.run_coroutine_threadsafe(delayed_dispatch(user_id, now), self.loop)
def cleanup(self):
self.closed = True
for fut in list(self.pending_tasks.values()):
try:
fut.cancel()
except Exception:
pass
self.pending_tasks.clear()
return super().cleanup()
@sinks.Filters.container
def write(self, data, user):
if self.closed or user is None:
def _done_callback(fut, uid=user_id):
if fut.cancelled():
return
try:
user_id = int(user)
except Exception:
fut.result()
except asyncio.CancelledError:
return
except Exception as exc:
logger.exception("Hotword delayed dispatch failed for user %s: %s", uid, exc)
finally:
self.pending_tasks.pop(uid, None)
buffer = self.buffers[user_id]
buffer.extend(data)
if len(buffer) > self.window_bytes:
del buffer[: len(buffer) - int(self.window_bytes)]
now = time.perf_counter()
self.last_activity[user_id] = now
if len(buffer) < self.min_chunk_bytes:
return
existing = self.pending_tasks.get(user_id)
if existing and not existing.done():
existing.cancel()
self.pending_tasks.pop(user_id, None)
async def delayed_dispatch(uid: int, expected_time: float):
try:
await asyncio.sleep(self.inactivity_seconds)
if self.closed:
return
last = self.last_activity.get(uid, 0.0)
if abs(last - expected_time) > 1e-6:
return
buffer = self.buffers.get(uid)
if not buffer or len(buffer) < self.min_chunk_bytes:
return
if uid in self.processing_users:
return
self.processing_users.add(uid)
chunk = bytes(buffer)
buffer.clear()
try:
await self.state.handle_hotword_buffer(uid, chunk, self.text_channel)
finally:
self.processing_users.discard(uid)
except asyncio.CancelledError:
return
finally:
self.pending_tasks.pop(uid, None)
future = asyncio.run_coroutine_threadsafe(delayed_dispatch(user_id, now), self.loop)
def _done_callback(fut, uid=user_id):
if fut.cancelled():
return
try:
fut.result()
except asyncio.CancelledError:
return
except Exception as exc:
logger.exception("Hotword delayed dispatch failed for user %s: %s", uid, exc)
finally:
self.pending_tasks.pop(uid, None)
future.add_done_callback(_done_callback)
self.pending_tasks[user_id] = future
else:
class HotwordStreamSink: # type: ignore
def __init__(self, *args, **kwargs):
pass
future.add_done_callback(_done_callback)
self.pending_tasks[user_id] = future
@dataclass
@@ -679,8 +741,8 @@ class GuildAudioState:
if not HOTWORD_ENABLED:
logger.debug("Hotword listening disabled by environment (guild %s)", self.guild_id)
return
if not HAS_SINKS:
logger.warning("Hotword listening requested but sinks are unavailable on this stack.")
if not HAS_VOICE_RECV:
logger.warning("Hotword listening requested but voice_recv is unavailable.")
try:
await text_channel.send("Live hotword listening is unavailable on this install. Send a voice message instead.")
except Exception:
@@ -689,6 +751,10 @@ class GuildAudioState:
if not self.voice_client or not self.voice_client.is_connected():
logger.debug("Cannot start listener without an active voice client (guild %s)", self.guild_id)
return
if not isinstance(self.voice_client, voice_recv.VoiceRecvClient):
logger.warning("Voice client is not VoiceRecvClient, cannot listen (guild %s)", self.guild_id)
return
self.listen_enabled = True
self.last_transcripts.clear()
if self.hotword_sink and not self.hotword_sink.closed:
@@ -696,10 +762,10 @@ class GuildAudioState:
logger.debug("Hotword listener already running (guild %s)", self.guild_id)
return
# If another recording is running, stop it first
if getattr(self.voice_client, "recording", False):
# If already listening, stop first
if self.voice_client.is_listening():
try:
self.voice_client.stop_recording()
self.voice_client.stop_listening()
except Exception:
pass
@@ -708,10 +774,7 @@ class GuildAudioState:
self.hotword_sink = sink
logger.info("Starting continuous hotword listener (guild %s)", self.guild_id)
async def _finished_callback(sink_obj, *_):
await self._on_sink_finished(sink_obj)
self.voice_client.start_recording(sink, _finished_callback)
self.voice_client.listen(sink)
channel = getattr(self.voice_client, "channel", None)
if channel:
@@ -737,11 +800,12 @@ class GuildAudioState:
sink = self.hotword_sink
if sink:
sink.close()
if self.voice_client and getattr(self.voice_client, "recording", False):
try:
self.voice_client.stop_recording()
except Exception:
pass
if self.voice_client and isinstance(self.voice_client, voice_recv.VoiceRecvClient):
if self.voice_client.is_listening():
try:
self.voice_client.stop_listening()
except Exception:
pass
self.hotword_sink = None
async def handle_hotword_buffer(self, user_id: int, pcm_bytes: bytes, text_channel: discord.abc.Messageable):
@@ -833,51 +897,39 @@ def get_state_for_guild(guild_id: int) -> GuildAudioState:
async def connect_to_author_channel(message: discord.Message) -> Optional[discord.VoiceClient]:
if not isinstance(message.author, discord.Member):
return None
logger.debug("Connect requested by %s in guild %s", message.author, getattr(message.guild, "id", "?"))
voice_state = message.author.voice
if not voice_state or not voice_state.channel:
logger.info("Author not in a voice channel; cannot join (guild %s)", getattr(message.guild, "id", "?"))
await message.channel.send("Join a voice channel first, then say 'hey bashar join'.")
return None
channel = voice_state.channel
vc = await _get_active_voice_client(message.guild)
guild = message.guild
# Check if already connected to the right channel
vc = guild.voice_client
if vc and vc.channel == channel and vc.is_connected():
logger.debug("Already connected to requested channel: %s (guild %s)", channel, getattr(message.guild, "id", "?"))
logger.debug("Already connected to target channel")
return vc
if vc:
try:
logger.info("Moving voice client to channel: %s (guild %s)", channel, getattr(message.guild, "id", "?"))
# Move or reconnect
try:
if vc and vc.is_connected():
logger.info("Moving to channel: %s", channel.name)
await vc.move_to(channel)
await announce_listening_roster(message.channel, channel)
return vc
except Exception as e:
logger.warning("Move failed; reconnecting fresh (guild %s): %s", getattr(message.guild, "id", "?"), e)
try:
await vc.disconnect(force=True)
except Exception:
pass
try:
else:
if vc:
await vc.disconnect(force=True)
await asyncio.sleep(1.0)
vc = await connect_voice_with_retry(channel)
await announce_listening_roster(message.channel, channel)
except Exception as e:
logger.exception("Voice connect retries exhausted (guild %s): %s", getattr(message.guild, "id", "?"), e)
await message.channel.send("I couldn't join the voice channel (error 4006). Try again in a few seconds.")
return None
else:
logger.info("Connecting to voice channel: %s (guild %s)", channel, getattr(message.guild, "id", "?"))
try:
vc = await connect_voice_with_retry(channel)
await announce_listening_roster(message.channel, channel)
except Exception as e:
logger.exception("Voice connect retries exhausted (guild %s): %s", getattr(message.guild, "id", "?"), e)
await message.channel.send("I couldn't join the voice channel (error 4006). Try again in a few seconds.")
return None
if vc and vc.is_connected():
logger.info("Connected to voice: %s (guild %s)", vc.channel, getattr(message.guild, "id", "?"))
else:
logger.error("Voice connect returned but not connected (guild %s)", getattr(message.guild, "id", "?"))
return vc
await announce_listening_roster(message.channel, channel)
return vc
except Exception as e:
logger.exception("Failed to connect to voice: %s", e)
await message.channel.send("Couldn't join voice channel. Try again in a moment.")
return None
def make_ffmpeg_source(url: str) -> discord.AudioSource:
@@ -967,10 +1019,10 @@ async def on_ready():
ensure_ffmpeg_available()
ensure_opus_loaded()
logger.info("Startup checks OK")
if HOTWORD_ENABLED and HAS_SINKS:
logger.info("Hotword listening: ENABLED (sinks available and HOTWORD_ENABLED=True)")
elif HOTWORD_ENABLED and not HAS_SINKS:
logger.info("Hotword listening: DISABLED (HOTWORD_ENABLED=True but sinks unavailable)")
if HOTWORD_ENABLED and HAS_VOICE_RECV:
logger.info("Hotword listening: ENABLED (voice_recv available and HOTWORD_ENABLED=True)")
elif HOTWORD_ENABLED and not HAS_VOICE_RECV:
logger.info("Hotword listening: DISABLED (HOTWORD_ENABLED=True but voice_recv unavailable)")
else:
logger.info("Hotword listening: DISABLED (HOTWORD_ENABLED unset/false)")
@@ -1062,6 +1114,7 @@ async def on_message(message: discord.Message):
if vc:
state = get_state_for_guild(message.guild.id)
state.voice_client = vc
await speak_response(vc, "join")
await message.channel.send("Joined your voice channel. Say 'hey bashar play <song>' here.")
logger.info("Joined voice channel for guild %s", message.guild.id)
# Auto-start hotword listener
@@ -1073,6 +1126,8 @@ async def on_message(message: discord.Message):
state = get_state_for_guild(message.guild.id)
await state.stop_listening()
if state.voice_client and state.voice_client.is_connected():
await speak_response(state.voice_client, "leave")
await asyncio.sleep(1.0) # Give time for goodbye to play
await message.channel.send("Leaving voice channel.")
logger.info("Disconnecting from voice (guild %s)", message.guild.id)
await state.voice_client.disconnect(force=True)
@@ -1089,16 +1144,20 @@ async def on_message(message: discord.Message):
if action == "skip":
state = get_state_for_guild(message.guild.id)
state.skip_current()
await speak_response(state.voice_client, "skip")
await message.channel.send("Skipped the current track.")
return
if action == "stop":
state = get_state_for_guild(message.guild.id)
state.stop_all()
await speak_response(state.voice_client, "stop")
await message.channel.send("Stopped playback and cleared the queue.")
return
# Unknown
state = get_state_for_guild(message.guild.id)
await speak_response(state.voice_client, "unknown")
await message.channel.send("Commands: 'hey bashar join', 'hey bashar play <song>', 'hey bashar skip', 'hey bashar stop', 'hey bashar leave'.")
logger.debug("Sent help for unknown command")
@@ -1176,6 +1235,7 @@ async def route_transcribed_command_from_member(guild: discord.Guild, member: di
await text_channel.send("I couldn't join the voice channel (error 4006). Try again in a few seconds.")
return
state.voice_client = vc
await speak_response(vc, "join")
await text_channel.send("Joined your voice channel. Say 'hey bashar play <song>' here.")
# Start listening if not already
await state.start_listening(text_channel)
@@ -1184,6 +1244,8 @@ async def route_transcribed_command_from_member(guild: discord.Guild, member: di
state = get_state_for_guild(guild.id)
await state.stop_listening()
if state.voice_client and state.voice_client.is_connected():
await speak_response(state.voice_client, "leave")
await asyncio.sleep(1.0)
await text_channel.send("Leaving voice channel.")
await state.voice_client.disconnect(force=True)
return
@@ -1191,18 +1253,24 @@ async def route_transcribed_command_from_member(guild: discord.Guild, member: di
if not args:
await text_channel.send("Say 'hey bashar play <search terms>'.")
return
state = get_state_for_guild(guild.id)
await speak_response(state.voice_client, "play")
await handle_play_for_member(guild, member, text_channel, args)
return
if action == "skip":
state = get_state_for_guild(guild.id)
state.skip_current()
await speak_response(state.voice_client, "skip")
await text_channel.send("Skipped the current track.")
return
if action == "stop":
state = get_state_for_guild(guild.id)
state.stop_all()
await speak_response(state.voice_client, "stop")
await text_channel.send("Stopped playback and cleared the queue.")
return
state = get_state_for_guild(guild.id)
await speak_response(state.voice_client, "unknown")
await text_channel.send("Commands: 'hey bashar join', 'hey bashar play <song>', 'hey bashar skip', 'hey bashar stop', 'hey bashar leave'.")
@client.event