Source code for dragon.search_space.bricks.recurrences

import numpy as np
import torch.nn as nn
from dragon.search_space import Brick
from dragon.utils.tools import logger

[docs] class Simple_1DRNN(Brick): def __init__(self, input_shape, num_layers, hidden_size): super(Simple_1DRNN, self).__init__(input_shape) self.input_shape = input_shape self.input_size = input_shape[-1] self.num_layers = num_layers self.hidden_size = hidden_size self.rnn = nn.RNN(input_size=input_shape[-1], hidden_size=hidden_size, num_layers=num_layers, batch_first=True)
[docs] def forward(self, X, h=None): init_shape = X.shape try: if len(init_shape) < 3: X = X.unsqueeze(-1) X = X.permute(0, 2, 1) if h is None: X, h = self.rnn(X) else: if h.shape[1] >= X.shape[0]: if h.shape[1] > X.shape[0]: h = h[:, :X.shape[0]].contiguous() X,h = self.rnn(X,h) else: X, h = self.rnn(X) if len(init_shape) < 3: X = X.permute(0, 2, 1) X = X.squeeze(-1) return X, h.detach() except Exception as e: logger.error(f"{e}", exc_info=True) raise e
[docs] def modify_operation(self, input_shape): T = input_shape[-1] diff = T - self.input_size sign = diff / np.abs(diff) if diff !=0 else 1 pad = (int(sign * np.ceil(np.abs(diff)/2)), int(sign * np.floor(np.abs(diff))/2)) self.rnn.weight_ih_l0.data = nn.functional.pad(self.rnn.weight_ih_l0, pad) self.rnn.input_size = T self.input_size = T self.input_shape = input_shape
[docs] def load_state_dict(self, state_dict, **kwargs): T = state_dict['rnn.weight_ih_l0'].shape[-1] self.modify_operation((T,)) super(Simple_1DRNN, self).load_state_dict(state_dict, **kwargs)
[docs] class Simple_2DLSTM(Brick): def __init__(self, input_shape, hidden_size, num_layers): F, T, d_in = input_shape super(Simple_2DLSTM, self).__init__(input_shape) if d_in == 1: self.input_size = F else: self.input_size = d_in self.lstm = nn.LSTM(input_size=self.input_size, hidden_size=hidden_size, num_layers=num_layers, batch_first=True)
[docs] def forward(self, X): bs, F, T, d_in = X.shape if d_in == 1: X_viewed = X.squeeze(-1) X_viewed = X_viewed.transpose(1,2) else: X_viewed = X.reshape(-1, T, d_in) X_lstm, _ = self.lstm(X_viewed) if d_in == 1: X_final = X_lstm.unsqueeze(-1) X_final = X_final.transpose(1, 2) else: X_final = X_lstm.reshape(bs, F, *X_lstm.shape[1:]) return X_final
[docs] def modify_operation(self, input_shape): F, T, d_in = input_shape if d_in == 1: d_in = F diff = d_in - self.input_size sign = diff / np.abs(diff) if diff !=0 else 1 pad = (int(sign * np.ceil(np.abs(diff)/2)), int(sign * np.floor(np.abs(diff))/2)) self.lstm.weight_ih_l0.data = nn.functional.pad(self.lstm.weight_ih_l0, pad) self.lstm.input_size = d_in self.input_size = d_in
[docs] def load_state_dict(self, state_dict, **kwargs): input_shape = state_dict['lstm.weight_ih_l0'].shape[1] self.modify_operation(input_shape) super(Simple_2DLSTM, self).load_state_dict(state_dict, **kwargs)
[docs] class Simple_1DLSTM(Brick): def __init__(self, input_shape, hidden_size, num_layers): super(Simple_1DLSTM, self).__init__(input_shape) self.input_shape = input_shape self.input_size = input_shape[-1] self.hidden_size = hidden_size self.num_layers = num_layers self.lstm = nn.LSTM(input_size=self.input_size, hidden_size=hidden_size, num_layers=num_layers, batch_first=True)
[docs] def forward(self, X, h=None): init_shape = X.shape try: if len(init_shape) < 3: X = X.unsqueeze(-1) X = X.permute(0, 2, 1) if h is None: X, (h,c) = self.lstm(X) else: h, c = h if h.shape[1] >= X.shape[0]: if h.shape[1] > X.shape[0]: h = h[:, :X.shape[0]].contiguous() c = h[:, :X.shape[0]].contiguous() X, (h,c) = self.lstm(X, (h,c)) else: X, (h,c) = self.lstm(X) h = h.detach() c = c.detach() if len(init_shape) < 3: X = X.permute(0, 2, 1) X = X.squeeze(-1) return X, (h,c) except Exception as e: logger.error(f"Input shape: {self.input_shape}, {self.input_size}, X shape: {X.shape}, init shape: {init_shape} lstm = {self.lstm} \n {e}", exc_info=True) if h is not None: logger.error(f"h: {h.shape}, c: {c.shape}") raise e
[docs] def modify_operation(self, input_shape): T = input_shape[-1] diff = T - self.input_size sign = diff / np.abs(diff) if diff !=0 else 1 pad = (int(sign * np.ceil(np.abs(diff)/2)), int(sign * np.floor(np.abs(diff))/2)) self.lstm.weight_ih_l0.data = nn.functional.pad(self.lstm.weight_ih_l0, pad) self.lstm.input_size = T self.input_size = T self.input_shape = input_shape
[docs] def load_state_dict(self, state_dict, **kwargs): T = state_dict['lstm.weight_ih_l0'].shape[-1] self.modify_operation((T,)) super(Simple_1DLSTM, self).load_state_dict(state_dict, **kwargs)
[docs] class Simple_2DGRU(Brick): def __init__(self, input_shape, hidden_size, num_layers): F, T, d_in = input_shape super(Simple_2DGRU, self).__init__(input_shape) if d_in == 1: self.input_size = F else: self.input_size = d_in self.gru = nn.GRU(input_size=self.input_size, hidden_size=hidden_size, num_layers=num_layers, batch_first=True)
[docs] def forward(self, X): bs, F, T, d_in = X.shape if d_in == 1: X_viewed = X.squeeze(-1) X_viewed = X_viewed.transpose(1,2) else: X_viewed = X.reshape(-1, T, d_in) X_gru, _ = self.gru(X_viewed) if d_in == 1: X_final = X_gru.unsqueeze(-1) X_final = X_final.transpose(1, 2) else: X_final = X_gru.reshape(bs, F, *X_gru.shape[1:]) return X_final
[docs] def modify_operation(self, input_shape): F, T, d_in = input_shape if d_in == 1: d_in = F diff = d_in - self.input_size sign = diff / np.abs(diff) if diff !=0 else 1 pad = (int(sign * np.ceil(np.abs(diff)/2)), int(sign * np.floor(np.abs(diff))/2)) self.gru.weight_ih_l0.data = nn.functional.pad(self.gru.weight_ih_l0, pad) self.gru.input_size = d_in self.input_size = d_in
[docs] def load_state_dict(self, state_dict, **kwargs): input_shape = state_dict['gru.weight_ih_l0'].shape[1] self.modify_operation(input_shape) super(Simple_2DGRU, self).load_state_dict(state_dict, **kwargs)
[docs] class Simple_1DGRU(Brick): def __init__(self, input_shape, num_layers, hidden_size): super(Simple_1DGRU, self).__init__(input_shape) self.input_shape = input_shape self.input_size = input_shape[-1] self.num_layers = num_layers self.hidden_size = hidden_size self.gru = nn.GRU(input_size=input_shape[-1], hidden_size=hidden_size, num_layers=num_layers, batch_first=True)
[docs] def forward(self, X, h=None): init_shape = X.shape try: if len(init_shape) < 3: X = X.unsqueeze(-1) X = X.permute(0, 2, 1) if h is None: X, h = self.gru(X) else: if h.shape[1] >= X.shape[0]: if h.shape[1] > X.shape[0]: h = h[:, :X.shape[0]].contiguous() X, h = self.gru(X, h) else: X, h = self.gru(X) if len(init_shape) < 3: X = X.permute(0, 2, 1) X = X.squeeze(-1) return X, h.detach() except Exception as e: logger.error(f"Input shape: {self.input_shape}, {self.input_size}, X shape: {X.shape}, init_shape: {init_shape}, gru = {self.gru}") if h is not None: logger.error(f"h: {h.shape}") raise e
[docs] def modify_operation(self, input_shape): T = input_shape[-1] diff = T - self.input_size sign = diff / np.abs(diff) if diff !=0 else 1 pad = (int(sign * np.ceil(np.abs(diff)/2)), int(sign * np.floor(np.abs(diff))/2)) self.gru.weight_ih_l0.data = nn.functional.pad(self.gru.weight_ih_l0, pad) self.gru.input_size = T self.input_size = T self.input_shape = input_shape
[docs] def load_state_dict(self, state_dict, **kwargs): T = state_dict['gru.weight_ih_l0'].shape[-1] self.modify_operation((T,)) super(Simple_1DGRU, self).load_state_dict(state_dict, **kwargs)