mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-06-11 23:46:02 +08:00
gpu relu is good
This commit is contained in:
@@ -51,6 +51,8 @@ class TestOps(unittest.TestCase):
|
||||
helper_test_op([(45,65), (45,65)], lambda x,y: x**y, Tensor.pow, gpu=self.gpu)
|
||||
def test_sqrt(self):
|
||||
helper_test_op([(45,65)], lambda x: x.sqrt(), Tensor.sqrt, gpu=self.gpu)
|
||||
def test_relu(self):
|
||||
helper_test_op([(45,65)], lambda x: x.relu(), Tensor.relu, gpu=self.gpu)
|
||||
def test_dot(self):
|
||||
helper_test_op([(45,65), (65,100)], lambda x,y: x.matmul(y), Tensor.dot, atol=1e-5, gpu=self.gpu)
|
||||
|
||||
|
||||
@@ -180,12 +180,14 @@ register('matmul', Dot, gpu=True)
|
||||
|
||||
class ReLU(Function):
|
||||
@staticmethod
|
||||
def forward(ctx, x):
|
||||
return unary_op(ctx, 'res_g[gid] = min(a_g[gid], (float)0.);', x)
|
||||
def forward(ctx, input):
|
||||
ctx.save_for_backward(input)
|
||||
return unary_op(ctx, 'res_g[gid] = max(a_g[gid], (float)0.);', input)
|
||||
|
||||
@staticmethod
|
||||
def backward(ctx, grad_output):
|
||||
return grad_output
|
||||
input, = ctx.saved_tensors
|
||||
return binary_op(ctx, 'res_g[gid] = a_g[gid] * (b_g[gid] >= 0);', grad_output, input)
|
||||
register('relu', ReLU, gpu=True)
|
||||
|
||||
class LogSoftmax(Function):
|
||||
|
||||
Reference in New Issue
Block a user