Compare commits
14 Commits
a0412541a8
...
main
| Author | SHA1 | Date | |
|---|---|---|---|
| 82780d4481 | |||
| dad2265871 | |||
| 6fe273d9d2 | |||
| 919d717675 | |||
| ed3a861214 | |||
| 0f44ea375d | |||
| 7e6a2c782b | |||
| 16dbd9cd52 | |||
| 06d8b189dd | |||
| 5c588cbe4f | |||
| 9c2dd68a28 | |||
| 6bd3c6e3ab | |||
| 0b1037e4e1 | |||
| 5d4184f214 |
@@ -0,0 +1,165 @@
|
||||
import csv
|
||||
import argparse
|
||||
import datetime
|
||||
import re
|
||||
import sys
|
||||
|
||||
def clean_instrument_name(instrument_raw, grade_level):
|
||||
"""Cleans the instrument name based on Percussion, Winds, and Drum Major rules."""
|
||||
# Strip numerical prefix
|
||||
if '-' in instrument_raw:
|
||||
instrument = instrument_raw.split('-', 1)[-1].strip()
|
||||
else:
|
||||
instrument = instrument_raw.strip()
|
||||
|
||||
lower_inst = instrument.lower()
|
||||
|
||||
# Pre-catch Drum Major to apply Junior/Senior title
|
||||
if 'drum major' in lower_inst:
|
||||
if grade_level.lower() == 'senior':
|
||||
return 'Senior Drum Major'
|
||||
else:
|
||||
return 'Junior Drum Major'
|
||||
|
||||
# 1. The Percussion Rule
|
||||
if 'percussion' in lower_inst and '(' in instrument and ')' in instrument:
|
||||
match = re.search(r'\((.*?)\)', instrument)
|
||||
if match:
|
||||
extracted = match.group(1).strip().title()
|
||||
|
||||
# Explicitly only append "Drum(s)" to Snare and Tenors
|
||||
if extracted.lower() == 'snare':
|
||||
return 'Snare Drum'
|
||||
elif extracted.lower() in ['tenor', 'tenors']:
|
||||
return 'Tenor Drums'
|
||||
|
||||
return extracted
|
||||
|
||||
# 2. The Winds/General Rule
|
||||
elif '(' in instrument and ')' in instrument:
|
||||
return re.sub(r'\(.*?\)', '', instrument).strip()
|
||||
|
||||
# 3. Default Rule
|
||||
return instrument
|
||||
|
||||
def get_category(instrument):
|
||||
"""Categorize the instrument into groups, isolating Drum Majors first."""
|
||||
inst_lower = instrument.lower()
|
||||
|
||||
# Catch Drum Majors first so they don't get trapped by the 'drum' keyword in Percussion
|
||||
if 'drum major' in inst_lower:
|
||||
return 'Drum Majors'
|
||||
|
||||
if any(x in inst_lower for x in ['flute', 'clarinet', 'sax', 'oboe', 'bassoon', 'piccolo']):
|
||||
return 'Woodwinds'
|
||||
if any(x in inst_lower for x in ['trumpet', 'mellophone', 'horn', 'trombone', 'baritone', 'euphonium', 'tuba', 'sousaphone']):
|
||||
return 'Brass'
|
||||
if any(x in inst_lower for x in ['percussion', 'snare', 'tenor', 'drum', 'cymbal', 'marimba', 'vibraphone', 'timpani', 'bells', 'electronics', 'aux', 'keyboard']):
|
||||
return 'Percussion'
|
||||
if 'guard' in inst_lower or 'color' in inst_lower:
|
||||
return 'Colorguard'
|
||||
|
||||
return 'Leadership' # Fallback for other potential roles
|
||||
|
||||
def format_phone(phone_str):
|
||||
"""Clean and standardize the phone numbers."""
|
||||
phone = str(phone_str).strip()
|
||||
if not phone:
|
||||
return ""
|
||||
|
||||
# Leave non-US international codes as they are
|
||||
if phone.startswith('+') and not phone.startswith('+1'):
|
||||
return phone
|
||||
|
||||
# Standardize US numbers to XXX-XXX-XXXX
|
||||
digits = re.sub(r'\D', '', phone)
|
||||
if len(digits) == 10:
|
||||
return f"{digits[:3]}-{digits[3:6]}-{digits[6:]}"
|
||||
elif len(digits) == 11 and digits.startswith('1'):
|
||||
return f"{digits[1:4]}-{digits[4:7]}-{digits[7:]}"
|
||||
|
||||
return phone
|
||||
|
||||
def main():
|
||||
parser = argparse.ArgumentParser(description="Convert Band Leadership CSV to Google Contacts format.")
|
||||
parser.add_argument("input_file", help="Path to the input CSV file")
|
||||
parser.add_argument("-y", "--year", type=int, default=datetime.datetime.now().year,
|
||||
help="Target year for labels and grade calculation (defaults to current year)")
|
||||
args = parser.parse_args()
|
||||
|
||||
target_year = args.year
|
||||
|
||||
# Google Contacts Headers matching your expected output
|
||||
google_headers = [
|
||||
"First Name", "Middle Name", "Last Name", "Phonetic First Name", "Phonetic Middle Name",
|
||||
"Phonetic Last Name", "Name Prefix", "Name Suffix", "Nickname", "File As",
|
||||
"Organization Name", "Organization Title", "Organization Department", "Birthday", "Notes",
|
||||
"Photo", "Labels", "E-mail 1 - Label", "E-mail 1 - Value", "Phone 1 - Label",
|
||||
"Phone 1 - Value", "Phone 2 - Label", "Phone 2 - Value", "Phone 3 - Label", "Phone 3 - Value",
|
||||
"Address 1 - Label", "Address 1 - Formatted", "Address 1 - Street", "Address 1 - City",
|
||||
"Address 1 - PO Box", "Address 1 - Region", "Address 1 - Postal Code", "Address 1 - Country",
|
||||
"Address 1 - Extended Address"
|
||||
]
|
||||
|
||||
with open(args.input_file, mode='r', encoding='utf-8') as infile:
|
||||
reader = csv.DictReader(infile)
|
||||
# lineterminator='\n' ensures consistent newlines across OS when printing to stdout
|
||||
writer = csv.DictWriter(sys.stdout, fieldnames=google_headers, lineterminator='\n')
|
||||
writer.writeheader()
|
||||
|
||||
for row in reader:
|
||||
# Skip empty rows or the trailing blank commas in the source file
|
||||
name_field = row.get('NAME', '').strip()
|
||||
if not name_field or name_field == ',':
|
||||
continue
|
||||
|
||||
# Parse Name ("Last, First")
|
||||
name_parts = name_field.split(',')
|
||||
last_name = name_parts[0].strip() if len(name_parts) > 0 else ""
|
||||
first_name = name_parts[1].strip() if len(name_parts) > 1 else ""
|
||||
|
||||
# Parse Grade and formulate Notes FIRST (so we have grade_level for instruments)
|
||||
grade_raw = row.get('GRADE', '').strip()
|
||||
grade_match = re.match(r"(\d{4})\s*\((.*?)\)", grade_raw)
|
||||
grade_level = ""
|
||||
|
||||
if grade_match:
|
||||
grad_year = grade_match.group(1)
|
||||
grade_level = grade_match.group(2).capitalize()
|
||||
notes = f"Class of {grad_year}. A {grade_level} in {target_year}"
|
||||
else:
|
||||
notes = grade_raw
|
||||
|
||||
# Parse and Clean Instrument
|
||||
section_raw = row.get('SECTION', '').strip()
|
||||
instrument = clean_instrument_name(section_raw, grade_level)
|
||||
if not instrument:
|
||||
instrument = "Unknown"
|
||||
|
||||
# Determine Label
|
||||
category = get_category(instrument)
|
||||
label = f"{target_year} {category} ::: {target_year} Marching Band ::: * myContacts"
|
||||
|
||||
# Build the output row
|
||||
out_row = {key: "" for key in google_headers}
|
||||
out_row["First Name"] = first_name
|
||||
out_row["Last Name"] = f"{last_name} ({instrument})"
|
||||
out_row["Notes"] = notes
|
||||
out_row["Labels"] = label
|
||||
|
||||
# Map Email
|
||||
email = row.get('EMAIL', '').strip()
|
||||
if email:
|
||||
out_row["E-mail 1 - Label"] = "Home"
|
||||
out_row["E-mail 1 - Value"] = email
|
||||
|
||||
# Map Phone
|
||||
phone = row.get('PHONE', '').strip()
|
||||
if phone:
|
||||
out_row["Phone 1 - Label"] = "Mobile"
|
||||
out_row["Phone 1 - Value"] = format_phone(phone)
|
||||
|
||||
writer.writerow(out_row)
|
||||
|
||||
if __name__ == '__main__':
|
||||
main()
|
||||
@@ -0,0 +1 @@
|
||||
source gemini_env/bin/activate
|
||||
+288
-52
@@ -3,19 +3,88 @@ import argparse
|
||||
import json
|
||||
import os
|
||||
import sys
|
||||
import urllib.request
|
||||
from google import genai
|
||||
from google.genai import types
|
||||
|
||||
PRICING_FILE = ".gemini_pricing.json"
|
||||
|
||||
def fetch_pricing_for_model(client, target_model):
|
||||
url = "https://cloud.google.com/gemini-enterprise-agent-platform/generative-ai/pricing"
|
||||
print(f"Fetching live pricing for {target_model} from the web...")
|
||||
|
||||
try:
|
||||
req = urllib.request.Request(url, headers={'User-Agent': 'Mozilla/5.0'})
|
||||
with urllib.request.urlopen(req) as response:
|
||||
html = response.read().decode('utf-8')
|
||||
except Exception as e:
|
||||
print(f"Warning: Failed to fetch HTML from {url}: {e}", file=sys.stderr)
|
||||
return None
|
||||
|
||||
prompt = f"""
|
||||
Extract the API pricing for the model '{target_model}' from the following HTML text.
|
||||
Find the cost per 1 million tokens for input, cached content, and output.
|
||||
Many models have a split tier where the cost increases if the prompt exceeds 200k tokens.
|
||||
|
||||
CRITICAL: If the model '{target_model}' is definitively NOT found in the HTML data, return an empty JSON object: {{}}
|
||||
|
||||
Otherwise, return ONLY a valid JSON object with this exact structure:
|
||||
{{
|
||||
"{target_model}": {{
|
||||
"input": 0.00,
|
||||
"cached": 0.00,
|
||||
"output": 0.00,
|
||||
"input_over_200k": 0.00,
|
||||
"cached_over_200k": 0.00,
|
||||
"output_over_200k": 0.00
|
||||
}}
|
||||
}}
|
||||
If a tier value is not found, duplicate the base tier values.
|
||||
|
||||
HTML DATA:
|
||||
{html}
|
||||
"""
|
||||
|
||||
config = types.GenerateContentConfig(
|
||||
response_mime_type="application/json",
|
||||
temperature=0.0
|
||||
)
|
||||
|
||||
print("Parsing pricing data via background AI session...")
|
||||
try:
|
||||
# We use a fast, cheap model just for the parsing task
|
||||
res = client.models.generate_content(
|
||||
model="gemini-3.1-flash-lite",
|
||||
contents=prompt,
|
||||
config=config
|
||||
)
|
||||
new_data = json.loads(res.text)
|
||||
|
||||
# Strip the array wrapper if the AI returned a list instead of a pure dict
|
||||
if isinstance(new_data, list) and len(new_data) > 0:
|
||||
new_data = new_data[0]
|
||||
|
||||
print(f"Successfully retrieved pricing for {target_model}.")
|
||||
print(new_data)
|
||||
return new_data
|
||||
except Exception as e:
|
||||
print(f"Warning: Failed to extract pricing using AI: {e}", file=sys.stderr)
|
||||
return None
|
||||
|
||||
def main():
|
||||
parser = argparse.ArgumentParser(description="Gemini API CLI with File & Context Caching")
|
||||
parser.add_argument("-c", "--context", type=str, default=None,
|
||||
help="Path to context file. If omitted, files/caches are deleted after execution.")
|
||||
help="Path to context file. If omitted, transient mode is used.")
|
||||
parser.add_argument("-f", "--files", nargs="+", default=[],
|
||||
help="Files to upload to the Gemini API")
|
||||
parser.add_argument("-m", "--model", type=str, default="gemini-3.1-flash-lite",
|
||||
help="The model to use (default: gemini-3.1-flash-lite)")
|
||||
parser.add_argument("-d", "--destroy", action="store_true",
|
||||
help="Destroy the cloud files and cache, and delete the local context file")
|
||||
help="Destroy cloud files/cache, and delete local context")
|
||||
parser.add_argument("-x", "--clear-history", action="store_true",
|
||||
help="Clear the conversation history without destroying files/caches")
|
||||
parser.add_argument("--pricing", action="store_true",
|
||||
help="Force update the pricing info for the specified model from the web")
|
||||
parser.add_argument("-o", "--output", type=str,
|
||||
help="Direct the raw output to a specific file instead of stdout")
|
||||
parser.add_argument("-p", "--prompt", type=str,
|
||||
@@ -24,7 +93,6 @@ def main():
|
||||
help="Positional arguments treated as the prompt if -p is omitted")
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
prompt_text = args.prompt if args.prompt else " ".join(args.positional_prompt) if args.positional_prompt else None
|
||||
|
||||
if not os.environ.get("GEMINI_API_KEY"):
|
||||
@@ -32,26 +100,74 @@ def main():
|
||||
sys.exit(1)
|
||||
|
||||
client = genai.Client()
|
||||
context_data = {"file_ids": [], "cache_id": None}
|
||||
|
||||
# ---------------------------------------------------------
|
||||
# PRICING CONFIGURATION
|
||||
# ---------------------------------------------------------
|
||||
pricing_data = {}
|
||||
|
||||
if os.path.exists(PRICING_FILE):
|
||||
try:
|
||||
with open(PRICING_FILE, "r") as f:
|
||||
pricing_data = json.load(f)
|
||||
except json.JSONDecodeError:
|
||||
print(f"Warning: {PRICING_FILE} is corrupted. Starting fresh.", file=sys.stderr)
|
||||
|
||||
# Fetch pricing if forced, or if the model isn't currently tracked
|
||||
if args.pricing or args.model not in pricing_data:
|
||||
new_pricing = fetch_pricing_for_model(client, args.model)
|
||||
|
||||
if new_pricing and args.model in new_pricing:
|
||||
pricing_data.update(new_pricing)
|
||||
with open(PRICING_FILE, "w") as f:
|
||||
json.dump(pricing_data, f, indent=4)
|
||||
else:
|
||||
print(f"Warning: Could not fetch pricing for {args.model}. Estimating cost at $0.00.", file=sys.stderr)
|
||||
# Add a fallback zero-value so the script doesn't crash during cost calculation
|
||||
pricing_data[args.model] = {"input": 0.0, "cached": 0.0, "output": 0.0}
|
||||
|
||||
# If the user only requested a pricing update and nothing else, exit cleanly
|
||||
if args.pricing and not prompt_text and not args.files and not args.destroy and not args.clear_history:
|
||||
return
|
||||
|
||||
# ---------------------------------------------------------
|
||||
# STATE MANAGEMENT
|
||||
# ---------------------------------------------------------
|
||||
context_data = {"file_ids": [], "caches": {}, "history": []}
|
||||
|
||||
if args.context and os.path.exists(args.context):
|
||||
try:
|
||||
with open(args.context, "r") as f:
|
||||
context_data = json.load(f)
|
||||
loaded_data = json.load(f)
|
||||
context_data["file_ids"] = loaded_data.get("file_ids", [])
|
||||
context_data["caches"] = loaded_data.get("caches", {})
|
||||
context_data["history"] = loaded_data.get("history", [])
|
||||
except json.JSONDecodeError:
|
||||
print(f"Warning: Could not parse {args.context}. Starting fresh.", file=sys.stderr)
|
||||
|
||||
# ---------------------------------------------------------
|
||||
# CLEAR HISTORY FLAG
|
||||
# ---------------------------------------------------------
|
||||
if args.clear_history:
|
||||
context_data["history"] = []
|
||||
if args.context:
|
||||
with open(args.context, "w") as f:
|
||||
json.dump(context_data, f, indent=4)
|
||||
print("Conversation history cleared.")
|
||||
if not prompt_text and not args.files and not args.destroy:
|
||||
return
|
||||
|
||||
# ---------------------------------------------------------
|
||||
# DESTROY FLAG LOGIC
|
||||
# ---------------------------------------------------------
|
||||
if args.destroy:
|
||||
print("Destroying server resources and local context...")
|
||||
if context_data.get("cache_id"):
|
||||
for model_name, cache_info in context_data.get("caches", {}).items():
|
||||
try:
|
||||
client.caches.delete(name=context_data["cache_id"])
|
||||
print(f"Deleted cache: {context_data['cache_id']}")
|
||||
client.caches.delete(name=cache_info["cache_id"])
|
||||
print(f"Deleted cache for {model_name}: {cache_info['cache_id']}")
|
||||
except Exception as e:
|
||||
print(f"Warning: Failed to delete cache. {e}", file=sys.stderr)
|
||||
print(f"Warning: Failed to delete cache for {model_name}. {e}", file=sys.stderr)
|
||||
|
||||
for file_id in context_data.get("file_ids", []):
|
||||
try:
|
||||
@@ -71,27 +187,34 @@ def main():
|
||||
# ---------------------------------------------------------
|
||||
# UPLOAD LOGIC
|
||||
# ---------------------------------------------------------
|
||||
new_files_added = False
|
||||
if args.files:
|
||||
for file_path in args.files:
|
||||
if not os.path.exists(file_path):
|
||||
print(f"Warning: File '{file_path}' not found. Skipping.", file=sys.stderr)
|
||||
continue
|
||||
|
||||
print(f"Uploading '{file_path}'...")
|
||||
uploaded_file = client.files.upload(file=file_path)
|
||||
print(f"Success: '{file_path}' uploaded as '{uploaded_file.name}'")
|
||||
# Sanitize the filename for HTTP headers (replace non-ASCII with underscores)
|
||||
base_name = os.path.basename(file_path)
|
||||
safe_name = "".join([c if ord(c) < 128 else "_" for c in base_name])
|
||||
|
||||
print(f"Uploading '{file_path}'...", file=sys.stderr)
|
||||
|
||||
# Force the SDK to use our sanitized name for the upload display name
|
||||
uploaded_file = client.files.upload(
|
||||
file=file_path,
|
||||
config={'display_name': safe_name}
|
||||
)
|
||||
print(f"Success: '{file_path}' uploaded as '{uploaded_file.name}'", file=sys.stderr)
|
||||
|
||||
if uploaded_file.name not in context_data["file_ids"]:
|
||||
context_data["file_ids"].append(uploaded_file.name)
|
||||
new_files_added = True
|
||||
|
||||
if args.context:
|
||||
with open(args.context, "w") as f:
|
||||
json.dump(context_data, f, indent=4)
|
||||
|
||||
# ---------------------------------------------------------
|
||||
# CACHE CREATION LOGIC
|
||||
# CACHE VALIDATION & CREATION LOGIC
|
||||
# ---------------------------------------------------------
|
||||
system_instruction = (
|
||||
"You are a hybrid data extraction tool. If a specific format or file format output is requested "
|
||||
@@ -100,23 +223,47 @@ def main():
|
||||
"and output raw data only. If not specific data format is suggested, you can answer with conversational text."
|
||||
)
|
||||
|
||||
# If new files were added but a cache already exists, we must destroy the stale cache.
|
||||
if new_files_added and context_data.get("cache_id"):
|
||||
print("New files detected. Destroying stale cache to rebuild...")
|
||||
try:
|
||||
client.caches.delete(name=context_data["cache_id"])
|
||||
except Exception as e:
|
||||
print(f"Warning: Could not delete stale cache. {e}", file=sys.stderr)
|
||||
context_data["cache_id"] = None
|
||||
|
||||
cache_too_small = False
|
||||
file_objects = []
|
||||
active_cache_id = None
|
||||
|
||||
if context_data.get("file_ids"):
|
||||
file_objects = [client.files.get(name=f_id) for f_id in context_data["file_ids"]]
|
||||
|
||||
if not context_data.get("cache_id"):
|
||||
print("Attempting to create Context Cache on Google's servers...")
|
||||
model_cache_info = context_data["caches"].get(args.model)
|
||||
current_files_set = set(context_data["file_ids"])
|
||||
rebuild_cache = False
|
||||
|
||||
if model_cache_info:
|
||||
cached_files_set = set(model_cache_info.get("cached_file_ids", []))
|
||||
|
||||
# Check 1: Have the files changed?
|
||||
if current_files_set != cached_files_set:
|
||||
print(f"File list changed. Destroying stale {args.model} cache to rebuild...")
|
||||
try:
|
||||
client.caches.delete(name=model_cache_info["cache_id"])
|
||||
except Exception as e:
|
||||
print(f"Warning: Could not delete stale cache. {e}", file=sys.stderr)
|
||||
rebuild_cache = True
|
||||
else:
|
||||
# Check 2: Does the cache still exist on the server?
|
||||
print(f"Verifying existing cache for {args.model}: {model_cache_info['cache_id']}...")
|
||||
try:
|
||||
client.caches.get(name=model_cache_info['cache_id'])
|
||||
print("Cache verified. Extending TTL by 60 minutes...")
|
||||
client.caches.update(
|
||||
name=model_cache_info['cache_id'],
|
||||
config=types.UpdateCachedContentConfig(ttl="3600s")
|
||||
)
|
||||
active_cache_id = model_cache_info['cache_id']
|
||||
except Exception as e:
|
||||
print(f"Cache expired or not found. Flagging for rebuild. ({e})")
|
||||
rebuild_cache = True
|
||||
else:
|
||||
rebuild_cache = True
|
||||
|
||||
if rebuild_cache:
|
||||
print(f"Attempting to create Context Cache for {args.model}...")
|
||||
try:
|
||||
cache = client.caches.create(
|
||||
model=args.model,
|
||||
@@ -126,7 +273,13 @@ def main():
|
||||
ttl="3600s"
|
||||
)
|
||||
)
|
||||
context_data["cache_id"] = cache.name
|
||||
|
||||
context_data["caches"][args.model] = {
|
||||
"cache_id": cache.name,
|
||||
"cached_file_ids": list(context_data["file_ids"])
|
||||
}
|
||||
active_cache_id = cache.name
|
||||
|
||||
if args.context:
|
||||
with open(args.context, "w") as f:
|
||||
json.dump(context_data, f, indent=4)
|
||||
@@ -139,19 +292,8 @@ def main():
|
||||
else:
|
||||
raise e
|
||||
|
||||
elif not cache_too_small:
|
||||
print(f"Loading existing cache: {context_data['cache_id']}")
|
||||
print("Extending cache TTL by 60 minutes...")
|
||||
try:
|
||||
client.caches.update(
|
||||
name=context_data["cache_id"],
|
||||
config=types.UpdateCachedContentConfig(ttl="3600s")
|
||||
)
|
||||
except Exception as e:
|
||||
print(f"Warning: Failed to update cache TTL. {e}")
|
||||
|
||||
# ---------------------------------------------------------
|
||||
# GENERATION LOGIC
|
||||
# GENERATION LOGIC (WITH HISTORY)
|
||||
# ---------------------------------------------------------
|
||||
if prompt_text:
|
||||
config_kwargs = {
|
||||
@@ -159,22 +301,47 @@ def main():
|
||||
"temperature": 0.0
|
||||
}
|
||||
|
||||
generation_contents = []
|
||||
|
||||
if context_data.get("cache_id") and not cache_too_small:
|
||||
config_kwargs["cached_content"] = context_data["cache_id"]
|
||||
if active_cache_id and not cache_too_small:
|
||||
config_kwargs["cached_content"] = active_cache_id
|
||||
else:
|
||||
generation_contents.extend(file_objects)
|
||||
config_kwargs["system_instruction"] = system_instruction
|
||||
|
||||
generation_contents.append(prompt_text)
|
||||
# Build the chat payload using strict dictionary representations
|
||||
api_contents = []
|
||||
files_added_to_payload = False
|
||||
|
||||
for msg in context_data.get("history", []):
|
||||
parts = []
|
||||
# If we aren't using a cache, attach un-cached files to the very first user message
|
||||
if msg["role"] == "user" and not files_added_to_payload and not active_cache_id and file_objects:
|
||||
for f in file_objects:
|
||||
parts.append({"file_data": {"file_uri": f.uri, "mime_type": f.mime_type}})
|
||||
files_added_to_payload = True
|
||||
|
||||
parts.append({"text": msg["text"]})
|
||||
api_contents.append({"role": msg["role"], "parts": parts})
|
||||
|
||||
current_parts = []
|
||||
if not files_added_to_payload and not active_cache_id and file_objects:
|
||||
for f in file_objects:
|
||||
current_parts.append({"file_data": {"file_uri": f.uri, "mime_type": f.mime_type}})
|
||||
files_added_to_payload = True
|
||||
|
||||
current_parts.append({"text": prompt_text})
|
||||
api_contents.append({"role": "user", "parts": current_parts})
|
||||
|
||||
config = types.GenerateContentConfig(**config_kwargs)
|
||||
|
||||
print("Generating response (this may take a moment for large outputs)...")
|
||||
|
||||
full_response_text = ""
|
||||
usage_metadata = None
|
||||
finish_reason_str = "UNKNOWN"
|
||||
|
||||
try:
|
||||
response_stream = client.models.generate_content_stream(
|
||||
model=args.model,
|
||||
contents=generation_contents,
|
||||
contents=api_contents,
|
||||
config=config
|
||||
)
|
||||
|
||||
@@ -184,13 +351,82 @@ def main():
|
||||
if chunk.text:
|
||||
f.write(chunk.text)
|
||||
f.flush()
|
||||
print(f"\nDone! Raw output saved directly to {args.output}")
|
||||
full_response_text += chunk.text
|
||||
if chunk.usage_metadata:
|
||||
usage_metadata = chunk.usage_metadata
|
||||
if chunk.candidates and chunk.candidates[0].finish_reason:
|
||||
finish_reason_str = chunk.candidates[0].finish_reason.name
|
||||
print(f"Done! Raw output saved directly to {args.output}")
|
||||
else:
|
||||
print("-" * 40)
|
||||
print()
|
||||
for chunk in response_stream:
|
||||
if chunk.text:
|
||||
print(chunk.text, end="", flush=True)
|
||||
print("\n" + "-" * 40)
|
||||
full_response_text += chunk.text
|
||||
if chunk.usage_metadata:
|
||||
usage_metadata = chunk.usage_metadata
|
||||
if chunk.candidates and chunk.candidates[0].finish_reason:
|
||||
finish_reason_str = chunk.candidates[0].finish_reason.name
|
||||
print("\n")
|
||||
|
||||
except Exception as e:
|
||||
# Catch 404 Model Not Found or other critical generation errors gracefully
|
||||
if "404" in str(e) and "NOT_FOUND" in str(e):
|
||||
print(f"\n[Error] The model '{args.model}' does not exist or is not available.", file=sys.stderr)
|
||||
else:
|
||||
print(f"\n[API Error] {e}", file=sys.stderr)
|
||||
|
||||
# Exit cleanly so we don't calculate costs or save bad history
|
||||
return
|
||||
|
||||
# ---------------------------------------------------------
|
||||
# USAGE AND COST CALCULATION
|
||||
# ---------------------------------------------------------
|
||||
if usage_metadata:
|
||||
prompt_tokens = usage_metadata.prompt_token_count or 0
|
||||
output_tokens = usage_metadata.candidates_token_count or 0
|
||||
cached_tokens = getattr(usage_metadata, 'cached_content_token_count', 0) or 0
|
||||
|
||||
# Revert to max(0) to handle Google's padding discrepancy where prompt_tokens < cached_tokens
|
||||
uncached_tokens = max(0, prompt_tokens - cached_tokens)
|
||||
|
||||
# Ensure the tier logic checks the absolute largest representation of the payload
|
||||
total_input_tokens = max(prompt_tokens, cached_tokens)
|
||||
|
||||
# Fetch the rate dynamically from the parsed JSON or default to 0.0 if fetch failed
|
||||
rates = pricing_data.get(args.model, {"input": 0.0, "cached": 0.0, "output": 0.0})
|
||||
|
||||
input_rate = rates.get("input", 0.0)
|
||||
cached_rate = rates.get("cached", 0.0)
|
||||
output_rate = rates.get("output", 0.0)
|
||||
tier_label = "Base Tier"
|
||||
|
||||
# Apply 200k+ tier pricing if total prompt size exceeds 200k
|
||||
if total_input_tokens > 200_000 and "input_over_200k" in rates:
|
||||
if rates.get("input_over_200k", 0.0) > 0:
|
||||
input_rate = rates["input_over_200k"]
|
||||
tier_label = ">200k Tier"
|
||||
if rates.get("cached_over_200k", 0.0) > 0:
|
||||
cached_rate = rates["cached_over_200k"]
|
||||
if rates.get("output_over_200k", 0.0) > 0:
|
||||
output_rate = rates["output_over_200k"]
|
||||
|
||||
input_cost = (uncached_tokens / 1_000_000) * input_rate
|
||||
cached_cost = (cached_tokens / 1_000_000) * cached_rate
|
||||
output_cost = (output_tokens / 1_000_000) * output_rate
|
||||
total_cost = input_cost + cached_cost + output_cost
|
||||
|
||||
print("\n[--- Execution Summary ---]")
|
||||
print(f"Finish Reason: {finish_reason_str}")
|
||||
print(f"Token Usage: Input: {uncached_tokens:,} | Cached: {cached_tokens:,} | Output: {output_tokens:,} ({tier_label})")
|
||||
print(f"Est. Cost: ${total_cost:.6f} (Model: {args.model})")
|
||||
|
||||
context_data["history"].append({"role": "user", "text": prompt_text})
|
||||
context_data["history"].append({"role": "model", "text": full_response_text})
|
||||
|
||||
if args.context:
|
||||
with open(args.context, "w") as f:
|
||||
json.dump(context_data, f, indent=4)
|
||||
|
||||
finally:
|
||||
# ---------------------------------------------------------
|
||||
@@ -198,10 +434,10 @@ def main():
|
||||
# ---------------------------------------------------------
|
||||
if not args.context and not args.destroy:
|
||||
print("\n[Transient Mode] Cleaning up resources...")
|
||||
if context_data.get("cache_id"):
|
||||
for model_name, cache_info in context_data.get("caches", {}).items():
|
||||
try:
|
||||
client.caches.delete(name=context_data["cache_id"])
|
||||
print(f"Deleted cache: {context_data['cache_id']}")
|
||||
client.caches.delete(name=cache_info["cache_id"])
|
||||
print(f"Deleted cache for {model_name}: {cache_info['cache_id']}")
|
||||
except Exception as e:
|
||||
print(f"Warning: Failed to delete cache. {e}", file=sys.stderr)
|
||||
|
||||
|
||||
Reference in New Issue
Block a user