PyTorch: CrossEntropyLoss, changing class weight does not change the computed loss
According to Doc for cross entropy loss, the weighted loss is calculated by multiplying the weight for each class and the original loss.
However, in the pytorch implementation, the class weight seems to have no effect on the final loss value unless it is set to zero. Following is the code:
from torch import nn
import torch
logits = torch.FloatTensor([
[0.1, 0.9],
])
label = torch.LongTensor([0])
criterion = nn.CrossEntropyLoss(weight=torch.FloatTensor([1, 1]))
loss = criterion(logits, label)
print(loss.item()) # result: 1.1711
# Change class weight for the first class to 0.1
criterion = nn.CrossEntropyLoss(weight=torch.FloatTensor([0.1, 1]))
loss = criterion(logits, label)
print(loss.item()) # result: 1.1711, should be 0.11711
# Change weight for first class to 0
criterion = nn.CrossEntropyLoss(weight=torch.FloatTensor([0, 1]))
loss = criterion(logits, label)
print(loss.item()) # result: 0
As illustrated in the code, the class weight seems to have no effect unless it is set to 0, this behavior contradicts to the documentation.
Updates
I implemented a version of weighted cross entropy which is in my eyes the "correct" way to do it.
import torch
from torch import nn
def weighted_cross_entropy(logits, label, weight=None):
assert len(logits.size()) == 2
batch_size, label_num = logits.size()
assert (batch_size == label.size(0))
if weight is None:
weight = torch.ones(label_num).float()
assert (label_num == weight.size(0))
x_terms = -torch.gather(logits, 1, label.unsqueeze(1)).squeeze()
log_terms = torch.log(torch.sum(torch.exp(logits), dim=1))
weights = torch.gather(weight, 0, label).float()
return torch.mean((x_terms+log_terms)*weights)
logits = torch.FloatTensor([
[0.1, 0.9],
[0.0, 0.1],
])
label = torch.LongTensor([0, 1])
neg_weight = 0.1
weight = torch.FloatTensor([neg_weight, 1])
criterion = nn.CrossEntropyLoss(weight=weight)
loss = criterion(logits, label)
print(loss.item()) # results: 0.69227
print(weighted_cross_entropy(logits, label, weight).item()) # results: 0.38075
What I did is to multiply each instance in the batch with its associated class weight. The result is still different from the original pytorch implementation, which makes me wonder how pytorch actually implement this.
machine-learning pytorch
add a comment |
According to Doc for cross entropy loss, the weighted loss is calculated by multiplying the weight for each class and the original loss.
However, in the pytorch implementation, the class weight seems to have no effect on the final loss value unless it is set to zero. Following is the code:
from torch import nn
import torch
logits = torch.FloatTensor([
[0.1, 0.9],
])
label = torch.LongTensor([0])
criterion = nn.CrossEntropyLoss(weight=torch.FloatTensor([1, 1]))
loss = criterion(logits, label)
print(loss.item()) # result: 1.1711
# Change class weight for the first class to 0.1
criterion = nn.CrossEntropyLoss(weight=torch.FloatTensor([0.1, 1]))
loss = criterion(logits, label)
print(loss.item()) # result: 1.1711, should be 0.11711
# Change weight for first class to 0
criterion = nn.CrossEntropyLoss(weight=torch.FloatTensor([0, 1]))
loss = criterion(logits, label)
print(loss.item()) # result: 0
As illustrated in the code, the class weight seems to have no effect unless it is set to 0, this behavior contradicts to the documentation.
Updates
I implemented a version of weighted cross entropy which is in my eyes the "correct" way to do it.
import torch
from torch import nn
def weighted_cross_entropy(logits, label, weight=None):
assert len(logits.size()) == 2
batch_size, label_num = logits.size()
assert (batch_size == label.size(0))
if weight is None:
weight = torch.ones(label_num).float()
assert (label_num == weight.size(0))
x_terms = -torch.gather(logits, 1, label.unsqueeze(1)).squeeze()
log_terms = torch.log(torch.sum(torch.exp(logits), dim=1))
weights = torch.gather(weight, 0, label).float()
return torch.mean((x_terms+log_terms)*weights)
logits = torch.FloatTensor([
[0.1, 0.9],
[0.0, 0.1],
])
label = torch.LongTensor([0, 1])
neg_weight = 0.1
weight = torch.FloatTensor([neg_weight, 1])
criterion = nn.CrossEntropyLoss(weight=weight)
loss = criterion(logits, label)
print(loss.item()) # results: 0.69227
print(weighted_cross_entropy(logits, label, weight).item()) # results: 0.38075
What I did is to multiply each instance in the batch with its associated class weight. The result is still different from the original pytorch implementation, which makes me wonder how pytorch actually implement this.
machine-learning pytorch
2
discuss.pytorch.org/t/… Here is an explanation of how it works, cheers, sorry for confusion.
– Szymon Maszke
Jan 19 at 12:57
2
It turns out that changing the final line of theweighted_cross_entropy
toreturn torch.sum((x_terms+log_terms)*weights)/torch.sum(weights)
gives the correct behavior. Thanks for your efforts!
– AveryLiu
Jan 19 at 14:00
Though in the zero weight case the denominator would become zero, giving nan.
– AveryLiu
Jan 19 at 14:08
Implementation probably checks for this age case and returns zero.
– Szymon Maszke
Jan 19 at 14:11
yep, the pytorch implementation is correct. I commented on my own implementation.
– AveryLiu
Jan 19 at 14:13
add a comment |
According to Doc for cross entropy loss, the weighted loss is calculated by multiplying the weight for each class and the original loss.
However, in the pytorch implementation, the class weight seems to have no effect on the final loss value unless it is set to zero. Following is the code:
from torch import nn
import torch
logits = torch.FloatTensor([
[0.1, 0.9],
])
label = torch.LongTensor([0])
criterion = nn.CrossEntropyLoss(weight=torch.FloatTensor([1, 1]))
loss = criterion(logits, label)
print(loss.item()) # result: 1.1711
# Change class weight for the first class to 0.1
criterion = nn.CrossEntropyLoss(weight=torch.FloatTensor([0.1, 1]))
loss = criterion(logits, label)
print(loss.item()) # result: 1.1711, should be 0.11711
# Change weight for first class to 0
criterion = nn.CrossEntropyLoss(weight=torch.FloatTensor([0, 1]))
loss = criterion(logits, label)
print(loss.item()) # result: 0
As illustrated in the code, the class weight seems to have no effect unless it is set to 0, this behavior contradicts to the documentation.
Updates
I implemented a version of weighted cross entropy which is in my eyes the "correct" way to do it.
import torch
from torch import nn
def weighted_cross_entropy(logits, label, weight=None):
assert len(logits.size()) == 2
batch_size, label_num = logits.size()
assert (batch_size == label.size(0))
if weight is None:
weight = torch.ones(label_num).float()
assert (label_num == weight.size(0))
x_terms = -torch.gather(logits, 1, label.unsqueeze(1)).squeeze()
log_terms = torch.log(torch.sum(torch.exp(logits), dim=1))
weights = torch.gather(weight, 0, label).float()
return torch.mean((x_terms+log_terms)*weights)
logits = torch.FloatTensor([
[0.1, 0.9],
[0.0, 0.1],
])
label = torch.LongTensor([0, 1])
neg_weight = 0.1
weight = torch.FloatTensor([neg_weight, 1])
criterion = nn.CrossEntropyLoss(weight=weight)
loss = criterion(logits, label)
print(loss.item()) # results: 0.69227
print(weighted_cross_entropy(logits, label, weight).item()) # results: 0.38075
What I did is to multiply each instance in the batch with its associated class weight. The result is still different from the original pytorch implementation, which makes me wonder how pytorch actually implement this.
machine-learning pytorch
According to Doc for cross entropy loss, the weighted loss is calculated by multiplying the weight for each class and the original loss.
However, in the pytorch implementation, the class weight seems to have no effect on the final loss value unless it is set to zero. Following is the code:
from torch import nn
import torch
logits = torch.FloatTensor([
[0.1, 0.9],
])
label = torch.LongTensor([0])
criterion = nn.CrossEntropyLoss(weight=torch.FloatTensor([1, 1]))
loss = criterion(logits, label)
print(loss.item()) # result: 1.1711
# Change class weight for the first class to 0.1
criterion = nn.CrossEntropyLoss(weight=torch.FloatTensor([0.1, 1]))
loss = criterion(logits, label)
print(loss.item()) # result: 1.1711, should be 0.11711
# Change weight for first class to 0
criterion = nn.CrossEntropyLoss(weight=torch.FloatTensor([0, 1]))
loss = criterion(logits, label)
print(loss.item()) # result: 0
As illustrated in the code, the class weight seems to have no effect unless it is set to 0, this behavior contradicts to the documentation.
Updates
I implemented a version of weighted cross entropy which is in my eyes the "correct" way to do it.
import torch
from torch import nn
def weighted_cross_entropy(logits, label, weight=None):
assert len(logits.size()) == 2
batch_size, label_num = logits.size()
assert (batch_size == label.size(0))
if weight is None:
weight = torch.ones(label_num).float()
assert (label_num == weight.size(0))
x_terms = -torch.gather(logits, 1, label.unsqueeze(1)).squeeze()
log_terms = torch.log(torch.sum(torch.exp(logits), dim=1))
weights = torch.gather(weight, 0, label).float()
return torch.mean((x_terms+log_terms)*weights)
logits = torch.FloatTensor([
[0.1, 0.9],
[0.0, 0.1],
])
label = torch.LongTensor([0, 1])
neg_weight = 0.1
weight = torch.FloatTensor([neg_weight, 1])
criterion = nn.CrossEntropyLoss(weight=weight)
loss = criterion(logits, label)
print(loss.item()) # results: 0.69227
print(weighted_cross_entropy(logits, label, weight).item()) # results: 0.38075
What I did is to multiply each instance in the batch with its associated class weight. The result is still different from the original pytorch implementation, which makes me wonder how pytorch actually implement this.
machine-learning pytorch
machine-learning pytorch
edited Jan 19 at 11:58
AveryLiu
asked Jan 19 at 9:22
AveryLiuAveryLiu
17210
17210
2
discuss.pytorch.org/t/… Here is an explanation of how it works, cheers, sorry for confusion.
– Szymon Maszke
Jan 19 at 12:57
2
It turns out that changing the final line of theweighted_cross_entropy
toreturn torch.sum((x_terms+log_terms)*weights)/torch.sum(weights)
gives the correct behavior. Thanks for your efforts!
– AveryLiu
Jan 19 at 14:00
Though in the zero weight case the denominator would become zero, giving nan.
– AveryLiu
Jan 19 at 14:08
Implementation probably checks for this age case and returns zero.
– Szymon Maszke
Jan 19 at 14:11
yep, the pytorch implementation is correct. I commented on my own implementation.
– AveryLiu
Jan 19 at 14:13
add a comment |
2
discuss.pytorch.org/t/… Here is an explanation of how it works, cheers, sorry for confusion.
– Szymon Maszke
Jan 19 at 12:57
2
It turns out that changing the final line of theweighted_cross_entropy
toreturn torch.sum((x_terms+log_terms)*weights)/torch.sum(weights)
gives the correct behavior. Thanks for your efforts!
– AveryLiu
Jan 19 at 14:00
Though in the zero weight case the denominator would become zero, giving nan.
– AveryLiu
Jan 19 at 14:08
Implementation probably checks for this age case and returns zero.
– Szymon Maszke
Jan 19 at 14:11
yep, the pytorch implementation is correct. I commented on my own implementation.
– AveryLiu
Jan 19 at 14:13
2
2
discuss.pytorch.org/t/… Here is an explanation of how it works, cheers, sorry for confusion.
– Szymon Maszke
Jan 19 at 12:57
discuss.pytorch.org/t/… Here is an explanation of how it works, cheers, sorry for confusion.
– Szymon Maszke
Jan 19 at 12:57
2
2
It turns out that changing the final line of the
weighted_cross_entropy
to return torch.sum((x_terms+log_terms)*weights)/torch.sum(weights)
gives the correct behavior. Thanks for your efforts!– AveryLiu
Jan 19 at 14:00
It turns out that changing the final line of the
weighted_cross_entropy
to return torch.sum((x_terms+log_terms)*weights)/torch.sum(weights)
gives the correct behavior. Thanks for your efforts!– AveryLiu
Jan 19 at 14:00
Though in the zero weight case the denominator would become zero, giving nan.
– AveryLiu
Jan 19 at 14:08
Though in the zero weight case the denominator would become zero, giving nan.
– AveryLiu
Jan 19 at 14:08
Implementation probably checks for this age case and returns zero.
– Szymon Maszke
Jan 19 at 14:11
Implementation probably checks for this age case and returns zero.
– Szymon Maszke
Jan 19 at 14:11
yep, the pytorch implementation is correct. I commented on my own implementation.
– AveryLiu
Jan 19 at 14:13
yep, the pytorch implementation is correct. I commented on my own implementation.
– AveryLiu
Jan 19 at 14:13
add a comment |
0
active
oldest
votes
Your Answer
StackExchange.ifUsing("editor", function () {
StackExchange.using("externalEditor", function () {
StackExchange.using("snippets", function () {
StackExchange.snippets.init();
});
});
}, "code-snippets");
StackExchange.ready(function() {
var channelOptions = {
tags: "".split(" "),
id: "1"
};
initTagRenderer("".split(" "), "".split(" "), channelOptions);
StackExchange.using("externalEditor", function() {
// Have to fire editor after snippets, if snippets enabled
if (StackExchange.settings.snippets.snippetsEnabled) {
StackExchange.using("snippets", function() {
createEditor();
});
}
else {
createEditor();
}
});
function createEditor() {
StackExchange.prepareEditor({
heartbeatType: 'answer',
autoActivateHeartbeat: false,
convertImagesToLinks: true,
noModals: true,
showLowRepImageUploadWarning: true,
reputationToPostImages: 10,
bindNavPrevention: true,
postfix: "",
imageUploader: {
brandingHtml: "Powered by u003ca class="icon-imgur-white" href="https://imgur.com/"u003eu003c/au003e",
contentPolicyHtml: "User contributions licensed under u003ca href="https://creativecommons.org/licenses/by-sa/3.0/"u003ecc by-sa 3.0 with attribution requiredu003c/au003e u003ca href="https://stackoverflow.com/legal/content-policy"u003e(content policy)u003c/au003e",
allowUrls: true
},
onDemand: true,
discardSelector: ".discard-answer"
,immediatelyShowMarkdownHelp:true
});
}
});
Sign up or log in
StackExchange.ready(function () {
StackExchange.helpers.onClickDraftSave('#login-link');
});
Sign up using Google
Sign up using Facebook
Sign up using Email and Password
Post as a guest
Required, but never shown
StackExchange.ready(
function () {
StackExchange.openid.initPostLogin('.new-post-login', 'https%3a%2f%2fstackoverflow.com%2fquestions%2f54265661%2fpytorch-crossentropyloss-changing-class-weight-does-not-change-the-computed-lo%23new-answer', 'question_page');
}
);
Post as a guest
Required, but never shown
0
active
oldest
votes
0
active
oldest
votes
active
oldest
votes
active
oldest
votes
Thanks for contributing an answer to Stack Overflow!
- Please be sure to answer the question. Provide details and share your research!
But avoid …
- Asking for help, clarification, or responding to other answers.
- Making statements based on opinion; back them up with references or personal experience.
To learn more, see our tips on writing great answers.
Sign up or log in
StackExchange.ready(function () {
StackExchange.helpers.onClickDraftSave('#login-link');
});
Sign up using Google
Sign up using Facebook
Sign up using Email and Password
Post as a guest
Required, but never shown
StackExchange.ready(
function () {
StackExchange.openid.initPostLogin('.new-post-login', 'https%3a%2f%2fstackoverflow.com%2fquestions%2f54265661%2fpytorch-crossentropyloss-changing-class-weight-does-not-change-the-computed-lo%23new-answer', 'question_page');
}
);
Post as a guest
Required, but never shown
Sign up or log in
StackExchange.ready(function () {
StackExchange.helpers.onClickDraftSave('#login-link');
});
Sign up using Google
Sign up using Facebook
Sign up using Email and Password
Post as a guest
Required, but never shown
Sign up or log in
StackExchange.ready(function () {
StackExchange.helpers.onClickDraftSave('#login-link');
});
Sign up using Google
Sign up using Facebook
Sign up using Email and Password
Post as a guest
Required, but never shown
Sign up or log in
StackExchange.ready(function () {
StackExchange.helpers.onClickDraftSave('#login-link');
});
Sign up using Google
Sign up using Facebook
Sign up using Email and Password
Sign up using Google
Sign up using Facebook
Sign up using Email and Password
Post as a guest
Required, but never shown
Required, but never shown
Required, but never shown
Required, but never shown
Required, but never shown
Required, but never shown
Required, but never shown
Required, but never shown
Required, but never shown
2
discuss.pytorch.org/t/… Here is an explanation of how it works, cheers, sorry for confusion.
– Szymon Maszke
Jan 19 at 12:57
2
It turns out that changing the final line of the
weighted_cross_entropy
toreturn torch.sum((x_terms+log_terms)*weights)/torch.sum(weights)
gives the correct behavior. Thanks for your efforts!– AveryLiu
Jan 19 at 14:00
Though in the zero weight case the denominator would become zero, giving nan.
– AveryLiu
Jan 19 at 14:08
Implementation probably checks for this age case and returns zero.
– Szymon Maszke
Jan 19 at 14:11
yep, the pytorch implementation is correct. I commented on my own implementation.
– AveryLiu
Jan 19 at 14:13