mirror of
https://github.com/block/goose.git
synced 2026-07-03 14:15:10 +02:00
fix(cli): report cumulative total_tokens in stream-json/json output (#8910)
Signed-off-by: Trinity <trinity@multica.ai> Signed-off-by: Bright Zheng <bzqzheng@gmail.com> Signed-off-by: Douwe Osinga <douwe@squareup.com> Co-authored-by: Douwe Osinga <douwe@squareup.com>
This commit is contained in:
@@ -1501,10 +1501,9 @@ impl CliSession {
|
||||
.await
|
||||
}
|
||||
|
||||
// Get the session's total token usage
|
||||
pub async fn get_total_token_usage(&self) -> Result<Option<i32>> {
|
||||
let metadata = self.get_session().await?;
|
||||
Ok(metadata.total_tokens)
|
||||
Ok(metadata.accumulated_total_tokens)
|
||||
}
|
||||
|
||||
/// Display enhanced context usage with session totals
|
||||
|
||||
@@ -1119,6 +1119,121 @@ mod tests {
|
||||
}
|
||||
}
|
||||
|
||||
mod cumulative_token_tests {
|
||||
use super::*;
|
||||
use async_trait::async_trait;
|
||||
use goose::agents::{AgentConfig, SessionConfig};
|
||||
use goose::config::permission::PermissionManager;
|
||||
use goose::config::GooseMode;
|
||||
use goose::conversation::message::Message;
|
||||
use goose::model::ModelConfig;
|
||||
use goose::providers::base::{
|
||||
stream_from_single_message, MessageStream, Provider, ProviderUsage, Usage,
|
||||
};
|
||||
use goose::providers::errors::ProviderError;
|
||||
use goose::session::session_manager::SessionType;
|
||||
use goose::session::SessionManager;
|
||||
use rmcp::model::Tool;
|
||||
use std::path::PathBuf;
|
||||
use std::sync::Arc;
|
||||
|
||||
struct FixedUsageProvider {
|
||||
input_tokens: i32,
|
||||
output_tokens: i32,
|
||||
}
|
||||
|
||||
#[async_trait]
|
||||
impl Provider for FixedUsageProvider {
|
||||
async fn stream(
|
||||
&self,
|
||||
_model_config: &ModelConfig,
|
||||
_session_id: &str,
|
||||
_system_prompt: &str,
|
||||
_messages: &[Message],
|
||||
_tools: &[Tool],
|
||||
) -> Result<MessageStream, ProviderError> {
|
||||
let total = self.input_tokens + self.output_tokens;
|
||||
let usage = ProviderUsage::new(
|
||||
"mock-model".to_string(),
|
||||
Usage::new(
|
||||
Some(self.input_tokens),
|
||||
Some(self.output_tokens),
|
||||
Some(total),
|
||||
),
|
||||
);
|
||||
let message = Message::assistant().with_text("Hello");
|
||||
Ok(stream_from_single_message(message, usage))
|
||||
}
|
||||
|
||||
fn get_model_config(&self) -> ModelConfig {
|
||||
ModelConfig::new("mock-model").unwrap()
|
||||
}
|
||||
|
||||
fn get_name(&self) -> &str {
|
||||
"fixed-usage-mock"
|
||||
}
|
||||
}
|
||||
|
||||
async fn run_turn(agent: &Agent, session_id: &str, text: &str) -> Result<()> {
|
||||
let session_config = SessionConfig {
|
||||
id: session_id.to_string(),
|
||||
schedule_id: None,
|
||||
max_turns: Some(1),
|
||||
retry_config: None,
|
||||
};
|
||||
let stream = agent
|
||||
.reply(Message::user().with_text(text), session_config, None)
|
||||
.await?;
|
||||
tokio::pin!(stream);
|
||||
while let Some(event) = stream.next().await {
|
||||
let _ = event?;
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_accumulated_total_tokens_across_multiple_turns() -> Result<()> {
|
||||
let temp_dir = tempfile::tempdir()?;
|
||||
let session_manager = Arc::new(SessionManager::new(temp_dir.path().to_path_buf()));
|
||||
let config = AgentConfig::new(
|
||||
session_manager.clone(),
|
||||
PermissionManager::instance(),
|
||||
None,
|
||||
GooseMode::Auto,
|
||||
true,
|
||||
GoosePlatform::GooseCli,
|
||||
);
|
||||
let agent = Agent::with_config(config);
|
||||
let provider = Arc::new(FixedUsageProvider {
|
||||
input_tokens: 10,
|
||||
output_tokens: 5,
|
||||
});
|
||||
|
||||
let session = session_manager
|
||||
.create_session(
|
||||
PathBuf::default(),
|
||||
"cumulative-token-test".to_string(),
|
||||
SessionType::Hidden,
|
||||
GooseMode::default(),
|
||||
)
|
||||
.await?;
|
||||
|
||||
let session_id = session.id.clone();
|
||||
agent.update_provider(provider.clone(), &session_id).await?;
|
||||
|
||||
run_turn(&agent, &session_id, "Turn 1").await?;
|
||||
let after_1 = session_manager.get_session(&session_id, false).await?;
|
||||
assert_eq!(after_1.accumulated_total_tokens, Some(15));
|
||||
|
||||
run_turn(&agent, &session_id, "Turn 2").await?;
|
||||
let after_2 = session_manager.get_session(&session_id, false).await?;
|
||||
assert_eq!(after_2.accumulated_total_tokens, Some(30));
|
||||
assert_eq!(after_2.total_tokens, Some(15));
|
||||
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
mod frontend_extension_tests {
|
||||
use super::*;
|
||||
use goose::agents::{AgentConfig, ExtensionConfig};
|
||||
|
||||
Reference in New Issue
Block a user