Use OAuth2 to authorize endpoints

This commit is contained in:
Csaba 2024-03-01 15:11:16 +01:00
parent 7016ba22bf
commit 3f25239533

View file

@ -3,7 +3,7 @@ import logging
from fastapi import APIRouter, HTTPException, status, Depends
from amarillo.models.Carpool import Region
from amarillo.routers.agencyconf import verify_admin_api_key
from amarillo.services.oauth2 import get_current_agency
from amarillo.services.regions import RegionService
from amarillo.utils.container import container
from fastapi.responses import FileResponse
@ -13,7 +13,7 @@ logger = logging.getLogger(__name__)
router = APIRouter()
@router.post("/export-grfs")
async def post_agency_conf(admin_api_key: str = Depends(verify_admin_api_key)):
async def post_agency_conf(admin_api_key: str = Depends(get_current_agency)):
#import is here to avoid circular import
from amarillo.plugins.grfs_export.gtfs_generator import generate_gtfs
generate_gtfs()
@ -40,7 +40,7 @@ def _assert_region_exists(region_id: str) -> Region:
status.HTTP_404_NOT_FOUND: {"description": "Region not found"},
}
)
async def get_file(region_id: str, user: str = Depends(verify_admin_api_key)):
async def get_file(region_id: str, user: str = Depends(get_current_agency)):
_assert_region_exists(region_id)
try:
from amarillo.plugins.metrics import increment_grfs_download_counter
@ -57,7 +57,7 @@ async def get_file(region_id: str, user: str = Depends(verify_admin_api_key)):
status.HTTP_404_NOT_FOUND: {"description": "Region not found"},
}
)
async def get_file(region_id: str, user: str = Depends(verify_admin_api_key)):
async def get_file(region_id: str, user: str = Depends(get_current_agency)):
_assert_region_exists(region_id)
try:
from amarillo.plugins.metrics import increment_grfs_download_counter
@ -76,7 +76,7 @@ async def get_file(region_id: str, user: str = Depends(verify_admin_api_key)):
status.HTTP_400_BAD_REQUEST: {"description": "Bad request, e.g. because format is not supported, i.e. neither protobuf nor json."}
}
)
async def get_file(region_id: str, format: str = 'protobuf', user: str = Depends(verify_admin_api_key)):
async def get_file(region_id: str, format: str = 'protobuf', user: str = Depends(get_current_agency)):
_assert_region_exists(region_id)
if format == 'json':
return FileResponse(f'data/grfs/amarillo.{region_id}.gtfsrt.json')