Compare commits

..

10 Commits

2 changed files with 302 additions and 65 deletions
+1
View File
@@ -0,0 +1 @@
source gemini_env/bin/activate
+301 -65
View File
@@ -3,19 +3,88 @@ import argparse
import json import json
import os import os
import sys import sys
import urllib.request
from google import genai from google import genai
from google.genai import types from google.genai import types
PRICING_FILE = ".gemini_pricing.json"
def fetch_pricing_for_model(client, target_model):
url = "https://cloud.google.com/gemini-enterprise-agent-platform/generative-ai/pricing"
print(f"Fetching live pricing for {target_model} from the web...")
try:
req = urllib.request.Request(url, headers={'User-Agent': 'Mozilla/5.0'})
with urllib.request.urlopen(req) as response:
html = response.read().decode('utf-8')
except Exception as e:
print(f"Warning: Failed to fetch HTML from {url}: {e}", file=sys.stderr)
return None
prompt = f"""
Extract the API pricing for the model '{target_model}' from the following HTML text.
Find the cost per 1 million tokens for input, cached content, and output.
Many models have a split tier where the cost increases if the prompt exceeds 200k tokens.
CRITICAL: If the model '{target_model}' is definitively NOT found in the HTML data, return an empty JSON object: {{}}
Otherwise, return ONLY a valid JSON object with this exact structure:
{{
"{target_model}": {{
"input": 0.00,
"cached": 0.00,
"output": 0.00,
"input_over_200k": 0.00,
"cached_over_200k": 0.00,
"output_over_200k": 0.00
}}
}}
If a tier value is not found, duplicate the base tier values.
HTML DATA:
{html}
"""
config = types.GenerateContentConfig(
response_mime_type="application/json",
temperature=0.0
)
print("Parsing pricing data via background AI session...")
try:
# We use a fast, cheap model just for the parsing task
res = client.models.generate_content(
model="gemini-3.1-flash-lite",
contents=prompt,
config=config
)
new_data = json.loads(res.text)
# Strip the array wrapper if the AI returned a list instead of a pure dict
if isinstance(new_data, list) and len(new_data) > 0:
new_data = new_data[0]
print(f"Successfully retrieved pricing for {target_model}.")
print(new_data)
return new_data
except Exception as e:
print(f"Warning: Failed to extract pricing using AI: {e}", file=sys.stderr)
return None
def main(): def main():
parser = argparse.ArgumentParser(description="Gemini API CLI with File & Context Caching") parser = argparse.ArgumentParser(description="Gemini API CLI with File & Context Caching")
parser.add_argument("-c", "--context", type=str, default=None, parser.add_argument("-c", "--context", type=str, default=None,
help="Path to context file. If omitted, files/caches are deleted after execution.") help="Path to context file. If omitted, transient mode is used.")
parser.add_argument("-f", "--files", nargs="+", default=[], parser.add_argument("-f", "--files", nargs="+", default=[],
help="Files to upload to the Gemini API") help="Files to upload to the Gemini API")
parser.add_argument("-m", "--model", type=str, default="gemini-3.1-flash-lite", parser.add_argument("-m", "--model", type=str, default="gemini-3.1-flash-lite",
help="The model to use (default: gemini-3.1-flash-lite)") help="The model to use (default: gemini-3.1-flash-lite)")
parser.add_argument("-d", "--destroy", action="store_true", parser.add_argument("-d", "--destroy", action="store_true",
help="Destroy the cloud files and cache, and delete the local context file") help="Destroy cloud files/cache, and delete local context")
parser.add_argument("-x", "--clear-history", action="store_true",
help="Clear the conversation history without destroying files/caches")
parser.add_argument("--pricing", action="store_true",
help="Force update the pricing info for the specified model from the web")
parser.add_argument("-o", "--output", type=str, parser.add_argument("-o", "--output", type=str,
help="Direct the raw output to a specific file instead of stdout") help="Direct the raw output to a specific file instead of stdout")
parser.add_argument("-p", "--prompt", type=str, parser.add_argument("-p", "--prompt", type=str,
@@ -24,7 +93,6 @@ def main():
help="Positional arguments treated as the prompt if -p is omitted") help="Positional arguments treated as the prompt if -p is omitted")
args = parser.parse_args() args = parser.parse_args()
prompt_text = args.prompt if args.prompt else " ".join(args.positional_prompt) if args.positional_prompt else None prompt_text = args.prompt if args.prompt else " ".join(args.positional_prompt) if args.positional_prompt else None
if not os.environ.get("GEMINI_API_KEY"): if not os.environ.get("GEMINI_API_KEY"):
@@ -32,26 +100,74 @@ def main():
sys.exit(1) sys.exit(1)
client = genai.Client() client = genai.Client()
context_data = {"file_ids": [], "cache_id": None}
# ---------------------------------------------------------
# PRICING CONFIGURATION
# ---------------------------------------------------------
pricing_data = {}
if os.path.exists(PRICING_FILE):
try:
with open(PRICING_FILE, "r") as f:
pricing_data = json.load(f)
except json.JSONDecodeError:
print(f"Warning: {PRICING_FILE} is corrupted. Starting fresh.", file=sys.stderr)
# Fetch pricing if forced, or if the model isn't currently tracked
if args.pricing or args.model not in pricing_data:
new_pricing = fetch_pricing_for_model(client, args.model)
if new_pricing and args.model in new_pricing:
pricing_data.update(new_pricing)
with open(PRICING_FILE, "w") as f:
json.dump(pricing_data, f, indent=4)
else:
print(f"Warning: Could not fetch pricing for {args.model}. Estimating cost at $0.00.", file=sys.stderr)
# Add a fallback zero-value so the script doesn't crash during cost calculation
pricing_data[args.model] = {"input": 0.0, "cached": 0.0, "output": 0.0}
# If the user only requested a pricing update and nothing else, exit cleanly
if args.pricing and not prompt_text and not args.files and not args.destroy and not args.clear_history:
return
# ---------------------------------------------------------
# STATE MANAGEMENT
# ---------------------------------------------------------
context_data = {"file_ids": [], "caches": {}, "history": []}
if args.context and os.path.exists(args.context): if args.context and os.path.exists(args.context):
try: try:
with open(args.context, "r") as f: with open(args.context, "r") as f:
context_data = json.load(f) loaded_data = json.load(f)
context_data["file_ids"] = loaded_data.get("file_ids", [])
context_data["caches"] = loaded_data.get("caches", {})
context_data["history"] = loaded_data.get("history", [])
except json.JSONDecodeError: except json.JSONDecodeError:
print(f"Warning: Could not parse {args.context}. Starting fresh.", file=sys.stderr) print(f"Warning: Could not parse {args.context}. Starting fresh.", file=sys.stderr)
# ---------------------------------------------------------
# CLEAR HISTORY FLAG
# ---------------------------------------------------------
if args.clear_history:
context_data["history"] = []
if args.context:
with open(args.context, "w") as f:
json.dump(context_data, f, indent=4)
print("Conversation history cleared.")
if not prompt_text and not args.files and not args.destroy:
return
# --------------------------------------------------------- # ---------------------------------------------------------
# DESTROY FLAG LOGIC # DESTROY FLAG LOGIC
# --------------------------------------------------------- # ---------------------------------------------------------
if args.destroy: if args.destroy:
print("Destroying server resources and local context...") print("Destroying server resources and local context...")
if context_data.get("cache_id"): for model_name, cache_info in context_data.get("caches", {}).items():
try: try:
client.caches.delete(name=context_data["cache_id"]) client.caches.delete(name=cache_info["cache_id"])
print(f"Deleted cache: {context_data['cache_id']}") print(f"Deleted cache for {model_name}: {cache_info['cache_id']}")
except Exception as e: except Exception as e:
print(f"Warning: Failed to delete cache. {e}", file=sys.stderr) print(f"Warning: Failed to delete cache for {model_name}. {e}", file=sys.stderr)
for file_id in context_data.get("file_ids", []): for file_id in context_data.get("file_ids", []):
try: try:
@@ -71,27 +187,34 @@ def main():
# --------------------------------------------------------- # ---------------------------------------------------------
# UPLOAD LOGIC # UPLOAD LOGIC
# --------------------------------------------------------- # ---------------------------------------------------------
new_files_added = False
if args.files: if args.files:
for file_path in args.files: for file_path in args.files:
if not os.path.exists(file_path): if not os.path.exists(file_path):
print(f"Warning: File '{file_path}' not found. Skipping.", file=sys.stderr) print(f"Warning: File '{file_path}' not found. Skipping.", file=sys.stderr)
continue continue
print(f"Uploading '{file_path}'...") # Sanitize the filename for HTTP headers (replace non-ASCII with underscores)
uploaded_file = client.files.upload(file=file_path) base_name = os.path.basename(file_path)
print(f"Success: '{file_path}' uploaded as '{uploaded_file.name}'") safe_name = "".join([c if ord(c) < 128 else "_" for c in base_name])
print(f"Uploading '{file_path}'...", file=sys.stderr)
# Force the SDK to use our sanitized name for the upload display name
uploaded_file = client.files.upload(
file=file_path,
config={'display_name': safe_name}
)
print(f"Success: '{file_path}' uploaded as '{uploaded_file.name}'", file=sys.stderr)
if uploaded_file.name not in context_data["file_ids"]: if uploaded_file.name not in context_data["file_ids"]:
context_data["file_ids"].append(uploaded_file.name) context_data["file_ids"].append(uploaded_file.name)
new_files_added = True
if args.context: if args.context:
with open(args.context, "w") as f: with open(args.context, "w") as f:
json.dump(context_data, f, indent=4) json.dump(context_data, f, indent=4)
# --------------------------------------------------------- # ---------------------------------------------------------
# CACHE CREATION LOGIC # CACHE VALIDATION & CREATION LOGIC
# --------------------------------------------------------- # ---------------------------------------------------------
system_instruction = ( system_instruction = (
"You are a hybrid data extraction tool. If a specific format or file format output is requested " "You are a hybrid data extraction tool. If a specific format or file format output is requested "
@@ -100,23 +223,47 @@ def main():
"and output raw data only. If not specific data format is suggested, you can answer with conversational text." "and output raw data only. If not specific data format is suggested, you can answer with conversational text."
) )
# If new files were added but a cache already exists, we must destroy the stale cache.
if new_files_added and context_data.get("cache_id"):
print("New files detected. Destroying stale cache to rebuild...")
try:
client.caches.delete(name=context_data["cache_id"])
except Exception as e:
print(f"Warning: Could not delete stale cache. {e}", file=sys.stderr)
context_data["cache_id"] = None
cache_too_small = False cache_too_small = False
file_objects = [] file_objects = []
active_cache_id = None
if context_data.get("file_ids"): if context_data.get("file_ids"):
file_objects = [client.files.get(name=f_id) for f_id in context_data["file_ids"]] file_objects = [client.files.get(name=f_id) for f_id in context_data["file_ids"]]
model_cache_info = context_data["caches"].get(args.model)
current_files_set = set(context_data["file_ids"])
rebuild_cache = False
if model_cache_info:
cached_files_set = set(model_cache_info.get("cached_file_ids", []))
# Check 1: Have the files changed?
if current_files_set != cached_files_set:
print(f"File list changed. Destroying stale {args.model} cache to rebuild...")
try:
client.caches.delete(name=model_cache_info["cache_id"])
except Exception as e:
print(f"Warning: Could not delete stale cache. {e}", file=sys.stderr)
rebuild_cache = True
else:
# Check 2: Does the cache still exist on the server?
print(f"Verifying existing cache for {args.model}: {model_cache_info['cache_id']}...")
try:
client.caches.get(name=model_cache_info['cache_id'])
print("Cache verified. Extending TTL by 60 minutes...")
client.caches.update(
name=model_cache_info['cache_id'],
config=types.UpdateCachedContentConfig(ttl="3600s")
)
active_cache_id = model_cache_info['cache_id']
except Exception as e:
print(f"Cache expired or not found. Flagging for rebuild. ({e})")
rebuild_cache = True
else:
rebuild_cache = True
if not context_data.get("cache_id"): if rebuild_cache:
print("Attempting to create Context Cache on Google's servers...") print(f"Attempting to create Context Cache for {args.model}...")
try: try:
cache = client.caches.create( cache = client.caches.create(
model=args.model, model=args.model,
@@ -126,7 +273,13 @@ def main():
ttl="3600s" ttl="3600s"
) )
) )
context_data["cache_id"] = cache.name
context_data["caches"][args.model] = {
"cache_id": cache.name,
"cached_file_ids": list(context_data["file_ids"])
}
active_cache_id = cache.name
if args.context: if args.context:
with open(args.context, "w") as f: with open(args.context, "w") as f:
json.dump(context_data, f, indent=4) json.dump(context_data, f, indent=4)
@@ -138,20 +291,9 @@ def main():
cache_too_small = True cache_too_small = True
else: else:
raise e raise e
elif not cache_too_small:
print(f"Loading existing cache: {context_data['cache_id']}")
print("Extending cache TTL by 60 minutes...")
try:
client.caches.update(
name=context_data["cache_id"],
config=types.UpdateCachedContentConfig(ttl="3600s")
)
except Exception as e:
print(f"Warning: Failed to update cache TTL. {e}")
# --------------------------------------------------------- # ---------------------------------------------------------
# GENERATION LOGIC # GENERATION LOGIC (WITH HISTORY)
# --------------------------------------------------------- # ---------------------------------------------------------
if prompt_text: if prompt_text:
config_kwargs = { config_kwargs = {
@@ -159,38 +301,132 @@ def main():
"temperature": 0.0 "temperature": 0.0
} }
generation_contents = [] if active_cache_id and not cache_too_small:
config_kwargs["cached_content"] = active_cache_id
if context_data.get("cache_id") and not cache_too_small:
config_kwargs["cached_content"] = context_data["cache_id"]
else: else:
generation_contents.extend(file_objects)
config_kwargs["system_instruction"] = system_instruction config_kwargs["system_instruction"] = system_instruction
# Build the chat payload using strict dictionary representations
api_contents = []
files_added_to_payload = False
for msg in context_data.get("history", []):
parts = []
# If we aren't using a cache, attach un-cached files to the very first user message
if msg["role"] == "user" and not files_added_to_payload and not active_cache_id and file_objects:
for f in file_objects:
parts.append({"file_data": {"file_uri": f.uri, "mime_type": f.mime_type}})
files_added_to_payload = True
parts.append({"text": msg["text"]})
api_contents.append({"role": msg["role"], "parts": parts})
current_parts = []
if not files_added_to_payload and not active_cache_id and file_objects:
for f in file_objects:
current_parts.append({"file_data": {"file_uri": f.uri, "mime_type": f.mime_type}})
files_added_to_payload = True
generation_contents.append(prompt_text) current_parts.append({"text": prompt_text})
api_contents.append({"role": "user", "parts": current_parts})
config = types.GenerateContentConfig(**config_kwargs) config = types.GenerateContentConfig(**config_kwargs)
print("Generating response (this may take a moment for large outputs)...") print("Generating response (this may take a moment for large outputs)...")
response_stream = client.models.generate_content_stream( full_response_text = ""
model=args.model, usage_metadata = None
contents=generation_contents, finish_reason_str = "UNKNOWN"
config=config
) try:
response_stream = client.models.generate_content_stream(
if args.output: model=args.model,
with open(args.output, "w") as f: contents=api_contents,
config=config
)
if args.output:
with open(args.output, "w") as f:
for chunk in response_stream:
if chunk.text:
f.write(chunk.text)
f.flush()
full_response_text += chunk.text
if chunk.usage_metadata:
usage_metadata = chunk.usage_metadata
if chunk.candidates and chunk.candidates[0].finish_reason:
finish_reason_str = chunk.candidates[0].finish_reason.name
print(f"Done! Raw output saved directly to {args.output}")
else:
print()
for chunk in response_stream: for chunk in response_stream:
if chunk.text: if chunk.text:
f.write(chunk.text) print(chunk.text, end="", flush=True)
f.flush() full_response_text += chunk.text
print(f"\nDone! Raw output saved directly to {args.output}") if chunk.usage_metadata:
else: usage_metadata = chunk.usage_metadata
print("-" * 40) if chunk.candidates and chunk.candidates[0].finish_reason:
for chunk in response_stream: finish_reason_str = chunk.candidates[0].finish_reason.name
if chunk.text: print("\n")
print(chunk.text, end="", flush=True)
print("\n" + "-" * 40) except Exception as e:
# Catch 404 Model Not Found or other critical generation errors gracefully
if "404" in str(e) and "NOT_FOUND" in str(e):
print(f"\n[Error] The model '{args.model}' does not exist or is not available.", file=sys.stderr)
else:
print(f"\n[API Error] {e}", file=sys.stderr)
# Exit cleanly so we don't calculate costs or save bad history
return
# ---------------------------------------------------------
# USAGE AND COST CALCULATION
# ---------------------------------------------------------
if usage_metadata:
prompt_tokens = usage_metadata.prompt_token_count or 0
output_tokens = usage_metadata.candidates_token_count or 0
cached_tokens = getattr(usage_metadata, 'cached_content_token_count', 0) or 0
# Revert to max(0) to handle Google's padding discrepancy where prompt_tokens < cached_tokens
uncached_tokens = max(0, prompt_tokens - cached_tokens)
# Ensure the tier logic checks the absolute largest representation of the payload
total_input_tokens = max(prompt_tokens, cached_tokens)
# Fetch the rate dynamically from the parsed JSON or default to 0.0 if fetch failed
rates = pricing_data.get(args.model, {"input": 0.0, "cached": 0.0, "output": 0.0})
input_rate = rates.get("input", 0.0)
cached_rate = rates.get("cached", 0.0)
output_rate = rates.get("output", 0.0)
tier_label = "Base Tier"
# Apply 200k+ tier pricing if total prompt size exceeds 200k
if total_input_tokens > 200_000 and "input_over_200k" in rates:
if rates.get("input_over_200k", 0.0) > 0:
input_rate = rates["input_over_200k"]
tier_label = ">200k Tier"
if rates.get("cached_over_200k", 0.0) > 0:
cached_rate = rates["cached_over_200k"]
if rates.get("output_over_200k", 0.0) > 0:
output_rate = rates["output_over_200k"]
input_cost = (uncached_tokens / 1_000_000) * input_rate
cached_cost = (cached_tokens / 1_000_000) * cached_rate
output_cost = (output_tokens / 1_000_000) * output_rate
total_cost = input_cost + cached_cost + output_cost
print("\n[--- Execution Summary ---]")
print(f"Finish Reason: {finish_reason_str}")
print(f"Token Usage: Input: {uncached_tokens:,} | Cached: {cached_tokens:,} | Output: {output_tokens:,} ({tier_label})")
print(f"Est. Cost: ${total_cost:.6f} (Model: {args.model})")
context_data["history"].append({"role": "user", "text": prompt_text})
context_data["history"].append({"role": "model", "text": full_response_text})
if args.context:
with open(args.context, "w") as f:
json.dump(context_data, f, indent=4)
finally: finally:
# --------------------------------------------------------- # ---------------------------------------------------------
@@ -198,10 +434,10 @@ def main():
# --------------------------------------------------------- # ---------------------------------------------------------
if not args.context and not args.destroy: if not args.context and not args.destroy:
print("\n[Transient Mode] Cleaning up resources...") print("\n[Transient Mode] Cleaning up resources...")
if context_data.get("cache_id"): for model_name, cache_info in context_data.get("caches", {}).items():
try: try:
client.caches.delete(name=context_data["cache_id"]) client.caches.delete(name=cache_info["cache_id"])
print(f"Deleted cache: {context_data['cache_id']}") print(f"Deleted cache for {model_name}: {cache_info['cache_id']}")
except Exception as e: except Exception as e:
print(f"Warning: Failed to delete cache. {e}", file=sys.stderr) print(f"Warning: Failed to delete cache. {e}", file=sys.stderr)