48 lines
1.4 KiB
Python
48 lines
1.4 KiB
Python
import os
|
|
from threading import Thread
|
|
from transformers import pipeline, TextIteratorStreamer
|
|
|
|
# Use a pipeline as a high-level helper
|
|
print("Loading model...")
|
|
pipe = pipeline("text-generation", model="PKU-Alignment/ProgressGym-HistLlama3-8B-C019-instruct-v0.2", device=0)
|
|
print("Model loaded. Type 'exit' to quit, 'clear' to reset conversation.")
|
|
|
|
messages = []
|
|
|
|
while True:
|
|
try:
|
|
user_input = input("\nUser: ").strip()
|
|
except EOFError:
|
|
break
|
|
|
|
if not user_input:
|
|
continue
|
|
|
|
if user_input.lower() in ["exit", "quit"]:
|
|
break
|
|
|
|
if user_input.lower() == "clear":
|
|
os.system('cls' if os.name == 'nt' else 'clear')
|
|
messages = []
|
|
print("Conversation cleared.")
|
|
continue
|
|
|
|
messages.append({"role": "user", "content": user_input})
|
|
|
|
# Stream response
|
|
streamer = TextIteratorStreamer(pipe.tokenizer, skip_prompt=True, skip_special_tokens=True)
|
|
generation_kwargs = dict(text_inputs=messages, max_new_tokens=512, streamer=streamer)
|
|
|
|
thread = Thread(target=pipe, kwargs=generation_kwargs)
|
|
thread.start()
|
|
|
|
print("\nAssistant: ", end="", flush=True)
|
|
generated_text = ""
|
|
for new_text in streamer:
|
|
print(new_text, end="", flush=True)
|
|
generated_text += new_text
|
|
print()
|
|
|
|
# Update history
|
|
messages.append({"role": "assistant", "content": generated_text})
|