from tsl.nn.layers.graph_convs import GraphConv
from .base import GraphGRUCellBase, GraphLSTMCellBase
[docs]class GraphConvGRUCell(GraphGRUCellBase):
r"""Gated Recurrent Unit with :class:`~tsl.nn.layers.graph_convs.GraphConv`
as graph convolution in the gates, based on the paper
`"Structured Sequence Modeling with Graph Convolutional Recurrent Networks"
<https://arxiv.org/abs/1612.07659>`_ (Seo et al., ICONIP 2017).
Args:
input_size (int): Size of the input.
hidden_size (int): Number of units in the hidden state.
bias (bool): If :obj:`True`, then the layer will learn an additive bias
for each gate.
(default: :obj:`True`)
norm (str): The normalization used for edges and edge weights. If
:obj:`'mean'`, then edge weights are normalized as
:math:`a_{j \rightarrow i} = \frac{a_{j \rightarrow i}} {deg_{i}}`,
other available options are: :obj:`'gcn'`, :obj:`'asym'` and
:obj:`'none'`.
(default: :obj:`'mean'`)
root_weight (bool): If :obj:`True`, then add a filter (with different
weights) for the root node itself.
(default :obj:`True`)
cached (bool): If :obj:`True`, then cached the normalized edge weights
computed in the first call.
(default :obj:`False`)
**kwargs (optional): Additional arguments of
:class:`torch_geometric.nn.conv.MessagePassing`.
"""
def __init__(self,
input_size: int,
hidden_size: int,
bias: bool = True,
norm: str = 'mean',
root_weight: bool = True,
cached: bool = False,
**kwargs):
self.input_size = input_size
# instantiate gates
forget_gate = GraphConv(input_size + hidden_size,
hidden_size,
norm=norm,
root_weight=root_weight,
bias=bias,
cached=cached,
**kwargs)
update_gate = GraphConv(input_size + hidden_size,
hidden_size,
norm=norm,
root_weight=root_weight,
bias=bias,
cached=cached,
**kwargs)
candidate_gate = GraphConv(input_size + hidden_size,
hidden_size,
norm=norm,
root_weight=root_weight,
bias=bias,
cached=cached,
**kwargs)
super(GraphConvGRUCell, self).__init__(hidden_size=hidden_size,
forget_gate=forget_gate,
update_gate=update_gate,
candidate_gate=candidate_gate)
[docs]class GraphConvLSTMCell(GraphLSTMCellBase):
r"""LSTM with :class:`~tsl.nn.layers.graph_convs.GraphConv` as graph
convolution in the gates, based on the paper `"Structured Sequence Modeling
with Graph Convolutional Recurrent Networks"
<https://arxiv.org/abs/1612.07659>`_ (Seo et al., ICONIP 2017).
Args:
input_size (int): Size of the input.
hidden_size (int): Number of units in the hidden state.
bias (bool): If :obj:`True`, then the layer will learn an additive bias
for each gate.
(default: :obj:`True`)
norm (str): Normalization used by the graph convolutional layer.
(default :obj:`mean`)
root_weight (bool): If :obj:`True`, then add a filter (with different
weights) for the root node itself.
(default :obj:`True`)
cached (bool): If :obj:`True`, then cached the normalized edge weights
computed in the first call.
(default :obj:`False`)
**kwargs (optional): Additional arguments of
:class:`torch_geometric.nn.conv.MessagePassing`.
"""
def __init__(self,
input_size: int,
hidden_size: int,
bias: bool = True,
norm: str = 'mean',
root_weight: bool = True,
cached: bool = False,
**kwargs):
self.input_size = input_size
# instantiate gates
input_gate = GraphConv(input_size + hidden_size,
hidden_size,
norm=norm,
root_weight=root_weight,
bias=bias,
cached=cached,
**kwargs)
forget_gate = GraphConv(input_size + hidden_size,
hidden_size,
norm=norm,
root_weight=root_weight,
bias=bias,
cached=cached,
**kwargs)
cell_gate = GraphConv(input_size + hidden_size,
hidden_size,
norm=norm,
root_weight=root_weight,
bias=bias,
cached=cached,
**kwargs)
output_gate = GraphConv(input_size + hidden_size,
hidden_size,
norm=norm,
root_weight=root_weight,
bias=bias,
cached=cached,
**kwargs)
super(GraphConvLSTMCell, self).__init__(hidden_size=hidden_size,
input_gate=input_gate,
forget_gate=forget_gate,
cell_gate=cell_gate,
output_gate=output_gate)