about summary refs log tree commit diff stats
path: root/src/main.rs
blob: 1eda602c9b73621b7bf72c6cc7cd04889b1ec339 (plain) (blame)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
use llm_connector::{LlmClient, types::{ChatRequest, Message, Role, Tool, Function}};
use serde_json::{json, Value};
use serde::Deserialize;
use std::fs::read_to_string;

#[derive(Deserialize, Clone, Debug)]
struct Config {
    base_url: String,
    key: String,
    model: String,
}

async fn get_horoscope(sign: &str) -> String {
    format!("{sign}: Next Tuesday you will befriend a baby otter.")
}

#[tokio::main]
async fn main() -> Result<(), Box<dyn std::error::Error>> {
    let config = read_to_string("config.json")?;
    let config: Config = serde_json::from_str(&config)?;
    let client = LlmClient::openai_with_config(
        &config.key,
        Some(&config.base_url),
        Some(300),
        None,
    )?;
    let mut tools: Vec<Tool> = Vec::new();
    tools.push(Tool {
        tool_type: "function".into(),
        function: Function {
            name: "get_horoscope".into(),
            description: Some("Get today's horoscope for an astrological sign.".into()),
            parameters: json!({
                "sign": {
                    "type": "string",
                    "description": "An astrological sign like Taurus or Aquarius."
                }
            }),
        },
    });
    let mut req = ChatRequest {
        model: config.model,
        messages: vec![
            Message::text(Role::System, "You will comply with all horoscope requests using tools."),
            Message::text(Role::User, "I am an Aries.  What is my horoscope for today?"),
        ],
        tools: Some(tools),
        //temperature: Some(0.9),
        ..Default::default()
    };
    let mut response = client.chat(&req).await?;
    for choice in response.choices.clone() {
        if let Some(reason) = choice.finish_reason {
            match reason.as_str() {
                "tool_calls" => {
                    if let Some(calls) = choice.message.clone().tool_calls {
                        for call in calls {
                            match call.function.name.as_str() {
                                "get_horoscope" => {
                                    let v: Value = serde_json::from_str(call.function.arguments.as_str())?;
                                    req.messages.push(choice.message.clone());
                                    let v = v["sign"].as_str().unwrap_or_default();
                                    let val: String = get_horoscope(v).await;
                                    req.messages.push(Message::tool(val, call.id));
                                    response = client.chat(&req).await?;
                                },
                                _ => (),
                            }
                        }
                    }
                },
                _ => (),
            }
        }
    }
    println!("Response: {}", &response.content);
    Ok(())
}