python-telegram-bot 的自定义持久化类

Mar*_*lli 4 python mongoengine telegram python-telegram-bot

我正在使用python-telegram-bot库开发一个简单的 Telegram 聊天机器人。我的机器人目前正在使用ConversationHandler来跟踪对话状态。

我想通过将对话状态存储在 MongoDB 数据库中来使对话持久化

我正在使用mongoenginepython 库与我的数据库进行通信。

通过阅读BasePersistence( https://python-telegram-bot.readthedocs.io/en/stable/telegram.ext.basepersistence.html )的文档,我了解到有必要使用自定义类扩展此类,我们称之为MongoPersistence,并覆盖以下方法:

  • get_conversations(name)
  • update_conversation(name, key, new_state)

文档没有指定dict返回的结构get_conversations(name),因此也很难理解如何实现update_conversation(name, key, new_state)

假设我有上面提到的类(store_user_data, store_chat_data,store_bot_data都设置为False因为我不想存储这些数据):

from telegram.ext import BasePersistence


class MongoPersistence(BasePersistence):

    def __init__(self):
        super(MongoPersistence, self).__init__(store_user_data=False,
                                               store_chat_data=False,
                                               store_bot_data=False)

    def get_conversations(self, name):
        pass

    def update_conversation(self, name, key, new_state):
        pass
Run Code Online (Sandbox Code Playgroud)

我怎样才能实现这个类,以便从数据库中获取和保存我的对话状态?

Ser*_*ioR 7

对话持久性

我想实现它的最简单方法是查看PicklePersistence(). 我见过的关于字典的唯一例子是conversations = { name : { (user_id,user_id): state} }在哪里name给出的ConversationHandler(),元组作为键(user_id,user_id)user_id你的机器人正在与之交谈的人,state是对话的状态。好吧,也许不是user_id,也许是,chat_id但我不能肯定地说,我需要更多的豚鼠。

为了处理 tuple-as-a-key,python-telegram-bot 包含一些工具来帮助您处理:encode_conversations_to_jsondecode_conversations_from_json.

这里,on_flush是一个变量,用于告诉代码是否要在每次调用update_conversation()when 设置False为时或仅在退出程序时保存 when 设置为True

最后一个细节:目前以下代码仅从数据库中保存和检索,但没有替换或删除

from telegram.ext import BasePersistence
from config import mongo_URI
from copy import deepcopy
from telegram.utils.helpers import decode_conversations_from_json, encode_conversations_to_json
import mongoengine
import json
from bson import json_util

class Conversations(mongoengine.Document):
    obj = mongoengine.DictField()
    meta = { 'collection': 'Conversations', 'ordering': ['-id']}

class MongoPersistence(BasePersistence):

    def __init__(self):
        super(MongoPersistence, self).__init__(store_user_data=False,
                                               store_chat_data=False,
                                               store_bot_data=False)
        dbname = "persistencedb"
        mongoengine.connect(host=mongo_URI, db=dbname)
        self.conversation_collection = "Conversations"
        self.conversations = None
        self.on_flush = False

    def get_conversations(self, name):
        if self.conversations:
            pass
        else:
            document = Conversations.objects()
            if document.first() == None:
                document = {}
            else:
                document = document.first()['obj']
            conversations_json = json_util.dumps(document)
            self.conversations = decode_conversations_from_json(conversations_json)
        return self.conversations.get(name, {}).copy()

    def update_conversation(self, name, key, new_state):
        if self.conversations.setdefault(name, {}).get(key) == new_state:
            return
        self.conversations[name][key] = new_state
        if not self.on_flush:
            conversations_dic = json_util.loads(encode_conversations_to_json(self.conversations))
            document = Conversations(obj=conversations_dic)
            document.save()

    def flush(self):
        conversations_dic = json_util.loads(encode_conversations_to_json(self.conversations))
        document = Conversations(obj=conversations_dic)
        document.save()
        mongoengine.disconnect()
Run Code Online (Sandbox Code Playgroud)

谨防!有时对话需要用户预先设置,user_data并且此代码未按要求提供。

所有的坚持

这是一个更完整的代码(仍然缺少数据库中的替换文档)。

from telegram.ext import BasePersistence
from collections import defaultdict
from config import mongo_URI
from copy import deepcopy
from telegram.utils.helpers import decode_user_chat_data_from_json, decode_conversations_from_json, encode_conversations_to_json
import mongoengine
import json
from bson import json_util

class Conversations(mongoengine.Document):
    obj = mongoengine.DictField()
    meta = { 'collection': 'Conversations', 'ordering': ['-id']}

class UserData(mongoengine.Document):
    obj = mongoengine.DictField()
    meta = { 'collection': 'UserData', 'ordering': ['-id']}

class ChatData(mongoengine.Document):
    obj = mongoengine.DictField()
    meta = { 'collection': 'ChatData', 'ordering': ['-id']}

class BotData(mongoengine.Document):
    obj = mongoengine.DictField()
    meta = { 'collection': 'BotData', 'ordering': ['-id']}

class DBHelper():
    """Class to add and get documents from a mongo database using mongoengine
    """
    def __init__(self, dbname="persistencedb"):
        mongoengine.connect(host=mongo_URI, db=dbname)
    def add_item(self, data, collection):
        if collection == "Conversations":
            document = Conversations(obj=data)
        elif collection == "UserData":
            document = UserData(obj=data)
        elif collection == "chat_data_collection":
            document = ChatData(obj=data)
        else:
            document = BotData(obj=data)
        document.save()
    def get_item(self, collection):
        if collection == "Conversations":
            document = Conversations.objects()
        elif collection == "UserData":
            document = UserData.objects()
        elif collection == "ChatData":
            document = ChatData.objects()
        else:
            document = BotData.objects()
        if document.first() == None:
            document = {}
        else:
            document = document.first()['obj']

        return document
    def close(self):
        mongoengine.disconnect()

class DBPersistence(BasePersistence):
    """Uses DBHelper to make the bot persistant on a database.
       It's heavily inspired on PicklePersistence from python-telegram-bot
    """
    def __init__(self):
        super(DBPersistence, self).__init__(store_user_data=True,
                                               store_chat_data=True,
                                               store_bot_data=True)
        self.persistdb = "persistancedb"
        self.conversation_collection = "Conversations"
        self.user_data_collection = "UserData"
        self.chat_data_collection = "ChatData"
        self.bot_data_collection = "BotData"
        self.db = DBHelper()
        self.user_data = None
        self.chat_data = None
        self.bot_data = None
        self.conversations = None
        self.on_flush = False

    def get_conversations(self, name):
        if self.conversations:
            pass
        else:
            conversations_json = json_util.dumps(self.db.get_item(self.conversation_collection))
            self.conversations = decode_conversations_from_json(conversations_json)
        return self.conversations.get(name, {}).copy()

    def update_conversation(self, name, key, new_state):
        if self.conversations.setdefault(name, {}).get(key) == new_state:
            return
        self.conversations[name][key] = new_state
        if not self.on_flush:
            conversations_json = json_util.loads(encode_conversations_to_json(self.conversations))
            self.db.add_item(conversations_json, self.conversation_collection)

    def get_user_data(self):
        if self.user_data:
            pass
        else:
            user_data_json = json_util.dumps(self.db.get_item(self.user_data_collection))
            if user_data_json != '{}':
                self.user_data = decode_user_chat_data_from_json(user_data_json)
            else:
                self.user_data = defaultdict(dict,{})
        return deepcopy(self.user_data)

    def update_user_data(self, user_id, data):
        if self.user_data is None:
            self.user_data = defaultdict(dict)
        # comment next line if you want to save to db every time this function is called
        if self.user_data.get(user_id) == data:
            return
        self.user_data[user_id] = data
        if not self.on_flush:
            user_data_json = json_util.loads(json.dumps(self.user_data))
            self.db.add_item(user_data_json, self.user_data_collection)

    def get_chat_data(self):
        if self.chat_data:
            pass
        else:
            chat_data_json = json_util.dumps(self.db.get_item(self.chat_data_collection))
            if chat_data_json != "{}":
                self.chat_data = decode_user_chat_data_from_json(chat_data_json)
            else:
                self.chat_data = defaultdict(dict,{})
        return deepcopy(self.chat_data)

    def update_chat_data(self, chat_id, data):
        if self.chat_data is None:
            self.chat_data = defaultdict(dict)
        # comment next line if you want to save to db every time this function is called
        if self.chat_data.get(chat_id) == data:
            return
        self.chat_data[chat_id] = data
        if not self.on_flush:
            chat_data_json = json_util.loads(json.dumps(self.chat_data))
            self.db.add_item(chat_data_json, self.chat_data_collection)

    def get_bot_data(self):
        if self.bot_data:
            pass
        else:
            bot_data_json = json_util.dumps(self.db.get_item(self.bot_data_collection))
            self.bot_data = json.loads(bot_data_json)
        return deepcopy(self.bot_data)

    def update_bot_data(self, data):
        if self.bot_data == data:
            return
        self.bot_data = data.copy()
        if not self.on_flush:
            bot_data_json = json_util.loads(json.dumps(self.bot_data))
            self.db.add_item(self.bot_data, self.bot_data_collection)

    def flush(self):
        if self.conversations:
            conversations_json = json_util.loads(encode_conversations_to_json(self.conversations))
            self.db.add_item(conversations_json, self.conversation_collection)
        if self.user_data:
            user_data_json = json_util.loads(json.dumps(self.user_data))
            self.db.add_item(user_data_json, self.user_data_collection)
        if self.chat_data:
            chat_data_json = json_util.loads(json.dumps(self.chat_data))
            self.db.add_item(chat_data_json, self.chat_data_collection)
        if self.bot_data:
            bot_data_json = json_util.loads(json.dumps(self.bot_data))
            self.db.add_item(self.bot_data, self.bot_data_collection)
        self.db.close()
Run Code Online (Sandbox Code Playgroud)

两个细节:

  1. Chat_data 持久化暂时没有保存到数据库中。需要更多的测试。也许那部分代码有一个错误。
  2. 目前,代码中唯一有效的部分on_flush = False是在对话中。在所有其他更新中,似乎调用是在分配之后完成的,所以if variable[key] == data总是True在保存到数据库之前完成代码,这就是为什么有一条评论说# comment next line if you want to save to db every time this function is called但节省了很多钱。如果您设置on_flush = True并且代码更早停止(例如进程被终止),您将不会在数据库中保存任何内容。