124 lines
4.3 KiB
Python
124 lines
4.3 KiB
Python
from __future__ import annotations
|
|
from typing import List, Optional, Tuple
|
|
from uuid import UUID as UUID_t
|
|
from datetime import date
|
|
import re, ast
|
|
from sqlalchemy.orm import Session, joinedload
|
|
from backend import models
|
|
|
|
ALLOWED_NAMES = {"days", "nights"}
|
|
ALLOWED_NODES = (
|
|
ast.Expression, ast.BinOp, ast.UnaryOp, ast.Num, ast.Add, ast.Sub, ast.Mult, ast.Div, ast.FloorDiv,
|
|
ast.Mod, ast.Pow, ast.USub, ast.Load, ast.Name, ast.Constant, ast.Call
|
|
)
|
|
|
|
# Very small safe-eval for placeholders like {days * 10}
|
|
|
|
def _safe_eval(expr: str, ctx: dict) -> str:
|
|
node = ast.parse(expr, mode="eval")
|
|
for n in ast.walk(node):
|
|
if not isinstance(n, ALLOWED_NODES):
|
|
raise ValueError("disallowed expression")
|
|
if isinstance(n, ast.Name) and n.id not in ALLOWED_NAMES:
|
|
raise ValueError("unknown name")
|
|
if isinstance(n, ast.Call):
|
|
raise ValueError("calls not allowed")
|
|
val = eval(compile(node, "<expr>", "eval"), {"__builtins__": {}}, ctx)
|
|
return str(int(val)) if isinstance(val, (int, float)) and float(val).is_integer() else str(val)
|
|
|
|
_placeholder_re = re.compile(r"\{([^{}]+)\}")
|
|
|
|
def render_name(name_template: str, start: Optional[date], end: Optional[date]) -> str:
|
|
if not name_template:
|
|
return ""
|
|
days = nights = 0
|
|
if start and end:
|
|
days = (end - start).days + 1
|
|
nights = max(days - 1, 0)
|
|
ctx = {"days": days, "nights": nights}
|
|
|
|
def repl(m):
|
|
expr = m.group(1).strip()
|
|
try:
|
|
return _safe_eval(expr, ctx)
|
|
except Exception:
|
|
# if not evaluable, leave placeholder as-is
|
|
return m.group(0)
|
|
|
|
return _placeholder_re.sub(repl, name_template)
|
|
|
|
|
|
def items_for_trip(db: Session, user_id: UUID_t, trip: models.Trip, selected_tag_ids: List[UUID_t]) -> List[models.Item]:
|
|
# Items without trip_id and tags (always) + items without trip_id and with any of the selected_tags + items with trip_id equal to the current trip
|
|
# but: if an item has a mandatory tag, it is only included if at least one of its mandatory tags is selected.
|
|
q = (
|
|
db.query(models.Item)
|
|
.options(joinedload(models.Item.tags))
|
|
.filter(models.Item.user_id == user_id)
|
|
)
|
|
items = q.all()
|
|
|
|
selected_set = set(selected_tag_ids)
|
|
result: List[models.Item] = []
|
|
for it in items:
|
|
if it.trip_id is None:
|
|
item_tag_ids = {tag.id for tag in it.tags}
|
|
mandatory_tag_ids = {tag.id for tag in it.tags if getattr(tag, "mandatory", False)}
|
|
if not item_tag_ids:
|
|
result.append(it)
|
|
elif mandatory_tag_ids:
|
|
# Nur aufnehmen, wenn ALLE mandatory tags ausgewählt sind
|
|
if mandatory_tag_ids <= selected_set:
|
|
result.append(it)
|
|
elif selected_set & item_tag_ids:
|
|
result.append(it)
|
|
elif it.trip_id == trip.id:
|
|
result.append(it)
|
|
|
|
return result
|
|
|
|
|
|
def generate_trip_items(
|
|
db: Session,
|
|
*,
|
|
trip: models.Trip,
|
|
selected_tag_ids: List[UUID_t],
|
|
marked_tag_ids: List[UUID_t],
|
|
) -> Tuple[List[UUID_t], List[UUID_t]]:
|
|
"""Regeneriert TripItems für einen Trip. Löscht alte, legt neue an.
|
|
Gibt (created_ids, deleted_checked_ids) zurück."""
|
|
# Sammle bestehende checked Items, falls sie verschwinden
|
|
deleted_checked: List[UUID_t] = []
|
|
|
|
# Lösche alle existierenden TripItems und merke checked, die wegfallen
|
|
for ti in list(trip.trip_items):
|
|
if ti.checked:
|
|
deleted_checked.append(ti.id)
|
|
db.delete(ti)
|
|
db.flush()
|
|
|
|
items = items_for_trip(db, trip.user_id, trip, selected_tag_ids)
|
|
|
|
created_ids: List[UUID_t] = []
|
|
marked_set = set(marked_tag_ids)
|
|
|
|
for it in items:
|
|
item_tag_ids = {tag.id for tag in it.tags}
|
|
intersection = item_tag_ids & marked_set
|
|
per_tags = sorted(list(intersection)) if (marked_set and intersection) else [None]
|
|
|
|
for tag_id in per_tags:
|
|
calc = render_name(it.name, trip.start_date, trip.end_date)
|
|
ti = models.TripItem(
|
|
trip_id=trip.id,
|
|
item_id=it.id,
|
|
name_calculated=calc,
|
|
checked=False,
|
|
tag_id=tag_id,
|
|
)
|
|
db.add(ti)
|
|
db.flush()
|
|
created_ids.append(ti.id)
|
|
|
|
db.flush()
|
|
return created_ids, deleted_checked
|