diff --git a/src/bouncer/network.py b/src/bouncer/network.py index f5d18b1..acbffcd 100644 --- a/src/bouncer/network.py +++ b/src/bouncer/network.py @@ -156,6 +156,7 @@ class Network: self._reconnect_attempt: int = 0 self._running: bool = False self._read_task: asyncio.Task[None] | None = None + self._reconnect_task: asyncio.Task[None] | None = None self._probation_task: asyncio.Task[None] | None = None # Transient nick used during registration/probation self._connect_nick: str = "" @@ -187,10 +188,9 @@ class Network: async def stop(self) -> None: """Disconnect and stop reconnection.""" self._running = False - if self._read_task and not self._read_task.done(): - self._read_task.cancel() - if self._probation_task and not self._probation_task.done(): - self._probation_task.cancel() + for task in (self._read_task, self._reconnect_task, self._probation_task): + if task and not task.done(): + task.cancel() await self._disconnect() async def send(self, msg: IRCMessage) -> None: @@ -242,7 +242,7 @@ class Network: log.exception("[%s] connection failed", self.cfg.name) self.state = State.DISCONNECTED if self._running: - await self._schedule_reconnect() + self._schedule_reconnect() async def _disconnect(self) -> None: """Close the connection.""" @@ -259,7 +259,11 @@ class Network: self._reader = None self._writer = None - async def _schedule_reconnect(self) -> None: + def _schedule_reconnect(self) -> None: + """Schedule a reconnect after exponential backoff.""" + self._reconnect_task = asyncio.create_task(self._reconnect_wait()) + + async def _reconnect_wait(self) -> None: """Wait with exponential backoff, then reconnect.""" delay = BACKOFF_STEPS[min(self._reconnect_attempt, len(BACKOFF_STEPS) - 1)] self._reconnect_attempt += 1 @@ -267,7 +271,10 @@ class Network: "[%s] reconnecting in %ds (attempt %d)", self.cfg.name, delay, self._reconnect_attempt, ) - await asyncio.sleep(delay) + try: + await asyncio.sleep(delay) + except asyncio.CancelledError: + return if self._running: await self._connect() @@ -300,7 +307,7 @@ class Network: finally: await self._disconnect() if self._running: - await self._schedule_reconnect() + self._schedule_reconnect() async def _enter_probation(self) -> None: """Start probation period after registration. Survive = ready."""