diff options
Diffstat (limited to 'src/main.rs')
| -rw-r--r-- | src/main.rs | 69 |
1 files changed, 50 insertions, 19 deletions
diff --git a/src/main.rs b/src/main.rs index b7ff867..3c139ab 100644 --- a/src/main.rs +++ b/src/main.rs @@ -1,39 +1,64 @@ use chrono::Utc; use reqwest::ClientBuilder; -use serde::Deserialize; -use std::fs::read_to_string; +use serde::{Deserialize, Serialize}; use std::time::Duration; -use std::io::Write; +use std::io::{Write, Cursor}; +//use base64::{Engine, engine::general_purpose::STANDARD}; +use serde_json::json; +use rodio::Decoder; use rig::{ agent::stream_to_stdout, prelude::*, providers::openai, streaming::StreamingChat, - message::Message, + message::{Message, Image, ImageMediaType, DocumentSourceKind, ImageDetail}, client::audio_generation::AudioGenerationClient, audio_generation::AudioGenerationModel, }; -#[derive(Deserialize, Clone, Debug)] +#[derive(Serialize, Deserialize, Clone, Debug)] struct Config { base_url: String, key: String, model: String, + #[serde(skip_serializing_if = "Option::is_none")] audio_model: Option<String>, + #[serde(skip_serializing_if = "Option::is_none")] audio_voice: Option<String>, system_prompt: String, timeout: u64, max_tokens: u64, temp: f64, } +impl std::default::Default for Config { + fn default() -> Self { + Self { + base_url: String::from("https://api.openai.com/v1"), + key: String::from("sk-..."), + model: String::from("gpt-4o"), + audio_model: None, + audio_voice: None, + system_prompt: String::from("You are a helpful assistant!"), + timeout: 30, + max_tokens: 4096, + temp: 0.4, + } + } +} #[tokio::main] async fn main() -> Result<(), Box<dyn std::error::Error>> { eprintln!("Starting setup"); eprintln!("Loading Config"); - let config = read_to_string("config.json")?; - let config: Config = serde_json::from_str(&config)?; + let config: Config = confy::load("violet", Some("violet"))?; + println!( + "Config file location: {}", + confy::get_configuration_file_path("violet", None)? + .as_path() + .to_str() + .unwrap_or("path does not exist") + ); eprintln!("Config Loaded"); let conn_timeout = if config.timeout < 30 { config.timeout @@ -48,7 +73,7 @@ async fn main() -> Result<(), Box<dyn std::error::Error>> { .connect_timeout(Duration::from_secs(conn_timeout)) .build()?; let date: String = Utc::now().date_naive().to_string(); - let system_prompt: String = format!("The current date is {date}. {}", &config.system_prompt); + let system_prompt: String = format!("The current date is {date}.\n\n{}", &config.system_prompt); eprintln!("System Prompt is: {system_prompt}"); let api = openai::ClientBuilder::new_with_client(&config.key, http_client) .base_url(&config.base_url) @@ -63,15 +88,19 @@ async fn main() -> Result<(), Box<dyn std::error::Error>> { let audio_model = if let Some(model) = &config.audio_model { model } else { - "cosyvoice" + "tts-1" }; - let _audio_voice = if let Some(voice) = &config.audio_voice { + let audio_voice = if let Some(voice) = &config.audio_voice { voice } else { - "english_female" + "alloy" }; - let _violet_voice = api.audio_generation_model(audio_model); + let violet_voice = api.audio_generation_model(audio_model); eprintln!("Base Request Setup"); + let mut history: Vec<Message> = Vec::new(); + eprintln!("Getting Audio Device"); + let stream_handle = rodio::OutputStreamBuilder::open_default_stream()?; + let _sink = rodio::Sink::connect_new(&stream_handle.mixer()); eprintln!("Setup Finished"); let mut s = String::new(); print!("> "); @@ -79,7 +108,6 @@ async fn main() -> Result<(), Box<dyn std::error::Error>> { if let Err(e) = std::io::stdin().read_line(&mut s) { eprintln!("Error reading stdin: {e}"); } - let mut history: Vec<Message> = Vec::new(); let mut uwu = true; if "stop" == s.as_str().to_lowercase().trim() { uwu = false; @@ -90,12 +118,15 @@ async fn main() -> Result<(), Box<dyn std::error::Error>> { .await; let res = stream_to_stdout(&mut stream).await?; print!("\n"); - //let vres = violet_voice - // .audio_generation_request() - // .text(res.response()) - // .voice(audio_voice) - // .send() - // .await?; + let vres = violet_voice + .audio_generation_request() + .text(res.response()) + .voice(audio_voice) + .additional_params(json!({"response_format": "mp3"})) + .send() + .await?; + let vdata = Decoder::new(Cursor::new(vres.audio.clone()))?; + stream_handle.mixer().add(vdata); history.push(Message::user(s.clone())); history.push(Message::assistant(res.response())); print!("> "); |
