Full test infrastructure, code audit fixes, and real E2E integration tests
Test infrastructure: - Django standalone test runner (pytest-django, test settings, EmailUser model) - React unit tests via Vitest with jsdom, jest compat layer, path aliases - Playwright E2E tests using generated hooks in a real Chromium browser - Docker Compose test backend (Django + Redis) for integration testing - Desktop integration test app (PyWebView + Django + uvicorn) - Makefile with test/test-django/test-react/test-integration targets Library bugs found and fixed: - hasJWT truthiness: undefined !== null was true, skipping session init - process.env crash: CSR client referenced process.env in non-Node browsers - baseUrl not forwarded: DjareaProvider didn't pass baseUrl to CSR client - Relative URL handling: new URL() failed with relative base paths - call() race condition: HTTP requests fired before CSRF cookie was set - Session init await: added sessionRef promise so call() waits for session - path_prefix on schema export: both export commands failed with URL reverse - NullBooleanField removed: referenced field doesn't exist in Django 5.0+ - lru_cache on JWT settings: get_settings() now cached as intended - Channel message routing: broadcasts now include channel name and params - httpFunctionCall: fixed URL and request body format Generator fixes: - Removed 1,100 lines of REST/OpenAPI client generation (not part of Djarea) - Generator now works for djarea-only projects without django-ninja REST APIs - Generated DjangoContext now includes ChannelProvider when channels exist - Fixed env var passthrough for schema export commands - Deduplicated fetch logic into single runDjangoCommand helper Test quality: - Fixed 33 tautological Django tests with real assertions - Found hidden bug: benchmark functions were never registered - Found hidden bug: unicode lookalike test used plain ASCII - Deleted worthless React unit tests (duplicates, shape checks, Zod-tests-Zod) - Replaced jsdom integration tests with Playwright browser tests Example apps: - example/: Integration test backend with 33 server functions, 5 forms, 4 channels covering auth variations, contexts, class-based ServerFunction, error codes, DjareaFormMixin, formsets, and JWT - desktop/: PyWebView desktop app with file system access, SQLite CRUD, system introspection, and 39 real HTTP integration tests Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
This commit is contained in:
0
django/src/djarea/tests/__init__.py
Normal file
0
django/src/djarea/tests/__init__.py
Normal file
531
django/src/djarea/tests/test_auth.py
Normal file
531
django/src/djarea/tests/test_auth.py
Normal file
@@ -0,0 +1,531 @@
|
||||
"""
|
||||
Authentication Tests for Djarea Server Functions
|
||||
|
||||
Tests all combinations of:
|
||||
- Transport: HTTP vs WebSocket RPC
|
||||
- JWT: Present (valid), Present (invalid), Absent
|
||||
- Session: Present (valid), Absent
|
||||
|
||||
Expected behavior:
|
||||
- JWT present (valid) → JWTUser (no DB query)
|
||||
- JWT present (invalid) → Reject (401), do NOT fall back to session
|
||||
- JWT absent + Session present → Session auth (DB query)
|
||||
- JWT absent + Session absent → AnonymousUser
|
||||
"""
|
||||
|
||||
from django.test import TestCase, RequestFactory, override_settings
|
||||
from django.contrib.auth import get_user_model
|
||||
from django.contrib.sessions.backends.db import SessionStore
|
||||
from unittest.mock import patch, MagicMock
|
||||
import json
|
||||
|
||||
from djarea.jwt.tokens import (
|
||||
create_token_pair,
|
||||
decode_token,
|
||||
JWTUser,
|
||||
)
|
||||
from djarea.client.executor import (
|
||||
_try_jwt_auth,
|
||||
execute_function,
|
||||
FunctionError,
|
||||
FunctionResult,
|
||||
ErrorCode,
|
||||
)
|
||||
from djarea.client import client
|
||||
from djarea.setup.registry import clear_registry, register
|
||||
from pydantic import BaseModel
|
||||
|
||||
|
||||
User = get_user_model()
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# Test Output Models (proper Pydantic models, not raw dicts)
|
||||
# =============================================================================
|
||||
|
||||
class WhoamiOutput(BaseModel):
|
||||
is_authenticated: bool
|
||||
user_id: int | None
|
||||
user_type: str
|
||||
is_staff: bool
|
||||
|
||||
|
||||
class OkOutput(BaseModel):
|
||||
ok: bool
|
||||
|
||||
|
||||
class UserTypeOutput(BaseModel):
|
||||
user_type: str
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# Test Server Functions - defined as plain functions, registered in setUp
|
||||
# =============================================================================
|
||||
|
||||
def _whoami_fn(request) -> WhoamiOutput:
|
||||
"""Returns info about the authenticated user."""
|
||||
user = request.user
|
||||
return WhoamiOutput(
|
||||
is_authenticated=user.is_authenticated,
|
||||
user_id=getattr(user, "id", None),
|
||||
user_type=type(user).__name__,
|
||||
is_staff=getattr(user, "is_staff", False),
|
||||
)
|
||||
|
||||
|
||||
@override_settings(
|
||||
JWT_PRIVATE_KEY="test-secret-key-for-testing-only",
|
||||
JWT_ALGORITHM="HS256",
|
||||
)
|
||||
class HTTPAuthTests(TestCase):
|
||||
"""Test HTTP transport authentication combinations."""
|
||||
|
||||
def setUp(self):
|
||||
clear_registry()
|
||||
self.factory = RequestFactory()
|
||||
self.user = User.objects.create_user(
|
||||
email="test@example.com",
|
||||
password="testpass123",
|
||||
is_staff=True,
|
||||
is_superuser=False,
|
||||
)
|
||||
# Create a session
|
||||
self.session = SessionStore()
|
||||
self.session.create()
|
||||
self.session_key = self.session.session_key
|
||||
|
||||
# Register test function
|
||||
@client
|
||||
def whoami(request) -> WhoamiOutput:
|
||||
user = request.user
|
||||
return WhoamiOutput(
|
||||
is_authenticated=user.is_authenticated,
|
||||
user_id=getattr(user, "id", None),
|
||||
user_type=type(user).__name__,
|
||||
is_staff=getattr(user, "is_staff", False),
|
||||
)
|
||||
register(whoami, "whoami")
|
||||
|
||||
def tearDown(self):
|
||||
self.user.delete()
|
||||
self.session.delete()
|
||||
clear_registry()
|
||||
|
||||
def test_jwt_valid_no_session(self):
|
||||
"""Valid JWT without session → JWTUser (no DB query)."""
|
||||
tokens = create_token_pair(
|
||||
self.user.pk,
|
||||
self.session_key,
|
||||
is_staff=True,
|
||||
is_superuser=False,
|
||||
)
|
||||
|
||||
request = self.factory.post("/")
|
||||
request.META["HTTP_AUTHORIZATION"] = f"Bearer {tokens.access_token}"
|
||||
request.user = MagicMock(is_authenticated=False) # No session auth
|
||||
|
||||
# Try JWT auth
|
||||
result = _try_jwt_auth(request)
|
||||
|
||||
self.assertTrue(result)
|
||||
self.assertIsInstance(request.user, JWTUser)
|
||||
self.assertEqual(request.user.id, self.user.pk)
|
||||
self.assertTrue(request.user.is_staff)
|
||||
self.assertTrue(request.user.is_authenticated)
|
||||
|
||||
def test_jwt_valid_with_session(self):
|
||||
"""Valid JWT with session → JWT takes precedence (no DB query)."""
|
||||
tokens = create_token_pair(
|
||||
self.user.pk,
|
||||
self.session_key,
|
||||
is_staff=True,
|
||||
is_superuser=False,
|
||||
)
|
||||
|
||||
request = self.factory.post("/")
|
||||
request.META["HTTP_AUTHORIZATION"] = f"Bearer {tokens.access_token}"
|
||||
request.user = self.user # Session auth already set user
|
||||
|
||||
# JWT should still be processed and take precedence
|
||||
result = _try_jwt_auth(request)
|
||||
|
||||
self.assertTrue(result)
|
||||
self.assertIsInstance(request.user, JWTUser)
|
||||
|
||||
def test_jwt_invalid_with_session(self):
|
||||
"""Invalid JWT with valid session → Reject (do NOT fall back)."""
|
||||
request = self.factory.post("/")
|
||||
request.META["HTTP_AUTHORIZATION"] = "Bearer invalid-token-here"
|
||||
request.user = self.user # Session would work
|
||||
|
||||
# JWT auth should fail
|
||||
result = _try_jwt_auth(request)
|
||||
|
||||
self.assertFalse(result)
|
||||
# User should NOT be changed to session user - that happens elsewhere
|
||||
# The point is _try_jwt_auth returns False, indicating JWT failed
|
||||
|
||||
def test_jwt_expired_with_session(self):
|
||||
"""Expired JWT with valid session → Reject (do NOT fall back)."""
|
||||
# Create token with past expiration by mocking time
|
||||
with patch("djarea.jwt.tokens.time.time", return_value=0):
|
||||
tokens = create_token_pair(
|
||||
self.user.pk,
|
||||
self.session_key,
|
||||
is_staff=True,
|
||||
is_superuser=False,
|
||||
)
|
||||
|
||||
request = self.factory.post("/")
|
||||
request.META["HTTP_AUTHORIZATION"] = f"Bearer {tokens.access_token}"
|
||||
request.user = self.user # Session would work
|
||||
|
||||
# JWT auth should fail (expired)
|
||||
result = _try_jwt_auth(request)
|
||||
|
||||
self.assertFalse(result)
|
||||
|
||||
def test_no_jwt_with_session(self):
|
||||
"""No JWT with valid session → Session auth (normal Django flow)."""
|
||||
request = self.factory.post("/")
|
||||
request.user = self.user # Session auth set user
|
||||
|
||||
# No JWT auth attempted
|
||||
result = _try_jwt_auth(request)
|
||||
|
||||
self.assertFalse(result) # No JWT to process
|
||||
# User remains the session user
|
||||
self.assertEqual(request.user, self.user)
|
||||
|
||||
def test_no_jwt_no_session(self):
|
||||
"""No JWT, no session → AnonymousUser."""
|
||||
from django.contrib.auth.models import AnonymousUser
|
||||
|
||||
request = self.factory.post("/")
|
||||
request.user = AnonymousUser()
|
||||
|
||||
result = _try_jwt_auth(request)
|
||||
|
||||
self.assertFalse(result)
|
||||
self.assertIsInstance(request.user, AnonymousUser)
|
||||
|
||||
def test_execute_function_with_jwt(self):
|
||||
"""Execute server function with JWT auth."""
|
||||
tokens = create_token_pair(
|
||||
self.user.pk,
|
||||
self.session_key,
|
||||
is_staff=True,
|
||||
is_superuser=False,
|
||||
)
|
||||
|
||||
request = self.factory.post("/")
|
||||
request.META["HTTP_AUTHORIZATION"] = f"Bearer {tokens.access_token}"
|
||||
|
||||
# Simulate what the view does: try JWT auth first
|
||||
_try_jwt_auth(request)
|
||||
|
||||
# Use the whoami function which returns WhoamiOutput (Pydantic model)
|
||||
result = execute_function(request, "whoami", {})
|
||||
|
||||
self.assertIsInstance(result, FunctionResult)
|
||||
self.assertTrue(result.data["is_authenticated"])
|
||||
self.assertEqual(result.data["user_type"], "JWTUser")
|
||||
self.assertTrue(result.data["is_staff"])
|
||||
|
||||
|
||||
@override_settings(
|
||||
JWT_PRIVATE_KEY="test-secret-key-for-testing-only",
|
||||
JWT_ALGORITHM="HS256",
|
||||
)
|
||||
class JWTUserTests(TestCase):
|
||||
"""Test JWTUser behavior."""
|
||||
|
||||
def setUp(self):
|
||||
clear_registry()
|
||||
|
||||
def tearDown(self):
|
||||
clear_registry()
|
||||
|
||||
def test_jwt_user_attributes(self):
|
||||
"""JWTUser has expected attributes."""
|
||||
from djarea.jwt.tokens import TokenPayload
|
||||
|
||||
payload = TokenPayload(
|
||||
user_id=42,
|
||||
session_key="test-session",
|
||||
token_type="access",
|
||||
is_staff=True,
|
||||
is_superuser=False,
|
||||
exp=9999999999,
|
||||
iat=0,
|
||||
)
|
||||
|
||||
user = JWTUser(payload)
|
||||
|
||||
self.assertEqual(user.id, 42)
|
||||
self.assertEqual(user.pk, 42)
|
||||
self.assertTrue(user.is_staff)
|
||||
self.assertFalse(user.is_superuser)
|
||||
self.assertTrue(user.is_authenticated)
|
||||
self.assertFalse(user.is_anonymous)
|
||||
self.assertTrue(user.is_active)
|
||||
|
||||
def test_jwt_user_string_id(self):
|
||||
"""JWTUser handles string user_id (converted to int)."""
|
||||
from djarea.jwt.tokens import TokenPayload
|
||||
|
||||
payload = TokenPayload(
|
||||
user_id="42", # String, as stored in JWT
|
||||
session_key="test-session",
|
||||
token_type="access",
|
||||
is_staff=False,
|
||||
is_superuser=False,
|
||||
exp=9999999999,
|
||||
iat=0,
|
||||
)
|
||||
|
||||
user = JWTUser(payload)
|
||||
|
||||
self.assertEqual(user.id, 42)
|
||||
self.assertIsInstance(user.id, int)
|
||||
|
||||
|
||||
@override_settings(
|
||||
JWT_PRIVATE_KEY="test-secret-key-for-testing-only",
|
||||
JWT_ALGORITHM="HS256",
|
||||
)
|
||||
class AuthDecoratorTests(TestCase):
|
||||
"""Test @client(auth=...) decorator."""
|
||||
|
||||
def setUp(self):
|
||||
clear_registry()
|
||||
self.factory = RequestFactory()
|
||||
self.user = User.objects.create_user(
|
||||
email="test@example.com",
|
||||
password="testpass123",
|
||||
is_staff=False,
|
||||
is_superuser=False,
|
||||
)
|
||||
self.staff_user = User.objects.create_user(
|
||||
email="staff@example.com",
|
||||
password="testpass123",
|
||||
is_staff=True,
|
||||
is_superuser=False,
|
||||
)
|
||||
self.superuser = User.objects.create_user(
|
||||
email="super@example.com",
|
||||
password="testpass123",
|
||||
is_staff=True,
|
||||
is_superuser=True,
|
||||
)
|
||||
|
||||
def tearDown(self):
|
||||
self.user.delete()
|
||||
self.staff_user.delete()
|
||||
self.superuser.delete()
|
||||
clear_registry()
|
||||
|
||||
def test_auth_required_with_anonymous(self):
|
||||
"""@client(auth=True) rejects anonymous users."""
|
||||
from django.contrib.auth.models import AnonymousUser
|
||||
|
||||
# Register a test function with proper Pydantic model
|
||||
@client(auth=True)
|
||||
def protected_fn(request) -> OkOutput:
|
||||
return OkOutput(ok=True)
|
||||
register(protected_fn, "protected_fn")
|
||||
|
||||
request = self.factory.post("/")
|
||||
request.user = AnonymousUser()
|
||||
|
||||
result = execute_function(request, "protected_fn", {})
|
||||
|
||||
self.assertIsInstance(result, FunctionError)
|
||||
self.assertEqual(result.code, ErrorCode.UNAUTHORIZED)
|
||||
|
||||
def test_auth_required_with_authenticated(self):
|
||||
"""@client(auth=True) allows authenticated users."""
|
||||
@client(auth=True)
|
||||
def protected_fn2(request) -> OkOutput:
|
||||
return OkOutput(ok=True)
|
||||
register(protected_fn2, "protected_fn2")
|
||||
|
||||
request = self.factory.post("/")
|
||||
request.user = self.user
|
||||
|
||||
result = execute_function(request, "protected_fn2", {})
|
||||
|
||||
self.assertIsInstance(result, FunctionResult)
|
||||
self.assertEqual(result.data["ok"], True)
|
||||
|
||||
def test_auth_staff_with_regular_user(self):
|
||||
"""@client(auth='staff') rejects non-staff users."""
|
||||
@client(auth='staff')
|
||||
def staff_fn(request) -> OkOutput:
|
||||
return OkOutput(ok=True)
|
||||
register(staff_fn, "staff_fn")
|
||||
|
||||
request = self.factory.post("/")
|
||||
request.user = self.user # Not staff
|
||||
|
||||
result = execute_function(request, "staff_fn", {})
|
||||
|
||||
self.assertIsInstance(result, FunctionError)
|
||||
self.assertEqual(result.code, ErrorCode.FORBIDDEN)
|
||||
|
||||
def test_auth_staff_with_staff_user(self):
|
||||
"""@client(auth='staff') allows staff users."""
|
||||
@client(auth='staff')
|
||||
def staff_fn2(request) -> OkOutput:
|
||||
return OkOutput(ok=True)
|
||||
register(staff_fn2, "staff_fn2")
|
||||
|
||||
request = self.factory.post("/")
|
||||
request.user = self.staff_user
|
||||
|
||||
result = execute_function(request, "staff_fn2", {})
|
||||
|
||||
self.assertIsInstance(result, FunctionResult)
|
||||
|
||||
def test_auth_superuser_with_staff(self):
|
||||
"""@client(auth='superuser') rejects non-superusers."""
|
||||
@client(auth='superuser')
|
||||
def super_fn(request) -> OkOutput:
|
||||
return OkOutput(ok=True)
|
||||
register(super_fn, "super_fn")
|
||||
|
||||
request = self.factory.post("/")
|
||||
request.user = self.staff_user # Staff but not superuser
|
||||
|
||||
result = execute_function(request, "super_fn", {})
|
||||
|
||||
self.assertIsInstance(result, FunctionError)
|
||||
self.assertEqual(result.code, ErrorCode.FORBIDDEN)
|
||||
|
||||
def test_auth_superuser_with_superuser(self):
|
||||
"""@client(auth='superuser') allows superusers."""
|
||||
@client(auth='superuser')
|
||||
def super_fn2(request) -> OkOutput:
|
||||
return OkOutput(ok=True)
|
||||
register(super_fn2, "super_fn2")
|
||||
|
||||
request = self.factory.post("/")
|
||||
request.user = self.superuser
|
||||
|
||||
result = execute_function(request, "super_fn2", {})
|
||||
|
||||
self.assertIsInstance(result, FunctionResult)
|
||||
|
||||
def test_auth_with_jwt_user(self):
|
||||
"""Auth checks work with JWTUser (stateless)."""
|
||||
from djarea.jwt.tokens import TokenPayload
|
||||
|
||||
@client(auth='staff')
|
||||
def jwt_staff_fn(request) -> UserTypeOutput:
|
||||
return UserTypeOutput(user_type=type(request.user).__name__)
|
||||
register(jwt_staff_fn, "jwt_staff_fn")
|
||||
|
||||
# Create JWTUser with is_staff=True
|
||||
payload = TokenPayload(
|
||||
user_id=99,
|
||||
session_key="test",
|
||||
token_type="access",
|
||||
is_staff=True,
|
||||
is_superuser=False,
|
||||
exp=9999999999,
|
||||
iat=0,
|
||||
)
|
||||
jwt_user = JWTUser(payload)
|
||||
|
||||
request = self.factory.post("/")
|
||||
request.user = jwt_user
|
||||
|
||||
result = execute_function(request, "jwt_staff_fn", {})
|
||||
|
||||
self.assertIsInstance(result, FunctionResult)
|
||||
self.assertEqual(result.data["user_type"], "JWTUser")
|
||||
|
||||
def test_auth_invalid_string_raises(self):
|
||||
"""Invalid auth string raises ValueError at decoration time."""
|
||||
with self.assertRaises(ValueError) as ctx:
|
||||
@client(auth='admin') # 'admin' is not valid
|
||||
def bad_fn(request) -> OkOutput:
|
||||
return OkOutput(ok=True)
|
||||
|
||||
self.assertIn("Invalid auth value 'admin'", str(ctx.exception))
|
||||
self.assertIn("required", str(ctx.exception))
|
||||
|
||||
def test_auth_callable_returns_true(self):
|
||||
"""Callable auth returning True allows access."""
|
||||
@client(auth=lambda r: r.user.email.endswith('@example.com'))
|
||||
def email_check_fn(request) -> OkOutput:
|
||||
return OkOutput(ok=True)
|
||||
register(email_check_fn, "email_check_fn")
|
||||
|
||||
request = self.factory.post("/")
|
||||
request.user = self.user # email is test@example.com
|
||||
|
||||
result = execute_function(request, "email_check_fn", {})
|
||||
|
||||
self.assertIsInstance(result, FunctionResult)
|
||||
self.assertTrue(result.data["ok"])
|
||||
|
||||
def test_auth_callable_returns_false(self):
|
||||
"""Callable auth returning False denies access."""
|
||||
@client(auth=lambda r: r.user.email.endswith('@admin.com'))
|
||||
def admin_email_fn(request) -> OkOutput:
|
||||
return OkOutput(ok=True)
|
||||
register(admin_email_fn, "admin_email_fn")
|
||||
|
||||
request = self.factory.post("/")
|
||||
request.user = self.user # email is test@example.com, not @admin.com
|
||||
|
||||
result = execute_function(request, "admin_email_fn", {})
|
||||
|
||||
self.assertIsInstance(result, FunctionError)
|
||||
self.assertEqual(result.code, ErrorCode.FORBIDDEN)
|
||||
self.assertEqual(result.message, "Access denied")
|
||||
|
||||
def test_auth_callable_raises_permission_error(self):
|
||||
"""Callable auth raising PermissionError uses custom message."""
|
||||
def check_premium(request):
|
||||
if not getattr(request.user, 'is_premium', False):
|
||||
raise PermissionError("Premium subscription required")
|
||||
return True
|
||||
|
||||
@client(auth=check_premium)
|
||||
def premium_fn(request) -> OkOutput:
|
||||
return OkOutput(ok=True)
|
||||
register(premium_fn, "premium_fn")
|
||||
|
||||
request = self.factory.post("/")
|
||||
request.user = self.user # No is_premium attribute
|
||||
|
||||
result = execute_function(request, "premium_fn", {})
|
||||
|
||||
self.assertIsInstance(result, FunctionError)
|
||||
self.assertEqual(result.code, ErrorCode.FORBIDDEN)
|
||||
self.assertEqual(result.message, "Premium subscription required")
|
||||
|
||||
def test_auth_callable_with_anonymous_user(self):
|
||||
"""Callable auth can check for anonymous users."""
|
||||
from django.contrib.auth.models import AnonymousUser
|
||||
|
||||
def must_be_authenticated(request):
|
||||
if not request.user.is_authenticated:
|
||||
raise PermissionError("Please log in")
|
||||
return True
|
||||
|
||||
@client(auth=must_be_authenticated)
|
||||
def needs_login_fn(request) -> OkOutput:
|
||||
return OkOutput(ok=True)
|
||||
register(needs_login_fn, "needs_login_fn")
|
||||
|
||||
request = self.factory.post("/")
|
||||
request.user = AnonymousUser()
|
||||
|
||||
result = execute_function(request, "needs_login_fn", {})
|
||||
|
||||
self.assertIsInstance(result, FunctionError)
|
||||
self.assertEqual(result.code, ErrorCode.FORBIDDEN)
|
||||
self.assertEqual(result.message, "Please log in")
|
||||
548
django/src/djarea/tests/test_benchmarks.py
Normal file
548
django/src/djarea/tests/test_benchmarks.py
Normal file
@@ -0,0 +1,548 @@
|
||||
"""
|
||||
Protocol Benchmark: HTTP vs WebSocket Server Functions
|
||||
|
||||
Compares performance of HTTP POST vs WebSocket RPC for server function calls.
|
||||
Includes realistic scenarios with ORM queries.
|
||||
|
||||
Usage:
|
||||
python manage.py test djarea.tests.test_benchmarks --verbosity=2
|
||||
|
||||
Note:
|
||||
These are not unit tests - they measure performance. Results are printed
|
||||
to stdout and should be run in isolation for accurate measurements.
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import json
|
||||
import statistics
|
||||
import time
|
||||
from typing import Any
|
||||
from unittest.mock import MagicMock, AsyncMock
|
||||
|
||||
from django.contrib.auth import get_user_model
|
||||
from django.contrib.auth.models import AnonymousUser
|
||||
from django.db import connection
|
||||
from django.http import HttpRequest
|
||||
from django.test import RequestFactory, TestCase, TransactionTestCase, override_settings
|
||||
from pydantic import BaseModel
|
||||
|
||||
from djarea.client.executor import FunctionResult, execute_function, function_call_view
|
||||
from djarea.setup.registry import clear_registry
|
||||
from djarea.client import client
|
||||
|
||||
|
||||
User = get_user_model()
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# Benchmark Output Models
|
||||
# =============================================================================
|
||||
|
||||
|
||||
class SimpleOutput(BaseModel):
|
||||
value: int
|
||||
|
||||
|
||||
class UserOutput(BaseModel):
|
||||
id: int
|
||||
email: str
|
||||
|
||||
|
||||
class UserListOutput(BaseModel):
|
||||
users: list[dict[str, Any]]
|
||||
count: int
|
||||
|
||||
|
||||
class StatsOutput(BaseModel):
|
||||
total_users: int
|
||||
active_users: int
|
||||
staff_count: int
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# Benchmark Functions
|
||||
# =============================================================================
|
||||
|
||||
|
||||
def setup_benchmark_functions():
|
||||
"""Register benchmark server functions."""
|
||||
from djarea.setup.registry import register
|
||||
|
||||
clear_registry()
|
||||
|
||||
# 1. Simple computation (no I/O)
|
||||
@client
|
||||
def bench_simple(request: HttpRequest, a: int, b: int) -> SimpleOutput:
|
||||
"""Simple addition - baseline with no I/O."""
|
||||
return SimpleOutput(value=a + b)
|
||||
register(bench_simple, "bench_simple")
|
||||
|
||||
# 2. Single ORM query
|
||||
@client
|
||||
def bench_get_user(request: HttpRequest, user_id: int) -> UserOutput:
|
||||
"""Fetch single user by ID."""
|
||||
user = User.objects.filter(id=user_id).first()
|
||||
if user:
|
||||
return UserOutput(id=user.id, email=user.email)
|
||||
return UserOutput(id=0, email="")
|
||||
register(bench_get_user, "bench_get_user")
|
||||
|
||||
# 3. List query with limit
|
||||
@client
|
||||
def bench_list_users(request: HttpRequest, limit: int) -> UserListOutput:
|
||||
"""Fetch list of users with limit."""
|
||||
users = User.objects.all()[:limit]
|
||||
return UserListOutput(
|
||||
users=[{"id": u.id, "email": u.email} for u in users],
|
||||
count=len(users),
|
||||
)
|
||||
register(bench_list_users, "bench_list_users")
|
||||
|
||||
# 4. Aggregation query
|
||||
@client
|
||||
def bench_user_stats(request: HttpRequest) -> StatsOutput:
|
||||
"""Compute user statistics with multiple queries."""
|
||||
total = User.objects.count()
|
||||
active = User.objects.filter(is_active=True).count()
|
||||
staff = User.objects.filter(is_staff=True).count()
|
||||
return StatsOutput(
|
||||
total_users=total,
|
||||
active_users=active,
|
||||
staff_count=staff,
|
||||
)
|
||||
register(bench_user_stats, "bench_user_stats")
|
||||
|
||||
# 5. Complex query with joins
|
||||
@client
|
||||
def bench_user_search(request: HttpRequest, email_contains: str, limit: int) -> UserListOutput:
|
||||
"""Search users by email pattern."""
|
||||
users = User.objects.filter(
|
||||
email__icontains=email_contains,
|
||||
is_active=True,
|
||||
).select_related()[:limit]
|
||||
return UserListOutput(
|
||||
users=[{"id": u.id, "email": u.email} for u in users],
|
||||
count=len(users),
|
||||
)
|
||||
register(bench_user_search, "bench_user_search")
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# Benchmark Test Cases
|
||||
# =============================================================================
|
||||
|
||||
|
||||
class ProtocolBenchmark(TransactionTestCase):
|
||||
"""
|
||||
Benchmark comparing HTTP vs WebSocket (simulated) performance.
|
||||
|
||||
Uses TransactionTestCase to ensure database state is realistic.
|
||||
"""
|
||||
|
||||
# Number of iterations for each benchmark
|
||||
ITERATIONS = 100
|
||||
WARMUP = 10
|
||||
|
||||
@classmethod
|
||||
def setUpClass(cls):
|
||||
super().setUpClass()
|
||||
setup_benchmark_functions()
|
||||
|
||||
def setUp(self):
|
||||
self.factory = RequestFactory()
|
||||
# Create test users for ORM benchmarks
|
||||
self._create_test_users()
|
||||
|
||||
def _create_test_users(self):
|
||||
"""Create test users for benchmarks."""
|
||||
# Create 100 test users
|
||||
users = []
|
||||
for i in range(100):
|
||||
users.append(User(
|
||||
email=f"bench{i}@example.com",
|
||||
is_active=i % 10 != 0, # 90% active
|
||||
is_staff=i < 5, # 5 staff
|
||||
))
|
||||
User.objects.bulk_create(users, ignore_conflicts=True)
|
||||
self.test_user = User.objects.first()
|
||||
|
||||
def _make_request(self, body: dict | None = None) -> HttpRequest:
|
||||
"""Create a request with optional JSON body."""
|
||||
if body:
|
||||
request = self.factory.post(
|
||||
"/api/djarea/call/",
|
||||
data=json.dumps(body),
|
||||
content_type="application/json",
|
||||
)
|
||||
else:
|
||||
request = self.factory.post("/api/djarea/call/")
|
||||
request.user = AnonymousUser()
|
||||
request._dont_enforce_csrf_checks = True
|
||||
return request
|
||||
|
||||
def _benchmark_executor(self, fn_name: str, args: dict, label: str) -> dict:
|
||||
"""
|
||||
Benchmark direct executor calls (simulates WebSocket RPC).
|
||||
|
||||
Returns timing statistics.
|
||||
"""
|
||||
request = self._make_request()
|
||||
times = []
|
||||
|
||||
# Warmup
|
||||
for _ in range(self.WARMUP):
|
||||
execute_function(request, fn_name, args)
|
||||
|
||||
# Benchmark
|
||||
for _ in range(self.ITERATIONS):
|
||||
start = time.perf_counter()
|
||||
result = execute_function(request, fn_name, args)
|
||||
end = time.perf_counter()
|
||||
times.append((end - start) * 1000) # ms
|
||||
|
||||
return self._compute_stats(times, f"Executor ({label})")
|
||||
|
||||
def _benchmark_http(self, fn_name: str, args: dict, label: str) -> dict:
|
||||
"""
|
||||
Benchmark HTTP view calls.
|
||||
|
||||
Returns timing statistics.
|
||||
"""
|
||||
times = []
|
||||
|
||||
# Warmup
|
||||
for _ in range(self.WARMUP):
|
||||
request = self._make_request({"fn": fn_name, "args": args})
|
||||
function_call_view(request)
|
||||
|
||||
# Benchmark
|
||||
for _ in range(self.ITERATIONS):
|
||||
request = self._make_request({"fn": fn_name, "args": args})
|
||||
start = time.perf_counter()
|
||||
response = function_call_view(request)
|
||||
end = time.perf_counter()
|
||||
times.append((end - start) * 1000) # ms
|
||||
|
||||
return self._compute_stats(times, f"HTTP ({label})")
|
||||
|
||||
def _compute_stats(self, times: list[float], label: str) -> dict:
|
||||
"""Compute statistics from timing data."""
|
||||
return {
|
||||
"label": label,
|
||||
"min": min(times),
|
||||
"max": max(times),
|
||||
"mean": statistics.mean(times),
|
||||
"median": statistics.median(times),
|
||||
"stdev": statistics.stdev(times) if len(times) > 1 else 0,
|
||||
"p95": sorted(times)[int(len(times) * 0.95)],
|
||||
"p99": sorted(times)[int(len(times) * 0.99)],
|
||||
"iterations": len(times),
|
||||
}
|
||||
|
||||
def _print_results(self, results: list[dict]):
|
||||
"""Print benchmark results in a table."""
|
||||
print("\n" + "=" * 80)
|
||||
print(f"{'Benchmark':<40} {'Mean':>8} {'Median':>8} {'P95':>8} {'P99':>8}")
|
||||
print("=" * 80)
|
||||
for r in results:
|
||||
print(f"{r['label']:<40} {r['mean']:>7.3f}ms {r['median']:>7.3f}ms {r['p95']:>7.3f}ms {r['p99']:>7.3f}ms")
|
||||
print("=" * 80)
|
||||
|
||||
def _print_comparison(self, executor_stats: dict, http_stats: dict):
|
||||
"""Print comparison between executor and HTTP."""
|
||||
overhead = ((http_stats["mean"] - executor_stats["mean"]) / executor_stats["mean"]) * 100
|
||||
print(f" HTTP overhead vs Executor: {overhead:+.1f}%")
|
||||
|
||||
# -------------------------------------------------------------------------
|
||||
# Benchmark Tests
|
||||
# -------------------------------------------------------------------------
|
||||
|
||||
def test_benchmark_simple_computation(self):
|
||||
"""Benchmark: Simple computation (no I/O)."""
|
||||
print("\n\n### BENCHMARK: Simple Computation (no I/O) ###")
|
||||
|
||||
args = {"a": 100, "b": 200}
|
||||
|
||||
exec_stats = self._benchmark_executor("bench_simple", args, "simple")
|
||||
http_stats = self._benchmark_http("bench_simple", args, "simple")
|
||||
|
||||
self._print_results([exec_stats, http_stats])
|
||||
self._print_comparison(exec_stats, http_stats)
|
||||
|
||||
# Correctness check: 100 + 200 = 300
|
||||
request = self._make_request()
|
||||
result = execute_function(request, "bench_simple", args)
|
||||
self.assertIsInstance(result, FunctionResult)
|
||||
self.assertEqual(result.data["value"], 300)
|
||||
|
||||
def test_benchmark_single_query(self):
|
||||
"""Benchmark: Single ORM query."""
|
||||
print("\n\n### BENCHMARK: Single ORM Query ###")
|
||||
|
||||
args = {"user_id": self.test_user.id if self.test_user else 1}
|
||||
|
||||
exec_stats = self._benchmark_executor("bench_get_user", args, "single query")
|
||||
http_stats = self._benchmark_http("bench_get_user", args, "single query")
|
||||
|
||||
self._print_results([exec_stats, http_stats])
|
||||
self._print_comparison(exec_stats, http_stats)
|
||||
|
||||
# Correctness check: should return the test user's data
|
||||
request = self._make_request()
|
||||
result = execute_function(request, "bench_get_user", args)
|
||||
self.assertIsInstance(result, FunctionResult)
|
||||
self.assertEqual(result.data["id"], self.test_user.id)
|
||||
|
||||
def test_benchmark_list_query(self):
|
||||
"""Benchmark: List query with serialization."""
|
||||
print("\n\n### BENCHMARK: List Query (10 users) ###")
|
||||
|
||||
args = {"limit": 10}
|
||||
|
||||
exec_stats = self._benchmark_executor("bench_list_users", args, "list 10")
|
||||
http_stats = self._benchmark_http("bench_list_users", args, "list 10")
|
||||
|
||||
self._print_results([exec_stats, http_stats])
|
||||
self._print_comparison(exec_stats, http_stats)
|
||||
|
||||
# Correctness check: should return up to 10 users
|
||||
request = self._make_request()
|
||||
result = execute_function(request, "bench_list_users", args)
|
||||
self.assertIsInstance(result, FunctionResult)
|
||||
self.assertLessEqual(result.data["count"], 10)
|
||||
self.assertEqual(len(result.data["users"]), result.data["count"])
|
||||
|
||||
def test_benchmark_aggregation(self):
|
||||
"""Benchmark: Aggregation queries."""
|
||||
print("\n\n### BENCHMARK: Aggregation (3 COUNT queries) ###")
|
||||
|
||||
args = {}
|
||||
|
||||
exec_stats = self._benchmark_executor("bench_user_stats", args, "aggregation")
|
||||
http_stats = self._benchmark_http("bench_user_stats", args, "aggregation")
|
||||
|
||||
self._print_results([exec_stats, http_stats])
|
||||
self._print_comparison(exec_stats, http_stats)
|
||||
|
||||
# Correctness check: stats should have non-negative counts
|
||||
request = self._make_request()
|
||||
result = execute_function(request, "bench_user_stats", args)
|
||||
self.assertIsInstance(result, FunctionResult)
|
||||
self.assertGreaterEqual(result.data["total_users"], 0)
|
||||
self.assertGreaterEqual(result.data["active_users"], 0)
|
||||
self.assertGreaterEqual(result.data["staff_count"], 0)
|
||||
|
||||
def test_benchmark_search_query(self):
|
||||
"""Benchmark: Search with filter."""
|
||||
print("\n\n### BENCHMARK: Search Query (LIKE + LIMIT) ###")
|
||||
|
||||
args = {"email_contains": "bench", "limit": 20}
|
||||
|
||||
exec_stats = self._benchmark_executor("bench_user_search", args, "search")
|
||||
http_stats = self._benchmark_http("bench_user_search", args, "search")
|
||||
|
||||
self._print_results([exec_stats, http_stats])
|
||||
self._print_comparison(exec_stats, http_stats)
|
||||
|
||||
# Correctness check: search results should contain "bench" in emails
|
||||
request = self._make_request()
|
||||
result = execute_function(request, "bench_user_search", args)
|
||||
self.assertIsInstance(result, FunctionResult)
|
||||
self.assertLessEqual(result.data["count"], 20)
|
||||
for user in result.data["users"]:
|
||||
self.assertIn("bench", user["email"].lower())
|
||||
|
||||
def test_summary(self):
|
||||
"""Print summary of all benchmarks."""
|
||||
print("\n\n" + "=" * 80)
|
||||
print("BENCHMARK SUMMARY")
|
||||
print("=" * 80)
|
||||
print(f"Iterations per benchmark: {self.ITERATIONS}")
|
||||
print(f"Warmup iterations: {self.WARMUP}")
|
||||
print("\nKey findings:")
|
||||
print("- 'Executor' simulates WebSocket RPC (direct function call)")
|
||||
print("- 'HTTP' measures full request/response cycle")
|
||||
print("- HTTP overhead includes: JSON parsing, CSRF, view dispatch")
|
||||
print("- For I/O-bound operations, protocol overhead is negligible")
|
||||
print("=" * 80)
|
||||
|
||||
# Verify bench_simple still produces correct output after all benchmarks
|
||||
request = self._make_request()
|
||||
result = execute_function(request, "bench_simple", {"a": 7, "b": 8})
|
||||
self.assertIsInstance(result, FunctionResult)
|
||||
self.assertEqual(result.data["value"], 15)
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# Throughput Benchmark
|
||||
# =============================================================================
|
||||
|
||||
|
||||
class ThroughputBenchmark(TransactionTestCase):
|
||||
"""
|
||||
Measure requests per second (throughput) for server functions.
|
||||
|
||||
Tests both sequential and concurrent scenarios.
|
||||
"""
|
||||
|
||||
DURATION_SECONDS = 2 # How long to run each throughput test
|
||||
|
||||
@classmethod
|
||||
def setUpClass(cls):
|
||||
super().setUpClass()
|
||||
setup_benchmark_functions()
|
||||
|
||||
def setUp(self):
|
||||
self.factory = RequestFactory()
|
||||
self._create_test_users()
|
||||
|
||||
def _create_test_users(self):
|
||||
"""Create test users for benchmarks."""
|
||||
users = []
|
||||
for i in range(100):
|
||||
users.append(User(
|
||||
email=f"bench{i}@example.com",
|
||||
is_active=i % 10 != 0,
|
||||
is_staff=i < 5,
|
||||
))
|
||||
User.objects.bulk_create(users, ignore_conflicts=True)
|
||||
self.test_user = User.objects.first()
|
||||
|
||||
def _make_request(self, body: dict) -> HttpRequest:
|
||||
"""Create a POST request with JSON body."""
|
||||
request = self.factory.post(
|
||||
"/api/djarea/call/",
|
||||
data=json.dumps(body),
|
||||
content_type="application/json",
|
||||
)
|
||||
request.user = AnonymousUser()
|
||||
request._dont_enforce_csrf_checks = True
|
||||
return request
|
||||
|
||||
def _measure_throughput_executor(self, fn_name: str, args: dict) -> float:
|
||||
"""Measure requests/second using direct executor calls."""
|
||||
request = self._make_request({"fn": fn_name, "args": args})
|
||||
|
||||
# Warmup
|
||||
for _ in range(10):
|
||||
execute_function(request, fn_name, args)
|
||||
|
||||
# Measure
|
||||
count = 0
|
||||
start = time.perf_counter()
|
||||
deadline = start + self.DURATION_SECONDS
|
||||
|
||||
while time.perf_counter() < deadline:
|
||||
execute_function(request, fn_name, args)
|
||||
count += 1
|
||||
|
||||
elapsed = time.perf_counter() - start
|
||||
return count / elapsed
|
||||
|
||||
def _measure_throughput_http(self, fn_name: str, args: dict) -> float:
|
||||
"""Measure requests/second using HTTP view calls."""
|
||||
# Warmup
|
||||
for _ in range(10):
|
||||
request = self._make_request({"fn": fn_name, "args": args})
|
||||
function_call_view(request)
|
||||
|
||||
# Measure
|
||||
count = 0
|
||||
start = time.perf_counter()
|
||||
deadline = start + self.DURATION_SECONDS
|
||||
|
||||
while time.perf_counter() < deadline:
|
||||
request = self._make_request({"fn": fn_name, "args": args})
|
||||
function_call_view(request)
|
||||
count += 1
|
||||
|
||||
elapsed = time.perf_counter() - start
|
||||
return count / elapsed
|
||||
|
||||
def _print_throughput(self, label: str, executor_rps: float, http_rps: float):
|
||||
"""Print throughput results."""
|
||||
print(f"\n{label}:")
|
||||
print(f" Executor (WebSocket): {executor_rps:,.0f} req/s")
|
||||
print(f" HTTP: {http_rps:,.0f} req/s")
|
||||
print(f" Ratio: {executor_rps/http_rps:.1f}x")
|
||||
|
||||
def test_throughput_simple(self):
|
||||
"""Throughput: Simple computation (no I/O)."""
|
||||
print("\n\n### THROUGHPUT: Simple Computation ###")
|
||||
|
||||
executor_rps = self._measure_throughput_executor("bench_simple", {"a": 1, "b": 2})
|
||||
http_rps = self._measure_throughput_http("bench_simple", {"a": 1, "b": 2})
|
||||
|
||||
self._print_throughput("Simple (no I/O)", executor_rps, http_rps)
|
||||
|
||||
# Correctness check
|
||||
request = self._make_request({"fn": "bench_simple", "args": {"a": 1, "b": 2}})
|
||||
result = execute_function(request, "bench_simple", {"a": 1, "b": 2})
|
||||
self.assertIsInstance(result, FunctionResult)
|
||||
self.assertEqual(result.data["value"], 3)
|
||||
|
||||
def test_throughput_single_query(self):
|
||||
"""Throughput: Single ORM query."""
|
||||
print("\n\n### THROUGHPUT: Single ORM Query ###")
|
||||
|
||||
args = {"user_id": self.test_user.id if self.test_user else 1}
|
||||
|
||||
executor_rps = self._measure_throughput_executor("bench_get_user", args)
|
||||
http_rps = self._measure_throughput_http("bench_get_user", args)
|
||||
|
||||
self._print_throughput("Single Query", executor_rps, http_rps)
|
||||
|
||||
# Correctness check
|
||||
request = self._make_request({"fn": "bench_get_user", "args": args})
|
||||
result = execute_function(request, "bench_get_user", args)
|
||||
self.assertIsInstance(result, FunctionResult)
|
||||
self.assertEqual(result.data["id"], self.test_user.id)
|
||||
|
||||
def test_throughput_list_query(self):
|
||||
"""Throughput: List query."""
|
||||
print("\n\n### THROUGHPUT: List Query (10 users) ###")
|
||||
|
||||
executor_rps = self._measure_throughput_executor("bench_list_users", {"limit": 10})
|
||||
http_rps = self._measure_throughput_http("bench_list_users", {"limit": 10})
|
||||
|
||||
self._print_throughput("List Query", executor_rps, http_rps)
|
||||
|
||||
# Correctness check
|
||||
request = self._make_request({"fn": "bench_list_users", "args": {"limit": 10}})
|
||||
result = execute_function(request, "bench_list_users", {"limit": 10})
|
||||
self.assertIsInstance(result, FunctionResult)
|
||||
self.assertLessEqual(result.data["count"], 10)
|
||||
|
||||
def test_throughput_aggregation(self):
|
||||
"""Throughput: Aggregation queries."""
|
||||
print("\n\n### THROUGHPUT: Aggregation ###")
|
||||
|
||||
executor_rps = self._measure_throughput_executor("bench_user_stats", {})
|
||||
http_rps = self._measure_throughput_http("bench_user_stats", {})
|
||||
|
||||
self._print_throughput("Aggregation", executor_rps, http_rps)
|
||||
|
||||
# Correctness check
|
||||
request = self._make_request({"fn": "bench_user_stats", "args": {}})
|
||||
result = execute_function(request, "bench_user_stats", {})
|
||||
self.assertIsInstance(result, FunctionResult)
|
||||
self.assertGreaterEqual(result.data["total_users"], 0)
|
||||
|
||||
def test_throughput_summary(self):
|
||||
"""Print throughput summary."""
|
||||
print("\n\n" + "=" * 80)
|
||||
print("THROUGHPUT SUMMARY")
|
||||
print("=" * 80)
|
||||
print(f"Test duration: {self.DURATION_SECONDS}s per scenario")
|
||||
print("\nNotes:")
|
||||
print("- These are single-threaded sequential measurements")
|
||||
print("- Real throughput scales with worker processes (gunicorn -w N)")
|
||||
print("- Database queries are the bottleneck, not protocol overhead")
|
||||
print("- Async workers (uvicorn) can handle more concurrent connections")
|
||||
print("=" * 80)
|
||||
|
||||
# Verify bench_simple still produces correct output after all throughput tests
|
||||
request = self._make_request({"fn": "bench_simple", "args": {"a": 10, "b": 20}})
|
||||
result = execute_function(request, "bench_simple", {"a": 10, "b": 20})
|
||||
self.assertIsInstance(result, FunctionResult)
|
||||
self.assertEqual(result.data["value"], 30)
|
||||
1170
django/src/djarea/tests/test_channels.py
Normal file
1170
django/src/djarea/tests/test_channels.py
Normal file
File diff suppressed because it is too large
Load Diff
1081
django/src/djarea/tests/test_core.py
Normal file
1081
django/src/djarea/tests/test_core.py
Normal file
File diff suppressed because it is too large
Load Diff
1224
django/src/djarea/tests/test_pentest.py
Normal file
1224
django/src/djarea/tests/test_pentest.py
Normal file
File diff suppressed because it is too large
Load Diff
1095
django/src/djarea/tests/test_security.py
Normal file
1095
django/src/djarea/tests/test_security.py
Normal file
File diff suppressed because it is too large
Load Diff
Reference in New Issue
Block a user