Initial commit
This commit is contained in:
@@ -0,0 +1,318 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
"""
|
||||
quart_session.sessions
|
||||
~~~~~~~~~~~~~~~~~~~~~~
|
||||
|
||||
Server-side Sessions and SessionInterfaces.
|
||||
|
||||
:copyright: (c) 2014 by Shipeng Feng.
|
||||
:copyright: (c) 2020 by Sander.
|
||||
:license: BSD, see LICENSE for more details.
|
||||
"""
|
||||
import time
|
||||
from typing import Any, Callable, Optional, TYPE_CHECKING
|
||||
from uuid import uuid4
|
||||
import asyncio
|
||||
|
||||
from quart import Quart
|
||||
from quart.wrappers import BaseRequestWebsocket, Response
|
||||
from quart.wrappers.response import FileBody
|
||||
from quart.sessions import SessionInterface as QuartSessionInterface, SecureCookieSession
|
||||
from quart.json.tag import TaggedJSONSerializer
|
||||
from itsdangerous import Signer, BadSignature, want_bytes
|
||||
|
||||
|
||||
def total_seconds(td):
|
||||
return td.days * 60 * 60 * 24 + td.seconds
|
||||
|
||||
|
||||
class ServerSideSession(SecureCookieSession):
|
||||
"""Baseclass for server-side based sessions."""
|
||||
|
||||
def __init__(self, initial=None, sid=None, permanent=None, addr=None):
|
||||
super(ServerSideSession, self).__init__(**initial or {})
|
||||
self.sid = sid
|
||||
if permanent:
|
||||
self.permanent = permanent
|
||||
if addr:
|
||||
self.addr = addr
|
||||
self._dirty = False
|
||||
|
||||
@property
|
||||
def addr(self) -> str:
|
||||
return self.get('_addr', False) # type: ignore
|
||||
|
||||
@addr.setter
|
||||
def addr(self, value: str) -> None:
|
||||
self['_addr'] = value # type: ignore
|
||||
|
||||
def dirty(self):
|
||||
"""Marks the session to be written/saved.
|
||||
|
||||
.. note::
|
||||
|
||||
This feature only works if you have set ``SESSION_EXPLICIT``
|
||||
to ``True``, at which point you'll have to explicitly mark
|
||||
each session before they'll get processed and saved.
|
||||
|
||||
Example::
|
||||
|
||||
app.config['SESSION_EXPLICIT'] = True
|
||||
Session(app)
|
||||
|
||||
@app.route('/')
|
||||
def root():
|
||||
session['foo'] = 'bar'
|
||||
session.dirty()
|
||||
return "Hello World!"
|
||||
"""
|
||||
self._dirty = True
|
||||
|
||||
|
||||
class RedisSession(ServerSideSession):
|
||||
pass
|
||||
|
||||
|
||||
class MemcachedSession(ServerSideSession):
|
||||
pass
|
||||
|
||||
|
||||
class NullSession(ServerSideSession):
|
||||
pass
|
||||
|
||||
|
||||
class SessionInterface(QuartSessionInterface):
|
||||
"""Baseclass for session interfaces"""
|
||||
|
||||
serializer = TaggedJSONSerializer()
|
||||
session_class = None
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
key_prefix: str,
|
||||
use_signer: bool = False,
|
||||
permanent: bool = True,
|
||||
**kwargs
|
||||
) -> None:
|
||||
self.key_prefix = key_prefix
|
||||
self.use_signer = use_signer
|
||||
self.permanent = permanent
|
||||
self._config = kwargs
|
||||
|
||||
async def open_session(
|
||||
self, app: Quart, request: BaseRequestWebsocket
|
||||
) -> Optional[SecureCookieSession]:
|
||||
sid = request.cookies.get(app.session_cookie_name)
|
||||
addr = request.headers.get('X-Forwarded-For', request.remote_addr) if \
|
||||
self._config['SESSION_HIJACK_PROTECTION'] else None
|
||||
options = {"sid": sid, "permanent": self.permanent, "addr": addr}
|
||||
|
||||
if not sid:
|
||||
options['sid'] = self._generate_sid()
|
||||
return self.session_class(**options)
|
||||
if self.use_signer:
|
||||
signer = self._get_signer(app)
|
||||
if signer is None:
|
||||
app.logger.warning("Failed to obtain a valid signer.")
|
||||
return None
|
||||
try:
|
||||
sid_as_bytes = signer.unsign(sid)
|
||||
sid = sid_as_bytes.decode()
|
||||
except BadSignature:
|
||||
app.logger.warning(f"Bad signature for sid: {sid}.")
|
||||
options['sid'] = self._generate_sid()
|
||||
return self.session_class(**options)
|
||||
|
||||
val = await self._backend_get(app, self.key_prefix + sid)
|
||||
if val is None:
|
||||
options['sid'] = self._generate_sid()
|
||||
return self.session_class(**options)
|
||||
|
||||
try:
|
||||
data = self.serializer.loads(val)
|
||||
except:
|
||||
app.logger.warning(f"Failed to deserialize session "
|
||||
f"data for sid: {sid}. Generating new sid.")
|
||||
options['sid'] = self._generate_sid()
|
||||
return self.session_class(**options)
|
||||
|
||||
prevent_hijack = self._config['SESSION_HIJACK_PROTECTION']
|
||||
if prevent_hijack is False:
|
||||
pass
|
||||
elif isinstance(prevent_hijack, bool) and \
|
||||
prevent_hijack is True:
|
||||
if self._config['SESSION_HIJACK_REVERSE_PROXY'] is True:
|
||||
addr = request.headers.get('X-Forwarded-For', request.remote_addr)
|
||||
else:
|
||||
addr = request.remote_addr
|
||||
if data.get('_addr', addr) != addr:
|
||||
await self._backend_delete(app, self.key_prefix + sid)
|
||||
options['sid'] = self._generate_sid()
|
||||
return self.session_class(**options)
|
||||
|
||||
res = self.session_class(data, sid)
|
||||
return res
|
||||
|
||||
async def save_session( # type: ignore
|
||||
self,
|
||||
app: "Quart",
|
||||
session: SecureCookieSession,
|
||||
response: Response
|
||||
) -> None:
|
||||
# prevent set-cookie
|
||||
# motivation: https://github.com/fengsp/flask-session/pull/70
|
||||
if self._config['SESSION_EXPLICIT'] is True and \
|
||||
not session._dirty:
|
||||
return
|
||||
|
||||
# prevent set-cookie on (static) file responses
|
||||
if self._config['SESSION_STATIC_FILE'] is False and \
|
||||
isinstance(response.response, FileBody):
|
||||
return
|
||||
session_key = self.key_prefix + session.sid
|
||||
domain = self.get_cookie_domain(app)
|
||||
path = self.get_cookie_path(app)
|
||||
if not session:
|
||||
if session.modified:
|
||||
await self._backend_delete(app=app, key=session_key)
|
||||
response.delete_cookie(app.session_cookie_name,
|
||||
domain=domain, path=path)
|
||||
return
|
||||
httponly = self.get_cookie_httponly(app)
|
||||
secure = self.get_cookie_secure(app)
|
||||
expires = self.get_expiration_time(app, session)
|
||||
|
||||
val = self.serializer.dumps(dict(session))
|
||||
await self._backend_set(app=app, key=session_key, value=val)
|
||||
if self.use_signer:
|
||||
session_id = self._get_signer(app).sign(want_bytes(session.sid))
|
||||
else:
|
||||
session_id = session.sid
|
||||
response.set_cookie(app.session_cookie_name, session_id,
|
||||
expires=expires, httponly=httponly,
|
||||
domain=domain, path=path, secure=secure)
|
||||
|
||||
async def create(self, app: Quart):
|
||||
raise NotImplementedError()
|
||||
|
||||
async def _backend_get(self, app: Quart, key: str):
|
||||
raise NotImplementedError()
|
||||
|
||||
async def _backend_set(self, app: Quart, key: str, value):
|
||||
raise NotImplementedError()
|
||||
|
||||
async def _backend_delete(self, app: Quart, key: str):
|
||||
raise NotImplementedError()
|
||||
|
||||
def _generate_sid(self) -> str:
|
||||
return str(uuid4())
|
||||
|
||||
def _get_signer(self, app) -> Optional[Signer]:
|
||||
if not app.secret_key:
|
||||
return None
|
||||
return Signer(app.secret_key, salt='quart-session',
|
||||
key_derivation='hmac')
|
||||
|
||||
|
||||
class RedisSessionInterface(SessionInterface):
|
||||
"""Uses the Redis key-value store as a session backend.
|
||||
|
||||
:param redis: ``aioredis.Redis`` instance.
|
||||
:param key_prefix: A prefix that is added to all Redis store keys.
|
||||
:param use_signer: Whether to sign the session id cookie or not.
|
||||
:param permanent: Whether to use permanent session or not.
|
||||
:param kwargs: Quart-session config, used internally.
|
||||
"""
|
||||
|
||||
session_class = RedisSession
|
||||
|
||||
def __init__(
|
||||
self, redis, key_prefix: str, use_signer: bool = False,
|
||||
permanent: bool = True, **kwargs):
|
||||
super(RedisSessionInterface, self).__init__(
|
||||
key_prefix=key_prefix, use_signer=use_signer,
|
||||
permanent=permanent, **kwargs)
|
||||
self.redis = redis
|
||||
|
||||
async def create(self, app: Quart) -> None:
|
||||
if self.redis is None:
|
||||
import aioredis
|
||||
self.redis = await aioredis.create_redis("redis://localhost")
|
||||
|
||||
async def _backend_get(self, app: Quart, key: str):
|
||||
return await self.redis.get(key)
|
||||
|
||||
async def _backend_set(self, app: Quart, key: str, value):
|
||||
return await self.redis.setex(
|
||||
key=key, value=value,
|
||||
seconds=total_seconds(app.permanent_session_lifetime))
|
||||
|
||||
async def _backend_delete(self, app: Quart, key: str):
|
||||
return await self.redis.delete(key)
|
||||
|
||||
|
||||
class MemcachedSessionInterface(SessionInterface):
|
||||
"""Uses the Memcached key-value store as a session backend.
|
||||
|
||||
:param client: ``aiomcache.Client`` instance.
|
||||
:param key_prefix: A prefix that is added to all Redis store keys.
|
||||
:param use_signer: Whether to sign the session id cookie or not.
|
||||
:param permanent: Whether to use permanent session or not.
|
||||
:param kwargs: Quart-session config, used internally.
|
||||
"""
|
||||
|
||||
session_class = MemcachedSession
|
||||
|
||||
def __init__(
|
||||
self, memcached, key_prefix: str, use_signer: bool = False,
|
||||
permanent: bool = True, **kwargs):
|
||||
super(MemcachedSessionInterface, self).__init__(
|
||||
key_prefix=key_prefix, use_signer=use_signer,
|
||||
permanent=permanent, **kwargs)
|
||||
self.memcached = memcached
|
||||
|
||||
@asyncio.coroutine
|
||||
def create(self, app: Quart) -> None:
|
||||
if self.memcached is None:
|
||||
import aiomcache
|
||||
loop = asyncio.get_running_loop()
|
||||
self.memcached = aiomcache.Client("127.0.0.1", 11211, loop=loop)
|
||||
|
||||
def _get_memcache_timeout(self, timeout):
|
||||
"""
|
||||
Memcached deals with long (> 30 days) timeouts in a special
|
||||
way. Call this function to obtain a safe value for your timeout.
|
||||
"""
|
||||
if timeout > 2592000: # 60*60*24*30, 30 days
|
||||
# See http://code.google.com/p/memcached/wiki/FAQ
|
||||
# "You can set expire times up to 30 days in the future. After that
|
||||
# memcached interprets it as a date, and will expire the item after
|
||||
# said date. This is a simple (but obscure) mechanic."
|
||||
#
|
||||
# This means that we have to switch to absolute timestamps.
|
||||
timeout += int(time.time())
|
||||
return timeout
|
||||
|
||||
async def _backend_get(self, app: Quart, key: str):
|
||||
key = key.encode()
|
||||
return await self.memcached.get(key)
|
||||
|
||||
async def _backend_set(self, app: Quart, key: str, value):
|
||||
key = key.encode()
|
||||
value = value.encode()
|
||||
expiry = self._get_memcache_timeout(total_seconds(
|
||||
app.permanent_session_lifetime))
|
||||
return await self.memcached.set(key=key, value=value,
|
||||
exptime=expiry)
|
||||
|
||||
async def _backend_delete(self, app: Quart, key: str):
|
||||
key = key.encode()
|
||||
return await self.memcached.delete(key)
|
||||
|
||||
|
||||
class NullSessionInterface(SessionInterface):
|
||||
"""Used to open a :class:`quart.sessions.NullSession` instance.
|
||||
"""
|
||||
|
||||
def open_session(self, app: Quart, request: BaseRequestWebsocket):
|
||||
return None
|
||||
Reference in New Issue
Block a user