Trying to get token count right
This commit is contained in:
+16
-8
@@ -324,7 +324,7 @@ def main():
|
|||||||
|
|
||||||
config = types.GenerateContentConfig(**config_kwargs)
|
config = types.GenerateContentConfig(**config_kwargs)
|
||||||
|
|
||||||
print("Generating response (this may take a moment for large outputs)...\n")
|
print("Generating response (this may take a moment for large outputs)...")
|
||||||
|
|
||||||
full_response_text = ""
|
full_response_text = ""
|
||||||
usage_metadata = None
|
usage_metadata = None
|
||||||
@@ -348,8 +348,9 @@ def main():
|
|||||||
usage_metadata = chunk.usage_metadata
|
usage_metadata = chunk.usage_metadata
|
||||||
if chunk.candidates and chunk.candidates[0].finish_reason:
|
if chunk.candidates and chunk.candidates[0].finish_reason:
|
||||||
finish_reason_str = chunk.candidates[0].finish_reason.name
|
finish_reason_str = chunk.candidates[0].finish_reason.name
|
||||||
print(f"Done! Raw output saved directly to {args.output}", file=sys.stderr)
|
print(f"Done! Raw output saved directly to {args.output}")
|
||||||
else:
|
else:
|
||||||
|
print()
|
||||||
for chunk in response_stream:
|
for chunk in response_stream:
|
||||||
if chunk.text:
|
if chunk.text:
|
||||||
print(chunk.text, end="", flush=True)
|
print(chunk.text, end="", flush=True)
|
||||||
@@ -358,10 +359,10 @@ def main():
|
|||||||
usage_metadata = chunk.usage_metadata
|
usage_metadata = chunk.usage_metadata
|
||||||
if chunk.candidates and chunk.candidates[0].finish_reason:
|
if chunk.candidates and chunk.candidates[0].finish_reason:
|
||||||
finish_reason_str = chunk.candidates[0].finish_reason.name
|
finish_reason_str = chunk.candidates[0].finish_reason.name
|
||||||
print()
|
print("\n")
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
# Catch the 404 Model Not Found error (or other API failures) gracefully
|
# Catch 404 Model Not Found or other critical generation errors gracefully
|
||||||
if "404" in str(e) and "NOT_FOUND" in str(e):
|
if "404" in str(e) and "NOT_FOUND" in str(e):
|
||||||
print(f"\n[Error] The model '{args.model}' does not exist or is not available.", file=sys.stderr)
|
print(f"\n[Error] The model '{args.model}' does not exist or is not available.", file=sys.stderr)
|
||||||
else:
|
else:
|
||||||
@@ -378,16 +379,23 @@ def main():
|
|||||||
output_tokens = usage_metadata.candidates_token_count or 0
|
output_tokens = usage_metadata.candidates_token_count or 0
|
||||||
cached_tokens = getattr(usage_metadata, 'cached_content_token_count', 0) or 0
|
cached_tokens = getattr(usage_metadata, 'cached_content_token_count', 0) or 0
|
||||||
|
|
||||||
uncached_tokens = max(0, prompt_tokens - cached_tokens)
|
# Handle API variations where prompt_token_count might strictly be the uncached tokens
|
||||||
rates = pricing_data.get(args.model, {})
|
if prompt_tokens >= cached_tokens:
|
||||||
|
uncached_tokens = prompt_tokens - cached_tokens
|
||||||
|
else:
|
||||||
|
uncached_tokens = prompt_tokens
|
||||||
|
|
||||||
|
total_input_tokens = uncached_tokens + cached_tokens
|
||||||
|
|
||||||
|
rates = pricing_data.get(args.model, {"input": 0.0, "cached": 0.0, "output": 0.0})
|
||||||
|
|
||||||
input_rate = rates.get("input", 0.0)
|
input_rate = rates.get("input", 0.0)
|
||||||
cached_rate = rates.get("cached", 0.0)
|
cached_rate = rates.get("cached", 0.0)
|
||||||
output_rate = rates.get("output", 0.0)
|
output_rate = rates.get("output", 0.0)
|
||||||
tier_label = "Base Tier"
|
tier_label = "Base Tier"
|
||||||
|
|
||||||
# Check if prompt exceeded 200k tokens to apply tier pricing
|
# Apply 200k+ tier pricing if total prompt size exceeds 200k
|
||||||
if prompt_tokens > 200_000 and "input_over_200k" in rates:
|
if total_input_tokens > 200_000 and "input_over_200k" in rates:
|
||||||
if rates.get("input_over_200k", 0.0) > 0:
|
if rates.get("input_over_200k", 0.0) > 0:
|
||||||
input_rate = rates["input_over_200k"]
|
input_rate = rates["input_over_200k"]
|
||||||
tier_label = ">200k Tier"
|
tier_label = ">200k Tier"
|
||||||
|
|||||||
Reference in New Issue
Block a user