1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
|
use llm_connector::{LlmClient, types::{ChatRequest, Message, Role, Tool, Function}};
use serde_json::{json, Value};
use serde::Deserialize;
use std::fs::read_to_string;
use std::io::Write;
use chrono::Utc;
#[derive(Deserialize, Clone, Debug)]
struct Config {
base_url: String,
key: String,
model: String,
system_prompt: String,
timeout: u64,
}
async fn get_horoscope(sign: &str) -> String {
format!("{sign}: Next Tuesday you will befriend a baby otter.")
}
#[tokio::main]
async fn main() -> Result<(), Box<dyn std::error::Error>> {
eprintln!("Starting setup");
eprintln!("Loading Config");
let config = read_to_string("config.json")?;
let config: Config = serde_json::from_str(&config)?;
eprintln!("Config Loaded");
let client = LlmClient::openai_with_config(
&config.key,
Some(&config.base_url),
Some(config.timeout),
None,
)?;
eprintln!("Config Setup");
let mut tools: Vec<Tool> = Vec::new();
tools.push(Tool {
tool_type: "function".into(),
function: Function {
name: "get_horoscope".into(),
description: Some("Get today's horoscope for an astrological sign.".into()),
parameters: json!({
"sign": {
"type": "string",
"description": "An astrological sign like Taurus or Aquarius."
}
}),
},
});
tools.push(Tool {
tool_type: "function".into(),
function: Function {
name: "stop".into(),
description: Some("Emergency Stop the Conversation. Only to be used when the user is requesting something dangerous.".into()),
parameters: json!({}),
},
});
eprintln!("Tools Loaded");
let date: String = Utc::now().date_naive().to_string();
let system_prompt: String = format!("The current date is {date}. {}", &config.system_prompt);
let mut req = ChatRequest {
model: config.model,
messages: vec![
Message::text(Role::System, &system_prompt),
],
tools: Some(tools),
//temperature: Some(0.9),
..Default::default()
};
eprintln!("System Prompt is: {system_prompt}");
eprintln!("Base Request Setup");
eprintln!("Setup Finished");
loop {
let mut s = String::new();
print!("user> ");
std::io::stdout().flush()?;
if let Err(e) = std::io::stdin().read_line(&mut s) {
eprintln!("could not read stdin: {e}");
break;
}
if s.as_str().trim().to_lowercase() == "stop" || s.as_str().trim().to_lowercase() == "quit" {
return Ok(());
}
req.messages.push(Message::text(Role::User, &s));
let mut response = client.chat(&req).await?;
for choice in response.choices.clone() {
if let Some(reason) = choice.finish_reason {
match reason.as_str() {
"tool_calls" => {
if let Some(calls) = choice.message.clone().tool_calls {
for call in calls {
match call.function.name.as_str() {
"get_horoscope" => {
let v: Value = serde_json::from_str(call.function.arguments.as_str())?;
req.messages.push(choice.message.clone());
let v = v["sign"].as_str().unwrap_or_default();
let val: String = get_horoscope(v).await;
req.messages.push(Message::tool(val, call.id));
response = client.chat(&req).await?;
},
"stop" => {
println!("Agent has stopped the conversation.");
if let Some(reasoning) = response.reasoning_content {
eprintln!("Agent Stopped with Reasoning: {}", reasoning.as_str().trim_start().trim_end());
} else {
eprintln!("Agent has not given reasoning for emergency stop.");
}
return Ok(());
}
_ => (),
}
}
}
},
_ => (),
}
}
}
println!("agent> {}", &response.content.trim_start().trim_end());
req.messages.push(Message::text(Role::Assistant, &response.content.clone()));
}
Ok(())
}
|