Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add vector_jacobian_product and jacobian_vector_product functions #623

Merged
merged 5 commits into from
May 6, 2024

Conversation

avik-pal
Copy link
Member

@avik-pal avik-pal commented May 5, 2024

Copy link
Contributor

@github-actions github-actions bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Benchmark Results

Benchmark suite Current: afba62e Previous: 6c2fbc8 Ratio
Dense(2 => 2)/cpu/reverse/ReverseDiff (compiled)/(2, 128) 3654.25 ns 3508.4375 ns 1.04
Dense(2 => 2)/cpu/reverse/Zygote/(2, 128) 7153.333333333333 ns 7246.666666666667 ns 0.99
Dense(2 => 2)/cpu/reverse/Tracker/(2, 128) 20919 ns 20578 ns 1.02
Dense(2 => 2)/cpu/reverse/ReverseDiff/(2, 128) 9838 ns 9676 ns 1.02
Dense(2 => 2)/cpu/reverse/Flux/(2, 128) 9072.8 ns 8922.8 ns 1.02
Dense(2 => 2)/cpu/reverse/SimpleChains/(2, 128) 4528.5 ns 4548.5 ns 1.00
Dense(2 => 2)/cpu/forward/NamedTuple/(2, 128) 1127.4935064935064 ns 1127.0986842105262 ns 1.00
Dense(2 => 2)/cpu/forward/ComponentArray/(2, 128) 1183.3925925925926 ns 1205.8192307692307 ns 0.98
Dense(2 => 2)/cpu/forward/Flux/(2, 128) 1804.4545454545455 ns 1808.1851851851852 ns 1.00
Dense(2 => 2)/cpu/forward/SimpleChains/(2, 128) 179.07757404795487 ns 179.81652661064425 ns 1.00
Dense(20 => 20)/cpu/reverse/ReverseDiff (compiled)/(20, 128) 17392 ns 17302 ns 1.01
Dense(20 => 20)/cpu/reverse/Zygote/(20, 128) 17653 ns 17522 ns 1.01
Dense(20 => 20)/cpu/reverse/Tracker/(20, 128) 37610 ns 36929 ns 1.02
Dense(20 => 20)/cpu/reverse/ReverseDiff/(20, 128) 28704 ns 28272 ns 1.02
Dense(20 => 20)/cpu/reverse/Flux/(20, 128) 19917 ns 19537 ns 1.02
Dense(20 => 20)/cpu/reverse/SimpleChains/(20, 128) 17212.5 ns 17102 ns 1.01
Dense(20 => 20)/cpu/forward/NamedTuple/(20, 128) 3931.125 ns 3886 ns 1.01
Dense(20 => 20)/cpu/forward/ComponentArray/(20, 128) 3981.125 ns 3974.875 ns 1.00
Dense(20 => 20)/cpu/forward/Flux/(20, 128) 5065.142857142857 ns 4914.857142857143 ns 1.03
Dense(20 => 20)/cpu/forward/SimpleChains/(20, 128) 1653.1 ns 1660 ns 1.00
Conv((3, 3), 3 => 3)/cpu/reverse/ReverseDiff (compiled)/(64, 64, 3, 128) 39651659.5 ns 39195567 ns 1.01
Conv((3, 3), 3 => 3)/cpu/reverse/Zygote/(64, 64, 3, 128) 83019705 ns 83173909.5 ns 1.00
Conv((3, 3), 3 => 3)/cpu/reverse/Tracker/(64, 64, 3, 128) 69539922 ns 76132879 ns 0.91
Conv((3, 3), 3 => 3)/cpu/reverse/ReverseDiff/(64, 64, 3, 128) 80902614 ns 85391910.5 ns 0.95
Conv((3, 3), 3 => 3)/cpu/reverse/Flux/(64, 64, 3, 128) 65418328 ns 65468388 ns 1.00
Conv((3, 3), 3 => 3)/cpu/reverse/SimpleChains/(64, 64, 3, 128) 11702797 ns 11744400 ns 1.00
Conv((3, 3), 3 => 3)/cpu/forward/NamedTuple/(64, 64, 3, 128) 14321910 ns 14297090 ns 1.00
Conv((3, 3), 3 => 3)/cpu/forward/ComponentArray/(64, 64, 3, 128) 14260866 ns 14209798 ns 1.00
Conv((3, 3), 3 => 3)/cpu/forward/Flux/(64, 64, 3, 128) 10058783 ns 10009941 ns 1.00
Conv((3, 3), 3 => 3)/cpu/forward/SimpleChains/(64, 64, 3, 128) 6419130 ns 6419783 ns 1.00
vgg16/cpu/reverse/Zygote/(32, 32, 3, 1) 124748009 ns 115357617 ns 1.08
vgg16/cpu/reverse/Zygote/(32, 32, 3, 16) 849955072 ns 894708632 ns 0.95
vgg16/cpu/reverse/Zygote/(32, 32, 3, 64) 3231168666 ns 3130396854 ns 1.03
vgg16/cpu/reverse/Tracker/(32, 32, 3, 1) 159942733.5 ns 148583143 ns 1.08
vgg16/cpu/reverse/Tracker/(32, 32, 3, 16) 784098352 ns 703083578 ns 1.12
vgg16/cpu/reverse/Tracker/(32, 32, 3, 64) 2599092176 ns 2645612324 ns 0.98
vgg16/cpu/reverse/Flux/(32, 32, 3, 1) 81391066 ns 79580760.5 ns 1.02
vgg16/cpu/reverse/Flux/(32, 32, 3, 16) 648873929 ns 652507658.5 ns 0.99
vgg16/cpu/reverse/Flux/(32, 32, 3, 64) 3000983905 ns 2782809284 ns 1.08
vgg16/cpu/forward/NamedTuple/(32, 32, 3, 1) 32692800 ns 29165219 ns 1.12
vgg16/cpu/forward/NamedTuple/(32, 32, 3, 16) 308210080 ns 307555362 ns 1.00
vgg16/cpu/forward/NamedTuple/(32, 32, 3, 64) 949208188.5 ns 952567181 ns 1.00
vgg16/cpu/forward/ComponentArray/(32, 32, 3, 1) 29015669 ns 28824975 ns 1.01
vgg16/cpu/forward/ComponentArray/(32, 32, 3, 16) 305664709 ns 303950539 ns 1.01
vgg16/cpu/forward/ComponentArray/(32, 32, 3, 64) 938295779.5 ns 940356219 ns 1.00
vgg16/cpu/forward/Flux/(32, 32, 3, 1) 24304756 ns 24207518 ns 1.00
vgg16/cpu/forward/Flux/(32, 32, 3, 16) 186418519 ns 208674267.5 ns 0.89
vgg16/cpu/forward/Flux/(32, 32, 3, 64) 779190653.5 ns 711398513 ns 1.10
Conv((3, 3), 64 => 64)/cpu/reverse/ReverseDiff (compiled)/(64, 64, 64, 128) 1091691990.5 ns 1040023601 ns 1.05
Conv((3, 3), 64 => 64)/cpu/reverse/Zygote/(64, 64, 64, 128) 1890239060 ns 1881706524 ns 1.00
Conv((3, 3), 64 => 64)/cpu/reverse/Tracker/(64, 64, 64, 128) 2168593335 ns 2139508656.5 ns 1.01
Conv((3, 3), 64 => 64)/cpu/reverse/ReverseDiff/(64, 64, 64, 128) 2561023822 ns 2366498132.5 ns 1.08
Conv((3, 3), 64 => 64)/cpu/reverse/Flux/(64, 64, 64, 128) 1888988086 ns 1868515057 ns 1.01
Conv((3, 3), 64 => 64)/cpu/forward/NamedTuple/(64, 64, 64, 128) 346510951 ns 339882096 ns 1.02
Conv((3, 3), 64 => 64)/cpu/forward/ComponentArray/(64, 64, 64, 128) 344023803 ns 338963529 ns 1.01
Conv((3, 3), 64 => 64)/cpu/forward/Flux/(64, 64, 64, 128) 367503103.5 ns 348597582 ns 1.05
Conv((3, 3), 1 => 1)/cpu/reverse/ReverseDiff (compiled)/(64, 64, 1, 128) 11977359 ns 11947596.5 ns 1.00
Conv((3, 3), 1 => 1)/cpu/reverse/Zygote/(64, 64, 1, 128) 29536902.5 ns 29496655 ns 1.00
Conv((3, 3), 1 => 1)/cpu/reverse/Tracker/(64, 64, 1, 128) 19394787.5 ns 19276906 ns 1.01
Conv((3, 3), 1 => 1)/cpu/reverse/ReverseDiff/(64, 64, 1, 128) 23961700 ns 23972028 ns 1.00
Conv((3, 3), 1 => 1)/cpu/reverse/Flux/(64, 64, 1, 128) 18008562.5 ns 18098635 ns 1.00
Conv((3, 3), 1 => 1)/cpu/reverse/SimpleChains/(64, 64, 1, 128) 1176351 ns 1142592 ns 1.03
Conv((3, 3), 1 => 1)/cpu/forward/NamedTuple/(64, 64, 1, 128) 4421352 ns 4412312 ns 1.00
Conv((3, 3), 1 => 1)/cpu/forward/ComponentArray/(64, 64, 1, 128) 4254689 ns 4235285 ns 1.00
Conv((3, 3), 1 => 1)/cpu/forward/Flux/(64, 64, 1, 128) 2105148 ns 2081648 ns 1.01
Conv((3, 3), 1 => 1)/cpu/forward/SimpleChains/(64, 64, 1, 128) 199622 ns 206394 ns 0.97
Dense(200 => 200)/cpu/reverse/ReverseDiff (compiled)/(200, 128) 293448 ns 292354 ns 1.00
Dense(200 => 200)/cpu/reverse/Zygote/(200, 128) 268391.5 ns 265855 ns 1.01
Dense(200 => 200)/cpu/reverse/Tracker/(200, 128) 369224 ns 364489 ns 1.01
Dense(200 => 200)/cpu/reverse/ReverseDiff/(200, 128) 411378 ns 408631 ns 1.01
Dense(200 => 200)/cpu/reverse/Flux/(200, 128) 277257.5 ns 274251 ns 1.01
Dense(200 => 200)/cpu/reverse/SimpleChains/(200, 128) 408082 ns 412438 ns 0.99
Dense(200 => 200)/cpu/forward/NamedTuple/(200, 128) 81923 ns 81182 ns 1.01
Dense(200 => 200)/cpu/forward/ComponentArray/(200, 128) 82254 ns 81802 ns 1.01
Dense(200 => 200)/cpu/forward/Flux/(200, 128) 87463 ns 86932 ns 1.01
Dense(200 => 200)/cpu/forward/SimpleChains/(200, 128) 104665 ns 104665 ns 1
Conv((3, 3), 16 => 16)/cpu/reverse/ReverseDiff (compiled)/(64, 64, 16, 128) 198222135 ns 204164825 ns 0.97
Conv((3, 3), 16 => 16)/cpu/reverse/Zygote/(64, 64, 16, 128) 346243659 ns 346236548 ns 1.00
Conv((3, 3), 16 => 16)/cpu/reverse/Tracker/(64, 64, 16, 128) 398959136 ns 396930806.5 ns 1.01
Conv((3, 3), 16 => 16)/cpu/reverse/ReverseDiff/(64, 64, 16, 128) 458440994.5 ns 457911693 ns 1.00
Conv((3, 3), 16 => 16)/cpu/reverse/Flux/(64, 64, 16, 128) 365204968 ns 372160775 ns 0.98
Conv((3, 3), 16 => 16)/cpu/reverse/SimpleChains/(64, 64, 16, 128) 341163429 ns 344593576 ns 0.99
Conv((3, 3), 16 => 16)/cpu/forward/NamedTuple/(64, 64, 16, 128) 54069724.5 ns 55820278 ns 0.97
Conv((3, 3), 16 => 16)/cpu/forward/ComponentArray/(64, 64, 16, 128) 53953057 ns 55703332.5 ns 0.97
Conv((3, 3), 16 => 16)/cpu/forward/Flux/(64, 64, 16, 128) 52082055.5 ns 49665442 ns 1.05
Conv((3, 3), 16 => 16)/cpu/forward/SimpleChains/(64, 64, 16, 128) 29415775 ns 28300682 ns 1.04
Dense(2000 => 2000)/cpu/reverse/ReverseDiff (compiled)/(2000, 128) 20059082 ns 19100132.5 ns 1.05
Dense(2000 => 2000)/cpu/reverse/Zygote/(2000, 128) 19565651 ns 19547856 ns 1.00
Dense(2000 => 2000)/cpu/reverse/Tracker/(2000, 128) 23684156 ns 23447627 ns 1.01
Dense(2000 => 2000)/cpu/reverse/ReverseDiff/(2000, 128) 24213424 ns 24070344.5 ns 1.01
Dense(2000 => 2000)/cpu/reverse/Flux/(2000, 128) 19864849 ns 19636438 ns 1.01
Dense(2000 => 2000)/cpu/forward/NamedTuple/(2000, 128) 6520588 ns 6488947 ns 1.00
Dense(2000 => 2000)/cpu/forward/ComponentArray/(2000, 128) 6538420 ns 6488751.5 ns 1.01
Dense(2000 => 2000)/cpu/forward/Flux/(2000, 128) 6515477.5 ns 6500113 ns 1.00

This comment was automatically generated by workflow using github-action-benchmark.

@avik-pal avik-pal force-pushed the ap/pullback branch 4 times, most recently from e48e58f to ff8d375 Compare May 5, 2024 23:20
@avik-pal avik-pal merged commit 5515383 into main May 6, 2024
25 of 39 checks passed
@avik-pal avik-pal deleted the ap/pullback branch May 6, 2024 00:27
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

Differentiating Zygote.pullback
1 participant