13 lines
283 B
Python
13 lines
283 B
Python
|
|
import torch
|
||
|
|
import numpy as np
|
||
|
|
import math
|
||
|
|
|
||
|
|
target = torch.tensor([0, 1.0, 0])
|
||
|
|
input = torch.tensor([.333, .333, .334])
|
||
|
|
|
||
|
|
print("nats: ", -math.log(.333))
|
||
|
|
print("bits: ", -math.log(.333, 2))
|
||
|
|
b = torch.nn.functional.cross_entropy(input, target).item()
|
||
|
|
print(b)
|
||
|
|
print(b / math.log(2))
|