Files
monorepo/gemini/gemini.py
T

280 lines
13 KiB
Python

#!/usr/bin/env python3
import argparse
import json
import os
import sys
from google import genai
from google.genai import types
def main():
parser = argparse.ArgumentParser(description="Gemini API CLI with File & Context Caching")
parser.add_argument("-c", "--context", type=str, default=None,
help="Path to context file. If omitted, transient mode is used.")
parser.add_argument("-f", "--files", nargs="+", default=[],
help="Files to upload to the Gemini API")
parser.add_argument("-m", "--model", type=str, default="gemini-3.1-flash-lite",
help="The model to use (default: gemini-3.1-flash-lite)")
parser.add_argument("-d", "--destroy", action="store_true",
help="Destroy cloud files/cache, and delete local context")
parser.add_argument("-x", "--clear-history", action="store_true",
help="Clear the conversation history without destroying files/caches")
parser.add_argument("-o", "--output", type=str,
help="Direct the raw output to a specific file instead of stdout")
parser.add_argument("-p", "--prompt", type=str,
help="The prompt to send to the AI")
parser.add_argument("positional_prompt", nargs=argparse.REMAINDER,
help="Positional arguments treated as the prompt if -p is omitted")
args = parser.parse_args()
prompt_text = args.prompt if args.prompt else " ".join(args.positional_prompt) if args.positional_prompt else None
if not os.environ.get("GEMINI_API_KEY"):
print("Error: GEMINI_API_KEY environment variable is not set.", file=sys.stderr)
sys.exit(1)
client = genai.Client()
context_data = {"file_ids": [], "caches": {}, "history": []}
if args.context and os.path.exists(args.context):
try:
with open(args.context, "r") as f:
loaded_data = json.load(f)
context_data["file_ids"] = loaded_data.get("file_ids", [])
context_data["caches"] = loaded_data.get("caches", {})
context_data["history"] = loaded_data.get("history", [])
except json.JSONDecodeError:
print(f"Warning: Could not parse {args.context}. Starting fresh.", file=sys.stderr)
# ---------------------------------------------------------
# CLEAR HISTORY FLAG
# ---------------------------------------------------------
if args.clear_history:
context_data["history"] = []
if args.context:
with open(args.context, "w") as f:
json.dump(context_data, f, indent=4)
print("Conversation history cleared.")
if not prompt_text and not args.files and not args.destroy:
return
# ---------------------------------------------------------
# DESTROY FLAG LOGIC
# ---------------------------------------------------------
if args.destroy:
print("Destroying server resources and local context...")
for model_name, cache_info in context_data.get("caches", {}).items():
try:
client.caches.delete(name=cache_info["cache_id"])
print(f"Deleted cache for {model_name}: {cache_info['cache_id']}")
except Exception as e:
print(f"Warning: Failed to delete cache for {model_name}. {e}", file=sys.stderr)
for file_id in context_data.get("file_ids", []):
try:
client.files.delete(name=file_id)
print(f"Deleted file: {file_id}")
except Exception as e:
print(f"Warning: Failed to delete file '{file_id}'. {e}", file=sys.stderr)
if args.context and os.path.exists(args.context):
os.remove(args.context)
print(f"Deleted local context file: {args.context}")
print("Cleanup complete.")
return
try:
# ---------------------------------------------------------
# UPLOAD LOGIC
# ---------------------------------------------------------
if args.files:
for file_path in args.files:
if not os.path.exists(file_path):
print(f"Warning: File '{file_path}' not found. Skipping.", file=sys.stderr)
continue
print(f"Uploading '{file_path}'...")
uploaded_file = client.files.upload(file=file_path)
print(f"Success: '{file_path}' uploaded as '{uploaded_file.name}'")
if uploaded_file.name not in context_data["file_ids"]:
context_data["file_ids"].append(uploaded_file.name)
if args.context:
with open(args.context, "w") as f:
json.dump(context_data, f, indent=4)
# ---------------------------------------------------------
# CACHE CREATION LOGIC
# ---------------------------------------------------------
system_instruction = (
"You are a hybrid data extraction tool. If a specific format or file format output is requested "
"(e.g., CSV), then output exactly what was requested and never use markdown formatting blocks (like ```csv). "
"If a specific format was requested, never include conversational text, greetings, or explanations; "
"and output raw data only. If not specific data format is suggested, you can answer with conversational text."
)
cache_too_small = False
file_objects = []
active_cache_id = None
if context_data.get("file_ids"):
file_objects = [client.files.get(name=f_id) for f_id in context_data["file_ids"]]
model_cache_info = context_data["caches"].get(args.model)
current_files_set = set(context_data["file_ids"])
rebuild_cache = False
if model_cache_info:
cached_files_set = set(model_cache_info.get("cached_file_ids", []))
if current_files_set != cached_files_set:
print(f"File list changed. Destroying stale {args.model} cache to rebuild...")
try:
client.caches.delete(name=model_cache_info["cache_id"])
except Exception as e:
print(f"Warning: Could not delete stale cache. {e}", file=sys.stderr)
rebuild_cache = True
else:
rebuild_cache = True
if rebuild_cache:
print(f"Attempting to create Context Cache for {args.model}...")
try:
cache = client.caches.create(
model=args.model,
config=types.CreateCachedContentConfig(
contents=file_objects,
system_instruction=system_instruction,
ttl="3600s"
)
)
context_data["caches"][args.model] = {
"cache_id": cache.name,
"cached_file_ids": list(context_data["file_ids"])
}
active_cache_id = cache.name
if args.context:
with open(args.context, "w") as f:
json.dump(context_data, f, indent=4)
print(f"Context Cache created: {cache.name}")
except Exception as e:
if "too small" in str(e).lower() or "1024" in str(e):
print("Notice: Files are too small for server-side caching (under 1024 tokens). Falling back to standard processing.")
cache_too_small = True
else:
raise e
elif not cache_too_small and model_cache_info:
active_cache_id = model_cache_info["cache_id"]
print(f"Loading existing cache for {args.model}: {active_cache_id}")
print("Extending cache TTL by 60 minutes...")
try:
client.caches.update(
name=active_cache_id,
config=types.UpdateCachedContentConfig(ttl="3600s")
)
except Exception as e:
print(f"Warning: Failed to update cache TTL. {e}")
# ---------------------------------------------------------
# GENERATION LOGIC (WITH HISTORY)
# ---------------------------------------------------------
if prompt_text:
config_kwargs = {
"max_output_tokens": 65536,
"temperature": 0.0
}
if active_cache_id and not cache_too_small:
config_kwargs["cached_content"] = active_cache_id
else:
config_kwargs["system_instruction"] = system_instruction
# Build the chat payload using strict dictionary representations
api_contents = []
files_added_to_payload = False
for msg in context_data.get("history", []):
parts = []
# If we aren't using a cache, attach un-cached files to the very first user message
if msg["role"] == "user" and not files_added_to_payload and not active_cache_id and file_objects:
for f in file_objects:
parts.append({"file_data": {"file_uri": f.uri, "mime_type": f.mime_type}})
files_added_to_payload = True
parts.append({"text": msg["text"]})
api_contents.append({"role": msg["role"], "parts": parts})
# Add the current prompt
current_parts = []
if not files_added_to_payload and not active_cache_id and file_objects:
for f in file_objects:
current_parts.append({"file_data": {"file_uri": f.uri, "mime_type": f.mime_type}})
files_added_to_payload = True
current_parts.append({"text": prompt_text})
api_contents.append({"role": "user", "parts": current_parts})
config = types.GenerateContentConfig(**config_kwargs)
print("Generating response (this may take a moment for large outputs)...")
response_stream = client.models.generate_content_stream(
model=args.model,
contents=api_contents,
config=config
)
full_response_text = ""
if args.output:
with open(args.output, "w") as f:
for chunk in response_stream:
if chunk.text:
f.write(chunk.text)
f.flush()
full_response_text += chunk.text
print(f"\nDone! Raw output saved directly to {args.output}")
else:
print("-" * 40)
for chunk in response_stream:
if chunk.text:
print(chunk.text, end="", flush=True)
full_response_text += chunk.text
print("\n" + "-" * 40)
# Append this turn to the local history and save
context_data["history"].append({"role": "user", "text": prompt_text})
context_data["history"].append({"role": "model", "text": full_response_text})
if args.context:
with open(args.context, "w") as f:
json.dump(context_data, f, indent=4)
finally:
# ---------------------------------------------------------
# TRANSIENT MODE CLEANUP
# ---------------------------------------------------------
if not args.context and not args.destroy:
print("\n[Transient Mode] Cleaning up resources...")
for model_name, cache_info in context_data.get("caches", {}).items():
try:
client.caches.delete(name=cache_info["cache_id"])
print(f"Deleted cache for {model_name}: {cache_info['cache_id']}")
except Exception as e:
print(f"Warning: Failed to delete cache. {e}", file=sys.stderr)
for file_id in context_data.get("file_ids", []):
try:
client.files.delete(name=file_id)
print(f"Deleted file: {file_id}")
except Exception as e:
print(f"Warning: Failed to delete file '{file_id}'. {e}", file=sys.stderr)
if __name__ == "__main__":
main()