feat: add mandatory tag support and update related functionality in tags and items

This commit is contained in:
Felix Zett 2025-08-31 22:25:11 +02:00
parent ff4d243532
commit 22a34d7ad8
5 changed files with 37 additions and 7 deletions

View file

@ -1,4 +1,3 @@
from __future__ import annotations from __future__ import annotations
from typing import List, Optional, Tuple from typing import List, Optional, Tuple
from uuid import UUID as UUID_t from uuid import UUID as UUID_t
@ -50,10 +49,11 @@ def render_name(name_template: str, start: Optional[date], end: Optional[date])
def items_for_trip(db: Session, user_id: UUID_t, selected_tag_ids: List[UUID_t]) -> List[models.Item]: def items_for_trip(db: Session, user_id: UUID_t, selected_tag_ids: List[UUID_t]) -> List[models.Item]:
# Items without tags (always) + items with any of the selected_tags # Items without tags (always) + items with any of the selected_tags,
# but: if an item has a mandatory tag, it is only included if at least one of its mandatory tags is selected.
q = ( q = (
db.query(models.Item) db.query(models.Item)
.options(joinedload(models.Item.tags)) # Load tags directly .options(joinedload(models.Item.tags))
.filter(models.Item.user_id == user_id) .filter(models.Item.user_id == user_id)
) )
items = q.all() items = q.all()
@ -61,9 +61,14 @@ def items_for_trip(db: Session, user_id: UUID_t, selected_tag_ids: List[UUID_t])
selected_set = set(selected_tag_ids) selected_set = set(selected_tag_ids)
result: List[models.Item] = [] result: List[models.Item] = []
for it in items: for it in items:
item_tag_ids = {tag.id for tag in it.tags} # Tag objects now item_tag_ids = {tag.id for tag in it.tags}
mandatory_tag_ids = {tag.id for tag in it.tags if getattr(tag, "mandatory", False)}
if not item_tag_ids: if not item_tag_ids:
result.append(it) result.append(it)
elif mandatory_tag_ids:
# Only include if at least one mandatory tag is selected
if selected_set & mandatory_tag_ids:
result.append(it)
elif selected_set & item_tag_ids: elif selected_set & item_tag_ids:
result.append(it) result.append(it)
return result return result

View file

@ -48,6 +48,7 @@ class Tag(Base):
id = Column(UUID(as_uuid=True), primary_key=True, default=uuid.uuid4) id = Column(UUID(as_uuid=True), primary_key=True, default=uuid.uuid4)
user_id = Column(UUID(as_uuid=True), ForeignKey("user.id", ondelete="CASCADE"), nullable=False) user_id = Column(UUID(as_uuid=True), ForeignKey("user.id", ondelete="CASCADE"), nullable=False)
name = Column(String, nullable=False) name = Column(String, nullable=False)
mandatory = Column(Boolean, nullable=False, default=False)
user = relationship("User", backref="tags") user = relationship("User", backref="tags")

View file

@ -28,7 +28,12 @@ def dev_seed(db: Session = Depends(get_db)):
.first() .first()
) )
if not tag: if not tag:
tag = models.Tag(id=uuid4(), user_id=user.id, name=name) tag = models.Tag(
id=uuid4(),
user_id=user.id,
name=name,
mandatory=(name == "sommer")
)
db.add(tag) db.add(tag)
db.flush() db.flush()
name_to_tag[name] = tag name_to_tag[name] = tag

View file

@ -1,4 +1,3 @@
from fastapi import APIRouter, Depends, HTTPException from fastapi import APIRouter, Depends, HTTPException
from sqlalchemy.orm import Session from sqlalchemy.orm import Session
from uuid import UUID from uuid import UUID
@ -24,8 +23,27 @@ def create_tag(payload: TagCreate, db: Session = Depends(get_db)):
existing = db.query(models.Tag).filter(models.Tag.user_id == user.id, models.Tag.name == payload.name).first() existing = db.query(models.Tag).filter(models.Tag.user_id == user.id, models.Tag.name == payload.name).first()
if existing: if existing:
raise HTTPException(status_code=400, detail="Tag already exists") raise HTTPException(status_code=400, detail="Tag already exists")
tag = models.Tag(user_id=user.id, name=payload.name) tag = models.Tag(user_id=user.id, name=payload.name, mandatory=payload.mandatory)
db.add(tag) db.add(tag)
db.commit() db.commit()
db.refresh(tag) db.refresh(tag)
return tag return tag
@router.put("/{tag_id}", response_model=TagOut)
def update_tag(tag_id: UUID, payload: TagCreate, db: Session = Depends(get_db)):
tag = db.get(models.Tag, tag_id)
if not tag:
raise HTTPException(status_code=404, detail="Tag not found")
tag.name = payload.name
tag.mandatory = payload.mandatory
db.commit()
db.refresh(tag)
return tag
@router.delete("/{tag_id}", status_code=204)
def delete_tag(tag_id: UUID, db: Session = Depends(get_db)):
tag = db.get(models.Tag, tag_id)
if not tag:
raise HTTPException(status_code=404, detail="Tag not found")
db.delete(tag)
db.commit()

View file

@ -5,6 +5,7 @@ from pydantic import BaseModel
class TagBase(BaseModel): class TagBase(BaseModel):
name: str name: str
mandatory: bool = False
class TagCreate(TagBase): class TagCreate(TagBase):
pass pass