about summary refs log tree commit diff stats
diff options
context:
space:
mode:
-rw-r--r--src/main.rs69
1 files changed, 48 insertions, 21 deletions
diff --git a/src/main.rs b/src/main.rs
index a9ff484..943104d 100644
--- a/src/main.rs
+++ b/src/main.rs
@@ -18,6 +18,10 @@ async fn get_horoscope(sign: &str) -> String {
     format!("{sign}: Next Tuesday you will befriend a baby otter.")
 }
 
+async fn wikipedia_lookup(title: &str) -> String {
+    format!("{title}: This article is the article on {title} and is quite interesting.  {title} has several toes.  Many more toes than a {title} should have.")
+}
+
 #[tokio::main]
 async fn main() -> Result<(), Box<dyn std::error::Error>> {
     eprintln!("Starting setup");
@@ -82,39 +86,50 @@ async fn main() -> Result<(), Box<dyn std::error::Error>> {
     eprintln!("System Prompt is: {system_prompt}");
     eprintln!("Base Request Setup");
     eprintln!("Setup Finished");
-    loop {
-        let mut s = String::new();
-        print!("user> ");
-        std::io::stdout().flush()?;
-        if let Err(e) = std::io::stdin().read_line(&mut s) {
-            eprintln!("could not read stdin: {e}");
-            break;
-        }
-        if s.as_str().trim().to_lowercase() == "stop" || s.as_str().trim().to_lowercase() == "quit" {
-            return Ok(());
-        }
-        req.messages.push(Message::text(Role::User, &s));
-        let mut response = client.chat(&req).await?;
+    let mut s = String::new();
+    print!("user> ");
+    std::io::stdout().flush()?;
+    if let Err(e) = std::io::stdin().read_line(&mut s) {
+        eprintln!("could not read stdin: {e}");
+        return Ok(())
+    }
+    if s.as_str().trim().to_lowercase() == "stop" || s.as_str().trim().to_lowercase() == "quit" {
+        return Ok(());
+    }
+    req.messages.push(Message::text(Role::User, &s));
+    let mut uwu: bool = true;
+    while uwu {
+        let response = client.chat(&req).await?;
+        eprintln!("{:?}", &response);
         for choice in response.choices.clone() {
             if let Some(reason) = choice.finish_reason {
                 match reason.as_str() {
                     "tool_calls" => {
-                        let mut need_call = false;
                         if let Some(calls) = choice.message.clone().tool_calls {
                             for call in calls {
                                 match call.function.name.as_str() {
                                     "get_horoscope" => {
+                                        eprintln!("get_horoscope!");
                                         let v: Value = serde_json::from_str(call.function.arguments.as_str())?;
                                         req.messages.push(choice.message.clone());
                                         let v = v["sign"].as_str().unwrap_or_default();
                                         let val: String = get_horoscope(v).await;
                                         req.messages.push(Message::tool(val, call.id));
-                                        need_call = true;
+                                        eprintln!("{:?}", &req);
                                     },
+                                    "wikipedia_lookup" => {
+                                        eprintln!("wikipedia_lookup!");
+                                        let v: Value = serde_json::from_str(call.function.arguments.as_str())?;
+                                        req.messages.push(choice.message.clone());
+                                        let v = v["title"].as_str().unwrap_or_default();
+                                        let val: String = wikipedia_lookup(v).await;
+                                        req.messages.push(Message::tool(val, call.id));
+                                        eprintln!("{:?}", &req);
+                                    }
                                     "stop" => {
                                         println!("Agent has stopped the conversation.");
                                         if let Some(reasoning) = response.reasoning_content {
-                                            eprintln!("Agent Stopped with Reasoning: {}", reasoning.as_str().trim_start().trim_end());
+                                            eprintln!("Agent Stopped with Reasoning: {}", reasoning.as_str().trim());
                                         } else {
                                             eprintln!("Agent has not given reasoning for emergency stop.");
                                         }
@@ -123,17 +138,29 @@ async fn main() -> Result<(), Box<dyn std::error::Error>> {
                                     _ => (),
                                 }
                             }
-                            if need_call {
-                                response = client.chat(&req).await?;
-                            }
                         }
                     },
+                    "stop" => {
+                        println!("agent> {}", &response.content.trim());
+                        req.messages.push(Message::text(Role::Assistant, &response.content.clone()));
+                        let mut s = String::new();
+                        print!("user> ");
+                        std::io::stdout().flush()?;
+                        if let Err(e) = std::io::stdin().read_line(&mut s) {
+                            eprintln!("could not read stdin: {e}");
+                            uwu = false;
+                            break;
+                        }
+                        if s.as_str().trim().to_lowercase() == "stop" || s.as_str().trim().to_lowercase() == "quit" {
+                            uwu = false;
+                            break;
+                        }
+                        req.messages.push(Message::text(Role::User, &s));
+                    },
                     _ => (),
                 }
             }
         }
-        println!("agent> {}", &response.content.trim_start().trim_end());
-        req.messages.push(Message::text(Role::Assistant, &response.content.clone()));
     }
     Ok(())
 }