Files
monorepo/gemini/gemini.py
T
2026-06-03 15:10:05 -07:00

161 lines
6.6 KiB
Python

#!/usr/bin/env python3
import argparse
import json
import os
import sys
import datetime
import google.generativeai as genai
from google.generativeai import caching
def main():
parser = argparse.ArgumentParser(description="Gemini API CLI with File & Context Caching")
parser.add_argument("-c", "--context", type=str, default=".gemini_context.json",
help="Path to the context file (defaults to .gemini_context.json)")
parser.add_argument("-f", "--files", nargs="+", default=[],
help="Files to upload to the Gemini API")
parser.add_argument("-m", "--model", type=str, default="models/gemini-1.5-pro-001",
help="The model to use (default: models/gemini-1.5-pro-001)")
parser.add_argument("-d", "--destroy", action="store_true",
help="Destroy the cloud files and cache, and delete the local context file")
parser.add_argument("-o", "--output", type=str,
help="Direct the raw output to a specific file instead of stdout")
parser.add_argument("-p", "--prompt", type=str,
help="The prompt to send to the AI")
parser.add_argument("positional_prompt", nargs=argparse.REMAINDER,
help="Positional arguments treated as the prompt if -p is omitted")
args = parser.parse_args()
# Determine the prompt text
prompt_text = args.prompt if args.prompt else " ".join(args.positional_prompt) if args.positional_prompt else None
# Authenticate
api_key = os.environ.get("GEMINI_API_KEY")
if not api_key:
print("Error: GEMINI_API_KEY environment variable is not set.", file=sys.stderr)
sys.exit(1)
genai.configure(api_key=api_key)
context_data = {"file_ids": [], "cache_id": None}
# Load existing context if it exists
if os.path.exists(args.context):
try:
with open(args.context, "r") as f:
context_data = json.load(f)
except json.JSONDecodeError:
print(f"Warning: Could not parse {args.context}. Starting fresh.", file=sys.stderr)
# ---------------------------------------------------------
# DESTROY FLAG LOGIC
# ---------------------------------------------------------
if args.destroy:
print("Destroying server resources and local context...")
if context_data.get("cache_id"):
try:
caching.CachedContent.get(context_data["cache_id"]).delete()
print(f"Deleted cache: {context_data['cache_id']}")
except Exception as e:
print(f"Warning: Failed to delete cache. {e}", file=sys.stderr)
for file_id in context_data.get("file_ids", []):
try:
genai.delete_file(file_id)
print(f"Deleted file: {file_id}")
except Exception as e:
print(f"Warning: Failed to delete file '{file_id}'. {e}", file=sys.stderr)
if os.path.exists(args.context):
os.remove(args.context)
print(f"Deleted local context file: {args.context}")
print("Cleanup complete.")
return # Exit after destroying
# ---------------------------------------------------------
# UPLOAD LOGIC
# ---------------------------------------------------------
if args.files:
for file_path in args.files:
if not os.path.exists(file_path):
print(f"Warning: File '{file_path}' not found. Skipping.", file=sys.stderr)
continue
print(f"Uploading '{file_path}'...")
uploaded_file = genai.upload_file(file_path)
print(f"Success: '{file_path}' uploaded as '{uploaded_file.name}'")
if uploaded_file.name not in context_data["file_ids"]:
context_data["file_ids"].append(uploaded_file.name)
with open(args.context, "w") as f:
json.dump(context_data, f, indent=4)
# ---------------------------------------------------------
# CACHE CREATION AND TTL EXTENSION LOGIC
# ---------------------------------------------------------
cache = None
system_instruction = (
"You are a strict data extraction tool. Output exactly what is requested "
"(e.g., CSV). Never use markdown formatting blocks (like ```csv). "
"Never include conversational text, greetings, or explanations. Output raw data only."
)
if context_data.get("file_ids"):
if not context_data.get("cache_id"):
print("Creating Context Cache on Google's servers...")
file_objects = [genai.get_file(f_id) for f_id in context_data["file_ids"]]
cache = caching.CachedContent.create(
model=args.model,
system_instruction=system_instruction,
contents=file_objects,
ttl=datetime.timedelta(minutes=60)
)
context_data["cache_id"] = cache.name
with open(args.context, "w") as f:
json.dump(context_data, f, indent=4)
print(f"Context Cache created: {cache.name}")
else:
print(f"Loading existing cache: {context_data['cache_id']}")
cache = caching.CachedContent.get(context_data["cache_id"])
print("Extending cache TTL by 60 minutes...")
cache.update(ttl=datetime.timedelta(minutes=60))
# ---------------------------------------------------------
# GENERATION LOGIC
# ---------------------------------------------------------
if prompt_text:
generation_config = {
"max_output_tokens": 65536,
"temperature": 0.0
}
if cache:
model = genai.GenerativeModel.from_cached_content(
cached_content=cache,
generation_config=generation_config
)
else:
# Fallback if no files were uploaded
model = genai.GenerativeModel(
model_name=args.model,
system_instruction=system_instruction,
generation_config=generation_config
)
print("Generating response (this may take a moment for large outputs)...")
response = model.generate_content(prompt_text)
if args.output:
with open(args.output, "w") as f:
f.write(response.text)
print(f"Done! Raw output saved directly to {args.output}")
else:
print("-" * 40)
print(response.text)
print("-" * 40)
if __name__ == "__main__":
main()