mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-06-11 23:46:02 +08:00
no contigs there
This commit is contained in:
@@ -4,10 +4,8 @@ from tinygrad.nn.state import tar_extract
|
||||
def mnist(device=None, fashion=False):
|
||||
base_url = "http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/" if fashion else "https://storage.googleapis.com/cvdf-datasets/mnist/"
|
||||
def _mnist(file): return Tensor.from_url(base_url+file, gunzip=True)
|
||||
return _mnist("train-images-idx3-ubyte.gz")[0x10:].reshape(-1,1,28,28).to(device).contiguous(), \
|
||||
_mnist("train-labels-idx1-ubyte.gz")[8:].to(device).contiguous(), \
|
||||
_mnist("t10k-images-idx3-ubyte.gz")[0x10:].reshape(-1,1,28,28).to(device).contiguous(), \
|
||||
_mnist("t10k-labels-idx1-ubyte.gz")[8:].to(device).contiguous()
|
||||
return _mnist("train-images-idx3-ubyte.gz")[0x10:].reshape(-1,1,28,28).to(device), _mnist("train-labels-idx1-ubyte.gz")[8:].to(device), \
|
||||
_mnist("t10k-images-idx3-ubyte.gz")[0x10:].reshape(-1,1,28,28).to(device), _mnist("t10k-labels-idx1-ubyte.gz")[8:].to(device)
|
||||
|
||||
def cifar(device=None):
|
||||
tt = tar_extract(Tensor.from_url('https://www.cs.toronto.edu/~kriz/cifar-10-binary.tar.gz', gunzip=True))
|
||||
|
||||
@@ -27,8 +27,8 @@ class Optimizer:
|
||||
|
||||
def _new_optim_param(self) -> list[Tensor]:
|
||||
param_dtype = to_dtype(getenv("OPTIM_DTYPE", "float32"))
|
||||
if self.fused: return [Tensor.zeros(self.pos_params[-1], dtype=param_dtype, device=self.device, requires_grad=False)]
|
||||
return [Tensor.zeros_like(t, dtype=param_dtype, requires_grad=False) for t in self.params]
|
||||
if self.fused: return [Tensor.zeros(self.pos_params[-1], dtype=param_dtype, device=self.device, requires_grad=False).contiguous()]
|
||||
return [Tensor.zeros_like(t, dtype=param_dtype, requires_grad=False).contiguous() for t in self.params]
|
||||
|
||||
def zero_grad(self):
|
||||
"""
|
||||
@@ -154,7 +154,7 @@ class LAMB(Optimizer):
|
||||
def __init__(self, params: list[Tensor], lr=0.001, b1=0.9, b2=0.999, eps=1e-6, weight_decay=0.0, adam=False, fused=FUSE_OPTIM):
|
||||
super().__init__(params, lr, fused)
|
||||
self.b1, self.b2, self.eps, self.wd, self.adam = b1, b2, eps, weight_decay, adam
|
||||
self.b1_t, self.b2_t = (Tensor.ones((1,), dtype=dtypes.float32, device=self.device, requires_grad=False) for _ in [b1, b2])
|
||||
self.b1_t, self.b2_t = (Tensor.ones((1,), dtype=dtypes.float32, device=self.device, requires_grad=False).contiguous() for _ in [b1, b2])
|
||||
self.m = self._new_optim_param()
|
||||
self.v = self._new_optim_param()
|
||||
|
||||
|
||||
Reference in New Issue
Block a user