137 lines
3.8 KiB
Python
137 lines
3.8 KiB
Python
# -*- coding: utf-8 -*-
|
|
"""
|
|
quart_session.redis_trio.connection
|
|
~~~~~~~~~~~~~~~~~~~~~~
|
|
|
|
A simple Redis Trio client.
|
|
|
|
:copyright: (c) 2017 by Bogdan Paul Popa.
|
|
:copyright: (c) 2019 by Oleksii Aleksieiev.
|
|
:copyright: (c) 2020 by dsc.
|
|
:license: BSD, see LICENSE for more details.
|
|
"""
|
|
|
|
from typing import Union
|
|
|
|
import trio
|
|
|
|
from .serialization import atom, serialize
|
|
from .errors import ProtocolError, ResponseError, ResponseTypeError
|
|
|
|
|
|
SP = ord("+")
|
|
EP = ord("-")
|
|
IP = ord(":")
|
|
BP = ord("$")
|
|
AP = ord("*")
|
|
|
|
#: The set of known Redis response prefixes.
|
|
known_prefixes = {SP, EP, IP, BP, AP}
|
|
|
|
|
|
class ReadMore(Exception):
|
|
"""Raised by parse to signal that it needs more data.
|
|
"""
|
|
|
|
|
|
class RedisConnection:
|
|
"""This class facilitates all communication with Redis via a trio socket.
|
|
Warning:
|
|
The interface of this class may change at any time, without
|
|
notice, due to the experimental nature of Trio.
|
|
"""
|
|
|
|
def __init__(self, addr: Union[bytes, str], port: int, bufsize: int = 16384):
|
|
self.addr = (addr, port)
|
|
self.sock = trio.socket.socket()
|
|
self.bufsize = bufsize
|
|
|
|
async def connect(self):
|
|
await self.sock.connect(self.addr)
|
|
|
|
def close(self):
|
|
self.sock.close()
|
|
|
|
async def send_command(self, command, *args):
|
|
command_and_args = (serialize(arg) for arg in (atom(command),) + args)
|
|
data = b" ".join(command_and_args) + b"\r\n"
|
|
await self.sock.send(data)
|
|
|
|
async def process_command(self, *command_and_args):
|
|
await self.send_command(*command_and_args)
|
|
return await self.process_response()
|
|
|
|
async def process_command_ok(self, *command_and_args):
|
|
await self.send_command(*command_and_args)
|
|
return await self.process_response() == b"OK"
|
|
|
|
async def process_response(self):
|
|
data = await self.sock.recv(self.bufsize)
|
|
while True:
|
|
try:
|
|
item, _ = await self.parse(data)
|
|
return item
|
|
except ReadMore:
|
|
data += await self.sock.recv(self.bufsize)
|
|
|
|
async def parse(self, data):
|
|
try:
|
|
index = data.index(b"\r\n")
|
|
except ValueError:
|
|
raise ReadMore()
|
|
|
|
if data[0] not in known_prefixes:
|
|
raise ProtocolError(f"Unexpected data in response: {data!r}.")
|
|
|
|
elif data[0] == SP:
|
|
return data[1:index], data[index + 2:]
|
|
|
|
elif data[0] == EP:
|
|
error = data[1:index].decode("ascii")
|
|
if error.startswith("WRONGTYPE"):
|
|
raise ResponseTypeError(error[len("WRONGTYPE "):])
|
|
|
|
elif error.startswith("ERR"):
|
|
raise ResponseError(error[len("ERR "):])
|
|
|
|
else:
|
|
raise ResponseError(error)
|
|
|
|
elif data[0] == IP:
|
|
return int(data[1:index]), data[index + 2:]
|
|
|
|
elif data[0] == BP:
|
|
length, data = int(data[1:index]), data[index + 2:]
|
|
if length == -1:
|
|
return None, data
|
|
|
|
elif len(data) < length + 2:
|
|
raise ReadMore()
|
|
|
|
return data[:length], data[length + 2:]
|
|
|
|
elif data[0] == AP:
|
|
length, data = int(data[1:index]), data[index + 2:]
|
|
if length == -1:
|
|
return None, data
|
|
|
|
return await self.parse_array(length, data)
|
|
|
|
else: # pragma: no cover
|
|
assert False, "unreachable"
|
|
|
|
async def parse_array(self, length, data):
|
|
items = []
|
|
while len(items) < length:
|
|
if not data:
|
|
data += await self.sock.recv(self.bufsize)
|
|
continue
|
|
|
|
try:
|
|
item, data = await self.parse(data)
|
|
items.append(item)
|
|
except ReadMore:
|
|
data += await self.sock.recv(self.bufsize)
|
|
|
|
return items, data
|