From the uDemy course on LLM engineering.
https://www.udemy.com/course/llm-engineering-master-ai-and-large-language-models
You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
334 lines
9.0 KiB
334 lines
9.0 KiB
WEBVTT |
|
|
|
00:00.710 --> 00:06.650 |
|
And now let me make this real for you by showing you some, some diagrams, particularly now looking |
|
|
|
00:06.650 --> 00:11.810 |
|
at how training works with Chullora, which is how we're actually doing it in practice. |
|
|
|
00:12.440 --> 00:16.280 |
|
So first of all, let's talk about this forward pass. |
|
|
|
00:16.280 --> 00:20.780 |
|
So here is a diagram that should be familiar to you because it's the same one I used before. |
|
|
|
00:20.810 --> 00:27.350 |
|
That shows our llama 3.1 base model that we've quantized all the way down to four bits with its 8 billion |
|
|
|
00:27.380 --> 00:31.340 |
|
parameters and the white, all of it is frozen. |
|
|
|
00:31.340 --> 00:36.350 |
|
We're not going to be changing its weights as part of training, because it would be way too much work, |
|
|
|
00:36.380 --> 00:43.400 |
|
way too much memory, and too slow to try and shift and tweak and optimize these 8 billion parameters. |
|
|
|
00:43.850 --> 00:50.630 |
|
So what we're seeing here is a bunch of frozen rows of weights, and the ones in yellow are also frozen, |
|
|
|
00:50.630 --> 00:53.180 |
|
but they're representing our target modules. |
|
|
|
00:53.180 --> 00:55.940 |
|
Which color is going to be applied to. |
|
|
|
00:56.330 --> 00:59.210 |
|
And let's bring in our Lora adapters. |
|
|
|
00:59.210 --> 00:59.790 |
|
There they are. |
|
|
|
00:59.820 --> 01:06.090 |
|
You may remember that to be technically correct, there are in fact two adapters called A and B for |
|
|
|
01:06.090 --> 01:10.380 |
|
each of the different, um target modules they have. |
|
|
|
01:10.440 --> 01:15.420 |
|
Dimensionality given by R, and you may remember that alpha is the scaling factor. |
|
|
|
01:15.420 --> 01:21.090 |
|
And it just is as simple as the way that these are applied to their target module is that it's alpha |
|
|
|
01:21.120 --> 01:22.380 |
|
times A times B. |
|
|
|
01:22.710 --> 01:24.450 |
|
Um simple as that. |
|
|
|
01:24.450 --> 01:27.750 |
|
So that is our neural network that we know well. |
|
|
|
01:27.780 --> 01:33.270 |
|
And we are going to be training the weights in this, these Laura adapters, which in our case is about |
|
|
|
01:33.300 --> 01:36.090 |
|
109MB worth of, of weights. |
|
|
|
01:36.780 --> 01:38.280 |
|
Uh, okay. |
|
|
|
01:38.280 --> 01:40.290 |
|
So then what happens? |
|
|
|
01:40.290 --> 01:42.900 |
|
We have an input prompt over on the left. |
|
|
|
01:42.900 --> 01:47.760 |
|
It is like price is dollars and then it's the next token. |
|
|
|
01:47.760 --> 01:50.070 |
|
We want the model to get good at predicting. |
|
|
|
01:50.520 --> 01:57.780 |
|
The forward pass is when we take that and we in inference mode basically we we put it take it through |
|
|
|
01:57.780 --> 02:02.400 |
|
the the model to say please predict the next token. |
|
|
|
02:03.360 --> 02:09.780 |
|
And so that goes through the model and what comes out the other side is a predicted next token price |
|
|
|
02:09.780 --> 02:10.260 |
|
is. |
|
|
|
02:10.290 --> 02:11.820 |
|
And then 99. |
|
|
|
02:11.820 --> 02:17.970 |
|
And again because we've got we're taking advantage of this simplicity of llama 3.1 that in fact 99 is |
|
|
|
02:17.970 --> 02:19.170 |
|
just one token. |
|
|
|
02:19.320 --> 02:20.400 |
|
Um, not not. |
|
|
|
02:20.430 --> 02:23.730 |
|
And that will always be the case for any three digit number. |
|
|
|
02:23.730 --> 02:27.300 |
|
That's not that that's critical, but it does simplify things a bit for us. |
|
|
|
02:27.300 --> 02:29.190 |
|
So that's the forward pass. |
|
|
|
02:29.220 --> 02:29.820 |
|
All right. |
|
|
|
02:29.850 --> 02:32.040 |
|
Now onto the loss calculation. |
|
|
|
02:32.610 --> 02:33.780 |
|
So here we are again. |
|
|
|
02:33.780 --> 02:35.400 |
|
We've predicted the next token. |
|
|
|
02:35.400 --> 02:41.400 |
|
So now the model is able to look up what or not the model the training process. |
|
|
|
02:41.400 --> 02:44.010 |
|
The SFT trainer in our case looks up. |
|
|
|
02:44.010 --> 02:45.960 |
|
What was the actual next token. |
|
|
|
02:45.960 --> 02:49.080 |
|
Because we've got the training data, we know the actual next token. |
|
|
|
02:49.080 --> 02:50.100 |
|
And what was it? |
|
|
|
02:50.100 --> 02:51.810 |
|
Let's say it was 89. |
|
|
|
02:51.810 --> 02:52.590 |
|
It was lower. |
|
|
|
02:52.590 --> 02:54.000 |
|
So we were wrong. |
|
|
|
02:54.000 --> 02:55.800 |
|
Wrong by $10. |
|
|
|
02:56.070 --> 02:57.370 |
|
Um, or wrong. |
|
|
|
02:57.370 --> 03:00.280 |
|
By different token, it doesn't know that this represents $10. |
|
|
|
03:00.280 --> 03:02.350 |
|
It just knows it's a different token. |
|
|
|
03:02.710 --> 03:05.650 |
|
Um, and so there is some kind of a loss. |
|
|
|
03:05.650 --> 03:10.180 |
|
And in just a moment I'm going to explain what that loss is and why it's not quite as simple as it just |
|
|
|
03:10.180 --> 03:11.500 |
|
being a different token. |
|
|
|
03:11.500 --> 03:16.300 |
|
There's a technicality there that we'll get to, but for now you can just think of it as it predicted |
|
|
|
03:16.300 --> 03:16.960 |
|
99. |
|
|
|
03:16.990 --> 03:18.670 |
|
The actual value is 89. |
|
|
|
03:18.700 --> 03:20.050 |
|
We have a loss. |
|
|
|
03:20.680 --> 03:23.140 |
|
So that's the loss calculation. |
|
|
|
03:23.140 --> 03:29.350 |
|
Step three is the backward pass that you hear people calling backprop or backward propagation. |
|
|
|
03:29.350 --> 03:30.580 |
|
Back propagation. |
|
|
|
03:30.760 --> 03:38.470 |
|
Um, and in backprop basically we look back through the network, back we go and we say, all right, |
|
|
|
03:38.470 --> 03:43.930 |
|
so how much if we if we were to tweak these weights by a little bit, how much would that affect the |
|
|
|
03:43.930 --> 03:44.650 |
|
loss? |
|
|
|
03:44.650 --> 03:48.850 |
|
How sensitive is the loss to to those those weights. |
|
|
|
03:48.880 --> 03:54.280 |
|
It gives us what we call the gradients of the weights of the parameters. |
|
|
|
03:54.280 --> 03:57.320 |
|
So when I say weights are synonymous with parameters. |
|
|
|
03:57.560 --> 04:00.770 |
|
Uh, so, uh, how, uh, um, yeah. |
|
|
|
04:00.800 --> 04:02.120 |
|
What are the gradients? |
|
|
|
04:02.120 --> 04:05.240 |
|
If we were to change those weights, what would it do to the loss? |
|
|
|
04:05.240 --> 04:08.030 |
|
Because we want to try and improve things a little bit. |
|
|
|
04:08.360 --> 04:14.120 |
|
Um, and so those, uh, red triangles are meant to represent Delta, uh, for showing like a gradient |
|
|
|
04:14.210 --> 04:16.100 |
|
calculation that's happened there. |
|
|
|
04:17.180 --> 04:22.730 |
|
Uh, and then finally we get to the last step, step four optimization. |
|
|
|
04:22.940 --> 04:23.960 |
|
Here it comes. |
|
|
|
04:23.960 --> 04:26.360 |
|
So we've got these gradients. |
|
|
|
04:26.420 --> 04:30.770 |
|
And what we now need to do is we want to take a tiny step in the right direction. |
|
|
|
04:30.770 --> 04:36.830 |
|
So we want to update the parameters in our Laura matrices a little tiny bit. |
|
|
|
04:36.830 --> 04:41.720 |
|
So that next time if it gets the same input prompt, the loss will be a bit lower. |
|
|
|
04:41.720 --> 04:42.890 |
|
It would do a bit better. |
|
|
|
04:42.890 --> 04:48.200 |
|
So we're taking a step in the right direction, and we use the learning rate to decide how much of a |
|
|
|
04:48.200 --> 04:48.980 |
|
step to take. |
|
|
|
04:48.980 --> 04:53.810 |
|
Because as I said before, there are pros and cons of taking smaller or larger steps. |
|
|
|
04:54.090 --> 05:00.690 |
|
Um, and you may remember that the optimizer we're using, the Adam W optimizer, does something quite |
|
|
|
05:00.690 --> 05:06.720 |
|
cunning, where it doesn't just use these gradients, it keeps a kind of rolling average of prior gradients, |
|
|
|
05:06.720 --> 05:12.540 |
|
so that it's being really smart about how to take that step in a way that's most likely to improve things. |
|
|
|
05:12.660 --> 05:17.700 |
|
And it's also trying to make sure that we don't do things like overfitting and other dangers that we've |
|
|
|
05:17.730 --> 05:18.810 |
|
talked about before. |
|
|
|
05:19.200 --> 05:20.370 |
|
So we do that. |
|
|
|
05:20.370 --> 05:26.430 |
|
And what happens is that the Laura adapters are then improved, and as they are applied in the future |
|
|
|
05:26.430 --> 05:29.700 |
|
to our base model, it will do slightly better. |
|
|
|
05:29.730 --> 05:36.390 |
|
The loss will be a bit lower next time, and that is why you see that the training loss is coming down |
|
|
|
05:36.390 --> 05:40.740 |
|
and our wiggly charts, because it's learning and getting better and better. |
|
|
|
05:41.070 --> 05:46.050 |
|
And so it's worth noting, of course, that the thing that the weights, the parameters that get changed |
|
|
|
05:46.050 --> 05:51.180 |
|
are the parameters in the Laura adapters, the green ones, we don't actually change any parameters |
|
|
|
05:51.180 --> 05:55.620 |
|
in the llama 3.1 base that is just too big, too many parameters. |
|
|
|
05:55.620 --> 06:02.280 |
|
If you were doing ordinary training, fine tuning, not not Laura based, but ordinary, then it would |
|
|
|
06:02.280 --> 06:09.120 |
|
be the whole llama 3.1 model that would need to be having gradients calculated and need to be shifted |
|
|
|
06:09.120 --> 06:10.320 |
|
during optimization. |
|
|
|
06:10.320 --> 06:13.650 |
|
And that, of course, is what these big companies like meta. |
|
|
|
06:13.680 --> 06:19.170 |
|
That's how they've trained llama 3.1in the first place, and they've spent significant amounts of money |
|
|
|
06:19.200 --> 06:25.620 |
|
doing it, which which we don't have easy access to, which is why we're using our Lora adapters instead. |
|
|
|
06:26.190 --> 06:32.160 |
|
So that's the quick summary of optimization in the sorry of the whole of the training process. |
|
|
|
06:32.160 --> 06:37.530 |
|
And I imagine that was probably mostly clear to you already, but I'm hoping that these diagrams of |
|
|
|
06:37.530 --> 06:41.310 |
|
crystallized it for you and mean that everything has fallen into place. |
|
|
|
06:41.310 --> 06:47.550 |
|
And in the next video, I'm just going to explain one more technicality, a very important technicality |
|
|
|
06:47.550 --> 06:51.570 |
|
about what we mean by the prediction and the loss calculation. |
|
|
|
06:51.600 --> 06:52.380 |
|
See you then.
|
|
|