Restructure repo into five-package AFI architecture
Mizan is an Application Framework Interface (AFI) with five
independent packages:
packages/
mizan-ast/ Language layer (source → KDL schema)
mizan-schema/ IR layer (KDL schema definition)
mizan-rpc/ Protocol layer (client gen + server adapters)
adapters/django/ ← was django/
generator/ ← was react/src/generator/
mizan-csr/ State layer (client state engine)
adapters/react/ ← was react/
mizan-ssr/ Rendering layer (server-side rendering)
Each package is independent. The adapter directories contain the
framework-specific implementations. Stub packages (ast, schema, ssr)
establish the structure for future work.
264 Django tests + 33 React tests pass from new locations.
Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
This commit is contained in:
2
packages/mizan-csr/adapters/react/.gitignore
vendored
Normal file
2
packages/mizan-csr/adapters/react/.gitignore
vendored
Normal file
@@ -0,0 +1,2 @@
|
||||
node_modules/
|
||||
dist/
|
||||
103
packages/mizan-csr/adapters/react/README.md
Normal file
103
packages/mizan-csr/adapters/react/README.md
Normal file
@@ -0,0 +1,103 @@
|
||||
# @rythazhur/mizan (TypeScript)
|
||||
|
||||
React client for the mizan framework. See the [monorepo root](../README.md) for full documentation.
|
||||
|
||||
## Install
|
||||
|
||||
```bash
|
||||
npm install @rythazhur/mizan@git+https://git.impactsoundworks.com/isw/mizan.git#workspace=react
|
||||
```
|
||||
|
||||
## Usage
|
||||
|
||||
You don't use this package directly. You use the **generated hooks**.
|
||||
|
||||
### 1. Configure
|
||||
|
||||
```js
|
||||
// django.config.mjs
|
||||
export default {
|
||||
source: {
|
||||
django: {
|
||||
managePath: '../backend/manage.py',
|
||||
command: ['uv', 'run', 'python'],
|
||||
},
|
||||
},
|
||||
output: 'src/api/generated.ts',
|
||||
}
|
||||
```
|
||||
|
||||
### 2. Generate
|
||||
|
||||
```bash
|
||||
npx mizan-generate # once
|
||||
npx mizan-generate --watch # dev mode
|
||||
```
|
||||
|
||||
### 3. Wrap your app
|
||||
|
||||
```tsx
|
||||
import { DjangoContext } from '@/api'
|
||||
|
||||
<DjangoContext>
|
||||
<App />
|
||||
</DjangoContext>
|
||||
```
|
||||
|
||||
`DjangoContext` is the only provider you need. It handles HTTP, WebSocket, CSRF, session init, context auto-fetching, and channel connections.
|
||||
|
||||
### 4. Use generated hooks
|
||||
|
||||
```tsx
|
||||
import { useCurrentUser, useEcho, useContactForm, useChatChannel } from '@/api'
|
||||
|
||||
// Context (SSR-hydrated, auto-refreshed)
|
||||
const user = useCurrentUser()
|
||||
|
||||
// Server function
|
||||
const echo = useEcho()
|
||||
const result = await echo({ text: 'hello' })
|
||||
|
||||
// Form (Zod + server validation)
|
||||
const form = useContactForm()
|
||||
form.set('email', 'test@example.com')
|
||||
await form.submit()
|
||||
|
||||
// Channel (WebSocket)
|
||||
const chat = useChatChannel({ room: 'general' })
|
||||
chat.send({ text: 'hello' })
|
||||
chat.messages // typed, reactive
|
||||
```
|
||||
|
||||
## Generated Files
|
||||
|
||||
| File | Contents |
|
||||
|------|----------|
|
||||
| `generated.django.tsx` | `DjangoContext` + typed hooks |
|
||||
| `generated.mizan.ts` | Pydantic types |
|
||||
| `generated.forms.ts` | Form hooks with Zod |
|
||||
| `generated.channels.hooks.tsx` | Channel hooks |
|
||||
| `index.ts` | Re-exports everything |
|
||||
|
||||
## Sub-exports
|
||||
|
||||
| Import | When to use |
|
||||
|--------|------------|
|
||||
| `@rythazhur/mizan` | Core: mizanProvider, hooks, forms, errors |
|
||||
| `@rythazhur/mizan/channels` | WebSocket channels |
|
||||
| `@rythazhur/mizan/jwt` | JWT token management |
|
||||
| `@rythazhur/mizan/client` | HTTP clients (CSR/SSR) |
|
||||
| `@rythazhur/mizan/allauth` | Allauth UI components |
|
||||
|
||||
These are **library internals** used by the generated code. You should import from `@/api` (your generated index), not from the library directly.
|
||||
|
||||
## Running Tests
|
||||
|
||||
```bash
|
||||
# Unit tests (Vitest, jsdom)
|
||||
npm test
|
||||
|
||||
# E2E tests (Playwright, real browser)
|
||||
# Requires Docker backend running
|
||||
npx playwright test
|
||||
```
|
||||
85
packages/mizan-csr/adapters/react/package.json
Normal file
85
packages/mizan-csr/adapters/react/package.json
Normal file
@@ -0,0 +1,85 @@
|
||||
{
|
||||
"name": "@rythazhur/mizan",
|
||||
"version": "0.1.1",
|
||||
"type": "module",
|
||||
"main": "dist/index.js",
|
||||
"types": "dist/index.d.ts",
|
||||
"exports": {
|
||||
".": {
|
||||
"types": "./dist/index.d.ts",
|
||||
"import": "./dist/index.js"
|
||||
},
|
||||
"./client": {
|
||||
"types": "./dist/client/index.d.ts",
|
||||
"import": "./dist/client/index.js"
|
||||
},
|
||||
"./client/react": {
|
||||
"types": "./dist/client/react.d.ts",
|
||||
"import": "./dist/client/react.js"
|
||||
},
|
||||
"./client/nextjs": {
|
||||
"types": "./dist/client/nextjs.d.ts",
|
||||
"import": "./dist/client/nextjs.js"
|
||||
},
|
||||
"./channels": {
|
||||
"types": "./dist/channels/index.d.ts",
|
||||
"import": "./dist/channels/index.js"
|
||||
},
|
||||
"./jwt": {
|
||||
"types": "./dist/jwt/index.d.ts",
|
||||
"import": "./dist/jwt/index.js"
|
||||
},
|
||||
"./allauth": {
|
||||
"types": "./dist/allauth/index.d.ts",
|
||||
"import": "./dist/allauth/index.js"
|
||||
},
|
||||
"./allauth/nextjs": {
|
||||
"types": "./dist/allauth/nextjs.d.ts",
|
||||
"import": "./dist/allauth/nextjs.js"
|
||||
}
|
||||
},
|
||||
"bin": {
|
||||
"mizan-generate": "./dist/generator/cli.mjs"
|
||||
},
|
||||
"scripts": {
|
||||
"build": "tsc -p tsconfig.build.json && node -e \"require('fs').cpSync('src/generator','dist/generator',{recursive:true})\"",
|
||||
"dev": "tsc -p tsconfig.build.json --watch",
|
||||
"test": "vitest run",
|
||||
"test:watch": "vitest",
|
||||
"test:integration": "RUN_INTEGRATION_TESTS=true NEXT_PUBLIC_HOST_URL=http://localhost:8000 vitest run",
|
||||
"prepublishOnly": "npm run build"
|
||||
},
|
||||
"peerDependencies": {
|
||||
"react": ">=18",
|
||||
"react-dom": ">=18"
|
||||
},
|
||||
"peerDependenciesMeta": {
|
||||
"next": {
|
||||
"optional": true
|
||||
}
|
||||
},
|
||||
"dependencies": {
|
||||
"ws": "^8.19.0"
|
||||
},
|
||||
"devDependencies": {
|
||||
"@playwright/test": "^1.58.2",
|
||||
"@simplewebauthn/browser": "^13.2.2",
|
||||
"@testing-library/jest-dom": "^6.0.0",
|
||||
"@testing-library/react": "^16.0.0",
|
||||
"@types/react": "^19.0.0",
|
||||
"@types/react-dom": "^19.0.0",
|
||||
"@types/ws": "^8.5.0",
|
||||
"jsdom": "^25.0.0",
|
||||
"next": "^16.1.6",
|
||||
"react": "^19.0.0",
|
||||
"react-dom": "^19.0.0",
|
||||
"typescript": "^5.7.0",
|
||||
"vitest": "^3.0.0",
|
||||
"zod": "^4.3.6"
|
||||
},
|
||||
"optionalDependencies": {
|
||||
"chokidar": "^4.0.0",
|
||||
"minimatch": "^10.0.0",
|
||||
"openapi-typescript": "^7.0.0"
|
||||
}
|
||||
}
|
||||
314
packages/mizan-csr/adapters/react/src/__tests__/context.test.tsx
Normal file
314
packages/mizan-csr/adapters/react/src/__tests__/context.test.tsx
Normal file
@@ -0,0 +1,314 @@
|
||||
/**
|
||||
* Tests for Django Server React Context
|
||||
*
|
||||
* Unit tests run without backend.
|
||||
* Integration tests require: docker-compose up
|
||||
*
|
||||
* Run integration tests with: RUN_INTEGRATION_TESTS=true npm run test
|
||||
*/
|
||||
|
||||
import React from 'react'
|
||||
import { render, screen, waitFor, act } from '@testing-library/react'
|
||||
import {
|
||||
MizanProvider,
|
||||
useMizan,
|
||||
useMizanStatus,
|
||||
useMizanCall,
|
||||
// Legacy aliases for backwards compatibility tests
|
||||
DjangoContext,
|
||||
useDjango,
|
||||
useDjangoStatus,
|
||||
useServerFunction,
|
||||
} from '../context'
|
||||
import { DjangoError } from '../errors'
|
||||
import { describeIntegration, BACKEND_URL } from '../testing'
|
||||
|
||||
// ============================================================================
|
||||
// Unit Tests (no backend required)
|
||||
// ============================================================================
|
||||
|
||||
describe('mizan Context (unit)', () => {
|
||||
describe('useMizan hook', () => {
|
||||
it('should throw when used outside provider', () => {
|
||||
function TestComponent() {
|
||||
useMizan()
|
||||
return <div>Test</div>
|
||||
}
|
||||
|
||||
const consoleSpy = jest.spyOn(console, 'error').mockImplementation()
|
||||
|
||||
expect(() => render(<TestComponent />)).toThrow(
|
||||
'useMizan must be used within a MizanProvider'
|
||||
)
|
||||
|
||||
consoleSpy.mockRestore()
|
||||
})
|
||||
|
||||
it('should return context value inside provider', () => {
|
||||
let contextValue: any = null
|
||||
|
||||
function TestComponent() {
|
||||
contextValue = useMizan()
|
||||
return <div>Test</div>
|
||||
}
|
||||
|
||||
render(
|
||||
<MizanProvider autoConnect={false}>
|
||||
<TestComponent />
|
||||
</MizanProvider>
|
||||
)
|
||||
|
||||
expect(contextValue).not.toBeNull()
|
||||
expect(contextValue!.status).toBe('disconnected')
|
||||
})
|
||||
})
|
||||
|
||||
describe('useMizanStatus hook', () => {
|
||||
it('should return disconnected when autoConnect is false', () => {
|
||||
function TestComponent() {
|
||||
const status = useMizanStatus()
|
||||
return <div data-testid="status">{status}</div>
|
||||
}
|
||||
|
||||
render(
|
||||
<MizanProvider autoConnect={false}>
|
||||
<TestComponent />
|
||||
</MizanProvider>
|
||||
)
|
||||
|
||||
expect(screen.getByTestId('status')).toHaveTextContent('disconnected')
|
||||
})
|
||||
})
|
||||
|
||||
describe('hydration', () => {
|
||||
it('should initialize context store from hydration data', () => {
|
||||
let contextValue: any = null
|
||||
|
||||
function TestComponent() {
|
||||
contextValue = useMizan()
|
||||
return <div>Test</div>
|
||||
}
|
||||
|
||||
const hydration = {
|
||||
auth_status: { is_authenticated: false },
|
||||
user: null,
|
||||
}
|
||||
|
||||
render(
|
||||
<MizanProvider hydration={hydration} autoConnect={false}>
|
||||
<TestComponent />
|
||||
</MizanProvider>
|
||||
)
|
||||
|
||||
expect(contextValue.getContext('auth_status')).toEqual({ is_authenticated: false })
|
||||
expect(contextValue.getContext('user')).toEqual(null)
|
||||
})
|
||||
})
|
||||
})
|
||||
|
||||
// ============================================================================
|
||||
// Integration Tests (require running backend)
|
||||
// ============================================================================
|
||||
|
||||
describeIntegration('mizan Context (integration)', () => {
|
||||
describe('server function calls via HTTP', () => {
|
||||
it('should call echo function and get response', async () => {
|
||||
let result: any = null
|
||||
let error: any = null
|
||||
|
||||
function TestComponent() {
|
||||
const { call, status } = useDjango()
|
||||
|
||||
React.useEffect(() => {
|
||||
// Use HTTP fallback (status will be disconnected without WebSocket)
|
||||
call<{ text: string }, { message: string }>('echo', { text: 'context test' })
|
||||
.then((r) => { result = r })
|
||||
.catch((e) => { error = e })
|
||||
}, [call])
|
||||
|
||||
return <div data-testid="status">{status}</div>
|
||||
}
|
||||
|
||||
render(
|
||||
<DjangoContext baseUrl={`${BACKEND_URL}/api/mizan`} autoConnect={false}>
|
||||
<TestComponent />
|
||||
</DjangoContext>
|
||||
)
|
||||
|
||||
await waitFor(() => {
|
||||
expect(result || error).not.toBeNull()
|
||||
}, { timeout: 5000 })
|
||||
|
||||
expect(error).toBeNull()
|
||||
expect(result).toHaveProperty('message')
|
||||
expect(result.message).toContain('context test')
|
||||
})
|
||||
|
||||
it('should call add function with correct result', async () => {
|
||||
let result: any = null
|
||||
let error: any = null
|
||||
|
||||
function TestComponent() {
|
||||
const { call } = useDjango()
|
||||
|
||||
React.useEffect(() => {
|
||||
call<{ a: number; b: number }, { result: number }>('add', { a: 10, b: 20 })
|
||||
.then((r) => { result = r })
|
||||
.catch((e) => { error = e })
|
||||
}, [call])
|
||||
|
||||
return <div>Test</div>
|
||||
}
|
||||
|
||||
render(
|
||||
<DjangoContext baseUrl={`${BACKEND_URL}/api/mizan`} autoConnect={false}>
|
||||
<TestComponent />
|
||||
</DjangoContext>
|
||||
)
|
||||
|
||||
await waitFor(() => {
|
||||
expect(result || error).not.toBeNull()
|
||||
}, { timeout: 5000 })
|
||||
|
||||
expect(error).toBeNull()
|
||||
expect(result).toEqual({ result: 30 })
|
||||
})
|
||||
|
||||
it('should throw DjangoError for validation errors', async () => {
|
||||
let result: any = null
|
||||
let error: any = null
|
||||
|
||||
function TestComponent() {
|
||||
const { call } = useDjango()
|
||||
|
||||
React.useEffect(() => {
|
||||
// Call without required field
|
||||
call('echo', {})
|
||||
.then((r) => { result = r })
|
||||
.catch((e) => { error = e })
|
||||
}, [call])
|
||||
|
||||
return <div>Test</div>
|
||||
}
|
||||
|
||||
render(
|
||||
<DjangoContext baseUrl={`${BACKEND_URL}/api/mizan`} autoConnect={false}>
|
||||
<TestComponent />
|
||||
</DjangoContext>
|
||||
)
|
||||
|
||||
await waitFor(() => {
|
||||
expect(result || error).not.toBeNull()
|
||||
}, { timeout: 5000 })
|
||||
|
||||
expect(result).toBeNull()
|
||||
expect(error).toBeInstanceOf(DjangoError)
|
||||
})
|
||||
})
|
||||
|
||||
describe('useServerFunction hook', () => {
|
||||
it('should create typed function that calls backend', async () => {
|
||||
let result: any = null
|
||||
let error: any = null
|
||||
|
||||
interface EchoInput { text: string }
|
||||
interface EchoOutput { message: string }
|
||||
|
||||
function TestComponent() {
|
||||
const echo = useServerFunction<EchoInput, EchoOutput>('echo')
|
||||
|
||||
React.useEffect(() => {
|
||||
echo({ text: 'typed function test' })
|
||||
.then((r) => { result = r })
|
||||
.catch((e) => { error = e })
|
||||
}, [echo])
|
||||
|
||||
return <div>Test</div>
|
||||
}
|
||||
|
||||
render(
|
||||
<DjangoContext baseUrl={`${BACKEND_URL}/api/mizan`} autoConnect={false}>
|
||||
<TestComponent />
|
||||
</DjangoContext>
|
||||
)
|
||||
|
||||
await waitFor(() => {
|
||||
expect(result || error).not.toBeNull()
|
||||
}, { timeout: 5000 })
|
||||
|
||||
expect(error).toBeNull()
|
||||
expect(result).toHaveProperty('message')
|
||||
expect(result.message).toContain('typed function test')
|
||||
})
|
||||
})
|
||||
|
||||
describe('form functions', () => {
|
||||
it('should call login.schema and get form fields', async () => {
|
||||
let result: any = null
|
||||
let error: any = null
|
||||
|
||||
function TestComponent() {
|
||||
const { call } = useDjango()
|
||||
|
||||
React.useEffect(() => {
|
||||
call('login.schema', { data: {} })
|
||||
.then((r) => { result = r })
|
||||
.catch((e) => { error = e })
|
||||
}, [call])
|
||||
|
||||
return <div>Test</div>
|
||||
}
|
||||
|
||||
render(
|
||||
<DjangoContext baseUrl={`${BACKEND_URL}/api/mizan`} autoConnect={false}>
|
||||
<TestComponent />
|
||||
</DjangoContext>
|
||||
)
|
||||
|
||||
await waitFor(() => {
|
||||
expect(result || error).not.toBeNull()
|
||||
}, { timeout: 5000 })
|
||||
|
||||
expect(error).toBeNull()
|
||||
expect(result).toHaveProperty('fields')
|
||||
expect(result).toHaveProperty('meta')
|
||||
// Login form should have login and password fields
|
||||
expect(result.fields).toHaveProperty('login')
|
||||
expect(result.fields).toHaveProperty('password')
|
||||
})
|
||||
|
||||
it('should call login.validate and get validation result', async () => {
|
||||
let result: any = null
|
||||
let error: any = null
|
||||
|
||||
function TestComponent() {
|
||||
const { call } = useDjango()
|
||||
|
||||
React.useEffect(() => {
|
||||
call('login.validate', {
|
||||
data: { login: 'test@example.com', password: 'testpass' }
|
||||
})
|
||||
.then((r) => { result = r })
|
||||
.catch((e) => { error = e })
|
||||
}, [call])
|
||||
|
||||
return <div>Test</div>
|
||||
}
|
||||
|
||||
render(
|
||||
<DjangoContext baseUrl={`${BACKEND_URL}/api/mizan`} autoConnect={false}>
|
||||
<TestComponent />
|
||||
</DjangoContext>
|
||||
)
|
||||
|
||||
await waitFor(() => {
|
||||
expect(result || error).not.toBeNull()
|
||||
}, { timeout: 5000 })
|
||||
|
||||
// Should return validation result (may have errors for invalid creds, that's ok)
|
||||
expect(error).toBeNull()
|
||||
expect(result).toHaveProperty('valid')
|
||||
})
|
||||
})
|
||||
|
||||
})
|
||||
214
packages/mizan-csr/adapters/react/src/__tests__/errors.test.ts
Normal file
214
packages/mizan-csr/adapters/react/src/__tests__/errors.test.ts
Normal file
@@ -0,0 +1,214 @@
|
||||
/**
|
||||
* Tests for Django Server Error
|
||||
*/
|
||||
|
||||
import { DjangoError, type FunctionErrorResponse } from '../errors'
|
||||
|
||||
describe('DjangoError', () => {
|
||||
it('should create error with message and code', () => {
|
||||
const response: FunctionErrorResponse = {
|
||||
error: true,
|
||||
code: 'NOT_FOUND',
|
||||
message: 'Function not found',
|
||||
}
|
||||
|
||||
const error = new DjangoError(response)
|
||||
|
||||
expect(error.message).toBe('Function not found')
|
||||
expect(error.code).toBe('NOT_FOUND')
|
||||
expect(error.name).toBe('DjangoError')
|
||||
})
|
||||
|
||||
it('should preserve details', () => {
|
||||
const response: FunctionErrorResponse = {
|
||||
error: true,
|
||||
code: 'VALIDATION_ERROR',
|
||||
message: 'Invalid input',
|
||||
details: {
|
||||
fields: {
|
||||
name: ['Required', 'Too short'],
|
||||
email: ['Invalid format'],
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
const error = new DjangoError(response)
|
||||
|
||||
expect(error.details).toBeDefined()
|
||||
expect(error.details?.fields?.name).toEqual(['Required', 'Too short'])
|
||||
})
|
||||
|
||||
it('should preserve original response', () => {
|
||||
const response: FunctionErrorResponse = {
|
||||
error: true,
|
||||
code: 'INTERNAL_ERROR',
|
||||
message: 'Server error',
|
||||
}
|
||||
|
||||
const error = new DjangoError(response)
|
||||
|
||||
expect(error.response).toBe(response)
|
||||
})
|
||||
|
||||
describe('isValidationError', () => {
|
||||
it('should return true for validation errors', () => {
|
||||
const error = new DjangoError({
|
||||
error: true,
|
||||
code: 'VALIDATION_ERROR',
|
||||
message: 'Invalid',
|
||||
})
|
||||
|
||||
expect(error.isValidationError()).toBe(true)
|
||||
})
|
||||
|
||||
it('should return false for other errors', () => {
|
||||
const error = new DjangoError({
|
||||
error: true,
|
||||
code: 'NOT_FOUND',
|
||||
message: 'Not found',
|
||||
})
|
||||
|
||||
expect(error.isValidationError()).toBe(false)
|
||||
})
|
||||
})
|
||||
|
||||
describe('isAuthError', () => {
|
||||
it('should return true for unauthorized', () => {
|
||||
const error = new DjangoError({
|
||||
error: true,
|
||||
code: 'UNAUTHORIZED',
|
||||
message: 'Not authenticated',
|
||||
})
|
||||
|
||||
expect(error.isAuthError()).toBe(true)
|
||||
})
|
||||
|
||||
it('should return true for forbidden', () => {
|
||||
const error = new DjangoError({
|
||||
error: true,
|
||||
code: 'FORBIDDEN',
|
||||
message: 'Access denied',
|
||||
})
|
||||
|
||||
expect(error.isAuthError()).toBe(true)
|
||||
})
|
||||
|
||||
it('should return false for other errors', () => {
|
||||
const error = new DjangoError({
|
||||
error: true,
|
||||
code: 'NOT_FOUND',
|
||||
message: 'Not found',
|
||||
})
|
||||
|
||||
expect(error.isAuthError()).toBe(false)
|
||||
})
|
||||
})
|
||||
|
||||
describe('isNotFound', () => {
|
||||
it('should return true for not found errors', () => {
|
||||
const error = new DjangoError({
|
||||
error: true,
|
||||
code: 'NOT_FOUND',
|
||||
message: 'Not found',
|
||||
})
|
||||
|
||||
expect(error.isNotFound()).toBe(true)
|
||||
})
|
||||
|
||||
it('should return false for other errors', () => {
|
||||
const error = new DjangoError({
|
||||
error: true,
|
||||
code: 'VALIDATION_ERROR',
|
||||
message: 'Invalid',
|
||||
})
|
||||
|
||||
expect(error.isNotFound()).toBe(false)
|
||||
})
|
||||
})
|
||||
|
||||
describe('getFieldErrors', () => {
|
||||
it('should return field errors for validation error', () => {
|
||||
const error = new DjangoError({
|
||||
error: true,
|
||||
code: 'VALIDATION_ERROR',
|
||||
message: 'Invalid input',
|
||||
details: {
|
||||
fields: {
|
||||
name: ['Required'],
|
||||
email: ['Invalid'],
|
||||
},
|
||||
},
|
||||
})
|
||||
|
||||
const errors = error.getFieldErrors()
|
||||
|
||||
expect(errors).toEqual({
|
||||
name: ['Required'],
|
||||
email: ['Invalid'],
|
||||
})
|
||||
})
|
||||
|
||||
it('should return null for non-validation errors', () => {
|
||||
const error = new DjangoError({
|
||||
error: true,
|
||||
code: 'NOT_FOUND',
|
||||
message: 'Not found',
|
||||
})
|
||||
|
||||
expect(error.getFieldErrors()).toBeNull()
|
||||
})
|
||||
|
||||
it('should return null if no fields in details', () => {
|
||||
const error = new DjangoError({
|
||||
error: true,
|
||||
code: 'VALIDATION_ERROR',
|
||||
message: 'Invalid',
|
||||
details: {},
|
||||
})
|
||||
|
||||
expect(error.getFieldErrors()).toBeNull()
|
||||
})
|
||||
})
|
||||
|
||||
describe('getFieldError', () => {
|
||||
it('should return first error for a field', () => {
|
||||
const error = new DjangoError({
|
||||
error: true,
|
||||
code: 'VALIDATION_ERROR',
|
||||
message: 'Invalid input',
|
||||
details: {
|
||||
fields: {
|
||||
name: ['Required', 'Too short'],
|
||||
},
|
||||
},
|
||||
})
|
||||
|
||||
expect(error.getFieldError('name')).toBe('Required')
|
||||
})
|
||||
|
||||
it('should return null for non-existent field', () => {
|
||||
const error = new DjangoError({
|
||||
error: true,
|
||||
code: 'VALIDATION_ERROR',
|
||||
message: 'Invalid input',
|
||||
details: {
|
||||
fields: {
|
||||
name: ['Required'],
|
||||
},
|
||||
},
|
||||
})
|
||||
|
||||
expect(error.getFieldError('email')).toBeNull()
|
||||
})
|
||||
|
||||
it('should return null for non-validation errors', () => {
|
||||
const error = new DjangoError({
|
||||
error: true,
|
||||
code: 'NOT_FOUND',
|
||||
message: 'Not found',
|
||||
})
|
||||
|
||||
expect(error.getFieldError('name')).toBeNull()
|
||||
})
|
||||
})
|
||||
})
|
||||
362
packages/mizan-csr/adapters/react/src/__tests__/forms.test.tsx
Normal file
362
packages/mizan-csr/adapters/react/src/__tests__/forms.test.tsx
Normal file
@@ -0,0 +1,362 @@
|
||||
/**
|
||||
* Tests for Django Forms
|
||||
*
|
||||
* Integration tests call the REAL backend - no mocks.
|
||||
* Backend must be running: docker-compose up
|
||||
*
|
||||
* Run integration tests with: RUN_INTEGRATION_TESTS=true npm run test
|
||||
*/
|
||||
|
||||
import React from 'react'
|
||||
import { renderHook, act, waitFor } from '@testing-library/react'
|
||||
import { z } from 'zod'
|
||||
|
||||
import {
|
||||
useDjangoFormCore,
|
||||
type FormCoreConfig,
|
||||
} from '../forms'
|
||||
import { DjangoContext } from '../context'
|
||||
import { describeIntegration, BACKEND_URL } from '../testing'
|
||||
|
||||
// ============================================================================
|
||||
// Test Setup
|
||||
// ============================================================================
|
||||
|
||||
// Helper to render hook with provider
|
||||
function renderFormHook<TData extends Record<string, unknown>>(
|
||||
config: FormCoreConfig<TData>
|
||||
) {
|
||||
return renderHook(() => useDjangoFormCore<TData>(config), {
|
||||
wrapper: ({ children }) => (
|
||||
<DjangoContext baseUrl={`${BACKEND_URL}/api/mizan`} autoConnect={false}>
|
||||
{children}
|
||||
</DjangoContext>
|
||||
),
|
||||
})
|
||||
}
|
||||
|
||||
// ============================================================================
|
||||
// Integration Tests - Real Backend Calls
|
||||
// ============================================================================
|
||||
|
||||
describeIntegration('useDjangoFormCore (integration)', () => {
|
||||
describe('Schema loading from real backend', () => {
|
||||
it('loads login form schema', async () => {
|
||||
const { result } = renderFormHook({
|
||||
name: 'login',
|
||||
})
|
||||
|
||||
expect(result.current.loading).toBe(true)
|
||||
|
||||
await waitFor(() => {
|
||||
expect(result.current.loading).toBe(false)
|
||||
}, { timeout: 5000 })
|
||||
|
||||
expect(result.current.schema).not.toBeNull()
|
||||
expect(result.current.schema?.name).toBe('login')
|
||||
// Login form should have login and password fields
|
||||
expect(result.current.schema?.fields).toHaveProperty('login')
|
||||
expect(result.current.schema?.fields).toHaveProperty('password')
|
||||
})
|
||||
|
||||
it('loads signup form schema', async () => {
|
||||
const { result } = renderFormHook({
|
||||
name: 'signup',
|
||||
})
|
||||
|
||||
await waitFor(() => {
|
||||
expect(result.current.loading).toBe(false)
|
||||
}, { timeout: 5000 })
|
||||
|
||||
expect(result.current.schema).not.toBeNull()
|
||||
expect(result.current.schema?.name).toBe('signup')
|
||||
// Signup form should have email and password fields
|
||||
expect(result.current.schema?.fields).toHaveProperty('email')
|
||||
expect(result.current.schema?.fields).toHaveProperty('password1')
|
||||
})
|
||||
|
||||
it('loads add_email form schema', async () => {
|
||||
const { result } = renderFormHook({
|
||||
name: 'add_email',
|
||||
})
|
||||
|
||||
await waitFor(() => {
|
||||
expect(result.current.loading).toBe(false)
|
||||
}, { timeout: 5000 })
|
||||
|
||||
expect(result.current.schema).not.toBeNull()
|
||||
expect(result.current.schema?.fields).toHaveProperty('email')
|
||||
})
|
||||
})
|
||||
|
||||
describe('Form data management', () => {
|
||||
it('sets and gets form data', async () => {
|
||||
const { result } = renderFormHook({
|
||||
name: 'login',
|
||||
})
|
||||
|
||||
await waitFor(() => {
|
||||
expect(result.current.loading).toBe(false)
|
||||
}, { timeout: 5000 })
|
||||
|
||||
act(() => {
|
||||
result.current.set('login', 'test@example.com')
|
||||
result.current.set('password', 'testpassword123')
|
||||
})
|
||||
|
||||
expect(result.current.data.login).toBe('test@example.com')
|
||||
expect(result.current.data.password).toBe('testpassword123')
|
||||
})
|
||||
|
||||
it('tracks touched fields', async () => {
|
||||
const { result } = renderFormHook({
|
||||
name: 'login',
|
||||
})
|
||||
|
||||
await waitFor(() => {
|
||||
expect(result.current.loading).toBe(false)
|
||||
}, { timeout: 5000 })
|
||||
|
||||
expect(result.current.touchedFields.size).toBe(0)
|
||||
|
||||
act(() => {
|
||||
result.current.touch('login')
|
||||
})
|
||||
|
||||
expect(result.current.touchedFields.has('login')).toBe(true)
|
||||
expect(result.current.touchedFields.has('password')).toBe(false)
|
||||
})
|
||||
|
||||
it('resets form state', async () => {
|
||||
const { result } = renderFormHook({
|
||||
name: 'login',
|
||||
})
|
||||
|
||||
await waitFor(() => {
|
||||
expect(result.current.loading).toBe(false)
|
||||
}, { timeout: 5000 })
|
||||
|
||||
act(() => {
|
||||
result.current.set('login', 'changed@example.com')
|
||||
result.current.touch('login')
|
||||
})
|
||||
|
||||
expect(result.current.data.login).toBe('changed@example.com')
|
||||
expect(result.current.touchedFields.has('login')).toBe(true)
|
||||
|
||||
act(() => {
|
||||
result.current.reset()
|
||||
})
|
||||
|
||||
expect(result.current.data.login).toBe('')
|
||||
expect(result.current.touchedFields.size).toBe(0)
|
||||
expect(result.current.errors).toBeNull()
|
||||
})
|
||||
})
|
||||
|
||||
describe('Zod validation with real schema', () => {
|
||||
// Define Zod schema matching login form
|
||||
const LoginZodSchema = z.object({
|
||||
login: z.string().min(1, 'Login is required').email('Invalid email'),
|
||||
password: z.string().min(1, 'Password is required'),
|
||||
})
|
||||
|
||||
type LoginData = z.infer<typeof LoginZodSchema>
|
||||
|
||||
it('validates with Zod schema on touch', async () => {
|
||||
const { result } = renderFormHook<LoginData>({
|
||||
name: 'login',
|
||||
zodSchema: LoginZodSchema,
|
||||
})
|
||||
|
||||
await waitFor(() => {
|
||||
expect(result.current.loading).toBe(false)
|
||||
}, { timeout: 5000 })
|
||||
|
||||
// Set invalid value
|
||||
act(() => {
|
||||
result.current.set('login', 'not-an-email')
|
||||
})
|
||||
|
||||
// Touch triggers validation
|
||||
act(() => {
|
||||
result.current.touch('login')
|
||||
})
|
||||
|
||||
// Zod validation should show email format error
|
||||
expect(result.current.errors?.fields.login).toBeDefined()
|
||||
expect(result.current.errors?.fields.login?.[0]?.message).toBe('Invalid email')
|
||||
expect(result.current.errors?.fields.login?.[0]?.source).toBe('zod')
|
||||
})
|
||||
|
||||
it('clears errors when value becomes valid', async () => {
|
||||
const { result } = renderFormHook<LoginData>({
|
||||
name: 'login',
|
||||
zodSchema: LoginZodSchema,
|
||||
})
|
||||
|
||||
await waitFor(() => {
|
||||
expect(result.current.loading).toBe(false)
|
||||
}, { timeout: 5000 })
|
||||
|
||||
// Set invalid value and touch
|
||||
act(() => {
|
||||
result.current.set('login', 'bad')
|
||||
})
|
||||
act(() => {
|
||||
result.current.touch('login')
|
||||
})
|
||||
|
||||
expect(result.current.errors?.fields.login).toBeDefined()
|
||||
|
||||
// Set valid value and touch
|
||||
act(() => {
|
||||
result.current.set('login', 'valid@example.com')
|
||||
})
|
||||
act(() => {
|
||||
result.current.touch('login')
|
||||
})
|
||||
|
||||
expect(result.current.errors?.fields.login).toBeUndefined()
|
||||
})
|
||||
|
||||
it('tracks hasErrors correctly', async () => {
|
||||
const { result } = renderFormHook<LoginData>({
|
||||
name: 'login',
|
||||
zodSchema: LoginZodSchema,
|
||||
})
|
||||
|
||||
await waitFor(() => {
|
||||
expect(result.current.loading).toBe(false)
|
||||
}, { timeout: 5000 })
|
||||
|
||||
expect(result.current.hasErrors).toBe(false)
|
||||
|
||||
// Set invalid and touch
|
||||
act(() => {
|
||||
result.current.set('login', 'bad')
|
||||
})
|
||||
act(() => {
|
||||
result.current.touch('login')
|
||||
})
|
||||
|
||||
expect(result.current.hasErrors).toBe(true)
|
||||
|
||||
// Set valid and touch
|
||||
act(() => {
|
||||
result.current.set('login', 'good@example.com')
|
||||
})
|
||||
act(() => {
|
||||
result.current.touch('login')
|
||||
})
|
||||
|
||||
expect(result.current.hasErrors).toBe(false)
|
||||
})
|
||||
})
|
||||
|
||||
describe('Form submission', () => {
|
||||
it('submits login form and handles validation errors', async () => {
|
||||
const { result } = renderFormHook({
|
||||
name: 'login',
|
||||
})
|
||||
|
||||
await waitFor(() => {
|
||||
expect(result.current.loading).toBe(false)
|
||||
}, { timeout: 5000 })
|
||||
|
||||
// Set invalid credentials
|
||||
act(() => {
|
||||
result.current.set('login', 'nonexistent@example.com')
|
||||
result.current.set('password', 'wrongpassword')
|
||||
})
|
||||
|
||||
// Submit should fail with validation error
|
||||
let submitResult: any
|
||||
await act(async () => {
|
||||
submitResult = await result.current.submit()
|
||||
})
|
||||
|
||||
// Submit should return error (invalid credentials)
|
||||
// The exact error depends on backend behavior
|
||||
expect(submitResult).toBeDefined()
|
||||
})
|
||||
|
||||
it('submits signup form with missing required fields', async () => {
|
||||
const { result } = renderFormHook({
|
||||
name: 'signup',
|
||||
})
|
||||
|
||||
await waitFor(() => {
|
||||
expect(result.current.loading).toBe(false)
|
||||
}, { timeout: 5000 })
|
||||
|
||||
// Submit with empty fields should return validation errors
|
||||
let submitResult: any
|
||||
await act(async () => {
|
||||
submitResult = await result.current.submit()
|
||||
})
|
||||
|
||||
// Should have validation errors for required fields
|
||||
expect(submitResult).toBeDefined()
|
||||
expect(submitResult.success).toBe(false)
|
||||
})
|
||||
})
|
||||
|
||||
describe('Error source tagging', () => {
|
||||
const LoginZodSchema = z.object({
|
||||
login: z.string().min(1, 'Login is required').email('Invalid email'),
|
||||
password: z.string().min(1, 'Password is required'),
|
||||
})
|
||||
|
||||
type LoginData = z.infer<typeof LoginZodSchema>
|
||||
|
||||
it('tags Zod errors with source: zod', async () => {
|
||||
const { result } = renderFormHook<LoginData>({
|
||||
name: 'login',
|
||||
zodSchema: LoginZodSchema,
|
||||
})
|
||||
|
||||
await waitFor(() => {
|
||||
expect(result.current.loading).toBe(false)
|
||||
}, { timeout: 5000 })
|
||||
|
||||
act(() => {
|
||||
result.current.set('login', 'invalid')
|
||||
})
|
||||
act(() => {
|
||||
result.current.touch('login')
|
||||
})
|
||||
|
||||
const errors = result.current.getFieldErrors('login')
|
||||
expect(errors.length).toBeGreaterThan(0)
|
||||
expect(errors[0].source).toBe('zod')
|
||||
})
|
||||
|
||||
it('filters errors by source', async () => {
|
||||
const { result } = renderFormHook<LoginData>({
|
||||
name: 'login',
|
||||
zodSchema: LoginZodSchema,
|
||||
})
|
||||
|
||||
await waitFor(() => {
|
||||
expect(result.current.loading).toBe(false)
|
||||
}, { timeout: 5000 })
|
||||
|
||||
act(() => {
|
||||
result.current.set('login', 'invalid')
|
||||
})
|
||||
act(() => {
|
||||
result.current.touch('login')
|
||||
})
|
||||
|
||||
// Should have Zod errors
|
||||
const zodErrors = result.current.getFieldErrors('login', { source: 'zod' })
|
||||
expect(zodErrors.length).toBeGreaterThan(0)
|
||||
|
||||
// Should have no server errors yet
|
||||
const serverErrors = result.current.getFieldErrors('login', { source: 'server' })
|
||||
expect(serverErrors.length).toBe(0)
|
||||
})
|
||||
})
|
||||
})
|
||||
|
||||
@@ -0,0 +1,824 @@
|
||||
/**
|
||||
* Cross-cutting integration tests for mizan
|
||||
*
|
||||
* Tests error paths and protocol correctness across HTTP, Forms, and WebSocket.
|
||||
* Requires a running backend: docker-compose up
|
||||
*
|
||||
* Run with: RUN_INTEGRATION_TESTS=true npm run test
|
||||
*/
|
||||
|
||||
import { renderHook, act } from '@testing-library/react'
|
||||
import { ReactNode } from 'react'
|
||||
import { describeIntegration, BACKEND_URL, WS_URL } from '../testing'
|
||||
import { MizanProvider, useMizan } from '../context'
|
||||
import { DjangoError } from '../errors'
|
||||
import { ChannelConnection } from '../channels/connection'
|
||||
import { RPCError } from '../channels/connection'
|
||||
|
||||
function Wrapper({ children }: { children: ReactNode }) {
|
||||
return (
|
||||
<MizanProvider baseUrl={`${BACKEND_URL}/api/mizan`} autoConnect={false}>
|
||||
{children}
|
||||
</MizanProvider>
|
||||
)
|
||||
}
|
||||
|
||||
// Helper to get call function
|
||||
function useCall() {
|
||||
const { call } = useMizan()
|
||||
return call
|
||||
}
|
||||
|
||||
// Helper to wait for a ChannelConnection to reach 'connected' status
|
||||
function waitForConnected(connection: ChannelConnection): Promise<void> {
|
||||
return new Promise((resolve) => {
|
||||
if (connection.status === 'connected') { resolve(); return }
|
||||
const unsub = connection.onStatusChange((status) => {
|
||||
if (status === 'connected') { unsub(); resolve() }
|
||||
})
|
||||
})
|
||||
}
|
||||
|
||||
// ============================================================================
|
||||
// Group 1: Executor framework validation
|
||||
// ============================================================================
|
||||
|
||||
describeIntegration('Executor framework validation', () => {
|
||||
it('should return VALIDATION_ERROR with field details for wrong input types', async () => {
|
||||
const { result } = renderHook(() => useCall(), { wrapper: Wrapper })
|
||||
|
||||
let error: DjangoError | null = null
|
||||
await act(async () => {
|
||||
try {
|
||||
await result.current('add', { a: 'hello', b: 'world' })
|
||||
} catch (e) {
|
||||
error = e as DjangoError
|
||||
}
|
||||
})
|
||||
|
||||
expect(error).toBeInstanceOf(DjangoError)
|
||||
expect(error!.code).toBe('VALIDATION_ERROR')
|
||||
expect(error!.isValidationError()).toBe(true)
|
||||
const fieldErrors = error!.getFieldErrors()
|
||||
expect(fieldErrors).not.toBeNull()
|
||||
expect(Object.keys(fieldErrors!).length).toBeGreaterThan(0)
|
||||
})
|
||||
|
||||
it('should return NOT_FOUND for non-existent function', async () => {
|
||||
const { result } = renderHook(() => useCall(), { wrapper: Wrapper })
|
||||
|
||||
let error: DjangoError | null = null
|
||||
await act(async () => {
|
||||
try {
|
||||
await result.current('this_function_does_not_exist', {})
|
||||
} catch (e) {
|
||||
error = e as DjangoError
|
||||
}
|
||||
})
|
||||
|
||||
expect(error).toBeInstanceOf(DjangoError)
|
||||
expect(error!.code).toBe('NOT_FOUND')
|
||||
})
|
||||
|
||||
it('should return FORBIDDEN for auth-required function when anonymous', async () => {
|
||||
const { result } = renderHook(() => useCall(), { wrapper: Wrapper })
|
||||
|
||||
let error: DjangoError | null = null
|
||||
await act(async () => {
|
||||
try {
|
||||
await result.current('whoami', {})
|
||||
} catch (e) {
|
||||
error = e as DjangoError
|
||||
}
|
||||
})
|
||||
|
||||
expect(error).toBeInstanceOf(DjangoError)
|
||||
expect(error!.isAuthError()).toBe(true)
|
||||
})
|
||||
|
||||
it('should return VALIDATION_ERROR with specific field for missing required input', async () => {
|
||||
const { result } = renderHook(() => useCall(), { wrapper: Wrapper })
|
||||
|
||||
let error: DjangoError | null = null
|
||||
await act(async () => {
|
||||
try {
|
||||
await result.current('echo', {})
|
||||
} catch (e) {
|
||||
error = e as DjangoError
|
||||
}
|
||||
})
|
||||
|
||||
expect(error).toBeInstanceOf(DjangoError)
|
||||
expect(error!.code).toBe('VALIDATION_ERROR')
|
||||
const fieldErrors = error!.getFieldErrors()
|
||||
expect(fieldErrors).not.toBeNull()
|
||||
expect(fieldErrors!).toHaveProperty('text')
|
||||
})
|
||||
})
|
||||
|
||||
// ============================================================================
|
||||
// Group 2: Form framework validation
|
||||
// ============================================================================
|
||||
|
||||
describeIntegration('Form framework validation', () => {
|
||||
it('should return field metadata with types and required flags', async () => {
|
||||
const { result } = renderHook(() => useCall(), { wrapper: Wrapper })
|
||||
|
||||
let response: any = null
|
||||
await act(async () => {
|
||||
response = await result.current('login.schema', { data: {} })
|
||||
})
|
||||
|
||||
expect(response).toHaveProperty('fields')
|
||||
|
||||
// Each field should have name, label, type, required, widget
|
||||
const fields = response.fields
|
||||
for (const fieldKey of Object.keys(fields)) {
|
||||
const field = fields[fieldKey]
|
||||
expect(field).toHaveProperty('name')
|
||||
expect(field).toHaveProperty('label')
|
||||
expect(field).toHaveProperty('type')
|
||||
expect(field).toHaveProperty('required')
|
||||
expect(field).toHaveProperty('widget')
|
||||
}
|
||||
|
||||
// login field should be required
|
||||
expect(fields.login.required).toBe(true)
|
||||
|
||||
// password field widget should contain 'password'
|
||||
expect(fields.password.widget.toLowerCase()).toContain('password')
|
||||
})
|
||||
|
||||
it('should return field-level errors for empty form validation', async () => {
|
||||
const { result } = renderHook(() => useCall(), { wrapper: Wrapper })
|
||||
|
||||
let response: any = null
|
||||
await act(async () => {
|
||||
response = await result.current('login.validate', { data: {} })
|
||||
})
|
||||
|
||||
expect(response.valid).toBe(false)
|
||||
expect(response.errors).toBeInstanceOf(Array)
|
||||
expect(response.errors.length).toBeGreaterThan(0)
|
||||
})
|
||||
|
||||
it('should return form-level error for wrong login credentials', async () => {
|
||||
const { result } = renderHook(() => useCall(), { wrapper: Wrapper })
|
||||
|
||||
let response: any = null
|
||||
await act(async () => {
|
||||
response = await result.current('login.submit', {
|
||||
login: 'wrong@example.com',
|
||||
password: 'wrongpass',
|
||||
})
|
||||
})
|
||||
|
||||
expect(response.success).toBe(false)
|
||||
expect(JSON.stringify(response.errors)).toContain('Invalid login credentials')
|
||||
})
|
||||
})
|
||||
|
||||
// ============================================================================
|
||||
// Group 3: WebSocket framework validation
|
||||
// ============================================================================
|
||||
|
||||
describeIntegration('WebSocket framework validation', () => {
|
||||
let connection: ChannelConnection
|
||||
|
||||
beforeEach(async () => {
|
||||
connection = new ChannelConnection({ url: WS_URL, reconnect: false })
|
||||
connection.connect()
|
||||
|
||||
// Wait for connected status
|
||||
await new Promise<void>((resolve) => {
|
||||
const unsub = connection.onStatusChange((status) => {
|
||||
if (status === 'connected') {
|
||||
unsub()
|
||||
resolve()
|
||||
}
|
||||
})
|
||||
if (connection.status === 'connected') {
|
||||
unsub()
|
||||
resolve()
|
||||
}
|
||||
})
|
||||
})
|
||||
|
||||
afterEach(() => {
|
||||
connection.disconnect()
|
||||
})
|
||||
|
||||
it('should deliver messages back through channel subscription', async () => {
|
||||
// Subscribe to chat channel and wait for confirmation
|
||||
await new Promise<void>((resolve, reject) => {
|
||||
const timeout = setTimeout(() => reject(new Error('Subscribe timeout')), 5000)
|
||||
const unsub = connection.onMessage((msg) => {
|
||||
if ('subscribed' in msg && msg.channel === 'chat') {
|
||||
clearTimeout(timeout)
|
||||
unsub()
|
||||
resolve()
|
||||
}
|
||||
})
|
||||
connection.send({
|
||||
action: 'subscribe',
|
||||
channel: 'chat',
|
||||
params: { room: 'integration-test' },
|
||||
})
|
||||
})
|
||||
|
||||
// Listen for the echoed message
|
||||
const messagePromise = new Promise<any>((resolve, reject) => {
|
||||
const timeout = setTimeout(() => reject(new Error('Message timeout')), 5000)
|
||||
const unsub = connection.onMessage((msg) => {
|
||||
if ('data' in msg) {
|
||||
clearTimeout(timeout)
|
||||
unsub()
|
||||
resolve(msg)
|
||||
}
|
||||
})
|
||||
})
|
||||
|
||||
// Send a message
|
||||
connection.send({
|
||||
action: 'message',
|
||||
channel: 'chat',
|
||||
params: { room: 'integration-test' },
|
||||
data: { text: 'hello from integration test' },
|
||||
})
|
||||
|
||||
const received = await messagePromise
|
||||
expect(received.data).toHaveProperty('text')
|
||||
expect(received.data.text).toBe('hello from integration test')
|
||||
})
|
||||
|
||||
it('should return error for unknown channel subscription', async () => {
|
||||
const errorPromise = new Promise<any>((resolve, reject) => {
|
||||
const timeout = setTimeout(() => reject(new Error('Error response timeout')), 5000)
|
||||
const unsub = connection.onMessage((msg) => {
|
||||
if ('error' in msg) {
|
||||
clearTimeout(timeout)
|
||||
unsub()
|
||||
resolve(msg)
|
||||
}
|
||||
})
|
||||
})
|
||||
|
||||
connection.send({
|
||||
action: 'subscribe',
|
||||
channel: 'nonexistent_channel',
|
||||
params: {},
|
||||
})
|
||||
|
||||
const errorMsg = await errorPromise
|
||||
expect(errorMsg.error.toLowerCase()).toContain('unknown channel')
|
||||
})
|
||||
|
||||
it('should reject HTTP-only function via WebSocket RPC', async () => {
|
||||
let rpcError: RPCError | null = null
|
||||
try {
|
||||
await connection.rpc('http_only_echo', { text: 'test' })
|
||||
} catch (e) {
|
||||
rpcError = e as RPCError
|
||||
}
|
||||
|
||||
expect(rpcError).toBeInstanceOf(RPCError)
|
||||
})
|
||||
|
||||
it('should return NOT_FOUND for non-existent RPC function', async () => {
|
||||
let rpcError: RPCError | null = null
|
||||
try {
|
||||
await connection.rpc('does_not_exist', {})
|
||||
} catch (e) {
|
||||
rpcError = e as RPCError
|
||||
}
|
||||
|
||||
expect(rpcError).toBeInstanceOf(RPCError)
|
||||
expect(rpcError!.code).toBe('NOT_FOUND')
|
||||
})
|
||||
|
||||
it('should return VALIDATION_ERROR for wrong RPC input types', async () => {
|
||||
let rpcError: RPCError | null = null
|
||||
try {
|
||||
await connection.rpc('add', { a: 'not_number', b: 'also_not' })
|
||||
} catch (e) {
|
||||
rpcError = e as RPCError
|
||||
}
|
||||
|
||||
expect(rpcError).toBeInstanceOf(RPCError)
|
||||
expect(rpcError!.code).toBe('VALIDATION_ERROR')
|
||||
})
|
||||
})
|
||||
|
||||
// ============================================================================
|
||||
// Group 4: HTTP happy path
|
||||
// ============================================================================
|
||||
|
||||
describeIntegration('HTTP happy path', () => {
|
||||
it('should call echo and receive echoed text', async () => {
|
||||
const { result } = renderHook(() => useCall(), { wrapper: Wrapper })
|
||||
|
||||
let response: any = null
|
||||
await act(async () => {
|
||||
response = await result.current('echo', { text: 'hello world' })
|
||||
})
|
||||
|
||||
expect(response).toHaveProperty('message')
|
||||
expect(response.message).toContain('hello world')
|
||||
})
|
||||
|
||||
it('should call add and receive correct sum', async () => {
|
||||
const { result } = renderHook(() => useCall(), { wrapper: Wrapper })
|
||||
|
||||
let response: any = null
|
||||
await act(async () => {
|
||||
response = await result.current('add', { a: 17, b: 25 })
|
||||
})
|
||||
|
||||
expect(response).toEqual({ result: 42 })
|
||||
})
|
||||
|
||||
it('should call class-based ServerFunction (multiply)', async () => {
|
||||
const { result } = renderHook(() => useCall(), { wrapper: Wrapper })
|
||||
|
||||
let response: any = null
|
||||
await act(async () => {
|
||||
response = await result.current('multiply', { x: 7, y: 6 })
|
||||
})
|
||||
|
||||
expect(response).toEqual({ product: 42 })
|
||||
})
|
||||
})
|
||||
|
||||
// ============================================================================
|
||||
// Group 5: Auth variations
|
||||
// ============================================================================
|
||||
|
||||
describeIntegration('Auth variations', () => {
|
||||
it('should reject staff_only for anonymous with UNAUTHORIZED', async () => {
|
||||
const { result } = renderHook(() => useCall(), { wrapper: Wrapper })
|
||||
|
||||
let error: DjangoError | null = null
|
||||
await act(async () => {
|
||||
try {
|
||||
await result.current('staff_only', {})
|
||||
} catch (e) {
|
||||
error = e as DjangoError
|
||||
}
|
||||
})
|
||||
|
||||
expect(error).toBeInstanceOf(DjangoError)
|
||||
expect(error!.code).toBe('UNAUTHORIZED')
|
||||
expect(error!.isAuthError()).toBe(true)
|
||||
})
|
||||
|
||||
it('should reject superuser_only for anonymous with UNAUTHORIZED', async () => {
|
||||
const { result } = renderHook(() => useCall(), { wrapper: Wrapper })
|
||||
|
||||
let error: DjangoError | null = null
|
||||
await act(async () => {
|
||||
try {
|
||||
await result.current('superuser_only', {})
|
||||
} catch (e) {
|
||||
error = e as DjangoError
|
||||
}
|
||||
})
|
||||
|
||||
expect(error).toBeInstanceOf(DjangoError)
|
||||
expect(error!.code).toBe('UNAUTHORIZED')
|
||||
expect(error!.isAuthError()).toBe(true)
|
||||
})
|
||||
|
||||
it('should reject verified_only for anonymous (callable auth)', async () => {
|
||||
const { result } = renderHook(() => useCall(), { wrapper: Wrapper })
|
||||
|
||||
let error: DjangoError | null = null
|
||||
await act(async () => {
|
||||
try {
|
||||
await result.current('verified_only', {})
|
||||
} catch (e) {
|
||||
error = e as DjangoError
|
||||
}
|
||||
})
|
||||
|
||||
// Callable auth returns False for anonymous, which maps to FORBIDDEN
|
||||
expect(error).toBeInstanceOf(DjangoError)
|
||||
expect(error!.code).toBe('FORBIDDEN')
|
||||
expect(error!.isAuthError()).toBe(true)
|
||||
})
|
||||
})
|
||||
|
||||
// ============================================================================
|
||||
// Group 6: Context functions
|
||||
// ============================================================================
|
||||
|
||||
describeIntegration('Context functions', () => {
|
||||
it('should call global context current_user and get anonymous response', async () => {
|
||||
const { result } = renderHook(() => useCall(), { wrapper: Wrapper })
|
||||
|
||||
let response: any = null
|
||||
await act(async () => {
|
||||
response = await result.current('current_user', {})
|
||||
})
|
||||
|
||||
expect(response).toHaveProperty('authenticated', false)
|
||||
expect(response).toHaveProperty('email', '')
|
||||
expect(response).toHaveProperty('is_staff', false)
|
||||
})
|
||||
|
||||
it('should call local context greet with name parameter', async () => {
|
||||
const { result } = renderHook(() => useCall(), { wrapper: Wrapper })
|
||||
|
||||
let response: any = null
|
||||
await act(async () => {
|
||||
response = await result.current('greet', { name: 'World' })
|
||||
})
|
||||
|
||||
expect(response).toHaveProperty('greeting', 'Hello, World!')
|
||||
})
|
||||
})
|
||||
|
||||
// ============================================================================
|
||||
// Group 7: Error code coverage
|
||||
// ============================================================================
|
||||
|
||||
describeIntegration('Error code coverage', () => {
|
||||
it('should return NOT_IMPLEMENTED for NotImplementedError', async () => {
|
||||
const { result } = renderHook(() => useCall(), { wrapper: Wrapper })
|
||||
|
||||
let error: DjangoError | null = null
|
||||
await act(async () => {
|
||||
try {
|
||||
await result.current('not_implemented_fn', {})
|
||||
} catch (e) {
|
||||
error = e as DjangoError
|
||||
}
|
||||
})
|
||||
|
||||
expect(error).toBeInstanceOf(DjangoError)
|
||||
expect(error!.code).toBe('NOT_IMPLEMENTED')
|
||||
})
|
||||
|
||||
it('should return INTERNAL_ERROR for unhandled RuntimeError', async () => {
|
||||
const { result } = renderHook(() => useCall(), { wrapper: Wrapper })
|
||||
|
||||
let error: DjangoError | null = null
|
||||
await act(async () => {
|
||||
try {
|
||||
await result.current('buggy_fn', {})
|
||||
} catch (e) {
|
||||
error = e as DjangoError
|
||||
}
|
||||
})
|
||||
|
||||
expect(error).toBeInstanceOf(DjangoError)
|
||||
expect(error!.code).toBe('INTERNAL_ERROR')
|
||||
})
|
||||
|
||||
it('should return FORBIDDEN for PermissionError with wrong secret', async () => {
|
||||
const { result } = renderHook(() => useCall(), { wrapper: Wrapper })
|
||||
|
||||
let error: DjangoError | null = null
|
||||
await act(async () => {
|
||||
try {
|
||||
await result.current('permission_check_fn', { secret: 'wrong' })
|
||||
} catch (e) {
|
||||
error = e as DjangoError
|
||||
}
|
||||
})
|
||||
|
||||
expect(error).toBeInstanceOf(DjangoError)
|
||||
expect(error!.code).toBe('FORBIDDEN')
|
||||
expect(error!.isAuthError()).toBe(true)
|
||||
})
|
||||
|
||||
it('should succeed with correct secret for permission_check_fn', async () => {
|
||||
const { result } = renderHook(() => useCall(), { wrapper: Wrapper })
|
||||
|
||||
let response: any = null
|
||||
await act(async () => {
|
||||
response = await result.current('permission_check_fn', { secret: 'open-sesame' })
|
||||
})
|
||||
|
||||
expect(response).toEqual({ message: 'access granted' })
|
||||
})
|
||||
|
||||
it('should return BAD_REQUEST for invalid JSON body', async () => {
|
||||
const response = await fetch(`${BACKEND_URL}/api/mizan/call/`, {
|
||||
method: 'POST',
|
||||
headers: { 'Content-Type': 'application/json' },
|
||||
credentials: 'include',
|
||||
body: '{not valid json}',
|
||||
})
|
||||
const data = await response.json()
|
||||
expect(data.error).toBe(true)
|
||||
expect(data.code).toBe('BAD_REQUEST')
|
||||
})
|
||||
|
||||
it('should return BAD_REQUEST for missing fn field', async () => {
|
||||
const response = await fetch(`${BACKEND_URL}/api/mizan/call/`, {
|
||||
method: 'POST',
|
||||
headers: { 'Content-Type': 'application/json' },
|
||||
credentials: 'include',
|
||||
body: JSON.stringify({ args: {} }),
|
||||
})
|
||||
const data = await response.json()
|
||||
expect(data.error).toBe(true)
|
||||
expect(data.code).toBe('BAD_REQUEST')
|
||||
})
|
||||
})
|
||||
|
||||
// ============================================================================
|
||||
// Group 8: mizanFormMixin integration
|
||||
// ============================================================================
|
||||
|
||||
describeIntegration('mizanFormMixin integration', () => {
|
||||
it('should return schema with title, subtitle, and submit_label from mizanFormMeta', async () => {
|
||||
const { result } = renderHook(() => useCall(), { wrapper: Wrapper })
|
||||
|
||||
let response: any = null
|
||||
await act(async () => {
|
||||
response = await result.current('contact.schema', { data: {} })
|
||||
})
|
||||
|
||||
expect(response).toHaveProperty('fields')
|
||||
const fields = response.fields
|
||||
|
||||
// Contact form should have name, email, and message fields
|
||||
expect(fields).toHaveProperty('name')
|
||||
expect(fields).toHaveProperty('email')
|
||||
expect(fields).toHaveProperty('message')
|
||||
|
||||
// Meta should include title, subtitle, and submit_label
|
||||
expect(response).toHaveProperty('meta')
|
||||
expect(response.meta.title).toBe('Contact Us')
|
||||
expect(response.meta).toHaveProperty('subtitle')
|
||||
expect(response.meta.submit_label).toBe('Send Message')
|
||||
})
|
||||
|
||||
it('should return form meta with live_validation and live_form_errors', async () => {
|
||||
const { result } = renderHook(() => useCall(), { wrapper: Wrapper })
|
||||
|
||||
let response: any = null
|
||||
await act(async () => {
|
||||
response = await result.current('contact.schema', { data: {} })
|
||||
})
|
||||
|
||||
expect(response).toHaveProperty('meta')
|
||||
expect(response.meta.live_validation).toBe(true)
|
||||
expect(response.meta.live_form_errors).toBe(false)
|
||||
})
|
||||
|
||||
it('should validate contact form and return field errors for missing fields', async () => {
|
||||
const { result } = renderHook(() => useCall(), { wrapper: Wrapper })
|
||||
|
||||
let response: any = null
|
||||
await act(async () => {
|
||||
response = await result.current('contact.validate', { data: {} })
|
||||
})
|
||||
|
||||
expect(response.valid).toBe(false)
|
||||
expect(response.errors).toBeInstanceOf(Array)
|
||||
expect(response.errors.length).toBeGreaterThan(0)
|
||||
|
||||
// Should have errors for name, email, and message
|
||||
const errorFieldNames = response.errors.map((e: any) => e.field)
|
||||
expect(errorFieldNames).toContain('name')
|
||||
expect(errorFieldNames).toContain('email')
|
||||
expect(errorFieldNames).toContain('message')
|
||||
})
|
||||
|
||||
it('should submit contact form successfully and get on_submit_success data', async () => {
|
||||
const { result } = renderHook(() => useCall(), { wrapper: Wrapper })
|
||||
|
||||
let response: any = null
|
||||
await act(async () => {
|
||||
response = await result.current('contact.submit', {
|
||||
name: 'Test',
|
||||
email: 'test@test.com',
|
||||
message: 'Hello',
|
||||
})
|
||||
})
|
||||
|
||||
expect(response.success).toBe(true)
|
||||
expect(response.data).toHaveProperty('received', true)
|
||||
expect(response.data).toHaveProperty('from', 'test@test.com')
|
||||
})
|
||||
})
|
||||
|
||||
// ============================================================================
|
||||
// Group 9: Formset integration
|
||||
// ============================================================================
|
||||
|
||||
describeIntegration('Formset integration', () => {
|
||||
it('should return formset schema for item form', async () => {
|
||||
const { result } = renderHook(() => useCall(), { wrapper: Wrapper })
|
||||
|
||||
let response: any = null
|
||||
await act(async () => {
|
||||
response = await result.current('item.formset.schema', { forms: [] })
|
||||
})
|
||||
|
||||
expect(response).toHaveProperty('min_num')
|
||||
expect(response).toHaveProperty('max_num')
|
||||
})
|
||||
|
||||
it('should validate formset with invalid data', async () => {
|
||||
const { result } = renderHook(() => useCall(), { wrapper: Wrapper })
|
||||
|
||||
let response: any = null
|
||||
await act(async () => {
|
||||
response = await result.current('item.formset.validate', {
|
||||
forms: [{ label: '', quantity: 0 }],
|
||||
})
|
||||
})
|
||||
|
||||
// Should have validation errors for the invalid form data
|
||||
expect(response).toHaveProperty('errors')
|
||||
})
|
||||
|
||||
it('should submit formset with valid data', async () => {
|
||||
const { result } = renderHook(() => useCall(), { wrapper: Wrapper })
|
||||
|
||||
let response: any = null
|
||||
await act(async () => {
|
||||
response = await result.current('item.formset.submit', {
|
||||
forms: [{ label: 'Widget', quantity: 5 }],
|
||||
})
|
||||
})
|
||||
|
||||
expect(response).toHaveProperty('success', true)
|
||||
})
|
||||
})
|
||||
|
||||
// ============================================================================
|
||||
// Group 10: Channel authorization
|
||||
// ============================================================================
|
||||
|
||||
describeIntegration('Channel authorization', () => {
|
||||
let connection: ChannelConnection
|
||||
|
||||
beforeEach(async () => {
|
||||
connection = new ChannelConnection({ url: WS_URL, reconnect: false })
|
||||
connection.connect()
|
||||
await waitForConnected(connection)
|
||||
})
|
||||
|
||||
afterEach(() => {
|
||||
connection.disconnect()
|
||||
})
|
||||
|
||||
it('should reject subscription to private channel when anonymous', async () => {
|
||||
const msgPromise = new Promise<any>((resolve, reject) => {
|
||||
const timeout = setTimeout(() => reject(new Error('Auth rejection timeout')), 5000)
|
||||
connection.onMessage((msg) => {
|
||||
if ('error' in msg) {
|
||||
clearTimeout(timeout)
|
||||
resolve(msg)
|
||||
}
|
||||
})
|
||||
})
|
||||
|
||||
connection.send({ action: 'subscribe', channel: 'private' })
|
||||
|
||||
const msg = await msgPromise
|
||||
expect(msg.error).toContain('Not authorized')
|
||||
})
|
||||
|
||||
it('should successfully unsubscribe from a channel', async () => {
|
||||
// First subscribe to chat
|
||||
await new Promise<void>((resolve, reject) => {
|
||||
const timeout = setTimeout(() => reject(new Error('Subscribe timeout')), 5000)
|
||||
const unsub = connection.onMessage((msg) => {
|
||||
if ('subscribed' in msg && msg.channel === 'chat') {
|
||||
clearTimeout(timeout)
|
||||
unsub()
|
||||
resolve()
|
||||
}
|
||||
})
|
||||
connection.send({
|
||||
action: 'subscribe',
|
||||
channel: 'chat',
|
||||
params: { room: 'unsub-test' },
|
||||
})
|
||||
})
|
||||
|
||||
// Now unsubscribe
|
||||
const unsubPromise = new Promise<any>((resolve, reject) => {
|
||||
const timeout = setTimeout(() => reject(new Error('Unsubscribe timeout')), 5000)
|
||||
const unsub = connection.onMessage((msg) => {
|
||||
if ('unsubscribed' in msg && msg.channel === 'chat') {
|
||||
clearTimeout(timeout)
|
||||
unsub()
|
||||
resolve(msg)
|
||||
}
|
||||
})
|
||||
})
|
||||
|
||||
connection.send({
|
||||
action: 'unsubscribe',
|
||||
channel: 'chat',
|
||||
params: { room: 'unsub-test' },
|
||||
})
|
||||
|
||||
const unsubMsg = await unsubPromise
|
||||
expect(unsubMsg).toHaveProperty('unsubscribed', true)
|
||||
expect(unsubMsg).toHaveProperty('channel', 'chat')
|
||||
})
|
||||
})
|
||||
|
||||
// ============================================================================
|
||||
// Group 11: WebSocket RPC happy path
|
||||
// ============================================================================
|
||||
|
||||
describeIntegration('WebSocket RPC happy path', () => {
|
||||
let connection: ChannelConnection
|
||||
|
||||
beforeEach(async () => {
|
||||
connection = new ChannelConnection({ url: WS_URL, reconnect: false })
|
||||
connection.connect()
|
||||
await waitForConnected(connection)
|
||||
})
|
||||
|
||||
afterEach(() => {
|
||||
connection.disconnect()
|
||||
})
|
||||
|
||||
it('should call echo via RPC and get correct response', async () => {
|
||||
const response = await connection.rpc<{ text: string }, { message: string }>(
|
||||
'echo',
|
||||
{ text: 'ws rpc echo' }
|
||||
)
|
||||
|
||||
expect(response).toHaveProperty('message')
|
||||
expect(response.message).toContain('ws rpc echo')
|
||||
})
|
||||
|
||||
it('should call add via RPC and get correct sum', async () => {
|
||||
const response = await connection.rpc<{ a: number; b: number }, { result: number }>(
|
||||
'add',
|
||||
{ a: 100, b: 200 }
|
||||
)
|
||||
|
||||
expect(response).toEqual({ result: 300 })
|
||||
})
|
||||
|
||||
it('should reject multiply via RPC if not websocket-enabled', async () => {
|
||||
// multiply uses @register_as which may not set websocket=True
|
||||
// If it's HTTP-only, RPC should fail; if it supports WS, it should succeed
|
||||
let response: any = null
|
||||
let rpcError: RPCError | null = null
|
||||
try {
|
||||
response = await connection.rpc<{ x: number; y: number }, { product: number }>(
|
||||
'multiply',
|
||||
{ x: 7, y: 6 }
|
||||
)
|
||||
} catch (e) {
|
||||
rpcError = e as RPCError
|
||||
}
|
||||
|
||||
// Either it succeeds with the correct product, or it fails because it's HTTP-only
|
||||
if (rpcError) {
|
||||
expect(rpcError).toBeInstanceOf(RPCError)
|
||||
} else {
|
||||
expect(response).toEqual({ product: 42 })
|
||||
}
|
||||
})
|
||||
|
||||
it('should reject ws_whoami via RPC when anonymous', async () => {
|
||||
let rpcError: RPCError | null = null
|
||||
try {
|
||||
await connection.rpc('ws_whoami', {})
|
||||
} catch (e) {
|
||||
rpcError = e as RPCError
|
||||
}
|
||||
|
||||
expect(rpcError).toBeInstanceOf(RPCError)
|
||||
expect(rpcError!.code).toBe('UNAUTHORIZED')
|
||||
})
|
||||
})
|
||||
|
||||
// ============================================================================
|
||||
// Group 12: Successful form submit flow
|
||||
// ============================================================================
|
||||
|
||||
describeIntegration('Form submit success flow', () => {
|
||||
it('should sign up a new user via signup form', async () => {
|
||||
const { result } = renderHook(() => useCall(), { wrapper: Wrapper })
|
||||
|
||||
// Use a unique email per run to avoid duplicate-user errors
|
||||
const uniqueEmail = `newuser+${Date.now()}@example.com`
|
||||
|
||||
let response: any = null
|
||||
await act(async () => {
|
||||
response = await result.current('signup.submit', {
|
||||
email: uniqueEmail,
|
||||
password1: 'testpass123',
|
||||
})
|
||||
})
|
||||
|
||||
expect(response).toHaveProperty('success', true)
|
||||
expect(response).toHaveProperty('data')
|
||||
expect(response.data).toHaveProperty('user_id')
|
||||
expect(typeof response.data.user_id).toBe('number')
|
||||
})
|
||||
})
|
||||
@@ -0,0 +1,11 @@
|
||||
/**
|
||||
* Re-export RouterAdapter from mizan/client.
|
||||
*
|
||||
* Allauth extends this with a required getParam method.
|
||||
*/
|
||||
import type { RouterAdapter as BaseRouterAdapter } from 'mizan/client'
|
||||
|
||||
export interface RouterAdapter extends BaseRouterAdapter {
|
||||
/** Get a specific route param (e.g., from /auth/[...path]) - required for allauth */
|
||||
getParam: (name: string) => string | string[] | undefined
|
||||
}
|
||||
309
packages/mizan-csr/adapters/react/src/allauth/api.ts
Normal file
309
packages/mizan-csr/adapters/react/src/allauth/api.ts
Normal file
@@ -0,0 +1,309 @@
|
||||
import { OAuthProcess, apiURL } from './defines'
|
||||
|
||||
import {
|
||||
type RegistrationResponseJSON,
|
||||
type AuthenticationResponseJSON,
|
||||
} from '@simplewebauthn/browser'
|
||||
|
||||
import type {
|
||||
// Core types
|
||||
AuthError,
|
||||
User,
|
||||
Flow,
|
||||
Authenticated,
|
||||
AuthenticationMeta,
|
||||
// Request types
|
||||
LoginRequest,
|
||||
SignupRequest,
|
||||
ProviderSignupRequest,
|
||||
ReauthenticateRequest,
|
||||
ChangePasswordRequest,
|
||||
ResetPasswordRequest,
|
||||
MFAAuthenticateRequest,
|
||||
WebAuthnUpdateRequest,
|
||||
// Response types
|
||||
AllauthResponse,
|
||||
AuthenticatedResponse,
|
||||
AuthenticationRequiredResponse,
|
||||
ReauthenticationRequiredResponse,
|
||||
ConfigurationResponse,
|
||||
EmailListResponse,
|
||||
SessionListResponse,
|
||||
AuthenticatorListResponse,
|
||||
ProviderAccountListResponse,
|
||||
TOTPStatusResponse,
|
||||
RecoveryCodesResponse,
|
||||
WebAuthnCreationOptionsResponse,
|
||||
WebAuthnRequestOptionsResponse,
|
||||
EmailVerificationInfoResponse,
|
||||
ErrorResponse,
|
||||
} from './types'
|
||||
|
||||
export type { AuthError } from './types'
|
||||
|
||||
// Registration = creating new credentials (signup, add)
|
||||
// Authentication = verifying existing credentials (login, authenticate, reauthenticate)
|
||||
type RegistrationCredential = RegistrationResponseJSON
|
||||
type AuthenticationCredential = AuthenticationResponseJSON
|
||||
|
||||
/**
|
||||
* Union of all possible auth responses
|
||||
*/
|
||||
export type AuthResponse =
|
||||
| AuthenticatedResponse
|
||||
| AuthenticationRequiredResponse
|
||||
| ReauthenticationRequiredResponse
|
||||
| ConfigurationResponse
|
||||
| EmailListResponse
|
||||
| SessionListResponse
|
||||
| AuthenticatorListResponse
|
||||
| ProviderAccountListResponse
|
||||
| TOTPStatusResponse
|
||||
| RecoveryCodesResponse
|
||||
| WebAuthnCreationOptionsResponse
|
||||
| WebAuthnRequestOptionsResponse
|
||||
| EmailVerificationInfoResponse
|
||||
| ErrorResponse
|
||||
| AllauthResponse
|
||||
|
||||
export interface AuthDetails {
|
||||
isAuthenticated: boolean
|
||||
requiresReauthentication: boolean
|
||||
user: User | null
|
||||
pendingFlow: Flow | undefined
|
||||
}
|
||||
|
||||
export const getAuthDetails = (auth: AllauthResponse | null | undefined): AuthDetails => {
|
||||
const meta = auth?.meta as AuthenticationMeta | undefined
|
||||
const isAuthenticated = !!auth && (auth?.status === 200 || (auth?.status === 401 && !!meta?.is_authenticated))
|
||||
const requiresReauthentication = !!(isAuthenticated && auth?.status === 401)
|
||||
const data = auth?.data as Authenticated | { flows?: Flow[]; user?: User } | undefined
|
||||
const pendingFlow = (data as { flows?: Flow[] })?.flows?.find((flow: Flow) => flow.is_pending)
|
||||
|
||||
return {
|
||||
isAuthenticated,
|
||||
requiresReauthentication,
|
||||
user: isAuthenticated ? (data as Authenticated)?.user ?? null : null,
|
||||
pendingFlow
|
||||
}
|
||||
}
|
||||
|
||||
export type BrowserFormAction = (action: string, data: Record<string, string>) => void
|
||||
|
||||
type RequestFn = (method: string, path: string, data?: unknown, headers?: Record<string, string>) => Promise<AllauthResponse>
|
||||
|
||||
export const createAPI = (
|
||||
request: RequestFn,
|
||||
browserFormAction?: BrowserFormAction
|
||||
) => {
|
||||
return {
|
||||
getConfig: async (): Promise<ConfigurationResponse | ErrorResponse> =>
|
||||
await request('GET', apiURL.CONFIG) as ConfigurationResponse | ErrorResponse,
|
||||
|
||||
session: {
|
||||
getStatus: async (): Promise<AuthenticatedResponse | AuthenticationRequiredResponse | ErrorResponse> =>
|
||||
await request('GET', apiURL.SESSION) as AuthenticatedResponse | AuthenticationRequiredResponse | ErrorResponse,
|
||||
|
||||
list: async (): Promise<SessionListResponse | ErrorResponse> =>
|
||||
await request('GET', apiURL.SESSIONS) as SessionListResponse | ErrorResponse,
|
||||
|
||||
logout: async (): Promise<AllauthResponse> =>
|
||||
await request('DELETE', apiURL.SESSION),
|
||||
|
||||
remove: async (ids: number[]): Promise<AllauthResponse> =>
|
||||
await request('DELETE', apiURL.SESSIONS, { sessions: ids }),
|
||||
},
|
||||
|
||||
account: {
|
||||
signup: async (data: SignupRequest): Promise<AuthenticatedResponse | AuthenticationRequiredResponse | ErrorResponse> =>
|
||||
await request('POST', apiURL.SIGNUP, data) as AuthenticatedResponse | AuthenticationRequiredResponse | ErrorResponse,
|
||||
|
||||
login: async (data: LoginRequest): Promise<AuthenticatedResponse | AuthenticationRequiredResponse | ErrorResponse> =>
|
||||
await request('POST', apiURL.LOGIN, data) as AuthenticatedResponse | AuthenticationRequiredResponse | ErrorResponse,
|
||||
|
||||
reauthenticate: async (data: ReauthenticateRequest): Promise<AuthenticatedResponse | ErrorResponse> =>
|
||||
await request('POST', apiURL.REAUTHENTICATE, data) as AuthenticatedResponse | ErrorResponse,
|
||||
|
||||
emails: {
|
||||
list: async (): Promise<EmailListResponse | ErrorResponse> =>
|
||||
await request('GET', apiURL.EMAIL) as EmailListResponse | ErrorResponse,
|
||||
|
||||
add: async (email: string): Promise<EmailListResponse | ErrorResponse> =>
|
||||
await request('POST', apiURL.EMAIL, { email }) as EmailListResponse | ErrorResponse,
|
||||
|
||||
remove: async (email: string): Promise<EmailListResponse | ErrorResponse> =>
|
||||
await request('DELETE', apiURL.EMAIL, { email }) as EmailListResponse | ErrorResponse,
|
||||
|
||||
setPrimary: async (email: string): Promise<EmailListResponse | ErrorResponse> =>
|
||||
await request('PATCH', apiURL.EMAIL, { email, primary: true }) as EmailListResponse | ErrorResponse,
|
||||
|
||||
verification: {
|
||||
dispatch: async (email: string): Promise<AllauthResponse> =>
|
||||
await request('PUT', apiURL.EMAIL, { email }),
|
||||
|
||||
checkKey: async (key: string): Promise<EmailVerificationInfoResponse | ErrorResponse> =>
|
||||
await request('GET', apiURL.VERIFY_EMAIL, undefined, { 'X-Email-Verification-Key': key }) as EmailVerificationInfoResponse | ErrorResponse,
|
||||
|
||||
confirmKey: async (key: string): Promise<AuthenticatedResponse | ErrorResponse> =>
|
||||
await request('POST', apiURL.VERIFY_EMAIL, { key }) as AuthenticatedResponse | ErrorResponse,
|
||||
}
|
||||
},
|
||||
|
||||
password: {
|
||||
set: async (data: ResetPasswordRequest): Promise<AuthenticatedResponse | ErrorResponse> =>
|
||||
await request('POST', apiURL.RESET_PASSWORD, data) as AuthenticatedResponse | ErrorResponse,
|
||||
|
||||
change: async (data: ChangePasswordRequest): Promise<AllauthResponse> =>
|
||||
await request('POST', apiURL.CHANGE_PASSWORD, data),
|
||||
|
||||
reset: {
|
||||
dispatch: async (email: string): Promise<AllauthResponse> =>
|
||||
await request('POST', apiURL.REQUEST_PASSWORD_RESET, { email }),
|
||||
|
||||
checkKey: async (key: string): Promise<AllauthResponse> =>
|
||||
await request('GET', apiURL.RESET_PASSWORD, undefined, { 'X-Password-Reset-Key': key }),
|
||||
}
|
||||
}
|
||||
},
|
||||
|
||||
loginCodes: {
|
||||
request: async (email: string): Promise<AuthenticationRequiredResponse | ErrorResponse> =>
|
||||
await request('POST', apiURL.REQUEST_LOGIN_CODE, { email }) as AuthenticationRequiredResponse | ErrorResponse,
|
||||
|
||||
confirm: async (code: string): Promise<AuthenticatedResponse | AuthenticationRequiredResponse | ErrorResponse> =>
|
||||
await request('POST', apiURL.CONFIRM_LOGIN_CODE, { code }) as AuthenticatedResponse | AuthenticationRequiredResponse | ErrorResponse,
|
||||
},
|
||||
|
||||
oauth: {
|
||||
list: async (): Promise<ProviderAccountListResponse | ErrorResponse> =>
|
||||
await request('GET', apiURL.PROVIDERS) as ProviderAccountListResponse | ErrorResponse,
|
||||
|
||||
signup: async (data: ProviderSignupRequest): Promise<AuthenticatedResponse | ErrorResponse> =>
|
||||
await request('POST', apiURL.PROVIDER_SIGNUP, data) as AuthenticatedResponse | ErrorResponse,
|
||||
|
||||
provider: (providerID: string) => {
|
||||
const buildAuths = (processType: string) => {
|
||||
return {
|
||||
withToken: async (token: string): Promise<AuthenticatedResponse | AuthenticationRequiredResponse | ErrorResponse> =>
|
||||
await request(
|
||||
'POST',
|
||||
apiURL.PROVIDER_TOKEN,
|
||||
{
|
||||
provider: providerID,
|
||||
process: processType,
|
||||
token: token,
|
||||
}
|
||||
) as AuthenticatedResponse | AuthenticationRequiredResponse | ErrorResponse,
|
||||
|
||||
withRedirect: (endpoint: string): void => {
|
||||
if (browserFormAction) {
|
||||
if (!process.env.NEXT_PUBLIC_HOST_URL) {
|
||||
throw new Error('NEXT_PUBLIC_HOST_URL environment variable is not set. OAuth redirects require this to be set at build time.')
|
||||
}
|
||||
browserFormAction(
|
||||
apiURL.REDIRECT_TO_PROVIDER,
|
||||
{
|
||||
provider: providerID,
|
||||
process: processType,
|
||||
callback_url: new URL(`${process.env.NEXT_PUBLIC_HOST_URL}/${endpoint}`).toString(),
|
||||
}
|
||||
)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return {
|
||||
removeFrom: async (accountUID: string): Promise<ProviderAccountListResponse | ErrorResponse> =>
|
||||
await request('DELETE', apiURL.PROVIDERS, { provider: providerID, account: accountUID }) as ProviderAccountListResponse | ErrorResponse,
|
||||
|
||||
login: buildAuths(OAuthProcess.LOGIN),
|
||||
connect: buildAuths(OAuthProcess.CONNECT),
|
||||
}
|
||||
}
|
||||
},
|
||||
|
||||
mfa: {
|
||||
list: async (): Promise<AuthenticatorListResponse | ErrorResponse> =>
|
||||
await request('GET', apiURL.AUTHENTICATORS) as AuthenticatorListResponse | ErrorResponse,
|
||||
|
||||
authenticate: async (code: string): Promise<AuthenticatedResponse | ErrorResponse> =>
|
||||
await request('POST', apiURL.MFA_AUTHENTICATE, { code } as MFAAuthenticateRequest) as AuthenticatedResponse | ErrorResponse,
|
||||
|
||||
reauthenticate: async (code: string): Promise<AuthenticatedResponse | ErrorResponse> =>
|
||||
await request('POST', apiURL.MFA_REAUTHENTICATE, { code } as MFAAuthenticateRequest) as AuthenticatedResponse | ErrorResponse,
|
||||
|
||||
trust: async (trust: boolean): Promise<AllauthResponse> =>
|
||||
await request('POST', apiURL.MFA_TRUST, { trust }),
|
||||
|
||||
totp: {
|
||||
getStatus: async (): Promise<TOTPStatusResponse | ErrorResponse> =>
|
||||
await request('GET', apiURL.TOTP_AUTHENTICATOR) as TOTPStatusResponse | ErrorResponse,
|
||||
|
||||
activate: async (code: string): Promise<TOTPStatusResponse | ErrorResponse> =>
|
||||
await request('POST', apiURL.TOTP_AUTHENTICATOR, { code }) as TOTPStatusResponse | ErrorResponse,
|
||||
|
||||
deactivate: async (): Promise<AllauthResponse> =>
|
||||
await request('DELETE', apiURL.TOTP_AUTHENTICATOR),
|
||||
},
|
||||
|
||||
recoveryCodes: {
|
||||
list: async (): Promise<RecoveryCodesResponse | ErrorResponse> =>
|
||||
await request('GET', apiURL.RECOVERY_CODES) as RecoveryCodesResponse | ErrorResponse,
|
||||
|
||||
regenerate: async (): Promise<RecoveryCodesResponse | ErrorResponse> =>
|
||||
await request('POST', apiURL.RECOVERY_CODES) as RecoveryCodesResponse | ErrorResponse,
|
||||
}
|
||||
},
|
||||
|
||||
webauthn: {
|
||||
signup: async (name: string, credential: RegistrationCredential): Promise<AuthenticatedResponse | ErrorResponse> =>
|
||||
await request('PUT', apiURL.SIGNUP_WEBAUTHN, { name, credential }) as AuthenticatedResponse | ErrorResponse,
|
||||
|
||||
add: async (name: string, credential: RegistrationCredential): Promise<AllauthResponse> =>
|
||||
await request('POST', apiURL.WEBAUTHN_AUTHENTICATOR, { name, credential }),
|
||||
|
||||
login: async (credential: AuthenticationCredential): Promise<AuthenticatedResponse | AuthenticationRequiredResponse | ErrorResponse> =>
|
||||
await request('POST', apiURL.LOGIN_WEBAUTHN, { credential }) as AuthenticatedResponse | AuthenticationRequiredResponse | ErrorResponse,
|
||||
|
||||
authenticate: async (credential: AuthenticationCredential): Promise<AuthenticatedResponse | ErrorResponse> =>
|
||||
await request('POST', apiURL.AUTHENTICATE_WEBAUTHN, { credential }) as AuthenticatedResponse | ErrorResponse,
|
||||
|
||||
reauthenticate: async (credential: AuthenticationCredential): Promise<AuthenticatedResponse | ErrorResponse> =>
|
||||
await request('POST', apiURL.REAUTHENTICATE_WEBAUTHN, { credential }) as AuthenticatedResponse | ErrorResponse,
|
||||
|
||||
update: async (id: number, data: Omit<WebAuthnUpdateRequest, 'id'>): Promise<AllauthResponse> =>
|
||||
await request('PUT', apiURL.WEBAUTHN_AUTHENTICATOR, { id, ...data }),
|
||||
|
||||
delete: async (ids: number[]): Promise<AllauthResponse> =>
|
||||
await request('DELETE', apiURL.WEBAUTHN_AUTHENTICATOR, { authenticators: ids }),
|
||||
|
||||
passkey: {
|
||||
signup: async (email: string): Promise<AllauthResponse> =>
|
||||
await request('POST', apiURL.SIGNUP_WEBAUTHN, { email }),
|
||||
|
||||
confirm: async (): Promise<AuthenticatedResponse | ErrorResponse> =>
|
||||
await request('PUT', apiURL.SIGNUP_WEBAUTHN) as AuthenticatedResponse | ErrorResponse,
|
||||
},
|
||||
|
||||
requestOptions: {
|
||||
creation: async (passwordless: boolean): Promise<WebAuthnCreationOptionsResponse | ErrorResponse> =>
|
||||
await request('GET', apiURL.WEBAUTHN_AUTHENTICATOR + (passwordless ? '?passwordless' : '')) as WebAuthnCreationOptionsResponse | ErrorResponse,
|
||||
|
||||
creationAtSignup: async (): Promise<WebAuthnCreationOptionsResponse | ErrorResponse> =>
|
||||
await request('GET', apiURL.SIGNUP_WEBAUTHN) as WebAuthnCreationOptionsResponse | ErrorResponse,
|
||||
|
||||
login: async (): Promise<WebAuthnRequestOptionsResponse | ErrorResponse> =>
|
||||
await request('GET', apiURL.LOGIN_WEBAUTHN) as WebAuthnRequestOptionsResponse | ErrorResponse,
|
||||
|
||||
authentication: async (): Promise<WebAuthnRequestOptionsResponse | ErrorResponse> =>
|
||||
await request('GET', apiURL.AUTHENTICATE_WEBAUTHN) as WebAuthnRequestOptionsResponse | ErrorResponse,
|
||||
|
||||
reauthentication: async (): Promise<WebAuthnRequestOptionsResponse | ErrorResponse> =>
|
||||
await request('GET', apiURL.REAUTHENTICATE_WEBAUTHN) as WebAuthnRequestOptionsResponse | ErrorResponse,
|
||||
},
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
export type AllauthAPI = ReturnType<typeof createAPI>
|
||||
@@ -0,0 +1,220 @@
|
||||
'use client'
|
||||
|
||||
import { useEffect, useState } from 'react'
|
||||
import { useRouter } from '../contexts/RouterContext'
|
||||
import { useAllauthAPI } from '../contexts/APIContext'
|
||||
import { useAllauthConfig } from '../contexts/ConfigContext'
|
||||
import { DjangoFlowPaths } from '../config'
|
||||
import { AuthCard } from './AuthCard'
|
||||
import { AuthDjangoForm } from './AuthDjangoForm'
|
||||
|
||||
interface AllauthRouterProps {
|
||||
/** Called after successful completion of any flow */
|
||||
onComplete?: () => void
|
||||
/** Called when user wants to go back to login */
|
||||
onLoginClick?: () => void
|
||||
}
|
||||
|
||||
/**
|
||||
* AllauthRouter handles Django-initiated flows (email verification, password reset, OAuth).
|
||||
*
|
||||
* Mount this at a catch-all route matching your basePath config:
|
||||
* app/auth/[...path]/page.tsx -> <AllauthRouter />
|
||||
*
|
||||
* The path determines which flow to render:
|
||||
* /auth/verify-email/[key] -> Email verification
|
||||
* /auth/reset-password?key=xxx -> Password reset form
|
||||
* /auth/oauth/callback -> OAuth completion
|
||||
*/
|
||||
export function AllauthRouter({ onComplete, onLoginClick }: AllauthRouterProps) {
|
||||
const router = useRouter()
|
||||
const config = useAllauthConfig()
|
||||
|
||||
// Parse the path segments after basePath
|
||||
// The router provides getParam('path') which returns the catch-all segments
|
||||
const pathParam = router.getParam('path')
|
||||
const pathSegments = Array.isArray(pathParam) ? pathParam : pathParam ? [pathParam] : []
|
||||
const path = pathSegments.length > 0 ? `/${pathSegments.join('/')}` : '/'
|
||||
|
||||
// Determine which flow based on path
|
||||
if (path.startsWith(DjangoFlowPaths.VERIFY_EMAIL)) {
|
||||
const key = pathSegments[1] || router.searchParams.get('key')
|
||||
return (
|
||||
<EmailVerifyView
|
||||
verificationKey={key}
|
||||
onComplete={onComplete}
|
||||
onLoginClick={onLoginClick}
|
||||
/>
|
||||
)
|
||||
}
|
||||
|
||||
if (path.startsWith(DjangoFlowPaths.RESET_PASSWORD)) {
|
||||
const key = pathSegments[1] || router.searchParams.get('key')
|
||||
return (
|
||||
<PasswordResetView
|
||||
resetKey={key}
|
||||
onComplete={onComplete}
|
||||
onLoginClick={onLoginClick}
|
||||
/>
|
||||
)
|
||||
}
|
||||
|
||||
if (path.startsWith(DjangoFlowPaths.OAUTH_ERROR)) {
|
||||
return (
|
||||
<OAuthErrorView
|
||||
onLoginClick={onLoginClick}
|
||||
/>
|
||||
)
|
||||
}
|
||||
|
||||
// Unknown path
|
||||
return (
|
||||
<AuthCard
|
||||
title="Not Found"
|
||||
subtitle="This page doesn't exist."
|
||||
footerLinks={onLoginClick ? [
|
||||
{ label: 'Back to Sign In', onClick: onLoginClick },
|
||||
] : []}
|
||||
/>
|
||||
)
|
||||
}
|
||||
|
||||
// ----------------------------------------------------------------------------
|
||||
// Email Verification View
|
||||
// ----------------------------------------------------------------------------
|
||||
|
||||
interface EmailVerifyViewProps {
|
||||
verificationKey: string | null | undefined
|
||||
onComplete?: () => void
|
||||
onLoginClick?: () => void
|
||||
}
|
||||
|
||||
function EmailVerifyView({ verificationKey, onComplete, onLoginClick }: EmailVerifyViewProps) {
|
||||
const api = useAllauthAPI()
|
||||
const [status, setStatus] = useState<'loading' | 'success' | 'error'>('loading')
|
||||
const [error, setError] = useState('')
|
||||
|
||||
useEffect(() => {
|
||||
if (!verificationKey) {
|
||||
setStatus('error')
|
||||
setError('Invalid verification link')
|
||||
return
|
||||
}
|
||||
|
||||
const verify = async () => {
|
||||
const res = await api.account.emails.verification.confirmKey(verificationKey)
|
||||
|
||||
if (res.status === 200) {
|
||||
setStatus('success')
|
||||
if (onComplete) {
|
||||
setTimeout(onComplete, 2000)
|
||||
}
|
||||
} else {
|
||||
setStatus('error')
|
||||
setError(res.errors?.[0]?.message || 'Invalid or expired verification link')
|
||||
}
|
||||
}
|
||||
|
||||
verify()
|
||||
}, [verificationKey, api, onComplete])
|
||||
|
||||
if (status === 'loading') {
|
||||
return <AuthCard title="" loading loadingText="Verifying your email..." />
|
||||
}
|
||||
|
||||
if (status === 'success') {
|
||||
return (
|
||||
<AuthCard
|
||||
title="Email Verified"
|
||||
subtitle="Your email has been verified successfully."
|
||||
/>
|
||||
)
|
||||
}
|
||||
|
||||
return (
|
||||
<AuthCard
|
||||
title="Verification Failed"
|
||||
error={error}
|
||||
footerLinks={onLoginClick ? [
|
||||
{ label: 'Back to Sign In', onClick: onLoginClick },
|
||||
] : []}
|
||||
/>
|
||||
)
|
||||
}
|
||||
|
||||
// ----------------------------------------------------------------------------
|
||||
// Password Reset View
|
||||
// ----------------------------------------------------------------------------
|
||||
|
||||
interface PasswordResetViewProps {
|
||||
resetKey: string | null | undefined
|
||||
onComplete?: () => void
|
||||
onLoginClick?: () => void
|
||||
}
|
||||
|
||||
function PasswordResetView({ resetKey, onComplete, onLoginClick }: PasswordResetViewProps) {
|
||||
const [success, setSuccess] = useState(false)
|
||||
|
||||
if (!resetKey) {
|
||||
return (
|
||||
<AuthCard
|
||||
title="Invalid Link"
|
||||
subtitle="This password reset link is invalid or has expired."
|
||||
footerLinks={onLoginClick ? [
|
||||
{ label: 'Back to Sign In', onClick: onLoginClick },
|
||||
] : []}
|
||||
/>
|
||||
)
|
||||
}
|
||||
|
||||
if (success) {
|
||||
return (
|
||||
<AuthCard
|
||||
title="Password Changed"
|
||||
subtitle="Your password has been successfully reset."
|
||||
footerLinks={onLoginClick ? [
|
||||
{ label: 'Sign In', onClick: onLoginClick },
|
||||
] : []}
|
||||
/>
|
||||
)
|
||||
}
|
||||
|
||||
return (
|
||||
<AuthDjangoForm
|
||||
formName="reset_password_from_key"
|
||||
onSuccess={() => {
|
||||
setSuccess(true)
|
||||
// Give user time to see success message before redirect
|
||||
if (onComplete) {
|
||||
setTimeout(onComplete, 2000)
|
||||
}
|
||||
}}
|
||||
footerLinks={onLoginClick ? [
|
||||
{ href: '#', label: 'Back to Sign In', onClick: onLoginClick },
|
||||
] : []}
|
||||
/>
|
||||
)
|
||||
}
|
||||
|
||||
// ----------------------------------------------------------------------------
|
||||
// OAuth Error View
|
||||
// ----------------------------------------------------------------------------
|
||||
|
||||
interface OAuthErrorViewProps {
|
||||
onLoginClick?: () => void
|
||||
}
|
||||
|
||||
function OAuthErrorView({ onLoginClick }: OAuthErrorViewProps) {
|
||||
const router = useRouter()
|
||||
const error = router.searchParams.get('error') || 'An error occurred during authentication'
|
||||
|
||||
return (
|
||||
<AuthCard
|
||||
title="Authentication Failed"
|
||||
error={error}
|
||||
footerLinks={onLoginClick ? [
|
||||
{ label: 'Back to Sign In', onClick: onLoginClick },
|
||||
] : []}
|
||||
/>
|
||||
)
|
||||
}
|
||||
@@ -0,0 +1,447 @@
|
||||
'use client'
|
||||
|
||||
import { useState, useEffect, useRef } from 'react'
|
||||
import { useAuth, useAuthContext, useFeatures } from '../contexts/AuthContext'
|
||||
import { useAllauthAPI } from '../contexts/APIContext'
|
||||
import { useStyles } from '../contexts/StylesContext'
|
||||
import { getAuthDetails } from '../api'
|
||||
import { AuthenticatorType } from '../defines'
|
||||
import { AuthSettings } from './settings/AuthSettings'
|
||||
import { AuthCard } from './AuthCard'
|
||||
import { AuthDjangoForm } from './AuthDjangoForm'
|
||||
import { Button } from './settings/SettingsComponents'
|
||||
import { LoginView } from './views/LoginView'
|
||||
import { SignupView } from './views/SignupView'
|
||||
import { MFAChooserView } from './views/MFAChooserView'
|
||||
import { MFAWebAuthnView } from './views/MFAWebAuthnView'
|
||||
import { MFATOTPView } from './views/MFATOTPView'
|
||||
import { MFARecoveryCodesView } from './views/MFARecoveryCodesView'
|
||||
|
||||
/**
|
||||
* All possible views in the AllauthUI component.
|
||||
* Views are rendered based on state, not URLs.
|
||||
*/
|
||||
export type AllauthUIView =
|
||||
// Auth views (for unauthenticated users)
|
||||
| 'login'
|
||||
| 'signup'
|
||||
| 'resetPassword'
|
||||
| 'resetPasswordSent'
|
||||
| 'requestCode'
|
||||
| 'confirmCode'
|
||||
// MFA views (during auth flow)
|
||||
| 'mfaChooser'
|
||||
| 'mfaTotp'
|
||||
| 'mfaWebauthn'
|
||||
| 'mfaRecoveryCodes'
|
||||
// Authenticated views
|
||||
| 'settings'
|
||||
| 'logout'
|
||||
|
||||
/**
|
||||
* Controls how AllauthUI behaves regarding auth/settings transitions.
|
||||
*
|
||||
* - `'auto'` (default): Full SPA - shows auth views when not authenticated,
|
||||
* automatically transitions to settings after login, and back to login after logout.
|
||||
*
|
||||
* - `'auth'`: Auth-only mode - only shows auth views (login, signup, MFA, etc.).
|
||||
* Never shows settings. Use `onAuthenticated` to handle post-login navigation.
|
||||
* Ideal for a dedicated login page.
|
||||
*
|
||||
* - `'settings'`: Settings-only mode - only shows settings views.
|
||||
* If not authenticated, calls `onUnauthenticated` or shows nothing.
|
||||
* Ideal for a dedicated settings page.
|
||||
*/
|
||||
export type AllauthUIMode = 'auto' | 'auth' | 'settings'
|
||||
|
||||
interface AllauthUIProps {
|
||||
/**
|
||||
* Controls auth/settings transition behavior.
|
||||
* @default 'auto'
|
||||
*/
|
||||
mode?: AllauthUIMode
|
||||
|
||||
/**
|
||||
* Initial view when component mounts (for 'auto' and 'auth' modes).
|
||||
* Defaults to 'login' for unauthenticated, 'settings' for authenticated (in auto mode).
|
||||
*/
|
||||
initialView?: AllauthUIView
|
||||
|
||||
/**
|
||||
* Called when authentication completes successfully.
|
||||
* Required for 'auth' mode to handle post-login navigation.
|
||||
*/
|
||||
onAuthenticated?: () => void
|
||||
|
||||
/**
|
||||
* Called when user is not authenticated (for 'settings' mode).
|
||||
* Use this to redirect to login page.
|
||||
*/
|
||||
onUnauthenticated?: () => void
|
||||
|
||||
/**
|
||||
* Called when user logs out.
|
||||
* In 'auto' mode, defaults to showing login view.
|
||||
*/
|
||||
onLogout?: () => void
|
||||
|
||||
/**
|
||||
* Which settings sections to show.
|
||||
* Defaults to all sections.
|
||||
*/
|
||||
settingsSections?: Array<'profile' | 'emails' | 'password' | 'passkeys' | 'connections' | 'mfa' | 'sessions'>
|
||||
|
||||
/**
|
||||
* OAuth callback URL for social login providers.
|
||||
*/
|
||||
oauthCallbackUrl?: string
|
||||
}
|
||||
|
||||
/**
|
||||
* AllauthUI is the main component for rendering auth UI.
|
||||
*
|
||||
* It can operate in three modes:
|
||||
* - `'auto'` (default): Full SPA handling login, MFA, settings, and logout
|
||||
* - `'auth'`: Auth-only for dedicated login pages
|
||||
* - `'settings'`: Settings-only for dedicated settings pages
|
||||
*
|
||||
* @example
|
||||
* ```tsx
|
||||
* // Full SPA mode (default) - handles everything
|
||||
* <AllauthUI />
|
||||
*
|
||||
* // Auth-only mode - for a dedicated login page
|
||||
* <AllauthUI mode="auth" onAuthenticated={() => router.push('/dashboard')} />
|
||||
*
|
||||
* // Settings-only mode - for a dedicated settings page
|
||||
* <AllauthUI mode="settings" onUnauthenticated={() => router.push('/login')} />
|
||||
* ```
|
||||
*/
|
||||
export function AllauthUI({
|
||||
mode = 'auto',
|
||||
initialView,
|
||||
onAuthenticated,
|
||||
onUnauthenticated,
|
||||
onLogout,
|
||||
settingsSections,
|
||||
oauthCallbackUrl,
|
||||
}: AllauthUIProps) {
|
||||
const { isAuthenticated, pendingFlow } = useAuth()
|
||||
const { refresh } = useAuthContext()
|
||||
const api = useAllauthAPI()
|
||||
const styles = useStyles()
|
||||
const features = useFeatures()
|
||||
|
||||
// Get available MFA types from pending flow
|
||||
const mfaTypes = pendingFlow?.types || []
|
||||
|
||||
// Internal view state
|
||||
const [view, setView] = useState<AllauthUIView>(() => {
|
||||
if (initialView) return initialView
|
||||
|
||||
// Settings mode always starts at settings
|
||||
if (mode === 'settings') return 'settings'
|
||||
|
||||
// Auth mode always starts at login (or MFA if pending)
|
||||
if (mode === 'auth') {
|
||||
if (pendingFlow) {
|
||||
return mfaTypes.length === 1 ? getMFAView(mfaTypes[0]) : 'mfaChooser'
|
||||
}
|
||||
return 'login'
|
||||
}
|
||||
|
||||
// Auto mode: settings if authenticated, login otherwise
|
||||
if (isAuthenticated) return 'settings'
|
||||
if (pendingFlow) {
|
||||
return mfaTypes.length === 1 ? getMFAView(mfaTypes[0]) : 'mfaChooser'
|
||||
}
|
||||
return 'login'
|
||||
})
|
||||
|
||||
// Track auth state changes
|
||||
const wasAuthenticated = useRef(isAuthenticated)
|
||||
const hadPendingFlow = useRef(!!pendingFlow)
|
||||
|
||||
// Handle auth state transitions
|
||||
useEffect(() => {
|
||||
// User just became authenticated
|
||||
if (!wasAuthenticated.current && isAuthenticated) {
|
||||
if (onAuthenticated) {
|
||||
onAuthenticated()
|
||||
} else if (mode === 'auto') {
|
||||
setView('settings')
|
||||
}
|
||||
// In 'auth' mode without onAuthenticated, do nothing (stay on current view)
|
||||
}
|
||||
|
||||
// User just logged out
|
||||
if (wasAuthenticated.current && !isAuthenticated) {
|
||||
if (onLogout) {
|
||||
onLogout()
|
||||
} else if (mode === 'auto') {
|
||||
setView('login')
|
||||
} else if (mode === 'settings' && onUnauthenticated) {
|
||||
onUnauthenticated()
|
||||
}
|
||||
}
|
||||
|
||||
wasAuthenticated.current = isAuthenticated
|
||||
}, [isAuthenticated, onAuthenticated, onUnauthenticated, onLogout, mode])
|
||||
|
||||
// Handle MFA flow transitions
|
||||
useEffect(() => {
|
||||
if (pendingFlow && !hadPendingFlow.current) {
|
||||
// New MFA flow started
|
||||
if (mfaTypes.length === 1) {
|
||||
setView(getMFAView(mfaTypes[0]))
|
||||
} else if (mfaTypes.length > 1) {
|
||||
setView('mfaChooser')
|
||||
}
|
||||
}
|
||||
if (!pendingFlow && hadPendingFlow.current && isAuthenticated) {
|
||||
// MFA completed successfully
|
||||
if (onAuthenticated) {
|
||||
onAuthenticated()
|
||||
} else if (mode === 'auto') {
|
||||
setView('settings')
|
||||
}
|
||||
}
|
||||
hadPendingFlow.current = !!pendingFlow
|
||||
}, [pendingFlow, mfaTypes, isAuthenticated, onAuthenticated, mode])
|
||||
|
||||
// Settings mode: handle unauthenticated state
|
||||
useEffect(() => {
|
||||
if (mode === 'settings' && !isAuthenticated && onUnauthenticated) {
|
||||
onUnauthenticated()
|
||||
}
|
||||
}, [mode, isAuthenticated, onUnauthenticated])
|
||||
|
||||
// Handle logout
|
||||
const handleLogout = async () => {
|
||||
await api.session.logout()
|
||||
await refresh()
|
||||
if (onLogout) {
|
||||
onLogout()
|
||||
} else if (mode === 'auto') {
|
||||
setView('login')
|
||||
}
|
||||
// In settings mode, the useEffect will call onUnauthenticated
|
||||
}
|
||||
|
||||
// Called after successful login/signup - check for MFA or complete auth
|
||||
const handleAuthSuccess = async () => {
|
||||
const newAuth = await refresh()
|
||||
const details = getAuthDetails(newAuth)
|
||||
|
||||
// If fully authenticated, handle completion
|
||||
if (details.isAuthenticated) {
|
||||
if (onAuthenticated) {
|
||||
onAuthenticated()
|
||||
} else if (mode === 'auto') {
|
||||
setView('settings')
|
||||
}
|
||||
// In 'auth' mode without onAuthenticated, stay on current view
|
||||
}
|
||||
// If MFA pending, the useEffect will handle the view transition
|
||||
}
|
||||
|
||||
// Render based on current view
|
||||
switch (view) {
|
||||
// ============================================
|
||||
// Authenticated views
|
||||
// ============================================
|
||||
case 'settings':
|
||||
// In auth mode, never show settings
|
||||
if (mode === 'auth') {
|
||||
return null
|
||||
}
|
||||
// Not authenticated - handle based on mode
|
||||
if (!isAuthenticated) {
|
||||
if (mode === 'settings' && onUnauthenticated) {
|
||||
// Will be handled by useEffect
|
||||
return null
|
||||
}
|
||||
// Auto mode: switch to login
|
||||
setView('login')
|
||||
return null
|
||||
}
|
||||
return (
|
||||
<AuthSettings
|
||||
sections={settingsSections}
|
||||
onSignOut={() => setView('logout')}
|
||||
/>
|
||||
)
|
||||
|
||||
case 'logout':
|
||||
if (!isAuthenticated) {
|
||||
if (mode === 'auto') {
|
||||
setView('login')
|
||||
}
|
||||
return null
|
||||
}
|
||||
return (
|
||||
<AuthCard
|
||||
title="Sign Out"
|
||||
subtitle="Are you sure you want to sign out?"
|
||||
footerLinks={[
|
||||
{ label: 'Cancel', onClick: () => setView('settings') },
|
||||
]}
|
||||
>
|
||||
<div className={styles.form}>
|
||||
<Button onClick={handleLogout}>
|
||||
Sign Out
|
||||
</Button>
|
||||
</div>
|
||||
</AuthCard>
|
||||
)
|
||||
|
||||
// ============================================
|
||||
// MFA views
|
||||
// ============================================
|
||||
case 'mfaChooser':
|
||||
return (
|
||||
<MFAChooserView
|
||||
types={mfaTypes}
|
||||
onSuccess={handleAuthSuccess}
|
||||
onCancel={() => setView('login')}
|
||||
/>
|
||||
)
|
||||
|
||||
case 'mfaTotp':
|
||||
return (
|
||||
<MFATOTPView
|
||||
onSuccess={handleAuthSuccess}
|
||||
onCancel={() => setView('login')}
|
||||
onBack={mfaTypes.length > 1 ? () => setView('mfaChooser') : undefined}
|
||||
/>
|
||||
)
|
||||
|
||||
case 'mfaWebauthn':
|
||||
return (
|
||||
<MFAWebAuthnView
|
||||
onSuccess={handleAuthSuccess}
|
||||
onCancel={() => setView('login')}
|
||||
onBack={mfaTypes.length > 1 ? () => setView('mfaChooser') : undefined}
|
||||
/>
|
||||
)
|
||||
|
||||
case 'mfaRecoveryCodes':
|
||||
return (
|
||||
<MFARecoveryCodesView
|
||||
onSuccess={handleAuthSuccess}
|
||||
onCancel={() => setView('login')}
|
||||
onBack={mfaTypes.length > 1 ? () => setView('mfaChooser') : undefined}
|
||||
/>
|
||||
)
|
||||
|
||||
// ============================================
|
||||
// Password reset views
|
||||
// ============================================
|
||||
case 'resetPassword':
|
||||
return (
|
||||
<AuthDjangoForm
|
||||
formName="reset_password"
|
||||
onSuccess={() => setView('resetPasswordSent')}
|
||||
footerLinks={[
|
||||
{ label: 'Back to Sign In', onClick: () => setView('login') },
|
||||
]}
|
||||
/>
|
||||
)
|
||||
|
||||
case 'resetPasswordSent':
|
||||
return (
|
||||
<AuthCard
|
||||
title="Check Your Email"
|
||||
subtitle="If an account exists with that email, we've sent password reset instructions."
|
||||
footerLinks={[
|
||||
{ label: 'Back to Sign In', onClick: () => setView('login') },
|
||||
]}
|
||||
/>
|
||||
)
|
||||
|
||||
// ============================================
|
||||
// Login by code views
|
||||
// ============================================
|
||||
case 'requestCode':
|
||||
// If login by code is disabled, redirect to login
|
||||
if (!features.loginByCodeEnabled) {
|
||||
setView('login')
|
||||
return null
|
||||
}
|
||||
return (
|
||||
<AuthDjangoForm
|
||||
formName="request_login_code"
|
||||
onSuccess={() => setView('confirmCode')}
|
||||
footerLinks={[
|
||||
{ label: 'Sign in with password instead', onClick: () => setView('login') },
|
||||
]}
|
||||
/>
|
||||
)
|
||||
|
||||
case 'confirmCode':
|
||||
// If login by code is disabled, redirect to login
|
||||
if (!features.loginByCodeEnabled) {
|
||||
setView('login')
|
||||
return null
|
||||
}
|
||||
return (
|
||||
<AuthDjangoForm
|
||||
formName="confirm_login_code"
|
||||
onSuccess={handleAuthSuccess}
|
||||
footerLinks={[
|
||||
{ label: 'Request a new code', onClick: () => setView('requestCode') },
|
||||
{ label: 'Sign in with password instead', onClick: () => setView('login') },
|
||||
]}
|
||||
/>
|
||||
)
|
||||
|
||||
// ============================================
|
||||
// Signup view
|
||||
// ============================================
|
||||
case 'signup':
|
||||
// If signup is disabled, redirect to login
|
||||
if (!features.signupEnabled) {
|
||||
setView('login')
|
||||
return null
|
||||
}
|
||||
return (
|
||||
<SignupView
|
||||
onSuccess={handleAuthSuccess}
|
||||
onLoginClick={() => setView('login')}
|
||||
/>
|
||||
)
|
||||
|
||||
// ============================================
|
||||
// Login view (default)
|
||||
// ============================================
|
||||
case 'login':
|
||||
default:
|
||||
return (
|
||||
<LoginView
|
||||
onSuccess={handleAuthSuccess}
|
||||
// Only provide signup callback if signups are enabled
|
||||
onSignupClick={features.signupEnabled ? () => setView('signup') : undefined}
|
||||
onForgotPasswordClick={() => setView('resetPassword')}
|
||||
// Only provide login-by-code callback if feature is enabled
|
||||
onLoginByCodeClick={features.loginByCodeEnabled ? () => setView('requestCode') : undefined}
|
||||
oauthCallbackUrl={oauthCallbackUrl}
|
||||
/>
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Get the view name for a given MFA authenticator type.
|
||||
*/
|
||||
function getMFAView(type: string): AllauthUIView {
|
||||
switch (type) {
|
||||
case AuthenticatorType.TOTP:
|
||||
return 'mfaTotp'
|
||||
case AuthenticatorType.WEBAUTHN:
|
||||
return 'mfaWebauthn'
|
||||
case AuthenticatorType.RECOVERY_CODES:
|
||||
return 'mfaRecoveryCodes'
|
||||
default:
|
||||
return 'mfaChooser'
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,85 @@
|
||||
'use client'
|
||||
|
||||
import { ReactNode } from 'react'
|
||||
import { useRouter } from '../contexts/RouterContext'
|
||||
import { useStyles } from '../contexts/StylesContext'
|
||||
|
||||
interface FooterLink {
|
||||
label: string
|
||||
href?: string
|
||||
onClick?: () => void
|
||||
}
|
||||
|
||||
interface AuthCardProps {
|
||||
title: string
|
||||
subtitle?: string
|
||||
children?: ReactNode
|
||||
footerLinks?: FooterLink[]
|
||||
error?: string
|
||||
success?: string
|
||||
loading?: boolean
|
||||
loadingText?: string
|
||||
}
|
||||
|
||||
export function AuthCard({
|
||||
title,
|
||||
subtitle,
|
||||
children,
|
||||
footerLinks,
|
||||
error,
|
||||
success,
|
||||
loading,
|
||||
loadingText = 'Loading...',
|
||||
}: AuthCardProps) {
|
||||
const router = useRouter()
|
||||
const styles = useStyles()
|
||||
|
||||
const handleLinkClick = (e: React.MouseEvent, href: string) => {
|
||||
e.preventDefault()
|
||||
router.push(href)
|
||||
}
|
||||
|
||||
return (
|
||||
<div className={styles.container}>
|
||||
<div className={styles.card}>
|
||||
{loading ? (
|
||||
<div className={styles.loading}>
|
||||
<div className={styles.spinner} />
|
||||
<p className={styles.subtitle}>{loadingText}</p>
|
||||
</div>
|
||||
) : (
|
||||
<>
|
||||
<h1 className={styles.title}>{title}</h1>
|
||||
{subtitle && <p className={styles.subtitle}>{subtitle}</p>}
|
||||
|
||||
{error && <div className={styles.error}>{error}</div>}
|
||||
{success && <div className={styles.success}>{success}</div>}
|
||||
|
||||
{children}
|
||||
|
||||
{footerLinks && footerLinks.length > 0 && (
|
||||
<div className={styles.footer}>
|
||||
{footerLinks.map((link, i) => (
|
||||
link.onClick ? (
|
||||
<button key={i} onClick={link.onClick} className={styles.link}>
|
||||
{link.label}
|
||||
</button>
|
||||
) : link.href ? (
|
||||
<a
|
||||
key={i}
|
||||
href={link.href}
|
||||
onClick={(e) => handleLinkClick(e, link.href!)}
|
||||
className={styles.link}
|
||||
>
|
||||
{link.label}
|
||||
</a>
|
||||
) : null
|
||||
))}
|
||||
</div>
|
||||
)}
|
||||
</>
|
||||
)}
|
||||
</div>
|
||||
</div>
|
||||
)
|
||||
}
|
||||
@@ -0,0 +1,326 @@
|
||||
'use client'
|
||||
|
||||
import { FormEvent, useEffect, useState } from 'react'
|
||||
import {
|
||||
useDjangoFormCore,
|
||||
type DjangoFormState,
|
||||
type FormOptions,
|
||||
type FormErrors,
|
||||
} from 'mizan'
|
||||
import { useAuthContext } from '../contexts/AuthContext'
|
||||
import { useStyles } from '../contexts/StylesContext'
|
||||
import { getAuthDetails, AuthDetails } from '../api'
|
||||
|
||||
interface FooterLink {
|
||||
label: string
|
||||
href?: string
|
||||
onClick?: () => void
|
||||
}
|
||||
|
||||
interface AuthDjangoFormProps {
|
||||
/** Form name (e.g., "login", "signup", "change_password") */
|
||||
formName: string
|
||||
/** Callback after successful form submission */
|
||||
onSuccess?: (result: any, authDetails: AuthDetails) => void
|
||||
/** Callback after failed form submission */
|
||||
onError?: (errors: any) => void
|
||||
/** Links to show in footer (e.g., "Forgot password?") */
|
||||
footerLinks?: FooterLink[]
|
||||
/** Content to render before form fields */
|
||||
preFields?: React.ReactNode
|
||||
/** Content to render after form fields (before submit button) */
|
||||
postFields?: React.ReactNode
|
||||
/** Override the submit button label from schema */
|
||||
submitLabel?: string
|
||||
/** Override the title from schema */
|
||||
title?: string
|
||||
/** Override the subtitle from schema */
|
||||
subtitle?: string
|
||||
/** Options for form behavior (validation, schema refetch, etc.) */
|
||||
formOptions?: FormOptions
|
||||
}
|
||||
|
||||
/**
|
||||
* AuthDjangoForm renders a form from the mizan server functions
|
||||
* with styling consistent with the auth UI.
|
||||
*
|
||||
* It fetches the form schema (including title, subtitle, fields, submit label)
|
||||
* from the backend and renders it dynamically with real-time validation.
|
||||
*/
|
||||
export function AuthDjangoForm({
|
||||
formName,
|
||||
onSuccess,
|
||||
onError,
|
||||
footerLinks,
|
||||
preFields,
|
||||
postFields,
|
||||
submitLabel,
|
||||
title,
|
||||
subtitle,
|
||||
formOptions,
|
||||
}: AuthDjangoFormProps) {
|
||||
const form = useDjangoFormCore<Record<string, unknown>>({
|
||||
name: formName,
|
||||
options: formOptions,
|
||||
})
|
||||
const { refresh } = useAuthContext()
|
||||
const styles = useStyles()
|
||||
const [mounted, setMounted] = useState(false)
|
||||
|
||||
// Hydration safety: only render inputs after mount
|
||||
useEffect(() => {
|
||||
setMounted(true)
|
||||
}, [])
|
||||
|
||||
const handleSubmit = async (e: FormEvent) => {
|
||||
e.preventDefault()
|
||||
const result = await form.submit()
|
||||
|
||||
if (result.success) {
|
||||
// Refresh auth state and get the updated auth for callbacks
|
||||
const newAuth = await refresh()
|
||||
onSuccess?.(result.data, getAuthDetails(newAuth))
|
||||
} else {
|
||||
onError?.(result.errors)
|
||||
}
|
||||
}
|
||||
|
||||
// Loading state
|
||||
if (form.loading) {
|
||||
return (
|
||||
<div className={styles.container}>
|
||||
<div className={styles.card}>
|
||||
<div className={styles.loading}>
|
||||
<div className={styles.spinner} />
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
)
|
||||
}
|
||||
|
||||
// Get form-level errors (non-field errors like "Invalid credentials")
|
||||
// These only appear after submission due to 'field-only' default
|
||||
const formErrors = form.getFormErrors()
|
||||
|
||||
// Use prop overrides or schema values
|
||||
const displayTitle = title ?? form.schema?.title
|
||||
const displaySubtitle = subtitle ?? form.schema?.subtitle
|
||||
const displaySubmitLabel = submitLabel ?? form.schema?.submit_label ?? 'Submit'
|
||||
|
||||
return (
|
||||
<div className={styles.container}>
|
||||
<div className={styles.card}>
|
||||
{displayTitle && (
|
||||
<h1 className={styles.title}>{displayTitle}</h1>
|
||||
)}
|
||||
{displaySubtitle && (
|
||||
<p className={styles.subtitle}>{displaySubtitle}</p>
|
||||
)}
|
||||
|
||||
{/* Form-level errors (shown after submission) */}
|
||||
{formErrors.length > 0 && (
|
||||
<div className={styles.error}>
|
||||
{formErrors.map((err, i) => (
|
||||
<p key={i}>{err.message}</p>
|
||||
))}
|
||||
</div>
|
||||
)}
|
||||
|
||||
<form onSubmit={handleSubmit} className={styles.form}>
|
||||
{preFields}
|
||||
|
||||
<div className={styles.fieldsContainer}>
|
||||
{form.schema?.fieldOrder.map(fieldName => {
|
||||
const field = form.schema!.fields[fieldName]
|
||||
return (
|
||||
<AuthField
|
||||
key={fieldName}
|
||||
field={{
|
||||
name: fieldName,
|
||||
label: field.label,
|
||||
type: field.type,
|
||||
widget: field.widget,
|
||||
required: field.required,
|
||||
disabled: field.disabled,
|
||||
help_text: field.help_text,
|
||||
max_length: field.max_length,
|
||||
choices: field.choices,
|
||||
}}
|
||||
value={form.data[fieldName]}
|
||||
mounted={mounted}
|
||||
touched={form.touchedFields.has(fieldName)}
|
||||
errors={form.getFieldErrors(fieldName)}
|
||||
onChange={(value) => form.set(fieldName, value)}
|
||||
onBlur={() => form.touch(fieldName)}
|
||||
/>
|
||||
)
|
||||
})}
|
||||
</div>
|
||||
|
||||
{postFields}
|
||||
|
||||
<button
|
||||
type="submit"
|
||||
disabled={form.submitting || form.validating}
|
||||
className={styles.submit}
|
||||
>
|
||||
{form.submitting ? 'Submitting...' : displaySubmitLabel}
|
||||
</button>
|
||||
</form>
|
||||
|
||||
{footerLinks && footerLinks.length > 0 && (
|
||||
<div className={styles.footer}>
|
||||
{footerLinks.map((link, i) => (
|
||||
link.onClick ? (
|
||||
<button key={i} type="button" onClick={link.onClick} className={styles.link}>
|
||||
{link.label}
|
||||
</button>
|
||||
) : link.href ? (
|
||||
<a key={i} href={link.href} className={styles.link}>
|
||||
{link.label}
|
||||
</a>
|
||||
) : null
|
||||
))}
|
||||
</div>
|
||||
)}
|
||||
</div>
|
||||
</div>
|
||||
)
|
||||
}
|
||||
|
||||
/**
|
||||
* Internal field component with hydration-safe rendering
|
||||
*/
|
||||
interface AuthFieldProps {
|
||||
field: {
|
||||
name: string
|
||||
label: string
|
||||
type: string
|
||||
widget: string
|
||||
required: boolean
|
||||
disabled: boolean
|
||||
help_text: string
|
||||
max_length?: number | null
|
||||
choices?: Array<{ value: string; label: string }> | null
|
||||
}
|
||||
value: any
|
||||
mounted: boolean
|
||||
touched: boolean
|
||||
errors: Array<{ message: string }>
|
||||
onChange: (value: any) => void
|
||||
onBlur: () => void
|
||||
}
|
||||
|
||||
function AuthField({ field, value, mounted, touched, errors, onChange, onBlur }: AuthFieldProps) {
|
||||
const styles = useStyles()
|
||||
|
||||
const renderInput = () => {
|
||||
// Select dropdown
|
||||
if (field.choices && (field.widget === 'Select' || field.type === 'select')) {
|
||||
return (
|
||||
<select
|
||||
value={value || ''}
|
||||
onChange={(e) => onChange(e.target.value)}
|
||||
onBlur={onBlur}
|
||||
required={field.required}
|
||||
disabled={field.disabled}
|
||||
className={styles.fieldInput}
|
||||
>
|
||||
{field.choices.map((choice) => (
|
||||
<option key={choice.value} value={choice.value}>
|
||||
{choice.label}
|
||||
</option>
|
||||
))}
|
||||
</select>
|
||||
)
|
||||
}
|
||||
|
||||
// Radio buttons
|
||||
if (field.choices && field.widget === 'RadioSelect') {
|
||||
return (
|
||||
<div className={styles.radioGroup}>
|
||||
{field.choices.map((choice) => (
|
||||
<label key={choice.value} className={styles.radioItem}>
|
||||
<input
|
||||
type="radio"
|
||||
name={field.name}
|
||||
value={choice.value}
|
||||
checked={value === choice.value}
|
||||
onChange={(e) => onChange(e.target.value)}
|
||||
onBlur={onBlur}
|
||||
required={field.required}
|
||||
disabled={field.disabled}
|
||||
/>
|
||||
<span>{choice.label}</span>
|
||||
</label>
|
||||
))}
|
||||
</div>
|
||||
)
|
||||
}
|
||||
|
||||
// Checkbox
|
||||
if (field.type === 'checkbox') {
|
||||
return (
|
||||
<input
|
||||
type="checkbox"
|
||||
checked={!!value}
|
||||
onChange={(e) => onChange(e.target.checked)}
|
||||
onBlur={onBlur}
|
||||
required={field.required}
|
||||
disabled={field.disabled}
|
||||
className={styles.checkbox}
|
||||
/>
|
||||
)
|
||||
}
|
||||
|
||||
// Textarea
|
||||
if (field.widget === 'Textarea') {
|
||||
return (
|
||||
<textarea
|
||||
value={value || ''}
|
||||
onChange={(e) => onChange(e.target.value)}
|
||||
onBlur={onBlur}
|
||||
required={field.required}
|
||||
disabled={field.disabled}
|
||||
maxLength={field.max_length || undefined}
|
||||
className={styles.fieldInput}
|
||||
/>
|
||||
)
|
||||
}
|
||||
|
||||
// Default: text input (text, password, email, etc.)
|
||||
return (
|
||||
<input
|
||||
type={field.type}
|
||||
value={value || ''}
|
||||
onChange={(e) => onChange(e.target.value)}
|
||||
onBlur={onBlur}
|
||||
required={field.required}
|
||||
disabled={field.disabled}
|
||||
maxLength={field.max_length || undefined}
|
||||
className={styles.fieldInput}
|
||||
autoComplete="off"
|
||||
/>
|
||||
)
|
||||
}
|
||||
|
||||
return (
|
||||
<div className={styles.field}>
|
||||
<label className={styles.fieldLabel}>
|
||||
{field.label}
|
||||
</label>
|
||||
|
||||
{/* Hydration-safe: render placeholder until mounted */}
|
||||
{mounted ? (
|
||||
renderInput()
|
||||
) : (
|
||||
<div className={styles.fieldInput} style={{ minHeight: '2.75rem' }} />
|
||||
)}
|
||||
|
||||
{/* Field errors (only show if touched) */}
|
||||
{touched && errors.map((err, i) => (
|
||||
<p key={i} className={styles.fieldError}>{err.message}</p>
|
||||
))}
|
||||
</div>
|
||||
)
|
||||
}
|
||||
@@ -0,0 +1,99 @@
|
||||
'use client'
|
||||
|
||||
import { ReactNode, useState, useEffect } from 'react'
|
||||
import { AuthDetails, AuthError, AuthResponse, getAuthDetails } from '../api'
|
||||
import { useAuthContext } from '../contexts/AuthContext'
|
||||
import { useStyles } from '../contexts/StylesContext'
|
||||
|
||||
interface AuthForm {
|
||||
submit: () => void
|
||||
authDetails: AuthDetails
|
||||
fetching: boolean
|
||||
response: AuthResponse | null
|
||||
errors: AuthError[]
|
||||
}
|
||||
|
||||
export default function useAuthForm(
|
||||
submissionAction: () => Promise<AuthResponse>,
|
||||
responseAction?: (response: AuthResponse, authDetails: AuthDetails) => void,
|
||||
): AuthForm {
|
||||
const auth = useAuthContext().auth
|
||||
const [fetching, setFetching] = useState<boolean>(false)
|
||||
const [response, setResponse] = useState<AuthResponse | null>(null)
|
||||
const [errors, setErrors] = useState<AuthError[]>([])
|
||||
const [authDetails, setAuthDetails] = useState<AuthDetails>(getAuthDetails(auth))
|
||||
|
||||
function submit() {
|
||||
setFetching(true)
|
||||
submissionAction()
|
||||
.then((r) => {
|
||||
setResponse(r)
|
||||
setErrors(r.errors || [])
|
||||
setFetching(false)
|
||||
if (r && responseAction) {
|
||||
responseAction(r, authDetails)
|
||||
}
|
||||
setAuthDetails(getAuthDetails(auth))
|
||||
})
|
||||
.catch((e) => {
|
||||
console.error(e)
|
||||
setFetching(false)
|
||||
})
|
||||
}
|
||||
|
||||
return { submit, authDetails, fetching, response, errors }
|
||||
}
|
||||
|
||||
interface AuthFieldProps {
|
||||
title: string
|
||||
name: string
|
||||
type: string
|
||||
init: string
|
||||
onChange: (e: React.ChangeEvent<HTMLInputElement>) => void
|
||||
authErrors: AuthError[]
|
||||
placeholder?: string
|
||||
children?: ReactNode
|
||||
}
|
||||
|
||||
export function AuthField({
|
||||
title,
|
||||
name,
|
||||
type,
|
||||
init,
|
||||
onChange,
|
||||
authErrors,
|
||||
placeholder,
|
||||
children,
|
||||
}: AuthFieldProps) {
|
||||
const styles = useStyles()
|
||||
const [mounted, setMounted] = useState(false)
|
||||
const fieldErrors = authErrors.filter(err => err.param === name)
|
||||
|
||||
useEffect(() => {
|
||||
setMounted(true)
|
||||
}, [])
|
||||
|
||||
return (
|
||||
<div className={styles.field}>
|
||||
<label className={styles.fieldLabel}>{title}</label>
|
||||
{mounted ? (
|
||||
<input
|
||||
type={type}
|
||||
value={init}
|
||||
onChange={onChange}
|
||||
placeholder={placeholder}
|
||||
className={styles.fieldInput}
|
||||
autoComplete="off"
|
||||
/>
|
||||
) : (
|
||||
<div className={styles.fieldInput} style={{ minHeight: '2.75rem' }} />
|
||||
)}
|
||||
{fieldErrors.map((err, i) => (
|
||||
<p key={i} className={styles.fieldError}>{err.message}</p>
|
||||
))}
|
||||
{children}
|
||||
</div>
|
||||
)
|
||||
}
|
||||
|
||||
|
||||
@@ -0,0 +1,127 @@
|
||||
'use client'
|
||||
|
||||
import { useState, ReactNode } from 'react'
|
||||
import { useAllauthAPI } from '../contexts/APIContext'
|
||||
import { useStyles } from '../contexts/StylesContext'
|
||||
import useAuthForm, { AuthField } from './AuthForm'
|
||||
import { AuthResponse, AuthDetails } from '../api'
|
||||
|
||||
interface FieldConfig {
|
||||
name: string
|
||||
title: string
|
||||
type: string
|
||||
placeholder?: string
|
||||
}
|
||||
|
||||
interface FooterLink {
|
||||
href: string
|
||||
label: string
|
||||
}
|
||||
|
||||
interface AuthFormPageProps {
|
||||
title: string
|
||||
subtitle?: string
|
||||
fields: FieldConfig[]
|
||||
submitLabel?: string
|
||||
submittingLabel?: string
|
||||
submitFn: (api: ReturnType<typeof useAllauthAPI>, data: Record<string, string>) => Promise<AuthResponse>
|
||||
onResponse: (response: AuthResponse, authDetails: AuthDetails, data: Record<string, string>) => void
|
||||
footerLinks?: FooterLink[]
|
||||
preFields?: ReactNode
|
||||
postFields?: ReactNode
|
||||
error?: string | null
|
||||
}
|
||||
|
||||
export function AuthFormPage({
|
||||
title,
|
||||
subtitle,
|
||||
fields,
|
||||
submitLabel = 'Submit',
|
||||
submittingLabel = 'Submitting...',
|
||||
submitFn,
|
||||
onResponse,
|
||||
footerLinks,
|
||||
preFields,
|
||||
postFields,
|
||||
error: externalError,
|
||||
}: AuthFormPageProps) {
|
||||
const api = useAllauthAPI()
|
||||
const styles = useStyles()
|
||||
|
||||
const [data, setData] = useState<Record<string, string>>(() =>
|
||||
Object.fromEntries(fields.map(f => [f.name, '']))
|
||||
)
|
||||
|
||||
const authForm = useAuthForm(
|
||||
() => submitFn(api, data),
|
||||
(response, authDetails) => onResponse(response, authDetails, data)
|
||||
)
|
||||
|
||||
const handleSubmit = (e: React.FormEvent) => {
|
||||
e.preventDefault()
|
||||
authForm.submit()
|
||||
}
|
||||
|
||||
const handleFieldChange = (fieldName: string) => (e: React.ChangeEvent<HTMLInputElement>) => {
|
||||
setData(prev => ({ ...prev, [fieldName]: e.target.value }))
|
||||
}
|
||||
|
||||
const formErrors = authForm.errors.filter(err => !err.param || err.param === '__all__')
|
||||
|
||||
return (
|
||||
<div className={styles.container}>
|
||||
<div className={styles.card}>
|
||||
<h1 className={styles.title}>{title}</h1>
|
||||
{subtitle && <p className={styles.subtitle}>{subtitle}</p>}
|
||||
|
||||
{externalError && <p className={styles.error}>{externalError}</p>}
|
||||
{formErrors.length > 0 && (
|
||||
<div className={styles.error}>
|
||||
{formErrors.map((err, i) => (
|
||||
<p key={i}>{err.message}</p>
|
||||
))}
|
||||
</div>
|
||||
)}
|
||||
|
||||
<form onSubmit={handleSubmit} className={styles.form} suppressHydrationWarning>
|
||||
{preFields}
|
||||
|
||||
<div className={styles.fieldsContainer}>
|
||||
{fields.map(field => (
|
||||
<AuthField
|
||||
key={field.name}
|
||||
title={field.title}
|
||||
name={field.name}
|
||||
type={field.type}
|
||||
init={data[field.name]}
|
||||
onChange={handleFieldChange(field.name)}
|
||||
authErrors={authForm.errors}
|
||||
placeholder={field.placeholder}
|
||||
/>
|
||||
))}
|
||||
</div>
|
||||
|
||||
{postFields}
|
||||
|
||||
<button
|
||||
type="submit"
|
||||
className={styles.submit}
|
||||
disabled={authForm.fetching}
|
||||
>
|
||||
{authForm.fetching ? submittingLabel : submitLabel}
|
||||
</button>
|
||||
</form>
|
||||
|
||||
{footerLinks && footerLinks.length > 0 && (
|
||||
<div className={styles.footer}>
|
||||
{footerLinks.map((link, i) => (
|
||||
<a key={i} href={link.href} className={styles.link}>
|
||||
{link.label}
|
||||
</a>
|
||||
))}
|
||||
</div>
|
||||
)}
|
||||
</div>
|
||||
</div>
|
||||
)
|
||||
}
|
||||
@@ -0,0 +1,103 @@
|
||||
'use client'
|
||||
|
||||
import { useState } from 'react'
|
||||
import { useRouter } from '../contexts/RouterContext'
|
||||
import { useConfig } from '../contexts/AuthContext'
|
||||
import { useAllauthAPI } from '../contexts/APIContext'
|
||||
import { useAuthContext } from '../contexts/AuthContext'
|
||||
import { useStyles } from '../contexts/StylesContext'
|
||||
|
||||
interface PasskeyLoginProps {
|
||||
onSuccess?: () => void
|
||||
}
|
||||
|
||||
export function PasskeyLogin({ onSuccess }: PasskeyLoginProps) {
|
||||
const router = useRouter()
|
||||
const config = useConfig()
|
||||
const api = useAllauthAPI()
|
||||
const { refresh } = useAuthContext()
|
||||
const styles = useStyles()
|
||||
const [error, setError] = useState<string | null>(null)
|
||||
const [authenticating, setAuthenticating] = useState(false)
|
||||
|
||||
// Check if passkey login is enabled
|
||||
const passkeyLoginEnabled = config?.data?.mfa?.passkey_login_enabled
|
||||
|
||||
if (!passkeyLoginEnabled) {
|
||||
return null
|
||||
}
|
||||
|
||||
const handlePasskeyLogin = async () => {
|
||||
setError(null)
|
||||
setAuthenticating(true)
|
||||
|
||||
try {
|
||||
const { startAuthentication } = await import('@simplewebauthn/browser')
|
||||
|
||||
// Get login options (challenge) from server
|
||||
const optionsRes = await api.webauthn.requestOptions.login()
|
||||
|
||||
if (optionsRes.status !== 200) {
|
||||
throw new Error('Failed to get login options')
|
||||
}
|
||||
|
||||
// Extract publicKey options - allauth returns { request_options: { publicKey: {...} } }
|
||||
const publicKeyOptions = optionsRes.data?.request_options?.publicKey
|
||||
|
||||
if (!publicKeyOptions?.challenge) {
|
||||
throw new Error('Invalid login options')
|
||||
}
|
||||
|
||||
// Perform WebAuthn authentication in browser
|
||||
// @simplewebauthn/browser v13+ expects { optionsJSON: ... }
|
||||
const credential = await startAuthentication({ optionsJSON: publicKeyOptions as any })
|
||||
|
||||
// Submit credential to server for login
|
||||
const res = await api.webauthn.login(credential)
|
||||
|
||||
if (res.status === 200) {
|
||||
await refresh()
|
||||
if (onSuccess) {
|
||||
onSuccess()
|
||||
} else {
|
||||
const next = router.searchParams.get('next')
|
||||
router.push(next?.startsWith('/') ? next : '/dashboard')
|
||||
}
|
||||
} else {
|
||||
setError('Login failed. Please try again.')
|
||||
}
|
||||
} catch (e: any) {
|
||||
if (e.name === 'AbortError' || e.name === 'NotAllowedError') {
|
||||
// User cancelled - not an error
|
||||
setError(null)
|
||||
} else {
|
||||
setError(e.message || 'Failed to sign in with passkey')
|
||||
}
|
||||
} finally {
|
||||
setAuthenticating(false)
|
||||
}
|
||||
}
|
||||
|
||||
return (
|
||||
<div className={styles.passkeyContainer}>
|
||||
<div className={styles.divider}>
|
||||
<span className={styles.dividerText}>or</span>
|
||||
</div>
|
||||
|
||||
{error && (
|
||||
<div className={styles.error}>
|
||||
<p>{error}</p>
|
||||
</div>
|
||||
)}
|
||||
|
||||
<button
|
||||
type="button"
|
||||
onClick={handlePasskeyLogin}
|
||||
disabled={authenticating}
|
||||
className={styles.passkeyButton}
|
||||
>
|
||||
{authenticating ? 'Waiting for passkey...' : 'Sign in with Passkey'}
|
||||
</button>
|
||||
</div>
|
||||
)
|
||||
}
|
||||
@@ -0,0 +1,56 @@
|
||||
'use client'
|
||||
|
||||
import { useConfig } from '../contexts/AuthContext'
|
||||
import { useAllauthAPI } from '../contexts/APIContext'
|
||||
import { useStyles } from '../contexts/StylesContext'
|
||||
|
||||
interface Provider {
|
||||
id: string
|
||||
name: string
|
||||
}
|
||||
|
||||
interface ProviderListProps {
|
||||
callbackUrl: string
|
||||
process?: 'login' | 'connect'
|
||||
}
|
||||
|
||||
export function ProviderList({ callbackUrl, process = 'login' }: ProviderListProps) {
|
||||
const config = useConfig()
|
||||
const api = useAllauthAPI()
|
||||
const styles = useStyles()
|
||||
|
||||
const providers: Provider[] = config?.data?.socialaccount?.providers || []
|
||||
|
||||
if (providers.length === 0) {
|
||||
return null
|
||||
}
|
||||
|
||||
const handleProviderClick = (providerId: string) => {
|
||||
const provider = api.oauth.provider(providerId)
|
||||
if (process === 'connect') {
|
||||
provider.connect.withRedirect(callbackUrl)
|
||||
} else {
|
||||
provider.login.withRedirect(callbackUrl)
|
||||
}
|
||||
}
|
||||
|
||||
return (
|
||||
<div className={styles.providersContainer}>
|
||||
<div className={styles.divider}>
|
||||
<span className={styles.dividerText}>or continue with</span>
|
||||
</div>
|
||||
<div className={styles.providerButtons}>
|
||||
{providers.map((provider) => (
|
||||
<button
|
||||
key={provider.id}
|
||||
type="button"
|
||||
onClick={() => handleProviderClick(provider.id)}
|
||||
className={styles.providerButton}
|
||||
>
|
||||
{provider.name}
|
||||
</button>
|
||||
))}
|
||||
</div>
|
||||
</div>
|
||||
)
|
||||
}
|
||||
@@ -0,0 +1,41 @@
|
||||
// Main UI component
|
||||
export { AllauthUI } from './AllauthUI'
|
||||
export type { AllauthUIView, AllauthUIMode } from './AllauthUI'
|
||||
|
||||
// Core components
|
||||
export { AuthCard } from './AuthCard'
|
||||
export { AuthFormPage } from './AuthFormPage'
|
||||
export { AuthDjangoForm } from './AuthDjangoForm'
|
||||
export { ProviderList } from './ProviderList'
|
||||
export { PasskeyLogin } from './PasskeyLogin'
|
||||
export { default as useAuthForm, AuthField } from './AuthForm'
|
||||
|
||||
// Django-initiated flow handler (email verification, password reset links, OAuth)
|
||||
export { AllauthRouter } from './AllauthRouter'
|
||||
|
||||
// Settings components
|
||||
export {
|
||||
AuthSettings,
|
||||
ProfileSection,
|
||||
EmailsSection,
|
||||
PasswordSection,
|
||||
PasskeysSection,
|
||||
ConnectionsSection,
|
||||
MFASection,
|
||||
SessionsSection,
|
||||
SettingsSection,
|
||||
SettingsItem,
|
||||
SettingsList,
|
||||
Badge,
|
||||
Button,
|
||||
} from './settings'
|
||||
|
||||
// Individual auth views (for granular control)
|
||||
export {
|
||||
LoginView,
|
||||
SignupView,
|
||||
MFAChooserView,
|
||||
MFAWebAuthnView,
|
||||
MFATOTPView,
|
||||
MFARecoveryCodesView,
|
||||
} from './views'
|
||||
@@ -0,0 +1,79 @@
|
||||
'use client'
|
||||
|
||||
import { useStyles, cx } from '../../contexts/StylesContext'
|
||||
import { ProfileSection } from './ProfileSection'
|
||||
import { EmailsSection } from './EmailsSection'
|
||||
import { PasswordSection } from './PasswordSection'
|
||||
import { PasskeysSection } from './PasskeysSection'
|
||||
import { ConnectionsSection } from './ConnectionsSection'
|
||||
import { MFASection } from './MFASection'
|
||||
import { SessionsSection } from './SessionsSection'
|
||||
import { Button } from './SettingsComponents'
|
||||
|
||||
type SettingsSectionType = 'profile' | 'emails' | 'password' | 'passkeys' | 'connections' | 'mfa' | 'sessions'
|
||||
|
||||
interface AuthSettingsProps {
|
||||
/** Title shown at the top of the settings page */
|
||||
title?: string
|
||||
/** Called when user clicks sign out */
|
||||
onSignOut?: () => void
|
||||
/** Which sections to show. Defaults to all. */
|
||||
sections?: SettingsSectionType[]
|
||||
/** URL to redirect back to after OAuth connect (for connections section) */
|
||||
oauthRedirectUrl?: string
|
||||
}
|
||||
|
||||
const DEFAULT_SECTIONS: SettingsSectionType[] = ['profile', 'emails', 'password', 'passkeys', 'connections', 'mfa', 'sessions']
|
||||
|
||||
/**
|
||||
* AuthSettings renders a complete account settings page.
|
||||
*
|
||||
* It includes sections for:
|
||||
* - Profile (display user info)
|
||||
* - Email addresses (manage, verify, set primary)
|
||||
* - Password change
|
||||
* - Passkeys (add/remove passwordless login)
|
||||
* - Connected accounts (OAuth providers)
|
||||
* - Two-factor authentication (TOTP, recovery codes)
|
||||
* - Active sessions (view/end sessions)
|
||||
*
|
||||
* @example
|
||||
* ```tsx
|
||||
* <AuthSettings
|
||||
* onSignOut={() => router.push('/logout')}
|
||||
* sections={['profile', 'password', 'mfa']} // Only show these sections
|
||||
* />
|
||||
* ```
|
||||
*/
|
||||
export function AuthSettings({
|
||||
title = 'Account Settings',
|
||||
onSignOut,
|
||||
sections = DEFAULT_SECTIONS,
|
||||
oauthRedirectUrl,
|
||||
}: AuthSettingsProps) {
|
||||
const styles = useStyles()
|
||||
const sectionSet = new Set(sections)
|
||||
|
||||
return (
|
||||
<div className={styles.settingsContainer}>
|
||||
<h1 className={styles.settingsPageTitle}>{title}</h1>
|
||||
|
||||
{sectionSet.has('profile') && <ProfileSection />}
|
||||
{sectionSet.has('emails') && <EmailsSection />}
|
||||
{sectionSet.has('password') && <PasswordSection />}
|
||||
{sectionSet.has('passkeys') && <PasskeysSection />}
|
||||
{sectionSet.has('connections') && <ConnectionsSection redirectUrl={oauthRedirectUrl} />}
|
||||
{sectionSet.has('mfa') && <MFASection />}
|
||||
{sectionSet.has('sessions') && <SessionsSection />}
|
||||
|
||||
{/* Sign Out */}
|
||||
{onSignOut && (
|
||||
<section className={styles.settingsCard}>
|
||||
<Button variant="danger" onClick={onSignOut}>
|
||||
Sign Out
|
||||
</Button>
|
||||
</section>
|
||||
)}
|
||||
</div>
|
||||
)
|
||||
}
|
||||
@@ -0,0 +1,87 @@
|
||||
'use client'
|
||||
|
||||
import { useState, useEffect } from 'react'
|
||||
import { useConfig } from '../../contexts/AuthContext'
|
||||
import { useAllauthAPI } from '../../contexts/APIContext'
|
||||
import { SettingsSection, SettingsItem, SettingsList, Button } from './SettingsComponents'
|
||||
|
||||
interface Connection {
|
||||
uid: string
|
||||
provider: { id: string; name: string }
|
||||
display: string
|
||||
}
|
||||
|
||||
interface ConnectionsSectionProps {
|
||||
/** URL to redirect back to after OAuth connect */
|
||||
redirectUrl?: string
|
||||
}
|
||||
|
||||
export function ConnectionsSection({ redirectUrl = '/account' }: ConnectionsSectionProps) {
|
||||
const api = useAllauthAPI()
|
||||
const config = useConfig()
|
||||
const [connections, setConnections] = useState<Connection[]>([])
|
||||
const [loading, setLoading] = useState(true)
|
||||
|
||||
const availableProviders = config?.data?.socialaccount?.providers || []
|
||||
|
||||
const fetchConnections = async () => {
|
||||
const res = await api.oauth.list()
|
||||
if (res.status === 200 && res.data) {
|
||||
setConnections(res.data)
|
||||
}
|
||||
setLoading(false)
|
||||
}
|
||||
|
||||
useEffect(() => { fetchConnections() }, [])
|
||||
|
||||
const handleConnect = (providerId: string) => {
|
||||
api.oauth.provider(providerId).connect.withRedirect(redirectUrl)
|
||||
}
|
||||
|
||||
const handleDisconnect = async (providerId: string, uid: string) => {
|
||||
if (!confirm('Disconnect this account?')) return
|
||||
await api.oauth.provider(providerId).removeFrom(uid)
|
||||
fetchConnections()
|
||||
}
|
||||
|
||||
// Don't render if no providers configured or still loading
|
||||
if (loading) return null
|
||||
|
||||
const connectedProviderIds = connections.map(c => c.provider.id)
|
||||
const unconnectedProviders = availableProviders.filter(
|
||||
(p: { id: string }) => !connectedProviderIds.includes(p.id)
|
||||
)
|
||||
|
||||
// Hide section entirely if no social providers
|
||||
if (connections.length === 0 && availableProviders.length === 0) return null
|
||||
|
||||
return (
|
||||
<SettingsSection title="Connected Accounts">
|
||||
<SettingsList>
|
||||
{connections.map(conn => (
|
||||
<SettingsItem
|
||||
key={conn.uid}
|
||||
label={conn.provider.name}
|
||||
meta={conn.display}
|
||||
actions={
|
||||
<Button variant="danger" onClick={() => handleDisconnect(conn.provider.id, conn.uid)}>
|
||||
Disconnect
|
||||
</Button>
|
||||
}
|
||||
/>
|
||||
))}
|
||||
{unconnectedProviders.map((provider: { id: string; name: string }) => (
|
||||
<SettingsItem
|
||||
key={provider.id}
|
||||
label={provider.name}
|
||||
actions={
|
||||
<Button onClick={() => handleConnect(provider.id)}>
|
||||
Connect
|
||||
</Button>
|
||||
}
|
||||
/>
|
||||
))}
|
||||
</SettingsList>
|
||||
</SettingsSection>
|
||||
)
|
||||
}
|
||||
@@ -0,0 +1,120 @@
|
||||
'use client'
|
||||
|
||||
import { useState, useEffect } from 'react'
|
||||
import { useAllauthAPI } from '../../contexts/APIContext'
|
||||
import { useStyles } from '../../contexts/StylesContext'
|
||||
import { useDjangoFormCore } from 'mizan'
|
||||
import { SettingsSection, SettingsItem, SettingsList, Badge, Button } from './SettingsComponents'
|
||||
|
||||
interface Email {
|
||||
email: string
|
||||
primary: boolean
|
||||
verified: boolean
|
||||
}
|
||||
|
||||
export function EmailsSection() {
|
||||
const api = useAllauthAPI()
|
||||
const styles = useStyles()
|
||||
const [emails, setEmails] = useState<Email[]>([])
|
||||
const [loading, setLoading] = useState(true)
|
||||
const addEmailForm = useDjangoFormCore<Record<string, unknown>>({ name: 'add_email' })
|
||||
|
||||
const fetchEmails = async () => {
|
||||
const res = await api.account.emails.list()
|
||||
if (res.status === 200 && res.data) {
|
||||
setEmails(res.data)
|
||||
}
|
||||
setLoading(false)
|
||||
}
|
||||
|
||||
useEffect(() => { fetchEmails() }, [])
|
||||
|
||||
const handleAdd = async (e: React.FormEvent) => {
|
||||
e.preventDefault()
|
||||
const result = await addEmailForm.submit()
|
||||
if (result.success) {
|
||||
addEmailForm.reset()
|
||||
fetchEmails()
|
||||
}
|
||||
}
|
||||
|
||||
const handleRemove = async (email: string) => {
|
||||
if (!confirm(`Remove ${email}?`)) return
|
||||
await api.account.emails.remove(email)
|
||||
fetchEmails()
|
||||
}
|
||||
|
||||
const handleSetPrimary = async (email: string) => {
|
||||
await api.account.emails.setPrimary(email)
|
||||
fetchEmails()
|
||||
}
|
||||
|
||||
const handleResendVerification = async (email: string) => {
|
||||
await api.account.emails.verification.dispatch(email)
|
||||
alert('Verification email sent!')
|
||||
}
|
||||
|
||||
if (loading) return null
|
||||
|
||||
return (
|
||||
<SettingsSection title="Email Addresses">
|
||||
<SettingsList>
|
||||
{emails.map(email => (
|
||||
<SettingsItem
|
||||
key={email.email}
|
||||
label={
|
||||
<>
|
||||
{email.email}
|
||||
{email.primary && <Badge variant="primary">Primary</Badge>}
|
||||
{!email.verified && <Badge variant="warning">Unverified</Badge>}
|
||||
</>
|
||||
}
|
||||
actions={
|
||||
<>
|
||||
{!email.verified && (
|
||||
<Button variant="secondary" onClick={() => handleResendVerification(email.email)}>
|
||||
Verify
|
||||
</Button>
|
||||
)}
|
||||
{!email.primary && email.verified && (
|
||||
<Button onClick={() => handleSetPrimary(email.email)}>
|
||||
Make Primary
|
||||
</Button>
|
||||
)}
|
||||
{!email.primary && (
|
||||
<Button variant="danger" onClick={() => handleRemove(email.email)}>
|
||||
Remove
|
||||
</Button>
|
||||
)}
|
||||
</>
|
||||
}
|
||||
/>
|
||||
))}
|
||||
</SettingsList>
|
||||
|
||||
{!addEmailForm.loading && (
|
||||
<form onSubmit={handleAdd} className={styles.inlineForm}>
|
||||
<div className={styles.field}>
|
||||
<label className={styles.fieldLabel}>
|
||||
{addEmailForm.schema?.fields.email?.label || 'Add Email'}
|
||||
</label>
|
||||
<input
|
||||
type="email"
|
||||
value={(addEmailForm.data.email as string) || ''}
|
||||
onChange={(e) => addEmailForm.set('email', e.target.value)}
|
||||
onBlur={() => addEmailForm.touch('email')}
|
||||
className={styles.fieldInput}
|
||||
required
|
||||
/>
|
||||
{addEmailForm.getFieldErrors('email').map((err, i) => (
|
||||
<p key={i} className={styles.fieldError}>{err.message}</p>
|
||||
))}
|
||||
</div>
|
||||
<Button type="submit">
|
||||
{addEmailForm.schema?.submit_label || 'Add'}
|
||||
</Button>
|
||||
</form>
|
||||
)}
|
||||
</SettingsSection>
|
||||
)
|
||||
}
|
||||
@@ -0,0 +1,171 @@
|
||||
'use client'
|
||||
|
||||
import { useState, useEffect } from 'react'
|
||||
import { useAllauthAPI } from '../../contexts/APIContext'
|
||||
import { useStyles } from '../../contexts/StylesContext'
|
||||
import { SettingsSection, SettingsItem, Badge, Button } from './SettingsComponents'
|
||||
import type { Authenticator, TOTPStatus } from '../../types'
|
||||
|
||||
interface TOTPSetup {
|
||||
secret: string
|
||||
totp_url: string
|
||||
}
|
||||
|
||||
export function MFASection() {
|
||||
const api = useAllauthAPI()
|
||||
const [authenticators, setAuthenticators] = useState<Authenticator[]>([])
|
||||
const [loading, setLoading] = useState(true)
|
||||
const [available, setAvailable] = useState(true)
|
||||
|
||||
const fetchAuthenticators = async () => {
|
||||
try {
|
||||
const res = await api.mfa.list()
|
||||
if (res.status === 200 && res.data) {
|
||||
setAuthenticators(res.data as Authenticator[])
|
||||
} else {
|
||||
// Non-200 status means MFA not available
|
||||
setAvailable(false)
|
||||
}
|
||||
} catch {
|
||||
setAvailable(false)
|
||||
}
|
||||
setLoading(false)
|
||||
}
|
||||
|
||||
useEffect(() => { fetchAuthenticators() }, [])
|
||||
|
||||
if (loading || !available) return null
|
||||
|
||||
const hasTOTP = authenticators.some(a => a.type === 'totp')
|
||||
|
||||
return (
|
||||
<SettingsSection title="Two-Factor Authentication">
|
||||
<TOTPSubsection
|
||||
hasTOTP={hasTOTP}
|
||||
onUpdate={fetchAuthenticators}
|
||||
/>
|
||||
|
||||
{hasTOTP && (
|
||||
<RecoveryCodesSubsection />
|
||||
)}
|
||||
</SettingsSection>
|
||||
)
|
||||
}
|
||||
|
||||
// --- TOTP Subsection ---
|
||||
|
||||
function TOTPSubsection({ hasTOTP, onUpdate }: { hasTOTP: boolean; onUpdate: () => void }) {
|
||||
const api = useAllauthAPI()
|
||||
const styles = useStyles()
|
||||
const [showSetup, setShowSetup] = useState(false)
|
||||
const [setup, setSetup] = useState<TOTPSetup | null>(null)
|
||||
const [code, setCode] = useState('')
|
||||
|
||||
const handleStartSetup = async () => {
|
||||
const res = await api.mfa.totp.getStatus()
|
||||
// allauth returns TOTP status with secret and totp_url for setup
|
||||
const data = res.data as TOTPStatus | undefined
|
||||
if (data?.secret && data?.totp_url) {
|
||||
setSetup({ secret: data.secret, totp_url: data.totp_url })
|
||||
setShowSetup(true)
|
||||
}
|
||||
}
|
||||
|
||||
const handleActivate = async (e: React.FormEvent) => {
|
||||
e.preventDefault()
|
||||
const res = await api.mfa.totp.activate(code)
|
||||
if (res.status === 200) {
|
||||
setShowSetup(false)
|
||||
setSetup(null)
|
||||
setCode('')
|
||||
onUpdate()
|
||||
}
|
||||
}
|
||||
|
||||
const handleDeactivate = async () => {
|
||||
if (!confirm('Disable authenticator app?')) return
|
||||
await api.mfa.totp.deactivate()
|
||||
onUpdate()
|
||||
}
|
||||
|
||||
return (
|
||||
<>
|
||||
<h3 className={styles.settingsSubtitle}>Authenticator App</h3>
|
||||
|
||||
{showSetup && setup ? (
|
||||
<div className={styles.totpSetup}>
|
||||
<p>Scan this QR code with your authenticator app:</p>
|
||||
<img
|
||||
src={`https://api.qrserver.com/v1/create-qr-code/?size=180x180&data=${encodeURIComponent(setup.totp_url)}`}
|
||||
alt="TOTP QR Code"
|
||||
className={styles.qrCode}
|
||||
/>
|
||||
<p className={styles.settingsItemMeta}>Secret: {setup.secret}</p>
|
||||
<form onSubmit={handleActivate} className={styles.inlineForm}>
|
||||
<div className={styles.field}>
|
||||
<input
|
||||
type="text"
|
||||
value={code}
|
||||
onChange={(e) => setCode(e.target.value)}
|
||||
placeholder="Verification Code"
|
||||
className={styles.fieldInput}
|
||||
/>
|
||||
</div>
|
||||
<Button type="submit">Activate</Button>
|
||||
<Button type="button" variant="secondary" onClick={() => setShowSetup(false)}>
|
||||
Cancel
|
||||
</Button>
|
||||
</form>
|
||||
</div>
|
||||
) : hasTOTP ? (
|
||||
<SettingsItem
|
||||
label={<>Authenticator App <Badge variant="success">Active</Badge></>}
|
||||
actions={<Button variant="danger" onClick={handleDeactivate}>Disable</Button>}
|
||||
/>
|
||||
) : (
|
||||
<Button onClick={handleStartSetup}>Set Up Authenticator</Button>
|
||||
)}
|
||||
</>
|
||||
)
|
||||
}
|
||||
|
||||
// --- Recovery Codes Subsection ---
|
||||
|
||||
function RecoveryCodesSubsection() {
|
||||
const api = useAllauthAPI()
|
||||
const styles = useStyles()
|
||||
const [codes, setCodes] = useState<string[]>([])
|
||||
|
||||
const handleView = async () => {
|
||||
const res = await api.mfa.recoveryCodes.list()
|
||||
if (res.status === 200) {
|
||||
setCodes(res.data?.unused_codes || [])
|
||||
}
|
||||
}
|
||||
|
||||
const handleRegenerate = async () => {
|
||||
if (!confirm('Generate new codes? Old codes will stop working.')) return
|
||||
const res = await api.mfa.recoveryCodes.regenerate()
|
||||
if (res.status === 200) {
|
||||
setCodes(res.data?.unused_codes || [])
|
||||
}
|
||||
}
|
||||
|
||||
return (
|
||||
<>
|
||||
<h3 className={styles.settingsSubtitle}>Recovery Codes</h3>
|
||||
|
||||
{codes.length > 0 ? (
|
||||
<div>
|
||||
<div className={styles.recoveryCodes}>
|
||||
{codes.map((code, i) => <span key={i}>{code}</span>)}
|
||||
</div>
|
||||
<p className={styles.settingsItemMeta}>Store these safely. Each code works once.</p>
|
||||
<Button variant="secondary" onClick={handleRegenerate}>Regenerate</Button>
|
||||
</div>
|
||||
) : (
|
||||
<Button variant="secondary" onClick={handleView}>View Recovery Codes</Button>
|
||||
)}
|
||||
</>
|
||||
)
|
||||
}
|
||||
@@ -0,0 +1,103 @@
|
||||
'use client'
|
||||
|
||||
import { useState, useEffect } from 'react'
|
||||
import { useAllauthAPI } from '../../contexts/APIContext'
|
||||
import { useConfig } from '../../contexts/AuthContext'
|
||||
import { useStyles } from '../../contexts/StylesContext'
|
||||
import { SettingsSection, SettingsItem, SettingsList, Button } from './SettingsComponents'
|
||||
import type { Authenticator, WebAuthnAuthenticator } from '../../types'
|
||||
|
||||
export function PasskeysSection() {
|
||||
const api = useAllauthAPI()
|
||||
const config = useConfig()
|
||||
const styles = useStyles()
|
||||
const [passkeys, setPasskeys] = useState<WebAuthnAuthenticator[]>([])
|
||||
const [loading, setLoading] = useState(true)
|
||||
|
||||
// Check if passkey login is enabled
|
||||
const passkeyLoginEnabled = config?.data?.mfa?.passkey_login_enabled
|
||||
|
||||
const fetchPasskeys = async () => {
|
||||
try {
|
||||
const res = await api.mfa.list()
|
||||
if (res.status === 200 && res.data) {
|
||||
const authenticators = res.data as Authenticator[]
|
||||
setPasskeys(authenticators.filter((a): a is WebAuthnAuthenticator => a.type === 'webauthn'))
|
||||
}
|
||||
} catch {
|
||||
// Silently fail - passkeys just won't show
|
||||
}
|
||||
setLoading(false)
|
||||
}
|
||||
|
||||
useEffect(() => { fetchPasskeys() }, [])
|
||||
|
||||
// Don't render if passkey login isn't enabled
|
||||
if (!passkeyLoginEnabled) return null
|
||||
if (loading) return null
|
||||
|
||||
const handleAdd = async () => {
|
||||
try {
|
||||
const { startRegistration } = await import('@simplewebauthn/browser')
|
||||
|
||||
// Request creation options - use passwordless=true for login passkeys
|
||||
const optionsRes = await api.webauthn.requestOptions.creation(true)
|
||||
|
||||
if (optionsRes.status !== 200) {
|
||||
return
|
||||
}
|
||||
|
||||
const publicKeyOptions = optionsRes.data?.creation_options?.publicKey
|
||||
if (!publicKeyOptions) throw new Error('Invalid options response')
|
||||
|
||||
// @simplewebauthn/browser v13+ expects { optionsJSON: ... }
|
||||
const credential = await startRegistration({ optionsJSON: publicKeyOptions as any })
|
||||
const name = prompt('Name this passkey:') || 'Passkey'
|
||||
|
||||
const res = await api.webauthn.add(name, credential)
|
||||
if (res.status === 200) {
|
||||
fetchPasskeys()
|
||||
}
|
||||
} catch (e: any) {
|
||||
if (e.name !== 'AbortError') {
|
||||
alert(e.message || 'Failed to add passkey')
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
const handleRemove = async (id: number) => {
|
||||
if (!confirm('Remove this passkey? You won\'t be able to use it to sign in anymore.')) return
|
||||
await api.webauthn.delete([id])
|
||||
fetchPasskeys()
|
||||
}
|
||||
|
||||
return (
|
||||
<SettingsSection title="Passkeys">
|
||||
<p className={styles.settingsItemMeta} style={{ marginBottom: '1rem' }}>
|
||||
Passkeys let you sign in quickly using your device's biometrics or security key.
|
||||
No password needed.
|
||||
</p>
|
||||
|
||||
{passkeys.length > 0 && (
|
||||
<SettingsList>
|
||||
{passkeys.map(passkey => (
|
||||
<SettingsItem
|
||||
key={passkey.id}
|
||||
label={passkey.name}
|
||||
meta={`Added ${new Date(passkey.created_at * 1000).toLocaleDateString()}`}
|
||||
actions={
|
||||
<Button variant="danger" onClick={() => handleRemove(passkey.id)}>
|
||||
Remove
|
||||
</Button>
|
||||
}
|
||||
/>
|
||||
))}
|
||||
</SettingsList>
|
||||
)}
|
||||
|
||||
<Button onClick={handleAdd}>
|
||||
{passkeys.length > 0 ? 'Add Another Passkey' : 'Set Up Passkey'}
|
||||
</Button>
|
||||
</SettingsSection>
|
||||
)
|
||||
}
|
||||
@@ -0,0 +1,54 @@
|
||||
'use client'
|
||||
|
||||
import { useDjangoFormCore } from 'mizan'
|
||||
import { useStyles } from '../../contexts/StylesContext'
|
||||
import { SettingsSection, Button } from './SettingsComponents'
|
||||
|
||||
export function PasswordSection() {
|
||||
const styles = useStyles()
|
||||
const form = useDjangoFormCore<Record<string, unknown>>({ name: 'change_password' })
|
||||
|
||||
const handleSubmit = async (e: React.FormEvent) => {
|
||||
e.preventDefault()
|
||||
const result = await form.submit()
|
||||
if (result.success) {
|
||||
form.reset()
|
||||
alert('Password changed successfully!')
|
||||
}
|
||||
}
|
||||
|
||||
if (form.loading) return null
|
||||
|
||||
return (
|
||||
<SettingsSection title={form.schema?.title || 'Change Password'}>
|
||||
<form onSubmit={handleSubmit} className={styles.form}>
|
||||
<div className={styles.fieldsContainer}>
|
||||
{form.schema?.fieldOrder.map(fieldName => {
|
||||
const field = form.schema!.fields[fieldName]
|
||||
return (
|
||||
<div key={fieldName} className={styles.field}>
|
||||
<label className={styles.fieldLabel}>{field.label}</label>
|
||||
<input
|
||||
type={field.type}
|
||||
value={(form.data[fieldName] as string) || ''}
|
||||
onChange={(e) => form.set(fieldName, e.target.value)}
|
||||
onBlur={() => form.touch(fieldName)}
|
||||
className={styles.fieldInput}
|
||||
required={field.required}
|
||||
/>
|
||||
{form.touchedFields.has(fieldName) &&
|
||||
form.getFieldErrors(fieldName).map((err, i) => (
|
||||
<p key={i} className={styles.fieldError}>{err.message}</p>
|
||||
))
|
||||
}
|
||||
</div>
|
||||
)
|
||||
})}
|
||||
</div>
|
||||
<Button type="submit" disabled={form.submitting}>
|
||||
{form.submitting ? 'Changing...' : (form.schema?.submit_label || 'Change Password')}
|
||||
</Button>
|
||||
</form>
|
||||
</SettingsSection>
|
||||
)
|
||||
}
|
||||
@@ -0,0 +1,22 @@
|
||||
'use client'
|
||||
|
||||
import { useUser } from '../../contexts/AuthContext'
|
||||
import { SettingsSection, SettingsItem, SettingsList } from './SettingsComponents'
|
||||
|
||||
export function ProfileSection() {
|
||||
const user = useUser()
|
||||
|
||||
return (
|
||||
<SettingsSection title="Profile">
|
||||
<SettingsList>
|
||||
<SettingsItem label="Email" meta={user?.email} />
|
||||
{user?.first_name && (
|
||||
<SettingsItem
|
||||
label="Name"
|
||||
meta={`${user.first_name} ${user.last_name || ''}`.trim()}
|
||||
/>
|
||||
)}
|
||||
</SettingsList>
|
||||
</SettingsSection>
|
||||
)
|
||||
}
|
||||
@@ -0,0 +1,88 @@
|
||||
'use client'
|
||||
|
||||
import { useState, useEffect } from 'react'
|
||||
import { useAllauthAPI } from '../../contexts/APIContext'
|
||||
import { SettingsSection, SettingsItem, SettingsList, Badge, Button } from './SettingsComponents'
|
||||
import type { Session } from '../../types'
|
||||
|
||||
function parseUserAgent(ua: string): string {
|
||||
if (ua.includes('Chrome')) return 'Chrome'
|
||||
if (ua.includes('Firefox')) return 'Firefox'
|
||||
if (ua.includes('Safari')) return 'Safari'
|
||||
if (ua.includes('Edge')) return 'Edge'
|
||||
return 'Unknown Browser'
|
||||
}
|
||||
|
||||
export function SessionsSection() {
|
||||
const api = useAllauthAPI()
|
||||
const [sessions, setSessions] = useState<Session[]>([])
|
||||
const [loading, setLoading] = useState(true)
|
||||
const [available, setAvailable] = useState(true)
|
||||
|
||||
const fetchSessions = async () => {
|
||||
try {
|
||||
const res = await api.session.list()
|
||||
if (res.status === 200 && res.data) {
|
||||
setSessions(res.data as Session[])
|
||||
} else {
|
||||
// Non-200 status means sessions feature not available
|
||||
setAvailable(false)
|
||||
}
|
||||
} catch {
|
||||
setAvailable(false)
|
||||
}
|
||||
setLoading(false)
|
||||
}
|
||||
|
||||
useEffect(() => { fetchSessions() }, [])
|
||||
|
||||
const handleEnd = async (id: number) => {
|
||||
if (!confirm('End this session?')) return
|
||||
await api.session.remove([id])
|
||||
fetchSessions()
|
||||
}
|
||||
|
||||
const handleEndAllOthers = async () => {
|
||||
const otherIds = sessions.filter(s => !s.is_current).map(s => s.id)
|
||||
if (otherIds.length === 0) return
|
||||
if (!confirm(`End ${otherIds.length} other session(s)?`)) return
|
||||
await api.session.remove(otherIds)
|
||||
fetchSessions()
|
||||
}
|
||||
|
||||
if (loading || !available) return null
|
||||
|
||||
const otherSessions = sessions.filter(s => !s.is_current)
|
||||
|
||||
return (
|
||||
<SettingsSection title="Active Sessions">
|
||||
<SettingsList>
|
||||
{sessions.map(session => (
|
||||
<SettingsItem
|
||||
key={session.id}
|
||||
label={
|
||||
<>
|
||||
{parseUserAgent(session.user_agent)}
|
||||
{session.is_current && <Badge variant="success">Current</Badge>}
|
||||
</>
|
||||
}
|
||||
meta={`${session.ip} · ${session.last_seen_at ? new Date(session.last_seen_at * 1000).toLocaleString() : 'Unknown'}`}
|
||||
actions={
|
||||
!session.is_current && (
|
||||
<Button variant="danger" onClick={() => handleEnd(session.id)}>
|
||||
End
|
||||
</Button>
|
||||
)
|
||||
}
|
||||
/>
|
||||
))}
|
||||
</SettingsList>
|
||||
|
||||
{otherSessions.length > 0 && (
|
||||
<Button variant="danger" onClick={handleEndAllOthers}>
|
||||
End All Other Sessions
|
||||
</Button>
|
||||
)}
|
||||
</SettingsSection>
|
||||
)
|
||||
}
|
||||
@@ -0,0 +1,76 @@
|
||||
'use client'
|
||||
|
||||
import { useStyles, cx } from '../../contexts/StylesContext'
|
||||
|
||||
interface SettingsSectionProps {
|
||||
title: string
|
||||
children: React.ReactNode
|
||||
}
|
||||
|
||||
export function SettingsSection({ title, children }: SettingsSectionProps) {
|
||||
const styles = useStyles()
|
||||
return (
|
||||
<section className={styles.settingsCard}>
|
||||
<h2 className={styles.settingsSectionTitle}>{title}</h2>
|
||||
{children}
|
||||
</section>
|
||||
)
|
||||
}
|
||||
|
||||
interface SettingsItemProps {
|
||||
label: React.ReactNode
|
||||
meta?: React.ReactNode
|
||||
actions?: React.ReactNode
|
||||
}
|
||||
|
||||
export function SettingsItem({ label, meta, actions }: SettingsItemProps) {
|
||||
const styles = useStyles()
|
||||
return (
|
||||
<div className={styles.settingsItem}>
|
||||
<div className={styles.settingsItemInfo}>
|
||||
<span className={styles.settingsItemLabel}>{label}</span>
|
||||
{meta && <span className={styles.settingsItemMeta}>{meta}</span>}
|
||||
</div>
|
||||
{actions && <div className={styles.settingsItemActions}>{actions}</div>}
|
||||
</div>
|
||||
)
|
||||
}
|
||||
|
||||
export function SettingsList({ children }: { children: React.ReactNode }) {
|
||||
const styles = useStyles()
|
||||
return <div className={styles.settingsList}>{children}</div>
|
||||
}
|
||||
|
||||
type BadgeVariant = 'primary' | 'success' | 'warning' | 'danger'
|
||||
|
||||
export function Badge({ variant, children }: { variant: BadgeVariant, children: React.ReactNode }) {
|
||||
const styles = useStyles()
|
||||
const variantClass = {
|
||||
primary: styles.badgePrimary,
|
||||
success: styles.badgeSuccess,
|
||||
warning: styles.badgeUnverified,
|
||||
danger: styles.badgeDanger,
|
||||
}[variant]
|
||||
|
||||
return <span className={cx(styles.badge, variantClass)}>{children}</span>
|
||||
}
|
||||
|
||||
interface ButtonProps extends React.ButtonHTMLAttributes<HTMLButtonElement> {
|
||||
variant?: 'primary' | 'secondary' | 'danger'
|
||||
size?: 'small' | 'normal'
|
||||
}
|
||||
|
||||
export function Button({ variant = 'primary', size = 'small', className, children, ...props }: ButtonProps) {
|
||||
const styles = useStyles()
|
||||
const variantClass = {
|
||||
primary: styles.smallButtonPrimary,
|
||||
secondary: styles.smallButtonSecondary,
|
||||
danger: styles.smallButtonDanger,
|
||||
}[variant]
|
||||
|
||||
return (
|
||||
<button className={cx(styles.smallButton, variantClass, className)} {...props}>
|
||||
{children}
|
||||
</button>
|
||||
)
|
||||
}
|
||||
@@ -0,0 +1,20 @@
|
||||
// Main settings component
|
||||
export { AuthSettings } from './AuthSettings'
|
||||
|
||||
// Individual sections (for custom layouts)
|
||||
export { ProfileSection } from './ProfileSection'
|
||||
export { EmailsSection } from './EmailsSection'
|
||||
export { PasswordSection } from './PasswordSection'
|
||||
export { PasskeysSection } from './PasskeysSection'
|
||||
export { ConnectionsSection } from './ConnectionsSection'
|
||||
export { MFASection } from './MFASection'
|
||||
export { SessionsSection } from './SessionsSection'
|
||||
|
||||
// Building blocks (for custom components)
|
||||
export {
|
||||
SettingsSection,
|
||||
SettingsItem,
|
||||
SettingsList,
|
||||
Badge,
|
||||
Button,
|
||||
} from './SettingsComponents'
|
||||
@@ -0,0 +1,75 @@
|
||||
'use client'
|
||||
|
||||
import { useAuthContext, useConfig } from '../../contexts/AuthContext'
|
||||
import { getAuthDetails } from '../../api'
|
||||
import { AuthDjangoForm } from '../AuthDjangoForm'
|
||||
import { PasskeyLogin } from '../PasskeyLogin'
|
||||
import { ProviderList } from '../ProviderList'
|
||||
import type { AllauthConfiguration } from '../../types'
|
||||
|
||||
interface LoginViewProps {
|
||||
/** Called after successful login (or when MFA is triggered) */
|
||||
onSuccess?: () => void
|
||||
/** Called when user clicks "Create account" */
|
||||
onSignupClick?: () => void
|
||||
/** Called when user clicks "Forgot password" */
|
||||
onForgotPasswordClick?: () => void
|
||||
/** Called when user clicks "Sign in with code" */
|
||||
onLoginByCodeClick?: () => void
|
||||
/** OAuth callback URL for social providers */
|
||||
oauthCallbackUrl?: string
|
||||
}
|
||||
|
||||
export function LoginView({
|
||||
onSuccess,
|
||||
onSignupClick,
|
||||
onForgotPasswordClick,
|
||||
onLoginByCodeClick,
|
||||
oauthCallbackUrl,
|
||||
}: LoginViewProps) {
|
||||
const { refresh } = useAuthContext()
|
||||
const config = useConfig()
|
||||
|
||||
// Get feature flags from backend config
|
||||
const allauthConfig = config?.data as AllauthConfiguration | undefined
|
||||
const isSignupEnabled = allauthConfig?.account?.is_open_for_signup ?? true
|
||||
const isLoginByCodeEnabled = allauthConfig?.account?.login_by_code_enabled ?? false
|
||||
|
||||
const handleSuccess = async () => {
|
||||
const newAuth = await refresh()
|
||||
const details = getAuthDetails(newAuth)
|
||||
|
||||
// Only call onSuccess if fully authenticated (no pending MFA)
|
||||
// If MFA is pending, AllauthUI will handle showing the MFA view
|
||||
if (details.isAuthenticated) {
|
||||
onSuccess?.()
|
||||
}
|
||||
}
|
||||
|
||||
// Build footer links based on provided callbacks AND backend config
|
||||
const footerLinks: Array<{ href?: string; label: string; onClick?: () => void }> = []
|
||||
|
||||
if (onForgotPasswordClick) {
|
||||
footerLinks.push({ label: 'Forgot your password?', onClick: onForgotPasswordClick })
|
||||
}
|
||||
if (onLoginByCodeClick && isLoginByCodeEnabled) {
|
||||
footerLinks.push({ label: 'Sign in with a code instead', onClick: onLoginByCodeClick })
|
||||
}
|
||||
if (onSignupClick && isSignupEnabled) {
|
||||
footerLinks.push({ label: "Don't have an account? Sign up", onClick: onSignupClick })
|
||||
}
|
||||
|
||||
return (
|
||||
<AuthDjangoForm
|
||||
formName="login"
|
||||
onSuccess={handleSuccess}
|
||||
footerLinks={footerLinks}
|
||||
postFields={
|
||||
<>
|
||||
<PasskeyLogin onSuccess={onSuccess} />
|
||||
{oauthCallbackUrl && <ProviderList callbackUrl={oauthCallbackUrl} />}
|
||||
</>
|
||||
}
|
||||
/>
|
||||
)
|
||||
}
|
||||
@@ -0,0 +1,137 @@
|
||||
'use client'
|
||||
|
||||
import { useState } from 'react'
|
||||
import { AuthenticatorType } from '../../defines'
|
||||
import { useAllauthAPI } from '../../contexts/APIContext'
|
||||
import { useStyles } from '../../contexts/StylesContext'
|
||||
import { AuthCard } from '../AuthCard'
|
||||
import { MFATOTPView } from './MFATOTPView'
|
||||
import { MFAWebAuthnView } from './MFAWebAuthnView'
|
||||
import { MFARecoveryCodesView } from './MFARecoveryCodesView'
|
||||
|
||||
const MFA_OPTIONS: Record<string, { label: string; description: string }> = {
|
||||
[AuthenticatorType.WEBAUTHN]: {
|
||||
label: 'Security Key / Passkey',
|
||||
description: 'Use your registered security key or passkey',
|
||||
},
|
||||
[AuthenticatorType.TOTP]: {
|
||||
label: 'Authenticator App',
|
||||
description: 'Enter a code from your authenticator app',
|
||||
},
|
||||
[AuthenticatorType.RECOVERY_CODES]: {
|
||||
label: 'Recovery Code',
|
||||
description: 'Use one of your recovery codes',
|
||||
},
|
||||
}
|
||||
|
||||
interface MFAChooserViewProps {
|
||||
types: string[]
|
||||
onSuccess?: () => void
|
||||
onCancel?: () => void
|
||||
isReauth?: boolean
|
||||
}
|
||||
|
||||
export function MFAChooserView({ types, onSuccess, onCancel, isReauth }: MFAChooserViewProps) {
|
||||
const api = useAllauthAPI()
|
||||
const styles = useStyles()
|
||||
const [selectedType, setSelectedType] = useState<string | null>(null)
|
||||
const [cancelling, setCancelling] = useState(false)
|
||||
|
||||
// Filter to only show options that are available
|
||||
const availableOptions = types
|
||||
.filter(type => MFA_OPTIONS[type])
|
||||
.map(type => ({ type, ...MFA_OPTIONS[type] }))
|
||||
|
||||
const handleCancel = async () => {
|
||||
setCancelling(true)
|
||||
try {
|
||||
await api.session.logout()
|
||||
onCancel?.()
|
||||
} catch {
|
||||
setCancelling(false)
|
||||
}
|
||||
}
|
||||
|
||||
const handleBack = types.length > 1 ? () => setSelectedType(null) : undefined
|
||||
|
||||
// If a type is selected, show that method's view
|
||||
if (selectedType === AuthenticatorType.TOTP) {
|
||||
return (
|
||||
<MFATOTPView
|
||||
onSuccess={onSuccess}
|
||||
onCancel={onCancel}
|
||||
onBack={handleBack}
|
||||
isReauth={isReauth}
|
||||
/>
|
||||
)
|
||||
}
|
||||
|
||||
if (selectedType === AuthenticatorType.WEBAUTHN) {
|
||||
return (
|
||||
<MFAWebAuthnView
|
||||
onSuccess={onSuccess}
|
||||
onCancel={onCancel}
|
||||
onBack={handleBack}
|
||||
isReauth={isReauth}
|
||||
/>
|
||||
)
|
||||
}
|
||||
|
||||
if (selectedType === AuthenticatorType.RECOVERY_CODES) {
|
||||
return (
|
||||
<MFARecoveryCodesView
|
||||
onSuccess={onSuccess}
|
||||
onCancel={onCancel}
|
||||
onBack={handleBack}
|
||||
isReauth={isReauth}
|
||||
/>
|
||||
)
|
||||
}
|
||||
|
||||
// Show chooser
|
||||
if (availableOptions.length === 0) {
|
||||
return (
|
||||
<AuthCard
|
||||
title="Two-Factor Authentication"
|
||||
subtitle="No authentication methods available."
|
||||
footerLinks={onCancel ? [
|
||||
{ label: 'Cancel and go back', onClick: handleCancel },
|
||||
] : []}
|
||||
/>
|
||||
)
|
||||
}
|
||||
|
||||
return (
|
||||
<div className={styles.container}>
|
||||
<div className={styles.card}>
|
||||
<h1 className={styles.title}>Two-Factor Authentication</h1>
|
||||
<p className={styles.subtitle}>Choose how you want to verify your identity.</p>
|
||||
|
||||
<div className={styles.form}>
|
||||
{availableOptions.map(option => (
|
||||
<button
|
||||
key={option.type}
|
||||
onClick={() => setSelectedType(option.type)}
|
||||
className={styles.providerButton}
|
||||
>
|
||||
<div style={{ textAlign: 'left' }}>
|
||||
<div style={{ fontWeight: 600 }}>{option.label}</div>
|
||||
<div style={{ fontSize: '0.8125rem', opacity: 0.7, marginTop: '0.25rem' }}>
|
||||
{option.description}
|
||||
</div>
|
||||
</div>
|
||||
</button>
|
||||
))}
|
||||
</div>
|
||||
|
||||
{onCancel && (
|
||||
<div className={styles.footer}>
|
||||
<button onClick={handleCancel} disabled={cancelling} className={styles.link}>
|
||||
{cancelling ? 'Cancelling...' : 'Cancel and go back'}
|
||||
</button>
|
||||
</div>
|
||||
)}
|
||||
</div>
|
||||
</div>
|
||||
)
|
||||
}
|
||||
@@ -0,0 +1,51 @@
|
||||
'use client'
|
||||
|
||||
import { useState } from 'react'
|
||||
import { useAllauthAPI } from '../../contexts/APIContext'
|
||||
import { AuthDjangoForm } from '../AuthDjangoForm'
|
||||
|
||||
interface MFARecoveryCodesViewProps {
|
||||
onSuccess?: () => void
|
||||
onCancel?: () => void
|
||||
onBack?: () => void
|
||||
isReauth?: boolean
|
||||
}
|
||||
|
||||
export function MFARecoveryCodesView({ onSuccess, onCancel, onBack, isReauth }: MFARecoveryCodesViewProps) {
|
||||
const api = useAllauthAPI()
|
||||
const [cancelling, setCancelling] = useState(false)
|
||||
|
||||
const handleCancel = async () => {
|
||||
setCancelling(true)
|
||||
try {
|
||||
await api.session.logout()
|
||||
onCancel?.()
|
||||
} catch {
|
||||
setCancelling(false)
|
||||
}
|
||||
}
|
||||
|
||||
// Build footer links
|
||||
const footerLinks = []
|
||||
if (onBack) {
|
||||
footerLinks.push({ label: 'Use a different method', onClick: onBack })
|
||||
}
|
||||
if (onCancel) {
|
||||
footerLinks.push({
|
||||
label: cancelling ? 'Cancelling...' : 'Cancel',
|
||||
onClick: handleCancel
|
||||
})
|
||||
}
|
||||
|
||||
const formName = isReauth ? 'mfa_reauthenticate' : 'mfa_authenticate'
|
||||
|
||||
return (
|
||||
<AuthDjangoForm
|
||||
formName={formName}
|
||||
title="Recovery Code"
|
||||
subtitle="Enter one of your recovery codes."
|
||||
onSuccess={() => onSuccess?.()}
|
||||
footerLinks={footerLinks}
|
||||
/>
|
||||
)
|
||||
}
|
||||
@@ -0,0 +1,51 @@
|
||||
'use client'
|
||||
|
||||
import { useState } from 'react'
|
||||
import { useAllauthAPI } from '../../contexts/APIContext'
|
||||
import { AuthDjangoForm } from '../AuthDjangoForm'
|
||||
|
||||
interface MFATOTPViewProps {
|
||||
onSuccess?: () => void
|
||||
onCancel?: () => void
|
||||
onBack?: () => void
|
||||
isReauth?: boolean
|
||||
}
|
||||
|
||||
export function MFATOTPView({ onSuccess, onCancel, onBack, isReauth }: MFATOTPViewProps) {
|
||||
const api = useAllauthAPI()
|
||||
const [cancelling, setCancelling] = useState(false)
|
||||
|
||||
const handleCancel = async () => {
|
||||
setCancelling(true)
|
||||
try {
|
||||
await api.session.logout()
|
||||
onCancel?.()
|
||||
} catch {
|
||||
setCancelling(false)
|
||||
}
|
||||
}
|
||||
|
||||
// Build footer links
|
||||
const footerLinks = []
|
||||
if (onBack) {
|
||||
footerLinks.push({ label: 'Use a different method', onClick: onBack })
|
||||
}
|
||||
if (onCancel) {
|
||||
footerLinks.push({
|
||||
label: cancelling ? 'Cancelling...' : 'Cancel',
|
||||
onClick: handleCancel
|
||||
})
|
||||
}
|
||||
|
||||
const formName = isReauth ? 'mfa_reauthenticate' : 'mfa_authenticate'
|
||||
|
||||
return (
|
||||
<AuthDjangoForm
|
||||
formName={formName}
|
||||
title="Authenticator App"
|
||||
subtitle="Enter the 6-digit code from your authenticator app."
|
||||
onSuccess={() => onSuccess?.()}
|
||||
footerLinks={footerLinks}
|
||||
/>
|
||||
)
|
||||
}
|
||||
@@ -0,0 +1,113 @@
|
||||
'use client'
|
||||
|
||||
import { useState } from 'react'
|
||||
import { useAllauthAPI } from '../../contexts/APIContext'
|
||||
import { useAuthContext } from '../../contexts/AuthContext'
|
||||
import { useStyles } from '../../contexts/StylesContext'
|
||||
|
||||
interface MFAWebAuthnViewProps {
|
||||
onSuccess?: () => void
|
||||
onCancel?: () => void
|
||||
onBack?: () => void
|
||||
isReauth?: boolean
|
||||
}
|
||||
|
||||
export function MFAWebAuthnView({ onSuccess, onCancel, onBack, isReauth }: MFAWebAuthnViewProps) {
|
||||
const api = useAllauthAPI()
|
||||
const { refresh } = useAuthContext()
|
||||
const styles = useStyles()
|
||||
const [error, setError] = useState<string | null>(null)
|
||||
const [authenticating, setAuthenticating] = useState(false)
|
||||
const [cancelling, setCancelling] = useState(false)
|
||||
|
||||
const handleCancel = async () => {
|
||||
setCancelling(true)
|
||||
try {
|
||||
await api.session.logout()
|
||||
onCancel?.()
|
||||
} catch {
|
||||
setCancelling(false)
|
||||
}
|
||||
}
|
||||
|
||||
const handleWebAuthn = async () => {
|
||||
setError(null)
|
||||
setAuthenticating(true)
|
||||
|
||||
try {
|
||||
const { startAuthentication } = await import('@simplewebauthn/browser')
|
||||
|
||||
// Get challenge from server
|
||||
const optionsRes = isReauth
|
||||
? await api.webauthn.requestOptions.reauthentication()
|
||||
: await api.webauthn.requestOptions.authentication()
|
||||
|
||||
if (optionsRes.status !== 200 || !optionsRes.data?.request_options?.publicKey) {
|
||||
throw new Error('Failed to get authentication options')
|
||||
}
|
||||
|
||||
// Perform WebAuthn authentication
|
||||
// The allauth API returns { request_options: { publicKey: {...} } }
|
||||
// @simplewebauthn/browser v13+ expects { optionsJSON: ... }
|
||||
const credential = await startAuthentication({ optionsJSON: optionsRes.data.request_options.publicKey as any })
|
||||
|
||||
// Verify with server
|
||||
const res = isReauth
|
||||
? await api.webauthn.reauthenticate(credential)
|
||||
: await api.webauthn.authenticate(credential)
|
||||
|
||||
if (res.status === 200) {
|
||||
await refresh()
|
||||
onSuccess?.()
|
||||
} else {
|
||||
setError('Authentication failed. Please try again.')
|
||||
}
|
||||
} catch (e: any) {
|
||||
if (e.name === 'AbortError' || e.name === 'NotAllowedError') {
|
||||
setError(null)
|
||||
} else {
|
||||
setError(e.message || 'Failed to authenticate with security key')
|
||||
}
|
||||
} finally {
|
||||
setAuthenticating(false)
|
||||
}
|
||||
}
|
||||
|
||||
return (
|
||||
<div className={styles.container}>
|
||||
<div className={styles.card}>
|
||||
<h1 className={styles.title}>Security Key</h1>
|
||||
<p className={styles.subtitle}>Use your security key to verify your identity.</p>
|
||||
|
||||
{error && (
|
||||
<div className={styles.error}>
|
||||
<p>{error}</p>
|
||||
</div>
|
||||
)}
|
||||
|
||||
<div className={styles.form}>
|
||||
<button
|
||||
onClick={handleWebAuthn}
|
||||
disabled={authenticating}
|
||||
className={styles.submit}
|
||||
>
|
||||
{authenticating ? 'Waiting for security key...' : 'Use Security Key'}
|
||||
</button>
|
||||
</div>
|
||||
|
||||
<div className={styles.footer}>
|
||||
{onBack && (
|
||||
<button onClick={onBack} className={styles.link}>
|
||||
Use a different method
|
||||
</button>
|
||||
)}
|
||||
{onCancel && (
|
||||
<button onClick={handleCancel} disabled={cancelling} className={styles.link}>
|
||||
{cancelling ? 'Cancelling...' : 'Cancel'}
|
||||
</button>
|
||||
)}
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
)
|
||||
}
|
||||
@@ -0,0 +1,42 @@
|
||||
'use client'
|
||||
|
||||
import { useAuthContext } from '../../contexts/AuthContext'
|
||||
import { getAuthDetails } from '../../api'
|
||||
import { AuthDjangoForm } from '../AuthDjangoForm'
|
||||
|
||||
interface SignupViewProps {
|
||||
/** Called after successful signup */
|
||||
onSuccess?: () => void
|
||||
/** Called when user clicks "Already have an account? Sign in" */
|
||||
onLoginClick?: () => void
|
||||
}
|
||||
|
||||
export function SignupView({
|
||||
onSuccess,
|
||||
onLoginClick,
|
||||
}: SignupViewProps) {
|
||||
const { refresh } = useAuthContext()
|
||||
|
||||
const handleSuccess = async () => {
|
||||
const newAuth = await refresh()
|
||||
const details = getAuthDetails(newAuth)
|
||||
|
||||
if (details.isAuthenticated) {
|
||||
onSuccess?.()
|
||||
}
|
||||
}
|
||||
|
||||
const footerLinks: Array<{ label: string; onClick?: () => void }> = []
|
||||
|
||||
if (onLoginClick) {
|
||||
footerLinks.push({ label: 'Already have an account? Sign in', onClick: onLoginClick })
|
||||
}
|
||||
|
||||
return (
|
||||
<AuthDjangoForm
|
||||
formName="signup"
|
||||
onSuccess={handleSuccess}
|
||||
footerLinks={footerLinks}
|
||||
/>
|
||||
)
|
||||
}
|
||||
@@ -0,0 +1,6 @@
|
||||
export { LoginView } from './LoginView'
|
||||
export { SignupView } from './SignupView'
|
||||
export { MFAChooserView } from './MFAChooserView'
|
||||
export { MFAWebAuthnView } from './MFAWebAuthnView'
|
||||
export { MFATOTPView } from './MFATOTPView'
|
||||
export { MFARecoveryCodesView } from './MFARecoveryCodesView'
|
||||
67
packages/mizan-csr/adapters/react/src/allauth/config.ts
Normal file
67
packages/mizan-csr/adapters/react/src/allauth/config.ts
Normal file
@@ -0,0 +1,67 @@
|
||||
/**
|
||||
* Configuration for the allauth library.
|
||||
*
|
||||
* This config serves two purposes:
|
||||
* 1. Define the base path for Django-initiated routes (must match HEADLESS_FRONTEND_URLS)
|
||||
* 2. Define where to navigate for various auth events (developer controls these)
|
||||
*
|
||||
* For JWT-based API calls, use mizan/jwt separately.
|
||||
*/
|
||||
|
||||
export interface AllauthConfig {
|
||||
/**
|
||||
* Base path for Django-initiated routes (email verification, password reset, OAuth).
|
||||
* This must match the base path configured in Django's HEADLESS_FRONTEND_URLS.
|
||||
*
|
||||
* Example: '/auth' means Django sends users to '/auth/verify-email/{key}'
|
||||
*/
|
||||
basePath: string
|
||||
|
||||
/**
|
||||
* Navigation targets for auth events.
|
||||
* These are the URLs/paths the developer wants users sent to.
|
||||
*/
|
||||
routes: {
|
||||
/** Where to go after successful authentication */
|
||||
authenticated: string
|
||||
/** Where to go after logout */
|
||||
logout: string
|
||||
/** Where the login page is (for "Back to login" links) */
|
||||
login: string
|
||||
/** Where the signup page is (for "Create account" links) */
|
||||
signup: string
|
||||
}
|
||||
}
|
||||
|
||||
export const defaultConfig: AllauthConfig = {
|
||||
basePath: '/auth',
|
||||
routes: {
|
||||
authenticated: '/dashboard',
|
||||
logout: '/',
|
||||
login: '/login',
|
||||
signup: '/signup',
|
||||
},
|
||||
}
|
||||
|
||||
/**
|
||||
* Creates a config by merging provided options with defaults.
|
||||
*/
|
||||
export function createAllauthConfig(config: Partial<AllauthConfig>): AllauthConfig {
|
||||
return {
|
||||
basePath: config.basePath ?? defaultConfig.basePath,
|
||||
routes: {
|
||||
...defaultConfig.routes,
|
||||
...config.routes,
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Django-initiated flow paths (relative to basePath).
|
||||
* These must match what's configured in Django's HEADLESS_FRONTEND_URLS.
|
||||
*/
|
||||
export const DjangoFlowPaths = {
|
||||
VERIFY_EMAIL: '/verify-email',
|
||||
RESET_PASSWORD: '/reset-password',
|
||||
OAUTH_ERROR: '/oauth/error',
|
||||
} as const
|
||||
@@ -0,0 +1,72 @@
|
||||
'use client'
|
||||
|
||||
import { useMemo } from 'react'
|
||||
import { useDjangoCSRClient, Auth } from 'mizan/client/react'
|
||||
import { useAuthContext } from './AuthContext'
|
||||
import { createAPI, AllauthAPI, BrowserFormAction } from '../api'
|
||||
|
||||
/**
|
||||
* Browser form action for OAuth redirects.
|
||||
* Creates and submits a form programmatically.
|
||||
*/
|
||||
const browserFormAction: BrowserFormAction = (action: string, data: Record<string, string>) => {
|
||||
const form = document.createElement('form')
|
||||
form.method = 'POST'
|
||||
form.action = action
|
||||
|
||||
for (const [key, value] of Object.entries(data)) {
|
||||
const input = document.createElement('input')
|
||||
input.type = 'hidden'
|
||||
input.name = key
|
||||
input.value = value
|
||||
form.appendChild(input)
|
||||
}
|
||||
|
||||
document.body.appendChild(form)
|
||||
form.submit()
|
||||
}
|
||||
|
||||
/**
|
||||
* Hook that returns the Allauth API with automatic auth refresh on relevant responses.
|
||||
*
|
||||
* Automatically triggers auth refresh when:
|
||||
* - 401 with flows (authentication required)
|
||||
* - 410 (session gone)
|
||||
* - 200 with is_authenticated (successful auth)
|
||||
*/
|
||||
export function useAllauthAPI(): AllauthAPI {
|
||||
const client = useDjangoCSRClient(Auth.SESSION)
|
||||
const { refresh } = useAuthContext()
|
||||
|
||||
return useMemo(() => {
|
||||
const authRequest = async (method: string, path: string, data?: any, headers?: Record<string, string>) => {
|
||||
const resp = await client.request(method, `/_allauth/browser/v1${path}`, data, headers)
|
||||
|
||||
if (resp.status >= 500) {
|
||||
throw new Error(`Allauth request failed: ${resp.status} ${resp.statusText}`)
|
||||
}
|
||||
|
||||
try {
|
||||
return await resp.json()
|
||||
} catch {
|
||||
throw new Error(`Allauth request failed: ${resp.status} ${resp.statusText}`)
|
||||
}
|
||||
}
|
||||
|
||||
return createAPI(
|
||||
async (method, path, data?, headers?) => {
|
||||
const resp = await authRequest(method, path, { ...(data as object), client: 'browser' }, headers)
|
||||
|
||||
// Auto-refresh auth state on relevant responses
|
||||
if (resp.status === 401 && resp.data?.flows) {
|
||||
refresh(resp)
|
||||
} else if ([401, 410].includes(resp.status) || (resp.status === 200 && resp.meta?.is_authenticated)) {
|
||||
refresh()
|
||||
}
|
||||
|
||||
return resp
|
||||
},
|
||||
browserFormAction
|
||||
)
|
||||
}, [client, refresh])
|
||||
}
|
||||
@@ -0,0 +1,116 @@
|
||||
'use client'
|
||||
|
||||
import { ReactNode, useEffect, useState } from 'react'
|
||||
import { useDjangoCSRClient, Auth } from 'mizan/client/react'
|
||||
import type { RouterAdapter } from '../adapters/router'
|
||||
import type { InitialAuth } from '../hydration'
|
||||
import { AuthContext } from './AuthContext'
|
||||
import { ConfigContext } from './ConfigContext'
|
||||
import { StylesContext } from './StylesContext'
|
||||
import { RouterContext } from './RouterContext'
|
||||
import { AllauthConfig } from '../config'
|
||||
import { AuthClassNames } from '../styles/types'
|
||||
import { createAPI } from '../api'
|
||||
|
||||
export interface AllauthContextProps {
|
||||
children: ReactNode
|
||||
|
||||
/** Router adapter for navigation */
|
||||
router: RouterAdapter
|
||||
|
||||
/** Optional initial auth state from getInitialAuth() - if not provided, fetches client-side */
|
||||
hydration?: InitialAuth
|
||||
|
||||
/** Library configuration (basePath, routes) */
|
||||
allauthConfig?: Partial<AllauthConfig>
|
||||
|
||||
/** CSS class names for styling components */
|
||||
classNames?: AuthClassNames
|
||||
}
|
||||
|
||||
/**
|
||||
* Core AllauthContext - sets up all contexts for the allauth library.
|
||||
*
|
||||
* IMPORTANT: AllauthContext must be wrapped by DjangoContext, which provides
|
||||
* user data via useUser(). The typical setup is:
|
||||
*
|
||||
* ```tsx
|
||||
* <DjangoContext client={client} hydration={djangoHydration}>
|
||||
* <AllauthContext hydration={allauthHydration}>
|
||||
* {children}
|
||||
* </AllauthContext>
|
||||
* </DjangoContext>
|
||||
* ```
|
||||
*
|
||||
* If hydration is provided (from SSR), uses it immediately.
|
||||
* If not provided, fetches initial auth state client-side using the CSR client.
|
||||
*
|
||||
* For Next.js apps, use NextAllauthContext instead which handles the router automatically.
|
||||
*/
|
||||
export function AllauthContext({
|
||||
children,
|
||||
router,
|
||||
hydration,
|
||||
allauthConfig,
|
||||
classNames,
|
||||
}: AllauthContextProps) {
|
||||
const client = useDjangoCSRClient(Auth.SESSION)
|
||||
const [initialAuth, setInitialAuth] = useState<InitialAuth | null>(hydration ?? null)
|
||||
const [loading, setLoading] = useState(!hydration)
|
||||
|
||||
useEffect(() => {
|
||||
if (hydration) return // Already have SSR hydration
|
||||
|
||||
const fetchInitialAuth = async () => {
|
||||
try {
|
||||
const authRequest = async (method: string, path: string, data?: any, headers?: Record<string, string>) => {
|
||||
const resp = await client.request(method, `/_allauth/browser/v1${path}`, data, headers)
|
||||
if (resp.status >= 500) {
|
||||
throw new Error(`Allauth request failed: ${resp.status} ${resp.statusText}`)
|
||||
}
|
||||
return resp.json()
|
||||
}
|
||||
|
||||
const api = createAPI((method, path, data?, headers?) =>
|
||||
authRequest(method, path, { ...(data as object), client: 'browser' }, headers)
|
||||
)
|
||||
|
||||
const [config, auth] = await Promise.all([
|
||||
api.getConfig(),
|
||||
api.session.getStatus(),
|
||||
])
|
||||
|
||||
setInitialAuth({ config, auth })
|
||||
} catch (e) {
|
||||
console.error('[AllauthContext] Failed to fetch initial auth:', e)
|
||||
setInitialAuth({
|
||||
config: { status: 200, data: {} },
|
||||
auth: { status: 401, data: {} },
|
||||
})
|
||||
} finally {
|
||||
setLoading(false)
|
||||
}
|
||||
}
|
||||
|
||||
fetchInitialAuth()
|
||||
}, [client, hydration])
|
||||
|
||||
if (loading || !initialAuth) {
|
||||
return null
|
||||
}
|
||||
|
||||
return (
|
||||
<RouterContext router={router}>
|
||||
<ConfigContext config={allauthConfig}>
|
||||
<StylesContext classNames={classNames}>
|
||||
<AuthContext
|
||||
config={initialAuth.config}
|
||||
auth={initialAuth.auth}
|
||||
>
|
||||
{children}
|
||||
</AuthContext>
|
||||
</StylesContext>
|
||||
</ConfigContext>
|
||||
</RouterContext>
|
||||
)
|
||||
}
|
||||
@@ -0,0 +1,153 @@
|
||||
'use client'
|
||||
|
||||
import { createContext, ReactNode, useCallback, useContext, useEffect, useMemo, useRef, useState } from 'react'
|
||||
import { useDjangoCSRClient, Auth } from 'mizan/client/react'
|
||||
import { useMizan, useMizanContext } from 'mizan'
|
||||
import { getAuthDetails, createAPI } from '../api'
|
||||
import type { AllauthResponse } from '../types'
|
||||
import getAuthChangeEvent from '../events'
|
||||
|
||||
export interface AuthState {
|
||||
config: AllauthResponse
|
||||
auth: AllauthResponse
|
||||
event: string
|
||||
refresh: (newAuth?: AllauthResponse) => Promise<AllauthResponse>
|
||||
}
|
||||
|
||||
const Context = createContext<AuthState | null>(null)
|
||||
|
||||
export interface AuthContextProps {
|
||||
children: ReactNode
|
||||
/** Initial config from hydration */
|
||||
config: AllauthResponse
|
||||
/** Initial auth from hydration */
|
||||
auth: AllauthResponse
|
||||
}
|
||||
|
||||
export function AuthContext({
|
||||
children,
|
||||
config,
|
||||
auth: initialAuth,
|
||||
}: AuthContextProps) {
|
||||
const client = useDjangoCSRClient(Auth.SESSION)
|
||||
const { refreshAllContexts } = useMizan()
|
||||
const [auth, setAuth] = useState(initialAuth)
|
||||
const [event, setEvent] = useState('')
|
||||
const prevAuth = useRef(initialAuth)
|
||||
|
||||
// Create API for refresh operations
|
||||
const baseAPI = useMemo(() => {
|
||||
const authRequest = async (method: string, path: string, data?: any, headers?: Record<string, string>) => {
|
||||
const resp = await client.request(method, `/_allauth/browser/v1${path}`, data, headers)
|
||||
if (resp.status >= 500) {
|
||||
throw new Error(`Allauth request failed: ${resp.status} ${resp.statusText}`)
|
||||
}
|
||||
return resp.json()
|
||||
}
|
||||
return createAPI((method, path, data?, headers?) =>
|
||||
authRequest(method, path, { ...(data as object), client: 'browser' }, headers)
|
||||
)
|
||||
}, [client])
|
||||
|
||||
const refresh = useCallback(async (newAuth?: AllauthResponse): Promise<AllauthResponse> => {
|
||||
const authState = newAuth ?? await baseAPI.session.getStatus()
|
||||
setAuth(authState)
|
||||
|
||||
// Refresh all Django contexts (user data, permissions, etc.)
|
||||
await refreshAllContexts()
|
||||
|
||||
return authState
|
||||
}, [baseAPI, refreshAllContexts])
|
||||
|
||||
useEffect(() => {
|
||||
if (prevAuth.current && auth) {
|
||||
setEvent(getAuthChangeEvent(prevAuth.current, auth))
|
||||
}
|
||||
prevAuth.current = auth
|
||||
}, [auth])
|
||||
|
||||
const contextValue = useMemo(() => ({
|
||||
config, auth, event, refresh
|
||||
}), [config, auth, event, refresh])
|
||||
|
||||
return (
|
||||
<Context value={contextValue}>
|
||||
{children}
|
||||
</Context>
|
||||
)
|
||||
}
|
||||
|
||||
export function useAuthContext(): AuthState {
|
||||
const ctx = useContext(Context)
|
||||
if (!ctx) throw new Error('useAuthContext must be used within AuthContext')
|
||||
return ctx
|
||||
}
|
||||
|
||||
export function useAuth() {
|
||||
return getAuthDetails(useAuthContext().auth)
|
||||
}
|
||||
|
||||
/**
|
||||
* Base user interface expected by Allauth.
|
||||
* Products can extend this with additional fields.
|
||||
*/
|
||||
export interface AllauthUser {
|
||||
email?: string
|
||||
first_name?: string
|
||||
last_name?: string
|
||||
is_staff?: boolean
|
||||
is_superuser?: boolean
|
||||
}
|
||||
|
||||
/**
|
||||
* Get the current user from MizanProvider.
|
||||
*
|
||||
* This uses the generic mizan hook to access the 'user' context.
|
||||
* The backend defines this context in lib/mizan/allauth/contexts.py:
|
||||
*
|
||||
* @client(context='global')
|
||||
* def user(request) -> UserOutput | None:
|
||||
* ...
|
||||
*
|
||||
* @typeParam T - User type (defaults to AllauthUser, products can use more specific types)
|
||||
*/
|
||||
export function useUser<T extends AllauthUser = AllauthUser>(): T {
|
||||
const user = useMizanContext<T>('user')
|
||||
// Return empty object cast to T if user is undefined (not loaded)
|
||||
// This matches the previous behavior and allows optional chaining
|
||||
return (user ?? {}) as T
|
||||
}
|
||||
|
||||
export function useConfig() {
|
||||
return useAuthContext().config
|
||||
}
|
||||
|
||||
/**
|
||||
* Hook to access backend feature flags from the allauth configuration.
|
||||
*/
|
||||
export function useFeatures() {
|
||||
const config = useConfig()
|
||||
const data = config?.data as {
|
||||
account?: {
|
||||
is_open_for_signup?: boolean
|
||||
login_by_code_enabled?: boolean
|
||||
email_verification_by_code_enabled?: boolean
|
||||
}
|
||||
mfa?: {
|
||||
supported_types?: string[]
|
||||
}
|
||||
socialaccount?: {
|
||||
providers?: any[]
|
||||
}
|
||||
} | undefined
|
||||
|
||||
return {
|
||||
signupEnabled: data?.account?.is_open_for_signup ?? true,
|
||||
loginByCodeEnabled: data?.account?.login_by_code_enabled ?? false,
|
||||
emailVerificationByCodeEnabled: data?.account?.email_verification_by_code_enabled ?? false,
|
||||
mfaEnabled: (data?.mfa?.supported_types?.length ?? 0) > 0,
|
||||
mfaTypes: data?.mfa?.supported_types ?? [],
|
||||
socialLoginEnabled: (data?.socialaccount?.providers?.length ?? 0) > 0,
|
||||
socialProviders: data?.socialaccount?.providers ?? [],
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,29 @@
|
||||
'use client'
|
||||
|
||||
import { createContext, ReactNode, useContext, useMemo } from 'react'
|
||||
import { AllauthConfig, defaultConfig, createAllauthConfig } from '../config'
|
||||
|
||||
const Context = createContext<AllauthConfig>(defaultConfig)
|
||||
|
||||
interface ConfigContextProps {
|
||||
children: ReactNode
|
||||
config?: Partial<AllauthConfig>
|
||||
}
|
||||
|
||||
export function ConfigContext({ children, config }: ConfigContextProps) {
|
||||
// Memoize the merged config to prevent creating new objects on every render
|
||||
const mergedConfig = useMemo(
|
||||
() => config ? createAllauthConfig(config) : defaultConfig,
|
||||
[config]
|
||||
)
|
||||
|
||||
return (
|
||||
<Context value={mergedConfig}>
|
||||
{children}
|
||||
</Context>
|
||||
)
|
||||
}
|
||||
|
||||
export function useAllauthConfig(): AllauthConfig {
|
||||
return useContext(Context)
|
||||
}
|
||||
@@ -0,0 +1,31 @@
|
||||
'use client'
|
||||
|
||||
import { createContext, useContext, type ReactNode } from 'react'
|
||||
import type { RouterAdapter } from '../adapters/router'
|
||||
|
||||
const Context = createContext<RouterAdapter | null>(null)
|
||||
|
||||
interface RouterContextProps {
|
||||
children: ReactNode
|
||||
router: RouterAdapter
|
||||
}
|
||||
|
||||
export function RouterContext({ children, router }: RouterContextProps) {
|
||||
return (
|
||||
<Context value={router}>
|
||||
{children}
|
||||
</Context>
|
||||
)
|
||||
}
|
||||
|
||||
/**
|
||||
* Hook to access the router adapter.
|
||||
* Must be used within AllauthContext.
|
||||
*/
|
||||
export function useRouter(): RouterAdapter {
|
||||
const router = useContext(Context)
|
||||
if (!router) {
|
||||
throw new Error('useRouter must be used within AllauthContext')
|
||||
}
|
||||
return router
|
||||
}
|
||||
@@ -0,0 +1,49 @@
|
||||
'use client'
|
||||
|
||||
import { createContext, useContext, ReactNode } from 'react'
|
||||
import { AuthClassNames, emptyClassNames } from '../styles/types'
|
||||
|
||||
const Context = createContext<AuthClassNames>(emptyClassNames)
|
||||
|
||||
interface StylesContextProps {
|
||||
children: ReactNode
|
||||
classNames?: AuthClassNames
|
||||
}
|
||||
|
||||
export function StylesContext({ children, classNames }: StylesContextProps) {
|
||||
return (
|
||||
<Context value={classNames ?? emptyClassNames}>
|
||||
{children}
|
||||
</Context>
|
||||
)
|
||||
}
|
||||
|
||||
/**
|
||||
* Hook to access auth component class names.
|
||||
*
|
||||
* Returns the class names provided to AllauthProvider, or empty strings if none provided.
|
||||
* Use this to style custom components consistently with the auth UI.
|
||||
*
|
||||
* @example
|
||||
* ```tsx
|
||||
* function MyAuthComponent() {
|
||||
* const styles = useStyles()
|
||||
* return (
|
||||
* <div className={styles.card}>
|
||||
* <h1 className={styles.title}>Custom Auth View</h1>
|
||||
* </div>
|
||||
* )
|
||||
* }
|
||||
* ```
|
||||
*/
|
||||
export function useStyles(): AuthClassNames {
|
||||
return useContext(Context)
|
||||
}
|
||||
|
||||
/**
|
||||
* Utility to get a class name, returning empty string if undefined.
|
||||
* Useful for conditional class application.
|
||||
*/
|
||||
export function cx(...classNames: (string | undefined | false | null)[]): string {
|
||||
return classNames.filter(Boolean).join(' ')
|
||||
}
|
||||
@@ -0,0 +1,6 @@
|
||||
export { AllauthContext, type AllauthContextProps } from './AllauthContext'
|
||||
export { AuthContext, useAuthContext, useAuth, useUser, useConfig, useFeatures } from './AuthContext'
|
||||
export { useAllauthAPI } from './APIContext'
|
||||
export { ConfigContext, useAllauthConfig } from './ConfigContext'
|
||||
export { StylesContext, useStyles, cx } from './StylesContext'
|
||||
export { RouterContext, useRouter } from './RouterContext'
|
||||
71
packages/mizan-csr/adapters/react/src/allauth/defines.ts
Normal file
71
packages/mizan-csr/adapters/react/src/allauth/defines.ts
Normal file
@@ -0,0 +1,71 @@
|
||||
export const OAuthProcess = {
|
||||
LOGIN: 'login',
|
||||
CONNECT: 'connect'
|
||||
}
|
||||
|
||||
export const AuthenticatorType = {
|
||||
TOTP: 'totp',
|
||||
RECOVERY_CODES: 'recovery_codes',
|
||||
WEBAUTHN: 'webauthn'
|
||||
}
|
||||
|
||||
export const Flows = {
|
||||
LOGIN: 'login',
|
||||
LOGIN_BY_CODE: 'login_by_code',
|
||||
MFA_AUTHENTICATE: 'mfa_authenticate',
|
||||
MFA_REAUTHENTICATE: 'mfa_reauthenticate',
|
||||
MFA_TRUST: 'mfa_trust',
|
||||
MFA_WEBAUTHN_SIGNUP: 'mfa_signup_webauthn',
|
||||
PASSWORD_RESET_BY_CODE: 'password_reset_by_code',
|
||||
PROVIDER_REDIRECT: 'provider_redirect',
|
||||
PROVIDER_SIGNUP: 'provider_signup',
|
||||
REAUTHENTICATE: 'reauthenticate',
|
||||
SIGNUP: 'signup',
|
||||
VERIFY_EMAIL: 'verify_email',
|
||||
}
|
||||
|
||||
export const apiURL = {
|
||||
// Meta
|
||||
CONFIG: '/config',
|
||||
|
||||
// Account management
|
||||
CHANGE_PASSWORD: '/account/password/change',
|
||||
EMAIL: '/account/email',
|
||||
PROVIDERS: '/account/providers',
|
||||
|
||||
// Account management: 2FA
|
||||
AUTHENTICATORS: '/account/authenticators',
|
||||
RECOVERY_CODES: '/account/authenticators/recovery-codes',
|
||||
TOTP_AUTHENTICATOR: '/account/authenticators/totp',
|
||||
|
||||
// Auth: Basics
|
||||
LOGIN: '/auth/login',
|
||||
REQUEST_LOGIN_CODE: '/auth/code/request',
|
||||
CONFIRM_LOGIN_CODE: '/auth/code/confirm',
|
||||
SESSION: '/auth/session',
|
||||
REAUTHENTICATE: '/auth/reauthenticate',
|
||||
REQUEST_PASSWORD_RESET: '/auth/password/request',
|
||||
RESET_PASSWORD: '/auth/password/reset',
|
||||
SIGNUP: '/auth/signup',
|
||||
VERIFY_EMAIL: '/auth/email/verify',
|
||||
|
||||
// Auth: 2FA
|
||||
MFA_AUTHENTICATE: '/auth/2fa/authenticate',
|
||||
MFA_REAUTHENTICATE: '/auth/2fa/reauthenticate',
|
||||
MFA_TRUST: '/auth/2fa/trust',
|
||||
|
||||
// Auth: Social
|
||||
PROVIDER_SIGNUP: '/auth/provider/signup',
|
||||
REDIRECT_TO_PROVIDER: '/auth/provider/redirect',
|
||||
PROVIDER_TOKEN: '/auth/provider/token',
|
||||
|
||||
// Auth: Sessions
|
||||
SESSIONS: '/auth/sessions',
|
||||
|
||||
// Auth: WebAuthn
|
||||
REAUTHENTICATE_WEBAUTHN: '/auth/webauthn/reauthenticate',
|
||||
AUTHENTICATE_WEBAUTHN: '/auth/webauthn/authenticate',
|
||||
LOGIN_WEBAUTHN: '/auth/webauthn/login',
|
||||
SIGNUP_WEBAUTHN: '/auth/webauthn/signup',
|
||||
WEBAUTHN_AUTHENTICATOR: '/account/authenticators/webauthn'
|
||||
}
|
||||
51
packages/mizan-csr/adapters/react/src/allauth/events.ts
Normal file
51
packages/mizan-csr/adapters/react/src/allauth/events.ts
Normal file
@@ -0,0 +1,51 @@
|
||||
import { getAuthDetails } from './api'
|
||||
import type { AllauthResponse, AuthenticationMethod } from './types'
|
||||
|
||||
export const AuthChangeEvent = {
|
||||
LOGGED_OUT: 'LOGGED_OUT',
|
||||
LOGGED_IN: 'LOGGED_IN',
|
||||
REAUTHENTICATED: 'REAUTHENTICATED',
|
||||
REAUTHENTICATION_REQUIRED: 'REAUTHENTICATION_REQUIRED',
|
||||
FLOW_UPDATED: 'FLOW_UPDATED'
|
||||
}
|
||||
|
||||
export default function getAuthChangeEvent(fromAuth: AllauthResponse, toAuth: AllauthResponse): string {
|
||||
let before = getAuthDetails(fromAuth)
|
||||
const after = getAuthDetails(toAuth)
|
||||
|
||||
if (toAuth.status === 410) {
|
||||
return AuthChangeEvent.LOGGED_OUT
|
||||
}
|
||||
|
||||
const shouldReauth = () => {
|
||||
const fromMethods = (fromAuth.data?.methods as AuthenticationMethod[] | undefined) ?? []
|
||||
const toMethods = (toAuth.data?.methods as AuthenticationMethod[] | undefined) ?? []
|
||||
return (before.requiresReauthentication) || (fromMethods.length < toMethods.length)
|
||||
}
|
||||
|
||||
// Corner case: user ID change. Treat as if we're transitioning from anonymous state.
|
||||
if (before.user && after.user && before.user?.id !== after.user?.id) {
|
||||
before = { isAuthenticated: false, requiresReauthentication: false, user: null, pendingFlow: undefined }
|
||||
}
|
||||
|
||||
if (!before.isAuthenticated && after.isAuthenticated) {
|
||||
return AuthChangeEvent.LOGGED_IN
|
||||
} else if (before.isAuthenticated && !after.isAuthenticated) {
|
||||
return AuthChangeEvent.LOGGED_OUT
|
||||
} else if (before.isAuthenticated && after.isAuthenticated) {
|
||||
if (after.requiresReauthentication) {
|
||||
return AuthChangeEvent.REAUTHENTICATION_REQUIRED
|
||||
} else if (shouldReauth()) {
|
||||
return AuthChangeEvent.REAUTHENTICATED
|
||||
}
|
||||
} else if (!before.isAuthenticated && !after.isAuthenticated) {
|
||||
const fromFlow = before.pendingFlow
|
||||
const toFlow = after.pendingFlow
|
||||
if (toFlow?.id && fromFlow?.id !== toFlow.id) {
|
||||
return AuthChangeEvent.FLOW_UPDATED
|
||||
}
|
||||
}
|
||||
|
||||
// No change.
|
||||
return ''
|
||||
}
|
||||
48
packages/mizan-csr/adapters/react/src/allauth/hydration.ts
Normal file
48
packages/mizan-csr/adapters/react/src/allauth/hydration.ts
Normal file
@@ -0,0 +1,48 @@
|
||||
import type { DjangoHTTPClient } from 'mizan/client'
|
||||
import { createAPI } from './api'
|
||||
import type { AllauthResponse } from './types'
|
||||
|
||||
export interface InitialAuth {
|
||||
config: AllauthResponse
|
||||
auth: AllauthResponse
|
||||
}
|
||||
|
||||
/**
|
||||
* Fetch initial allauth state using an SSR client.
|
||||
* Call this in a server component and pass the result to AllauthContext.
|
||||
*
|
||||
* Note: User data comes from DjangoContext (which should wrap AllauthContext).
|
||||
* Use getDjangoHydration() from generated.contexts for that.
|
||||
*
|
||||
* @param ssrClient - A server-side Django HTTP client (e.g., createDjangoSSRClient)
|
||||
*/
|
||||
export async function getInitialAuth(
|
||||
ssrClient: DjangoHTTPClient,
|
||||
): Promise<InitialAuth> {
|
||||
const authRequest = async (method: string, path: string, data?: any, headers?: Record<string, string>) => {
|
||||
const resp = await ssrClient.request(method, `/_allauth/browser/v1${path}`, data, headers)
|
||||
if (resp.status >= 500) {
|
||||
throw new Error(`Allauth request failed: ${resp.status} ${resp.statusText}`)
|
||||
}
|
||||
return resp.json()
|
||||
}
|
||||
|
||||
const api = createAPI((method, path, data?, headers?) =>
|
||||
authRequest(method, path, { ...(data as object), client: 'browser' }, headers)
|
||||
)
|
||||
|
||||
try {
|
||||
const [config, auth] = await Promise.all([
|
||||
api.getConfig(),
|
||||
api.session.getStatus(),
|
||||
])
|
||||
|
||||
return { config, auth }
|
||||
} catch (e) {
|
||||
console.error('[getInitialAuth] Failed to fetch initial auth:', e)
|
||||
return {
|
||||
config: { status: 200, data: {} },
|
||||
auth: { status: 401, data: {} },
|
||||
}
|
||||
}
|
||||
}
|
||||
213
packages/mizan-csr/adapters/react/src/allauth/index.ts
Normal file
213
packages/mizan-csr/adapters/react/src/allauth/index.ts
Normal file
@@ -0,0 +1,213 @@
|
||||
/**
|
||||
* mizan/allauth
|
||||
*
|
||||
* React integration for django-allauth headless API.
|
||||
* Framework-agnostic - works with Next.js, Remix, React Router, etc.
|
||||
*
|
||||
* ## Quick Start (Next.js)
|
||||
*
|
||||
* ```tsx
|
||||
* // layout.tsx
|
||||
* import { cookies } from 'next/headers'
|
||||
* import { createDjangoSSRClient } from 'mizan/client'
|
||||
* import { getInitialAuth } from 'mizan/allauth'
|
||||
* import { NextAllauthContext } from 'mizan/allauth/nextjs'
|
||||
*
|
||||
* export default async function RootLayout({ children }) {
|
||||
* const ssrClient = createDjangoSSRClient({ cookies: await cookies() })
|
||||
* const hydration = await getInitialAuth(ssrClient)
|
||||
*
|
||||
* return (
|
||||
* <NextAllauthContext hydration={hydration}>
|
||||
* {children}
|
||||
* </NextAllauthContext>
|
||||
* )
|
||||
* }
|
||||
* ```
|
||||
*
|
||||
* ## Without SSR (pure client-side)
|
||||
*
|
||||
* ```tsx
|
||||
* // Just omit hydration - AllauthContext will fetch client-side
|
||||
* <NextAllauthContext>
|
||||
* {children}
|
||||
* </NextAllauthContext>
|
||||
* ```
|
||||
*/
|
||||
|
||||
// Configuration
|
||||
export { createAllauthConfig, defaultConfig, DjangoFlowPaths } from './config'
|
||||
export type { AllauthConfig } from './config'
|
||||
|
||||
// Hydration
|
||||
export { getInitialAuth } from './hydration'
|
||||
export type { InitialAuth } from './hydration'
|
||||
|
||||
// Providers
|
||||
export { AllauthContext } from './contexts/AllauthContext'
|
||||
export type { AllauthContextProps } from './contexts/AllauthContext'
|
||||
|
||||
// Router adapter
|
||||
export type { RouterAdapter } from './adapters/router'
|
||||
export { useRouter } from './contexts/RouterContext'
|
||||
|
||||
// Hooks
|
||||
export { useAuthContext, useAuth, useUser, useConfig, useFeatures } from './contexts/AuthContext'
|
||||
export { useAllauthAPI } from './contexts/APIContext'
|
||||
export { useAllauthConfig } from './contexts/ConfigContext'
|
||||
export { useStyles, cx } from './contexts/StylesContext'
|
||||
|
||||
// Styling
|
||||
export type { AuthClassNames } from './styles/types'
|
||||
|
||||
// Components
|
||||
export {
|
||||
// Main UI component (SPA - handles login, signup, MFA, settings, logout)
|
||||
AllauthUI,
|
||||
// Django-initiated flow handler (email verification, password reset links, OAuth)
|
||||
AllauthRouter,
|
||||
// Settings
|
||||
AuthSettings,
|
||||
ProfileSection,
|
||||
EmailsSection,
|
||||
PasswordSection,
|
||||
PasskeysSection,
|
||||
ConnectionsSection,
|
||||
MFASection,
|
||||
SessionsSection,
|
||||
SettingsSection,
|
||||
SettingsItem,
|
||||
SettingsList,
|
||||
Badge,
|
||||
Button,
|
||||
// Individual auth views
|
||||
LoginView,
|
||||
SignupView,
|
||||
MFAChooserView,
|
||||
MFAWebAuthnView,
|
||||
MFATOTPView,
|
||||
MFARecoveryCodesView,
|
||||
// Building blocks
|
||||
AuthCard,
|
||||
AuthFormPage,
|
||||
AuthDjangoForm,
|
||||
PasskeyLogin,
|
||||
ProviderList,
|
||||
// Form utilities
|
||||
useAuthForm,
|
||||
AuthField,
|
||||
} from './components'
|
||||
export type { AllauthUIView, AllauthUIMode } from './components'
|
||||
|
||||
// Routing guards
|
||||
export { UserRoute, StaffRoute, AnonymousRoute, FeatureRoute } from './routing'
|
||||
|
||||
// API
|
||||
export { createAPI, getAuthDetails } from './api'
|
||||
export type { AuthResponse, AuthDetails, AllauthAPI, BrowserFormAction } from './api'
|
||||
|
||||
// Types (re-exported from types.ts)
|
||||
export type {
|
||||
// Primitive types
|
||||
Timestamp,
|
||||
Email,
|
||||
Phone,
|
||||
Username,
|
||||
Password,
|
||||
Code,
|
||||
AuthenticatorCode,
|
||||
ProviderID,
|
||||
ProviderAccountID,
|
||||
AuthenticatorID,
|
||||
ClientID,
|
||||
// Enums
|
||||
AuthenticatorType as AuthenticatorTypeEnum,
|
||||
FlowID,
|
||||
LoginMethod,
|
||||
OAuthProcess,
|
||||
ProviderFlow,
|
||||
// User & Session
|
||||
User,
|
||||
Session,
|
||||
EmailAddress,
|
||||
PhoneNumber,
|
||||
// Authentication
|
||||
Flow,
|
||||
AuthenticationMethod,
|
||||
Authenticated,
|
||||
ReauthenticationRequired,
|
||||
// Provider
|
||||
Provider,
|
||||
ProviderAccount,
|
||||
// MFA / Authenticator
|
||||
BaseAuthenticator,
|
||||
TOTPAuthenticator,
|
||||
RecoveryCodesAuthenticator,
|
||||
SensitiveRecoveryCodesAuthenticator,
|
||||
WebAuthnAuthenticator,
|
||||
Authenticator,
|
||||
// Configuration
|
||||
AccountConfiguration,
|
||||
SocialAccountConfiguration,
|
||||
MFAConfiguration,
|
||||
UserSessionsConfiguration,
|
||||
AllauthConfiguration,
|
||||
// WebAuthn
|
||||
WebAuthnPublicKeyCredentialCreationOptions,
|
||||
WebAuthnPublicKeyCredentialRequestOptions,
|
||||
WebAuthnCreationOptions,
|
||||
WebAuthnRequestOptions,
|
||||
// TOTP
|
||||
TOTPStatus,
|
||||
// Meta
|
||||
BaseAuthenticationMeta,
|
||||
AuthenticationMeta,
|
||||
AuthenticatedMeta,
|
||||
// Response types
|
||||
AuthError,
|
||||
AllauthResponse,
|
||||
AuthenticatedResponse,
|
||||
ConfigurationResponse,
|
||||
EmailListResponse,
|
||||
SessionListResponse,
|
||||
AuthenticatorListResponse,
|
||||
ProviderAccountListResponse,
|
||||
TOTPStatusResponse,
|
||||
RecoveryCodesResponse,
|
||||
WebAuthnCreationOptionsResponse,
|
||||
WebAuthnRequestOptionsResponse,
|
||||
EmailVerificationInfoResponse,
|
||||
AuthenticationRequiredResponse,
|
||||
ReauthenticationRequiredResponse,
|
||||
ErrorResponse,
|
||||
ForbiddenResponse,
|
||||
ConflictResponse,
|
||||
SessionGoneResponse,
|
||||
// Request types
|
||||
LoginRequest,
|
||||
SignupRequest,
|
||||
ProviderSignupRequest,
|
||||
ReauthenticateRequest,
|
||||
RequestLoginCodeRequest,
|
||||
ConfirmLoginCodeRequest,
|
||||
MFAAuthenticateRequest,
|
||||
MFATrustRequest,
|
||||
RequestPasswordResetRequest,
|
||||
ResetPasswordRequest,
|
||||
VerifyEmailRequest,
|
||||
ChangePasswordRequest,
|
||||
AddEmailRequest,
|
||||
ProviderRedirectRequest,
|
||||
ProviderTokenRequest,
|
||||
WebAuthnAddRequest,
|
||||
WebAuthnAuthenticateRequest,
|
||||
WebAuthnUpdateRequest,
|
||||
WebAuthnDeleteRequest,
|
||||
EndSessionsRequest,
|
||||
// Union types
|
||||
AuthResponse as AuthResponseUnion,
|
||||
SessionStatusResponse,
|
||||
} from './types'
|
||||
|
||||
// Constants
|
||||
export { Flows, AuthenticatorType } from './defines'
|
||||
96
packages/mizan-csr/adapters/react/src/allauth/nextjs.tsx
Normal file
96
packages/mizan-csr/adapters/react/src/allauth/nextjs.tsx
Normal file
@@ -0,0 +1,96 @@
|
||||
'use client'
|
||||
|
||||
/**
|
||||
* Next.js adapter for mizan/allauth.
|
||||
*
|
||||
* Usage:
|
||||
* ```tsx
|
||||
* // In layout.tsx (server component)
|
||||
* import { createDjangoSSRClient } from 'mizan/client'
|
||||
* import { getInitialAuth } from 'mizan/allauth'
|
||||
* import { NextAllauthContext } from 'mizan/allauth/nextjs'
|
||||
*
|
||||
* export default async function RootLayout({ children }) {
|
||||
* const ssrClient = createDjangoSSRClient({ cookies: await cookies() })
|
||||
* const hydration = await getInitialAuth(ssrClient)
|
||||
*
|
||||
* return (
|
||||
* <NextAllauthContext hydration={hydration}>
|
||||
* {children}
|
||||
* </NextAllauthContext>
|
||||
* )
|
||||
* }
|
||||
* ```
|
||||
*/
|
||||
|
||||
import { ReactNode } from 'react'
|
||||
import { useRouter, usePathname, useSearchParams, useParams } from 'next/navigation'
|
||||
import type { RouterAdapter } from './adapters/router'
|
||||
import type { InitialAuth } from './hydration'
|
||||
import { AllauthContext } from './contexts/AllauthContext'
|
||||
import { AllauthConfig } from './config'
|
||||
import { AuthClassNames } from './styles/types'
|
||||
|
||||
/**
|
||||
* Create a RouterAdapter from Next.js App Router hooks.
|
||||
*/
|
||||
export function useNextRouter(): RouterAdapter {
|
||||
const router = useRouter()
|
||||
const pathname = usePathname()
|
||||
const searchParams = useSearchParams()
|
||||
const params = useParams()
|
||||
|
||||
return {
|
||||
push: (path: string) => router.push(path),
|
||||
replace: (path: string) => router.replace(path),
|
||||
pathname,
|
||||
searchParams: new URLSearchParams(searchParams.toString()),
|
||||
getParam: (name: string) => params[name] as string | string[] | undefined,
|
||||
}
|
||||
}
|
||||
|
||||
export interface NextAllauthContextProps {
|
||||
children: ReactNode
|
||||
|
||||
/** Optional initial auth state from getInitialAuth() - if not provided, fetches client-side */
|
||||
hydration?: InitialAuth
|
||||
|
||||
/** Library configuration (basePath, routes) */
|
||||
allauthConfig?: Partial<AllauthConfig>
|
||||
|
||||
/** CSS class names for styling components */
|
||||
classNames?: AuthClassNames
|
||||
}
|
||||
|
||||
/**
|
||||
* Next.js-specific AllauthContext that handles the router automatically.
|
||||
*
|
||||
* IMPORTANT: Must be wrapped by DjangoContext which provides user data.
|
||||
*
|
||||
* ```tsx
|
||||
* <DjangoContext client={client} hydration={djangoHydration}>
|
||||
* <NextAllauthContext hydration={allauthHydration}>
|
||||
* {children}
|
||||
* </NextAllauthContext>
|
||||
* </DjangoContext>
|
||||
* ```
|
||||
*/
|
||||
export function NextAllauthContext({
|
||||
children,
|
||||
hydration,
|
||||
allauthConfig,
|
||||
classNames,
|
||||
}: NextAllauthContextProps) {
|
||||
const router = useNextRouter()
|
||||
|
||||
return (
|
||||
<AllauthContext
|
||||
hydration={hydration}
|
||||
router={router}
|
||||
allauthConfig={allauthConfig}
|
||||
classNames={classNames}
|
||||
>
|
||||
{children}
|
||||
</AllauthContext>
|
||||
)
|
||||
}
|
||||
110
packages/mizan-csr/adapters/react/src/allauth/routing.tsx
Normal file
110
packages/mizan-csr/adapters/react/src/allauth/routing.tsx
Normal file
@@ -0,0 +1,110 @@
|
||||
'use client'
|
||||
|
||||
import { useEffect } from 'react'
|
||||
import { useRouter } from './contexts/RouterContext'
|
||||
import { useAllauthConfig } from './contexts/ConfigContext'
|
||||
import { useAuth, useUser, useConfig } from './contexts/AuthContext'
|
||||
|
||||
/**
|
||||
* Route guard that only renders children if the user is authenticated.
|
||||
* Redirects to login page if not authenticated.
|
||||
*/
|
||||
export function UserRoute({ children }: { children: React.ReactNode }) {
|
||||
const router = useRouter()
|
||||
const config = useAllauthConfig()
|
||||
const { isAuthenticated } = useAuth()
|
||||
|
||||
useEffect(() => {
|
||||
if (!isAuthenticated) {
|
||||
const next = encodeURIComponent(router.pathname + router.searchParams.toString())
|
||||
router.replace(`${config.routes.login}?next=${next}`)
|
||||
}
|
||||
}, [isAuthenticated, router, config.routes.login])
|
||||
|
||||
if (!isAuthenticated) return null
|
||||
return children
|
||||
}
|
||||
|
||||
/**
|
||||
* Route guard that only renders children if the user is authenticated AND is staff.
|
||||
* Redirects to login if not authenticated, or to authenticated route if not staff.
|
||||
*/
|
||||
export function StaffRoute({ children }: { children: React.ReactNode }) {
|
||||
const router = useRouter()
|
||||
const config = useAllauthConfig()
|
||||
const { isAuthenticated } = useAuth()
|
||||
const user = useUser()
|
||||
|
||||
useEffect(() => {
|
||||
if (!isAuthenticated) {
|
||||
const next = encodeURIComponent(router.pathname + router.searchParams.toString())
|
||||
router.replace(`${config.routes.login}?next=${next}`)
|
||||
} else if (!user.is_staff) {
|
||||
router.replace(config.routes.authenticated)
|
||||
}
|
||||
}, [isAuthenticated, user.is_staff, router, config.routes])
|
||||
|
||||
if (!isAuthenticated || !user.is_staff) return null
|
||||
return children
|
||||
}
|
||||
|
||||
/**
|
||||
* Route guard that only renders children if the user is NOT authenticated.
|
||||
* Redirects to authenticated route if already logged in.
|
||||
*/
|
||||
export function AnonymousRoute({ children }: { children: React.ReactNode }) {
|
||||
const router = useRouter()
|
||||
const config = useAllauthConfig()
|
||||
const { isAuthenticated } = useAuth()
|
||||
|
||||
useEffect(() => {
|
||||
if (isAuthenticated) {
|
||||
router.replace(config.routes.authenticated)
|
||||
}
|
||||
}, [isAuthenticated, config.routes.authenticated, router])
|
||||
|
||||
if (isAuthenticated) return null
|
||||
return children
|
||||
}
|
||||
|
||||
/**
|
||||
* Route guard that checks if a feature is enabled in the allauth config.
|
||||
* Redirects to fallback if feature is disabled.
|
||||
*/
|
||||
type FeatureKey = 'signup' | 'login_by_code' | 'mfa' | 'socialaccount'
|
||||
|
||||
function isFeatureEnabled(config: any, feature: FeatureKey): boolean | undefined {
|
||||
if (!config?.data) return undefined
|
||||
switch (feature) {
|
||||
case 'signup': return config.data.account?.is_open_for_signup
|
||||
case 'login_by_code': return config.data.account?.login_by_code_enabled
|
||||
case 'mfa': return config.data.mfa !== undefined
|
||||
case 'socialaccount': return (config.data.socialaccount?.providers?.length ?? 0) > 0
|
||||
}
|
||||
}
|
||||
|
||||
export function FeatureRoute({
|
||||
children,
|
||||
feature,
|
||||
redirectTo,
|
||||
}: {
|
||||
children: React.ReactNode
|
||||
feature: FeatureKey
|
||||
redirectTo?: string
|
||||
}) {
|
||||
const router = useRouter()
|
||||
const allauthConfig = useConfig()
|
||||
const config = useAllauthConfig()
|
||||
|
||||
const enabled = isFeatureEnabled(allauthConfig, feature)
|
||||
const fallback = redirectTo ?? config.routes.login
|
||||
|
||||
useEffect(() => {
|
||||
if (allauthConfig && enabled === false) {
|
||||
router.replace(fallback)
|
||||
}
|
||||
}, [allauthConfig, enabled, fallback, router])
|
||||
|
||||
if (!allauthConfig || enabled === false) return null
|
||||
return children
|
||||
}
|
||||
122
packages/mizan-csr/adapters/react/src/allauth/styles/types.ts
Normal file
122
packages/mizan-csr/adapters/react/src/allauth/styles/types.ts
Normal file
@@ -0,0 +1,122 @@
|
||||
/**
|
||||
* Class names for styling the auth components.
|
||||
*
|
||||
* All properties are optional - components will use empty strings as defaults.
|
||||
* Pass your own CSS module or Tailwind classes to customize the appearance.
|
||||
*
|
||||
* @example
|
||||
* ```tsx
|
||||
* // With CSS Modules
|
||||
* import styles from './auth.module.css'
|
||||
* <AllauthProvider classNames={styles}>
|
||||
*
|
||||
* // With Tailwind
|
||||
* const classNames = {
|
||||
* container: 'max-w-md mx-auto p-4',
|
||||
* card: 'bg-white rounded-lg shadow-lg p-6',
|
||||
* title: 'text-2xl font-bold',
|
||||
* // ...
|
||||
* }
|
||||
* <AllauthProvider classNames={classNames}>
|
||||
* ```
|
||||
*/
|
||||
export interface AuthClassNames {
|
||||
// Layout
|
||||
container?: string
|
||||
card?: string
|
||||
|
||||
// Typography
|
||||
title?: string
|
||||
subtitle?: string
|
||||
|
||||
// Form elements
|
||||
form?: string
|
||||
fieldsContainer?: string
|
||||
field?: string
|
||||
fieldLabel?: string
|
||||
fieldInput?: string
|
||||
fieldHelp?: string
|
||||
fieldError?: string
|
||||
required?: string
|
||||
|
||||
// Buttons
|
||||
submit?: string
|
||||
link?: string
|
||||
smallButton?: string
|
||||
smallButtonPrimary?: string
|
||||
smallButtonSecondary?: string
|
||||
smallButtonDanger?: string
|
||||
|
||||
// Feedback
|
||||
error?: string
|
||||
success?: string
|
||||
loading?: string
|
||||
spinner?: string
|
||||
emptyState?: string
|
||||
|
||||
// Divider
|
||||
divider?: string
|
||||
dividerText?: string
|
||||
|
||||
// Footer
|
||||
footer?: string
|
||||
|
||||
// Code input (for TOTP/login codes)
|
||||
codeInput?: string
|
||||
|
||||
// OAuth providers
|
||||
providersContainer?: string
|
||||
providerButtons?: string
|
||||
providerButton?: string
|
||||
|
||||
// Passkey
|
||||
passkeyContainer?: string
|
||||
passkeyButton?: string
|
||||
|
||||
// Settings page
|
||||
settingsContainer?: string
|
||||
settingsPageTitle?: string
|
||||
settingsCard?: string
|
||||
settingsSection?: string
|
||||
settingsSectionTitle?: string
|
||||
settingsSubtitle?: string
|
||||
settingsList?: string
|
||||
settingsItem?: string
|
||||
settingsItemInfo?: string
|
||||
settingsItemLabel?: string
|
||||
settingsItemMeta?: string
|
||||
settingsItemActions?: string
|
||||
|
||||
// Badges
|
||||
badge?: string
|
||||
badgePrimary?: string
|
||||
badgeUnverified?: string
|
||||
badgeSuccess?: string
|
||||
badgeDanger?: string
|
||||
|
||||
// Inline form
|
||||
inlineForm?: string
|
||||
|
||||
// TOTP setup
|
||||
totpSetup?: string
|
||||
qrCode?: string
|
||||
|
||||
// Recovery codes
|
||||
recoveryCodes?: string
|
||||
|
||||
// Form controls
|
||||
checkbox?: string
|
||||
radioGroup?: string
|
||||
radioItem?: string
|
||||
|
||||
// Navigation
|
||||
navLinks?: string
|
||||
navLink?: string
|
||||
navLinkActive?: string
|
||||
}
|
||||
|
||||
/**
|
||||
* Empty class names - used as default when no styles provided.
|
||||
* Components will render without any styling classes.
|
||||
*/
|
||||
export const emptyClassNames: AuthClassNames = {}
|
||||
546
packages/mizan-csr/adapters/react/src/allauth/types.ts
Normal file
546
packages/mizan-csr/adapters/react/src/allauth/types.ts
Normal file
@@ -0,0 +1,546 @@
|
||||
/**
|
||||
* TypeScript types for django-allauth headless API
|
||||
* Generated from OpenAPI specification
|
||||
*/
|
||||
|
||||
// =============================================================================
|
||||
// Primitive Types
|
||||
// =============================================================================
|
||||
|
||||
/** Epoch-based timestamp (use: new Date(value * 1000)) */
|
||||
export type Timestamp = number
|
||||
|
||||
/** Email address */
|
||||
export type Email = string
|
||||
|
||||
/** Phone number */
|
||||
export type Phone = string
|
||||
|
||||
/** Username */
|
||||
export type Username = string
|
||||
|
||||
/** Password */
|
||||
export type Password = string
|
||||
|
||||
/** One-time code */
|
||||
export type Code = string
|
||||
|
||||
/** Authenticator code (e.g., TOTP) */
|
||||
export type AuthenticatorCode = string
|
||||
|
||||
/** Provider ID (e.g., "google", "github") */
|
||||
export type ProviderID = string
|
||||
|
||||
/** Provider-specific account ID */
|
||||
export type ProviderAccountID = string
|
||||
|
||||
/** Authenticator ID */
|
||||
export type AuthenticatorID = number
|
||||
|
||||
/** OAuth client ID */
|
||||
export type ClientID = string
|
||||
|
||||
// =============================================================================
|
||||
// Enums
|
||||
// =============================================================================
|
||||
|
||||
export type AuthenticatorType = 'recovery_codes' | 'totp' | 'webauthn'
|
||||
|
||||
export type FlowID =
|
||||
| 'login'
|
||||
| 'login_by_code'
|
||||
| 'mfa_authenticate'
|
||||
| 'mfa_reauthenticate'
|
||||
| 'provider_redirect'
|
||||
| 'provider_signup'
|
||||
| 'provider_token'
|
||||
| 'reauthenticate'
|
||||
| 'signup'
|
||||
| 'verify_email'
|
||||
| 'verify_phone'
|
||||
|
||||
export type LoginMethod = 'email' | 'username'
|
||||
|
||||
export type OAuthProcess = 'login' | 'connect'
|
||||
|
||||
export type ProviderFlow = 'provider_redirect' | 'provider_token'
|
||||
|
||||
// =============================================================================
|
||||
// User & Session Types
|
||||
// =============================================================================
|
||||
|
||||
export interface User {
|
||||
id?: number
|
||||
display: string
|
||||
email?: string
|
||||
username?: string
|
||||
has_usable_password: boolean
|
||||
}
|
||||
|
||||
export interface Session {
|
||||
id: number
|
||||
user_agent: string
|
||||
ip: string
|
||||
created_at: Timestamp
|
||||
last_seen_at?: Timestamp
|
||||
is_current: boolean
|
||||
}
|
||||
|
||||
export interface EmailAddress {
|
||||
email: Email
|
||||
primary: boolean
|
||||
verified: boolean
|
||||
}
|
||||
|
||||
export interface PhoneNumber {
|
||||
phone: Phone
|
||||
verified: boolean
|
||||
}
|
||||
|
||||
// =============================================================================
|
||||
// Authentication Types
|
||||
// =============================================================================
|
||||
|
||||
export interface Flow {
|
||||
id: FlowID
|
||||
is_pending?: true
|
||||
provider?: Provider
|
||||
/** MFA types available (for mfa_authenticate/mfa_reauthenticate flows) */
|
||||
types?: AuthenticatorType[]
|
||||
}
|
||||
|
||||
export interface AuthenticationMethod {
|
||||
method: 'password' | 'password_reset' | 'code' | 'socialaccount' | 'mfa'
|
||||
at: Timestamp
|
||||
email?: Email
|
||||
phone?: Phone
|
||||
username?: Username
|
||||
provider?: ProviderID
|
||||
uid?: ProviderAccountID
|
||||
type?: AuthenticatorType
|
||||
reauthenticated?: boolean
|
||||
}
|
||||
|
||||
export interface Authenticated {
|
||||
user: User
|
||||
methods: AuthenticationMethod[]
|
||||
}
|
||||
|
||||
export interface ReauthenticationRequired {
|
||||
flows: Flow[]
|
||||
user: User
|
||||
methods: AuthenticationMethod[]
|
||||
}
|
||||
|
||||
// =============================================================================
|
||||
// Provider Types
|
||||
// =============================================================================
|
||||
|
||||
export interface Provider {
|
||||
id: ProviderID
|
||||
name: string
|
||||
client_id?: ClientID
|
||||
openid_configuration_url?: string
|
||||
flows: ProviderFlow[]
|
||||
}
|
||||
|
||||
export interface ProviderAccount {
|
||||
uid: ProviderAccountID
|
||||
display: string
|
||||
provider: Provider
|
||||
}
|
||||
|
||||
// =============================================================================
|
||||
// MFA / Authenticator Types
|
||||
// =============================================================================
|
||||
|
||||
export interface BaseAuthenticator {
|
||||
created_at: Timestamp
|
||||
last_used_at: Timestamp | null
|
||||
}
|
||||
|
||||
export interface TOTPAuthenticator extends BaseAuthenticator {
|
||||
type: 'totp'
|
||||
}
|
||||
|
||||
export interface RecoveryCodesAuthenticator extends BaseAuthenticator {
|
||||
type: 'recovery_codes'
|
||||
total_code_count: number
|
||||
unused_code_count: number
|
||||
}
|
||||
|
||||
export interface SensitiveRecoveryCodesAuthenticator extends RecoveryCodesAuthenticator {
|
||||
unused_codes: AuthenticatorCode[]
|
||||
}
|
||||
|
||||
export interface WebAuthnAuthenticator extends BaseAuthenticator {
|
||||
type: 'webauthn'
|
||||
id: AuthenticatorID
|
||||
name: string
|
||||
is_passwordless?: boolean
|
||||
}
|
||||
|
||||
export type Authenticator = TOTPAuthenticator | RecoveryCodesAuthenticator | WebAuthnAuthenticator
|
||||
|
||||
// =============================================================================
|
||||
// Configuration Types
|
||||
// =============================================================================
|
||||
|
||||
export interface AccountConfiguration {
|
||||
login_methods?: LoginMethod[]
|
||||
is_open_for_signup: boolean
|
||||
email_verification_by_code_enabled: boolean
|
||||
login_by_code_enabled: boolean
|
||||
password_reset_by_code_enabled?: boolean
|
||||
}
|
||||
|
||||
export interface SocialAccountConfiguration {
|
||||
providers: Provider[]
|
||||
}
|
||||
|
||||
export interface MFAConfiguration {
|
||||
supported_types: AuthenticatorType[]
|
||||
passkey_login_enabled?: boolean
|
||||
}
|
||||
|
||||
export interface UserSessionsConfiguration {
|
||||
track_activity: boolean
|
||||
}
|
||||
|
||||
export interface AllauthConfiguration {
|
||||
account: AccountConfiguration
|
||||
socialaccount?: SocialAccountConfiguration
|
||||
mfa?: MFAConfiguration
|
||||
usersessions?: UserSessionsConfiguration
|
||||
}
|
||||
|
||||
// =============================================================================
|
||||
// WebAuthn Types
|
||||
// =============================================================================
|
||||
|
||||
export interface WebAuthnPublicKeyCredentialCreationOptions {
|
||||
challenge: string
|
||||
rp: {
|
||||
name: string
|
||||
id: string
|
||||
}
|
||||
user: {
|
||||
id: string
|
||||
name: string
|
||||
displayName: string
|
||||
}
|
||||
pubKeyCredParams: Array<{
|
||||
type: 'public-key'
|
||||
alg: number
|
||||
}>
|
||||
timeout?: number
|
||||
excludeCredentials?: Array<{
|
||||
type: 'public-key'
|
||||
id: string
|
||||
}>
|
||||
authenticatorSelection?: {
|
||||
authenticatorAttachment?: 'platform' | 'cross-platform'
|
||||
requireResidentKey?: boolean
|
||||
residentKey?: 'discouraged' | 'preferred' | 'required'
|
||||
userVerification?: 'required' | 'preferred' | 'discouraged'
|
||||
}
|
||||
attestation?: 'none' | 'indirect' | 'direct' | 'enterprise'
|
||||
}
|
||||
|
||||
export interface WebAuthnPublicKeyCredentialRequestOptions {
|
||||
challenge: string
|
||||
rpId: string
|
||||
allowCredentials?: Array<{
|
||||
type: 'public-key'
|
||||
id: string
|
||||
}>
|
||||
userVerification?: 'required' | 'preferred' | 'discouraged'
|
||||
timeout?: number
|
||||
}
|
||||
|
||||
export interface WebAuthnCreationOptions {
|
||||
creation_options: {
|
||||
publicKey: WebAuthnPublicKeyCredentialCreationOptions
|
||||
}
|
||||
}
|
||||
|
||||
export interface WebAuthnRequestOptions {
|
||||
request_options: {
|
||||
publicKey: WebAuthnPublicKeyCredentialRequestOptions
|
||||
}
|
||||
}
|
||||
|
||||
// =============================================================================
|
||||
// TOTP Types
|
||||
// =============================================================================
|
||||
|
||||
export interface TOTPStatus {
|
||||
type: 'totp'
|
||||
created_at: Timestamp
|
||||
last_used_at: Timestamp | null
|
||||
/** Base32-encoded secret (only present when not yet activated) */
|
||||
secret?: string
|
||||
/** TOTP URI for QR code generation */
|
||||
totp_url?: string
|
||||
}
|
||||
|
||||
// =============================================================================
|
||||
// API Response Meta Types
|
||||
// =============================================================================
|
||||
|
||||
export interface BaseAuthenticationMeta {
|
||||
/** Session token (app clients only) */
|
||||
session_token?: string
|
||||
/** Access token (app clients only) */
|
||||
access_token?: string
|
||||
}
|
||||
|
||||
export interface AuthenticationMeta extends BaseAuthenticationMeta {
|
||||
is_authenticated: boolean
|
||||
}
|
||||
|
||||
export interface AuthenticatedMeta extends BaseAuthenticationMeta {
|
||||
is_authenticated: true
|
||||
}
|
||||
|
||||
// =============================================================================
|
||||
// API Response Types
|
||||
// =============================================================================
|
||||
|
||||
export interface AuthError {
|
||||
code: string
|
||||
message: string
|
||||
param?: string
|
||||
}
|
||||
|
||||
/** Base response structure - uses `any` for data/meta to maintain flexibility in generic use */
|
||||
export interface AllauthResponse<TData = any, TMeta = any> {
|
||||
status: number
|
||||
data?: TData
|
||||
meta?: TMeta
|
||||
errors?: AuthError[]
|
||||
}
|
||||
|
||||
/** 200 OK - Authenticated */
|
||||
export interface AuthenticatedResponse extends AllauthResponse<Authenticated, AuthenticationMeta> {
|
||||
status: 200
|
||||
data: Authenticated
|
||||
meta: AuthenticationMeta
|
||||
}
|
||||
|
||||
/** 200 OK - Configuration */
|
||||
export interface ConfigurationResponse extends AllauthResponse<AllauthConfiguration> {
|
||||
status: 200
|
||||
data: AllauthConfiguration
|
||||
}
|
||||
|
||||
/** 200 OK - Email list */
|
||||
export interface EmailListResponse extends AllauthResponse<EmailAddress[]> {
|
||||
status: 200
|
||||
data: EmailAddress[]
|
||||
}
|
||||
|
||||
/** 200 OK - Session list */
|
||||
export interface SessionListResponse extends AllauthResponse<Session[]> {
|
||||
status: 200
|
||||
data: Session[]
|
||||
}
|
||||
|
||||
/** 200 OK - Authenticator list */
|
||||
export interface AuthenticatorListResponse extends AllauthResponse<Authenticator[]> {
|
||||
status: 200
|
||||
data: Authenticator[]
|
||||
}
|
||||
|
||||
/** 200 OK - Provider account list */
|
||||
export interface ProviderAccountListResponse extends AllauthResponse<ProviderAccount[]> {
|
||||
status: 200
|
||||
data: ProviderAccount[]
|
||||
}
|
||||
|
||||
/** 200 OK - TOTP status */
|
||||
export interface TOTPStatusResponse extends AllauthResponse<TOTPStatus> {
|
||||
status: 200
|
||||
data: TOTPStatus
|
||||
}
|
||||
|
||||
/** 200 OK - Recovery codes */
|
||||
export interface RecoveryCodesResponse extends AllauthResponse<SensitiveRecoveryCodesAuthenticator> {
|
||||
status: 200
|
||||
data: SensitiveRecoveryCodesAuthenticator
|
||||
}
|
||||
|
||||
/** 200 OK - WebAuthn creation options */
|
||||
export interface WebAuthnCreationOptionsResponse extends AllauthResponse<WebAuthnCreationOptions> {
|
||||
status: 200
|
||||
data: WebAuthnCreationOptions
|
||||
}
|
||||
|
||||
/** 200 OK - WebAuthn request options */
|
||||
export interface WebAuthnRequestOptionsResponse extends AllauthResponse<WebAuthnRequestOptions> {
|
||||
status: 200
|
||||
data: WebAuthnRequestOptions
|
||||
}
|
||||
|
||||
/** 200 OK - Email verification info */
|
||||
export interface EmailVerificationInfoResponse extends AllauthResponse<{ email: Email; user: User }> {
|
||||
status: 200
|
||||
data: { email: Email; user: User }
|
||||
}
|
||||
|
||||
/** 401 - Authentication required (not authenticated) */
|
||||
export interface AuthenticationRequiredResponse extends AllauthResponse<{ flows: Flow[] }, AuthenticationMeta> {
|
||||
status: 401
|
||||
data: { flows: Flow[] }
|
||||
meta: AuthenticationMeta & { is_authenticated: false }
|
||||
}
|
||||
|
||||
/** 401 - Reauthentication required (authenticated but needs reauthentication) */
|
||||
export interface ReauthenticationRequiredResponse extends AllauthResponse<ReauthenticationRequired, AuthenticatedMeta> {
|
||||
status: 401
|
||||
data: ReauthenticationRequired
|
||||
meta: AuthenticatedMeta
|
||||
}
|
||||
|
||||
/** 400 - Bad request / validation error */
|
||||
export interface ErrorResponse extends AllauthResponse<never> {
|
||||
status: 400
|
||||
errors: AuthError[]
|
||||
}
|
||||
|
||||
/** 403 - Forbidden */
|
||||
export interface ForbiddenResponse extends AllauthResponse<never> {
|
||||
status: 403
|
||||
}
|
||||
|
||||
/** 409 - Conflict */
|
||||
export interface ConflictResponse extends AllauthResponse<never> {
|
||||
status: 409
|
||||
}
|
||||
|
||||
/** 410 - Session gone/expired */
|
||||
export interface SessionGoneResponse extends AllauthResponse<Record<string, never>, AuthenticationMeta> {
|
||||
status: 410
|
||||
data: Record<string, never>
|
||||
meta: AuthenticationMeta
|
||||
}
|
||||
|
||||
// =============================================================================
|
||||
// API Request Types
|
||||
// =============================================================================
|
||||
|
||||
export interface LoginRequest {
|
||||
email?: Email
|
||||
username?: Username
|
||||
phone?: Phone
|
||||
password: Password
|
||||
}
|
||||
|
||||
export interface SignupRequest {
|
||||
email: Email
|
||||
password: Password
|
||||
[key: string]: unknown // Additional custom signup fields
|
||||
}
|
||||
|
||||
export interface ProviderSignupRequest {
|
||||
email: Email
|
||||
[key: string]: unknown
|
||||
}
|
||||
|
||||
export interface ReauthenticateRequest {
|
||||
password: Password
|
||||
}
|
||||
|
||||
export interface RequestLoginCodeRequest {
|
||||
email?: Email
|
||||
phone?: Phone
|
||||
}
|
||||
|
||||
export interface ConfirmLoginCodeRequest {
|
||||
code: Code
|
||||
}
|
||||
|
||||
export interface MFAAuthenticateRequest {
|
||||
code: AuthenticatorCode
|
||||
}
|
||||
|
||||
export interface MFATrustRequest {
|
||||
trust: boolean
|
||||
}
|
||||
|
||||
export interface RequestPasswordResetRequest {
|
||||
email: Email
|
||||
}
|
||||
|
||||
export interface ResetPasswordRequest {
|
||||
key: string
|
||||
password: Password
|
||||
}
|
||||
|
||||
export interface VerifyEmailRequest {
|
||||
key: string
|
||||
}
|
||||
|
||||
export interface ChangePasswordRequest {
|
||||
current_password?: Password
|
||||
new_password: Password
|
||||
}
|
||||
|
||||
export interface AddEmailRequest {
|
||||
email: Email
|
||||
}
|
||||
|
||||
export interface ProviderRedirectRequest {
|
||||
provider: ProviderID
|
||||
process: OAuthProcess
|
||||
callback_url: string
|
||||
}
|
||||
|
||||
export interface ProviderTokenRequest {
|
||||
provider: ProviderID
|
||||
process: OAuthProcess
|
||||
token: {
|
||||
client_id: ClientID
|
||||
id_token?: string
|
||||
access_token?: string
|
||||
}
|
||||
}
|
||||
|
||||
export interface WebAuthnAddRequest {
|
||||
name: string
|
||||
credential: unknown // WebAuthn RegistrationResponseJSON
|
||||
}
|
||||
|
||||
export interface WebAuthnAuthenticateRequest {
|
||||
credential: unknown // WebAuthn AuthenticationResponseJSON
|
||||
}
|
||||
|
||||
export interface WebAuthnUpdateRequest {
|
||||
id: AuthenticatorID
|
||||
name?: string
|
||||
}
|
||||
|
||||
export interface WebAuthnDeleteRequest {
|
||||
authenticators: AuthenticatorID[]
|
||||
}
|
||||
|
||||
export interface EndSessionsRequest {
|
||||
sessions: number[]
|
||||
}
|
||||
|
||||
// =============================================================================
|
||||
// Union Types for Responses
|
||||
// =============================================================================
|
||||
|
||||
/** Possible responses from authentication endpoints */
|
||||
export type AuthResponse =
|
||||
| AuthenticatedResponse
|
||||
| AuthenticationRequiredResponse
|
||||
| ReauthenticationRequiredResponse
|
||||
| ErrorResponse
|
||||
|
||||
/** Possible responses from session status endpoint */
|
||||
export type SessionStatusResponse =
|
||||
| AuthenticatedResponse
|
||||
| AuthenticationRequiredResponse
|
||||
| SessionGoneResponse
|
||||
@@ -0,0 +1,165 @@
|
||||
/**
|
||||
* Tests for ChannelConnection
|
||||
*
|
||||
* These tests verify the ChannelConnection class API.
|
||||
* Unit tests for class structure don't require a real backend.
|
||||
* Integration tests for actual WebSocket connections require the backend.
|
||||
*
|
||||
* Backend must be running for integration tests: docker-compose up
|
||||
*/
|
||||
|
||||
import { ChannelConnection, RPCError } from '../connection'
|
||||
import { describeIntegration, WS_URL } from '../../testing'
|
||||
|
||||
describe('ChannelConnection (unit tests)', () => {
|
||||
describe('construction', () => {
|
||||
it('should start in disconnected state', () => {
|
||||
const connection = new ChannelConnection({ url: 'ws://localhost/ws/' })
|
||||
expect(connection.status).toBe('disconnected')
|
||||
})
|
||||
|
||||
})
|
||||
|
||||
describe('status change handlers', () => {
|
||||
it('should allow subscribing to status changes', () => {
|
||||
const connection = new ChannelConnection({ url: 'ws://localhost/ws/' })
|
||||
const handler = jest.fn()
|
||||
|
||||
const unsubscribe = connection.onStatusChange(handler)
|
||||
|
||||
expect(typeof unsubscribe).toBe('function')
|
||||
})
|
||||
})
|
||||
|
||||
describe('message handlers', () => {
|
||||
it('should allow subscribing to messages', () => {
|
||||
const connection = new ChannelConnection({ url: 'ws://localhost/ws/' })
|
||||
const handler = jest.fn()
|
||||
|
||||
const unsubscribe = connection.onMessage(handler)
|
||||
|
||||
expect(typeof unsubscribe).toBe('function')
|
||||
})
|
||||
})
|
||||
|
||||
describe('send queueing', () => {
|
||||
it('should queue messages when not connected', () => {
|
||||
const connection = new ChannelConnection({
|
||||
url: 'ws://localhost/ws/',
|
||||
reconnect: false,
|
||||
})
|
||||
|
||||
// This shouldn't throw
|
||||
connection.send({
|
||||
action: 'subscribe',
|
||||
channel: 'test',
|
||||
params: {},
|
||||
})
|
||||
|
||||
// Status should still be disconnected (or connecting if it auto-connected)
|
||||
expect(['disconnected', 'connecting']).toContain(connection.status)
|
||||
})
|
||||
})
|
||||
|
||||
describe('rpc', () => {
|
||||
it('should queue rpc messages when not connected', () => {
|
||||
const connection = new ChannelConnection({
|
||||
url: 'ws://localhost/ws/',
|
||||
reconnect: false,
|
||||
})
|
||||
|
||||
const promise = connection.rpc('test_fn', { arg: 'value' })
|
||||
|
||||
expect(promise).toBeInstanceOf(Promise)
|
||||
})
|
||||
})
|
||||
})
|
||||
|
||||
describeIntegration('ChannelConnection (integration)', () => {
|
||||
describe('real WebSocket connection', () => {
|
||||
it('should connect to real backend WebSocket', async () => {
|
||||
const connection = new ChannelConnection({
|
||||
url: WS_URL,
|
||||
reconnect: false,
|
||||
})
|
||||
|
||||
const statusChanges: string[] = []
|
||||
connection.onStatusChange((status) => {
|
||||
statusChanges.push(status)
|
||||
})
|
||||
|
||||
// Connect
|
||||
connection.connect()
|
||||
|
||||
// Wait for connection
|
||||
await new Promise<void>((resolve, reject) => {
|
||||
const timeout = setTimeout(() => {
|
||||
reject(new Error('Connection timeout'))
|
||||
}, 5000)
|
||||
|
||||
const unsubscribe = connection.onStatusChange((status) => {
|
||||
if (status === 'connected') {
|
||||
clearTimeout(timeout)
|
||||
unsubscribe()
|
||||
resolve()
|
||||
}
|
||||
})
|
||||
})
|
||||
|
||||
expect(connection.status).toBe('connected')
|
||||
|
||||
// Cleanup
|
||||
connection.disconnect()
|
||||
})
|
||||
|
||||
it('should disconnect cleanly', async () => {
|
||||
const connection = new ChannelConnection({
|
||||
url: WS_URL,
|
||||
reconnect: false,
|
||||
})
|
||||
|
||||
// Connect first
|
||||
connection.connect()
|
||||
await new Promise<void>((resolve) => {
|
||||
const unsubscribe = connection.onStatusChange((status) => {
|
||||
if (status === 'connected') {
|
||||
unsubscribe()
|
||||
resolve()
|
||||
}
|
||||
})
|
||||
})
|
||||
|
||||
// Now disconnect
|
||||
connection.disconnect()
|
||||
|
||||
// Should be disconnected
|
||||
expect(connection.status).toBe('disconnected')
|
||||
})
|
||||
})
|
||||
})
|
||||
|
||||
describe('RPCError', () => {
|
||||
it('should be an Error subclass', () => {
|
||||
const error = new RPCError('TEST_CODE', 'Test message')
|
||||
|
||||
expect(error).toBeInstanceOf(Error)
|
||||
expect(error).toBeInstanceOf(RPCError)
|
||||
})
|
||||
|
||||
it('should have correct properties', () => {
|
||||
const error = new RPCError('VALIDATION_ERROR', 'Field is required', { field: 'email' })
|
||||
|
||||
expect(error.code).toBe('VALIDATION_ERROR')
|
||||
expect(error.message).toBe('Field is required')
|
||||
expect(error.details).toEqual({ field: 'email' })
|
||||
expect(error.name).toBe('RPCError')
|
||||
})
|
||||
|
||||
it('should work without details', () => {
|
||||
const error = new RPCError('NOT_FOUND', 'Function not found')
|
||||
|
||||
expect(error.code).toBe('NOT_FOUND')
|
||||
expect(error.message).toBe('Function not found')
|
||||
expect(error.details).toBeUndefined()
|
||||
})
|
||||
})
|
||||
@@ -0,0 +1,207 @@
|
||||
/**
|
||||
* Tests for ChannelProvider context
|
||||
*
|
||||
* Unit tests run without backend.
|
||||
* Integration tests require: docker-compose up
|
||||
*
|
||||
* Run integration tests with: RUN_INTEGRATION_TESTS=true npm run test
|
||||
*/
|
||||
|
||||
import { renderHook, act, waitFor } from '@testing-library/react'
|
||||
import { ReactNode } from 'react'
|
||||
import { ChannelProvider, useChannelContext, useChannelStatus } from '../context'
|
||||
import { ChannelConnection } from '../connection'
|
||||
import { describeIntegration, WS_URL } from '../../testing'
|
||||
|
||||
// ============================================================================
|
||||
// Unit Tests (no backend required)
|
||||
// ============================================================================
|
||||
|
||||
describe('ChannelProvider (unit)', () => {
|
||||
describe('useChannelContext', () => {
|
||||
it('should throw when used outside ChannelProvider', () => {
|
||||
const consoleSpy = jest.spyOn(console, 'error').mockImplementation(() => {})
|
||||
|
||||
expect(() => {
|
||||
renderHook(() => useChannelContext())
|
||||
}).toThrow('useChannelContext must be used within a ChannelProvider')
|
||||
|
||||
consoleSpy.mockRestore()
|
||||
})
|
||||
|
||||
it('should return connection and status when inside provider', () => {
|
||||
const connection = new ChannelConnection({
|
||||
url: 'ws://localhost/ws/',
|
||||
reconnect: false,
|
||||
})
|
||||
|
||||
const wrapper = ({ children }: { children: ReactNode }) => (
|
||||
<ChannelProvider connection={connection} autoConnect={false}>
|
||||
{children}
|
||||
</ChannelProvider>
|
||||
)
|
||||
|
||||
const { result } = renderHook(() => useChannelContext(), { wrapper })
|
||||
|
||||
expect(result.current.connection).toBe(connection)
|
||||
expect(result.current.status).toBe('disconnected')
|
||||
|
||||
connection.disconnect()
|
||||
})
|
||||
})
|
||||
|
||||
describe('useChannelStatus', () => {
|
||||
it('should return disconnected when autoConnect is false', () => {
|
||||
const connection = new ChannelConnection({
|
||||
url: 'ws://localhost/ws/',
|
||||
reconnect: false,
|
||||
})
|
||||
|
||||
const wrapper = ({ children }: { children: ReactNode }) => (
|
||||
<ChannelProvider connection={connection} autoConnect={false}>
|
||||
{children}
|
||||
</ChannelProvider>
|
||||
)
|
||||
|
||||
const { result } = renderHook(() => useChannelStatus(), { wrapper })
|
||||
expect(result.current).toBe('disconnected')
|
||||
|
||||
connection.disconnect()
|
||||
})
|
||||
})
|
||||
})
|
||||
|
||||
// ============================================================================
|
||||
// Integration Tests (require running backend)
|
||||
// ============================================================================
|
||||
|
||||
describeIntegration('ChannelProvider (integration)', () => {
|
||||
describe('with real WebSocket connection', () => {
|
||||
let connection: ChannelConnection
|
||||
|
||||
beforeEach(() => {
|
||||
connection = new ChannelConnection({
|
||||
url: WS_URL,
|
||||
reconnect: false,
|
||||
})
|
||||
})
|
||||
|
||||
afterEach(() => {
|
||||
connection.disconnect()
|
||||
})
|
||||
|
||||
const createWrapper = (autoConnect = true) => {
|
||||
return function Wrapper({ children }: { children: ReactNode }) {
|
||||
return (
|
||||
<ChannelProvider
|
||||
connection={connection}
|
||||
autoConnect={autoConnect}
|
||||
>
|
||||
{children}
|
||||
</ChannelProvider>
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
it('should auto-connect when autoConnect is true', async () => {
|
||||
const { result } = renderHook(() => useChannelContext(), {
|
||||
wrapper: createWrapper(true),
|
||||
})
|
||||
|
||||
// Wait for connection
|
||||
await waitFor(() => {
|
||||
expect(result.current.status).toBe('connected')
|
||||
}, { timeout: 5000 })
|
||||
})
|
||||
|
||||
it('should not auto-connect when autoConnect is false', () => {
|
||||
const { result } = renderHook(() => useChannelContext(), {
|
||||
wrapper: createWrapper(false),
|
||||
})
|
||||
|
||||
expect(result.current.status).toBe('disconnected')
|
||||
})
|
||||
|
||||
it('should update status when connection status changes', async () => {
|
||||
const { result } = renderHook(() => useChannelStatus(), {
|
||||
wrapper: createWrapper(true),
|
||||
})
|
||||
|
||||
// Should start connecting then become connected
|
||||
await waitFor(() => {
|
||||
expect(result.current).toBe('connected')
|
||||
}, { timeout: 5000 })
|
||||
})
|
||||
|
||||
it('should disconnect on unmount', async () => {
|
||||
const { result, unmount } = renderHook(() => useChannelContext(), {
|
||||
wrapper: createWrapper(true),
|
||||
})
|
||||
|
||||
// Wait for connection
|
||||
await waitFor(() => {
|
||||
expect(result.current.status).toBe('connected')
|
||||
}, { timeout: 5000 })
|
||||
|
||||
// Unmount
|
||||
unmount()
|
||||
|
||||
// Connection should be disconnected
|
||||
expect(connection.status).toBe('disconnected')
|
||||
})
|
||||
})
|
||||
})
|
||||
|
||||
describeIntegration('useChannelStatus (integration)', () => {
|
||||
let connection: ChannelConnection
|
||||
|
||||
beforeEach(() => {
|
||||
connection = new ChannelConnection({
|
||||
url: WS_URL,
|
||||
reconnect: false,
|
||||
})
|
||||
})
|
||||
|
||||
afterEach(() => {
|
||||
connection.disconnect()
|
||||
})
|
||||
|
||||
const createWrapper = (autoConnect: boolean) => {
|
||||
return function Wrapper({ children }: { children: ReactNode }) {
|
||||
return (
|
||||
<ChannelProvider connection={connection} autoConnect={autoConnect}>
|
||||
{children}
|
||||
</ChannelProvider>
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
it('should return current connection status', () => {
|
||||
const { result } = renderHook(() => useChannelStatus(), {
|
||||
wrapper: createWrapper(false),
|
||||
})
|
||||
|
||||
expect(result.current).toBe('disconnected')
|
||||
})
|
||||
|
||||
it('should track status through connection lifecycle', async () => {
|
||||
const { result } = renderHook(() => useChannelStatus(), {
|
||||
wrapper: createWrapper(true),
|
||||
})
|
||||
|
||||
// Wait for connected
|
||||
await waitFor(() => {
|
||||
expect(result.current).toBe('connected')
|
||||
}, { timeout: 5000 })
|
||||
|
||||
// Disconnect manually
|
||||
act(() => {
|
||||
connection.disconnect()
|
||||
})
|
||||
|
||||
// Should become disconnected
|
||||
await waitFor(() => {
|
||||
expect(result.current).toBe('disconnected')
|
||||
})
|
||||
})
|
||||
})
|
||||
@@ -0,0 +1,158 @@
|
||||
/**
|
||||
* Integration tests for channel hooks
|
||||
*
|
||||
* These tests call the REAL backend - no mocks.
|
||||
* Backend must be running: docker-compose up
|
||||
*
|
||||
* Run with: RUN_INTEGRATION_TESTS=true npm run test
|
||||
*/
|
||||
|
||||
import { renderHook, waitFor } from '@testing-library/react'
|
||||
import { ReactNode } from 'react'
|
||||
import { ChannelProvider } from '../context'
|
||||
import { useChannel, useChannelLatest, useRPC } from '../hooks'
|
||||
import { ChannelConnection } from '../connection'
|
||||
import { describeIntegration, WS_URL } from '../../testing'
|
||||
|
||||
describeIntegration('useChannel (integration)', () => {
|
||||
let connection: ChannelConnection
|
||||
|
||||
beforeEach(() => {
|
||||
connection = new ChannelConnection({
|
||||
url: WS_URL,
|
||||
reconnect: false,
|
||||
})
|
||||
})
|
||||
|
||||
afterEach(() => {
|
||||
connection.disconnect()
|
||||
})
|
||||
|
||||
const createWrapper = () => {
|
||||
return function Wrapper({ children }: { children: ReactNode }) {
|
||||
return (
|
||||
<ChannelProvider connection={connection} autoConnect={true}>
|
||||
{children}
|
||||
</ChannelProvider>
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
describe('subscription', () => {
|
||||
it('should subscribe to channel when connection is ready', async () => {
|
||||
const { result } = renderHook(
|
||||
() => useChannel<{ room: string }, { text: string }, { text: string }>('chat', { room: 'test' }),
|
||||
{ wrapper: createWrapper() }
|
||||
)
|
||||
|
||||
// Wait for connection to establish
|
||||
await waitFor(() => {
|
||||
// Status should progress from connecting
|
||||
expect(['connecting', 'connected', 'subscribed']).toContain(result.current.status)
|
||||
}, { timeout: 5000 })
|
||||
|
||||
// Should have expected API
|
||||
expect(typeof result.current.send).toBe('function')
|
||||
expect(typeof result.current.clearMessages).toBe('function')
|
||||
expect(typeof result.current.unsubscribe).toBe('function')
|
||||
expect(Array.isArray(result.current.messages)).toBe(true)
|
||||
})
|
||||
|
||||
})
|
||||
})
|
||||
|
||||
describeIntegration('useChannelLatest (integration)', () => {
|
||||
let connection: ChannelConnection
|
||||
|
||||
beforeEach(() => {
|
||||
connection = new ChannelConnection({
|
||||
url: WS_URL,
|
||||
reconnect: false,
|
||||
})
|
||||
})
|
||||
|
||||
afterEach(() => {
|
||||
connection.disconnect()
|
||||
})
|
||||
|
||||
const createWrapper = () => {
|
||||
return function Wrapper({ children }: { children: ReactNode }) {
|
||||
return (
|
||||
<ChannelProvider connection={connection} autoConnect={true}>
|
||||
{children}
|
||||
</ChannelProvider>
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
})
|
||||
|
||||
describeIntegration('useRPC (integration)', () => {
|
||||
let connection: ChannelConnection
|
||||
|
||||
beforeEach(() => {
|
||||
connection = new ChannelConnection({
|
||||
url: WS_URL,
|
||||
reconnect: false,
|
||||
})
|
||||
})
|
||||
|
||||
afterEach(() => {
|
||||
connection.disconnect()
|
||||
})
|
||||
|
||||
const createWrapper = () => {
|
||||
return function Wrapper({ children }: { children: ReactNode }) {
|
||||
return (
|
||||
<ChannelProvider connection={connection} autoConnect={true}>
|
||||
{children}
|
||||
</ChannelProvider>
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
it('should track connection status', async () => {
|
||||
const { result } = renderHook(() => useRPC(), { wrapper: createWrapper() })
|
||||
|
||||
// Wait for connection
|
||||
await waitFor(() => {
|
||||
expect(result.current.status).toBe('connected')
|
||||
}, { timeout: 5000 })
|
||||
})
|
||||
|
||||
it('should call backend echo function via RPC', async () => {
|
||||
const { result } = renderHook(() => useRPC(), { wrapper: createWrapper() })
|
||||
|
||||
// Wait for connection
|
||||
await waitFor(() => {
|
||||
expect(result.current.status).toBe('connected')
|
||||
}, { timeout: 5000 })
|
||||
|
||||
// Call echo function
|
||||
const response = await result.current.call<{ text: string }, { message: string }>(
|
||||
'echo',
|
||||
{ text: 'rpc test' }
|
||||
)
|
||||
|
||||
expect(response).toHaveProperty('message')
|
||||
expect(response.message).toContain('rpc test')
|
||||
})
|
||||
|
||||
it('should call backend add function via RPC', async () => {
|
||||
const { result } = renderHook(() => useRPC(), { wrapper: createWrapper() })
|
||||
|
||||
// Wait for connection
|
||||
await waitFor(() => {
|
||||
expect(result.current.status).toBe('connected')
|
||||
}, { timeout: 5000 })
|
||||
|
||||
// Call add function
|
||||
const response = await result.current.call<{ a: number; b: number }, { result: number }>(
|
||||
'add',
|
||||
{ a: 7, b: 8 }
|
||||
)
|
||||
|
||||
expect(response).toHaveProperty('result', 15)
|
||||
})
|
||||
})
|
||||
|
||||
299
packages/mizan-csr/adapters/react/src/channels/connection.ts
Normal file
299
packages/mizan-csr/adapters/react/src/channels/connection.ts
Normal file
@@ -0,0 +1,299 @@
|
||||
/**
|
||||
* WebSocket connection manager for mizan/channels
|
||||
*
|
||||
* Supports both pub/sub channels AND RPC calls over the same connection.
|
||||
*/
|
||||
|
||||
import type {
|
||||
ConnectionStatus,
|
||||
OutgoingMessage,
|
||||
IncomingPayload,
|
||||
SubscribeOptions,
|
||||
} from './types'
|
||||
|
||||
type MessageHandler = (payload: IncomingPayload) => void
|
||||
type StatusHandler = (status: ConnectionStatus) => void
|
||||
|
||||
/** RPC request message */
|
||||
export interface RPCRequest<T = unknown> {
|
||||
action: 'rpc'
|
||||
id: string
|
||||
fn: string
|
||||
args: T
|
||||
}
|
||||
|
||||
/** RPC response - success */
|
||||
export interface RPCSuccessResponse<T = unknown> {
|
||||
id: string
|
||||
ok: true
|
||||
data: T
|
||||
}
|
||||
|
||||
/** RPC response - error */
|
||||
export interface RPCErrorResponse {
|
||||
id: string
|
||||
ok: false
|
||||
error: {
|
||||
code: string
|
||||
message: string
|
||||
details?: Record<string, unknown>
|
||||
}
|
||||
}
|
||||
|
||||
export type RPCResponse<T = unknown> = RPCSuccessResponse<T> | RPCErrorResponse
|
||||
|
||||
/** RPC error thrown on failure */
|
||||
export class RPCError extends Error {
|
||||
code: string
|
||||
details?: Record<string, unknown>
|
||||
|
||||
constructor(code: string, message: string, details?: Record<string, unknown>) {
|
||||
super(message)
|
||||
this.name = 'RPCError'
|
||||
this.code = code
|
||||
this.details = details
|
||||
}
|
||||
}
|
||||
|
||||
export interface ChannelConnectionOptions {
|
||||
/** WebSocket URL (default: /ws/) */
|
||||
url?: string
|
||||
|
||||
/** Reconnect on disconnect (default: true) */
|
||||
reconnect?: boolean
|
||||
|
||||
/** Reconnection delay in ms (default: 1000) */
|
||||
reconnectDelay?: number
|
||||
|
||||
/** Maximum reconnection attempts (default: 10) */
|
||||
maxReconnectAttempts?: number
|
||||
}
|
||||
|
||||
export class ChannelConnection {
|
||||
private ws: WebSocket | null = null
|
||||
private url: string
|
||||
private reconnect: boolean
|
||||
private reconnectDelay: number
|
||||
private maxReconnectAttempts: number
|
||||
private reconnectAttempts = 0
|
||||
private reconnectTimer: ReturnType<typeof setTimeout> | null = null
|
||||
|
||||
private messageHandlers: Set<MessageHandler> = new Set()
|
||||
private statusHandlers: Set<StatusHandler> = new Set()
|
||||
|
||||
private _status: ConnectionStatus = 'disconnected'
|
||||
private pendingMessages: OutgoingMessage[] = []
|
||||
|
||||
// RPC state
|
||||
private rpcIdCounter = 0
|
||||
private pendingRPCs: Map<string, {
|
||||
resolve: (data: unknown) => void
|
||||
reject: (error: RPCError) => void
|
||||
}> = new Map()
|
||||
|
||||
constructor(options: ChannelConnectionOptions = {}) {
|
||||
// Build WebSocket URL
|
||||
const baseUrl = options.url || '/ws/'
|
||||
if (typeof window !== 'undefined') {
|
||||
const protocol = window.location.protocol === 'https:' ? 'wss:' : 'ws:'
|
||||
this.url = baseUrl.startsWith('ws')
|
||||
? baseUrl
|
||||
: `${protocol}//${window.location.host}${baseUrl}`
|
||||
} else {
|
||||
this.url = baseUrl
|
||||
}
|
||||
|
||||
this.reconnect = options.reconnect ?? true
|
||||
this.reconnectDelay = options.reconnectDelay ?? 1000
|
||||
this.maxReconnectAttempts = options.maxReconnectAttempts ?? 10
|
||||
}
|
||||
|
||||
get status(): ConnectionStatus {
|
||||
return this._status
|
||||
}
|
||||
|
||||
private setStatus(status: ConnectionStatus) {
|
||||
this._status = status
|
||||
this.statusHandlers.forEach(handler => handler(status))
|
||||
}
|
||||
|
||||
connect(): void {
|
||||
if (this.ws?.readyState === WebSocket.OPEN || this.ws?.readyState === WebSocket.CONNECTING) {
|
||||
return
|
||||
}
|
||||
|
||||
this.setStatus('connecting')
|
||||
|
||||
try {
|
||||
this.ws = new WebSocket(this.url)
|
||||
|
||||
this.ws.onopen = () => {
|
||||
this.reconnectAttempts = 0
|
||||
this.setStatus('connected')
|
||||
|
||||
// Send any pending messages
|
||||
this.pendingMessages.forEach(msg => this.send(msg))
|
||||
this.pendingMessages = []
|
||||
}
|
||||
|
||||
this.ws.onclose = (event) => {
|
||||
this.setStatus('disconnected')
|
||||
|
||||
// Attempt reconnection if enabled
|
||||
if (this.reconnect && this.reconnectAttempts < this.maxReconnectAttempts) {
|
||||
this.scheduleReconnect()
|
||||
}
|
||||
}
|
||||
|
||||
this.ws.onerror = () => {
|
||||
// WebSocket errors don't provide useful details (browser security)
|
||||
// The onclose handler will fire next and trigger reconnection
|
||||
console.warn('[ChannelConnection] WebSocket error (will reconnect)')
|
||||
}
|
||||
|
||||
this.ws.onmessage = (event) => {
|
||||
try {
|
||||
const payload = JSON.parse(event.data)
|
||||
|
||||
// Check if this is an RPC response (has 'id' and 'ok' fields)
|
||||
if ('id' in payload && 'ok' in payload) {
|
||||
this.handleRPCResponse(payload as RPCResponse)
|
||||
return
|
||||
}
|
||||
|
||||
// Otherwise, it's a channel message
|
||||
this.messageHandlers.forEach(handler => handler(payload as IncomingPayload))
|
||||
} catch (e) {
|
||||
console.error('[ChannelConnection] Failed to parse message:', e)
|
||||
}
|
||||
}
|
||||
} catch (error) {
|
||||
console.error('[ChannelConnection] Failed to connect:', error)
|
||||
this.setStatus('disconnected')
|
||||
|
||||
if (this.reconnect) {
|
||||
this.scheduleReconnect()
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
disconnect(): void {
|
||||
this.reconnect = false
|
||||
this.clearReconnectTimer()
|
||||
|
||||
if (this.ws) {
|
||||
this.ws.close()
|
||||
this.ws = null
|
||||
}
|
||||
|
||||
this.setStatus('disconnected')
|
||||
}
|
||||
|
||||
send(message: OutgoingMessage): void {
|
||||
if (this.ws?.readyState === WebSocket.OPEN) {
|
||||
this.ws.send(JSON.stringify(message))
|
||||
} else {
|
||||
// Queue message to send when connected
|
||||
this.pendingMessages.push(message)
|
||||
|
||||
// Ensure we're trying to connect
|
||||
if (this._status === 'disconnected') {
|
||||
this.connect()
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
onMessage(handler: MessageHandler): () => void {
|
||||
this.messageHandlers.add(handler)
|
||||
return () => this.messageHandlers.delete(handler)
|
||||
}
|
||||
|
||||
onStatusChange(handler: StatusHandler): () => void {
|
||||
this.statusHandlers.add(handler)
|
||||
return () => this.statusHandlers.delete(handler)
|
||||
}
|
||||
|
||||
private scheduleReconnect(): void {
|
||||
this.clearReconnectTimer()
|
||||
|
||||
const delay = this.reconnectDelay * Math.pow(2, this.reconnectAttempts)
|
||||
this.reconnectAttempts++
|
||||
|
||||
console.log(`[ChannelConnection] Reconnecting in ${delay}ms (attempt ${this.reconnectAttempts}/${this.maxReconnectAttempts})`)
|
||||
|
||||
this.reconnectTimer = setTimeout(() => {
|
||||
this.connect()
|
||||
}, delay)
|
||||
}
|
||||
|
||||
private clearReconnectTimer(): void {
|
||||
if (this.reconnectTimer) {
|
||||
clearTimeout(this.reconnectTimer)
|
||||
this.reconnectTimer = null
|
||||
}
|
||||
}
|
||||
|
||||
// =========================================================================
|
||||
// RPC Methods
|
||||
// =========================================================================
|
||||
|
||||
/**
|
||||
* Call a server function via RPC.
|
||||
*
|
||||
* @param fn - Function name (as registered on backend)
|
||||
* @param args - Function arguments
|
||||
* @returns Promise resolving to function output
|
||||
* @throws RPCError on failure
|
||||
*/
|
||||
rpc<TInput, TOutput>(fn: string, args: TInput): Promise<TOutput> {
|
||||
return new Promise((resolve, reject) => {
|
||||
const id = `rpc_${++this.rpcIdCounter}_${Date.now()}`
|
||||
|
||||
// Store pending RPC
|
||||
this.pendingRPCs.set(id, {
|
||||
resolve: resolve as (data: unknown) => void,
|
||||
reject,
|
||||
})
|
||||
|
||||
// Send RPC request
|
||||
const request: RPCRequest<TInput> = {
|
||||
action: 'rpc',
|
||||
id,
|
||||
fn,
|
||||
args,
|
||||
}
|
||||
|
||||
this.send(request as unknown as OutgoingMessage)
|
||||
})
|
||||
}
|
||||
|
||||
private handleRPCResponse(response: RPCResponse): void {
|
||||
const pending = this.pendingRPCs.get(response.id)
|
||||
if (!pending) {
|
||||
console.warn(`[ChannelConnection] Received RPC response for unknown id: ${response.id}`)
|
||||
return
|
||||
}
|
||||
|
||||
this.pendingRPCs.delete(response.id)
|
||||
|
||||
if (response.ok) {
|
||||
pending.resolve(response.data)
|
||||
} else {
|
||||
pending.reject(new RPCError(
|
||||
response.error.code,
|
||||
response.error.message,
|
||||
response.error.details,
|
||||
))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Singleton connection instance
|
||||
let defaultConnection: ChannelConnection | null = null
|
||||
|
||||
export function getDefaultConnection(options?: ChannelConnectionOptions): ChannelConnection {
|
||||
if (!defaultConnection) {
|
||||
defaultConnection = new ChannelConnection(options)
|
||||
}
|
||||
return defaultConnection
|
||||
}
|
||||
102
packages/mizan-csr/adapters/react/src/channels/context.tsx
Normal file
102
packages/mizan-csr/adapters/react/src/channels/context.tsx
Normal file
@@ -0,0 +1,102 @@
|
||||
'use client'
|
||||
|
||||
/**
|
||||
* React context for mizan/channels
|
||||
*/
|
||||
|
||||
import { createContext, useContext, useEffect, useMemo, useRef, useState, type ReactNode } from 'react'
|
||||
import { ChannelConnection, type ChannelConnectionOptions } from './connection'
|
||||
import type { ConnectionStatus } from './types'
|
||||
|
||||
interface ChannelContextValue {
|
||||
connection: ChannelConnection
|
||||
status: ConnectionStatus
|
||||
}
|
||||
|
||||
const ChannelContext = createContext<ChannelContextValue | null>(null)
|
||||
|
||||
export interface ChannelProviderProps {
|
||||
children: ReactNode
|
||||
|
||||
/** WebSocket URL (default: /ws/) */
|
||||
url?: string
|
||||
|
||||
/** Reconnect on disconnect (default: true) */
|
||||
reconnect?: boolean
|
||||
|
||||
/** Reconnection delay in ms (default: 1000) */
|
||||
reconnectDelay?: number
|
||||
|
||||
/** Maximum reconnection attempts (default: 10) */
|
||||
maxReconnectAttempts?: number
|
||||
|
||||
/** Connect automatically on mount (default: true) */
|
||||
autoConnect?: boolean
|
||||
|
||||
/** Custom connection instance (for testing) */
|
||||
connection?: ChannelConnection
|
||||
}
|
||||
|
||||
export function ChannelProvider({
|
||||
children,
|
||||
url,
|
||||
reconnect,
|
||||
reconnectDelay,
|
||||
maxReconnectAttempts,
|
||||
autoConnect = true,
|
||||
connection: providedConnection,
|
||||
}: ChannelProviderProps) {
|
||||
const connectionRef = useRef<ChannelConnection | null>(null)
|
||||
|
||||
// Use provided connection or create one
|
||||
if (!connectionRef.current) {
|
||||
connectionRef.current = providedConnection ?? new ChannelConnection({
|
||||
url,
|
||||
reconnect,
|
||||
reconnectDelay,
|
||||
maxReconnectAttempts,
|
||||
})
|
||||
}
|
||||
|
||||
const connection = connectionRef.current
|
||||
|
||||
// Track status for context value
|
||||
const [status, setStatus] = useState<ConnectionStatus>(connection.status)
|
||||
|
||||
useEffect(() => {
|
||||
const unsubscribe = connection.onStatusChange(setStatus)
|
||||
|
||||
if (autoConnect) {
|
||||
connection.connect()
|
||||
}
|
||||
|
||||
return () => {
|
||||
unsubscribe()
|
||||
connection.disconnect()
|
||||
}
|
||||
}, [connection, autoConnect])
|
||||
|
||||
const value = useMemo(() => ({
|
||||
connection,
|
||||
status,
|
||||
}), [connection, status])
|
||||
|
||||
return (
|
||||
<ChannelContext value={value}>
|
||||
{children}
|
||||
</ChannelContext>
|
||||
)
|
||||
}
|
||||
|
||||
export function useChannelContext(): ChannelContextValue {
|
||||
const context = useContext(ChannelContext)
|
||||
if (!context) {
|
||||
throw new Error('useChannelContext must be used within a ChannelProvider')
|
||||
}
|
||||
return context
|
||||
}
|
||||
|
||||
export function useChannelStatus(): ConnectionStatus {
|
||||
const { status } = useChannelContext()
|
||||
return status
|
||||
}
|
||||
256
packages/mizan-csr/adapters/react/src/channels/hooks.ts
Normal file
256
packages/mizan-csr/adapters/react/src/channels/hooks.ts
Normal file
@@ -0,0 +1,256 @@
|
||||
'use client'
|
||||
|
||||
/**
|
||||
* React hooks for mizan/channels
|
||||
*
|
||||
* Includes pub/sub channel hooks AND RPC hooks.
|
||||
*/
|
||||
|
||||
import { useCallback, useEffect, useRef, useState } from 'react'
|
||||
import { useChannelContext } from './context'
|
||||
import { RPCError } from './connection'
|
||||
import type {
|
||||
ConnectionStatus,
|
||||
ChannelSubscription,
|
||||
IncomingPayload,
|
||||
} from './types'
|
||||
|
||||
export interface UseChannelOptions<TDjangoMessage> {
|
||||
/** Called when subscribed successfully */
|
||||
onSubscribed?: () => void
|
||||
|
||||
/** Called when a message is received */
|
||||
onMessage?: (message: TDjangoMessage) => void
|
||||
|
||||
/** Called on error */
|
||||
onError?: (error: string) => void
|
||||
|
||||
/** Called when unsubscribed */
|
||||
onUnsubscribed?: () => void
|
||||
|
||||
/** Maximum messages to keep in history (default: 100) */
|
||||
maxMessages?: number
|
||||
}
|
||||
|
||||
/**
|
||||
* Subscribe to a channel and receive typed messages.
|
||||
*
|
||||
* @param channelName - The registered channel name
|
||||
* @param params - Channel parameters (if required)
|
||||
* @param options - Subscription options
|
||||
*/
|
||||
export function useChannel<
|
||||
TParams = undefined,
|
||||
TDjangoMessage = unknown,
|
||||
TReactMessage = unknown,
|
||||
>(
|
||||
channelName: string,
|
||||
params?: TParams,
|
||||
options: UseChannelOptions<TDjangoMessage> = {},
|
||||
): ChannelSubscription<TParams, TDjangoMessage, TReactMessage> {
|
||||
const { connection, status: connectionStatus } = useChannelContext()
|
||||
|
||||
const [messages, setMessages] = useState<TDjangoMessage[]>([])
|
||||
const [subscribed, setSubscribed] = useState(false)
|
||||
|
||||
const optionsRef = useRef(options)
|
||||
optionsRef.current = options
|
||||
|
||||
const maxMessages = options.maxMessages ?? 100
|
||||
|
||||
// Stable params reference for effect dependencies
|
||||
const paramsJson = JSON.stringify(params ?? {})
|
||||
const paramsRef = useRef(params)
|
||||
paramsRef.current = params
|
||||
|
||||
// Subscribe on mount / when params change
|
||||
useEffect(() => {
|
||||
if (connectionStatus !== 'connected') {
|
||||
return
|
||||
}
|
||||
|
||||
const currentParams = paramsRef.current ?? {}
|
||||
|
||||
// Subscribe
|
||||
connection.send({
|
||||
action: 'subscribe',
|
||||
channel: channelName,
|
||||
params: currentParams as Record<string, unknown>,
|
||||
})
|
||||
|
||||
// Handle incoming messages
|
||||
const unsubscribeMessages = connection.onMessage((payload: IncomingPayload) => {
|
||||
// Check for subscription confirmation
|
||||
if ('subscribed' in payload && payload.channel === channelName) {
|
||||
setSubscribed(true)
|
||||
optionsRef.current.onSubscribed?.()
|
||||
return
|
||||
}
|
||||
|
||||
// Check for unsubscription confirmation
|
||||
if ('unsubscribed' in payload && payload.channel === channelName) {
|
||||
setSubscribed(false)
|
||||
optionsRef.current.onUnsubscribed?.()
|
||||
return
|
||||
}
|
||||
|
||||
// Check for errors
|
||||
if ('error' in payload) {
|
||||
if (!payload.channel || payload.channel === channelName) {
|
||||
optionsRef.current.onError?.(payload.error)
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
// Handle data messages
|
||||
if ('type' in payload && 'data' in payload) {
|
||||
const message = payload.data as TDjangoMessage
|
||||
setMessages(prev => {
|
||||
const next = [...prev, message]
|
||||
// Trim to max messages
|
||||
if (next.length > maxMessages) {
|
||||
return next.slice(-maxMessages)
|
||||
}
|
||||
return next
|
||||
})
|
||||
optionsRef.current.onMessage?.(message)
|
||||
}
|
||||
})
|
||||
|
||||
// Cleanup: unsubscribe
|
||||
return () => {
|
||||
unsubscribeMessages()
|
||||
|
||||
connection.send({
|
||||
action: 'unsubscribe',
|
||||
channel: channelName,
|
||||
params: currentParams as Record<string, unknown>,
|
||||
})
|
||||
}
|
||||
}, [connection, connectionStatus, channelName, paramsJson, maxMessages])
|
||||
|
||||
// Send function
|
||||
const send = useCallback((message: TReactMessage) => {
|
||||
if (!subscribed) {
|
||||
console.warn(`[useChannel] Cannot send: not subscribed to ${channelName}`)
|
||||
return
|
||||
}
|
||||
|
||||
connection.send({
|
||||
action: 'message',
|
||||
channel: channelName,
|
||||
params: (paramsRef.current ?? {}) as Record<string, unknown>,
|
||||
data: message,
|
||||
})
|
||||
}, [connection, channelName, subscribed])
|
||||
|
||||
// Unsubscribe function
|
||||
const unsubscribe = useCallback(() => {
|
||||
connection.send({
|
||||
action: 'unsubscribe',
|
||||
channel: channelName,
|
||||
params: (paramsRef.current ?? {}) as Record<string, unknown>,
|
||||
})
|
||||
}, [connection, channelName])
|
||||
|
||||
// Clear messages
|
||||
const clearMessages = useCallback(() => {
|
||||
setMessages([])
|
||||
}, [])
|
||||
|
||||
// Derive status
|
||||
const status: ConnectionStatus = !subscribed
|
||||
? connectionStatus === 'connected' ? 'connecting' : connectionStatus
|
||||
: 'connected'
|
||||
|
||||
return {
|
||||
status,
|
||||
messages,
|
||||
send: send as ChannelSubscription<TParams, TDjangoMessage, TReactMessage>['send'],
|
||||
unsubscribe,
|
||||
clearMessages,
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Get only the latest message from a channel (useful for presence, typing indicators)
|
||||
*/
|
||||
export function useChannelLatest<
|
||||
TParams = undefined,
|
||||
TDjangoMessage = unknown,
|
||||
TReactMessage = unknown,
|
||||
>(
|
||||
channelName: string,
|
||||
params?: TParams,
|
||||
options: UseChannelOptions<TDjangoMessage> = {},
|
||||
): Omit<ChannelSubscription<TParams, TDjangoMessage, TReactMessage>, 'messages'> & { latest: TDjangoMessage | null } {
|
||||
const [latest, setLatest] = useState<TDjangoMessage | null>(null)
|
||||
|
||||
const channel = useChannel<TParams, TDjangoMessage, TReactMessage>(
|
||||
channelName,
|
||||
params,
|
||||
{
|
||||
...options,
|
||||
onMessage: (msg) => {
|
||||
setLatest(msg)
|
||||
options.onMessage?.(msg)
|
||||
},
|
||||
maxMessages: 1,
|
||||
},
|
||||
)
|
||||
|
||||
// Explicitly exclude messages to match the documented API
|
||||
const { messages: _, ...rest } = channel
|
||||
|
||||
return {
|
||||
...rest,
|
||||
latest,
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
// =============================================================================
|
||||
// RPC Hooks
|
||||
// =============================================================================
|
||||
|
||||
export interface RPCClient {
|
||||
/**
|
||||
* Call a server function.
|
||||
*
|
||||
* @param fn - Function name
|
||||
* @param args - Function arguments
|
||||
* @returns Promise resolving to function output
|
||||
* @throws RPCError on failure
|
||||
*/
|
||||
call<TInput, TOutput>(fn: string, args: TInput): Promise<TOutput>
|
||||
|
||||
/** Connection status */
|
||||
status: ConnectionStatus
|
||||
}
|
||||
|
||||
/**
|
||||
* Get an RPC client for calling server functions.
|
||||
*
|
||||
* Usage:
|
||||
* const rpc = useRPC()
|
||||
* const result = await rpc.call('update_profile', { name: 'New Name' })
|
||||
*
|
||||
* The generated code wraps this with typed functions:
|
||||
* const { updateProfile } = useDjango()
|
||||
* const result = await updateProfile({ name: 'New Name' })
|
||||
*/
|
||||
export function useRPC(): RPCClient {
|
||||
const { connection, status } = useChannelContext()
|
||||
|
||||
const call = useCallback(<TInput, TOutput>(fn: string, args: TInput): Promise<TOutput> => {
|
||||
return connection.rpc<TInput, TOutput>(fn, args)
|
||||
}, [connection])
|
||||
|
||||
return {
|
||||
call,
|
||||
status,
|
||||
}
|
||||
}
|
||||
|
||||
// Re-export RPCError for convenience
|
||||
export { RPCError }
|
||||
76
packages/mizan-csr/adapters/react/src/channels/index.ts
Normal file
76
packages/mizan-csr/adapters/react/src/channels/index.ts
Normal file
@@ -0,0 +1,76 @@
|
||||
/**
|
||||
* mizan/channels
|
||||
*
|
||||
* Real-time WebSocket communication with Django Channels.
|
||||
* Type-safe bidirectional messaging.
|
||||
*
|
||||
* ## Setup
|
||||
*
|
||||
* ```tsx
|
||||
* // layout.tsx
|
||||
* import { ChannelProvider } from 'mizan/channels'
|
||||
*
|
||||
* export default function Layout({ children }) {
|
||||
* return (
|
||||
* <ChannelProvider>
|
||||
* {children}
|
||||
* </ChannelProvider>
|
||||
* )
|
||||
* }
|
||||
* ```
|
||||
*
|
||||
* ## Usage
|
||||
*
|
||||
* ```tsx
|
||||
* // Using generated hooks (recommended)
|
||||
* import { useChatChannel } from '@/api/generated.channels'
|
||||
*
|
||||
* function Chat({ room }) {
|
||||
* const chat = useChatChannel({ room })
|
||||
*
|
||||
* chat.status // 'connecting' | 'connected' | 'disconnected'
|
||||
* chat.messages // DjangoMessage[]
|
||||
* chat.send({ text: 'Hello' }) // Send ReactMessage
|
||||
* }
|
||||
* ```
|
||||
*
|
||||
* ```tsx
|
||||
* // Using raw hook (for custom channels)
|
||||
* import { useChannel } from 'mizan/channels'
|
||||
*
|
||||
* function CustomChannel() {
|
||||
* const channel = useChannel<
|
||||
* { room: string }, // Params
|
||||
* { user: string; text: string }, // DjangoMessage
|
||||
* { text: string } // ReactMessage
|
||||
* >('chat', { room: 'general' })
|
||||
*
|
||||
* // ...
|
||||
* }
|
||||
* ```
|
||||
*/
|
||||
|
||||
// Context
|
||||
export { ChannelProvider, useChannelContext, useChannelStatus } from './context'
|
||||
export type { ChannelProviderProps } from './context'
|
||||
|
||||
// Hooks
|
||||
export { useChannel, useChannelLatest, useRPC, RPCError } from './hooks'
|
||||
export type { UseChannelOptions, RPCClient } from './hooks'
|
||||
|
||||
// Connection (for advanced use)
|
||||
export { ChannelConnection, getDefaultConnection } from './connection'
|
||||
export type {
|
||||
ChannelConnectionOptions,
|
||||
RPCRequest,
|
||||
RPCResponse,
|
||||
RPCSuccessResponse,
|
||||
RPCErrorResponse,
|
||||
} from './connection'
|
||||
|
||||
// Types
|
||||
export type {
|
||||
ConnectionStatus,
|
||||
ChannelSubscription,
|
||||
SubscribeOptions,
|
||||
} from './types'
|
||||
84
packages/mizan-csr/adapters/react/src/channels/types.ts
Normal file
84
packages/mizan-csr/adapters/react/src/channels/types.ts
Normal file
@@ -0,0 +1,84 @@
|
||||
/**
|
||||
* Types for mizan/channels
|
||||
*/
|
||||
|
||||
export type ConnectionStatus = 'connecting' | 'connected' | 'disconnected'
|
||||
|
||||
export interface ChannelSubscription<TParams = unknown, TDjangoMessage = unknown, TReactMessage = unknown> {
|
||||
/** Current connection status */
|
||||
status: ConnectionStatus
|
||||
|
||||
/** Received messages */
|
||||
messages: TDjangoMessage[]
|
||||
|
||||
/** Send a message (if channel accepts ReactMessage) */
|
||||
send: TReactMessage extends never ? never : (message: TReactMessage) => void
|
||||
|
||||
/** Unsubscribe from the channel */
|
||||
unsubscribe: () => void
|
||||
|
||||
/** Clear accumulated messages */
|
||||
clearMessages: () => void
|
||||
}
|
||||
|
||||
export interface SubscribeOptions {
|
||||
/** Called when subscribed successfully */
|
||||
onSubscribed?: () => void
|
||||
|
||||
/** Called when a message is received */
|
||||
onMessage?: (message: unknown) => void
|
||||
|
||||
/** Called on error */
|
||||
onError?: (error: string) => void
|
||||
|
||||
/** Called when unsubscribed */
|
||||
onUnsubscribed?: () => void
|
||||
}
|
||||
|
||||
/**
|
||||
* Protocol messages sent over the WebSocket
|
||||
*/
|
||||
export interface SubscribeAction {
|
||||
action: 'subscribe'
|
||||
channel: string
|
||||
params: Record<string, unknown>
|
||||
}
|
||||
|
||||
export interface UnsubscribeAction {
|
||||
action: 'unsubscribe'
|
||||
channel: string
|
||||
params: Record<string, unknown>
|
||||
}
|
||||
|
||||
export interface MessageAction {
|
||||
action: 'message'
|
||||
channel: string
|
||||
params: Record<string, unknown>
|
||||
data: unknown
|
||||
}
|
||||
|
||||
export type OutgoingMessage = SubscribeAction | UnsubscribeAction | MessageAction
|
||||
|
||||
export interface IncomingSubscribed {
|
||||
subscribed: true
|
||||
channel: string
|
||||
params: Record<string, unknown>
|
||||
}
|
||||
|
||||
export interface IncomingUnsubscribed {
|
||||
unsubscribed: true
|
||||
channel: string
|
||||
params: Record<string, unknown>
|
||||
}
|
||||
|
||||
export interface IncomingMessage {
|
||||
type: string
|
||||
data: unknown
|
||||
}
|
||||
|
||||
export interface IncomingError {
|
||||
error: string
|
||||
channel?: string
|
||||
}
|
||||
|
||||
export type IncomingPayload = IncomingSubscribed | IncomingUnsubscribed | IncomingMessage | IncomingError
|
||||
142
packages/mizan-csr/adapters/react/src/client/AuthContext.tsx
Normal file
142
packages/mizan-csr/adapters/react/src/client/AuthContext.tsx
Normal file
@@ -0,0 +1,142 @@
|
||||
'use client'
|
||||
|
||||
import {
|
||||
createContext,
|
||||
useContext,
|
||||
useState,
|
||||
useCallback,
|
||||
useMemo,
|
||||
type ReactNode,
|
||||
} from 'react'
|
||||
import { createDjangoCSRClient, Auth } from './index'
|
||||
import type { BaseUser, AuthDetails, AuthRoutes } from './types'
|
||||
|
||||
/**
|
||||
* Auth state provided by AuthContext.
|
||||
*/
|
||||
export interface AuthState<TUser extends BaseUser = BaseUser> {
|
||||
/** Current user (null if not authenticated) */
|
||||
user: TUser | null
|
||||
/** Whether auth state is loading */
|
||||
isLoading: boolean
|
||||
/** Refresh user from server */
|
||||
refresh: () => Promise<TUser | null>
|
||||
}
|
||||
|
||||
const Context = createContext<AuthState | null>(null)
|
||||
|
||||
/**
|
||||
* Default routes configuration.
|
||||
*/
|
||||
export const defaultRoutes: AuthRoutes = {
|
||||
login: '/auth/login',
|
||||
authenticated: '/dashboard',
|
||||
}
|
||||
|
||||
const RoutesContext = createContext<AuthRoutes>(defaultRoutes)
|
||||
|
||||
export interface AuthContextProps<TUser extends BaseUser = BaseUser> {
|
||||
children: ReactNode
|
||||
/** Initial user from SSR hydration (null if not authenticated) */
|
||||
user?: TUser | null
|
||||
/** API endpoint to fetch user data (default: '/api/auth/me/') */
|
||||
userEndpoint?: string
|
||||
/** Route configuration for guards */
|
||||
routes?: Partial<AuthRoutes>
|
||||
}
|
||||
|
||||
/**
|
||||
* Base auth context for Django-React apps.
|
||||
*
|
||||
* Provides user state from a simple /me endpoint.
|
||||
* For allauth integration, use AllauthContext instead.
|
||||
*/
|
||||
// Create client once at module level (session auth, no dynamic config needed)
|
||||
const client = createDjangoCSRClient(Auth.SESSION)
|
||||
|
||||
export function AuthContext<TUser extends BaseUser = BaseUser>({
|
||||
children,
|
||||
user: initialUser = null,
|
||||
userEndpoint = '/api/auth/me/',
|
||||
routes,
|
||||
}: AuthContextProps<TUser>) {
|
||||
const [user, setUser] = useState<TUser | null>(initialUser)
|
||||
const [isLoading, setIsLoading] = useState(false)
|
||||
|
||||
const refresh = useCallback(async (): Promise<TUser | null> => {
|
||||
setIsLoading(true)
|
||||
try {
|
||||
const resp = await client.request('GET', userEndpoint)
|
||||
if (resp.ok) {
|
||||
const userData = await resp.json()
|
||||
setUser(userData)
|
||||
return userData
|
||||
} else if (resp.status === 401 || resp.status === 403) {
|
||||
setUser(null)
|
||||
return null
|
||||
}
|
||||
throw new Error(`Failed to fetch user: ${resp.status}`)
|
||||
} catch (e) {
|
||||
console.error('[AuthContext] Failed to fetch user:', e)
|
||||
return null
|
||||
} finally {
|
||||
setIsLoading(false)
|
||||
}
|
||||
}, [userEndpoint])
|
||||
|
||||
const authState = useMemo<AuthState<TUser>>(() => ({
|
||||
user,
|
||||
isLoading,
|
||||
refresh,
|
||||
}), [user, isLoading, refresh])
|
||||
|
||||
const routesValue = useMemo(() => ({
|
||||
...defaultRoutes,
|
||||
...routes,
|
||||
}), [routes])
|
||||
|
||||
return (
|
||||
<RoutesContext value={routesValue}>
|
||||
<Context value={authState}>
|
||||
{children}
|
||||
</Context>
|
||||
</RoutesContext>
|
||||
)
|
||||
}
|
||||
|
||||
/**
|
||||
* Hook to access auth state.
|
||||
* Throws if used outside AuthContext.
|
||||
*/
|
||||
export function useAuthState<TUser extends BaseUser = BaseUser>(): AuthState<TUser> {
|
||||
const ctx = useContext(Context)
|
||||
if (!ctx) throw new Error('useAuthState must be used within AuthContext')
|
||||
return ctx as AuthState<TUser>
|
||||
}
|
||||
|
||||
/**
|
||||
* Hook to access current user.
|
||||
* Returns null if not authenticated.
|
||||
*/
|
||||
export function useUser<TUser extends BaseUser = BaseUser>(): TUser | null {
|
||||
return useAuthState<TUser>().user
|
||||
}
|
||||
|
||||
/**
|
||||
* Hook to access auth details (isAuthenticated, isStaff, etc.)
|
||||
*/
|
||||
export function useAuth(): AuthDetails {
|
||||
const user = useUser()
|
||||
return {
|
||||
isAuthenticated: user !== null,
|
||||
isStaff: user?.is_staff ?? false,
|
||||
isSuperuser: user?.is_superuser ?? false,
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Hook to access route configuration.
|
||||
*/
|
||||
export function useAuthRoutes(): AuthRoutes {
|
||||
return useContext(RoutesContext)
|
||||
}
|
||||
@@ -0,0 +1,43 @@
|
||||
'use client'
|
||||
|
||||
import { createContext, useContext, type ReactNode } from 'react'
|
||||
|
||||
/**
|
||||
* Framework-agnostic router adapter.
|
||||
* Implement this interface for your framework (Next.js, Remix, etc.)
|
||||
*/
|
||||
export interface RouterAdapter {
|
||||
/** Navigate to a path (adds to history) */
|
||||
push: (path: string) => void
|
||||
/** Replace current path (no history entry) */
|
||||
replace: (path: string) => void
|
||||
/** Current pathname (e.g., "/account/login") */
|
||||
pathname: string
|
||||
/** Current search params */
|
||||
searchParams: URLSearchParams
|
||||
/** Get a specific route param (e.g., from /auth/[...path]) - optional */
|
||||
getParam?: (name: string) => string | string[] | undefined
|
||||
}
|
||||
|
||||
const Context = createContext<RouterAdapter | null>(null)
|
||||
|
||||
interface RouterContextProps {
|
||||
children: ReactNode
|
||||
router: RouterAdapter
|
||||
}
|
||||
|
||||
/**
|
||||
* Provides router adapter to route guards.
|
||||
*/
|
||||
export function RouterContext({ children, router }: RouterContextProps) {
|
||||
return <Context value={router}>{children}</Context>
|
||||
}
|
||||
|
||||
/**
|
||||
* Hook to access router adapter.
|
||||
*/
|
||||
export function useRouter(): RouterAdapter {
|
||||
const ctx = useContext(Context)
|
||||
if (!ctx) throw new Error('useRouter must be used within RouterContext')
|
||||
return ctx
|
||||
}
|
||||
585
packages/mizan-csr/adapters/react/src/client/index.ts
Normal file
585
packages/mizan-csr/adapters/react/src/client/index.ts
Normal file
@@ -0,0 +1,585 @@
|
||||
/**
|
||||
* mizan/client
|
||||
*
|
||||
* HTTP client factories for Django backends.
|
||||
* Framework-agnostic - works with vanilla JS, React, Vue, etc.
|
||||
*
|
||||
* ## Quick Start
|
||||
*
|
||||
* ### Client-Side (CSR)
|
||||
* ```ts
|
||||
* import { createDjangoCSRClient, Auth } from 'mizan/client'
|
||||
*
|
||||
* // Session-based (cookies + CSRF)
|
||||
* const client = createDjangoCSRClient(Auth.SESSION)
|
||||
*
|
||||
* // JWT-based (Bearer token)
|
||||
* const client = createDjangoCSRClient(Auth.JWT, { getAccessToken })
|
||||
*
|
||||
* const user = await client.json('GET', '/api/accounts/me/')
|
||||
* ```
|
||||
*
|
||||
* ### Server-Side (SSR)
|
||||
* ```ts
|
||||
* import { createDjangoSSRClient } from 'mizan/client'
|
||||
*
|
||||
* const client = createDjangoSSRClient({
|
||||
* cookies: await cookies() // Next.js cookies()
|
||||
* })
|
||||
*
|
||||
* const user = await client.json('GET', '/api/accounts/me/')
|
||||
* ```
|
||||
*
|
||||
* ## React Hooks
|
||||
*
|
||||
* For React, import from `/react`:
|
||||
* ```tsx
|
||||
* import { useDjangoCSRClient, Auth } from 'mizan/client/react'
|
||||
*
|
||||
* const client = useDjangoCSRClient(Auth.SESSION)
|
||||
* ```
|
||||
*/
|
||||
|
||||
// =============================================================================
|
||||
// Types
|
||||
// =============================================================================
|
||||
|
||||
/**
|
||||
* Authentication strategy for client-side requests.
|
||||
*/
|
||||
export enum Auth {
|
||||
/** Session cookies with CSRF token */
|
||||
SESSION = 'session',
|
||||
/** JWT Bearer token */
|
||||
JWT = 'jwt',
|
||||
}
|
||||
|
||||
/**
|
||||
* Cookie getter interface (matches Next.js cookies() return type).
|
||||
*/
|
||||
export interface CookieGetter {
|
||||
get(name: string): { name: string; value: string } | undefined
|
||||
getAll(): { name: string; value: string }[]
|
||||
}
|
||||
|
||||
/**
|
||||
* Cookie configuration for SSR requests.
|
||||
* Can be either a cookie getter (like Next.js cookies()) or pre-extracted values.
|
||||
*/
|
||||
export type SSRCookies = CookieGetter | {
|
||||
/** CSRF token value */
|
||||
csrf: string
|
||||
/** Full cookie header string */
|
||||
cookieHeader: string
|
||||
}
|
||||
|
||||
/**
|
||||
* The core HTTP client interface for Django requests.
|
||||
*/
|
||||
export interface DjangoHTTPClient {
|
||||
/**
|
||||
* Make an HTTP request, returning the raw Response.
|
||||
*/
|
||||
request(method: string, path: string, data?: unknown, headers?: Record<string, string>): Promise<Response>
|
||||
|
||||
/**
|
||||
* Make an HTTP request, parsing the response as JSON.
|
||||
* @throws {HttpError} When response is not ok
|
||||
*/
|
||||
json<T>(method: string, path: string, data?: unknown, headers?: Record<string, string>): Promise<T>
|
||||
}
|
||||
|
||||
/**
|
||||
* Configuration for CSR client.
|
||||
*/
|
||||
export interface CSRClientConfig {
|
||||
/** Base URL for the Django backend */
|
||||
baseUrl?: string | (() => string)
|
||||
}
|
||||
|
||||
/**
|
||||
* Configuration for JWT-authenticated CSR client.
|
||||
*/
|
||||
export interface JWTClientConfig extends CSRClientConfig {
|
||||
/** Async function that returns the current access token */
|
||||
getAccessToken: () => Promise<string | null>
|
||||
}
|
||||
|
||||
/**
|
||||
* Configuration for SSR client.
|
||||
*/
|
||||
export interface SSRClientConfig {
|
||||
/** Cookies for authentication forwarding */
|
||||
cookies: SSRCookies
|
||||
/** Internal backend URL override (defaults to http://${INTERNAL_BACKEND_HOSTNAME}:8000) */
|
||||
baseUrl?: string
|
||||
}
|
||||
|
||||
// =============================================================================
|
||||
// Errors
|
||||
// =============================================================================
|
||||
|
||||
/**
|
||||
* Details about an HTTP error.
|
||||
*/
|
||||
export interface HttpErrorDetails {
|
||||
status: number
|
||||
statusText: string
|
||||
url: string
|
||||
bodyJson?: unknown
|
||||
bodySnippet?: string
|
||||
bodyIsHtml?: boolean
|
||||
}
|
||||
|
||||
/**
|
||||
* Error thrown when an HTTP request fails.
|
||||
*/
|
||||
export class HttpError extends Error {
|
||||
constructor(
|
||||
message: string,
|
||||
public readonly details: HttpErrorDetails
|
||||
) {
|
||||
super(message)
|
||||
this.name = 'HttpError'
|
||||
}
|
||||
}
|
||||
|
||||
// =============================================================================
|
||||
// Internal Utilities
|
||||
// =============================================================================
|
||||
|
||||
function getCookie(name: string): string | null {
|
||||
if (typeof document === 'undefined') return null
|
||||
const value = `; ${document.cookie}`
|
||||
const parts = value.split(`; ${name}=`)
|
||||
if (parts.length === 2) {
|
||||
return parts.pop()?.split(';').shift() ?? null
|
||||
}
|
||||
return null
|
||||
}
|
||||
|
||||
function getCSRFToken(): string | null {
|
||||
return getCookie('csrftoken')
|
||||
}
|
||||
|
||||
interface RequestBuild {
|
||||
request: RequestInit
|
||||
hasBody: boolean
|
||||
}
|
||||
|
||||
function buildRequest(method: string, data?: unknown, headers?: Record<string, string>): RequestBuild {
|
||||
const isBodyMethod = !['GET', 'HEAD'].includes(method.toUpperCase())
|
||||
const hasBody = isBodyMethod && data !== undefined
|
||||
|
||||
const requestHeaders: Record<string, string> = {
|
||||
'Accept': 'application/json',
|
||||
...headers,
|
||||
}
|
||||
|
||||
let body: BodyInit | undefined
|
||||
if (hasBody) {
|
||||
if (data instanceof FormData) {
|
||||
body = data
|
||||
} else {
|
||||
body = JSON.stringify(data)
|
||||
requestHeaders['Content-Type'] = 'application/json'
|
||||
}
|
||||
}
|
||||
|
||||
return {
|
||||
request: {
|
||||
method: method.toUpperCase(),
|
||||
headers: requestHeaders,
|
||||
body,
|
||||
credentials: 'include',
|
||||
},
|
||||
hasBody,
|
||||
}
|
||||
}
|
||||
|
||||
async function buildHttpError(resp: Response, url: URL | string): Promise<HttpError> {
|
||||
const urlStr = typeof url === 'string' ? url : url.toString()
|
||||
const details: HttpErrorDetails = {
|
||||
status: resp.status,
|
||||
statusText: resp.statusText,
|
||||
url: urlStr,
|
||||
}
|
||||
|
||||
try {
|
||||
const contentType = resp.headers.get('content-type') ?? ''
|
||||
if (contentType.includes('application/json')) {
|
||||
details.bodyJson = await resp.clone().json()
|
||||
} else {
|
||||
const text = await resp.clone().text()
|
||||
details.bodyIsHtml = contentType.includes('text/html')
|
||||
details.bodySnippet = text.slice(0, 500)
|
||||
}
|
||||
} catch {
|
||||
// Ignore body parsing errors
|
||||
}
|
||||
|
||||
return new HttpError(`Request failed: ${resp.status} ${resp.statusText}`, details)
|
||||
}
|
||||
|
||||
// =============================================================================
|
||||
// CSR Client Factory
|
||||
// =============================================================================
|
||||
|
||||
/**
|
||||
* Create a client-side HTTP client for Django.
|
||||
*
|
||||
* @param auth - Authentication strategy (Auth.SESSION or Auth.JWT)
|
||||
* @param config - Client configuration
|
||||
* @returns DjangoHTTPClient
|
||||
*
|
||||
* @example
|
||||
* // Session-based
|
||||
* const client = createDjangoCSRClient(Auth.SESSION)
|
||||
*
|
||||
* @example
|
||||
* // JWT-based
|
||||
* const client = createDjangoCSRClient(Auth.JWT, {
|
||||
* getAccessToken: async () => localStorage.getItem('token')
|
||||
* })
|
||||
*/
|
||||
export function createDjangoCSRClient(auth: Auth.SESSION, config?: CSRClientConfig): DjangoHTTPClient
|
||||
export function createDjangoCSRClient(auth: Auth.JWT, config: JWTClientConfig): DjangoHTTPClient
|
||||
export function createDjangoCSRClient(
|
||||
auth: Auth,
|
||||
config?: CSRClientConfig | JWTClientConfig
|
||||
): DjangoHTTPClient {
|
||||
if (!config?.baseUrl) {
|
||||
throw new Error(
|
||||
'baseUrl is required. Pass it via config or use MizanProvider which provides it automatically.'
|
||||
)
|
||||
}
|
||||
|
||||
const getBaseUrl = () => typeof config.baseUrl === 'function' ? config.baseUrl() : config.baseUrl!
|
||||
|
||||
const getHeaders = async (): Promise<Record<string, string>> => {
|
||||
if (auth === Auth.JWT) {
|
||||
const jwtConfig = config as JWTClientConfig
|
||||
const token = await jwtConfig.getAccessToken()
|
||||
if (token) {
|
||||
return { Authorization: `Bearer ${token}` }
|
||||
}
|
||||
return {}
|
||||
}
|
||||
// Session auth uses CSRF
|
||||
return { 'X-CSRFToken': getCSRFToken() ?? '' }
|
||||
}
|
||||
|
||||
function resolveUrl(path: string): string {
|
||||
const base = getBaseUrl()
|
||||
// Absolute base URL — use URL constructor
|
||||
if (base.startsWith('http://') || base.startsWith('https://')) {
|
||||
return new URL(path, base).toString()
|
||||
}
|
||||
// Relative base — path is already usable by fetch in a browser
|
||||
return path
|
||||
}
|
||||
|
||||
return {
|
||||
request: async (method, path, data?, headers?) => {
|
||||
const url = resolveUrl(path)
|
||||
const configHeaders = await getHeaders()
|
||||
const build = buildRequest(method, data, { ...configHeaders, ...headers })
|
||||
return fetch(url, build.request)
|
||||
},
|
||||
|
||||
json: async <T>(method: string, path: string, data?: unknown, headers?: Record<string, string>): Promise<T> => {
|
||||
const url = resolveUrl(path)
|
||||
const configHeaders = await getHeaders()
|
||||
const build = buildRequest(method, data, { ...configHeaders, ...headers })
|
||||
const resp = await fetch(url, build.request)
|
||||
|
||||
if (!resp.ok) {
|
||||
throw await buildHttpError(resp, url)
|
||||
}
|
||||
|
||||
return resp.json()
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
// =============================================================================
|
||||
// Internal Backend URL Resolution
|
||||
// =============================================================================
|
||||
|
||||
function getInternalBackendUrl(override?: string): string {
|
||||
if (override) return override
|
||||
throw new Error(
|
||||
'baseUrl is required for SSR client. Pass it via config.'
|
||||
)
|
||||
}
|
||||
|
||||
// =============================================================================
|
||||
// SSR Client Factory
|
||||
// =============================================================================
|
||||
|
||||
/**
|
||||
* Check if cookies is a CookieGetter interface.
|
||||
*/
|
||||
function isCookieGetter(cookies: SSRCookies): cookies is CookieGetter {
|
||||
return typeof (cookies as CookieGetter).get === 'function'
|
||||
}
|
||||
|
||||
/**
|
||||
* Extract CSRF token and cookie header from SSRCookies.
|
||||
*/
|
||||
function extractCookies(cookies: SSRCookies): { csrf: string; cookieHeader: string } {
|
||||
if (isCookieGetter(cookies)) {
|
||||
return {
|
||||
csrf: cookies.get('csrftoken')?.value ?? '',
|
||||
cookieHeader: cookies.getAll().map(c => `${c.name}=${c.value}`).join('; ')
|
||||
}
|
||||
}
|
||||
return cookies
|
||||
}
|
||||
|
||||
/**
|
||||
* Create a server-side HTTP client for Django.
|
||||
* Used in SSR contexts (Next.js server components, server actions, etc.)
|
||||
*
|
||||
* @param config - SSR client configuration with cookies
|
||||
* @returns DjangoHTTPClient
|
||||
*
|
||||
* @example
|
||||
* // Next.js server component
|
||||
* import { cookies } from 'next/headers'
|
||||
*
|
||||
* const client = createDjangoSSRClient({ cookies: await cookies() })
|
||||
*/
|
||||
// Re-export auth types for non-React usage
|
||||
export type {
|
||||
BaseUser,
|
||||
AuthDetails,
|
||||
AuthRoutes,
|
||||
JWTTokens,
|
||||
JWTConfig,
|
||||
JWTState,
|
||||
} from './types'
|
||||
|
||||
// Re-export RouterAdapter for libraries that extend it
|
||||
export type { RouterAdapter } from './RouterContext'
|
||||
|
||||
export function createDjangoSSRClient(config: SSRClientConfig): DjangoHTTPClient {
|
||||
const baseUrl = getInternalBackendUrl(config.baseUrl)
|
||||
const { csrf, cookieHeader } = extractCookies(config.cookies)
|
||||
|
||||
return {
|
||||
request: async (method, path, data?, headers?) => {
|
||||
const url = new URL(path, baseUrl)
|
||||
|
||||
const requestHeaders: Record<string, string> = {
|
||||
'Accept': 'application/json',
|
||||
'X-CSRFToken': csrf,
|
||||
'Cookie': cookieHeader,
|
||||
...headers,
|
||||
}
|
||||
|
||||
let body: BodyInit | undefined
|
||||
if (data && !['GET', 'HEAD'].includes(method.toUpperCase())) {
|
||||
if (data instanceof FormData) {
|
||||
body = data
|
||||
} else {
|
||||
body = JSON.stringify(data)
|
||||
requestHeaders['Content-Type'] = 'application/json'
|
||||
}
|
||||
}
|
||||
|
||||
return fetch(url, {
|
||||
method: method.toUpperCase(),
|
||||
headers: requestHeaders,
|
||||
body,
|
||||
credentials: 'include',
|
||||
cache: 'no-store',
|
||||
})
|
||||
},
|
||||
|
||||
json: async <T>(method: string, path: string, data?: unknown, headers?: Record<string, string>): Promise<T> => {
|
||||
const url = new URL(path, baseUrl)
|
||||
|
||||
const requestHeaders: Record<string, string> = {
|
||||
'Accept': 'application/json',
|
||||
'X-CSRFToken': csrf,
|
||||
'Cookie': cookieHeader,
|
||||
...headers,
|
||||
}
|
||||
|
||||
let body: BodyInit | undefined
|
||||
if (data && !['GET', 'HEAD'].includes(method.toUpperCase())) {
|
||||
if (data instanceof FormData) {
|
||||
body = data
|
||||
} else {
|
||||
body = JSON.stringify(data)
|
||||
requestHeaders['Content-Type'] = 'application/json'
|
||||
}
|
||||
}
|
||||
|
||||
const resp = await fetch(url, {
|
||||
method: method.toUpperCase(),
|
||||
headers: requestHeaders,
|
||||
body,
|
||||
credentials: 'include',
|
||||
cache: 'no-store',
|
||||
})
|
||||
|
||||
if (!resp.ok) {
|
||||
throw await buildHttpError(resp, url)
|
||||
}
|
||||
|
||||
const contentType = resp.headers.get('content-type') ?? ''
|
||||
if (!contentType.includes('application/json')) {
|
||||
throw new Error(`Expected JSON response but got ${contentType}`)
|
||||
}
|
||||
|
||||
return resp.json()
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
// =============================================================================
|
||||
// SSR Session Initialization
|
||||
// =============================================================================
|
||||
|
||||
/**
|
||||
* Response from the session initialization endpoint.
|
||||
*/
|
||||
interface SessionInitResponse {
|
||||
csrfToken: string
|
||||
}
|
||||
|
||||
/**
|
||||
* Ensure a Django session exists before making SSR requests.
|
||||
*
|
||||
* On first visit, the user has no cookies. This function pings Django to
|
||||
* establish a session and get a CSRF token, which can then be used for
|
||||
* subsequent hydration requests in the same SSR request chain.
|
||||
*
|
||||
* Note: Browser cookie forwarding is handled by Next.js middleware, not this
|
||||
* function. This function only ensures cookies exist for SSR data fetching.
|
||||
*
|
||||
* @param config - SSR client configuration with cookies
|
||||
* @returns Object with csrf token and cookie header for use in SSR requests
|
||||
*
|
||||
* @example
|
||||
* // In layout.tsx
|
||||
* const cookieStore = await cookies()
|
||||
* const session = await ensureDjangoSession({ cookies: cookieStore })
|
||||
* const client = createDjangoSSRClient({
|
||||
* cookies: { csrf: session.csrf, cookieHeader: session.cookieHeader }
|
||||
* })
|
||||
*/
|
||||
export async function ensureDjangoSession(config: SSRClientConfig): Promise<{
|
||||
csrf: string
|
||||
cookieHeader: string
|
||||
}> {
|
||||
const baseUrl = getInternalBackendUrl(config.baseUrl)
|
||||
const { csrf: existingCsrf, cookieHeader: existingCookies } = extractCookies(config.cookies)
|
||||
|
||||
// If we already have a CSRF token, just return existing cookies
|
||||
if (existingCsrf) {
|
||||
return { csrf: existingCsrf, cookieHeader: existingCookies }
|
||||
}
|
||||
|
||||
// No CSRF token - need to initialize session
|
||||
const url = new URL('/api/mizan/session/', baseUrl)
|
||||
const resp = await fetch(url, {
|
||||
method: 'GET',
|
||||
headers: {
|
||||
'Accept': 'application/json',
|
||||
'Cookie': existingCookies,
|
||||
},
|
||||
credentials: 'include',
|
||||
cache: 'no-store',
|
||||
})
|
||||
|
||||
if (!resp.ok) {
|
||||
console.error('[mizan] Failed to initialize session:', resp.status, resp.statusText)
|
||||
return { csrf: '', cookieHeader: existingCookies }
|
||||
}
|
||||
|
||||
// Extract CSRF token from response body
|
||||
const data: SessionInitResponse = await resp.json()
|
||||
|
||||
// Extract Set-Cookie headers to build updated cookie string for SSR chain
|
||||
const setCookieHeaders = resp.headers.getSetCookie?.() ?? []
|
||||
const newCookies = setCookieHeaders.map(c => c.split(';')[0]).join('; ')
|
||||
const combinedCookies = existingCookies
|
||||
? `${existingCookies}; ${newCookies}`
|
||||
: newCookies
|
||||
|
||||
return {
|
||||
csrf: data.csrfToken,
|
||||
cookieHeader: combinedCookies,
|
||||
}
|
||||
}
|
||||
|
||||
// =============================================================================
|
||||
// Server Function HTTP Call
|
||||
// =============================================================================
|
||||
|
||||
// Re-export error types from the canonical location
|
||||
export type { FunctionErrorResponse } from '../errors'
|
||||
import { DjangoError, type FunctionErrorResponse } from '../errors'
|
||||
|
||||
/**
|
||||
* Success response from a server function
|
||||
*/
|
||||
export interface FunctionSuccessResponse<T> {
|
||||
error: false
|
||||
data: T
|
||||
}
|
||||
|
||||
/**
|
||||
* Union type for server function responses
|
||||
*/
|
||||
export type FunctionResponse<T> = FunctionSuccessResponse<T> | FunctionErrorResponse
|
||||
|
||||
// Cached CSR client for server function calls
|
||||
let _csrClient: DjangoHTTPClient | null = null
|
||||
|
||||
function getCSRClient(): DjangoHTTPClient {
|
||||
if (!_csrClient) {
|
||||
_csrClient = createDjangoCSRClient(Auth.SESSION)
|
||||
}
|
||||
return _csrClient
|
||||
}
|
||||
|
||||
/**
|
||||
* Call a Django server function via HTTP.
|
||||
* Used as fallback when WebSocket is unavailable.
|
||||
*
|
||||
* Uses the standard CSR client with session-based auth.
|
||||
*
|
||||
* @param baseUrl - Base URL for the API (e.g., '/api/mizan')
|
||||
* @param functionName - Name of the server function
|
||||
* @param input - Input data for the function
|
||||
* @returns Promise resolving to the function output
|
||||
* @throws FunctionErrorResponse on failure
|
||||
*/
|
||||
export async function httpFunctionCall<TInput = unknown, TOutput = unknown>(
|
||||
baseUrl: string,
|
||||
functionName: string,
|
||||
input?: TInput
|
||||
): Promise<TOutput> {
|
||||
const client = getCSRClient()
|
||||
|
||||
// Use request() not json() because server functions return { error: true/false }
|
||||
// in the body, not HTTP status codes for business errors
|
||||
const response = await client.request(
|
||||
'POST',
|
||||
`${baseUrl}/call/`,
|
||||
{ fn: functionName, args: input }
|
||||
)
|
||||
|
||||
const data: FunctionResponse<TOutput> = await response.json()
|
||||
|
||||
if (data.error) {
|
||||
throw new DjangoError(data)
|
||||
}
|
||||
|
||||
return data.data
|
||||
}
|
||||
|
||||
72
packages/mizan-csr/adapters/react/src/client/nextjs.tsx
Normal file
72
packages/mizan-csr/adapters/react/src/client/nextjs.tsx
Normal file
@@ -0,0 +1,72 @@
|
||||
'use client'
|
||||
|
||||
/**
|
||||
* Next.js adapter for mizan/jwt.
|
||||
*
|
||||
* Usage:
|
||||
* ```tsx
|
||||
* // In layout.tsx
|
||||
* import { NextAuthContext } from 'mizan/jwt/nextjs'
|
||||
*
|
||||
* export default function RootLayout({ children }) {
|
||||
* return (
|
||||
* <NextAuthContext user={user}>
|
||||
* {children}
|
||||
* </NextAuthContext>
|
||||
* )
|
||||
* }
|
||||
* ```
|
||||
*/
|
||||
|
||||
import { type ReactNode } from 'react'
|
||||
import { useRouter, usePathname, useSearchParams } from 'next/navigation'
|
||||
import type { RouterAdapter } from './RouterContext'
|
||||
import { RouterContext } from './RouterContext'
|
||||
import { AuthContext, type AuthContextProps } from './AuthContext'
|
||||
import type { BaseUser, AuthRoutes } from './types'
|
||||
|
||||
/**
|
||||
* Create a RouterAdapter from Next.js App Router hooks.
|
||||
*/
|
||||
export function useNextRouter(): RouterAdapter {
|
||||
const router = useRouter()
|
||||
const pathname = usePathname()
|
||||
const searchParams = useSearchParams()
|
||||
|
||||
return {
|
||||
push: (path: string) => router.push(path),
|
||||
replace: (path: string) => router.replace(path),
|
||||
pathname,
|
||||
searchParams: new URLSearchParams(searchParams.toString()),
|
||||
}
|
||||
}
|
||||
|
||||
export interface NextAuthContextProps<TUser extends BaseUser = BaseUser> {
|
||||
children: ReactNode
|
||||
/** Initial user from SSR hydration */
|
||||
user?: TUser | null
|
||||
/** API endpoint to fetch user data (default: '/api/auth/me/') */
|
||||
userEndpoint?: string
|
||||
/** Route configuration for guards */
|
||||
routes?: Partial<AuthRoutes>
|
||||
}
|
||||
|
||||
/**
|
||||
* Next.js-specific AuthContext that handles the router automatically.
|
||||
*/
|
||||
export function NextAuthContext<TUser extends BaseUser = BaseUser>({
|
||||
children,
|
||||
user,
|
||||
userEndpoint,
|
||||
routes,
|
||||
}: NextAuthContextProps<TUser>) {
|
||||
const router = useNextRouter()
|
||||
|
||||
return (
|
||||
<RouterContext router={router}>
|
||||
<AuthContext user={user} userEndpoint={userEndpoint} routes={routes}>
|
||||
{children}
|
||||
</AuthContext>
|
||||
</RouterContext>
|
||||
)
|
||||
}
|
||||
63
packages/mizan-csr/adapters/react/src/client/react.ts
Normal file
63
packages/mizan-csr/adapters/react/src/client/react.ts
Normal file
@@ -0,0 +1,63 @@
|
||||
'use client'
|
||||
|
||||
import { useMemo } from 'react'
|
||||
import { useJWT } from '../jwt/JWTContext'
|
||||
import {
|
||||
createDjangoCSRClient,
|
||||
Auth,
|
||||
type DjangoHTTPClient,
|
||||
type CSRClientConfig,
|
||||
} from './index'
|
||||
|
||||
// Re-export everything from main entry for convenience
|
||||
export * from './index'
|
||||
|
||||
// Re-export auth components for React users
|
||||
export * from './AuthContext'
|
||||
export * from '../jwt/JWTContext'
|
||||
export * from './RouterContext'
|
||||
export * from './routing'
|
||||
export type * from './types'
|
||||
|
||||
/**
|
||||
* React hook that returns a client-side Django HTTP client.
|
||||
*
|
||||
* For SESSION auth, creates a session-based client with CSRF handling.
|
||||
* For JWT auth, automatically wires up the JWTContext from mizan/jwt.
|
||||
*
|
||||
* @param auth - Authentication strategy (Auth.SESSION or Auth.JWT)
|
||||
* @param config - Optional client configuration
|
||||
* @returns DjangoHTTPClient
|
||||
*
|
||||
* @example
|
||||
* // Session-based
|
||||
* const client = useDjangoCSRClient(Auth.SESSION)
|
||||
* const user = await client.json('GET', '/api/accounts/me/')
|
||||
*
|
||||
* @example
|
||||
* // JWT-based (requires JWTContext from mizan/jwt)
|
||||
* const client = useDjangoCSRClient(Auth.JWT)
|
||||
* const user = await client.json('GET', '/api/accounts/me/')
|
||||
*/
|
||||
export function useDjangoCSRClient(auth: Auth, config?: CSRClientConfig): DjangoHTTPClient {
|
||||
// Always call useJWT (React hooks must be unconditional)
|
||||
// Returns null when outside JWTContext
|
||||
const jwtContext = useJWT()
|
||||
|
||||
return useMemo(() => {
|
||||
if (auth === Auth.JWT) {
|
||||
if (!jwtContext?.getAccessToken) {
|
||||
throw new Error(
|
||||
'useDjangoCSRClient(Auth.JWT) requires JWTContext from mizan/jwt. ' +
|
||||
'Wrap your component in JWTContext to use JWT authentication.'
|
||||
)
|
||||
}
|
||||
return createDjangoCSRClient(Auth.JWT, {
|
||||
...config,
|
||||
getAccessToken: jwtContext.getAccessToken,
|
||||
})
|
||||
}
|
||||
|
||||
return createDjangoCSRClient(Auth.SESSION, config)
|
||||
}, [auth, config, jwtContext?.getAccessToken])
|
||||
}
|
||||
74
packages/mizan-csr/adapters/react/src/client/routing.tsx
Normal file
74
packages/mizan-csr/adapters/react/src/client/routing.tsx
Normal file
@@ -0,0 +1,74 @@
|
||||
'use client'
|
||||
|
||||
import { useEffect, type ReactNode } from 'react'
|
||||
import { useRouter } from './RouterContext'
|
||||
import { useAuth, useAuthRoutes } from './AuthContext'
|
||||
|
||||
/**
|
||||
* Route guard that only renders children if the user is authenticated.
|
||||
* Redirects to login page if not authenticated.
|
||||
*/
|
||||
export function UserRoute({ children }: { children: ReactNode }) {
|
||||
const router = useRouter()
|
||||
const routes = useAuthRoutes()
|
||||
const { isAuthenticated } = useAuth()
|
||||
|
||||
useEffect(() => {
|
||||
if (!isAuthenticated) {
|
||||
const searchParams = router.searchParams.toString()
|
||||
const currentPath = searchParams
|
||||
? `${router.pathname}?${searchParams}`
|
||||
: router.pathname
|
||||
const next = encodeURIComponent(currentPath)
|
||||
router.replace(`${routes.login}?next=${next}`)
|
||||
}
|
||||
}, [isAuthenticated, router, routes.login])
|
||||
|
||||
if (!isAuthenticated) return null
|
||||
return children
|
||||
}
|
||||
|
||||
/**
|
||||
* Route guard that only renders children if the user is authenticated AND is staff.
|
||||
* Redirects to login if not authenticated, or to authenticated route if not staff.
|
||||
*/
|
||||
export function StaffRoute({ children }: { children: ReactNode }) {
|
||||
const router = useRouter()
|
||||
const routes = useAuthRoutes()
|
||||
const { isAuthenticated, isStaff } = useAuth()
|
||||
|
||||
useEffect(() => {
|
||||
if (!isAuthenticated) {
|
||||
const searchParams = router.searchParams.toString()
|
||||
const currentPath = searchParams
|
||||
? `${router.pathname}?${searchParams}`
|
||||
: router.pathname
|
||||
const next = encodeURIComponent(currentPath)
|
||||
router.replace(`${routes.login}?next=${next}`)
|
||||
} else if (!isStaff) {
|
||||
router.replace(routes.authenticated)
|
||||
}
|
||||
}, [isAuthenticated, isStaff, router, routes])
|
||||
|
||||
if (!isAuthenticated || !isStaff) return null
|
||||
return children
|
||||
}
|
||||
|
||||
/**
|
||||
* Route guard that only renders children if the user is NOT authenticated.
|
||||
* Redirects to authenticated route if already logged in.
|
||||
*/
|
||||
export function AnonymousRoute({ children }: { children: ReactNode }) {
|
||||
const router = useRouter()
|
||||
const routes = useAuthRoutes()
|
||||
const { isAuthenticated } = useAuth()
|
||||
|
||||
useEffect(() => {
|
||||
if (isAuthenticated) {
|
||||
router.replace(routes.authenticated)
|
||||
}
|
||||
}, [isAuthenticated, routes.authenticated, router])
|
||||
|
||||
if (isAuthenticated) return null
|
||||
return children
|
||||
}
|
||||
66
packages/mizan-csr/adapters/react/src/client/types.ts
Normal file
66
packages/mizan-csr/adapters/react/src/client/types.ts
Normal file
@@ -0,0 +1,66 @@
|
||||
/**
|
||||
* Base user type - extend this for your app's user model.
|
||||
*/
|
||||
export interface BaseUser {
|
||||
id?: number | string
|
||||
email?: string
|
||||
username?: string
|
||||
is_staff?: boolean
|
||||
is_superuser?: boolean
|
||||
[key: string]: unknown
|
||||
}
|
||||
|
||||
/**
|
||||
* Auth state derived from user.
|
||||
*/
|
||||
export interface AuthDetails {
|
||||
isAuthenticated: boolean
|
||||
isStaff: boolean
|
||||
isSuperuser: boolean
|
||||
}
|
||||
|
||||
/**
|
||||
* Configuration for route guards.
|
||||
*/
|
||||
export interface AuthRoutes {
|
||||
login: string
|
||||
authenticated: string
|
||||
}
|
||||
|
||||
/**
|
||||
* JWT token structure.
|
||||
*/
|
||||
export interface JWTTokens {
|
||||
accessToken: string
|
||||
refreshToken: string
|
||||
expiresAt: number // Unix timestamp in ms
|
||||
}
|
||||
|
||||
/**
|
||||
* JWT endpoint configuration.
|
||||
*/
|
||||
export interface JWTConfig {
|
||||
/** Base URL for API calls (default: '' - use relative URLs) */
|
||||
baseUrl?: string
|
||||
/** mizan server function endpoint (default: /api/mizan/call/) */
|
||||
endpoint?: string
|
||||
/** Seconds before expiry to trigger refresh (default: 30) */
|
||||
refreshBuffer?: number
|
||||
/** Auto-obtain tokens on mount (default: true) */
|
||||
autoObtain?: boolean
|
||||
/** Auto-refresh tokens before expiry (default: true) */
|
||||
autoRefresh?: boolean
|
||||
}
|
||||
|
||||
/**
|
||||
* JWT state and methods.
|
||||
*/
|
||||
export interface JWTState {
|
||||
tokens: JWTTokens | null
|
||||
isLoading: boolean
|
||||
error: Error | null
|
||||
obtainTokens: () => Promise<JWTTokens | null>
|
||||
refreshTokens: () => Promise<JWTTokens | null>
|
||||
clearTokens: () => void
|
||||
getAccessToken: () => Promise<string | null>
|
||||
}
|
||||
751
packages/mizan-csr/adapters/react/src/context.tsx
Normal file
751
packages/mizan-csr/adapters/react/src/context.tsx
Normal file
@@ -0,0 +1,751 @@
|
||||
'use client'
|
||||
|
||||
/**
|
||||
* mizan React Context
|
||||
*
|
||||
* Provides server function calls via HTTP (default) or WebSocket RPC (opt-in).
|
||||
* This is the core React integration for Django server functions.
|
||||
*
|
||||
* Transport Model:
|
||||
* - HTTP-first: Functions use HTTP by default (transport='http' or undefined)
|
||||
* - WebSocket opt-in: Functions with transport='websocket' use WebSocket RPC
|
||||
* when connected, falling back to HTTP when disconnected
|
||||
*
|
||||
* Two layers:
|
||||
* 1. MizanProvider (this file) - Generic provider with name-based API
|
||||
* - Libraries like Allauth use this: useMizan(), useContext('current_user')
|
||||
*
|
||||
* 2. Generated DjangoContext (in @/api) - Typed wrapper around MizanProvider
|
||||
* - Product code uses this: useCurrentUser(), useUpdateProfile()
|
||||
*
|
||||
* The generated code wraps MizanProvider and adds type-safe hooks.
|
||||
*/
|
||||
|
||||
import {
|
||||
createContext,
|
||||
useContext as useReactContext,
|
||||
useEffect,
|
||||
useMemo,
|
||||
useRef,
|
||||
useState,
|
||||
useCallback,
|
||||
type ReactNode,
|
||||
} from 'react'
|
||||
import { ChannelConnection, RPCError } from 'mizan/channels'
|
||||
import {
|
||||
createDjangoCSRClient,
|
||||
Auth,
|
||||
type FunctionResponse,
|
||||
} from 'mizan/client'
|
||||
import { useJWT } from './jwt'
|
||||
import { DjangoError, type ErrorCode, type FunctionErrorResponse } from './errors'
|
||||
|
||||
// ============================================================================
|
||||
// Utilities
|
||||
// ============================================================================
|
||||
|
||||
function getCSRFToken(): string | null {
|
||||
if (typeof document === 'undefined') return null
|
||||
const match = document.cookie.match(/csrftoken=([^;]+)/)
|
||||
return match?.[1] ?? null
|
||||
}
|
||||
|
||||
// ============================================================================
|
||||
// Types
|
||||
// ============================================================================
|
||||
|
||||
export type ConnectionStatus = 'connected' | 'connecting' | 'disconnected'
|
||||
|
||||
/** Push message received from server */
|
||||
export interface PushMessage<T = unknown> {
|
||||
topic: string
|
||||
data: T
|
||||
}
|
||||
|
||||
/** Listener for push messages */
|
||||
export type PushListener<T = unknown> = (message: PushMessage<T>) => void
|
||||
|
||||
/** Context data store - maps context names to their data */
|
||||
export type ContextStore = Record<string, unknown>
|
||||
|
||||
/** Hydration data for SSR - maps context names to their initial data */
|
||||
export type MizanHydration = Record<string, unknown>
|
||||
|
||||
/** Transport mode for server function calls */
|
||||
export type Transport = 'http' | 'websocket'
|
||||
|
||||
export interface MizanContextValue {
|
||||
/**
|
||||
* Call a server function by name.
|
||||
*
|
||||
* Transport behavior:
|
||||
* - 'http' (default): Always use HTTP POST /api/mizan/call/
|
||||
* - 'websocket': Use WebSocket RPC when connected, HTTP fallback when not
|
||||
*
|
||||
* @param functionName - The server function name (e.g., 'echo', 'update_profile')
|
||||
* @param input - Optional input data for the function
|
||||
* @param transport - Transport mode ('http' or 'websocket', defaults to 'http')
|
||||
*/
|
||||
call: <TInput = unknown, TOutput = unknown>(
|
||||
functionName: string,
|
||||
input?: TInput,
|
||||
transport?: Transport
|
||||
) => Promise<TOutput>
|
||||
|
||||
/**
|
||||
* Get cached context data by name.
|
||||
* Returns undefined if the context hasn't been loaded yet.
|
||||
*/
|
||||
getContext: <T = unknown>(name: string) => T | undefined
|
||||
|
||||
/**
|
||||
* Refresh a specific context by name.
|
||||
* Fetches fresh data from the server and updates the cache.
|
||||
*/
|
||||
refreshContext: (name: string) => Promise<void>
|
||||
|
||||
/**
|
||||
* Refresh all registered contexts.
|
||||
*/
|
||||
refreshAllContexts: () => Promise<void>
|
||||
|
||||
/**
|
||||
* Current WebSocket connection status.
|
||||
*/
|
||||
status: ConnectionStatus
|
||||
|
||||
/**
|
||||
* Whether WebSocket RPC is available.
|
||||
*/
|
||||
isRPCAvailable: boolean
|
||||
|
||||
/**
|
||||
* Subscribe to push messages for a topic.
|
||||
* Returns an unsubscribe function.
|
||||
*/
|
||||
onPush: <T = unknown>(topic: string, listener: PushListener<T>) => () => void
|
||||
|
||||
/**
|
||||
* Subscribe to context changes.
|
||||
* Returns an unsubscribe function.
|
||||
*/
|
||||
onContextChange: (name: string, listener: (data: unknown) => void) => () => void
|
||||
|
||||
/**
|
||||
* Promise that resolves when the session is initialized (CSRF cookie set).
|
||||
* Await this before making HTTP calls in contexts where timing matters
|
||||
* (e.g., calling a server function immediately on mount).
|
||||
*/
|
||||
whenReady: Promise<void>
|
||||
|
||||
/**
|
||||
* Invalidate a named context, triggering a refetch.
|
||||
* Only refetches if the context is currently mounted (has a registered provider).
|
||||
* No-op if the context is not mounted.
|
||||
*/
|
||||
invalidateContext: (name: string) => Promise<void>
|
||||
|
||||
/**
|
||||
* Invalidate specific functions within their contexts.
|
||||
* Groups by context and calls invalidateContext per group.
|
||||
*/
|
||||
invalidateFunctions: (names: string[]) => Promise<void>
|
||||
|
||||
/**
|
||||
* Register a named context provider for invalidation support.
|
||||
* Called by generated context providers on mount.
|
||||
* Returns an unregister function (call on unmount).
|
||||
*/
|
||||
registerContextProvider: (
|
||||
name: string,
|
||||
refetch: () => Promise<void>,
|
||||
) => () => void
|
||||
|
||||
/**
|
||||
* Base URL for HTTP calls (for use by generated context providers).
|
||||
*/
|
||||
baseUrl: string
|
||||
|
||||
/**
|
||||
* Set context data directly without triggering a network request.
|
||||
* Used by generated providers that fetch bundled responses.
|
||||
*/
|
||||
setContextData: (name: string, data: unknown) => void
|
||||
|
||||
/**
|
||||
* Make an authenticated HTTP request.
|
||||
* Handles JWT Bearer or session cookie auth automatically.
|
||||
* Waits for session init before making the request.
|
||||
*/
|
||||
request: (method: string, path: string, data?: unknown) => Promise<Response>
|
||||
}
|
||||
|
||||
export interface MizanProviderProps {
|
||||
children: ReactNode
|
||||
|
||||
/**
|
||||
* Initial hydration data for contexts (from SSR).
|
||||
* Keys are context names, values are the data.
|
||||
*/
|
||||
hydration?: MizanHydration
|
||||
|
||||
/**
|
||||
* List of context names to auto-fetch if not in hydration.
|
||||
* These will be fetched on mount.
|
||||
*/
|
||||
contexts?: string[]
|
||||
|
||||
/**
|
||||
* Base URL for HTTP fallback calls.
|
||||
* @default '/api/mizan'
|
||||
*/
|
||||
baseUrl?: string
|
||||
|
||||
/**
|
||||
* WebSocket URL for RPC calls.
|
||||
* @default '/ws/'
|
||||
*/
|
||||
wsUrl?: string
|
||||
|
||||
/**
|
||||
* Whether to connect WebSocket automatically.
|
||||
* @default true
|
||||
*/
|
||||
autoConnect?: boolean
|
||||
|
||||
/**
|
||||
* WebSocket reconnection options.
|
||||
*/
|
||||
reconnect?: boolean
|
||||
reconnectDelay?: number
|
||||
maxReconnectAttempts?: number
|
||||
|
||||
/**
|
||||
* Custom connection instance (for testing).
|
||||
*/
|
||||
connection?: ChannelConnection
|
||||
}
|
||||
|
||||
// ============================================================================
|
||||
// Context
|
||||
// ============================================================================
|
||||
|
||||
const MizanContextInternal = createContext<MizanContextValue | null>(null)
|
||||
|
||||
// ============================================================================
|
||||
// Provider
|
||||
// ============================================================================
|
||||
|
||||
export function MizanProvider({
|
||||
children,
|
||||
hydration,
|
||||
contexts: contextNames = [],
|
||||
baseUrl = '/api/mizan',
|
||||
wsUrl = '/ws/',
|
||||
autoConnect = true,
|
||||
reconnect = true,
|
||||
reconnectDelay = 1000,
|
||||
maxReconnectAttempts = 10,
|
||||
connection: providedConnection,
|
||||
}: MizanProviderProps) {
|
||||
const connectionRef = useRef<ChannelConnection | null>(null)
|
||||
|
||||
// Push listeners: Map<topic, Set<listener>>
|
||||
const pushListenersRef = useRef<Map<string, Set<PushListener>>>(new Map())
|
||||
|
||||
// Context change listeners: Map<name, Set<listener>>
|
||||
const contextListenersRef = useRef<Map<string, Set<(data: unknown) => void>>>(new Map())
|
||||
|
||||
// Context data store
|
||||
const [contextStore, setContextStore] = useState<ContextStore>(() => {
|
||||
// Initialize from hydration if provided
|
||||
return hydration ?? {}
|
||||
})
|
||||
|
||||
// Check if JWT is available - use JWT auth if so, otherwise session auth
|
||||
const jwt = useJWT()
|
||||
const hasJWT = jwt !== null && jwt.tokens !== null
|
||||
const [sessionReady, setSessionReady] = useState(false)
|
||||
|
||||
// Promise that resolves when session is initialized.
|
||||
// Exposed via context so any code that needs to wait for CSRF can await it.
|
||||
const sessionRef = useRef<{ promise: Promise<void>; resolve: () => void } | null>(null)
|
||||
if (!sessionRef.current) {
|
||||
let resolve!: () => void
|
||||
const promise = new Promise<void>(r => { resolve = r })
|
||||
sessionRef.current = { promise, resolve }
|
||||
}
|
||||
|
||||
// Create HTTP client with appropriate auth method
|
||||
const httpClient = useMemo(() => {
|
||||
if (jwt?.getAccessToken) {
|
||||
return createDjangoCSRClient(Auth.JWT, {
|
||||
baseUrl,
|
||||
getAccessToken: jwt.getAccessToken,
|
||||
})
|
||||
}
|
||||
return createDjangoCSRClient(Auth.SESSION, { baseUrl })
|
||||
}, [hasJWT, jwt?.getAccessToken, baseUrl])
|
||||
|
||||
// Create or use provided connection
|
||||
if (!connectionRef.current) {
|
||||
connectionRef.current = providedConnection ?? new ChannelConnection({
|
||||
url: wsUrl,
|
||||
reconnect,
|
||||
reconnectDelay,
|
||||
maxReconnectAttempts,
|
||||
})
|
||||
}
|
||||
|
||||
const connection = connectionRef.current
|
||||
|
||||
// Track connection status
|
||||
const [status, setStatus] = useState<ConnectionStatus>(
|
||||
connection.status as ConnectionStatus
|
||||
)
|
||||
|
||||
// The core call function: HTTP-first, WebSocket opt-in
|
||||
const call = useCallback(
|
||||
async <TInput = unknown, TOutput = unknown>(
|
||||
functionName: string,
|
||||
input?: TInput,
|
||||
transport: Transport = 'http'
|
||||
): Promise<TOutput> => {
|
||||
// Only attempt WebSocket if explicitly requested AND connected
|
||||
if (transport === 'websocket' && connection.status === 'connected') {
|
||||
try {
|
||||
return await connection.rpc<TInput, TOutput>(functionName, input as TInput)
|
||||
} catch (e) {
|
||||
// If it's an RPC error (function error), re-throw as DjangoError
|
||||
if (e instanceof RPCError) {
|
||||
throw new DjangoError({
|
||||
error: true,
|
||||
code: e.code as ErrorCode,
|
||||
message: e.message,
|
||||
details: e.details,
|
||||
})
|
||||
}
|
||||
|
||||
// Connection error - fall through to HTTP
|
||||
console.warn(
|
||||
`[mizan] WebSocket RPC failed for '${functionName}', falling back to HTTP:`,
|
||||
e
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
// Wait for session init (CSRF cookie) before making HTTP requests
|
||||
await sessionRef.current!.promise
|
||||
|
||||
const response = await httpClient.request(
|
||||
'POST',
|
||||
`${baseUrl}/call/`,
|
||||
{ fn: functionName, args: input }
|
||||
)
|
||||
|
||||
const data: FunctionResponse<TOutput> = await response.json()
|
||||
|
||||
if (data.error) {
|
||||
throw new DjangoError(data as FunctionErrorResponse)
|
||||
}
|
||||
|
||||
return data.data
|
||||
},
|
||||
[connection, baseUrl, httpClient]
|
||||
)
|
||||
|
||||
// Get cached context data
|
||||
const getContext = useCallback(
|
||||
<T = unknown>(name: string): T | undefined => {
|
||||
return contextStore[name] as T | undefined
|
||||
},
|
||||
[contextStore]
|
||||
)
|
||||
|
||||
// Refresh a specific context
|
||||
const refreshContext = useCallback(
|
||||
async (name: string): Promise<void> => {
|
||||
try {
|
||||
const data = await call(name, {})
|
||||
setContextStore(prev => {
|
||||
const next = { ...prev, [name]: data }
|
||||
// Notify listeners
|
||||
const listeners = contextListenersRef.current.get(name)
|
||||
if (listeners) {
|
||||
listeners.forEach(listener => {
|
||||
try {
|
||||
listener(data)
|
||||
} catch (e) {
|
||||
console.error(`[mizan] Context listener error for '${name}':`, e)
|
||||
}
|
||||
})
|
||||
}
|
||||
return next
|
||||
})
|
||||
} catch (e) {
|
||||
console.error(`[mizan] Failed to refresh context '${name}':`, e)
|
||||
throw e
|
||||
}
|
||||
},
|
||||
[call]
|
||||
)
|
||||
|
||||
// Refresh all registered contexts
|
||||
const refreshAllContexts = useCallback(
|
||||
async (): Promise<void> => {
|
||||
await Promise.all(contextNames.map(name => refreshContext(name)))
|
||||
},
|
||||
[contextNames, refreshContext]
|
||||
)
|
||||
|
||||
// Subscribe to context changes
|
||||
const onContextChange = useCallback(
|
||||
(name: string, listener: (data: unknown) => void): (() => void) => {
|
||||
const listeners = contextListenersRef.current.get(name) ?? new Set()
|
||||
listeners.add(listener)
|
||||
contextListenersRef.current.set(name, listeners)
|
||||
|
||||
return () => {
|
||||
const nameListeners = contextListenersRef.current.get(name)
|
||||
if (nameListeners) {
|
||||
nameListeners.delete(listener)
|
||||
if (nameListeners.size === 0) {
|
||||
contextListenersRef.current.delete(name)
|
||||
}
|
||||
}
|
||||
}
|
||||
},
|
||||
[]
|
||||
)
|
||||
|
||||
// Subscribe to push messages
|
||||
const onPush = useCallback(
|
||||
<T = unknown>(topic: string, listener: PushListener<T>): (() => void) => {
|
||||
const listeners = pushListenersRef.current.get(topic) ?? new Set()
|
||||
listeners.add(listener as PushListener)
|
||||
pushListenersRef.current.set(topic, listeners)
|
||||
|
||||
return () => {
|
||||
const topicListeners = pushListenersRef.current.get(topic)
|
||||
if (topicListeners) {
|
||||
topicListeners.delete(listener as PushListener)
|
||||
if (topicListeners.size === 0) {
|
||||
pushListenersRef.current.delete(topic)
|
||||
}
|
||||
}
|
||||
}
|
||||
},
|
||||
[]
|
||||
)
|
||||
|
||||
// Connect on mount and listen for push messages
|
||||
useEffect(() => {
|
||||
const unsubscribeStatus = connection.onStatusChange((newStatus) => {
|
||||
setStatus(newStatus as ConnectionStatus)
|
||||
})
|
||||
|
||||
// Listen for all messages (including push)
|
||||
const unsubscribeMessages = connection.onMessage((payload) => {
|
||||
if (payload && typeof payload === 'object' && 'type' in payload && payload.type === 'push') {
|
||||
const topic = (payload as { topic?: string }).topic
|
||||
const data = (payload as { data?: unknown }).data
|
||||
|
||||
if (topic) {
|
||||
const listeners = pushListenersRef.current.get(topic)
|
||||
if (listeners) {
|
||||
const message: PushMessage = { topic, data }
|
||||
listeners.forEach(listener => {
|
||||
try {
|
||||
listener(message)
|
||||
} catch (e) {
|
||||
console.error('[mizan] Push listener error:', e)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
}
|
||||
})
|
||||
|
||||
if (autoConnect) {
|
||||
connection.connect()
|
||||
}
|
||||
|
||||
return () => {
|
||||
unsubscribeStatus()
|
||||
unsubscribeMessages()
|
||||
}
|
||||
}, [connection, autoConnect])
|
||||
|
||||
// Session init for CSR (fallback if proxy didn't run)
|
||||
useEffect(() => {
|
||||
if (hasJWT || getCSRFToken()) {
|
||||
setSessionReady(true)
|
||||
sessionRef.current?.resolve()
|
||||
return
|
||||
}
|
||||
fetch(`${baseUrl}/session/`, { credentials: 'include' })
|
||||
.catch(e => console.error('[MizanProvider] Session init failed:', e))
|
||||
.finally(() => {
|
||||
setSessionReady(true)
|
||||
sessionRef.current?.resolve()
|
||||
})
|
||||
}, [hasJWT, baseUrl])
|
||||
|
||||
// Auto-fetch contexts that weren't hydrated
|
||||
useEffect(() => {
|
||||
if (!sessionReady) return
|
||||
if (!hydration) {
|
||||
refreshAllContexts()
|
||||
} else {
|
||||
const missing = contextNames.filter(name => !(name in hydration))
|
||||
if (missing.length > 0) {
|
||||
Promise.all(missing.map(name => refreshContext(name)))
|
||||
}
|
||||
}
|
||||
// eslint-disable-next-line react-hooks/exhaustive-deps
|
||||
}, [sessionReady])
|
||||
|
||||
const isRPCAvailable = status === 'connected'
|
||||
|
||||
// Named context provider registry for invalidation
|
||||
const contextProvidersRef = useRef<Map<string, { refetch: () => Promise<void> }>>(new Map())
|
||||
|
||||
const registerContextProvider = useCallback(
|
||||
(name: string, refetch: () => Promise<void>): (() => void) => {
|
||||
contextProvidersRef.current.set(name, { refetch })
|
||||
return () => {
|
||||
contextProvidersRef.current.delete(name)
|
||||
}
|
||||
},
|
||||
[]
|
||||
)
|
||||
|
||||
const invalidateContext = useCallback(
|
||||
async (name: string): Promise<void> => {
|
||||
const provider = contextProvidersRef.current.get(name)
|
||||
if (provider) {
|
||||
await provider.refetch()
|
||||
}
|
||||
// If not mounted, no-op — no wasted request
|
||||
},
|
||||
[]
|
||||
)
|
||||
|
||||
const invalidateFunctions = useCallback(
|
||||
async (names: string[]): Promise<void> => {
|
||||
// Each function belongs to a context. Invalidating a function
|
||||
// means refetching its entire context (since the bundling endpoint
|
||||
// returns all functions). Dedupe by context name.
|
||||
const contexts = new Set<string>()
|
||||
for (const name of names) {
|
||||
// The context name for each function is known at codegen time
|
||||
// and baked into the generated hook. Here we just invalidate
|
||||
// whatever contexts are registered that contain these functions.
|
||||
for (const [ctxName] of contextProvidersRef.current) {
|
||||
contexts.add(ctxName)
|
||||
}
|
||||
}
|
||||
await Promise.all(
|
||||
Array.from(contexts).map(ctx => invalidateContext(ctx))
|
||||
)
|
||||
},
|
||||
[invalidateContext]
|
||||
)
|
||||
|
||||
// Set context data directly (used by generated providers that fetch bundles)
|
||||
const setContextData = useCallback(
|
||||
(name: string, data: unknown) => {
|
||||
setContextStore(prev => {
|
||||
const next = { ...prev, [name]: data }
|
||||
const listeners = contextListenersRef.current.get(name)
|
||||
if (listeners) {
|
||||
listeners.forEach(listener => {
|
||||
try {
|
||||
listener(data)
|
||||
} catch (e) {
|
||||
console.error(`[mizan] Context listener error for '${name}':`, e)
|
||||
}
|
||||
})
|
||||
}
|
||||
return next
|
||||
})
|
||||
},
|
||||
[]
|
||||
)
|
||||
|
||||
// Auth-transparent HTTP request (used by generated context providers)
|
||||
const request = useCallback(
|
||||
async (method: string, path: string, data?: unknown): Promise<Response> => {
|
||||
await sessionRef.current!.promise
|
||||
return httpClient.request(method, path, data)
|
||||
},
|
||||
[httpClient]
|
||||
)
|
||||
|
||||
const value = useMemo<MizanContextValue>(
|
||||
() => ({
|
||||
call,
|
||||
getContext,
|
||||
refreshContext,
|
||||
refreshAllContexts,
|
||||
status,
|
||||
isRPCAvailable,
|
||||
onPush,
|
||||
onContextChange,
|
||||
whenReady: sessionRef.current!.promise,
|
||||
invalidateContext,
|
||||
invalidateFunctions,
|
||||
registerContextProvider,
|
||||
baseUrl,
|
||||
setContextData,
|
||||
request,
|
||||
}),
|
||||
[call, getContext, refreshContext, refreshAllContexts, status, isRPCAvailable, onPush, onContextChange, invalidateContext, invalidateFunctions, registerContextProvider, baseUrl, setContextData, request]
|
||||
)
|
||||
|
||||
return (
|
||||
<MizanContextInternal value={value}>
|
||||
{children}
|
||||
</MizanContextInternal>
|
||||
)
|
||||
}
|
||||
|
||||
// ============================================================================
|
||||
// Hooks
|
||||
// ============================================================================
|
||||
|
||||
/**
|
||||
* Access the mizan context.
|
||||
*
|
||||
* Provides generic name-based API for server functions and contexts.
|
||||
* Libraries should use this hook, not the typed generated hooks.
|
||||
*
|
||||
* @example
|
||||
* ```tsx
|
||||
* // Library code (e.g., Allauth)
|
||||
* import { useMizan } from 'mizan'
|
||||
*
|
||||
* function useUser() {
|
||||
* const { getContext } = useMizan()
|
||||
* return getContext('current_user')
|
||||
* }
|
||||
* ```
|
||||
*/
|
||||
export function useMizan(): MizanContextValue {
|
||||
const context = useReactContext(MizanContextInternal)
|
||||
if (!context) {
|
||||
throw new Error('useMizan must be used within a MizanProvider')
|
||||
}
|
||||
return context
|
||||
}
|
||||
|
||||
/**
|
||||
* Get cached context data by name.
|
||||
*
|
||||
* For use by libraries that need to access context data without knowing types.
|
||||
*
|
||||
* @example
|
||||
* ```tsx
|
||||
* // In Allauth library
|
||||
* function useUser() {
|
||||
* return useMizanContext('current_user')
|
||||
* }
|
||||
* ```
|
||||
*/
|
||||
export function useMizanContext<T = unknown>(name: string): T | undefined {
|
||||
const { getContext } = useMizan()
|
||||
return getContext<T>(name)
|
||||
}
|
||||
|
||||
/**
|
||||
* Get a function caller by name with transport control.
|
||||
*
|
||||
* For use by libraries that need to call functions without knowing types.
|
||||
* The transport parameter is baked into the returned function.
|
||||
*
|
||||
* @param functionName - The server function name
|
||||
* @param transport - Transport mode ('http' or 'websocket', defaults to 'http')
|
||||
*
|
||||
* @example
|
||||
* ```tsx
|
||||
* // HTTP-only function (default)
|
||||
* function useUpdateProfile() {
|
||||
* return useMizanCall('update_profile')
|
||||
* }
|
||||
*
|
||||
* // WebSocket-enabled function
|
||||
* function useSendMessage() {
|
||||
* return useMizanCall('send_message', 'websocket')
|
||||
* }
|
||||
* ```
|
||||
*/
|
||||
export function useMizanCall<TInput = unknown, TOutput = unknown>(
|
||||
functionName: string,
|
||||
transport: Transport = 'http'
|
||||
): (input?: TInput) => Promise<TOutput> {
|
||||
const { call } = useMizan()
|
||||
return useCallback(
|
||||
(input?: TInput) => call<TInput, TOutput>(functionName, input, transport),
|
||||
[call, functionName, transport]
|
||||
)
|
||||
}
|
||||
|
||||
/**
|
||||
* Get the current WebSocket connection status.
|
||||
*/
|
||||
export function useMizanStatus(): ConnectionStatus {
|
||||
const { status } = useMizan()
|
||||
return status
|
||||
}
|
||||
|
||||
/**
|
||||
* Subscribe to push messages for a topic.
|
||||
* Automatically unsubscribes when the component unmounts.
|
||||
*/
|
||||
export function usePush<T = unknown>(
|
||||
topic: string,
|
||||
callback: PushListener<T>
|
||||
): void {
|
||||
const { onPush } = useMizan()
|
||||
const callbackRef = useRef(callback)
|
||||
|
||||
useEffect(() => {
|
||||
callbackRef.current = callback
|
||||
}, [callback])
|
||||
|
||||
useEffect(() => {
|
||||
const listener: PushListener<T> = (message) => {
|
||||
callbackRef.current(message)
|
||||
}
|
||||
|
||||
return onPush(topic, listener)
|
||||
}, [topic, onPush])
|
||||
}
|
||||
|
||||
// ============================================================================
|
||||
// Legacy Aliases (for backwards compatibility during migration)
|
||||
// ============================================================================
|
||||
|
||||
/** @deprecated Use MizanProvider instead */
|
||||
export const DjangoContext = MizanProvider
|
||||
|
||||
/** @deprecated Use useMizan instead */
|
||||
export const useDjango = useMizan
|
||||
|
||||
/** @deprecated Use useMizanStatus instead */
|
||||
export const useDjangoStatus = useMizanStatus
|
||||
|
||||
/** @deprecated Use useMizanCall instead */
|
||||
export function useServerFunction<TInput = unknown, TOutput = unknown>(
|
||||
functionName: string
|
||||
): (input: TInput) => Promise<TOutput> {
|
||||
const { call } = useMizan()
|
||||
return useCallback(
|
||||
(input: TInput) => call<TInput, TOutput>(functionName, input),
|
||||
[call, functionName]
|
||||
)
|
||||
}
|
||||
|
||||
// Re-export types for the legacy API
|
||||
export type DjangoContextValue = MizanContextValue
|
||||
export type DjangoContextProps = MizanProviderProps
|
||||
107
packages/mizan-csr/adapters/react/src/errors.ts
Normal file
107
packages/mizan-csr/adapters/react/src/errors.ts
Normal file
@@ -0,0 +1,107 @@
|
||||
/**
|
||||
* Django Server Error Types
|
||||
*
|
||||
* Typed errors for server function failures.
|
||||
*/
|
||||
|
||||
/**
|
||||
* Error codes returned by the server
|
||||
*/
|
||||
export type ErrorCode =
|
||||
| 'NOT_FOUND'
|
||||
| 'VALIDATION_ERROR'
|
||||
| 'UNAUTHORIZED'
|
||||
| 'FORBIDDEN'
|
||||
| 'BAD_REQUEST'
|
||||
| 'INTERNAL_ERROR'
|
||||
| 'NOT_IMPLEMENTED'
|
||||
|
||||
/**
|
||||
* Error response structure from the server
|
||||
*/
|
||||
export interface FunctionErrorResponse {
|
||||
error: true
|
||||
code: ErrorCode
|
||||
message: string
|
||||
details?: {
|
||||
fields?: Record<string, string[]>
|
||||
required?: string[]
|
||||
type?: string
|
||||
[key: string]: unknown
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Error thrown when a server function call fails
|
||||
*/
|
||||
export class DjangoError extends Error {
|
||||
/**
|
||||
* Error code from the server
|
||||
*/
|
||||
readonly code: ErrorCode
|
||||
|
||||
/**
|
||||
* Additional error details
|
||||
*/
|
||||
readonly details?: FunctionErrorResponse['details']
|
||||
|
||||
/**
|
||||
* The original error response
|
||||
*/
|
||||
readonly response: FunctionErrorResponse
|
||||
|
||||
constructor(response: FunctionErrorResponse) {
|
||||
super(response.message)
|
||||
this.name = 'DjangoError'
|
||||
this.code = response.code
|
||||
this.details = response.details
|
||||
this.response = response
|
||||
|
||||
// Maintains proper stack trace for where error was thrown
|
||||
if (Error.captureStackTrace) {
|
||||
Error.captureStackTrace(this, DjangoError)
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Check if this is a validation error
|
||||
*/
|
||||
isValidationError(): boolean {
|
||||
return this.code === 'VALIDATION_ERROR'
|
||||
}
|
||||
|
||||
/**
|
||||
* Check if this is an authentication error
|
||||
*/
|
||||
isAuthError(): boolean {
|
||||
return this.code === 'UNAUTHORIZED' || this.code === 'FORBIDDEN'
|
||||
}
|
||||
|
||||
/**
|
||||
* Check if this is a not found error
|
||||
*/
|
||||
isNotFound(): boolean {
|
||||
return this.code === 'NOT_FOUND'
|
||||
}
|
||||
|
||||
/**
|
||||
* Get field-level validation errors (if this is a validation error)
|
||||
*/
|
||||
getFieldErrors(): Record<string, string[]> | null {
|
||||
if (this.code === 'VALIDATION_ERROR' && this.details?.fields) {
|
||||
return this.details.fields
|
||||
}
|
||||
return null
|
||||
}
|
||||
|
||||
/**
|
||||
* Get error for a specific field
|
||||
*/
|
||||
getFieldError(field: string): string | null {
|
||||
const errors = this.getFieldErrors()
|
||||
if (errors && errors[field]?.length > 0) {
|
||||
return errors[field][0]
|
||||
}
|
||||
return null
|
||||
}
|
||||
}
|
||||
1163
packages/mizan-csr/adapters/react/src/forms.ts
Normal file
1163
packages/mizan-csr/adapters/react/src/forms.ts
Normal file
File diff suppressed because it is too large
Load Diff
115
packages/mizan-csr/adapters/react/src/index.ts
Normal file
115
packages/mizan-csr/adapters/react/src/index.ts
Normal file
@@ -0,0 +1,115 @@
|
||||
/**
|
||||
* mizan - Django Server Functions Client
|
||||
*
|
||||
* Frontend client for Django server functions.
|
||||
* Server functions are the core primitive - accessed via React hooks.
|
||||
*
|
||||
* Two-layer architecture:
|
||||
*
|
||||
* 1. Library layer (this package) - Generic name-based API
|
||||
* Used by libraries like Allauth that need to call functions by name.
|
||||
*
|
||||
* import { useMizan, useMizanContext, useMizanCall } from 'mizan'
|
||||
* const user = useMizanContext('current_user')
|
||||
* const call = useMizanCall('update_profile')
|
||||
*
|
||||
* 2. Generated layer (@/api) - Typed project-specific API
|
||||
* Used by product code for type-safe hooks.
|
||||
*
|
||||
* import { useCurrentUser, useUpdateProfile } from '@/api'
|
||||
* const user = useCurrentUser()
|
||||
* const updateProfile = useUpdateProfile()
|
||||
*
|
||||
* The generated code wraps MizanProvider and adds type-safe hooks.
|
||||
*/
|
||||
|
||||
// ============================================================================
|
||||
// React Context & Hooks (primary API)
|
||||
// ============================================================================
|
||||
|
||||
export {
|
||||
// Provider
|
||||
MizanProvider,
|
||||
type MizanProviderProps,
|
||||
type MizanHydration,
|
||||
|
||||
// Hooks (generic name-based API for libraries)
|
||||
useMizan,
|
||||
useMizanContext,
|
||||
useMizanCall,
|
||||
useMizanStatus,
|
||||
usePush,
|
||||
|
||||
// Types
|
||||
type MizanContextValue,
|
||||
type ConnectionStatus,
|
||||
type PushMessage,
|
||||
type PushListener,
|
||||
type ContextStore,
|
||||
type Transport,
|
||||
|
||||
// Legacy aliases (deprecated, for migration)
|
||||
DjangoContext,
|
||||
useDjango,
|
||||
useDjangoStatus,
|
||||
useServerFunction,
|
||||
type DjangoContextValue,
|
||||
type DjangoContextProps,
|
||||
} from './context'
|
||||
|
||||
// ============================================================================
|
||||
// HTTP Client (for SSR or non-React usage)
|
||||
// ============================================================================
|
||||
|
||||
export {
|
||||
httpFunctionCall,
|
||||
createDjangoCSRClient,
|
||||
createDjangoSSRClient,
|
||||
ensureDjangoSession,
|
||||
Auth,
|
||||
type DjangoHTTPClient,
|
||||
type CSRClientConfig,
|
||||
type JWTClientConfig,
|
||||
type SSRClientConfig,
|
||||
} from './client/'
|
||||
|
||||
// ============================================================================
|
||||
// Errors
|
||||
// ============================================================================
|
||||
|
||||
export {
|
||||
DjangoError,
|
||||
type FunctionErrorResponse,
|
||||
type ErrorCode,
|
||||
} from './errors'
|
||||
|
||||
// ============================================================================
|
||||
// Forms (typed form hooks core)
|
||||
// ============================================================================
|
||||
|
||||
export {
|
||||
// Single form
|
||||
useMizanFormCore,
|
||||
// Legacy alias
|
||||
useMizanFormCore as useDjangoFormCore,
|
||||
type DjangoFormState,
|
||||
type FormSchema,
|
||||
type FormErrors,
|
||||
type FormOptions,
|
||||
type FormSubmitResult,
|
||||
type FormCoreConfig,
|
||||
// Formset
|
||||
useMizanFormsetCore,
|
||||
// Legacy alias
|
||||
useMizanFormsetCore as useDjangoFormsetCore,
|
||||
type DjangoFormsetState,
|
||||
type FormsetSchema,
|
||||
type FormsetErrors,
|
||||
type FormsetCoreConfig,
|
||||
type FormsetSubmitResult,
|
||||
// Shared types
|
||||
type FieldSchema,
|
||||
type FieldChoice,
|
||||
type FieldError,
|
||||
type FormMeta,
|
||||
} from './forms'
|
||||
235
packages/mizan-csr/adapters/react/src/jwt/JWTContext.tsx
Normal file
235
packages/mizan-csr/adapters/react/src/jwt/JWTContext.tsx
Normal file
@@ -0,0 +1,235 @@
|
||||
'use client'
|
||||
|
||||
import {
|
||||
createContext,
|
||||
useContext,
|
||||
useState,
|
||||
useEffect,
|
||||
useCallback,
|
||||
useRef,
|
||||
type ReactNode,
|
||||
} from 'react'
|
||||
import type { JWTTokens, JWTConfig, JWTState } from '../client/types'
|
||||
|
||||
function getCSRFToken(): string | null {
|
||||
if (typeof document === 'undefined') return null
|
||||
const match = document.cookie.match(/csrftoken=([^;]+)/)
|
||||
return match?.[1] ?? null
|
||||
}
|
||||
|
||||
const Context = createContext<JWTState | null>(null)
|
||||
|
||||
const DEFAULT_CONFIG: Required<JWTConfig> = {
|
||||
baseUrl: '',
|
||||
endpoint: '/api/mizan/call/',
|
||||
refreshBuffer: 30,
|
||||
autoObtain: true,
|
||||
autoRefresh: true,
|
||||
}
|
||||
|
||||
interface JWTContextProps {
|
||||
children: ReactNode
|
||||
config?: JWTConfig
|
||||
}
|
||||
|
||||
export function JWTContext({ children, config }: JWTContextProps) {
|
||||
const cfg = { ...DEFAULT_CONFIG, ...config }
|
||||
|
||||
const [tokens, setTokens] = useState<JWTTokens | null>(null)
|
||||
const [isLoading, setIsLoading] = useState(false)
|
||||
const [error, setError] = useState<Error | null>(null)
|
||||
|
||||
const refreshTimeoutRef = useRef<NodeJS.Timeout | null>(null)
|
||||
|
||||
// Helper to call server functions
|
||||
const callServerFunction = useCallback(async (fn: string, args: Record<string, unknown> = {}) => {
|
||||
const url = cfg.baseUrl ? `${cfg.baseUrl}${cfg.endpoint}` : cfg.endpoint
|
||||
const csrfToken = getCSRFToken()
|
||||
const response = await fetch(url, {
|
||||
method: 'POST',
|
||||
headers: {
|
||||
'Content-Type': 'application/json',
|
||||
...(csrfToken ? { 'X-CSRFToken': csrfToken } : {}),
|
||||
},
|
||||
credentials: 'include', // Include session cookie for CSRF
|
||||
body: JSON.stringify({ fn, args }),
|
||||
})
|
||||
|
||||
const data = await response.json()
|
||||
|
||||
if (data.error) {
|
||||
const err = new Error(data.message || 'Server function failed')
|
||||
;(err as any).code = data.code
|
||||
;(err as any).details = data.details
|
||||
throw err
|
||||
}
|
||||
|
||||
return data.data
|
||||
}, [cfg.baseUrl, cfg.endpoint])
|
||||
|
||||
// Obtain tokens from session
|
||||
const obtainTokens = useCallback(async (): Promise<JWTTokens | null> => {
|
||||
setIsLoading(true)
|
||||
setError(null)
|
||||
|
||||
try {
|
||||
const result = await callServerFunction('jwt_obtain')
|
||||
const newTokens: JWTTokens = {
|
||||
accessToken: result.access_token,
|
||||
refreshToken: result.refresh_token,
|
||||
expiresAt: Date.now() + result.expires_in * 1000,
|
||||
}
|
||||
|
||||
setTokens(newTokens)
|
||||
return newTokens
|
||||
} catch (err: any) {
|
||||
// FORBIDDEN means not authenticated - expected, not an error
|
||||
if (err.code === 'FORBIDDEN') {
|
||||
setTokens(null)
|
||||
return null
|
||||
}
|
||||
const error = err instanceof Error ? err : new Error(String(err))
|
||||
setError(error)
|
||||
return null
|
||||
} finally {
|
||||
setIsLoading(false)
|
||||
}
|
||||
}, [callServerFunction])
|
||||
|
||||
// Refresh tokens
|
||||
const refreshTokens = useCallback(async (): Promise<JWTTokens | null> => {
|
||||
if (!tokens?.refreshToken) {
|
||||
return null
|
||||
}
|
||||
|
||||
setIsLoading(true)
|
||||
setError(null)
|
||||
|
||||
try {
|
||||
const result = await callServerFunction('jwt_refresh', {
|
||||
refresh_token: tokens.refreshToken,
|
||||
})
|
||||
const newTokens: JWTTokens = {
|
||||
accessToken: result.access_token,
|
||||
refreshToken: result.refresh_token,
|
||||
expiresAt: Date.now() + result.expires_in * 1000,
|
||||
}
|
||||
|
||||
setTokens(newTokens)
|
||||
return newTokens
|
||||
} catch (err: any) {
|
||||
// FORBIDDEN means refresh token invalid/expired - clear tokens
|
||||
if (err.code === 'FORBIDDEN') {
|
||||
setTokens(null)
|
||||
return null
|
||||
}
|
||||
const error = err instanceof Error ? err : new Error(String(err))
|
||||
setError(error)
|
||||
return null
|
||||
} finally {
|
||||
setIsLoading(false)
|
||||
}
|
||||
}, [tokens?.refreshToken, callServerFunction])
|
||||
|
||||
// Clear tokens
|
||||
const clearTokens = useCallback(() => {
|
||||
setTokens(null)
|
||||
setError(null)
|
||||
if (refreshTimeoutRef.current) {
|
||||
clearTimeout(refreshTimeoutRef.current)
|
||||
refreshTimeoutRef.current = null
|
||||
}
|
||||
}, [])
|
||||
|
||||
// Get access token (refresh if needed)
|
||||
const getAccessToken = useCallback(async (): Promise<string | null> => {
|
||||
if (!tokens) {
|
||||
return null
|
||||
}
|
||||
|
||||
// Check if token needs refresh
|
||||
const bufferMs = cfg.refreshBuffer * 1000
|
||||
if (tokens.expiresAt - Date.now() < bufferMs) {
|
||||
const newTokens = await refreshTokens()
|
||||
return newTokens?.accessToken ?? null
|
||||
}
|
||||
|
||||
return tokens.accessToken
|
||||
}, [tokens, cfg.refreshBuffer, refreshTokens])
|
||||
|
||||
// Auto-obtain on mount
|
||||
useEffect(() => {
|
||||
if (!cfg.autoObtain) return
|
||||
obtainTokens()
|
||||
}, [cfg.autoObtain, obtainTokens])
|
||||
|
||||
// Auto-refresh before expiry
|
||||
useEffect(() => {
|
||||
if (!cfg.autoRefresh || !tokens) {
|
||||
return
|
||||
}
|
||||
|
||||
// Clear existing timeout
|
||||
if (refreshTimeoutRef.current) {
|
||||
clearTimeout(refreshTimeoutRef.current)
|
||||
}
|
||||
|
||||
// Schedule refresh
|
||||
const bufferMs = cfg.refreshBuffer * 1000
|
||||
const timeUntilRefresh = tokens.expiresAt - Date.now() - bufferMs
|
||||
|
||||
if (timeUntilRefresh > 0) {
|
||||
refreshTimeoutRef.current = setTimeout(() => {
|
||||
refreshTokens()
|
||||
}, timeUntilRefresh)
|
||||
}
|
||||
|
||||
return () => {
|
||||
if (refreshTimeoutRef.current) {
|
||||
clearTimeout(refreshTimeoutRef.current)
|
||||
}
|
||||
}
|
||||
}, [cfg.autoRefresh, cfg.refreshBuffer, tokens, refreshTokens])
|
||||
|
||||
const value: JWTState = {
|
||||
tokens,
|
||||
isLoading,
|
||||
error,
|
||||
obtainTokens,
|
||||
refreshTokens,
|
||||
clearTokens,
|
||||
getAccessToken,
|
||||
}
|
||||
|
||||
return <Context value={value}>{children}</Context>
|
||||
}
|
||||
|
||||
/**
|
||||
* Hook to access JWT state and methods.
|
||||
*
|
||||
* When used outside JWTContext, returns null. This allows
|
||||
* conditional JWT usage (e.g., useDjangoApi({ jwt: true }))
|
||||
* without requiring JWTContext to always be present.
|
||||
*/
|
||||
export function useJWT(): JWTState | null {
|
||||
return useContext(Context)
|
||||
}
|
||||
|
||||
/**
|
||||
* Hook to access JWT state with a guarantee it exists.
|
||||
* Throws if used outside JWTContext.
|
||||
*/
|
||||
export function useJWTRequired(): JWTState {
|
||||
const context = useContext(Context)
|
||||
if (!context) {
|
||||
throw new Error('useJWTRequired must be used within JWTContext')
|
||||
}
|
||||
return context
|
||||
}
|
||||
|
||||
/** Check if JWT is available (tokens obtained) */
|
||||
export function useJWTReady(): boolean {
|
||||
const jwt = useJWT()
|
||||
if (!jwt) return false
|
||||
return !jwt.isLoading && jwt.tokens !== null
|
||||
}
|
||||
@@ -0,0 +1,152 @@
|
||||
/**
|
||||
* Integration tests for JWT Context
|
||||
*
|
||||
* These tests call the REAL backend - no mocks.
|
||||
* Backend must be running: docker-compose up
|
||||
*
|
||||
* Run with: RUN_INTEGRATION_TESTS=true npm run test
|
||||
*
|
||||
* Note: Most JWT operations require an authenticated session.
|
||||
* Tests that require authentication verify 401 handling (expected for anonymous users).
|
||||
*/
|
||||
|
||||
import { renderHook, act, waitFor } from '@testing-library/react'
|
||||
import { ReactNode } from 'react'
|
||||
import { JWTContext, useJWT, useJWTRequired } from '../JWTContext'
|
||||
import { describeIntegration, BACKEND_URL } from '../../testing'
|
||||
|
||||
function createWrapper(config?: Parameters<typeof JWTContext>[0]['config']) {
|
||||
return function Wrapper({ children }: { children: ReactNode }) {
|
||||
return (
|
||||
<JWTContext config={{
|
||||
autoObtain: false,
|
||||
baseUrl: BACKEND_URL,
|
||||
...config
|
||||
}}>
|
||||
{children}
|
||||
</JWTContext>
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
describeIntegration('JWTContext (integration)', () => {
|
||||
describe('Hook behavior outside provider', () => {
|
||||
it('should return null when useJWT used outside JWTContext', () => {
|
||||
const { result } = renderHook(() => useJWT())
|
||||
expect(result.current).toBeNull()
|
||||
})
|
||||
|
||||
it('should throw when useJWTRequired used outside JWTContext', () => {
|
||||
// Suppress console.error for this test
|
||||
const consoleSpy = jest.spyOn(console, 'error').mockImplementation(() => {})
|
||||
|
||||
expect(() => {
|
||||
renderHook(() => useJWTRequired())
|
||||
}).toThrow('useJWTRequired must be used within JWTContext')
|
||||
|
||||
consoleSpy.mockRestore()
|
||||
})
|
||||
})
|
||||
|
||||
describe('Token operations with real backend', () => {
|
||||
it('should handle 401 when obtaining tokens as anonymous user', async () => {
|
||||
const { result } = renderHook(() => useJWTRequired(), {
|
||||
wrapper: createWrapper(),
|
||||
})
|
||||
|
||||
await act(async () => {
|
||||
const tokens = await result.current.obtainTokens()
|
||||
// Anonymous users get 401, which returns null (not an error)
|
||||
expect(tokens).toBeNull()
|
||||
})
|
||||
|
||||
expect(result.current.tokens).toBeNull()
|
||||
// 401 is graceful - not an error state
|
||||
expect(result.current.error).toBeNull()
|
||||
})
|
||||
|
||||
it('should handle 401 when refreshing tokens without valid refresh token', async () => {
|
||||
const { result } = renderHook(() => useJWTRequired(), {
|
||||
wrapper: createWrapper(),
|
||||
})
|
||||
|
||||
// Try to refresh without any tokens - should fail gracefully
|
||||
await act(async () => {
|
||||
const tokens = await result.current.refreshTokens()
|
||||
expect(tokens).toBeNull()
|
||||
})
|
||||
|
||||
expect(result.current.tokens).toBeNull()
|
||||
})
|
||||
})
|
||||
|
||||
describe('clearTokens (no backend needed)', () => {
|
||||
it('should be safe to call clearTokens when no tokens are set', async () => {
|
||||
// Verifies clearTokens doesn't throw or corrupt state when called
|
||||
// with no tokens present (e.g., during logout when already logged out)
|
||||
const { result } = renderHook(() => useJWTRequired(), {
|
||||
wrapper: createWrapper(),
|
||||
})
|
||||
|
||||
act(() => {
|
||||
result.current.clearTokens()
|
||||
})
|
||||
|
||||
expect(result.current.tokens).toBeNull()
|
||||
expect(result.current.error).toBeNull()
|
||||
})
|
||||
})
|
||||
|
||||
describe('getAccessToken', () => {
|
||||
it('should return null when no tokens available', async () => {
|
||||
const { result } = renderHook(() => useJWTRequired(), {
|
||||
wrapper: createWrapper(),
|
||||
})
|
||||
|
||||
let token: string | null = 'not-null'
|
||||
await act(async () => {
|
||||
token = await result.current.getAccessToken()
|
||||
})
|
||||
|
||||
expect(token).toBeNull()
|
||||
})
|
||||
})
|
||||
|
||||
describe('autoObtain with real backend', () => {
|
||||
it('should attempt auto-obtain and handle 401 for anonymous', async () => {
|
||||
const { result } = renderHook(() => useJWTRequired(), {
|
||||
wrapper: createWrapper({ autoObtain: true }),
|
||||
})
|
||||
|
||||
// Auto-obtain will attempt to get tokens but fail for anonymous user
|
||||
await waitFor(() => {
|
||||
// After auto-obtain completes, tokens should be null (401 response)
|
||||
// or the loading state should be done
|
||||
expect(result.current.isLoading).toBe(false)
|
||||
}, { timeout: 5000 })
|
||||
|
||||
// Anonymous user won't have tokens
|
||||
expect(result.current.tokens).toBeNull()
|
||||
})
|
||||
})
|
||||
})
|
||||
|
||||
/**
|
||||
* Note on authenticated JWT tests:
|
||||
*
|
||||
* To fully test JWT token obtain/refresh, we would need:
|
||||
* 1. An authenticated user session (login first)
|
||||
* 2. Valid CSRF token handling
|
||||
*
|
||||
* These scenarios are better tested in E2E tests (Playwright/Cypress)
|
||||
* where we can:
|
||||
* 1. Navigate to login page
|
||||
* 2. Submit credentials
|
||||
* 3. Then test JWT token flows
|
||||
*
|
||||
* The tests above verify:
|
||||
* - Hook API contract (throws/returns null outside provider)
|
||||
* - Graceful 401 handling (anonymous users)
|
||||
* - State management (clearTokens)
|
||||
* - Integration with real backend (network calls happen)
|
||||
*/
|
||||
@@ -0,0 +1,79 @@
|
||||
/**
|
||||
* Contract Tests for mizan JWT Server Functions
|
||||
*
|
||||
* Validates that the backend schema exports the expected JWT functions.
|
||||
* These tests catch frontend/backend contract mismatches early.
|
||||
*/
|
||||
|
||||
import mizanSchema from '@/api/generated.mizan.schema.json'
|
||||
|
||||
type mizanFunction = {
|
||||
name: string
|
||||
camelName: string
|
||||
hasInput: boolean
|
||||
inputType: string | null
|
||||
outputType: string
|
||||
transport: string
|
||||
}
|
||||
|
||||
function getFunctions(): mizanFunction[] {
|
||||
return (mizanSchema as any)['x-mizan-functions'] ?? []
|
||||
}
|
||||
|
||||
function findFunction(name: string): mizanFunction | undefined {
|
||||
return getFunctions().find(fn => fn.name === name)
|
||||
}
|
||||
|
||||
describe('JWT Server Functions Contract', () => {
|
||||
describe('jwt_obtain', () => {
|
||||
it('should be registered as a server function', () => {
|
||||
const fn = findFunction('jwt_obtain')
|
||||
expect(fn).toBeDefined()
|
||||
expect(fn?.transport).toBe('http')
|
||||
})
|
||||
|
||||
it('should have no input (session-based)', () => {
|
||||
const fn = findFunction('jwt_obtain')
|
||||
expect(fn?.hasInput).toBe(false)
|
||||
})
|
||||
|
||||
it('should return token pair with expected fields', () => {
|
||||
const schemas = (mizanSchema as any).components?.schemas
|
||||
const output = schemas?.jwtObtainOutput
|
||||
|
||||
expect(output).toBeDefined()
|
||||
expect(output.properties.access_token).toBeDefined()
|
||||
expect(output.properties.refresh_token).toBeDefined()
|
||||
expect(output.properties.expires_in).toBeDefined()
|
||||
})
|
||||
})
|
||||
|
||||
describe('jwt_refresh', () => {
|
||||
it('should be registered as a server function', () => {
|
||||
const fn = findFunction('jwt_refresh')
|
||||
expect(fn).toBeDefined()
|
||||
expect(fn?.transport).toBe('http')
|
||||
})
|
||||
|
||||
it('should accept refresh_token as input', () => {
|
||||
const fn = findFunction('jwt_refresh')
|
||||
expect(fn?.hasInput).toBe(true)
|
||||
|
||||
const schemas = (mizanSchema as any).components?.schemas
|
||||
const input = schemas?.jwtRefreshInput
|
||||
|
||||
expect(input).toBeDefined()
|
||||
expect(input.properties.refresh_token).toBeDefined()
|
||||
})
|
||||
|
||||
it('should return token pair with expected fields', () => {
|
||||
const schemas = (mizanSchema as any).components?.schemas
|
||||
const output = schemas?.jwtRefreshOutput
|
||||
|
||||
expect(output).toBeDefined()
|
||||
expect(output.properties.access_token).toBeDefined()
|
||||
expect(output.properties.refresh_token).toBeDefined()
|
||||
expect(output.properties.expires_in).toBeDefined()
|
||||
})
|
||||
})
|
||||
})
|
||||
@@ -0,0 +1,34 @@
|
||||
/**
|
||||
* Unit Tests for JWT Hooks
|
||||
*
|
||||
* Tests hook behavior in isolation (no backend required).
|
||||
*/
|
||||
|
||||
import { renderHook } from '@testing-library/react'
|
||||
import { ReactNode } from 'react'
|
||||
import { JWTContext, useJWTReady } from '../JWTContext'
|
||||
|
||||
// Wrapper that provides JWTContext
|
||||
function createWrapper(config?: Parameters<typeof JWTContext>[0]['config']) {
|
||||
return function Wrapper({ children }: { children: ReactNode }) {
|
||||
return (
|
||||
<JWTContext config={{ autoObtain: false, ...config }}>
|
||||
{children}
|
||||
</JWTContext>
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
describe('useJWTReady', () => {
|
||||
it('returns false outside JWTContext', () => {
|
||||
const { result } = renderHook(() => useJWTReady())
|
||||
expect(result.current).toBe(false)
|
||||
})
|
||||
|
||||
it('returns false when no tokens', () => {
|
||||
const { result } = renderHook(() => useJWTReady(), {
|
||||
wrapper: createWrapper(),
|
||||
})
|
||||
expect(result.current).toBe(false)
|
||||
})
|
||||
})
|
||||
79
packages/mizan-csr/adapters/react/src/jwt/index.ts
Normal file
79
packages/mizan-csr/adapters/react/src/jwt/index.ts
Normal file
@@ -0,0 +1,79 @@
|
||||
/**
|
||||
* mizan/jwt
|
||||
*
|
||||
* JWT token management via mizan server functions.
|
||||
* Handles token lifecycle: obtain, refresh, clear.
|
||||
*
|
||||
* ## Quick Start
|
||||
*
|
||||
* Use JWTContext in authenticated areas (e.g., inside UserRoute):
|
||||
*
|
||||
* ```tsx
|
||||
* import { JWTContext } from 'mizan/jwt'
|
||||
* import { UserRoute } from 'mizan/allauth'
|
||||
*
|
||||
* function ProtectedPage() {
|
||||
* return (
|
||||
* <UserRoute>
|
||||
* <JWTContext>
|
||||
* <MyProtectedContent />
|
||||
* </JWTContext>
|
||||
* </UserRoute>
|
||||
* )
|
||||
* }
|
||||
* ```
|
||||
*
|
||||
* Then use JWT-authenticated requests:
|
||||
*
|
||||
* ```tsx
|
||||
* import { useDjangoCSRClient, Auth } from 'mizan/client/react'
|
||||
*
|
||||
* function MyProtectedContent() {
|
||||
* const client = useDjangoCSRClient(Auth.JWT)
|
||||
*
|
||||
* const fetchData = async () => {
|
||||
* const response = await client.request('GET', '/api/protected/')
|
||||
* return response.json()
|
||||
* }
|
||||
* }
|
||||
* ```
|
||||
*
|
||||
* ## How It Works
|
||||
*
|
||||
* 1. JWTContext calls jwt_obtain server function (via /api/mizan/call/)
|
||||
* 2. If not authenticated, returns FORBIDDEN (tokens stay null)
|
||||
* 3. Client uses getAccessToken() for Bearer token injection
|
||||
* 4. Tokens auto-refresh via jwt_refresh server function
|
||||
* 5. On logout, call clearTokens()
|
||||
*
|
||||
* ## Configuration
|
||||
*
|
||||
* ```tsx
|
||||
* <JWTContext
|
||||
* config={{
|
||||
* endpoint: '/api/mizan/call/', // default
|
||||
* refreshBuffer: 30, // refresh 30s before expiry
|
||||
* autoObtain: true, // obtain on mount
|
||||
* autoRefresh: true, // auto-refresh before expiry
|
||||
* }}
|
||||
* >
|
||||
* ```
|
||||
*
|
||||
* ## Manual Token Management
|
||||
*
|
||||
* ```tsx
|
||||
* import { useJWT } from 'mizan/jwt'
|
||||
*
|
||||
* function LogoutButton() {
|
||||
* const jwt = useJWT()
|
||||
*
|
||||
* const handleLogout = async () => {
|
||||
* await fetch('/api/logout/', { method: 'POST' })
|
||||
* jwt?.clearTokens()
|
||||
* }
|
||||
* }
|
||||
* ```
|
||||
*/
|
||||
|
||||
export { JWTContext, useJWT, useJWTRequired, useJWTReady } from './JWTContext'
|
||||
export type { JWTTokens, JWTConfig, JWTState } from '../client/types'
|
||||
42
packages/mizan-csr/adapters/react/src/testing.ts
Normal file
42
packages/mizan-csr/adapters/react/src/testing.ts
Normal file
@@ -0,0 +1,42 @@
|
||||
/**
|
||||
* Integration Test Helper
|
||||
*
|
||||
* Integration tests require a running backend: docker-compose up
|
||||
*
|
||||
* To run integration tests:
|
||||
* RUN_INTEGRATION_TESTS=true npm run test
|
||||
*
|
||||
* By default, integration tests are skipped in the regular test run.
|
||||
*/
|
||||
|
||||
export const runIntegrationTests = process.env.RUN_INTEGRATION_TESTS === 'true'
|
||||
|
||||
// Type for Jest's describe function (simplied, avoids needing @types/jest at build time)
|
||||
type DescribeFn = {
|
||||
(name: string, fn: () => void): void
|
||||
skip: (name: string, fn: () => void) => void
|
||||
}
|
||||
|
||||
// Declare global describe from Jest (only available in test environment)
|
||||
declare const describe: DescribeFn
|
||||
|
||||
/**
|
||||
* Use this instead of `describe` for integration test suites that require a backend.
|
||||
* Tests will be skipped unless RUN_INTEGRATION_TESTS=true.
|
||||
*/
|
||||
export const describeIntegration = runIntegrationTests ? describe : describe.skip
|
||||
|
||||
/**
|
||||
* Backend URL from environment or default localhost
|
||||
*/
|
||||
export const BACKEND_URL = (() => {
|
||||
if (!process.env.NEXT_PUBLIC_HOST_URL) {
|
||||
console.warn('[mizan/testing] NEXT_PUBLIC_HOST_URL not set, falling back to http://localhost')
|
||||
}
|
||||
return process.env.NEXT_PUBLIC_HOST_URL || 'http://localhost'
|
||||
})()
|
||||
|
||||
/**
|
||||
* WebSocket URL derived from backend URL
|
||||
*/
|
||||
export const WS_URL = BACKEND_URL.replace(/^http/, 'ws') + '/ws/'
|
||||
23
packages/mizan-csr/adapters/react/tsconfig.build.json
Normal file
23
packages/mizan-csr/adapters/react/tsconfig.build.json
Normal file
@@ -0,0 +1,23 @@
|
||||
{
|
||||
"compilerOptions": {
|
||||
"target": "ES2022",
|
||||
"module": "ESNext",
|
||||
"moduleResolution": "bundler",
|
||||
"declaration": true,
|
||||
"declarationMap": true,
|
||||
"sourceMap": true,
|
||||
"outDir": "dist",
|
||||
"rootDir": "src",
|
||||
"strict": true,
|
||||
"esModuleInterop": true,
|
||||
"jsx": "react-jsx",
|
||||
"skipLibCheck": true,
|
||||
"baseUrl": ".",
|
||||
"paths": {
|
||||
"mizan": ["./src/index.ts"],
|
||||
"mizan/*": ["./src/*"]
|
||||
}
|
||||
},
|
||||
"include": ["src"],
|
||||
"exclude": ["src/**/__tests__/**", "src/**/*.test.*"]
|
||||
}
|
||||
14
packages/mizan-csr/adapters/react/tsconfig.json
Normal file
14
packages/mizan-csr/adapters/react/tsconfig.json
Normal file
@@ -0,0 +1,14 @@
|
||||
{
|
||||
"compilerOptions": {
|
||||
"target": "ES2022",
|
||||
"module": "ESNext",
|
||||
"moduleResolution": "bundler",
|
||||
"sourceMap": true,
|
||||
"strict": true,
|
||||
"esModuleInterop": true,
|
||||
"jsx": "react-jsx",
|
||||
"skipLibCheck": true,
|
||||
"types": ["vitest/globals"]
|
||||
},
|
||||
"include": ["src"]
|
||||
}
|
||||
27
packages/mizan-csr/adapters/react/vitest.config.ts
Normal file
27
packages/mizan-csr/adapters/react/vitest.config.ts
Normal file
@@ -0,0 +1,27 @@
|
||||
import { defineConfig } from 'vitest/config'
|
||||
import path from 'path'
|
||||
|
||||
export default defineConfig({
|
||||
resolve: {
|
||||
alias: {
|
||||
'mizan/channels': path.resolve(__dirname, 'src/channels/index.ts'),
|
||||
'mizan/client/react': path.resolve(__dirname, 'src/client/react.ts'),
|
||||
'mizan/client/nextjs': path.resolve(__dirname, 'src/client/nextjs.tsx'),
|
||||
'mizan/client': path.resolve(__dirname, 'src/client/index.ts'),
|
||||
'mizan/jwt': path.resolve(__dirname, 'src/jwt/index.ts'),
|
||||
'mizan/allauth/nextjs': path.resolve(__dirname, 'src/allauth/nextjs.tsx'),
|
||||
'mizan/allauth': path.resolve(__dirname, 'src/allauth/index.ts'),
|
||||
'mizan': path.resolve(__dirname, 'src/index.ts'),
|
||||
},
|
||||
},
|
||||
test: {
|
||||
globals: true,
|
||||
environment: 'jsdom',
|
||||
setupFiles: ['./vitest.setup.ts'],
|
||||
include: ['src/**/*.test.{ts,tsx}'],
|
||||
exclude: [
|
||||
// Requires @/api/generated.mizan.schema.json from consuming project
|
||||
'src/jwt/__tests__/contract.test.ts',
|
||||
],
|
||||
},
|
||||
})
|
||||
6
packages/mizan-csr/adapters/react/vitest.setup.ts
Normal file
6
packages/mizan-csr/adapters/react/vitest.setup.ts
Normal file
@@ -0,0 +1,6 @@
|
||||
import { vi } from 'vitest'
|
||||
import '@testing-library/jest-dom/vitest'
|
||||
|
||||
// Jest compatibility: existing tests use jest.fn(), jest.spyOn(), jest.mock()
|
||||
// Vitest's `vi` object has the same API, so we alias it globally.
|
||||
;(globalThis as any).jest = vi
|
||||
4
packages/mizan-rpc/adapters/django/.gitignore
vendored
Normal file
4
packages/mizan-rpc/adapters/django/.gitignore
vendored
Normal file
@@ -0,0 +1,4 @@
|
||||
__pycache__/
|
||||
*.egg-info/
|
||||
dist/
|
||||
build/
|
||||
105
packages/mizan-rpc/adapters/django/README.md
Normal file
105
packages/mizan-rpc/adapters/django/README.md
Normal file
@@ -0,0 +1,105 @@
|
||||
# mizan (Python)
|
||||
|
||||
Django server functions framework. See the [monorepo root](../README.md) for full documentation.
|
||||
|
||||
## Install
|
||||
|
||||
```bash
|
||||
uv add "mizan[channels,allauth] @ git+https://git.impactsoundworks.com/isw/mizan.git#subdirectory=django"
|
||||
```
|
||||
|
||||
## Setup
|
||||
|
||||
```python
|
||||
# settings.py
|
||||
INSTALLED_APPS = ["mizan", ...]
|
||||
|
||||
# urls.py
|
||||
path("api/mizan/", include("mizan.urls"))
|
||||
|
||||
# asgi.py (optional, for WebSocket)
|
||||
from mizan import wrap_asgi
|
||||
application = wrap_asgi(get_asgi_application())
|
||||
```
|
||||
|
||||
## Define Functions
|
||||
|
||||
```python
|
||||
from mizan.client import client
|
||||
from mizan.setup.registry import register
|
||||
from pydantic import BaseModel
|
||||
|
||||
class Output(BaseModel):
|
||||
message: str
|
||||
|
||||
@client
|
||||
def echo(request, text: str) -> Output:
|
||||
return Output(message=text)
|
||||
|
||||
register(echo, "echo")
|
||||
```
|
||||
|
||||
Register in `apps.py`:
|
||||
|
||||
```python
|
||||
def ready(self):
|
||||
import myapp.mizan_clients
|
||||
```
|
||||
|
||||
## Auth
|
||||
|
||||
```python
|
||||
@client(auth=True) # requires authentication
|
||||
@client(auth='staff') # requires is_staff
|
||||
@client(auth='superuser') # requires is_superuser
|
||||
@client(auth=my_callable) # custom check
|
||||
```
|
||||
|
||||
## Contexts
|
||||
|
||||
```python
|
||||
@client(context='global') # fetched once, SSR-hydrated, becomes useCurrentUser()
|
||||
@client(context='local') # fetched with params, becomes <GreetProvider>
|
||||
```
|
||||
|
||||
## Forms
|
||||
|
||||
```python
|
||||
from mizan.forms import mizanFormMixin, mizanFormMeta
|
||||
|
||||
class ContactForm(mizanFormMixin, forms.Form):
|
||||
mizan = mizanFormMeta(name="contact", title="Contact Us")
|
||||
name = forms.CharField()
|
||||
email = forms.EmailField()
|
||||
|
||||
def on_submit_success(self, request):
|
||||
return {"sent": True}
|
||||
```
|
||||
|
||||
Auto-registers `contact.schema`, `contact.validate`, `contact.submit`. Generates `useContactForm()` with Zod validation.
|
||||
|
||||
## Channels
|
||||
|
||||
```python
|
||||
from mizan.channels import ReactChannel
|
||||
|
||||
class ChatChannel(ReactChannel):
|
||||
class Params(BaseModel):
|
||||
room: str
|
||||
class DjangoMessage(BaseModel):
|
||||
text: str
|
||||
|
||||
def authorize(self, params):
|
||||
return self.user.is_authenticated
|
||||
def group(self, params):
|
||||
return f"chat_{params.room}"
|
||||
```
|
||||
|
||||
Generates `useChatChannel({ room })`.
|
||||
|
||||
## Running Tests
|
||||
|
||||
```bash
|
||||
uv sync --extra dev --extra channels
|
||||
uv run pytest
|
||||
```
|
||||
46
packages/mizan-rpc/adapters/django/pyproject.toml
Normal file
46
packages/mizan-rpc/adapters/django/pyproject.toml
Normal file
@@ -0,0 +1,46 @@
|
||||
[project]
|
||||
name = "mizan"
|
||||
version = "1.0.1"
|
||||
description = "Django + React server functions framework"
|
||||
readme = "README.md"
|
||||
requires-python = ">=3.10"
|
||||
dependencies = [
|
||||
"django>=5.0",
|
||||
"django-ninja>=1.0",
|
||||
"django-readers>=2.0",
|
||||
"pydantic>=2.0",
|
||||
"PyJWT>=2.0",
|
||||
]
|
||||
|
||||
[project.optional-dependencies]
|
||||
channels = [
|
||||
"channels>=4.0",
|
||||
"channels-redis>=4.0",
|
||||
]
|
||||
allauth = [
|
||||
"django-allauth>=65.0",
|
||||
]
|
||||
webauthn = [
|
||||
"fido2>=2.0",
|
||||
]
|
||||
shapes = [
|
||||
"django-readers>=2.0",
|
||||
]
|
||||
dev = [
|
||||
"pytest>=8.0",
|
||||
"pytest-django>=4.9",
|
||||
]
|
||||
|
||||
[build-system]
|
||||
requires = ["hatchling"]
|
||||
build-backend = "hatchling.build"
|
||||
|
||||
[tool.hatch.build.targets.wheel]
|
||||
packages = ["src/mizan"]
|
||||
|
||||
[tool.pytest.ini_options]
|
||||
DJANGO_SETTINGS_MODULE = "tests.settings"
|
||||
pythonpath = ["src", "."]
|
||||
testpaths = ["src/mizan/tests"]
|
||||
python_classes = ["*Tests", "*Test", "Test*"]
|
||||
python_functions = ["test_*"]
|
||||
187
packages/mizan-rpc/adapters/django/src/mizan/__init__.py
Normal file
187
packages/mizan-rpc/adapters/django/src/mizan/__init__.py
Normal file
@@ -0,0 +1,187 @@
|
||||
"""
|
||||
mizan - Django + React unified framework
|
||||
|
||||
Server functions are the core primitive. Everything else builds on them.
|
||||
|
||||
## Quick Start
|
||||
|
||||
### 1. urls.py - HTTP endpoint
|
||||
```python
|
||||
from mizan import urls as mizan_urls
|
||||
|
||||
urlpatterns = [
|
||||
path('api/mizan/', include(mizan_urls)),
|
||||
]
|
||||
```
|
||||
|
||||
### 2. asgi.py - WebSocket support (optional)
|
||||
```python
|
||||
from mizan import wrap_asgi
|
||||
from django.core.asgi import get_asgi_application
|
||||
|
||||
application = wrap_asgi(get_asgi_application())
|
||||
```
|
||||
|
||||
### 3. Define server functions
|
||||
```python
|
||||
# apps/myapp/clients.py
|
||||
from mizan import client
|
||||
from pydantic import BaseModel
|
||||
|
||||
class EchoOutput(BaseModel):
|
||||
message: str
|
||||
|
||||
# HTTP-only function (default)
|
||||
@client
|
||||
def echo(request, text: str) -> EchoOutput:
|
||||
return EchoOutput(message=f"Echo: {text}")
|
||||
|
||||
# Global context (singleton, SSR-hydrated)
|
||||
@client(context='global')
|
||||
def current_user(request) -> UserOutput:
|
||||
return UserOutput(email=request.user.email)
|
||||
|
||||
# WebSocket-enabled for real-time
|
||||
@client(websocket=True)
|
||||
def send_message(request, room_id: int, text: str) -> MessageOutput:
|
||||
return MessageOutput(...)
|
||||
```
|
||||
|
||||
### 4. Auto-discover in apps.py
|
||||
```python
|
||||
class MyAppConfig(AppConfig):
|
||||
def ready(self):
|
||||
from mizan.setup import mizan_clients
|
||||
mizan_clients('apps')
|
||||
```
|
||||
|
||||
### 5. Frontend - generate types and use
|
||||
```bash
|
||||
npm run schemas
|
||||
```
|
||||
```tsx
|
||||
import { useEcho, useCurrentUser } from '@/api'
|
||||
|
||||
const user = useCurrentUser()
|
||||
const echo = useEcho()
|
||||
await echo({ text: 'hello' })
|
||||
```
|
||||
|
||||
## What You Get
|
||||
|
||||
| Backend | Frontend | Transport |
|
||||
|------------------------------------|-----------------------|------------|
|
||||
| `@client` | `useXxx()` hook | HTTP |
|
||||
| `@client(context='global')` | `useXxx()` + SSR | HTTP |
|
||||
| `@client(context='local')` | `<XxxProvider>` + hook| HTTP |
|
||||
| `@client(websocket=True)` | `useXxx()` hook | WebSocket |
|
||||
| `@compose(...)` | `<XxxProvider>` combined | varies |
|
||||
| `mizanFormMixin` | `useXxxForm()` + Zod | HTTP |
|
||||
| `ReactChannel` | `useXxxChannel()` | WebSocket |
|
||||
"""
|
||||
|
||||
# All imports at module level (sorted)
|
||||
from . import channels
|
||||
from . import client as client_module
|
||||
from . import export
|
||||
from . import forms
|
||||
from . import setup
|
||||
from .channels import ReactChannel
|
||||
from .channels import register as register_channel
|
||||
from .client import ComposedContext, GlobalContext, ReactContext, ServerFunction, client, compose
|
||||
|
||||
# Shape is lazy-loaded via __getattr__ because django_readers
|
||||
# imports contenttypes, which can't happen during apps.populate()
|
||||
from .setup import (
|
||||
mizan_clients,
|
||||
mizan_module,
|
||||
get_channel,
|
||||
get_function,
|
||||
register,
|
||||
register_as,
|
||||
)
|
||||
|
||||
|
||||
def __getattr__(name):
|
||||
"""Lazy loading for modules that can't be imported at app load time."""
|
||||
if name == "urls":
|
||||
from .urls import urlpatterns as mizan_patterns
|
||||
|
||||
return mizan_patterns
|
||||
if name == "Shape":
|
||||
from .shapes import Shape
|
||||
|
||||
return Shape
|
||||
raise AttributeError(f"module {__name__!r} has no attribute {name!r}")
|
||||
|
||||
|
||||
def wrap_asgi(http_application):
|
||||
"""
|
||||
Wrap an ASGI application with mizan WebSocket support.
|
||||
|
||||
Usage in asgi.py:
|
||||
from django.core.asgi import get_asgi_application
|
||||
from mizan import wrap_asgi
|
||||
|
||||
application = wrap_asgi(get_asgi_application())
|
||||
|
||||
This adds:
|
||||
- WebSocket routing at /ws/ for RPC and channels
|
||||
- Authentication middleware for WebSocket connections
|
||||
"""
|
||||
try:
|
||||
from channels.auth import AuthMiddlewareStack
|
||||
from channels.routing import ProtocolTypeRouter, URLRouter
|
||||
from django.urls import path
|
||||
except ImportError:
|
||||
raise ImportError(
|
||||
"django-channels is required for WebSocket support.\n"
|
||||
"Install with: pip install channels channels-redis\n"
|
||||
"Add 'channels' to INSTALLED_APPS and configure CHANNEL_LAYERS."
|
||||
)
|
||||
|
||||
from .channels.connection import DjangoReactConsumer
|
||||
|
||||
return ProtocolTypeRouter(
|
||||
{
|
||||
"http": http_application,
|
||||
"websocket": AuthMiddlewareStack(
|
||||
URLRouter(
|
||||
[
|
||||
path("ws/", DjangoReactConsumer.as_asgi()),
|
||||
]
|
||||
)
|
||||
),
|
||||
}
|
||||
)
|
||||
|
||||
|
||||
__all__ = [
|
||||
# Decorators & Contexts
|
||||
"client",
|
||||
"compose",
|
||||
"ReactContext",
|
||||
"GlobalContext",
|
||||
"ServerFunction",
|
||||
"ComposedContext",
|
||||
# Setup
|
||||
"mizan_clients",
|
||||
"mizan_module",
|
||||
"register",
|
||||
"register_as",
|
||||
"get_function",
|
||||
"get_channel",
|
||||
# ASGI
|
||||
"wrap_asgi",
|
||||
# Channels
|
||||
"ReactChannel",
|
||||
"register_channel",
|
||||
# Shapes
|
||||
"Shape",
|
||||
# Submodules
|
||||
"client_module",
|
||||
"setup",
|
||||
"forms",
|
||||
"channels",
|
||||
"export",
|
||||
]
|
||||
@@ -0,0 +1,91 @@
|
||||
import inspect
|
||||
from importlib import import_module
|
||||
from inspect import isclass
|
||||
from typing import Protocol, Any
|
||||
from django.conf import settings
|
||||
|
||||
|
||||
def get_members(path):
|
||||
try:
|
||||
module = import_module(path)
|
||||
except ModuleNotFoundError:
|
||||
print('Could not import module "{}"'.format(path))
|
||||
return []
|
||||
|
||||
members = [
|
||||
(name, member)
|
||||
for name, member in inspect.getmembers(module)
|
||||
if not isclass(member) or (member.__module__ == module.__name__)
|
||||
]
|
||||
|
||||
return members
|
||||
|
||||
|
||||
class DjangoAppVisitorHandler(Protocol):
|
||||
def on_module(
|
||||
self, app_name: str, path_parts: list[str], members: list[tuple[str, Any]]
|
||||
) -> None: ...
|
||||
|
||||
|
||||
class DjangoAppVisitor:
|
||||
"""
|
||||
Discovers Python modules under each Django app following conventions:
|
||||
- <app>/<module>.py -> url_prefix "<renamed>/"
|
||||
- <app>/<module>/**/*.py -> url_prefix "<renamed>/<subdirs...>/<module>/"
|
||||
|
||||
Example:
|
||||
<app>/<module>/forms/nksn.py -> url_prefix "<renamed>/forms/nksn/"
|
||||
module_path "<app>.module.forms.nksn"
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
layer: str,
|
||||
apps_root: str = "",
|
||||
):
|
||||
self.apps_root = apps_root
|
||||
self.layer = layer
|
||||
|
||||
def visit(self, handler: DjangoAppVisitorHandler) -> None:
|
||||
apps_dir = (
|
||||
settings.BASE_DIR / self.apps_root if self.apps_root else settings.BASE_DIR
|
||||
)
|
||||
if not apps_dir.is_dir():
|
||||
apps_dir = settings.BASE_DIR
|
||||
|
||||
module_prefix = f"{self.apps_root}." if self.apps_root else ""
|
||||
|
||||
for app_name in settings.INSTALLED_APPS:
|
||||
if app_name.startswith(self.apps_root + "."):
|
||||
app_name = app_name[(len(self.apps_root) + 1) :]
|
||||
|
||||
app_dir = apps_dir / app_name
|
||||
if not app_dir.exists():
|
||||
continue
|
||||
|
||||
app_module = f"{module_prefix}{app_name}"
|
||||
|
||||
# 1) Visit package: <app>/<module>/**/*.py
|
||||
layer_dir = app_dir / self.layer
|
||||
if layer_dir.is_dir():
|
||||
for py_file in layer_dir.rglob("*.py"):
|
||||
if py_file.name == "__init__.py":
|
||||
continue
|
||||
|
||||
relative_path = py_file.relative_to(layer_dir).with_suffix("")
|
||||
parts = list(relative_path.parts)
|
||||
|
||||
dotted = ".".join(parts)
|
||||
handler.on_module(
|
||||
app_name,
|
||||
parts,
|
||||
get_members(f"{app_module}.{self.layer}.{dotted}"),
|
||||
)
|
||||
|
||||
# 2) Visit module module file: <app>/module.py
|
||||
layer_file = app_dir / f"{self.layer}.py"
|
||||
if layer_file.is_file():
|
||||
handler.on_module(
|
||||
app_name, [], get_members(f"{app_module}.{self.layer}")
|
||||
)
|
||||
@@ -0,0 +1,543 @@
|
||||
"""
|
||||
mizan.channels - Real-time WebSocket communication.
|
||||
|
||||
Type-safe bidirectional messaging between Django and React via WebSockets.
|
||||
Hooks are auto-generated with full TypeScript types.
|
||||
|
||||
## Basic Usage
|
||||
|
||||
```python
|
||||
# channels.py
|
||||
from pydantic import BaseModel
|
||||
from mizan import channels
|
||||
|
||||
class ChatChannel(channels.ReactChannel):
|
||||
|
||||
class Params(BaseModel):
|
||||
room: str
|
||||
|
||||
class ReactMessage(BaseModel):
|
||||
text: str
|
||||
|
||||
class DjangoMessage(BaseModel):
|
||||
user: str
|
||||
text: str
|
||||
timestamp: datetime
|
||||
|
||||
def authorize(self, params: Params) -> bool:
|
||||
return self.user.is_authenticated
|
||||
|
||||
def group(self, params: Params) -> str:
|
||||
return f'chat_{params.room}'
|
||||
|
||||
def receive(self, params: Params, msg: ReactMessage) -> DjangoMessage | None:
|
||||
return self.DjangoMessage(
|
||||
user=self.user.email,
|
||||
text=msg.text,
|
||||
timestamp=now(),
|
||||
)
|
||||
|
||||
channels.register(ChatChannel, 'chat')
|
||||
```
|
||||
|
||||
```python
|
||||
# asgi.py
|
||||
from mizan import channels
|
||||
|
||||
application = ProtocolTypeRouter({
|
||||
"http": get_asgi_application(),
|
||||
"websocket": channels.get_websocket_application(),
|
||||
})
|
||||
```
|
||||
|
||||
## Frontend Usage (auto-generated)
|
||||
|
||||
```tsx
|
||||
import { useChatChannel } from '@/api/generated.channels'
|
||||
|
||||
function Chat({ room }) {
|
||||
const chat = useChatChannel({ room })
|
||||
|
||||
chat.status // 'connecting' | 'connected' | 'disconnected'
|
||||
chat.messages // DjangoMessage[]
|
||||
chat.send({ text: 'Hello' }) // ReactMessage
|
||||
}
|
||||
```
|
||||
|
||||
## Server Push
|
||||
|
||||
```python
|
||||
await ChatChannel.push(room='general', message=ChatChannel.DjangoMessage(...))
|
||||
```
|
||||
"""
|
||||
|
||||
import logging
|
||||
from typing import TYPE_CHECKING, Any, ClassVar, Type
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from django.contrib.auth.models import AbstractBaseUser, AnonymousUser
|
||||
from ninja import NinjaAPI
|
||||
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# Base Classes
|
||||
# =============================================================================
|
||||
|
||||
|
||||
class ReactChannel:
|
||||
"""
|
||||
Base class for WebSocket channels.
|
||||
|
||||
Define nested Pydantic classes for typed messaging:
|
||||
- Params: Query parameters for subscribing (optional)
|
||||
- ReactMessage: Messages from browser to server (optional)
|
||||
- DjangoMessage: Messages from server to browser (optional)
|
||||
|
||||
Implement required methods:
|
||||
- authorize(): Permission check for connection
|
||||
- group(): Which group to broadcast to
|
||||
|
||||
Optionally implement:
|
||||
- receive(): Handle incoming ReactMessage, return DjangoMessage to broadcast
|
||||
- on_connect(): Called after successful connection
|
||||
- on_disconnect(): Called when connection closes
|
||||
"""
|
||||
|
||||
# Nested classes (optional, defined by subclasses)
|
||||
Params: ClassVar[Type[BaseModel] | None] = None
|
||||
ReactMessage: ClassVar[Type[BaseModel] | None] = None
|
||||
DjangoMessage: ClassVar[Type[BaseModel] | None] = None
|
||||
|
||||
# Set by the framework when handling a connection
|
||||
user: "AbstractBaseUser | AnonymousUser"
|
||||
_channel_layer: Any = None
|
||||
_channel_name: str = ""
|
||||
_registered_name: ClassVar[str] = ""
|
||||
_params_dict: dict = {}
|
||||
_groups: set[str]
|
||||
|
||||
def __init__(self):
|
||||
self._groups = set()
|
||||
self._params_dict = {}
|
||||
|
||||
def authorize(self, params: BaseModel | None = None) -> bool:
|
||||
"""
|
||||
Permission check. Return True to allow connection, False to reject.
|
||||
|
||||
Override this to implement custom authorization logic.
|
||||
"""
|
||||
raise NotImplementedError(
|
||||
f"{self.__class__.__name__} must implement authorize()"
|
||||
)
|
||||
|
||||
def group(self, params: BaseModel | None = None) -> str:
|
||||
"""
|
||||
Return the group name for broadcasting.
|
||||
|
||||
Messages returned from receive() are broadcast to this group.
|
||||
"""
|
||||
raise NotImplementedError(f"{self.__class__.__name__} must implement group()")
|
||||
|
||||
def receive(self, params: BaseModel | None, msg: BaseModel) -> BaseModel | None:
|
||||
"""
|
||||
Handle incoming ReactMessage.
|
||||
|
||||
Return a DjangoMessage to broadcast to the group, or None to skip.
|
||||
Override this to implement message handling.
|
||||
"""
|
||||
return None
|
||||
|
||||
async def on_connect(self, params: BaseModel | None = None) -> None:
|
||||
"""Called after successful connection and group join."""
|
||||
pass
|
||||
|
||||
async def on_disconnect(self) -> None:
|
||||
"""Called when the connection closes."""
|
||||
pass
|
||||
|
||||
# -------------------------------------------------------------------------
|
||||
# Internal Methods (used by the consumer)
|
||||
# -------------------------------------------------------------------------
|
||||
|
||||
async def _join_group(self, group_name: str) -> None:
|
||||
"""Join a channel layer group."""
|
||||
if self._channel_layer:
|
||||
await self._channel_layer.group_add(group_name, self._channel_name)
|
||||
self._groups.add(group_name)
|
||||
|
||||
async def _leave_group(self, group_name: str) -> None:
|
||||
"""Leave a channel layer group."""
|
||||
if self._channel_layer and group_name in self._groups:
|
||||
await self._channel_layer.group_discard(group_name, self._channel_name)
|
||||
self._groups.discard(group_name)
|
||||
|
||||
async def _leave_all_groups(self) -> None:
|
||||
"""Leave all joined groups."""
|
||||
for group_name in list(self._groups):
|
||||
await self._leave_group(group_name)
|
||||
|
||||
async def _broadcast(self, group_name: str, message: BaseModel) -> None:
|
||||
"""Broadcast a message to a group."""
|
||||
if self._channel_layer:
|
||||
await self._channel_layer.group_send(
|
||||
group_name,
|
||||
{
|
||||
"type": "channel.message",
|
||||
"channel": self._registered_name,
|
||||
"params": self._params_dict,
|
||||
"data": message.model_dump(mode="json"),
|
||||
"message_type": message.__class__.__name__,
|
||||
},
|
||||
)
|
||||
|
||||
# -------------------------------------------------------------------------
|
||||
# Class Methods for Server Push
|
||||
# -------------------------------------------------------------------------
|
||||
|
||||
@classmethod
|
||||
async def push(cls, message: BaseModel, **params) -> None:
|
||||
"""
|
||||
Push a message from server code (views, tasks, signals).
|
||||
|
||||
Usage:
|
||||
await ChatChannel.push(
|
||||
room='general',
|
||||
message=ChatChannel.DjangoMessage(user='system', text='Hello')
|
||||
)
|
||||
"""
|
||||
from channels.layers import get_channel_layer
|
||||
|
||||
channel_layer = get_channel_layer()
|
||||
if not channel_layer:
|
||||
logger.warning(
|
||||
f"No channel layer configured, cannot push to {cls.__name__}"
|
||||
)
|
||||
return
|
||||
|
||||
# Build params model if defined
|
||||
params_obj = None
|
||||
if cls.Params:
|
||||
params_obj = cls.Params(**params)
|
||||
|
||||
# Get group name
|
||||
instance = cls()
|
||||
group_name = instance.group(params_obj)
|
||||
|
||||
# Send to group
|
||||
await channel_layer.group_send(
|
||||
group_name,
|
||||
{
|
||||
"type": "channel.message",
|
||||
"channel": cls._registered_name,
|
||||
"params": params,
|
||||
"data": message.model_dump(mode="json"),
|
||||
"message_type": message.__class__.__name__,
|
||||
},
|
||||
)
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# Registry
|
||||
# =============================================================================
|
||||
|
||||
_registry: dict[str, Type[ReactChannel]] = {}
|
||||
|
||||
|
||||
def register(channel_class: Type[ReactChannel], name: str) -> None:
|
||||
"""
|
||||
Register a channel.
|
||||
|
||||
Args:
|
||||
channel_class: The ReactChannel subclass to register
|
||||
name: URL-friendly name (used in subscriptions)
|
||||
"""
|
||||
if name in _registry:
|
||||
raise ValueError(f"Channel '{name}' is already registered")
|
||||
|
||||
channel_class._registered_name = name
|
||||
|
||||
# Validate the channel class
|
||||
if not hasattr(channel_class, "authorize"):
|
||||
raise ValueError(f"{channel_class.__name__} must implement authorize()")
|
||||
if not hasattr(channel_class, "group"):
|
||||
raise ValueError(f"{channel_class.__name__} must implement group()")
|
||||
|
||||
_registry[name] = channel_class
|
||||
logger.debug(f"Registered channel: {name} -> {channel_class.__name__}")
|
||||
|
||||
|
||||
def get_channel(name: str) -> Type[ReactChannel] | None:
|
||||
"""Get a registered channel class by name."""
|
||||
return _registry.get(name)
|
||||
|
||||
|
||||
def get_registered_channels() -> dict[str, Type[ReactChannel]]:
|
||||
"""Get all registered channel classes."""
|
||||
return dict(_registry)
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# WebSocket Consumer
|
||||
# =============================================================================
|
||||
|
||||
|
||||
def get_websocket_application():
|
||||
"""
|
||||
Get the WebSocket application for ASGI.
|
||||
|
||||
Usage in asgi.py:
|
||||
from mizan import channels
|
||||
|
||||
application = ProtocolTypeRouter({
|
||||
"http": get_asgi_application(),
|
||||
"websocket": channels.get_websocket_application(),
|
||||
})
|
||||
"""
|
||||
try:
|
||||
from channels.routing import URLRouter
|
||||
from channels.auth import AuthMiddlewareStack
|
||||
from django.urls import path
|
||||
except ImportError:
|
||||
raise ImportError(
|
||||
"django-channels is required for WebSocket support. "
|
||||
"Install it with: pip install channels channels-redis"
|
||||
)
|
||||
|
||||
from .connection import DjangoReactConsumer
|
||||
|
||||
return AuthMiddlewareStack(
|
||||
URLRouter(
|
||||
[
|
||||
path("ws/", DjangoReactConsumer.as_asgi()),
|
||||
]
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# Schema Export (for TypeScript generation)
|
||||
# =============================================================================
|
||||
|
||||
|
||||
def get_channels_schema() -> dict:
|
||||
"""
|
||||
Get schema for all registered channels (for TypeScript generation).
|
||||
|
||||
Returns a dict suitable for the frontend code generator.
|
||||
"""
|
||||
schema = {"channels": {}}
|
||||
|
||||
for name, channel_class in _registry.items():
|
||||
channel_schema = {
|
||||
"name": name,
|
||||
"params": None,
|
||||
"reactMessage": None,
|
||||
"djangoMessage": None,
|
||||
}
|
||||
|
||||
# Extract Params schema
|
||||
if hasattr(channel_class, "Params") and channel_class.Params:
|
||||
channel_schema["params"] = channel_class.Params.model_json_schema()
|
||||
|
||||
# Extract ReactMessage schema
|
||||
if hasattr(channel_class, "ReactMessage") and channel_class.ReactMessage:
|
||||
channel_schema[
|
||||
"reactMessage"
|
||||
] = channel_class.ReactMessage.model_json_schema()
|
||||
|
||||
# Extract DjangoMessage schema
|
||||
if hasattr(channel_class, "DjangoMessage") and channel_class.DjangoMessage:
|
||||
channel_schema[
|
||||
"djangoMessage"
|
||||
] = channel_class.DjangoMessage.model_json_schema()
|
||||
|
||||
schema["channels"][name] = channel_schema
|
||||
|
||||
return schema
|
||||
|
||||
|
||||
def _register_channel_schema_endpoint(
|
||||
api: "NinjaAPI",
|
||||
path: str,
|
||||
operation_id: str,
|
||||
summary: str,
|
||||
input_cls: type | None,
|
||||
output_cls: type,
|
||||
) -> None:
|
||||
"""Register a dummy endpoint for schema generation (avoids closure issues)."""
|
||||
if input_cls is not None:
|
||||
|
||||
def endpoint(request, data):
|
||||
pass
|
||||
|
||||
endpoint.__annotations__ = {"data": input_cls}
|
||||
else:
|
||||
|
||||
def endpoint(request):
|
||||
pass
|
||||
|
||||
api.post(path, response=output_cls, operation_id=operation_id, summary=summary)(
|
||||
endpoint
|
||||
)
|
||||
|
||||
|
||||
def get_channels_openapi_schema() -> dict:
|
||||
"""
|
||||
Get OpenAPI schema for all registered channels.
|
||||
|
||||
Uses Django Ninja's schema generation for robust Pydantic→OpenAPI conversion.
|
||||
This schema is consumed by openapi-typescript for type generation.
|
||||
"""
|
||||
from ninja import NinjaAPI
|
||||
from pydantic import BaseModel
|
||||
|
||||
# Create temporary Ninja API for schema generation only
|
||||
schema_api = NinjaAPI(
|
||||
title="mizan Channels",
|
||||
version="1.0.0",
|
||||
description="Auto-generated schema for mizan channels",
|
||||
docs_url=None,
|
||||
openapi_url=None,
|
||||
)
|
||||
|
||||
# Store dynamically created classes
|
||||
schema_classes: dict[str, type] = {}
|
||||
channel_metadata: list[dict] = []
|
||||
|
||||
for name, channel_class in _registry.items():
|
||||
pascal_name = name.replace("_", " ").title().replace(" ", "")
|
||||
|
||||
channel_meta = {
|
||||
"name": name,
|
||||
"pascalName": pascal_name,
|
||||
"hasParams": False,
|
||||
"hasReactMessage": False,
|
||||
"hasDjangoMessage": False,
|
||||
}
|
||||
|
||||
# Register Params type
|
||||
if hasattr(channel_class, "Params") and channel_class.Params:
|
||||
params_name = f"{pascal_name}Params"
|
||||
schema_classes[params_name] = type(params_name, (channel_class.Params,), {})
|
||||
channel_meta["hasParams"] = True
|
||||
channel_meta["paramsType"] = params_name
|
||||
|
||||
# Create dummy endpoint to include in schema
|
||||
_register_channel_schema_endpoint(
|
||||
api=schema_api,
|
||||
path=f"/channels/{name}/params",
|
||||
operation_id=f"{name}Params",
|
||||
summary=f"{pascal_name} channel params",
|
||||
input_cls=schema_classes[params_name],
|
||||
output_cls=BaseModel,
|
||||
)
|
||||
|
||||
# Register ReactMessage type
|
||||
if hasattr(channel_class, "ReactMessage") and channel_class.ReactMessage:
|
||||
react_name = f"{pascal_name}ReactMessage"
|
||||
schema_classes[react_name] = type(
|
||||
react_name, (channel_class.ReactMessage,), {}
|
||||
)
|
||||
channel_meta["hasReactMessage"] = True
|
||||
channel_meta["reactMessageType"] = react_name
|
||||
|
||||
_register_channel_schema_endpoint(
|
||||
api=schema_api,
|
||||
path=f"/channels/{name}/react",
|
||||
operation_id=f"{name}ReactMessage",
|
||||
summary=f"{pascal_name} React→Django message",
|
||||
input_cls=schema_classes[react_name],
|
||||
output_cls=BaseModel,
|
||||
)
|
||||
|
||||
# Register DjangoMessage type
|
||||
if hasattr(channel_class, "DjangoMessage") and channel_class.DjangoMessage:
|
||||
django_name = f"{pascal_name}DjangoMessage"
|
||||
schema_classes[django_name] = type(
|
||||
django_name, (channel_class.DjangoMessage,), {}
|
||||
)
|
||||
channel_meta["hasDjangoMessage"] = True
|
||||
channel_meta["djangoMessageType"] = django_name
|
||||
|
||||
_register_channel_schema_endpoint(
|
||||
api=schema_api,
|
||||
path=f"/channels/{name}/django",
|
||||
operation_id=f"{name}DjangoMessage",
|
||||
summary=f"{pascal_name} Django→React message",
|
||||
input_cls=None,
|
||||
output_cls=schema_classes[django_name],
|
||||
)
|
||||
|
||||
channel_metadata.append(channel_meta)
|
||||
|
||||
# Get OpenAPI schema from Ninja
|
||||
# path_prefix="" avoids URL reverse() — this API is never mounted
|
||||
schema = schema_api.get_openapi_schema(path_prefix="")
|
||||
|
||||
# Add channel metadata extension
|
||||
schema["x-mizan-channels"] = channel_metadata
|
||||
|
||||
return schema
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# Schema Endpoint (for TypeScript generation)
|
||||
# =============================================================================
|
||||
|
||||
_schema_router = None
|
||||
|
||||
|
||||
def _get_schema_router():
|
||||
"""Get the Ninja router for the channels schema endpoint."""
|
||||
global _schema_router
|
||||
if _schema_router is None:
|
||||
from ninja import Router
|
||||
|
||||
_schema_router = Router(tags=["channels"])
|
||||
|
||||
@_schema_router.get("/schema/")
|
||||
def channels_schema(request):
|
||||
"""Get schema for all registered channels (for TypeScript generation)."""
|
||||
return get_channels_schema()
|
||||
|
||||
return _schema_router
|
||||
|
||||
|
||||
def get_urls():
|
||||
"""Get URL patterns for channels schema endpoint."""
|
||||
from ninja import NinjaAPI
|
||||
|
||||
api = NinjaAPI(urls_namespace="django_react_channels")
|
||||
api.add_router("/", _get_schema_router())
|
||||
return api.urls
|
||||
|
||||
|
||||
def __getattr__(name):
|
||||
if name == "urls":
|
||||
return get_urls()
|
||||
raise AttributeError(f"module {__name__!r} has no attribute {name!r}")
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# Exports
|
||||
# =============================================================================
|
||||
|
||||
__all__ = [
|
||||
# URLs
|
||||
"urls",
|
||||
# Base class
|
||||
"ReactChannel",
|
||||
# Registration
|
||||
"register",
|
||||
"get_channel",
|
||||
"get_registered_channels",
|
||||
# ASGI application
|
||||
"get_websocket_application",
|
||||
# Schema export
|
||||
"get_channels_schema",
|
||||
]
|
||||
@@ -0,0 +1,528 @@
|
||||
"""
|
||||
WebSocket consumer for mizan.channels.
|
||||
|
||||
Handles multiplexed channel subscriptions AND RPC calls over a single WebSocket connection.
|
||||
|
||||
Protocol:
|
||||
Browser sends:
|
||||
# Channel subscriptions
|
||||
{"action": "subscribe", "channel": "chat", "params": {"room": "general"}}
|
||||
{"action": "unsubscribe", "channel": "chat", "params": {"room": "general"}}
|
||||
{"action": "message", "channel": "chat", "params": {"room": "general"}, "data": {...}}
|
||||
|
||||
# RPC calls (server functions)
|
||||
{"action": "rpc", "id": "request-id", "fn": "function_name", "args": {...}}
|
||||
|
||||
Server sends:
|
||||
# Channel messages
|
||||
{"channel": "chat", "params": {"room": "general"}, "type": "DjangoMessage", "data": {...}}
|
||||
|
||||
# RPC responses
|
||||
{"id": "request-id", "ok": true, "data": {...}}
|
||||
{"id": "request-id", "ok": false, "error": {...}}
|
||||
|
||||
{"error": "..."}
|
||||
|
||||
Authentication:
|
||||
Supports both session (cookie) and JWT authentication:
|
||||
- Session: Handled automatically via AuthMiddlewareStack (cookies in handshake)
|
||||
- JWT: Pass token as query parameter: ws://...?token=<jwt>
|
||||
|
||||
The WebSocket URL for JWT auth would be: ws://localhost/ws/?token=<access_token>
|
||||
|
||||
Security:
|
||||
- Functions must be explicitly registered (no arbitrary code execution)
|
||||
- Pydantic validation runs BEFORE any function code
|
||||
"""
|
||||
|
||||
import json
|
||||
import logging
|
||||
from typing import Any
|
||||
from urllib.parse import parse_qs
|
||||
|
||||
from channels.generic.websocket import AsyncJsonWebsocketConsumer
|
||||
from asgiref.sync import sync_to_async
|
||||
from . import get_channel
|
||||
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class WebSocketRequest:
|
||||
"""
|
||||
Minimal request adapter for WebSocket context.
|
||||
|
||||
Provides the interface expected by ServerFunction without full HttpRequest.
|
||||
This is intentionally minimal - only expose what's needed.
|
||||
|
||||
Note: Some Django libraries (e.g., allauth rate limiting) check request.method.
|
||||
We set method="POST" since WebSocket RPC calls are semantically similar to POST.
|
||||
"""
|
||||
|
||||
# WebSocket RPC is semantically similar to POST (sends data, expects response)
|
||||
method = "POST"
|
||||
|
||||
def __init__(self, scope: dict, channel_name: str = None):
|
||||
self.user = scope.get("user")
|
||||
self.session = scope.get("session", {})
|
||||
self.channel_name = channel_name # For push subscriptions
|
||||
self._scope = scope
|
||||
|
||||
@property
|
||||
def META(self) -> dict:
|
||||
"""HTTP headers from WebSocket handshake."""
|
||||
headers = dict(self._scope.get("headers", []))
|
||||
return {
|
||||
"HTTP_" + k.decode().upper().replace("-", "_"): v.decode()
|
||||
for k, v in headers.items()
|
||||
}
|
||||
|
||||
|
||||
class DjangoReactConsumer(AsyncJsonWebsocketConsumer):
|
||||
"""
|
||||
Multiplexed WebSocket consumer for django_react channels.
|
||||
|
||||
Manages multiple channel subscriptions over a single WebSocket connection.
|
||||
|
||||
Authentication:
|
||||
- Session auth via cookies (handled by AuthMiddlewareStack)
|
||||
- JWT auth via query parameter: ws://...?token=<jwt>
|
||||
"""
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
super().__init__(*args, **kwargs)
|
||||
# Track subscriptions: {(channel_name, params_json): channel_instance}
|
||||
self._subscriptions: dict[tuple[str, str], Any] = {}
|
||||
|
||||
async def connect(self):
|
||||
"""Accept the WebSocket connection, authenticating via JWT if provided."""
|
||||
# Check for JWT token in query parameters
|
||||
await self._try_jwt_auth()
|
||||
|
||||
await self.accept()
|
||||
logger.debug(
|
||||
f"WebSocket connected: {self.channel_name}, user={self.scope.get('user')}"
|
||||
)
|
||||
|
||||
async def _try_jwt_auth(self):
|
||||
"""
|
||||
Attempt JWT authentication from query parameter.
|
||||
|
||||
If a valid JWT token is provided via ?token=<jwt>, authenticate the user
|
||||
using JWTUser (no database query).
|
||||
|
||||
Security: If JWT is provided but invalid, we log it but don't reject
|
||||
the connection - the session auth may still be valid. However, if JWT
|
||||
IS valid, it takes precedence over session auth.
|
||||
"""
|
||||
# Parse query string for token
|
||||
query_string = self.scope.get("query_string", b"").decode()
|
||||
params = parse_qs(query_string)
|
||||
token_list = params.get("token", [])
|
||||
|
||||
if not token_list:
|
||||
return # No JWT provided, use session auth
|
||||
|
||||
token = token_list[0]
|
||||
if not token:
|
||||
return
|
||||
|
||||
# Validate JWT and create JWTUser (no DB query)
|
||||
try:
|
||||
from mizan.client.jwt import decode_token
|
||||
from mizan.jwt.tokens import JWTUser
|
||||
|
||||
payload = await sync_to_async(decode_token)(token, expected_type="access")
|
||||
if payload is None:
|
||||
logger.debug("JWT token invalid or expired")
|
||||
return # Fall back to session auth
|
||||
|
||||
# Create JWTUser from token claims - NO DATABASE QUERY
|
||||
self.scope["user"] = JWTUser(payload)
|
||||
logger.debug(f"JWT auth successful for user {payload.user_id}")
|
||||
except Exception as e:
|
||||
logger.debug(f"JWT auth failed: {e}")
|
||||
|
||||
async def disconnect(self, close_code):
|
||||
"""Clean up all subscriptions on disconnect."""
|
||||
for key, instance in list(self._subscriptions.items()):
|
||||
try:
|
||||
await instance.on_disconnect()
|
||||
await instance._leave_all_groups()
|
||||
except Exception as e:
|
||||
logger.error(f"Error during disconnect cleanup: {e}")
|
||||
|
||||
self._subscriptions.clear()
|
||||
logger.debug(f"WebSocket disconnected: {self.channel_name}")
|
||||
|
||||
async def receive_json(self, content: dict):
|
||||
"""Handle incoming JSON messages."""
|
||||
action = content.get("action")
|
||||
|
||||
if action == "subscribe":
|
||||
await self._handle_subscribe(content)
|
||||
elif action == "unsubscribe":
|
||||
await self._handle_unsubscribe(content)
|
||||
elif action == "message":
|
||||
await self._handle_message(content)
|
||||
elif action == "rpc":
|
||||
await self._handle_rpc(content)
|
||||
else:
|
||||
await self.send_json(
|
||||
{
|
||||
"error": f"Unknown action: {action}",
|
||||
}
|
||||
)
|
||||
|
||||
async def _handle_subscribe(self, content: dict):
|
||||
"""Handle subscription request."""
|
||||
channel_name = content.get("channel")
|
||||
params_dict = content.get("params", {})
|
||||
|
||||
# Get channel class
|
||||
channel_class = get_channel(channel_name)
|
||||
if not channel_class:
|
||||
await self.send_json(
|
||||
{
|
||||
"error": f"Unknown channel: {channel_name}",
|
||||
}
|
||||
)
|
||||
return
|
||||
|
||||
# Create subscription key
|
||||
params_json = json.dumps(params_dict, sort_keys=True)
|
||||
sub_key = (channel_name, params_json)
|
||||
|
||||
# Check if already subscribed
|
||||
if sub_key in self._subscriptions:
|
||||
await self.send_json(
|
||||
{
|
||||
"error": f"Already subscribed to {channel_name}",
|
||||
"channel": channel_name,
|
||||
"params": params_dict,
|
||||
}
|
||||
)
|
||||
return
|
||||
|
||||
# Create channel instance
|
||||
instance = channel_class()
|
||||
instance.user = self.scope.get("user")
|
||||
instance._channel_layer = self.channel_layer
|
||||
instance._channel_name = self.channel_name
|
||||
instance._registered_name = channel_name
|
||||
instance._params_dict = params_dict
|
||||
|
||||
# Parse params
|
||||
params_obj = None
|
||||
if channel_class.Params:
|
||||
try:
|
||||
params_obj = channel_class.Params(**params_dict)
|
||||
except Exception as e:
|
||||
await self.send_json(
|
||||
{
|
||||
"error": f"Invalid params: {e}",
|
||||
"channel": channel_name,
|
||||
}
|
||||
)
|
||||
return
|
||||
|
||||
# Check authorization
|
||||
try:
|
||||
if params_obj:
|
||||
authorized = instance.authorize(params_obj)
|
||||
else:
|
||||
authorized = instance.authorize()
|
||||
except Exception as e:
|
||||
logger.error(f"Authorization error for {channel_name}: {e}")
|
||||
await self.send_json(
|
||||
{
|
||||
"error": "Authorization failed",
|
||||
"channel": channel_name,
|
||||
}
|
||||
)
|
||||
return
|
||||
|
||||
if not authorized:
|
||||
await self.send_json(
|
||||
{
|
||||
"error": "Not authorized",
|
||||
"channel": channel_name,
|
||||
}
|
||||
)
|
||||
return
|
||||
|
||||
# Get group and join
|
||||
try:
|
||||
if params_obj:
|
||||
group_name = instance.group(params_obj)
|
||||
else:
|
||||
group_name = instance.group()
|
||||
await instance._join_group(group_name)
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to join group for {channel_name}: {e}")
|
||||
await self.send_json(
|
||||
{
|
||||
"error": f"Failed to subscribe: {e}",
|
||||
"channel": channel_name,
|
||||
}
|
||||
)
|
||||
return
|
||||
|
||||
# Store subscription
|
||||
self._subscriptions[sub_key] = instance
|
||||
|
||||
# Call on_connect hook
|
||||
try:
|
||||
await instance.on_connect(params_obj)
|
||||
except Exception as e:
|
||||
logger.error(f"on_connect error for {channel_name}: {e}")
|
||||
|
||||
# Confirm subscription
|
||||
await self.send_json(
|
||||
{
|
||||
"subscribed": True,
|
||||
"channel": channel_name,
|
||||
"params": params_dict,
|
||||
}
|
||||
)
|
||||
|
||||
logger.debug(f"Subscribed to {channel_name} with params {params_dict}")
|
||||
|
||||
async def _handle_unsubscribe(self, content: dict):
|
||||
"""Handle unsubscription request."""
|
||||
channel_name = content.get("channel")
|
||||
params_dict = content.get("params", {})
|
||||
|
||||
params_json = json.dumps(params_dict, sort_keys=True)
|
||||
sub_key = (channel_name, params_json)
|
||||
|
||||
instance = self._subscriptions.pop(sub_key, None)
|
||||
if instance:
|
||||
try:
|
||||
await instance.on_disconnect()
|
||||
await instance._leave_all_groups()
|
||||
except Exception as e:
|
||||
logger.error(f"Error during unsubscribe: {e}")
|
||||
|
||||
await self.send_json(
|
||||
{
|
||||
"unsubscribed": True,
|
||||
"channel": channel_name,
|
||||
"params": params_dict,
|
||||
}
|
||||
)
|
||||
|
||||
logger.debug(f"Unsubscribed from {channel_name}")
|
||||
|
||||
async def _handle_message(self, content: dict):
|
||||
"""Handle incoming message from browser."""
|
||||
channel_name = content.get("channel")
|
||||
params_dict = content.get("params", {})
|
||||
data = content.get("data", {})
|
||||
|
||||
params_json = json.dumps(params_dict, sort_keys=True)
|
||||
sub_key = (channel_name, params_json)
|
||||
|
||||
instance = self._subscriptions.get(sub_key)
|
||||
if not instance:
|
||||
await self.send_json(
|
||||
{
|
||||
"error": f"Not subscribed to {channel_name}",
|
||||
"channel": channel_name,
|
||||
}
|
||||
)
|
||||
return
|
||||
|
||||
channel_class = instance.__class__
|
||||
|
||||
# Check if channel accepts messages
|
||||
if not channel_class.ReactMessage:
|
||||
await self.send_json(
|
||||
{
|
||||
"error": f"Channel {channel_name} does not accept messages",
|
||||
"channel": channel_name,
|
||||
}
|
||||
)
|
||||
return
|
||||
|
||||
# Parse message
|
||||
try:
|
||||
msg = channel_class.ReactMessage(**data)
|
||||
except Exception as e:
|
||||
await self.send_json(
|
||||
{
|
||||
"error": f"Invalid message: {e}",
|
||||
"channel": channel_name,
|
||||
}
|
||||
)
|
||||
return
|
||||
|
||||
# Parse params
|
||||
params_obj = None
|
||||
if channel_class.Params:
|
||||
params_obj = channel_class.Params(**params_dict)
|
||||
|
||||
# Handle message
|
||||
try:
|
||||
response = instance.receive(params_obj, msg)
|
||||
|
||||
# If handler returned a message, broadcast it
|
||||
if response is not None:
|
||||
if params_obj:
|
||||
group_name = instance.group(params_obj)
|
||||
else:
|
||||
group_name = instance.group()
|
||||
|
||||
await instance._broadcast(group_name, response)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error handling message for {channel_name}: {e}")
|
||||
await self.send_json(
|
||||
{
|
||||
"error": f"Message handling failed: {e}",
|
||||
"channel": channel_name,
|
||||
}
|
||||
)
|
||||
|
||||
async def _handle_rpc(self, content: dict):
|
||||
"""
|
||||
Handle RPC (server function) call.
|
||||
|
||||
Protocol:
|
||||
Request: {"action": "rpc", "id": "request-id", "fn": "function_name", "args": {...}}
|
||||
Response: {"id": "request-id", "ok": true, "data": {...}}
|
||||
or: {"id": "request-id", "ok": false, "error": {...}}
|
||||
|
||||
Security:
|
||||
- Only functions with @client(websocket=True) are allowed
|
||||
- Pydantic validation happens BEFORE any function code runs
|
||||
- Function must be explicitly registered (no arbitrary code execution)
|
||||
- User context from WebSocket session is passed to function
|
||||
"""
|
||||
from mizan.client.executor import execute_function, FunctionError
|
||||
from mizan.setup.registry import get_function
|
||||
|
||||
request_id = content.get("id")
|
||||
fn_name = content.get("fn")
|
||||
args = content.get("args", {})
|
||||
|
||||
# Validate request structure
|
||||
if not request_id:
|
||||
await self.send_json(
|
||||
{
|
||||
"error": "RPC request missing 'id' field",
|
||||
}
|
||||
)
|
||||
return
|
||||
|
||||
if not fn_name:
|
||||
await self.send_json(
|
||||
{
|
||||
"id": request_id,
|
||||
"ok": False,
|
||||
"error": {
|
||||
"code": "BAD_REQUEST",
|
||||
"message": "Missing 'fn' field",
|
||||
},
|
||||
}
|
||||
)
|
||||
return
|
||||
|
||||
# Check if function exists and has websocket=True
|
||||
fn_class = get_function(fn_name)
|
||||
if fn_class is None:
|
||||
await self.send_json(
|
||||
{
|
||||
"id": request_id,
|
||||
"ok": False,
|
||||
"error": {
|
||||
"code": "NOT_FOUND",
|
||||
"message": f"Function '{fn_name}' not found",
|
||||
},
|
||||
}
|
||||
)
|
||||
return
|
||||
|
||||
# Only allow functions explicitly marked with websocket=True
|
||||
fn_meta = getattr(fn_class, "_meta", {})
|
||||
if not fn_meta.get("websocket"):
|
||||
await self.send_json(
|
||||
{
|
||||
"id": request_id,
|
||||
"ok": False,
|
||||
"error": {
|
||||
"code": "FORBIDDEN",
|
||||
"message": "This function is HTTP-only. Use POST /api/mizan/call/ instead.",
|
||||
},
|
||||
}
|
||||
)
|
||||
return
|
||||
|
||||
# Create request adapter from WebSocket scope
|
||||
ws_request = WebSocketRequest(
|
||||
self.scope, channel_name=getattr(self, "channel_name", None)
|
||||
)
|
||||
|
||||
# Execute function (Pydantic validation happens inside execute_function)
|
||||
# This is sync, so we need to run it in a thread pool
|
||||
result = await sync_to_async(execute_function, thread_sensitive=True)(
|
||||
ws_request,
|
||||
fn_name,
|
||||
args,
|
||||
)
|
||||
|
||||
# Send response
|
||||
if isinstance(result, FunctionError):
|
||||
await self.send_json(
|
||||
{
|
||||
"id": request_id,
|
||||
"ok": False,
|
||||
"error": {
|
||||
"code": result.code.value,
|
||||
"message": result.message,
|
||||
**({"details": result.details} if result.details else {}),
|
||||
},
|
||||
}
|
||||
)
|
||||
else:
|
||||
await self.send_json(
|
||||
{
|
||||
"id": request_id,
|
||||
"ok": True,
|
||||
"data": result.data,
|
||||
}
|
||||
)
|
||||
|
||||
async def channel_message(self, event: dict):
|
||||
"""
|
||||
Handle messages broadcast to a group.
|
||||
|
||||
Called when channel_layer.group_send() is used.
|
||||
Includes channel name and params so the client can route the message.
|
||||
"""
|
||||
await self.send_json(
|
||||
{
|
||||
"channel": event.get("channel"),
|
||||
"params": event.get("params", {}),
|
||||
"type": event.get("message_type", "message"),
|
||||
"data": event.get("data", {}),
|
||||
}
|
||||
)
|
||||
|
||||
async def push_message(self, event: dict):
|
||||
"""
|
||||
Handle push messages from server functions.
|
||||
|
||||
Called when push("topic", data) is used from a server function.
|
||||
The client receives this to update its local state.
|
||||
|
||||
Protocol:
|
||||
Server sends: {"type": "push", "topic": "room:42", "data": {...}}
|
||||
"""
|
||||
await self.send_json(
|
||||
{
|
||||
"type": "push",
|
||||
"topic": event.get("topic"),
|
||||
"data": event.get("data", {}),
|
||||
}
|
||||
)
|
||||
153
packages/mizan-rpc/adapters/django/src/mizan/channels/push.py
Normal file
153
packages/mizan-rpc/adapters/django/src/mizan/channels/push.py
Normal file
@@ -0,0 +1,153 @@
|
||||
"""
|
||||
mizan Push - Server-initiated messages to clients.
|
||||
|
||||
Simple API for pushing data to subscribed WebSocket connections.
|
||||
|
||||
Usage:
|
||||
# In a server function - push to all subscribers
|
||||
from mizan.push import push
|
||||
|
||||
push("room:42", {"type": "new_message", "data": {...}})
|
||||
|
||||
# Subscribe a connection to a topic (call during context fetch)
|
||||
from mizan.push import subscribe
|
||||
|
||||
subscribe(request, "room:42")
|
||||
"""
|
||||
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
# Lazy import to avoid import errors when channels is not installed
|
||||
# (e.g., during schema generation)
|
||||
if TYPE_CHECKING:
|
||||
from channels.layers import BaseChannelLayer
|
||||
|
||||
|
||||
def _get_channel_layer() -> "BaseChannelLayer | None":
|
||||
"""Get channel layer, returning None if channels is not installed."""
|
||||
try:
|
||||
from channels.layers import get_channel_layer
|
||||
|
||||
return get_channel_layer()
|
||||
except ImportError:
|
||||
return None
|
||||
|
||||
|
||||
def _async_to_sync(coro):
|
||||
"""Wrapper for async_to_sync that handles missing channels."""
|
||||
from asgiref.sync import async_to_sync
|
||||
|
||||
return async_to_sync(coro)
|
||||
|
||||
|
||||
def get_topic_group_name(topic: str) -> str:
|
||||
"""Convert a topic string to a valid channel layer group name."""
|
||||
# Channel layer group names must be valid ASCII alphanumeric + hyphens/underscores/periods
|
||||
# Replace colons with underscores
|
||||
return topic.replace(":", "_")
|
||||
|
||||
|
||||
def subscribe(request, topic: str) -> None:
|
||||
"""
|
||||
Subscribe this WebSocket connection to a topic.
|
||||
|
||||
Call this in a context or server function to register the connection
|
||||
for push notifications on the given topic.
|
||||
|
||||
Args:
|
||||
request: The Django request (must have channel_name attribute from WebSocket)
|
||||
topic: Topic string, e.g., "room:42", "user:123:notifications"
|
||||
"""
|
||||
channel_name = getattr(request, "channel_name", None)
|
||||
if not channel_name:
|
||||
# HTTP request, not WebSocket - can't subscribe
|
||||
return
|
||||
|
||||
channel_layer = _get_channel_layer()
|
||||
if not channel_layer:
|
||||
return
|
||||
|
||||
group_name = get_topic_group_name(topic)
|
||||
_async_to_sync(channel_layer.group_add)(group_name, channel_name)
|
||||
|
||||
|
||||
def unsubscribe(request, topic: str) -> None:
|
||||
"""
|
||||
Unsubscribe this WebSocket connection from a topic.
|
||||
|
||||
Args:
|
||||
request: The Django request (must have channel_name attribute from WebSocket)
|
||||
topic: Topic string to unsubscribe from
|
||||
"""
|
||||
channel_name = getattr(request, "channel_name", None)
|
||||
if not channel_name:
|
||||
return
|
||||
|
||||
channel_layer = _get_channel_layer()
|
||||
if not channel_layer:
|
||||
return
|
||||
|
||||
group_name = get_topic_group_name(topic)
|
||||
_async_to_sync(channel_layer.group_discard)(group_name, channel_name)
|
||||
|
||||
|
||||
def push(topic: str, data: dict | BaseModel) -> None:
|
||||
"""
|
||||
Push data to all connections subscribed to a topic.
|
||||
|
||||
Args:
|
||||
topic: Topic string, e.g., "room:42"
|
||||
data: Data to send (dict or Pydantic model)
|
||||
|
||||
Example:
|
||||
push("room:42", {
|
||||
"type": "new_message",
|
||||
"message": {"id": 1, "text": "Hello", "user": "alice@example.com"}
|
||||
})
|
||||
"""
|
||||
channel_layer = _get_channel_layer()
|
||||
if not channel_layer:
|
||||
import logging
|
||||
|
||||
logging.getLogger(__name__).warning(
|
||||
"No channel layer configured, cannot push to topic '%s'", topic
|
||||
)
|
||||
return
|
||||
|
||||
# Convert Pydantic model to dict if needed
|
||||
if isinstance(data, BaseModel):
|
||||
data = data.model_dump()
|
||||
|
||||
group_name = get_topic_group_name(topic)
|
||||
|
||||
_async_to_sync(channel_layer.group_send)(
|
||||
group_name,
|
||||
{
|
||||
"type": "push.message", # Maps to push_message handler in consumer
|
||||
"topic": topic,
|
||||
"data": data,
|
||||
},
|
||||
)
|
||||
|
||||
|
||||
async def push_async(topic: str, data: dict | BaseModel) -> None:
|
||||
"""Async version of push for use in async contexts."""
|
||||
channel_layer = _get_channel_layer()
|
||||
if not channel_layer:
|
||||
return
|
||||
|
||||
if isinstance(data, BaseModel):
|
||||
data = data.model_dump()
|
||||
|
||||
group_name = get_topic_group_name(topic)
|
||||
|
||||
await channel_layer.group_send(
|
||||
group_name,
|
||||
{
|
||||
"type": "push.message",
|
||||
"topic": topic,
|
||||
"data": data,
|
||||
},
|
||||
)
|
||||
@@ -0,0 +1,66 @@
|
||||
"""
|
||||
mizan.client - Server function implementation.
|
||||
|
||||
This subpackage contains everything needed to make server functions work:
|
||||
- The @client decorator
|
||||
- ServerFunction base class
|
||||
- Function execution logic
|
||||
- JWT authentication (integral to server functions)
|
||||
|
||||
Usage:
|
||||
from mizan.client import client, ServerFunction, compose
|
||||
"""
|
||||
|
||||
from .function import (
|
||||
# Decorator
|
||||
client,
|
||||
# Context markers
|
||||
ReactContext,
|
||||
GlobalContext,
|
||||
# Base classes
|
||||
ServerFunction,
|
||||
ComposedContext,
|
||||
# Composition
|
||||
compose,
|
||||
# Type aliases
|
||||
ContextMode,
|
||||
# Form helpers
|
||||
FormValidationOutput,
|
||||
FormSchemaField,
|
||||
FormSchemaOutput,
|
||||
create_form_functions,
|
||||
)
|
||||
|
||||
from .executor import (
|
||||
execute_function,
|
||||
function_call_view,
|
||||
ErrorCode,
|
||||
FunctionError,
|
||||
FunctionResult,
|
||||
)
|
||||
|
||||
__all__ = [
|
||||
# Decorator
|
||||
"client",
|
||||
# Context markers
|
||||
"ReactContext",
|
||||
"GlobalContext",
|
||||
# Base classes
|
||||
"ServerFunction",
|
||||
"ComposedContext",
|
||||
# Composition
|
||||
"compose",
|
||||
# Type aliases
|
||||
"ContextMode",
|
||||
# Execution
|
||||
"execute_function",
|
||||
"function_call_view",
|
||||
"ErrorCode",
|
||||
"FunctionError",
|
||||
"FunctionResult",
|
||||
# Form helpers
|
||||
"FormValidationOutput",
|
||||
"FormSchemaField",
|
||||
"FormSchemaOutput",
|
||||
"create_form_functions",
|
||||
]
|
||||
593
packages/mizan-rpc/adapters/django/src/mizan/client/executor.py
Normal file
593
packages/mizan-rpc/adapters/django/src/mizan/client/executor.py
Normal file
@@ -0,0 +1,593 @@
|
||||
"""
|
||||
mizan Function Executor
|
||||
|
||||
Handles execution of server functions.
|
||||
This is the core of the "Server Functions" feature - callable from React
|
||||
without REST boilerplate.
|
||||
|
||||
Security model:
|
||||
- All input validated against Pydantic schema BEFORE execution
|
||||
- Authentication: JWT (stateless) or Session (stateful) - auto-detected
|
||||
- JWT: Authorization header with Bearer token (no CSRF needed)
|
||||
- Session: Cookie-based with CSRF token (via X-CSRFToken header)
|
||||
- WebSocket RPC uses Origin header checking instead
|
||||
- No implicit function exposure - must be explicitly registered
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
import logging
|
||||
from dataclasses import dataclass
|
||||
from enum import Enum
|
||||
from functools import wraps
|
||||
from typing import TYPE_CHECKING, Any, Callable
|
||||
|
||||
from django.http import HttpRequest, JsonResponse
|
||||
from django.views.decorators.csrf import csrf_protect
|
||||
from pydantic import BaseModel, ValidationError
|
||||
|
||||
from mizan.setup.registry import get_function, get_context_groups
|
||||
|
||||
if TYPE_CHECKING:
|
||||
pass
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class ErrorCode(str, Enum):
|
||||
"""Standard error codes for function execution."""
|
||||
|
||||
# Client errors (4xx)
|
||||
NOT_FOUND = "NOT_FOUND" # Function not registered
|
||||
VALIDATION_ERROR = "VALIDATION_ERROR" # Input failed Pydantic validation
|
||||
UNAUTHORIZED = "UNAUTHORIZED" # User not authenticated (when required)
|
||||
FORBIDDEN = "FORBIDDEN" # User lacks permission
|
||||
BAD_REQUEST = "BAD_REQUEST" # Malformed request
|
||||
|
||||
# Server errors (5xx)
|
||||
INTERNAL_ERROR = "INTERNAL_ERROR" # Unhandled exception
|
||||
NOT_IMPLEMENTED = "NOT_IMPLEMENTED" # Function exists but not implemented
|
||||
|
||||
|
||||
@dataclass
|
||||
class FunctionError:
|
||||
"""Structured error response from function execution."""
|
||||
|
||||
code: ErrorCode
|
||||
message: str
|
||||
details: dict[str, Any] | None = None
|
||||
|
||||
def to_dict(self) -> dict[str, Any]:
|
||||
"""Convert to JSON-serializable dict."""
|
||||
result = {
|
||||
"error": True,
|
||||
"code": self.code.value,
|
||||
"message": self.message,
|
||||
}
|
||||
if self.details:
|
||||
result["details"] = self.details
|
||||
return result
|
||||
|
||||
def to_response(self, status: int = 400) -> JsonResponse:
|
||||
"""Convert to Django JsonResponse."""
|
||||
return JsonResponse(self.to_dict(), status=status)
|
||||
|
||||
|
||||
@dataclass
|
||||
class FunctionResult:
|
||||
"""Successful result from function execution."""
|
||||
|
||||
data: dict[str, Any]
|
||||
|
||||
def to_dict(self) -> dict[str, Any]:
|
||||
"""Convert to JSON-serializable dict."""
|
||||
return {
|
||||
"error": False,
|
||||
"data": self.data,
|
||||
}
|
||||
|
||||
def to_response(self) -> JsonResponse:
|
||||
"""Convert to Django JsonResponse."""
|
||||
return JsonResponse(self.to_dict())
|
||||
|
||||
|
||||
def _check_auth_requirement(
|
||||
request: HttpRequest,
|
||||
auth_requirement: str | Callable | None,
|
||||
) -> FunctionError | None:
|
||||
"""
|
||||
Check if the request meets the auth requirement.
|
||||
|
||||
Args:
|
||||
request: The Django HttpRequest (with user set)
|
||||
auth_requirement: 'required', 'staff', 'superuser', callable, or None
|
||||
|
||||
Returns:
|
||||
FunctionError if auth check fails, None if it passes.
|
||||
|
||||
Note: This uses request.user which may be a JWTUser (stateless) or
|
||||
Django User (from session). Either way, no additional DB query is made
|
||||
for the built-in checks. Custom callables may query DB if they choose.
|
||||
"""
|
||||
if auth_requirement is None:
|
||||
return None
|
||||
|
||||
user = request.user
|
||||
|
||||
# Handle callable auth
|
||||
if callable(auth_requirement):
|
||||
try:
|
||||
result = auth_requirement(request)
|
||||
if result:
|
||||
return None # Authorized
|
||||
else:
|
||||
return FunctionError(
|
||||
code=ErrorCode.FORBIDDEN,
|
||||
message="Access denied",
|
||||
)
|
||||
except PermissionError as e:
|
||||
# Custom error message from the callable
|
||||
return FunctionError(
|
||||
code=ErrorCode.FORBIDDEN,
|
||||
message=str(e) or "Access denied",
|
||||
)
|
||||
|
||||
# Check authentication (required for all string-based auth)
|
||||
if not getattr(user, "is_authenticated", False):
|
||||
return FunctionError(
|
||||
code=ErrorCode.UNAUTHORIZED,
|
||||
message="Authentication required",
|
||||
)
|
||||
|
||||
# Check staff requirement
|
||||
if auth_requirement == "staff":
|
||||
if not getattr(user, "is_staff", False):
|
||||
return FunctionError(
|
||||
code=ErrorCode.FORBIDDEN,
|
||||
message="Staff access required",
|
||||
)
|
||||
|
||||
# Check superuser requirement
|
||||
elif auth_requirement == "superuser":
|
||||
if not getattr(user, "is_superuser", False):
|
||||
return FunctionError(
|
||||
code=ErrorCode.FORBIDDEN,
|
||||
message="Superuser access required",
|
||||
)
|
||||
|
||||
return None
|
||||
|
||||
|
||||
def execute_function(
|
||||
request: HttpRequest,
|
||||
fn_name: str,
|
||||
input_data: dict[str, Any] | None = None,
|
||||
) -> FunctionResult | FunctionError:
|
||||
"""
|
||||
Execute a registered server function.
|
||||
|
||||
Args:
|
||||
request: The Django HttpRequest
|
||||
fn_name: Name of the registered function
|
||||
input_data: Input data to pass to the function
|
||||
|
||||
Returns:
|
||||
FunctionResult on success, FunctionError on failure
|
||||
"""
|
||||
from django.conf import settings
|
||||
|
||||
# Look up the function by name
|
||||
view_class = get_function(fn_name)
|
||||
if view_class is None:
|
||||
# In DEBUG mode, include the name for easier debugging
|
||||
if settings.DEBUG:
|
||||
message = f"Function '{fn_name}' not found"
|
||||
else:
|
||||
message = "Function not found"
|
||||
return FunctionError(
|
||||
code=ErrorCode.NOT_FOUND,
|
||||
message=message,
|
||||
)
|
||||
|
||||
# Check auth requirement BEFORE executing
|
||||
meta = getattr(view_class, "_meta", {})
|
||||
auth_requirement = meta.get("auth")
|
||||
auth_error = _check_auth_requirement(request, auth_requirement)
|
||||
if auth_error is not None:
|
||||
return auth_error
|
||||
|
||||
# Instantiate the view with the request
|
||||
view = view_class(request)
|
||||
|
||||
# Check if this is a form function that handles input specially
|
||||
meta = getattr(view_class, "_meta", {})
|
||||
is_form_multipart = meta.get("multipart", False)
|
||||
|
||||
# For form functions with Input=None, skip Pydantic validation
|
||||
# The form itself handles validation
|
||||
input_cls = view.Input
|
||||
if input_cls is None and is_form_multipart:
|
||||
# Form function - pass input_data directly (already parsed by view or will be)
|
||||
validated_input = input_data
|
||||
elif input_cls is BaseModel:
|
||||
has_input = False
|
||||
validated_input = None
|
||||
else:
|
||||
# Check if it has any fields defined
|
||||
has_input = bool(input_cls.model_fields) if input_cls else False
|
||||
|
||||
# Validate input against Pydantic schema
|
||||
try:
|
||||
if input_data:
|
||||
# Ensure input_data is a dict (not array or other type)
|
||||
if not isinstance(input_data, dict):
|
||||
return FunctionError(
|
||||
code=ErrorCode.BAD_REQUEST,
|
||||
message="Input must be an object, not "
|
||||
+ type(input_data).__name__,
|
||||
)
|
||||
validated_input = input_cls(**input_data)
|
||||
elif has_input:
|
||||
# Check if function requires input fields
|
||||
input_schema = input_cls.model_json_schema()
|
||||
required_fields = input_schema.get("required", [])
|
||||
if required_fields:
|
||||
# Format as field errors for consistency
|
||||
errors = {field: ["Field required"] for field in required_fields}
|
||||
return FunctionError(
|
||||
code=ErrorCode.VALIDATION_ERROR,
|
||||
message="Input validation failed",
|
||||
details={"fields": errors},
|
||||
)
|
||||
validated_input = input_cls()
|
||||
else:
|
||||
# No input expected, create empty model
|
||||
validated_input = None
|
||||
except ValidationError as e:
|
||||
# Convert Pydantic errors to our format
|
||||
errors = {}
|
||||
for error in e.errors():
|
||||
field = ".".join(str(loc) for loc in error["loc"])
|
||||
if field not in errors:
|
||||
errors[field] = []
|
||||
errors[field].append(error["msg"])
|
||||
|
||||
return FunctionError(
|
||||
code=ErrorCode.VALIDATION_ERROR,
|
||||
message="Input validation failed",
|
||||
details={"fields": errors},
|
||||
)
|
||||
|
||||
# Execute the function
|
||||
try:
|
||||
output = view.call(validated_input)
|
||||
except NotImplementedError as e:
|
||||
logger.error(f"Function {fn_name} not implemented: {e}")
|
||||
return FunctionError(
|
||||
code=ErrorCode.NOT_IMPLEMENTED,
|
||||
message=str(e),
|
||||
)
|
||||
except PermissionError as e:
|
||||
# Functions can raise PermissionError for auth issues
|
||||
return FunctionError(
|
||||
code=ErrorCode.FORBIDDEN,
|
||||
message=str(e) or "Permission denied",
|
||||
)
|
||||
except Exception as e:
|
||||
# Log the full exception for debugging
|
||||
logger.exception(f"Error executing function {fn_name}")
|
||||
return FunctionError(
|
||||
code=ErrorCode.INTERNAL_ERROR,
|
||||
message="An internal error occurred",
|
||||
# Don't expose internal details in production
|
||||
details={"type": type(e).__name__}
|
||||
if logger.isEnabledFor(logging.DEBUG)
|
||||
else None,
|
||||
)
|
||||
|
||||
# Serialize output (handle None for Optional return types)
|
||||
if output is None:
|
||||
return FunctionResult(data=None)
|
||||
return FunctionResult(data=output.model_dump())
|
||||
|
||||
|
||||
def _try_jwt_auth(request: HttpRequest) -> bool:
|
||||
"""
|
||||
Attempt to authenticate the request using JWT.
|
||||
|
||||
If Authorization header contains a valid Bearer token, authenticates
|
||||
the request and sets request.user to a JWTUser. Returns True if JWT
|
||||
auth succeeded.
|
||||
|
||||
IMPORTANT: This is stateless - no database query is made. The JWTUser
|
||||
object is created from the token claims. If you need the full User
|
||||
object, query it explicitly in your function.
|
||||
|
||||
Security: If JWT is provided but invalid, we return False and do NOT
|
||||
fall back to session auth. The caller should reject the request.
|
||||
"""
|
||||
auth_header = request.META.get("HTTP_AUTHORIZATION", "")
|
||||
if not auth_header.startswith("Bearer "):
|
||||
return False
|
||||
|
||||
token = auth_header[7:] # Strip "Bearer "
|
||||
if not token:
|
||||
return False
|
||||
|
||||
try:
|
||||
from mizan.client.jwt import decode_token
|
||||
from mizan.jwt.tokens import JWTUser
|
||||
|
||||
payload = decode_token(token, expected_type="access")
|
||||
if payload is None:
|
||||
return False
|
||||
|
||||
# Create JWTUser from token claims - NO DATABASE QUERY
|
||||
request.user = JWTUser(payload)
|
||||
request._mizan_jwt_authenticated = True
|
||||
return True
|
||||
except Exception:
|
||||
return False
|
||||
|
||||
|
||||
def _has_jwt_header(request: HttpRequest) -> bool:
|
||||
"""Check if request has a JWT Authorization header."""
|
||||
auth_header = request.META.get("HTTP_AUTHORIZATION", "")
|
||||
return auth_header.startswith("Bearer ")
|
||||
|
||||
|
||||
def _csrf_protect_unless_jwt(view_func):
|
||||
"""
|
||||
Decorator that applies CSRF protection unless JWT auth is used.
|
||||
|
||||
JWT tokens are self-authenticating (the token itself proves the request
|
||||
is legitimate), so CSRF protection is not needed.
|
||||
|
||||
Security: If JWT is provided but invalid, reject the request - do NOT
|
||||
fall back to session auth. This prevents attacks where an invalid token
|
||||
is sent alongside a valid session cookie.
|
||||
"""
|
||||
csrf_protected_view = csrf_protect(view_func)
|
||||
|
||||
@wraps(view_func)
|
||||
def wrapper(request: HttpRequest, *args, **kwargs):
|
||||
# Check if JWT header is present
|
||||
has_jwt = _has_jwt_header(request)
|
||||
|
||||
if has_jwt:
|
||||
# JWT header present - try to authenticate
|
||||
if _try_jwt_auth(request):
|
||||
# JWT valid - skip CSRF, proceed
|
||||
return view_func(request, *args, **kwargs)
|
||||
else:
|
||||
# JWT invalid - reject (do NOT fall back to session)
|
||||
return FunctionError(
|
||||
code=ErrorCode.UNAUTHORIZED,
|
||||
message="Invalid or expired JWT token",
|
||||
).to_response(status=401)
|
||||
else:
|
||||
# No JWT - use session auth with CSRF
|
||||
return csrf_protected_view(request, *args, **kwargs)
|
||||
|
||||
return wrapper
|
||||
|
||||
|
||||
@_csrf_protect_unless_jwt
|
||||
def function_call_view(request: HttpRequest) -> JsonResponse:
|
||||
"""
|
||||
Django view for handling function calls (HTTP fallback for WebSocket RPC).
|
||||
|
||||
Authentication (auto-detected):
|
||||
- JWT: Authorization: Bearer <token> (stateless, no CSRF needed)
|
||||
- Session: Cookie-based with X-CSRFToken header (CSRF required)
|
||||
|
||||
Endpoint: POST /api/mizan/call/
|
||||
|
||||
Request body (JSON):
|
||||
{
|
||||
"fn": "function_name", // Function name
|
||||
"args": { ... } // Optional, depending on function
|
||||
}
|
||||
|
||||
Request body (multipart/form-data for form submit functions):
|
||||
fn: function_name
|
||||
<field>: <value>
|
||||
...
|
||||
|
||||
Response on success:
|
||||
{
|
||||
"error": false,
|
||||
"data": { ... } // Function output
|
||||
}
|
||||
|
||||
Response on error:
|
||||
{
|
||||
"error": true,
|
||||
"code": "VALIDATION_ERROR",
|
||||
"message": "Input validation failed",
|
||||
"details": { ... }
|
||||
}
|
||||
"""
|
||||
# Only allow POST
|
||||
if request.method != "POST":
|
||||
return FunctionError(
|
||||
code=ErrorCode.BAD_REQUEST,
|
||||
message="Only POST method allowed",
|
||||
).to_response(status=405)
|
||||
|
||||
# Check content type to determine parsing method
|
||||
content_type = request.content_type or ""
|
||||
is_multipart = content_type.startswith("multipart/form-data")
|
||||
|
||||
if is_multipart:
|
||||
# Multipart form data - used by form submit functions
|
||||
fn_name = request.POST.get("fn")
|
||||
if not fn_name:
|
||||
return FunctionError(
|
||||
code=ErrorCode.BAD_REQUEST,
|
||||
message="Missing 'fn' field",
|
||||
).to_response()
|
||||
|
||||
# Get form data (excluding 'fn')
|
||||
input_data = {k: v for k, v in request.POST.dict().items() if k != "fn"}
|
||||
|
||||
# Attach parsed form data and files to request for form functions
|
||||
request._mizan_form_data = input_data
|
||||
request._mizan_form_files = request.FILES
|
||||
|
||||
else:
|
||||
# JSON body - standard RPC
|
||||
try:
|
||||
if request.body:
|
||||
body = json.loads(request.body)
|
||||
else:
|
||||
return FunctionError(
|
||||
code=ErrorCode.BAD_REQUEST,
|
||||
message="Request body required",
|
||||
).to_response()
|
||||
except json.JSONDecodeError:
|
||||
return FunctionError(
|
||||
code=ErrorCode.BAD_REQUEST,
|
||||
message="Invalid JSON in request body",
|
||||
).to_response()
|
||||
|
||||
# Extract function name and args
|
||||
fn_name = body.get("fn")
|
||||
if not fn_name:
|
||||
return FunctionError(
|
||||
code=ErrorCode.BAD_REQUEST,
|
||||
message="Missing 'fn' field",
|
||||
).to_response()
|
||||
|
||||
input_data = body.get("args")
|
||||
|
||||
# Execute the function
|
||||
result = execute_function(request, fn_name, input_data)
|
||||
|
||||
# Return appropriate response
|
||||
if isinstance(result, FunctionError):
|
||||
status = {
|
||||
ErrorCode.NOT_FOUND: 404,
|
||||
ErrorCode.VALIDATION_ERROR: 422,
|
||||
ErrorCode.UNAUTHORIZED: 401,
|
||||
ErrorCode.FORBIDDEN: 403,
|
||||
ErrorCode.BAD_REQUEST: 400,
|
||||
ErrorCode.INTERNAL_ERROR: 500,
|
||||
ErrorCode.NOT_IMPLEMENTED: 501,
|
||||
}.get(result.code, 400)
|
||||
return result.to_response(status=status)
|
||||
|
||||
return result.to_response()
|
||||
|
||||
|
||||
def execute_context(
|
||||
request: HttpRequest,
|
||||
context_name: str,
|
||||
params: dict[str, str],
|
||||
) -> FunctionResult | FunctionError:
|
||||
"""
|
||||
Execute all functions in a named context with merged params.
|
||||
|
||||
Each function receives only the params it declares in its Input schema.
|
||||
If any function fails (auth, validation, execution), the entire request fails.
|
||||
|
||||
Args:
|
||||
request: The Django HttpRequest
|
||||
context_name: Name of the context (e.g., 'user', 'global')
|
||||
params: Query parameters (strings — Pydantic coerces types)
|
||||
|
||||
Returns:
|
||||
FunctionResult with bundled data, or FunctionError
|
||||
"""
|
||||
groups = get_context_groups()
|
||||
fn_names = groups.get(context_name)
|
||||
if not fn_names:
|
||||
return FunctionError(
|
||||
code=ErrorCode.NOT_FOUND,
|
||||
message=f"Context '{context_name}' not found",
|
||||
)
|
||||
|
||||
results = {}
|
||||
for fn_name in fn_names:
|
||||
view_class = get_function(fn_name)
|
||||
if view_class is None:
|
||||
continue
|
||||
|
||||
# Filter params to only those in this function's Input schema
|
||||
input_cls = getattr(view_class, "Input", None)
|
||||
if input_cls and input_cls is not BaseModel and input_cls.model_fields:
|
||||
fn_params = {
|
||||
k: v for k, v in params.items()
|
||||
if k in input_cls.model_fields
|
||||
}
|
||||
else:
|
||||
fn_params = None
|
||||
|
||||
result = execute_function(request, fn_name, fn_params)
|
||||
if isinstance(result, FunctionError):
|
||||
return result
|
||||
results[fn_name] = result.data
|
||||
|
||||
return FunctionResult(data=results)
|
||||
|
||||
|
||||
def _jwt_auth_only(view_func):
|
||||
"""
|
||||
Decorator that handles JWT auth for GET endpoints (no CSRF needed for GET).
|
||||
"""
|
||||
@wraps(view_func)
|
||||
def wrapper(request: HttpRequest, *args, **kwargs):
|
||||
has_jwt = _has_jwt_header(request)
|
||||
if has_jwt:
|
||||
if _try_jwt_auth(request):
|
||||
return view_func(request, *args, **kwargs)
|
||||
else:
|
||||
return FunctionError(
|
||||
code=ErrorCode.UNAUTHORIZED,
|
||||
message="Invalid or expired JWT token",
|
||||
).to_response(status=401)
|
||||
# No JWT — session auth (no CSRF needed for GET)
|
||||
return view_func(request, *args, **kwargs)
|
||||
|
||||
return wrapper
|
||||
|
||||
|
||||
@_jwt_auth_only
|
||||
def context_fetch_view(request: HttpRequest, context_name: str) -> JsonResponse:
|
||||
"""
|
||||
Fetch all functions in a named context in a single bundled GET request.
|
||||
|
||||
Endpoint: GET /api/mizan/ctx/<context_name>/?param1=val1¶m2=val2
|
||||
|
||||
Response on success:
|
||||
{
|
||||
"error": false,
|
||||
"data": {
|
||||
"user_profile": { ... },
|
||||
"user_orders": [ ... ]
|
||||
}
|
||||
}
|
||||
"""
|
||||
if request.method != "GET":
|
||||
return FunctionError(
|
||||
code=ErrorCode.BAD_REQUEST,
|
||||
message="Only GET method allowed",
|
||||
).to_response(status=405)
|
||||
|
||||
params = dict(request.GET)
|
||||
result = execute_context(request, context_name, params)
|
||||
|
||||
if isinstance(result, FunctionError):
|
||||
status = {
|
||||
ErrorCode.NOT_FOUND: 404,
|
||||
ErrorCode.VALIDATION_ERROR: 422,
|
||||
ErrorCode.UNAUTHORIZED: 401,
|
||||
ErrorCode.FORBIDDEN: 403,
|
||||
ErrorCode.BAD_REQUEST: 400,
|
||||
ErrorCode.INTERNAL_ERROR: 500,
|
||||
ErrorCode.NOT_IMPLEMENTED: 501,
|
||||
}.get(result.code, 400)
|
||||
return result.to_response(status=status)
|
||||
|
||||
return result.to_response()
|
||||
816
packages/mizan-rpc/adapters/django/src/mizan/client/function.py
Normal file
816
packages/mizan-rpc/adapters/django/src/mizan/client/function.py
Normal file
@@ -0,0 +1,816 @@
|
||||
"""
|
||||
mizan Server Functions - Core Primitive
|
||||
|
||||
Server functions are the core primitive. Everything else builds on them.
|
||||
|
||||
Two styles supported:
|
||||
|
||||
1. Function-based (recommended, Django Ninja style):
|
||||
@client("update-profile")
|
||||
def update_profile(request, input: UpdateProfileInput) -> UpdateProfileOutput:
|
||||
return UpdateProfileOutput(success=True)
|
||||
|
||||
2. Class-based (for complex cases):
|
||||
class UpdateProfile(ServerFunction):
|
||||
def call(self, input: UpdateProfileInput) -> UpdateProfileOutput:
|
||||
return UpdateProfileOutput(success=True)
|
||||
register(UpdateProfile, 'update-profile')
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import inspect
|
||||
import warnings
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import (
|
||||
Any,
|
||||
Callable,
|
||||
ClassVar,
|
||||
Generic,
|
||||
Literal,
|
||||
TypeVar,
|
||||
Union,
|
||||
get_args,
|
||||
get_origin,
|
||||
get_type_hints,
|
||||
)
|
||||
|
||||
from django.http import HttpRequest
|
||||
from pydantic import BaseModel
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# REACT CONTEXT - Named context marker
|
||||
# =============================================================================
|
||||
|
||||
|
||||
class ReactContext:
|
||||
"""
|
||||
A named context that groups server functions into one provider and one fetch.
|
||||
|
||||
Usage:
|
||||
UserContext = ReactContext('user')
|
||||
|
||||
@client(context=UserContext)
|
||||
def user_profile(request, user_id: int) -> ProfileShape: ...
|
||||
|
||||
@client(context=UserContext)
|
||||
def user_orders(request, user_id: int) -> list[OrderShape]: ...
|
||||
|
||||
@client(affects=UserContext)
|
||||
def edit_profile(request, name: str) -> dict: ...
|
||||
|
||||
@client(affects=[UserContext, OrderContext])
|
||||
def change_plan(request) -> dict: ...
|
||||
"""
|
||||
|
||||
def __init__(self, name: str):
|
||||
if not name or not isinstance(name, str):
|
||||
raise ValueError("ReactContext name must be a non-empty string")
|
||||
self.name = name
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return f"ReactContext({self.name!r})"
|
||||
|
||||
|
||||
# Built-in global context (auto-mounted at root, SSR-hydrated)
|
||||
GlobalContext = ReactContext("global")
|
||||
|
||||
|
||||
# Context parameter type: a ReactContext instance, a raw string, or False
|
||||
ContextMode = ReactContext | str | Literal[False]
|
||||
|
||||
|
||||
TInput = TypeVar("TInput", bound=BaseModel)
|
||||
TOutput = TypeVar("TOutput", bound=BaseModel)
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# SERVER FUNCTION - The Core Primitive
|
||||
# =============================================================================
|
||||
|
||||
|
||||
class ServerFunction(ABC, Generic[TInput, TOutput]):
|
||||
"""
|
||||
Class-based server function (for complex cases).
|
||||
|
||||
For simple functions, use the @client decorator instead.
|
||||
|
||||
Usage:
|
||||
class UpdateProfile(ServerFunction):
|
||||
def call(self, input: UpdateProfileInput) -> UpdateProfileOutput:
|
||||
self.user.name = input.name
|
||||
self.user.save()
|
||||
return UpdateProfileOutput(success=True)
|
||||
|
||||
register(UpdateProfile, 'update-profile')
|
||||
"""
|
||||
|
||||
# Registration name (set by register())
|
||||
name: ClassVar[str]
|
||||
|
||||
# Metadata for code generation
|
||||
_meta: ClassVar[dict[str, Any]] = {}
|
||||
|
||||
# Schema classes (set automatically from type hints or explicitly)
|
||||
Input: ClassVar[type[BaseModel]] = BaseModel
|
||||
Output: ClassVar[type[BaseModel]] = BaseModel
|
||||
|
||||
def __init__(self, request: HttpRequest):
|
||||
"""Initialize with the Django request."""
|
||||
self.request = request
|
||||
|
||||
@property
|
||||
def user(self):
|
||||
"""Shortcut to request.user."""
|
||||
return self.request.user
|
||||
|
||||
@abstractmethod
|
||||
def call(self, input: TInput) -> TOutput:
|
||||
"""
|
||||
Execute the function.
|
||||
|
||||
Args:
|
||||
input: Validated input data
|
||||
|
||||
Returns:
|
||||
Output instance
|
||||
"""
|
||||
raise NotImplementedError(f"{self.__class__.__name__} must implement call()")
|
||||
|
||||
@classmethod
|
||||
def get_schema_export(cls) -> dict[str, Any]:
|
||||
"""Export schema for TypeScript generation."""
|
||||
export = {
|
||||
"name": getattr(cls, "name", cls.__name__),
|
||||
"type": "function",
|
||||
"meta": getattr(cls, "_meta", {}),
|
||||
}
|
||||
|
||||
# Get Input/Output from class attributes
|
||||
input_cls = getattr(cls, "Input", BaseModel)
|
||||
output_cls = getattr(cls, "Output", BaseModel)
|
||||
|
||||
# Check if Input has fields
|
||||
input_schema = input_cls.model_json_schema()
|
||||
has_input = bool(input_schema.get("properties"))
|
||||
if has_input:
|
||||
export["input"] = input_schema
|
||||
export["has_input"] = has_input
|
||||
|
||||
export["output"] = output_cls.model_json_schema()
|
||||
|
||||
return export
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# FUNCTION DECORATOR - Django Ninja Style
|
||||
# =============================================================================
|
||||
|
||||
|
||||
class _FunctionWrapper(ServerFunction):
|
||||
"""Internal wrapper that makes a plain function behave like a ServerFunction."""
|
||||
|
||||
# Will be set per-wrapper instance
|
||||
_wrapped_fn: ClassVar[Callable]
|
||||
_input_cls: ClassVar[type[BaseModel] | None]
|
||||
_output_cls: ClassVar[type[BaseModel]]
|
||||
_param_names: ClassVar[list[str]] = []
|
||||
_is_primitive_output: ClassVar[bool] = False
|
||||
|
||||
def call(self, input):
|
||||
"""Execute the wrapped function, unpacking input into individual args."""
|
||||
if input is not None and self._param_names:
|
||||
# Unpack validated model into keyword arguments
|
||||
kwargs = {name: getattr(input, name) for name in self._param_names}
|
||||
result = self._wrapped_fn(self.request, **kwargs)
|
||||
else:
|
||||
result = self._wrapped_fn(self.request)
|
||||
|
||||
# Wrap primitive returns in the generated output model
|
||||
if self._is_primitive_output:
|
||||
return self._output_cls(result=result)
|
||||
return result
|
||||
|
||||
@classmethod
|
||||
def get_schema_export(cls) -> dict[str, Any]:
|
||||
"""Export schema for TypeScript generation."""
|
||||
export = {
|
||||
"name": getattr(cls, "name", cls.__name__),
|
||||
"type": "function",
|
||||
"meta": getattr(cls, "_meta", {}),
|
||||
}
|
||||
|
||||
# Use stored schema classes
|
||||
if cls._input_cls is not None:
|
||||
input_schema = cls._input_cls.model_json_schema()
|
||||
has_input = bool(input_schema.get("properties"))
|
||||
if has_input:
|
||||
export["input"] = input_schema
|
||||
export["has_input"] = has_input
|
||||
else:
|
||||
export["has_input"] = False
|
||||
|
||||
export["output"] = cls._output_cls.model_json_schema()
|
||||
|
||||
return export
|
||||
|
||||
|
||||
# Valid string values for auth parameter
|
||||
_VALID_AUTH_STRINGS = frozenset({"required", "staff", "superuser"})
|
||||
|
||||
|
||||
def _resolve_context(context: ContextMode) -> str | Literal[False]:
|
||||
"""Resolve a context parameter to its name string."""
|
||||
if context is False:
|
||||
return False
|
||||
if isinstance(context, ReactContext):
|
||||
return context.name
|
||||
if isinstance(context, str):
|
||||
if not context.strip():
|
||||
raise ValueError("context must be a non-empty string, ReactContext, or False.")
|
||||
if context == "local":
|
||||
warnings.warn(
|
||||
"context='local' is deprecated. Use ReactContext('name') instead.",
|
||||
DeprecationWarning,
|
||||
stacklevel=3,
|
||||
)
|
||||
return context
|
||||
raise ValueError(
|
||||
f"context must be a ReactContext, a string, or False. Got {type(context).__name__}."
|
||||
)
|
||||
|
||||
|
||||
# Affects parameter type
|
||||
AffectsTarget = ReactContext | str | type["ServerFunction"]
|
||||
AffectsMode = AffectsTarget | list[AffectsTarget] | None
|
||||
|
||||
|
||||
def client(
|
||||
fn: Callable = None,
|
||||
*,
|
||||
context: ContextMode = False,
|
||||
affects: AffectsMode = None,
|
||||
websocket: bool = False,
|
||||
auth: bool | str | Callable[[Any], bool] | None = None,
|
||||
) -> type[ServerFunction] | Callable[[Callable], type[ServerFunction]]:
|
||||
"""
|
||||
Register a function as a server function.
|
||||
|
||||
Type annotations define the schema - just like Django Ninja/FastAPI.
|
||||
Function parameters become input fields automatically.
|
||||
|
||||
Args:
|
||||
context: Named context for React state management.
|
||||
- False (default): Not a context, just a callable function.
|
||||
- ReactContext instance: groups functions into a named context.
|
||||
- GlobalContext: reserved, auto-mounted at root, SSR-hydrated.
|
||||
- Raw string: also accepted (e.g., 'user'), but ReactContext preferred.
|
||||
|
||||
affects: Declare which contexts this mutation invalidates.
|
||||
- A ReactContext instance
|
||||
- A list of ReactContext instances
|
||||
- Also accepts strings or function references for backwards compat
|
||||
Mutually exclusive with context=.
|
||||
|
||||
websocket: Enable WebSocket RPC transport (default: False).
|
||||
|
||||
auth: Authentication requirement.
|
||||
- None (default): No auth required
|
||||
- True or 'required': Must be authenticated
|
||||
- 'staff': Must have is_staff=True
|
||||
- 'superuser': Must have is_superuser=True
|
||||
- callable(request) -> bool: Custom check function
|
||||
|
||||
Usage:
|
||||
UserContext = ReactContext('user')
|
||||
|
||||
@client(context=GlobalContext)
|
||||
def current_user(request) -> UserOutput: ...
|
||||
|
||||
@client(context=UserContext)
|
||||
def user_profile(request, user_id: int) -> ProfileOutput: ...
|
||||
|
||||
@client(affects=UserContext)
|
||||
def edit_profile(request, name: str) -> dict: ...
|
||||
|
||||
@client(affects=[UserContext, OrderContext])
|
||||
def change_plan(request) -> dict: ...
|
||||
|
||||
Returns:
|
||||
A ServerFunction class that wraps the function
|
||||
"""
|
||||
# Resolve context to name string
|
||||
resolved_context = _resolve_context(context)
|
||||
|
||||
# Validate affects parameter
|
||||
if affects is not None:
|
||||
if resolved_context is not False:
|
||||
raise ValueError(
|
||||
"context= and affects= are mutually exclusive. "
|
||||
"A function cannot be both a context reader and a mutation."
|
||||
)
|
||||
|
||||
# Validate auth parameter
|
||||
if auth is not None:
|
||||
if isinstance(auth, str) and auth not in _VALID_AUTH_STRINGS:
|
||||
raise ValueError(
|
||||
f"Invalid auth value '{auth}'. "
|
||||
f"Must be one of: {', '.join(sorted(_VALID_AUTH_STRINGS))}, True, or a callable."
|
||||
)
|
||||
|
||||
def decorator(fn: Callable) -> type[ServerFunction]:
|
||||
return _create_server_function(
|
||||
fn, context=resolved_context, affects=affects, websocket=websocket, auth=auth
|
||||
)
|
||||
|
||||
# Support both @client and @client(...)
|
||||
if fn is not None:
|
||||
return _create_server_function(
|
||||
fn, context=resolved_context, affects=affects, websocket=websocket, auth=auth
|
||||
)
|
||||
return decorator
|
||||
|
||||
|
||||
def _normalize_affects(affects: AffectsMode) -> list[dict[str, str]] | None:
|
||||
"""Normalize the affects parameter into a list of target descriptors."""
|
||||
if affects is None:
|
||||
return None
|
||||
|
||||
items = affects if isinstance(affects, list) else [affects]
|
||||
result = []
|
||||
for item in items:
|
||||
if isinstance(item, ReactContext):
|
||||
result.append({"type": "context", "name": item.name})
|
||||
elif isinstance(item, str):
|
||||
result.append({"type": "context", "name": item})
|
||||
elif isinstance(item, type) and issubclass(item, ServerFunction):
|
||||
fn_meta = getattr(item, "_meta", {})
|
||||
fn_ctx = fn_meta.get("context")
|
||||
result.append({
|
||||
"type": "function",
|
||||
"name": getattr(item, "__name__", str(item)),
|
||||
"context": fn_ctx or None,
|
||||
})
|
||||
else:
|
||||
raise ValueError(
|
||||
f"affects items must be ReactContext instances, context name strings, "
|
||||
f"or @client function references. Got {type(item)}"
|
||||
)
|
||||
return result
|
||||
|
||||
|
||||
def _create_server_function(
|
||||
fn: Callable,
|
||||
*,
|
||||
context: str | Literal[False] = False,
|
||||
affects: str | type["ServerFunction"] | list[str | type["ServerFunction"]] | None = None,
|
||||
websocket: bool = False,
|
||||
auth: bool | str | None = None,
|
||||
) -> type[ServerFunction]:
|
||||
"""Internal helper that creates a ServerFunction from a decorated function."""
|
||||
from pydantic import create_model
|
||||
|
||||
# Use function name directly
|
||||
name = fn.__name__
|
||||
|
||||
# Extract type hints and signature
|
||||
hints = get_type_hints(fn)
|
||||
sig = inspect.signature(fn)
|
||||
params = list(sig.parameters.items())
|
||||
|
||||
# Skip 'request' parameter (first param)
|
||||
input_params = params[1:] if params else []
|
||||
|
||||
# Build input schema from function parameters
|
||||
if input_params:
|
||||
# Build field definitions for create_model
|
||||
# Format: {field_name: (type, default) or (type, ...)}
|
||||
fields = {}
|
||||
for param_name, param in input_params:
|
||||
param_type = hints.get(param_name, Any)
|
||||
|
||||
if param.default is inspect.Parameter.empty:
|
||||
# Required field
|
||||
fields[param_name] = (param_type, ...)
|
||||
else:
|
||||
# Optional field with default
|
||||
fields[param_name] = (param_type, param.default)
|
||||
|
||||
# Create dynamic Pydantic model
|
||||
input_cls = create_model(f"{fn.__name__}_Input", **fields)
|
||||
else:
|
||||
input_cls = None
|
||||
|
||||
# Get output type from return annotation
|
||||
output_type = hints.get("return")
|
||||
if output_type is None:
|
||||
raise TypeError(f"Server function '{name}' must have a return type annotation")
|
||||
|
||||
# Support primitive return types by wrapping in a model with 'result' field
|
||||
# Also handle Optional[X] / X | None by extracting the non-None type
|
||||
import types
|
||||
|
||||
def is_basemodel_type(t: Any) -> bool:
|
||||
"""Check if type is a BaseModel subclass, handling Optional/Union."""
|
||||
if isinstance(t, type) and issubclass(t, BaseModel):
|
||||
return True
|
||||
# Handle Union types: typing.Union (Optional[X]) and types.UnionType (X | None)
|
||||
origin = get_origin(t)
|
||||
if origin is Union or isinstance(t, types.UnionType):
|
||||
args = get_args(t)
|
||||
# Check if any non-None arg is a BaseModel
|
||||
for arg in args:
|
||||
if (
|
||||
arg is not type(None)
|
||||
and isinstance(arg, type)
|
||||
and issubclass(arg, BaseModel)
|
||||
):
|
||||
return True
|
||||
return False
|
||||
|
||||
if is_basemodel_type(output_type):
|
||||
output_cls = output_type
|
||||
is_primitive_output = False
|
||||
else:
|
||||
# Create model wrapper for primitive types (int, str, list, etc.)
|
||||
output_cls = create_model(f"{fn.__name__}_Output", result=(output_type, ...))
|
||||
is_primitive_output = True
|
||||
|
||||
# Store param names for unpacking validated input
|
||||
param_names = [p[0] for p in input_params]
|
||||
|
||||
# Create a unique wrapper class for this function
|
||||
class FunctionWrapper(_FunctionWrapper):
|
||||
_param_names: ClassVar[list[str]] = param_names
|
||||
|
||||
FunctionWrapper.__name__ = fn.__name__
|
||||
FunctionWrapper.__doc__ = fn.__doc__
|
||||
FunctionWrapper.__module__ = fn.__module__ # Critical for discovery
|
||||
FunctionWrapper._wrapped_fn = staticmethod(fn)
|
||||
FunctionWrapper._input_cls = input_cls
|
||||
FunctionWrapper._output_cls = output_cls
|
||||
FunctionWrapper._is_primitive_output = is_primitive_output
|
||||
|
||||
# Set Input/Output class attributes for compatibility
|
||||
if input_cls is not None:
|
||||
FunctionWrapper.Input = input_cls
|
||||
FunctionWrapper.Output = output_cls
|
||||
|
||||
# Build metadata
|
||||
meta = {}
|
||||
|
||||
# Context name (any non-empty string)
|
||||
if context:
|
||||
meta["context"] = context
|
||||
|
||||
# Affects: mutation invalidation targets
|
||||
normalized_affects = _normalize_affects(affects)
|
||||
if normalized_affects:
|
||||
meta["affects"] = normalized_affects
|
||||
|
||||
# WebSocket: enable WebSocket transport
|
||||
if websocket:
|
||||
meta["websocket"] = True
|
||||
|
||||
# Auth requirement
|
||||
if auth is not None:
|
||||
if auth is True:
|
||||
meta["auth"] = "required"
|
||||
elif callable(auth):
|
||||
meta["auth"] = auth
|
||||
else:
|
||||
meta["auth"] = auth
|
||||
|
||||
if meta:
|
||||
FunctionWrapper._meta = {**FunctionWrapper._meta, **meta}
|
||||
|
||||
# Note: Registration happens via discovery (mizan_clients), not here.
|
||||
# This allows the decorator to be used without import-time side effects.
|
||||
|
||||
return FunctionWrapper
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# COMPOSE - Combine multiple contexts into a single provider
|
||||
# =============================================================================
|
||||
|
||||
|
||||
class ComposedContext:
|
||||
"""
|
||||
Marker class for composed contexts.
|
||||
|
||||
Stores metadata about the composition for schema export.
|
||||
"""
|
||||
|
||||
name: str
|
||||
_meta: dict[str, Any]
|
||||
_children: list[type[ServerFunction] | "ComposedContext"]
|
||||
_leaves: list[type[ServerFunction]]
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
name: str,
|
||||
children: list,
|
||||
leaves: list,
|
||||
on_server: bool,
|
||||
websocket: bool,
|
||||
):
|
||||
self.name = name
|
||||
self._children = children
|
||||
self._leaves = leaves
|
||||
self._meta = {
|
||||
"compose": True,
|
||||
"on_server": on_server,
|
||||
"websocket": websocket,
|
||||
"children": [c.name for c in children],
|
||||
"leaves": [leaf.name for leaf in leaves],
|
||||
}
|
||||
|
||||
@classmethod
|
||||
def get_schema_export(cls) -> dict[str, Any]:
|
||||
"""Export schema for TypeScript generation."""
|
||||
return {
|
||||
"name": cls.name,
|
||||
"type": "compose",
|
||||
"meta": cls._meta,
|
||||
"children": cls._meta.get("children", []),
|
||||
"leaves": cls._meta.get("leaves", []),
|
||||
}
|
||||
|
||||
|
||||
def _get_leaves(item) -> list[type[ServerFunction]]:
|
||||
"""Recursively collect all leaf contexts from a context or composition."""
|
||||
if isinstance(item, type) and issubclass(item, ServerFunction):
|
||||
return [item]
|
||||
elif isinstance(item, ComposedContext):
|
||||
return item._leaves.copy()
|
||||
elif hasattr(item, "_leaves"):
|
||||
# Duck typing for composed contexts
|
||||
return item._leaves.copy()
|
||||
else:
|
||||
raise TypeError(f"Expected ServerFunction or ComposedContext, got {type(item)}")
|
||||
|
||||
|
||||
def _is_context_enabled(item) -> bool:
|
||||
"""Check if an item is a context-enabled function or composition."""
|
||||
if isinstance(item, ComposedContext) or hasattr(item, "_leaves"):
|
||||
return True
|
||||
if isinstance(item, type) and issubclass(item, ServerFunction):
|
||||
meta = getattr(item, "_meta", {})
|
||||
return bool(meta.get("context"))
|
||||
return False
|
||||
|
||||
|
||||
def compose(
|
||||
*children,
|
||||
on_server: bool = False,
|
||||
websocket: bool = False,
|
||||
):
|
||||
"""
|
||||
Compose multiple contexts into a single provider.
|
||||
|
||||
Args:
|
||||
*children: Context functions (@client with a context name)
|
||||
or other @compose functions. All must be unique after flattening.
|
||||
|
||||
on_server: Bundle all calls into a single server request (default: False).
|
||||
- False: Frontend makes individual calls (mixed HTTP/WS OK)
|
||||
- True: Single bundled call. Requires transport consistency:
|
||||
all children must be HTTP-only XOR all must be websocket=True.
|
||||
|
||||
websocket: Transport for bundled call when on_server=True (default: False).
|
||||
- False: Bundled call over HTTP. All children must be HTTP-only.
|
||||
- True: Bundled call over WebSocket. All children must have websocket=True.
|
||||
|
||||
Usage:
|
||||
@client(context='local')
|
||||
def user_profile(request, user_id: int) -> ProfileOutput: ...
|
||||
|
||||
@client(context='local')
|
||||
def user_posts(request, user_id: int) -> PostsOutput: ...
|
||||
|
||||
@compose(user_profile, user_posts)
|
||||
def user_page():
|
||||
pass
|
||||
|
||||
# Frontend generates:
|
||||
# <UserPageProvider user_id={123}>
|
||||
# <App />
|
||||
# </UserPageProvider>
|
||||
|
||||
Nesting:
|
||||
@compose(ctx_a, ctx_b)
|
||||
def ab(): pass
|
||||
|
||||
@compose(ab, ctx_c) # Flattens to [ctx_a, ctx_b, ctx_c]
|
||||
def abc(): pass
|
||||
|
||||
Returns:
|
||||
A ComposedContext that can be used in other compositions.
|
||||
"""
|
||||
|
||||
def decorator(fn: Callable) -> ComposedContext:
|
||||
from mizan.setup.registry import register_compose
|
||||
|
||||
name = fn.__name__
|
||||
|
||||
# Validate: all children must be context-enabled
|
||||
for i, child in enumerate(children):
|
||||
if not _is_context_enabled(child):
|
||||
child_name = getattr(
|
||||
child, "name", getattr(child, "__name__", str(child))
|
||||
)
|
||||
raise ValueError(
|
||||
f"@compose argument {i} ({child_name}) is not context-enabled. "
|
||||
f"All children must have @client(context=...) or be @compose."
|
||||
)
|
||||
|
||||
# Flatten to collect all leaves
|
||||
leaves = []
|
||||
for child in children:
|
||||
leaves.extend(_get_leaves(child))
|
||||
|
||||
# Validate: no duplicate leaves (by identity)
|
||||
seen = set()
|
||||
for leaf in leaves:
|
||||
if id(leaf) in seen:
|
||||
raise ValueError(
|
||||
f"Duplicate context '{leaf.name}' in @compose({name}). "
|
||||
f"Each context can only appear once. Use named kwargs for reuse (future feature)."
|
||||
)
|
||||
seen.add(id(leaf))
|
||||
|
||||
# Validate transport consistency when on_server=True
|
||||
if on_server:
|
||||
has_websocket = [
|
||||
getattr(leaf, "_meta", {}).get("websocket", False) for leaf in leaves
|
||||
]
|
||||
|
||||
if websocket:
|
||||
# All must have websocket=True
|
||||
if not all(has_websocket):
|
||||
non_ws = [
|
||||
leaf.name for leaf, ws in zip(leaves, has_websocket) if not ws
|
||||
]
|
||||
raise ValueError(
|
||||
f"@compose({name}, on_server=True, websocket=True) requires all children "
|
||||
f"to have websocket=True. These are HTTP-only: {non_ws}"
|
||||
)
|
||||
else:
|
||||
# All must be HTTP-only
|
||||
if any(has_websocket):
|
||||
ws_enabled = [
|
||||
leaf.name for leaf, ws in zip(leaves, has_websocket) if ws
|
||||
]
|
||||
raise ValueError(
|
||||
f"@compose({name}, on_server=True, websocket=False) requires all children "
|
||||
f"to be HTTP-only. These have websocket=True: {ws_enabled}"
|
||||
)
|
||||
|
||||
# Create composed context
|
||||
composed = ComposedContext(
|
||||
name=name,
|
||||
children=list(children),
|
||||
leaves=leaves,
|
||||
on_server=on_server,
|
||||
websocket=websocket,
|
||||
)
|
||||
|
||||
# Make it a class-like object for consistency
|
||||
composed.__name__ = name
|
||||
composed.__doc__ = fn.__doc__
|
||||
|
||||
# Register the composition
|
||||
register_compose(composed, name)
|
||||
|
||||
return composed
|
||||
|
||||
return decorator
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# FORM HELPERS - Output types used by form server functions
|
||||
# =============================================================================
|
||||
|
||||
|
||||
class FormValidationOutput(BaseModel):
|
||||
"""Standard output for form validation."""
|
||||
|
||||
valid: bool
|
||||
errors: dict[str, list[str]]
|
||||
|
||||
|
||||
class FormSchemaField(BaseModel):
|
||||
"""Schema for a single form field."""
|
||||
|
||||
name: str
|
||||
type: str
|
||||
required: bool
|
||||
label: str
|
||||
help_text: str | None = None
|
||||
choices: list[tuple[str, str]] | None = None
|
||||
initial: Any = None
|
||||
|
||||
|
||||
class FormSchemaOutput(BaseModel):
|
||||
"""Standard output for form schema."""
|
||||
|
||||
fields: list[FormSchemaField]
|
||||
|
||||
|
||||
def create_form_functions(
|
||||
form_class: type,
|
||||
name: str,
|
||||
submit_handler: Callable[[HttpRequest, dict], BaseModel] | None = None,
|
||||
) -> tuple[type[ServerFunction], type[ServerFunction], type[ServerFunction] | None]:
|
||||
"""
|
||||
Generate server functions for a Django Form.
|
||||
|
||||
Args:
|
||||
form_class: Django Form class
|
||||
name: Base name for the functions
|
||||
submit_handler: Optional handler for form submission
|
||||
|
||||
Returns:
|
||||
Tuple of (SchemaFunction, ValidateFunction, SubmitFunction or None)
|
||||
|
||||
Usage:
|
||||
SchemaFn, ValidateFn, SubmitFn = create_form_functions(
|
||||
ContactForm,
|
||||
'contact',
|
||||
submit_handler=lambda req, data: ContactSubmitOutput(success=True),
|
||||
)
|
||||
register(SchemaFn, 'contact-schema')
|
||||
register(ValidateFn, 'contact-validate')
|
||||
register(SubmitFn, 'contact-submit')
|
||||
|
||||
Or use the helper:
|
||||
register_form(ContactForm, 'contact', submit_handler=...)
|
||||
"""
|
||||
from mizan.forms.schema_utils import build_form_schema
|
||||
|
||||
# Schema function - returns field definitions
|
||||
class FormSchema(ServerFunction):
|
||||
class Output(FormSchemaOutput):
|
||||
pass
|
||||
|
||||
def call(self, input):
|
||||
schema = build_form_schema(form_class)
|
||||
fields = [
|
||||
FormSchemaField(
|
||||
name=field.name,
|
||||
type=field.type,
|
||||
required=field.required,
|
||||
label=field.label or field.name,
|
||||
help_text=field.help_text or None,
|
||||
choices=[(c.value, c.label) for c in field.choices]
|
||||
if field.choices
|
||||
else None,
|
||||
initial=field.initial,
|
||||
)
|
||||
for field in schema.fields
|
||||
]
|
||||
return self.Output(fields=fields)
|
||||
|
||||
FormSchema.__name__ = f"{name.title().replace('-', '')}Schema"
|
||||
FormSchema._meta = {"form": True, "form_name": name, "form_role": "schema"}
|
||||
|
||||
# Validation function
|
||||
class FormDataInput(BaseModel):
|
||||
data: dict[str, Any]
|
||||
|
||||
class FormValidate(ServerFunction):
|
||||
Input = FormDataInput
|
||||
|
||||
class Output(FormValidationOutput):
|
||||
pass
|
||||
|
||||
def call(self, input):
|
||||
form = form_class(data=input.data)
|
||||
if form.is_valid():
|
||||
return self.Output(valid=True, errors={})
|
||||
return self.Output(valid=False, errors=dict(form.errors))
|
||||
|
||||
FormValidate.__name__ = f"{name.title().replace('-', '')}Validate"
|
||||
FormValidate._meta = {"form": True, "form_name": name, "form_role": "validate"}
|
||||
|
||||
# Submit function (optional)
|
||||
FormSubmit = None
|
||||
if submit_handler:
|
||||
|
||||
class FormSubmit(ServerFunction):
|
||||
Input = FormDataInput
|
||||
|
||||
def call(self, input):
|
||||
# Validate first
|
||||
form = form_class(data=input.data)
|
||||
if not form.is_valid():
|
||||
raise ValueError("Form validation failed")
|
||||
# Call handler
|
||||
return submit_handler(self.request, form.cleaned_data)
|
||||
|
||||
FormSubmit.__name__ = f"{name.title().replace('-', '')}Submit"
|
||||
FormSubmit._meta = {"form": True, "form_name": name, "form_role": "submit"}
|
||||
|
||||
return FormSchema, FormValidate, FormSubmit
|
||||
44
packages/mizan-rpc/adapters/django/src/mizan/client/jwt.py
Normal file
44
packages/mizan-rpc/adapters/django/src/mizan/client/jwt.py
Normal file
@@ -0,0 +1,44 @@
|
||||
"""
|
||||
mizan.client.jwt - JWT authentication for server functions.
|
||||
|
||||
Provides:
|
||||
- Server functions for obtaining/refreshing JWT tokens
|
||||
- JWT authentication utilities for validating tokens
|
||||
|
||||
Server Functions:
|
||||
- jwt_obtain: Convert authenticated session to JWT tokens
|
||||
- jwt_refresh: Refresh tokens using a refresh token
|
||||
|
||||
Note: This module is purpose-built for mizan server functions.
|
||||
For Django Ninja API authentication, use mizan.jwt.security directly.
|
||||
"""
|
||||
|
||||
# Token utilities (re-exports from django_jwt_session)
|
||||
from mizan.jwt.tokens import (
|
||||
create_token_pair,
|
||||
create_access_token,
|
||||
create_refresh_token,
|
||||
decode_token,
|
||||
refresh_tokens,
|
||||
TokenPair,
|
||||
TokenPayload,
|
||||
JWTUser,
|
||||
)
|
||||
|
||||
# Settings
|
||||
from mizan.jwt.settings import get_settings, JWTSettings
|
||||
|
||||
__all__ = [
|
||||
# Token utilities
|
||||
"create_token_pair",
|
||||
"create_access_token",
|
||||
"create_refresh_token",
|
||||
"decode_token",
|
||||
"refresh_tokens",
|
||||
"TokenPair",
|
||||
"TokenPayload",
|
||||
"JWTUser",
|
||||
# Settings
|
||||
"get_settings",
|
||||
"JWTSettings",
|
||||
]
|
||||
352
packages/mizan-rpc/adapters/django/src/mizan/export/__init__.py
Normal file
352
packages/mizan-rpc/adapters/django/src/mizan/export/__init__.py
Normal file
@@ -0,0 +1,352 @@
|
||||
"""
|
||||
mizan OpenAPI Schema Generator
|
||||
|
||||
Generates OpenAPI 3.0 compatible schema from registered server functions.
|
||||
Uses Django Ninja's battle-tested schema generation for robust Pydantic→OpenAPI conversion.
|
||||
|
||||
This schema is consumed by the frontend generator which uses openapi-typescript
|
||||
for robust type generation.
|
||||
|
||||
NOTE: Schema export is only available via management command for security.
|
||||
HTTP endpoint has been removed to prevent function enumeration.
|
||||
|
||||
Usage:
|
||||
python manage.py export_mizan_schema
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
import re
|
||||
from typing import TYPE_CHECKING, Any
|
||||
|
||||
# Lazy imports to avoid Django settings access at module load time
|
||||
# (asgi.py imports mizan before Django is fully configured)
|
||||
if TYPE_CHECKING:
|
||||
from django import forms
|
||||
from ninja import NinjaAPI
|
||||
|
||||
from mizan.setup.registry import get_registry, get_schema, get_context_groups, get_function
|
||||
|
||||
|
||||
__all__ = ["get_schema", "generate_openapi_schema", "generate_openapi_json"]
|
||||
|
||||
|
||||
def _extract_form_fields(form_class: type) -> list[dict[str, Any]]:
|
||||
"""
|
||||
Extract field definitions with constraints from a Django Form class.
|
||||
|
||||
Returns a list of field metadata suitable for Zod schema generation:
|
||||
- name: field name
|
||||
- zodType: base Zod type ("string", "number", "boolean", "array")
|
||||
- required: whether field is required
|
||||
- constraints: dict of Zod-compatible constraints
|
||||
|
||||
Constraints include:
|
||||
- min/max: for string length or number range
|
||||
- email/url: for format validation
|
||||
- regex: for pattern validation
|
||||
- choices: for enum validation
|
||||
"""
|
||||
try:
|
||||
# Try to instantiate form to get bound fields
|
||||
form = form_class()
|
||||
fields_dict = form.fields
|
||||
except TypeError:
|
||||
# Form requires extra args - use base_fields
|
||||
fields_dict = getattr(form_class, "base_fields", {})
|
||||
|
||||
result = []
|
||||
|
||||
for name, field in fields_dict.items():
|
||||
field_meta = _extract_field_constraints(name, field)
|
||||
result.append(field_meta)
|
||||
|
||||
return result
|
||||
|
||||
|
||||
def _extract_field_constraints(name: str, field: "forms.Field") -> dict[str, Any]:
|
||||
"""
|
||||
Extract Zod-compatible constraints from a single Django form field.
|
||||
"""
|
||||
from django import forms # Lazy import
|
||||
|
||||
meta: dict[str, Any] = {
|
||||
"name": name,
|
||||
"required": field.required,
|
||||
"constraints": {},
|
||||
}
|
||||
|
||||
# Determine base Zod type
|
||||
if isinstance(field, forms.BooleanField):
|
||||
meta["zodType"] = "boolean"
|
||||
elif isinstance(field, (forms.IntegerField, forms.FloatField, forms.DecimalField)):
|
||||
meta["zodType"] = "number"
|
||||
if isinstance(field, forms.IntegerField):
|
||||
meta["constraints"]["int"] = True
|
||||
elif isinstance(field, forms.MultipleChoiceField):
|
||||
meta["zodType"] = "array"
|
||||
meta["constraints"]["items"] = "string"
|
||||
elif isinstance(field, forms.FileField):
|
||||
meta["zodType"] = "file"
|
||||
else:
|
||||
# Default to string (CharField, EmailField, URLField, etc.)
|
||||
meta["zodType"] = "string"
|
||||
|
||||
# Extract string constraints
|
||||
if hasattr(field, "max_length") and field.max_length is not None:
|
||||
meta["constraints"]["max"] = field.max_length
|
||||
if hasattr(field, "min_length") and field.min_length is not None:
|
||||
meta["constraints"]["min"] = field.min_length
|
||||
|
||||
# Extract number constraints
|
||||
if hasattr(field, "max_value") and field.max_value is not None:
|
||||
meta["constraints"]["max"] = field.max_value
|
||||
if hasattr(field, "min_value") and field.min_value is not None:
|
||||
meta["constraints"]["min"] = field.min_value
|
||||
|
||||
# Email/URL format
|
||||
if isinstance(field, forms.EmailField):
|
||||
meta["constraints"]["email"] = True
|
||||
elif isinstance(field, forms.URLField):
|
||||
meta["constraints"]["url"] = True
|
||||
|
||||
# Choices (for enum validation)
|
||||
if hasattr(field, "choices") and field.choices:
|
||||
# Extract choice values (not labels)
|
||||
choices = []
|
||||
for choice in field.choices:
|
||||
if isinstance(choice, (list, tuple)) and len(choice) >= 1:
|
||||
# Skip empty/blank choices
|
||||
if choice[0] != "":
|
||||
choices.append(str(choice[0]))
|
||||
else:
|
||||
choices.append(str(choice))
|
||||
if choices:
|
||||
meta["constraints"]["choices"] = choices
|
||||
|
||||
# Regex validators
|
||||
for validator in field.validators:
|
||||
if hasattr(validator, "regex"):
|
||||
# RegexValidator - extract pattern
|
||||
pattern = validator.regex.pattern
|
||||
meta["constraints"]["regex"] = pattern
|
||||
if hasattr(validator, "message"):
|
||||
meta["constraints"]["regexMessage"] = validator.message
|
||||
break # Only use first regex validator
|
||||
|
||||
return meta
|
||||
|
||||
|
||||
def snake_to_camel(name: str) -> str:
|
||||
"""Convert snake_case or dotted.name to camelCase.
|
||||
|
||||
Examples:
|
||||
- login -> login
|
||||
- login.schema -> loginSchema
|
||||
- activate_totp -> activateTotp
|
||||
- activate_totp.schema -> activateTotpSchema
|
||||
"""
|
||||
# Split on both underscores and dots
|
||||
components = re.split(r"[._]", name)
|
||||
return components[0] + "".join(x.title() for x in components[1:])
|
||||
|
||||
|
||||
def _register_schema_endpoint(
|
||||
api: "NinjaAPI",
|
||||
path: str,
|
||||
operation_id: str,
|
||||
summary: str,
|
||||
input_cls: type | None,
|
||||
output_cls: type,
|
||||
) -> None:
|
||||
"""
|
||||
Register a dummy endpoint on the API for schema generation.
|
||||
|
||||
Sets __annotations__ directly to avoid closure capture issues
|
||||
and exec() security concerns.
|
||||
"""
|
||||
if input_cls is not None:
|
||||
|
||||
def endpoint(request, data):
|
||||
pass
|
||||
|
||||
# Set annotations directly to the actual type objects (not strings)
|
||||
endpoint.__annotations__ = {"data": input_cls}
|
||||
else:
|
||||
|
||||
def endpoint(request):
|
||||
pass
|
||||
|
||||
# Register with Ninja
|
||||
api.post(path, response=output_cls, operation_id=operation_id, summary=summary)(
|
||||
endpoint
|
||||
)
|
||||
|
||||
|
||||
def generate_openapi_schema() -> dict[str, Any]:
|
||||
"""
|
||||
Generate OpenAPI 3.0 schema for all registered mizan functions.
|
||||
|
||||
Uses Django Ninja's schema generation internally to ensure proper
|
||||
Pydantic→OpenAPI conversion (handling $refs, nested types, etc.).
|
||||
|
||||
Returns a complete OpenAPI document that can be processed by openapi-typescript.
|
||||
"""
|
||||
from ninja import NinjaAPI # Lazy import
|
||||
from pydantic import BaseModel, create_model # Lazy import
|
||||
|
||||
registry = get_registry()
|
||||
functions = registry.get("functions", {})
|
||||
|
||||
# Create a temporary Ninja API for schema generation only
|
||||
# This is NOT exposed as an HTTP endpoint - purely for leveraging Ninja's
|
||||
# battle-tested Pydantic→OpenAPI conversion
|
||||
schema_api = NinjaAPI(
|
||||
title="mizan Server Functions",
|
||||
version="1.0.0",
|
||||
description="Auto-generated schema for mizan server functions",
|
||||
docs_url=None, # No docs endpoint
|
||||
openapi_url=None, # No openapi endpoint
|
||||
)
|
||||
|
||||
function_metadata: list[dict[str, Any]] = []
|
||||
|
||||
# Store dynamically created classes so they persist for schema generation
|
||||
schema_classes: dict[str, type] = {}
|
||||
|
||||
for name, fn_class in functions.items():
|
||||
camel_name = snake_to_camel(name)
|
||||
meta = getattr(fn_class, "_meta", {})
|
||||
|
||||
# Get Input/Output classes
|
||||
input_cls = getattr(fn_class, "Input", None)
|
||||
output_cls = getattr(fn_class, "Output", None) or BaseModel
|
||||
|
||||
# Check if input_cls is a valid Pydantic model with fields
|
||||
has_input = (
|
||||
input_cls is not None
|
||||
and input_cls is not BaseModel
|
||||
and hasattr(input_cls, "model_fields")
|
||||
and bool(input_cls.model_fields)
|
||||
)
|
||||
|
||||
# Determine type names for metadata
|
||||
input_type_name = f"{camel_name}Input" if has_input else None
|
||||
output_type_name = f"{camel_name}Output"
|
||||
|
||||
# Create renamed Pydantic classes for cleaner schema names
|
||||
# Store them in schema_classes so they persist beyond loop scope
|
||||
# Uses create_model to avoid metaclass conflicts with custom base classes
|
||||
if has_input:
|
||||
schema_classes[input_type_name] = create_model(
|
||||
input_type_name, __base__=input_cls
|
||||
)
|
||||
schema_classes[output_type_name] = create_model(
|
||||
output_type_name, __base__=output_cls
|
||||
)
|
||||
|
||||
# Register endpoint using helper to avoid closure capture issues
|
||||
_register_schema_endpoint(
|
||||
api=schema_api,
|
||||
path=f"/mizan/{name}",
|
||||
operation_id=camel_name,
|
||||
summary=fn_class.__doc__ or f"Call {name}",
|
||||
input_cls=schema_classes.get(input_type_name),
|
||||
output_cls=schema_classes[output_type_name],
|
||||
)
|
||||
|
||||
# Collect function metadata for provider generation
|
||||
fn_meta_entry: dict[str, Any] = {
|
||||
"name": name,
|
||||
"camelName": camel_name,
|
||||
"hasInput": has_input,
|
||||
"inputType": input_type_name,
|
||||
"outputType": output_type_name,
|
||||
"transport": "websocket" if meta.get("websocket") else "http",
|
||||
"isContext": meta.get("context", False),
|
||||
# Form metadata
|
||||
"isForm": meta.get("form", False),
|
||||
"formName": meta.get("form_name"),
|
||||
"formRole": meta.get("form_role"), # "schema", "validate", "submit"
|
||||
}
|
||||
|
||||
# Affects metadata (mutation invalidation)
|
||||
if meta.get("affects"):
|
||||
fn_meta_entry["affects"] = meta["affects"]
|
||||
|
||||
# For form schema functions, extract field definitions for Zod generation
|
||||
if meta.get("form") and meta.get("form_role") == "schema":
|
||||
form_class = meta.get("form_class")
|
||||
if form_class is not None:
|
||||
try:
|
||||
fn_meta_entry["formFields"] = _extract_form_fields(form_class)
|
||||
except Exception as e:
|
||||
# Don't fail schema generation if field extraction fails
|
||||
fn_meta_entry["formFields"] = []
|
||||
fn_meta_entry["formFieldsError"] = str(e)
|
||||
|
||||
function_metadata.append(fn_meta_entry)
|
||||
|
||||
# Get the OpenAPI schema from Ninja (handles all Pydantic conversion properly)
|
||||
schema = schema_api.get_openapi_schema(path_prefix="")
|
||||
|
||||
# Add custom extension with function metadata for provider generation
|
||||
schema["x-mizan-functions"] = function_metadata
|
||||
|
||||
# Add x-mizan-contexts: grouped context metadata with param elevation
|
||||
context_groups = get_context_groups()
|
||||
if context_groups:
|
||||
contexts_meta: dict[str, Any] = {}
|
||||
for ctx_name, fn_names in context_groups.items():
|
||||
# Analyze params across all functions in the context
|
||||
param_info: dict[str, dict[str, Any]] = {}
|
||||
for fn_name in fn_names:
|
||||
fn_cls = get_function(fn_name)
|
||||
if fn_cls is None:
|
||||
continue
|
||||
input_cls = getattr(fn_cls, "Input", None)
|
||||
if input_cls and input_cls is not BaseModel and hasattr(input_cls, "model_fields"):
|
||||
for field_name, field_info in input_cls.model_fields.items():
|
||||
if field_name not in param_info:
|
||||
annotation = field_info.annotation
|
||||
# Map Python types to JSON schema types
|
||||
type_name = "string"
|
||||
if annotation in (int,):
|
||||
type_name = "integer"
|
||||
elif annotation in (float,):
|
||||
type_name = "number"
|
||||
elif annotation in (bool,):
|
||||
type_name = "boolean"
|
||||
param_info[field_name] = {
|
||||
"type": type_name,
|
||||
"sharedBy": [],
|
||||
}
|
||||
param_info[field_name]["sharedBy"].append(fn_name)
|
||||
|
||||
# A param is required if ALL functions in the context declare it
|
||||
for p_name, p_meta in param_info.items():
|
||||
p_meta["required"] = len(p_meta["sharedBy"]) == len(fn_names)
|
||||
|
||||
contexts_meta[ctx_name] = {
|
||||
"functions": fn_names,
|
||||
"params": param_info,
|
||||
}
|
||||
schema["x-mizan-contexts"] = contexts_meta
|
||||
|
||||
# Add x-mizan metadata to each operation
|
||||
for fn_meta in function_metadata:
|
||||
path = f"/mizan/{fn_meta['name']}"
|
||||
if path in schema.get("paths", {}):
|
||||
schema["paths"][path]["post"]["x-mizan"] = {
|
||||
"transport": fn_meta["transport"],
|
||||
"isContext": fn_meta["isContext"],
|
||||
}
|
||||
|
||||
return schema
|
||||
|
||||
|
||||
def generate_openapi_json(indent: int = 2) -> str:
|
||||
"""Generate OpenAPI schema as formatted JSON string."""
|
||||
schema = generate_openapi_schema()
|
||||
return json.dumps(schema, indent=indent)
|
||||
632
packages/mizan-rpc/adapters/django/src/mizan/forms/__init__.py
Normal file
632
packages/mizan-rpc/adapters/django/src/mizan/forms/__init__.py
Normal file
@@ -0,0 +1,632 @@
|
||||
"""
|
||||
mizanFormMixin - Turn Django Forms into server functions.
|
||||
|
||||
This mixin transforms any Django Form into mizan server functions,
|
||||
preserving full Django Form functionality (validation, widgets, ModelChoiceField, etc.)
|
||||
while exposing them through the unified server function API.
|
||||
|
||||
Usage:
|
||||
from django import forms
|
||||
from mizan.forms import mizanFormMixin, mizanFormMeta
|
||||
|
||||
class ContactForm(mizanFormMixin, forms.Form):
|
||||
mizan = mizanFormMeta(
|
||||
name="contact",
|
||||
title="Contact Us",
|
||||
submit_label="Send",
|
||||
)
|
||||
|
||||
name = forms.CharField()
|
||||
email = forms.EmailField()
|
||||
message = forms.CharField(widget=forms.Textarea)
|
||||
|
||||
def on_submit_success(self, request):
|
||||
send_email(self.cleaned_data)
|
||||
return {"sent": True}
|
||||
|
||||
Auto-registers server functions:
|
||||
- contact.schema
|
||||
- contact.validate
|
||||
- contact.submit
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import TYPE_CHECKING, Any, ClassVar
|
||||
|
||||
from django import forms
|
||||
from django.http import HttpRequest
|
||||
from pydantic import BaseModel, create_model
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from .schemas import FormValidation
|
||||
|
||||
|
||||
def _django_field_to_python_type(field: forms.Field) -> type:
|
||||
"""
|
||||
Map a Django form field to a Python type for Pydantic schema generation.
|
||||
|
||||
This provides TypeScript with proper field types instead of generic `any`.
|
||||
"""
|
||||
# Handle common Django field types
|
||||
if isinstance(field, forms.BooleanField):
|
||||
return bool
|
||||
elif isinstance(field, forms.IntegerField):
|
||||
return int
|
||||
elif isinstance(field, forms.FloatField):
|
||||
return float
|
||||
elif isinstance(field, forms.DecimalField):
|
||||
return str # Decimals serialize as strings for precision
|
||||
elif isinstance(field, forms.DateTimeField):
|
||||
return str # ISO format string
|
||||
elif isinstance(field, forms.DateField):
|
||||
return str # ISO format string
|
||||
elif isinstance(field, forms.TimeField):
|
||||
return str # ISO format string
|
||||
elif isinstance(field, forms.JSONField):
|
||||
return dict | list | str | int | float | bool | None
|
||||
elif isinstance(field, forms.MultipleChoiceField):
|
||||
return list[str]
|
||||
elif isinstance(field, forms.FileField):
|
||||
return str # File path/name as string
|
||||
elif isinstance(field, forms.ImageField):
|
||||
return str # File path/name as string
|
||||
else:
|
||||
# Default to string (covers CharField, EmailField, URLField, etc.)
|
||||
return str
|
||||
|
||||
|
||||
def _create_form_input_schema(
|
||||
form_class: type[forms.BaseForm],
|
||||
schema_name: str,
|
||||
) -> type[BaseModel]:
|
||||
"""
|
||||
Create a Pydantic model from Django Form fields.
|
||||
|
||||
This generates a typed schema for the form's input data, giving TypeScript
|
||||
full LSP support (autocomplete, type checking) for form fields.
|
||||
|
||||
Args:
|
||||
form_class: Django Form class to introspect
|
||||
schema_name: Name for the generated Pydantic model (e.g., "ContactFormData")
|
||||
|
||||
Returns:
|
||||
A Pydantic BaseModel subclass with fields matching the form
|
||||
"""
|
||||
# Instantiate form without data to get field definitions
|
||||
try:
|
||||
form = form_class()
|
||||
except TypeError:
|
||||
# Form requires extra args (like request) - use form_class.base_fields instead
|
||||
fields_dict = getattr(form_class, "base_fields", {})
|
||||
else:
|
||||
fields_dict = form.fields
|
||||
|
||||
# Build Pydantic field definitions
|
||||
pydantic_fields: dict[str, Any] = {}
|
||||
|
||||
for field_name, field in fields_dict.items():
|
||||
python_type = _django_field_to_python_type(field)
|
||||
|
||||
# Optional fields (not required or has initial value)
|
||||
if not field.required:
|
||||
python_type = python_type | None
|
||||
default = None
|
||||
elif field.initial is not None:
|
||||
default = field.initial
|
||||
else:
|
||||
default = ... # Required field
|
||||
|
||||
pydantic_fields[field_name] = (python_type, default)
|
||||
|
||||
# Create the model with a unique name
|
||||
model = create_model(schema_name, **pydantic_fields)
|
||||
|
||||
return model
|
||||
|
||||
|
||||
class mizanFormMeta(BaseModel):
|
||||
"""
|
||||
Configuration for a mizan form.
|
||||
|
||||
This Pydantic model provides type-safe configuration with full LSP support,
|
||||
and serializes to JSON for the frontend schema.
|
||||
|
||||
Required:
|
||||
name: API identifier (e.g., "contact" → contact.schema, contact.validate, contact.submit)
|
||||
|
||||
Display options:
|
||||
title: Display title (default: derived from class name)
|
||||
subtitle: Display subtitle
|
||||
submit_label: Submit button text (default: "Submit")
|
||||
|
||||
Frontend behavior:
|
||||
live_validation: Enable live validation as user types (default: True)
|
||||
live_form_errors: Show form-level errors during live validation (default: False)
|
||||
refetch_schema_on_validate: Refetch schema on each validation - useful for
|
||||
dynamic choice fields (default: False)
|
||||
|
||||
Features:
|
||||
enable_formset: Generate formset endpoints (default: False)
|
||||
"""
|
||||
|
||||
# Required
|
||||
name: str
|
||||
|
||||
# Display
|
||||
title: str | None = None
|
||||
subtitle: str | None = None
|
||||
submit_label: str = "Submit"
|
||||
|
||||
# Frontend behavior
|
||||
live_validation: bool = True
|
||||
live_form_errors: bool = False
|
||||
refetch_schema_on_validate: bool = False
|
||||
|
||||
# Features
|
||||
enable_formset: bool = False
|
||||
|
||||
|
||||
class mizanFormMixin:
|
||||
"""
|
||||
Mixin that exposes a Django Form as mizan server functions.
|
||||
|
||||
Add this mixin to any Django Form class along with a `mizan` configuration:
|
||||
|
||||
class ContactForm(mizanFormMixin, forms.Form):
|
||||
mizan = mizanFormMeta(
|
||||
name="contact",
|
||||
title="Contact Us",
|
||||
)
|
||||
|
||||
name = forms.CharField()
|
||||
email = forms.EmailField()
|
||||
|
||||
def on_submit_success(self, request):
|
||||
return {"sent": True}
|
||||
|
||||
This auto-registers:
|
||||
- contact.schema - Get form field definitions
|
||||
- contact.validate - Validate form data
|
||||
- contact.submit - Submit form
|
||||
|
||||
Overridable methods:
|
||||
get_init_kwargs(cls, request) -> dict: Extra kwargs for form instantiation
|
||||
on_submit_success(self, request) -> dict | None: Handle successful submission
|
||||
on_submit_failure(self, request, errors) -> None: Handle failed submission
|
||||
"""
|
||||
|
||||
# Configuration - subclasses must define this
|
||||
mizan: ClassVar[mizanFormMeta]
|
||||
|
||||
# Track registered forms to avoid duplicate registration
|
||||
_mizan_registered: ClassVar[bool] = False
|
||||
|
||||
@classmethod
|
||||
def get_init_kwargs(cls, request: HttpRequest) -> dict[str, Any]:
|
||||
"""
|
||||
Override to provide extra kwargs for form instantiation.
|
||||
|
||||
Common use: pass request or user to forms that need them.
|
||||
|
||||
Example:
|
||||
@classmethod
|
||||
def get_init_kwargs(cls, request):
|
||||
return {"request": request, "user": request.user}
|
||||
"""
|
||||
return {}
|
||||
|
||||
def on_submit_success(self, request: HttpRequest) -> dict | None:
|
||||
"""
|
||||
Called after successful form validation and submission.
|
||||
|
||||
Override to handle the form submission logic.
|
||||
Return a dict to include data in the response.
|
||||
|
||||
Example:
|
||||
def on_submit_success(self, request):
|
||||
self.save()
|
||||
return {"id": self.instance.pk}
|
||||
"""
|
||||
# Default: call save() if available
|
||||
if hasattr(self, "save"):
|
||||
result = self.save()
|
||||
# If save returns something serializable, include it
|
||||
if isinstance(result, dict):
|
||||
return result
|
||||
return None
|
||||
|
||||
def on_submit_failure(self, request: HttpRequest, errors: "FormValidation") -> None:
|
||||
"""
|
||||
Called after form validation fails.
|
||||
|
||||
Override to add custom error handling, logging, etc.
|
||||
"""
|
||||
pass
|
||||
|
||||
def __init_subclass__(cls, **kwargs):
|
||||
"""Auto-register when a concrete form class is defined."""
|
||||
super().__init_subclass__(**kwargs)
|
||||
|
||||
# Only register concrete forms with mizan config defined
|
||||
if _is_concrete_mizan_form(cls):
|
||||
_register_form_as_server_functions(cls)
|
||||
|
||||
|
||||
def _is_concrete_mizan_form(cls: type) -> bool:
|
||||
"""
|
||||
Check if a class is a concrete mizan form ready for registration.
|
||||
|
||||
A form is concrete if:
|
||||
1. It has a `mizan` attribute that is a mizanFormMeta instance
|
||||
2. It inherits from Django's BaseForm
|
||||
3. It hasn't been registered yet (for this class definition)
|
||||
"""
|
||||
# Must have mizan config (check cls.__dict__ to avoid inheriting)
|
||||
mizan_config = cls.__dict__.get("mizan")
|
||||
if not isinstance(mizan_config, mizanFormMeta):
|
||||
return False
|
||||
|
||||
# Must be a Django form
|
||||
if not issubclass(cls, forms.BaseForm):
|
||||
return False
|
||||
|
||||
# Check if already registered (handle re-imports gracefully)
|
||||
if cls.__dict__.get("_mizan_registered", False):
|
||||
return False
|
||||
|
||||
return True
|
||||
|
||||
|
||||
def _register_form_as_server_functions(form_class: type) -> None:
|
||||
"""
|
||||
Register a Django Form class as mizan server functions.
|
||||
|
||||
Creates and registers:
|
||||
- {name}.schema - Returns form field definitions
|
||||
- {name}.validate - Validates form data
|
||||
- {name}.submit - Validates and submits form
|
||||
|
||||
Each function gets a unique typed schema for better TypeScript LSP support.
|
||||
"""
|
||||
from .schemas import FormSchema, FormSubmitFail, FormSubmitPass, FormValidation
|
||||
from .schema_utils import build_form_schema
|
||||
from .validation_utils import validate_form_instance
|
||||
from mizan.setup.registry import register
|
||||
from mizan.client.function import ServerFunction
|
||||
|
||||
config: mizanFormMeta = form_class.mizan
|
||||
form_name = config.name
|
||||
|
||||
# Mark as registered
|
||||
form_class._mizan_registered = True
|
||||
|
||||
# Generate PascalCase name for schemas (e.g., "contact" -> "Contact")
|
||||
pascal_name = "".join(
|
||||
word.capitalize()
|
||||
for word in form_name.replace(".", "_").replace("-", "_").split("_")
|
||||
)
|
||||
|
||||
# NOTE: We cannot create FormDataSchema here because form fields aren't
|
||||
# populated yet during __init_subclass__. We use lazy creation instead.
|
||||
_form_data_schema_cache: dict[str, type[BaseModel]] = {}
|
||||
|
||||
def get_form_data_schema() -> type[BaseModel]:
|
||||
"""Lazily create the form data schema (form fields aren't available at registration time)."""
|
||||
if "schema" not in _form_data_schema_cache:
|
||||
_form_data_schema_cache["schema"] = _create_form_input_schema(
|
||||
form_class, f"{pascal_name}FormData"
|
||||
)
|
||||
return _form_data_schema_cache["schema"]
|
||||
|
||||
# -------------------------------------------------------------------------
|
||||
# Schema Function
|
||||
# -------------------------------------------------------------------------
|
||||
|
||||
# Schema input wraps the form data for pre-populating dynamic fields
|
||||
FormSchemaInput = create_model(
|
||||
f"{pascal_name}SchemaInput",
|
||||
data=(dict[str, Any], {}),
|
||||
)
|
||||
|
||||
class SchemaFunction(ServerFunction):
|
||||
Input = FormSchemaInput
|
||||
Output = FormSchema
|
||||
|
||||
_meta: ClassVar[dict] = {
|
||||
"form": True,
|
||||
"form_name": form_name,
|
||||
"form_role": "schema",
|
||||
"form_class": form_class, # Store reference for schema generation
|
||||
}
|
||||
|
||||
def call(self, input) -> FormSchema:
|
||||
init_kwargs = form_class.get_init_kwargs(self.request)
|
||||
schema = build_form_schema(
|
||||
form_class,
|
||||
data=input.data if input else {},
|
||||
**init_kwargs,
|
||||
)
|
||||
# Override with mizanFormMeta values
|
||||
if config.title is not None:
|
||||
schema.title = config.title
|
||||
if config.subtitle is not None:
|
||||
schema.subtitle = config.subtitle
|
||||
schema.submit_label = config.submit_label
|
||||
# Behavior settings are nested in schema.meta
|
||||
schema.meta.live_validation = config.live_validation
|
||||
schema.meta.live_form_errors = config.live_form_errors
|
||||
schema.meta.refetch_schema_on_validate = config.refetch_schema_on_validate
|
||||
return schema
|
||||
|
||||
SchemaFunction.__name__ = f"{form_name}_schema"
|
||||
SchemaFunction.__qualname__ = f"{form_name}_schema"
|
||||
register(SchemaFunction, f"{form_name}.schema")
|
||||
|
||||
# -------------------------------------------------------------------------
|
||||
# Validate Function
|
||||
# -------------------------------------------------------------------------
|
||||
|
||||
# Use generic dict input - form fields aren't available during __init_subclass__
|
||||
FormValidateInput = create_model(
|
||||
f"{pascal_name}ValidateInput",
|
||||
data=(dict[str, Any], ...),
|
||||
)
|
||||
|
||||
class ValidateFunction(ServerFunction):
|
||||
Input = FormValidateInput
|
||||
Output = FormValidation
|
||||
|
||||
_meta: ClassVar[dict] = {
|
||||
"form": True,
|
||||
"form_name": form_name,
|
||||
"form_role": "validate",
|
||||
}
|
||||
|
||||
def call(self, input) -> FormValidation:
|
||||
init_kwargs = form_class.get_init_kwargs(self.request)
|
||||
# Input data is already a dict
|
||||
data = input.data
|
||||
_, validation = validate_form_instance(
|
||||
form_class,
|
||||
data=data,
|
||||
files=None,
|
||||
**init_kwargs,
|
||||
)
|
||||
return validation
|
||||
|
||||
ValidateFunction.__name__ = f"{form_name}_validate"
|
||||
ValidateFunction.__qualname__ = f"{form_name}_validate"
|
||||
register(ValidateFunction, f"{form_name}.validate")
|
||||
|
||||
# -------------------------------------------------------------------------
|
||||
# Submit Function
|
||||
# -------------------------------------------------------------------------
|
||||
|
||||
class SubmitFunction(ServerFunction):
|
||||
"""
|
||||
Submit function handles both JSON and multipart/form-data.
|
||||
|
||||
The executor detects form functions and parses the request appropriately.
|
||||
"""
|
||||
|
||||
# Use dict for input - form fields unknown at registration time
|
||||
Input = None # Signals executor to pass raw dict
|
||||
|
||||
_meta: ClassVar[dict] = {
|
||||
"form": True,
|
||||
"form_name": form_name,
|
||||
"form_role": "submit",
|
||||
"multipart": True, # Signal that this function accepts multipart
|
||||
}
|
||||
|
||||
def call(self, input) -> FormSubmitPass | FormSubmitFail:
|
||||
"""Execute form submission."""
|
||||
request = self.request
|
||||
|
||||
# Check if we have multipart data from executor
|
||||
if hasattr(request, "_mizan_form_data"):
|
||||
data = request._mizan_form_data
|
||||
files = request._mizan_form_files
|
||||
elif input is not None:
|
||||
# JSON input - already a dict
|
||||
data = input if isinstance(input, dict) else input.model_dump()
|
||||
files = None
|
||||
else:
|
||||
data = {}
|
||||
files = None
|
||||
|
||||
init_kwargs = form_class.get_init_kwargs(request)
|
||||
|
||||
# Create and validate form
|
||||
form, validation = validate_form_instance(
|
||||
form_class,
|
||||
data=data,
|
||||
files=files,
|
||||
**init_kwargs,
|
||||
)
|
||||
|
||||
if form.is_valid():
|
||||
# Call the form's on_submit_success
|
||||
result_data = form.on_submit_success(request)
|
||||
return FormSubmitPass(success=True, data=result_data)
|
||||
|
||||
# Call the form's on_submit_failure
|
||||
form.on_submit_failure(request, validation)
|
||||
return FormSubmitFail(success=False, errors=validation)
|
||||
|
||||
SubmitFunction.__name__ = f"{form_name}_submit"
|
||||
SubmitFunction.__qualname__ = f"{form_name}_submit"
|
||||
SubmitFunction.Output = FormSubmitPass # For schema generation
|
||||
register(SubmitFunction, f"{form_name}.submit")
|
||||
|
||||
# -------------------------------------------------------------------------
|
||||
# Formset Functions (if enabled)
|
||||
# -------------------------------------------------------------------------
|
||||
|
||||
if config.enable_formset:
|
||||
_register_formset_functions(form_class, form_name)
|
||||
|
||||
|
||||
def _register_formset_functions(
|
||||
form_class: type,
|
||||
form_name: str,
|
||||
) -> None:
|
||||
"""Register formset server functions for a form."""
|
||||
from django.forms import formset_factory
|
||||
|
||||
from .schemas import (
|
||||
FormsetSchema,
|
||||
FormsetSubmitFail,
|
||||
FormsetSubmitPass,
|
||||
FormsetValidation,
|
||||
)
|
||||
from .schema_utils import build_form_schema
|
||||
from .validation_utils import build_formset_validation
|
||||
from .formset_utils import forms_to_formset_post_data
|
||||
from mizan.setup.registry import register
|
||||
from mizan.client.function import ServerFunction
|
||||
|
||||
formset_class = formset_factory(form_class)
|
||||
|
||||
# Generate PascalCase name for schemas
|
||||
pascal_name = "".join(
|
||||
word.capitalize()
|
||||
for word in form_name.replace(".", "_").replace("-", "_").split("_")
|
||||
)
|
||||
|
||||
# NOTE: We cannot create typed schemas here because form fields aren't
|
||||
# populated yet during __init_subclass__. We use generic dict inputs.
|
||||
|
||||
# -------------------------------------------------------------------------
|
||||
# Formset Schema Function
|
||||
# -------------------------------------------------------------------------
|
||||
|
||||
FormsetSchemaInput = create_model(
|
||||
f"{pascal_name}FormsetSchemaInput",
|
||||
forms=(list[dict[str, Any]], []),
|
||||
)
|
||||
|
||||
class FormsetSchemaFunction(ServerFunction):
|
||||
Input = FormsetSchemaInput
|
||||
Output = FormsetSchema
|
||||
|
||||
_meta: ClassVar[dict] = {
|
||||
"form": True,
|
||||
"form_name": form_name,
|
||||
"form_role": "formset_schema",
|
||||
}
|
||||
|
||||
def call(self, input) -> FormsetSchema:
|
||||
init_kwargs = form_class.get_init_kwargs(self.request)
|
||||
forms_data = input.forms if input else []
|
||||
|
||||
formset_data = forms_to_formset_post_data(forms_data)
|
||||
formset = formset_class(formset_data)
|
||||
|
||||
return FormsetSchema(
|
||||
forms=[
|
||||
build_form_schema(form_class, data=fd, **init_kwargs)
|
||||
for fd in forms_data
|
||||
],
|
||||
min_num=formset.min_num,
|
||||
max_num=formset.max_num,
|
||||
can_delete=formset.can_delete,
|
||||
can_order=formset.can_order,
|
||||
)
|
||||
|
||||
FormsetSchemaFunction.__name__ = f"{form_name}_formset_schema"
|
||||
register(FormsetSchemaFunction, f"{form_name}.formset.schema")
|
||||
|
||||
# -------------------------------------------------------------------------
|
||||
# Formset Validate Function
|
||||
# -------------------------------------------------------------------------
|
||||
|
||||
# Generic dict input - form fields aren't available during __init_subclass__
|
||||
FormsetValidateInput = create_model(
|
||||
f"{pascal_name}FormsetValidateInput",
|
||||
forms=(list[dict[str, Any]], ...),
|
||||
)
|
||||
|
||||
class FormsetValidateFunction(ServerFunction):
|
||||
Input = FormsetValidateInput
|
||||
Output = FormsetValidation
|
||||
|
||||
_meta: ClassVar[dict] = {
|
||||
"form": True,
|
||||
"form_name": form_name,
|
||||
"form_role": "formset_validate",
|
||||
}
|
||||
|
||||
def call(self, input) -> FormsetValidation:
|
||||
init_kwargs = form_class.get_init_kwargs(self.request)
|
||||
# Input.forms is already a list of dicts
|
||||
forms_data = input.forms
|
||||
|
||||
formset_data = forms_to_formset_post_data(forms_data)
|
||||
formset = formset_class(formset_data, form_kwargs=init_kwargs)
|
||||
|
||||
for form in formset:
|
||||
form.empty_permitted = False
|
||||
|
||||
return build_formset_validation(formset)
|
||||
|
||||
FormsetValidateFunction.__name__ = f"{form_name}_formset_validate"
|
||||
register(FormsetValidateFunction, f"{form_name}.formset.validate")
|
||||
|
||||
# -------------------------------------------------------------------------
|
||||
# Formset Submit Function
|
||||
# -------------------------------------------------------------------------
|
||||
|
||||
# Generic dict input - form fields aren't available during __init_subclass__
|
||||
FormsetSubmitInput = create_model(
|
||||
f"{pascal_name}FormsetSubmitInput",
|
||||
forms=(list[dict[str, Any]], ...),
|
||||
)
|
||||
|
||||
class FormsetSubmitFunction(ServerFunction):
|
||||
Input = FormsetSubmitInput
|
||||
|
||||
_meta: ClassVar[dict] = {
|
||||
"form": True,
|
||||
"form_name": form_name,
|
||||
"form_role": "formset_submit",
|
||||
"multipart": True,
|
||||
}
|
||||
|
||||
def call(self, input) -> FormsetSubmitPass | FormsetSubmitFail:
|
||||
request = self.request
|
||||
init_kwargs = form_class.get_init_kwargs(request)
|
||||
|
||||
# Handle multipart vs JSON
|
||||
if hasattr(request, "_mizan_form_data"):
|
||||
post_data = request._mizan_form_data
|
||||
files = request._mizan_form_files
|
||||
elif input and hasattr(input, "forms"):
|
||||
# Input.forms is already a list of dicts
|
||||
forms_data = input.forms
|
||||
post_data = forms_to_formset_post_data(forms_data)
|
||||
files = None
|
||||
else:
|
||||
post_data = {}
|
||||
files = None
|
||||
|
||||
formset = formset_class(post_data, files=files, form_kwargs=init_kwargs)
|
||||
|
||||
if formset.is_valid():
|
||||
for form in formset.forms:
|
||||
if form.cleaned_data:
|
||||
form.on_submit_success(request)
|
||||
return FormsetSubmitPass(success=True)
|
||||
|
||||
validation = build_formset_validation(formset)
|
||||
# Call failure handler on each form
|
||||
for form in formset.forms:
|
||||
if hasattr(form, "on_submit_failure"):
|
||||
form.on_submit_failure(request, validation)
|
||||
|
||||
return FormsetSubmitFail(success=False, errors=validation)
|
||||
|
||||
FormsetSubmitFunction.__name__ = f"{form_name}_formset_submit"
|
||||
FormsetSubmitFunction.Output = FormsetSubmitPass
|
||||
register(FormsetSubmitFunction, f"{form_name}.formset.submit")
|
||||
@@ -0,0 +1,16 @@
|
||||
from typing import Any
|
||||
|
||||
|
||||
def forms_to_formset_post_data(forms_data: list[dict[str, Any]]) -> dict[str, Any]:
|
||||
"""
|
||||
Convert a list of form dicts into Django formset-compatible POST data.
|
||||
"""
|
||||
formset_data: dict[str, Any] = {
|
||||
"form-TOTAL_FORMS": str(len(forms_data)),
|
||||
"form-INITIAL_FORMS": "0",
|
||||
}
|
||||
for i, form_data in enumerate(forms_data):
|
||||
formset_data.update(
|
||||
{f"form-{i}-{key}": value for key, value in form_data.items()}
|
||||
)
|
||||
return formset_data
|
||||
@@ -0,0 +1,187 @@
|
||||
import re
|
||||
from typing import Any, Optional
|
||||
|
||||
from django import forms
|
||||
from django.forms import Field
|
||||
|
||||
from .schemas import FieldChoice, FieldSchema, FormMeta, FormSchema
|
||||
|
||||
|
||||
def create_form_instance(
|
||||
form_class: type[forms.BaseForm],
|
||||
data: Optional[dict] = None,
|
||||
files: Optional[dict] = None,
|
||||
**kwargs,
|
||||
) -> forms.BaseForm:
|
||||
"""
|
||||
Create a form instance, gracefully handling kwargs that the form doesn't accept.
|
||||
|
||||
Some Django forms (like allauth's) accept `request` in __init__, others don't.
|
||||
This function tries with all kwargs first, then progressively removes kwargs
|
||||
that cause TypeErrors until instantiation succeeds.
|
||||
"""
|
||||
# Common kwargs that forms may or may not accept
|
||||
optional_kwargs = ['request', 'user', 'instance']
|
||||
|
||||
# Build init kwargs
|
||||
init_kwargs = dict(kwargs)
|
||||
if data is not None:
|
||||
init_kwargs['data'] = data
|
||||
if files is not None:
|
||||
init_kwargs['files'] = files
|
||||
|
||||
while True:
|
||||
try:
|
||||
return form_class(**init_kwargs)
|
||||
except TypeError as e:
|
||||
error_msg = str(e)
|
||||
|
||||
# Check if it's an unexpected keyword argument error
|
||||
if "unexpected keyword argument" not in error_msg:
|
||||
raise
|
||||
|
||||
# Find which kwarg caused the problem and remove it
|
||||
removed = False
|
||||
for kwarg in optional_kwargs:
|
||||
if f"'{kwarg}'" in error_msg and kwarg in init_kwargs:
|
||||
init_kwargs.pop(kwarg)
|
||||
removed = True
|
||||
break
|
||||
|
||||
# If we couldn't identify/remove the problematic kwarg, re-raise
|
||||
if not removed:
|
||||
raise
|
||||
|
||||
|
||||
def _get_choices(field: Field) -> Optional[list[FieldChoice]]:
|
||||
"""
|
||||
Extract choices from a field, handling ModelChoiceField properly.
|
||||
ModelChoiceField returns ModelChoiceIteratorValue which is not JSON serializable.
|
||||
"""
|
||||
if not hasattr(field, "choices"):
|
||||
return None
|
||||
|
||||
choices: list[FieldChoice] = []
|
||||
for raw_value, label in field.choices:
|
||||
value = getattr(
|
||||
raw_value, "value", raw_value
|
||||
) # ModelChoiceIteratorValue -> .value
|
||||
choices.append(FieldChoice(value=str(value), label=str(label)))
|
||||
|
||||
return choices
|
||||
|
||||
|
||||
def _get_initial(value: Any) -> Any:
|
||||
"""Convert initial value to JSON-serializable format."""
|
||||
if value is None:
|
||||
return None
|
||||
if hasattr(value, "isoformat"):
|
||||
return value.isoformat()
|
||||
if hasattr(value, "pk"):
|
||||
return value.pk
|
||||
if hasattr(value, "__iter__") and not isinstance(value, (str, bytes)):
|
||||
return [item.pk if hasattr(item, "pk") else item for item in value]
|
||||
return value
|
||||
|
||||
|
||||
def _class_name_to_title(name: str) -> str:
|
||||
"""
|
||||
Convert a class name to a human-readable title.
|
||||
e.g., 'LoginForm' -> 'Login', 'ResetPasswordForm' -> 'Reset Password'
|
||||
"""
|
||||
# Remove 'Form' suffix
|
||||
name = re.sub(r"Form$", "", name)
|
||||
# Insert spaces before capital letters
|
||||
name = re.sub(r"([a-z])([A-Z])", r"\1 \2", name)
|
||||
return name
|
||||
|
||||
|
||||
def _class_name_to_slug(name: str) -> str:
|
||||
"""
|
||||
Convert a class name to a slug.
|
||||
e.g., 'LoginForm' -> 'login', 'ResetPasswordForm' -> 'reset_password'
|
||||
"""
|
||||
# Remove 'Form' suffix
|
||||
name = re.sub(r"Form$", "", name)
|
||||
# Insert underscores before capital letters and lowercase
|
||||
name = re.sub(r"([a-z])([A-Z])", r"\1_\2", name)
|
||||
return name.lower()
|
||||
|
||||
|
||||
def build_form_schema(
|
||||
form_class: type[forms.BaseForm],
|
||||
data: Optional[dict] = None,
|
||||
**kwargs,
|
||||
) -> FormSchema:
|
||||
"""
|
||||
Produce a FormSchema for the given Django form class and (optional) data.
|
||||
|
||||
The form class can define metadata via an inner Meta class:
|
||||
|
||||
class MyForm(forms.Form):
|
||||
class Meta:
|
||||
form_name = "my_form"
|
||||
title = "My Form Title"
|
||||
subtitle = "Optional description"
|
||||
submit_label = "Submit"
|
||||
|
||||
# Frontend behavior (optional)
|
||||
refetch_schema_on_validate = False # Set True for dynamic choice fields
|
||||
live_validation = True # Set False to disable live validation
|
||||
live_form_errors = False # Set True to show form errors live
|
||||
|
||||
If not provided, sensible defaults are derived from the class name.
|
||||
"""
|
||||
form = create_form_instance(form_class, data=data, **kwargs)
|
||||
|
||||
# Extract metadata from form's Meta class
|
||||
form_meta = getattr(form_class, "Meta", None)
|
||||
|
||||
# Get form name (used as identifier)
|
||||
name = getattr(form_meta, "form_name", None)
|
||||
if name is None:
|
||||
name = _class_name_to_slug(form_class.__name__)
|
||||
|
||||
# Get title (human-readable heading)
|
||||
title = getattr(form_meta, "title", None)
|
||||
if title is None:
|
||||
title = _class_name_to_title(form_class.__name__)
|
||||
|
||||
# Get optional subtitle
|
||||
subtitle = getattr(form_meta, "subtitle", None)
|
||||
|
||||
# Get submit button label
|
||||
submit_label = getattr(form_meta, "submit_label", None)
|
||||
if submit_label is None:
|
||||
submit_label = "Submit"
|
||||
|
||||
# Build frontend behavior metadata
|
||||
frontend_meta = FormMeta(
|
||||
refetch_schema_on_validate=getattr(form_meta, "refetch_schema_on_validate", False),
|
||||
live_validation=getattr(form_meta, "live_validation", True),
|
||||
live_form_errors=getattr(form_meta, "live_form_errors", False),
|
||||
)
|
||||
|
||||
return FormSchema(
|
||||
name=name,
|
||||
title=title,
|
||||
subtitle=subtitle,
|
||||
submit_label=submit_label,
|
||||
fields=[
|
||||
FieldSchema(
|
||||
name=name,
|
||||
label=str(field.label or name.replace("_", " ").title()),
|
||||
type=getattr(field.widget, "input_type", "text"),
|
||||
widget=field.widget.__class__.__name__,
|
||||
required=field.required,
|
||||
disabled=field.disabled,
|
||||
help_text=str(field.help_text) if field.help_text else "",
|
||||
initial=_get_initial(field.initial),
|
||||
max_length=getattr(field, "max_length", None),
|
||||
min_length=getattr(field, "min_length", None),
|
||||
choices=_get_choices(field),
|
||||
)
|
||||
for name, field in form.fields.items()
|
||||
],
|
||||
meta=frontend_meta,
|
||||
)
|
||||
103
packages/mizan-rpc/adapters/django/src/mizan/forms/schemas.py
Normal file
103
packages/mizan-rpc/adapters/django/src/mizan/forms/schemas.py
Normal file
@@ -0,0 +1,103 @@
|
||||
from typing import Any, Optional
|
||||
|
||||
from ninja import Schema
|
||||
|
||||
|
||||
# Form metadata schema
|
||||
class FormMeta(Schema):
|
||||
"""
|
||||
Metadata controlling frontend form behavior.
|
||||
|
||||
Attributes:
|
||||
refetch_schema_on_validate: If True, frontend should refetch schema on each
|
||||
validation (useful for dynamic choice fields). Default False.
|
||||
live_validation: If False, frontend should disable live validation entirely.
|
||||
Useful for sensitive forms like login. Default True.
|
||||
live_form_errors: If True, show form-level errors during live validation.
|
||||
Form errors are things like "Invalid credentials" vs field errors like
|
||||
"This field is required". Default False for security.
|
||||
"""
|
||||
refetch_schema_on_validate: bool = False
|
||||
live_validation: bool = True
|
||||
live_form_errors: bool = False
|
||||
|
||||
|
||||
# Field-level schemas
|
||||
class FieldChoice(Schema):
|
||||
value: str
|
||||
label: str
|
||||
|
||||
|
||||
class FieldError(Schema):
|
||||
message: str
|
||||
code: Optional[str]
|
||||
|
||||
|
||||
class FieldErrorList(Schema):
|
||||
field: str
|
||||
errors: list[FieldError]
|
||||
|
||||
|
||||
class FieldSchema(Schema):
|
||||
name: str
|
||||
label: str
|
||||
type: str
|
||||
widget: str
|
||||
required: bool
|
||||
disabled: bool
|
||||
help_text: str
|
||||
initial: Any
|
||||
max_length: Optional[int]
|
||||
min_length: Optional[int]
|
||||
choices: Optional[list[FieldChoice]]
|
||||
|
||||
|
||||
# Form-level schemas
|
||||
class FormSchema(Schema):
|
||||
"""Schema returned by /schema endpoint with form metadata and fields."""
|
||||
# Form metadata
|
||||
name: str
|
||||
title: str
|
||||
subtitle: Optional[str]
|
||||
submit_label: str
|
||||
# Fields
|
||||
fields: list[FieldSchema]
|
||||
# Frontend behavior metadata
|
||||
meta: FormMeta = FormMeta()
|
||||
|
||||
|
||||
class FormValidation(Schema):
|
||||
errors: list[FieldErrorList]
|
||||
|
||||
|
||||
class FormSubmitPass(Schema):
|
||||
success: bool
|
||||
data: Optional[dict] = None
|
||||
|
||||
|
||||
class FormSubmitFail(Schema):
|
||||
success: bool
|
||||
errors: FormValidation
|
||||
|
||||
|
||||
# Formset-level schemas
|
||||
class FormsetSchema(Schema):
|
||||
forms: list[FormSchema]
|
||||
min_num: int
|
||||
max_num: int
|
||||
can_delete: bool
|
||||
can_order: bool
|
||||
|
||||
|
||||
class FormsetValidation(Schema):
|
||||
general: list[str]
|
||||
per_form: list[FormValidation]
|
||||
|
||||
|
||||
class FormsetSubmitPass(Schema):
|
||||
success: bool
|
||||
|
||||
|
||||
class FormsetSubmitFail(Schema):
|
||||
success: bool
|
||||
errors: FormsetValidation
|
||||
@@ -0,0 +1,72 @@
|
||||
from typing import Any
|
||||
|
||||
from django import forms
|
||||
from django.core.files.uploadedfile import UploadedFile
|
||||
from django.utils.datastructures import MultiValueDict
|
||||
|
||||
from .schemas import (
|
||||
FieldError,
|
||||
FieldErrorList,
|
||||
FormValidation,
|
||||
FormsetValidation,
|
||||
)
|
||||
from .schema_utils import create_form_instance
|
||||
|
||||
|
||||
def validate_form_instance(
|
||||
form_class: type[forms.BaseForm],
|
||||
data: dict,
|
||||
files: MultiValueDict[str, UploadedFile] | None = None,
|
||||
**kwargs: Any,
|
||||
) -> tuple[forms.BaseForm, FormValidation]:
|
||||
"""
|
||||
Build a form instance and return (form, structured_validation_errors).
|
||||
"""
|
||||
form = create_form_instance(form_class, data=data, files=files, initial=data, **kwargs)
|
||||
|
||||
# Run validation
|
||||
form.is_valid()
|
||||
|
||||
validation = FormValidation(
|
||||
errors=[
|
||||
FieldErrorList(
|
||||
field=field_name,
|
||||
errors=[
|
||||
FieldError(
|
||||
message=str(e.message) if hasattr(e, 'message') else str(e),
|
||||
code=getattr(e, "code", None),
|
||||
)
|
||||
for e in field_errors.as_data()
|
||||
],
|
||||
)
|
||||
for field_name, field_errors in form.errors.items()
|
||||
]
|
||||
)
|
||||
return form, validation
|
||||
|
||||
|
||||
def build_formset_validation(formset: forms.BaseFormSet) -> FormsetValidation:
|
||||
"""
|
||||
Turn a Django formset into a FormsetValidation structure.
|
||||
"""
|
||||
return FormsetValidation(
|
||||
general=[str(e) if e else "" for e in formset.non_form_errors()],
|
||||
per_form=[
|
||||
FormValidation(
|
||||
errors=[
|
||||
FieldErrorList(
|
||||
field=field_name,
|
||||
errors=[
|
||||
FieldError(
|
||||
message=str(e.message) if hasattr(e, 'message') else str(e),
|
||||
code=getattr(e, "code", None),
|
||||
)
|
||||
for e in field_errors.as_data()
|
||||
],
|
||||
)
|
||||
for field_name, field_errors in form.errors.items()
|
||||
]
|
||||
)
|
||||
for form in formset
|
||||
],
|
||||
)
|
||||
@@ -0,0 +1,25 @@
|
||||
"""
|
||||
mizan Allauth Integration
|
||||
|
||||
Backend support for django-allauth with mizan server functions.
|
||||
|
||||
Provides:
|
||||
- Auth contexts (auth_status, user) - required by frontend allauth module
|
||||
- Allauth form wrappers - expose allauth forms as server functions
|
||||
|
||||
Usage:
|
||||
# In your app's apps.py
|
||||
class MyAppConfig(AppConfig):
|
||||
def ready(self):
|
||||
import mizan.allauth.forms # noqa - registers forms
|
||||
import mizan.allauth.contexts # noqa - registers contexts
|
||||
"""
|
||||
|
||||
from .contexts import auth_status, user, AuthStatusOutput, UserOutput
|
||||
|
||||
__all__ = [
|
||||
"auth_status",
|
||||
"user",
|
||||
"AuthStatusOutput",
|
||||
"UserOutput",
|
||||
]
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user