compass/backend/api/authentication.py

102 lines
3.9 KiB
Python

import jwt
from datetime import datetime, timedelta, timezone
from fastapi import Cookie, Depends, HTTPException, Response, status, APIRouter
from fastapi.security import HTTPBearer, HTTPAuthorizationCredentials
from backend.env import getenv
from backend.models.user_model import User
from ..services import UserService
api = APIRouter(prefix="/api/authentication")
openapi_tags = {
"name": "Authentication",
"description": "Authentication of users and distributes bearer tokens",
}
JWT_SECRET = getenv("JWT_SECRET")
JWT_ALGORITHM = getenv("JWT_ALGORITHM")
ACCESS_TOKEN_EXPIRE_MINUTES = getenv("ACCESS_TOKEN_EXPIRE_MINUTES")
REFRESH_TOKEN_EXPIRE_DAYS = getenv("REFRESH_TOKEN_EXPIRE_DAYS")
def create_access_token(user_id: str) -> str:
expiration = datetime.now(timezone.utc) + timedelta(minutes=ACCESS_TOKEN_EXPIRE_MINUTES)
payload = {"user_id": user_id, "exp": expiration}
token = jwt.encode(payload, JWT_SECRET, algorithm=JWT_ALGORITHM)
return token
def create_refresh_token(user_id: str) -> str:
expiration = datetime.now(timezone.utc) + timedelta(days=REFRESH_TOKEN_EXPIRE_DAYS)
payload = {"user_id": user_id, "exp": expiration}
token = jwt.encode(payload, JWT_SECRET, algorithm=JWT_ALGORITHM)
return token
def registered_user(
token: HTTPAuthorizationCredentials = Depends(HTTPBearer()),
user_service: UserService = Depends()
) -> User:
try:
payload = jwt.decode(token.credentials, JWT_SECRET, algorithms=[JWT_ALGORITHM])
user_id = payload.get("user_id")
user = user_service.get_user_by_id(user_id)
if not user:
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail="User not found"
)
return user
except jwt.ExpiredSignatureError:
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail="Token expired"
)
except jwt.PyJWTError:
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail="Invalid token"
)
@api.post("", include_in_schema=False, tags=["Authentication"])
def return_bearer_token(user_id: str, response: Response, user_service: UserService = Depends()):
user = user_service.get_user_by_id(user_id)
if not user:
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail="Invalid user ID"
)
access_token = create_access_token(user_id)
refresh_token = create_refresh_token(user_id)
response.set_cookie(
key="access_token", value=access_token, httponly=True, secure=True, max_age=ACCESS_TOKEN_EXPIRE_MINUTES * 60
)
response.set_cookie(
key="refresh_token", value=refresh_token, httponly=True, secure=True, max_age=REFRESH_TOKEN_EXPIRE_DAYS * 24 * 60 * 60
)
return {"message": "Tokens set as cookies"}
@api.post("/refresh", tags=["Authentication"])
def refresh_access_token(response: Response, refresh_token: str = Depends(Cookie(None))):
try:
payload = jwt.decode(refresh_token, JWT_SECRET, algorithms=[JWT_ALGORITHM])
user_id = payload.get("user_id")
if not user_id:
raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED, detail="Invalid refresh token")
new_access_token = create_access_token(user_id)
response.set_cookie(
key="access_token", value=new_access_token, httponly=True, secure=True, max_age=ACCESS_TOKEN_EXPIRE_MINUTES * 60
)
return {"message": "Access token refreshed"}
except jwt.ExpiredSignatureError:
raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED, detail="Refresh token expired")
except jwt.PyJWTError:
raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED, detail="Invalid refresh token")
@api.get("", include_in_schema=False, tags=["Authentication"])
def get_user_id(user_service: UserService = Depends()):
return user_service.all()