154 lines
4.6 KiB
Python
154 lines
4.6 KiB
Python
from datetime import date
|
|
from fastapi import APIRouter, Depends, HTTPException
|
|
from sqlalchemy.orm import Session, joinedload
|
|
from uuid import UUID
|
|
from backend.database import get_db
|
|
from backend import models
|
|
from backend.schemas import TripCreate, TripOut, TripUpdate, TripRegenerationResult
|
|
from backend.crud import generate_trip_items
|
|
|
|
router = APIRouter(prefix="/trips", tags=["trips"])
|
|
|
|
@router.get("/", response_model=list[TripOut])
|
|
def list_trips(db: Session = Depends(get_db)):
|
|
trips = (
|
|
db.query(models.Trip)
|
|
.options(
|
|
joinedload(models.Trip.selected_tags),
|
|
joinedload(models.Trip.marked_tags),
|
|
)
|
|
.all()
|
|
)
|
|
return [
|
|
TripOut(
|
|
id=t.id,
|
|
name=t.name,
|
|
start_date=t.start_date,
|
|
end_date=t.end_date,
|
|
selected_tags=t.selected_tags,
|
|
marked_tags=t.marked_tags,
|
|
)
|
|
for t in trips
|
|
]
|
|
|
|
@router.post("/", response_model=TripOut)
|
|
def create_trip(payload: TripCreate, db: Session = Depends(get_db)):
|
|
user = db.query(models.User).first()
|
|
if not user:
|
|
from uuid import uuid4
|
|
user = models.User(id=uuid4(), name="Demo")
|
|
db.add(user)
|
|
db.flush()
|
|
|
|
trip = models.Trip(user_id=user.id,
|
|
name=payload.name,
|
|
start_date=payload.start_date,
|
|
end_date=payload.end_date)
|
|
|
|
for tag_id in payload.selected_tag_ids:
|
|
tag = db.query(models.Tag).get(tag_id)
|
|
if tag:
|
|
trip.selected_tags.append(tag)
|
|
|
|
for tag_id in payload.marked_tag_ids:
|
|
tag = db.query(models.Tag).get(tag_id)
|
|
if tag:
|
|
trip.marked_tags.append(tag)
|
|
|
|
db.add(trip)
|
|
db.flush()
|
|
|
|
|
|
# generate items per rules
|
|
created_ids, _ = generate_trip_items(
|
|
db,
|
|
trip=trip,
|
|
selected_tag_ids=payload.selected_tag_ids,
|
|
marked_tag_ids=payload.marked_tag_ids,
|
|
)
|
|
|
|
db.commit()
|
|
|
|
# reload with relationships
|
|
trip = (
|
|
db.query(models.Trip)
|
|
.options(
|
|
joinedload(models.Trip.selected_tags),
|
|
joinedload(models.Trip.marked_tags),
|
|
)
|
|
.get(trip.id)
|
|
)
|
|
return TripOut(
|
|
id=trip.id,
|
|
name=trip.name,
|
|
start_date=trip.start_date,
|
|
end_date=trip.end_date,
|
|
selected_tags=trip.selected_tags,
|
|
marked_tags=trip.marked_tags,
|
|
)
|
|
|
|
@router.put("/{trip_id}/reconfigure", response_model=TripRegenerationResult)
|
|
def reconfigure_trip(trip_id: UUID, payload: TripUpdate, db: Session = Depends(get_db)):
|
|
trip = db.get(models.Trip, trip_id)
|
|
if not trip:
|
|
raise HTTPException(status_code=404, detail="Trip not found")
|
|
|
|
# update base fields
|
|
if payload.name is not None:
|
|
trip.name = payload.name
|
|
if payload.start_date is not None:
|
|
trip.start_date = payload.start_date
|
|
if payload.end_date is not None:
|
|
trip.end_date = payload.end_date
|
|
db.flush()
|
|
|
|
# Always use a list, never None
|
|
selected_tag_ids = payload.selected_tag_ids or []
|
|
marked_tag_ids = payload.marked_tag_ids or []
|
|
|
|
for tag_id in selected_tag_ids:
|
|
tag = db.query(models.Tag).get(tag_id)
|
|
if tag and tag not in trip.selected_tags:
|
|
trip.selected_tags.append(tag)
|
|
|
|
for tag_id in marked_tag_ids:
|
|
tag = db.query(models.Tag).get(tag_id)
|
|
if tag and tag not in trip.marked_tags:
|
|
trip.marked_tags.append(tag)
|
|
|
|
# remove tags not in the new list
|
|
trip.selected_tags = [tag for tag in trip.selected_tags if tag.id in selected_tag_ids]
|
|
trip.marked_tags = [tag for tag in trip.marked_tags if tag.id in marked_tag_ids]
|
|
|
|
db.flush()
|
|
|
|
created_ids, deleted_checked = generate_trip_items(
|
|
db, trip=trip, selected_tag_ids=selected_tag_ids, marked_tag_ids=marked_tag_ids,
|
|
)
|
|
db.commit()
|
|
return {
|
|
"trip_id": trip.id,
|
|
"deleted_checked_trip_item_ids": deleted_checked,
|
|
"created_trip_item_ids": created_ids,
|
|
}
|
|
|
|
@router.get("/next-id", response_model=UUID)
|
|
def get_next_trip_id(db: Session = Depends(get_db)):
|
|
today = date.today()
|
|
trip = (
|
|
db.query(models.Trip)
|
|
.filter(models.Trip.start_date >= today)
|
|
.order_by(models.Trip.start_date.asc())
|
|
.first()
|
|
)
|
|
if not trip:
|
|
raise HTTPException(status_code=404, detail="No upcoming trip found")
|
|
return trip.id
|
|
|
|
@router.delete("/{trip_id}", status_code=204)
|
|
def delete_trip(trip_id: UUID, db: Session = Depends(get_db)):
|
|
trip = db.get(models.Trip, trip_id)
|
|
if not trip:
|
|
raise HTTPException(status_code=404, detail="Trip not found")
|
|
db.delete(trip)
|
|
db.commit()
|