#!/usr/bin/env python3 import argparse import json import os import sys import urllib.request from google import genai from google.genai import types PRICING_FILE = ".gemini_pricing.json" def fetch_pricing_for_model(client, target_model): url = "https://cloud.google.com/gemini-enterprise-agent-platform/generative-ai/pricing" print(f"Fetching live pricing for {target_model} from the web...") try: req = urllib.request.Request(url, headers={'User-Agent': 'Mozilla/5.0'}) with urllib.request.urlopen(req) as response: html = response.read().decode('utf-8') except Exception as e: print(f"Warning: Failed to fetch HTML from {url}: {e}", file=sys.stderr) return None prompt = f""" Extract the API pricing for the model '{target_model}' from the following HTML text. Find the cost per 1 million input tokens, 1 million cached content tokens, and 1 million output tokens. Return ONLY a valid JSON object with this exact structure: {{ "{target_model}": {{ "input": 0.00, "cached": 0.00, "output": 0.00 }} }} If a value is not found, use 0.00. HTML DATA: {html} """ config = types.GenerateContentConfig( response_mime_type="application/json", temperature=0.0 ) print("Parsing pricing data via background AI session...") try: # We use a fast, cheap model just for the parsing task res = client.models.generate_content( model="gemini-3.1-flash-lite", contents=prompt, config=config ) new_data = json.loads(res.text) # Strip the array wrapper if the AI returned a list instead of a pure dict if isinstance(new_data, list) and len(new_data) > 0: new_data = new_data[0] print(f"Successfully retrieved pricing for {target_model}.") print(new_data) return new_data except Exception as e: print(f"Warning: Failed to extract pricing using AI: {e}", file=sys.stderr) return None def main(): parser = argparse.ArgumentParser(description="Gemini API CLI with File & Context Caching") parser.add_argument("-c", "--context", type=str, default=None, help="Path to context file. If omitted, transient mode is used.") parser.add_argument("-f", "--files", nargs="+", default=[], help="Files to upload to the Gemini API") parser.add_argument("-m", "--model", type=str, default="gemini-3.1-flash-lite", help="The model to use (default: gemini-3.1-flash-lite)") parser.add_argument("-d", "--destroy", action="store_true", help="Destroy cloud files/cache, and delete local context") parser.add_argument("-x", "--clear-history", action="store_true", help="Clear the conversation history without destroying files/caches") parser.add_argument("--pricing", action="store_true", help="Force update the pricing info for the specified model from the web") parser.add_argument("-o", "--output", type=str, help="Direct the raw output to a specific file instead of stdout") parser.add_argument("-p", "--prompt", type=str, help="The prompt to send to the AI") parser.add_argument("positional_prompt", nargs=argparse.REMAINDER, help="Positional arguments treated as the prompt if -p is omitted") args = parser.parse_args() prompt_text = args.prompt if args.prompt else " ".join(args.positional_prompt) if args.positional_prompt else None if not os.environ.get("GEMINI_API_KEY"): print("Error: GEMINI_API_KEY environment variable is not set.", file=sys.stderr) sys.exit(1) client = genai.Client() # --------------------------------------------------------- # PRICING CONFIGURATION # --------------------------------------------------------- pricing_data = {} if os.path.exists(PRICING_FILE): try: with open(PRICING_FILE, "r") as f: pricing_data = json.load(f) except json.JSONDecodeError: print(f"Warning: {PRICING_FILE} is corrupted. Starting fresh.") # Fetch pricing if forced, or if the model isn't currently tracked if args.pricing or args.model not in pricing_data: new_pricing = fetch_pricing_for_model(client, args.model) if new_pricing and args.model in new_pricing: pricing_data.update(new_pricing) with open(PRICING_FILE, "w") as f: json.dump(pricing_data, f, indent=4) else: print(f"Warning: Could not fetch pricing for {args.model}. Estimating cost at $0.00.", file=sys.stderr) # Add a fallback zero-value so the script doesn't crash during cost calculation pricing_data[args.model] = {"input": 0.0, "cached": 0.0, "output": 0.0} # If the user only requested a pricing update and nothing else, exit cleanly if args.pricing and not prompt_text and not args.files and not args.destroy and not args.clear_history: return # --------------------------------------------------------- # STATE MANAGEMENT # --------------------------------------------------------- context_data = {"file_ids": [], "caches": {}, "history": []} if args.context and os.path.exists(args.context): try: with open(args.context, "r") as f: loaded_data = json.load(f) context_data["file_ids"] = loaded_data.get("file_ids", []) context_data["caches"] = loaded_data.get("caches", {}) context_data["history"] = loaded_data.get("history", []) except json.JSONDecodeError: print(f"Warning: Could not parse {args.context}. Starting fresh.", file=sys.stderr) # --------------------------------------------------------- # CLEAR HISTORY FLAG # --------------------------------------------------------- if args.clear_history: context_data["history"] = [] if args.context: with open(args.context, "w") as f: json.dump(context_data, f, indent=4) print("Conversation history cleared.") if not prompt_text and not args.files and not args.destroy: return # --------------------------------------------------------- # DESTROY FLAG LOGIC # --------------------------------------------------------- if args.destroy: print("Destroying server resources and local context...") for model_name, cache_info in context_data.get("caches", {}).items(): try: client.caches.delete(name=cache_info["cache_id"]) print(f"Deleted cache for {model_name}: {cache_info['cache_id']}") except Exception as e: print(f"Warning: Failed to delete cache for {model_name}. {e}", file=sys.stderr) for file_id in context_data.get("file_ids", []): try: client.files.delete(name=file_id) print(f"Deleted file: {file_id}") except Exception as e: print(f"Warning: Failed to delete file '{file_id}'. {e}", file=sys.stderr) if args.context and os.path.exists(args.context): os.remove(args.context) print(f"Deleted local context file: {args.context}") print("Cleanup complete.") return try: # --------------------------------------------------------- # UPLOAD LOGIC # --------------------------------------------------------- if args.files: for file_path in args.files: if not os.path.exists(file_path): print(f"Warning: File '{file_path}' not found. Skipping.", file=sys.stderr) continue print(f"Uploading '{file_path}'...") uploaded_file = client.files.upload(file=file_path) print(f"Success: '{file_path}' uploaded as '{uploaded_file.name}'") if uploaded_file.name not in context_data["file_ids"]: context_data["file_ids"].append(uploaded_file.name) if args.context: with open(args.context, "w") as f: json.dump(context_data, f, indent=4) # --------------------------------------------------------- # CACHE VALIDATION & CREATION LOGIC # --------------------------------------------------------- system_instruction = ( "You are a hybrid data extraction tool. If a specific format or file format output is requested " "(e.g., CSV), then output exactly what was requested and never use markdown formatting blocks (like ```csv). " "If a specific format was requested, never include conversational text, greetings, or explanations; " "and output raw data only. If not specific data format is suggested, you can answer with conversational text." ) cache_too_small = False file_objects = [] active_cache_id = None if context_data.get("file_ids"): file_objects = [client.files.get(name=f_id) for f_id in context_data["file_ids"]] model_cache_info = context_data["caches"].get(args.model) current_files_set = set(context_data["file_ids"]) rebuild_cache = False if model_cache_info: cached_files_set = set(model_cache_info.get("cached_file_ids", [])) # Check 1: Have the files changed? if current_files_set != cached_files_set: print(f"File list changed. Destroying stale {args.model} cache to rebuild...") try: client.caches.delete(name=model_cache_info["cache_id"]) except Exception as e: print(f"Warning: Could not delete stale cache. {e}", file=sys.stderr) rebuild_cache = True else: # Check 2: Does the cache still exist on the server? print(f"Verifying existing cache for {args.model}: {model_cache_info['cache_id']}...") try: client.caches.get(name=model_cache_info['cache_id']) print("Cache verified. Extending TTL by 60 minutes...") client.caches.update( name=model_cache_info['cache_id'], config=types.UpdateCachedContentConfig(ttl="3600s") ) active_cache_id = model_cache_info['cache_id'] except Exception as e: print(f"Cache expired or not found. Flagging for rebuild. ({e})") rebuild_cache = True else: rebuild_cache = True if rebuild_cache: print(f"Attempting to create Context Cache for {args.model}...") try: cache = client.caches.create( model=args.model, config=types.CreateCachedContentConfig( contents=file_objects, system_instruction=system_instruction, ttl="3600s" ) ) context_data["caches"][args.model] = { "cache_id": cache.name, "cached_file_ids": list(context_data["file_ids"]) } active_cache_id = cache.name if args.context: with open(args.context, "w") as f: json.dump(context_data, f, indent=4) print(f"Context Cache created: {cache.name}") except Exception as e: if "too small" in str(e).lower() or "1024" in str(e): print("Notice: Files are too small for server-side caching (under 1024 tokens). Falling back to standard processing.") cache_too_small = True else: raise e # --------------------------------------------------------- # GENERATION LOGIC (WITH HISTORY) # --------------------------------------------------------- if prompt_text: config_kwargs = { "max_output_tokens": 65536, "temperature": 0.0 } if active_cache_id and not cache_too_small: config_kwargs["cached_content"] = active_cache_id else: config_kwargs["system_instruction"] = system_instruction # Build the chat payload using strict dictionary representations api_contents = [] files_added_to_payload = False for msg in context_data.get("history", []): parts = [] # If we aren't using a cache, attach un-cached files to the very first user message if msg["role"] == "user" and not files_added_to_payload and not active_cache_id and file_objects: for f in file_objects: parts.append({"file_data": {"file_uri": f.uri, "mime_type": f.mime_type}}) files_added_to_payload = True parts.append({"text": msg["text"]}) api_contents.append({"role": msg["role"], "parts": parts}) # Add the current prompt current_parts = [] if not files_added_to_payload and not active_cache_id and file_objects: for f in file_objects: current_parts.append({"file_data": {"file_uri": f.uri, "mime_type": f.mime_type}}) files_added_to_payload = True current_parts.append({"text": prompt_text}) api_contents.append({"role": "user", "parts": current_parts}) config = types.GenerateContentConfig(**config_kwargs) print("Generating response (this may take a moment for large outputs)...\n") response_stream = client.models.generate_content_stream( model=args.model, contents=api_contents, config=config ) full_response_text = "" usage_metadata = None finish_reason_str = "UNKNOWN" if args.output: with open(args.output, "w") as f: for chunk in response_stream: if chunk.text: f.write(chunk.text) f.flush() full_response_text += chunk.text if chunk.usage_metadata: usage_metadata = chunk.usage_metadata if chunk.candidates and chunk.candidates[0].finish_reason: finish_reason_str = chunk.candidates[0].finish_reason.name print(f"Done! Raw output saved directly to {args.output}") else: for chunk in response_stream: if chunk.text: print(chunk.text, end="", flush=True) full_response_text += chunk.text if chunk.usage_metadata: usage_metadata = chunk.usage_metadata if chunk.candidates and chunk.candidates[0].finish_reason: finish_reason_str = chunk.candidates[0].finish_reason.name print() # --------------------------------------------------------- # USAGE AND COST CALCULATION # --------------------------------------------------------- if usage_metadata: prompt_tokens = usage_metadata.prompt_token_count or 0 output_tokens = usage_metadata.candidates_token_count or 0 cached_tokens = getattr(usage_metadata, 'cached_content_token_count', 0) or 0 uncached_tokens = max(0, prompt_tokens - cached_tokens) # Fetch the rate dynamically from the parsed JSON or default to 0.0 if fetch failed rates = pricing_data.get(args.model, {"input": 0.0, "cached": 0.0, "output": 0.0}) input_cost = (uncached_tokens / 1_000_000) * rates["input"] cached_cost = (cached_tokens / 1_000_000) * rates["cached"] output_cost = (output_tokens / 1_000_000) * rates["output"] total_cost = input_cost + cached_cost + output_cost print("\n[--- Execution Summary ---]") print(f"Finish Reason: {finish_reason_str}") print(f"Token Usage: Input: {uncached_tokens:,} | Cached: {cached_tokens:,} | Output: {output_tokens:,}") print(f"Est. Cost: ${total_cost:.6f} (Model: {args.model})") context_data["history"].append({"role": "user", "text": prompt_text}) context_data["history"].append({"role": "model", "text": full_response_text}) if args.context: with open(args.context, "w") as f: json.dump(context_data, f, indent=4) finally: # --------------------------------------------------------- # TRANSIENT MODE CLEANUP # --------------------------------------------------------- if not args.context and not args.destroy: print("\n[Transient Mode] Cleaning up resources...") for model_name, cache_info in context_data.get("caches", {}).items(): try: client.caches.delete(name=cache_info["cache_id"]) print(f"Deleted cache for {model_name}: {cache_info['cache_id']}") except Exception as e: print(f"Warning: Failed to delete cache. {e}", file=sys.stderr) for file_id in context_data.get("file_ids", []): try: client.files.delete(name=file_id) print(f"Deleted file: {file_id}") except Exception as e: print(f"Warning: Failed to delete file '{file_id}'. {e}", file=sys.stderr) if __name__ == "__main__": main()