Full test infrastructure, code audit fixes, and real E2E integration tests
Test infrastructure: - Django standalone test runner (pytest-django, test settings, EmailUser model) - React unit tests via Vitest with jsdom, jest compat layer, path aliases - Playwright E2E tests using generated hooks in a real Chromium browser - Docker Compose test backend (Django + Redis) for integration testing - Desktop integration test app (PyWebView + Django + uvicorn) - Makefile with test/test-django/test-react/test-integration targets Library bugs found and fixed: - hasJWT truthiness: undefined !== null was true, skipping session init - process.env crash: CSR client referenced process.env in non-Node browsers - baseUrl not forwarded: DjareaProvider didn't pass baseUrl to CSR client - Relative URL handling: new URL() failed with relative base paths - call() race condition: HTTP requests fired before CSRF cookie was set - Session init await: added sessionRef promise so call() waits for session - path_prefix on schema export: both export commands failed with URL reverse - NullBooleanField removed: referenced field doesn't exist in Django 5.0+ - lru_cache on JWT settings: get_settings() now cached as intended - Channel message routing: broadcasts now include channel name and params - httpFunctionCall: fixed URL and request body format Generator fixes: - Removed 1,100 lines of REST/OpenAPI client generation (not part of Djarea) - Generator now works for djarea-only projects without django-ninja REST APIs - Generated DjangoContext now includes ChannelProvider when channels exist - Fixed env var passthrough for schema export commands - Deduplicated fetch logic into single runDjangoCommand helper Test quality: - Fixed 33 tautological Django tests with real assertions - Found hidden bug: benchmark functions were never registered - Found hidden bug: unicode lookalike test used plain ASCII - Deleted worthless React unit tests (duplicates, shape checks, Zod-tests-Zod) - Replaced jsdom integration tests with Playwright browser tests Example apps: - example/: Integration test backend with 33 server functions, 5 forms, 4 channels covering auth variations, contexts, class-based ServerFunction, error codes, DjareaFormMixin, formsets, and JWT - desktop/: PyWebView desktop app with file system access, SQLite CRUD, system introspection, and 39 real HTTP integration tests Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
This commit is contained in:
32
.gitea/workflows/publish-django.yaml
Normal file
32
.gitea/workflows/publish-django.yaml
Normal file
@@ -0,0 +1,32 @@
|
||||
name: Publish Django package to PyPI
|
||||
|
||||
on:
|
||||
push:
|
||||
tags:
|
||||
- 'django/v*'
|
||||
|
||||
jobs:
|
||||
publish:
|
||||
runs-on: ubuntu-latest
|
||||
defaults:
|
||||
run:
|
||||
working-directory: django
|
||||
steps:
|
||||
- uses: actions/checkout@v4
|
||||
|
||||
- uses: actions/setup-python@v5
|
||||
with:
|
||||
python-version: '3.12'
|
||||
|
||||
- name: Install build tools
|
||||
run: pip install build twine
|
||||
|
||||
- name: Build package
|
||||
run: python -m build
|
||||
|
||||
- name: Publish to Gitea PyPI
|
||||
env:
|
||||
TWINE_REPOSITORY_URL: ${{ gitea.server_url }}/api/packages/${{ gitea.repository_owner }}/pypi
|
||||
TWINE_USERNAME: __token__
|
||||
TWINE_PASSWORD: ${{ secrets.PUBLISH_TOKEN }}
|
||||
run: twine upload dist/*
|
||||
36
.gitea/workflows/publish-react.yaml
Normal file
36
.gitea/workflows/publish-react.yaml
Normal file
@@ -0,0 +1,36 @@
|
||||
name: Publish React package to npm
|
||||
|
||||
on:
|
||||
push:
|
||||
tags:
|
||||
- 'react/v*'
|
||||
|
||||
jobs:
|
||||
publish:
|
||||
runs-on: ubuntu-latest
|
||||
defaults:
|
||||
run:
|
||||
working-directory: react
|
||||
steps:
|
||||
- uses: actions/checkout@v4
|
||||
|
||||
- uses: actions/setup-node@v4
|
||||
with:
|
||||
node-version: '22'
|
||||
|
||||
- name: Install dependencies
|
||||
run: npm ci
|
||||
|
||||
- name: Build
|
||||
run: npm run build
|
||||
|
||||
- name: Configure Gitea npm registry
|
||||
env:
|
||||
REGISTRY_URL: ${{ gitea.server_url }}/api/packages/${{ gitea.repository_owner }}/npm/
|
||||
PUBLISH_TOKEN: ${{ secrets.PUBLISH_TOKEN }}
|
||||
run: |
|
||||
npm config set @rythazhur:registry "${REGISTRY_URL}"
|
||||
npm config set -- "${REGISTRY_URL#https:}:_authToken" "${PUBLISH_TOKEN}"
|
||||
|
||||
- name: Publish
|
||||
run: npm publish
|
||||
32
.gitignore
vendored
Normal file
32
.gitignore
vendored
Normal file
@@ -0,0 +1,32 @@
|
||||
# Python
|
||||
__pycache__/
|
||||
*.pyc
|
||||
*.egg-info/
|
||||
.venv/
|
||||
*.db
|
||||
uv.lock
|
||||
|
||||
# Node
|
||||
node_modules/
|
||||
dist/
|
||||
package-lock.json
|
||||
|
||||
# Playwright
|
||||
/test-results/
|
||||
/playwright-report/
|
||||
/blob-report/
|
||||
|
||||
# IDE
|
||||
.idea/
|
||||
.vscode/
|
||||
|
||||
# Build artifacts
|
||||
desktop/frontend/dist/
|
||||
e2e/harness/src/api/generated.*
|
||||
e2e/harness/test-results/
|
||||
|
||||
# Env
|
||||
.env
|
||||
.env.*
|
||||
*.pem
|
||||
*.key
|
||||
21
Dockerfile.test
Normal file
21
Dockerfile.test
Normal file
@@ -0,0 +1,21 @@
|
||||
FROM python:3.12-slim
|
||||
|
||||
WORKDIR /app
|
||||
|
||||
# Install system dependencies
|
||||
RUN apt-get update && apt-get install -y --no-install-recommends \
|
||||
gcc \
|
||||
&& rm -rf /var/lib/apt/lists/*
|
||||
|
||||
# Install djarea from local source with channels support
|
||||
COPY django/ /app/django/
|
||||
RUN pip install --no-cache-dir /app/django[channels] daphne
|
||||
|
||||
# Copy example app
|
||||
COPY example/ /app/example/
|
||||
|
||||
WORKDIR /app/example
|
||||
|
||||
EXPOSE 8000
|
||||
|
||||
CMD ["sh", "-c", "python manage.py migrate --run-syncdb && daphne -b 0.0.0.0 -p 8000 testapp.asgi:application"]
|
||||
46
Makefile
Normal file
46
Makefile
Normal file
@@ -0,0 +1,46 @@
|
||||
.PHONY: install test test-django test-react test-integration docker-up docker-down clean
|
||||
|
||||
# ─── Setup ───────────────────────────────────────────────────────────────────
|
||||
|
||||
install:
|
||||
cd django && pip install -e ".[dev,channels]"
|
||||
cd react && npm install
|
||||
|
||||
# ─── Unit Tests ──────────────────────────────────────────────────────────────
|
||||
|
||||
test: test-django test-react
|
||||
|
||||
test-django:
|
||||
cd django && pytest
|
||||
|
||||
test-react:
|
||||
cd react && npm test
|
||||
|
||||
# ─── Integration Tests ──────────────────────────────────────────────────────
|
||||
|
||||
test-integration: docker-up
|
||||
@echo "Waiting for backend..."
|
||||
@timeout 30 sh -c 'until curl -sf http://localhost:8000/api/djarea/session/ > /dev/null 2>&1; do sleep 1; done'
|
||||
cd react && npm run test:integration
|
||||
@$(MAKE) docker-down
|
||||
|
||||
# ─── Docker ──────────────────────────────────────────────────────────────────
|
||||
|
||||
docker-up:
|
||||
docker compose -f docker-compose.test.yml up -d --build
|
||||
@echo "Backend starting at http://localhost:8000"
|
||||
|
||||
docker-down:
|
||||
docker compose -f docker-compose.test.yml down
|
||||
|
||||
# ─── All ─────────────────────────────────────────────────────────────────────
|
||||
|
||||
test-all: test test-integration
|
||||
|
||||
# ─── Cleanup ─────────────────────────────────────────────────────────────────
|
||||
|
||||
clean:
|
||||
docker compose -f docker-compose.test.yml down -v --remove-orphans 2>/dev/null || true
|
||||
rm -rf django/src/djarea.egg-info django/dist django/build
|
||||
rm -rf react/dist react/node_modules
|
||||
rm -f example/db.sqlite3
|
||||
22
README.md
Normal file
22
README.md
Normal file
@@ -0,0 +1,22 @@
|
||||
# djarea
|
||||
|
||||
Django + React server functions framework.
|
||||
|
||||
| Package | Path | Registry |
|
||||
|---------|------|----------|
|
||||
| `djarea` (Python) | `django/` | PyPI / git |
|
||||
| `djarea` (TypeScript) | `react/` | npm / git |
|
||||
|
||||
## Installation
|
||||
|
||||
```bash
|
||||
# Python
|
||||
uv add "djarea[channels,allauth] @ git+https://git.impactsoundworks.com/isw/djarea.git#subdirectory=django"
|
||||
|
||||
# TypeScript
|
||||
npm install djarea@git+https://git.impactsoundworks.com/isw/djarea.git#workspace=react
|
||||
```
|
||||
|
||||
## Quick Start
|
||||
|
||||
See [django/README.md](django/README.md) and [react/README.md](react/README.md).
|
||||
96
desktop/app.py
Normal file
96
desktop/app.py
Normal file
@@ -0,0 +1,96 @@
|
||||
#!/usr/bin/env python
|
||||
"""
|
||||
Djarea Desktop — PyWebView + Django local RPC.
|
||||
|
||||
Starts a local Django ASGI server and opens a native desktop window.
|
||||
All communication between the UI and backend uses Djarea server functions.
|
||||
"""
|
||||
|
||||
import os
|
||||
import sys
|
||||
import threading
|
||||
import time
|
||||
|
||||
os.environ.setdefault("DJANGO_SETTINGS_MODULE", "backend.settings")
|
||||
|
||||
# Work around Qt WebEngine GPU crashes on some systems
|
||||
os.environ.setdefault("QTWEBENGINE_CHROMIUM_FLAGS", "--disable-gpu")
|
||||
|
||||
|
||||
def start_server(host: str, port: int):
|
||||
"""Start the Django ASGI server in a background thread."""
|
||||
import django
|
||||
|
||||
django.setup()
|
||||
|
||||
# Run migrations on first launch
|
||||
from django.core.management import call_command
|
||||
|
||||
call_command("migrate", "--run-syncdb", verbosity=0)
|
||||
|
||||
import uvicorn
|
||||
|
||||
uvicorn.run(
|
||||
"backend.asgi:application",
|
||||
host=host,
|
||||
port=port,
|
||||
log_level="warning",
|
||||
)
|
||||
|
||||
|
||||
def wait_for_server(url: str, timeout: float = 10.0):
|
||||
"""Poll until the server responds."""
|
||||
from urllib.request import urlopen
|
||||
from urllib.error import URLError
|
||||
|
||||
deadline = time.time() + timeout
|
||||
while time.time() < deadline:
|
||||
try:
|
||||
urlopen(url, timeout=1)
|
||||
return True
|
||||
except (URLError, OSError):
|
||||
time.sleep(0.1)
|
||||
return False
|
||||
|
||||
|
||||
def main():
|
||||
host = "127.0.0.1"
|
||||
port = 8765
|
||||
|
||||
# Start Django in a daemon thread
|
||||
server = threading.Thread(target=start_server, args=(host, port), daemon=True)
|
||||
server.start()
|
||||
|
||||
base_url = f"http://{host}:{port}"
|
||||
|
||||
if not wait_for_server(f"{base_url}/api/djarea/session/"):
|
||||
print("ERROR: Django server failed to start", file=sys.stderr)
|
||||
sys.exit(1)
|
||||
|
||||
print(f"Backend running at {base_url}")
|
||||
|
||||
# Check if --headless flag is passed (for testing)
|
||||
if "--headless" in sys.argv:
|
||||
print("Headless mode — server running. Press Ctrl+C to stop.")
|
||||
try:
|
||||
while True:
|
||||
time.sleep(1)
|
||||
except KeyboardInterrupt:
|
||||
pass
|
||||
return
|
||||
|
||||
# Open native window
|
||||
import webview
|
||||
|
||||
window = webview.create_window(
|
||||
title="Djarea Desktop",
|
||||
url=base_url,
|
||||
width=1024,
|
||||
height=768,
|
||||
min_size=(640, 480),
|
||||
)
|
||||
webview.start()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
0
desktop/backend/__init__.py
Normal file
0
desktop/backend/__init__.py
Normal file
6
desktop/backend/apps.py
Normal file
6
desktop/backend/apps.py
Normal file
@@ -0,0 +1,6 @@
|
||||
from django.apps import AppConfig
|
||||
|
||||
|
||||
class DesktopBackendConfig(AppConfig):
|
||||
name = "backend"
|
||||
default_auto_field = "django.db.models.BigAutoField"
|
||||
13
desktop/backend/asgi.py
Normal file
13
desktop/backend/asgi.py
Normal file
@@ -0,0 +1,13 @@
|
||||
import os
|
||||
|
||||
import django
|
||||
|
||||
os.environ.setdefault("DJANGO_SETTINGS_MODULE", "backend.settings")
|
||||
django.setup()
|
||||
|
||||
from django.core.asgi import get_asgi_application
|
||||
from djarea import wrap_asgi
|
||||
|
||||
import backend.djarea_clients # noqa: F401
|
||||
|
||||
application = wrap_asgi(get_asgi_application())
|
||||
413
desktop/backend/djarea_clients.py
Normal file
413
desktop/backend/djarea_clients.py
Normal file
@@ -0,0 +1,413 @@
|
||||
"""
|
||||
Desktop RPC server functions.
|
||||
|
||||
Tests Djarea's appropriateness for desktop apps:
|
||||
- Local file system access
|
||||
- SQLite CRUD
|
||||
- System introspection
|
||||
- Real-time channels (file watcher, app status)
|
||||
- No auth required (single-user desktop)
|
||||
"""
|
||||
|
||||
import os
|
||||
import platform
|
||||
import shutil
|
||||
import sys
|
||||
import time
|
||||
from datetime import datetime
|
||||
from pathlib import Path
|
||||
|
||||
from django.http import HttpRequest
|
||||
from pydantic import BaseModel
|
||||
|
||||
from djarea.client import client
|
||||
from djarea.channels import ReactChannel
|
||||
from djarea.setup.registry import register
|
||||
from djarea.channels import register as register_channel
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# System Info
|
||||
# =============================================================================
|
||||
|
||||
|
||||
class SystemInfoOutput(BaseModel):
|
||||
os_name: str
|
||||
os_version: str
|
||||
python_version: str
|
||||
hostname: str
|
||||
username: str
|
||||
home_dir: str
|
||||
cwd: str
|
||||
cpu_count: int
|
||||
djarea_version: str
|
||||
|
||||
|
||||
@client(websocket=True)
|
||||
def system_info(request: HttpRequest) -> SystemInfoOutput:
|
||||
import djarea
|
||||
|
||||
return SystemInfoOutput(
|
||||
os_name=platform.system(),
|
||||
os_version=platform.version(),
|
||||
python_version=sys.version.split()[0],
|
||||
hostname=platform.node(),
|
||||
username=os.getenv("USER", os.getenv("USERNAME", "unknown")),
|
||||
home_dir=str(Path.home()),
|
||||
cwd=os.getcwd(),
|
||||
cpu_count=os.cpu_count() or 1,
|
||||
djarea_version=getattr(djarea, "__version__", "dev"),
|
||||
)
|
||||
|
||||
|
||||
register(system_info, "system_info")
|
||||
|
||||
|
||||
class DiskUsageOutput(BaseModel):
|
||||
path: str
|
||||
total_gb: float
|
||||
used_gb: float
|
||||
free_gb: float
|
||||
percent_used: float
|
||||
|
||||
|
||||
@client(websocket=True)
|
||||
def disk_usage(request: HttpRequest, path: str = "/") -> DiskUsageOutput:
|
||||
usage = shutil.disk_usage(path)
|
||||
return DiskUsageOutput(
|
||||
path=path,
|
||||
total_gb=round(usage.total / (1024**3), 2),
|
||||
used_gb=round(usage.used / (1024**3), 2),
|
||||
free_gb=round(usage.free / (1024**3), 2),
|
||||
percent_used=round(usage.used / usage.total * 100, 1),
|
||||
)
|
||||
|
||||
|
||||
register(disk_usage, "disk_usage")
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# File System
|
||||
# =============================================================================
|
||||
|
||||
|
||||
class FileEntry(BaseModel):
|
||||
name: str
|
||||
path: str
|
||||
is_dir: bool
|
||||
size: int
|
||||
modified: str
|
||||
|
||||
|
||||
class ListFilesOutput(BaseModel):
|
||||
directory: str
|
||||
entries: list[FileEntry]
|
||||
parent: str | None
|
||||
|
||||
|
||||
@client(websocket=True)
|
||||
def list_files(request: HttpRequest, directory: str = "~") -> ListFilesOutput:
|
||||
dir_path = Path(directory).expanduser().resolve()
|
||||
|
||||
if not dir_path.is_dir():
|
||||
raise ValueError(f"Not a directory: {dir_path}")
|
||||
|
||||
entries = []
|
||||
try:
|
||||
for entry in sorted(dir_path.iterdir(), key=lambda e: (not e.is_dir(), e.name.lower())):
|
||||
try:
|
||||
stat = entry.stat()
|
||||
entries.append(FileEntry(
|
||||
name=entry.name,
|
||||
path=str(entry),
|
||||
is_dir=entry.is_dir(),
|
||||
size=stat.st_size if not entry.is_dir() else 0,
|
||||
modified=datetime.fromtimestamp(stat.st_mtime).isoformat(),
|
||||
))
|
||||
except (PermissionError, OSError):
|
||||
continue
|
||||
except PermissionError:
|
||||
raise PermissionError(f"Cannot read directory: {dir_path}")
|
||||
|
||||
parent = str(dir_path.parent) if dir_path.parent != dir_path else None
|
||||
|
||||
return ListFilesOutput(
|
||||
directory=str(dir_path),
|
||||
entries=entries,
|
||||
parent=parent,
|
||||
)
|
||||
|
||||
|
||||
register(list_files, "list_files")
|
||||
|
||||
|
||||
class FileContentOutput(BaseModel):
|
||||
path: str
|
||||
content: str
|
||||
size: int
|
||||
modified: str
|
||||
|
||||
|
||||
@client(websocket=True)
|
||||
def read_file(request: HttpRequest, path: str) -> FileContentOutput:
|
||||
file_path = Path(path).expanduser().resolve()
|
||||
|
||||
if not file_path.is_file():
|
||||
raise FileNotFoundError(f"File not found: {file_path}")
|
||||
|
||||
stat = file_path.stat()
|
||||
|
||||
# Safety: limit to 1MB text files
|
||||
if stat.st_size > 1_048_576:
|
||||
raise ValueError(f"File too large: {stat.st_size} bytes (max 1MB)")
|
||||
|
||||
try:
|
||||
content = file_path.read_text(encoding="utf-8")
|
||||
except UnicodeDecodeError:
|
||||
raise ValueError(f"Not a text file: {file_path}")
|
||||
|
||||
return FileContentOutput(
|
||||
path=str(file_path),
|
||||
content=content,
|
||||
size=stat.st_size,
|
||||
modified=datetime.fromtimestamp(stat.st_mtime).isoformat(),
|
||||
)
|
||||
|
||||
|
||||
register(read_file, "read_file")
|
||||
|
||||
|
||||
class WriteFileOutput(BaseModel):
|
||||
path: str
|
||||
size: int
|
||||
|
||||
|
||||
@client(websocket=True)
|
||||
def write_file(request: HttpRequest, path: str, content: str) -> WriteFileOutput:
|
||||
file_path = Path(path).expanduser().resolve()
|
||||
|
||||
# Safety: only allow writing within home directory
|
||||
home = Path.home()
|
||||
if not str(file_path).startswith(str(home)):
|
||||
raise PermissionError(f"Can only write files within home directory: {home}")
|
||||
|
||||
file_path.parent.mkdir(parents=True, exist_ok=True)
|
||||
file_path.write_text(content, encoding="utf-8")
|
||||
|
||||
return WriteFileOutput(path=str(file_path), size=len(content.encode("utf-8")))
|
||||
|
||||
|
||||
register(write_file, "write_file")
|
||||
|
||||
|
||||
class DeleteFileOutput(BaseModel):
|
||||
path: str
|
||||
deleted: bool
|
||||
|
||||
|
||||
@client(websocket=True)
|
||||
def delete_file(request: HttpRequest, path: str) -> DeleteFileOutput:
|
||||
file_path = Path(path).expanduser().resolve()
|
||||
|
||||
home = Path.home()
|
||||
if not str(file_path).startswith(str(home)):
|
||||
raise PermissionError(f"Can only delete files within home directory: {home}")
|
||||
|
||||
if file_path.exists():
|
||||
file_path.unlink()
|
||||
return DeleteFileOutput(path=str(file_path), deleted=True)
|
||||
|
||||
return DeleteFileOutput(path=str(file_path), deleted=False)
|
||||
|
||||
|
||||
register(delete_file, "delete_file")
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# Notes CRUD (SQLite)
|
||||
# =============================================================================
|
||||
|
||||
|
||||
class NoteOutput(BaseModel):
|
||||
id: int
|
||||
title: str
|
||||
content: str
|
||||
pinned: bool
|
||||
created_at: str
|
||||
updated_at: str
|
||||
|
||||
|
||||
class NoteListOutput(BaseModel):
|
||||
notes: list[NoteOutput]
|
||||
count: int
|
||||
|
||||
|
||||
def _note_to_output(note) -> NoteOutput:
|
||||
return NoteOutput(
|
||||
id=note.id,
|
||||
title=note.title,
|
||||
content=note.content,
|
||||
pinned=note.pinned,
|
||||
created_at=note.created_at.isoformat(),
|
||||
updated_at=note.updated_at.isoformat(),
|
||||
)
|
||||
|
||||
|
||||
@client(websocket=True)
|
||||
def list_notes(request: HttpRequest) -> NoteListOutput:
|
||||
from backend.models import Note
|
||||
|
||||
notes = Note.objects.all()
|
||||
return NoteListOutput(
|
||||
notes=[_note_to_output(n) for n in notes],
|
||||
count=notes.count(),
|
||||
)
|
||||
|
||||
|
||||
register(list_notes, "list_notes")
|
||||
|
||||
|
||||
@client(websocket=True)
|
||||
def create_note(request: HttpRequest, title: str, content: str = "", pinned: bool = False) -> NoteOutput:
|
||||
from backend.models import Note
|
||||
|
||||
note = Note.objects.create(title=title, content=content, pinned=pinned)
|
||||
return _note_to_output(note)
|
||||
|
||||
|
||||
register(create_note, "create_note")
|
||||
|
||||
|
||||
@client(websocket=True)
|
||||
def get_note(request: HttpRequest, id: int) -> NoteOutput:
|
||||
from backend.models import Note
|
||||
|
||||
try:
|
||||
note = Note.objects.get(pk=id)
|
||||
except Note.DoesNotExist:
|
||||
raise ValueError(f"Note {id} not found")
|
||||
|
||||
return _note_to_output(note)
|
||||
|
||||
|
||||
register(get_note, "get_note")
|
||||
|
||||
|
||||
@client(websocket=True)
|
||||
def update_note(
|
||||
request: HttpRequest,
|
||||
id: int,
|
||||
title: str | None = None,
|
||||
content: str | None = None,
|
||||
pinned: bool | None = None,
|
||||
) -> NoteOutput:
|
||||
from backend.models import Note
|
||||
|
||||
try:
|
||||
note = Note.objects.get(pk=id)
|
||||
except Note.DoesNotExist:
|
||||
raise ValueError(f"Note {id} not found")
|
||||
|
||||
if title is not None:
|
||||
note.title = title
|
||||
if content is not None:
|
||||
note.content = content
|
||||
if pinned is not None:
|
||||
note.pinned = pinned
|
||||
|
||||
note.save()
|
||||
return _note_to_output(note)
|
||||
|
||||
|
||||
register(update_note, "update_note")
|
||||
|
||||
|
||||
class DeleteNoteOutput(BaseModel):
|
||||
id: int
|
||||
deleted: bool
|
||||
|
||||
|
||||
@client(websocket=True)
|
||||
def delete_note(request: HttpRequest, id: int) -> DeleteNoteOutput:
|
||||
from backend.models import Note
|
||||
|
||||
try:
|
||||
note = Note.objects.get(pk=id)
|
||||
note.delete()
|
||||
return DeleteNoteOutput(id=id, deleted=True)
|
||||
except Note.DoesNotExist:
|
||||
return DeleteNoteOutput(id=id, deleted=False)
|
||||
|
||||
|
||||
register(delete_note, "delete_note")
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# Channels — Real-time Desktop Events
|
||||
# =============================================================================
|
||||
|
||||
|
||||
class AppStatusChannel(ReactChannel):
|
||||
"""Push app status updates to the UI (uptime, memory, etc.)."""
|
||||
|
||||
class DjangoMessage(BaseModel):
|
||||
uptime_seconds: float
|
||||
memory_mb: float
|
||||
note_count: int
|
||||
timestamp: str
|
||||
|
||||
def authorize(self, params=None):
|
||||
return True # Desktop app, no auth needed
|
||||
|
||||
def group(self, params=None):
|
||||
return "app_status"
|
||||
|
||||
|
||||
register_channel(AppStatusChannel, "app_status")
|
||||
|
||||
|
||||
class NotesChannel(ReactChannel):
|
||||
"""Push notifications when notes are modified."""
|
||||
|
||||
class DjangoMessage(BaseModel):
|
||||
action: str # "created", "updated", "deleted"
|
||||
note_id: int
|
||||
title: str
|
||||
|
||||
def authorize(self, params=None):
|
||||
return True
|
||||
|
||||
def group(self, params=None):
|
||||
return "notes_updates"
|
||||
|
||||
|
||||
register_channel(NotesChannel, "notes_updates")
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# App Lifecycle
|
||||
# =============================================================================
|
||||
|
||||
_start_time = time.time()
|
||||
|
||||
|
||||
class AppInfoOutput(BaseModel):
|
||||
app_name: str
|
||||
uptime_seconds: float
|
||||
db_path: str
|
||||
pid: int
|
||||
|
||||
|
||||
@client(websocket=True)
|
||||
def app_info(request: HttpRequest) -> AppInfoOutput:
|
||||
from django.conf import settings
|
||||
|
||||
return AppInfoOutput(
|
||||
app_name="Djarea Desktop",
|
||||
uptime_seconds=round(time.time() - _start_time, 2),
|
||||
db_path=str(settings.DATABASES["default"]["NAME"]),
|
||||
pid=os.getpid(),
|
||||
)
|
||||
|
||||
|
||||
register(app_info, "app_info")
|
||||
15
desktop/backend/models.py
Normal file
15
desktop/backend/models.py
Normal file
@@ -0,0 +1,15 @@
|
||||
from django.db import models
|
||||
|
||||
|
||||
class Note(models.Model):
|
||||
title = models.CharField(max_length=200)
|
||||
content = models.TextField(blank=True, default="")
|
||||
pinned = models.BooleanField(default=False)
|
||||
created_at = models.DateTimeField(auto_now_add=True)
|
||||
updated_at = models.DateTimeField(auto_now=True)
|
||||
|
||||
class Meta:
|
||||
ordering = ["-pinned", "-updated_at"]
|
||||
|
||||
def __str__(self):
|
||||
return self.title
|
||||
49
desktop/backend/settings.py
Normal file
49
desktop/backend/settings.py
Normal file
@@ -0,0 +1,49 @@
|
||||
"""
|
||||
Django settings for the Djarea desktop integration test app.
|
||||
|
||||
Runs entirely local: SQLite database, in-memory channel layer,
|
||||
no external services required.
|
||||
"""
|
||||
|
||||
import os
|
||||
|
||||
BASE_DIR = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
|
||||
|
||||
SECRET_KEY = "desktop-app-local-only-secret-key"
|
||||
|
||||
DEBUG = True
|
||||
|
||||
ALLOWED_HOSTS = ["127.0.0.1", "localhost"]
|
||||
|
||||
INSTALLED_APPS = [
|
||||
"django.contrib.contenttypes",
|
||||
"backend",
|
||||
]
|
||||
|
||||
MIDDLEWARE = []
|
||||
|
||||
ROOT_URLCONF = "backend.urls"
|
||||
|
||||
DATABASES = {
|
||||
"default": {
|
||||
"ENGINE": "django.db.backends.sqlite3",
|
||||
"NAME": os.path.join(BASE_DIR, "app.db"),
|
||||
}
|
||||
}
|
||||
|
||||
DEFAULT_AUTO_FIELD = "django.db.models.BigAutoField"
|
||||
|
||||
ASGI_APPLICATION = "backend.asgi.application"
|
||||
|
||||
CHANNEL_LAYERS = {
|
||||
"default": {
|
||||
"BACKEND": "channels.layers.InMemoryChannelLayer",
|
||||
},
|
||||
}
|
||||
|
||||
# Serve the built frontend
|
||||
STATIC_URL = "/static/"
|
||||
STATICFILES_DIRS = [os.path.join(BASE_DIR, "frontend", "dist")]
|
||||
|
||||
# No auth, no CSRF — local desktop app
|
||||
CSRF_COOKIE_HTTPONLY = False
|
||||
34
desktop/backend/urls.py
Normal file
34
desktop/backend/urls.py
Normal file
@@ -0,0 +1,34 @@
|
||||
from django.urls import include, path, re_path
|
||||
from django.http import HttpResponse, HttpResponseNotFound
|
||||
from pathlib import Path
|
||||
|
||||
DIST_DIR = Path(__file__).resolve().parent.parent / "frontend" / "dist"
|
||||
|
||||
CONTENT_TYPES = {
|
||||
".html": "text/html",
|
||||
".js": "application/javascript",
|
||||
".css": "text/css",
|
||||
".svg": "image/svg+xml",
|
||||
".png": "image/png",
|
||||
".ico": "image/x-icon",
|
||||
".woff2": "font/woff2",
|
||||
".json": "application/json",
|
||||
}
|
||||
|
||||
|
||||
def serve_dist(request, path="index.html"):
|
||||
file_path = (DIST_DIR / path).resolve()
|
||||
|
||||
if not str(file_path).startswith(str(DIST_DIR)) or not file_path.is_file():
|
||||
return HttpResponseNotFound()
|
||||
|
||||
ct = CONTENT_TYPES.get(file_path.suffix, "application/octet-stream")
|
||||
return HttpResponse(file_path.read_bytes(), content_type=ct)
|
||||
|
||||
|
||||
urlpatterns = [
|
||||
path("api/djarea/", include("djarea.urls")),
|
||||
re_path(r"^(?P<path>assets/.+)$", serve_dist),
|
||||
path("favicon.ico", serve_dist, {"path": "favicon.ico"}),
|
||||
path("", serve_dist),
|
||||
]
|
||||
16
desktop/frontend/index.html
Normal file
16
desktop/frontend/index.html
Normal file
@@ -0,0 +1,16 @@
|
||||
<!DOCTYPE html>
|
||||
<html lang="en">
|
||||
<head>
|
||||
<meta charset="UTF-8" />
|
||||
<meta name="viewport" content="width=device-width, initial-scale=1.0" />
|
||||
<title>Djarea Desktop</title>
|
||||
<style>
|
||||
* { margin: 0; padding: 0; box-sizing: border-box; }
|
||||
body { font-family: system-ui, -apple-system, sans-serif; background: #0f0f0f; color: #e0e0e0; }
|
||||
</style>
|
||||
</head>
|
||||
<body>
|
||||
<div id="root"></div>
|
||||
<script type="module" src="/src/main.tsx"></script>
|
||||
</body>
|
||||
</html>
|
||||
21
desktop/frontend/package.json
Normal file
21
desktop/frontend/package.json
Normal file
@@ -0,0 +1,21 @@
|
||||
{
|
||||
"name": "djarea-desktop-frontend",
|
||||
"private": true,
|
||||
"type": "module",
|
||||
"scripts": {
|
||||
"dev": "vite --port 5173",
|
||||
"build": "vite build"
|
||||
},
|
||||
"dependencies": {
|
||||
"@rythazhur/djarea": "file:../../react",
|
||||
"react": "^19.0.0",
|
||||
"react-dom": "^19.0.0"
|
||||
},
|
||||
"devDependencies": {
|
||||
"@types/react": "^19.0.0",
|
||||
"@types/react-dom": "^19.0.0",
|
||||
"@vitejs/plugin-react": "^4.0.0",
|
||||
"typescript": "^5.7.0",
|
||||
"vite": "^6.0.0"
|
||||
}
|
||||
}
|
||||
215
desktop/frontend/src/App.tsx
Normal file
215
desktop/frontend/src/App.tsx
Normal file
@@ -0,0 +1,215 @@
|
||||
import { useState, useEffect, useCallback } from 'react'
|
||||
import { DjareaProvider, useDjarea, useDjareaStatus } from '@rythazhur/djarea'
|
||||
|
||||
// ─── System Info ────────────────────────────────────────────────────────────
|
||||
|
||||
function SystemInfo() {
|
||||
const { call } = useDjarea()
|
||||
const [info, setInfo] = useState<Record<string, unknown> | null>(null)
|
||||
|
||||
useEffect(() => {
|
||||
call('system_info').then(setInfo).catch(() => {})
|
||||
}, [call])
|
||||
|
||||
if (!info) return <div style={styles.card}>Loading system info...</div>
|
||||
|
||||
return (
|
||||
<div style={styles.card}>
|
||||
<h2 style={styles.h2}>System</h2>
|
||||
<table style={styles.table}>
|
||||
<tbody>
|
||||
{Object.entries(info).map(([k, v]) => (
|
||||
<tr key={k}>
|
||||
<td style={styles.label}>{k}</td>
|
||||
<td style={styles.value}>{String(v)}</td>
|
||||
</tr>
|
||||
))}
|
||||
</tbody>
|
||||
</table>
|
||||
</div>
|
||||
)
|
||||
}
|
||||
|
||||
// ─── Connection Status ──────────────────────────────────────────────────────
|
||||
|
||||
function StatusBar() {
|
||||
const status = useDjareaStatus()
|
||||
return (
|
||||
<div style={{ ...styles.statusBar, color: status === 'connected' ? '#4ade80' : '#f87171' }}>
|
||||
{status}
|
||||
</div>
|
||||
)
|
||||
}
|
||||
|
||||
// ─── Notes ──────────────────────────────────────────────────────────────────
|
||||
|
||||
type Note = { id: number; title: string; content: string; pinned: boolean; updated_at: string }
|
||||
|
||||
function Notes() {
|
||||
const { call } = useDjarea()
|
||||
const [notes, setNotes] = useState<Note[]>([])
|
||||
const [selected, setSelected] = useState<Note | null>(null)
|
||||
const [title, setTitle] = useState('')
|
||||
const [content, setContent] = useState('')
|
||||
|
||||
const refresh = useCallback(() => {
|
||||
call<{ notes: Note[] }>('list_notes').then(d => setNotes(d.notes)).catch(() => {})
|
||||
}, [call])
|
||||
|
||||
useEffect(() => { refresh() }, [refresh])
|
||||
|
||||
const create = async () => {
|
||||
if (!title.trim()) return
|
||||
await call('create_note', { title, content })
|
||||
setTitle('')
|
||||
setContent('')
|
||||
refresh()
|
||||
}
|
||||
|
||||
const save = async () => {
|
||||
if (!selected) return
|
||||
await call('update_note', { id: selected.id, title, content })
|
||||
setSelected(null)
|
||||
setTitle('')
|
||||
setContent('')
|
||||
refresh()
|
||||
}
|
||||
|
||||
const remove = async (id: number) => {
|
||||
await call('delete_note', { id })
|
||||
if (selected?.id === id) { setSelected(null); setTitle(''); setContent('') }
|
||||
refresh()
|
||||
}
|
||||
|
||||
const select = (n: Note) => {
|
||||
setSelected(n)
|
||||
setTitle(n.title)
|
||||
setContent(n.content)
|
||||
}
|
||||
|
||||
return (
|
||||
<div style={styles.card}>
|
||||
<h2 style={styles.h2}>Notes ({notes.length})</h2>
|
||||
<div style={{ display: 'flex', gap: 12 }}>
|
||||
<div style={{ flex: 1 }}>
|
||||
{notes.map(n => (
|
||||
<div
|
||||
key={n.id}
|
||||
onClick={() => select(n)}
|
||||
style={{
|
||||
...styles.noteItem,
|
||||
borderLeft: selected?.id === n.id ? '3px solid #6cf' : '3px solid transparent',
|
||||
}}
|
||||
>
|
||||
<span>{n.pinned ? '\u{1f4cc} ' : ''}{n.title}</span>
|
||||
<button onClick={e => { e.stopPropagation(); remove(n.id) }} style={styles.deleteBtn}>x</button>
|
||||
</div>
|
||||
))}
|
||||
{notes.length === 0 && <div style={{ color: '#666', padding: 8 }}>No notes yet</div>}
|
||||
</div>
|
||||
<div style={{ flex: 2 }}>
|
||||
<input
|
||||
value={title}
|
||||
onChange={e => setTitle(e.target.value)}
|
||||
placeholder="Title"
|
||||
style={styles.input}
|
||||
/>
|
||||
<textarea
|
||||
value={content}
|
||||
onChange={e => setContent(e.target.value)}
|
||||
placeholder="Content"
|
||||
rows={6}
|
||||
style={{ ...styles.input, resize: 'vertical' }}
|
||||
/>
|
||||
<button onClick={selected ? save : create} style={styles.btn}>
|
||||
{selected ? 'Save' : 'Create'}
|
||||
</button>
|
||||
{selected && (
|
||||
<button onClick={() => { setSelected(null); setTitle(''); setContent('') }} style={{ ...styles.btn, background: '#333' }}>
|
||||
Cancel
|
||||
</button>
|
||||
)}
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
)
|
||||
}
|
||||
|
||||
// ─── File Browser ───────────────────────────────────────────────────────────
|
||||
|
||||
type FileEntry = { name: string; path: string; is_dir: boolean; size: number }
|
||||
|
||||
function FileBrowser() {
|
||||
const { call } = useDjarea()
|
||||
const [dir, setDir] = useState('~')
|
||||
const [entries, setEntries] = useState<FileEntry[]>([])
|
||||
const [parent, setParent] = useState<string | null>(null)
|
||||
|
||||
const browse = useCallback((d: string) => {
|
||||
call<{ directory: string; entries: FileEntry[]; parent: string | null }>('list_files', { directory: d })
|
||||
.then(data => {
|
||||
setDir(data.directory)
|
||||
setEntries(data.entries.slice(0, 50))
|
||||
setParent(data.parent)
|
||||
})
|
||||
.catch(() => {})
|
||||
}, [call])
|
||||
|
||||
useEffect(() => { browse('~') }, [browse])
|
||||
|
||||
return (
|
||||
<div style={styles.card}>
|
||||
<h2 style={styles.h2}>Files</h2>
|
||||
<div style={{ color: '#888', fontSize: 13, marginBottom: 8 }}>{dir}</div>
|
||||
{parent && (
|
||||
<div onClick={() => browse(parent)} style={{ ...styles.fileItem, color: '#6cf', cursor: 'pointer' }}>
|
||||
../ (parent)
|
||||
</div>
|
||||
)}
|
||||
{entries.map(e => (
|
||||
<div
|
||||
key={e.path}
|
||||
onClick={() => e.is_dir && browse(e.path)}
|
||||
style={{ ...styles.fileItem, cursor: e.is_dir ? 'pointer' : 'default', color: e.is_dir ? '#6cf' : '#ccc' }}
|
||||
>
|
||||
{e.is_dir ? '\u{1f4c1}' : '\u{1f4c4}'} {e.name}
|
||||
{!e.is_dir && <span style={{ color: '#666', marginLeft: 8 }}>{(e.size / 1024).toFixed(1)}K</span>}
|
||||
</div>
|
||||
))}
|
||||
</div>
|
||||
)
|
||||
}
|
||||
|
||||
// ─── App ────────────────────────────────────────────────────────────────────
|
||||
|
||||
export function App() {
|
||||
return (
|
||||
<DjareaProvider baseUrl="/api/djarea" autoConnect={false}>
|
||||
<div style={{ maxWidth: 960, margin: '0 auto', padding: 24 }}>
|
||||
<div style={{ display: 'flex', justifyContent: 'space-between', alignItems: 'center', marginBottom: 24 }}>
|
||||
<h1 style={{ fontSize: 24, color: '#fff' }}>Djarea Desktop</h1>
|
||||
<StatusBar />
|
||||
</div>
|
||||
<SystemInfo />
|
||||
<Notes />
|
||||
<FileBrowser />
|
||||
</div>
|
||||
</DjareaProvider>
|
||||
)
|
||||
}
|
||||
|
||||
// ─── Styles ─────────────────────────────────────────────────────────────────
|
||||
|
||||
const styles: Record<string, React.CSSProperties> = {
|
||||
card: { background: '#1a1a1a', borderRadius: 8, padding: 20, marginBottom: 16 },
|
||||
h2: { fontSize: 16, marginBottom: 12, color: '#aaa', textTransform: 'uppercase', letterSpacing: 1 },
|
||||
table: { width: '100%', fontSize: 14 },
|
||||
label: { padding: '4px 12px 4px 0', color: '#888', whiteSpace: 'nowrap' },
|
||||
value: { padding: '4px 0', wordBreak: 'break-all' },
|
||||
input: { width: '100%', padding: '8px 12px', marginBottom: 8, background: '#111', border: '1px solid #333', borderRadius: 4, color: '#e0e0e0', fontSize: 14 },
|
||||
btn: { padding: '8px 16px', background: '#2563eb', color: '#fff', border: 'none', borderRadius: 4, cursor: 'pointer', marginRight: 8, fontSize: 14 },
|
||||
noteItem: { display: 'flex', justifyContent: 'space-between', alignItems: 'center', padding: '8px 12px', cursor: 'pointer', borderRadius: 4, marginBottom: 2 },
|
||||
deleteBtn: { background: 'none', border: 'none', color: '#666', cursor: 'pointer', fontSize: 14, padding: '2px 6px' },
|
||||
fileItem: { padding: '4px 8px', fontSize: 14 },
|
||||
statusBar: { fontSize: 12, fontFamily: 'monospace' },
|
||||
}
|
||||
4
desktop/frontend/src/main.tsx
Normal file
4
desktop/frontend/src/main.tsx
Normal file
@@ -0,0 +1,4 @@
|
||||
import { createRoot } from 'react-dom/client'
|
||||
import { App } from './App'
|
||||
|
||||
createRoot(document.getElementById('root')!).render(<App />)
|
||||
11
desktop/frontend/tsconfig.json
Normal file
11
desktop/frontend/tsconfig.json
Normal file
@@ -0,0 +1,11 @@
|
||||
{
|
||||
"compilerOptions": {
|
||||
"target": "ES2022",
|
||||
"module": "ESNext",
|
||||
"moduleResolution": "bundler",
|
||||
"strict": true,
|
||||
"jsx": "react-jsx",
|
||||
"skipLibCheck": true
|
||||
},
|
||||
"include": ["src"]
|
||||
}
|
||||
12
desktop/frontend/vite.config.ts
Normal file
12
desktop/frontend/vite.config.ts
Normal file
@@ -0,0 +1,12 @@
|
||||
import { defineConfig } from 'vite'
|
||||
import react from '@vitejs/plugin-react'
|
||||
|
||||
export default defineConfig({
|
||||
plugins: [react()],
|
||||
server: {
|
||||
proxy: {
|
||||
'/api': 'http://127.0.0.1:8765',
|
||||
'/ws': { target: 'ws://127.0.0.1:8765', ws: true },
|
||||
},
|
||||
},
|
||||
})
|
||||
8
desktop/manage.py
Normal file
8
desktop/manage.py
Normal file
@@ -0,0 +1,8 @@
|
||||
#!/usr/bin/env python
|
||||
import os
|
||||
import sys
|
||||
|
||||
if __name__ == "__main__":
|
||||
os.environ.setdefault("DJANGO_SETTINGS_MODULE", "backend.settings")
|
||||
from django.core.management import execute_from_command_line
|
||||
execute_from_command_line(sys.argv)
|
||||
25
desktop/pyproject.toml
Normal file
25
desktop/pyproject.toml
Normal file
@@ -0,0 +1,25 @@
|
||||
[project]
|
||||
name = "djarea-desktop"
|
||||
version = "0.1.0"
|
||||
description = "Desktop integration test app for Djarea"
|
||||
requires-python = ">=3.10"
|
||||
dependencies = [
|
||||
"djarea[channels]",
|
||||
"uvicorn[standard]>=0.30",
|
||||
"pywebview[qt]>=5.0",
|
||||
]
|
||||
|
||||
[tool.uv.sources]
|
||||
djarea = { path = "../django", editable = true }
|
||||
|
||||
[project.optional-dependencies]
|
||||
dev = [
|
||||
"pytest>=8.0",
|
||||
"pytest-django>=4.9",
|
||||
"httpx>=0.27",
|
||||
]
|
||||
|
||||
[tool.pytest.ini_options]
|
||||
DJANGO_SETTINGS_MODULE = "backend.settings"
|
||||
pythonpath = ["."]
|
||||
testpaths = ["tests"]
|
||||
0
desktop/tests/__init__.py
Normal file
0
desktop/tests/__init__.py
Normal file
7
desktop/tests/conftest.py
Normal file
7
desktop/tests/conftest.py
Normal file
@@ -0,0 +1,7 @@
|
||||
import django
|
||||
from django.conf import settings
|
||||
|
||||
# Ensure migrations run before tests
|
||||
def pytest_configure():
|
||||
# Import djarea_clients to trigger function registration
|
||||
import backend.djarea_clients # noqa: F401
|
||||
173
desktop/tests/test_desktop_rpc.py
Normal file
173
desktop/tests/test_desktop_rpc.py
Normal file
@@ -0,0 +1,173 @@
|
||||
"""
|
||||
REAL integration tests for the Djarea RPC framework layer.
|
||||
|
||||
Tests the actual HTTP stack: CSRF, middleware, error codes, validation.
|
||||
Every test makes a real HTTP request — no mocks, no RequestFactory.
|
||||
"""
|
||||
|
||||
import json
|
||||
from urllib.request import urlopen, Request
|
||||
from urllib.error import HTTPError
|
||||
|
||||
from django.test import LiveServerTestCase
|
||||
|
||||
|
||||
class RealHTTPMixin:
|
||||
def _session_init(self):
|
||||
url = f"{self.live_server_url}/api/djarea/session/"
|
||||
resp = urlopen(Request(url))
|
||||
cookies = resp.headers.get_all("Set-Cookie") or []
|
||||
for cookie in cookies:
|
||||
if "csrftoken=" in cookie:
|
||||
self._csrf_token = cookie.split("csrftoken=")[1].split(";")[0]
|
||||
self._cookies = f"csrftoken={self._csrf_token}"
|
||||
return
|
||||
self._csrf_token = None
|
||||
self._cookies = ""
|
||||
|
||||
def _call(self, fn: str, args: dict | None = None):
|
||||
url = f"{self.live_server_url}/api/djarea/call/"
|
||||
body = json.dumps({"fn": fn, "args": args or {}}).encode()
|
||||
req = Request(url, data=body, method="POST")
|
||||
req.add_header("Content-Type", "application/json")
|
||||
if self._csrf_token:
|
||||
req.add_header("X-CSRFToken", self._csrf_token)
|
||||
if self._cookies:
|
||||
req.add_header("Cookie", self._cookies)
|
||||
resp = urlopen(req)
|
||||
return json.loads(resp.read())
|
||||
|
||||
def _raw_post(self, path: str, body: bytes | str, content_type: str = "application/json", include_csrf: bool = False):
|
||||
"""Raw POST without the call() envelope — for testing malformed requests."""
|
||||
url = f"{self.live_server_url}{path}"
|
||||
if isinstance(body, str):
|
||||
body = body.encode()
|
||||
req = Request(url, data=body, method="POST")
|
||||
req.add_header("Content-Type", content_type)
|
||||
if include_csrf and self._csrf_token:
|
||||
req.add_header("X-CSRFToken", self._csrf_token)
|
||||
req.add_header("Cookie", self._cookies)
|
||||
return urlopen(req)
|
||||
|
||||
|
||||
class CSRFTests(RealHTTPMixin, LiveServerTestCase):
|
||||
"""CSRF handling over real HTTP — the thing that was broken."""
|
||||
|
||||
def test_session_endpoint_sets_csrf_cookie(self):
|
||||
"""GET /session/ must return a Set-Cookie with csrftoken."""
|
||||
url = f"{self.live_server_url}/api/djarea/session/"
|
||||
resp = urlopen(Request(url))
|
||||
cookies = resp.headers.get_all("Set-Cookie") or []
|
||||
|
||||
csrf_cookies = [c for c in cookies if "csrftoken=" in c]
|
||||
self.assertGreater(len(csrf_cookies), 0, "No csrftoken cookie set by /session/")
|
||||
|
||||
def test_call_without_csrf_is_rejected(self):
|
||||
"""POST /call/ without CSRF token must fail."""
|
||||
url = f"{self.live_server_url}/api/djarea/call/"
|
||||
body = json.dumps({"fn": "system_info", "args": {}}).encode()
|
||||
req = Request(url, data=body, method="POST")
|
||||
req.add_header("Content-Type", "application/json")
|
||||
|
||||
try:
|
||||
resp = urlopen(req)
|
||||
data = json.loads(resp.read())
|
||||
# If it doesn't raise, the response should indicate an error
|
||||
self.assertTrue(data.get("error"), "POST without CSRF should be rejected")
|
||||
except HTTPError as e:
|
||||
self.assertEqual(e.code, 403, f"Expected 403, got {e.code}")
|
||||
|
||||
def test_call_with_csrf_succeeds(self):
|
||||
"""POST /call/ with valid CSRF token must work."""
|
||||
self._session_init()
|
||||
data = self._call("system_info")
|
||||
|
||||
self.assertFalse(data["error"])
|
||||
self.assertIn("os_name", data["data"])
|
||||
|
||||
|
||||
class ValidationTests(RealHTTPMixin, LiveServerTestCase):
|
||||
"""Pydantic validation errors over real HTTP."""
|
||||
|
||||
def setUp(self):
|
||||
self._session_init()
|
||||
|
||||
def test_missing_required_field(self):
|
||||
"""Calling create_note without title should return VALIDATION_ERROR."""
|
||||
data = self._call("create_note", {})
|
||||
|
||||
self.assertTrue(data["error"])
|
||||
self.assertEqual(data["code"], "VALIDATION_ERROR")
|
||||
|
||||
def test_wrong_type(self):
|
||||
"""Calling delete_note with string id should return VALIDATION_ERROR."""
|
||||
data = self._call("delete_note", {"id": "not-an-int"})
|
||||
|
||||
self.assertTrue(data["error"])
|
||||
self.assertEqual(data["code"], "VALIDATION_ERROR")
|
||||
|
||||
def test_missing_multiple_fields(self):
|
||||
"""write_file with no args should list all missing fields."""
|
||||
data = self._call("write_file", {})
|
||||
|
||||
self.assertTrue(data["error"])
|
||||
self.assertEqual(data["code"], "VALIDATION_ERROR")
|
||||
|
||||
|
||||
class ErrorCodeTests(RealHTTPMixin, LiveServerTestCase):
|
||||
"""Error codes over real HTTP."""
|
||||
|
||||
def setUp(self):
|
||||
self._session_init()
|
||||
|
||||
def test_not_found_function(self):
|
||||
data = self._call("this_does_not_exist")
|
||||
|
||||
self.assertTrue(data["error"])
|
||||
self.assertEqual(data["code"], "NOT_FOUND")
|
||||
|
||||
def test_forbidden_write_outside_home(self):
|
||||
data = self._call("write_file", {"path": "/etc/nope.txt", "content": "x"})
|
||||
|
||||
self.assertTrue(data["error"])
|
||||
self.assertEqual(data["code"], "FORBIDDEN")
|
||||
|
||||
def test_get_method_rejected(self):
|
||||
"""GET to /call/ should be rejected."""
|
||||
url = f"{self.live_server_url}/api/djarea/call/"
|
||||
try:
|
||||
resp = urlopen(Request(url))
|
||||
data = json.loads(resp.read())
|
||||
self.assertTrue(data.get("error"))
|
||||
except HTTPError as e:
|
||||
self.assertIn(e.code, [403, 405])
|
||||
|
||||
def test_invalid_json_body(self):
|
||||
"""Malformed JSON should return BAD_REQUEST."""
|
||||
self._session_init()
|
||||
try:
|
||||
resp = self._raw_post(
|
||||
"/api/djarea/call/",
|
||||
body="not valid json{{{",
|
||||
include_csrf=True,
|
||||
)
|
||||
data = json.loads(resp.read())
|
||||
self.assertTrue(data["error"])
|
||||
self.assertEqual(data["code"], "BAD_REQUEST")
|
||||
except HTTPError as e:
|
||||
self.assertIn(e.code, [400, 403])
|
||||
|
||||
def test_missing_fn_field(self):
|
||||
"""POST with valid JSON but no 'fn' field should return BAD_REQUEST."""
|
||||
self._session_init()
|
||||
try:
|
||||
resp = self._raw_post(
|
||||
"/api/djarea/call/",
|
||||
body=json.dumps({"not_fn": "hello"}),
|
||||
include_csrf=True,
|
||||
)
|
||||
data = json.loads(resp.read())
|
||||
self.assertTrue(data["error"])
|
||||
self.assertEqual(data["code"], "BAD_REQUEST")
|
||||
except HTTPError as e:
|
||||
self.assertIn(e.code, [400, 403])
|
||||
142
desktop/tests/test_notes.py
Normal file
142
desktop/tests/test_notes.py
Normal file
@@ -0,0 +1,142 @@
|
||||
"""
|
||||
REAL integration tests for notes CRUD over HTTP.
|
||||
|
||||
Every test makes actual HTTP requests to a live Django server.
|
||||
"""
|
||||
|
||||
import json
|
||||
|
||||
from django.test import LiveServerTestCase
|
||||
from urllib.request import urlopen, Request
|
||||
|
||||
|
||||
class RealHTTPMixin:
|
||||
def _session_init(self):
|
||||
url = f"{self.live_server_url}/api/djarea/session/"
|
||||
resp = urlopen(Request(url))
|
||||
cookies = resp.headers.get_all("Set-Cookie") or []
|
||||
for cookie in cookies:
|
||||
if "csrftoken=" in cookie:
|
||||
self._csrf_token = cookie.split("csrftoken=")[1].split(";")[0]
|
||||
self._cookies = f"csrftoken={self._csrf_token}"
|
||||
return
|
||||
self._csrf_token = None
|
||||
self._cookies = ""
|
||||
|
||||
def _call(self, fn: str, args: dict | None = None):
|
||||
url = f"{self.live_server_url}/api/djarea/call/"
|
||||
body = json.dumps({"fn": fn, "args": args or {}}).encode()
|
||||
req = Request(url, data=body, method="POST")
|
||||
req.add_header("Content-Type", "application/json")
|
||||
if self._csrf_token:
|
||||
req.add_header("X-CSRFToken", self._csrf_token)
|
||||
if self._cookies:
|
||||
req.add_header("Cookie", self._cookies)
|
||||
resp = urlopen(req)
|
||||
return json.loads(resp.read())
|
||||
|
||||
|
||||
class NotesCRUDTests(RealHTTPMixin, LiveServerTestCase):
|
||||
"""Full CRUD lifecycle over real HTTP."""
|
||||
|
||||
def setUp(self):
|
||||
self._session_init()
|
||||
|
||||
def test_list_notes_empty(self):
|
||||
data = self._call("list_notes")
|
||||
|
||||
self.assertFalse(data["error"])
|
||||
self.assertEqual(data["data"]["notes"], [])
|
||||
self.assertEqual(data["data"]["count"], 0)
|
||||
|
||||
def test_create_note(self):
|
||||
data = self._call("create_note", {"title": "First Note", "content": "Hello!"})
|
||||
|
||||
self.assertFalse(data["error"])
|
||||
self.assertEqual(data["data"]["title"], "First Note")
|
||||
self.assertEqual(data["data"]["content"], "Hello!")
|
||||
self.assertFalse(data["data"]["pinned"])
|
||||
self.assertIn("id", data["data"])
|
||||
self.assertIn("created_at", data["data"])
|
||||
|
||||
def test_create_and_list(self):
|
||||
self._call("create_note", {"title": "Note A"})
|
||||
self._call("create_note", {"title": "Note B"})
|
||||
|
||||
data = self._call("list_notes")
|
||||
self.assertFalse(data["error"])
|
||||
self.assertEqual(data["data"]["count"], 2)
|
||||
titles = [n["title"] for n in data["data"]["notes"]]
|
||||
self.assertIn("Note A", titles)
|
||||
self.assertIn("Note B", titles)
|
||||
|
||||
def test_get_note_by_id(self):
|
||||
create = self._call("create_note", {"title": "Get Me", "content": "Specific"})
|
||||
note_id = create["data"]["id"]
|
||||
|
||||
data = self._call("get_note", {"id": note_id})
|
||||
self.assertFalse(data["error"])
|
||||
self.assertEqual(data["data"]["id"], note_id)
|
||||
self.assertEqual(data["data"]["title"], "Get Me")
|
||||
|
||||
def test_update_note(self):
|
||||
create = self._call("create_note", {"title": "Original"})
|
||||
note_id = create["data"]["id"]
|
||||
|
||||
data = self._call("update_note", {"id": note_id, "title": "Updated"})
|
||||
self.assertFalse(data["error"])
|
||||
self.assertEqual(data["data"]["title"], "Updated")
|
||||
|
||||
def test_update_note_pin(self):
|
||||
create = self._call("create_note", {"title": "Pin Me"})
|
||||
note_id = create["data"]["id"]
|
||||
|
||||
data = self._call("update_note", {"id": note_id, "pinned": True})
|
||||
self.assertFalse(data["error"])
|
||||
self.assertTrue(data["data"]["pinned"])
|
||||
|
||||
def test_delete_note(self):
|
||||
create = self._call("create_note", {"title": "Delete Me"})
|
||||
note_id = create["data"]["id"]
|
||||
|
||||
data = self._call("delete_note", {"id": note_id})
|
||||
self.assertFalse(data["error"])
|
||||
self.assertTrue(data["data"]["deleted"])
|
||||
|
||||
# Verify it's gone
|
||||
from urllib.error import HTTPError
|
||||
try:
|
||||
get_data = self._call("get_note", {"id": note_id})
|
||||
self.assertTrue(get_data["error"])
|
||||
except HTTPError:
|
||||
pass # 500 is also a valid failure signal
|
||||
|
||||
def test_pinned_notes_sort_first(self):
|
||||
self._call("create_note", {"title": "Unpinned"})
|
||||
self._call("create_note", {"title": "Pinned", "pinned": True})
|
||||
|
||||
data = self._call("list_notes")
|
||||
self.assertFalse(data["error"])
|
||||
self.assertEqual(data["data"]["notes"][0]["title"], "Pinned")
|
||||
|
||||
def test_full_lifecycle(self):
|
||||
"""Create -> update -> pin -> verify -> delete over real HTTP."""
|
||||
# Create
|
||||
create = self._call("create_note", {"title": "Lifecycle", "content": "v1"})
|
||||
note_id = create["data"]["id"]
|
||||
|
||||
# Update
|
||||
self._call("update_note", {"id": note_id, "content": "v2"})
|
||||
|
||||
# Pin
|
||||
self._call("update_note", {"id": note_id, "pinned": True})
|
||||
|
||||
# Verify
|
||||
get = self._call("get_note", {"id": note_id})
|
||||
self.assertEqual(get["data"]["title"], "Lifecycle")
|
||||
self.assertEqual(get["data"]["content"], "v2")
|
||||
self.assertTrue(get["data"]["pinned"])
|
||||
|
||||
# Delete
|
||||
delete = self._call("delete_note", {"id": note_id})
|
||||
self.assertTrue(delete["data"]["deleted"])
|
||||
162
desktop/tests/test_system.py
Normal file
162
desktop/tests/test_system.py
Normal file
@@ -0,0 +1,162 @@
|
||||
"""
|
||||
REAL integration tests for desktop system RPC functions.
|
||||
|
||||
These make actual HTTP requests to a running Django server.
|
||||
No RequestFactory, no mocks, no shortcuts.
|
||||
"""
|
||||
|
||||
import json
|
||||
import os
|
||||
import platform
|
||||
from pathlib import Path
|
||||
|
||||
from django.test import LiveServerTestCase
|
||||
from urllib.request import urlopen, Request
|
||||
|
||||
|
||||
class RealHTTPMixin:
|
||||
"""Makes real HTTP requests to the live server."""
|
||||
|
||||
def _session_init(self):
|
||||
"""Hit /session/ to get CSRF cookie, like DjareaProvider does."""
|
||||
url = f"{self.live_server_url}/api/djarea/session/"
|
||||
req = Request(url)
|
||||
resp = urlopen(req)
|
||||
# Extract csrftoken from Set-Cookie header
|
||||
cookies = resp.headers.get_all("Set-Cookie") or []
|
||||
for cookie in cookies:
|
||||
if "csrftoken=" in cookie:
|
||||
self._csrf_token = cookie.split("csrftoken=")[1].split(";")[0]
|
||||
self._cookies = f"csrftoken={self._csrf_token}"
|
||||
return
|
||||
self._csrf_token = None
|
||||
self._cookies = ""
|
||||
|
||||
def _call(self, fn: str, args: dict | None = None):
|
||||
"""Make a real POST to /api/djarea/call/ with CSRF token."""
|
||||
url = f"{self.live_server_url}/api/djarea/call/"
|
||||
body = json.dumps({"fn": fn, "args": args or {}}).encode()
|
||||
req = Request(url, data=body, method="POST")
|
||||
req.add_header("Content-Type", "application/json")
|
||||
if self._csrf_token:
|
||||
req.add_header("X-CSRFToken", self._csrf_token)
|
||||
if self._cookies:
|
||||
req.add_header("Cookie", self._cookies)
|
||||
resp = urlopen(req)
|
||||
return json.loads(resp.read())
|
||||
|
||||
|
||||
class SystemInfoTests(RealHTTPMixin, LiveServerTestCase):
|
||||
"""system_info over real HTTP."""
|
||||
|
||||
def setUp(self):
|
||||
self._session_init()
|
||||
|
||||
def test_system_info_returns_os_data(self):
|
||||
data = self._call("system_info")
|
||||
|
||||
self.assertFalse(data["error"])
|
||||
self.assertEqual(data["data"]["os_name"], platform.system())
|
||||
self.assertEqual(data["data"]["hostname"], platform.node())
|
||||
self.assertGreater(data["data"]["cpu_count"], 0)
|
||||
|
||||
def test_system_info_returns_paths(self):
|
||||
data = self._call("system_info")
|
||||
|
||||
self.assertFalse(data["error"])
|
||||
self.assertEqual(data["data"]["home_dir"], str(Path.home()))
|
||||
self.assertEqual(data["data"]["cwd"], os.getcwd())
|
||||
|
||||
def test_disk_usage(self):
|
||||
data = self._call("disk_usage", {"path": "/"})
|
||||
|
||||
self.assertFalse(data["error"])
|
||||
self.assertGreater(data["data"]["total_gb"], 0)
|
||||
self.assertGreater(data["data"]["free_gb"], 0)
|
||||
self.assertGreaterEqual(data["data"]["percent_used"], 0)
|
||||
self.assertLessEqual(data["data"]["percent_used"], 100)
|
||||
|
||||
def test_app_info(self):
|
||||
data = self._call("app_info")
|
||||
|
||||
self.assertFalse(data["error"])
|
||||
self.assertEqual(data["data"]["app_name"], "Djarea Desktop")
|
||||
self.assertGreater(data["data"]["uptime_seconds"], 0)
|
||||
|
||||
|
||||
class FileSystemTests(RealHTTPMixin, LiveServerTestCase):
|
||||
"""File system RPC over real HTTP."""
|
||||
|
||||
def setUp(self):
|
||||
self._session_init()
|
||||
self.test_dir = Path.home() / ".djarea-test"
|
||||
self.test_dir.mkdir(exist_ok=True)
|
||||
|
||||
def tearDown(self):
|
||||
import shutil
|
||||
if self.test_dir.exists():
|
||||
shutil.rmtree(self.test_dir)
|
||||
|
||||
def test_list_files_home(self):
|
||||
data = self._call("list_files", {"directory": "~"})
|
||||
|
||||
self.assertFalse(data["error"])
|
||||
self.assertEqual(data["data"]["directory"], str(Path.home()))
|
||||
self.assertIsInstance(data["data"]["entries"], list)
|
||||
|
||||
def test_list_files_root_has_no_parent(self):
|
||||
data = self._call("list_files", {"directory": "/"})
|
||||
|
||||
self.assertFalse(data["error"])
|
||||
self.assertIsNone(data["data"]["parent"])
|
||||
|
||||
def test_write_and_read_file(self):
|
||||
"""Full round-trip over real HTTP: write, read back, verify."""
|
||||
test_path = str(self.test_dir / "test-note.txt")
|
||||
test_content = "Hello from a REAL HTTP integration test!"
|
||||
|
||||
# Write
|
||||
write_data = self._call("write_file", {"path": test_path, "content": test_content})
|
||||
self.assertFalse(write_data["error"])
|
||||
self.assertEqual(write_data["data"]["path"], test_path)
|
||||
|
||||
# Read back
|
||||
read_data = self._call("read_file", {"path": test_path})
|
||||
self.assertFalse(read_data["error"])
|
||||
self.assertEqual(read_data["data"]["content"], test_content)
|
||||
|
||||
def test_write_outside_home_rejected(self):
|
||||
"""Server should reject writes outside home directory."""
|
||||
from urllib.error import HTTPError
|
||||
|
||||
try:
|
||||
data = self._call("write_file", {"path": "/tmp/escape.txt", "content": "nope"})
|
||||
# If we get here, check the response has an error
|
||||
self.assertTrue(data["error"])
|
||||
self.assertEqual(data["code"], "FORBIDDEN")
|
||||
except HTTPError as e:
|
||||
# 403 is also acceptable
|
||||
self.assertEqual(e.code, 403)
|
||||
|
||||
def test_delete_file(self):
|
||||
test_path = str(self.test_dir / "to-delete.txt")
|
||||
(self.test_dir / "to-delete.txt").write_text("delete me")
|
||||
|
||||
data = self._call("delete_file", {"path": test_path})
|
||||
self.assertFalse(data["error"])
|
||||
self.assertTrue(data["data"]["deleted"])
|
||||
self.assertFalse(Path(test_path).exists())
|
||||
|
||||
def test_file_entries_have_metadata(self):
|
||||
(self.test_dir / "metadata-test.txt").write_text("hello")
|
||||
|
||||
data = self._call("list_files", {"directory": str(self.test_dir)})
|
||||
self.assertFalse(data["error"])
|
||||
self.assertGreater(len(data["data"]["entries"]), 0)
|
||||
|
||||
entry = data["data"]["entries"][0]
|
||||
self.assertIn("name", entry)
|
||||
self.assertIn("path", entry)
|
||||
self.assertIn("is_dir", entry)
|
||||
self.assertIn("size", entry)
|
||||
self.assertIn("modified", entry)
|
||||
4
django/.gitignore
vendored
Normal file
4
django/.gitignore
vendored
Normal file
@@ -0,0 +1,4 @@
|
||||
__pycache__/
|
||||
*.egg-info/
|
||||
dist/
|
||||
build/
|
||||
29
django/README.md
Normal file
29
django/README.md
Normal file
@@ -0,0 +1,29 @@
|
||||
# djarea
|
||||
|
||||
Django + React server functions framework. See the [monorepo root](../README.md) for full documentation.
|
||||
|
||||
## Installation
|
||||
|
||||
```bash
|
||||
# From git
|
||||
uv add "djarea[channels,allauth] @ git+https://git.impactsoundworks.com/isw/djarea.git#subdirectory=django"
|
||||
|
||||
# Local editable
|
||||
uv add -e "../../web/djarea/django[channels,allauth]"
|
||||
```
|
||||
|
||||
## Quick Start
|
||||
|
||||
```python
|
||||
from djarea.client import client
|
||||
from pydantic import BaseModel
|
||||
|
||||
class UserOutput(BaseModel):
|
||||
email: str
|
||||
|
||||
@client(context='global')
|
||||
def current_user(request) -> UserOutput | None:
|
||||
if not request.user.is_authenticated:
|
||||
return None
|
||||
return UserOutput(email=request.user.email)
|
||||
```
|
||||
42
django/pyproject.toml
Normal file
42
django/pyproject.toml
Normal file
@@ -0,0 +1,42 @@
|
||||
[project]
|
||||
name = "djarea"
|
||||
version = "1.0.1"
|
||||
description = "Django + React server functions framework"
|
||||
readme = "README.md"
|
||||
requires-python = ">=3.10"
|
||||
dependencies = [
|
||||
"django>=5.0",
|
||||
"django-ninja>=1.0",
|
||||
"pydantic>=2.0",
|
||||
"PyJWT>=2.0",
|
||||
]
|
||||
|
||||
[project.optional-dependencies]
|
||||
channels = [
|
||||
"channels>=4.0",
|
||||
"channels-redis>=4.0",
|
||||
]
|
||||
allauth = [
|
||||
"django-allauth>=65.0",
|
||||
]
|
||||
webauthn = [
|
||||
"fido2>=2.0",
|
||||
]
|
||||
dev = [
|
||||
"pytest>=8.0",
|
||||
"pytest-django>=4.9",
|
||||
]
|
||||
|
||||
[build-system]
|
||||
requires = ["hatchling"]
|
||||
build-backend = "hatchling.build"
|
||||
|
||||
[tool.hatch.build.targets.wheel]
|
||||
packages = ["src/djarea"]
|
||||
|
||||
[tool.pytest.ini_options]
|
||||
DJANGO_SETTINGS_MODULE = "tests.settings"
|
||||
pythonpath = ["src", "."]
|
||||
testpaths = ["src/djarea/tests"]
|
||||
python_classes = ["*Tests", "*Test", "Test*"]
|
||||
python_functions = ["test_*"]
|
||||
176
django/src/djarea/__init__.py
Normal file
176
django/src/djarea/__init__.py
Normal file
@@ -0,0 +1,176 @@
|
||||
"""
|
||||
Djarea - Django + React unified framework
|
||||
|
||||
Server functions are the core primitive. Everything else builds on them.
|
||||
|
||||
## Quick Start
|
||||
|
||||
### 1. urls.py - HTTP endpoint
|
||||
```python
|
||||
from djarea import urls as djarea_urls
|
||||
|
||||
urlpatterns = [
|
||||
path('api/djarea/', include(djarea_urls)),
|
||||
]
|
||||
```
|
||||
|
||||
### 2. asgi.py - WebSocket support (optional)
|
||||
```python
|
||||
from djarea import wrap_asgi
|
||||
from django.core.asgi import get_asgi_application
|
||||
|
||||
application = wrap_asgi(get_asgi_application())
|
||||
```
|
||||
|
||||
### 3. Define server functions
|
||||
```python
|
||||
# apps/myapp/clients.py
|
||||
from djarea import client
|
||||
from pydantic import BaseModel
|
||||
|
||||
class EchoOutput(BaseModel):
|
||||
message: str
|
||||
|
||||
# HTTP-only function (default)
|
||||
@client
|
||||
def echo(request, text: str) -> EchoOutput:
|
||||
return EchoOutput(message=f"Echo: {text}")
|
||||
|
||||
# Global context (singleton, SSR-hydrated)
|
||||
@client(context='global')
|
||||
def current_user(request) -> UserOutput:
|
||||
return UserOutput(email=request.user.email)
|
||||
|
||||
# WebSocket-enabled for real-time
|
||||
@client(websocket=True)
|
||||
def send_message(request, room_id: int, text: str) -> MessageOutput:
|
||||
return MessageOutput(...)
|
||||
```
|
||||
|
||||
### 4. Auto-discover in apps.py
|
||||
```python
|
||||
class MyAppConfig(AppConfig):
|
||||
def ready(self):
|
||||
from djarea.setup import djarea_clients
|
||||
djarea_clients('apps')
|
||||
```
|
||||
|
||||
### 5. Frontend - generate types and use
|
||||
```bash
|
||||
npm run schemas
|
||||
```
|
||||
```tsx
|
||||
import { useEcho, useCurrentUser } from '@/api'
|
||||
|
||||
const user = useCurrentUser()
|
||||
const echo = useEcho()
|
||||
await echo({ text: 'hello' })
|
||||
```
|
||||
|
||||
## What You Get
|
||||
|
||||
| Backend | Frontend | Transport |
|
||||
|------------------------------------|-----------------------|------------|
|
||||
| `@client` | `useXxx()` hook | HTTP |
|
||||
| `@client(context='global')` | `useXxx()` + SSR | HTTP |
|
||||
| `@client(context='local')` | `<XxxProvider>` + hook| HTTP |
|
||||
| `@client(websocket=True)` | `useXxx()` hook | WebSocket |
|
||||
| `@compose(...)` | `<XxxProvider>` combined | varies |
|
||||
| `DjareaFormMixin` | `useXxxForm()` + Zod | HTTP |
|
||||
| `ReactChannel` | `useXxxChannel()` | WebSocket |
|
||||
"""
|
||||
|
||||
# All imports at module level (sorted)
|
||||
from . import channels
|
||||
from . import client as client_module
|
||||
from . import export
|
||||
from . import forms
|
||||
from . import setup
|
||||
from .channels import ReactChannel
|
||||
from .channels import register as register_channel
|
||||
from .client import ComposedContext, ServerFunction, client, compose
|
||||
from .setup import (
|
||||
djarea_clients,
|
||||
djarea_module,
|
||||
get_channel,
|
||||
get_function,
|
||||
register,
|
||||
register_as,
|
||||
)
|
||||
|
||||
|
||||
def __getattr__(name):
|
||||
"""Lazy loading for urls to avoid circular imports."""
|
||||
if name == "urls":
|
||||
from .urls import urlpatterns as djarea_patterns
|
||||
|
||||
return djarea_patterns
|
||||
raise AttributeError(f"module {__name__!r} has no attribute {name!r}")
|
||||
|
||||
|
||||
def wrap_asgi(http_application):
|
||||
"""
|
||||
Wrap an ASGI application with Djarea WebSocket support.
|
||||
|
||||
Usage in asgi.py:
|
||||
from django.core.asgi import get_asgi_application
|
||||
from djarea import wrap_asgi
|
||||
|
||||
application = wrap_asgi(get_asgi_application())
|
||||
|
||||
This adds:
|
||||
- WebSocket routing at /ws/ for RPC and channels
|
||||
- Authentication middleware for WebSocket connections
|
||||
"""
|
||||
try:
|
||||
from channels.auth import AuthMiddlewareStack
|
||||
from channels.routing import ProtocolTypeRouter, URLRouter
|
||||
from django.urls import path
|
||||
except ImportError:
|
||||
raise ImportError(
|
||||
"django-channels is required for WebSocket support.\n"
|
||||
"Install with: pip install channels channels-redis\n"
|
||||
"Add 'channels' to INSTALLED_APPS and configure CHANNEL_LAYERS."
|
||||
)
|
||||
|
||||
from .channels.connection import DjangoReactConsumer
|
||||
|
||||
return ProtocolTypeRouter(
|
||||
{
|
||||
"http": http_application,
|
||||
"websocket": AuthMiddlewareStack(
|
||||
URLRouter(
|
||||
[
|
||||
path("ws/", DjangoReactConsumer.as_asgi()),
|
||||
]
|
||||
)
|
||||
),
|
||||
}
|
||||
)
|
||||
|
||||
|
||||
__all__ = [
|
||||
# Decorators
|
||||
"client",
|
||||
"compose",
|
||||
"ServerFunction",
|
||||
"ComposedContext",
|
||||
# Setup
|
||||
"djarea_clients",
|
||||
"djarea_module",
|
||||
"register",
|
||||
"register_as",
|
||||
"get_function",
|
||||
"get_channel",
|
||||
# ASGI
|
||||
"wrap_asgi",
|
||||
# Channels
|
||||
"ReactChannel",
|
||||
"register_channel",
|
||||
# Submodules
|
||||
"client_module",
|
||||
"setup",
|
||||
"forms",
|
||||
"channels",
|
||||
"export",
|
||||
]
|
||||
0
django/src/djarea/_vendor/__init__.py
Normal file
0
django/src/djarea/_vendor/__init__.py
Normal file
91
django/src/djarea/_vendor/app_visitor.py
Normal file
91
django/src/djarea/_vendor/app_visitor.py
Normal file
@@ -0,0 +1,91 @@
|
||||
import inspect
|
||||
from importlib import import_module
|
||||
from inspect import isclass
|
||||
from typing import Protocol, Any
|
||||
from django.conf import settings
|
||||
|
||||
|
||||
def get_members(path):
|
||||
try:
|
||||
module = import_module(path)
|
||||
except ModuleNotFoundError:
|
||||
print('Could not import module "{}"'.format(path))
|
||||
return []
|
||||
|
||||
members = [
|
||||
(name, member)
|
||||
for name, member in inspect.getmembers(module)
|
||||
if not isclass(member) or (member.__module__ == module.__name__)
|
||||
]
|
||||
|
||||
return members
|
||||
|
||||
|
||||
class DjangoAppVisitorHandler(Protocol):
|
||||
def on_module(
|
||||
self, app_name: str, path_parts: list[str], members: list[tuple[str, Any]]
|
||||
) -> None: ...
|
||||
|
||||
|
||||
class DjangoAppVisitor:
|
||||
"""
|
||||
Discovers Python modules under each Django app following conventions:
|
||||
- <app>/<module>.py -> url_prefix "<renamed>/"
|
||||
- <app>/<module>/**/*.py -> url_prefix "<renamed>/<subdirs...>/<module>/"
|
||||
|
||||
Example:
|
||||
<app>/<module>/forms/nksn.py -> url_prefix "<renamed>/forms/nksn/"
|
||||
module_path "<app>.module.forms.nksn"
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
layer: str,
|
||||
apps_root: str = "",
|
||||
):
|
||||
self.apps_root = apps_root
|
||||
self.layer = layer
|
||||
|
||||
def visit(self, handler: DjangoAppVisitorHandler) -> None:
|
||||
apps_dir = (
|
||||
settings.BASE_DIR / self.apps_root if self.apps_root else settings.BASE_DIR
|
||||
)
|
||||
if not apps_dir.is_dir():
|
||||
apps_dir = settings.BASE_DIR
|
||||
|
||||
module_prefix = f"{self.apps_root}." if self.apps_root else ""
|
||||
|
||||
for app_name in settings.INSTALLED_APPS:
|
||||
if app_name.startswith(self.apps_root + "."):
|
||||
app_name = app_name[(len(self.apps_root) + 1) :]
|
||||
|
||||
app_dir = apps_dir / app_name
|
||||
if not app_dir.exists():
|
||||
continue
|
||||
|
||||
app_module = f"{module_prefix}{app_name}"
|
||||
|
||||
# 1) Visit package: <app>/<module>/**/*.py
|
||||
layer_dir = app_dir / self.layer
|
||||
if layer_dir.is_dir():
|
||||
for py_file in layer_dir.rglob("*.py"):
|
||||
if py_file.name == "__init__.py":
|
||||
continue
|
||||
|
||||
relative_path = py_file.relative_to(layer_dir).with_suffix("")
|
||||
parts = list(relative_path.parts)
|
||||
|
||||
dotted = ".".join(parts)
|
||||
handler.on_module(
|
||||
app_name,
|
||||
parts,
|
||||
get_members(f"{app_module}.{self.layer}.{dotted}"),
|
||||
)
|
||||
|
||||
# 2) Visit module module file: <app>/module.py
|
||||
layer_file = app_dir / f"{self.layer}.py"
|
||||
if layer_file.is_file():
|
||||
handler.on_module(
|
||||
app_name, [], get_members(f"{app_module}.{self.layer}")
|
||||
)
|
||||
527
django/src/djarea/channels/__init__.py
Normal file
527
django/src/djarea/channels/__init__.py
Normal file
@@ -0,0 +1,527 @@
|
||||
"""
|
||||
djarea.channels - Real-time WebSocket communication.
|
||||
|
||||
Type-safe bidirectional messaging between Django and React via WebSockets.
|
||||
Hooks are auto-generated with full TypeScript types.
|
||||
|
||||
## Basic Usage
|
||||
|
||||
```python
|
||||
# channels.py
|
||||
from pydantic import BaseModel
|
||||
from djarea import channels
|
||||
|
||||
class ChatChannel(channels.ReactChannel):
|
||||
|
||||
class Params(BaseModel):
|
||||
room: str
|
||||
|
||||
class ReactMessage(BaseModel):
|
||||
text: str
|
||||
|
||||
class DjangoMessage(BaseModel):
|
||||
user: str
|
||||
text: str
|
||||
timestamp: datetime
|
||||
|
||||
def authorize(self, params: Params) -> bool:
|
||||
return self.user.is_authenticated
|
||||
|
||||
def group(self, params: Params) -> str:
|
||||
return f'chat_{params.room}'
|
||||
|
||||
def receive(self, params: Params, msg: ReactMessage) -> DjangoMessage | None:
|
||||
return self.DjangoMessage(
|
||||
user=self.user.email,
|
||||
text=msg.text,
|
||||
timestamp=now(),
|
||||
)
|
||||
|
||||
channels.register(ChatChannel, 'chat')
|
||||
```
|
||||
|
||||
```python
|
||||
# asgi.py
|
||||
from djarea import channels
|
||||
|
||||
application = ProtocolTypeRouter({
|
||||
"http": get_asgi_application(),
|
||||
"websocket": channels.get_websocket_application(),
|
||||
})
|
||||
```
|
||||
|
||||
## Frontend Usage (auto-generated)
|
||||
|
||||
```tsx
|
||||
import { useChatChannel } from '@/api/generated.channels'
|
||||
|
||||
function Chat({ room }) {
|
||||
const chat = useChatChannel({ room })
|
||||
|
||||
chat.status // 'connecting' | 'connected' | 'disconnected'
|
||||
chat.messages // DjangoMessage[]
|
||||
chat.send({ text: 'Hello' }) // ReactMessage
|
||||
}
|
||||
```
|
||||
|
||||
## Server Push
|
||||
|
||||
```python
|
||||
await ChatChannel.push(room='general', message=ChatChannel.DjangoMessage(...))
|
||||
```
|
||||
"""
|
||||
|
||||
import logging
|
||||
from typing import TYPE_CHECKING, Any, ClassVar, Type
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from django.contrib.auth.models import AbstractBaseUser, AnonymousUser
|
||||
from ninja import NinjaAPI
|
||||
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# Base Classes
|
||||
# =============================================================================
|
||||
|
||||
class ReactChannel:
|
||||
"""
|
||||
Base class for WebSocket channels.
|
||||
|
||||
Define nested Pydantic classes for typed messaging:
|
||||
- Params: Query parameters for subscribing (optional)
|
||||
- ReactMessage: Messages from browser to server (optional)
|
||||
- DjangoMessage: Messages from server to browser (optional)
|
||||
|
||||
Implement required methods:
|
||||
- authorize(): Permission check for connection
|
||||
- group(): Which group to broadcast to
|
||||
|
||||
Optionally implement:
|
||||
- receive(): Handle incoming ReactMessage, return DjangoMessage to broadcast
|
||||
- on_connect(): Called after successful connection
|
||||
- on_disconnect(): Called when connection closes
|
||||
"""
|
||||
|
||||
# Nested classes (optional, defined by subclasses)
|
||||
Params: ClassVar[Type[BaseModel] | None] = None
|
||||
ReactMessage: ClassVar[Type[BaseModel] | None] = None
|
||||
DjangoMessage: ClassVar[Type[BaseModel] | None] = None
|
||||
|
||||
# Set by the framework when handling a connection
|
||||
user: "AbstractBaseUser | AnonymousUser"
|
||||
_channel_layer: Any = None
|
||||
_channel_name: str = ""
|
||||
_registered_name: ClassVar[str] = ""
|
||||
_params_dict: dict = {}
|
||||
_groups: set[str]
|
||||
|
||||
def __init__(self):
|
||||
self._groups = set()
|
||||
self._params_dict = {}
|
||||
|
||||
def authorize(self, params: BaseModel | None = None) -> bool:
|
||||
"""
|
||||
Permission check. Return True to allow connection, False to reject.
|
||||
|
||||
Override this to implement custom authorization logic.
|
||||
"""
|
||||
raise NotImplementedError(
|
||||
f"{self.__class__.__name__} must implement authorize()"
|
||||
)
|
||||
|
||||
def group(self, params: BaseModel | None = None) -> str:
|
||||
"""
|
||||
Return the group name for broadcasting.
|
||||
|
||||
Messages returned from receive() are broadcast to this group.
|
||||
"""
|
||||
raise NotImplementedError(
|
||||
f"{self.__class__.__name__} must implement group()"
|
||||
)
|
||||
|
||||
def receive(self, params: BaseModel | None, msg: BaseModel) -> BaseModel | None:
|
||||
"""
|
||||
Handle incoming ReactMessage.
|
||||
|
||||
Return a DjangoMessage to broadcast to the group, or None to skip.
|
||||
Override this to implement message handling.
|
||||
"""
|
||||
return None
|
||||
|
||||
async def on_connect(self, params: BaseModel | None = None) -> None:
|
||||
"""Called after successful connection and group join."""
|
||||
pass
|
||||
|
||||
async def on_disconnect(self) -> None:
|
||||
"""Called when the connection closes."""
|
||||
pass
|
||||
|
||||
# -------------------------------------------------------------------------
|
||||
# Internal Methods (used by the consumer)
|
||||
# -------------------------------------------------------------------------
|
||||
|
||||
async def _join_group(self, group_name: str) -> None:
|
||||
"""Join a channel layer group."""
|
||||
if self._channel_layer:
|
||||
await self._channel_layer.group_add(group_name, self._channel_name)
|
||||
self._groups.add(group_name)
|
||||
|
||||
async def _leave_group(self, group_name: str) -> None:
|
||||
"""Leave a channel layer group."""
|
||||
if self._channel_layer and group_name in self._groups:
|
||||
await self._channel_layer.group_discard(group_name, self._channel_name)
|
||||
self._groups.discard(group_name)
|
||||
|
||||
async def _leave_all_groups(self) -> None:
|
||||
"""Leave all joined groups."""
|
||||
for group_name in list(self._groups):
|
||||
await self._leave_group(group_name)
|
||||
|
||||
async def _broadcast(self, group_name: str, message: BaseModel) -> None:
|
||||
"""Broadcast a message to a group."""
|
||||
if self._channel_layer:
|
||||
await self._channel_layer.group_send(
|
||||
group_name,
|
||||
{
|
||||
"type": "channel.message",
|
||||
"channel": self._registered_name,
|
||||
"params": self._params_dict,
|
||||
"data": message.model_dump(mode='json'),
|
||||
"message_type": message.__class__.__name__,
|
||||
}
|
||||
)
|
||||
|
||||
# -------------------------------------------------------------------------
|
||||
# Class Methods for Server Push
|
||||
# -------------------------------------------------------------------------
|
||||
|
||||
@classmethod
|
||||
async def push(cls, message: BaseModel, **params) -> None:
|
||||
"""
|
||||
Push a message from server code (views, tasks, signals).
|
||||
|
||||
Usage:
|
||||
await ChatChannel.push(
|
||||
room='general',
|
||||
message=ChatChannel.DjangoMessage(user='system', text='Hello')
|
||||
)
|
||||
"""
|
||||
from channels.layers import get_channel_layer
|
||||
|
||||
channel_layer = get_channel_layer()
|
||||
if not channel_layer:
|
||||
logger.warning(f"No channel layer configured, cannot push to {cls.__name__}")
|
||||
return
|
||||
|
||||
# Build params model if defined
|
||||
params_obj = None
|
||||
if cls.Params:
|
||||
params_obj = cls.Params(**params)
|
||||
|
||||
# Get group name
|
||||
instance = cls()
|
||||
group_name = instance.group(params_obj)
|
||||
|
||||
# Send to group
|
||||
await channel_layer.group_send(
|
||||
group_name,
|
||||
{
|
||||
"type": "channel.message",
|
||||
"channel": cls._registered_name,
|
||||
"params": params,
|
||||
"data": message.model_dump(mode='json'),
|
||||
"message_type": message.__class__.__name__,
|
||||
}
|
||||
)
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# Registry
|
||||
# =============================================================================
|
||||
|
||||
_registry: dict[str, Type[ReactChannel]] = {}
|
||||
|
||||
|
||||
def register(channel_class: Type[ReactChannel], name: str) -> None:
|
||||
"""
|
||||
Register a channel.
|
||||
|
||||
Args:
|
||||
channel_class: The ReactChannel subclass to register
|
||||
name: URL-friendly name (used in subscriptions)
|
||||
"""
|
||||
if name in _registry:
|
||||
raise ValueError(f"Channel '{name}' is already registered")
|
||||
|
||||
channel_class._registered_name = name
|
||||
|
||||
# Validate the channel class
|
||||
if not hasattr(channel_class, 'authorize'):
|
||||
raise ValueError(f"{channel_class.__name__} must implement authorize()")
|
||||
if not hasattr(channel_class, 'group'):
|
||||
raise ValueError(f"{channel_class.__name__} must implement group()")
|
||||
|
||||
_registry[name] = channel_class
|
||||
logger.debug(f"Registered channel: {name} -> {channel_class.__name__}")
|
||||
|
||||
|
||||
def get_channel(name: str) -> Type[ReactChannel] | None:
|
||||
"""Get a registered channel class by name."""
|
||||
return _registry.get(name)
|
||||
|
||||
|
||||
def get_registered_channels() -> dict[str, Type[ReactChannel]]:
|
||||
"""Get all registered channel classes."""
|
||||
return dict(_registry)
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# WebSocket Consumer
|
||||
# =============================================================================
|
||||
|
||||
def get_websocket_application():
|
||||
"""
|
||||
Get the WebSocket application for ASGI.
|
||||
|
||||
Usage in asgi.py:
|
||||
from djarea import channels
|
||||
|
||||
application = ProtocolTypeRouter({
|
||||
"http": get_asgi_application(),
|
||||
"websocket": channels.get_websocket_application(),
|
||||
})
|
||||
"""
|
||||
try:
|
||||
from channels.routing import URLRouter
|
||||
from channels.auth import AuthMiddlewareStack
|
||||
from django.urls import path
|
||||
except ImportError:
|
||||
raise ImportError(
|
||||
"django-channels is required for WebSocket support. "
|
||||
"Install it with: pip install channels channels-redis"
|
||||
)
|
||||
|
||||
from .connection import DjangoReactConsumer
|
||||
|
||||
return AuthMiddlewareStack(
|
||||
URLRouter([
|
||||
path("ws/", DjangoReactConsumer.as_asgi()),
|
||||
])
|
||||
)
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# Schema Export (for TypeScript generation)
|
||||
# =============================================================================
|
||||
|
||||
def get_channels_schema() -> dict:
|
||||
"""
|
||||
Get schema for all registered channels (for TypeScript generation).
|
||||
|
||||
Returns a dict suitable for the frontend code generator.
|
||||
"""
|
||||
schema = {
|
||||
"channels": {}
|
||||
}
|
||||
|
||||
for name, channel_class in _registry.items():
|
||||
channel_schema = {
|
||||
"name": name,
|
||||
"params": None,
|
||||
"reactMessage": None,
|
||||
"djangoMessage": None,
|
||||
}
|
||||
|
||||
# Extract Params schema
|
||||
if hasattr(channel_class, 'Params') and channel_class.Params:
|
||||
channel_schema["params"] = channel_class.Params.model_json_schema()
|
||||
|
||||
# Extract ReactMessage schema
|
||||
if hasattr(channel_class, 'ReactMessage') and channel_class.ReactMessage:
|
||||
channel_schema["reactMessage"] = channel_class.ReactMessage.model_json_schema()
|
||||
|
||||
# Extract DjangoMessage schema
|
||||
if hasattr(channel_class, 'DjangoMessage') and channel_class.DjangoMessage:
|
||||
channel_schema["djangoMessage"] = channel_class.DjangoMessage.model_json_schema()
|
||||
|
||||
schema["channels"][name] = channel_schema
|
||||
|
||||
return schema
|
||||
|
||||
|
||||
def _register_channel_schema_endpoint(
|
||||
api: "NinjaAPI",
|
||||
path: str,
|
||||
operation_id: str,
|
||||
summary: str,
|
||||
input_cls: type | None,
|
||||
output_cls: type,
|
||||
) -> None:
|
||||
"""Register a dummy endpoint for schema generation (avoids closure issues)."""
|
||||
if input_cls is not None:
|
||||
def endpoint(request, data):
|
||||
pass
|
||||
endpoint.__annotations__ = {"data": input_cls}
|
||||
else:
|
||||
def endpoint(request):
|
||||
pass
|
||||
|
||||
api.post(path, response=output_cls, operation_id=operation_id, summary=summary)(endpoint)
|
||||
|
||||
|
||||
def get_channels_openapi_schema() -> dict:
|
||||
"""
|
||||
Get OpenAPI schema for all registered channels.
|
||||
|
||||
Uses Django Ninja's schema generation for robust Pydantic→OpenAPI conversion.
|
||||
This schema is consumed by openapi-typescript for type generation.
|
||||
"""
|
||||
from ninja import NinjaAPI
|
||||
from pydantic import BaseModel
|
||||
|
||||
# Create temporary Ninja API for schema generation only
|
||||
schema_api = NinjaAPI(
|
||||
title="Djarea Channels",
|
||||
version="1.0.0",
|
||||
description="Auto-generated schema for djarea channels",
|
||||
docs_url=None,
|
||||
openapi_url=None,
|
||||
)
|
||||
|
||||
# Store dynamically created classes
|
||||
schema_classes: dict[str, type] = {}
|
||||
channel_metadata: list[dict] = []
|
||||
|
||||
for name, channel_class in _registry.items():
|
||||
pascal_name = name.replace("_", " ").title().replace(" ", "")
|
||||
|
||||
channel_meta = {
|
||||
"name": name,
|
||||
"pascalName": pascal_name,
|
||||
"hasParams": False,
|
||||
"hasReactMessage": False,
|
||||
"hasDjangoMessage": False,
|
||||
}
|
||||
|
||||
# Register Params type
|
||||
if hasattr(channel_class, 'Params') and channel_class.Params:
|
||||
params_name = f"{pascal_name}Params"
|
||||
schema_classes[params_name] = type(params_name, (channel_class.Params,), {})
|
||||
channel_meta["hasParams"] = True
|
||||
channel_meta["paramsType"] = params_name
|
||||
|
||||
# Create dummy endpoint to include in schema
|
||||
_register_channel_schema_endpoint(
|
||||
api=schema_api,
|
||||
path=f"/channels/{name}/params",
|
||||
operation_id=f"{name}Params",
|
||||
summary=f"{pascal_name} channel params",
|
||||
input_cls=schema_classes[params_name],
|
||||
output_cls=BaseModel,
|
||||
)
|
||||
|
||||
# Register ReactMessage type
|
||||
if hasattr(channel_class, 'ReactMessage') and channel_class.ReactMessage:
|
||||
react_name = f"{pascal_name}ReactMessage"
|
||||
schema_classes[react_name] = type(react_name, (channel_class.ReactMessage,), {})
|
||||
channel_meta["hasReactMessage"] = True
|
||||
channel_meta["reactMessageType"] = react_name
|
||||
|
||||
_register_channel_schema_endpoint(
|
||||
api=schema_api,
|
||||
path=f"/channels/{name}/react",
|
||||
operation_id=f"{name}ReactMessage",
|
||||
summary=f"{pascal_name} React→Django message",
|
||||
input_cls=schema_classes[react_name],
|
||||
output_cls=BaseModel,
|
||||
)
|
||||
|
||||
# Register DjangoMessage type
|
||||
if hasattr(channel_class, 'DjangoMessage') and channel_class.DjangoMessage:
|
||||
django_name = f"{pascal_name}DjangoMessage"
|
||||
schema_classes[django_name] = type(django_name, (channel_class.DjangoMessage,), {})
|
||||
channel_meta["hasDjangoMessage"] = True
|
||||
channel_meta["djangoMessageType"] = django_name
|
||||
|
||||
_register_channel_schema_endpoint(
|
||||
api=schema_api,
|
||||
path=f"/channels/{name}/django",
|
||||
operation_id=f"{name}DjangoMessage",
|
||||
summary=f"{pascal_name} Django→React message",
|
||||
input_cls=None,
|
||||
output_cls=schema_classes[django_name],
|
||||
)
|
||||
|
||||
channel_metadata.append(channel_meta)
|
||||
|
||||
# Get OpenAPI schema from Ninja
|
||||
# path_prefix="" avoids URL reverse() — this API is never mounted
|
||||
schema = schema_api.get_openapi_schema(path_prefix="")
|
||||
|
||||
# Add channel metadata extension
|
||||
schema["x-djarea-channels"] = channel_metadata
|
||||
|
||||
return schema
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# Schema Endpoint (for TypeScript generation)
|
||||
# =============================================================================
|
||||
|
||||
_schema_router = None
|
||||
|
||||
|
||||
def _get_schema_router():
|
||||
"""Get the Ninja router for the channels schema endpoint."""
|
||||
global _schema_router
|
||||
if _schema_router is None:
|
||||
from ninja import Router
|
||||
|
||||
_schema_router = Router(tags=["channels"])
|
||||
|
||||
@_schema_router.get("/schema/")
|
||||
def channels_schema(request):
|
||||
"""Get schema for all registered channels (for TypeScript generation)."""
|
||||
return get_channels_schema()
|
||||
|
||||
return _schema_router
|
||||
|
||||
|
||||
def get_urls():
|
||||
"""Get URL patterns for channels schema endpoint."""
|
||||
from ninja import NinjaAPI
|
||||
|
||||
api = NinjaAPI(urls_namespace="django_react_channels")
|
||||
api.add_router("/", _get_schema_router())
|
||||
return api.urls
|
||||
|
||||
|
||||
def __getattr__(name):
|
||||
if name == "urls":
|
||||
return get_urls()
|
||||
raise AttributeError(f"module {__name__!r} has no attribute {name!r}")
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# Exports
|
||||
# =============================================================================
|
||||
|
||||
__all__ = [
|
||||
# URLs
|
||||
"urls",
|
||||
# Base class
|
||||
"ReactChannel",
|
||||
# Registration
|
||||
"register",
|
||||
"get_channel",
|
||||
"get_registered_channels",
|
||||
# ASGI application
|
||||
"get_websocket_application",
|
||||
# Schema export
|
||||
"get_channels_schema",
|
||||
]
|
||||
482
django/src/djarea/channels/connection.py
Normal file
482
django/src/djarea/channels/connection.py
Normal file
@@ -0,0 +1,482 @@
|
||||
"""
|
||||
WebSocket consumer for djarea.channels.
|
||||
|
||||
Handles multiplexed channel subscriptions AND RPC calls over a single WebSocket connection.
|
||||
|
||||
Protocol:
|
||||
Browser sends:
|
||||
# Channel subscriptions
|
||||
{"action": "subscribe", "channel": "chat", "params": {"room": "general"}}
|
||||
{"action": "unsubscribe", "channel": "chat", "params": {"room": "general"}}
|
||||
{"action": "message", "channel": "chat", "params": {"room": "general"}, "data": {...}}
|
||||
|
||||
# RPC calls (server functions)
|
||||
{"action": "rpc", "id": "request-id", "fn": "function_name", "args": {...}}
|
||||
|
||||
Server sends:
|
||||
# Channel messages
|
||||
{"channel": "chat", "params": {"room": "general"}, "type": "DjangoMessage", "data": {...}}
|
||||
|
||||
# RPC responses
|
||||
{"id": "request-id", "ok": true, "data": {...}}
|
||||
{"id": "request-id", "ok": false, "error": {...}}
|
||||
|
||||
{"error": "..."}
|
||||
|
||||
Authentication:
|
||||
Supports both session (cookie) and JWT authentication:
|
||||
- Session: Handled automatically via AuthMiddlewareStack (cookies in handshake)
|
||||
- JWT: Pass token as query parameter: ws://...?token=<jwt>
|
||||
|
||||
The WebSocket URL for JWT auth would be: ws://localhost/ws/?token=<access_token>
|
||||
|
||||
Security:
|
||||
- Functions must be explicitly registered (no arbitrary code execution)
|
||||
- Pydantic validation runs BEFORE any function code
|
||||
"""
|
||||
|
||||
import json
|
||||
import logging
|
||||
from typing import Any
|
||||
from urllib.parse import parse_qs
|
||||
|
||||
from channels.generic.websocket import AsyncJsonWebsocketConsumer
|
||||
from asgiref.sync import sync_to_async
|
||||
from . import get_channel
|
||||
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class WebSocketRequest:
|
||||
"""
|
||||
Minimal request adapter for WebSocket context.
|
||||
|
||||
Provides the interface expected by ServerFunction without full HttpRequest.
|
||||
This is intentionally minimal - only expose what's needed.
|
||||
|
||||
Note: Some Django libraries (e.g., allauth rate limiting) check request.method.
|
||||
We set method="POST" since WebSocket RPC calls are semantically similar to POST.
|
||||
"""
|
||||
|
||||
# WebSocket RPC is semantically similar to POST (sends data, expects response)
|
||||
method = "POST"
|
||||
|
||||
def __init__(self, scope: dict, channel_name: str = None):
|
||||
self.user = scope.get("user")
|
||||
self.session = scope.get("session", {})
|
||||
self.channel_name = channel_name # For push subscriptions
|
||||
self._scope = scope
|
||||
|
||||
@property
|
||||
def META(self) -> dict:
|
||||
"""HTTP headers from WebSocket handshake."""
|
||||
headers = dict(self._scope.get("headers", []))
|
||||
return {
|
||||
"HTTP_" + k.decode().upper().replace("-", "_"): v.decode()
|
||||
for k, v in headers.items()
|
||||
}
|
||||
|
||||
|
||||
class DjangoReactConsumer(AsyncJsonWebsocketConsumer):
|
||||
"""
|
||||
Multiplexed WebSocket consumer for django_react channels.
|
||||
|
||||
Manages multiple channel subscriptions over a single WebSocket connection.
|
||||
|
||||
Authentication:
|
||||
- Session auth via cookies (handled by AuthMiddlewareStack)
|
||||
- JWT auth via query parameter: ws://...?token=<jwt>
|
||||
"""
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
super().__init__(*args, **kwargs)
|
||||
# Track subscriptions: {(channel_name, params_json): channel_instance}
|
||||
self._subscriptions: dict[tuple[str, str], Any] = {}
|
||||
|
||||
async def connect(self):
|
||||
"""Accept the WebSocket connection, authenticating via JWT if provided."""
|
||||
# Check for JWT token in query parameters
|
||||
await self._try_jwt_auth()
|
||||
|
||||
await self.accept()
|
||||
logger.debug(f"WebSocket connected: {self.channel_name}, user={self.scope.get('user')}")
|
||||
|
||||
async def _try_jwt_auth(self):
|
||||
"""
|
||||
Attempt JWT authentication from query parameter.
|
||||
|
||||
If a valid JWT token is provided via ?token=<jwt>, authenticate the user
|
||||
using JWTUser (no database query).
|
||||
|
||||
Security: If JWT is provided but invalid, we log it but don't reject
|
||||
the connection - the session auth may still be valid. However, if JWT
|
||||
IS valid, it takes precedence over session auth.
|
||||
"""
|
||||
# Parse query string for token
|
||||
query_string = self.scope.get("query_string", b"").decode()
|
||||
params = parse_qs(query_string)
|
||||
token_list = params.get("token", [])
|
||||
|
||||
if not token_list:
|
||||
return # No JWT provided, use session auth
|
||||
|
||||
token = token_list[0]
|
||||
if not token:
|
||||
return
|
||||
|
||||
# Validate JWT and create JWTUser (no DB query)
|
||||
try:
|
||||
from djarea.client.jwt import decode_token
|
||||
from djarea.jwt.tokens import JWTUser
|
||||
|
||||
payload = await sync_to_async(decode_token)(token, expected_type="access")
|
||||
if payload is None:
|
||||
logger.debug("JWT token invalid or expired")
|
||||
return # Fall back to session auth
|
||||
|
||||
# Create JWTUser from token claims - NO DATABASE QUERY
|
||||
self.scope["user"] = JWTUser(payload)
|
||||
logger.debug(f"JWT auth successful for user {payload.user_id}")
|
||||
except Exception as e:
|
||||
logger.debug(f"JWT auth failed: {e}")
|
||||
|
||||
async def disconnect(self, close_code):
|
||||
"""Clean up all subscriptions on disconnect."""
|
||||
for key, instance in list(self._subscriptions.items()):
|
||||
try:
|
||||
await instance.on_disconnect()
|
||||
await instance._leave_all_groups()
|
||||
except Exception as e:
|
||||
logger.error(f"Error during disconnect cleanup: {e}")
|
||||
|
||||
self._subscriptions.clear()
|
||||
logger.debug(f"WebSocket disconnected: {self.channel_name}")
|
||||
|
||||
async def receive_json(self, content: dict):
|
||||
"""Handle incoming JSON messages."""
|
||||
action = content.get("action")
|
||||
|
||||
if action == "subscribe":
|
||||
await self._handle_subscribe(content)
|
||||
elif action == "unsubscribe":
|
||||
await self._handle_unsubscribe(content)
|
||||
elif action == "message":
|
||||
await self._handle_message(content)
|
||||
elif action == "rpc":
|
||||
await self._handle_rpc(content)
|
||||
else:
|
||||
await self.send_json({
|
||||
"error": f"Unknown action: {action}",
|
||||
})
|
||||
|
||||
async def _handle_subscribe(self, content: dict):
|
||||
"""Handle subscription request."""
|
||||
channel_name = content.get("channel")
|
||||
params_dict = content.get("params", {})
|
||||
|
||||
# Get channel class
|
||||
channel_class = get_channel(channel_name)
|
||||
if not channel_class:
|
||||
await self.send_json({
|
||||
"error": f"Unknown channel: {channel_name}",
|
||||
})
|
||||
return
|
||||
|
||||
# Create subscription key
|
||||
params_json = json.dumps(params_dict, sort_keys=True)
|
||||
sub_key = (channel_name, params_json)
|
||||
|
||||
# Check if already subscribed
|
||||
if sub_key in self._subscriptions:
|
||||
await self.send_json({
|
||||
"error": f"Already subscribed to {channel_name}",
|
||||
"channel": channel_name,
|
||||
"params": params_dict,
|
||||
})
|
||||
return
|
||||
|
||||
# Create channel instance
|
||||
instance = channel_class()
|
||||
instance.user = self.scope.get("user")
|
||||
instance._channel_layer = self.channel_layer
|
||||
instance._channel_name = self.channel_name
|
||||
instance._registered_name = channel_name
|
||||
instance._params_dict = params_dict
|
||||
|
||||
# Parse params
|
||||
params_obj = None
|
||||
if channel_class.Params:
|
||||
try:
|
||||
params_obj = channel_class.Params(**params_dict)
|
||||
except Exception as e:
|
||||
await self.send_json({
|
||||
"error": f"Invalid params: {e}",
|
||||
"channel": channel_name,
|
||||
})
|
||||
return
|
||||
|
||||
# Check authorization
|
||||
try:
|
||||
if params_obj:
|
||||
authorized = instance.authorize(params_obj)
|
||||
else:
|
||||
authorized = instance.authorize()
|
||||
except Exception as e:
|
||||
logger.error(f"Authorization error for {channel_name}: {e}")
|
||||
await self.send_json({
|
||||
"error": "Authorization failed",
|
||||
"channel": channel_name,
|
||||
})
|
||||
return
|
||||
|
||||
if not authorized:
|
||||
await self.send_json({
|
||||
"error": "Not authorized",
|
||||
"channel": channel_name,
|
||||
})
|
||||
return
|
||||
|
||||
# Get group and join
|
||||
try:
|
||||
if params_obj:
|
||||
group_name = instance.group(params_obj)
|
||||
else:
|
||||
group_name = instance.group()
|
||||
await instance._join_group(group_name)
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to join group for {channel_name}: {e}")
|
||||
await self.send_json({
|
||||
"error": f"Failed to subscribe: {e}",
|
||||
"channel": channel_name,
|
||||
})
|
||||
return
|
||||
|
||||
# Store subscription
|
||||
self._subscriptions[sub_key] = instance
|
||||
|
||||
# Call on_connect hook
|
||||
try:
|
||||
await instance.on_connect(params_obj)
|
||||
except Exception as e:
|
||||
logger.error(f"on_connect error for {channel_name}: {e}")
|
||||
|
||||
# Confirm subscription
|
||||
await self.send_json({
|
||||
"subscribed": True,
|
||||
"channel": channel_name,
|
||||
"params": params_dict,
|
||||
})
|
||||
|
||||
logger.debug(f"Subscribed to {channel_name} with params {params_dict}")
|
||||
|
||||
async def _handle_unsubscribe(self, content: dict):
|
||||
"""Handle unsubscription request."""
|
||||
channel_name = content.get("channel")
|
||||
params_dict = content.get("params", {})
|
||||
|
||||
params_json = json.dumps(params_dict, sort_keys=True)
|
||||
sub_key = (channel_name, params_json)
|
||||
|
||||
instance = self._subscriptions.pop(sub_key, None)
|
||||
if instance:
|
||||
try:
|
||||
await instance.on_disconnect()
|
||||
await instance._leave_all_groups()
|
||||
except Exception as e:
|
||||
logger.error(f"Error during unsubscribe: {e}")
|
||||
|
||||
await self.send_json({
|
||||
"unsubscribed": True,
|
||||
"channel": channel_name,
|
||||
"params": params_dict,
|
||||
})
|
||||
|
||||
logger.debug(f"Unsubscribed from {channel_name}")
|
||||
|
||||
async def _handle_message(self, content: dict):
|
||||
"""Handle incoming message from browser."""
|
||||
channel_name = content.get("channel")
|
||||
params_dict = content.get("params", {})
|
||||
data = content.get("data", {})
|
||||
|
||||
params_json = json.dumps(params_dict, sort_keys=True)
|
||||
sub_key = (channel_name, params_json)
|
||||
|
||||
instance = self._subscriptions.get(sub_key)
|
||||
if not instance:
|
||||
await self.send_json({
|
||||
"error": f"Not subscribed to {channel_name}",
|
||||
"channel": channel_name,
|
||||
})
|
||||
return
|
||||
|
||||
channel_class = instance.__class__
|
||||
|
||||
# Check if channel accepts messages
|
||||
if not channel_class.ReactMessage:
|
||||
await self.send_json({
|
||||
"error": f"Channel {channel_name} does not accept messages",
|
||||
"channel": channel_name,
|
||||
})
|
||||
return
|
||||
|
||||
# Parse message
|
||||
try:
|
||||
msg = channel_class.ReactMessage(**data)
|
||||
except Exception as e:
|
||||
await self.send_json({
|
||||
"error": f"Invalid message: {e}",
|
||||
"channel": channel_name,
|
||||
})
|
||||
return
|
||||
|
||||
# Parse params
|
||||
params_obj = None
|
||||
if channel_class.Params:
|
||||
params_obj = channel_class.Params(**params_dict)
|
||||
|
||||
# Handle message
|
||||
try:
|
||||
response = instance.receive(params_obj, msg)
|
||||
|
||||
# If handler returned a message, broadcast it
|
||||
if response is not None:
|
||||
if params_obj:
|
||||
group_name = instance.group(params_obj)
|
||||
else:
|
||||
group_name = instance.group()
|
||||
|
||||
await instance._broadcast(group_name, response)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error handling message for {channel_name}: {e}")
|
||||
await self.send_json({
|
||||
"error": f"Message handling failed: {e}",
|
||||
"channel": channel_name,
|
||||
})
|
||||
|
||||
async def _handle_rpc(self, content: dict):
|
||||
"""
|
||||
Handle RPC (server function) call.
|
||||
|
||||
Protocol:
|
||||
Request: {"action": "rpc", "id": "request-id", "fn": "function_name", "args": {...}}
|
||||
Response: {"id": "request-id", "ok": true, "data": {...}}
|
||||
or: {"id": "request-id", "ok": false, "error": {...}}
|
||||
|
||||
Security:
|
||||
- Only functions with @client(websocket=True) are allowed
|
||||
- Pydantic validation happens BEFORE any function code runs
|
||||
- Function must be explicitly registered (no arbitrary code execution)
|
||||
- User context from WebSocket session is passed to function
|
||||
"""
|
||||
from djarea.client.executor import execute_function, FunctionError
|
||||
from djarea.setup.registry import get_function
|
||||
|
||||
request_id = content.get("id")
|
||||
fn_name = content.get("fn")
|
||||
args = content.get("args", {})
|
||||
|
||||
# Validate request structure
|
||||
if not request_id:
|
||||
await self.send_json({
|
||||
"error": "RPC request missing 'id' field",
|
||||
})
|
||||
return
|
||||
|
||||
if not fn_name:
|
||||
await self.send_json({
|
||||
"id": request_id,
|
||||
"ok": False,
|
||||
"error": {
|
||||
"code": "BAD_REQUEST",
|
||||
"message": "Missing 'fn' field",
|
||||
},
|
||||
})
|
||||
return
|
||||
|
||||
# Check if function exists and has websocket=True
|
||||
fn_class = get_function(fn_name)
|
||||
if fn_class is None:
|
||||
await self.send_json({
|
||||
"id": request_id,
|
||||
"ok": False,
|
||||
"error": {
|
||||
"code": "NOT_FOUND",
|
||||
"message": f"Function '{fn_name}' not found",
|
||||
},
|
||||
})
|
||||
return
|
||||
|
||||
# Only allow functions explicitly marked with websocket=True
|
||||
fn_meta = getattr(fn_class, "_meta", {})
|
||||
if not fn_meta.get("websocket"):
|
||||
await self.send_json({
|
||||
"id": request_id,
|
||||
"ok": False,
|
||||
"error": {
|
||||
"code": "FORBIDDEN",
|
||||
"message": "This function is HTTP-only. Use POST /api/djarea/call/ instead.",
|
||||
},
|
||||
})
|
||||
return
|
||||
|
||||
# Create request adapter from WebSocket scope
|
||||
ws_request = WebSocketRequest(self.scope, channel_name=getattr(self, 'channel_name', None))
|
||||
|
||||
# Execute function (Pydantic validation happens inside execute_function)
|
||||
# This is sync, so we need to run it in a thread pool
|
||||
result = await sync_to_async(execute_function, thread_sensitive=True)(
|
||||
ws_request,
|
||||
fn_name,
|
||||
args,
|
||||
)
|
||||
|
||||
# Send response
|
||||
if isinstance(result, FunctionError):
|
||||
await self.send_json({
|
||||
"id": request_id,
|
||||
"ok": False,
|
||||
"error": {
|
||||
"code": result.code.value,
|
||||
"message": result.message,
|
||||
**({"details": result.details} if result.details else {}),
|
||||
},
|
||||
})
|
||||
else:
|
||||
await self.send_json({
|
||||
"id": request_id,
|
||||
"ok": True,
|
||||
"data": result.data,
|
||||
})
|
||||
|
||||
async def channel_message(self, event: dict):
|
||||
"""
|
||||
Handle messages broadcast to a group.
|
||||
|
||||
Called when channel_layer.group_send() is used.
|
||||
Includes channel name and params so the client can route the message.
|
||||
"""
|
||||
await self.send_json({
|
||||
"channel": event.get("channel"),
|
||||
"params": event.get("params", {}),
|
||||
"type": event.get("message_type", "message"),
|
||||
"data": event.get("data", {}),
|
||||
})
|
||||
|
||||
async def push_message(self, event: dict):
|
||||
"""
|
||||
Handle push messages from server functions.
|
||||
|
||||
Called when push("topic", data) is used from a server function.
|
||||
The client receives this to update its local state.
|
||||
|
||||
Protocol:
|
||||
Server sends: {"type": "push", "topic": "room:42", "data": {...}}
|
||||
"""
|
||||
await self.send_json({
|
||||
"type": "push",
|
||||
"topic": event.get("topic"),
|
||||
"data": event.get("data", {}),
|
||||
})
|
||||
150
django/src/djarea/channels/push.py
Normal file
150
django/src/djarea/channels/push.py
Normal file
@@ -0,0 +1,150 @@
|
||||
"""
|
||||
Djarea Push - Server-initiated messages to clients.
|
||||
|
||||
Simple API for pushing data to subscribed WebSocket connections.
|
||||
|
||||
Usage:
|
||||
# In a server function - push to all subscribers
|
||||
from djarea.push import push
|
||||
|
||||
push("room:42", {"type": "new_message", "data": {...}})
|
||||
|
||||
# Subscribe a connection to a topic (call during context fetch)
|
||||
from djarea.push import subscribe
|
||||
|
||||
subscribe(request, "room:42")
|
||||
"""
|
||||
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
# Lazy import to avoid import errors when channels is not installed
|
||||
# (e.g., during schema generation)
|
||||
if TYPE_CHECKING:
|
||||
from channels.layers import BaseChannelLayer
|
||||
|
||||
|
||||
def _get_channel_layer() -> "BaseChannelLayer | None":
|
||||
"""Get channel layer, returning None if channels is not installed."""
|
||||
try:
|
||||
from channels.layers import get_channel_layer
|
||||
return get_channel_layer()
|
||||
except ImportError:
|
||||
return None
|
||||
|
||||
|
||||
def _async_to_sync(coro):
|
||||
"""Wrapper for async_to_sync that handles missing channels."""
|
||||
from asgiref.sync import async_to_sync
|
||||
return async_to_sync(coro)
|
||||
|
||||
|
||||
def get_topic_group_name(topic: str) -> str:
|
||||
"""Convert a topic string to a valid channel layer group name."""
|
||||
# Channel layer group names must be valid ASCII alphanumeric + hyphens/underscores/periods
|
||||
# Replace colons with underscores
|
||||
return topic.replace(":", "_")
|
||||
|
||||
|
||||
def subscribe(request, topic: str) -> None:
|
||||
"""
|
||||
Subscribe this WebSocket connection to a topic.
|
||||
|
||||
Call this in a context or server function to register the connection
|
||||
for push notifications on the given topic.
|
||||
|
||||
Args:
|
||||
request: The Django request (must have channel_name attribute from WebSocket)
|
||||
topic: Topic string, e.g., "room:42", "user:123:notifications"
|
||||
"""
|
||||
channel_name = getattr(request, "channel_name", None)
|
||||
if not channel_name:
|
||||
# HTTP request, not WebSocket - can't subscribe
|
||||
return
|
||||
|
||||
channel_layer = _get_channel_layer()
|
||||
if not channel_layer:
|
||||
return
|
||||
|
||||
group_name = get_topic_group_name(topic)
|
||||
_async_to_sync(channel_layer.group_add)(group_name, channel_name)
|
||||
|
||||
|
||||
def unsubscribe(request, topic: str) -> None:
|
||||
"""
|
||||
Unsubscribe this WebSocket connection from a topic.
|
||||
|
||||
Args:
|
||||
request: The Django request (must have channel_name attribute from WebSocket)
|
||||
topic: Topic string to unsubscribe from
|
||||
"""
|
||||
channel_name = getattr(request, "channel_name", None)
|
||||
if not channel_name:
|
||||
return
|
||||
|
||||
channel_layer = _get_channel_layer()
|
||||
if not channel_layer:
|
||||
return
|
||||
|
||||
group_name = get_topic_group_name(topic)
|
||||
_async_to_sync(channel_layer.group_discard)(group_name, channel_name)
|
||||
|
||||
|
||||
def push(topic: str, data: dict | BaseModel) -> None:
|
||||
"""
|
||||
Push data to all connections subscribed to a topic.
|
||||
|
||||
Args:
|
||||
topic: Topic string, e.g., "room:42"
|
||||
data: Data to send (dict or Pydantic model)
|
||||
|
||||
Example:
|
||||
push("room:42", {
|
||||
"type": "new_message",
|
||||
"message": {"id": 1, "text": "Hello", "user": "alice@example.com"}
|
||||
})
|
||||
"""
|
||||
channel_layer = _get_channel_layer()
|
||||
if not channel_layer:
|
||||
import logging
|
||||
logging.getLogger(__name__).warning(
|
||||
"No channel layer configured, cannot push to topic '%s'", topic
|
||||
)
|
||||
return
|
||||
|
||||
# Convert Pydantic model to dict if needed
|
||||
if isinstance(data, BaseModel):
|
||||
data = data.model_dump()
|
||||
|
||||
group_name = get_topic_group_name(topic)
|
||||
|
||||
_async_to_sync(channel_layer.group_send)(
|
||||
group_name,
|
||||
{
|
||||
"type": "push.message", # Maps to push_message handler in consumer
|
||||
"topic": topic,
|
||||
"data": data,
|
||||
}
|
||||
)
|
||||
|
||||
|
||||
async def push_async(topic: str, data: dict | BaseModel) -> None:
|
||||
"""Async version of push for use in async contexts."""
|
||||
channel_layer = _get_channel_layer()
|
||||
if not channel_layer:
|
||||
return
|
||||
|
||||
if isinstance(data, BaseModel):
|
||||
data = data.model_dump()
|
||||
|
||||
group_name = get_topic_group_name(topic)
|
||||
|
||||
await channel_layer.group_send(
|
||||
group_name,
|
||||
{
|
||||
"type": "push.message",
|
||||
"topic": topic,
|
||||
"data": data,
|
||||
}
|
||||
)
|
||||
60
django/src/djarea/client/__init__.py
Normal file
60
django/src/djarea/client/__init__.py
Normal file
@@ -0,0 +1,60 @@
|
||||
"""
|
||||
djarea.client - Server function implementation.
|
||||
|
||||
This subpackage contains everything needed to make server functions work:
|
||||
- The @client decorator
|
||||
- ServerFunction base class
|
||||
- Function execution logic
|
||||
- JWT authentication (integral to server functions)
|
||||
|
||||
Usage:
|
||||
from djarea.client import client, ServerFunction, compose
|
||||
"""
|
||||
|
||||
from .function import (
|
||||
# Decorator
|
||||
client,
|
||||
# Base classes
|
||||
ServerFunction,
|
||||
ComposedContext,
|
||||
# Composition
|
||||
compose,
|
||||
# Type aliases
|
||||
ContextMode,
|
||||
# Form helpers
|
||||
FormValidationOutput,
|
||||
FormSchemaField,
|
||||
FormSchemaOutput,
|
||||
create_form_functions,
|
||||
)
|
||||
|
||||
from .executor import (
|
||||
execute_function,
|
||||
function_call_view,
|
||||
ErrorCode,
|
||||
FunctionError,
|
||||
FunctionResult,
|
||||
)
|
||||
|
||||
__all__ = [
|
||||
# Decorator
|
||||
"client",
|
||||
# Base classes
|
||||
"ServerFunction",
|
||||
"ComposedContext",
|
||||
# Composition
|
||||
"compose",
|
||||
# Type aliases
|
||||
"ContextMode",
|
||||
# Execution
|
||||
"execute_function",
|
||||
"function_call_view",
|
||||
"ErrorCode",
|
||||
"FunctionError",
|
||||
"FunctionResult",
|
||||
# Form helpers
|
||||
"FormValidationOutput",
|
||||
"FormSchemaField",
|
||||
"FormSchemaOutput",
|
||||
"create_form_functions",
|
||||
]
|
||||
478
django/src/djarea/client/executor.py
Normal file
478
django/src/djarea/client/executor.py
Normal file
@@ -0,0 +1,478 @@
|
||||
"""
|
||||
Djarea Function Executor
|
||||
|
||||
Handles execution of server functions.
|
||||
This is the core of the "Server Functions" feature - callable from React
|
||||
without REST boilerplate.
|
||||
|
||||
Security model:
|
||||
- All input validated against Pydantic schema BEFORE execution
|
||||
- Authentication: JWT (stateless) or Session (stateful) - auto-detected
|
||||
- JWT: Authorization header with Bearer token (no CSRF needed)
|
||||
- Session: Cookie-based with CSRF token (via X-CSRFToken header)
|
||||
- WebSocket RPC uses Origin header checking instead
|
||||
- No implicit function exposure - must be explicitly registered
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
import logging
|
||||
from dataclasses import dataclass
|
||||
from enum import Enum
|
||||
from functools import wraps
|
||||
from typing import TYPE_CHECKING, Any, Callable
|
||||
|
||||
from django.http import HttpRequest, JsonResponse
|
||||
from django.views.decorators.csrf import csrf_protect
|
||||
from pydantic import BaseModel, ValidationError
|
||||
|
||||
from djarea.setup.registry import get_function
|
||||
|
||||
if TYPE_CHECKING:
|
||||
pass
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class ErrorCode(str, Enum):
|
||||
"""Standard error codes for function execution."""
|
||||
|
||||
# Client errors (4xx)
|
||||
NOT_FOUND = "NOT_FOUND" # Function not registered
|
||||
VALIDATION_ERROR = "VALIDATION_ERROR" # Input failed Pydantic validation
|
||||
UNAUTHORIZED = "UNAUTHORIZED" # User not authenticated (when required)
|
||||
FORBIDDEN = "FORBIDDEN" # User lacks permission
|
||||
BAD_REQUEST = "BAD_REQUEST" # Malformed request
|
||||
|
||||
# Server errors (5xx)
|
||||
INTERNAL_ERROR = "INTERNAL_ERROR" # Unhandled exception
|
||||
NOT_IMPLEMENTED = "NOT_IMPLEMENTED" # Function exists but not implemented
|
||||
|
||||
|
||||
@dataclass
|
||||
class FunctionError:
|
||||
"""Structured error response from function execution."""
|
||||
|
||||
code: ErrorCode
|
||||
message: str
|
||||
details: dict[str, Any] | None = None
|
||||
|
||||
def to_dict(self) -> dict[str, Any]:
|
||||
"""Convert to JSON-serializable dict."""
|
||||
result = {
|
||||
"error": True,
|
||||
"code": self.code.value,
|
||||
"message": self.message,
|
||||
}
|
||||
if self.details:
|
||||
result["details"] = self.details
|
||||
return result
|
||||
|
||||
def to_response(self, status: int = 400) -> JsonResponse:
|
||||
"""Convert to Django JsonResponse."""
|
||||
return JsonResponse(self.to_dict(), status=status)
|
||||
|
||||
|
||||
@dataclass
|
||||
class FunctionResult:
|
||||
"""Successful result from function execution."""
|
||||
|
||||
data: dict[str, Any]
|
||||
|
||||
def to_dict(self) -> dict[str, Any]:
|
||||
"""Convert to JSON-serializable dict."""
|
||||
return {
|
||||
"error": False,
|
||||
"data": self.data,
|
||||
}
|
||||
|
||||
def to_response(self) -> JsonResponse:
|
||||
"""Convert to Django JsonResponse."""
|
||||
return JsonResponse(self.to_dict())
|
||||
|
||||
|
||||
def _check_auth_requirement(
|
||||
request: HttpRequest,
|
||||
auth_requirement: str | Callable | None,
|
||||
) -> FunctionError | None:
|
||||
"""
|
||||
Check if the request meets the auth requirement.
|
||||
|
||||
Args:
|
||||
request: The Django HttpRequest (with user set)
|
||||
auth_requirement: 'required', 'staff', 'superuser', callable, or None
|
||||
|
||||
Returns:
|
||||
FunctionError if auth check fails, None if it passes.
|
||||
|
||||
Note: This uses request.user which may be a JWTUser (stateless) or
|
||||
Django User (from session). Either way, no additional DB query is made
|
||||
for the built-in checks. Custom callables may query DB if they choose.
|
||||
"""
|
||||
if auth_requirement is None:
|
||||
return None
|
||||
|
||||
user = request.user
|
||||
|
||||
# Handle callable auth
|
||||
if callable(auth_requirement):
|
||||
try:
|
||||
result = auth_requirement(request)
|
||||
if result:
|
||||
return None # Authorized
|
||||
else:
|
||||
return FunctionError(
|
||||
code=ErrorCode.FORBIDDEN,
|
||||
message="Access denied",
|
||||
)
|
||||
except PermissionError as e:
|
||||
# Custom error message from the callable
|
||||
return FunctionError(
|
||||
code=ErrorCode.FORBIDDEN,
|
||||
message=str(e) or "Access denied",
|
||||
)
|
||||
|
||||
# Check authentication (required for all string-based auth)
|
||||
if not getattr(user, 'is_authenticated', False):
|
||||
return FunctionError(
|
||||
code=ErrorCode.UNAUTHORIZED,
|
||||
message="Authentication required",
|
||||
)
|
||||
|
||||
# Check staff requirement
|
||||
if auth_requirement == 'staff':
|
||||
if not getattr(user, 'is_staff', False):
|
||||
return FunctionError(
|
||||
code=ErrorCode.FORBIDDEN,
|
||||
message="Staff access required",
|
||||
)
|
||||
|
||||
# Check superuser requirement
|
||||
elif auth_requirement == 'superuser':
|
||||
if not getattr(user, 'is_superuser', False):
|
||||
return FunctionError(
|
||||
code=ErrorCode.FORBIDDEN,
|
||||
message="Superuser access required",
|
||||
)
|
||||
|
||||
return None
|
||||
|
||||
|
||||
def execute_function(
|
||||
request: HttpRequest,
|
||||
fn_name: str,
|
||||
input_data: dict[str, Any] | None = None,
|
||||
) -> FunctionResult | FunctionError:
|
||||
"""
|
||||
Execute a registered server function.
|
||||
|
||||
Args:
|
||||
request: The Django HttpRequest
|
||||
fn_name: Name of the registered function
|
||||
input_data: Input data to pass to the function
|
||||
|
||||
Returns:
|
||||
FunctionResult on success, FunctionError on failure
|
||||
"""
|
||||
from django.conf import settings
|
||||
|
||||
# Look up the function by name
|
||||
view_class = get_function(fn_name)
|
||||
if view_class is None:
|
||||
# In DEBUG mode, include the name for easier debugging
|
||||
if settings.DEBUG:
|
||||
message = f"Function '{fn_name}' not found"
|
||||
else:
|
||||
message = "Function not found"
|
||||
return FunctionError(
|
||||
code=ErrorCode.NOT_FOUND,
|
||||
message=message,
|
||||
)
|
||||
|
||||
# Check auth requirement BEFORE executing
|
||||
meta = getattr(view_class, "_meta", {})
|
||||
auth_requirement = meta.get("auth")
|
||||
auth_error = _check_auth_requirement(request, auth_requirement)
|
||||
if auth_error is not None:
|
||||
return auth_error
|
||||
|
||||
# Instantiate the view with the request
|
||||
view = view_class(request)
|
||||
|
||||
# Check if this is a form function that handles input specially
|
||||
meta = getattr(view_class, "_meta", {})
|
||||
is_form_multipart = meta.get("multipart", False)
|
||||
|
||||
# For form functions with Input=None, skip Pydantic validation
|
||||
# The form itself handles validation
|
||||
input_cls = view.Input
|
||||
if input_cls is None and is_form_multipart:
|
||||
# Form function - pass input_data directly (already parsed by view or will be)
|
||||
validated_input = input_data
|
||||
elif input_cls is BaseModel:
|
||||
has_input = False
|
||||
validated_input = None
|
||||
else:
|
||||
# Check if it has any fields defined
|
||||
has_input = bool(input_cls.model_fields) if input_cls else False
|
||||
|
||||
# Validate input against Pydantic schema
|
||||
try:
|
||||
if input_data:
|
||||
# Ensure input_data is a dict (not array or other type)
|
||||
if not isinstance(input_data, dict):
|
||||
return FunctionError(
|
||||
code=ErrorCode.BAD_REQUEST,
|
||||
message="Input must be an object, not " + type(input_data).__name__,
|
||||
)
|
||||
validated_input = input_cls(**input_data)
|
||||
elif has_input:
|
||||
# Check if function requires input fields
|
||||
input_schema = input_cls.model_json_schema()
|
||||
required_fields = input_schema.get("required", [])
|
||||
if required_fields:
|
||||
# Format as field errors for consistency
|
||||
errors = {field: ["Field required"] for field in required_fields}
|
||||
return FunctionError(
|
||||
code=ErrorCode.VALIDATION_ERROR,
|
||||
message="Input validation failed",
|
||||
details={"fields": errors},
|
||||
)
|
||||
validated_input = input_cls()
|
||||
else:
|
||||
# No input expected, create empty model
|
||||
validated_input = None
|
||||
except ValidationError as e:
|
||||
# Convert Pydantic errors to our format
|
||||
errors = {}
|
||||
for error in e.errors():
|
||||
field = ".".join(str(loc) for loc in error["loc"])
|
||||
if field not in errors:
|
||||
errors[field] = []
|
||||
errors[field].append(error["msg"])
|
||||
|
||||
return FunctionError(
|
||||
code=ErrorCode.VALIDATION_ERROR,
|
||||
message="Input validation failed",
|
||||
details={"fields": errors},
|
||||
)
|
||||
|
||||
# Execute the function
|
||||
try:
|
||||
output = view.call(validated_input)
|
||||
except NotImplementedError as e:
|
||||
logger.error(f"Function {fn_name} not implemented: {e}")
|
||||
return FunctionError(
|
||||
code=ErrorCode.NOT_IMPLEMENTED,
|
||||
message=str(e),
|
||||
)
|
||||
except PermissionError as e:
|
||||
# Functions can raise PermissionError for auth issues
|
||||
return FunctionError(
|
||||
code=ErrorCode.FORBIDDEN,
|
||||
message=str(e) or "Permission denied",
|
||||
)
|
||||
except Exception as e:
|
||||
# Log the full exception for debugging
|
||||
logger.exception(f"Error executing function {fn_name}")
|
||||
return FunctionError(
|
||||
code=ErrorCode.INTERNAL_ERROR,
|
||||
message="An internal error occurred",
|
||||
# Don't expose internal details in production
|
||||
details={"type": type(e).__name__} if logger.isEnabledFor(logging.DEBUG) else None,
|
||||
)
|
||||
|
||||
# Serialize output (handle None for Optional return types)
|
||||
if output is None:
|
||||
return FunctionResult(data=None)
|
||||
return FunctionResult(data=output.model_dump())
|
||||
|
||||
|
||||
def _try_jwt_auth(request: HttpRequest) -> bool:
|
||||
"""
|
||||
Attempt to authenticate the request using JWT.
|
||||
|
||||
If Authorization header contains a valid Bearer token, authenticates
|
||||
the request and sets request.user to a JWTUser. Returns True if JWT
|
||||
auth succeeded.
|
||||
|
||||
IMPORTANT: This is stateless - no database query is made. The JWTUser
|
||||
object is created from the token claims. If you need the full User
|
||||
object, query it explicitly in your function.
|
||||
|
||||
Security: If JWT is provided but invalid, we return False and do NOT
|
||||
fall back to session auth. The caller should reject the request.
|
||||
"""
|
||||
auth_header = request.META.get("HTTP_AUTHORIZATION", "")
|
||||
if not auth_header.startswith("Bearer "):
|
||||
return False
|
||||
|
||||
token = auth_header[7:] # Strip "Bearer "
|
||||
if not token:
|
||||
return False
|
||||
|
||||
try:
|
||||
from djarea.client.jwt import decode_token
|
||||
from djarea.jwt.tokens import JWTUser
|
||||
|
||||
payload = decode_token(token, expected_type="access")
|
||||
if payload is None:
|
||||
return False
|
||||
|
||||
# Create JWTUser from token claims - NO DATABASE QUERY
|
||||
request.user = JWTUser(payload)
|
||||
request._djarea_jwt_authenticated = True
|
||||
return True
|
||||
except Exception:
|
||||
return False
|
||||
|
||||
|
||||
def _has_jwt_header(request: HttpRequest) -> bool:
|
||||
"""Check if request has a JWT Authorization header."""
|
||||
auth_header = request.META.get("HTTP_AUTHORIZATION", "")
|
||||
return auth_header.startswith("Bearer ")
|
||||
|
||||
|
||||
def _csrf_protect_unless_jwt(view_func):
|
||||
"""
|
||||
Decorator that applies CSRF protection unless JWT auth is used.
|
||||
|
||||
JWT tokens are self-authenticating (the token itself proves the request
|
||||
is legitimate), so CSRF protection is not needed.
|
||||
|
||||
Security: If JWT is provided but invalid, reject the request - do NOT
|
||||
fall back to session auth. This prevents attacks where an invalid token
|
||||
is sent alongside a valid session cookie.
|
||||
"""
|
||||
csrf_protected_view = csrf_protect(view_func)
|
||||
|
||||
@wraps(view_func)
|
||||
def wrapper(request: HttpRequest, *args, **kwargs):
|
||||
# Check if JWT header is present
|
||||
has_jwt = _has_jwt_header(request)
|
||||
|
||||
if has_jwt:
|
||||
# JWT header present - try to authenticate
|
||||
if _try_jwt_auth(request):
|
||||
# JWT valid - skip CSRF, proceed
|
||||
return view_func(request, *args, **kwargs)
|
||||
else:
|
||||
# JWT invalid - reject (do NOT fall back to session)
|
||||
return FunctionError(
|
||||
code=ErrorCode.UNAUTHORIZED,
|
||||
message="Invalid or expired JWT token",
|
||||
).to_response(status=401)
|
||||
else:
|
||||
# No JWT - use session auth with CSRF
|
||||
return csrf_protected_view(request, *args, **kwargs)
|
||||
|
||||
return wrapper
|
||||
|
||||
|
||||
@_csrf_protect_unless_jwt
|
||||
def function_call_view(request: HttpRequest) -> JsonResponse:
|
||||
"""
|
||||
Django view for handling function calls (HTTP fallback for WebSocket RPC).
|
||||
|
||||
Authentication (auto-detected):
|
||||
- JWT: Authorization: Bearer <token> (stateless, no CSRF needed)
|
||||
- Session: Cookie-based with X-CSRFToken header (CSRF required)
|
||||
|
||||
Endpoint: POST /api/djarea/call/
|
||||
|
||||
Request body (JSON):
|
||||
{
|
||||
"fn": "function_name", // Function name
|
||||
"args": { ... } // Optional, depending on function
|
||||
}
|
||||
|
||||
Request body (multipart/form-data for form submit functions):
|
||||
fn: function_name
|
||||
<field>: <value>
|
||||
...
|
||||
|
||||
Response on success:
|
||||
{
|
||||
"error": false,
|
||||
"data": { ... } // Function output
|
||||
}
|
||||
|
||||
Response on error:
|
||||
{
|
||||
"error": true,
|
||||
"code": "VALIDATION_ERROR",
|
||||
"message": "Input validation failed",
|
||||
"details": { ... }
|
||||
}
|
||||
"""
|
||||
# Only allow POST
|
||||
if request.method != "POST":
|
||||
return FunctionError(
|
||||
code=ErrorCode.BAD_REQUEST,
|
||||
message="Only POST method allowed",
|
||||
).to_response(status=405)
|
||||
|
||||
# Check content type to determine parsing method
|
||||
content_type = request.content_type or ""
|
||||
is_multipart = content_type.startswith("multipart/form-data")
|
||||
|
||||
if is_multipart:
|
||||
# Multipart form data - used by form submit functions
|
||||
fn_name = request.POST.get("fn")
|
||||
if not fn_name:
|
||||
return FunctionError(
|
||||
code=ErrorCode.BAD_REQUEST,
|
||||
message="Missing 'fn' field",
|
||||
).to_response()
|
||||
|
||||
# Get form data (excluding 'fn')
|
||||
input_data = {k: v for k, v in request.POST.dict().items() if k != "fn"}
|
||||
|
||||
# Attach parsed form data and files to request for form functions
|
||||
request._djarea_form_data = input_data
|
||||
request._djarea_form_files = request.FILES
|
||||
|
||||
else:
|
||||
# JSON body - standard RPC
|
||||
try:
|
||||
if request.body:
|
||||
body = json.loads(request.body)
|
||||
else:
|
||||
return FunctionError(
|
||||
code=ErrorCode.BAD_REQUEST,
|
||||
message="Request body required",
|
||||
).to_response()
|
||||
except json.JSONDecodeError:
|
||||
return FunctionError(
|
||||
code=ErrorCode.BAD_REQUEST,
|
||||
message="Invalid JSON in request body",
|
||||
).to_response()
|
||||
|
||||
# Extract function name and args
|
||||
fn_name = body.get("fn")
|
||||
if not fn_name:
|
||||
return FunctionError(
|
||||
code=ErrorCode.BAD_REQUEST,
|
||||
message="Missing 'fn' field",
|
||||
).to_response()
|
||||
|
||||
input_data = body.get("args")
|
||||
|
||||
# Execute the function
|
||||
result = execute_function(request, fn_name, input_data)
|
||||
|
||||
# Return appropriate response
|
||||
if isinstance(result, FunctionError):
|
||||
status = {
|
||||
ErrorCode.NOT_FOUND: 404,
|
||||
ErrorCode.VALIDATION_ERROR: 422,
|
||||
ErrorCode.UNAUTHORIZED: 401,
|
||||
ErrorCode.FORBIDDEN: 403,
|
||||
ErrorCode.BAD_REQUEST: 400,
|
||||
ErrorCode.INTERNAL_ERROR: 500,
|
||||
ErrorCode.NOT_IMPLEMENTED: 501,
|
||||
}.get(result.code, 400)
|
||||
return result.to_response(status=status)
|
||||
|
||||
return result.to_response()
|
||||
694
django/src/djarea/client/function.py
Normal file
694
django/src/djarea/client/function.py
Normal file
@@ -0,0 +1,694 @@
|
||||
"""
|
||||
Djarea Server Functions - Core Primitive
|
||||
|
||||
Server functions are the core primitive. Everything else builds on them.
|
||||
|
||||
Two styles supported:
|
||||
|
||||
1. Function-based (recommended, Django Ninja style):
|
||||
@client("update-profile")
|
||||
def update_profile(request, input: UpdateProfileInput) -> UpdateProfileOutput:
|
||||
return UpdateProfileOutput(success=True)
|
||||
|
||||
2. Class-based (for complex cases):
|
||||
class UpdateProfile(ServerFunction):
|
||||
def call(self, input: UpdateProfileInput) -> UpdateProfileOutput:
|
||||
return UpdateProfileOutput(success=True)
|
||||
register(UpdateProfile, 'update-profile')
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import inspect
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import Any, Callable, ClassVar, Generic, Literal, TypeVar, Union, get_args, get_origin, get_type_hints
|
||||
|
||||
from django.http import HttpRequest
|
||||
from pydantic import BaseModel
|
||||
|
||||
|
||||
# Valid context modes: 'global', 'local', or False (not a context)
|
||||
ContextMode = Literal['global', 'local', False]
|
||||
|
||||
|
||||
TInput = TypeVar("TInput", bound=BaseModel)
|
||||
TOutput = TypeVar("TOutput", bound=BaseModel)
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# SERVER FUNCTION - The Core Primitive
|
||||
# =============================================================================
|
||||
|
||||
|
||||
class ServerFunction(ABC, Generic[TInput, TOutput]):
|
||||
"""
|
||||
Class-based server function (for complex cases).
|
||||
|
||||
For simple functions, use the @client decorator instead.
|
||||
|
||||
Usage:
|
||||
class UpdateProfile(ServerFunction):
|
||||
def call(self, input: UpdateProfileInput) -> UpdateProfileOutput:
|
||||
self.user.name = input.name
|
||||
self.user.save()
|
||||
return UpdateProfileOutput(success=True)
|
||||
|
||||
register(UpdateProfile, 'update-profile')
|
||||
"""
|
||||
|
||||
# Registration name (set by register())
|
||||
name: ClassVar[str]
|
||||
|
||||
# Metadata for code generation
|
||||
_meta: ClassVar[dict[str, Any]] = {}
|
||||
|
||||
# Schema classes (set automatically from type hints or explicitly)
|
||||
Input: ClassVar[type[BaseModel]] = BaseModel
|
||||
Output: ClassVar[type[BaseModel]] = BaseModel
|
||||
|
||||
def __init__(self, request: HttpRequest):
|
||||
"""Initialize with the Django request."""
|
||||
self.request = request
|
||||
|
||||
@property
|
||||
def user(self):
|
||||
"""Shortcut to request.user."""
|
||||
return self.request.user
|
||||
|
||||
@abstractmethod
|
||||
def call(self, input: TInput) -> TOutput:
|
||||
"""
|
||||
Execute the function.
|
||||
|
||||
Args:
|
||||
input: Validated input data
|
||||
|
||||
Returns:
|
||||
Output instance
|
||||
"""
|
||||
raise NotImplementedError(f"{self.__class__.__name__} must implement call()")
|
||||
|
||||
@classmethod
|
||||
def get_schema_export(cls) -> dict[str, Any]:
|
||||
"""Export schema for TypeScript generation."""
|
||||
export = {
|
||||
"name": getattr(cls, "name", cls.__name__),
|
||||
"type": "function",
|
||||
"meta": getattr(cls, "_meta", {}),
|
||||
}
|
||||
|
||||
# Get Input/Output from class attributes
|
||||
input_cls = getattr(cls, "Input", BaseModel)
|
||||
output_cls = getattr(cls, "Output", BaseModel)
|
||||
|
||||
# Check if Input has fields
|
||||
input_schema = input_cls.model_json_schema()
|
||||
has_input = bool(input_schema.get("properties"))
|
||||
if has_input:
|
||||
export["input"] = input_schema
|
||||
export["has_input"] = has_input
|
||||
|
||||
export["output"] = output_cls.model_json_schema()
|
||||
|
||||
return export
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# FUNCTION DECORATOR - Django Ninja Style
|
||||
# =============================================================================
|
||||
|
||||
|
||||
class _FunctionWrapper(ServerFunction):
|
||||
"""Internal wrapper that makes a plain function behave like a ServerFunction."""
|
||||
|
||||
# Will be set per-wrapper instance
|
||||
_wrapped_fn: ClassVar[Callable]
|
||||
_input_cls: ClassVar[type[BaseModel] | None]
|
||||
_output_cls: ClassVar[type[BaseModel]]
|
||||
_param_names: ClassVar[list[str]] = []
|
||||
_is_primitive_output: ClassVar[bool] = False
|
||||
|
||||
def call(self, input):
|
||||
"""Execute the wrapped function, unpacking input into individual args."""
|
||||
if input is not None and self._param_names:
|
||||
# Unpack validated model into keyword arguments
|
||||
kwargs = {name: getattr(input, name) for name in self._param_names}
|
||||
result = self._wrapped_fn(self.request, **kwargs)
|
||||
else:
|
||||
result = self._wrapped_fn(self.request)
|
||||
|
||||
# Wrap primitive returns in the generated output model
|
||||
if self._is_primitive_output:
|
||||
return self._output_cls(result=result)
|
||||
return result
|
||||
|
||||
@classmethod
|
||||
def get_schema_export(cls) -> dict[str, Any]:
|
||||
"""Export schema for TypeScript generation."""
|
||||
export = {
|
||||
"name": getattr(cls, "name", cls.__name__),
|
||||
"type": "function",
|
||||
"meta": getattr(cls, "_meta", {}),
|
||||
}
|
||||
|
||||
# Use stored schema classes
|
||||
if cls._input_cls is not None:
|
||||
input_schema = cls._input_cls.model_json_schema()
|
||||
has_input = bool(input_schema.get("properties"))
|
||||
if has_input:
|
||||
export["input"] = input_schema
|
||||
export["has_input"] = has_input
|
||||
else:
|
||||
export["has_input"] = False
|
||||
|
||||
export["output"] = cls._output_cls.model_json_schema()
|
||||
|
||||
return export
|
||||
|
||||
|
||||
# Valid string values for auth parameter
|
||||
_VALID_AUTH_STRINGS = frozenset({'required', 'staff', 'superuser'})
|
||||
|
||||
|
||||
def client(
|
||||
fn: Callable = None,
|
||||
*,
|
||||
context: ContextMode = False,
|
||||
websocket: bool = False,
|
||||
auth: bool | str | Callable[[Any], bool] | None = None,
|
||||
) -> type[ServerFunction] | Callable[[Callable], type[ServerFunction]]:
|
||||
"""
|
||||
Register a function as a server function.
|
||||
|
||||
Type annotations define the schema - just like Django Ninja/FastAPI.
|
||||
Function parameters become input fields automatically.
|
||||
|
||||
Args:
|
||||
context: Context mode for React state management.
|
||||
- False (default): Not a context, just a callable function
|
||||
- 'global': Embedded in root DjangoContext, no params, singleton
|
||||
- 'local': Standalone provider, supports params via flat props
|
||||
|
||||
websocket: Enable WebSocket RPC transport (default: False).
|
||||
By default, functions use HTTP-only transport. Enable this for
|
||||
real-time features (chat, gaming, live updates) that benefit
|
||||
from lower latency.
|
||||
|
||||
Note: Forms (DjareaFormMixin) always use HTTP because auth
|
||||
flows require full HTTP request semantics.
|
||||
|
||||
auth: Authentication requirement.
|
||||
- None (default): No auth required (AnonymousUser allowed)
|
||||
- True or 'required': Must be authenticated
|
||||
- 'staff': Must have is_staff=True
|
||||
- 'superuser': Must have is_superuser=True
|
||||
- callable(request) -> bool: Custom check function
|
||||
|
||||
Usage:
|
||||
# Basic HTTP-only function (not a context)
|
||||
@client
|
||||
def echo(request, message: str) -> EchoOutput:
|
||||
return EchoOutput(message=message)
|
||||
|
||||
# Global context - embedded in DjangoContext, no params
|
||||
@client(context='global')
|
||||
def current_user(request) -> UserOutput:
|
||||
return UserOutput(email=request.user.email)
|
||||
|
||||
# Local context - standalone provider, supports params
|
||||
@client(context='local')
|
||||
def user_profile(request, user_id: int) -> ProfileOutput:
|
||||
return ProfileOutput(...)
|
||||
|
||||
# WebSocket-enabled for real-time
|
||||
@client(websocket=True)
|
||||
def send_message(request, room_id: int, text: str) -> MessageOutput:
|
||||
return MessageOutput(...)
|
||||
|
||||
# Local context with WebSocket (live data)
|
||||
@client(context='local', websocket=True)
|
||||
def live_user_status(request, user_id: int) -> StatusOutput:
|
||||
return StatusOutput(...)
|
||||
|
||||
Returns:
|
||||
A ServerFunction class that wraps the function
|
||||
"""
|
||||
# Validate context parameter
|
||||
if context not in (False, 'global', 'local'):
|
||||
raise ValueError(
|
||||
f"Invalid context value '{context}'. "
|
||||
f"Must be False, 'global', or 'local'."
|
||||
)
|
||||
|
||||
# Validate auth parameter
|
||||
if auth is not None:
|
||||
if isinstance(auth, str) and auth not in _VALID_AUTH_STRINGS:
|
||||
raise ValueError(
|
||||
f"Invalid auth value '{auth}'. "
|
||||
f"Must be one of: {', '.join(sorted(_VALID_AUTH_STRINGS))}, True, or a callable."
|
||||
)
|
||||
|
||||
def decorator(fn: Callable) -> type[ServerFunction]:
|
||||
return _create_server_function(fn, context=context, websocket=websocket, auth=auth)
|
||||
|
||||
# Support both @client and @client(...)
|
||||
if fn is not None:
|
||||
return _create_server_function(fn, context=context, websocket=websocket, auth=auth)
|
||||
return decorator
|
||||
|
||||
|
||||
def _create_server_function(
|
||||
fn: Callable,
|
||||
*,
|
||||
context: ContextMode = False,
|
||||
websocket: bool = False,
|
||||
auth: bool | str | None = None,
|
||||
) -> type[ServerFunction]:
|
||||
"""Internal helper that creates a ServerFunction from a decorated function."""
|
||||
from pydantic import create_model
|
||||
|
||||
# Use function name directly
|
||||
name = fn.__name__
|
||||
|
||||
# Extract type hints and signature
|
||||
hints = get_type_hints(fn)
|
||||
sig = inspect.signature(fn)
|
||||
params = list(sig.parameters.items())
|
||||
|
||||
# Skip 'request' parameter (first param)
|
||||
input_params = params[1:] if params else []
|
||||
|
||||
# Build input schema from function parameters
|
||||
if input_params:
|
||||
# Build field definitions for create_model
|
||||
# Format: {field_name: (type, default) or (type, ...)}
|
||||
fields = {}
|
||||
for param_name, param in input_params:
|
||||
param_type = hints.get(param_name, Any)
|
||||
|
||||
if param.default is inspect.Parameter.empty:
|
||||
# Required field
|
||||
fields[param_name] = (param_type, ...)
|
||||
else:
|
||||
# Optional field with default
|
||||
fields[param_name] = (param_type, param.default)
|
||||
|
||||
# Create dynamic Pydantic model
|
||||
input_cls = create_model(f"{fn.__name__}_Input", **fields)
|
||||
else:
|
||||
input_cls = None
|
||||
|
||||
# Get output type from return annotation
|
||||
output_type = hints.get("return")
|
||||
if output_type is None:
|
||||
raise TypeError(
|
||||
f"Server function '{name}' must have a return type annotation"
|
||||
)
|
||||
|
||||
# Support primitive return types by wrapping in a model with 'result' field
|
||||
# Also handle Optional[X] / X | None by extracting the non-None type
|
||||
import types
|
||||
|
||||
def is_basemodel_type(t: Any) -> bool:
|
||||
"""Check if type is a BaseModel subclass, handling Optional/Union."""
|
||||
if isinstance(t, type) and issubclass(t, BaseModel):
|
||||
return True
|
||||
# Handle Union types: typing.Union (Optional[X]) and types.UnionType (X | None)
|
||||
origin = get_origin(t)
|
||||
if origin is Union or isinstance(t, types.UnionType):
|
||||
args = get_args(t)
|
||||
# Check if any non-None arg is a BaseModel
|
||||
for arg in args:
|
||||
if arg is not type(None) and isinstance(arg, type) and issubclass(arg, BaseModel):
|
||||
return True
|
||||
return False
|
||||
|
||||
if is_basemodel_type(output_type):
|
||||
output_cls = output_type
|
||||
is_primitive_output = False
|
||||
else:
|
||||
# Create model wrapper for primitive types (int, str, list, etc.)
|
||||
output_cls = create_model(f"{fn.__name__}_Output", result=(output_type, ...))
|
||||
is_primitive_output = True
|
||||
|
||||
# Store param names for unpacking validated input
|
||||
param_names = [p[0] for p in input_params]
|
||||
|
||||
# Create a unique wrapper class for this function
|
||||
class FunctionWrapper(_FunctionWrapper):
|
||||
_param_names: ClassVar[list[str]] = param_names
|
||||
|
||||
FunctionWrapper.__name__ = fn.__name__
|
||||
FunctionWrapper.__doc__ = fn.__doc__
|
||||
FunctionWrapper.__module__ = fn.__module__ # Critical for discovery
|
||||
FunctionWrapper._wrapped_fn = staticmethod(fn)
|
||||
FunctionWrapper._input_cls = input_cls
|
||||
FunctionWrapper._output_cls = output_cls
|
||||
FunctionWrapper._is_primitive_output = is_primitive_output
|
||||
|
||||
# Set Input/Output class attributes for compatibility
|
||||
if input_cls is not None:
|
||||
FunctionWrapper.Input = input_cls
|
||||
FunctionWrapper.Output = output_cls
|
||||
|
||||
# Build metadata
|
||||
meta = {}
|
||||
|
||||
# Context mode: 'global' or 'local' (False means not a context)
|
||||
if context:
|
||||
meta["context"] = context
|
||||
|
||||
# WebSocket: enable WebSocket transport
|
||||
if websocket:
|
||||
meta["websocket"] = True
|
||||
|
||||
# Auth requirement
|
||||
if auth is not None:
|
||||
if auth is True:
|
||||
meta["auth"] = 'required'
|
||||
elif callable(auth):
|
||||
meta["auth"] = auth
|
||||
else:
|
||||
meta["auth"] = auth
|
||||
|
||||
if meta:
|
||||
FunctionWrapper._meta = {**FunctionWrapper._meta, **meta}
|
||||
|
||||
# Note: Registration happens via discovery (djarea_clients), not here.
|
||||
# This allows the decorator to be used without import-time side effects.
|
||||
|
||||
return FunctionWrapper
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# COMPOSE - Combine multiple contexts into a single provider
|
||||
# =============================================================================
|
||||
|
||||
|
||||
class ComposedContext:
|
||||
"""
|
||||
Marker class for composed contexts.
|
||||
|
||||
Stores metadata about the composition for schema export.
|
||||
"""
|
||||
|
||||
name: str
|
||||
_meta: dict[str, Any]
|
||||
_children: list[type[ServerFunction] | "ComposedContext"]
|
||||
_leaves: list[type[ServerFunction]]
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
name: str,
|
||||
children: list,
|
||||
leaves: list,
|
||||
on_server: bool,
|
||||
websocket: bool,
|
||||
):
|
||||
self.name = name
|
||||
self._children = children
|
||||
self._leaves = leaves
|
||||
self._meta = {
|
||||
"compose": True,
|
||||
"on_server": on_server,
|
||||
"websocket": websocket,
|
||||
"children": [c.name for c in children],
|
||||
"leaves": [leaf.name for leaf in leaves],
|
||||
}
|
||||
|
||||
@classmethod
|
||||
def get_schema_export(cls) -> dict[str, Any]:
|
||||
"""Export schema for TypeScript generation."""
|
||||
return {
|
||||
"name": cls.name,
|
||||
"type": "compose",
|
||||
"meta": cls._meta,
|
||||
"children": cls._meta.get("children", []),
|
||||
"leaves": cls._meta.get("leaves", []),
|
||||
}
|
||||
|
||||
|
||||
def _get_leaves(item) -> list[type[ServerFunction]]:
|
||||
"""Recursively collect all leaf contexts from a context or composition."""
|
||||
if isinstance(item, type) and issubclass(item, ServerFunction):
|
||||
return [item]
|
||||
elif isinstance(item, ComposedContext):
|
||||
return item._leaves.copy()
|
||||
elif hasattr(item, '_leaves'):
|
||||
# Duck typing for composed contexts
|
||||
return item._leaves.copy()
|
||||
else:
|
||||
raise TypeError(f"Expected ServerFunction or ComposedContext, got {type(item)}")
|
||||
|
||||
|
||||
def _is_context_enabled(item) -> bool:
|
||||
"""Check if an item is a context-enabled function or composition."""
|
||||
if isinstance(item, ComposedContext) or hasattr(item, '_leaves'):
|
||||
return True
|
||||
if isinstance(item, type) and issubclass(item, ServerFunction):
|
||||
meta = getattr(item, '_meta', {})
|
||||
return meta.get('context') in ('global', 'local')
|
||||
return False
|
||||
|
||||
|
||||
def compose(
|
||||
*children,
|
||||
on_server: bool = False,
|
||||
websocket: bool = False,
|
||||
):
|
||||
"""
|
||||
Compose multiple contexts into a single provider.
|
||||
|
||||
Args:
|
||||
*children: Context functions (@client with context='global'|'local')
|
||||
or other @compose functions. All must be unique after flattening.
|
||||
|
||||
on_server: Bundle all calls into a single server request (default: False).
|
||||
- False: Frontend makes individual calls (mixed HTTP/WS OK)
|
||||
- True: Single bundled call. Requires transport consistency:
|
||||
all children must be HTTP-only XOR all must be websocket=True.
|
||||
|
||||
websocket: Transport for bundled call when on_server=True (default: False).
|
||||
- False: Bundled call over HTTP. All children must be HTTP-only.
|
||||
- True: Bundled call over WebSocket. All children must have websocket=True.
|
||||
|
||||
Usage:
|
||||
@client(context='local')
|
||||
def user_profile(request, user_id: int) -> ProfileOutput: ...
|
||||
|
||||
@client(context='local')
|
||||
def user_posts(request, user_id: int) -> PostsOutput: ...
|
||||
|
||||
@compose(user_profile, user_posts)
|
||||
def user_page():
|
||||
pass
|
||||
|
||||
# Frontend generates:
|
||||
# <UserPageProvider user_id={123}>
|
||||
# <App />
|
||||
# </UserPageProvider>
|
||||
|
||||
Nesting:
|
||||
@compose(ctx_a, ctx_b)
|
||||
def ab(): pass
|
||||
|
||||
@compose(ab, ctx_c) # Flattens to [ctx_a, ctx_b, ctx_c]
|
||||
def abc(): pass
|
||||
|
||||
Returns:
|
||||
A ComposedContext that can be used in other compositions.
|
||||
"""
|
||||
def decorator(fn: Callable) -> ComposedContext:
|
||||
from djarea.setup.registry import register_compose
|
||||
|
||||
name = fn.__name__
|
||||
|
||||
# Validate: all children must be context-enabled
|
||||
for i, child in enumerate(children):
|
||||
if not _is_context_enabled(child):
|
||||
child_name = getattr(child, 'name', getattr(child, '__name__', str(child)))
|
||||
raise ValueError(
|
||||
f"@compose argument {i} ({child_name}) is not context-enabled. "
|
||||
f"All children must have @client(context='global'|'local') or be @compose."
|
||||
)
|
||||
|
||||
# Flatten to collect all leaves
|
||||
leaves = []
|
||||
for child in children:
|
||||
leaves.extend(_get_leaves(child))
|
||||
|
||||
# Validate: no duplicate leaves (by identity)
|
||||
seen = set()
|
||||
for leaf in leaves:
|
||||
if id(leaf) in seen:
|
||||
raise ValueError(
|
||||
f"Duplicate context '{leaf.name}' in @compose({name}). "
|
||||
f"Each context can only appear once. Use named kwargs for reuse (future feature)."
|
||||
)
|
||||
seen.add(id(leaf))
|
||||
|
||||
# Validate transport consistency when on_server=True
|
||||
if on_server:
|
||||
has_websocket = [getattr(leaf, '_meta', {}).get('websocket', False) for leaf in leaves]
|
||||
|
||||
if websocket:
|
||||
# All must have websocket=True
|
||||
if not all(has_websocket):
|
||||
non_ws = [leaf.name for leaf, ws in zip(leaves, has_websocket) if not ws]
|
||||
raise ValueError(
|
||||
f"@compose({name}, on_server=True, websocket=True) requires all children "
|
||||
f"to have websocket=True. These are HTTP-only: {non_ws}"
|
||||
)
|
||||
else:
|
||||
# All must be HTTP-only
|
||||
if any(has_websocket):
|
||||
ws_enabled = [leaf.name for leaf, ws in zip(leaves, has_websocket) if ws]
|
||||
raise ValueError(
|
||||
f"@compose({name}, on_server=True, websocket=False) requires all children "
|
||||
f"to be HTTP-only. These have websocket=True: {ws_enabled}"
|
||||
)
|
||||
|
||||
# Create composed context
|
||||
composed = ComposedContext(
|
||||
name=name,
|
||||
children=list(children),
|
||||
leaves=leaves,
|
||||
on_server=on_server,
|
||||
websocket=websocket,
|
||||
)
|
||||
|
||||
# Make it a class-like object for consistency
|
||||
composed.__name__ = name
|
||||
composed.__doc__ = fn.__doc__
|
||||
|
||||
# Register the composition
|
||||
register_compose(composed, name)
|
||||
|
||||
return composed
|
||||
|
||||
return decorator
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# FORM HELPERS - Output types used by form server functions
|
||||
# =============================================================================
|
||||
|
||||
|
||||
class FormValidationOutput(BaseModel):
|
||||
"""Standard output for form validation."""
|
||||
|
||||
valid: bool
|
||||
errors: dict[str, list[str]]
|
||||
|
||||
|
||||
class FormSchemaField(BaseModel):
|
||||
"""Schema for a single form field."""
|
||||
|
||||
name: str
|
||||
type: str
|
||||
required: bool
|
||||
label: str
|
||||
help_text: str | None = None
|
||||
choices: list[tuple[str, str]] | None = None
|
||||
initial: Any = None
|
||||
|
||||
|
||||
class FormSchemaOutput(BaseModel):
|
||||
"""Standard output for form schema."""
|
||||
|
||||
fields: list[FormSchemaField]
|
||||
|
||||
|
||||
def create_form_functions(
|
||||
form_class: type,
|
||||
name: str,
|
||||
submit_handler: Callable[[HttpRequest, dict], BaseModel] | None = None,
|
||||
) -> tuple[type[ServerFunction], type[ServerFunction], type[ServerFunction] | None]:
|
||||
"""
|
||||
Generate server functions for a Django Form.
|
||||
|
||||
Args:
|
||||
form_class: Django Form class
|
||||
name: Base name for the functions
|
||||
submit_handler: Optional handler for form submission
|
||||
|
||||
Returns:
|
||||
Tuple of (SchemaFunction, ValidateFunction, SubmitFunction or None)
|
||||
|
||||
Usage:
|
||||
SchemaFn, ValidateFn, SubmitFn = create_form_functions(
|
||||
ContactForm,
|
||||
'contact',
|
||||
submit_handler=lambda req, data: ContactSubmitOutput(success=True),
|
||||
)
|
||||
register(SchemaFn, 'contact-schema')
|
||||
register(ValidateFn, 'contact-validate')
|
||||
register(SubmitFn, 'contact-submit')
|
||||
|
||||
Or use the helper:
|
||||
register_form(ContactForm, 'contact', submit_handler=...)
|
||||
"""
|
||||
from djarea.forms.schema_utils import build_form_schema
|
||||
|
||||
# Schema function - returns field definitions
|
||||
class FormSchema(ServerFunction):
|
||||
class Output(FormSchemaOutput):
|
||||
pass
|
||||
|
||||
def call(self, input):
|
||||
schema = build_form_schema(form_class)
|
||||
fields = [
|
||||
FormSchemaField(
|
||||
name=field.name,
|
||||
type=field.type,
|
||||
required=field.required,
|
||||
label=field.label or field.name,
|
||||
help_text=field.help_text or None,
|
||||
choices=[(c.value, c.label) for c in field.choices] if field.choices else None,
|
||||
initial=field.initial,
|
||||
)
|
||||
for field in schema.fields
|
||||
]
|
||||
return self.Output(fields=fields)
|
||||
|
||||
FormSchema.__name__ = f"{name.title().replace('-', '')}Schema"
|
||||
FormSchema._meta = {"form": True, "form_name": name, "form_role": "schema"}
|
||||
|
||||
# Validation function
|
||||
class FormDataInput(BaseModel):
|
||||
data: dict[str, Any]
|
||||
|
||||
class FormValidate(ServerFunction):
|
||||
Input = FormDataInput
|
||||
|
||||
class Output(FormValidationOutput):
|
||||
pass
|
||||
|
||||
def call(self, input):
|
||||
form = form_class(data=input.data)
|
||||
if form.is_valid():
|
||||
return self.Output(valid=True, errors={})
|
||||
return self.Output(valid=False, errors=dict(form.errors))
|
||||
|
||||
FormValidate.__name__ = f"{name.title().replace('-', '')}Validate"
|
||||
FormValidate._meta = {"form": True, "form_name": name, "form_role": "validate"}
|
||||
|
||||
# Submit function (optional)
|
||||
FormSubmit = None
|
||||
if submit_handler:
|
||||
|
||||
class FormSubmit(ServerFunction):
|
||||
Input = FormDataInput
|
||||
|
||||
def call(self, input):
|
||||
# Validate first
|
||||
form = form_class(data=input.data)
|
||||
if not form.is_valid():
|
||||
raise ValueError("Form validation failed")
|
||||
# Call handler
|
||||
return submit_handler(self.request, form.cleaned_data)
|
||||
|
||||
FormSubmit.__name__ = f"{name.title().replace('-', '')}Submit"
|
||||
FormSubmit._meta = {"form": True, "form_name": name, "form_role": "submit"}
|
||||
|
||||
return FormSchema, FormValidate, FormSubmit
|
||||
44
django/src/djarea/client/jwt.py
Normal file
44
django/src/djarea/client/jwt.py
Normal file
@@ -0,0 +1,44 @@
|
||||
"""
|
||||
djarea.client.jwt - JWT authentication for server functions.
|
||||
|
||||
Provides:
|
||||
- Server functions for obtaining/refreshing JWT tokens
|
||||
- JWT authentication utilities for validating tokens
|
||||
|
||||
Server Functions:
|
||||
- jwt_obtain: Convert authenticated session to JWT tokens
|
||||
- jwt_refresh: Refresh tokens using a refresh token
|
||||
|
||||
Note: This module is purpose-built for Djarea server functions.
|
||||
For Django Ninja API authentication, use djarea.jwt.security directly.
|
||||
"""
|
||||
|
||||
# Token utilities (re-exports from django_jwt_session)
|
||||
from djarea.jwt.tokens import (
|
||||
create_token_pair,
|
||||
create_access_token,
|
||||
create_refresh_token,
|
||||
decode_token,
|
||||
refresh_tokens,
|
||||
TokenPair,
|
||||
TokenPayload,
|
||||
JWTUser,
|
||||
)
|
||||
|
||||
# Settings
|
||||
from djarea.jwt.settings import get_settings, JWTSettings
|
||||
|
||||
__all__ = [
|
||||
# Token utilities
|
||||
"create_token_pair",
|
||||
"create_access_token",
|
||||
"create_refresh_token",
|
||||
"decode_token",
|
||||
"refresh_tokens",
|
||||
"TokenPair",
|
||||
"TokenPayload",
|
||||
"JWTUser",
|
||||
# Settings
|
||||
"get_settings",
|
||||
"JWTSettings",
|
||||
]
|
||||
299
django/src/djarea/export/__init__.py
Normal file
299
django/src/djarea/export/__init__.py
Normal file
@@ -0,0 +1,299 @@
|
||||
"""
|
||||
Djarea OpenAPI Schema Generator
|
||||
|
||||
Generates OpenAPI 3.0 compatible schema from registered server functions.
|
||||
Uses Django Ninja's battle-tested schema generation for robust Pydantic→OpenAPI conversion.
|
||||
|
||||
This schema is consumed by the frontend generator which uses openapi-typescript
|
||||
for robust type generation.
|
||||
|
||||
NOTE: Schema export is only available via management command for security.
|
||||
HTTP endpoint has been removed to prevent function enumeration.
|
||||
|
||||
Usage:
|
||||
python manage.py export_djarea_schema
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
import re
|
||||
from typing import TYPE_CHECKING, Any
|
||||
|
||||
# Lazy imports to avoid Django settings access at module load time
|
||||
# (asgi.py imports djarea before Django is fully configured)
|
||||
if TYPE_CHECKING:
|
||||
from django import forms
|
||||
from ninja import NinjaAPI
|
||||
|
||||
from djarea.setup.registry import get_registry, get_schema
|
||||
|
||||
|
||||
__all__ = ["get_schema", "generate_openapi_schema", "generate_openapi_json"]
|
||||
|
||||
|
||||
def _extract_form_fields(form_class: type) -> list[dict[str, Any]]:
|
||||
"""
|
||||
Extract field definitions with constraints from a Django Form class.
|
||||
|
||||
Returns a list of field metadata suitable for Zod schema generation:
|
||||
- name: field name
|
||||
- zodType: base Zod type ("string", "number", "boolean", "array")
|
||||
- required: whether field is required
|
||||
- constraints: dict of Zod-compatible constraints
|
||||
|
||||
Constraints include:
|
||||
- min/max: for string length or number range
|
||||
- email/url: for format validation
|
||||
- regex: for pattern validation
|
||||
- choices: for enum validation
|
||||
"""
|
||||
try:
|
||||
# Try to instantiate form to get bound fields
|
||||
form = form_class()
|
||||
fields_dict = form.fields
|
||||
except TypeError:
|
||||
# Form requires extra args - use base_fields
|
||||
fields_dict = getattr(form_class, "base_fields", {})
|
||||
|
||||
result = []
|
||||
|
||||
for name, field in fields_dict.items():
|
||||
field_meta = _extract_field_constraints(name, field)
|
||||
result.append(field_meta)
|
||||
|
||||
return result
|
||||
|
||||
|
||||
def _extract_field_constraints(name: str, field: "forms.Field") -> dict[str, Any]:
|
||||
"""
|
||||
Extract Zod-compatible constraints from a single Django form field.
|
||||
"""
|
||||
from django import forms # Lazy import
|
||||
|
||||
meta: dict[str, Any] = {
|
||||
"name": name,
|
||||
"required": field.required,
|
||||
"constraints": {},
|
||||
}
|
||||
|
||||
# Determine base Zod type
|
||||
if isinstance(field, forms.BooleanField):
|
||||
meta["zodType"] = "boolean"
|
||||
elif isinstance(field, (forms.IntegerField, forms.FloatField, forms.DecimalField)):
|
||||
meta["zodType"] = "number"
|
||||
if isinstance(field, forms.IntegerField):
|
||||
meta["constraints"]["int"] = True
|
||||
elif isinstance(field, forms.MultipleChoiceField):
|
||||
meta["zodType"] = "array"
|
||||
meta["constraints"]["items"] = "string"
|
||||
elif isinstance(field, forms.FileField):
|
||||
meta["zodType"] = "file"
|
||||
else:
|
||||
# Default to string (CharField, EmailField, URLField, etc.)
|
||||
meta["zodType"] = "string"
|
||||
|
||||
# Extract string constraints
|
||||
if hasattr(field, "max_length") and field.max_length is not None:
|
||||
meta["constraints"]["max"] = field.max_length
|
||||
if hasattr(field, "min_length") and field.min_length is not None:
|
||||
meta["constraints"]["min"] = field.min_length
|
||||
|
||||
# Extract number constraints
|
||||
if hasattr(field, "max_value") and field.max_value is not None:
|
||||
meta["constraints"]["max"] = field.max_value
|
||||
if hasattr(field, "min_value") and field.min_value is not None:
|
||||
meta["constraints"]["min"] = field.min_value
|
||||
|
||||
# Email/URL format
|
||||
if isinstance(field, forms.EmailField):
|
||||
meta["constraints"]["email"] = True
|
||||
elif isinstance(field, forms.URLField):
|
||||
meta["constraints"]["url"] = True
|
||||
|
||||
# Choices (for enum validation)
|
||||
if hasattr(field, "choices") and field.choices:
|
||||
# Extract choice values (not labels)
|
||||
choices = []
|
||||
for choice in field.choices:
|
||||
if isinstance(choice, (list, tuple)) and len(choice) >= 1:
|
||||
# Skip empty/blank choices
|
||||
if choice[0] != "":
|
||||
choices.append(str(choice[0]))
|
||||
else:
|
||||
choices.append(str(choice))
|
||||
if choices:
|
||||
meta["constraints"]["choices"] = choices
|
||||
|
||||
# Regex validators
|
||||
for validator in field.validators:
|
||||
if hasattr(validator, "regex"):
|
||||
# RegexValidator - extract pattern
|
||||
pattern = validator.regex.pattern
|
||||
meta["constraints"]["regex"] = pattern
|
||||
if hasattr(validator, "message"):
|
||||
meta["constraints"]["regexMessage"] = validator.message
|
||||
break # Only use first regex validator
|
||||
|
||||
return meta
|
||||
|
||||
|
||||
def snake_to_camel(name: str) -> str:
|
||||
"""Convert snake_case or dotted.name to camelCase.
|
||||
|
||||
Examples:
|
||||
- login -> login
|
||||
- login.schema -> loginSchema
|
||||
- activate_totp -> activateTotp
|
||||
- activate_totp.schema -> activateTotpSchema
|
||||
"""
|
||||
# Split on both underscores and dots
|
||||
components = re.split(r"[._]", name)
|
||||
return components[0] + "".join(x.title() for x in components[1:])
|
||||
|
||||
|
||||
def _register_schema_endpoint(
|
||||
api: "NinjaAPI",
|
||||
path: str,
|
||||
operation_id: str,
|
||||
summary: str,
|
||||
input_cls: type | None,
|
||||
output_cls: type,
|
||||
) -> None:
|
||||
"""
|
||||
Register a dummy endpoint on the API for schema generation.
|
||||
|
||||
Sets __annotations__ directly to avoid closure capture issues
|
||||
and exec() security concerns.
|
||||
"""
|
||||
if input_cls is not None:
|
||||
def endpoint(request, data):
|
||||
pass
|
||||
# Set annotations directly to the actual type objects (not strings)
|
||||
endpoint.__annotations__ = {"data": input_cls}
|
||||
else:
|
||||
def endpoint(request):
|
||||
pass
|
||||
|
||||
# Register with Ninja
|
||||
api.post(path, response=output_cls, operation_id=operation_id, summary=summary)(endpoint)
|
||||
|
||||
|
||||
def generate_openapi_schema() -> dict[str, Any]:
|
||||
"""
|
||||
Generate OpenAPI 3.0 schema for all registered djarea functions.
|
||||
|
||||
Uses Django Ninja's schema generation internally to ensure proper
|
||||
Pydantic→OpenAPI conversion (handling $refs, nested types, etc.).
|
||||
|
||||
Returns a complete OpenAPI document that can be processed by openapi-typescript.
|
||||
"""
|
||||
from ninja import NinjaAPI # Lazy import
|
||||
from pydantic import BaseModel, create_model # Lazy import
|
||||
|
||||
registry = get_registry()
|
||||
functions = registry.get("functions", {})
|
||||
|
||||
# Create a temporary Ninja API for schema generation only
|
||||
# This is NOT exposed as an HTTP endpoint - purely for leveraging Ninja's
|
||||
# battle-tested Pydantic→OpenAPI conversion
|
||||
schema_api = NinjaAPI(
|
||||
title="Djarea Server Functions",
|
||||
version="1.0.0",
|
||||
description="Auto-generated schema for djarea server functions",
|
||||
docs_url=None, # No docs endpoint
|
||||
openapi_url=None, # No openapi endpoint
|
||||
)
|
||||
|
||||
function_metadata: list[dict[str, Any]] = []
|
||||
|
||||
# Store dynamically created classes so they persist for schema generation
|
||||
schema_classes: dict[str, type] = {}
|
||||
|
||||
for name, fn_class in functions.items():
|
||||
camel_name = snake_to_camel(name)
|
||||
meta = getattr(fn_class, "_meta", {})
|
||||
|
||||
# Get Input/Output classes
|
||||
input_cls = getattr(fn_class, "Input", None)
|
||||
output_cls = getattr(fn_class, "Output", None) or BaseModel
|
||||
|
||||
# Check if input_cls is a valid Pydantic model with fields
|
||||
has_input = (
|
||||
input_cls is not None
|
||||
and input_cls is not BaseModel
|
||||
and hasattr(input_cls, "model_fields")
|
||||
and bool(input_cls.model_fields)
|
||||
)
|
||||
|
||||
# Determine type names for metadata
|
||||
input_type_name = f"{camel_name}Input" if has_input else None
|
||||
output_type_name = f"{camel_name}Output"
|
||||
|
||||
# Create renamed Pydantic classes for cleaner schema names
|
||||
# Store them in schema_classes so they persist beyond loop scope
|
||||
# Uses create_model to avoid metaclass conflicts with custom base classes
|
||||
if has_input:
|
||||
schema_classes[input_type_name] = create_model(input_type_name, __base__=input_cls)
|
||||
schema_classes[output_type_name] = create_model(output_type_name, __base__=output_cls)
|
||||
|
||||
# Register endpoint using helper to avoid closure capture issues
|
||||
_register_schema_endpoint(
|
||||
api=schema_api,
|
||||
path=f"/djarea/{name}",
|
||||
operation_id=camel_name,
|
||||
summary=fn_class.__doc__ or f"Call {name}",
|
||||
input_cls=schema_classes.get(input_type_name),
|
||||
output_cls=schema_classes[output_type_name],
|
||||
)
|
||||
|
||||
# Collect function metadata for provider generation
|
||||
fn_meta_entry: dict[str, Any] = {
|
||||
"name": name,
|
||||
"camelName": camel_name,
|
||||
"hasInput": has_input,
|
||||
"inputType": input_type_name,
|
||||
"outputType": output_type_name,
|
||||
"transport": "websocket" if meta.get("websocket") else "http",
|
||||
"isContext": meta.get("context", False),
|
||||
# Form metadata
|
||||
"isForm": meta.get("form", False),
|
||||
"formName": meta.get("form_name"),
|
||||
"formRole": meta.get("form_role"), # "schema", "validate", "submit"
|
||||
}
|
||||
|
||||
# For form schema functions, extract field definitions for Zod generation
|
||||
if meta.get("form") and meta.get("form_role") == "schema":
|
||||
form_class = meta.get("form_class")
|
||||
if form_class is not None:
|
||||
try:
|
||||
fn_meta_entry["formFields"] = _extract_form_fields(form_class)
|
||||
except Exception as e:
|
||||
# Don't fail schema generation if field extraction fails
|
||||
fn_meta_entry["formFields"] = []
|
||||
fn_meta_entry["formFieldsError"] = str(e)
|
||||
|
||||
function_metadata.append(fn_meta_entry)
|
||||
|
||||
# Get the OpenAPI schema from Ninja (handles all Pydantic conversion properly)
|
||||
schema = schema_api.get_openapi_schema(path_prefix="")
|
||||
|
||||
# Add custom extension with function metadata for provider generation
|
||||
schema["x-djarea-functions"] = function_metadata
|
||||
|
||||
# Add x-djarea metadata to each operation
|
||||
for fn_meta in function_metadata:
|
||||
path = f"/djarea/{fn_meta['name']}"
|
||||
if path in schema.get("paths", {}):
|
||||
schema["paths"][path]["post"]["x-djarea"] = {
|
||||
"transport": fn_meta["transport"],
|
||||
"isContext": fn_meta["isContext"],
|
||||
}
|
||||
|
||||
return schema
|
||||
|
||||
|
||||
def generate_openapi_json(indent: int = 2) -> str:
|
||||
"""Generate OpenAPI schema as formatted JSON string."""
|
||||
schema = generate_openapi_schema()
|
||||
return json.dumps(schema, indent=indent)
|
||||
623
django/src/djarea/forms/__init__.py
Normal file
623
django/src/djarea/forms/__init__.py
Normal file
@@ -0,0 +1,623 @@
|
||||
"""
|
||||
DjareaFormMixin - Turn Django Forms into server functions.
|
||||
|
||||
This mixin transforms any Django Form into Djarea server functions,
|
||||
preserving full Django Form functionality (validation, widgets, ModelChoiceField, etc.)
|
||||
while exposing them through the unified server function API.
|
||||
|
||||
Usage:
|
||||
from django import forms
|
||||
from djarea.forms import DjareaFormMixin, DjareaFormMeta
|
||||
|
||||
class ContactForm(DjareaFormMixin, forms.Form):
|
||||
djarea = DjareaFormMeta(
|
||||
name="contact",
|
||||
title="Contact Us",
|
||||
submit_label="Send",
|
||||
)
|
||||
|
||||
name = forms.CharField()
|
||||
email = forms.EmailField()
|
||||
message = forms.CharField(widget=forms.Textarea)
|
||||
|
||||
def on_submit_success(self, request):
|
||||
send_email(self.cleaned_data)
|
||||
return {"sent": True}
|
||||
|
||||
Auto-registers server functions:
|
||||
- contact.schema
|
||||
- contact.validate
|
||||
- contact.submit
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import TYPE_CHECKING, Any, ClassVar
|
||||
|
||||
from django import forms
|
||||
from django.http import HttpRequest
|
||||
from pydantic import BaseModel, create_model
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from .schemas import FormValidation
|
||||
|
||||
|
||||
def _django_field_to_python_type(field: forms.Field) -> type:
|
||||
"""
|
||||
Map a Django form field to a Python type for Pydantic schema generation.
|
||||
|
||||
This provides TypeScript with proper field types instead of generic `any`.
|
||||
"""
|
||||
# Handle common Django field types
|
||||
if isinstance(field, forms.BooleanField):
|
||||
return bool
|
||||
elif isinstance(field, forms.IntegerField):
|
||||
return int
|
||||
elif isinstance(field, forms.FloatField):
|
||||
return float
|
||||
elif isinstance(field, forms.DecimalField):
|
||||
return str # Decimals serialize as strings for precision
|
||||
elif isinstance(field, forms.DateTimeField):
|
||||
return str # ISO format string
|
||||
elif isinstance(field, forms.DateField):
|
||||
return str # ISO format string
|
||||
elif isinstance(field, forms.TimeField):
|
||||
return str # ISO format string
|
||||
elif isinstance(field, forms.JSONField):
|
||||
return dict | list | str | int | float | bool | None
|
||||
elif isinstance(field, forms.MultipleChoiceField):
|
||||
return list[str]
|
||||
elif isinstance(field, forms.FileField):
|
||||
return str # File path/name as string
|
||||
elif isinstance(field, forms.ImageField):
|
||||
return str # File path/name as string
|
||||
else:
|
||||
# Default to string (covers CharField, EmailField, URLField, etc.)
|
||||
return str
|
||||
|
||||
|
||||
def _create_form_input_schema(
|
||||
form_class: type[forms.BaseForm],
|
||||
schema_name: str,
|
||||
) -> type[BaseModel]:
|
||||
"""
|
||||
Create a Pydantic model from Django Form fields.
|
||||
|
||||
This generates a typed schema for the form's input data, giving TypeScript
|
||||
full LSP support (autocomplete, type checking) for form fields.
|
||||
|
||||
Args:
|
||||
form_class: Django Form class to introspect
|
||||
schema_name: Name for the generated Pydantic model (e.g., "ContactFormData")
|
||||
|
||||
Returns:
|
||||
A Pydantic BaseModel subclass with fields matching the form
|
||||
"""
|
||||
# Instantiate form without data to get field definitions
|
||||
try:
|
||||
form = form_class()
|
||||
except TypeError:
|
||||
# Form requires extra args (like request) - use form_class.base_fields instead
|
||||
fields_dict = getattr(form_class, 'base_fields', {})
|
||||
else:
|
||||
fields_dict = form.fields
|
||||
|
||||
# Build Pydantic field definitions
|
||||
pydantic_fields: dict[str, Any] = {}
|
||||
|
||||
for field_name, field in fields_dict.items():
|
||||
python_type = _django_field_to_python_type(field)
|
||||
|
||||
# Optional fields (not required or has initial value)
|
||||
if not field.required:
|
||||
python_type = python_type | None
|
||||
default = None
|
||||
elif field.initial is not None:
|
||||
default = field.initial
|
||||
else:
|
||||
default = ... # Required field
|
||||
|
||||
pydantic_fields[field_name] = (python_type, default)
|
||||
|
||||
# Create the model with a unique name
|
||||
model = create_model(schema_name, **pydantic_fields)
|
||||
|
||||
return model
|
||||
|
||||
|
||||
class DjareaFormMeta(BaseModel):
|
||||
"""
|
||||
Configuration for a Djarea form.
|
||||
|
||||
This Pydantic model provides type-safe configuration with full LSP support,
|
||||
and serializes to JSON for the frontend schema.
|
||||
|
||||
Required:
|
||||
name: API identifier (e.g., "contact" → contact.schema, contact.validate, contact.submit)
|
||||
|
||||
Display options:
|
||||
title: Display title (default: derived from class name)
|
||||
subtitle: Display subtitle
|
||||
submit_label: Submit button text (default: "Submit")
|
||||
|
||||
Frontend behavior:
|
||||
live_validation: Enable live validation as user types (default: True)
|
||||
live_form_errors: Show form-level errors during live validation (default: False)
|
||||
refetch_schema_on_validate: Refetch schema on each validation - useful for
|
||||
dynamic choice fields (default: False)
|
||||
|
||||
Features:
|
||||
enable_formset: Generate formset endpoints (default: False)
|
||||
"""
|
||||
|
||||
# Required
|
||||
name: str
|
||||
|
||||
# Display
|
||||
title: str | None = None
|
||||
subtitle: str | None = None
|
||||
submit_label: str = "Submit"
|
||||
|
||||
# Frontend behavior
|
||||
live_validation: bool = True
|
||||
live_form_errors: bool = False
|
||||
refetch_schema_on_validate: bool = False
|
||||
|
||||
# Features
|
||||
enable_formset: bool = False
|
||||
|
||||
|
||||
class DjareaFormMixin:
|
||||
"""
|
||||
Mixin that exposes a Django Form as Djarea server functions.
|
||||
|
||||
Add this mixin to any Django Form class along with a `djarea` configuration:
|
||||
|
||||
class ContactForm(DjareaFormMixin, forms.Form):
|
||||
djarea = DjareaFormMeta(
|
||||
name="contact",
|
||||
title="Contact Us",
|
||||
)
|
||||
|
||||
name = forms.CharField()
|
||||
email = forms.EmailField()
|
||||
|
||||
def on_submit_success(self, request):
|
||||
return {"sent": True}
|
||||
|
||||
This auto-registers:
|
||||
- contact.schema - Get form field definitions
|
||||
- contact.validate - Validate form data
|
||||
- contact.submit - Submit form
|
||||
|
||||
Overridable methods:
|
||||
get_init_kwargs(cls, request) -> dict: Extra kwargs for form instantiation
|
||||
on_submit_success(self, request) -> dict | None: Handle successful submission
|
||||
on_submit_failure(self, request, errors) -> None: Handle failed submission
|
||||
"""
|
||||
|
||||
# Configuration - subclasses must define this
|
||||
djarea: ClassVar[DjareaFormMeta]
|
||||
|
||||
# Track registered forms to avoid duplicate registration
|
||||
_djarea_registered: ClassVar[bool] = False
|
||||
|
||||
@classmethod
|
||||
def get_init_kwargs(cls, request: HttpRequest) -> dict[str, Any]:
|
||||
"""
|
||||
Override to provide extra kwargs for form instantiation.
|
||||
|
||||
Common use: pass request or user to forms that need them.
|
||||
|
||||
Example:
|
||||
@classmethod
|
||||
def get_init_kwargs(cls, request):
|
||||
return {"request": request, "user": request.user}
|
||||
"""
|
||||
return {}
|
||||
|
||||
def on_submit_success(self, request: HttpRequest) -> dict | None:
|
||||
"""
|
||||
Called after successful form validation and submission.
|
||||
|
||||
Override to handle the form submission logic.
|
||||
Return a dict to include data in the response.
|
||||
|
||||
Example:
|
||||
def on_submit_success(self, request):
|
||||
self.save()
|
||||
return {"id": self.instance.pk}
|
||||
"""
|
||||
# Default: call save() if available
|
||||
if hasattr(self, "save"):
|
||||
result = self.save()
|
||||
# If save returns something serializable, include it
|
||||
if isinstance(result, dict):
|
||||
return result
|
||||
return None
|
||||
|
||||
def on_submit_failure(
|
||||
self, request: HttpRequest, errors: "FormValidation"
|
||||
) -> None:
|
||||
"""
|
||||
Called after form validation fails.
|
||||
|
||||
Override to add custom error handling, logging, etc.
|
||||
"""
|
||||
pass
|
||||
|
||||
def __init_subclass__(cls, **kwargs):
|
||||
"""Auto-register when a concrete form class is defined."""
|
||||
super().__init_subclass__(**kwargs)
|
||||
|
||||
# Only register concrete forms with djarea config defined
|
||||
if _is_concrete_djarea_form(cls):
|
||||
_register_form_as_server_functions(cls)
|
||||
|
||||
|
||||
def _is_concrete_djarea_form(cls: type) -> bool:
|
||||
"""
|
||||
Check if a class is a concrete Djarea form ready for registration.
|
||||
|
||||
A form is concrete if:
|
||||
1. It has a `djarea` attribute that is a DjareaFormMeta instance
|
||||
2. It inherits from Django's BaseForm
|
||||
3. It hasn't been registered yet (for this class definition)
|
||||
"""
|
||||
# Must have djarea config (check cls.__dict__ to avoid inheriting)
|
||||
djarea_config = cls.__dict__.get("djarea")
|
||||
if not isinstance(djarea_config, DjareaFormMeta):
|
||||
return False
|
||||
|
||||
# Must be a Django form
|
||||
if not issubclass(cls, forms.BaseForm):
|
||||
return False
|
||||
|
||||
# Check if already registered (handle re-imports gracefully)
|
||||
if cls.__dict__.get("_djarea_registered", False):
|
||||
return False
|
||||
|
||||
return True
|
||||
|
||||
|
||||
def _register_form_as_server_functions(form_class: type) -> None:
|
||||
"""
|
||||
Register a Django Form class as Djarea server functions.
|
||||
|
||||
Creates and registers:
|
||||
- {name}.schema - Returns form field definitions
|
||||
- {name}.validate - Validates form data
|
||||
- {name}.submit - Validates and submits form
|
||||
|
||||
Each function gets a unique typed schema for better TypeScript LSP support.
|
||||
"""
|
||||
from .schemas import FormSchema, FormSubmitFail, FormSubmitPass, FormValidation
|
||||
from .schema_utils import build_form_schema
|
||||
from .validation_utils import validate_form_instance
|
||||
from djarea.setup.registry import register
|
||||
from djarea.client.function import ServerFunction
|
||||
|
||||
config: DjareaFormMeta = form_class.djarea
|
||||
form_name = config.name
|
||||
|
||||
# Mark as registered
|
||||
form_class._djarea_registered = True
|
||||
|
||||
# Generate PascalCase name for schemas (e.g., "contact" -> "Contact")
|
||||
pascal_name = ''.join(word.capitalize() for word in form_name.replace('.', '_').replace('-', '_').split('_'))
|
||||
|
||||
# NOTE: We cannot create FormDataSchema here because form fields aren't
|
||||
# populated yet during __init_subclass__. We use lazy creation instead.
|
||||
_form_data_schema_cache: dict[str, type[BaseModel]] = {}
|
||||
|
||||
def get_form_data_schema() -> type[BaseModel]:
|
||||
"""Lazily create the form data schema (form fields aren't available at registration time)."""
|
||||
if "schema" not in _form_data_schema_cache:
|
||||
_form_data_schema_cache["schema"] = _create_form_input_schema(
|
||||
form_class, f"{pascal_name}FormData"
|
||||
)
|
||||
return _form_data_schema_cache["schema"]
|
||||
|
||||
# -------------------------------------------------------------------------
|
||||
# Schema Function
|
||||
# -------------------------------------------------------------------------
|
||||
|
||||
# Schema input wraps the form data for pre-populating dynamic fields
|
||||
FormSchemaInput = create_model(
|
||||
f"{pascal_name}SchemaInput",
|
||||
data=(dict[str, Any], {}),
|
||||
)
|
||||
|
||||
class SchemaFunction(ServerFunction):
|
||||
Input = FormSchemaInput
|
||||
Output = FormSchema
|
||||
|
||||
_meta: ClassVar[dict] = {
|
||||
"form": True,
|
||||
"form_name": form_name,
|
||||
"form_role": "schema",
|
||||
"form_class": form_class, # Store reference for schema generation
|
||||
}
|
||||
|
||||
def call(self, input) -> FormSchema:
|
||||
init_kwargs = form_class.get_init_kwargs(self.request)
|
||||
schema = build_form_schema(
|
||||
form_class,
|
||||
data=input.data if input else {},
|
||||
**init_kwargs,
|
||||
)
|
||||
# Override with DjareaFormMeta values
|
||||
if config.title is not None:
|
||||
schema.title = config.title
|
||||
if config.subtitle is not None:
|
||||
schema.subtitle = config.subtitle
|
||||
schema.submit_label = config.submit_label
|
||||
# Behavior settings are nested in schema.meta
|
||||
schema.meta.live_validation = config.live_validation
|
||||
schema.meta.live_form_errors = config.live_form_errors
|
||||
schema.meta.refetch_schema_on_validate = config.refetch_schema_on_validate
|
||||
return schema
|
||||
|
||||
SchemaFunction.__name__ = f"{form_name}_schema"
|
||||
SchemaFunction.__qualname__ = f"{form_name}_schema"
|
||||
register(SchemaFunction, f"{form_name}.schema")
|
||||
|
||||
# -------------------------------------------------------------------------
|
||||
# Validate Function
|
||||
# -------------------------------------------------------------------------
|
||||
|
||||
# Use generic dict input - form fields aren't available during __init_subclass__
|
||||
FormValidateInput = create_model(
|
||||
f"{pascal_name}ValidateInput",
|
||||
data=(dict[str, Any], ...),
|
||||
)
|
||||
|
||||
class ValidateFunction(ServerFunction):
|
||||
Input = FormValidateInput
|
||||
Output = FormValidation
|
||||
|
||||
_meta: ClassVar[dict] = {
|
||||
"form": True,
|
||||
"form_name": form_name,
|
||||
"form_role": "validate",
|
||||
}
|
||||
|
||||
def call(self, input) -> FormValidation:
|
||||
init_kwargs = form_class.get_init_kwargs(self.request)
|
||||
# Input data is already a dict
|
||||
data = input.data
|
||||
_, validation = validate_form_instance(
|
||||
form_class,
|
||||
data=data,
|
||||
files=None,
|
||||
**init_kwargs,
|
||||
)
|
||||
return validation
|
||||
|
||||
ValidateFunction.__name__ = f"{form_name}_validate"
|
||||
ValidateFunction.__qualname__ = f"{form_name}_validate"
|
||||
register(ValidateFunction, f"{form_name}.validate")
|
||||
|
||||
# -------------------------------------------------------------------------
|
||||
# Submit Function
|
||||
# -------------------------------------------------------------------------
|
||||
|
||||
class SubmitFunction(ServerFunction):
|
||||
"""
|
||||
Submit function handles both JSON and multipart/form-data.
|
||||
|
||||
The executor detects form functions and parses the request appropriately.
|
||||
"""
|
||||
|
||||
# Use dict for input - form fields unknown at registration time
|
||||
Input = None # Signals executor to pass raw dict
|
||||
|
||||
_meta: ClassVar[dict] = {
|
||||
"form": True,
|
||||
"form_name": form_name,
|
||||
"form_role": "submit",
|
||||
"multipart": True, # Signal that this function accepts multipart
|
||||
}
|
||||
|
||||
def call(self, input) -> FormSubmitPass | FormSubmitFail:
|
||||
"""Execute form submission."""
|
||||
request = self.request
|
||||
|
||||
# Check if we have multipart data from executor
|
||||
if hasattr(request, "_djarea_form_data"):
|
||||
data = request._djarea_form_data
|
||||
files = request._djarea_form_files
|
||||
elif input is not None:
|
||||
# JSON input - already a dict
|
||||
data = input if isinstance(input, dict) else input.model_dump()
|
||||
files = None
|
||||
else:
|
||||
data = {}
|
||||
files = None
|
||||
|
||||
init_kwargs = form_class.get_init_kwargs(request)
|
||||
|
||||
# Create and validate form
|
||||
form, validation = validate_form_instance(
|
||||
form_class,
|
||||
data=data,
|
||||
files=files,
|
||||
**init_kwargs,
|
||||
)
|
||||
|
||||
if form.is_valid():
|
||||
# Call the form's on_submit_success
|
||||
result_data = form.on_submit_success(request)
|
||||
return FormSubmitPass(success=True, data=result_data)
|
||||
|
||||
# Call the form's on_submit_failure
|
||||
form.on_submit_failure(request, validation)
|
||||
return FormSubmitFail(success=False, errors=validation)
|
||||
|
||||
SubmitFunction.__name__ = f"{form_name}_submit"
|
||||
SubmitFunction.__qualname__ = f"{form_name}_submit"
|
||||
SubmitFunction.Output = FormSubmitPass # For schema generation
|
||||
register(SubmitFunction, f"{form_name}.submit")
|
||||
|
||||
# -------------------------------------------------------------------------
|
||||
# Formset Functions (if enabled)
|
||||
# -------------------------------------------------------------------------
|
||||
|
||||
if config.enable_formset:
|
||||
_register_formset_functions(form_class, form_name)
|
||||
|
||||
|
||||
def _register_formset_functions(
|
||||
form_class: type,
|
||||
form_name: str,
|
||||
) -> None:
|
||||
"""Register formset server functions for a form."""
|
||||
from django.forms import formset_factory
|
||||
|
||||
from .schemas import FormsetSchema, FormsetSubmitFail, FormsetSubmitPass, FormsetValidation
|
||||
from .schema_utils import build_form_schema
|
||||
from .validation_utils import build_formset_validation
|
||||
from .formset_utils import forms_to_formset_post_data
|
||||
from djarea.setup.registry import register
|
||||
from djarea.client.function import ServerFunction
|
||||
|
||||
formset_class = formset_factory(form_class)
|
||||
|
||||
# Generate PascalCase name for schemas
|
||||
pascal_name = ''.join(word.capitalize() for word in form_name.replace('.', '_').replace('-', '_').split('_'))
|
||||
|
||||
# NOTE: We cannot create typed schemas here because form fields aren't
|
||||
# populated yet during __init_subclass__. We use generic dict inputs.
|
||||
|
||||
# -------------------------------------------------------------------------
|
||||
# Formset Schema Function
|
||||
# -------------------------------------------------------------------------
|
||||
|
||||
FormsetSchemaInput = create_model(
|
||||
f"{pascal_name}FormsetSchemaInput",
|
||||
forms=(list[dict[str, Any]], []),
|
||||
)
|
||||
|
||||
class FormsetSchemaFunction(ServerFunction):
|
||||
Input = FormsetSchemaInput
|
||||
Output = FormsetSchema
|
||||
|
||||
_meta: ClassVar[dict] = {
|
||||
"form": True,
|
||||
"form_name": form_name,
|
||||
"form_role": "formset_schema",
|
||||
}
|
||||
|
||||
def call(self, input) -> FormsetSchema:
|
||||
init_kwargs = form_class.get_init_kwargs(self.request)
|
||||
forms_data = input.forms if input else []
|
||||
|
||||
formset_data = forms_to_formset_post_data(forms_data)
|
||||
formset = formset_class(formset_data)
|
||||
|
||||
return FormsetSchema(
|
||||
forms=[
|
||||
build_form_schema(form_class, data=fd, **init_kwargs)
|
||||
for fd in forms_data
|
||||
],
|
||||
min_num=formset.min_num,
|
||||
max_num=formset.max_num,
|
||||
can_delete=formset.can_delete,
|
||||
can_order=formset.can_order,
|
||||
)
|
||||
|
||||
FormsetSchemaFunction.__name__ = f"{form_name}_formset_schema"
|
||||
register(FormsetSchemaFunction, f"{form_name}.formset.schema")
|
||||
|
||||
# -------------------------------------------------------------------------
|
||||
# Formset Validate Function
|
||||
# -------------------------------------------------------------------------
|
||||
|
||||
# Generic dict input - form fields aren't available during __init_subclass__
|
||||
FormsetValidateInput = create_model(
|
||||
f"{pascal_name}FormsetValidateInput",
|
||||
forms=(list[dict[str, Any]], ...),
|
||||
)
|
||||
|
||||
class FormsetValidateFunction(ServerFunction):
|
||||
Input = FormsetValidateInput
|
||||
Output = FormsetValidation
|
||||
|
||||
_meta: ClassVar[dict] = {
|
||||
"form": True,
|
||||
"form_name": form_name,
|
||||
"form_role": "formset_validate",
|
||||
}
|
||||
|
||||
def call(self, input) -> FormsetValidation:
|
||||
init_kwargs = form_class.get_init_kwargs(self.request)
|
||||
# Input.forms is already a list of dicts
|
||||
forms_data = input.forms
|
||||
|
||||
formset_data = forms_to_formset_post_data(forms_data)
|
||||
formset = formset_class(formset_data, form_kwargs=init_kwargs)
|
||||
|
||||
for form in formset:
|
||||
form.empty_permitted = False
|
||||
|
||||
return build_formset_validation(formset)
|
||||
|
||||
FormsetValidateFunction.__name__ = f"{form_name}_formset_validate"
|
||||
register(FormsetValidateFunction, f"{form_name}.formset.validate")
|
||||
|
||||
# -------------------------------------------------------------------------
|
||||
# Formset Submit Function
|
||||
# -------------------------------------------------------------------------
|
||||
|
||||
# Generic dict input - form fields aren't available during __init_subclass__
|
||||
FormsetSubmitInput = create_model(
|
||||
f"{pascal_name}FormsetSubmitInput",
|
||||
forms=(list[dict[str, Any]], ...),
|
||||
)
|
||||
|
||||
class FormsetSubmitFunction(ServerFunction):
|
||||
Input = FormsetSubmitInput
|
||||
|
||||
_meta: ClassVar[dict] = {
|
||||
"form": True,
|
||||
"form_name": form_name,
|
||||
"form_role": "formset_submit",
|
||||
"multipart": True,
|
||||
}
|
||||
|
||||
def call(self, input) -> FormsetSubmitPass | FormsetSubmitFail:
|
||||
request = self.request
|
||||
init_kwargs = form_class.get_init_kwargs(request)
|
||||
|
||||
# Handle multipart vs JSON
|
||||
if hasattr(request, "_djarea_form_data"):
|
||||
post_data = request._djarea_form_data
|
||||
files = request._djarea_form_files
|
||||
elif input and hasattr(input, 'forms'):
|
||||
# Input.forms is already a list of dicts
|
||||
forms_data = input.forms
|
||||
post_data = forms_to_formset_post_data(forms_data)
|
||||
files = None
|
||||
else:
|
||||
post_data = {}
|
||||
files = None
|
||||
|
||||
formset = formset_class(post_data, files=files, form_kwargs=init_kwargs)
|
||||
|
||||
if formset.is_valid():
|
||||
for form in formset.forms:
|
||||
if form.cleaned_data:
|
||||
form.on_submit_success(request)
|
||||
return FormsetSubmitPass(success=True)
|
||||
|
||||
validation = build_formset_validation(formset)
|
||||
# Call failure handler on each form
|
||||
for form in formset.forms:
|
||||
if hasattr(form, "on_submit_failure"):
|
||||
form.on_submit_failure(request, validation)
|
||||
|
||||
return FormsetSubmitFail(success=False, errors=validation)
|
||||
|
||||
FormsetSubmitFunction.__name__ = f"{form_name}_formset_submit"
|
||||
FormsetSubmitFunction.Output = FormsetSubmitPass
|
||||
register(FormsetSubmitFunction, f"{form_name}.formset.submit")
|
||||
16
django/src/djarea/forms/formset_utils.py
Normal file
16
django/src/djarea/forms/formset_utils.py
Normal file
@@ -0,0 +1,16 @@
|
||||
from typing import Any
|
||||
|
||||
|
||||
def forms_to_formset_post_data(forms_data: list[dict[str, Any]]) -> dict[str, Any]:
|
||||
"""
|
||||
Convert a list of form dicts into Django formset-compatible POST data.
|
||||
"""
|
||||
formset_data: dict[str, Any] = {
|
||||
"form-TOTAL_FORMS": str(len(forms_data)),
|
||||
"form-INITIAL_FORMS": "0",
|
||||
}
|
||||
for i, form_data in enumerate(forms_data):
|
||||
formset_data.update(
|
||||
{f"form-{i}-{key}": value for key, value in form_data.items()}
|
||||
)
|
||||
return formset_data
|
||||
187
django/src/djarea/forms/schema_utils.py
Normal file
187
django/src/djarea/forms/schema_utils.py
Normal file
@@ -0,0 +1,187 @@
|
||||
import re
|
||||
from typing import Any, Optional
|
||||
|
||||
from django import forms
|
||||
from django.forms import Field
|
||||
|
||||
from .schemas import FieldChoice, FieldSchema, FormMeta, FormSchema
|
||||
|
||||
|
||||
def create_form_instance(
|
||||
form_class: type[forms.BaseForm],
|
||||
data: Optional[dict] = None,
|
||||
files: Optional[dict] = None,
|
||||
**kwargs,
|
||||
) -> forms.BaseForm:
|
||||
"""
|
||||
Create a form instance, gracefully handling kwargs that the form doesn't accept.
|
||||
|
||||
Some Django forms (like allauth's) accept `request` in __init__, others don't.
|
||||
This function tries with all kwargs first, then progressively removes kwargs
|
||||
that cause TypeErrors until instantiation succeeds.
|
||||
"""
|
||||
# Common kwargs that forms may or may not accept
|
||||
optional_kwargs = ['request', 'user', 'instance']
|
||||
|
||||
# Build init kwargs
|
||||
init_kwargs = dict(kwargs)
|
||||
if data is not None:
|
||||
init_kwargs['data'] = data
|
||||
if files is not None:
|
||||
init_kwargs['files'] = files
|
||||
|
||||
while True:
|
||||
try:
|
||||
return form_class(**init_kwargs)
|
||||
except TypeError as e:
|
||||
error_msg = str(e)
|
||||
|
||||
# Check if it's an unexpected keyword argument error
|
||||
if "unexpected keyword argument" not in error_msg:
|
||||
raise
|
||||
|
||||
# Find which kwarg caused the problem and remove it
|
||||
removed = False
|
||||
for kwarg in optional_kwargs:
|
||||
if f"'{kwarg}'" in error_msg and kwarg in init_kwargs:
|
||||
init_kwargs.pop(kwarg)
|
||||
removed = True
|
||||
break
|
||||
|
||||
# If we couldn't identify/remove the problematic kwarg, re-raise
|
||||
if not removed:
|
||||
raise
|
||||
|
||||
|
||||
def _get_choices(field: Field) -> Optional[list[FieldChoice]]:
|
||||
"""
|
||||
Extract choices from a field, handling ModelChoiceField properly.
|
||||
ModelChoiceField returns ModelChoiceIteratorValue which is not JSON serializable.
|
||||
"""
|
||||
if not hasattr(field, "choices"):
|
||||
return None
|
||||
|
||||
choices: list[FieldChoice] = []
|
||||
for raw_value, label in field.choices:
|
||||
value = getattr(
|
||||
raw_value, "value", raw_value
|
||||
) # ModelChoiceIteratorValue -> .value
|
||||
choices.append(FieldChoice(value=str(value), label=str(label)))
|
||||
|
||||
return choices
|
||||
|
||||
|
||||
def _get_initial(value: Any) -> Any:
|
||||
"""Convert initial value to JSON-serializable format."""
|
||||
if value is None:
|
||||
return None
|
||||
if hasattr(value, "isoformat"):
|
||||
return value.isoformat()
|
||||
if hasattr(value, "pk"):
|
||||
return value.pk
|
||||
if hasattr(value, "__iter__") and not isinstance(value, (str, bytes)):
|
||||
return [item.pk if hasattr(item, "pk") else item for item in value]
|
||||
return value
|
||||
|
||||
|
||||
def _class_name_to_title(name: str) -> str:
|
||||
"""
|
||||
Convert a class name to a human-readable title.
|
||||
e.g., 'LoginForm' -> 'Login', 'ResetPasswordForm' -> 'Reset Password'
|
||||
"""
|
||||
# Remove 'Form' suffix
|
||||
name = re.sub(r"Form$", "", name)
|
||||
# Insert spaces before capital letters
|
||||
name = re.sub(r"([a-z])([A-Z])", r"\1 \2", name)
|
||||
return name
|
||||
|
||||
|
||||
def _class_name_to_slug(name: str) -> str:
|
||||
"""
|
||||
Convert a class name to a slug.
|
||||
e.g., 'LoginForm' -> 'login', 'ResetPasswordForm' -> 'reset_password'
|
||||
"""
|
||||
# Remove 'Form' suffix
|
||||
name = re.sub(r"Form$", "", name)
|
||||
# Insert underscores before capital letters and lowercase
|
||||
name = re.sub(r"([a-z])([A-Z])", r"\1_\2", name)
|
||||
return name.lower()
|
||||
|
||||
|
||||
def build_form_schema(
|
||||
form_class: type[forms.BaseForm],
|
||||
data: Optional[dict] = None,
|
||||
**kwargs,
|
||||
) -> FormSchema:
|
||||
"""
|
||||
Produce a FormSchema for the given Django form class and (optional) data.
|
||||
|
||||
The form class can define metadata via an inner Meta class:
|
||||
|
||||
class MyForm(forms.Form):
|
||||
class Meta:
|
||||
form_name = "my_form"
|
||||
title = "My Form Title"
|
||||
subtitle = "Optional description"
|
||||
submit_label = "Submit"
|
||||
|
||||
# Frontend behavior (optional)
|
||||
refetch_schema_on_validate = False # Set True for dynamic choice fields
|
||||
live_validation = True # Set False to disable live validation
|
||||
live_form_errors = False # Set True to show form errors live
|
||||
|
||||
If not provided, sensible defaults are derived from the class name.
|
||||
"""
|
||||
form = create_form_instance(form_class, data=data, **kwargs)
|
||||
|
||||
# Extract metadata from form's Meta class
|
||||
form_meta = getattr(form_class, "Meta", None)
|
||||
|
||||
# Get form name (used as identifier)
|
||||
name = getattr(form_meta, "form_name", None)
|
||||
if name is None:
|
||||
name = _class_name_to_slug(form_class.__name__)
|
||||
|
||||
# Get title (human-readable heading)
|
||||
title = getattr(form_meta, "title", None)
|
||||
if title is None:
|
||||
title = _class_name_to_title(form_class.__name__)
|
||||
|
||||
# Get optional subtitle
|
||||
subtitle = getattr(form_meta, "subtitle", None)
|
||||
|
||||
# Get submit button label
|
||||
submit_label = getattr(form_meta, "submit_label", None)
|
||||
if submit_label is None:
|
||||
submit_label = "Submit"
|
||||
|
||||
# Build frontend behavior metadata
|
||||
frontend_meta = FormMeta(
|
||||
refetch_schema_on_validate=getattr(form_meta, "refetch_schema_on_validate", False),
|
||||
live_validation=getattr(form_meta, "live_validation", True),
|
||||
live_form_errors=getattr(form_meta, "live_form_errors", False),
|
||||
)
|
||||
|
||||
return FormSchema(
|
||||
name=name,
|
||||
title=title,
|
||||
subtitle=subtitle,
|
||||
submit_label=submit_label,
|
||||
fields=[
|
||||
FieldSchema(
|
||||
name=name,
|
||||
label=str(field.label or name.replace("_", " ").title()),
|
||||
type=getattr(field.widget, "input_type", "text"),
|
||||
widget=field.widget.__class__.__name__,
|
||||
required=field.required,
|
||||
disabled=field.disabled,
|
||||
help_text=str(field.help_text) if field.help_text else "",
|
||||
initial=_get_initial(field.initial),
|
||||
max_length=getattr(field, "max_length", None),
|
||||
min_length=getattr(field, "min_length", None),
|
||||
choices=_get_choices(field),
|
||||
)
|
||||
for name, field in form.fields.items()
|
||||
],
|
||||
meta=frontend_meta,
|
||||
)
|
||||
103
django/src/djarea/forms/schemas.py
Normal file
103
django/src/djarea/forms/schemas.py
Normal file
@@ -0,0 +1,103 @@
|
||||
from typing import Any, Optional
|
||||
|
||||
from ninja import Schema
|
||||
|
||||
|
||||
# Form metadata schema
|
||||
class FormMeta(Schema):
|
||||
"""
|
||||
Metadata controlling frontend form behavior.
|
||||
|
||||
Attributes:
|
||||
refetch_schema_on_validate: If True, frontend should refetch schema on each
|
||||
validation (useful for dynamic choice fields). Default False.
|
||||
live_validation: If False, frontend should disable live validation entirely.
|
||||
Useful for sensitive forms like login. Default True.
|
||||
live_form_errors: If True, show form-level errors during live validation.
|
||||
Form errors are things like "Invalid credentials" vs field errors like
|
||||
"This field is required". Default False for security.
|
||||
"""
|
||||
refetch_schema_on_validate: bool = False
|
||||
live_validation: bool = True
|
||||
live_form_errors: bool = False
|
||||
|
||||
|
||||
# Field-level schemas
|
||||
class FieldChoice(Schema):
|
||||
value: str
|
||||
label: str
|
||||
|
||||
|
||||
class FieldError(Schema):
|
||||
message: str
|
||||
code: Optional[str]
|
||||
|
||||
|
||||
class FieldErrorList(Schema):
|
||||
field: str
|
||||
errors: list[FieldError]
|
||||
|
||||
|
||||
class FieldSchema(Schema):
|
||||
name: str
|
||||
label: str
|
||||
type: str
|
||||
widget: str
|
||||
required: bool
|
||||
disabled: bool
|
||||
help_text: str
|
||||
initial: Any
|
||||
max_length: Optional[int]
|
||||
min_length: Optional[int]
|
||||
choices: Optional[list[FieldChoice]]
|
||||
|
||||
|
||||
# Form-level schemas
|
||||
class FormSchema(Schema):
|
||||
"""Schema returned by /schema endpoint with form metadata and fields."""
|
||||
# Form metadata
|
||||
name: str
|
||||
title: str
|
||||
subtitle: Optional[str]
|
||||
submit_label: str
|
||||
# Fields
|
||||
fields: list[FieldSchema]
|
||||
# Frontend behavior metadata
|
||||
meta: FormMeta = FormMeta()
|
||||
|
||||
|
||||
class FormValidation(Schema):
|
||||
errors: list[FieldErrorList]
|
||||
|
||||
|
||||
class FormSubmitPass(Schema):
|
||||
success: bool
|
||||
data: Optional[dict] = None
|
||||
|
||||
|
||||
class FormSubmitFail(Schema):
|
||||
success: bool
|
||||
errors: FormValidation
|
||||
|
||||
|
||||
# Formset-level schemas
|
||||
class FormsetSchema(Schema):
|
||||
forms: list[FormSchema]
|
||||
min_num: int
|
||||
max_num: int
|
||||
can_delete: bool
|
||||
can_order: bool
|
||||
|
||||
|
||||
class FormsetValidation(Schema):
|
||||
general: list[str]
|
||||
per_form: list[FormValidation]
|
||||
|
||||
|
||||
class FormsetSubmitPass(Schema):
|
||||
success: bool
|
||||
|
||||
|
||||
class FormsetSubmitFail(Schema):
|
||||
success: bool
|
||||
errors: FormsetValidation
|
||||
72
django/src/djarea/forms/validation_utils.py
Normal file
72
django/src/djarea/forms/validation_utils.py
Normal file
@@ -0,0 +1,72 @@
|
||||
from typing import Any
|
||||
|
||||
from django import forms
|
||||
from django.core.files.uploadedfile import UploadedFile
|
||||
from django.utils.datastructures import MultiValueDict
|
||||
|
||||
from .schemas import (
|
||||
FieldError,
|
||||
FieldErrorList,
|
||||
FormValidation,
|
||||
FormsetValidation,
|
||||
)
|
||||
from .schema_utils import create_form_instance
|
||||
|
||||
|
||||
def validate_form_instance(
|
||||
form_class: type[forms.BaseForm],
|
||||
data: dict,
|
||||
files: MultiValueDict[str, UploadedFile] | None = None,
|
||||
**kwargs: Any,
|
||||
) -> tuple[forms.BaseForm, FormValidation]:
|
||||
"""
|
||||
Build a form instance and return (form, structured_validation_errors).
|
||||
"""
|
||||
form = create_form_instance(form_class, data=data, files=files, initial=data, **kwargs)
|
||||
|
||||
# Run validation
|
||||
form.is_valid()
|
||||
|
||||
validation = FormValidation(
|
||||
errors=[
|
||||
FieldErrorList(
|
||||
field=field_name,
|
||||
errors=[
|
||||
FieldError(
|
||||
message=str(e.message) if hasattr(e, 'message') else str(e),
|
||||
code=getattr(e, "code", None),
|
||||
)
|
||||
for e in field_errors.as_data()
|
||||
],
|
||||
)
|
||||
for field_name, field_errors in form.errors.items()
|
||||
]
|
||||
)
|
||||
return form, validation
|
||||
|
||||
|
||||
def build_formset_validation(formset: forms.BaseFormSet) -> FormsetValidation:
|
||||
"""
|
||||
Turn a Django formset into a FormsetValidation structure.
|
||||
"""
|
||||
return FormsetValidation(
|
||||
general=[str(e) if e else "" for e in formset.non_form_errors()],
|
||||
per_form=[
|
||||
FormValidation(
|
||||
errors=[
|
||||
FieldErrorList(
|
||||
field=field_name,
|
||||
errors=[
|
||||
FieldError(
|
||||
message=str(e.message) if hasattr(e, 'message') else str(e),
|
||||
code=getattr(e, "code", None),
|
||||
)
|
||||
for e in field_errors.as_data()
|
||||
],
|
||||
)
|
||||
for field_name, field_errors in form.errors.items()
|
||||
]
|
||||
)
|
||||
for form in formset
|
||||
],
|
||||
)
|
||||
25
django/src/djarea/integrations/allauth/__init__.py
Normal file
25
django/src/djarea/integrations/allauth/__init__.py
Normal file
@@ -0,0 +1,25 @@
|
||||
"""
|
||||
Djarea Allauth Integration
|
||||
|
||||
Backend support for django-allauth with Djarea server functions.
|
||||
|
||||
Provides:
|
||||
- Auth contexts (auth_status, user) - required by frontend allauth module
|
||||
- Allauth form wrappers - expose allauth forms as server functions
|
||||
|
||||
Usage:
|
||||
# In your app's apps.py
|
||||
class MyAppConfig(AppConfig):
|
||||
def ready(self):
|
||||
import djarea.allauth.forms # noqa - registers forms
|
||||
import djarea.allauth.contexts # noqa - registers contexts
|
||||
"""
|
||||
|
||||
from .contexts import auth_status, user, AuthStatusOutput, UserOutput
|
||||
|
||||
__all__ = [
|
||||
"auth_status",
|
||||
"user",
|
||||
"AuthStatusOutput",
|
||||
"UserOutput",
|
||||
]
|
||||
115
django/src/djarea/integrations/allauth/contexts.py
Normal file
115
django/src/djarea/integrations/allauth/contexts.py
Normal file
@@ -0,0 +1,115 @@
|
||||
"""
|
||||
Auth contexts for Djarea Allauth integration.
|
||||
|
||||
These are the core auth primitives that the frontend allauth module depends on.
|
||||
Separated into two concerns:
|
||||
|
||||
- auth_status: Authentication state and permission guards (fast, no DB hit with JWT)
|
||||
- user: Full user profile data (may require DB query for JWT auth)
|
||||
|
||||
Both are registered as global contexts for SSR hydration.
|
||||
"""
|
||||
|
||||
from django.http import HttpRequest
|
||||
from pydantic import BaseModel
|
||||
|
||||
from djarea.client import client
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# Auth Status Context
|
||||
# =============================================================================
|
||||
|
||||
|
||||
class AuthStatusOutput(BaseModel):
|
||||
"""Authentication status and permission guards."""
|
||||
is_authenticated: bool
|
||||
user_id: int | None = None
|
||||
is_staff: bool = False
|
||||
is_superuser: bool = False
|
||||
|
||||
|
||||
@client(context='global')
|
||||
def auth_status(request: HttpRequest) -> AuthStatusOutput:
|
||||
"""
|
||||
Auth status context - provides authentication state and guards.
|
||||
|
||||
This works identically for both session and JWT auth. The data comes
|
||||
from the request.user object (either full User or JWTUser with claims).
|
||||
|
||||
Frontend:
|
||||
const auth = useAuthStatus()
|
||||
if (auth.is_authenticated) { ... }
|
||||
if (auth.is_staff) { ... }
|
||||
"""
|
||||
user = request.user
|
||||
|
||||
if not user.is_authenticated:
|
||||
return AuthStatusOutput(is_authenticated=False)
|
||||
|
||||
return AuthStatusOutput(
|
||||
is_authenticated=True,
|
||||
user_id=user.id,
|
||||
is_staff=user.is_staff,
|
||||
is_superuser=user.is_superuser,
|
||||
)
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# User Profile Context
|
||||
# =============================================================================
|
||||
|
||||
|
||||
class UserOutput(BaseModel):
|
||||
"""Full user profile data."""
|
||||
id: int
|
||||
email: str
|
||||
first_name: str = ""
|
||||
last_name: str = ""
|
||||
|
||||
|
||||
@client(context='global')
|
||||
def user(request: HttpRequest) -> UserOutput | None:
|
||||
"""
|
||||
User profile context - provides full user data.
|
||||
|
||||
Unlike auth_status, this may require a DB query (for JWT auth where
|
||||
the user object is a minimal JWTUser with only claims).
|
||||
|
||||
Returns None if not authenticated.
|
||||
|
||||
Frontend:
|
||||
const user = useUser()
|
||||
if (user) {
|
||||
console.log(user.email)
|
||||
}
|
||||
"""
|
||||
req_user = request.user
|
||||
|
||||
if not req_user.is_authenticated:
|
||||
return None
|
||||
|
||||
# Check if we have full user data or just JWT claims
|
||||
if hasattr(req_user, 'email') and req_user.email:
|
||||
# Full User object (session auth)
|
||||
return UserOutput(
|
||||
id=req_user.id,
|
||||
email=req_user.email,
|
||||
first_name=getattr(req_user, 'first_name', '') or '',
|
||||
last_name=getattr(req_user, 'last_name', '') or '',
|
||||
)
|
||||
|
||||
# JWTUser - need to fetch from DB
|
||||
from django.contrib.auth import get_user_model
|
||||
User = get_user_model()
|
||||
|
||||
try:
|
||||
db_user = User.objects.get(pk=req_user.id)
|
||||
return UserOutput(
|
||||
id=db_user.id,
|
||||
email=db_user.email,
|
||||
first_name=db_user.first_name or '',
|
||||
last_name=db_user.last_name or '',
|
||||
)
|
||||
except User.DoesNotExist:
|
||||
return None
|
||||
400
django/src/djarea/integrations/allauth/forms.py
Normal file
400
django/src/djarea/integrations/allauth/forms.py
Normal file
@@ -0,0 +1,400 @@
|
||||
"""
|
||||
Allauth forms as Djarea server functions.
|
||||
|
||||
This module wraps allauth forms with DjareaFormMixin, exposing them as
|
||||
typed server functions for the React frontend.
|
||||
|
||||
Each form becomes three server functions:
|
||||
- {name}.schema - Get form field definitions
|
||||
- {name}.validate - Validate form data
|
||||
- {name}.submit - Submit form
|
||||
|
||||
Import this module in your app's ready() to register the forms:
|
||||
|
||||
class MyAppConfig(AppConfig):
|
||||
def ready(self):
|
||||
import djarea.allauth.forms # noqa
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import TYPE_CHECKING, Any
|
||||
|
||||
from django.http import HttpRequest
|
||||
|
||||
from djarea.forms import DjareaFormMixin, DjareaFormMeta
|
||||
|
||||
# Account forms
|
||||
from allauth.account.forms import (
|
||||
AddEmailForm,
|
||||
ChangePasswordForm,
|
||||
ConfirmLoginCodeForm,
|
||||
LoginForm,
|
||||
RequestLoginCodeForm,
|
||||
ResetPasswordForm,
|
||||
ResetPasswordKeyForm,
|
||||
SetPasswordForm,
|
||||
SignupForm,
|
||||
UserTokenForm,
|
||||
)
|
||||
|
||||
# Password reauthentication form - conditionally import
|
||||
try:
|
||||
from allauth.account.forms import ReauthenticateForm
|
||||
HAS_REAUTH = True
|
||||
except ImportError:
|
||||
HAS_REAUTH = False
|
||||
|
||||
# MFA forms - conditionally import
|
||||
try:
|
||||
from allauth.mfa.base.forms import AuthenticateForm as MFAAuthenticateForm
|
||||
from allauth.mfa.base.forms import ReauthenticateForm as MFAReauthenticateForm
|
||||
from allauth.mfa.totp.forms import ActivateTOTPForm, DeactivateTOTPForm
|
||||
from allauth.mfa.recovery_codes.forms import GenerateRecoveryCodesForm
|
||||
HAS_MFA = True
|
||||
except ImportError:
|
||||
HAS_MFA = False
|
||||
|
||||
# WebAuthn forms (if available)
|
||||
try:
|
||||
from allauth.mfa.webauthn.forms import AuthenticateWebAuthnForm
|
||||
HAS_WEBAUTHN = True
|
||||
except ImportError:
|
||||
HAS_WEBAUTHN = False
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from djarea.forms.schemas import FormValidation
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# Account Forms
|
||||
# =============================================================================
|
||||
|
||||
class DjareaLoginForm(LoginForm, DjareaFormMixin):
|
||||
"""Sign in with email and password."""
|
||||
|
||||
djarea = DjareaFormMeta(
|
||||
name="login",
|
||||
title="Sign In",
|
||||
subtitle="Welcome back. Enter your credentials to continue.",
|
||||
submit_label="Sign In",
|
||||
live_validation=False, # Don't validate credentials as user types
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def get_init_kwargs(cls, request: HttpRequest) -> dict[str, Any]:
|
||||
return {"request": request}
|
||||
|
||||
def on_submit_success(self, request: HttpRequest) -> dict | None:
|
||||
self.login(request)
|
||||
return None
|
||||
|
||||
|
||||
class DjareaSignupForm(SignupForm, DjareaFormMixin):
|
||||
"""Create a new account."""
|
||||
|
||||
djarea = DjareaFormMeta(
|
||||
name="signup",
|
||||
title="Create Account",
|
||||
subtitle="Enter your details to get started.",
|
||||
submit_label="Create Account",
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def get_init_kwargs(cls, request: HttpRequest) -> dict[str, Any]:
|
||||
return {"request": request}
|
||||
|
||||
def on_submit_success(self, request: HttpRequest) -> dict | None:
|
||||
self.save(request)
|
||||
return None
|
||||
|
||||
|
||||
class DjareaAddEmailForm(AddEmailForm, DjareaFormMixin):
|
||||
"""Add another email address to your account."""
|
||||
|
||||
djarea = DjareaFormMeta(
|
||||
name="add_email",
|
||||
title="Add Email Address",
|
||||
subtitle="Add another email address to your account.",
|
||||
submit_label="Add Email",
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def get_init_kwargs(cls, request: HttpRequest) -> dict[str, Any]:
|
||||
return {"request": request, "user": request.user}
|
||||
|
||||
def on_submit_success(self, request: HttpRequest) -> dict | None:
|
||||
self.save()
|
||||
return None
|
||||
|
||||
|
||||
class DjareaChangePasswordForm(ChangePasswordForm, DjareaFormMixin):
|
||||
"""Change your account password."""
|
||||
|
||||
djarea = DjareaFormMeta(
|
||||
name="change_password",
|
||||
title="Change Password",
|
||||
subtitle="Update your password to keep your account secure.",
|
||||
submit_label="Change Password",
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def get_init_kwargs(cls, request: HttpRequest) -> dict[str, Any]:
|
||||
return {"request": request, "user": request.user}
|
||||
|
||||
def on_submit_success(self, request: HttpRequest) -> dict | None:
|
||||
self.save()
|
||||
return None
|
||||
|
||||
|
||||
class DjareaSetPasswordForm(SetPasswordForm, DjareaFormMixin):
|
||||
"""Set a password for accounts created via social login."""
|
||||
|
||||
djarea = DjareaFormMeta(
|
||||
name="set_password",
|
||||
title="Set Password",
|
||||
subtitle="Create a password for your account.",
|
||||
submit_label="Set Password",
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def get_init_kwargs(cls, request: HttpRequest) -> dict[str, Any]:
|
||||
return {"request": request, "user": request.user}
|
||||
|
||||
def on_submit_success(self, request: HttpRequest) -> dict | None:
|
||||
self.save()
|
||||
return None
|
||||
|
||||
|
||||
class DjareaResetPasswordForm(ResetPasswordForm, DjareaFormMixin):
|
||||
"""Request a password reset email."""
|
||||
|
||||
djarea = DjareaFormMeta(
|
||||
name="reset_password",
|
||||
title="Reset Password",
|
||||
subtitle="Enter your email address and we'll send you a link to reset your password.",
|
||||
submit_label="Send Reset Link",
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def get_init_kwargs(cls, request: HttpRequest) -> dict[str, Any]:
|
||||
return {"request": request}
|
||||
|
||||
def on_submit_success(self, request: HttpRequest) -> dict | None:
|
||||
self.save(request)
|
||||
return None
|
||||
|
||||
|
||||
class DjareaResetPasswordKeyForm(ResetPasswordKeyForm, DjareaFormMixin):
|
||||
"""Set a new password using a reset key."""
|
||||
|
||||
djarea = DjareaFormMeta(
|
||||
name="reset_password_from_key",
|
||||
title="Set New Password",
|
||||
subtitle="Enter your new password below.",
|
||||
submit_label="Reset Password",
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def get_init_kwargs(cls, request: HttpRequest) -> dict[str, Any]:
|
||||
return {"request": request, "user": request.user}
|
||||
|
||||
def on_submit_success(self, request: HttpRequest) -> dict | None:
|
||||
self.save()
|
||||
return None
|
||||
|
||||
|
||||
class DjareaRequestLoginCodeForm(RequestLoginCodeForm, DjareaFormMixin):
|
||||
"""Request a login code via email."""
|
||||
|
||||
djarea = DjareaFormMeta(
|
||||
name="request_login_code",
|
||||
title="Sign In with Code",
|
||||
subtitle="Enter your email address and we'll send you a login code.",
|
||||
submit_label="Send Code",
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def get_init_kwargs(cls, request: HttpRequest) -> dict[str, Any]:
|
||||
return {"request": request}
|
||||
|
||||
def on_submit_success(self, request: HttpRequest) -> dict | None:
|
||||
self.save()
|
||||
return None
|
||||
|
||||
|
||||
class DjareaConfirmLoginCodeForm(ConfirmLoginCodeForm, DjareaFormMixin):
|
||||
"""Confirm a login code."""
|
||||
|
||||
djarea = DjareaFormMeta(
|
||||
name="confirm_login_code",
|
||||
title="Enter Code",
|
||||
subtitle="Enter the code we sent to your email.",
|
||||
submit_label="Verify Code",
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def get_init_kwargs(cls, request: HttpRequest) -> dict[str, Any]:
|
||||
return {"request": request}
|
||||
|
||||
def on_submit_success(self, request: HttpRequest) -> dict | None:
|
||||
self.save()
|
||||
return None
|
||||
|
||||
|
||||
class DjareaUserTokenForm(UserTokenForm, DjareaFormMixin):
|
||||
"""Verify an email with a token."""
|
||||
|
||||
djarea = DjareaFormMeta(
|
||||
name="user_token",
|
||||
title="Verify Email",
|
||||
subtitle="Enter the verification code from your email.",
|
||||
submit_label="Verify",
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def get_init_kwargs(cls, request: HttpRequest) -> dict[str, Any]:
|
||||
return {"request": request}
|
||||
|
||||
def on_submit_success(self, request: HttpRequest) -> dict | None:
|
||||
self.save()
|
||||
return None
|
||||
|
||||
|
||||
# Password reauthentication - conditionally define
|
||||
if HAS_REAUTH:
|
||||
class DjareaReauthenticateForm(ReauthenticateForm, DjareaFormMixin):
|
||||
"""Re-authenticate with password for sensitive actions."""
|
||||
|
||||
djarea = DjareaFormMeta(
|
||||
name="reauthenticate",
|
||||
title="Confirm Your Identity",
|
||||
subtitle="Please enter your password to continue.",
|
||||
submit_label="Confirm",
|
||||
live_validation=False,
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def get_init_kwargs(cls, request: HttpRequest) -> dict[str, Any]:
|
||||
return {"request": request, "user": request.user}
|
||||
|
||||
def on_submit_success(self, request: HttpRequest) -> dict | None:
|
||||
from allauth.account.internal.flows import reauthentication
|
||||
reauthentication.reauthenticate_by_password(request)
|
||||
return None
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# MFA Forms
|
||||
# =============================================================================
|
||||
|
||||
if HAS_MFA:
|
||||
class DjareaMFAAuthenticateForm(MFAAuthenticateForm, DjareaFormMixin):
|
||||
"""Authenticate with MFA during login."""
|
||||
|
||||
djarea = DjareaFormMeta(
|
||||
name="mfa_authenticate",
|
||||
title="Two-Factor Authentication",
|
||||
subtitle="Enter your authentication code to continue.",
|
||||
submit_label="Verify",
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def get_init_kwargs(cls, request: HttpRequest) -> dict[str, Any]:
|
||||
return {"request": request, "user": request.user}
|
||||
|
||||
def on_submit_success(self, request: HttpRequest) -> dict | None:
|
||||
self.save()
|
||||
return None
|
||||
|
||||
class DjareaMFAReauthenticateForm(MFAReauthenticateForm, DjareaFormMixin):
|
||||
"""Re-authenticate with MFA for sensitive actions."""
|
||||
|
||||
djarea = DjareaFormMeta(
|
||||
name="mfa_reauthenticate",
|
||||
title="Confirm Your Identity",
|
||||
subtitle="Enter your authentication code to continue.",
|
||||
submit_label="Confirm",
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def get_init_kwargs(cls, request: HttpRequest) -> dict[str, Any]:
|
||||
return {"request": request, "user": request.user}
|
||||
|
||||
def on_submit_success(self, request: HttpRequest) -> dict | None:
|
||||
self.save()
|
||||
return None
|
||||
|
||||
class DjareaActivateTOTPForm(ActivateTOTPForm, DjareaFormMixin):
|
||||
"""Activate TOTP authenticator."""
|
||||
|
||||
djarea = DjareaFormMeta(
|
||||
name="activate_totp",
|
||||
title="Set Up Authenticator",
|
||||
subtitle="Enter the code from your authenticator app to complete setup.",
|
||||
submit_label="Activate",
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def get_init_kwargs(cls, request: HttpRequest) -> dict[str, Any]:
|
||||
return {"request": request, "user": request.user}
|
||||
|
||||
def on_submit_success(self, request: HttpRequest) -> dict | None:
|
||||
self.save()
|
||||
return None
|
||||
|
||||
class DjareaDeactivateTOTPForm(DeactivateTOTPForm, DjareaFormMixin):
|
||||
"""Deactivate TOTP authenticator."""
|
||||
|
||||
djarea = DjareaFormMeta(
|
||||
name="deactivate_totp",
|
||||
title="Disable Authenticator",
|
||||
subtitle="Enter your password to disable two-factor authentication.",
|
||||
submit_label="Disable",
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def get_init_kwargs(cls, request: HttpRequest) -> dict[str, Any]:
|
||||
return {"request": request, "user": request.user}
|
||||
|
||||
def on_submit_success(self, request: HttpRequest) -> dict | None:
|
||||
self.save()
|
||||
return None
|
||||
|
||||
class DjareaGenerateRecoveryCodesForm(GenerateRecoveryCodesForm, DjareaFormMixin):
|
||||
"""Generate new recovery codes."""
|
||||
|
||||
djarea = DjareaFormMeta(
|
||||
name="generate_recovery_codes",
|
||||
title="Recovery Codes",
|
||||
subtitle="Generate new recovery codes for your account.",
|
||||
submit_label="Generate Codes",
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def get_init_kwargs(cls, request: HttpRequest) -> dict[str, Any]:
|
||||
return {"request": request, "user": request.user}
|
||||
|
||||
def on_submit_success(self, request: HttpRequest) -> dict | None:
|
||||
self.save()
|
||||
return None
|
||||
|
||||
|
||||
if HAS_WEBAUTHN:
|
||||
class DjareaAuthenticateWebAuthnForm(AuthenticateWebAuthnForm, DjareaFormMixin):
|
||||
"""Authenticate with WebAuthn security key."""
|
||||
|
||||
djarea = DjareaFormMeta(
|
||||
name="webauthn_authenticate",
|
||||
title="Security Key",
|
||||
subtitle="Use your security key to authenticate.",
|
||||
submit_label="Use Security Key",
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def get_init_kwargs(cls, request: HttpRequest) -> dict[str, Any]:
|
||||
return {"request": request, "user": request.user}
|
||||
|
||||
def on_submit_success(self, request: HttpRequest) -> dict | None:
|
||||
self.save()
|
||||
return None
|
||||
70
django/src/djarea/jwt/__init__.py
Normal file
70
django/src/djarea/jwt/__init__.py
Normal file
@@ -0,0 +1,70 @@
|
||||
"""
|
||||
djarea.jwt - JWT authentication for server functions.
|
||||
|
||||
Provides:
|
||||
- Server functions for obtaining/refreshing JWT tokens
|
||||
- JWT authentication utilities for validating tokens
|
||||
|
||||
Server Functions:
|
||||
- jwt_obtain: Convert authenticated session to JWT tokens
|
||||
- jwt_refresh: Refresh tokens using a refresh token
|
||||
|
||||
Usage in apps.py or urls.py (to register the functions):
|
||||
import djarea.jwt.functions # noqa: F401
|
||||
|
||||
Note: This module is purpose-built for Djarea server functions.
|
||||
For Django Ninja API authentication, use djarea.jwt.security directly.
|
||||
"""
|
||||
|
||||
# Server functions (import to register with @client decorator)
|
||||
from .functions import jwt_obtain, jwt_refresh
|
||||
|
||||
# Token utilities
|
||||
from .tokens import (
|
||||
create_token_pair,
|
||||
create_access_token,
|
||||
create_refresh_token,
|
||||
decode_token,
|
||||
refresh_tokens,
|
||||
TokenPair,
|
||||
TokenPayload,
|
||||
JWTUser,
|
||||
)
|
||||
|
||||
# Settings
|
||||
from .settings import get_settings, JWTSettings
|
||||
|
||||
# Security (Ninja API auth) - lazy import to avoid triggering
|
||||
# django-ninja's settings access at module load time.
|
||||
# Use: from djarea.jwt.security import jwt_auth
|
||||
|
||||
|
||||
def __getattr__(name):
|
||||
if name in ("JWTAuth", "jwt_auth"):
|
||||
from .security import JWTAuth, jwt_auth
|
||||
globals()["JWTAuth"] = JWTAuth
|
||||
globals()["jwt_auth"] = jwt_auth
|
||||
return globals()[name]
|
||||
raise AttributeError(f"module {__name__!r} has no attribute {name!r}")
|
||||
|
||||
|
||||
__all__ = [
|
||||
# Server functions
|
||||
"jwt_obtain",
|
||||
"jwt_refresh",
|
||||
# Token utilities
|
||||
"create_token_pair",
|
||||
"create_access_token",
|
||||
"create_refresh_token",
|
||||
"decode_token",
|
||||
"refresh_tokens",
|
||||
"TokenPair",
|
||||
"TokenPayload",
|
||||
"JWTUser",
|
||||
# Settings
|
||||
"get_settings",
|
||||
"JWTSettings",
|
||||
# Security (lazy)
|
||||
"JWTAuth",
|
||||
"jwt_auth",
|
||||
]
|
||||
97
django/src/djarea/jwt/functions.py
Normal file
97
django/src/djarea/jwt/functions.py
Normal file
@@ -0,0 +1,97 @@
|
||||
"""
|
||||
JWT Server Functions
|
||||
|
||||
JWT token operations exposed as djarea server functions.
|
||||
Works over WebSocket RPC (primary) or HTTP fallback.
|
||||
"""
|
||||
|
||||
from django.http import HttpRequest
|
||||
from pydantic import BaseModel
|
||||
|
||||
from djarea.client import client
|
||||
from djarea.jwt.tokens import create_token_pair, refresh_tokens
|
||||
|
||||
|
||||
class TokenPairOutput(BaseModel):
|
||||
"""JWT token pair response."""
|
||||
access_token: str
|
||||
refresh_token: str
|
||||
expires_in: int
|
||||
|
||||
|
||||
class JWTError(BaseModel):
|
||||
"""JWT operation error."""
|
||||
error: str
|
||||
|
||||
|
||||
@client
|
||||
def jwt_obtain(request: HttpRequest) -> TokenPairOutput:
|
||||
"""
|
||||
Obtain JWT tokens from an authenticated session.
|
||||
|
||||
Requires session authentication (cookie or WebSocket session).
|
||||
Returns access and refresh tokens that can be used for stateless auth.
|
||||
|
||||
The tokens include user claims (is_staff, is_superuser) so that
|
||||
subsequent JWT-authenticated requests don't need a database query.
|
||||
|
||||
Usage:
|
||||
const { access_token, refresh_token } = await call('jwt_obtain')
|
||||
// Use access_token in Authorization: Bearer header
|
||||
"""
|
||||
user = request.user
|
||||
|
||||
if not user.is_authenticated:
|
||||
raise PermissionError("Authentication required")
|
||||
|
||||
# Get session key - for WebSocket, this comes from the scope
|
||||
session = getattr(request, 'session', None)
|
||||
if session is None:
|
||||
# WebSocket request adapter - session is a dict, not SessionBase
|
||||
session_key = getattr(request, '_scope', {}).get('session', {}).get('_session_key')
|
||||
if not session_key:
|
||||
raise PermissionError("No session available")
|
||||
else:
|
||||
# HTTP request - ensure session is saved
|
||||
if not session.session_key:
|
||||
session.save()
|
||||
session_key = session.session_key
|
||||
|
||||
# Include user claims in the token for stateless auth
|
||||
tokens = create_token_pair(
|
||||
user.pk,
|
||||
session_key,
|
||||
is_staff=getattr(user, 'is_staff', False),
|
||||
is_superuser=getattr(user, 'is_superuser', False),
|
||||
)
|
||||
|
||||
return TokenPairOutput(
|
||||
access_token=tokens.access_token,
|
||||
refresh_token=tokens.refresh_token,
|
||||
expires_in=tokens.expires_in,
|
||||
)
|
||||
|
||||
|
||||
@client
|
||||
def jwt_refresh(request: HttpRequest, refresh_token: str) -> TokenPairOutput:
|
||||
"""
|
||||
Refresh JWT tokens using a refresh token.
|
||||
|
||||
Does not require session authentication - the refresh token itself
|
||||
contains the session reference and is validated against the session store.
|
||||
|
||||
If the original session has been destroyed (user logged out), this fails.
|
||||
|
||||
Usage:
|
||||
const { access_token, refresh_token } = await call('jwt_refresh', { refresh_token })
|
||||
"""
|
||||
tokens = refresh_tokens(refresh_token)
|
||||
|
||||
if tokens is None:
|
||||
raise PermissionError("Invalid or expired refresh token")
|
||||
|
||||
return TokenPairOutput(
|
||||
access_token=tokens.access_token,
|
||||
refresh_token=tokens.refresh_token,
|
||||
expires_in=tokens.expires_in,
|
||||
)
|
||||
64
django/src/djarea/jwt/security.py
Normal file
64
django/src/djarea/jwt/security.py
Normal file
@@ -0,0 +1,64 @@
|
||||
"""
|
||||
Django Ninja Security Classes for JWT Authentication
|
||||
|
||||
Provides authentication classes that can be used with Django Ninja's
|
||||
auth parameter to protect API endpoints.
|
||||
"""
|
||||
|
||||
from django.http import HttpRequest
|
||||
from ninja.security import HttpBearer
|
||||
|
||||
from .tokens import decode_token, JWTUser
|
||||
|
||||
|
||||
class JWTAuth(HttpBearer):
|
||||
"""
|
||||
JWT Bearer token authentication for Django Ninja.
|
||||
|
||||
Usage:
|
||||
from ninja_jwt_session import jwt_auth
|
||||
|
||||
@api.get("/protected/", auth=jwt_auth)
|
||||
def protected_endpoint(request):
|
||||
return {"user_id": request.user.id}
|
||||
|
||||
Or globally:
|
||||
api = NinjaExtraAPI(auth=[django_auth, jwt_auth])
|
||||
|
||||
The token must be passed in the Authorization header:
|
||||
Authorization: Bearer <access_token>
|
||||
|
||||
IMPORTANT: This is stateless - no database query is made.
|
||||
request.user is a JWTUser object with id, is_staff, is_superuser.
|
||||
If you need the full User object, query it explicitly:
|
||||
user = User.objects.get(pk=request.user.id)
|
||||
"""
|
||||
|
||||
def authenticate(self, request: HttpRequest, token: str):
|
||||
"""
|
||||
Validate the JWT and return a JWTUser if valid.
|
||||
|
||||
Returns None (authentication failed) if:
|
||||
- Token is invalid or expired
|
||||
- Token is not an access token
|
||||
|
||||
Note: No database query is made. The JWTUser is created from
|
||||
token claims. This is truly stateless authentication.
|
||||
"""
|
||||
# Decode and validate the token
|
||||
payload = decode_token(token, expected_type="access")
|
||||
|
||||
if payload is None:
|
||||
return None
|
||||
|
||||
# Create JWTUser from token claims - NO DATABASE QUERY
|
||||
jwt_user = JWTUser(payload)
|
||||
|
||||
# Set request.user for compatibility with code expecting it
|
||||
request.user = jwt_user
|
||||
|
||||
return jwt_user
|
||||
|
||||
|
||||
# Singleton instance for convenience
|
||||
jwt_auth = JWTAuth()
|
||||
118
django/src/djarea/jwt/settings.py
Normal file
118
django/src/djarea/jwt/settings.py
Normal file
@@ -0,0 +1,118 @@
|
||||
"""
|
||||
JWT Hybrid Settings
|
||||
|
||||
Configuration is read from Django settings with sensible defaults.
|
||||
Supports both symmetric (HS256) and asymmetric (RS256) algorithms.
|
||||
"""
|
||||
|
||||
from dataclasses import dataclass
|
||||
from functools import lru_cache
|
||||
|
||||
from django.conf import settings as django_settings
|
||||
|
||||
|
||||
@dataclass
|
||||
class JWTSettings:
|
||||
"""JWT configuration."""
|
||||
|
||||
# Signing keys
|
||||
private_key: str # Used for signing (required)
|
||||
public_key: str # Used for verification (same as private for HS256)
|
||||
|
||||
# Algorithm
|
||||
algorithm: str # HS256, RS256, etc.
|
||||
|
||||
# Token lifetimes (seconds)
|
||||
access_token_expires_in: int
|
||||
refresh_token_expires_in: int
|
||||
|
||||
# Security options
|
||||
validate_session: bool # Check session exists on token validation
|
||||
rotate_refresh_token: bool # Issue new refresh token on refresh
|
||||
|
||||
|
||||
@lru_cache
|
||||
def get_settings() -> JWTSettings:
|
||||
"""
|
||||
Load JWT settings from Django settings.
|
||||
|
||||
Settings:
|
||||
JWT_PRIVATE_KEY: Signing key (required)
|
||||
JWT_PUBLIC_KEY: Verification key (defaults to private key for HS256)
|
||||
JWT_ALGORITHM: Algorithm to use (default: HS256)
|
||||
JWT_ACCESS_TOKEN_EXPIRES_IN: Access token lifetime (default: 300)
|
||||
JWT_REFRESH_TOKEN_EXPIRES_IN: Refresh token lifetime (default: 604800)
|
||||
JWT_VALIDATE_SESSION: Validate session on token use (default: True)
|
||||
JWT_ROTATE_REFRESH_TOKEN: Rotate refresh tokens (default: True)
|
||||
"""
|
||||
private_key = getattr(django_settings, "JWT_PRIVATE_KEY", None)
|
||||
|
||||
if not private_key:
|
||||
# Fall back to allauth setting if available (for compatibility)
|
||||
headless_key = getattr(django_settings, "HEADLESS_JWT_PRIVATE_KEY", None)
|
||||
if headless_key:
|
||||
private_key = headless_key
|
||||
|
||||
if private_key is None:
|
||||
raise ValueError(
|
||||
"JWT_PRIVATE_KEY must be set in Django settings. "
|
||||
"For HS256, use a secure random string. "
|
||||
"For RS256, use a PEM-encoded RSA private key."
|
||||
)
|
||||
|
||||
# Auto-detect algorithm based on key format if not explicitly set
|
||||
algorithm = getattr(django_settings, "JWT_ALGORITHM", None)
|
||||
|
||||
if algorithm is None:
|
||||
# Auto-detect: if key looks like PEM, use RS256; otherwise HS256
|
||||
if isinstance(private_key, str) and private_key.strip().startswith("-----BEGIN"):
|
||||
algorithm = "RS256"
|
||||
else:
|
||||
algorithm = "HS256"
|
||||
|
||||
# For symmetric algorithms, public key = private key
|
||||
if algorithm.startswith("HS"):
|
||||
public_key = private_key
|
||||
else:
|
||||
public_key = getattr(django_settings, "JWT_PUBLIC_KEY", None)
|
||||
if public_key is None:
|
||||
# Try to extract public key from private key for RSA
|
||||
try:
|
||||
from cryptography.hazmat.primitives import serialization
|
||||
from cryptography.hazmat.primitives.serialization import load_pem_private_key
|
||||
|
||||
private_key_obj = load_pem_private_key(
|
||||
private_key.encode() if isinstance(private_key, str) else private_key,
|
||||
password=None,
|
||||
)
|
||||
public_key = private_key_obj.public_key().public_bytes(
|
||||
encoding=serialization.Encoding.PEM,
|
||||
format=serialization.PublicFormat.SubjectPublicKeyInfo,
|
||||
).decode()
|
||||
except Exception:
|
||||
raise ValueError(
|
||||
f"JWT_PUBLIC_KEY must be set for {algorithm} algorithm, "
|
||||
"or JWT_PRIVATE_KEY must be a valid PEM-encoded RSA key."
|
||||
)
|
||||
|
||||
return JWTSettings(
|
||||
private_key=private_key,
|
||||
public_key=public_key,
|
||||
algorithm=algorithm,
|
||||
access_token_expires_in=getattr(
|
||||
django_settings,
|
||||
"JWT_ACCESS_TOKEN_EXPIRES_IN",
|
||||
getattr(django_settings, "HEADLESS_JWT_ACCESS_TOKEN_EXPIRES_IN", 300),
|
||||
),
|
||||
refresh_token_expires_in=getattr(
|
||||
django_settings,
|
||||
"JWT_REFRESH_TOKEN_EXPIRES_IN",
|
||||
getattr(django_settings, "HEADLESS_JWT_REFRESH_TOKEN_EXPIRES_IN", 604800),
|
||||
),
|
||||
validate_session=getattr(
|
||||
django_settings, "JWT_VALIDATE_SESSION", True
|
||||
),
|
||||
rotate_refresh_token=getattr(
|
||||
django_settings, "JWT_ROTATE_REFRESH_TOKEN", True
|
||||
),
|
||||
)
|
||||
245
django/src/djarea/jwt/tokens.py
Normal file
245
django/src/djarea/jwt/tokens.py
Normal file
@@ -0,0 +1,245 @@
|
||||
"""
|
||||
JWT Token Creation and Validation
|
||||
|
||||
Uses PyJWT directly - no allauth dependency.
|
||||
Tokens are tied to Django sessions for immediate revocation on logout.
|
||||
"""
|
||||
|
||||
import time
|
||||
from typing import NamedTuple
|
||||
|
||||
import jwt
|
||||
from django.contrib.sessions.backends.base import SessionBase
|
||||
|
||||
from .settings import get_settings
|
||||
|
||||
|
||||
class TokenPair(NamedTuple):
|
||||
"""Access and refresh token pair."""
|
||||
access_token: str
|
||||
refresh_token: str
|
||||
expires_in: int
|
||||
|
||||
|
||||
class TokenPayload(NamedTuple):
|
||||
"""Decoded token payload."""
|
||||
user_id: int | str
|
||||
session_key: str
|
||||
token_type: str
|
||||
is_staff: bool
|
||||
is_superuser: bool
|
||||
exp: int
|
||||
iat: int
|
||||
|
||||
|
||||
class JWTUser:
|
||||
"""
|
||||
Minimal user object created from JWT claims.
|
||||
|
||||
Used as request.user for JWT-authenticated requests.
|
||||
No database query required - all data comes from the token.
|
||||
|
||||
If you need the full User object with all fields, query explicitly:
|
||||
user = User.objects.get(pk=request.user.id)
|
||||
"""
|
||||
|
||||
def __init__(self, payload: TokenPayload):
|
||||
self.id = int(payload.user_id) if isinstance(payload.user_id, str) else payload.user_id
|
||||
self.pk = self.id
|
||||
self.is_staff = payload.is_staff
|
||||
self.is_superuser = payload.is_superuser
|
||||
self.is_authenticated = True
|
||||
self.is_anonymous = False
|
||||
self.is_active = True # Assumed active if they have a valid token
|
||||
|
||||
def __str__(self):
|
||||
return f"JWTUser(id={self.id})"
|
||||
|
||||
def __repr__(self):
|
||||
return f"JWTUser(id={self.id}, is_staff={self.is_staff}, is_superuser={self.is_superuser})"
|
||||
|
||||
|
||||
def create_access_token(
|
||||
user_id: int | str,
|
||||
session_key: str,
|
||||
*,
|
||||
is_staff: bool = False,
|
||||
is_superuser: bool = False,
|
||||
) -> str:
|
||||
"""
|
||||
Create a short-lived access token.
|
||||
|
||||
The token contains:
|
||||
- sub: user ID
|
||||
- sid: session key (for revocation checking)
|
||||
- staff: is_staff flag
|
||||
- super: is_superuser flag
|
||||
- type: "access"
|
||||
- iat: issued at
|
||||
- exp: expiration
|
||||
"""
|
||||
settings = get_settings()
|
||||
now = int(time.time())
|
||||
|
||||
payload = {
|
||||
"sub": str(user_id),
|
||||
"sid": session_key,
|
||||
"staff": is_staff,
|
||||
"super": is_superuser,
|
||||
"type": "access",
|
||||
"iat": now,
|
||||
"exp": now + settings.access_token_expires_in,
|
||||
}
|
||||
|
||||
return jwt.encode(
|
||||
payload,
|
||||
settings.private_key,
|
||||
algorithm=settings.algorithm,
|
||||
)
|
||||
|
||||
|
||||
def create_refresh_token(
|
||||
user_id: int | str,
|
||||
session_key: str,
|
||||
*,
|
||||
is_staff: bool = False,
|
||||
is_superuser: bool = False,
|
||||
) -> str:
|
||||
"""
|
||||
Create a longer-lived refresh token.
|
||||
|
||||
The token contains:
|
||||
- sub: user ID
|
||||
- sid: session key (for revocation checking)
|
||||
- staff: is_staff flag
|
||||
- super: is_superuser flag
|
||||
- type: "refresh"
|
||||
- iat: issued at
|
||||
- exp: expiration
|
||||
"""
|
||||
settings = get_settings()
|
||||
now = int(time.time())
|
||||
|
||||
payload = {
|
||||
"sub": str(user_id),
|
||||
"sid": session_key,
|
||||
"staff": is_staff,
|
||||
"super": is_superuser,
|
||||
"type": "refresh",
|
||||
"iat": now,
|
||||
"exp": now + settings.refresh_token_expires_in,
|
||||
}
|
||||
|
||||
return jwt.encode(
|
||||
payload,
|
||||
settings.private_key,
|
||||
algorithm=settings.algorithm,
|
||||
)
|
||||
|
||||
|
||||
def create_token_pair(
|
||||
user_id: int | str,
|
||||
session_key: str,
|
||||
*,
|
||||
is_staff: bool = False,
|
||||
is_superuser: bool = False,
|
||||
) -> TokenPair:
|
||||
"""Create both access and refresh tokens."""
|
||||
settings = get_settings()
|
||||
return TokenPair(
|
||||
access_token=create_access_token(
|
||||
user_id, session_key, is_staff=is_staff, is_superuser=is_superuser
|
||||
),
|
||||
refresh_token=create_refresh_token(
|
||||
user_id, session_key, is_staff=is_staff, is_superuser=is_superuser
|
||||
),
|
||||
expires_in=settings.access_token_expires_in,
|
||||
)
|
||||
|
||||
|
||||
def decode_token(token: str, expected_type: str = None) -> TokenPayload | None:
|
||||
"""
|
||||
Decode and validate a JWT token.
|
||||
|
||||
Returns None if:
|
||||
- Token is invalid or expired
|
||||
- Token type doesn't match expected_type (if specified)
|
||||
"""
|
||||
settings = get_settings()
|
||||
|
||||
try:
|
||||
payload = jwt.decode(
|
||||
token,
|
||||
settings.public_key,
|
||||
algorithms=[settings.algorithm],
|
||||
)
|
||||
except jwt.PyJWTError:
|
||||
return None
|
||||
|
||||
# Validate token type if specified
|
||||
if expected_type and payload.get("type") != expected_type:
|
||||
return None
|
||||
|
||||
return TokenPayload(
|
||||
user_id=payload["sub"],
|
||||
session_key=payload["sid"],
|
||||
token_type=payload["type"],
|
||||
is_staff=payload.get("staff", False),
|
||||
is_superuser=payload.get("super", False),
|
||||
exp=payload["exp"],
|
||||
iat=payload["iat"],
|
||||
)
|
||||
|
||||
|
||||
def validate_session(session_key: str) -> bool:
|
||||
"""
|
||||
Check if a session is still valid (exists and not expired).
|
||||
|
||||
This is the key to immediate logout revocation - if the session
|
||||
is destroyed, tokens tied to it become invalid.
|
||||
"""
|
||||
from importlib import import_module
|
||||
|
||||
from django.conf import settings as django_settings
|
||||
|
||||
jwt_settings = get_settings()
|
||||
|
||||
if not jwt_settings.validate_session:
|
||||
return True
|
||||
|
||||
# Use the configured session engine
|
||||
engine = import_module(django_settings.SESSION_ENGINE)
|
||||
SessionStore = engine.SessionStore
|
||||
|
||||
# Try to load the session
|
||||
session = SessionStore(session_key=session_key)
|
||||
|
||||
# Check if session exists and is not empty
|
||||
# exists() is more reliable than checking load() result
|
||||
return session.exists(session_key)
|
||||
|
||||
|
||||
def refresh_tokens(refresh_token: str) -> TokenPair | None:
|
||||
"""
|
||||
Use a refresh token to obtain new tokens.
|
||||
|
||||
Returns None if:
|
||||
- Refresh token is invalid or expired
|
||||
- Associated session no longer exists
|
||||
"""
|
||||
payload = decode_token(refresh_token, expected_type="refresh")
|
||||
|
||||
if payload is None:
|
||||
return None
|
||||
|
||||
# Validate the session still exists
|
||||
if not validate_session(payload.session_key):
|
||||
return None
|
||||
|
||||
# Issue new token pair with same claims
|
||||
return create_token_pair(
|
||||
payload.user_id,
|
||||
payload.session_key,
|
||||
is_staff=payload.is_staff,
|
||||
is_superuser=payload.is_superuser,
|
||||
)
|
||||
0
django/src/djarea/management/__init__.py
Normal file
0
django/src/djarea/management/__init__.py
Normal file
0
django/src/djarea/management/commands/__init__.py
Normal file
0
django/src/djarea/management/commands/__init__.py
Normal file
@@ -0,0 +1,35 @@
|
||||
"""
|
||||
Export channels schema as OpenAPI JSON for TypeScript generation.
|
||||
|
||||
Uses Django Ninja's schema generation for robust Pydantic→OpenAPI conversion.
|
||||
The schema is consumed by openapi-typescript for type generation.
|
||||
|
||||
Usage:
|
||||
python manage.py export_channels_schema
|
||||
"""
|
||||
|
||||
import json
|
||||
|
||||
from django.core.management.base import BaseCommand
|
||||
|
||||
|
||||
class Command(BaseCommand):
|
||||
help = "Export channels schema as OpenAPI JSON for TypeScript code generation"
|
||||
|
||||
def add_arguments(self, parser):
|
||||
parser.add_argument(
|
||||
"--indent",
|
||||
type=int,
|
||||
default=2,
|
||||
help="JSON indentation level (default: 2, use 0 for compact)",
|
||||
)
|
||||
|
||||
def handle(self, *args, **options):
|
||||
from djarea.channels import get_channels_openapi_schema
|
||||
|
||||
schema = get_channels_openapi_schema()
|
||||
|
||||
indent = options["indent"] if options["indent"] > 0 else None
|
||||
output = json.dumps(schema, indent=indent)
|
||||
|
||||
self.stdout.write(output)
|
||||
@@ -0,0 +1,51 @@
|
||||
"""
|
||||
Export Djarea Schema
|
||||
|
||||
Management command to export the djarea OpenAPI schema for TypeScript code generation.
|
||||
The schema is consumed by openapi-typescript for robust type generation.
|
||||
|
||||
Usage:
|
||||
python manage.py export_djarea_schema # Output to stdout
|
||||
python manage.py export_djarea_schema --output schema.json # Output to file
|
||||
"""
|
||||
|
||||
import json
|
||||
from pathlib import Path
|
||||
|
||||
from django.core.management.base import BaseCommand
|
||||
|
||||
from djarea.export import generate_openapi_schema
|
||||
|
||||
|
||||
class Command(BaseCommand):
|
||||
help = "Export djarea OpenAPI schema for TypeScript code generation"
|
||||
|
||||
def add_arguments(self, parser):
|
||||
parser.add_argument(
|
||||
"--output",
|
||||
"-o",
|
||||
type=str,
|
||||
default=None,
|
||||
help="Output file path. If not specified, outputs to stdout.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--indent",
|
||||
type=int,
|
||||
default=2,
|
||||
help="JSON indentation level (0 for compact output)",
|
||||
)
|
||||
|
||||
def handle(self, *args, **options):
|
||||
schema = generate_openapi_schema()
|
||||
indent = options["indent"] if options["indent"] > 0 else None
|
||||
json_output = json.dumps(schema, indent=indent)
|
||||
|
||||
if options["output"]:
|
||||
output_path = Path(options["output"])
|
||||
output_path.parent.mkdir(parents=True, exist_ok=True)
|
||||
output_path.write_text(json_output)
|
||||
self.stdout.write(
|
||||
self.style.SUCCESS(f"Schema written to {output_path}")
|
||||
)
|
||||
else:
|
||||
self.stdout.write(json_output)
|
||||
69
django/src/djarea/setup/__init__.py
Normal file
69
django/src/djarea/setup/__init__.py
Normal file
@@ -0,0 +1,69 @@
|
||||
"""
|
||||
djarea.setup - Integration and registration utilities.
|
||||
|
||||
This subpackage contains everything developers need to integrate Djarea:
|
||||
- Registry for server functions and channels
|
||||
- Auto-discovery for apps
|
||||
- Configuration settings
|
||||
|
||||
Usage:
|
||||
from djarea.setup import djarea_clients, register, get_function
|
||||
"""
|
||||
|
||||
from .registry import (
|
||||
register,
|
||||
register_as,
|
||||
register_form,
|
||||
register_compose,
|
||||
get_function,
|
||||
get_channel,
|
||||
get_compose,
|
||||
get_view,
|
||||
get_all_functions,
|
||||
get_all_channels,
|
||||
get_all_compositions,
|
||||
get_registry,
|
||||
get_schema,
|
||||
get_contexts,
|
||||
get_forms,
|
||||
clear_registry,
|
||||
)
|
||||
|
||||
from .discovery import (
|
||||
djarea_clients,
|
||||
djarea_module,
|
||||
)
|
||||
|
||||
from .settings import (
|
||||
DjareaSettings,
|
||||
get_settings,
|
||||
clear_settings_cache,
|
||||
)
|
||||
|
||||
__all__ = [
|
||||
# Registration
|
||||
"register",
|
||||
"register_as",
|
||||
"register_form",
|
||||
"register_compose",
|
||||
# Lookup
|
||||
"get_function",
|
||||
"get_channel",
|
||||
"get_compose",
|
||||
"get_view",
|
||||
"get_all_functions",
|
||||
"get_all_channels",
|
||||
"get_all_compositions",
|
||||
"get_registry",
|
||||
"get_schema",
|
||||
"get_contexts",
|
||||
"get_forms",
|
||||
"clear_registry",
|
||||
# Discovery
|
||||
"djarea_clients",
|
||||
"djarea_module",
|
||||
# Settings
|
||||
"DjareaSettings",
|
||||
"get_settings",
|
||||
"clear_settings_cache",
|
||||
]
|
||||
90
django/src/djarea/setup/discovery.py
Normal file
90
django/src/djarea/setup/discovery.py
Normal file
@@ -0,0 +1,90 @@
|
||||
"""
|
||||
Djarea Auto-Discovery
|
||||
|
||||
Scans Django apps for server functions following the 'clients' layer convention:
|
||||
- <app>/clients.py
|
||||
- <app>/clients/**/*.py
|
||||
|
||||
Usage in urls.py:
|
||||
from djarea.setup.discovery import djarea_clients
|
||||
|
||||
djarea_clients('apps') # Scans apps/*/clients.py
|
||||
djarea_clients('djarea', 'allauth') # Scans djarea/allauth/**/*.py
|
||||
|
||||
This replaces manual "import to register" patterns with explicit auto-discovery.
|
||||
"""
|
||||
|
||||
from typing import Any
|
||||
|
||||
from djarea._vendor.app_visitor import DjangoAppVisitor, get_members
|
||||
|
||||
from .registry import register, get_function
|
||||
from djarea.client.function import ServerFunction
|
||||
|
||||
|
||||
class _RegisterServerFunctions:
|
||||
"""Visitor handler that registers ServerFunction subclasses."""
|
||||
|
||||
def on_module(
|
||||
self, app_name: str, path_parts: list[str], members: list[tuple[str, Any]]
|
||||
) -> None:
|
||||
"""Process discovered module members."""
|
||||
for name, member in members:
|
||||
# Register ServerFunction subclasses
|
||||
if (
|
||||
isinstance(member, type)
|
||||
and issubclass(member, ServerFunction)
|
||||
and member is not ServerFunction
|
||||
and hasattr(member, '__name__')
|
||||
):
|
||||
# Use the function name as registration name
|
||||
fn_name = getattr(member, 'name', None) or member.__name__
|
||||
|
||||
# Skip already registered (idempotent)
|
||||
if get_function(fn_name) is member:
|
||||
continue
|
||||
|
||||
try:
|
||||
register(member, fn_name)
|
||||
except ValueError:
|
||||
# Already registered with different class - skip
|
||||
pass
|
||||
|
||||
|
||||
def djarea_clients(apps_root: str, layer: str = 'clients') -> None:
|
||||
"""
|
||||
Discover and register server functions from Django apps.
|
||||
|
||||
Scans for the specified layer (default: 'clients') in each app:
|
||||
- <app>/<layer>.py
|
||||
- <app>/<layer>/**/*.py
|
||||
|
||||
Args:
|
||||
apps_root: Root package containing Django apps (e.g., 'apps')
|
||||
layer: Module name pattern to scan (default: 'clients')
|
||||
|
||||
Example:
|
||||
# In urls.py
|
||||
djarea_clients('apps') # Scans apps/*/clients.py
|
||||
djarea_clients('apps', 'functions') # Scans apps/*/functions.py
|
||||
"""
|
||||
visitor = DjangoAppVisitor(layer=layer, apps_root=apps_root)
|
||||
visitor.visit(_RegisterServerFunctions())
|
||||
|
||||
|
||||
def djarea_module(module_path: str) -> None:
|
||||
"""
|
||||
Register server functions from a specific module.
|
||||
|
||||
Use this for library modules that don't follow the app convention.
|
||||
|
||||
Args:
|
||||
module_path: Full module path (e.g., 'djarea.integrations.allauth')
|
||||
|
||||
Example:
|
||||
djarea_module('djarea.integrations.allauth')
|
||||
djarea_module('djarea.jwt.functions')
|
||||
"""
|
||||
members = get_members(module_path)
|
||||
handler = _RegisterServerFunctions()
|
||||
handler.on_module('', [], members)
|
||||
316
django/src/djarea/setup/registry.py
Normal file
316
django/src/djarea/setup/registry.py
Normal file
@@ -0,0 +1,316 @@
|
||||
"""
|
||||
Djarea Registry
|
||||
|
||||
Central registration for server functions, channels, and compositions.
|
||||
All items are identified by name.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import TYPE_CHECKING, Any, Callable
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from djarea.client.function import ServerFunction, ComposedContext
|
||||
from djarea.channels import ReactChannel
|
||||
|
||||
|
||||
# Global registries - all use name as key
|
||||
_functions: dict[str, type["ServerFunction"]] = {}
|
||||
_channels: dict[str, type["ReactChannel"]] = {}
|
||||
_compositions: dict[str, "ComposedContext"] = {}
|
||||
|
||||
|
||||
def register(
|
||||
view_class: type["ServerFunction"] | type["ReactChannel"],
|
||||
name: str,
|
||||
) -> type["ServerFunction"] | type["ReactChannel"]:
|
||||
"""
|
||||
Register a server function or channel.
|
||||
|
||||
Args:
|
||||
view_class: ServerFunction or ReactChannel subclass
|
||||
name: Registration name (used for API calls and code generation)
|
||||
|
||||
Returns:
|
||||
The view class (allows use as part of decorator chain)
|
||||
"""
|
||||
from djarea.client.function import ServerFunction
|
||||
from djarea.channels import ReactChannel
|
||||
|
||||
view_class.name = name
|
||||
|
||||
if issubclass(view_class, ReactChannel):
|
||||
if name in _channels:
|
||||
# Allow re-registration of the same class (idempotent for reloads)
|
||||
if _channels[name] is not view_class:
|
||||
raise ValueError(
|
||||
f"Channel '{name}' already registered by {_channels[name].__name__}"
|
||||
)
|
||||
return view_class
|
||||
_channels[name] = view_class
|
||||
elif issubclass(view_class, ServerFunction):
|
||||
if name in _functions:
|
||||
# Allow re-registration of the same class (idempotent for reloads)
|
||||
existing = _functions[name]
|
||||
if existing.__name__ == view_class.__name__:
|
||||
# Same function being re-registered (reload scenario)
|
||||
_functions[name] = view_class
|
||||
return view_class
|
||||
raise ValueError(
|
||||
f"Function '{name}' already registered by {existing.__name__}"
|
||||
)
|
||||
_functions[name] = view_class
|
||||
else:
|
||||
raise TypeError(f"{view_class} must be a ServerFunction or ReactChannel")
|
||||
|
||||
return view_class
|
||||
|
||||
|
||||
def register_as(name: str):
|
||||
"""
|
||||
Decorator for registering a server function or channel.
|
||||
|
||||
Usage:
|
||||
@register_as('update-profile')
|
||||
class UpdateProfile(ServerFunction):
|
||||
...
|
||||
"""
|
||||
|
||||
def decorator(view_class):
|
||||
return register(view_class, name)
|
||||
|
||||
return decorator
|
||||
|
||||
|
||||
def register_form(
|
||||
form_class: type,
|
||||
name: str,
|
||||
submit_handler: Callable | None = None,
|
||||
) -> None:
|
||||
"""
|
||||
Register a Django Form as server functions.
|
||||
|
||||
Creates and registers:
|
||||
- {name}.schema: Returns form field definitions
|
||||
- {name}.validate: Validates form data
|
||||
- {name}.submit: Submits form (if submit_handler provided)
|
||||
|
||||
Usage:
|
||||
register_form(ContactForm, 'contact', submit_handler=handle_contact)
|
||||
"""
|
||||
from djarea.client.function import create_form_functions
|
||||
|
||||
schema_fn, validate_fn, submit_fn = create_form_functions(
|
||||
form_class, name, submit_handler
|
||||
)
|
||||
|
||||
register(schema_fn, f"{name}.schema")
|
||||
register(validate_fn, f"{name}.validate")
|
||||
if submit_fn:
|
||||
register(submit_fn, f"{name}.submit")
|
||||
|
||||
|
||||
def register_compose(
|
||||
composed: "ComposedContext",
|
||||
name: str,
|
||||
) -> "ComposedContext":
|
||||
"""
|
||||
Register a composed context.
|
||||
|
||||
Args:
|
||||
composed: ComposedContext instance
|
||||
name: Registration name
|
||||
|
||||
Returns:
|
||||
The composed context
|
||||
"""
|
||||
if name in _compositions:
|
||||
existing = _compositions[name]
|
||||
if existing.name == composed.name:
|
||||
# Same composition being re-registered (reload scenario)
|
||||
_compositions[name] = composed
|
||||
return composed
|
||||
raise ValueError(
|
||||
f"Composition '{name}' already registered by {existing.name}"
|
||||
)
|
||||
_compositions[name] = composed
|
||||
return composed
|
||||
|
||||
|
||||
def get_function(name: str) -> type["ServerFunction"] | None:
|
||||
"""Get a registered server function by name."""
|
||||
return _functions.get(name)
|
||||
|
||||
|
||||
def get_channel(name: str) -> type["ReactChannel"] | None:
|
||||
"""Get a registered channel by name."""
|
||||
return _channels.get(name)
|
||||
|
||||
|
||||
def get_compose(name: str) -> "ComposedContext | None":
|
||||
"""Get a registered composition by name."""
|
||||
return _compositions.get(name)
|
||||
|
||||
|
||||
def get_view(name: str) -> type["ServerFunction"] | type["ReactChannel"] | None:
|
||||
"""Get any registered view by name (function or channel)."""
|
||||
return _functions.get(name) or _channels.get(name)
|
||||
|
||||
|
||||
def get_all_functions() -> dict[str, type["ServerFunction"]]:
|
||||
"""Get all registered functions."""
|
||||
return _functions.copy()
|
||||
|
||||
|
||||
def get_all_channels() -> dict[str, type["ReactChannel"]]:
|
||||
"""Get all registered channels."""
|
||||
return _channels.copy()
|
||||
|
||||
|
||||
def get_all_compositions() -> dict[str, "ComposedContext"]:
|
||||
"""Get all registered compositions."""
|
||||
return _compositions.copy()
|
||||
|
||||
|
||||
def get_registry() -> dict[str, dict[str, Any]]:
|
||||
"""
|
||||
Get the full registry organized by type.
|
||||
|
||||
Returns:
|
||||
{
|
||||
"functions": { name: class, ... },
|
||||
"channels": { name: class, ... },
|
||||
"compositions": { name: ComposedContext, ... },
|
||||
}
|
||||
"""
|
||||
return {
|
||||
"functions": _functions.copy(),
|
||||
"channels": _channels.copy(),
|
||||
"compositions": _compositions.copy(),
|
||||
}
|
||||
|
||||
|
||||
def get_schema() -> dict[str, Any]:
|
||||
"""
|
||||
Export the full schema for TypeScript generation.
|
||||
|
||||
Returns:
|
||||
{
|
||||
"functions": {
|
||||
"update_profile": {
|
||||
"name": "update_profile",
|
||||
"type": "function",
|
||||
"meta": { "context": "global", ... },
|
||||
"input": { ... },
|
||||
"output": { ... },
|
||||
},
|
||||
...
|
||||
},
|
||||
"channels": {
|
||||
"chat": {
|
||||
"name": "chat",
|
||||
"type": "channel",
|
||||
"params": { ... },
|
||||
"django_message": { ... },
|
||||
...
|
||||
},
|
||||
...
|
||||
},
|
||||
"compositions": {
|
||||
"user_page": {
|
||||
"name": "user_page",
|
||||
"type": "compose",
|
||||
"meta": { "on_server": false, ... },
|
||||
"children": ["user_profile", "user_posts"],
|
||||
"leaves": ["user_profile", "user_posts"],
|
||||
},
|
||||
...
|
||||
},
|
||||
}
|
||||
"""
|
||||
functions = {}
|
||||
for name, cls in _functions.items():
|
||||
schema = cls.get_schema_export()
|
||||
functions[name] = schema
|
||||
|
||||
compositions = {}
|
||||
for name, composed in _compositions.items():
|
||||
compositions[name] = {
|
||||
"name": composed.name,
|
||||
"type": "compose",
|
||||
"meta": composed._meta,
|
||||
"children": composed._meta.get("children", []),
|
||||
"leaves": composed._meta.get("leaves", []),
|
||||
}
|
||||
|
||||
# Build channel schemas from our registry
|
||||
# Only include keys when they have values (test expects absent keys, not None)
|
||||
channels_schema = {}
|
||||
for name, channel_class in _channels.items():
|
||||
channel_schema: dict[str, Any] = {
|
||||
"name": name,
|
||||
"type": "channel",
|
||||
"bidirectional": False,
|
||||
}
|
||||
|
||||
# Extract Params schema (only if defined)
|
||||
if hasattr(channel_class, 'Params') and channel_class.Params:
|
||||
channel_schema["params"] = channel_class.Params.model_json_schema()
|
||||
|
||||
# Extract ReactMessage schema (only if defined - indicates bidirectional)
|
||||
if hasattr(channel_class, 'ReactMessage') and channel_class.ReactMessage:
|
||||
channel_schema["react_message"] = channel_class.ReactMessage.model_json_schema()
|
||||
channel_schema["bidirectional"] = True
|
||||
|
||||
# Extract DjangoMessage schema (only if defined)
|
||||
if hasattr(channel_class, 'DjangoMessage') and channel_class.DjangoMessage:
|
||||
channel_schema["django_message"] = channel_class.DjangoMessage.model_json_schema()
|
||||
|
||||
channels_schema[name] = channel_schema
|
||||
|
||||
return {
|
||||
"functions": functions,
|
||||
"channels": channels_schema,
|
||||
"compositions": compositions,
|
||||
}
|
||||
|
||||
|
||||
def get_contexts() -> dict[str, type["ServerFunction"]]:
|
||||
"""
|
||||
Get all server functions marked as contexts.
|
||||
|
||||
These are functions with meta.context = True, used for SSR hydration.
|
||||
"""
|
||||
contexts = {}
|
||||
for name, cls in _functions.items():
|
||||
if getattr(cls, "_meta", {}).get("context"):
|
||||
contexts[name] = cls
|
||||
return contexts
|
||||
|
||||
|
||||
def get_forms() -> dict[str, list[type["ServerFunction"]]]:
|
||||
"""
|
||||
Get all server functions that are form-related, grouped by form name.
|
||||
|
||||
Returns:
|
||||
{
|
||||
"contact": [ContactSchema, ContactValidate, ContactSubmit],
|
||||
...
|
||||
}
|
||||
"""
|
||||
forms: dict[str, list] = {}
|
||||
for name, cls in _functions.items():
|
||||
meta = getattr(cls, "_meta", {})
|
||||
if meta.get("form"):
|
||||
form_name = meta.get("form_name")
|
||||
if form_name not in forms:
|
||||
forms[form_name] = []
|
||||
forms[form_name].append(cls)
|
||||
return forms
|
||||
|
||||
|
||||
def clear_registry() -> None:
|
||||
"""Clear all registrations. Primarily for testing."""
|
||||
_functions.clear()
|
||||
_channels.clear()
|
||||
_compositions.clear()
|
||||
36
django/src/djarea/setup/settings.py
Normal file
36
django/src/djarea/setup/settings.py
Normal file
@@ -0,0 +1,36 @@
|
||||
"""
|
||||
Djarea Settings
|
||||
|
||||
Configuration is read from Django settings with sensible defaults.
|
||||
"""
|
||||
|
||||
from dataclasses import dataclass
|
||||
from functools import lru_cache
|
||||
|
||||
from django.conf import settings as django_settings
|
||||
|
||||
|
||||
@dataclass
|
||||
class DjareaSettings:
|
||||
"""Djarea configuration."""
|
||||
|
||||
# Whether to expose function names in DEBUG mode errors
|
||||
debug_expose_names: bool
|
||||
|
||||
|
||||
@lru_cache
|
||||
def get_settings() -> DjareaSettings:
|
||||
"""
|
||||
Load Djarea settings from Django settings.
|
||||
|
||||
Settings:
|
||||
DJAREA_DEBUG_EXPOSE_NAMES: Show function names in errors when DEBUG=True (default: True)
|
||||
"""
|
||||
return DjareaSettings(
|
||||
debug_expose_names=getattr(django_settings, "DJAREA_DEBUG_EXPOSE_NAMES", True),
|
||||
)
|
||||
|
||||
|
||||
def clear_settings_cache():
|
||||
"""Clear the settings cache (for testing)."""
|
||||
get_settings.cache_clear()
|
||||
0
django/src/djarea/tests/__init__.py
Normal file
0
django/src/djarea/tests/__init__.py
Normal file
531
django/src/djarea/tests/test_auth.py
Normal file
531
django/src/djarea/tests/test_auth.py
Normal file
@@ -0,0 +1,531 @@
|
||||
"""
|
||||
Authentication Tests for Djarea Server Functions
|
||||
|
||||
Tests all combinations of:
|
||||
- Transport: HTTP vs WebSocket RPC
|
||||
- JWT: Present (valid), Present (invalid), Absent
|
||||
- Session: Present (valid), Absent
|
||||
|
||||
Expected behavior:
|
||||
- JWT present (valid) → JWTUser (no DB query)
|
||||
- JWT present (invalid) → Reject (401), do NOT fall back to session
|
||||
- JWT absent + Session present → Session auth (DB query)
|
||||
- JWT absent + Session absent → AnonymousUser
|
||||
"""
|
||||
|
||||
from django.test import TestCase, RequestFactory, override_settings
|
||||
from django.contrib.auth import get_user_model
|
||||
from django.contrib.sessions.backends.db import SessionStore
|
||||
from unittest.mock import patch, MagicMock
|
||||
import json
|
||||
|
||||
from djarea.jwt.tokens import (
|
||||
create_token_pair,
|
||||
decode_token,
|
||||
JWTUser,
|
||||
)
|
||||
from djarea.client.executor import (
|
||||
_try_jwt_auth,
|
||||
execute_function,
|
||||
FunctionError,
|
||||
FunctionResult,
|
||||
ErrorCode,
|
||||
)
|
||||
from djarea.client import client
|
||||
from djarea.setup.registry import clear_registry, register
|
||||
from pydantic import BaseModel
|
||||
|
||||
|
||||
User = get_user_model()
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# Test Output Models (proper Pydantic models, not raw dicts)
|
||||
# =============================================================================
|
||||
|
||||
class WhoamiOutput(BaseModel):
|
||||
is_authenticated: bool
|
||||
user_id: int | None
|
||||
user_type: str
|
||||
is_staff: bool
|
||||
|
||||
|
||||
class OkOutput(BaseModel):
|
||||
ok: bool
|
||||
|
||||
|
||||
class UserTypeOutput(BaseModel):
|
||||
user_type: str
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# Test Server Functions - defined as plain functions, registered in setUp
|
||||
# =============================================================================
|
||||
|
||||
def _whoami_fn(request) -> WhoamiOutput:
|
||||
"""Returns info about the authenticated user."""
|
||||
user = request.user
|
||||
return WhoamiOutput(
|
||||
is_authenticated=user.is_authenticated,
|
||||
user_id=getattr(user, "id", None),
|
||||
user_type=type(user).__name__,
|
||||
is_staff=getattr(user, "is_staff", False),
|
||||
)
|
||||
|
||||
|
||||
@override_settings(
|
||||
JWT_PRIVATE_KEY="test-secret-key-for-testing-only",
|
||||
JWT_ALGORITHM="HS256",
|
||||
)
|
||||
class HTTPAuthTests(TestCase):
|
||||
"""Test HTTP transport authentication combinations."""
|
||||
|
||||
def setUp(self):
|
||||
clear_registry()
|
||||
self.factory = RequestFactory()
|
||||
self.user = User.objects.create_user(
|
||||
email="test@example.com",
|
||||
password="testpass123",
|
||||
is_staff=True,
|
||||
is_superuser=False,
|
||||
)
|
||||
# Create a session
|
||||
self.session = SessionStore()
|
||||
self.session.create()
|
||||
self.session_key = self.session.session_key
|
||||
|
||||
# Register test function
|
||||
@client
|
||||
def whoami(request) -> WhoamiOutput:
|
||||
user = request.user
|
||||
return WhoamiOutput(
|
||||
is_authenticated=user.is_authenticated,
|
||||
user_id=getattr(user, "id", None),
|
||||
user_type=type(user).__name__,
|
||||
is_staff=getattr(user, "is_staff", False),
|
||||
)
|
||||
register(whoami, "whoami")
|
||||
|
||||
def tearDown(self):
|
||||
self.user.delete()
|
||||
self.session.delete()
|
||||
clear_registry()
|
||||
|
||||
def test_jwt_valid_no_session(self):
|
||||
"""Valid JWT without session → JWTUser (no DB query)."""
|
||||
tokens = create_token_pair(
|
||||
self.user.pk,
|
||||
self.session_key,
|
||||
is_staff=True,
|
||||
is_superuser=False,
|
||||
)
|
||||
|
||||
request = self.factory.post("/")
|
||||
request.META["HTTP_AUTHORIZATION"] = f"Bearer {tokens.access_token}"
|
||||
request.user = MagicMock(is_authenticated=False) # No session auth
|
||||
|
||||
# Try JWT auth
|
||||
result = _try_jwt_auth(request)
|
||||
|
||||
self.assertTrue(result)
|
||||
self.assertIsInstance(request.user, JWTUser)
|
||||
self.assertEqual(request.user.id, self.user.pk)
|
||||
self.assertTrue(request.user.is_staff)
|
||||
self.assertTrue(request.user.is_authenticated)
|
||||
|
||||
def test_jwt_valid_with_session(self):
|
||||
"""Valid JWT with session → JWT takes precedence (no DB query)."""
|
||||
tokens = create_token_pair(
|
||||
self.user.pk,
|
||||
self.session_key,
|
||||
is_staff=True,
|
||||
is_superuser=False,
|
||||
)
|
||||
|
||||
request = self.factory.post("/")
|
||||
request.META["HTTP_AUTHORIZATION"] = f"Bearer {tokens.access_token}"
|
||||
request.user = self.user # Session auth already set user
|
||||
|
||||
# JWT should still be processed and take precedence
|
||||
result = _try_jwt_auth(request)
|
||||
|
||||
self.assertTrue(result)
|
||||
self.assertIsInstance(request.user, JWTUser)
|
||||
|
||||
def test_jwt_invalid_with_session(self):
|
||||
"""Invalid JWT with valid session → Reject (do NOT fall back)."""
|
||||
request = self.factory.post("/")
|
||||
request.META["HTTP_AUTHORIZATION"] = "Bearer invalid-token-here"
|
||||
request.user = self.user # Session would work
|
||||
|
||||
# JWT auth should fail
|
||||
result = _try_jwt_auth(request)
|
||||
|
||||
self.assertFalse(result)
|
||||
# User should NOT be changed to session user - that happens elsewhere
|
||||
# The point is _try_jwt_auth returns False, indicating JWT failed
|
||||
|
||||
def test_jwt_expired_with_session(self):
|
||||
"""Expired JWT with valid session → Reject (do NOT fall back)."""
|
||||
# Create token with past expiration by mocking time
|
||||
with patch("djarea.jwt.tokens.time.time", return_value=0):
|
||||
tokens = create_token_pair(
|
||||
self.user.pk,
|
||||
self.session_key,
|
||||
is_staff=True,
|
||||
is_superuser=False,
|
||||
)
|
||||
|
||||
request = self.factory.post("/")
|
||||
request.META["HTTP_AUTHORIZATION"] = f"Bearer {tokens.access_token}"
|
||||
request.user = self.user # Session would work
|
||||
|
||||
# JWT auth should fail (expired)
|
||||
result = _try_jwt_auth(request)
|
||||
|
||||
self.assertFalse(result)
|
||||
|
||||
def test_no_jwt_with_session(self):
|
||||
"""No JWT with valid session → Session auth (normal Django flow)."""
|
||||
request = self.factory.post("/")
|
||||
request.user = self.user # Session auth set user
|
||||
|
||||
# No JWT auth attempted
|
||||
result = _try_jwt_auth(request)
|
||||
|
||||
self.assertFalse(result) # No JWT to process
|
||||
# User remains the session user
|
||||
self.assertEqual(request.user, self.user)
|
||||
|
||||
def test_no_jwt_no_session(self):
|
||||
"""No JWT, no session → AnonymousUser."""
|
||||
from django.contrib.auth.models import AnonymousUser
|
||||
|
||||
request = self.factory.post("/")
|
||||
request.user = AnonymousUser()
|
||||
|
||||
result = _try_jwt_auth(request)
|
||||
|
||||
self.assertFalse(result)
|
||||
self.assertIsInstance(request.user, AnonymousUser)
|
||||
|
||||
def test_execute_function_with_jwt(self):
|
||||
"""Execute server function with JWT auth."""
|
||||
tokens = create_token_pair(
|
||||
self.user.pk,
|
||||
self.session_key,
|
||||
is_staff=True,
|
||||
is_superuser=False,
|
||||
)
|
||||
|
||||
request = self.factory.post("/")
|
||||
request.META["HTTP_AUTHORIZATION"] = f"Bearer {tokens.access_token}"
|
||||
|
||||
# Simulate what the view does: try JWT auth first
|
||||
_try_jwt_auth(request)
|
||||
|
||||
# Use the whoami function which returns WhoamiOutput (Pydantic model)
|
||||
result = execute_function(request, "whoami", {})
|
||||
|
||||
self.assertIsInstance(result, FunctionResult)
|
||||
self.assertTrue(result.data["is_authenticated"])
|
||||
self.assertEqual(result.data["user_type"], "JWTUser")
|
||||
self.assertTrue(result.data["is_staff"])
|
||||
|
||||
|
||||
@override_settings(
|
||||
JWT_PRIVATE_KEY="test-secret-key-for-testing-only",
|
||||
JWT_ALGORITHM="HS256",
|
||||
)
|
||||
class JWTUserTests(TestCase):
|
||||
"""Test JWTUser behavior."""
|
||||
|
||||
def setUp(self):
|
||||
clear_registry()
|
||||
|
||||
def tearDown(self):
|
||||
clear_registry()
|
||||
|
||||
def test_jwt_user_attributes(self):
|
||||
"""JWTUser has expected attributes."""
|
||||
from djarea.jwt.tokens import TokenPayload
|
||||
|
||||
payload = TokenPayload(
|
||||
user_id=42,
|
||||
session_key="test-session",
|
||||
token_type="access",
|
||||
is_staff=True,
|
||||
is_superuser=False,
|
||||
exp=9999999999,
|
||||
iat=0,
|
||||
)
|
||||
|
||||
user = JWTUser(payload)
|
||||
|
||||
self.assertEqual(user.id, 42)
|
||||
self.assertEqual(user.pk, 42)
|
||||
self.assertTrue(user.is_staff)
|
||||
self.assertFalse(user.is_superuser)
|
||||
self.assertTrue(user.is_authenticated)
|
||||
self.assertFalse(user.is_anonymous)
|
||||
self.assertTrue(user.is_active)
|
||||
|
||||
def test_jwt_user_string_id(self):
|
||||
"""JWTUser handles string user_id (converted to int)."""
|
||||
from djarea.jwt.tokens import TokenPayload
|
||||
|
||||
payload = TokenPayload(
|
||||
user_id="42", # String, as stored in JWT
|
||||
session_key="test-session",
|
||||
token_type="access",
|
||||
is_staff=False,
|
||||
is_superuser=False,
|
||||
exp=9999999999,
|
||||
iat=0,
|
||||
)
|
||||
|
||||
user = JWTUser(payload)
|
||||
|
||||
self.assertEqual(user.id, 42)
|
||||
self.assertIsInstance(user.id, int)
|
||||
|
||||
|
||||
@override_settings(
|
||||
JWT_PRIVATE_KEY="test-secret-key-for-testing-only",
|
||||
JWT_ALGORITHM="HS256",
|
||||
)
|
||||
class AuthDecoratorTests(TestCase):
|
||||
"""Test @client(auth=...) decorator."""
|
||||
|
||||
def setUp(self):
|
||||
clear_registry()
|
||||
self.factory = RequestFactory()
|
||||
self.user = User.objects.create_user(
|
||||
email="test@example.com",
|
||||
password="testpass123",
|
||||
is_staff=False,
|
||||
is_superuser=False,
|
||||
)
|
||||
self.staff_user = User.objects.create_user(
|
||||
email="staff@example.com",
|
||||
password="testpass123",
|
||||
is_staff=True,
|
||||
is_superuser=False,
|
||||
)
|
||||
self.superuser = User.objects.create_user(
|
||||
email="super@example.com",
|
||||
password="testpass123",
|
||||
is_staff=True,
|
||||
is_superuser=True,
|
||||
)
|
||||
|
||||
def tearDown(self):
|
||||
self.user.delete()
|
||||
self.staff_user.delete()
|
||||
self.superuser.delete()
|
||||
clear_registry()
|
||||
|
||||
def test_auth_required_with_anonymous(self):
|
||||
"""@client(auth=True) rejects anonymous users."""
|
||||
from django.contrib.auth.models import AnonymousUser
|
||||
|
||||
# Register a test function with proper Pydantic model
|
||||
@client(auth=True)
|
||||
def protected_fn(request) -> OkOutput:
|
||||
return OkOutput(ok=True)
|
||||
register(protected_fn, "protected_fn")
|
||||
|
||||
request = self.factory.post("/")
|
||||
request.user = AnonymousUser()
|
||||
|
||||
result = execute_function(request, "protected_fn", {})
|
||||
|
||||
self.assertIsInstance(result, FunctionError)
|
||||
self.assertEqual(result.code, ErrorCode.UNAUTHORIZED)
|
||||
|
||||
def test_auth_required_with_authenticated(self):
|
||||
"""@client(auth=True) allows authenticated users."""
|
||||
@client(auth=True)
|
||||
def protected_fn2(request) -> OkOutput:
|
||||
return OkOutput(ok=True)
|
||||
register(protected_fn2, "protected_fn2")
|
||||
|
||||
request = self.factory.post("/")
|
||||
request.user = self.user
|
||||
|
||||
result = execute_function(request, "protected_fn2", {})
|
||||
|
||||
self.assertIsInstance(result, FunctionResult)
|
||||
self.assertEqual(result.data["ok"], True)
|
||||
|
||||
def test_auth_staff_with_regular_user(self):
|
||||
"""@client(auth='staff') rejects non-staff users."""
|
||||
@client(auth='staff')
|
||||
def staff_fn(request) -> OkOutput:
|
||||
return OkOutput(ok=True)
|
||||
register(staff_fn, "staff_fn")
|
||||
|
||||
request = self.factory.post("/")
|
||||
request.user = self.user # Not staff
|
||||
|
||||
result = execute_function(request, "staff_fn", {})
|
||||
|
||||
self.assertIsInstance(result, FunctionError)
|
||||
self.assertEqual(result.code, ErrorCode.FORBIDDEN)
|
||||
|
||||
def test_auth_staff_with_staff_user(self):
|
||||
"""@client(auth='staff') allows staff users."""
|
||||
@client(auth='staff')
|
||||
def staff_fn2(request) -> OkOutput:
|
||||
return OkOutput(ok=True)
|
||||
register(staff_fn2, "staff_fn2")
|
||||
|
||||
request = self.factory.post("/")
|
||||
request.user = self.staff_user
|
||||
|
||||
result = execute_function(request, "staff_fn2", {})
|
||||
|
||||
self.assertIsInstance(result, FunctionResult)
|
||||
|
||||
def test_auth_superuser_with_staff(self):
|
||||
"""@client(auth='superuser') rejects non-superusers."""
|
||||
@client(auth='superuser')
|
||||
def super_fn(request) -> OkOutput:
|
||||
return OkOutput(ok=True)
|
||||
register(super_fn, "super_fn")
|
||||
|
||||
request = self.factory.post("/")
|
||||
request.user = self.staff_user # Staff but not superuser
|
||||
|
||||
result = execute_function(request, "super_fn", {})
|
||||
|
||||
self.assertIsInstance(result, FunctionError)
|
||||
self.assertEqual(result.code, ErrorCode.FORBIDDEN)
|
||||
|
||||
def test_auth_superuser_with_superuser(self):
|
||||
"""@client(auth='superuser') allows superusers."""
|
||||
@client(auth='superuser')
|
||||
def super_fn2(request) -> OkOutput:
|
||||
return OkOutput(ok=True)
|
||||
register(super_fn2, "super_fn2")
|
||||
|
||||
request = self.factory.post("/")
|
||||
request.user = self.superuser
|
||||
|
||||
result = execute_function(request, "super_fn2", {})
|
||||
|
||||
self.assertIsInstance(result, FunctionResult)
|
||||
|
||||
def test_auth_with_jwt_user(self):
|
||||
"""Auth checks work with JWTUser (stateless)."""
|
||||
from djarea.jwt.tokens import TokenPayload
|
||||
|
||||
@client(auth='staff')
|
||||
def jwt_staff_fn(request) -> UserTypeOutput:
|
||||
return UserTypeOutput(user_type=type(request.user).__name__)
|
||||
register(jwt_staff_fn, "jwt_staff_fn")
|
||||
|
||||
# Create JWTUser with is_staff=True
|
||||
payload = TokenPayload(
|
||||
user_id=99,
|
||||
session_key="test",
|
||||
token_type="access",
|
||||
is_staff=True,
|
||||
is_superuser=False,
|
||||
exp=9999999999,
|
||||
iat=0,
|
||||
)
|
||||
jwt_user = JWTUser(payload)
|
||||
|
||||
request = self.factory.post("/")
|
||||
request.user = jwt_user
|
||||
|
||||
result = execute_function(request, "jwt_staff_fn", {})
|
||||
|
||||
self.assertIsInstance(result, FunctionResult)
|
||||
self.assertEqual(result.data["user_type"], "JWTUser")
|
||||
|
||||
def test_auth_invalid_string_raises(self):
|
||||
"""Invalid auth string raises ValueError at decoration time."""
|
||||
with self.assertRaises(ValueError) as ctx:
|
||||
@client(auth='admin') # 'admin' is not valid
|
||||
def bad_fn(request) -> OkOutput:
|
||||
return OkOutput(ok=True)
|
||||
|
||||
self.assertIn("Invalid auth value 'admin'", str(ctx.exception))
|
||||
self.assertIn("required", str(ctx.exception))
|
||||
|
||||
def test_auth_callable_returns_true(self):
|
||||
"""Callable auth returning True allows access."""
|
||||
@client(auth=lambda r: r.user.email.endswith('@example.com'))
|
||||
def email_check_fn(request) -> OkOutput:
|
||||
return OkOutput(ok=True)
|
||||
register(email_check_fn, "email_check_fn")
|
||||
|
||||
request = self.factory.post("/")
|
||||
request.user = self.user # email is test@example.com
|
||||
|
||||
result = execute_function(request, "email_check_fn", {})
|
||||
|
||||
self.assertIsInstance(result, FunctionResult)
|
||||
self.assertTrue(result.data["ok"])
|
||||
|
||||
def test_auth_callable_returns_false(self):
|
||||
"""Callable auth returning False denies access."""
|
||||
@client(auth=lambda r: r.user.email.endswith('@admin.com'))
|
||||
def admin_email_fn(request) -> OkOutput:
|
||||
return OkOutput(ok=True)
|
||||
register(admin_email_fn, "admin_email_fn")
|
||||
|
||||
request = self.factory.post("/")
|
||||
request.user = self.user # email is test@example.com, not @admin.com
|
||||
|
||||
result = execute_function(request, "admin_email_fn", {})
|
||||
|
||||
self.assertIsInstance(result, FunctionError)
|
||||
self.assertEqual(result.code, ErrorCode.FORBIDDEN)
|
||||
self.assertEqual(result.message, "Access denied")
|
||||
|
||||
def test_auth_callable_raises_permission_error(self):
|
||||
"""Callable auth raising PermissionError uses custom message."""
|
||||
def check_premium(request):
|
||||
if not getattr(request.user, 'is_premium', False):
|
||||
raise PermissionError("Premium subscription required")
|
||||
return True
|
||||
|
||||
@client(auth=check_premium)
|
||||
def premium_fn(request) -> OkOutput:
|
||||
return OkOutput(ok=True)
|
||||
register(premium_fn, "premium_fn")
|
||||
|
||||
request = self.factory.post("/")
|
||||
request.user = self.user # No is_premium attribute
|
||||
|
||||
result = execute_function(request, "premium_fn", {})
|
||||
|
||||
self.assertIsInstance(result, FunctionError)
|
||||
self.assertEqual(result.code, ErrorCode.FORBIDDEN)
|
||||
self.assertEqual(result.message, "Premium subscription required")
|
||||
|
||||
def test_auth_callable_with_anonymous_user(self):
|
||||
"""Callable auth can check for anonymous users."""
|
||||
from django.contrib.auth.models import AnonymousUser
|
||||
|
||||
def must_be_authenticated(request):
|
||||
if not request.user.is_authenticated:
|
||||
raise PermissionError("Please log in")
|
||||
return True
|
||||
|
||||
@client(auth=must_be_authenticated)
|
||||
def needs_login_fn(request) -> OkOutput:
|
||||
return OkOutput(ok=True)
|
||||
register(needs_login_fn, "needs_login_fn")
|
||||
|
||||
request = self.factory.post("/")
|
||||
request.user = AnonymousUser()
|
||||
|
||||
result = execute_function(request, "needs_login_fn", {})
|
||||
|
||||
self.assertIsInstance(result, FunctionError)
|
||||
self.assertEqual(result.code, ErrorCode.FORBIDDEN)
|
||||
self.assertEqual(result.message, "Please log in")
|
||||
548
django/src/djarea/tests/test_benchmarks.py
Normal file
548
django/src/djarea/tests/test_benchmarks.py
Normal file
@@ -0,0 +1,548 @@
|
||||
"""
|
||||
Protocol Benchmark: HTTP vs WebSocket Server Functions
|
||||
|
||||
Compares performance of HTTP POST vs WebSocket RPC for server function calls.
|
||||
Includes realistic scenarios with ORM queries.
|
||||
|
||||
Usage:
|
||||
python manage.py test djarea.tests.test_benchmarks --verbosity=2
|
||||
|
||||
Note:
|
||||
These are not unit tests - they measure performance. Results are printed
|
||||
to stdout and should be run in isolation for accurate measurements.
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import json
|
||||
import statistics
|
||||
import time
|
||||
from typing import Any
|
||||
from unittest.mock import MagicMock, AsyncMock
|
||||
|
||||
from django.contrib.auth import get_user_model
|
||||
from django.contrib.auth.models import AnonymousUser
|
||||
from django.db import connection
|
||||
from django.http import HttpRequest
|
||||
from django.test import RequestFactory, TestCase, TransactionTestCase, override_settings
|
||||
from pydantic import BaseModel
|
||||
|
||||
from djarea.client.executor import FunctionResult, execute_function, function_call_view
|
||||
from djarea.setup.registry import clear_registry
|
||||
from djarea.client import client
|
||||
|
||||
|
||||
User = get_user_model()
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# Benchmark Output Models
|
||||
# =============================================================================
|
||||
|
||||
|
||||
class SimpleOutput(BaseModel):
|
||||
value: int
|
||||
|
||||
|
||||
class UserOutput(BaseModel):
|
||||
id: int
|
||||
email: str
|
||||
|
||||
|
||||
class UserListOutput(BaseModel):
|
||||
users: list[dict[str, Any]]
|
||||
count: int
|
||||
|
||||
|
||||
class StatsOutput(BaseModel):
|
||||
total_users: int
|
||||
active_users: int
|
||||
staff_count: int
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# Benchmark Functions
|
||||
# =============================================================================
|
||||
|
||||
|
||||
def setup_benchmark_functions():
|
||||
"""Register benchmark server functions."""
|
||||
from djarea.setup.registry import register
|
||||
|
||||
clear_registry()
|
||||
|
||||
# 1. Simple computation (no I/O)
|
||||
@client
|
||||
def bench_simple(request: HttpRequest, a: int, b: int) -> SimpleOutput:
|
||||
"""Simple addition - baseline with no I/O."""
|
||||
return SimpleOutput(value=a + b)
|
||||
register(bench_simple, "bench_simple")
|
||||
|
||||
# 2. Single ORM query
|
||||
@client
|
||||
def bench_get_user(request: HttpRequest, user_id: int) -> UserOutput:
|
||||
"""Fetch single user by ID."""
|
||||
user = User.objects.filter(id=user_id).first()
|
||||
if user:
|
||||
return UserOutput(id=user.id, email=user.email)
|
||||
return UserOutput(id=0, email="")
|
||||
register(bench_get_user, "bench_get_user")
|
||||
|
||||
# 3. List query with limit
|
||||
@client
|
||||
def bench_list_users(request: HttpRequest, limit: int) -> UserListOutput:
|
||||
"""Fetch list of users with limit."""
|
||||
users = User.objects.all()[:limit]
|
||||
return UserListOutput(
|
||||
users=[{"id": u.id, "email": u.email} for u in users],
|
||||
count=len(users),
|
||||
)
|
||||
register(bench_list_users, "bench_list_users")
|
||||
|
||||
# 4. Aggregation query
|
||||
@client
|
||||
def bench_user_stats(request: HttpRequest) -> StatsOutput:
|
||||
"""Compute user statistics with multiple queries."""
|
||||
total = User.objects.count()
|
||||
active = User.objects.filter(is_active=True).count()
|
||||
staff = User.objects.filter(is_staff=True).count()
|
||||
return StatsOutput(
|
||||
total_users=total,
|
||||
active_users=active,
|
||||
staff_count=staff,
|
||||
)
|
||||
register(bench_user_stats, "bench_user_stats")
|
||||
|
||||
# 5. Complex query with joins
|
||||
@client
|
||||
def bench_user_search(request: HttpRequest, email_contains: str, limit: int) -> UserListOutput:
|
||||
"""Search users by email pattern."""
|
||||
users = User.objects.filter(
|
||||
email__icontains=email_contains,
|
||||
is_active=True,
|
||||
).select_related()[:limit]
|
||||
return UserListOutput(
|
||||
users=[{"id": u.id, "email": u.email} for u in users],
|
||||
count=len(users),
|
||||
)
|
||||
register(bench_user_search, "bench_user_search")
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# Benchmark Test Cases
|
||||
# =============================================================================
|
||||
|
||||
|
||||
class ProtocolBenchmark(TransactionTestCase):
|
||||
"""
|
||||
Benchmark comparing HTTP vs WebSocket (simulated) performance.
|
||||
|
||||
Uses TransactionTestCase to ensure database state is realistic.
|
||||
"""
|
||||
|
||||
# Number of iterations for each benchmark
|
||||
ITERATIONS = 100
|
||||
WARMUP = 10
|
||||
|
||||
@classmethod
|
||||
def setUpClass(cls):
|
||||
super().setUpClass()
|
||||
setup_benchmark_functions()
|
||||
|
||||
def setUp(self):
|
||||
self.factory = RequestFactory()
|
||||
# Create test users for ORM benchmarks
|
||||
self._create_test_users()
|
||||
|
||||
def _create_test_users(self):
|
||||
"""Create test users for benchmarks."""
|
||||
# Create 100 test users
|
||||
users = []
|
||||
for i in range(100):
|
||||
users.append(User(
|
||||
email=f"bench{i}@example.com",
|
||||
is_active=i % 10 != 0, # 90% active
|
||||
is_staff=i < 5, # 5 staff
|
||||
))
|
||||
User.objects.bulk_create(users, ignore_conflicts=True)
|
||||
self.test_user = User.objects.first()
|
||||
|
||||
def _make_request(self, body: dict | None = None) -> HttpRequest:
|
||||
"""Create a request with optional JSON body."""
|
||||
if body:
|
||||
request = self.factory.post(
|
||||
"/api/djarea/call/",
|
||||
data=json.dumps(body),
|
||||
content_type="application/json",
|
||||
)
|
||||
else:
|
||||
request = self.factory.post("/api/djarea/call/")
|
||||
request.user = AnonymousUser()
|
||||
request._dont_enforce_csrf_checks = True
|
||||
return request
|
||||
|
||||
def _benchmark_executor(self, fn_name: str, args: dict, label: str) -> dict:
|
||||
"""
|
||||
Benchmark direct executor calls (simulates WebSocket RPC).
|
||||
|
||||
Returns timing statistics.
|
||||
"""
|
||||
request = self._make_request()
|
||||
times = []
|
||||
|
||||
# Warmup
|
||||
for _ in range(self.WARMUP):
|
||||
execute_function(request, fn_name, args)
|
||||
|
||||
# Benchmark
|
||||
for _ in range(self.ITERATIONS):
|
||||
start = time.perf_counter()
|
||||
result = execute_function(request, fn_name, args)
|
||||
end = time.perf_counter()
|
||||
times.append((end - start) * 1000) # ms
|
||||
|
||||
return self._compute_stats(times, f"Executor ({label})")
|
||||
|
||||
def _benchmark_http(self, fn_name: str, args: dict, label: str) -> dict:
|
||||
"""
|
||||
Benchmark HTTP view calls.
|
||||
|
||||
Returns timing statistics.
|
||||
"""
|
||||
times = []
|
||||
|
||||
# Warmup
|
||||
for _ in range(self.WARMUP):
|
||||
request = self._make_request({"fn": fn_name, "args": args})
|
||||
function_call_view(request)
|
||||
|
||||
# Benchmark
|
||||
for _ in range(self.ITERATIONS):
|
||||
request = self._make_request({"fn": fn_name, "args": args})
|
||||
start = time.perf_counter()
|
||||
response = function_call_view(request)
|
||||
end = time.perf_counter()
|
||||
times.append((end - start) * 1000) # ms
|
||||
|
||||
return self._compute_stats(times, f"HTTP ({label})")
|
||||
|
||||
def _compute_stats(self, times: list[float], label: str) -> dict:
|
||||
"""Compute statistics from timing data."""
|
||||
return {
|
||||
"label": label,
|
||||
"min": min(times),
|
||||
"max": max(times),
|
||||
"mean": statistics.mean(times),
|
||||
"median": statistics.median(times),
|
||||
"stdev": statistics.stdev(times) if len(times) > 1 else 0,
|
||||
"p95": sorted(times)[int(len(times) * 0.95)],
|
||||
"p99": sorted(times)[int(len(times) * 0.99)],
|
||||
"iterations": len(times),
|
||||
}
|
||||
|
||||
def _print_results(self, results: list[dict]):
|
||||
"""Print benchmark results in a table."""
|
||||
print("\n" + "=" * 80)
|
||||
print(f"{'Benchmark':<40} {'Mean':>8} {'Median':>8} {'P95':>8} {'P99':>8}")
|
||||
print("=" * 80)
|
||||
for r in results:
|
||||
print(f"{r['label']:<40} {r['mean']:>7.3f}ms {r['median']:>7.3f}ms {r['p95']:>7.3f}ms {r['p99']:>7.3f}ms")
|
||||
print("=" * 80)
|
||||
|
||||
def _print_comparison(self, executor_stats: dict, http_stats: dict):
|
||||
"""Print comparison between executor and HTTP."""
|
||||
overhead = ((http_stats["mean"] - executor_stats["mean"]) / executor_stats["mean"]) * 100
|
||||
print(f" HTTP overhead vs Executor: {overhead:+.1f}%")
|
||||
|
||||
# -------------------------------------------------------------------------
|
||||
# Benchmark Tests
|
||||
# -------------------------------------------------------------------------
|
||||
|
||||
def test_benchmark_simple_computation(self):
|
||||
"""Benchmark: Simple computation (no I/O)."""
|
||||
print("\n\n### BENCHMARK: Simple Computation (no I/O) ###")
|
||||
|
||||
args = {"a": 100, "b": 200}
|
||||
|
||||
exec_stats = self._benchmark_executor("bench_simple", args, "simple")
|
||||
http_stats = self._benchmark_http("bench_simple", args, "simple")
|
||||
|
||||
self._print_results([exec_stats, http_stats])
|
||||
self._print_comparison(exec_stats, http_stats)
|
||||
|
||||
# Correctness check: 100 + 200 = 300
|
||||
request = self._make_request()
|
||||
result = execute_function(request, "bench_simple", args)
|
||||
self.assertIsInstance(result, FunctionResult)
|
||||
self.assertEqual(result.data["value"], 300)
|
||||
|
||||
def test_benchmark_single_query(self):
|
||||
"""Benchmark: Single ORM query."""
|
||||
print("\n\n### BENCHMARK: Single ORM Query ###")
|
||||
|
||||
args = {"user_id": self.test_user.id if self.test_user else 1}
|
||||
|
||||
exec_stats = self._benchmark_executor("bench_get_user", args, "single query")
|
||||
http_stats = self._benchmark_http("bench_get_user", args, "single query")
|
||||
|
||||
self._print_results([exec_stats, http_stats])
|
||||
self._print_comparison(exec_stats, http_stats)
|
||||
|
||||
# Correctness check: should return the test user's data
|
||||
request = self._make_request()
|
||||
result = execute_function(request, "bench_get_user", args)
|
||||
self.assertIsInstance(result, FunctionResult)
|
||||
self.assertEqual(result.data["id"], self.test_user.id)
|
||||
|
||||
def test_benchmark_list_query(self):
|
||||
"""Benchmark: List query with serialization."""
|
||||
print("\n\n### BENCHMARK: List Query (10 users) ###")
|
||||
|
||||
args = {"limit": 10}
|
||||
|
||||
exec_stats = self._benchmark_executor("bench_list_users", args, "list 10")
|
||||
http_stats = self._benchmark_http("bench_list_users", args, "list 10")
|
||||
|
||||
self._print_results([exec_stats, http_stats])
|
||||
self._print_comparison(exec_stats, http_stats)
|
||||
|
||||
# Correctness check: should return up to 10 users
|
||||
request = self._make_request()
|
||||
result = execute_function(request, "bench_list_users", args)
|
||||
self.assertIsInstance(result, FunctionResult)
|
||||
self.assertLessEqual(result.data["count"], 10)
|
||||
self.assertEqual(len(result.data["users"]), result.data["count"])
|
||||
|
||||
def test_benchmark_aggregation(self):
|
||||
"""Benchmark: Aggregation queries."""
|
||||
print("\n\n### BENCHMARK: Aggregation (3 COUNT queries) ###")
|
||||
|
||||
args = {}
|
||||
|
||||
exec_stats = self._benchmark_executor("bench_user_stats", args, "aggregation")
|
||||
http_stats = self._benchmark_http("bench_user_stats", args, "aggregation")
|
||||
|
||||
self._print_results([exec_stats, http_stats])
|
||||
self._print_comparison(exec_stats, http_stats)
|
||||
|
||||
# Correctness check: stats should have non-negative counts
|
||||
request = self._make_request()
|
||||
result = execute_function(request, "bench_user_stats", args)
|
||||
self.assertIsInstance(result, FunctionResult)
|
||||
self.assertGreaterEqual(result.data["total_users"], 0)
|
||||
self.assertGreaterEqual(result.data["active_users"], 0)
|
||||
self.assertGreaterEqual(result.data["staff_count"], 0)
|
||||
|
||||
def test_benchmark_search_query(self):
|
||||
"""Benchmark: Search with filter."""
|
||||
print("\n\n### BENCHMARK: Search Query (LIKE + LIMIT) ###")
|
||||
|
||||
args = {"email_contains": "bench", "limit": 20}
|
||||
|
||||
exec_stats = self._benchmark_executor("bench_user_search", args, "search")
|
||||
http_stats = self._benchmark_http("bench_user_search", args, "search")
|
||||
|
||||
self._print_results([exec_stats, http_stats])
|
||||
self._print_comparison(exec_stats, http_stats)
|
||||
|
||||
# Correctness check: search results should contain "bench" in emails
|
||||
request = self._make_request()
|
||||
result = execute_function(request, "bench_user_search", args)
|
||||
self.assertIsInstance(result, FunctionResult)
|
||||
self.assertLessEqual(result.data["count"], 20)
|
||||
for user in result.data["users"]:
|
||||
self.assertIn("bench", user["email"].lower())
|
||||
|
||||
def test_summary(self):
|
||||
"""Print summary of all benchmarks."""
|
||||
print("\n\n" + "=" * 80)
|
||||
print("BENCHMARK SUMMARY")
|
||||
print("=" * 80)
|
||||
print(f"Iterations per benchmark: {self.ITERATIONS}")
|
||||
print(f"Warmup iterations: {self.WARMUP}")
|
||||
print("\nKey findings:")
|
||||
print("- 'Executor' simulates WebSocket RPC (direct function call)")
|
||||
print("- 'HTTP' measures full request/response cycle")
|
||||
print("- HTTP overhead includes: JSON parsing, CSRF, view dispatch")
|
||||
print("- For I/O-bound operations, protocol overhead is negligible")
|
||||
print("=" * 80)
|
||||
|
||||
# Verify bench_simple still produces correct output after all benchmarks
|
||||
request = self._make_request()
|
||||
result = execute_function(request, "bench_simple", {"a": 7, "b": 8})
|
||||
self.assertIsInstance(result, FunctionResult)
|
||||
self.assertEqual(result.data["value"], 15)
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# Throughput Benchmark
|
||||
# =============================================================================
|
||||
|
||||
|
||||
class ThroughputBenchmark(TransactionTestCase):
|
||||
"""
|
||||
Measure requests per second (throughput) for server functions.
|
||||
|
||||
Tests both sequential and concurrent scenarios.
|
||||
"""
|
||||
|
||||
DURATION_SECONDS = 2 # How long to run each throughput test
|
||||
|
||||
@classmethod
|
||||
def setUpClass(cls):
|
||||
super().setUpClass()
|
||||
setup_benchmark_functions()
|
||||
|
||||
def setUp(self):
|
||||
self.factory = RequestFactory()
|
||||
self._create_test_users()
|
||||
|
||||
def _create_test_users(self):
|
||||
"""Create test users for benchmarks."""
|
||||
users = []
|
||||
for i in range(100):
|
||||
users.append(User(
|
||||
email=f"bench{i}@example.com",
|
||||
is_active=i % 10 != 0,
|
||||
is_staff=i < 5,
|
||||
))
|
||||
User.objects.bulk_create(users, ignore_conflicts=True)
|
||||
self.test_user = User.objects.first()
|
||||
|
||||
def _make_request(self, body: dict) -> HttpRequest:
|
||||
"""Create a POST request with JSON body."""
|
||||
request = self.factory.post(
|
||||
"/api/djarea/call/",
|
||||
data=json.dumps(body),
|
||||
content_type="application/json",
|
||||
)
|
||||
request.user = AnonymousUser()
|
||||
request._dont_enforce_csrf_checks = True
|
||||
return request
|
||||
|
||||
def _measure_throughput_executor(self, fn_name: str, args: dict) -> float:
|
||||
"""Measure requests/second using direct executor calls."""
|
||||
request = self._make_request({"fn": fn_name, "args": args})
|
||||
|
||||
# Warmup
|
||||
for _ in range(10):
|
||||
execute_function(request, fn_name, args)
|
||||
|
||||
# Measure
|
||||
count = 0
|
||||
start = time.perf_counter()
|
||||
deadline = start + self.DURATION_SECONDS
|
||||
|
||||
while time.perf_counter() < deadline:
|
||||
execute_function(request, fn_name, args)
|
||||
count += 1
|
||||
|
||||
elapsed = time.perf_counter() - start
|
||||
return count / elapsed
|
||||
|
||||
def _measure_throughput_http(self, fn_name: str, args: dict) -> float:
|
||||
"""Measure requests/second using HTTP view calls."""
|
||||
# Warmup
|
||||
for _ in range(10):
|
||||
request = self._make_request({"fn": fn_name, "args": args})
|
||||
function_call_view(request)
|
||||
|
||||
# Measure
|
||||
count = 0
|
||||
start = time.perf_counter()
|
||||
deadline = start + self.DURATION_SECONDS
|
||||
|
||||
while time.perf_counter() < deadline:
|
||||
request = self._make_request({"fn": fn_name, "args": args})
|
||||
function_call_view(request)
|
||||
count += 1
|
||||
|
||||
elapsed = time.perf_counter() - start
|
||||
return count / elapsed
|
||||
|
||||
def _print_throughput(self, label: str, executor_rps: float, http_rps: float):
|
||||
"""Print throughput results."""
|
||||
print(f"\n{label}:")
|
||||
print(f" Executor (WebSocket): {executor_rps:,.0f} req/s")
|
||||
print(f" HTTP: {http_rps:,.0f} req/s")
|
||||
print(f" Ratio: {executor_rps/http_rps:.1f}x")
|
||||
|
||||
def test_throughput_simple(self):
|
||||
"""Throughput: Simple computation (no I/O)."""
|
||||
print("\n\n### THROUGHPUT: Simple Computation ###")
|
||||
|
||||
executor_rps = self._measure_throughput_executor("bench_simple", {"a": 1, "b": 2})
|
||||
http_rps = self._measure_throughput_http("bench_simple", {"a": 1, "b": 2})
|
||||
|
||||
self._print_throughput("Simple (no I/O)", executor_rps, http_rps)
|
||||
|
||||
# Correctness check
|
||||
request = self._make_request({"fn": "bench_simple", "args": {"a": 1, "b": 2}})
|
||||
result = execute_function(request, "bench_simple", {"a": 1, "b": 2})
|
||||
self.assertIsInstance(result, FunctionResult)
|
||||
self.assertEqual(result.data["value"], 3)
|
||||
|
||||
def test_throughput_single_query(self):
|
||||
"""Throughput: Single ORM query."""
|
||||
print("\n\n### THROUGHPUT: Single ORM Query ###")
|
||||
|
||||
args = {"user_id": self.test_user.id if self.test_user else 1}
|
||||
|
||||
executor_rps = self._measure_throughput_executor("bench_get_user", args)
|
||||
http_rps = self._measure_throughput_http("bench_get_user", args)
|
||||
|
||||
self._print_throughput("Single Query", executor_rps, http_rps)
|
||||
|
||||
# Correctness check
|
||||
request = self._make_request({"fn": "bench_get_user", "args": args})
|
||||
result = execute_function(request, "bench_get_user", args)
|
||||
self.assertIsInstance(result, FunctionResult)
|
||||
self.assertEqual(result.data["id"], self.test_user.id)
|
||||
|
||||
def test_throughput_list_query(self):
|
||||
"""Throughput: List query."""
|
||||
print("\n\n### THROUGHPUT: List Query (10 users) ###")
|
||||
|
||||
executor_rps = self._measure_throughput_executor("bench_list_users", {"limit": 10})
|
||||
http_rps = self._measure_throughput_http("bench_list_users", {"limit": 10})
|
||||
|
||||
self._print_throughput("List Query", executor_rps, http_rps)
|
||||
|
||||
# Correctness check
|
||||
request = self._make_request({"fn": "bench_list_users", "args": {"limit": 10}})
|
||||
result = execute_function(request, "bench_list_users", {"limit": 10})
|
||||
self.assertIsInstance(result, FunctionResult)
|
||||
self.assertLessEqual(result.data["count"], 10)
|
||||
|
||||
def test_throughput_aggregation(self):
|
||||
"""Throughput: Aggregation queries."""
|
||||
print("\n\n### THROUGHPUT: Aggregation ###")
|
||||
|
||||
executor_rps = self._measure_throughput_executor("bench_user_stats", {})
|
||||
http_rps = self._measure_throughput_http("bench_user_stats", {})
|
||||
|
||||
self._print_throughput("Aggregation", executor_rps, http_rps)
|
||||
|
||||
# Correctness check
|
||||
request = self._make_request({"fn": "bench_user_stats", "args": {}})
|
||||
result = execute_function(request, "bench_user_stats", {})
|
||||
self.assertIsInstance(result, FunctionResult)
|
||||
self.assertGreaterEqual(result.data["total_users"], 0)
|
||||
|
||||
def test_throughput_summary(self):
|
||||
"""Print throughput summary."""
|
||||
print("\n\n" + "=" * 80)
|
||||
print("THROUGHPUT SUMMARY")
|
||||
print("=" * 80)
|
||||
print(f"Test duration: {self.DURATION_SECONDS}s per scenario")
|
||||
print("\nNotes:")
|
||||
print("- These are single-threaded sequential measurements")
|
||||
print("- Real throughput scales with worker processes (gunicorn -w N)")
|
||||
print("- Database queries are the bottleneck, not protocol overhead")
|
||||
print("- Async workers (uvicorn) can handle more concurrent connections")
|
||||
print("=" * 80)
|
||||
|
||||
# Verify bench_simple still produces correct output after all throughput tests
|
||||
request = self._make_request({"fn": "bench_simple", "args": {"a": 10, "b": 20}})
|
||||
result = execute_function(request, "bench_simple", {"a": 10, "b": 20})
|
||||
self.assertIsInstance(result, FunctionResult)
|
||||
self.assertEqual(result.data["value"], 30)
|
||||
1170
django/src/djarea/tests/test_channels.py
Normal file
1170
django/src/djarea/tests/test_channels.py
Normal file
File diff suppressed because it is too large
Load Diff
1081
django/src/djarea/tests/test_core.py
Normal file
1081
django/src/djarea/tests/test_core.py
Normal file
File diff suppressed because it is too large
Load Diff
1224
django/src/djarea/tests/test_pentest.py
Normal file
1224
django/src/djarea/tests/test_pentest.py
Normal file
File diff suppressed because it is too large
Load Diff
1095
django/src/djarea/tests/test_security.py
Normal file
1095
django/src/djarea/tests/test_security.py
Normal file
File diff suppressed because it is too large
Load Diff
40
django/src/djarea/urls.py
Normal file
40
django/src/djarea/urls.py
Normal file
@@ -0,0 +1,40 @@
|
||||
"""
|
||||
Djarea URL Configuration
|
||||
|
||||
Single integration point for all djarea HTTP endpoints:
|
||||
- GET /session/ - Initialize session and get CSRF token (for SSR)
|
||||
- POST /call/ - Server function calls (HTTP transport)
|
||||
|
||||
Security:
|
||||
- Schema export is NOT exposed over HTTP to prevent API enumeration
|
||||
- Use the management command instead: python manage.py export_djarea_schema
|
||||
"""
|
||||
|
||||
from django.http import JsonResponse
|
||||
from django.middleware.csrf import get_token
|
||||
from django.urls import path
|
||||
from django.views.decorators.csrf import ensure_csrf_cookie
|
||||
|
||||
from .client.executor import function_call_view
|
||||
|
||||
app_name = "djarea"
|
||||
|
||||
|
||||
@ensure_csrf_cookie
|
||||
def session_init_view(request):
|
||||
"""
|
||||
Initialize a Django session and return the CSRF token.
|
||||
|
||||
Used by SSR to establish a session before making authenticated requests.
|
||||
The @ensure_csrf_cookie decorator ensures the csrftoken cookie is set.
|
||||
|
||||
Returns:
|
||||
{ "csrfToken": "..." }
|
||||
"""
|
||||
return JsonResponse({"csrfToken": get_token(request)})
|
||||
|
||||
|
||||
urlpatterns = [
|
||||
path("session/", session_init_view, name="session-init"),
|
||||
path("call/", function_call_view, name="function-call"),
|
||||
]
|
||||
0
django/tests/__init__.py
Normal file
0
django/tests/__init__.py
Normal file
40
django/tests/models.py
Normal file
40
django/tests/models.py
Normal file
@@ -0,0 +1,40 @@
|
||||
from django.contrib.auth.models import AbstractBaseUser, BaseUserManager, PermissionsMixin
|
||||
from django.db import models
|
||||
|
||||
|
||||
class EmailUserManager(BaseUserManager):
|
||||
"""Custom user manager using email as the unique identifier."""
|
||||
|
||||
def create_user(self, email, password=None, **extra_fields):
|
||||
if not email:
|
||||
raise ValueError("Email is required")
|
||||
email = self.normalize_email(email)
|
||||
user = self.model(email=email, **extra_fields)
|
||||
user.set_password(password)
|
||||
user.save(using=self._db)
|
||||
return user
|
||||
|
||||
def create_superuser(self, email, password=None, **extra_fields):
|
||||
extra_fields.setdefault("is_staff", True)
|
||||
extra_fields.setdefault("is_superuser", True)
|
||||
return self.create_user(email, password, **extra_fields)
|
||||
|
||||
|
||||
class EmailUser(AbstractBaseUser, PermissionsMixin):
|
||||
"""Minimal user model with email as USERNAME_FIELD.
|
||||
|
||||
Matches the calling convention used in djarea's test suite:
|
||||
User.objects.create_user(email="...", password="...", is_staff=True)
|
||||
"""
|
||||
|
||||
email = models.EmailField(unique=True)
|
||||
is_staff = models.BooleanField(default=False)
|
||||
is_active = models.BooleanField(default=True)
|
||||
|
||||
objects = EmailUserManager()
|
||||
|
||||
USERNAME_FIELD = "email"
|
||||
REQUIRED_FIELDS = []
|
||||
|
||||
class Meta:
|
||||
app_label = "tests"
|
||||
46
django/tests/settings.py
Normal file
46
django/tests/settings.py
Normal file
@@ -0,0 +1,46 @@
|
||||
"""
|
||||
Django settings for running djarea's test suite standalone.
|
||||
|
||||
Usage:
|
||||
cd django/
|
||||
pip install -e ".[dev]"
|
||||
pytest
|
||||
"""
|
||||
|
||||
SECRET_KEY = "test-secret-key-for-standalone-tests-only"
|
||||
|
||||
DEBUG = True
|
||||
|
||||
DATABASES = {
|
||||
"default": {
|
||||
"ENGINE": "django.db.backends.sqlite3",
|
||||
"NAME": ":memory:",
|
||||
}
|
||||
}
|
||||
|
||||
INSTALLED_APPS = [
|
||||
"django.contrib.auth",
|
||||
"django.contrib.contenttypes",
|
||||
"django.contrib.sessions",
|
||||
"djarea",
|
||||
"tests",
|
||||
]
|
||||
|
||||
AUTH_USER_MODEL = "tests.EmailUser"
|
||||
|
||||
ROOT_URLCONF = "tests.urls"
|
||||
|
||||
DEFAULT_AUTO_FIELD = "django.db.models.BigAutoField"
|
||||
|
||||
# JWT settings for test_auth.py (can be overridden per-class with @override_settings)
|
||||
JWT_PRIVATE_KEY = "test-secret-key-for-testing-only"
|
||||
JWT_ALGORITHM = "HS256"
|
||||
|
||||
# Session engine (for test_auth.py SessionStore usage)
|
||||
SESSION_ENGINE = "django.contrib.sessions.backends.db"
|
||||
|
||||
MIDDLEWARE = [
|
||||
"django.contrib.sessions.middleware.SessionMiddleware",
|
||||
"django.middleware.csrf.CsrfViewMiddleware",
|
||||
"django.contrib.auth.middleware.AuthenticationMiddleware",
|
||||
]
|
||||
5
django/tests/urls.py
Normal file
5
django/tests/urls.py
Normal file
@@ -0,0 +1,5 @@
|
||||
from django.urls import include, path
|
||||
|
||||
urlpatterns = [
|
||||
path("api/djarea/", include("djarea.urls")),
|
||||
]
|
||||
17
docker-compose.test.yml
Normal file
17
docker-compose.test.yml
Normal file
@@ -0,0 +1,17 @@
|
||||
services:
|
||||
redis:
|
||||
image: redis:7-alpine
|
||||
ports:
|
||||
- "6379:6379"
|
||||
|
||||
django:
|
||||
build:
|
||||
context: .
|
||||
dockerfile: Dockerfile.test
|
||||
ports:
|
||||
- "8000:8000"
|
||||
depends_on:
|
||||
- redis
|
||||
environment:
|
||||
- REDIS_URL=redis://redis:6379/0
|
||||
- DJANGO_SETTINGS_MODULE=testapp.settings
|
||||
186
e2e/djarea.spec.ts
Normal file
186
e2e/djarea.spec.ts
Normal file
@@ -0,0 +1,186 @@
|
||||
/**
|
||||
* Djarea E2E Integration Tests
|
||||
*
|
||||
* Real Chromium → Real React app (generated hooks) → Real Django backend
|
||||
*
|
||||
* Every test uses the generated Djarea API, not raw call() or fetch().
|
||||
*/
|
||||
|
||||
import { test, expect } from '@playwright/test'
|
||||
|
||||
const BASE = process.env.HARNESS_URL || 'http://localhost:5174'
|
||||
|
||||
async function fixture(page: any, name: string) {
|
||||
await page.goto(`${BASE}#${name}`)
|
||||
await page.waitForSelector('[data-testid="result"], [data-testid="error-type"]', { timeout: 10000 })
|
||||
}
|
||||
|
||||
async function getResult(page: any): Promise<any> {
|
||||
const el = page.locator('[data-testid="result"]')
|
||||
if (await el.count() > 0) return JSON.parse(await el.textContent())
|
||||
return null
|
||||
}
|
||||
|
||||
async function getError(page: any) {
|
||||
const typeEl = page.locator('[data-testid="error-type"]')
|
||||
if (await typeEl.count() === 0) return null
|
||||
return {
|
||||
type: await typeEl.textContent(),
|
||||
code: await page.locator('[data-testid="error-code"]').textContent(),
|
||||
message: await page.locator('[data-testid="error-message"]').textContent(),
|
||||
}
|
||||
}
|
||||
|
||||
// ─── useEcho, useAdd, useMultiply ───────────────────────────────────────────
|
||||
|
||||
test.describe('generated function hooks', () => {
|
||||
test('useEcho returns echoed text', async ({ page }) => {
|
||||
await fixture(page, 'echo')
|
||||
const result = await getResult(page)
|
||||
expect(result.message).toContain('e2e-test')
|
||||
})
|
||||
|
||||
test('useAdd returns correct sum', async ({ page }) => {
|
||||
await fixture(page, 'add')
|
||||
const result = await getResult(page)
|
||||
expect(result.result).toBe(42)
|
||||
})
|
||||
|
||||
test('useMultiply (class-based ServerFunction) returns product', async ({ page }) => {
|
||||
await fixture(page, 'multiply')
|
||||
const result = await getResult(page)
|
||||
expect(result.product).toBe(42)
|
||||
})
|
||||
|
||||
test('usePermissionCheckFn succeeds with correct secret', async ({ page }) => {
|
||||
await fixture(page, 'permission-success')
|
||||
const result = await getResult(page)
|
||||
expect(result.message).toBe('access granted')
|
||||
})
|
||||
})
|
||||
|
||||
// ─── Error handling ─────────────────────────────────────────────────────────
|
||||
|
||||
test.describe('error codes from generated hooks', () => {
|
||||
test('non-existent function → DjangoError NOT_FOUND', async ({ page }) => {
|
||||
await fixture(page, 'not-found')
|
||||
const error = await getError(page)
|
||||
expect(error!.type).toBe('DjangoError')
|
||||
expect(error!.code).toBe('NOT_FOUND')
|
||||
})
|
||||
|
||||
test('wrong input types → DjangoError VALIDATION_ERROR', async ({ page }) => {
|
||||
await fixture(page, 'validation-error')
|
||||
const error = await getError(page)
|
||||
expect(error!.type).toBe('DjangoError')
|
||||
expect(error!.code).toBe('VALIDATION_ERROR')
|
||||
})
|
||||
|
||||
test('useWhoami anonymous → auth error', async ({ page }) => {
|
||||
await fixture(page, 'auth-required')
|
||||
const error = await getError(page)
|
||||
expect(error!.type).toBe('DjangoError')
|
||||
expect(['UNAUTHORIZED', 'FORBIDDEN']).toContain(error!.code)
|
||||
})
|
||||
|
||||
test('useStaffOnly anonymous → UNAUTHORIZED', async ({ page }) => {
|
||||
await fixture(page, 'staff-only')
|
||||
const error = await getError(page)
|
||||
expect(error!.type).toBe('DjangoError')
|
||||
expect(['UNAUTHORIZED', 'FORBIDDEN']).toContain(error!.code)
|
||||
})
|
||||
|
||||
test('useSuperuserOnly anonymous → UNAUTHORIZED', async ({ page }) => {
|
||||
await fixture(page, 'superuser-only')
|
||||
const error = await getError(page)
|
||||
expect(error!.type).toBe('DjangoError')
|
||||
expect(['UNAUTHORIZED', 'FORBIDDEN']).toContain(error!.code)
|
||||
})
|
||||
|
||||
test('useVerifiedOnly anonymous → FORBIDDEN', async ({ page }) => {
|
||||
await fixture(page, 'verified-only')
|
||||
const error = await getError(page)
|
||||
expect(error!.type).toBe('DjangoError')
|
||||
expect(['UNAUTHORIZED', 'FORBIDDEN']).toContain(error!.code)
|
||||
})
|
||||
|
||||
test('useNotImplementedFn → NOT_IMPLEMENTED', async ({ page }) => {
|
||||
await fixture(page, 'not-implemented')
|
||||
const error = await getError(page)
|
||||
expect(error!.type).toBe('DjangoError')
|
||||
expect(error!.code).toBe('NOT_IMPLEMENTED')
|
||||
})
|
||||
|
||||
test('useBuggyFn → INTERNAL_ERROR', async ({ page }) => {
|
||||
await fixture(page, 'internal-error')
|
||||
const error = await getError(page)
|
||||
expect(error!.type).toBe('DjangoError')
|
||||
expect(error!.code).toBe('INTERNAL_ERROR')
|
||||
})
|
||||
|
||||
test('usePermissionCheckFn wrong secret → FORBIDDEN', async ({ page }) => {
|
||||
await fixture(page, 'permission-error')
|
||||
const error = await getError(page)
|
||||
expect(error!.type).toBe('DjangoError')
|
||||
expect(error!.code).toBe('FORBIDDEN')
|
||||
})
|
||||
})
|
||||
|
||||
// ─── Context hooks ──────────────────────────────────────────────────────────
|
||||
|
||||
test.describe('generated context hooks', () => {
|
||||
test('useCurrentUser returns anonymous data', async ({ page }) => {
|
||||
await page.goto(`${BASE}#context-current-user`)
|
||||
// Context loads async, wait for result
|
||||
await page.waitForSelector('[data-testid="result"]', { timeout: 10000 })
|
||||
const result = await getResult(page)
|
||||
expect(result.authenticated).toBe(false)
|
||||
expect(result.email).toBe('')
|
||||
})
|
||||
})
|
||||
|
||||
// ─── Form hooks ─────────────────────────────────────────────────────────────
|
||||
|
||||
test.describe('generated form hooks', () => {
|
||||
test('useLoginForm loads schema with field definitions', async ({ page }) => {
|
||||
await fixture(page, 'form-login-schema')
|
||||
const result = await getResult(page)
|
||||
expect(result.fields).toBeDefined()
|
||||
expect(result.fields.login).toBeDefined()
|
||||
expect(result.fields.password).toBeDefined()
|
||||
})
|
||||
|
||||
test('useContactForm loads schema with DjareaFormMeta', async ({ page }) => {
|
||||
await fixture(page, 'form-contact-schema')
|
||||
const result = await getResult(page)
|
||||
expect(result.title).toBe('Contact Us')
|
||||
expect(result.subtitle).toBe("We'd love to hear from you")
|
||||
expect(result.submit_label).toBe('Send Message')
|
||||
expect(result.meta.live_validation).toBe(true)
|
||||
})
|
||||
|
||||
test('useContactForm submit returns on_submit_success data', async ({ page }) => {
|
||||
await fixture(page, 'form-contact-submit')
|
||||
const result = await getResult(page)
|
||||
expect(result.success).toBe(true)
|
||||
expect(result.data.received).toBe(true)
|
||||
expect(result.data.from).toBe('test@example.com')
|
||||
})
|
||||
})
|
||||
|
||||
// ─── Channel hooks ──────────────────────────────────────────────────────────
|
||||
|
||||
test.describe('generated channel hooks', () => {
|
||||
test('useChatChannel receives echoed message', async ({ page }) => {
|
||||
await page.goto(`${BASE}#channel-chat`)
|
||||
await page.waitForFunction(
|
||||
() => {
|
||||
const el = document.querySelector('[data-testid="channel-message-count"]')
|
||||
return el && parseInt(el.textContent || '0') > 0
|
||||
},
|
||||
{ timeout: 15000 }
|
||||
)
|
||||
const msg = JSON.parse(await page.locator('[data-testid="channel-last-message"]').textContent())
|
||||
expect(msg.text).toBe('hello from e2e')
|
||||
})
|
||||
})
|
||||
22
e2e/harness/django.config.mjs
Normal file
22
e2e/harness/django.config.mjs
Normal file
@@ -0,0 +1,22 @@
|
||||
import path from 'path'
|
||||
import { fileURLToPath } from 'url'
|
||||
|
||||
const __dirname = path.dirname(fileURLToPath(import.meta.url))
|
||||
const root = path.resolve(__dirname, '../..')
|
||||
|
||||
export default {
|
||||
projectId: 'e2e-harness',
|
||||
|
||||
source: {
|
||||
django: {
|
||||
managePath: path.join(root, 'example/manage.py'),
|
||||
command: [path.join(root, 'django/.venv/bin/python')],
|
||||
env: {
|
||||
PYTHONPATH: `${path.join(root, 'django/src')}:${path.join(root, 'example')}`,
|
||||
DJANGO_SETTINGS_MODULE: 'testapp.settings',
|
||||
},
|
||||
},
|
||||
},
|
||||
|
||||
output: 'src/api/generated.ts',
|
||||
}
|
||||
5
e2e/harness/index.html
Normal file
5
e2e/harness/index.html
Normal file
@@ -0,0 +1,5 @@
|
||||
<!DOCTYPE html>
|
||||
<html>
|
||||
<head><meta charset="UTF-8" /><title>Djarea E2E Harness</title></head>
|
||||
<body><div id="root"></div><script type="module" src="/src/main.tsx"></script></body>
|
||||
</html>
|
||||
22
e2e/harness/package.json
Normal file
22
e2e/harness/package.json
Normal file
@@ -0,0 +1,22 @@
|
||||
{
|
||||
"name": "djarea-e2e-harness",
|
||||
"private": true,
|
||||
"type": "module",
|
||||
"scripts": {
|
||||
"build": "vite build",
|
||||
"dev": "vite --port 5174"
|
||||
},
|
||||
"dependencies": {
|
||||
"@rythazhur/djarea": "file:../../react",
|
||||
"react": "^19.0.0",
|
||||
"react-dom": "^19.0.0",
|
||||
"zod": "^4.3.6"
|
||||
},
|
||||
"devDependencies": {
|
||||
"@types/react": "^19.0.0",
|
||||
"@types/react-dom": "^19.0.0",
|
||||
"@vitejs/plugin-react": "^4.0.0",
|
||||
"typescript": "^5.7.0",
|
||||
"vite": "^6.0.0"
|
||||
}
|
||||
}
|
||||
90
e2e/harness/src/api/index.ts
Normal file
90
e2e/harness/src/api/index.ts
Normal file
@@ -0,0 +1,90 @@
|
||||
/**
|
||||
* Djarea API - Consolidated Exports
|
||||
*
|
||||
* Import everything from here:
|
||||
*
|
||||
* @example
|
||||
* ```tsx
|
||||
* import {
|
||||
* DjangoContext,
|
||||
* useUser,
|
||||
* useEcho,
|
||||
* useChatChannel,
|
||||
* DjangoError,
|
||||
* } from '@/api'
|
||||
* ```
|
||||
*/
|
||||
|
||||
// AUTO-GENERATED by djarea - do not edit manually
|
||||
// Regenerate with: npm run schemas
|
||||
|
||||
// =============================================================================
|
||||
// Djarea Provider & Hooks
|
||||
// =============================================================================
|
||||
|
||||
export {
|
||||
getDjangoHydration,
|
||||
type DjangoHydration,
|
||||
} from './generated.django.server'
|
||||
|
||||
export {
|
||||
// Provider
|
||||
DjangoContext,
|
||||
type DjangoContextProps,
|
||||
|
||||
// Context hooks
|
||||
useCurrentUser,
|
||||
useGreet,
|
||||
|
||||
// Refresh hooks
|
||||
useDjangoRefresh,
|
||||
|
||||
// Function hooks
|
||||
useEcho,
|
||||
useAdd,
|
||||
useWhoami,
|
||||
useHttpOnlyEcho,
|
||||
useStaffOnly,
|
||||
useSuperuserOnly,
|
||||
useVerifiedOnly,
|
||||
useMultiply,
|
||||
useNotImplementedFn,
|
||||
useBuggyFn,
|
||||
usePermissionCheckFn,
|
||||
useWsWhoami,
|
||||
useJwtObtain,
|
||||
useJwtRefresh,
|
||||
|
||||
// Re-exports from djarea library
|
||||
useDjarea,
|
||||
useDjareaStatus,
|
||||
usePush,
|
||||
DjangoError,
|
||||
type ConnectionStatus,
|
||||
type PushMessage,
|
||||
type PushListener,
|
||||
} from './generated.django'
|
||||
|
||||
// =============================================================================
|
||||
// Channel Hooks
|
||||
// =============================================================================
|
||||
|
||||
export {
|
||||
useChatChannel,
|
||||
useNotificationsChannel,
|
||||
usePresenceChannel,
|
||||
usePrivateChannel,
|
||||
} from './generated.channels.hooks'
|
||||
|
||||
// =============================================================================
|
||||
// Channel Types
|
||||
// =============================================================================
|
||||
|
||||
export type {
|
||||
ChatParams,
|
||||
ChatReactMessage,
|
||||
ChatDjangoMessage,
|
||||
NotificationsDjangoMessage,
|
||||
PresenceDjangoMessage,
|
||||
PrivateDjangoMessage,
|
||||
} from './generated.channels'
|
||||
264
e2e/harness/src/fixtures.tsx
Normal file
264
e2e/harness/src/fixtures.tsx
Normal file
@@ -0,0 +1,264 @@
|
||||
/**
|
||||
* E2E Test Fixtures
|
||||
*
|
||||
* Each fixture uses GENERATED Djarea hooks (not raw call()).
|
||||
* Playwright reads the DOM to verify behavior.
|
||||
*
|
||||
* URL hash selects the fixture: #echo, #add, #multiply, etc.
|
||||
*/
|
||||
|
||||
import { useState, useEffect, useRef } from 'react'
|
||||
|
||||
// Generated typed hooks — the actual Djarea API
|
||||
import {
|
||||
DjangoContext,
|
||||
useEcho,
|
||||
useAdd,
|
||||
useMultiply,
|
||||
useWhoami,
|
||||
useStaffOnly,
|
||||
useSuperuserOnly,
|
||||
useVerifiedOnly,
|
||||
useNotImplementedFn,
|
||||
useBuggyFn,
|
||||
usePermissionCheckFn,
|
||||
useCurrentUser,
|
||||
DjangoError,
|
||||
useDjarea,
|
||||
} from './api/generated.django'
|
||||
import { useContactForm, useLoginForm } from './api/generated.forms'
|
||||
import { useChatChannel } from './api/generated.channels.hooks'
|
||||
|
||||
// ─── Fixture router ─────────────────────────────────────────────────────────
|
||||
|
||||
export function Fixtures() {
|
||||
const [hash, setHash] = useState(window.location.hash.slice(1))
|
||||
|
||||
useEffect(() => {
|
||||
const onHash = () => setHash(window.location.hash.slice(1))
|
||||
window.addEventListener('hashchange', onHash)
|
||||
return () => window.removeEventListener('hashchange', onHash)
|
||||
}, [])
|
||||
|
||||
switch (hash) {
|
||||
case 'echo': return <Echo />
|
||||
case 'add': return <Add />
|
||||
case 'multiply': return <Multiply />
|
||||
case 'not-found': return <NotFound />
|
||||
case 'validation-error': return <ValidationError />
|
||||
case 'auth-required': return <AuthRequired />
|
||||
case 'staff-only': return <StaffOnly />
|
||||
case 'superuser-only': return <SuperuserOnly />
|
||||
case 'verified-only': return <VerifiedOnly />
|
||||
case 'not-implemented': return <NotImplemented />
|
||||
case 'internal-error': return <InternalError />
|
||||
case 'permission-error': return <PermissionError_ />
|
||||
case 'permission-success': return <PermissionSuccess />
|
||||
case 'context-current-user': return <ContextCurrentUser />
|
||||
case 'form-login-schema': return <FormLoginSchema />
|
||||
case 'form-contact-schema': return <FormContactSchema />
|
||||
case 'form-contact-submit': return <FormContactSubmit />
|
||||
case 'channel-chat': return <ChannelChatFixture />
|
||||
default: return <div data-testid="ready">Harness ready. Set #hash.</div>
|
||||
}
|
||||
}
|
||||
|
||||
// ─── Result helper ──────────────────────────────────────────────────────────
|
||||
|
||||
function Result({ data, error }: { data?: unknown; error?: unknown }) {
|
||||
return (
|
||||
<>
|
||||
{data !== undefined && (
|
||||
<pre data-testid="result">{JSON.stringify(data)}</pre>
|
||||
)}
|
||||
{error !== undefined && error !== null && (
|
||||
<>
|
||||
<div data-testid="error-type">
|
||||
{error instanceof DjangoError ? 'DjangoError' : 'Error'}
|
||||
</div>
|
||||
<div data-testid="error-code">
|
||||
{error instanceof DjangoError ? error.code : ''}
|
||||
</div>
|
||||
<pre data-testid="error-message">
|
||||
{error instanceof Error ? error.message : String(error)}
|
||||
</pre>
|
||||
</>
|
||||
)}
|
||||
</>
|
||||
)
|
||||
}
|
||||
|
||||
// ─── Hook runner: calls a generated hook and renders result ─────────────────
|
||||
|
||||
function useRun<T>(hook: () => (input?: any) => Promise<T>, input?: any) {
|
||||
const call = hook()
|
||||
const [data, setData] = useState<T>()
|
||||
const [error, setError] = useState<unknown>()
|
||||
|
||||
useEffect(() => {
|
||||
call(input).then(setData).catch(setError)
|
||||
}, []) // eslint-disable-line react-hooks/exhaustive-deps
|
||||
|
||||
return { data, error }
|
||||
}
|
||||
|
||||
// ─── Server function fixtures ───────────────────────────────────────────────
|
||||
|
||||
function Echo() {
|
||||
const { data, error } = useRun(useEcho, { text: 'e2e-test' })
|
||||
return <Result data={data} error={error} />
|
||||
}
|
||||
|
||||
function Add() {
|
||||
const { data, error } = useRun(useAdd, { a: 17, b: 25 })
|
||||
return <Result data={data} error={error} />
|
||||
}
|
||||
|
||||
function Multiply() {
|
||||
const { data, error } = useRun(useMultiply, { x: 6, y: 7 })
|
||||
return <Result data={data} error={error} />
|
||||
}
|
||||
|
||||
function NotFound() {
|
||||
// Deliberately call a non-existent function via the raw primitive
|
||||
const { call } = useDjarea()
|
||||
const [error, setError] = useState<unknown>()
|
||||
useEffect(() => { call('does_not_exist').catch(setError) }, [call])
|
||||
return <Result error={error} />
|
||||
}
|
||||
|
||||
function ValidationError() {
|
||||
// Send wrong types to add (strings instead of numbers)
|
||||
const call = useAdd()
|
||||
const [error, setError] = useState<unknown>()
|
||||
useEffect(() => { (call as any)({ a: 'not_a_number', b: 'also_not' }).catch(setError) }, [call])
|
||||
return <Result error={error} />
|
||||
}
|
||||
|
||||
function AuthRequired() {
|
||||
const { data, error } = useRun(useWhoami)
|
||||
return <Result data={data} error={error} />
|
||||
}
|
||||
|
||||
function StaffOnly() {
|
||||
const { data, error } = useRun(useStaffOnly)
|
||||
return <Result data={data} error={error} />
|
||||
}
|
||||
|
||||
function SuperuserOnly() {
|
||||
const { data, error } = useRun(useSuperuserOnly)
|
||||
return <Result data={data} error={error} />
|
||||
}
|
||||
|
||||
function VerifiedOnly() {
|
||||
const { data, error } = useRun(useVerifiedOnly)
|
||||
return <Result data={data} error={error} />
|
||||
}
|
||||
|
||||
function NotImplemented() {
|
||||
const { data, error } = useRun(useNotImplementedFn)
|
||||
return <Result data={data} error={error} />
|
||||
}
|
||||
|
||||
function InternalError() {
|
||||
const { data, error } = useRun(useBuggyFn)
|
||||
return <Result data={data} error={error} />
|
||||
}
|
||||
|
||||
function PermissionError_() {
|
||||
const { data, error } = useRun(usePermissionCheckFn, { secret: 'wrong' })
|
||||
return <Result data={data} error={error} />
|
||||
}
|
||||
|
||||
function PermissionSuccess() {
|
||||
const { data, error } = useRun(usePermissionCheckFn, { secret: 'open-sesame' })
|
||||
return <Result data={data} error={error} />
|
||||
}
|
||||
|
||||
// ─── Context fixtures ───────────────────────────────────────────────────────
|
||||
|
||||
function ContextCurrentUser() {
|
||||
// useCurrentUser throws if context not loaded yet, so catch that
|
||||
try {
|
||||
const user = useCurrentUser()
|
||||
return <pre data-testid="result">{JSON.stringify(user)}</pre>
|
||||
} catch {
|
||||
return <div>loading context...</div>
|
||||
}
|
||||
}
|
||||
|
||||
// ─── Form fixtures (using generated form hooks) ─────────────────────────────
|
||||
|
||||
function FormLoginSchema() {
|
||||
const form = useLoginForm()
|
||||
if (form.loading) return <div>loading...</div>
|
||||
return <pre data-testid="result">{JSON.stringify(form.schema)}</pre>
|
||||
}
|
||||
|
||||
function FormContactSchema() {
|
||||
const form = useContactForm()
|
||||
if (form.loading) return <div>loading...</div>
|
||||
return <pre data-testid="result">{JSON.stringify(form.schema)}</pre>
|
||||
}
|
||||
|
||||
function FormContactSubmit() {
|
||||
const form = useContactForm()
|
||||
const [result, setResult] = useState<unknown>()
|
||||
const [submitted, setSubmitted] = useState(false)
|
||||
|
||||
useEffect(() => {
|
||||
if (!form.loading && !submitted) {
|
||||
form.set('name', 'Test User')
|
||||
form.set('email', 'test@example.com')
|
||||
form.set('message', 'Hello from e2e')
|
||||
setSubmitted(true)
|
||||
}
|
||||
}, [form.loading, submitted, form])
|
||||
|
||||
useEffect(() => {
|
||||
if (submitted && !result) {
|
||||
form.submit().then(setResult)
|
||||
}
|
||||
}, [submitted, result, form])
|
||||
|
||||
if (!result) return <div>loading...</div>
|
||||
return <pre data-testid="result">{JSON.stringify(result)}</pre>
|
||||
}
|
||||
|
||||
// ─── Channel fixtures ───────────────────────────────────────────────────────
|
||||
|
||||
function ChannelChatFixture() {
|
||||
// DjangoContext already includes ChannelProvider
|
||||
return <ChannelChat />
|
||||
}
|
||||
|
||||
function ChannelChat() {
|
||||
const chat = useChatChannel({ room: 'e2e' })
|
||||
const [sent, setSent] = useState(false)
|
||||
const prevStatus = useRef(chat.status)
|
||||
|
||||
useEffect(() => {
|
||||
// Send once when status transitions to 'connected' (meaning subscribed)
|
||||
// The hook maps subscribed → 'connected', but we need to wait for it
|
||||
// to go through 'connecting' first (before subscription is confirmed)
|
||||
const wasConnecting = prevStatus.current === 'connecting'
|
||||
prevStatus.current = chat.status
|
||||
|
||||
if (wasConnecting && chat.status === 'connected' && !sent) {
|
||||
chat.send({ text: 'hello from e2e' })
|
||||
setSent(true)
|
||||
}
|
||||
}, [chat.status, sent, chat])
|
||||
|
||||
return (
|
||||
<div>
|
||||
<div data-testid="channel-status">{chat.status}</div>
|
||||
<div data-testid="channel-message-count">{chat.messages.length}</div>
|
||||
{chat.messages.length > 0 && (
|
||||
<pre data-testid="channel-last-message">
|
||||
{JSON.stringify(chat.messages[chat.messages.length - 1])}
|
||||
</pre>
|
||||
)}
|
||||
</div>
|
||||
)
|
||||
}
|
||||
13
e2e/harness/src/main.tsx
Normal file
13
e2e/harness/src/main.tsx
Normal file
@@ -0,0 +1,13 @@
|
||||
import { createRoot } from 'react-dom/client'
|
||||
import { DjangoContext } from './api/generated.django'
|
||||
import { Fixtures } from './fixtures'
|
||||
|
||||
function App() {
|
||||
return (
|
||||
<DjangoContext baseUrl="/api/djarea">
|
||||
<Fixtures />
|
||||
</DjangoContext>
|
||||
)
|
||||
}
|
||||
|
||||
createRoot(document.getElementById('root')!).render(<App />)
|
||||
11
e2e/harness/tsconfig.json
Normal file
11
e2e/harness/tsconfig.json
Normal file
@@ -0,0 +1,11 @@
|
||||
{
|
||||
"compilerOptions": {
|
||||
"target": "ES2022",
|
||||
"module": "ESNext",
|
||||
"moduleResolution": "bundler",
|
||||
"strict": true,
|
||||
"jsx": "react-jsx",
|
||||
"skipLibCheck": true
|
||||
},
|
||||
"include": ["src"]
|
||||
}
|
||||
30
e2e/harness/vite.config.ts
Normal file
30
e2e/harness/vite.config.ts
Normal file
@@ -0,0 +1,30 @@
|
||||
import { defineConfig } from 'vite'
|
||||
import react from '@vitejs/plugin-react'
|
||||
import path from 'path'
|
||||
|
||||
const reactPkg = path.resolve(__dirname, '../../react/src')
|
||||
|
||||
export default defineConfig({
|
||||
plugins: [react()],
|
||||
resolve: {
|
||||
alias: {
|
||||
'djarea/channels': path.join(reactPkg, 'channels/index.ts'),
|
||||
'djarea/client/react': path.join(reactPkg, 'client/react.ts'),
|
||||
'djarea/client/nextjs': path.join(reactPkg, 'client/nextjs.tsx'),
|
||||
'djarea/client': path.join(reactPkg, 'client/index.ts'),
|
||||
'djarea/jwt': path.join(reactPkg, 'jwt/index.ts'),
|
||||
'djarea/allauth/nextjs': path.join(reactPkg, 'allauth/nextjs.tsx'),
|
||||
'djarea/allauth': path.join(reactPkg, 'allauth/index.ts'),
|
||||
'djarea': path.join(reactPkg, 'index.ts'),
|
||||
'@rythazhur/djarea/channels': path.join(reactPkg, 'channels/index.ts'),
|
||||
'@rythazhur/djarea/jwt': path.join(reactPkg, 'jwt/index.ts'),
|
||||
'@rythazhur/djarea': path.join(reactPkg, 'index.ts'),
|
||||
},
|
||||
},
|
||||
server: {
|
||||
proxy: {
|
||||
'/api': 'http://localhost:8000',
|
||||
'/ws': { target: 'ws://localhost:8000', ws: true },
|
||||
},
|
||||
},
|
||||
})
|
||||
10
example/manage.py
Normal file
10
example/manage.py
Normal file
@@ -0,0 +1,10 @@
|
||||
#!/usr/bin/env python
|
||||
import os
|
||||
import sys
|
||||
|
||||
if __name__ == "__main__":
|
||||
os.environ.setdefault("DJANGO_SETTINGS_MODULE", "testapp.settings")
|
||||
|
||||
from django.core.management import execute_from_command_line
|
||||
|
||||
execute_from_command_line(sys.argv)
|
||||
0
example/testapp/__init__.py
Normal file
0
example/testapp/__init__.py
Normal file
9
example/testapp/apps.py
Normal file
9
example/testapp/apps.py
Normal file
@@ -0,0 +1,9 @@
|
||||
from django.apps import AppConfig
|
||||
|
||||
|
||||
class TestAppConfig(AppConfig):
|
||||
name = "testapp"
|
||||
default_auto_field = "django.db.models.BigAutoField"
|
||||
|
||||
def ready(self):
|
||||
import testapp.djarea_clients # noqa: F401
|
||||
14
example/testapp/asgi.py
Normal file
14
example/testapp/asgi.py
Normal file
@@ -0,0 +1,14 @@
|
||||
import os
|
||||
|
||||
import django
|
||||
|
||||
os.environ.setdefault("DJANGO_SETTINGS_MODULE", "testapp.settings")
|
||||
django.setup()
|
||||
|
||||
from django.core.asgi import get_asgi_application
|
||||
from djarea import wrap_asgi
|
||||
|
||||
# Register server functions and channels before building the ASGI app
|
||||
import testapp.djarea_clients # noqa: F401
|
||||
|
||||
application = wrap_asgi(get_asgi_application())
|
||||
393
example/testapp/djarea_clients.py
Normal file
393
example/testapp/djarea_clients.py
Normal file
@@ -0,0 +1,393 @@
|
||||
"""
|
||||
Server functions and channels for integration tests.
|
||||
|
||||
Registers everything the React integration test suite expects:
|
||||
- echo, add (HTTP + WebSocket RPC)
|
||||
- login, signup, add_email forms
|
||||
- chat, notifications, presence channels
|
||||
"""
|
||||
|
||||
from django import forms
|
||||
from django.http import HttpRequest
|
||||
from pydantic import BaseModel
|
||||
|
||||
from djarea.client import ServerFunction, client
|
||||
from djarea.channels import ReactChannel
|
||||
from djarea.setup.registry import register, register_form, register_as
|
||||
from djarea.channels import register as register_channel
|
||||
from djarea.forms import DjareaFormMixin, DjareaFormMeta
|
||||
from djarea.jwt import jwt_obtain, jwt_refresh
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# Server Functions
|
||||
# =============================================================================
|
||||
|
||||
|
||||
class EchoOutput(BaseModel):
|
||||
message: str
|
||||
|
||||
|
||||
@client(websocket=True)
|
||||
def echo(request: HttpRequest, text: str) -> EchoOutput:
|
||||
return EchoOutput(message=text)
|
||||
|
||||
|
||||
register(echo, "echo")
|
||||
|
||||
|
||||
class AddOutput(BaseModel):
|
||||
result: int
|
||||
|
||||
|
||||
@client(websocket=True)
|
||||
def add(request: HttpRequest, a: int, b: int) -> AddOutput:
|
||||
return AddOutput(result=a + b)
|
||||
|
||||
|
||||
register(add, "add")
|
||||
|
||||
|
||||
class WhoamiOutput(BaseModel):
|
||||
user_id: int | None
|
||||
email: str
|
||||
is_staff: bool
|
||||
|
||||
|
||||
@client(auth=True)
|
||||
def whoami(request: HttpRequest) -> WhoamiOutput:
|
||||
return WhoamiOutput(
|
||||
user_id=getattr(request.user, 'id', None),
|
||||
email=getattr(request.user, 'email', ''),
|
||||
is_staff=getattr(request.user, 'is_staff', False),
|
||||
)
|
||||
|
||||
|
||||
register(whoami, "whoami")
|
||||
|
||||
|
||||
@client
|
||||
def http_only_echo(request: HttpRequest, text: str) -> EchoOutput:
|
||||
return EchoOutput(message=text)
|
||||
|
||||
|
||||
register(http_only_echo, "http_only_echo")
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# Forms
|
||||
# =============================================================================
|
||||
|
||||
|
||||
class LoginForm(forms.Form):
|
||||
login = forms.CharField(max_length=150, label="Login")
|
||||
password = forms.CharField(widget=forms.PasswordInput, label="Password")
|
||||
|
||||
|
||||
def handle_login(request, form):
|
||||
"""Login form submit handler."""
|
||||
from django.contrib.auth import authenticate, login
|
||||
|
||||
user = authenticate(
|
||||
request,
|
||||
username=form.cleaned_data["login"],
|
||||
password=form.cleaned_data["password"],
|
||||
)
|
||||
if user is not None:
|
||||
login(request, user)
|
||||
return {"success": True}
|
||||
form.add_error(None, "Invalid login credentials.")
|
||||
return None # Signals validation failure
|
||||
|
||||
|
||||
register_form(LoginForm, "login", submit_handler=handle_login)
|
||||
|
||||
|
||||
class SignupForm(forms.Form):
|
||||
email = forms.EmailField(label="Email")
|
||||
password1 = forms.CharField(widget=forms.PasswordInput, label="Password")
|
||||
|
||||
|
||||
def handle_signup(request, form):
|
||||
"""Signup form submit handler."""
|
||||
from django.contrib.auth import get_user_model
|
||||
|
||||
User = get_user_model()
|
||||
try:
|
||||
user = User.objects.create_user(
|
||||
email=form.cleaned_data["email"],
|
||||
password=form.cleaned_data["password1"],
|
||||
)
|
||||
return {"success": True, "data": {"user_id": user.pk}}
|
||||
except Exception as e:
|
||||
form.add_error(None, str(e))
|
||||
return None
|
||||
|
||||
|
||||
register_form(SignupForm, "signup", submit_handler=handle_signup)
|
||||
|
||||
|
||||
class AddEmailForm(forms.Form):
|
||||
email = forms.EmailField(label="Email address")
|
||||
|
||||
|
||||
register_form(AddEmailForm, "add_email")
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# Channels
|
||||
# =============================================================================
|
||||
|
||||
|
||||
class ChatChannel(ReactChannel):
|
||||
class Params(BaseModel):
|
||||
room: str
|
||||
|
||||
class ReactMessage(BaseModel):
|
||||
text: str
|
||||
|
||||
class DjangoMessage(BaseModel):
|
||||
text: str
|
||||
|
||||
def authorize(self, params=None):
|
||||
return True
|
||||
|
||||
def group(self, params=None):
|
||||
room = params.room if params else "default"
|
||||
return f"chat_{room}"
|
||||
|
||||
def receive(self, params, msg):
|
||||
return self.DjangoMessage(text=msg.text)
|
||||
|
||||
|
||||
register_channel(ChatChannel, "chat")
|
||||
|
||||
|
||||
class NotificationsChannel(ReactChannel):
|
||||
class DjangoMessage(BaseModel):
|
||||
text: str
|
||||
|
||||
def authorize(self, params=None):
|
||||
return True
|
||||
|
||||
def group(self, params=None):
|
||||
return "notifications_global"
|
||||
|
||||
|
||||
register_channel(NotificationsChannel, "notifications")
|
||||
|
||||
|
||||
class PresenceChannel(ReactChannel):
|
||||
class DjangoMessage(BaseModel):
|
||||
value: int
|
||||
|
||||
def authorize(self, params=None):
|
||||
return True
|
||||
|
||||
def group(self, params=None):
|
||||
return "presence_global"
|
||||
|
||||
|
||||
register_channel(PresenceChannel, "presence")
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# Auth Variations
|
||||
# =============================================================================
|
||||
|
||||
|
||||
# --- Staff-only ---
|
||||
@client(auth='staff')
|
||||
def staff_only(request: HttpRequest) -> EchoOutput:
|
||||
return EchoOutput(message=f"staff:{request.user.email}")
|
||||
|
||||
register(staff_only, "staff_only")
|
||||
|
||||
|
||||
# --- Superuser-only ---
|
||||
@client(auth='superuser')
|
||||
def superuser_only(request: HttpRequest) -> EchoOutput:
|
||||
return EchoOutput(message=f"superuser:{request.user.email}")
|
||||
|
||||
register(superuser_only, "superuser_only")
|
||||
|
||||
|
||||
# --- Callable auth ---
|
||||
def check_verified_email(request):
|
||||
if not request.user.is_authenticated:
|
||||
return False
|
||||
return getattr(request.user, 'email', '').endswith('@verified.com')
|
||||
|
||||
@client(auth=check_verified_email)
|
||||
def verified_only(request: HttpRequest) -> EchoOutput:
|
||||
return EchoOutput(message="verified")
|
||||
|
||||
register(verified_only, "verified_only")
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# Context Functions
|
||||
# =============================================================================
|
||||
|
||||
|
||||
class CurrentUserOutput(BaseModel):
|
||||
authenticated: bool
|
||||
email: str
|
||||
is_staff: bool
|
||||
|
||||
@client(context='global')
|
||||
def current_user(request: HttpRequest) -> CurrentUserOutput:
|
||||
if request.user.is_authenticated:
|
||||
return CurrentUserOutput(
|
||||
authenticated=True,
|
||||
email=request.user.email,
|
||||
is_staff=request.user.is_staff,
|
||||
)
|
||||
return CurrentUserOutput(authenticated=False, email="", is_staff=False)
|
||||
|
||||
register(current_user, "current_user")
|
||||
|
||||
|
||||
class GreetOutput(BaseModel):
|
||||
greeting: str
|
||||
|
||||
@client(context='local')
|
||||
def greet(request: HttpRequest, name: str) -> GreetOutput:
|
||||
return GreetOutput(greeting=f"Hello, {name}!")
|
||||
|
||||
register(greet, "greet")
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# Class-based ServerFunction
|
||||
# =============================================================================
|
||||
|
||||
|
||||
class MultiplyInput(BaseModel):
|
||||
x: int
|
||||
y: int
|
||||
|
||||
class MultiplyOutput(BaseModel):
|
||||
product: int
|
||||
|
||||
@register_as("multiply")
|
||||
class Multiply(ServerFunction):
|
||||
Input = MultiplyInput
|
||||
Output = MultiplyOutput
|
||||
|
||||
def call(self, input: MultiplyInput) -> MultiplyOutput:
|
||||
return MultiplyOutput(product=input.x * input.y)
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# Error-producing Functions
|
||||
# =============================================================================
|
||||
|
||||
|
||||
@client
|
||||
def not_implemented_fn(request: HttpRequest) -> EchoOutput:
|
||||
raise NotImplementedError("This feature is not yet implemented")
|
||||
|
||||
register(not_implemented_fn, "not_implemented_fn")
|
||||
|
||||
|
||||
@client
|
||||
def buggy_fn(request: HttpRequest) -> EchoOutput:
|
||||
raise RuntimeError("Unexpected internal failure")
|
||||
|
||||
register(buggy_fn, "buggy_fn")
|
||||
|
||||
|
||||
@client
|
||||
def permission_check_fn(request: HttpRequest, secret: str) -> EchoOutput:
|
||||
if secret != "open-sesame":
|
||||
raise PermissionError("Wrong secret")
|
||||
return EchoOutput(message="access granted")
|
||||
|
||||
register(permission_check_fn, "permission_check_fn")
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# WebSocket + Auth Function
|
||||
# =============================================================================
|
||||
|
||||
|
||||
@client(websocket=True, auth=True)
|
||||
def ws_whoami(request: HttpRequest) -> WhoamiOutput:
|
||||
return WhoamiOutput(
|
||||
user_id=getattr(request.user, 'id', None),
|
||||
email=getattr(request.user, 'email', ''),
|
||||
is_staff=getattr(request.user, 'is_staff', False),
|
||||
)
|
||||
|
||||
register(ws_whoami, "ws_whoami")
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# DjareaFormMixin Forms
|
||||
# =============================================================================
|
||||
|
||||
|
||||
class ContactForm(DjareaFormMixin, forms.Form):
|
||||
djarea = DjareaFormMeta(
|
||||
name="contact",
|
||||
title="Contact Us",
|
||||
subtitle="We'd love to hear from you",
|
||||
submit_label="Send Message",
|
||||
live_validation=True,
|
||||
live_form_errors=False,
|
||||
)
|
||||
|
||||
name = forms.CharField(max_length=100, label="Your Name")
|
||||
email = forms.EmailField(label="Email Address")
|
||||
message = forms.CharField(widget=forms.Textarea, label="Message")
|
||||
|
||||
def on_submit_success(self, request):
|
||||
return {"received": True, "from": self.cleaned_data["email"]}
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# Formset-enabled Form
|
||||
# =============================================================================
|
||||
|
||||
|
||||
class ItemForm(DjareaFormMixin, forms.Form):
|
||||
djarea = DjareaFormMeta(
|
||||
name="item",
|
||||
title="Items",
|
||||
submit_label="Save Items",
|
||||
enable_formset=True,
|
||||
)
|
||||
|
||||
label = forms.CharField(max_length=50, label="Item Label")
|
||||
quantity = forms.IntegerField(min_value=1, label="Quantity")
|
||||
|
||||
def on_submit_success(self, request):
|
||||
return {"label": self.cleaned_data["label"], "qty": self.cleaned_data["quantity"]}
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# Auth-gated Channel
|
||||
# =============================================================================
|
||||
|
||||
|
||||
class PrivateChannel(ReactChannel):
|
||||
class DjangoMessage(BaseModel):
|
||||
text: str
|
||||
|
||||
def authorize(self, params=None):
|
||||
return getattr(self.user, 'is_authenticated', False)
|
||||
|
||||
def group(self, params=None):
|
||||
return "private_global"
|
||||
|
||||
register_channel(PrivateChannel, "private")
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# JWT Function Registration
|
||||
# =============================================================================
|
||||
|
||||
|
||||
register(jwt_obtain, "jwt_obtain")
|
||||
register(jwt_refresh, "jwt_refresh")
|
||||
29
example/testapp/models.py
Normal file
29
example/testapp/models.py
Normal file
@@ -0,0 +1,29 @@
|
||||
from django.contrib.auth.models import AbstractBaseUser, BaseUserManager, PermissionsMixin
|
||||
from django.db import models
|
||||
|
||||
|
||||
class EmailUserManager(BaseUserManager):
|
||||
def create_user(self, email, password=None, **extra_fields):
|
||||
if not email:
|
||||
raise ValueError("Email is required")
|
||||
email = self.normalize_email(email)
|
||||
user = self.model(email=email, **extra_fields)
|
||||
user.set_password(password)
|
||||
user.save(using=self._db)
|
||||
return user
|
||||
|
||||
def create_superuser(self, email, password=None, **extra_fields):
|
||||
extra_fields.setdefault("is_staff", True)
|
||||
extra_fields.setdefault("is_superuser", True)
|
||||
return self.create_user(email, password, **extra_fields)
|
||||
|
||||
|
||||
class EmailUser(AbstractBaseUser, PermissionsMixin):
|
||||
email = models.EmailField(unique=True)
|
||||
is_staff = models.BooleanField(default=False)
|
||||
is_active = models.BooleanField(default=True)
|
||||
|
||||
objects = EmailUserManager()
|
||||
|
||||
USERNAME_FIELD = "email"
|
||||
REQUIRED_FIELDS = []
|
||||
76
example/testapp/settings.py
Normal file
76
example/testapp/settings.py
Normal file
@@ -0,0 +1,76 @@
|
||||
"""
|
||||
Django settings for the integration test backend.
|
||||
|
||||
Provides:
|
||||
- HTTP server functions (echo, add)
|
||||
- WebSocket channels (chat, notifications, presence)
|
||||
- JWT authentication
|
||||
- Form integration (login, signup, add_email)
|
||||
"""
|
||||
|
||||
import os
|
||||
|
||||
SECRET_KEY = "integration-test-secret-key-not-for-production"
|
||||
|
||||
DEBUG = True
|
||||
|
||||
ALLOWED_HOSTS = ["*"]
|
||||
|
||||
INSTALLED_APPS = [
|
||||
"django.contrib.auth",
|
||||
"django.contrib.contenttypes",
|
||||
"django.contrib.sessions",
|
||||
"djarea",
|
||||
"testapp",
|
||||
]
|
||||
|
||||
AUTH_USER_MODEL = "testapp.EmailUser"
|
||||
|
||||
MIDDLEWARE = [
|
||||
"django.contrib.sessions.middleware.SessionMiddleware",
|
||||
"django.middleware.csrf.CsrfViewMiddleware",
|
||||
"django.contrib.auth.middleware.AuthenticationMiddleware",
|
||||
]
|
||||
|
||||
ROOT_URLCONF = "testapp.urls"
|
||||
|
||||
DATABASES = {
|
||||
"default": {
|
||||
"ENGINE": "django.db.backends.sqlite3",
|
||||
"NAME": os.path.join(os.path.dirname(__file__), "..", "db.sqlite3"),
|
||||
}
|
||||
}
|
||||
|
||||
DEFAULT_AUTO_FIELD = "django.db.models.BigAutoField"
|
||||
|
||||
ASGI_APPLICATION = "testapp.asgi.application"
|
||||
|
||||
# JWT
|
||||
JWT_PRIVATE_KEY = "integration-test-jwt-secret-key"
|
||||
JWT_ALGORITHM = "HS256"
|
||||
|
||||
# Channel layers — Redis when available, in-memory fallback for local dev
|
||||
REDIS_URL = os.environ.get("REDIS_URL", "")
|
||||
if REDIS_URL:
|
||||
CHANNEL_LAYERS = {
|
||||
"default": {
|
||||
"BACKEND": "channels_redis.core.RedisChannelLayer",
|
||||
"CONFIG": {"hosts": [REDIS_URL]},
|
||||
},
|
||||
}
|
||||
else:
|
||||
CHANNEL_LAYERS = {
|
||||
"default": {
|
||||
"BACKEND": "channels.layers.InMemoryChannelLayer",
|
||||
},
|
||||
}
|
||||
|
||||
# Session
|
||||
SESSION_ENGINE = "django.contrib.sessions.backends.db"
|
||||
|
||||
# CORS — allow React dev server
|
||||
CSRF_TRUSTED_ORIGINS = [
|
||||
"http://localhost:3000",
|
||||
"http://localhost:5173",
|
||||
"http://localhost:5174",
|
||||
]
|
||||
5
example/testapp/urls.py
Normal file
5
example/testapp/urls.py
Normal file
@@ -0,0 +1,5 @@
|
||||
from django.urls import include, path
|
||||
|
||||
urlpatterns = [
|
||||
path("api/djarea/", include("djarea.urls")),
|
||||
]
|
||||
18
package.json
Normal file
18
package.json
Normal file
@@ -0,0 +1,18 @@
|
||||
{
|
||||
"name": "djarea",
|
||||
"version": "1.0.0",
|
||||
"description": "Django + React server functions framework.",
|
||||
"main": "index.js",
|
||||
"directories": {
|
||||
"example": "example"
|
||||
},
|
||||
"scripts": {},
|
||||
"keywords": [],
|
||||
"author": "",
|
||||
"license": "ISC",
|
||||
"type": "commonjs",
|
||||
"devDependencies": {
|
||||
"@playwright/test": "^1.58.2",
|
||||
"@types/node": "^25.5.0"
|
||||
}
|
||||
}
|
||||
14
playwright.config.ts
Normal file
14
playwright.config.ts
Normal file
@@ -0,0 +1,14 @@
|
||||
import { defineConfig } from '@playwright/test'
|
||||
|
||||
export default defineConfig({
|
||||
testDir: './e2e',
|
||||
timeout: 15000,
|
||||
retries: 0,
|
||||
reporter: 'list',
|
||||
use: {
|
||||
baseURL: 'http://localhost:8000',
|
||||
},
|
||||
projects: [
|
||||
{ name: 'chromium', use: { browserName: 'chromium' } },
|
||||
],
|
||||
})
|
||||
2
react/.gitignore
vendored
Normal file
2
react/.gitignore
vendored
Normal file
@@ -0,0 +1,2 @@
|
||||
node_modules/
|
||||
dist/
|
||||
26
react/README.md
Normal file
26
react/README.md
Normal file
@@ -0,0 +1,26 @@
|
||||
# djarea (TypeScript)
|
||||
|
||||
TypeScript client library for the Djarea framework. See the [monorepo root](../README.md) for full documentation.
|
||||
|
||||
## Installation
|
||||
|
||||
```bash
|
||||
# From git
|
||||
npm install djarea@git+https://git.impactsoundworks.com/isw/djarea.git#workspace=react
|
||||
|
||||
# Local development
|
||||
npm install djarea@file:../../web/djarea/react
|
||||
```
|
||||
|
||||
## Exports
|
||||
|
||||
| Import | Purpose |
|
||||
|--------|---------|
|
||||
| `djarea` | Core: DjareaProvider, hooks, forms, errors |
|
||||
| `djarea/client` | HTTP clients, SSR helpers, `ensureDjangoSession()` |
|
||||
| `djarea/client/react` | React-specific client hooks |
|
||||
| `djarea/client/nextjs` | Next.js integration |
|
||||
| `djarea/channels` | WebSocket channels |
|
||||
| `djarea/jwt` | JWT token management |
|
||||
| `djarea/allauth` | Allauth UI components |
|
||||
| `djarea/allauth/nextjs` | Next.js allauth context |
|
||||
85
react/package.json
Normal file
85
react/package.json
Normal file
@@ -0,0 +1,85 @@
|
||||
{
|
||||
"name": "@rythazhur/djarea",
|
||||
"version": "0.1.1",
|
||||
"type": "module",
|
||||
"main": "dist/index.js",
|
||||
"types": "dist/index.d.ts",
|
||||
"exports": {
|
||||
".": {
|
||||
"types": "./dist/index.d.ts",
|
||||
"import": "./dist/index.js"
|
||||
},
|
||||
"./client": {
|
||||
"types": "./dist/client/index.d.ts",
|
||||
"import": "./dist/client/index.js"
|
||||
},
|
||||
"./client/react": {
|
||||
"types": "./dist/client/react.d.ts",
|
||||
"import": "./dist/client/react.js"
|
||||
},
|
||||
"./client/nextjs": {
|
||||
"types": "./dist/client/nextjs.d.ts",
|
||||
"import": "./dist/client/nextjs.js"
|
||||
},
|
||||
"./channels": {
|
||||
"types": "./dist/channels/index.d.ts",
|
||||
"import": "./dist/channels/index.js"
|
||||
},
|
||||
"./jwt": {
|
||||
"types": "./dist/jwt/index.d.ts",
|
||||
"import": "./dist/jwt/index.js"
|
||||
},
|
||||
"./allauth": {
|
||||
"types": "./dist/allauth/index.d.ts",
|
||||
"import": "./dist/allauth/index.js"
|
||||
},
|
||||
"./allauth/nextjs": {
|
||||
"types": "./dist/allauth/nextjs.d.ts",
|
||||
"import": "./dist/allauth/nextjs.js"
|
||||
}
|
||||
},
|
||||
"bin": {
|
||||
"djarea-generate": "./dist/generator/cli.mjs"
|
||||
},
|
||||
"scripts": {
|
||||
"build": "tsc -p tsconfig.build.json && node -e \"require('fs').cpSync('src/generator','dist/generator',{recursive:true})\"",
|
||||
"dev": "tsc -p tsconfig.build.json --watch",
|
||||
"test": "vitest run",
|
||||
"test:watch": "vitest",
|
||||
"test:integration": "RUN_INTEGRATION_TESTS=true NEXT_PUBLIC_HOST_URL=http://localhost:8000 vitest run",
|
||||
"prepublishOnly": "npm run build"
|
||||
},
|
||||
"peerDependencies": {
|
||||
"react": ">=18",
|
||||
"react-dom": ">=18"
|
||||
},
|
||||
"peerDependenciesMeta": {
|
||||
"next": {
|
||||
"optional": true
|
||||
}
|
||||
},
|
||||
"dependencies": {
|
||||
"ws": "^8.19.0"
|
||||
},
|
||||
"devDependencies": {
|
||||
"@playwright/test": "^1.58.2",
|
||||
"@simplewebauthn/browser": "^13.2.2",
|
||||
"@testing-library/jest-dom": "^6.0.0",
|
||||
"@testing-library/react": "^16.0.0",
|
||||
"@types/react": "^19.0.0",
|
||||
"@types/react-dom": "^19.0.0",
|
||||
"@types/ws": "^8.5.0",
|
||||
"jsdom": "^25.0.0",
|
||||
"next": "^16.1.6",
|
||||
"react": "^19.0.0",
|
||||
"react-dom": "^19.0.0",
|
||||
"typescript": "^5.7.0",
|
||||
"vitest": "^3.0.0",
|
||||
"zod": "^4.3.6"
|
||||
},
|
||||
"optionalDependencies": {
|
||||
"chokidar": "^4.0.0",
|
||||
"minimatch": "^10.0.0",
|
||||
"openapi-typescript": "^7.0.0"
|
||||
}
|
||||
}
|
||||
314
react/src/__tests__/context.test.tsx
Normal file
314
react/src/__tests__/context.test.tsx
Normal file
@@ -0,0 +1,314 @@
|
||||
/**
|
||||
* Tests for Django Server React Context
|
||||
*
|
||||
* Unit tests run without backend.
|
||||
* Integration tests require: docker-compose up
|
||||
*
|
||||
* Run integration tests with: RUN_INTEGRATION_TESTS=true npm run test
|
||||
*/
|
||||
|
||||
import React from 'react'
|
||||
import { render, screen, waitFor, act } from '@testing-library/react'
|
||||
import {
|
||||
DjareaProvider,
|
||||
useDjarea,
|
||||
useDjareaStatus,
|
||||
useDjareaCall,
|
||||
// Legacy aliases for backwards compatibility tests
|
||||
DjangoContext,
|
||||
useDjango,
|
||||
useDjangoStatus,
|
||||
useServerFunction,
|
||||
} from '../context'
|
||||
import { DjangoError } from '../errors'
|
||||
import { describeIntegration, BACKEND_URL } from '../testing'
|
||||
|
||||
// ============================================================================
|
||||
// Unit Tests (no backend required)
|
||||
// ============================================================================
|
||||
|
||||
describe('Djarea Context (unit)', () => {
|
||||
describe('useDjarea hook', () => {
|
||||
it('should throw when used outside provider', () => {
|
||||
function TestComponent() {
|
||||
useDjarea()
|
||||
return <div>Test</div>
|
||||
}
|
||||
|
||||
const consoleSpy = jest.spyOn(console, 'error').mockImplementation()
|
||||
|
||||
expect(() => render(<TestComponent />)).toThrow(
|
||||
'useDjarea must be used within a DjareaProvider'
|
||||
)
|
||||
|
||||
consoleSpy.mockRestore()
|
||||
})
|
||||
|
||||
it('should return context value inside provider', () => {
|
||||
let contextValue: any = null
|
||||
|
||||
function TestComponent() {
|
||||
contextValue = useDjarea()
|
||||
return <div>Test</div>
|
||||
}
|
||||
|
||||
render(
|
||||
<DjareaProvider autoConnect={false}>
|
||||
<TestComponent />
|
||||
</DjareaProvider>
|
||||
)
|
||||
|
||||
expect(contextValue).not.toBeNull()
|
||||
expect(contextValue!.status).toBe('disconnected')
|
||||
})
|
||||
})
|
||||
|
||||
describe('useDjareaStatus hook', () => {
|
||||
it('should return disconnected when autoConnect is false', () => {
|
||||
function TestComponent() {
|
||||
const status = useDjareaStatus()
|
||||
return <div data-testid="status">{status}</div>
|
||||
}
|
||||
|
||||
render(
|
||||
<DjareaProvider autoConnect={false}>
|
||||
<TestComponent />
|
||||
</DjareaProvider>
|
||||
)
|
||||
|
||||
expect(screen.getByTestId('status')).toHaveTextContent('disconnected')
|
||||
})
|
||||
})
|
||||
|
||||
describe('hydration', () => {
|
||||
it('should initialize context store from hydration data', () => {
|
||||
let contextValue: any = null
|
||||
|
||||
function TestComponent() {
|
||||
contextValue = useDjarea()
|
||||
return <div>Test</div>
|
||||
}
|
||||
|
||||
const hydration = {
|
||||
auth_status: { is_authenticated: false },
|
||||
user: null,
|
||||
}
|
||||
|
||||
render(
|
||||
<DjareaProvider hydration={hydration} autoConnect={false}>
|
||||
<TestComponent />
|
||||
</DjareaProvider>
|
||||
)
|
||||
|
||||
expect(contextValue.getContext('auth_status')).toEqual({ is_authenticated: false })
|
||||
expect(contextValue.getContext('user')).toEqual(null)
|
||||
})
|
||||
})
|
||||
})
|
||||
|
||||
// ============================================================================
|
||||
// Integration Tests (require running backend)
|
||||
// ============================================================================
|
||||
|
||||
describeIntegration('Djarea Context (integration)', () => {
|
||||
describe('server function calls via HTTP', () => {
|
||||
it('should call echo function and get response', async () => {
|
||||
let result: any = null
|
||||
let error: any = null
|
||||
|
||||
function TestComponent() {
|
||||
const { call, status } = useDjango()
|
||||
|
||||
React.useEffect(() => {
|
||||
// Use HTTP fallback (status will be disconnected without WebSocket)
|
||||
call<{ text: string }, { message: string }>('echo', { text: 'context test' })
|
||||
.then((r) => { result = r })
|
||||
.catch((e) => { error = e })
|
||||
}, [call])
|
||||
|
||||
return <div data-testid="status">{status}</div>
|
||||
}
|
||||
|
||||
render(
|
||||
<DjangoContext baseUrl={`${BACKEND_URL}/api/djarea`} autoConnect={false}>
|
||||
<TestComponent />
|
||||
</DjangoContext>
|
||||
)
|
||||
|
||||
await waitFor(() => {
|
||||
expect(result || error).not.toBeNull()
|
||||
}, { timeout: 5000 })
|
||||
|
||||
expect(error).toBeNull()
|
||||
expect(result).toHaveProperty('message')
|
||||
expect(result.message).toContain('context test')
|
||||
})
|
||||
|
||||
it('should call add function with correct result', async () => {
|
||||
let result: any = null
|
||||
let error: any = null
|
||||
|
||||
function TestComponent() {
|
||||
const { call } = useDjango()
|
||||
|
||||
React.useEffect(() => {
|
||||
call<{ a: number; b: number }, { result: number }>('add', { a: 10, b: 20 })
|
||||
.then((r) => { result = r })
|
||||
.catch((e) => { error = e })
|
||||
}, [call])
|
||||
|
||||
return <div>Test</div>
|
||||
}
|
||||
|
||||
render(
|
||||
<DjangoContext baseUrl={`${BACKEND_URL}/api/djarea`} autoConnect={false}>
|
||||
<TestComponent />
|
||||
</DjangoContext>
|
||||
)
|
||||
|
||||
await waitFor(() => {
|
||||
expect(result || error).not.toBeNull()
|
||||
}, { timeout: 5000 })
|
||||
|
||||
expect(error).toBeNull()
|
||||
expect(result).toEqual({ result: 30 })
|
||||
})
|
||||
|
||||
it('should throw DjangoError for validation errors', async () => {
|
||||
let result: any = null
|
||||
let error: any = null
|
||||
|
||||
function TestComponent() {
|
||||
const { call } = useDjango()
|
||||
|
||||
React.useEffect(() => {
|
||||
// Call without required field
|
||||
call('echo', {})
|
||||
.then((r) => { result = r })
|
||||
.catch((e) => { error = e })
|
||||
}, [call])
|
||||
|
||||
return <div>Test</div>
|
||||
}
|
||||
|
||||
render(
|
||||
<DjangoContext baseUrl={`${BACKEND_URL}/api/djarea`} autoConnect={false}>
|
||||
<TestComponent />
|
||||
</DjangoContext>
|
||||
)
|
||||
|
||||
await waitFor(() => {
|
||||
expect(result || error).not.toBeNull()
|
||||
}, { timeout: 5000 })
|
||||
|
||||
expect(result).toBeNull()
|
||||
expect(error).toBeInstanceOf(DjangoError)
|
||||
})
|
||||
})
|
||||
|
||||
describe('useServerFunction hook', () => {
|
||||
it('should create typed function that calls backend', async () => {
|
||||
let result: any = null
|
||||
let error: any = null
|
||||
|
||||
interface EchoInput { text: string }
|
||||
interface EchoOutput { message: string }
|
||||
|
||||
function TestComponent() {
|
||||
const echo = useServerFunction<EchoInput, EchoOutput>('echo')
|
||||
|
||||
React.useEffect(() => {
|
||||
echo({ text: 'typed function test' })
|
||||
.then((r) => { result = r })
|
||||
.catch((e) => { error = e })
|
||||
}, [echo])
|
||||
|
||||
return <div>Test</div>
|
||||
}
|
||||
|
||||
render(
|
||||
<DjangoContext baseUrl={`${BACKEND_URL}/api/djarea`} autoConnect={false}>
|
||||
<TestComponent />
|
||||
</DjangoContext>
|
||||
)
|
||||
|
||||
await waitFor(() => {
|
||||
expect(result || error).not.toBeNull()
|
||||
}, { timeout: 5000 })
|
||||
|
||||
expect(error).toBeNull()
|
||||
expect(result).toHaveProperty('message')
|
||||
expect(result.message).toContain('typed function test')
|
||||
})
|
||||
})
|
||||
|
||||
describe('form functions', () => {
|
||||
it('should call login.schema and get form fields', async () => {
|
||||
let result: any = null
|
||||
let error: any = null
|
||||
|
||||
function TestComponent() {
|
||||
const { call } = useDjango()
|
||||
|
||||
React.useEffect(() => {
|
||||
call('login.schema', { data: {} })
|
||||
.then((r) => { result = r })
|
||||
.catch((e) => { error = e })
|
||||
}, [call])
|
||||
|
||||
return <div>Test</div>
|
||||
}
|
||||
|
||||
render(
|
||||
<DjangoContext baseUrl={`${BACKEND_URL}/api/djarea`} autoConnect={false}>
|
||||
<TestComponent />
|
||||
</DjangoContext>
|
||||
)
|
||||
|
||||
await waitFor(() => {
|
||||
expect(result || error).not.toBeNull()
|
||||
}, { timeout: 5000 })
|
||||
|
||||
expect(error).toBeNull()
|
||||
expect(result).toHaveProperty('fields')
|
||||
expect(result).toHaveProperty('meta')
|
||||
// Login form should have login and password fields
|
||||
expect(result.fields).toHaveProperty('login')
|
||||
expect(result.fields).toHaveProperty('password')
|
||||
})
|
||||
|
||||
it('should call login.validate and get validation result', async () => {
|
||||
let result: any = null
|
||||
let error: any = null
|
||||
|
||||
function TestComponent() {
|
||||
const { call } = useDjango()
|
||||
|
||||
React.useEffect(() => {
|
||||
call('login.validate', {
|
||||
data: { login: 'test@example.com', password: 'testpass' }
|
||||
})
|
||||
.then((r) => { result = r })
|
||||
.catch((e) => { error = e })
|
||||
}, [call])
|
||||
|
||||
return <div>Test</div>
|
||||
}
|
||||
|
||||
render(
|
||||
<DjangoContext baseUrl={`${BACKEND_URL}/api/djarea`} autoConnect={false}>
|
||||
<TestComponent />
|
||||
</DjangoContext>
|
||||
)
|
||||
|
||||
await waitFor(() => {
|
||||
expect(result || error).not.toBeNull()
|
||||
}, { timeout: 5000 })
|
||||
|
||||
// Should return validation result (may have errors for invalid creds, that's ok)
|
||||
expect(error).toBeNull()
|
||||
expect(result).toHaveProperty('valid')
|
||||
})
|
||||
})
|
||||
|
||||
})
|
||||
214
react/src/__tests__/errors.test.ts
Normal file
214
react/src/__tests__/errors.test.ts
Normal file
@@ -0,0 +1,214 @@
|
||||
/**
|
||||
* Tests for Django Server Error
|
||||
*/
|
||||
|
||||
import { DjangoError, type FunctionErrorResponse } from '../errors'
|
||||
|
||||
describe('DjangoError', () => {
|
||||
it('should create error with message and code', () => {
|
||||
const response: FunctionErrorResponse = {
|
||||
error: true,
|
||||
code: 'NOT_FOUND',
|
||||
message: 'Function not found',
|
||||
}
|
||||
|
||||
const error = new DjangoError(response)
|
||||
|
||||
expect(error.message).toBe('Function not found')
|
||||
expect(error.code).toBe('NOT_FOUND')
|
||||
expect(error.name).toBe('DjangoError')
|
||||
})
|
||||
|
||||
it('should preserve details', () => {
|
||||
const response: FunctionErrorResponse = {
|
||||
error: true,
|
||||
code: 'VALIDATION_ERROR',
|
||||
message: 'Invalid input',
|
||||
details: {
|
||||
fields: {
|
||||
name: ['Required', 'Too short'],
|
||||
email: ['Invalid format'],
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
const error = new DjangoError(response)
|
||||
|
||||
expect(error.details).toBeDefined()
|
||||
expect(error.details?.fields?.name).toEqual(['Required', 'Too short'])
|
||||
})
|
||||
|
||||
it('should preserve original response', () => {
|
||||
const response: FunctionErrorResponse = {
|
||||
error: true,
|
||||
code: 'INTERNAL_ERROR',
|
||||
message: 'Server error',
|
||||
}
|
||||
|
||||
const error = new DjangoError(response)
|
||||
|
||||
expect(error.response).toBe(response)
|
||||
})
|
||||
|
||||
describe('isValidationError', () => {
|
||||
it('should return true for validation errors', () => {
|
||||
const error = new DjangoError({
|
||||
error: true,
|
||||
code: 'VALIDATION_ERROR',
|
||||
message: 'Invalid',
|
||||
})
|
||||
|
||||
expect(error.isValidationError()).toBe(true)
|
||||
})
|
||||
|
||||
it('should return false for other errors', () => {
|
||||
const error = new DjangoError({
|
||||
error: true,
|
||||
code: 'NOT_FOUND',
|
||||
message: 'Not found',
|
||||
})
|
||||
|
||||
expect(error.isValidationError()).toBe(false)
|
||||
})
|
||||
})
|
||||
|
||||
describe('isAuthError', () => {
|
||||
it('should return true for unauthorized', () => {
|
||||
const error = new DjangoError({
|
||||
error: true,
|
||||
code: 'UNAUTHORIZED',
|
||||
message: 'Not authenticated',
|
||||
})
|
||||
|
||||
expect(error.isAuthError()).toBe(true)
|
||||
})
|
||||
|
||||
it('should return true for forbidden', () => {
|
||||
const error = new DjangoError({
|
||||
error: true,
|
||||
code: 'FORBIDDEN',
|
||||
message: 'Access denied',
|
||||
})
|
||||
|
||||
expect(error.isAuthError()).toBe(true)
|
||||
})
|
||||
|
||||
it('should return false for other errors', () => {
|
||||
const error = new DjangoError({
|
||||
error: true,
|
||||
code: 'NOT_FOUND',
|
||||
message: 'Not found',
|
||||
})
|
||||
|
||||
expect(error.isAuthError()).toBe(false)
|
||||
})
|
||||
})
|
||||
|
||||
describe('isNotFound', () => {
|
||||
it('should return true for not found errors', () => {
|
||||
const error = new DjangoError({
|
||||
error: true,
|
||||
code: 'NOT_FOUND',
|
||||
message: 'Not found',
|
||||
})
|
||||
|
||||
expect(error.isNotFound()).toBe(true)
|
||||
})
|
||||
|
||||
it('should return false for other errors', () => {
|
||||
const error = new DjangoError({
|
||||
error: true,
|
||||
code: 'VALIDATION_ERROR',
|
||||
message: 'Invalid',
|
||||
})
|
||||
|
||||
expect(error.isNotFound()).toBe(false)
|
||||
})
|
||||
})
|
||||
|
||||
describe('getFieldErrors', () => {
|
||||
it('should return field errors for validation error', () => {
|
||||
const error = new DjangoError({
|
||||
error: true,
|
||||
code: 'VALIDATION_ERROR',
|
||||
message: 'Invalid input',
|
||||
details: {
|
||||
fields: {
|
||||
name: ['Required'],
|
||||
email: ['Invalid'],
|
||||
},
|
||||
},
|
||||
})
|
||||
|
||||
const errors = error.getFieldErrors()
|
||||
|
||||
expect(errors).toEqual({
|
||||
name: ['Required'],
|
||||
email: ['Invalid'],
|
||||
})
|
||||
})
|
||||
|
||||
it('should return null for non-validation errors', () => {
|
||||
const error = new DjangoError({
|
||||
error: true,
|
||||
code: 'NOT_FOUND',
|
||||
message: 'Not found',
|
||||
})
|
||||
|
||||
expect(error.getFieldErrors()).toBeNull()
|
||||
})
|
||||
|
||||
it('should return null if no fields in details', () => {
|
||||
const error = new DjangoError({
|
||||
error: true,
|
||||
code: 'VALIDATION_ERROR',
|
||||
message: 'Invalid',
|
||||
details: {},
|
||||
})
|
||||
|
||||
expect(error.getFieldErrors()).toBeNull()
|
||||
})
|
||||
})
|
||||
|
||||
describe('getFieldError', () => {
|
||||
it('should return first error for a field', () => {
|
||||
const error = new DjangoError({
|
||||
error: true,
|
||||
code: 'VALIDATION_ERROR',
|
||||
message: 'Invalid input',
|
||||
details: {
|
||||
fields: {
|
||||
name: ['Required', 'Too short'],
|
||||
},
|
||||
},
|
||||
})
|
||||
|
||||
expect(error.getFieldError('name')).toBe('Required')
|
||||
})
|
||||
|
||||
it('should return null for non-existent field', () => {
|
||||
const error = new DjangoError({
|
||||
error: true,
|
||||
code: 'VALIDATION_ERROR',
|
||||
message: 'Invalid input',
|
||||
details: {
|
||||
fields: {
|
||||
name: ['Required'],
|
||||
},
|
||||
},
|
||||
})
|
||||
|
||||
expect(error.getFieldError('email')).toBeNull()
|
||||
})
|
||||
|
||||
it('should return null for non-validation errors', () => {
|
||||
const error = new DjangoError({
|
||||
error: true,
|
||||
code: 'NOT_FOUND',
|
||||
message: 'Not found',
|
||||
})
|
||||
|
||||
expect(error.getFieldError('name')).toBeNull()
|
||||
})
|
||||
})
|
||||
})
|
||||
362
react/src/__tests__/forms.test.tsx
Normal file
362
react/src/__tests__/forms.test.tsx
Normal file
@@ -0,0 +1,362 @@
|
||||
/**
|
||||
* Tests for Django Forms
|
||||
*
|
||||
* Integration tests call the REAL backend - no mocks.
|
||||
* Backend must be running: docker-compose up
|
||||
*
|
||||
* Run integration tests with: RUN_INTEGRATION_TESTS=true npm run test
|
||||
*/
|
||||
|
||||
import React from 'react'
|
||||
import { renderHook, act, waitFor } from '@testing-library/react'
|
||||
import { z } from 'zod'
|
||||
|
||||
import {
|
||||
useDjangoFormCore,
|
||||
type FormCoreConfig,
|
||||
} from '../forms'
|
||||
import { DjangoContext } from '../context'
|
||||
import { describeIntegration, BACKEND_URL } from '../testing'
|
||||
|
||||
// ============================================================================
|
||||
// Test Setup
|
||||
// ============================================================================
|
||||
|
||||
// Helper to render hook with provider
|
||||
function renderFormHook<TData extends Record<string, unknown>>(
|
||||
config: FormCoreConfig<TData>
|
||||
) {
|
||||
return renderHook(() => useDjangoFormCore<TData>(config), {
|
||||
wrapper: ({ children }) => (
|
||||
<DjangoContext baseUrl={`${BACKEND_URL}/api/djarea`} autoConnect={false}>
|
||||
{children}
|
||||
</DjangoContext>
|
||||
),
|
||||
})
|
||||
}
|
||||
|
||||
// ============================================================================
|
||||
// Integration Tests - Real Backend Calls
|
||||
// ============================================================================
|
||||
|
||||
describeIntegration('useDjangoFormCore (integration)', () => {
|
||||
describe('Schema loading from real backend', () => {
|
||||
it('loads login form schema', async () => {
|
||||
const { result } = renderFormHook({
|
||||
name: 'login',
|
||||
})
|
||||
|
||||
expect(result.current.loading).toBe(true)
|
||||
|
||||
await waitFor(() => {
|
||||
expect(result.current.loading).toBe(false)
|
||||
}, { timeout: 5000 })
|
||||
|
||||
expect(result.current.schema).not.toBeNull()
|
||||
expect(result.current.schema?.name).toBe('login')
|
||||
// Login form should have login and password fields
|
||||
expect(result.current.schema?.fields).toHaveProperty('login')
|
||||
expect(result.current.schema?.fields).toHaveProperty('password')
|
||||
})
|
||||
|
||||
it('loads signup form schema', async () => {
|
||||
const { result } = renderFormHook({
|
||||
name: 'signup',
|
||||
})
|
||||
|
||||
await waitFor(() => {
|
||||
expect(result.current.loading).toBe(false)
|
||||
}, { timeout: 5000 })
|
||||
|
||||
expect(result.current.schema).not.toBeNull()
|
||||
expect(result.current.schema?.name).toBe('signup')
|
||||
// Signup form should have email and password fields
|
||||
expect(result.current.schema?.fields).toHaveProperty('email')
|
||||
expect(result.current.schema?.fields).toHaveProperty('password1')
|
||||
})
|
||||
|
||||
it('loads add_email form schema', async () => {
|
||||
const { result } = renderFormHook({
|
||||
name: 'add_email',
|
||||
})
|
||||
|
||||
await waitFor(() => {
|
||||
expect(result.current.loading).toBe(false)
|
||||
}, { timeout: 5000 })
|
||||
|
||||
expect(result.current.schema).not.toBeNull()
|
||||
expect(result.current.schema?.fields).toHaveProperty('email')
|
||||
})
|
||||
})
|
||||
|
||||
describe('Form data management', () => {
|
||||
it('sets and gets form data', async () => {
|
||||
const { result } = renderFormHook({
|
||||
name: 'login',
|
||||
})
|
||||
|
||||
await waitFor(() => {
|
||||
expect(result.current.loading).toBe(false)
|
||||
}, { timeout: 5000 })
|
||||
|
||||
act(() => {
|
||||
result.current.set('login', 'test@example.com')
|
||||
result.current.set('password', 'testpassword123')
|
||||
})
|
||||
|
||||
expect(result.current.data.login).toBe('test@example.com')
|
||||
expect(result.current.data.password).toBe('testpassword123')
|
||||
})
|
||||
|
||||
it('tracks touched fields', async () => {
|
||||
const { result } = renderFormHook({
|
||||
name: 'login',
|
||||
})
|
||||
|
||||
await waitFor(() => {
|
||||
expect(result.current.loading).toBe(false)
|
||||
}, { timeout: 5000 })
|
||||
|
||||
expect(result.current.touchedFields.size).toBe(0)
|
||||
|
||||
act(() => {
|
||||
result.current.touch('login')
|
||||
})
|
||||
|
||||
expect(result.current.touchedFields.has('login')).toBe(true)
|
||||
expect(result.current.touchedFields.has('password')).toBe(false)
|
||||
})
|
||||
|
||||
it('resets form state', async () => {
|
||||
const { result } = renderFormHook({
|
||||
name: 'login',
|
||||
})
|
||||
|
||||
await waitFor(() => {
|
||||
expect(result.current.loading).toBe(false)
|
||||
}, { timeout: 5000 })
|
||||
|
||||
act(() => {
|
||||
result.current.set('login', 'changed@example.com')
|
||||
result.current.touch('login')
|
||||
})
|
||||
|
||||
expect(result.current.data.login).toBe('changed@example.com')
|
||||
expect(result.current.touchedFields.has('login')).toBe(true)
|
||||
|
||||
act(() => {
|
||||
result.current.reset()
|
||||
})
|
||||
|
||||
expect(result.current.data.login).toBe('')
|
||||
expect(result.current.touchedFields.size).toBe(0)
|
||||
expect(result.current.errors).toBeNull()
|
||||
})
|
||||
})
|
||||
|
||||
describe('Zod validation with real schema', () => {
|
||||
// Define Zod schema matching login form
|
||||
const LoginZodSchema = z.object({
|
||||
login: z.string().min(1, 'Login is required').email('Invalid email'),
|
||||
password: z.string().min(1, 'Password is required'),
|
||||
})
|
||||
|
||||
type LoginData = z.infer<typeof LoginZodSchema>
|
||||
|
||||
it('validates with Zod schema on touch', async () => {
|
||||
const { result } = renderFormHook<LoginData>({
|
||||
name: 'login',
|
||||
zodSchema: LoginZodSchema,
|
||||
})
|
||||
|
||||
await waitFor(() => {
|
||||
expect(result.current.loading).toBe(false)
|
||||
}, { timeout: 5000 })
|
||||
|
||||
// Set invalid value
|
||||
act(() => {
|
||||
result.current.set('login', 'not-an-email')
|
||||
})
|
||||
|
||||
// Touch triggers validation
|
||||
act(() => {
|
||||
result.current.touch('login')
|
||||
})
|
||||
|
||||
// Zod validation should show email format error
|
||||
expect(result.current.errors?.fields.login).toBeDefined()
|
||||
expect(result.current.errors?.fields.login?.[0]?.message).toBe('Invalid email')
|
||||
expect(result.current.errors?.fields.login?.[0]?.source).toBe('zod')
|
||||
})
|
||||
|
||||
it('clears errors when value becomes valid', async () => {
|
||||
const { result } = renderFormHook<LoginData>({
|
||||
name: 'login',
|
||||
zodSchema: LoginZodSchema,
|
||||
})
|
||||
|
||||
await waitFor(() => {
|
||||
expect(result.current.loading).toBe(false)
|
||||
}, { timeout: 5000 })
|
||||
|
||||
// Set invalid value and touch
|
||||
act(() => {
|
||||
result.current.set('login', 'bad')
|
||||
})
|
||||
act(() => {
|
||||
result.current.touch('login')
|
||||
})
|
||||
|
||||
expect(result.current.errors?.fields.login).toBeDefined()
|
||||
|
||||
// Set valid value and touch
|
||||
act(() => {
|
||||
result.current.set('login', 'valid@example.com')
|
||||
})
|
||||
act(() => {
|
||||
result.current.touch('login')
|
||||
})
|
||||
|
||||
expect(result.current.errors?.fields.login).toBeUndefined()
|
||||
})
|
||||
|
||||
it('tracks hasErrors correctly', async () => {
|
||||
const { result } = renderFormHook<LoginData>({
|
||||
name: 'login',
|
||||
zodSchema: LoginZodSchema,
|
||||
})
|
||||
|
||||
await waitFor(() => {
|
||||
expect(result.current.loading).toBe(false)
|
||||
}, { timeout: 5000 })
|
||||
|
||||
expect(result.current.hasErrors).toBe(false)
|
||||
|
||||
// Set invalid and touch
|
||||
act(() => {
|
||||
result.current.set('login', 'bad')
|
||||
})
|
||||
act(() => {
|
||||
result.current.touch('login')
|
||||
})
|
||||
|
||||
expect(result.current.hasErrors).toBe(true)
|
||||
|
||||
// Set valid and touch
|
||||
act(() => {
|
||||
result.current.set('login', 'good@example.com')
|
||||
})
|
||||
act(() => {
|
||||
result.current.touch('login')
|
||||
})
|
||||
|
||||
expect(result.current.hasErrors).toBe(false)
|
||||
})
|
||||
})
|
||||
|
||||
describe('Form submission', () => {
|
||||
it('submits login form and handles validation errors', async () => {
|
||||
const { result } = renderFormHook({
|
||||
name: 'login',
|
||||
})
|
||||
|
||||
await waitFor(() => {
|
||||
expect(result.current.loading).toBe(false)
|
||||
}, { timeout: 5000 })
|
||||
|
||||
// Set invalid credentials
|
||||
act(() => {
|
||||
result.current.set('login', 'nonexistent@example.com')
|
||||
result.current.set('password', 'wrongpassword')
|
||||
})
|
||||
|
||||
// Submit should fail with validation error
|
||||
let submitResult: any
|
||||
await act(async () => {
|
||||
submitResult = await result.current.submit()
|
||||
})
|
||||
|
||||
// Submit should return error (invalid credentials)
|
||||
// The exact error depends on backend behavior
|
||||
expect(submitResult).toBeDefined()
|
||||
})
|
||||
|
||||
it('submits signup form with missing required fields', async () => {
|
||||
const { result } = renderFormHook({
|
||||
name: 'signup',
|
||||
})
|
||||
|
||||
await waitFor(() => {
|
||||
expect(result.current.loading).toBe(false)
|
||||
}, { timeout: 5000 })
|
||||
|
||||
// Submit with empty fields should return validation errors
|
||||
let submitResult: any
|
||||
await act(async () => {
|
||||
submitResult = await result.current.submit()
|
||||
})
|
||||
|
||||
// Should have validation errors for required fields
|
||||
expect(submitResult).toBeDefined()
|
||||
expect(submitResult.success).toBe(false)
|
||||
})
|
||||
})
|
||||
|
||||
describe('Error source tagging', () => {
|
||||
const LoginZodSchema = z.object({
|
||||
login: z.string().min(1, 'Login is required').email('Invalid email'),
|
||||
password: z.string().min(1, 'Password is required'),
|
||||
})
|
||||
|
||||
type LoginData = z.infer<typeof LoginZodSchema>
|
||||
|
||||
it('tags Zod errors with source: zod', async () => {
|
||||
const { result } = renderFormHook<LoginData>({
|
||||
name: 'login',
|
||||
zodSchema: LoginZodSchema,
|
||||
})
|
||||
|
||||
await waitFor(() => {
|
||||
expect(result.current.loading).toBe(false)
|
||||
}, { timeout: 5000 })
|
||||
|
||||
act(() => {
|
||||
result.current.set('login', 'invalid')
|
||||
})
|
||||
act(() => {
|
||||
result.current.touch('login')
|
||||
})
|
||||
|
||||
const errors = result.current.getFieldErrors('login')
|
||||
expect(errors.length).toBeGreaterThan(0)
|
||||
expect(errors[0].source).toBe('zod')
|
||||
})
|
||||
|
||||
it('filters errors by source', async () => {
|
||||
const { result } = renderFormHook<LoginData>({
|
||||
name: 'login',
|
||||
zodSchema: LoginZodSchema,
|
||||
})
|
||||
|
||||
await waitFor(() => {
|
||||
expect(result.current.loading).toBe(false)
|
||||
}, { timeout: 5000 })
|
||||
|
||||
act(() => {
|
||||
result.current.set('login', 'invalid')
|
||||
})
|
||||
act(() => {
|
||||
result.current.touch('login')
|
||||
})
|
||||
|
||||
// Should have Zod errors
|
||||
const zodErrors = result.current.getFieldErrors('login', { source: 'zod' })
|
||||
expect(zodErrors.length).toBeGreaterThan(0)
|
||||
|
||||
// Should have no server errors yet
|
||||
const serverErrors = result.current.getFieldErrors('login', { source: 'server' })
|
||||
expect(serverErrors.length).toBe(0)
|
||||
})
|
||||
})
|
||||
})
|
||||
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user