Gong Junmin commited on
Commit
11a221a
·
1 Parent(s): 509f9f2

first commit

Browse files
Files changed (31) hide show
  1. .gitignore +4 -0
  2. LICENSE +246 -201
  3. acestep/acestep_v15_pipeline.py +67 -0
  4. acestep/gradio_ui.py +744 -0
  5. acestep/handler.py +1100 -0
  6. acestep/third_parts/nano-vllm/LICENSE +21 -0
  7. acestep/third_parts/nano-vllm/README.md +66 -0
  8. acestep/third_parts/nano-vllm/assets/logo.png +3 -0
  9. acestep/third_parts/nano-vllm/bench.py +32 -0
  10. acestep/third_parts/nano-vllm/example.py +33 -0
  11. acestep/third_parts/nano-vllm/nanovllm/__init__.py +2 -0
  12. acestep/third_parts/nano-vllm/nanovllm/config.py +26 -0
  13. acestep/third_parts/nano-vllm/nanovllm/engine/block_manager.py +112 -0
  14. acestep/third_parts/nano-vllm/nanovllm/engine/llm_engine.py +120 -0
  15. acestep/third_parts/nano-vllm/nanovllm/engine/model_runner.py +315 -0
  16. acestep/third_parts/nano-vllm/nanovllm/engine/scheduler.py +222 -0
  17. acestep/third_parts/nano-vllm/nanovllm/engine/sequence.py +89 -0
  18. acestep/third_parts/nano-vllm/nanovllm/layers/activation.py +14 -0
  19. acestep/third_parts/nano-vllm/nanovllm/layers/attention.py +75 -0
  20. acestep/third_parts/nano-vllm/nanovllm/layers/embed_head.py +66 -0
  21. acestep/third_parts/nano-vllm/nanovllm/layers/layernorm.py +50 -0
  22. acestep/third_parts/nano-vllm/nanovllm/layers/linear.py +153 -0
  23. acestep/third_parts/nano-vllm/nanovllm/layers/rotary_embedding.py +61 -0
  24. acestep/third_parts/nano-vllm/nanovllm/layers/sampler.py +15 -0
  25. acestep/third_parts/nano-vllm/nanovllm/llm.py +5 -0
  26. acestep/third_parts/nano-vllm/nanovllm/models/qwen3.py +215 -0
  27. acestep/third_parts/nano-vllm/nanovllm/sampling_params.py +13 -0
  28. acestep/third_parts/nano-vllm/nanovllm/utils/context.py +27 -0
  29. acestep/third_parts/nano-vllm/nanovllm/utils/loader.py +28 -0
  30. acestep/third_parts/nano-vllm/pyproject.toml +27 -0
  31. requirements.txt +4 -0
.gitignore CHANGED
@@ -205,3 +205,7 @@ cython_debug/
205
  marimo/_static/
206
  marimo/_lsp/
207
  __marimo__/
 
 
 
 
 
205
  marimo/_static/
206
  marimo/_lsp/
207
  __marimo__/
208
+ tests/
209
+ checkpoints/
210
+ playground.ipynb
211
+ .history/
LICENSE CHANGED
@@ -1,201 +1,246 @@
1
- Apache License
2
- Version 2.0, January 2004
3
- http://www.apache.org/licenses/
4
-
5
- TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
6
-
7
- 1. Definitions.
8
-
9
- "License" shall mean the terms and conditions for use, reproduction,
10
- and distribution as defined by Sections 1 through 9 of this document.
11
-
12
- "Licensor" shall mean the copyright owner or entity authorized by
13
- the copyright owner that is granting the License.
14
-
15
- "Legal Entity" shall mean the union of the acting entity and all
16
- other entities that control, are controlled by, or are under common
17
- control with that entity. For the purposes of this definition,
18
- "control" means (i) the power, direct or indirect, to cause the
19
- direction or management of such entity, whether by contract or
20
- otherwise, or (ii) ownership of fifty percent (50%) or more of the
21
- outstanding shares, or (iii) beneficial ownership of such entity.
22
-
23
- "You" (or "Your") shall mean an individual or Legal Entity
24
- exercising permissions granted by this License.
25
-
26
- "Source" form shall mean the preferred form for making modifications,
27
- including but not limited to software source code, documentation
28
- source, and configuration files.
29
-
30
- "Object" form shall mean any form resulting from mechanical
31
- transformation or translation of a Source form, including but
32
- not limited to compiled object code, generated documentation,
33
- and conversions to other media types.
34
-
35
- "Work" shall mean the work of authorship, whether in Source or
36
- Object form, made available under the License, as indicated by a
37
- copyright notice that is included in or attached to the work
38
- (an example is provided in the Appendix below).
39
-
40
- "Derivative Works" shall mean any work, whether in Source or Object
41
- form, that is based on (or derived from) the Work and for which the
42
- editorial revisions, annotations, elaborations, or other modifications
43
- represent, as a whole, an original work of authorship. For the purposes
44
- of this License, Derivative Works shall not include works that remain
45
- separable from, or merely link (or bind by name) to the interfaces of,
46
- the Work and Derivative Works thereof.
47
-
48
- "Contribution" shall mean any work of authorship, including
49
- the original version of the Work and any modifications or additions
50
- to that Work or Derivative Works thereof, that is intentionally
51
- submitted to Licensor for inclusion in the Work by the copyright owner
52
- or by an individual or Legal Entity authorized to submit on behalf of
53
- the copyright owner. For the purposes of this definition, "submitted"
54
- means any form of electronic, verbal, or written communication sent
55
- to the Licensor or its representatives, including but not limited to
56
- communication on electronic mailing lists, source code control systems,
57
- and issue tracking systems that are managed by, or on behalf of, the
58
- Licensor for the purpose of discussing and improving the Work, but
59
- excluding communication that is conspicuously marked or otherwise
60
- designated in writing by the copyright owner as "Not a Contribution."
61
-
62
- "Contributor" shall mean Licensor and any individual or Legal Entity
63
- on behalf of whom a Contribution has been received by Licensor and
64
- subsequently incorporated within the Work.
65
-
66
- 2. Grant of Copyright License. Subject to the terms and conditions of
67
- this License, each Contributor hereby grants to You a perpetual,
68
- worldwide, non-exclusive, no-charge, royalty-free, irrevocable
69
- copyright license to reproduce, prepare Derivative Works of,
70
- publicly display, publicly perform, sublicense, and distribute the
71
- Work and such Derivative Works in Source or Object form.
72
-
73
- 3. Grant of Patent License. Subject to the terms and conditions of
74
- this License, each Contributor hereby grants to You a perpetual,
75
- worldwide, non-exclusive, no-charge, royalty-free, irrevocable
76
- (except as stated in this section) patent license to make, have made,
77
- use, offer to sell, sell, import, and otherwise transfer the Work,
78
- where such license applies only to those patent claims licensable
79
- by such Contributor that are necessarily infringed by their
80
- Contribution(s) alone or by combination of their Contribution(s)
81
- with the Work to which such Contribution(s) was submitted. If You
82
- institute patent litigation against any entity (including a
83
- cross-claim or counterclaim in a lawsuit) alleging that the Work
84
- or a Contribution incorporated within the Work constitutes direct
85
- or contributory patent infringement, then any patent licenses
86
- granted to You under this License for that Work shall terminate
87
- as of the date such litigation is filed.
88
-
89
- 4. Redistribution. You may reproduce and distribute copies of the
90
- Work or Derivative Works thereof in any medium, with or without
91
- modifications, and in Source or Object form, provided that You
92
- meet the following conditions:
93
-
94
- (a) You must give any other recipients of the Work or
95
- Derivative Works a copy of this License; and
96
-
97
- (b) You must cause any modified files to carry prominent notices
98
- stating that You changed the files; and
99
-
100
- (c) You must retain, in the Source form of any Derivative Works
101
- that You distribute, all copyright, patent, trademark, and
102
- attribution notices from the Source form of the Work,
103
- excluding those notices that do not pertain to any part of
104
- the Derivative Works; and
105
-
106
- (d) If the Work includes a "NOTICE" text file as part of its
107
- distribution, then any Derivative Works that You distribute must
108
- include a readable copy of the attribution notices contained
109
- within such NOTICE file, excluding those notices that do not
110
- pertain to any part of the Derivative Works, in at least one
111
- of the following places: within a NOTICE text file distributed
112
- as part of the Derivative Works; within the Source form or
113
- documentation, if provided along with the Derivative Works; or,
114
- within a display generated by the Derivative Works, if and
115
- wherever such third-party notices normally appear. The contents
116
- of the NOTICE file are for informational purposes only and
117
- do not modify the License. You may add Your own attribution
118
- notices within Derivative Works that You distribute, alongside
119
- or as an addendum to the NOTICE text from the Work, provided
120
- that such additional attribution notices cannot be construed
121
- as modifying the License.
122
-
123
- You may add Your own copyright statement to Your modifications and
124
- may provide additional or different license terms and conditions
125
- for use, reproduction, or distribution of Your modifications, or
126
- for any such Derivative Works as a whole, provided Your use,
127
- reproduction, and distribution of the Work otherwise complies with
128
- the conditions stated in this License.
129
-
130
- 5. Submission of Contributions. Unless You explicitly state otherwise,
131
- any Contribution intentionally submitted for inclusion in the Work
132
- by You to the Licensor shall be under the terms and conditions of
133
- this License, without any additional terms or conditions.
134
- Notwithstanding the above, nothing herein shall supersede or modify
135
- the terms of any separate license agreement you may have executed
136
- with Licensor regarding such Contributions.
137
-
138
- 6. Trademarks. This License does not grant permission to use the trade
139
- names, trademarks, service marks, or product names of the Licensor,
140
- except as required for reasonable and customary use in describing the
141
- origin of the Work and reproducing the content of the NOTICE file.
142
-
143
- 7. Disclaimer of Warranty. Unless required by applicable law or
144
- agreed to in writing, Licensor provides the Work (and each
145
- Contributor provides its Contributions) on an "AS IS" BASIS,
146
- WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
147
- implied, including, without limitation, any warranties or conditions
148
- of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
149
- PARTICULAR PURPOSE. You are solely responsible for determining the
150
- appropriateness of using or redistributing the Work and assume any
151
- risks associated with Your exercise of permissions under this License.
152
-
153
- 8. Limitation of Liability. In no event and under no legal theory,
154
- whether in tort (including negligence), contract, or otherwise,
155
- unless required by applicable law (such as deliberate and grossly
156
- negligent acts) or agreed to in writing, shall any Contributor be
157
- liable to You for damages, including any direct, indirect, special,
158
- incidental, or consequential damages of any character arising as a
159
- result of this License or out of the use or inability to use the
160
- Work (including but not limited to damages for loss of goodwill,
161
- work stoppage, computer failure or malfunction, or any and all
162
- other commercial damages or losses), even if such Contributor
163
- has been advised of the possibility of such damages.
164
-
165
- 9. Accepting Warranty or Additional Liability. While redistributing
166
- the Work or Derivative Works thereof, You may choose to offer,
167
- and charge a fee for, acceptance of support, warranty, indemnity,
168
- or other liability obligations and/or rights consistent with this
169
- License. However, in accepting such obligations, You may act only
170
- on Your own behalf and on Your sole responsibility, not on behalf
171
- of any other Contributor, and only if You agree to indemnify,
172
- defend, and hold each Contributor harmless for any liability
173
- incurred by, or claims asserted against, such Contributor by reason
174
- of your accepting any such warranty or additional liability.
175
-
176
- END OF TERMS AND CONDITIONS
177
-
178
- APPENDIX: How to apply the Apache License to your work.
179
-
180
- To apply the Apache License to your work, attach the following
181
- boilerplate notice, with the fields enclosed by brackets "[]"
182
- replaced with your own identifying information. (Don't include
183
- the brackets!) The text should be enclosed in the appropriate
184
- comment syntax for the file format. We also recommend that a
185
- file or class name and description of purpose be included on the
186
- same "printed page" as the copyright notice for easier
187
- identification within third-party archives.
188
-
189
- Copyright [yyyy] [name of copyright owner]
190
-
191
- Licensed under the Apache License, Version 2.0 (the "License");
192
- you may not use this file except in compliance with the License.
193
- You may obtain a copy of the License at
194
-
195
- http://www.apache.org/licenses/LICENSE-2.0
196
-
197
- Unless required by applicable law or agreed to in writing, software
198
- distributed under the License is distributed on an "AS IS" BASIS,
199
- WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
200
- See the License for the specific language governing permissions and
201
- limitations under the License.
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ================================================================================
2
+ ACE-STEP LICENSE
3
+ Version 1.0
4
+ ================================================================================
5
+
6
+ TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
7
+
8
+ --------------------------------------------------------------------------------
9
+
10
+ 1. DEFINITIONS
11
+
12
+ "License" shall mean the terms and conditions for use, reproduction, and
13
+ distribution as defined by Sections 1 through 11 of this document.
14
+
15
+ "Licensor" shall mean ACE Studio (or the copyright owner/entity authorized
16
+ by ACE Studio) that is granting the License.
17
+
18
+ "Legal Entity" shall mean the union of the acting entity and all other
19
+ entities that control, are controlled by, or are under common control with
20
+ that entity. For the purposes of this definition, "control" means (i) the
21
+ power, direct or indirect, to cause the direction or management of such
22
+ entity, whether by contract or otherwise, or (ii) ownership of fifty
23
+ percent (50%) or more of the outstanding shares, or (iii) beneficial
24
+ ownership of such entity.
25
+
26
+ "You" (or "Your") shall mean an individual or Legal Entity exercising
27
+ permissions granted by this License.
28
+
29
+ "Source" form shall mean the preferred form for making modifications,
30
+ including but not limited to software source code, documentation source,
31
+ configuration files, and model training code.
32
+
33
+ "Object" form shall mean any form resulting from mechanical transformation
34
+ or translation of a Source form, including but not limited to compiled
35
+ object code, generated documentation, and conversions to other media types.
36
+
37
+ "Work" shall mean the work of authorship, whether in Source or Object form,
38
+ made available under the License, as indicated by a copyright notice that
39
+ is included in or attached to the work. For the avoidance of doubt, "Work"
40
+ explicitly includes the Model Weights, parameters, and configuration files
41
+ provided by the Licensor.
42
+
43
+ "Derivative Works" shall mean any work, whether in Source or Object form,
44
+ that is based on (or derived from) the Work and for which the editorial
45
+ revisions, annotations, elaborations, or other modifications represent, as
46
+ a whole, an original work of authorship. For the purposes of this License,
47
+ "Derivative Works" shall explicitly include "Derivative Models," defined as
48
+ any modifications to the Model Weights, including but not limited to
49
+ Fine-tunes, LoRAs (Low-Rank Adaptation), adapters, and other distinct
50
+ parameter sets derived from the Work.
51
+
52
+ "Output" shall mean any audio, music, sound recordings, or data generated
53
+ by the use or execution of the Work or Derivative Works.
54
+
55
+ "Contribution" shall mean any work of authorship, including the original
56
+ version of the Work and any modifications or additions to that Work or
57
+ Derivative Works thereof, that is intentionally submitted to Licensor for
58
+ inclusion in the Work by the copyright owner or by an individual or Legal
59
+ Entity authorized to submit on behalf of the copyright owner.
60
+
61
+ "Contributor" shall mean Licensor and any individual or Legal Entity on
62
+ behalf of whom a Contribution has been received by Licensor and subsequently
63
+ incorporated within the Work.
64
+
65
+ --------------------------------------------------------------------------------
66
+
67
+ 2. GRANT OF COPYRIGHT LICENSE
68
+
69
+ Subject to the terms and conditions of this License (including the specific
70
+ restrictions in Section 5 and Section 6), each Contributor hereby grants to
71
+ You a perpetual, worldwide, non-exclusive, no-charge, royalty-free,
72
+ irrevocable copyright license to reproduce, prepare Derivative Works of,
73
+ publicly display, publicly perform, sublicense, and distribute the Work and
74
+ such Derivative Works in Source or Object form.
75
+
76
+ --------------------------------------------------------------------------------
77
+
78
+ 3. GRANT OF PATENT LICENSE
79
+
80
+ Subject to the terms and conditions of this License, each Contributor
81
+ hereby grants to You a perpetual, worldwide, non-exclusive, no-charge,
82
+ royalty-free, irrevocable (except as stated in this section) patent license
83
+ to make, have made, use, offer to sell, sell, import, and otherwise transfer
84
+ the Work, where such license applies only to those patent claims licensable
85
+ by such Contributor that are necessarily infringed by their Contribution(s)
86
+ alone or by combination of their Contribution(s) with the Work to which such
87
+ Contribution(s) was submitted.
88
+
89
+ If You institute patent litigation against any entity (including a
90
+ cross-claim or counterclaim in a lawsuit) alleging that the Work or a
91
+ Contribution incorporated within the Work constitutes direct or contributory
92
+ patent infringement, then any patent licenses granted to You under this
93
+ License for that Work shall terminate as of the date such litigation is
94
+ filed.
95
+
96
+ --------------------------------------------------------------------------------
97
+
98
+ 4. REDISTRIBUTION
99
+
100
+ You may reproduce and distribute copies of the Work or Derivative Works
101
+ thereof in any medium, with or without modifications, and in Source or
102
+ Object form, provided that You meet the following conditions:
103
+
104
+ (a) You must give any other recipients of the Work or Derivative Works a
105
+ copy of this License; and
106
+
107
+ (b) You must cause any modified files to carry prominent notices stating
108
+ that You changed the files; and
109
+
110
+ (c) You must retain, in the Source form of any Derivative Works that You
111
+ distribute, all copyright, patent, trademark, and attribution notices
112
+ from the Source form of the Work, excluding those notices that do not
113
+ pertain to any part of the Derivative Works; and
114
+
115
+ (d) If the Work includes a "NOTICE" text file as part of its distribution,
116
+ then any Derivative Works that You distribute must include a readable
117
+ copy of the attribution notices contained within such NOTICE file,
118
+ excluding those notices that do not pertain to any part of the
119
+ Derivative Works, in at least one of the following places: within a
120
+ NOTICE text file distributed as part of the Derivative Works; within
121
+ the Source form or documentation, if provided along with the Derivative
122
+ Works; or, within a display generated by the Derivative Works, if and
123
+ wherever such third-party notices normally appear. The contents of the
124
+ NOTICE file are for informational purposes only and do not modify the
125
+ License. You may add Your own attribution notices within Derivative
126
+ Works that You distribute, alongside or as an addendum to the NOTICE
127
+ text from the Work, provided that such additional attribution notices
128
+ cannot be construed as modifying the License.
129
+
130
+ (e) Community Contribution Requirement: If You create a Derivative Work
131
+ (specifically a Derivative Model, such as a Fine-tune or LoRA) and You
132
+ distribute it, publicly perform it, or use it to generate publicly
133
+ available Output, You must make the Source form (including the weights,
134
+ parameters, and training configuration) of said Derivative Work publicly
135
+ available under the terms of this License. This clause ensures that
136
+ improvements to the model remain accessible to the community.
137
+
138
+ --------------------------------------------------------------------------------
139
+
140
+ 5. RESTRICTIONS ON COMMERCIAL SERVICES (THE "ANTI-SAAS" CLAUSE)
141
+
142
+ Notwithstanding the grants in Section 2, You are strictly prohibited from
143
+ using the Work or Derivative Works to operate, promote, or distribute a
144
+ commercial service where the primary value provided to users is the ability
145
+ to generate Outputs using the Work. This includes, but is not limited to:
146
+
147
+ (a) Offering the Work as a hosted Application Programming Interface (API);
148
+
149
+ (b) Offering the Work as a Software-as-a-Service (SaaS) product;
150
+
151
+ (c) Integrating the Work into a commercial platform that charges users
152
+ specifically for generation capabilities.
153
+
154
+ Note: You may use the Work to develop independent software tools, plugins,
155
+ or applications (e.g., local plugins for DAWs), provided that such tools
156
+ run locally on end-users' hardware and do not violate the restrictions on
157
+ hosted commercial generation services defined above.
158
+
159
+ --------------------------------------------------------------------------------
160
+
161
+ 6. OWNERSHIP AND COMMERCIALIZATION OF OUTPUTS
162
+
163
+ (a) Personal and Non-Commercial Use: You are free to use Outputs generated
164
+ by the Work for personal use, research, educational purposes, and
165
+ non-commercial creative projects (e.g., background music for personal
166
+ videos) without restriction.
167
+
168
+ (b) Commercial Use and Verification: The Licensor provides the Work without
169
+ any default warranty of copyright ownership for raw Outputs. To use
170
+ Outputs for Commercial Purposes (including but not limited to
171
+ distributing to music streaming platforms such as Spotify, Apple Music,
172
+ or YouTube Music, or registering Content ID), You must obtain Proof of
173
+ Human Creation or authorization through ACE Studio (or channels
174
+ officially designated by the Licensor). "Commercial Purposes" implies
175
+ intent to profit from the direct exploitation of the generated audio.
176
+
177
+ (c) Prohibition on Mass Generation: You are expressly prohibited from using
178
+ the Work to generate and distribute mass quantities of content
179
+ ("spamming") for the purpose of flooding streaming services,
180
+ manipulating royalty systems, or engaging in automated content farming.
181
+
182
+ --------------------------------------------------------------------------------
183
+
184
+ 7. SUBMISSION OF CONTRIBUTIONS
185
+
186
+ Unless You explicitly state otherwise, any Contribution intentionally
187
+ submitted for inclusion in the Work by You to the Licensor shall be under
188
+ the terms and conditions of this License, without any additional terms or
189
+ conditions. Notwithstanding the above, nothing herein shall supersede or
190
+ modify the terms of any separate license agreement you may have executed
191
+ with Licensor regarding such Contributions.
192
+
193
+ --------------------------------------------------------------------------------
194
+
195
+ 8. TRADEMARKS
196
+
197
+ This License does not grant permission to use the trade names, trademarks,
198
+ service marks, or product names of the Licensor, except as required for
199
+ reasonable and customary use in describing the origin of the Work and
200
+ reproducing the content of the NOTICE file.
201
+
202
+ --------------------------------------------------------------------------------
203
+
204
+ 9. DISCLAIMER OF WARRANTY
205
+
206
+ Unless required by applicable law or agreed to in writing, Licensor provides
207
+ the Work (and each Contributor provides its Contributions) on an "AS IS"
208
+ BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
209
+ implied, including, without limitation, any warranties or conditions of
210
+ TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A PARTICULAR
211
+ PURPOSE. You are solely responsible for determining the appropriateness of
212
+ using or redistributing the Work and assume any risks associated with Your
213
+ exercise of permissions under this License.
214
+
215
+ --------------------------------------------------------------------------------
216
+
217
+ 10. LIMITATION OF LIABILITY
218
+
219
+ In no event and under no legal theory, whether in tort (including
220
+ negligence), contract, or otherwise, unless required by applicable law
221
+ (such as deliberate and grossly negligent acts) or agreed to in writing,
222
+ shall any Contributor be liable to You for damages, including any direct,
223
+ indirect, special, incidental, or consequential damages of any character
224
+ arising as a result of this License or out of the use or inability to use
225
+ the Work (including but not limited to damages for loss of goodwill, work
226
+ stoppage, computer failure or malfunction, or any and all other commercial
227
+ damages or losses), even if such Contributor has been advised of the
228
+ possibility of such damages.
229
+
230
+ --------------------------------------------------------------------------------
231
+
232
+ 11. ACCEPTING WARRANTY OR ADDITIONAL LIABILITY
233
+
234
+ While redistributing the Work or Derivative Works thereof, You may choose
235
+ to offer, and charge a fee for, acceptance of support, warranty,
236
+ indemnity, or other liability obligations and/or rights consistent with
237
+ this License. However, in accepting such obligations, You may act only on
238
+ Your own behalf and on Your sole responsibility, not on behalf of any
239
+ other Contributor, and only if You agree to indemnify, defend, and hold
240
+ each Contributor harmless for any liability incurred by, or claims asserted
241
+ against, such Contributor by reason of your accepting any such warranty or
242
+ additional liability.
243
+
244
+ ================================================================================
245
+ END OF TERMS AND CONDITIONS
246
+ ================================================================================
acestep/acestep_v15_pipeline.py ADDED
@@ -0,0 +1,67 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ ACE-Step V1.5 Pipeline
3
+ Handler wrapper connecting model and UI
4
+ """
5
+ import os
6
+ import sys
7
+
8
+ # Clear proxy settings that may affect Gradio
9
+ for proxy_var in ['http_proxy', 'https_proxy', 'HTTP_PROXY', 'HTTPS_PROXY', 'ALL_PROXY']:
10
+ os.environ.pop(proxy_var, None)
11
+
12
+ os.environ['CUDA_VISIBLE_DEVICES'] = '7' # Adjust as needed
13
+
14
+ from .handler import AceStepHandler
15
+ from .gradio_ui import create_gradio_interface
16
+
17
+
18
+ def create_demo():
19
+ """
20
+ Create Gradio demo interface
21
+
22
+ Returns:
23
+ Gradio Blocks instance
24
+ """
25
+ # Create handler instance (business logic processor)
26
+ handler = AceStepHandler()
27
+
28
+ # Create Gradio interface
29
+ demo = create_gradio_interface(handler)
30
+
31
+ return demo
32
+
33
+
34
+ def main():
35
+ """Main entry function"""
36
+ import argparse
37
+
38
+ parser = argparse.ArgumentParser(description="Gradio Demo for ACE-Step V1.5")
39
+ parser.add_argument("--port", type=int, default=7860, help="Port to run the gradio server on")
40
+ parser.add_argument("--share", action="store_true", help="Create a public link")
41
+ parser.add_argument("--debug", action="store_true", help="Enable debug mode")
42
+ parser.add_argument("--server-name", type=str, default="127.0.0.1", help="Server name (default: 127.0.0.1, use 0.0.0.0 for all interfaces)")
43
+ args = parser.parse_args()
44
+
45
+ try:
46
+ # Create and launch demo
47
+ print("Creating Gradio interface...")
48
+ demo = create_demo()
49
+ print(f"Launching server on {args.server_name}:{args.port}...")
50
+ demo.launch(
51
+ server_name=args.server_name,
52
+ server_port=args.port,
53
+ share=args.share,
54
+ debug=args.debug,
55
+ show_error=True,
56
+ prevent_thread_lock=False, # Keep thread locked to maintain server running
57
+ inbrowser=False, # Don't auto-open browser
58
+ )
59
+ except Exception as e:
60
+ print(f"Error launching Gradio: {e}", file=sys.stderr)
61
+ import traceback
62
+ traceback.print_exc()
63
+ sys.exit(1)
64
+
65
+
66
+ if __name__ == "__main__":
67
+ main()
acestep/gradio_ui.py ADDED
@@ -0,0 +1,744 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Gradio UI Components Module
3
+ Contains all Gradio interface component definitions and layouts
4
+ """
5
+ import gradio as gr
6
+ from typing import Callable, Optional
7
+
8
+
9
+ def create_gradio_interface(handler) -> gr.Blocks:
10
+ """
11
+ Create Gradio interface
12
+
13
+ Args:
14
+ handler: Business logic handler instance
15
+
16
+ Returns:
17
+ Gradio Blocks instance
18
+ """
19
+ with gr.Blocks(
20
+ title="ACE-Step V1.5 Demo",
21
+ theme=gr.themes.Soft(),
22
+ css="""
23
+ .main-header {
24
+ text-align: center;
25
+ margin-bottom: 2rem;
26
+ }
27
+ .section-header {
28
+ background: linear-gradient(90deg, #4CAF50, #45a049);
29
+ color: white;
30
+ padding: 10px;
31
+ border-radius: 5px;
32
+ margin: 10px 0;
33
+ }
34
+ """
35
+ ) as demo:
36
+
37
+ gr.HTML("""
38
+ <div class="main-header">
39
+ <h1>♪ACE-Step V1.5 Demo</h1>
40
+ <p>Generate music from text captions and lyrics using diffusion models</p>
41
+ </div>
42
+ """)
43
+
44
+ # Dataset Explorer Section
45
+ dataset_section = create_dataset_section(handler)
46
+
47
+ # Generation Section
48
+ generation_section = create_generation_section(handler)
49
+
50
+ # Results Section
51
+ results_section = create_results_section(handler)
52
+
53
+ # Connect event handlers
54
+ setup_event_handlers(demo, handler, dataset_section, generation_section, results_section)
55
+
56
+ return demo
57
+
58
+
59
+ def create_dataset_section(handler) -> dict:
60
+ """Create dataset explorer section"""
61
+ with gr.Group():
62
+ gr.HTML('<div class="section-header"><h3>📊 Dataset Explorer</h3></div>')
63
+
64
+ with gr.Row(equal_height=True):
65
+ dataset_type = gr.Dropdown(
66
+ choices=["train", "test"],
67
+ value="train",
68
+ label="Dataset",
69
+ info="Choose dataset to explore",
70
+ scale=2
71
+ )
72
+ import_dataset_btn = gr.Button("📥 Import Dataset", variant="primary", scale=1)
73
+
74
+ search_type = gr.Dropdown(
75
+ choices=["keys", "idx", "random"],
76
+ value="random",
77
+ label="Search Type",
78
+ info="How to find items",
79
+ scale=1
80
+ )
81
+ search_value = gr.Textbox(
82
+ label="Search Value",
83
+ placeholder="Enter keys or index (leave empty for random)",
84
+ info="Keys: exact match, Index: 0 to dataset size-1",
85
+ scale=2
86
+ )
87
+
88
+ instruction_display = gr.Textbox(
89
+ label="📝 Instruction",
90
+ interactive=False,
91
+ placeholder="No instruction available",
92
+ lines=1
93
+ )
94
+
95
+ repaint_viz_plot = gr.Plot()
96
+
97
+ with gr.Accordion("📋 Item Metadata (JSON)", open=False):
98
+ item_info_json = gr.Code(
99
+ label="Complete Item Information",
100
+ language="json",
101
+ interactive=False,
102
+ lines=15
103
+ )
104
+
105
+ with gr.Row(equal_height=True):
106
+ item_src_audio = gr.Audio(
107
+ label="Source Audio",
108
+ type="filepath",
109
+ interactive=False,
110
+ scale=8
111
+ )
112
+ get_item_btn = gr.Button("🔍 Get Item", variant="secondary", interactive=False, scale=2)
113
+
114
+ with gr.Row(equal_height=True):
115
+ item_target_audio = gr.Audio(
116
+ label="Target Audio",
117
+ type="filepath",
118
+ interactive=False,
119
+ scale=8
120
+ )
121
+ item_refer_audio = gr.Audio(
122
+ label="Reference Audio",
123
+ type="filepath",
124
+ interactive=False,
125
+ scale=2
126
+ )
127
+
128
+ with gr.Row():
129
+ use_src_checkbox = gr.Checkbox(
130
+ label="Use Source Audio from Dataset",
131
+ value=True,
132
+ info="Check to use the source audio from dataset"
133
+ )
134
+
135
+ data_status = gr.Textbox(label="📊 Data Status", interactive=False, value="❌ No dataset imported")
136
+ auto_fill_btn = gr.Button("📋 Auto-fill Generation Form", variant="primary")
137
+
138
+ return {
139
+ "dataset_type": dataset_type,
140
+ "import_dataset_btn": import_dataset_btn,
141
+ "search_type": search_type,
142
+ "search_value": search_value,
143
+ "instruction_display": instruction_display,
144
+ "repaint_viz_plot": repaint_viz_plot,
145
+ "item_info_json": item_info_json,
146
+ "item_src_audio": item_src_audio,
147
+ "get_item_btn": get_item_btn,
148
+ "item_target_audio": item_target_audio,
149
+ "item_refer_audio": item_refer_audio,
150
+ "use_src_checkbox": use_src_checkbox,
151
+ "data_status": data_status,
152
+ "auto_fill_btn": auto_fill_btn,
153
+ }
154
+
155
+
156
+ def create_generation_section(handler) -> dict:
157
+ """Create generation section"""
158
+ with gr.Group():
159
+ gr.HTML('<div class="section-header"><h3>🎼 ACE-Step V1.5 Demo </h3></div>')
160
+
161
+ # Service Configuration
162
+ with gr.Accordion("🔧 Service Configuration", open=True) as service_config_accordion:
163
+ with gr.Row():
164
+ with gr.Column(scale=2):
165
+ checkpoint_dropdown = gr.Dropdown(
166
+ label="Checkpoint File",
167
+ choices=handler.get_available_checkpoints(),
168
+ value=None,
169
+ info="Select a trained model checkpoint file (full path or filename)"
170
+ )
171
+ with gr.Column(scale=1):
172
+ refresh_btn = gr.Button("🔄 Refresh", size="sm")
173
+
174
+ with gr.Row():
175
+ # Get available acestep-v15- model list
176
+ available_models = handler.get_available_acestep_v15_models()
177
+ default_model = "acestep-v15-turbo" if "acestep-v15-turbo" in available_models else (available_models[0] if available_models else None)
178
+
179
+ config_path = gr.Dropdown(
180
+ label="Main Model Path",
181
+ choices=available_models,
182
+ value=default_model,
183
+ info="Select the model configuration directory (auto-scanned from checkpoints)"
184
+ )
185
+ device = gr.Dropdown(
186
+ choices=["auto", "cuda", "cpu"],
187
+ value="auto",
188
+ label="Device",
189
+ info="Processing device (auto-detect recommended)"
190
+ )
191
+
192
+ with gr.Row():
193
+ # Get available 5Hz LM model list
194
+ available_lm_models = handler.get_available_5hz_lm_models()
195
+ default_lm_model = "acestep-5Hz-lm-0.6B" if "acestep-5Hz-lm-0.6B" in available_lm_models else (available_lm_models[0] if available_lm_models else None)
196
+
197
+ lm_model_path = gr.Dropdown(
198
+ label="5Hz LM Model Path",
199
+ choices=available_lm_models,
200
+ value=default_lm_model,
201
+ info="Select the 5Hz LM model checkpoint (auto-scanned from checkpoints)"
202
+ )
203
+ init_llm_checkbox = gr.Checkbox(
204
+ label="Initialize 5Hz LM",
205
+ value=False,
206
+ info="Check to initialize 5Hz LM during service initialization"
207
+ )
208
+
209
+ with gr.Row():
210
+ # Auto-detect flash attention availability
211
+ flash_attn_available = handler.is_flash_attention_available()
212
+ use_flash_attention_checkbox = gr.Checkbox(
213
+ label="Use Flash Attention",
214
+ value=flash_attn_available,
215
+ interactive=flash_attn_available,
216
+ info="Enable flash attention for faster inference (requires flash_attn package)" if flash_attn_available else "Flash attention not available (flash_attn package not installed)"
217
+ )
218
+
219
+ init_btn = gr.Button("Initialize Service", variant="primary", size="lg")
220
+ init_status = gr.Textbox(label="Status", interactive=False, lines=3)
221
+
222
+ # Inputs
223
+ with gr.Row():
224
+ with gr.Column(scale=2):
225
+ with gr.Accordion("📝 Required Inputs", open=True):
226
+ # Task type
227
+ with gr.Row():
228
+ with gr.Column(scale=2):
229
+ task_type = gr.Dropdown(
230
+ choices=["text2music", "repaint", "cover", "extract", "lego", "complete"],
231
+ value="text2music",
232
+ label="Task Type",
233
+ info="Select the task type for generation",
234
+ )
235
+ with gr.Column(scale=8):
236
+ instruction_display_gen = gr.Textbox(
237
+ label="Instruction",
238
+ value="Fill the audio semantic mask based on the given conditions:",
239
+ interactive=False,
240
+ lines=1,
241
+ info="Instruction is automatically generated based on task type",
242
+ )
243
+
244
+ track_name = gr.Dropdown(
245
+ choices=["woodwinds", "brass", "fx", "synth", "strings", "percussion",
246
+ "keyboard", "guitar", "bass", "drums", "backing_vocals", "vocals"],
247
+ value=None,
248
+ label="Track Name",
249
+ info="Select track name for lego/extract tasks",
250
+ visible=False
251
+ )
252
+
253
+ complete_track_classes = gr.CheckboxGroup(
254
+ choices=["woodwinds", "brass", "fx", "synth", "strings", "percussion",
255
+ "keyboard", "guitar", "bass", "drums", "backing_vocals", "vocals"],
256
+ label="Track Names",
257
+ info="Select multiple track classes for complete task",
258
+ visible=False
259
+ )
260
+
261
+ # Audio uploads
262
+ with gr.Accordion("🎵 Audio Uploads", open=False):
263
+ with gr.Row():
264
+ with gr.Column(scale=2):
265
+ reference_audio = gr.Audio(
266
+ label="Reference Audio (optional)",
267
+ type="filepath",
268
+ )
269
+ with gr.Column(scale=8):
270
+ src_audio = gr.Audio(
271
+ label="Source Audio (optional)",
272
+ type="filepath",
273
+ )
274
+
275
+ audio_code_string = gr.Textbox(
276
+ label="Audio Codes (optional)",
277
+ placeholder="<|audio_code_10695|><|audio_code_54246|>...",
278
+ lines=4,
279
+ visible=False,
280
+ info="Paste precomputed audio code tokens"
281
+ )
282
+
283
+ # Audio Codes for text2music
284
+ with gr.Accordion("🎼 Audio Codes (for text2music)", open=True, visible=True) as text2music_audio_codes_group:
285
+ text2music_audio_code_string = gr.Textbox(
286
+ label="Audio Codes",
287
+ placeholder="<|audio_code_10695|><|audio_code_54246|>...",
288
+ lines=6,
289
+ info="Paste precomputed audio code tokens for text2music generation"
290
+ )
291
+
292
+ # 5Hz LM
293
+ with gr.Row(visible=False) as use_5hz_lm_row:
294
+ use_5hz_lm_btn = gr.Button(
295
+ "Generate LM Hints",
296
+ variant="secondary",
297
+ size="lg",
298
+ )
299
+ lm_temperature = gr.Slider(
300
+ label="Temperature",
301
+ minimum=0.0,
302
+ maximum=2.0,
303
+ value=0.7,
304
+ step=0.1,
305
+ scale=2,
306
+ info="Temperature for 5Hz LM sampling"
307
+ )
308
+
309
+ # Repainting controls
310
+ with gr.Group(visible=False) as repainting_group:
311
+ gr.HTML("<h5>🎨 Repainting Controls (seconds) </h5>")
312
+ with gr.Row():
313
+ repainting_start = gr.Number(
314
+ label="Repainting Start",
315
+ value=0.0,
316
+ step=0.1,
317
+ )
318
+ repainting_end = gr.Number(
319
+ label="Repainting End",
320
+ value=-1,
321
+ minimum=-1,
322
+ step=0.1,
323
+ )
324
+
325
+ # Audio Cover Strength
326
+ audio_cover_strength = gr.Slider(
327
+ minimum=0.0,
328
+ maximum=1.0,
329
+ value=1.0,
330
+ step=0.01,
331
+ label="Audio Cover Strength",
332
+ info="Control how many denoising steps use cover mode",
333
+ visible=False
334
+ )
335
+
336
+ # Music Caption
337
+ with gr.Accordion("📝 Music Caption", open=True):
338
+ captions = gr.Textbox(
339
+ label="Music Caption (optional)",
340
+ placeholder="A peaceful acoustic guitar melody with soft vocals...",
341
+ lines=3,
342
+ info="Describe the style, genre, instruments, and mood"
343
+ )
344
+
345
+ # Lyrics
346
+ with gr.Accordion("📝 Lyrics", open=True):
347
+ lyrics = gr.Textbox(
348
+ label="Lyrics (optional)",
349
+ placeholder="[Verse 1]\nUnder the starry night\nI feel so alive...",
350
+ lines=8,
351
+ info="Song lyrics with structure"
352
+ )
353
+
354
+ # Optional Parameters
355
+ with gr.Accordion("⚙️ Optional Parameters", open=True):
356
+ with gr.Row():
357
+ vocal_language = gr.Dropdown(
358
+ choices=["en", "zh", "ja", "ko", "es", "fr", "de"],
359
+ value="en",
360
+ label="Vocal Language (optional)",
361
+ allow_custom_value=True
362
+ )
363
+ bpm = gr.Number(
364
+ label="BPM (optional)",
365
+ value=None,
366
+ step=1,
367
+ info="leave empty for N/A"
368
+ )
369
+ key_scale = gr.Textbox(
370
+ label="Key/Scale (optional)",
371
+ placeholder="Leave empty for N/A",
372
+ value="",
373
+ )
374
+ time_signature = gr.Dropdown(
375
+ choices=["2", "3", "4", "N/A", ""],
376
+ value="4",
377
+ label="Time Signature (optional)",
378
+ allow_custom_value=True
379
+ )
380
+ audio_duration = gr.Number(
381
+ label="Audio Duration (seconds)",
382
+ value=-1,
383
+ minimum=-1,
384
+ maximum=600.0,
385
+ step=0.1,
386
+ info="Use -1 for random"
387
+ )
388
+ batch_size_input = gr.Number(
389
+ label="Batch Size",
390
+ value=1,
391
+ minimum=1,
392
+ maximum=8,
393
+ step=1,
394
+ info="Number of audio files to parallel generate"
395
+ )
396
+
397
+ # Advanced Settings
398
+ with gr.Accordion("🔧 Advanced Settings", open=False):
399
+ with gr.Row():
400
+ inference_steps = gr.Slider(
401
+ minimum=1,
402
+ maximum=8,
403
+ value=8,
404
+ step=1,
405
+ label="Inference Steps",
406
+ info="Turbo: max 8, Base: max 100"
407
+ )
408
+ guidance_scale = gr.Slider(
409
+ minimum=1.0,
410
+ maximum=15.0,
411
+ value=7.0,
412
+ step=0.1,
413
+ label="Guidance Scale",
414
+ info="Higher values follow text more closely",
415
+ visible=False
416
+ )
417
+ seed = gr.Textbox(
418
+ label="Seed",
419
+ value="-1",
420
+ info="Use comma-separated values for batches"
421
+ )
422
+ random_seed_checkbox = gr.Checkbox(
423
+ label="Random Seed",
424
+ value=True,
425
+ info="Enable to auto-generate seeds"
426
+ )
427
+
428
+ with gr.Row():
429
+ use_adg = gr.Checkbox(
430
+ label="Use ADG",
431
+ value=False,
432
+ info="Enable Angle Domain Guidance",
433
+ visible=False
434
+ )
435
+
436
+ with gr.Row():
437
+ cfg_interval_start = gr.Slider(
438
+ minimum=0.0,
439
+ maximum=1.0,
440
+ value=0.0,
441
+ step=0.01,
442
+ label="CFG Interval Start",
443
+ visible=False
444
+ )
445
+ cfg_interval_end = gr.Slider(
446
+ minimum=0.0,
447
+ maximum=1.0,
448
+ value=1.0,
449
+ step=0.01,
450
+ label="CFG Interval End",
451
+ visible=False
452
+ )
453
+
454
+ with gr.Row():
455
+ audio_format = gr.Dropdown(
456
+ choices=["mp3", "flac"],
457
+ value="mp3",
458
+ label="Audio Format",
459
+ info="Audio format for saved files"
460
+ )
461
+
462
+ generate_btn = gr.Button("🎵 Generate Music", variant="primary", size="lg", interactive=False)
463
+
464
+ return {
465
+ "checkpoint_dropdown": checkpoint_dropdown,
466
+ "refresh_btn": refresh_btn,
467
+ "config_path": config_path,
468
+ "device": device,
469
+ "init_btn": init_btn,
470
+ "init_status": init_status,
471
+ "lm_model_path": lm_model_path,
472
+ "init_llm_checkbox": init_llm_checkbox,
473
+ "use_flash_attention_checkbox": use_flash_attention_checkbox,
474
+ "task_type": task_type,
475
+ "instruction_display_gen": instruction_display_gen,
476
+ "track_name": track_name,
477
+ "complete_track_classes": complete_track_classes,
478
+ "reference_audio": reference_audio,
479
+ "src_audio": src_audio,
480
+ "audio_code_string": audio_code_string,
481
+ "text2music_audio_code_string": text2music_audio_code_string,
482
+ "text2music_audio_codes_group": text2music_audio_codes_group,
483
+ "use_5hz_lm_row": use_5hz_lm_row,
484
+ "use_5hz_lm_btn": use_5hz_lm_btn,
485
+ "lm_temperature": lm_temperature,
486
+ "repainting_group": repainting_group,
487
+ "repainting_start": repainting_start,
488
+ "repainting_end": repainting_end,
489
+ "audio_cover_strength": audio_cover_strength,
490
+ "captions": captions,
491
+ "lyrics": lyrics,
492
+ "vocal_language": vocal_language,
493
+ "bpm": bpm,
494
+ "key_scale": key_scale,
495
+ "time_signature": time_signature,
496
+ "audio_duration": audio_duration,
497
+ "batch_size_input": batch_size_input,
498
+ "inference_steps": inference_steps,
499
+ "guidance_scale": guidance_scale,
500
+ "seed": seed,
501
+ "random_seed_checkbox": random_seed_checkbox,
502
+ "use_adg": use_adg,
503
+ "cfg_interval_start": cfg_interval_start,
504
+ "cfg_interval_end": cfg_interval_end,
505
+ "audio_format": audio_format,
506
+ "generate_btn": generate_btn,
507
+ }
508
+
509
+
510
+ def create_results_section(handler) -> dict:
511
+ """Create results display section"""
512
+ with gr.Group():
513
+ gr.HTML('<div class="section-header"><h3>🎧 Generated Results</h3></div>')
514
+
515
+ status_output = gr.Textbox(label="Generation Status", interactive=False)
516
+
517
+ with gr.Row():
518
+ with gr.Column():
519
+ generated_audio_1 = gr.Audio(
520
+ label="🎵 Generated Music (Sample 1)",
521
+ type="filepath",
522
+ interactive=False
523
+ )
524
+ with gr.Column():
525
+ generated_audio_2 = gr.Audio(
526
+ label="🎵 Generated Music (Sample 2)",
527
+ type="filepath",
528
+ interactive=False
529
+ )
530
+
531
+ with gr.Accordion("📁 Batch Results & Generation Details", open=False):
532
+ generated_audio_batch = gr.File(
533
+ label="📁 All Generated Files (Download)",
534
+ file_count="multiple",
535
+ interactive=False
536
+ )
537
+ generation_info = gr.Markdown(label="Generation Details")
538
+
539
+ gr.Markdown("### ⚖️ Alignment Preference Analysis")
540
+
541
+ with gr.Row():
542
+ with gr.Column():
543
+ align_score_1 = gr.Textbox(label="Alignment Score (Sample 1)", interactive=False)
544
+ align_text_1 = gr.Textbox(label="Lyric Timestamps (Sample 1)", interactive=False, lines=10)
545
+ align_plot_1 = gr.Plot(label="Alignment Heatmap (Sample 1)")
546
+ with gr.Column():
547
+ align_score_2 = gr.Textbox(label="Alignment Score (Sample 2)", interactive=False)
548
+ align_text_2 = gr.Textbox(label="Lyric Timestamps (Sample 2)", interactive=False, lines=10)
549
+ align_plot_2 = gr.Plot(label="Alignment Heatmap (Sample 2)")
550
+
551
+ return {
552
+ "status_output": status_output,
553
+ "generated_audio_1": generated_audio_1,
554
+ "generated_audio_2": generated_audio_2,
555
+ "generated_audio_batch": generated_audio_batch,
556
+ "generation_info": generation_info,
557
+ "align_score_1": align_score_1,
558
+ "align_text_1": align_text_1,
559
+ "align_plot_1": align_plot_1,
560
+ "align_score_2": align_score_2,
561
+ "align_text_2": align_text_2,
562
+ "align_plot_2": align_plot_2,
563
+ }
564
+
565
+
566
+ def setup_event_handlers(demo, handler, dataset_section, generation_section, results_section):
567
+ """Setup event handlers connecting UI components and business logic"""
568
+
569
+ def update_init_status(status_msg, enable_btn):
570
+ """Update initialization status and enable/disable generate button"""
571
+ return status_msg, gr.update(interactive=enable_btn)
572
+
573
+ # Dataset handlers
574
+ dataset_section["import_dataset_btn"].click(
575
+ fn=handler.import_dataset,
576
+ inputs=[dataset_section["dataset_type"]],
577
+ outputs=[dataset_section["data_status"]]
578
+ )
579
+
580
+ # Service initialization - refresh checkpoints
581
+ def refresh_checkpoints():
582
+ choices = handler.get_available_checkpoints()
583
+ return gr.update(choices=choices)
584
+
585
+ generation_section["refresh_btn"].click(
586
+ fn=refresh_checkpoints,
587
+ outputs=[generation_section["checkpoint_dropdown"]]
588
+ )
589
+
590
+ # Update UI based on model type (turbo vs base)
591
+ def update_model_type_settings(config_path):
592
+ """Update UI settings based on model type"""
593
+ if config_path is None:
594
+ config_path = ""
595
+ config_path_lower = config_path.lower()
596
+
597
+ if "turbo" in config_path_lower:
598
+ # Turbo model: max 8 steps, hide CFG/ADG
599
+ return (
600
+ gr.update(value=8, maximum=8, minimum=1), # inference_steps
601
+ gr.update(visible=False), # guidance_scale
602
+ gr.update(visible=False), # use_adg
603
+ gr.update(visible=False), # cfg_interval_start
604
+ gr.update(visible=False), # cfg_interval_end
605
+ )
606
+ elif "base" in config_path_lower:
607
+ # Base model: max 100 steps, show CFG/ADG
608
+ return (
609
+ gr.update(value=32, maximum=100, minimum=1), # inference_steps
610
+ gr.update(visible=True), # guidance_scale
611
+ gr.update(visible=True), # use_adg
612
+ gr.update(visible=True), # cfg_interval_start
613
+ gr.update(visible=True), # cfg_interval_end
614
+ )
615
+ else:
616
+ # Default to turbo settings
617
+ return (
618
+ gr.update(value=8, maximum=8, minimum=1),
619
+ gr.update(visible=False),
620
+ gr.update(visible=False),
621
+ gr.update(visible=False),
622
+ gr.update(visible=False),
623
+ )
624
+
625
+ generation_section["config_path"].change(
626
+ fn=update_model_type_settings,
627
+ inputs=[generation_section["config_path"]],
628
+ outputs=[
629
+ generation_section["inference_steps"],
630
+ generation_section["guidance_scale"],
631
+ generation_section["use_adg"],
632
+ generation_section["cfg_interval_start"],
633
+ generation_section["cfg_interval_end"],
634
+ ]
635
+ )
636
+
637
+ # Service initialization
638
+ def init_service_wrapper(checkpoint, config_path, device, init_llm, lm_model_path, use_flash_attention):
639
+ """Wrapper for service initialization, returns status and button state"""
640
+ status, enable = handler.initialize_service(checkpoint, config_path, device, init_llm, lm_model_path, use_flash_attention)
641
+ return status, gr.update(interactive=enable)
642
+
643
+ generation_section["init_btn"].click(
644
+ fn=init_service_wrapper,
645
+ inputs=[
646
+ generation_section["checkpoint_dropdown"],
647
+ generation_section["config_path"],
648
+ generation_section["device"],
649
+ generation_section["init_llm_checkbox"],
650
+ generation_section["lm_model_path"],
651
+ generation_section["use_flash_attention_checkbox"],
652
+ ],
653
+ outputs=[generation_section["init_status"], generation_section["generate_btn"]]
654
+ )
655
+
656
+ # Generation with progress bar
657
+ def generate_with_progress(
658
+ captions, lyrics, bpm, key_scale, time_signature, vocal_language,
659
+ inference_steps, guidance_scale, random_seed_checkbox, seed,
660
+ reference_audio, audio_duration, batch_size_input, src_audio,
661
+ text2music_audio_code_string, repainting_start, repainting_end,
662
+ instruction_display_gen, audio_cover_strength, task_type,
663
+ use_adg, cfg_interval_start, cfg_interval_end, audio_format, lm_temperature,
664
+ progress=gr.Progress(track_tqdm=True)
665
+ ):
666
+ return handler.generate_music(
667
+ captions=captions, lyrics=lyrics, bpm=bpm, key_scale=key_scale,
668
+ time_signature=time_signature, vocal_language=vocal_language,
669
+ inference_steps=inference_steps, guidance_scale=guidance_scale,
670
+ use_random_seed=random_seed_checkbox, seed=seed,
671
+ reference_audio=reference_audio, audio_duration=audio_duration,
672
+ batch_size=batch_size_input, src_audio=src_audio,
673
+ audio_code_string=text2music_audio_code_string,
674
+ repainting_start=repainting_start, repainting_end=repainting_end,
675
+ instruction=instruction_display_gen, audio_cover_strength=audio_cover_strength,
676
+ task_type=task_type, use_adg=use_adg,
677
+ cfg_interval_start=cfg_interval_start, cfg_interval_end=cfg_interval_end,
678
+ audio_format=audio_format, lm_temperature=lm_temperature,
679
+ progress=progress
680
+ )
681
+
682
+ generation_section["generate_btn"].click(
683
+ fn=generate_with_progress,
684
+ inputs=[
685
+ generation_section["captions"],
686
+ generation_section["lyrics"],
687
+ generation_section["bpm"],
688
+ generation_section["key_scale"],
689
+ generation_section["time_signature"],
690
+ generation_section["vocal_language"],
691
+ generation_section["inference_steps"],
692
+ generation_section["guidance_scale"],
693
+ generation_section["random_seed_checkbox"],
694
+ generation_section["seed"],
695
+ generation_section["reference_audio"],
696
+ generation_section["audio_duration"],
697
+ generation_section["batch_size_input"],
698
+ generation_section["src_audio"],
699
+ generation_section["text2music_audio_code_string"],
700
+ generation_section["repainting_start"],
701
+ generation_section["repainting_end"],
702
+ generation_section["instruction_display_gen"],
703
+ generation_section["audio_cover_strength"],
704
+ generation_section["task_type"],
705
+ generation_section["use_adg"],
706
+ generation_section["cfg_interval_start"],
707
+ generation_section["cfg_interval_end"],
708
+ generation_section["audio_format"],
709
+ generation_section["lm_temperature"]
710
+ ],
711
+ outputs=[
712
+ results_section["generated_audio_1"],
713
+ results_section["generated_audio_2"],
714
+ results_section["generated_audio_batch"],
715
+ results_section["generation_info"],
716
+ results_section["status_output"],
717
+ generation_section["seed"],
718
+ results_section["align_score_1"],
719
+ results_section["align_text_1"],
720
+ results_section["align_plot_1"],
721
+ results_section["align_score_2"],
722
+ results_section["align_text_2"],
723
+ results_section["align_plot_2"]
724
+ ]
725
+ )
726
+
727
+ # 5Hz LM generation (simplified version, can be extended as needed)
728
+ def generate_lm_hints_wrapper(caption, lyrics, temperature):
729
+ """Wrapper for 5Hz LM generation"""
730
+ metadata, audio_codes, status = handler.generate_with_5hz_lm(caption, lyrics, temperature)
731
+ # 返回格式化的结果,可以根据需要调整
732
+ result_text = f"Status: {status}\n\nMetadata: {metadata}\n\nAudio Codes: {audio_codes[:200]}..." if len(audio_codes) > 200 else f"Status: {status}\n\nMetadata: {metadata}\n\nAudio Codes: {audio_codes}"
733
+ return result_text
734
+
735
+ generation_section["use_5hz_lm_btn"].click(
736
+ fn=generate_lm_hints_wrapper,
737
+ inputs=[
738
+ generation_section["captions"],
739
+ generation_section["lyrics"],
740
+ generation_section["lm_temperature"]
741
+ ],
742
+ outputs=[generation_section["text2music_audio_code_string"]]
743
+ )
744
+
acestep/handler.py ADDED
@@ -0,0 +1,1100 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Business Logic Handler
3
+ Encapsulates all data processing and business logic as a bridge between model and UI
4
+ """
5
+ import os
6
+ import math
7
+ import glob
8
+ import tempfile
9
+ import traceback
10
+ import re
11
+ import random
12
+ from typing import Optional, Dict, Any, Tuple, List, Union
13
+
14
+ import torch
15
+ import matplotlib.pyplot as plt
16
+ import numpy as np
17
+ import scipy.io.wavfile as wavfile
18
+ import soundfile as sf
19
+ import time
20
+
21
+ from transformers import AutoTokenizer, AutoModel
22
+ from diffusers.models import AutoencoderOobleck
23
+
24
+
25
+ class AceStepHandler:
26
+ """ACE-Step Business Logic Handler"""
27
+
28
+ def __init__(self):
29
+ self.model = None
30
+ self.config = None
31
+ self.device = "cpu"
32
+ self.dtype = torch.float32 # Will be set based on device in initialize_service
33
+ self.temp_dir = tempfile.mkdtemp()
34
+
35
+ # VAE for audio encoding/decoding
36
+ self.vae = None
37
+
38
+ # Text encoder and tokenizer
39
+ self.text_encoder = None
40
+ self.text_tokenizer = None
41
+
42
+ # Silence latent for initialization
43
+ self.silence_latent = None
44
+
45
+ # Sample rate
46
+ self.sample_rate = 48000
47
+
48
+ # 5Hz LM related
49
+ self.lm_model = None
50
+ self.lm_tokenizer = None
51
+ self.lm_initialized = False
52
+
53
+ # Reward model (temporarily disabled)
54
+ self.reward_model = None
55
+
56
+ # Dataset related (temporarily disabled)
57
+ self.dataset = None
58
+ self.dataset_imported = False
59
+
60
+ # Batch size
61
+ self.batch_size = 2
62
+
63
+ # Custom layers config
64
+ self.custom_layers_config = {
65
+ 2: [6, 7],
66
+ 3: [10, 11],
67
+ 4: [3],
68
+ 5: [8, 9, 11],
69
+ 6: [8]
70
+ }
71
+
72
+ def get_available_checkpoints(self) -> str:
73
+ """Return project root directory path"""
74
+ # Get project root (handler.py is in acestep/, so go up two levels to project root)
75
+ current_file = os.path.abspath(__file__)
76
+ project_root = os.path.dirname(os.path.dirname(current_file))
77
+ # default checkpoints
78
+ checkpoint_dir = os.path.join(project_root, "checkpoints")
79
+ if os.path.exists(checkpoint_dir):
80
+ return [checkpoint_dir]
81
+ else:
82
+ return []
83
+
84
+ def get_available_acestep_v15_models(self) -> List[str]:
85
+ """Scan and return all model directory names starting with 'acestep-v15-'"""
86
+ # Get project root
87
+ current_file = os.path.abspath(__file__)
88
+ project_root = os.path.dirname(os.path.dirname(current_file))
89
+ checkpoint_dir = os.path.join(project_root, "checkpoints")
90
+
91
+ models = []
92
+ if os.path.exists(checkpoint_dir):
93
+ # Scan all directories starting with 'acestep-v15-' in checkpoints folder
94
+ for item in os.listdir(checkpoint_dir):
95
+ item_path = os.path.join(checkpoint_dir, item)
96
+ if os.path.isdir(item_path) and item.startswith("acestep-v15-"):
97
+ models.append(item)
98
+
99
+ # Sort by name
100
+ models.sort()
101
+ return models
102
+
103
+ def get_available_5hz_lm_models(self) -> List[str]:
104
+ """Scan and return all model directory names starting with 'acestep-5Hz-lm-'"""
105
+ current_file = os.path.abspath(__file__)
106
+ project_root = os.path.dirname(os.path.dirname(current_file))
107
+ checkpoint_dir = os.path.join(project_root, "checkpoints")
108
+
109
+ models = []
110
+ if os.path.exists(checkpoint_dir):
111
+ for item in os.listdir(checkpoint_dir):
112
+ item_path = os.path.join(checkpoint_dir, item)
113
+ if os.path.isdir(item_path) and item.startswith("acestep-5Hz-lm-"):
114
+ models.append(item)
115
+
116
+ models.sort()
117
+ return models
118
+
119
+ def is_flash_attention_available(self) -> bool:
120
+ """Check if flash attention is available on the system"""
121
+ try:
122
+ import flash_attn
123
+ return True
124
+ except ImportError:
125
+ return False
126
+
127
+ def initialize_service(
128
+ self,
129
+ project_root: str,
130
+ config_path: str,
131
+ device: str = "auto",
132
+ init_llm: bool = False,
133
+ lm_model_path: str = "acestep-5Hz-lm-0.6B",
134
+ use_flash_attention: bool = False,
135
+ ) -> Tuple[str, bool]:
136
+ """
137
+ Initialize model service
138
+
139
+ Args:
140
+ project_root: Project root path (may be checkpoints directory, will be handled automatically)
141
+ config_path: Model config directory name (e.g., "acestep-v15-turbo")
142
+ device: Device type
143
+ init_llm: Whether to initialize 5Hz LM model
144
+ lm_model_path: 5Hz LM model path
145
+ use_flash_attention: Whether to use flash attention (requires flash_attn package)
146
+
147
+ Returns:
148
+ (status_message, enable_generate_button)
149
+ """
150
+ try:
151
+ if device == "auto":
152
+ device = "cuda" if torch.cuda.is_available() else "cpu"
153
+
154
+ self.device = device
155
+ # Set dtype based on device: bfloat16 for cuda, float32 for cpu
156
+ self.dtype = torch.bfloat16 if device == "cuda" else torch.float32
157
+
158
+ # Auto-detect project root (independent of passed project_root parameter)
159
+ current_file = os.path.abspath(__file__)
160
+ actual_project_root = os.path.dirname(os.path.dirname(current_file))
161
+ checkpoint_dir = os.path.join(actual_project_root, "checkpoints")
162
+
163
+ # 1. Load main model
164
+ # config_path is relative path (e.g., "acestep-v15-turbo"), concatenate to checkpoints directory
165
+ acestep_v15_checkpoint_path = os.path.join(checkpoint_dir, config_path)
166
+ if os.path.exists(acestep_v15_checkpoint_path):
167
+ # Determine attention implementation
168
+ attn_implementation = "flash_attention_2" if use_flash_attention and self.is_flash_attention_available() else "eager"
169
+ self.model = AutoModel.from_pretrained(
170
+ acestep_v15_checkpoint_path,
171
+ trust_remote_code=True,
172
+ attn_implementation=attn_implementation
173
+ )
174
+ self.config = self.model.config
175
+ # Move model to device and set dtype
176
+ self.model = self.model.to(device).to(self.dtype)
177
+ self.model.eval()
178
+ silence_latent_path = os.path.join(acestep_v15_checkpoint_path, "silence_latent.pt")
179
+ if os.path.exists(silence_latent_path):
180
+ self.silence_latent = torch.load(silence_latent_path).transpose(1, 2).squeeze(0) # [L, C]
181
+ self.silence_latent = self.silence_latent.to(device).to(self.dtype)
182
+ else:
183
+ raise FileNotFoundError(f"Silence latent not found at {silence_latent_path}")
184
+ else:
185
+ raise FileNotFoundError(f"ACE-Step V1.5 checkpoint not found at {acestep_v15_checkpoint_path}")
186
+
187
+ # 2. Load VAE
188
+ vae_checkpoint_path = os.path.join(checkpoint_dir, "vae")
189
+ if os.path.exists(vae_checkpoint_path):
190
+ self.vae = AutoencoderOobleck.from_pretrained(vae_checkpoint_path)
191
+ self.vae = self.vae.to(device).to(self.dtype)
192
+ self.vae.eval()
193
+ else:
194
+ raise FileNotFoundError(f"VAE checkpoint not found at {vae_checkpoint_path}")
195
+
196
+ # 3. Load text encoder and tokenizer
197
+ text_encoder_path = os.path.join(checkpoint_dir, "Qwen3-Embedding-0.6B")
198
+ if os.path.exists(text_encoder_path):
199
+ self.text_tokenizer = AutoTokenizer.from_pretrained(text_encoder_path)
200
+ self.text_encoder = AutoModel.from_pretrained(text_encoder_path)
201
+ self.text_encoder = self.text_encoder.to(device).to(self.dtype)
202
+ self.text_encoder.eval()
203
+ else:
204
+ raise FileNotFoundError(f"Text encoder not found at {text_encoder_path}")
205
+
206
+ # 4. Load 5Hz LM model (optional, only if init_llm is True)
207
+ if init_llm:
208
+ full_lm_model_path = os.path.join(checkpoint_dir, lm_model_path)
209
+ if os.path.exists(full_lm_model_path):
210
+ if device == "cuda":
211
+ status_msg = self._initialize_5hz_lm_cuda(full_lm_model_path)
212
+ if not self.llm_initialized:
213
+ return status_msg, False
214
+ self.llm = AutoModel.from_pretrained(full_lm_model_path)
215
+ self.llm_tokenizer = AutoTokenizer.from_pretrained(full_lm_model_path)
216
+ else:
217
+ # 5Hz LM path not found
218
+ return f"❌ 5Hz LM model not found at {full_lm_model_path}", False
219
+
220
+ # Determine actual attention implementation used
221
+ actual_attn = "flash_attention_2" if use_flash_attention and self.is_flash_attention_available() else "eager"
222
+
223
+ status_msg = f"✅ Model initialized successfully on {device}\n"
224
+ status_msg += f"Main model: {acestep_v15_checkpoint_path}\n"
225
+ status_msg += f"VAE: {vae_checkpoint_path}\n"
226
+ status_msg += f"Text encoder: {text_encoder_path}\n"
227
+ if init_llm and hasattr(self, 'llm') and self.llm is not None:
228
+ status_msg += f"5Hz LM model: {os.path.join(checkpoint_dir, lm_model_path)}\n"
229
+ else:
230
+ status_msg += f"5Hz LM model: Not loaded (checkbox not selected)\n"
231
+ status_msg += f"Dtype: {self.dtype}\n"
232
+ status_msg += f"Attention: {actual_attn}"
233
+
234
+ return status_msg, True
235
+
236
+ except Exception as e:
237
+ error_msg = f"❌ Error initializing model: {str(e)}\n\nTraceback:\n{traceback.format_exc()}"
238
+ return error_msg, False
239
+
240
+ def import_dataset(self, dataset_type: str) -> str:
241
+ """Import dataset (temporarily disabled)"""
242
+ self.dataset_imported = False
243
+ return f"⚠️ Dataset import is currently disabled. Text2MusicDataset dependency not available."
244
+
245
+ def get_item_data(self, *args, **kwargs):
246
+ """Get dataset item (temporarily disabled)"""
247
+ return "", "", "", "", "", None, None, None, "❌ Dataset not available", "", 0, "", None, None, None, {}, "text2music"
248
+
249
+ def get_gpu_memory_utilization(self, minimal_gpu: float = 8, min_ratio: float = 0.2, max_ratio: float = 0.9) -> float:
250
+ """Get GPU memory utilization ratio"""
251
+ try:
252
+ device = torch.device("cuda:0")
253
+ total_gpu_mem_bytes = torch.cuda.get_device_properties(device).total_memory
254
+ allocated_mem_bytes = torch.cuda.memory_allocated(device)
255
+ reserved_mem_bytes = torch.cuda.memory_reserved(device)
256
+
257
+ total_gpu = total_gpu_mem_bytes / 1024**3
258
+ allocated_gpu = allocated_mem_bytes / 1024**3
259
+ reserved_gpu = reserved_mem_bytes / 1024**3
260
+ available_gpu = total_gpu - reserved_gpu
261
+
262
+ if available_gpu >= minimal_gpu:
263
+ ratio = min(max_ratio, max(min_ratio, minimal_gpu / total_gpu))
264
+ else:
265
+ ratio = min(max_ratio, max(min_ratio, (available_gpu * 0.8) / total_gpu))
266
+
267
+ return ratio
268
+ except Exception as e:
269
+ return 0.9
270
+
271
+ def _initialize_5hz_lm_cuda(self, model_path: str) -> str:
272
+ """Initialize 5Hz LM model"""
273
+ try:
274
+ from nanovllm import LLM, SamplingParams
275
+
276
+ if not torch.cuda.is_available():
277
+ return "❌ CUDA is not available. Please check your GPU setup."
278
+
279
+ current_device = torch.cuda.current_device()
280
+ device_name = torch.cuda.get_device_name(current_device)
281
+
282
+ torch.cuda.empty_cache()
283
+ gpu_memory_utilization = self.get_gpu_memory_utilization(
284
+ minimal_gpu=8,
285
+ min_ratio=0.2,
286
+ max_ratio=0.9
287
+ )
288
+
289
+ self.llm = LLM(
290
+ model=model_path,
291
+ enforce_eager=False,
292
+ tensor_parallel_size=1,
293
+ max_model_len=4096,
294
+ gpu_memory_utilization=gpu_memory_utilization,
295
+ )
296
+ self.llm_tokenizer = self.llm.tokenizer
297
+ self.llm_initialized = True
298
+ return f"✅ 5Hz LM initialized successfully\nModel: {model_path}\nDevice: {device_name}\nGPU Memory Utilization: {gpu_memory_utilization:.2f}"
299
+ except Exception as e:
300
+ self.llm_initialized = False
301
+ error_msg = f"❌ Error initializing 5Hz LM: {str(e)}\n\nTraceback:\n{traceback.format_exc()}"
302
+ return error_msg
303
+
304
+ def generate_with_5hz_lm(self, caption: str, lyrics: str, temperature: float = 0.6) -> Tuple[Dict[str, Any], str, str]:
305
+ """Generate metadata and audio codes using 5Hz LM"""
306
+ if not self.lm_initialized or self.llm is None:
307
+ return {}, "", "❌ 5Hz LM not initialized. Please initialize it first."
308
+
309
+ try:
310
+ from nanovllm import SamplingParams
311
+
312
+ prompt = f"# Caption\n{caption}\n\n# Lyric\n{lyrics}\n"
313
+
314
+ formatted_prompt = self.lm_tokenizer.apply_chat_template(
315
+ [
316
+ {"role": "system", "content": "# Instruction\nGenerate audio semantic tokens based on the given conditions:\n\n"},
317
+ {"role": "user", "content": prompt}
318
+ ],
319
+ tokenize=False,
320
+ add_generation_prompt=True,
321
+ )
322
+
323
+ sampling_params = SamplingParams(max_tokens=3072, temperature=temperature)
324
+ outputs = self.llm.generate([formatted_prompt], sampling_params)
325
+
326
+ if isinstance(outputs, list) and len(outputs) > 0:
327
+ if hasattr(outputs[0], 'outputs') and len(outputs[0].outputs) > 0:
328
+ output_text = outputs[0].outputs[0].text
329
+ elif hasattr(outputs[0], 'text'):
330
+ output_text = outputs[0].text
331
+ else:
332
+ output_text = str(outputs[0])
333
+ else:
334
+ output_text = str(outputs)
335
+
336
+ metadata, audio_codes = self.parse_lm_output(output_text)
337
+ codes_count = len(audio_codes.split('<|audio_code_')) - 1 if audio_codes else 0
338
+ return metadata, audio_codes, f"✅ Generated successfully\nOutput length: {len(output_text)} chars\nCodes count: {codes_count}"
339
+
340
+ except Exception as e:
341
+ error_msg = f"❌ Error generating with 5Hz LM: {str(e)}\n\nTraceback:\n{traceback.format_exc()}"
342
+ return {}, "", error_msg
343
+
344
+ def parse_lm_output(self, output_text: str) -> Tuple[Dict[str, Any], str]:
345
+ """Parse LM output"""
346
+ metadata = {}
347
+ audio_codes = ""
348
+
349
+ import re
350
+
351
+ # Extract audio codes
352
+ code_pattern = r'<\|audio_code_\d+\|>'
353
+ code_matches = re.findall(code_pattern, output_text)
354
+ if code_matches:
355
+ audio_codes = "".join(code_matches)
356
+
357
+ # Extract metadata
358
+ reasoning_patterns = [
359
+ r'<think>(.*?)</think>',
360
+ r'<reasoning>(.*?)</reasoning>',
361
+ ]
362
+
363
+ reasoning_text = None
364
+ for pattern in reasoning_patterns:
365
+ match = re.search(pattern, output_text, re.DOTALL)
366
+ if match:
367
+ reasoning_text = match.group(1).strip()
368
+ break
369
+
370
+ if not reasoning_text:
371
+ lines_before_codes = output_text.split('<|audio_code_')[0] if '<|audio_code_' in output_text else output_text
372
+ reasoning_text = lines_before_codes.strip()
373
+
374
+ # Parse metadata fields
375
+ if reasoning_text:
376
+ for line in reasoning_text.split('\n'):
377
+ line = line.strip()
378
+ if ':' in line and not line.startswith('<'):
379
+ parts = line.split(':', 1)
380
+ if len(parts) == 2:
381
+ key = parts[0].strip().lower()
382
+ value = parts[1].strip()
383
+
384
+ if key == 'bpm':
385
+ try:
386
+ metadata['bpm'] = int(value)
387
+ except:
388
+ metadata['bpm'] = value
389
+ elif key == 'duration':
390
+ try:
391
+ metadata['duration'] = int(value)
392
+ except:
393
+ metadata['duration'] = value
394
+ elif key in ['genres', 'keyscale', 'timesignature']:
395
+ metadata[key] = value
396
+
397
+ return metadata, audio_codes
398
+
399
+ def process_reference_audio(self, audio_file) -> Optional[torch.Tensor]:
400
+ """Process reference audio"""
401
+ if audio_file is None:
402
+ return None
403
+
404
+ try:
405
+ # Load audio using soundfile
406
+ audio_np, sr = sf.read(audio_file, dtype='float32')
407
+ # Convert to torch: [samples, channels] or [samples] -> [channels, samples]
408
+ if audio_np.ndim == 1:
409
+ audio = torch.from_numpy(audio_np).unsqueeze(0)
410
+ else:
411
+ audio = torch.from_numpy(audio_np.T)
412
+
413
+ if audio.shape[0] == 1:
414
+ audio = torch.cat([audio, audio], dim=0)
415
+
416
+ audio = audio[:2]
417
+
418
+ # Resample if needed
419
+ if sr != 48000:
420
+ import torch.nn.functional as F
421
+ # Simple resampling using interpolate
422
+ ratio = 48000 / sr
423
+ new_length = int(audio.shape[-1] * ratio)
424
+ audio = F.interpolate(audio.unsqueeze(0), size=new_length, mode='linear', align_corners=False).squeeze(0)
425
+
426
+ audio = torch.clamp(audio, -1.0, 1.0)
427
+
428
+ target_frames = 30 * 48000
429
+ if audio.shape[-1] > target_frames:
430
+ start_frame = (audio.shape[-1] - target_frames) // 2
431
+ audio = audio[:, start_frame:start_frame + target_frames]
432
+ elif audio.shape[-1] < target_frames:
433
+ audio = torch.nn.functional.pad(
434
+ audio, (0, target_frames - audio.shape[-1]), 'constant', 0
435
+ )
436
+
437
+ return audio
438
+ except Exception as e:
439
+ print(f"Error processing reference audio: {e}")
440
+ return None
441
+
442
+ def process_target_audio(self, audio_file) -> Optional[torch.Tensor]:
443
+ """Process target audio"""
444
+ if audio_file is None:
445
+ return None
446
+
447
+ try:
448
+ # Load audio using soundfile
449
+ audio_np, sr = sf.read(audio_file, dtype='float32')
450
+ # Convert to torch: [samples, channels] or [samples] -> [channels, samples]
451
+ if audio_np.ndim == 1:
452
+ audio = torch.from_numpy(audio_np).unsqueeze(0)
453
+ else:
454
+ audio = torch.from_numpy(audio_np.T)
455
+
456
+ if audio.shape[0] == 1:
457
+ audio = torch.cat([audio, audio], dim=0)
458
+
459
+ audio = audio[:2]
460
+
461
+ # Resample if needed
462
+ if sr != 48000:
463
+ import torch.nn.functional as F
464
+ ratio = 48000 / sr
465
+ new_length = int(audio.shape[-1] * ratio)
466
+ audio = F.interpolate(audio.unsqueeze(0), size=new_length, mode='linear', align_corners=False).squeeze(0)
467
+
468
+ audio = torch.clamp(audio, -1.0, 1.0)
469
+
470
+ return audio
471
+ except Exception as e:
472
+ print(f"Error processing target audio: {e}")
473
+ return None
474
+
475
+ def _parse_audio_code_string(self, code_str: str) -> List[int]:
476
+ """Extract integer audio codes from prompt tokens like <|audio_code_123|>."""
477
+ if not code_str:
478
+ return []
479
+ try:
480
+ return [int(x) for x in re.findall(r"<\|audio_code_(\d+)\|>", code_str)]
481
+ except Exception:
482
+ return []
483
+
484
+ def _decode_audio_codes_to_latents(self, code_str: str) -> Optional[torch.Tensor]:
485
+ """
486
+ Convert serialized audio code string into 25Hz latents using model quantizer/detokenizer.
487
+ """
488
+ if not self.model or not hasattr(self.model, 'tokenizer') or not hasattr(self.model, 'detokenizer'):
489
+ return None
490
+
491
+ code_ids = self._parse_audio_code_string(code_str)
492
+ if len(code_ids) == 0:
493
+ return None
494
+
495
+ quantizer = self.model.tokenizer.quantizer
496
+ detokenizer = self.model.detokenizer
497
+
498
+ num_quantizers = getattr(quantizer, "num_quantizers", 1)
499
+ indices = torch.tensor(code_ids, device=self.device, dtype=torch.long).unsqueeze(0) # [1, T_5Hz]
500
+
501
+ # Expand to include quantizer dimension: [1, T_5Hz, num_quantizers]
502
+ if indices.dim() == 2:
503
+ indices = indices.unsqueeze(-1).expand(-1, -1, num_quantizers)
504
+
505
+ # Get quantized representation from indices: [1, T_5Hz, dim]
506
+ quantized = quantizer.get_output_from_indices(indices)
507
+ if quantized.dtype != self.dtype:
508
+ quantized = quantized.to(self.dtype)
509
+
510
+ # Detokenize to 25Hz: [1, T_5Hz, dim] -> [1, T_25Hz, dim]
511
+ lm_hints_25hz = detokenizer(quantized)
512
+ return lm_hints_25hz
513
+
514
+ def _create_default_meta(self) -> str:
515
+ """Create default metadata string."""
516
+ return (
517
+ "- bpm: N/A\n"
518
+ "- timesignature: N/A\n"
519
+ "- keyscale: N/A\n"
520
+ "- duration: 30 seconds\n"
521
+ )
522
+
523
+ def _dict_to_meta_string(self, meta_dict: Dict[str, Any]) -> str:
524
+ """Convert metadata dict to formatted string."""
525
+ bpm = meta_dict.get('bpm', meta_dict.get('tempo', 'N/A'))
526
+ timesignature = meta_dict.get('timesignature', meta_dict.get('time_signature', 'N/A'))
527
+ keyscale = meta_dict.get('keyscale', meta_dict.get('key', meta_dict.get('scale', 'N/A')))
528
+ duration = meta_dict.get('duration', meta_dict.get('length', 30))
529
+
530
+ # Format duration
531
+ if isinstance(duration, (int, float)):
532
+ duration = f"{int(duration)} seconds"
533
+ elif not isinstance(duration, str):
534
+ duration = "30 seconds"
535
+
536
+ return (
537
+ f"- bpm: {bpm}\n"
538
+ f"- timesignature: {timesignature}\n"
539
+ f"- keyscale: {keyscale}\n"
540
+ f"- duration: {duration}\n"
541
+ )
542
+
543
+ def _parse_metas(self, metas: List[Union[str, Dict[str, Any]]]) -> List[str]:
544
+ """Parse and normalize metadata with fallbacks."""
545
+ parsed_metas = []
546
+ for meta in metas:
547
+ if meta is None:
548
+ parsed_meta = self._create_default_meta()
549
+ elif isinstance(meta, str):
550
+ parsed_meta = meta
551
+ elif isinstance(meta, dict):
552
+ parsed_meta = self._dict_to_meta_string(meta)
553
+ else:
554
+ parsed_meta = self._create_default_meta()
555
+ parsed_metas.append(parsed_meta)
556
+ return parsed_metas
557
+
558
+ def _get_text_hidden_states(self, text_prompt: str) -> Tuple[torch.Tensor, torch.Tensor]:
559
+ """Get text hidden states from text encoder."""
560
+ if self.text_tokenizer is None or self.text_encoder is None:
561
+ raise ValueError("Text encoder not initialized")
562
+
563
+ # Tokenize
564
+ text_inputs = self.text_tokenizer(
565
+ text_prompt,
566
+ padding="longest",
567
+ truncation=True,
568
+ max_length=256,
569
+ return_tensors="pt",
570
+ )
571
+ text_input_ids = text_inputs.input_ids.to(self.device)
572
+ text_attention_mask = text_inputs.attention_mask.to(self.device).bool()
573
+
574
+ # Encode
575
+ with torch.no_grad():
576
+ text_outputs = self.text_encoder(text_input_ids)
577
+ if hasattr(text_outputs, 'last_hidden_state'):
578
+ text_hidden_states = text_outputs.last_hidden_state
579
+ elif isinstance(text_outputs, tuple):
580
+ text_hidden_states = text_outputs[0]
581
+ else:
582
+ text_hidden_states = text_outputs
583
+
584
+ text_hidden_states = text_hidden_states.to(self.dtype)
585
+
586
+ return text_hidden_states, text_attention_mask
587
+
588
+ def extract_caption_from_sft_format(self, caption: str) -> str:
589
+ """Extract caption from SFT format if needed."""
590
+ # Simple extraction - can be enhanced if needed
591
+ if caption and isinstance(caption, str):
592
+ return caption.strip()
593
+ return caption if caption else ""
594
+
595
+ def generate_music(
596
+ self,
597
+ captions: str,
598
+ lyrics: str,
599
+ bpm: Optional[int] = None,
600
+ key_scale: str = "",
601
+ time_signature: str = "",
602
+ vocal_language: str = "en",
603
+ inference_steps: int = 8,
604
+ guidance_scale: float = 7.0,
605
+ use_random_seed: bool = True,
606
+ seed: Optional[Union[str, float, int]] = -1,
607
+ reference_audio=None,
608
+ audio_duration: Optional[float] = None,
609
+ batch_size: Optional[int] = None,
610
+ src_audio=None,
611
+ audio_code_string: str = "",
612
+ repainting_start: float = 0.0,
613
+ repainting_end: Optional[float] = None,
614
+ instruction: str = "Fill the audio semantic mask based on the given conditions:",
615
+ audio_cover_strength: float = 1.0,
616
+ task_type: str = "text2music",
617
+ use_adg: bool = False,
618
+ cfg_interval_start: float = 0.0,
619
+ cfg_interval_end: float = 1.0,
620
+ audio_format: str = "mp3",
621
+ lm_temperature: float = 0.6,
622
+ progress=None
623
+ ) -> Tuple[Optional[str], Optional[str], List[str], str, str, str, str, str, Optional[Any], str, str, Optional[Any]]:
624
+ """
625
+ Main interface for music generation
626
+
627
+ Returns:
628
+ (first_audio, second_audio, all_audio_paths, generation_info, status_message,
629
+ seed_value_for_ui, align_score_1, align_text_1, align_plot_1,
630
+ align_score_2, align_text_2, align_plot_2)
631
+ """
632
+ if self.model is None or self.vae is None or self.text_tokenizer is None or self.text_encoder is None:
633
+ return None, None, [], "", "❌ Model not fully initialized. Please initialize all components first.", "-1", "", "", None, "", "", None
634
+
635
+ try:
636
+ print("[generate_music] Starting generation...")
637
+ if progress:
638
+ progress(0.05, desc="Preparing inputs...")
639
+ print("[generate_music] Preparing inputs...")
640
+
641
+ # Determine actual batch size
642
+ actual_batch_size = batch_size if batch_size is not None else self.batch_size
643
+ actual_batch_size = max(1, min(actual_batch_size, 8)) # Limit to 8 for memory safety
644
+
645
+ # Process seeds
646
+ if use_random_seed:
647
+ seed_list = [random.randint(0, 2**32 - 1) for _ in range(actual_batch_size)]
648
+ else:
649
+ # Parse seed input
650
+ if isinstance(seed, str):
651
+ seed_parts = [s.strip() for s in seed.split(",")]
652
+ seed_list = [int(float(s)) if s != "-1" and s else random.randint(0, 2**32 - 1) for s in seed_parts[:actual_batch_size]]
653
+ elif isinstance(seed, (int, float)) and seed >= 0:
654
+ seed_list = [int(seed)] * actual_batch_size
655
+ else:
656
+ seed_list = [random.randint(0, 2**32 - 1) for _ in range(actual_batch_size)]
657
+
658
+ # Pad if needed
659
+ while len(seed_list) < actual_batch_size:
660
+ seed_list.append(random.randint(0, 2**32 - 1))
661
+
662
+ seed_value_for_ui = ", ".join(str(s) for s in seed_list)
663
+
664
+ # Process audio inputs
665
+ processed_ref_audio = self.process_reference_audio(reference_audio) if reference_audio else None
666
+ processed_src_audio = self.process_target_audio(src_audio) if src_audio else None
667
+
668
+ # Extract caption
669
+ pure_caption = self.extract_caption_from_sft_format(captions)
670
+
671
+ # Determine task type and update instruction if needed
672
+ if task_type == "text2music" and audio_code_string and str(audio_code_string).strip():
673
+ task_type = "cover"
674
+ instruction = "Generate audio semantic tokens based on the given conditions:"
675
+
676
+ # Build metadata
677
+ metadata_dict = {
678
+ "bpm": bpm if bpm else "N/A",
679
+ "keyscale": key_scale if key_scale else "N/A",
680
+ "timesignature": time_signature if time_signature else "N/A",
681
+ }
682
+
683
+ # Calculate duration
684
+ if processed_src_audio is not None:
685
+ calculated_duration = processed_src_audio.shape[-1] / self.sample_rate
686
+ elif audio_duration is not None and audio_duration > 0:
687
+ calculated_duration = audio_duration
688
+ else:
689
+ calculated_duration = 30.0 # Default 30 seconds
690
+
691
+ metadata_dict["duration"] = f"{int(calculated_duration)} seconds"
692
+
693
+ if progress:
694
+ progress(0.1, desc="Processing audio inputs...")
695
+ print("[generate_music] Processing audio inputs...")
696
+
697
+ # Prepare batch data
698
+ captions_batch = [pure_caption] * actual_batch_size
699
+ lyrics_batch = [lyrics] * actual_batch_size
700
+ vocal_languages_batch = [vocal_language] * actual_batch_size
701
+ instructions_batch = [instruction] * actual_batch_size
702
+ metas_batch = [metadata_dict.copy()] * actual_batch_size
703
+ audio_code_hints_batch = [audio_code_string if audio_code_string else None] * actual_batch_size
704
+
705
+ # Process reference audios
706
+ if processed_ref_audio is not None:
707
+ refer_audios = [[processed_ref_audio] for _ in range(actual_batch_size)]
708
+ else:
709
+ # Create silence as fallback
710
+ silence_frames = 30 * self.sample_rate
711
+ silence = torch.zeros(2, silence_frames)
712
+ refer_audios = [[silence] for _ in range(actual_batch_size)]
713
+
714
+ # Process target wavs (src_audio)
715
+ if processed_src_audio is not None:
716
+ target_wavs_list = [processed_src_audio.clone() for _ in range(actual_batch_size)]
717
+ else:
718
+ # Create silence based on duration
719
+ target_frames = int(calculated_duration * self.sample_rate)
720
+ silence = torch.zeros(2, target_frames)
721
+ target_wavs_list = [silence for _ in range(actual_batch_size)]
722
+
723
+ # Pad target_wavs to consistent length
724
+ max_target_frames = max(wav.shape[-1] for wav in target_wavs_list)
725
+ target_wavs = torch.stack([
726
+ torch.nn.functional.pad(wav, (0, max_target_frames - wav.shape[-1]), 'constant', 0)
727
+ for wav in target_wavs_list
728
+ ])
729
+
730
+ if progress:
731
+ progress(0.2, desc="Encoding audio to latents...")
732
+ print("[generate_music] Encoding audio to latents...")
733
+
734
+ # Encode target_wavs to latents using VAE
735
+ target_latents_list = []
736
+ latent_lengths = []
737
+
738
+ with torch.no_grad():
739
+ for i in range(actual_batch_size):
740
+ # Check if audio codes are provided
741
+ code_hint = audio_code_hints_batch[i]
742
+ if code_hint:
743
+ decoded_latents = self._decode_audio_codes_to_latents(code_hint)
744
+ if decoded_latents is not None:
745
+ decoded_latents = decoded_latents.squeeze(0) # Remove batch dim
746
+ target_latents_list.append(decoded_latents)
747
+ latent_lengths.append(decoded_latents.shape[0])
748
+ continue
749
+
750
+ # If no src_audio provided, use silence_latent directly (skip VAE)
751
+ if processed_src_audio is None:
752
+ # Calculate required latent length based on duration
753
+ # VAE downsample ratio is 1920 (2*4*4*6*10), so latent rate is 48000/1920 = 25Hz
754
+ latent_length = int(calculated_duration * 25) # 25Hz latent rate
755
+ latent_length = max(128, latent_length) # Minimum 128
756
+
757
+ # Tile silence_latent to required length
758
+ if self.silence_latent.shape[0] >= latent_length:
759
+ target_latent = self.silence_latent[:latent_length].to(self.device).to(self.dtype)
760
+ else:
761
+ repeat_times = (latent_length // self.silence_latent.shape[0]) + 1
762
+ target_latent = self.silence_latent.repeat(repeat_times, 1)[:latent_length].to(self.device).to(self.dtype)
763
+ target_latents_list.append(target_latent)
764
+ latent_lengths.append(target_latent.shape[0])
765
+ continue
766
+
767
+ # Encode from audio using VAE
768
+ current_wav = target_wavs[i].unsqueeze(0).to(self.device).to(self.dtype)
769
+ target_latent = self.vae.encode(current_wav)
770
+ target_latent = target_latent.squeeze(0).transpose(0, 1) # [latent_length, latent_dim]
771
+ target_latents_list.append(target_latent)
772
+ latent_lengths.append(target_latent.shape[0])
773
+
774
+ # Pad latents to same length
775
+ max_latent_length = max(latent_lengths)
776
+ max_latent_length = max(128, max_latent_length) # Minimum 128
777
+
778
+ padded_latents = []
779
+ for i, latent in enumerate(target_latents_list):
780
+ if latent.shape[0] < max_latent_length:
781
+ pad_length = max_latent_length - latent.shape[0]
782
+ # Tile silence_latent to pad_length (silence_latent is [L, C])
783
+ if self.silence_latent.shape[0] >= pad_length:
784
+ pad_latent = self.silence_latent[:pad_length]
785
+ else:
786
+ repeat_times = (pad_length // self.silence_latent.shape[0]) + 1
787
+ pad_latent = self.silence_latent.repeat(repeat_times, 1)[:pad_length]
788
+ latent = torch.cat([latent, pad_latent.to(self.device).to(self.dtype)], dim=0)
789
+ padded_latents.append(latent)
790
+
791
+ target_latents = torch.stack(padded_latents).to(self.device).to(self.dtype)
792
+ latent_masks = torch.stack([
793
+ torch.cat([
794
+ torch.ones(l, dtype=torch.long, device=self.device),
795
+ torch.zeros(max_latent_length - l, dtype=torch.long, device=self.device)
796
+ ])
797
+ for l in latent_lengths
798
+ ])
799
+
800
+ if progress:
801
+ progress(0.3, desc="Preparing conditions...")
802
+ print("[generate_music] Preparing conditions...")
803
+
804
+ # Determine task type and create chunk masks
805
+ is_covers = []
806
+ chunk_masks = []
807
+ repainting_ranges = {}
808
+
809
+ for i in range(actual_batch_size):
810
+ has_code_hint = audio_code_hints_batch[i] is not None
811
+ has_repainting = (repainting_end is not None and repainting_end > repainting_start)
812
+
813
+ if has_repainting:
814
+ # Repainting mode
815
+ start_sec = max(0, repainting_start)
816
+ end_sec = repainting_end if repainting_end is not None else calculated_duration
817
+
818
+ start_latent = int(start_sec * self.sample_rate // 1920)
819
+ end_latent = int(end_sec * self.sample_rate // 1920)
820
+ start_latent = max(0, min(start_latent, max_latent_length - 1))
821
+ end_latent = max(start_latent + 1, min(end_latent, max_latent_length))
822
+
823
+ mask = torch.zeros(max_latent_length, dtype=torch.bool, device=self.device)
824
+ mask[start_latent:end_latent] = True
825
+ chunk_masks.append(mask)
826
+ repainting_ranges[i] = (start_latent, end_latent)
827
+ is_covers.append(False)
828
+ else:
829
+ # Full generation or cover
830
+ chunk_masks.append(torch.ones(max_latent_length, dtype=torch.bool, device=self.device))
831
+ # Check if cover task
832
+ instruction_lower = instructions_batch[i].lower()
833
+ is_cover = ("generate audio semantic tokens" in instruction_lower and
834
+ "based on the given conditions" in instruction_lower) or has_code_hint
835
+ is_covers.append(is_cover)
836
+
837
+ chunk_masks = torch.stack(chunk_masks).unsqueeze(-1).expand(-1, -1, 64) # [batch, length, 64]
838
+ is_covers = torch.tensor(is_covers, dtype=torch.bool, device=self.device)
839
+
840
+ # Create src_latents
841
+ # Tile silence_latent to max_latent_length (silence_latent is now [L, C])
842
+ if self.silence_latent.shape[0] >= max_latent_length:
843
+ silence_latent_tiled = self.silence_latent[:max_latent_length].to(self.device).to(self.dtype)
844
+ else:
845
+ repeat_times = (max_latent_length // self.silence_latent.shape[0]) + 1
846
+ silence_latent_tiled = self.silence_latent.repeat(repeat_times, 1)[:max_latent_length].to(self.device).to(self.dtype)
847
+ src_latents_list = []
848
+
849
+ for i in range(actual_batch_size):
850
+ has_target_audio = (target_wavs[i].abs().sum() > 1e-6) or (audio_code_hints_batch[i] is not None)
851
+
852
+ if has_target_audio:
853
+ if i in repainting_ranges:
854
+ # Repaint: replace inpainting region with silence
855
+ src_latent = target_latents[i].clone()
856
+ start_latent, end_latent = repainting_ranges[i]
857
+ src_latent[start_latent:end_latent] = silence_latent_tiled[start_latent:end_latent]
858
+ src_latents_list.append(src_latent)
859
+ else:
860
+ # Cover/extract/complete/lego: use target_latents
861
+ src_latents_list.append(target_latents[i].clone())
862
+ else:
863
+ # Text2music: use silence
864
+ src_latents_list.append(silence_latent_tiled.clone())
865
+
866
+ src_latents = torch.stack(src_latents_list) # [batch, length, channels]
867
+
868
+ if progress:
869
+ progress(0.4, desc="Tokenizing text inputs...")
870
+ print("[generate_music] Tokenizing text inputs...")
871
+
872
+ # Prepare text and lyric hidden states
873
+ SFT_GEN_PROMPT = """# Instruction
874
+ {}
875
+
876
+ # Caption
877
+ {}
878
+
879
+ # Metas
880
+ {}<|endoftext|>
881
+ """
882
+
883
+ text_hidden_states_list = []
884
+ text_attention_masks_list = []
885
+ lyric_hidden_states_list = []
886
+ lyric_attention_masks_list = []
887
+
888
+ with torch.no_grad():
889
+ for i in range(actual_batch_size):
890
+ # Format text prompt
891
+ inst = instructions_batch[i]
892
+ if not inst.endswith(":"):
893
+ inst = inst + ":"
894
+
895
+ meta_str = self._dict_to_meta_string(metas_batch[i])
896
+ text_prompt = SFT_GEN_PROMPT.format(inst, captions_batch[i], meta_str)
897
+
898
+ # Tokenize and encode text
899
+ text_hidden, text_mask = self._get_text_hidden_states(text_prompt)
900
+ text_hidden_states_list.append(text_hidden.squeeze(0))
901
+ text_attention_masks_list.append(text_mask.squeeze(0))
902
+
903
+ # Format and tokenize lyrics
904
+ lyrics_text = f"# Languages\n{vocal_languages_batch[i]}\n\n# Lyric\n{lyrics_batch[i]}<|endoftext|>"
905
+ lyric_hidden, lyric_mask = self._get_text_hidden_states(lyrics_text)
906
+ lyric_hidden_states_list.append(lyric_hidden.squeeze(0))
907
+ lyric_attention_masks_list.append(lyric_mask.squeeze(0))
908
+
909
+ # Pad sequences
910
+ max_text_length = max(h.shape[0] for h in text_hidden_states_list)
911
+ max_lyric_length = max(h.shape[0] for h in lyric_hidden_states_list)
912
+
913
+ text_hidden_states = torch.stack([
914
+ torch.nn.functional.pad(h, (0, 0, 0, max_text_length - h.shape[0]), 'constant', 0)
915
+ for h in text_hidden_states_list
916
+ ]).to(self.device).to(self.dtype)
917
+
918
+ text_attention_mask = torch.stack([
919
+ torch.nn.functional.pad(m, (0, max_text_length - m.shape[0]), 'constant', 0)
920
+ for m in text_attention_masks_list
921
+ ]).to(self.device)
922
+
923
+ lyric_hidden_states = torch.stack([
924
+ torch.nn.functional.pad(h, (0, 0, 0, max_lyric_length - h.shape[0]), 'constant', 0)
925
+ for h in lyric_hidden_states_list
926
+ ]).to(self.device).to(self.dtype)
927
+
928
+ lyric_attention_mask = torch.stack([
929
+ torch.nn.functional.pad(m, (0, max_lyric_length - m.shape[0]), 'constant', 0)
930
+ for m in lyric_attention_masks_list
931
+ ]).to(self.device)
932
+
933
+ if progress:
934
+ progress(0.5, desc="Processing reference audio...")
935
+ print("[generate_music] Processing reference audio...")
936
+
937
+ # Process reference audio for timbre
938
+ # Model expects: refer_audio_acoustic_hidden_states_packed [N, timbre_fix_frame, audio_acoustic_hidden_dim]
939
+ # refer_audio_order_mask [N] indicating batch assignment
940
+ timbre_fix_frame = getattr(self.config, 'timbre_fix_frame', 750)
941
+ refer_audio_acoustic_hidden_states_packed_list = []
942
+ refer_audio_order_mask_list = []
943
+
944
+ with torch.no_grad():
945
+ for i, ref_audio_list in enumerate(refer_audios):
946
+ if ref_audio_list and len(ref_audio_list) > 0 and ref_audio_list[0].abs().sum() > 1e-6:
947
+ # Encode reference audio: [channels, samples] -> [1, latent_dim, T] -> [T, latent_dim]
948
+ ref_audio = ref_audio_list[0].unsqueeze(0).to(self.device).to(self.dtype)
949
+ ref_latent = self.vae.encode(ref_audio).latent_dist.sample() # [1, latent_dim, T]
950
+ ref_latent = ref_latent.squeeze(0).transpose(0, 1) # [T, latent_dim]
951
+ # Ensure dimension matches audio_acoustic_hidden_dim (64)
952
+ if ref_latent.shape[-1] != self.config.audio_acoustic_hidden_dim:
953
+ ref_latent = ref_latent[:, :self.config.audio_acoustic_hidden_dim]
954
+ # Pad or truncate to timbre_fix_frame
955
+ if ref_latent.shape[0] < timbre_fix_frame:
956
+ pad_length = timbre_fix_frame - ref_latent.shape[0]
957
+ padding = torch.zeros(pad_length, ref_latent.shape[1], device=self.device, dtype=self.dtype)
958
+ ref_latent = torch.cat([ref_latent, padding], dim=0)
959
+ else:
960
+ ref_latent = ref_latent[:timbre_fix_frame]
961
+ refer_audio_acoustic_hidden_states_packed_list.append(ref_latent)
962
+ refer_audio_order_mask_list.append(i)
963
+ else:
964
+ # Use silence_latent directly instead of running VAE
965
+ if self.silence_latent.shape[0] >= timbre_fix_frame:
966
+ silence_ref = self.silence_latent[:timbre_fix_frame, :self.config.audio_acoustic_hidden_dim]
967
+ else:
968
+ repeat_times = (timbre_fix_frame // self.silence_latent.shape[0]) + 1
969
+ silence_ref = self.silence_latent.repeat(repeat_times, 1)[:timbre_fix_frame, :self.config.audio_acoustic_hidden_dim]
970
+ refer_audio_acoustic_hidden_states_packed_list.append(silence_ref.to(self.device).to(self.dtype))
971
+ refer_audio_order_mask_list.append(i)
972
+
973
+ # Stack all reference audios: [N, timbre_fix_frame, audio_acoustic_hidden_dim]
974
+ refer_audio_acoustic_hidden_states_packed = torch.stack(refer_audio_acoustic_hidden_states_packed_list, dim=0).to(self.device).to(self.dtype)
975
+ # Order mask: [N] indicating which batch item each reference belongs to
976
+ refer_audio_order_mask = torch.tensor(refer_audio_order_mask_list, dtype=torch.long, device=self.device)
977
+
978
+ if progress:
979
+ progress(0.6, desc="Generating audio...")
980
+ print("[generate_music] Calling model.generate_audio()...")
981
+ print(f" - text_hidden_states: {text_hidden_states.shape}, dtype={text_hidden_states.dtype}")
982
+ print(f" - text_attention_mask: {text_attention_mask.shape}, dtype={text_attention_mask.dtype}")
983
+ print(f" - lyric_hidden_states: {lyric_hidden_states.shape}, dtype={lyric_hidden_states.dtype}")
984
+ print(f" - lyric_attention_mask: {lyric_attention_mask.shape}, dtype={lyric_attention_mask.dtype}")
985
+ print(f" - refer_audio_acoustic_hidden_states_packed: {refer_audio_acoustic_hidden_states_packed.shape}, dtype={refer_audio_acoustic_hidden_states_packed.dtype}")
986
+ print(f" - refer_audio_order_mask: {refer_audio_order_mask.shape}, dtype={refer_audio_order_mask.dtype}")
987
+ print(f" - src_latents: {src_latents.shape}, dtype={src_latents.dtype}")
988
+ print(f" - chunk_masks: {chunk_masks.shape}, dtype={chunk_masks.dtype}")
989
+ print(f" - is_covers: {is_covers.shape}, dtype={is_covers.dtype}")
990
+ print(f" - silence_latent: {self.silence_latent.unsqueeze(0).shape}")
991
+ print(f" - seed: {seed_list[0] if len(seed_list) > 0 else None}")
992
+ print(f" - fix_nfe: {inference_steps}")
993
+
994
+ # Call model to generate
995
+ with torch.no_grad():
996
+ outputs = self.model.generate_audio(
997
+ text_hidden_states=text_hidden_states,
998
+ text_attention_mask=text_attention_mask,
999
+ lyric_hidden_states=lyric_hidden_states,
1000
+ lyric_attention_mask=lyric_attention_mask,
1001
+ refer_audio_acoustic_hidden_states_packed=refer_audio_acoustic_hidden_states_packed,
1002
+ refer_audio_order_mask=refer_audio_order_mask,
1003
+ src_latents=src_latents,
1004
+ chunk_masks=chunk_masks,
1005
+ is_covers=is_covers,
1006
+ silence_latent=self.silence_latent.unsqueeze(0), # [1, L, C]
1007
+ seed=seed_list[0] if len(seed_list) > 0 else None,
1008
+ fix_nfe=inference_steps,
1009
+ infer_method="ode",
1010
+ use_cache=True,
1011
+ )
1012
+
1013
+ print("[generate_music] Model generation completed. Decoding latents...")
1014
+ pred_latents = outputs["target_latents"] # [batch, latent_length, latent_dim]
1015
+ time_costs = outputs["time_costs"]
1016
+ print(f" - pred_latents: {pred_latents.shape}, dtype={pred_latents.dtype} {pred_latents.min()=}, {pred_latents.max()=}, {pred_latents.mean()=} {pred_latents.std()=}")
1017
+ print(f" - time_costs: {time_costs}")
1018
+ if progress:
1019
+ progress(0.8, desc="Decoding audio...")
1020
+ print("[generate_music] Decoding latents with VAE...")
1021
+
1022
+ # Decode latents to audio
1023
+ start_time = time.time()
1024
+ with torch.no_grad():
1025
+ # Transpose for VAE decode: [batch, latent_length, latent_dim] -> [batch, latent_dim, latent_length]
1026
+ pred_latents_for_decode = pred_latents.transpose(1, 2)
1027
+ pred_wavs = self.vae.decode(pred_latents_for_decode).sample # [batch, channels, samples]
1028
+ end_time = time.time()
1029
+ time_costs["vae_decode_time_cost"] = end_time - start_time
1030
+ time_costs["total_time_cost"] = time_costs["total_time_cost"] + time_costs["vae_decode_time_cost"]
1031
+
1032
+ print("[generate_music] VAE decode completed. Saving audio files...")
1033
+ if progress:
1034
+ progress(0.9, desc="Saving audio files...")
1035
+
1036
+ # Save audio files using soundfile (supports wav, flac, mp3 via format param)
1037
+ audio_format_lower = audio_format.lower() if audio_format else "wav"
1038
+ if audio_format_lower not in ["wav", "flac", "mp3"]:
1039
+ audio_format_lower = "wav"
1040
+
1041
+ saved_files = []
1042
+ for i in range(actual_batch_size):
1043
+ audio_file = os.path.join(self.temp_dir, f"generated_{i}_{seed_list[i]}.{audio_format_lower}")
1044
+ # Convert to numpy: [channels, samples] -> [samples, channels]
1045
+ audio_np = pred_wavs[i].cpu().float().numpy().T
1046
+ sf.write(audio_file, audio_np, self.sample_rate)
1047
+ saved_files.append(audio_file)
1048
+
1049
+ # Prepare return values
1050
+ first_audio = saved_files[0] if len(saved_files) > 0 else None
1051
+ second_audio = saved_files[1] if len(saved_files) > 1 else None
1052
+
1053
+ # Format time costs if available
1054
+ time_costs_str = ""
1055
+ if time_costs:
1056
+ if isinstance(time_costs, dict):
1057
+ time_costs_str = "\n\n**⏱️ Time Costs:**\n"
1058
+ for key, value in time_costs.items():
1059
+ # Format key: encoder_time_cost -> Encoder
1060
+ formatted_key = key.replace("_time_cost", "").replace("_", " ").title()
1061
+ time_costs_str += f" - {formatted_key}: {value:.2f}s\n"
1062
+ elif isinstance(time_costs, (int, float)):
1063
+ time_costs_str = f"\n\n**⏱️ Time Cost:** {time_costs:.2f}s"
1064
+
1065
+ generation_info = f"""**🎵 Generation Complete**
1066
+
1067
+ **Seeds:** {seed_value_for_ui}
1068
+ **Duration:** {calculated_duration:.1f}s
1069
+ **Steps:** {inference_steps}
1070
+ **Files:** {len(saved_files)} audio(s){time_costs_str}"""
1071
+ status_message = f"✅ Generation completed successfully!"
1072
+ print(f"[generate_music] Done! Generated {len(saved_files)} audio files.")
1073
+
1074
+ # Alignment scores and plots (placeholder for now)
1075
+ align_score_1 = ""
1076
+ align_text_1 = ""
1077
+ align_plot_1 = None
1078
+ align_score_2 = ""
1079
+ align_text_2 = ""
1080
+ align_plot_2 = None
1081
+
1082
+ return (
1083
+ first_audio,
1084
+ second_audio,
1085
+ saved_files,
1086
+ generation_info,
1087
+ status_message,
1088
+ seed_value_for_ui,
1089
+ align_score_1,
1090
+ align_text_1,
1091
+ align_plot_1,
1092
+ align_score_2,
1093
+ align_text_2,
1094
+ align_plot_2,
1095
+ )
1096
+
1097
+ except Exception as e:
1098
+ error_msg = f"❌ Error generating music: {str(e)}\n\nTraceback:\n{traceback.format_exc()}"
1099
+ return None, None, [], "", error_msg, "-1", "", "", None, "", "", None
1100
+
acestep/third_parts/nano-vllm/LICENSE ADDED
@@ -0,0 +1,21 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ MIT License
2
+
3
+ Copyright (c) 2025 Xingkai Yu
4
+
5
+ Permission is hereby granted, free of charge, to any person obtaining a copy
6
+ of this software and associated documentation files (the "Software"), to deal
7
+ in the Software without restriction, including without limitation the rights
8
+ to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9
+ copies of the Software, and to permit persons to whom the Software is
10
+ furnished to do so, subject to the following conditions:
11
+
12
+ The above copyright notice and this permission notice shall be included in all
13
+ copies or substantial portions of the Software.
14
+
15
+ THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16
+ IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17
+ FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18
+ AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19
+ LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20
+ OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21
+ SOFTWARE.
acestep/third_parts/nano-vllm/README.md ADDED
@@ -0,0 +1,66 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ <p align="center">
2
+ <img width="300" src="assets/logo.png">
3
+ </p>
4
+
5
+ <p align="center">
6
+ <a href="https://trendshift.io/repositories/15323" target="_blank"><img src="https://trendshift.io/api/badge/repositories/15323" alt="GeeeekExplorer%2Fnano-vllm | Trendshift" style="width: 250px; height: 55px;" width="250" height="55"/></a>
7
+ </p>
8
+
9
+ # Nano-vLLM
10
+
11
+ A lightweight vLLM implementation built from scratch.
12
+
13
+ ## Key Features
14
+
15
+ * 🚀 **Fast offline inference** - Comparable inference speeds to vLLM
16
+ * 📖 **Readable codebase** - Clean implementation in ~ 1,200 lines of Python code
17
+ * ⚡ **Optimization Suite** - Prefix caching, Tensor Parallelism, Torch compilation, CUDA graph, etc.
18
+
19
+ ## Installation
20
+
21
+ ```bash
22
+ pip install git+https://github.com/GeeeekExplorer/nano-vllm.git
23
+ ```
24
+
25
+ ## Model Download
26
+
27
+ To download the model weights manually, use the following command:
28
+ ```bash
29
+ huggingface-cli download --resume-download Qwen/Qwen3-0.6B \
30
+ --local-dir ~/huggingface/Qwen3-0.6B/ \
31
+ --local-dir-use-symlinks False
32
+ ```
33
+
34
+ ## Quick Start
35
+
36
+ See `example.py` for usage. The API mirrors vLLM's interface with minor differences in the `LLM.generate` method:
37
+ ```python
38
+ from nanovllm import LLM, SamplingParams
39
+ llm = LLM("/YOUR/MODEL/PATH", enforce_eager=True, tensor_parallel_size=1)
40
+ sampling_params = SamplingParams(temperature=0.6, max_tokens=256)
41
+ prompts = ["Hello, Nano-vLLM."]
42
+ outputs = llm.generate(prompts, sampling_params)
43
+ outputs[0]["text"]
44
+ ```
45
+
46
+ ## Benchmark
47
+
48
+ See `bench.py` for benchmark.
49
+
50
+ **Test Configuration:**
51
+ - Hardware: RTX 4070 Laptop (8GB)
52
+ - Model: Qwen3-0.6B
53
+ - Total Requests: 256 sequences
54
+ - Input Length: Randomly sampled between 100–1024 tokens
55
+ - Output Length: Randomly sampled between 100–1024 tokens
56
+
57
+ **Performance Results:**
58
+ | Inference Engine | Output Tokens | Time (s) | Throughput (tokens/s) |
59
+ |----------------|-------------|----------|-----------------------|
60
+ | vLLM | 133,966 | 98.37 | 1361.84 |
61
+ | Nano-vLLM | 133,966 | 93.41 | 1434.13 |
62
+
63
+
64
+ ## Star History
65
+
66
+ [![Star History Chart](https://api.star-history.com/svg?repos=GeeeekExplorer/nano-vllm&type=Date)](https://www.star-history.com/#GeeeekExplorer/nano-vllm&Date)
acestep/third_parts/nano-vllm/assets/logo.png ADDED

Git LFS Details

  • SHA256: 03ec4039dc248e97e9943694d3ccfb52c1a73a6dab94c4cd6fd4288e08de98c8
  • Pointer size: 131 Bytes
  • Size of remote file: 397 kB
acestep/third_parts/nano-vllm/bench.py ADDED
@@ -0,0 +1,32 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import time
3
+ from random import randint, seed
4
+ from nanovllm import LLM, SamplingParams
5
+ # from vllm import LLM, SamplingParams
6
+
7
+
8
+ def main():
9
+ seed(0)
10
+ num_seqs = 256
11
+ max_input_len = 1024
12
+ max_ouput_len = 1024
13
+
14
+ path = os.path.expanduser("~/huggingface/Qwen3-0.6B/")
15
+ llm = LLM(path, enforce_eager=False, max_model_len=4096)
16
+
17
+ prompt_token_ids = [[randint(0, 10000) for _ in range(randint(100, max_input_len))] for _ in range(num_seqs)]
18
+ sampling_params = [SamplingParams(temperature=0.6, ignore_eos=True, max_tokens=randint(100, max_ouput_len)) for _ in range(num_seqs)]
19
+ # uncomment the following line for vllm
20
+ # prompt_token_ids = [dict(prompt_token_ids=p) for p in prompt_token_ids]
21
+
22
+ llm.generate(["Benchmark: "], SamplingParams())
23
+ t = time.time()
24
+ llm.generate(prompt_token_ids, sampling_params, use_tqdm=False)
25
+ t = (time.time() - t)
26
+ total_tokens = sum(sp.max_tokens for sp in sampling_params)
27
+ throughput = total_tokens / t
28
+ print(f"Total: {total_tokens}tok, Time: {t:.2f}s, Throughput: {throughput:.2f}tok/s")
29
+
30
+
31
+ if __name__ == "__main__":
32
+ main()
acestep/third_parts/nano-vllm/example.py ADDED
@@ -0,0 +1,33 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ from nanovllm import LLM, SamplingParams
3
+ from transformers import AutoTokenizer
4
+
5
+
6
+ def main():
7
+ path = os.path.expanduser("~/huggingface/Qwen3-0.6B/")
8
+ tokenizer = AutoTokenizer.from_pretrained(path)
9
+ llm = LLM(path, enforce_eager=True, tensor_parallel_size=1)
10
+
11
+ sampling_params = SamplingParams(temperature=0.6, max_tokens=256)
12
+ prompts = [
13
+ "introduce yourself",
14
+ "list all prime numbers within 100",
15
+ ]
16
+ prompts = [
17
+ tokenizer.apply_chat_template(
18
+ [{"role": "user", "content": prompt}],
19
+ tokenize=False,
20
+ add_generation_prompt=True,
21
+ )
22
+ for prompt in prompts
23
+ ]
24
+ outputs = llm.generate(prompts, sampling_params)
25
+
26
+ for prompt, output in zip(prompts, outputs):
27
+ print("\n")
28
+ print(f"Prompt: {prompt!r}")
29
+ print(f"Completion: {output['text']!r}")
30
+
31
+
32
+ if __name__ == "__main__":
33
+ main()
acestep/third_parts/nano-vllm/nanovllm/__init__.py ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ from nanovllm.llm import LLM
2
+ from nanovllm.sampling_params import SamplingParams
acestep/third_parts/nano-vllm/nanovllm/config.py ADDED
@@ -0,0 +1,26 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ from dataclasses import dataclass
3
+ from transformers import AutoConfig
4
+
5
+
6
+ @dataclass
7
+ class Config:
8
+ model: str
9
+ max_num_batched_tokens: int = 16384
10
+ max_num_seqs: int = 512
11
+ max_model_len: int = 4096
12
+ gpu_memory_utilization: float = 0.9
13
+ tensor_parallel_size: int = 1
14
+ enforce_eager: bool = False
15
+ hf_config: AutoConfig | None = None
16
+ eos: int = -1
17
+ kvcache_block_size: int = 256
18
+ num_kvcache_blocks: int = -1
19
+
20
+ def __post_init__(self):
21
+ assert os.path.isdir(self.model)
22
+ assert self.kvcache_block_size % 256 == 0
23
+ assert 1 <= self.tensor_parallel_size <= 8
24
+ self.hf_config = AutoConfig.from_pretrained(self.model)
25
+ self.max_model_len = min(self.max_model_len, self.hf_config.max_position_embeddings)
26
+ assert self.max_num_batched_tokens >= self.max_model_len
acestep/third_parts/nano-vllm/nanovllm/engine/block_manager.py ADDED
@@ -0,0 +1,112 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from collections import deque
2
+ import xxhash
3
+ import numpy as np
4
+
5
+ from nanovllm.engine.sequence import Sequence
6
+
7
+
8
+ class Block:
9
+
10
+ def __init__(self, block_id):
11
+ self.block_id = block_id
12
+ self.ref_count = 0
13
+ self.hash = -1
14
+ self.token_ids = []
15
+
16
+ def update(self, hash: int, token_ids: list[int]):
17
+ self.hash = hash
18
+ self.token_ids = token_ids
19
+
20
+ def reset(self):
21
+ self.ref_count = 1
22
+ self.hash = -1
23
+ self.token_ids = []
24
+
25
+
26
+ class BlockManager:
27
+
28
+ def __init__(self, num_blocks: int, block_size: int):
29
+ self.block_size = block_size
30
+ self.blocks: list[Block] = [Block(i) for i in range(num_blocks)]
31
+ self.hash_to_block_id: dict[int, int] = dict()
32
+ self.free_block_ids: deque[int] = deque(range(num_blocks))
33
+ self.used_block_ids: set[int] = set()
34
+
35
+ @classmethod
36
+ def compute_hash(cls, token_ids: list[int], prefix: int = -1):
37
+ h = xxhash.xxh64()
38
+ if prefix != -1:
39
+ h.update(prefix.to_bytes(8, "little"))
40
+ h.update(np.array(token_ids).tobytes())
41
+ return h.intdigest()
42
+
43
+ def _allocate_block(self, block_id: int) -> Block:
44
+ block = self.blocks[block_id]
45
+ assert block.ref_count == 0
46
+ block.reset()
47
+ self.free_block_ids.remove(block_id)
48
+ self.used_block_ids.add(block_id)
49
+ return self.blocks[block_id]
50
+
51
+ def _deallocate_block(self, block_id: int) -> Block:
52
+ assert self.blocks[block_id].ref_count == 0
53
+ self.used_block_ids.remove(block_id)
54
+ self.free_block_ids.append(block_id)
55
+
56
+ def can_allocate(self, seq: Sequence) -> bool:
57
+ return len(self.free_block_ids) >= seq.num_blocks
58
+
59
+ def allocate(self, seq: Sequence):
60
+ assert not seq.block_table
61
+ h = -1
62
+ cache_miss = False
63
+ for i in range(seq.num_blocks):
64
+ token_ids = seq.block(i)
65
+ h = self.compute_hash(token_ids, h) if len(token_ids) == self.block_size else -1
66
+ block_id = self.hash_to_block_id.get(h, -1)
67
+ if block_id == -1 or self.blocks[block_id].token_ids != token_ids:
68
+ cache_miss = True
69
+ if cache_miss:
70
+ block_id = self.free_block_ids[0]
71
+ block = self._allocate_block(block_id)
72
+ else:
73
+ seq.num_cached_tokens += self.block_size
74
+ if block_id in self.used_block_ids:
75
+ block = self.blocks[block_id]
76
+ block.ref_count += 1
77
+ else:
78
+ block = self._allocate_block(block_id)
79
+ if h != -1:
80
+ block.update(h, token_ids)
81
+ self.hash_to_block_id[h] = block_id
82
+ seq.block_table.append(block_id)
83
+
84
+ def deallocate(self, seq: Sequence):
85
+ for block_id in reversed(seq.block_table):
86
+ block = self.blocks[block_id]
87
+ block.ref_count -= 1
88
+ if block.ref_count == 0:
89
+ self._deallocate_block(block_id)
90
+ seq.num_cached_tokens = 0
91
+ seq.block_table.clear()
92
+
93
+ def can_append(self, seq: Sequence) -> bool:
94
+ return len(self.free_block_ids) >= (len(seq) % self.block_size == 1)
95
+
96
+ def may_append(self, seq: Sequence):
97
+ block_table = seq.block_table
98
+ last_block = self.blocks[block_table[-1]]
99
+ if len(seq) % self.block_size == 1:
100
+ assert last_block.hash != -1
101
+ block_id = self.free_block_ids[0]
102
+ self._allocate_block(block_id)
103
+ block_table.append(block_id)
104
+ elif len(seq) % self.block_size == 0:
105
+ assert last_block.hash == -1
106
+ token_ids = seq.block(seq.num_blocks-1)
107
+ prefix = self.blocks[block_table[-2]].hash if len(block_table) > 1 else -1
108
+ h = self.compute_hash(token_ids, prefix)
109
+ last_block.update(h, token_ids)
110
+ self.hash_to_block_id[h] = last_block.block_id
111
+ else:
112
+ assert last_block.hash == -1
acestep/third_parts/nano-vllm/nanovllm/engine/llm_engine.py ADDED
@@ -0,0 +1,120 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import atexit
2
+ from dataclasses import fields
3
+ from time import perf_counter
4
+ from tqdm.auto import tqdm
5
+ from transformers import AutoTokenizer
6
+ import torch.multiprocessing as mp
7
+
8
+ from nanovllm.config import Config
9
+ from nanovllm.sampling_params import SamplingParams
10
+ from nanovllm.engine.sequence import Sequence
11
+ from nanovllm.engine.scheduler import Scheduler
12
+ from nanovllm.engine.model_runner import ModelRunner
13
+
14
+
15
+ class LLMEngine:
16
+
17
+ def __init__(self, model, **kwargs):
18
+ config_fields = {field.name for field in fields(Config)}
19
+ config_kwargs = {k: v for k, v in kwargs.items() if k in config_fields}
20
+ config = Config(model, **config_kwargs)
21
+ self.ps = []
22
+ self.events = []
23
+ ctx = mp.get_context("spawn")
24
+ for i in range(1, config.tensor_parallel_size):
25
+ event = ctx.Event()
26
+ process = ctx.Process(target=ModelRunner, args=(config, i, event))
27
+ process.start()
28
+ self.ps.append(process)
29
+ self.events.append(event)
30
+ self.model_runner = ModelRunner(config, 0, self.events)
31
+ self.tokenizer = AutoTokenizer.from_pretrained(config.model, use_fast=True)
32
+ config.eos = self.tokenizer.eos_token_id
33
+ self.scheduler = Scheduler(config)
34
+ atexit.register(self.exit)
35
+
36
+ def exit(self):
37
+ self.model_runner.call("exit")
38
+ del self.model_runner
39
+ for p in self.ps:
40
+ p.join()
41
+
42
+ def add_request(self, prompt: str | list[int], sampling_params: SamplingParams, unconditional_prompt: str | list[int] | None = None):
43
+ if isinstance(prompt, str):
44
+ prompt = self.tokenizer.encode(prompt)
45
+ # For CFG: if cfg_scale > 1.0, create both conditional and unconditional sequences
46
+ if sampling_params.cfg_scale > 1.0:
47
+ if unconditional_prompt is None:
48
+ # Try to construct unconditional prompt by replacing user input with "NO USER INPUT"
49
+ # This is a fallback - ideally users should provide unconditional_prompt
50
+ if isinstance(prompt, list):
51
+ # For now, just use the same prompt (user should provide unconditional_prompt)
52
+ # TODO: Implement automatic "NO USER INPUT" replacement if possible
53
+ unconditional_prompt = prompt
54
+ else:
55
+ unconditional_prompt = prompt
56
+ if isinstance(unconditional_prompt, str):
57
+ unconditional_prompt = self.tokenizer.encode(unconditional_prompt)
58
+ # Create unconditional sequence first (so we can reference it from conditional)
59
+ uncond_seq = Sequence(unconditional_prompt, sampling_params, is_unconditional=True)
60
+ # Create conditional sequence with reference to unconditional
61
+ cond_seq = Sequence(prompt, sampling_params, is_unconditional=False, conditional_seq=uncond_seq)
62
+ uncond_seq.paired_seq = cond_seq # Link them bidirectionally
63
+ # Add both sequences to scheduler
64
+ self.scheduler.add(cond_seq)
65
+ self.scheduler.add(uncond_seq)
66
+ else:
67
+ seq = Sequence(prompt, sampling_params)
68
+ self.scheduler.add(seq)
69
+
70
+ def step(self):
71
+ seqs, is_prefill = self.scheduler.schedule()
72
+ token_ids = self.model_runner.call("run", seqs, is_prefill)
73
+ self.scheduler.postprocess(seqs, token_ids)
74
+ # Only output conditional sequences (unconditional sequences are just for CFG computation)
75
+ output_seqs = [seq for seq in seqs if seq.is_finished and (seq.cfg_scale <= 1.0 or not seq.is_unconditional)]
76
+ outputs = [(seq.seq_id, seq.completion_token_ids) for seq in output_seqs]
77
+ num_tokens = sum(len(seq) for seq in seqs) if is_prefill else -len([s for s in seqs if not s.is_unconditional])
78
+ return outputs, num_tokens
79
+
80
+ def is_finished(self):
81
+ return self.scheduler.is_finished()
82
+
83
+ def generate(
84
+ self,
85
+ prompts: list[str] | list[list[int]],
86
+ sampling_params: SamplingParams | list[SamplingParams],
87
+ use_tqdm: bool = True,
88
+ unconditional_prompts: list[str] | list[list[int]] | None = None,
89
+ ) -> list[str]:
90
+ if use_tqdm:
91
+ pbar = tqdm(total=len(prompts), desc="Generating", dynamic_ncols=True)
92
+ if not isinstance(sampling_params, list):
93
+ sampling_params = [sampling_params] * len(prompts)
94
+ if unconditional_prompts is None:
95
+ unconditional_prompts = [None] * len(prompts)
96
+ for prompt, sp, uncond_prompt in zip(prompts, sampling_params, unconditional_prompts):
97
+ self.add_request(prompt, sp, uncond_prompt)
98
+ outputs = {}
99
+ prefill_throughput = decode_throughput = 0.
100
+ while not self.is_finished():
101
+ t = perf_counter()
102
+ output, num_tokens = self.step()
103
+ if use_tqdm:
104
+ if num_tokens > 0:
105
+ prefill_throughput = num_tokens / (perf_counter() - t)
106
+ else:
107
+ decode_throughput = -num_tokens / (perf_counter() - t)
108
+ pbar.set_postfix({
109
+ "Prefill": f"{int(prefill_throughput)}tok/s",
110
+ "Decode": f"{int(decode_throughput)}tok/s",
111
+ })
112
+ for seq_id, token_ids in output:
113
+ outputs[seq_id] = token_ids
114
+ if use_tqdm:
115
+ pbar.update(1)
116
+ outputs = [outputs[seq_id] for seq_id in sorted(outputs.keys())]
117
+ outputs = [{"text": self.tokenizer.decode(token_ids), "token_ids": token_ids} for token_ids in outputs]
118
+ if use_tqdm:
119
+ pbar.close()
120
+ return outputs
acestep/third_parts/nano-vllm/nanovllm/engine/model_runner.py ADDED
@@ -0,0 +1,315 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import pickle
2
+ import torch
3
+ import torch.distributed as dist
4
+ from multiprocessing.synchronize import Event
5
+ from multiprocessing.shared_memory import SharedMemory
6
+
7
+ from nanovllm.config import Config
8
+ from nanovllm.engine.sequence import Sequence
9
+ from nanovllm.models.qwen3 import Qwen3ForCausalLM
10
+ from nanovllm.layers.sampler import Sampler
11
+ from nanovllm.utils.context import set_context, get_context, reset_context
12
+ from nanovllm.utils.loader import load_model
13
+
14
+
15
+ class ModelRunner:
16
+
17
+ def __init__(self, config: Config, rank: int, event: Event | list[Event]):
18
+ self.config = config
19
+ hf_config = config.hf_config
20
+ self.block_size = config.kvcache_block_size
21
+ self.enforce_eager = config.enforce_eager
22
+ self.world_size = config.tensor_parallel_size
23
+ self.rank = rank
24
+ self.event = event
25
+
26
+ dist.init_process_group("nccl", "tcp://localhost:2333", world_size=self.world_size, rank=rank)
27
+ torch.cuda.set_device(rank)
28
+ default_dtype = torch.get_default_dtype()
29
+ torch.set_default_dtype(hf_config.torch_dtype)
30
+ torch.set_default_device("cuda")
31
+ self.model = Qwen3ForCausalLM(hf_config)
32
+ load_model(self.model, config.model)
33
+ self.sampler = Sampler()
34
+ self.warmup_model()
35
+ self.allocate_kv_cache()
36
+ if not self.enforce_eager:
37
+ self.capture_cudagraph()
38
+ torch.set_default_device("cpu")
39
+ torch.set_default_dtype(default_dtype)
40
+
41
+ if self.world_size > 1:
42
+ if rank == 0:
43
+ self.shm = SharedMemory(name="nanovllm", create=True, size=2**20)
44
+ dist.barrier()
45
+ else:
46
+ dist.barrier()
47
+ self.shm = SharedMemory(name="nanovllm")
48
+ self.loop()
49
+
50
+ def exit(self):
51
+ if self.world_size > 1:
52
+ self.shm.close()
53
+ dist.barrier()
54
+ if self.rank == 0:
55
+ self.shm.unlink()
56
+ if not self.enforce_eager:
57
+ del self.graphs, self.graph_pool
58
+ torch.cuda.synchronize()
59
+ dist.destroy_process_group()
60
+
61
+ def loop(self):
62
+ while True:
63
+ method_name, args = self.read_shm()
64
+ self.call(method_name, *args)
65
+ if method_name == "exit":
66
+ break
67
+
68
+ def read_shm(self):
69
+ assert self.world_size > 1 and self.rank > 0
70
+ self.event.wait()
71
+ n = int.from_bytes(self.shm.buf[0:4], "little")
72
+ method_name, *args = pickle.loads(self.shm.buf[4:n+4])
73
+ self.event.clear()
74
+ return method_name, args
75
+
76
+ def write_shm(self, method_name, *args):
77
+ assert self.world_size > 1 and self.rank == 0
78
+ data = pickle.dumps([method_name, *args])
79
+ n = len(data)
80
+ self.shm.buf[0:4] = n.to_bytes(4, "little")
81
+ self.shm.buf[4:n+4] = data
82
+ for event in self.event:
83
+ event.set()
84
+
85
+ def call(self, method_name, *args):
86
+ if self.world_size > 1 and self.rank == 0:
87
+ self.write_shm(method_name, *args)
88
+ method = getattr(self, method_name, None)
89
+ return method(*args)
90
+
91
+ def warmup_model(self):
92
+ torch.cuda.empty_cache()
93
+ torch.cuda.reset_peak_memory_stats()
94
+ max_num_batched_tokens, max_model_len = self.config.max_num_batched_tokens, self.config.max_model_len
95
+ num_seqs = min(max_num_batched_tokens // max_model_len, self.config.max_num_seqs)
96
+ seqs = [Sequence([0] * max_model_len) for _ in range(num_seqs)]
97
+ self.run(seqs, True)
98
+ torch.cuda.empty_cache()
99
+
100
+ def allocate_kv_cache(self):
101
+ config = self.config
102
+ hf_config = config.hf_config
103
+ free, total = torch.cuda.mem_get_info()
104
+ used = total - free
105
+ peak = torch.cuda.memory_stats()["allocated_bytes.all.peak"]
106
+ current = torch.cuda.memory_stats()["allocated_bytes.all.current"]
107
+ num_kv_heads = hf_config.num_key_value_heads // self.world_size
108
+ head_dim = getattr(hf_config, "head_dim", hf_config.hidden_size // hf_config.num_attention_heads)
109
+ block_bytes = 2 * hf_config.num_hidden_layers * self.block_size * num_kv_heads * head_dim * hf_config.torch_dtype.itemsize
110
+ config.num_kvcache_blocks = int(total * config.gpu_memory_utilization - used - peak + current) // block_bytes
111
+ assert config.num_kvcache_blocks > 0
112
+ self.kv_cache = torch.empty(2, hf_config.num_hidden_layers, config.num_kvcache_blocks, self.block_size, num_kv_heads, head_dim)
113
+ layer_id = 0
114
+ for module in self.model.modules():
115
+ if hasattr(module, "k_cache") and hasattr(module, "v_cache"):
116
+ module.k_cache = self.kv_cache[0, layer_id]
117
+ module.v_cache = self.kv_cache[1, layer_id]
118
+ layer_id += 1
119
+
120
+ def prepare_block_tables(self, seqs: list[Sequence]):
121
+ max_len = max(len(seq.block_table) for seq in seqs)
122
+ block_tables = [seq.block_table + [-1] * (max_len - len(seq.block_table)) for seq in seqs]
123
+ block_tables = torch.tensor(block_tables, dtype=torch.int32, pin_memory=True).cuda(non_blocking=True)
124
+ return block_tables
125
+
126
+ def prepare_prefill(self, seqs: list[Sequence]):
127
+ input_ids = []
128
+ positions = []
129
+ cu_seqlens_q = [0]
130
+ cu_seqlens_k = [0]
131
+ max_seqlen_q = 0
132
+ max_seqlen_k = 0
133
+ slot_mapping = []
134
+ block_tables = None
135
+ for seq in seqs:
136
+ seqlen = len(seq)
137
+ input_ids.extend(seq[seq.num_cached_tokens:])
138
+ positions.extend(list(range(seq.num_cached_tokens, seqlen)))
139
+ seqlen_q = seqlen - seq.num_cached_tokens
140
+ seqlen_k = seqlen
141
+ cu_seqlens_q.append(cu_seqlens_q[-1] + seqlen_q)
142
+ cu_seqlens_k.append(cu_seqlens_k[-1] + seqlen_k)
143
+ max_seqlen_q = max(seqlen_q, max_seqlen_q)
144
+ max_seqlen_k = max(seqlen_k, max_seqlen_k)
145
+ if not seq.block_table: # warmup
146
+ continue
147
+ for i in range(seq.num_cached_blocks, seq.num_blocks):
148
+ start = seq.block_table[i] * self.block_size
149
+ if i != seq.num_blocks - 1:
150
+ end = start + self.block_size
151
+ else:
152
+ end = start + seq.last_block_num_tokens
153
+ slot_mapping.extend(list(range(start, end)))
154
+ if cu_seqlens_k[-1] > cu_seqlens_q[-1]: # prefix cache
155
+ block_tables = self.prepare_block_tables(seqs)
156
+ input_ids = torch.tensor(input_ids, dtype=torch.int64, pin_memory=True).cuda(non_blocking=True)
157
+ positions = torch.tensor(positions, dtype=torch.int64, pin_memory=True).cuda(non_blocking=True)
158
+ cu_seqlens_q = torch.tensor(cu_seqlens_q, dtype=torch.int32, pin_memory=True).cuda(non_blocking=True)
159
+ cu_seqlens_k = torch.tensor(cu_seqlens_k, dtype=torch.int32, pin_memory=True).cuda(non_blocking=True)
160
+ slot_mapping = torch.tensor(slot_mapping, dtype=torch.int32, pin_memory=True).cuda(non_blocking=True)
161
+ set_context(True, cu_seqlens_q, cu_seqlens_k, max_seqlen_q, max_seqlen_k, slot_mapping, None, block_tables)
162
+ return input_ids, positions
163
+
164
+ def prepare_decode(self, seqs: list[Sequence]):
165
+ input_ids = []
166
+ positions = []
167
+ slot_mapping = []
168
+ context_lens = []
169
+ for seq in seqs:
170
+ input_ids.append(seq.last_token)
171
+ positions.append(len(seq) - 1)
172
+ context_lens.append(len(seq))
173
+ slot_mapping.append(seq.block_table[-1] * self.block_size + seq.last_block_num_tokens - 1)
174
+ input_ids = torch.tensor(input_ids, dtype=torch.int64, pin_memory=True).cuda(non_blocking=True)
175
+ positions = torch.tensor(positions, dtype=torch.int64, pin_memory=True).cuda(non_blocking=True)
176
+ slot_mapping = torch.tensor(slot_mapping, dtype=torch.int32, pin_memory=True).cuda(non_blocking=True)
177
+ context_lens = torch.tensor(context_lens, dtype=torch.int32, pin_memory=True).cuda(non_blocking=True)
178
+ block_tables = self.prepare_block_tables(seqs)
179
+ set_context(False, slot_mapping=slot_mapping, context_lens=context_lens, block_tables=block_tables)
180
+ return input_ids, positions
181
+
182
+ def prepare_sample(self, seqs: list[Sequence], is_cfg_batch: bool = False):
183
+ """Prepare sampling parameters. For CFG batch, only return parameters for conditional sequences."""
184
+ if is_cfg_batch:
185
+ # For CFG batch, seqs contains [cond_seq1, cond_seq2, ..., uncond_seq1, uncond_seq2, ...]
186
+ # We only need temperatures for conditional sequences (first half)
187
+ num_cond = len(seqs) // 2
188
+ temperatures = []
189
+ cfg_scales = []
190
+ for seq in seqs[:num_cond]:
191
+ temperatures.append(seq.temperature)
192
+ cfg_scales.append(seq.cfg_scale)
193
+ else:
194
+ temperatures = []
195
+ cfg_scales = []
196
+ for seq in seqs:
197
+ temperatures.append(seq.temperature)
198
+ cfg_scales.append(seq.cfg_scale)
199
+ temperatures = torch.tensor(temperatures, dtype=torch.float32, pin_memory=True).cuda(non_blocking=True)
200
+ cfg_scales = torch.tensor(cfg_scales, dtype=torch.float32, pin_memory=True).cuda(non_blocking=True)
201
+ return temperatures, cfg_scales
202
+
203
+ @torch.inference_mode()
204
+ def run_model(self, input_ids: torch.Tensor, positions: torch.Tensor, is_prefill: bool):
205
+ if is_prefill or self.enforce_eager or input_ids.size(0) > 512:
206
+ return self.model.compute_logits(self.model(input_ids, positions))
207
+ else:
208
+ bs = input_ids.size(0)
209
+ context = get_context()
210
+ graph = self.graphs[next(x for x in self.graph_bs if x >= bs)]
211
+ graph_vars = self.graph_vars
212
+ graph_vars["input_ids"][:bs] = input_ids
213
+ graph_vars["positions"][:bs] = positions
214
+ graph_vars["slot_mapping"].fill_(-1)
215
+ graph_vars["slot_mapping"][:bs] = context.slot_mapping
216
+ graph_vars["context_lens"].zero_()
217
+ graph_vars["context_lens"][:bs] = context.context_lens
218
+ graph_vars["block_tables"][:bs, :context.block_tables.size(1)] = context.block_tables
219
+ graph.replay()
220
+ return self.model.compute_logits(graph_vars["outputs"][:bs])
221
+
222
+ def run(self, seqs: list[Sequence], is_prefill: bool) -> list[int]:
223
+ """Run model forward and sampling. For CFG sequences, batch is structured as:
224
+ [cond_seq1, cond_seq2, ..., uncond_seq1, uncond_seq2, ...]
225
+ where uncond_seqi is the paired unconditional sequence of cond_seqi."""
226
+ # Check if this is a CFG batch (contains paired conditional and unconditional sequences)
227
+ is_cfg_batch = False
228
+ if len(seqs) > 0:
229
+ # CFG batch if first sequence has cfg_scale > 1.0 and paired_seq
230
+ if seqs[0].cfg_scale > 1.0 and seqs[0].paired_seq is not None:
231
+ is_cfg_batch = True
232
+ # Verify batch structure: first half conditional, second half unconditional
233
+ num_cond = len(seqs) // 2
234
+ for i in range(num_cond):
235
+ if seqs[i].is_unconditional or seqs[i + num_cond].is_unconditional == False:
236
+ is_cfg_batch = False
237
+ break
238
+
239
+ if is_cfg_batch:
240
+ # CFG batch: seqs = [cond_seq1, cond_seq2, ..., uncond_seq1, uncond_seq2, ...]
241
+ num_cond = len(seqs) // 2
242
+ cond_seqs = seqs[:num_cond]
243
+ uncond_seqs = seqs[num_cond:]
244
+
245
+ # Prepare inputs for both conditional and unconditional (they're already in the batch)
246
+ input_ids, positions = (self.prepare_prefill(seqs) if is_prefill
247
+ else self.prepare_decode(seqs))
248
+ temperatures, cfg_scales = self.prepare_sample(seqs, is_cfg_batch=True) if self.rank == 0 else (None, None)
249
+
250
+ # Run model forward (processes entire batch: cond + uncond)
251
+ logits_all = self.run_model(input_ids, positions, is_prefill)
252
+ reset_context()
253
+
254
+ if self.rank == 0:
255
+ # Split logits: first half is conditional, second half is unconditional
256
+ logits_cond = logits_all[:num_cond]
257
+ logits_uncond = logits_all[num_cond:]
258
+
259
+ # Apply CFG formula: logits_cfg = logits_cond + cfg_scale * (logits_cond - logits_uncond)
260
+ cfg_scales_tensor = cfg_scales.unsqueeze(1) # [num_cond, 1]
261
+ logits_cfg = logits_cond + cfg_scales_tensor * (logits_cond - logits_uncond)
262
+
263
+ # Sample from CFG logits
264
+ token_ids_cfg = self.sampler(logits_cfg, temperatures).tolist()
265
+
266
+ # Return token_ids (will be applied to both conditional and unconditional sequences)
267
+ return token_ids_cfg
268
+ else:
269
+ return None
270
+ else:
271
+ # Normal batch (non-CFG)
272
+ input_ids, positions = (self.prepare_prefill(seqs) if is_prefill
273
+ else self.prepare_decode(seqs))
274
+ temperatures, cfg_scales = self.prepare_sample(seqs, is_cfg_batch=False) if self.rank == 0 else (None, None)
275
+ logits = self.run_model(input_ids, positions, is_prefill)
276
+ reset_context()
277
+ token_ids = self.sampler(logits, temperatures).tolist() if self.rank == 0 else None
278
+ return token_ids
279
+
280
+ @torch.inference_mode()
281
+ def capture_cudagraph(self):
282
+ config = self.config
283
+ hf_config = config.hf_config
284
+ max_bs = min(self.config.max_num_seqs, 512)
285
+ max_num_blocks = (config.max_model_len + self.block_size - 1) // self.block_size
286
+ input_ids = torch.zeros(max_bs, dtype=torch.int64)
287
+ positions = torch.zeros(max_bs, dtype=torch.int64)
288
+ slot_mapping = torch.zeros(max_bs, dtype=torch.int32)
289
+ context_lens = torch.zeros(max_bs, dtype=torch.int32)
290
+ block_tables = torch.zeros(max_bs, max_num_blocks, dtype=torch.int32)
291
+ outputs = torch.zeros(max_bs, hf_config.hidden_size)
292
+ self.graph_bs = [1, 2, 4, 8] + list(range(16, max_bs + 1, 16))
293
+ self.graphs = {}
294
+ self.graph_pool = None
295
+
296
+ for bs in reversed(self.graph_bs):
297
+ graph = torch.cuda.CUDAGraph()
298
+ set_context(False, slot_mapping=slot_mapping[:bs], context_lens=context_lens[:bs], block_tables=block_tables[:bs])
299
+ outputs[:bs] = self.model(input_ids[:bs], positions[:bs]) # warmup
300
+ with torch.cuda.graph(graph, self.graph_pool):
301
+ outputs[:bs] = self.model(input_ids[:bs], positions[:bs]) # capture
302
+ if self.graph_pool is None:
303
+ self.graph_pool = graph.pool()
304
+ self.graphs[bs] = graph
305
+ torch.cuda.synchronize()
306
+ reset_context()
307
+
308
+ self.graph_vars = dict(
309
+ input_ids=input_ids,
310
+ positions=positions,
311
+ slot_mapping=slot_mapping,
312
+ context_lens=context_lens,
313
+ block_tables=block_tables,
314
+ outputs=outputs,
315
+ )
acestep/third_parts/nano-vllm/nanovllm/engine/scheduler.py ADDED
@@ -0,0 +1,222 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from collections import deque
2
+
3
+ from nanovllm.config import Config
4
+ from nanovllm.engine.sequence import Sequence, SequenceStatus
5
+ from nanovllm.engine.block_manager import BlockManager
6
+
7
+
8
+ class Scheduler:
9
+
10
+ def __init__(self, config: Config):
11
+ self.max_num_seqs = config.max_num_seqs
12
+ self.max_num_batched_tokens = config.max_num_batched_tokens
13
+ self.eos = config.eos
14
+ self.block_manager = BlockManager(config.num_kvcache_blocks, config.kvcache_block_size)
15
+ self.waiting: deque[Sequence] = deque()
16
+ self.running: deque[Sequence] = deque()
17
+
18
+ def is_finished(self):
19
+ return not self.waiting and not self.running
20
+
21
+ def add(self, seq: Sequence):
22
+ self.waiting.append(seq)
23
+
24
+ def schedule(self) -> tuple[list[Sequence], bool]:
25
+ # prefill
26
+ scheduled_seqs = []
27
+ num_seqs = 0
28
+ num_batched_tokens = 0
29
+ processed_seqs = set() # Track processed sequences to handle CFG pairs
30
+
31
+ while self.waiting and num_seqs < self.max_num_seqs:
32
+ seq = self.waiting[0]
33
+
34
+ # For CFG sequences, ensure conditional and unconditional are scheduled together
35
+ if seq.cfg_scale > 1.0 and seq.paired_seq is not None and not seq.is_unconditional:
36
+ # This is a conditional sequence, need to schedule its paired unconditional sequence too
37
+ paired_seq = seq.paired_seq
38
+ if paired_seq.status != SequenceStatus.WAITING:
39
+ # Paired sequence not in waiting, skip this conditional sequence for now
40
+ break
41
+
42
+ # Calculate tokens for both sequences
43
+ total_tokens = (len(seq) - seq.num_cached_tokens) + (len(paired_seq) - paired_seq.num_cached_tokens)
44
+ can_allocate_both = (self.block_manager.can_allocate(seq) and
45
+ self.block_manager.can_allocate(paired_seq))
46
+
47
+ if num_batched_tokens + total_tokens > self.max_num_batched_tokens or not can_allocate_both:
48
+ break
49
+
50
+ # Schedule both sequences: conditional first, then unconditional
51
+ for s in [seq, paired_seq]:
52
+ num_seqs += 1
53
+ self.block_manager.allocate(s)
54
+ num_batched_tokens += len(s) - s.num_cached_tokens
55
+ s.status = SequenceStatus.RUNNING
56
+ self.waiting.remove(s)
57
+ self.running.append(s)
58
+ scheduled_seqs.append(s)
59
+ processed_seqs.add(s.seq_id)
60
+ else:
61
+ # Normal sequence or unconditional sequence (already processed with its conditional)
62
+ if seq.seq_id in processed_seqs:
63
+ # Skip if already processed as part of a CFG pair
64
+ self.waiting.popleft()
65
+ continue
66
+
67
+ if num_batched_tokens + len(seq) > self.max_num_batched_tokens or not self.block_manager.can_allocate(seq):
68
+ break
69
+ num_seqs += 1
70
+ self.block_manager.allocate(seq)
71
+ num_batched_tokens += len(seq) - seq.num_cached_tokens
72
+ seq.status = SequenceStatus.RUNNING
73
+ self.waiting.popleft()
74
+ self.running.append(seq)
75
+ scheduled_seqs.append(seq)
76
+
77
+ if scheduled_seqs:
78
+ # For CFG batches, ensure conditional sequences come before their unconditional pairs
79
+ cfg_cond_seqs = [s for s in scheduled_seqs if s.cfg_scale > 1.0 and not s.is_unconditional]
80
+ cfg_uncond_seqs = [s for s in scheduled_seqs if s.is_unconditional]
81
+ non_cfg_seqs = [s for s in scheduled_seqs if s.cfg_scale <= 1.0]
82
+
83
+ # Reorder: non-CFG, then CFG conditional, then CFG unconditional
84
+ scheduled_seqs = non_cfg_seqs + cfg_cond_seqs + cfg_uncond_seqs
85
+ return scheduled_seqs, True
86
+
87
+ # decode
88
+ processed_seqs = set()
89
+ temp_running = list(self.running) # Work with a copy
90
+
91
+ while temp_running and num_seqs < self.max_num_seqs:
92
+ seq = temp_running.pop(0)
93
+
94
+ # For CFG sequences, ensure conditional and unconditional are scheduled together
95
+ if seq.cfg_scale > 1.0 and seq.paired_seq is not None and not seq.is_unconditional:
96
+ paired_seq = seq.paired_seq
97
+ if paired_seq not in temp_running:
98
+ # Paired sequence not available, skip for now
99
+ continue
100
+
101
+ # Remove paired_seq from temp_running
102
+ temp_running.remove(paired_seq)
103
+
104
+ # Check if both can append
105
+ can_append_both = (self.block_manager.can_append(seq) and
106
+ self.block_manager.can_append(paired_seq))
107
+
108
+ if not can_append_both:
109
+ # Try preempting other sequences
110
+ preempted = False
111
+ while not can_append_both and temp_running:
112
+ other_seq = temp_running.pop(0)
113
+ if other_seq != seq and other_seq != paired_seq:
114
+ self.preempt(other_seq)
115
+ can_append_both = (self.block_manager.can_append(seq) and
116
+ self.block_manager.can_append(paired_seq))
117
+ preempted = True
118
+ else:
119
+ temp_running.append(other_seq)
120
+ break
121
+
122
+ if not can_append_both:
123
+ # Can't schedule this pair right now
124
+ temp_running.append(seq)
125
+ temp_running.append(paired_seq)
126
+ continue
127
+
128
+ # Schedule both sequences
129
+ for s in [seq, paired_seq]:
130
+ num_seqs += 1
131
+ self.block_manager.may_append(s)
132
+ scheduled_seqs.append(s)
133
+ processed_seqs.add(s.seq_id)
134
+ # Remove from actual running list if scheduled
135
+ if s in self.running:
136
+ self.running.remove(s)
137
+ else:
138
+ # Normal sequence or unconditional (already processed)
139
+ if seq.seq_id in processed_seqs:
140
+ continue
141
+
142
+ while not self.block_manager.can_append(seq):
143
+ if temp_running:
144
+ other_seq = temp_running.pop(0)
145
+ if other_seq != seq:
146
+ self.preempt(other_seq)
147
+ else:
148
+ temp_running.append(other_seq)
149
+ break
150
+ else:
151
+ self.preempt(seq)
152
+ if seq in self.running:
153
+ self.running.remove(seq)
154
+ break
155
+ else:
156
+ num_seqs += 1
157
+ self.block_manager.may_append(seq)
158
+ scheduled_seqs.append(seq)
159
+ if seq in self.running:
160
+ self.running.remove(seq)
161
+
162
+ assert scheduled_seqs
163
+
164
+ # For CFG batches in decode, ensure conditional sequences come before unconditional
165
+ cfg_cond_seqs = [s for s in scheduled_seqs if s.cfg_scale > 1.0 and not s.is_unconditional]
166
+ cfg_uncond_seqs = [s for s in scheduled_seqs if s.is_unconditional]
167
+ non_cfg_seqs = [s for s in scheduled_seqs if s.cfg_scale <= 1.0]
168
+ scheduled_seqs = non_cfg_seqs + cfg_cond_seqs + cfg_uncond_seqs
169
+
170
+ self.running.extendleft(reversed(scheduled_seqs))
171
+ return scheduled_seqs, False
172
+
173
+ def preempt(self, seq: Sequence):
174
+ seq.status = SequenceStatus.WAITING
175
+ self.block_manager.deallocate(seq)
176
+ self.waiting.appendleft(seq)
177
+
178
+ def postprocess(self, seqs: list[Sequence], token_ids: list[int]) -> list[bool]:
179
+ # Check if this is a CFG batch
180
+ is_cfg_batch = False
181
+ if len(seqs) > 0 and seqs[0].cfg_scale > 1.0 and seqs[0].paired_seq is not None:
182
+ num_cond = len(seqs) // 2
183
+ is_cfg_batch = (num_cond > 0 and
184
+ not seqs[0].is_unconditional and
185
+ seqs[num_cond].is_unconditional)
186
+
187
+ if is_cfg_batch:
188
+ # CFG batch: seqs = [cond_seq1, cond_seq2, ..., uncond_seq1, uncond_seq2, ...]
189
+ # token_ids correspond to conditional sequences only (sampled from CFG logits)
190
+ num_cond = len(seqs) // 2
191
+ cond_seqs = seqs[:num_cond]
192
+ uncond_seqs = seqs[num_cond:]
193
+
194
+ # Apply the same sampled token to both conditional and unconditional sequences
195
+ for i, (cond_seq, uncond_seq, token_id) in enumerate(zip(cond_seqs, uncond_seqs, token_ids)):
196
+ cond_seq.append_token(token_id)
197
+ uncond_seq.append_token(token_id) # Same token for unconditional
198
+
199
+ # Check if either sequence is finished
200
+ cond_finished = ((not cond_seq.ignore_eos and token_id == self.eos) or
201
+ cond_seq.num_completion_tokens == cond_seq.max_tokens)
202
+ uncond_finished = ((not uncond_seq.ignore_eos and token_id == self.eos) or
203
+ uncond_seq.num_completion_tokens == uncond_seq.max_tokens)
204
+
205
+ if cond_finished or uncond_finished:
206
+ # Mark both as finished
207
+ cond_seq.status = SequenceStatus.FINISHED
208
+ uncond_seq.status = SequenceStatus.FINISHED
209
+ self.block_manager.deallocate(cond_seq)
210
+ self.block_manager.deallocate(uncond_seq)
211
+ if cond_seq in self.running:
212
+ self.running.remove(cond_seq)
213
+ if uncond_seq in self.running:
214
+ self.running.remove(uncond_seq)
215
+ else:
216
+ # Normal batch
217
+ for seq, token_id in zip(seqs, token_ids):
218
+ seq.append_token(token_id)
219
+ if (not seq.ignore_eos and token_id == self.eos) or seq.num_completion_tokens == seq.max_tokens:
220
+ seq.status = SequenceStatus.FINISHED
221
+ self.block_manager.deallocate(seq)
222
+ self.running.remove(seq)
acestep/third_parts/nano-vllm/nanovllm/engine/sequence.py ADDED
@@ -0,0 +1,89 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from copy import copy
2
+ from enum import Enum, auto
3
+ from itertools import count
4
+
5
+ from nanovllm.sampling_params import SamplingParams
6
+
7
+
8
+ class SequenceStatus(Enum):
9
+ WAITING = auto()
10
+ RUNNING = auto()
11
+ FINISHED = auto()
12
+
13
+
14
+ class Sequence:
15
+ block_size = 256
16
+ counter = count()
17
+
18
+ def __init__(self, token_ids: list[int], sampling_params = SamplingParams(), is_unconditional: bool = False, conditional_seq = None):
19
+ self.seq_id = next(Sequence.counter)
20
+ self.status = SequenceStatus.WAITING
21
+ self.token_ids = copy(token_ids)
22
+ self.last_token = token_ids[-1]
23
+ self.num_tokens = len(self.token_ids)
24
+ self.num_prompt_tokens = len(token_ids)
25
+ self.num_cached_tokens = 0
26
+ self.block_table = []
27
+ self.temperature = sampling_params.temperature
28
+ self.max_tokens = sampling_params.max_tokens
29
+ self.ignore_eos = sampling_params.ignore_eos
30
+ self.cfg_scale = sampling_params.cfg_scale
31
+ # For CFG: mark if this is an unconditional sequence
32
+ self.is_unconditional = is_unconditional
33
+ # For CFG: reference to the corresponding conditional sequence (if this is unconditional)
34
+ # For conditional sequences, this points to the unconditional sequence
35
+ self.paired_seq = conditional_seq # For conditional seq, points to uncond; for uncond seq, points to cond
36
+
37
+ def __len__(self):
38
+ return self.num_tokens
39
+
40
+ def __getitem__(self, key):
41
+ return self.token_ids[key]
42
+
43
+ @property
44
+ def is_finished(self):
45
+ return self.status == SequenceStatus.FINISHED
46
+
47
+ @property
48
+ def num_completion_tokens(self):
49
+ return self.num_tokens - self.num_prompt_tokens
50
+
51
+ @property
52
+ def prompt_token_ids(self):
53
+ return self.token_ids[:self.num_prompt_tokens]
54
+
55
+ @property
56
+ def completion_token_ids(self):
57
+ return self.token_ids[self.num_prompt_tokens:]
58
+
59
+ @property
60
+ def num_cached_blocks(self):
61
+ return self.num_cached_tokens // self.block_size
62
+
63
+ @property
64
+ def num_blocks(self):
65
+ return (self.num_tokens + self.block_size - 1) // self.block_size
66
+
67
+ @property
68
+ def last_block_num_tokens(self):
69
+ return self.num_tokens - (self.num_blocks - 1) * self.block_size
70
+
71
+ def block(self, i):
72
+ assert 0 <= i < self.num_blocks
73
+ return self.token_ids[i*self.block_size: (i+1)*self.block_size]
74
+
75
+ def append_token(self, token_id: int):
76
+ self.token_ids.append(token_id)
77
+ self.last_token = token_id
78
+ self.num_tokens += 1
79
+
80
+ def __getstate__(self):
81
+ return (self.num_tokens, self.num_prompt_tokens, self.num_cached_tokens, self.block_table,
82
+ self.token_ids if self.num_completion_tokens == 0 else self.last_token)
83
+
84
+ def __setstate__(self, state):
85
+ self.num_tokens, self.num_prompt_tokens, self.num_cached_tokens, self.block_table = state[:-1]
86
+ if self.num_completion_tokens == 0:
87
+ self.token_ids = state[-1]
88
+ else:
89
+ self.last_token = state[-1]
acestep/third_parts/nano-vllm/nanovllm/layers/activation.py ADDED
@@ -0,0 +1,14 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from torch import nn
3
+ import torch.nn.functional as F
4
+
5
+
6
+ class SiluAndMul(nn.Module):
7
+
8
+ def __init__(self):
9
+ super().__init__()
10
+
11
+ @torch.compile
12
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
13
+ x, y = x.chunk(2, -1)
14
+ return F.silu(x) * y
acestep/third_parts/nano-vllm/nanovllm/layers/attention.py ADDED
@@ -0,0 +1,75 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from torch import nn
3
+ import triton
4
+ import triton.language as tl
5
+
6
+ from flash_attn import flash_attn_varlen_func, flash_attn_with_kvcache
7
+ from nanovllm.utils.context import get_context
8
+
9
+
10
+ @triton.jit
11
+ def store_kvcache_kernel(
12
+ key_ptr,
13
+ key_stride,
14
+ value_ptr,
15
+ value_stride,
16
+ k_cache_ptr,
17
+ v_cache_ptr,
18
+ slot_mapping_ptr,
19
+ D: tl.constexpr,
20
+ ):
21
+ idx = tl.program_id(0)
22
+ slot = tl.load(slot_mapping_ptr + idx)
23
+ if slot == -1: return
24
+ key_offsets = idx * key_stride + tl.arange(0, D)
25
+ value_offsets = idx * value_stride + tl.arange(0, D)
26
+ key = tl.load(key_ptr + key_offsets)
27
+ value = tl.load(value_ptr + value_offsets)
28
+ cache_offsets = slot * D + tl.arange(0, D)
29
+ tl.store(k_cache_ptr + cache_offsets, key)
30
+ tl.store(v_cache_ptr + cache_offsets, value)
31
+
32
+
33
+ def store_kvcache(key: torch.Tensor, value: torch.Tensor, k_cache: torch.Tensor, v_cache: torch.Tensor, slot_mapping: torch.Tensor):
34
+ N, num_heads, head_dim = key.shape
35
+ D = num_heads * head_dim
36
+ assert key.stride(-1) == 1 and value.stride(-1) == 1
37
+ assert key.stride(1) == head_dim and value.stride(1) == head_dim
38
+ assert k_cache.stride(1) == D and v_cache.stride(1) == D
39
+ assert slot_mapping.numel() == N
40
+ store_kvcache_kernel[(N,)](key, key.stride(0), value, value.stride(0), k_cache, v_cache, slot_mapping, D)
41
+
42
+
43
+ class Attention(nn.Module):
44
+
45
+ def __init__(
46
+ self,
47
+ num_heads,
48
+ head_dim,
49
+ scale,
50
+ num_kv_heads,
51
+ ):
52
+ super().__init__()
53
+ self.num_heads = num_heads
54
+ self.head_dim = head_dim
55
+ self.scale = scale
56
+ self.num_kv_heads = num_kv_heads
57
+ self.k_cache = self.v_cache = torch.tensor([])
58
+
59
+ def forward(self, q: torch.Tensor, k: torch.Tensor, v: torch.Tensor):
60
+ context = get_context()
61
+ k_cache, v_cache = self.k_cache, self.v_cache
62
+ if k_cache.numel() and v_cache.numel():
63
+ store_kvcache(k, v, k_cache, v_cache, context.slot_mapping)
64
+ if context.is_prefill:
65
+ if context.block_tables is not None: # prefix cache
66
+ k, v = k_cache, v_cache
67
+ o = flash_attn_varlen_func(q, k, v,
68
+ max_seqlen_q=context.max_seqlen_q, cu_seqlens_q=context.cu_seqlens_q,
69
+ max_seqlen_k=context.max_seqlen_k, cu_seqlens_k=context.cu_seqlens_k,
70
+ softmax_scale=self.scale, causal=True, block_table=context.block_tables)
71
+ else: # decode
72
+ o = flash_attn_with_kvcache(q.unsqueeze(1), k_cache, v_cache,
73
+ cache_seqlens=context.context_lens, block_table=context.block_tables,
74
+ softmax_scale=self.scale, causal=True)
75
+ return o
acestep/third_parts/nano-vllm/nanovllm/layers/embed_head.py ADDED
@@ -0,0 +1,66 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from torch import nn
3
+ import torch.nn.functional as F
4
+ import torch.distributed as dist
5
+
6
+ from nanovllm.utils.context import get_context
7
+
8
+
9
+ class VocabParallelEmbedding(nn.Module):
10
+
11
+ def __init__(
12
+ self,
13
+ num_embeddings: int,
14
+ embedding_dim: int,
15
+ ):
16
+ super().__init__()
17
+ self.tp_rank = dist.get_rank()
18
+ self.tp_size = dist.get_world_size()
19
+ assert num_embeddings % self.tp_size == 0
20
+ self.num_embeddings = num_embeddings
21
+ self.num_embeddings_per_partition = self.num_embeddings // self.tp_size
22
+ self.vocab_start_idx = self.num_embeddings_per_partition * self.tp_rank
23
+ self.vocab_end_idx = self.vocab_start_idx + self.num_embeddings_per_partition
24
+ self.weight = nn.Parameter(torch.empty(self.num_embeddings_per_partition, embedding_dim))
25
+ self.weight.weight_loader = self.weight_loader
26
+
27
+ def weight_loader(self, param: nn.Parameter, loaded_weight: torch.Tensor):
28
+ param_data = param.data
29
+ shard_size = param_data.size(0)
30
+ start_idx = self.tp_rank * shard_size
31
+ loaded_weight = loaded_weight.narrow(0, start_idx, shard_size)
32
+ param_data.copy_(loaded_weight)
33
+
34
+ def forward(self, x: torch.Tensor):
35
+ if self.tp_size > 1:
36
+ mask = (x >= self.vocab_start_idx) & (x < self.vocab_end_idx)
37
+ x = mask * (x - self.vocab_start_idx)
38
+ y = F.embedding(x, self.weight)
39
+ if self.tp_size > 1:
40
+ y = mask.unsqueeze(1) * y
41
+ dist.all_reduce(y)
42
+ return y
43
+
44
+
45
+ class ParallelLMHead(VocabParallelEmbedding):
46
+
47
+ def __init__(
48
+ self,
49
+ num_embeddings: int,
50
+ embedding_dim: int,
51
+ bias: bool = False,
52
+ ):
53
+ assert not bias
54
+ super().__init__(num_embeddings, embedding_dim)
55
+
56
+ def forward(self, x: torch.Tensor):
57
+ context = get_context()
58
+ if context.is_prefill:
59
+ last_indices = context.cu_seqlens_q[1:] - 1
60
+ x = x[last_indices].contiguous()
61
+ logits = F.linear(x, self.weight)
62
+ if self.tp_size > 1:
63
+ all_logits = [torch.empty_like(logits) for _ in range(self.tp_size)] if self.tp_rank == 0 else None
64
+ dist.gather(logits, all_logits, 0)
65
+ logits = torch.cat(all_logits, -1) if self.tp_rank == 0 else None
66
+ return logits
acestep/third_parts/nano-vllm/nanovllm/layers/layernorm.py ADDED
@@ -0,0 +1,50 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from torch import nn
3
+
4
+
5
+ class RMSNorm(nn.Module):
6
+
7
+ def __init__(
8
+ self,
9
+ hidden_size: int,
10
+ eps: float = 1e-6,
11
+ ) -> None:
12
+ super().__init__()
13
+ self.eps = eps
14
+ self.weight = nn.Parameter(torch.ones(hidden_size))
15
+
16
+ @torch.compile
17
+ def rms_forward(
18
+ self,
19
+ x: torch.Tensor,
20
+ ) -> torch.Tensor:
21
+ orig_dtype = x.dtype
22
+ x = x.float()
23
+ var = x.pow(2).mean(dim=-1, keepdim=True)
24
+ x.mul_(torch.rsqrt(var + self.eps))
25
+ x = x.to(orig_dtype).mul_(self.weight)
26
+ return x
27
+
28
+ @torch.compile
29
+ def add_rms_forward(
30
+ self,
31
+ x: torch.Tensor,
32
+ residual: torch.Tensor,
33
+ ) -> tuple[torch.Tensor, torch.Tensor]:
34
+ orig_dtype = x.dtype
35
+ x = x.float().add_(residual.float())
36
+ residual = x.to(orig_dtype)
37
+ var = x.pow(2).mean(dim=-1, keepdim=True)
38
+ x.mul_(torch.rsqrt(var + self.eps))
39
+ x = x.to(orig_dtype).mul_(self.weight)
40
+ return x, residual
41
+
42
+ def forward(
43
+ self,
44
+ x: torch.Tensor,
45
+ residual: torch.Tensor | None = None,
46
+ ) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
47
+ if residual is None:
48
+ return self.rms_forward(x)
49
+ else:
50
+ return self.add_rms_forward(x, residual)
acestep/third_parts/nano-vllm/nanovllm/layers/linear.py ADDED
@@ -0,0 +1,153 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from torch import nn
3
+ import torch.nn.functional as F
4
+ import torch.distributed as dist
5
+
6
+
7
+ def divide(numerator, denominator):
8
+ assert numerator % denominator == 0
9
+ return numerator // denominator
10
+
11
+
12
+ class LinearBase(nn.Module):
13
+
14
+ def __init__(
15
+ self,
16
+ input_size: int,
17
+ output_size: int,
18
+ bias: bool = False,
19
+ tp_dim: int | None = None,
20
+ ):
21
+ super().__init__()
22
+ self.tp_dim = tp_dim
23
+ self.tp_rank = dist.get_rank()
24
+ self.tp_size = dist.get_world_size()
25
+ self.weight = nn.Parameter(torch.empty(output_size, input_size))
26
+ self.weight.weight_loader = self.weight_loader
27
+ if bias:
28
+ self.bias = nn.Parameter(torch.empty(output_size))
29
+ self.bias.weight_loader = self.weight_loader
30
+ else:
31
+ self.register_parameter("bias", None)
32
+
33
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
34
+ raise NotImplementedError
35
+
36
+
37
+ class ReplicatedLinear(LinearBase):
38
+
39
+ def __init__(
40
+ self,
41
+ input_size: int,
42
+ output_size: int,
43
+ bias: bool = False,
44
+ ):
45
+ super().__init__(input_size, output_size, bias)
46
+
47
+ def weight_loader(self, param: nn.Parameter, loaded_weight: torch.Tensor):
48
+ param.data.copy_(loaded_weight)
49
+
50
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
51
+ return F.linear(x, self.weight, self.bias)
52
+
53
+
54
+ class ColumnParallelLinear(LinearBase):
55
+
56
+ def __init__(
57
+ self,
58
+ input_size: int,
59
+ output_size: int,
60
+ bias: bool = False,
61
+ ):
62
+ tp_size = dist.get_world_size()
63
+ super().__init__(input_size, divide(output_size, tp_size), bias, 0)
64
+
65
+ def weight_loader(self, param: nn.Parameter, loaded_weight: torch.Tensor):
66
+ param_data = param.data
67
+ shard_size = param_data.size(self.tp_dim)
68
+ start_idx = self.tp_rank * shard_size
69
+ loaded_weight = loaded_weight.narrow(self.tp_dim, start_idx, shard_size)
70
+ param_data.copy_(loaded_weight)
71
+
72
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
73
+ return F.linear(x, self.weight, self.bias)
74
+
75
+
76
+ class MergedColumnParallelLinear(ColumnParallelLinear):
77
+
78
+ def __init__(
79
+ self,
80
+ input_size: int,
81
+ output_sizes: list[int],
82
+ bias: bool = False,
83
+ ):
84
+ self.output_sizes = output_sizes
85
+ super().__init__(input_size, sum(output_sizes), bias)
86
+
87
+ def weight_loader(self, param: nn.Parameter, loaded_weight: torch.Tensor, loaded_shard_id: int):
88
+ param_data = param.data
89
+ shard_offset = sum(self.output_sizes[:loaded_shard_id]) // self.tp_size
90
+ shard_size = self.output_sizes[loaded_shard_id] // self.tp_size
91
+ param_data = param_data.narrow(self.tp_dim, shard_offset, shard_size)
92
+ loaded_weight = loaded_weight.chunk(self.tp_size, self.tp_dim)[self.tp_rank]
93
+ param_data.copy_(loaded_weight)
94
+
95
+
96
+ class QKVParallelLinear(ColumnParallelLinear):
97
+
98
+ def __init__(
99
+ self,
100
+ hidden_size: int,
101
+ head_size: int,
102
+ total_num_heads: int,
103
+ total_num_kv_heads: int | None = None,
104
+ bias: bool = False,
105
+ ):
106
+ tp_size = dist.get_world_size()
107
+ total_num_kv_heads = total_num_kv_heads or total_num_heads
108
+ self.head_size = head_size
109
+ self.num_heads = divide(total_num_heads, tp_size)
110
+ self.num_kv_heads = divide(total_num_kv_heads, tp_size)
111
+ output_size = (total_num_heads + 2 * total_num_kv_heads) * self.head_size
112
+ super().__init__(hidden_size, output_size, bias)
113
+
114
+ def weight_loader(self, param: nn.Parameter, loaded_weight: torch.Tensor, loaded_shard_id: str):
115
+ param_data = param.data
116
+ assert loaded_shard_id in ["q", "k", "v"]
117
+ if loaded_shard_id == "q":
118
+ shard_size = self.num_heads * self.head_size
119
+ shard_offset = 0
120
+ elif loaded_shard_id == "k":
121
+ shard_size = self.num_kv_heads * self.head_size
122
+ shard_offset = self.num_heads * self.head_size
123
+ else:
124
+ shard_size = self.num_kv_heads * self.head_size
125
+ shard_offset = self.num_heads * self.head_size + self.num_kv_heads * self.head_size
126
+ param_data = param_data.narrow(self.tp_dim, shard_offset, shard_size)
127
+ loaded_weight = loaded_weight.chunk(self.tp_size, self.tp_dim)[self.tp_rank]
128
+ param_data.copy_(loaded_weight)
129
+
130
+
131
+ class RowParallelLinear(LinearBase):
132
+
133
+ def __init__(
134
+ self,
135
+ input_size: int,
136
+ output_size: int,
137
+ bias: bool = False,
138
+ ):
139
+ tp_size = dist.get_world_size()
140
+ super().__init__(divide(input_size, tp_size), output_size, bias, 1)
141
+
142
+ def weight_loader(self, param: nn.Parameter, loaded_weight: torch.Tensor):
143
+ param_data = param.data
144
+ shard_size = param_data.size(self.tp_dim)
145
+ start_idx = self.tp_rank * shard_size
146
+ loaded_weight = loaded_weight.narrow(self.tp_dim, start_idx, shard_size)
147
+ param_data.copy_(loaded_weight)
148
+
149
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
150
+ y = F.linear(x, self.weight, self.bias if self.tp_rank == 0 else None)
151
+ if self.tp_size > 1:
152
+ dist.all_reduce(y)
153
+ return y
acestep/third_parts/nano-vllm/nanovllm/layers/rotary_embedding.py ADDED
@@ -0,0 +1,61 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from functools import lru_cache
2
+ import torch
3
+ from torch import nn
4
+
5
+
6
+ def apply_rotary_emb(
7
+ x: torch.Tensor,
8
+ cos: torch.Tensor,
9
+ sin: torch.Tensor,
10
+ ) -> torch.Tensor:
11
+ x1, x2 = torch.chunk(x.float(), 2, dim=-1)
12
+ y1 = x1 * cos - x2 * sin
13
+ y2 = x2 * cos + x1 * sin
14
+ return torch.cat((y1, y2), dim=-1).to(x.dtype)
15
+
16
+
17
+ class RotaryEmbedding(nn.Module):
18
+
19
+ def __init__(
20
+ self,
21
+ head_size: int,
22
+ rotary_dim: int,
23
+ max_position_embeddings: int,
24
+ base: float,
25
+ ) -> None:
26
+ super().__init__()
27
+ self.head_size = head_size
28
+ assert rotary_dim == head_size
29
+ inv_freq = 1.0 / (base**(torch.arange(0, rotary_dim, 2, dtype=torch.float) / rotary_dim))
30
+ t = torch.arange(max_position_embeddings, dtype=torch.float)
31
+ freqs = torch.einsum("i,j -> ij", t, inv_freq)
32
+ cos = freqs.cos()
33
+ sin = freqs.sin()
34
+ cache = torch.cat((cos, sin), dim=-1).unsqueeze_(1)
35
+ self.register_buffer("cos_sin_cache", cache, persistent=False)
36
+
37
+ @torch.compile
38
+ def forward(
39
+ self,
40
+ positions: torch.Tensor,
41
+ query: torch.Tensor,
42
+ key: torch.Tensor,
43
+ ) -> tuple[torch.Tensor, torch.Tensor]:
44
+ cos_sin = self.cos_sin_cache[positions]
45
+ cos, sin = cos_sin.chunk(2, dim=-1)
46
+ query = apply_rotary_emb(query, cos, sin)
47
+ key = apply_rotary_emb(key, cos, sin)
48
+ return query, key
49
+
50
+
51
+ @lru_cache(1)
52
+ def get_rope(
53
+ head_size: int,
54
+ rotary_dim: int,
55
+ max_position: int,
56
+ base: float,
57
+ rope_scaling: dict | None = None,
58
+ ):
59
+ assert rope_scaling is None
60
+ rotary_emb = RotaryEmbedding(head_size, rotary_dim, max_position, base)
61
+ return rotary_emb
acestep/third_parts/nano-vllm/nanovllm/layers/sampler.py ADDED
@@ -0,0 +1,15 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from torch import nn
3
+
4
+
5
+ class Sampler(nn.Module):
6
+
7
+ def __init__(self):
8
+ super().__init__()
9
+
10
+ @torch.compile
11
+ def forward(self, logits: torch.Tensor, temperatures: torch.Tensor):
12
+ logits = logits.float().div_(temperatures.unsqueeze(dim=1))
13
+ probs = torch.softmax(logits, dim=-1)
14
+ sample_tokens = probs.div_(torch.empty_like(probs).exponential_(1).clamp_min_(1e-10)).argmax(dim=-1)
15
+ return sample_tokens
acestep/third_parts/nano-vllm/nanovllm/llm.py ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ from nanovllm.engine.llm_engine import LLMEngine
2
+
3
+
4
+ class LLM(LLMEngine):
5
+ pass
acestep/third_parts/nano-vllm/nanovllm/models/qwen3.py ADDED
@@ -0,0 +1,215 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from torch import nn
3
+ import torch.distributed as dist
4
+ from transformers import Qwen3Config
5
+
6
+ from nanovllm.layers.activation import SiluAndMul
7
+ from nanovllm.layers.attention import Attention
8
+ from nanovllm.layers.layernorm import RMSNorm
9
+ from nanovllm.layers.linear import QKVParallelLinear, MergedColumnParallelLinear, RowParallelLinear
10
+ from nanovllm.layers.rotary_embedding import get_rope
11
+ from nanovllm.layers.embed_head import VocabParallelEmbedding, ParallelLMHead
12
+
13
+
14
+ class Qwen3Attention(nn.Module):
15
+
16
+ def __init__(
17
+ self,
18
+ hidden_size: int,
19
+ num_heads: int,
20
+ num_kv_heads: int,
21
+ max_position: int = 4096 * 32,
22
+ head_dim: int | None = None,
23
+ rms_norm_eps: float = 1e-06,
24
+ qkv_bias: bool = False,
25
+ rope_theta: float = 10000,
26
+ rope_scaling: tuple | None = None,
27
+ ) -> None:
28
+ super().__init__()
29
+ tp_size = dist.get_world_size()
30
+ self.total_num_heads = num_heads
31
+ assert self.total_num_heads % tp_size == 0
32
+ self.num_heads = self.total_num_heads // tp_size
33
+ self.total_num_kv_heads = num_kv_heads
34
+ assert self.total_num_kv_heads % tp_size == 0
35
+ self.num_kv_heads = self.total_num_kv_heads // tp_size
36
+ self.head_dim = head_dim or hidden_size // self.total_num_heads
37
+ self.q_size = self.num_heads * self.head_dim
38
+ self.kv_size = self.num_kv_heads * self.head_dim
39
+ self.scaling = self.head_dim ** -0.5
40
+ self.qkv_bias = qkv_bias
41
+
42
+ self.qkv_proj = QKVParallelLinear(
43
+ hidden_size,
44
+ self.head_dim,
45
+ self.total_num_heads,
46
+ self.total_num_kv_heads,
47
+ bias=qkv_bias,
48
+ )
49
+ self.o_proj = RowParallelLinear(
50
+ self.total_num_heads * self.head_dim,
51
+ hidden_size,
52
+ bias=False,
53
+ )
54
+ self.rotary_emb = get_rope(
55
+ self.head_dim,
56
+ rotary_dim=self.head_dim,
57
+ max_position=max_position,
58
+ base=rope_theta,
59
+ rope_scaling=rope_scaling,
60
+ )
61
+ self.attn = Attention(
62
+ self.num_heads,
63
+ self.head_dim,
64
+ self.scaling,
65
+ self.num_kv_heads,
66
+ )
67
+ if not self.qkv_bias:
68
+ self.q_norm = RMSNorm(self.head_dim, eps=rms_norm_eps)
69
+ self.k_norm = RMSNorm(self.head_dim, eps=rms_norm_eps)
70
+
71
+ def forward(
72
+ self,
73
+ positions: torch.Tensor,
74
+ hidden_states: torch.Tensor,
75
+ ) -> torch.Tensor:
76
+ qkv = self.qkv_proj(hidden_states)
77
+ q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
78
+ q = q.view(-1, self.num_heads, self.head_dim)
79
+ k = k.view(-1, self.num_kv_heads, self.head_dim)
80
+ v = v.view(-1, self.num_kv_heads, self.head_dim)
81
+ if not self.qkv_bias:
82
+ q = self.q_norm(q)
83
+ k = self.k_norm(k)
84
+ q, k = self.rotary_emb(positions, q, k)
85
+ o = self.attn(q, k, v)
86
+ output = self.o_proj(o.flatten(1, -1))
87
+ return output
88
+
89
+
90
+ class Qwen3MLP(nn.Module):
91
+
92
+ def __init__(
93
+ self,
94
+ hidden_size: int,
95
+ intermediate_size: int,
96
+ hidden_act: str,
97
+ ) -> None:
98
+ super().__init__()
99
+ self.gate_up_proj = MergedColumnParallelLinear(
100
+ hidden_size,
101
+ [intermediate_size] * 2,
102
+ bias=False,
103
+ )
104
+ self.down_proj = RowParallelLinear(
105
+ intermediate_size,
106
+ hidden_size,
107
+ bias=False,
108
+ )
109
+ assert hidden_act == "silu"
110
+ self.act_fn = SiluAndMul()
111
+
112
+ def forward(self, x):
113
+ gate_up = self.gate_up_proj(x)
114
+ x = self.act_fn(gate_up)
115
+ x = self.down_proj(x)
116
+ return x
117
+
118
+
119
+ class Qwen3DecoderLayer(nn.Module):
120
+
121
+ def __init__(
122
+ self,
123
+ config: Qwen3Config,
124
+ ) -> None:
125
+ super().__init__()
126
+ self.self_attn = Qwen3Attention(
127
+ hidden_size=config.hidden_size,
128
+ num_heads=config.num_attention_heads,
129
+ num_kv_heads=config.num_key_value_heads,
130
+ max_position=config.max_position_embeddings,
131
+ rms_norm_eps=config.rms_norm_eps,
132
+ qkv_bias=getattr(config, 'attention_bias', True),
133
+ head_dim=getattr(config, 'head_dim', None),
134
+ rope_theta=getattr(config, "rope_theta", 1000000),
135
+ rope_scaling=getattr(config, "rope_scaling", None),
136
+ )
137
+ self.mlp = Qwen3MLP(
138
+ hidden_size=config.hidden_size,
139
+ intermediate_size=config.intermediate_size,
140
+ hidden_act=config.hidden_act,
141
+ )
142
+ self.input_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
143
+ self.post_attention_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
144
+
145
+ def forward(
146
+ self,
147
+ positions: torch.Tensor,
148
+ hidden_states: torch.Tensor,
149
+ residual: torch.Tensor | None,
150
+ ) -> tuple[torch.Tensor, torch.Tensor]:
151
+ if residual is None:
152
+ hidden_states, residual = self.input_layernorm(hidden_states), hidden_states
153
+ else:
154
+ hidden_states, residual = self.input_layernorm(hidden_states, residual)
155
+ hidden_states = self.self_attn(positions, hidden_states)
156
+ hidden_states, residual = self.post_attention_layernorm(hidden_states, residual)
157
+ hidden_states = self.mlp(hidden_states)
158
+ return hidden_states, residual
159
+
160
+
161
+ class Qwen3Model(nn.Module):
162
+
163
+ def __init__(
164
+ self,
165
+ config: Qwen3Config,
166
+ ) -> None:
167
+ super().__init__()
168
+ self.embed_tokens = VocabParallelEmbedding(config.vocab_size, config.hidden_size)
169
+ self.layers = nn.ModuleList([Qwen3DecoderLayer(config) for _ in range(config.num_hidden_layers)])
170
+ self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
171
+
172
+ def forward(
173
+ self,
174
+ input_ids: torch.Tensor,
175
+ positions: torch.Tensor,
176
+ ) -> torch.Tensor:
177
+ hidden_states = self.embed_tokens(input_ids)
178
+ residual = None
179
+ for layer in self.layers:
180
+ hidden_states, residual = layer(positions, hidden_states, residual)
181
+ hidden_states, _ = self.norm(hidden_states, residual)
182
+ return hidden_states
183
+
184
+
185
+ class Qwen3ForCausalLM(nn.Module):
186
+ packed_modules_mapping = {
187
+ "q_proj": ("qkv_proj", "q"),
188
+ "k_proj": ("qkv_proj", "k"),
189
+ "v_proj": ("qkv_proj", "v"),
190
+ "gate_proj": ("gate_up_proj", 0),
191
+ "up_proj": ("gate_up_proj", 1),
192
+ }
193
+
194
+ def __init__(
195
+ self,
196
+ config: Qwen3Config
197
+ ) -> None:
198
+ super().__init__()
199
+ self.model = Qwen3Model(config)
200
+ self.lm_head = ParallelLMHead(config.vocab_size, config.hidden_size)
201
+ if config.tie_word_embeddings:
202
+ self.lm_head.weight.data = self.model.embed_tokens.weight.data
203
+
204
+ def forward(
205
+ self,
206
+ input_ids: torch.Tensor,
207
+ positions: torch.Tensor,
208
+ ) -> torch.Tensor:
209
+ return self.model(input_ids, positions)
210
+
211
+ def compute_logits(
212
+ self,
213
+ hidden_states: torch.Tensor,
214
+ ) -> torch.Tensor:
215
+ return self.lm_head(hidden_states)
acestep/third_parts/nano-vllm/nanovllm/sampling_params.py ADDED
@@ -0,0 +1,13 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from dataclasses import dataclass
2
+
3
+
4
+ @dataclass
5
+ class SamplingParams:
6
+ temperature: float = 1.0
7
+ max_tokens: int = 64
8
+ ignore_eos: bool = False
9
+ cfg_scale: float = 1.0 # CFG guidance scale. When > 1.0, applies classifier-free guidance
10
+
11
+ def __post_init__(self):
12
+ assert self.temperature > 1e-10, "greedy sampling is not permitted"
13
+ assert self.cfg_scale >= 1.0, "cfg_scale must be >= 1.0"
acestep/third_parts/nano-vllm/nanovllm/utils/context.py ADDED
@@ -0,0 +1,27 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from dataclasses import dataclass
2
+ import torch
3
+
4
+
5
+ @dataclass
6
+ class Context:
7
+ is_prefill: bool = False
8
+ cu_seqlens_q: torch.Tensor | None = None
9
+ cu_seqlens_k: torch.Tensor | None = None
10
+ max_seqlen_q: int = 0
11
+ max_seqlen_k: int = 0
12
+ slot_mapping: torch.Tensor | None = None
13
+ context_lens: torch.Tensor | None = None
14
+ block_tables: torch.Tensor | None = None
15
+
16
+ _CONTEXT = Context()
17
+
18
+ def get_context():
19
+ return _CONTEXT
20
+
21
+ def set_context(is_prefill, cu_seqlens_q=None, cu_seqlens_k=None, max_seqlen_q=0, max_seqlen_k=0, slot_mapping=None, context_lens=None, block_tables=None):
22
+ global _CONTEXT
23
+ _CONTEXT = Context(is_prefill, cu_seqlens_q, cu_seqlens_k, max_seqlen_q, max_seqlen_k, slot_mapping, context_lens, block_tables)
24
+
25
+ def reset_context():
26
+ global _CONTEXT
27
+ _CONTEXT = Context()
acestep/third_parts/nano-vllm/nanovllm/utils/loader.py ADDED
@@ -0,0 +1,28 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ from glob import glob
3
+ import torch
4
+ from torch import nn
5
+ from safetensors import safe_open
6
+
7
+
8
+ def default_weight_loader(param: nn.Parameter, loaded_weight: torch.Tensor):
9
+ param.data.copy_(loaded_weight)
10
+
11
+
12
+ def load_model(model: nn.Module, path: str):
13
+ packed_modules_mapping = getattr(model, "packed_modules_mapping", {})
14
+ for file in glob(os.path.join(path, "*.safetensors")):
15
+ with safe_open(file, "pt", "cpu") as f:
16
+ for weight_name in f.keys():
17
+ for k in packed_modules_mapping:
18
+ if k in weight_name:
19
+ v, shard_id = packed_modules_mapping[k]
20
+ param_name = weight_name.replace(k, v)
21
+ param = model.get_parameter(param_name)
22
+ weight_loader = getattr(param, "weight_loader")
23
+ weight_loader(param, f.get_tensor(weight_name), shard_id)
24
+ break
25
+ else:
26
+ param = model.get_parameter(weight_name)
27
+ weight_loader = getattr(param, "weight_loader", default_weight_loader)
28
+ weight_loader(param, f.get_tensor(weight_name))
acestep/third_parts/nano-vllm/pyproject.toml ADDED
@@ -0,0 +1,27 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ [build-system]
2
+ requires = ["setuptools>=61"]
3
+ build-backend = "setuptools.build_meta"
4
+
5
+ [project]
6
+ name = "nano-vllm"
7
+ version = "0.2.0"
8
+ authors = [{ name = "Xingkai Yu" }]
9
+ license = "MIT"
10
+ license-files = ["LICENSE"]
11
+ readme = "README.md"
12
+ description = "a lightweight vLLM implementation built from scratch"
13
+ requires-python = ">=3.10,<3.13"
14
+ dependencies = [
15
+ "torch>=2.4.0",
16
+ "triton>=3.0.0",
17
+ "transformers>=4.51.0",
18
+ "flash-attn",
19
+ "xxhash",
20
+ ]
21
+
22
+ [project.urls]
23
+ Homepage="https://github.com/GeeeekExplorer/nano-vllm"
24
+
25
+ [tool.setuptools.packages.find]
26
+ where = ["."]
27
+ include = ["nanovllm*"]
requirements.txt ADDED
@@ -0,0 +1,4 @@
 
 
 
 
 
1
+ torch
2
+ transformers
3
+ diffusers
4
+ gradio