diff --git a/test/models/test_train.py b/test/models/test_train.py index 6020a8e777..605e6f6de1 100644 --- a/test/models/test_train.py +++ b/test/models/test_train.py @@ -40,6 +40,7 @@ class TestTrain(unittest.TestCase): check_gc() @unittest.skipIf(CI, "slow") + @unittest.skipIf(Device.DEFAULT in ["METAL", "WEBGPU"], "too many buffers for webgpu and metal") def test_efficientnet(self): model = EfficientNet(0) X = np.zeros((BS,3,224,224), dtype=np.float32) @@ -56,6 +57,7 @@ class TestTrain(unittest.TestCase): train_one_step(model,X,Y) check_gc() + @unittest.skipIf(Device.DEFAULT in ["METAL", "WEBGPU"], "too many buffers for webgpu and metal") def test_transformer(self): # this should be small GPT-2, but the param count is wrong # (real ff_dim is 768*4)