import dataclasses
import itertools
import logging
import re
from functools import cached_property
from typing import TYPE_CHECKING, Any, Optional
try:
from typing import override # type: ignore[attr-defined]
except ImportError:
from typing_extensions import override
from rhoknp.cohesion.argument import Argument, EndophoraArgument, ExophoraArgument
from rhoknp.cohesion.coreference import Entity, EntityManager
from rhoknp.cohesion.exophora import ExophoraReferent
from rhoknp.cohesion.pas import CaseInfoFormat, Pas, normalize_case
from rhoknp.cohesion.predicate import Predicate
from rhoknp.cohesion.rel import RelMode, RelTag, RelTagList
from rhoknp.props.dependency import DepType
from rhoknp.props.feature import FeatureDict
from rhoknp.props.memo import MemoTag
from rhoknp.units.morpheme import Morpheme
from rhoknp.units.unit import Unit
if TYPE_CHECKING:
from rhoknp.units.clause import Clause
from rhoknp.units.document import Document
from rhoknp.units.phrase import Phrase
from rhoknp.units.sentence import Sentence
logger = logging.getLogger(__name__)
[docs]
class BasePhrase(Unit):
"""基本句クラス."""
PAT = re.compile(
rf"^\+( (?P<pid>-1|\d+)(?P<dtype>[{''.join(e.value for e in DepType)}]))?( {FeatureDict.PAT.pattern})?$"
)
count = 0
def __init__(
self,
parent_index: Optional[int],
dep_type: Optional[DepType],
features: Optional[FeatureDict] = None,
rel_tags: Optional[RelTagList] = None,
memo_tag: Optional[MemoTag] = None,
) -> None:
super().__init__()
# parent unit
self._phrase: Optional["Phrase"] = None
# child units
self._morphemes: Optional[list[Morpheme]] = None
self.parent_index: Optional[int] = parent_index #: 係り先の基本句の文内におけるインデックス.
self.dep_type: Optional[DepType] = dep_type #: 係り受けの種類.
self.features: FeatureDict = features or FeatureDict() #: 素性.
self.rel_tags: RelTagList = rel_tags or RelTagList() #: 基本句間関係.
self.memo_tag: MemoTag = memo_tag or MemoTag() #: タグ付けメモ.
self.pas: Pas = Pas(Predicate(self)) #: 述語項構造.
self.entities: set[Entity] = set() #: 参照しているエンティティ.
self.entities_nonidentical: set[Entity] = set() #: ≒で参照しているエンティティ.
self.index = self.count #: 文内におけるインデックス.
BasePhrase.count += 1
def __getstate__(self) -> dict[str, Any]:
state = self.__dict__.copy()
# Dump a tuple instead of a set so that the __hash__ function won't be called.
# `eids` is used to hash uninitialized Entity objects.
state["entities"] = tuple(self.entities)
state["eids"] = tuple(entity.eid for entity in state["entities"])
state["entities_nonidentical"] = tuple(self.entities_nonidentical)
state["eids_nonidentical"] = tuple(entity.eid for entity in state["entities_nonidentical"])
return state
def __setstate__(self, state: dict[str, Any]) -> None:
# Restore eids to Entity objects for hashing.
for entity, eid in zip(state["entities"], state.pop("eids")):
entity.eid = eid
for entity, eid in zip(state["entities_nonidentical"], state.pop("eids_nonidentical")):
entity.eid = eid
self.__dict__.update(state) # Entity objects are hashed by eid.
@override
def __post_init__(self) -> None:
super().__post_init__()
# Parse the PAS tag.
if "述語項構造" in self.features:
pas_string = self.features["述語項構造"]
assert isinstance(pas_string, str)
self.pas.parse_pas_string(self, pas_string, format_=CaseInfoFormat.PAS)
elif "格解析結果" in self.features:
pas_string = self.features["格解析結果"]
assert isinstance(pas_string, str)
self.pas.parse_pas_string(self, pas_string, format_=CaseInfoFormat.CASE)
# Parse the rel tags.
for rel_tag_orig in self.rel_tags:
rel_tag = rel_tag_orig
if rel_tag.sid == "":
# The target is considered to be in the same sentence.
rel_tag = dataclasses.replace(rel_tag, sid=self.sentence.sid)
if rel_tag.is_coreference():
if rel_tag.mode not in (RelMode.OR, RelMode.AMBIGUOUS):
self._add_coreference(rel_tag)
else:
self._add_argument(rel_tag)
@override
def __eq__(self, other: object) -> bool:
if not isinstance(other, type(self)):
return False
if self.parent_unit != other.parent_unit:
return False
return self.index == other.index
@cached_property
def global_index(self) -> int:
"""文書全体におけるインデックス."""
if not self.sentence.has_document():
return self.index
if self.sentence.index == 0:
return self.index
if self.index > 0:
return self.sentence.base_phrases[0].global_index + self.index
prev_sentence = self.document.sentences[self.sentence.index - 1]
return prev_sentence.base_phrases[0].global_index + len(prev_sentence.base_phrases)
@property
def parent_unit(self) -> Optional["Phrase"]:
"""上位の言語単位(文節).未登録なら None."""
return self._phrase
@property
def child_units(self) -> Optional[list[Morpheme]]:
"""下位の言語単位(形態素).解析結果にアクセスできないなら None."""
return self._morphemes
@property
def document(self) -> "Document":
"""文書.
Raises:
AttributeError: 解析結果にアクセスできない場合.
"""
return self.phrase.document
@property
def sentence(self) -> "Sentence":
"""文."""
return self.phrase.sentence
@property
def clause(self) -> "Clause":
"""節.
Raises:
AttributeError: 解析結果にアクセスできない場合.
"""
return self.phrase.clause
@property
def phrase(self) -> "Phrase":
"""文節."""
assert self._phrase is not None
return self._phrase
@phrase.setter
def phrase(self, phrase: "Phrase") -> None:
"""文節.
Args:
phrase: 文節.
"""
self._phrase = phrase
@property
def morphemes(self) -> list[Morpheme]:
"""形態素のリスト."""
assert self._morphemes is not None
return self._morphemes
@morphemes.setter
def morphemes(self, morphemes: list[Morpheme]) -> None:
"""形態素のリスト.
Args:
morphemes: 形態素のリスト.
"""
for morpheme in morphemes:
morpheme.base_phrase = self
self._morphemes = morphemes
@property
def head(self) -> Morpheme:
"""主辞の形態素."""
feature_to_priority = {"内容語": 0, "準内容語": 1, "基本句-主辞": 2}
head = self.morphemes[0]
current_priority = -1
for morpheme in self.morphemes:
if not morpheme.features:
continue
for feature, priority in feature_to_priority.items():
if feature in morpheme.features and priority > current_priority:
head = morpheme
current_priority = priority
return head
@property
def parent(self) -> Optional["BasePhrase"]:
"""係り先の基本句.ないなら None.
Raises:
AttributeError: 解析結果にアクセスできない場合.
"""
if self.parent_index is None:
raise AttributeError("parent_index has not been set")
if self.parent_index == -1:
return None
return self.sentence.base_phrases[self.parent_index]
@cached_property
def children(self) -> list["BasePhrase"]:
"""この基本句に係っている基本句のリスト.
Raises:
AttributeError: 解析結果にアクセスできない場合.
"""
return [base_phrase for base_phrase in self.sentence.base_phrases if base_phrase.parent == self]
@property
def entities_all(self) -> set[Entity]:
"""nonidentical も含めた参照している全エンティティの集合."""
return self.entities | self.entities_nonidentical
[docs]
@classmethod
def from_knp(cls, knp_text: str) -> "BasePhrase":
"""基本句クラスのインスタンスを KNP の解析結果から初期化.
Args:
knp_text: KNP の解析結果.
"""
first_line, *lines = knp_text.split("\n")
match = cls.PAT.match(first_line)
if match is None:
raise ValueError(f"malformed base phrase line: {first_line}")
base_phrase = cls(
parent_index=int(match["pid"]) if match["pid"] is not None else None,
dep_type=DepType(match["dtype"]) if match["dtype"] is not None else None,
features=FeatureDict.from_fstring(match["feats"] or ""),
rel_tags=RelTagList.from_fstring(match["feats"] or ""),
memo_tag=MemoTag.from_fstring(match["feats"] or ""),
)
morphemes: list[Morpheme] = []
for line in lines:
if line.strip() == "":
continue
morphemes.append(Morpheme.from_jumanpp(line))
base_phrase.morphemes = morphemes
return base_phrase
[docs]
def to_knp(self) -> str:
"""KNP フォーマットに変換."""
ret = "+"
if self.parent_index is not None:
assert self.dep_type is not None
ret += f" {self.parent_index}{self.dep_type.value}"
if self.rel_tags or self.memo_tag or self.features:
ret += " "
ret += self.rel_tags.to_fstring()
if self.memo_tag:
ret += self.memo_tag.to_fstring()
ret += self.features.to_fstring()
ret += "\n"
ret += "".join(morpheme.to_knp() for morpheme in self.morphemes)
return ret
[docs]
def get_coreferents(self, include_nonidentical: bool = False, include_self: bool = False) -> list["BasePhrase"]:
"""この基本句と共参照している基本句の集合を返却.
Args:
include_nonidentical: nonidentical なメンションを含めるなら True.
include_self: 自身を含めるなら True.
Returns:
共参照している基本句の集合.
"""
mentions: list["BasePhrase"] = [self]
for mention in itertools.chain.from_iterable(entity.mentions for entity in self.entities):
if mention not in mentions:
mentions.append(mention)
if include_nonidentical is True:
for mention in itertools.chain.from_iterable(entity.mentions for entity in self.entities_nonidentical):
if mention not in mentions:
mentions.append(mention)
if include_self is False:
while self in mentions:
mentions.remove(self)
return mentions
def _add_argument(self, rel_tag: RelTag) -> None:
"""自身を述語とする述語項構造に項を追加."""
case = normalize_case(rel_tag.type)
argument: Argument
if rel_tag.sid is not None:
arg_base_phrase = self._get_target_base_phrase(rel_tag)
if arg_base_phrase is None:
return
if not arg_base_phrase.entities:
EntityManager.get_or_create_entity().add_mention(arg_base_phrase)
argument = EndophoraArgument(case, arg_base_phrase, self.pas.predicate)
else:
if rel_tag.target == "なし":
self.pas.set_arguments_optional(case)
return
exophora_referent = ExophoraReferent(rel_tag.target)
entity = EntityManager.get_or_create_entity(exophora_referent)
argument = ExophoraArgument(case, exophora_referent, entity.eid)
self.pas.add_argument(argument, mode=rel_tag.mode)
def _add_coreference(self, rel_tag: RelTag) -> None:
"""共参照関係を追加."""
# create source entity
if not self.entities:
EntityManager.get_or_create_entity().add_mention(self)
is_nonidentical: bool = rel_tag.type.endswith("≒")
if rel_tag.sid is not None:
target_base_phrase = self._get_target_base_phrase(rel_tag)
if target_base_phrase is None:
return
if target_base_phrase == self:
logger.warning(f"{self.sentence.sid}: coreference with self found: {self}")
return
# create target entity
if not target_base_phrase.entities:
EntityManager.get_or_create_entity().add_mention(target_base_phrase)
for source_entity, target_entity in itertools.product(self.entities_all, target_base_phrase.entities_all):
# Because entities are dynamically deleted within this loop, we need to check if they exist.
if source_entity in self.entities_all and target_entity in target_base_phrase.entities_all:
EntityManager.merge_entities(
self, target_base_phrase, source_entity, target_entity, is_nonidentical
)
else:
# exophora
target_entity = EntityManager.get_or_create_entity(exophora_referent=ExophoraReferent(rel_tag.target))
for source_entity in self.entities_all:
# Because entities are dynamically deleted within this loop, we need to check if they exist.
if source_entity in self.entities_all and target_entity in EntityManager.entities.values():
EntityManager.merge_entities(self, None, source_entity, target_entity, is_nonidentical)
def _get_target_base_phrase(self, rel_tag: RelTag) -> Optional["BasePhrase"]:
"""rel_tag が指す基本句を返す.見つからなければ None を返す."""
sentences = self.document.sentences if self.sentence.has_document() else [self.sentence]
sentences = [sent for sent in sentences if sent.sid == rel_tag.sid]
if not sentences:
logger.warning(f"{self.sentence.sid}: relation with unknown sid found: {rel_tag.sid}")
return None
sentence = sentences[0]
assert rel_tag.base_phrase_index is not None
if rel_tag.base_phrase_index >= len(sentence.base_phrases):
logger.warning(f"{self.sentence.sid}: index out of range")
return None
target_base_phrase = sentence.base_phrases[rel_tag.base_phrase_index]
if not (set(rel_tag.target) & set(target_base_phrase.text)):
logger.warning(
f"{self.sentence.sid}: rel target mismatch; '{rel_tag.target}' vs '{target_base_phrase.text}'"
)
return target_base_phrase
[docs]
@staticmethod
def is_base_phrase_line(line: str) -> bool:
"""基本句行なら True を返す."""
return BasePhrase.PAT.match(line) is not None