01 P1A Pilot Data Extraction
Jupyter notebook from the Gene Function Ecological Agora project.
NB01 — Phase 1A Pilot Data Extraction¶
Project: Gene Function Ecological Agora — Innovation Atlas Across the Bacterial Tree
Phase: 1A — Pilot (1,000 species × 1,000 UniRef50 clusters)
Purpose: Extract the pilot subset for null-model calibration, control validation, and Alm 2006 TCS reproduction before scaling to the full GTDB substrate at Phase 1B.
Stages¶
- GTDB taxonomy + species representatives — pull rank scaffold; identify GTDB representative genome per species.
- Stratified 1K-species pilot sample — proportional-to-phylum-size with floor of 5 species per phylum (forces CPR/DPANN representation against cultivation bias).
- Pilot species gene clusters — restrict to species-specific gene_cluster_ids for the 1K pilot species.
- UniRef50/UniRef90 mappings — pull
bakta_db_xrefsfiltered to pilot clusters; split by UniRef tier from the accession prefix. - eggNOG annotations — pull COG_category, KEGG_ko, KEGG BRITE, PFAMs for pilot clusters (the projection-rail and control-membership inputs).
- Control sets — define positive controls (AMR / CRISPR-Cas / Alm 2006 TCS HKs) and negative controls (ribosomal / tRNA synthetase / RNAP core) at the UniRef50 level.
- Stratified 1K-UniRef50 pilot sample — COG-stratified with explicit force-inclusion of all positive- and negative-control UniRef50s.
- Pre-flight coverage check — confirm ≥ 80 % of pilot species have ≥ 1 entry in each control set; flag any sparse control set before downstream NB02.
- Per-genome annotation density — D2 nuisance covariate input for the OLS residualization.
- Materialize pilot extract — write all outputs to
data/.
Outputs (all under projects/gene_function_ecological_agora/data/)¶
p1a_pilot_species.tsv— 1 row per pilot species rep; columns:gtdb_species_clade_id,representative_genome_id,domain,phylum,class,order,family,genus,species,no_genomes,checkm_completeness,checkm_contamination,genome_size,gc_percentage,protein_count,annotated_fraction.p1a_pilot_uniref50.tsv— 1 row per pilot UniRef50; columns:uniref50_id,dominant_cog_category,dominant_kegg_ko,dominant_brite_b,n_unique_uniref90,n_gene_clusters_supporting,control_class(pos_amr/pos_crispr_cas/pos_tcs_hk/neg_ribosomal/neg_trna_synth/neg_rnap_core/none).p1a_pilot_extract.parquet— long-format presence/copy-number matrix; one row per (gtdb_species_clade_id,uniref50_id); columns:n_uniref90_present,n_gene_clusters,is_present. Used as the primary input to NB02 null-model construction.p1a_preflight_coverage.tsv— control-set coverage diagnostic.p1a_extraction_log.txt— extraction parameters, row counts, sampling seed.
Pre-registered seed¶
All stratified sampling uses numpy.random.default_rng(42) for reproducibility. Re-running NB01 with the same seed produces the same pilot subset.
Pitfalls observed (will be added to docs/pitfalls.md if novel)¶
bakta_db_xrefs.dbvalue scheme verified at runtime in Stage 4 (not assumed from schema docs).- Pre-flight check Stage 8 catches sparse-control failure modes before NB02 wastes compute on the null model.
Setup¶
import os, json, time
from pathlib import Path
import numpy as np
import pandas as pd
from pyspark.sql import functions as F
from pyspark.sql.types import StringType
# Spark session — robust to JupyterHub notebook (auto-injected) vs CLI/script context
try:
spark # noqa: F821 — set by JupyterHub kernel startup
print("Spark session: pre-injected (JupyterHub notebook context)")
except NameError:
from berdl_notebook_utils.setup_spark_session import get_spark_session
spark = get_spark_session()
print("Spark session: created via berdl_notebook_utils")
spark.sql("SET spark.sql.autoBroadcastJoinThreshold = -1") # per docs/pitfalls.md, required for bakta_db_xrefs ↔ kbase_uniprot joins
PROJECT_ROOT = Path("/home/aparkin/BERIL-research-observatory/projects/gene_function_ecological_agora")
DATA_DIR = PROJECT_ROOT / "data"
DATA_DIR.mkdir(parents=True, exist_ok=True)
RNG_SEED = 42
rng = np.random.default_rng(RNG_SEED)
TARGET_N_SPECIES = 1000
TARGET_N_UNIREF50 = 1000
PHYLUM_FLOOR = 5
log = {
"seed": RNG_SEED,
"target_n_species": TARGET_N_SPECIES,
"target_n_uniref50": TARGET_N_UNIREF50,
"phylum_floor": PHYLUM_FLOOR,
"timestamp_utc": time.strftime("%Y-%m-%dT%H:%M:%SZ", time.gmtime()),
}
print(json.dumps(log, indent=2))
Spark session: created via berdl_notebook_utils
{
"seed": 42,
"target_n_species": 1000,
"target_n_uniref50": 1000,
"phylum_floor": 5,
"timestamp_utc": "2026-04-26T21:27:08Z"
}
Stage 1 — GTDB taxonomy and species representatives¶
# Pull species clade table — has the representative_genome_id we need for D1 dedup
species_clade = spark.sql("""
SELECT
gtdb_species_clade_id,
representative_genome_id,
GTDB_species,
GTDB_taxonomy
FROM kbase_ke_pangenome.gtdb_species_clade
""")
# Parse the GTDB_taxonomy string (d__;p__;c__;o__;f__;g__;s__) into ranks.
# Some clades have <7 ranks (truncated entries); use try_element_at to return NULL on missing rank.
species_clade = species_clade.withColumn("_taxa", F.split(F.col("GTDB_taxonomy"), ";"))
for i, rank in enumerate(["domain", "phylum", "class", "order", "family", "genus", "species_rank"]):
species_clade = species_clade.withColumn(rank, F.expr(f"try_element_at(_taxa, {i + 1})"))
species_clade = species_clade.drop("_taxa")
# Restrict to bacteria — archaea explicitly out of scope per the project constraint
species_clade = species_clade.filter(F.col("domain") == "d__Bacteria")
# Pull pangenome stats and gtdb_metadata for the representative genomes.
# Many BERDL numeric columns are stored as STRING (per docs/pitfalls.md) — use try_cast
# at pull time to coerce to numeric and return NULL on bad input rather than failing.
pangenome_stats = spark.sql("""
SELECT gtdb_species_clade_id,
try_cast(no_genomes AS BIGINT) AS no_genomes,
try_cast(no_gene_clusters AS BIGINT) AS no_gene_clusters,
try_cast(no_core AS BIGINT) AS no_core,
try_cast(no_aux_genome AS BIGINT) AS no_aux_genome,
try_cast(no_singleton_gene_clusters AS BIGINT) AS no_singleton_gene_clusters
FROM kbase_ke_pangenome.pangenome
""")
metadata = spark.sql("""
SELECT accession AS representative_genome_id,
try_cast(checkm_completeness AS DOUBLE) AS checkm_completeness,
try_cast(checkm_contamination AS DOUBLE) AS checkm_contamination,
try_cast(genome_size AS BIGINT) AS genome_size,
try_cast(gc_percentage AS DOUBLE) AS gc_percentage,
try_cast(protein_count AS BIGINT) AS protein_count
FROM kbase_ke_pangenome.gtdb_metadata
""")
species_full = (
species_clade
.join(pangenome_stats, on="gtdb_species_clade_id", how="inner")
.join(metadata, on="representative_genome_id", how="left")
)
n_total = species_full.count()
log["n_bacteria_species_total"] = n_total
print(f"Bacteria species clades with pangenome stats: {n_total:,}")
phylum_counts = (
species_full.groupBy("phylum").count().orderBy(F.desc("count")).toPandas()
)
log["n_phyla"] = int(len(phylum_counts))
print(f"Distinct phyla: {len(phylum_counts)}")
phylum_counts.head(15)
Bacteria species clades with pangenome stats: 26,525
Distinct phyla: 125
| phylum | count | |
|---|---|---|
| 0 | p__Pseudomonadota | 7456 |
| 1 | p__Bacillota_A | 3917 |
| 2 | p__Bacteroidota | 3629 |
| 3 | p__Actinomycetota | 3172 |
| 4 | p__Bacillota | 2146 |
| 5 | p__Patescibacteria | 981 |
| 6 | p__Verrucomicrobiota | 553 |
| 7 | p__Cyanobacteriota | 469 |
| 8 | p__Chloroflexota | 457 |
| 9 | p__Planctomycetota | 348 |
| 10 | p__Desulfobacterota | 319 |
| 11 | p__Spirochaetota | 296 |
| 12 | p__Acidobacteriota | 283 |
| 13 | p__Campylobacterota | 271 |
| 14 | p__Bacillota_C | 235 |
Stage 2 — Stratified 1K-species pilot sample¶
Proportional-to-phylum-size with a floor of 5 species per phylum. The floor guarantees CPR / DPANN-equivalent (small, poorly-cultivated) phyla are represented against the cultivation bias. Phyla with fewer than 5 species are taken in full.
# Quality filter — drop species reps with poor CheckM (D1 quality requirement)
species_quality = species_full.filter(
(F.col("checkm_completeness") >= 90)
& (F.col("checkm_contamination") <= 5)
).toPandas()
log["n_species_after_quality_filter"] = int(len(species_quality))
print(f"After CheckM ≥90 %, contamination ≤5 %: {len(species_quality):,} species")
# Allocate species per phylum — proportional with floor
phylum_sizes = species_quality.groupby("phylum").size().sort_values(ascending=False)
n_phyla = len(phylum_sizes)
# Floor allocation
alloc = {p: min(PHYLUM_FLOOR, n) for p, n in phylum_sizes.items()}
remaining = TARGET_N_SPECIES - sum(alloc.values())
# Proportional allocation of the residual to phyla larger than the floor
if remaining > 0:
eligible = phylum_sizes[phylum_sizes > PHYLUM_FLOOR]
weights = eligible / eligible.sum()
proportional = (weights * remaining).round().astype(int)
# Adjust rounding so totals match exactly
while proportional.sum() < remaining:
idx = (weights * remaining - proportional).idxmax()
proportional[idx] += 1
while proportional.sum() > remaining:
idx = (proportional - weights * remaining).idxmax()
proportional[idx] -= 1
for p, add in proportional.items():
alloc[p] = min(alloc[p] + int(add), int(phylum_sizes[p]))
# Sample
sampled_pieces = []
for phylum, n_to_sample in alloc.items():
pool = species_quality[species_quality["phylum"] == phylum]
if len(pool) == 0 or n_to_sample == 0:
continue
if n_to_sample >= len(pool):
sampled_pieces.append(pool)
else:
sampled_pieces.append(pool.sample(n=n_to_sample, random_state=RNG_SEED))
pilot_species = pd.concat(sampled_pieces, ignore_index=True)
log["n_pilot_species"] = int(len(pilot_species))
log["n_pilot_phyla"] = int(pilot_species["phylum"].nunique())
print(f"Pilot species sampled: {len(pilot_species):,} across {pilot_species['phylum'].nunique()} phyla")
pilot_species.groupby("phylum").size().sort_values(ascending=False).head(20)
After CheckM ≥90 %, contamination ≤5 %: 18,989 species
Pilot species sampled: 1,000 across 110 phyla
phylum p__Pseudomonadota 197 p__Actinomycetota 88 p__Bacteroidota 87 p__Bacillota_A 86 p__Bacillota 64 p__Cyanobacteriota 16 p__Verrucomicrobiota 16 p__Planctomycetota 13 p__Campylobacterota 13 p__Chloroflexota 13 p__Spirochaetota 12 p__Desulfobacterota 12 p__Acidobacteriota 11 p__Bacillota_C 11 p__Myxococcota 8 p__Fusobacteriota 7 p__Elusimicrobiota 7 p__Gemmatimonadota 7 p__Desulfobacterota_F 7 p__Bdellovibrionota 7 dtype: int64
Stage 3 — Pilot species gene clusters¶
Pull the gene_cluster_ids for the pilot species. Gene clusters are species-specific (gtdb_species_clade_id is FK on gene_cluster).
pilot_clade_ids = pilot_species["gtdb_species_clade_id"].tolist()
# Push the clade IDs to the cluster as a temp view (avoids large IN clause)
spark.createDataFrame(
pd.DataFrame({"gtdb_species_clade_id": pilot_clade_ids})
).createOrReplaceTempView("pilot_species_view")
pilot_gene_clusters = spark.sql("""
SELECT gc.gene_cluster_id,
gc.gtdb_species_clade_id,
gc.is_core,
gc.is_auxiliary,
gc.is_singleton
FROM kbase_ke_pangenome.gene_cluster gc
JOIN pilot_species_view ps
ON gc.gtdb_species_clade_id = ps.gtdb_species_clade_id
""")
pilot_gene_clusters.cache()
n_pilot_clusters = pilot_gene_clusters.count()
log["n_pilot_gene_clusters"] = int(n_pilot_clusters)
print(f"Gene clusters in pilot species: {n_pilot_clusters:,}")
Gene clusters in pilot species: 4,885,206
Stage 4 — bakta_db_xrefs UniRef50/UniRef90 mappings¶
First inspect distinct db values in bakta_db_xrefs to confirm how UniRef tiers are encoded. Then pull the mapping filtered to pilot gene clusters.
# Stage 4a — discovery: what are the distinct db values?
db_values = spark.sql("""
SELECT db, COUNT(*) AS n FROM kbase_ke_pangenome.bakta_db_xrefs GROUP BY db ORDER BY n DESC
""").toPandas()
print("Distinct db values in bakta_db_xrefs:")
print(db_values)
log["bakta_db_xrefs_db_values"] = db_values.to_dict(orient="records")
Distinct db values in bakta_db_xrefs:
db n
0 UniRef 242260603
1 SO 102373648
2 UniParc 61464352
3 RefSeq 50046460
4 GO 45014616
5 COG 20112326
6 PFAM 18807208
7 KEGG 16748620
8 EC 15177756
9 BlastRules 223876
10 NCBIFam 42742
11 NCBIProtein 40266
12 VFDB 39150
13 IS 24854
# Stage 4b — pull UniRef mappings for pilot clusters
# We'll filter by db prefix dynamically based on what we found in Stage 4a;
# the schema docs suggest "UniRef" (single value) with the tier embedded in `accession` as e.g. "UniRef50_P12345".
# We materialise both interpretations and use whichever has data.
pilot_gene_clusters.createOrReplaceTempView("pilot_gc_view")
# Try the value(s) most likely to encode UniRef
uniref_db_candidates = [v for v in db_values["db"].tolist() if isinstance(v, str) and "uniref" in v.lower()]
print(f"Candidate db values for UniRef: {uniref_db_candidates}")
if not uniref_db_candidates:
raise RuntimeError("No UniRef-like db value found in bakta_db_xrefs — investigate before continuing.")
uniref_db_filter = "', '".join(uniref_db_candidates)
uniref_xrefs = spark.sql(f"""
SELECT bx.gene_cluster_id, bx.db, bx.accession
FROM kbase_ke_pangenome.bakta_db_xrefs bx
JOIN pilot_gc_view pg ON bx.gene_cluster_id = pg.gene_cluster_id
WHERE bx.db IN ('{uniref_db_filter}')
""")
uniref_xrefs.cache()
n_uniref_rows = uniref_xrefs.count()
log["n_uniref_xref_rows"] = int(n_uniref_rows)
print(f"UniRef cross-reference rows for pilot clusters: {n_uniref_rows:,}")
Candidate db values for UniRef: ['UniRef']
UniRef cross-reference rows for pilot clusters: 9,177,608
# Stage 4c — split UniRef tiers from the accession prefix
uniref_split = (
uniref_xrefs
.withColumn("uniref_tier", F.regexp_extract(F.col("accession"), r"^(UniRef\d+)_", 1))
.withColumn("uniref_id", F.col("accession"))
)
tier_counts = uniref_split.groupBy("uniref_tier").count().toPandas()
print("UniRef tier distribution (pilot subset):")
print(tier_counts)
log["uniref_tier_counts"] = tier_counts.to_dict(orient="records")
uniref50 = uniref_split.filter(F.col("uniref_tier") == "UniRef50").select(
"gene_cluster_id", F.col("uniref_id").alias("uniref50_id")
)
uniref90 = uniref_split.filter(F.col("uniref_tier") == "UniRef90").select(
"gene_cluster_id", F.col("uniref_id").alias("uniref90_id")
)
uniref50.cache(); uniref90.cache()
log["n_uniref50_xref_rows"] = int(uniref50.count())
log["n_uniref90_xref_rows"] = int(uniref90.count())
UniRef tier distribution (pilot subset): uniref_tier count 0 UniRef50 3842755 1 UniRef90 2976743 2 UniRef100 2358110
Stage 5 — eggNOG annotations for COG / KEGG / BRITE / PFAMs¶
These are the projection-rail and control-membership inputs. We pull them per pilot gene_cluster_id.
eggnog = spark.sql("""
SELECT query_name AS gene_cluster_id,
COG_category, KEGG_ko, KEGG_Pathway, BRITE, PFAMs
FROM kbase_ke_pangenome.eggnog_mapper_annotations
""")
pilot_eggnog = (
eggnog.join(pilot_gene_clusters.select("gene_cluster_id"), on="gene_cluster_id", how="inner")
)
pilot_eggnog.cache()
n_pilot_eggnog = pilot_eggnog.count()
log["n_pilot_eggnog_annotated_clusters"] = int(n_pilot_eggnog)
print(f"Pilot gene clusters with eggNOG annotation: {n_pilot_eggnog:,} / {n_pilot_clusters:,} ({100*n_pilot_eggnog/max(n_pilot_clusters,1):.1f} %)")
# v2: Also pull InterProScan domains (the authoritative Pfam annotation source on BERDL).
# 833 M rows total; 146 M Pfam hits across 132.5 M cluster reps (83.8 % coverage).
# We restrict to pilot gene_clusters only and to analysis='Pfam' for the headline control detection.
ips_domains = spark.sql("""
SELECT gene_cluster_id, analysis, signature_acc, signature_desc, ipr_acc, ipr_desc
FROM kbase_ke_pangenome.interproscan_domains
""")
pilot_ips = (
ips_domains.join(pilot_gene_clusters.select("gene_cluster_id"), on="gene_cluster_id", how="inner")
)
pilot_ips.cache()
n_pilot_ips = pilot_ips.count()
log["n_pilot_ips_domain_rows"] = int(n_pilot_ips)
n_pilot_ips_clusters = pilot_ips.select("gene_cluster_id").distinct().count()
log["n_pilot_ips_distinct_clusters"] = int(n_pilot_ips_clusters)
print(f"Pilot InterProScan domain rows: {n_pilot_ips:,}; distinct clusters with any IPS hit: {n_pilot_ips_clusters:,} ({100*n_pilot_ips_clusters/max(n_pilot_clusters,1):.1f} %)")
# InterProScan GO terms — for Phase 2 cross-validation of regulatory vs metabolic at GO BP level
ips_go = spark.sql("""
SELECT gene_cluster_id, go_id, go_source, n_supporting_analyses
FROM kbase_ke_pangenome.interproscan_go
""")
pilot_ips_go = (
ips_go.join(pilot_gene_clusters.select("gene_cluster_id"), on="gene_cluster_id", how="inner")
)
pilot_ips_go.cache()
log["n_pilot_ips_go_rows"] = int(pilot_ips_go.count())
print(f"Pilot InterProScan GO rows: {log['n_pilot_ips_go_rows']:,}")
# InterProScan pathways — MetaCyc + KEGG, alternative to eggNOG KEGG_Pathway
ips_pathways = spark.sql("""
SELECT gene_cluster_id, pathway_db, pathway_id, n_supporting_analyses
FROM kbase_ke_pangenome.interproscan_pathways
""")
pilot_ips_pathways = (
ips_pathways.join(pilot_gene_clusters.select("gene_cluster_id"), on="gene_cluster_id", how="inner")
)
pilot_ips_pathways.cache()
log["n_pilot_ips_pathway_rows"] = int(pilot_ips_pathways.count())
print(f"Pilot InterProScan pathway rows: {log['n_pilot_ips_pathway_rows']:,}")
Pilot gene clusters with eggNOG annotation: 3,417,381 / 4,885,206 (70.0 %)
Pilot InterProScan domain rows: 31,291,050; distinct clusters with any IPS hit: 4,114,075 (84.2 %)
Pilot InterProScan GO rows: 9,865,309
Pilot InterProScan pathway rows: 10,537,763
Stage 6 — Define positive and negative control sets¶
Positive controls (expected: detectable acquisition signal in clades known for HGT):
pos_amr— Antimicrobial-resistance genes viabakta_amr.gene_cluster_idjoin.pos_crispr_cas— CRISPR-Cas via Cas-family Pfams (PFAMscontaining any of the canonical Cas accessions).pos_tcs_hk— Alm 2006 two-component-system histidine kinases viaPFAMscontainingPF00512(HisKA) orPF00072(Response_reg).
Negative controls (expected: low producer + low participation, on-diagonal):
neg_ribosomal— KEGG_ko in the ribosome KO set.neg_trna_synth— KEGG_ko in the aminoacyl-tRNA-synthetase KO set.neg_rnap_core— RNA polymerase core viaKEGG_komatching K03040 / K03043 / K03046.
We assign control class at the UniRef50 level by aggregating gene-cluster-level membership: a UniRef50 carries a control class if ≥ 50 % of its mapped pilot gene clusters carry that class. UniRef50s with mixed signals get none.
# v2 control detection: union of eggNOG-name-based and InterProScan-accession-based.
# InterProScan signature_acc is the authoritative Pfam ID; eggNOG PFAMs has domain NAMES
# (per docs/pitfalls.md [snipe_defense_system]) and is fragile but cheap. Take the UNION
# so we maximize recall on each control class.
# --- Pfam accessions for InterProScan-based detection (analysis='Pfam') ---
# TCS histidine kinases — HisKA, HisKA_2, HisKA_3, HWE_HK, HATPase_c, His_kinase, HK_dimer
TCS_HK_PFAM_ACC = {"PF00512", "PF07568", "PF07730", "PF06580", "PF02518", "PF13415", "PF13581"}
# Response regulator + output domains
TCS_RR_PFAM_ACC = {"PF00072", "PF00196", "PF02954", "PF00486"}
# Pfam domain names for eggNOG-based detection (fallback / cross-validation)
TCS_HK_NAMES = {"HisKA", "HisKA_2", "HisKA_3", "HWE_HK", "HATPase_c"}
TCS_RR_NAMES = {"Response_reg", "Trans_reg_C"}
# --- KEGG KO sets for negative controls ---
RIBOSOMAL_KOS = {f"K{n:05d}" for n in (
list(range(2860, 2900)) +
list(range(2950, 2999))
)}
TRNA_SYNTH_KOS = {f"K{n:05d}" for n in range(1866, 1891)}
RNAP_CORE_KOS = {"K03040", "K03043", "K03046"}
# --- Pre-compute eggNOG-based control flags (cheap UDF approach) ---
def has_any_pfam_name(pfam_str, name_set):
if pfam_str is None or pfam_str in ("-", ""):
return False
tokens = pfam_str.replace(";", ",").replace(" ", ",").split(",")
tokens = {t.strip() for t in tokens if t.strip()}
return bool(tokens & name_set)
def has_any_ko(ko_str, ko_set):
if ko_str is None or ko_str in ("-", ""):
return False
tokens = [t.replace("ko:", "").strip() for t in ko_str.replace(";", ",").split(",")]
return any(t in ko_set for t in tokens)
ribosomal_udf = F.udf(lambda s: has_any_ko(s, RIBOSOMAL_KOS))
trna_synth_udf = F.udf(lambda s: has_any_ko(s, TRNA_SYNTH_KOS))
rnap_udf = F.udf(lambda s: has_any_ko(s, RNAP_CORE_KOS))
tcs_hk_eggnog_udf = F.udf(lambda s: has_any_pfam_name(s, TCS_HK_NAMES | TCS_RR_NAMES))
eggnog_flags = (
pilot_eggnog
.withColumn("egg_ribosomal", ribosomal_udf(F.col("KEGG_ko")) == "true")
.withColumn("egg_trna_synth", trna_synth_udf(F.col("KEGG_ko")) == "true")
.withColumn("egg_rnap_core", rnap_udf(F.col("KEGG_ko")) == "true")
.withColumn("egg_tcs_hk", tcs_hk_eggnog_udf(F.col("PFAMs")) == "true")
).select("gene_cluster_id", "egg_ribosomal", "egg_trna_synth", "egg_rnap_core", "egg_tcs_hk")
# --- InterProScan-based detection (the v2 authoritative path) ---
# Pfam-only TCS HK: filter to analysis='Pfam' + signature_acc in TCS Pfam set
tcs_acc_filter = "', '".join(sorted(TCS_HK_PFAM_ACC | TCS_RR_PFAM_ACC))
ips_tcs_clusters = spark.sql(f"""
SELECT DISTINCT gene_cluster_id
FROM kbase_ke_pangenome.interproscan_domains
WHERE analysis = 'Pfam'
AND signature_acc IN ('{tcs_acc_filter}')
""")
ips_tcs_pilot = (
ips_tcs_clusters
.join(pilot_gene_clusters.select("gene_cluster_id"), on="gene_cluster_id", how="inner")
.withColumn("ips_tcs_hk", F.lit(True))
)
log["n_pilot_ips_tcs_hk_clusters"] = int(ips_tcs_pilot.count())
# Ribosomal proteins — InterPro descriptions (broad pattern catches all subunits)
ips_ribosomal = spark.sql("""
SELECT DISTINCT gene_cluster_id
FROM kbase_ke_pangenome.interproscan_domains
WHERE (LOWER(ipr_desc) LIKE '%ribosomal protein%'
OR LOWER(signature_desc) LIKE '%ribosomal protein%')
""")
ips_ribosomal_pilot = (
ips_ribosomal
.join(pilot_gene_clusters.select("gene_cluster_id"), on="gene_cluster_id", how="inner")
.withColumn("ips_ribosomal", F.lit(True))
)
log["n_pilot_ips_ribosomal_clusters"] = int(ips_ribosomal_pilot.count())
# tRNA synthetases
ips_trna_synth = spark.sql("""
SELECT DISTINCT gene_cluster_id
FROM kbase_ke_pangenome.interproscan_domains
WHERE (LOWER(ipr_desc) LIKE '%aminoacyl-trna synthetase%'
OR LOWER(ipr_desc) LIKE '%aminoacyl trna synthetase%'
OR LOWER(signature_desc) LIKE '%aminoacyl-trna synthetase%'
OR LOWER(signature_desc) LIKE '%trna ligase%')
""")
ips_trna_synth_pilot = (
ips_trna_synth
.join(pilot_gene_clusters.select("gene_cluster_id"), on="gene_cluster_id", how="inner")
.withColumn("ips_trna_synth", F.lit(True))
)
log["n_pilot_ips_trna_synth_clusters"] = int(ips_trna_synth_pilot.count())
# RNAP core (alpha, beta, beta-prime)
ips_rnap_core = spark.sql("""
SELECT DISTINCT gene_cluster_id
FROM kbase_ke_pangenome.interproscan_domains
WHERE (
(LOWER(ipr_desc) LIKE '%rna polymerase%alpha%' OR LOWER(signature_desc) LIKE '%rna polymerase%alpha%')
OR (LOWER(ipr_desc) LIKE '%rna polymerase%beta%' OR LOWER(signature_desc) LIKE '%rna polymerase%beta%')
)
AND LOWER(ipr_desc) NOT LIKE '%dna-directed%accessory%'
""")
ips_rnap_pilot = (
ips_rnap_core
.join(pilot_gene_clusters.select("gene_cluster_id"), on="gene_cluster_id", how="inner")
.withColumn("ips_rnap_core", F.lit(True))
)
log["n_pilot_ips_rnap_clusters"] = int(ips_rnap_pilot.count())
# --- AMR via bakta_amr (unchanged from v1) ---
amr = spark.sql("SELECT DISTINCT gene_cluster_id FROM kbase_ke_pangenome.bakta_amr").withColumn("is_amr", F.lit(True))
# --- Combine: take UNION of eggNOG and InterProScan flags per control class ---
cluster_flags = (
pilot_gene_clusters.select("gene_cluster_id")
.join(eggnog_flags, on="gene_cluster_id", how="left")
.join(ips_tcs_pilot, on="gene_cluster_id", how="left")
.join(ips_ribosomal_pilot, on="gene_cluster_id", how="left")
.join(ips_trna_synth_pilot, on="gene_cluster_id", how="left")
.join(ips_rnap_pilot, on="gene_cluster_id", how="left")
.join(amr, on="gene_cluster_id", how="left")
.na.fill({
"egg_ribosomal": False, "egg_trna_synth": False, "egg_rnap_core": False, "egg_tcs_hk": False,
"ips_tcs_hk": False, "ips_ribosomal": False, "ips_trna_synth": False, "ips_rnap_core": False,
"is_amr": False,
})
.withColumn("is_ribosomal", F.col("egg_ribosomal") | F.col("ips_ribosomal"))
.withColumn("is_trna_synth", F.col("egg_trna_synth") | F.col("ips_trna_synth"))
.withColumn("is_rnap_core", F.col("egg_rnap_core") | F.col("ips_rnap_core"))
.withColumn("is_tcs_hk", F.col("egg_tcs_hk") | F.col("ips_tcs_hk"))
.select("gene_cluster_id", "is_ribosomal", "is_trna_synth", "is_rnap_core", "is_tcs_hk", "is_amr",
"egg_ribosomal", "egg_trna_synth", "egg_rnap_core", "egg_tcs_hk",
"ips_ribosomal", "ips_trna_synth", "ips_rnap_core", "ips_tcs_hk")
)
cluster_flags.cache()
print("Cluster-level control-flag counts (pilot subset; eggNOG vs InterProScan vs UNION):")
cluster_flags.select(
F.sum(F.col("egg_ribosomal").cast("int")).alias("egg_rib"),
F.sum(F.col("ips_ribosomal").cast("int")).alias("ips_rib"),
F.sum(F.col("is_ribosomal").cast("int")).alias("union_rib"),
F.sum(F.col("egg_trna_synth").cast("int")).alias("egg_trna"),
F.sum(F.col("ips_trna_synth").cast("int")).alias("ips_trna"),
F.sum(F.col("is_trna_synth").cast("int")).alias("union_trna"),
F.sum(F.col("egg_rnap_core").cast("int")).alias("egg_rnap"),
F.sum(F.col("ips_rnap_core").cast("int")).alias("ips_rnap"),
F.sum(F.col("is_rnap_core").cast("int")).alias("union_rnap"),
F.sum(F.col("egg_tcs_hk").cast("int")).alias("egg_tcs"),
F.sum(F.col("ips_tcs_hk").cast("int")).alias("ips_tcs"),
F.sum(F.col("is_tcs_hk").cast("int")).alias("union_tcs"),
F.sum(F.col("is_amr").cast("int")).alias("amr"),
).show(truncate=False)
Cluster-level control-flag counts (pilot subset; eggNOG vs InterProScan vs UNION):
+-------+-------+---------+--------+--------+----------+--------+--------+----------+-------+-------+---------+----+ |egg_rib|ips_rib|union_rib|egg_trna|ips_trna|union_trna|egg_rnap|ips_rnap|union_rnap|egg_tcs|ips_tcs|union_tcs|amr | +-------+-------+---------+--------+--------+----------+--------+--------+----------+-------+-------+---------+----+ |37244 |88694 |89832 |24317 |34960 |35794 |4103 |4553 |4718 |77327 |92101 |99055 |2650| +-------+-------+---------+--------+--------+----------+--------+--------+----------+-------+-------+---------+----+
Stage 7 — Aggregate to UniRef50 level + COG-stratified sample with control inclusion¶
# Aggregate cluster-level control flags + COG / KEGG annotation up to UniRef50.
# Per UniRef50: dominant COG (mode), dominant KEGG_ko (mode), dominant BRITE B-level (mode),
# dominant InterPro IPR description (mode), dominant GO BP, dominant MetaCyc pathway, control flag (any).
uniref50_eggnog = (
uniref50
.join(pilot_eggnog, on="gene_cluster_id", how="inner")
.join(cluster_flags, on="gene_cluster_id", how="left")
.na.fill({"is_ribosomal": False, "is_trna_synth": False, "is_rnap_core": False,
"is_tcs_hk": False, "is_amr": False})
)
uniref50_eggnog.cache()
uniref50_agg = uniref50_eggnog.groupBy("uniref50_id").agg(
F.count("*").alias("n_gene_clusters_supporting"),
F.collect_list("COG_category").alias("cogs"),
F.collect_list("KEGG_ko").alias("kos"),
F.collect_list("BRITE").alias("brites"),
F.max(F.col("is_ribosomal").cast("int")).alias("is_ribosomal"),
F.max(F.col("is_trna_synth").cast("int")).alias("is_trna_synth"),
F.max(F.col("is_rnap_core").cast("int")).alias("is_rnap_core"),
F.max(F.col("is_tcs_hk").cast("int")).alias("is_tcs_hk"),
F.max(F.col("is_amr").cast("int")).alias("is_amr"),
)
def _mode(lst):
if not lst:
return None
cleaned = [x for x in lst if x not in (None, "", "-")]
if not cleaned:
return None
counts = {}
for x in cleaned:
first = x[0] if isinstance(x, str) else x
counts[first] = counts.get(first, 0) + 1
return max(counts.items(), key=lambda kv: kv[1])[0]
def _mode_str(lst):
"""Mode for full strings (not first-letter) — for IPR desc, GO id, pathway id."""
if not lst:
return None
cleaned = [x for x in lst if x not in (None, "", "-")]
if not cleaned:
return None
counts = {}
for x in cleaned:
counts[x] = counts.get(x, 0) + 1
return max(counts.items(), key=lambda kv: kv[1])[0]
_mode_udf = F.udf(_mode, StringType())
_mode_str_udf = F.udf(_mode_str, StringType())
uniref50_summary = (
uniref50_agg
.withColumn("dominant_cog_category", _mode_udf(F.col("cogs")))
.withColumn("dominant_kegg_ko", _mode_str_udf(F.col("kos")))
.withColumn("dominant_brite_b", _mode_str_udf(F.col("brites")))
.drop("cogs", "kos", "brites")
)
# v2: enrich with InterPro IPR description (dominant per UniRef50) — taken from interproscan_domains
ips_uniref50 = (
uniref50
.join(pilot_ips.filter(F.col("ipr_desc").isNotNull() & (F.col("ipr_desc") != "")), on="gene_cluster_id", how="inner")
.groupBy("uniref50_id")
.agg(F.collect_list("ipr_desc").alias("ipr_descs"))
.withColumn("dominant_ipr_desc", _mode_str_udf(F.col("ipr_descs")))
.drop("ipr_descs")
)
uniref50_summary = uniref50_summary.join(ips_uniref50, on="uniref50_id", how="left")
# v2: enrich with dominant GO BP id (go_source IN ('InterPro', 'PANTHER'); no BP filter — all GO)
ips_go_uniref50 = (
uniref50
.join(pilot_ips_go, on="gene_cluster_id", how="inner")
.groupBy("uniref50_id")
.agg(F.collect_list("go_id").alias("go_ids"))
.withColumn("dominant_go_id", _mode_str_udf(F.col("go_ids")))
.drop("go_ids")
)
uniref50_summary = uniref50_summary.join(ips_go_uniref50, on="uniref50_id", how="left")
# v2: enrich with dominant MetaCyc pathway (preferred over KEGG since interproscan_pathways favors MetaCyc)
ips_pathway_uniref50 = (
uniref50
.join(pilot_ips_pathways, on="gene_cluster_id", how="inner")
.groupBy("uniref50_id")
.agg(
F.collect_list("pathway_id").alias("pathway_ids"),
F.collect_list("pathway_db").alias("pathway_dbs"),
)
.withColumn("dominant_pathway_id", _mode_str_udf(F.col("pathway_ids")))
.withColumn("dominant_pathway_db", _mode_str_udf(F.col("pathway_dbs")))
.drop("pathway_ids", "pathway_dbs")
)
uniref50_summary = uniref50_summary.join(ips_pathway_uniref50, on="uniref50_id", how="left")
uniref50_summary.cache()
n_uniref50_total = uniref50_summary.count()
log["n_unique_uniref50_in_pilot_pool"] = int(n_uniref50_total)
print(f"Distinct UniRef50 IDs in pilot pool: {n_uniref50_total:,}")
uniref50_summary.show(5, truncate=80)
Distinct UniRef50 IDs in pilot pool: 1,548,871 +-------------------+--------------------------+------------+-------------+------------+---------+------+---------------------+----------------+-----------------------------------------------+-------------------------------------------------------+--------------+-------------------+-------------------+ | uniref50_id|n_gene_clusters_supporting|is_ribosomal|is_trna_synth|is_rnap_core|is_tcs_hk|is_amr|dominant_cog_category|dominant_kegg_ko| dominant_brite_b| dominant_ipr_desc|dominant_go_id|dominant_pathway_id|dominant_pathway_db| +-------------------+--------------------------+------------+-------------+------------+---------+------+---------------------+----------------+-----------------------------------------------+-------------------------------------------------------+--------------+-------------------+-------------------+ |UniRef50_A0A009M360| 1| 0| 0| 0| 0| 0| K| ko:K16137| ko00000,ko03000| DNA-binding HTH domain, TetR-type| GO:0003677| NULL| NULL| |UniRef50_A0A011NUG4| 2| 0| 0| 0| 0| 0| O| NULL| NULL| HSP20-like chaperone| GO:0009408| NULL| NULL| |UniRef50_A0A011NXE7| 5| 0| 0| 0| 0| 0| L| ko:K06896| ko00000,ko00001,ko01000| CBM21 (carbohydrate binding type-21) domain| GO:0003824| NULL| NULL| |UniRef50_A0A011Q210| 1| 0| 0| 0| 0| 0| E| ko:K04487|ko00000,ko00001,ko01000,ko02048,ko03016,ko03029|Pyridoxal phosphate-dependent transferase, small domain| NULL| NULL| NULL| |UniRef50_A0A011Q8H8| 1| 0| 0| 0| 0| 0| NULL| NULL| NULL| NULL| NULL| NULL| NULL| +-------------------+--------------------------+------------+-------------+------------+---------+------+---------------------+----------------+-----------------------------------------------+-------------------------------------------------------+--------------+-------------------+-------------------+ only showing top 5 rows
# Pull pilot pool to driver for sampling (≤ low millions of rows; toPandas is fine here)
uniref50_pool = uniref50_summary.toPandas()
# Assign control_class
def _control_class(row):
if row["is_ribosomal"]: return "neg_ribosomal"
if row["is_trna_synth"]: return "neg_trna_synth"
if row["is_rnap_core"]: return "neg_rnap_core"
if row["is_tcs_hk"]: return "pos_tcs_hk"
if row["is_amr"]: return "pos_amr"
return "none"
uniref50_pool["control_class"] = uniref50_pool.apply(_control_class, axis=1)
control_pool = uniref50_pool[uniref50_pool["control_class"] != "none"]
noncontrol_pool = uniref50_pool[uniref50_pool["control_class"] == "none"]
print("Control-pool counts before sampling:")
print(control_pool["control_class"].value_counts())
log["control_pool_class_counts"] = control_pool["control_class"].value_counts().to_dict()
# v2 cap: each control class → at most CONTROL_CAP_PER_CLASS UniRef50s, COG-stratified within class.
CONTROL_CAP_PER_CLASS = 200
log["control_cap_per_class"] = CONTROL_CAP_PER_CLASS
# v3 (2026-04-26 post-NB02 sparsity audit): add a 6th class "natural_expansion" — UniRef50s
# with documented paralog signal in pilot species. These are UniRefs where:
# max paralog count across any pilot species ≥ 3 (within-species expansion)
# AND number of pilot species containing the UniRef ≥ 5 (cross-species breadth)
# Provides a denser, signal-rich substrate for null-model validation in NB02.
NATURAL_EXPANSION_CAP = 200
NATURAL_MIN_PARALOG = 3
NATURAL_MIN_SPECIES = 5
log["natural_expansion_cap"] = NATURAL_EXPANSION_CAP
log["natural_min_paralog"] = NATURAL_MIN_PARALOG
log["natural_min_species"] = NATURAL_MIN_SPECIES
def _cog_stratified_sample(df, n, seed):
"""COG-stratified sample of size n from df."""
if n <= 0 or len(df) == 0:
return df.iloc[0:0]
if n >= len(df):
return df
cog_dist = df["dominant_cog_category"].fillna("NA").value_counts(normalize=True)
alloc = (cog_dist * n).round().astype(int)
while alloc.sum() < n:
idx = (cog_dist * n - alloc).idxmax()
alloc[idx] += 1
while alloc.sum() > n:
idx = (alloc - cog_dist * n).idxmax()
alloc[idx] -= 1
pieces = []
for cog, k in alloc.items():
sub = df[df["dominant_cog_category"].fillna("NA") == cog]
if k >= len(sub):
pieces.append(sub)
elif k > 0:
pieces.append(sub.sample(n=k, random_state=seed))
return pd.concat(pieces, ignore_index=True) if pieces else df.iloc[0:0]
# --- Build natural_expansion candidate pool (Spark) ---
# Compute per-(uniref50, species) paralog count, then aggregate to per-uniref signal
paralog_signal = (
uniref50_eggnog
.join(pilot_gene_clusters.select("gene_cluster_id", "gtdb_species_clade_id"), on="gene_cluster_id")
.groupBy("uniref50_id", "gtdb_species_clade_id")
.agg(F.count("*").alias("n_clusters_in_species"))
.groupBy("uniref50_id")
.agg(
F.max("n_clusters_in_species").alias("max_paralog_in_any_species"),
F.countDistinct("gtdb_species_clade_id").alias("n_species_with_uref"),
)
.filter(
(F.col("max_paralog_in_any_species") >= NATURAL_MIN_PARALOG)
& (F.col("n_species_with_uref") >= NATURAL_MIN_SPECIES)
)
)
natural_pool_pdf = paralog_signal.toPandas()
log["n_natural_expansion_pool"] = int(len(natural_pool_pdf))
print(f"Natural-expansion pool (max_paralog ≥ {NATURAL_MIN_PARALOG} AND n_species ≥ {NATURAL_MIN_SPECIES}): {len(natural_pool_pdf):,} UniRef50s")
# Mark the natural-expansion subset of uniref50_pool — exclude UniRefs already in control classes
natural_candidate_ids = set(natural_pool_pdf["uniref50_id"].tolist())
natural_pool = uniref50_pool[
(uniref50_pool["uniref50_id"].isin(natural_candidate_ids))
& (uniref50_pool["control_class"] == "none")
].copy()
natural_pool["control_class"] = "natural_expansion"
log["n_natural_expansion_after_dedup"] = int(len(natural_pool))
print(f" after deduplicating against existing controls: {len(natural_pool):,} candidates")
# --- Sample within each control class (cap at 200) ---
control_sampled_pieces = []
for cls in sorted(control_pool["control_class"].unique()):
cls_pool = control_pool[control_pool["control_class"] == cls]
cls_sampled = _cog_stratified_sample(cls_pool, min(CONTROL_CAP_PER_CLASS, len(cls_pool)), seed=RNG_SEED)
control_sampled_pieces.append(cls_sampled)
control_sampled = pd.concat(control_sampled_pieces, ignore_index=True) if control_sampled_pieces else pd.DataFrame()
# --- Sample natural_expansion class ---
natural_sampled = _cog_stratified_sample(natural_pool, min(NATURAL_EXPANSION_CAP, len(natural_pool)), seed=RNG_SEED)
n_control = len(control_sampled)
n_natural = len(natural_sampled)
log["control_sampled_class_counts"] = control_sampled["control_class"].value_counts().to_dict() if n_control > 0 else {}
log["n_natural_expansion_sampled"] = int(n_natural)
# --- Fill remainder of TARGET_N_UNIREF50 from non-control non-natural pool with COG stratification ---
# (This ensures we still hit the 1000 target if controls + natural < 1000; with current caps,
# 5×200 + 200 = 1200 already, so n_remaining is typically 0 and we keep the larger pilot.)
already_picked_ids = set(control_sampled["uniref50_id"]).union(set(natural_sampled["uniref50_id"]))
remaining_pool = noncontrol_pool[~noncontrol_pool["uniref50_id"].isin(already_picked_ids) & (~noncontrol_pool["uniref50_id"].isin(natural_candidate_ids))]
n_remaining = max(0, TARGET_N_UNIREF50 - n_control - n_natural)
noncontrol_sampled = _cog_stratified_sample(remaining_pool, n_remaining, seed=RNG_SEED)
pilot_uniref50 = pd.concat([control_sampled, natural_sampled, noncontrol_sampled], ignore_index=True)
log["n_pilot_uniref50"] = int(len(pilot_uniref50))
log["pilot_uniref50_class_counts"] = pilot_uniref50["control_class"].value_counts().to_dict()
print(f"\nPilot UniRef50s: {len(pilot_uniref50):,} "
f"(controls: {n_control:,}; natural_expansion: {n_natural:,}; non-control: {len(noncontrol_sampled):,})")
print(pilot_uniref50["control_class"].value_counts())
Control-pool counts before sampling: control_class pos_tcs_hk 43217 neg_ribosomal 19389 neg_trna_synth 8514 pos_amr 958 neg_rnap_core 770 Name: count, dtype: int64
Natural-expansion pool (max_paralog ≥ 3 AND n_species ≥ 5): 5,537 UniRef50s after deduplicating against existing controls: 5,239 candidates
Pilot UniRef50s: 1,200 (controls: 1,000; natural_expansion: 200; non-control: 0) control_class neg_ribosomal 200 neg_rnap_core 200 neg_trna_synth 200 pos_amr 200 pos_tcs_hk 200 natural_expansion 200 Name: count, dtype: int64
Stage 8 — Pre-flight control coverage check¶
Confirm ≥ 80 % of pilot species have ≥ 1 entry in each control set. If a control set is sparse (<80 %), Phase 1A null-model validation downstream will be unreliable for that control. Flag any failures here.
# Pre-flight: report TWO coverage metrics per control class.
# pool_coverage = % of pilot species with ≥1 UniRef50 in the FULL unsampled control pool
# (biological question: "is this control class biologically present in pilot species?")
# sampled_coverage = % of pilot species with ≥1 UniRef50 in the 200-cap SAMPLED subset
# (statistical question: "will null-model fitting have enough per-species data?")
# Pool coverage is the canonical pre-flight; sampled coverage is informational.
pilot_uniref50_ids = pilot_uniref50["uniref50_id"].tolist()
# v2 fix: rebuild from raw lists (avoid pyarrow ChunkedArray after pd.concat)
pilot_uniref50_spark = spark.createDataFrame(
pd.DataFrame({
"uniref50_id": pilot_uniref50["uniref50_id"].tolist(),
"control_class": pilot_uniref50["control_class"].tolist(),
})
)
# --- POOL coverage: presence of any pool UniRef50 in any pilot species, by class ---
pool_classes = uniref50_pool[["uniref50_id", "control_class"]].copy()
pool_classes_spark = spark.createDataFrame(
pd.DataFrame({
"uniref50_id": pool_classes["uniref50_id"].tolist(),
"control_class_pool": pool_classes["control_class"].tolist(),
})
)
pool_presence = (
uniref50.join(pool_classes_spark, on="uniref50_id", how="inner")
.filter(F.col("control_class_pool") != "none")
.join(pilot_gene_clusters.select("gene_cluster_id", "gtdb_species_clade_id"), on="gene_cluster_id", how="inner")
.select("gtdb_species_clade_id", "control_class_pool")
.distinct()
)
pool_coverage = (
pool_presence.groupBy("control_class_pool")
.agg(F.countDistinct("gtdb_species_clade_id").alias("n_species_pool"))
.toPandas()
.rename(columns={"control_class_pool": "control_class"})
)
# --- SAMPLED coverage: presence of sampled-200 UniRef50s in pilot species ---
sampled_presence = (
uniref50.join(pilot_uniref50_spark, on="uniref50_id", how="inner")
.join(pilot_gene_clusters.select("gene_cluster_id", "gtdb_species_clade_id"), on="gene_cluster_id", how="inner")
.select("gtdb_species_clade_id", "control_class")
.distinct()
)
sampled_coverage = (
sampled_presence.groupBy("control_class")
.agg(F.countDistinct("gtdb_species_clade_id").alias("n_species_sampled"))
.toPandas()
)
coverage = pool_coverage.merge(sampled_coverage, on="control_class", how="outer").fillna(0)
coverage["n_species_pool"] = coverage["n_species_pool"].astype(int)
coverage["n_species_sampled"] = coverage["n_species_sampled"].astype(int)
coverage["n_pilot_species"] = log["n_pilot_species"]
coverage["pool_coverage_fraction"] = coverage["n_species_pool"] / coverage["n_pilot_species"]
coverage["sampled_coverage_fraction"] = coverage["n_species_sampled"] / coverage["n_pilot_species"]
coverage["pool_passes_80pct"] = coverage["pool_coverage_fraction"] >= 0.80
print("Pre-flight coverage (pool = biological; sampled = methodological):")
print(coverage.to_string(index=False))
log["preflight_coverage"] = coverage.to_dict(orient="records")
preflight_failed = coverage[~coverage["pool_passes_80pct"]]
if len(preflight_failed) > 0:
print("\n*** PRE-FLIGHT POOL-COVERAGE WARNING — biological control sparsity ***")
print(preflight_failed[["control_class", "pool_coverage_fraction"]])
print("These control classes are not present in ≥80 % of pilot species at the biological level.")
print("Acceptable for AMR (clade-restricted) but suspicious for housekeeping negatives.")
Pre-flight coverage (pool = biological; sampled = methodological):
control_class n_species_pool n_species_sampled n_pilot_species pool_coverage_fraction sampled_coverage_fraction pool_passes_80pct
natural_expansion 0 736 1000 0.000 0.736 False
neg_ribosomal 1000 732 1000 1.000 0.732 True
neg_rnap_core 1000 712 1000 1.000 0.712 True
neg_trna_synth 1000 559 1000 1.000 0.559 True
pos_amr 688 331 1000 0.688 0.331 False
pos_tcs_hk 996 264 1000 0.996 0.264 True
*** PRE-FLIGHT POOL-COVERAGE WARNING — biological control sparsity ***
control_class pool_coverage_fraction
0 natural_expansion 0.000
4 pos_amr 0.688
These control classes are not present in ≥80 % of pilot species at the biological level.
Acceptable for AMR (clade-restricted) but suspicious for housekeeping negatives.
Stage 9 — Per-genome annotation density (D2 input)¶
For each pilot representative genome: fraction of its gene clusters with at least one eggNOG annotation. This is the D2 nuisance covariate residualized against producer / participation scores in NB02.
# Per-species annotated fraction = (annotated clusters in species) / (total clusters in species)
# Uses pilot gene clusters joined to eggnog presence
pilot_eggnog_ids = pilot_eggnog.select("gene_cluster_id").distinct()
ann_density = (
pilot_gene_clusters.select("gene_cluster_id", "gtdb_species_clade_id")
.join(pilot_eggnog_ids.withColumn("is_annotated", F.lit(1)), on="gene_cluster_id", how="left")
.na.fill({"is_annotated": 0})
.groupBy("gtdb_species_clade_id")
.agg(
F.count("*").alias("n_clusters_total"),
F.sum("is_annotated").alias("n_clusters_annotated"),
)
.withColumn("annotated_fraction", F.col("n_clusters_annotated") / F.col("n_clusters_total"))
).toPandas()
pilot_species = pilot_species.merge(ann_density, on="gtdb_species_clade_id", how="left")
print("Annotated fraction summary:")
print(pilot_species["annotated_fraction"].describe())
Annotated fraction summary: count 1000.000000 mean 0.724550 std 0.126357 min 0.276753 25% 0.635311 50% 0.729249 75% 0.829652 max 0.949606 Name: annotated_fraction, dtype: float64
Stage 10 — Materialize outputs¶
# 10a — pilot species
species_out = pilot_species[[
"gtdb_species_clade_id", "representative_genome_id",
"domain", "phylum", "class", "order", "family", "genus", "species_rank",
"no_genomes", "no_gene_clusters", "no_core", "no_aux_genome", "no_singleton_gene_clusters",
"checkm_completeness", "checkm_contamination", "genome_size", "gc_percentage", "protein_count",
"n_clusters_total", "n_clusters_annotated", "annotated_fraction",
]]
species_out.to_csv(DATA_DIR / "p1a_pilot_species.tsv", sep="\t", index=False)
print(f"Wrote p1a_pilot_species.tsv: {len(species_out):,} rows")
# 10b — pilot UniRef50
uniref50_cols = [
"uniref50_id", "dominant_cog_category", "dominant_kegg_ko", "dominant_brite_b",
"dominant_ipr_desc", "dominant_go_id", "dominant_pathway_id", "dominant_pathway_db",
"n_gene_clusters_supporting", "control_class",
"is_ribosomal", "is_trna_synth", "is_rnap_core",
"is_tcs_hk", "is_amr",
]
# Tolerate missing optional cols if upstream join produced no row for some UniRefs
for c_ in uniref50_cols:
if c_ not in pilot_uniref50.columns:
pilot_uniref50[c_] = None
uniref50_out = pilot_uniref50[uniref50_cols]
uniref50_out.to_csv(DATA_DIR / "p1a_pilot_uniref50.tsv", sep="\t", index=False)
print(f"Wrote p1a_pilot_uniref50.tsv: {len(uniref50_out):,} rows")
# 10c — pilot extract: long-format (species × UniRef50) with copy-number.
# Compute via Spark, then materialize to driver via toPandas() and write parquet locally.
# (Spark Connect cannot write to driver-local paths reliably; collecting to driver and using
# pandas.to_parquet avoids the executor-vs-driver filesystem split.)
uniref50_pilot_view = spark.createDataFrame(
pd.DataFrame({"uniref50_id": pilot_uniref50_ids})
)
uniref50_filtered = uniref50.join(uniref50_pilot_view, on="uniref50_id", how="inner")
extract = (
uniref50_filtered
.join(pilot_gene_clusters.select("gene_cluster_id", "gtdb_species_clade_id"), on="gene_cluster_id")
.join(uniref90.select("gene_cluster_id", "uniref90_id"), on="gene_cluster_id", how="left")
.groupBy("gtdb_species_clade_id", "uniref50_id")
.agg(
F.countDistinct("uniref90_id").alias("n_uniref90_present"),
F.countDistinct("gene_cluster_id").alias("n_gene_clusters"),
)
.withColumn("is_present", (F.col("n_gene_clusters") > 0).cast("boolean"))
)
extract_pdf = extract.toPandas()
# Spark Connect attaches PlanMetrics in pandas DataFrame's `attrs`; pyarrow tries to
# embed those attrs as JSON metadata in the parquet file and fails. Strip attrs by
# rebuilding the DataFrame from raw column values.
extract_pdf_clean = pd.DataFrame(
{col: extract_pdf[col].to_numpy() for col in extract_pdf.columns}
)
extract_out_path = str(DATA_DIR / "p1a_pilot_extract.parquet")
# Defensive cleanup: a prior failed Spark write may have left a part-file directory
# at this path. Single-file pandas writes will fail with IsADirectoryError otherwise.
import shutil as _sh
if os.path.isdir(extract_out_path):
_sh.rmtree(extract_out_path)
elif os.path.isfile(extract_out_path):
os.remove(extract_out_path)
extract_pdf_clean.to_parquet(extract_out_path, index=False)
log["n_extract_rows"] = int(len(extract_pdf_clean))
print(f"Wrote p1a_pilot_extract.parquet: {len(extract_pdf_clean):,} rows")
# 10d — pre-flight coverage diagnostic
coverage.to_csv(DATA_DIR / "p1a_preflight_coverage.tsv", sep="\t", index=False)
print(f"Wrote p1a_preflight_coverage.tsv")
# 10e — extraction log
log["completed_utc"] = time.strftime("%Y-%m-%dT%H:%M:%SZ", time.gmtime())
with open(DATA_DIR / "p1a_extraction_log.json", "w") as f:
json.dump(log, f, indent=2, default=str)
print(f"Wrote p1a_extraction_log.json")
print("\nAll Stage-10 outputs materialized.")
Wrote p1a_pilot_species.tsv: 1,000 rows Wrote p1a_pilot_uniref50.tsv: 1,200 rows
Wrote p1a_pilot_extract.parquet: 6,638 rows Wrote p1a_preflight_coverage.tsv Wrote p1a_extraction_log.json All Stage-10 outputs materialized.
Summary¶
Phase 1A pilot extract is complete. Outputs are written under data/:
p1a_pilot_species.tsv— 1K species reps with taxonomy + quality + annotation densityp1a_pilot_uniref50.tsv— 1K UniRef50s with COG / KEGG / BRITE + control-class labelsp1a_pilot_extract.parquet— long-format (species × UniRef50) presence/copy-number matrixp1a_preflight_coverage.tsv— control-set coverage diagnostic with 80 % pass/fail flagp1a_extraction_log.json— full extraction log: seed, sampling parameters, all row counts, timestamps, discoveredbakta_db_xrefs.dbvalues
Next: NB02 (02_p1a_null_model_construction.ipynb) reads p1a_pilot_extract.parquet to construct and calibrate the producer null (clade-matched neutral-family) and consumer null (phyletic-distribution permutation), then validates them against the Alm 2006 TCS HK reproduction at pilot scale.
If the pre-flight coverage check above flagged any control set as <80 % covered, NB02's gate decision will need to factor that in — a control set that is too sparse cannot anchor null-model validation.