Cute math trick and some tips for proofs

July 14, 2019   

🚨 This post is over 5 years old 🚨

Cute math trick and some tips for proofs

I had the privilege of attending ICML 2019 with some colleagues last month, and I’ve started working through some of the papers that stood out to me. First on the docket: Combating Label Noise in Deep Learning Using Abstention. Key idea: When classifying k variables, create a k+1 category, codify the extra class as “choosing to abstain”, and then train your model to abstain explicitly rather than deriving abstention after the fact based on classification scores.

The trick

I was hitting some issues with the very first formula in the paper (I’m new to research if that wasn’t obvious), and I wanted to see how the authors coded it up.

The formula

L(xj)=(1pk+1)(i=1ktilogpi1pk+1)+αlog11pk+1

Gets translated to

loss = (1. - p_out_abstain)*h_c - \ self.alpha_var*torch.log(1. - p_out_abstain)

This seems okay (log transform on the rightmost term but whatever), but what is h_c?

h_c = F.cross_entropy(input_batch[:,0:-1],target_batch,reduce=False)

Huh? What happened to logpi1pk+1? It took me a few tries to convince myself, but these are actually equivalent.

Even after I was convinced, I had to chew on the proof for a bit. I revisited the problem after a long weekend off, and it’s pretty slick.

pi1pk+1=evi=1k+1ev1pk+1given ouput v, softmax function=evi=1k+1ev11pk+1=evi=1k+1ev1j=1kpj(1) =evi=1k+1ev1j=1k(evi=1k+1ev)=evi=1k+1evi=1k+1evj=1kev=evj=1kev=pjby softmax definition

This means that

h_c=k_i=1tilogpi1pk+1=k_i=1tilogpi=cross_entropy(v,t)for outputs, targets v, t

Pretty cool stuff, but definitely deserved a comment!

The tips

As mentioned above, it’s been a few years since I tried a proof (and to be honest, I don’t think I ever successfully did this much series - apologies to my Calc II teacher). Here are some tips I have for future me the next time I try this, maybe someone else can find them useful, too.

  • Demonstrate, then prove. When I originally saw h_c, I assumed it was a typo or mistake in the calculation. It wasn’t until I’d shown for myself that the trick works in practice that I could approach demonstrating why it worked
  • A little understanding goes a long way. I spent many hours on this little proof. If I had to give a “eureka” moment, it’d definitely be understanding what 1pk+1 represents (See (1) in the proof.). It wasn’t until I asked myself out lout Well, what is this? and answered with The remaining probability mass of the sum of the first k probabilities that I saw a path through the proof.
  • Don’t stress subscripts at first, just focus on what’s a scalar and what’s a vector. This might be a little too specific, but I found labeling the dimension of each term immensely helpful. My first work through of this proof eschewed subscripts entirely Instead, I just noted when each value represented a scalar or a vector. I haven’t used many vector-valued functions in proofs, and this explicit labeling approach helped me a lot more than pining for intuition. (Note: This technique was especially helpful, I think, because pi is overloaded in the paper - see my introduction of pk in the proof).
If you're reading this, you might like the Recurse Center! Ask me about it if you want to hear more :)