Hi, it's a great work!
We have three inputs designated as i1, i2, and i3, which are to be processed by the llama-7b. For input i1, I will extract two hidden states at two distinct locations and label them p11 and p12, respectively. Regarding the remaining inputs, i2 and i3, I will select a single hidden state for each, which will be denoted as n21 and n31 correspondingly.
In this setup, p11 paired with n21 constitutes a positive pair, whereas p11 coupled with n22 forms a negative pair. Meanwhile, p12 paired with n22 constitutes a positive pair, whereas p12 coupled with n21 forms a negative pair. My objective is to compute the InfoNCE loss between these pairs.
So I set the get_rep_fn in the class GradCache to handle the different situations. Here is a sample snippet or a piece of example code:
def get_rep_fn(x):
if x.label == 2:
return [x.e1, x.e2]
else:
return [x.e1]
In the same time, I changed the following code from append to extend:
|
model_reps.append(self.get_reps(y)) |
|
all_reps.append(model_reps) |
I'd like to inquire about the correctness of the gradient computation. Could you please confirm if it's being done accurately?
Thanks!
Hi, it's a great work!
We have three inputs designated as
i1,i2, andi3, which are to be processed by the llama-7b. For inputi1, I will extract two hidden states at two distinct locations and label themp11andp12, respectively. Regarding the remaining inputs,i2andi3, I will select a single hidden state for each, which will be denoted asn21andn31correspondingly.In this setup,
p11paired withn21constitutes a positive pair, whereasp11coupled withn22forms a negative pair. Meanwhile,p12paired withn22constitutes a positive pair, whereasp12coupled withn21forms a negative pair. My objective is to compute the InfoNCE loss between these pairs.So I set the
get_rep_fnin the classGradCacheto handle the different situations. Here is a sample snippet or a piece of example code:In the same time, I changed the following code from
appendtoextend:GradCache/src/grad_cache/grad_cache.py
Line 187 in 0c33638
GradCache/src/grad_cache/grad_cache.py
Line 270 in 0c33638
I'd like to inquire about the correctness of the gradient computation. Could you please confirm if it's being done accurately?
Thanks!