"""Module containing directed loop graph class"""
from __future__ import annotations
from collections.abc import Iterator
from dataclasses import dataclass, field
from typing import Generic, TypeVar, cast, overload
from networkx import DiGraph, ancestors, dfs_edges, dfs_preorder_nodes, has_path
from knit_graphs.Loop import Loop
from knit_graphs.Pull_Direction import Pull_Direction
EdgeT = TypeVar("EdgeT")
LoopT = TypeVar("LoopT", bound=Loop)
[docs]
class Directed_Loop_Graph(Generic[LoopT, EdgeT]):
"""
Wrapper for networkx.DiGraphs with directed edges between loops (i.e., floats in yarns, stitches in knitgraph).
"""
_DATA_ATTRIBUTE_NAME = "data"
[docs]
def __init__(self) -> None:
self._loop_graph: DiGraph = DiGraph()
self._loops_by_loop_id: dict[int, LoopT] = {}
@property
def loop_count(self) -> int:
"""
Returns:
int: The number of loops in the graph.
"""
return len(self._loops_by_loop_id)
@property
def edge_count(self) -> int:
"""
Returns:
int: The number of edges in the graph.
"""
return len(self._loop_graph.edges)
@property
def contains_loops(self) -> bool:
"""
Returns:
bool: True if the graph has at least one loop, False otherwise.
"""
return len(self) > 0
@property
def sorted_loops(self) -> list[LoopT]:
"""
Returns:
list[Loop]: The list of loops in the graph sorted from the earliest formed loop to the latest formed loop.
"""
return sorted(self)
@property
def edge_iter(self) -> Iterator[tuple[LoopT, LoopT, EdgeT]]:
"""
Returns:
Iterator[tuple[LoopT, LoopT, EdgeT]]: Iterator over the edges and edge-data in the graph.
Notes:
No guarantees about the order of the edges.
"""
return iter((cast(LoopT, u), cast(LoopT, v), self.get_edge(u, v)) for u, v in self._loop_graph.edges)
@property
def terminal_loops(self) -> Iterator[LoopT]:
"""
Returns:
Iterator[Loop]: An iterator over all terminal loops in the graph.
"""
return iter(loop for loop in self if self.is_terminal_loop(loop))
[docs]
def has_loop(self, loop: int | LoopT) -> bool:
"""
Args:
loop (int | LoopT): The loop or loop id to check for in the graph.
Returns:
bool: True if the loop id is in the graph. False, otherwise.
"""
if isinstance(loop, int):
return loop in self._loops_by_loop_id
else:
return bool(self._loop_graph.has_node(loop))
[docs]
def get_loop(self, loop_id: int) -> LoopT:
"""
Args:
loop_id (int): The loop id of the loop to get from the graph.
Returns:
LoopT: The loop node in the graph.
"""
return self._loops_by_loop_id[loop_id]
[docs]
def successors(self, loop: int | LoopT) -> set[LoopT]:
"""
Args:
loop (int | LoopT): The loop to get the successors of from the graph.
Returns:
set[LoopT]: The successors of the loop.
"""
if isinstance(loop, int):
loop = self.get_loop(loop)
return cast(set[LoopT], set(self._loop_graph.successors(loop)))
[docs]
def has_child_loop(self, loop: LoopT) -> bool:
"""
Args:
loop (Loop): The loop to check for child connections.
Returns:
bool: True if the loop has a child loop, False otherwise.
"""
return len(self.successors(loop)) > 0
[docs]
def is_terminal_loop(self, loop: LoopT) -> bool:
"""
Args:
loop (LoopT): The loop to check for terminal status.
Returns:
bool: True if the loop has no child loops, False otherwise.
"""
return not self.has_child_loop(loop)
[docs]
def get_child_loop(self, loop: LoopT) -> LoopT | None:
"""
Args:
loop (LoopT): The loop to look for a child loop from.
Returns:
LoopT | None: The child loop if one exists, or None if no child loop is found.
"""
successors = self.successors(loop)
if len(successors) == 0:
return None
return successors.pop()
[docs]
def predecessors(self, loop: int | LoopT) -> set[LoopT]:
"""
Args:
loop (int | LoopT): The loop to get the predecessors of from the graph.
Returns:
set[LoopT]: The successors of the loop.
"""
if isinstance(loop, int):
loop = self.get_loop(loop)
return cast(set[LoopT], set(self._loop_graph.predecessors(loop)))
[docs]
def ancestors(self, loop: LoopT) -> set[LoopT]:
"""
Args:
loop (LoopT): The loop to get the ancestors of from the graph.
Returns:
set[LoopT]: The ancestors of the given loop.
"""
return cast(set[LoopT], ancestors(self._loop_graph, loop))
[docs]
def is_descendant(self, ancestor: LoopT, descendant: LoopT) -> bool:
"""
Args:
ancestor (LoopT): The loop to check if it is an ancestor of the other loop.
descendant (LoopT): The loop to check if it is a descendant of the other loop.
Returns:
bool: True if there is a directed path from the ancestor to the descendant, False otherwise.
"""
return bool(has_path(self._loop_graph, ancestor, descendant))
[docs]
def source_loops(self, loop: LoopT) -> set[LoopT]:
"""
Args:
loop (LoopT): The loop to get the sources of.
Returns:
set[LoopT]: The source loops of the given loop. These are the loops that are the source of all ancestors to the given loop.
"""
ancestor_loops = self.ancestors(loop)
if len(ancestor_loops) == 0:
return set()
sources = set()
while len(ancestor_loops) > 0:
ancestor = ancestor_loops.pop()
if len(ancestor_loops) == 0:
sources.add(ancestor)
elif not any(self.is_descendant(other, ancestor) for other in ancestor_loops):
sources.add(ancestor)
non_sources = {descendant for descendant in ancestor_loops if self.is_descendant(ancestor, descendant)}
ancestor_loops.difference_update(non_sources) # remove all ancestors that descend from the source.
return sources
[docs]
def dfs_edges(self, loop: LoopT) -> Iterator[tuple[LoopT, LoopT]]:
"""
Args:
loop (LoopT): The loop to start iteration over edges from.
Returns:
The depth-first-search ordering of edges starting from the given loop.
"""
return cast(Iterator[tuple[LoopT, LoopT]], dfs_edges(self._loop_graph, loop))
[docs]
def dfs_preorder_loops(self, loop: LoopT) -> Iterator[LoopT]:
"""
Args:
loop (LoopT): The loop to start iteration from.
Returns:
Iterator[LoopT]: The depth-first-search ordering of loops in the graph starting from the given loop.
"""
return cast(Iterator[LoopT], dfs_preorder_nodes(self._loop_graph, loop))
[docs]
def has_edge(self, u: LoopT | int, v: LoopT | int) -> bool:
"""
Args:
u (LoopT | int): The loop or id of the first loop in the edge.
v (LoopT | int): The loop or id of the second loop in the edge.
Returns:
bool: True if the graph has an edge from loop u to v. False, otherwise.
"""
if not self.has_loop(u) or not self.has_loop(v):
return False
if isinstance(u, int):
u = self.get_loop(u)
if isinstance(v, int):
v = self.get_loop(v)
return bool(self._loop_graph.has_edge(u, v))
[docs]
def get_edge(self, u: LoopT | int, v: LoopT | int) -> EdgeT:
"""
Args:
u (LoopT | int): The loop or id of the first loop in the edge.
v (LoopT | int): The loop or id of the second loop in the edge.
Returns:
EdgeT: The data about the edge from loop u to v.
"""
return cast(EdgeT, self._loop_graph.edges[u, v][self._DATA_ATTRIBUTE_NAME])
[docs]
def add_loop(self, loop: LoopT) -> None:
"""Add the given loop to the loop graph.
Args:
loop (LoopT): The loop to add to the graph.
"""
self._loop_graph.add_node(loop)
self._loops_by_loop_id[loop.loop_id] = loop
[docs]
def remove_loop(self, loop: LoopT | int) -> None:
"""
Remove the given loop from the graph.
Args:
loop (LoopT | int): The loop or id of the loop to remove.
Raises:
KeyError: The given loop does not exist in the graph.
"""
if isinstance(loop, int):
if loop not in self._loops_by_loop_id:
raise KeyError(f"No loop with id {loop} in graph")
loop = self._loops_by_loop_id[loop]
if not self._loop_graph.has_node(loop):
raise KeyError(f"No loop {loop} in graph")
self._loop_graph.remove_node(loop)
del self._loops_by_loop_id[loop.loop_id]
[docs]
def add_edge(self, u: LoopT | int, v: LoopT | int, edge_data: EdgeT) -> None:
"""
Connect the given loop u to the loop v with the given edge data.
Args:
u (LoopT | int): The loop to connect to.
v (LoopT | int): The loop to connect to.
edge_data (EdgeT): The edge data to associate with the connection.
Raises:
KeyError: Either of the given loops does not exist in the graph.
"""
if u not in self:
raise KeyError(f"parent loop {u} is not in graph")
if v not in self:
raise KeyError(f"child loop {v} i not in graph")
if isinstance(u, int):
u = self.get_loop(u)
if isinstance(v, int):
v = self.get_loop(v)
self._loop_graph.add_edge(u, v, data=edge_data)
[docs]
def remove_edge(self, u: LoopT | int, v: LoopT | int) -> None:
"""
Removes the edge from loop u to the loop v from the graph.
Args:
u (LoopT | int): The loop to connect to.
v (LoopT | int): The loop to connect to.
Raises:
KeyError: The given edge does not exist in the graph.
"""
if (u, v) not in self:
raise KeyError(f"Edge from {u} to {v} is not in graph")
if isinstance(u, int):
u = self.get_loop(u)
if isinstance(v, int):
v = self.get_loop(v)
self._loop_graph.remove_edge(u, v)
[docs]
def __contains__(self, item: LoopT | int | tuple[LoopT | int, LoopT | int]) -> bool:
"""
Args:
item (LoopT | int | tuple[LoopT | int, LoopT | int]): The loop, loop-id, or a pair of loops in a directed edge to search for.
Returns:
bool: True if the graph contains the given loop or edge.
"""
if isinstance(item, tuple):
return self.has_edge(item[0], item[1])
else:
return self.has_loop(item)
@overload
def __getitem__(self, item: int) -> LoopT: ...
@overload
def __getitem__(self, item: tuple[LoopT | int, LoopT | int]) -> EdgeT: ...
[docs]
def __getitem__(self, item: int | tuple[LoopT | int, LoopT | int]) -> LoopT | EdgeT:
"""
Args:
item (int | tuple[LoopT | int, LoopT | int]): The id of the loop or the pair of loops that form an edge in the graph.
Returns:
LoopT: The loop associated with the given id.
EdgeT: The edge data associated with the given pair of loops/ loop_ids.
Raises:
KeyError: The given loop or edge does not exist in the graph.
"""
if item not in self:
raise KeyError(f"{item} is not in graph")
elif isinstance(item, int):
return self.get_loop(item)
else:
return self.get_edge(item[0], item[1])
[docs]
def __iter__(self) -> Iterator[LoopT]:
"""
Returns:
Iterator[LoopT]: Iterator over all the loops in the graph.
Notes:
No guarantees about order of loops. Expected to be insertion order.
"""
return iter(self._loops_by_loop_id.values())
[docs]
def __len__(self) -> int:
"""
Returns:
int: The number of loops in the graph.
"""
return len(self._loops_by_loop_id)
[docs]
@dataclass
class Stitch_Edge:
"""Common data about stitch edges."""
pull_direction: Pull_Direction # The direction of the stitch edge.
[docs]
@dataclass
class Float_Edge(Generic[LoopT]):
"""The edge data for float edges between loops on a yarn."""
front_loops: set[LoopT] = field(default_factory=set) # The set of loops that sit in front of this float. Defaults to empty set.
back_loops: set[LoopT] = field(default_factory=set) # THe set of loops that sit behind this float. Defaults to the empty set.
[docs]
def loop_in_front_of_float(self, loop: LoopT) -> bool:
"""
Args:
loop (LoopT): The loop to find relative to this float.
Returns:
bool: True if the loop is in the front of the float.
"""
return loop in self.front_loops
[docs]
def loop_behind_float(self, loop: LoopT) -> bool:
"""
Args:
loop (LoopT): The loop to find relative to this float.
Returns:
bool: True if the loop is behind the float.
"""
return loop in self.front_loops
[docs]
def add_loop_in_front_of_float(self, loop: LoopT) -> None:
"""
Adds the given loop to the set of loops in front of this float.
If the loop was behind the float, it is swapped to be in front.
Args:
loop (LoopT): The loop to put in front of this float.
"""
if loop in self.back_loops:
self.back_loops.remove(loop)
self.front_loops.add(loop)
[docs]
def add_loop_behind_float(self, back_loop: LoopT) -> None:
"""
Adds the given loop to the set of loops behind this float.
If the loop was in front of this float, it is swapped to be behind.
Args:
back_loop (LoopT): The loop to put behind this float.
"""
if back_loop in self.front_loops:
self.front_loops.remove(back_loop)
self.back_loops.add(back_loop)
[docs]
def remove_loop_relative_to_floats(self, loop: LoopT) -> None:
"""
Removes the given loop from the edge data (if present), noting that it is neither in front of nor behind the float.
Args:
loop (LoopT): The loop to remove.
"""
if loop in self.front_loops:
self.front_loops.remove(loop)
elif loop in self.back_loops:
self.back_loops.remove(loop)
[docs]
def __contains__(self, item: LoopT) -> bool:
"""
Args:
item (LoopT): The loop to find relative to this float.
Returns:
bool: True if the loop is in front or behind this float.
"""
return item in self.front_loops or item in self.back_loops