It would be great if entmax worked with torch.float16 and torch.bfloat16. Unfortunately, it currently does not. There are bugs for both bisection and the exact algorithm. Here I'll document a numerical stability problem that exists for the bisection-based algorithm for both torch.float16 and torch.bfloat16 (don't believe the propaganda that says that bf16 is a drop-in solution for float32).
Let's say you have a 32-bit vector of logits whose largest element is sufficiently negative.
a = torch.zeros(128, device="cuda").fill_(-5) # torch.float32
a[0] = 0
a -= 1000
With alpha=1.5, the correct output for this vector is a one-hot distribution peaked on index 0. We get this behavior with both entmax.entmax15 and entmax.entmax_bisect.
p1 = entmax.entmax15(a)
p2 = entmax.entmax_bisect(a, alpha=1.5)
p1[0] == p2[0] == 1 # True
Ok, great. But what happens if we use torch.float16?
b = a.to(torch.float16)
p3 = entmax.entmax_bisect(b, alpha=1.5)
p3.isnan().all() # True
and what about torch.bfloat16?
c = a.to(torch.bfloat16)
p4 = entmax.entmax_bisect(c, alpha=1.5)
p4.isnan().all() # True
Well that's not good! (solution after this commercial break)
It would be great if entmax worked with
torch.float16andtorch.bfloat16. Unfortunately, it currently does not. There are bugs for both bisection and the exact algorithm. Here I'll document a numerical stability problem that exists for the bisection-based algorithm for bothtorch.float16andtorch.bfloat16(don't believe the propaganda that says that bf16 is a drop-in solution for float32).Let's say you have a 32-bit vector of logits whose largest element is sufficiently negative.
With
alpha=1.5, the correct output for this vector is a one-hot distribution peaked on index 0. We get this behavior with bothentmax.entmax15andentmax.entmax_bisect.Ok, great. But what happens if we use
torch.float16?and what about
torch.bfloat16?Well that's not good! (solution after this commercial break)