about summary refs log tree commit diff stats
path: root/src/main.rs
diff options
context:
space:
mode:
authorRen Kararou <[email protected]>2025-11-23 00:42:54 -0600
committerRen Kararou <[email protected]>2025-11-23 00:49:14 -0600
commitb062128ec1715e5de948347fea1b3df8c6333cac (patch)
tree3c787c504851a8bb27a20cf3b103cb86ac99206c /src/main.rs
parentdb1edb4c90a6f1dbe391d81b9e5b34969494a7af (diff)
downloadviolet-b062128ec1715e5de948347fea1b3df8c6333cac.tar.gz
violet-b062128ec1715e5de948347fea1b3df8c6333cac.tar.bz2
violet-b062128ec1715e5de948347fea1b3df8c6333cac.zip
switch to rig; working repl
Diffstat (limited to 'src/main.rs')
-rw-r--r--src/main.rs210
1 files changed, 78 insertions, 132 deletions
diff --git a/src/main.rs b/src/main.rs
index 943104d..b7ff867 100644
--- a/src/main.rs
+++ b/src/main.rs
@@ -1,25 +1,31 @@
-use llm_connector::{LlmClient, types::{ChatRequest, Message, Role, Tool, Function}};
-use serde_json::{json, Value};
+use chrono::Utc;
+use reqwest::ClientBuilder;
 use serde::Deserialize;
 use std::fs::read_to_string;
+use std::time::Duration;
 use std::io::Write;
-use chrono::Utc;
+
+use rig::{
+    agent::stream_to_stdout,
+    prelude::*,
+    providers::openai,
+    streaming::StreamingChat,
+    message::Message,
+    client::audio_generation::AudioGenerationClient,
+    audio_generation::AudioGenerationModel,
+};
 
 #[derive(Deserialize, Clone, Debug)]
 struct Config {
     base_url: String,
     key: String,
     model: String,
+    audio_model: Option<String>,
+    audio_voice: Option<String>,
     system_prompt: String,
     timeout: u64,
-}
-
-async fn get_horoscope(sign: &str) -> String {
-    format!("{sign}: Next Tuesday you will befriend a baby otter.")
-}
-
-async fn wikipedia_lookup(title: &str) -> String {
-    format!("{title}: This article is the article on {title} and is quite interesting.  {title} has several toes.  Many more toes than a {title} should have.")
+    max_tokens: u64,
+    temp: f64,
 }
 
 #[tokio::main]
@@ -29,137 +35,77 @@ async fn main() -> Result<(), Box<dyn std::error::Error>> {
     let config = read_to_string("config.json")?;
     let config: Config = serde_json::from_str(&config)?;
     eprintln!("Config Loaded");
-    let client = LlmClient::openai_with_config(
-        &config.key,
-        Some(&config.base_url),
-        Some(config.timeout),
-        None,
-    )?;
-    eprintln!("Config Setup");
-    let mut tools: Vec<Tool> = Vec::new();
-    tools.push(Tool {
-        tool_type: "function".into(),
-        function: Function {
-            name: "get_horoscope".into(),
-            description: Some("Get today's horoscope for an astrological sign.".into()),
-            parameters: json!({
-                "sign": {
-                    "type": "string",
-                    "description": "An astrological sign like Taurus or Aquarius."
-                }
-            }),
-        },
-    });
-    tools.push(Tool {
-        tool_type: "function".into(),
-        function: Function {
-            name: "stop".into(),
-            description: Some("Emergency Stop the Conversation.  Only to be used when the user is requesting something dangerous.".into()),
-            parameters: json!({}),
-        },
-    });
-    tools.push(Tool {
-        tool_type: "function".into(),
-        function: Function {
-            name: "wikipedia_lookup".into(),
-            description: Some("Look up a wikipedia article and have its summary returned.".into()),
-            parameters: json!({
-                "title": {
-                    "type": "string",
-                    "description": "The title of the article to look up."
-                }
-            }),
-        },
-    });
-    eprintln!("Tools Loaded");
+    let conn_timeout = if config.timeout < 30 {
+        config.timeout
+    } else if config.timeout < 300 {
+        config.timeout / 2
+    } else {
+        config.timeout / 4
+    };
+    let http_client = ClientBuilder::new()
+        .user_agent("violet-rs/0.1")
+        .read_timeout(Duration::from_secs(config.timeout))
+        .connect_timeout(Duration::from_secs(conn_timeout))
+        .build()?;
     let date: String = Utc::now().date_naive().to_string();
     let system_prompt: String = format!("The current date is {date}.  {}", &config.system_prompt);
-    let mut req = ChatRequest {
-        model: config.model,
-        messages: vec![
-            Message::text(Role::System, &system_prompt),
-        ],
-        tools: Some(tools),
-        //temperature: Some(0.9),
-        ..Default::default()
-    };
     eprintln!("System Prompt is: {system_prompt}");
+    let api = openai::ClientBuilder::new_with_client(&config.key, http_client)
+        .base_url(&config.base_url)
+        .build();
+    let violet = api.completion_model(&config.model)
+        .completions_api()
+        .into_agent_builder()
+        .preamble(&system_prompt)
+        .max_tokens(config.max_tokens)
+        .temperature(config.temp)
+        .build();
+    let audio_model = if let Some(model) = &config.audio_model {
+        model
+    } else {
+        "cosyvoice"
+    };
+    let _audio_voice = if let Some(voice) = &config.audio_voice {
+        voice
+    } else {
+        "english_female"
+    };
+    let _violet_voice = api.audio_generation_model(audio_model);
     eprintln!("Base Request Setup");
     eprintln!("Setup Finished");
     let mut s = String::new();
-    print!("user> ");
-    std::io::stdout().flush()?;
+    print!("> ");
+    let _ = std::io::stdout().flush();
     if let Err(e) = std::io::stdin().read_line(&mut s) {
-        eprintln!("could not read stdin: {e}");
-        return Ok(())
+        eprintln!("Error reading stdin: {e}");
     }
-    if s.as_str().trim().to_lowercase() == "stop" || s.as_str().trim().to_lowercase() == "quit" {
-        return Ok(());
+    let mut history: Vec<Message> = Vec::new();
+    let mut uwu = true;
+    if "stop" == s.as_str().to_lowercase().trim() {
+        uwu = false;
     }
-    req.messages.push(Message::text(Role::User, &s));
-    let mut uwu: bool = true;
     while uwu {
-        let response = client.chat(&req).await?;
-        eprintln!("{:?}", &response);
-        for choice in response.choices.clone() {
-            if let Some(reason) = choice.finish_reason {
-                match reason.as_str() {
-                    "tool_calls" => {
-                        if let Some(calls) = choice.message.clone().tool_calls {
-                            for call in calls {
-                                match call.function.name.as_str() {
-                                    "get_horoscope" => {
-                                        eprintln!("get_horoscope!");
-                                        let v: Value = serde_json::from_str(call.function.arguments.as_str())?;
-                                        req.messages.push(choice.message.clone());
-                                        let v = v["sign"].as_str().unwrap_or_default();
-                                        let val: String = get_horoscope(v).await;
-                                        req.messages.push(Message::tool(val, call.id));
-                                        eprintln!("{:?}", &req);
-                                    },
-                                    "wikipedia_lookup" => {
-                                        eprintln!("wikipedia_lookup!");
-                                        let v: Value = serde_json::from_str(call.function.arguments.as_str())?;
-                                        req.messages.push(choice.message.clone());
-                                        let v = v["title"].as_str().unwrap_or_default();
-                                        let val: String = wikipedia_lookup(v).await;
-                                        req.messages.push(Message::tool(val, call.id));
-                                        eprintln!("{:?}", &req);
-                                    }
-                                    "stop" => {
-                                        println!("Agent has stopped the conversation.");
-                                        if let Some(reasoning) = response.reasoning_content {
-                                            eprintln!("Agent Stopped with Reasoning: {}", reasoning.as_str().trim());
-                                        } else {
-                                            eprintln!("Agent has not given reasoning for emergency stop.");
-                                        }
-                                        return Ok(());
-                                    }
-                                    _ => (),
-                                }
-                            }
-                        }
-                    },
-                    "stop" => {
-                        println!("agent> {}", &response.content.trim());
-                        req.messages.push(Message::text(Role::Assistant, &response.content.clone()));
-                        let mut s = String::new();
-                        print!("user> ");
-                        std::io::stdout().flush()?;
-                        if let Err(e) = std::io::stdin().read_line(&mut s) {
-                            eprintln!("could not read stdin: {e}");
-                            uwu = false;
-                            break;
-                        }
-                        if s.as_str().trim().to_lowercase() == "stop" || s.as_str().trim().to_lowercase() == "quit" {
-                            uwu = false;
-                            break;
-                        }
-                        req.messages.push(Message::text(Role::User, &s));
-                    },
-                    _ => (),
-                }
-            }
+        let mut stream = violet
+            .stream_chat(&s, history.clone())
+            .await;
+        let res = stream_to_stdout(&mut stream).await?;
+        print!("\n");
+        //let vres = violet_voice
+        //    .audio_generation_request()
+        //    .text(res.response())
+        //    .voice(audio_voice)
+        //    .send()
+        //    .await?;
+        history.push(Message::user(s.clone()));
+        history.push(Message::assistant(res.response()));
+        print!("> ");
+        s = String::new();
+        let _ = std::io::stdout().flush();
+        if let Err(e) = std::io::stdin().read_line(&mut s) {
+            eprintln!("Error reading stdin: {e}");
+        }
+        if s.as_str().to_lowercase().trim() == "stop" {
+            break;
         }
     }
     Ok(())