Fastformer: Additive Attention Can Be All You Need (Machine Learning Research Paper Explained)

Fastformer: Additive Attention Can Be All You Need (Machine Learning Research Paper Explained)

hello there today we’ll look at fast former additive attention can be all you need by chuwan wu fung chow woo tao chi and yong feng huang so this paper definitely wins out in the category of most innovative paper titles of the last few months as apparently we’ve gone from is all you need to can be all you need so a big win on this front as you might have guessed from this title the paper is introducing a new kind of attention mechanism if you don’t know what an attention mechanism is and you’re in machine learning you might want to find out i have a paper video on attention is all you need so the new attention here is additive attention which is supposed to be a much much much faster way of doing attention thus the name fast former this additive attention circumvents this quadratic bottleneck that we usually have in the attention mechanism instead of doing sort of multiplicative attention they do what they call additive attention now the naming in my opinion is a bit confusing and the whole concept is a bit confusing so on a high level that’s what they do they design a new attention mechanism my opinion of the paper is that it’s kind of uh deceptively naming things to make it appear like it’s an attention mechanism where in reality it seems to be sort of a just sort of a feed-forward-ish layer type of thing that they propose maybe not even so you know we’ll go into that uh their promises are that of course at circumventing this quadratic bottleneck of attention you can input much longer sequences into the context of a transformer and you can do it also much faster for the same length of sequences since everything is just additive and not multiplicative when i find that out they claim they have a lot of experimental evidence and yeah if you like content like this you know don’t hesitate to subscribe if you haven’t done so already so the abstract reads um transformer are very powerful okay however the attention mechanism is inefficient due to the quadratic complexity to input sequence length they say although there are many methods on transformer acceleration they are still either inefficient on long sequences or not effective enough by effective i guess they mean uh their performance suffers too much so they say they propose fast former an efficient transformer model based on additive attention so instead of modeling the pairwise interactions between tokens which is what attention does we first use additive attention mechanism to model global contexts and then further transform each token representation based on its interaction with the global context representations now if this sounds confusing to you it does so to me too they go a little bit into more detail right here they say they have this additive attention uh which is linear complexity instead of quadratic as in usual transformers so here is a bit more detail we use additive attention to summarize the input attention query matrix into a global query vector then we model the interaction between the attention key and the global query vector via elementwise product to learn the global context aware key matrix we further summarize it into a global key vector via additive attention then we use element wise product to aggregate the global key and attention value which are further processed by a linear transformation to compute the global context aware attention value finally we add together the original attention query and the global context aware attention value to form the final output you know still after this paragraph doesn’t make too much sense to me to understand um so we’ll go to the diagram in just one second but here is essentially what they promise okay they propose an additive attention-based transformer named fast former to our knowledge fast former is the most efficient transformer architecture so that’s one they propose the most efficient transformer architecture second we propose to model the interaction between global context and token representations we element wise product which can help fully model context information an efficient way okay so the the element wise product seems to be the second component so there’s additive attention uh there is element wise product and then lastly they say you know our experimental data sets uh valid validate our approach all right so here is the coveted diagram of the fast former it’s a little bit complicated but i want to go back a little bit to the regular attention mechanism i know i’ve done this a lot but i think in this context it is really worth discussing so in a regular attention mechanism what do you have you have some sort of an input sequence each one of these things can be a a vector some sort of an embedding vector or something like this but it’s a sequence essentially it’s a set but we think of it as a sequence of let’s say tokens in natural language and we want to transform the sequence of one layer into a sequence of equal length of the next layer so if we stack many of these layers together we sort of want to improve the representations of these tokens layer by layer by layer such that we can at the end of the transformer understand what each token means in the context of all other tokens so if this is a sentence my house is very green then at the at the beginning each word is just an isolated piece of data at the end of these transformations we want sort of all the tokens to uh to be aware of all the other tokens in the input and uh sort of capture their in-context meaning now what we need to do is we need to transform one set of representations into the next one the way we do this is by the attention mechanism so the attention mechanism essentially from each of the tokens it derives three different things one is called a key so the key is a vector so the key is a vector for each token and that vector describes kind of like what the content of this token is so far okay so one vector is the key which allows the token to advertise what it has to offer the other one is the query which allows each token and that’s also derived from the same token but i’m going to draw it up here the query means what does this token want to know about the other tokens in the sequence so this can be different from its content so as you see the query and the key they might be different there are variants where this the same but usually you derive two different values from each token and then what we do is we route by inner product so you for every single query you aggregate across the entire input sequence sequence you aggregate by inner product which means that this would get routed here by a lot this one may be two these ones not so much and so on so you aggregate essentially the inner product which for each query gives you a histogram a histogram across the sequence saying okay this information here is mildly relevant this one is more relevant this one is slightly relevant these ones aren’t relevant at all for me this histogram you then normalize via a soft max operation and that gives you i mean that gives you a real distribution over the input so with the query and the key you decide how you want to aggregate the information in the input sequence for one particular element in the output sequence you do this for every element so for every element you get a distribution of how you want to aggregate and then in the last step every single item also emits what’s called a value and the value is yet another vector and the value i guess you don’t even have to actually transform anything the value you can just take the information itself of the token if you want but essentially the value is ultimately what you multiply together with this distribution and then that becomes your next layer representation for this particular token all right so the whole query key attention mechanism is simply to decide how do i want to aggregate the um the different values of the input sequence for any given token in the next layer all right okay i hope this is clear so the the query the key advertises what the contents are which is kind of like the value the value is the actual contents but the key is more like an addressable representation of the content and the query uh emits what do i want to know about the others so you won’t match the queries of myself with the key of the others and that aggregates now in that context let’s look at the fast former so we said there are two elements there is first of all there is this additive attention and that’s what you can see kind of down here so you see there’s the input and the input gets transformed into three different things into queries keys and values that is just like a regular attention mechanism these are linear transformations that each token independently goes through so this token independently produces this this query this key and this value and with the same transformation this token produces this query this key and these this value so there’s no interaction every token goes through the same transformation then you can see instead of now considering the interactions between each of the queries and each of the keys sorry that should probably be up here instead of considering this interaction we don’t do that what we do first is we say well this really becomes quadratic if we do if we consider interaction between each query and each key therefore let’s simply construct one global query okay one global query and then we consider the interaction of that global query with each of the keys instead of um instead of uh doing everything with everything so here is you where here you can see how the linearness instead of the quadraticness of this approach comes to be instead of considering pairwise interactions we simply construct a single query vector by the way this is all this is one head so this is one head uh usually a transformer has multiple heads so over here you would have like head number two and so on had number three head number four but in a single head we make one query vector um yeah and you immediately see what the shortcomings are here whereas previously every token could sort of dynamically decide how it wants to aggregate information and every token could do that you know in a in a sort of by itself now it’s only the sequence as a whole that gets to decide how it wants to aggregate information because it needs to come up with a combined query vector so i’m going to guess this thing here works might work quite well for tasks that have sort of a single single-minded output sort of topic classification or something like this where you simply you know the global information is necessary usually whereas tasks that might be more you know nuanced and language relevant like considering specific interactions between individual tokens and so on uh those might fall a lot short in this approach okay but how how does this single query vector come to be now this single query vector is constructed purely as you can see from the queries of the individual token elements how there’s this funny construction here where you have you can see this is the query vector right here and then it itself goes here and here so it’s used twice okay so we what we do is we construct this alpha value for each query vector and then we multiply that alpha value by the query vector itself and then we add this is an addition here we add all together at the end so essentially this query vector here the global one is a weighted sum across all of the individual query vectors now the question is you know how do we side decide on the weight and that’s where these alpha values come in so let’s see oh yeah here is the formulas for the alpha values so each query vector q i um will produce the its own alpha i how is that computed as you can see right here this should be familiar to you this is the softmax formula so what we do is we it’s also the formula for logistic regression if you squint a little bit um so essentially the alpha i’s are the result of a soft max operation across the queries so you have query one query two query three right it’s a softmax across not the queries itself but this quantity right here the query multiplied by some sort of a transformation and this now really looks like logistic regression this w here is a vector that is learned this is a learned parameter vector right i take the inner product with each of the queries and that gives me like a number right and then what i do is i simply normalize this by all the numbers of all the queries okay so every one of these gets multiplied by this w which gives me one number and then i simply normalize i i push it through the exponential function then i normalize it this is essentially a logistic regression with the w being the feature vector okay now what does it mean what does this mean okay like we construct the final query vector as an aggregate across all query vectors with the weightings being dependent on like a soft max or a logistic regression with respect to this learned vector w this is always the same right for for every one of those queries uh i can make sense of that if i think okay this is the w here is essentially um you know in logistic regression you classify so the w vector me is the sort of the classification boundary of you know the one class versus the other class so this here i think is essentially a little classifier that cares about one particular thing that is learned so this can be some intermediate feature that is useful that is learned via back propagation in this w vector and the weighting of this particular head in this particular layer is then according to that feature okay so in here there is somewhere there is a w vector and that w vector in this particular layer for this particular head refers to some kind of useful feature like i don’t know like is there then a name of a country somewhere in the sentence okay and that’s what we use as a way to aggregate the queries so um you can immediately see that if a term if you know a token um it’s if it’s query sort of contains a country information this classifier would you know say well that particular query has a lot of the information that i particularly look for in this layer therefore the inner product will be high therefore the alpha will be high therefore that particular query would be represented greatly in the global query vector so the global query vector essentially you can think of i select among all the query vectors the ones that i care about in this particular layer in this particular head however what you care about in this layer in this head is static it’s statically learned it’s the same for every single sample okay all right so this is sort of a weighing by particular feature now once we have the global query vector right here how do we let it interact with the key vector so usually what we do is we do an inner product of the query and the key and then that defines sort of our aggregation distribution however since we only have a single query you know that will not give us that will in fact not give us an an n dimensional c uh sorry an n length sequence as here that will only give us a sequence of length one in the next layer so we can’t really do that so what they do is they almost do an inner product except they don’t sum right they do simply element-wise multiplications of the queries and the keys now element-wise multiplication it kind of means um so it means you know like the element-wise multiplication if you think of it if both elements are small the result is very small if and if both are high the result is very high so there’s some non-linear dynamics going on within the same dimension right there’s no aggregation across dimensions um and yeah so they do element wise multiplication right here in order to obtain these p vectors and the p vectors they are now the integration each every p vector p vector so p i is equal to the element-wise multiplication of the i of key vector with the global query vector okay so yeah and the queer the query vector itself is of course a sum across a weighted sum across all of the queries so if i pull the k in you can see that i still have okay alpha j i still have this quadratic thing here i still have for you know i get i have n uh p vectors uh and for each one i have also n q vectors and i consider products of the form i j so i still have the quadratic products in here however i don’t have quadratic complexity why because i don’t have a soft max in between aggregating the queries and aggregating the keys and therefore uh you know the what is the commuted associative rule applies and i can simply get away with first aggregating the query and then multiplying it as a whole by the keys now of course that are those are two linear operations in sequence whereas in the normal attention mechanism i have a linear operation then a non-linear one with the softmax and then again a linear one and arguably the non-linearities is what brings the whole power to deep learning so uh you know this essentially here you can see how it really circumvents the quadratic bottlenecks by simply saying well if everything’s linear then there you know we can we can just add all together uh yeah that’s the trick essentially now then you you realize we’re not done yet okay what do we do with the p vectors well this seems familiar right again we do another one of these additive attentions so they call this thing additive attention you can see from each p1 we produce a beta value the beta value exactly the same way as the alpha values i suppose at least yes you can see that right here right the beta value is exactly the same for each p we multiply it by a learned feature vector which is wk right here and then we normalize by all of them and you know after the exponential function and then we aggregate the global key via again a weighted sum of all of these p vectors so this is again additive attention in order in order to have a global key vector and now exactly the same trick we use the global key vector element-wise multiplied by the value vectors which gives us these u vectors right here that these apparently go through another linear transformation uh to give us the r vectors um you know you can you can stack as many linear transformations as you want and then we’re we’re still not done right we’re still not done so essentially what we’ve done in the end is we we we take the values which is the information we want to forward propagate and for each value we element-wise multiply it uh with this k vector and this k vector is a result of the keys and also this query vector and that’s a result of the the cues so essentially there is no aggregation of information as is there in the regular transformer i don’t aggregate the values from the sequence in a weighted fashion i simply leave each value as it is you know these are as i said these are transformations that don’t depend on the other sequence elements so v1 purely depends on e1 and the only way the only way that token information from the other tokens can come into any token is via this aggregation methods right here in in that in the normalization constant right in in in the aggregation uh that happens via the normalization you know for example the key n could be represented more in this global key and then that’s multiplied here to my vector 1. so that’s how other information comes into any particular token and as i said we’re still not done after we obtain these r vectors we then add to them this thing right here we add to them the query vectors again now why i don’t know but we just do so we simply add the query vectors to the um r vectors that we have here and that’s going to be our final output so this is stupidly complex and i don’t think for any particular reason so there are multiple problems right here for example this transformation right here is a linear transformation um i okay maybe it makes sense but it seems like you just had a linear transformation here and this whole sum here is sort of a linear aggregation ergo yeah okay maybe you can justify that but second of all this connection right here right if this is not ablated in the experiment like i don’t believe squat here um like i want to know how much this this is clearly not something you do from the beginning this is clearly something you add after the other stuff don’t doesn’t work so i want to see an experiment where this connection is missing to decide and i want to see an experiment where only this connection happens to decide you know where the actual work is going here then another thing you can see this here the the middle column is entirely useless like like this this right here it simply it simply the lower part is a repetition from sorry the upper part here is a repetition from the left so these two things are repeating um and then the lower part is repeated here right and in fact you can stack as many of these columns they just call them query key and value well if i just call them column one column two and here this this is like the final column find f c f right i can in fact insert column three column four column five i can insert as many as i want because it’s just repeated right that there’s no qualitative difference that differentiates the queries from the keys in this model right only the values are a bit different because at the end they’re not aggregated into this global vector uh with this additive attention thing but in essence you know you could do away completely with for example with the key column and directly do the query multiplying them into the values completely possible so completely unnecessary key column now you might think okay if the key column is unnecessary or if i can introduce 50 keys in between 50 key columns that always take the last whatever global vector and multiply it in and do additive attention um is this really an attention mechanism and the answer is kind of but not in the way you expect it’s a bit sneaky honestly see attention is when i have well arguably right who am i to define this but arguably attention is when i create one of these things in a dynamic way and these things are how do i aggregate information how do i weigh information from an input sequence okay that is in essence an attention mechanism dynamically creating this weighting so the only way this actually really happens right here is where we’re in this w thing right so this here is in fact the attention mechanism not the not the not this this is just a weighted sum like this here is the the hidden attention mechanism with it’s essentially a self-attention mechanism right you can you can see so the alpha eyes are how do we aggregate information and then okay i guess yeah this belongs to the attention mechanism but uh the keys and the queries sorry the keys and the values are both what they call q right what i aggregate here those are essentially the values the things to be addressed these are essentially the keys so the query is essentially this thing right here that’s that’s the query now the query as you can see is not dynamic the query is just statically learned which makes this essentially into a like a feed forward network or at best an attention mechanism with a single learned query so instead of having n queries now we have one query per head and that’s why i said the thing at the very beginning if if this is applied to a task that largely relies on you know single-minded task global global information tasks and so on such as sequence classification or something like this it can be that i only need a couple of intermediate really different features per layer after all they are vector valued so um which means that if i have eight heads which have eight different w vectors and you know there are two w vectors per layer to be fair there is a w here and there’s also a w again in this thing right here so every column gives me essentially a new feature to extract right so the number of heads times the number of these columns i have is essentially the number of features i can of static features i can extract from such a sequence and as i said for global information tasks that might in fact be enough and in that case you know good i can i can get around however i could have done the same thing probably by yeah but by simply constructing less queries than um keys and reducing the sequence length or something like this i mean there are there are many ways of this but i i think the thing here is framed in terms of the words of an attention mechanism where the actual attention mechanism is simply like the thing here that happens inside the queries it’s essentially a self-attention mechanism on top of the queries with not a dynamic but one single fixed query the same goes for column two and then column three is just kind of like weird like it’s kind of a weird residual connection or something where where there’s this this product here with something that’s incoming it’s kind of like a feed forward layer again um like a dynamic feed forward layer per token yeah so yes that’s that’s why i find the name a bit deceptive right here also to formulate as query key and value uh here and la and and their whole talk about who we model the interaction between something something something yeah okay but what about experiments they’re experiments i find to be relatively lacking they do have a lot of baseline comparisons which is respectable their data sets however appear to be uh yeah things like sentiment classification topic classification tasks and it you know they they do perform well i um you know experimental results are experimental results and um then you know the best numbers are are achieved by ensembles which is which is also fine right but even the regular numbers right here appear to be quite competitive so i don’t exactly know um yeah the the complexity right here is also a bit shaky because they sort of leave away the linear operations and so on like yeah and as i said there are no ablations of most of the things so there are no ablations for example of this residual connection where you just randomly add the query like why would you do that like that doesn’t even make sense if you call this a query this thing then by itself it should carry no information to pass on by nature of being a query right so you know why do you why do you add it up there you know what’s the effect of the individual columns how many there are right um you know there are many things to ablate here to really show why this model performs well what they do is they compare sort of the run time and the the runtime as the sequence length increases and as you can see they’re quite fast right here which i guess fast tran is this fast former i guess fast transformer is fast former um so and and the regular transformer and they also are like a constant factor faster than others but you know are like are you a constant factor faster because you actually don’t do any sort of attention uh i don’t i don’t know so yeah that those are my my two cents uh to this paper again this might be a neat model for certain tasks it’s certainly fast it certainly uh doesn’t make you run out of memory as a regular transformer for a given set of tasks it might in fact work better than a transformer my main problem here is with with the whole framing in terms of attention um in terms of the the sort of same languages trying to pass this off as a faster transformer which it is not all right let me know what you think in the comments and thanks for listening bye-bye
rn

Fastformer: Additive Attention Can Be All You Need (Machine Learning Research Paper Explained)

rn

Share this post

Leave a Reply

Your email address will not be published. Required fields are marked *