From 1c163f6f23a6af6da6948d8470a27fa4dbe984bf Mon Sep 17 00:00:00 2001 From: Khyber Sen Date: Sun, 23 Jun 2024 18:44:27 -0700 Subject: [PATCH] `fn decode_ceofs`: Translate `decode_coefs_class!` macro as a macro. Lots of borrows and captures make it difficult to make this a closure or `fn`, so it's easiest to translate this as a macro. --- src/recon.rs | 388 +++++---------------------------------------------- 1 file changed, 38 insertions(+), 350 deletions(-) diff --git a/src/recon.rs b/src/recon.rs index cf21862f8..a223bf01a 100644 --- a/src/recon.rs +++ b/src/recon.rs @@ -736,30 +736,19 @@ fn decode_coefs( let mut mag: c_uint = 0; let mut scan: &[u16] = &[]; - match tx_class { - TxClass::TwoD => { - let nonsquare_tx: c_uint = (tx >= TxfmSize::R4x8) as c_uint; - let lo_ctx_offsets = Some( - &dav1d_lo_ctx_offsets - [nonsquare_tx.wrapping_add(tx as c_uint & nonsquare_tx) as usize], - ); - scan = dav1d_scans[tx as usize]; - let stride = 4 * sh; - let shift: c_uint = if t_dim.lh < 4 { - t_dim.lh as c_uint + 2 - } else { - 5 - }; - let shift2: c_uint = 0; - let mask: c_uint = 4 * sh as c_uint - 1; - // Optimizes better than `.fill(0)`, - // which doesn't elide the bounds check, inline, or vectorize. - for i in 0..stride as usize * (4 * sw as usize + 2) { - levels[i] = 0; - } + + macro_rules! decode_coefs_class { + ($tx_class:expr, $lo_ctx_offsets:expr, $stride:expr, $shift:expr, $shift2:expr, $mask:expr) => {{ + const TX_CLASS: TxClass = $tx_class; + let lo_ctx_offsets: Option<&[[u8; 5]; 5]> = $lo_ctx_offsets; + let stride: u8 = $stride; + let shift: u32 = $shift; + let shift2: u32 = $shift2; + let mask: u32 = $mask; + let mut x: c_uint; let mut y: c_uint; - match tx_class { + match TX_CLASS { TxClass::TwoD => { rc = scan[eob as usize] as c_uint; x = rc >> shift; @@ -926,6 +915,31 @@ fn decode_coefs( ); } } + }}; + } + + match tx_class { + TxClass::TwoD => { + let nonsquare_tx: c_uint = (tx >= TxfmSize::R4x8) as c_uint; + let lo_ctx_offsets = Some( + &dav1d_lo_ctx_offsets + [nonsquare_tx.wrapping_add(tx as c_uint & nonsquare_tx) as usize], + ); + scan = dav1d_scans[tx as usize]; + let stride = 4 * sh; + let shift: c_uint = if t_dim.lh < 4 { + t_dim.lh as c_uint + 2 + } else { + 5 + }; + let shift2: c_uint = 0; + let mask: c_uint = 4 * sh as c_uint - 1; + // Optimizes better than `.fill(0)`, + // which doesn't elide the bounds check, inline, or vectorize. + for i in 0..stride as usize * (4 * sw as usize + 2) { + levels[i] = 0; + } + decode_coefs_class!(TxClass::TwoD, lo_ctx_offsets, stride, shift, shift2, mask); } TxClass::H => { let lo_ctx_offsets = None; @@ -938,170 +952,7 @@ fn decode_coefs( for i in 0..stride as usize * (4 * sh as usize + 2) { levels[i] = 0; } - let mut x: c_uint; - let mut y: c_uint; - match tx_class { - TxClass::TwoD => { - rc = scan[eob as usize] as c_uint; - x = rc >> shift; - y = rc & mask; - } - TxClass::H => { - x = eob as c_uint & mask; - y = (eob >> shift) as c_uint; - rc = eob as c_uint; - } - TxClass::V => { - x = eob as c_uint & mask; - y = (eob >> shift) as c_uint; - rc = x << shift2 | y; - } - } - if dbg { - println!( - "Post-lo_tok[{}][{}][{}][{}={}={}]: r={}", - t_dim.ctx, chroma, ctx, eob, rc, tok, ts_c.msac.rng, - ); - } - if eob_tok == 2 { - ctx = if if tx_class == TxClass::TwoD { - (x | y) > 1 - } else { - y != 0 - } { - 14 - } else { - 7 - }; - tok = rav1d_msac_decode_hi_tok(&mut ts_c.msac, &mut hi_cdf[ctx as usize]) - as c_int; - level_tok = tok + (3 << 6); - if dbg { - println!( - "Post-hi_tok[{}][{}][{}][{}={}={}]: r={}", - cmp::min(t_dim.ctx, 3), - chroma, - ctx, - eob, - rc, - tok, - ts_c.msac.rng, - ); - } - } - cf.set::(f, t_cf, rc as usize, (tok << 11).as_::()); - levels[x as usize * stride as usize + y as usize] = level_tok as u8; - let mut i = eob - 1; - while i > 0 { - let rc_i: c_uint; - match tx_class { - TxClass::TwoD => { - rc_i = scan[i as usize] as c_uint; - x = rc_i >> shift; - y = rc_i & mask; - } - TxClass::H => { - x = i as c_uint & mask; - y = (i >> shift) as c_uint; - rc_i = i as c_uint; - } - TxClass::V => { - x = i as c_uint & mask; - y = (i >> shift) as c_uint; - rc_i = x << shift2 | y; - } - } - assert!(x < 32 && y < 32); - let level = &mut levels[x as usize * stride as usize + y as usize..]; - ctx = get_lo_ctx(level, tx_class, &mut mag, lo_ctx_offsets, x, y, stride); - if tx_class == TxClass::TwoD { - y |= x; - } - tok = rav1d_msac_decode_symbol_adapt4( - &mut ts_c.msac, - &mut lo_cdf[ctx as usize], - 3, - ) as c_int; - if dbg { - println!( - "Post-lo_tok[{}][{}][{}][{}={}={}]: r={}", - t_dim.ctx, chroma, ctx, i, rc_i, tok, ts_c.msac.rng, - ); - } - if tok == 3 { - mag &= 63; - ctx = if y > (tx_class == TxClass::TwoD) as c_uint { - 14 - } else { - 7 - } + if mag > 12 { 6 } else { (mag as u8 + 1) >> 1 }; - tok = rav1d_msac_decode_hi_tok(&mut ts_c.msac, &mut hi_cdf[ctx as usize]) - as c_int; - if dbg { - println!( - "Post-hi_tok[{}][{}][{}][{}={}={}]: r={}", - cmp::min(t_dim.ctx, 3), - chroma, - ctx, - i, - rc_i, - tok, - ts_c.msac.rng, - ); - } - level[0] = (tok + (3 << 6)) as u8; - cf.set::( - f, - t_cf, - rc_i as usize, - ((tok << 11) as c_uint | rc).as_::(), - ); - rc = rc_i; - } else { - tok *= 0x17ff41; - level[0] = tok as u8; - tok = ((tok as c_uint >> 9) & rc.wrapping_add(!(0x7ff as c_uint))) as c_int; - if tok != 0 { - rc = rc_i; - } - cf.set::(f, t_cf, rc_i as usize, tok.as_::()); - } - i -= 1; - } - ctx = if tx_class == TxClass::TwoD { - 0 - } else { - get_lo_ctx(levels, tx_class, &mut mag, lo_ctx_offsets, 0, 0, stride) - }; - dc_tok = - rav1d_msac_decode_symbol_adapt4(&mut ts_c.msac, &mut lo_cdf[ctx as usize], 3) - as c_uint; - if dbg { - println!( - "Post-dc_lo_tok[{}][{}][{}][{}]: r={}", - t_dim.ctx, chroma, ctx, dc_tok, ts_c.msac.rng, - ); - } - if dc_tok == 3 { - if tx_class == TxClass::TwoD { - mag = levels[0 * stride as usize + 1] as c_uint - + levels[1 * stride as usize + 0] as c_uint - + levels[1 * stride as usize + 1] as c_uint; - } - mag &= 63; - ctx = if mag > 12 { 6 } else { (mag as u8 + 1) >> 1 }; - dc_tok = rav1d_msac_decode_hi_tok(&mut ts_c.msac, &mut hi_cdf[ctx as usize]) - as c_uint; - if dbg { - println!( - "Post-dc_hi_tok[{}][{}][0][{}]: r={}", - cmp::min(t_dim.ctx, 3), - chroma, - dc_tok, - ts_c.msac.rng, - ); - } - } + decode_coefs_class!(TxClass::H, lo_ctx_offsets, stride, shift, shift2, mask); } TxClass::V => { let lo_ctx_offsets = None; @@ -1114,170 +965,7 @@ fn decode_coefs( for i in 0..stride as usize * (4 * sw as usize + 2) { levels[i] = 0; } - let mut x: c_uint; - let mut y: c_uint; - match tx_class { - TxClass::TwoD => { - rc = scan[eob as usize] as c_uint; - x = rc >> shift; - y = rc & mask; - } - TxClass::H => { - x = eob as c_uint & mask; - y = (eob >> shift) as c_uint; - rc = eob as c_uint; - } - TxClass::V => { - x = eob as c_uint & mask; - y = (eob >> shift) as c_uint; - rc = x << shift2 | y; - } - } - if dbg { - println!( - "Post-lo_tok[{}][{}][{}][{}={}={}]: r={}", - t_dim.ctx, chroma, ctx, eob, rc, tok, ts_c.msac.rng, - ); - } - if eob_tok == 2 { - ctx = if if tx_class == TxClass::TwoD { - (x | y) > 1 - } else { - y != 0 - } { - 14 - } else { - 7 - }; - tok = rav1d_msac_decode_hi_tok(&mut ts_c.msac, &mut hi_cdf[ctx as usize]) - as c_int; - level_tok = tok + (3 << 6); - if dbg { - println!( - "Post-hi_tok[{}][{}][{}][{}={}={}]: r={}", - cmp::min(t_dim.ctx, 3), - chroma, - ctx, - eob, - rc, - tok, - ts_c.msac.rng, - ); - } - } - cf.set::(f, t_cf, rc as usize, (tok << 11).as_::()); - levels[x as usize * stride as usize + y as usize] = level_tok as u8; - let mut i = eob - 1; - while i > 0 { - let rc_i: c_uint; - match tx_class { - TxClass::TwoD => { - rc_i = scan[i as usize] as c_uint; - x = rc_i >> shift; - y = rc_i & mask; - } - TxClass::H => { - x = i as c_uint & mask; - y = (i >> shift) as c_uint; - rc_i = i as c_uint; - } - TxClass::V => { - x = i as c_uint & mask; - y = (i >> shift) as c_uint; - rc_i = x << shift2 | y; - } - } - assert!(x < 32 && y < 32); - let level = &mut levels[x as usize * stride as usize + y as usize..]; - ctx = get_lo_ctx(level, tx_class, &mut mag, lo_ctx_offsets, x, y, stride); - if tx_class == TxClass::TwoD { - y |= x; - } - tok = rav1d_msac_decode_symbol_adapt4( - &mut ts_c.msac, - &mut lo_cdf[ctx as usize], - 3, - ) as c_int; - if dbg { - println!( - "Post-lo_tok[{}][{}][{}][{}={}={}]: r={}", - t_dim.ctx, chroma, ctx, i, rc_i, tok, ts_c.msac.rng, - ); - } - if tok == 3 { - mag &= 63; - ctx = if y > (tx_class == TxClass::TwoD) as c_uint { - 14 - } else { - 7 - } + if mag > 12 { 6 } else { (mag as u8 + 1) >> 1 }; - tok = rav1d_msac_decode_hi_tok(&mut ts_c.msac, &mut hi_cdf[ctx as usize]) - as c_int; - if dbg { - println!( - "Post-hi_tok[{}][{}][{}][{}={}={}]: r={}", - cmp::min(t_dim.ctx, 3), - chroma, - ctx, - i, - rc_i, - tok, - ts_c.msac.rng, - ); - } - level[0] = (tok + (3 << 6)) as u8; - cf.set::( - f, - t_cf, - rc_i as usize, - ((tok << 11) as c_uint | rc).as_::(), - ); - rc = rc_i; - } else { - tok *= 0x17ff41; - level[0] = tok as u8; - tok = ((tok as c_uint >> 9) & rc.wrapping_add(!(0x7ff as c_uint))) as c_int; - if tok != 0 { - rc = rc_i; - } - cf.set::(f, t_cf, rc_i as usize, tok.as_::()); - } - i -= 1; - } - ctx = if tx_class == TxClass::TwoD { - 0 - } else { - get_lo_ctx(levels, tx_class, &mut mag, lo_ctx_offsets, 0, 0, stride) - }; - dc_tok = - rav1d_msac_decode_symbol_adapt4(&mut ts_c.msac, &mut lo_cdf[ctx as usize], 3) - as c_uint; - if dbg { - println!( - "Post-dc_lo_tok[{}][{}][{}][{}]: r={}", - t_dim.ctx, chroma, ctx, dc_tok, ts_c.msac.rng, - ); - } - if dc_tok == 3 { - if tx_class == TxClass::TwoD { - mag = levels[0 * stride as usize + 1] as c_uint - + levels[1 * stride as usize + 0] as c_uint - + levels[1 * stride as usize + 1] as c_uint; - } - mag &= 63; - ctx = if mag > 12 { 6 } else { (mag as u8 + 1) >> 1 }; - dc_tok = rav1d_msac_decode_hi_tok(&mut ts_c.msac, &mut hi_cdf[ctx as usize]) - as c_uint; - if dbg { - println!( - "Post-dc_hi_tok[{}][{}][0][{}]: r={}", - cmp::min(t_dim.ctx, 3), - chroma, - dc_tok, - ts_c.msac.rng, - ); - } - } + decode_coefs_class!(TxClass::V, lo_ctx_offsets, stride, shift, shift2, mask); } } } else {