mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-06-15 09:33:03 +08:00
One hot in tensor.py (#3093)
* onehot in Tensor.py * one_hot tests * works for all shapes, not just 1 * pylint * not a static method * moved around, num_classes mandatory * pylint * pylint * space & moving * formatting * moved tests
This commit is contained in:
@@ -1488,6 +1488,12 @@ class TestOps(unittest.TestCase):
|
||||
helper_test_op([(32,10), (32,10)], lambda x,y: torch.nn.functional.binary_cross_entropy(x.sigmoid(),torch.clip(y,0,1)),
|
||||
lambda x,y: x.binary_crossentropy_logits(y.clip(0,1)))
|
||||
|
||||
def test_one_hot(self):
|
||||
data = [1, 2, 4]
|
||||
helper_test_op([], lambda: torch.nn.functional.one_hot(torch.tensor(data), 6), lambda: Tensor(data).one_hot(6), forward_only=True)
|
||||
data = [[[1, 2, 3], [0, 3, 5]], [[1, 2, 3], [0, 3, 5]]]
|
||||
helper_test_op([], lambda: torch.nn.functional.one_hot(torch.tensor(data), 8), lambda: Tensor(data).one_hot(8), forward_only=True)
|
||||
|
||||
if __name__ == '__main__':
|
||||
np.random.seed(1337)
|
||||
unittest.main(verbosity=2)
|
||||
|
||||
@@ -893,6 +893,8 @@ class Tensor:
|
||||
if not Tensor.training or p == 0: return self
|
||||
return self * (Tensor.rand(*self.shape, requires_grad=False, device=self.device) >= p) * (1/(1.0 - p))
|
||||
|
||||
def one_hot(self, num_classes:int, **kwargs) -> Tensor: return Tensor.where(self[..., None] == Tensor.arange(num_classes), 1, 0, **kwargs)
|
||||
|
||||
def scaled_dot_product_attention(self, key:Tensor, value:Tensor, attn_mask:Optional[Tensor]=None, dropout_p:float=0.0, is_causal:bool=False) -> Tensor: # noqa: E501
|
||||
# NOTE: it works if key, value have symbolic shape
|
||||
assert all_int(self.shape), f"does not support symbolic shape {self.shape}"
|
||||
|
||||
Reference in New Issue
Block a user