about summary refs log tree commit diff stats
path: root/src/main.rs
diff options
context:
space:
mode:
Diffstat (limited to 'src/main.rs')
-rw-r--r--src/main.rs19
1 files changed, 18 insertions, 1 deletions
diff --git a/src/main.rs b/src/main.rs
index 12cb491..a9ff484 100644
--- a/src/main.rs
+++ b/src/main.rs
@@ -54,6 +54,19 @@ async fn main() -> Result<(), Box<dyn std::error::Error>> {
             parameters: json!({}),
         },
     });
+    tools.push(Tool {
+        tool_type: "function".into(),
+        function: Function {
+            name: "wikipedia_lookup".into(),
+            description: Some("Look up a wikipedia article and have its summary returned.".into()),
+            parameters: json!({
+                "title": {
+                    "type": "string",
+                    "description": "The title of the article to look up."
+                }
+            }),
+        },
+    });
     eprintln!("Tools Loaded");
     let date: String = Utc::now().date_naive().to_string();
     let system_prompt: String = format!("The current date is {date}.  {}", &config.system_prompt);
@@ -86,6 +99,7 @@ async fn main() -> Result<(), Box<dyn std::error::Error>> {
             if let Some(reason) = choice.finish_reason {
                 match reason.as_str() {
                     "tool_calls" => {
+                        let mut need_call = false;
                         if let Some(calls) = choice.message.clone().tool_calls {
                             for call in calls {
                                 match call.function.name.as_str() {
@@ -95,7 +109,7 @@ async fn main() -> Result<(), Box<dyn std::error::Error>> {
                                         let v = v["sign"].as_str().unwrap_or_default();
                                         let val: String = get_horoscope(v).await;
                                         req.messages.push(Message::tool(val, call.id));
-                                        response = client.chat(&req).await?;
+                                        need_call = true;
                                     },
                                     "stop" => {
                                         println!("Agent has stopped the conversation.");
@@ -109,6 +123,9 @@ async fn main() -> Result<(), Box<dyn std::error::Error>> {
                                     _ => (),
                                 }
                             }
+                            if need_call {
+                                response = client.chat(&req).await?;
+                            }
                         }
                     },
                     _ => (),