Working on the gemini CLI

This commit is contained in:
2026-06-03 15:10:24 -07:00
parent d2209fd209
commit 8d9fe389c6
+131 -119
View File
@@ -8,8 +8,8 @@ from google.genai import types
def main(): def main():
parser = argparse.ArgumentParser(description="Gemini API CLI with File & Context Caching") parser = argparse.ArgumentParser(description="Gemini API CLI with File & Context Caching")
parser.add_argument("-c", "--context", type=str, default=".gemini_context.json", parser.add_argument("-c", "--context", type=str, default=None,
help="Path to the context file (defaults to .gemini_context.json)") help="Path to context file. If omitted, files/caches are deleted after execution.")
parser.add_argument("-f", "--files", nargs="+", default=[], parser.add_argument("-f", "--files", nargs="+", default=[],
help="Files to upload to the Gemini API") help="Files to upload to the Gemini API")
parser.add_argument("-m", "--model", type=str, default="gemini-3.1-flash-lite", parser.add_argument("-m", "--model", type=str, default="gemini-3.1-flash-lite",
@@ -25,20 +25,17 @@ def main():
args = parser.parse_args() args = parser.parse_args()
# Determine the prompt text
prompt_text = args.prompt if args.prompt else " ".join(args.positional_prompt) if args.positional_prompt else None prompt_text = args.prompt if args.prompt else " ".join(args.positional_prompt) if args.positional_prompt else None
# Authenticate: The new Client() automatically looks for the GEMINI_API_KEY environment variable.
if not os.environ.get("GEMINI_API_KEY"): if not os.environ.get("GEMINI_API_KEY"):
print("Error: GEMINI_API_KEY environment variable is not set.", file=sys.stderr) print("Error: GEMINI_API_KEY environment variable is not set.", file=sys.stderr)
sys.exit(1) sys.exit(1)
client = genai.Client() client = genai.Client()
context_data = {"file_ids": [], "cache_id": None} context_data = {"file_ids": [], "cache_id": None}
# Load existing context if it exists # Load existing context if a context file is specified and exists
if os.path.exists(args.context): if args.context and os.path.exists(args.context):
try: try:
with open(args.context, "r") as f: with open(args.context, "r") as f:
context_data = json.load(f) context_data = json.load(f)
@@ -64,36 +61,39 @@ def main():
except Exception as e: except Exception as e:
print(f"Warning: Failed to delete file '{file_id}'. {e}", file=sys.stderr) print(f"Warning: Failed to delete file '{file_id}'. {e}", file=sys.stderr)
if os.path.exists(args.context): if args.context and os.path.exists(args.context):
os.remove(args.context) os.remove(args.context)
print(f"Deleted local context file: {args.context}") print(f"Deleted local context file: {args.context}")
print("Cleanup complete.") print("Cleanup complete.")
return # Exit after destroying return
# --------------------------------------------------------- # Wrap the main execution in a try/finally to guarantee cleanup on transient runs
# UPLOAD LOGIC try:
# --------------------------------------------------------- # ---------------------------------------------------------
if args.files: # UPLOAD LOGIC
for file_path in args.files: # ---------------------------------------------------------
if not os.path.exists(file_path): if args.files:
print(f"Warning: File '{file_path}' not found. Skipping.", file=sys.stderr) for file_path in args.files:
continue if not os.path.exists(file_path):
print(f"Warning: File '{file_path}' not found. Skipping.", file=sys.stderr)
print(f"Uploading '{file_path}'...") continue
uploaded_file = client.files.upload(file=file_path)
print(f"Success: '{file_path}' uploaded as '{uploaded_file.name}'") print(f"Uploading '{file_path}'...")
uploaded_file = client.files.upload(file=file_path)
if uploaded_file.name not in context_data["file_ids"]: print(f"Success: '{file_path}' uploaded as '{uploaded_file.name}'")
context_data["file_ids"].append(uploaded_file.name)
if uploaded_file.name not in context_data["file_ids"]:
context_data["file_ids"].append(uploaded_file.name)
with open(args.context, "w") as f: if args.context:
json.dump(context_data, f, indent=4) with open(args.context, "w") as f:
json.dump(context_data, f, indent=4)
# --------------------------------------------------------- # ---------------------------------------------------------
# CACHE CREATION AND TTL EXTENSION LOGIC # CACHE CREATION LOGIC
# --------------------------------------------------------- # ---------------------------------------------------------
system_instruction = ( system_instruction = (
# "You are a strict data extraction tool. Output exactly what is requested " # "You are a strict data extraction tool. Output exactly what is requested "
# "(e.g., CSV). Never use markdown formatting blocks (like ```csv). " # "(e.g., CSV). Never use markdown formatting blocks (like ```csv). "
# "Never include conversational text, greetings, or explanations. Output raw data only." # "Never include conversational text, greetings, or explanations. Output raw data only."
@@ -101,98 +101,110 @@ def main():
"(e.g., CSV), then output exactly what was requested and never use markdown formatting blocks (like ```csv). " "(e.g., CSV), then output exactly what was requested and never use markdown formatting blocks (like ```csv). "
"If a specific format was requested, never include conversational text, greetings, or explanations; " "If a specific format was requested, never include conversational text, greetings, or explanations; "
"and output raw data only. If not specific data format is suggested, you can answer with conversational text." "and output raw data only. If not specific data format is suggested, you can answer with conversational text."
)
cache_too_small = False
file_objects = []
if context_data.get("file_ids"):
# We always need the file objects just in case caching fails
file_objects = [client.files.get(name=f_id) for f_id in context_data["file_ids"]]
if not context_data.get("cache_id"):
print("Attempting to create Context Cache on Google's servers...")
try:
# Attempt to create the cache
cache = client.caches.create(
model=args.model,
config=types.CreateCachedContentConfig(
contents=file_objects,
system_instruction=system_instruction,
ttl="3600s"
)
)
context_data["cache_id"] = cache.name
with open(args.context, "w") as f:
json.dump(context_data, f, indent=4)
print(f"Context Cache created: {cache.name}")
except Exception as e:
# Catch the specific size error and fall back
if "too small" in str(e).lower() or "1024" in str(e):
print("Notice: Files are too small for server-side caching (under 1024 tokens). Falling back to standard processing.")
cache_too_small = True
else:
raise e # Reraise if it's a different error (like authentication)
elif not cache_too_small:
print(f"Loading existing cache: {context_data['cache_id']}")
print("Extending cache TTL by 60 minutes...")
try:
client.caches.update(
name=context_data["cache_id"],
config=types.UpdateCachedContentConfig(
ttl="3600s"
)
)
except Exception as e:
print(f"Warning: Failed to update cache TTL. It may have expired. {e}")
# ---------------------------------------------------------
# GENERATION LOGIC
# ---------------------------------------------------------
if prompt_text:
config_kwargs = {
"max_output_tokens": 65536,
"temperature": 0.0
}
# Setup contents array
generation_contents = []
if context_data.get("cache_id") and not cache_too_small:
# If we successfully cached, we just pass the cache ID in the config
config_kwargs["cached_content"] = context_data["cache_id"]
else:
# If we didn't cache (or it was too small), pass the files and system instruction directly
generation_contents.extend(file_objects)
config_kwargs["system_instruction"] = system_instruction
generation_contents.append(prompt_text)
config = types.GenerateContentConfig(**config_kwargs)
print("Generating response (this may take a moment for large outputs)...")
response_stream = client.models.generate_content_stream(
model=args.model,
contents=generation_contents,
config=config
) )
if args.output: cache_too_small = False
with open(args.output, "w") as f: file_objects = []
if context_data.get("file_ids"):
file_objects = [client.files.get(name=f_id) for f_id in context_data["file_ids"]]
if not context_data.get("cache_id"):
print("Attempting to create Context Cache on Google's servers...")
try:
cache = client.caches.create(
model=args.model,
config=types.CreateCachedContentConfig(
contents=file_objects,
system_instruction=system_instruction,
ttl="3600s"
)
)
context_data["cache_id"] = cache.name
if args.context:
with open(args.context, "w") as f:
json.dump(context_data, f, indent=4)
print(f"Context Cache created: {cache.name}")
except Exception as e:
if "too small" in str(e).lower() or "1024" in str(e):
print("Notice: Files are too small for server-side caching (under 1024 tokens). Falling back to standard processing.")
cache_too_small = True
else:
raise e
elif not cache_too_small:
print(f"Loading existing cache: {context_data['cache_id']}")
print("Extending cache TTL by 60 minutes...")
try:
client.caches.update(
name=context_data["cache_id"],
config=types.UpdateCachedContentConfig(ttl="3600s")
)
except Exception as e:
print(f"Warning: Failed to update cache TTL. {e}")
# ---------------------------------------------------------
# GENERATION LOGIC
# ---------------------------------------------------------
if prompt_text:
config_kwargs = {
"max_output_tokens": 65536,
"temperature": 0.0
}
generation_contents = []
if context_data.get("cache_id") and not cache_too_small:
config_kwargs["cached_content"] = context_data["cache_id"]
else:
generation_contents.extend(file_objects)
config_kwargs["system_instruction"] = system_instruction
generation_contents.append(prompt_text)
config = types.GenerateContentConfig(**config_kwargs)
print("Generating response (this may take a moment for large outputs)...")
response_stream = client.models.generate_content_stream(
model=args.model,
contents=generation_contents,
config=config
)
if args.output:
with open(args.output, "w") as f:
for chunk in response_stream:
if chunk.text:
f.write(chunk.text)
f.flush()
print(f"\nDone! Raw output saved directly to {args.output}")
else:
print("-" * 40)
for chunk in response_stream: for chunk in response_stream:
if chunk.text: if chunk.text:
f.write(chunk.text) print(chunk.text, end="", flush=True)
f.flush() print("\n" + "-" * 40)
print(f"\nDone! Raw output saved directly to {args.output}")
else: finally:
print("-" * 40) # ---------------------------------------------------------
for chunk in response_stream: # TRANSIENT MODE CLEANUP
if chunk.text: # ---------------------------------------------------------
print(chunk.text, end="", flush=True) if not args.context and not args.destroy:
print("\n" + "-" * 40) print("\n[Transient Mode] Cleaning up resources...")
if context_data.get("cache_id"):
try:
client.caches.delete(name=context_data["cache_id"])
print(f"Deleted cache: {context_data['cache_id']}")
except Exception as e:
print(f"Warning: Failed to delete cache. {e}", file=sys.stderr)
for file_id in context_data.get("file_ids", []):
try:
client.files.delete(name=file_id)
print(f"Deleted file: {file_id}")
except Exception as e:
print(f"Warning: Failed to delete file '{file_id}'. {e}", file=sys.stderr)
if __name__ == "__main__": if __name__ == "__main__":
main() main()