packlist/backend/crud.py

119 lines
4.1 KiB
Python

from __future__ import annotations
from typing import List, Optional, Tuple
from uuid import UUID as UUID_t
from datetime import date
import re, ast
from sqlalchemy.orm import Session, joinedload
from backend import models
ALLOWED_NAMES = {"days", "nights"}
ALLOWED_NODES = (
ast.Expression, ast.BinOp, ast.UnaryOp, ast.Num, ast.Add, ast.Sub, ast.Mult, ast.Div, ast.FloorDiv,
ast.Mod, ast.Pow, ast.USub, ast.Load, ast.Name, ast.Constant, ast.Call
)
# Very small safe-eval for placeholders like {days * 10}
def _safe_eval(expr: str, ctx: dict) -> str:
node = ast.parse(expr, mode="eval")
for n in ast.walk(node):
if not isinstance(n, ALLOWED_NODES):
raise ValueError("disallowed expression")
if isinstance(n, ast.Name) and n.id not in ALLOWED_NAMES:
raise ValueError("unknown name")
if isinstance(n, ast.Call):
raise ValueError("calls not allowed")
val = eval(compile(node, "<expr>", "eval"), {"__builtins__": {}}, ctx)
return str(int(val)) if isinstance(val, (int, float)) and float(val).is_integer() else str(val)
_placeholder_re = re.compile(r"\{([^{}]+)\}")
def render_name(name_template: str, start: Optional[date], end: Optional[date]) -> str:
if not name_template:
return ""
days = nights = 0
if start and end:
days = (end - start).days + 1
nights = max(days - 1, 0)
ctx = {"days": days, "nights": nights}
def repl(m):
expr = m.group(1).strip()
try:
return _safe_eval(expr, ctx)
except Exception:
# if not evaluable, leave placeholder as-is
return m.group(0)
return _placeholder_re.sub(repl, name_template)
def items_for_trip(db: Session, user_id: UUID_t, trip: models.Trip, selected_tag_ids: List[UUID_t]) -> List[models.Item]:
# Items without trip_id and tags (always) + items without trip_id and with any of the selected_tags + items with trip_id equal to the current trip
# but: if an item has a mandatory tag, it is only included if at least one of its mandatory tags is selected.
q = (
db.query(models.Item)
.options(joinedload(models.Item.tags))
.filter(models.Item.user_id == user_id)
)
items = q.all()
selected_set = set(selected_tag_ids)
result: List[models.Item] = []
for it in items:
if it.trip_id is None:
item_tag_ids = {tag.id for tag in it.tags}
mandatory_tag_ids = {tag.id for tag in it.tags if getattr(tag, "mandatory", False)}
if not item_tag_ids:
result.append(it)
elif mandatory_tag_ids:
# Nur aufnehmen, wenn ALLE mandatory tags ausgewählt sind
if mandatory_tag_ids <= selected_set:
result.append(it)
elif selected_set & item_tag_ids:
result.append(it)
elif it.trip_id == trip.id:
result.append(it)
return result
def generate_trip_items(
db: Session,
*,
trip: models.Trip,
selected_tag_ids: List[UUID_t],
marked_tag_ids: List[UUID_t],
) -> Tuple[List[UUID_t], List[UUID_t]]:
# Map (item_id, tag_id) -> checked
previous_checked = {}
for ti in list(trip.trip_items):
previous_checked[(ti.item_id, ti.tag_id)] = ti.checked
db.delete(ti)
db.flush()
items = items_for_trip(db, trip.user_id, trip, selected_tag_ids)
created_ids: List[UUID_t] = []
marked_set = set(marked_tag_ids)
for it in items:
item_tag_ids = {tag.id for tag in it.tags}
intersection = item_tag_ids & marked_set
per_tags = sorted(list(intersection)) if (marked_set and intersection) else [None]
for tag_id in per_tags:
calc = render_name(it.name, trip.start_date, trip.end_date)
checked = previous_checked.get((it.id, tag_id), False)
ti = models.TripItem(
trip_id=trip.id,
item_id=it.id,
name_calculated=calc,
checked=checked,
tag_id=tag_id,
)
db.add(ti)
db.flush()
created_ids.append(ti.id)
db.flush()
return created_ids, []