PyTorch: CrossEntropyLoss, changing class weight does not change the computed loss












1















According to Doc for cross entropy loss, the weighted loss is calculated by multiplying the weight for each class and the original loss.



However, in the pytorch implementation, the class weight seems to have no effect on the final loss value unless it is set to zero. Following is the code:



from torch import nn
import torch

logits = torch.FloatTensor([
[0.1, 0.9],
])
label = torch.LongTensor([0])

criterion = nn.CrossEntropyLoss(weight=torch.FloatTensor([1, 1]))
loss = criterion(logits, label)
print(loss.item()) # result: 1.1711

# Change class weight for the first class to 0.1
criterion = nn.CrossEntropyLoss(weight=torch.FloatTensor([0.1, 1]))
loss = criterion(logits, label)
print(loss.item()) # result: 1.1711, should be 0.11711

# Change weight for first class to 0
criterion = nn.CrossEntropyLoss(weight=torch.FloatTensor([0, 1]))
loss = criterion(logits, label)
print(loss.item()) # result: 0


As illustrated in the code, the class weight seems to have no effect unless it is set to 0, this behavior contradicts to the documentation.



Updates

I implemented a version of weighted cross entropy which is in my eyes the "correct" way to do it.



import torch
from torch import nn

def weighted_cross_entropy(logits, label, weight=None):
assert len(logits.size()) == 2
batch_size, label_num = logits.size()
assert (batch_size == label.size(0))

if weight is None:
weight = torch.ones(label_num).float()

assert (label_num == weight.size(0))

x_terms = -torch.gather(logits, 1, label.unsqueeze(1)).squeeze()
log_terms = torch.log(torch.sum(torch.exp(logits), dim=1))

weights = torch.gather(weight, 0, label).float()

return torch.mean((x_terms+log_terms)*weights)

logits = torch.FloatTensor([
[0.1, 0.9],
[0.0, 0.1],

])

label = torch.LongTensor([0, 1])

neg_weight = 0.1

weight = torch.FloatTensor([neg_weight, 1])

criterion = nn.CrossEntropyLoss(weight=weight)
loss = criterion(logits, label)

print(loss.item()) # results: 0.69227
print(weighted_cross_entropy(logits, label, weight).item()) # results: 0.38075


What I did is to multiply each instance in the batch with its associated class weight. The result is still different from the original pytorch implementation, which makes me wonder how pytorch actually implement this.










share|improve this question




















  • 2





    discuss.pytorch.org/t/… Here is an explanation of how it works, cheers, sorry for confusion.

    – Szymon Maszke
    Jan 19 at 12:57






  • 2





    It turns out that changing the final line of the weighted_cross_entropy to return torch.sum((x_terms+log_terms)*weights)/torch.sum(weights) gives the correct behavior. Thanks for your efforts!

    – AveryLiu
    Jan 19 at 14:00













  • Though in the zero weight case the denominator would become zero, giving nan.

    – AveryLiu
    Jan 19 at 14:08











  • Implementation probably checks for this age case and returns zero.

    – Szymon Maszke
    Jan 19 at 14:11











  • yep, the pytorch implementation is correct. I commented on my own implementation.

    – AveryLiu
    Jan 19 at 14:13
















1















According to Doc for cross entropy loss, the weighted loss is calculated by multiplying the weight for each class and the original loss.



However, in the pytorch implementation, the class weight seems to have no effect on the final loss value unless it is set to zero. Following is the code:



from torch import nn
import torch

logits = torch.FloatTensor([
[0.1, 0.9],
])
label = torch.LongTensor([0])

criterion = nn.CrossEntropyLoss(weight=torch.FloatTensor([1, 1]))
loss = criterion(logits, label)
print(loss.item()) # result: 1.1711

# Change class weight for the first class to 0.1
criterion = nn.CrossEntropyLoss(weight=torch.FloatTensor([0.1, 1]))
loss = criterion(logits, label)
print(loss.item()) # result: 1.1711, should be 0.11711

# Change weight for first class to 0
criterion = nn.CrossEntropyLoss(weight=torch.FloatTensor([0, 1]))
loss = criterion(logits, label)
print(loss.item()) # result: 0


As illustrated in the code, the class weight seems to have no effect unless it is set to 0, this behavior contradicts to the documentation.



Updates

I implemented a version of weighted cross entropy which is in my eyes the "correct" way to do it.



import torch
from torch import nn

def weighted_cross_entropy(logits, label, weight=None):
assert len(logits.size()) == 2
batch_size, label_num = logits.size()
assert (batch_size == label.size(0))

if weight is None:
weight = torch.ones(label_num).float()

assert (label_num == weight.size(0))

x_terms = -torch.gather(logits, 1, label.unsqueeze(1)).squeeze()
log_terms = torch.log(torch.sum(torch.exp(logits), dim=1))

weights = torch.gather(weight, 0, label).float()

return torch.mean((x_terms+log_terms)*weights)

logits = torch.FloatTensor([
[0.1, 0.9],
[0.0, 0.1],

])

label = torch.LongTensor([0, 1])

neg_weight = 0.1

weight = torch.FloatTensor([neg_weight, 1])

criterion = nn.CrossEntropyLoss(weight=weight)
loss = criterion(logits, label)

print(loss.item()) # results: 0.69227
print(weighted_cross_entropy(logits, label, weight).item()) # results: 0.38075


What I did is to multiply each instance in the batch with its associated class weight. The result is still different from the original pytorch implementation, which makes me wonder how pytorch actually implement this.










share|improve this question




















  • 2





    discuss.pytorch.org/t/… Here is an explanation of how it works, cheers, sorry for confusion.

    – Szymon Maszke
    Jan 19 at 12:57






  • 2





    It turns out that changing the final line of the weighted_cross_entropy to return torch.sum((x_terms+log_terms)*weights)/torch.sum(weights) gives the correct behavior. Thanks for your efforts!

    – AveryLiu
    Jan 19 at 14:00













  • Though in the zero weight case the denominator would become zero, giving nan.

    – AveryLiu
    Jan 19 at 14:08











  • Implementation probably checks for this age case and returns zero.

    – Szymon Maszke
    Jan 19 at 14:11











  • yep, the pytorch implementation is correct. I commented on my own implementation.

    – AveryLiu
    Jan 19 at 14:13














1












1








1








According to Doc for cross entropy loss, the weighted loss is calculated by multiplying the weight for each class and the original loss.



However, in the pytorch implementation, the class weight seems to have no effect on the final loss value unless it is set to zero. Following is the code:



from torch import nn
import torch

logits = torch.FloatTensor([
[0.1, 0.9],
])
label = torch.LongTensor([0])

criterion = nn.CrossEntropyLoss(weight=torch.FloatTensor([1, 1]))
loss = criterion(logits, label)
print(loss.item()) # result: 1.1711

# Change class weight for the first class to 0.1
criterion = nn.CrossEntropyLoss(weight=torch.FloatTensor([0.1, 1]))
loss = criterion(logits, label)
print(loss.item()) # result: 1.1711, should be 0.11711

# Change weight for first class to 0
criterion = nn.CrossEntropyLoss(weight=torch.FloatTensor([0, 1]))
loss = criterion(logits, label)
print(loss.item()) # result: 0


As illustrated in the code, the class weight seems to have no effect unless it is set to 0, this behavior contradicts to the documentation.



Updates

I implemented a version of weighted cross entropy which is in my eyes the "correct" way to do it.



import torch
from torch import nn

def weighted_cross_entropy(logits, label, weight=None):
assert len(logits.size()) == 2
batch_size, label_num = logits.size()
assert (batch_size == label.size(0))

if weight is None:
weight = torch.ones(label_num).float()

assert (label_num == weight.size(0))

x_terms = -torch.gather(logits, 1, label.unsqueeze(1)).squeeze()
log_terms = torch.log(torch.sum(torch.exp(logits), dim=1))

weights = torch.gather(weight, 0, label).float()

return torch.mean((x_terms+log_terms)*weights)

logits = torch.FloatTensor([
[0.1, 0.9],
[0.0, 0.1],

])

label = torch.LongTensor([0, 1])

neg_weight = 0.1

weight = torch.FloatTensor([neg_weight, 1])

criterion = nn.CrossEntropyLoss(weight=weight)
loss = criterion(logits, label)

print(loss.item()) # results: 0.69227
print(weighted_cross_entropy(logits, label, weight).item()) # results: 0.38075


What I did is to multiply each instance in the batch with its associated class weight. The result is still different from the original pytorch implementation, which makes me wonder how pytorch actually implement this.










share|improve this question
















According to Doc for cross entropy loss, the weighted loss is calculated by multiplying the weight for each class and the original loss.



However, in the pytorch implementation, the class weight seems to have no effect on the final loss value unless it is set to zero. Following is the code:



from torch import nn
import torch

logits = torch.FloatTensor([
[0.1, 0.9],
])
label = torch.LongTensor([0])

criterion = nn.CrossEntropyLoss(weight=torch.FloatTensor([1, 1]))
loss = criterion(logits, label)
print(loss.item()) # result: 1.1711

# Change class weight for the first class to 0.1
criterion = nn.CrossEntropyLoss(weight=torch.FloatTensor([0.1, 1]))
loss = criterion(logits, label)
print(loss.item()) # result: 1.1711, should be 0.11711

# Change weight for first class to 0
criterion = nn.CrossEntropyLoss(weight=torch.FloatTensor([0, 1]))
loss = criterion(logits, label)
print(loss.item()) # result: 0


As illustrated in the code, the class weight seems to have no effect unless it is set to 0, this behavior contradicts to the documentation.



Updates

I implemented a version of weighted cross entropy which is in my eyes the "correct" way to do it.



import torch
from torch import nn

def weighted_cross_entropy(logits, label, weight=None):
assert len(logits.size()) == 2
batch_size, label_num = logits.size()
assert (batch_size == label.size(0))

if weight is None:
weight = torch.ones(label_num).float()

assert (label_num == weight.size(0))

x_terms = -torch.gather(logits, 1, label.unsqueeze(1)).squeeze()
log_terms = torch.log(torch.sum(torch.exp(logits), dim=1))

weights = torch.gather(weight, 0, label).float()

return torch.mean((x_terms+log_terms)*weights)

logits = torch.FloatTensor([
[0.1, 0.9],
[0.0, 0.1],

])

label = torch.LongTensor([0, 1])

neg_weight = 0.1

weight = torch.FloatTensor([neg_weight, 1])

criterion = nn.CrossEntropyLoss(weight=weight)
loss = criterion(logits, label)

print(loss.item()) # results: 0.69227
print(weighted_cross_entropy(logits, label, weight).item()) # results: 0.38075


What I did is to multiply each instance in the batch with its associated class weight. The result is still different from the original pytorch implementation, which makes me wonder how pytorch actually implement this.







machine-learning pytorch






share|improve this question















share|improve this question













share|improve this question




share|improve this question








edited Jan 19 at 11:58







AveryLiu

















asked Jan 19 at 9:22









AveryLiuAveryLiu

17210




17210








  • 2





    discuss.pytorch.org/t/… Here is an explanation of how it works, cheers, sorry for confusion.

    – Szymon Maszke
    Jan 19 at 12:57






  • 2





    It turns out that changing the final line of the weighted_cross_entropy to return torch.sum((x_terms+log_terms)*weights)/torch.sum(weights) gives the correct behavior. Thanks for your efforts!

    – AveryLiu
    Jan 19 at 14:00













  • Though in the zero weight case the denominator would become zero, giving nan.

    – AveryLiu
    Jan 19 at 14:08











  • Implementation probably checks for this age case and returns zero.

    – Szymon Maszke
    Jan 19 at 14:11











  • yep, the pytorch implementation is correct. I commented on my own implementation.

    – AveryLiu
    Jan 19 at 14:13














  • 2





    discuss.pytorch.org/t/… Here is an explanation of how it works, cheers, sorry for confusion.

    – Szymon Maszke
    Jan 19 at 12:57






  • 2





    It turns out that changing the final line of the weighted_cross_entropy to return torch.sum((x_terms+log_terms)*weights)/torch.sum(weights) gives the correct behavior. Thanks for your efforts!

    – AveryLiu
    Jan 19 at 14:00













  • Though in the zero weight case the denominator would become zero, giving nan.

    – AveryLiu
    Jan 19 at 14:08











  • Implementation probably checks for this age case and returns zero.

    – Szymon Maszke
    Jan 19 at 14:11











  • yep, the pytorch implementation is correct. I commented on my own implementation.

    – AveryLiu
    Jan 19 at 14:13








2




2





discuss.pytorch.org/t/… Here is an explanation of how it works, cheers, sorry for confusion.

– Szymon Maszke
Jan 19 at 12:57





discuss.pytorch.org/t/… Here is an explanation of how it works, cheers, sorry for confusion.

– Szymon Maszke
Jan 19 at 12:57




2




2





It turns out that changing the final line of the weighted_cross_entropy to return torch.sum((x_terms+log_terms)*weights)/torch.sum(weights) gives the correct behavior. Thanks for your efforts!

– AveryLiu
Jan 19 at 14:00







It turns out that changing the final line of the weighted_cross_entropy to return torch.sum((x_terms+log_terms)*weights)/torch.sum(weights) gives the correct behavior. Thanks for your efforts!

– AveryLiu
Jan 19 at 14:00















Though in the zero weight case the denominator would become zero, giving nan.

– AveryLiu
Jan 19 at 14:08





Though in the zero weight case the denominator would become zero, giving nan.

– AveryLiu
Jan 19 at 14:08













Implementation probably checks for this age case and returns zero.

– Szymon Maszke
Jan 19 at 14:11





Implementation probably checks for this age case and returns zero.

– Szymon Maszke
Jan 19 at 14:11













yep, the pytorch implementation is correct. I commented on my own implementation.

– AveryLiu
Jan 19 at 14:13





yep, the pytorch implementation is correct. I commented on my own implementation.

– AveryLiu
Jan 19 at 14:13












0






active

oldest

votes











Your Answer






StackExchange.ifUsing("editor", function () {
StackExchange.using("externalEditor", function () {
StackExchange.using("snippets", function () {
StackExchange.snippets.init();
});
});
}, "code-snippets");

StackExchange.ready(function() {
var channelOptions = {
tags: "".split(" "),
id: "1"
};
initTagRenderer("".split(" "), "".split(" "), channelOptions);

StackExchange.using("externalEditor", function() {
// Have to fire editor after snippets, if snippets enabled
if (StackExchange.settings.snippets.snippetsEnabled) {
StackExchange.using("snippets", function() {
createEditor();
});
}
else {
createEditor();
}
});

function createEditor() {
StackExchange.prepareEditor({
heartbeatType: 'answer',
autoActivateHeartbeat: false,
convertImagesToLinks: true,
noModals: true,
showLowRepImageUploadWarning: true,
reputationToPostImages: 10,
bindNavPrevention: true,
postfix: "",
imageUploader: {
brandingHtml: "Powered by u003ca class="icon-imgur-white" href="https://imgur.com/"u003eu003c/au003e",
contentPolicyHtml: "User contributions licensed under u003ca href="https://creativecommons.org/licenses/by-sa/3.0/"u003ecc by-sa 3.0 with attribution requiredu003c/au003e u003ca href="https://stackoverflow.com/legal/content-policy"u003e(content policy)u003c/au003e",
allowUrls: true
},
onDemand: true,
discardSelector: ".discard-answer"
,immediatelyShowMarkdownHelp:true
});


}
});














draft saved

draft discarded


















StackExchange.ready(
function () {
StackExchange.openid.initPostLogin('.new-post-login', 'https%3a%2f%2fstackoverflow.com%2fquestions%2f54265661%2fpytorch-crossentropyloss-changing-class-weight-does-not-change-the-computed-lo%23new-answer', 'question_page');
}
);

Post as a guest















Required, but never shown

























0






active

oldest

votes








0






active

oldest

votes









active

oldest

votes






active

oldest

votes
















draft saved

draft discarded




















































Thanks for contributing an answer to Stack Overflow!


  • Please be sure to answer the question. Provide details and share your research!

But avoid



  • Asking for help, clarification, or responding to other answers.

  • Making statements based on opinion; back them up with references or personal experience.


To learn more, see our tips on writing great answers.




draft saved


draft discarded














StackExchange.ready(
function () {
StackExchange.openid.initPostLogin('.new-post-login', 'https%3a%2f%2fstackoverflow.com%2fquestions%2f54265661%2fpytorch-crossentropyloss-changing-class-weight-does-not-change-the-computed-lo%23new-answer', 'question_page');
}
);

Post as a guest















Required, but never shown





















































Required, but never shown














Required, but never shown












Required, but never shown







Required, but never shown

































Required, but never shown














Required, but never shown












Required, but never shown







Required, but never shown







Popular posts from this blog

Liquibase includeAll doesn't find base path

How to use setInterval in EJS file?

Petrus Granier-Deferre