diff --git a/backend/crud.py b/backend/crud.py index 0b6ec45..58f7680 100644 --- a/backend/crud.py +++ b/backend/crud.py @@ -1,4 +1,3 @@ - from __future__ import annotations from typing import List, Optional, Tuple from uuid import UUID as UUID_t @@ -50,10 +49,11 @@ def render_name(name_template: str, start: Optional[date], end: Optional[date]) def items_for_trip(db: Session, user_id: UUID_t, selected_tag_ids: List[UUID_t]) -> List[models.Item]: - # Items without tags (always) + items with any of the selected_tags + # Items without tags (always) + items with any of the selected_tags, + # but: if an item has a mandatory tag, it is only included if at least one of its mandatory tags is selected. q = ( db.query(models.Item) - .options(joinedload(models.Item.tags)) # Load tags directly + .options(joinedload(models.Item.tags)) .filter(models.Item.user_id == user_id) ) items = q.all() @@ -61,9 +61,14 @@ def items_for_trip(db: Session, user_id: UUID_t, selected_tag_ids: List[UUID_t]) selected_set = set(selected_tag_ids) result: List[models.Item] = [] for it in items: - item_tag_ids = {tag.id for tag in it.tags} # Tag objects now + item_tag_ids = {tag.id for tag in it.tags} + mandatory_tag_ids = {tag.id for tag in it.tags if getattr(tag, "mandatory", False)} if not item_tag_ids: result.append(it) + elif mandatory_tag_ids: + # Only include if at least one mandatory tag is selected + if selected_set & mandatory_tag_ids: + result.append(it) elif selected_set & item_tag_ids: result.append(it) return result diff --git a/backend/models.py b/backend/models.py index 2b93dcc..d303888 100644 --- a/backend/models.py +++ b/backend/models.py @@ -48,6 +48,7 @@ class Tag(Base): id = Column(UUID(as_uuid=True), primary_key=True, default=uuid.uuid4) user_id = Column(UUID(as_uuid=True), ForeignKey("user.id", ondelete="CASCADE"), nullable=False) name = Column(String, nullable=False) + mandatory = Column(Boolean, nullable=False, default=False) user = relationship("User", backref="tags") diff --git a/backend/routes/dev_seed.py b/backend/routes/dev_seed.py index 0a7c343..e47526a 100644 --- a/backend/routes/dev_seed.py +++ b/backend/routes/dev_seed.py @@ -28,7 +28,12 @@ def dev_seed(db: Session = Depends(get_db)): .first() ) if not tag: - tag = models.Tag(id=uuid4(), user_id=user.id, name=name) + tag = models.Tag( + id=uuid4(), + user_id=user.id, + name=name, + mandatory=(name == "sommer") + ) db.add(tag) db.flush() name_to_tag[name] = tag diff --git a/backend/routes/tags.py b/backend/routes/tags.py index 70a4a74..a45af40 100644 --- a/backend/routes/tags.py +++ b/backend/routes/tags.py @@ -1,4 +1,3 @@ - from fastapi import APIRouter, Depends, HTTPException from sqlalchemy.orm import Session from uuid import UUID @@ -24,8 +23,27 @@ def create_tag(payload: TagCreate, db: Session = Depends(get_db)): existing = db.query(models.Tag).filter(models.Tag.user_id == user.id, models.Tag.name == payload.name).first() if existing: raise HTTPException(status_code=400, detail="Tag already exists") - tag = models.Tag(user_id=user.id, name=payload.name) + tag = models.Tag(user_id=user.id, name=payload.name, mandatory=payload.mandatory) db.add(tag) db.commit() db.refresh(tag) return tag + +@router.put("/{tag_id}", response_model=TagOut) +def update_tag(tag_id: UUID, payload: TagCreate, db: Session = Depends(get_db)): + tag = db.get(models.Tag, tag_id) + if not tag: + raise HTTPException(status_code=404, detail="Tag not found") + tag.name = payload.name + tag.mandatory = payload.mandatory + db.commit() + db.refresh(tag) + return tag + +@router.delete("/{tag_id}", status_code=204) +def delete_tag(tag_id: UUID, db: Session = Depends(get_db)): + tag = db.get(models.Tag, tag_id) + if not tag: + raise HTTPException(status_code=404, detail="Tag not found") + db.delete(tag) + db.commit() diff --git a/backend/schemas.py b/backend/schemas.py index 16f5130..268f589 100644 --- a/backend/schemas.py +++ b/backend/schemas.py @@ -5,6 +5,7 @@ from pydantic import BaseModel class TagBase(BaseModel): name: str + mandatory: bool = False class TagCreate(TagBase): pass