diff --git a/my_modules/app/logger.py b/my_modules/app/logger.py index 7b89922..dac2c47 100644 --- a/my_modules/app/logger.py +++ b/my_modules/app/logger.py @@ -1,12 +1,4 @@ -from aiologger.formatters.base import Formatter -from aiologger.handlers.streams import AsyncStreamHandler -from aiologger import Logger -import os -import sys +from quart_common.web.logger import build_logger -formatter = Formatter(fmt="%(levelname)s %(module)s: %(message)s") -handler = AsyncStreamHandler(stream=sys.stdout) -handler.formatter = formatter -logger = Logger(name="simple-picoshare", level="DEBUG" if os.getenv("WEB_DEBUG", False) == "true" else "INFO") -logger.handlers = [handler] +logger = build_logger(name="simple-picoshare") diff --git a/my_modules/decoratory/header.py b/my_modules/decoratory/header.py index a1bbae3..1683854 100644 --- a/my_modules/decoratory/header.py +++ b/my_modules/decoratory/header.py @@ -1,164 +1,27 @@ from my_modules.app.constens import THE_IP_BOT_MANAGER from my_modules.app.logger import logger from my_modules.app.setup import LIMITER -from my_modules.functions import get_ip +from quart_common.web.auth import ( + get_auth_token, + build_verify_token, + build_token_required, +) +from quart_common.web.feature_flags import build_feature_flag_required +from quart_common.web.decorators import ( + parse_request_data, + format_response, + build_apply_limit, + login_required, +) -from quart import jsonify, request, url_for, Response, current_app, session, abort -from functools import wraps -from datetime import datetime -import asyncio, msgpack, json +verify_token = build_verify_token(logger=logger) +token_required = build_token_required(logger=logger, verify_token=verify_token) -def encode_object_default(obj): - if isinstance(obj, datetime): - return obj.strftime('%a, %d %b %Y %H:%M:%S %Z') - raise TypeError(f"Type {type(obj)} not serializable") +apply_limit = build_apply_limit( + limiter=LIMITER, + ip_bot_manager=THE_IP_BOT_MANAGER, +) -# Helper function to extract the token -async def get_auth_token(): - auth_header = request.headers.get('Authorization') - if auth_header: - try: - return auth_header.split(" ")[1] - except IndexError: - pass - - return None - -async def verify_token(token:str): - decoded_payload = await current_app.convex.decode_access_token_payload(access_token=token) - decoded_payload_error_state = decoded_payload.get('state', None) - - if decoded_payload is None: - return {'error': "No Data from Database"}, 504 - elif decoded_payload_error_state == 1: - await logger.error(decoded_payload.get('error')) - return {'error': 'Invalid access token'}, 401 - elif decoded_payload_error_state == 2: - await logger.error(decoded_payload.get('error')) - return {'error': 'Wrong access token type'}, 401 - elif decoded_payload_error_state == 3: - await logger.error(decoded_payload.get('error')) - return {'error': 'Refresh token not found', 'msg': 'Please login again and generate a new Token', 'url': url_for('auth_login.login')}, 403 - elif decoded_payload_error_state == 4: - await logger.error(decoded_payload.get('error')) - return {'error': 'Refresh token expired'}, 401 - - return decoded_payload, None - -# Custom decorator for token validation -def token_required(func): - @wraps(func) - async def wrapper(*args, **kwargs): - token = await get_auth_token() - if not token: - await logger.error('API Token is missing') - return jsonify(error='Token is missing'), 400 - - decoded_payload, status_code = await verify_token(token) - decoded_payload_error = decoded_payload.get('error', None) - if decoded_payload_error: - return jsonify(decoded_payload), status_code - - return await func(user=decoded_payload, *args, **kwargs) - return wrapper - -# Custom decorator for content type reading, convertig dict to response -def parse_request_data(func): - @wraps(func) - async def wrapper(*args, **kwargs): - content_type = request.headers.get('Content-Type', '').lower() - data = None - body = await request.body - - if body: - if 'application/msgpack' in content_type or 'application/x-msgpack' in content_type: - try: - data = await asyncio.to_thread(msgpack.unpackb, body, raw=False) - except Exception: - return jsonify({'error': 'Invalid MessagePack'}), 400 - elif 'application/json' in content_type: - data = await request.get_json(silent=True) - if data is None: - return jsonify({'error': 'Invalid JSON'}), 400 - else: - if request.method in ['POST', 'PUT', 'PATCH', 'DELETE']: - return jsonify({'error': 'Unsupported Content-Type'}), 415 - # else: - # if request.method in ['POST', 'PUT', 'PATCH']: - # return jsonify({'error': 'Empty request body'}), 400 - - return await func(data=data, *args, **kwargs) - return wrapper - -def format_response(func): - @wraps(func) - async def wrapper(*args, **kwargs): - result = await func(*args, **kwargs) - - # Unpack result: (data), (data, status), (data, headers), (data, status, headers) - data = None - status = 200 - headers = {} - - if isinstance(result, tuple): - data = result[0] - if len(result) == 2: - if isinstance(result[1], dict): - headers = result[1] - else: - status = result[1] - elif len(result) == 3: - status = result[1] - headers = result[2] - else: - data = result - - accept = request.headers.get('Accept', '').lower() - if 'application/msgpack' in accept or 'application/x-msgpack' in accept: - packed = await asyncio.to_thread(msgpack.packb, data, default=encode_object_default, use_bin_type=True) - return Response(packed, content_type='application/msgpack', status=status, headers=headers) - else: - json_str = await asyncio.to_thread(json.dumps, data, ensure_ascii=False, default=encode_object_default) - response = Response(json_str, status=status, content_type='application/json') - response.headers.update(headers) - return response - - return wrapper - -# Custom decorator for adding limits for spezific methodes by endpoint -def apply_limit(endpoint_name, limits:dict=None): - def make_key_func(endpoint): - def key_func(): - ip = get_ip() - if THE_IP_BOT_MANAGER.is_client_ip_always_allowed(ip): - return None # No key, no increment, no enforcement - - # Combine endpoint name and HTTP method (and client IP) into the rate-limit key - return f":{ip}:{endpoint}:{request.method}:" - return key_func - - def decorator(func): - @wraps(func) - async def wrapped(*args, **kwargs): - return await func(*args, **kwargs) - - rules = limits.get(endpoint_name) - def dynamic_limit(): - if isinstance(rules, dict): - return rules.get(request.method.upper(), "10000 per second") - return rules or "10000 per second" - - key_fn = make_key_func(endpoint_name) - return LIMITER.limit(dynamic_limit, key_func=key_fn)(wrapped) - return decorator - -# Check if User is loggedin -def login_required(func): - @wraps(func) - async def decorated_function(*args, **kwargs): - user_session = session.get('user') - if user_session is None: - abort(401) - - return await func(user=user_session, *args, **kwargs) - return decorated_function +feature_flag_required = build_feature_flag_required( + logger=logger, +) diff --git a/my_modules/functions.py b/my_modules/functions.py index 05fd0bf..94a91f0 100644 --- a/my_modules/functions.py +++ b/my_modules/functions.py @@ -1,38 +1,15 @@ -from quart import has_request_context, request, has_websocket_context, websocket +from quart_common.web.request import ( + get_ip, + get_my_ip_address, + get_local_ip_addresses, + generate_all_ips, + replace_last_ip_segment, + get_request_context, + is_valid_uuid, +) +from quart_common.web.env import is_development_environment, is_testing_environment + from flask_limiter import Limiter -from uuid import UUID -import subprocess, aiohttp - -# Get IPs -def get_ip(): - context = get_request_context() - if context: - xff = context.headers.get("X-Forwarded-For", "") - return xff.split(",")[0].strip() if xff else context.remote_addr - return None # No active request or websocket context - -async def get_my_ip_address(): - async with aiohttp.ClientSession() as session: - async with session.get("https://ipinfo.io/ip") as response: - if response.status == 200: - return await response.text() - raise aiohttp.ClientError(f'Could not get IP: {response.status} {await response.text()}') - -def get_local_ip_addresses(): - try: - result = subprocess.run(['hostname', '-I'], capture_output=True, text=True) - first_ip = result.stdout.strip().split()[0] - return first_ip - except subprocess.CalledProcessError as e: - return None - except IndexError: - return None - -def generate_all_ips(base_ip:str) -> set: - ips = set() - for i in range(1, 255): # 1 to 254 inclusive - ips.add(replace_last_ip_segment(base_ip, i)) - return ips # Limiter Key Gen def custom_limit_key(): @@ -52,25 +29,3 @@ def enforce_custom_limit(limiter:Limiter, key:str, limit_count: int = 3, window_ current = limiter.storage.incr(key, expiry=window_sec) if current > limit_count: raise LookupError("To Many 404 Requests") - -## Helper -def replace_last_ip_segment(ip:str, new_value:str="1") -> str: - parts = ip.strip().split('.') - if len(parts) == 4: - parts[-1] = str(new_value) - return '.'.join(parts) - raise ValueError("Invalid IP address format") - -def get_request_context(): - if has_request_context(): - return request - elif has_websocket_context(): - return websocket - return None - -def is_valid_uuid(value: str) -> bool: - try: - UUID(value) - return True - except ValueError: - return False diff --git a/quart_common b/quart_common index 80d6d12..58d4d80 160000 --- a/quart_common +++ b/quart_common @@ -1 +1 @@ -Subproject commit 80d6d123df3640eed3364122509c46826ea86473 +Subproject commit 58d4d8043240b3b95e5767b6e3f1f45c23aec8b3