mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-06-08 05:54:59 +08:00
use more _uop for cleanup [PR] (#16531)
`t.uop if isinstance(t, Tensor) else t` -> `t._uop`
This commit is contained in:
@@ -40,7 +40,7 @@ class _function(Generic[ReturnType]):
|
||||
params = get_state_dict((args, kwargs), tensor_type=(Tensor, UOp)).values()
|
||||
|
||||
# deduplicate input_uops, keeping the first occurrence index for each unique uop
|
||||
call_uops: list[UOp] = dedup([u for t in params if (u:=(t.uop if isinstance(t, Tensor) else t)).device is not None])
|
||||
call_uops: list[UOp] = dedup([u for t in params if (u:=t._uop).device is not None])
|
||||
|
||||
# disable realize/schedule while this is running
|
||||
# run it and do surgery later
|
||||
|
||||
@@ -209,7 +209,7 @@ class Tensor(OpMixin):
|
||||
param = UOp.param(slot, self.dtype, self.shape, self.device)
|
||||
return Tensor(param)
|
||||
def call(self, *lst:Tensor, fxn:Tensor|UOp, grad_fxn:Callable|None=None) -> Tensor:
|
||||
fret = (fxn.uop if isinstance(fxn, Tensor) else fxn).call(*[t.uop for t in (self,)+lst], grad_fxn=grad_fxn)
|
||||
fret = fxn._uop.call(*[t.uop for t in (self,)+lst], grad_fxn=grad_fxn)
|
||||
return Tensor(fret.gettuple(0))
|
||||
|
||||
def custom_kernel(self, *lst:Tensor, fxn:Callable, grad_fxn:Callable|None=None) -> list[Tensor]:
|
||||
|
||||
Reference in New Issue
Block a user