import os import gradio as gr from transformers import AutoModelForCausalLM, AutoTokenizer import torch import spaces # PyTorch設定(パフォーマンスと再現性向上のため) torch.backends.cudnn.deterministic = True torch.backends.cudnn.benchmark = False torch.backends.cuda.matmul.allow_tf32 = True HF_TOKEN = os.getenv("HF_TOKEN") # モデルのキャッシュ用辞書(ロード済みなら再利用) loaded_models = {} def get_model_and_tokenizer(model_name): # 既にロード済みならそのまま返す if model_name in loaded_models: return loaded_models[model_name] # ロードされていなければロードする tokenizer = AutoTokenizer.from_pretrained( model_name, attn_implementation="flash_attention_2", use_auth_token=HF_TOKEN ) model = AutoModelForCausalLM.from_pretrained(model_name, use_auth_token=HF_TOKEN) loaded_models[model_name] = (model, tokenizer) return model, tokenizer def disable_generate_button(): # 生成ボタンを無効化し、テキストを「モデルをロード中……」に変更する return gr.update(interactive=False, value="モデルをロード中……") def load_model(model_name): """ プルダウン変更時や起動時に呼ばれ、モデルをロードして生成ボタンを有効化する。 """ tokenizer = AutoTokenizer.from_pretrained( model_name, attn_implementation="flash_attention_2", use_auth_token=HF_TOKEN ) model = AutoModelForCausalLM.from_pretrained(model_name, use_auth_token=HF_TOKEN) loaded_models[model_name] = (model, tokenizer) status_message = f"Model '{model_name}' loaded successfully." # ロード完了後、生成ボタンを有効化し、テキストを「続きを生成」に戻す return status_message, gr.update(interactive=True, value="続きを生成") @spaces.GPU def generate_text( model_name, input_text, max_length=150, temperature=0.7, top_k=50, top_p=0.95, repetition_penalty=1.0 ): """ユーザー入力に基づいてテキストを生成し、元のテキストに追加する関数""" try: if not input_text.strip(): return "" # 既にロード済みのモデルとトークナイザーを使用 model, tokenizer = get_model_and_tokenizer(model_name) # GPUが利用可能ならGPUへ移動。bf16がサポートされている場合はbf16を使用 device = "cuda" if torch.cuda.is_available() else "cpu" if device == "cuda" and hasattr(torch.cuda, "is_bf16_supported") and torch.cuda.is_bf16_supported(): model.to(device, dtype=torch.bfloat16) else: model.to(device) # 入力テキストのトークン化 input_ids = tokenizer.encode(input_text, return_tensors="pt").to(device) input_token_count = input_ids.shape[1] # 総トークン数の上限を入力トークン数 + max_length(max_lengthは追加するトークン数として扱う) total_max_length = input_token_count + max_length # テキスト生成 output_ids = model.generate( input_ids, max_length=total_max_length, do_sample=True, temperature=temperature, top_k=top_k, top_p=top_p, repetition_penalty=repetition_penalty, pad_token_id=tokenizer.eos_token_id, num_return_sequences=1 ) # 生成されたテキストをデコードし、入力部分を除いた生成分を抽出 generated_text = tokenizer.decode(output_ids[0], skip_special_tokens=True) new_text = generated_text[len(input_text):] # 入力テキストに生成したテキストを追加して返す return input_text + new_text except Exception as e: return f"{input_text}\n\nエラーが発生しました: {str(e)}" # Gradioインターフェースの作成 with gr.Blocks() as demo: gr.Markdown("# テキスト続き生成アシスタント") gr.Markdown("モデルを選択し、テキストボックスに文章を入力してパラメータを調整後、「続きを生成」ボタンをクリックすると、選択したモデルがその続きを生成します。") # モデル選択用プルダウンメニュー model_dropdown = gr.Dropdown( choices=[ "Local-Novel-LLM-project/Vecteus-v1-abliterated", "Local-Novel-LLM-project/Ninja-V3", "Local-Novel-LLM-project/kagemusya-7B-v1" ], label="モデルを選択してください", value="Local-Novel-LLM-project/Vecteus-v1-abliterated" ) # 隠しコンポーネント:モデルロードの状況を表示(ユーザーには見せなくても良い) load_status = gr.Textbox(visible=False) # テキスト入力ボックス input_text = gr.Textbox(label="テキストを入力してください", placeholder="ここにテキストを入力...", lines=10) # 生成パラメータの設定UI max_length_slider = gr.Slider(minimum=10, maximum=1000, value=100, step=10, label="追加するトークン数") temperature_slider = gr.Slider(minimum=0.1, maximum=1.5, value=0.7, step=0.1, label="創造性(温度)") top_k_slider = gr.Slider(minimum=0, maximum=100, value=50, step=1, label="top_k") top_p_slider = gr.Slider(minimum=0.0, maximum=1.0, value=0.95, step=0.05, label="top_p") repetition_penalty_slider = gr.Slider(minimum=1.0, maximum=2.0, value=1.0, step=0.1, label="繰り返しペナルティ") # 生成ボタンは初期状態で無効化 generate_btn = gr.Button("モデルをロード中……", variant="primary", interactive=False) clear_btn = gr.Button("クリア") # プルダウン変更時に、まず生成ボタンを無効化(テキストを「モデルをロード中……」に変更)し、その後モデルをロードして生成ボタンを再有効化するイベントチェーンを設定 model_dropdown.change( fn=disable_generate_button, inputs=None, outputs=generate_btn ).then( fn=load_model, inputs=model_dropdown, outputs=[load_status, generate_btn] ) # 起動時にも load_model を実行する(初期値のモデルでロード) demo.load(fn=load_model, inputs=model_dropdown, outputs=[load_status, generate_btn]) # 生成ボタン押下時のイベント設定 generate_btn.click( fn=generate_text, inputs=[model_dropdown, input_text, max_length_slider, temperature_slider, top_k_slider, top_p_slider, repetition_penalty_slider], outputs=input_text ) clear_btn.click(lambda: "", None, input_text) # アプリの起動 demo.launch()