diff --git a/app/corpora/cqi_over_sio/__init__.py b/app/corpora/cqi_over_sio/__init__.py index 25102c94..90d9719e 100644 --- a/app/corpora/cqi_over_sio/__init__.py +++ b/app/corpora/cqi_over_sio/__init__.py @@ -9,7 +9,7 @@ from inspect import signature from threading import Lock from typing import Callable, Dict, List, Optional from app import db, docker_client, hashids, socketio -from app.decorators import socketio_login_required +from app.extensions.flask_socketio_extras import login_required from app.models import Corpus, CorpusStatus from . import extensions @@ -87,11 +87,11 @@ CQI_API_FUNCTION_NAMES: List[str] = [ class CQiNamespace(Namespace): - @socketio_login_required + @login_required def on_connect(self): pass - @socketio_login_required + @login_required def on_init(self, db_corpus_hashid: str): db_corpus_id: int = hashids.decode(db_corpus_hashid) db_corpus: Optional[Corpus] = Corpus.query.get(db_corpus_id) @@ -134,7 +134,7 @@ class CQiNamespace(Namespace): } return {'code': 200, 'msg': 'OK'} - @socketio_login_required + @login_required def on_exec(self, fn_name: str, fn_args: Dict = {}): try: cqi_client: CQiClient = session['cqi_over_sio']['cqi_client'] diff --git a/app/corpora/events.py b/app/corpora/events.py index 9a0e3a36..cc0baad1 100644 --- a/app/corpora/events.py +++ b/app/corpora/events.py @@ -1,12 +1,12 @@ from flask_login import current_user from flask_socketio import join_room from app import hashids, socketio -from app.decorators import socketio_login_required +from app.extensions.flask_socketio_extras import login_required from app.models import Corpus @socketio.on('GET /corpora/') -@socketio_login_required +@login_required def get_corpus(corpus_hashid): corpus_id = hashids.decode(corpus_hashid) corpus = Corpus.query.get(corpus_id) @@ -29,7 +29,7 @@ def get_corpus(corpus_hashid): @socketio.on('SUBSCRIBE /corpora/') -@socketio_login_required +@login_required def subscribe_corpus(corpus_hashid): corpus_id = hashids.decode(corpus_hashid) corpus = Corpus.query.get(corpus_id) diff --git a/app/decorators.py b/app/decorators.py index 21527233..5fbd7671 100644 --- a/app/decorators.py +++ b/app/decorators.py @@ -22,31 +22,6 @@ def admin_required(f): return permission_required(Permission.ADMINISTRATE)(f) -def socketio_login_required(f): - @wraps(f) - def decorated_function(*args, **kwargs): - if current_user.is_authenticated: - return f(*args, **kwargs) - else: - return {'code': 401, 'msg': 'Unauthorized'} - return decorated_function - - -def socketio_permission_required(permission): - def decorator(f): - @wraps(f) - def decorated_function(*args, **kwargs): - if not current_user.can(permission): - return {'code': 403, 'msg': 'Forbidden'} - return f(*args, **kwargs) - return decorated_function - return decorator - - -def socketio_admin_required(f): - return socketio_permission_required(Permission.ADMINISTRATE)(f) - - def background(f): ''' ' This decorator executes a function in a Thread. diff --git a/app/extensions/flask_socketio_extras/__init__.py b/app/extensions/flask_socketio_extras/__init__.py new file mode 100644 index 00000000..866c0cc9 --- /dev/null +++ b/app/extensions/flask_socketio_extras/__init__.py @@ -0,0 +1,3 @@ +from .decorators import login_required +from .decorators import permission_required +from .decorators import admin_required diff --git a/app/extensions/flask_socketio_extras/decorators.py b/app/extensions/flask_socketio_extras/decorators.py new file mode 100644 index 00000000..8e8d0a05 --- /dev/null +++ b/app/extensions/flask_socketio_extras/decorators.py @@ -0,0 +1,27 @@ +from flask_login import current_user +from functools import wraps +from app.models import Permission as UserPermission + + +def login_required(f): + @wraps(f) + def wrapper(*args, **kwargs): + if current_user.is_authenticated: + return f(*args, **kwargs) + return {'code': 401, 'body': 'Unauthorized'} + return wrapper + + +def permission_required(permission): + def decorator(f): + @wraps(f) + def wrapper(*args, **kwargs): + if not current_user.can(permission): + return {'code': 403, 'body': 'Forbidden'} + return f(*args, **kwargs) + return wrapper + return decorator + + +def admin_required(f): + return permission_required(UserPermission.ADMINISTRATE)(f) diff --git a/app/extensions/sqlalchemy/__init__.py b/app/extensions/sqlalchemy_extras/__init__.py similarity index 100% rename from app/extensions/sqlalchemy/__init__.py rename to app/extensions/sqlalchemy_extras/__init__.py diff --git a/app/extensions/sqlalchemy/types.py b/app/extensions/sqlalchemy_extras/types.py similarity index 100% rename from app/extensions/sqlalchemy/types.py rename to app/extensions/sqlalchemy_extras/types.py diff --git a/app/extensions/wtforms/__init__.py b/app/extensions/wtforms_extras/__init__.py similarity index 100% rename from app/extensions/wtforms/__init__.py rename to app/extensions/wtforms_extras/__init__.py diff --git a/app/extensions/wtforms/validators.py b/app/extensions/wtforms_extras/validators.py similarity index 100% rename from app/extensions/wtforms/validators.py rename to app/extensions/wtforms_extras/validators.py diff --git a/app/models/corpus.py b/app/models/corpus.py index 147fbb02..efe38833 100644 --- a/app/models/corpus.py +++ b/app/models/corpus.py @@ -9,7 +9,7 @@ import shutil import xml.etree.ElementTree as ET from app import db from app.converters.vrt import normalize_vrt_file -from app.extensions.sqlalchemy import IntEnumColumn +from app.extensions.sqlalchemy_extras import IntEnumColumn from .corpus_follower_association import CorpusFollowerAssociation diff --git a/app/models/job.py b/app/models/job.py index bba8ea0e..72f899a6 100644 --- a/app/models/job.py +++ b/app/models/job.py @@ -7,7 +7,7 @@ from typing import Union from pathlib import Path import shutil from app import db -from app.extensions.sqlalchemy import ContainerColumn, IntEnumColumn +from app.extensions.sqlalchemy_extras import ContainerColumn, IntEnumColumn class JobStatus(IntEnum): diff --git a/app/models/spacy_nlp_pipeline_model.py b/app/models/spacy_nlp_pipeline_model.py index e8a2501b..89fabba1 100644 --- a/app/models/spacy_nlp_pipeline_model.py +++ b/app/models/spacy_nlp_pipeline_model.py @@ -5,7 +5,7 @@ from pathlib import Path import requests import yaml from app import db -from app.extensions.sqlalchemy import ContainerColumn +from app.extensions.sqlalchemy_extras import ContainerColumn from .file_mixin import FileMixin from .user import User diff --git a/app/models/tesseract_ocr_pipeline_model.py b/app/models/tesseract_ocr_pipeline_model.py index 43198711..145173ec 100644 --- a/app/models/tesseract_ocr_pipeline_model.py +++ b/app/models/tesseract_ocr_pipeline_model.py @@ -5,7 +5,7 @@ from pathlib import Path import requests import yaml from app import db -from app.extensions.sqlalchemy import ContainerColumn +from app.extensions.sqlalchemy_extras import ContainerColumn from .file_mixin import FileMixin from .user import User diff --git a/app/models/user.py b/app/models/user.py index 0861b737..341e49aa 100644 --- a/app/models/user.py +++ b/app/models/user.py @@ -12,7 +12,7 @@ import re import secrets import shutil from app import db, hashids -from app.extensions.sqlalchemy import IntEnumColumn +from app.extensions.sqlalchemy_extras import IntEnumColumn from .corpus import Corpus from .corpus_follower_association import CorpusFollowerAssociation from .corpus_follower_role import CorpusFollowerRole diff --git a/app/users/events.py b/app/users/events.py index 8132f818..e010ce10 100644 --- a/app/users/events.py +++ b/app/users/events.py @@ -1,12 +1,12 @@ from flask_login import current_user from flask_socketio import join_room, leave_room from app import hashids, socketio -from app.decorators import socketio_login_required +from app.extensions.flask_socketio_extras import login_required from app.models import User @socketio.on('GET /users/') -@socketio_login_required +@login_required def get_user(user_hashid): user_id = hashids.decode(user_hashid) user = User.query.get(user_id) @@ -22,7 +22,7 @@ def get_user(user_hashid): @socketio.on('SUBSCRIBE /users/') -@socketio_login_required +@login_required def subscribe_user(user_hashid): user_id = hashids.decode(user_hashid) user = User.query.get(user_id) @@ -35,7 +35,7 @@ def subscribe_user(user_hashid): @socketio.on('UNSUBSCRIBE /users/') -@socketio_login_required +@login_required def unsubscribe_user(user_hashid): user_id = hashids.decode(user_hashid) user = User.query.get(user_id) diff --git a/app/users/nevents.py b/app/users/nevents.py index 59bb5b0d..de79161c 100644 --- a/app/users/nevents.py +++ b/app/users/nevents.py @@ -1,12 +1,12 @@ from flask_login import current_user -from flask_socketio import join_room, leave_room +from flask_socketio import join_room from app import hashids, socketio -from app.decorators import socketio_admin_required, socketio_login_required +from app.extensions.flask_socketio_extras import admin_required, login_required from app.models import User @socketio.on('GET /users') -@socketio_admin_required +@admin_required def get_users(): users = User.query.filter_by().all() return { @@ -20,14 +20,14 @@ def get_users(): @socketio.on('SUBSCRIBE /users') -@socketio_admin_required +@admin_required def subscribe_users(): join_room('/users') return {'options': {'status': 200, 'statusText': 'OK'}} @socketio.on('GET /users/') -@socketio_login_required +@login_required def get_user(user_hashid): user_id = hashids.decode(user_hashid) user = User.query.get(user_id) @@ -46,7 +46,7 @@ def get_user(user_hashid): @socketio.on('SUBSCRIBE /users/') -@socketio_login_required +@login_required def subscribe_user(user_hashid): user_id = hashids.decode(user_hashid) user = User.query.get(user_id) @@ -59,7 +59,7 @@ def subscribe_user(user_hashid): @socketio.on('GET /public_users') -@socketio_login_required +@login_required def get_public_users(): users = User.query.filter_by(is_public=True).all() return { @@ -76,14 +76,14 @@ def get_public_users(): @socketio.on('SUBSCRIBE /users') -@socketio_admin_required +@admin_required def subscribe_users(): join_room('/public_users') return {'options': {'status': 200, 'statusText': 'OK'}} @socketio.on('GET /public_users/') -@socketio_login_required +@login_required def get_user(user_hashid): user_id = hashids.decode(user_hashid) user = User.query.filter_by(id=user_id, is_public=True).first() @@ -102,7 +102,7 @@ def get_user(user_hashid): @socketio.on('SUBSCRIBE /public_users/') -@socketio_login_required +@login_required def subscribe_user(user_hashid): user_id = hashids.decode(user_hashid) user = User.query.filter_by(id=user_id, is_public=True).first() diff --git a/app/users/settings/forms.py b/app/users/settings/forms.py index dc4687c1..1f673d78 100644 --- a/app/users/settings/forms.py +++ b/app/users/settings/forms.py @@ -16,7 +16,7 @@ from wtforms.validators import ( Regexp ) from app.models import User, UserSettingJobStatusMailNotificationLevel -from app.extensions.wtforms.validators import FileSize +from app.extensions.wtforms_extras.validators import FileSize class UpdateAccountInformationForm(FlaskForm):