From 7e6a2c782b0e7a77c7c30aef00e5cde4775326b4 Mon Sep 17 00:00:00 2001 From: Abijah Date: Thu, 4 Jun 2026 17:38:56 -0700 Subject: [PATCH] Trying to get token count right --- gemini/gemini.py | 26 +++++++++++++++++--------- 1 file changed, 17 insertions(+), 9 deletions(-) diff --git a/gemini/gemini.py b/gemini/gemini.py index bec079c..1b4c008 100644 --- a/gemini/gemini.py +++ b/gemini/gemini.py @@ -324,7 +324,7 @@ def main(): config = types.GenerateContentConfig(**config_kwargs) - print("Generating response (this may take a moment for large outputs)...\n") + print("Generating response (this may take a moment for large outputs)...") full_response_text = "" usage_metadata = None @@ -348,8 +348,9 @@ def main(): usage_metadata = chunk.usage_metadata if chunk.candidates and chunk.candidates[0].finish_reason: finish_reason_str = chunk.candidates[0].finish_reason.name - print(f"Done! Raw output saved directly to {args.output}", file=sys.stderr) + print(f"Done! Raw output saved directly to {args.output}") else: + print() for chunk in response_stream: if chunk.text: print(chunk.text, end="", flush=True) @@ -358,10 +359,10 @@ def main(): usage_metadata = chunk.usage_metadata if chunk.candidates and chunk.candidates[0].finish_reason: finish_reason_str = chunk.candidates[0].finish_reason.name - print() - + print("\n") + except Exception as e: - # Catch the 404 Model Not Found error (or other API failures) gracefully + # Catch 404 Model Not Found or other critical generation errors gracefully if "404" in str(e) and "NOT_FOUND" in str(e): print(f"\n[Error] The model '{args.model}' does not exist or is not available.", file=sys.stderr) else: @@ -378,16 +379,23 @@ def main(): output_tokens = usage_metadata.candidates_token_count or 0 cached_tokens = getattr(usage_metadata, 'cached_content_token_count', 0) or 0 - uncached_tokens = max(0, prompt_tokens - cached_tokens) - rates = pricing_data.get(args.model, {}) + # Handle API variations where prompt_token_count might strictly be the uncached tokens + if prompt_tokens >= cached_tokens: + uncached_tokens = prompt_tokens - cached_tokens + else: + uncached_tokens = prompt_tokens + + total_input_tokens = uncached_tokens + cached_tokens + + rates = pricing_data.get(args.model, {"input": 0.0, "cached": 0.0, "output": 0.0}) input_rate = rates.get("input", 0.0) cached_rate = rates.get("cached", 0.0) output_rate = rates.get("output", 0.0) tier_label = "Base Tier" - # Check if prompt exceeded 200k tokens to apply tier pricing - if prompt_tokens > 200_000 and "input_over_200k" in rates: + # Apply 200k+ tier pricing if total prompt size exceeds 200k + if total_input_tokens > 200_000 and "input_over_200k" in rates: if rates.get("input_over_200k", 0.0) > 0: input_rate = rates["input_over_200k"] tier_label = ">200k Tier"