Complete Phase 1: parallel sync, IPC, theme colors, lazy CLI loading
- Sync: Parallelize message downloads with asyncio.gather (batch size 5) - Sync: Increase HTTP semaphore from 2 to 5 concurrent requests - Sync: Add IPC notifications to sync daemon after sync completes - Mail: Replace all hardcoded RGB colors with theme variables - Mail: Remove envelope icon/checkbox gap (padding cleanup) - Mail: Add IPC listener for refresh notifications from sync - Calendar: Style current time line with error color and solid line - Tasks: Fix table not displaying (CSS grid to horizontal layout) - CLI: Implement lazy command loading for faster startup (~12s to ~0.3s) - Add PROJECT_PLAN.md with full improvement roadmap - Add src/utils/ipc.py for Unix socket cross-app communication
This commit is contained in:
@@ -373,8 +373,8 @@ class WeekGridBody(ScrollView):
|
||||
|
||||
# Style time label - highlight current time, dim outside work hours
|
||||
if is_current_time_row:
|
||||
secondary_color = self._get_theme_color("secondary")
|
||||
time_style = Style(color=secondary_color, bold=True)
|
||||
error_color = self._get_theme_color("error")
|
||||
time_style = Style(color=error_color, bold=True)
|
||||
elif (
|
||||
row_index < self._work_day_start * rows_per_hour
|
||||
or row_index >= self._work_day_end * rows_per_hour
|
||||
@@ -388,13 +388,19 @@ class WeekGridBody(ScrollView):
|
||||
|
||||
# Event cells for each day
|
||||
for col_idx, day_col in enumerate(self._days):
|
||||
cell_text, cell_style = self._render_event_cell(day_col, row_index, col_idx)
|
||||
cell_text, cell_style = self._render_event_cell(
|
||||
day_col, row_index, col_idx, is_current_time_row
|
||||
)
|
||||
segments.append(Segment(cell_text, cell_style))
|
||||
|
||||
return Strip(segments)
|
||||
|
||||
def _render_event_cell(
|
||||
self, day_col: DayColumn, row_index: int, col_idx: int
|
||||
self,
|
||||
day_col: DayColumn,
|
||||
row_index: int,
|
||||
col_idx: int,
|
||||
is_current_time_row: bool = False,
|
||||
) -> Tuple[str, Style]:
|
||||
"""Render a single cell for a day/time slot."""
|
||||
events_at_row = day_col.grid[row_index] if row_index < len(day_col.grid) else []
|
||||
@@ -404,10 +410,16 @@ class WeekGridBody(ScrollView):
|
||||
|
||||
is_cursor = col_idx == self.cursor_col and row_index == self.cursor_row
|
||||
|
||||
# Get colors for current time line
|
||||
error_color = self._get_theme_color("error") if is_current_time_row else None
|
||||
|
||||
if not events_at_row:
|
||||
# Empty cell
|
||||
if is_cursor:
|
||||
return ">" + " " * (day_col_width - 1), Style(reverse=True)
|
||||
elif is_current_time_row:
|
||||
# Current time indicator line
|
||||
return "─" * day_col_width, Style(color=error_color, bold=True)
|
||||
else:
|
||||
# Grid line style
|
||||
if row_index % rows_per_hour == 0:
|
||||
|
||||
@@ -1,37 +1,55 @@
|
||||
# CLI module for the application
|
||||
# Uses lazy imports to speed up startup time
|
||||
|
||||
import click
|
||||
|
||||
from .sync import sync
|
||||
from .drive import drive
|
||||
from .email import email
|
||||
from .calendar import calendar
|
||||
from .ticktick import ticktick
|
||||
from .godspeed import godspeed
|
||||
from .gitlab_monitor import gitlab_monitor
|
||||
from .tasks import tasks
|
||||
import importlib
|
||||
|
||||
|
||||
@click.group()
|
||||
class LazyGroup(click.Group):
|
||||
"""A click Group that lazily loads subcommands."""
|
||||
|
||||
def __init__(self, *args, lazy_subcommands=None, **kwargs):
|
||||
super().__init__(*args, **kwargs)
|
||||
self._lazy_subcommands = lazy_subcommands or {}
|
||||
|
||||
def list_commands(self, ctx):
|
||||
base = super().list_commands(ctx)
|
||||
lazy = list(self._lazy_subcommands.keys())
|
||||
return sorted(base + lazy)
|
||||
|
||||
def get_command(self, ctx, cmd_name):
|
||||
if cmd_name in self._lazy_subcommands:
|
||||
return self._load_command(cmd_name)
|
||||
return super().get_command(ctx, cmd_name)
|
||||
|
||||
def _load_command(self, cmd_name):
|
||||
module_path, attr_name = self._lazy_subcommands[cmd_name]
|
||||
# Handle relative imports
|
||||
if module_path.startswith("."):
|
||||
module = importlib.import_module(module_path, package="src.cli")
|
||||
else:
|
||||
module = importlib.import_module(module_path)
|
||||
return getattr(module, attr_name)
|
||||
|
||||
|
||||
# Create CLI with lazy loading - commands only imported when invoked
|
||||
@click.group(
|
||||
cls=LazyGroup,
|
||||
lazy_subcommands={
|
||||
"sync": (".sync", "sync"),
|
||||
"drive": (".drive", "drive"),
|
||||
"email": (".email", "email"),
|
||||
"mail": (".email", "email"), # alias
|
||||
"calendar": (".calendar", "calendar"),
|
||||
"ticktick": (".ticktick", "ticktick"),
|
||||
"tt": (".ticktick", "ticktick"), # alias
|
||||
"godspeed": (".godspeed", "godspeed"),
|
||||
"gs": (".godspeed", "godspeed"), # alias
|
||||
"gitlab_monitor": (".gitlab_monitor", "gitlab_monitor"),
|
||||
"glm": (".gitlab_monitor", "gitlab_monitor"), # alias
|
||||
"tasks": (".tasks", "tasks"),
|
||||
},
|
||||
)
|
||||
def cli():
|
||||
"""Root command for the CLI."""
|
||||
"""LUK - Local Unix Kit for productivity."""
|
||||
pass
|
||||
|
||||
|
||||
cli.add_command(sync)
|
||||
cli.add_command(drive)
|
||||
cli.add_command(email)
|
||||
cli.add_command(calendar)
|
||||
cli.add_command(ticktick)
|
||||
cli.add_command(godspeed)
|
||||
cli.add_command(gitlab_monitor)
|
||||
cli.add_command(tasks)
|
||||
|
||||
# Add 'mail' as an alias for email
|
||||
cli.add_command(email, name="mail")
|
||||
# Add 'tt' as a short alias for ticktick
|
||||
cli.add_command(ticktick, name="tt")
|
||||
# Add 'gs' as a short alias for godspeed
|
||||
cli.add_command(godspeed, name="gs")
|
||||
# Add 'glm' as a short alias for gitlab_monitor
|
||||
cli.add_command(gitlab_monitor, name="glm")
|
||||
|
||||
@@ -14,6 +14,7 @@ from typing import Optional, Dict, Any
|
||||
|
||||
from src.cli.sync import _sync_outlook_data, should_run_godspeed_sync, should_run_sweep
|
||||
from src.cli.sync import run_godspeed_sync, run_task_sweep, load_sync_state
|
||||
from src.utils.ipc import notify_all, notify_refresh
|
||||
|
||||
|
||||
class SyncDaemon:
|
||||
@@ -247,6 +248,13 @@ class SyncDaemon:
|
||||
notify=self.config.get("notify", False),
|
||||
)
|
||||
self.logger.info("Sync completed successfully")
|
||||
|
||||
# Notify all running TUI apps to refresh their data
|
||||
results = await notify_all({"source": "sync_daemon"})
|
||||
notified = [app for app, success in results.items() if success]
|
||||
if notified:
|
||||
self.logger.info(f"Notified apps to refresh: {', '.join(notified)}")
|
||||
|
||||
except Exception as e:
|
||||
self.logger.error(f"Sync failed: {e}")
|
||||
|
||||
|
||||
@@ -10,6 +10,7 @@ from .actions.delete import delete_current
|
||||
from src.services.taskwarrior import client as taskwarrior_client
|
||||
from src.services.himalaya import client as himalaya_client
|
||||
from src.utils.shared_config import get_theme_name
|
||||
from src.utils.ipc import IPCListener, IPCMessage
|
||||
from textual.containers import Container, ScrollableContainer, Vertical, Horizontal
|
||||
from textual.timer import Timer
|
||||
from textual.binding import Binding
|
||||
@@ -149,7 +150,7 @@ class EmailViewerApp(App):
|
||||
async def on_mount(self) -> None:
|
||||
self.alert_timer: Timer | None = None # Timer to throttle alerts
|
||||
self.theme = get_theme_name()
|
||||
self.title = "MaildirGTD"
|
||||
self.title = "LUK Mail"
|
||||
self.query_one("#main_content").border_title = self.status_title
|
||||
sort_indicator = "↑" if self.sort_order_ascending else "↓"
|
||||
self.query_one("#envelopes_list").border_title = f"1️⃣ Emails {sort_indicator}"
|
||||
@@ -157,6 +158,10 @@ class EmailViewerApp(App):
|
||||
|
||||
self.query_one("#folders_list").border_title = "3️⃣ Folders"
|
||||
|
||||
# Start IPC listener for refresh notifications from sync daemon
|
||||
self._ipc_listener = IPCListener("mail", self._on_ipc_message)
|
||||
self._ipc_listener.start()
|
||||
|
||||
self.fetch_accounts()
|
||||
self.fetch_folders()
|
||||
worker = self.fetch_envelopes()
|
||||
@@ -164,6 +169,12 @@ class EmailViewerApp(App):
|
||||
self.query_one("#envelopes_list").focus()
|
||||
self.action_oldest()
|
||||
|
||||
def _on_ipc_message(self, message: IPCMessage) -> None:
|
||||
"""Handle IPC messages from sync daemon."""
|
||||
if message.event == "refresh":
|
||||
# Schedule a reload on the main thread
|
||||
self.call_from_thread(self.fetch_envelopes)
|
||||
|
||||
def compute_status_title(self):
|
||||
metadata = self.message_store.get_metadata(self.current_message_id)
|
||||
message_date = metadata["date"] if metadata else "N/A"
|
||||
@@ -820,6 +831,9 @@ class EmailViewerApp(App):
|
||||
self.query_one("#envelopes_list").focus()
|
||||
|
||||
def action_quit(self) -> None:
|
||||
# Stop IPC listener before exiting
|
||||
if hasattr(self, "_ipc_listener"):
|
||||
self._ipc_listener.stop()
|
||||
self.exit()
|
||||
|
||||
def action_toggle_selection(self) -> None:
|
||||
|
||||
@@ -3,7 +3,7 @@
|
||||
|
||||
#main_content, .list_view {
|
||||
scrollbar-size: 1 1;
|
||||
border: round rgb(117, 106, 129);
|
||||
border: round $border;
|
||||
height: 1fr;
|
||||
}
|
||||
|
||||
@@ -43,18 +43,18 @@
|
||||
|
||||
#main_content:focus, .list_view:focus {
|
||||
border: round $secondary;
|
||||
background: rgb(55, 53, 57);
|
||||
background: $surface;
|
||||
border-title-style: bold;
|
||||
}
|
||||
|
||||
Label#task_prompt {
|
||||
padding: 1;
|
||||
color: rgb(128,128,128);
|
||||
color: $text-muted;
|
||||
}
|
||||
|
||||
Label#task_prompt_label {
|
||||
padding: 1;
|
||||
color: rgb(255, 216, 102);
|
||||
color: $warning;
|
||||
}
|
||||
|
||||
Label#message_label {
|
||||
@@ -66,7 +66,7 @@ StatusTitle {
|
||||
width: 100%;
|
||||
height: 1;
|
||||
color: $text;
|
||||
background: rgb(64, 62, 65);
|
||||
background: $panel;
|
||||
content-align: center middle;
|
||||
}
|
||||
|
||||
@@ -113,8 +113,8 @@ EnvelopeListItem .envelope-row-3 {
|
||||
}
|
||||
|
||||
EnvelopeListItem .status-icon {
|
||||
width: 3;
|
||||
padding: 0 1 0 0;
|
||||
width: 2;
|
||||
padding: 0;
|
||||
color: $text-muted;
|
||||
}
|
||||
|
||||
@@ -124,7 +124,7 @@ EnvelopeListItem .status-icon.unread {
|
||||
|
||||
EnvelopeListItem .checkbox {
|
||||
width: 2;
|
||||
padding: 0 1 0 0;
|
||||
padding: 0;
|
||||
}
|
||||
|
||||
EnvelopeListItem .sender-name {
|
||||
@@ -166,11 +166,11 @@ EnvelopeListItem.selected {
|
||||
GroupHeader {
|
||||
height: 1;
|
||||
width: 1fr;
|
||||
background: rgb(64, 62, 65);
|
||||
background: $panel;
|
||||
}
|
||||
|
||||
GroupHeader .group-header-label {
|
||||
color: rgb(160, 160, 160);
|
||||
color: $text-muted;
|
||||
text-style: bold;
|
||||
padding: 0 1;
|
||||
width: 1fr;
|
||||
@@ -222,10 +222,10 @@ GroupHeader .group-header-label {
|
||||
|
||||
#envelopes_list {
|
||||
ListItem:odd {
|
||||
background: rgb(45, 45, 46);
|
||||
background: $surface;
|
||||
}
|
||||
ListItem:even {
|
||||
background: rgb(50, 50, 56);
|
||||
background: $surface-darken-1;
|
||||
}
|
||||
|
||||
& > ListItem {
|
||||
@@ -269,9 +269,9 @@ GroupHeader .group-header-label {
|
||||
}
|
||||
|
||||
Label.group_header {
|
||||
color: rgb(140, 140, 140);
|
||||
color: $text-muted;
|
||||
text-style: bold;
|
||||
background: rgb(64, 62, 65);
|
||||
background: $panel;
|
||||
width: 100%;
|
||||
padding: 0 1;
|
||||
}
|
||||
@@ -300,6 +300,3 @@ ContentContainer {
|
||||
width: 100%;
|
||||
height: 1fr;
|
||||
}
|
||||
.checkbox {
|
||||
padding-right: 1;
|
||||
}
|
||||
|
||||
@@ -44,12 +44,17 @@ class EnvelopeListItem(Static):
|
||||
|
||||
EnvelopeListItem .status-icon {
|
||||
width: 2;
|
||||
padding: 0 1 0 0;
|
||||
padding: 0;
|
||||
}
|
||||
|
||||
EnvelopeListItem .checkbox {
|
||||
width: 2;
|
||||
padding: 0 1 0 0;
|
||||
padding: 0;
|
||||
}
|
||||
|
||||
EnvelopeListItem .checkbox {
|
||||
width: 2;
|
||||
padding: 0;
|
||||
}
|
||||
|
||||
EnvelopeListItem .sender-name {
|
||||
|
||||
@@ -13,8 +13,8 @@ logging.getLogger("aiohttp.access").setLevel(logging.ERROR)
|
||||
logging.getLogger("urllib3").setLevel(logging.ERROR)
|
||||
logging.getLogger("asyncio").setLevel(logging.ERROR)
|
||||
|
||||
# Define a global semaphore for throttling - reduced for better compliance
|
||||
semaphore = asyncio.Semaphore(2)
|
||||
# Define a global semaphore for throttling - increased for better parallelization
|
||||
semaphore = asyncio.Semaphore(5)
|
||||
|
||||
|
||||
async def _handle_throttling_retry(func, *args, max_retries=3):
|
||||
|
||||
@@ -110,26 +110,47 @@ async def fetch_mail_async(
|
||||
progress.update(task_id, total=len(messages_to_download), completed=0)
|
||||
downloaded_count = 0
|
||||
|
||||
for message in messages_to_download:
|
||||
# Download messages in parallel batches for better performance
|
||||
BATCH_SIZE = 5
|
||||
|
||||
for i in range(0, len(messages_to_download), BATCH_SIZE):
|
||||
# Check if task was cancelled/disabled
|
||||
if is_cancelled and is_cancelled():
|
||||
progress.console.print("Task cancelled, stopping inbox fetch")
|
||||
break
|
||||
|
||||
progress.console.print(
|
||||
f"Processing message: {message.get('subject', 'No Subject')}", end="\r"
|
||||
)
|
||||
await save_mime_to_maildir_async(
|
||||
maildir_path,
|
||||
message,
|
||||
attachments_dir,
|
||||
headers,
|
||||
progress,
|
||||
dry_run,
|
||||
download_attachments,
|
||||
)
|
||||
progress.update(task_id, advance=1)
|
||||
downloaded_count += 1
|
||||
batch = messages_to_download[i : i + BATCH_SIZE]
|
||||
|
||||
# Create tasks for parallel download
|
||||
async def download_message(message):
|
||||
progress.console.print(
|
||||
f"Processing message: {message.get('subject', 'No Subject')[:50]}",
|
||||
end="\r",
|
||||
)
|
||||
await save_mime_to_maildir_async(
|
||||
maildir_path,
|
||||
message,
|
||||
attachments_dir,
|
||||
headers,
|
||||
progress,
|
||||
dry_run,
|
||||
download_attachments,
|
||||
)
|
||||
return 1
|
||||
|
||||
# Execute batch in parallel
|
||||
tasks = [download_message(msg) for msg in batch]
|
||||
results = await asyncio.gather(*tasks, return_exceptions=True)
|
||||
|
||||
# Count successful downloads
|
||||
batch_success = sum(1 for r in results if r == 1)
|
||||
downloaded_count += batch_success
|
||||
progress.update(task_id, advance=len(batch))
|
||||
|
||||
# Log any errors
|
||||
for idx, result in enumerate(results):
|
||||
if isinstance(result, Exception):
|
||||
progress.console.print(f"Error downloading message: {result}")
|
||||
|
||||
progress.update(task_id, completed=downloaded_count)
|
||||
progress.console.print(f"\nFinished downloading {downloaded_count} new messages.")
|
||||
@@ -461,37 +482,57 @@ async def fetch_archive_mail_async(
|
||||
# Update progress to reflect only the messages we actually need to download
|
||||
progress.update(task_id, total=len(messages_to_download), completed=0)
|
||||
|
||||
# Load sync state once, we'll update it incrementally
|
||||
# Load sync state once, we'll update it after each batch for resilience
|
||||
synced_ids = _load_archive_sync_state(maildir_path) if not dry_run else set()
|
||||
downloaded_count = 0
|
||||
|
||||
for message in messages_to_download:
|
||||
# Download messages in parallel batches for better performance
|
||||
BATCH_SIZE = 5
|
||||
|
||||
for i in range(0, len(messages_to_download), BATCH_SIZE):
|
||||
# Check if task was cancelled/disabled
|
||||
if is_cancelled and is_cancelled():
|
||||
progress.console.print("Task cancelled, stopping archive fetch")
|
||||
break
|
||||
|
||||
progress.console.print(
|
||||
f"Processing archived message: {message.get('subject', 'No Subject')[:50]}",
|
||||
end="\r",
|
||||
)
|
||||
# Save to .Archive folder instead of main maildir
|
||||
await save_mime_to_maildir_async(
|
||||
archive_dir, # Use archive_dir instead of maildir_path
|
||||
message,
|
||||
attachments_dir,
|
||||
headers,
|
||||
progress,
|
||||
dry_run,
|
||||
download_attachments,
|
||||
)
|
||||
progress.update(task_id, advance=1)
|
||||
downloaded_count += 1
|
||||
batch = messages_to_download[i : i + BATCH_SIZE]
|
||||
batch_msg_ids = []
|
||||
|
||||
# Update sync state after each message for resilience
|
||||
# This ensures we don't try to re-upload this message in archive_mail_async
|
||||
if not dry_run:
|
||||
synced_ids.add(message["id"])
|
||||
# Create tasks for parallel download
|
||||
async def download_message(message):
|
||||
progress.console.print(
|
||||
f"Processing archived message: {message.get('subject', 'No Subject')[:50]}",
|
||||
end="\r",
|
||||
)
|
||||
# Save to .Archive folder instead of main maildir
|
||||
await save_mime_to_maildir_async(
|
||||
archive_dir, # Use archive_dir instead of maildir_path
|
||||
message,
|
||||
attachments_dir,
|
||||
headers,
|
||||
progress,
|
||||
dry_run,
|
||||
download_attachments,
|
||||
)
|
||||
return message["id"]
|
||||
|
||||
# Execute batch in parallel
|
||||
tasks = [download_message(msg) for msg in batch]
|
||||
results = await asyncio.gather(*tasks, return_exceptions=True)
|
||||
|
||||
# Process results and collect successful message IDs
|
||||
for result in results:
|
||||
if isinstance(result, Exception):
|
||||
progress.console.print(f"Error downloading archived message: {result}")
|
||||
elif result:
|
||||
batch_msg_ids.append(result)
|
||||
downloaded_count += 1
|
||||
|
||||
progress.update(task_id, advance=len(batch))
|
||||
|
||||
# Update sync state after each batch (not each message) for resilience + performance
|
||||
if not dry_run and batch_msg_ids:
|
||||
synced_ids.update(batch_msg_ids)
|
||||
_save_archive_sync_state(maildir_path, synced_ids)
|
||||
|
||||
progress.update(task_id, completed=downloaded_count)
|
||||
|
||||
@@ -46,24 +46,20 @@ class TasksApp(App):
|
||||
|
||||
CSS = """
|
||||
Screen {
|
||||
layout: grid;
|
||||
grid-size: 2;
|
||||
grid-columns: auto 1fr;
|
||||
grid-rows: auto 1fr auto auto;
|
||||
layout: horizontal;
|
||||
}
|
||||
|
||||
Header {
|
||||
column-span: 2;
|
||||
dock: top;
|
||||
}
|
||||
|
||||
Footer {
|
||||
column-span: 2;
|
||||
dock: bottom;
|
||||
}
|
||||
|
||||
#sidebar {
|
||||
width: 28;
|
||||
height: 100%;
|
||||
row-span: 1;
|
||||
}
|
||||
|
||||
#sidebar.hidden {
|
||||
@@ -116,7 +112,6 @@ class TasksApp(App):
|
||||
background: $surface;
|
||||
color: $text-muted;
|
||||
padding: 0 1;
|
||||
column-span: 2;
|
||||
}
|
||||
|
||||
#detail-pane {
|
||||
@@ -124,7 +119,6 @@ class TasksApp(App):
|
||||
height: 50%;
|
||||
border-top: solid $primary;
|
||||
background: $surface;
|
||||
column-span: 2;
|
||||
}
|
||||
|
||||
#detail-pane.hidden {
|
||||
@@ -154,7 +148,6 @@ class TasksApp(App):
|
||||
border-top: solid $primary;
|
||||
padding: 1;
|
||||
background: $surface;
|
||||
column-span: 2;
|
||||
}
|
||||
|
||||
#notes-pane.hidden {
|
||||
|
||||
318
src/utils/ipc.py
Normal file
318
src/utils/ipc.py
Normal file
@@ -0,0 +1,318 @@
|
||||
"""Inter-Process Communication using Unix Domain Sockets.
|
||||
|
||||
This module provides a simple pub/sub mechanism for cross-app notifications.
|
||||
The sync daemon can broadcast messages when data changes, and TUI apps can
|
||||
listen for these messages to refresh their displays.
|
||||
|
||||
Usage:
|
||||
# In sync daemon (publisher):
|
||||
from src.utils.ipc import notify_refresh
|
||||
await notify_refresh("mail") # Notify mail app to refresh
|
||||
await notify_refresh("calendar") # Notify calendar app to refresh
|
||||
await notify_refresh("tasks") # Notify tasks app to refresh
|
||||
|
||||
# In TUI apps (subscriber):
|
||||
from src.utils.ipc import IPCListener
|
||||
|
||||
class MyApp(App):
|
||||
def on_mount(self):
|
||||
self.ipc_listener = IPCListener("mail", self.on_refresh)
|
||||
self.ipc_listener.start()
|
||||
|
||||
def on_unmount(self):
|
||||
self.ipc_listener.stop()
|
||||
|
||||
async def on_refresh(self, message):
|
||||
# Refresh the app's data
|
||||
await self.refresh_data()
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import json
|
||||
import os
|
||||
import socket
|
||||
import threading
|
||||
from pathlib import Path
|
||||
from typing import Callable, Optional, Any, Dict
|
||||
|
||||
# Socket paths for each app type
|
||||
SOCKET_DIR = Path("~/.local/share/luk/ipc").expanduser()
|
||||
SOCKET_PATHS = {
|
||||
"mail": SOCKET_DIR / "mail.sock",
|
||||
"calendar": SOCKET_DIR / "calendar.sock",
|
||||
"tasks": SOCKET_DIR / "tasks.sock",
|
||||
}
|
||||
|
||||
|
||||
def ensure_socket_dir():
|
||||
"""Ensure the socket directory exists."""
|
||||
SOCKET_DIR.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
|
||||
def get_socket_path(app_type: str) -> Path:
|
||||
"""Get the socket path for a given app type."""
|
||||
if app_type not in SOCKET_PATHS:
|
||||
raise ValueError(
|
||||
f"Unknown app type: {app_type}. Must be one of: {list(SOCKET_PATHS.keys())}"
|
||||
)
|
||||
return SOCKET_PATHS[app_type]
|
||||
|
||||
|
||||
class IPCMessage:
|
||||
"""A message sent via IPC."""
|
||||
|
||||
def __init__(self, event: str, data: Optional[Dict[str, Any]] = None):
|
||||
self.event = event
|
||||
self.data = data or {}
|
||||
|
||||
def to_json(self) -> str:
|
||||
return json.dumps({"event": self.event, "data": self.data})
|
||||
|
||||
@classmethod
|
||||
def from_json(cls, json_str: str) -> "IPCMessage":
|
||||
parsed = json.loads(json_str)
|
||||
return cls(event=parsed["event"], data=parsed.get("data", {}))
|
||||
|
||||
|
||||
async def notify_refresh(app_type: str, data: Optional[Dict[str, Any]] = None) -> bool:
|
||||
"""Send a refresh notification to a specific app.
|
||||
|
||||
Args:
|
||||
app_type: The type of app to notify ("mail", "calendar", "tasks")
|
||||
data: Optional data to include with the notification
|
||||
|
||||
Returns:
|
||||
True if the notification was sent successfully, False otherwise
|
||||
"""
|
||||
socket_path = get_socket_path(app_type)
|
||||
|
||||
if not socket_path.exists():
|
||||
# No listener, that's okay
|
||||
return False
|
||||
|
||||
try:
|
||||
message = IPCMessage("refresh", data)
|
||||
|
||||
# Connect to the socket and send the message
|
||||
reader, writer = await asyncio.open_unix_connection(str(socket_path))
|
||||
|
||||
writer.write((message.to_json() + "\n").encode())
|
||||
await writer.drain()
|
||||
|
||||
writer.close()
|
||||
await writer.wait_closed()
|
||||
|
||||
return True
|
||||
except (ConnectionRefusedError, FileNotFoundError, OSError):
|
||||
# Socket exists but no one is listening, or other error
|
||||
return False
|
||||
except Exception:
|
||||
return False
|
||||
|
||||
|
||||
async def notify_all(data: Optional[Dict[str, Any]] = None) -> Dict[str, bool]:
|
||||
"""Send a refresh notification to all apps.
|
||||
|
||||
Args:
|
||||
data: Optional data to include with the notification
|
||||
|
||||
Returns:
|
||||
Dictionary of app_type -> success status
|
||||
"""
|
||||
results = {}
|
||||
for app_type in SOCKET_PATHS:
|
||||
results[app_type] = await notify_refresh(app_type, data)
|
||||
return results
|
||||
|
||||
|
||||
class IPCListener:
|
||||
"""Listens for IPC messages on a Unix socket.
|
||||
|
||||
Usage:
|
||||
listener = IPCListener("mail", on_message_callback)
|
||||
listener.start()
|
||||
# ... later ...
|
||||
listener.stop()
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
app_type: str,
|
||||
callback: Callable[[IPCMessage], Any],
|
||||
):
|
||||
"""Initialize the IPC listener.
|
||||
|
||||
Args:
|
||||
app_type: The type of app ("mail", "calendar", "tasks")
|
||||
callback: Function to call when a message is received.
|
||||
Can be sync or async.
|
||||
"""
|
||||
self.app_type = app_type
|
||||
self.callback = callback
|
||||
self.socket_path = get_socket_path(app_type)
|
||||
self._server: Optional[asyncio.AbstractServer] = None
|
||||
self._running = False
|
||||
self._loop: Optional[asyncio.AbstractEventLoop] = None
|
||||
self._thread: Optional[threading.Thread] = None
|
||||
|
||||
async def _handle_client(
|
||||
self, reader: asyncio.StreamReader, writer: asyncio.StreamWriter
|
||||
):
|
||||
"""Handle an incoming client connection."""
|
||||
try:
|
||||
data = await reader.readline()
|
||||
if data:
|
||||
message_str = data.decode().strip()
|
||||
if message_str:
|
||||
message = IPCMessage.from_json(message_str)
|
||||
|
||||
# Call the callback (handle both sync and async)
|
||||
result = self.callback(message)
|
||||
if asyncio.iscoroutine(result):
|
||||
await result
|
||||
except Exception:
|
||||
pass # Ignore errors from malformed messages
|
||||
finally:
|
||||
writer.close()
|
||||
try:
|
||||
await writer.wait_closed()
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
async def _run_server(self):
|
||||
"""Run the Unix socket server."""
|
||||
ensure_socket_dir()
|
||||
|
||||
# Remove stale socket file if it exists
|
||||
if self.socket_path.exists():
|
||||
self.socket_path.unlink()
|
||||
|
||||
self._server = await asyncio.start_unix_server(
|
||||
self._handle_client, path=str(self.socket_path)
|
||||
)
|
||||
|
||||
# Set socket permissions (readable/writable by owner only)
|
||||
os.chmod(self.socket_path, 0o600)
|
||||
|
||||
async with self._server:
|
||||
await self._server.serve_forever()
|
||||
|
||||
def _run_in_thread(self):
|
||||
"""Run the event loop in a separate thread."""
|
||||
self._loop = asyncio.new_event_loop()
|
||||
asyncio.set_event_loop(self._loop)
|
||||
try:
|
||||
self._loop.run_until_complete(self._run_server())
|
||||
except asyncio.CancelledError:
|
||||
pass
|
||||
finally:
|
||||
self._loop.close()
|
||||
|
||||
def start(self):
|
||||
"""Start listening for IPC messages in a background thread."""
|
||||
if self._running:
|
||||
return
|
||||
|
||||
self._running = True
|
||||
self._thread = threading.Thread(target=self._run_in_thread, daemon=True)
|
||||
self._thread.start()
|
||||
|
||||
def stop(self):
|
||||
"""Stop listening for IPC messages."""
|
||||
if not self._running:
|
||||
return
|
||||
|
||||
self._running = False
|
||||
|
||||
# Cancel the server
|
||||
if self._server and self._loop:
|
||||
self._loop.call_soon_threadsafe(self._server.close)
|
||||
|
||||
# Stop the event loop
|
||||
if self._loop:
|
||||
self._loop.call_soon_threadsafe(self._loop.stop)
|
||||
|
||||
# Wait for thread to finish
|
||||
if self._thread:
|
||||
self._thread.join(timeout=1.0)
|
||||
|
||||
# Clean up socket file
|
||||
if self.socket_path.exists():
|
||||
try:
|
||||
self.socket_path.unlink()
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
|
||||
class AsyncIPCListener:
|
||||
"""Async version of IPCListener for use within an existing event loop.
|
||||
|
||||
Usage in a Textual app:
|
||||
class MyApp(App):
|
||||
async def on_mount(self):
|
||||
self.ipc_listener = AsyncIPCListener("mail", self.on_refresh)
|
||||
await self.ipc_listener.start()
|
||||
|
||||
async def on_unmount(self):
|
||||
await self.ipc_listener.stop()
|
||||
|
||||
async def on_refresh(self, message):
|
||||
self.refresh_data()
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
app_type: str,
|
||||
callback: Callable[[IPCMessage], Any],
|
||||
):
|
||||
self.app_type = app_type
|
||||
self.callback = callback
|
||||
self.socket_path = get_socket_path(app_type)
|
||||
self._server: Optional[asyncio.AbstractServer] = None
|
||||
|
||||
async def _handle_client(
|
||||
self, reader: asyncio.StreamReader, writer: asyncio.StreamWriter
|
||||
):
|
||||
"""Handle an incoming client connection."""
|
||||
try:
|
||||
data = await reader.readline()
|
||||
if data:
|
||||
message_str = data.decode().strip()
|
||||
if message_str:
|
||||
message = IPCMessage.from_json(message_str)
|
||||
|
||||
result = self.callback(message)
|
||||
if asyncio.iscoroutine(result):
|
||||
await result
|
||||
except Exception:
|
||||
pass
|
||||
finally:
|
||||
writer.close()
|
||||
try:
|
||||
await writer.wait_closed()
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
async def start(self):
|
||||
"""Start the Unix socket server."""
|
||||
ensure_socket_dir()
|
||||
|
||||
if self.socket_path.exists():
|
||||
self.socket_path.unlink()
|
||||
|
||||
self._server = await asyncio.start_unix_server(
|
||||
self._handle_client, path=str(self.socket_path)
|
||||
)
|
||||
os.chmod(self.socket_path, 0o600)
|
||||
|
||||
async def stop(self):
|
||||
"""Stop the server and clean up."""
|
||||
if self._server:
|
||||
self._server.close()
|
||||
await self._server.wait_closed()
|
||||
|
||||
if self.socket_path.exists():
|
||||
try:
|
||||
self.socket_path.unlink()
|
||||
except Exception:
|
||||
pass
|
||||
Reference in New Issue
Block a user