about summary refs log tree commit diff stats
path: root/src/main.rs
diff options
context:
space:
mode:
Diffstat (limited to 'src/main.rs')
-rw-r--r--src/main.rs197
1 files changed, 153 insertions, 44 deletions
diff --git a/src/main.rs b/src/main.rs
index 7a436e7..4dbb9cf 100644
--- a/src/main.rs
+++ b/src/main.rs
@@ -1,16 +1,20 @@
-//use base64::{Engine, engine::general_purpose::STANDARD};
+use base64::{Engine, engine::general_purpose::STANDARD};
 use chrono::Utc;
-use reqwest::ClientBuilder;
+use reqwest::{Client, ClientBuilder};
 use rig::{
-  agent::stream_to_stdout,
+  agent::{Agent, stream_to_stdout},
   audio_generation::AudioGenerationModel,
   client::audio_generation::AudioGenerationClient,
-  message::{DocumentSourceKind, Image, ImageDetail, ImageMediaType, Message},
+  completion::Chat,
+  message::{
+    AssistantContent, DocumentSourceKind, Image, ImageDetail, ImageMediaType,
+    Message, UserContent,
+  },
   prelude::*,
   providers::openai,
+  providers::openai::CompletionModel,
   streaming::StreamingChat,
 };
-use rodio::Decoder;
 use serde::{Deserialize, Serialize};
 use serde_json::json;
 use std::io::{Cursor, Write};
@@ -20,28 +24,40 @@ use std::time::Duration;
 struct Config {
   base_url: String,
   key: String,
-  model: String,
+  timeout: u64,
+  #[serde(skip_serializing_if = "Option::is_none")]
+  vision_model: Option<String>,
+  #[serde(skip_serializing_if = "Option::is_none")]
+  vision_prompt: Option<String>,
+  #[serde(skip_serializing_if = "Option::is_none")]
+  summary_model: Option<String>,
+  #[serde(skip_serializing_if = "Option::is_none")]
+  summary_prompt: Option<String>,
   #[serde(skip_serializing_if = "Option::is_none")]
   audio_model: Option<String>,
   #[serde(skip_serializing_if = "Option::is_none")]
   audio_voice: Option<String>,
-  system_prompt: String,
-  timeout: u64,
-  max_tokens: u64,
-  temp: f64,
+  vision_tokens: u64,
+  vision_temp: f64,
+  summary_tokens: u64,
+  summary_temp: f64,
 }
 impl std::default::Default for Config {
   fn default() -> Self {
     Self {
       base_url: String::from("https://api.openai.com/v1"),
       key: String::from("sk-..."),
-      model: String::from("gpt-4o"),
+      vision_model: None,
+      vision_prompt: None,
+      summary_prompt: None,
+      summary_model: None,
       audio_model: None,
       audio_voice: None,
-      system_prompt: String::from("You are a helpful assistant!"),
       timeout: 30,
-      max_tokens: 4096,
-      temp: 0.4,
+      vision_tokens: 4096,
+      vision_temp: 0.4,
+      summary_tokens: 8192,
+      summary_temp: 0.9,
     }
   }
 }
@@ -53,7 +69,7 @@ async fn main() -> Result<(), Box<dyn std::error::Error>> {
   let config: Config = confy::load("violet", Some("violet"))?;
   println!(
     "Config file location: {}",
-    confy::get_configuration_file_path("violet", None)?
+    confy::get_configuration_file_path("violet", Some("violet"))?
       .as_path()
       .to_str()
       .unwrap_or("path does not exist")
@@ -72,19 +88,50 @@ async fn main() -> Result<(), Box<dyn std::error::Error>> {
     .connect_timeout(Duration::from_secs(conn_timeout))
     .build()?;
   let date: String = Utc::now().date_naive().to_string();
-  let system_prompt: String =
-    format!("The current date is {date}.\n\n{}", &config.system_prompt);
-  eprintln!("System Prompt is: {system_prompt}");
+  let vision_prompt: String = if let Some(prompt) = config.vision_prompt {
+    prompt
+  } else {
+    "You will describe the images attached".into()
+  };
+  let summary_prompt: String = format!(
+    "The current date is {date}.\n\n{}",
+    if let Some(prompt) = config.summary_prompt {
+      prompt
+    } else {
+      String::from("You will create a narrative for the image ")
+        + "descriptions given as if you were telling a story."
+    }
+  );
+  eprintln!("Vision System Prompt is: {vision_prompt}");
+  eprintln!("Summary System Prompt is: {summary_prompt}");
   let api = openai::ClientBuilder::new_with_client(&config.key, http_client)
     .base_url(&config.base_url)
     .build();
-  let violet = api
-    .completion_model(&config.model)
+  let vision_model: String = if let Some(vmodel) = config.vision_model {
+    vmodel
+  } else {
+    "gpt-image-1".into()
+  };
+  let vision = api
+    .completion_model(&vision_model)
+    .completions_api()
+    .into_agent_builder()
+    .preamble(&vision_prompt)
+    .max_tokens(config.vision_tokens)
+    .temperature(config.vision_temp)
+    .build();
+  let summary_model: String = if let Some(smodel) = config.summary_model {
+    smodel
+  } else {
+    "gpt-4o".into()
+  };
+  let summary = api
+    .completion_model(&summary_model)
     .completions_api()
     .into_agent_builder()
-    .preamble(&system_prompt)
-    .max_tokens(config.max_tokens)
-    .temperature(config.temp)
+    .preamble(&summary_prompt)
+    .max_tokens(config.summary_tokens)
+    .temperature(config.summary_temp)
     .build();
   let audio_model = if let Some(model) = &config.audio_model {
     model
@@ -96,13 +143,16 @@ async fn main() -> Result<(), Box<dyn std::error::Error>> {
   } else {
     "alloy"
   };
-  let violet_voice = api.audio_generation_model(audio_model);
-  eprintln!("Base Request Setup");
-  let mut history: Vec<Message> = Vec::new();
-  eprintln!("Getting Audio Device");
-  let stream_handle = rodio::OutputStreamBuilder::open_default_stream()?;
-  let _sink = rodio::Sink::connect_new(&stream_handle.mixer());
+  let audio = api.audio_generation_model(audio_model);
   eprintln!("Setup Finished");
+  routing(vision, summary, audio, audio_voice).await?;
+  Ok(())
+}
+
+async fn chat(
+  agent: Agent<CompletionModel<Client>>,
+) -> Result<Vec<Message>, Box<dyn std::error::Error>> {
+  let mut history: Vec<Message> = Vec::new();
   let mut s = String::new();
   print!("> ");
   let _ = std::io::stdout().flush();
@@ -114,22 +164,9 @@ async fn main() -> Result<(), Box<dyn std::error::Error>> {
     uwu = false;
   }
   while uwu {
-    let mut stream = violet.stream_chat(&s, history.clone()).await;
+    let mut stream = agent.stream_chat(&s, history.clone()).await;
     let res = stream_to_stdout(&mut stream).await?;
     print!("\n");
-    let vres = violet_voice
-      .audio_generation_request()
-      .text(res.response())
-      .voice(&audio_voice)
-      .additional_params(json!(
-        {
-          "response_format": "mp3",
-        }
-      ))
-      .send()
-      .await?;
-    let vdata = Decoder::new(Cursor::new(vres.audio.clone()))?;
-    stream_handle.mixer().add(vdata);
     history.push(Message::user(s.clone()));
     history.push(Message::assistant(res.response()));
     print!("> ");
@@ -139,8 +176,80 @@ async fn main() -> Result<(), Box<dyn std::error::Error>> {
       eprintln!("Error reading stdin: {e}");
     }
     if s.as_str().to_lowercase().trim() == "stop" {
-      break;
+      uwu = false;
     }
   }
+  Ok(history)
+}
+
+async fn prompt_model(
+  agent: Agent<CompletionModel<Client>>,
+  prompt: Message,
+  history: Vec<Message>,
+) -> Result<Message, Box<dyn std::error::Error>> {
+  let res = agent.chat(prompt, history).await?;
+  Ok(rig::message::AssistantContent::text(&res).into())
+}
+
+async fn get_audio(
+  audio: openai::audio_generation::AudioGenerationModel,
+  voice: &str,
+  text: &str,
+) -> Result<Vec<u8>, Box<dyn std::error::Error>> {
+  let vres = audio
+    .audio_generation_request()
+    .text(text)
+    .voice(voice)
+    .additional_params(json!(
+      {
+        "response_format": "mp3",
+      }
+    ))
+    .send()
+    .await?;
+  Ok(vres.audio.clone())
+}
+
+async fn routing(
+  vision: Agent<CompletionModel<Client>>,
+  summary: Agent<CompletionModel<Client>>,
+  audio: openai::audio_generation::AudioGenerationModel,
+  audio_voice: &str,
+) -> Result<(), Box<dyn std::error::Error>> {
+  let _vision = vision;
+  let mut s: String = String::new();
+  for m in chat(summary).await? {
+    let text: String = match m {
+      Message::User { content } => {
+        let mut e: String = "User: ".into();
+        for c in content {
+          if let UserContent::Text(content) = c {
+            e = e + content.text().into();
+            e = e + "\n".into();
+          }
+        }
+        e
+      },
+      Message::Assistant { id, content } => {
+        let _id = id;
+        let mut e: String = "Assistant: ".into();
+        for c in content {
+          if let AssistantContent::Text(content) = c {
+            e = e + content.text().into();
+            e = e + "\n".into();
+          }
+        }
+        e
+      },
+    };
+    s = s + &text;
+  }
+  let e = get_audio(audio, audio_voice, &s).await?;
+  let mut fiel = std::fs::OpenOptions::new()
+    .create(true)
+    .write(true)
+    .truncate(true)
+    .open("chat.mp3")?;
+  fiel.write_all(&e.as_slice())?;
   Ok(())
 }