mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-06-13 00:15:35 +08:00
test const pattern [pr] (#8304)
* test const pattern [pr] * add model to test_tiny
This commit is contained in:
@@ -1,6 +1,6 @@
|
||||
# basic self-contained tests of the external functionality of tinygrad
|
||||
import unittest, random
|
||||
from tinygrad import Tensor, Context, Variable, TinyJit, dtypes, Device
|
||||
from tinygrad import Tensor, Context, Variable, TinyJit, dtypes, Device, nn
|
||||
from tinygrad.helpers import IMAGE
|
||||
|
||||
class TestTiny(unittest.TestCase):
|
||||
@@ -79,6 +79,25 @@ class TestTiny(unittest.TestCase):
|
||||
ret = Tensor.ones(s).contiguous().reshape(i.bind(s)).sum()
|
||||
self.assertEqual(ret.item(), s)
|
||||
|
||||
# *** a model ***
|
||||
|
||||
def test_mnist_model(self):
|
||||
layers = [
|
||||
nn.Conv2d(1, 32, 5), Tensor.relu,
|
||||
nn.Conv2d(32, 32, 5), Tensor.relu,
|
||||
nn.BatchNorm(32), Tensor.max_pool2d,
|
||||
nn.Conv2d(32, 64, 3), Tensor.relu,
|
||||
nn.Conv2d(64, 64, 3), Tensor.relu,
|
||||
nn.BatchNorm(64), Tensor.max_pool2d,
|
||||
lambda x: x.flatten(1), nn.Linear(576, 10)]
|
||||
|
||||
# pre-realize random weights
|
||||
for p in nn.state.get_parameters(layers): p.realize()
|
||||
|
||||
# run model inference
|
||||
probs = Tensor.rand(1, 1, 28, 28).sequential(layers).tolist()
|
||||
self.assertEqual(len(probs[0]), 10)
|
||||
|
||||
# *** image ***
|
||||
|
||||
@unittest.skipIf(Device.DEFAULT != "GPU", "image only supported on GPU")
|
||||
|
||||
Reference in New Issue
Block a user