about summary refs log tree commit diff stats
path: root/src/main.rs
blob: b7ff8673cf01d5ca084ff572580db9cb73ac708e (plain) (blame)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
use chrono::Utc;
use reqwest::ClientBuilder;
use serde::Deserialize;
use std::fs::read_to_string;
use std::time::Duration;
use std::io::Write;

use rig::{
    agent::stream_to_stdout,
    prelude::*,
    providers::openai,
    streaming::StreamingChat,
    message::Message,
    client::audio_generation::AudioGenerationClient,
    audio_generation::AudioGenerationModel,
};

#[derive(Deserialize, Clone, Debug)]
struct Config {
    base_url: String,
    key: String,
    model: String,
    audio_model: Option<String>,
    audio_voice: Option<String>,
    system_prompt: String,
    timeout: u64,
    max_tokens: u64,
    temp: f64,
}

#[tokio::main]
async fn main() -> Result<(), Box<dyn std::error::Error>> {
    eprintln!("Starting setup");
    eprintln!("Loading Config");
    let config = read_to_string("config.json")?;
    let config: Config = serde_json::from_str(&config)?;
    eprintln!("Config Loaded");
    let conn_timeout = if config.timeout < 30 {
        config.timeout
    } else if config.timeout < 300 {
        config.timeout / 2
    } else {
        config.timeout / 4
    };
    let http_client = ClientBuilder::new()
        .user_agent("violet-rs/0.1")
        .read_timeout(Duration::from_secs(config.timeout))
        .connect_timeout(Duration::from_secs(conn_timeout))
        .build()?;
    let date: String = Utc::now().date_naive().to_string();
    let system_prompt: String = format!("The current date is {date}.  {}", &config.system_prompt);
    eprintln!("System Prompt is: {system_prompt}");
    let api = openai::ClientBuilder::new_with_client(&config.key, http_client)
        .base_url(&config.base_url)
        .build();
    let violet = api.completion_model(&config.model)
        .completions_api()
        .into_agent_builder()
        .preamble(&system_prompt)
        .max_tokens(config.max_tokens)
        .temperature(config.temp)
        .build();
    let audio_model = if let Some(model) = &config.audio_model {
        model
    } else {
        "cosyvoice"
    };
    let _audio_voice = if let Some(voice) = &config.audio_voice {
        voice
    } else {
        "english_female"
    };
    let _violet_voice = api.audio_generation_model(audio_model);
    eprintln!("Base Request Setup");
    eprintln!("Setup Finished");
    let mut s = String::new();
    print!("> ");
    let _ = std::io::stdout().flush();
    if let Err(e) = std::io::stdin().read_line(&mut s) {
        eprintln!("Error reading stdin: {e}");
    }
    let mut history: Vec<Message> = Vec::new();
    let mut uwu = true;
    if "stop" == s.as_str().to_lowercase().trim() {
        uwu = false;
    }
    while uwu {
        let mut stream = violet
            .stream_chat(&s, history.clone())
            .await;
        let res = stream_to_stdout(&mut stream).await?;
        print!("\n");
        //let vres = violet_voice
        //    .audio_generation_request()
        //    .text(res.response())
        //    .voice(audio_voice)
        //    .send()
        //    .await?;
        history.push(Message::user(s.clone()));
        history.push(Message::assistant(res.response()));
        print!("> ");
        s = String::new();
        let _ = std::io::stdout().flush();
        if let Err(e) = std::io::stdin().read_line(&mut s) {
            eprintln!("Error reading stdin: {e}");
        }
        if s.as_str().to_lowercase().trim() == "stop" {
            break;
        }
    }
    Ok(())
}