Image Classification#
We show in this tutorial how to use DRAGON for image classification task. We need to create a search space with two graphs, one treating 2D data, and a second one treating 1D data.
Loading the dataset#
[1]:
from sklearn.datasets import load_digits
import matplotlib.pyplot as plt
digits = load_digits()
_, axes = plt.subplots(nrows=1, ncols=4, figsize=(10, 3))
for ax, image, label in zip(axes, digits.images, digits.target):
ax.set_axis_off()
ax.imshow(image, cmap=plt.cm.gray_r, interpolation="nearest")
ax.set_title("Training: %i" % label)

[2]:
from sklearn.model_selection import train_test_split
X_train, X_test, y_train, y_test = train_test_split(
digits.images, digits.target, test_size=0.5, shuffle=False
)
X_train = X_train/digits.images.max()
X_test = X_test/digits.images.max()
print(f"X_train: {X_train.shape}, y_train: {y_train.shape}, X_val: {X_test.shape}, y_val: {y_test.shape}")
X_train: (898, 8, 8), y_train: (898,), X_val: (899, 8, 8), y_val: (899,)
Defining the Loss function#
DNN definition#
[3]:
import torch
import torch.nn as nn
import numpy as np
import os
class ClassificationDNN(nn.Module):
def __init__(self, args, input_shape) -> None:
super().__init__()
self.input_shape = input_shape
self.dag_2d = args['2D Dag']
self.dag_2d.set(self.input_shape)
flat_shape = (np.prod(self.dag_2d.output_shape),)
self.dag_1d = args['1D Dag']
self.dag_1d.set(flat_shape)
self.output = args["Out"]
self.output.set(self.dag_1d.output_shape)
def forward(self, X, **kwargs):
out_2d = self.dag_2d(X)
flat = nn.Flatten()(out_2d)
out_1d = self.dag_1d(flat)
out = self.output(out_1d)
return out
def save(self, path):
if not os.path.exists(path):
os.makedirs(path)
full_path = os.path.join(path, "best_model.pth")
torch.save(self.state_dict(), full_path)
Search Space Definition#
[4]:
from dragon.search_space.bricks_variables import mlp_var, dropout, identity_var, operations_var, mlp_const_var, conv_2d, pooling_2d, dag_var, node_var
from dragon.search_space.base_variables import ArrayVar
from dragon.search_operators.base_neighborhoods import ArrayInterval
candidate_operations_2d = operations_var("2D Candidate operations", size=10,
candidates=[mlp_var("MLP"), identity_var("Identity"), dropout('Dropout'), conv_2d('Conv 2d', max_out=8), pooling_2d("Pooling")])
dag_2d = dag_var("2D Dag", candidate_operations_2d)
candidate_operations_1d = operations_var("1D Candidate operations", size=10,
candidates=[mlp_var("MLP"), identity_var("Identity"), dropout('Dropout')])
dag_1d = dag_var("1D Dag", candidate_operations_1d)
out = node_var("Out", operation=mlp_const_var('Operation', 10), activation_function=nn.Identity())
search_space = ArrayVar(dag_2d, dag_1d, out, label="Search Space", neighbor=ArrayInterval())
DNN Training#
[5]:
import numpy as np
from skorch import NeuralNetClassifier
from sklearn.metrics import accuracy_score
from dragon.utils.tools import set_seed
device = "cuda" if torch.cuda.is_available() else "cpu"
def train_and_predict(args, idx, verbose=False):
set_seed(0)
labels = [e.label for e in search_space]
args = dict(zip(labels, args))
model = ClassificationDNN(args, input_shape=(8,8,1))
trainer = NeuralNetClassifier(
model,
max_epochs=20,
lr=0.01,
optimizer = torch.optim.Adam,
criterion=nn.CrossEntropyLoss,
iterator_train__shuffle=True,
verbose=verbose,
device=device
)
X_train_torch = torch.tensor(np.expand_dims(X_train.astype(np.float32), axis=-1)).to(device)
y_train_torch = torch.tensor(y_train.astype(np.int64)).to(device)
X_test_torch = torch.tensor(np.expand_dims(X_test.astype(np.float32), axis=-1)).to(device)
trainer.fit(X_train_torch, y_train_torch)
y_pred = trainer.predict(X_test_torch)
acc = accuracy_score(y_test, y_pred)
print(f"With idx = {idx}, accuracy = {acc}")
return - acc, model # We are optimizing a minimization problem
p1, p2 = search_space.random(2)
loss_1, model_1 = train_and_predict(p1,idx="p1", verbose=True)
loss_2, model_2 = train_and_predict(p2, idx="p2", verbose=True)
print("P1 ==> accuracy: ", np.round(-loss_1*100,2), "%\n")
print("P2 ==> accuracy: ", np.round(-loss_2*100,2), "%")
epoch train_loss valid_acc valid_loss dur
------- ------------ ----------- ------------ ------
1 10.4300 0.1000 35.2483 0.6397
2 17.3778 0.1389 21.8782 0.5451
3 10.0463 0.3056 11.8703 0.5422
4 3.7305 0.3222 7.0649 0.5363
5 2.4646 0.6556 2.2861 0.5357
6 1.1765 0.5500 3.7344 0.5302
7 0.8276 0.6778 2.2579 0.5375
8 0.4357 0.7778 2.0746 0.5252
9 0.3695 0.7889 1.5015 0.5352
10 0.2913 0.7444 2.1330 0.5405
11 0.2722 0.7611 1.4074 0.5306
12 0.2230 0.7167 1.9679 0.5648
13 0.1919 0.8056 1.2097 0.5308
14 0.1777 0.7833 1.6146 0.5378
15 0.1509 0.8222 1.0433 0.5347
16 0.1600 0.7500 1.5645 0.5452
17 0.1115 0.8389 1.0201 0.5333
18 0.1309 0.7889 1.3314 0.5282
19 0.1284 0.8167 1.2493 0.5428
20 0.1314 0.8444 0.9768 0.5454
With idx = p1, accuracy = 0.8776418242491657
epoch train_loss valid_acc valid_loss dur
------- ------------ ----------- ------------ ------
1 2.3037 0.1000 2.3025 0.5547
2 2.3032 0.1000 2.3026 0.5334
3 2.3028 0.1000 2.3025 0.7201
4 2.3027 0.1000 2.3026 0.6273
5 2.3027 0.1000 2.3026 0.5953
6 2.3027 0.1000 2.3026 0.5357
7 2.3027 0.1000 2.3026 0.5499
8 2.3026 0.1000 2.3026 0.5495
9 2.3027 0.1000 2.3026 0.5550
10 2.3025 0.1000 2.3026 0.5345
11 2.3027 0.1000 2.3026 0.5682
12 2.3026 0.1000 2.3026 0.5566
13 2.3027 0.1000 2.3026 0.5360
14 2.3026 0.1000 2.3026 0.5511
15 2.3026 0.1000 2.3026 0.5505
16 2.3027 0.1000 2.3026 0.5496
17 2.3027 0.1000 2.3026 0.5618
18 2.3026 0.1000 2.3026 0.5442
19 2.3027 0.1000 2.3025 0.5380
20 2.3026 0.1000 2.3026 0.5555
With idx = p2, accuracy = 0.10122358175750834
P1 ==> accuracy: 87.76 %
P2 ==> accuracy: 10.12 %
Implementing an optimization strategy#
[6]:
import time
from dragon.search_algorithm.ssea import SteadyStateEA
search_algorithm = SteadyStateEA(search_space, n_iterations=20, population_size=5, selection_size=3, evaluation=train_and_predict, save_dir="save/test_image/")
start_time = time.time()
search_algorithm.run()
min_loss = search_algorithm.min_loss
end_time = time.time() - start_time
print(f"Best score: {np.round(-min_loss*100,2)}%\nComputation time: {np.round(end_time,2)} seconds")
2025-02-24 11:02:28,574 | WARNING | Install mpi4py if you want to use the distributed version.
2025-02-24 11:02:28,578 | INFO | save/test_image/ already exists. Deleting it.
2025-02-24 11:02:28,609 | INFO | The whole population has been created (size = 5), 5 have been randomy initialized.
With idx = 0, accuracy = 0.10122358175750834
2025-02-24 11:02:33,894 | INFO | Best found! -0.10122358175750834 < inf
With idx = 1, accuracy = 0.9365962180200222
2025-02-24 11:02:35,145 | INFO | Best found! -0.9365962180200222 < -0.10122358175750834
With idx = 2, accuracy = 0.8787541713014461
With idx = 3, accuracy = 0.09788654060066741
With idx = 4, accuracy = 0.10122358175750834
2025-02-24 11:03:03,601 | INFO | All models have been at least evaluated once, t = 5 < 20.
2025-02-24 11:03:03,603 | INFO | After initialisation, it remains 15 iterations.
2025-02-24 11:03:03,673 | INFO | Evolving 1 and 2 to 6 and 7
With idx = 7, accuracy = 0.10122358175750834
2025-02-24 11:03:05,516 | INFO | Replacing 3 by 7, removing save/test_image//x_3.pkl
With idx = 6, accuracy = 0.932146829810901
2025-02-24 11:03:07,120 | INFO | Replacing 0 by 6, removing save/test_image//x_0.pkl
2025-02-24 11:03:07,211 | INFO | Evolving 6 and 1 to 8 and 9
With idx = 9, accuracy = 0.932146829810901
2025-02-24 11:03:08,246 | INFO | Replacing 4 by 9, removing save/test_image//x_4.pkl
With idx = 8, accuracy = 0.9210233592880979
2025-02-24 11:03:09,271 | INFO | Replacing 7 by 8, removing save/test_image//x_7.pkl
2025-02-24 11:03:09,336 | INFO | Evolving 9 and 1 to 10 and 11
With idx = 11, accuracy = 0.9299221357063404
2025-02-24 11:03:10,365 | INFO | Replacing 2 by 11, removing save/test_image//x_2.pkl
With idx = 10, accuracy = 0.9165739710789766
2025-02-24 11:03:11,411 | INFO | 10 is the worst element, removing save/test_image//x_10.pkl.
2025-02-24 11:03:11,454 | INFO | Evolving 1 and 6 to 12 and 13
With idx = 13, accuracy = 0.9210233592880979
2025-02-24 11:03:12,486 | INFO | 13 is the worst element, removing save/test_image//x_13.pkl.
With idx = 12, accuracy = 0.8553948832035595
2025-02-24 11:03:13,415 | INFO | 12 is the worst element, removing save/test_image//x_12.pkl.
2025-02-24 11:03:13,475 | INFO | Evolving 6 and 9 to 14 and 15
With idx = 15, accuracy = 0.917686318131257
2025-02-24 11:03:14,570 | INFO | 15 is the worst element, removing save/test_image//x_15.pkl.
With idx = 14, accuracy = 0.9254727474972191
2025-02-24 11:03:15,599 | INFO | Replacing 8 by 14, removing save/test_image//x_8.pkl
2025-02-24 11:03:15,660 | INFO | Evolving 1 and 6 to 16 and 17
With idx = 17, accuracy = 0.9210233592880979
2025-02-24 11:03:16,881 | INFO | 17 is the worst element, removing save/test_image//x_17.pkl.
With idx = 16, accuracy = 0.932146829810901
2025-02-24 11:03:18,062 | INFO | Replacing 14 by 16, removing save/test_image//x_14.pkl
2025-02-24 11:03:18,183 | INFO | Evolving 11 and 6 to 18 and 19
With idx = 19, accuracy = 0.9210233592880979
2025-02-24 11:03:19,621 | INFO | 19 is the worst element, removing save/test_image//x_19.pkl.
With idx = 18, accuracy = 0.9454949944382648
2025-02-24 11:03:22,338 | INFO | Replacing 11 by 18, removing save/test_image//x_11.pkl
2025-02-24 11:03:22,340 | INFO | Best found! -0.9454949944382648 < -0.9365962180200222
2025-02-24 11:03:22,547 | INFO | Evolving 18 and 1 to 20 and 21
With idx = 21, accuracy = 0.9299221357063404
2025-02-24 11:03:23,936 | INFO | 21 is the worst element, removing save/test_image//x_21.pkl.
2025-02-24 11:03:23,958 | INFO | Search algorithm is done. Min Loss = -0.9454949944382648
Best score: 94.55%
Computation time: 55.39 seconds
Starting with a completely random sets of DNNs, we managed in a few minutes to converge towards an accuracy higher than 92%.
[7]:
from dragon.utils.plot_functions import load_archi
set_seed(0)
best_args =load_archi('save/test_image/best_model/x.pkl')
labels = [e.label for e in search_space]
best_args = dict(zip(labels, best_args))
model = ClassificationDNN(best_args, (8,8,1))
model.load_state_dict(torch.load('save/test_image/best_model/best_model.pth'))
model = NeuralNetClassifier(
model,
max_epochs=1,
lr=0.0001,
iterator_train__shuffle=True,
verbose=False,
)
model.fit(np.expand_dims(X_train.astype(np.float32), axis=-1), y_train.astype(np.int64))
y_pred = model.predict(np.expand_dims(X_test.astype(np.float32), axis=-1))
acc = accuracy_score(y_test, y_pred)
print("Final accuracy: ", np.round(acc*100,2), "%\n")
Final accuracy: 94.55 %
[8]:
from sklearn import metrics
disp = metrics.ConfusionMatrixDisplay.from_predictions(y_test, y_pred)
disp.figure_.suptitle("Confusion Matrix")
plt.show()

[9]:
import graphviz
from dragon.utils.plot_functions import draw_cell, str_operations
def draw_graph(n_dag2, m_dag2, n_dag1, m_dag1, output_file, act="Identity()", name="MNIST"):
G = graphviz.Digraph(output_file, format='pdf',
node_attr={'nodesep': '0.02', 'shape': 'box', 'rankstep': '0.02', 'fontsize': '20', "fontname": "sans-serif"})
G, g_nodes = draw_cell(G, n_dag2, m_dag2, "#ffa600", [], name_input=name,
color_input="#ef5675")
G.node("Flatten", style="rounded,filled", color="black", fillcolor="#CE1C4E", fontcolor="#ECECEC")
G.edge(g_nodes[-1], "Flatten")
G, g_nodes = draw_cell(G, n_dag1, m_dag1, "#ffa600", g_nodes, name_input=["Flatten"],
color_input="#ef5675")
G.node(','.join(["MLP", "10", act]), style="rounded,filled", color="black", fillcolor="#ef5675", fontcolor="#ECECEC")
G.edge(g_nodes[-1], ','.join(["MLP", "10", act]))
return G
m_dag2 = best_args['2D Dag'].matrix
n_dag2 = str_operations(best_args["2D Dag"].operations)
m_dag1 = best_args['1D Dag'].matrix
n_dag1 = str_operations(best_args["1D Dag"].operations)
graph = draw_graph(n_dag2, m_dag2, n_dag1, m_dag1, "save/test_image/best_archi")
print(f'Model giving a score of {np.round(acc*100,2)}%:')
graph
Model giving a score of 94.55%:
[9]: