Spaces:
Running
on
A100
Running
on
A100
Gong Junmin
commited on
Commit
·
11a221a
1
Parent(s):
509f9f2
first commit
Browse files- .gitignore +4 -0
- LICENSE +246 -201
- acestep/acestep_v15_pipeline.py +67 -0
- acestep/gradio_ui.py +744 -0
- acestep/handler.py +1100 -0
- acestep/third_parts/nano-vllm/LICENSE +21 -0
- acestep/third_parts/nano-vllm/README.md +66 -0
- acestep/third_parts/nano-vllm/assets/logo.png +3 -0
- acestep/third_parts/nano-vllm/bench.py +32 -0
- acestep/third_parts/nano-vllm/example.py +33 -0
- acestep/third_parts/nano-vllm/nanovllm/__init__.py +2 -0
- acestep/third_parts/nano-vllm/nanovllm/config.py +26 -0
- acestep/third_parts/nano-vllm/nanovllm/engine/block_manager.py +112 -0
- acestep/third_parts/nano-vllm/nanovllm/engine/llm_engine.py +120 -0
- acestep/third_parts/nano-vllm/nanovllm/engine/model_runner.py +315 -0
- acestep/third_parts/nano-vllm/nanovllm/engine/scheduler.py +222 -0
- acestep/third_parts/nano-vllm/nanovllm/engine/sequence.py +89 -0
- acestep/third_parts/nano-vllm/nanovllm/layers/activation.py +14 -0
- acestep/third_parts/nano-vllm/nanovllm/layers/attention.py +75 -0
- acestep/third_parts/nano-vllm/nanovllm/layers/embed_head.py +66 -0
- acestep/third_parts/nano-vllm/nanovllm/layers/layernorm.py +50 -0
- acestep/third_parts/nano-vllm/nanovllm/layers/linear.py +153 -0
- acestep/third_parts/nano-vllm/nanovllm/layers/rotary_embedding.py +61 -0
- acestep/third_parts/nano-vllm/nanovllm/layers/sampler.py +15 -0
- acestep/third_parts/nano-vllm/nanovllm/llm.py +5 -0
- acestep/third_parts/nano-vllm/nanovllm/models/qwen3.py +215 -0
- acestep/third_parts/nano-vllm/nanovllm/sampling_params.py +13 -0
- acestep/third_parts/nano-vllm/nanovllm/utils/context.py +27 -0
- acestep/third_parts/nano-vllm/nanovllm/utils/loader.py +28 -0
- acestep/third_parts/nano-vllm/pyproject.toml +27 -0
- requirements.txt +4 -0
.gitignore
CHANGED
|
@@ -205,3 +205,7 @@ cython_debug/
|
|
| 205 |
marimo/_static/
|
| 206 |
marimo/_lsp/
|
| 207 |
__marimo__/
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 205 |
marimo/_static/
|
| 206 |
marimo/_lsp/
|
| 207 |
__marimo__/
|
| 208 |
+
tests/
|
| 209 |
+
checkpoints/
|
| 210 |
+
playground.ipynb
|
| 211 |
+
.history/
|
LICENSE
CHANGED
|
@@ -1,201 +1,246 @@
|
|
| 1 |
-
|
| 2 |
-
|
| 3 |
-
|
| 4 |
-
|
| 5 |
-
|
| 6 |
-
|
| 7 |
-
|
| 8 |
-
|
| 9 |
-
|
| 10 |
-
|
| 11 |
-
|
| 12 |
-
|
| 13 |
-
|
| 14 |
-
|
| 15 |
-
|
| 16 |
-
|
| 17 |
-
|
| 18 |
-
|
| 19 |
-
|
| 20 |
-
|
| 21 |
-
|
| 22 |
-
|
| 23 |
-
|
| 24 |
-
|
| 25 |
-
|
| 26 |
-
|
| 27 |
-
|
| 28 |
-
|
| 29 |
-
|
| 30 |
-
|
| 31 |
-
|
| 32 |
-
|
| 33 |
-
|
| 34 |
-
|
| 35 |
-
|
| 36 |
-
|
| 37 |
-
|
| 38 |
-
|
| 39 |
-
|
| 40 |
-
|
| 41 |
-
|
| 42 |
-
|
| 43 |
-
|
| 44 |
-
|
| 45 |
-
|
| 46 |
-
|
| 47 |
-
|
| 48 |
-
|
| 49 |
-
|
| 50 |
-
|
| 51 |
-
|
| 52 |
-
|
| 53 |
-
|
| 54 |
-
|
| 55 |
-
|
| 56 |
-
|
| 57 |
-
|
| 58 |
-
|
| 59 |
-
|
| 60 |
-
|
| 61 |
-
|
| 62 |
-
|
| 63 |
-
|
| 64 |
-
|
| 65 |
-
|
| 66 |
-
|
| 67 |
-
|
| 68 |
-
|
| 69 |
-
|
| 70 |
-
|
| 71 |
-
|
| 72 |
-
|
| 73 |
-
|
| 74 |
-
|
| 75 |
-
|
| 76 |
-
|
| 77 |
-
|
| 78 |
-
|
| 79 |
-
|
| 80 |
-
|
| 81 |
-
|
| 82 |
-
|
| 83 |
-
|
| 84 |
-
|
| 85 |
-
|
| 86 |
-
|
| 87 |
-
|
| 88 |
-
|
| 89 |
-
|
| 90 |
-
|
| 91 |
-
|
| 92 |
-
|
| 93 |
-
|
| 94 |
-
|
| 95 |
-
|
| 96 |
-
|
| 97 |
-
|
| 98 |
-
|
| 99 |
-
|
| 100 |
-
|
| 101 |
-
|
| 102 |
-
|
| 103 |
-
|
| 104 |
-
|
| 105 |
-
|
| 106 |
-
|
| 107 |
-
|
| 108 |
-
|
| 109 |
-
|
| 110 |
-
|
| 111 |
-
|
| 112 |
-
|
| 113 |
-
|
| 114 |
-
|
| 115 |
-
|
| 116 |
-
|
| 117 |
-
|
| 118 |
-
|
| 119 |
-
|
| 120 |
-
|
| 121 |
-
|
| 122 |
-
|
| 123 |
-
|
| 124 |
-
|
| 125 |
-
|
| 126 |
-
|
| 127 |
-
|
| 128 |
-
|
| 129 |
-
|
| 130 |
-
|
| 131 |
-
|
| 132 |
-
|
| 133 |
-
|
| 134 |
-
|
| 135 |
-
|
| 136 |
-
|
| 137 |
-
|
| 138 |
-
|
| 139 |
-
|
| 140 |
-
|
| 141 |
-
|
| 142 |
-
|
| 143 |
-
|
| 144 |
-
|
| 145 |
-
|
| 146 |
-
|
| 147 |
-
|
| 148 |
-
|
| 149 |
-
|
| 150 |
-
|
| 151 |
-
|
| 152 |
-
|
| 153 |
-
|
| 154 |
-
|
| 155 |
-
|
| 156 |
-
|
| 157 |
-
|
| 158 |
-
|
| 159 |
-
|
| 160 |
-
|
| 161 |
-
|
| 162 |
-
|
| 163 |
-
|
| 164 |
-
|
| 165 |
-
|
| 166 |
-
|
| 167 |
-
|
| 168 |
-
|
| 169 |
-
|
| 170 |
-
|
| 171 |
-
|
| 172 |
-
|
| 173 |
-
|
| 174 |
-
|
| 175 |
-
|
| 176 |
-
|
| 177 |
-
|
| 178 |
-
|
| 179 |
-
|
| 180 |
-
|
| 181 |
-
|
| 182 |
-
|
| 183 |
-
|
| 184 |
-
|
| 185 |
-
|
| 186 |
-
|
| 187 |
-
|
| 188 |
-
|
| 189 |
-
|
| 190 |
-
|
| 191 |
-
|
| 192 |
-
|
| 193 |
-
|
| 194 |
-
|
| 195 |
-
|
| 196 |
-
|
| 197 |
-
|
| 198 |
-
|
| 199 |
-
|
| 200 |
-
|
| 201 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
================================================================================
|
| 2 |
+
ACE-STEP LICENSE
|
| 3 |
+
Version 1.0
|
| 4 |
+
================================================================================
|
| 5 |
+
|
| 6 |
+
TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
|
| 7 |
+
|
| 8 |
+
--------------------------------------------------------------------------------
|
| 9 |
+
|
| 10 |
+
1. DEFINITIONS
|
| 11 |
+
|
| 12 |
+
"License" shall mean the terms and conditions for use, reproduction, and
|
| 13 |
+
distribution as defined by Sections 1 through 11 of this document.
|
| 14 |
+
|
| 15 |
+
"Licensor" shall mean ACE Studio (or the copyright owner/entity authorized
|
| 16 |
+
by ACE Studio) that is granting the License.
|
| 17 |
+
|
| 18 |
+
"Legal Entity" shall mean the union of the acting entity and all other
|
| 19 |
+
entities that control, are controlled by, or are under common control with
|
| 20 |
+
that entity. For the purposes of this definition, "control" means (i) the
|
| 21 |
+
power, direct or indirect, to cause the direction or management of such
|
| 22 |
+
entity, whether by contract or otherwise, or (ii) ownership of fifty
|
| 23 |
+
percent (50%) or more of the outstanding shares, or (iii) beneficial
|
| 24 |
+
ownership of such entity.
|
| 25 |
+
|
| 26 |
+
"You" (or "Your") shall mean an individual or Legal Entity exercising
|
| 27 |
+
permissions granted by this License.
|
| 28 |
+
|
| 29 |
+
"Source" form shall mean the preferred form for making modifications,
|
| 30 |
+
including but not limited to software source code, documentation source,
|
| 31 |
+
configuration files, and model training code.
|
| 32 |
+
|
| 33 |
+
"Object" form shall mean any form resulting from mechanical transformation
|
| 34 |
+
or translation of a Source form, including but not limited to compiled
|
| 35 |
+
object code, generated documentation, and conversions to other media types.
|
| 36 |
+
|
| 37 |
+
"Work" shall mean the work of authorship, whether in Source or Object form,
|
| 38 |
+
made available under the License, as indicated by a copyright notice that
|
| 39 |
+
is included in or attached to the work. For the avoidance of doubt, "Work"
|
| 40 |
+
explicitly includes the Model Weights, parameters, and configuration files
|
| 41 |
+
provided by the Licensor.
|
| 42 |
+
|
| 43 |
+
"Derivative Works" shall mean any work, whether in Source or Object form,
|
| 44 |
+
that is based on (or derived from) the Work and for which the editorial
|
| 45 |
+
revisions, annotations, elaborations, or other modifications represent, as
|
| 46 |
+
a whole, an original work of authorship. For the purposes of this License,
|
| 47 |
+
"Derivative Works" shall explicitly include "Derivative Models," defined as
|
| 48 |
+
any modifications to the Model Weights, including but not limited to
|
| 49 |
+
Fine-tunes, LoRAs (Low-Rank Adaptation), adapters, and other distinct
|
| 50 |
+
parameter sets derived from the Work.
|
| 51 |
+
|
| 52 |
+
"Output" shall mean any audio, music, sound recordings, or data generated
|
| 53 |
+
by the use or execution of the Work or Derivative Works.
|
| 54 |
+
|
| 55 |
+
"Contribution" shall mean any work of authorship, including the original
|
| 56 |
+
version of the Work and any modifications or additions to that Work or
|
| 57 |
+
Derivative Works thereof, that is intentionally submitted to Licensor for
|
| 58 |
+
inclusion in the Work by the copyright owner or by an individual or Legal
|
| 59 |
+
Entity authorized to submit on behalf of the copyright owner.
|
| 60 |
+
|
| 61 |
+
"Contributor" shall mean Licensor and any individual or Legal Entity on
|
| 62 |
+
behalf of whom a Contribution has been received by Licensor and subsequently
|
| 63 |
+
incorporated within the Work.
|
| 64 |
+
|
| 65 |
+
--------------------------------------------------------------------------------
|
| 66 |
+
|
| 67 |
+
2. GRANT OF COPYRIGHT LICENSE
|
| 68 |
+
|
| 69 |
+
Subject to the terms and conditions of this License (including the specific
|
| 70 |
+
restrictions in Section 5 and Section 6), each Contributor hereby grants to
|
| 71 |
+
You a perpetual, worldwide, non-exclusive, no-charge, royalty-free,
|
| 72 |
+
irrevocable copyright license to reproduce, prepare Derivative Works of,
|
| 73 |
+
publicly display, publicly perform, sublicense, and distribute the Work and
|
| 74 |
+
such Derivative Works in Source or Object form.
|
| 75 |
+
|
| 76 |
+
--------------------------------------------------------------------------------
|
| 77 |
+
|
| 78 |
+
3. GRANT OF PATENT LICENSE
|
| 79 |
+
|
| 80 |
+
Subject to the terms and conditions of this License, each Contributor
|
| 81 |
+
hereby grants to You a perpetual, worldwide, non-exclusive, no-charge,
|
| 82 |
+
royalty-free, irrevocable (except as stated in this section) patent license
|
| 83 |
+
to make, have made, use, offer to sell, sell, import, and otherwise transfer
|
| 84 |
+
the Work, where such license applies only to those patent claims licensable
|
| 85 |
+
by such Contributor that are necessarily infringed by their Contribution(s)
|
| 86 |
+
alone or by combination of their Contribution(s) with the Work to which such
|
| 87 |
+
Contribution(s) was submitted.
|
| 88 |
+
|
| 89 |
+
If You institute patent litigation against any entity (including a
|
| 90 |
+
cross-claim or counterclaim in a lawsuit) alleging that the Work or a
|
| 91 |
+
Contribution incorporated within the Work constitutes direct or contributory
|
| 92 |
+
patent infringement, then any patent licenses granted to You under this
|
| 93 |
+
License for that Work shall terminate as of the date such litigation is
|
| 94 |
+
filed.
|
| 95 |
+
|
| 96 |
+
--------------------------------------------------------------------------------
|
| 97 |
+
|
| 98 |
+
4. REDISTRIBUTION
|
| 99 |
+
|
| 100 |
+
You may reproduce and distribute copies of the Work or Derivative Works
|
| 101 |
+
thereof in any medium, with or without modifications, and in Source or
|
| 102 |
+
Object form, provided that You meet the following conditions:
|
| 103 |
+
|
| 104 |
+
(a) You must give any other recipients of the Work or Derivative Works a
|
| 105 |
+
copy of this License; and
|
| 106 |
+
|
| 107 |
+
(b) You must cause any modified files to carry prominent notices stating
|
| 108 |
+
that You changed the files; and
|
| 109 |
+
|
| 110 |
+
(c) You must retain, in the Source form of any Derivative Works that You
|
| 111 |
+
distribute, all copyright, patent, trademark, and attribution notices
|
| 112 |
+
from the Source form of the Work, excluding those notices that do not
|
| 113 |
+
pertain to any part of the Derivative Works; and
|
| 114 |
+
|
| 115 |
+
(d) If the Work includes a "NOTICE" text file as part of its distribution,
|
| 116 |
+
then any Derivative Works that You distribute must include a readable
|
| 117 |
+
copy of the attribution notices contained within such NOTICE file,
|
| 118 |
+
excluding those notices that do not pertain to any part of the
|
| 119 |
+
Derivative Works, in at least one of the following places: within a
|
| 120 |
+
NOTICE text file distributed as part of the Derivative Works; within
|
| 121 |
+
the Source form or documentation, if provided along with the Derivative
|
| 122 |
+
Works; or, within a display generated by the Derivative Works, if and
|
| 123 |
+
wherever such third-party notices normally appear. The contents of the
|
| 124 |
+
NOTICE file are for informational purposes only and do not modify the
|
| 125 |
+
License. You may add Your own attribution notices within Derivative
|
| 126 |
+
Works that You distribute, alongside or as an addendum to the NOTICE
|
| 127 |
+
text from the Work, provided that such additional attribution notices
|
| 128 |
+
cannot be construed as modifying the License.
|
| 129 |
+
|
| 130 |
+
(e) Community Contribution Requirement: If You create a Derivative Work
|
| 131 |
+
(specifically a Derivative Model, such as a Fine-tune or LoRA) and You
|
| 132 |
+
distribute it, publicly perform it, or use it to generate publicly
|
| 133 |
+
available Output, You must make the Source form (including the weights,
|
| 134 |
+
parameters, and training configuration) of said Derivative Work publicly
|
| 135 |
+
available under the terms of this License. This clause ensures that
|
| 136 |
+
improvements to the model remain accessible to the community.
|
| 137 |
+
|
| 138 |
+
--------------------------------------------------------------------------------
|
| 139 |
+
|
| 140 |
+
5. RESTRICTIONS ON COMMERCIAL SERVICES (THE "ANTI-SAAS" CLAUSE)
|
| 141 |
+
|
| 142 |
+
Notwithstanding the grants in Section 2, You are strictly prohibited from
|
| 143 |
+
using the Work or Derivative Works to operate, promote, or distribute a
|
| 144 |
+
commercial service where the primary value provided to users is the ability
|
| 145 |
+
to generate Outputs using the Work. This includes, but is not limited to:
|
| 146 |
+
|
| 147 |
+
(a) Offering the Work as a hosted Application Programming Interface (API);
|
| 148 |
+
|
| 149 |
+
(b) Offering the Work as a Software-as-a-Service (SaaS) product;
|
| 150 |
+
|
| 151 |
+
(c) Integrating the Work into a commercial platform that charges users
|
| 152 |
+
specifically for generation capabilities.
|
| 153 |
+
|
| 154 |
+
Note: You may use the Work to develop independent software tools, plugins,
|
| 155 |
+
or applications (e.g., local plugins for DAWs), provided that such tools
|
| 156 |
+
run locally on end-users' hardware and do not violate the restrictions on
|
| 157 |
+
hosted commercial generation services defined above.
|
| 158 |
+
|
| 159 |
+
--------------------------------------------------------------------------------
|
| 160 |
+
|
| 161 |
+
6. OWNERSHIP AND COMMERCIALIZATION OF OUTPUTS
|
| 162 |
+
|
| 163 |
+
(a) Personal and Non-Commercial Use: You are free to use Outputs generated
|
| 164 |
+
by the Work for personal use, research, educational purposes, and
|
| 165 |
+
non-commercial creative projects (e.g., background music for personal
|
| 166 |
+
videos) without restriction.
|
| 167 |
+
|
| 168 |
+
(b) Commercial Use and Verification: The Licensor provides the Work without
|
| 169 |
+
any default warranty of copyright ownership for raw Outputs. To use
|
| 170 |
+
Outputs for Commercial Purposes (including but not limited to
|
| 171 |
+
distributing to music streaming platforms such as Spotify, Apple Music,
|
| 172 |
+
or YouTube Music, or registering Content ID), You must obtain Proof of
|
| 173 |
+
Human Creation or authorization through ACE Studio (or channels
|
| 174 |
+
officially designated by the Licensor). "Commercial Purposes" implies
|
| 175 |
+
intent to profit from the direct exploitation of the generated audio.
|
| 176 |
+
|
| 177 |
+
(c) Prohibition on Mass Generation: You are expressly prohibited from using
|
| 178 |
+
the Work to generate and distribute mass quantities of content
|
| 179 |
+
("spamming") for the purpose of flooding streaming services,
|
| 180 |
+
manipulating royalty systems, or engaging in automated content farming.
|
| 181 |
+
|
| 182 |
+
--------------------------------------------------------------------------------
|
| 183 |
+
|
| 184 |
+
7. SUBMISSION OF CONTRIBUTIONS
|
| 185 |
+
|
| 186 |
+
Unless You explicitly state otherwise, any Contribution intentionally
|
| 187 |
+
submitted for inclusion in the Work by You to the Licensor shall be under
|
| 188 |
+
the terms and conditions of this License, without any additional terms or
|
| 189 |
+
conditions. Notwithstanding the above, nothing herein shall supersede or
|
| 190 |
+
modify the terms of any separate license agreement you may have executed
|
| 191 |
+
with Licensor regarding such Contributions.
|
| 192 |
+
|
| 193 |
+
--------------------------------------------------------------------------------
|
| 194 |
+
|
| 195 |
+
8. TRADEMARKS
|
| 196 |
+
|
| 197 |
+
This License does not grant permission to use the trade names, trademarks,
|
| 198 |
+
service marks, or product names of the Licensor, except as required for
|
| 199 |
+
reasonable and customary use in describing the origin of the Work and
|
| 200 |
+
reproducing the content of the NOTICE file.
|
| 201 |
+
|
| 202 |
+
--------------------------------------------------------------------------------
|
| 203 |
+
|
| 204 |
+
9. DISCLAIMER OF WARRANTY
|
| 205 |
+
|
| 206 |
+
Unless required by applicable law or agreed to in writing, Licensor provides
|
| 207 |
+
the Work (and each Contributor provides its Contributions) on an "AS IS"
|
| 208 |
+
BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
|
| 209 |
+
implied, including, without limitation, any warranties or conditions of
|
| 210 |
+
TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A PARTICULAR
|
| 211 |
+
PURPOSE. You are solely responsible for determining the appropriateness of
|
| 212 |
+
using or redistributing the Work and assume any risks associated with Your
|
| 213 |
+
exercise of permissions under this License.
|
| 214 |
+
|
| 215 |
+
--------------------------------------------------------------------------------
|
| 216 |
+
|
| 217 |
+
10. LIMITATION OF LIABILITY
|
| 218 |
+
|
| 219 |
+
In no event and under no legal theory, whether in tort (including
|
| 220 |
+
negligence), contract, or otherwise, unless required by applicable law
|
| 221 |
+
(such as deliberate and grossly negligent acts) or agreed to in writing,
|
| 222 |
+
shall any Contributor be liable to You for damages, including any direct,
|
| 223 |
+
indirect, special, incidental, or consequential damages of any character
|
| 224 |
+
arising as a result of this License or out of the use or inability to use
|
| 225 |
+
the Work (including but not limited to damages for loss of goodwill, work
|
| 226 |
+
stoppage, computer failure or malfunction, or any and all other commercial
|
| 227 |
+
damages or losses), even if such Contributor has been advised of the
|
| 228 |
+
possibility of such damages.
|
| 229 |
+
|
| 230 |
+
--------------------------------------------------------------------------------
|
| 231 |
+
|
| 232 |
+
11. ACCEPTING WARRANTY OR ADDITIONAL LIABILITY
|
| 233 |
+
|
| 234 |
+
While redistributing the Work or Derivative Works thereof, You may choose
|
| 235 |
+
to offer, and charge a fee for, acceptance of support, warranty,
|
| 236 |
+
indemnity, or other liability obligations and/or rights consistent with
|
| 237 |
+
this License. However, in accepting such obligations, You may act only on
|
| 238 |
+
Your own behalf and on Your sole responsibility, not on behalf of any
|
| 239 |
+
other Contributor, and only if You agree to indemnify, defend, and hold
|
| 240 |
+
each Contributor harmless for any liability incurred by, or claims asserted
|
| 241 |
+
against, such Contributor by reason of your accepting any such warranty or
|
| 242 |
+
additional liability.
|
| 243 |
+
|
| 244 |
+
================================================================================
|
| 245 |
+
END OF TERMS AND CONDITIONS
|
| 246 |
+
================================================================================
|
acestep/acestep_v15_pipeline.py
ADDED
|
@@ -0,0 +1,67 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
ACE-Step V1.5 Pipeline
|
| 3 |
+
Handler wrapper connecting model and UI
|
| 4 |
+
"""
|
| 5 |
+
import os
|
| 6 |
+
import sys
|
| 7 |
+
|
| 8 |
+
# Clear proxy settings that may affect Gradio
|
| 9 |
+
for proxy_var in ['http_proxy', 'https_proxy', 'HTTP_PROXY', 'HTTPS_PROXY', 'ALL_PROXY']:
|
| 10 |
+
os.environ.pop(proxy_var, None)
|
| 11 |
+
|
| 12 |
+
os.environ['CUDA_VISIBLE_DEVICES'] = '7' # Adjust as needed
|
| 13 |
+
|
| 14 |
+
from .handler import AceStepHandler
|
| 15 |
+
from .gradio_ui import create_gradio_interface
|
| 16 |
+
|
| 17 |
+
|
| 18 |
+
def create_demo():
|
| 19 |
+
"""
|
| 20 |
+
Create Gradio demo interface
|
| 21 |
+
|
| 22 |
+
Returns:
|
| 23 |
+
Gradio Blocks instance
|
| 24 |
+
"""
|
| 25 |
+
# Create handler instance (business logic processor)
|
| 26 |
+
handler = AceStepHandler()
|
| 27 |
+
|
| 28 |
+
# Create Gradio interface
|
| 29 |
+
demo = create_gradio_interface(handler)
|
| 30 |
+
|
| 31 |
+
return demo
|
| 32 |
+
|
| 33 |
+
|
| 34 |
+
def main():
|
| 35 |
+
"""Main entry function"""
|
| 36 |
+
import argparse
|
| 37 |
+
|
| 38 |
+
parser = argparse.ArgumentParser(description="Gradio Demo for ACE-Step V1.5")
|
| 39 |
+
parser.add_argument("--port", type=int, default=7860, help="Port to run the gradio server on")
|
| 40 |
+
parser.add_argument("--share", action="store_true", help="Create a public link")
|
| 41 |
+
parser.add_argument("--debug", action="store_true", help="Enable debug mode")
|
| 42 |
+
parser.add_argument("--server-name", type=str, default="127.0.0.1", help="Server name (default: 127.0.0.1, use 0.0.0.0 for all interfaces)")
|
| 43 |
+
args = parser.parse_args()
|
| 44 |
+
|
| 45 |
+
try:
|
| 46 |
+
# Create and launch demo
|
| 47 |
+
print("Creating Gradio interface...")
|
| 48 |
+
demo = create_demo()
|
| 49 |
+
print(f"Launching server on {args.server_name}:{args.port}...")
|
| 50 |
+
demo.launch(
|
| 51 |
+
server_name=args.server_name,
|
| 52 |
+
server_port=args.port,
|
| 53 |
+
share=args.share,
|
| 54 |
+
debug=args.debug,
|
| 55 |
+
show_error=True,
|
| 56 |
+
prevent_thread_lock=False, # Keep thread locked to maintain server running
|
| 57 |
+
inbrowser=False, # Don't auto-open browser
|
| 58 |
+
)
|
| 59 |
+
except Exception as e:
|
| 60 |
+
print(f"Error launching Gradio: {e}", file=sys.stderr)
|
| 61 |
+
import traceback
|
| 62 |
+
traceback.print_exc()
|
| 63 |
+
sys.exit(1)
|
| 64 |
+
|
| 65 |
+
|
| 66 |
+
if __name__ == "__main__":
|
| 67 |
+
main()
|
acestep/gradio_ui.py
ADDED
|
@@ -0,0 +1,744 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Gradio UI Components Module
|
| 3 |
+
Contains all Gradio interface component definitions and layouts
|
| 4 |
+
"""
|
| 5 |
+
import gradio as gr
|
| 6 |
+
from typing import Callable, Optional
|
| 7 |
+
|
| 8 |
+
|
| 9 |
+
def create_gradio_interface(handler) -> gr.Blocks:
|
| 10 |
+
"""
|
| 11 |
+
Create Gradio interface
|
| 12 |
+
|
| 13 |
+
Args:
|
| 14 |
+
handler: Business logic handler instance
|
| 15 |
+
|
| 16 |
+
Returns:
|
| 17 |
+
Gradio Blocks instance
|
| 18 |
+
"""
|
| 19 |
+
with gr.Blocks(
|
| 20 |
+
title="ACE-Step V1.5 Demo",
|
| 21 |
+
theme=gr.themes.Soft(),
|
| 22 |
+
css="""
|
| 23 |
+
.main-header {
|
| 24 |
+
text-align: center;
|
| 25 |
+
margin-bottom: 2rem;
|
| 26 |
+
}
|
| 27 |
+
.section-header {
|
| 28 |
+
background: linear-gradient(90deg, #4CAF50, #45a049);
|
| 29 |
+
color: white;
|
| 30 |
+
padding: 10px;
|
| 31 |
+
border-radius: 5px;
|
| 32 |
+
margin: 10px 0;
|
| 33 |
+
}
|
| 34 |
+
"""
|
| 35 |
+
) as demo:
|
| 36 |
+
|
| 37 |
+
gr.HTML("""
|
| 38 |
+
<div class="main-header">
|
| 39 |
+
<h1>♪ACE-Step V1.5 Demo</h1>
|
| 40 |
+
<p>Generate music from text captions and lyrics using diffusion models</p>
|
| 41 |
+
</div>
|
| 42 |
+
""")
|
| 43 |
+
|
| 44 |
+
# Dataset Explorer Section
|
| 45 |
+
dataset_section = create_dataset_section(handler)
|
| 46 |
+
|
| 47 |
+
# Generation Section
|
| 48 |
+
generation_section = create_generation_section(handler)
|
| 49 |
+
|
| 50 |
+
# Results Section
|
| 51 |
+
results_section = create_results_section(handler)
|
| 52 |
+
|
| 53 |
+
# Connect event handlers
|
| 54 |
+
setup_event_handlers(demo, handler, dataset_section, generation_section, results_section)
|
| 55 |
+
|
| 56 |
+
return demo
|
| 57 |
+
|
| 58 |
+
|
| 59 |
+
def create_dataset_section(handler) -> dict:
|
| 60 |
+
"""Create dataset explorer section"""
|
| 61 |
+
with gr.Group():
|
| 62 |
+
gr.HTML('<div class="section-header"><h3>📊 Dataset Explorer</h3></div>')
|
| 63 |
+
|
| 64 |
+
with gr.Row(equal_height=True):
|
| 65 |
+
dataset_type = gr.Dropdown(
|
| 66 |
+
choices=["train", "test"],
|
| 67 |
+
value="train",
|
| 68 |
+
label="Dataset",
|
| 69 |
+
info="Choose dataset to explore",
|
| 70 |
+
scale=2
|
| 71 |
+
)
|
| 72 |
+
import_dataset_btn = gr.Button("📥 Import Dataset", variant="primary", scale=1)
|
| 73 |
+
|
| 74 |
+
search_type = gr.Dropdown(
|
| 75 |
+
choices=["keys", "idx", "random"],
|
| 76 |
+
value="random",
|
| 77 |
+
label="Search Type",
|
| 78 |
+
info="How to find items",
|
| 79 |
+
scale=1
|
| 80 |
+
)
|
| 81 |
+
search_value = gr.Textbox(
|
| 82 |
+
label="Search Value",
|
| 83 |
+
placeholder="Enter keys or index (leave empty for random)",
|
| 84 |
+
info="Keys: exact match, Index: 0 to dataset size-1",
|
| 85 |
+
scale=2
|
| 86 |
+
)
|
| 87 |
+
|
| 88 |
+
instruction_display = gr.Textbox(
|
| 89 |
+
label="📝 Instruction",
|
| 90 |
+
interactive=False,
|
| 91 |
+
placeholder="No instruction available",
|
| 92 |
+
lines=1
|
| 93 |
+
)
|
| 94 |
+
|
| 95 |
+
repaint_viz_plot = gr.Plot()
|
| 96 |
+
|
| 97 |
+
with gr.Accordion("📋 Item Metadata (JSON)", open=False):
|
| 98 |
+
item_info_json = gr.Code(
|
| 99 |
+
label="Complete Item Information",
|
| 100 |
+
language="json",
|
| 101 |
+
interactive=False,
|
| 102 |
+
lines=15
|
| 103 |
+
)
|
| 104 |
+
|
| 105 |
+
with gr.Row(equal_height=True):
|
| 106 |
+
item_src_audio = gr.Audio(
|
| 107 |
+
label="Source Audio",
|
| 108 |
+
type="filepath",
|
| 109 |
+
interactive=False,
|
| 110 |
+
scale=8
|
| 111 |
+
)
|
| 112 |
+
get_item_btn = gr.Button("🔍 Get Item", variant="secondary", interactive=False, scale=2)
|
| 113 |
+
|
| 114 |
+
with gr.Row(equal_height=True):
|
| 115 |
+
item_target_audio = gr.Audio(
|
| 116 |
+
label="Target Audio",
|
| 117 |
+
type="filepath",
|
| 118 |
+
interactive=False,
|
| 119 |
+
scale=8
|
| 120 |
+
)
|
| 121 |
+
item_refer_audio = gr.Audio(
|
| 122 |
+
label="Reference Audio",
|
| 123 |
+
type="filepath",
|
| 124 |
+
interactive=False,
|
| 125 |
+
scale=2
|
| 126 |
+
)
|
| 127 |
+
|
| 128 |
+
with gr.Row():
|
| 129 |
+
use_src_checkbox = gr.Checkbox(
|
| 130 |
+
label="Use Source Audio from Dataset",
|
| 131 |
+
value=True,
|
| 132 |
+
info="Check to use the source audio from dataset"
|
| 133 |
+
)
|
| 134 |
+
|
| 135 |
+
data_status = gr.Textbox(label="📊 Data Status", interactive=False, value="❌ No dataset imported")
|
| 136 |
+
auto_fill_btn = gr.Button("📋 Auto-fill Generation Form", variant="primary")
|
| 137 |
+
|
| 138 |
+
return {
|
| 139 |
+
"dataset_type": dataset_type,
|
| 140 |
+
"import_dataset_btn": import_dataset_btn,
|
| 141 |
+
"search_type": search_type,
|
| 142 |
+
"search_value": search_value,
|
| 143 |
+
"instruction_display": instruction_display,
|
| 144 |
+
"repaint_viz_plot": repaint_viz_plot,
|
| 145 |
+
"item_info_json": item_info_json,
|
| 146 |
+
"item_src_audio": item_src_audio,
|
| 147 |
+
"get_item_btn": get_item_btn,
|
| 148 |
+
"item_target_audio": item_target_audio,
|
| 149 |
+
"item_refer_audio": item_refer_audio,
|
| 150 |
+
"use_src_checkbox": use_src_checkbox,
|
| 151 |
+
"data_status": data_status,
|
| 152 |
+
"auto_fill_btn": auto_fill_btn,
|
| 153 |
+
}
|
| 154 |
+
|
| 155 |
+
|
| 156 |
+
def create_generation_section(handler) -> dict:
|
| 157 |
+
"""Create generation section"""
|
| 158 |
+
with gr.Group():
|
| 159 |
+
gr.HTML('<div class="section-header"><h3>🎼 ACE-Step V1.5 Demo </h3></div>')
|
| 160 |
+
|
| 161 |
+
# Service Configuration
|
| 162 |
+
with gr.Accordion("🔧 Service Configuration", open=True) as service_config_accordion:
|
| 163 |
+
with gr.Row():
|
| 164 |
+
with gr.Column(scale=2):
|
| 165 |
+
checkpoint_dropdown = gr.Dropdown(
|
| 166 |
+
label="Checkpoint File",
|
| 167 |
+
choices=handler.get_available_checkpoints(),
|
| 168 |
+
value=None,
|
| 169 |
+
info="Select a trained model checkpoint file (full path or filename)"
|
| 170 |
+
)
|
| 171 |
+
with gr.Column(scale=1):
|
| 172 |
+
refresh_btn = gr.Button("🔄 Refresh", size="sm")
|
| 173 |
+
|
| 174 |
+
with gr.Row():
|
| 175 |
+
# Get available acestep-v15- model list
|
| 176 |
+
available_models = handler.get_available_acestep_v15_models()
|
| 177 |
+
default_model = "acestep-v15-turbo" if "acestep-v15-turbo" in available_models else (available_models[0] if available_models else None)
|
| 178 |
+
|
| 179 |
+
config_path = gr.Dropdown(
|
| 180 |
+
label="Main Model Path",
|
| 181 |
+
choices=available_models,
|
| 182 |
+
value=default_model,
|
| 183 |
+
info="Select the model configuration directory (auto-scanned from checkpoints)"
|
| 184 |
+
)
|
| 185 |
+
device = gr.Dropdown(
|
| 186 |
+
choices=["auto", "cuda", "cpu"],
|
| 187 |
+
value="auto",
|
| 188 |
+
label="Device",
|
| 189 |
+
info="Processing device (auto-detect recommended)"
|
| 190 |
+
)
|
| 191 |
+
|
| 192 |
+
with gr.Row():
|
| 193 |
+
# Get available 5Hz LM model list
|
| 194 |
+
available_lm_models = handler.get_available_5hz_lm_models()
|
| 195 |
+
default_lm_model = "acestep-5Hz-lm-0.6B" if "acestep-5Hz-lm-0.6B" in available_lm_models else (available_lm_models[0] if available_lm_models else None)
|
| 196 |
+
|
| 197 |
+
lm_model_path = gr.Dropdown(
|
| 198 |
+
label="5Hz LM Model Path",
|
| 199 |
+
choices=available_lm_models,
|
| 200 |
+
value=default_lm_model,
|
| 201 |
+
info="Select the 5Hz LM model checkpoint (auto-scanned from checkpoints)"
|
| 202 |
+
)
|
| 203 |
+
init_llm_checkbox = gr.Checkbox(
|
| 204 |
+
label="Initialize 5Hz LM",
|
| 205 |
+
value=False,
|
| 206 |
+
info="Check to initialize 5Hz LM during service initialization"
|
| 207 |
+
)
|
| 208 |
+
|
| 209 |
+
with gr.Row():
|
| 210 |
+
# Auto-detect flash attention availability
|
| 211 |
+
flash_attn_available = handler.is_flash_attention_available()
|
| 212 |
+
use_flash_attention_checkbox = gr.Checkbox(
|
| 213 |
+
label="Use Flash Attention",
|
| 214 |
+
value=flash_attn_available,
|
| 215 |
+
interactive=flash_attn_available,
|
| 216 |
+
info="Enable flash attention for faster inference (requires flash_attn package)" if flash_attn_available else "Flash attention not available (flash_attn package not installed)"
|
| 217 |
+
)
|
| 218 |
+
|
| 219 |
+
init_btn = gr.Button("Initialize Service", variant="primary", size="lg")
|
| 220 |
+
init_status = gr.Textbox(label="Status", interactive=False, lines=3)
|
| 221 |
+
|
| 222 |
+
# Inputs
|
| 223 |
+
with gr.Row():
|
| 224 |
+
with gr.Column(scale=2):
|
| 225 |
+
with gr.Accordion("📝 Required Inputs", open=True):
|
| 226 |
+
# Task type
|
| 227 |
+
with gr.Row():
|
| 228 |
+
with gr.Column(scale=2):
|
| 229 |
+
task_type = gr.Dropdown(
|
| 230 |
+
choices=["text2music", "repaint", "cover", "extract", "lego", "complete"],
|
| 231 |
+
value="text2music",
|
| 232 |
+
label="Task Type",
|
| 233 |
+
info="Select the task type for generation",
|
| 234 |
+
)
|
| 235 |
+
with gr.Column(scale=8):
|
| 236 |
+
instruction_display_gen = gr.Textbox(
|
| 237 |
+
label="Instruction",
|
| 238 |
+
value="Fill the audio semantic mask based on the given conditions:",
|
| 239 |
+
interactive=False,
|
| 240 |
+
lines=1,
|
| 241 |
+
info="Instruction is automatically generated based on task type",
|
| 242 |
+
)
|
| 243 |
+
|
| 244 |
+
track_name = gr.Dropdown(
|
| 245 |
+
choices=["woodwinds", "brass", "fx", "synth", "strings", "percussion",
|
| 246 |
+
"keyboard", "guitar", "bass", "drums", "backing_vocals", "vocals"],
|
| 247 |
+
value=None,
|
| 248 |
+
label="Track Name",
|
| 249 |
+
info="Select track name for lego/extract tasks",
|
| 250 |
+
visible=False
|
| 251 |
+
)
|
| 252 |
+
|
| 253 |
+
complete_track_classes = gr.CheckboxGroup(
|
| 254 |
+
choices=["woodwinds", "brass", "fx", "synth", "strings", "percussion",
|
| 255 |
+
"keyboard", "guitar", "bass", "drums", "backing_vocals", "vocals"],
|
| 256 |
+
label="Track Names",
|
| 257 |
+
info="Select multiple track classes for complete task",
|
| 258 |
+
visible=False
|
| 259 |
+
)
|
| 260 |
+
|
| 261 |
+
# Audio uploads
|
| 262 |
+
with gr.Accordion("🎵 Audio Uploads", open=False):
|
| 263 |
+
with gr.Row():
|
| 264 |
+
with gr.Column(scale=2):
|
| 265 |
+
reference_audio = gr.Audio(
|
| 266 |
+
label="Reference Audio (optional)",
|
| 267 |
+
type="filepath",
|
| 268 |
+
)
|
| 269 |
+
with gr.Column(scale=8):
|
| 270 |
+
src_audio = gr.Audio(
|
| 271 |
+
label="Source Audio (optional)",
|
| 272 |
+
type="filepath",
|
| 273 |
+
)
|
| 274 |
+
|
| 275 |
+
audio_code_string = gr.Textbox(
|
| 276 |
+
label="Audio Codes (optional)",
|
| 277 |
+
placeholder="<|audio_code_10695|><|audio_code_54246|>...",
|
| 278 |
+
lines=4,
|
| 279 |
+
visible=False,
|
| 280 |
+
info="Paste precomputed audio code tokens"
|
| 281 |
+
)
|
| 282 |
+
|
| 283 |
+
# Audio Codes for text2music
|
| 284 |
+
with gr.Accordion("🎼 Audio Codes (for text2music)", open=True, visible=True) as text2music_audio_codes_group:
|
| 285 |
+
text2music_audio_code_string = gr.Textbox(
|
| 286 |
+
label="Audio Codes",
|
| 287 |
+
placeholder="<|audio_code_10695|><|audio_code_54246|>...",
|
| 288 |
+
lines=6,
|
| 289 |
+
info="Paste precomputed audio code tokens for text2music generation"
|
| 290 |
+
)
|
| 291 |
+
|
| 292 |
+
# 5Hz LM
|
| 293 |
+
with gr.Row(visible=False) as use_5hz_lm_row:
|
| 294 |
+
use_5hz_lm_btn = gr.Button(
|
| 295 |
+
"Generate LM Hints",
|
| 296 |
+
variant="secondary",
|
| 297 |
+
size="lg",
|
| 298 |
+
)
|
| 299 |
+
lm_temperature = gr.Slider(
|
| 300 |
+
label="Temperature",
|
| 301 |
+
minimum=0.0,
|
| 302 |
+
maximum=2.0,
|
| 303 |
+
value=0.7,
|
| 304 |
+
step=0.1,
|
| 305 |
+
scale=2,
|
| 306 |
+
info="Temperature for 5Hz LM sampling"
|
| 307 |
+
)
|
| 308 |
+
|
| 309 |
+
# Repainting controls
|
| 310 |
+
with gr.Group(visible=False) as repainting_group:
|
| 311 |
+
gr.HTML("<h5>🎨 Repainting Controls (seconds) </h5>")
|
| 312 |
+
with gr.Row():
|
| 313 |
+
repainting_start = gr.Number(
|
| 314 |
+
label="Repainting Start",
|
| 315 |
+
value=0.0,
|
| 316 |
+
step=0.1,
|
| 317 |
+
)
|
| 318 |
+
repainting_end = gr.Number(
|
| 319 |
+
label="Repainting End",
|
| 320 |
+
value=-1,
|
| 321 |
+
minimum=-1,
|
| 322 |
+
step=0.1,
|
| 323 |
+
)
|
| 324 |
+
|
| 325 |
+
# Audio Cover Strength
|
| 326 |
+
audio_cover_strength = gr.Slider(
|
| 327 |
+
minimum=0.0,
|
| 328 |
+
maximum=1.0,
|
| 329 |
+
value=1.0,
|
| 330 |
+
step=0.01,
|
| 331 |
+
label="Audio Cover Strength",
|
| 332 |
+
info="Control how many denoising steps use cover mode",
|
| 333 |
+
visible=False
|
| 334 |
+
)
|
| 335 |
+
|
| 336 |
+
# Music Caption
|
| 337 |
+
with gr.Accordion("📝 Music Caption", open=True):
|
| 338 |
+
captions = gr.Textbox(
|
| 339 |
+
label="Music Caption (optional)",
|
| 340 |
+
placeholder="A peaceful acoustic guitar melody with soft vocals...",
|
| 341 |
+
lines=3,
|
| 342 |
+
info="Describe the style, genre, instruments, and mood"
|
| 343 |
+
)
|
| 344 |
+
|
| 345 |
+
# Lyrics
|
| 346 |
+
with gr.Accordion("📝 Lyrics", open=True):
|
| 347 |
+
lyrics = gr.Textbox(
|
| 348 |
+
label="Lyrics (optional)",
|
| 349 |
+
placeholder="[Verse 1]\nUnder the starry night\nI feel so alive...",
|
| 350 |
+
lines=8,
|
| 351 |
+
info="Song lyrics with structure"
|
| 352 |
+
)
|
| 353 |
+
|
| 354 |
+
# Optional Parameters
|
| 355 |
+
with gr.Accordion("⚙️ Optional Parameters", open=True):
|
| 356 |
+
with gr.Row():
|
| 357 |
+
vocal_language = gr.Dropdown(
|
| 358 |
+
choices=["en", "zh", "ja", "ko", "es", "fr", "de"],
|
| 359 |
+
value="en",
|
| 360 |
+
label="Vocal Language (optional)",
|
| 361 |
+
allow_custom_value=True
|
| 362 |
+
)
|
| 363 |
+
bpm = gr.Number(
|
| 364 |
+
label="BPM (optional)",
|
| 365 |
+
value=None,
|
| 366 |
+
step=1,
|
| 367 |
+
info="leave empty for N/A"
|
| 368 |
+
)
|
| 369 |
+
key_scale = gr.Textbox(
|
| 370 |
+
label="Key/Scale (optional)",
|
| 371 |
+
placeholder="Leave empty for N/A",
|
| 372 |
+
value="",
|
| 373 |
+
)
|
| 374 |
+
time_signature = gr.Dropdown(
|
| 375 |
+
choices=["2", "3", "4", "N/A", ""],
|
| 376 |
+
value="4",
|
| 377 |
+
label="Time Signature (optional)",
|
| 378 |
+
allow_custom_value=True
|
| 379 |
+
)
|
| 380 |
+
audio_duration = gr.Number(
|
| 381 |
+
label="Audio Duration (seconds)",
|
| 382 |
+
value=-1,
|
| 383 |
+
minimum=-1,
|
| 384 |
+
maximum=600.0,
|
| 385 |
+
step=0.1,
|
| 386 |
+
info="Use -1 for random"
|
| 387 |
+
)
|
| 388 |
+
batch_size_input = gr.Number(
|
| 389 |
+
label="Batch Size",
|
| 390 |
+
value=1,
|
| 391 |
+
minimum=1,
|
| 392 |
+
maximum=8,
|
| 393 |
+
step=1,
|
| 394 |
+
info="Number of audio files to parallel generate"
|
| 395 |
+
)
|
| 396 |
+
|
| 397 |
+
# Advanced Settings
|
| 398 |
+
with gr.Accordion("🔧 Advanced Settings", open=False):
|
| 399 |
+
with gr.Row():
|
| 400 |
+
inference_steps = gr.Slider(
|
| 401 |
+
minimum=1,
|
| 402 |
+
maximum=8,
|
| 403 |
+
value=8,
|
| 404 |
+
step=1,
|
| 405 |
+
label="Inference Steps",
|
| 406 |
+
info="Turbo: max 8, Base: max 100"
|
| 407 |
+
)
|
| 408 |
+
guidance_scale = gr.Slider(
|
| 409 |
+
minimum=1.0,
|
| 410 |
+
maximum=15.0,
|
| 411 |
+
value=7.0,
|
| 412 |
+
step=0.1,
|
| 413 |
+
label="Guidance Scale",
|
| 414 |
+
info="Higher values follow text more closely",
|
| 415 |
+
visible=False
|
| 416 |
+
)
|
| 417 |
+
seed = gr.Textbox(
|
| 418 |
+
label="Seed",
|
| 419 |
+
value="-1",
|
| 420 |
+
info="Use comma-separated values for batches"
|
| 421 |
+
)
|
| 422 |
+
random_seed_checkbox = gr.Checkbox(
|
| 423 |
+
label="Random Seed",
|
| 424 |
+
value=True,
|
| 425 |
+
info="Enable to auto-generate seeds"
|
| 426 |
+
)
|
| 427 |
+
|
| 428 |
+
with gr.Row():
|
| 429 |
+
use_adg = gr.Checkbox(
|
| 430 |
+
label="Use ADG",
|
| 431 |
+
value=False,
|
| 432 |
+
info="Enable Angle Domain Guidance",
|
| 433 |
+
visible=False
|
| 434 |
+
)
|
| 435 |
+
|
| 436 |
+
with gr.Row():
|
| 437 |
+
cfg_interval_start = gr.Slider(
|
| 438 |
+
minimum=0.0,
|
| 439 |
+
maximum=1.0,
|
| 440 |
+
value=0.0,
|
| 441 |
+
step=0.01,
|
| 442 |
+
label="CFG Interval Start",
|
| 443 |
+
visible=False
|
| 444 |
+
)
|
| 445 |
+
cfg_interval_end = gr.Slider(
|
| 446 |
+
minimum=0.0,
|
| 447 |
+
maximum=1.0,
|
| 448 |
+
value=1.0,
|
| 449 |
+
step=0.01,
|
| 450 |
+
label="CFG Interval End",
|
| 451 |
+
visible=False
|
| 452 |
+
)
|
| 453 |
+
|
| 454 |
+
with gr.Row():
|
| 455 |
+
audio_format = gr.Dropdown(
|
| 456 |
+
choices=["mp3", "flac"],
|
| 457 |
+
value="mp3",
|
| 458 |
+
label="Audio Format",
|
| 459 |
+
info="Audio format for saved files"
|
| 460 |
+
)
|
| 461 |
+
|
| 462 |
+
generate_btn = gr.Button("🎵 Generate Music", variant="primary", size="lg", interactive=False)
|
| 463 |
+
|
| 464 |
+
return {
|
| 465 |
+
"checkpoint_dropdown": checkpoint_dropdown,
|
| 466 |
+
"refresh_btn": refresh_btn,
|
| 467 |
+
"config_path": config_path,
|
| 468 |
+
"device": device,
|
| 469 |
+
"init_btn": init_btn,
|
| 470 |
+
"init_status": init_status,
|
| 471 |
+
"lm_model_path": lm_model_path,
|
| 472 |
+
"init_llm_checkbox": init_llm_checkbox,
|
| 473 |
+
"use_flash_attention_checkbox": use_flash_attention_checkbox,
|
| 474 |
+
"task_type": task_type,
|
| 475 |
+
"instruction_display_gen": instruction_display_gen,
|
| 476 |
+
"track_name": track_name,
|
| 477 |
+
"complete_track_classes": complete_track_classes,
|
| 478 |
+
"reference_audio": reference_audio,
|
| 479 |
+
"src_audio": src_audio,
|
| 480 |
+
"audio_code_string": audio_code_string,
|
| 481 |
+
"text2music_audio_code_string": text2music_audio_code_string,
|
| 482 |
+
"text2music_audio_codes_group": text2music_audio_codes_group,
|
| 483 |
+
"use_5hz_lm_row": use_5hz_lm_row,
|
| 484 |
+
"use_5hz_lm_btn": use_5hz_lm_btn,
|
| 485 |
+
"lm_temperature": lm_temperature,
|
| 486 |
+
"repainting_group": repainting_group,
|
| 487 |
+
"repainting_start": repainting_start,
|
| 488 |
+
"repainting_end": repainting_end,
|
| 489 |
+
"audio_cover_strength": audio_cover_strength,
|
| 490 |
+
"captions": captions,
|
| 491 |
+
"lyrics": lyrics,
|
| 492 |
+
"vocal_language": vocal_language,
|
| 493 |
+
"bpm": bpm,
|
| 494 |
+
"key_scale": key_scale,
|
| 495 |
+
"time_signature": time_signature,
|
| 496 |
+
"audio_duration": audio_duration,
|
| 497 |
+
"batch_size_input": batch_size_input,
|
| 498 |
+
"inference_steps": inference_steps,
|
| 499 |
+
"guidance_scale": guidance_scale,
|
| 500 |
+
"seed": seed,
|
| 501 |
+
"random_seed_checkbox": random_seed_checkbox,
|
| 502 |
+
"use_adg": use_adg,
|
| 503 |
+
"cfg_interval_start": cfg_interval_start,
|
| 504 |
+
"cfg_interval_end": cfg_interval_end,
|
| 505 |
+
"audio_format": audio_format,
|
| 506 |
+
"generate_btn": generate_btn,
|
| 507 |
+
}
|
| 508 |
+
|
| 509 |
+
|
| 510 |
+
def create_results_section(handler) -> dict:
|
| 511 |
+
"""Create results display section"""
|
| 512 |
+
with gr.Group():
|
| 513 |
+
gr.HTML('<div class="section-header"><h3>🎧 Generated Results</h3></div>')
|
| 514 |
+
|
| 515 |
+
status_output = gr.Textbox(label="Generation Status", interactive=False)
|
| 516 |
+
|
| 517 |
+
with gr.Row():
|
| 518 |
+
with gr.Column():
|
| 519 |
+
generated_audio_1 = gr.Audio(
|
| 520 |
+
label="🎵 Generated Music (Sample 1)",
|
| 521 |
+
type="filepath",
|
| 522 |
+
interactive=False
|
| 523 |
+
)
|
| 524 |
+
with gr.Column():
|
| 525 |
+
generated_audio_2 = gr.Audio(
|
| 526 |
+
label="🎵 Generated Music (Sample 2)",
|
| 527 |
+
type="filepath",
|
| 528 |
+
interactive=False
|
| 529 |
+
)
|
| 530 |
+
|
| 531 |
+
with gr.Accordion("📁 Batch Results & Generation Details", open=False):
|
| 532 |
+
generated_audio_batch = gr.File(
|
| 533 |
+
label="📁 All Generated Files (Download)",
|
| 534 |
+
file_count="multiple",
|
| 535 |
+
interactive=False
|
| 536 |
+
)
|
| 537 |
+
generation_info = gr.Markdown(label="Generation Details")
|
| 538 |
+
|
| 539 |
+
gr.Markdown("### ⚖️ Alignment Preference Analysis")
|
| 540 |
+
|
| 541 |
+
with gr.Row():
|
| 542 |
+
with gr.Column():
|
| 543 |
+
align_score_1 = gr.Textbox(label="Alignment Score (Sample 1)", interactive=False)
|
| 544 |
+
align_text_1 = gr.Textbox(label="Lyric Timestamps (Sample 1)", interactive=False, lines=10)
|
| 545 |
+
align_plot_1 = gr.Plot(label="Alignment Heatmap (Sample 1)")
|
| 546 |
+
with gr.Column():
|
| 547 |
+
align_score_2 = gr.Textbox(label="Alignment Score (Sample 2)", interactive=False)
|
| 548 |
+
align_text_2 = gr.Textbox(label="Lyric Timestamps (Sample 2)", interactive=False, lines=10)
|
| 549 |
+
align_plot_2 = gr.Plot(label="Alignment Heatmap (Sample 2)")
|
| 550 |
+
|
| 551 |
+
return {
|
| 552 |
+
"status_output": status_output,
|
| 553 |
+
"generated_audio_1": generated_audio_1,
|
| 554 |
+
"generated_audio_2": generated_audio_2,
|
| 555 |
+
"generated_audio_batch": generated_audio_batch,
|
| 556 |
+
"generation_info": generation_info,
|
| 557 |
+
"align_score_1": align_score_1,
|
| 558 |
+
"align_text_1": align_text_1,
|
| 559 |
+
"align_plot_1": align_plot_1,
|
| 560 |
+
"align_score_2": align_score_2,
|
| 561 |
+
"align_text_2": align_text_2,
|
| 562 |
+
"align_plot_2": align_plot_2,
|
| 563 |
+
}
|
| 564 |
+
|
| 565 |
+
|
| 566 |
+
def setup_event_handlers(demo, handler, dataset_section, generation_section, results_section):
|
| 567 |
+
"""Setup event handlers connecting UI components and business logic"""
|
| 568 |
+
|
| 569 |
+
def update_init_status(status_msg, enable_btn):
|
| 570 |
+
"""Update initialization status and enable/disable generate button"""
|
| 571 |
+
return status_msg, gr.update(interactive=enable_btn)
|
| 572 |
+
|
| 573 |
+
# Dataset handlers
|
| 574 |
+
dataset_section["import_dataset_btn"].click(
|
| 575 |
+
fn=handler.import_dataset,
|
| 576 |
+
inputs=[dataset_section["dataset_type"]],
|
| 577 |
+
outputs=[dataset_section["data_status"]]
|
| 578 |
+
)
|
| 579 |
+
|
| 580 |
+
# Service initialization - refresh checkpoints
|
| 581 |
+
def refresh_checkpoints():
|
| 582 |
+
choices = handler.get_available_checkpoints()
|
| 583 |
+
return gr.update(choices=choices)
|
| 584 |
+
|
| 585 |
+
generation_section["refresh_btn"].click(
|
| 586 |
+
fn=refresh_checkpoints,
|
| 587 |
+
outputs=[generation_section["checkpoint_dropdown"]]
|
| 588 |
+
)
|
| 589 |
+
|
| 590 |
+
# Update UI based on model type (turbo vs base)
|
| 591 |
+
def update_model_type_settings(config_path):
|
| 592 |
+
"""Update UI settings based on model type"""
|
| 593 |
+
if config_path is None:
|
| 594 |
+
config_path = ""
|
| 595 |
+
config_path_lower = config_path.lower()
|
| 596 |
+
|
| 597 |
+
if "turbo" in config_path_lower:
|
| 598 |
+
# Turbo model: max 8 steps, hide CFG/ADG
|
| 599 |
+
return (
|
| 600 |
+
gr.update(value=8, maximum=8, minimum=1), # inference_steps
|
| 601 |
+
gr.update(visible=False), # guidance_scale
|
| 602 |
+
gr.update(visible=False), # use_adg
|
| 603 |
+
gr.update(visible=False), # cfg_interval_start
|
| 604 |
+
gr.update(visible=False), # cfg_interval_end
|
| 605 |
+
)
|
| 606 |
+
elif "base" in config_path_lower:
|
| 607 |
+
# Base model: max 100 steps, show CFG/ADG
|
| 608 |
+
return (
|
| 609 |
+
gr.update(value=32, maximum=100, minimum=1), # inference_steps
|
| 610 |
+
gr.update(visible=True), # guidance_scale
|
| 611 |
+
gr.update(visible=True), # use_adg
|
| 612 |
+
gr.update(visible=True), # cfg_interval_start
|
| 613 |
+
gr.update(visible=True), # cfg_interval_end
|
| 614 |
+
)
|
| 615 |
+
else:
|
| 616 |
+
# Default to turbo settings
|
| 617 |
+
return (
|
| 618 |
+
gr.update(value=8, maximum=8, minimum=1),
|
| 619 |
+
gr.update(visible=False),
|
| 620 |
+
gr.update(visible=False),
|
| 621 |
+
gr.update(visible=False),
|
| 622 |
+
gr.update(visible=False),
|
| 623 |
+
)
|
| 624 |
+
|
| 625 |
+
generation_section["config_path"].change(
|
| 626 |
+
fn=update_model_type_settings,
|
| 627 |
+
inputs=[generation_section["config_path"]],
|
| 628 |
+
outputs=[
|
| 629 |
+
generation_section["inference_steps"],
|
| 630 |
+
generation_section["guidance_scale"],
|
| 631 |
+
generation_section["use_adg"],
|
| 632 |
+
generation_section["cfg_interval_start"],
|
| 633 |
+
generation_section["cfg_interval_end"],
|
| 634 |
+
]
|
| 635 |
+
)
|
| 636 |
+
|
| 637 |
+
# Service initialization
|
| 638 |
+
def init_service_wrapper(checkpoint, config_path, device, init_llm, lm_model_path, use_flash_attention):
|
| 639 |
+
"""Wrapper for service initialization, returns status and button state"""
|
| 640 |
+
status, enable = handler.initialize_service(checkpoint, config_path, device, init_llm, lm_model_path, use_flash_attention)
|
| 641 |
+
return status, gr.update(interactive=enable)
|
| 642 |
+
|
| 643 |
+
generation_section["init_btn"].click(
|
| 644 |
+
fn=init_service_wrapper,
|
| 645 |
+
inputs=[
|
| 646 |
+
generation_section["checkpoint_dropdown"],
|
| 647 |
+
generation_section["config_path"],
|
| 648 |
+
generation_section["device"],
|
| 649 |
+
generation_section["init_llm_checkbox"],
|
| 650 |
+
generation_section["lm_model_path"],
|
| 651 |
+
generation_section["use_flash_attention_checkbox"],
|
| 652 |
+
],
|
| 653 |
+
outputs=[generation_section["init_status"], generation_section["generate_btn"]]
|
| 654 |
+
)
|
| 655 |
+
|
| 656 |
+
# Generation with progress bar
|
| 657 |
+
def generate_with_progress(
|
| 658 |
+
captions, lyrics, bpm, key_scale, time_signature, vocal_language,
|
| 659 |
+
inference_steps, guidance_scale, random_seed_checkbox, seed,
|
| 660 |
+
reference_audio, audio_duration, batch_size_input, src_audio,
|
| 661 |
+
text2music_audio_code_string, repainting_start, repainting_end,
|
| 662 |
+
instruction_display_gen, audio_cover_strength, task_type,
|
| 663 |
+
use_adg, cfg_interval_start, cfg_interval_end, audio_format, lm_temperature,
|
| 664 |
+
progress=gr.Progress(track_tqdm=True)
|
| 665 |
+
):
|
| 666 |
+
return handler.generate_music(
|
| 667 |
+
captions=captions, lyrics=lyrics, bpm=bpm, key_scale=key_scale,
|
| 668 |
+
time_signature=time_signature, vocal_language=vocal_language,
|
| 669 |
+
inference_steps=inference_steps, guidance_scale=guidance_scale,
|
| 670 |
+
use_random_seed=random_seed_checkbox, seed=seed,
|
| 671 |
+
reference_audio=reference_audio, audio_duration=audio_duration,
|
| 672 |
+
batch_size=batch_size_input, src_audio=src_audio,
|
| 673 |
+
audio_code_string=text2music_audio_code_string,
|
| 674 |
+
repainting_start=repainting_start, repainting_end=repainting_end,
|
| 675 |
+
instruction=instruction_display_gen, audio_cover_strength=audio_cover_strength,
|
| 676 |
+
task_type=task_type, use_adg=use_adg,
|
| 677 |
+
cfg_interval_start=cfg_interval_start, cfg_interval_end=cfg_interval_end,
|
| 678 |
+
audio_format=audio_format, lm_temperature=lm_temperature,
|
| 679 |
+
progress=progress
|
| 680 |
+
)
|
| 681 |
+
|
| 682 |
+
generation_section["generate_btn"].click(
|
| 683 |
+
fn=generate_with_progress,
|
| 684 |
+
inputs=[
|
| 685 |
+
generation_section["captions"],
|
| 686 |
+
generation_section["lyrics"],
|
| 687 |
+
generation_section["bpm"],
|
| 688 |
+
generation_section["key_scale"],
|
| 689 |
+
generation_section["time_signature"],
|
| 690 |
+
generation_section["vocal_language"],
|
| 691 |
+
generation_section["inference_steps"],
|
| 692 |
+
generation_section["guidance_scale"],
|
| 693 |
+
generation_section["random_seed_checkbox"],
|
| 694 |
+
generation_section["seed"],
|
| 695 |
+
generation_section["reference_audio"],
|
| 696 |
+
generation_section["audio_duration"],
|
| 697 |
+
generation_section["batch_size_input"],
|
| 698 |
+
generation_section["src_audio"],
|
| 699 |
+
generation_section["text2music_audio_code_string"],
|
| 700 |
+
generation_section["repainting_start"],
|
| 701 |
+
generation_section["repainting_end"],
|
| 702 |
+
generation_section["instruction_display_gen"],
|
| 703 |
+
generation_section["audio_cover_strength"],
|
| 704 |
+
generation_section["task_type"],
|
| 705 |
+
generation_section["use_adg"],
|
| 706 |
+
generation_section["cfg_interval_start"],
|
| 707 |
+
generation_section["cfg_interval_end"],
|
| 708 |
+
generation_section["audio_format"],
|
| 709 |
+
generation_section["lm_temperature"]
|
| 710 |
+
],
|
| 711 |
+
outputs=[
|
| 712 |
+
results_section["generated_audio_1"],
|
| 713 |
+
results_section["generated_audio_2"],
|
| 714 |
+
results_section["generated_audio_batch"],
|
| 715 |
+
results_section["generation_info"],
|
| 716 |
+
results_section["status_output"],
|
| 717 |
+
generation_section["seed"],
|
| 718 |
+
results_section["align_score_1"],
|
| 719 |
+
results_section["align_text_1"],
|
| 720 |
+
results_section["align_plot_1"],
|
| 721 |
+
results_section["align_score_2"],
|
| 722 |
+
results_section["align_text_2"],
|
| 723 |
+
results_section["align_plot_2"]
|
| 724 |
+
]
|
| 725 |
+
)
|
| 726 |
+
|
| 727 |
+
# 5Hz LM generation (simplified version, can be extended as needed)
|
| 728 |
+
def generate_lm_hints_wrapper(caption, lyrics, temperature):
|
| 729 |
+
"""Wrapper for 5Hz LM generation"""
|
| 730 |
+
metadata, audio_codes, status = handler.generate_with_5hz_lm(caption, lyrics, temperature)
|
| 731 |
+
# 返回格式化的结果,可以根据需要调整
|
| 732 |
+
result_text = f"Status: {status}\n\nMetadata: {metadata}\n\nAudio Codes: {audio_codes[:200]}..." if len(audio_codes) > 200 else f"Status: {status}\n\nMetadata: {metadata}\n\nAudio Codes: {audio_codes}"
|
| 733 |
+
return result_text
|
| 734 |
+
|
| 735 |
+
generation_section["use_5hz_lm_btn"].click(
|
| 736 |
+
fn=generate_lm_hints_wrapper,
|
| 737 |
+
inputs=[
|
| 738 |
+
generation_section["captions"],
|
| 739 |
+
generation_section["lyrics"],
|
| 740 |
+
generation_section["lm_temperature"]
|
| 741 |
+
],
|
| 742 |
+
outputs=[generation_section["text2music_audio_code_string"]]
|
| 743 |
+
)
|
| 744 |
+
|
acestep/handler.py
ADDED
|
@@ -0,0 +1,1100 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Business Logic Handler
|
| 3 |
+
Encapsulates all data processing and business logic as a bridge between model and UI
|
| 4 |
+
"""
|
| 5 |
+
import os
|
| 6 |
+
import math
|
| 7 |
+
import glob
|
| 8 |
+
import tempfile
|
| 9 |
+
import traceback
|
| 10 |
+
import re
|
| 11 |
+
import random
|
| 12 |
+
from typing import Optional, Dict, Any, Tuple, List, Union
|
| 13 |
+
|
| 14 |
+
import torch
|
| 15 |
+
import matplotlib.pyplot as plt
|
| 16 |
+
import numpy as np
|
| 17 |
+
import scipy.io.wavfile as wavfile
|
| 18 |
+
import soundfile as sf
|
| 19 |
+
import time
|
| 20 |
+
|
| 21 |
+
from transformers import AutoTokenizer, AutoModel
|
| 22 |
+
from diffusers.models import AutoencoderOobleck
|
| 23 |
+
|
| 24 |
+
|
| 25 |
+
class AceStepHandler:
|
| 26 |
+
"""ACE-Step Business Logic Handler"""
|
| 27 |
+
|
| 28 |
+
def __init__(self):
|
| 29 |
+
self.model = None
|
| 30 |
+
self.config = None
|
| 31 |
+
self.device = "cpu"
|
| 32 |
+
self.dtype = torch.float32 # Will be set based on device in initialize_service
|
| 33 |
+
self.temp_dir = tempfile.mkdtemp()
|
| 34 |
+
|
| 35 |
+
# VAE for audio encoding/decoding
|
| 36 |
+
self.vae = None
|
| 37 |
+
|
| 38 |
+
# Text encoder and tokenizer
|
| 39 |
+
self.text_encoder = None
|
| 40 |
+
self.text_tokenizer = None
|
| 41 |
+
|
| 42 |
+
# Silence latent for initialization
|
| 43 |
+
self.silence_latent = None
|
| 44 |
+
|
| 45 |
+
# Sample rate
|
| 46 |
+
self.sample_rate = 48000
|
| 47 |
+
|
| 48 |
+
# 5Hz LM related
|
| 49 |
+
self.lm_model = None
|
| 50 |
+
self.lm_tokenizer = None
|
| 51 |
+
self.lm_initialized = False
|
| 52 |
+
|
| 53 |
+
# Reward model (temporarily disabled)
|
| 54 |
+
self.reward_model = None
|
| 55 |
+
|
| 56 |
+
# Dataset related (temporarily disabled)
|
| 57 |
+
self.dataset = None
|
| 58 |
+
self.dataset_imported = False
|
| 59 |
+
|
| 60 |
+
# Batch size
|
| 61 |
+
self.batch_size = 2
|
| 62 |
+
|
| 63 |
+
# Custom layers config
|
| 64 |
+
self.custom_layers_config = {
|
| 65 |
+
2: [6, 7],
|
| 66 |
+
3: [10, 11],
|
| 67 |
+
4: [3],
|
| 68 |
+
5: [8, 9, 11],
|
| 69 |
+
6: [8]
|
| 70 |
+
}
|
| 71 |
+
|
| 72 |
+
def get_available_checkpoints(self) -> str:
|
| 73 |
+
"""Return project root directory path"""
|
| 74 |
+
# Get project root (handler.py is in acestep/, so go up two levels to project root)
|
| 75 |
+
current_file = os.path.abspath(__file__)
|
| 76 |
+
project_root = os.path.dirname(os.path.dirname(current_file))
|
| 77 |
+
# default checkpoints
|
| 78 |
+
checkpoint_dir = os.path.join(project_root, "checkpoints")
|
| 79 |
+
if os.path.exists(checkpoint_dir):
|
| 80 |
+
return [checkpoint_dir]
|
| 81 |
+
else:
|
| 82 |
+
return []
|
| 83 |
+
|
| 84 |
+
def get_available_acestep_v15_models(self) -> List[str]:
|
| 85 |
+
"""Scan and return all model directory names starting with 'acestep-v15-'"""
|
| 86 |
+
# Get project root
|
| 87 |
+
current_file = os.path.abspath(__file__)
|
| 88 |
+
project_root = os.path.dirname(os.path.dirname(current_file))
|
| 89 |
+
checkpoint_dir = os.path.join(project_root, "checkpoints")
|
| 90 |
+
|
| 91 |
+
models = []
|
| 92 |
+
if os.path.exists(checkpoint_dir):
|
| 93 |
+
# Scan all directories starting with 'acestep-v15-' in checkpoints folder
|
| 94 |
+
for item in os.listdir(checkpoint_dir):
|
| 95 |
+
item_path = os.path.join(checkpoint_dir, item)
|
| 96 |
+
if os.path.isdir(item_path) and item.startswith("acestep-v15-"):
|
| 97 |
+
models.append(item)
|
| 98 |
+
|
| 99 |
+
# Sort by name
|
| 100 |
+
models.sort()
|
| 101 |
+
return models
|
| 102 |
+
|
| 103 |
+
def get_available_5hz_lm_models(self) -> List[str]:
|
| 104 |
+
"""Scan and return all model directory names starting with 'acestep-5Hz-lm-'"""
|
| 105 |
+
current_file = os.path.abspath(__file__)
|
| 106 |
+
project_root = os.path.dirname(os.path.dirname(current_file))
|
| 107 |
+
checkpoint_dir = os.path.join(project_root, "checkpoints")
|
| 108 |
+
|
| 109 |
+
models = []
|
| 110 |
+
if os.path.exists(checkpoint_dir):
|
| 111 |
+
for item in os.listdir(checkpoint_dir):
|
| 112 |
+
item_path = os.path.join(checkpoint_dir, item)
|
| 113 |
+
if os.path.isdir(item_path) and item.startswith("acestep-5Hz-lm-"):
|
| 114 |
+
models.append(item)
|
| 115 |
+
|
| 116 |
+
models.sort()
|
| 117 |
+
return models
|
| 118 |
+
|
| 119 |
+
def is_flash_attention_available(self) -> bool:
|
| 120 |
+
"""Check if flash attention is available on the system"""
|
| 121 |
+
try:
|
| 122 |
+
import flash_attn
|
| 123 |
+
return True
|
| 124 |
+
except ImportError:
|
| 125 |
+
return False
|
| 126 |
+
|
| 127 |
+
def initialize_service(
|
| 128 |
+
self,
|
| 129 |
+
project_root: str,
|
| 130 |
+
config_path: str,
|
| 131 |
+
device: str = "auto",
|
| 132 |
+
init_llm: bool = False,
|
| 133 |
+
lm_model_path: str = "acestep-5Hz-lm-0.6B",
|
| 134 |
+
use_flash_attention: bool = False,
|
| 135 |
+
) -> Tuple[str, bool]:
|
| 136 |
+
"""
|
| 137 |
+
Initialize model service
|
| 138 |
+
|
| 139 |
+
Args:
|
| 140 |
+
project_root: Project root path (may be checkpoints directory, will be handled automatically)
|
| 141 |
+
config_path: Model config directory name (e.g., "acestep-v15-turbo")
|
| 142 |
+
device: Device type
|
| 143 |
+
init_llm: Whether to initialize 5Hz LM model
|
| 144 |
+
lm_model_path: 5Hz LM model path
|
| 145 |
+
use_flash_attention: Whether to use flash attention (requires flash_attn package)
|
| 146 |
+
|
| 147 |
+
Returns:
|
| 148 |
+
(status_message, enable_generate_button)
|
| 149 |
+
"""
|
| 150 |
+
try:
|
| 151 |
+
if device == "auto":
|
| 152 |
+
device = "cuda" if torch.cuda.is_available() else "cpu"
|
| 153 |
+
|
| 154 |
+
self.device = device
|
| 155 |
+
# Set dtype based on device: bfloat16 for cuda, float32 for cpu
|
| 156 |
+
self.dtype = torch.bfloat16 if device == "cuda" else torch.float32
|
| 157 |
+
|
| 158 |
+
# Auto-detect project root (independent of passed project_root parameter)
|
| 159 |
+
current_file = os.path.abspath(__file__)
|
| 160 |
+
actual_project_root = os.path.dirname(os.path.dirname(current_file))
|
| 161 |
+
checkpoint_dir = os.path.join(actual_project_root, "checkpoints")
|
| 162 |
+
|
| 163 |
+
# 1. Load main model
|
| 164 |
+
# config_path is relative path (e.g., "acestep-v15-turbo"), concatenate to checkpoints directory
|
| 165 |
+
acestep_v15_checkpoint_path = os.path.join(checkpoint_dir, config_path)
|
| 166 |
+
if os.path.exists(acestep_v15_checkpoint_path):
|
| 167 |
+
# Determine attention implementation
|
| 168 |
+
attn_implementation = "flash_attention_2" if use_flash_attention and self.is_flash_attention_available() else "eager"
|
| 169 |
+
self.model = AutoModel.from_pretrained(
|
| 170 |
+
acestep_v15_checkpoint_path,
|
| 171 |
+
trust_remote_code=True,
|
| 172 |
+
attn_implementation=attn_implementation
|
| 173 |
+
)
|
| 174 |
+
self.config = self.model.config
|
| 175 |
+
# Move model to device and set dtype
|
| 176 |
+
self.model = self.model.to(device).to(self.dtype)
|
| 177 |
+
self.model.eval()
|
| 178 |
+
silence_latent_path = os.path.join(acestep_v15_checkpoint_path, "silence_latent.pt")
|
| 179 |
+
if os.path.exists(silence_latent_path):
|
| 180 |
+
self.silence_latent = torch.load(silence_latent_path).transpose(1, 2).squeeze(0) # [L, C]
|
| 181 |
+
self.silence_latent = self.silence_latent.to(device).to(self.dtype)
|
| 182 |
+
else:
|
| 183 |
+
raise FileNotFoundError(f"Silence latent not found at {silence_latent_path}")
|
| 184 |
+
else:
|
| 185 |
+
raise FileNotFoundError(f"ACE-Step V1.5 checkpoint not found at {acestep_v15_checkpoint_path}")
|
| 186 |
+
|
| 187 |
+
# 2. Load VAE
|
| 188 |
+
vae_checkpoint_path = os.path.join(checkpoint_dir, "vae")
|
| 189 |
+
if os.path.exists(vae_checkpoint_path):
|
| 190 |
+
self.vae = AutoencoderOobleck.from_pretrained(vae_checkpoint_path)
|
| 191 |
+
self.vae = self.vae.to(device).to(self.dtype)
|
| 192 |
+
self.vae.eval()
|
| 193 |
+
else:
|
| 194 |
+
raise FileNotFoundError(f"VAE checkpoint not found at {vae_checkpoint_path}")
|
| 195 |
+
|
| 196 |
+
# 3. Load text encoder and tokenizer
|
| 197 |
+
text_encoder_path = os.path.join(checkpoint_dir, "Qwen3-Embedding-0.6B")
|
| 198 |
+
if os.path.exists(text_encoder_path):
|
| 199 |
+
self.text_tokenizer = AutoTokenizer.from_pretrained(text_encoder_path)
|
| 200 |
+
self.text_encoder = AutoModel.from_pretrained(text_encoder_path)
|
| 201 |
+
self.text_encoder = self.text_encoder.to(device).to(self.dtype)
|
| 202 |
+
self.text_encoder.eval()
|
| 203 |
+
else:
|
| 204 |
+
raise FileNotFoundError(f"Text encoder not found at {text_encoder_path}")
|
| 205 |
+
|
| 206 |
+
# 4. Load 5Hz LM model (optional, only if init_llm is True)
|
| 207 |
+
if init_llm:
|
| 208 |
+
full_lm_model_path = os.path.join(checkpoint_dir, lm_model_path)
|
| 209 |
+
if os.path.exists(full_lm_model_path):
|
| 210 |
+
if device == "cuda":
|
| 211 |
+
status_msg = self._initialize_5hz_lm_cuda(full_lm_model_path)
|
| 212 |
+
if not self.llm_initialized:
|
| 213 |
+
return status_msg, False
|
| 214 |
+
self.llm = AutoModel.from_pretrained(full_lm_model_path)
|
| 215 |
+
self.llm_tokenizer = AutoTokenizer.from_pretrained(full_lm_model_path)
|
| 216 |
+
else:
|
| 217 |
+
# 5Hz LM path not found
|
| 218 |
+
return f"❌ 5Hz LM model not found at {full_lm_model_path}", False
|
| 219 |
+
|
| 220 |
+
# Determine actual attention implementation used
|
| 221 |
+
actual_attn = "flash_attention_2" if use_flash_attention and self.is_flash_attention_available() else "eager"
|
| 222 |
+
|
| 223 |
+
status_msg = f"✅ Model initialized successfully on {device}\n"
|
| 224 |
+
status_msg += f"Main model: {acestep_v15_checkpoint_path}\n"
|
| 225 |
+
status_msg += f"VAE: {vae_checkpoint_path}\n"
|
| 226 |
+
status_msg += f"Text encoder: {text_encoder_path}\n"
|
| 227 |
+
if init_llm and hasattr(self, 'llm') and self.llm is not None:
|
| 228 |
+
status_msg += f"5Hz LM model: {os.path.join(checkpoint_dir, lm_model_path)}\n"
|
| 229 |
+
else:
|
| 230 |
+
status_msg += f"5Hz LM model: Not loaded (checkbox not selected)\n"
|
| 231 |
+
status_msg += f"Dtype: {self.dtype}\n"
|
| 232 |
+
status_msg += f"Attention: {actual_attn}"
|
| 233 |
+
|
| 234 |
+
return status_msg, True
|
| 235 |
+
|
| 236 |
+
except Exception as e:
|
| 237 |
+
error_msg = f"❌ Error initializing model: {str(e)}\n\nTraceback:\n{traceback.format_exc()}"
|
| 238 |
+
return error_msg, False
|
| 239 |
+
|
| 240 |
+
def import_dataset(self, dataset_type: str) -> str:
|
| 241 |
+
"""Import dataset (temporarily disabled)"""
|
| 242 |
+
self.dataset_imported = False
|
| 243 |
+
return f"⚠️ Dataset import is currently disabled. Text2MusicDataset dependency not available."
|
| 244 |
+
|
| 245 |
+
def get_item_data(self, *args, **kwargs):
|
| 246 |
+
"""Get dataset item (temporarily disabled)"""
|
| 247 |
+
return "", "", "", "", "", None, None, None, "❌ Dataset not available", "", 0, "", None, None, None, {}, "text2music"
|
| 248 |
+
|
| 249 |
+
def get_gpu_memory_utilization(self, minimal_gpu: float = 8, min_ratio: float = 0.2, max_ratio: float = 0.9) -> float:
|
| 250 |
+
"""Get GPU memory utilization ratio"""
|
| 251 |
+
try:
|
| 252 |
+
device = torch.device("cuda:0")
|
| 253 |
+
total_gpu_mem_bytes = torch.cuda.get_device_properties(device).total_memory
|
| 254 |
+
allocated_mem_bytes = torch.cuda.memory_allocated(device)
|
| 255 |
+
reserved_mem_bytes = torch.cuda.memory_reserved(device)
|
| 256 |
+
|
| 257 |
+
total_gpu = total_gpu_mem_bytes / 1024**3
|
| 258 |
+
allocated_gpu = allocated_mem_bytes / 1024**3
|
| 259 |
+
reserved_gpu = reserved_mem_bytes / 1024**3
|
| 260 |
+
available_gpu = total_gpu - reserved_gpu
|
| 261 |
+
|
| 262 |
+
if available_gpu >= minimal_gpu:
|
| 263 |
+
ratio = min(max_ratio, max(min_ratio, minimal_gpu / total_gpu))
|
| 264 |
+
else:
|
| 265 |
+
ratio = min(max_ratio, max(min_ratio, (available_gpu * 0.8) / total_gpu))
|
| 266 |
+
|
| 267 |
+
return ratio
|
| 268 |
+
except Exception as e:
|
| 269 |
+
return 0.9
|
| 270 |
+
|
| 271 |
+
def _initialize_5hz_lm_cuda(self, model_path: str) -> str:
|
| 272 |
+
"""Initialize 5Hz LM model"""
|
| 273 |
+
try:
|
| 274 |
+
from nanovllm import LLM, SamplingParams
|
| 275 |
+
|
| 276 |
+
if not torch.cuda.is_available():
|
| 277 |
+
return "❌ CUDA is not available. Please check your GPU setup."
|
| 278 |
+
|
| 279 |
+
current_device = torch.cuda.current_device()
|
| 280 |
+
device_name = torch.cuda.get_device_name(current_device)
|
| 281 |
+
|
| 282 |
+
torch.cuda.empty_cache()
|
| 283 |
+
gpu_memory_utilization = self.get_gpu_memory_utilization(
|
| 284 |
+
minimal_gpu=8,
|
| 285 |
+
min_ratio=0.2,
|
| 286 |
+
max_ratio=0.9
|
| 287 |
+
)
|
| 288 |
+
|
| 289 |
+
self.llm = LLM(
|
| 290 |
+
model=model_path,
|
| 291 |
+
enforce_eager=False,
|
| 292 |
+
tensor_parallel_size=1,
|
| 293 |
+
max_model_len=4096,
|
| 294 |
+
gpu_memory_utilization=gpu_memory_utilization,
|
| 295 |
+
)
|
| 296 |
+
self.llm_tokenizer = self.llm.tokenizer
|
| 297 |
+
self.llm_initialized = True
|
| 298 |
+
return f"✅ 5Hz LM initialized successfully\nModel: {model_path}\nDevice: {device_name}\nGPU Memory Utilization: {gpu_memory_utilization:.2f}"
|
| 299 |
+
except Exception as e:
|
| 300 |
+
self.llm_initialized = False
|
| 301 |
+
error_msg = f"❌ Error initializing 5Hz LM: {str(e)}\n\nTraceback:\n{traceback.format_exc()}"
|
| 302 |
+
return error_msg
|
| 303 |
+
|
| 304 |
+
def generate_with_5hz_lm(self, caption: str, lyrics: str, temperature: float = 0.6) -> Tuple[Dict[str, Any], str, str]:
|
| 305 |
+
"""Generate metadata and audio codes using 5Hz LM"""
|
| 306 |
+
if not self.lm_initialized or self.llm is None:
|
| 307 |
+
return {}, "", "❌ 5Hz LM not initialized. Please initialize it first."
|
| 308 |
+
|
| 309 |
+
try:
|
| 310 |
+
from nanovllm import SamplingParams
|
| 311 |
+
|
| 312 |
+
prompt = f"# Caption\n{caption}\n\n# Lyric\n{lyrics}\n"
|
| 313 |
+
|
| 314 |
+
formatted_prompt = self.lm_tokenizer.apply_chat_template(
|
| 315 |
+
[
|
| 316 |
+
{"role": "system", "content": "# Instruction\nGenerate audio semantic tokens based on the given conditions:\n\n"},
|
| 317 |
+
{"role": "user", "content": prompt}
|
| 318 |
+
],
|
| 319 |
+
tokenize=False,
|
| 320 |
+
add_generation_prompt=True,
|
| 321 |
+
)
|
| 322 |
+
|
| 323 |
+
sampling_params = SamplingParams(max_tokens=3072, temperature=temperature)
|
| 324 |
+
outputs = self.llm.generate([formatted_prompt], sampling_params)
|
| 325 |
+
|
| 326 |
+
if isinstance(outputs, list) and len(outputs) > 0:
|
| 327 |
+
if hasattr(outputs[0], 'outputs') and len(outputs[0].outputs) > 0:
|
| 328 |
+
output_text = outputs[0].outputs[0].text
|
| 329 |
+
elif hasattr(outputs[0], 'text'):
|
| 330 |
+
output_text = outputs[0].text
|
| 331 |
+
else:
|
| 332 |
+
output_text = str(outputs[0])
|
| 333 |
+
else:
|
| 334 |
+
output_text = str(outputs)
|
| 335 |
+
|
| 336 |
+
metadata, audio_codes = self.parse_lm_output(output_text)
|
| 337 |
+
codes_count = len(audio_codes.split('<|audio_code_')) - 1 if audio_codes else 0
|
| 338 |
+
return metadata, audio_codes, f"✅ Generated successfully\nOutput length: {len(output_text)} chars\nCodes count: {codes_count}"
|
| 339 |
+
|
| 340 |
+
except Exception as e:
|
| 341 |
+
error_msg = f"❌ Error generating with 5Hz LM: {str(e)}\n\nTraceback:\n{traceback.format_exc()}"
|
| 342 |
+
return {}, "", error_msg
|
| 343 |
+
|
| 344 |
+
def parse_lm_output(self, output_text: str) -> Tuple[Dict[str, Any], str]:
|
| 345 |
+
"""Parse LM output"""
|
| 346 |
+
metadata = {}
|
| 347 |
+
audio_codes = ""
|
| 348 |
+
|
| 349 |
+
import re
|
| 350 |
+
|
| 351 |
+
# Extract audio codes
|
| 352 |
+
code_pattern = r'<\|audio_code_\d+\|>'
|
| 353 |
+
code_matches = re.findall(code_pattern, output_text)
|
| 354 |
+
if code_matches:
|
| 355 |
+
audio_codes = "".join(code_matches)
|
| 356 |
+
|
| 357 |
+
# Extract metadata
|
| 358 |
+
reasoning_patterns = [
|
| 359 |
+
r'<think>(.*?)</think>',
|
| 360 |
+
r'<reasoning>(.*?)</reasoning>',
|
| 361 |
+
]
|
| 362 |
+
|
| 363 |
+
reasoning_text = None
|
| 364 |
+
for pattern in reasoning_patterns:
|
| 365 |
+
match = re.search(pattern, output_text, re.DOTALL)
|
| 366 |
+
if match:
|
| 367 |
+
reasoning_text = match.group(1).strip()
|
| 368 |
+
break
|
| 369 |
+
|
| 370 |
+
if not reasoning_text:
|
| 371 |
+
lines_before_codes = output_text.split('<|audio_code_')[0] if '<|audio_code_' in output_text else output_text
|
| 372 |
+
reasoning_text = lines_before_codes.strip()
|
| 373 |
+
|
| 374 |
+
# Parse metadata fields
|
| 375 |
+
if reasoning_text:
|
| 376 |
+
for line in reasoning_text.split('\n'):
|
| 377 |
+
line = line.strip()
|
| 378 |
+
if ':' in line and not line.startswith('<'):
|
| 379 |
+
parts = line.split(':', 1)
|
| 380 |
+
if len(parts) == 2:
|
| 381 |
+
key = parts[0].strip().lower()
|
| 382 |
+
value = parts[1].strip()
|
| 383 |
+
|
| 384 |
+
if key == 'bpm':
|
| 385 |
+
try:
|
| 386 |
+
metadata['bpm'] = int(value)
|
| 387 |
+
except:
|
| 388 |
+
metadata['bpm'] = value
|
| 389 |
+
elif key == 'duration':
|
| 390 |
+
try:
|
| 391 |
+
metadata['duration'] = int(value)
|
| 392 |
+
except:
|
| 393 |
+
metadata['duration'] = value
|
| 394 |
+
elif key in ['genres', 'keyscale', 'timesignature']:
|
| 395 |
+
metadata[key] = value
|
| 396 |
+
|
| 397 |
+
return metadata, audio_codes
|
| 398 |
+
|
| 399 |
+
def process_reference_audio(self, audio_file) -> Optional[torch.Tensor]:
|
| 400 |
+
"""Process reference audio"""
|
| 401 |
+
if audio_file is None:
|
| 402 |
+
return None
|
| 403 |
+
|
| 404 |
+
try:
|
| 405 |
+
# Load audio using soundfile
|
| 406 |
+
audio_np, sr = sf.read(audio_file, dtype='float32')
|
| 407 |
+
# Convert to torch: [samples, channels] or [samples] -> [channels, samples]
|
| 408 |
+
if audio_np.ndim == 1:
|
| 409 |
+
audio = torch.from_numpy(audio_np).unsqueeze(0)
|
| 410 |
+
else:
|
| 411 |
+
audio = torch.from_numpy(audio_np.T)
|
| 412 |
+
|
| 413 |
+
if audio.shape[0] == 1:
|
| 414 |
+
audio = torch.cat([audio, audio], dim=0)
|
| 415 |
+
|
| 416 |
+
audio = audio[:2]
|
| 417 |
+
|
| 418 |
+
# Resample if needed
|
| 419 |
+
if sr != 48000:
|
| 420 |
+
import torch.nn.functional as F
|
| 421 |
+
# Simple resampling using interpolate
|
| 422 |
+
ratio = 48000 / sr
|
| 423 |
+
new_length = int(audio.shape[-1] * ratio)
|
| 424 |
+
audio = F.interpolate(audio.unsqueeze(0), size=new_length, mode='linear', align_corners=False).squeeze(0)
|
| 425 |
+
|
| 426 |
+
audio = torch.clamp(audio, -1.0, 1.0)
|
| 427 |
+
|
| 428 |
+
target_frames = 30 * 48000
|
| 429 |
+
if audio.shape[-1] > target_frames:
|
| 430 |
+
start_frame = (audio.shape[-1] - target_frames) // 2
|
| 431 |
+
audio = audio[:, start_frame:start_frame + target_frames]
|
| 432 |
+
elif audio.shape[-1] < target_frames:
|
| 433 |
+
audio = torch.nn.functional.pad(
|
| 434 |
+
audio, (0, target_frames - audio.shape[-1]), 'constant', 0
|
| 435 |
+
)
|
| 436 |
+
|
| 437 |
+
return audio
|
| 438 |
+
except Exception as e:
|
| 439 |
+
print(f"Error processing reference audio: {e}")
|
| 440 |
+
return None
|
| 441 |
+
|
| 442 |
+
def process_target_audio(self, audio_file) -> Optional[torch.Tensor]:
|
| 443 |
+
"""Process target audio"""
|
| 444 |
+
if audio_file is None:
|
| 445 |
+
return None
|
| 446 |
+
|
| 447 |
+
try:
|
| 448 |
+
# Load audio using soundfile
|
| 449 |
+
audio_np, sr = sf.read(audio_file, dtype='float32')
|
| 450 |
+
# Convert to torch: [samples, channels] or [samples] -> [channels, samples]
|
| 451 |
+
if audio_np.ndim == 1:
|
| 452 |
+
audio = torch.from_numpy(audio_np).unsqueeze(0)
|
| 453 |
+
else:
|
| 454 |
+
audio = torch.from_numpy(audio_np.T)
|
| 455 |
+
|
| 456 |
+
if audio.shape[0] == 1:
|
| 457 |
+
audio = torch.cat([audio, audio], dim=0)
|
| 458 |
+
|
| 459 |
+
audio = audio[:2]
|
| 460 |
+
|
| 461 |
+
# Resample if needed
|
| 462 |
+
if sr != 48000:
|
| 463 |
+
import torch.nn.functional as F
|
| 464 |
+
ratio = 48000 / sr
|
| 465 |
+
new_length = int(audio.shape[-1] * ratio)
|
| 466 |
+
audio = F.interpolate(audio.unsqueeze(0), size=new_length, mode='linear', align_corners=False).squeeze(0)
|
| 467 |
+
|
| 468 |
+
audio = torch.clamp(audio, -1.0, 1.0)
|
| 469 |
+
|
| 470 |
+
return audio
|
| 471 |
+
except Exception as e:
|
| 472 |
+
print(f"Error processing target audio: {e}")
|
| 473 |
+
return None
|
| 474 |
+
|
| 475 |
+
def _parse_audio_code_string(self, code_str: str) -> List[int]:
|
| 476 |
+
"""Extract integer audio codes from prompt tokens like <|audio_code_123|>."""
|
| 477 |
+
if not code_str:
|
| 478 |
+
return []
|
| 479 |
+
try:
|
| 480 |
+
return [int(x) for x in re.findall(r"<\|audio_code_(\d+)\|>", code_str)]
|
| 481 |
+
except Exception:
|
| 482 |
+
return []
|
| 483 |
+
|
| 484 |
+
def _decode_audio_codes_to_latents(self, code_str: str) -> Optional[torch.Tensor]:
|
| 485 |
+
"""
|
| 486 |
+
Convert serialized audio code string into 25Hz latents using model quantizer/detokenizer.
|
| 487 |
+
"""
|
| 488 |
+
if not self.model or not hasattr(self.model, 'tokenizer') or not hasattr(self.model, 'detokenizer'):
|
| 489 |
+
return None
|
| 490 |
+
|
| 491 |
+
code_ids = self._parse_audio_code_string(code_str)
|
| 492 |
+
if len(code_ids) == 0:
|
| 493 |
+
return None
|
| 494 |
+
|
| 495 |
+
quantizer = self.model.tokenizer.quantizer
|
| 496 |
+
detokenizer = self.model.detokenizer
|
| 497 |
+
|
| 498 |
+
num_quantizers = getattr(quantizer, "num_quantizers", 1)
|
| 499 |
+
indices = torch.tensor(code_ids, device=self.device, dtype=torch.long).unsqueeze(0) # [1, T_5Hz]
|
| 500 |
+
|
| 501 |
+
# Expand to include quantizer dimension: [1, T_5Hz, num_quantizers]
|
| 502 |
+
if indices.dim() == 2:
|
| 503 |
+
indices = indices.unsqueeze(-1).expand(-1, -1, num_quantizers)
|
| 504 |
+
|
| 505 |
+
# Get quantized representation from indices: [1, T_5Hz, dim]
|
| 506 |
+
quantized = quantizer.get_output_from_indices(indices)
|
| 507 |
+
if quantized.dtype != self.dtype:
|
| 508 |
+
quantized = quantized.to(self.dtype)
|
| 509 |
+
|
| 510 |
+
# Detokenize to 25Hz: [1, T_5Hz, dim] -> [1, T_25Hz, dim]
|
| 511 |
+
lm_hints_25hz = detokenizer(quantized)
|
| 512 |
+
return lm_hints_25hz
|
| 513 |
+
|
| 514 |
+
def _create_default_meta(self) -> str:
|
| 515 |
+
"""Create default metadata string."""
|
| 516 |
+
return (
|
| 517 |
+
"- bpm: N/A\n"
|
| 518 |
+
"- timesignature: N/A\n"
|
| 519 |
+
"- keyscale: N/A\n"
|
| 520 |
+
"- duration: 30 seconds\n"
|
| 521 |
+
)
|
| 522 |
+
|
| 523 |
+
def _dict_to_meta_string(self, meta_dict: Dict[str, Any]) -> str:
|
| 524 |
+
"""Convert metadata dict to formatted string."""
|
| 525 |
+
bpm = meta_dict.get('bpm', meta_dict.get('tempo', 'N/A'))
|
| 526 |
+
timesignature = meta_dict.get('timesignature', meta_dict.get('time_signature', 'N/A'))
|
| 527 |
+
keyscale = meta_dict.get('keyscale', meta_dict.get('key', meta_dict.get('scale', 'N/A')))
|
| 528 |
+
duration = meta_dict.get('duration', meta_dict.get('length', 30))
|
| 529 |
+
|
| 530 |
+
# Format duration
|
| 531 |
+
if isinstance(duration, (int, float)):
|
| 532 |
+
duration = f"{int(duration)} seconds"
|
| 533 |
+
elif not isinstance(duration, str):
|
| 534 |
+
duration = "30 seconds"
|
| 535 |
+
|
| 536 |
+
return (
|
| 537 |
+
f"- bpm: {bpm}\n"
|
| 538 |
+
f"- timesignature: {timesignature}\n"
|
| 539 |
+
f"- keyscale: {keyscale}\n"
|
| 540 |
+
f"- duration: {duration}\n"
|
| 541 |
+
)
|
| 542 |
+
|
| 543 |
+
def _parse_metas(self, metas: List[Union[str, Dict[str, Any]]]) -> List[str]:
|
| 544 |
+
"""Parse and normalize metadata with fallbacks."""
|
| 545 |
+
parsed_metas = []
|
| 546 |
+
for meta in metas:
|
| 547 |
+
if meta is None:
|
| 548 |
+
parsed_meta = self._create_default_meta()
|
| 549 |
+
elif isinstance(meta, str):
|
| 550 |
+
parsed_meta = meta
|
| 551 |
+
elif isinstance(meta, dict):
|
| 552 |
+
parsed_meta = self._dict_to_meta_string(meta)
|
| 553 |
+
else:
|
| 554 |
+
parsed_meta = self._create_default_meta()
|
| 555 |
+
parsed_metas.append(parsed_meta)
|
| 556 |
+
return parsed_metas
|
| 557 |
+
|
| 558 |
+
def _get_text_hidden_states(self, text_prompt: str) -> Tuple[torch.Tensor, torch.Tensor]:
|
| 559 |
+
"""Get text hidden states from text encoder."""
|
| 560 |
+
if self.text_tokenizer is None or self.text_encoder is None:
|
| 561 |
+
raise ValueError("Text encoder not initialized")
|
| 562 |
+
|
| 563 |
+
# Tokenize
|
| 564 |
+
text_inputs = self.text_tokenizer(
|
| 565 |
+
text_prompt,
|
| 566 |
+
padding="longest",
|
| 567 |
+
truncation=True,
|
| 568 |
+
max_length=256,
|
| 569 |
+
return_tensors="pt",
|
| 570 |
+
)
|
| 571 |
+
text_input_ids = text_inputs.input_ids.to(self.device)
|
| 572 |
+
text_attention_mask = text_inputs.attention_mask.to(self.device).bool()
|
| 573 |
+
|
| 574 |
+
# Encode
|
| 575 |
+
with torch.no_grad():
|
| 576 |
+
text_outputs = self.text_encoder(text_input_ids)
|
| 577 |
+
if hasattr(text_outputs, 'last_hidden_state'):
|
| 578 |
+
text_hidden_states = text_outputs.last_hidden_state
|
| 579 |
+
elif isinstance(text_outputs, tuple):
|
| 580 |
+
text_hidden_states = text_outputs[0]
|
| 581 |
+
else:
|
| 582 |
+
text_hidden_states = text_outputs
|
| 583 |
+
|
| 584 |
+
text_hidden_states = text_hidden_states.to(self.dtype)
|
| 585 |
+
|
| 586 |
+
return text_hidden_states, text_attention_mask
|
| 587 |
+
|
| 588 |
+
def extract_caption_from_sft_format(self, caption: str) -> str:
|
| 589 |
+
"""Extract caption from SFT format if needed."""
|
| 590 |
+
# Simple extraction - can be enhanced if needed
|
| 591 |
+
if caption and isinstance(caption, str):
|
| 592 |
+
return caption.strip()
|
| 593 |
+
return caption if caption else ""
|
| 594 |
+
|
| 595 |
+
def generate_music(
|
| 596 |
+
self,
|
| 597 |
+
captions: str,
|
| 598 |
+
lyrics: str,
|
| 599 |
+
bpm: Optional[int] = None,
|
| 600 |
+
key_scale: str = "",
|
| 601 |
+
time_signature: str = "",
|
| 602 |
+
vocal_language: str = "en",
|
| 603 |
+
inference_steps: int = 8,
|
| 604 |
+
guidance_scale: float = 7.0,
|
| 605 |
+
use_random_seed: bool = True,
|
| 606 |
+
seed: Optional[Union[str, float, int]] = -1,
|
| 607 |
+
reference_audio=None,
|
| 608 |
+
audio_duration: Optional[float] = None,
|
| 609 |
+
batch_size: Optional[int] = None,
|
| 610 |
+
src_audio=None,
|
| 611 |
+
audio_code_string: str = "",
|
| 612 |
+
repainting_start: float = 0.0,
|
| 613 |
+
repainting_end: Optional[float] = None,
|
| 614 |
+
instruction: str = "Fill the audio semantic mask based on the given conditions:",
|
| 615 |
+
audio_cover_strength: float = 1.0,
|
| 616 |
+
task_type: str = "text2music",
|
| 617 |
+
use_adg: bool = False,
|
| 618 |
+
cfg_interval_start: float = 0.0,
|
| 619 |
+
cfg_interval_end: float = 1.0,
|
| 620 |
+
audio_format: str = "mp3",
|
| 621 |
+
lm_temperature: float = 0.6,
|
| 622 |
+
progress=None
|
| 623 |
+
) -> Tuple[Optional[str], Optional[str], List[str], str, str, str, str, str, Optional[Any], str, str, Optional[Any]]:
|
| 624 |
+
"""
|
| 625 |
+
Main interface for music generation
|
| 626 |
+
|
| 627 |
+
Returns:
|
| 628 |
+
(first_audio, second_audio, all_audio_paths, generation_info, status_message,
|
| 629 |
+
seed_value_for_ui, align_score_1, align_text_1, align_plot_1,
|
| 630 |
+
align_score_2, align_text_2, align_plot_2)
|
| 631 |
+
"""
|
| 632 |
+
if self.model is None or self.vae is None or self.text_tokenizer is None or self.text_encoder is None:
|
| 633 |
+
return None, None, [], "", "❌ Model not fully initialized. Please initialize all components first.", "-1", "", "", None, "", "", None
|
| 634 |
+
|
| 635 |
+
try:
|
| 636 |
+
print("[generate_music] Starting generation...")
|
| 637 |
+
if progress:
|
| 638 |
+
progress(0.05, desc="Preparing inputs...")
|
| 639 |
+
print("[generate_music] Preparing inputs...")
|
| 640 |
+
|
| 641 |
+
# Determine actual batch size
|
| 642 |
+
actual_batch_size = batch_size if batch_size is not None else self.batch_size
|
| 643 |
+
actual_batch_size = max(1, min(actual_batch_size, 8)) # Limit to 8 for memory safety
|
| 644 |
+
|
| 645 |
+
# Process seeds
|
| 646 |
+
if use_random_seed:
|
| 647 |
+
seed_list = [random.randint(0, 2**32 - 1) for _ in range(actual_batch_size)]
|
| 648 |
+
else:
|
| 649 |
+
# Parse seed input
|
| 650 |
+
if isinstance(seed, str):
|
| 651 |
+
seed_parts = [s.strip() for s in seed.split(",")]
|
| 652 |
+
seed_list = [int(float(s)) if s != "-1" and s else random.randint(0, 2**32 - 1) for s in seed_parts[:actual_batch_size]]
|
| 653 |
+
elif isinstance(seed, (int, float)) and seed >= 0:
|
| 654 |
+
seed_list = [int(seed)] * actual_batch_size
|
| 655 |
+
else:
|
| 656 |
+
seed_list = [random.randint(0, 2**32 - 1) for _ in range(actual_batch_size)]
|
| 657 |
+
|
| 658 |
+
# Pad if needed
|
| 659 |
+
while len(seed_list) < actual_batch_size:
|
| 660 |
+
seed_list.append(random.randint(0, 2**32 - 1))
|
| 661 |
+
|
| 662 |
+
seed_value_for_ui = ", ".join(str(s) for s in seed_list)
|
| 663 |
+
|
| 664 |
+
# Process audio inputs
|
| 665 |
+
processed_ref_audio = self.process_reference_audio(reference_audio) if reference_audio else None
|
| 666 |
+
processed_src_audio = self.process_target_audio(src_audio) if src_audio else None
|
| 667 |
+
|
| 668 |
+
# Extract caption
|
| 669 |
+
pure_caption = self.extract_caption_from_sft_format(captions)
|
| 670 |
+
|
| 671 |
+
# Determine task type and update instruction if needed
|
| 672 |
+
if task_type == "text2music" and audio_code_string and str(audio_code_string).strip():
|
| 673 |
+
task_type = "cover"
|
| 674 |
+
instruction = "Generate audio semantic tokens based on the given conditions:"
|
| 675 |
+
|
| 676 |
+
# Build metadata
|
| 677 |
+
metadata_dict = {
|
| 678 |
+
"bpm": bpm if bpm else "N/A",
|
| 679 |
+
"keyscale": key_scale if key_scale else "N/A",
|
| 680 |
+
"timesignature": time_signature if time_signature else "N/A",
|
| 681 |
+
}
|
| 682 |
+
|
| 683 |
+
# Calculate duration
|
| 684 |
+
if processed_src_audio is not None:
|
| 685 |
+
calculated_duration = processed_src_audio.shape[-1] / self.sample_rate
|
| 686 |
+
elif audio_duration is not None and audio_duration > 0:
|
| 687 |
+
calculated_duration = audio_duration
|
| 688 |
+
else:
|
| 689 |
+
calculated_duration = 30.0 # Default 30 seconds
|
| 690 |
+
|
| 691 |
+
metadata_dict["duration"] = f"{int(calculated_duration)} seconds"
|
| 692 |
+
|
| 693 |
+
if progress:
|
| 694 |
+
progress(0.1, desc="Processing audio inputs...")
|
| 695 |
+
print("[generate_music] Processing audio inputs...")
|
| 696 |
+
|
| 697 |
+
# Prepare batch data
|
| 698 |
+
captions_batch = [pure_caption] * actual_batch_size
|
| 699 |
+
lyrics_batch = [lyrics] * actual_batch_size
|
| 700 |
+
vocal_languages_batch = [vocal_language] * actual_batch_size
|
| 701 |
+
instructions_batch = [instruction] * actual_batch_size
|
| 702 |
+
metas_batch = [metadata_dict.copy()] * actual_batch_size
|
| 703 |
+
audio_code_hints_batch = [audio_code_string if audio_code_string else None] * actual_batch_size
|
| 704 |
+
|
| 705 |
+
# Process reference audios
|
| 706 |
+
if processed_ref_audio is not None:
|
| 707 |
+
refer_audios = [[processed_ref_audio] for _ in range(actual_batch_size)]
|
| 708 |
+
else:
|
| 709 |
+
# Create silence as fallback
|
| 710 |
+
silence_frames = 30 * self.sample_rate
|
| 711 |
+
silence = torch.zeros(2, silence_frames)
|
| 712 |
+
refer_audios = [[silence] for _ in range(actual_batch_size)]
|
| 713 |
+
|
| 714 |
+
# Process target wavs (src_audio)
|
| 715 |
+
if processed_src_audio is not None:
|
| 716 |
+
target_wavs_list = [processed_src_audio.clone() for _ in range(actual_batch_size)]
|
| 717 |
+
else:
|
| 718 |
+
# Create silence based on duration
|
| 719 |
+
target_frames = int(calculated_duration * self.sample_rate)
|
| 720 |
+
silence = torch.zeros(2, target_frames)
|
| 721 |
+
target_wavs_list = [silence for _ in range(actual_batch_size)]
|
| 722 |
+
|
| 723 |
+
# Pad target_wavs to consistent length
|
| 724 |
+
max_target_frames = max(wav.shape[-1] for wav in target_wavs_list)
|
| 725 |
+
target_wavs = torch.stack([
|
| 726 |
+
torch.nn.functional.pad(wav, (0, max_target_frames - wav.shape[-1]), 'constant', 0)
|
| 727 |
+
for wav in target_wavs_list
|
| 728 |
+
])
|
| 729 |
+
|
| 730 |
+
if progress:
|
| 731 |
+
progress(0.2, desc="Encoding audio to latents...")
|
| 732 |
+
print("[generate_music] Encoding audio to latents...")
|
| 733 |
+
|
| 734 |
+
# Encode target_wavs to latents using VAE
|
| 735 |
+
target_latents_list = []
|
| 736 |
+
latent_lengths = []
|
| 737 |
+
|
| 738 |
+
with torch.no_grad():
|
| 739 |
+
for i in range(actual_batch_size):
|
| 740 |
+
# Check if audio codes are provided
|
| 741 |
+
code_hint = audio_code_hints_batch[i]
|
| 742 |
+
if code_hint:
|
| 743 |
+
decoded_latents = self._decode_audio_codes_to_latents(code_hint)
|
| 744 |
+
if decoded_latents is not None:
|
| 745 |
+
decoded_latents = decoded_latents.squeeze(0) # Remove batch dim
|
| 746 |
+
target_latents_list.append(decoded_latents)
|
| 747 |
+
latent_lengths.append(decoded_latents.shape[0])
|
| 748 |
+
continue
|
| 749 |
+
|
| 750 |
+
# If no src_audio provided, use silence_latent directly (skip VAE)
|
| 751 |
+
if processed_src_audio is None:
|
| 752 |
+
# Calculate required latent length based on duration
|
| 753 |
+
# VAE downsample ratio is 1920 (2*4*4*6*10), so latent rate is 48000/1920 = 25Hz
|
| 754 |
+
latent_length = int(calculated_duration * 25) # 25Hz latent rate
|
| 755 |
+
latent_length = max(128, latent_length) # Minimum 128
|
| 756 |
+
|
| 757 |
+
# Tile silence_latent to required length
|
| 758 |
+
if self.silence_latent.shape[0] >= latent_length:
|
| 759 |
+
target_latent = self.silence_latent[:latent_length].to(self.device).to(self.dtype)
|
| 760 |
+
else:
|
| 761 |
+
repeat_times = (latent_length // self.silence_latent.shape[0]) + 1
|
| 762 |
+
target_latent = self.silence_latent.repeat(repeat_times, 1)[:latent_length].to(self.device).to(self.dtype)
|
| 763 |
+
target_latents_list.append(target_latent)
|
| 764 |
+
latent_lengths.append(target_latent.shape[0])
|
| 765 |
+
continue
|
| 766 |
+
|
| 767 |
+
# Encode from audio using VAE
|
| 768 |
+
current_wav = target_wavs[i].unsqueeze(0).to(self.device).to(self.dtype)
|
| 769 |
+
target_latent = self.vae.encode(current_wav)
|
| 770 |
+
target_latent = target_latent.squeeze(0).transpose(0, 1) # [latent_length, latent_dim]
|
| 771 |
+
target_latents_list.append(target_latent)
|
| 772 |
+
latent_lengths.append(target_latent.shape[0])
|
| 773 |
+
|
| 774 |
+
# Pad latents to same length
|
| 775 |
+
max_latent_length = max(latent_lengths)
|
| 776 |
+
max_latent_length = max(128, max_latent_length) # Minimum 128
|
| 777 |
+
|
| 778 |
+
padded_latents = []
|
| 779 |
+
for i, latent in enumerate(target_latents_list):
|
| 780 |
+
if latent.shape[0] < max_latent_length:
|
| 781 |
+
pad_length = max_latent_length - latent.shape[0]
|
| 782 |
+
# Tile silence_latent to pad_length (silence_latent is [L, C])
|
| 783 |
+
if self.silence_latent.shape[0] >= pad_length:
|
| 784 |
+
pad_latent = self.silence_latent[:pad_length]
|
| 785 |
+
else:
|
| 786 |
+
repeat_times = (pad_length // self.silence_latent.shape[0]) + 1
|
| 787 |
+
pad_latent = self.silence_latent.repeat(repeat_times, 1)[:pad_length]
|
| 788 |
+
latent = torch.cat([latent, pad_latent.to(self.device).to(self.dtype)], dim=0)
|
| 789 |
+
padded_latents.append(latent)
|
| 790 |
+
|
| 791 |
+
target_latents = torch.stack(padded_latents).to(self.device).to(self.dtype)
|
| 792 |
+
latent_masks = torch.stack([
|
| 793 |
+
torch.cat([
|
| 794 |
+
torch.ones(l, dtype=torch.long, device=self.device),
|
| 795 |
+
torch.zeros(max_latent_length - l, dtype=torch.long, device=self.device)
|
| 796 |
+
])
|
| 797 |
+
for l in latent_lengths
|
| 798 |
+
])
|
| 799 |
+
|
| 800 |
+
if progress:
|
| 801 |
+
progress(0.3, desc="Preparing conditions...")
|
| 802 |
+
print("[generate_music] Preparing conditions...")
|
| 803 |
+
|
| 804 |
+
# Determine task type and create chunk masks
|
| 805 |
+
is_covers = []
|
| 806 |
+
chunk_masks = []
|
| 807 |
+
repainting_ranges = {}
|
| 808 |
+
|
| 809 |
+
for i in range(actual_batch_size):
|
| 810 |
+
has_code_hint = audio_code_hints_batch[i] is not None
|
| 811 |
+
has_repainting = (repainting_end is not None and repainting_end > repainting_start)
|
| 812 |
+
|
| 813 |
+
if has_repainting:
|
| 814 |
+
# Repainting mode
|
| 815 |
+
start_sec = max(0, repainting_start)
|
| 816 |
+
end_sec = repainting_end if repainting_end is not None else calculated_duration
|
| 817 |
+
|
| 818 |
+
start_latent = int(start_sec * self.sample_rate // 1920)
|
| 819 |
+
end_latent = int(end_sec * self.sample_rate // 1920)
|
| 820 |
+
start_latent = max(0, min(start_latent, max_latent_length - 1))
|
| 821 |
+
end_latent = max(start_latent + 1, min(end_latent, max_latent_length))
|
| 822 |
+
|
| 823 |
+
mask = torch.zeros(max_latent_length, dtype=torch.bool, device=self.device)
|
| 824 |
+
mask[start_latent:end_latent] = True
|
| 825 |
+
chunk_masks.append(mask)
|
| 826 |
+
repainting_ranges[i] = (start_latent, end_latent)
|
| 827 |
+
is_covers.append(False)
|
| 828 |
+
else:
|
| 829 |
+
# Full generation or cover
|
| 830 |
+
chunk_masks.append(torch.ones(max_latent_length, dtype=torch.bool, device=self.device))
|
| 831 |
+
# Check if cover task
|
| 832 |
+
instruction_lower = instructions_batch[i].lower()
|
| 833 |
+
is_cover = ("generate audio semantic tokens" in instruction_lower and
|
| 834 |
+
"based on the given conditions" in instruction_lower) or has_code_hint
|
| 835 |
+
is_covers.append(is_cover)
|
| 836 |
+
|
| 837 |
+
chunk_masks = torch.stack(chunk_masks).unsqueeze(-1).expand(-1, -1, 64) # [batch, length, 64]
|
| 838 |
+
is_covers = torch.tensor(is_covers, dtype=torch.bool, device=self.device)
|
| 839 |
+
|
| 840 |
+
# Create src_latents
|
| 841 |
+
# Tile silence_latent to max_latent_length (silence_latent is now [L, C])
|
| 842 |
+
if self.silence_latent.shape[0] >= max_latent_length:
|
| 843 |
+
silence_latent_tiled = self.silence_latent[:max_latent_length].to(self.device).to(self.dtype)
|
| 844 |
+
else:
|
| 845 |
+
repeat_times = (max_latent_length // self.silence_latent.shape[0]) + 1
|
| 846 |
+
silence_latent_tiled = self.silence_latent.repeat(repeat_times, 1)[:max_latent_length].to(self.device).to(self.dtype)
|
| 847 |
+
src_latents_list = []
|
| 848 |
+
|
| 849 |
+
for i in range(actual_batch_size):
|
| 850 |
+
has_target_audio = (target_wavs[i].abs().sum() > 1e-6) or (audio_code_hints_batch[i] is not None)
|
| 851 |
+
|
| 852 |
+
if has_target_audio:
|
| 853 |
+
if i in repainting_ranges:
|
| 854 |
+
# Repaint: replace inpainting region with silence
|
| 855 |
+
src_latent = target_latents[i].clone()
|
| 856 |
+
start_latent, end_latent = repainting_ranges[i]
|
| 857 |
+
src_latent[start_latent:end_latent] = silence_latent_tiled[start_latent:end_latent]
|
| 858 |
+
src_latents_list.append(src_latent)
|
| 859 |
+
else:
|
| 860 |
+
# Cover/extract/complete/lego: use target_latents
|
| 861 |
+
src_latents_list.append(target_latents[i].clone())
|
| 862 |
+
else:
|
| 863 |
+
# Text2music: use silence
|
| 864 |
+
src_latents_list.append(silence_latent_tiled.clone())
|
| 865 |
+
|
| 866 |
+
src_latents = torch.stack(src_latents_list) # [batch, length, channels]
|
| 867 |
+
|
| 868 |
+
if progress:
|
| 869 |
+
progress(0.4, desc="Tokenizing text inputs...")
|
| 870 |
+
print("[generate_music] Tokenizing text inputs...")
|
| 871 |
+
|
| 872 |
+
# Prepare text and lyric hidden states
|
| 873 |
+
SFT_GEN_PROMPT = """# Instruction
|
| 874 |
+
{}
|
| 875 |
+
|
| 876 |
+
# Caption
|
| 877 |
+
{}
|
| 878 |
+
|
| 879 |
+
# Metas
|
| 880 |
+
{}<|endoftext|>
|
| 881 |
+
"""
|
| 882 |
+
|
| 883 |
+
text_hidden_states_list = []
|
| 884 |
+
text_attention_masks_list = []
|
| 885 |
+
lyric_hidden_states_list = []
|
| 886 |
+
lyric_attention_masks_list = []
|
| 887 |
+
|
| 888 |
+
with torch.no_grad():
|
| 889 |
+
for i in range(actual_batch_size):
|
| 890 |
+
# Format text prompt
|
| 891 |
+
inst = instructions_batch[i]
|
| 892 |
+
if not inst.endswith(":"):
|
| 893 |
+
inst = inst + ":"
|
| 894 |
+
|
| 895 |
+
meta_str = self._dict_to_meta_string(metas_batch[i])
|
| 896 |
+
text_prompt = SFT_GEN_PROMPT.format(inst, captions_batch[i], meta_str)
|
| 897 |
+
|
| 898 |
+
# Tokenize and encode text
|
| 899 |
+
text_hidden, text_mask = self._get_text_hidden_states(text_prompt)
|
| 900 |
+
text_hidden_states_list.append(text_hidden.squeeze(0))
|
| 901 |
+
text_attention_masks_list.append(text_mask.squeeze(0))
|
| 902 |
+
|
| 903 |
+
# Format and tokenize lyrics
|
| 904 |
+
lyrics_text = f"# Languages\n{vocal_languages_batch[i]}\n\n# Lyric\n{lyrics_batch[i]}<|endoftext|>"
|
| 905 |
+
lyric_hidden, lyric_mask = self._get_text_hidden_states(lyrics_text)
|
| 906 |
+
lyric_hidden_states_list.append(lyric_hidden.squeeze(0))
|
| 907 |
+
lyric_attention_masks_list.append(lyric_mask.squeeze(0))
|
| 908 |
+
|
| 909 |
+
# Pad sequences
|
| 910 |
+
max_text_length = max(h.shape[0] for h in text_hidden_states_list)
|
| 911 |
+
max_lyric_length = max(h.shape[0] for h in lyric_hidden_states_list)
|
| 912 |
+
|
| 913 |
+
text_hidden_states = torch.stack([
|
| 914 |
+
torch.nn.functional.pad(h, (0, 0, 0, max_text_length - h.shape[0]), 'constant', 0)
|
| 915 |
+
for h in text_hidden_states_list
|
| 916 |
+
]).to(self.device).to(self.dtype)
|
| 917 |
+
|
| 918 |
+
text_attention_mask = torch.stack([
|
| 919 |
+
torch.nn.functional.pad(m, (0, max_text_length - m.shape[0]), 'constant', 0)
|
| 920 |
+
for m in text_attention_masks_list
|
| 921 |
+
]).to(self.device)
|
| 922 |
+
|
| 923 |
+
lyric_hidden_states = torch.stack([
|
| 924 |
+
torch.nn.functional.pad(h, (0, 0, 0, max_lyric_length - h.shape[0]), 'constant', 0)
|
| 925 |
+
for h in lyric_hidden_states_list
|
| 926 |
+
]).to(self.device).to(self.dtype)
|
| 927 |
+
|
| 928 |
+
lyric_attention_mask = torch.stack([
|
| 929 |
+
torch.nn.functional.pad(m, (0, max_lyric_length - m.shape[0]), 'constant', 0)
|
| 930 |
+
for m in lyric_attention_masks_list
|
| 931 |
+
]).to(self.device)
|
| 932 |
+
|
| 933 |
+
if progress:
|
| 934 |
+
progress(0.5, desc="Processing reference audio...")
|
| 935 |
+
print("[generate_music] Processing reference audio...")
|
| 936 |
+
|
| 937 |
+
# Process reference audio for timbre
|
| 938 |
+
# Model expects: refer_audio_acoustic_hidden_states_packed [N, timbre_fix_frame, audio_acoustic_hidden_dim]
|
| 939 |
+
# refer_audio_order_mask [N] indicating batch assignment
|
| 940 |
+
timbre_fix_frame = getattr(self.config, 'timbre_fix_frame', 750)
|
| 941 |
+
refer_audio_acoustic_hidden_states_packed_list = []
|
| 942 |
+
refer_audio_order_mask_list = []
|
| 943 |
+
|
| 944 |
+
with torch.no_grad():
|
| 945 |
+
for i, ref_audio_list in enumerate(refer_audios):
|
| 946 |
+
if ref_audio_list and len(ref_audio_list) > 0 and ref_audio_list[0].abs().sum() > 1e-6:
|
| 947 |
+
# Encode reference audio: [channels, samples] -> [1, latent_dim, T] -> [T, latent_dim]
|
| 948 |
+
ref_audio = ref_audio_list[0].unsqueeze(0).to(self.device).to(self.dtype)
|
| 949 |
+
ref_latent = self.vae.encode(ref_audio).latent_dist.sample() # [1, latent_dim, T]
|
| 950 |
+
ref_latent = ref_latent.squeeze(0).transpose(0, 1) # [T, latent_dim]
|
| 951 |
+
# Ensure dimension matches audio_acoustic_hidden_dim (64)
|
| 952 |
+
if ref_latent.shape[-1] != self.config.audio_acoustic_hidden_dim:
|
| 953 |
+
ref_latent = ref_latent[:, :self.config.audio_acoustic_hidden_dim]
|
| 954 |
+
# Pad or truncate to timbre_fix_frame
|
| 955 |
+
if ref_latent.shape[0] < timbre_fix_frame:
|
| 956 |
+
pad_length = timbre_fix_frame - ref_latent.shape[0]
|
| 957 |
+
padding = torch.zeros(pad_length, ref_latent.shape[1], device=self.device, dtype=self.dtype)
|
| 958 |
+
ref_latent = torch.cat([ref_latent, padding], dim=0)
|
| 959 |
+
else:
|
| 960 |
+
ref_latent = ref_latent[:timbre_fix_frame]
|
| 961 |
+
refer_audio_acoustic_hidden_states_packed_list.append(ref_latent)
|
| 962 |
+
refer_audio_order_mask_list.append(i)
|
| 963 |
+
else:
|
| 964 |
+
# Use silence_latent directly instead of running VAE
|
| 965 |
+
if self.silence_latent.shape[0] >= timbre_fix_frame:
|
| 966 |
+
silence_ref = self.silence_latent[:timbre_fix_frame, :self.config.audio_acoustic_hidden_dim]
|
| 967 |
+
else:
|
| 968 |
+
repeat_times = (timbre_fix_frame // self.silence_latent.shape[0]) + 1
|
| 969 |
+
silence_ref = self.silence_latent.repeat(repeat_times, 1)[:timbre_fix_frame, :self.config.audio_acoustic_hidden_dim]
|
| 970 |
+
refer_audio_acoustic_hidden_states_packed_list.append(silence_ref.to(self.device).to(self.dtype))
|
| 971 |
+
refer_audio_order_mask_list.append(i)
|
| 972 |
+
|
| 973 |
+
# Stack all reference audios: [N, timbre_fix_frame, audio_acoustic_hidden_dim]
|
| 974 |
+
refer_audio_acoustic_hidden_states_packed = torch.stack(refer_audio_acoustic_hidden_states_packed_list, dim=0).to(self.device).to(self.dtype)
|
| 975 |
+
# Order mask: [N] indicating which batch item each reference belongs to
|
| 976 |
+
refer_audio_order_mask = torch.tensor(refer_audio_order_mask_list, dtype=torch.long, device=self.device)
|
| 977 |
+
|
| 978 |
+
if progress:
|
| 979 |
+
progress(0.6, desc="Generating audio...")
|
| 980 |
+
print("[generate_music] Calling model.generate_audio()...")
|
| 981 |
+
print(f" - text_hidden_states: {text_hidden_states.shape}, dtype={text_hidden_states.dtype}")
|
| 982 |
+
print(f" - text_attention_mask: {text_attention_mask.shape}, dtype={text_attention_mask.dtype}")
|
| 983 |
+
print(f" - lyric_hidden_states: {lyric_hidden_states.shape}, dtype={lyric_hidden_states.dtype}")
|
| 984 |
+
print(f" - lyric_attention_mask: {lyric_attention_mask.shape}, dtype={lyric_attention_mask.dtype}")
|
| 985 |
+
print(f" - refer_audio_acoustic_hidden_states_packed: {refer_audio_acoustic_hidden_states_packed.shape}, dtype={refer_audio_acoustic_hidden_states_packed.dtype}")
|
| 986 |
+
print(f" - refer_audio_order_mask: {refer_audio_order_mask.shape}, dtype={refer_audio_order_mask.dtype}")
|
| 987 |
+
print(f" - src_latents: {src_latents.shape}, dtype={src_latents.dtype}")
|
| 988 |
+
print(f" - chunk_masks: {chunk_masks.shape}, dtype={chunk_masks.dtype}")
|
| 989 |
+
print(f" - is_covers: {is_covers.shape}, dtype={is_covers.dtype}")
|
| 990 |
+
print(f" - silence_latent: {self.silence_latent.unsqueeze(0).shape}")
|
| 991 |
+
print(f" - seed: {seed_list[0] if len(seed_list) > 0 else None}")
|
| 992 |
+
print(f" - fix_nfe: {inference_steps}")
|
| 993 |
+
|
| 994 |
+
# Call model to generate
|
| 995 |
+
with torch.no_grad():
|
| 996 |
+
outputs = self.model.generate_audio(
|
| 997 |
+
text_hidden_states=text_hidden_states,
|
| 998 |
+
text_attention_mask=text_attention_mask,
|
| 999 |
+
lyric_hidden_states=lyric_hidden_states,
|
| 1000 |
+
lyric_attention_mask=lyric_attention_mask,
|
| 1001 |
+
refer_audio_acoustic_hidden_states_packed=refer_audio_acoustic_hidden_states_packed,
|
| 1002 |
+
refer_audio_order_mask=refer_audio_order_mask,
|
| 1003 |
+
src_latents=src_latents,
|
| 1004 |
+
chunk_masks=chunk_masks,
|
| 1005 |
+
is_covers=is_covers,
|
| 1006 |
+
silence_latent=self.silence_latent.unsqueeze(0), # [1, L, C]
|
| 1007 |
+
seed=seed_list[0] if len(seed_list) > 0 else None,
|
| 1008 |
+
fix_nfe=inference_steps,
|
| 1009 |
+
infer_method="ode",
|
| 1010 |
+
use_cache=True,
|
| 1011 |
+
)
|
| 1012 |
+
|
| 1013 |
+
print("[generate_music] Model generation completed. Decoding latents...")
|
| 1014 |
+
pred_latents = outputs["target_latents"] # [batch, latent_length, latent_dim]
|
| 1015 |
+
time_costs = outputs["time_costs"]
|
| 1016 |
+
print(f" - pred_latents: {pred_latents.shape}, dtype={pred_latents.dtype} {pred_latents.min()=}, {pred_latents.max()=}, {pred_latents.mean()=} {pred_latents.std()=}")
|
| 1017 |
+
print(f" - time_costs: {time_costs}")
|
| 1018 |
+
if progress:
|
| 1019 |
+
progress(0.8, desc="Decoding audio...")
|
| 1020 |
+
print("[generate_music] Decoding latents with VAE...")
|
| 1021 |
+
|
| 1022 |
+
# Decode latents to audio
|
| 1023 |
+
start_time = time.time()
|
| 1024 |
+
with torch.no_grad():
|
| 1025 |
+
# Transpose for VAE decode: [batch, latent_length, latent_dim] -> [batch, latent_dim, latent_length]
|
| 1026 |
+
pred_latents_for_decode = pred_latents.transpose(1, 2)
|
| 1027 |
+
pred_wavs = self.vae.decode(pred_latents_for_decode).sample # [batch, channels, samples]
|
| 1028 |
+
end_time = time.time()
|
| 1029 |
+
time_costs["vae_decode_time_cost"] = end_time - start_time
|
| 1030 |
+
time_costs["total_time_cost"] = time_costs["total_time_cost"] + time_costs["vae_decode_time_cost"]
|
| 1031 |
+
|
| 1032 |
+
print("[generate_music] VAE decode completed. Saving audio files...")
|
| 1033 |
+
if progress:
|
| 1034 |
+
progress(0.9, desc="Saving audio files...")
|
| 1035 |
+
|
| 1036 |
+
# Save audio files using soundfile (supports wav, flac, mp3 via format param)
|
| 1037 |
+
audio_format_lower = audio_format.lower() if audio_format else "wav"
|
| 1038 |
+
if audio_format_lower not in ["wav", "flac", "mp3"]:
|
| 1039 |
+
audio_format_lower = "wav"
|
| 1040 |
+
|
| 1041 |
+
saved_files = []
|
| 1042 |
+
for i in range(actual_batch_size):
|
| 1043 |
+
audio_file = os.path.join(self.temp_dir, f"generated_{i}_{seed_list[i]}.{audio_format_lower}")
|
| 1044 |
+
# Convert to numpy: [channels, samples] -> [samples, channels]
|
| 1045 |
+
audio_np = pred_wavs[i].cpu().float().numpy().T
|
| 1046 |
+
sf.write(audio_file, audio_np, self.sample_rate)
|
| 1047 |
+
saved_files.append(audio_file)
|
| 1048 |
+
|
| 1049 |
+
# Prepare return values
|
| 1050 |
+
first_audio = saved_files[0] if len(saved_files) > 0 else None
|
| 1051 |
+
second_audio = saved_files[1] if len(saved_files) > 1 else None
|
| 1052 |
+
|
| 1053 |
+
# Format time costs if available
|
| 1054 |
+
time_costs_str = ""
|
| 1055 |
+
if time_costs:
|
| 1056 |
+
if isinstance(time_costs, dict):
|
| 1057 |
+
time_costs_str = "\n\n**⏱️ Time Costs:**\n"
|
| 1058 |
+
for key, value in time_costs.items():
|
| 1059 |
+
# Format key: encoder_time_cost -> Encoder
|
| 1060 |
+
formatted_key = key.replace("_time_cost", "").replace("_", " ").title()
|
| 1061 |
+
time_costs_str += f" - {formatted_key}: {value:.2f}s\n"
|
| 1062 |
+
elif isinstance(time_costs, (int, float)):
|
| 1063 |
+
time_costs_str = f"\n\n**⏱️ Time Cost:** {time_costs:.2f}s"
|
| 1064 |
+
|
| 1065 |
+
generation_info = f"""**🎵 Generation Complete**
|
| 1066 |
+
|
| 1067 |
+
**Seeds:** {seed_value_for_ui}
|
| 1068 |
+
**Duration:** {calculated_duration:.1f}s
|
| 1069 |
+
**Steps:** {inference_steps}
|
| 1070 |
+
**Files:** {len(saved_files)} audio(s){time_costs_str}"""
|
| 1071 |
+
status_message = f"✅ Generation completed successfully!"
|
| 1072 |
+
print(f"[generate_music] Done! Generated {len(saved_files)} audio files.")
|
| 1073 |
+
|
| 1074 |
+
# Alignment scores and plots (placeholder for now)
|
| 1075 |
+
align_score_1 = ""
|
| 1076 |
+
align_text_1 = ""
|
| 1077 |
+
align_plot_1 = None
|
| 1078 |
+
align_score_2 = ""
|
| 1079 |
+
align_text_2 = ""
|
| 1080 |
+
align_plot_2 = None
|
| 1081 |
+
|
| 1082 |
+
return (
|
| 1083 |
+
first_audio,
|
| 1084 |
+
second_audio,
|
| 1085 |
+
saved_files,
|
| 1086 |
+
generation_info,
|
| 1087 |
+
status_message,
|
| 1088 |
+
seed_value_for_ui,
|
| 1089 |
+
align_score_1,
|
| 1090 |
+
align_text_1,
|
| 1091 |
+
align_plot_1,
|
| 1092 |
+
align_score_2,
|
| 1093 |
+
align_text_2,
|
| 1094 |
+
align_plot_2,
|
| 1095 |
+
)
|
| 1096 |
+
|
| 1097 |
+
except Exception as e:
|
| 1098 |
+
error_msg = f"❌ Error generating music: {str(e)}\n\nTraceback:\n{traceback.format_exc()}"
|
| 1099 |
+
return None, None, [], "", error_msg, "-1", "", "", None, "", "", None
|
| 1100 |
+
|
acestep/third_parts/nano-vllm/LICENSE
ADDED
|
@@ -0,0 +1,21 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
MIT License
|
| 2 |
+
|
| 3 |
+
Copyright (c) 2025 Xingkai Yu
|
| 4 |
+
|
| 5 |
+
Permission is hereby granted, free of charge, to any person obtaining a copy
|
| 6 |
+
of this software and associated documentation files (the "Software"), to deal
|
| 7 |
+
in the Software without restriction, including without limitation the rights
|
| 8 |
+
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
| 9 |
+
copies of the Software, and to permit persons to whom the Software is
|
| 10 |
+
furnished to do so, subject to the following conditions:
|
| 11 |
+
|
| 12 |
+
The above copyright notice and this permission notice shall be included in all
|
| 13 |
+
copies or substantial portions of the Software.
|
| 14 |
+
|
| 15 |
+
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
| 16 |
+
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
| 17 |
+
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
| 18 |
+
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
| 19 |
+
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
| 20 |
+
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
| 21 |
+
SOFTWARE.
|
acestep/third_parts/nano-vllm/README.md
ADDED
|
@@ -0,0 +1,66 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
<p align="center">
|
| 2 |
+
<img width="300" src="assets/logo.png">
|
| 3 |
+
</p>
|
| 4 |
+
|
| 5 |
+
<p align="center">
|
| 6 |
+
<a href="https://trendshift.io/repositories/15323" target="_blank"><img src="https://trendshift.io/api/badge/repositories/15323" alt="GeeeekExplorer%2Fnano-vllm | Trendshift" style="width: 250px; height: 55px;" width="250" height="55"/></a>
|
| 7 |
+
</p>
|
| 8 |
+
|
| 9 |
+
# Nano-vLLM
|
| 10 |
+
|
| 11 |
+
A lightweight vLLM implementation built from scratch.
|
| 12 |
+
|
| 13 |
+
## Key Features
|
| 14 |
+
|
| 15 |
+
* 🚀 **Fast offline inference** - Comparable inference speeds to vLLM
|
| 16 |
+
* 📖 **Readable codebase** - Clean implementation in ~ 1,200 lines of Python code
|
| 17 |
+
* ⚡ **Optimization Suite** - Prefix caching, Tensor Parallelism, Torch compilation, CUDA graph, etc.
|
| 18 |
+
|
| 19 |
+
## Installation
|
| 20 |
+
|
| 21 |
+
```bash
|
| 22 |
+
pip install git+https://github.com/GeeeekExplorer/nano-vllm.git
|
| 23 |
+
```
|
| 24 |
+
|
| 25 |
+
## Model Download
|
| 26 |
+
|
| 27 |
+
To download the model weights manually, use the following command:
|
| 28 |
+
```bash
|
| 29 |
+
huggingface-cli download --resume-download Qwen/Qwen3-0.6B \
|
| 30 |
+
--local-dir ~/huggingface/Qwen3-0.6B/ \
|
| 31 |
+
--local-dir-use-symlinks False
|
| 32 |
+
```
|
| 33 |
+
|
| 34 |
+
## Quick Start
|
| 35 |
+
|
| 36 |
+
See `example.py` for usage. The API mirrors vLLM's interface with minor differences in the `LLM.generate` method:
|
| 37 |
+
```python
|
| 38 |
+
from nanovllm import LLM, SamplingParams
|
| 39 |
+
llm = LLM("/YOUR/MODEL/PATH", enforce_eager=True, tensor_parallel_size=1)
|
| 40 |
+
sampling_params = SamplingParams(temperature=0.6, max_tokens=256)
|
| 41 |
+
prompts = ["Hello, Nano-vLLM."]
|
| 42 |
+
outputs = llm.generate(prompts, sampling_params)
|
| 43 |
+
outputs[0]["text"]
|
| 44 |
+
```
|
| 45 |
+
|
| 46 |
+
## Benchmark
|
| 47 |
+
|
| 48 |
+
See `bench.py` for benchmark.
|
| 49 |
+
|
| 50 |
+
**Test Configuration:**
|
| 51 |
+
- Hardware: RTX 4070 Laptop (8GB)
|
| 52 |
+
- Model: Qwen3-0.6B
|
| 53 |
+
- Total Requests: 256 sequences
|
| 54 |
+
- Input Length: Randomly sampled between 100–1024 tokens
|
| 55 |
+
- Output Length: Randomly sampled between 100–1024 tokens
|
| 56 |
+
|
| 57 |
+
**Performance Results:**
|
| 58 |
+
| Inference Engine | Output Tokens | Time (s) | Throughput (tokens/s) |
|
| 59 |
+
|----------------|-------------|----------|-----------------------|
|
| 60 |
+
| vLLM | 133,966 | 98.37 | 1361.84 |
|
| 61 |
+
| Nano-vLLM | 133,966 | 93.41 | 1434.13 |
|
| 62 |
+
|
| 63 |
+
|
| 64 |
+
## Star History
|
| 65 |
+
|
| 66 |
+
[](https://www.star-history.com/#GeeeekExplorer/nano-vllm&Date)
|
acestep/third_parts/nano-vllm/assets/logo.png
ADDED
|
Git LFS Details
|
acestep/third_parts/nano-vllm/bench.py
ADDED
|
@@ -0,0 +1,32 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
import time
|
| 3 |
+
from random import randint, seed
|
| 4 |
+
from nanovllm import LLM, SamplingParams
|
| 5 |
+
# from vllm import LLM, SamplingParams
|
| 6 |
+
|
| 7 |
+
|
| 8 |
+
def main():
|
| 9 |
+
seed(0)
|
| 10 |
+
num_seqs = 256
|
| 11 |
+
max_input_len = 1024
|
| 12 |
+
max_ouput_len = 1024
|
| 13 |
+
|
| 14 |
+
path = os.path.expanduser("~/huggingface/Qwen3-0.6B/")
|
| 15 |
+
llm = LLM(path, enforce_eager=False, max_model_len=4096)
|
| 16 |
+
|
| 17 |
+
prompt_token_ids = [[randint(0, 10000) for _ in range(randint(100, max_input_len))] for _ in range(num_seqs)]
|
| 18 |
+
sampling_params = [SamplingParams(temperature=0.6, ignore_eos=True, max_tokens=randint(100, max_ouput_len)) for _ in range(num_seqs)]
|
| 19 |
+
# uncomment the following line for vllm
|
| 20 |
+
# prompt_token_ids = [dict(prompt_token_ids=p) for p in prompt_token_ids]
|
| 21 |
+
|
| 22 |
+
llm.generate(["Benchmark: "], SamplingParams())
|
| 23 |
+
t = time.time()
|
| 24 |
+
llm.generate(prompt_token_ids, sampling_params, use_tqdm=False)
|
| 25 |
+
t = (time.time() - t)
|
| 26 |
+
total_tokens = sum(sp.max_tokens for sp in sampling_params)
|
| 27 |
+
throughput = total_tokens / t
|
| 28 |
+
print(f"Total: {total_tokens}tok, Time: {t:.2f}s, Throughput: {throughput:.2f}tok/s")
|
| 29 |
+
|
| 30 |
+
|
| 31 |
+
if __name__ == "__main__":
|
| 32 |
+
main()
|
acestep/third_parts/nano-vllm/example.py
ADDED
|
@@ -0,0 +1,33 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
from nanovllm import LLM, SamplingParams
|
| 3 |
+
from transformers import AutoTokenizer
|
| 4 |
+
|
| 5 |
+
|
| 6 |
+
def main():
|
| 7 |
+
path = os.path.expanduser("~/huggingface/Qwen3-0.6B/")
|
| 8 |
+
tokenizer = AutoTokenizer.from_pretrained(path)
|
| 9 |
+
llm = LLM(path, enforce_eager=True, tensor_parallel_size=1)
|
| 10 |
+
|
| 11 |
+
sampling_params = SamplingParams(temperature=0.6, max_tokens=256)
|
| 12 |
+
prompts = [
|
| 13 |
+
"introduce yourself",
|
| 14 |
+
"list all prime numbers within 100",
|
| 15 |
+
]
|
| 16 |
+
prompts = [
|
| 17 |
+
tokenizer.apply_chat_template(
|
| 18 |
+
[{"role": "user", "content": prompt}],
|
| 19 |
+
tokenize=False,
|
| 20 |
+
add_generation_prompt=True,
|
| 21 |
+
)
|
| 22 |
+
for prompt in prompts
|
| 23 |
+
]
|
| 24 |
+
outputs = llm.generate(prompts, sampling_params)
|
| 25 |
+
|
| 26 |
+
for prompt, output in zip(prompts, outputs):
|
| 27 |
+
print("\n")
|
| 28 |
+
print(f"Prompt: {prompt!r}")
|
| 29 |
+
print(f"Completion: {output['text']!r}")
|
| 30 |
+
|
| 31 |
+
|
| 32 |
+
if __name__ == "__main__":
|
| 33 |
+
main()
|
acestep/third_parts/nano-vllm/nanovllm/__init__.py
ADDED
|
@@ -0,0 +1,2 @@
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from nanovllm.llm import LLM
|
| 2 |
+
from nanovllm.sampling_params import SamplingParams
|
acestep/third_parts/nano-vllm/nanovllm/config.py
ADDED
|
@@ -0,0 +1,26 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
from dataclasses import dataclass
|
| 3 |
+
from transformers import AutoConfig
|
| 4 |
+
|
| 5 |
+
|
| 6 |
+
@dataclass
|
| 7 |
+
class Config:
|
| 8 |
+
model: str
|
| 9 |
+
max_num_batched_tokens: int = 16384
|
| 10 |
+
max_num_seqs: int = 512
|
| 11 |
+
max_model_len: int = 4096
|
| 12 |
+
gpu_memory_utilization: float = 0.9
|
| 13 |
+
tensor_parallel_size: int = 1
|
| 14 |
+
enforce_eager: bool = False
|
| 15 |
+
hf_config: AutoConfig | None = None
|
| 16 |
+
eos: int = -1
|
| 17 |
+
kvcache_block_size: int = 256
|
| 18 |
+
num_kvcache_blocks: int = -1
|
| 19 |
+
|
| 20 |
+
def __post_init__(self):
|
| 21 |
+
assert os.path.isdir(self.model)
|
| 22 |
+
assert self.kvcache_block_size % 256 == 0
|
| 23 |
+
assert 1 <= self.tensor_parallel_size <= 8
|
| 24 |
+
self.hf_config = AutoConfig.from_pretrained(self.model)
|
| 25 |
+
self.max_model_len = min(self.max_model_len, self.hf_config.max_position_embeddings)
|
| 26 |
+
assert self.max_num_batched_tokens >= self.max_model_len
|
acestep/third_parts/nano-vllm/nanovllm/engine/block_manager.py
ADDED
|
@@ -0,0 +1,112 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from collections import deque
|
| 2 |
+
import xxhash
|
| 3 |
+
import numpy as np
|
| 4 |
+
|
| 5 |
+
from nanovllm.engine.sequence import Sequence
|
| 6 |
+
|
| 7 |
+
|
| 8 |
+
class Block:
|
| 9 |
+
|
| 10 |
+
def __init__(self, block_id):
|
| 11 |
+
self.block_id = block_id
|
| 12 |
+
self.ref_count = 0
|
| 13 |
+
self.hash = -1
|
| 14 |
+
self.token_ids = []
|
| 15 |
+
|
| 16 |
+
def update(self, hash: int, token_ids: list[int]):
|
| 17 |
+
self.hash = hash
|
| 18 |
+
self.token_ids = token_ids
|
| 19 |
+
|
| 20 |
+
def reset(self):
|
| 21 |
+
self.ref_count = 1
|
| 22 |
+
self.hash = -1
|
| 23 |
+
self.token_ids = []
|
| 24 |
+
|
| 25 |
+
|
| 26 |
+
class BlockManager:
|
| 27 |
+
|
| 28 |
+
def __init__(self, num_blocks: int, block_size: int):
|
| 29 |
+
self.block_size = block_size
|
| 30 |
+
self.blocks: list[Block] = [Block(i) for i in range(num_blocks)]
|
| 31 |
+
self.hash_to_block_id: dict[int, int] = dict()
|
| 32 |
+
self.free_block_ids: deque[int] = deque(range(num_blocks))
|
| 33 |
+
self.used_block_ids: set[int] = set()
|
| 34 |
+
|
| 35 |
+
@classmethod
|
| 36 |
+
def compute_hash(cls, token_ids: list[int], prefix: int = -1):
|
| 37 |
+
h = xxhash.xxh64()
|
| 38 |
+
if prefix != -1:
|
| 39 |
+
h.update(prefix.to_bytes(8, "little"))
|
| 40 |
+
h.update(np.array(token_ids).tobytes())
|
| 41 |
+
return h.intdigest()
|
| 42 |
+
|
| 43 |
+
def _allocate_block(self, block_id: int) -> Block:
|
| 44 |
+
block = self.blocks[block_id]
|
| 45 |
+
assert block.ref_count == 0
|
| 46 |
+
block.reset()
|
| 47 |
+
self.free_block_ids.remove(block_id)
|
| 48 |
+
self.used_block_ids.add(block_id)
|
| 49 |
+
return self.blocks[block_id]
|
| 50 |
+
|
| 51 |
+
def _deallocate_block(self, block_id: int) -> Block:
|
| 52 |
+
assert self.blocks[block_id].ref_count == 0
|
| 53 |
+
self.used_block_ids.remove(block_id)
|
| 54 |
+
self.free_block_ids.append(block_id)
|
| 55 |
+
|
| 56 |
+
def can_allocate(self, seq: Sequence) -> bool:
|
| 57 |
+
return len(self.free_block_ids) >= seq.num_blocks
|
| 58 |
+
|
| 59 |
+
def allocate(self, seq: Sequence):
|
| 60 |
+
assert not seq.block_table
|
| 61 |
+
h = -1
|
| 62 |
+
cache_miss = False
|
| 63 |
+
for i in range(seq.num_blocks):
|
| 64 |
+
token_ids = seq.block(i)
|
| 65 |
+
h = self.compute_hash(token_ids, h) if len(token_ids) == self.block_size else -1
|
| 66 |
+
block_id = self.hash_to_block_id.get(h, -1)
|
| 67 |
+
if block_id == -1 or self.blocks[block_id].token_ids != token_ids:
|
| 68 |
+
cache_miss = True
|
| 69 |
+
if cache_miss:
|
| 70 |
+
block_id = self.free_block_ids[0]
|
| 71 |
+
block = self._allocate_block(block_id)
|
| 72 |
+
else:
|
| 73 |
+
seq.num_cached_tokens += self.block_size
|
| 74 |
+
if block_id in self.used_block_ids:
|
| 75 |
+
block = self.blocks[block_id]
|
| 76 |
+
block.ref_count += 1
|
| 77 |
+
else:
|
| 78 |
+
block = self._allocate_block(block_id)
|
| 79 |
+
if h != -1:
|
| 80 |
+
block.update(h, token_ids)
|
| 81 |
+
self.hash_to_block_id[h] = block_id
|
| 82 |
+
seq.block_table.append(block_id)
|
| 83 |
+
|
| 84 |
+
def deallocate(self, seq: Sequence):
|
| 85 |
+
for block_id in reversed(seq.block_table):
|
| 86 |
+
block = self.blocks[block_id]
|
| 87 |
+
block.ref_count -= 1
|
| 88 |
+
if block.ref_count == 0:
|
| 89 |
+
self._deallocate_block(block_id)
|
| 90 |
+
seq.num_cached_tokens = 0
|
| 91 |
+
seq.block_table.clear()
|
| 92 |
+
|
| 93 |
+
def can_append(self, seq: Sequence) -> bool:
|
| 94 |
+
return len(self.free_block_ids) >= (len(seq) % self.block_size == 1)
|
| 95 |
+
|
| 96 |
+
def may_append(self, seq: Sequence):
|
| 97 |
+
block_table = seq.block_table
|
| 98 |
+
last_block = self.blocks[block_table[-1]]
|
| 99 |
+
if len(seq) % self.block_size == 1:
|
| 100 |
+
assert last_block.hash != -1
|
| 101 |
+
block_id = self.free_block_ids[0]
|
| 102 |
+
self._allocate_block(block_id)
|
| 103 |
+
block_table.append(block_id)
|
| 104 |
+
elif len(seq) % self.block_size == 0:
|
| 105 |
+
assert last_block.hash == -1
|
| 106 |
+
token_ids = seq.block(seq.num_blocks-1)
|
| 107 |
+
prefix = self.blocks[block_table[-2]].hash if len(block_table) > 1 else -1
|
| 108 |
+
h = self.compute_hash(token_ids, prefix)
|
| 109 |
+
last_block.update(h, token_ids)
|
| 110 |
+
self.hash_to_block_id[h] = last_block.block_id
|
| 111 |
+
else:
|
| 112 |
+
assert last_block.hash == -1
|
acestep/third_parts/nano-vllm/nanovllm/engine/llm_engine.py
ADDED
|
@@ -0,0 +1,120 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import atexit
|
| 2 |
+
from dataclasses import fields
|
| 3 |
+
from time import perf_counter
|
| 4 |
+
from tqdm.auto import tqdm
|
| 5 |
+
from transformers import AutoTokenizer
|
| 6 |
+
import torch.multiprocessing as mp
|
| 7 |
+
|
| 8 |
+
from nanovllm.config import Config
|
| 9 |
+
from nanovllm.sampling_params import SamplingParams
|
| 10 |
+
from nanovllm.engine.sequence import Sequence
|
| 11 |
+
from nanovllm.engine.scheduler import Scheduler
|
| 12 |
+
from nanovllm.engine.model_runner import ModelRunner
|
| 13 |
+
|
| 14 |
+
|
| 15 |
+
class LLMEngine:
|
| 16 |
+
|
| 17 |
+
def __init__(self, model, **kwargs):
|
| 18 |
+
config_fields = {field.name for field in fields(Config)}
|
| 19 |
+
config_kwargs = {k: v for k, v in kwargs.items() if k in config_fields}
|
| 20 |
+
config = Config(model, **config_kwargs)
|
| 21 |
+
self.ps = []
|
| 22 |
+
self.events = []
|
| 23 |
+
ctx = mp.get_context("spawn")
|
| 24 |
+
for i in range(1, config.tensor_parallel_size):
|
| 25 |
+
event = ctx.Event()
|
| 26 |
+
process = ctx.Process(target=ModelRunner, args=(config, i, event))
|
| 27 |
+
process.start()
|
| 28 |
+
self.ps.append(process)
|
| 29 |
+
self.events.append(event)
|
| 30 |
+
self.model_runner = ModelRunner(config, 0, self.events)
|
| 31 |
+
self.tokenizer = AutoTokenizer.from_pretrained(config.model, use_fast=True)
|
| 32 |
+
config.eos = self.tokenizer.eos_token_id
|
| 33 |
+
self.scheduler = Scheduler(config)
|
| 34 |
+
atexit.register(self.exit)
|
| 35 |
+
|
| 36 |
+
def exit(self):
|
| 37 |
+
self.model_runner.call("exit")
|
| 38 |
+
del self.model_runner
|
| 39 |
+
for p in self.ps:
|
| 40 |
+
p.join()
|
| 41 |
+
|
| 42 |
+
def add_request(self, prompt: str | list[int], sampling_params: SamplingParams, unconditional_prompt: str | list[int] | None = None):
|
| 43 |
+
if isinstance(prompt, str):
|
| 44 |
+
prompt = self.tokenizer.encode(prompt)
|
| 45 |
+
# For CFG: if cfg_scale > 1.0, create both conditional and unconditional sequences
|
| 46 |
+
if sampling_params.cfg_scale > 1.0:
|
| 47 |
+
if unconditional_prompt is None:
|
| 48 |
+
# Try to construct unconditional prompt by replacing user input with "NO USER INPUT"
|
| 49 |
+
# This is a fallback - ideally users should provide unconditional_prompt
|
| 50 |
+
if isinstance(prompt, list):
|
| 51 |
+
# For now, just use the same prompt (user should provide unconditional_prompt)
|
| 52 |
+
# TODO: Implement automatic "NO USER INPUT" replacement if possible
|
| 53 |
+
unconditional_prompt = prompt
|
| 54 |
+
else:
|
| 55 |
+
unconditional_prompt = prompt
|
| 56 |
+
if isinstance(unconditional_prompt, str):
|
| 57 |
+
unconditional_prompt = self.tokenizer.encode(unconditional_prompt)
|
| 58 |
+
# Create unconditional sequence first (so we can reference it from conditional)
|
| 59 |
+
uncond_seq = Sequence(unconditional_prompt, sampling_params, is_unconditional=True)
|
| 60 |
+
# Create conditional sequence with reference to unconditional
|
| 61 |
+
cond_seq = Sequence(prompt, sampling_params, is_unconditional=False, conditional_seq=uncond_seq)
|
| 62 |
+
uncond_seq.paired_seq = cond_seq # Link them bidirectionally
|
| 63 |
+
# Add both sequences to scheduler
|
| 64 |
+
self.scheduler.add(cond_seq)
|
| 65 |
+
self.scheduler.add(uncond_seq)
|
| 66 |
+
else:
|
| 67 |
+
seq = Sequence(prompt, sampling_params)
|
| 68 |
+
self.scheduler.add(seq)
|
| 69 |
+
|
| 70 |
+
def step(self):
|
| 71 |
+
seqs, is_prefill = self.scheduler.schedule()
|
| 72 |
+
token_ids = self.model_runner.call("run", seqs, is_prefill)
|
| 73 |
+
self.scheduler.postprocess(seqs, token_ids)
|
| 74 |
+
# Only output conditional sequences (unconditional sequences are just for CFG computation)
|
| 75 |
+
output_seqs = [seq for seq in seqs if seq.is_finished and (seq.cfg_scale <= 1.0 or not seq.is_unconditional)]
|
| 76 |
+
outputs = [(seq.seq_id, seq.completion_token_ids) for seq in output_seqs]
|
| 77 |
+
num_tokens = sum(len(seq) for seq in seqs) if is_prefill else -len([s for s in seqs if not s.is_unconditional])
|
| 78 |
+
return outputs, num_tokens
|
| 79 |
+
|
| 80 |
+
def is_finished(self):
|
| 81 |
+
return self.scheduler.is_finished()
|
| 82 |
+
|
| 83 |
+
def generate(
|
| 84 |
+
self,
|
| 85 |
+
prompts: list[str] | list[list[int]],
|
| 86 |
+
sampling_params: SamplingParams | list[SamplingParams],
|
| 87 |
+
use_tqdm: bool = True,
|
| 88 |
+
unconditional_prompts: list[str] | list[list[int]] | None = None,
|
| 89 |
+
) -> list[str]:
|
| 90 |
+
if use_tqdm:
|
| 91 |
+
pbar = tqdm(total=len(prompts), desc="Generating", dynamic_ncols=True)
|
| 92 |
+
if not isinstance(sampling_params, list):
|
| 93 |
+
sampling_params = [sampling_params] * len(prompts)
|
| 94 |
+
if unconditional_prompts is None:
|
| 95 |
+
unconditional_prompts = [None] * len(prompts)
|
| 96 |
+
for prompt, sp, uncond_prompt in zip(prompts, sampling_params, unconditional_prompts):
|
| 97 |
+
self.add_request(prompt, sp, uncond_prompt)
|
| 98 |
+
outputs = {}
|
| 99 |
+
prefill_throughput = decode_throughput = 0.
|
| 100 |
+
while not self.is_finished():
|
| 101 |
+
t = perf_counter()
|
| 102 |
+
output, num_tokens = self.step()
|
| 103 |
+
if use_tqdm:
|
| 104 |
+
if num_tokens > 0:
|
| 105 |
+
prefill_throughput = num_tokens / (perf_counter() - t)
|
| 106 |
+
else:
|
| 107 |
+
decode_throughput = -num_tokens / (perf_counter() - t)
|
| 108 |
+
pbar.set_postfix({
|
| 109 |
+
"Prefill": f"{int(prefill_throughput)}tok/s",
|
| 110 |
+
"Decode": f"{int(decode_throughput)}tok/s",
|
| 111 |
+
})
|
| 112 |
+
for seq_id, token_ids in output:
|
| 113 |
+
outputs[seq_id] = token_ids
|
| 114 |
+
if use_tqdm:
|
| 115 |
+
pbar.update(1)
|
| 116 |
+
outputs = [outputs[seq_id] for seq_id in sorted(outputs.keys())]
|
| 117 |
+
outputs = [{"text": self.tokenizer.decode(token_ids), "token_ids": token_ids} for token_ids in outputs]
|
| 118 |
+
if use_tqdm:
|
| 119 |
+
pbar.close()
|
| 120 |
+
return outputs
|
acestep/third_parts/nano-vllm/nanovllm/engine/model_runner.py
ADDED
|
@@ -0,0 +1,315 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import pickle
|
| 2 |
+
import torch
|
| 3 |
+
import torch.distributed as dist
|
| 4 |
+
from multiprocessing.synchronize import Event
|
| 5 |
+
from multiprocessing.shared_memory import SharedMemory
|
| 6 |
+
|
| 7 |
+
from nanovllm.config import Config
|
| 8 |
+
from nanovllm.engine.sequence import Sequence
|
| 9 |
+
from nanovllm.models.qwen3 import Qwen3ForCausalLM
|
| 10 |
+
from nanovllm.layers.sampler import Sampler
|
| 11 |
+
from nanovllm.utils.context import set_context, get_context, reset_context
|
| 12 |
+
from nanovllm.utils.loader import load_model
|
| 13 |
+
|
| 14 |
+
|
| 15 |
+
class ModelRunner:
|
| 16 |
+
|
| 17 |
+
def __init__(self, config: Config, rank: int, event: Event | list[Event]):
|
| 18 |
+
self.config = config
|
| 19 |
+
hf_config = config.hf_config
|
| 20 |
+
self.block_size = config.kvcache_block_size
|
| 21 |
+
self.enforce_eager = config.enforce_eager
|
| 22 |
+
self.world_size = config.tensor_parallel_size
|
| 23 |
+
self.rank = rank
|
| 24 |
+
self.event = event
|
| 25 |
+
|
| 26 |
+
dist.init_process_group("nccl", "tcp://localhost:2333", world_size=self.world_size, rank=rank)
|
| 27 |
+
torch.cuda.set_device(rank)
|
| 28 |
+
default_dtype = torch.get_default_dtype()
|
| 29 |
+
torch.set_default_dtype(hf_config.torch_dtype)
|
| 30 |
+
torch.set_default_device("cuda")
|
| 31 |
+
self.model = Qwen3ForCausalLM(hf_config)
|
| 32 |
+
load_model(self.model, config.model)
|
| 33 |
+
self.sampler = Sampler()
|
| 34 |
+
self.warmup_model()
|
| 35 |
+
self.allocate_kv_cache()
|
| 36 |
+
if not self.enforce_eager:
|
| 37 |
+
self.capture_cudagraph()
|
| 38 |
+
torch.set_default_device("cpu")
|
| 39 |
+
torch.set_default_dtype(default_dtype)
|
| 40 |
+
|
| 41 |
+
if self.world_size > 1:
|
| 42 |
+
if rank == 0:
|
| 43 |
+
self.shm = SharedMemory(name="nanovllm", create=True, size=2**20)
|
| 44 |
+
dist.barrier()
|
| 45 |
+
else:
|
| 46 |
+
dist.barrier()
|
| 47 |
+
self.shm = SharedMemory(name="nanovllm")
|
| 48 |
+
self.loop()
|
| 49 |
+
|
| 50 |
+
def exit(self):
|
| 51 |
+
if self.world_size > 1:
|
| 52 |
+
self.shm.close()
|
| 53 |
+
dist.barrier()
|
| 54 |
+
if self.rank == 0:
|
| 55 |
+
self.shm.unlink()
|
| 56 |
+
if not self.enforce_eager:
|
| 57 |
+
del self.graphs, self.graph_pool
|
| 58 |
+
torch.cuda.synchronize()
|
| 59 |
+
dist.destroy_process_group()
|
| 60 |
+
|
| 61 |
+
def loop(self):
|
| 62 |
+
while True:
|
| 63 |
+
method_name, args = self.read_shm()
|
| 64 |
+
self.call(method_name, *args)
|
| 65 |
+
if method_name == "exit":
|
| 66 |
+
break
|
| 67 |
+
|
| 68 |
+
def read_shm(self):
|
| 69 |
+
assert self.world_size > 1 and self.rank > 0
|
| 70 |
+
self.event.wait()
|
| 71 |
+
n = int.from_bytes(self.shm.buf[0:4], "little")
|
| 72 |
+
method_name, *args = pickle.loads(self.shm.buf[4:n+4])
|
| 73 |
+
self.event.clear()
|
| 74 |
+
return method_name, args
|
| 75 |
+
|
| 76 |
+
def write_shm(self, method_name, *args):
|
| 77 |
+
assert self.world_size > 1 and self.rank == 0
|
| 78 |
+
data = pickle.dumps([method_name, *args])
|
| 79 |
+
n = len(data)
|
| 80 |
+
self.shm.buf[0:4] = n.to_bytes(4, "little")
|
| 81 |
+
self.shm.buf[4:n+4] = data
|
| 82 |
+
for event in self.event:
|
| 83 |
+
event.set()
|
| 84 |
+
|
| 85 |
+
def call(self, method_name, *args):
|
| 86 |
+
if self.world_size > 1 and self.rank == 0:
|
| 87 |
+
self.write_shm(method_name, *args)
|
| 88 |
+
method = getattr(self, method_name, None)
|
| 89 |
+
return method(*args)
|
| 90 |
+
|
| 91 |
+
def warmup_model(self):
|
| 92 |
+
torch.cuda.empty_cache()
|
| 93 |
+
torch.cuda.reset_peak_memory_stats()
|
| 94 |
+
max_num_batched_tokens, max_model_len = self.config.max_num_batched_tokens, self.config.max_model_len
|
| 95 |
+
num_seqs = min(max_num_batched_tokens // max_model_len, self.config.max_num_seqs)
|
| 96 |
+
seqs = [Sequence([0] * max_model_len) for _ in range(num_seqs)]
|
| 97 |
+
self.run(seqs, True)
|
| 98 |
+
torch.cuda.empty_cache()
|
| 99 |
+
|
| 100 |
+
def allocate_kv_cache(self):
|
| 101 |
+
config = self.config
|
| 102 |
+
hf_config = config.hf_config
|
| 103 |
+
free, total = torch.cuda.mem_get_info()
|
| 104 |
+
used = total - free
|
| 105 |
+
peak = torch.cuda.memory_stats()["allocated_bytes.all.peak"]
|
| 106 |
+
current = torch.cuda.memory_stats()["allocated_bytes.all.current"]
|
| 107 |
+
num_kv_heads = hf_config.num_key_value_heads // self.world_size
|
| 108 |
+
head_dim = getattr(hf_config, "head_dim", hf_config.hidden_size // hf_config.num_attention_heads)
|
| 109 |
+
block_bytes = 2 * hf_config.num_hidden_layers * self.block_size * num_kv_heads * head_dim * hf_config.torch_dtype.itemsize
|
| 110 |
+
config.num_kvcache_blocks = int(total * config.gpu_memory_utilization - used - peak + current) // block_bytes
|
| 111 |
+
assert config.num_kvcache_blocks > 0
|
| 112 |
+
self.kv_cache = torch.empty(2, hf_config.num_hidden_layers, config.num_kvcache_blocks, self.block_size, num_kv_heads, head_dim)
|
| 113 |
+
layer_id = 0
|
| 114 |
+
for module in self.model.modules():
|
| 115 |
+
if hasattr(module, "k_cache") and hasattr(module, "v_cache"):
|
| 116 |
+
module.k_cache = self.kv_cache[0, layer_id]
|
| 117 |
+
module.v_cache = self.kv_cache[1, layer_id]
|
| 118 |
+
layer_id += 1
|
| 119 |
+
|
| 120 |
+
def prepare_block_tables(self, seqs: list[Sequence]):
|
| 121 |
+
max_len = max(len(seq.block_table) for seq in seqs)
|
| 122 |
+
block_tables = [seq.block_table + [-1] * (max_len - len(seq.block_table)) for seq in seqs]
|
| 123 |
+
block_tables = torch.tensor(block_tables, dtype=torch.int32, pin_memory=True).cuda(non_blocking=True)
|
| 124 |
+
return block_tables
|
| 125 |
+
|
| 126 |
+
def prepare_prefill(self, seqs: list[Sequence]):
|
| 127 |
+
input_ids = []
|
| 128 |
+
positions = []
|
| 129 |
+
cu_seqlens_q = [0]
|
| 130 |
+
cu_seqlens_k = [0]
|
| 131 |
+
max_seqlen_q = 0
|
| 132 |
+
max_seqlen_k = 0
|
| 133 |
+
slot_mapping = []
|
| 134 |
+
block_tables = None
|
| 135 |
+
for seq in seqs:
|
| 136 |
+
seqlen = len(seq)
|
| 137 |
+
input_ids.extend(seq[seq.num_cached_tokens:])
|
| 138 |
+
positions.extend(list(range(seq.num_cached_tokens, seqlen)))
|
| 139 |
+
seqlen_q = seqlen - seq.num_cached_tokens
|
| 140 |
+
seqlen_k = seqlen
|
| 141 |
+
cu_seqlens_q.append(cu_seqlens_q[-1] + seqlen_q)
|
| 142 |
+
cu_seqlens_k.append(cu_seqlens_k[-1] + seqlen_k)
|
| 143 |
+
max_seqlen_q = max(seqlen_q, max_seqlen_q)
|
| 144 |
+
max_seqlen_k = max(seqlen_k, max_seqlen_k)
|
| 145 |
+
if not seq.block_table: # warmup
|
| 146 |
+
continue
|
| 147 |
+
for i in range(seq.num_cached_blocks, seq.num_blocks):
|
| 148 |
+
start = seq.block_table[i] * self.block_size
|
| 149 |
+
if i != seq.num_blocks - 1:
|
| 150 |
+
end = start + self.block_size
|
| 151 |
+
else:
|
| 152 |
+
end = start + seq.last_block_num_tokens
|
| 153 |
+
slot_mapping.extend(list(range(start, end)))
|
| 154 |
+
if cu_seqlens_k[-1] > cu_seqlens_q[-1]: # prefix cache
|
| 155 |
+
block_tables = self.prepare_block_tables(seqs)
|
| 156 |
+
input_ids = torch.tensor(input_ids, dtype=torch.int64, pin_memory=True).cuda(non_blocking=True)
|
| 157 |
+
positions = torch.tensor(positions, dtype=torch.int64, pin_memory=True).cuda(non_blocking=True)
|
| 158 |
+
cu_seqlens_q = torch.tensor(cu_seqlens_q, dtype=torch.int32, pin_memory=True).cuda(non_blocking=True)
|
| 159 |
+
cu_seqlens_k = torch.tensor(cu_seqlens_k, dtype=torch.int32, pin_memory=True).cuda(non_blocking=True)
|
| 160 |
+
slot_mapping = torch.tensor(slot_mapping, dtype=torch.int32, pin_memory=True).cuda(non_blocking=True)
|
| 161 |
+
set_context(True, cu_seqlens_q, cu_seqlens_k, max_seqlen_q, max_seqlen_k, slot_mapping, None, block_tables)
|
| 162 |
+
return input_ids, positions
|
| 163 |
+
|
| 164 |
+
def prepare_decode(self, seqs: list[Sequence]):
|
| 165 |
+
input_ids = []
|
| 166 |
+
positions = []
|
| 167 |
+
slot_mapping = []
|
| 168 |
+
context_lens = []
|
| 169 |
+
for seq in seqs:
|
| 170 |
+
input_ids.append(seq.last_token)
|
| 171 |
+
positions.append(len(seq) - 1)
|
| 172 |
+
context_lens.append(len(seq))
|
| 173 |
+
slot_mapping.append(seq.block_table[-1] * self.block_size + seq.last_block_num_tokens - 1)
|
| 174 |
+
input_ids = torch.tensor(input_ids, dtype=torch.int64, pin_memory=True).cuda(non_blocking=True)
|
| 175 |
+
positions = torch.tensor(positions, dtype=torch.int64, pin_memory=True).cuda(non_blocking=True)
|
| 176 |
+
slot_mapping = torch.tensor(slot_mapping, dtype=torch.int32, pin_memory=True).cuda(non_blocking=True)
|
| 177 |
+
context_lens = torch.tensor(context_lens, dtype=torch.int32, pin_memory=True).cuda(non_blocking=True)
|
| 178 |
+
block_tables = self.prepare_block_tables(seqs)
|
| 179 |
+
set_context(False, slot_mapping=slot_mapping, context_lens=context_lens, block_tables=block_tables)
|
| 180 |
+
return input_ids, positions
|
| 181 |
+
|
| 182 |
+
def prepare_sample(self, seqs: list[Sequence], is_cfg_batch: bool = False):
|
| 183 |
+
"""Prepare sampling parameters. For CFG batch, only return parameters for conditional sequences."""
|
| 184 |
+
if is_cfg_batch:
|
| 185 |
+
# For CFG batch, seqs contains [cond_seq1, cond_seq2, ..., uncond_seq1, uncond_seq2, ...]
|
| 186 |
+
# We only need temperatures for conditional sequences (first half)
|
| 187 |
+
num_cond = len(seqs) // 2
|
| 188 |
+
temperatures = []
|
| 189 |
+
cfg_scales = []
|
| 190 |
+
for seq in seqs[:num_cond]:
|
| 191 |
+
temperatures.append(seq.temperature)
|
| 192 |
+
cfg_scales.append(seq.cfg_scale)
|
| 193 |
+
else:
|
| 194 |
+
temperatures = []
|
| 195 |
+
cfg_scales = []
|
| 196 |
+
for seq in seqs:
|
| 197 |
+
temperatures.append(seq.temperature)
|
| 198 |
+
cfg_scales.append(seq.cfg_scale)
|
| 199 |
+
temperatures = torch.tensor(temperatures, dtype=torch.float32, pin_memory=True).cuda(non_blocking=True)
|
| 200 |
+
cfg_scales = torch.tensor(cfg_scales, dtype=torch.float32, pin_memory=True).cuda(non_blocking=True)
|
| 201 |
+
return temperatures, cfg_scales
|
| 202 |
+
|
| 203 |
+
@torch.inference_mode()
|
| 204 |
+
def run_model(self, input_ids: torch.Tensor, positions: torch.Tensor, is_prefill: bool):
|
| 205 |
+
if is_prefill or self.enforce_eager or input_ids.size(0) > 512:
|
| 206 |
+
return self.model.compute_logits(self.model(input_ids, positions))
|
| 207 |
+
else:
|
| 208 |
+
bs = input_ids.size(0)
|
| 209 |
+
context = get_context()
|
| 210 |
+
graph = self.graphs[next(x for x in self.graph_bs if x >= bs)]
|
| 211 |
+
graph_vars = self.graph_vars
|
| 212 |
+
graph_vars["input_ids"][:bs] = input_ids
|
| 213 |
+
graph_vars["positions"][:bs] = positions
|
| 214 |
+
graph_vars["slot_mapping"].fill_(-1)
|
| 215 |
+
graph_vars["slot_mapping"][:bs] = context.slot_mapping
|
| 216 |
+
graph_vars["context_lens"].zero_()
|
| 217 |
+
graph_vars["context_lens"][:bs] = context.context_lens
|
| 218 |
+
graph_vars["block_tables"][:bs, :context.block_tables.size(1)] = context.block_tables
|
| 219 |
+
graph.replay()
|
| 220 |
+
return self.model.compute_logits(graph_vars["outputs"][:bs])
|
| 221 |
+
|
| 222 |
+
def run(self, seqs: list[Sequence], is_prefill: bool) -> list[int]:
|
| 223 |
+
"""Run model forward and sampling. For CFG sequences, batch is structured as:
|
| 224 |
+
[cond_seq1, cond_seq2, ..., uncond_seq1, uncond_seq2, ...]
|
| 225 |
+
where uncond_seqi is the paired unconditional sequence of cond_seqi."""
|
| 226 |
+
# Check if this is a CFG batch (contains paired conditional and unconditional sequences)
|
| 227 |
+
is_cfg_batch = False
|
| 228 |
+
if len(seqs) > 0:
|
| 229 |
+
# CFG batch if first sequence has cfg_scale > 1.0 and paired_seq
|
| 230 |
+
if seqs[0].cfg_scale > 1.0 and seqs[0].paired_seq is not None:
|
| 231 |
+
is_cfg_batch = True
|
| 232 |
+
# Verify batch structure: first half conditional, second half unconditional
|
| 233 |
+
num_cond = len(seqs) // 2
|
| 234 |
+
for i in range(num_cond):
|
| 235 |
+
if seqs[i].is_unconditional or seqs[i + num_cond].is_unconditional == False:
|
| 236 |
+
is_cfg_batch = False
|
| 237 |
+
break
|
| 238 |
+
|
| 239 |
+
if is_cfg_batch:
|
| 240 |
+
# CFG batch: seqs = [cond_seq1, cond_seq2, ..., uncond_seq1, uncond_seq2, ...]
|
| 241 |
+
num_cond = len(seqs) // 2
|
| 242 |
+
cond_seqs = seqs[:num_cond]
|
| 243 |
+
uncond_seqs = seqs[num_cond:]
|
| 244 |
+
|
| 245 |
+
# Prepare inputs for both conditional and unconditional (they're already in the batch)
|
| 246 |
+
input_ids, positions = (self.prepare_prefill(seqs) if is_prefill
|
| 247 |
+
else self.prepare_decode(seqs))
|
| 248 |
+
temperatures, cfg_scales = self.prepare_sample(seqs, is_cfg_batch=True) if self.rank == 0 else (None, None)
|
| 249 |
+
|
| 250 |
+
# Run model forward (processes entire batch: cond + uncond)
|
| 251 |
+
logits_all = self.run_model(input_ids, positions, is_prefill)
|
| 252 |
+
reset_context()
|
| 253 |
+
|
| 254 |
+
if self.rank == 0:
|
| 255 |
+
# Split logits: first half is conditional, second half is unconditional
|
| 256 |
+
logits_cond = logits_all[:num_cond]
|
| 257 |
+
logits_uncond = logits_all[num_cond:]
|
| 258 |
+
|
| 259 |
+
# Apply CFG formula: logits_cfg = logits_cond + cfg_scale * (logits_cond - logits_uncond)
|
| 260 |
+
cfg_scales_tensor = cfg_scales.unsqueeze(1) # [num_cond, 1]
|
| 261 |
+
logits_cfg = logits_cond + cfg_scales_tensor * (logits_cond - logits_uncond)
|
| 262 |
+
|
| 263 |
+
# Sample from CFG logits
|
| 264 |
+
token_ids_cfg = self.sampler(logits_cfg, temperatures).tolist()
|
| 265 |
+
|
| 266 |
+
# Return token_ids (will be applied to both conditional and unconditional sequences)
|
| 267 |
+
return token_ids_cfg
|
| 268 |
+
else:
|
| 269 |
+
return None
|
| 270 |
+
else:
|
| 271 |
+
# Normal batch (non-CFG)
|
| 272 |
+
input_ids, positions = (self.prepare_prefill(seqs) if is_prefill
|
| 273 |
+
else self.prepare_decode(seqs))
|
| 274 |
+
temperatures, cfg_scales = self.prepare_sample(seqs, is_cfg_batch=False) if self.rank == 0 else (None, None)
|
| 275 |
+
logits = self.run_model(input_ids, positions, is_prefill)
|
| 276 |
+
reset_context()
|
| 277 |
+
token_ids = self.sampler(logits, temperatures).tolist() if self.rank == 0 else None
|
| 278 |
+
return token_ids
|
| 279 |
+
|
| 280 |
+
@torch.inference_mode()
|
| 281 |
+
def capture_cudagraph(self):
|
| 282 |
+
config = self.config
|
| 283 |
+
hf_config = config.hf_config
|
| 284 |
+
max_bs = min(self.config.max_num_seqs, 512)
|
| 285 |
+
max_num_blocks = (config.max_model_len + self.block_size - 1) // self.block_size
|
| 286 |
+
input_ids = torch.zeros(max_bs, dtype=torch.int64)
|
| 287 |
+
positions = torch.zeros(max_bs, dtype=torch.int64)
|
| 288 |
+
slot_mapping = torch.zeros(max_bs, dtype=torch.int32)
|
| 289 |
+
context_lens = torch.zeros(max_bs, dtype=torch.int32)
|
| 290 |
+
block_tables = torch.zeros(max_bs, max_num_blocks, dtype=torch.int32)
|
| 291 |
+
outputs = torch.zeros(max_bs, hf_config.hidden_size)
|
| 292 |
+
self.graph_bs = [1, 2, 4, 8] + list(range(16, max_bs + 1, 16))
|
| 293 |
+
self.graphs = {}
|
| 294 |
+
self.graph_pool = None
|
| 295 |
+
|
| 296 |
+
for bs in reversed(self.graph_bs):
|
| 297 |
+
graph = torch.cuda.CUDAGraph()
|
| 298 |
+
set_context(False, slot_mapping=slot_mapping[:bs], context_lens=context_lens[:bs], block_tables=block_tables[:bs])
|
| 299 |
+
outputs[:bs] = self.model(input_ids[:bs], positions[:bs]) # warmup
|
| 300 |
+
with torch.cuda.graph(graph, self.graph_pool):
|
| 301 |
+
outputs[:bs] = self.model(input_ids[:bs], positions[:bs]) # capture
|
| 302 |
+
if self.graph_pool is None:
|
| 303 |
+
self.graph_pool = graph.pool()
|
| 304 |
+
self.graphs[bs] = graph
|
| 305 |
+
torch.cuda.synchronize()
|
| 306 |
+
reset_context()
|
| 307 |
+
|
| 308 |
+
self.graph_vars = dict(
|
| 309 |
+
input_ids=input_ids,
|
| 310 |
+
positions=positions,
|
| 311 |
+
slot_mapping=slot_mapping,
|
| 312 |
+
context_lens=context_lens,
|
| 313 |
+
block_tables=block_tables,
|
| 314 |
+
outputs=outputs,
|
| 315 |
+
)
|
acestep/third_parts/nano-vllm/nanovllm/engine/scheduler.py
ADDED
|
@@ -0,0 +1,222 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from collections import deque
|
| 2 |
+
|
| 3 |
+
from nanovllm.config import Config
|
| 4 |
+
from nanovllm.engine.sequence import Sequence, SequenceStatus
|
| 5 |
+
from nanovllm.engine.block_manager import BlockManager
|
| 6 |
+
|
| 7 |
+
|
| 8 |
+
class Scheduler:
|
| 9 |
+
|
| 10 |
+
def __init__(self, config: Config):
|
| 11 |
+
self.max_num_seqs = config.max_num_seqs
|
| 12 |
+
self.max_num_batched_tokens = config.max_num_batched_tokens
|
| 13 |
+
self.eos = config.eos
|
| 14 |
+
self.block_manager = BlockManager(config.num_kvcache_blocks, config.kvcache_block_size)
|
| 15 |
+
self.waiting: deque[Sequence] = deque()
|
| 16 |
+
self.running: deque[Sequence] = deque()
|
| 17 |
+
|
| 18 |
+
def is_finished(self):
|
| 19 |
+
return not self.waiting and not self.running
|
| 20 |
+
|
| 21 |
+
def add(self, seq: Sequence):
|
| 22 |
+
self.waiting.append(seq)
|
| 23 |
+
|
| 24 |
+
def schedule(self) -> tuple[list[Sequence], bool]:
|
| 25 |
+
# prefill
|
| 26 |
+
scheduled_seqs = []
|
| 27 |
+
num_seqs = 0
|
| 28 |
+
num_batched_tokens = 0
|
| 29 |
+
processed_seqs = set() # Track processed sequences to handle CFG pairs
|
| 30 |
+
|
| 31 |
+
while self.waiting and num_seqs < self.max_num_seqs:
|
| 32 |
+
seq = self.waiting[0]
|
| 33 |
+
|
| 34 |
+
# For CFG sequences, ensure conditional and unconditional are scheduled together
|
| 35 |
+
if seq.cfg_scale > 1.0 and seq.paired_seq is not None and not seq.is_unconditional:
|
| 36 |
+
# This is a conditional sequence, need to schedule its paired unconditional sequence too
|
| 37 |
+
paired_seq = seq.paired_seq
|
| 38 |
+
if paired_seq.status != SequenceStatus.WAITING:
|
| 39 |
+
# Paired sequence not in waiting, skip this conditional sequence for now
|
| 40 |
+
break
|
| 41 |
+
|
| 42 |
+
# Calculate tokens for both sequences
|
| 43 |
+
total_tokens = (len(seq) - seq.num_cached_tokens) + (len(paired_seq) - paired_seq.num_cached_tokens)
|
| 44 |
+
can_allocate_both = (self.block_manager.can_allocate(seq) and
|
| 45 |
+
self.block_manager.can_allocate(paired_seq))
|
| 46 |
+
|
| 47 |
+
if num_batched_tokens + total_tokens > self.max_num_batched_tokens or not can_allocate_both:
|
| 48 |
+
break
|
| 49 |
+
|
| 50 |
+
# Schedule both sequences: conditional first, then unconditional
|
| 51 |
+
for s in [seq, paired_seq]:
|
| 52 |
+
num_seqs += 1
|
| 53 |
+
self.block_manager.allocate(s)
|
| 54 |
+
num_batched_tokens += len(s) - s.num_cached_tokens
|
| 55 |
+
s.status = SequenceStatus.RUNNING
|
| 56 |
+
self.waiting.remove(s)
|
| 57 |
+
self.running.append(s)
|
| 58 |
+
scheduled_seqs.append(s)
|
| 59 |
+
processed_seqs.add(s.seq_id)
|
| 60 |
+
else:
|
| 61 |
+
# Normal sequence or unconditional sequence (already processed with its conditional)
|
| 62 |
+
if seq.seq_id in processed_seqs:
|
| 63 |
+
# Skip if already processed as part of a CFG pair
|
| 64 |
+
self.waiting.popleft()
|
| 65 |
+
continue
|
| 66 |
+
|
| 67 |
+
if num_batched_tokens + len(seq) > self.max_num_batched_tokens or not self.block_manager.can_allocate(seq):
|
| 68 |
+
break
|
| 69 |
+
num_seqs += 1
|
| 70 |
+
self.block_manager.allocate(seq)
|
| 71 |
+
num_batched_tokens += len(seq) - seq.num_cached_tokens
|
| 72 |
+
seq.status = SequenceStatus.RUNNING
|
| 73 |
+
self.waiting.popleft()
|
| 74 |
+
self.running.append(seq)
|
| 75 |
+
scheduled_seqs.append(seq)
|
| 76 |
+
|
| 77 |
+
if scheduled_seqs:
|
| 78 |
+
# For CFG batches, ensure conditional sequences come before their unconditional pairs
|
| 79 |
+
cfg_cond_seqs = [s for s in scheduled_seqs if s.cfg_scale > 1.0 and not s.is_unconditional]
|
| 80 |
+
cfg_uncond_seqs = [s for s in scheduled_seqs if s.is_unconditional]
|
| 81 |
+
non_cfg_seqs = [s for s in scheduled_seqs if s.cfg_scale <= 1.0]
|
| 82 |
+
|
| 83 |
+
# Reorder: non-CFG, then CFG conditional, then CFG unconditional
|
| 84 |
+
scheduled_seqs = non_cfg_seqs + cfg_cond_seqs + cfg_uncond_seqs
|
| 85 |
+
return scheduled_seqs, True
|
| 86 |
+
|
| 87 |
+
# decode
|
| 88 |
+
processed_seqs = set()
|
| 89 |
+
temp_running = list(self.running) # Work with a copy
|
| 90 |
+
|
| 91 |
+
while temp_running and num_seqs < self.max_num_seqs:
|
| 92 |
+
seq = temp_running.pop(0)
|
| 93 |
+
|
| 94 |
+
# For CFG sequences, ensure conditional and unconditional are scheduled together
|
| 95 |
+
if seq.cfg_scale > 1.0 and seq.paired_seq is not None and not seq.is_unconditional:
|
| 96 |
+
paired_seq = seq.paired_seq
|
| 97 |
+
if paired_seq not in temp_running:
|
| 98 |
+
# Paired sequence not available, skip for now
|
| 99 |
+
continue
|
| 100 |
+
|
| 101 |
+
# Remove paired_seq from temp_running
|
| 102 |
+
temp_running.remove(paired_seq)
|
| 103 |
+
|
| 104 |
+
# Check if both can append
|
| 105 |
+
can_append_both = (self.block_manager.can_append(seq) and
|
| 106 |
+
self.block_manager.can_append(paired_seq))
|
| 107 |
+
|
| 108 |
+
if not can_append_both:
|
| 109 |
+
# Try preempting other sequences
|
| 110 |
+
preempted = False
|
| 111 |
+
while not can_append_both and temp_running:
|
| 112 |
+
other_seq = temp_running.pop(0)
|
| 113 |
+
if other_seq != seq and other_seq != paired_seq:
|
| 114 |
+
self.preempt(other_seq)
|
| 115 |
+
can_append_both = (self.block_manager.can_append(seq) and
|
| 116 |
+
self.block_manager.can_append(paired_seq))
|
| 117 |
+
preempted = True
|
| 118 |
+
else:
|
| 119 |
+
temp_running.append(other_seq)
|
| 120 |
+
break
|
| 121 |
+
|
| 122 |
+
if not can_append_both:
|
| 123 |
+
# Can't schedule this pair right now
|
| 124 |
+
temp_running.append(seq)
|
| 125 |
+
temp_running.append(paired_seq)
|
| 126 |
+
continue
|
| 127 |
+
|
| 128 |
+
# Schedule both sequences
|
| 129 |
+
for s in [seq, paired_seq]:
|
| 130 |
+
num_seqs += 1
|
| 131 |
+
self.block_manager.may_append(s)
|
| 132 |
+
scheduled_seqs.append(s)
|
| 133 |
+
processed_seqs.add(s.seq_id)
|
| 134 |
+
# Remove from actual running list if scheduled
|
| 135 |
+
if s in self.running:
|
| 136 |
+
self.running.remove(s)
|
| 137 |
+
else:
|
| 138 |
+
# Normal sequence or unconditional (already processed)
|
| 139 |
+
if seq.seq_id in processed_seqs:
|
| 140 |
+
continue
|
| 141 |
+
|
| 142 |
+
while not self.block_manager.can_append(seq):
|
| 143 |
+
if temp_running:
|
| 144 |
+
other_seq = temp_running.pop(0)
|
| 145 |
+
if other_seq != seq:
|
| 146 |
+
self.preempt(other_seq)
|
| 147 |
+
else:
|
| 148 |
+
temp_running.append(other_seq)
|
| 149 |
+
break
|
| 150 |
+
else:
|
| 151 |
+
self.preempt(seq)
|
| 152 |
+
if seq in self.running:
|
| 153 |
+
self.running.remove(seq)
|
| 154 |
+
break
|
| 155 |
+
else:
|
| 156 |
+
num_seqs += 1
|
| 157 |
+
self.block_manager.may_append(seq)
|
| 158 |
+
scheduled_seqs.append(seq)
|
| 159 |
+
if seq in self.running:
|
| 160 |
+
self.running.remove(seq)
|
| 161 |
+
|
| 162 |
+
assert scheduled_seqs
|
| 163 |
+
|
| 164 |
+
# For CFG batches in decode, ensure conditional sequences come before unconditional
|
| 165 |
+
cfg_cond_seqs = [s for s in scheduled_seqs if s.cfg_scale > 1.0 and not s.is_unconditional]
|
| 166 |
+
cfg_uncond_seqs = [s for s in scheduled_seqs if s.is_unconditional]
|
| 167 |
+
non_cfg_seqs = [s for s in scheduled_seqs if s.cfg_scale <= 1.0]
|
| 168 |
+
scheduled_seqs = non_cfg_seqs + cfg_cond_seqs + cfg_uncond_seqs
|
| 169 |
+
|
| 170 |
+
self.running.extendleft(reversed(scheduled_seqs))
|
| 171 |
+
return scheduled_seqs, False
|
| 172 |
+
|
| 173 |
+
def preempt(self, seq: Sequence):
|
| 174 |
+
seq.status = SequenceStatus.WAITING
|
| 175 |
+
self.block_manager.deallocate(seq)
|
| 176 |
+
self.waiting.appendleft(seq)
|
| 177 |
+
|
| 178 |
+
def postprocess(self, seqs: list[Sequence], token_ids: list[int]) -> list[bool]:
|
| 179 |
+
# Check if this is a CFG batch
|
| 180 |
+
is_cfg_batch = False
|
| 181 |
+
if len(seqs) > 0 and seqs[0].cfg_scale > 1.0 and seqs[0].paired_seq is not None:
|
| 182 |
+
num_cond = len(seqs) // 2
|
| 183 |
+
is_cfg_batch = (num_cond > 0 and
|
| 184 |
+
not seqs[0].is_unconditional and
|
| 185 |
+
seqs[num_cond].is_unconditional)
|
| 186 |
+
|
| 187 |
+
if is_cfg_batch:
|
| 188 |
+
# CFG batch: seqs = [cond_seq1, cond_seq2, ..., uncond_seq1, uncond_seq2, ...]
|
| 189 |
+
# token_ids correspond to conditional sequences only (sampled from CFG logits)
|
| 190 |
+
num_cond = len(seqs) // 2
|
| 191 |
+
cond_seqs = seqs[:num_cond]
|
| 192 |
+
uncond_seqs = seqs[num_cond:]
|
| 193 |
+
|
| 194 |
+
# Apply the same sampled token to both conditional and unconditional sequences
|
| 195 |
+
for i, (cond_seq, uncond_seq, token_id) in enumerate(zip(cond_seqs, uncond_seqs, token_ids)):
|
| 196 |
+
cond_seq.append_token(token_id)
|
| 197 |
+
uncond_seq.append_token(token_id) # Same token for unconditional
|
| 198 |
+
|
| 199 |
+
# Check if either sequence is finished
|
| 200 |
+
cond_finished = ((not cond_seq.ignore_eos and token_id == self.eos) or
|
| 201 |
+
cond_seq.num_completion_tokens == cond_seq.max_tokens)
|
| 202 |
+
uncond_finished = ((not uncond_seq.ignore_eos and token_id == self.eos) or
|
| 203 |
+
uncond_seq.num_completion_tokens == uncond_seq.max_tokens)
|
| 204 |
+
|
| 205 |
+
if cond_finished or uncond_finished:
|
| 206 |
+
# Mark both as finished
|
| 207 |
+
cond_seq.status = SequenceStatus.FINISHED
|
| 208 |
+
uncond_seq.status = SequenceStatus.FINISHED
|
| 209 |
+
self.block_manager.deallocate(cond_seq)
|
| 210 |
+
self.block_manager.deallocate(uncond_seq)
|
| 211 |
+
if cond_seq in self.running:
|
| 212 |
+
self.running.remove(cond_seq)
|
| 213 |
+
if uncond_seq in self.running:
|
| 214 |
+
self.running.remove(uncond_seq)
|
| 215 |
+
else:
|
| 216 |
+
# Normal batch
|
| 217 |
+
for seq, token_id in zip(seqs, token_ids):
|
| 218 |
+
seq.append_token(token_id)
|
| 219 |
+
if (not seq.ignore_eos and token_id == self.eos) or seq.num_completion_tokens == seq.max_tokens:
|
| 220 |
+
seq.status = SequenceStatus.FINISHED
|
| 221 |
+
self.block_manager.deallocate(seq)
|
| 222 |
+
self.running.remove(seq)
|
acestep/third_parts/nano-vllm/nanovllm/engine/sequence.py
ADDED
|
@@ -0,0 +1,89 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from copy import copy
|
| 2 |
+
from enum import Enum, auto
|
| 3 |
+
from itertools import count
|
| 4 |
+
|
| 5 |
+
from nanovllm.sampling_params import SamplingParams
|
| 6 |
+
|
| 7 |
+
|
| 8 |
+
class SequenceStatus(Enum):
|
| 9 |
+
WAITING = auto()
|
| 10 |
+
RUNNING = auto()
|
| 11 |
+
FINISHED = auto()
|
| 12 |
+
|
| 13 |
+
|
| 14 |
+
class Sequence:
|
| 15 |
+
block_size = 256
|
| 16 |
+
counter = count()
|
| 17 |
+
|
| 18 |
+
def __init__(self, token_ids: list[int], sampling_params = SamplingParams(), is_unconditional: bool = False, conditional_seq = None):
|
| 19 |
+
self.seq_id = next(Sequence.counter)
|
| 20 |
+
self.status = SequenceStatus.WAITING
|
| 21 |
+
self.token_ids = copy(token_ids)
|
| 22 |
+
self.last_token = token_ids[-1]
|
| 23 |
+
self.num_tokens = len(self.token_ids)
|
| 24 |
+
self.num_prompt_tokens = len(token_ids)
|
| 25 |
+
self.num_cached_tokens = 0
|
| 26 |
+
self.block_table = []
|
| 27 |
+
self.temperature = sampling_params.temperature
|
| 28 |
+
self.max_tokens = sampling_params.max_tokens
|
| 29 |
+
self.ignore_eos = sampling_params.ignore_eos
|
| 30 |
+
self.cfg_scale = sampling_params.cfg_scale
|
| 31 |
+
# For CFG: mark if this is an unconditional sequence
|
| 32 |
+
self.is_unconditional = is_unconditional
|
| 33 |
+
# For CFG: reference to the corresponding conditional sequence (if this is unconditional)
|
| 34 |
+
# For conditional sequences, this points to the unconditional sequence
|
| 35 |
+
self.paired_seq = conditional_seq # For conditional seq, points to uncond; for uncond seq, points to cond
|
| 36 |
+
|
| 37 |
+
def __len__(self):
|
| 38 |
+
return self.num_tokens
|
| 39 |
+
|
| 40 |
+
def __getitem__(self, key):
|
| 41 |
+
return self.token_ids[key]
|
| 42 |
+
|
| 43 |
+
@property
|
| 44 |
+
def is_finished(self):
|
| 45 |
+
return self.status == SequenceStatus.FINISHED
|
| 46 |
+
|
| 47 |
+
@property
|
| 48 |
+
def num_completion_tokens(self):
|
| 49 |
+
return self.num_tokens - self.num_prompt_tokens
|
| 50 |
+
|
| 51 |
+
@property
|
| 52 |
+
def prompt_token_ids(self):
|
| 53 |
+
return self.token_ids[:self.num_prompt_tokens]
|
| 54 |
+
|
| 55 |
+
@property
|
| 56 |
+
def completion_token_ids(self):
|
| 57 |
+
return self.token_ids[self.num_prompt_tokens:]
|
| 58 |
+
|
| 59 |
+
@property
|
| 60 |
+
def num_cached_blocks(self):
|
| 61 |
+
return self.num_cached_tokens // self.block_size
|
| 62 |
+
|
| 63 |
+
@property
|
| 64 |
+
def num_blocks(self):
|
| 65 |
+
return (self.num_tokens + self.block_size - 1) // self.block_size
|
| 66 |
+
|
| 67 |
+
@property
|
| 68 |
+
def last_block_num_tokens(self):
|
| 69 |
+
return self.num_tokens - (self.num_blocks - 1) * self.block_size
|
| 70 |
+
|
| 71 |
+
def block(self, i):
|
| 72 |
+
assert 0 <= i < self.num_blocks
|
| 73 |
+
return self.token_ids[i*self.block_size: (i+1)*self.block_size]
|
| 74 |
+
|
| 75 |
+
def append_token(self, token_id: int):
|
| 76 |
+
self.token_ids.append(token_id)
|
| 77 |
+
self.last_token = token_id
|
| 78 |
+
self.num_tokens += 1
|
| 79 |
+
|
| 80 |
+
def __getstate__(self):
|
| 81 |
+
return (self.num_tokens, self.num_prompt_tokens, self.num_cached_tokens, self.block_table,
|
| 82 |
+
self.token_ids if self.num_completion_tokens == 0 else self.last_token)
|
| 83 |
+
|
| 84 |
+
def __setstate__(self, state):
|
| 85 |
+
self.num_tokens, self.num_prompt_tokens, self.num_cached_tokens, self.block_table = state[:-1]
|
| 86 |
+
if self.num_completion_tokens == 0:
|
| 87 |
+
self.token_ids = state[-1]
|
| 88 |
+
else:
|
| 89 |
+
self.last_token = state[-1]
|
acestep/third_parts/nano-vllm/nanovllm/layers/activation.py
ADDED
|
@@ -0,0 +1,14 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
from torch import nn
|
| 3 |
+
import torch.nn.functional as F
|
| 4 |
+
|
| 5 |
+
|
| 6 |
+
class SiluAndMul(nn.Module):
|
| 7 |
+
|
| 8 |
+
def __init__(self):
|
| 9 |
+
super().__init__()
|
| 10 |
+
|
| 11 |
+
@torch.compile
|
| 12 |
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
| 13 |
+
x, y = x.chunk(2, -1)
|
| 14 |
+
return F.silu(x) * y
|
acestep/third_parts/nano-vllm/nanovllm/layers/attention.py
ADDED
|
@@ -0,0 +1,75 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
from torch import nn
|
| 3 |
+
import triton
|
| 4 |
+
import triton.language as tl
|
| 5 |
+
|
| 6 |
+
from flash_attn import flash_attn_varlen_func, flash_attn_with_kvcache
|
| 7 |
+
from nanovllm.utils.context import get_context
|
| 8 |
+
|
| 9 |
+
|
| 10 |
+
@triton.jit
|
| 11 |
+
def store_kvcache_kernel(
|
| 12 |
+
key_ptr,
|
| 13 |
+
key_stride,
|
| 14 |
+
value_ptr,
|
| 15 |
+
value_stride,
|
| 16 |
+
k_cache_ptr,
|
| 17 |
+
v_cache_ptr,
|
| 18 |
+
slot_mapping_ptr,
|
| 19 |
+
D: tl.constexpr,
|
| 20 |
+
):
|
| 21 |
+
idx = tl.program_id(0)
|
| 22 |
+
slot = tl.load(slot_mapping_ptr + idx)
|
| 23 |
+
if slot == -1: return
|
| 24 |
+
key_offsets = idx * key_stride + tl.arange(0, D)
|
| 25 |
+
value_offsets = idx * value_stride + tl.arange(0, D)
|
| 26 |
+
key = tl.load(key_ptr + key_offsets)
|
| 27 |
+
value = tl.load(value_ptr + value_offsets)
|
| 28 |
+
cache_offsets = slot * D + tl.arange(0, D)
|
| 29 |
+
tl.store(k_cache_ptr + cache_offsets, key)
|
| 30 |
+
tl.store(v_cache_ptr + cache_offsets, value)
|
| 31 |
+
|
| 32 |
+
|
| 33 |
+
def store_kvcache(key: torch.Tensor, value: torch.Tensor, k_cache: torch.Tensor, v_cache: torch.Tensor, slot_mapping: torch.Tensor):
|
| 34 |
+
N, num_heads, head_dim = key.shape
|
| 35 |
+
D = num_heads * head_dim
|
| 36 |
+
assert key.stride(-1) == 1 and value.stride(-1) == 1
|
| 37 |
+
assert key.stride(1) == head_dim and value.stride(1) == head_dim
|
| 38 |
+
assert k_cache.stride(1) == D and v_cache.stride(1) == D
|
| 39 |
+
assert slot_mapping.numel() == N
|
| 40 |
+
store_kvcache_kernel[(N,)](key, key.stride(0), value, value.stride(0), k_cache, v_cache, slot_mapping, D)
|
| 41 |
+
|
| 42 |
+
|
| 43 |
+
class Attention(nn.Module):
|
| 44 |
+
|
| 45 |
+
def __init__(
|
| 46 |
+
self,
|
| 47 |
+
num_heads,
|
| 48 |
+
head_dim,
|
| 49 |
+
scale,
|
| 50 |
+
num_kv_heads,
|
| 51 |
+
):
|
| 52 |
+
super().__init__()
|
| 53 |
+
self.num_heads = num_heads
|
| 54 |
+
self.head_dim = head_dim
|
| 55 |
+
self.scale = scale
|
| 56 |
+
self.num_kv_heads = num_kv_heads
|
| 57 |
+
self.k_cache = self.v_cache = torch.tensor([])
|
| 58 |
+
|
| 59 |
+
def forward(self, q: torch.Tensor, k: torch.Tensor, v: torch.Tensor):
|
| 60 |
+
context = get_context()
|
| 61 |
+
k_cache, v_cache = self.k_cache, self.v_cache
|
| 62 |
+
if k_cache.numel() and v_cache.numel():
|
| 63 |
+
store_kvcache(k, v, k_cache, v_cache, context.slot_mapping)
|
| 64 |
+
if context.is_prefill:
|
| 65 |
+
if context.block_tables is not None: # prefix cache
|
| 66 |
+
k, v = k_cache, v_cache
|
| 67 |
+
o = flash_attn_varlen_func(q, k, v,
|
| 68 |
+
max_seqlen_q=context.max_seqlen_q, cu_seqlens_q=context.cu_seqlens_q,
|
| 69 |
+
max_seqlen_k=context.max_seqlen_k, cu_seqlens_k=context.cu_seqlens_k,
|
| 70 |
+
softmax_scale=self.scale, causal=True, block_table=context.block_tables)
|
| 71 |
+
else: # decode
|
| 72 |
+
o = flash_attn_with_kvcache(q.unsqueeze(1), k_cache, v_cache,
|
| 73 |
+
cache_seqlens=context.context_lens, block_table=context.block_tables,
|
| 74 |
+
softmax_scale=self.scale, causal=True)
|
| 75 |
+
return o
|
acestep/third_parts/nano-vllm/nanovllm/layers/embed_head.py
ADDED
|
@@ -0,0 +1,66 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
from torch import nn
|
| 3 |
+
import torch.nn.functional as F
|
| 4 |
+
import torch.distributed as dist
|
| 5 |
+
|
| 6 |
+
from nanovllm.utils.context import get_context
|
| 7 |
+
|
| 8 |
+
|
| 9 |
+
class VocabParallelEmbedding(nn.Module):
|
| 10 |
+
|
| 11 |
+
def __init__(
|
| 12 |
+
self,
|
| 13 |
+
num_embeddings: int,
|
| 14 |
+
embedding_dim: int,
|
| 15 |
+
):
|
| 16 |
+
super().__init__()
|
| 17 |
+
self.tp_rank = dist.get_rank()
|
| 18 |
+
self.tp_size = dist.get_world_size()
|
| 19 |
+
assert num_embeddings % self.tp_size == 0
|
| 20 |
+
self.num_embeddings = num_embeddings
|
| 21 |
+
self.num_embeddings_per_partition = self.num_embeddings // self.tp_size
|
| 22 |
+
self.vocab_start_idx = self.num_embeddings_per_partition * self.tp_rank
|
| 23 |
+
self.vocab_end_idx = self.vocab_start_idx + self.num_embeddings_per_partition
|
| 24 |
+
self.weight = nn.Parameter(torch.empty(self.num_embeddings_per_partition, embedding_dim))
|
| 25 |
+
self.weight.weight_loader = self.weight_loader
|
| 26 |
+
|
| 27 |
+
def weight_loader(self, param: nn.Parameter, loaded_weight: torch.Tensor):
|
| 28 |
+
param_data = param.data
|
| 29 |
+
shard_size = param_data.size(0)
|
| 30 |
+
start_idx = self.tp_rank * shard_size
|
| 31 |
+
loaded_weight = loaded_weight.narrow(0, start_idx, shard_size)
|
| 32 |
+
param_data.copy_(loaded_weight)
|
| 33 |
+
|
| 34 |
+
def forward(self, x: torch.Tensor):
|
| 35 |
+
if self.tp_size > 1:
|
| 36 |
+
mask = (x >= self.vocab_start_idx) & (x < self.vocab_end_idx)
|
| 37 |
+
x = mask * (x - self.vocab_start_idx)
|
| 38 |
+
y = F.embedding(x, self.weight)
|
| 39 |
+
if self.tp_size > 1:
|
| 40 |
+
y = mask.unsqueeze(1) * y
|
| 41 |
+
dist.all_reduce(y)
|
| 42 |
+
return y
|
| 43 |
+
|
| 44 |
+
|
| 45 |
+
class ParallelLMHead(VocabParallelEmbedding):
|
| 46 |
+
|
| 47 |
+
def __init__(
|
| 48 |
+
self,
|
| 49 |
+
num_embeddings: int,
|
| 50 |
+
embedding_dim: int,
|
| 51 |
+
bias: bool = False,
|
| 52 |
+
):
|
| 53 |
+
assert not bias
|
| 54 |
+
super().__init__(num_embeddings, embedding_dim)
|
| 55 |
+
|
| 56 |
+
def forward(self, x: torch.Tensor):
|
| 57 |
+
context = get_context()
|
| 58 |
+
if context.is_prefill:
|
| 59 |
+
last_indices = context.cu_seqlens_q[1:] - 1
|
| 60 |
+
x = x[last_indices].contiguous()
|
| 61 |
+
logits = F.linear(x, self.weight)
|
| 62 |
+
if self.tp_size > 1:
|
| 63 |
+
all_logits = [torch.empty_like(logits) for _ in range(self.tp_size)] if self.tp_rank == 0 else None
|
| 64 |
+
dist.gather(logits, all_logits, 0)
|
| 65 |
+
logits = torch.cat(all_logits, -1) if self.tp_rank == 0 else None
|
| 66 |
+
return logits
|
acestep/third_parts/nano-vllm/nanovllm/layers/layernorm.py
ADDED
|
@@ -0,0 +1,50 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
from torch import nn
|
| 3 |
+
|
| 4 |
+
|
| 5 |
+
class RMSNorm(nn.Module):
|
| 6 |
+
|
| 7 |
+
def __init__(
|
| 8 |
+
self,
|
| 9 |
+
hidden_size: int,
|
| 10 |
+
eps: float = 1e-6,
|
| 11 |
+
) -> None:
|
| 12 |
+
super().__init__()
|
| 13 |
+
self.eps = eps
|
| 14 |
+
self.weight = nn.Parameter(torch.ones(hidden_size))
|
| 15 |
+
|
| 16 |
+
@torch.compile
|
| 17 |
+
def rms_forward(
|
| 18 |
+
self,
|
| 19 |
+
x: torch.Tensor,
|
| 20 |
+
) -> torch.Tensor:
|
| 21 |
+
orig_dtype = x.dtype
|
| 22 |
+
x = x.float()
|
| 23 |
+
var = x.pow(2).mean(dim=-1, keepdim=True)
|
| 24 |
+
x.mul_(torch.rsqrt(var + self.eps))
|
| 25 |
+
x = x.to(orig_dtype).mul_(self.weight)
|
| 26 |
+
return x
|
| 27 |
+
|
| 28 |
+
@torch.compile
|
| 29 |
+
def add_rms_forward(
|
| 30 |
+
self,
|
| 31 |
+
x: torch.Tensor,
|
| 32 |
+
residual: torch.Tensor,
|
| 33 |
+
) -> tuple[torch.Tensor, torch.Tensor]:
|
| 34 |
+
orig_dtype = x.dtype
|
| 35 |
+
x = x.float().add_(residual.float())
|
| 36 |
+
residual = x.to(orig_dtype)
|
| 37 |
+
var = x.pow(2).mean(dim=-1, keepdim=True)
|
| 38 |
+
x.mul_(torch.rsqrt(var + self.eps))
|
| 39 |
+
x = x.to(orig_dtype).mul_(self.weight)
|
| 40 |
+
return x, residual
|
| 41 |
+
|
| 42 |
+
def forward(
|
| 43 |
+
self,
|
| 44 |
+
x: torch.Tensor,
|
| 45 |
+
residual: torch.Tensor | None = None,
|
| 46 |
+
) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
|
| 47 |
+
if residual is None:
|
| 48 |
+
return self.rms_forward(x)
|
| 49 |
+
else:
|
| 50 |
+
return self.add_rms_forward(x, residual)
|
acestep/third_parts/nano-vllm/nanovllm/layers/linear.py
ADDED
|
@@ -0,0 +1,153 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
from torch import nn
|
| 3 |
+
import torch.nn.functional as F
|
| 4 |
+
import torch.distributed as dist
|
| 5 |
+
|
| 6 |
+
|
| 7 |
+
def divide(numerator, denominator):
|
| 8 |
+
assert numerator % denominator == 0
|
| 9 |
+
return numerator // denominator
|
| 10 |
+
|
| 11 |
+
|
| 12 |
+
class LinearBase(nn.Module):
|
| 13 |
+
|
| 14 |
+
def __init__(
|
| 15 |
+
self,
|
| 16 |
+
input_size: int,
|
| 17 |
+
output_size: int,
|
| 18 |
+
bias: bool = False,
|
| 19 |
+
tp_dim: int | None = None,
|
| 20 |
+
):
|
| 21 |
+
super().__init__()
|
| 22 |
+
self.tp_dim = tp_dim
|
| 23 |
+
self.tp_rank = dist.get_rank()
|
| 24 |
+
self.tp_size = dist.get_world_size()
|
| 25 |
+
self.weight = nn.Parameter(torch.empty(output_size, input_size))
|
| 26 |
+
self.weight.weight_loader = self.weight_loader
|
| 27 |
+
if bias:
|
| 28 |
+
self.bias = nn.Parameter(torch.empty(output_size))
|
| 29 |
+
self.bias.weight_loader = self.weight_loader
|
| 30 |
+
else:
|
| 31 |
+
self.register_parameter("bias", None)
|
| 32 |
+
|
| 33 |
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
| 34 |
+
raise NotImplementedError
|
| 35 |
+
|
| 36 |
+
|
| 37 |
+
class ReplicatedLinear(LinearBase):
|
| 38 |
+
|
| 39 |
+
def __init__(
|
| 40 |
+
self,
|
| 41 |
+
input_size: int,
|
| 42 |
+
output_size: int,
|
| 43 |
+
bias: bool = False,
|
| 44 |
+
):
|
| 45 |
+
super().__init__(input_size, output_size, bias)
|
| 46 |
+
|
| 47 |
+
def weight_loader(self, param: nn.Parameter, loaded_weight: torch.Tensor):
|
| 48 |
+
param.data.copy_(loaded_weight)
|
| 49 |
+
|
| 50 |
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
| 51 |
+
return F.linear(x, self.weight, self.bias)
|
| 52 |
+
|
| 53 |
+
|
| 54 |
+
class ColumnParallelLinear(LinearBase):
|
| 55 |
+
|
| 56 |
+
def __init__(
|
| 57 |
+
self,
|
| 58 |
+
input_size: int,
|
| 59 |
+
output_size: int,
|
| 60 |
+
bias: bool = False,
|
| 61 |
+
):
|
| 62 |
+
tp_size = dist.get_world_size()
|
| 63 |
+
super().__init__(input_size, divide(output_size, tp_size), bias, 0)
|
| 64 |
+
|
| 65 |
+
def weight_loader(self, param: nn.Parameter, loaded_weight: torch.Tensor):
|
| 66 |
+
param_data = param.data
|
| 67 |
+
shard_size = param_data.size(self.tp_dim)
|
| 68 |
+
start_idx = self.tp_rank * shard_size
|
| 69 |
+
loaded_weight = loaded_weight.narrow(self.tp_dim, start_idx, shard_size)
|
| 70 |
+
param_data.copy_(loaded_weight)
|
| 71 |
+
|
| 72 |
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
| 73 |
+
return F.linear(x, self.weight, self.bias)
|
| 74 |
+
|
| 75 |
+
|
| 76 |
+
class MergedColumnParallelLinear(ColumnParallelLinear):
|
| 77 |
+
|
| 78 |
+
def __init__(
|
| 79 |
+
self,
|
| 80 |
+
input_size: int,
|
| 81 |
+
output_sizes: list[int],
|
| 82 |
+
bias: bool = False,
|
| 83 |
+
):
|
| 84 |
+
self.output_sizes = output_sizes
|
| 85 |
+
super().__init__(input_size, sum(output_sizes), bias)
|
| 86 |
+
|
| 87 |
+
def weight_loader(self, param: nn.Parameter, loaded_weight: torch.Tensor, loaded_shard_id: int):
|
| 88 |
+
param_data = param.data
|
| 89 |
+
shard_offset = sum(self.output_sizes[:loaded_shard_id]) // self.tp_size
|
| 90 |
+
shard_size = self.output_sizes[loaded_shard_id] // self.tp_size
|
| 91 |
+
param_data = param_data.narrow(self.tp_dim, shard_offset, shard_size)
|
| 92 |
+
loaded_weight = loaded_weight.chunk(self.tp_size, self.tp_dim)[self.tp_rank]
|
| 93 |
+
param_data.copy_(loaded_weight)
|
| 94 |
+
|
| 95 |
+
|
| 96 |
+
class QKVParallelLinear(ColumnParallelLinear):
|
| 97 |
+
|
| 98 |
+
def __init__(
|
| 99 |
+
self,
|
| 100 |
+
hidden_size: int,
|
| 101 |
+
head_size: int,
|
| 102 |
+
total_num_heads: int,
|
| 103 |
+
total_num_kv_heads: int | None = None,
|
| 104 |
+
bias: bool = False,
|
| 105 |
+
):
|
| 106 |
+
tp_size = dist.get_world_size()
|
| 107 |
+
total_num_kv_heads = total_num_kv_heads or total_num_heads
|
| 108 |
+
self.head_size = head_size
|
| 109 |
+
self.num_heads = divide(total_num_heads, tp_size)
|
| 110 |
+
self.num_kv_heads = divide(total_num_kv_heads, tp_size)
|
| 111 |
+
output_size = (total_num_heads + 2 * total_num_kv_heads) * self.head_size
|
| 112 |
+
super().__init__(hidden_size, output_size, bias)
|
| 113 |
+
|
| 114 |
+
def weight_loader(self, param: nn.Parameter, loaded_weight: torch.Tensor, loaded_shard_id: str):
|
| 115 |
+
param_data = param.data
|
| 116 |
+
assert loaded_shard_id in ["q", "k", "v"]
|
| 117 |
+
if loaded_shard_id == "q":
|
| 118 |
+
shard_size = self.num_heads * self.head_size
|
| 119 |
+
shard_offset = 0
|
| 120 |
+
elif loaded_shard_id == "k":
|
| 121 |
+
shard_size = self.num_kv_heads * self.head_size
|
| 122 |
+
shard_offset = self.num_heads * self.head_size
|
| 123 |
+
else:
|
| 124 |
+
shard_size = self.num_kv_heads * self.head_size
|
| 125 |
+
shard_offset = self.num_heads * self.head_size + self.num_kv_heads * self.head_size
|
| 126 |
+
param_data = param_data.narrow(self.tp_dim, shard_offset, shard_size)
|
| 127 |
+
loaded_weight = loaded_weight.chunk(self.tp_size, self.tp_dim)[self.tp_rank]
|
| 128 |
+
param_data.copy_(loaded_weight)
|
| 129 |
+
|
| 130 |
+
|
| 131 |
+
class RowParallelLinear(LinearBase):
|
| 132 |
+
|
| 133 |
+
def __init__(
|
| 134 |
+
self,
|
| 135 |
+
input_size: int,
|
| 136 |
+
output_size: int,
|
| 137 |
+
bias: bool = False,
|
| 138 |
+
):
|
| 139 |
+
tp_size = dist.get_world_size()
|
| 140 |
+
super().__init__(divide(input_size, tp_size), output_size, bias, 1)
|
| 141 |
+
|
| 142 |
+
def weight_loader(self, param: nn.Parameter, loaded_weight: torch.Tensor):
|
| 143 |
+
param_data = param.data
|
| 144 |
+
shard_size = param_data.size(self.tp_dim)
|
| 145 |
+
start_idx = self.tp_rank * shard_size
|
| 146 |
+
loaded_weight = loaded_weight.narrow(self.tp_dim, start_idx, shard_size)
|
| 147 |
+
param_data.copy_(loaded_weight)
|
| 148 |
+
|
| 149 |
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
| 150 |
+
y = F.linear(x, self.weight, self.bias if self.tp_rank == 0 else None)
|
| 151 |
+
if self.tp_size > 1:
|
| 152 |
+
dist.all_reduce(y)
|
| 153 |
+
return y
|
acestep/third_parts/nano-vllm/nanovllm/layers/rotary_embedding.py
ADDED
|
@@ -0,0 +1,61 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from functools import lru_cache
|
| 2 |
+
import torch
|
| 3 |
+
from torch import nn
|
| 4 |
+
|
| 5 |
+
|
| 6 |
+
def apply_rotary_emb(
|
| 7 |
+
x: torch.Tensor,
|
| 8 |
+
cos: torch.Tensor,
|
| 9 |
+
sin: torch.Tensor,
|
| 10 |
+
) -> torch.Tensor:
|
| 11 |
+
x1, x2 = torch.chunk(x.float(), 2, dim=-1)
|
| 12 |
+
y1 = x1 * cos - x2 * sin
|
| 13 |
+
y2 = x2 * cos + x1 * sin
|
| 14 |
+
return torch.cat((y1, y2), dim=-1).to(x.dtype)
|
| 15 |
+
|
| 16 |
+
|
| 17 |
+
class RotaryEmbedding(nn.Module):
|
| 18 |
+
|
| 19 |
+
def __init__(
|
| 20 |
+
self,
|
| 21 |
+
head_size: int,
|
| 22 |
+
rotary_dim: int,
|
| 23 |
+
max_position_embeddings: int,
|
| 24 |
+
base: float,
|
| 25 |
+
) -> None:
|
| 26 |
+
super().__init__()
|
| 27 |
+
self.head_size = head_size
|
| 28 |
+
assert rotary_dim == head_size
|
| 29 |
+
inv_freq = 1.0 / (base**(torch.arange(0, rotary_dim, 2, dtype=torch.float) / rotary_dim))
|
| 30 |
+
t = torch.arange(max_position_embeddings, dtype=torch.float)
|
| 31 |
+
freqs = torch.einsum("i,j -> ij", t, inv_freq)
|
| 32 |
+
cos = freqs.cos()
|
| 33 |
+
sin = freqs.sin()
|
| 34 |
+
cache = torch.cat((cos, sin), dim=-1).unsqueeze_(1)
|
| 35 |
+
self.register_buffer("cos_sin_cache", cache, persistent=False)
|
| 36 |
+
|
| 37 |
+
@torch.compile
|
| 38 |
+
def forward(
|
| 39 |
+
self,
|
| 40 |
+
positions: torch.Tensor,
|
| 41 |
+
query: torch.Tensor,
|
| 42 |
+
key: torch.Tensor,
|
| 43 |
+
) -> tuple[torch.Tensor, torch.Tensor]:
|
| 44 |
+
cos_sin = self.cos_sin_cache[positions]
|
| 45 |
+
cos, sin = cos_sin.chunk(2, dim=-1)
|
| 46 |
+
query = apply_rotary_emb(query, cos, sin)
|
| 47 |
+
key = apply_rotary_emb(key, cos, sin)
|
| 48 |
+
return query, key
|
| 49 |
+
|
| 50 |
+
|
| 51 |
+
@lru_cache(1)
|
| 52 |
+
def get_rope(
|
| 53 |
+
head_size: int,
|
| 54 |
+
rotary_dim: int,
|
| 55 |
+
max_position: int,
|
| 56 |
+
base: float,
|
| 57 |
+
rope_scaling: dict | None = None,
|
| 58 |
+
):
|
| 59 |
+
assert rope_scaling is None
|
| 60 |
+
rotary_emb = RotaryEmbedding(head_size, rotary_dim, max_position, base)
|
| 61 |
+
return rotary_emb
|
acestep/third_parts/nano-vllm/nanovllm/layers/sampler.py
ADDED
|
@@ -0,0 +1,15 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
from torch import nn
|
| 3 |
+
|
| 4 |
+
|
| 5 |
+
class Sampler(nn.Module):
|
| 6 |
+
|
| 7 |
+
def __init__(self):
|
| 8 |
+
super().__init__()
|
| 9 |
+
|
| 10 |
+
@torch.compile
|
| 11 |
+
def forward(self, logits: torch.Tensor, temperatures: torch.Tensor):
|
| 12 |
+
logits = logits.float().div_(temperatures.unsqueeze(dim=1))
|
| 13 |
+
probs = torch.softmax(logits, dim=-1)
|
| 14 |
+
sample_tokens = probs.div_(torch.empty_like(probs).exponential_(1).clamp_min_(1e-10)).argmax(dim=-1)
|
| 15 |
+
return sample_tokens
|
acestep/third_parts/nano-vllm/nanovllm/llm.py
ADDED
|
@@ -0,0 +1,5 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from nanovllm.engine.llm_engine import LLMEngine
|
| 2 |
+
|
| 3 |
+
|
| 4 |
+
class LLM(LLMEngine):
|
| 5 |
+
pass
|
acestep/third_parts/nano-vllm/nanovllm/models/qwen3.py
ADDED
|
@@ -0,0 +1,215 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
from torch import nn
|
| 3 |
+
import torch.distributed as dist
|
| 4 |
+
from transformers import Qwen3Config
|
| 5 |
+
|
| 6 |
+
from nanovllm.layers.activation import SiluAndMul
|
| 7 |
+
from nanovllm.layers.attention import Attention
|
| 8 |
+
from nanovllm.layers.layernorm import RMSNorm
|
| 9 |
+
from nanovllm.layers.linear import QKVParallelLinear, MergedColumnParallelLinear, RowParallelLinear
|
| 10 |
+
from nanovllm.layers.rotary_embedding import get_rope
|
| 11 |
+
from nanovllm.layers.embed_head import VocabParallelEmbedding, ParallelLMHead
|
| 12 |
+
|
| 13 |
+
|
| 14 |
+
class Qwen3Attention(nn.Module):
|
| 15 |
+
|
| 16 |
+
def __init__(
|
| 17 |
+
self,
|
| 18 |
+
hidden_size: int,
|
| 19 |
+
num_heads: int,
|
| 20 |
+
num_kv_heads: int,
|
| 21 |
+
max_position: int = 4096 * 32,
|
| 22 |
+
head_dim: int | None = None,
|
| 23 |
+
rms_norm_eps: float = 1e-06,
|
| 24 |
+
qkv_bias: bool = False,
|
| 25 |
+
rope_theta: float = 10000,
|
| 26 |
+
rope_scaling: tuple | None = None,
|
| 27 |
+
) -> None:
|
| 28 |
+
super().__init__()
|
| 29 |
+
tp_size = dist.get_world_size()
|
| 30 |
+
self.total_num_heads = num_heads
|
| 31 |
+
assert self.total_num_heads % tp_size == 0
|
| 32 |
+
self.num_heads = self.total_num_heads // tp_size
|
| 33 |
+
self.total_num_kv_heads = num_kv_heads
|
| 34 |
+
assert self.total_num_kv_heads % tp_size == 0
|
| 35 |
+
self.num_kv_heads = self.total_num_kv_heads // tp_size
|
| 36 |
+
self.head_dim = head_dim or hidden_size // self.total_num_heads
|
| 37 |
+
self.q_size = self.num_heads * self.head_dim
|
| 38 |
+
self.kv_size = self.num_kv_heads * self.head_dim
|
| 39 |
+
self.scaling = self.head_dim ** -0.5
|
| 40 |
+
self.qkv_bias = qkv_bias
|
| 41 |
+
|
| 42 |
+
self.qkv_proj = QKVParallelLinear(
|
| 43 |
+
hidden_size,
|
| 44 |
+
self.head_dim,
|
| 45 |
+
self.total_num_heads,
|
| 46 |
+
self.total_num_kv_heads,
|
| 47 |
+
bias=qkv_bias,
|
| 48 |
+
)
|
| 49 |
+
self.o_proj = RowParallelLinear(
|
| 50 |
+
self.total_num_heads * self.head_dim,
|
| 51 |
+
hidden_size,
|
| 52 |
+
bias=False,
|
| 53 |
+
)
|
| 54 |
+
self.rotary_emb = get_rope(
|
| 55 |
+
self.head_dim,
|
| 56 |
+
rotary_dim=self.head_dim,
|
| 57 |
+
max_position=max_position,
|
| 58 |
+
base=rope_theta,
|
| 59 |
+
rope_scaling=rope_scaling,
|
| 60 |
+
)
|
| 61 |
+
self.attn = Attention(
|
| 62 |
+
self.num_heads,
|
| 63 |
+
self.head_dim,
|
| 64 |
+
self.scaling,
|
| 65 |
+
self.num_kv_heads,
|
| 66 |
+
)
|
| 67 |
+
if not self.qkv_bias:
|
| 68 |
+
self.q_norm = RMSNorm(self.head_dim, eps=rms_norm_eps)
|
| 69 |
+
self.k_norm = RMSNorm(self.head_dim, eps=rms_norm_eps)
|
| 70 |
+
|
| 71 |
+
def forward(
|
| 72 |
+
self,
|
| 73 |
+
positions: torch.Tensor,
|
| 74 |
+
hidden_states: torch.Tensor,
|
| 75 |
+
) -> torch.Tensor:
|
| 76 |
+
qkv = self.qkv_proj(hidden_states)
|
| 77 |
+
q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
|
| 78 |
+
q = q.view(-1, self.num_heads, self.head_dim)
|
| 79 |
+
k = k.view(-1, self.num_kv_heads, self.head_dim)
|
| 80 |
+
v = v.view(-1, self.num_kv_heads, self.head_dim)
|
| 81 |
+
if not self.qkv_bias:
|
| 82 |
+
q = self.q_norm(q)
|
| 83 |
+
k = self.k_norm(k)
|
| 84 |
+
q, k = self.rotary_emb(positions, q, k)
|
| 85 |
+
o = self.attn(q, k, v)
|
| 86 |
+
output = self.o_proj(o.flatten(1, -1))
|
| 87 |
+
return output
|
| 88 |
+
|
| 89 |
+
|
| 90 |
+
class Qwen3MLP(nn.Module):
|
| 91 |
+
|
| 92 |
+
def __init__(
|
| 93 |
+
self,
|
| 94 |
+
hidden_size: int,
|
| 95 |
+
intermediate_size: int,
|
| 96 |
+
hidden_act: str,
|
| 97 |
+
) -> None:
|
| 98 |
+
super().__init__()
|
| 99 |
+
self.gate_up_proj = MergedColumnParallelLinear(
|
| 100 |
+
hidden_size,
|
| 101 |
+
[intermediate_size] * 2,
|
| 102 |
+
bias=False,
|
| 103 |
+
)
|
| 104 |
+
self.down_proj = RowParallelLinear(
|
| 105 |
+
intermediate_size,
|
| 106 |
+
hidden_size,
|
| 107 |
+
bias=False,
|
| 108 |
+
)
|
| 109 |
+
assert hidden_act == "silu"
|
| 110 |
+
self.act_fn = SiluAndMul()
|
| 111 |
+
|
| 112 |
+
def forward(self, x):
|
| 113 |
+
gate_up = self.gate_up_proj(x)
|
| 114 |
+
x = self.act_fn(gate_up)
|
| 115 |
+
x = self.down_proj(x)
|
| 116 |
+
return x
|
| 117 |
+
|
| 118 |
+
|
| 119 |
+
class Qwen3DecoderLayer(nn.Module):
|
| 120 |
+
|
| 121 |
+
def __init__(
|
| 122 |
+
self,
|
| 123 |
+
config: Qwen3Config,
|
| 124 |
+
) -> None:
|
| 125 |
+
super().__init__()
|
| 126 |
+
self.self_attn = Qwen3Attention(
|
| 127 |
+
hidden_size=config.hidden_size,
|
| 128 |
+
num_heads=config.num_attention_heads,
|
| 129 |
+
num_kv_heads=config.num_key_value_heads,
|
| 130 |
+
max_position=config.max_position_embeddings,
|
| 131 |
+
rms_norm_eps=config.rms_norm_eps,
|
| 132 |
+
qkv_bias=getattr(config, 'attention_bias', True),
|
| 133 |
+
head_dim=getattr(config, 'head_dim', None),
|
| 134 |
+
rope_theta=getattr(config, "rope_theta", 1000000),
|
| 135 |
+
rope_scaling=getattr(config, "rope_scaling", None),
|
| 136 |
+
)
|
| 137 |
+
self.mlp = Qwen3MLP(
|
| 138 |
+
hidden_size=config.hidden_size,
|
| 139 |
+
intermediate_size=config.intermediate_size,
|
| 140 |
+
hidden_act=config.hidden_act,
|
| 141 |
+
)
|
| 142 |
+
self.input_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
| 143 |
+
self.post_attention_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
| 144 |
+
|
| 145 |
+
def forward(
|
| 146 |
+
self,
|
| 147 |
+
positions: torch.Tensor,
|
| 148 |
+
hidden_states: torch.Tensor,
|
| 149 |
+
residual: torch.Tensor | None,
|
| 150 |
+
) -> tuple[torch.Tensor, torch.Tensor]:
|
| 151 |
+
if residual is None:
|
| 152 |
+
hidden_states, residual = self.input_layernorm(hidden_states), hidden_states
|
| 153 |
+
else:
|
| 154 |
+
hidden_states, residual = self.input_layernorm(hidden_states, residual)
|
| 155 |
+
hidden_states = self.self_attn(positions, hidden_states)
|
| 156 |
+
hidden_states, residual = self.post_attention_layernorm(hidden_states, residual)
|
| 157 |
+
hidden_states = self.mlp(hidden_states)
|
| 158 |
+
return hidden_states, residual
|
| 159 |
+
|
| 160 |
+
|
| 161 |
+
class Qwen3Model(nn.Module):
|
| 162 |
+
|
| 163 |
+
def __init__(
|
| 164 |
+
self,
|
| 165 |
+
config: Qwen3Config,
|
| 166 |
+
) -> None:
|
| 167 |
+
super().__init__()
|
| 168 |
+
self.embed_tokens = VocabParallelEmbedding(config.vocab_size, config.hidden_size)
|
| 169 |
+
self.layers = nn.ModuleList([Qwen3DecoderLayer(config) for _ in range(config.num_hidden_layers)])
|
| 170 |
+
self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
| 171 |
+
|
| 172 |
+
def forward(
|
| 173 |
+
self,
|
| 174 |
+
input_ids: torch.Tensor,
|
| 175 |
+
positions: torch.Tensor,
|
| 176 |
+
) -> torch.Tensor:
|
| 177 |
+
hidden_states = self.embed_tokens(input_ids)
|
| 178 |
+
residual = None
|
| 179 |
+
for layer in self.layers:
|
| 180 |
+
hidden_states, residual = layer(positions, hidden_states, residual)
|
| 181 |
+
hidden_states, _ = self.norm(hidden_states, residual)
|
| 182 |
+
return hidden_states
|
| 183 |
+
|
| 184 |
+
|
| 185 |
+
class Qwen3ForCausalLM(nn.Module):
|
| 186 |
+
packed_modules_mapping = {
|
| 187 |
+
"q_proj": ("qkv_proj", "q"),
|
| 188 |
+
"k_proj": ("qkv_proj", "k"),
|
| 189 |
+
"v_proj": ("qkv_proj", "v"),
|
| 190 |
+
"gate_proj": ("gate_up_proj", 0),
|
| 191 |
+
"up_proj": ("gate_up_proj", 1),
|
| 192 |
+
}
|
| 193 |
+
|
| 194 |
+
def __init__(
|
| 195 |
+
self,
|
| 196 |
+
config: Qwen3Config
|
| 197 |
+
) -> None:
|
| 198 |
+
super().__init__()
|
| 199 |
+
self.model = Qwen3Model(config)
|
| 200 |
+
self.lm_head = ParallelLMHead(config.vocab_size, config.hidden_size)
|
| 201 |
+
if config.tie_word_embeddings:
|
| 202 |
+
self.lm_head.weight.data = self.model.embed_tokens.weight.data
|
| 203 |
+
|
| 204 |
+
def forward(
|
| 205 |
+
self,
|
| 206 |
+
input_ids: torch.Tensor,
|
| 207 |
+
positions: torch.Tensor,
|
| 208 |
+
) -> torch.Tensor:
|
| 209 |
+
return self.model(input_ids, positions)
|
| 210 |
+
|
| 211 |
+
def compute_logits(
|
| 212 |
+
self,
|
| 213 |
+
hidden_states: torch.Tensor,
|
| 214 |
+
) -> torch.Tensor:
|
| 215 |
+
return self.lm_head(hidden_states)
|
acestep/third_parts/nano-vllm/nanovllm/sampling_params.py
ADDED
|
@@ -0,0 +1,13 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from dataclasses import dataclass
|
| 2 |
+
|
| 3 |
+
|
| 4 |
+
@dataclass
|
| 5 |
+
class SamplingParams:
|
| 6 |
+
temperature: float = 1.0
|
| 7 |
+
max_tokens: int = 64
|
| 8 |
+
ignore_eos: bool = False
|
| 9 |
+
cfg_scale: float = 1.0 # CFG guidance scale. When > 1.0, applies classifier-free guidance
|
| 10 |
+
|
| 11 |
+
def __post_init__(self):
|
| 12 |
+
assert self.temperature > 1e-10, "greedy sampling is not permitted"
|
| 13 |
+
assert self.cfg_scale >= 1.0, "cfg_scale must be >= 1.0"
|
acestep/third_parts/nano-vllm/nanovllm/utils/context.py
ADDED
|
@@ -0,0 +1,27 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from dataclasses import dataclass
|
| 2 |
+
import torch
|
| 3 |
+
|
| 4 |
+
|
| 5 |
+
@dataclass
|
| 6 |
+
class Context:
|
| 7 |
+
is_prefill: bool = False
|
| 8 |
+
cu_seqlens_q: torch.Tensor | None = None
|
| 9 |
+
cu_seqlens_k: torch.Tensor | None = None
|
| 10 |
+
max_seqlen_q: int = 0
|
| 11 |
+
max_seqlen_k: int = 0
|
| 12 |
+
slot_mapping: torch.Tensor | None = None
|
| 13 |
+
context_lens: torch.Tensor | None = None
|
| 14 |
+
block_tables: torch.Tensor | None = None
|
| 15 |
+
|
| 16 |
+
_CONTEXT = Context()
|
| 17 |
+
|
| 18 |
+
def get_context():
|
| 19 |
+
return _CONTEXT
|
| 20 |
+
|
| 21 |
+
def set_context(is_prefill, cu_seqlens_q=None, cu_seqlens_k=None, max_seqlen_q=0, max_seqlen_k=0, slot_mapping=None, context_lens=None, block_tables=None):
|
| 22 |
+
global _CONTEXT
|
| 23 |
+
_CONTEXT = Context(is_prefill, cu_seqlens_q, cu_seqlens_k, max_seqlen_q, max_seqlen_k, slot_mapping, context_lens, block_tables)
|
| 24 |
+
|
| 25 |
+
def reset_context():
|
| 26 |
+
global _CONTEXT
|
| 27 |
+
_CONTEXT = Context()
|
acestep/third_parts/nano-vllm/nanovllm/utils/loader.py
ADDED
|
@@ -0,0 +1,28 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
from glob import glob
|
| 3 |
+
import torch
|
| 4 |
+
from torch import nn
|
| 5 |
+
from safetensors import safe_open
|
| 6 |
+
|
| 7 |
+
|
| 8 |
+
def default_weight_loader(param: nn.Parameter, loaded_weight: torch.Tensor):
|
| 9 |
+
param.data.copy_(loaded_weight)
|
| 10 |
+
|
| 11 |
+
|
| 12 |
+
def load_model(model: nn.Module, path: str):
|
| 13 |
+
packed_modules_mapping = getattr(model, "packed_modules_mapping", {})
|
| 14 |
+
for file in glob(os.path.join(path, "*.safetensors")):
|
| 15 |
+
with safe_open(file, "pt", "cpu") as f:
|
| 16 |
+
for weight_name in f.keys():
|
| 17 |
+
for k in packed_modules_mapping:
|
| 18 |
+
if k in weight_name:
|
| 19 |
+
v, shard_id = packed_modules_mapping[k]
|
| 20 |
+
param_name = weight_name.replace(k, v)
|
| 21 |
+
param = model.get_parameter(param_name)
|
| 22 |
+
weight_loader = getattr(param, "weight_loader")
|
| 23 |
+
weight_loader(param, f.get_tensor(weight_name), shard_id)
|
| 24 |
+
break
|
| 25 |
+
else:
|
| 26 |
+
param = model.get_parameter(weight_name)
|
| 27 |
+
weight_loader = getattr(param, "weight_loader", default_weight_loader)
|
| 28 |
+
weight_loader(param, f.get_tensor(weight_name))
|
acestep/third_parts/nano-vllm/pyproject.toml
ADDED
|
@@ -0,0 +1,27 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
[build-system]
|
| 2 |
+
requires = ["setuptools>=61"]
|
| 3 |
+
build-backend = "setuptools.build_meta"
|
| 4 |
+
|
| 5 |
+
[project]
|
| 6 |
+
name = "nano-vllm"
|
| 7 |
+
version = "0.2.0"
|
| 8 |
+
authors = [{ name = "Xingkai Yu" }]
|
| 9 |
+
license = "MIT"
|
| 10 |
+
license-files = ["LICENSE"]
|
| 11 |
+
readme = "README.md"
|
| 12 |
+
description = "a lightweight vLLM implementation built from scratch"
|
| 13 |
+
requires-python = ">=3.10,<3.13"
|
| 14 |
+
dependencies = [
|
| 15 |
+
"torch>=2.4.0",
|
| 16 |
+
"triton>=3.0.0",
|
| 17 |
+
"transformers>=4.51.0",
|
| 18 |
+
"flash-attn",
|
| 19 |
+
"xxhash",
|
| 20 |
+
]
|
| 21 |
+
|
| 22 |
+
[project.urls]
|
| 23 |
+
Homepage="https://github.com/GeeeekExplorer/nano-vllm"
|
| 24 |
+
|
| 25 |
+
[tool.setuptools.packages.find]
|
| 26 |
+
where = ["."]
|
| 27 |
+
include = ["nanovllm*"]
|
requirements.txt
ADDED
|
@@ -0,0 +1,4 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
torch
|
| 2 |
+
transformers
|
| 3 |
+
diffusers
|
| 4 |
+
gradio
|