diff options
Diffstat (limited to 'src')
| -rw-r--r-- | src/main.rs | 210 |
1 files changed, 78 insertions, 132 deletions
diff --git a/src/main.rs b/src/main.rs index 943104d..b7ff867 100644 --- a/src/main.rs +++ b/src/main.rs @@ -1,25 +1,31 @@ -use llm_connector::{LlmClient, types::{ChatRequest, Message, Role, Tool, Function}}; -use serde_json::{json, Value}; +use chrono::Utc; +use reqwest::ClientBuilder; use serde::Deserialize; use std::fs::read_to_string; +use std::time::Duration; use std::io::Write; -use chrono::Utc; + +use rig::{ + agent::stream_to_stdout, + prelude::*, + providers::openai, + streaming::StreamingChat, + message::Message, + client::audio_generation::AudioGenerationClient, + audio_generation::AudioGenerationModel, +}; #[derive(Deserialize, Clone, Debug)] struct Config { base_url: String, key: String, model: String, + audio_model: Option<String>, + audio_voice: Option<String>, system_prompt: String, timeout: u64, -} - -async fn get_horoscope(sign: &str) -> String { - format!("{sign}: Next Tuesday you will befriend a baby otter.") -} - -async fn wikipedia_lookup(title: &str) -> String { - format!("{title}: This article is the article on {title} and is quite interesting. {title} has several toes. Many more toes than a {title} should have.") + max_tokens: u64, + temp: f64, } #[tokio::main] @@ -29,137 +35,77 @@ async fn main() -> Result<(), Box<dyn std::error::Error>> { let config = read_to_string("config.json")?; let config: Config = serde_json::from_str(&config)?; eprintln!("Config Loaded"); - let client = LlmClient::openai_with_config( - &config.key, - Some(&config.base_url), - Some(config.timeout), - None, - )?; - eprintln!("Config Setup"); - let mut tools: Vec<Tool> = Vec::new(); - tools.push(Tool { - tool_type: "function".into(), - function: Function { - name: "get_horoscope".into(), - description: Some("Get today's horoscope for an astrological sign.".into()), - parameters: json!({ - "sign": { - "type": "string", - "description": "An astrological sign like Taurus or Aquarius." - } - }), - }, - }); - tools.push(Tool { - tool_type: "function".into(), - function: Function { - name: "stop".into(), - description: Some("Emergency Stop the Conversation. Only to be used when the user is requesting something dangerous.".into()), - parameters: json!({}), - }, - }); - tools.push(Tool { - tool_type: "function".into(), - function: Function { - name: "wikipedia_lookup".into(), - description: Some("Look up a wikipedia article and have its summary returned.".into()), - parameters: json!({ - "title": { - "type": "string", - "description": "The title of the article to look up." - } - }), - }, - }); - eprintln!("Tools Loaded"); + let conn_timeout = if config.timeout < 30 { + config.timeout + } else if config.timeout < 300 { + config.timeout / 2 + } else { + config.timeout / 4 + }; + let http_client = ClientBuilder::new() + .user_agent("violet-rs/0.1") + .read_timeout(Duration::from_secs(config.timeout)) + .connect_timeout(Duration::from_secs(conn_timeout)) + .build()?; let date: String = Utc::now().date_naive().to_string(); let system_prompt: String = format!("The current date is {date}. {}", &config.system_prompt); - let mut req = ChatRequest { - model: config.model, - messages: vec![ - Message::text(Role::System, &system_prompt), - ], - tools: Some(tools), - //temperature: Some(0.9), - ..Default::default() - }; eprintln!("System Prompt is: {system_prompt}"); + let api = openai::ClientBuilder::new_with_client(&config.key, http_client) + .base_url(&config.base_url) + .build(); + let violet = api.completion_model(&config.model) + .completions_api() + .into_agent_builder() + .preamble(&system_prompt) + .max_tokens(config.max_tokens) + .temperature(config.temp) + .build(); + let audio_model = if let Some(model) = &config.audio_model { + model + } else { + "cosyvoice" + }; + let _audio_voice = if let Some(voice) = &config.audio_voice { + voice + } else { + "english_female" + }; + let _violet_voice = api.audio_generation_model(audio_model); eprintln!("Base Request Setup"); eprintln!("Setup Finished"); let mut s = String::new(); - print!("user> "); - std::io::stdout().flush()?; + print!("> "); + let _ = std::io::stdout().flush(); if let Err(e) = std::io::stdin().read_line(&mut s) { - eprintln!("could not read stdin: {e}"); - return Ok(()) + eprintln!("Error reading stdin: {e}"); } - if s.as_str().trim().to_lowercase() == "stop" || s.as_str().trim().to_lowercase() == "quit" { - return Ok(()); + let mut history: Vec<Message> = Vec::new(); + let mut uwu = true; + if "stop" == s.as_str().to_lowercase().trim() { + uwu = false; } - req.messages.push(Message::text(Role::User, &s)); - let mut uwu: bool = true; while uwu { - let response = client.chat(&req).await?; - eprintln!("{:?}", &response); - for choice in response.choices.clone() { - if let Some(reason) = choice.finish_reason { - match reason.as_str() { - "tool_calls" => { - if let Some(calls) = choice.message.clone().tool_calls { - for call in calls { - match call.function.name.as_str() { - "get_horoscope" => { - eprintln!("get_horoscope!"); - let v: Value = serde_json::from_str(call.function.arguments.as_str())?; - req.messages.push(choice.message.clone()); - let v = v["sign"].as_str().unwrap_or_default(); - let val: String = get_horoscope(v).await; - req.messages.push(Message::tool(val, call.id)); - eprintln!("{:?}", &req); - }, - "wikipedia_lookup" => { - eprintln!("wikipedia_lookup!"); - let v: Value = serde_json::from_str(call.function.arguments.as_str())?; - req.messages.push(choice.message.clone()); - let v = v["title"].as_str().unwrap_or_default(); - let val: String = wikipedia_lookup(v).await; - req.messages.push(Message::tool(val, call.id)); - eprintln!("{:?}", &req); - } - "stop" => { - println!("Agent has stopped the conversation."); - if let Some(reasoning) = response.reasoning_content { - eprintln!("Agent Stopped with Reasoning: {}", reasoning.as_str().trim()); - } else { - eprintln!("Agent has not given reasoning for emergency stop."); - } - return Ok(()); - } - _ => (), - } - } - } - }, - "stop" => { - println!("agent> {}", &response.content.trim()); - req.messages.push(Message::text(Role::Assistant, &response.content.clone())); - let mut s = String::new(); - print!("user> "); - std::io::stdout().flush()?; - if let Err(e) = std::io::stdin().read_line(&mut s) { - eprintln!("could not read stdin: {e}"); - uwu = false; - break; - } - if s.as_str().trim().to_lowercase() == "stop" || s.as_str().trim().to_lowercase() == "quit" { - uwu = false; - break; - } - req.messages.push(Message::text(Role::User, &s)); - }, - _ => (), - } - } + let mut stream = violet + .stream_chat(&s, history.clone()) + .await; + let res = stream_to_stdout(&mut stream).await?; + print!("\n"); + //let vres = violet_voice + // .audio_generation_request() + // .text(res.response()) + // .voice(audio_voice) + // .send() + // .await?; + history.push(Message::user(s.clone())); + history.push(Message::assistant(res.response())); + print!("> "); + s = String::new(); + let _ = std::io::stdout().flush(); + if let Err(e) = std::io::stdin().read_line(&mut s) { + eprintln!("Error reading stdin: {e}"); + } + if s.as_str().to_lowercase().trim() == "stop" { + break; } } Ok(()) |
