mirror of
https://github.com/Mayuri-Chan/pyrofork.git
synced 2025-12-29 12:04:51 +00:00
243 lines
8.5 KiB
Python
243 lines
8.5 KiB
Python
# Pyrogram - Telegram MTProto API Client Library for Python
|
|
# Copyright (C) 2017-present Dan <https://github.com/delivrance>
|
|
#
|
|
# This file is part of Pyrogram.
|
|
#
|
|
# Pyrogram is free software: you can redistribute it and/or modify
|
|
# it under the terms of the GNU Lesser General Public License as published
|
|
# by the Free Software Foundation, either version 3 of the License, or
|
|
# (at your option) any later version.
|
|
#
|
|
# Pyrogram is distributed in the hope that it will be useful,
|
|
# but WITHOUT ANY WARRANTY; without even the implied warranty of
|
|
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
|
|
# GNU Lesser General Public License for more details.
|
|
#
|
|
# You should have received a copy of the GNU Lesser General Public License
|
|
# along with Pyrogram. If not, see <http://www.gnu.org/licenses/>.
|
|
|
|
import html
|
|
import logging
|
|
import re
|
|
from html.parser import HTMLParser
|
|
from typing import Optional
|
|
|
|
import pyrogram
|
|
from pyrogram import raw
|
|
from pyrogram.enums import MessageEntityType
|
|
from pyrogram.errors import PeerIdInvalid
|
|
from . import utils
|
|
|
|
log = logging.getLogger(__name__)
|
|
|
|
|
|
class Parser(HTMLParser):
|
|
MENTION_RE = re.compile(r"tg://user\?id=(\d+)")
|
|
|
|
def __init__(self, client: "pyrogram.Client"):
|
|
super().__init__()
|
|
|
|
self.client = client
|
|
|
|
self.text = ""
|
|
self.entities = []
|
|
self.tag_entities = {}
|
|
|
|
def handle_starttag(self, tag, attrs):
|
|
attrs = dict(attrs)
|
|
extra = {}
|
|
|
|
if tag in ["b", "strong"]:
|
|
entity = raw.types.MessageEntityBold
|
|
elif tag in ["i", "em"]:
|
|
entity = raw.types.MessageEntityItalic
|
|
elif tag == "u":
|
|
entity = raw.types.MessageEntityUnderline
|
|
elif tag in ["s", "del", "strike"]:
|
|
entity = raw.types.MessageEntityStrike
|
|
elif tag == "blockquote":
|
|
entity = raw.types.MessageEntityBlockquote
|
|
elif tag == "code":
|
|
entity = raw.types.MessageEntityCode
|
|
elif tag == "pre":
|
|
entity = raw.types.MessageEntityPre
|
|
extra["language"] = attrs.get("language", "")
|
|
elif tag == "spoiler":
|
|
entity = raw.types.MessageEntitySpoiler
|
|
elif tag == "a":
|
|
url = attrs.get("href", "")
|
|
|
|
mention = Parser.MENTION_RE.match(url)
|
|
|
|
if mention:
|
|
entity = raw.types.InputMessageEntityMentionName
|
|
extra["user_id"] = int(mention.group(1))
|
|
else:
|
|
entity = raw.types.MessageEntityTextUrl
|
|
extra["url"] = url
|
|
elif tag == "emoji":
|
|
entity = raw.types.MessageEntityCustomEmoji
|
|
custom_emoji_id = int(attrs.get("id"))
|
|
extra["document_id"] = custom_emoji_id
|
|
else:
|
|
return
|
|
|
|
if tag not in self.tag_entities:
|
|
self.tag_entities[tag] = []
|
|
|
|
self.tag_entities[tag].append(entity(offset=len(self.text), length=0, **extra))
|
|
|
|
def handle_data(self, data):
|
|
data = html.unescape(data)
|
|
|
|
for entities in self.tag_entities.values():
|
|
for entity in entities:
|
|
entity.length += len(data)
|
|
|
|
self.text += data
|
|
|
|
def handle_endtag(self, tag):
|
|
try:
|
|
self.entities.append(self.tag_entities[tag].pop())
|
|
except (KeyError, IndexError):
|
|
line, offset = self.getpos()
|
|
offset += 1
|
|
|
|
log.debug("Unmatched closing tag </%s> at line %s:%s", tag, line, offset)
|
|
else:
|
|
if not self.tag_entities[tag]:
|
|
self.tag_entities.pop(tag)
|
|
|
|
def error(self, message):
|
|
pass
|
|
|
|
|
|
class HTML:
|
|
def __init__(self, client: Optional["pyrogram.Client"]):
|
|
self.client = client
|
|
|
|
async def parse(self, text: str):
|
|
# Strip whitespaces from the beginning and the end, but preserve closing tags
|
|
text = re.sub(r"^\s*(<[\w<>=\s\"]*>)\s*", r"\1", text)
|
|
text = re.sub(r"\s*(</[\w</>]*>)\s*$", r"\1", text)
|
|
|
|
parser = Parser(self.client)
|
|
parser.feed(utils.add_surrogates(text))
|
|
parser.close()
|
|
|
|
if parser.tag_entities:
|
|
unclosed_tags = []
|
|
|
|
for tag, entities in parser.tag_entities.items():
|
|
unclosed_tags.append(f"<{tag}> (x{len(entities)})")
|
|
|
|
log.info("Unclosed tags: %s", ", ".join(unclosed_tags))
|
|
|
|
entities = []
|
|
|
|
for entity in parser.entities:
|
|
if isinstance(entity, raw.types.InputMessageEntityMentionName):
|
|
try:
|
|
if self.client is not None:
|
|
entity.user_id = await self.client.resolve_peer(entity.user_id)
|
|
except PeerIdInvalid:
|
|
continue
|
|
|
|
entities.append(entity)
|
|
|
|
# Remove zero-length entities
|
|
entities = list(filter(lambda x: x.length > 0, entities))
|
|
|
|
return {
|
|
"message": utils.remove_surrogates(parser.text),
|
|
"entities": sorted(entities, key=lambda e: e.offset) or None
|
|
}
|
|
|
|
@staticmethod
|
|
def unparse(text: str, entities: list):
|
|
def parse_one(entity):
|
|
"""
|
|
Parses a single entity and returns (start_tag, start), (end_tag, end)
|
|
"""
|
|
entity_type = entity.type
|
|
start = entity.offset
|
|
end = start + entity.length
|
|
|
|
if entity_type in (
|
|
MessageEntityType.BOLD,
|
|
MessageEntityType.ITALIC,
|
|
MessageEntityType.UNDERLINE,
|
|
MessageEntityType.STRIKETHROUGH,
|
|
):
|
|
name = entity_type.name[0].lower()
|
|
start_tag = f"<{name}>"
|
|
end_tag = f"</{name}>"
|
|
elif entity_type == MessageEntityType.PRE:
|
|
name = entity_type.name.lower()
|
|
language = getattr(entity, "language", "") or ""
|
|
start_tag = f'<{name} language="{language}">' if language else f"<{name}>"
|
|
end_tag = f"</{name}>"
|
|
elif entity_type in (
|
|
MessageEntityType.CODE,
|
|
MessageEntityType.BLOCKQUOTE,
|
|
MessageEntityType.SPOILER,
|
|
):
|
|
name = entity_type.name.lower()
|
|
start_tag = f"<{name}>"
|
|
end_tag = f"</{name}>"
|
|
elif entity_type == MessageEntityType.TEXT_LINK:
|
|
url = entity.url
|
|
start_tag = f'<a href="{url}">'
|
|
end_tag = "</a>"
|
|
elif entity_type == MessageEntityType.TEXT_MENTION:
|
|
user = entity.user
|
|
start_tag = f'<a href="tg://user?id={user.id}">'
|
|
end_tag = "</a>"
|
|
elif entity_type == MessageEntityType.CUSTOM_EMOJI:
|
|
custom_emoji_id = entity.custom_emoji_id
|
|
start_tag = f'<emoji id="{custom_emoji_id}">'
|
|
end_tag = "</emoji>"
|
|
else:
|
|
return
|
|
|
|
return (start_tag, start), (end_tag, end)
|
|
|
|
def recursive(entity_i: int) -> int:
|
|
"""
|
|
Takes the index of the entity to start parsing from, returns the number of parsed entities inside it.
|
|
Uses entities_offsets as a stack, pushing (start_tag, start) first, then parsing nested entities,
|
|
and finally pushing (end_tag, end) to the stack.
|
|
No need to sort at the end.
|
|
"""
|
|
this = parse_one(entities[entity_i])
|
|
if this is None:
|
|
return 1
|
|
(start_tag, start), (end_tag, end) = this
|
|
entities_offsets.append((start_tag, start))
|
|
internal_i = entity_i + 1
|
|
# while the next entity is inside the current one, keep parsing
|
|
while internal_i < len(entities) and entities[internal_i].offset < end:
|
|
internal_i += recursive(internal_i)
|
|
entities_offsets.append((end_tag, end))
|
|
return internal_i - entity_i
|
|
|
|
text = utils.add_surrogates(text)
|
|
|
|
entities_offsets = []
|
|
|
|
# probably useless because entities are already sorted by telegram
|
|
entities.sort(key=lambda e: (e.offset, -e.length))
|
|
|
|
# main loop for first-level entities
|
|
i = 0
|
|
while i < len(entities):
|
|
i += recursive(i)
|
|
|
|
if entities_offsets:
|
|
last_offset = entities_offsets[-1][1]
|
|
# no need to sort, but still add entities starting from the end
|
|
for entity, offset in reversed(entities_offsets):
|
|
text = text[:offset] + entity + html.escape(text[offset:last_offset]) + text[last_offset:]
|
|
last_offset = offset
|
|
|
|
return utils.remove_surrogates(text)
|