Source code for semstr.validation

from itertools import groupby

from ucca import layer0, layer1, validation as ucca_validations
from ucca.normalization import normalize

from .constraints import Direction


[docs]def ucca_constraints(*args, **kwargs): from .constraint.ucca import UccaConstraints return UccaConstraints(*args, **kwargs)
[docs]def sdp_constraints(*args, **kwargs): from .constraint.sdp import SdpConstraints return SdpConstraints(*args, **kwargs)
[docs]def conllu_constraints(*args, **kwargs): from .constraint.conllu import ConlluConstraints return ConlluConstraints(*args, **kwargs)
[docs]def amr_constraints(*args, **kwargs): from .constraint.amr import AmrConstraints return AmrConstraints(*args, **kwargs)
CONSTRAINTS = { None: ucca_constraints, "amr": amr_constraints, "sdp": sdp_constraints, "conllu": conllu_constraints, }
[docs]def detect_cycles(passage): stack = [list(passage.layer(layer1.LAYER_ID).heads)] visited = set() path = [] path_set = set(path) while stack: for node in stack[-1]: if node in path_set: yield "Detected cycle (%s)" % "->".join(n.ID for n in path) elif node not in visited: visited.add(node) path.append(node) path_set.add(node) stack.append(node.children) break else: if path: path_set.remove(path.pop()) stack.pop()
[docs]def join(edges): return ", ".join("%s-[%s%s]->%s" % (e.parent.ID, e.tag, "*" if e.attrib.get("remote") else "", e.child.ID) for e in edges)
[docs]def check_orphan_terminals(constraints, terminal): if not constraints.allow_orphan_terminals: if not terminal.incoming: yield "Orphan %s terminal (%s) '%s'" % (terminal.tag, terminal.ID, terminal)
[docs]def check_root_terminal_children(constraints, l1, terminal): if not constraints.allow_root_terminal_children: if set(l1.heads).intersection(terminal.parents): yield "Terminal child of root (%s) '%s'" % (terminal.ID, terminal)
[docs]def check_top_level_allowed(constraints, l1): if constraints.top_level_allowed: for head in l1.heads: for edge in head: if edge.tag not in constraints.top_level_allowed: yield "Top level %s edge (%s)" % (edge.tag, edge)
[docs]def check_multigraph(constraints, node): if not constraints.multigraph: for parent_id, edges in groupby(node.incoming, key=lambda e: e.parent.ID): edges = list(edges) if len(edges) > 1: yield "Multiple edges from %s to %s (%s)" % (parent_id, node.ID, join(edges))
[docs]def check_implicit_children(constraints, node): if constraints.require_implicit_childless and node.attrib.get("implicit") and len(node.outgoing) > 1: yield "Implicit node with children (%s)" % node.ID
[docs]def check_multiple_incoming(constraints, node): if constraints.possible_multiple_incoming: incoming = [e for e in node.incoming if not e.attrib.get("remote") and e.tag not in constraints.possible_multiple_incoming] if len(incoming) > 1: yield "Multiple incoming non-remote (%s)" % join(incoming)
[docs]def check_top_level_only(constraints, l1, node): if constraints.top_level_only and node not in l1.heads: for edge in node: if edge.tag in constraints.top_level_only: yield "Non-top level %s edge (%s)" % (edge.tag, edge)
[docs]def check_required_outgoing(constraints, node): if constraints.required_outgoing and all(n.tag == layer1.NodeTags.Foundational for n in node.children) and \ not any(e.tag in constraints.required_outgoing for e in node): yield "Non-terminal without outgoing %s (%s)" % (constraints.required_outgoing, node.ID)
[docs]def check_tag_rules(constraints, node): for edge in node: for rule in constraints.tag_rules: for violation in (rule.violation(node, edge, Direction.outgoing, message=True), rule.violation(edge.child, edge, Direction.incoming, message=True)): if violation: yield "%s (%s)" % (violation, join([edge])) valid = constraints.allow_parent(node, edge.tag) if not valid: yield "%s may not be a '%s' parent (%s, %s): %s" % ( node.ID, edge.tag, join(node.incoming), join(node), valid) valid = constraints.allow_child(edge.child, edge.tag) if not valid: yield "%s may not be a '%s' child (%s, %s): %s" % ( edge.child.ID, edge.tag, join(edge.child.incoming), join(edge.child), valid) valid = constraints.allow_edge(edge) if not valid: "Illegal edge: %s (%s)" % (join([edge]), valid)
[docs]def validate(passage, normalization=False, extra_normalization=False, ucca_validation=False, output_format=None, **kwargs): del kwargs if normalization: normalize(passage, extra=extra_normalization) if ucca_validation: yield from ucca_validations.validate(passage) else: # Generic validations depending on format-specific constraints try: constraints = CONSTRAINTS[passage.extra.get("format", output_format)]() except KeyError as e: raise ValueError("No validations defined for '%s' format" % output_format) from e yield from detect_cycles(passage) l0 = passage.layer(layer0.LAYER_ID) l1 = passage.layer(layer1.LAYER_ID) for terminal in l0.all: yield from check_orphan_terminals(constraints, terminal) yield from check_root_terminal_children(constraints, l1, terminal) yield from check_multiple_incoming(constraints, terminal) yield from check_top_level_allowed(constraints, l1) for node in l1.all: yield from check_multigraph(constraints, node) yield from check_implicit_children(constraints, node) yield from check_multiple_incoming(constraints, node) yield from check_top_level_only(constraints, l1, node) yield from check_required_outgoing(constraints, node) yield from check_tag_rules(constraints, node)