diff --git a/.vscode/settings.json b/.vscode/settings.json new file mode 100644 index 0000000..7581cbf --- /dev/null +++ b/.vscode/settings.json @@ -0,0 +1,7 @@ +{ + "python.testing.pytestArgs": [ + "backend" + ], + "python.testing.unittestEnabled": false, + "python.testing.pytestEnabled": true +} \ No newline at end of file diff --git a/backend/entities/__init__.py b/backend/entities/__init__.py index 0b82067..546a7b5 100644 --- a/backend/entities/__init__.py +++ b/backend/entities/__init__.py @@ -1,10 +1,9 @@ from .entity_base import EntityBase -from .sample_entity import SampleEntity from .tag_entity import TagEntity from .user_entity import UserEntity from .resource_entity import ResourceEntity from .resource_tag_entity import ResourceTagEntity from .service_entity import ServiceEntity from .service_tag_entity import ServiceTagEntity -from .program_enum import ProgramEnum -from .user_enum import RoleEnum +from .program_enum import Program_Enum +from .user_enum import Role_Enum diff --git a/backend/entities/program_enum.py b/backend/entities/program_enum.py index 3a207bd..c2823c2 100644 --- a/backend/entities/program_enum.py +++ b/backend/entities/program_enum.py @@ -1,10 +1,7 @@ -from sqlalchemy import Enum +from enum import Enum -class ProgramEnum(Enum): - ECONOMIC = "economic" - DOMESTIC = "domestic" - COMMUNITY = "community" - - def __init__(self): - super().__init__(name="program_enum") +class Program_Enum(Enum): + ECONOMIC = "ECONOMIC" + DOMESTIC = "DOMESTIC" + COMMUNITY = "COMMUNITY" diff --git a/backend/entities/resource_entity.py b/backend/entities/resource_entity.py index b38e625..cbe5d7c 100644 --- a/backend/entities/resource_entity.py +++ b/backend/entities/resource_entity.py @@ -1,7 +1,7 @@ """ Defines the table for storing resources """ # Import our mapped SQL types from SQLAlchemy -from sqlalchemy import Integer, String, DateTime +from sqlalchemy import Integer, String, DateTime, Enum # Import mapping capabilities from the SQLAlchemy ORM from sqlalchemy.orm import Mapped, mapped_column, relationship @@ -14,7 +14,7 @@ from datetime import datetime # Import self for to model from typing import Self -from backend.entities.program_enum import ProgramEnum +from backend.entities.program_enum import Program_Enum class ResourceEntity(EntityBase): @@ -28,8 +28,7 @@ class ResourceEntity(EntityBase): name: Mapped[str] = mapped_column(String(32), nullable=False) summary: Mapped[str] = mapped_column(String(100), nullable=False) link: Mapped[str] = mapped_column(String, nullable=False) - program: Mapped[ProgramEnum] = mapped_column(ProgramEnum, nullable=False) - + program: Mapped[Program_Enum] = mapped_column(Enum(Program_Enum), nullable=False) # relationships resourceTags: Mapped[list["ResourceTagEntity"]] = relationship( back_populates="resource", cascade="all,delete" diff --git a/backend/entities/resource_tag_entity.py b/backend/entities/resource_tag_entity.py index f1de522..e6d863b 100644 --- a/backend/entities/resource_tag_entity.py +++ b/backend/entities/resource_tag_entity.py @@ -19,7 +19,7 @@ from typing import Self class ResourceTagEntity(EntityBase): # set table name to user in the database - __tablename__ = "resourceTag" + __tablename__ = "resource_tag" # set fields or 'columns' for the user table id: Mapped[int] = mapped_column(Integer, primary_key=True, autoincrement=True) diff --git a/backend/entities/sample_entity.py b/backend/entities/sample_entity.py deleted file mode 100644 index 5372899..0000000 --- a/backend/entities/sample_entity.py +++ /dev/null @@ -1,12 +0,0 @@ -from sqlalchemy import create_engine, Column, Integer, String -from sqlalchemy.orm import Mapped, mapped_column, relationship -from .entity_base import EntityBase - - -class SampleEntity(EntityBase): - __tablename__ = "persons" - - id: Mapped[int] = mapped_column(Integer, primary_key=True, autoincrement=True) - name: Mapped[str] = mapped_column(String, nullable=False) - age: Mapped[int] = mapped_column(Integer) - email: Mapped[str] = mapped_column(String, unique=True, nullable=False) diff --git a/backend/entities/service_entity.py b/backend/entities/service_entity.py index e927be7..b6f999a 100644 --- a/backend/entities/service_entity.py +++ b/backend/entities/service_entity.py @@ -13,18 +13,10 @@ from .entity_base import EntityBase from datetime import datetime # Import enums for Program -import enum +from .program_enum import Program_Enum from sqlalchemy import Enum -class ProgramEnum(enum.Enum): - """Determine program for Service""" - - DOMESTIC = "DOMESTIC" - ECONOMIC = "ECONOMIC" - COMMUNITY = "COMMUNITY" - - class ServiceEntity(EntityBase): # set table name @@ -36,7 +28,7 @@ class ServiceEntity(EntityBase): name: Mapped[str] = mapped_column(String(32), nullable=False) summary: Mapped[str] = mapped_column(String(100), nullable=False) requirements: Mapped[list[str]] = mapped_column(ARRAY(String)) - program: Mapped[ProgramEnum] = mapped_column(Enum(ProgramEnum), nullable=False) + program: Mapped[Program_Enum] = mapped_column(Enum(Program_Enum), nullable=False) # relationships serviceTags: Mapped[list["ServiceTagEntity"]] = relationship( diff --git a/backend/entities/service_tag_entity.py b/backend/entities/service_tag_entity.py index c1dbdc7..0f05738 100644 --- a/backend/entities/service_tag_entity.py +++ b/backend/entities/service_tag_entity.py @@ -13,7 +13,7 @@ from .entity_base import EntityBase class ServiceTagEntity(EntityBase): # set table name to user in the database - __tablename__ = "serviceTag" + __tablename__ = "service_tag" # set fields or 'columns' for the user table id: Mapped[int] = mapped_column(Integer, primary_key=True, autoincrement=True) @@ -21,5 +21,5 @@ class ServiceTagEntity(EntityBase): tagId: Mapped[int] = mapped_column(ForeignKey("tag.id")) # relationships - service: Mapped["ServiceEntity"] = relationship(back_populates="resourceTags") - tag: Mapped["TagEntity"] = relationship(back_populates="resourceTags") + service: Mapped["ServiceEntity"] = relationship(back_populates="serviceTags") + tag: Mapped["TagEntity"] = relationship(back_populates="serviceTags") diff --git a/backend/entities/tag_entity.py b/backend/entities/tag_entity.py index e61f1ee..0d1548b 100644 --- a/backend/entities/tag_entity.py +++ b/backend/entities/tag_entity.py @@ -12,6 +12,10 @@ from .entity_base import EntityBase # Import datetime for created_at type from datetime import datetime +from ..models.tag_model import Tag + +from typing import Self + class TagEntity(EntityBase): #set table name @@ -27,17 +31,17 @@ class TagEntity(EntityBase): serviceTags: Mapped[list["ServiceTagEntity"]] = relationship(back_populates="tag", cascade="all,delete") - """ + @classmethod def from_model(cls, model: Tag) -> Self: - + """ Create a user entity from model Args: model (User): the model to create the entity from Returns: self: The entity - + """ return cls( id=model.id, @@ -45,18 +49,17 @@ class TagEntity(EntityBase): ) def to_model(self) -> Tag: - + """ Create a user model from entity Returns: User: A User model for API usage - + """ return Tag( id=self.id, - content=self.id, + content=self.content, ) - """ diff --git a/backend/entities/user_entity.py b/backend/entities/user_entity.py index 8f82e4d..ebc7a87 100644 --- a/backend/entities/user_entity.py +++ b/backend/entities/user_entity.py @@ -1,24 +1,25 @@ """ Defines the table for storing users """ # Import our mapped SQL types from SQLAlchemy -from sqlalchemy import Integer, String, DateTime, ARRAY - +from sqlalchemy import Integer, String, DateTime, ARRAY, Enum # Import mapping capabilities from the SQLAlchemy ORM from sqlalchemy.orm import Mapped, mapped_column - # Import the EntityBase that we are extending from .entity_base import EntityBase - # Import datetime for created_at type from datetime import datetime - # Import enums for Role and Program -from backend.entities.program_enum import ProgramEnum -from .user_enum import RoleEnum +from .program_enum import Program_Enum +from .user_enum import Role_Enum + +# Import models for User methods +from ..models.user_model import User + +from typing import Self class UserEntity(EntityBase): @@ -33,35 +34,28 @@ class UserEntity(EntityBase): username: Mapped[str] = mapped_column( String(32), nullable=False, default="", unique=True ) - role: Mapped[RoleEnum] = mapped_column(RoleEnum, nullable=False) - username: Mapped[str] = mapped_column( - String(32), nullable=False, default="", unique=True - ) - role: Mapped[RoleEnum] = mapped_column(RoleEnum, nullable=False) + role: Mapped[Role_Enum] = mapped_column(Enum(Role_Enum), nullable=False) email: Mapped[str] = mapped_column(String(50), nullable=False, unique=True) - program: Mapped[list[ProgramEnum]] = mapped_column( - ARRAY(ProgramEnum), nullable=False - ) - program: Mapped[list[ProgramEnum]] = mapped_column( - ARRAY(ProgramEnum), nullable=False + program: Mapped[list[Program_Enum]] = mapped_column( + ARRAY(Enum(Program_Enum)), nullable=False ) experience: Mapped[int] = mapped_column(Integer, nullable=False) group: Mapped[str] = mapped_column(String(50)) - """ @classmethod def from_model(cls, model: User) -> Self: - + """ Create a user entity from model Args: model (User): the model to create the entity from Returns: self: The entity - + """ return cls( id=model.id, + created_at=model.created_at, username=model.username, role=model.role, email=model.email, @@ -71,20 +65,22 @@ class UserEntity(EntityBase): ) def to_model(self) -> User: - + """ + Create a user model from entity Returns: User: A User model for API usage - + + """ return User( id=self.id, - username=self.id, - role=self.role, + username=self.username, email=self.email, - program=self.program, experience=self.experience, group=self.group, + program=self.program, + role=self.role, + created_at=self.created_at, ) - """ diff --git a/backend/entities/user_enum.py b/backend/entities/user_enum.py index 99594ec..af5aba6 100644 --- a/backend/entities/user_enum.py +++ b/backend/entities/user_enum.py @@ -1,12 +1,9 @@ -from sqlalchemy import Enum +from enum import Enum -class RoleEnum(Enum): +class Role_Enum(Enum): """Determine role for User""" ADMIN = "ADMIN" EMPLOYEE = "EMPLOYEE" VOLUNTEER = "VOLUNTEER" - - def __init__(self): - super().__init__(name="role_enum") diff --git a/backend/models/enum_for_models.py b/backend/models/enum_for_models.py index 8e6cdfe..ce2eb30 100644 --- a/backend/models/enum_for_models.py +++ b/backend/models/enum_for_models.py @@ -1,8 +1,4 @@ -from pydantic import BaseModel, Field from enum import Enum -from typing import List -from datetime import datetime -from typing import Optional class ProgramTypeEnum(str, Enum): diff --git a/backend/models/user_model.py b/backend/models/user_model.py index c881d54..1ba4e18 100644 --- a/backend/models/user_model.py +++ b/backend/models/user_model.py @@ -12,6 +12,6 @@ class User(BaseModel): email: str = Field(..., description="The e-mail of the user") experience: int = Field(..., description="Years of Experience of the User") group: str - programtype: List[ProgramTypeEnum] - usertype: UserTypeEnum + program: List[ProgramTypeEnum] + role: UserTypeEnum created_at: Optional[datetime] diff --git a/backend/script/create_database.py b/backend/script/create_database.py index f969f0b..54babaf 100644 --- a/backend/script/create_database.py +++ b/backend/script/create_database.py @@ -6,9 +6,7 @@ engine = create_engine(_engine_str(database=""), echo=True) """Application-level SQLAlchemy database engine.""" with engine.connect() as connection: - connection.execute( - text("COMMIT") - ) + connection.execute(text("COMMIT")) database = getenv("POSTGRES_DATABASE") stmt = text(f"CREATE DATABASE {database}") - connection.execute(stmt) \ No newline at end of file + connection.execute(stmt) diff --git a/backend/script/delete_database.py b/backend/script/delete_database.py index 513aec4..707c63a 100644 --- a/backend/script/delete_database.py +++ b/backend/script/delete_database.py @@ -6,9 +6,7 @@ engine = create_engine(_engine_str(database=""), echo=True) """Application-level SQLAlchemy database engine.""" with engine.connect() as connection: - connection.execute( - text("COMMIT") - ) + connection.execute(text("COMMIT")) database = getenv("POSTGRES_DATABASE") stmt = text(f"DROP DATABASE IF EXISTS {database}") - connection.execute(stmt) \ No newline at end of file + connection.execute(stmt) diff --git a/backend/services/__init__.py b/backend/services/__init__.py index e69de29..4067973 100644 --- a/backend/services/__init__.py +++ b/backend/services/__init__.py @@ -0,0 +1,4 @@ +from .user import UserService +from .resource import ResourceService +from .tag import TagService +from .service import ServiceService \ No newline at end of file diff --git a/backend/services/resouce.py b/backend/services/resource.py similarity index 100% rename from backend/services/resouce.py rename to backend/services/resource.py diff --git a/backend/services/tag.py b/backend/services/tag.py index b66f1fe..dfc369a 100644 --- a/backend/services/tag.py +++ b/backend/services/tag.py @@ -1,9 +1,20 @@ from fastapi import Depends from ..database import db_session from sqlalchemy.orm import Session +from ..models.tag_model import Tag +from ..entities.tag_entity import TagEntity +from sqlalchemy import select class TagService: def __init__(self, session: Session = Depends(db_session)): self._session = session + + def all(self) -> list[Tag]: + """Returns a list of all Tags""" + + query = select(TagEntity) + entities = self._session.scalars(query).all() + + return [entity.to_model() for entity in entities] diff --git a/backend/services/user.py b/backend/services/user.py index 629e22b..8e0b541 100644 --- a/backend/services/user.py +++ b/backend/services/user.py @@ -1,9 +1,59 @@ from fastapi import Depends from ..database import db_session from sqlalchemy.orm import Session +from ..entities.user_entity import UserEntity +from ..models.user_model import User +from sqlalchemy import select class UserService: def __init__(self, session: Session = Depends(db_session)): self._session = session + + def get_user_by_id(self, id: int) -> User: + """ + Gets a user by id from the database + + Returns: A User Pydantic model + + """ + query = select(UserEntity).where(UserEntity.id == id) + user_entity: UserEntity | None = self._session.scalar(query) + + if user_entity is None: + raise Exception(f"No user found with matching id: {id}") + + return user_entity.to_model() + + def all(self) -> list[User]: + """ + Returns a list of all Users + + """ + query = select(UserEntity) + entities = self._session.scalars(query).all() + + return [entity.to_model() for entity in entities] + + def create(self, user: User) -> User: + """ + Creates a new User Entity and adds to database + + Args: User model + + Returns: User model + + """ + try: + user = self.get_user_by_id(user.id) + except: + # if does not exist, create new object + user_entity = UserEntity.from_model(user) + + # add new user to table + self._session.add(user_entity) + self._session.commit() + finally: + # return added object + return user diff --git a/backend/test/entities/conftest.py b/backend/test/conftest.py similarity index 78% rename from backend/test/entities/conftest.py rename to backend/test/conftest.py index 63e15a5..777b219 100644 --- a/backend/test/entities/conftest.py +++ b/backend/test/conftest.py @@ -4,14 +4,16 @@ import pytest from sqlalchemy import Engine, create_engine, text from sqlalchemy.orm import Session from sqlalchemy.exc import OperationalError +from .services import user_test_data, tag_test_data -from ...database import _engine_str -from ...env import getenv -from ... import entities +from ..database import _engine_str +from ..env import getenv +from .. import entities POSTGRES_DATABASE = f'{getenv("POSTGRES_DATABASE")}_test' POSTGRES_USER = getenv("POSTGRES_USER") + def reset_database(): engine = create_engine(_engine_str(database="")) with engine.connect() as connection: @@ -48,3 +50,11 @@ def session(test_engine: Engine): yield session finally: session.close() + + +@pytest.fixture(autouse=True) +def setup_insert_data_fixture(session: Session): + user_test_data.insert_fake_data(session) + tag_test_data.insert_fake_data(session) + session.commit() + yield diff --git a/backend/test/entities/tag_entity_test.py b/backend/test/entities/tag_entity_test.py index b2b2f7f..7f58b58 100644 --- a/backend/test/entities/tag_entity_test.py +++ b/backend/test/entities/tag_entity_test.py @@ -1,19 +1,4 @@ -""" Testing Tag Entity """ -from sqlalchemy import Engine -from ... import entities -from ...entities.tag_entity import TagEntity - - -def test_add_sample_data_tag(session: Engine): - - """Inserts a sample data point and verifies it is in the database""" - entity = TagEntity(content="Test tag") - session.add(entity) - session.commit() - data = session.get(TagEntity, 1) - assert data.id == 1 - assert data.content == "Test tag" \ No newline at end of file diff --git a/backend/test/entities/user_entity_test.py b/backend/test/entities/user_entity_test.py index 1f775ce..e69de29 100644 --- a/backend/test/entities/user_entity_test.py +++ b/backend/test/entities/user_entity_test.py @@ -1,24 +0,0 @@ -""" Testing User Entity """ - -from sqlalchemy import Engine -from ... import entities -from ...entities.user_entity import UserEntity -from ...entities.user_entity import RoleEnum -from ...entities.user_entity import ProgramEnum - -def test_add_sample_data_user(session: Engine): - - - """Inserts a sample data point and verifies it is in the database""" - entity = UserEntity(id=1, username="emmalynf", role=RoleEnum.ADMIN, email="efoster@unc.edu", program=[ProgramEnum.COMMUNITY, ProgramEnum.DOMESTIC, ProgramEnum.ECONOMIC], experience=10, group="group") - session.add(entity) - session.commit() - data = session.get(UserEntity, 1) - assert data.id == 1 - assert data.username == "emmalynf" - assert data.email == "efoster@unc.edu" - assert data.experience == 10 - assert data.role == RoleEnum.ADMIN - assert data.program == [ProgramEnum.COMMUNITY, ProgramEnum.DOMESTIC, ProgramEnum.ECONOMIC] - - \ No newline at end of file diff --git a/backend/test/services/__init__.py b/backend/test/services/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/backend/test/services/fixtures.py b/backend/test/services/fixtures.py new file mode 100644 index 0000000..e6fd67b --- /dev/null +++ b/backend/test/services/fixtures.py @@ -0,0 +1,20 @@ +"""Fixtures used for testing the core services.""" + +import pytest +from unittest.mock import create_autospec +from sqlalchemy.orm import Session +from ...services import UserService +from ...services import TagService + + + + +@pytest.fixture() +def user_svc(session: Session): + """This fixture is used to test the UserService class""" + return UserService(session) + +@pytest.fixture() +def tag_svc(session: Session): + """This fixture is used to test the TagService class""" + return TagService(session) diff --git a/backend/test/services/tag_test.py b/backend/test/services/tag_test.py index e69de29..fe7597c 100644 --- a/backend/test/services/tag_test.py +++ b/backend/test/services/tag_test.py @@ -0,0 +1,14 @@ +"""Tests for the TagService class.""" + +# PyTest +import pytest +from ...services.tag import TagService +from .fixtures import tag_svc +from .tag_test_data import tag1, tag2, tag3 +from . import tag_test_data + + +def test_get_all(tag_svc: TagService): + """Test that all tags can be retrieved.""" + tags = tag_svc.all() + assert len(tags) == 3 \ No newline at end of file diff --git a/backend/test/services/tag_test_data.py b/backend/test/services/tag_test_data.py new file mode 100644 index 0000000..cb16e5c --- /dev/null +++ b/backend/test/services/tag_test_data.py @@ -0,0 +1,72 @@ +import pytest +from sqlalchemy.orm import Session +from ...models.tag_model import Tag + +from ...entities.tag_entity import TagEntity +from datetime import datetime + +tag1 = Tag(id=1, content="Tag 1", created_at=datetime.now()) + +tag2 = Tag(id=2, content="Tag 2", created_at=datetime.now()) + +tag3 = Tag(id=3, content="Tag 3", created_at=datetime.now()) + +tagToCreate = Tag(id=4, content="Tag 4", created_at=datetime.now()) + +tags = [tag1, tag2, tag3] + + +from sqlalchemy import text +from sqlalchemy.orm import Session, DeclarativeBase, InstrumentedAttribute + + +def reset_table_id_seq( + session: Session, + entity: type[DeclarativeBase], + entity_id_column: InstrumentedAttribute[int], + next_id: int, +) -> None: + """Reset the ID sequence of an entity table. + + Args: + session (Session) - A SQLAlchemy Session + entity (DeclarativeBase) - The SQLAlchemy Entity table to target + entity_id_column (MappedColumn) - The ID column (should be an int column) + next_id (int) - Where the next inserted, autogenerated ID should begin + + Returns: + None""" + table = entity.__table__ + id_column_name = entity_id_column.name + sql = text(f"ALTER SEQUENCe {table}_{id_column_name}_seq RESTART WITH {next_id}") + session.execute(sql) + + +def insert_fake_data(session: Session): + """Inserts fake organization data into the test session.""" + + global tags + + # Create entities for test organization data + entities = [] + for tag in tags: + entity = TagEntity.from_model(tag) + session.add(entity) + entities.append(entity) + + # Reset table IDs to prevent ID conflicts + reset_table_id_seq(session, TagEntity, TagEntity.id, len(tags) + 1) + + # Commit all changes + session.commit() + + +@pytest.fixture(autouse=True) +def fake_data_fixture(session: Session): + """Insert fake data the session automatically when test is run. + Note: + This function runs automatically due to the fixture property `autouse=True`. + """ + insert_fake_data(session) + session.commit() + yield diff --git a/backend/test/services/user_test.py b/backend/test/services/user_test.py index e69de29..34795a5 100644 --- a/backend/test/services/user_test.py +++ b/backend/test/services/user_test.py @@ -0,0 +1,46 @@ +"""Tests for the UserService class.""" + +# PyTest +import pytest + +from ...services import UserService +from .fixtures import user_svc +from ...models.enum_for_models import ProgramTypeEnum + +from .user_test_data import employee, volunteer, admin, newUser +from . import user_test_data + + +def test_create(user_svc: UserService): + """Test creating a user""" + user1 = user_svc.create(admin) + + print(user1) + assert user1 is not None + assert user1.id is not None + + +def test_create_id_exists(user_svc: UserService): + """Test creating a user with id conflict""" + user1 = user_svc.create(volunteer) + assert user1 is not None + assert user1.id is not None + + +def test_get_all(user_svc: UserService): + """Test that all users can be retrieved.""" + users = user_svc.all() + assert len(users) == 3 + + +def test_get_user_by_id(user_svc: UserService): + """Test getting a user by an id""" + user = user_svc.get_user_by_id(volunteer.id) + assert user is not None + assert user.id is not None + + +def test_get_user_by_id_nonexistent(user_svc: UserService): + """Test getting a user by id that does not exist""" + with pytest.raises(Exception): + user_svc.get_by_id(5) diff --git a/backend/test/services/user_test_data.py b/backend/test/services/user_test_data.py new file mode 100644 index 0000000..e050bda --- /dev/null +++ b/backend/test/services/user_test_data.py @@ -0,0 +1,118 @@ +import pytest +from sqlalchemy.orm import Session +from ...models.user_model import User + +# import model enums instead +from ...models.enum_for_models import UserTypeEnum, ProgramTypeEnum +from ...entities.user_entity import UserEntity +from datetime import datetime + + +programs = ProgramTypeEnum +roles = UserTypeEnum + +volunteer = User( + id=1, + username="volunteer", + email="volunteer@compass.com", + experience=1, + group="volunteers", + program=[programs.COMMUNITY], + created_at=datetime.now(), + role=UserTypeEnum.VOLUNTEER, +) + +employee = User( + id=2, + username="employee", + email="employee@compass.com", + experience=5, + group="employees", + program=[programs.DOMESTIC, programs.ECONOMIC], + created_at=datetime.now(), + role=roles.EMPLOYEE, +) + +admin = User( + id=3, + username="admin", + email="admin@compass.com", + experience=10, + group="admin", + program=[ + programs.ECONOMIC, + programs.DOMESTIC, + programs.COMMUNITY, + ], + created_at=datetime.now(), + role=roles.ADMIN, +) + +newUser = User( + id=4, + username="new", + email="new@compass.com", + experience=1, + group="volunteer", + program=[programs.ECONOMIC], + created_at=datetime.now(), + role=roles.VOLUNTEER, +) + +users = [volunteer, employee, admin] + + +from sqlalchemy import text +from sqlalchemy.orm import Session, DeclarativeBase, InstrumentedAttribute + + +def reset_table_id_seq( + session: Session, + entity: type[DeclarativeBase], + entity_id_column: InstrumentedAttribute[int], + next_id: int, +) -> None: + """Reset the ID sequence of an entity table. + + Args: + session (Session) - A SQLAlchemy Session + entity (DeclarativeBase) - The SQLAlchemy Entity table to target + entity_id_column (MappedColumn) - The ID column (should be an int column) + next_id (int) - Where the next inserted, autogenerated ID should begin + + Returns: + None""" + table = entity.__table__ + id_column_name = entity_id_column.name + sql = text(f"ALTER SEQUENCe {table}_{id_column_name}_seq RESTART WITH {next_id}") + session.execute(sql) + + +def insert_fake_data(session: Session): + """Inserts fake organization data into the test session.""" + + global users + + # Create entities for test organization data + entities = [] + for user in users: + entity = UserEntity.from_model(user) + session.add(entity) + entities.append(entity) + + # Reset table IDs to prevent ID conflicts + reset_table_id_seq(session, UserEntity, UserEntity.id, len(users) + 1) + + # Commit all changes + session.commit() + + +@pytest.fixture(autouse=True) +def fake_data_fixture(session: Session): + """Insert fake data the session automatically when test is run. + Note: + This function runs automatically due to the fixture property `autouse=True`. + """ + # insert_fake_data(session) + session.commit() + yield