407 lines
18 KiB
Python
407 lines
18 KiB
Python
#!/usr/bin/env python3
|
|
import argparse
|
|
import json
|
|
import os
|
|
import sys
|
|
import urllib.request
|
|
from google import genai
|
|
from google.genai import types
|
|
|
|
PRICING_FILE = ".gemini_pricing.json"
|
|
|
|
def fetch_pricing_for_model(client, target_model):
|
|
url = "https://cloud.google.com/gemini-enterprise-agent-platform/generative-ai/pricing"
|
|
print(f"Fetching live pricing for {target_model} from the web...")
|
|
|
|
try:
|
|
req = urllib.request.Request(url, headers={'User-Agent': 'Mozilla/5.0'})
|
|
with urllib.request.urlopen(req) as response:
|
|
html = response.read().decode('utf-8')
|
|
except Exception as e:
|
|
print(f"Warning: Failed to fetch HTML from {url}: {e}", file=sys.stderr)
|
|
return None
|
|
|
|
prompt = f"""
|
|
Extract the API pricing for the model '{target_model}' from the following HTML text.
|
|
Find the cost per 1 million input tokens, 1 million cached content tokens, and 1 million output tokens.
|
|
Return ONLY a valid JSON object with this exact structure:
|
|
{{
|
|
"{target_model}": {{
|
|
"input": 0.00,
|
|
"cached": 0.00,
|
|
"output": 0.00
|
|
}}
|
|
}}
|
|
If a value is not found, use 0.00.
|
|
|
|
HTML DATA:
|
|
{html}
|
|
"""
|
|
|
|
config = types.GenerateContentConfig(
|
|
response_mime_type="application/json",
|
|
temperature=0.0
|
|
)
|
|
|
|
print("Parsing pricing data via background AI session...")
|
|
try:
|
|
# We use a fast, cheap model just for the parsing task
|
|
res = client.models.generate_content(
|
|
model="gemini-3.1-flash-lite",
|
|
contents=prompt,
|
|
config=config
|
|
)
|
|
new_data = json.loads(res.text)
|
|
|
|
# Strip the array wrapper if the AI returned a list instead of a pure dict
|
|
if isinstance(new_data, list) and len(new_data) > 0:
|
|
new_data = new_data[0]
|
|
|
|
print(f"Successfully retrieved pricing for {target_model}.")
|
|
print(new_data)
|
|
return new_data
|
|
except Exception as e:
|
|
print(f"Warning: Failed to extract pricing using AI: {e}", file=sys.stderr)
|
|
return None
|
|
|
|
def main():
|
|
parser = argparse.ArgumentParser(description="Gemini API CLI with File & Context Caching")
|
|
parser.add_argument("-c", "--context", type=str, default=None,
|
|
help="Path to context file. If omitted, transient mode is used.")
|
|
parser.add_argument("-f", "--files", nargs="+", default=[],
|
|
help="Files to upload to the Gemini API")
|
|
parser.add_argument("-m", "--model", type=str, default="gemini-3.1-flash-lite",
|
|
help="The model to use (default: gemini-3.1-flash-lite)")
|
|
parser.add_argument("-d", "--destroy", action="store_true",
|
|
help="Destroy cloud files/cache, and delete local context")
|
|
parser.add_argument("-x", "--clear-history", action="store_true",
|
|
help="Clear the conversation history without destroying files/caches")
|
|
parser.add_argument("--pricing", action="store_true",
|
|
help="Force update the pricing info for the specified model from the web")
|
|
parser.add_argument("-o", "--output", type=str,
|
|
help="Direct the raw output to a specific file instead of stdout")
|
|
parser.add_argument("-p", "--prompt", type=str,
|
|
help="The prompt to send to the AI")
|
|
parser.add_argument("positional_prompt", nargs=argparse.REMAINDER,
|
|
help="Positional arguments treated as the prompt if -p is omitted")
|
|
|
|
args = parser.parse_args()
|
|
prompt_text = args.prompt if args.prompt else " ".join(args.positional_prompt) if args.positional_prompt else None
|
|
|
|
if not os.environ.get("GEMINI_API_KEY"):
|
|
print("Error: GEMINI_API_KEY environment variable is not set.", file=sys.stderr)
|
|
sys.exit(1)
|
|
|
|
client = genai.Client()
|
|
|
|
# ---------------------------------------------------------
|
|
# PRICING CONFIGURATION
|
|
# ---------------------------------------------------------
|
|
pricing_data = {}
|
|
|
|
if os.path.exists(PRICING_FILE):
|
|
try:
|
|
with open(PRICING_FILE, "r") as f:
|
|
pricing_data = json.load(f)
|
|
except json.JSONDecodeError:
|
|
print(f"Warning: {PRICING_FILE} is corrupted. Starting fresh.")
|
|
|
|
# Fetch pricing if forced, or if the model isn't currently tracked
|
|
if args.pricing or args.model not in pricing_data:
|
|
new_pricing = fetch_pricing_for_model(client, args.model)
|
|
|
|
if new_pricing and args.model in new_pricing:
|
|
pricing_data.update(new_pricing)
|
|
with open(PRICING_FILE, "w") as f:
|
|
json.dump(pricing_data, f, indent=4)
|
|
else:
|
|
print(f"Warning: Could not fetch pricing for {args.model}. Estimating cost at $0.00.", file=sys.stderr)
|
|
# Add a fallback zero-value so the script doesn't crash during cost calculation
|
|
pricing_data[args.model] = {"input": 0.0, "cached": 0.0, "output": 0.0}
|
|
|
|
# If the user only requested a pricing update and nothing else, exit cleanly
|
|
if args.pricing and not prompt_text and not args.files and not args.destroy and not args.clear_history:
|
|
return
|
|
|
|
# ---------------------------------------------------------
|
|
# STATE MANAGEMENT
|
|
# ---------------------------------------------------------
|
|
context_data = {"file_ids": [], "caches": {}, "history": []}
|
|
|
|
if args.context and os.path.exists(args.context):
|
|
try:
|
|
with open(args.context, "r") as f:
|
|
loaded_data = json.load(f)
|
|
context_data["file_ids"] = loaded_data.get("file_ids", [])
|
|
context_data["caches"] = loaded_data.get("caches", {})
|
|
context_data["history"] = loaded_data.get("history", [])
|
|
except json.JSONDecodeError:
|
|
print(f"Warning: Could not parse {args.context}. Starting fresh.", file=sys.stderr)
|
|
|
|
# ---------------------------------------------------------
|
|
# CLEAR HISTORY FLAG
|
|
# ---------------------------------------------------------
|
|
if args.clear_history:
|
|
context_data["history"] = []
|
|
if args.context:
|
|
with open(args.context, "w") as f:
|
|
json.dump(context_data, f, indent=4)
|
|
print("Conversation history cleared.")
|
|
if not prompt_text and not args.files and not args.destroy:
|
|
return
|
|
|
|
# ---------------------------------------------------------
|
|
# DESTROY FLAG LOGIC
|
|
# ---------------------------------------------------------
|
|
if args.destroy:
|
|
print("Destroying server resources and local context...")
|
|
for model_name, cache_info in context_data.get("caches", {}).items():
|
|
try:
|
|
client.caches.delete(name=cache_info["cache_id"])
|
|
print(f"Deleted cache for {model_name}: {cache_info['cache_id']}")
|
|
except Exception as e:
|
|
print(f"Warning: Failed to delete cache for {model_name}. {e}", file=sys.stderr)
|
|
|
|
for file_id in context_data.get("file_ids", []):
|
|
try:
|
|
client.files.delete(name=file_id)
|
|
print(f"Deleted file: {file_id}")
|
|
except Exception as e:
|
|
print(f"Warning: Failed to delete file '{file_id}'. {e}", file=sys.stderr)
|
|
|
|
if args.context and os.path.exists(args.context):
|
|
os.remove(args.context)
|
|
print(f"Deleted local context file: {args.context}")
|
|
|
|
print("Cleanup complete.")
|
|
return
|
|
|
|
try:
|
|
# ---------------------------------------------------------
|
|
# UPLOAD LOGIC
|
|
# ---------------------------------------------------------
|
|
if args.files:
|
|
for file_path in args.files:
|
|
if not os.path.exists(file_path):
|
|
print(f"Warning: File '{file_path}' not found. Skipping.", file=sys.stderr)
|
|
continue
|
|
|
|
print(f"Uploading '{file_path}'...")
|
|
uploaded_file = client.files.upload(file=file_path)
|
|
print(f"Success: '{file_path}' uploaded as '{uploaded_file.name}'")
|
|
|
|
if uploaded_file.name not in context_data["file_ids"]:
|
|
context_data["file_ids"].append(uploaded_file.name)
|
|
|
|
if args.context:
|
|
with open(args.context, "w") as f:
|
|
json.dump(context_data, f, indent=4)
|
|
|
|
# ---------------------------------------------------------
|
|
# CACHE VALIDATION & CREATION LOGIC
|
|
# ---------------------------------------------------------
|
|
system_instruction = (
|
|
"You are a hybrid data extraction tool. If a specific format or file format output is requested "
|
|
"(e.g., CSV), then output exactly what was requested and never use markdown formatting blocks (like ```csv). "
|
|
"If a specific format was requested, never include conversational text, greetings, or explanations; "
|
|
"and output raw data only. If not specific data format is suggested, you can answer with conversational text."
|
|
)
|
|
|
|
cache_too_small = False
|
|
file_objects = []
|
|
active_cache_id = None
|
|
|
|
if context_data.get("file_ids"):
|
|
file_objects = [client.files.get(name=f_id) for f_id in context_data["file_ids"]]
|
|
|
|
model_cache_info = context_data["caches"].get(args.model)
|
|
current_files_set = set(context_data["file_ids"])
|
|
rebuild_cache = False
|
|
|
|
if model_cache_info:
|
|
cached_files_set = set(model_cache_info.get("cached_file_ids", []))
|
|
|
|
# Check 1: Have the files changed?
|
|
if current_files_set != cached_files_set:
|
|
print(f"File list changed. Destroying stale {args.model} cache to rebuild...")
|
|
try:
|
|
client.caches.delete(name=model_cache_info["cache_id"])
|
|
except Exception as e:
|
|
print(f"Warning: Could not delete stale cache. {e}", file=sys.stderr)
|
|
rebuild_cache = True
|
|
else:
|
|
# Check 2: Does the cache still exist on the server?
|
|
print(f"Verifying existing cache for {args.model}: {model_cache_info['cache_id']}...")
|
|
try:
|
|
client.caches.get(name=model_cache_info['cache_id'])
|
|
print("Cache verified. Extending TTL by 60 minutes...")
|
|
client.caches.update(
|
|
name=model_cache_info['cache_id'],
|
|
config=types.UpdateCachedContentConfig(ttl="3600s")
|
|
)
|
|
active_cache_id = model_cache_info['cache_id']
|
|
except Exception as e:
|
|
print(f"Cache expired or not found. Flagging for rebuild. ({e})")
|
|
rebuild_cache = True
|
|
else:
|
|
rebuild_cache = True
|
|
|
|
if rebuild_cache:
|
|
print(f"Attempting to create Context Cache for {args.model}...")
|
|
try:
|
|
cache = client.caches.create(
|
|
model=args.model,
|
|
config=types.CreateCachedContentConfig(
|
|
contents=file_objects,
|
|
system_instruction=system_instruction,
|
|
ttl="3600s"
|
|
)
|
|
)
|
|
|
|
context_data["caches"][args.model] = {
|
|
"cache_id": cache.name,
|
|
"cached_file_ids": list(context_data["file_ids"])
|
|
}
|
|
active_cache_id = cache.name
|
|
|
|
if args.context:
|
|
with open(args.context, "w") as f:
|
|
json.dump(context_data, f, indent=4)
|
|
print(f"Context Cache created: {cache.name}")
|
|
|
|
except Exception as e:
|
|
if "too small" in str(e).lower() or "1024" in str(e):
|
|
print("Notice: Files are too small for server-side caching (under 1024 tokens). Falling back to standard processing.")
|
|
cache_too_small = True
|
|
else:
|
|
raise e
|
|
|
|
# ---------------------------------------------------------
|
|
# GENERATION LOGIC (WITH HISTORY)
|
|
# ---------------------------------------------------------
|
|
if prompt_text:
|
|
config_kwargs = {
|
|
"max_output_tokens": 65536,
|
|
"temperature": 0.0
|
|
}
|
|
|
|
if active_cache_id and not cache_too_small:
|
|
config_kwargs["cached_content"] = active_cache_id
|
|
else:
|
|
config_kwargs["system_instruction"] = system_instruction
|
|
|
|
# Build the chat payload using strict dictionary representations
|
|
api_contents = []
|
|
files_added_to_payload = False
|
|
|
|
for msg in context_data.get("history", []):
|
|
parts = []
|
|
# If we aren't using a cache, attach un-cached files to the very first user message
|
|
if msg["role"] == "user" and not files_added_to_payload and not active_cache_id and file_objects:
|
|
for f in file_objects:
|
|
parts.append({"file_data": {"file_uri": f.uri, "mime_type": f.mime_type}})
|
|
files_added_to_payload = True
|
|
|
|
parts.append({"text": msg["text"]})
|
|
api_contents.append({"role": msg["role"], "parts": parts})
|
|
|
|
# Add the current prompt
|
|
current_parts = []
|
|
if not files_added_to_payload and not active_cache_id and file_objects:
|
|
for f in file_objects:
|
|
current_parts.append({"file_data": {"file_uri": f.uri, "mime_type": f.mime_type}})
|
|
files_added_to_payload = True
|
|
|
|
current_parts.append({"text": prompt_text})
|
|
api_contents.append({"role": "user", "parts": current_parts})
|
|
|
|
config = types.GenerateContentConfig(**config_kwargs)
|
|
|
|
print("Generating response (this may take a moment for large outputs)...\n")
|
|
|
|
response_stream = client.models.generate_content_stream(
|
|
model=args.model,
|
|
contents=api_contents,
|
|
config=config
|
|
)
|
|
|
|
full_response_text = ""
|
|
usage_metadata = None
|
|
finish_reason_str = "UNKNOWN"
|
|
|
|
if args.output:
|
|
with open(args.output, "w") as f:
|
|
for chunk in response_stream:
|
|
if chunk.text:
|
|
f.write(chunk.text)
|
|
f.flush()
|
|
full_response_text += chunk.text
|
|
if chunk.usage_metadata:
|
|
usage_metadata = chunk.usage_metadata
|
|
if chunk.candidates and chunk.candidates[0].finish_reason:
|
|
finish_reason_str = chunk.candidates[0].finish_reason.name
|
|
print(f"Done! Raw output saved directly to {args.output}")
|
|
else:
|
|
for chunk in response_stream:
|
|
if chunk.text:
|
|
print(chunk.text, end="", flush=True)
|
|
full_response_text += chunk.text
|
|
if chunk.usage_metadata:
|
|
usage_metadata = chunk.usage_metadata
|
|
if chunk.candidates and chunk.candidates[0].finish_reason:
|
|
finish_reason_str = chunk.candidates[0].finish_reason.name
|
|
print()
|
|
|
|
# ---------------------------------------------------------
|
|
# USAGE AND COST CALCULATION
|
|
# ---------------------------------------------------------
|
|
if usage_metadata:
|
|
prompt_tokens = usage_metadata.prompt_token_count or 0
|
|
output_tokens = usage_metadata.candidates_token_count or 0
|
|
cached_tokens = getattr(usage_metadata, 'cached_content_token_count', 0) or 0
|
|
|
|
uncached_tokens = max(0, prompt_tokens - cached_tokens)
|
|
|
|
# Fetch the rate dynamically from the parsed JSON or default to 0.0 if fetch failed
|
|
rates = pricing_data.get(args.model, {"input": 0.0, "cached": 0.0, "output": 0.0})
|
|
|
|
input_cost = (uncached_tokens / 1_000_000) * rates["input"]
|
|
cached_cost = (cached_tokens / 1_000_000) * rates["cached"]
|
|
output_cost = (output_tokens / 1_000_000) * rates["output"]
|
|
total_cost = input_cost + cached_cost + output_cost
|
|
|
|
print("\n[--- Execution Summary ---]")
|
|
print(f"Finish Reason: {finish_reason_str}")
|
|
print(f"Token Usage: Input: {uncached_tokens:,} | Cached: {cached_tokens:,} | Output: {output_tokens:,}")
|
|
print(f"Est. Cost: ${total_cost:.6f} (Model: {args.model})")
|
|
|
|
context_data["history"].append({"role": "user", "text": prompt_text})
|
|
context_data["history"].append({"role": "model", "text": full_response_text})
|
|
|
|
if args.context:
|
|
with open(args.context, "w") as f:
|
|
json.dump(context_data, f, indent=4)
|
|
|
|
finally:
|
|
# ---------------------------------------------------------
|
|
# TRANSIENT MODE CLEANUP
|
|
# ---------------------------------------------------------
|
|
if not args.context and not args.destroy:
|
|
print("\n[Transient Mode] Cleaning up resources...")
|
|
for model_name, cache_info in context_data.get("caches", {}).items():
|
|
try:
|
|
client.caches.delete(name=cache_info["cache_id"])
|
|
print(f"Deleted cache for {model_name}: {cache_info['cache_id']}")
|
|
except Exception as e:
|
|
print(f"Warning: Failed to delete cache. {e}", file=sys.stderr)
|
|
|
|
for file_id in context_data.get("file_ids", []):
|
|
try:
|
|
client.files.delete(name=file_id)
|
|
print(f"Deleted file: {file_id}")
|
|
except Exception as e:
|
|
print(f"Warning: Failed to delete file '{file_id}'. {e}", file=sys.stderr)
|
|
|
|
if __name__ == "__main__":
|
|
main()
|