-
Notifications
You must be signed in to change notification settings - Fork 0
/
index.xml
2847 lines (2468 loc) · 596 KB
/
index.xml
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
853
854
855
856
857
858
859
860
861
862
863
864
865
866
867
868
869
870
871
872
873
874
875
876
877
878
879
880
881
882
883
884
885
886
887
888
889
890
891
892
893
894
895
896
897
898
899
900
901
902
903
904
905
906
907
908
909
910
911
912
913
914
915
916
917
918
919
920
921
922
923
924
925
926
927
928
929
930
931
932
933
934
935
936
937
938
939
940
941
942
943
944
945
946
947
948
949
950
951
952
953
954
955
956
957
958
959
960
961
962
963
964
965
966
967
968
969
970
971
972
973
974
975
976
977
978
979
980
981
982
983
984
985
986
987
988
989
990
991
992
993
994
995
996
997
998
999
1000
<?xml version="1.0" encoding="utf-8" standalone="yes" ?>
<rss version="2.0" xmlns:atom="http://www.w3.org/2005/Atom">
<channel>
<title>FS</title>
<link>https://frankschae.github.io/</link>
<atom:link href="https://frankschae.github.io/index.xml" rel="self" type="application/rss+xml" />
<description>FS</description>
<generator>Hugo Blox Builder (https://hugoblox.com)</generator><language>en-us</language><lastBuildDate>Mon, 24 Oct 2022 00:00:00 +0000</lastBuildDate>
<image>
<url>https://frankschae.github.io/media/icon_hu0b7a4cb9992c9ac0e91bd28ffd38dd00_9727_512x512_fill_lanczos_center_3.png</url>
<title>FS</title>
<link>https://frankschae.github.io/</link>
</image>
<item>
<title>Example Talk</title>
<link>https://frankschae.github.io/talk/example-talk/</link>
<pubDate>Sat, 01 Jun 2030 13:00:00 +0000</pubDate>
<guid>https://frankschae.github.io/talk/example-talk/</guid>
<description><div class="alert alert-note">
<div>
Click on the <strong>Slides</strong> button above to view the built-in slides feature.
</div>
</div>
<p>Slides can be added in a few ways:</p>
<ul>
<li><strong>Create</strong> slides using Hugo Blox Builder&rsquo;s <a href="https://docs.hugoblox.com/reference/content-types/" target="_blank" rel="noopener"><em>Slides</em></a> feature and link using <code>slides</code> parameter in the front matter of the talk file</li>
<li><strong>Upload</strong> an existing slide deck to <code>static/</code> and link using <code>url_slides</code> parameter in the front matter of the talk file</li>
<li><strong>Embed</strong> your slides (e.g. Google Slides) or presentation video on this page using <a href="https://docs.hugoblox.com/reference/markdown/" target="_blank" rel="noopener">shortcodes</a>.</li>
</ul>
<p>Further event details, including <a href="https://docs.hugoblox.com/reference/markdown/" target="_blank" rel="noopener">page elements</a> such as image galleries, can be added to the body of this page.</p>
</description>
</item>
<item>
<title>Automatic Differentiation of Programs with Discrete Randomness</title>
<link>https://frankschae.github.io/project/stochasticad/</link>
<pubDate>Mon, 15 Jan 2024 10:49:24 -0500</pubDate>
<guid>https://frankschae.github.io/project/stochasticad/</guid>
<description><p>Automatic differentiation (AD) has become ubiquitous throughout scientific computing and deep learning. However, AD systems have been restricted to the subset of programs that have a continuous dependence on parameters. Programs that have discrete stochastic behaviors governed by distribution parameters, such as flipping a coin with probability p of being heads, pose a challenge to these systems. In this work we develop a new AD methodology for programs with discrete randomness. We demonstrate how this method gives an unbiased and low-variance estimator.</p>
</description>
</item>
<item>
<title>Performance Bounds for Quantum Control</title>
<link>https://frankschae.github.io/project/controlbounds/</link>
<pubDate>Sat, 29 Apr 2023 17:26:49 -0400</pubDate>
<guid>https://frankschae.github.io/project/controlbounds/</guid>
<description><p>Control of devices at the quantum level holds enormous potential for current and future applications in the field of quantum information science. However, due to the nonlinear and stochastic nature of quantum systems under continuous observation, analytical solutions to all but the simplest quantum control problems remain unknown. In this project, we present a convex optimization framework to compute informative bounds on the best attainable control performance. Since our approach provides an under-approximator for the value function, we can use it directly to construct near-optimal heuristic controllers as demonstrated for a qubit subjected to homodyne detection and photon counting.</p>
</description>
</item>
<item>
<title>StochasticAD.jl: Automatic Differentiation of Programs with Discrete Randomness</title>
<link>https://frankschae.github.io/software/stochasticad/</link>
<pubDate>Mon, 16 Jan 2023 00:20:37 +0100</pubDate>
<guid>https://frankschae.github.io/software/stochasticad/</guid>
<description></description>
</item>
<item>
<title>AbstractDifferentiation.jl: Backend-Agnostic Differentiable Programming in Julia</title>
<link>https://frankschae.github.io/software/abstractdifferentiation/</link>
<pubDate>Fri, 03 Dec 2021 00:20:37 +0100</pubDate>
<guid>https://frankschae.github.io/software/abstractdifferentiation/</guid>
<description></description>
</item>
<item>
<title>Neural Hybrid Differential Equations and Adjoint Sensitivity Analysis</title>
<link>https://frankschae.github.io/post/gsoc-2021/</link>
<pubDate>Fri, 13 Aug 2021 21:41:45 +0200</pubDate>
<guid>https://frankschae.github.io/post/gsoc-2021/</guid>
<description><h2 id="project-summary">Project summary</h2>
<p>In this project, we have implemented state-of-the-art sensitivity tools for chaotic dynamical systems, continuous adjoint sensitivity methods for hybrid differential equations, as well as a high level API for automatic differentiation.</p>
<p>Possible fields of application for these tools range from model discovery with explicit dosing times in pharmacology, over accurate gradient estimates for chaotic fluid dynamics, to the control of open quantum systems. A more detailed summary is available <a href="https://summerofcode.withgoogle.com/projects/#5357798591823872" target="_blank" rel="noopener">on the GSoC page</a>.</p>
<h2 id="blog-posts">Blog posts</h2>
<p>The following blog posts describe the work throughout the GSoC period in more detail:</p>
<ol>
<li><a href="https://frankschae.github.io/post/hybridde/" target="_blank" rel="noopener">Neural Hybrid Differential Equations</a></li>
<li><a href="https://frankschae.github.io/post/shadowing/" target="_blank" rel="noopener">Shadowing Methods for Forward and Adjoint Sensitivity Analysis of Chaotic Systems</a></li>
<li><a href="https://frankschae.github.io/post/bouncing_ball/" target="_blank" rel="noopener">Sensitivity Analysis of Hybrid Differential Equations</a></li>
<li><a href="https://frankschae.github.io/post/abstract_differentiation/" target="_blank" rel="noopener">AbstractDifferentiation.jl for AD-backend agnostic code</a></li>
</ol>
<h2 id="docs">Docs</h2>
<p>Documentation with respect to the adjoint sensitivity tools will be available <a href="https://diffeq.sciml.ai/latest/analysis/sensitivity/" target="_blank" rel="noopener">on the local sensitivity analysis</a> and <a href="http://scimlbase.sciml.ai/dev/fundamentals/Differentiation/" target="_blank" rel="noopener">on the control of automatic differentiation choices</a> pages.</p>
<h2 id="achievements">Achievements</h2>
<p>Below is a list of PRs in the various repositories in chronological order.</p>
<h4 id="diffeqsensitivityjl">DiffEqSensitivity.jl</h4>
<p>Merged:</p>
<ul>
<li><a href="https://github.com/SciML/DiffEqSensitivity.jl/pull/415" target="_blank" rel="noopener">Add additive noise downstream test for DiffEqFlux</a></li>
<li><a href="https://github.com/SciML/DiffEqSensitivity.jl/pull/416" target="_blank" rel="noopener">DiscreteCallback fixes</a></li>
<li><a href="https://github.com/SciML/DiffEqSensitivity.jl/pull/417" target="_blank" rel="noopener">Allow for changes of p in callbacks</a></li>
<li><a href="https://github.com/SciML/DiffEqSensitivity.jl/pull/418" target="_blank" rel="noopener">Fix for using the correct uleft/pleft in continuous callback</a></li>
<li><a href="https://github.com/SciML/DiffEqSensitivity.jl/pull/419" target="_blank" rel="noopener">Fix broadcasting error on steady state adjoint</a></li>
<li><a href="https://github.com/SciML/DiffEqSensitivity.jl/pull/420" target="_blank" rel="noopener">Forward Least Squares Shadowing (LSS)</a></li>
<li><a href="https://github.com/SciML/DiffEqSensitivity.jl/pull/422" target="_blank" rel="noopener">Adjoint-mode for the LSS method</a></li>
<li><a href="https://github.com/SciML/DiffEqSensitivity.jl/pull/423" target="_blank" rel="noopener">concrete_solve dispatch for LSS methods</a></li>
<li><a href="https://github.com/SciML/DiffEqSensitivity.jl/pull/437" target="_blank" rel="noopener">Non-Intrusive Least Square Shadowing (NILSS)</a></li>
<li><a href="https://github.com/SciML/DiffEqSensitivity.jl/pull/442" target="_blank" rel="noopener">concrete_solve for NILSS</a></li>
<li><a href="https://github.com/SciML/DiffEqSensitivity.jl/pull/443" target="_blank" rel="noopener">Remove allocation in NILSS</a></li>
<li><a href="https://github.com/SciML/DiffEqSensitivity.jl/pull/444" target="_blank" rel="noopener">Handle additional callback case</a></li>
<li><a href="https://github.com/SciML/DiffEqSensitivity.jl/pull/445" target="_blank" rel="noopener">State-dependent Continuous Callbacks for BacksolveAdjoint</a></li>
<li><a href="https://github.com/SciML/DiffEqSensitivity.jl/pull/474" target="_blank" rel="noopener">QuadratureAdjoint() for ContinuousCallback</a></li>
<li><a href="https://github.com/SciML/DiffEqSensitivity.jl/pull/475" target="_blank" rel="noopener">More tests for Neural ODEs with callbacks for different sensitivity algorithms</a></li>
<li><a href="https://github.com/SciML/DiffEqSensitivity.jl/pull/476" target="_blank" rel="noopener">Support for PeriodicCallbacks in continuous adjoint methods</a></li>
</ul>
<h4 id="abstractdifferentiationjl">AbstractDifferentiation.jl</h4>
<p>Merged:</p>
<ul>
<li><a href="https://github.com/JuliaDiff/AbstractDifferentiation.jl/pull/2" target="_blank" rel="noopener">Fixes gradient, Jacobian, Hessian, and vjp tests</a></li>
</ul>
<p>Open:</p>
<ul>
<li><a href="https://github.com/JuliaDiff/AbstractDifferentiation.jl/pull/3" target="_blank" rel="noopener">Add ForwardDiff and Zygote</a></li>
</ul>
<h4 id="ordinarydiffeqjl">OrdinaryDiffEq.jl</h4>
<p>Merged:</p>
<ul>
<li><a href="https://github.com/SciML/OrdinaryDiffEq.jl/pull/1424" target="_blank" rel="noopener">Fix discrete reverse mode for some standard controllers</a></li>
</ul>
<h4 id="diffeqcallbacksjl">DiffEqCallbacks.jl</h4>
<p>Merged:</p>
<ul>
<li><a href="https://github.com/SciML/DiffEqCallbacks.jl/pull/102" target="_blank" rel="noopener">Introduce a PeriodicCallbackAffect struct</a></li>
</ul>
<h4 id="steadystatediffeqjl">SteadyStateDiffEq.jl</h4>
<p>Merged:</p>
<ul>
<li><a href="https://github.com/SciML/SteadyStateDiffEq.jl/pull/31" target="_blank" rel="noopener">Convert alg.tspan to type of prob.u0</a></li>
</ul>
<h4 id="chainrulesjl">ChainRules.jl</h4>
<p>Merged:</p>
<ul>
<li><a href="https://github.com/JuliaDiff/ChainRules.jl/pull/506" target="_blank" rel="noopener">Do not differentiate through the construction of BitArray</a></li>
<li><a href="https://github.com/JuliaDiff/ChainRules.jl/pull/508" target="_blank" rel="noopener">Use splatting in BitArray</a></li>
</ul>
<h4 id="diffeqnoiseprocessjl">DiffEqNoiseProcess.jl</h4>
<p>Merged:</p>
<ul>
<li><a href="https://github.com/SciML/DiffEqNoiseProcess.jl/pull/94" target="_blank" rel="noopener">Allow solvers to use Noise Grid with SVectors</a></li>
</ul>
<h4 id="stochasticdiffeqjl">StochasticDiffEq.jl</h4>
<p>Merged:</p>
<ul>
<li><a href="https://github.com/SciML/StochasticDiffEq.jl/pull/428" target="_blank" rel="noopener">Remove Ihat2 matrix from weak solvers</a></li>
</ul>
<h4 id="diffeqdocsjl">DiffEqDocs.jl</h4>
<p>Merged:</p>
<ul>
<li><a href="https://github.com/SciML/DiffEqDocs.jl/pull/490" target="_blank" rel="noopener">Small typo on plot page</a></li>
<li><a href="https://github.com/SciML/DiffEqDocs.jl/pull/492" target="_blank" rel="noopener">Add docs for shadowing methods</a></li>
</ul>
<h2 id="future-work">Future work</h2>
<p>Besides the implementation of more shadowing methods, such as</p>
<ul>
<li><a href="https://arxiv.org/abs/1801.08674" target="_blank" rel="noopener">NILSAS</a>,</li>
<li><a href="https://arxiv.org/abs/1711.06633" target="_blank" rel="noopener">FD-NILSS</a>, or</li>
<li><a href="https://arxiv.org/abs/2009.00595" target="_blank" rel="noopener">Fast linear response</a>,</li>
</ul>
<p>we are planning to</p>
<ul>
<li>benchmark the new adjoints,</li>
<li>refine the AbstractDifferentiation.jl package and use it within DiffEqSensitivity.jl,</li>
<li>add more docs and examples.</li>
</ul>
<p>If you have any further suggestions or comments, check out our slac/zulip channels #sciml-bridged and #diffeq-bridged or the <a href="https://discourse.julialang.org/" target="_blank" rel="noopener">Julia language discourse</a>.</p>
<h2 id="acknowledgement">Acknowledgement</h2>
<p>Many thanks to my mentors <a href="https://github.com/ChrisRackauckas" target="_blank" rel="noopener">Chris Rackauckas</a>, <a href="https://github.com/mschauer" target="_blank" rel="noopener">Moritz Schauer</a>, <a href="https://github.com/YingboMa" target="_blank" rel="noopener">Yingbo Ma</a>, and <a href="https://github.com/mohamed82008" target="_blank" rel="noopener">Mohamed Tarek</a> for their unique, continuous support. It was a great opportunity to be part of such an inspiring collaboration. I highly appreciate our quick and flexible meeting times.
I would also like to thank <a href="https://quantumtheory-bruder.physik.unibas.ch/en/people/group-members/christoph-bruder/" target="_blank" rel="noopener">Christoph Bruder</a>, <a href="https://github.com/arnoldjulian" target="_blank" rel="noopener">Julian Arnold</a>, and <a href="https://github.com/mako-git" target="_blank" rel="noopener">Martin Koppenhöfer</a> for helpful comments on my blog posts. Special thanks to <a href="https://github.com/Zymrael" target="_blank" rel="noopener">Michael Poli</a> and <a href="https://github.com/massastrello" target="_blank" rel="noopener">Stefano Massaroli</a> for their suggestions on adjoints for hybrid differential equations. Finally, thanks to the very supportive julia community and to Google&rsquo;s open source program for funding this experience!</p>
</description>
</item>
<item>
<title>AbstractDifferentiation.jl for AD-backend agnostic code </title>
<link>https://frankschae.github.io/post/abstract_differentiation/</link>
<pubDate>Sun, 01 Aug 2021 12:03:17 +0200</pubDate>
<guid>https://frankschae.github.io/post/abstract_differentiation/</guid>
<description><p><a href="https://sinews.siam.org/Details-Page/scientific-machine-learning-how-julia-employs-differentiable-programming-to-do-it-best" target="_blank" rel="noopener">Differentiable programming (∂P)</a>, i.e., the ability to differentiate general computer program structures, has enabled the efficient combination of existing packages for scientific computation and machine learning<sup id="fnref:1"><a href="#fn:1" class="footnote-ref" role="doc-noteref">1</a></sup>. The Julia<sup id="fnref:2"><a href="#fn:2" class="footnote-ref" role="doc-noteref">2</a></sup> language is <a href="https://github.com/tensorflow/swift/blob/main/docs/WhySwiftForTensorFlow.md" target="_blank" rel="noopener">well suited for ∂P</a>, see also Chris&rsquo; article<sup id="fnref:3"><a href="#fn:3" class="footnote-ref" role="doc-noteref">3</a></sup> for a detailed examination. There is already a plethora of examples where ∂P has provided massive performance <em>and</em> accuracy advantages over black-box approaches to machine learning. This is because black-box machine learning approaches are flexible but require a large amount of data. Incorporating previously acquired knowledge about the structure of a problem reduces the amount of data and allows the learning task to be simplified<sup id="fnref:4"><a href="#fn:4" class="footnote-ref" role="doc-noteref">4</a></sup>, for example, by focusing on learning only the parts of the model that are actually missing<sup id="fnref1:4"><a href="#fn:4" class="footnote-ref" role="doc-noteref">4</a></sup> <sup id="fnref:5"><a href="#fn:5" class="footnote-ref" role="doc-noteref">5</a></sup>. In the context of quantum control, we have demonstrated the power of this framework for closed<sup id="fnref:6"><a href="#fn:6" class="footnote-ref" role="doc-noteref">6</a></sup> and <a href="https://www.youtube.com/watch?v=uDUwdAqKzYM&amp;list=PLP8iPy9hna6TxktMt-IzdU2vQpGp3bwDn&amp;index=3&amp;t=12s" target="_blank" rel="noopener">open quantum systems</a><sup id="fnref:7"><a href="#fn:7" class="footnote-ref" role="doc-noteref">7</a></sup>.</p>
<p>∂P is (commonly) realized by automatic differentiation (AD), which is a family of techniques to efficiently and accurately differentiate numeric functions expressed as computer programs. Generally, besides forward- and reverse-mode AD, the two main branches of AD, <a href="https://juliadiff.org/" target="_blank" rel="noopener">a large variety of software implementations</a> with different <a href="https://discourse.julialang.org/t/state-of-automatic-differentiation-in-julia/43083" target="_blank" rel="noopener">pros and cons</a> exists. The goal is to make the best choice in every part of the program without requiring users to significantly customize their code. Having a common ground by <a href="https://github.com/JuliaDiff/ChainRules.jl" target="_blank" rel="noopener">ChainRules.jl</a> empowers this idea of a <a href="http://www.stochasticlifestyle.com/glue-ad-for-full-language-differentiable-programming/" target="_blank" rel="noopener">Glue AD</a> where backend developers just define ChainRules overloads. However, switching from one backend to another on the user side can still be tedious because the user has to look up the syntax of the new AD package.</p>
<p><a href="https://github.com/mohamed82008" target="_blank" rel="noopener">Mohamed Tarek</a> has started to <a href="https://github.com/JuliaDiff/AbstractDifferentiation.jl/pull/1" target="_blank" rel="noopener">implement a high level API for differentiation</a> that unifies the APIs of all the AD packages in the Julia ecosystem. Ultimately, the API of our new package, <a href="https://github.com/JuliaDiff/AbstractDifferentiation.jl" target="_blank" rel="noopener">AbstractDifferentiation.jl</a>, aims at enabling AD users to write AD backend-agnostic code. This will greatly facilitate the switching between different AD packages. Once the interface is completed and all tests are added, it is also planned that <a href="https://github.com/SciML/DiffEqSensitivity.jl" target="_blank" rel="noopener">DiffEqSensitivity.jl</a> within the <a href="https://sciml.ai/" target="_blank" rel="noopener">SciML</a> software suite adopts AbstractDifferentiation.jl as a better way of handling AD choices. In this part of my GSoC project, I&rsquo;ve started to fix remaining errors of the <a href="https://github.com/JuliaDiff/AbstractDifferentiation.jl/pull/1" target="_blank" rel="noopener">initial PR</a>.</p>
<p>The interested reader is encouraged to look at Mohamed&rsquo;s <a href="https://github.com/JuliaDiff/AbstractDifferentiation.jl/pull/1" target="_blank" rel="noopener">first PR</a> for a complete list of functions provided by AbstractDifferentiation.jl (and some great discussions about the package). In the rest of this blog post, I will focus on a concrete example to illustrate the main idea.</p>
<h2 id="optimization-of-the-rosenbrock-function">Optimization of the Rosenbrock function</h2>
<p>The <a href="https://en.wikipedia.org/wiki/Rosenbrock_function" target="_blank" rel="noopener">Rosenbrock function</a> is defined by</p>
<p>$$
g(x_1,x_2) = (a-x_1)^2 + b(x_2-x_1^2)^2.
$$</p>
<p>The function $g$ has a global minimum at $(x_1^\star, x_2^\star)= (a, a^2)$ with $g(x_1^\star, x_2^\star)=0$. In the following, we fix $a = 1$ and $b = 100$. The global minimum is located inside a long, narrow, banana-shaped, flat valley, which makes the function a common test case for optimization algorithms.</p>
<p>Let us now implement the <a href="https://en.wikipedia.org/wiki/Gauss%E2%80%93Newton_algorithm" target="_blank" rel="noopener">Gauss–Newton algorithm</a> to find the global minimum. The Gauss–Newton algorithm iteratively finds the value of the $N$ variables ${\bf{x}}=(x_1,\dots, x_N)$ that minimize the sum of squares of $M$ residuals $(f_1,\dots, f_M)$</p>
<p>$$
S({\bf x}) = \frac{1}{2} \sum_{i=1}^M f_i({\bf x})^2.
$$</p>
<p>Starting from an initial guess ${\bf x_0}$ for the minimum, the method runs through the iterations</p>
<p>$$
{\bf x}^{k+1} = {\bf x}^k - \alpha_k \left(J^T J \right)^{-1} J^T f({\bf x}^k),
$$
where $J$ is the Jacobian matrix at ${\bf{x}}^k$ and $\alpha_k$ is the step length determined via a <a href="https://de.wikipedia.org/wiki/Gau%C3%9F-Newton-Verfahren#Beispiel" target="_blank" rel="noopener">line search subroutine</a>.</p>
<p>The following plot shows the Rosenbrock function in 3D as well as a 2D heatmap including the global minimum ${\bf x^\star}=(1,1)$ and our initial guess ${\bf x_0}=(0,-0.1)$.</p>
<div class="highlight"><pre tabindex="0" class="chroma"><code class="language-julia" data-lang="julia"><span class="line"><span class="cl"><span class="k">using</span> <span class="n">Pkg</span>
</span></span><span class="line"><span class="cl"><span class="n">path</span> <span class="o">=</span> <span class="nd">@__DIR__</span>
</span></span><span class="line"><span class="cl"><span class="n">cd</span><span class="p">(</span><span class="n">path</span><span class="p">);</span> <span class="n">Pkg</span><span class="o">.</span><span class="n">activate</span><span class="p">(</span><span class="s">&#34;.&#34;</span><span class="p">);</span> <span class="n">Pkg</span><span class="o">.</span><span class="n">instantiate</span><span class="p">()</span>
</span></span><span class="line"><span class="cl">
</span></span><span class="line"><span class="cl"><span class="c">## AbstractDifferentiation is not released yet!!</span>
</span></span><span class="line"><span class="cl">
</span></span><span class="line"><span class="cl"><span class="k">using</span> <span class="n">AbstractDifferentiation</span>
</span></span><span class="line"><span class="cl"><span class="k">using</span> <span class="n">Test</span><span class="p">,</span> <span class="n">LinearAlgebra</span>
</span></span><span class="line"><span class="cl"><span class="k">using</span> <span class="n">FiniteDifferences</span><span class="p">,</span> <span class="n">ForwardDiff</span><span class="p">,</span> <span class="n">Zygote</span>
</span></span><span class="line"><span class="cl"><span class="k">using</span> <span class="n">Enzyme</span><span class="p">,</span> <span class="n">UnPack</span>
</span></span><span class="line"><span class="cl"><span class="k">using</span> <span class="n">Plots</span><span class="p">,</span> <span class="n">LaTeXStrings</span>
</span></span><span class="line"><span class="cl"><span class="c"># using Diffractor: ∂⃖¹ ## Diffractor needs &gt;[email protected]</span>
</span></span><span class="line"><span class="cl">
</span></span><span class="line"><span class="cl"><span class="c">## Rosenbrock function</span>
</span></span><span class="line"><span class="cl"><span class="c"># R: R^2 -&gt; R: x -&gt; (a-x₁)² + b(x₂-x₁²)²</span>
</span></span><span class="line"><span class="cl"><span class="n">g</span><span class="p">(</span><span class="n">x</span><span class="p">,</span><span class="n">p</span><span class="p">)</span> <span class="o">=</span> <span class="p">(</span><span class="n">p</span><span class="p">[</span><span class="mi">1</span><span class="p">]</span><span class="o">-</span><span class="n">x</span><span class="p">[</span><span class="mi">1</span><span class="p">])</span><span class="o">^</span><span class="mi">2</span> <span class="o">+</span> <span class="n">p</span><span class="p">[</span><span class="mi">2</span><span class="p">]</span><span class="o">*</span><span class="p">(</span><span class="n">x</span><span class="p">[</span><span class="mi">2</span><span class="p">]</span><span class="o">-</span><span class="n">x</span><span class="p">[</span><span class="mi">1</span><span class="p">]</span><span class="o">^</span><span class="mi">2</span><span class="p">)</span><span class="o">^</span><span class="mi">2</span>
</span></span><span class="line"><span class="cl">
</span></span><span class="line"><span class="cl"><span class="c"># visualization</span>
</span></span><span class="line"><span class="cl"><span class="n">p</span> <span class="o">=</span> <span class="p">[</span><span class="mf">1.0</span><span class="p">,</span><span class="mf">100.0</span><span class="p">]</span>
</span></span><span class="line"><span class="cl"><span class="n">x₀</span> <span class="o">=</span> <span class="p">[</span><span class="mf">0.0</span><span class="p">,</span><span class="o">-</span><span class="mf">0.1</span><span class="p">]</span>
</span></span><span class="line"><span class="cl"><span class="n">xopt</span> <span class="o">=</span> <span class="p">[</span><span class="mf">1.0</span><span class="p">,</span><span class="mf">1.0</span><span class="p">]</span>
</span></span><span class="line"><span class="cl">
</span></span><span class="line"><span class="cl"><span class="n">do_plot</span> <span class="o">=</span> <span class="nb">true</span>
</span></span><span class="line"><span class="cl"><span class="k">if</span> <span class="n">do_plot</span>
</span></span><span class="line"><span class="cl"> <span class="n">x₁</span><span class="p">,</span> <span class="n">x₂</span> <span class="o">=</span> <span class="o">-</span><span class="mf">2.0</span><span class="o">:</span><span class="mf">0.01</span><span class="o">:</span><span class="mf">2.0</span><span class="p">,</span> <span class="o">-</span><span class="mf">0.6</span><span class="o">:</span><span class="mf">0.01</span><span class="o">:</span><span class="mf">3.5</span>
</span></span><span class="line"><span class="cl"> <span class="n">z</span> <span class="o">=</span> <span class="n">Surface</span><span class="p">((</span><span class="n">x₁</span><span class="p">,</span><span class="n">x₂</span><span class="p">)</span><span class="o">-&gt;</span><span class="n">g</span><span class="p">([</span><span class="n">x₁</span><span class="p">,</span><span class="n">x₂</span><span class="p">],</span><span class="n">p</span><span class="p">),</span> <span class="n">x₁</span><span class="p">,</span> <span class="n">x₂</span><span class="p">)</span>
</span></span><span class="line"><span class="cl"> <span class="n">pl1</span> <span class="o">=</span> <span class="n">surface</span><span class="p">(</span><span class="n">x₁</span><span class="p">,</span><span class="n">x₂</span><span class="p">,</span><span class="n">z</span><span class="p">,</span> <span class="n">linealpha</span> <span class="o">=</span> <span class="mf">0.3</span><span class="p">,</span> <span class="n">c</span><span class="o">=</span><span class="n">cgrad</span><span class="p">(</span><span class="ss">:thermal</span><span class="p">,</span> <span class="n">scale</span> <span class="o">=</span> <span class="ss">:exp</span><span class="p">),</span> <span class="n">colorbar</span><span class="o">=</span><span class="nb">true</span><span class="p">,</span>
</span></span><span class="line"><span class="cl"> <span class="n">labelfontsize</span><span class="o">=</span><span class="mi">20</span><span class="p">,</span><span class="n">camera</span> <span class="o">=</span> <span class="p">(</span><span class="mi">3</span><span class="p">,</span><span class="mi">50</span><span class="p">),</span>
</span></span><span class="line"><span class="cl"> <span class="n">xlabel</span> <span class="o">=</span> <span class="sa">L</span><span class="s">&#34;x_1&#34;</span><span class="p">,</span> <span class="n">ylabel</span> <span class="o">=</span> <span class="sa">L</span><span class="s">&#34;x_2&#34;</span><span class="p">)</span>
</span></span><span class="line"><span class="cl">
</span></span><span class="line"><span class="cl"> <span class="n">pl2</span> <span class="o">=</span> <span class="n">heatmap</span><span class="p">(</span><span class="n">x₁</span><span class="p">,</span><span class="n">x₂</span><span class="p">,</span><span class="n">z</span><span class="p">,</span> <span class="n">c</span><span class="o">=</span><span class="n">cgrad</span><span class="p">(</span><span class="ss">:thermal</span><span class="p">,</span> <span class="n">scale</span> <span class="o">=</span> <span class="ss">:exp</span><span class="p">),</span>
</span></span><span class="line"><span class="cl"> <span class="n">labelfontsize</span><span class="o">=</span><span class="mi">20</span><span class="p">,</span>
</span></span><span class="line"><span class="cl"> <span class="n">xlabel</span> <span class="o">=</span> <span class="sa">L</span><span class="s">&#34;x_1&#34;</span><span class="p">,</span> <span class="n">ylabel</span> <span class="o">=</span> <span class="sa">L</span><span class="s">&#34;x_2&#34;</span><span class="p">)</span>
</span></span><span class="line"><span class="cl"> <span class="n">scatter!</span><span class="p">(</span><span class="n">pl2</span><span class="p">,</span> <span class="p">[(</span><span class="n">x₀</span><span class="p">[</span><span class="mi">1</span><span class="p">],</span><span class="n">x₀</span><span class="p">[</span><span class="mi">2</span><span class="p">])],</span> <span class="n">label</span><span class="o">=</span><span class="sa">L</span><span class="s">&#34;x_0&#34;</span><span class="p">,</span> <span class="n">legendfontsize</span><span class="o">=</span><span class="mi">15</span><span class="p">,</span> <span class="n">markershape</span> <span class="o">=</span> <span class="ss">:circle</span><span class="p">,</span> <span class="n">markersize</span> <span class="o">=</span> <span class="mi">10</span><span class="p">,</span> <span class="n">markercolor</span> <span class="o">=</span> <span class="ss">:green</span><span class="p">)</span>
</span></span><span class="line"><span class="cl"> <span class="n">scatter!</span><span class="p">(</span><span class="n">pl2</span><span class="p">,</span> <span class="p">[(</span><span class="n">xopt</span><span class="p">[</span><span class="mi">1</span><span class="p">],</span><span class="n">xopt</span><span class="p">[</span><span class="mi">2</span><span class="p">])],</span><span class="n">label</span><span class="o">=</span><span class="sa">L</span><span class="s">&#34;x^\star&#34;</span><span class="p">,</span> <span class="n">legendfontsize</span><span class="o">=</span><span class="mi">15</span><span class="p">,</span> <span class="n">markershape</span> <span class="o">=</span> <span class="ss">:star</span><span class="p">,</span> <span class="n">markersize</span> <span class="o">=</span> <span class="mi">10</span><span class="p">,</span> <span class="n">markercolor</span> <span class="o">=</span> <span class="ss">:red</span><span class="p">)</span>
</span></span><span class="line"><span class="cl">
</span></span><span class="line"><span class="cl"> <span class="n">pl</span> <span class="o">=</span> <span class="n">plot</span><span class="p">(</span><span class="n">pl1</span><span class="p">,</span><span class="n">pl2</span><span class="p">,</span> <span class="n">layout</span><span class="o">=</span><span class="p">(</span><span class="mi">2</span><span class="p">,</span><span class="mi">1</span><span class="p">))</span>
</span></span><span class="line"><span class="cl"> <span class="n">savefig</span><span class="p">(</span><span class="n">pl</span><span class="p">,</span> <span class="s">&#34;Rosenbrock.png&#34;</span><span class="p">)</span>
</span></span><span class="line"><span class="cl"><span class="k">end</span>
</span></span></code></pre></div>
<figure >
<div class="d-flex justify-content-center">
<div class="w-100" ><img src="https://frankschae.github.io/img/Rosenbrock.png" alt="" loading="lazy" data-zoomable /></div>
</div></figure>
<p>To apply the Gauss-Newton algorithm to the Rosenbrock function $g$, we first cast $g$ into an appropriate form fulfilling $S({\bf x})$, i.e., we use:</p>
<p>$$
f:\mathbb{R}^2\rightarrow\mathbb{R}^2: {\bf x} \mapsto \begin{pmatrix}
f_1({\bf x}) \\
f_2({\bf x}) \\
\end{pmatrix} = \begin{pmatrix}
\sqrt{2}(a-x_1) \\
\sqrt{2b}(x_2-x_1^2)\\
\end{pmatrix},
$$</p>
<p>instead of $g$. We can easily compute the Jacobian of $f$ manually</p>
<p>$$
J = \begin{pmatrix}
-\sqrt{2} &amp; 0 \\
-2x_1\sqrt{2b} &amp; \sqrt{2b} \\
\end{pmatrix}.
$$</p>
<p>We can then implement a (simple, non-optimized) version of the Gauss-Newton algorithm as follows.</p>
<div class="highlight"><pre tabindex="0" class="chroma"><code class="language-julia" data-lang="julia"><span class="line"><span class="cl"><span class="c"># bring Rosenbrock function into the form &#34;sum of squares of functions&#34;</span>
</span></span><span class="line"><span class="cl"><span class="n">f1</span><span class="p">(</span><span class="n">x</span><span class="p">,</span><span class="n">p</span><span class="p">)</span> <span class="o">=</span> <span class="n">convert</span><span class="p">(</span><span class="n">eltype</span><span class="p">(</span><span class="n">x</span><span class="p">),</span><span class="n">sqrt</span><span class="p">(</span><span class="mi">2</span><span class="p">))</span><span class="o">*</span><span class="p">(</span><span class="n">p</span><span class="p">[</span><span class="mi">1</span><span class="p">]</span><span class="o">-</span><span class="n">x</span><span class="p">[</span><span class="mi">1</span><span class="p">])</span>
</span></span><span class="line"><span class="cl"><span class="n">f2</span><span class="p">(</span><span class="n">x</span><span class="p">,</span><span class="n">p</span><span class="p">)</span> <span class="o">=</span> <span class="n">convert</span><span class="p">(</span><span class="n">eltype</span><span class="p">(</span><span class="n">x</span><span class="p">),</span><span class="n">sqrt</span><span class="p">(</span><span class="mi">2</span><span class="o">*</span><span class="n">p</span><span class="p">[</span><span class="mi">2</span><span class="p">]))</span><span class="o">*</span><span class="p">(</span><span class="n">x</span><span class="p">[</span><span class="mi">2</span><span class="p">]</span><span class="o">-</span><span class="n">x</span><span class="p">[</span><span class="mi">1</span><span class="p">]</span><span class="o">^</span><span class="mi">2</span><span class="p">)</span>
</span></span><span class="line"><span class="cl"><span class="n">f</span><span class="p">(</span><span class="n">x</span><span class="p">,</span><span class="n">p</span><span class="p">)</span> <span class="o">=</span> <span class="p">[</span><span class="n">f1</span><span class="p">(</span><span class="n">x</span><span class="p">,</span><span class="n">p</span><span class="p">),</span><span class="n">f2</span><span class="p">(</span><span class="n">x</span><span class="p">,</span><span class="n">p</span><span class="p">)]</span>
</span></span><span class="line"><span class="cl"><span class="k">function</span> <span class="n">f</span><span class="p">(</span><span class="n">res</span><span class="p">,</span><span class="n">x</span><span class="p">,</span><span class="n">p</span><span class="p">)</span> <span class="c"># Enzyme works with inplace functions</span>
</span></span><span class="line"><span class="cl"> <span class="n">res</span><span class="p">[</span><span class="mi">1</span><span class="p">]</span> <span class="o">=</span> <span class="n">f1</span><span class="p">(</span><span class="n">x</span><span class="p">,</span><span class="n">p</span><span class="p">)</span>
</span></span><span class="line"><span class="cl"> <span class="n">res</span><span class="p">[</span><span class="mi">2</span><span class="p">]</span> <span class="o">=</span> <span class="n">f2</span><span class="p">(</span><span class="n">x</span><span class="p">,</span><span class="n">p</span><span class="p">)</span>
</span></span><span class="line"><span class="cl"> <span class="k">return</span> <span class="nb">nothing</span>
</span></span><span class="line"><span class="cl"><span class="k">end</span>
</span></span><span class="line"><span class="cl">
</span></span><span class="line"><span class="cl"><span class="c">## manually pre-defined Jacobian</span>
</span></span><span class="line"><span class="cl"><span class="k">function</span> <span class="n">Jacobian</span><span class="p">(</span><span class="n">x</span><span class="p">,</span><span class="n">p</span><span class="p">)</span>
</span></span><span class="line"><span class="cl"> <span class="p">[</span><span class="o">-</span><span class="n">convert</span><span class="p">(</span><span class="n">eltype</span><span class="p">(</span><span class="n">x</span><span class="p">),</span><span class="n">sqrt</span><span class="p">(</span><span class="mi">2</span><span class="p">))</span> <span class="mi">0</span>
</span></span><span class="line"><span class="cl"> <span class="o">-</span><span class="mi">2</span><span class="o">*</span><span class="n">x</span><span class="p">[</span><span class="mi">1</span><span class="p">]</span><span class="o">*</span><span class="n">convert</span><span class="p">(</span><span class="n">eltype</span><span class="p">(</span><span class="n">x</span><span class="p">),</span><span class="n">sqrt</span><span class="p">(</span><span class="mi">2</span><span class="o">*</span><span class="n">p</span><span class="p">[</span><span class="mi">2</span><span class="p">]))</span> <span class="n">convert</span><span class="p">(</span><span class="n">eltype</span><span class="p">(</span><span class="n">x</span><span class="p">),</span><span class="n">sqrt</span><span class="p">(</span><span class="mi">2</span><span class="o">*</span><span class="n">p</span><span class="p">[</span><span class="mi">2</span><span class="p">]))]</span>
</span></span><span class="line"><span class="cl"><span class="k">end</span>
</span></span><span class="line"><span class="cl">
</span></span><span class="line"><span class="cl"><span class="c">## Gauss-Newton scheme</span>
</span></span><span class="line"><span class="cl"><span class="k">function</span> <span class="n">GaussNewton!</span><span class="p">(</span><span class="n">xs</span><span class="p">,</span> <span class="n">x</span><span class="p">,</span> <span class="n">p</span><span class="p">;</span> <span class="n">maxiter</span><span class="o">=</span><span class="mi">8</span><span class="p">,</span> <span class="n">backend</span><span class="o">=</span><span class="nb">nothing</span><span class="p">)</span>
</span></span><span class="line"><span class="cl"> <span class="k">for</span> <span class="n">i</span><span class="o">=</span><span class="mi">1</span><span class="o">:</span><span class="n">maxiter</span>
</span></span><span class="line"><span class="cl"> <span class="n">x</span> <span class="o">=</span> <span class="n">step</span><span class="p">(</span><span class="n">x</span><span class="p">,</span> <span class="n">p</span><span class="p">,</span> <span class="n">backend</span><span class="p">)</span>
</span></span><span class="line"><span class="cl"> <span class="nd">@info</span> <span class="n">i</span>
</span></span><span class="line"><span class="cl"> <span class="nd">@show</span> <span class="n">x</span>
</span></span><span class="line"><span class="cl"> <span class="n">push!</span><span class="p">(</span><span class="n">xs</span><span class="p">,</span> <span class="n">x</span><span class="p">)</span>
</span></span><span class="line"><span class="cl"> <span class="k">end</span>
</span></span><span class="line"><span class="cl"> <span class="k">return</span> <span class="n">xs</span><span class="p">,</span> <span class="n">x</span>
</span></span><span class="line"><span class="cl"><span class="k">end</span>
</span></span><span class="line"><span class="cl"><span class="n">done</span><span class="p">(</span><span class="n">x</span><span class="p">,</span><span class="n">x2</span><span class="p">,</span><span class="n">p</span><span class="p">)</span> <span class="o">=</span> <span class="n">g</span><span class="p">(</span><span class="n">x2</span><span class="p">,</span><span class="n">p</span><span class="p">)</span> <span class="o">&lt;</span> <span class="n">g</span><span class="p">(</span><span class="n">x</span><span class="p">,</span><span class="n">p</span><span class="p">)</span>
</span></span><span class="line"><span class="cl"><span class="k">function</span> <span class="n">step</span><span class="p">(</span><span class="n">x</span><span class="p">,</span> <span class="n">p</span><span class="p">,</span> <span class="n">backend</span><span class="o">::</span><span class="kt">Nothing</span><span class="p">,</span> <span class="n">α</span><span class="o">=</span><span class="mi">1</span><span class="o">//</span><span class="mi">1</span><span class="p">)</span>
</span></span><span class="line"><span class="cl"> <span class="n">x2</span> <span class="o">=</span> <span class="n">deepcopy</span><span class="p">(</span><span class="n">x</span><span class="p">)</span>
</span></span><span class="line"><span class="cl"> <span class="k">while</span> <span class="o">!</span><span class="n">done</span><span class="p">(</span><span class="n">x</span><span class="p">,</span><span class="n">x2</span><span class="p">,</span><span class="n">p</span><span class="p">)</span>
</span></span><span class="line"><span class="cl"> <span class="n">J</span> <span class="o">=</span> <span class="n">Jacobian</span><span class="p">(</span><span class="n">x</span><span class="p">,</span><span class="n">p</span><span class="p">)</span>
</span></span><span class="line"><span class="cl"> <span class="n">d</span> <span class="o">=</span> <span class="o">-</span><span class="n">inv</span><span class="p">(</span><span class="n">J</span><span class="o">&#39;*</span><span class="n">J</span><span class="p">)</span><span class="o">*</span><span class="n">J</span><span class="o">&#39;*</span><span class="n">f</span><span class="p">(</span><span class="n">x</span><span class="p">,</span><span class="n">p</span><span class="p">)</span>
</span></span><span class="line"><span class="cl"> <span class="n">copyto!</span><span class="p">(</span><span class="n">x2</span><span class="p">,</span><span class="n">x</span> <span class="o">+</span> <span class="n">α</span><span class="o">*</span><span class="n">d</span><span class="p">)</span>
</span></span><span class="line"><span class="cl"> <span class="n">α</span> <span class="o">=</span> <span class="n">α</span><span class="o">//</span><span class="mi">2</span>
</span></span><span class="line"><span class="cl"> <span class="k">end</span>
</span></span><span class="line"><span class="cl"> <span class="k">return</span> <span class="n">x2</span>
</span></span><span class="line"><span class="cl"><span class="k">end</span>
</span></span></code></pre></div><p>When we run the algorithm, we find the global minimum after about the 7th iteration.</p>
<div class="highlight"><pre tabindex="0" class="chroma"><code class="language-julia" data-lang="julia"><span class="line"><span class="cl"><span class="n">xs</span> <span class="o">=</span> <span class="p">[</span><span class="n">x₀</span><span class="p">]</span>
</span></span><span class="line"><span class="cl"><span class="n">GaussNewton!</span><span class="p">(</span><span class="n">xs</span><span class="p">,</span> <span class="n">x₀</span><span class="p">,</span> <span class="n">p</span><span class="p">)</span>
</span></span></code></pre></div><div class="highlight"><pre tabindex="0" class="chroma"><code class="language-julia" data-lang="julia"><span class="line"><span class="cl"><span class="c"># output:</span>
</span></span><span class="line"><span class="cl"><span class="p">[</span> <span class="n">Info</span><span class="o">:</span> <span class="mi">1</span> <span class="p">]</span>
</span></span><span class="line"><span class="cl"><span class="n">x</span> <span class="o">=</span> <span class="p">[</span><span class="mf">0.125</span><span class="p">,</span> <span class="o">-</span><span class="mf">0.08750000000000001</span><span class="p">]</span>
</span></span><span class="line"><span class="cl"><span class="p">[</span> <span class="n">Info</span><span class="o">:</span> <span class="mi">2</span> <span class="p">]</span>
</span></span><span class="line"><span class="cl"><span class="n">x</span> <span class="o">=</span> <span class="p">[</span><span class="mf">0.234375</span><span class="p">,</span> <span class="o">-</span><span class="mf">0.047265625000000006</span><span class="p">]</span>
</span></span><span class="line"><span class="cl"><span class="p">[</span> <span class="n">Info</span><span class="o">:</span> <span class="mi">3</span> <span class="p">]</span>
</span></span><span class="line"><span class="cl"><span class="n">x</span> <span class="o">=</span> <span class="p">[</span><span class="mf">0.4257812499999995</span><span class="p">,</span> <span class="mf">0.06800537109374968</span><span class="p">]</span>
</span></span><span class="line"><span class="cl"><span class="p">[</span> <span class="n">Info</span><span class="o">:</span> <span class="mi">4</span> <span class="p">]</span>
</span></span><span class="line"><span class="cl"><span class="n">x</span> <span class="o">=</span> <span class="p">[</span><span class="mf">0.5693359374999986</span><span class="p">,</span> <span class="mf">0.21857223510742047</span><span class="p">]</span>
</span></span><span class="line"><span class="cl"><span class="p">[</span> <span class="n">Info</span><span class="o">:</span> <span class="mi">5</span> <span class="p">]</span>
</span></span><span class="line"><span class="cl"><span class="n">x</span> <span class="o">=</span> <span class="p">[</span><span class="mf">0.784667968749996</span><span class="p">,</span> <span class="mf">0.5165503501892037</span><span class="p">]</span>
</span></span><span class="line"><span class="cl"><span class="p">[</span> <span class="n">Info</span><span class="o">:</span> <span class="mi">6</span> <span class="p">]</span>
</span></span><span class="line"><span class="cl"><span class="n">x</span> <span class="o">=</span> <span class="p">[</span><span class="mf">0.9999999999999961</span><span class="p">,</span> <span class="mf">0.9536321163177449</span><span class="p">]</span>
</span></span><span class="line"><span class="cl"><span class="p">[</span> <span class="n">Info</span><span class="o">:</span> <span class="mi">7</span> <span class="p">]</span>
</span></span><span class="line"><span class="cl"><span class="n">x</span> <span class="o">=</span> <span class="p">[</span><span class="mf">0.9999999999999989</span><span class="p">,</span> <span class="mf">0.9999999999999999</span><span class="p">]</span>
</span></span><span class="line"><span class="cl"><span class="p">[</span> <span class="n">Info</span><span class="o">:</span> <span class="mi">8</span> <span class="p">]</span>
</span></span><span class="line"><span class="cl"><span class="n">x</span> <span class="o">=</span> <span class="p">[</span><span class="mf">1.0</span><span class="p">,</span> <span class="mf">1.0</span><span class="p">]</span>
</span></span></code></pre></div><p>If computing the Jacobian by hand is too cumbersome (or not possible for other reasons), we can compute it using finite differences. Within the AbstractDifferentiation API, we can directly define, for instance, the Jacobian of <a href="https://github.com/JuliaDiff/FiniteDifferences.jl" target="_blank" rel="noopener">FiniteDifferences.jl</a> as a new primitive operation.</p>
<div class="highlight"><pre tabindex="0" class="chroma"><code class="language-julia" data-lang="julia"><span class="line"><span class="cl"><span class="c">## FiniteDifferences</span>
</span></span><span class="line"><span class="cl"><span class="k">struct</span> <span class="kt">FDMBackend</span><span class="p">{</span><span class="kt">A</span><span class="p">}</span> <span class="o">&lt;:</span> <span class="kt">AD</span><span class="o">.</span><span class="n">AbstractFiniteDifference</span>
</span></span><span class="line"><span class="cl"> <span class="n">alg</span><span class="o">::</span><span class="kt">A</span>
</span></span><span class="line"><span class="cl"><span class="k">end</span>
</span></span><span class="line"><span class="cl"><span class="n">FDMBackend</span><span class="p">()</span> <span class="o">=</span> <span class="n">FDMBackend</span><span class="p">(</span><span class="n">central_fdm</span><span class="p">(</span><span class="mi">5</span><span class="p">,</span> <span class="mi">1</span><span class="p">))</span>
</span></span><span class="line"><span class="cl"><span class="k">const</span> <span class="n">fdm_backend</span> <span class="o">=</span> <span class="n">FDMBackend</span><span class="p">()</span>
</span></span><span class="line"><span class="cl"><span class="c"># Minimal interface</span>
</span></span><span class="line"><span class="cl"><span class="n">AD</span><span class="o">.</span><span class="nd">@primitive</span> <span class="k">function</span> <span class="n">jacobian</span><span class="p">(</span><span class="n">ab</span><span class="o">::</span><span class="kt">FDMBackend</span><span class="p">,</span> <span class="n">f</span><span class="p">,</span> <span class="n">xs</span><span class="o">...</span><span class="p">)</span>
</span></span><span class="line"><span class="cl"> <span class="k">return</span> <span class="n">FiniteDifferences</span><span class="o">.</span><span class="n">jacobian</span><span class="p">(</span><span class="n">ab</span><span class="o">.</span><span class="n">alg</span><span class="p">,</span> <span class="n">f</span><span class="p">,</span> <span class="n">xs</span><span class="o">...</span><span class="p">)</span>
</span></span><span class="line"><span class="cl"><span class="k">end</span>
</span></span><span class="line"><span class="cl">
</span></span><span class="line"><span class="cl"><span class="c"># AD Jacobian returns tuple</span>
</span></span><span class="line"><span class="cl"><span class="c"># df_dx = AD.jacobian(fdm_backend, f(x,p), x₀, p)[1]</span>
</span></span><span class="line"><span class="cl"><span class="c"># df_dp = AD.jacobian(fdm_backend, f(x,p), x₀, p)[2]</span>
</span></span><span class="line"><span class="cl">
</span></span><span class="line"><span class="cl"><span class="nd">@test</span> <span class="n">AD</span><span class="o">.</span><span class="n">jacobian</span><span class="p">(</span><span class="n">fdm_backend</span><span class="p">,</span> <span class="n">x</span><span class="o">-&gt;</span><span class="n">f</span><span class="p">(</span><span class="n">x</span><span class="p">,</span><span class="n">p</span><span class="p">),</span> <span class="n">x₀</span><span class="p">)[</span><span class="mi">1</span><span class="p">]</span> <span class="o">≈</span> <span class="n">Jacobian</span><span class="p">(</span><span class="n">x₀</span><span class="p">,</span> <span class="n">p</span><span class="p">)</span>
</span></span><span class="line"><span class="cl"><span class="nd">@test</span> <span class="n">AD</span><span class="o">.</span><span class="n">jacobian</span><span class="p">(</span><span class="n">fdm_backend</span><span class="p">,</span> <span class="n">f</span><span class="p">,</span> <span class="n">x₀</span><span class="p">,</span> <span class="n">p</span><span class="p">)[</span><span class="mi">1</span><span class="p">]</span> <span class="o">≈</span> <span class="n">Jacobian</span><span class="p">(</span><span class="n">x₀</span><span class="p">,</span> <span class="n">p</span><span class="p">)</span>
</span></span></code></pre></div><p>After overloading the <code>step</code> function, we can run the Gauss-Newton algorithm as follows:</p>
<div class="highlight"><pre tabindex="0" class="chroma"><code class="language-julia" data-lang="julia"><span class="line"><span class="cl"><span class="k">function</span> <span class="n">step</span><span class="p">(</span><span class="n">x</span><span class="p">,</span> <span class="n">p</span><span class="p">,</span> <span class="n">backend</span><span class="p">,</span> <span class="n">α</span><span class="o">=</span><span class="mi">1</span><span class="o">//</span><span class="mi">1</span><span class="p">)</span>
</span></span><span class="line"><span class="cl"> <span class="n">x2</span> <span class="o">=</span> <span class="n">deepcopy</span><span class="p">(</span><span class="n">x</span><span class="p">)</span>
</span></span><span class="line"><span class="cl"> <span class="k">while</span> <span class="o">!</span><span class="n">done</span><span class="p">(</span><span class="n">x</span><span class="p">,</span><span class="n">x2</span><span class="p">,</span><span class="n">p</span><span class="p">)</span>
</span></span><span class="line"><span class="cl"> <span class="n">J</span> <span class="o">=</span> <span class="n">AD</span><span class="o">.</span><span class="n">jacobian</span><span class="p">(</span><span class="n">backend</span><span class="p">,</span> <span class="n">f</span><span class="p">,</span> <span class="n">x</span><span class="p">,</span> <span class="n">p</span><span class="p">)[</span><span class="mi">1</span><span class="p">]</span>
</span></span><span class="line"><span class="cl"> <span class="n">d</span> <span class="o">=</span> <span class="o">-</span><span class="n">inv</span><span class="p">(</span><span class="n">J</span><span class="o">&#39;*</span><span class="n">J</span><span class="p">)</span><span class="o">*</span><span class="n">J</span><span class="o">&#39;*</span><span class="n">f</span><span class="p">(</span><span class="n">x</span><span class="p">,</span><span class="n">p</span><span class="p">)</span>
</span></span><span class="line"><span class="cl"> <span class="n">copyto!</span><span class="p">(</span><span class="n">x2</span><span class="p">,</span><span class="n">x</span> <span class="o">+</span> <span class="n">α</span><span class="o">*</span><span class="n">d</span><span class="p">)</span>
</span></span><span class="line"><span class="cl"> <span class="n">α</span> <span class="o">=</span> <span class="n">α</span><span class="o">//</span><span class="mi">2</span>
</span></span><span class="line"><span class="cl"> <span class="k">end</span>
</span></span><span class="line"><span class="cl"> <span class="k">return</span> <span class="n">x2</span>
</span></span><span class="line"><span class="cl"><span class="k">end</span>
</span></span><span class="line"><span class="cl">
</span></span><span class="line"><span class="cl">
</span></span><span class="line"><span class="cl"><span class="n">xs</span> <span class="o">=</span> <span class="p">[</span><span class="n">x₀</span><span class="p">]</span>
</span></span><span class="line"><span class="cl"><span class="n">GaussNewton!</span><span class="p">(</span><span class="n">xs</span><span class="p">,</span> <span class="n">x₀</span><span class="p">,</span> <span class="n">p</span><span class="p">,</span> <span class="n">backend</span><span class="o">=</span><span class="n">fdm_backend</span><span class="p">)</span>
</span></span></code></pre></div><p>If we want to use reverse-mode AD instead, for example via <a href="https://github.com/FluxML/Zygote.jl" target="_blank" rel="noopener">Zygote.jl</a>, a natural choice for the primitive is to define the pullback function. AbstractDifferentiation then generates the associated code to compute the Jacobian for us.</p>
<div class="highlight"><pre tabindex="0" class="chroma"><code class="language-julia" data-lang="julia"><span class="line"><span class="cl"><span class="c">## Zygote</span>
</span></span><span class="line"><span class="cl"><span class="k">struct</span> <span class="kt">ZygoteBackend</span> <span class="o">&lt;:</span> <span class="kt">AD</span><span class="o">.</span><span class="n">AbstractReverseMode</span> <span class="k">end</span>
</span></span><span class="line"><span class="cl"><span class="k">const</span> <span class="n">zygote_backend</span> <span class="o">=</span> <span class="n">ZygoteBackend</span><span class="p">()</span>
</span></span><span class="line"><span class="cl"><span class="n">AD</span><span class="o">.</span><span class="nd">@primitive</span> <span class="k">function</span> <span class="n">pullback_function</span><span class="p">(</span><span class="n">ab</span><span class="o">::</span><span class="kt">ZygoteBackend</span><span class="p">,</span> <span class="n">f</span><span class="p">,</span> <span class="n">xs</span><span class="o">...</span><span class="p">)</span>
</span></span><span class="line"><span class="cl"> <span class="k">return</span> <span class="k">function</span> <span class="p">(</span><span class="n">vs</span><span class="p">)</span>
</span></span><span class="line"><span class="cl"> <span class="c"># Supports only single output</span>
</span></span><span class="line"><span class="cl"> <span class="n">_</span><span class="p">,</span> <span class="n">back</span> <span class="o">=</span> <span class="n">Zygote</span><span class="o">.</span><span class="n">pullback</span><span class="p">(</span><span class="n">f</span><span class="p">,</span> <span class="n">xs</span><span class="o">...</span><span class="p">)</span>
</span></span><span class="line"><span class="cl"> <span class="k">if</span> <span class="n">vs</span> <span class="k">isa</span> <span class="kt">AbstractVector</span>
</span></span><span class="line"><span class="cl"> <span class="n">back</span><span class="p">(</span><span class="n">vs</span><span class="p">)</span>
</span></span><span class="line"><span class="cl"> <span class="k">else</span>
</span></span><span class="line"><span class="cl"> <span class="nd">@assert</span> <span class="n">length</span><span class="p">(</span><span class="n">vs</span><span class="p">)</span> <span class="o">==</span> <span class="mi">1</span>
</span></span><span class="line"><span class="cl"> <span class="n">back</span><span class="p">(</span><span class="n">vs</span><span class="p">[</span><span class="mi">1</span><span class="p">])</span>
</span></span><span class="line"><span class="cl"> <span class="k">end</span>
</span></span><span class="line"><span class="cl"> <span class="k">end</span>
</span></span><span class="line"><span class="cl"><span class="k">end</span>
</span></span><span class="line"><span class="cl"><span class="c">##</span>
</span></span><span class="line"><span class="cl">
</span></span><span class="line"><span class="cl"><span class="nd">@test</span> <span class="n">minimum</span><span class="p">(</span><span class="n">AD</span><span class="o">.</span><span class="n">jacobian</span><span class="p">(</span><span class="n">fdm_backend</span><span class="p">,</span> <span class="n">f</span><span class="p">,</span> <span class="n">x₀</span><span class="p">,</span> <span class="n">p</span><span class="p">)</span> <span class="o">.≈</span> <span class="n">AD</span><span class="o">.</span><span class="n">jacobian</span><span class="p">(</span><span class="n">zygote_backend</span><span class="p">,</span> <span class="n">f</span><span class="p">,</span> <span class="n">x₀</span><span class="p">,</span> <span class="n">p</span><span class="p">))</span>
</span></span><span class="line"><span class="cl"><span class="n">xs</span> <span class="o">=</span> <span class="p">[</span><span class="n">x₀</span><span class="p">]</span>
</span></span><span class="line"><span class="cl"><span class="n">GaussNewton!</span><span class="p">(</span><span class="n">xs</span><span class="p">,</span> <span class="n">x₀</span><span class="p">,</span> <span class="n">p</span><span class="p">,</span> <span class="n">backend</span><span class="o">=</span><span class="n">zygote_backend</span><span class="p">)</span>
</span></span></code></pre></div><p>Typically, reverse-mode AD is only beneficial for functions $f:\mathbb{R}^N\rightarrow\mathbb{R}^M$ where $M \ll N$, thus it is also a good idea to compare the performance with respect to forward-mode AD (<a href="https://github.com/JuliaDiff/ForwardDiff.jl" target="_blank" rel="noopener">ForwardDiff.jl</a>)</p>
<div class="highlight"><pre tabindex="0" class="chroma"><code class="language-julia" data-lang="julia"><span class="line"><span class="cl"><span class="c">## ForwardDiff</span>
</span></span><span class="line"><span class="cl"><span class="k">struct</span> <span class="kt">ForwardDiffBackend</span> <span class="o">&lt;:</span> <span class="kt">AD</span><span class="o">.</span><span class="n">AbstractForwardMode</span> <span class="k">end</span>
</span></span><span class="line"><span class="cl"><span class="k">const</span> <span class="n">forwarddiff_backend</span> <span class="o">=</span> <span class="n">ForwardDiffBackend</span><span class="p">()</span>
</span></span><span class="line"><span class="cl"><span class="n">AD</span><span class="o">.</span><span class="nd">@primitive</span> <span class="k">function</span> <span class="n">pushforward_function</span><span class="p">(</span><span class="n">ab</span><span class="o">::</span><span class="kt">ForwardDiffBackend</span><span class="p">,</span> <span class="n">f</span><span class="p">,</span> <span class="n">xs</span><span class="o">...</span><span class="p">)</span>
</span></span><span class="line"><span class="cl"> <span class="c"># jvp = f&#39;(x)*v, i.e., differentiate f(x + h*v) wrt h at 0</span>
</span></span><span class="line"><span class="cl"> <span class="k">return</span> <span class="k">function</span> <span class="p">(</span><span class="n">vs</span><span class="p">)</span>
</span></span><span class="line"><span class="cl"> <span class="k">if</span> <span class="n">xs</span> <span class="k">isa</span> <span class="kt">Tuple</span>
</span></span><span class="line"><span class="cl"> <span class="nd">@assert</span> <span class="n">length</span><span class="p">(</span><span class="n">xs</span><span class="p">)</span> <span class="o">&lt;=</span> <span class="mi">2</span>
</span></span><span class="line"><span class="cl"> <span class="k">if</span> <span class="n">length</span><span class="p">(</span><span class="n">xs</span><span class="p">)</span> <span class="o">==</span> <span class="mi">1</span>
</span></span><span class="line"><span class="cl"> <span class="p">(</span><span class="n">ForwardDiff</span><span class="o">.</span><span class="n">derivative</span><span class="p">(</span><span class="n">h</span><span class="o">-&gt;</span><span class="n">f</span><span class="p">(</span><span class="n">xs</span><span class="p">[</span><span class="mi">1</span><span class="p">]</span><span class="o">+</span><span class="n">h</span><span class="o">*</span><span class="n">vs</span><span class="p">[</span><span class="mi">1</span><span class="p">]),</span><span class="mi">0</span><span class="p">),)</span>
</span></span><span class="line"><span class="cl"> <span class="k">else</span>
</span></span><span class="line"><span class="cl"> <span class="n">ForwardDiff</span><span class="o">.</span><span class="n">derivative</span><span class="p">(</span><span class="n">h</span><span class="o">-&gt;</span><span class="n">f</span><span class="p">(</span><span class="n">xs</span><span class="p">[</span><span class="mi">1</span><span class="p">]</span><span class="o">+</span><span class="n">h</span><span class="o">*</span><span class="n">vs</span><span class="p">[</span><span class="mi">1</span><span class="p">],</span> <span class="n">xs</span><span class="p">[</span><span class="mi">2</span><span class="p">]</span><span class="o">+</span><span class="n">h</span><span class="o">*</span><span class="n">vs</span><span class="p">[</span><span class="mi">2</span><span class="p">]),</span><span class="mi">0</span><span class="p">)</span>
</span></span><span class="line"><span class="cl"> <span class="k">end</span>
</span></span><span class="line"><span class="cl"> <span class="k">else</span>
</span></span><span class="line"><span class="cl"> <span class="n">ForwardDiff</span><span class="o">.</span><span class="n">derivative</span><span class="p">(</span><span class="n">h</span><span class="o">-&gt;</span><span class="n">f</span><span class="p">(</span><span class="n">xs</span><span class="o">+</span><span class="n">h</span><span class="o">*</span><span class="n">vs</span><span class="p">),</span><span class="mi">0</span><span class="p">)</span>
</span></span><span class="line"><span class="cl"> <span class="k">end</span>
</span></span><span class="line"><span class="cl"> <span class="k">end</span>
</span></span><span class="line"><span class="cl"><span class="k">end</span>
</span></span><span class="line"><span class="cl"><span class="c">##</span>
</span></span><span class="line"><span class="cl">
</span></span><span class="line"><span class="cl"><span class="nd">@test</span> <span class="n">minimum</span><span class="p">(</span><span class="n">AD</span><span class="o">.</span><span class="n">jacobian</span><span class="p">(</span><span class="n">fdm_backend</span><span class="p">,</span> <span class="n">f</span><span class="p">,</span> <span class="n">x₀</span><span class="p">,</span> <span class="n">p</span><span class="p">)</span> <span class="o">.≈</span> <span class="n">AD</span><span class="o">.</span><span class="n">jacobian</span><span class="p">(</span><span class="n">forwarddiff_backend</span><span class="p">,</span> <span class="n">f</span><span class="p">,</span> <span class="n">x₀</span><span class="p">,</span> <span class="n">p</span><span class="p">))</span>
</span></span><span class="line"><span class="cl"><span class="n">xs</span> <span class="o">=</span> <span class="p">[</span><span class="n">x₀</span><span class="p">]</span>
</span></span><span class="line"><span class="cl"><span class="n">GaussNewton!</span><span class="p">(</span><span class="n">xs</span><span class="p">,</span> <span class="n">x₀</span><span class="p">,</span> <span class="n">p</span><span class="p">,</span> <span class="n">backend</span><span class="o">=</span><span class="n">forwarddiff_backend</span><span class="p">)</span>
</span></span></code></pre></div><p>where we have used that the Jacobian-vector product $f&rsquo;(x)v$, i.e., the primitives of forward-mode AD, can be computed by <a href="https://discourse.julialang.org/t/help-with-jacobian-vector-product-to-get-natural-gradient/51115/12" target="_blank" rel="noopener">differentiating $f(x + hv)$ with respect to $h$ at 0</a>.</p>
<p>Many AD packages, such as Zygote, have troubles with mutating functions. <a href="https://github.com/wsmoses/Enzyme.jl" target="_blank" rel="noopener">Enzyme.jl</a> is one of the exceptions. Additionally, it is very fast and has further improved the performance of the <a href="https://github.com/SciML/DiffEqSensitivity.jl/pull/427#issuecomment-866509944" target="_blank" rel="noopener">adjoints implemented within the DiffEqSensitivity package</a>.</p>
<div class="highlight"><pre tabindex="0" class="chroma"><code class="language-julia" data-lang="julia"><span class="line"><span class="cl"><span class="c">## Enzyme</span>
</span></span><span class="line"><span class="cl"><span class="k">struct</span> <span class="kt">EnzymeBackend</span><span class="p">{</span><span class="kt">T1</span><span class="p">,</span><span class="kt">T2</span><span class="p">,</span><span class="kt">T3</span><span class="p">,</span><span class="kt">T4</span><span class="p">}</span> <span class="o">&lt;:</span> <span class="kt">AD</span><span class="o">.</span><span class="n">AbstractReverseMode</span>
</span></span><span class="line"><span class="cl"> <span class="n">out</span><span class="o">::</span><span class="kt">T1</span>
</span></span><span class="line"><span class="cl"> <span class="n">λ</span><span class="o">::</span><span class="kt">T2</span>
</span></span><span class="line"><span class="cl"> <span class="n">∂f_∂x</span><span class="o">::</span><span class="kt">T3</span>
</span></span><span class="line"><span class="cl"> <span class="n">∂f_∂p</span><span class="o">::</span><span class="kt">T4</span>
</span></span><span class="line"><span class="cl"><span class="k">end</span>
</span></span><span class="line"><span class="cl">
</span></span><span class="line"><span class="cl"><span class="n">out</span> <span class="o">=</span> <span class="n">zero</span><span class="p">(</span><span class="n">x₀</span><span class="p">)</span>
</span></span><span class="line"><span class="cl"><span class="n">λ</span> <span class="o">=</span> <span class="n">zero</span><span class="p">(</span><span class="n">x₀</span><span class="p">)</span>
</span></span><span class="line"><span class="cl"><span class="n">∂f_∂x</span> <span class="o">=</span> <span class="n">zero</span><span class="p">(</span><span class="n">x₀</span><span class="p">)</span>
</span></span><span class="line"><span class="cl"><span class="n">∂f_∂p</span> <span class="o">=</span> <span class="n">zero</span><span class="p">(</span><span class="n">p</span><span class="p">)</span>
</span></span><span class="line"><span class="cl">
</span></span><span class="line"><span class="cl"><span class="k">const</span> <span class="n">enzyme_backend</span> <span class="o">=</span> <span class="n">EnzymeBackend</span><span class="p">(</span><span class="n">out</span><span class="p">,</span><span class="n">λ</span><span class="p">,</span><span class="n">∂f_∂x</span><span class="p">,</span><span class="n">∂f_∂p</span><span class="p">)</span>
</span></span><span class="line"><span class="cl"><span class="n">AD</span><span class="o">.</span><span class="nd">@primitive</span> <span class="k">function</span> <span class="n">pullback_function</span><span class="p">(</span><span class="n">ab</span><span class="o">::</span><span class="kt">EnzymeBackend</span><span class="p">,</span> <span class="n">f</span><span class="p">,</span> <span class="n">xs</span><span class="o">...</span><span class="p">)</span>
</span></span><span class="line"><span class="cl"> <span class="k">return</span> <span class="k">function</span> <span class="p">(</span><span class="n">vs</span><span class="p">)</span>
</span></span><span class="line"><span class="cl"> <span class="c"># enzyme works only with inplace functions</span>
</span></span><span class="line"><span class="cl"> <span class="k">if</span> <span class="o">!</span><span class="p">(</span><span class="n">vs</span> <span class="k">isa</span> <span class="kt">AbstractVector</span><span class="p">)</span>
</span></span><span class="line"><span class="cl"> <span class="nd">@assert</span> <span class="n">length</span><span class="p">(</span><span class="n">vs</span><span class="p">)</span> <span class="o">==</span> <span class="mi">1</span> <span class="c"># Supports only single output</span>
</span></span><span class="line"><span class="cl"> <span class="n">vs</span> <span class="o">=</span> <span class="n">vs</span><span class="p">[</span><span class="mi">1</span><span class="p">]</span>
</span></span><span class="line"><span class="cl"> <span class="k">end</span>
</span></span><span class="line"><span class="cl">
</span></span><span class="line"><span class="cl"> <span class="k">if</span> <span class="n">xs</span> <span class="k">isa</span> <span class="kt">Tuple</span>
</span></span><span class="line"><span class="cl"> <span class="nd">@assert</span> <span class="n">length</span><span class="p">(</span><span class="n">xs</span><span class="p">)</span> <span class="o">==</span> <span class="mi">2</span> <span class="c"># hard-coded for use case with two inputs</span>
</span></span><span class="line"><span class="cl"> <span class="n">x₀</span> <span class="o">=</span> <span class="n">xs</span><span class="p">[</span><span class="mi">1</span><span class="p">]</span>
</span></span><span class="line"><span class="cl"> <span class="n">p</span> <span class="o">=</span> <span class="n">xs</span><span class="p">[</span><span class="mi">2</span><span class="p">]</span>
</span></span><span class="line"><span class="cl"> <span class="k">end</span>
</span></span><span class="line"><span class="cl">
</span></span><span class="line"><span class="cl"> <span class="nd">@unpack</span> <span class="n">out</span><span class="p">,</span> <span class="n">λ</span><span class="p">,</span> <span class="n">∂f_∂x</span><span class="p">,</span> <span class="n">∂f_∂p</span> <span class="o">=</span> <span class="n">ab</span> <span class="c"># cached in the struct, could also be created in here</span>
</span></span><span class="line"><span class="cl">
</span></span><span class="line"><span class="cl"> <span class="n">∂f_∂x</span> <span class="o">.*=</span> <span class="nb">false</span>
</span></span><span class="line"><span class="cl"> <span class="n">∂f_∂p</span> <span class="o">.*=</span> <span class="nb">false</span>
</span></span><span class="line"><span class="cl"> <span class="n">out</span> <span class="o">.*=</span> <span class="nb">false</span>
</span></span><span class="line"><span class="cl">
</span></span><span class="line"><span class="cl"> <span class="n">copyto!</span><span class="p">(</span><span class="n">λ</span><span class="p">,</span> <span class="n">vs</span><span class="p">)</span>
</span></span><span class="line"><span class="cl">
</span></span><span class="line"><span class="cl"> <span class="n">autodiff</span><span class="p">(</span><span class="n">Duplicated</span><span class="p">(</span><span class="n">out</span><span class="p">,</span> <span class="n">λ</span><span class="p">),</span> <span class="n">Duplicated</span><span class="p">(</span><span class="n">x₀</span><span class="p">,</span> <span class="n">∂f_∂x</span><span class="p">),</span> <span class="n">Duplicated</span><span class="p">(</span><span class="n">p</span><span class="p">,</span> <span class="n">∂f_∂p</span><span class="p">))</span> <span class="k">do</span> <span class="n">_out</span><span class="p">,</span><span class="n">_x</span><span class="p">,</span> <span class="n">_p</span>
</span></span><span class="line"><span class="cl"> <span class="n">f</span><span class="p">(</span><span class="n">_out</span><span class="p">,</span><span class="n">_x</span><span class="p">,</span><span class="n">_p</span><span class="p">)</span>
</span></span><span class="line"><span class="cl"> <span class="k">end</span>
</span></span><span class="line"><span class="cl"> <span class="k">return</span> <span class="p">(</span><span class="n">∂f_∂x</span><span class="p">,</span><span class="n">∂f_∂p</span><span class="p">)</span>
</span></span><span class="line"><span class="cl"> <span class="k">end</span>
</span></span><span class="line"><span class="cl"><span class="k">end</span>
</span></span><span class="line"><span class="cl"><span class="n">AD</span><span class="o">.</span><span class="n">isinplace</span><span class="p">(</span><span class="n">ab</span><span class="o">::</span><span class="kt">EnzymeBackend</span><span class="p">)</span> <span class="o">=</span> <span class="nb">true</span>
</span></span><span class="line"><span class="cl"><span class="n">AD</span><span class="o">.</span><span class="n">primalvalue</span><span class="p">(</span><span class="n">ab</span><span class="o">::</span><span class="kt">EnzymeBackend</span><span class="p">,</span> <span class="nb">nothing</span><span class="p">,</span> <span class="n">f</span><span class="p">,</span> <span class="n">xs</span><span class="p">)</span> <span class="o">=</span> <span class="p">(</span><span class="n">f</span><span class="p">(</span><span class="n">ab</span><span class="o">.</span><span class="n">out</span><span class="p">,</span><span class="n">xs</span><span class="o">...</span><span class="p">);</span><span class="k">return</span> <span class="n">ab</span><span class="o">.</span><span class="n">out</span><span class="p">)</span>
</span></span><span class="line"><span class="cl"><span class="c">##</span>
</span></span><span class="line"><span class="cl">
</span></span><span class="line"><span class="cl"><span class="nd">@test</span> <span class="n">minimum</span><span class="p">(</span><span class="n">AD</span><span class="o">.</span><span class="n">jacobian</span><span class="p">(</span><span class="n">fdm_backend</span><span class="p">,</span> <span class="n">f</span><span class="p">,</span> <span class="n">x₀</span><span class="p">,</span> <span class="n">p</span><span class="p">)</span> <span class="o">.≈</span> <span class="n">AD</span><span class="o">.</span><span class="n">jacobian</span><span class="p">(</span><span class="n">enzyme_backend</span><span class="p">,</span> <span class="n">f</span><span class="p">,</span> <span class="n">x₀</span><span class="p">,</span> <span class="n">p</span><span class="p">))</span>
</span></span><span class="line"><span class="cl"><span class="n">xs</span> <span class="o">=</span> <span class="p">[</span><span class="n">x₀</span><span class="p">]</span>
</span></span><span class="line"><span class="cl"><span class="n">GaussNewton!</span><span class="p">(</span><span class="n">xs</span><span class="p">,</span> <span class="n">x₀</span><span class="p">,</span> <span class="n">p</span><span class="p">,</span> <span class="n">backend</span><span class="o">=</span><span class="n">enzyme_backend</span><span class="p">)</span>
</span></span></code></pre></div><p>Note that we have declared the Enzyme backend as <code>inplace</code> (which is important for internal control flow) and specified a <code>primalvalue</code> function returning the primal value of the forward pass.</p>
<h2 id="some-current-glitches">Some current glitches</h2>
<p>First, the push forward of a tuple of vectors, e.g., $(v_1, v_2)$, for a function with several input arguments is currently ambiguous. While <code>AD.jacobian</code> primitives and <code>AD.pullback_function</code> primitives interpret the push forward of our $f$ function as</p>
<p>$$
\left(\frac{\partial f(x_0,p)}{\partial x} v_1 , \frac{\partial f(x_0,p)}{\partial p} v_2 \right),
$$</p>
<p><code>AD.pushforward_function</code> primitives compute</p>
<p>$$
\frac{\partial f(x_0,p)}{\partial x} v_1 + \frac{\partial f(x_0,p)}{\partial p} v_2.
$$</p>
<div class="highlight"><pre tabindex="0" class="chroma"><code class="language-julia" data-lang="julia"><span class="line"><span class="cl"><span class="c"># pushforward_function wrt to multiple vectors is currently ambiguous</span>
</span></span><span class="line"><span class="cl"><span class="n">vs</span> <span class="o">=</span> <span class="p">(</span><span class="n">randn</span><span class="p">(</span><span class="mi">2</span><span class="p">),</span> <span class="n">randn</span><span class="p">(</span><span class="mi">2</span><span class="p">))</span>
</span></span><span class="line"><span class="cl"><span class="n">res1</span> <span class="o">=</span> <span class="n">AD</span><span class="o">.</span><span class="n">pushforward_function</span><span class="p">(</span><span class="n">fdm_backend</span><span class="p">,</span> <span class="n">f</span><span class="p">,</span> <span class="n">x₀</span><span class="p">,</span> <span class="n">p</span><span class="p">)(</span><span class="n">vs</span><span class="p">)</span>
</span></span><span class="line"><span class="cl"><span class="n">res2</span> <span class="o">=</span> <span class="n">AD</span><span class="o">.</span><span class="n">pushforward_function</span><span class="p">(</span><span class="n">forwarddiff_backend</span><span class="p">,</span> <span class="n">f</span><span class="p">,</span> <span class="n">x₀</span><span class="p">,</span> <span class="n">p</span><span class="p">)(</span><span class="n">vs</span><span class="p">)</span>
</span></span><span class="line"><span class="cl">
</span></span><span class="line"><span class="cl"><span class="nd">@test</span> <span class="n">res2</span> <span class="o">≈</span> <span class="n">res1</span><span class="p">[</span><span class="mi">1</span><span class="p">]</span> <span class="o">+</span> <span class="n">res1</span><span class="p">[</span><span class="mi">2</span><span class="p">]</span>
</span></span></code></pre></div><p>Thus, we currently solve this issue by augmenting the input in the case of <code>AD.pushforward_function</code> primitives.</p>
<div class="highlight"><pre tabindex="0" class="chroma"><code class="language-julia" data-lang="julia"><span class="line"><span class="cl"><span class="n">res2a</span> <span class="o">=</span> <span class="n">AD</span><span class="o">.</span><span class="n">pushforward_function</span><span class="p">(</span><span class="n">forwarddiff_backend</span><span class="p">,</span> <span class="n">f</span><span class="p">,</span> <span class="n">x₀</span><span class="p">,</span> <span class="n">p</span><span class="p">)((</span><span class="n">vs</span><span class="p">[</span><span class="mi">1</span><span class="p">],</span> <span class="n">zero</span><span class="p">(</span><span class="n">vs</span><span class="p">[</span><span class="mi">2</span><span class="p">])))</span>
</span></span><span class="line"><span class="cl"><span class="n">res2b</span> <span class="o">=</span> <span class="n">AD</span><span class="o">.</span><span class="n">pushforward_function</span><span class="p">(</span><span class="n">forwarddiff_backend</span><span class="p">,</span> <span class="n">f</span><span class="p">,</span> <span class="n">x₀</span><span class="p">,</span> <span class="n">p</span><span class="p">)((</span><span class="n">zero</span><span class="p">(</span><span class="n">vs</span><span class="p">[</span><span class="mi">1</span><span class="p">]),</span> <span class="n">vs</span><span class="p">[</span><span class="mi">2</span><span class="p">]))</span>
</span></span><span class="line"><span class="cl">
</span></span><span class="line"><span class="cl"><span class="nd">@test</span> <span class="n">res2a</span> <span class="o">≈</span> <span class="n">res1</span><span class="p">[</span><span class="mi">1</span><span class="p">]</span>
</span></span><span class="line"><span class="cl"><span class="nd">@test</span> <span class="n">res2b</span> <span class="o">≈</span> <span class="n">res1</span><span class="p">[</span><span class="mi">2</span><span class="p">]</span>
</span></span></code></pre></div><p>The plural &ldquo;primitives&rdquo; is used here because we may have different <code>pushforward_function</code> primitives for different backends. For instance, we can define an additional <code>pushforward_function</code> primitive for FiniteDifferences by:</p>
<div class="highlight"><pre tabindex="0" class="chroma"><code class="language-julia" data-lang="julia"><span class="line"><span class="cl"><span class="k">struct</span> <span class="kt">FDMBackend2</span><span class="p">{</span><span class="kt">A</span><span class="p">}</span> <span class="o">&lt;:</span> <span class="kt">AD</span><span class="o">.</span><span class="n">AbstractFiniteDifference</span>
</span></span><span class="line"><span class="cl"> <span class="n">alg</span><span class="o">::</span><span class="kt">A</span>
</span></span><span class="line"><span class="cl"><span class="k">end</span>
</span></span><span class="line"><span class="cl"><span class="n">FDMBackend2</span><span class="p">()</span> <span class="o">=</span> <span class="n">FDMBackend2</span><span class="p">(</span><span class="n">central_fdm</span><span class="p">(</span><span class="mi">5</span><span class="p">,</span> <span class="mi">1</span><span class="p">))</span>
</span></span><span class="line"><span class="cl"><span class="k">const</span> <span class="n">fdm_backend2</span> <span class="o">=</span> <span class="n">FDMBackend2</span><span class="p">()</span>
</span></span><span class="line"><span class="cl"><span class="n">AD</span><span class="o">.</span><span class="nd">@primitive</span> <span class="k">function</span> <span class="n">pushforward_function</span><span class="p">(</span><span class="n">ab</span><span class="o">::</span><span class="kt">FDMBackend2</span><span class="p">,</span> <span class="n">f</span><span class="p">,</span> <span class="n">xs</span><span class="o">...</span><span class="p">)</span>
</span></span><span class="line"><span class="cl"> <span class="k">return</span> <span class="k">function</span> <span class="p">(</span><span class="n">vs</span><span class="p">)</span>
</span></span><span class="line"><span class="cl"> <span class="n">FDM</span><span class="o">.</span><span class="n">jvp</span><span class="p">(</span><span class="n">ab</span><span class="o">.</span><span class="n">alg</span><span class="p">,</span> <span class="n">f</span><span class="p">,</span> <span class="n">tuple</span><span class="o">.</span><span class="p">(</span><span class="n">xs</span><span class="p">,</span> <span class="n">vs</span><span class="p">)</span><span class="o">...</span><span class="p">)</span>
</span></span><span class="line"><span class="cl"> <span class="k">end</span>
</span></span><span class="line"><span class="cl"><span class="k">end</span>
</span></span></code></pre></div><p>Second, to avoid misunderstandings for the output of a Hessian of a function with several input arguments, we allow only single input arguments to the <code>Hessian</code> function.</p>
<div class="highlight"><pre tabindex="0" class="chroma"><code class="language-julia" data-lang="julia"><span class="line"><span class="cl"><span class="c"># Hessian only defined with respect to single input variable</span>
</span></span><span class="line"><span class="cl"><span class="nd">@test_throws</span> <span class="kt">AssertionError</span> <span class="n">H1</span> <span class="o">=</span> <span class="n">AD</span><span class="o">.</span><span class="n">hessian</span><span class="p">(</span><span class="n">forwarddiff_backend</span><span class="p">,</span> <span class="n">g</span><span class="p">,</span> <span class="n">x₀</span><span class="p">,</span> <span class="n">p</span><span class="p">)</span>
</span></span><span class="line"><span class="cl"><span class="n">H1</span> <span class="o">=</span> <span class="n">AD</span><span class="o">.</span><span class="n">hessian</span><span class="p">(</span><span class="n">forwarddiff_backend</span><span class="p">,</span> <span class="n">x</span><span class="o">-&gt;</span><span class="n">g</span><span class="p">(</span><span class="n">x</span><span class="p">,</span><span class="n">p</span><span class="p">),</span> <span class="n">x₀</span><span class="p">)</span>
</span></span><span class="line"><span class="cl"><span class="n">H2</span> <span class="o">=</span> <span class="n">AD</span><span class="o">.</span><span class="n">hessian</span><span class="p">(</span><span class="n">forwarddiff_backend</span><span class="p">,</span> <span class="n">p</span><span class="o">-&gt;</span><span class="n">g</span><span class="p">(</span><span class="n">x₀</span><span class="p">,</span><span class="n">p</span><span class="p">),</span> <span class="n">p</span><span class="p">)</span>
</span></span></code></pre></div><p>Third, computing the Hessian requires to nest AD/backend calls. This can lead to failure if one tries to use Zygote over Zygote. To solve this problem, we have implemented a <code>HigherOrderBackend</code> that takes a tuple containing multiple backends (because, for example, using ForwardDiff over Zygote is perfectly fine).</p>
<div class="highlight"><pre tabindex="0" class="chroma"><code class="language-julia" data-lang="julia"><span class="line"><span class="cl"><span class="c"># Hessian might fail if AD system calls must not be nested (e.g. Zygote over Zygote)</span>
</span></span><span class="line"><span class="cl"><span class="n">backends</span> <span class="o">=</span> <span class="n">AD</span><span class="o">.</span><span class="n">HigherOrderBackend</span><span class="p">((</span><span class="n">forwarddiff_backend</span><span class="p">,</span><span class="n">zygote_backend</span><span class="p">))</span>
</span></span><span class="line"><span class="cl"><span class="n">H3</span> <span class="o">=</span> <span class="n">AD</span><span class="o">.</span><span class="n">hessian</span><span class="p">(</span><span class="n">backends</span><span class="p">,</span> <span class="n">x</span><span class="o">-&gt;</span><span class="n">g</span><span class="p">(</span><span class="n">x</span><span class="p">,</span><span class="n">p</span><span class="p">),</span> <span class="n">x₀</span><span class="p">)</span>
</span></span></code></pre></div><h2 id="outlook">Outlook</h2>
<p>There are many other use cases, e.g.,</p>
<ul>
<li><a href="https://diffeq.sciml.ai/stable/analysis/sensitivity/" target="_blank" rel="noopener">Sensitivity analysis of differential equations</a> requires vector-Jacobian products for adjoint methods and Jacobian-vector products for tangent methods.</li>
<li>The <a href="https://en.wikipedia.org/wiki/Newton%27s_method" target="_blank" rel="noopener">Newton–Raphson method</a> for rootfinding requires the gradient in the case of scalar function $f:\mathbb{R}\rightarrow\mathbb{R}$ and the Jacobian in case of $N$ (nonlinear) equations, i.e., finding the zeros of $f:\mathbb{R}^N\rightarrow\mathbb{R}^N$.</li>
<li>The <a href="https://en.wikipedia.org/wiki/Newton%27s_method_in_optimization" target="_blank" rel="noopener">Newton method</a> in optimization requires the computation of the Hessian.</li>
</ul>
<p>AbstractDifferentiation.jl is by no means complete yet. We are still in the very early stages, but we hope to make significant progress in the coming weeks. Some of the next steps are:</p>
<ul>
<li>fixing remaining bugs, e.g., with respect to the computation of the Hessian and</li>
<li>adding AD/Finite Differentiation packages such as <a href="https://github.com/JuliaDiff/Diffractor.jl" target="_blank" rel="noopener">Diffractor</a>.</li>
</ul>
<p>If you have any questions or comments, please don’t hesitate to contact me!</p>
<div class="footnotes" role="doc-endnotes">
<hr>
<ol>
<li id="fn:1">
<p>Mike Innes, Alan Edelman, et al., arXiv preprint arXiv:1907.07587 (2019).&#160;<a href="#fnref:1" class="footnote-backref" role="doc-backlink">&#x21a9;&#xfe0e;</a></p>
</li>
<li id="fn:2">
<p>Jeff Bezanson, Stefan Karpinski, et al., arXiv preprint arXiv:1209.5145 (2012).&#160;<a href="#fnref:2" class="footnote-backref" role="doc-backlink">&#x21a9;&#xfe0e;</a></p>
</li>
<li id="fn:3">
<p>Chris Rackauckas, The Winnower 8, DOI: 10.15200/winn.156631.13064 (2019).&#160;<a href="#fnref:3" class="footnote-backref" role="doc-backlink">&#x21a9;&#xfe0e;</a></p>
</li>
<li id="fn:4">
<p>Chris Rackauckas, Yingbo Ma, et al., arXiv preprint arXiv:2001.04385 (2020).&#160;<a href="#fnref:4" class="footnote-backref" role="doc-backlink">&#x21a9;&#xfe0e;</a>&#160;<a href="#fnref1:4" class="footnote-backref" role="doc-backlink">&#x21a9;&#xfe0e;</a></p>
</li>
<li id="fn:5">
<p>Raj Dandekar, Chris Rackauckas, et al., Patterns <strong>1</strong>, 100145 (2020).&#160;<a href="#fnref:5" class="footnote-backref" role="doc-backlink">&#x21a9;&#xfe0e;</a></p>
</li>
<li id="fn:6">
<p>Frank Schäfer, Michal Kloc, et al., Mach. Learn.: Sci. Technol. <strong>1</strong>, 035009 (2020).&#160;<a href="#fnref:6" class="footnote-backref" role="doc-backlink">&#x21a9;&#xfe0e;</a></p>
</li>
<li id="fn:7">
<p>Frank Schäfer, Pavel Sekatski, et al., Mach. Learn.: Sci. Technol. <strong>2</strong>, 035004 (2021).&#160;<a href="#fnref:7" class="footnote-backref" role="doc-backlink">&#x21a9;&#xfe0e;</a></p>
</li>
</ol>
</div>
</description>
</item>
<item>
<title>Sensitivity Analysis of Hybrid Differential Equations</title>
<link>https://frankschae.github.io/post/bouncing_ball/</link>
<pubDate>Fri, 16 Jul 2021 13:24:04 +0200</pubDate>
<guid>https://frankschae.github.io/post/bouncing_ball/</guid>
<description><p>In this post, we discuss sensitivity analysis of differential equations with state changes caused by events triggered at defined moments, for example reflections, bounces off a wall or other sudden forces. These are described by hybrid differential equations<sup id="fnref:1"><a href="#fn:1" class="footnote-ref" role="doc-noteref">1</a></sup>. We highlight differences between explicit<sup id="fnref:2"><a href="#fn:2" class="footnote-ref" role="doc-noteref">2</a></sup> and implicit events<sup id="fnref:3"><a href="#fn:3" class="footnote-ref" role="doc-noteref">3</a></sup> <sup id="fnref:4"><a href="#fn:4" class="footnote-ref" role="doc-noteref">4</a></sup>. As a paradigmatic example, we consider a bouncing ball described by the ODE</p>
<p>$$
\begin{aligned}
\text{d}z(t) &amp;= v(t) \text{d}t, \\
\text{d}v(t) &amp;= -\mathrm g\thinspace \text{d}t
\end{aligned}<br>
$$</p>
<p>with initial condition</p>
<p>$$
\begin{aligned}
z(t=0) &amp;= z_0 = 5, \\
v(t=0) &amp;= v_0 = -0.1.
\end{aligned}<br>
$$</p>
<p>The initial condition contains the initial height $z_0$ and initial velocity $v_0$ of the ball. We have two important parameters in this system. First, there is the gravitational constant $\mathrm g=10$ modeling the acceleration of the ball due to an approximately constant gravitational field.</p>
<p>Second, we model the ground as barrier at $z = 0$ where the ball bounces off in opposite direction. We include a dissipation factor $\gamma=0.8$ (<a href="https://en.wikipedia.org/wiki/Coefficient_of_restitution" target="_blank" rel="noopener">coefficient of restitution</a>) that accounts for a imperfect elastic bounce on the ground.</p>
<p>When ignoring the bounces, we can straightforwardly integrate the ODE analytically</p>
<p>$$
\begin{aligned}
z(t) &amp;= z_0 + v_0 t - \frac{\mathrm g}{2} t^2, \\
v(t) &amp;= v_0 - \mathrm g\thinspace t
\end{aligned}<br>
$$</p>
<p>or numerically using the OrdinaryDiffEq package from the <a href="https://sciml.ai/" target="_blank" rel="noopener">SciML</a> ecosystem.</p>
<div class="highlight"><pre tabindex="0" class="chroma"><code class="language-julia" data-lang="julia"><span class="line"><span class="cl"><span class="c">### simulate forward</span>
</span></span><span class="line"><span class="cl">
</span></span><span class="line"><span class="cl"><span class="k">using</span> <span class="n">ForwardDiff</span><span class="p">,</span> <span class="n">Zygote</span><span class="p">,</span> <span class="n">OrdinaryDiffEq</span><span class="p">,</span> <span class="n">DiffEqSensitivity</span>
</span></span><span class="line"><span class="cl"><span class="k">using</span> <span class="n">Plots</span><span class="p">,</span> <span class="n">LaTeXStrings</span>
</span></span><span class="line"><span class="cl">
</span></span><span class="line"><span class="cl"><span class="c"># dynamics</span>
</span></span><span class="line"><span class="cl"><span class="k">function</span> <span class="n">f</span><span class="p">(</span><span class="n">du</span><span class="p">,</span><span class="n">u</span><span class="p">,</span><span class="n">p</span><span class="p">,</span><span class="n">t</span><span class="p">)</span>
</span></span><span class="line"><span class="cl"> <span class="n">du</span><span class="p">[</span><span class="mi">1</span><span class="p">]</span> <span class="o">=</span> <span class="n">u</span><span class="p">[</span><span class="mi">2</span><span class="p">]</span>
</span></span><span class="line"><span class="cl"> <span class="n">du</span><span class="p">[</span><span class="mi">2</span><span class="p">]</span> <span class="o">=</span> <span class="o">-</span><span class="n">p</span><span class="p">[</span><span class="mi">1</span><span class="p">]</span>
</span></span><span class="line"><span class="cl"><span class="k">end</span>
</span></span><span class="line"><span class="cl">
</span></span><span class="line"><span class="cl"><span class="c"># parameters and solve</span>
</span></span><span class="line"><span class="cl"><span class="n">z0</span> <span class="o">=</span> <span class="mf">5.0</span>
</span></span><span class="line"><span class="cl"><span class="n">v0</span> <span class="o">=</span> <span class="o">-</span><span class="mf">0.1</span>
</span></span><span class="line"><span class="cl"><span class="n">t0</span> <span class="o">=</span> <span class="mf">0.0</span>
</span></span><span class="line"><span class="cl"><span class="n">tend</span> <span class="o">=</span> <span class="mf">1.9</span>
</span></span><span class="line"><span class="cl"><span class="n">g</span> <span class="o">=</span> <span class="mi">10</span>
</span></span><span class="line"><span class="cl"><span class="n">γ</span> <span class="o">=</span> <span class="mf">0.8</span>
</span></span><span class="line"><span class="cl">
</span></span><span class="line"><span class="cl"><span class="n">u0</span> <span class="o">=</span> <span class="p">[</span><span class="n">z0</span><span class="p">,</span><span class="n">v0</span><span class="p">]</span>
</span></span><span class="line"><span class="cl"><span class="n">tspan</span> <span class="o">=</span> <span class="p">(</span><span class="n">t0</span><span class="p">,</span><span class="n">tend</span><span class="p">)</span>
</span></span><span class="line"><span class="cl"><span class="n">p</span> <span class="o">=</span> <span class="p">[</span><span class="n">g</span><span class="p">,</span> <span class="n">γ</span><span class="p">]</span>
</span></span><span class="line"><span class="cl"><span class="n">prob</span> <span class="o">=</span> <span class="n">ODEProblem</span><span class="p">(</span><span class="n">f</span><span class="p">,</span><span class="n">u0</span><span class="p">,</span><span class="n">tspan</span><span class="p">,</span><span class="n">p</span><span class="p">)</span>
</span></span><span class="line"><span class="cl">
</span></span><span class="line"><span class="cl"><span class="c"># plot forward trajectory</span>
</span></span><span class="line"><span class="cl"><span class="n">sol</span> <span class="o">=</span> <span class="n">solve</span><span class="p">(</span><span class="n">prob</span><span class="p">,</span><span class="n">Tsit5</span><span class="p">(),</span><span class="n">saveat</span><span class="o">=</span><span class="mf">0.1</span><span class="p">)</span>
</span></span><span class="line"><span class="cl"><span class="n">pl</span> <span class="o">=</span> <span class="n">plot</span><span class="p">(</span><span class="n">sol</span><span class="p">,</span> <span class="n">label</span> <span class="o">=</span> <span class="p">[</span><span class="s">&#34;z(t)&#34;</span> <span class="s">&#34;v(t)&#34;</span><span class="p">],</span> <span class="n">labelfontsize</span><span class="o">=</span><span class="mi">20</span><span class="p">,</span> <span class="n">legendfontsize</span><span class="o">=</span><span class="mi">20</span><span class="p">,</span> <span class="n">lw</span> <span class="o">=</span> <span class="mi">2</span><span class="p">,</span> <span class="n">xlabel</span> <span class="o">=</span> <span class="s">&#34;t&#34;</span><span class="p">,</span> <span class="n">legend</span><span class="o">=</span><span class="ss">:bottomleft</span><span class="p">)</span>
</span></span><span class="line"><span class="cl"><span class="n">hline!</span><span class="p">(</span><span class="n">pl</span><span class="p">,</span> <span class="p">[</span><span class="mf">0.0</span><span class="p">],</span> <span class="n">label</span><span class="o">=</span><span class="nb">false</span><span class="p">,</span> <span class="n">color</span><span class="o">=</span><span class="s">&#34;black&#34;</span><span class="p">)</span>
</span></span><span class="line"><span class="cl"><span class="n">savefig</span><span class="p">(</span><span class="n">pl</span><span class="p">,</span><span class="s">&#34;BB_forward_no_bounce.png&#34;</span><span class="p">)</span>
</span></span></code></pre></div><p>
<figure >
<div class="d-flex justify-content-center">
<div class="w-100" ><img src="https://i.imgur.com/fKTiDEe.png" alt="" loading="lazy" data-zoomable /></div>
</div></figure>
</p>
<figure >
<div class="d-flex justify-content-center">
<div class="w-100" ><img src="https://frankschae.github.io/img/BB_forward_no_bounce.png" alt="" loading="lazy" data-zoomable /></div>
</div></figure>
<p>Of course, this way the ball continues to fall through the barrier at $z=0$.</p>
<h2 id="forward-simulation-with-events">Forward simulation with events</h2>
<p>At time $\tau$ around $\tau \approx 1$, the ball hits the ground $z(\tau) = 0$, and is inelastically reflected while dissipating a fraction of its energy. This can be modeled by re-initializing the ODE at time $\tau$ with new initial conditions</p>
<p>$$
\begin{aligned}<br>
z({\tau}) &amp;= z(\tau-) ,\\
v({\tau})&amp;= -\gamma v(\tau-) ,
\end{aligned}<br>
$$</p>
<p>so that there is a jump in the velocity at the event time: the velocity right before the bounce, the left limit $v(\tau-)$, and the velocity with which the ball continues its movement after the bounce $v(\tau)$, are different.</p>
<p>Given our analytical solution for the state as a function of time, we can easily compute the event time $\tau$ in terms of the initial condition and parameters as</p>
<p>$$
\tau = \frac{v_0 + \sqrt{v_0^2 + 2 \mathrm g z_0}}{\mathrm g}.
$$</p>
<h3 id="explicit-events">Explicit events</h3>
<p>We can define the bounce of the ball as an explicit event by inserting the values of the initial condition and the parameters into the formula for $\tau$. We obtain</p>
<p>$$
\tau = 0.99005.
$$</p>
<p>The full explicit trajectory $z_{\rm exp}(t) = z(t)$ is determined by</p>
<p>$$
z(t) = \begin{cases}
z_0 + v_0 t - \dfrac{\mathrm g}{2} t^2 ,&amp; \forall t &lt; \tau, \\
-0.4901 \mathrm g - 0.5 \mathrm g (-0.99005 + t)^2 + 0.99005 v_0 + z_0\\
\quad - (-0.99005 + t) (-0.99005 \mathrm g + v_0)\gamma ,&amp; \forall t \ge \tau,
\end{cases}
$$</p>
<p>where we used</p>
<p>$$
\begin{aligned}<br>
z({\tau})&amp;= z_0 + 0.99005 v_0 -0.4901 \mathrm g, \\
v({\tau})&amp;= -\gamma v({\tau-}) = -\gamma(v_0 - 0.99005 \mathrm g) .
\end{aligned}<br>
$$</p>
<p>Here the change in state $(z,v)$ at the event time is defined with the help of an <em>affect function</em></p>
<p>$$
a(z,v) = (z, -\gamma v).
$$</p>
<p>Numerically, we use a <code>DiscreteCallback</code> in this case to simulate the system.</p>
<div class="highlight"><pre tabindex="0" class="chroma"><code class="language-julia" data-lang="julia"><span class="line"><span class="cl"><span class="c"># solve with DiscreteCallback (explicit event)</span>
</span></span><span class="line"><span class="cl"><span class="n">tstar</span> <span class="o">=</span> <span class="p">(</span><span class="n">v0</span> <span class="o">+</span> <span class="n">sqrt</span><span class="p">(</span><span class="n">v0</span><span class="o">^</span><span class="mi">2</span><span class="o">+</span><span class="mi">2</span><span class="o">*</span><span class="n">z0</span><span class="o">*</span><span class="n">g</span><span class="p">))</span><span class="o">/</span><span class="n">g</span>
</span></span><span class="line"><span class="cl"><span class="n">condition1</span><span class="p">(</span><span class="n">u</span><span class="p">,</span><span class="n">t</span><span class="p">,</span><span class="n">integrator</span><span class="p">)</span> <span class="o">=</span> <span class="p">(</span><span class="n">t</span> <span class="o">==</span> <span class="n">tstar</span><span class="p">)</span>
</span></span><span class="line"><span class="cl"><span class="n">affect!</span><span class="p">(</span><span class="n">integrator</span><span class="p">)</span> <span class="o">=</span> <span class="n">integrator</span><span class="o">.</span><span class="n">u</span><span class="p">[</span><span class="mi">2</span><span class="p">]</span> <span class="o">=</span> <span class="o">-</span><span class="n">integrator</span><span class="o">.</span><span class="n">p</span><span class="p">[</span><span class="mi">2</span><span class="p">]</span><span class="o">*</span><span class="n">integrator</span><span class="o">.</span><span class="n">u</span><span class="p">[</span><span class="mi">2</span><span class="p">]</span>
</span></span><span class="line"><span class="cl"><span class="n">cb1</span> <span class="o">=</span> <span class="n">DiscreteCallback</span><span class="p">(</span><span class="n">condition1</span><span class="p">,</span><span class="n">affect!</span><span class="p">,</span><span class="n">save_positions</span><span class="o">=</span><span class="p">(</span><span class="nb">true</span><span class="p">,</span><span class="nb">true</span><span class="p">))</span>
</span></span><span class="line"><span class="cl">
</span></span><span class="line"><span class="cl"><span class="n">sol1</span> <span class="o">=</span> <span class="n">solve</span><span class="p">(</span><span class="n">prob</span><span class="p">,</span><span class="n">Tsit5</span><span class="p">(),</span><span class="n">callback</span><span class="o">=</span><span class="n">cb1</span><span class="p">,</span> <span class="n">saveat</span><span class="o">=</span><span class="mf">0.1</span><span class="p">,</span> <span class="n">tstops</span><span class="o">=</span><span class="p">[</span><span class="n">tstar</span><span class="p">])</span>
</span></span></code></pre></div><p>Evidently, by choosing an explicit definition of the event, the impact time is fixed. The reflection event is triggered at $\tau = 0.99005$, a time where under different initial configurations the ball perhaps hasn’t reached the ground.</p>
<h3 id="implicit-events">Implicit events</h3>
<p>The physically more meaningful description of a bouncing ball is therefore given by an implicit description of the event in form of a condition (event function)</p>
<p>$$
g(z,v,p,t),
$$</p>
<p>where an event occurs at time $\tau$ if $g(z(\tau),v(\tau),p,\tau) = 0$. We have already used this condition to define our impact time $\tau$ when modeling the bounce explicitly. The implicit formulation also lends itself to take multiple bounces into account by triggering the event every time $g(z,v,p,t) = 0$.</p>
<p>As in the previous case, we can analytically compute the full trajectory of the ball. By substituting the formula for $\tau$ we have at the event time</p>
<p>\begin{aligned}
z({\tau})&amp;= 0, \\
v({\tau}-)&amp;= - \sqrt{v_0^2 + 2 \mathrm g z_0}
\end{aligned}</p>
<p>for the left limit and</p>
<p>\begin{aligned}
v({\tau})&amp;= \gamma \sqrt{v_0^2 + 2 \mathrm g z_0}
\end{aligned}</p>
<p>right after the bounce. Thus, the full trajectory $z_{\rm imp}(t) = z(t)$ is given by</p>
<p>$$
(\star) \quad z(t) = \begin{cases}
z_0 + v_0 t - \dfrac{\mathrm g}{2} t^2 ,&amp; \forall t &lt; \tau ,\\
-\dfrac{-\mathrm g t + v_0 + \sqrt{v_0^2 + 2 \mathrm g z_0}}{2 \mathrm g} \\
\quad\cdot \space (-\mathrm g t + v_0 + \sqrt{v_0^2 + 2 \mathrm g z_0} (1 + 2 \gamma)), &amp; \forall t \ge \tau.
\end{cases}
$$</p>
<p>This is correct even if one substitutes, e.g., a value with higher precision $\mathrm g = 9.81$ for the gravitation constant.</p>
<p>Numerically, we use a <code>ContinuousCallback</code> in this case.</p>
<div class="highlight"><pre tabindex="0" class="chroma"><code class="language-julia" data-lang="julia"><span class="line"><span class="cl"><span class="c"># solve with ContinuousCallback (implicit event)</span>
</span></span><span class="line"><span class="cl"><span class="n">condition2</span><span class="p">(</span><span class="n">u</span><span class="p">,</span><span class="n">t</span><span class="p">,</span><span class="n">integrator</span><span class="p">)</span> <span class="o">=</span> <span class="n">u</span><span class="p">[</span><span class="mi">1</span><span class="p">]</span> <span class="c"># Event happens when condition2(u,t,integrator) == 0</span>
</span></span><span class="line"><span class="cl"><span class="n">cb2</span> <span class="o">=</span> <span class="n">ContinuousCallback</span><span class="p">(</span><span class="n">condition2</span><span class="p">,</span><span class="n">affect!</span><span class="p">,</span><span class="n">save_positions</span><span class="o">=</span><span class="p">(</span><span class="nb">true</span><span class="p">,</span><span class="nb">true</span><span class="p">))</span>
</span></span><span class="line"><span class="cl"><span class="n">sol2</span> <span class="o">=</span> <span class="n">solve</span><span class="p">(</span><span class="n">prob</span><span class="p">,</span><span class="n">Tsit5</span><span class="p">(),</span><span class="n">callback</span><span class="o">=</span><span class="n">cb2</span><span class="p">,</span><span class="n">saveat</span><span class="o">=</span><span class="mf">0.1</span><span class="p">)</span>
</span></span></code></pre></div><p>We can verify that both callbacks lead to the same forward time evolution (for fixed initial conditions and parameters).</p>
<div class="highlight"><pre tabindex="0" class="chroma"><code class="language-julia" data-lang="julia"><span class="line"><span class="cl"><span class="c"># plot forward trajectory</span>
</span></span><span class="line"><span class="cl"><span class="n">pl1</span> <span class="o">=</span> <span class="n">plot</span><span class="p">(</span><span class="n">sol1</span><span class="p">,</span> <span class="n">label</span> <span class="o">=</span> <span class="p">[</span><span class="s">&#34;z(t)&#34;</span> <span class="s">&#34;v(t)&#34;</span><span class="p">],</span> <span class="n">title</span><span class="o">=</span><span class="s">&#34;explicit event&#34;</span><span class="p">,</span> <span class="n">labelfontsize</span><span class="o">=</span><span class="mi">20</span><span class="p">,</span> <span class="n">legendfontsize</span><span class="o">=</span><span class="mi">20</span><span class="p">,</span> <span class="n">lw</span> <span class="o">=</span> <span class="mi">2</span><span class="p">,</span> <span class="n">xlabel</span> <span class="o">=</span> <span class="s">&#34;t&#34;</span><span class="p">,</span> <span class="n">legend</span><span class="o">=</span><span class="ss">:bottomright</span><span class="p">)</span>
</span></span><span class="line"><span class="cl"><span class="n">pl2</span> <span class="o">=</span> <span class="n">plot</span><span class="p">(</span><span class="n">sol2</span><span class="p">,</span> <span class="n">label</span> <span class="o">=</span> <span class="p">[</span><span class="s">&#34;z(t)&#34;</span> <span class="s">&#34;v(t)&#34;</span><span class="p">],</span> <span class="n">title</span><span class="o">=</span><span class="s">&#34;implicit event&#34;</span><span class="p">,</span> <span class="n">labelfontsize</span><span class="o">=</span><span class="mi">20</span><span class="p">,</span> <span class="n">legendfontsize</span><span class="o">=</span><span class="mi">20</span><span class="p">,</span> <span class="n">lw</span> <span class="o">=</span> <span class="mi">2</span><span class="p">,</span> <span class="n">xlabel</span> <span class="o">=</span> <span class="s">&#34;t&#34;</span><span class="p">,</span> <span class="n">legend</span><span class="o">=</span><span class="ss">:bottomright</span><span class="p">)</span>
</span></span><span class="line"><span class="cl"><span class="n">hline!</span><span class="p">(</span><span class="n">pl1</span><span class="p">,</span> <span class="p">[</span><span class="mf">0.0</span><span class="p">],</span> <span class="n">label</span><span class="o">=</span><span class="nb">false</span><span class="p">,</span> <span class="n">color</span><span class="o">=</span><span class="s">&#34;black&#34;</span><span class="p">)</span>
</span></span><span class="line"><span class="cl"><span class="n">hline!</span><span class="p">(</span><span class="n">pl2</span><span class="p">,</span> <span class="p">[</span><span class="mf">0.0</span><span class="p">],</span> <span class="n">label</span><span class="o">=</span><span class="nb">false</span><span class="p">,</span> <span class="n">color</span><span class="o">=</span><span class="s">&#34;black&#34;</span><span class="p">)</span>
</span></span><span class="line"><span class="cl"><span class="n">pl</span> <span class="o">=</span> <span class="n">plot</span><span class="p">(</span><span class="n">pl1</span><span class="p">,</span><span class="n">pl2</span><span class="p">)</span>
</span></span><span class="line"><span class="cl"><span class="n">savefig</span><span class="p">(</span><span class="n">pl</span><span class="p">,</span><span class="s">&#34;BB_forward.png&#34;</span><span class="p">)</span>
</span></span></code></pre></div>
<figure >
<div class="d-flex justify-content-center">
<div class="w-100" ><img src="https://frankschae.github.io/img/BB_forward.png" alt="" loading="lazy" data-zoomable /></div>
</div></figure>
<p>In addition, the implicitly defined impact time via the <code>ContinuousCallback</code> also changes appropriately when changing the initial conditions or the parameters, for example when using $\mathrm g = 9.81$ for the gravitation constant. In other words, the event time $\tau=\tau(p,z_0,v_0,t_0)$ is a function of the parameters and initial conditions, and is implicitly defined by the event condition.</p>
<p>Suppose we let the ball drop from a somewhat higher position now. Does an increase in height $z$ at $t=0$ give an increase or decrease in height at the end time $t_\text{end}=1.9$? This is something we can answer with sensitivity analysis. For example if we increase the height by (a fraction of) one unit then using $(\star)$</p>
<p>$$
\frac{\text{d} z(t_\text{end})}{\text{d} z_0} = 0.84,
$$</p>
<p>meaning the height at $t_\text{end}$ is also by a corresponding fraction of 0.84 units higher.</p>
<p>We can verify this visually:</p>
<div class="highlight"><pre tabindex="0" class="chroma"><code class="language-julia" data-lang="julia"><span class="line"><span class="cl"><span class="c"># animate forward trajectory</span>
</span></span><span class="line"><span class="cl"><span class="n">sol3</span> <span class="o">=</span> <span class="n">solve</span><span class="p">(</span><span class="n">remake</span><span class="p">(</span><span class="n">prob</span><span class="p">,</span><span class="n">u0</span><span class="o">=</span><span class="p">[</span><span class="n">u0</span><span class="p">[</span><span class="mi">1</span><span class="p">]</span><span class="o">+</span><span class="mf">0.5</span><span class="p">,</span><span class="n">u0</span><span class="p">[</span><span class="mi">2</span><span class="p">]]),</span><span class="n">Tsit5</span><span class="p">(),</span><span class="n">callback</span><span class="o">=</span><span class="n">cb2</span><span class="p">,</span><span class="n">saveat</span><span class="o">=</span><span class="mf">0.01</span><span class="p">)</span>
</span></span><span class="line"><span class="cl">
</span></span><span class="line"><span class="cl"><span class="n">plt2</span> <span class="o">=</span> <span class="n">plot</span><span class="p">(</span><span class="n">sol2</span><span class="p">,</span> <span class="n">label</span> <span class="o">=</span> <span class="nb">false</span><span class="p">,</span> <span class="n">labelfontsize</span><span class="o">=</span><span class="mi">20</span><span class="p">,</span> <span class="n">legendfontsize</span><span class="o">=</span><span class="mi">20</span><span class="p">,</span> <span class="n">lw</span> <span class="o">=</span> <span class="mi">1</span><span class="p">,</span> <span class="n">xlabel</span> <span class="o">=</span> <span class="s">&#34;t&#34;</span><span class="p">,</span> <span class="n">legend</span><span class="o">=</span><span class="ss">:bottomright</span><span class="p">,</span> <span class="n">color</span><span class="o">=</span><span class="s">&#34;black&#34;</span><span class="p">,</span> <span class="n">xlims</span><span class="o">=</span><span class="p">(</span><span class="n">t0</span><span class="p">,</span><span class="n">tend</span><span class="p">))</span>
</span></span><span class="line"><span class="cl"><span class="n">hline!</span><span class="p">(</span><span class="n">plt2</span><span class="p">,</span> <span class="p">[</span><span class="mf">0.0</span><span class="p">],</span> <span class="n">label</span><span class="o">=</span><span class="nb">false</span><span class="p">,</span> <span class="n">color</span><span class="o">=</span><span class="s">&#34;black&#34;</span><span class="p">)</span>
</span></span><span class="line"><span class="cl"><span class="n">plot!</span><span class="p">(</span><span class="n">plt2</span><span class="p">,</span> <span class="n">sol3</span><span class="p">,</span> <span class="n">tspan</span><span class="o">=</span><span class="p">(</span><span class="n">t0</span><span class="p">,</span><span class="n">tend</span><span class="p">),</span> <span class="n">color</span><span class="o">=</span><span class="p">[</span><span class="mi">1</span> <span class="mi">2</span><span class="p">],</span> <span class="n">label</span> <span class="o">=</span> <span class="p">[</span><span class="s">&#34;z(t)&#34;</span> <span class="s">&#34;v(t)&#34;</span><span class="p">],</span> <span class="n">labelfontsize</span><span class="o">=</span><span class="mi">20</span><span class="p">,</span> <span class="n">legendfontsize</span><span class="o">=</span><span class="mi">20</span><span class="p">,</span> <span class="n">lw</span> <span class="o">=</span> <span class="mi">2</span><span class="p">,</span> <span class="n">xlabel</span> <span class="o">=</span> <span class="s">&#34;t&#34;</span><span class="p">,</span> <span class="n">legend</span><span class="o">=</span><span class="ss">:bottomright</span><span class="p">,</span> <span class="n">denseplot</span><span class="o">=</span><span class="nb">true</span><span class="p">,</span> <span class="n">xlims</span><span class="o">=</span><span class="p">(</span><span class="n">t0</span><span class="p">,</span><span class="n">tend</span><span class="p">),</span> <span class="n">ylims</span><span class="o">=</span><span class="p">(</span><span class="o">-</span><span class="mi">11</span><span class="p">,</span><span class="mi">9</span><span class="p">))</span>
</span></span><span class="line"><span class="cl"><span class="c"># scatter!(plt2, [t2,t2], sol3(t2), color=[1, 2], label=false)</span>
</span></span><span class="line"><span class="cl">
</span></span><span class="line"><span class="cl"><span class="n">list_plots</span> <span class="o">=</span> <span class="p">[]</span>
</span></span><span class="line"><span class="cl"><span class="k">for</span> <span class="n">t</span> <span class="k">in</span> <span class="n">sol3</span><span class="o">.</span><span class="n">t</span>
</span></span><span class="line"><span class="cl"> <span class="n">tstart</span> <span class="o">=</span> <span class="mf">0.0</span>
</span></span><span class="line"><span class="cl">
</span></span><span class="line"><span class="cl"> <span class="n">plt1</span> <span class="o">=</span> <span class="n">plot</span><span class="p">(</span><span class="n">sol2</span><span class="p">,</span> <span class="n">label</span> <span class="o">=</span> <span class="nb">false</span><span class="p">,</span> <span class="n">labelfontsize</span><span class="o">=</span><span class="mi">20</span><span class="p">,</span> <span class="n">legendfontsize</span><span class="o">=</span><span class="mi">20</span><span class="p">,</span> <span class="n">lw</span> <span class="o">=</span> <span class="mi">1</span><span class="p">,</span> <span class="n">xlabel</span> <span class="o">=</span> <span class="s">&#34;t&#34;</span><span class="p">,</span> <span class="n">legend</span><span class="o">=</span><span class="ss">:bottomright</span><span class="p">,</span> <span class="n">color</span><span class="o">=</span><span class="s">&#34;black&#34;</span><span class="p">)</span>
</span></span><span class="line"><span class="cl"> <span class="n">hline!</span><span class="p">(</span><span class="n">plt1</span><span class="p">,</span> <span class="p">[</span><span class="mf">0.0</span><span class="p">],</span> <span class="n">label</span><span class="o">=</span><span class="nb">false</span><span class="p">,</span> <span class="n">color</span><span class="o">=</span><span class="s">&#34;black&#34;</span><span class="p">)</span>
</span></span><span class="line"><span class="cl"> <span class="n">plot!</span><span class="p">(</span><span class="n">plt1</span><span class="p">,</span> <span class="n">sol3</span><span class="p">,</span> <span class="n">tspan</span><span class="o">=</span><span class="p">(</span><span class="n">t0</span><span class="p">,</span><span class="n">t</span><span class="p">),</span> <span class="n">color</span><span class="o">=</span><span class="p">[</span><span class="mi">1</span> <span class="mi">2</span><span class="p">],</span> <span class="n">label</span> <span class="o">=</span> <span class="p">[</span><span class="s">&#34;z(t)&#34;</span> <span class="s">&#34;v(t)&#34;</span><span class="p">],</span> <span class="n">labelfontsize</span><span class="o">=</span><span class="mi">20</span><span class="p">,</span> <span class="n">legendfontsize</span><span class="o">=</span><span class="mi">20</span><span class="p">,</span> <span class="n">lw</span> <span class="o">=</span> <span class="mi">2</span><span class="p">,</span> <span class="n">xlabel</span> <span class="o">=</span> <span class="s">&#34;t&#34;</span><span class="p">,</span> <span class="n">legend</span><span class="o">=</span><span class="ss">:bottomright</span><span class="p">,</span> <span class="n">denseplot</span><span class="o">=</span><span class="nb">true</span><span class="p">,</span> <span class="n">xlims</span><span class="o">=</span><span class="p">(</span><span class="n">t0</span><span class="p">,</span><span class="n">tend</span><span class="p">),</span> <span class="n">ylims</span><span class="o">=</span><span class="p">(</span><span class="o">-</span><span class="mi">11</span><span class="p">,</span><span class="mi">9</span><span class="p">))</span>
</span></span><span class="line"><span class="cl"> <span class="n">scatter!</span><span class="p">(</span><span class="n">plt1</span><span class="p">,[</span><span class="n">t</span><span class="p">,</span><span class="n">t</span><span class="p">],</span> <span class="n">sol3</span><span class="p">(</span><span class="n">t</span><span class="p">),</span> <span class="n">color</span><span class="o">=</span><span class="p">[</span><span class="mi">1</span><span class="p">,</span> <span class="mi">2</span><span class="p">],</span> <span class="n">label</span><span class="o">=</span><span class="nb">false</span><span class="p">)</span>
</span></span><span class="line"><span class="cl"> <span class="n">plt</span> <span class="o">=</span> <span class="n">plot</span><span class="p">(</span><span class="n">plt1</span><span class="p">,</span><span class="n">plt2</span><span class="p">)</span>
</span></span><span class="line"><span class="cl"> <span class="n">push!</span><span class="p">(</span><span class="n">list_plots</span><span class="p">,</span> <span class="n">plt</span><span class="p">)</span>
</span></span><span class="line"><span class="cl"><span class="k">end</span>
</span></span><span class="line"><span class="cl">
</span></span><span class="line"><span class="cl"><span class="n">plot</span><span class="p">(</span><span class="n">list_plots</span><span class="p">[</span><span class="mi">100</span><span class="p">])</span>
</span></span><span class="line"><span class="cl">
</span></span><span class="line"><span class="cl"><span class="n">anim</span> <span class="o">=</span> <span class="n">animate</span><span class="p">(</span><span class="n">list_plots</span><span class="p">,</span><span class="n">every</span><span class="o">=</span><span class="mi">1</span><span class="p">)</span>
</span></span></code></pre></div>
<figure >
<div class="d-flex justify-content-center">
<div class="w-100" ><img src="https://frankschae.github.io/img/BB.gif" alt="" loading="lazy" data-zoomable /></div>
</div></figure>
<p>The original curve is shown in black in the figure above.</p>
<h2 id="sensitivity-analysis-with-events">Sensitivity analysis with events</h2>
<p>In more general terms, the last example can be seen as a loss function acting on $z$ at the end time</p>
<p>$$
L^{\text{exp}} = z(t_{\text{end}}),
$$</p>
<p>where the superscript &ldquo;exp&rdquo; refers to an explicit definition of the time ($t_{\text{end}}$) at which we evaluate the loss function. Let $\alpha$ denote any of the inputs $(z_0,v_0,g,\gamma)$. The sensitivity with respect to $\alpha$ is then given by the chain rule</p>
<p>$$
\frac{\text{d}L^{\text{exp}}}{\text{d} \alpha} = \frac{\text{d}L^{\text{exp}}}{\text{d} z} \frac{\text{d}z(t_{\text{end}})}{\text{d} \alpha} = \frac{\text{d}z(t_{\text{end}})}{\text{d} \alpha}.
$$</p>
<p>Inserting our results for $z(t) = z_{\rm exp}(t)$ instead of $z(t) = z_{\rm imp}(t)$ at $t_{\text{end}}$, a different value for the sensitivity is obtained (a value ignoring the changes in $z$ due to changes in the bouncing time $\tau$), e.g.,</p>
<p>$$
\quad \frac{\text{d}L^{\text{exp}}}{\text{d} z_0} = \begin{cases}
1 &amp; \text{with } z(t) = z_{\rm exp}(t) ,\\
0.84 &amp; \text{with } z(t) = z_{\rm imp}(t) .
\end{cases}
$$</p>
<p>Besides an explicit description of time points, we may also encounter implicitly defined time points. For example, for the bouncing ball with velocity at the impact time $\tau-$ (i.e., at the left limit) as the quantity of interest,</p>
<p>$$
L = v(\tau-),
$$</p>
<p>we could be interested in the sensitivity of $L$ with respect $g$. Using the velocity $v(t)$ at the explicit time $t = 0.9905$ instead of the implicit $v(\tau-)$, again a different value for the sensitivity is obtained (a value again ignoring the changes in $\tau$ and not the one we are looking for), even though $\tau = 0.9905$ in both cases:</p>
<p>$$
\quad \frac{\text{d}L}{\text{d} g} = \begin{cases}
-0.99005 &amp; \text{with } v(t) = v_{\rm exp}(t) ,\\
-\frac{g}{\sqrt{v_0^2 + 2 g z_0}} = -0.99995 &amp; \text{ with } v(t) = v_{\rm imp}(t) .
\end{cases}
$$</p>
<p>Thanks to our analytical results for the bouncing ball, the sensitivity computation has been straightforward up to now. However, in most systems, we won&rsquo;t be able to solve analytically a differential equation</p>
<p>$$
\text{d}x(t) = f(x,p,t) \text{d}t
$$</p>
<p>with initial condition $x_0=x(t_0)$. Instead, we have to numerically solve for the trajectory $x(t)$.</p>
<p>More completely, we will in the following derive an adjoint sensitivity method for a loss function</p>
<p>$$L = \sum_j L_j(\tau_j,x(\tau_j),p) + \sum_i L^{\text{exp}}_i(s_j,x(s_i),p) ,$$</p>
<p>with $L_i^{\text{exp}}$ at explicit time points $s_i$, such as $t_{\text{end}}$, and $L_j(\tau_j,x(\tau_j),p)$ at implicit time points, such as $\tau$, which allows us to compute the sensitivity of $L$ with respect to changes of the parameters or initial condition without the requirement of an analytical solution for $x(t)$.</p>
<h3 id="backsolve-adjoint-algorithm-for-ordinary-differential-equations">Backsolve-Adjoint algorithm for ordinary differential equations</h3>
<p>Taking derivatives (or finding sensitivities) works in a beautiful mechanical way. We or a computer can find the derivatives of complex expressions by just repeatedly applying the chain rule.</p>
<p>We write</p>
<p>$$\text{solve}(t_0, x_0, t, p)$$</p>
<p>$(= x(t))$ for the functional solution of the ODE at time $t$.</p>
<p>Regarding the computation of the sensitivities (the derivatives of the function <code>solve</code>), we may then choose one of the <a href="https://diffeq.sciml.ai/stable/analysis/sensitivity/" target="_blank" rel="noopener">available algorithms</a> for the given differential equation. Currently, <code>BacksolveAdjoint()</code>, <code>InterpolatingAdjoint()</code>, <code>QuadratureAdjoint()</code>, <code>ReverseDiffAdjoint()</code>, <code>TrackerAdjoint()</code>, and <code>ForwardDiffAdjoint()</code> are compatible with events in ordinary differential equations.</p>
<p>Let us focus on the <code>BacksolveAdjoint()</code> algorithm which computes the sensitivities</p>
<p>$$
\begin{aligned}
\frac{\text{d}\thinspace L(\text{solve}(t_0, x_0, t, p))}{\text{d}x_{0}} &amp;= \lambda(t_{0}),\\
\frac{\text{d}\thinspace L(\text{solve}(t_0, x_0, t, p))}{\text{d}p} &amp;= \lambda_{p}(t_{0}),
\end{aligned}
$$</p>
<p>of a loss function $L$ acting on the final state with respect to the initial state and the parameters. It does so by solving an ODE for $\lambda(s)$ in reverse time from $t$ to $t_0$</p>
<p>$$
\begin{aligned}
\frac{\text{d}\lambda(s)}{\text{d}s} &amp;= -\lambda(s)^\dagger \frac{\text{d} f(\rightarrow x(s), p, t)}{\text{d} x(s)} \\
\frac{\text{d}\lambda_{p}(s)}{\text{d}s} &amp;= -\lambda(s)^\dagger \frac{\text{d} f(x(s), \rightarrow p, s)}{\text{d} p},
\end{aligned}
$$</p>
<p>with initial conditions:</p>
<p>$$
\begin{aligned}
\lambda(t)&amp;= \frac{\text{d}\thinspace L(\text{solve}(t_0, x_0, t, p))}{\text{d}x_{T}}, \\
\lambda_{p}(t) &amp;= 0.
\end{aligned}
$$</p>
<p>The arrows ($\rightarrow$) indicate the variable with respect to which we differentiate, which will become important later when the same variable shows up in multiple function arguments.</p>
<p>Note that computing the vector-Jacobian products (vjp) in the adjoint ODE requires the value of $x(s)$ along its trajectory. In <code>BacksolveAdjoint()</code>, we recompute $x(s)$ &ndash; together with the adjoint variables &ndash; backwards in time starting with its final value $x(t)$. A derivation of the ODE adjoint is given in <a href="https://mitmath.github.io/18337/lecture11/adjoints" target="_blank" rel="noopener">Chris&rsquo; MIT 18.337 lecture notes</a>.</p>
<p><code>BacksolveAdjoint()</code>, essentially the custom primitive differentiation rule of <code>solve</code>, is the elementary building block needed to derive sensitivities also in more complicated examples:</p>
<p>Consider a loss depending on the state $x(s_i)$ at fix time points $s_i$ through loss functions $L_i^{\text{exp}}$,</p>
<p>$$
L^{\text{exp}} = \sum_i L_i^{\text{exp}}(s_i, x(s_i), p).
$$</p>
<p>Without contortions we can obtain the sensitivity of $L^{\text{exp}}$ in $p$ (or in $x_0$) using the tool we have. For those a bit familiar with automatic differentiation, this is perhaps easiest to see if we write $L^{\text{exp}}$ as pseudo code</p>
<div class="highlight"><pre tabindex="0" class="chroma"><code class="language-julia" data-lang="julia"><span class="line"><span class="cl"><span class="k">function</span> <span class="n">loss</span><span class="p">(</span><span class="n">t0</span><span class="p">,</span> <span class="n">x0</span><span class="p">,</span> <span class="n">p</span><span class="p">)</span>
</span></span><span class="line"><span class="cl"> <span class="n">x1</span> <span class="o">=</span> <span class="n">solve</span><span class="p">(</span><span class="n">t0</span><span class="p">,</span> <span class="n">x0</span><span class="p">,</span> <span class="n">s1</span><span class="p">,</span> <span class="n">p</span><span class="p">)</span>
</span></span><span class="line"><span class="cl"> <span class="n">L</span> <span class="o">=</span> <span class="n">L1</span><span class="p">(</span><span class="n">s1</span><span class="p">,</span> <span class="n">x1</span><span class="p">,</span> <span class="n">p</span><span class="p">)</span>
</span></span><span class="line"><span class="cl"> <span class="n">x2</span> <span class="o">=</span> <span class="n">solve</span><span class="p">(</span><span class="n">s1</span><span class="p">,</span> <span class="n">x1</span><span class="p">,</span> <span class="n">s2</span><span class="p">,</span> <span class="n">p</span><span class="p">)</span>
</span></span><span class="line"><span class="cl"> <span class="n">L</span> <span class="o">+=</span> <span class="n">L2</span><span class="p">(</span><span class="n">s2</span><span class="p">,</span> <span class="n">x2</span><span class="p">,</span> <span class="n">p</span><span class="p">)</span>
</span></span><span class="line"><span class="cl"> <span class="o">...</span>
</span></span><span class="line"><span class="cl"> <span class="k">return</span> <span class="n">L</span>
</span></span><span class="line"><span class="cl"><span class="k">end</span>
</span></span></code></pre></div><p>and consider the problem of automatically differentiating it. You&rsquo;ll just need the primitives of <code>solve</code>, <code>L1</code> and <code>L2</code>! The sensitivities can be computed by repeated calls to <code>BacksolveAdjoint()</code> on the intervals <code>(s_i, s_{i+1})</code> backward in time, taking in the sensitivities $\sum_i L^{\text{exp}}(s_i, x(s_i), p)$ at times $s_i$, or according to the common notation<sup id="fnref1:2"><a href="#fn:2" class="footnote-ref" role="doc-noteref">2</a></sup> as single call <code>BacksolveAdjoint()</code> with discrete callbacks:</p>
<p>$$
\begin{aligned}
\frac{\text{d}\lambda(t)}{\text{d}t} &amp;= -\lambda(t)^\dagger \frac{\text{d} f(\rightarrow x(t), p, t)}{\text{d} x(t)} - \frac{\text{d} L_i^{\text{exp}}(\rightarrow x(t), p)}{\text{d} x(t)}^\dagger \delta(t-s_i), \\
\frac{\text{d}\lambda_{p}(t)}{\text{d}t} &amp;= -\lambda(t)^\dagger \frac{\text{d} f(x(t), \rightarrow p, t)}{\text{d} p} - \frac{\text{d} L_i^{\text{exp}}( x(t),\rightarrow p)}{\text{d} p}^\dagger \delta(t-s_i).
\end{aligned}
$$</p>
<p>The sensitivities of the ordinary <code>solve</code> with respect to the other arguments are also needed and given by</p>
<p>$$
\frac{\text{d}(\text{solve}(s, x, \rightarrow t, p))}{\text{d}t} = f(\text{solve}(s, x, t, p), p, t)
$$</p>
<p>and</p>
<p>$$
\frac{\text{d}(\text{solve}(\rightarrow s, x, t, p))}{\text{d}s} = -f(x, p, s).
$$</p>
<p>Now we can even properly define the <code>rrule</code> of <code>solve</code> in the sense of <code>DiffRules.jl</code>.</p>
<h3 id="explicit-events-1">Explicit events</h3>
<p>To make <code>BacksolveAdjoint()</code> compatible with explicit events<sup id="fnref2:2"><a href="#fn:2" class="footnote-ref" role="doc-noteref">2</a></sup>,</p>
<div class="highlight"><pre tabindex="0" class="chroma"><code class="language-julia" data-lang="julia"><span class="line"><span class="cl"><span class="k">function</span> <span class="n">loss</span><span class="p">(</span><span class="n">t0</span><span class="p">,</span> <span class="n">x0</span><span class="p">,</span> <span class="n">p</span><span class="p">)</span>
</span></span><span class="line"><span class="cl"> <span class="n">x1</span><span class="o">-</span> <span class="o">=</span> <span class="n">solve</span><span class="p">(</span><span class="n">t0</span><span class="p">,</span> <span class="n">x0</span><span class="p">,</span> <span class="n">s1</span><span class="p">,</span> <span class="n">p</span><span class="p">)</span>
</span></span><span class="line"><span class="cl"> <span class="n">L</span> <span class="o">=</span> <span class="n">L1</span><span class="o">-</span><span class="p">(</span><span class="n">s1</span><span class="p">,</span> <span class="n">x1</span><span class="o">-</span><span class="p">,</span> <span class="n">p</span><span class="p">)</span> <span class="c"># saved before affect</span>
</span></span><span class="line"><span class="cl"> <span class="n">x1</span> <span class="o">=</span> <span class="n">a</span><span class="p">(</span><span class="n">x1</span><span class="o">-</span><span class="p">,</span> <span class="n">p</span><span class="p">,</span> <span class="n">s1</span><span class="p">)</span> <span class="c"># affect</span>
</span></span><span class="line"><span class="cl"> <span class="n">L</span> <span class="o">+=</span> <span class="n">L1</span><span class="p">(</span><span class="n">s1</span><span class="p">,</span> <span class="n">x1</span><span class="p">,</span> <span class="n">p</span><span class="p">)</span> <span class="c"># saved after affect</span>
</span></span><span class="line"><span class="cl"> <span class="n">x2</span><span class="o">-</span> <span class="o">=</span> <span class="n">solve</span><span class="p">(</span><span class="n">s1</span><span class="p">,</span> <span class="n">x1</span><span class="p">,</span> <span class="n">s2</span><span class="p">,</span> <span class="n">p</span><span class="p">)</span>
</span></span><span class="line"><span class="cl"> <span class="n">L</span> <span class="o">+=</span> <span class="n">L2</span><span class="o">-</span><span class="p">(</span><span class="n">s2</span><span class="p">,</span> <span class="n">x2</span><span class="o">-</span><span class="p">,</span> <span class="n">p</span><span class="p">)</span>
</span></span><span class="line"><span class="cl"> <span class="n">x2</span> <span class="o">=</span> <span class="n">a</span><span class="p">(</span><span class="n">x2</span><span class="o">-</span><span class="p">,</span> <span class="n">p</span><span class="p">,</span> <span class="n">s2</span><span class="p">)</span> <span class="c"># affect</span>
</span></span><span class="line"><span class="cl"> <span class="n">L</span> <span class="o">+=</span> <span class="n">L2</span><span class="p">(</span><span class="n">s2</span><span class="p">,</span> <span class="n">x2</span><span class="p">,</span> <span class="n">p</span><span class="p">)</span>
</span></span><span class="line"><span class="cl"> <span class="o">...</span>
</span></span><span class="line"><span class="cl"> <span class="k">return</span> <span class="n">L</span>
</span></span><span class="line"><span class="cl"><span class="k">end</span>
</span></span></code></pre></div><p>we have to store the event times $s_j$ as well as the state $x(s_j-)$ at the left limit of $s_i$.<sup id="fnref:5"><a href="#fn:5" class="footnote-ref" role="doc-noteref">5</a></sup> We then solve the adjoint ODE backwards in time between the events. As soon as we reach an event time $s_j$ from the right, we update the augmented state according to</p>
<p>$$
\begin{aligned}
\lambda({s_j}-) &amp;= \lambda({s_j})^\dagger \frac{\text{d} a(\rightarrow x({s_j}-), p, {s_j}-)}{\text{d} x(s_j-)} \\
\lambda_p({s_j}-) &amp;= \lambda_p({s_j}) - \lambda({s_j})^\dagger \frac{\text{d} a(x({s_j}-), \rightarrow p, {s_j}-)}{\text{d} p}
\end{aligned}
$$</p>
<p>where $a$ is the affect function applied at the discontinuity. That is, to lift the adjoint from the right to the left limit, we compute a vjp with the adjoint $\lambda({s_j})$ from the right and the Jacobian of the affect function evaluated immediately before the event time at $s_j-$ .</p>
<p>In particular, we apply a loss function callback before and after this update if the state was saved in the forward evolution and entered directly into the loss function.</p>
<h3 id="implicit-events-1">Implicit events</h3>
<p>With implicit events it is similar: Being able to differentiate the ODE when an implicit event terminates the ODE gives us the custom primitive differentiation rule of a <code>solve</code> with implicit callback.</p>
<p>We have to account for an important change: besides the value $\xi = x(\tau)$ at time of the implicit event, the solver returns the variable event time $\tau$ itself.</p>
<p>We could write for an event condition function $g$</p>
<p>$$(\tau_1, \xi_1) = \text{solve2}(t_0, x_0, g, p)$$</p>
<p>to put emphasis on this, or equivalently, compute for a unspecified function $L(t, x, p)$ the result of $\frac{\text{d}L}{\text{d}p}$ with</p>
<p>$$
\begin{aligned}
\frac{\text{d}L}{\text{d}p} &amp;= \frac{\text{d}L( \text{solve2}(t_0, x_0, g, \rightarrow p),\rightarrow p)}{\text{d}p}\\
&amp;= \frac{\text{d}L(\tau_1({\color{black}\rightarrow} p), \text{solve}(t_0, x_0, \tau_1({\color{black}\rightarrow}p), \rightarrow p),\rightarrow p)}{\text{d}p},
\end{aligned}
$$</p>
<p>which indicates that changing $p$ influences $L$ both through changes in $\tau_1$ as well as changes in</p>
<p>$$\xi_1 = x(\tau_1-).$$</p>
<p>This case where we have a loss function $L = L_1$ depending on $\tau_1$, $x(\tau_1)$, and $p$
was also considered by Ricky T. Q. Chen, Brandon Amos, and Maximilian Nickel in their ICLR 2021 paper<sup id="fnref1:4"><a href="#fn:4" class="footnote-ref" role="doc-noteref">4</a></sup>.</p>
<p>Therefore, the sensitivity of the event time with respect to parameters $\frac{\text{d}\tau}{\text{d}p}$ must be taken into account.
Here and in the following we consider only the $p$-dependence of $\tau_1$ for simplicity. However, it is straightforward to include a dependence on the initial state $x_0$ in an analogues way<sup id="fnref2:4"><a href="#fn:4" class="footnote-ref" role="doc-noteref">4</a></sup>.</p>
<p>In a first step, we need to compute the sensitivity of $\tau_1(p)$ with respect to $p$ (or $x_0$) based on the event condition $g(t, x(t)) = 0$. We can apply the <a href="https://www.uni-siegen.de/fb6/analysis/overhagen/vorlesungsbeschreibungen/skripte/analysis3_1.pdf" target="_blank" rel="noopener">implicit function theorem</a>. For this, see that $\tau_1(p)$ is implicitly defined by $F(p, \tau_1) = g( \tau_1, \text{solve}(t_0, x_0, \tau_1, p)) = 0$ which yields</p>
<p>$$
\begin{aligned}
\frac{\text{d}\tau_1(p)}{\text{d}p} &amp;= - \left(\frac{\text{d}g(\rightarrow \tau_1, \text{solve}(t_0, x_0, \rightarrow \tau_1, p))}{\text{d}\tau_1}\right)^{-1} \frac{\text{d}g(\tau_1, \text{solve}(t_0, x_0, \tau_1, \rightarrow p))}{\text{d}p} .\\
\end{aligned}
$$</p>
<p>The total derivative<sup id="fnref:6"><a href="#fn:6" class="footnote-ref" role="doc-noteref">6</a></sup> inside the bracket is:
$$
\begin{aligned}
\frac{\text{d}g}{\text{d}\tau_1} \stackrel{\text{def}}{=} \frac{\text{d}g(\rightarrow \tau_1, \text{solve}(t_0, x_0, \rightarrow \tau_1, p))}{\text{d}\tau_1} &amp;= \frac{\text{d}g(\rightarrow \tau_1, \xi_1)}{\text{d}\tau_1} + \frac{\text{d}g(\tau_1, \text{solve}(t_0, x_0, \rightarrow \tau_1, p))}{\text{d}\tau_1}\\
\end{aligned}
$$</p>
<p>Since</p>
<p>$$
\frac{\text{d}(\text{solve}(t_0, x_0, \rightarrow \tau_1, p))}{\text{d}\tau_1} = f(\xi_1, p, \tau_1)
$$</p>
<p>by definition of the ODE, we can write</p>
<p>$$
\begin{aligned}
\frac{\text{d}g(\tau_1, \text{solve}(t_0, x_0, \rightarrow \tau_1, p))}{\text{d}\tau_1} = \frac{\text{d}g(\tau_1, \xi_1)}{\text{d} \xi_1}^{\dagger} f(\xi_1, p, \tau_1).
\end{aligned}
$$</p>
<p>Furthermore, we have
$$
\begin{aligned}
\frac{\text{d}g(\tau_1, \text{solve}(t_0, x_0, \tau_1, \rightarrow p))}{\text{d}p} = \frac{\text{d}g(\tau_1, \xi_1)}{\text{d} \xi_1}^{\dagger} \frac{\text{d}\text{ solve}(t_0, x_0, \tau_1,\rightarrow p)}{\text{d}p}
\end{aligned}
$$
for the second term of $\dfrac{\text{d}\tau_1(p)}{\text{d}p}$.</p>
<p>We can now write the gradient as:</p>
<p>$$
\begin{aligned}
\frac{\text{d}L(\tau_1({\color{black}\rightarrow} p), \text{solve}(t_0, x_0, \tau_1({\color{black}\rightarrow}p), \rightarrow p),\rightarrow p)}{\text{d}p} &amp;= \frac{\text{d}L(\tau_1(p), \text{solve}(t_0, x_0, \tau_1(p), p), \rightarrow p)}{\text{d}p} \\
+&amp; \frac{\text{d}L(\tau_1(p), \text{solve}(t_0, x_0, \tau_1(p), \rightarrow p), p)}{\text{d}p} \\
+&amp; \frac{\text{d}L(\rightarrow \tau_1(p), \text{solve}(t_0, x_0, \rightarrow \tau_1(p), p), p)}{\text{d}\tau_1} \frac{\text{d} \tau_1(p)}{\text{d}p},
\end{aligned}
$$</p>
<p>which, after insertion of our results above, can be cast into the form:</p>
<p>$$
\frac{\text{d}L}{\text{d}p} = v^\dagger \frac{\text{d}\text{ solve}(t_0, x_0, \tau_1(p), \rightarrow p)}{\text{d}p} + \frac{\text{d}L(\tau_1(p), \text{solve}(t_0, x_0, \tau_1(p), p), \rightarrow p)}{\text{d}p},
$$</p>
<p>with</p>
<p>$$
\begin{aligned}
v &amp;= \rho \left(-\frac{\text{d}g}{\text{d}\tau_1}\right)^{-1} \frac{\text{d}g}{\text{d}\xi_1} + \frac{\text{d}L(\tau_1, \xi_1)}{\text{d} \xi_1},
\end{aligned}
$$</p>
<p>where we introduced the scalar pre-factor</p>
<p>$$
\begin{aligned}
\rho = \left( \frac{\text{d}L(\rightarrow \tau_1, \xi_1)}{\text{d}\tau_1} + \frac{\text{d}L(\tau_1, \xi_1)}{\text{d} \xi_1}^\dagger f(\xi_1, p, \tau_1)\right).
\end{aligned}
$$</p>