use more _uop for cleanup [PR] (#16531)

`t.uop if isinstance(t, Tensor) else t` -> `t._uop`
This commit is contained in:
chenyu
2026-06-07 17:41:36 -04:00
committed by GitHub
parent 937aeaec60
commit 03943cd1a0
2 changed files with 2 additions and 2 deletions

View File

@@ -40,7 +40,7 @@ class _function(Generic[ReturnType]):
params = get_state_dict((args, kwargs), tensor_type=(Tensor, UOp)).values()
# deduplicate input_uops, keeping the first occurrence index for each unique uop
call_uops: list[UOp] = dedup([u for t in params if (u:=(t.uop if isinstance(t, Tensor) else t)).device is not None])
call_uops: list[UOp] = dedup([u for t in params if (u:=t._uop).device is not None])
# disable realize/schedule while this is running
# run it and do surgery later

View File

@@ -209,7 +209,7 @@ class Tensor(OpMixin):
param = UOp.param(slot, self.dtype, self.shape, self.device)
return Tensor(param)
def call(self, *lst:Tensor, fxn:Tensor|UOp, grad_fxn:Callable|None=None) -> Tensor:
fret = (fxn.uop if isinstance(fxn, Tensor) else fxn).call(*[t.uop for t in (self,)+lst], grad_fxn=grad_fxn)
fret = fxn._uop.call(*[t.uop for t in (self,)+lst], grad_fxn=grad_fxn)
return Tensor(fret.gettuple(0))
def custom_kernel(self, *lst:Tensor, fxn:Callable, grad_fxn:Callable|None=None) -> list[Tensor]: