1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
|
use chrono::Utc;
use reqwest::ClientBuilder;
use serde::{Deserialize, Serialize};
use std::time::Duration;
use std::io::{Write, Cursor};
//use base64::{Engine, engine::general_purpose::STANDARD};
use serde_json::json;
use rodio::Decoder;
use rig::{
agent::stream_to_stdout,
prelude::*,
providers::openai,
streaming::StreamingChat,
message::{Message, Image, ImageMediaType, DocumentSourceKind, ImageDetail},
client::audio_generation::AudioGenerationClient,
audio_generation::AudioGenerationModel,
};
#[derive(Serialize, Deserialize, Clone, Debug)]
struct Config {
base_url: String,
key: String,
model: String,
#[serde(skip_serializing_if = "Option::is_none")]
audio_model: Option<String>,
#[serde(skip_serializing_if = "Option::is_none")]
audio_voice: Option<String>,
system_prompt: String,
timeout: u64,
max_tokens: u64,
temp: f64,
}
impl std::default::Default for Config {
fn default() -> Self {
Self {
base_url: String::from("https://api.openai.com/v1"),
key: String::from("sk-..."),
model: String::from("gpt-4o"),
audio_model: None,
audio_voice: None,
system_prompt: String::from("You are a helpful assistant!"),
timeout: 30,
max_tokens: 4096,
temp: 0.4,
}
}
}
#[tokio::main]
async fn main() -> Result<(), Box<dyn std::error::Error>> {
eprintln!("Starting setup");
eprintln!("Loading Config");
let config: Config = confy::load("violet", Some("violet"))?;
println!(
"Config file location: {}",
confy::get_configuration_file_path("violet", None)?
.as_path()
.to_str()
.unwrap_or("path does not exist")
);
eprintln!("Config Loaded");
let conn_timeout = if config.timeout < 30 {
config.timeout
} else if config.timeout < 300 {
config.timeout / 2
} else {
config.timeout / 4
};
let http_client = ClientBuilder::new()
.user_agent("violet-rs/0.1")
.read_timeout(Duration::from_secs(config.timeout))
.connect_timeout(Duration::from_secs(conn_timeout))
.build()?;
let date: String = Utc::now().date_naive().to_string();
let system_prompt: String = format!("The current date is {date}.\n\n{}", &config.system_prompt);
eprintln!("System Prompt is: {system_prompt}");
let api = openai::ClientBuilder::new_with_client(&config.key, http_client)
.base_url(&config.base_url)
.build();
let violet = api.completion_model(&config.model)
.completions_api()
.into_agent_builder()
.preamble(&system_prompt)
.max_tokens(config.max_tokens)
.temperature(config.temp)
.build();
let audio_model = if let Some(model) = &config.audio_model {
model
} else {
"tts-1"
};
let audio_voice = if let Some(voice) = &config.audio_voice {
voice
} else {
"alloy"
};
let violet_voice = api.audio_generation_model(audio_model);
eprintln!("Base Request Setup");
let mut history: Vec<Message> = Vec::new();
eprintln!("Getting Audio Device");
let stream_handle = rodio::OutputStreamBuilder::open_default_stream()?;
let _sink = rodio::Sink::connect_new(&stream_handle.mixer());
eprintln!("Setup Finished");
let mut s = String::new();
print!("> ");
let _ = std::io::stdout().flush();
if let Err(e) = std::io::stdin().read_line(&mut s) {
eprintln!("Error reading stdin: {e}");
}
let mut uwu = true;
if "stop" == s.as_str().to_lowercase().trim() {
uwu = false;
}
while uwu {
let mut stream = violet
.stream_chat(&s, history.clone())
.await;
let res = stream_to_stdout(&mut stream).await?;
print!("\n");
let vres = violet_voice
.audio_generation_request()
.text(res.response())
.voice(audio_voice)
.additional_params(json!({"response_format": "mp3"}))
.send()
.await?;
let vdata = Decoder::new(Cursor::new(vres.audio.clone()))?;
stream_handle.mixer().add(vdata);
history.push(Message::user(s.clone()));
history.push(Message::assistant(res.response()));
print!("> ");
s = String::new();
let _ = std::io::stdout().flush();
if let Err(e) = std::io::stdin().read_line(&mut s) {
eprintln!("Error reading stdin: {e}");
}
if s.as_str().to_lowercase().trim() == "stop" {
break;
}
}
Ok(())
}
|