Skip to content

Commit

Permalink
Dropped turbo option because of latency issues;
Browse files Browse the repository at this point in the history
  • Loading branch information
droodman committed Mar 4, 2022
1 parent 1f9b3a3 commit bf6a0c7
Show file tree
Hide file tree
Showing 13 changed files with 441 additions and 539 deletions.
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -12,3 +12,4 @@ compilation tests.jl
Stata examples.do

src/investigate turbo.jl
src/debug.jl
2 changes: 1 addition & 1 deletion Manifest.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
# This file is machine-generated - editing it directly is not advised

julia_version = "1.7.0"
julia_version = "1.7.2"
manifest_format = "2.0"

[[deps.Adapt]]
Expand Down
2 changes: 1 addition & 1 deletion Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -14,5 +14,5 @@ SortingAlgorithms = "a2af1166-a08f-5f64-846c-94a0d3cef48c"
[compat]
Distributions = "0.25"
SortingAlgorithms = "1.0"
julia = "1.6"
julia = "1.7"
LoopVectorization = "0.12.99"
2 changes: 1 addition & 1 deletion docs/src/index.md
Original file line number Diff line number Diff line change
Expand Up @@ -19,4 +19,4 @@ The interface is low-level: the exported function `wildboottest()` accepts scala
`wildboottest()` accepts many optional arguments. Most correspond to options of the Stata package `boottest`, which are documented in [Roodman et al. (2019), §7](https://www.econ.queensu.ca/sites/econ.queensu.ca/files/qed_wp_1406.pdf#page=28). Julia-specific additions include an optional first argument `T`, which can be `Float32` or `Float64` to specify the precision of computation; and `rng`, which takes a random number generator such as `MersenneTwister(2302394)`.

## On latency
The first time you run `wildboottest()` in a session, Julia's just-in-time compilation will take ~10 seconds. The same will happen the first time you switch between turbo and non-turbo modes or between Float32 and Float64 calculations, or between OLS and IV/2SLS estimation. (Non-turbo and Float32 are defaults.)
The first time you run `wildboottest()` in a session, Julia's just-in-time compilation will take ~10 seconds. The same will happen the first time you switch between Float32 and Float64 calculations, or between OLS and IV/2SLS estimation.
11 changes: 7 additions & 4 deletions src/StrBootTest.jl
Original file line number Diff line number Diff line change
Expand Up @@ -83,7 +83,7 @@ mutable struct StrBootTest{T<:AbstractFloat}
peak::NamedTuple{(:X, :p), Tuple{Vector{T}, T}}

Nobs::Int64; NClustVar::Int8; kX₁::Int64; kX₂::Int64; kY₂::Int64; WREnonARubin::Bool; boottest!::Function
turbo::Bool
coldotplus!::Function; colquadformminus!::Function; matmulplus!::Function; panelsum!::Function

sqrt::Bool; _Nobs::T; kZ::Int64; sumwt::T; haswt::Bool; sqrtwt::Vector{T}; REst::Bool; multiplier::T; smallsample::T
dof::Int64; dof_r::T; p::T; BootClust::Int8
Expand Down Expand Up @@ -125,7 +125,7 @@ mutable struct StrBootTest{T<:AbstractFloat}

StrBootTest{T}(R, r, R₁, r₁, y₁, X₁, Y₂, X₂, wt, fweights, LIML,
Fuller, κ, ARubin, B, auxtwtype, rng, maxmatsize, ptype, null, scorebs, bootstrapt, ID, NBootClustVar, NErrClustVar, issorted, robust, small, FEID, FEdfadj, level, rtol, madjtype, NH₀, ML,
β̈, A, sc, willplot, gridmin, gridmax, gridpoints, turbo) where T<:Real =
β̈, A, sc, willplot, gridmin, gridmax, gridpoints) where T<:Real =
begin
kX₂ = ncols(X₂)
scorebs = scorebs || iszero(B) || ML
Expand All @@ -143,7 +143,10 @@ mutable struct StrBootTest{T<:AbstractFloat}
Matrix{T}(undef,0,0),
(X = Vector{T}(undef,0), p = T(NaN)),
nrows(X₁), ncols(ID), ncols(X₁), kX₂, ncols(Y₂), WREnonARubin, WREnonARubin ? boottestWRE! : boottestOLSARubin!,
turbo)
coldotplus_nonturbo!,
colquadformminus_nonturbo!,
matmulplus_nonturbo!,
panelsum_nonturbo!)
end
end

Expand All @@ -154,7 +157,7 @@ function getdist(o::StrBootTest, diststat::DistStatType=nodist)
sort!(o.distCDR)
elseif nrows(o.distCDR)==0 # return test stats
if length(o.dist) > 1
o.distCDR = (@view o.dist[2:end])' * o.multiplier
o.distCDR = (@view o.dist[1,2:end])' * o.multiplier
sort!(o.distCDR, dims=1)
else
o.distCDR = zeros(0,1)
Expand Down
12 changes: 6 additions & 6 deletions src/WRE.jl
Original file line number Diff line number Diff line change
Expand Up @@ -137,24 +137,24 @@ function InitWRE!(o::StrBootTest{T}) where T

o.r₁S✻ReplZR₁Y₂ = o.r₁' * (o.Repl.S✻ZR₁Y₂ - _S✻ZperpReplZR₁' * o.Repl.invZperpZperpZperpY₂ - o.Repl.invZperpZperpZperpZR₁' * o.S✻ZperpY₂)
o.r₁S✻ReplZR₁X = o.r₁' * (_S✻⋂XReplZR₁' - _S✻ZperpReplZR₁' * o.Repl.invZperpZperpZperpX - o.Repl.invZperpZperpZperpZR₁' * o.S✻ZperpX )
o.r₁S✻ReplZR₁DGPZ = o.r₁' * panelcross11(o, o.Repl.ZR₁, o.DGP.Z, o.info✻)
o.r₁S✻ReplZR₁DGPZ = o.r₁' * panelcross(o.Repl.ZR₁, o.DGP.Z, o.info✻)
o.r₁S✻ReplZR₁y₁ = o.r₁' * (o.Repl.S✻ZR₁y₁ - _S✻ZperpReplZR₁' * o.Repl.invZperpZperpZperpy₁ - o.Repl.invZperpZperpZperpZR₁' * o.S✻Zperpy₁)
o.DGP.restricted &&
(o.r₁S✻ReplZR₁DGPZR₁ = o.r₁' * panelcross11(o, o.Repl.ZR₁, o.DGP.ZR₁, o.info✻))
(o.r₁S✻ReplZR₁DGPZR₁ = o.r₁' * panelcross(o.Repl.ZR₁, o.DGP.ZR₁, o.info✻))
end

_S✻ZperpReplZpar = @panelsum(o, o.Repl.S✻⋂ZperpZpar, o.info✻_✻⋂)
_S✻ReplXZ = @panelsum(o, o.Repl.S✻⋂XZpar , o.info✻_✻⋂)

o.S✻ReplZY₂ = o.Repl.S✻ZparY₂ - _S✻ZperpReplZpar' * o.Repl.invZperpZperpZperpY₂ - o.Repl.invZperpZperpZperpZpar' * o.S✻ZperpY₂
o.S✻ReplZX = _S✻ReplXZ' - _S✻ZperpReplZpar' * o.Repl.invZperpZperpZperpX - o.Repl.invZperpZperpZperpZpar' * o.S✻ZperpX
o.S✻ReplZDGPZ = panelcross11(o, o.Repl.Z, o.DGP.Z, o.info✻)
o.S✻ReplZDGPZ = panelcross(o.Repl.Z, o.DGP.Z, o.info✻)
o.S✻ReplZy₁ = o.Repl.S✻Zpary₁ - _S✻ZperpReplZpar' * o.Repl.invZperpZperpZperpy₁ - o.Repl.invZperpZperpZperpZpar' * o.S✻Zperpy₁

if o.DGP.restricted
_S✻⋂XDGPZR₁ = @panelsum(o, o.DGP.S✻⋂XZR₁, o.info✻_✻⋂)

o.S✻ReplZDGPZR₁ = panelcross11(o, o.Repl.Z, o.DGP.ZR₁, o.info✻)
o.S✻ReplZDGPZR₁ = panelcross(o.Repl.Z, o.DGP.ZR₁, o.info✻)
o.S✻DGPZR₁Y₂ = o.DGP.S✻ZR₁Y₂ - _S✻ZperpDGPZR₁' * o.Repl.invZperpZperpZperpY₂ - o.DGP.invZperpZperpZperpZR₁' * o.S✻ZperpY₂
o.S✻DGPZR₁DGPZR₁ = o.DGP.S✻ZR₁ZR₁ - _S✻ZperpDGPZR₁' * o.DGP.invZperpZperpZperpZR₁ - o.DGP.invZperpZperpZperpZR₁' * o.S✻ZperpDGPZR₁
o.S✻DGPZR₁DGPZ = o.DGP.S✻ZR₁Z - _S✻ZperpDGPZR₁' * o.DGP.invZperpZperpZperpZpar - o.DGP.invZperpZperpZperpZR₁' * o.S✻ZperpDGPZ
Expand Down Expand Up @@ -481,7 +481,7 @@ function Filling!(o::StrBootTest{T}, dest::AbstractMatrix{T}, ind1::Int64, β̈s
dest .= reshape(o.F1_0'o.F2_0,:,1) .- (dropdims(o.F1_0'o.F2_1; dims=1) - o.F2_0'o.F1_1) * o.v # 0th- & 1st-order terms
o.Q .= o.F1_1'o.F2_1
@inbounds for g 1:o.N⋂
colquadformminus!(Val(o.turbo), dest, g, o.v, o.Q[:,g,:], o.v)
o.colquadformminus!(dest, g, o.v, o.Q[:,g,:], o.v)
end
else
dest .= reshape(o.F1_0'o.F2_0,:,1) .- dropdims(o.F1_0'o.F2_1; dims=1) * o.v # 0th- & 1st-order terms
Expand All @@ -496,7 +496,7 @@ function Filling!(o::StrBootTest{T}, dest::AbstractMatrix{T}, ind1::Int64, β̈s
dest .+= reshape(o.F1_0'o.F2_0,:,1) .* _β̈ .+ (o.F2_0'o.F1_1 - dropdims(o.F1_0'o.F2_1; dims=1)) * β̈v # "-" because S✻UMZperpX is stored negated as negS✻UMZperpX
o.Q .= o.F1_1'o.F2_1
for g 1:o.N⋂
colquadformminus!(Val(o.turbo), dest, g, o.v, o.Q[:,g,:], β̈v)
o.colquadformminus!(dest, g, o.v, o.Q[:,g,:], β̈v)
end
elseif o.Repl.Yendog[ind1+1]
dest .+= reshape(o.F1_0'o.F2_0,:,1) .* _β̈ .- dropdims(o.F1_0'o.F2_1; dims=1) * β̈v
Expand Down
6 changes: 3 additions & 3 deletions src/WildBootTests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@ module WildBootTests
export BootTestResult, wildboottest, AuxWtType, PType, MAdjType, DistStatType,
teststat, stattype, p, padj, reps, repsfeas, nbootclust, dof, dof_r, plotpoints, peak, CI, dist, statnumer, statvar, auxweights

using LinearAlgebra, Random, Distributions, SortingAlgorithms, LoopVectorization
using LinearAlgebra, Random, Distributions, SortingAlgorithms #, LoopVectorization

include("StrBootTest.jl")
include("utilities.jl")
Expand Down Expand Up @@ -100,13 +100,13 @@ function UpdateBootstrapcDenom!(o::StrBootTest{T} where T, w::Integer)
if o.sqrt
o.dist .= o.numer ./ sqrtNaN.(o.statDenom)
else
negcolquadform!(o.dist, -invsym(o.statDenom), o.numer) # to reduce latency by minimizing @tturbo instances, work with negative of colquadform in order to fuse code with colquadformminus!
negcolquadform!(o.dist, -invsym(o.statDenom), o.numer) # to reduce latency by minimizing #=@tturbo=# instances, work with negative of colquadform in order to fuse code with colquadformminus!
end
end
nothing
end

include("precompile_WildBootTests.jl") # source: https://timholy.github.io/SnoopCompile.jl/stable/snoopi_deep_parcel/#SnoopCompile.write
include("precompile_WildBootTests.jl") # https://timholy.github.io/SnoopCompile.jl/stable/snoopi_deep_parcel/#SnoopCompile.write
_precompile_()

end
56 changes: 28 additions & 28 deletions src/estimators.jl
Original file line number Diff line number Diff line change
Expand Up @@ -109,15 +109,15 @@ function InitVarsIV!(o::StrEstimator{T}, parent::StrBootTest{T}, Rperp::Abstract
else
o.kZperp = ncols(o.RperpX)
o.Zperp = parent.X₁ * o.RperpX
o.S✻⋂ZperpZperp = panelcross11(parent, o.Zperp, o.Zperp, parent.info✻⋂)
o.S✻⋂ZperpZperp = panelcross(o.Zperp, o.Zperp, parent.info✻⋂)
o.invZperpZperp = iszero(ncols(o.RperpX)) ? Symmetric(Matrix{T}(undef,0,0)) : invsym(sumpanelcross(o.S✻⋂ZperpZperp))

o.Xpar₁ = parent.X₁ * o.RperpXperp
S✻⋂X₁Zperp = panelcross11(parent, o.Xpar₁, o.Zperp, parent.info✻⋂)
S✻⋂X₁Zperp = panelcross(o.Xpar₁, o.Zperp, parent.info✻⋂)
ZperpX₁ = sumpanelcross(S✻⋂X₁Zperp)'
o.invZperpZperpZperpX₁ = o.invZperpZperp * ZperpX₁

S✻⋂X₂Zperp = panelcross11(parent, parent.X₂, o.Zperp, parent.info✻⋂)
S✻⋂X₂Zperp = panelcross(parent.X₂, o.Zperp, parent.info✻⋂)
ZperpX₂ = sumpanelcross(S✻⋂X₂Zperp)'
o.ZperpX = [ZperpX₁ ZperpX₂]
o.S✻⋂XZperp = [S✻⋂X₁Zperp; S✻⋂X₂Zperp]
Expand All @@ -130,9 +130,9 @@ function InitVarsIV!(o::StrEstimator{T}, parent::StrBootTest{T}, Rperp::Abstract
o.X₂ = o.Zperp * o.invZperpZperpZperpX₂; o.X₂ .= parent.X₂ .- o.X₂ # FWL-process X₂
end

S✻⋂X₁X₁ = panelcross11(parent, o.Xpar₁, o.Xpar₁, parent.info✻⋂) # S⋂(X₂:*M_(Z⟂)X₁∥)
S✻⋂X₂X₁ = panelcross11(parent, parent.X₂, o.Xpar₁, parent.info✻⋂)
S✻⋂X₂X₂ = panelcross11(parent, parent.X₂, parent.X₂, parent.info✻⋂)
S✻⋂X₁X₁ = panelcross(o.Xpar₁, o.Xpar₁, parent.info✻⋂) # S⋂(X₂:*M_(Z⟂)X₁∥)
S✻⋂X₂X₁ = panelcross(parent.X₂, o.Xpar₁, parent.info✻⋂)
S✻⋂X₂X₂ = panelcross(parent.X₂, parent.X₂, parent.info✻⋂)
X₂X₁ = sumpanelcross(S✻⋂X₂X₁) - ZperpX₂'o.invZperpZperp * ZperpX₁
X₁X₁ = Symmetric(sumpanelcross(S✻⋂X₁X₁)) - Symmetric(ZperpX₁'o.invZperpZperp * ZperpX₁)
X₂X₂ = Symmetric(sumpanelcross(S✻⋂X₂X₂)) - Symmetric(ZperpX₂'o.invZperpZperp * ZperpX₂)
Expand All @@ -141,80 +141,80 @@ function InitVarsIV!(o::StrEstimator{T}, parent::StrBootTest{T}, Rperp::Abstract
o.kX = ncols(o.XX)
o.invXX = invsym(o.XX)

o.S✻⋂ZperpY₂ = panelcross11(parent, o.Zperp, parent.Y₂, parent.info✻⋂)
o.S✻⋂ZperpY₂ = panelcross(o.Zperp, parent.Y₂, parent.info✻⋂)
ZperpY₂ = sumpanelcross(o.S✻⋂ZperpY₂)
o.invZperpZperpZperpY₂ = o.invZperpZperp * ZperpY₂
((parent.NFE>0 && (parent.LIML || !isone(parent.κ) || parent.bootstrapt)) || (parent.robust && parent.bootstrapt && parent.granular)) &&
(o.Y₂ = parent.Y₂ - o.Zperp * o.invZperpZperpZperpY₂)
o.S✻⋂Zperpy₁ = panelcross11(parent, o.Zperp, parent.y₁, parent.info✻⋂)
o.S✻⋂Zperpy₁ = panelcross(o.Zperp, parent.y₁, parent.info✻⋂)
Zperpy₁ = sumpanelcross(o.S✻⋂Zperpy₁)
o.invZperpZperpZperpy₁ = o.invZperpZperp * Zperpy₁
((parent.NFE>0 && (parent.LIML || !isone(parent.κ) || parent.bootstrapt)) || (parent.scorebs || parent.robust && parent.bootstrapt && parent.granular)) &&
(o.y₁ = parent.y₁ - o.Zperp * o.invZperpZperpZperpy₁)

o.S✻⋂X₁Y₂ = panelcross11(parent, o.Xpar₁, parent.Y₂, parent.info✻⋂)
o.S✻⋂X₂Y₂ = panelcross11(parent, parent.X₂, parent.Y₂, parent.info✻⋂)
o.S✻⋂X₁Y₂ = panelcross(o.Xpar₁, parent.Y₂, parent.info✻⋂)
o.S✻⋂X₂Y₂ = panelcross(parent.X₂, parent.Y₂, parent.info✻⋂)
o.S✻⋂XY₂ = [o.S✻⋂X₁Y₂; o.S✻⋂X₂Y₂]
o.XY₂ = sumpanelcross(o.S✻⋂XY₂) - o.invZperpZperpZperpX'ZperpY₂
o.S✻Y₂y₁ = panelcross11(parent, parent.Y₂, parent.y₁, parent.info✻)
o.S✻Y₂y₁ = panelcross(parent.Y₂, parent.y₁, parent.info✻)
o.Y₂y₁ = sumpanelcross(o.S✻Y₂y₁) - ZperpY₂'o.invZperpZperpZperpy₁
o.S✻Y₂Y₂ = panelcross11(parent, parent.Y₂, parent.Y₂, parent.info✻)
o.S✻Y₂Y₂ = panelcross(parent.Y₂, parent.Y₂, parent.info✻)
o.Y₂Y₂ = Symmetric(sumpanelcross(o.S✻Y₂Y₂)) - Symmetric(ZperpY₂'o.invZperpZperpZperpY₂)
S✻⋂X₂y₁ = panelcross11(parent, parent.X₂, parent.y₁, parent.info✻⋂)
S✻⋂X₂y₁ = panelcross(parent.X₂, parent.y₁, parent.info✻⋂)
o.X₂y₁ = reshape(sumpanelcross(S✻⋂X₂y₁), :) - ZperpX₂'o.invZperpZperpZperpy₁
S✻⋂X₁y₁ = panelcross11(parent, o.Xpar₁, parent.y₁, parent.info✻⋂)
S✻⋂X₁y₁ = panelcross(o.Xpar₁, parent.y₁, parent.info✻⋂)
o.S✻⋂Xy₁ = [S✻⋂X₁y₁; S✻⋂X₂y₁]
o.X₁y₁ = reshape(sumpanelcross(S✻⋂X₁y₁), :) - ZperpX₁'o.invZperpZperpZperpy₁
o.S✻y₁y₁ = reshape(panelcross11(parent, parent.y₁, parent.y₁, parent.info✻), :)
o.S✻y₁y₁ = reshape(panelcross(parent.y₁, parent.y₁, parent.info✻), :)
o.y₁y₁ = sum(o.S✻y₁y₁) - Zperpy₁'o.invZperpZperpZperpy₁
end

o.Z = X₁₂B(parent, parent.X₁, parent.Y₂, o.Rpar) # Z∥

X₁par = parent.X₁ * o.RparX # XXX expressible as a linear combination of Xpar₁??
S✻⋂X₁Zpar = panelcross11(parent, o.Xpar₁ , X₁par, parent.info✻⋂) + o.S✻⋂X₁Y₂ * o.RparY
S✻⋂X₂Zpar = panelcross11(parent, parent.X₂, X₁par, parent.info✻⋂) + o.S✻⋂X₂Y₂ * o.RparY
S✻⋂X₁Zpar = panelcross(o.Xpar₁ , X₁par, parent.info✻⋂) + o.S✻⋂X₁Y₂ * o.RparY
S✻⋂X₂Zpar = panelcross(parent.X₂, X₁par, parent.info✻⋂) + o.S✻⋂X₂Y₂ * o.RparY
o.S✻⋂XZpar = [S✻⋂X₁Zpar; S✻⋂X₂Zpar]
X₁Zpar = sumpanelcross(S✻⋂X₁Zpar)
X₂Zpar = sumpanelcross(S✻⋂X₂Zpar)
S✻⋂ZperpX₁par = panelcross11(parent, o.Zperp, X₁par, parent.info✻⋂)
S✻⋂ZperpX₁par = panelcross(o.Zperp, X₁par, parent.info✻⋂)
ZperpX₁par = sumpanelcross(S✻⋂ZperpX₁par)
o.S✻⋂ZperpZpar = S✻⋂ZperpX₁par + o.S✻⋂ZperpY₂ * o.RparY
ZperpZpar = ZperpX₁par + sumpanelcross(o.S✻⋂ZperpY₂) * o.RparY
o.invZperpZperpZperpZpar = o.invZperpZperp * ZperpZpar

S✻X₁pary₁ = panelcross11(parent, X₁par, parent.y₁, parent.info✻)
S✻X₁pary₁ = panelcross(X₁par, parent.y₁, parent.info✻)
o.S✻Zpary₁ = S✻X₁pary₁ + o.RparY' * o.S✻Y₂y₁
o.Zy₁ = sumpanelcross(S✻X₁pary₁) + o.RparY' * o.Y₂y₁ - ZperpX₁par'o.invZperpZperpZperpy₁

S✻X₁parY₂ = panelcross11(parent, X₁par, parent.Y₂, parent.info✻)
S✻X₁parY₂ = panelcross(X₁par, parent.Y₂, parent.info✻)
o.XZ = [X₁Zpar - o.invZperpZperpZperpX₁'ZperpZpar ; X₂Zpar - o.invZperpZperpZperpX₂'ZperpZpar]
o.S✻ZparY₂ = S✻X₁parY₂ + o.RparY' * o.S✻Y₂Y₂
o.ZY₂ = sumpanelcross(S✻X₁parY₂) - ZperpX₁par'o.invZperpZperpZperpY₂
tmp = S✻X₁parY₂ * o.RparY; o.S✻ZparZpar = panelcross11(parent, X₁par, X₁par, parent.info✻) + tmp + tmp' + o.RparY' * o.S✻Y₂Y₂ * o.RparY
tmp = S✻X₁parY₂ * o.RparY; o.S✻ZparZpar = panelcross(X₁par, X₁par, parent.info✻) + tmp + tmp' + o.RparY' * o.S✻Y₂Y₂ * o.RparY
o.ZZ = Symmetric(sumpanelcross(o.S✻ZparZpar)) - Symmetric(ZperpZpar'o.invZperpZperpZperpZpar)

o.invXXXZ = o.invXX * o.XZ
o.ZXinvXXXZ = o.XZ'o.invXXXZ # symmetric but converting to Symmetric() hampers type inference in the one place it's used

if o.restricted
_ZR₁ = X₁₂B(parent, parent.X₁, parent.Y₂, o.R₁invR₁R₁)
S✻⋂X₁ZR₁ = panelcross11(parent, o.Xpar₁, _ZR₁, parent.info✻⋂)
S✻⋂X₂ZR₁ = panelcross11(parent, parent.X₂, _ZR₁, parent.info⋂)
S✻⋂X₁ZR₁ = panelcross(o.Xpar₁, _ZR₁, parent.info✻⋂)
S✻⋂X₂ZR₁ = panelcross(parent.X₂, _ZR₁, parent.info⋂)
o.S✻⋂XZR₁ = [S✻⋂X₁ZR₁ ; S✻⋂X₂ZR₁]
o.S✻⋂ZperpZR₁ = panelcross11(parent, o.Zperp, _ZR₁, parent.info✻⋂)
o.S✻⋂ZperpZR₁ = panelcross(o.Zperp, _ZR₁, parent.info✻⋂)
o.ZperpZR₁ = sumpanelcross(o.S✻⋂ZperpZR₁)
o.invZperpZperpZperpZR₁ = o.invZperpZperp * o.ZperpZR₁
o.ZR₁ = _ZR₁ - o.Zperp * o.invZperpZperpZperpZR₁
o.X₁ZR₁ = sumpanelcross(S✻⋂X₁ZR₁) - o.invZperpZperpZperpX₁'o.ZperpZR₁
o.X₂ZR₁ = sumpanelcross(S✻⋂X₂ZR₁) - o.invZperpZperpZperpX₂'o.ZperpZR₁
o.S✻ZR₁Z = panelcross11(parent, _ZR₁, o.Z, parent.info✻)
o.S✻ZR₁Z = panelcross(_ZR₁, o.Z, parent.info✻)
o.ZR₁Z = sumpanelcross(o.S✻ZR₁Z) - o.ZperpZR₁'o.invZperpZperp * ZperpZpar
o.S✻ZR₁Y₂ = panelcross11(parent, _ZR₁, parent.Y₂, parent.info✻)
o.S✻ZR₁Y₂ = panelcross(_ZR₁, parent.Y₂, parent.info✻)
o.ZR₁Y₂ = sumpanelcross(o.S✻ZR₁Y₂) - o.ZperpZR₁'o.invZperpZperpZperpY₂
o.S✻ZR₁y₁ = panelcross11(parent, _ZR₁, parent.y₁, parent.info✻)
o.S✻ZR₁y₁ = panelcross(_ZR₁, parent.y₁, parent.info✻)
o.twoZR₁y₁ = 2 * (sumpanelcross(o.S✻ZR₁y₁) - o.ZperpZR₁'o.invZperpZperpZperpy₁)
o.S✻ZR₁ZR₁ = panelcross11(parent, _ZR₁, _ZR₁, parent.info✻)
o.S✻ZR₁ZR₁ = panelcross(_ZR₁, _ZR₁, parent.info✻)
o.ZR₁ZR₁ = Symmetric(sumpanelcross(o.S✻ZR₁ZR₁)) - Symmetric(o.ZperpZR₁'o.invZperpZperp * o.ZperpZR₁)
else
o.Y₂y₁par = o.Y₂y₁
Expand Down
12 changes: 4 additions & 8 deletions src/interface.jl
Original file line number Diff line number Diff line change
Expand Up @@ -125,13 +125,12 @@ function __wildboottest(
diststat::DistStatType,
getCI::Bool,
getplot::Bool,
getauxweights::Bool,
turbo::Bool) where T
getauxweights::Bool) where T

M = StrBootTest{T}(R, r, R1, r1, resp, predexog, predendog, inst, obswt, fweights, LIML, Fuller, kappa, ARubin,
reps, auxwttype, rng, maxmatsize, ptype, imposenull, scorebs, !bootstrapc, clustid, nbootclustvar, nerrclustvar, issorted, hetrobust, small,
feid, fedfadj, level, rtol, madjtype, NH0, ML, beta, A, scores, getplot,
gridmin, gridmax, gridpoints, turbo)
gridmin, gridmax, gridpoints)

if getplot || (level<1 && getCI)
plot!(M)
Expand Down Expand Up @@ -200,8 +199,7 @@ function _wildboottest(T::DataType,
diststat::DistStatType=nodist,
getCI::Bool=true,
getplot::Bool=getCI,
getauxweights::Bool=false,
turbo::Bool=false)
getauxweights::Bool=false)

nrows(R)>2 && (getplot = getCI = false)

Expand Down Expand Up @@ -294,8 +292,7 @@ function _wildboottest(T::DataType,
diststat,
getCI,
getplot,
getauxweights,
turbo)
getauxweights)
end

_wildboottest(T::DataType, R, r::Number; kwargs...) = _wildboottest(T, R, [r]; kwargs...)
Expand Down Expand Up @@ -358,7 +355,6 @@ Function to perform wild-bootstrap-based hypothesis test
* `getCI::Bool=true`: whether to return CI
* `getplot::Bool=getCI`: whether to generate plot data
* `getauxweights::Bool=false`: whether to save auxilliary weight matrix (v)
* `turbo::Bool=false`: whether to exploit acceleration of the LoopVectorization package: slower on first use in a session, faster after
# Notes
`T`, `ptype`, `auxwttype`, `madjtype`, and `diststat` may also be strings. Examples: `"Float32"` and `"webb"`.
Expand Down
2 changes: 1 addition & 1 deletion src/nonWRE.jl
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ function MakeInterpolables!(o::StrBootTest{T}) where T
o.∂Jcd∂r[h₁,c,d₁] = (o.Jcd[c,d₁] .- o.Jcd₀[c,d₁]) ./ o.poles[h₁]
for d₂ 1:d₁
tmp = coldot(o, o.Jcd₀[c,d₁], o.∂Jcd∂r[h₁,c,d₂])
d₁ d₂ && (coldotplus!(Val(o.turbo), tmp, o.Jcd₀[c,d₂], o.∂Jcd∂r[h₁,c,d₁])) # for diagonal items, faster to just double after the c loop
d₁ d₂ && (o.coldotplus!(tmp, o.Jcd₀[c,d₂], o.∂Jcd∂r[h₁,c,d₁])) # for diagonal items, faster to just double after the c loop
@clustAccum!(o.∂denom∂r[h₁,d₁,d₂], c, tmp)
end
end
Expand Down
Loading

0 comments on commit bf6a0c7

Please sign in to comment.