From f2558babd9cd64764730a832f837cad3d920fdd7 Mon Sep 17 00:00:00 2001 From: Ren Kararou Date: Sat, 29 Nov 2025 20:17:15 -0600 Subject: start routing to actually do something useful --- src/main.rs | 197 ++++++++++++++++++++++++++++++++++++++++++++++-------------- 1 file changed, 153 insertions(+), 44 deletions(-) (limited to 'src') diff --git a/src/main.rs b/src/main.rs index 7a436e7..4dbb9cf 100644 --- a/src/main.rs +++ b/src/main.rs @@ -1,16 +1,20 @@ -//use base64::{Engine, engine::general_purpose::STANDARD}; +use base64::{Engine, engine::general_purpose::STANDARD}; use chrono::Utc; -use reqwest::ClientBuilder; +use reqwest::{Client, ClientBuilder}; use rig::{ - agent::stream_to_stdout, + agent::{Agent, stream_to_stdout}, audio_generation::AudioGenerationModel, client::audio_generation::AudioGenerationClient, - message::{DocumentSourceKind, Image, ImageDetail, ImageMediaType, Message}, + completion::Chat, + message::{ + AssistantContent, DocumentSourceKind, Image, ImageDetail, ImageMediaType, + Message, UserContent, + }, prelude::*, providers::openai, + providers::openai::CompletionModel, streaming::StreamingChat, }; -use rodio::Decoder; use serde::{Deserialize, Serialize}; use serde_json::json; use std::io::{Cursor, Write}; @@ -20,28 +24,40 @@ use std::time::Duration; struct Config { base_url: String, key: String, - model: String, + timeout: u64, + #[serde(skip_serializing_if = "Option::is_none")] + vision_model: Option, + #[serde(skip_serializing_if = "Option::is_none")] + vision_prompt: Option, + #[serde(skip_serializing_if = "Option::is_none")] + summary_model: Option, + #[serde(skip_serializing_if = "Option::is_none")] + summary_prompt: Option, #[serde(skip_serializing_if = "Option::is_none")] audio_model: Option, #[serde(skip_serializing_if = "Option::is_none")] audio_voice: Option, - system_prompt: String, - timeout: u64, - max_tokens: u64, - temp: f64, + vision_tokens: u64, + vision_temp: f64, + summary_tokens: u64, + summary_temp: f64, } impl std::default::Default for Config { fn default() -> Self { Self { base_url: String::from("https://api.openai.com/v1"), key: String::from("sk-..."), - model: String::from("gpt-4o"), + vision_model: None, + vision_prompt: None, + summary_prompt: None, + summary_model: None, audio_model: None, audio_voice: None, - system_prompt: String::from("You are a helpful assistant!"), timeout: 30, - max_tokens: 4096, - temp: 0.4, + vision_tokens: 4096, + vision_temp: 0.4, + summary_tokens: 8192, + summary_temp: 0.9, } } } @@ -53,7 +69,7 @@ async fn main() -> Result<(), Box> { let config: Config = confy::load("violet", Some("violet"))?; println!( "Config file location: {}", - confy::get_configuration_file_path("violet", None)? + confy::get_configuration_file_path("violet", Some("violet"))? .as_path() .to_str() .unwrap_or("path does not exist") @@ -72,19 +88,50 @@ async fn main() -> Result<(), Box> { .connect_timeout(Duration::from_secs(conn_timeout)) .build()?; let date: String = Utc::now().date_naive().to_string(); - let system_prompt: String = - format!("The current date is {date}.\n\n{}", &config.system_prompt); - eprintln!("System Prompt is: {system_prompt}"); + let vision_prompt: String = if let Some(prompt) = config.vision_prompt { + prompt + } else { + "You will describe the images attached".into() + }; + let summary_prompt: String = format!( + "The current date is {date}.\n\n{}", + if let Some(prompt) = config.summary_prompt { + prompt + } else { + String::from("You will create a narrative for the image ") + + "descriptions given as if you were telling a story." + } + ); + eprintln!("Vision System Prompt is: {vision_prompt}"); + eprintln!("Summary System Prompt is: {summary_prompt}"); let api = openai::ClientBuilder::new_with_client(&config.key, http_client) .base_url(&config.base_url) .build(); - let violet = api - .completion_model(&config.model) + let vision_model: String = if let Some(vmodel) = config.vision_model { + vmodel + } else { + "gpt-image-1".into() + }; + let vision = api + .completion_model(&vision_model) + .completions_api() + .into_agent_builder() + .preamble(&vision_prompt) + .max_tokens(config.vision_tokens) + .temperature(config.vision_temp) + .build(); + let summary_model: String = if let Some(smodel) = config.summary_model { + smodel + } else { + "gpt-4o".into() + }; + let summary = api + .completion_model(&summary_model) .completions_api() .into_agent_builder() - .preamble(&system_prompt) - .max_tokens(config.max_tokens) - .temperature(config.temp) + .preamble(&summary_prompt) + .max_tokens(config.summary_tokens) + .temperature(config.summary_temp) .build(); let audio_model = if let Some(model) = &config.audio_model { model @@ -96,13 +143,16 @@ async fn main() -> Result<(), Box> { } else { "alloy" }; - let violet_voice = api.audio_generation_model(audio_model); - eprintln!("Base Request Setup"); - let mut history: Vec = Vec::new(); - eprintln!("Getting Audio Device"); - let stream_handle = rodio::OutputStreamBuilder::open_default_stream()?; - let _sink = rodio::Sink::connect_new(&stream_handle.mixer()); + let audio = api.audio_generation_model(audio_model); eprintln!("Setup Finished"); + routing(vision, summary, audio, audio_voice).await?; + Ok(()) +} + +async fn chat( + agent: Agent>, +) -> Result, Box> { + let mut history: Vec = Vec::new(); let mut s = String::new(); print!("> "); let _ = std::io::stdout().flush(); @@ -114,22 +164,9 @@ async fn main() -> Result<(), Box> { uwu = false; } while uwu { - let mut stream = violet.stream_chat(&s, history.clone()).await; + let mut stream = agent.stream_chat(&s, history.clone()).await; let res = stream_to_stdout(&mut stream).await?; print!("\n"); - let vres = violet_voice - .audio_generation_request() - .text(res.response()) - .voice(&audio_voice) - .additional_params(json!( - { - "response_format": "mp3", - } - )) - .send() - .await?; - let vdata = Decoder::new(Cursor::new(vres.audio.clone()))?; - stream_handle.mixer().add(vdata); history.push(Message::user(s.clone())); history.push(Message::assistant(res.response())); print!("> "); @@ -139,8 +176,80 @@ async fn main() -> Result<(), Box> { eprintln!("Error reading stdin: {e}"); } if s.as_str().to_lowercase().trim() == "stop" { - break; + uwu = false; } } + Ok(history) +} + +async fn prompt_model( + agent: Agent>, + prompt: Message, + history: Vec, +) -> Result> { + let res = agent.chat(prompt, history).await?; + Ok(rig::message::AssistantContent::text(&res).into()) +} + +async fn get_audio( + audio: openai::audio_generation::AudioGenerationModel, + voice: &str, + text: &str, +) -> Result, Box> { + let vres = audio + .audio_generation_request() + .text(text) + .voice(voice) + .additional_params(json!( + { + "response_format": "mp3", + } + )) + .send() + .await?; + Ok(vres.audio.clone()) +} + +async fn routing( + vision: Agent>, + summary: Agent>, + audio: openai::audio_generation::AudioGenerationModel, + audio_voice: &str, +) -> Result<(), Box> { + let _vision = vision; + let mut s: String = String::new(); + for m in chat(summary).await? { + let text: String = match m { + Message::User { content } => { + let mut e: String = "User: ".into(); + for c in content { + if let UserContent::Text(content) = c { + e = e + content.text().into(); + e = e + "\n".into(); + } + } + e + }, + Message::Assistant { id, content } => { + let _id = id; + let mut e: String = "Assistant: ".into(); + for c in content { + if let AssistantContent::Text(content) = c { + e = e + content.text().into(); + e = e + "\n".into(); + } + } + e + }, + }; + s = s + &text; + } + let e = get_audio(audio, audio_voice, &s).await?; + let mut fiel = std::fs::OpenOptions::new() + .create(true) + .write(true) + .truncate(true) + .open("chat.mp3")?; + fiel.write_all(&e.as_slice())?; Ok(()) } -- cgit 1.4.1-2-gfad0