Store JWT in a HTTP only cookie and add methods/endpoints for refresh tokens

This commit is contained in:
Aidan Kim 2024-11-19 10:23:15 -05:00
parent 61dcfde469
commit 5312e0d64a
2 changed files with 49 additions and 15 deletions

View File

@ -1,11 +1,11 @@
import jwt
from datetime import datetime, timedelta
from fastapi import Depends, HTTPException, status, APIRouter
from datetime import datetime, timedelta, timezone
from fastapi import Cookie, Depends, HTTPException, Response, status, APIRouter
from fastapi.security import HTTPBearer, HTTPAuthorizationCredentials
from backend.env import getenv
from backend.models.user_model import User
from ..services import UserService
auth_router = APIRouter()
api = APIRouter(prefix="/api/authentication")
openapi_tags = {
@ -13,12 +13,19 @@ openapi_tags = {
"description": "Authentication of users and distributes bearer tokens",
}
JWT_SECRET = "Sample Secret"
JWT_ALGORITHM = "HS256"
ACCESS_TOKEN_EXPIRE_MINUTES = 30
JWT_SECRET = getenv("JWT_SECRET")
JWT_ALGORITHM = getenv("JWT_ALGORITHM")
ACCESS_TOKEN_EXPIRE_MINUTES = getenv("ACCESS_TOKEN_EXPIRE_MINUTES")
REFRESH_TOKEN_EXPIRE_DAYS = getenv("REFRESH_TOKEN_EXPIRE_DAYS")
def create_access_token(user_id: str) -> str:
expiration = datetime.utcnow() + timedelta(minutes=ACCESS_TOKEN_EXPIRE_MINUTES)
expiration = datetime.now(timezone.utc) + timedelta(minutes=ACCESS_TOKEN_EXPIRE_MINUTES)
payload = {"user_id": user_id, "exp": expiration}
token = jwt.encode(payload, JWT_SECRET, algorithm=JWT_ALGORITHM)
return token
def create_refresh_token(user_id: str) -> str:
expiration = datetime.now(timezone.utc) + timedelta(days=REFRESH_TOKEN_EXPIRE_DAYS)
payload = {"user_id": user_id, "exp": expiration}
token = jwt.encode(payload, JWT_SECRET, algorithm=JWT_ALGORITHM)
return token
@ -49,8 +56,8 @@ def registered_user(
detail="Invalid token"
)
@auth_router.post("/api/authentication", tags=["Authentication"])
def return_bearer_token(user_id: str, user_service: UserService = Depends()):
@api.post("", include_in_schema=False, tags=["Authentication"])
def return_bearer_token(user_id: str, response: Response, user_service: UserService = Depends()):
user = user_service.get_user_by_id(user_id)
if not user:
raise HTTPException(
@ -58,9 +65,38 @@ def return_bearer_token(user_id: str, user_service: UserService = Depends()):
detail="Invalid user ID"
)
access_token = create_access_token(user_id=user_id)
return {"access_token": access_token}
access_token = create_access_token(user_id)
refresh_token = create_refresh_token(user_id)
@auth_router.get("/api/authentication", tags=["Authentication"])
response.set_cookie(
key="access_token", value=access_token, httponly=True, secure=True, max_age=ACCESS_TOKEN_EXPIRE_MINUTES * 60
)
response.set_cookie(
key="refresh_token", value=refresh_token, httponly=True, secure=True, max_age=REFRESH_TOKEN_EXPIRE_DAYS * 24 * 60 * 60
)
return {"message": "Tokens set as cookies"}
@api.post("/refresh", tags=["Authentication"])
def refresh_access_token(response: Response, refresh_token: str = Depends(Cookie(None))):
try:
payload = jwt.decode(refresh_token, JWT_SECRET, algorithms=[JWT_ALGORITHM])
user_id = payload.get("user_id")
if not user_id:
raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED, detail="Invalid refresh token")
new_access_token = create_access_token(user_id)
response.set_cookie(
key="access_token", value=new_access_token, httponly=True, secure=True, max_age=ACCESS_TOKEN_EXPIRE_MINUTES * 60
)
return {"message": "Access token refreshed"}
except jwt.ExpiredSignatureError:
raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED, detail="Refresh token expired")
except jwt.PyJWTError:
raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED, detail="Invalid refresh token")
@api.get("", include_in_schema=False, tags=["Authentication"])
def get_user_id(user_service: UserService = Depends()):
return user_service.all()

View File

@ -26,13 +26,11 @@ app = FastAPI(
app.add_middleware(GZipMiddleware)
feature_apis = [user, health, service, resource, tag]
feature_apis = [user, health, service, resource, tag, authentication]
for feature_api in feature_apis:
app.include_router(feature_api.api)
app.include_router(authentication.auth_router)
# Add application-wide exception handling middleware for commonly encountered API Exceptions
@app.exception_handler(Exception)
def permission_exception_handler(request: Request, e: Exception):