253 lines
9.0 KiB
Python
253 lines
9.0 KiB
Python
import asyncio
|
|
import functools
|
|
import json
|
|
import logging
|
|
import os
|
|
import signal
|
|
import socket
|
|
|
|
from gmqtt import Client as MQTTClient
|
|
from gmqtt.mqtt.constants import MQTTv311, MQTTv50
|
|
|
|
LOGGER = logging.getLogger(__name__)
|
|
|
|
|
|
class ChannelsMQTTProxy:
|
|
def __init__(self, channel_layer, settings):
|
|
self.channel_layer = channel_layer
|
|
|
|
# MQTTClient takes an identifier which is seen at the broker
|
|
# Creating the client does not connect.
|
|
self.mqtt = MQTTClient(
|
|
f"ChannelsMQTTProxy@{socket.gethostname()}.{os.getpid()}")
|
|
self.mqtt.set_auth_credentials(username=settings.MQTT_USER,
|
|
password=settings.MQTT_PASSWORD)
|
|
self.mqtt_host = settings.MQTT_HOST
|
|
try:
|
|
self.mqtt_version = settings.MQTT_VERSION
|
|
except AttributeError:
|
|
self.mqtt_version = None
|
|
# Hook up the callbacks and some lifecycle management events
|
|
self.mqtt.on_connect = self._on_connect
|
|
self.mqtt.on_disconnect = self._on_disconnect
|
|
self.mqtt.on_subscribe = self._on_subscribe
|
|
self.mqtt.on_message = self._on_message
|
|
self.stop_event = asyncio.Event()
|
|
self.connected = asyncio.Event()
|
|
|
|
try:
|
|
self.mqtt_channel_name = settings.MQTT_CHANNEL_NAME
|
|
except AttributeError:
|
|
self.mqtt_channel_name = "mqtt"
|
|
self.mqtt_channel_publish = f"{self.mqtt_channel_name}.publish"
|
|
self.mqtt_channel_message = f"{self.mqtt_channel_name}.message"
|
|
self.subscriptions = {}
|
|
|
|
async def run(self):
|
|
"""This connects to the mqtt broker (retrying forever), then calls the
|
|
overrideable :func:`setup()` method finally awaits
|
|
:func:`ask_exit` is called at which point it exits cleanly.
|
|
Alternatively you can call :func:`connect()` and then wait for
|
|
:func:`finish()` Once connected the underlying qmqtt client
|
|
will re-connect if the connection is lost.
|
|
|
|
"""
|
|
loop = asyncio.get_event_loop()
|
|
loop.add_signal_handler(signal.SIGINT, self.ask_exit)
|
|
loop.add_signal_handler(signal.SIGTERM, self.ask_exit)
|
|
await self.connect()
|
|
await self.finish()
|
|
|
|
async def connect(self):
|
|
# Connect to the broker
|
|
if self.mqtt_version == 311:
|
|
version = MQTTv311
|
|
else:
|
|
version = MQTTv50
|
|
|
|
while not self.mqtt.is_connected:
|
|
try:
|
|
await self.mqtt.connect(self.mqtt_host, version=version)
|
|
except Exception as e:
|
|
LOGGER.warning(f"Error trying to connect: {e}. Retrying.")
|
|
await asyncio.sleep(1)
|
|
self.connected.set()
|
|
|
|
async def finish(self):
|
|
# This will wait until the client is signalled
|
|
LOGGER.debug("Waiting for stop event")
|
|
await self.stop_event.wait()
|
|
await self.mqtt.disconnect()
|
|
LOGGER.debug("MQTT client disconnected")
|
|
|
|
def _on_connect(self, _client, _flags, _rc, _properties):
|
|
for s in self.subscriptions.keys():
|
|
LOGGER.debug(f"Re-subscribing to {s}")
|
|
self.mqtt.subscribe(s)
|
|
LOGGER.debug('Connected and subscribed')
|
|
|
|
def _on_disconnect(self, _client, _packet, _exc=None):
|
|
LOGGER.debug('Disconnected')
|
|
|
|
async def _on_message(self, _client, topic, payload, qos, properties):
|
|
LOGGER.debug(f"{topic} => '{payload}' props:{properties}")
|
|
|
|
# Check properties for 'retain' (which in this context means
|
|
# the message is being sent from retained backing store) and
|
|
# drop those.
|
|
# Eventually find a way to direct these to support code which can build
|
|
# initial state and keep it up-to-date? This is linked to
|
|
# connect-on-startup.
|
|
if properties["retain"] == 1:
|
|
LOGGER.debug(f"Dropping replayed retained message")
|
|
return
|
|
|
|
# Compose a Channel message
|
|
payload = payload.decode("utf-8")
|
|
try:
|
|
payload = json.loads(payload)
|
|
except:
|
|
LOGGER.debug("Payload is not JSON - sending it raw")
|
|
pass
|
|
msg = {
|
|
"topic": topic,
|
|
"payload": payload,
|
|
"qos": qos,
|
|
}
|
|
event = {
|
|
"type": self.mqtt_channel_message, # default "mqtt.message"
|
|
"message": msg
|
|
}
|
|
|
|
tasks = list()
|
|
for grp in self.groups_matching_topic(topic):
|
|
tasks.append(self.channel_layer.group_send(grp, event))
|
|
LOGGER.debug(f"Calling {grp} handler for {topic}")
|
|
try:
|
|
await asyncio.gather(*tasks)
|
|
except Exception as e:
|
|
LOGGER.error("Cannot send event: {event}")
|
|
LOGGER.exception(e)
|
|
|
|
def subscribe(self, topic, group):
|
|
"""Subscribes a group to an MQTT topic (passed directly to MQTT)"""
|
|
if topic not in self.subscriptions:
|
|
LOGGER.debug(f"New subscription for {topic}")
|
|
self.subscriptions[topic] = []
|
|
# We need to mqtt-subscribe now:
|
|
|
|
# This actually just sends a subscribe packet. We should
|
|
# store this and the details and handle the setup in
|
|
# _on_subscribe callback when the _mid (message id) is
|
|
# confirmed.
|
|
_mid = self.mqtt.subscribe(topic)
|
|
else:
|
|
if group in self.subscriptions[topic]:
|
|
LOGGER.debug(f"{group} already subscibed to {topic}")
|
|
return
|
|
|
|
LOGGER.debug(f"{group} subscribed to {topic}")
|
|
self.subscriptions[topic].append(group)
|
|
|
|
def _on_subscribe(self, _client, _mid, _qos, _properties):
|
|
LOGGER.debug('Subscribe callback {_mid}')
|
|
|
|
async def unsubscribe(self, topic, group):
|
|
"""Un subscribes a group to an MQTT topic"""
|
|
LOGGER.debug(f"unsubscribe {group} from {topic}")
|
|
if topic in self.subscriptions:
|
|
groups = self.subscriptions[topic]
|
|
if group in groups:
|
|
LOGGER.debug(f"{group} being unsubscribed from {topic}")
|
|
groups.delete(topic)
|
|
else:
|
|
LOGGER.debug(f"{group} not subscribed to {topic}")
|
|
if not len(groups):
|
|
LOGGER.debug(f"No more {group}s, unsubscribing to {topic}")
|
|
self.mqtt.unsubscribe(topic)
|
|
LOGGER.debug(f"{topic} not subscribed")
|
|
|
|
def publish(self, topic=None, payload=None, qos=2, retain=True):
|
|
"""Publish :param payload: to :param topic:"""
|
|
LOGGER.debug(f"Publishing {topic} = {payload}")
|
|
self.mqtt.publish(topic, payload, qos=qos, retain=retain)
|
|
|
|
def ask_exit(self):
|
|
"""Handle outstanding messages and cleanly disconnect"""
|
|
LOGGER.warning(f"{self} received signal asking to exit")
|
|
self.stop_event.set()
|
|
|
|
def groups_matching_topic(self, topic):
|
|
groups = set()
|
|
for sub, gs in self.subscriptions.items():
|
|
if sub == topic: # simple match
|
|
groups.update(gs)
|
|
elif self.topic_matches_sub(sub, topic):
|
|
LOGGER.debug(f"Found matching groups {gs}")
|
|
groups.update(gs)
|
|
return groups
|
|
|
|
# Taken from paho-mqtt - thanks :)
|
|
@staticmethod
|
|
def topic_matches_sub(sub, topic):
|
|
"""Check whether a topic matches a subscription.
|
|
For example:
|
|
foo/bar would match the subscription foo/# or +/bar
|
|
non/matching would not match the subscription non/+/+
|
|
"""
|
|
result = True
|
|
multilevel_wildcard = False
|
|
|
|
slen = len(sub)
|
|
tlen = len(topic)
|
|
|
|
if slen > 0 and tlen > 0:
|
|
if (sub[0] == '$' and topic[0] != '$') or (topic[0] == '$' and sub[0] != '$'):
|
|
return False
|
|
|
|
spos = 0
|
|
tpos = 0
|
|
|
|
while spos < slen and tpos < tlen:
|
|
if sub[spos] == topic[tpos]:
|
|
if tpos == tlen-1:
|
|
# Check for e.g. foo matching foo/#
|
|
if spos == slen-3 and sub[spos+1] == '/' and sub[spos+2] == '#':
|
|
result = True
|
|
multilevel_wildcard = True
|
|
break
|
|
|
|
spos += 1
|
|
tpos += 1
|
|
|
|
if tpos == tlen and spos == slen-1 and sub[spos] == '+':
|
|
spos += 1
|
|
result = True
|
|
break
|
|
else:
|
|
if sub[spos] == '+':
|
|
spos += 1
|
|
while tpos < tlen and topic[tpos] != '/':
|
|
tpos += 1
|
|
if tpos == tlen and spos == slen:
|
|
result = True
|
|
break
|
|
|
|
elif sub[spos] == '#':
|
|
multilevel_wildcard = True
|
|
if spos+1 != slen:
|
|
result = False
|
|
break
|
|
else:
|
|
result = True
|
|
break
|
|
|
|
else:
|
|
result = False
|
|
break
|
|
|
|
if not multilevel_wildcard and (tpos < tlen or spos < slen):
|
|
result = False
|
|
|
|
return result
|