From 5312e0d64a61f4e6c70bc7c830998a9acd04972f Mon Sep 17 00:00:00 2001 From: Aidan Kim Date: Tue, 19 Nov 2024 10:23:15 -0500 Subject: [PATCH] Store JWT in a HTTP only cookie and add methods/endpoints for refresh tokens --- backend/api/authentication.py | 60 ++++++++++++++++++++++++++++------- backend/main.py | 4 +-- 2 files changed, 49 insertions(+), 15 deletions(-) diff --git a/backend/api/authentication.py b/backend/api/authentication.py index 051a155..e586a8e 100644 --- a/backend/api/authentication.py +++ b/backend/api/authentication.py @@ -1,11 +1,11 @@ import jwt -from datetime import datetime, timedelta -from fastapi import Depends, HTTPException, status, APIRouter +from datetime import datetime, timedelta, timezone +from fastapi import Cookie, Depends, HTTPException, Response, status, APIRouter from fastapi.security import HTTPBearer, HTTPAuthorizationCredentials +from backend.env import getenv from backend.models.user_model import User from ..services import UserService -auth_router = APIRouter() api = APIRouter(prefix="/api/authentication") openapi_tags = { @@ -13,12 +13,19 @@ openapi_tags = { "description": "Authentication of users and distributes bearer tokens", } -JWT_SECRET = "Sample Secret" -JWT_ALGORITHM = "HS256" -ACCESS_TOKEN_EXPIRE_MINUTES = 30 +JWT_SECRET = getenv("JWT_SECRET") +JWT_ALGORITHM = getenv("JWT_ALGORITHM") +ACCESS_TOKEN_EXPIRE_MINUTES = getenv("ACCESS_TOKEN_EXPIRE_MINUTES") +REFRESH_TOKEN_EXPIRE_DAYS = getenv("REFRESH_TOKEN_EXPIRE_DAYS") def create_access_token(user_id: str) -> str: - expiration = datetime.utcnow() + timedelta(minutes=ACCESS_TOKEN_EXPIRE_MINUTES) + expiration = datetime.now(timezone.utc) + timedelta(minutes=ACCESS_TOKEN_EXPIRE_MINUTES) + payload = {"user_id": user_id, "exp": expiration} + token = jwt.encode(payload, JWT_SECRET, algorithm=JWT_ALGORITHM) + return token + +def create_refresh_token(user_id: str) -> str: + expiration = datetime.now(timezone.utc) + timedelta(days=REFRESH_TOKEN_EXPIRE_DAYS) payload = {"user_id": user_id, "exp": expiration} token = jwt.encode(payload, JWT_SECRET, algorithm=JWT_ALGORITHM) return token @@ -49,8 +56,8 @@ def registered_user( detail="Invalid token" ) -@auth_router.post("/api/authentication", tags=["Authentication"]) -def return_bearer_token(user_id: str, user_service: UserService = Depends()): +@api.post("", include_in_schema=False, tags=["Authentication"]) +def return_bearer_token(user_id: str, response: Response, user_service: UserService = Depends()): user = user_service.get_user_by_id(user_id) if not user: raise HTTPException( @@ -58,9 +65,38 @@ def return_bearer_token(user_id: str, user_service: UserService = Depends()): detail="Invalid user ID" ) - access_token = create_access_token(user_id=user_id) - return {"access_token": access_token} + access_token = create_access_token(user_id) + refresh_token = create_refresh_token(user_id) -@auth_router.get("/api/authentication", tags=["Authentication"]) + response.set_cookie( + key="access_token", value=access_token, httponly=True, secure=True, max_age=ACCESS_TOKEN_EXPIRE_MINUTES * 60 + ) + response.set_cookie( + key="refresh_token", value=refresh_token, httponly=True, secure=True, max_age=REFRESH_TOKEN_EXPIRE_DAYS * 24 * 60 * 60 + ) + + return {"message": "Tokens set as cookies"} + +@api.post("/refresh", tags=["Authentication"]) +def refresh_access_token(response: Response, refresh_token: str = Depends(Cookie(None))): + try: + payload = jwt.decode(refresh_token, JWT_SECRET, algorithms=[JWT_ALGORITHM]) + user_id = payload.get("user_id") + if not user_id: + raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED, detail="Invalid refresh token") + + new_access_token = create_access_token(user_id) + + response.set_cookie( + key="access_token", value=new_access_token, httponly=True, secure=True, max_age=ACCESS_TOKEN_EXPIRE_MINUTES * 60 + ) + return {"message": "Access token refreshed"} + + except jwt.ExpiredSignatureError: + raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED, detail="Refresh token expired") + except jwt.PyJWTError: + raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED, detail="Invalid refresh token") + +@api.get("", include_in_schema=False, tags=["Authentication"]) def get_user_id(user_service: UserService = Depends()): return user_service.all() \ No newline at end of file diff --git a/backend/main.py b/backend/main.py index 99a1728..4330afd 100644 --- a/backend/main.py +++ b/backend/main.py @@ -26,13 +26,11 @@ app = FastAPI( app.add_middleware(GZipMiddleware) -feature_apis = [user, health, service, resource, tag] +feature_apis = [user, health, service, resource, tag, authentication] for feature_api in feature_apis: app.include_router(feature_api.api) -app.include_router(authentication.auth_router) - # Add application-wide exception handling middleware for commonly encountered API Exceptions @app.exception_handler(Exception) def permission_exception_handler(request: Request, e: Exception):