Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
56 changes: 56 additions & 0 deletions hbmqtt/adapters.py
Original file line number Diff line number Diff line change
Expand Up @@ -184,6 +184,62 @@ def close(self):
try: yield from self._writer.wait_closed() # py37+
except AttributeError: pass

class UnixStreamReaderAdapter(ReaderAdapter):
"""
Asyncio Streams API protocol adapter
This adapter relies on StreamReader to read from a TCP socket.
Because API is very close, this class is trivial
"""
def __init__(self, reader: StreamReader):
self._reader = reader

@asyncio.coroutine
def read(self, n=-1) -> bytes:
if n == -1:
data = yield from self._reader.read(n)
else:
data = yield from self._reader.readexactly(n)
return data

def feed_eof(self):
return self._reader.feed_eof()


class UnixStreamWriterAdapter(WriterAdapter):
"""
Asyncio Streams API protocol adapter
This adapter relies on StreamWriter to write to a TCP socket.
Because API is very close, this class is trivial
"""
def __init__(self, writer: StreamWriter):
self.logger = logging.getLogger(__name__)
self._writer = writer
self.is_closed = False # StreamWriter has no test for closed...we use our own

def write(self, data):
if not self.is_closed:
self._writer.write(data)

@asyncio.coroutine
def drain(self):
if not self.is_closed:
yield from self._writer.drain()

def get_peer_info(self):
extra_info = self._writer.get_extra_info('socket')
return extra_info.getsockname(), 0

@asyncio.coroutine
def close(self):
if not self.is_closed:
self.is_closed = True # we first mark this closed so yields below don't cause races with waiting writes
yield from self._writer.drain()
if self._writer.can_write_eof():
self._writer.write_eof()
self._writer.close()
try: yield from self._writer.wait_closed() # py37+
except AttributeError: pass


class BufferReader(ReaderAdapter):
"""
Expand Down
27 changes: 23 additions & 4 deletions hbmqtt/broker.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,9 @@
ReaderAdapter,
WriterAdapter,
WebSocketsReader,
WebSocketsWriter)
WebSocketsWriter,
UnixStreamReaderAdapter,
UnixStreamWriterAdapter)
from .plugins.manager import PluginManager, BaseContext


Expand Down Expand Up @@ -260,12 +262,18 @@ def start(self):
raise BrokerException("Can't read cert files '%s' or '%s' : %s" %
(listener['certfile'], listener['keyfile'], fnfe))

address, s_port = listener['bind'].split(':')
port = 0
path = ''
address = ''
s_port = ''
try:
port = int(s_port)
address, s_port = listener['bind'].split(':')
try:
port = int(s_port)
except ValueError as ve:
raise BrokerException("Invalid port value in bind value: %s" % listener['bind'])
except ValueError as ve:
raise BrokerException("Invalid port value in bind value: %s" % listener['bind'])
path = listener['bind']

if listener['type'] == 'tcp':
cb_partial = partial(self.stream_connected, listener_name=listener_name)
Expand All @@ -281,6 +289,13 @@ def start(self):
instance = yield from websockets.serve(cb_partial, address, port, ssl=sc, loop=self._loop,
subprotocols=['mqtt'])
self._servers[listener_name] = Server(listener_name, instance, max_connections, self._loop)
elif listener['type'] == 'unix':
cb_partial = partial(self.unix_stream_connected, listener_name=listener_name)
instance = yield from asyncio.start_unix_server(cb_partial,
path,
ssl=sc,
loop=self._loop)
self._servers[listener_name] = Server(listener_name, instance, max_connections, self._loop)

self.logger.info("Listener '%s' bind to %s (max_connections=%d)" %
(listener_name, listener['bind'], max_connections))
Expand Down Expand Up @@ -343,6 +358,10 @@ def ws_connected(self, websocket, uri, listener_name):
def stream_connected(self, reader, writer, listener_name):
yield from self.client_connected(listener_name, StreamReaderAdapter(reader), StreamWriterAdapter(writer))

@asyncio.coroutine
def unix_stream_connected(self, reader, writer, listener_name):
yield from self.client_connected(listener_name, UnixStreamReaderAdapter(reader), UnixStreamWriterAdapter(writer))

@asyncio.coroutine
def client_connected(self, listener_name, reader: ReaderAdapter, writer: WriterAdapter):
# Wait for connection available on listener
Expand Down
10 changes: 10 additions & 0 deletions hbmqtt/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -382,6 +382,8 @@ def _connect_coro(self):
uri = (scheme, self.session.remote_address + ":" + str(self.session.remote_port), uri_attributes[2],
uri_attributes[3], uri_attributes[4], uri_attributes[5])
self.session.broker_uri = urlunparse(uri)
if scheme in ('unix'):
self.session.broker_uri = '/' + self.session.remote_address + uri_attributes.path
# Init protocol handler
#if not self._handler:
self._handler = ClientProtocolHandler(self.plugins_manager, loop=self._loop)
Expand Down Expand Up @@ -419,6 +421,14 @@ def _connect_coro(self):
**kwargs)
reader = WebSocketsReader(websocket)
writer = WebSocketsWriter(websocket)
elif scheme in ('unix'):
conn_reader, conn_writer = \
yield from asyncio.open_unix_connection(
path=self.session.broker_uri,
loop=self._loop, **kwargs)
reader = StreamReaderAdapter(conn_reader)
writer = StreamWriterAdapter(conn_writer)

# Start MQTT protocol
self._handler.attach(self.session, reader, writer)
return_code = yield from self._handler.mqtt_connect()
Expand Down
2 changes: 1 addition & 1 deletion scripts/default_broker.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -9,4 +9,4 @@ plugins:
- auth_file
- auth_anonymous
topic-check:
enabled: False
enabled: False
12 changes: 12 additions & 0 deletions scripts/default_unix_socket_broker.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
listeners:
default:
type: unix
bind: /tmp/test-mqtt
sys_interval: 20
auth:
allow-anonymous: true
plugins:
- auth_anonymous
topic-check:
enabled: true
plugins: topic_taboo