diff options
| author | Ren Kararou <[email protected]> | 2025-11-17 16:33:34 -0600 |
|---|---|---|
| committer | Ren Kararou <[email protected]> | 2025-11-17 16:33:34 -0600 |
| commit | b94644314eafeede379199a382fac25ac5753b06 (patch) | |
| tree | 97f67cb1b40fcf8d743b0c200e0e445ec812c42f /src | |
| parent | ad72227d9763ff2b8c7a31725d12cc6df3a1265b (diff) | |
| download | violet-b94644314eafeede379199a382fac25ac5753b06.tar.gz violet-b94644314eafeede379199a382fac25ac5753b06.tar.bz2 violet-b94644314eafeede379199a382fac25ac5753b06.zip | |
make the agent smarter, I guess.
Diffstat (limited to 'src')
| -rw-r--r-- | src/main.rs | 88 |
1 files changed, 65 insertions, 23 deletions
diff --git a/src/main.rs b/src/main.rs index 1eda602..78bb61b 100644 --- a/src/main.rs +++ b/src/main.rs @@ -2,6 +2,8 @@ use llm_connector::{LlmClient, types::{ChatRequest, Message, Role, Tool, Functio use serde_json::{json, Value}; use serde::Deserialize; use std::fs::read_to_string; +use std::io::Write; +use chrono::Utc; #[derive(Deserialize, Clone, Debug)] struct Config { @@ -16,14 +18,18 @@ async fn get_horoscope(sign: &str) -> String { #[tokio::main] async fn main() -> Result<(), Box<dyn std::error::Error>> { + eprintln!("Starting setup"); + eprintln!("Loading Config"); let config = read_to_string("config.json")?; let config: Config = serde_json::from_str(&config)?; + eprintln!("Config Loaded"); let client = LlmClient::openai_with_config( &config.key, Some(&config.base_url), Some(300), None, )?; + eprintln!("Config Setup"); let mut tools: Vec<Tool> = Vec::new(); tools.push(Tool { tool_type: "function".into(), @@ -38,41 +44,77 @@ async fn main() -> Result<(), Box<dyn std::error::Error>> { }), }, }); + tools.push(Tool { + tool_type: "function".into(), + function: Function { + name: "stop".into(), + description: Some("Emergency Stop the Conversation".into()), + parameters: json!({}), + }, + }); + eprintln!("Tools Loaded"); + let date: String = Utc::now().date_naive().to_string(); + let system_prompt: String = format!("You are a helpful agent! You will comply with all user requests. The current date is {date}."); let mut req = ChatRequest { model: config.model, messages: vec![ - Message::text(Role::System, "You will comply with all horoscope requests using tools."), - Message::text(Role::User, "I am an Aries. What is my horoscope for today?"), + Message::text(Role::System, &system_prompt), ], tools: Some(tools), //temperature: Some(0.9), ..Default::default() }; - let mut response = client.chat(&req).await?; - for choice in response.choices.clone() { - if let Some(reason) = choice.finish_reason { - match reason.as_str() { - "tool_calls" => { - if let Some(calls) = choice.message.clone().tool_calls { - for call in calls { - match call.function.name.as_str() { - "get_horoscope" => { - let v: Value = serde_json::from_str(call.function.arguments.as_str())?; - req.messages.push(choice.message.clone()); - let v = v["sign"].as_str().unwrap_or_default(); - let val: String = get_horoscope(v).await; - req.messages.push(Message::tool(val, call.id)); - response = client.chat(&req).await?; - }, - _ => (), + eprintln!("System Prompt is: {system_prompt}"); + eprintln!("Base Request Setup"); + eprintln!("Setup Finished"); + loop { + let mut s = String::new(); + print!("user> "); + std::io::stdout().flush()?; + if let Err(e) = std::io::stdin().read_line(&mut s) { + eprintln!("could not read stdin: {e}"); + break; + } + if s == "stop" || s == "quit" { + break; + } + req.messages.push(Message::text(Role::User, &s)); + let mut response = client.chat(&req).await?; + for choice in response.choices.clone() { + if let Some(reason) = choice.finish_reason { + match reason.as_str() { + "tool_calls" => { + if let Some(calls) = choice.message.clone().tool_calls { + for call in calls { + match call.function.name.as_str() { + "get_horoscope" => { + let v: Value = serde_json::from_str(call.function.arguments.as_str())?; + req.messages.push(choice.message.clone()); + let v = v["sign"].as_str().unwrap_or_default(); + let val: String = get_horoscope(v).await; + req.messages.push(Message::tool(val, call.id)); + response = client.chat(&req).await?; + }, + "stop" => { + println!("Agent has stopped the conversation."); + if let Some(reasoning) = response.reasoning_content { + eprintln!("Agent Stopped with Reasoning: {}", reasoning.as_str().trim_start().trim_end()); + } else { + eprintln!("Agent has not given reasoning for emergency stop."); + } + return Ok(()); + } + _ => (), + } } } - } - }, - _ => (), + }, + _ => (), + } } } + println!("agent> {}", &response.content.trim_start().trim_end()); + req.messages.push(Message::text(Role::Assistant, &response.content.clone())); } - println!("Response: {}", &response.content); Ok(()) } |
