verify_permission function

This commit is contained in:
Csaba 2024-04-22 12:49:13 +02:00
parent 11d5849290
commit c8acc46382
3 changed files with 65 additions and 11 deletions

View file

@ -1,8 +1,5 @@
from typing import Annotated, Optional, List, Union from typing import Annotated, Optional, List
from pydantic import ConfigDict, BaseModel, Field, constr from pydantic import ConfigDict, BaseModel, Field
MyUrlsType = constr(regex="^[a-z]$")
class User(BaseModel): class User(BaseModel):
#TODO: add attributes admin, permissions, fullname, email #TODO: add attributes admin, permissions, fullname, email
@ -24,7 +21,7 @@ class User(BaseModel):
min_length=8, min_length=8,
max_length=256, max_length=256,
examples=["$2b$12$EixZaYVK1fsbw1ZfbX3OXePaWxn96p36WQoeG6Lruj3vjPGga31lW"]) examples=["$2b$12$EixZaYVK1fsbw1ZfbX3OXePaWxn96p36WQoeG6Lruj3vjPGga31lW"])
permissions: Optional[List[str]] = Field([], permissions: Optional[List[Annotated[str, Field(pattern=r'^[a-z0-9]+(:[a-z]+)?$')]]] = Field([],
description="The permissions of this user, a list of strings in the format <agency:operation> or <operation>", description="The permissions of this user, a list of strings in the format <agency:operation> or <operation>",
max_length=256, max_length=256,
# pattern=r'^[a-zA-Z0-9]+(:[a-zA-Z]+)?$', #TODO # pattern=r'^[a-zA-Z0-9]+(:[a-zA-Z]+)?$', #TODO

View file

@ -3,6 +3,7 @@
from datetime import datetime, timedelta, timezone from datetime import datetime, timedelta, timezone
from typing import Annotated, Optional, Union from typing import Annotated, Optional, Union
import logging import logging
import logging.config
from fastapi import Depends, HTTPException, Header, status, APIRouter from fastapi import Depends, HTTPException, Header, status, APIRouter
from fastapi.security import OAuth2PasswordBearer, OAuth2PasswordRequestForm from fastapi.security import OAuth2PasswordBearer, OAuth2PasswordRequestForm
@ -60,7 +61,6 @@ def create_access_token(data: dict, expires_delta: Union[timedelta, None] = None
encoded_jwt = jwt.encode(to_encode, SECRET_KEY, algorithm=ALGORITHM) encoded_jwt = jwt.encode(to_encode, SECRET_KEY, algorithm=ALGORITHM)
return encoded_jwt return encoded_jwt
#TODO: function verify_permission(user, permission)
#TODO: rename to get_current_user, agency_from_api_key -> user_from_api_key #TODO: rename to get_current_user, agency_from_api_key -> user_from_api_key
async def get_current_agency(token: str = Depends(oauth2_scheme), agency_from_api_key: str = Depends(verify_optional_api_key)): async def get_current_agency(token: str = Depends(oauth2_scheme), agency_from_api_key: str = Depends(verify_optional_api_key)):
@ -101,15 +101,36 @@ async def get_current_user(token: str = Depends(oauth2_scheme), agency_from_api_
# TODO: use verify_permission("admin", user) # TODO: use verify_permission("admin", user)
def verify_permission(permission: str, user: User):
# permission_exception =
if user.permissions is None or permission not in user.permissions: raise HTTPException( def verify_permission(permission: str, user: User):
def permissions_exception():
return HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED, status_code=status.HTTP_401_UNAUTHORIZED,
detail=f"User '{user}' does not have the permission '{permission}'", detail=f"User '{user.user_id}' does not have the permission '{permission}'",
headers={"WWW-Authenticate": "Bearer"}, headers={"WWW-Authenticate": "Bearer"},
) )
#user is admin
if "admin" in user.permissions: return
#permission is an operation
if ":" not in permission:
if permission not in user.permissions:
raise permissions_exception()
return
def permission_matches(permission, user_permission):
prescribed_agency, prescribed_operation = permission.split(":")
given_agency, given_operation = user_permission.split(":")
return (prescribed_agency == given_agency or given_agency == "all") and (prescribed_operation == given_operation or given_operation == "all")
if any(permission_matches(permission, p) for p in user.permissions if ":" in p): return
raise permissions_exception()
async def verify_admin(agency: str = Depends(get_current_agency)): async def verify_admin(agency: str = Depends(get_current_agency)):

View file

@ -0,0 +1,36 @@
from fastapi import HTTPException
import pytest
from amarillo.services.oauth2 import verify_permission
from amarillo.models.User import User
test_user = User(user_id="test", password="testpassword", permissions=["all:read", "mfdz:write", "ride2go:all", "metrics"])
admin_user = User(user_id="admin", password="testpassword", permissions=["admin"])
def test_operation():
verify_permission("metrics", test_user)
with pytest.raises(HTTPException):
verify_permission("geojson", test_user)
def test_agency_permission():
verify_permission("mvv:read", test_user)
verify_permission("mfdz:read", test_user)
verify_permission("mfdz:write", test_user)
verify_permission("ride2go:write", test_user)
with pytest.raises(HTTPException):
verify_permission("mvv:write", test_user)
verify_permission("mvv:all", test_user)
def test_admin():
verify_permission("admin", admin_user)
verify_permission("all:all", admin_user)
verify_permission("mvv:all", admin_user)
verify_permission("mfdz:read", admin_user)
verify_permission("mfdz:write", admin_user)
verify_permission("ride2go:write", admin_user)
with pytest.raises(HTTPException):
verify_permission("admin", test_user)
verify_permission("all:all", test_user)