diff --git a/scripts/attention.py b/scripts/attention.py index 520458d..9c1101c 100644 --- a/scripts/attention.py +++ b/scripts/attention.py @@ -61,7 +61,7 @@ def main_forward(module,x,context,mask,divide,isvanilla = False,userpp = False,t global pmaskshw,pmasks - if inhr and not hiresfinished: hiresscaler(height,width,attn) + if inhr and not hiresfinished: hiresscaler(height,width,attn,h) if userpp and step > 0: for b in range(attn.shape[0] // 8):