MetaQNN


About

In our paper, Designing Neural Network Architectures Using Reinforcement Learning (arxiv, openreview), we propose a meta-modeling approach based on reinforcement learning to automatically generate high-performing CNN architectures for a given learning task. The learning agent is trained to sequentially choose CNN layers using Q-learning with an ε-greedy exploration strategy and experience replay. The agent explores a large but finite space of possible architectures and iteratively discovers designs with improved performance on the learning task. On image classification benchmarks, the agent-designed networks (consisting of only standard convolution, pooling, and fully-connected layers) beat existing networks designed with the same layer types, and are competitive against the state-of-the-art methods that use more complex layer types. We also outperform existing network design meta-modeling approaches on image classification.

We have just released the full code to run MetaQNN! Find it here.

CIFAR10

6.92% test error
C
(128, 5, 1)
C
(512, 3, 1)
D
12.5%
P
(2, 2)
C
(128, 1, 1)
D
25.0%
C
(128, 5, 1)
P
(3, 2)
D
37.5%
C
(512, 3, 1)
SM
(10)
C(128, 5, 1) - C(512, 3, 1) - P(2, 2) - C(128, 1, 1) - C(128, 5, 1) - P(3, 2) - C(512, 3, 1) - SM(10)
8.88% test error
C
(256, 3, 1)
C
(128, 1, 1)
D
16.7%
C
(128, 3, 1)
C
(128, 3, 1)
D
33.3%
P
(5, 3)
C
(128, 3, 1)
D
50.0%
SM
(10)
C(256, 3, 1) - C(128, 1, 1) - C(128, 3, 1) - C(128, 3, 1) - P(5, 3) - C(128, 3, 1) - SM(10)
11.63% test error
C
(64, 3, 1)
C
(128, 3, 1)
D
10.0%
P
(5, 3)
C
(256, 3, 1)
D
20.0%
P
(3, 2)
FC
(512)
D
30.0%
FC
(128)
D
40.0%
SM
(10)
C(64, 3, 1) - C(128, 3, 1) - P(5, 3) - C(256, 3, 1) - P(3, 2) - FC(512) - FC(128) - SM(10)
9.24% test error
C
(64, 5, 1)
C
(512, 3, 1)
D
10.0%
C
(512, 1, 1)
C
(64, 1, 1)
D
20.0%
C
(128, 5, 1)
P
(3, 2)
D
30.0%
C
(512, 5, 1)
C
(512, 3, 1)
D
40.0%
P
(5, 3)
SM
(10)
C(64, 5, 1) - C(512, 3, 1) - C(512, 1, 1) - C(64, 1, 1) - C(128, 5, 1) - P(3, 2) - C(512, 5, 1) - C(512, 3, 1) - P(5, 3) - SM(10)
8.78% test error
C
(128, 5, 1)
C
(512, 3, 1)
D
16.7%
P
(2, 2)
C
(128, 1, 1)
D
33.3%
C
(128, 5, 1)
P
(3, 2)
D
50.0%
SM
(10)
C(128, 5, 1) - C(512, 3, 1) - P(2, 2) - C(128, 1, 1) - C(128, 5, 1) - P(3, 2) - SM(10)

SVHN

2.29% test error
C
(64, 1, 1)
C
(128, 3, 1)
D
8.3%
C
(64, 5, 1)
C
(512, 5, 1)
D
16.7%
C
(256, 1, 1)
C
(256, 5, 1)
D
25.0%
C
(128, 1, 1)
C
(256, 5, 1)
D
33.3%
P
(3, 2)
C
(512, 5, 1)
D
41.7%
C
(256, 3, 1)
C
(128, 3, 1)
D
50.0%
SM
(10)
C(64, 1, 1) - C(128, 3, 1) - C(64, 5, 1) - C(512, 5, 1) - C(256, 1, 1) - C(256, 5, 1) - C(128, 1, 1) - C(256, 5, 1) - P(3, 2) - C(512, 5, 1) - C(256, 3, 1) - C(128, 3, 1) - SM(10)
2.33% test error
C
(128, 1, 1)
C
(256, 5, 1)
D
8.3%
C
(128, 5, 1)
P
(2, 2)
D
16.7%
C
(256, 5, 1)
C
(256, 1, 1)
D
25.0%
C
(256, 3, 1)
C
(256, 3, 1)
D
33.3%
C
(256, 5, 1)
C
(512, 5, 1)
D
41.7%
C
(256, 3, 1)
C
(128, 3, 1)
D
50.0%
SM
(10)
C(128, 1, 1) - C(256, 5, 1) - C(128, 5, 1) - P(2, 2) - C(256, 5, 1) - C(256, 1, 1) - C(256, 3, 1) - C(256, 3, 1) - C(256, 5, 1) - C(512, 5, 1) - C(256, 3, 1) - C(128, 3, 1) - SM(10)
2.35% test error
C
(128, 5, 1)
C
(128, 3, 1)
D
10.0%
C
(64, 5, 1)
P
(5, 3)
D
20.0%
C
(128, 3, 1)
C
(512, 5, 1)
D
30.0%
C
(256, 5, 1)
C
(128, 5, 1)
D
40.0%
C
(128, 5, 1)
C
(128, 3, 1)
D
50.0%
SM
(10)
C(128, 5, 1) - C(128, 3, 1) - C(64, 5, 1) - P(5, 3) - C(128, 3, 1) - C(512, 5, 1) - C(256, 5, 1) - C(128, 5, 1) - C(128, 5, 1) - C(128, 3, 1) - SM(10)
2.24% test error
C
(128, 3, 1)
P
(2, 2)
D
8.3%
C
(64, 1, 1)
C
(256, 1, 1)
D
16.7%
C
(256, 5, 1)
C
(128, 1, 1)
D
25.0%
C
(128, 5, 1)
C
(512, 3, 1)
D
33.3%
C
(256, 5, 1)
C
(256, 1, 1)
D
41.7%
C
(128, 3, 1)
C
(64, 1, 1)
D
50.0%
SM
(10)
C(128, 3, 1) - P(2, 2) - C(64, 1, 1) - C(256, 1, 1) - C(256, 5, 1) - C(128, 1, 1) - C(128, 5, 1) - C(512, 3, 1) - C(256, 5, 1) - C(256, 1, 1) - C(128, 3, 1) - C(64, 1, 1) - SM(10)
2.36% test error
C
(128, 1, 1)
C
(256, 5, 1)
D
8.3%
C
(128, 5, 1)
C
(512, 5, 1)
D
16.7%
C
(256, 1, 1)
C
(256, 5, 1)
D
25.0%
P
(5, 3)
C
(128, 5, 1)
D
33.3%
C
(128, 5, 1)
C
(128, 5, 1)
D
41.7%
C
(64, 1, 1)
C
(128, 5, 1)
D
50.0%
SM
(10)
C(128, 1, 1) - C(256, 5, 1) - C(128, 5, 1) - C(512, 5, 1) - C(256, 1, 1) - C(256, 5, 1) - P(5, 3) - C(128, 5, 1) - C(128, 5, 1) - C(128, 5, 1) - C(64, 1, 1) - C(128, 5, 1) - SM(10)

MNIST

0.44% test error
C
(512, 5, 1)
C
(128, 5, 1)
D
12.5%
C
(128, 5, 1)
C
(128, 3, 1)
D
25.0%
C
(256, 3, 1)
C
(512, 5, 1)
D
37.5%
C
(256, 3, 1)
C
(128, 3, 1)
D
50.0%
SM
(10)
C(512, 5, 1) - C(128, 5, 1) - C(128, 5, 1) - C(128, 3, 1) - C(256, 3, 1) - C(512, 5, 1) - C(256, 3, 1) - C(128, 3, 1) - SM(10)
0.44% test error
C
(64, 1, 1)
C
(256, 3, 1)
D
8.3%
P
(2, 2)
C
(512, 3, 1)
D
16.7%
C
(256, 1, 1)
P
(5, 3)
D
25.0%
C
(256, 3, 1)
C
(512, 3, 1)
D
33.3%
FC
(512)
D
41.7%
SM
(10)
C(64, 1, 1) - C(256, 3, 1) - P(2, 2) - C(512, 3, 1) - C(256, 1, 1) - P(5, 3) - C(256, 3, 1) - C(512, 3, 1) - FC(512) - SM(10)
0.38% test error
C
(64, 1, 1)
C
(256, 5, 1)
D
8.3%
C
(256, 5, 1)
C
(512, 1, 1)
D
16.7%
C
(64, 3, 1)
P
(5, 3)
D
25.0%
C
(256, 5, 1)
C
(256, 5, 1)
D
33.3%
C
(512, 5, 1)
C
(64, 1, 1)
D
41.7%
C
(128, 5, 1)
C
(512, 5, 1)
D
50.0%
SM
(10)
C(64, 1, 1) - C(256, 5, 1) - C(256, 5, 1) - C(512, 1, 1) - C(64, 3, 1) - P(5, 3) - C(256, 5, 1) - C(256, 5, 1) - C(512, 5, 1) - C(64, 1, 1) - C(128, 5, 1) - C(512, 5, 1) - SM(10)
0.46% test error
C
(512, 5, 1)
C
(128, 5, 1)
D
12.5%
C
(128, 5, 1)
C
(128, 1, 1)
D
25.0%
P
(2, 2)
C
(512, 5, 1)
D
37.5%
C
(256, 3, 1)
C
(128, 3, 1)
D
50.0%
SM
(10)
C(512, 5, 1) - C(128, 5, 1) - C(128, 5, 1) - C(128, 1, 1) - P(2, 2) - C(512, 5, 1) - C(256, 3, 1) - C(128, 3, 1) - SM(10)
0.55% test error
C
(256, 3, 1)
C
(256, 5, 1)
D
8.3%
C
(512, 3, 1)
C
(256, 5, 1)
D
16.7%
C
(512, 1, 1)
P
(5, 3)
D
25.0%
C
(256, 3, 1)
C
(64, 3, 1)
D
33.3%
C
(256, 5, 1)
C
(512, 3, 1)
D
41.7%
C
(128, 5, 1)
C
(512, 5, 1)
D
50.0%
SM
(10)
C(256, 3, 1) - C(256, 5, 1) - C(512, 3, 1) - C(256, 5, 1) - C(512, 1, 1) - P(5, 3) - C(256, 3, 1) - C(64, 3, 1) - C(256, 5, 1) - C(512, 3, 1) - C(128, 5, 1) - C(512, 5, 1) - SM(10)
0.43% test error
C
(64, 3, 1)
C
(128, 3, 1)
D
10.0%
C
(512, 1, 1)
C
(256, 1, 1)
D
20.0%
C
(256, 5, 1)
C
(128, 3, 1)
D
30.0%
P
(5, 3)
C
(512, 1, 1)
D
40.0%
C
(512, 3, 1)
C
(128, 5, 1)
D
50.0%
SM
(10)
C(64, 3, 1) - C(128, 3, 1) - C(512, 1, 1) - C(256, 1, 1) - C(256, 5, 1) - C(128, 3, 1) - P(5, 3) - C(512, 1, 1) - C(512, 3, 1) - C(128, 5, 1) - SM(10)
0.41% test error
C
(128, 3, 1)
C
(64, 1, 1)
D
7.1%
C
(64, 3, 1)
C
(64, 5, 1)
D
14.3%
P
(2, 2)
C
(128, 3, 1)
D
21.4%
P
(3, 2)
C
(512, 3, 1)
D
28.6%
FC
(512)
D
35.7%
FC
(128)
D
42.9%
SM
(10)
C(128, 3, 1) - C(64, 1, 1) - C(64, 3, 1) - C(64, 5, 1) - P(2, 2) - C(128, 3, 1) - P(3, 2) - C(512, 3, 1) - FC(512) - FC(128) - SM(10)
0.35% test error
C
(128, 3, 1)
C
(512, 3, 1)
D
12.5%
P
(2, 2)
C
(256, 3, 1)
D
25.0%
C
(128, 5, 1)
C
(64, 1, 1)
D
37.5%
C
(64, 5, 1)
C
(512, 5, 1)
D
50.0%
GAP
(10)
SM
(10)
C(128, 3, 1) - C(512, 3, 1) - P(2, 2) - C(256, 3, 1) - C(128, 5, 1) - C(64, 1, 1) - C(64, 5, 1) - C(512, 5, 1) - GAP(10) - SM(10)
0.40% test error
C
(64, 5, 1)
C
(512, 5, 1)
D
8.3%
P
(3, 2)
C
(256, 5, 1)
D
16.7%
C
(256, 3, 1)
C
(256, 3, 1)
D
25.0%
C
(128, 1, 1)
C
(256, 3, 1)
D
33.3%
C
(256, 5, 1)
C
(64, 1, 1)
D
41.7%
C
(256, 3, 1)
C
(64, 3, 1)
D
50.0%
SM
(10)
C(64, 5, 1) - C(512, 5, 1) - P(3, 2) - C(256, 5, 1) - C(256, 3, 1) - C(256, 3, 1) - C(128, 1, 1) - C(256, 3, 1) - C(256, 5, 1) - C(64, 1, 1) - C(256, 3, 1) - C(64, 3, 1) - SM(10)
0.56% test error
C
(512, 1, 1)
C
(128, 3, 1)
D
8.3%
C
(128, 5, 1)
C
(64, 1, 1)
D
16.7%
C
(256, 5, 1)
C
(64, 1, 1)
D
25.0%
P
(5, 3)
C
(512, 1, 1)
D
33.3%
C
(512, 3, 1)
C
(256, 3, 1)
D
41.7%
C
(256, 5, 1)
C
(256, 5, 1)
D
50.0%
SM
(10)
C(512, 1, 1) - C(128, 3, 1) - C(128, 5, 1) - C(64, 1, 1) - C(256, 5, 1) - C(64, 1, 1) - P(5, 3) - C(512, 1, 1) - C(512, 3, 1) - C(256, 3, 1) - C(256, 5, 1) - C(256, 5, 1) - SM(10)

Team

Bowen Baker

bowen@mit.edu

Otkrist Gupta

otkrist@mit.edu

Nikhil Naik

naik@mit.edu

Ramesh Raskar

raskar@media.mit.edu

Peter Downs

downs@mit.edu