Store JWT in a HTTP only cookie and add methods/endpoints for refresh tokens

This commit is contained in:
Aidan Kim 2024-11-19 10:23:15 -05:00
parent 61dcfde469
commit 5312e0d64a
2 changed files with 49 additions and 15 deletions

View File

@ -1,11 +1,11 @@
import jwt import jwt
from datetime import datetime, timedelta from datetime import datetime, timedelta, timezone
from fastapi import Depends, HTTPException, status, APIRouter from fastapi import Cookie, Depends, HTTPException, Response, status, APIRouter
from fastapi.security import HTTPBearer, HTTPAuthorizationCredentials from fastapi.security import HTTPBearer, HTTPAuthorizationCredentials
from backend.env import getenv
from backend.models.user_model import User from backend.models.user_model import User
from ..services import UserService from ..services import UserService
auth_router = APIRouter()
api = APIRouter(prefix="/api/authentication") api = APIRouter(prefix="/api/authentication")
openapi_tags = { openapi_tags = {
@ -13,12 +13,19 @@ openapi_tags = {
"description": "Authentication of users and distributes bearer tokens", "description": "Authentication of users and distributes bearer tokens",
} }
JWT_SECRET = "Sample Secret" JWT_SECRET = getenv("JWT_SECRET")
JWT_ALGORITHM = "HS256" JWT_ALGORITHM = getenv("JWT_ALGORITHM")
ACCESS_TOKEN_EXPIRE_MINUTES = 30 ACCESS_TOKEN_EXPIRE_MINUTES = getenv("ACCESS_TOKEN_EXPIRE_MINUTES")
REFRESH_TOKEN_EXPIRE_DAYS = getenv("REFRESH_TOKEN_EXPIRE_DAYS")
def create_access_token(user_id: str) -> str: def create_access_token(user_id: str) -> str:
expiration = datetime.utcnow() + timedelta(minutes=ACCESS_TOKEN_EXPIRE_MINUTES) expiration = datetime.now(timezone.utc) + timedelta(minutes=ACCESS_TOKEN_EXPIRE_MINUTES)
payload = {"user_id": user_id, "exp": expiration}
token = jwt.encode(payload, JWT_SECRET, algorithm=JWT_ALGORITHM)
return token
def create_refresh_token(user_id: str) -> str:
expiration = datetime.now(timezone.utc) + timedelta(days=REFRESH_TOKEN_EXPIRE_DAYS)
payload = {"user_id": user_id, "exp": expiration} payload = {"user_id": user_id, "exp": expiration}
token = jwt.encode(payload, JWT_SECRET, algorithm=JWT_ALGORITHM) token = jwt.encode(payload, JWT_SECRET, algorithm=JWT_ALGORITHM)
return token return token
@ -49,8 +56,8 @@ def registered_user(
detail="Invalid token" detail="Invalid token"
) )
@auth_router.post("/api/authentication", tags=["Authentication"]) @api.post("", include_in_schema=False, tags=["Authentication"])
def return_bearer_token(user_id: str, user_service: UserService = Depends()): def return_bearer_token(user_id: str, response: Response, user_service: UserService = Depends()):
user = user_service.get_user_by_id(user_id) user = user_service.get_user_by_id(user_id)
if not user: if not user:
raise HTTPException( raise HTTPException(
@ -58,9 +65,38 @@ def return_bearer_token(user_id: str, user_service: UserService = Depends()):
detail="Invalid user ID" detail="Invalid user ID"
) )
access_token = create_access_token(user_id=user_id) access_token = create_access_token(user_id)
return {"access_token": access_token} refresh_token = create_refresh_token(user_id)
@auth_router.get("/api/authentication", tags=["Authentication"]) response.set_cookie(
key="access_token", value=access_token, httponly=True, secure=True, max_age=ACCESS_TOKEN_EXPIRE_MINUTES * 60
)
response.set_cookie(
key="refresh_token", value=refresh_token, httponly=True, secure=True, max_age=REFRESH_TOKEN_EXPIRE_DAYS * 24 * 60 * 60
)
return {"message": "Tokens set as cookies"}
@api.post("/refresh", tags=["Authentication"])
def refresh_access_token(response: Response, refresh_token: str = Depends(Cookie(None))):
try:
payload = jwt.decode(refresh_token, JWT_SECRET, algorithms=[JWT_ALGORITHM])
user_id = payload.get("user_id")
if not user_id:
raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED, detail="Invalid refresh token")
new_access_token = create_access_token(user_id)
response.set_cookie(
key="access_token", value=new_access_token, httponly=True, secure=True, max_age=ACCESS_TOKEN_EXPIRE_MINUTES * 60
)
return {"message": "Access token refreshed"}
except jwt.ExpiredSignatureError:
raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED, detail="Refresh token expired")
except jwt.PyJWTError:
raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED, detail="Invalid refresh token")
@api.get("", include_in_schema=False, tags=["Authentication"])
def get_user_id(user_service: UserService = Depends()): def get_user_id(user_service: UserService = Depends()):
return user_service.all() return user_service.all()

View File

@ -26,13 +26,11 @@ app = FastAPI(
app.add_middleware(GZipMiddleware) app.add_middleware(GZipMiddleware)
feature_apis = [user, health, service, resource, tag] feature_apis = [user, health, service, resource, tag, authentication]
for feature_api in feature_apis: for feature_api in feature_apis:
app.include_router(feature_api.api) app.include_router(feature_api.api)
app.include_router(authentication.auth_router)
# Add application-wide exception handling middleware for commonly encountered API Exceptions # Add application-wide exception handling middleware for commonly encountered API Exceptions
@app.exception_handler(Exception) @app.exception_handler(Exception)
def permission_exception_handler(request: Request, e: Exception): def permission_exception_handler(request: Request, e: Exception):