Source code for tsl.transforms.masked_subgraph
from torch_geometric.transforms import BaseTransform
from tsl.data import Data
[docs]class MaskedSubgraph(BaseTransform):
"""Reduce graph in :attr:`sample` removing masked nodes."""
def __call__(self, data: Data) -> Data:
if not data.has_mask:
return data
live_nodes = data.mask.any(0).any(-1)
node_index = live_nodes.nonzero().squeeze(dim=1)
return data.subgraph_(node_index)