I Trained a BTS Image Classifier
BTS has recently been back on the music charts in full force, with their second all-English single, “Butter,” having now spent 3 consecutive weeks at No. 1 on the Billboard Hot 100 chart. In addition to the fact that this week marks the 8th anniversary of BTS’s debut, BTS fans, which they call ARMY, have a lot to celebrate. Most audiences that don’t follow K-pop, however, may have a hard time keeping track of all of BTS’s members and activities; even late night host James Corden, who played an important role in introducing the K-pop band to American audiences, barely escaped unscathed in a previous segment of Spill Your Guts or Fill Your Guts:
The fact that Western entertainment audiences still struggle at times to identify the members of BTS led me to wonder if a neural network could do a better identification job. Thus, in the spirit of BTS festivities this month, I’ve created a multiclass image classifier that takes in a user-uploaded picture of a single BTS member and will identify which member he is. You can try the classifier here (it takes a moment to load on first launch).
Below I explain the steps to and learnings from creating the classifier with fastai v2, as well as areas for future improvement.
Collecting and Preprocessing Images
The first step in generating the initial classifier is collecting data. For this, I used the Bing Image Search API in order to collect images of each individual member. Some of the search results included GIFs and other non-image objects, and so I used the fastai library’s
verify_images() function to remove corrupt search results. The remaining dataset included an average of 517 images per BTS member.
In order to train the model, I needed to have training and validation sets. In this case, it worked fine to split the downloaded images into the training and validation sets randomly, so I called the fastai
RandomSplitter class to allocate 20% of the data to the validation set:
seed sets the training and validation split to get the same split on multiple runs.
Some image classification models require training datasets magnitudes larger than the ~3600 images I have for BTS, but given that I was planning on using transfer learning on an already-trained model, I focused my attention on looking for ways to create more variation in the data through data augmentation. Standard methods for data augmentation include increasing contrast, stretching or rotating, but in order to preserve the features that would allow the model to learn recognition, I used fastai’s
RandomResizedCrop, a method of data augmentation that in every epoch takes a random crop of each image to help learn and recognize different features in the images.
Training a Model Using Transfer Learning
Image classification is one of the more mature areas of deep learning research, and there are now a lot of well-known models readily available for the task. In particular, convolutional neural networks (CNNs) perform well on image classification tasks, and so I used the pre-trained CNN Resnet34 to learn classification of my BTS member data.
As with most pre-trained models, Resnet34 does not come fine-tuned for BTS-specific images. Thus, in order to adjust the model to the task at hand, I first trained the added layers to the model (with the other pretrained layers frozen) for 3 epochs, then unfreezed all the layers and trained them on 12 epochs:
learn = cnn_learner(dls, resnet34, metrics=error_rate) learn.fit_one_cycle(3, 3e-3) learn.unfreeze() learn.fit_one_cycle(12, lr_max=slice(1e-5,1e-3), wd=0.1)
To improve the model, I also included discriminative learning rates, using the Python
slice object to set the earliest layer in the model (with its pretrained weights) to a learning rate of 1e-5, and scaling up the later learning rates to 1e-3. The learning rate values were determined via the fastai
learn.lr_find() function based on researcher Leslie N. Smith’s learning rate finder. In addition, I added a weight decay parameter (
wd=0.1) to limit the weights from growing too large and overfitting on the dataset.
After 12 epochs, I was able to train the model to a 25.3% error rate, or 74.7% accuracy.
To get an understanding of the types of mistakes the model is making, I took a look at the confusion matrix:
It seems that the model made the most errors in classifying J-Hope, both in wrongly classifying a member as J-hope when he wasn’t, and in wrongly classifying J-Hope as different members.
Learnings and Future Improvements
Given that many image classifiers today are trained to very high accuracy, the 74.7% accuracy of my initial BTS classifier could definitely be improved on. While getting more data is not always possible, one major area is better preprocessing of the images for facial recognition. Data augmentation during preprocessing helped to increase the variation in the dataset, but didn’t do much to cut out noise. The images of BTS members found on search engines often have a lot of noise - band members change their hairstyles frequently, often pose with hand gestures close to their faces, and wear facial accessories (such as mics, headbands and, recently, masks). Preprocessing using libraries such as MTCNN or OpenCV would help to perform face detection, rotate non-centered faces and crop out noise, which would result in cleaner data and likely lead to a much lower error rate.
In addition to better preprocessing, I found that there was also just a bunch of bad data in the dataset. Search results for a member would sometimes return images of another member, and many images were heavily edited with text and emojis, either by fans or the members themselves. Cleaning up the dataset to only include correctly classified images would also improve model performance.
Another improvement I would make to the classifier is to allow for multilabel classification, so that if more than one member shows up in an image, the model could identify all of them. This case showed up frequently in the search engine data, where images with multiple members were included. I’d imagine this would also be common in images users might want to upload to the classifier to identify as well.
One final issue I’d want to watch out for is distribution differences between the image data used during training and the actual production data. Since my BTS member classifier was trained using internet images that tend to be high resolution and professional, it may not perform well if users mainly upload phone-captured images, which can be more casual and lower resolution. I’d want to pay attention to the distribution of the production data and make sure the production and validation dataset distributions are as similar as possible.
There are of course many other things to consider when improving a deep learning model, and for now, the classifier running in production does a surprisingly decent job of identifying who is who in BTS. You can try the classifier and reach out to me on Twitter or via email for any thoughts, questions, or ideas for improvement - I’d love to hear them!
Code for the BTS Classifier can be found here: https://github.com/gleanawa/bts_classifier