Working on the gemini CLI
This commit is contained in:
+59
-41
@@ -3,9 +3,8 @@ import argparse
|
|||||||
import json
|
import json
|
||||||
import os
|
import os
|
||||||
import sys
|
import sys
|
||||||
import datetime
|
from google import genai
|
||||||
import google.generativeai as genai
|
from google.genai import types
|
||||||
from google.generativeai import caching
|
|
||||||
|
|
||||||
def main():
|
def main():
|
||||||
parser = argparse.ArgumentParser(description="Gemini API CLI with File & Context Caching")
|
parser = argparse.ArgumentParser(description="Gemini API CLI with File & Context Caching")
|
||||||
@@ -13,8 +12,8 @@ def main():
|
|||||||
help="Path to the context file (defaults to .gemini_context.json)")
|
help="Path to the context file (defaults to .gemini_context.json)")
|
||||||
parser.add_argument("-f", "--files", nargs="+", default=[],
|
parser.add_argument("-f", "--files", nargs="+", default=[],
|
||||||
help="Files to upload to the Gemini API")
|
help="Files to upload to the Gemini API")
|
||||||
parser.add_argument("-m", "--model", type=str, default="models/gemini-1.5-pro-001",
|
parser.add_argument("-m", "--model", type=str, default="gemini-3.1-flash-lite",
|
||||||
help="The model to use (default: models/gemini-1.5-pro-001)")
|
help="The model to use (default: gemini-3.1-flash-lite)")
|
||||||
parser.add_argument("-d", "--destroy", action="store_true",
|
parser.add_argument("-d", "--destroy", action="store_true",
|
||||||
help="Destroy the cloud files and cache, and delete the local context file")
|
help="Destroy the cloud files and cache, and delete the local context file")
|
||||||
parser.add_argument("-o", "--output", type=str,
|
parser.add_argument("-o", "--output", type=str,
|
||||||
@@ -29,12 +28,12 @@ def main():
|
|||||||
# Determine the prompt text
|
# Determine the prompt text
|
||||||
prompt_text = args.prompt if args.prompt else " ".join(args.positional_prompt) if args.positional_prompt else None
|
prompt_text = args.prompt if args.prompt else " ".join(args.positional_prompt) if args.positional_prompt else None
|
||||||
|
|
||||||
# Authenticate
|
# Authenticate: The new Client() automatically looks for the GEMINI_API_KEY environment variable.
|
||||||
api_key = os.environ.get("GEMINI_API_KEY")
|
if not os.environ.get("GEMINI_API_KEY"):
|
||||||
if not api_key:
|
|
||||||
print("Error: GEMINI_API_KEY environment variable is not set.", file=sys.stderr)
|
print("Error: GEMINI_API_KEY environment variable is not set.", file=sys.stderr)
|
||||||
sys.exit(1)
|
sys.exit(1)
|
||||||
genai.configure(api_key=api_key)
|
|
||||||
|
client = genai.Client()
|
||||||
|
|
||||||
context_data = {"file_ids": [], "cache_id": None}
|
context_data = {"file_ids": [], "cache_id": None}
|
||||||
|
|
||||||
@@ -53,14 +52,14 @@ def main():
|
|||||||
print("Destroying server resources and local context...")
|
print("Destroying server resources and local context...")
|
||||||
if context_data.get("cache_id"):
|
if context_data.get("cache_id"):
|
||||||
try:
|
try:
|
||||||
caching.CachedContent.get(context_data["cache_id"]).delete()
|
client.caches.delete(name=context_data["cache_id"])
|
||||||
print(f"Deleted cache: {context_data['cache_id']}")
|
print(f"Deleted cache: {context_data['cache_id']}")
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
print(f"Warning: Failed to delete cache. {e}", file=sys.stderr)
|
print(f"Warning: Failed to delete cache. {e}", file=sys.stderr)
|
||||||
|
|
||||||
for file_id in context_data.get("file_ids", []):
|
for file_id in context_data.get("file_ids", []):
|
||||||
try:
|
try:
|
||||||
genai.delete_file(file_id)
|
client.files.delete(name=file_id)
|
||||||
print(f"Deleted file: {file_id}")
|
print(f"Deleted file: {file_id}")
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
print(f"Warning: Failed to delete file '{file_id}'. {e}", file=sys.stderr)
|
print(f"Warning: Failed to delete file '{file_id}'. {e}", file=sys.stderr)
|
||||||
@@ -82,7 +81,7 @@ def main():
|
|||||||
continue
|
continue
|
||||||
|
|
||||||
print(f"Uploading '{file_path}'...")
|
print(f"Uploading '{file_path}'...")
|
||||||
uploaded_file = genai.upload_file(file_path)
|
uploaded_file = client.files.upload(file=file_path)
|
||||||
print(f"Success: '{file_path}' uploaded as '{uploaded_file.name}'")
|
print(f"Success: '{file_path}' uploaded as '{uploaded_file.name}'")
|
||||||
|
|
||||||
if uploaded_file.name not in context_data["file_ids"]:
|
if uploaded_file.name not in context_data["file_ids"]:
|
||||||
@@ -94,68 +93,87 @@ def main():
|
|||||||
# ---------------------------------------------------------
|
# ---------------------------------------------------------
|
||||||
# CACHE CREATION AND TTL EXTENSION LOGIC
|
# CACHE CREATION AND TTL EXTENSION LOGIC
|
||||||
# ---------------------------------------------------------
|
# ---------------------------------------------------------
|
||||||
cache = None
|
|
||||||
system_instruction = (
|
system_instruction = (
|
||||||
"You are a strict data extraction tool. Output exactly what is requested "
|
# "You are a strict data extraction tool. Output exactly what is requested "
|
||||||
"(e.g., CSV). Never use markdown formatting blocks (like ```csv). "
|
# "(e.g., CSV). Never use markdown formatting blocks (like ```csv). "
|
||||||
"Never include conversational text, greetings, or explanations. Output raw data only."
|
# "Never include conversational text, greetings, or explanations. Output raw data only."
|
||||||
|
"You are a hybrid data extraction tool. If a specific format or file format output is requested "
|
||||||
|
"(e.g., CSV), then output exactly what was requested and never use markdown formatting blocks (like ```csv). "
|
||||||
|
"If a specific format was requested, never include conversational text, greetings, or explanations; "
|
||||||
|
"and output raw data only. If not specific data format is suggested, you can answer with conversational text."
|
||||||
)
|
)
|
||||||
|
|
||||||
if context_data.get("file_ids"):
|
if context_data.get("file_ids"):
|
||||||
if not context_data.get("cache_id"):
|
if not context_data.get("cache_id"):
|
||||||
print("Creating Context Cache on Google's servers...")
|
print("Creating Context Cache on Google's servers...")
|
||||||
file_objects = [genai.get_file(f_id) for f_id in context_data["file_ids"]]
|
|
||||||
|
|
||||||
cache = caching.CachedContent.create(
|
# Retrieve the file objects
|
||||||
|
file_objects = [client.files.get(name=f_id) for f_id in context_data["file_ids"]]
|
||||||
|
|
||||||
|
# Create the server-side cache (Set to expire in 3600 seconds / 60 minutes)
|
||||||
|
cache = client.caches.create(
|
||||||
model=args.model,
|
model=args.model,
|
||||||
system_instruction=system_instruction,
|
config=types.CreateCachedContentConfig(
|
||||||
contents=file_objects,
|
contents=file_objects,
|
||||||
ttl=datetime.timedelta(minutes=60)
|
system_instruction=system_instruction,
|
||||||
|
ttl="3600s"
|
||||||
|
)
|
||||||
)
|
)
|
||||||
context_data["cache_id"] = cache.name
|
context_data["cache_id"] = cache.name
|
||||||
with open(args.context, "w") as f:
|
with open(args.context, "w") as f:
|
||||||
json.dump(context_data, f, indent=4)
|
json.dump(context_data, f, indent=4)
|
||||||
print(f"Context Cache created: {cache.name}")
|
print(f"Context Cache created: {cache.name}")
|
||||||
|
|
||||||
else:
|
else:
|
||||||
print(f"Loading existing cache: {context_data['cache_id']}")
|
print(f"Loading existing cache: {context_data['cache_id']}")
|
||||||
cache = caching.CachedContent.get(context_data["cache_id"])
|
|
||||||
print("Extending cache TTL by 60 minutes...")
|
print("Extending cache TTL by 60 minutes...")
|
||||||
cache.update(ttl=datetime.timedelta(minutes=60))
|
client.caches.update(
|
||||||
|
name=context_data["cache_id"],
|
||||||
|
config=types.UpdateCachedContentConfig(
|
||||||
|
ttl="3600s"
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
# ---------------------------------------------------------
|
# ---------------------------------------------------------
|
||||||
# GENERATION LOGIC
|
# GENERATION LOGIC
|
||||||
# ---------------------------------------------------------
|
# ---------------------------------------------------------
|
||||||
if prompt_text:
|
if prompt_text:
|
||||||
generation_config = {
|
config_kwargs = {
|
||||||
"max_output_tokens": 65536,
|
"max_output_tokens": 65536,
|
||||||
"temperature": 0.0
|
"temperature": 0.0
|
||||||
}
|
}
|
||||||
|
|
||||||
if cache:
|
# The new SDK passes the cache ID directly into the generation config.
|
||||||
model = genai.GenerativeModel.from_cached_content(
|
# If caching is used, the system_instruction must be in the cache, not here.
|
||||||
cached_content=cache,
|
if context_data.get("cache_id"):
|
||||||
generation_config=generation_config
|
config_kwargs["cached_content"] = context_data["cache_id"]
|
||||||
)
|
|
||||||
else:
|
else:
|
||||||
# Fallback if no files were uploaded
|
config_kwargs["system_instruction"] = system_instruction
|
||||||
model = genai.GenerativeModel(
|
|
||||||
model_name=args.model,
|
config = types.GenerateContentConfig(**config_kwargs)
|
||||||
system_instruction=system_instruction,
|
|
||||||
generation_config=generation_config
|
|
||||||
)
|
|
||||||
|
|
||||||
print("Generating response (this may take a moment for large outputs)...")
|
print("Generating response (this may take a moment for large outputs)...")
|
||||||
|
|
||||||
response = model.generate_content(prompt_text)
|
# We use stream generation so it writes immediately, avoiding memory bottlenecks
|
||||||
|
response_stream = client.models.generate_content_stream(
|
||||||
|
model=args.model,
|
||||||
|
contents=prompt_text,
|
||||||
|
config=config
|
||||||
|
)
|
||||||
|
|
||||||
if args.output:
|
if args.output:
|
||||||
with open(args.output, "w") as f:
|
with open(args.output, "w") as f:
|
||||||
f.write(response.text)
|
for chunk in response_stream:
|
||||||
print(f"Done! Raw output saved directly to {args.output}")
|
if chunk.text:
|
||||||
|
f.write(chunk.text)
|
||||||
|
f.flush() # Stream direct to the disk in real-time
|
||||||
|
print(f"\nDone! Raw output saved directly to {args.output}")
|
||||||
else:
|
else:
|
||||||
print("-" * 40)
|
print("-" * 40)
|
||||||
print(response.text)
|
for chunk in response_stream:
|
||||||
print("-" * 40)
|
if chunk.text:
|
||||||
|
print(chunk.text, end="", flush=True)
|
||||||
|
print("\n" + "-" * 40)
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
main()
|
main()
|
||||||
|
|||||||
Reference in New Issue
Block a user