about summary refs log tree commit diff stats
path: root/src
diff options
context:
space:
mode:
authorRen Kararou <[email protected]>2025-11-28 21:11:11 -0600
committerRen Kararou <[email protected]>2025-11-28 21:11:11 -0600
commitb0429fdf9164c3b59f9cd8186b971d9cb038c46c (patch)
tree0d86884506beeffebe2f272d6241abd55045092a /src
parentb062128ec1715e5de948347fea1b3df8c6333cac (diff)
downloadviolet-b0429fdf9164c3b59f9cd8186b971d9cb038c46c.tar.gz
violet-b0429fdf9164c3b59f9cd8186b971d9cb038c46c.tar.bz2
violet-b0429fdf9164c3b59f9cd8186b971d9cb038c46c.zip
llm now speaks; config autogenerates
Diffstat (limited to 'src')
-rw-r--r--src/main.rs69
1 files changed, 50 insertions, 19 deletions
diff --git a/src/main.rs b/src/main.rs
index b7ff867..3c139ab 100644
--- a/src/main.rs
+++ b/src/main.rs
@@ -1,39 +1,64 @@
 use chrono::Utc;
 use reqwest::ClientBuilder;
-use serde::Deserialize;
-use std::fs::read_to_string;
+use serde::{Deserialize, Serialize};
 use std::time::Duration;
-use std::io::Write;
+use std::io::{Write, Cursor};
+//use base64::{Engine, engine::general_purpose::STANDARD};
+use serde_json::json;
+use rodio::Decoder;
 
 use rig::{
     agent::stream_to_stdout,
     prelude::*,
     providers::openai,
     streaming::StreamingChat,
-    message::Message,
+    message::{Message, Image, ImageMediaType, DocumentSourceKind, ImageDetail},
     client::audio_generation::AudioGenerationClient,
     audio_generation::AudioGenerationModel,
 };
 
-#[derive(Deserialize, Clone, Debug)]
+#[derive(Serialize, Deserialize, Clone, Debug)]
 struct Config {
     base_url: String,
     key: String,
     model: String,
+    #[serde(skip_serializing_if = "Option::is_none")]
     audio_model: Option<String>,
+    #[serde(skip_serializing_if = "Option::is_none")]
     audio_voice: Option<String>,
     system_prompt: String,
     timeout: u64,
     max_tokens: u64,
     temp: f64,
 }
+impl std::default::Default for Config {
+    fn default() -> Self {
+        Self {
+            base_url: String::from("https://api.openai.com/v1"),
+            key: String::from("sk-..."),
+            model: String::from("gpt-4o"),
+            audio_model: None,
+            audio_voice: None,
+            system_prompt: String::from("You are a helpful assistant!"),
+            timeout: 30,
+            max_tokens: 4096,
+            temp: 0.4,
+        }
+    }
+}
 
 #[tokio::main]
 async fn main() -> Result<(), Box<dyn std::error::Error>> {
     eprintln!("Starting setup");
     eprintln!("Loading Config");
-    let config = read_to_string("config.json")?;
-    let config: Config = serde_json::from_str(&config)?;
+    let config: Config = confy::load("violet", Some("violet"))?;
+    println!(
+        "Config file location: {}",
+        confy::get_configuration_file_path("violet", None)?
+            .as_path()
+            .to_str()
+            .unwrap_or("path does not exist")
+    );
     eprintln!("Config Loaded");
     let conn_timeout = if config.timeout < 30 {
         config.timeout
@@ -48,7 +73,7 @@ async fn main() -> Result<(), Box<dyn std::error::Error>> {
         .connect_timeout(Duration::from_secs(conn_timeout))
         .build()?;
     let date: String = Utc::now().date_naive().to_string();
-    let system_prompt: String = format!("The current date is {date}.  {}", &config.system_prompt);
+    let system_prompt: String = format!("The current date is {date}.\n\n{}", &config.system_prompt);
     eprintln!("System Prompt is: {system_prompt}");
     let api = openai::ClientBuilder::new_with_client(&config.key, http_client)
         .base_url(&config.base_url)
@@ -63,15 +88,19 @@ async fn main() -> Result<(), Box<dyn std::error::Error>> {
     let audio_model = if let Some(model) = &config.audio_model {
         model
     } else {
-        "cosyvoice"
+        "tts-1"
     };
-    let _audio_voice = if let Some(voice) = &config.audio_voice {
+    let audio_voice = if let Some(voice) = &config.audio_voice {
         voice
     } else {
-        "english_female"
+        "alloy"
     };
-    let _violet_voice = api.audio_generation_model(audio_model);
+    let violet_voice = api.audio_generation_model(audio_model);
     eprintln!("Base Request Setup");
+    let mut history: Vec<Message> = Vec::new();
+    eprintln!("Getting Audio Device");
+    let stream_handle = rodio::OutputStreamBuilder::open_default_stream()?;
+    let _sink = rodio::Sink::connect_new(&stream_handle.mixer());
     eprintln!("Setup Finished");
     let mut s = String::new();
     print!("> ");
@@ -79,7 +108,6 @@ async fn main() -> Result<(), Box<dyn std::error::Error>> {
     if let Err(e) = std::io::stdin().read_line(&mut s) {
         eprintln!("Error reading stdin: {e}");
     }
-    let mut history: Vec<Message> = Vec::new();
     let mut uwu = true;
     if "stop" == s.as_str().to_lowercase().trim() {
         uwu = false;
@@ -90,12 +118,15 @@ async fn main() -> Result<(), Box<dyn std::error::Error>> {
             .await;
         let res = stream_to_stdout(&mut stream).await?;
         print!("\n");
-        //let vres = violet_voice
-        //    .audio_generation_request()
-        //    .text(res.response())
-        //    .voice(audio_voice)
-        //    .send()
-        //    .await?;
+        let vres = violet_voice
+            .audio_generation_request()
+            .text(res.response())
+            .voice(audio_voice)
+            .additional_params(json!({"response_format": "mp3"}))
+            .send()
+            .await?;
+        let vdata = Decoder::new(Cursor::new(vres.audio.clone()))?;
+        stream_handle.mixer().add(vdata);
         history.push(Message::user(s.clone()));
         history.push(Message::assistant(res.response()));
         print!("> ");