Making the cached content model aware
This commit is contained in:
+49
-28
@@ -32,12 +32,20 @@ def main():
|
|||||||
sys.exit(1)
|
sys.exit(1)
|
||||||
|
|
||||||
client = genai.Client()
|
client = genai.Client()
|
||||||
context_data = {"file_ids": [], "cache_id": None}
|
|
||||||
|
# Updated structure to track caches and their associated file states by model
|
||||||
|
context_data = {"file_ids": [], "caches": {}}
|
||||||
|
|
||||||
if args.context and os.path.exists(args.context):
|
if args.context and os.path.exists(args.context):
|
||||||
try:
|
try:
|
||||||
with open(args.context, "r") as f:
|
with open(args.context, "r") as f:
|
||||||
context_data = json.load(f)
|
loaded_data = json.load(f)
|
||||||
|
context_data["file_ids"] = loaded_data.get("file_ids", [])
|
||||||
|
context_data["caches"] = loaded_data.get("caches", {})
|
||||||
|
|
||||||
|
# Handle migration from old format
|
||||||
|
if "cache_id" in loaded_data and loaded_data["cache_id"]:
|
||||||
|
print("Warning: Old context format detected. Please use -d to destroy and start fresh, or unexpected cache behavior may occur.", file=sys.stderr)
|
||||||
except json.JSONDecodeError:
|
except json.JSONDecodeError:
|
||||||
print(f"Warning: Could not parse {args.context}. Starting fresh.", file=sys.stderr)
|
print(f"Warning: Could not parse {args.context}. Starting fresh.", file=sys.stderr)
|
||||||
|
|
||||||
@@ -46,12 +54,12 @@ def main():
|
|||||||
# ---------------------------------------------------------
|
# ---------------------------------------------------------
|
||||||
if args.destroy:
|
if args.destroy:
|
||||||
print("Destroying server resources and local context...")
|
print("Destroying server resources and local context...")
|
||||||
if context_data.get("cache_id"):
|
for model_name, cache_info in context_data.get("caches", {}).items():
|
||||||
try:
|
try:
|
||||||
client.caches.delete(name=context_data["cache_id"])
|
client.caches.delete(name=cache_info["cache_id"])
|
||||||
print(f"Deleted cache: {context_data['cache_id']}")
|
print(f"Deleted cache for {model_name}: {cache_info['cache_id']}")
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
print(f"Warning: Failed to delete cache. {e}", file=sys.stderr)
|
print(f"Warning: Failed to delete cache for {model_name}. {e}", file=sys.stderr)
|
||||||
|
|
||||||
for file_id in context_data.get("file_ids", []):
|
for file_id in context_data.get("file_ids", []):
|
||||||
try:
|
try:
|
||||||
@@ -71,7 +79,6 @@ def main():
|
|||||||
# ---------------------------------------------------------
|
# ---------------------------------------------------------
|
||||||
# UPLOAD LOGIC
|
# UPLOAD LOGIC
|
||||||
# ---------------------------------------------------------
|
# ---------------------------------------------------------
|
||||||
new_files_added = False
|
|
||||||
if args.files:
|
if args.files:
|
||||||
for file_path in args.files:
|
for file_path in args.files:
|
||||||
if not os.path.exists(file_path):
|
if not os.path.exists(file_path):
|
||||||
@@ -84,7 +91,6 @@ def main():
|
|||||||
|
|
||||||
if uploaded_file.name not in context_data["file_ids"]:
|
if uploaded_file.name not in context_data["file_ids"]:
|
||||||
context_data["file_ids"].append(uploaded_file.name)
|
context_data["file_ids"].append(uploaded_file.name)
|
||||||
new_files_added = True
|
|
||||||
|
|
||||||
if args.context:
|
if args.context:
|
||||||
with open(args.context, "w") as f:
|
with open(args.context, "w") as f:
|
||||||
@@ -100,23 +106,31 @@ def main():
|
|||||||
"and output raw data only. If not specific data format is suggested, you can answer with conversational text."
|
"and output raw data only. If not specific data format is suggested, you can answer with conversational text."
|
||||||
)
|
)
|
||||||
|
|
||||||
# If new files were added but a cache already exists, we must destroy the stale cache.
|
|
||||||
if new_files_added and context_data.get("cache_id"):
|
|
||||||
print("New files detected. Destroying stale cache to rebuild...")
|
|
||||||
try:
|
|
||||||
client.caches.delete(name=context_data["cache_id"])
|
|
||||||
except Exception as e:
|
|
||||||
print(f"Warning: Could not delete stale cache. {e}", file=sys.stderr)
|
|
||||||
context_data["cache_id"] = None
|
|
||||||
|
|
||||||
cache_too_small = False
|
cache_too_small = False
|
||||||
file_objects = []
|
file_objects = []
|
||||||
|
active_cache_id = None
|
||||||
|
|
||||||
if context_data.get("file_ids"):
|
if context_data.get("file_ids"):
|
||||||
file_objects = [client.files.get(name=f_id) for f_id in context_data["file_ids"]]
|
file_objects = [client.files.get(name=f_id) for f_id in context_data["file_ids"]]
|
||||||
|
|
||||||
|
model_cache_info = context_data["caches"].get(args.model)
|
||||||
|
current_files_set = set(context_data["file_ids"])
|
||||||
|
rebuild_cache = False
|
||||||
|
|
||||||
|
if model_cache_info:
|
||||||
|
cached_files_set = set(model_cache_info.get("cached_file_ids", []))
|
||||||
|
if current_files_set != cached_files_set:
|
||||||
|
print(f"File list changed. Destroying stale {args.model} cache to rebuild...")
|
||||||
|
try:
|
||||||
|
client.caches.delete(name=model_cache_info["cache_id"])
|
||||||
|
except Exception as e:
|
||||||
|
print(f"Warning: Could not delete stale cache. {e}", file=sys.stderr)
|
||||||
|
rebuild_cache = True
|
||||||
|
else:
|
||||||
|
rebuild_cache = True
|
||||||
|
|
||||||
if not context_data.get("cache_id"):
|
if rebuild_cache:
|
||||||
print("Attempting to create Context Cache on Google's servers...")
|
print(f"Attempting to create Context Cache for {args.model}...")
|
||||||
try:
|
try:
|
||||||
cache = client.caches.create(
|
cache = client.caches.create(
|
||||||
model=args.model,
|
model=args.model,
|
||||||
@@ -126,7 +140,13 @@ def main():
|
|||||||
ttl="3600s"
|
ttl="3600s"
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
context_data["cache_id"] = cache.name
|
|
||||||
|
context_data["caches"][args.model] = {
|
||||||
|
"cache_id": cache.name,
|
||||||
|
"cached_file_ids": list(context_data["file_ids"])
|
||||||
|
}
|
||||||
|
active_cache_id = cache.name
|
||||||
|
|
||||||
if args.context:
|
if args.context:
|
||||||
with open(args.context, "w") as f:
|
with open(args.context, "w") as f:
|
||||||
json.dump(context_data, f, indent=4)
|
json.dump(context_data, f, indent=4)
|
||||||
@@ -139,12 +159,13 @@ def main():
|
|||||||
else:
|
else:
|
||||||
raise e
|
raise e
|
||||||
|
|
||||||
elif not cache_too_small:
|
elif not cache_too_small and model_cache_info:
|
||||||
print(f"Loading existing cache: {context_data['cache_id']}")
|
active_cache_id = model_cache_info["cache_id"]
|
||||||
|
print(f"Loading existing cache for {args.model}: {active_cache_id}")
|
||||||
print("Extending cache TTL by 60 minutes...")
|
print("Extending cache TTL by 60 minutes...")
|
||||||
try:
|
try:
|
||||||
client.caches.update(
|
client.caches.update(
|
||||||
name=context_data["cache_id"],
|
name=active_cache_id,
|
||||||
config=types.UpdateCachedContentConfig(ttl="3600s")
|
config=types.UpdateCachedContentConfig(ttl="3600s")
|
||||||
)
|
)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
@@ -161,8 +182,8 @@ def main():
|
|||||||
|
|
||||||
generation_contents = []
|
generation_contents = []
|
||||||
|
|
||||||
if context_data.get("cache_id") and not cache_too_small:
|
if active_cache_id and not cache_too_small:
|
||||||
config_kwargs["cached_content"] = context_data["cache_id"]
|
config_kwargs["cached_content"] = active_cache_id
|
||||||
else:
|
else:
|
||||||
generation_contents.extend(file_objects)
|
generation_contents.extend(file_objects)
|
||||||
config_kwargs["system_instruction"] = system_instruction
|
config_kwargs["system_instruction"] = system_instruction
|
||||||
@@ -198,10 +219,10 @@ def main():
|
|||||||
# ---------------------------------------------------------
|
# ---------------------------------------------------------
|
||||||
if not args.context and not args.destroy:
|
if not args.context and not args.destroy:
|
||||||
print("\n[Transient Mode] Cleaning up resources...")
|
print("\n[Transient Mode] Cleaning up resources...")
|
||||||
if context_data.get("cache_id"):
|
for model_name, cache_info in context_data.get("caches", {}).items():
|
||||||
try:
|
try:
|
||||||
client.caches.delete(name=context_data["cache_id"])
|
client.caches.delete(name=cache_info["cache_id"])
|
||||||
print(f"Deleted cache: {context_data['cache_id']}")
|
print(f"Deleted cache for {model_name}: {cache_info['cache_id']}")
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
print(f"Warning: Failed to delete cache. {e}", file=sys.stderr)
|
print(f"Warning: Failed to delete cache. {e}", file=sys.stderr)
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user