Fix MWT security issues from expert review

Critical:
- Separate MIZAN_MWT_SECRET from MIZAN_CACHE_SECRET — compromising one
  no longer compromises the other (token forgery vs cache poisoning)
- Move kid from JWT payload to JOSE header per RFC 7515 — standard
  libraries use header kid for key selection before payload decode

High:
- Full SHA-256 pkey (64 chars) instead of truncated 16 — no reason to
  reduce collision resistance
- Add nbf (not-before) claim for clock skew protection
- Log warnings in _try_mwt_auth on missing secret and decode failures
  instead of silent swallow
- Rename _csrf_protect_unless_jwt to _csrf_protect_unless_token (accuracy)
- decode_mwt logs at DEBUG level on failures for observability

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
This commit is contained in:
2026-04-07 00:52:30 -04:00
parent d7ec13c43c
commit 54581d184f
5 changed files with 46 additions and 24 deletions

View File

@@ -470,12 +470,15 @@ def _try_mwt_auth(request: HttpRequest) -> bool:
try: try:
settings = get_settings() settings = get_settings()
if not settings.cache_secret: if not settings.mwt_secret:
logging.getLogger("mizan.mwt").warning(
"X-Mizan-Token header present but MIZAN_MWT_SECRET is not configured"
)
return False return False
from mizan.mwt import decode_mwt, MWTUser from mizan.mwt import decode_mwt, MWTUser
payload = decode_mwt(token, settings.cache_secret) payload = decode_mwt(token, settings.mwt_secret)
if payload is None: if payload is None:
return False return False
@@ -483,6 +486,9 @@ def _try_mwt_auth(request: HttpRequest) -> bool:
request._mizan_mwt_authenticated = True request._mizan_mwt_authenticated = True
return True return True
except Exception: except Exception:
logging.getLogger("mizan.mwt").warning(
"MWT authentication failed unexpectedly", exc_info=True
)
return False return False
@@ -536,7 +542,7 @@ def _has_jwt_header(request: HttpRequest) -> bool:
return auth_header.startswith("Bearer ") return auth_header.startswith("Bearer ")
def _csrf_protect_unless_jwt(view_func): def _csrf_protect_unless_token(view_func):
""" """
Decorator that applies CSRF protection unless token auth is used. Decorator that applies CSRF protection unless token auth is used.
@@ -574,7 +580,7 @@ def _csrf_protect_unless_jwt(view_func):
return wrapper return wrapper
@_csrf_protect_unless_jwt @_csrf_protect_unless_token
def function_call_view(request: HttpRequest) -> JsonResponse: def function_call_view(request: HttpRequest) -> JsonResponse:
""" """
Django view for handling function calls (HTTP fallback for WebSocket RPC). Django view for handling function calls (HTTP fallback for WebSocket RPC).

View File

@@ -132,10 +132,10 @@ def mwt_obtain(request: HttpRequest) -> MWTOutput:
from mizan.setup.settings import get_settings from mizan.setup.settings import get_settings
settings = get_settings() settings = get_settings()
if not settings.cache_secret: if not settings.mwt_secret:
raise ValueError( raise ValueError(
"MIZAN_CACHE_SECRET is not configured. MWT requires a signing secret." "MIZAN_MWT_SECRET is not configured. MWT requires a signing secret."
) )
token = create_mwt(user, settings.cache_secret, ttl=settings.mwt_ttl) token = create_mwt(user, settings.mwt_secret, ttl=settings.mwt_ttl)
return MWTOutput(token=token, expires_in=settings.mwt_ttl) return MWTOutput(token=token, expires_in=settings.mwt_ttl)

View File

@@ -6,8 +6,9 @@ traveling on the `X-Mizan-Token` header. It provides:
- `sub`: user_id for HMAC cache key derivation - `sub`: user_id for HMAC cache key derivation
- `pkey`: permission state hash for staleness detection - `pkey`: permission state hash for staleness detection
- `kid`: key ID for secret rotation - `kid`: key ID in the JOSE header (per RFC 7515) for secret rotation
- `aud`: audience binding to prevent cross-tenant replay - `aud`: audience binding to prevent cross-tenant replay
- `nbf`: not-before to handle clock skew
MWT is issued from an authenticated Django session. The app handles MWT is issued from an authenticated Django session. The app handles
authentication (session, social auth, etc.); Mizan issues MWT from authentication (session, social auth, etc.); Mizan issues MWT from
@@ -18,19 +19,22 @@ Usage:
from mizan.mwt import create_mwt, decode_mwt, MWTUser from mizan.mwt import create_mwt, decode_mwt, MWTUser
Configuration: Configuration:
MIZAN_CACHE_SECRET: signing key (shared with cache key derivation) MIZAN_MWT_SECRET: MWT signing key (separate from MIZAN_CACHE_SECRET)
MIZAN_MWT_TTL: token lifetime in seconds (default: 300) MIZAN_MWT_TTL: token lifetime in seconds (default: 300)
""" """
from __future__ import annotations from __future__ import annotations
import hashlib import hashlib
import logging
import time import time
from dataclasses import dataclass from dataclasses import dataclass
from typing import Any from typing import Any
import jwt import jwt
logger = logging.getLogger("mizan.mwt")
@dataclass @dataclass
class MWTPayload: class MWTPayload:
@@ -38,8 +42,8 @@ class MWTPayload:
sub: str # user_id sub: str # user_id
staff: bool # is_staff staff: bool # is_staff
super: bool # is_superuser super: bool # is_superuser
pkey: str # permission state hash pkey: str # permission state hash (full SHA-256 hex)
kid: str # key ID kid: str # key ID (from JOSE header)
aud: str # audience aud: str # audience
iat: int # issued at iat: int # issued at
exp: int # expiration exp: int # expiration
@@ -78,13 +82,13 @@ def compute_permission_key(user: Any) -> str:
When the MWT expires and is refreshed, the new pkey reflects When the MWT expires and is refreshed, the new pkey reflects
any permission changes. The short TTL controls the staleness window. any permission changes. The short TTL controls the staleness window.
Returns a 16-character hex digest (SHA-256 truncated). Returns the full 64-character SHA-256 hex digest.
""" """
perms = sorted(user.get_all_permissions()) if hasattr(user, "get_all_permissions") else [] perms = sorted(user.get_all_permissions()) if hasattr(user, "get_all_permissions") else []
staff = "1" if getattr(user, "is_staff", False) else "0" staff = "1" if getattr(user, "is_staff", False) else "0"
superuser = "1" if getattr(user, "is_superuser", False) else "0" superuser = "1" if getattr(user, "is_superuser", False) else "0"
blob = f"{staff}:{superuser}:{','.join(perms)}" blob = f"{staff}:{superuser}:{','.join(perms)}"
return hashlib.sha256(blob.encode("utf-8")).hexdigest()[:16] return hashlib.sha256(blob.encode("utf-8")).hexdigest()
def create_mwt( def create_mwt(
@@ -99,10 +103,10 @@ def create_mwt(
Args: Args:
user: Django user object (must have pk, is_staff, is_superuser). user: Django user object (must have pk, is_staff, is_superuser).
secret: MIZAN_CACHE_SECRET signing key. secret: MIZAN_MWT_SECRET signing key.
ttl: Token lifetime in seconds (default: 300 = 5 minutes). ttl: Token lifetime in seconds (default: 300 = 5 minutes).
audience: Audience claim for cross-tenant protection. audience: Audience claim for cross-tenant protection.
kid: Key ID for future secret rotation. kid: Key ID placed in JOSE header (per RFC 7515) for rotation.
Returns: Returns:
Encoded JWT string. Encoded JWT string.
@@ -114,11 +118,13 @@ def create_mwt(
"super": getattr(user, "is_superuser", False), "super": getattr(user, "is_superuser", False),
"pkey": compute_permission_key(user), "pkey": compute_permission_key(user),
"aud": audience, "aud": audience,
"kid": kid,
"iat": now, "iat": now,
"nbf": now,
"exp": now + ttl, "exp": now + ttl,
} }
return jwt.encode(payload, secret, algorithm="HS256") # kid goes in the JOSE header per RFC 7515, not the payload
headers = {"kid": kid}
return jwt.encode(payload, secret, algorithm="HS256", headers=headers)
def decode_mwt( def decode_mwt(
@@ -130,9 +136,13 @@ def decode_mwt(
Decode and validate an MWT. Decode and validate an MWT.
Returns MWTPayload on success, None on any failure (expired, invalid Returns MWTPayload on success, None on any failure (expired, invalid
signature, wrong audience, malformed). signature, wrong audience, not-yet-valid, malformed).
""" """
try: try:
# Decode header first to extract kid
unverified_header = jwt.get_unverified_header(token)
kid = unverified_header.get("kid", "v1")
data = jwt.decode( data = jwt.decode(
token, token,
secret, secret,
@@ -140,6 +150,7 @@ def decode_mwt(
audience=audience, audience=audience,
) )
except jwt.PyJWTError: except jwt.PyJWTError:
logger.debug("MWT decode failed", exc_info=True)
return None return None
try: try:
@@ -148,10 +159,11 @@ def decode_mwt(
staff=data.get("staff", False), staff=data.get("staff", False),
super=data.get("super", False), super=data.get("super", False),
pkey=data.get("pkey", ""), pkey=data.get("pkey", ""),
kid=data.get("kid", "v1"), kid=kid,
aud=audience, aud=audience,
iat=data["iat"], iat=data["iat"],
exp=data["exp"], exp=data["exp"],
) )
except (KeyError, TypeError): except (KeyError, TypeError):
logger.debug("MWT payload missing required claims", exc_info=True)
return None return None

View File

@@ -17,12 +17,15 @@ class mizanSettings:
# Whether to expose function names in DEBUG mode errors # Whether to expose function names in DEBUG mode errors
debug_expose_names: bool debug_expose_names: bool
# Cache signing secret (required when cache is enabled) # Cache HMAC signing secret (required when cache is enabled)
cache_secret: str | None cache_secret: str | None
# Redis URL for cache backend (None = cache disabled) # Redis URL for cache backend (None = cache disabled)
cache_redis_url: str | None cache_redis_url: str | None
# MWT signing secret (separate from cache secret for blast radius containment)
mwt_secret: str | None
# MWT token lifetime in seconds (default: 300 = 5 minutes) # MWT token lifetime in seconds (default: 300 = 5 minutes)
mwt_ttl: int mwt_ttl: int
@@ -41,6 +44,7 @@ def get_settings() -> mizanSettings:
debug_expose_names=getattr(django_settings, "mizan_DEBUG_EXPOSE_NAMES", True), debug_expose_names=getattr(django_settings, "mizan_DEBUG_EXPOSE_NAMES", True),
cache_secret=getattr(django_settings, "MIZAN_CACHE_SECRET", None), cache_secret=getattr(django_settings, "MIZAN_CACHE_SECRET", None),
cache_redis_url=getattr(django_settings, "MIZAN_CACHE_REDIS_URL", None), cache_redis_url=getattr(django_settings, "MIZAN_CACHE_REDIS_URL", None),
mwt_secret=getattr(django_settings, "MIZAN_MWT_SECRET", None),
mwt_ttl=getattr(django_settings, "MIZAN_MWT_TTL", 300), mwt_ttl=getattr(django_settings, "MIZAN_MWT_TTL", 300),
) )

View File

@@ -3283,7 +3283,7 @@ class MWTCreationTests(TestCase):
self.assertTrue(payload.staff) self.assertTrue(payload.staff)
self.assertFalse(payload.super) self.assertFalse(payload.super)
self.assertEqual(payload.kid, "v1") self.assertEqual(payload.kid, "v1")
self.assertEqual(len(payload.pkey), 16) self.assertEqual(len(payload.pkey), 64)
def test_decode_expired(self): def test_decode_expired(self):
"""Expired MWT returns None.""" """Expired MWT returns None."""
@@ -3323,7 +3323,7 @@ class MWTCreationTests(TestCase):
self.assertEqual(mwt_user.pk, 5) self.assertEqual(mwt_user.pk, 5)
self.assertTrue(mwt_user.is_authenticated) self.assertTrue(mwt_user.is_authenticated)
self.assertEqual(len(mwt_user.pkey), 16) self.assertEqual(len(mwt_user.pkey), 64)
class PermissionKeyTests(TestCase): class PermissionKeyTests(TestCase):
@@ -3408,7 +3408,7 @@ class MWTAuthIntegrationTests(TestCase):
request.META["HTTP_X_MIZAN_TOKEN"] = token request.META["HTTP_X_MIZAN_TOKEN"] = token
request.user = MagicMock(is_authenticated=False) request.user = MagicMock(is_authenticated=False)
with override_settings(MIZAN_CACHE_SECRET=self.SECRET): with override_settings(MIZAN_MWT_SECRET=self.SECRET):
from mizan.setup.settings import clear_settings_cache from mizan.setup.settings import clear_settings_cache
clear_settings_cache() clear_settings_cache()
result = _try_mwt_auth(request) result = _try_mwt_auth(request)
@@ -3421,7 +3421,7 @@ class MWTAuthIntegrationTests(TestCase):
"""Invalid X-Mizan-Token returns 401 on context fetch.""" """Invalid X-Mizan-Token returns 401 on context fetch."""
from django.test import override_settings from django.test import override_settings
with override_settings(MIZAN_CACHE_SECRET=self.SECRET): with override_settings(MIZAN_MWT_SECRET=self.SECRET):
from mizan.setup.settings import clear_settings_cache from mizan.setup.settings import clear_settings_cache
clear_settings_cache() clear_settings_cache()
response = self.client.get( response = self.client.get(