diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml
index 09eece68..8ed89839 100644
--- a/.github/workflows/ci.yml
+++ b/.github/workflows/ci.yml
@@ -18,7 +18,6 @@ jobs:
       fail-fast: false
       matrix:
         version:
-          - '1.6'
           - '1'
         os:
           - ubuntu-latest
@@ -85,8 +84,7 @@ jobs:
       - run: |
           julia --project=HPCG -e '
             using Pkg
-            Pkg.develop(path=".")
-            Pkg.develop(path="./PartitionedSolvers")
+            Pkg.develop([Pkg.PackageSpec(path="."),Pkg.PackageSpec(path="./PartitionedSolvers")])
             Pkg.test("HPCG")'
   docs:
     name: Documentation
diff --git a/HPCG/Project.toml b/HPCG/Project.toml
index 14b7e9c7..ad8bbac5 100644
--- a/HPCG/Project.toml
+++ b/HPCG/Project.toml
@@ -23,7 +23,7 @@ DelimitedFiles = "1.9"
 JSON = "0.21"
 MPI = "0.20"
 PartitionedArrays = "0.5"
-PartitionedSolvers = "0.2"
+PartitionedSolvers = "0.3"
 Primes = "0.5"
 SparseMatricesCSR = "0.6"
 julia = "1.1"
diff --git a/HPCG/src/mg_preconditioner.jl b/HPCG/src/mg_preconditioner.jl
index cda9a6a7..4c999299 100644
--- a/HPCG/src/mg_preconditioner.jl
+++ b/HPCG/src/mg_preconditioner.jl
@@ -110,7 +110,7 @@ function generate_problem(ranks, npx, npy, npz, nx, ny, nz, solver)
     Axf = similar(r)
     Axf .= 0
     x .= 0
-    gs_state = setup(solver, x, A, r)
+    gs_state = solver(PartitionedSolvers.linear_problem(x, A, r))
     return A, r, x, Axf, gs_state
 end
 
@@ -139,11 +139,11 @@ function pc_setup(np, ranks, l, nx, ny, nz)
     r = Vector{PVector}(undef, l)
     x = Vector{PVector}(undef, l)
     Axf = Vector{PVector}(undef, l)
-    gs_states = Vector{PartitionedSolvers.Preconditioner}(undef, l)
+    gs_states = Vector{PartitionedSolvers.LinearSolver}(undef, l)
     npx, npy, npz = compute_optimal_shape_XYZ(np)
     nnz_vec = Vector{Int64}(undef, l)
     nrows_vec = Vector{Int64}(undef, l)
-    solver = PartitionedSolvers.additive_schwarz_correction_partition(gauss_seidel(; iters = 1))
+    solver = p -> PartitionedSolvers.gauss_seidel(p;iterations=1)
 
     # create top problem 
     A, r, x, Axf, gs_state = generate_problem(ranks, npx, npy, npz, nx, ny, nz, solver)
@@ -313,16 +313,16 @@ end
 """
 function pc_solve!(x, s::Mg_preconditioner, b, l; zero_guess = false)
     if l == 1
-        solve!(x, s.gs_states[l], b; zero_guess) # bottom solve
+        PartitionedSolvers.smooth!(x, s.gs_states[l], b; zero_guess) # bottom solve
     else
-        solve!(x, s.gs_states[l], b; zero_guess) # presmoother 
+        PartitionedSolvers.smooth!(x, s.gs_states[l], b; zero_guess) # presmoother 
         mul_no_lat!(s.Axf[l], s.A_vec[l], x)
         p_restrict!(s.r[l-1], b, s.Axf[l], s.f2c[l-1])
         s.x[l-1] .= 0.0
         pc_solve!(s.x[l-1], s, s.r[l-1], l - 1; zero_guess = true)
         p_prolongate!(x, s.x[l-1], s.f2c[l-1])
-        consistent!(x) |> wait
-        solve!(x, s.gs_states[l], b) # post smooth
+        #consistent!(x) |> wait #Already inside gauss_seidel
+        PartitionedSolvers.smooth!(x, s.gs_states[l], b) # post smooth
     end
     x
 end
diff --git a/PartitionedSolvers/Project.toml b/PartitionedSolvers/Project.toml
index 61d3b903..7a6a21dd 100644
--- a/PartitionedSolvers/Project.toml
+++ b/PartitionedSolvers/Project.toml
@@ -1,18 +1,21 @@
 name = "PartitionedSolvers"
 uuid = "11b65f7f-80ac-401b-9ef2-3db765482d62"
 authors = ["Francesc Verdugo <f.verdugo.rojano@vu.nl>"]
-version = "0.2.2"
+version = "0.3.0"
 
 [deps]
 IterativeSolvers = "42fd0dbc-a981-5370-80f2-aaf504508153"
 LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
+NLsolve = "2774e3e8-f4cf-5e23-947b-6d7e65073b56"
 PartitionedArrays = "5a9dfac6-5c52-46f7-8278-5e2210713be9"
+Printf = "de0858da-6303-5e67-8744-51eddeeeb8d7"
 Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
 SparseArrays = "2f01184e-e22b-5df5-ae63-d93ebab69eaf"
 SparseMatricesCSR = "a0a7dd2c-ebf4-11e9-1f05-cf50bc540ca1"
 
 [compat]
 IterativeSolvers = "0.9"
+NLsolve = "4"
 PartitionedArrays = "0.4.4, 0.5"
 SparseArrays = "1"
 julia = "1.6"
diff --git a/PartitionedSolvers/src/PartitionedSolvers.jl b/PartitionedSolvers/src/PartitionedSolvers.jl
index ee5b8f68..5ea3a80e 100644
--- a/PartitionedSolvers/src/PartitionedSolvers.jl
+++ b/PartitionedSolvers/src/PartitionedSolvers.jl
@@ -5,39 +5,13 @@ using PartitionedArrays: val_parameter
 using SparseArrays
 using LinearAlgebra
 using IterativeSolvers
+using Printf
+import NLsolve
 using SparseMatricesCSR
 
-export setup
-export solve!
-export update!
-export finalize!
-export AbstractLinearSolver
-export linear_solver
-export default_nullspace
-export nullspace
-export uses_nullspace
-export uses_initial_guess
-export iterations!
 include("interfaces.jl")
-
-export lu_solver
-export jacobi_correction
-export richardson
-export jacobi
-export gauss_seidel
-export additive_schwarz_correction
-export additive_schwarz
+include("wrappers.jl")
 include("smoothers.jl")
-
-export amg
-export smoothed_aggregation
-export v_cycle
-export w_cycle
-export amg_level_params
-export amg_level_params_linear_elasticity
-export amg_fine_params
-export amg_coarse_params
-export amg_statistics
 include("amg.jl")
 
 end # module
diff --git a/PartitionedSolvers/src/amg.jl b/PartitionedSolvers/src/amg.jl
index 5e6a85bb..a5aea212 100644
--- a/PartitionedSolvers/src/amg.jl
+++ b/PartitionedSolvers/src/amg.jl
@@ -1,4 +1,15 @@
 
+function default_nullspace(A)
+    T = eltype(A)
+    [ones(T,size(A,2))]
+end
+
+function default_nullspace(A::PSparseMatrix)
+    col_partition = partition(axes(A,2))
+    T = eltype(A)
+    [ pones(T,col_partition) ]
+end
+
 function aggregate(A,diagA=dense_diag(A);epsilon)
     # TODO It assumes CSC format for the moment
 
@@ -744,7 +755,7 @@ function dofs_to_node(dofs, block_size)
 end
 
 function amg_level_params_linear_elasticity(block_size;
-    pre_smoother = additive_schwarz(gauss_seidel(;iters=1);iters=1),
+    pre_smoother = p-> additive_schwarz(p;local_solver=q->gauss_seidel(q;iterations=1),iterations=1),
     coarsening = smoothed_aggregation_with_block_size(;approximate_omega = lambda_generic,
     tentative_prolongator = tentative_prolongator_with_block_size, block_size = block_size),
     cycle = v_cycle,
@@ -756,7 +767,7 @@ function amg_level_params_linear_elasticity(block_size;
 end
 
 function amg_level_params(;
-    pre_smoother = additive_schwarz(gauss_seidel(;iters=1);iters=1),
+    pre_smoother = p-> additive_schwarz(p;local_solver=q->gauss_seidel(q;iterations=1),iterations=1),
     coarsening = smoothed_aggregation(;),
     cycle = v_cycle,
     pos_smoother = pre_smoother,
@@ -774,23 +785,24 @@ end
 
 function amg_coarse_params(;
     #TODO more resonable defaults?
-    coarse_solver = lu_solver(),
+    coarse_solver = LinearAlgebra_lu,
     coarse_size = 10,
     )
     coarse_params = (;coarse_solver,coarse_size)
     coarse_params
 end
 
-function amg(;
+function amg(p;
         fine_params=amg_fine_params(),
         coarse_params=amg_coarse_params(),)
     amg_params = (;fine_params,coarse_params)
-    setup(x,O,b,options) = amg_setup(x,O,b,nullspace(options),amg_params)
-    update! = amg_update!
-    solve! = amg_solve!
-    finalize! = amg_finalize!
-    uses_nullspace = Val(true)
-    linear_solver(;setup,update!,solve!,finalize!,uses_nullspace)
+
+    x = solution(p)
+    b = rhs(p)
+    A = matrix(p)
+    B = nullspace(p)
+    setup = amg_setup(x,A,b,B,amg_params)
+    linear_solver(amg_update!,amg_step!,p,setup)
 end
 
 function amg_setup(x,A,b,::Nothing,amg_params)
@@ -807,8 +819,8 @@ function amg_setup(x,A,b,B,amg_params)
             return nothing
         end
         (;pre_smoother,pos_smoother,coarsening,cycle) = fine_level
-        pre_setup = setup(pre_smoother,x,A,b)
-        pos_setup = setup(pos_smoother,x,A,b)
+        pre_setup = pre_smoother(linear_problem(x,A,b))
+        pos_setup = pos_smoother(linear_problem(x,A,b))
         coarsen, _ = coarsening
         Ac,Bc,R,P,Ac_setup = coarsen(A,B)
         nc = size(Ac,1)
@@ -830,28 +842,29 @@ function amg_setup(x,A,b,B,amg_params)
     end
     n_fine_levels = count(i->i!==nothing,fine_levels)
     nlevels = n_fine_levels+1
-    coarse_solver_setup = setup(coarse_solver,x,A,b)
+    coarse_solver_setup = coarse_solver(linear_problem(x,A,b))
     coarse_level = (;coarse_solver_setup)
     (;nlevels,fine_levels,coarse_level,amg_params)
 end
 
-function amg_solve!(x,setup,b,options)
+function amg_step!(x,setup,b,phase=:start;kwargs...)
     level=1
     amg_cycle!(x,setup,b,level)
-    x
+    phase=:stop
+    x,setup,phase
 end
 
 function amg_cycle!(x,setup,b,level)
     amg_params = setup.amg_params
     if level == setup.nlevels
         coarse_solver_setup = setup.coarse_level.coarse_solver_setup
-        return solve!(x,coarse_solver_setup,b)
+        return smooth!(x,coarse_solver_setup,b)
     end
     level_params = amg_params.fine_params[level]
     level_setup = setup.fine_levels[level]
     (;cycle) = level_params
     (;R,P,r,rc,rc2,e,ec,ec2,A,Ac,pre_setup,pos_setup) = level_setup
-    solve!(x,pre_setup,b)
+    smooth!(x,pre_setup,b)
     mul!(r,A,x)
     r .= b .- r
     mul!(rc2,R,r)
@@ -861,16 +874,16 @@ function amg_cycle!(x,setup,b,level)
     ec2 .= ec
     mul!(e,P,ec2)
     x .+= e
-    solve!(x,pos_setup,b)
+    smooth!(x,pos_setup,b)
     x
 end
 
-function amg_statistics(P::Preconditioner)
+function amg_statistics(P)
     # Taken from: An Introduction to Algebraic Multigrid, R. D. Falgout, April 25, 2006
     # Grid complexity is the total number of grid points on all grids divided by the number
     # of grid points on the fine grid. Operator complexity is the total number of nonzeroes in the linear operators
     # on all grids divided by the number of nonzeroes in the fine grid operator
-    setup = P.solver_setup
+    setup = P.workspace
     nlevels = setup.nlevels
     level_rows = zeros(Int,nlevels)
     level_nnz = zeros(Int,nlevels)
@@ -909,7 +922,7 @@ end
     amg_cycle!(args...)
 end
 
-function amg_update!(setup,A,options)
+function amg_update!(setup,A)
     amg_params = setup.amg_params
     nlevels = setup.nlevels
     for level in 1:(nlevels-1)
@@ -918,13 +931,13 @@ function amg_update!(setup,A,options)
         (;coarsening) = level_params
         _, coarsen! = coarsening
         (;R,P,A,Ac,Ac_setup,pre_setup,pos_setup) = level_setup
-        update!(pre_setup,A)
-        update!(pos_setup,A)
+        update(pre_setup,matrix=A)
+        update(pos_setup,matrix=A)
         coarsen!(A,Ac,R,P,Ac_setup)
         A = Ac
     end
     coarse_solver_setup = setup.coarse_level.coarse_solver_setup
-    update!(coarse_solver_setup,A)
+    update(coarse_solver_setup,matrix=A)
     setup
 end
 
diff --git a/PartitionedSolvers/src/interfaces.jl b/PartitionedSolvers/src/interfaces.jl
index edc40335..fca4998e 100644
--- a/PartitionedSolvers/src/interfaces.jl
+++ b/PartitionedSolvers/src/interfaces.jl
@@ -4,155 +4,1029 @@ function Base.show(io::IO,data::AbstractType)
     print(io,"PartitionedSolvers.$(nameof(typeof(data)))(…)")
 end
 
-function default_nullspace(A)
-    T = eltype(A)
-    [ones(T,size(A,2))]
+abstract type AbstractProblem <: AbstractType end
+abstract type AbstractSolver <: AbstractType end
+#abstract type AbstractAge <: AbstractType end
+
+#function update(;kwargs...)
+#    function update_it(p)
+#        update(p;kwargs...)
+#    end
+#end
+#
+#function update(s::AbstractSolver)
+#    function update_solver(p)
+#        update(s,p)
+#    end
+#end
+#
+#function update(p::AbstractProblem;kwargs...)
+#    function update_problem(args...)
+#        update(p,args...;kwargs...)
+#    end
+#end
+
+function solve(solver;kwargs...)
+    solver, state = step(solver;kwargs...)
+    while state !== :stop
+        solver, state = step(solver,state)
+    end
+    solver
+end
+
+function history(solver;kwargs...)
+    History(identity,solver,kwargs)
+end
+
+function history(f,solver;kwargs...)
+    History(f,solver,kwargs)
+end
+
+struct History{A,B,C}
+    f::A
+    solver::B
+    kwargs::C
+end
+
+function Base.iterate(a::History)
+    solver, state = step(a.solver;a.kwargs...)
+    a.f(solver), (solver,state)
+end
+
+function Base.iterate(a::History,(solver,state))
+    if state === :stop
+        return nothing
+    end
+    solver, state = step(solver,state)
+    a.f(solver), (solver,state)
+end
+
+solution(a) = a.solution
+jacobian(a) = a.jacobian
+residual(a) = a.residual
+attributes(a) = a.attributes
+problem(a) = a.problem
+matrix(a) = a.matrix
+rhs(a) = a.rhs
+statement(a) = a.statement
+workspace(a) = a.workspace
+interval(a) = a.interval
+coefficients(a) = a.coefficients
+uses_initial_guess(a) = val_parameter(a.uses_initial_guess)
+constant_jacobian(a) = val_parameter(a.constant_jacobian)
+uses_mutable_types(a) = val_parameter(a.uses_mutable_types)
+
+uses_initial_guess(a::AbstractSolver) = uses_initial_guess(attributes(a))
+constant_jacobian(a::AbstractProblem) = constant_jacobian(attributes(a))
+uses_mutable_types(a::AbstractProblem) = uses_mutable_types(attributes(a))
+
+#solution(a::AbstractProblem) = solution(workspace(a))
+#jacobian(a::AbstractProblem) = jacobian(workspace(a))
+#residual(a::AbstractProblem) = residual(workspace(a))
+#age(a::AbstractProblem) = age(workspace(a))
+#matrix(a::AbstractProblem) = matrix(workspace(a))
+#rhs(a::AbstractProblem) = rhs(workspace(a))
+
+solution(a::AbstractSolver) = solution(problem(a))
+matrix(a::AbstractSolver) = matrix(problem(a))
+rhs(a::AbstractSolver) = rhs(problem(a))
+residual(a::AbstractSolver) = residual(problem(a))
+jacobian(a::AbstractSolver) = jacobian(problem(a))
+
+abstract type AbstractLinearProblem <: AbstractProblem end
+
+#struct LinearProblemAge <: AbstractAge
+#    solution::Int
+#    matrix::Int
+#    rhs::Int
+#end
+#
+#function linear_problem_age()
+#    LinearProblemAge(1,1,1)
+#end
+#
+#function increment(a::LinearProblemAge;solution=0,matrix=0,rhs=0)
+#    LinearProblemAge(
+#                          a.solution + solution,
+#                          a.matrix + matrix,
+#                          a.rhs + rhs,
+#                         )
+#end
+
+function linear_problem(solution,matrix,rhs;uses_mutable_types=Val(true),nullspace=nothing)
+    attributes = (;uses_mutable_types,nullspace)
+    LinearProblem(solution,matrix,rhs,attributes)
 end
 
-function default_nullspace(A::PSparseMatrix)
-    col_partition = partition(axes(A,2))
-    T = eltype(A)
-    [ pones(T,col_partition) ]
+struct LinearProblem{A,B,C,D} <: AbstractLinearProblem
+    solution::A
+    matrix::B
+    rhs::C
+    attributes::D
 end
 
-abstract type AbstractLinearSolver <: AbstractType end
+nullspace(a::LinearProblem) = a.attributes.nullspace
+
+function update(p::LinearProblem;kwargs...)
+    data = (;kwargs...)
+    if hasproperty(data,:matrix)
+        A = data.matrix
+    else
+        A = matrix(p)
+    end
+    if hasproperty(data,:rhs)
+        b = data.rhs
+    else
+        b = rhs(p)
+    end
+    if hasproperty(data,:solution)
+        x = data.solution
+    else
+        x = solution(p)
+    end
+    if hasproperty(data,:attributes)
+        attrs = data.attributes
+    else
+        attrs = attributes(p)
+    end
+    LinearProblem(x,A,b,attrs)
+end
 
-function linear_solver(;
-        setup,
-        solve!,
-        update!,
-        finalize! = ls_setup->nothing,
-        step! = nothing,
-        uses_nullspace = Val(false),
+abstract type AbstractLinearSolver <: AbstractSolver end
+
+function LinearAlgebra.ldiv!(x,solver::AbstractLinearSolver,b)
+    if uses_initial_guess(attributes(solver))
+        fill!(x,zero(eltype(x)))
+    end
+    smooth!(x,solver,b;zero_guess=true)
+    x
+end
+
+function smooth!(x,s::AbstractLinearSolver,b;kwargs...)
+    s = update(s,solution=x,rhs=b)
+    s = solve(s;kwargs...)
+    x
+end
+
+function linear_solver(args...;
         uses_initial_guess = Val(true),
-        returns_history = Val(false),
     )
-    if step! === nothing
-        step! = (x,ls_setup,b,options,step=0) -> begin
-            if step !=0
-                return nothing
-            end
-            x = solve!(x,ls_setup,b,options)
-            x,step+1
-        end
+    attributes = (;uses_initial_guess)
+    LinearSolver(args...,attributes)
+end
+
+struct LinearSolver{A,B,C,D,E} <: AbstractLinearSolver
+    update::A
+    step::B
+    problem::C
+    workspace::D
+    attributes::E
+end
+
+function update(s::LinearSolver;problem=s.problem,kwargs...)
+    p = update(problem;kwargs...)
+    workspace = s.workspace
+    if haskey(kwargs,:matrix) || problem !== s.problem
+        workspace = s.update(workspace,matrix(p))
     end
-    traits = LinearSolverTraits(uses_nullspace,uses_initial_guess,returns_history)
-    LinearSolver(setup,solve!,update!,finalize!,step!,traits)
+    LinearSolver(s.update,s.step,p,workspace,s.attributes)
 end
 
-struct LinearSolverTraits{A,B,C} <: AbstractType
-    uses_nullspace::A
-    uses_initial_guess::B
-    returns_history::C
+function step(s::LinearSolver;kwargs...)
+    p = s.problem
+    x = solution(p)
+    b = rhs(p)
+    next = s.step(x,s.workspace,b;kwargs...)
+    #if next === nothing
+    #    return nothing
+    #end
+    x,workspace,state = next
+    p = update(p,solution=x)
+    s = LinearSolver(s.update,s.step,p,workspace,s.attributes)
+    s,state
 end
 
-struct LinearSolver{A,B,C,D,E,F} <: AbstractLinearSolver
-    setup::A
-    solve!::B
-    update!::C
-    finalize!::D
-    step!::E
-    traits::F
+function step(s::LinearSolver,state)
+    p = s.problem
+    x = solution(p)
+    b = rhs(p)
+    next = s.step(x,s.workspace,b,state)
+    #if next === nothing
+    #    return nothing
+    #end
+    x,workspace,state = next
+    p = update(p,solution=x)
+    s = LinearSolver(s.update,s.step,p,workspace,s.attributes)
+    s,state
 end
 
-function linear_solver(s::LinearSolver)
-    s
+function preconditioner(solver,p)
+    dx = similar(solution(p),axes(matrix(p),2))
+    r = similar(rhs(p),axes(matrix(p),1))
+    dp = update(p,solution=dx,rhs=r)
+    solver(dp)
 end
 
-struct Preconditioner{A,B} <: AbstractType
-    solver::A
-    solver_setup::B
+abstract type AbstractNonlinearProblem <: AbstractProblem end
+
+#function update(p::AbstractNonlinearProblem;kwargs...)
+#    function update_nonlinear_problem(x)
+#        update(p,x;kwargs...)
+#    end
+#end
+
+#struct NonlinearProblemAge <: AbstractAge
+#    solution::Int
+#    residual::Int
+#    jacobian::Int
+#end
+#
+#function nonlinear_problem_age()
+#    NonlinearProblemAge(0,0,0)
+#end
+#
+#function increment(a::NonlinearProblemAge;solution=0,residual=0,jacobian=0)
+#    NonlinearProblemAge(
+#                          a.solution + solution,
+#                          a.residual + residual,
+#                          a.jacobian + jacobian,
+#                         )
+#end
+
+function nonlinear_problem(args...;uses_mutable_types=Val(true))
+    attributes = (;uses_mutable_types)
+    NonlinearProblem(args...,attributes)
 end
 
-function setup_options(;nullspace=nothing)
-    options = (;nullspace)
+struct NonlinearProblem{A,B,C,D,E,F} <: AbstractNonlinearProblem
+    statement::A
+    solution::B
+    residual::C
+    jacobian::D
+    workspace::E
+    attributes::F
 end
 
-function nullspace(options)
-    options.nullspace
+#function linear_problem(p::NonlinearProblem)
+#    x = p.solution
+#    r = p.residual
+#    j = p.jacobian
+#    attrs = p.attributes
+#    dx = similar(x,axes(j,2))
+#    linear_problem(dx,j,r;attrs...)
+#end
+#
+#function update(lp::LinearProblem,p::NonlinearProblem)
+#    r = p.residual
+#    j = p.jacobian
+#    update(lp,matrix=j,rhs=j)
+#end
+
+function set(p::NonlinearProblem;kwargs...)
+    data = (;kwargs...)
+    if hasproperty(data,:statement)
+        st = data.statement
+    else
+        st = statement(p)
+    end
+    if hasproperty(data,:residual)
+        b = data.residual
+    else
+        b = residual(p)
+    end
+    if hasproperty(data,:jacobian)
+        A = data.jacobian
+    else
+        A = jacobian(p)
+    end
+    if hasproperty(data,:solution)
+        x = data.solution
+    else
+        x = solution(p)
+    end
+    if hasproperty(data,:attributes)
+        attrs = data.attributes
+    else
+        attrs = attributes(p)
+    end
+    NonlinearProblem(st,x,b,A,p.workspace,attrs)
 end
 
-function uses_nullspace(a::LinearSolver)
-    val_parameter(a.traits.uses_nullspace)
+function update(p::NonlinearProblem;kwargs...)
+    p = set(p;kwargs...)
+    q = set(p,statement=Base.identity)
+    q = p.statement(q)
+    p = set(q,statement=p.statement)
+    p
 end
 
-function uses_initial_guess(a::LinearSolver)
-    val_parameter(a.traits.uses_initial_guess)
+abstract type AbstractNonlinearSolver <: AbstractSolver end
+
+function nonlinear_solver(args...;attributes...)
+    NonlinearSolver(args...,attributes)
 end
 
-function returns_history(a::LinearSolver)
-    val_parameter(a.traits.returns_history)
+struct NonlinearSolver{A,B,C,D,E} <: AbstractNonlinearSolver
+    update::A
+    step::B
+    problem::C
+    workspace::D
+    attributes::E
 end
 
-function setup(solver::LinearSolver,x,A,b;kwargs...)
-    options = setup_options(;kwargs...)
-    solver_setup = solver.setup(x,A,b,options)
-    Preconditioner(solver,solver_setup)
+function update(s::NonlinearSolver;problem=s.problem,kwargs...)
+    p = update(problem;kwargs...)
+    workspace = s.update(s.workspace,p)
+    NonlinearSolver(s.update,s.step,p,workspace,s.attributes)
 end
 
-function update!(P::Preconditioner,A;kwargs...)
-    options = setup_options(;kwargs...)
-    P.solver.update!(P.solver_setup,A,options)
-    P
+function step(s::NonlinearSolver;kwargs...)
+    next = s.step(s.workspace,s.problem;kwargs...)
+    #if next === nothing
+    #    return nothing
+    #end
+    workspace,p,state = next
+    s = NonlinearSolver(s.update,s.step,p,workspace,s.attributes)
+    s,state
 end
 
-function solve_options(;zero_guess=false,history=Val(false))
-    options = (;zero_guess,history)
+function step(s::NonlinearSolver,state)
+    next = s.step(s.workspace,s.problem,state)
+    #if next === nothing
+    #    return nothing
+    #end
+    workspace,p,state = next
+    s = NonlinearSolver(s.update,s.step,p,workspace,s.attributes)
+    s,state
 end
 
-function solve!(x,P::Preconditioner,b;kwargs...)
-    options = solve_options(;kwargs...)
-    next = P.solver.solve!(x,P.solver_setup,b,options)
-    if returns_history(P.solver)
-        x,log = next
+abstract type AbstractODEProblem <: AbstractProblem end
+
+#function update(p::AbstractODEProblem;kwargs...)
+#    function update_ode_problem(x)
+#        update(p,x;kwargs...)
+#    end
+#end
+
+#struct ODEProblemAge <: AbstractAge
+#    solution::Int
+#    residual::Int
+#    jacobian::Int
+#end
+#
+#function ode_problem_age()
+#    ODEProblemAge(0,0,0)
+#end
+#
+#function increment(a::ODEProblemAge;solution=0,residual=0,jacobian=0)
+#    ODEProblemAge(
+#                          a.solution + solution,
+#                          a.residual + residual,
+#                          a.jacobian + jacobian,
+#                         )
+#end
+
+
+function ode_problem(args...;constant_jacobian=Val(false),uses_mutable_types=Val(true))
+    attributes = (;constant_jacobian,uses_mutable_types)
+    ODEProblem(args...,attributes)
+end
+
+struct ODEProblem{A,B,C,D,E,F,G,H} <: AbstractODEProblem
+    statement::A
+    solution::B
+    residual::C
+    jacobian::D
+    interval::E
+    coefficients::F
+    workspace::G
+    attributes::H
+end
+
+function set(p::ODEProblem;kwargs...)
+    data = (;kwargs...)
+    if hasproperty(data,:statement)
+        st = data.statement
+    else
+        st = statement(p)
+    end
+    if hasproperty(data,:residual)
+        b = data.residual
+    else
+        b = residual(p)
+    end
+    if hasproperty(data,:jacobian)
+        A = data.jacobian
+    else
+        A = jacobian(p)
+    end
+    if hasproperty(data,:solution)
+        x = data.solution
+    else
+        x = solution(p)
+    end
+    if hasproperty(data,:attributes)
+        attrs = data.attributes
+    else
+        attrs = attributes(p)
+    end
+    if hasproperty(data,:interval)
+        i = data.interval
     else
-        x = next
-        log = nothing
+        i = interval(p)
     end
-    if val_parameter(options.history) == true
-        return x, log
+    if hasproperty(data,:coefficients)
+        c = data.coefficients
     else
-        return x
+        c = coefficients(p)
     end
+    p = ODEProblem(st,x,b,A,i,c,p.workspace,attrs)
 end
 
-function LinearAlgebra.ldiv!(x,P::Preconditioner,b)
-    if uses_initial_guess(P.solver)
-        fill!(x,zero(eltype(x)))
-    end
-    solve!(x,P,b;zero_guess=true)
-    x
+function update(p::ODEProblem;kwargs...)
+    p = set(p;kwargs...)
+    q = set(p,statement=Base.identity)
+    q = p.statement(q)
+    p = set(q,statement=p.statement)
+    p
 end
 
-function finalize!(P::Preconditioner)
-    P.solver.finalize!(P.solver_setup)
+abstract type AbstractODESolver <: AbstractSolver end
+
+function ode_solver(args...;attributes...)
+    ODESolver(args...,attributes)
 end
 
-function iterations!(x,P::Preconditioner,b;kwargs...)
-    options = solve_options(;kwargs...)
-    params = (;options,x,P,b)
-    LinearSolverIterator(params)
+struct ODESolver{A,B,C,D,E} <: AbstractODESolver
+    update::A
+    step::B
+    problem::C
+    workspace::D
+    attributes::E
 end
 
-struct LinearSolverIterator{A} <: AbstractType
-    params::A
+function update(s::ODESolver;problem=s.problem,kwargs...)
+    p = update(problem;kwargs...)
+    workspace = s.update(s.workspace,p)
+    ODESolver(s.update,s.step,p,workspace,s.attributes)
 end
 
-function Base.iterate(a::LinearSolverIterator)
-    P = a.params.P
-    options = a.params.options
-    b = a.params.b
-    x = a.params.x
-    next = P.solver.step!(x,P.solver_setup,b,options)
-    next
+function step(s::ODESolver;kwargs...)
+    next = s.step(s.workspace,s.problem;kwargs...)
+    #if next === nothing
+    #    return nothing
+    #end
+    workspace,p,state = next
+    s = ODESolver(s.update,s.step,p,workspace,s.attributes)
+    s,state
 end
 
-function Base.iterate(a::LinearSolverIterator,state)
-    P = a.params.P
-    options = a.params.options
-    b = a.params.b
-    x = a.params.x
-    next = P.solver.step!(x,P.solver_setup,b,options,state)
-    next
+function step(s::ODESolver,state)
+    next = s.step(s.workspace,s.problem,state)
+    #if next === nothing
+    #    return nothing
+    #end
+    workspace,p,state = next
+    s = ODESolver(s.update,s.step,p,workspace,s.attributes)
+    s,state
 end
 
+#function linear_problem(args...;nullspace=nothing,block_size=1)
+#    attributes = (;nullspace,block_size)
+#    LinearProblem(args...,attributes)
+#end
+#
+#mutable struct LinearProblem{A,B,C,D}
+#    solution::A
+#    matrix::B
+#    rhs::C
+#    attributes::D
+#end
+#
+#solution(a) = a.solution
+#matrix(a) = a.matrix
+#rhs(a) = a.rhs
+#attributes(a) = a.attributes
+#
+#function update!(p::LinearProblem;kwargs...)
+#    @assert issubset(propertynames(kwargs),propertynames(p))
+#    if hasproperty(kwargs,:matrix)
+#        p.matrix = kwargs.matrix
+#    end
+#    if hasproperty(kwargs,:rhs)
+#        p.rhs = kwargs.rhs
+#    end
+#    if hasproperty(kwargs,:solution)
+#        p.solution = kwargs.solution
+#    end
+#    if hasproperty(kwargs,:attributes)
+#        p.attributes = kwargs.attributes
+#    end
+#    p
+#end
+#
+#function nonlinear_problem(args...;nullspace=nothing,block_size=1)
+#    attributes = (;nullspace,block_size)
+#    NonlinearProblem(args...,attributes)
+#end
+#
+#struct NonlinearProblem{A,B,C,D,E}
+#    statement::A
+#    solution::B
+#    residual::C
+#    jacobian::D
+#    attributes::E
+#end
+#
+#statement(a) = a.statement
+#residual(a) = a.residual
+#jacobian(a) = a.jacobian
+#
+##function update!(p::NonlinearProblem;kwargs...)
+##    @assert issubset(propertynames(kwargs),propertynames(p))
+##    if hasproperty(kwargs,:statement)
+##        p.statement = kwargs.statement
+##    end
+##    if hasproperty(kwargs,:residual)
+##        p.residual = kwargs.residual
+##    end
+##    if hasproperty(kwargs,:jacobian)
+##        p.jacobian = kwargs.jacobian
+##    end
+##    if hasproperty(kwargs,:solution)
+##        p.solution = kwargs.solution
+##    end
+##    if hasproperty(kwargs,:attributes)
+##        p.attributes = kwargs.attributes
+##    end
+##    p
+##end
+#
+#function ode_problem(args...;nullspace=nothing,block_size=1,constant_jacobian=false)
+#    attributes = (;nullspace,block_size,constant_jacobian)
+#    ODEProblem(args...,attributes)
+#end
+#
+#struct ODEProblem{A,B,C,D,E,F}
+#    statement::A
+#    interval::B
+#    solution::C
+#    residual::D
+#    jacobian::E
+#    attributes::F
+#end
+#
+#interval(a) = a.interval
+#
+##function update!(p::ODEProblem;kwargs...)
+##    @assert issubset(propertynames(kwargs),propertynames(p))
+##    if hasproperty(kwargs,:statement)
+##        p.statement = kwargs.statement
+##    end
+##    if hasproperty(kwargs,:residual)
+##        p.residual = kwargs.residual
+##    end
+##    if hasproperty(kwargs,:jacobian)
+##        p.jacobian = kwargs.jacobian
+##    end
+##    if hasproperty(kwargs,:solution)
+##        p.solution = kwargs.solution
+##    end
+##    if hasproperty(kwargs,:attributes)
+##        p.attributes = kwargs.attributes
+##    end
+##    if hasproperty(kwargs,:interval)
+##        p.interval = kwargs.interval
+##    end
+##    p
+##end
+#
+#function solve!(solver;kwargs...)
+#    next = step!(solver;kwargs...)
+#    while next !== nothing
+#        solver,phase = next
+#        next = step!(solver,next)
+#    end
+#    solution(solver)
+#end
+#
+#function history(solver)
+#    History(solver)
+#end
+#
+#struct History{A}
+#    solver::A
+#end
+#
+#function Base.iterate(a::History)
+#    next = step!(a.solver)
+#    if next === nothing
+#        return nothing
+#    end
+#    solution(problem(a.solver)), next
+#end
+#
+#function Base.iterate(a::History,next)
+#    next = step!(a.solver,next)
+#    if next === nothing
+#        return nothing
+#    end
+#    solution(problem(a.solver)), next
+#end
+#
+#struct Solver{A,B,C,D,E}
+#    problem::A
+#    step!::B
+#    update!::C
+#    finalize!::D
+#    attributes::E
+#end
+#
+#uses_initial_guess(a) = a.uses_initial_guess
+#
+#function solver(args...;uses_initial_guess=Val(true))
+#    attributes = (;uses_initial_guess)
+#    Solver(args...,attributes)
+#end
+#
+#function step!(solver::Solver)
+#    solver.step!()
+#end
+#
+#function step!(solver::Solver,state)
+#    solver.step!(state)
+#end
+#
+#function update!(solver::Solver;kwargs...)
+#    kwargs2 = solver.update!(;kwargs...)
+#    update!(solver.problem,kwargs2...)
+#    solver
+#end
+#
+#function finalize!(solver::Solver)
+#    solver.step!()
+#end
+#
+#function LinearAlgebra.ldiv!(x,solver::Solver,b)
+#    if uses_initial_guess(solver.attributes)
+#        fill!(x,zero(eltype(x)))
+#    end
+#    smooth!(x,solver,b;zero_guess=true)
+#    x
+#end
+#
+#function smooth!(x,solver,b;kwargs...)
+#    update!(solver,solution=x,rhs=b)
+#    solve!(solver;kwargs...)
+#    x
+#end
+#
+#function LinearAlgebra_lu(problem)
+#    F = lu(matrix(problem))
+#    function lu_step!(phase=:start)
+#        if phase === :stop
+#            return nothing
+#        end
+#        x = solution(problem)
+#        b = rhs(problem)
+#        ldiv!(x,F,b)
+#        phase = :stop
+#        phase
+#    end
+#    function lu_update!(;kwargs...)
+#        if hasproperty(kwargs,:matrix)
+#            lu!(F,kwargs.matrix)
+#        end
+#        kwargs
+#    end
+#    function lu_finalize!()
+#        nothing
+#    end
+#    uses_initial_guess = false
+#    solver(problem,lu_step!,lu_update!,lu_finalize!;uses_initial_guess)
+#end
+#
+#function identity_solver(problem)
+#    function id_step!(phase=:start)
+#        if phase === :stop
+#            return nothing
+#        end
+#        x = solution(problem)
+#        b = rhs(problem)
+#        copyto!(x,b)
+#        phase = :stop
+#        phase
+#    end
+#    function id_update!(;kwargs...)
+#        kwargs
+#    end
+#    function id_finalize!()
+#        nothing
+#    end
+#    uses_initial_guess = false
+#    solver(problem,id_step!,id_update!,id_finalize!;uses_initial_guess)
+#end
+#
+#function jacobi_correction(problem)
+#    Adiag = dense_diag(matrix(problem))
+#    function jc_step!(phase=:start)
+#        if phase === :stop
+#            return nothing
+#        end
+#        x = solution(problem)
+#        b = rhs(problem)
+#        x .= Adiag .\ b
+#        phase = :stop
+#        phase
+#    end
+#    function jc_update!(;kwargs...)
+#        if hasproperty(kwargs,:matrix)
+#            dense_diag!(Adiag,kwargs.matrix)
+#        end
+#        kwargs
+#    end
+#    function jc_finalize!()
+#        nothing
+#    end
+#    uses_initial_guess = false
+#    solver(problem,id_step!,id_update!,id_finalize!;uses_initial_guess)
+#end
+#
+#function richardson(problem;
+#    P = identity_solver(problem),
+#    iterations = 10,
+#    omega = 1,
+#    )
+#    iteration = 0
+#    b = rhs(problem)
+#    A = matrix(problem)
+#    x = solution(problem)
+#    dx = similar(x,axes(A,2))
+#    r = similar(b,axes(A,1))
+#    function rc_step!(phase=:start;zero_guess=false)
+#        @assert phase in (:start,:stop,:advance)
+#        if phase === :stop
+#            return nothing
+#        end
+#        if phase === :start
+#            iteration = 0
+#            phase = :advance
+#        end
+#        b = rhs(problem)
+#        A = matrix(problem)
+#        x = solution(problem)
+#        dx .= x
+#        if zero_guess
+#            r .= .- b
+#        else
+#            mul!(r,A,dx)
+#            r .-= b
+#        end
+#        ldiv!(dx,P,r)
+#        x .-= omega .* dx
+#        iteration += 1
+#        if iteration == iterations
+#            phase = :stop
+#        end
+#        phase
+#    end
+#    function rc_update!(;kwargs...)
+#        kwargs
+#    end
+#    function rc_finalize!()
+#        nothing
+#    end
+#    solver(problem,id_step!,id_update!,id_finalize!)
+#end
+#
+#function jacobi(problem;iterations=10,omega=1)
+#    P = jacobi_correction(problem)
+#    R = richardson(problem;P,iterations,omega)
+#    function ja_step!(args...)
+#        step!(R,args...)
+#    end
+#    function update!(;kwargs...)
+#        update!(P;kwargs...)
+#        update!(R;kwargs...)
+#        kwargs
+#    end
+#    function rc_finalize!()
+#        finalize!(P)
+#        finalize!(R)
+#    end
+#    solver(problem,ja_step!,ja_update!,ja_finalize!)
+#end
+#
+#function convergence(;kwargs...)
+#    convergence(Float64;kwargs...)
+#end
+#
+#function convergence(::Type{T};
+#    iterations = 1000,
+#    abs_res_tol = typemax(T),
+#    rel_res_tol = T(1e-12),
+#    abs_sol_tol = zero(T),
+#    res_norm = norm,
+#    sol_norm = dx -> maximum(abs,dx)
+#    ) where T
+#    (;iterations,abs_res_tol,rel_res_tol,abs_sol_tol,res_norm,sol_norm)
+#end
+#
+#function verbosity(;
+#        level=0,
+#        prefix="")
+#    (;level,prefix)
+#end
+#
+#function status(params)
+#    (;abs_res_tol,abs_sol_tol,rel_res_tol,iterations) = params
+#    res_target = zero(abs_res_tol)
+#    sol_target = zero(abs_sol_tol)
+#    iteration = 0
+#    status = Status(iterations,iteration,res_target,sol_target,res_error,sol_error)
+#    start!(status,res_error,sol_error)
+#    status
+#end
+#
+#mutable struct Status{T}
+#    iterations::Int
+#    iteration::Int
+#    res_target::T
+#    sol_target::T
+#    res_error::T
+#    sol_error::T
+#end
+#
+#function start!(status::Status,res_error,sol_error)
+#    res_target = min(abs_res_tol,res_error*rel_res_tol)
+#    sol_target = abs_sol_tol
+#    iteration = 0
+#end
+#
+#function step!(status::Status,res_error,sol_error)
+#    status.iteration += 1
+#    status.res_error = res_error
+#    status.sol_error = sol_error
+#    status
+#end
+#
+#function tired(status)
+#    status.iteration >= status.iterations
+#end
+#
+#function converged(status)
+#    status.res_error <= status.res_target || status.sol_error <= status.sol_target
+#end
+#
+#function print_progress(verbosity,status)
+#    if verbosity.level > 0
+#        s = verbosity.prefix
+#        @printf "%s%6i %6i %12.3e %12.3e\n" s status.iteration status.iterations status.res_error a.res_target
+#    end
+#end
+#
+##struct Usage
+##    num_updates::Dict{Symbol,Int}
+##end
+##
+##function usage()
+##    num_updates = Dict{Symbol,Int}()
+##    Usage(num_updates)
+##end
+##
+##function start!(usage::Usage)
+##    for k in keys(usage.num_updates)
+##        usage.num_updates[k] = 0
+##    end
+##    usage
+##end
+##
+##function update!(usage::Usage;kwargs...)
+##    for k in propertynames(kwargs)
+##        if ! haskey(usage.num_updates,k)
+##            usage.num_updates[k] = 0
+##        end
+##        usage.num_updates[k] += 1
+##    end
+##    usage
+##end
+#
+#function newton_raphson(problem;
+#        solver=lp->LinearAlgebra_lu(lp),
+#        convergence = PartitionedSolvers.convergence(eltype(solution(problem))),
+#        verbosity = PartitionedSolvers.verbosity(),
+#    )
+#    x = solution(problem)
+#    J = jacobian(problem)
+#    r = residual(problem)
+#    dx = similar(x,axes(J,2))
+#    lp = linear_problem(dx,J,r)
+#    S = solver(lp)
+#    status = PartitionedSolvers.status(convergence)
+#    function nr_step!(phase=:start;kwargs...)
+#        @assert phase in (:start,:stop,:advance)
+#        if phase === :stop
+#            return nothing
+#        end
+#        x = solution(problem)
+#        J = jacobian(problem)
+#        r = residual(problem)
+#        rj! = statement(problem)
+#        if phase === :start
+#            rj!(r,J,x)
+#            res_error = convergence.res_norm(r)
+#            sol_error = typemax(res_error)
+#            start!(status,res_error,sol_error)
+#            print_progress(verbosity,status)
+#            phase = :advance
+#        end
+#        update!(S,matrix=J)
+#        dx = solution(S)
+#        ldiv!(dx,S,b)
+#        x .-= dx
+#        rj!(r,J,x)
+#        res_error = convergence.res_norm(r)
+#        sol_error = convergence.sol_norm(dx)
+#        step!(status,res_error,sol_error)
+#        print_progress(verbosity,status)
+#        if converged(status) || tired(status)
+#            phase = :stop
+#        end
+#        phase
+#    end
+#    function nr_update!(;kwargs...)
+#        kwargs
+#    end
+#    function nr_finalize!()
+#        nothing
+#    end
+#    PartitionedSolvers.solver(problem,nr_step!,nr_update!,nr_finalize!)
+#end
+#
+#function print_time_step(verbosity,t,tend)
+#    if verbosity.level > 0
+#        s = verbosity.prefix
+#        @printf "%s%12.3e %12.3e\n" s t tend
+#    end
+#end
+#
+#function backward_euler(ode;
+#        dt = (interval(ode)[2]-interval(ode)[1])/100,
+#        solver = constant_jacobian(ode) ? LinearAlgebra_lu : newton_raphson,
+#        verbosity = PartitionedSolvers.verbosity(),
+#    )
+#
+#    (t,u,v) = solution(ode)
+#    x = copy(u)
+#    J = jacobian(ode)
+#    r = residual(ode)
+#    attrs = attributes(problem)
+#    rj! = statement(problem)
+#    if constant_jacobian(ode)
+#        rj!(r,j,(t,u,v),(1,1/dt))
+#        lp = linear_problem(x,J,r;attrs...)
+#        S = solver(lp)
+#    else
+#        nlp = nonlinear_problem(x,J,r;attrs...) do r,j,x
+#            v .= (x .- u) ./ dt
+#            rj!(r,j,(t,x,v),(1,1/dt))
+#        end
+#        S = solver(nlp)
+#    end
+#    function be_step!(phase=:start;kwargs...)
+#        @assert phase in (:start,:stop,:advance)
+#        if phase === :stop
+#            return nothing
+#        end
+#        J = jacobian(ode)
+#        r = residual(ode)
+#        tend = last(interval(ode))
+#        if phase === :start
+#            t = first(interval(ode))
+#            phase = :advance
+#            if constant_jacobian
+#                rj!(r,J,(t,u,v),(1,1/dt))
+#            end
+#            print_time_step(verbosity,t,tend)
+#        end
+#        x = solve!(S)
+#        v .= (x .- u) ./ dt
+#        u .= x
+#        t += dt
+#        if constant_jacobian(ode)
+#            rj!(r,nothing,(t,u,v),(1,1/dt))
+#        end
+#        print_time_step(verbosity,t,tend)
+#        if t >= tend
+#            phase = :stop
+#        end
+#    end
+#    function be_update!(;kwargs...)
+#        kwargs
+#    end
+#    function be_finalize!()
+#        nothing
+#    end
+#    PartitionedSolvers.solver(problem,be_step!,be_update!,be_finalize!)
+#end
+
diff --git a/PartitionedSolvers/src/smoothers.jl b/PartitionedSolvers/src/smoothers.jl
index a015cc7d..da809330 100644
--- a/PartitionedSolvers/src/smoothers.jl
+++ b/PartitionedSolvers/src/smoothers.jl
@@ -1,343 +1,356 @@
 
-function lu_solver()
-    setup(x,op,b,options) = lu(op)
-    update!(state,op,options) = lu!(state,op)
-    solve!(x,P,b,options) = ldiv!(x,P,b)
+function identity_solver(p)
+    @assert uses_mutable_types(p)
+    workspace = nothing
+    function update(workspace,A)
+        workspace
+    end
+    function step(x,workspace,b,phase=:start;kwargs...)
+        copyto!(x,b)
+        phase = :stop
+        x,workspace,phase
+    end
     uses_initial_guess = Val(false)
-    linear_solver(;setup,solve!,update!,uses_initial_guess)
+    linear_solver(update,step,p,uses_initial_guess)
 end
 
-function jacobi_correction()
-    setup(x,op,b,options) = dense_diag!(similar(b),op)
-    update!(state,op,options) = dense_diag!(state,op)
-    function solve!(x,state,b,options)
-        x .= state .\ b
-        x
+function jacobi_correction(p)
+    @assert uses_mutable_types(p)
+    Adiag = dense_diag!(similar(rhs(p)),matrix(p))
+    function update(Adiag,A)
+        dense_diag!(Adiag,A)
+        Adiag
+    end
+    function step(x,Adiag,b,phase=:start;kwargs...)
+        x .= Adiag .\ b
+        phase = :stop
+        x,Adiag,phase
     end
     uses_initial_guess = Val(false)
-    linear_solver(;setup,update!,solve!,uses_initial_guess)
+    linear_solver(update,step,p,Adiag;uses_initial_guess)
 end
 
-function richardson(solver;iters,omega=1)
-    function setup(x,A,b,options)
-        A_ref = Ref(A)
-        r = similar(b)
-        dx = similar(x,axes(A,2))
-        P = PartitionedSolvers.setup(solver,dx,A,r)
-        state = (r,dx,P,A_ref)
-    end
-    function update!(state,A,options)
-        (r,dx,P,A_ref) = state
-        A_ref[] = A
-        PartitionedSolvers.update!(P,A)
-        state
+function richardson(p;
+        P=preconditioner(identity_solver,p) ,
+        iterations=10,
+        omega=1,
+        update_P = true,
+    )
+    @assert uses_mutable_types(p)
+    iteration = 0
+    A = matrix(p)
+    ws = (;iterations,P,iteration,omega,update_P,A)
+    linear_solver(richardson_update,richardson_step,p,ws)
+end
+
+function richardson_update(ws,A)
+    (;iterations,P,iteration,omega,update_P) = ws
+    if update_P
+        P = update(P,matrix=A)
     end
-    function solve!(x,state,b,options)
-        (r,dx,P,A_ref) = state
-        A = A_ref[]
-        for iter in 1:iters
-            dx .= x
-            mul!(r,A,dx)
-            r .-= b
-            ldiv!(dx,P,r)
-            x .-= omega .* dx
+    iteration = 0
+    ws = (;iterations,P,iteration,omega,update_P,A)
+end
+
+function richardson_step(x,ws,b,phase=:start;kwargs...)
+    (;iterations,P,iteration,omega,update_P,A) = ws
+    if phase === :start
+        iteration = 0
+        phase = :advance
+    end
+    dx = solution(P)
+    r = rhs(P)
+    dx .= x
+    mul!(r,A,dx)
+    r .-= b
+    ldiv!(dx,P,r)
+    x .-= omega .* dx
+    iteration += 1
+    if iteration == iterations
+        phase = :stop
+    end
+    ws = (;iterations,P,iteration,omega,update_P,A)
+    x,ws,phase
+end
+
+function jacobi(p;iterations=10,omega=1)
+    P = preconditioner(jacobi_correction,p)
+    update_P = true
+    richardson(p;P,update_P,iterations,omega)
+end
+
+function gauss_seidel(p;iterations=1,sweep=:symmetric)
+    @assert uses_mutable_types(p)
+    iteration = 0
+    A = matrix(p)
+    Adiag = dense_diag!(similar(rhs(p)),A)
+    ws = (;iterations,sweep,iteration,A,Adiag)
+    linear_solver(gauss_seidel_update,gauss_seidel_step,p,ws)
+end
+
+function gauss_seidel_update(ws,A)
+    (;iterations,sweep,iteration,A,Adiag) = ws
+    dense_diag!(Adiag,A)
+    iteration = 0
+    ws = (;iterations,sweep,iteration,A,Adiag)
+end
+
+function gauss_seidel_step(x,ws,b,phase=:start;zero_guess=false,kwargs...)
+    (;iterations,sweep,iteration,A,Adiag) = ws
+    if phase === :start
+        iteration = 0
+        phase = :advance
+    end
+    if (! zero_guess) && isa(x,PVector)
+        consistent!(x) |> wait
+    end
+    # TODO the meaning of :forward and :backward
+    # depends on the sparse matrix format
+    if sweep === :symmetric || sweep === :forward
+        if zero_guess
+            gauss_seidel_forward_sweep_zero!(x,A,Adiag,b)
+        else
+            gauss_seidel_forward_sweep!(x,A,Adiag,b)
         end
-        x, (;iters)
     end
-    function step!(x,state,b,options,step=0)
-        if step == iters
-            return nothing
-        end
-        (r,dx,P,A_ref) = state
-        A = A_ref[]
-        dx .= x
-        mul!(r,A,dx)
-        r .-= b
-        ldiv!(dx,P,r)
-        x .-= omega .* dx
-        x,step+1
+    if sweep === :symmetric || sweep === :backward
+        gauss_seidel_backward_sweep!(x,A,Adiag,b)
     end
-    function finalize!(state)
-        (r,dx,P,A_ref) = state
-        PartitionedSolvers.finalize!(P)
+    iteration += 1
+    if iteration == iterations
+        phase = :stop
     end
-    returns_history = Val(true)
-    linear_solver(;setup,update!,solve!,finalize!,step!,returns_history)
+    ws = (;iterations,sweep,iteration,A,Adiag)
+    x,ws,phase
+end
+
+function gauss_seidel_forward_sweep!(x,A,diagA,b)
+    n = length(b)
+    gauss_seidel_sweep!(x,A,diagA,b,1:n)
+end
+
+function gauss_seidel_backward_sweep!(x,A,diagA,b)
+    n = length(b)
+    gauss_seidel_sweep!(x,A,diagA,b,n:-1:1)
 end
 
-function jacobi(;kwargs...)
-    solver = jacobi_correction()
-    richardson(solver;kwargs...)
+function gauss_seidel_forward_sweep!(x,A::PSparseMatrix,diagA,b)
+    foreach(gauss_seidel_forward_sweep!,partition(x),partition(A),partition(diagA),own_values(b))
 end
 
-function gauss_seidel(;iters=1,sweep=:symmetric)
-    @assert sweep in (:forward,:backward,:symmetric)
-    function setup(x,A,b,options)
-        diagA = dense_diag!(similar(b),A)
-        A_ref = Ref(A)
-        (diagA,A_ref)
+function gauss_seidel_backward_sweep!(x,A::PSparseMatrix,diagA,b)
+    foreach(gauss_seidel_backward_sweep!,partition(x),partition(A),partition(diagA),own_values(b))
+end
+
+function gauss_seidel_sweep!(x,A::SparseArrays.AbstractSparseMatrixCSC,diagA,b,cols)
+    # assumes symmetric matrix
+    for col in cols
+        s = b[col]
+        for p in nzrange(A,col)
+            row = A.rowval[p]
+            a = A.nzval[p]
+            s -= a*x[row]
+        end
+        d = diagA[col]
+        s += d*x[col]
+        s = s/d
+        x[col] = s
     end
-    function update!(state,A,options)
-        (diagA,A_ref) = state
-        dense_diag!(diagA,A)
-        A_ref[] = A
-        state
+    x
+end
+
+function gauss_seidel_sweep!(x,A::SparseMatricesCSR.SparseMatrixCSR,diagA,b,rows)
+    for row in rows
+        s = b[row]
+        for p in nzrange(A,row)
+            col = A.colval[p]
+            a = A.nzval[p]
+            s -= a * x[col]
+        end
+        d = diagA[row]
+        s += d * x[row]
+        s = s / d
+        x[row] = s
     end
-    function gauss_seidel_sweep!(x,A::SparseArrays.AbstractSparseMatrixCSC,diagA,b,cols)
-        #assumes symmetric matrix
-        for col in cols
-            s = b[col]
-            for p in nzrange(A,col)
-                row = A.rowval[p]
-                a = A.nzval[p]
-                s -= a*x[row]
-            end
-            d = diagA[col]
-            s += d*x[col]
-            s = s/d
-            x[col] = s
+    x
+end
+
+function gauss_seidel_sweep!(x,A::PartitionedArrays.AbstractSplitMatrix,diagA,b,cols)
+    @assert isa(A.row_permutation,UnitRange)
+    @assert isa(A.col_permutation,UnitRange)
+    Aoo = A.blocks.own_own
+    Aoh = A.blocks.own_ghost
+    gauss_seidel_sweep_split!(x,Aoo,Aoh,diagA,b,cols)
+end
+
+function gauss_seidel_sweep_split!(x,Aoo::SparseMatricesCSR.SparseMatrixCSR,Aoh,diagA,b,rows)
+    for row in rows
+        s = b[row]
+        for p in nzrange(Aoo,row)
+            col = Aoo.colval[p]
+            a = Aoo.nzval[p]
+            s -= a * x[col]
+        end
+        for p in nzrange(Aoh,row)
+            col = Aoh.colval[p]
+            a = Aoh.nzval[p]
+            s -= a * x[col]
         end
-        x
+        d = diagA[row]
+        s += d * x[row]
+        s = s / d
+        x[row] = s
     end
-    function gauss_seidel_sweep!(x,A::SparseMatricesCSR.SparseMatrixCSR,diagA,b,rows)
-        #assumes symmetric matrix
-        for row in rows
-            s = b[row]
-            for p in nzrange(A,row)
-                col = A.colval[p]
+    x
+end
+
+function gauss_seidel_forward_sweep_zero!(x,A,diagA,b)
+    n = length(b)
+    gauss_seidel_sweep_zero!(x,A,diagA,b,1:n)
+end
+
+# TODO not sure if correct
+function gauss_seidel_backward_sweep_zero!(x,A,diagA,b)
+    n = length(b)
+    gauss_seidel_sweep_zero!(x,A,diagA,b,n:-1:1)
+end
+
+function gauss_seidel_forward_sweep_zero!(x,A::PSparseMatrix,diagA,b)
+    foreach(gauss_seidel_forward_sweep_zero!,partition(x),partition(A),partition(diagA),own_values(b))
+end
+
+function gauss_seidel_backward_sweep_zero!(x,A::PSparseMatrix,diagA,b)
+    foreach(gauss_seidel_backward_sweep_zero!,partition(x),partition(A),partition(diagA),own_values(b))
+end
+
+# Zero guess: only calculate points below diagonal of sparse matrix in forward sweep.
+function gauss_seidel_sweep_zero!(x,A::SparseArrays.AbstractSparseMatrixCSC,diagA,b,cols)
+    gauss_seidel_sweep!(x,A,diagA,b,cols)
+    ## There is a bug, falling back to nonzero x
+    ##assumes symmetric matrix
+    #for col in cols
+    #    s = b[col]
+    #    for p in nzrange(A,col)
+    #        row = A.rowval[p]
+    #        if col < row
+    #            a = A.nzval[p]
+    #            s -= a*x[row]
+    #        end
+    #    end
+    #    d = diagA[col]
+    #    #s += d*x[col]
+    #    s = s/d
+    #    x[col] = s
+    #end
+    #x
+end
+# Zero guess: only calculate points below diagonal of sparse matrix in forward sweep.
+function gauss_seidel_sweep_zero!(x,A::SparseMatricesCSR.SparseMatrixCSR,diagA,b,rows)
+    rows
+    length(x)
+    size(A)
+    length(b)
+    #assumes symmetric matrix
+    for row in rows
+        s = b[row]
+        for p in nzrange(A,row)
+            col = A.colval[p]
+            if col < row
                 a = A.nzval[p]
                 s -= a * x[col]
             end
-            d = diagA[row]
-            s += d * x[row]
-            s = s / d
-            x[row] = s
         end
-        x
+        d = diagA[row]
+        #s += d * x[col]
+        s = s / d
+        x[row] = s
     end
+    x
+end
 
-    # Zero guess: only calculate points below diagonal of sparse matrix in forward sweep.
-    function gauss_seidel_sweep_zero!(x,A::SparseArrays.AbstractSparseMatrixCSC,diagA,b,cols)
-        #assumes symmetric matrix
-        for col in cols
-            s = b[col]
-            for p in nzrange(A,col)
-                row = A.rowval[p]
-                if col < row
-                    a = A.nzval[p]
-                    s -= a*x[row]
-                end
-            end
-            d = diagA[col]
-            #s += d*x[col]
-            s = s/d
-            x[col] = s
-        end
-        x
-    end
-    # Zero guess: only calculate points below diagonal of sparse matrix in forward sweep.
-    function gauss_seidel_sweep_zero!(x,A::SparseMatricesCSR.SparseMatrixCSR,diagA,b,rows)
-        #assumes symmetric matrix
-        for row in rows
-            s = b[row]
-            for p in nzrange(A,row)
-                col = A.colval[p]
-                if col < row
-                    a = A.nzval[p]
-                    s -= a * x[col]
-                end
+function gauss_seidel_sweep_zero!(x,A::PartitionedArrays.AbstractSplitMatrix,diagA,b,cols)
+    @assert isa(A.row_permutation,UnitRange)
+    @assert isa(A.col_permutation,UnitRange)
+    Aoo = A.blocks.own_own
+    Aoh = A.blocks.own_ghost
+    gauss_seidel_sweep_zero_split!(x,Aoo,Aoh,diagA,b,cols)
+end
+
+function gauss_seidel_sweep_zero_split!(x,Aoo::SparseMatricesCSR.SparseMatrixCSR,Aoh,diagA,b,rows)
+    for row in rows
+        s = b[row]
+        for p in nzrange(Aoo,row)
+            col = Aoo.colval[p]
+            if col < row
+                a = Aoo.nzval[p]
+                s -= a * x[col]
             end
-            d = diagA[row]
-            #s += d * x[col]
-            s = s / d
-            x[row] = s
         end
-        x
-    end
-    function solve!(x,state,b,options)
-        (diagA,A_ref) = state
-        A = A_ref[]
-        n = length(b)
-
-        for iter in 1:iters
-            if sweep === :symmetric || sweep === :forward
-                if options.zero_guess
-                    gauss_seidel_sweep_zero!(x,A,diagA,b,1:n)
-                else
-                    gauss_seidel_sweep!(x,A,diagA,b,1:n)
-                end
-            end
-            if sweep === :symmetric || sweep === :backward
-                gauss_seidel_sweep!(x,A,diagA,b,n:-1:1)
+        for p in nzrange(Aoh,row)
+            col = Aoh.colval[p]
+            if col < row
+                a = Aoh.nzval[p]
+                s -= a * x[col]
             end
         end
-        x
+        d = diagA[row]
+        #s += d * x[col]
+        s = s / d
+        x[row] = s
     end
-    linear_solver(;setup,update!,solve!)
+    x
 end
 
-function additive_schwarz(local_solver;iters=1)
-    richardson(additive_schwarz_correction(local_solver);iters)
+function additive_schwarz_correction(p;local_solver=LinearAlgebra_lu)
+    x = solution(p)
+    A = matrix(p)
+    b = rhs(p)
+    local_s = additive_schwarz_correction_setup(local_solver,x,A,b)
+    uses_initial_guess = Val(false)
+    linear_solver(
+        additive_schwarz_correction_update,
+        additive_schwarz_correction_step,
+        p,
+        local_s;
+        uses_initial_guess
+       )
 end
 
-function local_setup_options(A,options)
-    if nullspace(options) !== nothing
-        ns = map(i->own_values(i),nullspace(options))
-        map(ns) do ns
-            setup_options(;nullspace=ns)
-        end
-    else
-        map(partition(A)) do A
-            options
-        end
-    end
+function additive_schwarz_correction_setup(local_solver,x,A::PSparseMatrix,b)
+    local_p = map(linear_problem,own_values(x),own_own_values(A),own_values(b))
+    local_s = map(local_solver,local_p)
 end
 
-function local_solver_options(A,options)
-    map(partition(A)) do A
-        options
-    end
+function additive_schwarz_correction_update(local_s,A::PSparseMatrix)
+    local_s = map(additive_schwarz_correction_update,local_s,own_own_values(A))
 end
 
-struct AdditiveSchwarzSetup{A} <: AbstractType
-    local_setups::A
+function additive_schwarz_correction_step(x::PVector,local_s,b,phase=:start;kwargs...)
+    foreach(ldiv!,own_values(x),local_s,own_values(b))
+    phase = :stop
+    x,local_s,phase
 end
 
-function additive_schwarz_correction(local_solver)
-    # For parallel matrices
-    function setup(x,A::PSparseMatrix,b,options)
-        map(
-            local_solver.setup,
-            own_values(x),
-            own_own_values(A),
-            own_values(b),
-            local_setup_options(A,options),
-           ) |> AdditiveSchwarzSetup
-    end
-    function update!(state::AdditiveSchwarzSetup,A,options)
-        map(
-            local_solver.update!,
-            state.local_setups,
-            own_own_values(A),
-            local_setup_options(A,options),
-           )
-    end
-    function solve!(x,state::AdditiveSchwarzSetup,b,options)
-        map(
-            local_solver.solve!,
-            own_values(x),
-            state.local_setups,
-            own_values(b),
-            local_solver_options(b,options)
-           )
-        x
-    end
-    function finalize!(state::AdditiveSchwarzSetup)
-        map(
-            local_solver.finalize!,
-            state.local_setups)
-        nothing
-    end
-    # Fall back for sequential matrices
-    function setup(x,A,b,options)
-        local_solver.setup(x,A,b,options)
-    end
-    function update!(state,A,options)
-        local_solver.update!(state,A,options)
-    end
-    function solve!(x,state,b,options)
-        local_solver.solve!(x,state,b,options)
-        x
-    end
-    function finalize!(state)
-        local_solver.finalize!(state)
-        nothing
-    end
-    linear_solver(;setup,update!,solve!,finalize!)
-end
-
-function additive_schwarz_correction_partition(local_solver)
-    # For parallel matrices
-    function setup(x,A::PSparseMatrix,b,options)
-        map(
-            local_solver.setup,
-            partition(x),
-            partition(A),
-            own_values(b),
-            local_setup_options(A,options),
-        ) |> AdditiveSchwarzSetup
-    end
-    function update!(state::AdditiveSchwarzSetup,A,options)
-        map(
-            local_solver.update!,
-            state.local_setups,
-            partition(A),
-            local_setup_options(A,options),
-        )
-    end
-    function solve!(x,state::AdditiveSchwarzSetup,b,options)
-        map(
-            local_solver.solve!,
-            partition(x),
-            state.local_setups,
-            own_values(b),
-            local_solver_options(b,options),
-        )
-        x
-    end
-    function finalize!(state::AdditiveSchwarzSetup)
-        map(
-            local_solver.finalize!,
-            state.local_setups)
-        nothing
-    end
-    # Fall back for sequential matrices
-    function setup(x,A,b,options)
-        local_solver.setup(x,A,b,options)
-    end
-    function update!(state,A,options)
-        local_solver.update!(state,A,options)
-    end
-    function solve!(x,state,b,options)
-        local_solver.solve!(x,state,b,options)
-        x
-    end
-    function finalize!(state)
-        local_solver.finalize!(state)
-        nothing
-    end
-    linear_solver(;setup,update!,solve!,finalize!)
+function additive_schwarz_correction_setup(local_solver,x,A,b)
+    local_p = linear_problem(x,A,b)
+    local_s = local_solver(local_p)
 end
 
-# Wrappers
+function additive_schwarz_correction_update(local_s,A)
+    local_s = update(local_s,matrix=A)
+end
 
-function linear_solver(::typeof(LinearAlgebra.lu))
-    lu_solver()
+function additive_schwarz_correction_step(x,local_s,b,phase=:start;kwargs...)
+    ldiv!(x,local_s,b)
+    phase = :stop
+    x,local_s,phase
 end
 
-function linear_solver(::typeof(IterativeSolvers.cg);Pl,kwargs...)
-    function setup(x,A,b,options)
-        Pl_solver = linear_solver(Pl)
-        P = PartitionedSolvers.setup(Pl_solver,x,A,b;options...)
-        A_ref = Ref(A)
-        (;P,A_ref)
-    end
-    function update!(state,A,options)
-        (;P,A_ref) = state
-        A_ref[] = A
-        P = PartitionedSolvers.update!(P,A;options...)
-        state
-    end
-    function solve!(x,state,b,options)
-        (;P,A_ref) = state
-        A = A_ref[]
-        IterativeSolvers.cg!(x,A,b;Pl=P,kwargs...)
-    end
-    function finalize!(state,A,options)
-        (;P) = state
-        PartitionedSolvers.finalize!(P)
-        nothing
+function additive_schwarz(p;local_solver=LinearAlgebra_lu,iterations=1)
+    P = preconditioner(p) do dp
+        additive_schwarz_correction(dp;local_solver)
     end
-    linear_solver(;setup,update!,solve!,finalize!)
+    update_P = true
+    richardson(p;P,update_P,iterations)
 end
 
diff --git a/PartitionedSolvers/src/wrappers.jl b/PartitionedSolvers/src/wrappers.jl
new file mode 100644
index 00000000..3789fe6f
--- /dev/null
+++ b/PartitionedSolvers/src/wrappers.jl
@@ -0,0 +1,80 @@
+
+function LinearAlgebra_lu(p)
+    @assert uses_mutable_types(p)
+    F = lu(matrix(p))
+    function update(F,A)
+        lu!(F,A)
+        F
+    end
+    function step(x,F,b,phase=:start;kwargs...)
+        ldiv!(x,F,b)
+        phase = :stop
+        x,F,phase
+    end
+    uses_initial_guess = Val(false)
+    linear_solver(update,step,p,F;uses_initial_guess)
+end
+
+function IterativeSolvers_cg(p;kwargs...)
+    A = matrix(p)
+    function update(state,A)
+        A
+    end
+    function step(x,A,b,phase=:start;kwargs...)
+        IterativeSolvers.cg!(x,A,b;kwargs...)
+        phase = :stop
+        x,A,phase
+    end
+    linear_solver(update,step,p,A)
+end
+
+function NLSolvers_nlsolve_setup(p)
+    function f!(r,x)
+        update(p,residual=r,jacobian=nothing,solution=x)
+        r
+    end
+    function j!(j,x)
+        update(p,residual=nothing,jacobian=j,solution=x)
+        j
+    end
+    function fj!(r,j,x)
+        update(p,residual=r,jacobian=j,solution=x)
+        r,j
+    end
+    df = NLsolve.OnceDifferentiable(f!,j!,fj!,solution(p),residual(p),jacobian(p))
+end
+
+function NLSolvers_nlsolve(p;kwargs...)
+    @assert uses_mutable_types(p)
+    workspace = NLSolvers_nlsolve_setup(p)
+    function update(workspace,p)
+        workspace = NLSolvers_nlsolve_setup(p)
+    end
+    function step(workspace,p,phase=:start;kwargs...)
+        if phase === :stop
+            return nothing
+        end
+        df = workspace
+        x = solution(p)
+        result = NLsolve.nlsolve(df,x;kwargs...)
+        copyto!(x,result.zero)
+        phase = :stop
+        workspace,p,phase
+    end
+    nonlinear_solver(update,step,p,workspace)
+end
+
+function NLSolvers_nlsolve_linsolve(solver,p)
+    x = solution(p)
+    A = jacobian(p)
+    r = residual(p)
+    dx = similar(x,axes(A,2))
+    lp = linear_problem(dx,A,r)
+    ls = solver(lp)
+    function linsolve(dx,A,b)
+        ls = update(ls,matrix=A)
+        ldiv!(dx,ls,b)
+        dx
+    end
+end
+
diff --git a/PartitionedSolvers/test/amg_tests.jl b/PartitionedSolvers/test/amg_tests.jl
index a7328c8c..82a849f9 100644
--- a/PartitionedSolvers/test/amg_tests.jl
+++ b/PartitionedSolvers/test/amg_tests.jl
@@ -1,7 +1,8 @@
 module AMGTests
 
 using PartitionedArrays
-using PartitionedSolvers
+import PartitionedSolvers as PS
+import PartitionedSolvers
 using LinearAlgebra
 using IterativeSolvers
 using IterativeSolvers: cg!
@@ -222,61 +223,56 @@ b = A*x
 y = similar(x)
 y .= 0
 
-solver = amg()
-S = setup(solver,y,A,b)
-solve!(y,S,b)
-update!(S,2*A)
-solve!(y,S,b)
-finalize!(S)
+p = PS.linear_problem(y,A,b)
 
-amg_statistics(S) |> display
+S = PS.amg(p)
+S = PS.solve(S)
+S = PS.update(S,matrix=2*A)
+S = PS.solve(S)
+
+PS.amg_statistics(S) |> display
 
 # Non-default options
 
-level_params = amg_level_params(;
-    pre_smoother = PartitionedSolvers.jacobi(;iters=10,omega=2/3),
-    cycle = v_cycle,
+level_params = PS.amg_level_params(;
+    pre_smoother = p->PartitionedSolvers.jacobi(p;iterations=10,omega=2/3),
+    cycle = PS.v_cycle,
    )
 
-fine_params = amg_fine_params(;
+fine_params = PS.amg_fine_params(;
     level_params,
     n_fine_levels=5)
 
 coarse_params = (;
-    coarse_solver = lu_solver(),
+    coarse_solver = PS.LinearAlgebra_lu,
     coarse_size = 15,
    )
 
-solver = amg(;fine_params,coarse_params)
 
 # Now with a nullspace
 
-B = default_nullspace(A)
-S = setup(solver,y,A,b;nullspace=B)
-solve!(y,S,b)
-update!(S,2*A;nullspace=B)
-solve!(y,S,b)
-finalize!(S)
+B = PS.default_nullspace(A)
+p = PS.linear_problem(y,A,b;nullspace=B)
+
+S = PS.amg(p;fine_params,coarse_params)
+PS.solve(S)
+PS.update(S,matrix=2*A)
+PS.solve(S)
 
 # Now as a preconditioner
 
-level_params = amg_level_params(;
-   pre_smoother = PartitionedSolvers.gauss_seidel(;iters=1),
+level_params = PS.amg_level_params(;
+   pre_smoother = p->PartitionedSolvers.gauss_seidel(p;iterations=1),
    )
 
-fine_params = amg_fine_params(;level_params)
-
-Pl = setup(amg(;fine_params),y,A,b;nullspace=B)
-y .= 0
-cg!(y,A,b;Pl,verbose=true)
+fine_params = PS.amg_fine_params(;level_params)
 
-solver = linear_solver(IterativeSolvers.cg;Pl=amg(;fine_params),verbose=true)
-S = setup(solver,y,A,b)
-solve!(y,S,b)
-update!(S,2*A)
-solve!(y,S,b)
+p = PS.linear_problem(y,A,b;nullspace=B)
 
+Pl = PS.amg(p;fine_params)
 
+y .= 0
+cg!(y,A,b;Pl,verbose=true)
 
 # Now for linear elasticity (sequential)
 
@@ -297,17 +293,16 @@ b = A*x_exact
 y = similar(x_exact)
 y .= 0
 
-level_params = amg_level_params_linear_elasticity(block_size)
-fine_params = amg_fine_params(;level_params)
-solver = amg(;fine_params)
+level_params = PS.amg_level_params_linear_elasticity(block_size)
+fine_params = PS.amg_fine_params(;level_params)
+
+p = PS.linear_problem(y,A,b;nullspace=B)
+Pl = PS.amg(p;fine_params)
 
-Pl = setup(solver,y,A,b;nullspace=B)
 cg!(y,A,b;Pl,verbose=true)
 println("Linear Elasticity norm of error: $(norm(y-x_exact))")
 @test y ≈ x_exact
 
-
-
 # Now in parallel
 
 parts_per_dir = (2,2)
@@ -322,36 +317,34 @@ b = A*x
 y = similar(x)
 y .= 0
 
-solver = amg()
-S = setup(solver,y,A,b)
-amg_statistics(S) |> display
-solve!(y,S,b)
-update!(S,2*A)
-solve!(y,S,b)
-finalize!(S)
+p = PS.linear_problem(y,A,b)
+
+S = PS.amg(p)
+PS.amg_statistics(S) |> display
+PS.solve(S)
+PS.update(S,matrix=2*A)
+PS.solve(S)
 
 # Now with a nullspace
 
-B = default_nullspace(A)
-solver = amg()
-S = setup(solver,y,A,b;nullspace=B)
-solve!(y,S,b)
-update!(S,2*A)
-solve!(y,S,b)
-finalize!(S)
-
-level_params = amg_level_params(;
-    pre_smoother = PartitionedSolvers.jacobi(;iters=1,omega=2/3),
-    coarsening = smoothed_aggregation(;repartition_threshold=10000000)
+B = PS.default_nullspace(A)
+p = PS.linear_problem(y,A,b;nullspace=B)
+S = PS.amg(p)
+PS.solve(S)
+PS.update(S,matrix=2*A)
+PS.solve(S)
+
+level_params = PS.amg_level_params(;
+    pre_smoother = p->PartitionedSolvers.jacobi(p;iterations=1,omega=2/3),
+    coarsening = PS.smoothed_aggregation(;repartition_threshold=10000000)
    )
 
-fine_params = amg_fine_params(;
+fine_params = PS.amg_fine_params(;
     level_params,
     n_fine_levels=5)
 
-solver = amg(;fine_params)
-
-Pl = setup(solver,y,A,b;nullspace=B)
+p = PS.linear_problem(y,A,b;nullspace=B)
+Pl = PS.amg(p;fine_params)
 y .= 0
 cg!(y,A,b;Pl,verbose=true)
 
@@ -364,7 +357,9 @@ x_exact = pones(partition(axes(A,2)))
 b = A*x_exact
 x = similar(b,axes(A,2))
 x .= 0
-Pl = setup(amg(),x,A,b)
+
+p = PS.linear_problem(x,A,b)
+Pl = PS.amg(p)
 _, history = cg!(x,A,b;Pl,log=true)
 display(history)
 
@@ -385,12 +380,12 @@ b = A*x_exact
 y = similar(b,axes(A,2))
 y .= 0
 
-level_params = amg_level_params_linear_elasticity(block_size)
-fine_params = amg_fine_params(;level_params)
-solver = amg(;fine_params)
+level_params = PS.amg_level_params_linear_elasticity(block_size)
+fine_params = PS.amg_fine_params(;level_params)
 
-Pl = setup(solver,y,A,b;nullspace=B) 
+p = PS.linear_problem(y,A,b;nullspace=B)
+Pl = PS.amg(p;fine_params)
 cg!(y,A,b;Pl,verbose=true)
 @test y ≈ x_exact
 
-end
\ No newline at end of file
+end
diff --git a/PartitionedSolvers/test/interfaces_tests.jl b/PartitionedSolvers/test/interfaces_tests.jl
new file mode 100644
index 00000000..e2a057a8
--- /dev/null
+++ b/PartitionedSolvers/test/interfaces_tests.jl
@@ -0,0 +1,259 @@
+module InterfacesTests
+
+import PartitionedSolvers as PS
+
+using Test
+
+function mock_linear_solver(p)
+    @assert ! PS.uses_mutable_types(p)
+    Ainv = 1/PS.matrix(p)
+    function update(Ainv,A)
+        Ainv = 1/A
+    end
+    function step(x,Ainv,b,phase=:start)
+        #if phase === :stop
+        #    return nothing
+        #end
+        x = Ainv*b
+        phase = :stop
+        x,Ainv,phase
+    end
+    uses_initial_guess = false
+    PS.linear_solver(update,step,p,Ainv;uses_initial_guess)
+end
+
+#function main()
+#    x = 0.0
+#    A = 2.0
+#    b = 12.0
+#    @time lp = PS.linear_problem(x,A,b)
+#    @time ls = mock_linear_solver(lp)
+#    @time ls = PS.solve(ls)
+#    @time x = PS.solution(ls)
+#    @time ls = PS.update(ls,matrix=2*A)
+#    @time ls = PS.solve(ls)
+#end
+#main()
+
+x = 0.0
+A = 2.0
+b = 12.0
+lp = PS.linear_problem(x,A,b;uses_mutable_types=false)
+ls = mock_linear_solver(lp)
+ls = PS.solve(ls)
+x = PS.solution(ls)
+@test x == A\b
+
+ls,phase = PS.step(ls)
+ls,phase = PS.step(ls,phase)
+@test phase === :stop
+
+ls = PS.update(ls,matrix=2*A)
+ls = PS.solve(ls)
+x = PS.solution(ls)
+@test x == (2*A)\b
+
+ls = PS.update(ls,rhs=4*b)
+ls = PS.solve(ls)
+x = PS.solution(ls)
+@test x == (2*A)\(4*b)
+
+for ls in PS.history(ls)
+    @show ls
+end
+
+for x in PS.history(PS.solution,ls)
+    @show x
+end
+
+function mock_nonlinear_problem(x0)
+    r0 = 0*x0
+    j0 = 0*x0
+    workspace = nothing
+    PS.nonlinear_problem(x0,r0,j0,workspace;uses_mutable_types=false) do p
+        x = PS.solution(p)
+        if PS.residual(p) !== nothing
+            p = PS.update(p,residual = 2*x^2 - 4)
+        end
+        if PS.jacobian(p) !== nothing
+            p = PS.update(p,jacobian = 4*x)
+        end
+        p
+    end |> PS.update
+end
+
+function mock_nonlinear_solver_update(ws,p)
+    (;ls,iteration,iterations) = ws
+    ls = PS.update(ls;matrix=PS.jacobian(p),rhs=PS.residual(p))
+    (;ls,iteration,iterations)
+end
+
+function mock_nonlinear_solver_step(ws,p,phase=:start)
+    #if phase === :stop
+    #    return nothing
+    #end
+    (;ls,iteration,iterations) = ws
+    if phase === :start
+        iteration = 0
+        phase = :advance
+    end
+    ls = PS.solve(ls)
+    x = PS.solution(p)
+    x -= PS.solution(ls)
+    p = PS.update(p,solution=x)
+    iteration += 1
+    if iteration == iterations
+        phase = :stop
+    end
+    ws = (;ls,iteration,iterations)
+    ws = mock_nonlinear_solver_update(ws,p)
+    ws,p,phase
+end
+
+function mock_nonlinear_solver(p;solver=mock_linear_solver,iterations=10)
+    @assert ! PS.uses_mutable_types(p)
+    iteration = 0
+    dx = PS.solution(p)
+    lp = PS.linear_problem(dx,PS.jacobian(p),PS.residual(p);uses_mutable_types=false)
+    ls = solver(lp)
+    workspace = (;ls,iteration,iterations)
+    PS.nonlinear_solver(
+        mock_nonlinear_solver_update,
+        mock_nonlinear_solver_step,
+        p,
+        workspace)
+end
+
+#function main()
+#    x = 1.0
+#    @time p = mock_nonlinear_problem(x)
+#    @time s = mock_nonlinear_solver(p)
+#    @time s = PS.solve(s)
+#end
+#
+#main()
+
+x = 1.0
+p = mock_nonlinear_problem(x)
+@show PS.residual(p)
+@show PS.jacobian(p)
+
+s = mock_nonlinear_solver(p)
+s = PS.solve(s)
+
+x = 1.0
+s = PS.update(s,solution=x)
+for x in PS.history(PS.solution,s)
+    @show x
+end
+
+function mock_ode(u)
+    r = 0*u
+    j = 0*u
+    v = 0*u
+    x = (0,u,v)
+    ts = (0,10)
+    dx = (u,u)
+    workspace = nothing
+    PS.ode_problem(x,r,j,ts,dx,workspace;uses_mutable_types=false) do ode
+        (t,u2,v2) = PS.solution(ode)
+        du,dv = PS.coefficients(ode)
+        if PS.residual(ode) !== nothing
+            ode = PS.update(ode,residual = 2*u2^2 + v2 - 4*t + 1)
+        end
+        if PS.jacobian(ode) !== nothing
+            ode = PS.update(ode,jacobian = 4*u2*du + dv)
+        end
+        ode
+    end |> PS.update
+end
+
+function mock_ode_solver_problem(x0,dt,ode0)
+    t,u, = PS.solution(ode0)
+    workspace = nothing
+    PS.nonlinear_problem(PS.residual(ode0),PS.jacobian(ode0),x0,workspace;uses_mutable_types=false) do p
+        x = PS.solution(p)
+        v = (x - u) / dt
+        r = PS.residual(p)
+        j = PS.jacobian(p)
+        ode = PS.update(ode0,residual=r,jacobian=p,solution=(t,x,v))
+        r = PS.residual(ode)
+        j = PS.jacobian(ode)
+        p = PS.update(p,residual=r,jacobian=j)
+    end
+end
+
+function mock_ode_solver_update(workspace,ode0)
+    (;s,dt) = workspace
+    ode = PS.update(ode0,coefficients=(1.0,1/dt))
+    x = PS.solution(s)
+    p = mock_ode_solver_problem(x,dt,ode)
+    s = PS.update(s,problem=p)
+    (;s,dt)
+end
+
+function mock_ode_solver_step(workspace,ode,phase=:start)
+    #if phase === :stop
+    #    return nothing
+    #end
+    (;s,dt) = workspace
+    t,u,v = PS.solution(ode)
+    if phase === :start
+        t = first(PS.interval(ode))
+        phase = :advance
+    end
+    s = PS.solve(s)
+    x = PS.solution(s)
+    t += dt
+    v = (x - u) / dt
+    u = x
+    ode = PS.update(ode,solution=(t,u,v))
+    tend = last(PS.interval(ode))
+    if t >= tend
+        phase = :stop
+    end
+    workspace = (;s,dt)
+    workspace = mock_ode_solver_update(workspace,ode)
+    workspace,ode,phase
+end
+
+function mock_ode_solver(ode0;
+        dt = (PS.interval(ode0)[2]-PS.interval(ode0)[1])/10,
+        solver = mock_nonlinear_solver)
+
+    @assert ! PS.uses_mutable_types(ode0)
+    ode = PS.update(ode0,coefficients=(1.0,1/dt))
+    _,u,_ = PS.solution(ode)
+    x = u
+    p = mock_ode_solver_problem(x,dt,ode)
+    s = solver(p)
+    workspace = (;s,dt)
+    PS.ode_solver(mock_ode_solver_update,mock_ode_solver_step,ode,workspace)
+end
+
+#function main()
+#    u = 2.0
+#    p = mock_ode(u)
+#    s = mock_ode_solver(p)
+#    for x in PS.history(s)
+#        @show x
+#    end
+#    s = PS.update(s,solution=(0.0,u,0.0))
+#    @time s = PS.solve(s)
+#end
+#
+#main()
+
+u = 2.0
+p = mock_ode(u)
+s = mock_ode_solver(p)
+
+for x in PS.history(PS.solution,s)
+    @show x
+end
+#s = PS.update(s,solution=(0.0,u,0.0))
+#PS.solve(s)
+
+end # module
+
+
diff --git a/PartitionedSolvers/test/runtests.jl b/PartitionedSolvers/test/runtests.jl
index 2b061722..a4ffbe83 100644
--- a/PartitionedSolvers/test/runtests.jl
+++ b/PartitionedSolvers/test/runtests.jl
@@ -5,6 +5,8 @@ using PartitionedSolvers
 using Test
 
 @testset "PartitionedSolvers" begin
+    @testset "interfaces" begin include("interfaces_tests.jl") end
+    @testset "wrappers" begin include("wrappers_tests.jl") end
     @testset "smoothers" begin include("smoothers_tests.jl") end
     @testset "amg" begin include("amg_tests.jl") end
 end
diff --git a/PartitionedSolvers/test/smoothers_tests.jl b/PartitionedSolvers/test/smoothers_tests.jl
index cc978fe4..4a34c673 100644
--- a/PartitionedSolvers/test/smoothers_tests.jl
+++ b/PartitionedSolvers/test/smoothers_tests.jl
@@ -1,108 +1,130 @@
 module SmoothersTests
 
+import PartitionedSolvers as PS
 using PartitionedArrays
-using PartitionedArrays: laplace_matrix
-using PartitionedSolvers
 using LinearAlgebra
 using Test
+using SparseMatricesCSR
 
 np = 4
 parts = DebugArray(LinearIndices((np,)))
-
 parts_per_dir = (2,2)
 nodes_per_dir = (8,8)
-A = laplace_matrix(nodes_per_dir,parts_per_dir,parts)
+args = laplacian_fem(nodes_per_dir,parts_per_dir,parts)
+A = psparse(args...) |> fetch
 x = pones(partition(axes(A,2)))
 b = A*x
 
-solver = lu_solver()
-y = similar(x)
-S = setup(solver,y,A,b)
-solve!(y,S,b)
 tol = 1.e-8
+y = similar(x)
+y .= 0
+p = PS.linear_problem(y,A,b)
+s = PS.jacobi(p;iterations=1000)
+s = PS.solve(s)
+y = PS.solution(s)
 @test norm(y-x)/norm(x) < tol
-update!(S,2*A)
-solve!(y,S,b)
-@test norm(y-x/2)/norm(x/2) < tol
 
-istep = Ref(0)
-for y in iterations!(y,S,b)
-    isa(y,AbstractVector)
-    istep[] += 1
+s = PS.update(s,matrix=2*A)
+for x in PS.history(PS.solution,s)
 end
-@test istep[] == 1
+@test norm(y-x/2)/norm(x/2) < tol
 
-y2,hist = solve!(y,S,b;history=true)
-@test y2 === y
-@test hist === nothing
+y .= 0
+p = PS.update(p,solution=y)
+s = PS.additive_schwarz(p)
+s = PS.solve(s)
 
-finalize!(S)
+s = PS.additive_schwarz(p;local_solver=PS.jacobi)
+s = PS.solve(s)
 
-solver = linear_solver(LinearAlgebra.lu)
-y = similar(x)
-S = setup(solver,y,A,b)
-solve!(y,S,b)
-tol = 1.e-8
-@test norm(y-x)/norm(x) < tol
-update!(S,2*A)
-solve!(y,S,b)
-@test norm(y-x/2)/norm(x/2) < tol
-finalize!(S)
+y .= 0
+p = PS.update(p,solution=y)
+s = PS.additive_schwarz(p;local_solver=PS.gauss_seidel)
+s = PS.solve(s)
 
-solver = richardson(lu_solver(),iters=1)
+args = laplacian_fdm(nodes_per_dir,parts_per_dir,parts)
+T = SparseMatrixCSR{1,Float64,Int32}
+A = psparse(T,args...;assembled=true) |> fetch
+x = pones(partition(axes(A,2)))
+b = A*x
 y = similar(x)
 y .= 0
-S = setup(solver,y,A,b)
-solve!(y,S,b)
-tol = 1.e-8
-@test norm(y-x)/norm(x) < tol
-update!(S,2*A)
-solve!(y,S,b)
-@test norm(y-x/2)/norm(x/2) < tol
-finalize!(S)
+p = PS.linear_problem(y,A,b)
+s = PS.gauss_seidel(p)
+s = PS.solve(s;zero_guess=true)
+n1 = norm(PS.solution(s))
+y .= 0
+s = PS.update(s,solution=y)
+s = PS.solve(s)
+n2 = norm(PS.solution(s))
+n = n1
+@test n ≈ n1 ≈ n2
 
-solver = jacobi(;iters=1000)
+args = laplacian_fdm(nodes_per_dir,parts_per_dir,parts)
+T = SparseMatrixCSR{1,Float64,Int32}
+A = psparse(T,args...;assembled=true,split_format=false) |> fetch
+x = pones(partition(axes(A,2)))
+b = A*x
 y = similar(x)
 y .= 0
-S = setup(solver,y,A,b)
-solve!(y,S,b)
-tol = 1.e-8
-@test norm(y-x)/norm(x) < tol
-update!(S,2*A)
-solve!(y,S,b)
-@test norm(y-x/2)/norm(x/2) < tol
-
-y,hist = solve!(y,S,b;history=true)
-@test hist.iters == 1000
+p = PS.linear_problem(y,A,b)
+s = PS.gauss_seidel(p)
+s = PS.solve(s;zero_guess=true)
+n1 = norm(PS.solution(s))
+y .= 0
+s = PS.update(s,solution=y)
+s = PS.solve(s)
+n2 = norm(PS.solution(s))
+@test n ≈ n1 ≈ n2
 
-istep = Ref(0)
+args = laplacian_fdm(nodes_per_dir,parts_per_dir,parts)
+A = psparse(args...;assembled=true,split_format=false) |> fetch
+x = pones(partition(axes(A,2)))
+b = A*x
+y = similar(x)
 y .= 0
-for y in iterations!(y,S,b)
-    isa(y,AbstractVector)
-    istep[] += 1
-end
-@test norm(y-x/2)/norm(x/2) < tol
-@test istep[] == 1000
-finalize!(S)
+p = PS.linear_problem(y,A,b)
+s = PS.gauss_seidel(p)
+s = PS.solve(s;zero_guess=true)
+n1 = norm(PS.solution(s))
+y .= 0
+s = PS.update(s,solution=y)
+s = PS.solve(s)
+n2 = norm(PS.solution(s))
+@test n ≈ n1 ≈ n2
 
-solver = additive_schwarz(lu_solver())
+args = laplacian_fdm(nodes_per_dir)
+A = sparse_matrix(args...)
+x = ones(axes(A,2))
+b = A*x
 y = similar(x)
 y .= 0
-S = setup(solver,y,A,b)
-solve!(y,S,b;zero_guess=true)
-solve!(y,S,b)
-ldiv!(y,S,b)
-update!(S,2*A)
-solve!(y,S,b)
-finalize!(S)
+p = PS.linear_problem(y,A,b)
+s = PS.gauss_seidel(p)
+s = PS.solve(s;zero_guess=true)
+n1 = norm(PS.solution(s))
+n = n1
+y .= 0
+s = PS.update(s,solution=y)
+s = PS.solve(s)
+n2 = norm(PS.solution(s))
+@test n ≈ n1 ≈ n2
 
-solver = additive_schwarz(gauss_seidel(;iters=1))
+args = laplacian_fdm(nodes_per_dir)
+A = sparse_matrix(T,args...)
+x = ones(axes(A,2))
+b = A*x
 y = similar(x)
 y .= 0
-S = setup(solver,y,A,b)
-solve!(y,S,b)
-update!(S,2*A)
-solve!(y,S,b)
-finalize!(S)
+p = PS.linear_problem(y,A,b)
+s = PS.gauss_seidel(p)
+s = PS.solve(s;zero_guess=true)
+n1 = norm(PS.solution(s))
+y .= 0
+s = PS.update(s,solution=y)
+s = PS.solve(s)
+n2 = norm(PS.solution(s))
+@test n ≈ n1 ≈ n2
+
 
-end #module
+end # module
diff --git a/PartitionedSolvers/test/wrappers_tests.jl b/PartitionedSolvers/test/wrappers_tests.jl
new file mode 100644
index 00000000..2fd85aee
--- /dev/null
+++ b/PartitionedSolvers/test/wrappers_tests.jl
@@ -0,0 +1,68 @@
+module WrappersTests
+
+import PartitionedSolvers as PS
+import PartitionedArrays as PA
+using Test
+using LinearAlgebra
+
+nodes=(10,10)
+args = PA.laplacian_fem(nodes)
+A = PA.sparse_matrix(args...)
+x_exact = ones(axes(A,2))
+b = A*x_exact
+
+x = similar(x_exact)
+p = PS.linear_problem(x,A,b)
+
+s = PS.LinearAlgebra_lu(p)
+s = PS.solve(s)
+x = PS.solution(s)
+@test x ≈ x_exact
+
+s = PS.update(s,matrix=2*A)
+s = PS.solve(s)
+x = PS.solution(s)
+@test x*2 ≈ x_exact
+
+ldiv!(x,s,b)
+PS.smooth!(x,s,b)
+
+x .= 0
+p = PS.update(p,solution=x)
+s = PS.IterativeSolvers_cg(p;verbose=true)
+s = PS.solve(s)
+
+Pl = PS.LinearAlgebra_lu(p)
+x .= 0
+p = PS.update(p,solution=x)
+s = PS.IterativeSolvers_cg(p;verbose=true,Pl)
+s = PS.solve(s)
+
+r = similar(b)
+w = nothing
+p = PS.nonlinear_problem(x,r,A,w) do p2
+    x2 = PS.solution(p2)
+    if PS.residual(p2) !== nothing
+        r2 = PS.residual(p2)
+        mul!(r2,A,x2)
+        r2 .-= b
+        p2 = PS.update(p2,residual = r2)
+    end
+    if PS.jacobian(p2) !== nothing
+        p2 = PS.update(p2,jacobian = A)
+    end
+    p2
+end
+
+x .= 0
+p = PS.update(p,solution=x)
+s = PS.NLSolvers_nlsolve(p;show_trace=true,method=:newton)
+s = PS.solve(s)
+
+linsolve = PS.NLSolvers_nlsolve_linsolve(PS.LinearAlgebra_lu,p)
+x .= 0
+p = PS.update(p,solution=x)
+s = PS.NLSolvers_nlsolve(p;show_trace=true,linsolve,method=:newton)
+s = PS.solve(s)
+
+end # module
diff --git a/docs/examples.jl b/docs/examples.jl
index 12af2f57..af09fdbe 100644
--- a/docs/examples.jl
+++ b/docs/examples.jl
@@ -236,10 +236,10 @@ history
 
 # Now solve the system while using an AMG preconditioner.
 
-using PartitionedSolvers
+import PartitionedSolvers as PS
 
 x .= 0
-Pl = setup(amg(),x,A,b)
+Pl = PS.amg(PS.linear_problem(x,A,b))
 _, history = IterativeSolvers.cg!(x,A,b;Pl,log=true)
 history
 
diff --git a/src/gallery.jl b/src/gallery.jl
index f68fec9d..b48f9575 100644
--- a/src/gallery.jl
+++ b/src/gallery.jl
@@ -85,6 +85,18 @@ function laplacian_fdm(
     I,J,V,node_partition,node_partition
 end
 
+function laplacian_fdm(nodes_per_dir;kwargs...)
+    parts_per_dir = map(n->1,nodes_per_dir)
+    ranks = LinearIndices((1,))
+    args = laplacian_fdm(nodes_per_dir,parts_per_dir,ranks;kwargs...)
+    I = getany(args[1])
+    J = getany(args[2])
+    V = getany(args[3])
+    nrows = length(PRange(args[4]))
+    ncols = length(PRange(args[5]))
+    I,J,V,nrows,ncols
+end
+
 """
     laplacian_fem(
             nodes_per_dir,
@@ -226,6 +238,18 @@ function laplacian_fem(
     I,J,V,node_partition,node_partition
 end
 
+function laplacian_fem(nodes_per_dir;kwargs...)
+    parts_per_dir = map(n->1,nodes_per_dir)
+    ranks = LinearIndices((1,))
+    args = laplacian_fem(nodes_per_dir,parts_per_dir,ranks;kwargs...)
+    I = getany(args[1])
+    J = getany(args[2])
+    V = getany(args[3])
+    nrows = length(PRange(args[4]))
+    ncols = length(PRange(args[5]))
+    I,J,V,nrows,ncols
+end
+
 function linear_elasticity_fem(
         nodes_per_dir, # free (== interior) nodes
         parts_per_dir,
@@ -386,6 +410,18 @@ function linear_elasticity_fem(
     I,J,V,dof_partition,dof_partition
 end
 
+function linear_elasticity_fem(nodes_per_dir;kwargs...)
+    parts_per_dir = map(n->1,nodes_per_dir)
+    ranks = LinearIndices((1,))
+    args = linear_elasticity_fem(nodes_per_dir,parts_per_dir,ranks;kwargs...)
+    I = getany(args[1])
+    J = getany(args[2])
+    V = getany(args[3])
+    nrows = length(PRange(args[4]))
+    ncols = length(PRange(args[5]))
+    I,J,V,nrows,ncols
+end
+
 function node_to_dof_partition(node_partition,D)
     global_node_to_owner = global_to_owner(node_partition)
     dof_partition = map(node_partition) do mynodes
diff --git a/src/p_sparse_matrix.jl b/src/p_sparse_matrix.jl
index f2b65fc8..6ad7b98e 100644
--- a/src/p_sparse_matrix.jl
+++ b/src/p_sparse_matrix.jl
@@ -1353,7 +1353,7 @@ function assemble!(B::PSparseMatrix,A::PSparseMatrix,cache)
     psparse_assemble_impl!(B,A,T,cache)
 end
 
-function psparse_assemble_impl(A,::Type,rows)
+function psparse_assemble_impl(A,::Type,rows;kwargs...)
     error("Case not implemented yet")
 end