MIT 6.S191: Taming Dataset Bias via Domain Adaptation
Transcription for the video titled "MIT 6.S191: Taming Dataset Bias via Domain Adaptation".
Note: This transcription is split and grouped by topics and subtopics. You can navigate through the Table of Contents on the left. It's interactive. All paragraphs are timed to the original video. Click on the time (e.g., 01:53) to jump to the specific portion of the video.
I'm really happy to be here today to talk to you guys about something that I'm very excited and interested in because it's my research area, so it's always fun to give a talk about your own research. So the topic of my talk is taming dataset bias via domain adaptation. And I believe you've already had some material in this course talking about sort of bias issues and perhaps fairness. So this will dovetail with that pretty well, I think. Okay, so I don't probably don't need to tell you guys or spend a lot of time on how successful deep learning has been for various applications here. I'll be focusing mostly on computer vision applications because that's my primary research area so we know that in computer vision deep learning has gotten to the point where we can detect different objects pretty accurately in a variety of scenes and we can even detect objects that are not real people or could be even cartoon characters. As long as we have training data, we can train models to do this. And we can do things like face recognition and emotion recognition. So there are a lot of applications where deep learning has been super successful. But there's also been some problems with it. And the one that I want to talk about is data set bias. So in data set bias, what happens is you have some data set, let's say you're training a computer vision model to detect pedestrians, and you want to put it on a self driving car. And so you went and collected some data. You labeled the pedestrians with bounding boxes and you trained your deep neural network and it worked really well on your held out test set that you held out from that same data set. But now if you put that same model on your car and have it try to recognize pedestrians in a different environment like in New England, which is where I am right now. And in fact, if I look out my window, that's exactly what it looks like. There's snow. There could be sort of people wearing heavy coats and looks different from my training data, which I would say I collected it in California where we don't see a lot of snow. So visually, this new data that I'm supposed to label with my model looks quite different from my training data. So this is what we refer to as data set shift and it leads to problems in terms of missing detections and generally just lower accuracy of our train model. Right. So it's called data set bias. It's also referred to as domain shift right and the primary issue here again is that the training data looks well I'm going to just say looks different but I'll define a little more more specific in a specific way later it's it looks different from the target test data when does data set bias happen?
In-Depth Discussion On Data Bias And Methods To Overcome It
When does dataset bias occur? (03:20)
Well, it happens in a few different scenarios. I'll just show a few here and they're actually, I will argue that it happens with every dataset you collect. But one example is you collect, as I already mentioned, collect a dataset in one city and you want to test on a different city. Or maybe you collect the dataset from the web and then you want to put your model on a robot that gets images from its environment where the angle and the background and the lighting is all different. Another very common issue is with simulated training that we then want to transfer to the real world. So that's a sim to real we then want to transfer to the real world. So that's a sim to real data set shift, very common in robotics. And another way that this could happen is if your training data set is primarily of a particular demographic, say if we're dealing with people, it could be mostly light skin faces. And then at test time, you're given darker skin faces and you didn't train on that kind of data so again you have a data set bias or perhaps you're classifying weddings but your training data comes from images of weddings in the western culture and then at test time you have other cultures so you again have a data set by so it could happen my point is that it can happen in many different ways and in fact i i believe that no matter what data set you collect it will have data set bias no matter what just because especially in the visual domain our visual world is so complex it's just very hard to collect enough variety to cover all possible situations so let's talk about more specifically why this is a problem and i'll give you an example that i'll use throughout my talk just to put some real numbers on this right so we probably all know by now the mnist data set so it's just a handwritten digits, very common for benchmarking neural networks. So if we train a neural network on MNIST and then we test it on the same domain on MNIST, we know that we'll have very good performance upwards of 99% accuracy. This is more or less a solved task. But if we train our network on this Street View House Numbers dataset, which is also 10 digits, the same 10 digits, but visually it's a different domain, it's from the Street View dataset. Now, when we test that model on the MNIST dataset, performance drops considerably. This is really, really bad performance for this task, right? The 10 digits. And in fact, even if we train with a much smaller shift, so this is from USPS to MNIST, visually, actually they look very similar to the human eye, but there are some small differences between these two datasets. Still performance drops pretty much the same as before. And if we swap, we still have bad performance when training on MNIST and testing on USPS. So that's just to put some numbers, like even for such a very simple task that we should have solved a long time ago in deep learning, it doesn't work, right? So if we have this dataset shift and we test the model on a new domain, it pretty much breaks. So, okay, but this is a very academic dataset.
Implications in the real-world (07:00)
What about the real world? What are the implications of dataset bias? Have we seen any actual implications of dataset bias in the real world? What are the implications of dataset bias? Have we seen any actual implications of dataset bias in the real world? And I would argue that yes, we have. This is one example where there have been several studies of face recognition models and gender recognition models, commercial software that's being deployed for these problems in the real world and these studies show that facial recognition algorithms are far less accurate at identifying african-american and asian faces compared to caucasian faces And a big part of the reason why is data set shift, because the training data sets for use for these models are biased towards Caucasian faces. So another real world example that I want to show is a very sad example, actually, where a while back there was a self-driving car accident it's actually the first time that a robot has killed a person so there was an accident that was fatal with an uber self-driving car and according to some reports the reason they think that the car failed to stop is that its algorithm was not designed to detect pedestrians outside of a crosswalk, right? So you could actually think of this as a data set bias problem. Again, if your training data contains only pedestrians on a crosswalk, which is reasonable to assume, right? Because majority of the time pedestrians follow the rules and cross on the crosswalk, and only a few times, you might see people jaywalking. You probably not see a lot of examples like that in your dataset. So you might be wondering at this point, well, wait a minute, can't we just fix this problem by collecting more data, just getting more data and labeling it? Well, yes, we could theoretically. However, it gets very, very expensive very quickly. And to illustrate why, let's take this example. This is again in the self-driving domain. This is images from the Berkeley BDD data set, This is images from the Berkeley BDD dataset, which has actually quite a variety of domains already. So it has the nighttime images, also has daytime images. And the labels here are semantic segmentation labels. So each pixel is labeled, like with road or pedestrian and so on. So if we wanted to label 1000 pedestrians with these polygons, that would cost around $1,000. This is just, you know, kind of standard market price. However, if we now want to multiply that times how many different variations we want in the pose, times variations in gender, times variations in age, race, clothing style, so on, so on. We very quickly see how much data we have to collect, right? And somewhere in there, we also want people who ride bicycles. So this becomes, this blows up very quickly, becomes very expensive so instead um maybe what we want to do is design models that can use unlabeled data rather than label data that's what i'm talking about today um so now let's think about okay what causes this poor performance that we've already seen and there are basically two main reasons, I think. The first reason is that the training and test data distributions are different. And you can see that in this picture. So here, the blue points are feature vectors extracted from the digit domain, the MNIST digit domain, using a network that was trained on the digit domain. So we train a network and we take the second to last layer activations and we plot them using t-SNE embeddings so that we can plot them in 2D. So you see these blue points are the training MNIST points, and then we take that same network and extract features from our target domain which you can see here it's basically MNIST but with different colors as opposed to black and white and so those are the red points and you can see very clearly that the distribution over the inputs is very different in the training and test domains. So that's one issue is that our classifier which is trained on the blue points will not generalize to the red points because of this distribution shift. Another problem is actually it's a little bit more subtle but if you look at how the blue points are much better clustered together with spaces between them and so these are clusters according to category of the digit but the red points are a lot more kind of spread out and not as well clustered into categories and this is because the model learned discriminative features for the source domain. And these features are not very discriminative for the target domain. So the test target points are not being grouped into classes using those features because, you know, they just the kinds of features the model needs weren't learned from the source domain.
Dealing with data bias (12:41)
All right. So what can we do? Well, we can actually do quite a lot. And here actually is a list of methods that you could use to try to deal with dataset shift that are fairly simple standard things that you could do. For example, if you just use a better backbone for your CNN, like ResNet 18, as opposed to AlexNet, you will have a smaller performance gap due to domain shift. Batch normalization done per domain is a very good trick. You can combine it with instance normalization. Of course, you could do data augmentation, use semi-supervised methods like pseudo labeling. And then what I'll talk about today is domain adaptation techniques. Okay, so let's define domain adaptation. All right, so we have a source domain which has a lot of unlabeled data, sorry, which has a lot of labeled data. So we have inputs XI and labels YI in our source domain. And then we have a target domain, which has unlabeled data. So no labels, just the inputs. And our goal is to learn a classifier F that achieves a low expected loss under the target distribution DT. Right, so we're learning on the source labels, but we want to have good performance on the target. And a key assumption in domain adaptation that is really important to keep in mind is that in domain adaptation, we assume that we get to see the unlabeled data. We get access to it. Which is which is important. We don't get the labels. Because again, we assume it's very expensive, or we just can't label it for some reason. But we do get the unlabeled data.
Adversarial domain alignment (14:38)
Okay, so what can we do? I'll so here's the outline of the rest of my talk. And I'm sure I'm going to go pretty quickly. And I'll try to have time at the end for questions. So please, if you have questions, note them down. very, at this point, conventional technique called adversarial domain alignment. And then I'll talk about a few more recent techniques that have been applied to this problem. And then we'll wrap up. Okay, so let's start with adversarial domain alignment. Okay, so say we have our source domain with labels, and we're trying to train a neural network here. I've split it into the encoder CNN. It's a convolutional neural networks because we're dealing with images and then the classifier, which is just the last layer of the network. And we can train it in a normal way using standard classification loss. Um, and then we can extract features from the encoder to plot them here, to visualize the two categories, just for illustration purposes, I'm showing just two. And then we can also visualize some notion of discriminator between classes that the classifier is learning this decision boundary. Now, we also have unlabeled target data, which is coming from our target domain. We don't have any labels, but we can take the encoder and generate features. And as we've already seen, we'll see a distribution shift between the source blue features and the target orange features. So the goal in adversarial domain alignment is to take these two distributions and align them so update the encoder cnn such that the target features are distributed in the same way as the source okay so how can we do this well it involves adding a domain discriminator think of it as just another Well, it involves adding a domain discriminator. Think of it as just another neural network, which is going to take our features from the source and the target domain, and it's going to try and predict the domain label. So its output is a binary label, source or target domain. Okay? And so this domain discriminator is trying to distinguish the blue points from the orange points. Okay. So we train it just with classification loss on the domain labels. And then that's our first step. and instead update the encoder such that the encoder results in a poor domain discriminator accuracy. So it's trying to fool the domain discriminator by generating features that are essentially indistinguishable between the source and target domain. Okay, so it's an adversarial approach because of this adversarial back and forth. First, we train the domain discriminator to do a good job at telling the domains apart. Then we fix it, and we train the encoder to fool the discriminator so that it can no longer tell the domains apart. If everything goes well, we have aligned distributions. everything goes well, we have aligned distributions. So does this actually work? Let's take a look at our digits example from before. So here we have again two digit domains and you see before adaptation the distributions of the red and the blue points were very different. Now after applying this adversarial domain alignment on the features, we can see that in fact the feature distributions are very well aligned now. You more or less cannot tell the difference in terms of the distribution between the red and the blue points. Okay, so it works. And not only does it work to align the features, it also works to improve the accuracy of the classifier we trained because we're still here training this classifier using the source domain labels, right? And this actually is what prevents our alignment from diverging into something, you know, really silly like mapping everything to one point, because we still have this classification loss that has to be satisfied. So the classifier also improves. So let's see how much. So here I'm going to show results from our CTR 17 paper called ADA or adversarial discriminative domain adaptation. So with this technique, we can improve accuracy when training on these domains and then testing on these target domains quite a lot. So it's a significant improvement. It's a little harder on it's not as good on the SVHN to MNIST shift because that is the hardest of these shifts. because that is the hardest of these shifts. Great, so the takeaway so far is that domain adaptation can improve the accuracy of the classifier on target data without requiring any labels at all. So we didn't label our target domains here at all. We just trained with unlabeled data. And so you can think of this as a form of unsupervised fine-tuning, right? So fine-tuning is something we often do to improve a model on some target task, but it requires labels. So this is something we can do if we have no labels and we just have unlabeled data, we can do this kind of unsupervised fine-tuning great um so okay so so far i talked about domain alignment in the feature space because we were updating features next i want to talk about pixel space alignment so the idea in pixel space alignment is what if we could take our source data, the images themselves, the pixels, and actually just make them look like they came from the target domain.
Pixel space alignment (20:30)
And we can actually do that thanks to adversarial generative models, or GANs, which work in a very similar way to what I already described, but the discriminator is looking at the whole image, not the features, but the actual image that's being generated. So we can take this idea and apply it here and train again that will take our source data and translate it in the image domain to make it look like it actually comes from the target domain, right? So we can do this because we have unlabeled target data. And there are a few approaches for this, just for doing this image-to-image translation. Famous one is called CycleGAN. But essentially, these are conditional GANs that use some kind of loss to align the two domains in the pixel space. So what's the point of this? Well if we now have this translated source data we still have labels for this data but it now looks like it comes from the target domain. So we can just train on this new fake target-like data with the labels, and hopefully it improves our classifier error on the target. By the way, we can still apply our previous feature alignment. So we can still add a domain discriminator on the features, just like we did before, and do these two things in tandem. And they actually do improve performance when you do both on many problems. Okay, so let me show you an example. Here's a training domain, which is, so we're trying to do semantic pixel labeling so our goal is for a neural network is to label each pixel as one of several categories like road or car or pedestrian or sky and we want to train on this GTA domain which is from the Grand Theft Auto game, which is a nice source of data because it basically is free. The labels are free. We just get them from the game. And then we want to test on the Cityscapes dataset, which is a real world dataset collected in Germany in multiple cities. So you can see what it looks like like so i'm going to show you the result of doing pixel to pixel domain alignment between these two domains so you see that here we're actually taking the real data set and translating it to the game so the original video here is from cityscapes and we're translating it to the GTA game. All right, so what happens if we apply this idea to domain adaptation? So here our source domain is GTA. Here's an example of the adapted source image that we take from GTA and translate it now into the real domain. And then when we then train the classifier on these translated images, our accuracy goes up from 54 to 83.6 per pixel accuracy on this task. So it's a really good improvement in accuracy again without using any additional labels on the target domain and also going back to our digit problem remember that really difficult shift we had from the street view image digits to the mnist digits well now with this pixel space adaptation we can see that we can take those source images from the Street View domain and make them look like MNIST images. So this middle plot, middle image shows the images on the left that are original from SVHN and we translated them to look like MNIST. And so if we train on these now, we can improve our accuracy to 90.4% on MNIST. And if we compare this with our previous result using only feature space alignment, we were getting around 76%. So we're improving on that quite a bit. So the takeaway here is that unsupervised image-to-image translation can discover and align the corresponding structures in the two domains. So there is corresponding structure, right? We have digits in both domains and they have similar structure. And what this method is doing is trying to align them by discovering these structures and making them correspond to each other.
Few-shot pixel alignment (26:03)
Great. So next I want to move on to talk about few shot pixel alignment. Okay. So, so far I didn't tell you this explicitly, but we actually assume that we have quite a lot of unlabeled target data, right? So in the case of that game, adapting from the game to the real data set, we took a lot of images from the real world. They weren't labeled, but we had a lot of them. So we had like, I don't know how many thousands of images. What happens if we only have a few images from our target domain? Well, turns out these methods that I talked about can't really handle that case. They need a lot more images. So what we did was with my graduate student and my collaborator at NVIDIA, we came up with a method that can do translation with just one or a few, maybe two or three or up to five images in the target domain. So suppose we have our source domain where we have a lot of images that are labeled. Here we're going to look at an example of different animal species. So the domain will be the species of the animal. So here we have a particular breed of dog and now we want to translate this image into a different domain which is this other breed of dog and now we want to translate this image into a different domain which is this other breed of dog but we only have one example of this breed of dog so our target domain is only given to us in one image so and then our goal is to output a translated version of our source and then our goal is to output a translated version of our source image that preserves the content of that source image but adds the style of the target domain image so here the content is the pose of the animal and the style is the species of the animal so in this case it's the breed of the dog and you can see that we're actually able to do that fairly successfully, because as you can see, we've preserved the pose of the dog, but we've changed the breed of the dog to the one from the target image. Okay? So this is a pretty cool result. And the way we've achieved this is by modifying an existing model called Funit which you see on the left here basically by updating the style encoder part of this model so we call it COCO or content conditioned style encoder and so the way our model works is it takes the content image and the style image. It encodes the content using an encoder. This is just a convolutional network. And then it also takes both the style and the content, encodes it as a style vector. And then this image decoder takes the content vector and the style vector and combines them together to generate the final output image. And there's a GAN loss on that image. Make sure that we're generating images that look like our target. So the main difference between the previous work Funit and ours that we call Cocoa Funit is that this style encoder is structured differently it's conditioned on both the content image and the style image okay so if we kind of look under the hood of this model a little more some more detail the main difference again in this is in the style encoder. It so takes the style image, encodes it with features and that also learns a separate style bias vector, which is concatenated with the image encoding. And these are parameters that are learned for the entire data set. So they're constant, they don't depend on the on the image essentially what that does is it helps the model kind of learn how to account for pose variation because in different images we'll have sometimes very drastic change in pose in one image we see the whole body of the animal and then we could have very occluded animal with just the head visible like in this example. And then the content encoding is combined with these style encodings to produce the final style code, which is used in the adaptive instance normalization framework, if you're familiar with that. If not, don't worry about it. It's just some way to combine these two vectors to generate an image. So here are some example outputs from our model. On top we have a style, so it's an image from our target species that we want our animal to look like. And then below that is the content which is the the source essentially the pose that we want to preserve and then at the bottom in the bottom row you see the generated result which our model produced and so you can see that we're actually able to preserve the pose of the content image pretty well, but combine it with the style or the the species of the source, sorry, of the target style image. And sometimes we even, you know, make something that is a cat look more like a dog, because the target domain is a dog breed, but the pose is the same as the original cat image. Or here in the last one, it's actually a bear that's generated to look like a dog. So if we compare this to the previous method called Funit that I mentioned before, we see that our model is getting significantly better generations than Funit that I mentioned before we see that our model is getting significantly better generations than Funit which in this case a lot of the time fails to produce realistic images it's just not generating images that are convincing or photorealistic and so here I'm going to play a video just to show you a few more results. Here we're actually taking the whole video and translating it into some target domains. So you see various domains on top here. So the same input video is going to be translated into each of these target domains where we have two images for each target domain. Right, so the first one's actually a fox. And now there's another example here with different bird species. So you can see that the pose of the bird is preserved from the original video, but its species is changed to the target. So there's varying levels of success there, but overall it's doing better than the previous approach and here's another final example here we're again taking the content image combining it with the style and generating the output not really sure what species this would be some kind of strange new species okay so the takeaway is that by conditioning on the content and style image together, we're able to prove the encoding of style and improve the domain translation in this few shot case.
Moving beyond alignment (33:56)
All right, so I have a little bit of time left. I think, how much time do I have actually? About 10 minutes? Yes. Okay, so in just the last few minutes I want to talk about more recent work that we've done that goes beyond these alignment techniques that I talked about and actually improves on them and the first one is self-supervised learning so one assumption that all of these methods i talked about make is that the categories are the same on the source and target and they actually break if that assumption is violated so why would we violate this assumption so suppose we have a source domain of objects and we want to transfer to a target domain from the real source to say a drawings domain. But in the drawings domain, we have some images which some of those images are the same categories that we have in the source, but some of the source categories are missing in our target domain. Suppose like we don't have cup or cello. And also we might even have new categories in the target domain that are not present in the source. Okay, so here what we have is a case of category shift, not just feature shift, not just visual domain shift, but actually the categories are shifting. And so this is a difficult case for domain alignment because in domain alignment, we always assume that the whole domain should be aligned together. And if we try to do that, in this case, we'll have catastrophic adaptation results. So we actually could do worse than just doing nothing, doing no adaptation. So in our recent paper from NeurIPS in 2020, we propose a solution to this that uses, doesn't use domain alignment, but uses self-supervised learning. Okay, so the idea is, let's say we have some source data which is labeled, that's the blue points here. We have some target which is unlabeled, and some of those points could be unknown classes right so the first thing we do is we find points pairs of points that are close together and we train a feature extractor in such a way that these points are embedded close together so we're basically trying to cluster neighboring points even closer together while pushing far away points even further apart. And we can do that because we're starting with a pre-trained model already. So let's say it's pre-trained on ImageNet. So it already gives us a pretty good initialization. So after this neighborhood clustering, which is an unsupervised loss, or we can call it self-supervised, we get a better clustering of our features that already is clustering the unlabeled target points from the known classes closer to the source points from the known classes. And then it's clustering the yellow unknown classes in the target away from those known classes and then it's clustering the yellow unknown classes in the target away from those known classes now what we want to do is add an entropy separation loss which further encourages points that have excuse me points that have a certain entropy away from the known classes. So this is essentially an outlier rejection mechanism, right? So if we look at a point and we see that it has very high entropy, it's probably an outlier. So we want to reject it and push it even further away. And so finally, what we obtain is an encoder that gives us this feature distribution where points of the same class are clustered close to the source, but points of novel classes are clustered away from the source. Okay, so if we apply it on this data set called the Vista challenge, which is training on the synthetic images and adapting to a target domain, which is real images. But some of those categories are missing in the target. And again, we don't know which ones in real life because the target is unlabeled. Right. So if we approve if we apply this dance approach that I just described, we can improve performance compared to a lot of the recent domain adaptation methods and also compared to just training on the source. So that's, we get pretty low performance if we train only on the source data. And then if we do this domain alignment on the entire domain, that's this D-A-N-N method, we actually see that we have worse accuracy than doing nothing than just training on the source. Again, because this same category assumption is violated in this problem.
Enforcing consistency (38:59)
OK, but with our method, we're actually able to do much better than just training on source and improve accuracy. Okay, and then finally, I want to mention another cool idea that has become more prevalent recently in semi supervised literature, and we can actually apply it here as well. So here we start again with self supervised pre training on the source and target domains. But in this case, we're doing a different self-supervised task. Instead of clustering points here, we are predicting rotation of images. So we can rotate an image and we know exactly what orientation it's in. But then we train our feature extractor to predict that orientation. For example, is it rotated 90 degrees or zero degrees? OK, so but again again that's just another self-supervised task it helps us pre-train a better feature encoder which is more discriminative for our source and target domain and then we apply this consistency loss so what is the consistency loss so here we're going to do some data augmentation on our unlabeled images. Okay, so we're going to take our pre trained model and then use that model to generate probability distributions of the target, sorry, of the class on the original image, and also on the augmented unlabeled image, where the augmentation is, you know, cropping, color transformation, adding noise, adding small rotations, and things like that. So it's designed to preserve the category of the object, but it changes the image. And then we take these two probability outputs and we add a loss which ensures that they're consistent. So we're telling our model, look, if you see an augmented version of this image, you should still predict the same category for that image. We don't know what it is because the image is unlabeled, but it should be the same as the original image. So with this idea, so we call the combination of this rotation prediction pre-training and consistency training, we call this PAC. And this is just one small taste of the results we got just because I don't have much time, but essentially here we are again adapting from the synthetic We are again adapting from the synthetic data set in the Vista challenge to real images. But now we are assuming a few examples are labeled in our target domain. And we are actually able to just with this pack method improve a lot on the domain alignment method, which is called MME. That's our previous work in this case so basically the point that I want you to take away from this is that domain alignment is not the only approach and we can use other approaches like self-supervised training and consistency training to improve performance on target data.
Summary and conclusion (42:05)
All right, so I'll stop here. Just to summarize what I talked about, I hope I've convinced you that dataset bias is a major problem. And I've talked about how we can solve it using domain adaptation techniques, which try to transfer knowledge using unlabeled data. And we can think of this as a form of unsupervised fine-tuning and the technique i talked about include adversarial alignment and also some other techniques that are relying on self-supervision and consistency training and so i hope you enjoyed this talk. And if you have any questions, I'll be very happy to answer them.