From 1218098ca5dbc2f24fe432368c966e8d065f82d2 Mon Sep 17 00:00:00 2001 From: Francia Csaba Date: Mon, 22 Apr 2024 13:11:42 +0200 Subject: [PATCH] Use get_current_user --- amarillo/plugins/gtfs_export/router.py | 10 ++++++---- 1 file changed, 6 insertions(+), 4 deletions(-) diff --git a/amarillo/plugins/gtfs_export/router.py b/amarillo/plugins/gtfs_export/router.py index a181bfc..53159b8 100644 --- a/amarillo/plugins/gtfs_export/router.py +++ b/amarillo/plugins/gtfs_export/router.py @@ -3,8 +3,9 @@ import logging from fastapi import APIRouter, HTTPException, status, Depends from amarillo.models.Carpool import Region -from amarillo.routers.agencyconf import verify_admin_api_key from amarillo.services.regions import RegionService +from amarillo.services.oauth2 import get_current_user, verify_permission +from amarillo.models.User import User from amarillo.utils.container import container from fastapi.responses import FileResponse @@ -13,7 +14,8 @@ logger = logging.getLogger(__name__) router = APIRouter() @router.post("/export") -async def post_agency_conf(admin_api_key: str = Depends(verify_admin_api_key)): +async def trigger_export(requesting_user: User = Depends(get_current_user)): + verify_permission("admin", requesting_user) #import is here to avoid circular import from amarillo.plugins.gtfs_export.gtfs_generator import generate_gtfs generate_gtfs() @@ -40,7 +42,7 @@ def _assert_region_exists(region_id: str) -> Region: status.HTTP_404_NOT_FOUND: {"description": "Region not found"}, } ) -async def get_file(region_id: str, user: str = Depends(verify_admin_api_key)): +async def get_file(region_id: str, requesting_user: User = Depends(get_current_user)): _assert_region_exists(region_id) return FileResponse(f'data/gtfs/amarillo.{region_id}.gtfs.zip') @@ -53,7 +55,7 @@ async def get_file(region_id: str, user: str = Depends(verify_admin_api_key)): status.HTTP_400_BAD_REQUEST: {"description": "Bad request, e.g. because format is not supported, i.e. neither protobuf nor json."} } ) -async def get_file(region_id: str, format: str = 'protobuf', user: str = Depends(verify_admin_api_key)): +async def get_file(region_id: str, format: str = 'protobuf', requesting_user: User = Depends(get_current_user)): _assert_region_exists(region_id) if format == 'json': return FileResponse(f'data/gtfs/amarillo.{region_id}.gtfsrt.json')