Source code for dragon.search_space.bricks.basics
import numpy as np
import torch
import torch.nn as nn
from dragon.search_space.cells import Brick
from dragon.utils.tools import logger
[docs]
class Identity(Brick):
def __init__(self, input_shape=None, **args):
super(Identity, self).__init__(input_shape)
[docs]
def forward(self, X):
return X
[docs]
def modify_operation(self, input_shape):
pass
[docs]
class MLP(Brick):
def __init__(self, input_shape, out_channels):
super(MLP, self).__init__(input_shape)
self.in_channels = input_shape[-1]
self.out_channels = out_channels
self.linear = nn.Linear(self.in_channels, out_channels)
[docs]
def forward(self, X):
try:
X = self.linear(X)
except Exception as e:
logger.error(f'linear: {self.linear.weight.get_device()},\nx = {X.get_device()}')
raise e
return X
[docs]
def modify_operation(self, input_shape, hp=None):
if hp is not None:
d_out = hp["out_channels"]
else:
d_out = self.out_channels
d_in = input_shape[-1]
diff = d_in - self.in_channels
diff_out = d_out - self.out_channels
sign = diff / np.abs(diff) if diff !=0 else 1
sign_out = diff_out / np.abs(diff_out) if diff_out !=0 else 1
pad = (int(sign * np.ceil(np.abs(diff)/2)), int(sign * np.floor(np.abs(diff))/2),
int(sign_out * np.ceil(np.abs(diff_out)/2)), int(sign_out * np.floor(np.abs(diff_out))/2))
pad_bias = (int(sign_out * np.ceil(np.abs(diff_out)/2)), int(sign_out * np.floor(np.abs(diff_out))/2))
self.in_channels = d_in
self.out_channels = d_out
self.linear.weight.data = nn.functional.pad(self.linear.weight, pad)
self.linear.bias.data = nn.functional.pad(self.linear.bias, pad_bias)
[docs]
def load_state_dict(self, state_dict, **kwargs):
input_shape = state_dict['linear.weight'].shape[1]
output_shape = state_dict['linear.weight'].shape[0]
self.modify_operation((input_shape,), hp={'out_channels': output_shape})
super(MLP, self).load_state_dict(state_dict, **kwargs)