playground/torchtut/rick.py

13 lines
283 B
Python
Raw Normal View History

2024-09-21 21:50:07 -04:00
import torch
import numpy as np
import math
target = torch.tensor([0, 1.0, 0])
input = torch.tensor([.333, .333, .334])
print("nats: ", -math.log(.333))
print("bits: ", -math.log(.333, 2))
b = torch.nn.functional.cross_entropy(input, target).item()
print(b)
print(b / math.log(2))