Spaces:
Running
on
Zero
Running
on
Zero
| import os | |
| import torch | |
| import spaces | |
| import psycopg2 | |
| import gradio as gr | |
| from threading import Thread | |
| from collections.abc import Iterator | |
| from transformers import AutoModelForCausalLM, AutoTokenizer, TextIteratorStreamer | |
| MAX_MAX_NEW_TOKENS = 4096 | |
| MAX_INPUT_TOKEN_LENGTH = 4096 | |
| DEFAULT_MAX_NEW_TOKENS = 2048 | |
| HF_TOKEN = os.environ["HF_TOKEN"] | |
| model_id = "ai4bharat/IndicTrans3-beta" | |
| model = AutoModelForCausalLM.from_pretrained( | |
| model_id, torch_dtype=torch.float16, device_map="auto", token=HF_TOKEN | |
| ) | |
| tokenizer = AutoTokenizer.from_pretrained("meta-llama/Llama-3.2-3B-Instruct") | |
| LANGUAGES = [ | |
| "Hindi", | |
| "Bengali", | |
| "Telugu", | |
| "Marathi", | |
| "Tamil", | |
| "Urdu", | |
| "Gujarati", | |
| "Kannada", | |
| "Odia", | |
| "Malayalam", | |
| "Punjabi", | |
| "Assamese", | |
| "Maithili", | |
| "Santali", | |
| "Kashmiri", | |
| "Nepali", | |
| "Sindhi", | |
| "Konkani", | |
| "Dogri", | |
| "Manipuri", | |
| "Bodo", | |
| ] | |
| def format_message_for_translation(message, target_lang): | |
| return f"Translate the following text to {target_lang}: {message}" | |
| def store_feedback(rating, feedback_text, chat_history, tgt_lang): | |
| try: | |
| if not rating: | |
| gr.Warning("Please select a rating before submitting feedback.", duration=5) | |
| return None | |
| if not feedback_text or feedback_text.strip() == "": | |
| gr.Warning("Please provide some feedback before submitting.", duration=5) | |
| return None | |
| if not chat_history: | |
| gr.Warning( | |
| "Please provide the input text before submitting feedback.", duration=5 | |
| ) | |
| return None | |
| if len(chat_history[0]) < 2: | |
| gr.Warning( | |
| "Please translate the input text before submitting feedback.", | |
| duration=5, | |
| ) | |
| return None | |
| conn = psycopg2.connect( | |
| host=os.getenv("DB_HOST"), | |
| database=os.getenv("DB_NAME"), | |
| user=os.getenv("DB_USER"), | |
| password=os.getenv("DB_PASSWORD"), | |
| port=os.getenv("DB_PORT"), | |
| ) | |
| cursor = conn.cursor() | |
| insert_query = """ | |
| INSERT INTO feedback | |
| (tgt_lang, rating, feedback_txt, chat_history) | |
| VALUES (%s, %s, %s, %s) | |
| """ | |
| cursor.execute( | |
| insert_query, (tgt_lang, int(rating), feedback_text, chat_history) | |
| ) | |
| conn.commit() | |
| cursor.close() | |
| conn.close() | |
| gr.Info("Thank you for your feedback! π", duration=5) | |
| except: | |
| gr.Error( | |
| "An error occurred while storing feedback. Please try again later.", | |
| duration=5, | |
| ) | |
| def store_output(tgt_lang, input_text, output_text): | |
| conn = psycopg2.connect( | |
| host=os.getenv("DB_HOST"), | |
| database=os.getenv("DB_NAME"), | |
| user=os.getenv("DB_USER"), | |
| password=os.getenv("DB_PASSWORD"), | |
| port=os.getenv("DB_PORT"), | |
| ) | |
| cursor = conn.cursor() | |
| insert_query = """ | |
| INSERT INTO translation | |
| (input_txt, output_txt, tgt_lang) | |
| VALUES (%s, %s, %s) | |
| """ | |
| cursor.execute(insert_query, (input_text, output_text, tgt_lang)) | |
| conn.commit() | |
| cursor.close() | |
| def translate_message( | |
| message: str, | |
| chat_history: list[dict], | |
| target_language: str = "Hindi", | |
| max_new_tokens: int = 1024, | |
| temperature: float = 0.6, | |
| top_p: float = 0.9, | |
| top_k: int = 50, | |
| repetition_penalty: float = 1.2, | |
| ) -> Iterator[str]: | |
| conversation = [] | |
| translation_request = format_message_for_translation(message, target_language) | |
| conversation.append({"role": "user", "content": translation_request}) | |
| input_ids = tokenizer.apply_chat_template( | |
| conversation, return_tensors="pt", add_generation_prompt=True | |
| ) | |
| if input_ids.shape[1] > MAX_INPUT_TOKEN_LENGTH: | |
| input_ids = input_ids[:, -MAX_INPUT_TOKEN_LENGTH:] | |
| gr.Warning( | |
| f"Trimmed input as it was longer than {MAX_INPUT_TOKEN_LENGTH} tokens." | |
| ) | |
| input_ids = input_ids.to(model.device) | |
| streamer = TextIteratorStreamer( | |
| tokenizer, timeout=240.0, skip_prompt=True, skip_special_tokens=True | |
| ) | |
| generate_kwargs = dict( | |
| {"input_ids": input_ids}, | |
| streamer=streamer, | |
| max_new_tokens=max_new_tokens, | |
| do_sample=True, | |
| top_p=top_p, | |
| top_k=top_k, | |
| temperature=temperature, | |
| num_beams=1, | |
| repetition_penalty=repetition_penalty, | |
| ) | |
| t = Thread(target=model.generate, kwargs=generate_kwargs) | |
| t.start() | |
| outputs = [] | |
| for text in streamer: | |
| outputs.append(text) | |
| yield "".join(outputs) | |
| store_output(target_language, message, "".join(outputs)) | |
| css = """ | |
| # body { | |
| # background-color: #f7f7f7; | |
| # } | |
| .feedback-section { | |
| margin-top: 30px; | |
| border-top: 1px solid #ddd; | |
| padding-top: 20px; | |
| } | |
| .container { | |
| max-width: 90%; | |
| margin: 0 auto; | |
| } | |
| .language-selector { | |
| margin-bottom: 20px; | |
| padding: 10px; | |
| background-color: #ffffff; | |
| border-radius: 8px; | |
| box-shadow: 0 2px 5px rgba(0,0,0,0.1); | |
| } | |
| .advanced-options { | |
| margin-top: 20px; | |
| } | |
| """ | |
| DESCRIPTION = """\ | |
| IndicTrans3 is the latest state-of-the-art (SOTA) translation model from AI4Bharat, designed to handle translations across <b>22 Indic languages</b> with high accuracy. It supports <b>document-level machine translation (MT)</b> and is built to match the performance of other leading SOTA models. <br> | |
| π’ <b>Training data will be released soon!</b> | |
| <h3>πΉ Features</h3> | |
| β Supports <b>22 Indic languages</b> | |
| β Enables <b>document-level translation</b> | |
| β Achieves <b>SOTA performance</b> in Indic MT | |
| β Optimized for <b>real-world applications</b> | |
| <h3>π Try It Out!</h3> | |
| 1οΈβ£ Enter text in any supported language | |
| 2οΈβ£ Select the target language | |
| 3οΈβ£ Click <b>Translate</b> and get high-quality results! | |
| Built for <b>linguistic diversity and accessibility</b>, IndicTrans3 is a major step forward in <b>Indic language AI</b>. | |
| π‘ <b>Source:</b> AI4Bharat | Powered by Hugging Face | |
| """ | |
| with gr.Blocks(css=css) as demo: | |
| with gr.Column(elem_classes="container"): | |
| gr.Markdown( | |
| "# π IndicTrans3-beta π: Multilingual Translation for 22 Indic Languages </center>" | |
| ) | |
| gr.Markdown(DESCRIPTION) | |
| target_language = gr.Dropdown( | |
| LANGUAGES, | |
| value="Hindi", | |
| label="Which language would you like to translate to?", | |
| elem_id="language-dropdown", | |
| ) | |
| chatbot = gr.Chatbot(height=400, elem_id="chatbot") | |
| with gr.Row(): | |
| msg = gr.Textbox( | |
| placeholder="Enter text to translate...", | |
| show_label=False, | |
| container=False, | |
| scale=9, | |
| ) | |
| submit_btn = gr.Button("Translate", scale=1) | |
| gr.Examples( | |
| examples=[ | |
| "The Taj Mahal stands majestically along the banks of river Yamuna, a timeless symbol of eternal love.", | |
| "Kumbh Mela is the world's largest gathering of people, where millions of pilgrims bathe in sacred rivers for spiritual purification.", | |
| "India's classical dance forms like Bharatanatyam, Kathak, and Odissi beautifully blend rhythm, expression, and storytelling.", | |
| "Ayurveda, the ancient Indian medical system, focuses on holistic wellness through natural herbs and balanced living.", | |
| "During Diwali, homes across India are decorated with oil lamps, colorful rangoli patterns, and twinkling lights to celebrate the victory of light over darkness.", | |
| ], | |
| inputs=msg, | |
| ) | |
| with gr.Accordion("Provide Feedback", open=True): | |
| gr.Markdown("## Rate Translation & Provide Feedback π") | |
| gr.Markdown( | |
| "Help us improve the translation quality by providing your feedback." | |
| ) | |
| with gr.Row(): | |
| rating = gr.Radio( | |
| ["1", "2", "3", "4", "5"], label="Translation Rating (1-5)" | |
| ) | |
| feedback_text = gr.Textbox( | |
| placeholder="Share your feedback about the translation...", | |
| label="Feedback", | |
| lines=3, | |
| ) | |
| feedback_submit = gr.Button("Submit Feedback") | |
| feedback_result = gr.Textbox(label="", visible=False) | |
| with gr.Accordion( | |
| "Advanced Options", open=False, elem_classes="advanced-options" | |
| ): | |
| max_new_tokens = gr.Slider( | |
| label="Max new tokens", | |
| minimum=1, | |
| maximum=MAX_MAX_NEW_TOKENS, | |
| step=1, | |
| value=DEFAULT_MAX_NEW_TOKENS, | |
| ) | |
| temperature = gr.Slider( | |
| label="Temperature", | |
| minimum=0.1, | |
| maximum=1.0, | |
| step=0.1, | |
| value=0.1, | |
| ) | |
| top_p = gr.Slider( | |
| label="Top-p (nucleus sampling)", | |
| minimum=0.05, | |
| maximum=1.0, | |
| step=0.05, | |
| value=0.9, | |
| ) | |
| top_k = gr.Slider( | |
| label="Top-k", | |
| minimum=1, | |
| maximum=100, | |
| step=1, | |
| value=50, | |
| ) | |
| repetition_penalty = gr.Slider( | |
| label="Repetition penalty", | |
| minimum=1.0, | |
| maximum=2.0, | |
| step=0.05, | |
| value=1.0, | |
| ) | |
| chat_state = gr.State([]) | |
| def user(user_message, history, target_lang): | |
| return "", history + [[user_message, None]] | |
| def bot( | |
| history, target_lang, max_tokens, temp, top_p_val, top_k_val, rep_penalty | |
| ): | |
| user_message = history[-1][0] | |
| history[-1][1] = "" | |
| for chunk in translate_message( | |
| user_message, | |
| history[:-1], | |
| target_lang, | |
| max_tokens, | |
| temp, | |
| top_p_val, | |
| top_k_val, | |
| rep_penalty, | |
| ): | |
| history[-1][1] = chunk | |
| yield history | |
| msg.submit( | |
| user, [msg, chatbot, target_language], [msg, chatbot], queue=False | |
| ).then( | |
| bot, | |
| [ | |
| chatbot, | |
| target_language, | |
| max_new_tokens, | |
| temperature, | |
| top_p, | |
| top_k, | |
| repetition_penalty, | |
| ], | |
| chatbot, | |
| ) | |
| submit_btn.click( | |
| user, [msg, chatbot, target_language], [msg, chatbot], queue=False | |
| ).then( | |
| bot, | |
| [ | |
| chatbot, | |
| target_language, | |
| max_new_tokens, | |
| temperature, | |
| top_p, | |
| top_k, | |
| repetition_penalty, | |
| ], | |
| chatbot, | |
| ) | |
| feedback_submit.click( | |
| fn=store_feedback, | |
| inputs=[rating, feedback_text, chatbot, target_language], | |
| ) | |
| if __name__ == "__main__": | |
| demo.launch() | |