Use get_current_user

This commit is contained in:
Csaba 2024-04-22 13:11:42 +02:00
parent 40ff3354d0
commit 1218098ca5

View file

@ -3,8 +3,9 @@ import logging
from fastapi import APIRouter, HTTPException, status, Depends from fastapi import APIRouter, HTTPException, status, Depends
from amarillo.models.Carpool import Region from amarillo.models.Carpool import Region
from amarillo.routers.agencyconf import verify_admin_api_key
from amarillo.services.regions import RegionService from amarillo.services.regions import RegionService
from amarillo.services.oauth2 import get_current_user, verify_permission
from amarillo.models.User import User
from amarillo.utils.container import container from amarillo.utils.container import container
from fastapi.responses import FileResponse from fastapi.responses import FileResponse
@ -13,7 +14,8 @@ logger = logging.getLogger(__name__)
router = APIRouter() router = APIRouter()
@router.post("/export") @router.post("/export")
async def post_agency_conf(admin_api_key: str = Depends(verify_admin_api_key)): async def trigger_export(requesting_user: User = Depends(get_current_user)):
verify_permission("admin", requesting_user)
#import is here to avoid circular import #import is here to avoid circular import
from amarillo.plugins.gtfs_export.gtfs_generator import generate_gtfs from amarillo.plugins.gtfs_export.gtfs_generator import generate_gtfs
generate_gtfs() generate_gtfs()
@ -40,7 +42,7 @@ def _assert_region_exists(region_id: str) -> Region:
status.HTTP_404_NOT_FOUND: {"description": "Region not found"}, status.HTTP_404_NOT_FOUND: {"description": "Region not found"},
} }
) )
async def get_file(region_id: str, user: str = Depends(verify_admin_api_key)): async def get_file(region_id: str, requesting_user: User = Depends(get_current_user)):
_assert_region_exists(region_id) _assert_region_exists(region_id)
return FileResponse(f'data/gtfs/amarillo.{region_id}.gtfs.zip') return FileResponse(f'data/gtfs/amarillo.{region_id}.gtfs.zip')
@ -53,7 +55,7 @@ async def get_file(region_id: str, user: str = Depends(verify_admin_api_key)):
status.HTTP_400_BAD_REQUEST: {"description": "Bad request, e.g. because format is not supported, i.e. neither protobuf nor json."} status.HTTP_400_BAD_REQUEST: {"description": "Bad request, e.g. because format is not supported, i.e. neither protobuf nor json."}
} }
) )
async def get_file(region_id: str, format: str = 'protobuf', user: str = Depends(verify_admin_api_key)): async def get_file(region_id: str, format: str = 'protobuf', requesting_user: User = Depends(get_current_user)):
_assert_region_exists(region_id) _assert_region_exists(region_id)
if format == 'json': if format == 'json':
return FileResponse(f'data/gtfs/amarillo.{region_id}.gtfsrt.json') return FileResponse(f'data/gtfs/amarillo.{region_id}.gtfsrt.json')