from fastapi import APIRouter, Depends, HTTPException from sqlalchemy.orm import Session, joinedload from uuid import UUID from database import get_db import models from schemas import TripCreate, TripOut, TripUpdate, TripRegenerationResult from crud import generate_trip_items router = APIRouter(prefix="/trips", tags=["trips"]) @router.get("/", response_model=list[TripOut]) def list_trips(db: Session = Depends(get_db)): trips = ( db.query(models.Trip) .options( joinedload(models.Trip.selected_tags).joinedload(models.TripTagSelected.tag), joinedload(models.Trip.marked_tags).joinedload(models.TripTagMarked.tag), ) .all() ) return [ TripOut( id=t.id, name=t.name, start_date=t.start_date, end_date=t.end_date, selected_tags=[st.tag for st in t.selected_tags], marked_tags=[mt.tag for mt in t.marked_tags], ) for t in trips ] @router.post("/", response_model=TripOut) def create_trip(payload: TripCreate, db: Session = Depends(get_db)): user = db.query(models.User).first() if not user: from uuid import uuid4 user = models.User(id=uuid4(), name="Demo") db.add(user) db.flush() trip = models.Trip(user_id=user.id, name=payload.name, start_date=payload.start_date, end_date=payload.end_date) db.add(trip) db.flush() # attach selected & marked if payload.selected_tag_ids: for tid in payload.selected_tag_ids: db.add(models.TripTagSelected(trip_id=trip.id, tag_id=tid)) if payload.marked_tag_ids: for tid in payload.marked_tag_ids: db.add(models.TripTagMarked(trip_id=trip.id, tag_id=tid)) db.flush() # generate items per rules created_ids, _ = generate_trip_items( db, trip=trip, selected_tag_ids=payload.selected_tag_ids, marked_tag_ids=payload.marked_tag_ids, ) db.commit() # reload with relationships trip = ( db.query(models.Trip) .options( joinedload(models.Trip.selected_tags).joinedload(models.TripTagSelected.tag), joinedload(models.Trip.marked_tags).joinedload(models.TripTagMarked.tag), ) .get(trip.id) ) return TripOut( id=trip.id, name=trip.name, start_date=trip.start_date, end_date=trip.end_date, selected_tags=[st.tag for st in trip.selected_tags], marked_tags=[mt.tag for mt in trip.marked_tags], ) @router.put("/{trip_id}/reconfigure", response_model=TripRegenerationResult) def reconfigure_trip(trip_id: UUID, payload: TripUpdate, db: Session = Depends(get_db)): trip = db.get(models.Trip, trip_id) if not trip: raise HTTPException(status_code=404, detail="Trip not found") # update base fields if payload.name is not None: trip.name = payload.name if payload.start_date is not None: trip.start_date = payload.start_date if payload.end_date is not None: trip.end_date = payload.end_date db.flush() # update selected/marked join tables if provided if payload.selected_tag_ids is not None: # replace all db.query(models.TripTagSelected).filter_by(trip_id=trip.id).delete() for tid in payload.selected_tag_ids: db.add(models.TripTagSelected(trip_id=trip.id, tag_id=tid)) if payload.marked_tag_ids is not None: db.query(models.TripTagMarked).filter_by(trip_id=trip.id).delete() for tid in payload.marked_tag_ids: db.add(models.TripTagMarked(trip_id=trip.id, tag_id=tid)) db.flush() # read back lists sel_ids = [row.tag_id for row in db.query(models.TripTagSelected).filter_by(trip_id=trip.id).all()] mrk_ids = [row.tag_id for row in db.query(models.TripTagMarked).filter_by(trip_id=trip.id).all()] created_ids, deleted_checked = generate_trip_items( db, trip=trip, selected_tag_ids=sel_ids, marked_tag_ids=mrk_ids ) db.commit() return { "trip_id": trip.id, "deleted_checked_trip_item_ids": deleted_checked, "created_trip_item_ids": created_ids, }