about summary refs log tree commit diff stats
path: root/src/main.rs
diff options
context:
space:
mode:
Diffstat (limited to 'src/main.rs')
-rw-r--r--src/main.rs151
1 files changed, 97 insertions, 54 deletions
diff --git a/src/main.rs b/src/main.rs
index 4dbb9cf..d86baf8 100644
--- a/src/main.rs
+++ b/src/main.rs
@@ -19,49 +19,85 @@ use serde::{Deserialize, Serialize};
 use serde_json::json;
 use std::io::{Cursor, Write};
 use std::time::Duration;
+use surrealdb::{
+  Surreal,
+  engine::local::RocksDb,
+};
 
 #[derive(Serialize, Deserialize, Clone, Debug)]
-struct Config {
+struct ModelOptions {
+  #[serde(skip_serializing_if = "Option::is_none")]
+  model: Option<String>,
+  #[serde(skip_serializing_if = "Option::is_none")]
+  prompt: Option<String>,
+  tokens: u64,
+  temp: f64,
+}
+impl std::default::Default for ModelOptions {
+  fn default() -> Self {
+    Self {
+      model: None,
+      prompt: None,
+      tokens: 4096,
+      temp: 1.0,
+    }
+  }
+}
+
+#[derive(Serialize, Deserialize, Clone, Debug)]
+struct TTSOptions {
+  model: String,
+  voice: String,
+}
+impl std::default::Default for TTSOptions {
+  fn default() -> Self {
+    Self {
+      model: String::from("tts-1"),
+      voice: String::from("Alloy"),
+    }
+  }
+}
+
+#[derive(Serialize, Deserialize, Clone, Debug)]
+struct ApiOptions {
   base_url: String,
   key: String,
   timeout: u64,
+}
+impl std::default::Default for ApiOptions {
+  fn default() -> Self {
+  Self {
+    base_url: String::from("https://api.openai.com/v1"),
+    key: String::from("sk-..."),
+    timeout: 30,
+    }
+  }
+}
+
+#[derive(Serialize, Deserialize, Clone, Debug)]
+struct Config {
+  api: ApiOptions,
+  vision: ModelOptions,
+  summary: ModelOptions,
   #[serde(skip_serializing_if = "Option::is_none")]
-  vision_model: Option<String>,
-  #[serde(skip_serializing_if = "Option::is_none")]
-  vision_prompt: Option<String>,
-  #[serde(skip_serializing_if = "Option::is_none")]
-  summary_model: Option<String>,
-  #[serde(skip_serializing_if = "Option::is_none")]
-  summary_prompt: Option<String>,
-  #[serde(skip_serializing_if = "Option::is_none")]
-  audio_model: Option<String>,
-  #[serde(skip_serializing_if = "Option::is_none")]
-  audio_voice: Option<String>,
-  vision_tokens: u64,
-  vision_temp: f64,
-  summary_tokens: u64,
-  summary_temp: f64,
+  tts: Option<TTSOptions>,
 }
 impl std::default::Default for Config {
   fn default() -> Self {
     Self {
-      base_url: String::from("https://api.openai.com/v1"),
-      key: String::from("sk-..."),
-      vision_model: None,
-      vision_prompt: None,
-      summary_prompt: None,
-      summary_model: None,
-      audio_model: None,
-      audio_voice: None,
-      timeout: 30,
-      vision_tokens: 4096,
-      vision_temp: 0.4,
-      summary_tokens: 8192,
-      summary_temp: 0.9,
+      api: ApiOptions::default(),
+      vision: ModelOptions::default(),
+      summary: ModelOptions::default(),
+      tts: None,
     }
   }
 }
 
+#[derive(Debug, Serialize, Deserialize)]
+struct Record {
+  id: surrealdb::RecordId,
+}
+
 #[tokio::main]
 async fn main() -> Result<(), Box<dyn std::error::Error>> {
   eprintln!("Starting setup");
@@ -75,27 +111,27 @@ async fn main() -> Result<(), Box<dyn std::error::Error>> {
       .unwrap_or("path does not exist")
   );
   eprintln!("Config Loaded");
-  let conn_timeout = if config.timeout < 30 {
-    config.timeout
-  } else if config.timeout < 300 {
-    config.timeout / 2
+  let conn_timeout = if config.api.timeout < 30 {
+    config.api.timeout
+  } else if config.api.timeout < 300 {
+    config.api.timeout / 2
   } else {
-    config.timeout / 4
+    config.api.timeout / 4
   };
   let http_client = ClientBuilder::new()
     .user_agent("violet-rs/0.1")
-    .read_timeout(Duration::from_secs(config.timeout))
+    .read_timeout(Duration::from_secs(config.api.timeout))
     .connect_timeout(Duration::from_secs(conn_timeout))
     .build()?;
   let date: String = Utc::now().date_naive().to_string();
-  let vision_prompt: String = if let Some(prompt) = config.vision_prompt {
+  let vision_prompt: String = if let Some(prompt) = config.vision.prompt {
     prompt
   } else {
     "You will describe the images attached".into()
   };
   let summary_prompt: String = format!(
     "The current date is {date}.\n\n{}",
-    if let Some(prompt) = config.summary_prompt {
+    if let Some(prompt) = config.summary.prompt {
       prompt
     } else {
       String::from("You will create a narrative for the image ")
@@ -104,10 +140,10 @@ async fn main() -> Result<(), Box<dyn std::error::Error>> {
   );
   eprintln!("Vision System Prompt is: {vision_prompt}");
   eprintln!("Summary System Prompt is: {summary_prompt}");
-  let api = openai::ClientBuilder::new_with_client(&config.key, http_client)
-    .base_url(&config.base_url)
+  let api = openai::ClientBuilder::new_with_client(&config.api.key, http_client)
+    .base_url(&config.api.base_url)
     .build();
-  let vision_model: String = if let Some(vmodel) = config.vision_model {
+  let vision_model: String = if let Some(vmodel) = config.vision.model {
     vmodel
   } else {
     "gpt-image-1".into()
@@ -117,10 +153,10 @@ async fn main() -> Result<(), Box<dyn std::error::Error>> {
     .completions_api()
     .into_agent_builder()
     .preamble(&vision_prompt)
-    .max_tokens(config.vision_tokens)
-    .temperature(config.vision_temp)
+    .max_tokens(config.vision.tokens)
+    .temperature(config.vision.temp)
     .build();
-  let summary_model: String = if let Some(smodel) = config.summary_model {
+  let summary_model: String = if let Some(smodel) = config.summary.model {
     smodel
   } else {
     "gpt-4o".into()
@@ -130,18 +166,13 @@ async fn main() -> Result<(), Box<dyn std::error::Error>> {
     .completions_api()
     .into_agent_builder()
     .preamble(&summary_prompt)
-    .max_tokens(config.summary_tokens)
-    .temperature(config.summary_temp)
+    .max_tokens(config.summary.tokens)
+    .temperature(config.summary.temp)
     .build();
-  let audio_model = if let Some(model) = &config.audio_model {
-    model
-  } else {
-    "tts-1"
-  };
-  let audio_voice = if let Some(voice) = &config.audio_voice {
-    voice
+  let (audio_model, audio_voice) = if let Some(tts) = &config.tts {
+      (tts.model.as_str(), tts.voice.as_str())
   } else {
-    "alloy"
+      ("tts-1", "Alloy")
   };
   let audio = api.audio_generation_model(audio_model);
   eprintln!("Setup Finished");
@@ -188,7 +219,7 @@ async fn prompt_model(
   history: Vec<Message>,
 ) -> Result<Message, Box<dyn std::error::Error>> {
   let res = agent.chat(prompt, history).await?;
-  Ok(rig::message::AssistantContent::text(&res).into())
+  Ok(AssistantContent::text(&res).into())
 }
 
 async fn get_audio(
@@ -216,6 +247,18 @@ async fn routing(
   audio: openai::audio_generation::AudioGenerationModel,
   audio_voice: &str,
 ) -> Result<(), Box<dyn std::error::Error>> {
+  let db = Surreal::new::<RocksDb>("db").await?;
+  db.use_ns("violet").use_db("manga").await?;
+  /*TODO:
+   * 1. pick manga based on tags via mangadex api.
+   * 2. batch and send pages to gemma-3 VL model
+   * 3. send descriptions to gpt-oss for verification
+   * 4. send approved descriptions back to gpt-oss for narrative generation
+   * 5. send gpt-oss output to kokoro
+   * 6. potentially get gemma-3 to mark important frames for video purposes?
+   * 7. compile video from template, images, and kokoro output
+   * 8. upload directly to youtube via some youtube api crate
+   */
   let _vision = vision;
   let mut s: String = String::new();
   for m in chat(summary).await? {