87 lines
3.2 KiB
Python
87 lines
3.2 KiB
Python
import asyncio
|
|
import collections
|
|
import enum
|
|
import logging
|
|
|
|
import slixmpp
|
|
from slixmpp.exceptions import IqError, IqTimeout
|
|
|
|
logger = logging.getLogger('forwardxmpp.bot')
|
|
|
|
class XEP(enum.Enum):
|
|
DISCO = 'xep_0030' # Service discovery
|
|
MUC = 'xep_0045' # Multi-user chats
|
|
PING = 'xep_0199'
|
|
|
|
|
|
class ForwardBot(slixmpp.ClientXMPP):
|
|
MUC_FEATURE = 'http://jabber.org/protocol/muc'
|
|
|
|
def __init__(self, config):
|
|
super().__init__(config.bot_jid().full, config.bot_password())
|
|
self.nick = config.bot_nick()
|
|
self.max_retry_wait = config.get_max_retry_wait(300)
|
|
self.forwards = dict(config.forwards())
|
|
self.jid_mtype_map = {}
|
|
|
|
for xep in XEP:
|
|
self.register_plugin(xep.value)
|
|
setattr(self, xep.name.lower(), self.plugin[xep.value])
|
|
|
|
self.add_event_handler('session_start', self.handle_start)
|
|
self.add_event_handler('forward::queryhost', self.handle_queryhost)
|
|
self.add_event_handler('message', self.handle_message)
|
|
|
|
async def handle_start(self, event):
|
|
self.send_presence()
|
|
self.get_roster()
|
|
jids_by_host = collections.defaultdict(set)
|
|
for _, jid, _ in self.forwards.values():
|
|
jids_by_host[jid.host].add(jid)
|
|
for host, host_jids in jids_by_host.items():
|
|
self.event('forward::queryhost', {'host': host, 'jids': host_jids})
|
|
|
|
async def handle_queryhost(self, query):
|
|
try:
|
|
result = await self.disco.get_info(jid=query['host'])
|
|
except IqError as error:
|
|
logger.error("queryhost: error querying %r: %s",
|
|
query['host'], error.iq['error']['condition'])
|
|
except IqTimeout:
|
|
try:
|
|
wait_secs = min(query['wait_secs'] * 2, self.max_retry_wait)
|
|
except KeyError:
|
|
wait_secs = 10
|
|
logger.warning("queryhost: timeout querying %r: will retry in %s seconds",
|
|
query['host'], wait_secs)
|
|
new_query = query.copy()
|
|
new_query['wait_secs'] = wait_secs
|
|
await asyncio.sleep(wait_secs)
|
|
self.event('forward::queryhost', new_query)
|
|
else:
|
|
if self.MUC_FEATURE in result['disco_info']['features']:
|
|
logmsg = "is MUC host, joining rooms"
|
|
msgtype = 'groupchat'
|
|
for jid in query['jids']:
|
|
self.muc.join_muc(jid, self.nick)
|
|
else:
|
|
logmsg = "is not MUC host"
|
|
msgtype = 'chat'
|
|
logger.info("queryhost: %r: %s", query['host'], logmsg)
|
|
for jid in query['jids']:
|
|
self.jid_mtype_map[jid] = msgtype
|
|
|
|
async def handle_message(self, msg):
|
|
for fname, forward in self.forwards.items():
|
|
msgfrom = msg['from'].full
|
|
match = forward.from_re.search(msgfrom)
|
|
if match is None:
|
|
continue
|
|
logger.info("message: %s from %r", fname, msgfrom)
|
|
self.send_message(
|
|
mto=forward.to_jid,
|
|
mfrom=self.boundjid,
|
|
mnick=self.nick,
|
|
mtype=self.jid_mtype_map.get(forward.to_jid),
|
|
mbody=forward.body_fmt.format_map(msg),
|
|
)
|