use new quart_common functions
This commit is contained in:
@@ -1,12 +1,4 @@
|
|||||||
from aiologger.formatters.base import Formatter
|
from quart_common.web.logger import build_logger
|
||||||
from aiologger.handlers.streams import AsyncStreamHandler
|
|
||||||
from aiologger import Logger
|
|
||||||
import os
|
|
||||||
import sys
|
|
||||||
|
|
||||||
formatter = Formatter(fmt="%(levelname)s %(module)s: %(message)s")
|
|
||||||
handler = AsyncStreamHandler(stream=sys.stdout)
|
|
||||||
handler.formatter = formatter
|
|
||||||
|
|
||||||
logger = Logger(name="simple-picoshare", level="DEBUG" if os.getenv("WEB_DEBUG", False) == "true" else "INFO")
|
logger = build_logger(name="simple-picoshare")
|
||||||
logger.handlers = [handler]
|
|
||||||
|
|||||||
+21
-158
@@ -1,164 +1,27 @@
|
|||||||
from my_modules.app.constens import THE_IP_BOT_MANAGER
|
from my_modules.app.constens import THE_IP_BOT_MANAGER
|
||||||
from my_modules.app.logger import logger
|
from my_modules.app.logger import logger
|
||||||
from my_modules.app.setup import LIMITER
|
from my_modules.app.setup import LIMITER
|
||||||
from my_modules.functions import get_ip
|
from quart_common.web.auth import (
|
||||||
|
get_auth_token,
|
||||||
|
build_verify_token,
|
||||||
|
build_token_required,
|
||||||
|
)
|
||||||
|
from quart_common.web.feature_flags import build_feature_flag_required
|
||||||
|
from quart_common.web.decorators import (
|
||||||
|
parse_request_data,
|
||||||
|
format_response,
|
||||||
|
build_apply_limit,
|
||||||
|
login_required,
|
||||||
|
)
|
||||||
|
|
||||||
from quart import jsonify, request, url_for, Response, current_app, session, abort
|
verify_token = build_verify_token(logger=logger)
|
||||||
from functools import wraps
|
token_required = build_token_required(logger=logger, verify_token=verify_token)
|
||||||
from datetime import datetime
|
|
||||||
import asyncio, msgpack, json
|
|
||||||
|
|
||||||
def encode_object_default(obj):
|
apply_limit = build_apply_limit(
|
||||||
if isinstance(obj, datetime):
|
limiter=LIMITER,
|
||||||
return obj.strftime('%a, %d %b %Y %H:%M:%S %Z')
|
ip_bot_manager=THE_IP_BOT_MANAGER,
|
||||||
raise TypeError(f"Type {type(obj)} not serializable")
|
)
|
||||||
|
|
||||||
# Helper function to extract the token
|
feature_flag_required = build_feature_flag_required(
|
||||||
async def get_auth_token():
|
logger=logger,
|
||||||
auth_header = request.headers.get('Authorization')
|
)
|
||||||
if auth_header:
|
|
||||||
try:
|
|
||||||
return auth_header.split(" ")[1]
|
|
||||||
except IndexError:
|
|
||||||
pass
|
|
||||||
|
|
||||||
return None
|
|
||||||
|
|
||||||
async def verify_token(token:str):
|
|
||||||
decoded_payload = await current_app.convex.decode_access_token_payload(access_token=token)
|
|
||||||
decoded_payload_error_state = decoded_payload.get('state', None)
|
|
||||||
|
|
||||||
if decoded_payload is None:
|
|
||||||
return {'error': "No Data from Database"}, 504
|
|
||||||
elif decoded_payload_error_state == 1:
|
|
||||||
await logger.error(decoded_payload.get('error'))
|
|
||||||
return {'error': 'Invalid access token'}, 401
|
|
||||||
elif decoded_payload_error_state == 2:
|
|
||||||
await logger.error(decoded_payload.get('error'))
|
|
||||||
return {'error': 'Wrong access token type'}, 401
|
|
||||||
elif decoded_payload_error_state == 3:
|
|
||||||
await logger.error(decoded_payload.get('error'))
|
|
||||||
return {'error': 'Refresh token not found', 'msg': 'Please login again and generate a new Token', 'url': url_for('auth_login.login')}, 403
|
|
||||||
elif decoded_payload_error_state == 4:
|
|
||||||
await logger.error(decoded_payload.get('error'))
|
|
||||||
return {'error': 'Refresh token expired'}, 401
|
|
||||||
|
|
||||||
return decoded_payload, None
|
|
||||||
|
|
||||||
# Custom decorator for token validation
|
|
||||||
def token_required(func):
|
|
||||||
@wraps(func)
|
|
||||||
async def wrapper(*args, **kwargs):
|
|
||||||
token = await get_auth_token()
|
|
||||||
if not token:
|
|
||||||
await logger.error('API Token is missing')
|
|
||||||
return jsonify(error='Token is missing'), 400
|
|
||||||
|
|
||||||
decoded_payload, status_code = await verify_token(token)
|
|
||||||
decoded_payload_error = decoded_payload.get('error', None)
|
|
||||||
if decoded_payload_error:
|
|
||||||
return jsonify(decoded_payload), status_code
|
|
||||||
|
|
||||||
return await func(user=decoded_payload, *args, **kwargs)
|
|
||||||
return wrapper
|
|
||||||
|
|
||||||
# Custom decorator for content type reading, convertig dict to response
|
|
||||||
def parse_request_data(func):
|
|
||||||
@wraps(func)
|
|
||||||
async def wrapper(*args, **kwargs):
|
|
||||||
content_type = request.headers.get('Content-Type', '').lower()
|
|
||||||
data = None
|
|
||||||
body = await request.body
|
|
||||||
|
|
||||||
if body:
|
|
||||||
if 'application/msgpack' in content_type or 'application/x-msgpack' in content_type:
|
|
||||||
try:
|
|
||||||
data = await asyncio.to_thread(msgpack.unpackb, body, raw=False)
|
|
||||||
except Exception:
|
|
||||||
return jsonify({'error': 'Invalid MessagePack'}), 400
|
|
||||||
elif 'application/json' in content_type:
|
|
||||||
data = await request.get_json(silent=True)
|
|
||||||
if data is None:
|
|
||||||
return jsonify({'error': 'Invalid JSON'}), 400
|
|
||||||
else:
|
|
||||||
if request.method in ['POST', 'PUT', 'PATCH', 'DELETE']:
|
|
||||||
return jsonify({'error': 'Unsupported Content-Type'}), 415
|
|
||||||
# else:
|
|
||||||
# if request.method in ['POST', 'PUT', 'PATCH']:
|
|
||||||
# return jsonify({'error': 'Empty request body'}), 400
|
|
||||||
|
|
||||||
return await func(data=data, *args, **kwargs)
|
|
||||||
return wrapper
|
|
||||||
|
|
||||||
def format_response(func):
|
|
||||||
@wraps(func)
|
|
||||||
async def wrapper(*args, **kwargs):
|
|
||||||
result = await func(*args, **kwargs)
|
|
||||||
|
|
||||||
# Unpack result: (data), (data, status), (data, headers), (data, status, headers)
|
|
||||||
data = None
|
|
||||||
status = 200
|
|
||||||
headers = {}
|
|
||||||
|
|
||||||
if isinstance(result, tuple):
|
|
||||||
data = result[0]
|
|
||||||
if len(result) == 2:
|
|
||||||
if isinstance(result[1], dict):
|
|
||||||
headers = result[1]
|
|
||||||
else:
|
|
||||||
status = result[1]
|
|
||||||
elif len(result) == 3:
|
|
||||||
status = result[1]
|
|
||||||
headers = result[2]
|
|
||||||
else:
|
|
||||||
data = result
|
|
||||||
|
|
||||||
accept = request.headers.get('Accept', '').lower()
|
|
||||||
if 'application/msgpack' in accept or 'application/x-msgpack' in accept:
|
|
||||||
packed = await asyncio.to_thread(msgpack.packb, data, default=encode_object_default, use_bin_type=True)
|
|
||||||
return Response(packed, content_type='application/msgpack', status=status, headers=headers)
|
|
||||||
else:
|
|
||||||
json_str = await asyncio.to_thread(json.dumps, data, ensure_ascii=False, default=encode_object_default)
|
|
||||||
response = Response(json_str, status=status, content_type='application/json')
|
|
||||||
response.headers.update(headers)
|
|
||||||
return response
|
|
||||||
|
|
||||||
return wrapper
|
|
||||||
|
|
||||||
# Custom decorator for adding limits for spezific methodes by endpoint
|
|
||||||
def apply_limit(endpoint_name, limits:dict=None):
|
|
||||||
def make_key_func(endpoint):
|
|
||||||
def key_func():
|
|
||||||
ip = get_ip()
|
|
||||||
if THE_IP_BOT_MANAGER.is_client_ip_always_allowed(ip):
|
|
||||||
return None # No key, no increment, no enforcement
|
|
||||||
|
|
||||||
# Combine endpoint name and HTTP method (and client IP) into the rate-limit key
|
|
||||||
return f":{ip}:{endpoint}:{request.method}:"
|
|
||||||
return key_func
|
|
||||||
|
|
||||||
def decorator(func):
|
|
||||||
@wraps(func)
|
|
||||||
async def wrapped(*args, **kwargs):
|
|
||||||
return await func(*args, **kwargs)
|
|
||||||
|
|
||||||
rules = limits.get(endpoint_name)
|
|
||||||
def dynamic_limit():
|
|
||||||
if isinstance(rules, dict):
|
|
||||||
return rules.get(request.method.upper(), "10000 per second")
|
|
||||||
return rules or "10000 per second"
|
|
||||||
|
|
||||||
key_fn = make_key_func(endpoint_name)
|
|
||||||
return LIMITER.limit(dynamic_limit, key_func=key_fn)(wrapped)
|
|
||||||
return decorator
|
|
||||||
|
|
||||||
# Check if User is loggedin
|
|
||||||
def login_required(func):
|
|
||||||
@wraps(func)
|
|
||||||
async def decorated_function(*args, **kwargs):
|
|
||||||
user_session = session.get('user')
|
|
||||||
if user_session is None:
|
|
||||||
abort(401)
|
|
||||||
|
|
||||||
return await func(user=user_session, *args, **kwargs)
|
|
||||||
return decorated_function
|
|
||||||
|
|||||||
+11
-56
@@ -1,38 +1,15 @@
|
|||||||
from quart import has_request_context, request, has_websocket_context, websocket
|
from quart_common.web.request import (
|
||||||
|
get_ip,
|
||||||
|
get_my_ip_address,
|
||||||
|
get_local_ip_addresses,
|
||||||
|
generate_all_ips,
|
||||||
|
replace_last_ip_segment,
|
||||||
|
get_request_context,
|
||||||
|
is_valid_uuid,
|
||||||
|
)
|
||||||
|
from quart_common.web.env import is_development_environment, is_testing_environment
|
||||||
|
|
||||||
from flask_limiter import Limiter
|
from flask_limiter import Limiter
|
||||||
from uuid import UUID
|
|
||||||
import subprocess, aiohttp
|
|
||||||
|
|
||||||
# Get IPs
|
|
||||||
def get_ip():
|
|
||||||
context = get_request_context()
|
|
||||||
if context:
|
|
||||||
xff = context.headers.get("X-Forwarded-For", "")
|
|
||||||
return xff.split(",")[0].strip() if xff else context.remote_addr
|
|
||||||
return None # No active request or websocket context
|
|
||||||
|
|
||||||
async def get_my_ip_address():
|
|
||||||
async with aiohttp.ClientSession() as session:
|
|
||||||
async with session.get("https://ipinfo.io/ip") as response:
|
|
||||||
if response.status == 200:
|
|
||||||
return await response.text()
|
|
||||||
raise aiohttp.ClientError(f'Could not get IP: {response.status} {await response.text()}')
|
|
||||||
|
|
||||||
def get_local_ip_addresses():
|
|
||||||
try:
|
|
||||||
result = subprocess.run(['hostname', '-I'], capture_output=True, text=True)
|
|
||||||
first_ip = result.stdout.strip().split()[0]
|
|
||||||
return first_ip
|
|
||||||
except subprocess.CalledProcessError as e:
|
|
||||||
return None
|
|
||||||
except IndexError:
|
|
||||||
return None
|
|
||||||
|
|
||||||
def generate_all_ips(base_ip:str) -> set:
|
|
||||||
ips = set()
|
|
||||||
for i in range(1, 255): # 1 to 254 inclusive
|
|
||||||
ips.add(replace_last_ip_segment(base_ip, i))
|
|
||||||
return ips
|
|
||||||
|
|
||||||
# Limiter Key Gen
|
# Limiter Key Gen
|
||||||
def custom_limit_key():
|
def custom_limit_key():
|
||||||
@@ -52,25 +29,3 @@ def enforce_custom_limit(limiter:Limiter, key:str, limit_count: int = 3, window_
|
|||||||
current = limiter.storage.incr(key, expiry=window_sec)
|
current = limiter.storage.incr(key, expiry=window_sec)
|
||||||
if current > limit_count:
|
if current > limit_count:
|
||||||
raise LookupError("To Many 404 Requests")
|
raise LookupError("To Many 404 Requests")
|
||||||
|
|
||||||
## Helper
|
|
||||||
def replace_last_ip_segment(ip:str, new_value:str="1") -> str:
|
|
||||||
parts = ip.strip().split('.')
|
|
||||||
if len(parts) == 4:
|
|
||||||
parts[-1] = str(new_value)
|
|
||||||
return '.'.join(parts)
|
|
||||||
raise ValueError("Invalid IP address format")
|
|
||||||
|
|
||||||
def get_request_context():
|
|
||||||
if has_request_context():
|
|
||||||
return request
|
|
||||||
elif has_websocket_context():
|
|
||||||
return websocket
|
|
||||||
return None
|
|
||||||
|
|
||||||
def is_valid_uuid(value: str) -> bool:
|
|
||||||
try:
|
|
||||||
UUID(value)
|
|
||||||
return True
|
|
||||||
except ValueError:
|
|
||||||
return False
|
|
||||||
|
|||||||
+1
-1
Submodule quart_common updated: 80d6d123df...58d4d80432
Reference in New Issue
Block a user