Restructure tree by role; rename mizan-runtime → mizan-base

packages/ flattens into:
  backends/   server protocol adapters (mizan-django, mizan-ts)
  frontends/  client kernel + framework adapters (mizan-base, mizan-react, mizan-vue, mizan-svelte)
  workers/    runtime workers (mizan-ssr)
  cores/      shared language-level primitives (empty for now; mizan-python forthcoming)

The frontend kernel (was packages/mizan-runtime, now frontends/mizan-base) is
renamed to reflect its role — it's the shared base that frontend adapters
depend on directly. Reflects the substrate position that per-framework adapters
wrap a single shared kernel; codegen targets the adapter, not the raw kernel.

Path updates landed in: Makefile, two Gitea workflows, Dockerfile.test, four
example/harness config files, .claude/settings.local.json, four docs
(CLAUDE/ISSUES/ROADMAP/AFI_ARCHITECTURE), four codegen templates (stage1 +
react/vue/svelte adapters), and three package.jsons (the mizan-base rename
plus mizan-vue/svelte peerDeps).

Generated files under examples/django-react-site/harness/src/api/ still
reference @mizan/runtime — left as-is; they're regenerated artifacts and
the harness is non-functional pending the React wrapper-layer codegen.

Also folded in a pre-existing fix: the Gitea workflows had
working-directory: react / django pointing at a layout that predates
packages/, never updated.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
This commit is contained in:
2026-05-05 20:55:37 -04:00
parent 6eca514777
commit fe39fcb229
126 changed files with 0 additions and 0 deletions

4
backends/mizan-django/.gitignore vendored Normal file
View File

@@ -0,0 +1,4 @@
__pycache__/
*.egg-info/
dist/
build/

View File

@@ -0,0 +1,105 @@
# mizan (Python)
Django server functions framework. See the [monorepo root](../README.md) for full documentation.
## Install
```bash
uv add "mizan[channels,allauth] @ git+https://git.impactsoundworks.com/isw/mizan.git#subdirectory=django"
```
## Setup
```python
# settings.py
INSTALLED_APPS = ["mizan", ...]
# urls.py
path("api/mizan/", include("mizan.urls"))
# asgi.py (optional, for WebSocket)
from mizan import wrap_asgi
application = wrap_asgi(get_asgi_application())
```
## Define Functions
```python
from mizan.client import client
from mizan.setup.registry import register
from pydantic import BaseModel
class Output(BaseModel):
message: str
@client
def echo(request, text: str) -> Output:
return Output(message=text)
register(echo, "echo")
```
Register in `apps.py`:
```python
def ready(self):
import myapp.mizan_clients
```
## Auth
```python
@client(auth=True) # requires authentication
@client(auth='staff') # requires is_staff
@client(auth='superuser') # requires is_superuser
@client(auth=my_callable) # custom check
```
## Contexts
```python
@client(context='global') # fetched once, SSR-hydrated, becomes useCurrentUser()
@client(context='local') # fetched with params, becomes <GreetProvider>
```
## Forms
```python
from mizan.forms import mizanFormMixin, mizanFormMeta
class ContactForm(mizanFormMixin, forms.Form):
mizan = mizanFormMeta(name="contact", title="Contact Us")
name = forms.CharField()
email = forms.EmailField()
def on_submit_success(self, request):
return {"sent": True}
```
Auto-registers `contact.schema`, `contact.validate`, `contact.submit`. Generates `useContactForm()` with Zod validation.
## Channels
```python
from mizan.channels import ReactChannel
class ChatChannel(ReactChannel):
class Params(BaseModel):
room: str
class DjangoMessage(BaseModel):
text: str
def authorize(self, params):
return self.user.is_authenticated
def group(self, params):
return f"chat_{params.room}"
```
Generates `useChatChannel({ room })`.
## Running Tests
```bash
uv sync --extra dev --extra channels
uv run pytest
```

View File

@@ -0,0 +1,223 @@
#!/usr/bin/env node
/**
* mizan Code Generator CLI
*
* Two-stage codegen:
* Stage 1: Framework-agnostic types + fetch/mutation functions
* Stage 2: Framework-specific wrappers (React hooks, Vue composables, Svelte stores)
*
* Usage:
* npx mizan-generate # React (default)
* npx mizan-generate --target vue # Vue
* npx mizan-generate --target react,vue,svelte # All three
*/
import { promises as fs } from 'fs'
import path from 'path'
import { fetchChannelsSchema, fetchMizanSchema } from './lib/fetch.mjs'
import { generateTypes, generateContextFile, generateMutationFile, generateFunctionFile, generateStage1Index } from './lib/stage1.mjs'
import { generateReactAdapter } from './lib/adapters/react.mjs'
import { generateVueAdapter } from './lib/adapters/vue.mjs'
import { generateSvelteAdapter } from './lib/adapters/svelte.mjs'
import { generateChannelsFiles } from './lib/channels.mjs'
const frontendDir = process.cwd()
async function loadConfig(configPath) {
const fullPath = path.resolve(frontendDir, configPath)
try { await fs.access(fullPath) } catch { throw new Error(`Config not found: ${fullPath}`) }
const fileUrl = new URL(`file://${fullPath.replace(/\\/g, '/')}`)
const module = await import(fileUrl)
return module.default
}
async function writeOutput(filePath, content) {
const dir = path.dirname(filePath)
await fs.mkdir(dir, { recursive: true })
await fs.writeFile(filePath, content, 'utf8')
}
function pascalCase(str) {
return str.split(/[.\-_]/).map(p => p.charAt(0).toUpperCase() + p.slice(1)).join('')
}
async function generate(config, options = {}) {
const { output, target: targetFlag } = options
const outputDir = output || config.output || 'src/api'
const targets = (targetFlag || config.target || 'react').split(',').map(t => t.trim())
console.log(`[mizan] Starting generation (targets: ${targets.join(', ')})...`)
const fullOutputDir = path.resolve(frontendDir, outputDir)
let mizanSchema = null
let channelsSchema = null
// ── Channels (React-only for now) ───────────────────────────────────
try {
console.log('[mizan] Fetching channels schema...')
channelsSchema = await fetchChannelsSchema(config.source, frontendDir)
const channelCount = channelsSchema['x-mizan-channels']?.length || 0
if (channelCount > 0 && targets.includes('react')) {
console.log(`[mizan] Found ${channelCount} channels`)
const { types, hooks } = await generateChannelsFiles(channelsSchema)
await writeOutput(path.join(fullOutputDir, 'channels.ts'), types)
if (hooks) await writeOutput(path.join(fullOutputDir, 'channels.hooks.tsx'), hooks)
}
} catch (err) {
console.log(`[mizan] Channels not available: ${err.message}`)
}
// ── Mizan functions ─────────────────────────────────────────────────
try {
console.log('[mizan] Fetching mizan schema...')
mizanSchema = await fetchMizanSchema(config.source, frontendDir)
const functions = mizanSchema['x-mizan-functions'] || []
const contextGroups = mizanSchema['x-mizan-contexts'] || {}
if (functions.length === 0) {
console.log('[mizan] No functions registered')
return
}
console.log(`[mizan] Found ${functions.length} functions`)
// ── Stage 1: Framework-agnostic ─────────────────────────────────
// Types
const types = await generateTypes(mizanSchema)
await writeOutput(path.join(fullOutputDir, 'types.ts'), types)
console.log('[mizan] Stage 1 -> types.ts')
// Context files
await fs.mkdir(path.join(fullOutputDir, 'contexts'), { recursive: true })
for (const [ctxName, ctxMeta] of Object.entries(contextGroups)) {
const content = generateContextFile(ctxName, ctxMeta, functions)
await writeOutput(path.join(fullOutputDir, 'contexts', `${ctxName}.ts`), content)
console.log(`[mizan] Stage 1 -> contexts/${ctxName}.ts`)
}
// Mutation + function files
const regularFns = functions.filter(fn => !fn.isContext && !fn.isForm)
if (regularFns.length > 0) {
await fs.mkdir(path.join(fullOutputDir, 'mutations'), { recursive: true })
await fs.mkdir(path.join(fullOutputDir, 'functions'), { recursive: true })
for (const fn of regularFns) {
const dir = fn.affects ? 'mutations' : 'functions'
const content = fn.affects ? generateMutationFile(fn) : generateFunctionFile(fn)
await writeOutput(path.join(fullOutputDir, dir, `${fn.camelName}.ts`), content)
console.log(`[mizan] Stage 1 -> ${dir}/${fn.camelName}.ts`)
}
}
// Stage 1 index
const stage1Index = generateStage1Index(mizanSchema)
await writeOutput(path.join(fullOutputDir, 'index.ts'), stage1Index)
console.log('[mizan] Stage 1 -> index.ts')
// ── Stage 2: Framework-specific ─────────────────────────────────
for (const target of targets) {
let content
let filename
switch (target) {
case 'react':
content = generateReactAdapter(mizanSchema)
filename = 'react.tsx'
break
case 'vue':
content = generateVueAdapter(mizanSchema)
filename = 'vue.ts'
break
case 'svelte':
content = generateSvelteAdapter(mizanSchema)
filename = 'svelte.ts'
break
default:
console.warn(`[mizan] Unknown target: ${target}`)
continue
}
if (content) {
await writeOutput(path.join(fullOutputDir, filename), content)
console.log(`[mizan] Stage 2 -> ${filename}`)
}
}
// Schema JSON
await writeOutput(
path.join(fullOutputDir, 'schema.json'),
JSON.stringify(mizanSchema, null, 2),
)
} catch (err) {
console.log(`[mizan] Schema not available: ${err.message}`)
}
console.log('[mizan] Generation complete!')
}
async function main() {
const args = process.argv.slice(2)
let configPath = 'django.config.mjs'
let watchMode = false
let output = null
let target = null
for (let i = 0; i < args.length; i++) {
if (args[i] === '--config' || args[i] === '-c') configPath = args[++i]
else if (args[i] === '--watch' || args[i] === '-w') watchMode = true
else if (args[i] === '--output' || args[i] === '-o') output = args[++i]
else if (args[i] === '--target' || args[i] === '-t') target = args[++i]
else if (args[i] === '--help' || args[i] === '-h') {
console.log(`
mizan Code Generator
Usage:
npx mizan-generate [options]
Options:
-c, --config <path> Config file (default: django.config.mjs)
-t, --target <targets> Comma-separated: react,vue,svelte (default: react)
-o, --output <dir> Output directory (default: src/api)
-w, --watch Watch mode
-h, --help Show help
`)
process.exit(0)
}
}
const config = await loadConfig(configPath)
const options = { output, target }
if (watchMode) {
await generate(config, options)
console.log('[mizan] Watching for changes...')
const { watch: chokidarWatch } = await import('chokidar')
if (config.source.django) {
const djangoDir = path.resolve(frontendDir, path.dirname(config.source.django.managePath))
let timeout = null
const watcher = chokidarWatch([path.join(djangoDir, '**/*.py')], {
ignored: ['**/node_modules/**', '**/__pycache__/**', '**/migrations/**'],
ignoreInitial: true,
})
watcher.on('change', () => {
if (timeout) clearTimeout(timeout)
timeout = setTimeout(() => generate(config, options), 1000)
})
}
process.on('SIGINT', () => process.exit(0))
} else {
await generate(config, options)
}
}
main().catch(err => {
console.error('[mizan] Error:', err.message)
process.exit(1)
})

View File

@@ -0,0 +1,180 @@
/**
* React Stage 2 — Generates hooks + context providers from Stage 1 output.
*
* Generated providers subscribe to the runtime kernel for state.
* The kernel owns data, status, and error. React just renders.
*/
function pascalCase(str) {
return str.split(/[.\-_]/).map(p => p.charAt(0).toUpperCase() + p.slice(1)).join('')
}
export function generateReactAdapter(schema) {
const functions = schema['x-mizan-functions'] || []
const contextGroups = schema['x-mizan-contexts'] || {}
const namedContexts = Object.entries(contextGroups).filter(([n]) => n !== 'global')
const globalContexts = functions.filter(fn => fn.isContext === 'global')
const mutations = functions.filter(fn => !fn.isContext && !fn.isForm && fn.affects)
const plainFns = functions.filter(fn => !fn.isContext && !fn.isForm && !fn.affects)
const lines = [
"'use client'",
'',
'// AUTO-GENERATED by mizan — do not edit',
'',
"import { createContext, useContext, useState, useEffect, useCallback, useRef, useSyncExternalStore, type ReactNode } from 'react'",
"import { registerContext, mizanFetch, mizanCall, type ContextState } from '@mizan/runtime'",
'',
]
// Import from Stage 1
const stage1Imports = []
for (const [ctxName] of Object.entries(contextGroups)) {
const p = pascalCase(ctxName)
stage1Imports.push(`fetch${p}Context`, `type ${p}ContextData`, `type ${p}ContextParams`)
}
for (const fn of [...mutations, ...plainFns]) {
stage1Imports.push(`call${pascalCase(fn.camelName)}`)
}
if (stage1Imports.length > 0) {
lines.push(`import { ${stage1Imports.join(', ')} } from '../index'`)
lines.push('')
}
// ── Helper hook: subscribe to kernel state ──────────────────────────
lines.push('// Subscribe to kernel state via useSyncExternalStore')
lines.push('function useContextState<T>(')
lines.push(' name: string,')
lines.push(' params: Record<string, any>,')
lines.push(' fetchFn: () => Promise<T>,')
lines.push(' initialData?: T,')
lines.push('): ContextState<T> {')
lines.push(' const ref = useRef<ReturnType<typeof registerContext> | null>(null)')
lines.push('')
lines.push(' if (!ref.current) {')
lines.push(' ref.current = registerContext(name, params, fetchFn, initialData)')
lines.push(' }')
lines.push('')
lines.push(' const handle = ref.current')
lines.push('')
lines.push(' // Fetch on mount if no data')
lines.push(' useEffect(() => {')
lines.push(" if (handle.getState().status === 'idle') handle.refetch()")
lines.push(' return () => handle.unregister()')
lines.push(' }, [handle])')
lines.push('')
lines.push(' return useSyncExternalStore(')
lines.push(' handle.subscribe,')
lines.push(' handle.getState,')
lines.push(' handle.getState,')
lines.push(' )')
lines.push('}')
lines.push('')
// ── Mutation hook helper ────────────────────────────────────────────
lines.push('// Mutation hook with loading/error state')
lines.push('function useMutation<TArgs, TResult>(')
lines.push(' callFn: (args: TArgs) => Promise<TResult>,')
lines.push('): { mutate: (args: TArgs) => Promise<TResult>; isPending: boolean; error: Error | null } {')
lines.push(' const [isPending, setIsPending] = useState(false)')
lines.push(' const [error, setError] = useState<Error | null>(null)')
lines.push('')
lines.push(' const mutate = useCallback(async (args: TArgs) => {')
lines.push(' setIsPending(true)')
lines.push(' setError(null)')
lines.push(' try {')
lines.push(' const result = await callFn(args)')
lines.push(' return result')
lines.push(' } catch (e) {')
lines.push(' setError(e as Error)')
lines.push(' throw e')
lines.push(' } finally {')
lines.push(' setIsPending(false)')
lines.push(' }')
lines.push(' }, [callFn])')
lines.push('')
lines.push(' return { mutate, isPending, error }')
lines.push('}')
lines.push('')
// ── Context hooks ───────────────────────────────────────────────────
for (const [ctxName, ctxMeta] of Object.entries(contextGroups)) {
const p = pascalCase(ctxName)
const ctxFunctions = functions.filter(fn => fn.isContext === ctxName)
const paramEntries = Object.entries(ctxMeta.params || {})
lines.push(`// ── ${p} Context ──`)
lines.push('')
// Hook that returns the full kernel state
if (paramEntries.length > 0) {
lines.push(`export function use${p}Context(params: ${p}ContextParams): ContextState<${p}ContextData> {`)
lines.push(` const ssrData = typeof window !== 'undefined' ? (window as any).__MIZAN_SSR_DATA__ : null`)
lines.push(` return useContextState('${ctxName}', params, () => fetch${p}Context(params), ssrData)`)
} else {
lines.push(`export function use${p}Context(): ContextState<${p}ContextData> {`)
lines.push(` const ssrData = typeof window !== 'undefined' ? (window as any).__MIZAN_SSR_DATA__ : null`)
lines.push(` return useContextState('${ctxName}', {}, () => fetch${p}Context({} as any), ssrData)`)
}
lines.push('}')
lines.push('')
// Convenience hooks for individual data fields
for (const fn of ctxFunctions) {
const hookPascal = pascalCase(fn.camelName)
if (paramEntries.length > 0) {
lines.push(`export function use${hookPascal}(params: ${p}ContextParams): ${fn.outputType} | null {`)
lines.push(` const state = use${p}Context(params)`)
} else {
lines.push(`export function use${hookPascal}(): ${fn.outputType} | null {`)
lines.push(` const state = use${p}Context()`)
}
lines.push(` return state.data?.${fn.name} ?? null`)
lines.push('}')
lines.push('')
}
}
// ── Mutation hooks (with loading/error) ──────────────────────────────
for (const fn of mutations) {
const p = pascalCase(fn.camelName)
if (fn.hasInput) {
lines.push(`export function use${p}() {`)
lines.push(` return useMutation<Parameters<typeof call${p}>[0], Awaited<ReturnType<typeof call${p}>>>(call${p})`)
lines.push('}')
} else {
lines.push(`export function use${p}() {`)
lines.push(` return useMutation<void, Awaited<ReturnType<typeof call${p}>>>(() => call${p}() as any)`)
lines.push('}')
}
lines.push('')
}
// ── Plain function hooks ────────────────────────────────────────────
for (const fn of plainFns) {
const p = pascalCase(fn.camelName)
if (fn.hasInput) {
lines.push(`export function use${p}() {`)
lines.push(` return useMutation<Parameters<typeof call${p}>[0], Awaited<ReturnType<typeof call${p}>>>(call${p})`)
lines.push('}')
} else {
lines.push(`export function use${p}() {`)
lines.push(` return useMutation<void, Awaited<ReturnType<typeof call${p}>>>(() => call${p}() as any)`)
lines.push('}')
}
lines.push('')
}
// ── Re-export runtime types ─────────────────────────────────────────
lines.push("export type { ContextState } from '@mizan/runtime'")
lines.push("export { configure, initSession, MizanError } from '@mizan/runtime'")
lines.push('')
return lines.join('\n')
}

View File

@@ -0,0 +1,78 @@
/**
* Svelte Stage 2 — Generates stores from Stage 1 output.
*
* Subscribes to the kernel for state. Returns readable stores.
*/
function pascalCase(str) {
return str.split(/[.\-_]/).map(p => p.charAt(0).toUpperCase() + p.slice(1)).join('')
}
export function generateSvelteAdapter(schema) {
const functions = schema['x-mizan-functions'] || []
const contextGroups = schema['x-mizan-contexts'] || {}
const mutations = functions.filter(fn => !fn.isContext && !fn.isForm && fn.affects)
const plainFns = functions.filter(fn => !fn.isContext && !fn.isForm && !fn.affects)
const lines = [
'// AUTO-GENERATED by mizan — do not edit',
'',
"import { readable, type Readable } from 'svelte/store'",
"import { registerContext, type ContextState } from '@mizan/runtime'",
'',
]
const stage1Imports = []
for (const [ctxName] of Object.entries(contextGroups)) {
const p = pascalCase(ctxName)
stage1Imports.push(`fetch${p}Context`, `type ${p}ContextData`, `type ${p}ContextParams`)
}
for (const fn of [...mutations, ...plainFns]) {
stage1Imports.push(`call${pascalCase(fn.camelName)}`)
}
if (stage1Imports.length > 0) {
lines.push(`import { ${stage1Imports.join(', ')} } from '../index'`)
lines.push('')
}
for (const [ctxName, ctxMeta] of Object.entries(contextGroups)) {
const p = pascalCase(ctxName)
const ctxFunctions = functions.filter(fn => fn.isContext === ctxName)
const paramEntries = Object.entries(ctxMeta.params || {})
const paramsArg = paramEntries.length > 0 ? 'params' : '{} as any'
if (paramEntries.length > 0) {
lines.push(`export function create${p}Context(params: ${p}ContextParams) {`)
} else {
lines.push(`export function create${p}Context() {`)
}
// Use readable store backed by kernel subscription
lines.push(` const store = readable<ContextState<${p}ContextData>>(`)
lines.push(` { data: null, status: 'idle', error: null },`)
lines.push(` (set) => {`)
lines.push(` const handle = registerContext('${ctxName}', ${paramsArg}, () => fetch${p}Context(${paramsArg}))`)
lines.push(` const unsub = handle.subscribe(() => set(handle.getState()))`)
lines.push(` handle.refetch()`)
lines.push(` return () => { unsub(); handle.unregister() }`)
lines.push(` },`)
lines.push(` )`)
lines.push('')
lines.push(` return store`)
lines.push('}')
lines.push('')
}
// Re-export mutations as-is from Stage 1
for (const fn of [...mutations, ...plainFns]) {
const p = pascalCase(fn.camelName)
lines.push(`export { call${p} } from '../index'`)
}
lines.push('')
lines.push("export type { ContextState } from '@mizan/runtime'")
lines.push("export { configure, initSession, MizanError } from '@mizan/runtime'")
lines.push('')
return lines.join('\n')
}

View File

@@ -0,0 +1,104 @@
/**
* Vue Stage 2 — Generates composables from Stage 1 output.
*
* Subscribes to the kernel for state. Vue reactivity wraps kernel notifications.
*/
function pascalCase(str) {
return str.split(/[.\-_]/).map(p => p.charAt(0).toUpperCase() + p.slice(1)).join('')
}
export function generateVueAdapter(schema) {
const functions = schema['x-mizan-functions'] || []
const contextGroups = schema['x-mizan-contexts'] || {}
const mutations = functions.filter(fn => !fn.isContext && !fn.isForm && fn.affects)
const plainFns = functions.filter(fn => !fn.isContext && !fn.isForm && !fn.affects)
const lines = [
'// AUTO-GENERATED by mizan — do not edit',
'',
"import { ref, computed, onMounted, onUnmounted, onServerPrefetch, type ComputedRef } from 'vue'",
"import { registerContext, type ContextState } from '@mizan/runtime'",
'',
]
const stage1Imports = []
for (const [ctxName] of Object.entries(contextGroups)) {
const p = pascalCase(ctxName)
stage1Imports.push(`fetch${p}Context`, `type ${p}ContextData`, `type ${p}ContextParams`)
}
for (const fn of [...mutations, ...plainFns]) {
stage1Imports.push(`call${pascalCase(fn.camelName)}`)
}
if (stage1Imports.length > 0) {
lines.push(`import { ${stage1Imports.join(', ')} } from '../index'`)
lines.push('')
}
for (const [ctxName, ctxMeta] of Object.entries(contextGroups)) {
const p = pascalCase(ctxName)
const ctxFunctions = functions.filter(fn => fn.isContext === ctxName)
const paramEntries = Object.entries(ctxMeta.params || {})
const paramsArg = paramEntries.length > 0 ? 'params' : '{} as any'
if (paramEntries.length > 0) {
lines.push(`export function use${p}Context(params: ${p}ContextParams) {`)
} else {
lines.push(`export function use${p}Context() {`)
}
lines.push(` const state = ref<ContextState<${p}ContextData>>({ data: null, status: 'idle', error: null })`)
lines.push(` let handle: ReturnType<typeof registerContext> | null = null`)
lines.push('')
lines.push(` onMounted(() => {`)
lines.push(` handle = registerContext('${ctxName}', ${paramsArg}, () => fetch${p}Context(${paramsArg}))`)
lines.push(` handle.subscribe(() => { state.value = handle!.getState() })`)
lines.push(` handle.refetch()`)
lines.push(` })`)
lines.push('')
lines.push(` onServerPrefetch(async () => {`)
lines.push(` handle = registerContext('${ctxName}', ${paramsArg}, () => fetch${p}Context(${paramsArg}))`)
lines.push(` await handle.refetch()`)
lines.push(` state.value = handle.getState()`)
lines.push(` })`)
lines.push('')
lines.push(` onUnmounted(() => { handle?.unregister() })`)
lines.push('')
lines.push(` return {`)
lines.push(` state,`)
for (const fn of ctxFunctions) {
lines.push(` ${fn.camelName}: computed(() => state.value.data?.${fn.name} ?? null) as ComputedRef<${fn.outputType} | null>,`)
}
lines.push(` loading: computed(() => state.value.status === 'loading'),`)
lines.push(` error: computed(() => state.value.error),`)
lines.push(` }`)
lines.push('}')
lines.push('')
}
for (const fn of [...mutations, ...plainFns]) {
const p = pascalCase(fn.camelName)
lines.push(`export function use${p}() {`)
lines.push(` const isPending = ref(false)`)
lines.push(` const error = ref<Error | null>(null)`)
if (fn.hasInput) {
lines.push(` async function mutate(args: Parameters<typeof call${p}>[0]) {`)
} else {
lines.push(` async function mutate() {`)
}
lines.push(` isPending.value = true; error.value = null`)
lines.push(` try { return await call${p}(${fn.hasInput ? 'args' : ''}) }`)
lines.push(` catch (e) { error.value = e as Error; throw e }`)
lines.push(` finally { isPending.value = false }`)
lines.push(` }`)
lines.push(` return { mutate, isPending, error }`)
lines.push('}')
lines.push('')
}
lines.push("export type { ContextState } from '@mizan/runtime'")
lines.push("export { configure, initSession, MizanError } from '@mizan/runtime'")
lines.push('')
return lines.join('\n')
}

View File

@@ -0,0 +1,155 @@
/**
* Channels Code Generator
*
* Generates TypeScript types and React hooks from Channels OpenAPI schema.
* Uses openapi-typescript for robust type generation.
*/
import openapiTS, { astToString } from 'openapi-typescript'
/**
* Generate channels TypeScript types using openapi-typescript.
*/
export async function generateChannelsTypes(schema) {
// Generate types using openapi-typescript
const ast = await openapiTS(schema)
const typesCode = astToString(ast)
const lines = [
'// AUTO-GENERATED by mizan - do not edit manually',
'// Regenerate with: npm run schemas',
'',
'// ============================================================================',
'// OpenAPI Types (generated by openapi-typescript)',
'// ============================================================================',
'',
typesCode,
'',
]
// Extract channel metadata from x-mizan-channels extension
const channels = schema['x-mizan-channels'] || []
if (channels.length > 0) {
lines.push('// ============================================================================')
lines.push('// Convenience Type Exports')
lines.push('// ============================================================================')
lines.push('')
for (const channel of channels) {
if (channel.hasParams) {
lines.push(`export type ${channel.paramsType} = components["schemas"]["${channel.paramsType}"]`)
}
if (channel.hasReactMessage) {
lines.push(`export type ${channel.reactMessageType} = components["schemas"]["${channel.reactMessageType}"]`)
}
if (channel.hasDjangoMessage) {
lines.push(`export type ${channel.djangoMessageType} = components["schemas"]["${channel.djangoMessageType}"]`)
}
}
lines.push('')
lines.push('// ============================================================================')
lines.push('// Channel Registry')
lines.push('// ============================================================================')
lines.push('')
lines.push('export const CHANNELS = {')
for (const channel of channels) {
lines.push(` ${channel.name}: {`)
lines.push(` name: '${channel.name}',`)
lines.push(` pascalName: '${channel.pascalName}',`)
lines.push(` hasParams: ${channel.hasParams},`)
lines.push(` hasReactMessage: ${channel.hasReactMessage},`)
lines.push(` hasDjangoMessage: ${channel.hasDjangoMessage},`)
if (channel.hasParams) {
lines.push(` paramsType: '${channel.paramsType}',`)
}
if (channel.hasReactMessage) {
lines.push(` reactMessageType: '${channel.reactMessageType}',`)
}
if (channel.hasDjangoMessage) {
lines.push(` djangoMessageType: '${channel.djangoMessageType}',`)
}
lines.push(` },`)
}
lines.push('} as const')
} else {
lines.push('export const CHANNELS = {} as const')
}
lines.push('')
return lines.join('\n')
}
/**
* Generate channel hooks from metadata.
*/
export function generateChannelsHooks(schema) {
const channels = schema['x-mizan-channels'] || []
if (channels.length === 0) {
return null
}
const lines = [
"'use client'",
'',
'// AUTO-GENERATED by mizan - do not edit manually',
'// Regenerate with: npm run schemas',
'',
"import { useChannel, type ChannelSubscription } from 'mizan/channels'",
'',
]
// Collect type imports
const typeImports = []
for (const channel of channels) {
if (channel.hasParams) typeImports.push(channel.paramsType)
if (channel.hasReactMessage) typeImports.push(channel.reactMessageType)
if (channel.hasDjangoMessage) typeImports.push(channel.djangoMessageType)
}
if (typeImports.length > 0) {
lines.push(`import type { ${typeImports.join(', ')} } from './generated.channels'`)
lines.push('')
}
// Generate hooks for each channel
lines.push('// ============================================================================')
lines.push('// Channel Hooks')
lines.push('// ============================================================================')
lines.push('')
for (const channel of channels) {
const paramsType = channel.hasParams ? channel.paramsType : 'Record<string, never>'
const reactMsgType = channel.hasReactMessage ? channel.reactMessageType : 'never'
const djangoMsgType = channel.hasDjangoMessage ? channel.djangoMessageType : 'never'
lines.push(`/**`)
lines.push(` * Hook for the ${channel.name} channel.`)
lines.push(` */`)
if (channel.hasParams) {
lines.push(`export function use${channel.pascalName}Channel(params: ${paramsType}): ChannelSubscription<${paramsType}, ${djangoMsgType}, ${reactMsgType}> {`)
lines.push(` return useChannel('${channel.name}', params)`)
} else {
lines.push(`export function use${channel.pascalName}Channel(): ChannelSubscription<Record<string, never>, ${djangoMsgType}, ${reactMsgType}> {`)
lines.push(` return useChannel('${channel.name}', {})`)
}
lines.push('}')
lines.push('')
}
return lines.join('\n')
}
/**
* Generate all channels files.
*/
export async function generateChannelsFiles(schema) {
const types = await generateChannelsTypes(schema)
const hooks = generateChannelsHooks(schema)
return { types, hooks }
}

View File

@@ -0,0 +1,88 @@
/**
* Schema Fetching
*
* Fetches mizan and channels schemas from Django management commands.
*/
import { spawn } from 'child_process'
import path from 'path'
/**
* Run a Django management command and parse JSON output.
*/
function runDjangoCommand(source, cwd, command) {
const managePath = path.resolve(cwd, source.django.managePath)
const manageDir = path.dirname(managePath)
let cmd, args
if (source.django.command) {
cmd = source.django.command[0]
args = [...source.django.command.slice(1), 'manage.py', command, '--indent', '0']
} else {
const python = source.django.python || 'python'
cmd = python
args = [managePath, command, '--indent', '0']
}
const env = source.django.env
? { ...process.env, ...source.django.env }
: undefined
return new Promise((resolve, reject) => {
const proc = spawn(cmd, args, {
cwd: manageDir,
stdio: ['ignore', 'pipe', 'pipe'],
shell: process.platform === 'win32',
env,
})
let stdout = ''
let stderr = ''
proc.stdout.on('data', (data) => { stdout += data.toString() })
proc.stderr.on('data', (data) => { stderr += data.toString() })
proc.on('close', (code) => {
if (code !== 0) {
reject(new Error(`Django command failed (exit ${code}):\n${stderr}`))
return
}
const jsonStart = stdout.indexOf('{')
if (jsonStart === -1) {
reject(new Error(`No JSON found in Django output:\n${stdout}\n${stderr}`))
return
}
try {
resolve(JSON.parse(stdout.slice(jsonStart)))
} catch (err) {
reject(new Error(`Failed to parse JSON from Django:\n${err.message}\n${stdout}`))
}
})
proc.on('error', (err) => {
reject(new Error(`Failed to spawn Django command: ${err.message}`))
})
})
}
/**
* Fetch channels schema from Django.
*/
export async function fetchChannelsSchema(source, cwd) {
if (!source.django) {
throw new Error('Channels schema export requires django source configuration')
}
return runDjangoCommand(source, cwd, 'export_channels_schema')
}
/**
* Fetch mizan schema from Django.
*/
export async function fetchMizanSchema(source, cwd) {
if (!source.django) {
throw new Error('mizan schema export requires django source configuration')
}
return runDjangoCommand(source, cwd, 'export_mizan_schema')
}

View File

@@ -0,0 +1,164 @@
/**
* Index File Generator
*
* Generates a consolidated index.ts that re-exports everything
* from the generated files for clean imports.
*/
function pascalCase(str) {
return str.charAt(0).toUpperCase() + str.slice(1)
}
function toPascalCase(str) {
return str
.split(/[.\-_]/)
.map(part => part.charAt(0).toUpperCase() + part.slice(1))
.join('')
}
/**
* Generate the consolidated index.ts file.
*/
export function generateIndex({ channelsSchema, mizanSchema }) {
const lines = [
'/**',
' * mizan API - Consolidated Exports',
' *',
' * Import everything from here:',
' *',
' * @example',
' * ```tsx',
' * import {',
' * MizanContext,',
' * useCurrentUser,',
' * useEcho,',
' * useChatChannel,',
' * } from \'@/api\'',
' * ```',
' */',
'',
'// AUTO-GENERATED by mizan - do not edit manually',
'// Regenerate with: npm run schemas',
'',
]
const functions = mizanSchema?.['x-mizan-functions'] || []
const contextGroups = mizanSchema?.['x-mizan-contexts'] || {}
const hasMizan = functions.length > 0
if (hasMizan) {
const globalContexts = functions.filter(fn => fn.isContext === 'global')
const regularFunctions = functions.filter(fn => !fn.isContext && !fn.isForm)
const namedContextEntries = Object.entries(contextGroups).filter(([name]) => name !== 'global')
lines.push('// =============================================================================')
lines.push('// mizan Provider & Hooks')
lines.push('// =============================================================================')
lines.push('')
// Server exports
if (globalContexts.length > 0) {
lines.push('export {')
lines.push(' getMizanHydration,')
lines.push(' getDjangoHydration,')
lines.push(' type MizanHydrationData,')
lines.push(' type DjangoHydration,')
lines.push("} from './generated.server'")
lines.push('')
}
// Client exports
lines.push('export {')
lines.push(' // Provider')
lines.push(' MizanContext,')
lines.push(' type MizanContextProps,')
lines.push(' DjangoContext,')
lines.push(' type DjangoContextProps,')
// Global context hooks
if (globalContexts.length > 0) {
lines.push('')
lines.push(' // Global context hooks')
for (const ctx of globalContexts) {
const hookPascal = pascalCase(ctx.camelName)
lines.push(` use${hookPascal},`)
}
lines.push('')
lines.push(' // Refresh hooks')
lines.push(' useMizanRefresh,')
lines.push(' useDjangoRefresh,')
}
// Named context providers and hooks
if (namedContextEntries.length > 0) {
lines.push('')
lines.push(' // Named context providers')
for (const [ctxName, ctxMeta] of namedContextEntries) {
const ctxPascal = toPascalCase(ctxName)
lines.push(` ${ctxPascal}Context,`)
// Hooks for this context's functions
const ctxFunctions = functions.filter(fn => fn.isContext === ctxName)
for (const fn of ctxFunctions) {
const hookPascal = pascalCase(fn.camelName)
lines.push(` use${hookPascal},`)
}
}
}
// Function hooks (mutations + plain)
if (regularFunctions.length > 0) {
lines.push('')
lines.push(' // Function hooks')
for (const fn of regularFunctions) {
const pascal = pascalCase(fn.camelName)
lines.push(` use${pascal},`)
}
}
lines.push('')
lines.push(' // Re-exports from mizan library')
lines.push(' useMizan,')
lines.push(' useMizanStatus,')
lines.push(' usePush,')
lines.push(' DjangoError,')
lines.push(' type ConnectionStatus,')
lines.push(' type PushMessage,')
lines.push(' type PushListener,')
lines.push("} from './generated.provider'")
lines.push('')
}
// ==========================================================================
// Channel Hooks
// ==========================================================================
const channels = channelsSchema?.['x-mizan-channels'] || []
if (channels.length > 0) {
lines.push('// =============================================================================')
lines.push('// Channel Hooks')
lines.push('// =============================================================================')
lines.push('')
lines.push('export {')
for (const ch of channels) {
lines.push(` use${ch.pascalName}Channel,`)
}
lines.push("} from './generated.channels.hooks'")
lines.push('')
lines.push('// =============================================================================')
lines.push('// Channel Types')
lines.push('// =============================================================================')
lines.push('')
lines.push('export type {')
for (const ch of channels) {
if (ch.hasParams) lines.push(` ${ch.paramsType},`)
if (ch.hasReactMessage) lines.push(` ${ch.reactMessageType},`)
if (ch.hasDjangoMessage) lines.push(` ${ch.djangoMessageType},`)
}
lines.push("} from './generated.channels'")
lines.push('')
}
return lines.join('\n')
}

View File

@@ -0,0 +1,980 @@
/**
* mizan Code Generator
*
* Generates TypeScript types and React provider from mizan OpenAPI schema.
* Uses openapi-typescript for robust type generation.
*
* Output structure:
* - generated.mizan.ts - Types only (from OpenAPI)
* - generated.provider.tsx - Typed provider wrapping MizanProvider + hooks
* - generated.forms.ts - Typed form hooks with Zod schemas
*/
import openapiTS, { astToString } from 'openapi-typescript'
// TypeScript SyntaxKind values for AST manipulation
const SyntaxKind = {
InterfaceDeclaration: 265,
TypeAliasDeclaration: 266,
PropertySignature: 172,
TypeReference: 184,
IndexedAccessType: 200,
Identifier: 80,
StringLiteral: 11,
}
/**
* Get identifier name from AST node.
*/
function idName(node) {
return node?.kind === SyntaxKind.Identifier ? node.escapedText : undefined
}
/**
* Extract schema names from openapi-typescript AST.
*/
function getSchemaNamesFromAst(ast) {
if (!Array.isArray(ast)) return []
const componentsNode = ast.find(
node =>
node?.kind === SyntaxKind.InterfaceDeclaration &&
idName(node?.name) === 'components'
)
if (!componentsNode?.members) return []
const schemasProp = componentsNode.members.find(
member =>
member?.kind === SyntaxKind.PropertySignature &&
idName(member?.name) === 'schemas' &&
Array.isArray(member?.type?.members)
)
if (!schemasProp) return []
return schemasProp.type.members
.map(member =>
member?.kind === SyntaxKind.PropertySignature ? idName(member.name) : undefined
)
.filter(n => typeof n === 'string')
}
/**
* Build convenience type exports for schemas.
*/
function buildSchemaExports(schemaNames) {
if (!schemaNames.length) return ''
return schemaNames
.map(name => `export type ${name} = components["schemas"]["${name}"]`)
.join('\n')
}
/**
* Generate the types file using openapi-typescript.
*/
export async function generateMizanTypes(schema) {
// Generate types using openapi-typescript
const ast = await openapiTS(schema)
const schemaNames = getSchemaNamesFromAst(ast)
const typesCode = astToString(ast)
const lines = [
'// AUTO-GENERATED by mizan - do not edit manually',
'// Regenerate with: npm run schemas',
'',
'// ============================================================================',
'// OpenAPI Types (generated by openapi-typescript)',
'// ============================================================================',
'',
typesCode,
'',
'// ============================================================================',
'// Convenience Type Exports',
'// ============================================================================',
'',
buildSchemaExports(schemaNames),
'',
'// ============================================================================',
'// Function Registry (for reference)',
'// ============================================================================',
'',
"export type Transport = 'http' | 'websocket' | 'both'",
'',
]
// Extract function metadata from x-mizan-functions extension
const functions = schema['x-mizan-functions'] || []
if (functions.length > 0) {
lines.push('export const MIZAN_FUNCTIONS = {')
for (const fn of functions) {
lines.push(` ${fn.camelName}: {`)
lines.push(` name: '${fn.name}',`)
lines.push(` hasInput: ${fn.hasInput},`)
lines.push(` isContext: ${fn.isContext},`)
lines.push(` transport: '${fn.transport}' as Transport,`)
lines.push(` },`)
}
lines.push('} as const')
} else {
lines.push('export const MIZAN_FUNCTIONS = {} as const')
}
lines.push('')
return lines.join('\n')
}
/**
* Extract unique context names from an affects array.
* Both context-level and function-level affects resolve to context names.
*/
function getAffectedContexts(affects) {
const contexts = new Set()
for (const target of affects) {
if (target.type === 'context') {
contexts.add(target.name)
} else if (target.type === 'function' && target.context) {
contexts.add(target.context)
}
}
return [...contexts]
}
/**
* Map JSON schema type string to TypeScript type.
*/
function jsonTypeToTS(type) {
if (type === 'integer' || type === 'number') return 'number'
if (type === 'boolean') return 'boolean'
return 'string'
}
/**
* Generate the React provider that wraps MizanProvider with typed hooks.
*
* The generated provider:
* - MizanContext: Root provider with global context bundled fetch
* - Named context providers: <UserContext user_id={...}>
* - Mutation hooks with auto-invalidation
* - Plain function hooks
*/
export function generateMizanProvider(schema, options = {}) {
const { hasChannels = false } = options
const functions = schema['x-mizan-functions'] || []
const contextGroups = schema['x-mizan-contexts'] || {}
if (functions.length === 0) {
return null
}
// Partition functions
const globalContexts = functions.filter(fn => fn.isContext === 'global')
const regularFunctions = functions.filter(fn => !fn.isContext && !fn.isForm)
const mutationFunctions = regularFunctions.filter(fn => fn.affects)
const plainFunctions = regularFunctions.filter(fn => !fn.affects)
// Named context groups (everything except 'global')
const namedContextEntries = Object.entries(contextGroups).filter(([name]) => name !== 'global')
// Collect type imports
const typeImports = []
for (const fn of functions) {
if (fn.hasInput && fn.inputType) {
typeImports.push(fn.inputType)
}
if (fn.outputType) {
typeImports.push(fn.outputType)
}
}
const uniqueTypeImports = [...new Set(typeImports)].sort()
const lines = [
"'use client'",
'',
'// AUTO-GENERATED by mizan - do not edit manually',
'// Regenerate with: npm run schemas',
'',
'// This file provides typed wrappers around the mizan library.',
'// - MizanContext: Root provider with global context',
'// - Named context providers: <UserContext user_id={...}>',
'// - Typed hooks with auto-invalidation',
'',
"import { type ReactNode, useCallback, useState, useEffect, useRef, createContext, useContext } from 'react'",
"import {",
" MizanProvider,",
" useMizan,",
" useMizanContext,",
" useMizanCall,",
" type MizanHydration,",
" type Transport,",
"} from 'mizan'",
...(hasChannels ? [
"import { ChannelProvider, ChannelConnection } from 'mizan/channels'",
] : []),
'',
]
if (uniqueTypeImports.length > 0) {
lines.push(`import type { ${uniqueTypeImports.join(', ')} } from './generated.mizan'`)
lines.push('')
}
// ============================================================================
// Hydration types (global contexts only)
// ============================================================================
lines.push('// ============================================================================')
lines.push('// Hydration Types')
lines.push('// ============================================================================')
lines.push('')
if (globalContexts.length > 0) {
lines.push('/** Typed hydration data for SSR (global contexts only) */')
lines.push('export interface MizanHydrationData {')
for (const ctx of globalContexts) {
lines.push(` ${ctx.camelName}?: ${ctx.outputType}`)
}
lines.push('}')
lines.push('')
lines.push('function toMizanHydration(hydration?: MizanHydrationData): MizanHydration | undefined {')
lines.push(' if (!hydration) return undefined')
lines.push(' const result: MizanHydration = {}')
for (const ctx of globalContexts) {
lines.push(` if (hydration.${ctx.camelName} !== undefined) result['${ctx.name}'] = hydration.${ctx.camelName}`)
}
lines.push(' return result')
lines.push('}')
lines.push('')
}
// ============================================================================
// Global Context Loader (inner component, fetches GET /ctx/global/)
// ============================================================================
if (globalContexts.length > 0) {
lines.push('// ============================================================================')
lines.push('// Global Context Loader')
lines.push('// ============================================================================')
lines.push('')
lines.push('function GlobalContextLoader({ children }: { children: ReactNode }) {')
lines.push(' const mizan = useMizan()')
lines.push(' const loaded = useRef(false)')
lines.push('')
lines.push(' useEffect(() => {')
lines.push(' if (loaded.current) return')
lines.push(' loaded.current = true')
lines.push('')
lines.push(' // Check for SSR hydration data first')
lines.push(" const ssr = typeof window !== 'undefined' && (window as any).__MIZAN_SSR_DATA__")
lines.push(' if (ssr) {')
lines.push(' for (const [name, data] of Object.entries(ssr)) {')
lines.push(' mizan.setContextData(name, data)')
lines.push(' }')
lines.push(' return')
lines.push(' }')
lines.push('')
lines.push(' ;(async () => {')
lines.push(' await mizan.whenReady')
lines.push(' try {')
lines.push(" const response = await mizan.request('GET', `${mizan.baseUrl}/ctx/global/`)")
lines.push(' const result = await response.json()')
lines.push(' for (const [name, data] of Object.entries(result)) {')
lines.push(' mizan.setContextData(name, data)')
lines.push(' }')
lines.push(' } catch (e) {')
lines.push(" console.error('[MizanContext] Global context fetch failed:', e)")
lines.push(' }')
lines.push(' })()')
lines.push(' }, [mizan])')
lines.push('')
lines.push(' return <>{children}</>')
lines.push('}')
lines.push('')
}
// ============================================================================
// Root Provider (MizanContext)
// ============================================================================
lines.push('// ============================================================================')
lines.push('// Root Provider')
lines.push('// ============================================================================')
lines.push('')
lines.push('export interface MizanContextProps {')
lines.push(' children: ReactNode')
if (globalContexts.length > 0) {
lines.push(' /** SSR hydration data (global contexts only) */')
lines.push(' hydration?: MizanHydrationData')
}
lines.push(' /** WebSocket URL for RPC calls (default: /ws/) */')
lines.push(' wsUrl?: string')
lines.push(' /** Base URL for HTTP calls (default: /api/mizan) */')
lines.push(' baseUrl?: string')
lines.push('}')
lines.push('')
lines.push('/**')
lines.push(' * Root mizan provider. Mount at your app root.')
lines.push(' *')
lines.push(' * Usage:')
lines.push(' * <MizanContext hydration={hydration}>')
lines.push(' * <App />')
lines.push(' * </MizanContext>')
lines.push(' */')
lines.push('export function MizanContext({')
lines.push(' children,')
if (globalContexts.length > 0) {
lines.push(' hydration,')
}
lines.push(' wsUrl,')
lines.push(' baseUrl,')
lines.push('}: MizanContextProps) {')
if (hasChannels) {
lines.push(' const connectionRef = useRef<ChannelConnection | null>(null)')
lines.push(' if (!connectionRef.current) {')
lines.push(" connectionRef.current = new ChannelConnection({ url: wsUrl || '/ws/' })")
lines.push(' }')
lines.push('')
}
// Build the JSX tree
lines.push(' return (')
lines.push(' <MizanProvider')
if (globalContexts.length > 0) {
lines.push(' hydration={toMizanHydration(hydration)}')
}
lines.push(' wsUrl={wsUrl}')
lines.push(' baseUrl={baseUrl}')
if (hasChannels) {
lines.push(' connection={connectionRef.current}')
}
lines.push(' >')
// Inner content: GlobalContextLoader wraps children if needed
let innerContent = '{children}'
if (globalContexts.length > 0) {
innerContent = `<GlobalContextLoader>{children}</GlobalContextLoader>`
}
if (hasChannels) {
lines.push(` <ChannelProvider connection={connectionRef.current} autoConnect={true}>`)
lines.push(` ${innerContent}`)
lines.push(` </ChannelProvider>`)
} else {
lines.push(` ${innerContent}`)
}
lines.push(' </MizanProvider>')
lines.push(' )')
lines.push('}')
lines.push('')
// Legacy alias
lines.push('/** @deprecated Use MizanContext instead */')
lines.push('export const DjangoContext = MizanContext')
lines.push('/** @deprecated Use MizanContextProps instead */')
lines.push('export type DjangoContextProps = MizanContextProps')
if (globalContexts.length > 0) {
lines.push('/** @deprecated Use MizanHydrationData instead */')
lines.push('export type DjangoHydration = MizanHydrationData')
}
lines.push('')
// ============================================================================
// Global Context Hooks
// ============================================================================
if (globalContexts.length > 0) {
lines.push('// ============================================================================')
lines.push('// Global Context Hooks')
lines.push('// ============================================================================')
lines.push('')
for (const ctx of globalContexts) {
const pascal = pascalCase(ctx.camelName)
lines.push(`/** Get ${ctx.name} context data. @throws if not loaded yet */`)
lines.push(`export function use${pascal}(): ${ctx.outputType} {`)
lines.push(` const data = useMizanContext<${ctx.outputType}>('${ctx.name}')`)
lines.push(` if (data === undefined) throw new Error('use${pascal}: context not loaded yet')`)
lines.push(` return data`)
lines.push(`}`)
lines.push('')
}
lines.push('/** Refresh functions for global contexts. */')
lines.push('export function useMizanRefresh() {')
lines.push(' const { invalidateContext } = useMizan()')
lines.push(' return {')
for (const ctx of globalContexts) {
const pascal = pascalCase(ctx.camelName)
lines.push(` refresh${pascal}: () => invalidateContext('${ctx.name}'),`)
}
lines.push(' }')
lines.push('}')
lines.push('')
// Legacy alias
lines.push('/** @deprecated Use useMizanRefresh instead */')
lines.push('export const useDjangoRefresh = useMizanRefresh')
lines.push('')
}
// ============================================================================
// Named Context Providers
// ============================================================================
if (namedContextEntries.length > 0) {
lines.push('// ============================================================================')
lines.push('// Named Context Providers')
lines.push('// ============================================================================')
lines.push('')
for (const [ctxName, ctxMeta] of namedContextEntries) {
const ctxPascal = toPascalCase(ctxName)
const ctxFunctions = functions.filter(fn => fn.isContext === ctxName)
const params = ctxMeta.params || {}
const paramEntries = Object.entries(params)
// Internal React context type
lines.push(`const ${ctxPascal}ContextInternal = createContext<{`)
for (const fn of ctxFunctions) {
lines.push(` ${fn.name}: ${fn.outputType}`)
}
lines.push(`} | null>(null)`)
lines.push('')
// Props interface
lines.push(`export interface ${ctxPascal}ContextProps {`)
lines.push(` children: ReactNode`)
for (const [pName, pMeta] of paramEntries) {
const tsType = jsonTypeToTS(pMeta.type)
const optional = pMeta.required ? '' : '?'
lines.push(` ${pName}${optional}: ${tsType}`)
}
lines.push(`}`)
lines.push('')
// Provider component
lines.push(`export function ${ctxPascal}Context({ children, ...params }: ${ctxPascal}ContextProps) {`)
lines.push(` const mizan = useMizan()`)
// SSR hydration check — initialize from __MIZAN_SSR_DATA__ if available
lines.push(` const [data, setData] = useState<{`)
for (const fn of ctxFunctions) {
lines.push(` ${fn.name}: ${fn.outputType}`)
}
lines.push(` } | null>(() => {`)
lines.push(` if (typeof window === 'undefined') return null`)
lines.push(` const ssr = (window as any).__MIZAN_SSR_DATA__`)
lines.push(` if (!ssr) return null`)
// Check if all functions for this context have SSR data
const firstFn = ctxFunctions[0]
lines.push(` if (ssr.${firstFn.name} === undefined) return null`)
lines.push(` return {`)
for (const fn of ctxFunctions) {
lines.push(` ${fn.name}: ssr.${fn.name},`)
}
lines.push(` }`)
lines.push(` })`)
lines.push('')
lines.push(` const refetch = useCallback(async () => {`)
lines.push(` await mizan.whenReady`)
lines.push(` const qs = new URLSearchParams()`)
for (const [pName] of paramEntries) {
lines.push(` if (params.${pName} !== undefined) qs.set('${pName}', String(params.${pName}))`)
}
lines.push(` const resp = await mizan.request('GET', \`\${mizan.baseUrl}/ctx/${ctxName}/?\${qs}\`)`)
lines.push(` const result = await resp.json()`)
lines.push(` setData(result)`)
// Dependency array: mizan + each param
const deps = ['mizan', ...paramEntries.map(([pName]) => `params.${pName}`)]
lines.push(` }, [${deps.join(', ')}])`)
lines.push('')
lines.push(` useEffect(() => { refetch() }, [refetch])`)
lines.push(` useEffect(() => mizan.registerContextProvider('${ctxName}', refetch), [mizan, refetch])`)
lines.push('')
lines.push(` return <${ctxPascal}ContextInternal value={data}>{children}</${ctxPascal}ContextInternal>`)
lines.push(`}`)
lines.push('')
// Individual data hooks
for (const fn of ctxFunctions) {
const hookPascal = pascalCase(fn.camelName)
lines.push(`export function use${hookPascal}(): ${fn.outputType} {`)
lines.push(` const ctx = useContext(${ctxPascal}ContextInternal)`)
lines.push(` if (!ctx) throw new Error('use${hookPascal} must be used within ${ctxPascal}Context')`)
lines.push(` return ctx.${fn.name}`)
lines.push(`}`)
lines.push('')
}
}
}
// ============================================================================
// Mutation Hooks (with auto-invalidation)
// ============================================================================
if (mutationFunctions.length > 0) {
lines.push('// ============================================================================')
lines.push('// Mutation Hooks (auto-invalidate on success)')
lines.push('// ============================================================================')
lines.push('')
for (const fn of mutationFunctions) {
const pascal = pascalCase(fn.camelName)
const transport = fn.transport || 'http'
const affectedContexts = getAffectedContexts(fn.affects)
lines.push(`/** Call ${fn.name}. Auto-invalidates: ${affectedContexts.join(', ')} */`)
lines.push(`export function use${pascal}() {`)
lines.push(` const mizan = useMizan()`)
if (fn.hasInput) {
lines.push(` return useCallback(async (input: ${fn.inputType}) => {`)
lines.push(` const result = await mizan.call<${fn.inputType}, ${fn.outputType}>('${fn.name}', input, '${transport}')`)
} else {
lines.push(` return useCallback(async () => {`)
lines.push(` const result = await mizan.call<void, ${fn.outputType}>('${fn.name}', undefined, '${transport}')`)
}
// Invalidation
if (affectedContexts.length === 1) {
lines.push(` await mizan.invalidateContext('${affectedContexts[0]}')`)
} else if (affectedContexts.length > 1) {
lines.push(` await Promise.all([`)
for (const ctx of affectedContexts) {
lines.push(` mizan.invalidateContext('${ctx}'),`)
}
lines.push(` ])`)
}
lines.push(` return result`)
lines.push(` }, [mizan])`)
lines.push(`}`)
lines.push('')
}
}
// ============================================================================
// Plain Function Hooks
// ============================================================================
if (plainFunctions.length > 0) {
lines.push('// ============================================================================')
lines.push('// Function Hooks')
lines.push('// ============================================================================')
lines.push('')
for (const fn of plainFunctions) {
const pascal = pascalCase(fn.camelName)
const transport = fn.transport || 'http'
if (fn.hasInput) {
lines.push(`/** Call ${fn.name}. Transport: ${transport} */`)
lines.push(`export function use${pascal}() {`)
lines.push(` return useMizanCall<${fn.inputType}, ${fn.outputType}>('${fn.name}', '${transport}')`)
lines.push(`}`)
} else {
lines.push(`/** Call ${fn.name}. Transport: ${transport} */`)
lines.push(`export function use${pascal}() {`)
lines.push(` return useMizanCall<void, ${fn.outputType}>('${fn.name}', '${transport}')`)
lines.push(`}`)
}
lines.push('')
}
}
// ============================================================================
// Re-exports
// ============================================================================
lines.push('// ============================================================================')
lines.push('// Re-exports from mizan library')
lines.push('// ============================================================================')
lines.push('')
lines.push("export { useMizan, useMizanStatus, usePush, DjangoError } from 'mizan'")
lines.push("export type { ConnectionStatus, PushMessage, PushListener } from 'mizan'")
lines.push('')
return lines.join('\n')
}
/**
* Generate server-side hydration helper (runs in Next.js server components).
* This is separate from the client file because it needs to run on the server.
*/
export function generateMizanServer(schema) {
const functions = schema['x-mizan-functions'] || []
const globalContexts = functions.filter(fn => fn.isContext === 'global')
if (globalContexts.length === 0) {
return null
}
// Collect type imports for global contexts
const typeImports = globalContexts.map(ctx => ctx.outputType).filter(Boolean)
const uniqueTypeImports = [...new Set(typeImports)].sort()
const lines = [
'// AUTO-GENERATED by mizan - do not edit manually',
'// Regenerate with: npm run schemas',
'//',
'// Server-side functions for SSR hydration.',
'// These run in Next.js server components/layouts.',
'',
]
if (uniqueTypeImports.length > 0) {
lines.push(`import type { ${uniqueTypeImports.join(', ')} } from './generated.mizan'`)
lines.push('')
}
// Hydration type
lines.push('// ============================================================================')
lines.push('// Hydration Types')
lines.push('// ============================================================================')
lines.push('')
lines.push('/** Typed hydration data for SSR (global contexts only) */')
lines.push('export interface MizanHydrationData {')
for (const ctx of globalContexts) {
lines.push(` ${ctx.camelName}?: ${ctx.outputType}`)
}
lines.push('}')
lines.push('')
lines.push('/** @deprecated Use MizanHydrationData instead */')
lines.push('export type DjangoHydration = MizanHydrationData')
lines.push('')
// SSR Hydration Helper — single bundled GET
lines.push('// ============================================================================')
lines.push('// SSR Hydration Helper')
lines.push('// ============================================================================')
lines.push('')
lines.push('/**')
lines.push(' * Fetch hydration data for SSR via bundled context endpoint.')
lines.push(' *')
lines.push(' * Call this in your server component:')
lines.push(' * const hydration = await getMizanHydration(client)')
lines.push(' * return <MizanContext hydration={hydration}>...</MizanContext>')
lines.push(' */')
lines.push('export async function getMizanHydration(')
lines.push(" client: { request: (method: string, url: string, body?: unknown) => Promise<Response> }")
lines.push('): Promise<MizanHydrationData> {')
lines.push(' const hydration: MizanHydrationData = {}')
lines.push('')
lines.push(' try {')
lines.push(" const response = await client.request('GET', '/api/mizan/ctx/global/')")
lines.push(' if (response.ok) {')
lines.push(' const result = await response.json()')
for (const ctx of globalContexts) {
lines.push(` if (result?.${ctx.name} !== undefined) hydration.${ctx.camelName} = result.${ctx.name}`)
}
lines.push(' } else {')
lines.push(" console.error('[getMizanHydration] Global context fetch failed:', response.status)")
lines.push(' }')
lines.push(' } catch (e) {')
lines.push(" console.error('[getMizanHydration] Request failed:', e)")
lines.push(' }')
lines.push('')
lines.push(' return hydration')
lines.push('}')
lines.push('')
lines.push('/** @deprecated Use getMizanHydration instead */')
lines.push('export const getDjangoHydration = getMizanHydration')
lines.push('')
return lines.join('\n')
}
/**
* Generate all mizan files.
*/
export async function generateMizanFiles(schema, options = {}) {
const types = await generateMizanTypes(schema)
const provider = generateMizanProvider(schema, options)
const server = generateMizanServer(schema)
const forms = generateMizanForms(schema)
return { types, provider, server, forms }
}
/**
* Generate typed form hooks with Zod schemas.
*/
export function generateMizanForms(schema) {
const functions = schema['x-mizan-functions'] || []
// Group form functions by form name
const formFunctions = functions.filter(fn => fn.isForm)
const formGroups = new Map()
for (const fn of formFunctions) {
const formName = fn.formName
if (!formGroups.has(formName)) {
formGroups.set(formName, { schema: null, validate: null, submit: null, formset: {} })
}
const group = formGroups.get(formName)
if (fn.formRole === 'schema') {
group.schema = fn
group.formFields = fn.formFields || []
} else if (fn.formRole === 'validate') {
group.validate = fn
} else if (fn.formRole === 'submit') {
group.submit = fn
} else if (fn.formRole === 'formset_schema') {
group.formset.schema = fn
} else if (fn.formRole === 'formset_validate') {
group.formset.validate = fn
} else if (fn.formRole === 'formset_submit') {
group.formset.submit = fn
}
}
if (formGroups.size === 0) {
return null
}
const lines = [
"'use client'",
'',
'// AUTO-GENERATED by mizan - do not edit manually',
'// Regenerate with: npm run schemas',
'',
'// Typed form hooks with Zod validation.',
'// Zod schemas are generated from Django form field definitions.',
'// Client-side validation matches Django constraints (required, max_length, email, etc.)',
'',
"import { z } from 'zod'",
"import {",
" useDjangoFormCore,",
" useDjangoFormsetCore,",
" type DjangoFormState,",
" type DjangoFormsetState,",
" type FormOptions,",
"} from 'mizan'",
'',
'// ============================================================================',
'// Zod Schemas',
'// ============================================================================',
'',
]
// Generate Zod schemas for each form
for (const [formName, group] of formGroups) {
if (!group.schema) continue
const pascalName = toPascalCase(formName)
const schemaName = `${pascalName}Schema`
const fields = group.formFields || []
lines.push(`/**`)
lines.push(` * Zod schema for ${formName} form`)
lines.push(` * Generated from Django form field definitions`)
lines.push(` */`)
lines.push(`export const ${schemaName} = z.object({`)
for (const field of fields) {
const zodField = generateZodField(field)
lines.push(` ${field.name}: ${zodField},`)
}
lines.push(`})`)
lines.push('')
}
// Generate TypeScript types from Zod schemas
lines.push('// ============================================================================')
lines.push('// Form Data Types (inferred from Zod schemas)')
lines.push('// ============================================================================')
lines.push('')
for (const [formName, group] of formGroups) {
if (!group.schema) continue
const pascalName = toPascalCase(formName)
const schemaName = `${pascalName}Schema`
const typeName = `${pascalName}FormData`
lines.push(`/** Form data type for ${formName}, inferred from Zod schema */`)
lines.push(`export type ${typeName} = z.infer<typeof ${schemaName}>`)
lines.push('')
}
lines.push('// ============================================================================')
lines.push('// Form Hooks')
lines.push('// ============================================================================')
lines.push('')
// Generate hooks for each form
for (const [formName, group] of formGroups) {
if (!group.schema) continue
const pascalName = toPascalCase(formName)
const hookName = `use${pascalName}Form`
const typeName = `${pascalName}FormData`
const schemaName = `${pascalName}Schema`
lines.push(`/**`)
lines.push(` * Typed form hook for ${formName}`)
lines.push(` *`)
lines.push(` * Features:`)
lines.push(` * - Full TypeScript inference for form fields`)
lines.push(` * - Client-side Zod validation (instant feedback)`)
lines.push(` * - Server-side Django validation (authoritative)`)
lines.push(` */`)
lines.push(`export function ${hookName}(`)
lines.push(` options?: FormOptions`)
lines.push(`): DjangoFormState<${typeName}> {`)
lines.push(` return useDjangoFormCore<${typeName}>({`)
lines.push(` name: '${formName}',`)
lines.push(` zodSchema: ${schemaName},`)
lines.push(` options,`)
lines.push(` })`)
lines.push(`}`)
lines.push('')
// Generate formset hook if formset is enabled
if (group.formset.schema) {
const formsetHookName = `use${pascalName}Formset`
lines.push(`/**`)
lines.push(` * Typed formset hook for ${formName}`)
lines.push(` */`)
lines.push(`export function ${formsetHookName}(`)
lines.push(` initialCount?: number,`)
lines.push(` liveValidation?: boolean`)
lines.push(`): DjangoFormsetState<${typeName}> {`)
lines.push(` return useDjangoFormsetCore<${typeName}>({`)
lines.push(` name: '${formName}',`)
lines.push(` zodSchema: ${schemaName},`)
lines.push(` initialCount,`)
lines.push(` liveValidation,`)
lines.push(` })`)
lines.push(`}`)
lines.push('')
}
}
// Export list of form names for reference
lines.push('// ============================================================================')
lines.push('// Form Registry')
lines.push('// ============================================================================')
lines.push('')
lines.push('export const MIZAN_FORMS = {')
for (const [formName, group] of formGroups) {
if (!group.schema) continue
const pascalName = toPascalCase(formName)
lines.push(` ${toCamelCase(formName)}: {`)
lines.push(` name: '${formName}',`)
lines.push(` schema: ${pascalName}Schema,`)
lines.push(` hook: 'use${pascalName}Form',`)
lines.push(` hasFormset: ${!!group.formset.schema},`)
lines.push(` },`)
}
lines.push('} as const')
lines.push('')
return lines.join('\n')
}
/**
* Generate a Zod field definition from Django field metadata.
*/
function generateZodField(field) {
const { zodType, required, constraints } = field
let zodCode = ''
// Base type
switch (zodType) {
case 'boolean':
zodCode = 'z.boolean()'
break
case 'number':
zodCode = 'z.number()'
if (constraints.int) {
zodCode += '.int()'
}
break
case 'array':
zodCode = `z.array(z.${constraints.items || 'string'}())`
break
case 'file':
zodCode = 'z.any()'
break
default:
zodCode = 'z.string()'
}
// Add constraints
if (zodType === 'string') {
if (constraints.email) {
zodCode += ".email('Invalid email address')"
} else if (constraints.url) {
zodCode += ".url('Invalid URL')"
}
if (constraints.regex) {
const escapedRegex = constraints.regex.replace(/\\/g, '\\\\').replace(/'/g, "\\'")
const message = constraints.regexMessage || 'Invalid format'
zodCode += `.regex(new RegExp('${escapedRegex}'), '${message}')`
}
if (constraints.min !== undefined) {
zodCode += `.min(${constraints.min})`
}
if (constraints.max !== undefined) {
zodCode += `.max(${constraints.max})`
}
} else if (zodType === 'number') {
if (constraints.min !== undefined) {
zodCode += `.min(${constraints.min})`
}
if (constraints.max !== undefined) {
zodCode += `.max(${constraints.max})`
}
}
// Handle optional fields
if (!required) {
if (zodType === 'boolean') {
zodCode += '.default(false)'
} else {
zodCode += '.optional()'
}
}
return zodCode
}
/**
* Convert form name to PascalCase for type names.
*/
function toPascalCase(str) {
return str
.split(/[.\-_]/)
.map(part => part.charAt(0).toUpperCase() + part.slice(1))
.join('')
}
/**
* Convert form name to camelCase for object keys.
*/
function toCamelCase(str) {
const pascal = toPascalCase(str)
return pascal.charAt(0).toLowerCase() + pascal.slice(1)
}
/**
* Convert camelCase to PascalCase.
*/
function pascalCase(str) {
return str.charAt(0).toUpperCase() + str.slice(1)
}

View File

@@ -0,0 +1,198 @@
/**
* Stage 1 Codegen — Framework-agnostic TypeScript output.
*
* Produces:
* types.ts — interfaces from OpenAPI schema
* contexts/<name>.ts — fetchXxxContext(params) per context group
* mutations/<name>.ts — callXxx(args) per mutation
* functions/<name>.ts — callXxx(args) per plain function
* index.ts — re-exports
*/
import openapiTS, { astToString } from 'openapi-typescript'
// ─── Helpers ────────────────────────────────────────────────────────────────
function pascalCase(str) {
return str
.split(/[.\-_]/)
.map(part => part.charAt(0).toUpperCase() + part.slice(1))
.join('')
}
function camelCase(str) {
const p = pascalCase(str)
return p.charAt(0).toLowerCase() + p.slice(1)
}
// TypeScript SyntaxKind values for openapi-typescript AST
const SyntaxKind = {
InterfaceDeclaration: 265,
PropertySignature: 172,
Identifier: 80,
}
function idName(node) {
return node?.kind === SyntaxKind.Identifier ? node.escapedText : undefined
}
function getSchemaNamesFromAst(ast) {
if (!Array.isArray(ast)) return []
const componentsNode = ast.find(
n => n?.kind === SyntaxKind.InterfaceDeclaration && idName(n?.name) === 'components'
)
if (!componentsNode?.members) return []
const schemasProp = componentsNode.members.find(
m => m?.kind === SyntaxKind.PropertySignature && idName(m?.name) === 'schemas' && Array.isArray(m?.type?.members)
)
if (!schemasProp) return []
return schemasProp.type.members
.map(m => m?.kind === SyntaxKind.PropertySignature ? idName(m.name) : undefined)
.filter(n => typeof n === 'string')
}
// ─── Types ──────────────────────────────────────────────────────────────────
export async function generateTypes(schema) {
const ast = await openapiTS(schema)
const schemaNames = getSchemaNamesFromAst(ast)
const typesCode = astToString(ast)
const lines = [
'// AUTO-GENERATED by mizan — do not edit',
'',
typesCode,
'',
'// Convenience type exports',
...schemaNames.map(name => `export type ${name} = components["schemas"]["${name}"]`),
'',
]
return lines.join('\n')
}
// ─── Context Files ──────────────────────────────────────────────────────────
export function generateContextFile(ctxName, ctxMeta, functions) {
const pascal = pascalCase(ctxName)
const ctxFunctions = functions.filter(fn => fn.isContext === ctxName)
const lines = [
'// AUTO-GENERATED by mizan — do not edit',
'',
"import { mizanFetch } from '@mizan/runtime'",
'',
]
// Import output types
const typeImports = ctxFunctions.map(fn => fn.outputType).filter(Boolean)
if (typeImports.length > 0) {
lines.push(`import type { ${[...new Set(typeImports)].join(', ')} } from '../types'`)
lines.push('')
}
// Data interface
lines.push(`export interface ${pascal}ContextData {`)
for (const fn of ctxFunctions) {
lines.push(` ${fn.name}: ${fn.outputType}`)
}
lines.push('}')
lines.push('')
// Params interface (from x-mizan-contexts)
const params = ctxMeta?.params || {}
const paramEntries = Object.entries(params)
if (paramEntries.length > 0) {
lines.push(`export interface ${pascal}ContextParams {`)
for (const [pName, pMeta] of paramEntries) {
const tsType = pMeta.type === 'integer' || pMeta.type === 'number' ? 'number' : pMeta.type === 'boolean' ? 'boolean' : 'string'
const optional = pMeta.required ? '' : '?'
lines.push(` ${pName}${optional}: ${tsType}`)
}
lines.push('}')
} else {
lines.push(`export type ${pascal}ContextParams = Record<string, never>`)
}
lines.push('')
// Fetch function
lines.push(`export function fetch${pascal}Context(params: ${pascal}ContextParams): Promise<${pascal}ContextData> {`)
lines.push(` return mizanFetch('${ctxName}', params)`)
lines.push('}')
lines.push('')
return lines.join('\n')
}
// ─── Mutation Files ─────────────────────────────────────────────────────────
export function generateMutationFile(fn) {
const pascal = pascalCase(fn.camelName)
const lines = [
'// AUTO-GENERATED by mizan — do not edit',
'',
"import { mizanCall } from '@mizan/runtime'",
'',
]
// Import types
const typeImports = []
if (fn.hasInput && fn.inputType) typeImports.push(fn.inputType)
if (fn.outputType) typeImports.push(fn.outputType)
if (typeImports.length > 0) {
lines.push(`import type { ${[...new Set(typeImports)].join(', ')} } from '../types'`)
lines.push('')
}
// Call function
if (fn.hasInput) {
lines.push(`export function call${pascal}(args: ${fn.inputType}): Promise<${fn.outputType}> {`)
} else {
lines.push(`export function call${pascal}(): Promise<${fn.outputType}> {`)
}
lines.push(` return mizanCall('${fn.name}', ${fn.hasInput ? 'args' : '{}'})`)
lines.push('}')
lines.push('')
return lines.join('\n')
}
// ─── Function Files (plain, no context, no affects) ─────────────────────────
export function generateFunctionFile(fn) {
// Same shape as mutation, just different semantics
return generateMutationFile(fn)
}
// ─── Index ──────────────────────────────────────────────────────────────────
export function generateStage1Index(schema) {
const functions = schema['x-mizan-functions'] || []
const contextGroups = schema['x-mizan-contexts'] || {}
const lines = [
'// AUTO-GENERATED by mizan — do not edit',
'',
"export * from './types'",
'',
]
// Context exports
for (const ctxName of Object.keys(contextGroups)) {
const pascal = pascalCase(ctxName)
lines.push(`export { fetch${pascal}Context, type ${pascal}ContextData, type ${pascal}ContextParams } from './contexts/${ctxName}'`)
}
if (Object.keys(contextGroups).length > 0) lines.push('')
// Mutation + function exports
const regularFns = functions.filter(fn => !fn.isContext && !fn.isForm)
for (const fn of regularFns) {
const pascal = pascalCase(fn.camelName)
lines.push(`export { call${pascal} } from './${fn.affects ? 'mutations' : 'functions'}/${fn.camelName}'`)
}
if (regularFns.length > 0) lines.push('')
return lines.join('\n')
}

View File

@@ -0,0 +1,16 @@
{
"name": "generate",
"version": "1.0.0",
"description": "",
"main": "index.js",
"scripts": {
"test": "echo \"Error: no test specified\" && exit 1"
},
"keywords": [],
"author": "",
"license": "ISC",
"type": "commonjs",
"dependencies": {
"openapi-typescript": "^7.13.0"
}
}

View File

@@ -0,0 +1,49 @@
[project]
name = "mizan"
version = "1.0.1"
description = "Django + React server functions framework"
readme = "README.md"
requires-python = ">=3.10"
dependencies = [
"django>=5.0",
"django-ninja>=1.0",
"django-readers>=2.0",
"pydantic>=2.0",
"PyJWT>=2.0",
]
[project.optional-dependencies]
cache = [
"redis>=5.0",
]
channels = [
"channels>=4.0",
"channels-redis>=4.0",
]
allauth = [
"django-allauth>=65.0",
]
webauthn = [
"fido2>=2.0",
]
shapes = [
"django-readers>=2.0",
]
dev = [
"pytest>=8.0",
"pytest-django>=4.9",
]
[build-system]
requires = ["hatchling"]
build-backend = "hatchling.build"
[tool.hatch.build.targets.wheel]
packages = ["src/mizan"]
[tool.pytest.ini_options]
DJANGO_SETTINGS_MODULE = "tests.settings"
pythonpath = ["src", "."]
testpaths = ["src/mizan/tests"]
python_classes = ["*Tests", "*Test", "Test*"]
python_functions = ["test_*"]

View File

@@ -0,0 +1,187 @@
"""
mizan - Django + React unified framework
Server functions are the core primitive. Everything else builds on them.
## Quick Start
### 1. urls.py - HTTP endpoint
```python
from mizan import urls as mizan_urls
urlpatterns = [
path('api/mizan/', include(mizan_urls)),
]
```
### 2. asgi.py - WebSocket support (optional)
```python
from mizan import wrap_asgi
from django.core.asgi import get_asgi_application
application = wrap_asgi(get_asgi_application())
```
### 3. Define server functions
```python
# apps/myapp/clients.py
from mizan import client
from pydantic import BaseModel
class EchoOutput(BaseModel):
message: str
# HTTP-only function (default)
@client
def echo(request, text: str) -> EchoOutput:
return EchoOutput(message=f"Echo: {text}")
# Global context (singleton, SSR-hydrated)
@client(context='global')
def current_user(request) -> UserOutput:
return UserOutput(email=request.user.email)
# WebSocket-enabled for real-time
@client(websocket=True)
def send_message(request, room_id: int, text: str) -> MessageOutput:
return MessageOutput(...)
```
### 4. Auto-discover in apps.py
```python
class MyAppConfig(AppConfig):
def ready(self):
from mizan.setup import mizan_clients
mizan_clients('apps')
```
### 5. Frontend - generate types and use
```bash
npm run schemas
```
```tsx
import { useEcho, useCurrentUser } from '@/api'
const user = useCurrentUser()
const echo = useEcho()
await echo({ text: 'hello' })
```
## What You Get
| Backend | Frontend | Transport |
|------------------------------------|-----------------------|------------|
| `@client` | `useXxx()` hook | HTTP |
| `@client(context='global')` | `useXxx()` + SSR | HTTP |
| `@client(context='local')` | `<XxxProvider>` + hook| HTTP |
| `@client(websocket=True)` | `useXxx()` hook | WebSocket |
| `@compose(...)` | `<XxxProvider>` combined | varies |
| `mizanFormMixin` | `useXxxForm()` + Zod | HTTP |
| `ReactChannel` | `useXxxChannel()` | WebSocket |
"""
# All imports at module level (sorted)
from . import channels
from . import client as client_module
from . import export
from . import forms
from . import setup
from .channels import ReactChannel
from .channels import register as register_channel
from .client import ComposedContext, GlobalContext, ReactContext, ServerFunction, client, compose
# Shape is lazy-loaded via __getattr__ because django_readers
# imports contenttypes, which can't happen during apps.populate()
from .setup import (
mizan_clients,
mizan_module,
get_channel,
get_function,
register,
register_as,
)
def __getattr__(name):
"""Lazy loading for modules that can't be imported at app load time."""
if name == "urls":
from .urls import urlpatterns as mizan_patterns
return mizan_patterns
if name == "Shape":
from .shapes import Shape
return Shape
raise AttributeError(f"module {__name__!r} has no attribute {name!r}")
def wrap_asgi(http_application):
"""
Wrap an ASGI application with mizan WebSocket support.
Usage in asgi.py:
from django.core.asgi import get_asgi_application
from mizan import wrap_asgi
application = wrap_asgi(get_asgi_application())
This adds:
- WebSocket routing at /ws/ for RPC and channels
- Authentication middleware for WebSocket connections
"""
try:
from channels.auth import AuthMiddlewareStack
from channels.routing import ProtocolTypeRouter, URLRouter
from django.urls import path
except ImportError:
raise ImportError(
"django-channels is required for WebSocket support.\n"
"Install with: pip install channels channels-redis\n"
"Add 'channels' to INSTALLED_APPS and configure CHANNEL_LAYERS."
)
from .channels.connection import DjangoReactConsumer
return ProtocolTypeRouter(
{
"http": http_application,
"websocket": AuthMiddlewareStack(
URLRouter(
[
path("ws/", DjangoReactConsumer.as_asgi()),
]
)
),
}
)
__all__ = [
# Decorators & Contexts
"client",
"compose",
"ReactContext",
"GlobalContext",
"ServerFunction",
"ComposedContext",
# Setup
"mizan_clients",
"mizan_module",
"register",
"register_as",
"get_function",
"get_channel",
# ASGI
"wrap_asgi",
# Channels
"ReactChannel",
"register_channel",
# Shapes
"Shape",
# Submodules
"client_module",
"setup",
"forms",
"channels",
"export",
]

View File

@@ -0,0 +1,91 @@
import inspect
from importlib import import_module
from inspect import isclass
from typing import Protocol, Any
from django.conf import settings
def get_members(path):
try:
module = import_module(path)
except ModuleNotFoundError:
print('Could not import module "{}"'.format(path))
return []
members = [
(name, member)
for name, member in inspect.getmembers(module)
if not isclass(member) or (member.__module__ == module.__name__)
]
return members
class DjangoAppVisitorHandler(Protocol):
def on_module(
self, app_name: str, path_parts: list[str], members: list[tuple[str, Any]]
) -> None: ...
class DjangoAppVisitor:
"""
Discovers Python modules under each Django app following conventions:
- <app>/<module>.py -> url_prefix "<renamed>/"
- <app>/<module>/**/*.py -> url_prefix "<renamed>/<subdirs...>/<module>/"
Example:
<app>/<module>/forms/nksn.py -> url_prefix "<renamed>/forms/nksn/"
module_path "<app>.module.forms.nksn"
"""
def __init__(
self,
*,
layer: str,
apps_root: str = "",
):
self.apps_root = apps_root
self.layer = layer
def visit(self, handler: DjangoAppVisitorHandler) -> None:
apps_dir = (
settings.BASE_DIR / self.apps_root if self.apps_root else settings.BASE_DIR
)
if not apps_dir.is_dir():
apps_dir = settings.BASE_DIR
module_prefix = f"{self.apps_root}." if self.apps_root else ""
for app_name in settings.INSTALLED_APPS:
if app_name.startswith(self.apps_root + "."):
app_name = app_name[(len(self.apps_root) + 1) :]
app_dir = apps_dir / app_name
if not app_dir.exists():
continue
app_module = f"{module_prefix}{app_name}"
# 1) Visit package: <app>/<module>/**/*.py
layer_dir = app_dir / self.layer
if layer_dir.is_dir():
for py_file in layer_dir.rglob("*.py"):
if py_file.name == "__init__.py":
continue
relative_path = py_file.relative_to(layer_dir).with_suffix("")
parts = list(relative_path.parts)
dotted = ".".join(parts)
handler.on_module(
app_name,
parts,
get_members(f"{app_module}.{self.layer}.{dotted}"),
)
# 2) Visit module module file: <app>/module.py
layer_file = app_dir / f"{self.layer}.py"
if layer_file.is_file():
handler.on_module(
app_name, [], get_members(f"{app_module}.{self.layer}")
)

View File

@@ -0,0 +1,65 @@
# Cache Module — Known Issues
Issues identified by 8-domain-expert review. Status tracked here.
## Critical (Security / Data Corruption)
### 1. ~~User-scoped content cached without user_id~~ FIXED
`context_fetch_view` now extracts `user_id` from `request.user.pk` and
passes it to `cache_get`/`cache_put`.
### 2. Purge race condition (non-atomic index operations)
`cache_purge` does index reads and deletes as separate operations.
Concurrent `cache_put` between steps can orphan entries.
**Status:** Partially mitigated by AND semantics fix. Full atomicity
(Lua script or WATCH/MULTI) still needed for Redis backend.
### 3. ~~No Redis error handling~~ FIXED
All cache operations in `executor.py` wrapped in try/except with
`logger.warning`. Redis failure falls through to uncached execution.
### 4. ~~Scoped purge uses OR semantics~~ FIXED
Changed to AND (intersection). `{user_id: 5, org_id: 3}` now only
deletes entries matching BOTH params.
## High (Correctness / Operability)
### 5. ~~No TTL on Redis entries~~ FIXED
`RedisCache.put` now sets `ex=86400` (24h safety-net TTL) by default.
### 6. Cross-language str() vs String() divergence
Python `str(True)` -> `"True"`, JS `String(true)` -> `"true"`.
**Status:** Open. Needs canonical stringification rules in protocol spec.
### 7. Broad purge doesn't clean per-param sub-indexes
**Status:** Open. Slow memory leak in Redis.
### 8. ~~build_index_keys doesn't stringify values~~ FIXED
Now calls `str(v)` on all values, matching `derive_cache_key`.
### 9. ~~Silent exception swallowing in get_cache()~~ FIXED
Now logs warnings for partial config and connection failures.
### 10. ~~_initialized flag not thread-safe~~ FIXED
Now uses `threading.Lock` for thread-safe initialization.
## Medium (Design / Performance)
### 11. No thundering-herd protection
**Status:** Open. Concurrent cold misses all execute and write.
### 12. ~~Wire-protocol internals in __all__~~ FIXED
`derive_cache_key` and `build_index_keys` removed from `__all__`.
### 13. Inconsistent API pattern
**Status:** Open. `cache_get`/`cache_put` take explicit args but executor
fetches from globals.
### 14. ~~clear() uses SCAN + DELETE without pipeline~~ FIXED
Now uses pipeline with UNLINK for batched async deletes.
### 15. ~~No Redis connection timeouts~~ FIXED
`socket_connect_timeout=5`, `socket_timeout=5`, `health_check_interval=30`.
### 16. No RedisCache test coverage
**Status:** Open. Only MemoryCache is tested.

View File

@@ -0,0 +1,142 @@
"""
mizan.cache — Origin-side cache implementing the Mizan cache protocol.
Simple key-value cache with HMAC-derived keys. No reverse indexes.
Scoped purge recomputes the key and deletes directly.
Broad purge uses key-prefix scan (rare operation).
Usage:
from mizan.cache import get_cache, cache_get, cache_put, cache_purge
"""
from __future__ import annotations
import logging
import threading
from typing import Any
from .backend import CacheBackend, MemoryCache, RedisCache
from .keys import derive_cache_key, CONTEXT_KEY_PREFIX
logger = logging.getLogger("mizan.cache")
_cache_instance: CacheBackend | None = None
_initialized = False
_init_lock = threading.Lock()
def get_cache() -> CacheBackend | None:
"""
Get the configured cache backend, or None if caching is disabled.
Thread-safe.
"""
global _cache_instance, _initialized
if _initialized:
return _cache_instance
with _init_lock:
if _initialized:
return _cache_instance
_initialized = True
try:
from mizan.setup.settings import get_settings
settings = get_settings()
if settings.cache_secret and settings.cache_redis_url:
_cache_instance = RedisCache(settings.cache_redis_url)
logger.info("Mizan cache enabled (Redis: %s)", settings.cache_redis_url)
elif settings.cache_secret and not settings.cache_redis_url:
logger.warning(
"MIZAN_CACHE_SECRET is set but MIZAN_CACHE_REDIS_URL is missing. "
"Cache is disabled."
)
elif settings.cache_redis_url and not settings.cache_secret:
logger.warning(
"MIZAN_CACHE_REDIS_URL is set but MIZAN_CACHE_SECRET is missing. "
"Cache is disabled."
)
except Exception:
logger.warning("Failed to initialize Mizan cache", exc_info=True)
_cache_instance = None
return _cache_instance
def set_cache(backend: CacheBackend | None) -> None:
"""Override the cache backend. For testing."""
global _cache_instance, _initialized
_cache_instance = backend
_initialized = True
def reset_cache() -> None:
"""Reset to uninitialized state. For testing teardown."""
global _cache_instance, _initialized
_cache_instance = None
_initialized = False
def cache_get(
secret: str,
backend: CacheBackend,
context: str,
params: dict[str, Any],
user_id: str | None = None,
rev: int = 0,
) -> bytes | None:
"""Look up a cached context response."""
key = derive_cache_key(secret, context, params, user_id, rev)
return backend.get(key)
def cache_put(
secret: str,
backend: CacheBackend,
context: str,
params: dict[str, Any],
value: bytes,
user_id: str | None = None,
rev: int = 0,
) -> None:
"""Store a context response in the cache."""
key = derive_cache_key(secret, context, params, user_id, rev)
backend.set(key, value)
def cache_purge(
backend: CacheBackend,
context: str,
params: dict[str, Any] | None = None,
secret: str | None = None,
user_id: str | None = None,
rev: int = 0,
) -> int:
"""
Purge cached entries for a context.
Scoped purge (params provided): recomputes the HMAC key and deletes
it directly. One DELETE, no index needed.
Broad purge (no params): scans by key prefix "ctx:{context}:*".
This is a rare operation (Tier 3 fallback in invalidation).
"""
if params is not None and len(params) > 0 and secret:
key = derive_cache_key(secret, context, params, user_id, rev)
return 1 if backend.delete(key) else 0
else:
prefix = f"{CONTEXT_KEY_PREFIX}{context}:"
return backend.delete_by_prefix(prefix)
__all__ = [
"CacheBackend",
"MemoryCache",
"RedisCache",
"get_cache",
"set_cache",
"reset_cache",
"cache_get",
"cache_put",
"cache_purge",
]

View File

@@ -0,0 +1,115 @@
"""
Cache backends — MemoryCache (testing) and RedisCache (production).
Simple key-value stores. No reverse indexes. Cache keys are derived
from HMAC, so scoped purge just recomputes the key and deletes it.
Broad purge uses key-prefix scan (rare operation).
"""
from __future__ import annotations
from typing import Protocol
class CacheBackend(Protocol):
"""Interface that all Mizan cache backends implement."""
def get(self, key: str) -> bytes | None: ...
def set(self, key: str, value: bytes) -> None: ...
def delete(self, key: str) -> bool: ...
def delete_by_prefix(self, prefix: str) -> int: ...
def clear(self) -> None: ...
class MemoryCache:
"""
In-memory cache backend for testing.
Uses a Python dict. No persistence, no cross-process sharing.
"""
def __init__(self) -> None:
self._store: dict[str, bytes] = {}
def get(self, key: str) -> bytes | None:
return self._store.get(key)
def set(self, key: str, value: bytes) -> None:
self._store[key] = value
def delete(self, key: str) -> bool:
if key in self._store:
del self._store[key]
return True
return False
def delete_by_prefix(self, prefix: str) -> int:
to_delete = [k for k in self._store if k.startswith(prefix)]
for k in to_delete:
del self._store[k]
return len(to_delete)
def clear(self) -> None:
self._store.clear()
class RedisCache:
"""
Redis-backed cache backend for production.
Simple GET/SET/DEL. No reverse indexes. Scoped purge recomputes
the HMAC key and deletes directly. Broad purge uses SCAN.
"""
DEFAULT_TTL = 86400 # 24h safety-net
def __init__(
self,
redis_url: str,
prefix: str = "mizan:",
ttl: int | None = None,
) -> None:
try:
import redis as redis_lib
except ImportError:
raise ImportError(
"Redis is required for Mizan's cache backend. "
"Install it with: pip install mizan[cache]"
)
self._client = redis_lib.from_url(
redis_url,
socket_connect_timeout=5,
socket_timeout=5,
health_check_interval=30,
retry_on_timeout=True,
max_connections=50,
)
self._prefix = prefix
self._ttl = ttl if ttl is not None else self.DEFAULT_TTL
def _key(self, key: str) -> str:
return f"{self._prefix}{key}"
def get(self, key: str) -> bytes | None:
return self._client.get(self._key(key))
def set(self, key: str, value: bytes) -> None:
self._client.set(self._key(key), value, ex=self._ttl)
def delete(self, key: str) -> bool:
return self._client.unlink(self._key(key)) > 0
def delete_by_prefix(self, prefix: str) -> int:
pattern = f"{self._prefix}{prefix}*"
count = 0
cursor = 0
while True:
cursor, keys = self._client.scan(cursor, match=pattern, count=1000)
if keys:
count += self._client.unlink(*keys)
if cursor == 0:
break
return count
def clear(self) -> None:
self.delete_by_prefix("")

View File

@@ -0,0 +1,59 @@
"""
Cache key derivation — HMAC-SHA256 over JSON-canonical form.
Protocol-critical: every Mizan adapter must produce identical output
for identical inputs. Cross-language conformance verified by pin tests.
Scoped purge recomputes the key directly — no reverse index needed.
Broad purge uses a context prefix scan.
"""
from __future__ import annotations
import hashlib
import hmac
import json
from typing import Any
# Context prefix for broad purge (SCAN pattern)
CONTEXT_KEY_PREFIX = "ctx:"
def derive_cache_key(
secret: str,
context: str,
params: dict[str, Any],
user_id: str | None = None,
rev: int = 0,
) -> str:
"""
Derive a deterministic HMAC-SHA256 cache key.
Returns a prefixed key: "ctx:{context}:{hmac_hex}" so that
broad purge can SCAN by prefix "ctx:{context}:*".
"""
def _normalize(v: Any) -> str:
"""Normalize values for cross-language HMAC consistency.
Python str(True)="True" but JS String(true)="true". Use JSON-native forms."""
if v is True:
return "true"
if v is False:
return "false"
if v is None:
return "null"
return str(v)
sorted_params = {k: _normalize(v) for k, v in sorted(params.items())}
key_data: dict[str, Any] = {"c": context, "p": sorted_params, "r": rev}
if user_id is not None:
key_data["u"] = str(user_id)
message = json.dumps(key_data, sort_keys=True, separators=(",", ":"))
hmac_hex = hmac.new(
secret.encode("utf-8"),
message.encode("utf-8"),
hashlib.sha256,
).hexdigest()
return f"{CONTEXT_KEY_PREFIX}{context}:{hmac_hex}"

View File

@@ -0,0 +1,543 @@
"""
mizan.channels - Real-time WebSocket communication.
Type-safe bidirectional messaging between Django and React via WebSockets.
Hooks are auto-generated with full TypeScript types.
## Basic Usage
```python
# channels.py
from pydantic import BaseModel
from mizan import channels
class ChatChannel(channels.ReactChannel):
class Params(BaseModel):
room: str
class ReactMessage(BaseModel):
text: str
class DjangoMessage(BaseModel):
user: str
text: str
timestamp: datetime
def authorize(self, params: Params) -> bool:
return self.user.is_authenticated
def group(self, params: Params) -> str:
return f'chat_{params.room}'
def receive(self, params: Params, msg: ReactMessage) -> DjangoMessage | None:
return self.DjangoMessage(
user=self.user.email,
text=msg.text,
timestamp=now(),
)
channels.register(ChatChannel, 'chat')
```
```python
# asgi.py
from mizan import channels
application = ProtocolTypeRouter({
"http": get_asgi_application(),
"websocket": channels.get_websocket_application(),
})
```
## Frontend Usage (auto-generated)
```tsx
import { useChatChannel } from '@/api/generated.channels'
function Chat({ room }) {
const chat = useChatChannel({ room })
chat.status // 'connecting' | 'connected' | 'disconnected'
chat.messages // DjangoMessage[]
chat.send({ text: 'Hello' }) // ReactMessage
}
```
## Server Push
```python
await ChatChannel.push(room='general', message=ChatChannel.DjangoMessage(...))
```
"""
import logging
from typing import TYPE_CHECKING, Any, ClassVar, Type
from pydantic import BaseModel
if TYPE_CHECKING:
from django.contrib.auth.models import AbstractBaseUser, AnonymousUser
from ninja import NinjaAPI
logger = logging.getLogger(__name__)
# =============================================================================
# Base Classes
# =============================================================================
class ReactChannel:
"""
Base class for WebSocket channels.
Define nested Pydantic classes for typed messaging:
- Params: Query parameters for subscribing (optional)
- ReactMessage: Messages from browser to server (optional)
- DjangoMessage: Messages from server to browser (optional)
Implement required methods:
- authorize(): Permission check for connection
- group(): Which group to broadcast to
Optionally implement:
- receive(): Handle incoming ReactMessage, return DjangoMessage to broadcast
- on_connect(): Called after successful connection
- on_disconnect(): Called when connection closes
"""
# Nested classes (optional, defined by subclasses)
Params: ClassVar[Type[BaseModel] | None] = None
ReactMessage: ClassVar[Type[BaseModel] | None] = None
DjangoMessage: ClassVar[Type[BaseModel] | None] = None
# Set by the framework when handling a connection
user: "AbstractBaseUser | AnonymousUser"
_channel_layer: Any = None
_channel_name: str = ""
_registered_name: ClassVar[str] = ""
_params_dict: dict = {}
_groups: set[str]
def __init__(self):
self._groups = set()
self._params_dict = {}
def authorize(self, params: BaseModel | None = None) -> bool:
"""
Permission check. Return True to allow connection, False to reject.
Override this to implement custom authorization logic.
"""
raise NotImplementedError(
f"{self.__class__.__name__} must implement authorize()"
)
def group(self, params: BaseModel | None = None) -> str:
"""
Return the group name for broadcasting.
Messages returned from receive() are broadcast to this group.
"""
raise NotImplementedError(f"{self.__class__.__name__} must implement group()")
def receive(self, params: BaseModel | None, msg: BaseModel) -> BaseModel | None:
"""
Handle incoming ReactMessage.
Return a DjangoMessage to broadcast to the group, or None to skip.
Override this to implement message handling.
"""
return None
async def on_connect(self, params: BaseModel | None = None) -> None:
"""Called after successful connection and group join."""
pass
async def on_disconnect(self) -> None:
"""Called when the connection closes."""
pass
# -------------------------------------------------------------------------
# Internal Methods (used by the consumer)
# -------------------------------------------------------------------------
async def _join_group(self, group_name: str) -> None:
"""Join a channel layer group."""
if self._channel_layer:
await self._channel_layer.group_add(group_name, self._channel_name)
self._groups.add(group_name)
async def _leave_group(self, group_name: str) -> None:
"""Leave a channel layer group."""
if self._channel_layer and group_name in self._groups:
await self._channel_layer.group_discard(group_name, self._channel_name)
self._groups.discard(group_name)
async def _leave_all_groups(self) -> None:
"""Leave all joined groups."""
for group_name in list(self._groups):
await self._leave_group(group_name)
async def _broadcast(self, group_name: str, message: BaseModel) -> None:
"""Broadcast a message to a group."""
if self._channel_layer:
await self._channel_layer.group_send(
group_name,
{
"type": "channel.message",
"channel": self._registered_name,
"params": self._params_dict,
"data": message.model_dump(mode="json"),
"message_type": message.__class__.__name__,
},
)
# -------------------------------------------------------------------------
# Class Methods for Server Push
# -------------------------------------------------------------------------
@classmethod
async def push(cls, message: BaseModel, **params) -> None:
"""
Push a message from server code (views, tasks, signals).
Usage:
await ChatChannel.push(
room='general',
message=ChatChannel.DjangoMessage(user='system', text='Hello')
)
"""
from channels.layers import get_channel_layer
channel_layer = get_channel_layer()
if not channel_layer:
logger.warning(
f"No channel layer configured, cannot push to {cls.__name__}"
)
return
# Build params model if defined
params_obj = None
if cls.Params:
params_obj = cls.Params(**params)
# Get group name
instance = cls()
group_name = instance.group(params_obj)
# Send to group
await channel_layer.group_send(
group_name,
{
"type": "channel.message",
"channel": cls._registered_name,
"params": params,
"data": message.model_dump(mode="json"),
"message_type": message.__class__.__name__,
},
)
# =============================================================================
# Registry
# =============================================================================
_registry: dict[str, Type[ReactChannel]] = {}
def register(channel_class: Type[ReactChannel], name: str) -> None:
"""
Register a channel.
Args:
channel_class: The ReactChannel subclass to register
name: URL-friendly name (used in subscriptions)
"""
if name in _registry:
raise ValueError(f"Channel '{name}' is already registered")
channel_class._registered_name = name
# Validate the channel class
if not hasattr(channel_class, "authorize"):
raise ValueError(f"{channel_class.__name__} must implement authorize()")
if not hasattr(channel_class, "group"):
raise ValueError(f"{channel_class.__name__} must implement group()")
_registry[name] = channel_class
logger.debug(f"Registered channel: {name} -> {channel_class.__name__}")
def get_channel(name: str) -> Type[ReactChannel] | None:
"""Get a registered channel class by name."""
return _registry.get(name)
def get_registered_channels() -> dict[str, Type[ReactChannel]]:
"""Get all registered channel classes."""
return dict(_registry)
# =============================================================================
# WebSocket Consumer
# =============================================================================
def get_websocket_application():
"""
Get the WebSocket application for ASGI.
Usage in asgi.py:
from mizan import channels
application = ProtocolTypeRouter({
"http": get_asgi_application(),
"websocket": channels.get_websocket_application(),
})
"""
try:
from channels.routing import URLRouter
from channels.auth import AuthMiddlewareStack
from django.urls import path
except ImportError:
raise ImportError(
"django-channels is required for WebSocket support. "
"Install it with: pip install channels channels-redis"
)
from .connection import DjangoReactConsumer
return AuthMiddlewareStack(
URLRouter(
[
path("ws/", DjangoReactConsumer.as_asgi()),
]
)
)
# =============================================================================
# Schema Export (for TypeScript generation)
# =============================================================================
def get_channels_schema() -> dict:
"""
Get schema for all registered channels (for TypeScript generation).
Returns a dict suitable for the frontend code generator.
"""
schema = {"channels": {}}
for name, channel_class in _registry.items():
channel_schema = {
"name": name,
"params": None,
"reactMessage": None,
"djangoMessage": None,
}
# Extract Params schema
if hasattr(channel_class, "Params") and channel_class.Params:
channel_schema["params"] = channel_class.Params.model_json_schema()
# Extract ReactMessage schema
if hasattr(channel_class, "ReactMessage") and channel_class.ReactMessage:
channel_schema[
"reactMessage"
] = channel_class.ReactMessage.model_json_schema()
# Extract DjangoMessage schema
if hasattr(channel_class, "DjangoMessage") and channel_class.DjangoMessage:
channel_schema[
"djangoMessage"
] = channel_class.DjangoMessage.model_json_schema()
schema["channels"][name] = channel_schema
return schema
def _register_channel_schema_endpoint(
api: "NinjaAPI",
path: str,
operation_id: str,
summary: str,
input_cls: type | None,
output_cls: type,
) -> None:
"""Register a dummy endpoint for schema generation (avoids closure issues)."""
if input_cls is not None:
def endpoint(request, data):
pass
endpoint.__annotations__ = {"data": input_cls}
else:
def endpoint(request):
pass
api.post(path, response=output_cls, operation_id=operation_id, summary=summary)(
endpoint
)
def get_channels_openapi_schema() -> dict:
"""
Get OpenAPI schema for all registered channels.
Uses Django Ninja's schema generation for robust Pydantic→OpenAPI conversion.
This schema is consumed by openapi-typescript for type generation.
"""
from ninja import NinjaAPI
from pydantic import BaseModel
# Create temporary Ninja API for schema generation only
schema_api = NinjaAPI(
title="mizan Channels",
version="1.0.0",
description="Auto-generated schema for mizan channels",
docs_url=None,
openapi_url=None,
)
# Store dynamically created classes
schema_classes: dict[str, type] = {}
channel_metadata: list[dict] = []
for name, channel_class in _registry.items():
pascal_name = name.replace("_", " ").title().replace(" ", "")
channel_meta = {
"name": name,
"pascalName": pascal_name,
"hasParams": False,
"hasReactMessage": False,
"hasDjangoMessage": False,
}
# Register Params type
if hasattr(channel_class, "Params") and channel_class.Params:
params_name = f"{pascal_name}Params"
schema_classes[params_name] = type(params_name, (channel_class.Params,), {})
channel_meta["hasParams"] = True
channel_meta["paramsType"] = params_name
# Create dummy endpoint to include in schema
_register_channel_schema_endpoint(
api=schema_api,
path=f"/channels/{name}/params",
operation_id=f"{name}Params",
summary=f"{pascal_name} channel params",
input_cls=schema_classes[params_name],
output_cls=BaseModel,
)
# Register ReactMessage type
if hasattr(channel_class, "ReactMessage") and channel_class.ReactMessage:
react_name = f"{pascal_name}ReactMessage"
schema_classes[react_name] = type(
react_name, (channel_class.ReactMessage,), {}
)
channel_meta["hasReactMessage"] = True
channel_meta["reactMessageType"] = react_name
_register_channel_schema_endpoint(
api=schema_api,
path=f"/channels/{name}/react",
operation_id=f"{name}ReactMessage",
summary=f"{pascal_name} React→Django message",
input_cls=schema_classes[react_name],
output_cls=BaseModel,
)
# Register DjangoMessage type
if hasattr(channel_class, "DjangoMessage") and channel_class.DjangoMessage:
django_name = f"{pascal_name}DjangoMessage"
schema_classes[django_name] = type(
django_name, (channel_class.DjangoMessage,), {}
)
channel_meta["hasDjangoMessage"] = True
channel_meta["djangoMessageType"] = django_name
_register_channel_schema_endpoint(
api=schema_api,
path=f"/channels/{name}/django",
operation_id=f"{name}DjangoMessage",
summary=f"{pascal_name} Django→React message",
input_cls=None,
output_cls=schema_classes[django_name],
)
channel_metadata.append(channel_meta)
# Get OpenAPI schema from Ninja
# path_prefix="" avoids URL reverse() — this API is never mounted
schema = schema_api.get_openapi_schema(path_prefix="")
# Add channel metadata extension
schema["x-mizan-channels"] = channel_metadata
return schema
# =============================================================================
# Schema Endpoint (for TypeScript generation)
# =============================================================================
_schema_router = None
def _get_schema_router():
"""Get the Ninja router for the channels schema endpoint."""
global _schema_router
if _schema_router is None:
from ninja import Router
_schema_router = Router(tags=["channels"])
@_schema_router.get("/schema/")
def channels_schema(request):
"""Get schema for all registered channels (for TypeScript generation)."""
return get_channels_schema()
return _schema_router
def get_urls():
"""Get URL patterns for channels schema endpoint."""
from ninja import NinjaAPI
api = NinjaAPI(urls_namespace="django_react_channels")
api.add_router("/", _get_schema_router())
return api.urls
def __getattr__(name):
if name == "urls":
return get_urls()
raise AttributeError(f"module {__name__!r} has no attribute {name!r}")
# =============================================================================
# Exports
# =============================================================================
__all__ = [
# URLs
"urls",
# Base class
"ReactChannel",
# Registration
"register",
"get_channel",
"get_registered_channels",
# ASGI application
"get_websocket_application",
# Schema export
"get_channels_schema",
]

View File

@@ -0,0 +1,528 @@
"""
WebSocket consumer for mizan.channels.
Handles multiplexed channel subscriptions AND RPC calls over a single WebSocket connection.
Protocol:
Browser sends:
# Channel subscriptions
{"action": "subscribe", "channel": "chat", "params": {"room": "general"}}
{"action": "unsubscribe", "channel": "chat", "params": {"room": "general"}}
{"action": "message", "channel": "chat", "params": {"room": "general"}, "data": {...}}
# RPC calls (server functions)
{"action": "rpc", "id": "request-id", "fn": "function_name", "args": {...}}
Server sends:
# Channel messages
{"channel": "chat", "params": {"room": "general"}, "type": "DjangoMessage", "data": {...}}
# RPC responses
{"id": "request-id", "ok": true, "data": {...}}
{"id": "request-id", "ok": false, "error": {...}}
{"error": "..."}
Authentication:
Supports both session (cookie) and JWT authentication:
- Session: Handled automatically via AuthMiddlewareStack (cookies in handshake)
- JWT: Pass token as query parameter: ws://...?token=<jwt>
The WebSocket URL for JWT auth would be: ws://localhost/ws/?token=<access_token>
Security:
- Functions must be explicitly registered (no arbitrary code execution)
- Pydantic validation runs BEFORE any function code
"""
import json
import logging
from typing import Any
from urllib.parse import parse_qs
from channels.generic.websocket import AsyncJsonWebsocketConsumer
from asgiref.sync import sync_to_async
from . import get_channel
logger = logging.getLogger(__name__)
class WebSocketRequest:
"""
Minimal request adapter for WebSocket context.
Provides the interface expected by ServerFunction without full HttpRequest.
This is intentionally minimal - only expose what's needed.
Note: Some Django libraries (e.g., allauth rate limiting) check request.method.
We set method="POST" since WebSocket RPC calls are semantically similar to POST.
"""
# WebSocket RPC is semantically similar to POST (sends data, expects response)
method = "POST"
def __init__(self, scope: dict, channel_name: str = None):
self.user = scope.get("user")
self.session = scope.get("session", {})
self.channel_name = channel_name # For push subscriptions
self._scope = scope
@property
def META(self) -> dict:
"""HTTP headers from WebSocket handshake."""
headers = dict(self._scope.get("headers", []))
return {
"HTTP_" + k.decode().upper().replace("-", "_"): v.decode()
for k, v in headers.items()
}
class DjangoReactConsumer(AsyncJsonWebsocketConsumer):
"""
Multiplexed WebSocket consumer for django_react channels.
Manages multiple channel subscriptions over a single WebSocket connection.
Authentication:
- Session auth via cookies (handled by AuthMiddlewareStack)
- JWT auth via query parameter: ws://...?token=<jwt>
"""
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
# Track subscriptions: {(channel_name, params_json): channel_instance}
self._subscriptions: dict[tuple[str, str], Any] = {}
async def connect(self):
"""Accept the WebSocket connection, authenticating via JWT if provided."""
# Check for JWT token in query parameters
await self._try_jwt_auth()
await self.accept()
logger.debug(
f"WebSocket connected: {self.channel_name}, user={self.scope.get('user')}"
)
async def _try_jwt_auth(self):
"""
Attempt JWT authentication from query parameter.
If a valid JWT token is provided via ?token=<jwt>, authenticate the user
using JWTUser (no database query).
Security: If JWT is provided but invalid, we log it but don't reject
the connection - the session auth may still be valid. However, if JWT
IS valid, it takes precedence over session auth.
"""
# Parse query string for token
query_string = self.scope.get("query_string", b"").decode()
params = parse_qs(query_string)
token_list = params.get("token", [])
if not token_list:
return # No JWT provided, use session auth
token = token_list[0]
if not token:
return
# Validate JWT and create JWTUser (no DB query)
try:
from mizan.client.jwt import decode_token
from mizan.jwt.tokens import JWTUser
payload = await sync_to_async(decode_token)(token, expected_type="access")
if payload is None:
logger.debug("JWT token invalid or expired")
return # Fall back to session auth
# Create JWTUser from token claims - NO DATABASE QUERY
self.scope["user"] = JWTUser(payload)
logger.debug(f"JWT auth successful for user {payload.user_id}")
except Exception as e:
logger.debug(f"JWT auth failed: {e}")
async def disconnect(self, close_code):
"""Clean up all subscriptions on disconnect."""
for key, instance in list(self._subscriptions.items()):
try:
await instance.on_disconnect()
await instance._leave_all_groups()
except Exception as e:
logger.error(f"Error during disconnect cleanup: {e}")
self._subscriptions.clear()
logger.debug(f"WebSocket disconnected: {self.channel_name}")
async def receive_json(self, content: dict):
"""Handle incoming JSON messages."""
action = content.get("action")
if action == "subscribe":
await self._handle_subscribe(content)
elif action == "unsubscribe":
await self._handle_unsubscribe(content)
elif action == "message":
await self._handle_message(content)
elif action == "rpc":
await self._handle_rpc(content)
else:
await self.send_json(
{
"error": f"Unknown action: {action}",
}
)
async def _handle_subscribe(self, content: dict):
"""Handle subscription request."""
channel_name = content.get("channel")
params_dict = content.get("params", {})
# Get channel class
channel_class = get_channel(channel_name)
if not channel_class:
await self.send_json(
{
"error": f"Unknown channel: {channel_name}",
}
)
return
# Create subscription key
params_json = json.dumps(params_dict, sort_keys=True)
sub_key = (channel_name, params_json)
# Check if already subscribed
if sub_key in self._subscriptions:
await self.send_json(
{
"error": f"Already subscribed to {channel_name}",
"channel": channel_name,
"params": params_dict,
}
)
return
# Create channel instance
instance = channel_class()
instance.user = self.scope.get("user")
instance._channel_layer = self.channel_layer
instance._channel_name = self.channel_name
instance._registered_name = channel_name
instance._params_dict = params_dict
# Parse params
params_obj = None
if channel_class.Params:
try:
params_obj = channel_class.Params(**params_dict)
except Exception as e:
await self.send_json(
{
"error": f"Invalid params: {e}",
"channel": channel_name,
}
)
return
# Check authorization
try:
if params_obj:
authorized = instance.authorize(params_obj)
else:
authorized = instance.authorize()
except Exception as e:
logger.error(f"Authorization error for {channel_name}: {e}")
await self.send_json(
{
"error": "Authorization failed",
"channel": channel_name,
}
)
return
if not authorized:
await self.send_json(
{
"error": "Not authorized",
"channel": channel_name,
}
)
return
# Get group and join
try:
if params_obj:
group_name = instance.group(params_obj)
else:
group_name = instance.group()
await instance._join_group(group_name)
except Exception as e:
logger.error(f"Failed to join group for {channel_name}: {e}")
await self.send_json(
{
"error": f"Failed to subscribe: {e}",
"channel": channel_name,
}
)
return
# Store subscription
self._subscriptions[sub_key] = instance
# Call on_connect hook
try:
await instance.on_connect(params_obj)
except Exception as e:
logger.error(f"on_connect error for {channel_name}: {e}")
# Confirm subscription
await self.send_json(
{
"subscribed": True,
"channel": channel_name,
"params": params_dict,
}
)
logger.debug(f"Subscribed to {channel_name} with params {params_dict}")
async def _handle_unsubscribe(self, content: dict):
"""Handle unsubscription request."""
channel_name = content.get("channel")
params_dict = content.get("params", {})
params_json = json.dumps(params_dict, sort_keys=True)
sub_key = (channel_name, params_json)
instance = self._subscriptions.pop(sub_key, None)
if instance:
try:
await instance.on_disconnect()
await instance._leave_all_groups()
except Exception as e:
logger.error(f"Error during unsubscribe: {e}")
await self.send_json(
{
"unsubscribed": True,
"channel": channel_name,
"params": params_dict,
}
)
logger.debug(f"Unsubscribed from {channel_name}")
async def _handle_message(self, content: dict):
"""Handle incoming message from browser."""
channel_name = content.get("channel")
params_dict = content.get("params", {})
data = content.get("data", {})
params_json = json.dumps(params_dict, sort_keys=True)
sub_key = (channel_name, params_json)
instance = self._subscriptions.get(sub_key)
if not instance:
await self.send_json(
{
"error": f"Not subscribed to {channel_name}",
"channel": channel_name,
}
)
return
channel_class = instance.__class__
# Check if channel accepts messages
if not channel_class.ReactMessage:
await self.send_json(
{
"error": f"Channel {channel_name} does not accept messages",
"channel": channel_name,
}
)
return
# Parse message
try:
msg = channel_class.ReactMessage(**data)
except Exception as e:
await self.send_json(
{
"error": f"Invalid message: {e}",
"channel": channel_name,
}
)
return
# Parse params
params_obj = None
if channel_class.Params:
params_obj = channel_class.Params(**params_dict)
# Handle message
try:
response = instance.receive(params_obj, msg)
# If handler returned a message, broadcast it
if response is not None:
if params_obj:
group_name = instance.group(params_obj)
else:
group_name = instance.group()
await instance._broadcast(group_name, response)
except Exception as e:
logger.error(f"Error handling message for {channel_name}: {e}")
await self.send_json(
{
"error": f"Message handling failed: {e}",
"channel": channel_name,
}
)
async def _handle_rpc(self, content: dict):
"""
Handle RPC (server function) call.
Protocol:
Request: {"action": "rpc", "id": "request-id", "fn": "function_name", "args": {...}}
Response: {"id": "request-id", "ok": true, "data": {...}}
or: {"id": "request-id", "ok": false, "error": {...}}
Security:
- Only functions with @client(websocket=True) are allowed
- Pydantic validation happens BEFORE any function code runs
- Function must be explicitly registered (no arbitrary code execution)
- User context from WebSocket session is passed to function
"""
from mizan.client.executor import execute_function, FunctionError
from mizan.setup.registry import get_function
request_id = content.get("id")
fn_name = content.get("fn")
args = content.get("args", {})
# Validate request structure
if not request_id:
await self.send_json(
{
"error": "RPC request missing 'id' field",
}
)
return
if not fn_name:
await self.send_json(
{
"id": request_id,
"ok": False,
"error": {
"code": "BAD_REQUEST",
"message": "Missing 'fn' field",
},
}
)
return
# Check if function exists and has websocket=True
fn_class = get_function(fn_name)
if fn_class is None:
await self.send_json(
{
"id": request_id,
"ok": False,
"error": {
"code": "NOT_FOUND",
"message": f"Function '{fn_name}' not found",
},
}
)
return
# Only allow functions explicitly marked with websocket=True
fn_meta = getattr(fn_class, "_meta", {})
if not fn_meta.get("websocket"):
await self.send_json(
{
"id": request_id,
"ok": False,
"error": {
"code": "FORBIDDEN",
"message": "This function is HTTP-only. Use POST /api/mizan/call/ instead.",
},
}
)
return
# Create request adapter from WebSocket scope
ws_request = WebSocketRequest(
self.scope, channel_name=getattr(self, "channel_name", None)
)
# Execute function (Pydantic validation happens inside execute_function)
# This is sync, so we need to run it in a thread pool
result = await sync_to_async(execute_function, thread_sensitive=True)(
ws_request,
fn_name,
args,
)
# Send response
if isinstance(result, FunctionError):
await self.send_json(
{
"id": request_id,
"ok": False,
"error": {
"code": result.code.value,
"message": result.message,
**({"details": result.details} if result.details else {}),
},
}
)
else:
await self.send_json(
{
"id": request_id,
"ok": True,
"data": result.data,
}
)
async def channel_message(self, event: dict):
"""
Handle messages broadcast to a group.
Called when channel_layer.group_send() is used.
Includes channel name and params so the client can route the message.
"""
await self.send_json(
{
"channel": event.get("channel"),
"params": event.get("params", {}),
"type": event.get("message_type", "message"),
"data": event.get("data", {}),
}
)
async def push_message(self, event: dict):
"""
Handle push messages from server functions.
Called when push("topic", data) is used from a server function.
The client receives this to update its local state.
Protocol:
Server sends: {"type": "push", "topic": "room:42", "data": {...}}
"""
await self.send_json(
{
"type": "push",
"topic": event.get("topic"),
"data": event.get("data", {}),
}
)

View File

@@ -0,0 +1,153 @@
"""
mizan Push - Server-initiated messages to clients.
Simple API for pushing data to subscribed WebSocket connections.
Usage:
# In a server function - push to all subscribers
from mizan.push import push
push("room:42", {"type": "new_message", "data": {...}})
# Subscribe a connection to a topic (call during context fetch)
from mizan.push import subscribe
subscribe(request, "room:42")
"""
from typing import TYPE_CHECKING
from pydantic import BaseModel
# Lazy import to avoid import errors when channels is not installed
# (e.g., during schema generation)
if TYPE_CHECKING:
from channels.layers import BaseChannelLayer
def _get_channel_layer() -> "BaseChannelLayer | None":
"""Get channel layer, returning None if channels is not installed."""
try:
from channels.layers import get_channel_layer
return get_channel_layer()
except ImportError:
return None
def _async_to_sync(coro):
"""Wrapper for async_to_sync that handles missing channels."""
from asgiref.sync import async_to_sync
return async_to_sync(coro)
def get_topic_group_name(topic: str) -> str:
"""Convert a topic string to a valid channel layer group name."""
# Channel layer group names must be valid ASCII alphanumeric + hyphens/underscores/periods
# Replace colons with underscores
return topic.replace(":", "_")
def subscribe(request, topic: str) -> None:
"""
Subscribe this WebSocket connection to a topic.
Call this in a context or server function to register the connection
for push notifications on the given topic.
Args:
request: The Django request (must have channel_name attribute from WebSocket)
topic: Topic string, e.g., "room:42", "user:123:notifications"
"""
channel_name = getattr(request, "channel_name", None)
if not channel_name:
# HTTP request, not WebSocket - can't subscribe
return
channel_layer = _get_channel_layer()
if not channel_layer:
return
group_name = get_topic_group_name(topic)
_async_to_sync(channel_layer.group_add)(group_name, channel_name)
def unsubscribe(request, topic: str) -> None:
"""
Unsubscribe this WebSocket connection from a topic.
Args:
request: The Django request (must have channel_name attribute from WebSocket)
topic: Topic string to unsubscribe from
"""
channel_name = getattr(request, "channel_name", None)
if not channel_name:
return
channel_layer = _get_channel_layer()
if not channel_layer:
return
group_name = get_topic_group_name(topic)
_async_to_sync(channel_layer.group_discard)(group_name, channel_name)
def push(topic: str, data: dict | BaseModel) -> None:
"""
Push data to all connections subscribed to a topic.
Args:
topic: Topic string, e.g., "room:42"
data: Data to send (dict or Pydantic model)
Example:
push("room:42", {
"type": "new_message",
"message": {"id": 1, "text": "Hello", "user": "alice@example.com"}
})
"""
channel_layer = _get_channel_layer()
if not channel_layer:
import logging
logging.getLogger(__name__).warning(
"No channel layer configured, cannot push to topic '%s'", topic
)
return
# Convert Pydantic model to dict if needed
if isinstance(data, BaseModel):
data = data.model_dump()
group_name = get_topic_group_name(topic)
_async_to_sync(channel_layer.group_send)(
group_name,
{
"type": "push.message", # Maps to push_message handler in consumer
"topic": topic,
"data": data,
},
)
async def push_async(topic: str, data: dict | BaseModel) -> None:
"""Async version of push for use in async contexts."""
channel_layer = _get_channel_layer()
if not channel_layer:
return
if isinstance(data, BaseModel):
data = data.model_dump()
group_name = get_topic_group_name(topic)
await channel_layer.group_send(
group_name,
{
"type": "push.message",
"topic": topic,
"data": data,
},
)

View File

@@ -0,0 +1,66 @@
"""
mizan.client - Server function implementation.
This subpackage contains everything needed to make server functions work:
- The @client decorator
- ServerFunction base class
- Function execution logic
- JWT authentication (integral to server functions)
Usage:
from mizan.client import client, ServerFunction, compose
"""
from .function import (
# Decorator
client,
# Context markers
ReactContext,
GlobalContext,
# Base classes
ServerFunction,
ComposedContext,
# Composition
compose,
# Type aliases
ContextMode,
# Form helpers
FormValidationOutput,
FormSchemaField,
FormSchemaOutput,
create_form_functions,
)
from .executor import (
execute_function,
function_call_view,
ErrorCode,
FunctionError,
FunctionResult,
)
__all__ = [
# Decorator
"client",
# Context markers
"ReactContext",
"GlobalContext",
# Base classes
"ServerFunction",
"ComposedContext",
# Composition
"compose",
# Type aliases
"ContextMode",
# Execution
"execute_function",
"function_call_view",
"ErrorCode",
"FunctionError",
"FunctionResult",
# Form helpers
"FormValidationOutput",
"FormSchemaField",
"FormSchemaOutput",
"create_form_functions",
]

View File

@@ -0,0 +1,935 @@
"""
mizan Function Executor
Handles execution of server functions.
This is the core of the "Server Functions" feature - callable from React
without REST boilerplate.
Security model:
- All input validated against Pydantic schema BEFORE execution
- Authentication: JWT (stateless) or Session (stateful) - auto-detected
- JWT: Authorization header with Bearer token (no CSRF needed)
- Session: Cookie-based with CSRF token (via X-CSRFToken header)
- WebSocket RPC uses Origin header checking instead
- No implicit function exposure - must be explicitly registered
"""
from __future__ import annotations
import json
import logging
from dataclasses import dataclass
from enum import Enum
from functools import wraps
from typing import TYPE_CHECKING, Any, Callable
from django.http import HttpRequest, HttpResponse, HttpResponseBase, JsonResponse
from django.views.decorators.csrf import csrf_protect
from pydantic import BaseModel, ValidationError
from mizan.cache import get_cache, cache_get, cache_put, cache_purge
from mizan.setup.registry import get_function, get_context_groups
from mizan.setup.settings import get_settings
if TYPE_CHECKING:
pass
logger = logging.getLogger(__name__)
class ErrorCode(str, Enum):
"""Standard error codes for function execution."""
# Client errors (4xx)
NOT_FOUND = "NOT_FOUND" # Function not registered
VALIDATION_ERROR = "VALIDATION_ERROR" # Input failed Pydantic validation
UNAUTHORIZED = "UNAUTHORIZED" # User not authenticated (when required)
FORBIDDEN = "FORBIDDEN" # User lacks permission
BAD_REQUEST = "BAD_REQUEST" # Malformed request
# Server errors (5xx)
INTERNAL_ERROR = "INTERNAL_ERROR" # Unhandled exception
NOT_IMPLEMENTED = "NOT_IMPLEMENTED" # Function exists but not implemented
@dataclass
class FunctionError:
"""Structured error response from function execution."""
code: ErrorCode
message: str
details: dict[str, Any] | None = None
def to_dict(self) -> dict[str, Any]:
"""Convert to JSON-serializable dict."""
result = {
"error": True,
"code": self.code.value,
"message": self.message,
}
if self.details:
result["details"] = self.details
return result
def to_response(self, status: int = 400) -> JsonResponse:
"""Convert to Django JsonResponse."""
return JsonResponse(self.to_dict(), status=status)
@dataclass
class FunctionResult:
"""Successful result from function execution."""
data: dict[str, Any]
def to_dict(self) -> dict[str, Any]:
"""Convert to JSON-serializable dict."""
return {
"error": False,
"data": self.data,
}
def to_response(self) -> JsonResponse:
"""Convert to Django JsonResponse."""
return JsonResponse(self.to_dict())
def _check_auth_requirement(
request: HttpRequest,
auth_requirement: str | Callable | None,
) -> FunctionError | None:
"""
Check if the request meets the auth requirement.
Args:
request: The Django HttpRequest (with user set)
auth_requirement: 'required', 'staff', 'superuser', callable, or None
Returns:
FunctionError if auth check fails, None if it passes.
Note: This uses request.user which may be a JWTUser (stateless) or
Django User (from session). Either way, no additional DB query is made
for the built-in checks. Custom callables may query DB if they choose.
"""
if auth_requirement is None:
return None
user = request.user
# Handle callable auth
if callable(auth_requirement):
try:
result = auth_requirement(request)
if result:
return None # Authorized
else:
return FunctionError(
code=ErrorCode.FORBIDDEN,
message="Access denied",
)
except PermissionError as e:
# Custom error message from the callable
return FunctionError(
code=ErrorCode.FORBIDDEN,
message=str(e) or "Access denied",
)
# Check authentication (required for all string-based auth)
if not getattr(user, "is_authenticated", False):
return FunctionError(
code=ErrorCode.UNAUTHORIZED,
message="Authentication required",
)
# Check staff requirement
if auth_requirement == "staff":
if not getattr(user, "is_staff", False):
return FunctionError(
code=ErrorCode.FORBIDDEN,
message="Staff access required",
)
# Check superuser requirement
elif auth_requirement == "superuser":
if not getattr(user, "is_superuser", False):
return FunctionError(
code=ErrorCode.FORBIDDEN,
message="Superuser access required",
)
return None
_cache_log = logging.getLogger("mizan.cache")
def _purge_cache_for_invalidation(
invalidate: list,
request: HttpRequest | None = None,
) -> None:
"""Purge origin-side cache for invalidation targets. Includes user_id if available."""
cache = get_cache()
if cache is None:
return
settings = get_settings()
if not settings.cache_secret:
return
user_id = None
if request and hasattr(request, 'user') and hasattr(request.user, 'pk'):
uid = getattr(request.user, 'pk', None)
if uid is not None:
user_id = str(uid)
try:
for entry in invalidate:
if isinstance(entry, str):
cache_purge(cache, entry)
elif isinstance(entry, dict):
cache_purge(
cache, entry["context"], entry.get("params"),
secret=settings.cache_secret,
user_id=user_id,
)
except Exception:
_cache_log.warning("Cache purge failed", exc_info=True)
def _resolve_affects_target(target_name: str) -> tuple[str, str, str | None]:
"""
Determine whether an affects target is a context name or function name.
Returns:
("context", "user", None) — full context invalidation
("function", "user_profile", "user") — function within context
"""
groups = get_context_groups()
# Check if it's a context name directly
if target_name in groups:
return ("context", target_name, None)
# Check if it's a function name within a context
for ctx_name, fn_names in groups.items():
if target_name in fn_names:
return ("function", target_name, ctx_name)
# Not a context or context function — treat as context name anyway
# (it might be a non-context function or an as-yet-unregistered context)
return ("context", target_name, None)
def _get_context_param_names(context_name: str) -> set[str]:
"""
Get the set of parameter names used by functions in a context.
Returns the union of all Input field names across context functions.
"""
groups = get_context_groups()
fn_names = groups.get(context_name, [])
param_names: set[str] = set()
for fn_name in fn_names:
fn_cls = get_function(fn_name)
if fn_cls is None:
continue
input_cls = getattr(fn_cls, "Input", None)
if input_cls and input_cls is not BaseModel and hasattr(input_cls, "model_fields"):
param_names.update(input_cls.model_fields.keys())
return param_names
def _resolve_invalidation(
view_class: type | None,
input_data: dict[str, Any] | None = None,
) -> list[str | dict[str, Any]] | None:
"""
Resolve invalidation targets with three-tier auto-scoping.
Tier 1: Argument name matching — if the mutation's input args overlap
with the context's params by name, auto-scope.
Tier 2: Auth inference — Edge-side concern, not handled here.
Tier 3: Broad fallback — invalidate all instances.
Also handles function-level targeting: affects='user_profile' resolves
to the function name (v1: runtime refetches the whole context anyway).
Returns a list suitable for both JSON body and header serialization.
Returns None if no invalidation needed.
"""
if view_class is None:
return None
meta = getattr(view_class, "_meta", {})
affects = meta.get("affects")
if not affects:
return None
result = []
seen = set()
for target in affects:
if target["type"] == "context":
target_name = target["name"]
elif target["type"] == "function" and target.get("context"):
# Function-level: use the function name as the invalidation key
target_name = target["name"]
else:
continue
if target_name in seen:
continue
seen.add(target_name)
# Resolve the context this target belongs to (for param lookup)
resolved = _resolve_affects_target(target_name)
ctx_for_params = resolved[2] if resolved[0] == "function" else resolved[1]
# Tier 1: argument name matching
if input_data and ctx_for_params:
context_params = _get_context_param_names(ctx_for_params)
matched = {
k: v for k, v in input_data.items()
if k in context_params
}
if matched:
result.append({"context": target_name, "params": matched})
continue
# Tier 3: broad fallback
result.append(target_name)
return result if result else None
def _format_invalidate_header(
invalidate: list[str | dict[str, Any]],
) -> str:
"""
Format invalidation targets as X-Mizan-Invalidate header value.
Format: comma-separated contexts. Semicolon-separated params per context.
Param values are URL-encoded to prevent delimiter collisions.
Examples:
["user"] → "user"
["user", "notifications"] → "user, notifications"
[{"context": "user", "params": {"user_id": 5}}]
"user;user_id=5"
[{"context": "search", "params": {"q": "hello world"}}]
"search;q=hello%20world"
"""
from urllib.parse import quote
parts = []
for entry in invalidate:
if isinstance(entry, str):
parts.append(entry)
elif isinstance(entry, dict):
ctx = entry["context"]
params = entry.get("params", {})
if params:
param_str = ";".join(
f"{quote(str(k), safe='')}={quote(str(v), safe='')}"
for k, v in sorted(params.items())
)
parts.append(f"{ctx};{param_str}")
else:
parts.append(ctx)
return ", ".join(parts)
def execute_function(
request: HttpRequest,
fn_name: str,
input_data: dict[str, Any] | None = None,
) -> "FunctionResult | FunctionError | HttpResponseBase":
"""
Execute a registered server function.
Args:
request: The Django HttpRequest
fn_name: Name of the registered function
input_data: Input data to pass to the function
Returns:
FunctionResult on success, FunctionError on failure
"""
from django.conf import settings
# Look up the function by name
view_class = get_function(fn_name)
if view_class is None:
# In DEBUG mode, include the name for easier debugging
if settings.DEBUG:
message = f"Function '{fn_name}' not found"
else:
message = "Function not found"
return FunctionError(
code=ErrorCode.NOT_FOUND,
message=message,
)
# Reject private functions from RPC dispatch
meta = getattr(view_class, "_meta", {})
if meta.get("private"):
return FunctionError(
code=ErrorCode.FORBIDDEN,
message="Function is not client-callable",
)
# Check auth requirement BEFORE executing
auth_requirement = meta.get("auth")
auth_error = _check_auth_requirement(request, auth_requirement)
if auth_error is not None:
return auth_error
# Instantiate the view with the request
view = view_class(request)
# Check if this is a form function that handles input specially
meta = getattr(view_class, "_meta", {})
is_form_multipart = meta.get("multipart", False)
# For form functions with Input=None, skip Pydantic validation
# The form itself handles validation
input_cls = view.Input
if input_cls is None and is_form_multipart:
# Form function - pass input_data directly (already parsed by view or will be)
validated_input = input_data
elif input_cls is BaseModel:
has_input = False
validated_input = None
else:
# Check if it has any fields defined
has_input = bool(input_cls.model_fields) if input_cls else False
# Validate input against Pydantic schema
try:
if input_data:
# Ensure input_data is a dict (not array or other type)
if not isinstance(input_data, dict):
return FunctionError(
code=ErrorCode.BAD_REQUEST,
message="Input must be an object, not "
+ type(input_data).__name__,
)
validated_input = input_cls(**input_data)
elif has_input:
# Check if function requires input fields
input_schema = input_cls.model_json_schema()
required_fields = input_schema.get("required", [])
if required_fields:
# Format as field errors for consistency
errors = {field: ["Field required"] for field in required_fields}
return FunctionError(
code=ErrorCode.VALIDATION_ERROR,
message="Input validation failed",
details={"fields": errors},
)
validated_input = input_cls()
else:
# No input expected, create empty model
validated_input = None
except ValidationError as e:
# Convert Pydantic errors to our format
errors = {}
for error in e.errors():
field = ".".join(str(loc) for loc in error["loc"])
if field not in errors:
errors[field] = []
errors[field].append(error["msg"])
return FunctionError(
code=ErrorCode.VALIDATION_ERROR,
message="Input validation failed",
details={"fields": errors},
)
# Execute the function
try:
output = view.call(validated_input)
except NotImplementedError as e:
logger.error(f"Function {fn_name} not implemented: {e}")
return FunctionError(
code=ErrorCode.NOT_IMPLEMENTED,
message=str(e),
)
except PermissionError as e:
# Functions can raise PermissionError for auth issues
return FunctionError(
code=ErrorCode.FORBIDDEN,
message=str(e) or "Permission denied",
)
except Exception as e:
# Log the full exception for debugging
logger.exception(f"Error executing function {fn_name}")
return FunctionError(
code=ErrorCode.INTERNAL_ERROR,
message="An internal error occurred",
# Don't expose internal details in production
details={"type": type(e).__name__}
if logger.isEnabledFor(logging.DEBUG)
else None,
)
# Return-type branching: HttpResponse (view path) vs data (RPC path)
from django.http import HttpResponseBase
if isinstance(output, HttpResponseBase):
# View path — add invalidation header + purge origin cache
invalidate = _resolve_invalidation(view_class, input_data)
if invalidate:
output["X-Mizan-Invalidate"] = _format_invalidate_header(invalidate)
_purge_cache_for_invalidation(invalidate, request)
output["Cache-Control"] = "no-store"
return output
# RPC path — serialize output
if output is None:
return FunctionResult(data=None)
return FunctionResult(data=output.model_dump())
def _try_mwt_auth(request: HttpRequest) -> bool:
"""
Attempt to authenticate the request using MWT (Mizan Web Token).
Checks the X-Mizan-Token header. If present and valid, sets request.user
to an MWTUser. Returns True on success, False if no MWT header or invalid.
"""
token = request.META.get("HTTP_X_MIZAN_TOKEN", "")
if not token:
return False
try:
settings = get_settings()
if not settings.mwt_secret:
logging.getLogger("mizan.mwt").warning(
"X-Mizan-Token header present but MIZAN_MWT_SECRET is not configured"
)
return False
from mizan.mwt import decode_mwt, MWTUser
payload = decode_mwt(token, settings.mwt_secret)
if payload is None:
return False
request.user = MWTUser(payload)
request._mizan_mwt_authenticated = True
return True
except Exception:
logging.getLogger("mizan.mwt").warning(
"MWT authentication failed unexpectedly", exc_info=True
)
return False
def _has_mwt_header(request: HttpRequest) -> bool:
"""Check if request has an X-Mizan-Token header."""
return bool(request.META.get("HTTP_X_MIZAN_TOKEN", ""))
def _try_jwt_auth(request: HttpRequest) -> bool:
"""
Attempt to authenticate the request using JWT.
If Authorization header contains a valid Bearer token, authenticates
the request and sets request.user to a JWTUser. Returns True if JWT
auth succeeded.
IMPORTANT: This is stateless - no database query is made. The JWTUser
object is created from the token claims. If you need the full User
object, query it explicitly in your function.
Security: If JWT is provided but invalid, we return False and do NOT
fall back to session auth. The caller should reject the request.
"""
auth_header = request.META.get("HTTP_AUTHORIZATION", "")
if not auth_header.startswith("Bearer "):
return False
token = auth_header[7:] # Strip "Bearer "
if not token:
return False
try:
from mizan.client.jwt import decode_token
from mizan.jwt.tokens import JWTUser
payload = decode_token(token, expected_type="access")
if payload is None:
return False
# Create JWTUser from token claims - NO DATABASE QUERY
request.user = JWTUser(payload)
request._mizan_jwt_authenticated = True
return True
except Exception:
return False
def _has_jwt_header(request: HttpRequest) -> bool:
"""Check if request has a JWT Authorization header."""
auth_header = request.META.get("HTTP_AUTHORIZATION", "")
return auth_header.startswith("Bearer ")
def _csrf_protect_unless_token(view_func):
"""
Decorator that applies CSRF protection unless token auth is used.
MWT (X-Mizan-Token) is checked first, then legacy JWT (Authorization: Bearer).
Both are self-authenticating, so CSRF protection is not needed.
Security: If a token is provided but invalid, reject the request - do NOT
fall back to session auth.
"""
csrf_protected_view = csrf_protect(view_func)
@wraps(view_func)
def wrapper(request: HttpRequest, *args, **kwargs):
# MWT takes priority
if _has_mwt_header(request):
if _try_mwt_auth(request):
return view_func(request, *args, **kwargs)
return FunctionError(
code=ErrorCode.UNAUTHORIZED,
message="Invalid or expired MWT",
).to_response(status=401)
# Legacy JWT fallback
if _has_jwt_header(request):
if _try_jwt_auth(request):
return view_func(request, *args, **kwargs)
return FunctionError(
code=ErrorCode.UNAUTHORIZED,
message="Invalid or expired JWT token",
).to_response(status=401)
# No token — session auth with CSRF
return csrf_protected_view(request, *args, **kwargs)
return wrapper
@_csrf_protect_unless_token
def function_call_view(request: HttpRequest) -> JsonResponse:
"""
Django view for handling function calls (HTTP fallback for WebSocket RPC).
Authentication (auto-detected):
- JWT: Authorization: Bearer <token> (stateless, no CSRF needed)
- Session: Cookie-based with X-CSRFToken header (CSRF required)
Endpoint: POST /api/mizan/call/
Request body (JSON):
{
"fn": "function_name", // Function name
"args": { ... } // Optional, depending on function
}
Request body (multipart/form-data for form submit functions):
fn: function_name
<field>: <value>
...
Response on success:
{
"error": false,
"data": { ... } // Function output
}
Response on error:
{
"error": true,
"code": "VALIDATION_ERROR",
"message": "Input validation failed",
"details": { ... }
}
"""
# Only allow POST
if request.method != "POST":
return FunctionError(
code=ErrorCode.BAD_REQUEST,
message="Only POST method allowed",
).to_response(status=405)
# Check content type to determine parsing method
content_type = request.content_type or ""
is_multipart = content_type.startswith("multipart/form-data")
if is_multipart:
# Multipart form data - used by form submit functions
fn_name = request.POST.get("fn")
if not fn_name:
return FunctionError(
code=ErrorCode.BAD_REQUEST,
message="Missing 'fn' field",
).to_response()
# Get form data (excluding 'fn')
input_data = {k: v for k, v in request.POST.dict().items() if k != "fn"}
# Attach parsed form data and files to request for form functions
request._mizan_form_data = input_data
request._mizan_form_files = request.FILES
else:
# JSON body - standard RPC
try:
if request.body:
body = json.loads(request.body)
else:
return FunctionError(
code=ErrorCode.BAD_REQUEST,
message="Request body required",
).to_response()
except json.JSONDecodeError:
return FunctionError(
code=ErrorCode.BAD_REQUEST,
message="Invalid JSON in request body",
).to_response()
# Extract function name and args
fn_name = body.get("fn")
if not fn_name:
return FunctionError(
code=ErrorCode.BAD_REQUEST,
message="Missing 'fn' field",
).to_response()
input_data = body.get("args")
# Execute the function
result = execute_function(request, fn_name, input_data)
# View path — function returned an HttpResponse directly
from django.http import HttpResponseBase
if isinstance(result, HttpResponseBase):
return result
# Return appropriate response
if isinstance(result, FunctionError):
status = {
ErrorCode.NOT_FOUND: 404,
ErrorCode.VALIDATION_ERROR: 422,
ErrorCode.UNAUTHORIZED: 401,
ErrorCode.FORBIDDEN: 403,
ErrorCode.BAD_REQUEST: 400,
ErrorCode.INTERNAL_ERROR: 500,
ErrorCode.NOT_IMPLEMENTED: 501,
}.get(result.code, 400)
return result.to_response(status=status)
# RPC path — build response with server-driven invalidation
view_class = get_function(fn_name)
response_data = {"result": result.data}
invalidate_contexts = _resolve_invalidation(view_class, input_data)
if invalidate_contexts:
response_data["invalidate"] = invalidate_contexts
response = JsonResponse(response_data)
response["Cache-Control"] = "no-store"
if invalidate_contexts:
response["X-Mizan-Invalidate"] = _format_invalidate_header(invalidate_contexts)
_purge_cache_for_invalidation(invalidate_contexts, request)
return response
def execute_context(
request: HttpRequest,
context_name: str,
params: dict[str, str],
) -> FunctionResult | FunctionError:
"""
Execute all functions in a named context with merged params.
Each function receives only the params it declares in its Input schema.
If any function fails (auth, validation, execution), the entire request fails.
Args:
request: The Django HttpRequest
context_name: Name of the context (e.g., 'user', 'global')
params: Query parameters (strings — Pydantic coerces types)
Returns:
FunctionResult with bundled data, or FunctionError
"""
groups = get_context_groups()
fn_names = groups.get(context_name)
if not fn_names:
return FunctionError(
code=ErrorCode.NOT_FOUND,
message=f"Context '{context_name}' not found",
)
results = {}
for fn_name in fn_names:
view_class = get_function(fn_name)
if view_class is None:
continue
# Filter params to only those in this function's Input schema
input_cls = getattr(view_class, "Input", None)
if input_cls and input_cls is not BaseModel and input_cls.model_fields:
fn_params = {
k: v for k, v in params.items()
if k in input_cls.model_fields
}
else:
fn_params = None
result = execute_function(request, fn_name, fn_params)
if isinstance(result, FunctionError):
return result
results[fn_name] = result.data
return FunctionResult(data=results)
def _jwt_auth_only(view_func):
"""
Decorator that handles token auth for GET endpoints (no CSRF needed for GET).
Checks MWT first, then legacy JWT.
"""
@wraps(view_func)
def wrapper(request: HttpRequest, *args, **kwargs):
# MWT takes priority
if _has_mwt_header(request):
if _try_mwt_auth(request):
return view_func(request, *args, **kwargs)
return FunctionError(
code=ErrorCode.UNAUTHORIZED,
message="Invalid or expired MWT",
).to_response(status=401)
# Legacy JWT fallback
if _has_jwt_header(request):
if _try_jwt_auth(request):
return view_func(request, *args, **kwargs)
return FunctionError(
code=ErrorCode.UNAUTHORIZED,
message="Invalid or expired JWT token",
).to_response(status=401)
# No token — session auth (no CSRF needed for GET)
return view_func(request, *args, **kwargs)
return wrapper
@_jwt_auth_only
def context_fetch_view(request: HttpRequest, context_name: str) -> JsonResponse:
"""
Fetch all functions in a named context in a single bundled GET request.
Endpoint: GET /api/mizan/ctx/<context_name>/?param1=val1&param2=val2
Response: raw bundled data, CDN-cacheable.
{
"user_profile": { ... },
"user_orders": [ ... ]
}
Headers:
Cache-Control: public, max-age=0, s-maxage=31536000
"""
if request.method != "GET":
return FunctionError(
code=ErrorCode.BAD_REQUEST,
message="Only GET method allowed",
).to_response(status=405)
params = request.GET.dict()
# Resolve effective rev and cache policy across all functions in this context
_cache_log = logging.getLogger("mizan.cache")
groups = get_context_groups()
fn_names = groups.get(context_name, [])
effective_rev = 0
effective_cache: int | bool = True # True=forever, False=no-store, int=TTL
for fn_name in fn_names:
fn_cls = get_function(fn_name)
if fn_cls:
meta = getattr(fn_cls, "_meta", {})
fn_rev = meta.get("rev", 0)
effective_rev = max(effective_rev, fn_rev)
fn_cache = meta.get("cache", True)
if fn_cache is False:
effective_cache = False
break
elif isinstance(fn_cache, int):
if effective_cache is True:
effective_cache = fn_cache
else:
effective_cache = min(effective_cache, fn_cache)
# Origin-side cache lookup (skip if cache=False)
cache_backend = get_cache()
cache_settings = get_settings()
user_id = None
if hasattr(request, "user") and hasattr(request.user, "pk") and request.user.pk:
user_id = str(request.user.pk)
use_cache = (
cache_backend is not None
and cache_settings.cache_secret
and effective_cache is not False
)
if use_cache:
try:
cached = cache_get(
cache_settings.cache_secret, cache_backend, context_name, params,
user_id=user_id, rev=effective_rev,
)
if cached is not None:
response = HttpResponse(cached, content_type="application/json")
response["Cache-Control"] = "no-store"
response["X-Mizan-Cache"] = "HIT"
return response
except Exception:
_cache_log.warning("Cache lookup failed, falling through", exc_info=True)
result = execute_context(request, context_name, params)
if isinstance(result, FunctionError):
status = {
ErrorCode.NOT_FOUND: 404,
ErrorCode.VALIDATION_ERROR: 422,
ErrorCode.UNAUTHORIZED: 401,
ErrorCode.FORBIDDEN: 403,
ErrorCode.BAD_REQUEST: 400,
ErrorCode.INTERNAL_ERROR: 500,
ErrorCode.NOT_IMPLEMENTED: 501,
}.get(result.code, 400)
error_response = result.to_response(status=status)
error_response["Cache-Control"] = "no-store"
return error_response
# Deterministic JSON (sorted keys) for consistent cache keys
response = JsonResponse(result.data, json_dumps_params={"sort_keys": True})
# Mizan's protocol layers handle caching (origin Redis, Edge Worker).
# The browser and non-Mizan intermediaries must not cache.
response["Cache-Control"] = "no-store"
# Store in origin-side cache (skip if cache=False)
if use_cache:
try:
cache_put(
cache_settings.cache_secret, cache_backend, context_name, params,
response.content, user_id=user_id, rev=effective_rev,
)
response["X-Mizan-Cache"] = "MISS"
except Exception:
_cache_log.warning("Cache store failed", exc_info=True)
return response

View File

@@ -0,0 +1,870 @@
"""
mizan Server Functions - Core Primitive
Server functions are the core primitive. Everything else builds on them.
Two styles supported:
1. Function-based (recommended, Django Ninja style):
@client("update-profile")
def update_profile(request, input: UpdateProfileInput) -> UpdateProfileOutput:
return UpdateProfileOutput(success=True)
2. Class-based (for complex cases):
class UpdateProfile(ServerFunction):
def call(self, input: UpdateProfileInput) -> UpdateProfileOutput:
return UpdateProfileOutput(success=True)
register(UpdateProfile, 'update-profile')
"""
from __future__ import annotations
import inspect
import warnings
from abc import ABC, abstractmethod
from typing import (
Any,
Callable,
ClassVar,
Generic,
Literal,
TypeVar,
Union,
get_args,
get_origin,
get_type_hints,
)
from django.http import HttpRequest
from pydantic import BaseModel
# =============================================================================
# REACT CONTEXT - Named context marker
# =============================================================================
class ReactContext:
"""
A named context that groups server functions into one provider and one fetch.
Usage:
UserContext = ReactContext('user')
@client(context=UserContext)
def user_profile(request, user_id: int) -> ProfileShape: ...
@client(context=UserContext)
def user_orders(request, user_id: int) -> list[OrderShape]: ...
@client(affects=UserContext)
def edit_profile(request, name: str) -> dict: ...
@client(affects=[UserContext, OrderContext])
def change_plan(request) -> dict: ...
"""
def __init__(self, name: str):
if not name or not isinstance(name, str):
raise ValueError("ReactContext name must be a non-empty string")
self.name = name
def __repr__(self) -> str:
return f"ReactContext({self.name!r})"
# Built-in global context (auto-mounted at root, SSR-hydrated)
GlobalContext = ReactContext("global")
# Context parameter type: a ReactContext instance, a raw string, or False
ContextMode = ReactContext | str | Literal[False]
TInput = TypeVar("TInput", bound=BaseModel)
TOutput = TypeVar("TOutput", bound=BaseModel)
# =============================================================================
# SERVER FUNCTION - The Core Primitive
# =============================================================================
class ServerFunction(ABC, Generic[TInput, TOutput]):
"""
Class-based server function (for complex cases).
For simple functions, use the @client decorator instead.
Usage:
class UpdateProfile(ServerFunction):
def call(self, input: UpdateProfileInput) -> UpdateProfileOutput:
self.user.name = input.name
self.user.save()
return UpdateProfileOutput(success=True)
register(UpdateProfile, 'update-profile')
"""
# Registration name (set by register())
name: ClassVar[str]
# Metadata for code generation
_meta: ClassVar[dict[str, Any]] = {}
# Schema classes (set automatically from type hints or explicitly)
Input: ClassVar[type[BaseModel]] = BaseModel
Output: ClassVar[type[BaseModel]] = BaseModel
def __init__(self, request: HttpRequest):
"""Initialize with the Django request."""
self.request = request
@property
def user(self):
"""Shortcut to request.user."""
return self.request.user
@abstractmethod
def call(self, input: TInput) -> TOutput:
"""
Execute the function.
Args:
input: Validated input data
Returns:
Output instance
"""
raise NotImplementedError(f"{self.__class__.__name__} must implement call()")
@classmethod
def get_schema_export(cls) -> dict[str, Any]:
"""Export schema for TypeScript generation."""
export = {
"name": getattr(cls, "name", cls.__name__),
"type": "function",
"meta": getattr(cls, "_meta", {}),
}
# Get Input/Output from class attributes
input_cls = getattr(cls, "Input", BaseModel)
output_cls = getattr(cls, "Output", BaseModel)
# Check if Input has fields
input_schema = input_cls.model_json_schema()
has_input = bool(input_schema.get("properties"))
if has_input:
export["input"] = input_schema
export["has_input"] = has_input
export["output"] = output_cls.model_json_schema()
return export
# =============================================================================
# FUNCTION DECORATOR - Django Ninja Style
# =============================================================================
class _FunctionWrapper(ServerFunction):
"""Internal wrapper that makes a plain function behave like a ServerFunction."""
# Will be set per-wrapper instance
_wrapped_fn: ClassVar[Callable]
_input_cls: ClassVar[type[BaseModel] | None]
_output_cls: ClassVar[type[BaseModel]]
_param_names: ClassVar[list[str]] = []
_is_primitive_output: ClassVar[bool] = False
def call(self, input):
"""Execute the wrapped function, unpacking input into individual args."""
if input is not None and self._param_names:
# Unpack validated model into keyword arguments
kwargs = {name: getattr(input, name) for name in self._param_names}
result = self._wrapped_fn(self.request, **kwargs)
else:
result = self._wrapped_fn(self.request)
# View path — return HttpResponse directly (no serialization)
from django.http import HttpResponseBase
if isinstance(result, HttpResponseBase):
return result
# Wrap primitive returns in the generated output model
if self._is_primitive_output:
return self._output_cls(result=result)
return result
@classmethod
def get_schema_export(cls) -> dict[str, Any]:
"""Export schema for TypeScript generation."""
export = {
"name": getattr(cls, "name", cls.__name__),
"type": "function",
"meta": getattr(cls, "_meta", {}),
}
# Use stored schema classes
if cls._input_cls is not None:
input_schema = cls._input_cls.model_json_schema()
has_input = bool(input_schema.get("properties"))
if has_input:
export["input"] = input_schema
export["has_input"] = has_input
else:
export["has_input"] = False
export["output"] = cls._output_cls.model_json_schema()
return export
# Valid string values for auth parameter
_VALID_AUTH_STRINGS = frozenset({"required", "staff", "superuser"})
def _resolve_context(context: ContextMode) -> str | Literal[False]:
"""Resolve a context parameter to its name string."""
if context is False:
return False
if isinstance(context, ReactContext):
return context.name
if isinstance(context, str):
if not context.strip():
raise ValueError("context must be a non-empty string, ReactContext, or False.")
if context == "local":
warnings.warn(
"context='local' is deprecated. Use ReactContext('name') instead.",
DeprecationWarning,
stacklevel=3,
)
return context
raise ValueError(
f"context must be a ReactContext, a string, or False. Got {type(context).__name__}."
)
# Affects parameter type
AffectsTarget = ReactContext | str | type["ServerFunction"]
AffectsMode = AffectsTarget | list[AffectsTarget] | None
def client(
fn: Callable = None,
*,
context: ContextMode = False,
affects: AffectsMode = None,
private: bool = False,
route: str | None = None,
methods: list[str] | None = None,
websocket: bool = False,
auth: bool | str | Callable[[Any], bool] | None = None,
rev: int = 0,
cache: int | bool = True,
) -> type[ServerFunction] | Callable[[Callable], type[ServerFunction]]:
"""
Register a function as a server function.
Args:
context: Named context for React state management.
- False (default): Not a context, just a callable function.
- ReactContext instance: groups functions into a named context.
- GlobalContext: reserved, auto-mounted at root, SSR-hydrated.
affects: Declare which contexts or functions this mutation invalidates.
Mutually exclusive with context=.
Scoping is automatic via argument name matching.
private: If True, the function is not client-callable.
- Not exposed as an RPC endpoint
- No generated TypeScript
- Still participates in the invalidation graph
- Use for webhooks, cron jobs, internal mutations
route: URL route pattern for view-path functions.
Mizan registers this route during autodiscovery.
Example: '/profile/<user_id>/', '/webhooks/stripe/'
methods: HTTP methods allowed for the route.
Default: ['GET'] for context functions, ['POST'] for mutations.
Example: ['POST'], ['GET', 'POST']
websocket: Enable WebSocket RPC transport (default: False).
auth: Authentication requirement.
Usage:
UserContext = ReactContext('user')
@client(context=UserContext)
def user_profile(request, user_id: int) -> ProfileOutput: ...
@client(affects=UserContext)
def update_profile(request, user_id: int, name: str) -> dict: ...
# View with route — Mizan owns the URL
@client(context=UserContext, route='/profile/<user_id>/')
def profile_page(request, user_id: int) -> HttpResponse: ...
# Private webhook — not client-callable, emits invalidation
@client(affects='subscription', private=True, route='/webhooks/stripe/', methods=['POST'])
def stripe_webhook(request) -> HttpResponse: ...
Returns:
A ServerFunction class that wraps the function
"""
# Resolve context to name string
resolved_context = _resolve_context(context)
# Validate affects parameter
if affects is not None:
if resolved_context is not False:
raise ValueError(
"context= and affects= are mutually exclusive. "
"A function cannot be both a context reader and a mutation."
)
# Validate auth parameter
if auth is not None:
if isinstance(auth, str) and auth not in _VALID_AUTH_STRINGS:
raise ValueError(
f"Invalid auth value '{auth}'. "
f"Must be one of: {', '.join(sorted(_VALID_AUTH_STRINGS))}, True, or a callable."
)
def decorator(fn: Callable) -> type[ServerFunction]:
return _create_server_function(
fn, context=resolved_context, affects=affects,
private=private, route=route, methods=methods,
websocket=websocket, auth=auth, rev=rev, cache=cache,
)
# Support both @client and @client(...)
if fn is not None:
return _create_server_function(
fn, context=resolved_context, affects=affects,
private=private, route=route, methods=methods,
websocket=websocket, auth=auth, rev=rev, cache=cache,
)
return decorator
def _normalize_affects(affects: AffectsMode) -> list[dict[str, str]] | None:
"""Normalize the affects parameter into a list of target descriptors."""
if affects is None:
return None
items = affects if isinstance(affects, list) else [affects]
result = []
for item in items:
if isinstance(item, ReactContext):
result.append({"type": "context", "name": item.name})
elif isinstance(item, str):
result.append({"type": "context", "name": item})
elif isinstance(item, type) and issubclass(item, ServerFunction):
fn_meta = getattr(item, "_meta", {})
fn_ctx = fn_meta.get("context")
result.append({
"type": "function",
"name": getattr(item, "__name__", str(item)),
"context": fn_ctx or None,
})
else:
raise ValueError(
f"affects items must be ReactContext instances, context name strings, "
f"or @client function references. Got {type(item)}"
)
return result
def _create_server_function(
fn: Callable,
*,
context: str | Literal[False] = False,
affects: str | type["ServerFunction"] | list[str | type["ServerFunction"]] | None = None,
private: bool = False,
route: str | None = None,
methods: list[str] | None = None,
websocket: bool = False,
auth: bool | str | None = None,
rev: int = 0,
cache: int | bool = True,
) -> type[ServerFunction]:
"""Internal helper that creates a ServerFunction from a decorated function."""
from pydantic import create_model
# Use function name directly
name = fn.__name__
# Extract type hints and signature
hints = get_type_hints(fn)
sig = inspect.signature(fn)
params = list(sig.parameters.items())
# Skip 'request' parameter (first param)
input_params = params[1:] if params else []
# Build input schema from function parameters
if input_params:
# Build field definitions for create_model
# Format: {field_name: (type, default) or (type, ...)}
fields = {}
for param_name, param in input_params:
param_type = hints.get(param_name, Any)
if param.default is inspect.Parameter.empty:
# Required field
fields[param_name] = (param_type, ...)
else:
# Optional field with default
fields[param_name] = (param_type, param.default)
# Create dynamic Pydantic model
input_cls = create_model(f"{fn.__name__}_Input", **fields)
else:
input_cls = None
# Get output type from return annotation
output_type = hints.get("return")
if output_type is None:
raise TypeError(f"Server function '{name}' must have a return type annotation")
# Detect view path: function returns HttpResponse (or has no return annotation
# that maps to a model — view functions often just have -> HttpResponse)
from django.http import HttpResponseBase
is_view_path = (
isinstance(output_type, type) and issubclass(output_type, HttpResponseBase)
)
if is_view_path:
# View path — no Pydantic output wrapping needed
output_cls = BaseModel # placeholder, never used for serialization
is_primitive_output = False
else:
# RPC path — resolve output type
import types
def is_basemodel_type(t: Any) -> bool:
"""Check if type is a BaseModel subclass, handling Optional/Union."""
if isinstance(t, type) and issubclass(t, BaseModel):
return True
origin = get_origin(t)
if origin is Union or isinstance(t, types.UnionType):
args = get_args(t)
for arg in args:
if (
arg is not type(None)
and isinstance(arg, type)
and issubclass(arg, BaseModel)
):
return True
return False
if is_basemodel_type(output_type):
output_cls = output_type
is_primitive_output = False
else:
output_cls = create_model(f"{fn.__name__}_Output", result=(output_type, ...))
is_primitive_output = True
# Store param names for unpacking validated input
param_names = [p[0] for p in input_params]
# Create a unique wrapper class for this function
class FunctionWrapper(_FunctionWrapper):
_param_names: ClassVar[list[str]] = param_names
FunctionWrapper.__name__ = fn.__name__
FunctionWrapper.__doc__ = fn.__doc__
FunctionWrapper.__module__ = fn.__module__ # Critical for discovery
FunctionWrapper._wrapped_fn = staticmethod(fn)
FunctionWrapper._input_cls = input_cls
FunctionWrapper._output_cls = output_cls
FunctionWrapper._is_primitive_output = is_primitive_output
# Set Input/Output class attributes for compatibility
if input_cls is not None:
FunctionWrapper.Input = input_cls
FunctionWrapper.Output = output_cls
# Build metadata
meta = {}
# View path flag (function returns HttpResponse, no codegen)
if is_view_path:
meta["view_path"] = True
# Private flag (not client-callable, no codegen, no RPC endpoint)
if private:
meta["private"] = True
# Route (Mizan-owned URL pattern for view-path functions)
if route:
meta["route"] = route
meta["methods"] = methods or (["GET"] if context else ["POST"])
# Context name (any non-empty string)
if context:
meta["context"] = context
# Affects: mutation invalidation targets
normalized_affects = _normalize_affects(affects)
if normalized_affects:
meta["affects"] = normalized_affects
# WebSocket: enable WebSocket transport
if websocket:
meta["websocket"] = True
# Auth requirement
if auth is not None:
if auth is True:
meta["auth"] = "required"
elif callable(auth):
meta["auth"] = auth
else:
meta["auth"] = auth
# Revision: bumped by developer when function logic changes.
# Part of the HMAC cache key — old entries become unreachable orphans.
if rev != 0:
meta["rev"] = rev
# Cache policy: True=forever (default), False=no-store, int=TTL seconds
if cache is not True:
meta["cache"] = cache
# Always assign a fresh dict to prevent shared-dict mutation across classes
FunctionWrapper._meta = {**meta}
# Note: Registration happens via discovery (mizan_clients), not here.
# This allows the decorator to be used without import-time side effects.
return FunctionWrapper
# =============================================================================
# COMPOSE - Combine multiple contexts into a single provider
# =============================================================================
class ComposedContext:
"""
Marker class for composed contexts.
Stores metadata about the composition for schema export.
"""
name: str
_meta: dict[str, Any]
_children: list[type[ServerFunction] | "ComposedContext"]
_leaves: list[type[ServerFunction]]
def __init__(
self,
name: str,
children: list,
leaves: list,
on_server: bool,
websocket: bool,
):
self.name = name
self._children = children
self._leaves = leaves
self._meta = {
"compose": True,
"on_server": on_server,
"websocket": websocket,
"children": [c.name for c in children],
"leaves": [leaf.name for leaf in leaves],
}
@classmethod
def get_schema_export(cls) -> dict[str, Any]:
"""Export schema for TypeScript generation."""
return {
"name": cls.name,
"type": "compose",
"meta": cls._meta,
"children": cls._meta.get("children", []),
"leaves": cls._meta.get("leaves", []),
}
def _get_leaves(item) -> list[type[ServerFunction]]:
"""Recursively collect all leaf contexts from a context or composition."""
if isinstance(item, type) and issubclass(item, ServerFunction):
return [item]
elif isinstance(item, ComposedContext):
return item._leaves.copy()
elif hasattr(item, "_leaves"):
# Duck typing for composed contexts
return item._leaves.copy()
else:
raise TypeError(f"Expected ServerFunction or ComposedContext, got {type(item)}")
def _is_context_enabled(item) -> bool:
"""Check if an item is a context-enabled function or composition."""
if isinstance(item, ComposedContext) or hasattr(item, "_leaves"):
return True
if isinstance(item, type) and issubclass(item, ServerFunction):
meta = getattr(item, "_meta", {})
return bool(meta.get("context"))
return False
def compose(
*children,
on_server: bool = False,
websocket: bool = False,
):
"""
Compose multiple contexts into a single provider.
Args:
*children: Context functions (@client with a context name)
or other @compose functions. All must be unique after flattening.
on_server: Bundle all calls into a single server request (default: False).
- False: Frontend makes individual calls (mixed HTTP/WS OK)
- True: Single bundled call. Requires transport consistency:
all children must be HTTP-only XOR all must be websocket=True.
websocket: Transport for bundled call when on_server=True (default: False).
- False: Bundled call over HTTP. All children must be HTTP-only.
- True: Bundled call over WebSocket. All children must have websocket=True.
Usage:
@client(context='local')
def user_profile(request, user_id: int) -> ProfileOutput: ...
@client(context='local')
def user_posts(request, user_id: int) -> PostsOutput: ...
@compose(user_profile, user_posts)
def user_page():
pass
# Frontend generates:
# <UserPageProvider user_id={123}>
# <App />
# </UserPageProvider>
Nesting:
@compose(ctx_a, ctx_b)
def ab(): pass
@compose(ab, ctx_c) # Flattens to [ctx_a, ctx_b, ctx_c]
def abc(): pass
Returns:
A ComposedContext that can be used in other compositions.
"""
def decorator(fn: Callable) -> ComposedContext:
from mizan.setup.registry import register_compose
name = fn.__name__
# Validate: all children must be context-enabled
for i, child in enumerate(children):
if not _is_context_enabled(child):
child_name = getattr(
child, "name", getattr(child, "__name__", str(child))
)
raise ValueError(
f"@compose argument {i} ({child_name}) is not context-enabled. "
f"All children must have @client(context=...) or be @compose."
)
# Flatten to collect all leaves
leaves = []
for child in children:
leaves.extend(_get_leaves(child))
# Validate: no duplicate leaves (by identity)
seen = set()
for leaf in leaves:
if id(leaf) in seen:
raise ValueError(
f"Duplicate context '{leaf.name}' in @compose({name}). "
f"Each context can only appear once. Use named kwargs for reuse (future feature)."
)
seen.add(id(leaf))
# Validate transport consistency when on_server=True
if on_server:
has_websocket = [
getattr(leaf, "_meta", {}).get("websocket", False) for leaf in leaves
]
if websocket:
# All must have websocket=True
if not all(has_websocket):
non_ws = [
leaf.name for leaf, ws in zip(leaves, has_websocket) if not ws
]
raise ValueError(
f"@compose({name}, on_server=True, websocket=True) requires all children "
f"to have websocket=True. These are HTTP-only: {non_ws}"
)
else:
# All must be HTTP-only
if any(has_websocket):
ws_enabled = [
leaf.name for leaf, ws in zip(leaves, has_websocket) if ws
]
raise ValueError(
f"@compose({name}, on_server=True, websocket=False) requires all children "
f"to be HTTP-only. These have websocket=True: {ws_enabled}"
)
# Create composed context
composed = ComposedContext(
name=name,
children=list(children),
leaves=leaves,
on_server=on_server,
websocket=websocket,
)
# Make it a class-like object for consistency
composed.__name__ = name
composed.__doc__ = fn.__doc__
# Register the composition
register_compose(composed, name)
return composed
return decorator
# =============================================================================
# FORM HELPERS - Output types used by form server functions
# =============================================================================
class FormValidationOutput(BaseModel):
"""Standard output for form validation."""
valid: bool
errors: dict[str, list[str]]
class FormSchemaField(BaseModel):
"""Schema for a single form field."""
name: str
type: str
required: bool
label: str
help_text: str | None = None
choices: list[tuple[str, str]] | None = None
initial: Any = None
class FormSchemaOutput(BaseModel):
"""Standard output for form schema."""
fields: list[FormSchemaField]
def create_form_functions(
form_class: type,
name: str,
submit_handler: Callable[[HttpRequest, dict], BaseModel] | None = None,
) -> tuple[type[ServerFunction], type[ServerFunction], type[ServerFunction] | None]:
"""
Generate server functions for a Django Form.
Args:
form_class: Django Form class
name: Base name for the functions
submit_handler: Optional handler for form submission
Returns:
Tuple of (SchemaFunction, ValidateFunction, SubmitFunction or None)
Usage:
SchemaFn, ValidateFn, SubmitFn = create_form_functions(
ContactForm,
'contact',
submit_handler=lambda req, data: ContactSubmitOutput(success=True),
)
register(SchemaFn, 'contact-schema')
register(ValidateFn, 'contact-validate')
register(SubmitFn, 'contact-submit')
Or use the helper:
register_form(ContactForm, 'contact', submit_handler=...)
"""
from mizan.forms.schema_utils import build_form_schema
# Schema function - returns field definitions
class FormSchema(ServerFunction):
class Output(FormSchemaOutput):
pass
def call(self, input):
schema = build_form_schema(form_class)
fields = [
FormSchemaField(
name=field.name,
type=field.type,
required=field.required,
label=field.label or field.name,
help_text=field.help_text or None,
choices=[(c.value, c.label) for c in field.choices]
if field.choices
else None,
initial=field.initial,
)
for field in schema.fields
]
return self.Output(fields=fields)
FormSchema.__name__ = f"{name.title().replace('-', '')}Schema"
FormSchema._meta = {"form": True, "form_name": name, "form_role": "schema"}
# Validation function
class FormDataInput(BaseModel):
data: dict[str, Any]
class FormValidate(ServerFunction):
Input = FormDataInput
class Output(FormValidationOutput):
pass
def call(self, input):
form = form_class(data=input.data)
if form.is_valid():
return self.Output(valid=True, errors={})
return self.Output(valid=False, errors=dict(form.errors))
FormValidate.__name__ = f"{name.title().replace('-', '')}Validate"
FormValidate._meta = {"form": True, "form_name": name, "form_role": "validate"}
# Submit function (optional)
FormSubmit = None
if submit_handler:
class FormSubmit(ServerFunction):
Input = FormDataInput
def call(self, input):
# Validate first
form = form_class(data=input.data)
if not form.is_valid():
raise ValueError("Form validation failed")
# Call handler
return submit_handler(self.request, form.cleaned_data)
FormSubmit.__name__ = f"{name.title().replace('-', '')}Submit"
FormSubmit._meta = {"form": True, "form_name": name, "form_role": "submit"}
return FormSchema, FormValidate, FormSubmit

View File

@@ -0,0 +1,44 @@
"""
mizan.client.jwt - JWT authentication for server functions.
Provides:
- Server functions for obtaining/refreshing JWT tokens
- JWT authentication utilities for validating tokens
Server Functions:
- jwt_obtain: Convert authenticated session to JWT tokens
- jwt_refresh: Refresh tokens using a refresh token
Note: This module is purpose-built for mizan server functions.
For Django Ninja API authentication, use mizan.jwt.security directly.
"""
# Token utilities (re-exports from django_jwt_session)
from mizan.jwt.tokens import (
create_token_pair,
create_access_token,
create_refresh_token,
decode_token,
refresh_tokens,
TokenPair,
TokenPayload,
JWTUser,
)
# Settings
from mizan.jwt.settings import get_settings, JWTSettings
__all__ = [
# Token utilities
"create_token_pair",
"create_access_token",
"create_refresh_token",
"decode_token",
"refresh_tokens",
"TokenPair",
"TokenPayload",
"JWTUser",
# Settings
"get_settings",
"JWTSettings",
]

View File

@@ -0,0 +1,515 @@
"""
mizan OpenAPI Schema Generator
Generates OpenAPI 3.0 compatible schema from registered server functions.
Uses Django Ninja's battle-tested schema generation for robust Pydantic→OpenAPI conversion.
This schema is consumed by the frontend generator which uses openapi-typescript
for robust type generation.
NOTE: Schema export is only available via management command for security.
HTTP endpoint has been removed to prevent function enumeration.
Usage:
python manage.py export_mizan_schema
"""
from __future__ import annotations
import json
import re
from typing import TYPE_CHECKING, Any
# Lazy imports to avoid Django settings access at module load time
# (asgi.py imports mizan before Django is fully configured)
if TYPE_CHECKING:
from django import forms
from ninja import NinjaAPI
from mizan.setup.registry import get_registry, get_schema, get_context_groups, get_function
__all__ = [
"get_schema",
"generate_openapi_schema",
"generate_openapi_json",
"generate_edge_manifest",
"generate_edge_manifest_json",
]
def _extract_form_fields(form_class: type) -> list[dict[str, Any]]:
"""
Extract field definitions with constraints from a Django Form class.
Returns a list of field metadata suitable for Zod schema generation:
- name: field name
- zodType: base Zod type ("string", "number", "boolean", "array")
- required: whether field is required
- constraints: dict of Zod-compatible constraints
Constraints include:
- min/max: for string length or number range
- email/url: for format validation
- regex: for pattern validation
- choices: for enum validation
"""
try:
# Try to instantiate form to get bound fields
form = form_class()
fields_dict = form.fields
except TypeError:
# Form requires extra args - use base_fields
fields_dict = getattr(form_class, "base_fields", {})
result = []
for name, field in fields_dict.items():
field_meta = _extract_field_constraints(name, field)
result.append(field_meta)
return result
def _extract_field_constraints(name: str, field: "forms.Field") -> dict[str, Any]:
"""
Extract Zod-compatible constraints from a single Django form field.
"""
from django import forms # Lazy import
meta: dict[str, Any] = {
"name": name,
"required": field.required,
"constraints": {},
}
# Determine base Zod type
if isinstance(field, forms.BooleanField):
meta["zodType"] = "boolean"
elif isinstance(field, (forms.IntegerField, forms.FloatField, forms.DecimalField)):
meta["zodType"] = "number"
if isinstance(field, forms.IntegerField):
meta["constraints"]["int"] = True
elif isinstance(field, forms.MultipleChoiceField):
meta["zodType"] = "array"
meta["constraints"]["items"] = "string"
elif isinstance(field, forms.FileField):
meta["zodType"] = "file"
else:
# Default to string (CharField, EmailField, URLField, etc.)
meta["zodType"] = "string"
# Extract string constraints
if hasattr(field, "max_length") and field.max_length is not None:
meta["constraints"]["max"] = field.max_length
if hasattr(field, "min_length") and field.min_length is not None:
meta["constraints"]["min"] = field.min_length
# Extract number constraints
if hasattr(field, "max_value") and field.max_value is not None:
meta["constraints"]["max"] = field.max_value
if hasattr(field, "min_value") and field.min_value is not None:
meta["constraints"]["min"] = field.min_value
# Email/URL format
if isinstance(field, forms.EmailField):
meta["constraints"]["email"] = True
elif isinstance(field, forms.URLField):
meta["constraints"]["url"] = True
# Choices (for enum validation)
if hasattr(field, "choices") and field.choices:
# Extract choice values (not labels)
choices = []
for choice in field.choices:
if isinstance(choice, (list, tuple)) and len(choice) >= 1:
# Skip empty/blank choices
if choice[0] != "":
choices.append(str(choice[0]))
else:
choices.append(str(choice))
if choices:
meta["constraints"]["choices"] = choices
# Regex validators
for validator in field.validators:
if hasattr(validator, "regex"):
# RegexValidator - extract pattern
pattern = validator.regex.pattern
meta["constraints"]["regex"] = pattern
if hasattr(validator, "message"):
meta["constraints"]["regexMessage"] = validator.message
break # Only use first regex validator
return meta
def snake_to_camel(name: str) -> str:
"""Convert snake_case or dotted.name to camelCase.
Examples:
- login -> login
- login.schema -> loginSchema
- activate_totp -> activateTotp
- activate_totp.schema -> activateTotpSchema
"""
# Split on both underscores and dots
components = re.split(r"[._]", name)
return components[0] + "".join(x.title() for x in components[1:])
def _register_schema_endpoint(
api: "NinjaAPI",
path: str,
operation_id: str,
summary: str,
input_cls: type | None,
output_cls: type,
) -> None:
"""
Register a dummy endpoint on the API for schema generation.
Sets __annotations__ directly to avoid closure capture issues
and exec() security concerns.
"""
if input_cls is not None:
def endpoint(request, data):
pass
# Set annotations directly to the actual type objects (not strings)
endpoint.__annotations__ = {"data": input_cls}
else:
def endpoint(request):
pass
# Register with Ninja
api.post(path, response=output_cls, operation_id=operation_id, summary=summary)(
endpoint
)
def generate_openapi_schema() -> dict[str, Any]:
"""
Generate OpenAPI 3.0 schema for all registered mizan functions.
Uses Django Ninja's schema generation internally to ensure proper
Pydantic→OpenAPI conversion (handling $refs, nested types, etc.).
Returns a complete OpenAPI document that can be processed by openapi-typescript.
"""
from ninja import NinjaAPI # Lazy import
from pydantic import BaseModel, create_model # Lazy import
registry = get_registry()
functions = registry.get("functions", {})
# Create a temporary Ninja API for schema generation only
# This is NOT exposed as an HTTP endpoint - purely for leveraging Ninja's
# battle-tested Pydantic→OpenAPI conversion
schema_api = NinjaAPI(
title="mizan Server Functions",
version="1.0.0",
description="Auto-generated schema for mizan server functions",
docs_url=None, # No docs endpoint
openapi_url=None, # No openapi endpoint
)
function_metadata: list[dict[str, Any]] = []
# Store dynamically created classes so they persist for schema generation
schema_classes: dict[str, type] = {}
for name, fn_class in functions.items():
camel_name = snake_to_camel(name)
meta = getattr(fn_class, "_meta", {})
# Get Input/Output classes
input_cls = getattr(fn_class, "Input", None)
output_cls = getattr(fn_class, "Output", None) or BaseModel
# Check if input_cls is a valid Pydantic model with fields
has_input = (
input_cls is not None
and input_cls is not BaseModel
and hasattr(input_cls, "model_fields")
and bool(input_cls.model_fields)
)
# Determine type names for metadata
input_type_name = f"{camel_name}Input" if has_input else None
output_type_name = f"{camel_name}Output"
# Create renamed Pydantic classes for cleaner schema names
# Store them in schema_classes so they persist beyond loop scope
# Uses create_model to avoid metaclass conflicts with custom base classes
if has_input:
schema_classes[input_type_name] = create_model(
input_type_name, __base__=input_cls
)
schema_classes[output_type_name] = create_model(
output_type_name, __base__=output_cls
)
# Register endpoint using helper to avoid closure capture issues
_register_schema_endpoint(
api=schema_api,
path=f"/mizan/{name}",
operation_id=camel_name,
summary=fn_class.__doc__ or f"Call {name}",
input_cls=schema_classes.get(input_type_name),
output_cls=schema_classes[output_type_name],
)
# Collect function metadata for provider generation
fn_meta_entry: dict[str, Any] = {
"name": name,
"camelName": camel_name,
"hasInput": has_input,
"inputType": input_type_name,
"outputType": output_type_name,
"transport": "websocket" if meta.get("websocket") else "http",
"isContext": meta.get("context", False),
# Form metadata
"isForm": meta.get("form", False),
"formName": meta.get("form_name"),
"formRole": meta.get("form_role"), # "schema", "validate", "submit"
}
# Affects metadata (mutation invalidation)
if meta.get("affects"):
fn_meta_entry["affects"] = meta["affects"]
# For form schema functions, extract field definitions for Zod generation
if meta.get("form") and meta.get("form_role") == "schema":
form_class = meta.get("form_class")
if form_class is not None:
try:
fn_meta_entry["formFields"] = _extract_form_fields(form_class)
except Exception as e:
# Don't fail schema generation if field extraction fails
fn_meta_entry["formFields"] = []
fn_meta_entry["formFieldsError"] = str(e)
function_metadata.append(fn_meta_entry)
# Get the OpenAPI schema from Ninja (handles all Pydantic conversion properly)
schema = schema_api.get_openapi_schema(path_prefix="")
# Add custom extension with function metadata for provider generation
schema["x-mizan-functions"] = function_metadata
# Add x-mizan-contexts: grouped context metadata with param elevation
context_groups = get_context_groups()
if context_groups:
contexts_meta: dict[str, Any] = {}
for ctx_name, fn_names in context_groups.items():
# Analyze params across all functions in the context
param_info: dict[str, dict[str, Any]] = {}
for fn_name in fn_names:
fn_cls = get_function(fn_name)
if fn_cls is None:
continue
input_cls = getattr(fn_cls, "Input", None)
if input_cls and input_cls is not BaseModel and hasattr(input_cls, "model_fields"):
for field_name, field_info in input_cls.model_fields.items():
if field_name not in param_info:
annotation = field_info.annotation
# Map Python types to JSON schema types
type_name = "string"
if annotation in (int,):
type_name = "integer"
elif annotation in (float,):
type_name = "number"
elif annotation in (bool,):
type_name = "boolean"
param_info[field_name] = {
"type": type_name,
"sharedBy": [],
}
param_info[field_name]["sharedBy"].append(fn_name)
# A param is required if ALL functions in the context declare it
for p_name, p_meta in param_info.items():
p_meta["required"] = len(p_meta["sharedBy"]) == len(fn_names)
contexts_meta[ctx_name] = {
"functions": fn_names,
"params": param_info,
}
schema["x-mizan-contexts"] = contexts_meta
# Add x-mizan metadata to each operation
for fn_meta in function_metadata:
path = f"/mizan/{fn_meta['name']}"
if path in schema.get("paths", {}):
schema["paths"][path]["post"]["x-mizan"] = {
"transport": fn_meta["transport"],
"isContext": fn_meta["isContext"],
}
return schema
def generate_openapi_json(indent: int = 2) -> str:
"""Generate OpenAPI schema as formatted JSON string."""
schema = generate_openapi_schema()
return json.dumps(schema, indent=indent)
def generate_edge_manifest(
base_url: str = "/api/mizan",
view_urls: dict[str, list[str]] | None = None,
) -> dict[str, Any]:
"""
Generate the Edge manifest — a static JSON mapping contexts to URL
patterns and params for CDN cache purging.
The manifest is consumed by Mizan Edge at deploy time. When Edge
receives X-Mizan-Invalidate: user;user_id=5, it:
1. Looks up 'user' in the manifest
2. Resolves URL patterns with params: /profile/:user_id/ → /profile/5/
3. Purges the resolved URLs + the context API endpoint
Args:
base_url: The Mizan API mount point (default: /api/mizan)
view_urls: Optional mapping of context names to URL patterns for
view-path functions. These are URLs that Edge should
also purge when a context is invalidated.
Example: {"user": ["/profile/:user_id/"]}
Returns:
Manifest dict suitable for JSON serialization.
"""
from pydantic import BaseModel as PydanticBaseModel
# Common user identity param names for user_scoped detection
_USER_SCOPED_PARAMS = {"user_id", "user", "owner_id", "account_id"}
groups = get_context_groups()
registry = get_registry()
all_functions = registry.get("functions", {})
manifest: dict[str, Any] = {"version": 1, "contexts": {}, "mutations": {}}
for ctx_name, fn_names in groups.items():
# Collect params and routes from all functions in this context
param_names: set[str] = set()
functions_meta: list[dict[str, Any]] = []
page_routes: list[str] = []
for fn_name in fn_names:
fn_cls = all_functions.get(fn_name)
if fn_cls is None:
continue
meta = getattr(fn_cls, "_meta", {})
is_view = meta.get("view_path", False)
# Collect param names from Input schema
input_cls = getattr(fn_cls, "Input", None)
if (
input_cls
and input_cls is not PydanticBaseModel
and hasattr(input_cls, "model_fields")
):
param_names.update(input_cls.model_fields.keys())
fn_entry: dict[str, Any] = {
"name": fn_name,
"path": "view" if is_view else "rpc",
}
# Collect routes from view-path functions
fn_route = meta.get("route")
if fn_route:
fn_entry["route"] = fn_route
fn_entry["methods"] = meta.get("methods", ["GET"])
page_routes.append(fn_route)
# Cache protocol metadata
if "rev" in meta:
fn_entry["rev"] = meta["rev"]
if "cache" in meta:
fn_entry["cache"] = meta["cache"]
functions_meta.append(fn_entry)
sorted_params = sorted(param_names)
user_scoped = bool(param_names & _USER_SCOPED_PARAMS)
ctx_entry: dict[str, Any] = {
"functions": functions_meta,
"endpoints": [f"{base_url}/ctx/{ctx_name}/"],
"params": sorted_params,
"user_scoped": user_scoped,
"render_strategy": "dynamic_cached" if user_scoped else "psr",
}
# Add page routes from view-path functions with route=
if page_routes:
ctx_entry["page_routes"] = page_routes
# Add externally-declared view URLs
if view_urls and ctx_name in view_urls:
ctx_entry.setdefault("page_routes", []).extend(view_urls[ctx_name])
manifest["contexts"][ctx_name] = ctx_entry
# Mutations section — all functions with affects=
for fn_name, fn_cls in all_functions.items():
meta = getattr(fn_cls, "_meta", {})
affects = meta.get("affects")
if not affects:
continue
# Resolve context names from affects targets
affected_contexts = []
for target in affects:
if target["type"] == "context":
affected_contexts.append(target["name"])
elif target["type"] == "function" and target.get("context"):
affected_contexts.append(target["context"])
affected_contexts = list(dict.fromkeys(affected_contexts))
# Determine which params auto-scope
auto_scoped = []
input_cls = getattr(fn_cls, "Input", None)
if input_cls and input_cls is not PydanticBaseModel and hasattr(input_cls, "model_fields"):
fn_params = set(input_cls.model_fields.keys())
for ctx_name in affected_contexts:
ctx_params = set()
for ctx_fn_name in groups.get(ctx_name, []):
ctx_fn_cls = all_functions.get(ctx_fn_name)
if ctx_fn_cls:
ctx_input = getattr(ctx_fn_cls, "Input", None)
if ctx_input and ctx_input is not PydanticBaseModel and hasattr(ctx_input, "model_fields"):
ctx_params.update(ctx_input.model_fields.keys())
auto_scoped.extend(sorted(fn_params & ctx_params))
auto_scoped = list(dict.fromkeys(auto_scoped))
mutation_entry: dict[str, Any] = {
"affects": affected_contexts,
}
if auto_scoped:
mutation_entry["auto_scoped_params"] = auto_scoped
if meta.get("private"):
mutation_entry["private"] = True
if meta.get("route"):
mutation_entry["route"] = meta["route"]
mutation_entry["methods"] = meta.get("methods", ["POST"])
manifest["mutations"][fn_name] = mutation_entry
return manifest
def generate_edge_manifest_json(
indent: int = 2,
base_url: str = "/api/mizan",
view_urls: dict[str, list[str]] | None = None,
) -> str:
"""Generate Edge manifest as formatted JSON string."""
manifest = generate_edge_manifest(base_url=base_url, view_urls=view_urls)
return json.dumps(manifest, indent=indent, sort_keys=True)

View File

@@ -0,0 +1,632 @@
"""
mizanFormMixin - Turn Django Forms into server functions.
This mixin transforms any Django Form into mizan server functions,
preserving full Django Form functionality (validation, widgets, ModelChoiceField, etc.)
while exposing them through the unified server function API.
Usage:
from django import forms
from mizan.forms import mizanFormMixin, mizanFormMeta
class ContactForm(mizanFormMixin, forms.Form):
mizan = mizanFormMeta(
name="contact",
title="Contact Us",
submit_label="Send",
)
name = forms.CharField()
email = forms.EmailField()
message = forms.CharField(widget=forms.Textarea)
def on_submit_success(self, request):
send_email(self.cleaned_data)
return {"sent": True}
Auto-registers server functions:
- contact.schema
- contact.validate
- contact.submit
"""
from __future__ import annotations
from typing import TYPE_CHECKING, Any, ClassVar
from django import forms
from django.http import HttpRequest
from pydantic import BaseModel, create_model
if TYPE_CHECKING:
from .schemas import FormValidation
def _django_field_to_python_type(field: forms.Field) -> type:
"""
Map a Django form field to a Python type for Pydantic schema generation.
This provides TypeScript with proper field types instead of generic `any`.
"""
# Handle common Django field types
if isinstance(field, forms.BooleanField):
return bool
elif isinstance(field, forms.IntegerField):
return int
elif isinstance(field, forms.FloatField):
return float
elif isinstance(field, forms.DecimalField):
return str # Decimals serialize as strings for precision
elif isinstance(field, forms.DateTimeField):
return str # ISO format string
elif isinstance(field, forms.DateField):
return str # ISO format string
elif isinstance(field, forms.TimeField):
return str # ISO format string
elif isinstance(field, forms.JSONField):
return dict | list | str | int | float | bool | None
elif isinstance(field, forms.MultipleChoiceField):
return list[str]
elif isinstance(field, forms.FileField):
return str # File path/name as string
elif isinstance(field, forms.ImageField):
return str # File path/name as string
else:
# Default to string (covers CharField, EmailField, URLField, etc.)
return str
def _create_form_input_schema(
form_class: type[forms.BaseForm],
schema_name: str,
) -> type[BaseModel]:
"""
Create a Pydantic model from Django Form fields.
This generates a typed schema for the form's input data, giving TypeScript
full LSP support (autocomplete, type checking) for form fields.
Args:
form_class: Django Form class to introspect
schema_name: Name for the generated Pydantic model (e.g., "ContactFormData")
Returns:
A Pydantic BaseModel subclass with fields matching the form
"""
# Instantiate form without data to get field definitions
try:
form = form_class()
except TypeError:
# Form requires extra args (like request) - use form_class.base_fields instead
fields_dict = getattr(form_class, "base_fields", {})
else:
fields_dict = form.fields
# Build Pydantic field definitions
pydantic_fields: dict[str, Any] = {}
for field_name, field in fields_dict.items():
python_type = _django_field_to_python_type(field)
# Optional fields (not required or has initial value)
if not field.required:
python_type = python_type | None
default = None
elif field.initial is not None:
default = field.initial
else:
default = ... # Required field
pydantic_fields[field_name] = (python_type, default)
# Create the model with a unique name
model = create_model(schema_name, **pydantic_fields)
return model
class mizanFormMeta(BaseModel):
"""
Configuration for a mizan form.
This Pydantic model provides type-safe configuration with full LSP support,
and serializes to JSON for the frontend schema.
Required:
name: API identifier (e.g., "contact" → contact.schema, contact.validate, contact.submit)
Display options:
title: Display title (default: derived from class name)
subtitle: Display subtitle
submit_label: Submit button text (default: "Submit")
Frontend behavior:
live_validation: Enable live validation as user types (default: True)
live_form_errors: Show form-level errors during live validation (default: False)
refetch_schema_on_validate: Refetch schema on each validation - useful for
dynamic choice fields (default: False)
Features:
enable_formset: Generate formset endpoints (default: False)
"""
# Required
name: str
# Display
title: str | None = None
subtitle: str | None = None
submit_label: str = "Submit"
# Frontend behavior
live_validation: bool = True
live_form_errors: bool = False
refetch_schema_on_validate: bool = False
# Features
enable_formset: bool = False
class mizanFormMixin:
"""
Mixin that exposes a Django Form as mizan server functions.
Add this mixin to any Django Form class along with a `mizan` configuration:
class ContactForm(mizanFormMixin, forms.Form):
mizan = mizanFormMeta(
name="contact",
title="Contact Us",
)
name = forms.CharField()
email = forms.EmailField()
def on_submit_success(self, request):
return {"sent": True}
This auto-registers:
- contact.schema - Get form field definitions
- contact.validate - Validate form data
- contact.submit - Submit form
Overridable methods:
get_init_kwargs(cls, request) -> dict: Extra kwargs for form instantiation
on_submit_success(self, request) -> dict | None: Handle successful submission
on_submit_failure(self, request, errors) -> None: Handle failed submission
"""
# Configuration - subclasses must define this
mizan: ClassVar[mizanFormMeta]
# Track registered forms to avoid duplicate registration
_mizan_registered: ClassVar[bool] = False
@classmethod
def get_init_kwargs(cls, request: HttpRequest) -> dict[str, Any]:
"""
Override to provide extra kwargs for form instantiation.
Common use: pass request or user to forms that need them.
Example:
@classmethod
def get_init_kwargs(cls, request):
return {"request": request, "user": request.user}
"""
return {}
def on_submit_success(self, request: HttpRequest) -> dict | None:
"""
Called after successful form validation and submission.
Override to handle the form submission logic.
Return a dict to include data in the response.
Example:
def on_submit_success(self, request):
self.save()
return {"id": self.instance.pk}
"""
# Default: call save() if available
if hasattr(self, "save"):
result = self.save()
# If save returns something serializable, include it
if isinstance(result, dict):
return result
return None
def on_submit_failure(self, request: HttpRequest, errors: "FormValidation") -> None:
"""
Called after form validation fails.
Override to add custom error handling, logging, etc.
"""
pass
def __init_subclass__(cls, **kwargs):
"""Auto-register when a concrete form class is defined."""
super().__init_subclass__(**kwargs)
# Only register concrete forms with mizan config defined
if _is_concrete_mizan_form(cls):
_register_form_as_server_functions(cls)
def _is_concrete_mizan_form(cls: type) -> bool:
"""
Check if a class is a concrete mizan form ready for registration.
A form is concrete if:
1. It has a `mizan` attribute that is a mizanFormMeta instance
2. It inherits from Django's BaseForm
3. It hasn't been registered yet (for this class definition)
"""
# Must have mizan config (check cls.__dict__ to avoid inheriting)
mizan_config = cls.__dict__.get("mizan")
if not isinstance(mizan_config, mizanFormMeta):
return False
# Must be a Django form
if not issubclass(cls, forms.BaseForm):
return False
# Check if already registered (handle re-imports gracefully)
if cls.__dict__.get("_mizan_registered", False):
return False
return True
def _register_form_as_server_functions(form_class: type) -> None:
"""
Register a Django Form class as mizan server functions.
Creates and registers:
- {name}.schema - Returns form field definitions
- {name}.validate - Validates form data
- {name}.submit - Validates and submits form
Each function gets a unique typed schema for better TypeScript LSP support.
"""
from .schemas import FormSchema, FormSubmitFail, FormSubmitPass, FormValidation
from .schema_utils import build_form_schema
from .validation_utils import validate_form_instance
from mizan.setup.registry import register
from mizan.client.function import ServerFunction
config: mizanFormMeta = form_class.mizan
form_name = config.name
# Mark as registered
form_class._mizan_registered = True
# Generate PascalCase name for schemas (e.g., "contact" -> "Contact")
pascal_name = "".join(
word.capitalize()
for word in form_name.replace(".", "_").replace("-", "_").split("_")
)
# NOTE: We cannot create FormDataSchema here because form fields aren't
# populated yet during __init_subclass__. We use lazy creation instead.
_form_data_schema_cache: dict[str, type[BaseModel]] = {}
def get_form_data_schema() -> type[BaseModel]:
"""Lazily create the form data schema (form fields aren't available at registration time)."""
if "schema" not in _form_data_schema_cache:
_form_data_schema_cache["schema"] = _create_form_input_schema(
form_class, f"{pascal_name}FormData"
)
return _form_data_schema_cache["schema"]
# -------------------------------------------------------------------------
# Schema Function
# -------------------------------------------------------------------------
# Schema input wraps the form data for pre-populating dynamic fields
FormSchemaInput = create_model(
f"{pascal_name}SchemaInput",
data=(dict[str, Any], {}),
)
class SchemaFunction(ServerFunction):
Input = FormSchemaInput
Output = FormSchema
_meta: ClassVar[dict] = {
"form": True,
"form_name": form_name,
"form_role": "schema",
"form_class": form_class, # Store reference for schema generation
}
def call(self, input) -> FormSchema:
init_kwargs = form_class.get_init_kwargs(self.request)
schema = build_form_schema(
form_class,
data=input.data if input else {},
**init_kwargs,
)
# Override with mizanFormMeta values
if config.title is not None:
schema.title = config.title
if config.subtitle is not None:
schema.subtitle = config.subtitle
schema.submit_label = config.submit_label
# Behavior settings are nested in schema.meta
schema.meta.live_validation = config.live_validation
schema.meta.live_form_errors = config.live_form_errors
schema.meta.refetch_schema_on_validate = config.refetch_schema_on_validate
return schema
SchemaFunction.__name__ = f"{form_name}_schema"
SchemaFunction.__qualname__ = f"{form_name}_schema"
register(SchemaFunction, f"{form_name}.schema")
# -------------------------------------------------------------------------
# Validate Function
# -------------------------------------------------------------------------
# Use generic dict input - form fields aren't available during __init_subclass__
FormValidateInput = create_model(
f"{pascal_name}ValidateInput",
data=(dict[str, Any], ...),
)
class ValidateFunction(ServerFunction):
Input = FormValidateInput
Output = FormValidation
_meta: ClassVar[dict] = {
"form": True,
"form_name": form_name,
"form_role": "validate",
}
def call(self, input) -> FormValidation:
init_kwargs = form_class.get_init_kwargs(self.request)
# Input data is already a dict
data = input.data
_, validation = validate_form_instance(
form_class,
data=data,
files=None,
**init_kwargs,
)
return validation
ValidateFunction.__name__ = f"{form_name}_validate"
ValidateFunction.__qualname__ = f"{form_name}_validate"
register(ValidateFunction, f"{form_name}.validate")
# -------------------------------------------------------------------------
# Submit Function
# -------------------------------------------------------------------------
class SubmitFunction(ServerFunction):
"""
Submit function handles both JSON and multipart/form-data.
The executor detects form functions and parses the request appropriately.
"""
# Use dict for input - form fields unknown at registration time
Input = None # Signals executor to pass raw dict
_meta: ClassVar[dict] = {
"form": True,
"form_name": form_name,
"form_role": "submit",
"multipart": True, # Signal that this function accepts multipart
}
def call(self, input) -> FormSubmitPass | FormSubmitFail:
"""Execute form submission."""
request = self.request
# Check if we have multipart data from executor
if hasattr(request, "_mizan_form_data"):
data = request._mizan_form_data
files = request._mizan_form_files
elif input is not None:
# JSON input - already a dict
data = input if isinstance(input, dict) else input.model_dump()
files = None
else:
data = {}
files = None
init_kwargs = form_class.get_init_kwargs(request)
# Create and validate form
form, validation = validate_form_instance(
form_class,
data=data,
files=files,
**init_kwargs,
)
if form.is_valid():
# Call the form's on_submit_success
result_data = form.on_submit_success(request)
return FormSubmitPass(success=True, data=result_data)
# Call the form's on_submit_failure
form.on_submit_failure(request, validation)
return FormSubmitFail(success=False, errors=validation)
SubmitFunction.__name__ = f"{form_name}_submit"
SubmitFunction.__qualname__ = f"{form_name}_submit"
SubmitFunction.Output = FormSubmitPass # For schema generation
register(SubmitFunction, f"{form_name}.submit")
# -------------------------------------------------------------------------
# Formset Functions (if enabled)
# -------------------------------------------------------------------------
if config.enable_formset:
_register_formset_functions(form_class, form_name)
def _register_formset_functions(
form_class: type,
form_name: str,
) -> None:
"""Register formset server functions for a form."""
from django.forms import formset_factory
from .schemas import (
FormsetSchema,
FormsetSubmitFail,
FormsetSubmitPass,
FormsetValidation,
)
from .schema_utils import build_form_schema
from .validation_utils import build_formset_validation
from .formset_utils import forms_to_formset_post_data
from mizan.setup.registry import register
from mizan.client.function import ServerFunction
formset_class = formset_factory(form_class)
# Generate PascalCase name for schemas
pascal_name = "".join(
word.capitalize()
for word in form_name.replace(".", "_").replace("-", "_").split("_")
)
# NOTE: We cannot create typed schemas here because form fields aren't
# populated yet during __init_subclass__. We use generic dict inputs.
# -------------------------------------------------------------------------
# Formset Schema Function
# -------------------------------------------------------------------------
FormsetSchemaInput = create_model(
f"{pascal_name}FormsetSchemaInput",
forms=(list[dict[str, Any]], []),
)
class FormsetSchemaFunction(ServerFunction):
Input = FormsetSchemaInput
Output = FormsetSchema
_meta: ClassVar[dict] = {
"form": True,
"form_name": form_name,
"form_role": "formset_schema",
}
def call(self, input) -> FormsetSchema:
init_kwargs = form_class.get_init_kwargs(self.request)
forms_data = input.forms if input else []
formset_data = forms_to_formset_post_data(forms_data)
formset = formset_class(formset_data)
return FormsetSchema(
forms=[
build_form_schema(form_class, data=fd, **init_kwargs)
for fd in forms_data
],
min_num=formset.min_num,
max_num=formset.max_num,
can_delete=formset.can_delete,
can_order=formset.can_order,
)
FormsetSchemaFunction.__name__ = f"{form_name}_formset_schema"
register(FormsetSchemaFunction, f"{form_name}.formset.schema")
# -------------------------------------------------------------------------
# Formset Validate Function
# -------------------------------------------------------------------------
# Generic dict input - form fields aren't available during __init_subclass__
FormsetValidateInput = create_model(
f"{pascal_name}FormsetValidateInput",
forms=(list[dict[str, Any]], ...),
)
class FormsetValidateFunction(ServerFunction):
Input = FormsetValidateInput
Output = FormsetValidation
_meta: ClassVar[dict] = {
"form": True,
"form_name": form_name,
"form_role": "formset_validate",
}
def call(self, input) -> FormsetValidation:
init_kwargs = form_class.get_init_kwargs(self.request)
# Input.forms is already a list of dicts
forms_data = input.forms
formset_data = forms_to_formset_post_data(forms_data)
formset = formset_class(formset_data, form_kwargs=init_kwargs)
for form in formset:
form.empty_permitted = False
return build_formset_validation(formset)
FormsetValidateFunction.__name__ = f"{form_name}_formset_validate"
register(FormsetValidateFunction, f"{form_name}.formset.validate")
# -------------------------------------------------------------------------
# Formset Submit Function
# -------------------------------------------------------------------------
# Generic dict input - form fields aren't available during __init_subclass__
FormsetSubmitInput = create_model(
f"{pascal_name}FormsetSubmitInput",
forms=(list[dict[str, Any]], ...),
)
class FormsetSubmitFunction(ServerFunction):
Input = FormsetSubmitInput
_meta: ClassVar[dict] = {
"form": True,
"form_name": form_name,
"form_role": "formset_submit",
"multipart": True,
}
def call(self, input) -> FormsetSubmitPass | FormsetSubmitFail:
request = self.request
init_kwargs = form_class.get_init_kwargs(request)
# Handle multipart vs JSON
if hasattr(request, "_mizan_form_data"):
post_data = request._mizan_form_data
files = request._mizan_form_files
elif input and hasattr(input, "forms"):
# Input.forms is already a list of dicts
forms_data = input.forms
post_data = forms_to_formset_post_data(forms_data)
files = None
else:
post_data = {}
files = None
formset = formset_class(post_data, files=files, form_kwargs=init_kwargs)
if formset.is_valid():
for form in formset.forms:
if form.cleaned_data:
form.on_submit_success(request)
return FormsetSubmitPass(success=True)
validation = build_formset_validation(formset)
# Call failure handler on each form
for form in formset.forms:
if hasattr(form, "on_submit_failure"):
form.on_submit_failure(request, validation)
return FormsetSubmitFail(success=False, errors=validation)
FormsetSubmitFunction.__name__ = f"{form_name}_formset_submit"
FormsetSubmitFunction.Output = FormsetSubmitPass
register(FormsetSubmitFunction, f"{form_name}.formset.submit")

View File

@@ -0,0 +1,16 @@
from typing import Any
def forms_to_formset_post_data(forms_data: list[dict[str, Any]]) -> dict[str, Any]:
"""
Convert a list of form dicts into Django formset-compatible POST data.
"""
formset_data: dict[str, Any] = {
"form-TOTAL_FORMS": str(len(forms_data)),
"form-INITIAL_FORMS": "0",
}
for i, form_data in enumerate(forms_data):
formset_data.update(
{f"form-{i}-{key}": value for key, value in form_data.items()}
)
return formset_data

View File

@@ -0,0 +1,187 @@
import re
from typing import Any, Optional
from django import forms
from django.forms import Field
from .schemas import FieldChoice, FieldSchema, FormMeta, FormSchema
def create_form_instance(
form_class: type[forms.BaseForm],
data: Optional[dict] = None,
files: Optional[dict] = None,
**kwargs,
) -> forms.BaseForm:
"""
Create a form instance, gracefully handling kwargs that the form doesn't accept.
Some Django forms (like allauth's) accept `request` in __init__, others don't.
This function tries with all kwargs first, then progressively removes kwargs
that cause TypeErrors until instantiation succeeds.
"""
# Common kwargs that forms may or may not accept
optional_kwargs = ['request', 'user', 'instance']
# Build init kwargs
init_kwargs = dict(kwargs)
if data is not None:
init_kwargs['data'] = data
if files is not None:
init_kwargs['files'] = files
while True:
try:
return form_class(**init_kwargs)
except TypeError as e:
error_msg = str(e)
# Check if it's an unexpected keyword argument error
if "unexpected keyword argument" not in error_msg:
raise
# Find which kwarg caused the problem and remove it
removed = False
for kwarg in optional_kwargs:
if f"'{kwarg}'" in error_msg and kwarg in init_kwargs:
init_kwargs.pop(kwarg)
removed = True
break
# If we couldn't identify/remove the problematic kwarg, re-raise
if not removed:
raise
def _get_choices(field: Field) -> Optional[list[FieldChoice]]:
"""
Extract choices from a field, handling ModelChoiceField properly.
ModelChoiceField returns ModelChoiceIteratorValue which is not JSON serializable.
"""
if not hasattr(field, "choices"):
return None
choices: list[FieldChoice] = []
for raw_value, label in field.choices:
value = getattr(
raw_value, "value", raw_value
) # ModelChoiceIteratorValue -> .value
choices.append(FieldChoice(value=str(value), label=str(label)))
return choices
def _get_initial(value: Any) -> Any:
"""Convert initial value to JSON-serializable format."""
if value is None:
return None
if hasattr(value, "isoformat"):
return value.isoformat()
if hasattr(value, "pk"):
return value.pk
if hasattr(value, "__iter__") and not isinstance(value, (str, bytes)):
return [item.pk if hasattr(item, "pk") else item for item in value]
return value
def _class_name_to_title(name: str) -> str:
"""
Convert a class name to a human-readable title.
e.g., 'LoginForm' -> 'Login', 'ResetPasswordForm' -> 'Reset Password'
"""
# Remove 'Form' suffix
name = re.sub(r"Form$", "", name)
# Insert spaces before capital letters
name = re.sub(r"([a-z])([A-Z])", r"\1 \2", name)
return name
def _class_name_to_slug(name: str) -> str:
"""
Convert a class name to a slug.
e.g., 'LoginForm' -> 'login', 'ResetPasswordForm' -> 'reset_password'
"""
# Remove 'Form' suffix
name = re.sub(r"Form$", "", name)
# Insert underscores before capital letters and lowercase
name = re.sub(r"([a-z])([A-Z])", r"\1_\2", name)
return name.lower()
def build_form_schema(
form_class: type[forms.BaseForm],
data: Optional[dict] = None,
**kwargs,
) -> FormSchema:
"""
Produce a FormSchema for the given Django form class and (optional) data.
The form class can define metadata via an inner Meta class:
class MyForm(forms.Form):
class Meta:
form_name = "my_form"
title = "My Form Title"
subtitle = "Optional description"
submit_label = "Submit"
# Frontend behavior (optional)
refetch_schema_on_validate = False # Set True for dynamic choice fields
live_validation = True # Set False to disable live validation
live_form_errors = False # Set True to show form errors live
If not provided, sensible defaults are derived from the class name.
"""
form = create_form_instance(form_class, data=data, **kwargs)
# Extract metadata from form's Meta class
form_meta = getattr(form_class, "Meta", None)
# Get form name (used as identifier)
name = getattr(form_meta, "form_name", None)
if name is None:
name = _class_name_to_slug(form_class.__name__)
# Get title (human-readable heading)
title = getattr(form_meta, "title", None)
if title is None:
title = _class_name_to_title(form_class.__name__)
# Get optional subtitle
subtitle = getattr(form_meta, "subtitle", None)
# Get submit button label
submit_label = getattr(form_meta, "submit_label", None)
if submit_label is None:
submit_label = "Submit"
# Build frontend behavior metadata
frontend_meta = FormMeta(
refetch_schema_on_validate=getattr(form_meta, "refetch_schema_on_validate", False),
live_validation=getattr(form_meta, "live_validation", True),
live_form_errors=getattr(form_meta, "live_form_errors", False),
)
return FormSchema(
name=name,
title=title,
subtitle=subtitle,
submit_label=submit_label,
fields=[
FieldSchema(
name=name,
label=str(field.label or name.replace("_", " ").title()),
type=getattr(field.widget, "input_type", "text"),
widget=field.widget.__class__.__name__,
required=field.required,
disabled=field.disabled,
help_text=str(field.help_text) if field.help_text else "",
initial=_get_initial(field.initial),
max_length=getattr(field, "max_length", None),
min_length=getattr(field, "min_length", None),
choices=_get_choices(field),
)
for name, field in form.fields.items()
],
meta=frontend_meta,
)

View File

@@ -0,0 +1,103 @@
from typing import Any, Optional
from ninja import Schema
# Form metadata schema
class FormMeta(Schema):
"""
Metadata controlling frontend form behavior.
Attributes:
refetch_schema_on_validate: If True, frontend should refetch schema on each
validation (useful for dynamic choice fields). Default False.
live_validation: If False, frontend should disable live validation entirely.
Useful for sensitive forms like login. Default True.
live_form_errors: If True, show form-level errors during live validation.
Form errors are things like "Invalid credentials" vs field errors like
"This field is required". Default False for security.
"""
refetch_schema_on_validate: bool = False
live_validation: bool = True
live_form_errors: bool = False
# Field-level schemas
class FieldChoice(Schema):
value: str
label: str
class FieldError(Schema):
message: str
code: Optional[str]
class FieldErrorList(Schema):
field: str
errors: list[FieldError]
class FieldSchema(Schema):
name: str
label: str
type: str
widget: str
required: bool
disabled: bool
help_text: str
initial: Any
max_length: Optional[int]
min_length: Optional[int]
choices: Optional[list[FieldChoice]]
# Form-level schemas
class FormSchema(Schema):
"""Schema returned by /schema endpoint with form metadata and fields."""
# Form metadata
name: str
title: str
subtitle: Optional[str]
submit_label: str
# Fields
fields: list[FieldSchema]
# Frontend behavior metadata
meta: FormMeta = FormMeta()
class FormValidation(Schema):
errors: list[FieldErrorList]
class FormSubmitPass(Schema):
success: bool
data: Optional[dict] = None
class FormSubmitFail(Schema):
success: bool
errors: FormValidation
# Formset-level schemas
class FormsetSchema(Schema):
forms: list[FormSchema]
min_num: int
max_num: int
can_delete: bool
can_order: bool
class FormsetValidation(Schema):
general: list[str]
per_form: list[FormValidation]
class FormsetSubmitPass(Schema):
success: bool
class FormsetSubmitFail(Schema):
success: bool
errors: FormsetValidation

View File

@@ -0,0 +1,72 @@
from typing import Any
from django import forms
from django.core.files.uploadedfile import UploadedFile
from django.utils.datastructures import MultiValueDict
from .schemas import (
FieldError,
FieldErrorList,
FormValidation,
FormsetValidation,
)
from .schema_utils import create_form_instance
def validate_form_instance(
form_class: type[forms.BaseForm],
data: dict,
files: MultiValueDict[str, UploadedFile] | None = None,
**kwargs: Any,
) -> tuple[forms.BaseForm, FormValidation]:
"""
Build a form instance and return (form, structured_validation_errors).
"""
form = create_form_instance(form_class, data=data, files=files, initial=data, **kwargs)
# Run validation
form.is_valid()
validation = FormValidation(
errors=[
FieldErrorList(
field=field_name,
errors=[
FieldError(
message=str(e.message) if hasattr(e, 'message') else str(e),
code=getattr(e, "code", None),
)
for e in field_errors.as_data()
],
)
for field_name, field_errors in form.errors.items()
]
)
return form, validation
def build_formset_validation(formset: forms.BaseFormSet) -> FormsetValidation:
"""
Turn a Django formset into a FormsetValidation structure.
"""
return FormsetValidation(
general=[str(e) if e else "" for e in formset.non_form_errors()],
per_form=[
FormValidation(
errors=[
FieldErrorList(
field=field_name,
errors=[
FieldError(
message=str(e.message) if hasattr(e, 'message') else str(e),
code=getattr(e, "code", None),
)
for e in field_errors.as_data()
],
)
for field_name, field_errors in form.errors.items()
]
)
for form in formset
],
)

View File

@@ -0,0 +1,25 @@
"""
mizan Allauth Integration
Backend support for django-allauth with mizan server functions.
Provides:
- Auth contexts (auth_status, user) - required by frontend allauth module
- Allauth form wrappers - expose allauth forms as server functions
Usage:
# In your app's apps.py
class MyAppConfig(AppConfig):
def ready(self):
import mizan.allauth.forms # noqa - registers forms
import mizan.allauth.contexts # noqa - registers contexts
"""
from .contexts import auth_status, user, AuthStatusOutput, UserOutput
__all__ = [
"auth_status",
"user",
"AuthStatusOutput",
"UserOutput",
]

View File

@@ -0,0 +1,118 @@
"""
Auth contexts for mizan Allauth integration.
These are the core auth primitives that the frontend allauth module depends on.
Separated into two concerns:
- auth_status: Authentication state and permission guards (fast, no DB hit with JWT)
- user: Full user profile data (may require DB query for JWT auth)
Both are registered as global contexts for SSR hydration.
"""
from django.http import HttpRequest
from pydantic import BaseModel
from mizan.client import client
# =============================================================================
# Auth Status Context
# =============================================================================
class AuthStatusOutput(BaseModel):
"""Authentication status and permission guards."""
is_authenticated: bool
user_id: int | None = None
is_staff: bool = False
is_superuser: bool = False
@client(context="global")
def auth_status(request: HttpRequest) -> AuthStatusOutput:
"""
Auth status context - provides authentication state and guards.
This works identically for both session and JWT auth. The data comes
from the request.user object (either full User or JWTUser with claims).
Frontend:
const auth = useAuthStatus()
if (auth.is_authenticated) { ... }
if (auth.is_staff) { ... }
"""
user = request.user
if not user.is_authenticated:
return AuthStatusOutput(is_authenticated=False)
return AuthStatusOutput(
is_authenticated=True,
user_id=user.id,
is_staff=user.is_staff,
is_superuser=user.is_superuser,
)
# =============================================================================
# User Profile Context
# =============================================================================
class UserOutput(BaseModel):
"""Full user profile data."""
id: int
email: str
first_name: str = ""
last_name: str = ""
@client(context="global")
def user(request: HttpRequest) -> UserOutput | None:
"""
User profile context - provides full user data.
Unlike auth_status, this may require a DB query (for JWT auth where
the user object is a minimal JWTUser with only claims).
Returns None if not authenticated.
Frontend:
const user = useUser()
if (user) {
console.log(user.email)
}
"""
req_user = request.user
if not req_user.is_authenticated:
return None
# Check if we have full user data or just JWT claims
if hasattr(req_user, "email") and req_user.email:
# Full User object (session auth)
return UserOutput(
id=req_user.id,
email=req_user.email,
first_name=getattr(req_user, "first_name", "") or "",
last_name=getattr(req_user, "last_name", "") or "",
)
# JWTUser - need to fetch from DB
from django.contrib.auth import get_user_model
User = get_user_model()
try:
db_user = User.objects.get(pk=req_user.id)
return UserOutput(
id=db_user.id,
email=db_user.email,
first_name=db_user.first_name or "",
last_name=db_user.last_name or "",
)
except User.DoesNotExist:
return None

View File

@@ -0,0 +1,408 @@
"""
Allauth forms as mizan server functions.
This module wraps allauth forms with mizanFormMixin, exposing them as
typed server functions for the React frontend.
Each form becomes three server functions:
- {name}.schema - Get form field definitions
- {name}.validate - Validate form data
- {name}.submit - Submit form
Import this module in your app's ready() to register the forms:
class MyAppConfig(AppConfig):
def ready(self):
import mizan.allauth.forms # noqa
"""
from __future__ import annotations
from typing import TYPE_CHECKING, Any
from django.http import HttpRequest
from mizan.forms import mizanFormMixin, mizanFormMeta
# Account forms
from allauth.account.forms import (
AddEmailForm,
ChangePasswordForm,
ConfirmLoginCodeForm,
LoginForm,
RequestLoginCodeForm,
ResetPasswordForm,
ResetPasswordKeyForm,
SetPasswordForm,
SignupForm,
UserTokenForm,
)
# Password reauthentication form - conditionally import
try:
from allauth.account.forms import ReauthenticateForm
HAS_REAUTH = True
except ImportError:
HAS_REAUTH = False
# MFA forms - conditionally import
try:
from allauth.mfa.base.forms import AuthenticateForm as MFAAuthenticateForm
from allauth.mfa.base.forms import ReauthenticateForm as MFAReauthenticateForm
from allauth.mfa.totp.forms import ActivateTOTPForm, DeactivateTOTPForm
from allauth.mfa.recovery_codes.forms import GenerateRecoveryCodesForm
HAS_MFA = True
except ImportError:
HAS_MFA = False
# WebAuthn forms (if available)
try:
from allauth.mfa.webauthn.forms import AuthenticateWebAuthnForm
HAS_WEBAUTHN = True
except ImportError:
HAS_WEBAUTHN = False
if TYPE_CHECKING:
from mizan.forms.schemas import FormValidation
# =============================================================================
# Account Forms
# =============================================================================
class mizanLoginForm(LoginForm, mizanFormMixin):
"""Sign in with email and password."""
mizan = mizanFormMeta(
name="login",
title="Sign In",
subtitle="Welcome back. Enter your credentials to continue.",
submit_label="Sign In",
live_validation=False, # Don't validate credentials as user types
)
@classmethod
def get_init_kwargs(cls, request: HttpRequest) -> dict[str, Any]:
return {"request": request}
def on_submit_success(self, request: HttpRequest) -> dict | None:
self.login(request)
return None
class mizanSignupForm(SignupForm, mizanFormMixin):
"""Create a new account."""
mizan = mizanFormMeta(
name="signup",
title="Create Account",
subtitle="Enter your details to get started.",
submit_label="Create Account",
)
@classmethod
def get_init_kwargs(cls, request: HttpRequest) -> dict[str, Any]:
return {"request": request}
def on_submit_success(self, request: HttpRequest) -> dict | None:
self.save(request)
return None
class mizanAddEmailForm(AddEmailForm, mizanFormMixin):
"""Add another email address to your account."""
mizan = mizanFormMeta(
name="add_email",
title="Add Email Address",
subtitle="Add another email address to your account.",
submit_label="Add Email",
)
@classmethod
def get_init_kwargs(cls, request: HttpRequest) -> dict[str, Any]:
return {"request": request, "user": request.user}
def on_submit_success(self, request: HttpRequest) -> dict | None:
self.save()
return None
class mizanChangePasswordForm(ChangePasswordForm, mizanFormMixin):
"""Change your account password."""
mizan = mizanFormMeta(
name="change_password",
title="Change Password",
subtitle="Update your password to keep your account secure.",
submit_label="Change Password",
)
@classmethod
def get_init_kwargs(cls, request: HttpRequest) -> dict[str, Any]:
return {"request": request, "user": request.user}
def on_submit_success(self, request: HttpRequest) -> dict | None:
self.save()
return None
class mizanSetPasswordForm(SetPasswordForm, mizanFormMixin):
"""Set a password for accounts created via social login."""
mizan = mizanFormMeta(
name="set_password",
title="Set Password",
subtitle="Create a password for your account.",
submit_label="Set Password",
)
@classmethod
def get_init_kwargs(cls, request: HttpRequest) -> dict[str, Any]:
return {"request": request, "user": request.user}
def on_submit_success(self, request: HttpRequest) -> dict | None:
self.save()
return None
class mizanResetPasswordForm(ResetPasswordForm, mizanFormMixin):
"""Request a password reset email."""
mizan = mizanFormMeta(
name="reset_password",
title="Reset Password",
subtitle="Enter your email address and we'll send you a link to reset your password.",
submit_label="Send Reset Link",
)
@classmethod
def get_init_kwargs(cls, request: HttpRequest) -> dict[str, Any]:
return {"request": request}
def on_submit_success(self, request: HttpRequest) -> dict | None:
self.save(request)
return None
class mizanResetPasswordKeyForm(ResetPasswordKeyForm, mizanFormMixin):
"""Set a new password using a reset key."""
mizan = mizanFormMeta(
name="reset_password_from_key",
title="Set New Password",
subtitle="Enter your new password below.",
submit_label="Reset Password",
)
@classmethod
def get_init_kwargs(cls, request: HttpRequest) -> dict[str, Any]:
return {"request": request, "user": request.user}
def on_submit_success(self, request: HttpRequest) -> dict | None:
self.save()
return None
class mizanRequestLoginCodeForm(RequestLoginCodeForm, mizanFormMixin):
"""Request a login code via email."""
mizan = mizanFormMeta(
name="request_login_code",
title="Sign In with Code",
subtitle="Enter your email address and we'll send you a login code.",
submit_label="Send Code",
)
@classmethod
def get_init_kwargs(cls, request: HttpRequest) -> dict[str, Any]:
return {"request": request}
def on_submit_success(self, request: HttpRequest) -> dict | None:
self.save()
return None
class mizanConfirmLoginCodeForm(ConfirmLoginCodeForm, mizanFormMixin):
"""Confirm a login code."""
mizan = mizanFormMeta(
name="confirm_login_code",
title="Enter Code",
subtitle="Enter the code we sent to your email.",
submit_label="Verify Code",
)
@classmethod
def get_init_kwargs(cls, request: HttpRequest) -> dict[str, Any]:
return {"request": request}
def on_submit_success(self, request: HttpRequest) -> dict | None:
self.save()
return None
class mizanUserTokenForm(UserTokenForm, mizanFormMixin):
"""Verify an email with a token."""
mizan = mizanFormMeta(
name="user_token",
title="Verify Email",
subtitle="Enter the verification code from your email.",
submit_label="Verify",
)
@classmethod
def get_init_kwargs(cls, request: HttpRequest) -> dict[str, Any]:
return {"request": request}
def on_submit_success(self, request: HttpRequest) -> dict | None:
self.save()
return None
# Password reauthentication - conditionally define
if HAS_REAUTH:
class mizanReauthenticateForm(ReauthenticateForm, mizanFormMixin):
"""Re-authenticate with password for sensitive actions."""
mizan = mizanFormMeta(
name="reauthenticate",
title="Confirm Your Identity",
subtitle="Please enter your password to continue.",
submit_label="Confirm",
live_validation=False,
)
@classmethod
def get_init_kwargs(cls, request: HttpRequest) -> dict[str, Any]:
return {"request": request, "user": request.user}
def on_submit_success(self, request: HttpRequest) -> dict | None:
from allauth.account.internal.flows import reauthentication
reauthentication.reauthenticate_by_password(request)
return None
# =============================================================================
# MFA Forms
# =============================================================================
if HAS_MFA:
class mizanMFAAuthenticateForm(MFAAuthenticateForm, mizanFormMixin):
"""Authenticate with MFA during login."""
mizan = mizanFormMeta(
name="mfa_authenticate",
title="Two-Factor Authentication",
subtitle="Enter your authentication code to continue.",
submit_label="Verify",
)
@classmethod
def get_init_kwargs(cls, request: HttpRequest) -> dict[str, Any]:
return {"request": request, "user": request.user}
def on_submit_success(self, request: HttpRequest) -> dict | None:
self.save()
return None
class mizanMFAReauthenticateForm(MFAReauthenticateForm, mizanFormMixin):
"""Re-authenticate with MFA for sensitive actions."""
mizan = mizanFormMeta(
name="mfa_reauthenticate",
title="Confirm Your Identity",
subtitle="Enter your authentication code to continue.",
submit_label="Confirm",
)
@classmethod
def get_init_kwargs(cls, request: HttpRequest) -> dict[str, Any]:
return {"request": request, "user": request.user}
def on_submit_success(self, request: HttpRequest) -> dict | None:
self.save()
return None
class mizanActivateTOTPForm(ActivateTOTPForm, mizanFormMixin):
"""Activate TOTP authenticator."""
mizan = mizanFormMeta(
name="activate_totp",
title="Set Up Authenticator",
subtitle="Enter the code from your authenticator app to complete setup.",
submit_label="Activate",
)
@classmethod
def get_init_kwargs(cls, request: HttpRequest) -> dict[str, Any]:
return {"request": request, "user": request.user}
def on_submit_success(self, request: HttpRequest) -> dict | None:
self.save()
return None
class mizanDeactivateTOTPForm(DeactivateTOTPForm, mizanFormMixin):
"""Deactivate TOTP authenticator."""
mizan = mizanFormMeta(
name="deactivate_totp",
title="Disable Authenticator",
subtitle="Enter your password to disable two-factor authentication.",
submit_label="Disable",
)
@classmethod
def get_init_kwargs(cls, request: HttpRequest) -> dict[str, Any]:
return {"request": request, "user": request.user}
def on_submit_success(self, request: HttpRequest) -> dict | None:
self.save()
return None
class mizanGenerateRecoveryCodesForm(GenerateRecoveryCodesForm, mizanFormMixin):
"""Generate new recovery codes."""
mizan = mizanFormMeta(
name="generate_recovery_codes",
title="Recovery Codes",
subtitle="Generate new recovery codes for your account.",
submit_label="Generate Codes",
)
@classmethod
def get_init_kwargs(cls, request: HttpRequest) -> dict[str, Any]:
return {"request": request, "user": request.user}
def on_submit_success(self, request: HttpRequest) -> dict | None:
self.save()
return None
if HAS_WEBAUTHN:
class mizanAuthenticateWebAuthnForm(AuthenticateWebAuthnForm, mizanFormMixin):
"""Authenticate with WebAuthn security key."""
mizan = mizanFormMeta(
name="webauthn_authenticate",
title="Security Key",
subtitle="Use your security key to authenticate.",
submit_label="Use Security Key",
)
@classmethod
def get_init_kwargs(cls, request: HttpRequest) -> dict[str, Any]:
return {"request": request, "user": request.user}
def on_submit_success(self, request: HttpRequest) -> dict | None:
self.save()
return None

View File

@@ -0,0 +1,71 @@
"""
mizan.jwt - JWT authentication for server functions.
Provides:
- Server functions for obtaining/refreshing JWT tokens
- JWT authentication utilities for validating tokens
Server Functions:
- jwt_obtain: Convert authenticated session to JWT tokens
- jwt_refresh: Refresh tokens using a refresh token
Usage in apps.py or urls.py (to register the functions):
import mizan.jwt.functions # noqa: F401
Note: This module is purpose-built for mizan server functions.
For Django Ninja API authentication, use mizan.jwt.security directly.
"""
# Server functions (import to register with @client decorator)
from .functions import jwt_obtain, jwt_refresh
# Token utilities
from .tokens import (
create_token_pair,
create_access_token,
create_refresh_token,
decode_token,
refresh_tokens,
TokenPair,
TokenPayload,
JWTUser,
)
# Settings
from .settings import get_settings, JWTSettings
# Security (Ninja API auth) - lazy import to avoid triggering
# django-ninja's settings access at module load time.
# Use: from mizan.jwt.security import jwt_auth
def __getattr__(name):
if name in ("JWTAuth", "jwt_auth"):
from .security import JWTAuth, jwt_auth
globals()["JWTAuth"] = JWTAuth
globals()["jwt_auth"] = jwt_auth
return globals()[name]
raise AttributeError(f"module {__name__!r} has no attribute {name!r}")
__all__ = [
# Server functions
"jwt_obtain",
"jwt_refresh",
# Token utilities
"create_token_pair",
"create_access_token",
"create_refresh_token",
"decode_token",
"refresh_tokens",
"TokenPair",
"TokenPayload",
"JWTUser",
# Settings
"get_settings",
"JWTSettings",
# Security (lazy)
"JWTAuth",
"jwt_auth",
]

View File

@@ -0,0 +1,141 @@
"""
JWT & MWT Server Functions
Token operations exposed as mizan server functions.
Works over WebSocket RPC (primary) or HTTP fallback.
"""
from django.http import HttpRequest
from pydantic import BaseModel
from mizan.client import client
from mizan.jwt.tokens import create_token_pair, refresh_tokens
from mizan.mwt import create_mwt
class TokenPairOutput(BaseModel):
"""JWT token pair response."""
access_token: str
refresh_token: str
expires_in: int
class JWTError(BaseModel):
"""JWT operation error."""
error: str
@client
def jwt_obtain(request: HttpRequest) -> TokenPairOutput:
"""
Obtain JWT tokens from an authenticated session.
Requires session authentication (cookie or WebSocket session).
Returns access and refresh tokens that can be used for stateless auth.
The tokens include user claims (is_staff, is_superuser) so that
subsequent JWT-authenticated requests don't need a database query.
Usage:
const { access_token, refresh_token } = await call('jwt_obtain')
// Use access_token in Authorization: Bearer header
"""
user = request.user
if not user.is_authenticated:
raise PermissionError("Authentication required")
# Get session key - for WebSocket, this comes from the scope
session = getattr(request, "session", None)
if session is None:
# WebSocket request adapter - session is a dict, not SessionBase
session_key = (
getattr(request, "_scope", {}).get("session", {}).get("_session_key")
)
if not session_key:
raise PermissionError("No session available")
else:
# HTTP request - ensure session is saved
if not session.session_key:
session.save()
session_key = session.session_key
# Include user claims in the token for stateless auth
tokens = create_token_pair(
user.pk,
session_key,
is_staff=getattr(user, "is_staff", False),
is_superuser=getattr(user, "is_superuser", False),
)
return TokenPairOutput(
access_token=tokens.access_token,
refresh_token=tokens.refresh_token,
expires_in=tokens.expires_in,
)
@client
def jwt_refresh(request: HttpRequest, refresh_token: str) -> TokenPairOutput:
"""
Refresh JWT tokens using a refresh token.
Does not require session authentication - the refresh token itself
contains the session reference and is validated against the session store.
If the original session has been destroyed (user logged out), this fails.
Usage:
const { access_token, refresh_token } = await call('jwt_refresh', { refresh_token })
"""
tokens = refresh_tokens(refresh_token)
if tokens is None:
raise PermissionError("Invalid or expired refresh token")
return TokenPairOutput(
access_token=tokens.access_token,
refresh_token=tokens.refresh_token,
expires_in=tokens.expires_in,
)
# ── MWT (Mizan Web Token) ──────────────────────────────────────────────
class MWTOutput(BaseModel):
"""MWT token response."""
token: str
expires_in: int
@client
def mwt_obtain(request: HttpRequest) -> MWTOutput:
"""
Obtain a Mizan Web Token from an authenticated session.
Requires session authentication (cookie-based login).
Returns an MWT for the X-Mizan-Token header — stateless,
cache-aware authentication with permission staleness detection.
Usage (from frontend):
const { token, expires_in } = await call('mwt_obtain')
// Use token in X-Mizan-Token header
"""
user = request.user
if not user.is_authenticated:
raise PermissionError("Authentication required")
from mizan.setup.settings import get_settings
settings = get_settings()
if not settings.mwt_secret:
raise ValueError(
"MIZAN_MWT_SECRET is not configured. MWT requires a signing secret."
)
token = create_mwt(user, settings.mwt_secret, ttl=settings.mwt_ttl)
return MWTOutput(token=token, expires_in=settings.mwt_ttl)

View File

@@ -0,0 +1,64 @@
"""
Django Ninja Security Classes for JWT Authentication
Provides authentication classes that can be used with Django Ninja's
auth parameter to protect API endpoints.
"""
from django.http import HttpRequest
from ninja.security import HttpBearer
from .tokens import decode_token, JWTUser
class JWTAuth(HttpBearer):
"""
JWT Bearer token authentication for Django Ninja.
Usage:
from ninja_jwt_session import jwt_auth
@api.get("/protected/", auth=jwt_auth)
def protected_endpoint(request):
return {"user_id": request.user.id}
Or globally:
api = NinjaExtraAPI(auth=[django_auth, jwt_auth])
The token must be passed in the Authorization header:
Authorization: Bearer <access_token>
IMPORTANT: This is stateless - no database query is made.
request.user is a JWTUser object with id, is_staff, is_superuser.
If you need the full User object, query it explicitly:
user = User.objects.get(pk=request.user.id)
"""
def authenticate(self, request: HttpRequest, token: str):
"""
Validate the JWT and return a JWTUser if valid.
Returns None (authentication failed) if:
- Token is invalid or expired
- Token is not an access token
Note: No database query is made. The JWTUser is created from
token claims. This is truly stateless authentication.
"""
# Decode and validate the token
payload = decode_token(token, expected_type="access")
if payload is None:
return None
# Create JWTUser from token claims - NO DATABASE QUERY
jwt_user = JWTUser(payload)
# Set request.user for compatibility with code expecting it
request.user = jwt_user
return jwt_user
# Singleton instance for convenience
jwt_auth = JWTAuth()

View File

@@ -0,0 +1,118 @@
"""
JWT Hybrid Settings
Configuration is read from Django settings with sensible defaults.
Supports both symmetric (HS256) and asymmetric (RS256) algorithms.
"""
from dataclasses import dataclass
from functools import lru_cache
from django.conf import settings as django_settings
@dataclass
class JWTSettings:
"""JWT configuration."""
# Signing keys
private_key: str # Used for signing (required)
public_key: str # Used for verification (same as private for HS256)
# Algorithm
algorithm: str # HS256, RS256, etc.
# Token lifetimes (seconds)
access_token_expires_in: int
refresh_token_expires_in: int
# Security options
validate_session: bool # Check session exists on token validation
rotate_refresh_token: bool # Issue new refresh token on refresh
@lru_cache
def get_settings() -> JWTSettings:
"""
Load JWT settings from Django settings.
Settings:
JWT_PRIVATE_KEY: Signing key (required)
JWT_PUBLIC_KEY: Verification key (defaults to private key for HS256)
JWT_ALGORITHM: Algorithm to use (default: HS256)
JWT_ACCESS_TOKEN_EXPIRES_IN: Access token lifetime (default: 300)
JWT_REFRESH_TOKEN_EXPIRES_IN: Refresh token lifetime (default: 604800)
JWT_VALIDATE_SESSION: Validate session on token use (default: True)
JWT_ROTATE_REFRESH_TOKEN: Rotate refresh tokens (default: True)
"""
private_key = getattr(django_settings, "JWT_PRIVATE_KEY", None)
if not private_key:
# Fall back to allauth setting if available (for compatibility)
headless_key = getattr(django_settings, "HEADLESS_JWT_PRIVATE_KEY", None)
if headless_key:
private_key = headless_key
if private_key is None:
raise ValueError(
"JWT_PRIVATE_KEY must be set in Django settings. "
"For HS256, use a secure random string. "
"For RS256, use a PEM-encoded RSA private key."
)
# Auto-detect algorithm based on key format if not explicitly set
algorithm = getattr(django_settings, "JWT_ALGORITHM", None)
if algorithm is None:
# Auto-detect: if key looks like PEM, use RS256; otherwise HS256
if isinstance(private_key, str) and private_key.strip().startswith("-----BEGIN"):
algorithm = "RS256"
else:
algorithm = "HS256"
# For symmetric algorithms, public key = private key
if algorithm.startswith("HS"):
public_key = private_key
else:
public_key = getattr(django_settings, "JWT_PUBLIC_KEY", None)
if public_key is None:
# Try to extract public key from private key for RSA
try:
from cryptography.hazmat.primitives import serialization
from cryptography.hazmat.primitives.serialization import load_pem_private_key
private_key_obj = load_pem_private_key(
private_key.encode() if isinstance(private_key, str) else private_key,
password=None,
)
public_key = private_key_obj.public_key().public_bytes(
encoding=serialization.Encoding.PEM,
format=serialization.PublicFormat.SubjectPublicKeyInfo,
).decode()
except Exception:
raise ValueError(
f"JWT_PUBLIC_KEY must be set for {algorithm} algorithm, "
"or JWT_PRIVATE_KEY must be a valid PEM-encoded RSA key."
)
return JWTSettings(
private_key=private_key,
public_key=public_key,
algorithm=algorithm,
access_token_expires_in=getattr(
django_settings,
"JWT_ACCESS_TOKEN_EXPIRES_IN",
getattr(django_settings, "HEADLESS_JWT_ACCESS_TOKEN_EXPIRES_IN", 300),
),
refresh_token_expires_in=getattr(
django_settings,
"JWT_REFRESH_TOKEN_EXPIRES_IN",
getattr(django_settings, "HEADLESS_JWT_REFRESH_TOKEN_EXPIRES_IN", 604800),
),
validate_session=getattr(
django_settings, "JWT_VALIDATE_SESSION", True
),
rotate_refresh_token=getattr(
django_settings, "JWT_ROTATE_REFRESH_TOKEN", True
),
)

View File

@@ -0,0 +1,245 @@
"""
JWT Token Creation and Validation
Uses PyJWT directly - no allauth dependency.
Tokens are tied to Django sessions for immediate revocation on logout.
"""
import time
from typing import NamedTuple
import jwt
from django.contrib.sessions.backends.base import SessionBase
from .settings import get_settings
class TokenPair(NamedTuple):
"""Access and refresh token pair."""
access_token: str
refresh_token: str
expires_in: int
class TokenPayload(NamedTuple):
"""Decoded token payload."""
user_id: int | str
session_key: str
token_type: str
is_staff: bool
is_superuser: bool
exp: int
iat: int
class JWTUser:
"""
Minimal user object created from JWT claims.
Used as request.user for JWT-authenticated requests.
No database query required - all data comes from the token.
If you need the full User object with all fields, query explicitly:
user = User.objects.get(pk=request.user.id)
"""
def __init__(self, payload: TokenPayload):
self.id = int(payload.user_id) if isinstance(payload.user_id, str) else payload.user_id
self.pk = self.id
self.is_staff = payload.is_staff
self.is_superuser = payload.is_superuser
self.is_authenticated = True
self.is_anonymous = False
self.is_active = True # Assumed active if they have a valid token
def __str__(self):
return f"JWTUser(id={self.id})"
def __repr__(self):
return f"JWTUser(id={self.id}, is_staff={self.is_staff}, is_superuser={self.is_superuser})"
def create_access_token(
user_id: int | str,
session_key: str,
*,
is_staff: bool = False,
is_superuser: bool = False,
) -> str:
"""
Create a short-lived access token.
The token contains:
- sub: user ID
- sid: session key (for revocation checking)
- staff: is_staff flag
- super: is_superuser flag
- type: "access"
- iat: issued at
- exp: expiration
"""
settings = get_settings()
now = int(time.time())
payload = {
"sub": str(user_id),
"sid": session_key,
"staff": is_staff,
"super": is_superuser,
"type": "access",
"iat": now,
"exp": now + settings.access_token_expires_in,
}
return jwt.encode(
payload,
settings.private_key,
algorithm=settings.algorithm,
)
def create_refresh_token(
user_id: int | str,
session_key: str,
*,
is_staff: bool = False,
is_superuser: bool = False,
) -> str:
"""
Create a longer-lived refresh token.
The token contains:
- sub: user ID
- sid: session key (for revocation checking)
- staff: is_staff flag
- super: is_superuser flag
- type: "refresh"
- iat: issued at
- exp: expiration
"""
settings = get_settings()
now = int(time.time())
payload = {
"sub": str(user_id),
"sid": session_key,
"staff": is_staff,
"super": is_superuser,
"type": "refresh",
"iat": now,
"exp": now + settings.refresh_token_expires_in,
}
return jwt.encode(
payload,
settings.private_key,
algorithm=settings.algorithm,
)
def create_token_pair(
user_id: int | str,
session_key: str,
*,
is_staff: bool = False,
is_superuser: bool = False,
) -> TokenPair:
"""Create both access and refresh tokens."""
settings = get_settings()
return TokenPair(
access_token=create_access_token(
user_id, session_key, is_staff=is_staff, is_superuser=is_superuser
),
refresh_token=create_refresh_token(
user_id, session_key, is_staff=is_staff, is_superuser=is_superuser
),
expires_in=settings.access_token_expires_in,
)
def decode_token(token: str, expected_type: str = None) -> TokenPayload | None:
"""
Decode and validate a JWT token.
Returns None if:
- Token is invalid or expired
- Token type doesn't match expected_type (if specified)
"""
settings = get_settings()
try:
payload = jwt.decode(
token,
settings.public_key,
algorithms=[settings.algorithm],
)
except jwt.PyJWTError:
return None
# Validate token type if specified
if expected_type and payload.get("type") != expected_type:
return None
return TokenPayload(
user_id=payload["sub"],
session_key=payload["sid"],
token_type=payload["type"],
is_staff=payload.get("staff", False),
is_superuser=payload.get("super", False),
exp=payload["exp"],
iat=payload["iat"],
)
def validate_session(session_key: str) -> bool:
"""
Check if a session is still valid (exists and not expired).
This is the key to immediate logout revocation - if the session
is destroyed, tokens tied to it become invalid.
"""
from importlib import import_module
from django.conf import settings as django_settings
jwt_settings = get_settings()
if not jwt_settings.validate_session:
return True
# Use the configured session engine
engine = import_module(django_settings.SESSION_ENGINE)
SessionStore = engine.SessionStore
# Try to load the session
session = SessionStore(session_key=session_key)
# Check if session exists and is not empty
# exists() is more reliable than checking load() result
return session.exists(session_key)
def refresh_tokens(refresh_token: str) -> TokenPair | None:
"""
Use a refresh token to obtain new tokens.
Returns None if:
- Refresh token is invalid or expired
- Associated session no longer exists
"""
payload = decode_token(refresh_token, expected_type="refresh")
if payload is None:
return None
# Validate the session still exists
if not validate_session(payload.session_key):
return None
# Issue new token pair with same claims
return create_token_pair(
payload.user_id,
payload.session_key,
is_staff=payload.is_staff,
is_superuser=payload.is_superuser,
)

View File

@@ -0,0 +1,35 @@
"""
Export channels schema as OpenAPI JSON for TypeScript generation.
Uses Django Ninja's schema generation for robust Pydantic→OpenAPI conversion.
The schema is consumed by openapi-typescript for type generation.
Usage:
python manage.py export_channels_schema
"""
import json
from django.core.management.base import BaseCommand
class Command(BaseCommand):
help = "Export channels schema as OpenAPI JSON for TypeScript code generation"
def add_arguments(self, parser):
parser.add_argument(
"--indent",
type=int,
default=2,
help="JSON indentation level (default: 2, use 0 for compact)",
)
def handle(self, *args, **options):
from mizan.channels import get_channels_openapi_schema
schema = get_channels_openapi_schema()
indent = options["indent"] if options["indent"] > 0 else None
output = json.dumps(schema, indent=indent)
self.stdout.write(output)

View File

@@ -0,0 +1,56 @@
"""
Export Edge Manifest
Generates the static JSON manifest that Mizan Edge reads at deploy time
to configure CDN cache rules and invalidation routing.
Usage:
python manage.py export_edge_manifest
python manage.py export_edge_manifest --output mizan-manifest.json
python manage.py export_edge_manifest --base-url /api/mizan
"""
import json
from pathlib import Path
from django.core.management.base import BaseCommand
from mizan.export import generate_edge_manifest
class Command(BaseCommand):
help = "Export Edge manifest for CDN cache invalidation"
def add_arguments(self, parser):
parser.add_argument(
"--output",
"-o",
type=str,
default=None,
help="Output file path. If not specified, outputs to stdout.",
)
parser.add_argument(
"--indent",
type=int,
default=2,
help="JSON indentation level (0 for compact output)",
)
parser.add_argument(
"--base-url",
type=str,
default="/api/mizan",
help="Mizan API mount point (default: /api/mizan)",
)
def handle(self, *args, **options):
manifest = generate_edge_manifest(base_url=options["base_url"])
indent = options["indent"] if options["indent"] > 0 else None
json_output = json.dumps(manifest, indent=indent, sort_keys=True)
if options["output"]:
output_path = Path(options["output"])
output_path.parent.mkdir(parents=True, exist_ok=True)
output_path.write_text(json_output)
self.stdout.write(self.style.SUCCESS(f"Manifest written to {output_path}"))
else:
self.stdout.write(json_output)

View File

@@ -0,0 +1,49 @@
"""
Export mizan Schema
Management command to export the mizan OpenAPI schema for TypeScript code generation.
The schema is consumed by openapi-typescript for robust type generation.
Usage:
python manage.py export_mizan_schema # Output to stdout
python manage.py export_mizan_schema --output schema.json # Output to file
"""
import json
from pathlib import Path
from django.core.management.base import BaseCommand
from mizan.export import generate_openapi_schema
class Command(BaseCommand):
help = "Export mizan OpenAPI schema for TypeScript code generation"
def add_arguments(self, parser):
parser.add_argument(
"--output",
"-o",
type=str,
default=None,
help="Output file path. If not specified, outputs to stdout.",
)
parser.add_argument(
"--indent",
type=int,
default=2,
help="JSON indentation level (0 for compact output)",
)
def handle(self, *args, **options):
schema = generate_openapi_schema()
indent = options["indent"] if options["indent"] > 0 else None
json_output = json.dumps(schema, indent=indent)
if options["output"]:
output_path = Path(options["output"])
output_path.parent.mkdir(parents=True, exist_ok=True)
output_path.write_text(json_output)
self.stdout.write(self.style.SUCCESS(f"Schema written to {output_path}"))
else:
self.stdout.write(json_output)

View File

@@ -0,0 +1,169 @@
"""
MWT (Mizan Web Token) — Protocol-owned identity layer.
MWT is a standard JWT (RFC 7519, HMAC-SHA256) with Mizan-specific claims,
traveling on the `X-Mizan-Token` header. It provides:
- `sub`: user_id for HMAC cache key derivation
- `pkey`: permission state hash for staleness detection
- `kid`: key ID in the JOSE header (per RFC 7515) for secret rotation
- `aud`: audience binding to prevent cross-tenant replay
- `nbf`: not-before to handle clock skew
MWT is issued from an authenticated Django session. The app handles
authentication (session, social auth, etc.); Mizan issues MWT from
the authenticated identity. Edge Workers and the origin-side cache
validate MWT to extract user identity for cache operations.
Usage:
from mizan.mwt import create_mwt, decode_mwt, MWTUser
Configuration:
MIZAN_MWT_SECRET: MWT signing key (separate from MIZAN_CACHE_SECRET)
MIZAN_MWT_TTL: token lifetime in seconds (default: 300)
"""
from __future__ import annotations
import hashlib
import logging
import time
from dataclasses import dataclass
from typing import Any
import jwt
logger = logging.getLogger("mizan.mwt")
@dataclass
class MWTPayload:
"""Decoded MWT claims."""
sub: str # user_id
staff: bool # is_staff
super: bool # is_superuser
pkey: str # permission state hash (full SHA-256 hex)
kid: str # key ID (from JOSE header)
aud: str # audience
iat: int # issued at
exp: int # expiration
class MWTUser:
"""
Minimal user object created from MWT claims.
Used as request.user for MWT-authenticated requests.
No database query required — all data comes from the token.
"""
def __init__(self, payload: MWTPayload):
self.id = int(payload.sub)
self.pk = self.id
self.is_staff = payload.staff
self.is_superuser = payload.super
self.is_authenticated = True
self.is_anonymous = False
self.is_active = True
self.pkey = payload.pkey
def __str__(self) -> str:
return f"MWTUser(id={self.id})"
def __repr__(self) -> str:
return f"MWTUser(id={self.id}, pkey={self.pkey[:8]}...)"
def compute_permission_key(user: Any) -> str:
"""
Compute a deterministic hash of the user's permission state.
Includes is_staff, is_superuser, and all Django permissions.
When the MWT expires and is refreshed, the new pkey reflects
any permission changes. The short TTL controls the staleness window.
Returns the full 64-character SHA-256 hex digest.
"""
perms = sorted(user.get_all_permissions()) if hasattr(user, "get_all_permissions") else []
staff = "1" if getattr(user, "is_staff", False) else "0"
superuser = "1" if getattr(user, "is_superuser", False) else "0"
blob = f"{staff}:{superuser}:{','.join(perms)}"
return hashlib.sha256(blob.encode("utf-8")).hexdigest()
def create_mwt(
user: Any,
secret: str,
ttl: int = 300,
audience: str = "mizan",
kid: str = "v1",
) -> str:
"""
Create an MWT from an authenticated Django user.
Args:
user: Django user object (must have pk, is_staff, is_superuser).
secret: MIZAN_MWT_SECRET signing key.
ttl: Token lifetime in seconds (default: 300 = 5 minutes).
audience: Audience claim for cross-tenant protection.
kid: Key ID placed in JOSE header (per RFC 7515) for rotation.
Returns:
Encoded JWT string.
"""
now = int(time.time())
payload = {
"sub": str(user.pk),
"staff": getattr(user, "is_staff", False),
"super": getattr(user, "is_superuser", False),
"pkey": compute_permission_key(user),
"aud": audience,
"iat": now,
"nbf": now,
"exp": now + ttl,
}
# kid goes in the JOSE header per RFC 7515, not the payload
headers = {"kid": kid}
return jwt.encode(payload, secret, algorithm="HS256", headers=headers)
def decode_mwt(
token: str,
secret: str,
audience: str = "mizan",
) -> MWTPayload | None:
"""
Decode and validate an MWT.
Returns MWTPayload on success, None on any failure (expired, invalid
signature, wrong audience, not-yet-valid, malformed).
"""
try:
# Decode header first to extract kid
unverified_header = jwt.get_unverified_header(token)
kid = unverified_header.get("kid", "v1")
data = jwt.decode(
token,
secret,
algorithms=["HS256"],
audience=audience,
)
except jwt.PyJWTError:
logger.debug("MWT decode failed", exc_info=True)
return None
try:
return MWTPayload(
sub=data["sub"],
staff=data.get("staff", False),
super=data.get("super", False),
pkey=data.get("pkey", ""),
kid=kid,
aud=audience,
iat=data["iat"],
exp=data["exp"],
)
except (KeyError, TypeError):
logger.debug("MWT payload missing required claims", exc_info=True)
return None

View File

@@ -0,0 +1,73 @@
"""
mizan.setup - Integration and registration utilities.
This subpackage contains everything developers need to integrate mizan:
- Registry for server functions and channels
- Auto-discovery for apps
- Configuration settings
Usage:
from mizan.setup import mizan_clients, register, get_function
"""
from .registry import (
register,
register_as,
register_form,
register_compose,
get_function,
get_channel,
get_compose,
get_view,
get_all_functions,
get_all_channels,
get_all_compositions,
get_registry,
get_schema,
get_contexts,
get_context_groups,
get_forms,
validate_registry,
clear_registry,
)
from .discovery import (
mizan_clients,
mizan_module,
)
from .settings import (
mizanSettings,
get_settings,
clear_settings_cache,
)
__all__ = [
# Registration
"register",
"register_as",
"register_form",
"register_compose",
# Lookup
"get_function",
"get_channel",
"get_compose",
"get_view",
"get_all_functions",
"get_all_channels",
"get_all_compositions",
"get_registry",
"get_schema",
"get_contexts",
"get_context_groups",
"get_forms",
"validate_registry",
"clear_registry",
# Discovery
"mizan_clients",
"mizan_module",
# Settings
"mizanSettings",
"get_settings",
"clear_settings_cache",
]

View File

@@ -0,0 +1,93 @@
"""
mizan Auto-Discovery
Scans Django apps for server functions following the 'clients' layer convention:
- <app>/clients.py
- <app>/clients/**/*.py
Usage in urls.py:
from mizan.setup.discovery import mizan_clients
mizan_clients('apps') # Scans apps/*/clients.py
mizan_clients('mizan', 'allauth') # Scans mizan/allauth/**/*.py
This replaces manual "import to register" patterns with explicit auto-discovery.
"""
from typing import Any
from mizan._vendor.app_visitor import DjangoAppVisitor, get_members
from .registry import register, get_function
from mizan.client.function import ServerFunction
class _RegisterServerFunctions:
"""Visitor handler that registers ServerFunction subclasses."""
def on_module(
self, app_name: str, path_parts: list[str], members: list[tuple[str, Any]]
) -> None:
"""Process discovered module members."""
for name, member in members:
# Register ServerFunction subclasses
if (
isinstance(member, type)
and issubclass(member, ServerFunction)
and member is not ServerFunction
and hasattr(member, "__name__")
):
# Use the function name as registration name
fn_name = getattr(member, "name", None) or member.__name__
# Skip already registered (idempotent)
if get_function(fn_name) is member:
continue
try:
register(member, fn_name)
except ValueError:
# Already registered with different class - skip
pass
def mizan_clients(apps_root: str, layer: str = "clients") -> None:
"""
Discover and register server functions from Django apps.
Scans for the specified layer (default: 'clients') in each app:
- <app>/<layer>.py
- <app>/<layer>/**/*.py
Args:
apps_root: Root package containing Django apps (e.g., 'apps')
layer: Module name pattern to scan (default: 'clients')
Example:
# In urls.py
mizan_clients('apps') # Scans apps/*/clients.py
mizan_clients('apps', 'functions') # Scans apps/*/functions.py
"""
visitor = DjangoAppVisitor(layer=layer, apps_root=apps_root)
visitor.visit(_RegisterServerFunctions())
from .registry import validate_registry
validate_registry()
def mizan_module(module_path: str) -> None:
"""
Register server functions from a specific module.
Use this for library modules that don't follow the app convention.
Args:
module_path: Full module path (e.g., 'mizan.integrations.allauth')
Example:
mizan_module('mizan.integrations.allauth')
mizan_module('mizan.jwt.functions')
"""
members = get_members(module_path)
handler = _RegisterServerFunctions()
handler.on_module("", [], members)

View File

@@ -0,0 +1,373 @@
"""
mizan Registry
Central registration for server functions, channels, and compositions.
All items are identified by name.
"""
from __future__ import annotations
from typing import TYPE_CHECKING, Any, Callable
if TYPE_CHECKING:
from mizan.client.function import ServerFunction, ComposedContext
from mizan.channels import ReactChannel
# Global registries - all use name as key
_functions: dict[str, type["ServerFunction"]] = {}
_channels: dict[str, type["ReactChannel"]] = {}
_compositions: dict[str, "ComposedContext"] = {}
def register(
view_class: type["ServerFunction"] | type["ReactChannel"],
name: str,
) -> type["ServerFunction"] | type["ReactChannel"]:
"""
Register a server function or channel.
Args:
view_class: ServerFunction or ReactChannel subclass
name: Registration name (used for API calls and code generation)
Returns:
The view class (allows use as part of decorator chain)
"""
from mizan.client.function import ServerFunction
from mizan.channels import ReactChannel
view_class.name = name
if issubclass(view_class, ReactChannel):
if name in _channels:
# Allow re-registration of the same class (idempotent for reloads)
if _channels[name] is not view_class:
raise ValueError(
f"Channel '{name}' already registered by {_channels[name].__name__}"
)
return view_class
_channels[name] = view_class
elif issubclass(view_class, ServerFunction):
if name in _functions:
# Allow re-registration of the same class (idempotent for reloads)
existing = _functions[name]
if existing.__name__ == view_class.__name__:
# Same function being re-registered (reload scenario)
_functions[name] = view_class
return view_class
raise ValueError(
f"Function '{name}' already registered by {existing.__name__}"
)
_functions[name] = view_class
else:
raise TypeError(f"{view_class} must be a ServerFunction or ReactChannel")
return view_class
def register_as(name: str):
"""
Decorator for registering a server function or channel.
Usage:
@register_as('update-profile')
class UpdateProfile(ServerFunction):
...
"""
def decorator(view_class):
return register(view_class, name)
return decorator
def register_form(
form_class: type,
name: str,
submit_handler: Callable | None = None,
) -> None:
"""
Register a Django Form as server functions.
Creates and registers:
- {name}.schema: Returns form field definitions
- {name}.validate: Validates form data
- {name}.submit: Submits form (if submit_handler provided)
Usage:
register_form(ContactForm, 'contact', submit_handler=handle_contact)
"""
from mizan.client.function import create_form_functions
schema_fn, validate_fn, submit_fn = create_form_functions(
form_class, name, submit_handler
)
register(schema_fn, f"{name}.schema")
register(validate_fn, f"{name}.validate")
if submit_fn:
register(submit_fn, f"{name}.submit")
def register_compose(
composed: "ComposedContext",
name: str,
) -> "ComposedContext":
"""
Register a composed context.
Args:
composed: ComposedContext instance
name: Registration name
Returns:
The composed context
"""
if name in _compositions:
existing = _compositions[name]
if existing.name == composed.name:
# Same composition being re-registered (reload scenario)
_compositions[name] = composed
return composed
raise ValueError(f"Composition '{name}' already registered by {existing.name}")
_compositions[name] = composed
return composed
def get_function(name: str) -> type["ServerFunction"] | None:
"""Get a registered server function by name."""
return _functions.get(name)
def get_channel(name: str) -> type["ReactChannel"] | None:
"""Get a registered channel by name."""
return _channels.get(name)
def get_compose(name: str) -> "ComposedContext | None":
"""Get a registered composition by name."""
return _compositions.get(name)
def get_view(name: str) -> type["ServerFunction"] | type["ReactChannel"] | None:
"""Get any registered view by name (function or channel)."""
return _functions.get(name) or _channels.get(name)
def get_all_functions() -> dict[str, type["ServerFunction"]]:
"""Get all registered functions."""
return _functions.copy()
def get_all_channels() -> dict[str, type["ReactChannel"]]:
"""Get all registered channels."""
return _channels.copy()
def get_all_compositions() -> dict[str, "ComposedContext"]:
"""Get all registered compositions."""
return _compositions.copy()
def get_registry() -> dict[str, dict[str, Any]]:
"""
Get the full registry organized by type.
Returns:
{
"functions": { name: class, ... },
"channels": { name: class, ... },
"compositions": { name: ComposedContext, ... },
}
"""
return {
"functions": _functions.copy(),
"channels": _channels.copy(),
"compositions": _compositions.copy(),
}
def get_schema() -> dict[str, Any]:
"""
Export the full schema for TypeScript generation.
Returns:
{
"functions": {
"update_profile": {
"name": "update_profile",
"type": "function",
"meta": { "context": "global", ... },
"input": { ... },
"output": { ... },
},
...
},
"channels": {
"chat": {
"name": "chat",
"type": "channel",
"params": { ... },
"django_message": { ... },
...
},
...
},
"compositions": {
"user_page": {
"name": "user_page",
"type": "compose",
"meta": { "on_server": false, ... },
"children": ["user_profile", "user_posts"],
"leaves": ["user_profile", "user_posts"],
},
...
},
}
"""
functions = {}
for name, cls in _functions.items():
schema = cls.get_schema_export()
functions[name] = schema
compositions = {}
for name, composed in _compositions.items():
compositions[name] = {
"name": composed.name,
"type": "compose",
"meta": composed._meta,
"children": composed._meta.get("children", []),
"leaves": composed._meta.get("leaves", []),
}
# Build channel schemas from our registry
# Only include keys when they have values (test expects absent keys, not None)
channels_schema = {}
for name, channel_class in _channels.items():
channel_schema: dict[str, Any] = {
"name": name,
"type": "channel",
"bidirectional": False,
}
# Extract Params schema (only if defined)
if hasattr(channel_class, "Params") and channel_class.Params:
channel_schema["params"] = channel_class.Params.model_json_schema()
# Extract ReactMessage schema (only if defined - indicates bidirectional)
if hasattr(channel_class, "ReactMessage") and channel_class.ReactMessage:
channel_schema[
"react_message"
] = channel_class.ReactMessage.model_json_schema()
channel_schema["bidirectional"] = True
# Extract DjangoMessage schema (only if defined)
if hasattr(channel_class, "DjangoMessage") and channel_class.DjangoMessage:
channel_schema[
"django_message"
] = channel_class.DjangoMessage.model_json_schema()
channels_schema[name] = channel_schema
return {
"functions": functions,
"channels": channels_schema,
"compositions": compositions,
}
def get_contexts() -> dict[str, type["ServerFunction"]]:
"""
Get all server functions marked as contexts.
These are functions with meta.context = True, used for SSR hydration.
"""
contexts = {}
for name, cls in _functions.items():
if getattr(cls, "_meta", {}).get("context"):
contexts[name] = cls
return contexts
def get_context_groups() -> dict[str, list[str]]:
"""
Group function names by their context string.
Returns:
{"global": ["current_user"], "user": ["user_profile", "user_orders"]}
"""
groups: dict[str, list[str]] = {}
for name, cls in _functions.items():
ctx = getattr(cls, "_meta", {}).get("context")
if ctx:
groups.setdefault(ctx, []).append(name)
return groups
def get_forms() -> dict[str, list[type["ServerFunction"]]]:
"""
Get all server functions that are form-related, grouped by form name.
Returns:
{
"contact": [ContactSchema, ContactValidate, ContactSubmit],
...
}
"""
forms: dict[str, list] = {}
for name, cls in _functions.items():
meta = getattr(cls, "_meta", {})
if meta.get("form"):
form_name = meta.get("form_name")
if form_name not in forms:
forms[form_name] = []
forms[form_name].append(cls)
return forms
def validate_registry() -> list[str]:
"""
Validate that all affects targets resolve to known contexts or functions.
Called automatically after discovery. Emits warnings for unresolved targets
(e.g., typos in string-based affects declarations).
Returns a list of warning messages (empty if everything resolves).
"""
import warnings
issues: list[str] = []
groups = get_context_groups()
all_fn_names = set(_functions.keys())
for fn_name, fn_cls in _functions.items():
meta = getattr(fn_cls, "_meta", {})
affects = meta.get("affects")
if not affects:
continue
for target in affects:
target_name = target.get("name", "")
target_type = target.get("type", "")
if target_type == "context" and target_name not in groups:
issues.append(
f"@client function '{fn_name}' declares affects='{target_name}', "
f"but no context named '{target_name}' is registered."
)
elif target_type == "function" and target_name not in all_fn_names:
issues.append(
f"@client function '{fn_name}' targets function '{target_name}', "
f"but no function named '{target_name}' is registered."
)
for msg in issues:
warnings.warn(msg, stacklevel=2)
return issues
def clear_registry() -> None:
"""Clear all registrations. Primarily for testing."""
_functions.clear()
_channels.clear()
_compositions.clear()

View File

@@ -0,0 +1,54 @@
"""
mizan Settings
Configuration is read from Django settings with sensible defaults.
"""
from dataclasses import dataclass
from functools import lru_cache
from django.conf import settings as django_settings
@dataclass
class mizanSettings:
"""mizan configuration."""
# Whether to expose function names in DEBUG mode errors
debug_expose_names: bool
# Cache HMAC signing secret (required when cache is enabled)
cache_secret: str | None
# Redis URL for cache backend (None = cache disabled)
cache_redis_url: str | None
# MWT signing secret (separate from cache secret for blast radius containment)
mwt_secret: str | None
# MWT token lifetime in seconds (default: 300 = 5 minutes)
mwt_ttl: int
@lru_cache
def get_settings() -> mizanSettings:
"""
Load mizan settings from Django settings.
Settings:
mizan_DEBUG_EXPOSE_NAMES: Show function names in errors when DEBUG=True (default: True)
MIZAN_CACHE_SECRET: HMAC signing key for cache keys (default: None)
MIZAN_CACHE_REDIS_URL: Redis connection URL (default: None)
"""
return mizanSettings(
debug_expose_names=getattr(django_settings, "mizan_DEBUG_EXPOSE_NAMES", True),
cache_secret=getattr(django_settings, "MIZAN_CACHE_SECRET", None),
cache_redis_url=getattr(django_settings, "MIZAN_CACHE_REDIS_URL", None),
mwt_secret=getattr(django_settings, "MIZAN_MWT_SECRET", None),
mwt_ttl=getattr(django_settings, "MIZAN_MWT_TTL", 300),
)
def clear_settings_cache():
"""Clear the settings cache (for testing)."""
get_settings.cache_clear()

View File

@@ -0,0 +1,3 @@
from mizan.shapes.core import Diff, NestedDiff, Shape
__all__ = ["Diff", "NestedDiff", "Shape"]

View File

@@ -0,0 +1,265 @@
from __future__ import annotations
import types
from typing import Any, ClassVar, Generic, TypeVar, Union, get_type_hints
from pydantic import BaseModel
from django_readers import pairs, specs
from django_readers import qs as readers_qs
_M = TypeVar("_M")
_S = TypeVar("_S", bound="Shape")
def _extract_shape_class(hint) -> type[Shape] | None:
origin = getattr(hint, "__origin__", None)
args = getattr(hint, "__args__", ())
# list[SomeShape]
if (
origin is list
and args
and isinstance(args[0], type)
and issubclass(args[0], Shape)
):
return args[0]
# SomeShape (bare)
if isinstance(hint, type) and issubclass(hint, Shape) and hint is not Shape:
return hint
# SomeShape | None (Union/Optional)
if origin is Union or isinstance(hint, types.UnionType):
for arg in args:
if arg is type(None):
continue
if isinstance(arg, type) and issubclass(arg, Shape) and arg is not Shape:
return arg
return None
def _resolve_model(cls) -> Any | None:
for base in cls.__bases__:
meta = getattr(base, "__pydantic_generic_metadata__", None) or {}
if meta.get("origin") is Shape and (args := meta.get("args")):
return args[0]
return None
class Shape(BaseModel, Generic[_M]):
_model: ClassVar[Any]
_nested: ClassVar[dict[str, type[Shape]]]
_field_names: ClassVar[list[str]]
_pk_field: ClassVar[str]
_spec: ClassVar[list]
_pair: ClassVar[tuple]
def __init_subclass__(cls, **kwargs):
super().__init_subclass__(**kwargs)
if not (model := _resolve_model(cls)):
return
cls._model = model
cls._nested = {}
cls._pk_field = model._meta.pk.name if model._meta.pk else "id"
hints = get_type_hints(cls, include_extras=False, localns={cls.__name__: cls}) or cls.__annotations__
field_names = []
for name, hint in hints.items():
if name.startswith("_"):
continue
if shape_cls := _extract_shape_class(hint):
cls._nested[name] = shape_cls
else:
field_names.append(name)
cls._field_names = field_names
# Set field-only spec first so self-references can find it
cls._spec = [*field_names]
# Now rebuild with nested — self-refs resolve because cls._spec exists
cls._spec = [
*field_names,
*({name: shape._spec} for name, shape in cls._nested.items()),
]
cls._pair = specs.process(cls._spec)
@classmethod
def _build_pair(cls, relation_qs: dict[str, Any]):
field_pairs = [
pairs.producer_to_projector(name, pairs.field(name))
for name in cls._field_names
]
rel_pairs = []
for name, shape_cls in cls._nested.items():
child_prepare, child_project = shape_cls._pair
prepare = (
readers_qs.pipe(relation_qs[name], child_prepare)
if name in relation_qs
else child_prepare
)
rel_pairs.append(
pairs.producer_to_projector(
name, pairs.relationship(name, (prepare, child_project))
)
)
return pairs.combine(*field_pairs, *rel_pairs)
@classmethod
def _get_pk(cls, instance) -> Any | None:
return getattr(instance, cls._pk_field, None)
@classmethod
def query(cls: type[_S], *qs_fns, **relation_qs) -> list[_S]:
prepare, project = cls._build_pair(relation_qs) if relation_qs else cls._pair
base = cls._model.objects.all()
# Accept a raw QuerySet as the first arg, or qs functions, or nothing
if qs_fns and hasattr(qs_fns[0], "query"):
base, qs_fns = qs_fns[0], qs_fns[1:]
queryset = readers_qs.pipe(prepare, *qs_fns)(base)
return [cls.model_validate(project(obj)) for obj in queryset]
@classmethod
def diff_many(cls: type[_S], items: list[_S]) -> list[tuple[_S, Diff]]:
pk_field = cls._pk_field
pk_map: dict[Any, _S] = {}
new_items: list[_S] = []
for item in items:
pk = cls._get_pk(item)
if pk is not None:
pk_map[pk] = item
else:
new_items.append(item)
# Single query for all existing items
current_map: dict[Any, _S] = {}
if pk_map:
current_items = cls.query(
cls._model.objects.filter(**{f"{pk_field}__in": pk_map.keys()})
)
current_map = {cls._get_pk(c): c for c in current_items}
results: list[tuple[_S, Diff]] = []
for item in new_items:
results.append((item, cls._diff_one(item, None)))
for pk, item in pk_map.items():
current = current_map.get(pk)
if current is None:
raise cls._model.DoesNotExist(
f"{cls._model.__name__} with {pk_field}={pk} does not exist"
)
results.append((item, cls._diff_one(item, current)))
return results
@classmethod
def _diff_one(cls, incoming: _S, current: _S | None) -> Diff:
pk_field = cls._pk_field
changed = (
{
k: getattr(incoming, k)
for k in cls._field_names
if k != pk_field and getattr(incoming, k) != getattr(current, k)
}
if current
else {k: getattr(incoming, k) for k in cls._field_names if k != pk_field}
)
nested = {}
for name, shape_cls in cls._nested.items():
incoming_items = getattr(incoming, name, None) or []
current_items = getattr(current, name, None) or [] if current else []
if not isinstance(incoming_items, list):
incoming_items = [incoming_items]
if not isinstance(current_items, list):
current_items = [current_items]
child_pk = shape_cls._pk_field
current_by_pk = {
shape_cls._get_pk(c): c
for c in current_items
if shape_cls._get_pk(c) is not None
}
incoming_by_pk = {
shape_cls._get_pk(c): c
for c in incoming_items
if shape_cls._get_pk(c) is not None
}
nested[name] = NestedDiff(
created=[c for c in incoming_items if shape_cls._get_pk(c) is None],
updated=[
c
for pk, c in incoming_by_pk.items()
if pk in current_by_pk and c != current_by_pk[pk]
],
deleted=[pk for pk in current_by_pk if pk not in incoming_by_pk],
)
return Diff(is_new=current is None, changed=changed, _nested=nested)
def diff(self) -> Diff:
cls = type(self)
pk = cls._get_pk(self)
if pk is not None:
results = cls.query(cls._model.objects.filter(pk=pk))
if not results:
raise cls._model.DoesNotExist(
f"{cls._model.__name__} with {cls._pk_field}={pk} does not exist"
)
current = results[0]
else:
current = None
return cls._diff_one(self, current)
class NestedDiff:
__slots__ = ("created", "updated", "deleted")
def __init__(self, created=(), updated=(), deleted=()):
self.created = list(created)
self.updated = list(updated)
self.deleted = list(deleted)
class Diff:
__slots__ = ("is_new", "changed", "_nested")
def __init__(
self, is_new: bool, changed: dict[str, Any], _nested: dict[str, NestedDiff]
):
self.is_new = is_new
self.changed = changed
self._nested = _nested
def nested(self, name: str) -> NestedDiff:
"""Strict access to nested diffs. Raises KeyError for invalid names."""
if name not in self._nested:
valid = ", ".join(sorted(self._nested)) or "(none)"
raise KeyError(f"No nested diff for '{name}'. Valid nested shapes: {valid}")
return self._nested[name]
def __getattr__(self, name: str) -> NestedDiff:
if name.startswith("_"):
raise AttributeError(name)
if name not in self._nested:
valid = ", ".join(sorted(self._nested)) or "(none)"
raise AttributeError(
f"No nested diff for '{name}'. Valid nested shapes: {valid}"
)
return self._nested[name]

View File

@@ -0,0 +1,25 @@
"""
mizan.ssr — Server-side rendering via Bun subprocess.
Mizan's SSR is a Django template backend. Configure it in TEMPLATES:
TEMPLATES = [
{
'BACKEND': 'mizan.ssr.MizanTemplates',
'OPTIONS': {
'worker_path': 'frontend/ssr-worker.tsx',
'timeout': 5,
},
},
]
Then use Django's standard render():
return render(request, 'ProfilePage', {'user_id': 5})
The component name is the template name. The context dict becomes props.
"""
from .backend import MizanTemplates
__all__ = ["MizanTemplates"]

View File

@@ -0,0 +1,100 @@
"""
Mizan SSR Template Backend — Django template engine that renders React via Bun.
TEMPLATES = [
{
'BACKEND': 'mizan.ssr.MizanTemplates',
'DIRS': [BASE_DIR / 'frontend'],
'OPTIONS': {
'worker': 'path/to/mizan-ssr/src/worker.tsx',
},
},
]
Then: render(request, 'components/Hello.tsx', {'name': 'World'})
"""
from __future__ import annotations
import os
from typing import Any
from django.template import TemplateDoesNotExist
from django.template.backends.base import BaseEngine
from django.utils.safestring import mark_safe
from .bridge import SSRBridge
class MizanTemplate:
"""Renders a .tsx/.jsx file via the SSR bridge."""
def __init__(self, file_path: str, bridge: SSRBridge) -> None:
self.file_path = file_path
self.origin = None
self._bridge = bridge
def render(self, context: dict[str, Any] | None = None, request: Any = None) -> str:
import json as _json
props = dict(context) if context else {}
props.pop("request", None)
props.pop("csrf_token", None)
result = self._bridge.render(self.file_path, props)
# Serialize props as hydration data for client-side React
hydration_json = _json.dumps(props, sort_keys=True, default=str)
return mark_safe(
f'<div id="mizan-root">{result.html}</div>'
f'<script>window.__MIZAN_SSR_DATA__={hydration_json}</script>'
)
class MizanTemplates(BaseEngine):
"""
Django template backend that renders React components via Bun.
Template names are file paths resolved against DIRS.
Same model as Django's built-in template engines.
"""
def __init__(self, params: dict[str, Any]) -> None:
options = params.pop("OPTIONS", {})
params.setdefault("NAME", "mizan")
params.setdefault("APP_DIRS", False)
super().__init__(params)
self._worker = options.get("worker")
self._timeout = options.get("timeout", 5)
self._bridge: SSRBridge | None = None
if not self._worker:
raise ValueError(
"MizanTemplates requires OPTIONS['worker'] — "
"the path to mizan-ssr's worker.tsx"
)
def get_bridge(self) -> SSRBridge:
if self._bridge is None:
self._bridge = SSRBridge(
worker_path=self._worker,
timeout=self._timeout,
)
return self._bridge
def get_template(self, template_name: str) -> MizanTemplate:
for dir_path in self.dirs:
file_path = os.path.join(dir_path, template_name)
if os.path.isfile(file_path):
return MizanTemplate(
os.path.abspath(file_path),
self.get_bridge(),
)
raise TemplateDoesNotExist(template_name)
def from_string(self, template_code: str) -> MizanTemplate:
raise TemplateDoesNotExist(
"MizanTemplates renders .tsx files, not template strings."
)

View File

@@ -0,0 +1,181 @@
"""
SSR Bridge — Manages a persistent Bun subprocess for React rendering.
Protocol: newline-delimited JSON-RPC over stdin/stdout.
Request: {"id": 1, "method": "render", "params": {"file": "/abs/path/Hello.tsx", "props": {...}}}
Response: {"id": 1, "html": "<div>...</div>"}
The subprocess stays alive across requests. It is started on first use
and restarted automatically if it crashes.
"""
from __future__ import annotations
import atexit
import json
import logging
import subprocess
import threading
from dataclasses import dataclass
from typing import Any
logger = logging.getLogger("mizan.ssr")
@dataclass
class RenderResult:
"""Result of an SSR render call."""
html: str
class SSRBridge:
"""
Manages a persistent Bun subprocess for server-side rendering.
Thread-safe. Multiple Django workers can call render() concurrently.
Request-response matching via message IDs.
"""
def __init__(self, worker_path: str, timeout: float = 5.0) -> None:
self._worker_path = worker_path
self._timeout = timeout
self._proc: subprocess.Popen | None = None
self._lock = threading.Lock()
self._write_lock = threading.Lock() # Serializes stdin writes
self._counter = 0
self._pending: dict[int, threading.Event] = {}
self._results: dict[int, dict] = {}
self._reader_thread: threading.Thread | None = None
self._ready = threading.Event()
# Ensure cleanup on process exit
atexit.register(self.shutdown)
def _ensure_running(self) -> None:
"""Start the Bun subprocess if it's not running."""
if self._proc is not None and self._proc.poll() is None:
return
if self._proc is not None:
logger.warning("Bun SSR worker died (exit code %s), restarting", self._proc.returncode)
self._ready.clear()
self._proc = subprocess.Popen(
["bun", "run", self._worker_path],
stdin=subprocess.PIPE,
stdout=subprocess.PIPE,
stderr=subprocess.PIPE,
)
self._reader_thread = threading.Thread(
target=self._read_responses, daemon=True, name="mizan-ssr-reader",
)
self._reader_thread.start()
# Wait for the "ready" signal from the worker
if not self._ready.wait(timeout=self._timeout):
logger.error("Bun SSR worker failed to start within %ss", self._timeout)
self.shutdown()
raise TimeoutError("SSR worker failed to start")
logger.info("Bun SSR worker started (pid %s)", self._proc.pid)
def _read_responses(self) -> None:
"""Background thread that reads JSON responses from stdout."""
try:
for line in self._proc.stdout:
if isinstance(line, bytes):
line = line.decode("utf-8")
line = line.strip()
if not line:
continue
try:
msg = json.loads(line)
except json.JSONDecodeError:
logger.warning("Malformed JSON from SSR worker: %s", line[:200])
continue
msg_id = msg.get("id")
# Ready signal (id=0)
if msg_id == 0 and msg.get("ready"):
self._ready.set()
continue
if msg_id is not None and msg_id in self._pending:
self._results[msg_id] = msg
self._pending[msg_id].set()
except Exception:
logger.warning("SSR reader thread exited", exc_info=True)
def render(self, file: str, props: dict[str, Any] | None = None) -> RenderResult:
"""
Render a React component to HTML.
Args:
file: Absolute path to the .tsx/.jsx file to render.
props: Props to pass to the component.
Returns:
RenderResult with the HTML string.
Raises:
TimeoutError: If the render takes longer than the configured timeout.
RuntimeError: If the render fails.
"""
with self._lock:
self._ensure_running()
self._counter += 1
msg_id = self._counter
event = threading.Event()
self._pending[msg_id] = event
request = json.dumps({
"id": msg_id,
"method": "render",
"params": {"file": file, "props": props or {}},
}) + "\n"
# Serialize stdin writes to prevent interleaving from concurrent threads
with self._write_lock:
try:
self._proc.stdin.write(request.encode("utf-8"))
self._proc.stdin.flush()
except (BrokenPipeError, OSError) as e:
self._pending.pop(msg_id, None)
raise RuntimeError(f"SSR worker pipe broken: {e}")
if not event.wait(self._timeout):
self._pending.pop(msg_id, None)
raise TimeoutError(
f"SSR render of '{file}' timed out after {self._timeout}s"
)
self._pending.pop(msg_id, None)
result = self._results.pop(msg_id)
if "error" in result:
raise RuntimeError(f"SSR render failed: {result['error']}")
return RenderResult(html=result["html"])
def shutdown(self) -> None:
"""Stop the Bun subprocess."""
if self._proc is not None:
try:
self._proc.stdin.close()
except Exception:
pass
try:
self._proc.terminate()
self._proc.wait(timeout=3)
except Exception:
try:
self._proc.kill()
except Exception:
pass
self._proc = None
logger.info("Bun SSR worker stopped")

View File

@@ -0,0 +1,554 @@
"""
Authentication Tests for mizan Server Functions
Tests all combinations of:
- Transport: HTTP vs WebSocket RPC
- JWT: Present (valid), Present (invalid), Absent
- Session: Present (valid), Absent
Expected behavior:
- JWT present (valid) → JWTUser (no DB query)
- JWT present (invalid) → Reject (401), do NOT fall back to session
- JWT absent + Session present → Session auth (DB query)
- JWT absent + Session absent → AnonymousUser
"""
from django.test import TestCase, RequestFactory, override_settings
from django.contrib.auth import get_user_model
from django.contrib.sessions.backends.db import SessionStore
from unittest.mock import patch, MagicMock
import json
from mizan.jwt.tokens import (
create_token_pair,
decode_token,
JWTUser,
)
from mizan.client.executor import (
_try_jwt_auth,
execute_function,
FunctionError,
FunctionResult,
ErrorCode,
)
from mizan.client import client
from mizan.setup.registry import clear_registry, register
from pydantic import BaseModel
User = get_user_model()
# =============================================================================
# Test Output Models (proper Pydantic models, not raw dicts)
# =============================================================================
class WhoamiOutput(BaseModel):
is_authenticated: bool
user_id: int | None
user_type: str
is_staff: bool
class OkOutput(BaseModel):
ok: bool
class UserTypeOutput(BaseModel):
user_type: str
# =============================================================================
# Test Server Functions - defined as plain functions, registered in setUp
# =============================================================================
def _whoami_fn(request) -> WhoamiOutput:
"""Returns info about the authenticated user."""
user = request.user
return WhoamiOutput(
is_authenticated=user.is_authenticated,
user_id=getattr(user, "id", None),
user_type=type(user).__name__,
is_staff=getattr(user, "is_staff", False),
)
@override_settings(
JWT_PRIVATE_KEY="test-secret-key-for-testing-only",
JWT_ALGORITHM="HS256",
)
class HTTPAuthTests(TestCase):
"""Test HTTP transport authentication combinations."""
def setUp(self):
clear_registry()
self.factory = RequestFactory()
self.user = User.objects.create_user(
email="test@example.com",
password="testpass123",
is_staff=True,
is_superuser=False,
)
# Create a session
self.session = SessionStore()
self.session.create()
self.session_key = self.session.session_key
# Register test function
@client
def whoami(request) -> WhoamiOutput:
user = request.user
return WhoamiOutput(
is_authenticated=user.is_authenticated,
user_id=getattr(user, "id", None),
user_type=type(user).__name__,
is_staff=getattr(user, "is_staff", False),
)
register(whoami, "whoami")
def tearDown(self):
self.user.delete()
self.session.delete()
clear_registry()
def test_jwt_valid_no_session(self):
"""Valid JWT without session → JWTUser (no DB query)."""
tokens = create_token_pair(
self.user.pk,
self.session_key,
is_staff=True,
is_superuser=False,
)
request = self.factory.post("/")
request.META["HTTP_AUTHORIZATION"] = f"Bearer {tokens.access_token}"
request.user = MagicMock(is_authenticated=False) # No session auth
# Try JWT auth
result = _try_jwt_auth(request)
self.assertTrue(result)
self.assertIsInstance(request.user, JWTUser)
self.assertEqual(request.user.id, self.user.pk)
self.assertTrue(request.user.is_staff)
self.assertTrue(request.user.is_authenticated)
def test_jwt_valid_with_session(self):
"""Valid JWT with session → JWT takes precedence (no DB query)."""
tokens = create_token_pair(
self.user.pk,
self.session_key,
is_staff=True,
is_superuser=False,
)
request = self.factory.post("/")
request.META["HTTP_AUTHORIZATION"] = f"Bearer {tokens.access_token}"
request.user = self.user # Session auth already set user
# JWT should still be processed and take precedence
result = _try_jwt_auth(request)
self.assertTrue(result)
self.assertIsInstance(request.user, JWTUser)
def test_jwt_invalid_with_session(self):
"""Invalid JWT with valid session → Reject (do NOT fall back)."""
request = self.factory.post("/")
request.META["HTTP_AUTHORIZATION"] = "Bearer invalid-token-here"
request.user = self.user # Session would work
# JWT auth should fail
result = _try_jwt_auth(request)
self.assertFalse(result)
# User should NOT be changed to session user - that happens elsewhere
# The point is _try_jwt_auth returns False, indicating JWT failed
def test_jwt_expired_with_session(self):
"""Expired JWT with valid session → Reject (do NOT fall back)."""
# Create token with past expiration by mocking time
with patch("mizan.jwt.tokens.time.time", return_value=0):
tokens = create_token_pair(
self.user.pk,
self.session_key,
is_staff=True,
is_superuser=False,
)
request = self.factory.post("/")
request.META["HTTP_AUTHORIZATION"] = f"Bearer {tokens.access_token}"
request.user = self.user # Session would work
# JWT auth should fail (expired)
result = _try_jwt_auth(request)
self.assertFalse(result)
def test_no_jwt_with_session(self):
"""No JWT with valid session → Session auth (normal Django flow)."""
request = self.factory.post("/")
request.user = self.user # Session auth set user
# No JWT auth attempted
result = _try_jwt_auth(request)
self.assertFalse(result) # No JWT to process
# User remains the session user
self.assertEqual(request.user, self.user)
def test_no_jwt_no_session(self):
"""No JWT, no session → AnonymousUser."""
from django.contrib.auth.models import AnonymousUser
request = self.factory.post("/")
request.user = AnonymousUser()
result = _try_jwt_auth(request)
self.assertFalse(result)
self.assertIsInstance(request.user, AnonymousUser)
def test_execute_function_with_jwt(self):
"""Execute server function with JWT auth."""
tokens = create_token_pair(
self.user.pk,
self.session_key,
is_staff=True,
is_superuser=False,
)
request = self.factory.post("/")
request.META["HTTP_AUTHORIZATION"] = f"Bearer {tokens.access_token}"
# Simulate what the view does: try JWT auth first
_try_jwt_auth(request)
# Use the whoami function which returns WhoamiOutput (Pydantic model)
result = execute_function(request, "whoami", {})
self.assertIsInstance(result, FunctionResult)
self.assertTrue(result.data["is_authenticated"])
self.assertEqual(result.data["user_type"], "JWTUser")
self.assertTrue(result.data["is_staff"])
@override_settings(
JWT_PRIVATE_KEY="test-secret-key-for-testing-only",
JWT_ALGORITHM="HS256",
)
class JWTUserTests(TestCase):
"""Test JWTUser behavior."""
def setUp(self):
clear_registry()
def tearDown(self):
clear_registry()
def test_jwt_user_attributes(self):
"""JWTUser has expected attributes."""
from mizan.jwt.tokens import TokenPayload
payload = TokenPayload(
user_id=42,
session_key="test-session",
token_type="access",
is_staff=True,
is_superuser=False,
exp=9999999999,
iat=0,
)
user = JWTUser(payload)
self.assertEqual(user.id, 42)
self.assertEqual(user.pk, 42)
self.assertTrue(user.is_staff)
self.assertFalse(user.is_superuser)
self.assertTrue(user.is_authenticated)
self.assertFalse(user.is_anonymous)
self.assertTrue(user.is_active)
def test_jwt_user_string_id(self):
"""JWTUser handles string user_id (converted to int)."""
from mizan.jwt.tokens import TokenPayload
payload = TokenPayload(
user_id="42", # String, as stored in JWT
session_key="test-session",
token_type="access",
is_staff=False,
is_superuser=False,
exp=9999999999,
iat=0,
)
user = JWTUser(payload)
self.assertEqual(user.id, 42)
self.assertIsInstance(user.id, int)
@override_settings(
JWT_PRIVATE_KEY="test-secret-key-for-testing-only",
JWT_ALGORITHM="HS256",
)
class AuthDecoratorTests(TestCase):
"""Test @client(auth=...) decorator."""
def setUp(self):
clear_registry()
self.factory = RequestFactory()
self.user = User.objects.create_user(
email="test@example.com",
password="testpass123",
is_staff=False,
is_superuser=False,
)
self.staff_user = User.objects.create_user(
email="staff@example.com",
password="testpass123",
is_staff=True,
is_superuser=False,
)
self.superuser = User.objects.create_user(
email="super@example.com",
password="testpass123",
is_staff=True,
is_superuser=True,
)
def tearDown(self):
self.user.delete()
self.staff_user.delete()
self.superuser.delete()
clear_registry()
def test_auth_required_with_anonymous(self):
"""@client(auth=True) rejects anonymous users."""
from django.contrib.auth.models import AnonymousUser
# Register a test function with proper Pydantic model
@client(auth=True)
def protected_fn(request) -> OkOutput:
return OkOutput(ok=True)
register(protected_fn, "protected_fn")
request = self.factory.post("/")
request.user = AnonymousUser()
result = execute_function(request, "protected_fn", {})
self.assertIsInstance(result, FunctionError)
self.assertEqual(result.code, ErrorCode.UNAUTHORIZED)
def test_auth_required_with_authenticated(self):
"""@client(auth=True) allows authenticated users."""
@client(auth=True)
def protected_fn2(request) -> OkOutput:
return OkOutput(ok=True)
register(protected_fn2, "protected_fn2")
request = self.factory.post("/")
request.user = self.user
result = execute_function(request, "protected_fn2", {})
self.assertIsInstance(result, FunctionResult)
self.assertEqual(result.data["ok"], True)
def test_auth_staff_with_regular_user(self):
"""@client(auth='staff') rejects non-staff users."""
@client(auth="staff")
def staff_fn(request) -> OkOutput:
return OkOutput(ok=True)
register(staff_fn, "staff_fn")
request = self.factory.post("/")
request.user = self.user # Not staff
result = execute_function(request, "staff_fn", {})
self.assertIsInstance(result, FunctionError)
self.assertEqual(result.code, ErrorCode.FORBIDDEN)
def test_auth_staff_with_staff_user(self):
"""@client(auth='staff') allows staff users."""
@client(auth="staff")
def staff_fn2(request) -> OkOutput:
return OkOutput(ok=True)
register(staff_fn2, "staff_fn2")
request = self.factory.post("/")
request.user = self.staff_user
result = execute_function(request, "staff_fn2", {})
self.assertIsInstance(result, FunctionResult)
def test_auth_superuser_with_staff(self):
"""@client(auth='superuser') rejects non-superusers."""
@client(auth="superuser")
def super_fn(request) -> OkOutput:
return OkOutput(ok=True)
register(super_fn, "super_fn")
request = self.factory.post("/")
request.user = self.staff_user # Staff but not superuser
result = execute_function(request, "super_fn", {})
self.assertIsInstance(result, FunctionError)
self.assertEqual(result.code, ErrorCode.FORBIDDEN)
def test_auth_superuser_with_superuser(self):
"""@client(auth='superuser') allows superusers."""
@client(auth="superuser")
def super_fn2(request) -> OkOutput:
return OkOutput(ok=True)
register(super_fn2, "super_fn2")
request = self.factory.post("/")
request.user = self.superuser
result = execute_function(request, "super_fn2", {})
self.assertIsInstance(result, FunctionResult)
def test_auth_with_jwt_user(self):
"""Auth checks work with JWTUser (stateless)."""
from mizan.jwt.tokens import TokenPayload
@client(auth="staff")
def jwt_staff_fn(request) -> UserTypeOutput:
return UserTypeOutput(user_type=type(request.user).__name__)
register(jwt_staff_fn, "jwt_staff_fn")
# Create JWTUser with is_staff=True
payload = TokenPayload(
user_id=99,
session_key="test",
token_type="access",
is_staff=True,
is_superuser=False,
exp=9999999999,
iat=0,
)
jwt_user = JWTUser(payload)
request = self.factory.post("/")
request.user = jwt_user
result = execute_function(request, "jwt_staff_fn", {})
self.assertIsInstance(result, FunctionResult)
self.assertEqual(result.data["user_type"], "JWTUser")
def test_auth_invalid_string_raises(self):
"""Invalid auth string raises ValueError at decoration time."""
with self.assertRaises(ValueError) as ctx:
@client(auth="admin") # 'admin' is not valid
def bad_fn(request) -> OkOutput:
return OkOutput(ok=True)
self.assertIn("Invalid auth value 'admin'", str(ctx.exception))
self.assertIn("required", str(ctx.exception))
def test_auth_callable_returns_true(self):
"""Callable auth returning True allows access."""
@client(auth=lambda r: r.user.email.endswith("@example.com"))
def email_check_fn(request) -> OkOutput:
return OkOutput(ok=True)
register(email_check_fn, "email_check_fn")
request = self.factory.post("/")
request.user = self.user # email is test@example.com
result = execute_function(request, "email_check_fn", {})
self.assertIsInstance(result, FunctionResult)
self.assertTrue(result.data["ok"])
def test_auth_callable_returns_false(self):
"""Callable auth returning False denies access."""
@client(auth=lambda r: r.user.email.endswith("@admin.com"))
def admin_email_fn(request) -> OkOutput:
return OkOutput(ok=True)
register(admin_email_fn, "admin_email_fn")
request = self.factory.post("/")
request.user = self.user # email is test@example.com, not @admin.com
result = execute_function(request, "admin_email_fn", {})
self.assertIsInstance(result, FunctionError)
self.assertEqual(result.code, ErrorCode.FORBIDDEN)
self.assertEqual(result.message, "Access denied")
def test_auth_callable_raises_permission_error(self):
"""Callable auth raising PermissionError uses custom message."""
def check_premium(request):
if not getattr(request.user, "is_premium", False):
raise PermissionError("Premium subscription required")
return True
@client(auth=check_premium)
def premium_fn(request) -> OkOutput:
return OkOutput(ok=True)
register(premium_fn, "premium_fn")
request = self.factory.post("/")
request.user = self.user # No is_premium attribute
result = execute_function(request, "premium_fn", {})
self.assertIsInstance(result, FunctionError)
self.assertEqual(result.code, ErrorCode.FORBIDDEN)
self.assertEqual(result.message, "Premium subscription required")
def test_auth_callable_with_anonymous_user(self):
"""Callable auth can check for anonymous users."""
from django.contrib.auth.models import AnonymousUser
def must_be_authenticated(request):
if not request.user.is_authenticated:
raise PermissionError("Please log in")
return True
@client(auth=must_be_authenticated)
def needs_login_fn(request) -> OkOutput:
return OkOutput(ok=True)
register(needs_login_fn, "needs_login_fn")
request = self.factory.post("/")
request.user = AnonymousUser()
result = execute_function(request, "needs_login_fn", {})
self.assertIsInstance(result, FunctionError)
self.assertEqual(result.code, ErrorCode.FORBIDDEN)
self.assertEqual(result.message, "Please log in")

View File

@@ -0,0 +1,567 @@
"""
Protocol Benchmark: HTTP vs WebSocket Server Functions
Compares performance of HTTP POST vs WebSocket RPC for server function calls.
Includes realistic scenarios with ORM queries.
Usage:
python manage.py test mizan.tests.test_benchmarks --verbosity=2
Note:
These are not unit tests - they measure performance. Results are printed
to stdout and should be run in isolation for accurate measurements.
"""
import asyncio
import json
import statistics
import time
from typing import Any
from unittest.mock import MagicMock, AsyncMock
from django.contrib.auth import get_user_model
from django.contrib.auth.models import AnonymousUser
from django.db import connection
from django.http import HttpRequest
from django.test import RequestFactory, TestCase, TransactionTestCase, override_settings
from pydantic import BaseModel
from mizan.client.executor import FunctionResult, execute_function, function_call_view
from mizan.setup.registry import clear_registry
from mizan.client import client
User = get_user_model()
# =============================================================================
# Benchmark Output Models
# =============================================================================
class SimpleOutput(BaseModel):
value: int
class UserOutput(BaseModel):
id: int
email: str
class UserListOutput(BaseModel):
users: list[dict[str, Any]]
count: int
class StatsOutput(BaseModel):
total_users: int
active_users: int
staff_count: int
# =============================================================================
# Benchmark Functions
# =============================================================================
def setup_benchmark_functions():
"""Register benchmark server functions."""
from mizan.setup.registry import register
clear_registry()
# 1. Simple computation (no I/O)
@client
def bench_simple(request: HttpRequest, a: int, b: int) -> SimpleOutput:
"""Simple addition - baseline with no I/O."""
return SimpleOutput(value=a + b)
register(bench_simple, "bench_simple")
# 2. Single ORM query
@client
def bench_get_user(request: HttpRequest, user_id: int) -> UserOutput:
"""Fetch single user by ID."""
user = User.objects.filter(id=user_id).first()
if user:
return UserOutput(id=user.id, email=user.email)
return UserOutput(id=0, email="")
register(bench_get_user, "bench_get_user")
# 3. List query with limit
@client
def bench_list_users(request: HttpRequest, limit: int) -> UserListOutput:
"""Fetch list of users with limit."""
users = User.objects.all()[:limit]
return UserListOutput(
users=[{"id": u.id, "email": u.email} for u in users],
count=len(users),
)
register(bench_list_users, "bench_list_users")
# 4. Aggregation query
@client
def bench_user_stats(request: HttpRequest) -> StatsOutput:
"""Compute user statistics with multiple queries."""
total = User.objects.count()
active = User.objects.filter(is_active=True).count()
staff = User.objects.filter(is_staff=True).count()
return StatsOutput(
total_users=total,
active_users=active,
staff_count=staff,
)
register(bench_user_stats, "bench_user_stats")
# 5. Complex query with joins
@client
def bench_user_search(
request: HttpRequest, email_contains: str, limit: int
) -> UserListOutput:
"""Search users by email pattern."""
users = User.objects.filter(
email__icontains=email_contains,
is_active=True,
).select_related()[:limit]
return UserListOutput(
users=[{"id": u.id, "email": u.email} for u in users],
count=len(users),
)
register(bench_user_search, "bench_user_search")
# =============================================================================
# Benchmark Test Cases
# =============================================================================
class ProtocolBenchmark(TransactionTestCase):
"""
Benchmark comparing HTTP vs WebSocket (simulated) performance.
Uses TransactionTestCase to ensure database state is realistic.
"""
# Number of iterations for each benchmark
ITERATIONS = 100
WARMUP = 10
@classmethod
def setUpClass(cls):
super().setUpClass()
setup_benchmark_functions()
def setUp(self):
self.factory = RequestFactory()
# Create test users for ORM benchmarks
self._create_test_users()
def _create_test_users(self):
"""Create test users for benchmarks."""
# Create 100 test users
users = []
for i in range(100):
users.append(
User(
email=f"bench{i}@example.com",
is_active=i % 10 != 0, # 90% active
is_staff=i < 5, # 5 staff
)
)
User.objects.bulk_create(users, ignore_conflicts=True)
self.test_user = User.objects.first()
def _make_request(self, body: dict | None = None) -> HttpRequest:
"""Create a request with optional JSON body."""
if body:
request = self.factory.post(
"/api/mizan/call/",
data=json.dumps(body),
content_type="application/json",
)
else:
request = self.factory.post("/api/mizan/call/")
request.user = AnonymousUser()
request._dont_enforce_csrf_checks = True
return request
def _benchmark_executor(self, fn_name: str, args: dict, label: str) -> dict:
"""
Benchmark direct executor calls (simulates WebSocket RPC).
Returns timing statistics.
"""
request = self._make_request()
times = []
# Warmup
for _ in range(self.WARMUP):
execute_function(request, fn_name, args)
# Benchmark
for _ in range(self.ITERATIONS):
start = time.perf_counter()
result = execute_function(request, fn_name, args)
end = time.perf_counter()
times.append((end - start) * 1000) # ms
return self._compute_stats(times, f"Executor ({label})")
def _benchmark_http(self, fn_name: str, args: dict, label: str) -> dict:
"""
Benchmark HTTP view calls.
Returns timing statistics.
"""
times = []
# Warmup
for _ in range(self.WARMUP):
request = self._make_request({"fn": fn_name, "args": args})
function_call_view(request)
# Benchmark
for _ in range(self.ITERATIONS):
request = self._make_request({"fn": fn_name, "args": args})
start = time.perf_counter()
response = function_call_view(request)
end = time.perf_counter()
times.append((end - start) * 1000) # ms
return self._compute_stats(times, f"HTTP ({label})")
def _compute_stats(self, times: list[float], label: str) -> dict:
"""Compute statistics from timing data."""
return {
"label": label,
"min": min(times),
"max": max(times),
"mean": statistics.mean(times),
"median": statistics.median(times),
"stdev": statistics.stdev(times) if len(times) > 1 else 0,
"p95": sorted(times)[int(len(times) * 0.95)],
"p99": sorted(times)[int(len(times) * 0.99)],
"iterations": len(times),
}
def _print_results(self, results: list[dict]):
"""Print benchmark results in a table."""
print("\n" + "=" * 80)
print(f"{'Benchmark':<40} {'Mean':>8} {'Median':>8} {'P95':>8} {'P99':>8}")
print("=" * 80)
for r in results:
print(
f"{r['label']:<40} {r['mean']:>7.3f}ms {r['median']:>7.3f}ms {r['p95']:>7.3f}ms {r['p99']:>7.3f}ms"
)
print("=" * 80)
def _print_comparison(self, executor_stats: dict, http_stats: dict):
"""Print comparison between executor and HTTP."""
overhead = (
(http_stats["mean"] - executor_stats["mean"]) / executor_stats["mean"]
) * 100
print(f" HTTP overhead vs Executor: {overhead:+.1f}%")
# -------------------------------------------------------------------------
# Benchmark Tests
# -------------------------------------------------------------------------
def test_benchmark_simple_computation(self):
"""Benchmark: Simple computation (no I/O)."""
print("\n\n### BENCHMARK: Simple Computation (no I/O) ###")
args = {"a": 100, "b": 200}
exec_stats = self._benchmark_executor("bench_simple", args, "simple")
http_stats = self._benchmark_http("bench_simple", args, "simple")
self._print_results([exec_stats, http_stats])
self._print_comparison(exec_stats, http_stats)
# Correctness check: 100 + 200 = 300
request = self._make_request()
result = execute_function(request, "bench_simple", args)
self.assertIsInstance(result, FunctionResult)
self.assertEqual(result.data["value"], 300)
def test_benchmark_single_query(self):
"""Benchmark: Single ORM query."""
print("\n\n### BENCHMARK: Single ORM Query ###")
args = {"user_id": self.test_user.id if self.test_user else 1}
exec_stats = self._benchmark_executor("bench_get_user", args, "single query")
http_stats = self._benchmark_http("bench_get_user", args, "single query")
self._print_results([exec_stats, http_stats])
self._print_comparison(exec_stats, http_stats)
# Correctness check: should return the test user's data
request = self._make_request()
result = execute_function(request, "bench_get_user", args)
self.assertIsInstance(result, FunctionResult)
self.assertEqual(result.data["id"], self.test_user.id)
def test_benchmark_list_query(self):
"""Benchmark: List query with serialization."""
print("\n\n### BENCHMARK: List Query (10 users) ###")
args = {"limit": 10}
exec_stats = self._benchmark_executor("bench_list_users", args, "list 10")
http_stats = self._benchmark_http("bench_list_users", args, "list 10")
self._print_results([exec_stats, http_stats])
self._print_comparison(exec_stats, http_stats)
# Correctness check: should return up to 10 users
request = self._make_request()
result = execute_function(request, "bench_list_users", args)
self.assertIsInstance(result, FunctionResult)
self.assertLessEqual(result.data["count"], 10)
self.assertEqual(len(result.data["users"]), result.data["count"])
def test_benchmark_aggregation(self):
"""Benchmark: Aggregation queries."""
print("\n\n### BENCHMARK: Aggregation (3 COUNT queries) ###")
args = {}
exec_stats = self._benchmark_executor("bench_user_stats", args, "aggregation")
http_stats = self._benchmark_http("bench_user_stats", args, "aggregation")
self._print_results([exec_stats, http_stats])
self._print_comparison(exec_stats, http_stats)
# Correctness check: stats should have non-negative counts
request = self._make_request()
result = execute_function(request, "bench_user_stats", args)
self.assertIsInstance(result, FunctionResult)
self.assertGreaterEqual(result.data["total_users"], 0)
self.assertGreaterEqual(result.data["active_users"], 0)
self.assertGreaterEqual(result.data["staff_count"], 0)
def test_benchmark_search_query(self):
"""Benchmark: Search with filter."""
print("\n\n### BENCHMARK: Search Query (LIKE + LIMIT) ###")
args = {"email_contains": "bench", "limit": 20}
exec_stats = self._benchmark_executor("bench_user_search", args, "search")
http_stats = self._benchmark_http("bench_user_search", args, "search")
self._print_results([exec_stats, http_stats])
self._print_comparison(exec_stats, http_stats)
# Correctness check: search results should contain "bench" in emails
request = self._make_request()
result = execute_function(request, "bench_user_search", args)
self.assertIsInstance(result, FunctionResult)
self.assertLessEqual(result.data["count"], 20)
for user in result.data["users"]:
self.assertIn("bench", user["email"].lower())
def test_summary(self):
"""Print summary of all benchmarks."""
print("\n\n" + "=" * 80)
print("BENCHMARK SUMMARY")
print("=" * 80)
print(f"Iterations per benchmark: {self.ITERATIONS}")
print(f"Warmup iterations: {self.WARMUP}")
print("\nKey findings:")
print("- 'Executor' simulates WebSocket RPC (direct function call)")
print("- 'HTTP' measures full request/response cycle")
print("- HTTP overhead includes: JSON parsing, CSRF, view dispatch")
print("- For I/O-bound operations, protocol overhead is negligible")
print("=" * 80)
# Verify bench_simple still produces correct output after all benchmarks
request = self._make_request()
result = execute_function(request, "bench_simple", {"a": 7, "b": 8})
self.assertIsInstance(result, FunctionResult)
self.assertEqual(result.data["value"], 15)
# =============================================================================
# Throughput Benchmark
# =============================================================================
class ThroughputBenchmark(TransactionTestCase):
"""
Measure requests per second (throughput) for server functions.
Tests both sequential and concurrent scenarios.
"""
DURATION_SECONDS = 2 # How long to run each throughput test
@classmethod
def setUpClass(cls):
super().setUpClass()
setup_benchmark_functions()
def setUp(self):
self.factory = RequestFactory()
self._create_test_users()
def _create_test_users(self):
"""Create test users for benchmarks."""
users = []
for i in range(100):
users.append(
User(
email=f"bench{i}@example.com",
is_active=i % 10 != 0,
is_staff=i < 5,
)
)
User.objects.bulk_create(users, ignore_conflicts=True)
self.test_user = User.objects.first()
def _make_request(self, body: dict) -> HttpRequest:
"""Create a POST request with JSON body."""
request = self.factory.post(
"/api/mizan/call/",
data=json.dumps(body),
content_type="application/json",
)
request.user = AnonymousUser()
request._dont_enforce_csrf_checks = True
return request
def _measure_throughput_executor(self, fn_name: str, args: dict) -> float:
"""Measure requests/second using direct executor calls."""
request = self._make_request({"fn": fn_name, "args": args})
# Warmup
for _ in range(10):
execute_function(request, fn_name, args)
# Measure
count = 0
start = time.perf_counter()
deadline = start + self.DURATION_SECONDS
while time.perf_counter() < deadline:
execute_function(request, fn_name, args)
count += 1
elapsed = time.perf_counter() - start
return count / elapsed
def _measure_throughput_http(self, fn_name: str, args: dict) -> float:
"""Measure requests/second using HTTP view calls."""
# Warmup
for _ in range(10):
request = self._make_request({"fn": fn_name, "args": args})
function_call_view(request)
# Measure
count = 0
start = time.perf_counter()
deadline = start + self.DURATION_SECONDS
while time.perf_counter() < deadline:
request = self._make_request({"fn": fn_name, "args": args})
function_call_view(request)
count += 1
elapsed = time.perf_counter() - start
return count / elapsed
def _print_throughput(self, label: str, executor_rps: float, http_rps: float):
"""Print throughput results."""
print(f"\n{label}:")
print(f" Executor (WebSocket): {executor_rps:,.0f} req/s")
print(f" HTTP: {http_rps:,.0f} req/s")
print(f" Ratio: {executor_rps/http_rps:.1f}x")
def test_throughput_simple(self):
"""Throughput: Simple computation (no I/O)."""
print("\n\n### THROUGHPUT: Simple Computation ###")
executor_rps = self._measure_throughput_executor(
"bench_simple", {"a": 1, "b": 2}
)
http_rps = self._measure_throughput_http("bench_simple", {"a": 1, "b": 2})
self._print_throughput("Simple (no I/O)", executor_rps, http_rps)
# Correctness check
request = self._make_request({"fn": "bench_simple", "args": {"a": 1, "b": 2}})
result = execute_function(request, "bench_simple", {"a": 1, "b": 2})
self.assertIsInstance(result, FunctionResult)
self.assertEqual(result.data["value"], 3)
def test_throughput_single_query(self):
"""Throughput: Single ORM query."""
print("\n\n### THROUGHPUT: Single ORM Query ###")
args = {"user_id": self.test_user.id if self.test_user else 1}
executor_rps = self._measure_throughput_executor("bench_get_user", args)
http_rps = self._measure_throughput_http("bench_get_user", args)
self._print_throughput("Single Query", executor_rps, http_rps)
# Correctness check
request = self._make_request({"fn": "bench_get_user", "args": args})
result = execute_function(request, "bench_get_user", args)
self.assertIsInstance(result, FunctionResult)
self.assertEqual(result.data["id"], self.test_user.id)
def test_throughput_list_query(self):
"""Throughput: List query."""
print("\n\n### THROUGHPUT: List Query (10 users) ###")
executor_rps = self._measure_throughput_executor(
"bench_list_users", {"limit": 10}
)
http_rps = self._measure_throughput_http("bench_list_users", {"limit": 10})
self._print_throughput("List Query", executor_rps, http_rps)
# Correctness check
request = self._make_request({"fn": "bench_list_users", "args": {"limit": 10}})
result = execute_function(request, "bench_list_users", {"limit": 10})
self.assertIsInstance(result, FunctionResult)
self.assertLessEqual(result.data["count"], 10)
def test_throughput_aggregation(self):
"""Throughput: Aggregation queries."""
print("\n\n### THROUGHPUT: Aggregation ###")
executor_rps = self._measure_throughput_executor("bench_user_stats", {})
http_rps = self._measure_throughput_http("bench_user_stats", {})
self._print_throughput("Aggregation", executor_rps, http_rps)
# Correctness check
request = self._make_request({"fn": "bench_user_stats", "args": {}})
result = execute_function(request, "bench_user_stats", {})
self.assertIsInstance(result, FunctionResult)
self.assertGreaterEqual(result.data["total_users"], 0)
def test_throughput_summary(self):
"""Print throughput summary."""
print("\n\n" + "=" * 80)
print("THROUGHPUT SUMMARY")
print("=" * 80)
print(f"Test duration: {self.DURATION_SECONDS}s per scenario")
print("\nNotes:")
print("- These are single-threaded sequential measurements")
print("- Real throughput scales with worker processes (gunicorn -w N)")
print("- Database queries are the bottleneck, not protocol overhead")
print("- Async workers (uvicorn) can handle more concurrent connections")
print("=" * 80)
# Verify bench_simple still produces correct output after all throughput tests
request = self._make_request({"fn": "bench_simple", "args": {"a": 10, "b": 20}})
result = execute_function(request, "bench_simple", {"a": 10, "b": 20})
self.assertIsInstance(result, FunctionResult)
self.assertEqual(result.data["value"], 30)

File diff suppressed because it is too large Load Diff

File diff suppressed because it is too large Load Diff

File diff suppressed because it is too large Load Diff

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,679 @@
"""
Stress tests for mizan.shapes — edge cases and deep nesting.
Models: Publisher → Author → Book → Chapter → Section (5 levels deep),
two FKs to same model, slug PK, UUID PK, self-referential FK, M2M,
nullable FKs, abstract bases, empty/zero/false values.
"""
import pytest
from typing import get_type_hints
from django.test import TestCase
from mizan.shapes import Shape, Diff, NestedDiff
import uuid
from tests.models import (
Publisher,
Author,
Book,
Chapter,
Section,
Tag,
Category,
)
# =============================================================================
# Shapes — varying projections
# =============================================================================
class TagShape(Shape[Tag]):
slug: str
label: str
class FlatAuthorShape(Shape[Author]):
id: int | None = None
name: str
class FlatBookShape(Shape[Book]):
id: int | None = None
title: str
is_published: bool
class BookCardShape(Shape[Book]):
id: int | None = None
title: str
isbn: str
page_count: int
is_published: bool
author: FlatAuthorShape # single nested, not list
class AuthorCardShape(Shape[Author]):
id: int | None = None
name: str
bio: str
books: list[FlatBookShape] = []
class SectionShape(Shape[Section]):
id: uuid.UUID | None = None
heading: str
body: str
position: int
class ChapterShape(Shape[Chapter]):
id: int | None = None
number: int
title: str
word_count: int
sections: list[SectionShape] = []
class BookDetailShape(Shape[Book]):
id: int | None = None
title: str
isbn: str
page_count: int
is_published: bool
author: FlatAuthorShape
chapters: list[ChapterShape] = []
tags: list[TagShape] = []
class AuthorDetailShape(Shape[Author]):
id: int | None = None
name: str
bio: str
books: list[BookDetailShape] = []
class PublisherDetailShape(Shape[Publisher]):
id: int | None = None
name: str
country: str
authors: list[AuthorDetailShape] = []
class BookWithEditorShape(Shape[Book]):
"""Two FKs to the same model (author + editor)."""
id: int | None = None
title: str
author: FlatAuthorShape
editor: FlatAuthorShape | None = None
class CategoryShape(Shape[Category]):
id: int | None = None
name: str
children: list["CategoryShape"] = []
# =============================================================================
# Shape class creation
# =============================================================================
class TestShapeClassCreation(TestCase):
def test_flat_shape_has_no_nested(self):
self.assertEqual(FlatAuthorShape._nested, {})
self.assertEqual(FlatAuthorShape._field_names, ["id", "name"])
def test_nested_shape_detected(self):
self.assertIn("books", AuthorCardShape._nested)
self.assertIs(AuthorCardShape._nested["books"], FlatBookShape)
def test_deep_nesting_spec_depth(self):
"""PublisherDetailShape → Author → Book → Chapter → Section."""
nested_keys = {
k for d in PublisherDetailShape._spec if isinstance(d, dict) for k in d
}
self.assertIn("authors", nested_keys)
author_spec = next(
d["authors"]
for d in PublisherDetailShape._spec
if isinstance(d, dict) and "authors" in d
)
author_nested = {k for d in author_spec if isinstance(d, dict) for k in d}
self.assertIn("books", author_nested)
def test_pk_field_resolution_integer(self):
self.assertEqual(FlatAuthorShape._pk_field, "id")
def test_pk_field_resolution_slug(self):
self.assertEqual(TagShape._pk_field, "slug")
def test_pk_field_resolution_uuid(self):
self.assertEqual(SectionShape._pk_field, "id")
def test_single_nested_not_list(self):
self.assertIn("author", BookCardShape._nested)
self.assertIs(BookCardShape._nested["author"], FlatAuthorShape)
def test_optional_nested(self):
"""BookWithEditorShape.editor is FlatAuthorShape | None.
_extract_shape_class needs to handle Optional/Union."""
# If this doesn't detect editor as nested, it's a known gap
if "editor" in BookWithEditorShape._nested:
self.assertIs(BookWithEditorShape._nested["editor"], FlatAuthorShape)
else:
self.skipTest(
"_extract_shape_class does not unwrap Optional[Shape] — known gap"
)
def test_self_referential_shape(self):
"""CategoryShape.children references itself."""
self.assertIn("children", CategoryShape._nested)
self.assertIs(CategoryShape._nested["children"], CategoryShape)
def test_multiple_shapes_same_model_independent(self):
self.assertLess(
len(FlatBookShape._field_names), len(BookDetailShape._field_names)
)
self.assertNotEqual(FlatBookShape._spec, BookDetailShape._spec)
# =============================================================================
# Queries
# =============================================================================
class TestShapeQuery(TestCase):
@classmethod
def setUpTestData(cls):
cls.publisher = Publisher.objects.create(name="Orbit", country="UK")
cls.mentor = Author.objects.create(
name="Ursula", bio="Legend", publisher=cls.publisher
)
cls.author = Author.objects.create(
name="Ann Leckie",
bio="Imperial Radch",
publisher=cls.publisher,
mentor=cls.mentor,
)
cls.editor = Author.objects.create(
name="Devi Pillai", bio="Editor", publisher=cls.publisher
)
cls.tag_sf = Tag.objects.create(slug="sci-fi", label="Science Fiction")
cls.tag_space = Tag.objects.create(slug="space-opera", label="Space Opera")
cls.book = Book.objects.create(
title="Ancillary Justice",
isbn="9780316246620",
page_count=386,
is_published=True,
author=cls.author,
editor=cls.editor,
)
cls.book.tags.add(cls.tag_sf, cls.tag_space)
cls.ch1 = Chapter.objects.create(
book=cls.book, number=1, title="The Body", word_count=5200
)
cls.ch2 = Chapter.objects.create(
book=cls.book, number=2, title="The Ship", word_count=4800
)
Section.objects.create(
chapter=cls.ch1, heading="Opening", body="...", position=0
)
Section.objects.create(
chapter=cls.ch1, heading="Discovery", body="...", position=1
)
cls.root_cat = Category.objects.create(name="Fiction")
cls.child_cat = Category.objects.create(name="Sci-Fi", parent=cls.root_cat)
Category.objects.create(name="Hard SF", parent=cls.child_cat)
# ── Flat ──
def test_flat_query_returns_minimal_fields(self):
results = FlatAuthorShape.query()
self.assertEqual(len(results), 3)
for r in results:
self.assertTrue(hasattr(r, "name"))
self.assertTrue(hasattr(r, "id"))
def test_flat_query_with_lambda_filter(self):
results = FlatAuthorShape.query(lambda qs: qs.filter(name="Ann Leckie"))
self.assertEqual(len(results), 1)
self.assertEqual(results[0].name, "Ann Leckie")
def test_flat_query_with_raw_queryset(self):
qs = Author.objects.filter(mentor__isnull=False)
results = FlatAuthorShape.query(qs)
self.assertEqual(len(results), 1)
self.assertEqual(results[0].name, "Ann Leckie")
# ── Nested ──
def test_single_nested_fk(self):
results = BookCardShape.query(lambda qs: qs.filter(pk=self.book.pk))
self.assertEqual(len(results), 1)
self.assertEqual(results[0].author.name, "Ann Leckie")
def test_list_nested_reverse_fk(self):
results = AuthorCardShape.query(lambda qs: qs.filter(pk=self.author.pk))
self.assertEqual(len(results), 1)
self.assertEqual(len(results[0].books), 1)
self.assertEqual(results[0].books[0].title, "Ancillary Justice")
def test_deep_nesting_book_chapters_sections(self):
results = BookDetailShape.query(lambda qs: qs.filter(pk=self.book.pk))
self.assertEqual(len(results), 1)
book = results[0]
self.assertEqual(len(book.chapters), 2)
ch1 = next(c for c in book.chapters if c.number == 1)
self.assertEqual(len(ch1.sections), 2)
def test_full_depth_publisher_to_section(self):
"""5 levels: Publisher → Author → Book → Chapter → Section."""
results = PublisherDetailShape.query(lambda qs: qs.filter(pk=self.publisher.pk))
self.assertEqual(len(results), 1)
pub = results[0]
self.assertEqual(len(pub.authors), 3)
leckie = next(a for a in pub.authors if a.name == "Ann Leckie")
self.assertEqual(len(leckie.books), 1)
self.assertEqual(len(leckie.books[0].chapters), 2)
def test_two_fks_to_same_model(self):
results = BookWithEditorShape.query(lambda qs: qs.filter(pk=self.book.pk))
self.assertEqual(len(results), 1)
self.assertEqual(results[0].author.name, "Ann Leckie")
if "editor" in BookWithEditorShape._nested:
self.assertIsNotNone(results[0].editor)
self.assertEqual(results[0].editor.name, "Devi Pillai")
def test_nullable_fk_returns_none(self):
book_no_editor = Book.objects.create(
title="Provenance",
isbn="9780316246699",
page_count=448,
is_published=True,
author=self.author,
editor=None,
)
results = BookWithEditorShape.query(lambda qs: qs.filter(pk=book_no_editor.pk))
self.assertEqual(len(results), 1)
if "editor" in BookWithEditorShape._nested:
self.assertIsNone(results[0].editor)
def test_m2m_tags(self):
results = BookDetailShape.query(lambda qs: qs.filter(pk=self.book.pk))
book = results[0]
self.assertEqual(len(book.tags), 2)
slugs = {t.slug for t in book.tags}
self.assertEqual(slugs, {"sci-fi", "space-opera"})
def test_slug_pk_shape(self):
results = TagShape.query()
self.assertEqual(len(results), 2)
self.assertTrue(all(isinstance(r.slug, str) for r in results))
def test_relation_qs_filters_nested(self):
results = AuthorCardShape.query(
lambda qs: qs.filter(pk=self.author.pk),
books=lambda qs: qs.filter(is_published=True),
)
self.assertEqual(len(results), 1)
self.assertTrue(all(b.is_published for b in results[0].books))
def test_empty_nested_list(self):
results = AuthorCardShape.query(lambda qs: qs.filter(pk=self.editor.pk))
self.assertEqual(len(results), 1)
self.assertEqual(results[0].books, [])
# ── Query efficiency ──
def test_flat_query_is_single_query(self):
with self.assertNumQueries(1):
FlatAuthorShape.query()
def test_nested_query_uses_prefetch(self):
with self.assertNumQueries(2):
AuthorCardShape.query()
# =============================================================================
# Diff
# =============================================================================
class TestDiff(TestCase):
@classmethod
def setUpTestData(cls):
cls.publisher = Publisher.objects.create(name="Tor", country="US")
cls.author = Author.objects.create(
name="Brandon Sanderson", bio="Cosmere", publisher=cls.publisher
)
cls.book = Book.objects.create(
title="Mistborn",
isbn="9780765311788",
page_count=541,
is_published=True,
author=cls.author,
)
cls.ch1 = Chapter.objects.create(
book=cls.book, number=1, title="Ash", word_count=6000
)
cls.ch2 = Chapter.objects.create(
book=cls.book, number=2, title="Mist", word_count=5500
)
# ── Single item ──
def test_diff_no_changes(self):
shape = BookCardShape(
id=self.book.pk,
title="Mistborn",
isbn="9780765311788",
page_count=541,
is_published=True,
author=FlatAuthorShape(id=self.author.pk, name="Brandon Sanderson"),
)
d = shape.diff()
self.assertFalse(d.is_new)
self.assertEqual(d.changed, {})
def test_diff_detects_field_change(self):
shape = BookCardShape(
id=self.book.pk,
title="Mistborn: The Final Empire",
isbn="9780765311788",
page_count=541,
is_published=True,
author=FlatAuthorShape(id=self.author.pk, name="Brandon Sanderson"),
)
d = shape.diff()
self.assertIn("title", d.changed)
self.assertEqual(d.changed["title"], "Mistborn: The Final Empire")
def test_diff_multiple_field_changes(self):
shape = BookCardShape(
id=self.book.pk,
title="Mistborn: TFE",
isbn="9780765311788",
page_count=600,
is_published=True,
author=FlatAuthorShape(id=self.author.pk, name="Brandon Sanderson"),
)
d = shape.diff()
self.assertIn("title", d.changed)
self.assertIn("page_count", d.changed)
self.assertNotIn("isbn", d.changed)
def test_diff_new_item(self):
shape = FlatBookShape(id=None, title="Elantris", is_published=True)
d = shape.diff()
self.assertTrue(d.is_new)
self.assertIn("title", d.changed)
def test_diff_nonexistent_pk_raises(self):
shape = FlatBookShape(id=999999, title="Nope", is_published=False)
with self.assertRaises(Book.DoesNotExist):
shape.diff()
# ── Nested ──
def test_nested_diff_detects_updated_chapter(self):
shape = BookDetailShape(
id=self.book.pk,
title="Mistborn",
isbn="9780765311788",
page_count=541,
is_published=True,
author=FlatAuthorShape(id=self.author.pk, name="Brandon Sanderson"),
chapters=[
ChapterShape(
id=self.ch1.pk,
number=1,
title="Ash Falls",
word_count=6000,
sections=[],
),
ChapterShape(
id=self.ch2.pk, number=2, title="Mist", word_count=5500, sections=[]
),
],
tags=[],
)
d = shape.diff()
self.assertEqual(len(d.chapters.updated), 1)
self.assertEqual(d.chapters.updated[0].title, "Ash Falls")
def test_nested_diff_detects_created(self):
shape = BookDetailShape(
id=self.book.pk,
title="Mistborn",
isbn="9780765311788",
page_count=541,
is_published=True,
author=FlatAuthorShape(id=self.author.pk, name="Brandon Sanderson"),
chapters=[
ChapterShape(
id=self.ch1.pk, number=1, title="Ash", word_count=6000, sections=[]
),
ChapterShape(
id=self.ch2.pk, number=2, title="Mist", word_count=5500, sections=[]
),
ChapterShape(
id=None, number=3, title="New Chapter", word_count=0, sections=[]
),
],
tags=[],
)
d = shape.diff()
self.assertEqual(len(d.chapters.created), 1)
def test_nested_diff_detects_deleted(self):
shape = BookDetailShape(
id=self.book.pk,
title="Mistborn",
isbn="9780765311788",
page_count=541,
is_published=True,
author=FlatAuthorShape(id=self.author.pk, name="Brandon Sanderson"),
chapters=[
ChapterShape(
id=self.ch1.pk, number=1, title="Ash", word_count=6000, sections=[]
),
],
tags=[],
)
d = shape.diff()
self.assertIn(self.ch2.pk, d.chapters.deleted)
def test_nested_diff_combined_operations(self):
shape = BookDetailShape(
id=self.book.pk,
title="Mistborn",
isbn="9780765311788",
page_count=541,
is_published=True,
author=FlatAuthorShape(id=self.author.pk, name="Brandon Sanderson"),
chapters=[
ChapterShape(
id=self.ch1.pk,
number=1,
title="Ash Rewritten",
word_count=7000,
sections=[],
),
ChapterShape(
id=None, number=3, title="Epilogue", word_count=2000, sections=[]
),
],
tags=[],
)
d = shape.diff()
self.assertEqual(len(d.chapters.updated), 1)
self.assertEqual(len(d.chapters.deleted), 1)
self.assertEqual(len(d.chapters.created), 1)
# ── Strict Diff access ──
def test_diff_strict_getattr_raises_on_typo(self):
shape = FlatBookShape(id=self.book.pk, title="Mistborn", is_published=True)
d = shape.diff()
with self.assertRaises(AttributeError):
_ = d.chapterz
def test_diff_strict_nested_raises_on_typo(self):
shape = FlatBookShape(id=self.book.pk, title="Mistborn", is_published=True)
d = shape.diff()
with self.assertRaises(KeyError):
d.nested("chapterz")
def test_diff_strict_shows_valid_names(self):
shape = BookDetailShape(
id=self.book.pk,
title="Mistborn",
isbn="9780765311788",
page_count=541,
is_published=True,
author=FlatAuthorShape(id=self.author.pk, name="Brandon Sanderson"),
chapters=[],
tags=[],
)
d = shape.diff()
with self.assertRaises(AttributeError) as ctx:
_ = d.bogus
self.assertIn("chapters", str(ctx.exception))
# ── diff_many ──
def test_diff_many_single_query_for_existing(self):
items = [FlatBookShape(id=self.book.pk, title="Renamed", is_published=True)]
results = FlatBookShape.diff_many(items)
self.assertEqual(len(results), 1)
_, d = results[0]
self.assertIn("title", d.changed)
def test_diff_many_mixed_new_and_existing(self):
items = [
FlatBookShape(id=self.book.pk, title="Mistborn", is_published=True),
FlatBookShape(id=None, title="New Book", is_published=False),
]
results = FlatBookShape.diff_many(items)
new = [d for _, d in results if d.is_new]
existing = [d for _, d in results if not d.is_new]
self.assertEqual(len(new), 1)
self.assertEqual(len(existing), 1)
def test_diff_many_nonexistent_raises(self):
items = [FlatBookShape(id=999999, title="Ghost", is_published=False)]
with self.assertRaises(Book.DoesNotExist):
FlatBookShape.diff_many(items)
def test_diff_many_batched_query(self):
book2 = Book.objects.create(
title="Warbreaker",
isbn="9780765320308",
page_count=592,
is_published=True,
author=self.author,
)
items = [
FlatBookShape(id=self.book.pk, title="Mistborn", is_published=True),
FlatBookShape(id=book2.pk, title="Warbreaker Updated", is_published=True),
]
with self.assertNumQueries(1):
FlatBookShape.diff_many(items)
def test_diff_many_empty(self):
self.assertEqual(FlatBookShape.diff_many([]), [])
# =============================================================================
# Edge cases
# =============================================================================
class TestEdgeCases(TestCase):
@classmethod
def setUpTestData(cls):
cls.publisher = Publisher.objects.create(name="Edge Cases Ltd", country="XX")
cls.author = Author.objects.create(
name="Edge Author", bio="", publisher=cls.publisher
)
def test_empty_table_returns_empty_list(self):
Tag.objects.all().delete()
results = TagShape.query()
self.assertEqual(results, [])
def test_empty_string_fields(self):
results = AuthorCardShape.query(lambda qs: qs.filter(pk=self.author.pk))
self.assertEqual(results[0].bio, "")
def test_boolean_false_is_not_missing(self):
book = Book.objects.create(
title="Unpublished",
isbn="0000000000000",
page_count=0,
is_published=False,
author=self.author,
)
results = FlatBookShape.query(lambda qs: qs.filter(pk=book.pk))
self.assertIs(results[0].is_published, False)
def test_zero_integer_is_not_missing(self):
book = Book.objects.create(
title="Empty",
isbn="0000000000001",
page_count=0,
is_published=False,
author=self.author,
)
results = BookCardShape.query(lambda qs: qs.filter(pk=book.pk))
self.assertEqual(results[0].page_count, 0)
def test_large_queryset(self):
books = [
Book(
title=f"Book {i}",
isbn=f"{i:013d}",
page_count=i * 10,
is_published=i % 2 == 0,
author=self.author,
)
for i in range(100)
]
Book.objects.bulk_create(books)
results = FlatBookShape.query(lambda qs: qs.filter(author=self.author))
self.assertGreaterEqual(len(results), 100)
def test_diff_on_boolean_change(self):
book = Book.objects.create(
title="Toggle",
isbn="1111111111111",
page_count=100,
is_published=False,
author=self.author,
)
shape = FlatBookShape(id=book.pk, title="Toggle", is_published=True)
d = shape.diff()
self.assertIn("is_published", d.changed)
self.assertIs(d.changed["is_published"], True)
def test_diff_unchanged_returns_empty(self):
book = Book.objects.create(
title="Same",
isbn="2222222222222",
page_count=200,
is_published=True,
author=self.author,
)
shape = FlatBookShape(id=book.pk, title="Same", is_published=True)
d = shape.diff()
self.assertEqual(d.changed, {})
self.assertFalse(d.is_new)

View File

@@ -0,0 +1,162 @@
"""
Tests for the Mizan SSR bridge and template backend.
Requires Bun installed and the test worker at packages/mizan-ssr/src/test-worker.tsx.
Tests skip gracefully if Bun is not available.
"""
import os
import shutil
import threading
from django.test import SimpleTestCase, RequestFactory
# Path to the test worker
_SSR_WORKER = os.path.join(
os.path.dirname(__file__),
"..", "..", "..", "..", "..", # up to repo root
"packages", "mizan-ssr", "src", "test-worker.tsx",
)
_SSR_WORKER = os.path.normpath(_SSR_WORKER)
_BUN_AVAILABLE = shutil.which("bun") is not None
_SKIP_MSG = "Bun not available"
class SSRBridgeTests(SimpleTestCase):
"""Tests for the SSR bridge subprocess manager."""
def setUp(self):
if not _BUN_AVAILABLE:
self.skipTest(_SKIP_MSG)
if not os.path.exists(_SSR_WORKER):
self.skipTest(f"Test worker not found at {_SSR_WORKER}")
from mizan.ssr.bridge import SSRBridge
self.bridge = SSRBridge(worker_path=_SSR_WORKER, timeout=5.0)
def tearDown(self):
if hasattr(self, "bridge"):
self.bridge.shutdown()
def test_ping(self):
"""Worker starts and responds to ping."""
self.assertTrue(self.bridge.ping())
def test_render_simple(self):
"""Renders a simple component to HTML."""
result = self.bridge.render("Hello", {"name": "World"})
self.assertIn("Hello,", result.html)
self.assertIn("World", result.html)
def test_render_with_props(self):
"""Renders a component with multiple props."""
result = self.bridge.render("UserProfile", {"user_id": 42, "name": "Alice"})
self.assertIn("Alice", result.html)
self.assertIn("42", result.html)
def test_render_missing_component(self):
"""Rendering an unregistered component raises RuntimeError."""
with self.assertRaises(RuntimeError) as ctx:
self.bridge.render("NonExistent", {})
self.assertIn("not registered", str(ctx.exception))
def test_render_error(self):
"""Component that throws during render raises RuntimeError."""
with self.assertRaises(RuntimeError) as ctx:
self.bridge.render("Broken", {})
self.assertIn("Render error", str(ctx.exception))
def test_crash_recovery(self):
"""Bridge restarts the worker if it dies."""
# First render works
result = self.bridge.render("Hello", {"name": "Before"})
self.assertIn("Before", result.html)
# Kill the subprocess
self.bridge._proc.kill()
self.bridge._proc.wait()
# Next render should restart and work
result = self.bridge.render("Hello", {"name": "After"})
self.assertIn("After", result.html)
def test_concurrent_renders(self):
"""Multiple threads can render simultaneously."""
results = {}
errors = {}
def render_in_thread(name: str, idx: int):
try:
result = self.bridge.render("Hello", {"name": name})
results[idx] = result.html
except Exception as e:
errors[idx] = e
threads = []
for i in range(5):
t = threading.Thread(target=render_in_thread, args=(f"User{i}", i))
threads.append(t)
t.start()
for t in threads:
t.join(timeout=10)
self.assertEqual(len(errors), 0, f"Errors in concurrent renders: {errors}")
self.assertEqual(len(results), 5)
for i in range(5):
self.assertIn(f"User{i}", results[i])
class SSRTemplateBackendTests(SimpleTestCase):
"""Tests for the MizanTemplates Django template backend."""
def setUp(self):
if not _BUN_AVAILABLE:
self.skipTest(_SKIP_MSG)
if not os.path.exists(_SSR_WORKER):
self.skipTest(f"Test worker not found at {_SSR_WORKER}")
from mizan.ssr.backend import MizanTemplates
self.engine = MizanTemplates({
"NAME": "mizan-test",
"DIRS": [],
"APP_DIRS": False,
"OPTIONS": {
"worker_path": _SSR_WORKER,
"timeout": 5,
},
})
self.factory = RequestFactory()
def tearDown(self):
if hasattr(self, "engine") and self.engine._bridge is not None:
self.engine._bridge.shutdown()
def test_get_template(self):
"""get_template returns a MizanTemplate."""
from mizan.ssr.backend import MizanTemplate
template = self.engine.get_template("Hello")
self.assertIsInstance(template, MizanTemplate)
self.assertEqual(template.component_name, "Hello")
def test_template_render(self):
"""MizanTemplate.render() produces HTML."""
template = self.engine.get_template("Hello")
html = template.render({"name": "Django"})
self.assertIn("Hello,", html)
self.assertIn("Django", html)
self.assertIn('data-mizan-component="Hello"', html)
def test_template_render_strips_django_internals(self):
"""Django-internal context keys (request, csrf_token) are not passed as props."""
template = self.engine.get_template("Hello")
request = self.factory.get("/")
html = template.render({"name": "Test", "request": request, "csrf_token": "abc"}, request)
self.assertIn("Test", html)
def test_from_string_raises(self):
"""from_string is not supported."""
from django.template import TemplateDoesNotExist
with self.assertRaises(TemplateDoesNotExist):
self.engine.from_string("<div>Not supported</div>")

View File

@@ -0,0 +1,42 @@
"""
mizan URL Configuration
HTTP endpoints:
- GET /session/ - Initialize session and get CSRF token (for SSR)
- POST /call/ - Server function calls (HTTP transport)
- GET /ctx/<name>/ - Bundled context fetch (all functions in a named context)
Security:
- Schema export is NOT exposed over HTTP to prevent API enumeration
- Use the management command instead: python manage.py export_mizan_schema
"""
from django.http import JsonResponse
from django.middleware.csrf import get_token
from django.urls import path
from django.views.decorators.csrf import ensure_csrf_cookie
from .client.executor import function_call_view, context_fetch_view
app_name = "mizan"
@ensure_csrf_cookie
def session_init_view(request):
"""
Initialize a Django session and return the CSRF token.
Used by SSR to establish a session before making authenticated requests.
The @ensure_csrf_cookie decorator ensures the csrftoken cookie is set.
Returns:
{ "csrfToken": "..." }
"""
return JsonResponse({"csrfToken": get_token(request)})
urlpatterns = [
path("session/", session_init_view, name="session-init"),
path("call/", function_call_view, name="function-call"),
path("ctx/<str:context_name>/", context_fetch_view, name="context-fetch"),
]

View File

View File

@@ -0,0 +1,142 @@
from django.contrib.auth.models import (
AbstractBaseUser,
BaseUserManager,
PermissionsMixin,
)
from django.db import models
class EmailUserManager(BaseUserManager):
"""Custom user manager using email as the unique identifier."""
def create_user(self, email, password=None, **extra_fields):
if not email:
raise ValueError("Email is required")
email = self.normalize_email(email)
user = self.model(email=email, **extra_fields)
user.set_password(password)
user.save(using=self._db)
return user
def create_superuser(self, email, password=None, **extra_fields):
extra_fields.setdefault("is_staff", True)
extra_fields.setdefault("is_superuser", True)
return self.create_user(email, password, **extra_fields)
class EmailUser(AbstractBaseUser, PermissionsMixin):
"""Minimal user model with email as USERNAME_FIELD.
Matches the calling convention used in mizan's test suite:
User.objects.create_user(email="...", password="...", is_staff=True)
"""
email = models.EmailField(unique=True)
is_staff = models.BooleanField(default=False)
is_active = models.BooleanField(default=True)
objects = EmailUserManager()
USERNAME_FIELD = "email"
REQUIRED_FIELDS = []
class Meta:
app_label = "tests"
# ─── Shape test models ──────────────────────────────────────────────────────
import uuid
class TimestampMixin(models.Model):
created_at = models.DateTimeField(auto_now_add=True)
updated_at = models.DateTimeField(auto_now=True)
class Meta:
abstract = True
class Publisher(TimestampMixin):
name = models.CharField(max_length=200)
country = models.CharField(max_length=100, default="")
class Meta:
app_label = "tests"
class Author(TimestampMixin):
name = models.CharField(max_length=200)
bio = models.TextField(default="")
publisher = models.ForeignKey(
Publisher, on_delete=models.CASCADE, related_name="authors"
)
mentor = models.ForeignKey(
"self", on_delete=models.SET_NULL, null=True, blank=True, related_name="mentees"
)
class Meta:
app_label = "tests"
class Tag(models.Model):
slug = models.SlugField(primary_key=True, max_length=100)
label = models.CharField(max_length=100)
class Meta:
app_label = "tests"
class Book(TimestampMixin):
title = models.CharField(max_length=300)
isbn = models.CharField(max_length=13, unique=True)
page_count = models.IntegerField(default=0)
is_published = models.BooleanField(default=False)
author = models.ForeignKey(Author, on_delete=models.CASCADE, related_name="books")
editor = models.ForeignKey(
Author,
on_delete=models.SET_NULL,
null=True,
blank=True,
related_name="edited_books",
)
tags = models.ManyToManyField(Tag, blank=True, related_name="books")
class Meta:
app_label = "tests"
class Chapter(TimestampMixin):
book = models.ForeignKey(Book, on_delete=models.CASCADE, related_name="chapters")
number = models.IntegerField()
title = models.CharField(max_length=300)
word_count = models.IntegerField(default=0)
class Meta:
app_label = "tests"
ordering = ["number"]
unique_together = [("book", "number")]
class Section(models.Model):
id = models.UUIDField(primary_key=True, default=uuid.uuid4, editable=False)
chapter = models.ForeignKey(
Chapter, on_delete=models.CASCADE, related_name="sections"
)
heading = models.CharField(max_length=300)
body = models.TextField(default="")
position = models.IntegerField(default=0)
class Meta:
app_label = "tests"
ordering = ["position"]
class Category(models.Model):
name = models.CharField(max_length=200)
parent = models.ForeignKey(
"self", on_delete=models.CASCADE, null=True, blank=True, related_name="children"
)
class Meta:
app_label = "tests"

View File

@@ -0,0 +1,46 @@
"""
Django settings for running mizan's test suite standalone.
Usage:
cd django/
pip install -e ".[dev]"
pytest
"""
SECRET_KEY = "test-secret-key-for-standalone-tests-only"
DEBUG = True
DATABASES = {
"default": {
"ENGINE": "django.db.backends.sqlite3",
"NAME": ":memory:",
}
}
INSTALLED_APPS = [
"django.contrib.auth",
"django.contrib.contenttypes",
"django.contrib.sessions",
"mizan",
"tests",
]
AUTH_USER_MODEL = "tests.EmailUser"
ROOT_URLCONF = "tests.urls"
DEFAULT_AUTO_FIELD = "django.db.models.BigAutoField"
# JWT settings for test_auth.py (can be overridden per-class with @override_settings)
JWT_PRIVATE_KEY = "test-secret-key-for-testing-only"
JWT_ALGORITHM = "HS256"
# Session engine (for test_auth.py SessionStore usage)
SESSION_ENGINE = "django.contrib.sessions.backends.db"
MIDDLEWARE = [
"django.contrib.sessions.middleware.SessionMiddleware",
"django.middleware.csrf.CsrfViewMiddleware",
"django.contrib.auth.middleware.AuthenticationMiddleware",
]

View File

@@ -0,0 +1,5 @@
from django.urls import include, path
urlpatterns = [
path("api/mizan/", include("mizan.urls")),
]