J'ai travaillé sur un problème de régression où l'entrée est une image et l'étiquette est une valeur continue entre 80 et 350. Les images sont de certains produits chimiques après qu'une réaction ait lieu. La couleur qui apparaît indique la concentration d'un autre produit chimique qui reste, et c'est ce que le modèle doit produire - la concentration de ce produit chimique. Les images peuvent être pivotées, retournées, mises en miroir et la sortie attendue doit toujours être la même. Ce type d'analyse est effectué dans de vrais laboratoires (des machines très spécialisées produisent la concentration des produits chimiques en utilisant l'analyse des couleurs, tout comme j'entraîne ce modèle à le faire).
Jusqu'à présent, je n'ai expérimenté que des modèles basés à peu près sur VGG (plusieurs séquences de blocs conv-conv-conv-pool). Avant d'expérimenter avec des architectures plus récentes (Inception, ResNets, etc.), je pensais rechercher s'il existe d'autres architectures plus couramment utilisées pour la régression à l'aide d'images.
L'ensemble de données ressemble à ceci:
L'ensemble de données contient environ 5 000 échantillons 250 x 250, que j'ai redimensionnés à 64 x 64, ce qui facilite la formation. Une fois que j'ai trouvé une architecture prometteuse, je vais expérimenter avec des images de plus grande résolution.
Jusqu'à présent, mes meilleurs modèles ont une erreur quadratique moyenne sur les ensembles de formation et de validation d'environ 0,3, ce qui est loin d'être acceptable dans mon cas d'utilisation.
Jusqu'à présent, mon meilleur modèle ressemble à ceci:
// pseudo code
x = conv2d(x, filters=32, kernel=[3,3])->batch_norm()->relu()
x = conv2d(x, filters=32, kernel=[3,3])->batch_norm()->relu()
x = conv2d(x, filters=32, kernel=[3,3])->batch_norm()->relu()
x = maxpool(x, size=[2,2], stride=[2,2])
x = conv2d(x, filters=64, kernel=[3,3])->batch_norm()->relu()
x = conv2d(x, filters=64, kernel=[3,3])->batch_norm()->relu()
x = conv2d(x, filters=64, kernel=[3,3])->batch_norm()->relu()
x = maxpool(x, size=[2,2], stride=[2,2])
x = conv2d(x, filters=128, kernel=[3,3])->batch_norm()->relu()
x = conv2d(x, filters=128, kernel=[3,3])->batch_norm()->relu()
x = conv2d(x, filters=128, kernel=[3,3])->batch_norm()->relu()
x = maxpool(x, size=[2,2], stride=[2,2])
x = dropout()->conv2d(x, filters=128, kernel=[1, 1])->batch_norm()->relu()
x = dropout()->conv2d(x, filters=32, kernel=[1, 1])->batch_norm()->relu()
y = dense(x, units=1)
// loss = mean_squared_error(y, labels)
Question
Quelle est une architecture appropriée pour la sortie de régression d'une entrée d'image?
modifier
J'ai reformulé mon explication et supprimé les mentions d'exactitude.
Modifier 2
J'ai restructuré ma question donc j'espère qu'il est clair ce que je recherche
la source
Réponses:
Tout d'abord une suggestion générale: faites une recherche documentaire avant de commencer à faire des expériences sur un sujet que vous ne connaissez pas. Vous vous épargnerez beaucoup de temps.
Dans ce cas, en regardant les documents existants, vous avez peut-être remarqué que
La régression avec les CNN n'est pas un problème trivial. En examinant à nouveau le premier article, vous verrez qu'ils ont un problème où ils peuvent essentiellement générer des données infinies. Leur objectif est de prédire l'angle de rotation nécessaire pour rectifier les images 2D. Cela signifie que je peux essentiellement prendre mon ensemble d'entraînement et l'augmenter en tournant chaque image selon des angles arbitraires, et j'obtiendrai un ensemble d'entraînement valide et plus grand. Ainsi, le problème semble relativement simple, en ce qui concerne les problèmes de Deep Learning. Soit dit en passant, notez les autres astuces d'augmentation de données qu'ils utilisent:
Sur un problème beaucoup plus simple (MNIST tourné), vous pouvez obtenir quelque chose de mieux , mais vous ne descendez toujours pas en dessous d'une erreur RMSE qui est de de l'erreur maximale possible.2.6%
Alors, que pouvons-nous apprendre de cela? Tout d'abord, ces 5000 images sont un petit ensemble de données pour votre tâche. Le premier article a utilisé un réseau qui a été pré-formé sur des images similaires à celles pour lesquelles ils voulaient apprendre la tâche de régression: non seulement vous devez apprendre une tâche différente de celle pour laquelle l'architecture a été conçue (classification), mais votre ensemble d'entraînement ne fonctionne pas ne ressemblent en rien aux ensembles de formation sur lesquels ces réseaux sont généralement formés (CIFAR-10/100 ou ImageNet). Vous n'obtiendrez donc probablement aucun avantage de l'apprentissage par transfert. L'exemple MATLAB avait 5000 images, mais elles étaient en noir et blanc et sémantiquement très similaires (eh bien, cela pourrait être votre cas aussi).
Alors, comment réaliste fait mieux que 0,3? Nous devons tout d'abord comprendre ce que vous entendez par 0,3 perte moyenne. Voulez-vous dire que l'erreur RMSE est de 0,3,
où est la taille de votre ensemble d'entraînement (donc, ), est la sortie de votre CNN pour l'image et est la concentration correspondante du produit chimique? Depuis , puis en supposant que vous coupez les prédictions de votre CNN entre 80 et 350 (ou que vous utilisez simplement un logit pour les faire tenir dans cet intervalle), vous obtenez moins de erreur. Sérieusement, à quoi vous attendez-vous? cela ne me semble pas du tout une grosse erreur.N N<5000 h(xi) xi yi yi∈[80,350] 0.12%
Essayez également de calculer le nombre de paramètres dans votre réseau: je suis pressé et je fais peut-être des erreurs stupides, alors vérifiez bien mes calculs avec une
summary
fonction de n'importe quel cadre que vous utilisez. Cependant, grosso modo, je dirais que vous avez(notez que j'ai ignoré les paramètres des couches normalisées par lot, mais ce ne sont que 4 paramètres pour la couche, donc ils ne font pas de différence). Vous avez un demi-million de paramètres et 5000 exemples ... à quoi vous attendriez-vous? Bien sûr, le nombre de paramètres n'est pas un bon indicateur de la capacité d'un réseau neuronal (c'est un modèle non identifiable), mais quand même ... Je ne pense pas que vous puissiez faire beaucoup mieux que cela, mais vous pouvez essayer un quelques choses:
la source