hotfix: cast mnist to float

This commit is contained in:
George Hotz
2024-04-09 19:30:03 -07:00
parent fea774f669
commit 216eb235e5

View File

@@ -21,6 +21,9 @@ class Model:
if __name__ == "__main__":
X_train, Y_train, X_test, Y_test = mnist()
# TODO: remove this when HIP is fixed
X_train, X_test = X_train.float(), X_test.float()
model = Model()
opt = nn.optim.Adam(nn.state.get_parameters(model))