mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-06-13 00:15:35 +08:00
touchup function.py [pr] (#8220)
* touchup function.py [pr] * remove ALLOWED_READ_IMAGE * eh, keep it, just change it
This commit is contained in:
2
.github/workflows/test.yml
vendored
2
.github/workflows/test.yml
vendored
@@ -295,7 +295,7 @@ jobs:
|
||||
- if: ${{ matrix.task == 'optimage' }}
|
||||
name: Test openpilot model kernel count and gate usage
|
||||
run: |
|
||||
PYTHONPATH="." ALLOWED_KERNEL_COUNT=208 ALLOWED_READ_IMAGE=2131 ALLOWED_GATED_READ_IMAGE=13 FLOAT16=0 GPU=1 IMAGE=2 python examples/openpilot/compile3.py https://github.com/commaai/openpilot/raw/v0.9.4/selfdrive/modeld/models/supercombo.onnx
|
||||
PYTHONPATH="." ALLOWED_KERNEL_COUNT=208 ALLOWED_READ_IMAGE=2138 ALLOWED_GATED_READ_IMAGE=13 FLOAT16=0 GPU=1 IMAGE=2 python examples/openpilot/compile3.py https://github.com/commaai/openpilot/raw/v0.9.4/selfdrive/modeld/models/supercombo.onnx
|
||||
- if: ${{ matrix.task == 'optimage' }}
|
||||
name: Test openpilot alt model correctness (float32)
|
||||
run: PYTHONPATH="." FLOAT16=0 DEBUGCL=1 GPU=1 IMAGE=2 python examples/openpilot/compile3.py https://github.com/commaai/openpilot/raw/3799fe46b3a629e491d4b8498b8ae83e4c88c304/selfdrive/modeld/models/supercombo.onnx
|
||||
|
||||
@@ -41,7 +41,7 @@ class Sin(Function):
|
||||
|
||||
class Relu(Function):
|
||||
def forward(self, x:UOp) -> UOp:
|
||||
self.ret = x.maximum(0)
|
||||
self.ret = (x>0).where(x, 0)
|
||||
return self.ret
|
||||
|
||||
def backward(self, grad_output:UOp) -> UOp: return (self.ret>0).cast(grad_output.dtype) * grad_output
|
||||
@@ -79,7 +79,8 @@ class Sigmoid(Function):
|
||||
return (self.ret * (1 - self.ret)) * grad_output
|
||||
|
||||
class Sign(Function):
|
||||
def forward(self, x:UOp) -> UOp: return x.ne(0).where((x<0).where(x.const_like(-1), x.const_like(1)), x.const_like(0))
|
||||
# NOTE: the x*0 is to match torch behavior without function.py
|
||||
def forward(self, x:UOp) -> UOp: return x.ne(0).where((x<0).where(x.const_like(-1), x.const_like(1)), x.const_like(0)) + x*0
|
||||
# backward always return 0 to match torch
|
||||
def backward(self, grad_output:UOp) -> UOp: return grad_output.const_like(0)
|
||||
|
||||
|
||||
@@ -62,7 +62,8 @@ def uop_to_json(x:UOp) -> Dict[int, Tuple[str, str, List[int], str, str]]:
|
||||
graph: Dict[int, Tuple[str, str, List[int], str, str]] = {}
|
||||
for u in x.toposort:
|
||||
if u.op is Ops.CONST: continue
|
||||
label = f"{str(u.op).split('.')[1]}{(' '+word_wrap(str(u.arg).replace(':', ''))) if u.arg is not None else ''}\n{str(u.dtype)}"
|
||||
argst = ("\n".join([f"{v.shape} / {v.strides}"+(f" / {v.offset}" if v.offset else "") for v in u.arg.views])) if u.op is Ops.VIEW else str(u.arg)
|
||||
label = f"{str(u.op).split('.')[1]}{(' '+word_wrap(argst.replace(':', ''))) if u.arg is not None else ''}\n{str(u.dtype)}"
|
||||
for idx,x in enumerate(u.src):
|
||||
if x.op is Ops.CONST: label += f"\nCONST{idx} {x.arg:g}"
|
||||
graph[id(u)] = (label, str(u.dtype), [id(x) for x in u.src if x.op is not Ops.CONST], str(u.arg), uops_colors.get(u.op, "#ffffff"))
|
||||
|
||||
Reference in New Issue
Block a user