From e22975b10fcdd7a6f718a1f090e27ad64dfa221f Mon Sep 17 00:00:00 2001 From: NekoDaemon Date: Tue, 23 Aug 2022 14:44:26 +0000 Subject: [PATCH 1/6] Squashed commit of the following: commit 413192308203a46fd41d7e1aae6b4ca080698b81 Merge: 61f3dd04 9e363198 Author: NekoDaemon Date: Tue Aug 23 14:28:20 2022 +0000 Merge branch 'main' of github.com:Tonny-Gu/raf into sharding-pr1 commit 61f3dd046e6c6b76f0730139ecac864af12827ec Merge: dd26fbcf 31366515 Author: NekoDaemon <29330054+Tonny-Gu@users.noreply.github.com> Date: Mon Aug 8 02:20:36 2022 +0800 Merge branch 'awslabs:main' into sharding commit dd26fbcf21b8f7d7e4a4d4765403f62fb350ac37 Author: NekoDaemon Date: Wed Aug 3 18:57:55 2022 +0000 impl commit d7deb205f2737fdc3a3edac9111aeef7b495c12b Author: NekoDaemon Date: Sun Jul 31 19:10:08 2022 +0000 impl commit 6acdac580c0061db6122b73b80345a5aed817178 Merge: c1ff9361 9f26c8e0 Author: NekoDaemon <29330054+Tonny-Gu@users.noreply.github.com> Date: Mon Aug 1 02:48:04 2022 +0800 Merge branch 'awslabs:main' into sharding commit c1ff9361fe7081ffe1099f8ed53391757c29ccd2 Author: NekoDaemon Date: Sun Jul 31 18:46:46 2022 +0000 impl commit 4468f2d2e4161e68c7af036845d0e67efcc2d169 Author: NekoDaemon Date: Sun Jul 31 09:44:02 2022 +0000 impl commit e191d2bae04486e37dfdd777aee22d526209a4e3 Author: NekoDaemon Date: Tue Jul 19 18:22:34 2022 +0000 impl commit 08f8fb0a3ec82e07e57d5e073c1d861b1c841849 Author: NekoDaemon Date: Tue Jul 5 22:28:57 2022 +0000 impl commit ef45da38390864ff34ba1d191fd2a8365d4eedaa Author: NekoDaemon Date: Tue Jun 28 16:30:07 2022 +0000 impl commit 8e9e83e7e4d119d9013dfbc24a9a2c30c1610718 Author: NekoDaemon Date: Tue Jun 28 08:31:11 2022 +0000 impl commit 5394633313f54f1c7ac7cd1a09442366d12567fe Author: NekoDaemon Date: Tue Jun 14 17:52:43 2022 +0000 fix commit 65a896d56f461f81fd07ff696decaf858da3fd20 Author: NekoDaemon Date: Tue Jun 14 17:52:23 2022 +0000 fix commit ce663a772f27037a7b3c44a427300a3fbd309fe7 Author: NekoDaemon Date: Tue Jun 14 16:29:39 2022 +0000 refactor commit 1727b4f465d61a7b019af07c3c15bb5dc58733c3 Merge: f1244949 c8ddbc93 Author: NekoDaemon Date: Sun Jun 12 16:14:20 2022 +0800 Merge branch 'awslabs:main' into sharding commit f12449491084cd277cf8e2b72f30de29b64e7bde Merge: 10749add f980111e Author: NekoDaemon Date: Fri Apr 22 17:01:58 2022 +0000 Merge branch 'main' of https://github.com/Tonny-Gu/raf into sharding commit 10749add5fc14413e8f85d44fc70e3cde54fe5f4 Author: NekoDaemon Date: Fri Apr 22 17:00:42 2022 +0000 refactor commit adf58ecf653da5cd06c07f4950f66de256edfaf3 Merge: ccadefc7 4d1d0335 Author: NekoDaemon Date: Wed Apr 13 15:38:17 2022 +0000 Merge branch 'main' of https://github.com/Tonny-Gu/raf into sharding commit ccadefc748bb25b85b4dc356bd26b885701543a3 Author: NekoDaemon Date: Mon Apr 11 15:34:48 2022 +0000 test commit 7869cf12e067b7616bdbd5e6ec2dad402871b1be Author: NekoDaemon Date: Sat Apr 2 12:26:00 2022 +0000 impl commit c2b91332d3ea89531bf90b68e2a7050878025764 Author: NekoDaemon Date: Sat Mar 19 10:09:48 2022 +0000 refactor commit 4c80afb3a606829905f0f77819722a22373ee44e Author: NekoDaemon Date: Fri Mar 18 18:42:43 2022 +0000 trace shardspec commit 2133a71dba43ea94ba71e3997b86fc0c42e193cd Author: NekoDaemon Date: Fri Mar 18 09:19:49 2022 +0000 refactor commit 8a8b783af02e2a8c662671d390b7c565a182273d Author: NekoDaemon Date: Thu Mar 17 16:12:21 2022 +0000 refactor commit bd00a6cd1de52f1f6c0d446ef6b07f4ef5eb58e6 Author: NekoDaemon Date: Wed Mar 16 10:21:37 2022 +0000 lint commit 8f590b8f01896ee9c907c40e8490058465e0b4ba Merge: ec654df2 b87a69d7 Author: NekoDaemon Date: Wed Mar 16 10:14:26 2022 +0000 Merge branch 'multi-comm' of github.com:Tonny-Gu/meta into sharding commit ec654df2f01bde7e3e2b81ba513e42924083fc2d Merge: 0e74b40a 1899126f Author: NekoDaemon Date: Wed Mar 16 16:18:34 2022 +0800 Merge branch 'meta-project:main' into sharding commit b87a69d798febc7436bbc0381749a286a5006f75 Merge: 4b3ba0ae 1899126f Author: NekoDaemon Date: Wed Mar 16 16:18:25 2022 +0800 Merge branch 'meta-project:main' into multi-comm commit 4b3ba0ae4a9d59685a8a16b33849ec0322c15152 Merge: 0e06a41a df740a1d Author: NekoDaemon Date: Tue Mar 15 20:18:54 2022 +0800 Merge branch 'meta-project:main' into multi-comm commit 0e06a41a0ec6911fb2a21e94409c9f72906381f1 Merge: c58a495d 7afd4153 Author: NekoDaemon Date: Fri Mar 11 23:29:36 2022 +0800 Merge branch 'meta-project:main' into multi-comm commit c58a495d383d835f9f84180dc10746e511b263ee Author: NekoDaemon Date: Fri Mar 11 15:29:16 2022 +0000 fix commit bd2c3d79cccbd4de41224cc38e725b8c7acc4d85 Merge: ad29efa8 aa423a29 Author: NekoDaemon Date: Wed Mar 9 17:39:40 2022 +0800 Merge branch 'meta-project:main' into multi-comm commit ad29efa8490d2540630cd7cc506546234c554502 Author: NekoDaemon Date: Wed Mar 9 09:38:52 2022 +0000 clean commit 66d2984ab0a32df7adce81fcd24e2822c1b05f09 Merge: 2df68d26 20eb5a12 Author: NekoDaemon Date: Wed Mar 9 01:11:11 2022 +0800 Merge branch 'meta-project:main' into multi-comm commit 2df68d268c9d175b658bcf4108012f072f84f5da Author: NekoDaemon Date: Tue Mar 8 17:10:37 2022 +0000 nested tuple commit b22ba52d8ae7b025d787aa382ed3f8f7e8fc50bb Author: NekoDaemon Date: Tue Mar 8 09:25:07 2022 +0000 fix commit 0e74b40a53194445bc0b023e37d930fb5d5726a2 Merge: 44214ac6 20eb5a12 Author: NekoDaemon Date: Tue Mar 8 16:38:02 2022 +0800 Merge branch 'meta-project:main' into sharding commit 995bb1f11a453c2d395ddd91eb18f4010b81fdd0 Author: NekoDaemon Date: Tue Mar 8 08:37:16 2022 +0000 nested tuple commit eefd9ebec70af19265b5238802bd5aca34330243 Merge: 54071516 279dba8c Author: NekoDaemon Date: Sat Mar 5 14:59:25 2022 +0800 Merge branch 'meta-project:main' into multi-comm commit 54071516269908531490550007d6bd767a08c5d5 Author: NekoDaemon Date: Sat Mar 5 06:56:58 2022 +0000 clean commit 44214ac6168fe2e152e51c09d5e04ef4cc58d711 Author: NekoDaemon Date: Tue Mar 1 14:47:15 2022 +0000 fix commit dfc60f268e6ed8573c54d3beecdb52cc49900b28 Author: NekoDaemon Date: Tue Mar 1 12:32:16 2022 +0000 merge commit 86b4e6997c1894b9794e212ef89c4ab50942ee4e Merge: 46eec3e3 df572d22 Author: NekoDaemon Date: Tue Mar 1 12:22:52 2022 +0000 Merge branch 'main' of github.com:Tonny-Gu/meta into sharding commit c924e804e161d78ea5ac22fe209e47c9e6a33a4a Merge: efea2ff4 df572d22 Author: NekoDaemon Date: Sun Feb 27 22:52:29 2022 +0800 Merge branch 'meta-project:main' into multi-comm commit efea2ff4190e2ea6b024f5c67ec20751611be1ee Author: NekoDaemon Date: Fri Feb 25 15:06:46 2022 +0000 fix commit 43b2100615cd207661b2c43da55b590b48b4ab84 Author: NekoDaemon Date: Fri Feb 25 14:59:30 2022 +0000 merge commit 4f95e2a555695dc6f7d2879442b4f0b784258ab2 Merge: 1dbe4648 654a121e Author: NekoDaemon Date: Fri Feb 25 14:48:24 2022 +0000 Merge branch 'main' of github.com:Tonny-Gu/meta into multi-comm commit 46eec3e34042fbdabdf6520cd24617993c3065ec Merge: 5989fe16 7e595b5a Author: NekoDaemon Date: Tue Feb 22 13:29:35 2022 +0000 Merge branch 'main' of github.com:Tonny-Gu/meta into sharding commit 1dbe4648afa97da4e469594a11fcea2aab15eb8a Author: NekoDaemon Date: Tue Feb 22 13:12:54 2022 +0000 clean commit 6e8a865a934edba4f65549e0cad36e95b146b7d7 Merge: a4deaed5 7e595b5a Author: NekoDaemon Date: Tue Feb 22 19:56:18 2022 +0800 Merge branch 'meta-project:main' into multi-comm commit a4deaed50545152afd4c962073d353b2b2ded34b Author: NekoDaemon Date: Sat Feb 19 08:26:22 2022 +0000 fix commit 1bff8581ac55bb62b6da28a4d566e09e8c3fee74 Author: NekoDaemon Date: Sat Feb 19 08:01:51 2022 +0000 license commit 9032d21de178485cd17ada0b9f3a885bbb5e0cb5 Merge: 8e02780e 8a941b84 Author: NekoDaemon Date: Sat Feb 19 07:51:56 2022 +0000 Merge branch 'main' of github.com:Tonny-Gu/meta into multi-comm commit 8e02780e8e593f14a363e369a66e95d58d3f7a53 Merge: 1f994f0a e7360656 Author: NekoDaemon Date: Sat Feb 19 07:41:58 2022 +0000 Merge branch 'comm-refactor' of github.com:Tonny-Gu/meta into multi-comm commit e73606560ba5d53a746c60b8ec561e6b8f00cd0c Author: NekoDaemon Date: Sat Feb 19 07:39:41 2022 +0000 fix commit 31e20ff45403ae724eb029e85b5efc2e8e8b13d3 Author: NekoDaemon Date: Fri Feb 11 15:13:46 2022 +0000 refactor commit ec73ffdb8742f546b2c43a859e2a2d9d8ffbb76e Merge: 3395d682 aa3b2a8a Author: NekoDaemon Date: Fri Feb 11 14:51:48 2022 +0800 Merge branch 'meta-project:main' into comm-refactor commit 1f994f0a3737a0a94712cf96cf24c7ed3eb4204c Merge: 3395d682 aa3b2a8a Author: NekoDaemon Date: Fri Feb 11 14:51:35 2022 +0800 Merge branch 'meta-project:main' into multi-comm commit 3395d682747caae8069b10a956e4d2d4727c0c1d Merge: e9ede20d d0da2584 Author: NekoDaemon Date: Thu Feb 10 17:52:48 2022 +0800 Merge branch 'meta-project:main' into multi-comm commit e9ede20d1dfb70f1fa253884e4944d779c89e226 Merge: ef2138e3 ba516c2f Author: NekoDaemon Date: Wed Feb 9 22:13:06 2022 +0800 Merge branch 'meta-project:main' into multi-comm commit ef2138e371f1d86c7c93f435e9f3a83a796b98e5 Author: NekoDaemon Date: Wed Feb 9 13:23:02 2022 +0000 fix commit 967ab79b0a13123d2673cbc20c79d4da388bb489 Author: NekoDaemon Date: Wed Feb 9 12:09:54 2022 +0000 fix commit 026fc341f6e3ac41fdc0bdecb519646203650902 Author: NekoDaemon Date: Wed Feb 9 04:06:38 2022 +0000 lint commit b62320bc27b29de9663ba5f763f852f9230f503d Merge: b676af67 08f6ffa8 Author: NekoDaemon Date: Wed Feb 9 11:58:16 2022 +0800 Merge branch 'meta-project:main' into multi-comm commit b676af672ab29e13a5159631bba438e1ca839c2c Author: NekoDaemon Date: Wed Feb 9 03:57:25 2022 +0000 fix gpu test commit 6b1bc9b0caa917f6a97df08b92f6cfd65ed98db8 Author: NekoDaemon Date: Tue Feb 8 15:27:46 2022 +0000 fix cpu test commit 4c5b083934b27d9cc8c550f60d6ec95713293d2f Merge: bccc9912 41fc64ec Author: NekoDaemon Date: Tue Feb 8 10:15:11 2022 +0000 Merge branch 'main' of github.com:Tonny-Gu/meta into multi-comm commit bccc99124664f419e61bf573cdcf9b747336f977 Author: NekoDaemon Date: Tue Jan 4 07:01:26 2022 +0000 lint commit e4d256eb88a516a102fc68ab370b1c4727aa7810 Author: NekoDaemon Date: Tue Jan 4 06:49:18 2022 +0000 lint commit 3f03bff7715bca4be6f976b9b43359c56e00923e Author: NekoDaemon Date: Mon Jan 3 16:37:44 2022 +0000 lint commit 8d8d89858115f11d645de49cca1384cdb8819d1b Author: NekoDaemon Date: Mon Jan 3 13:59:26 2022 +0000 implement commit 230828d64342596dcfc6e3ffb10d8f6b247b60cb Author: NekoDaemon Date: Mon Jan 3 13:20:55 2022 +0000 implement commit a671d38ed4195ac6df14d19e0efb992c4086ab4e Author: NekoDaemon Date: Mon Jan 3 10:31:47 2022 +0000 implement commit af7b85b28727f72ea41068ea08d64018a956612e Author: NekoDaemon Date: Sun Jan 2 17:07:56 2022 +0000 test & lint commit b40e9d1e4b142064e1dd4e2b93832f6c7d1dd5c1 Author: NekoDaemon Date: Sun Jan 2 16:33:22 2022 +0000 implement commit a8b875efae1e7ab9d453503e7e719e1ce8265dd9 Author: NekoDaemon Date: Sat Jan 1 15:13:12 2022 +0000 test commit b93a6ab12b4d8cf31ee4b82a7de72dcab1122fc7 Merge: e6548da6 6215bfcb Author: NekoDaemon Date: Sat Jan 1 15:02:43 2022 +0000 Merge branch 'master' of github.com:Tonny-Gu/meta into multi-comm commit e6548da63d51d03638bc1c1c739b33b655c07168 Author: NekoDaemon Date: Sat Jan 1 14:23:57 2022 +0000 implement commit 47701b5bcbb4c99c9816d9bf6a91176ae2e7b71c Author: NekoDaemon Date: Sun Dec 19 12:27:35 2021 +0000 implement commit e45b60a730cc686b5b105c214e1b59afabd4b512 Author: NekoDaemon Date: Tue Dec 14 08:19:30 2021 +0000 implement commit 56409b7aad86350e05da59e4b201be0fccf1040f Merge: c8c7381d 5ca2012d Author: NekoDaemon Date: Mon Dec 13 14:56:37 2021 +0800 Merge branch 'meta-project:master' into multi-comm commit c8c7381d943894859fc55da9d01f11bc3a57304c Author: NekoDaemon Date: Mon Dec 13 06:54:48 2021 +0000 update commit 97f5aac4c98f3e5f4b721ce25bd575946da023d6 Author: NekoDaemon Date: Sat Dec 11 08:21:32 2021 +0000 update commit 5d1a7b9a17d252d7e21f968c970a484825265099 Author: NekoDaemon Date: Fri Dec 10 16:12:34 2021 +0000 fix commit 5989fe162c8124cfc2f4140872b212aa4ab91c82 Author: NekoDaemon Date: Tue Nov 9 06:25:30 2021 +0000 merge commit 61ae89bc4e969e13cbaed7944bbe98b2e6fb1e16 Merge: 4bbd7794 5114fb69 Author: NekoDaemon Date: Tue Nov 9 06:24:05 2021 +0000 Merge branch 'master' of github.com:Tonny-Gu/meta into sharding commit 4bbd779488f386ea6d05634e7777df0c0b7ad4f8 Author: NekoDaemon Date: Tue Nov 9 05:57:41 2021 +0000 refactor commit 3e2a3ae2518808147a405ae82029ddaaa5ff5655 Author: NekoDaemon Date: Fri Oct 29 03:45:31 2021 +0000 run! commit 11d08d27ed10ee6ff81621ea90a9139da18f2854 Author: NekoDaemon Date: Tue Oct 26 13:40:00 2021 +0000 implement commit cf306f776065f89684170028efd65cf132b3a3e3 Author: NekoDaemon Date: Tue Oct 19 16:13:52 2021 +0000 implement commit 527141f88f9dc4e9d65a4e8e939b4a1b5abc367d Merge: 8e4271e3 eae1132f Author: NekoDaemon Date: Tue Oct 19 00:54:47 2021 +0800 Merge branch 'meta-project:master' into sharding commit 8e4271e3aff4744b81b1405094ad629291dff673 Author: NekoDaemon Date: Mon Oct 18 16:53:41 2021 +0000 implement commit c22c5a017d901ea2cf671e66e6a2f82f96fa9625 Merge: 75ddb688 c25a1527 Author: NekoDaemon Date: Thu Oct 14 05:47:56 2021 +0000 Merge branch 'master' of github.com:Tonny-Gu/meta into sharding commit 75ddb688c3801a38bd5a688a74c77f009b117cc3 Author: NekoDaemon Date: Thu Oct 14 05:45:48 2021 +0000 implement commit 971c93eadb2489ef2835c758e87ebee206ff6532 Author: NekoDaemon Date: Tue Sep 28 09:05:02 2021 +0000 refactor commit c67fbdf6ae8a308073194e4587ebdaba9f21ae90 Author: NekoDaemon Date: Tue Sep 28 07:49:53 2021 +0000 implement commit c676ee5dc20fc1aa2d55e7ef55595389a2aeb027 Author: NekoDaemon Date: Wed Sep 22 15:57:38 2021 +0000 implement commit 9b304955acd8dd58f3d0fa4b89de299c0a3ff774 Merge: dea0149b b5845ae0 Author: NekoDaemon Date: Wed Sep 22 15:16:01 2021 +0000 Merge branch 'master' of github.com:Tonny-Gu/meta into sharding commit dea0149bd6c9c61fa0f16dd5dcd2f89c6a0714c1 Author: NekoDaemon Date: Wed Sep 15 13:05:12 2021 +0000 implement commit c28de27f18633b2e2f1d6adbca410a94524f3c6e Merge: 6137b98b caa36505 Author: NekoDaemon Date: Mon Sep 13 22:41:22 2021 +0800 Merge branch 'meta-project:master' into sharding commit 6137b98b3fb5bbf5aa28fc50dc097561f3f9806a Author: NekoDaemon Date: Mon Sep 13 14:39:51 2021 +0000 implement commit 7587a1f773b7540ca288e8e7e4ed6eca82af68b0 Author: Tonny-Gu Date: Sat Sep 11 13:51:37 2021 +0000 implement commit 7315d93160986f83e40527a200451d2e116c5033 Author: Tonny-Gu Date: Fri Sep 10 03:58:02 2021 +0000 fix bug commit b26566ac27a565d8d42c0718fd8c7574b0c32b61 Merge: 7d0a5191 8b24b09c Author: NekoDaemon Date: Thu Sep 9 14:44:45 2021 +0800 Merge branch 'meta-project:master' into sharding commit 7d0a5191cc357acd791debc4c85fb640871e728b Author: Tonny-Gu Date: Thu Sep 9 06:43:40 2021 +0000 refactor and fix bug commit 25bb65d26d2617525e722fd09db2ec2df40bcdc9 Author: Tonny-Gu Date: Tue Sep 7 09:41:08 2021 +0000 implement commit deb3d43f52c6253f0fc4338edbb92048c2213604 Author: Tonny-Gu Date: Mon Sep 6 03:52:12 2021 +0000 fix bug commit a08fd847a86a9065c07b923a1cd8bc824a4b5ca3 Merge: af6df281 469357db Author: NekoDaemon Date: Mon Sep 6 11:24:59 2021 +0800 Merge branch 'meta-project:master' into sharding commit af6df28148adc1eb0263a4af3173c91ec4e823ec Merge: ebf6dc64 08dcf2e6 Author: NekoDaemon Date: Mon Aug 30 15:12:02 2021 +0800 Merge branch 'meta-project:master' into sharding commit ebf6dc646286c51ad57643ddc969fd21c69dbbfc Merge: 94f90bad 0110e388 Author: NekoDaemon Date: Fri Aug 27 15:17:31 2021 +0800 Merge branch 'meta-project:master' into sharding commit 94f90badd7e2690d2de7b523e1f10cb6403a0bb3 Merge: ec43e777 c87ae8ee Author: Tonny-Gu Date: Thu Aug 26 09:27:10 2021 +0000 Merge remote-tracking branch 'origin/master' into sharding commit ec43e7777ea951d0e69c638505a3e2f516d8737f Author: Tonny-Gu Date: Thu Aug 26 09:21:44 2021 +0000 implement commit 50a90f960b9f83713d2577228638835d434c85ad Author: Tonny-Gu Date: Mon Aug 23 11:34:13 2021 +0000 implement commit 4a50d9d72e568318be53c6abadf387d53073fb75 Author: Tonny-Gu Date: Sat Aug 21 08:04:06 2021 +0000 extend shardspec family commit 38d46bd3a5ee59d12801b71ebcb874cad4885da3 Merge: 595cd9ce 3ed45989 Author: NekoDaemon Date: Thu Aug 19 17:48:56 2021 +0800 Merge branch 'meta-project:master' into sharding commit 595cd9ce042179701f2aa62efc075c276f360bfa Author: Tonny-Gu Date: Thu Aug 19 09:23:58 2021 +0000 implement commit 637d4aa97caaef411f8b66ae9a2c203d0d15b6f7 Merge: 77ad1b2d f6f87a3f Author: NekoDaemon Date: Wed Aug 18 14:47:41 2021 +0800 Merge branch 'meta-project:master' into sharding commit 77ad1b2d8a7f495429112144154bdeced75fb6f0 Merge: 3de82931 e0e003e7 Author: NekoDaemon Date: Tue Aug 17 18:13:44 2021 +0800 Merge branch 'meta-project:master' into sharding commit 3de82931cb5417138fefabdcdc7c0d2ff907333f Author: Tonny-Gu Date: Tue Aug 17 10:09:42 2021 +0000 implement commit d8cf09265987b176bf4b763fd1b0fd60ccb097a5 Author: Tonny-Gu Date: Mon Aug 16 12:32:30 2021 +0000 add ShardSpec, ShardOpAttrs, InitShardOpAttrs Pass --- include/raf/pass.h | 23 ++ python/raf/distributed/sharding/expandrule.py | 229 ++++++++++++++++++ python/raf/distributed/sharding/inferhint.py | 141 +++++++++++ src/pass/sharding.cc | 163 +++++++++++++ 4 files changed, 556 insertions(+) create mode 100644 python/raf/distributed/sharding/expandrule.py create mode 100644 python/raf/distributed/sharding/inferhint.py create mode 100644 src/pass/sharding.cc diff --git a/include/raf/pass.h b/include/raf/pass.h index 605c6d5f..e08a40fe 100644 --- a/include/raf/pass.h +++ b/include/raf/pass.h @@ -17,6 +17,7 @@ #include "raf/value.h" #include "raf/ir_ext.h" #include "raf/pass_manager.h" +#include "raf/sharding.h" namespace raf { namespace pass { @@ -342,6 +343,28 @@ Pass WavefrontStreamSchedule(); */ Pass ASAPStreamSchedule(); +/*! + * \brief Set ShardOpCallAttrs for annotated Relay Op Call + * + * \return The created pass. + */ +Pass AnnotateShardOpCall(const ir::Map& attrs_map); + +/*! + * \brief Expand Op Call with ShardOpCallAttrs to a series of expressions + * according to the corresponding expansion pattern + * + * \return The created pass. + */ +Pass ExpandShardOpCall(); + +/*! + * \brief . + * + * \return . + */ +Pass InferShardSpec(); + /*! * \brief This pass transforms BBNF into ANF and schedules operators to improve overlapping * between computation and communication. diff --git a/python/raf/distributed/sharding/expandrule.py b/python/raf/distributed/sharding/expandrule.py new file mode 100644 index 00000000..7ab927ac --- /dev/null +++ b/python/raf/distributed/sharding/expandrule.py @@ -0,0 +1,229 @@ +# pylint: disable=invalid-name, unused-argument +"""Implementation of Expansion Rules""" +from ctypes import Union +import functools +import numpy as np +import raf +import tvm +from queue import PriorityQueue +from typing import Callable, List, Tuple + +from raf._ffi.sharding._make import ShardOpCallAttrs +from raf._ffi.op import GetOp +from raf._lib import _register_func, relay +from raf.distributed.sharding.shardspec import BaseShardSpec, ShardSpec, UnsetShardSpec +from raf._core.value import Value +from raf import distributed as dist +from raf.ir.anf_builder import ANFBuilder +from tvm.relay import Call, Expr +from tvm.ir import Op +from tvm.relay.op.transform import full +from tvm.runtime.object import Object + +pattern_map = { + 0: "kElemWise", + 1: "kBroadcast", + 2: "kInjective", + 3: "kCommReduce", + 4: "kOutEWiseFusable", + 7: "kTuple", + 8: "kOpaque", +} +# TODO: this pattern map is replicated multiple times in source code + +class ShardInfo: + call: relay.Call + op: Op + args: List[Expr] + attrs: Object + sin: List[BaseShardSpec] + sout: List[BaseShardSpec] + + def __init__(self, call: relay.Call): + assert isinstance(call, relay.Call) + self.call = call + self.op = call.op + self.args = call.args + self.attrs = call.attrs + self.sin = call.attrs.sin + self.sout = call.attrs.sout + +def all_satisfied(conds: List[Callable[[ShardInfo], bool]]): + def func(s: ShardInfo): + for c in conds: + if not c(s): + return False + return True + return func + +def is_same_spec(*args): + for e in args[1:]: + if not tvm.ir.structural_equal(args[0], e): + return False + return True + +def is_sharded(s: BaseShardSpec): + return isinstance(s, ShardSpec) + +def is_replicated(s: BaseShardSpec): + if not isinstance(s, ShardSpec): + return False + return s.nshard == 1 + +def no_subgroup(s: BaseShardSpec): + if not isinstance(s, ShardSpec): + return False + return s.ngroup == 1 + +def always_apply(s: ShardInfo): + """Always apply this rule to expand op call""" + return True + +def expand_when(cond: Callable[[ShardInfo], bool], priority=1): + """Specify the priority and the condition when this expansion rule should be used. + + Parameters + ---------- + cond : function(call) -> bool + A function answering this expansion rule is eligible under particular conditions + (e.g. with particular sharding specifications) + """ + if not hasattr(expand_when, "counter"): + expand_when.counter = 0 + if not hasattr(expand_when, "rules"): + expand_when.rules = {} + + def decorator(pyfunc): + if not hasattr(pyfunc, "op_names"): + raise ValueError("Must register expansion rule first") + for op_name in pyfunc.op_names: + op = GetOp(op_name) + if op not in expand_when.rules: + expand_when.rules[op] = PriorityQueue() + expand_when.rules[op].put((-priority, expand_when.counter, cond, pyfunc)) + expand_when.counter += 1 + return pyfunc + + return decorator + + +def register_expansion_rule(op_name): + """Register an expansion rule that converts a full-sized op into a partitioned-size op + + Parameters + ---------- + op_name: str or List[str] + Name of op to register + """ + op_names = [op_name] if isinstance(op_name, str) else op_name + assert isinstance(op_names, list) + + def decorator(pyfunc): + @functools.wraps(pyfunc) + def new_pyfunc(call: relay.Call): + return pyfunc(call) + + setattr(new_pyfunc, "op_names", op_names) + return new_pyfunc + + return decorator + +@_register_func("raf.sharding._match_expansion_rule") +def expand_opcall(call: relay.Call): + """Match an eligible expansion rule and return expanded IR expr""" + rules = expand_when.rules[call.op] + s = ShardInfo(call) + for rule in rules.queue: + _, _, cond, irgen = rule + if cond(s): + return irgen(s) + return None + +@expand_when( + all_satisfied([ + lambda s: is_replicated(s.sin[0]), + lambda s: is_sharded(s.sout[0]) + ]), + priority=1, +) +@register_expansion_rule("raf.op._reshard") +def reshard_replicated_to_sharded(s: ShardInfo): + """_reshard -> _reshard_r2s (strided_slice)""" + begin, end = [], [] + shape = s.args[0].checked_type.concrete_shape + spec = s.sout[0] + # spec = ShardSpec() + for idx, dim_nshard, dim_size in zip(spec.logic_index, spec.logic_shape, shape): + assert dim_size % dim_nshard == 0 + begin.append(int((dim_size // dim_nshard) * idx)) + end.append(int((dim_size // dim_nshard) * (idx + 1))) + return relay.Call(GetOp("raf.op.strided_slice"), [s.args[0], raf.ir.const(begin), raf.ir.const(end), raf.ir.const([1] * spec.ndim), raf.ir.const("end")]) + +@expand_when( + all_satisfied([ + lambda s: print(s.sin[0], s.sout[0]) or True, + lambda s: is_sharded(s.sin[0]), + lambda s: is_replicated(s.sout[0]), + ]), + priority=1, +) +@register_expansion_rule("raf.op._reshard") +def reshard_sharded_to_replicated(s: ShardInfo): + """_reshard -> _reshard_s2r (allgather)""" + spec = s.sin[0] + axis = [] + full_shape = [] + for i in range(spec.ndim): + if spec.logic_shape[i] > 0: + axis.append(i) + full_shape.append(int(spec.logic_shape[i])) + full_shape.append(int(spec.subgroup_shape[i])) + ranks = np.array([int(e) for e in spec.ranks]).reshape(full_shape) + nshard_on_dim = int(spec.logic_shape[axis[0]]) + rank_list = np.moveaxis(ranks, axis[0], -1).reshape((ranks.size // nshard_on_dim, nshard_on_dim)) + return relay.Call(GetOp("raf.op._allgather"), [s.args[0], raf.ir.const(axis[0]), raf.ir.const(rank_list.tolist())]) + +# @expand_when(always_apply, priority=0) +# @register_expansion_rule("raf.op._reshard") +# def reshard_mismatch(s: ShardInfo): +# """_reshard -> """ +# raise NotImplementedError("Unable to process the given sharding specifications") + + +@expand_when(lambda s: is_same_spec(s.sin[0], s.sin[1], s.sout[0])) +@register_expansion_rule(["raf.op.add", "raf.op.subtract"]) +def add_or_sub(s: ShardInfo): + """add/sub -> add/sub""" + return relay.Call(s.op, s.args) + +@expand_when(lambda s: is_same_spec(s.sin[0], s.sout[0])) +@register_expansion_rule(["raf.op.relu"]) +def element_wise(s: ShardInfo): + return relay.Call(s.op, s.args) + +@expand_when(all_satisfied([ + lambda s: is_sharded(s.sin[0]) and is_sharded(s.sin[1]), + lambda s: no_subgroup(s.sin[0]) and no_subgroup(s.sin[1]), + lambda s: is_replicated(s.sout[0]), + lambda s: s.sin[0].logic_shape[1] == s.sin[1].logic_shape[0] +])) +@register_expansion_rule(["raf.op.matmul"]) +def matmul_algor1(s: ShardInfo): + y_1 = relay.Call(s.op, s.args) + y_2 = tvm.relay.Tuple([y_1]) + return relay.Call(GetOp("raf.op._allreduce"), [y_2, raf.ir.const("sum"), raf.ir.const(None)]) + +# @expand_when(always_apply) +# @register_expansion_rule("_fallback") +# def fallback_reshard_to_replicated(s: ShardInfo): +# """Gather partitioned tensors for op without matched rules""" +# op, args, attrs = call.op, call.args, call.attrs +# if ( +# len(args) != 1 +# or isinstance(attrs.shard_in, TupleSpec) +# or isinstance(attrs.shard_out, TupleSpec) +# ): +# raise NotImplementedError("Currently coverting multiple args is not supported") +# new_attrs = ShardOpCallAttrs(attrs.shard_in, MirroredSpec()) +# new_args = [relay.Call(GetOp("raf.op._reshard"), args, new_attrs)] +# return relay.Call(op, new_args) diff --git a/python/raf/distributed/sharding/inferhint.py b/python/raf/distributed/sharding/inferhint.py new file mode 100644 index 00000000..b973072a --- /dev/null +++ b/python/raf/distributed/sharding/inferhint.py @@ -0,0 +1,141 @@ +# pylint: disable=invalid-name, unused-argument +"""Implementaion of Infer Hints""" +from ctypes import Union +import functools +import numpy as np +import raf +import tvm +from queue import PriorityQueue +from typing import Callable, List, Tuple + +from raf._ffi.sharding._make import ShardOpCallAttrs +from raf._ffi.op import GetOp +from raf._lib import _register_func, relay +from raf.distributed.sharding.shardspec import BaseShardSpec, ShardSpec, UnsetShardSpec +from raf.distributed.sharding.utils import make_replicated_spec +from raf._core.value import Value +from raf import distributed as dist +from raf.ir.anf_builder import ANFBuilder +from tvm.relay import Call, Expr +from tvm.ir import Op + +from .expandrule import ShardInfo, all_satisfied, always_apply, expand_opcall, is_same_spec, is_sharded +from .expandrule import register_expansion_rule as register_infer_hint + +def try_when(cond: Callable[[ShardInfo], bool], priority=1): + if not hasattr(try_when, "counter"): + try_when.counter = 0 + if not hasattr(try_when, "rules"): + try_when.rules = {} + + def decorator(pyfunc): + if not hasattr(pyfunc, "op_names"): + raise ValueError("Must register infer hint first") + for op_name in pyfunc.op_names: + op = GetOp(op_name) + if op not in try_when.rules: + try_when.rules[op] = PriorityQueue() + try_when.rules[op].put((-priority, try_when.counter, cond, pyfunc)) + try_when.counter += 1 + return pyfunc + + return decorator + +@_register_func("raf.sharding._infer_shardspec") +def infer_shardspec(call: relay.Call): + rules = try_when.rules[call.op] + s = ShardInfo(call) + + # Step 1: Inherit ShardSpec from previous output + filled_sin = [] + for i in range(len(s.sin)): + if isinstance(s.sin[i], UnsetShardSpec): + if isinstance(s.args[i], relay.Call) and hasattr(s.args[i].attrs, "sin"): + # cannot use isinstance to check the type of OpCall Attrs + # direct inherit ShardSpec + prev_sinfo = ShardInfo(s.args[i]) + filled_sin.append(prev_sinfo.sout[0]) + else: + # the previous output doesn't have ShardSpec + ndim = len(s.args[0].checked_type.concrete_shape) + filled_sin.append(make_replicated_spec(ndim)) + + else: + # already exist a specified ShardSpec + filled_sin.append(s.sin[i]) + + filled_attrs = ShardOpCallAttrs(filled_sin, s.sout) + filled_call = relay.Call(s.op, s.args, filled_attrs) + filled_s = ShardInfo(filled_call) + + # Step 2: Match an InferHint + guessed_calls = [] + for rule in rules.queue: + _, _, cond, irgen = rule + if cond(filled_s): + guessed_calls.extend([relay.Call(s.op, s.args, a) for a in irgen(filled_s)]) + if not guessed_calls: + raise ValueError("Failed to match an InferHint") + + # Step 3: Check the solution is practicable + ninputs = len(filled_s.sin) + noutputs = len(filled_s.sout) + immut_in_idx = [i for i in range(ninputs) if is_sharded(filled_s.sin[i]) and filled_s.sin[i].mutable == False] + immut_out_idx = [i for i in range(noutputs) if is_sharded(filled_s.sout[i]) and filled_s.sout[i].mutable == False] + + possible_calls = [] + for guessed_call in guessed_calls: + if not expand_opcall(guessed_call): + continue + guessed_s = ShardInfo(guessed_call) + immut_args = [(filled_s.sin[i], guessed_s.sin[i]) for i in immut_in_idx] + \ + [(filled_s.sout[i], guessed_s.sout[i]) for i in immut_out_idx] + for pair in immut_args: + if not is_same_spec(pair[0], pair[1]): + break + else: + possible_calls.append(guessed_call) + + # Step 4: Pick an OpCall with full ShardSpec + # TODO: should use graph searching algorithm with cost map here. For now, always select the first solution. + inferred_call = possible_calls[0] + inferred_s = ShardInfo(inferred_call) + + # Step 5: Insert Reshard OpCall + resharded_args = [] + for i in range(ninputs): + if is_same_spec(filled_s.sin[i], inferred_s.sin[i]): + resharded_args.append(inferred_s.args[i]) + else: + resharded_args.append(relay.Call( + GetOp("raf.op._reshard"), + [inferred_s.args[i]], + ShardOpCallAttrs([filled_s.sin[i]], [inferred_s.sin[i]]))) + + print("[Sharding Infer] %s %s ### %s" % (filled_s.op, inferred_s.attrs, filled_s.attrs)) + return relay.Call(inferred_s.op, resharded_args, inferred_s.attrs) + +def is_unset(s: BaseShardSpec): + return isinstance(s, UnsetShardSpec) + +@try_when(always_apply) +@register_infer_hint(["raf.op.add", "raf.op.subtract"]) +def element_wise_op_with_2in_1out(s: ShardInfo) -> List[ShardOpCallAttrs]: + specs = [] + for e in (s.sin[0], s.sin[1], s.sout[0]): + if not is_unset(e): + specs.append(e) + return [ + ShardOpCallAttrs([e, e], [e]) for e in specs + ] + +@try_when(always_apply) +@register_infer_hint(["raf.op.relu"]) +def element_wise_op_with_1in_1out(s: ShardInfo) -> List[ShardOpCallAttrs]: + specs = [] + for e in (s.sin[0], s.sout[0]): + if not is_unset(e): + specs.append(e) + return [ + ShardOpCallAttrs([e], [e]) for e in specs + ] diff --git a/src/pass/sharding.cc b/src/pass/sharding.cc new file mode 100644 index 00000000..0491702b --- /dev/null +++ b/src/pass/sharding.cc @@ -0,0 +1,163 @@ +/*! + * Copyright (c) 2021 by Contributors + * \file init_shardspec.cc + * \brief Gradient operator input selection pass + */ +#include +#include "raf/op.h" +#include "raf/ir.h" +#include "raf/pass.h" +#include "raf/sharding.h" +#include +#include + +namespace raf { +namespace pass { + +using namespace raf::ir; +using namespace raf::op; +using namespace raf::value; +using namespace raf::sharding; + +namespace shard_pass { + +class ShardOpCallAttrsSetter : public ExprMutator { + public: + explicit ShardOpCallAttrsSetter(const Map& attrs_map) : _attrs_map(attrs_map) { + } + + Expr VisitExpr_(const CallNode* node) override { + const Expr& callee = node->op; + if (callee->IsInstance()) { + auto ref = GetRef(node); + if (_attrs_map.count(ref)) { + auto new_expr = Call(node->op, node->args, Attrs(_attrs_map[ref])); + return ExprMutator::VisitExpr_(new_expr.as()); + } + } + return ExprMutator::VisitExpr_(node); + } + + private: + const Map& _attrs_map; +}; + +class ShardOpCallExpander : public ExprMutator { + public: + Expr VisitExpr_(const CallNode* node) override { + const Expr& op = node->op; + const Attrs& attrs = node->attrs; + const auto* f = tvm::runtime::Registry::Get("raf.sharding._match_expansion_rule"); + if (attrs.defined() && op->IsInstance() && attrs->IsInstance()) { + auto call = GetRef(node); + Expr new_expr = (*f)(call); + // return call.same_as(new_expr) ? new_expr : ExprMutator::VisitExpr(new_expr); + return new_expr; + } + return ExprMutator::VisitExpr_(node); + } +}; + + // // Step 1: Propagate ShardSpec + // Array sin; + // for (int64_t i = 0; i < sattr->sin.size(); ++i) { + // if (sattr->sin[i]->IsInstance()) { + // LOG(INFO) << i << " is unset shardspec"; + // bool flag_unchanged = true; + // if (args[i]->IsInstance()) { + // // Copy ShardSpec from previous output + // LOG(INFO) << i << " is call"; + // const auto pcall = Downcast(args[i]); + // if (pcall->attrs->IsInstance()) { + // const auto pattr = pcall->attrs.as(); + // sin.push_back(pattr->sout[0]); + // flag_unchanged = false; + // } + // } + // if (flag_unchanged) { + // // sin[i] = ShardSpec::make() + // } + // } else { + // sin.push_back(sattr->sin[i]); + // } + // } + +class ShardSpecPropagator : public ExprMutator { + public: + Expr VisitExpr_(const CallNode* node) override { + Call call = Downcast(ExprMutator::VisitExpr_(node)); + const Expr& op = call->op; + const Attrs& attrs = call->attrs; + const Array& args = call->args; + const auto* f = tvm::runtime::Registry::Get("raf.sharding._infer_shardspec"); + if (attrs.defined() && op->IsInstance() && attrs->IsInstance()) { + LOG(INFO) << op << " " << call->op; + + Expr new_expr = (*f)(call); + return new_expr; + } + return call; + } +}; + +} // namespace shard_pass + +Pass AnnotateShardOpCall(const Map& attrs_map) { + return CreateModulePass( + [=](IRModule mod, const PassContext& pass_ctx) { + DLOG(INFO) << "pass::AnnotateShardOpCall"; + IRModule updated_mod = IRModule(mod->functions); + for (auto kv : updated_mod->functions) { + if (kv.second.as()) { + auto setter = shard_pass::ShardOpCallAttrsSetter(attrs_map); + auto func = tvm::runtime::Downcast(setter.VisitExpr(kv.second)); + updated_mod->Add(kv.first, func, true); + } + } + return updated_mod; + }, + 0, "AnnotateShardOpCall", {}); +} + +RAF_REGISTER_GLOBAL("raf.pass_.AnnotateShardOpCall").set_body_typed(AnnotateShardOpCall); + +Pass ExpandShardOpCall() { + return CreateModulePass( + [=](IRModule mod, const PassContext& pass_ctx) { + DLOG(INFO) << "pass::ExpandShardOpCall"; + IRModule updated_mod = IRModule(mod->functions); + for (auto kv : updated_mod->functions) { + if (kv.second.as()) { + auto setter = shard_pass::ShardOpCallExpander(); + auto func = tvm::runtime::Downcast(setter.VisitExpr(kv.second)); + updated_mod->Add(kv.first, func, true); + } + } + return updated_mod; + }, + 0, "ExpandShardOpCall", {}); +} + +RAF_REGISTER_GLOBAL("raf.pass_.ExpandShardOpCall").set_body_typed(ExpandShardOpCall); + +Pass InferShardSpec() { + return CreateModulePass( + [=](IRModule mod, const PassContext& pass_ctx) { + DLOG(INFO) << "pass::InferShardSpec"; + IRModule updated_mod = IRModule(mod->functions); + for (auto kv : updated_mod->functions) { + if (kv.second.as()) { + auto setter = shard_pass::ShardSpecPropagator(); + auto func = tvm::runtime::Downcast(setter.VisitExpr(kv.second)); + updated_mod->Add(kv.first, func, true); + } + } + return updated_mod; + }, + 0, "InferShardSpec", {}); +} + +RAF_REGISTER_GLOBAL("raf.pass_.InferShardSpec").set_body_typed(InferShardSpec); + +} // namespace pass +} // namespace raf From 2c556e7dadfca9caf403399fde677431770db55f Mon Sep 17 00:00:00 2001 From: NekoDaemon Date: Tue, 23 Aug 2022 21:02:27 +0000 Subject: [PATCH 2/6] lint --- python/raf/distributed/sharding/expandrule.py | 141 ++++++++++-------- python/raf/distributed/sharding/inferhint.py | 82 ++++++---- src/pass/sharding.cc | 34 +---- 3 files changed, 138 insertions(+), 119 deletions(-) diff --git a/python/raf/distributed/sharding/expandrule.py b/python/raf/distributed/sharding/expandrule.py index 7ab927ac..7dbdd04c 100644 --- a/python/raf/distributed/sharding/expandrule.py +++ b/python/raf/distributed/sharding/expandrule.py @@ -1,23 +1,21 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# SPDX-License-Identifier: Apache-2.0 + # pylint: disable=invalid-name, unused-argument """Implementation of Expansion Rules""" -from ctypes import Union import functools +from queue import PriorityQueue +from typing import Callable, List import numpy as np import raf import tvm -from queue import PriorityQueue -from typing import Callable, List, Tuple -from raf._ffi.sharding._make import ShardOpCallAttrs + from raf._ffi.op import GetOp from raf._lib import _register_func, relay -from raf.distributed.sharding.shardspec import BaseShardSpec, ShardSpec, UnsetShardSpec -from raf._core.value import Value -from raf import distributed as dist -from raf.ir.anf_builder import ANFBuilder -from tvm.relay import Call, Expr +from raf.distributed.sharding.shardspec import BaseShardSpec, ShardSpec +from tvm.relay import Expr from tvm.ir import Op -from tvm.relay.op.transform import full from tvm.runtime.object import Object pattern_map = { @@ -31,14 +29,18 @@ } # TODO: this pattern map is replicated multiple times in source code + class ShardInfo: + """Helper for parsing ShardSpec.""" + + # pylint: disable=too-few-public-methods call: relay.Call op: Op args: List[Expr] attrs: Object sin: List[BaseShardSpec] sout: List[BaseShardSpec] - + def __init__(self, call: relay.Call): assert isinstance(call, relay.Call) self.call = call @@ -48,46 +50,60 @@ def __init__(self, call: relay.Call): self.sin = call.attrs.sin self.sout = call.attrs.sout + def all_satisfied(conds: List[Callable[[ShardInfo], bool]]): + """Return true when all conditions are satisfied.""" + def func(s: ShardInfo): for c in conds: if not c(s): return False return True + return func + def is_same_spec(*args): + """Check whether two ShardSpecs are same.""" for e in args[1:]: if not tvm.ir.structural_equal(args[0], e): return False return True + def is_sharded(s: BaseShardSpec): + """Check whether it is a ShardSpec.""" return isinstance(s, ShardSpec) + def is_replicated(s: BaseShardSpec): + """Check whether it is a replicated ShardSpec.""" if not isinstance(s, ShardSpec): return False return s.nshard == 1 + def no_subgroup(s: BaseShardSpec): + """Check whether subgrouping feature is disabled.""" if not isinstance(s, ShardSpec): return False return s.ngroup == 1 + def always_apply(s: ShardInfo): - """Always apply this rule to expand op call""" + """Always return True.""" return True + def expand_when(cond: Callable[[ShardInfo], bool], priority=1): """Specify the priority and the condition when this expansion rule should be used. Parameters ---------- - cond : function(call) -> bool - A function answering this expansion rule is eligible under particular conditions - (e.g. with particular sharding specifications) + cond : function(ShardInfo) -> bool + A function validating this expansion rule is eligible to apply. """ + if not hasattr(expand_when, "counter"): expand_when.counter = 0 if not hasattr(expand_when, "rules"): @@ -113,7 +129,7 @@ def register_expansion_rule(op_name): Parameters ---------- op_name: str or List[str] - Name of op to register + Names of op to register """ op_names = [op_name] if isinstance(op_name, str) else op_name assert isinstance(op_names, list) @@ -128,48 +144,58 @@ def new_pyfunc(call: relay.Call): return decorator + @_register_func("raf.sharding._match_expansion_rule") def expand_opcall(call: relay.Call): - """Match an eligible expansion rule and return expanded IR expr""" + """Match an eligible expansion rule and return expanded IR expr.""" rules = expand_when.rules[call.op] s = ShardInfo(call) for rule in rules.queue: _, _, cond, irgen = rule if cond(s): return irgen(s) - return None + return None + @expand_when( - all_satisfied([ - lambda s: is_replicated(s.sin[0]), - lambda s: is_sharded(s.sout[0]) - ]), + all_satisfied([lambda s: is_replicated(s.sin[0]), lambda s: is_sharded(s.sout[0])]), priority=1, ) @register_expansion_rule("raf.op._reshard") def reshard_replicated_to_sharded(s: ShardInfo): - """_reshard -> _reshard_r2s (strided_slice)""" + """_reshard (R to S) -> strided_slice""" begin, end = [], [] shape = s.args[0].checked_type.concrete_shape spec = s.sout[0] - # spec = ShardSpec() for idx, dim_nshard, dim_size in zip(spec.logic_index, spec.logic_shape, shape): assert dim_size % dim_nshard == 0 begin.append(int((dim_size // dim_nshard) * idx)) end.append(int((dim_size // dim_nshard) * (idx + 1))) - return relay.Call(GetOp("raf.op.strided_slice"), [s.args[0], raf.ir.const(begin), raf.ir.const(end), raf.ir.const([1] * spec.ndim), raf.ir.const("end")]) + return relay.Call( + GetOp("raf.op.strided_slice"), + [ + s.args[0], + raf.ir.const(begin), + raf.ir.const(end), + raf.ir.const([1] * spec.ndim), + raf.ir.const("end"), + ], + ) + @expand_when( - all_satisfied([ - lambda s: print(s.sin[0], s.sout[0]) or True, - lambda s: is_sharded(s.sin[0]), - lambda s: is_replicated(s.sout[0]), - ]), + all_satisfied( + [ + lambda s: print(s.sin[0], s.sout[0]) or True, + lambda s: is_sharded(s.sin[0]), + lambda s: is_replicated(s.sout[0]), + ] + ), priority=1, ) @register_expansion_rule("raf.op._reshard") def reshard_sharded_to_replicated(s: ShardInfo): - """_reshard -> _reshard_s2r (allgather)""" + """_reshard (S to R) -> allgather""" spec = s.sin[0] axis = [] full_shape = [] @@ -178,16 +204,16 @@ def reshard_sharded_to_replicated(s: ShardInfo): axis.append(i) full_shape.append(int(spec.logic_shape[i])) full_shape.append(int(spec.subgroup_shape[i])) + assert len(axis) == 1 # TODO: remove this constrain ranks = np.array([int(e) for e in spec.ranks]).reshape(full_shape) nshard_on_dim = int(spec.logic_shape[axis[0]]) - rank_list = np.moveaxis(ranks, axis[0], -1).reshape((ranks.size // nshard_on_dim, nshard_on_dim)) - return relay.Call(GetOp("raf.op._allgather"), [s.args[0], raf.ir.const(axis[0]), raf.ir.const(rank_list.tolist())]) - -# @expand_when(always_apply, priority=0) -# @register_expansion_rule("raf.op._reshard") -# def reshard_mismatch(s: ShardInfo): -# """_reshard -> """ -# raise NotImplementedError("Unable to process the given sharding specifications") + rank_list = np.moveaxis(ranks, axis[0], -1).reshape( + (ranks.size // nshard_on_dim, nshard_on_dim) + ) + return relay.Call( + GetOp("raf.op._allgather"), + [s.args[0], raf.ir.const(axis[0]), raf.ir.const(rank_list.tolist())], + ) @expand_when(lambda s: is_same_spec(s.sin[0], s.sin[1], s.sout[0])) @@ -196,34 +222,27 @@ def add_or_sub(s: ShardInfo): """add/sub -> add/sub""" return relay.Call(s.op, s.args) + @expand_when(lambda s: is_same_spec(s.sin[0], s.sout[0])) -@register_expansion_rule(["raf.op.relu"]) +@register_expansion_rule(["raf.op.relu"]) # TODO: should use a generated list instead def element_wise(s: ShardInfo): + """element wise -> element wise""" return relay.Call(s.op, s.args) -@expand_when(all_satisfied([ - lambda s: is_sharded(s.sin[0]) and is_sharded(s.sin[1]), - lambda s: no_subgroup(s.sin[0]) and no_subgroup(s.sin[1]), - lambda s: is_replicated(s.sout[0]), - lambda s: s.sin[0].logic_shape[1] == s.sin[1].logic_shape[0] -])) + +@expand_when( + all_satisfied( + [ + lambda s: is_sharded(s.sin[0]) and is_sharded(s.sin[1]), + lambda s: no_subgroup(s.sin[0]) and no_subgroup(s.sin[1]), + lambda s: is_replicated(s.sout[0]), + lambda s: s.sin[0].logic_shape[1] == s.sin[1].logic_shape[0], + ] + ) +) @register_expansion_rule(["raf.op.matmul"]) def matmul_algor1(s: ShardInfo): + """matmul -> matmul + allreduce""" y_1 = relay.Call(s.op, s.args) y_2 = tvm.relay.Tuple([y_1]) return relay.Call(GetOp("raf.op._allreduce"), [y_2, raf.ir.const("sum"), raf.ir.const(None)]) - -# @expand_when(always_apply) -# @register_expansion_rule("_fallback") -# def fallback_reshard_to_replicated(s: ShardInfo): -# """Gather partitioned tensors for op without matched rules""" -# op, args, attrs = call.op, call.args, call.attrs -# if ( -# len(args) != 1 -# or isinstance(attrs.shard_in, TupleSpec) -# or isinstance(attrs.shard_out, TupleSpec) -# ): -# raise NotImplementedError("Currently coverting multiple args is not supported") -# new_attrs = ShardOpCallAttrs(attrs.shard_in, MirroredSpec()) -# new_args = [relay.Call(GetOp("raf.op._reshard"), args, new_attrs)] -# return relay.Call(op, new_args) diff --git a/python/raf/distributed/sharding/inferhint.py b/python/raf/distributed/sharding/inferhint.py index b973072a..35db2ffa 100644 --- a/python/raf/distributed/sharding/inferhint.py +++ b/python/raf/distributed/sharding/inferhint.py @@ -1,28 +1,36 @@ -# pylint: disable=invalid-name, unused-argument +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# SPDX-License-Identifier: Apache-2.0 + +# pylint: disable=invalid-name, unused-argument, missing-function-docstring """Implementaion of Infer Hints""" -from ctypes import Union -import functools -import numpy as np -import raf -import tvm from queue import PriorityQueue -from typing import Callable, List, Tuple +from typing import Callable, List from raf._ffi.sharding._make import ShardOpCallAttrs from raf._ffi.op import GetOp from raf._lib import _register_func, relay -from raf.distributed.sharding.shardspec import BaseShardSpec, ShardSpec, UnsetShardSpec +from raf.distributed.sharding.shardspec import BaseShardSpec, UnsetShardSpec from raf.distributed.sharding.utils import make_replicated_spec -from raf._core.value import Value -from raf import distributed as dist -from raf.ir.anf_builder import ANFBuilder -from tvm.relay import Call, Expr -from tvm.ir import Op -from .expandrule import ShardInfo, all_satisfied, always_apply, expand_opcall, is_same_spec, is_sharded +from .expandrule import ( + ShardInfo, + always_apply, + expand_opcall, + is_same_spec, + is_sharded, +) from .expandrule import register_expansion_rule as register_infer_hint + def try_when(cond: Callable[[ShardInfo], bool], priority=1): + """Specify the priority and the condition when this infer hint should be used. + + Parameters + ---------- + cond : function(ShardInfo) -> bool + A function validating this infer hint is eligible to apply. + """ + if not hasattr(try_when, "counter"): try_when.counter = 0 if not hasattr(try_when, "rules"): @@ -41,8 +49,11 @@ def decorator(pyfunc): return decorator + @_register_func("raf.sharding._infer_shardspec") def infer_shardspec(call: relay.Call): + # pylint: disable=too-many-locals, too-many-branches + """Fill the placeholders of ShardSpec with infer hints.""" rules = try_when.rules[call.op] s = ShardInfo(call) @@ -63,7 +74,7 @@ def infer_shardspec(call: relay.Call): else: # already exist a specified ShardSpec filled_sin.append(s.sin[i]) - + filled_attrs = ShardOpCallAttrs(filled_sin, s.sout) filled_call = relay.Call(s.op, s.args, filled_attrs) filled_s = ShardInfo(filled_call) @@ -80,16 +91,21 @@ def infer_shardspec(call: relay.Call): # Step 3: Check the solution is practicable ninputs = len(filled_s.sin) noutputs = len(filled_s.sout) - immut_in_idx = [i for i in range(ninputs) if is_sharded(filled_s.sin[i]) and filled_s.sin[i].mutable == False] - immut_out_idx = [i for i in range(noutputs) if is_sharded(filled_s.sout[i]) and filled_s.sout[i].mutable == False] + immut_in_idx = [ + i for i in range(ninputs) if is_sharded(filled_s.sin[i]) and not filled_s.sin[i].mutable + ] + immut_out_idx = [ + i for i in range(noutputs) if is_sharded(filled_s.sout[i]) and not filled_s.sout[i].mutable + ] possible_calls = [] for guessed_call in guessed_calls: if not expand_opcall(guessed_call): continue guessed_s = ShardInfo(guessed_call) - immut_args = [(filled_s.sin[i], guessed_s.sin[i]) for i in immut_in_idx] + \ - [(filled_s.sout[i], guessed_s.sout[i]) for i in immut_out_idx] + immut_args = [(filled_s.sin[i], guessed_s.sin[i]) for i in immut_in_idx] + [ + (filled_s.sout[i], guessed_s.sout[i]) for i in immut_out_idx + ] for pair in immut_args: if not is_same_spec(pair[0], pair[1]): break @@ -97,7 +113,8 @@ def infer_shardspec(call: relay.Call): possible_calls.append(guessed_call) # Step 4: Pick an OpCall with full ShardSpec - # TODO: should use graph searching algorithm with cost map here. For now, always select the first solution. + # TODO: should use graph searching algorithm with cost map here. + # For now, always select the first solution. inferred_call = possible_calls[0] inferred_s = ShardInfo(inferred_call) @@ -107,17 +124,23 @@ def infer_shardspec(call: relay.Call): if is_same_spec(filled_s.sin[i], inferred_s.sin[i]): resharded_args.append(inferred_s.args[i]) else: - resharded_args.append(relay.Call( - GetOp("raf.op._reshard"), - [inferred_s.args[i]], - ShardOpCallAttrs([filled_s.sin[i]], [inferred_s.sin[i]]))) - + resharded_args.append( + relay.Call( + GetOp("raf.op._reshard"), + [inferred_s.args[i]], + ShardOpCallAttrs([filled_s.sin[i]], [inferred_s.sin[i]]), + ) + ) + print("[Sharding Infer] %s %s ### %s" % (filled_s.op, inferred_s.attrs, filled_s.attrs)) return relay.Call(inferred_s.op, resharded_args, inferred_s.attrs) + def is_unset(s: BaseShardSpec): + """Check whether it is an UnsetShardSpec (placeholder of ShardSpec).""" return isinstance(s, UnsetShardSpec) + @try_when(always_apply) @register_infer_hint(["raf.op.add", "raf.op.subtract"]) def element_wise_op_with_2in_1out(s: ShardInfo) -> List[ShardOpCallAttrs]: @@ -125,9 +148,8 @@ def element_wise_op_with_2in_1out(s: ShardInfo) -> List[ShardOpCallAttrs]: for e in (s.sin[0], s.sin[1], s.sout[0]): if not is_unset(e): specs.append(e) - return [ - ShardOpCallAttrs([e, e], [e]) for e in specs - ] + return [ShardOpCallAttrs([e, e], [e]) for e in specs] + @try_when(always_apply) @register_infer_hint(["raf.op.relu"]) @@ -136,6 +158,4 @@ def element_wise_op_with_1in_1out(s: ShardInfo) -> List[ShardOpCallAttrs]: for e in (s.sin[0], s.sout[0]): if not is_unset(e): specs.append(e) - return [ - ShardOpCallAttrs([e], [e]) for e in specs - ] + return [ShardOpCallAttrs([e], [e]) for e in specs] diff --git a/src/pass/sharding.cc b/src/pass/sharding.cc index 0491702b..5f8f626a 100644 --- a/src/pass/sharding.cc +++ b/src/pass/sharding.cc @@ -1,7 +1,12 @@ +/* + * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. + * SPDX-License-Identifier: Apache-2.0 + */ + /*! * Copyright (c) 2021 by Contributors - * \file init_shardspec.cc - * \brief Gradient operator input selection pass + * \file sharding.cc + * \brief Sharding-related Passes (C++ Side) */ #include #include "raf/op.h" @@ -51,37 +56,12 @@ class ShardOpCallExpander : public ExprMutator { if (attrs.defined() && op->IsInstance() && attrs->IsInstance()) { auto call = GetRef(node); Expr new_expr = (*f)(call); - // return call.same_as(new_expr) ? new_expr : ExprMutator::VisitExpr(new_expr); return new_expr; } return ExprMutator::VisitExpr_(node); } }; - // // Step 1: Propagate ShardSpec - // Array sin; - // for (int64_t i = 0; i < sattr->sin.size(); ++i) { - // if (sattr->sin[i]->IsInstance()) { - // LOG(INFO) << i << " is unset shardspec"; - // bool flag_unchanged = true; - // if (args[i]->IsInstance()) { - // // Copy ShardSpec from previous output - // LOG(INFO) << i << " is call"; - // const auto pcall = Downcast(args[i]); - // if (pcall->attrs->IsInstance()) { - // const auto pattr = pcall->attrs.as(); - // sin.push_back(pattr->sout[0]); - // flag_unchanged = false; - // } - // } - // if (flag_unchanged) { - // // sin[i] = ShardSpec::make() - // } - // } else { - // sin.push_back(sattr->sin[i]); - // } - // } - class ShardSpecPropagator : public ExprMutator { public: Expr VisitExpr_(const CallNode* node) override { From 505f1aba1aeeb8eb9a211aa954cc836cffc82065 Mon Sep 17 00:00:00 2001 From: NekoDaemon Date: Fri, 16 Sep 2022 21:22:24 +0000 Subject: [PATCH 3/6] bugfix --- python/raf/distributed/sharding/__init__.py | 2 + src/pass/sharding.cc | 23 +++---- tests/python/pass/test_pass_sharding.py | 70 ++++++++++++++++++++- 3 files changed, 83 insertions(+), 12 deletions(-) diff --git a/python/raf/distributed/sharding/__init__.py b/python/raf/distributed/sharding/__init__.py index f86bc27a..19ebff03 100644 --- a/python/raf/distributed/sharding/__init__.py +++ b/python/raf/distributed/sharding/__init__.py @@ -5,3 +5,5 @@ from raf._ffi.sharding._make import ShardOpCallAttrs from .shardspec import BaseShardSpec, ShardSpec, UnsetShardSpec from .utils import make_replicated_spec, make_shard_spec, make_unset_spec +from .expandrule import expand_opcall +from .inferhint import infer_shardspec diff --git a/src/pass/sharding.cc b/src/pass/sharding.cc index 5f8f626a..355832b4 100644 --- a/src/pass/sharding.cc +++ b/src/pass/sharding.cc @@ -32,15 +32,14 @@ class ShardOpCallAttrsSetter : public ExprMutator { } Expr VisitExpr_(const CallNode* node) override { - const Expr& callee = node->op; - if (callee->IsInstance()) { - auto ref = GetRef(node); - if (_attrs_map.count(ref)) { - auto new_expr = Call(node->op, node->args, Attrs(_attrs_map[ref])); - return ExprMutator::VisitExpr_(new_expr.as()); + Call call = Downcast(ExprMutator::VisitExpr_(node)); + const Expr& op = call->op; + if (op->IsInstance()) { + if (_attrs_map.count(call)) { + return Call(node->op, node->args, Attrs(_attrs_map[call])); } } - return ExprMutator::VisitExpr_(node); + return call; } private: @@ -50,15 +49,17 @@ class ShardOpCallAttrsSetter : public ExprMutator { class ShardOpCallExpander : public ExprMutator { public: Expr VisitExpr_(const CallNode* node) override { - const Expr& op = node->op; - const Attrs& attrs = node->attrs; + Call call = Downcast(ExprMutator::VisitExpr_(node)); + const Expr& op = call->op; + const Attrs& attrs = call->attrs; const auto* f = tvm::runtime::Registry::Get("raf.sharding._match_expansion_rule"); if (attrs.defined() && op->IsInstance() && attrs->IsInstance()) { - auto call = GetRef(node); + LOG(INFO) << op << " " << call->op; + Expr new_expr = (*f)(call); return new_expr; } - return ExprMutator::VisitExpr_(node); + return call; } }; diff --git a/tests/python/pass/test_pass_sharding.py b/tests/python/pass/test_pass_sharding.py index c50c316f..908983ab 100644 --- a/tests/python/pass/test_pass_sharding.py +++ b/tests/python/pass/test_pass_sharding.py @@ -2,9 +2,25 @@ # SPDX-License-Identifier: Apache-2.0 # pylint: disable=missing-function-docstring, missing-class-docstring, invalid-name, protected-access +import raf import pytest +import numpy as np +from raf.distributed.sharding import ( + ShardSpec, + BaseShardSpec, + ShardOpCallAttrs, +) +from raf._ffi.pass_ import ( + AnnotateShardOpCall, + ToGraphNormalForm, + ExpandShardOpCall, + InferType, + InferShardSpec, +) +from raf._lib import relay from raf.distributed.sharding import make_replicated_spec, make_shard_spec, make_unset_spec from tvm.ir import structural_equal +from tvm.relay.analysis.analysis import post_order_visit def test_shardspec(): @@ -32,5 +48,57 @@ def test_shardspec(): assert not structural_equal(a, i) +def test_infer_hint_with_reshard(): + class Model(raf.Model): + def build(self): + pass + + @raf.model.trace + def forward(self, x, y): + z = raf.add(x, y) + a = raf.relu(z) + return a + + model = Model() + m_x = raf.array(np.arange(16, dtype="float").reshape((4, 4))) + m_y = raf.array(np.zeros(16, dtype="float").reshape((4, 4))) + record = model._internal(m_x, m_y) + mod_before = record.mod + mod_before = InferType()(mod_before) + + print(m_x) + call_list = [] + post_order_visit( + mod_before["main"].body, + lambda op: call_list.append(op) if isinstance(op, relay.Call) else None, + ) + + spec = make_shard_spec([2, 2], [1, 2], 4, mutable=False) + + attrs_map = { + call_list[0]: ShardOpCallAttrs([make_unset_spec(), make_unset_spec()], [make_unset_spec()]), + call_list[1]: ShardOpCallAttrs([make_unset_spec()], [spec]), + } + + mod0 = AnnotateShardOpCall(attrs_map)(mod_before) + mod1 = ToGraphNormalForm()(mod0) + mod2 = InferType()(mod1) + print("after 1st infer type") + print(raf._ffi.ir.AsText(mod2)) + + mod3 = InferShardSpec()(mod2) + print("after infer shard spec") + print(raf._ffi.ir.AsText(mod3)) + + mod4 = InferType()(mod3) + print("after 2nd infer type") + print(raf._ffi.ir.AsText(mod4)) + + mod5 = ExpandShardOpCall()(mod4) + print("after expand shard opcall") + print(raf._ffi.ir.AsText(mod5)) + + if __name__ == "__main__": - pytest.main([__file__]) + test_infer_hint_with_reshard() + # pytest.main([__file__]) From 81013aa9505aac8c8bca00ebb0c8fa20fa9a356f Mon Sep 17 00:00:00 2001 From: NekoDaemon Date: Sat, 17 Sep 2022 17:59:41 +0000 Subject: [PATCH 4/6] bugfix & refactor & add test --- python/raf/distributed/sharding/expandrule.py | 2 +- python/raf/distributed/sharding/inferhint.py | 54 ++++++++++--------- src/pass/sharding.cc | 20 +++---- tests/python/pass/test_pass_sharding.py | 53 +++++++++++++++++- 4 files changed, 90 insertions(+), 39 deletions(-) diff --git a/python/raf/distributed/sharding/expandrule.py b/python/raf/distributed/sharding/expandrule.py index 7dbdd04c..fb95ce83 100644 --- a/python/raf/distributed/sharding/expandrule.py +++ b/python/raf/distributed/sharding/expandrule.py @@ -200,7 +200,7 @@ def reshard_sharded_to_replicated(s: ShardInfo): axis = [] full_shape = [] for i in range(spec.ndim): - if spec.logic_shape[i] > 0: + if spec.logic_shape[i] > 1: axis.append(i) full_shape.append(int(spec.logic_shape[i])) full_shape.append(int(spec.subgroup_shape[i])) diff --git a/python/raf/distributed/sharding/inferhint.py b/python/raf/distributed/sharding/inferhint.py index 35db2ffa..2455d302 100644 --- a/python/raf/distributed/sharding/inferhint.py +++ b/python/raf/distributed/sharding/inferhint.py @@ -58,59 +58,61 @@ def infer_shardspec(call: relay.Call): s = ShardInfo(call) # Step 1: Inherit ShardSpec from previous output - filled_sin = [] + inherit_sin = [] for i in range(len(s.sin)): if isinstance(s.sin[i], UnsetShardSpec): if isinstance(s.args[i], relay.Call) and hasattr(s.args[i].attrs, "sin"): # cannot use isinstance to check the type of OpCall Attrs # direct inherit ShardSpec prev_sinfo = ShardInfo(s.args[i]) - filled_sin.append(prev_sinfo.sout[0]) + inherit_sin.append(prev_sinfo.sout[0]) else: # the previous output doesn't have ShardSpec ndim = len(s.args[0].checked_type.concrete_shape) - filled_sin.append(make_replicated_spec(ndim)) + inherit_sin.append(make_replicated_spec(ndim)) else: # already exist a specified ShardSpec - filled_sin.append(s.sin[i]) + inherit_sin.append(s.sin[i]) - filled_attrs = ShardOpCallAttrs(filled_sin, s.sout) - filled_call = relay.Call(s.op, s.args, filled_attrs) - filled_s = ShardInfo(filled_call) + inherit_attrs = ShardOpCallAttrs(inherit_sin, s.sout) + inherit_call = relay.Call(s.op, s.args, inherit_attrs) + inherit_s = ShardInfo(inherit_call) - # Step 2: Match an InferHint - guessed_calls = [] + # Step 2: Match InferHints + filled_calls = [] for rule in rules.queue: _, _, cond, irgen = rule - if cond(filled_s): - guessed_calls.extend([relay.Call(s.op, s.args, a) for a in irgen(filled_s)]) - if not guessed_calls: + if cond(inherit_s): + filled_calls.extend([relay.Call(s.op, s.args, a) for a in irgen(inherit_s)]) + if not filled_calls: raise ValueError("Failed to match an InferHint") # Step 3: Check the solution is practicable - ninputs = len(filled_s.sin) - noutputs = len(filled_s.sout) + ninputs = len(s.sin) + noutputs = len(s.sout) immut_in_idx = [ - i for i in range(ninputs) if is_sharded(filled_s.sin[i]) and not filled_s.sin[i].mutable + i for i in range(ninputs) if is_sharded(s.sin[i]) and not s.sin[i].mutable ] immut_out_idx = [ - i for i in range(noutputs) if is_sharded(filled_s.sout[i]) and not filled_s.sout[i].mutable + i for i in range(noutputs) if is_sharded(s.sout[i]) and not s.sout[i].mutable ] possible_calls = [] - for guessed_call in guessed_calls: - if not expand_opcall(guessed_call): + for filled_call in filled_calls: + if not expand_opcall(filled_call): + # there doesn't exist a expansion rule that accepts this sharding solution continue - guessed_s = ShardInfo(guessed_call) - immut_args = [(filled_s.sin[i], guessed_s.sin[i]) for i in immut_in_idx] + [ - (filled_s.sout[i], guessed_s.sout[i]) for i in immut_out_idx + filled_s = ShardInfo(filled_call) + immut_args = [(inherit_s.sin[i], filled_s.sin[i]) for i in immut_in_idx] + [ + (inherit_s.sout[i], filled_s.sout[i]) for i in immut_out_idx ] for pair in immut_args: if not is_same_spec(pair[0], pair[1]): + # violate immutable flag break else: - possible_calls.append(guessed_call) + possible_calls.append(filled_call) # Step 4: Pick an OpCall with full ShardSpec # TODO: should use graph searching algorithm with cost map here. @@ -118,21 +120,21 @@ def infer_shardspec(call: relay.Call): inferred_call = possible_calls[0] inferred_s = ShardInfo(inferred_call) - # Step 5: Insert Reshard OpCall + # Step 5: Insert Reshard OpCalls resharded_args = [] for i in range(ninputs): - if is_same_spec(filled_s.sin[i], inferred_s.sin[i]): + if is_same_spec(inherit_s.sin[i], inferred_s.sin[i]): resharded_args.append(inferred_s.args[i]) else: resharded_args.append( relay.Call( GetOp("raf.op._reshard"), [inferred_s.args[i]], - ShardOpCallAttrs([filled_s.sin[i]], [inferred_s.sin[i]]), + ShardOpCallAttrs([inherit_s.sin[i]], [inferred_s.sin[i]]), ) ) - print("[Sharding Infer] %s %s ### %s" % (filled_s.op, inferred_s.attrs, filled_s.attrs)) + print("[Sharding Infer] %s %s ### %s" % (inherit_s.op, inferred_s.attrs, inherit_s.attrs)) return relay.Call(inferred_s.op, resharded_args, inferred_s.attrs) diff --git a/src/pass/sharding.cc b/src/pass/sharding.cc index 355832b4..119cae56 100644 --- a/src/pass/sharding.cc +++ b/src/pass/sharding.cc @@ -49,17 +49,17 @@ class ShardOpCallAttrsSetter : public ExprMutator { class ShardOpCallExpander : public ExprMutator { public: Expr VisitExpr_(const CallNode* node) override { - Call call = Downcast(ExprMutator::VisitExpr_(node)); + Call call = GetRef(node); const Expr& op = call->op; const Attrs& attrs = call->attrs; const auto* f = tvm::runtime::Registry::Get("raf.sharding._match_expansion_rule"); - if (attrs.defined() && op->IsInstance() && attrs->IsInstance()) { - LOG(INFO) << op << " " << call->op; - Expr new_expr = (*f)(call); - return new_expr; + if (attrs.defined() && op->IsInstance() && attrs->IsInstance()) { + Call new_opcall = (*f)(call); + return ExprMutator::VisitExpr_(new_opcall.as()); } - return call; + + return ExprMutator::VisitExpr_(node); } }; @@ -71,12 +71,12 @@ class ShardSpecPropagator : public ExprMutator { const Attrs& attrs = call->attrs; const Array& args = call->args; const auto* f = tvm::runtime::Registry::Get("raf.sharding._infer_shardspec"); - if (attrs.defined() && op->IsInstance() && attrs->IsInstance()) { - LOG(INFO) << op << " " << call->op; - Expr new_expr = (*f)(call); - return new_expr; + if (attrs.defined() && op->IsInstance() && attrs->IsInstance()) { + Call new_opcall = (*f)(call); + return new_opcall; } + return call; } }; diff --git a/tests/python/pass/test_pass_sharding.py b/tests/python/pass/test_pass_sharding.py index 908983ab..35142204 100644 --- a/tests/python/pass/test_pass_sharding.py +++ b/tests/python/pass/test_pass_sharding.py @@ -47,8 +47,56 @@ def test_shardspec(): i = make_shard_spec([4], ranks=4, mutable=False) assert not structural_equal(a, i) +def test_infer_hint_without_prev_spec(): + class Model(raf.Model): + def build(self): + pass + + @raf.model.trace + def forward(self, x, y): + z = raf.add(x, y) + a = raf.relu(z) + b = raf.relu(a) + return b + + model = Model() + m_x = raf.array(np.arange(16, dtype="float").reshape((4, 4))) + m_y = raf.array(np.zeros(16, dtype="float").reshape((4, 4))) + record = model._internal(m_x, m_y) + mod_before = record.mod + mod_before = InferType()(mod_before) + + print(m_x) + call_list = [] + post_order_visit( + mod_before["main"].body, + lambda op: call_list.append(op) if isinstance(op, relay.Call) else None, + ) + + attrs_map = { + call_list[1]: ShardOpCallAttrs([make_unset_spec()], [make_shard_spec([4, 1], ranks=4, mutable=False)]), + call_list[2]: ShardOpCallAttrs([make_unset_spec()], [make_replicated_spec(2, mutable=False)]) + } + + mod0 = AnnotateShardOpCall(attrs_map)(mod_before) + mod1 = ToGraphNormalForm()(mod0) + mod2 = InferType()(mod1) + print("after 1st infer type") + print(raf._ffi.ir.AsText(mod2)) + + mod3 = InferShardSpec()(mod2) + print("after infer shard spec") + print(raf._ffi.ir.AsText(mod3)) + + mod4 = InferType()(mod3) + print("after 2nd infer type") + print(raf._ffi.ir.AsText(mod4)) + + mod5 = ExpandShardOpCall()(mod4) + print("after expand shard opcall") + print(raf._ffi.ir.AsText(mod5)) -def test_infer_hint_with_reshard(): +def test_infer_hint_inserting_reshard(): class Model(raf.Model): def build(self): pass @@ -100,5 +148,6 @@ def forward(self, x, y): if __name__ == "__main__": - test_infer_hint_with_reshard() + test_infer_hint_inserting_reshard() + # test_infer_hint_without_prev_spec() # pytest.main([__file__]) From 1c4ab1df7c7be503d9b9781184fb23c868c07869 Mon Sep 17 00:00:00 2001 From: NekoDaemon Date: Sat, 17 Sep 2022 18:00:10 +0000 Subject: [PATCH 5/6] lint --- python/raf/distributed/sharding/inferhint.py | 8 ++------ src/pass/sharding.cc | 2 +- tests/python/pass/test_pass_sharding.py | 10 ++++++++-- 3 files changed, 11 insertions(+), 9 deletions(-) diff --git a/python/raf/distributed/sharding/inferhint.py b/python/raf/distributed/sharding/inferhint.py index 2455d302..d7184892 100644 --- a/python/raf/distributed/sharding/inferhint.py +++ b/python/raf/distributed/sharding/inferhint.py @@ -91,12 +91,8 @@ def infer_shardspec(call: relay.Call): # Step 3: Check the solution is practicable ninputs = len(s.sin) noutputs = len(s.sout) - immut_in_idx = [ - i for i in range(ninputs) if is_sharded(s.sin[i]) and not s.sin[i].mutable - ] - immut_out_idx = [ - i for i in range(noutputs) if is_sharded(s.sout[i]) and not s.sout[i].mutable - ] + immut_in_idx = [i for i in range(ninputs) if is_sharded(s.sin[i]) and not s.sin[i].mutable] + immut_out_idx = [i for i in range(noutputs) if is_sharded(s.sout[i]) and not s.sout[i].mutable] possible_calls = [] for filled_call in filled_calls: diff --git a/src/pass/sharding.cc b/src/pass/sharding.cc index 119cae56..93fda62f 100644 --- a/src/pass/sharding.cc +++ b/src/pass/sharding.cc @@ -76,7 +76,7 @@ class ShardSpecPropagator : public ExprMutator { Call new_opcall = (*f)(call); return new_opcall; } - + return call; } }; diff --git a/tests/python/pass/test_pass_sharding.py b/tests/python/pass/test_pass_sharding.py index 35142204..4fef7f47 100644 --- a/tests/python/pass/test_pass_sharding.py +++ b/tests/python/pass/test_pass_sharding.py @@ -47,6 +47,7 @@ def test_shardspec(): i = make_shard_spec([4], ranks=4, mutable=False) assert not structural_equal(a, i) + def test_infer_hint_without_prev_spec(): class Model(raf.Model): def build(self): @@ -74,8 +75,12 @@ def forward(self, x, y): ) attrs_map = { - call_list[1]: ShardOpCallAttrs([make_unset_spec()], [make_shard_spec([4, 1], ranks=4, mutable=False)]), - call_list[2]: ShardOpCallAttrs([make_unset_spec()], [make_replicated_spec(2, mutable=False)]) + call_list[1]: ShardOpCallAttrs( + [make_unset_spec()], [make_shard_spec([4, 1], ranks=4, mutable=False)] + ), + call_list[2]: ShardOpCallAttrs( + [make_unset_spec()], [make_replicated_spec(2, mutable=False)] + ), } mod0 = AnnotateShardOpCall(attrs_map)(mod_before) @@ -96,6 +101,7 @@ def forward(self, x, y): print("after expand shard opcall") print(raf._ffi.ir.AsText(mod5)) + def test_infer_hint_inserting_reshard(): class Model(raf.Model): def build(self): From 5ee2f1ee3f5984087255b0ab05f6076c8b92d731 Mon Sep 17 00:00:00 2001 From: NekoDaemon Date: Tue, 27 Sep 2022 22:44:54 +0000 Subject: [PATCH 6/6] impl & bugfix --- include/raf/sharding.h | 4 +- python/raf/distributed/sharding/expandrule.py | 43 +++++++++- python/raf/distributed/sharding/inferhint.py | 85 +++++++++++++------ python/raf/distributed/sharding/shardspec.py | 8 ++ src/impl/sharding.cc | 6 +- src/pass/sharding.cc | 6 ++ tests/python/pass/test_pass_sharding.py | 41 +++------ 7 files changed, 126 insertions(+), 67 deletions(-) diff --git a/include/raf/sharding.h b/include/raf/sharding.h index b9c24147..1ca9c8ad 100644 --- a/include/raf/sharding.h +++ b/include/raf/sharding.h @@ -100,8 +100,8 @@ class ShardSpecObj final : public BaseShardSpecObj { v->Visit("ranks", &ranks); v->Visit("logic_shape", &logic_shape); v->Visit("logic_index", &logic_index_); - v->Visit("phy_shape", &logic_shape); - v->Visit("phy_index", &logic_index_); + v->Visit("phy_shape", &phy_shape); + v->Visit("phy_index", &phy_index_); v->Visit("subgroup_shape", &subgroup_shape); v->Visit("subgroup_index", &subgroup_index_); } diff --git a/python/raf/distributed/sharding/expandrule.py b/python/raf/distributed/sharding/expandrule.py index fb95ce83..77f87f1c 100644 --- a/python/raf/distributed/sharding/expandrule.py +++ b/python/raf/distributed/sharding/expandrule.py @@ -13,8 +13,12 @@ from raf._ffi.op import GetOp from raf._lib import _register_func, relay -from raf.distributed.sharding.shardspec import BaseShardSpec, ShardSpec -from tvm.relay import Expr +from raf.distributed.sharding import ( + ShardSpec, + BaseShardSpec, + ShardOpCallAttrs, +) +from tvm.relay import Call, Expr from tvm.ir import Op from tvm.runtime.object import Object @@ -50,6 +54,20 @@ def __init__(self, call: relay.Call): self.sin = call.attrs.sin self.sout = call.attrs.sout + def make_updated(self, op=None, args=None, sin=None, sout=None, attrs=None): + # pylint: disable=too-many-arguments + """Make a new ShardInfo based on this ShardInfo with a few fields modified""" + op = op if op else self.op + args = args if args else self.args + if sin or sout: + sin = sin if sin else self.sin + sout = sout if sout else self.sout + attrs = ShardOpCallAttrs(sin, sout) + elif not attrs: + attrs = self.attrs + call = Call(op, args, attrs) + return ShardInfo(call) + def all_satisfied(conds: List[Callable[[ShardInfo], bool]]): """Return true when all conditions are satisfied.""" @@ -63,14 +81,31 @@ def func(s: ShardInfo): return func -def is_same_spec(*args): - """Check whether two ShardSpecs are same.""" +def is_exact_same_spec(*args): + """Check whether two ShardSpecs are exact same.""" for e in args[1:]: if not tvm.ir.structural_equal(args[0], e): return False return True +def is_same_spec(*args): + """Check whether two ShardSpecs are same except Mutable Attr.""" + if is_sharded(args[0]): + for e in args[1:]: + if not is_sharded(e): + return False + if not tvm.ir.structural_equal(args[0].ranks, e.ranks): + return False + if not tvm.ir.structural_equal(args[0].phy_shape, e.phy_shape): + return False + if not tvm.ir.structural_equal(args[0].subgroup_shape, e.subgroup_shape): + return False + else: + return is_exact_same_spec(*args) + return True + + def is_sharded(s: BaseShardSpec): """Check whether it is a ShardSpec.""" return isinstance(s, ShardSpec) diff --git a/python/raf/distributed/sharding/inferhint.py b/python/raf/distributed/sharding/inferhint.py index d7184892..a5593e9a 100644 --- a/python/raf/distributed/sharding/inferhint.py +++ b/python/raf/distributed/sharding/inferhint.py @@ -9,13 +9,14 @@ from raf._ffi.sharding._make import ShardOpCallAttrs from raf._ffi.op import GetOp from raf._lib import _register_func, relay -from raf.distributed.sharding.shardspec import BaseShardSpec, UnsetShardSpec +from raf.distributed.sharding.shardspec import BaseShardSpec, UnsetShardSpec, ShardSpec from raf.distributed.sharding.utils import make_replicated_spec from .expandrule import ( ShardInfo, always_apply, expand_opcall, + is_exact_same_spec, is_same_spec, is_sharded, ) @@ -52,40 +53,53 @@ def decorator(pyfunc): @_register_func("raf.sharding._infer_shardspec") def infer_shardspec(call: relay.Call): - # pylint: disable=too-many-locals, too-many-branches + # pylint: disable=too-many-locals, too-many-branches, too-many-statements """Fill the placeholders of ShardSpec with infer hints.""" rules = try_when.rules[call.op] s = ShardInfo(call) - # Step 1: Inherit ShardSpec from previous output + # Step 1: Inherit input spec from previous output + + # inherit_sin should be the correct specs of current inputs inherit_sin = [] + # specified_sin should be the user-specified specs with filled unset shard specs + specified_sin = [] + for i in range(len(s.sin)): - if isinstance(s.sin[i], UnsetShardSpec): - if isinstance(s.args[i], relay.Call) and hasattr(s.args[i].attrs, "sin"): - # cannot use isinstance to check the type of OpCall Attrs - # direct inherit ShardSpec - prev_sinfo = ShardInfo(s.args[i]) - inherit_sin.append(prev_sinfo.sout[0]) + if isinstance(s.args[i], relay.Call) and hasattr(s.args[i].attrs, "sin"): + # cannot use isinstance to check the type of OpCall Attrs + # direct inherit ShardSpec + prev_sinfo = ShardInfo(s.args[i]) + inherit_sin.append(prev_sinfo.sout[0]) + else: + # the previous output isn't annotated with ShardSpec + if isinstance(s.sin[i], ShardSpec): + # already exist a specified ShardSpec + inherit_sin.append(s.sin[i]) else: - # the previous output doesn't have ShardSpec - ndim = len(s.args[0].checked_type.concrete_shape) + # assume the previous output is replicated on all ranks + ndim = len(s.args[i].checked_type.concrete_shape) inherit_sin.append(make_replicated_spec(ndim)) + if isinstance(s.sin[i], UnsetShardSpec): + specified_sin.append(inherit_sin[-1]) else: - # already exist a specified ShardSpec - inherit_sin.append(s.sin[i]) + specified_sin.append(s.sin[i]) - inherit_attrs = ShardOpCallAttrs(inherit_sin, s.sout) - inherit_call = relay.Call(s.op, s.args, inherit_attrs) - inherit_s = ShardInfo(inherit_call) + inherit_s = s.make_updated(sin=inherit_sin) + specified_s = s.make_updated(sin=specified_sin) # Step 2: Match InferHints - filled_calls = [] + + filled_s_list: List[ShardInfo] = [] # TODO: try to remove duplicated solutions for rule in rules.queue: _, _, cond, irgen = rule + if cond(specified_s): + filled_s_list.extend([s.make_updated(attrs=a) for a in irgen(specified_s)]) if cond(inherit_s): - filled_calls.extend([relay.Call(s.op, s.args, a) for a in irgen(inherit_s)]) - if not filled_calls: + filled_s_list.extend([s.make_updated(attrs=a) for a in irgen(inherit_s)]) + + if not filled_s_list: raise ValueError("Failed to match an InferHint") # Step 3: Check the solution is practicable @@ -94,27 +108,29 @@ def infer_shardspec(call: relay.Call): immut_in_idx = [i for i in range(ninputs) if is_sharded(s.sin[i]) and not s.sin[i].mutable] immut_out_idx = [i for i in range(noutputs) if is_sharded(s.sout[i]) and not s.sout[i].mutable] - possible_calls = [] - for filled_call in filled_calls: - if not expand_opcall(filled_call): + possible_s_list: List[ShardInfo] = [] + for filled_s in filled_s_list: + if not expand_opcall(filled_s.call): # there doesn't exist a expansion rule that accepts this sharding solution continue - filled_s = ShardInfo(filled_call) immut_args = [(inherit_s.sin[i], filled_s.sin[i]) for i in immut_in_idx] + [ (inherit_s.sout[i], filled_s.sout[i]) for i in immut_out_idx ] for pair in immut_args: if not is_same_spec(pair[0], pair[1]): - # violate immutable flag + # violate immutable attribute of shard spec break else: - possible_calls.append(filled_call) + # reset Mutable flag for outputs to prevent from spreading this flag mistakenly + sout = [ + spec if spec.mutable else spec.make_updated(mutable=True) for spec in filled_s.sout + ] + possible_s_list.append(filled_s.make_updated(sout=sout)) # Step 4: Pick an OpCall with full ShardSpec # TODO: should use graph searching algorithm with cost map here. # For now, always select the first solution. - inferred_call = possible_calls[0] - inferred_s = ShardInfo(inferred_call) + inferred_s = possible_s_list[0] # Step 5: Insert Reshard OpCalls resharded_args = [] @@ -130,7 +146,20 @@ def infer_shardspec(call: relay.Call): ) ) - print("[Sharding Infer] %s %s ### %s" % (inherit_s.op, inferred_s.attrs, inherit_s.attrs)) + print("[Sharding Infer] OpCall: %s" % s.op) + for phase in ("In", "Out"): + for i in range(ninputs if phase == "In" else noutputs): + if phase == "In": + a_spec, b_spec, c_spec = s.sin[i], inherit_s.sin[i], inferred_s.sin[i] + else: + a_spec, b_spec, c_spec = s.sout[i], inherit_s.sout[i], inferred_s.sout[i] + print(" %sArg %s: %s" % (phase, i, a_spec), end="") + if not is_exact_same_spec(a_spec, b_spec): + print(" -> %s" % b_spec, end="") + if not is_exact_same_spec(b_spec, c_spec): + print(" -> %s" % c_spec, end="") + print() + return relay.Call(inferred_s.op, resharded_args, inferred_s.attrs) diff --git a/python/raf/distributed/sharding/shardspec.py b/python/raf/distributed/sharding/shardspec.py index 60872ad2..b3abd90f 100644 --- a/python/raf/distributed/sharding/shardspec.py +++ b/python/raf/distributed/sharding/shardspec.py @@ -35,6 +35,14 @@ def __init__(self, ranks, phy_shape, subgroup_shape, mutable): _make.ShardSpec, ranks, phy_shape, subgroup_shape, mutable ) + def make_updated(self, ranks=None, phy_shape=None, subgroup_shape=None, mutable=None): + """Make a new spec based on this spec with a few fields modified""" + ranks = ranks if ranks else self.ranks + phy_shape = phy_shape if phy_shape else self.phy_shape + subgroup_shape = subgroup_shape if subgroup_shape else self.subgroup_shape + mutable = mutable if mutable else self.mutable + return ShardSpec(ranks, phy_shape, subgroup_shape, mutable) + @register_node("raf.sharding.UnsetShardSpec") class UnsetShardSpec(BaseShardSpec): diff --git a/src/impl/sharding.cc b/src/impl/sharding.cc index 5a62c8cc..d97854c7 100644 --- a/src/impl/sharding.cc +++ b/src/impl/sharding.cc @@ -170,7 +170,7 @@ TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) auto r = Downcast(ref); auto ndim = r->ndim_; if (r->nshard_ == 1) { - p->stream << "ShardSpec(Replicated)"; + p->stream << "ShardSpec(Replicated, " << (r->mutable_ ? "Mut)" : "Immut)"); } else { p->stream << "ShardSpec(" << "["; @@ -179,9 +179,9 @@ TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) auto ngroup_on_dim = r->subgroup_shape[i]->value; p->stream << (nshard_on_dim == 1 ? ":" : std::to_string(nshard_on_dim)) << (ngroup_on_dim == 1 ? "" : "(x" + std::to_string(ngroup_on_dim) + ")") - << (i != ndim - 1 ? ", " : ""); + << (i != ndim - 1 ? ", " : "], "); } - p->stream << "])"; + p->stream << (r->mutable_ ? "Mut)" : "Immut)"); } }); diff --git a/src/pass/sharding.cc b/src/pass/sharding.cc index 93fda62f..4f19ce34 100644 --- a/src/pass/sharding.cc +++ b/src/pass/sharding.cc @@ -48,6 +48,12 @@ class ShardOpCallAttrsSetter : public ExprMutator { class ShardOpCallExpander : public ExprMutator { public: + Expr VisitExpr_(const FunctionNode* node) override { + // remove inferred function return type as IR has changed + Expr new_body = VisitExpr(node->body); + return Function(node->params, new_body, {}, {}); + } + Expr VisitExpr_(const CallNode* node) override { Call call = GetRef(node); const Expr& op = call->op; diff --git a/tests/python/pass/test_pass_sharding.py b/tests/python/pass/test_pass_sharding.py index 4fef7f47..e465f50a 100644 --- a/tests/python/pass/test_pass_sharding.py +++ b/tests/python/pass/test_pass_sharding.py @@ -1,15 +1,11 @@ # Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. # SPDX-License-Identifier: Apache-2.0 -# pylint: disable=missing-function-docstring, missing-class-docstring, invalid-name, protected-access -import raf -import pytest +# pylint: disable=missing-function-docstring, missing-class-docstring, invalid-name, protected-access, no-self-use, too-many-locals import numpy as np -from raf.distributed.sharding import ( - ShardSpec, - BaseShardSpec, - ShardOpCallAttrs, -) +import pytest +import raf +from raf.distributed.sharding import ShardOpCallAttrs from raf._ffi.pass_ import ( AnnotateShardOpCall, ToGraphNormalForm, @@ -67,7 +63,6 @@ def forward(self, x, y): mod_before = record.mod mod_before = InferType()(mod_before) - print(m_x) call_list = [] post_order_visit( mod_before["main"].body, @@ -86,17 +81,8 @@ def forward(self, x, y): mod0 = AnnotateShardOpCall(attrs_map)(mod_before) mod1 = ToGraphNormalForm()(mod0) mod2 = InferType()(mod1) - print("after 1st infer type") - print(raf._ffi.ir.AsText(mod2)) - mod3 = InferShardSpec()(mod2) - print("after infer shard spec") - print(raf._ffi.ir.AsText(mod3)) - mod4 = InferType()(mod3) - print("after 2nd infer type") - print(raf._ffi.ir.AsText(mod4)) - mod5 = ExpandShardOpCall()(mod4) print("after expand shard opcall") print(raf._ffi.ir.AsText(mod5)) @@ -111,7 +97,8 @@ def build(self): def forward(self, x, y): z = raf.add(x, y) a = raf.relu(z) - return a + b = raf.relu(a) + return b model = Model() m_x = raf.array(np.arange(16, dtype="float").reshape((4, 4))) @@ -137,23 +124,17 @@ def forward(self, x, y): mod0 = AnnotateShardOpCall(attrs_map)(mod_before) mod1 = ToGraphNormalForm()(mod0) mod2 = InferType()(mod1) - print("after 1st infer type") - print(raf._ffi.ir.AsText(mod2)) - mod3 = InferShardSpec()(mod2) - print("after infer shard spec") - print(raf._ffi.ir.AsText(mod3)) - mod4 = InferType()(mod3) - print("after 2nd infer type") + print("after infer type") print(raf._ffi.ir.AsText(mod4)) - mod5 = ExpandShardOpCall()(mod4) print("after expand shard opcall") print(raf._ffi.ir.AsText(mod5)) + mod6 = InferType()(mod5) + print("after infer type2") + print(raf._ffi.ir.AsText(mod6)) if __name__ == "__main__": - test_infer_hint_inserting_reshard() - # test_infer_hint_without_prev_spec() - # pytest.main([__file__]) + pytest.main([__file__])