Redis+Trio support

This commit is contained in:
dsc
2020-01-05 15:10:16 +01:00
parent 499e46c93d
commit df20bbff18
9 changed files with 468 additions and 22 deletions
+28 -4
View File
@@ -17,7 +17,7 @@ from typing import Optional
from quart import Quart
from .sessions import RedisSessionInterface, MemcachedSessionInterface, NullSessionInterface
from .sessions import RedisSessionInterface, RedisTrioSessionInterface, MemcachedSessionInterface, NullSessionInterface
class Session(object):
@@ -51,6 +51,7 @@ class Session(object):
"""
def __init__(self, app: Quart = None) -> None:
self._current_async_library = "asyncio"
self.app = app
if app is not None:
self.init_app(app)
@@ -60,6 +61,12 @@ class Session(object):
:param app: the Quart app object with proper configuration.
"""
try:
import quart_trio
if isinstance(app, quart_trio.QuartTrio):
self._current_async_library = "trio"
except ImportError:
pass
app.session_interface = self._get_interface(app)
@app.before_serving
@@ -85,12 +92,29 @@ class Session(object):
config = {k: v for k, v in config.items() if k.startswith('SESSION_')}
if config['SESSION_TYPE'] == 'redis':
session_interface = RedisSessionInterface(
options = {
"redis": config['SESSION_REDIS'],
"key_prefix": config['SESSION_KEY_PREFIX'],
"use_signer": config['SESSION_USE_SIGNER'],
"permanent": config['SESSION_PERMANENT'],
**config
}
if self._current_async_library == "asyncio":
session_interface = RedisSessionInterface(**options)
elif self._current_async_library == "trio":
session_interface = RedisTrioSessionInterface(**options)
else:
raise NotImplementedError("Unknown eventloop")
elif config['SESSION_TYPE'] == 'redis+trio':
session_interface = RedisTrioSessionInterface(
redis=config['SESSION_REDIS'],
key_prefix=config['SESSION_KEY_PREFIX'],
use_signer=config['SESSION_USE_SIGNER'],
permanent=config['SESSION_PERMANENT'],
**config)
premanent=config['SESSION_PERMANENT'],
**config
)
elif config['SESSION_TYPE'] == 'memcached':
session_interface = MemcachedSessionInterface(
memcached=config['SESSION_MEMCACHED'],
+37
View File
@@ -0,0 +1,37 @@
Code here is borrowed from [alekseyev/trio-redis](https://github.com/alekseyev/trio-redis), which was originally developed over at [Bogdanp/trio-redis](https://github.com/Bogdanp/trio-redis).
Since it has no active maintainers and no PyPI package - I am including it as-is.
## Usage
```python3
from quart_session.redis_trio import RedisTrio
cache = RedisTrio(
addr=b"10.0.0.3", port=6379, password=b"foo")
await cache.connect()
await cache.setex(key="foo", value=42, seconds=300)
await cache.get("foo")
```
Or,
```python3
async with RedisTrio() as cache:
await cache.set("foo", 42)
await cache.get("foo")
```
## Future work
If someone makes a Redis+Trio client that supports connection pooling, we can switch to it.
```
:copyright: (c) 2017 by Bogdan Paul Popa.
:copyright: (c) 2019 by Oleksii Aleksieiev.
:copyright: (c) 2020 by dsc.
:license: BSD, see LICENSE for more
```
+14
View File
@@ -0,0 +1,14 @@
# -*- coding: utf-8 -*-
"""
quart_session.redis_trio
~~~~~~~~~~~~~~~~~~~~~~
A simple Redis Trio client.
:copyright: (c) 2017 by Bogdan Paul Popa.
:copyright: (c) 2019 by Oleksii Aleksieiev.
:copyright: (c) 2020 by dsc.
:license: BSD, see LICENSE for more details.
"""
from quart_session.redis_trio.client import RedisTrio
+88
View File
@@ -0,0 +1,88 @@
# -*- coding: utf-8 -*-
"""
quart_session.redis_trio
~~~~~~~~~~~~~~~~~~~~~~
A simple Redis Trio client.
:copyright: (c) 2017 by Bogdan Paul Popa.
:copyright: (c) 2019 by Oleksii Aleksieiev.
:copyright: (c) 2020 by dsc.
:license: BSD, see LICENSE for more details.
"""
from typing import Union
from .connection import RedisConnection
class RedisTrio:
"""A simple Redis+Trio client.
Parameters:
addr(str): The IP address the Redis server is listening on.
port(int): The port the Redis server is listening on.
Examples:
>>> async with RedisTrio() as redis:
... await redis.set("foo", 42)
... await redis.get("foo")
b'42'
"""
def __init__(self, addr: Union[bytes, str] = b"127.0.0.1", port: int = 6379, password: bytes = b""):
self.conn = RedisConnection(addr, port)
self.password = password
async def connect(self):
"""Open a connection to the Redis server.
Returns:
RedisTrio: This instance.
"""
await self.conn.connect()
if self.password:
await self.auth(self.password)
return self
async def close(self):
"""Close the connection to the Redis server.
"""
await self.quit()
self.conn.close()
async def auth(self, password):
return await self.conn.process_command_ok(b"AUTH", password)
async def delete(self, *keys):
return await self.conn.process_command(b"DEL", *keys)
async def echo(self, message):
return await self.conn.process_command(b"ECHO", message)
async def flushall(self):
return await self.conn.process_command_ok(b"FLUSHALL")
async def get(self, key) -> bytes:
return await self.conn.process_command(b"GET", key)
async def quit(self):
return await self.conn.process_command(b"QUIT")
async def set(self, key, value):
return await self.conn.process_command_ok(b"SET", key, value)
async def setex(self, key: str, value: str, seconds: int):
"""Set the value and expiration of a key.
:raises TypeError: if seconds is not int
"""
if not isinstance(seconds, int):
raise TypeError("milliseconds argument must be int")
return await self.conn.process_command_ok(b"SETEX", key, seconds, value)
async def __aenter__(self):
return await self.connect()
async def __aexit__(self, exc_type, exc_value, traceback):
self.close()
+136
View File
@@ -0,0 +1,136 @@
# -*- coding: utf-8 -*-
"""
quart_session.redis_trio.connection
~~~~~~~~~~~~~~~~~~~~~~
A simple Redis Trio client.
:copyright: (c) 2017 by Bogdan Paul Popa.
:copyright: (c) 2019 by Oleksii Aleksieiev.
:copyright: (c) 2020 by dsc.
:license: BSD, see LICENSE for more details.
"""
from typing import Union
import trio
from .serialization import atom, serialize
from .errors import ProtocolError, ResponseError, ResponseTypeError
SP = ord("+")
EP = ord("-")
IP = ord(":")
BP = ord("$")
AP = ord("*")
#: The set of known Redis response prefixes.
known_prefixes = {SP, EP, IP, BP, AP}
class ReadMore(Exception):
"""Raised by parse to signal that it needs more data.
"""
class RedisConnection:
"""This class facilitates all communication with Redis via a trio socket.
Warning:
The interface of this class may change at any time, without
notice, due to the experimental nature of Trio.
"""
def __init__(self, addr: Union[bytes, str], port: int, bufsize: int = 16384):
self.addr = (addr, port)
self.sock = trio.socket.socket()
self.bufsize = bufsize
async def connect(self):
await self.sock.connect(self.addr)
def close(self):
self.sock.close()
async def send_command(self, command, *args):
command_and_args = (serialize(arg) for arg in (atom(command),) + args)
data = b" ".join(command_and_args) + b"\r\n"
await self.sock.send(data)
async def process_command(self, *command_and_args):
await self.send_command(*command_and_args)
return await self.process_response()
async def process_command_ok(self, *command_and_args):
await self.send_command(*command_and_args)
return await self.process_response() == b"OK"
async def process_response(self):
data = await self.sock.recv(self.bufsize)
while True:
try:
item, _ = await self.parse(data)
return item
except ReadMore:
data += await self.sock.recv(self.bufsize)
async def parse(self, data):
try:
index = data.index(b"\r\n")
except ValueError:
raise ReadMore()
if data[0] not in known_prefixes:
raise ProtocolError(f"Unexpected data in response: {data!r}.")
elif data[0] == SP:
return data[1:index], data[index + 2:]
elif data[0] == EP:
error = data[1:index].decode("ascii")
if error.startswith("WRONGTYPE"):
raise ResponseTypeError(error[len("WRONGTYPE "):])
elif error.startswith("ERR"):
raise ResponseError(error[len("ERR "):])
else:
raise ResponseError(error)
elif data[0] == IP:
return int(data[1:index]), data[index + 2:]
elif data[0] == BP:
length, data = int(data[1:index]), data[index + 2:]
if length == -1:
return None, data
elif len(data) < length + 2:
raise ReadMore()
return data[:length], data[length + 2:]
elif data[0] == AP:
length, data = int(data[1:index]), data[index + 2:]
if length == -1:
return None, data
return await self.parse_array(length, data)
else: # pragma: no cover
assert False, "unreachable"
async def parse_array(self, length, data):
items = []
while len(items) < length:
if not data:
data += await self.sock.recv(self.bufsize)
continue
try:
item, data = await self.parse(data)
items.append(item)
except ReadMore:
data += await self.sock.recv(self.bufsize)
return items, data
+33
View File
@@ -0,0 +1,33 @@
# -*- coding: utf-8 -*-
"""
quart_session.redis_trio.errors
~~~~~~~~~~~~~~~~~~~~~~
A simple Redis Trio client.
:copyright: (c) 2017 by Bogdan Paul Popa.
:copyright: (c) 2019 by Oleksii Aleksieiev.
:copyright: (c) 2020 by dsc.
:license: BSD, see LICENSE for more details.
"""
class RedisError(Exception):
"""Base class for all Redis-related errors.
"""
class ProtocolError(RedisError):
"""Raised when Redis responds with something that doesn't conform
to the protocol.
"""
class ResponseError(RedisError):
"""Raised when Redis returns an error response.
"""
class ResponseTypeError(ResponseError):
"""Raised when Redis returns an error response with a `WRONGTYPE` prefix.
"""
+60
View File
@@ -0,0 +1,60 @@
# -*- coding: utf-8 -*-
"""
quart_session.redis_trio.serialization
~~~~~~~~~~~~~~~~~~~~~~
A simple Redis Trio client.
:copyright: (c) 2017 by Bogdan Paul Popa.
:copyright: (c) 2019 by Oleksii Aleksieiev.
:copyright: (c) 2020 by dsc.
:license: BSD, see LICENSE for more details.
"""
from collections import namedtuple
#: Wrapper class for values that don't have to be quoted.
atom = namedtuple("atom", ("value",))
#: The set of characters that must be escaped before being sent as
#: Redis strings.
escapes = {
ord(b"\0"): rb"\x00",
ord(b"\n"): rb"\n",
ord(b"\r"): rb"\r",
ord(b"\\"): rb"\\",
ord(b'"'): rb'\"',
}
def serialize(x):
"""Serialize `x` so that it can safely be sent to Redis.
Parameters:
x(object): The value to serialize.
Returns:
bytes: The serialized value.
"""
if isinstance(x, atom):
return x.value
elif isinstance(x, bytes):
return quote(x)
elif isinstance(x, str):
return quote(x.encode("utf-8"))
elif isinstance(x, (float, int)):
return str(x).encode("ascii")
else:
return serialize(str(x))
def quote(bs):
return b'"' + bytes(escape(bs)) + b'"'
def escape(bs):
for c in bs:
if c in escapes:
yield from escapes[c]
else:
yield c
+57 -12
View File
@@ -100,7 +100,9 @@ class SessionInterface(QuartSessionInterface):
self._config = kwargs
async def open_session(
self, app: Quart, request: BaseRequestWebsocket
self,
app: Quart,
request: BaseRequestWebsocket
) -> Optional[SecureCookieSession]:
sid = request.cookies.get(app.session_cookie_name)
addr = request.headers.get('X-Forwarded-For', request.remote_addr) if \
@@ -137,10 +139,7 @@ class SessionInterface(QuartSessionInterface):
return self.session_class(**options)
prevent_hijack = self._config['SESSION_HIJACK_PROTECTION']
if prevent_hijack is False:
pass
elif isinstance(prevent_hijack, bool) and \
prevent_hijack is True:
if prevent_hijack is True:
if self._config['SESSION_HIJACK_REVERSE_PROXY'] is True:
addr = request.headers.get('X-Forwarded-For', request.remote_addr)
else:
@@ -160,12 +159,12 @@ class SessionInterface(QuartSessionInterface):
response: Response
) -> None:
# prevent set-cookie
# motivation: https://github.com/fengsp/flask-session/pull/70
if self._config['SESSION_EXPLICIT'] is True and \
not session._dirty:
return
# prevent set-cookie on (static) file responses
# https://github.com/fengsp/flask-session/pull/70
if self._config['SESSION_STATIC_FILE'] is False and \
isinstance(response.response, FileBody):
return
@@ -226,15 +225,18 @@ class RedisSessionInterface(SessionInterface):
session_class = RedisSession
def __init__(
self, redis, key_prefix: str, use_signer: bool = False,
permanent: bool = True, **kwargs):
super(RedisSessionInterface, self).__init__(
key_prefix=key_prefix, use_signer=use_signer,
permanent=permanent, **kwargs)
def __init__(self, redis, **kwargs):
super(RedisSessionInterface, self).__init__(**kwargs)
self.redis = redis
async def create(self, app: Quart) -> None:
"""Creates ``aioredis.Redis`` instance.
.. note::
Creates a single Redis connection, you might prefer
pooling instead (see ``aioredis.Redis.create_redis_pool``)
"""
if self.redis is None:
import aioredis
self.redis = await aioredis.create_redis("redis://localhost")
@@ -251,6 +253,49 @@ class RedisSessionInterface(SessionInterface):
return await self.redis.delete(key)
class RedisTrioSessionInterface(SessionInterface):
"""Uses the Redis+Trio key-value store as a session backend.
:param redis: ``quart_session.redis_trio.RedisTrio`` instance.
:param key_prefix: A prefix that is added to all Redis store keys.
:param use_signer: Whether to sign the session id cookie or not.
:param permanent: Whether to use permanent session or not.
:param kwargs: Quart-session config, used internally.
"""
session_class = RedisSession
def __init__(self, redis, **kwargs):
super(RedisTrioSessionInterface, self).__init__(**kwargs)
self.redis_trio = redis
async def create(self, app: Quart) -> None:
"""Creates ``aioredis.Redis`` instance.
.. note::
Creates a single Redis connection. Pooling not
supported yet for ``RedisTrio``.
"""
if self.redis_trio is None:
from quart_session.redis_trio import RedisTrio
self.redis_trio = RedisTrio()
await self.redis_trio.connect()
async def _backend_get(self, app: Quart, key: str):
data = await self.redis_trio.get(key)
if data:
return data.decode()
async def _backend_set(self, app: Quart, key: str, value):
return await self.redis_trio.setex(
key=key, value=value,
seconds=total_seconds(app.permanent_session_lifetime))
async def _backend_delete(self, app: Quart, key: str):
return await self.redis_trio.delete(key)
class MemcachedSessionInterface(SessionInterface):
"""Uses the Memcached key-value store as a session backend.