about summary refs log tree commit diff stats
path: root/src/main.rs
blob: a9ff48404b5d84c690cf2701dab769ffc2a5650b (plain) (blame)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
use llm_connector::{LlmClient, types::{ChatRequest, Message, Role, Tool, Function}};
use serde_json::{json, Value};
use serde::Deserialize;
use std::fs::read_to_string;
use std::io::Write;
use chrono::Utc;

#[derive(Deserialize, Clone, Debug)]
struct Config {
    base_url: String,
    key: String,
    model: String,
    system_prompt: String,
    timeout: u64,
}

async fn get_horoscope(sign: &str) -> String {
    format!("{sign}: Next Tuesday you will befriend a baby otter.")
}

#[tokio::main]
async fn main() -> Result<(), Box<dyn std::error::Error>> {
    eprintln!("Starting setup");
    eprintln!("Loading Config");
    let config = read_to_string("config.json")?;
    let config: Config = serde_json::from_str(&config)?;
    eprintln!("Config Loaded");
    let client = LlmClient::openai_with_config(
        &config.key,
        Some(&config.base_url),
        Some(config.timeout),
        None,
    )?;
    eprintln!("Config Setup");
    let mut tools: Vec<Tool> = Vec::new();
    tools.push(Tool {
        tool_type: "function".into(),
        function: Function {
            name: "get_horoscope".into(),
            description: Some("Get today's horoscope for an astrological sign.".into()),
            parameters: json!({
                "sign": {
                    "type": "string",
                    "description": "An astrological sign like Taurus or Aquarius."
                }
            }),
        },
    });
    tools.push(Tool {
        tool_type: "function".into(),
        function: Function {
            name: "stop".into(),
            description: Some("Emergency Stop the Conversation.  Only to be used when the user is requesting something dangerous.".into()),
            parameters: json!({}),
        },
    });
    tools.push(Tool {
        tool_type: "function".into(),
        function: Function {
            name: "wikipedia_lookup".into(),
            description: Some("Look up a wikipedia article and have its summary returned.".into()),
            parameters: json!({
                "title": {
                    "type": "string",
                    "description": "The title of the article to look up."
                }
            }),
        },
    });
    eprintln!("Tools Loaded");
    let date: String = Utc::now().date_naive().to_string();
    let system_prompt: String = format!("The current date is {date}.  {}", &config.system_prompt);
    let mut req = ChatRequest {
        model: config.model,
        messages: vec![
            Message::text(Role::System, &system_prompt),
        ],
        tools: Some(tools),
        //temperature: Some(0.9),
        ..Default::default()
    };
    eprintln!("System Prompt is: {system_prompt}");
    eprintln!("Base Request Setup");
    eprintln!("Setup Finished");
    loop {
        let mut s = String::new();
        print!("user> ");
        std::io::stdout().flush()?;
        if let Err(e) = std::io::stdin().read_line(&mut s) {
            eprintln!("could not read stdin: {e}");
            break;
        }
        if s.as_str().trim().to_lowercase() == "stop" || s.as_str().trim().to_lowercase() == "quit" {
            return Ok(());
        }
        req.messages.push(Message::text(Role::User, &s));
        let mut response = client.chat(&req).await?;
        for choice in response.choices.clone() {
            if let Some(reason) = choice.finish_reason {
                match reason.as_str() {
                    "tool_calls" => {
                        let mut need_call = false;
                        if let Some(calls) = choice.message.clone().tool_calls {
                            for call in calls {
                                match call.function.name.as_str() {
                                    "get_horoscope" => {
                                        let v: Value = serde_json::from_str(call.function.arguments.as_str())?;
                                        req.messages.push(choice.message.clone());
                                        let v = v["sign"].as_str().unwrap_or_default();
                                        let val: String = get_horoscope(v).await;
                                        req.messages.push(Message::tool(val, call.id));
                                        need_call = true;
                                    },
                                    "stop" => {
                                        println!("Agent has stopped the conversation.");
                                        if let Some(reasoning) = response.reasoning_content {
                                            eprintln!("Agent Stopped with Reasoning: {}", reasoning.as_str().trim_start().trim_end());
                                        } else {
                                            eprintln!("Agent has not given reasoning for emergency stop.");
                                        }
                                        return Ok(());
                                    }
                                    _ => (),
                                }
                            }
                            if need_call {
                                response = client.chat(&req).await?;
                            }
                        }
                    },
                    _ => (),
                }
            }
        }
        println!("agent> {}", &response.content.trim_start().trim_end());
        req.messages.push(Message::text(Role::Assistant, &response.content.clone()));
    }
    Ok(())
}