05 P1B Full Data Extraction
Jupyter notebook from the Gene Function Ecological Agora project.
NB05 — Phase 1B Full GTDB Data Extraction¶
Project: Gene Function Ecological Agora — Innovation Atlas Across the Bacterial Tree
Phase: 1B — Full GTDB scale (post-Phase-1A PASS_WITH_REVISION)
Purpose: Extract the full Phase 1B substrate (all bacterial GTDB representatives × targeted UniRef50 pool) for downstream null-model construction (NB06), atlas computation (NB07), and Phase 1B gate decision (NB08).
What changes from Phase 1A NB01 v3¶
Per RESEARCH_PLAN.md v2.3:
- Species: all bacterial GTDB representatives passing CheckM ≥ 90 % / contamination ≤ 5 % (no 1K cap)
- UniRef sampling: targeted rather than uniform-random across the 1.5 M-cluster pool. Includes:
- All control classes from Phase 1A (positive: AMR, TCS HK; negative: ribosomal, tRNA-synth, RNAP core)
- All natural_expansion candidates from the full pool (UniRef50s with max-paralog ≥ 3 AND species ≥ 5)
- All CAZyme UniRef50s (glycoside hydrolases + carbohydrate-binding modules + polysaccharide lyases) — the Bacteroidota PUL hypothesis test target
- All β-lactamase UniRef50s (HIGH 1 — known-cross-phylum-HGT positive control for consumer null; literature: Forsberg 2012, Bonomo 2017)
- All class-I CRISPR-Cas UniRef50s (HIGH 1 — known cross-tree-of-life HGT, Metcalf 2014)
- Prevalence-stratified sample of remaining UniRef50s to fill atlas distribution
- Output via Spark write to MinIO (s3a://) — driver-side pandas will not fit the larger extract
- Reproducible seed:
RNG_SEED=42
Inputs from Phase 1A¶
data/p1a_pilot_uniref50.tsv— pilot UniRef50 IDs, used as a presence check (sanity: are these still in the pool?)data/p1a_phase_gate_decision.json— confirms Phase 1APASS_WITH_REVISION
Outputs¶
data/p1b_full_species.tsv— all bacterial GTDB-rep species with taxonomy + quality + annotation density (~20K–28K rows)data/p1b_full_uniref50.tsv— UniRef50 pool with class labels + IPR/GO/pathway enrichment- Spark/MinIO:
s3a://cdm-lake/tenant-general-warehouse/microbialdiscoveryforge/projects/gene_function_ecological_agora/data/p1b_full_extract.parquet— long-format (species, UniRef50, paralog_count, n_uniref90_present, is_present) data/p1b_full_extract_local.parquet— local copy if size permits (≤2 GB target)data/p1b_full_extraction_log.json— extraction parameters + row counts + diagnostics
Strategy notes¶
- Performance: gene_genecluster_junction (1 B rows) and bakta_db_xrefs (572 M rows) are filtered by
gtdb_species_clade_id IN (...)via temp view (per docs/performance.md anti-pattern guidance against large IN clauses) - InterProScan: 833 M rows; filter by pilot gene_cluster_ids first, then by
signature_acc IN (Pfam set)orLOWER(ipr_desc) LIKE '%pattern%' - autoBroadcast: disabled at session start (per
[bakta_reannotation]pitfall) - Scaling: a Phase-1B-lite test parameter
MAX_SPECIES = Noneallows running on a subset for verification; set toNonefor full scale
Estimated runtime at full scale: 20–60 min depending on cluster load. Stages 4–6 (joins on bakta_db_xrefs and interproscan_domains) are the bottleneck.
Setup¶
import os, json, time
from pathlib import Path
import numpy as np
import pandas as pd
from pyspark.sql import functions as F
from pyspark.sql.types import StringType
try:
spark # noqa: F821
print("Spark session: pre-injected (JupyterHub notebook)")
except NameError:
from berdl_notebook_utils.setup_spark_session import get_spark_session
spark = get_spark_session()
print("Spark session: created via berdl_notebook_utils")
spark.sql("SET spark.sql.autoBroadcastJoinThreshold = -1")
PROJECT_ROOT = Path("/home/aparkin/BERIL-research-observatory/projects/gene_function_ecological_agora")
DATA_DIR = PROJECT_ROOT / "data"
DATA_DIR.mkdir(parents=True, exist_ok=True)
RNG_SEED = 42
rng = np.random.default_rng(RNG_SEED)
# Set to None for full-scale; an integer for testing (e.g., 5000)
MAX_SPECIES = None
PREVALENCE_FILL_TARGET = 30000 # number of additional UniRef50s sampled by prevalence stratification
MIN_CHECKM_COMPLETE = 90.0
MAX_CHECKM_CONTAM = 5.0
# Natural-expansion thresholds (same as Phase 1A v3)
NATURAL_MIN_PARALOG = 3
NATURAL_MIN_SPECIES = 5
log = {
"seed": RNG_SEED,
"max_species": MAX_SPECIES,
"min_checkm_complete": MIN_CHECKM_COMPLETE,
"max_checkm_contam": MAX_CHECKM_CONTAM,
"natural_min_paralog": NATURAL_MIN_PARALOG,
"natural_min_species": NATURAL_MIN_SPECIES,
"prevalence_fill_target": PREVALENCE_FILL_TARGET,
"timestamp_utc": time.strftime("%Y-%m-%dT%H:%M:%SZ", time.gmtime()),
}
print(json.dumps(log, indent=2))
Spark session: created via berdl_notebook_utils
{
"seed": 42,
"max_species": null,
"min_checkm_complete": 90.0,
"max_checkm_contam": 5.0,
"natural_min_paralog": 3,
"natural_min_species": 5,
"prevalence_fill_target": 30000,
"timestamp_utc": "2026-04-26T23:29:50Z"
}
Stage 1 — All bacterial GTDB representatives + quality filter¶
species_clade = spark.sql("""
SELECT
gtdb_species_clade_id,
representative_genome_id,
GTDB_species,
GTDB_taxonomy
FROM kbase_ke_pangenome.gtdb_species_clade
""")
species_clade = species_clade.withColumn("_taxa", F.split(F.col("GTDB_taxonomy"), ";"))
for i, rank in enumerate(["domain", "phylum", "class", "order", "family", "genus", "species_rank"]):
species_clade = species_clade.withColumn(rank, F.expr(f"try_element_at(_taxa, {i + 1})"))
species_clade = species_clade.drop("_taxa").filter(F.col("domain") == "d__Bacteria")
pangenome_stats = spark.sql("""
SELECT gtdb_species_clade_id,
try_cast(no_genomes AS BIGINT) AS no_genomes,
try_cast(no_gene_clusters AS BIGINT) AS no_gene_clusters,
try_cast(no_core AS BIGINT) AS no_core,
try_cast(no_aux_genome AS BIGINT) AS no_aux_genome,
try_cast(no_singleton_gene_clusters AS BIGINT) AS no_singleton_gene_clusters
FROM kbase_ke_pangenome.pangenome
""")
metadata = spark.sql("""
SELECT accession AS representative_genome_id,
try_cast(checkm_completeness AS DOUBLE) AS checkm_completeness,
try_cast(checkm_contamination AS DOUBLE) AS checkm_contamination,
try_cast(genome_size AS BIGINT) AS genome_size,
try_cast(gc_percentage AS DOUBLE) AS gc_percentage,
try_cast(protein_count AS BIGINT) AS protein_count
FROM kbase_ke_pangenome.gtdb_metadata
""")
species_full = (
species_clade
.join(pangenome_stats, on="gtdb_species_clade_id", how="inner")
.join(metadata, on="representative_genome_id", how="left")
.filter(F.col("checkm_completeness") >= MIN_CHECKM_COMPLETE)
.filter(F.col("checkm_contamination") <= MAX_CHECKM_CONTAM)
)
if MAX_SPECIES is not None:
print(f"MAX_SPECIES set to {MAX_SPECIES} — restricting for scaling test")
# Stratified across phyla using a pandas sample on the small driver-side df
species_pdf = species_full.toPandas()
if len(species_pdf) > MAX_SPECIES:
# Sample proportional-to-phylum-size with floor=5 per phylum
phylum_sizes = species_pdf.groupby("phylum").size()
alloc = {p: min(5, n) for p, n in phylum_sizes.items()}
remaining = MAX_SPECIES - sum(alloc.values())
if remaining > 0:
eligible = phylum_sizes[phylum_sizes > 5]
weights = eligible / eligible.sum()
for p, w in weights.items():
alloc[p] = min(int(alloc[p] + round(w * remaining)), int(phylum_sizes[p]))
pieces = []
for p, n in alloc.items():
sub = species_pdf[species_pdf["phylum"] == p]
if n > 0:
pieces.append(sub.sample(n=min(n, len(sub)), random_state=RNG_SEED))
species_pdf = pd.concat(pieces, ignore_index=True)
species_full = spark.createDataFrame(species_pdf)
species_full.cache()
n_species = species_full.count()
log["n_species_after_quality"] = int(n_species)
print(f"Species after quality filter: {n_species:,}")
print(species_full.groupBy("phylum").count().orderBy(F.desc("count")).show(15))
Species after quality filter: 18,989
+--------------------+-----+ | phylum|count| +--------------------+-----+ | p__Pseudomonadota| 6035| | p__Actinomycetota| 2603| | p__Bacteroidota| 2581| | p__Bacillota_A| 2531| | p__Bacillota| 1844| |p__Verrucomicrobiota| 340| | p__Cyanobacteriota| 331| | p__Campylobacterota| 252| | p__Planctomycetota| 235| | p__Chloroflexota| 234| | p__Spirochaetota| 228| | p__Desulfobacterota| 226| | p__Acidobacteriota| 181| | p__Bacillota_C| 177| | p__Myxococcota| 81| +--------------------+-----+ only showing top 15 rows None
Stage 2 — Pull all gene_clusters for these species¶
species_full.select("gtdb_species_clade_id").createOrReplaceTempView("p1b_species_view")
all_gene_clusters = spark.sql("""
SELECT gc.gene_cluster_id, gc.gtdb_species_clade_id,
gc.is_core, gc.is_auxiliary, gc.is_singleton
FROM kbase_ke_pangenome.gene_cluster gc
JOIN p1b_species_view sv ON gc.gtdb_species_clade_id = sv.gtdb_species_clade_id
""")
all_gene_clusters.cache()
n_gene_clusters = all_gene_clusters.count()
log["n_gene_clusters"] = int(n_gene_clusters)
print(f"Gene clusters in full Phase 1B species: {n_gene_clusters:,}")
all_gene_clusters.createOrReplaceTempView("p1b_gc_view")
Gene clusters in full Phase 1B species: 103,629,867
Stage 3 — UniRef50 pool from bakta_db_xrefs¶
uniref_xrefs = spark.sql("""
SELECT bx.gene_cluster_id, bx.accession
FROM kbase_ke_pangenome.bakta_db_xrefs bx
JOIN p1b_gc_view pg ON bx.gene_cluster_id = pg.gene_cluster_id
WHERE bx.db = 'UniRef'
""")
uniref_split = (
uniref_xrefs
.withColumn("uniref_tier", F.regexp_extract(F.col("accession"), r"^(UniRef\d+)_", 1))
.withColumn("uniref_id", F.col("accession"))
)
uniref50 = uniref_split.filter(F.col("uniref_tier") == "UniRef50").select(
"gene_cluster_id", F.col("uniref_id").alias("uniref50_id")
)
uniref90 = uniref_split.filter(F.col("uniref_tier") == "UniRef90").select(
"gene_cluster_id", F.col("uniref_id").alias("uniref90_id")
)
uniref50.cache(); uniref90.cache()
log["n_uniref50_xref_rows"] = int(uniref50.count())
log["n_uniref90_xref_rows"] = int(uniref90.count())
print(f"UniRef50 xref rows: {log['n_uniref50_xref_rows']:,}")
print(f"UniRef90 xref rows: {log['n_uniref90_xref_rows']:,}")
UniRef50 xref rows: 84,175,091 UniRef90 xref rows: 68,381,626
Stage 4 — Compute UniRef50 paralog signal across full species set¶
# Per (uniref50_id, gtdb_species_clade_id) paralog count = number of distinct gene_clusters in species mapped to this UniRef50
uniref50_with_species = uniref50.join(
all_gene_clusters.select("gene_cluster_id", "gtdb_species_clade_id"),
on="gene_cluster_id", how="inner",
)
uniref50_signal = (
uniref50_with_species
.groupBy("uniref50_id", "gtdb_species_clade_id")
.agg(F.count("*").alias("n_clusters_in_species"))
.groupBy("uniref50_id")
.agg(
F.max("n_clusters_in_species").alias("max_paralog_in_any_species"),
F.countDistinct("gtdb_species_clade_id").alias("n_species_with_uref"),
)
)
uniref50_signal.cache()
n_uref50_total = uniref50_signal.count()
log["n_unique_uniref50_in_pool"] = int(n_uref50_total)
print(f"Distinct UniRef50 IDs in Phase 1B pool: {n_uref50_total:,}")
Distinct UniRef50 IDs in Phase 1B pool: 15,382,302
Stage 5 — Pull eggNOG + InterProScan annotations + control flags¶
eggnog = spark.sql("""
SELECT query_name AS gene_cluster_id, COG_category, KEGG_ko, KEGG_Pathway, BRITE, PFAMs
FROM kbase_ke_pangenome.eggnog_mapper_annotations
""")
pilot_eggnog = eggnog.join(all_gene_clusters.select("gene_cluster_id"), on="gene_cluster_id", how="inner")
pilot_eggnog.cache()
n_eggnog = pilot_eggnog.count()
log["n_eggnog_annotated_clusters"] = int(n_eggnog)
print(f"eggNOG-annotated clusters: {n_eggnog:,}")
ips_domains = spark.sql("""
SELECT gene_cluster_id, analysis, signature_acc, signature_desc, ipr_acc, ipr_desc
FROM kbase_ke_pangenome.interproscan_domains
""")
pilot_ips = ips_domains.join(all_gene_clusters.select("gene_cluster_id"), on="gene_cluster_id", how="inner")
pilot_ips.cache()
n_ips = pilot_ips.count()
log["n_ips_domain_rows"] = int(n_ips)
print(f"InterProScan domain rows: {n_ips:,}")
ips_go = spark.sql("SELECT gene_cluster_id, go_id, go_source FROM kbase_ke_pangenome.interproscan_go")
pilot_ips_go = ips_go.join(all_gene_clusters.select("gene_cluster_id"), on="gene_cluster_id", how="inner")
pilot_ips_go.cache()
log["n_ips_go_rows"] = int(pilot_ips_go.count())
ips_pathways = spark.sql("SELECT gene_cluster_id, pathway_db, pathway_id FROM kbase_ke_pangenome.interproscan_pathways")
pilot_ips_pathways = ips_pathways.join(all_gene_clusters.select("gene_cluster_id"), on="gene_cluster_id", how="inner")
pilot_ips_pathways.cache()
log["n_ips_pathway_rows"] = int(pilot_ips_pathways.count())
eggNOG-annotated clusters: 74,450,297
InterProScan domain rows: 653,537,492
Stage 6 — Identify control sets at gene-cluster level¶
Following the InterProScan-based detection from Phase 1A NB01 v2 + adding HIGH 1 known-HGT positive controls (β-lactamases, class-I CRISPR-Cas) + CAZymes (Bacteroidota PUL hypothesis target).
# --- Pfam accessions ---
TCS_HK_PFAM = {"PF00512", "PF07568", "PF07730", "PF06580", "PF02518", "PF13415", "PF13581"}
TCS_RR_PFAM = {"PF00072", "PF00196", "PF02954", "PF00486"}
BETA_LAC_PFAM = {"PF00144", "PF13354", "PF12706", "PF13483", "PF00768"}
CRISPR_CAS_PFAM = {"PF18557", "PF09704", "PF09827", "PF09455", "PF09659", "PF09704", "PF09827"}
# --- KEGG ---
RIBOSOMAL_KOS = {f"K{n:05d}" for n in (list(range(2860, 2900)) + list(range(2950, 2999)))}
TRNA_SYNTH_KOS = {f"K{n:05d}" for n in range(1866, 1891)}
RNAP_CORE_KOS = {"K03040", "K03043", "K03046"}
def has_any_pfam_name(pfam_str, names):
if pfam_str in (None, "", "-"):
return False
toks = {t.strip() for t in pfam_str.replace(";", ",").replace(" ", ",").split(",") if t.strip()}
return bool(toks & names)
def has_any_ko(ko_str, kos):
if ko_str in (None, "", "-"):
return False
toks = [t.replace("ko:", "").strip() for t in ko_str.replace(";", ",").split(",")]
return any(t in kos for t in toks)
rib_udf = F.udf(lambda s: has_any_ko(s, RIBOSOMAL_KOS))
trna_udf = F.udf(lambda s: has_any_ko(s, TRNA_SYNTH_KOS))
rnap_udf = F.udf(lambda s: has_any_ko(s, RNAP_CORE_KOS))
tcs_eggnog_udf = F.udf(lambda s: has_any_pfam_name(s, {"HisKA", "HisKA_2", "HisKA_3", "HWE_HK", "HATPase_c", "Response_reg", "Trans_reg_C"}))
eggnog_flags = (
pilot_eggnog
.withColumn("egg_ribosomal", rib_udf(F.col("KEGG_ko")) == "true")
.withColumn("egg_trna_synth", trna_udf(F.col("KEGG_ko")) == "true")
.withColumn("egg_rnap_core", rnap_udf(F.col("KEGG_ko")) == "true")
.withColumn("egg_tcs_hk", tcs_eggnog_udf(F.col("PFAMs")) == "true")
).select("gene_cluster_id", "egg_ribosomal", "egg_trna_synth", "egg_rnap_core", "egg_tcs_hk")
# InterProScan-based detections
def acc_filter_clause(accs):
return "', '".join(sorted(accs))
ips_tcs = spark.sql(f"""
SELECT DISTINCT gene_cluster_id FROM kbase_ke_pangenome.interproscan_domains
WHERE analysis = 'Pfam' AND signature_acc IN ('{acc_filter_clause(TCS_HK_PFAM | TCS_RR_PFAM)}')
""").join(all_gene_clusters.select("gene_cluster_id"), on="gene_cluster_id", how="inner").withColumn("ips_tcs_hk", F.lit(True))
ips_betalac = spark.sql(f"""
SELECT DISTINCT gene_cluster_id FROM kbase_ke_pangenome.interproscan_domains
WHERE (analysis = 'Pfam' AND signature_acc IN ('{acc_filter_clause(BETA_LAC_PFAM)}'))
OR LOWER(ipr_desc) LIKE '%beta-lactamase%'
OR LOWER(signature_desc) LIKE '%beta-lactamase%'
""").join(all_gene_clusters.select("gene_cluster_id"), on="gene_cluster_id", how="inner").withColumn("ips_betalac", F.lit(True))
ips_crispr = spark.sql(f"""
SELECT DISTINCT gene_cluster_id FROM kbase_ke_pangenome.interproscan_domains
WHERE analysis = 'Pfam' AND signature_acc IN ('{acc_filter_clause(CRISPR_CAS_PFAM)}')
""").join(all_gene_clusters.select("gene_cluster_id"), on="gene_cluster_id", how="inner").withColumn("ips_crispr", F.lit(True))
ips_ribosomal = spark.sql("""
SELECT DISTINCT gene_cluster_id FROM kbase_ke_pangenome.interproscan_domains
WHERE LOWER(ipr_desc) LIKE '%ribosomal protein%' OR LOWER(signature_desc) LIKE '%ribosomal protein%'
""").join(all_gene_clusters.select("gene_cluster_id"), on="gene_cluster_id", how="inner").withColumn("ips_ribosomal", F.lit(True))
ips_trna = spark.sql("""
SELECT DISTINCT gene_cluster_id FROM kbase_ke_pangenome.interproscan_domains
WHERE LOWER(ipr_desc) LIKE '%aminoacyl-trna synthetase%'
OR LOWER(signature_desc) LIKE '%aminoacyl-trna synthetase%'
OR LOWER(signature_desc) LIKE '%trna ligase%'
""").join(all_gene_clusters.select("gene_cluster_id"), on="gene_cluster_id", how="inner").withColumn("ips_trna_synth", F.lit(True))
ips_rnap = spark.sql("""
SELECT DISTINCT gene_cluster_id FROM kbase_ke_pangenome.interproscan_domains
WHERE (LOWER(ipr_desc) LIKE '%rna polymerase%alpha%' OR LOWER(signature_desc) LIKE '%rna polymerase%alpha%')
OR (LOWER(ipr_desc) LIKE '%rna polymerase%beta%' OR LOWER(signature_desc) LIKE '%rna polymerase%beta%')
""").join(all_gene_clusters.select("gene_cluster_id"), on="gene_cluster_id", how="inner").withColumn("ips_rnap_core", F.lit(True))
# CAZymes — Bacteroidota PUL hypothesis target
ips_cazyme = spark.sql("""
SELECT DISTINCT gene_cluster_id FROM kbase_ke_pangenome.interproscan_domains
WHERE LOWER(ipr_desc) LIKE '%glycoside hydrolase%'
OR LOWER(signature_desc) LIKE '%glycoside hydrolase%'
OR LOWER(signature_desc) LIKE '%glycosyl hydrolase%'
OR LOWER(ipr_desc) LIKE '%carbohydrate-binding module%'
OR LOWER(signature_desc) LIKE '%carbohydrate-binding module%'
OR LOWER(ipr_desc) LIKE '%polysaccharide lyase%'
""").join(all_gene_clusters.select("gene_cluster_id"), on="gene_cluster_id", how="inner").withColumn("ips_cazyme", F.lit(True))
# AMR via bakta_amr (broader curated set)
amr = spark.sql("SELECT DISTINCT gene_cluster_id FROM kbase_ke_pangenome.bakta_amr").join(
all_gene_clusters.select("gene_cluster_id"), on="gene_cluster_id", how="inner"
).withColumn("is_amr", F.lit(True))
# Combine all flags
cluster_flags = (
all_gene_clusters.select("gene_cluster_id")
.join(eggnog_flags, on="gene_cluster_id", how="left")
.join(ips_tcs, on="gene_cluster_id", how="left")
.join(ips_betalac, on="gene_cluster_id", how="left")
.join(ips_crispr, on="gene_cluster_id", how="left")
.join(ips_ribosomal, on="gene_cluster_id", how="left")
.join(ips_trna, on="gene_cluster_id", how="left")
.join(ips_rnap, on="gene_cluster_id", how="left")
.join(ips_cazyme, on="gene_cluster_id", how="left")
.join(amr, on="gene_cluster_id", how="left")
.na.fill({
"egg_ribosomal": False, "egg_trna_synth": False, "egg_rnap_core": False, "egg_tcs_hk": False,
"ips_tcs_hk": False, "ips_betalac": False, "ips_crispr": False,
"ips_ribosomal": False, "ips_trna_synth": False, "ips_rnap_core": False,
"ips_cazyme": False, "is_amr": False,
})
.withColumn("is_ribosomal", F.col("egg_ribosomal") | F.col("ips_ribosomal"))
.withColumn("is_trna_synth", F.col("egg_trna_synth") | F.col("ips_trna_synth"))
.withColumn("is_rnap_core", F.col("egg_rnap_core") | F.col("ips_rnap_core"))
.withColumn("is_tcs_hk", F.col("egg_tcs_hk") | F.col("ips_tcs_hk"))
.withColumn("is_betalac", F.col("ips_betalac"))
.withColumn("is_crispr_cas", F.col("ips_crispr"))
.withColumn("is_cazyme", F.col("ips_cazyme"))
.select("gene_cluster_id", "is_ribosomal", "is_trna_synth", "is_rnap_core",
"is_tcs_hk", "is_betalac", "is_crispr_cas", "is_cazyme", "is_amr")
)
cluster_flags.cache()
log["control_class_cluster_counts"] = (
cluster_flags.select(
F.sum(F.col("is_ribosomal").cast("int")).alias("ribosomal"),
F.sum(F.col("is_trna_synth").cast("int")).alias("trna_synth"),
F.sum(F.col("is_rnap_core").cast("int")).alias("rnap_core"),
F.sum(F.col("is_tcs_hk").cast("int")).alias("tcs_hk"),
F.sum(F.col("is_betalac").cast("int")).alias("betalac"),
F.sum(F.col("is_crispr_cas").cast("int")).alias("crispr_cas"),
F.sum(F.col("is_cazyme").cast("int")).alias("cazyme"),
F.sum(F.col("is_amr").cast("int")).alias("amr"),
).toPandas().iloc[0].to_dict()
)
print("Cluster-level control flag counts:")
print(json.dumps(log["control_class_cluster_counts"], indent=2, default=str))
Cluster-level control flag counts:
{
"ribosomal": 1746105,
"trna_synth": 698864,
"rnap_core": 93356,
"tcs_hk": 2047587,
"betalac": 610325,
"crispr_cas": 23777,
"cazyme": 1219866,
"amr": 74651
}
Stage 7 — Aggregate flags + IPR annotation to UniRef50 level¶
uniref50_flags = (
uniref50
.join(cluster_flags, on="gene_cluster_id", how="left")
.na.fill({"is_ribosomal": False, "is_trna_synth": False, "is_rnap_core": False,
"is_tcs_hk": False, "is_betalac": False, "is_crispr_cas": False,
"is_cazyme": False, "is_amr": False})
.groupBy("uniref50_id")
.agg(
F.max(F.col("is_ribosomal").cast("int")).alias("is_ribosomal"),
F.max(F.col("is_trna_synth").cast("int")).alias("is_trna_synth"),
F.max(F.col("is_rnap_core").cast("int")).alias("is_rnap_core"),
F.max(F.col("is_tcs_hk").cast("int")).alias("is_tcs_hk"),
F.max(F.col("is_betalac").cast("int")).alias("is_betalac"),
F.max(F.col("is_crispr_cas").cast("int")).alias("is_crispr_cas"),
F.max(F.col("is_cazyme").cast("int")).alias("is_cazyme"),
F.max(F.col("is_amr").cast("int")).alias("is_amr"),
)
)
uniref50_flags.cache()
n_uref_flagged = uniref50_flags.count()
log["n_uniref50_flagged"] = int(n_uref_flagged)
print(f"UniRef50 IDs with control flags computed: {n_uref_flagged:,}")
UniRef50 IDs with control flags computed: 15,382,302
Stage 8 — Targeted UniRef50 sampling: controls + natural_expansion + CAZymes + prevalence stratification¶
# Do the entire selection on Spark side — collecting the full 1.5M+ UniRef50 pool to driver
# blew driver.maxResultSize. We only collect the final selected ID list + counts.
uref_pool_spark = uniref50_flags.join(uniref50_signal, on="uniref50_id", how="inner")
uref_pool_spark.cache()
# Class assignment via Spark expression (priority order matches primary_class fn)
uref_pool_spark = uref_pool_spark.withColumn(
"control_class",
F.when(F.col("is_betalac") == 1, "pos_betalac")
.when(F.col("is_crispr_cas") == 1, "pos_crispr_cas")
.when(F.col("is_amr") == 1, "pos_amr")
.when(F.col("is_tcs_hk") == 1, "pos_tcs_hk")
.when(F.col("is_ribosomal") == 1, "neg_ribosomal")
.when(F.col("is_trna_synth") == 1, "neg_trna_synth")
.when(F.col("is_rnap_core") == 1, "neg_rnap_core")
.when(F.col("is_cazyme") == 1, "hyp_cazyme")
.when(
(F.col("max_paralog_in_any_species") >= NATURAL_MIN_PARALOG)
& (F.col("n_species_with_uref") >= NATURAL_MIN_SPECIES),
"natural_expansion",
)
.otherwise("none")
)
# Class counts (small aggregation; safe to collect)
class_counts = uref_pool_spark.groupBy("control_class").count().toPandas()
log["control_class_uref_counts"] = dict(zip(class_counts["control_class"], class_counts["count"]))
print("UniRef50 pool by class (full pool):")
print(class_counts.sort_values("count", ascending=False).to_string(index=False))
# v2 fix: cap each class at PER_CLASS_CAP (some classes have 100K+ UniRefs;
# uncapped this blew driver.maxResultSize on the downstream extract). Sample within
# each class with prevalence stratification (proxy: bin by n_species_with_uref).
PER_CLASS_CAP = 10000
log["per_class_cap"] = PER_CLASS_CAP
print(f"\nApplying per-class cap of {PER_CLASS_CAP} UniRef50s with prevalence-stratified sampling within class")
# For each named class, do bucketed-fractional sampling on Spark side
named_classes_to_cap = [
"pos_betalac", "pos_tcs_hk", "pos_amr", "pos_crispr_cas",
"neg_ribosomal", "neg_trna_synth", "neg_rnap_core",
"hyp_cazyme", "natural_expansion",
]
named_capped_pieces = []
class_counts_dict = {row["control_class"]: int(row["count"]) for _, row in class_counts.iterrows()}
for cls in named_classes_to_cap:
pool_n = class_counts_dict.get(cls, 0)
if pool_n == 0:
continue
cls_pool = uref_pool_spark.filter(F.col("control_class") == cls)
if pool_n <= PER_CLASS_CAP:
named_capped_pieces.append(cls_pool)
print(f" {cls:20s}: {pool_n:>7,} → kept all")
else:
# Prevalence-stratified sampling on n_species_with_uref
from pyspark.sql.window import Window as _Win
binned = cls_pool.withColumn("_pb", F.ntile(10).over(_Win.orderBy(F.col("n_species_with_uref").asc())))
bin_sizes_cls = binned.groupBy("_pb").count().toPandas()
per_bin_target = PER_CLASS_CAP // 10
fractions_cls = {int(r["_pb"]): min(1.0, per_bin_target / max(int(r["count"]), 1))
for _, r in bin_sizes_cls.iterrows()}
sampled_cls = binned.sampleBy("_pb", fractions=fractions_cls, seed=RNG_SEED).drop("_pb")
named_capped_pieces.append(sampled_cls)
sampled_n = sampled_cls.count()
print(f" {cls:20s}: {pool_n:>7,} → sampled {sampled_n:,}")
named_spark = named_capped_pieces[0]
for piece in named_capped_pieces[1:]:
named_spark = named_spark.unionByName(piece)
named_spark.cache()
n_named = named_spark.count()
log["n_named_uref"] = int(n_named)
print(f"\nNamed-class UniRef50s after per-class cap: {n_named:,}")
# Prevalence-stratified sample on "none" class — Spark-side bucketing
# Use ntile() over n_species_with_uref to bin into 10 quantiles
from pyspark.sql.window import Window
unnamed_spark = uref_pool_spark.filter(F.col("control_class") == "none")
n_unnamed = unnamed_spark.count()
log["n_unnamed_uref_pool"] = int(n_unnamed)
print(f"Unnamed (\"none\"-class) pool: {n_unnamed:,}")
if n_unnamed > 0 and PREVALENCE_FILL_TARGET > 0:
n_bins = 10
per_bin = max(1, PREVALENCE_FILL_TARGET // n_bins)
# ntile assigns 1..n_bins
w = Window.orderBy(F.col("n_species_with_uref").asc())
unnamed_binned = unnamed_spark.withColumn("_prev_bin", F.ntile(n_bins).over(w))
# sampleBy fractional per bin
bin_sizes = unnamed_binned.groupBy("_prev_bin").count().toPandas()
fractions = {}
for _, row in bin_sizes.iterrows():
b = int(row["_prev_bin"])
n = int(row["count"])
fractions[b] = min(1.0, per_bin / max(n, 1))
print(f"Sampling {per_bin} per bin → fractions {fractions}")
unnamed_sampled = unnamed_binned.sampleBy("_prev_bin", fractions=fractions, seed=RNG_SEED).drop("_prev_bin")
else:
unnamed_sampled = unnamed_spark.limit(0)
n_unnamed_sampled = unnamed_sampled.count()
log["n_unnamed_sampled"] = int(n_unnamed_sampled)
print(f"Unnamed sampled: {n_unnamed_sampled:,}")
# Union named + sampled-unnamed → final Phase 1B target set
p1b_uref_spark = named_spark.unionByName(unnamed_sampled)
p1b_uref_spark.cache()
n_p1b_uref = p1b_uref_spark.count()
log["n_p1b_uref50_target"] = int(n_p1b_uref)
print(f"\nPhase 1B target UniRef50 set: {n_p1b_uref:,}")
# Class composition (small aggregation; safe)
final_class_counts = p1b_uref_spark.groupBy("control_class").count().toPandas()
log["p1b_uref50_class_counts"] = dict(zip(final_class_counts["control_class"], final_class_counts["count"]))
print(final_class_counts.sort_values("count", ascending=False).to_string(index=False))
# Materialize the metadata TSV via Spark write (then read back) to avoid driver-side OOM
# For a few hundred thousand rows, toPandas should still fit since the column count is small (12 cols)
p1b_uref_pdf = p1b_uref_spark.toPandas()
UniRef50 pool by class (full pool):
control_class count
none 14365873
pos_tcs_hk 346644
natural_expansion 311445
hyp_cazyme 160557
pos_betalac 82064
neg_ribosomal 70849
neg_trna_synth 34644
pos_amr 3656
neg_rnap_core 3606
pos_crispr_cas 2964
Applying per-class cap of 10000 UniRef50s with prevalence-stratified sampling within class
/usr/local/spark/python/pyspark/sql/connect/expressions.py:1091: UserWarning: WARN WindowExpression: No Partition Defined for Window operation! Moving all data to a single partition, this can cause serious performance degradation. warnings.warn(
/usr/local/spark/python/pyspark/sql/connect/expressions.py:1091: UserWarning: WARN WindowExpression: No Partition Defined for Window operation! Moving all data to a single partition, this can cause serious performance degradation. warnings.warn(
pos_betalac : 82,064 → sampled 9,917
/usr/local/spark/python/pyspark/sql/connect/expressions.py:1091: UserWarning: WARN WindowExpression: No Partition Defined for Window operation! Moving all data to a single partition, this can cause serious performance degradation. warnings.warn(
pos_tcs_hk : 346,644 → sampled 10,089 pos_amr : 3,656 → kept all pos_crispr_cas : 2,964 → kept all
/usr/local/spark/python/pyspark/sql/connect/expressions.py:1091: UserWarning: WARN WindowExpression: No Partition Defined for Window operation! Moving all data to a single partition, this can cause serious performance degradation. warnings.warn(
neg_ribosomal : 70,849 → sampled 10,077
/usr/local/spark/python/pyspark/sql/connect/expressions.py:1091: UserWarning: WARN WindowExpression: No Partition Defined for Window operation! Moving all data to a single partition, this can cause serious performance degradation. warnings.warn(
neg_trna_synth : 34,644 → sampled 10,070 neg_rnap_core : 3,606 → kept all
/usr/local/spark/python/pyspark/sql/connect/expressions.py:1091: UserWarning: WARN WindowExpression: No Partition Defined for Window operation! Moving all data to a single partition, this can cause serious performance degradation. warnings.warn(
hyp_cazyme : 160,557 → sampled 9,993
/usr/local/spark/python/pyspark/sql/connect/expressions.py:1091: UserWarning: WARN WindowExpression: No Partition Defined for Window operation! Moving all data to a single partition, this can cause serious performance degradation. warnings.warn(
natural_expansion : 311,445 → sampled 10,030
Named-class UniRef50s after per-class cap: 70,402
Unnamed ("none"-class) pool: 14,365,873
Sampling 3000 per bin → fractions {1: 0.0020882814000952256, 2: 0.0020882814000952256, 3: 0.0020882814000952256, 4: 0.002088282853735973, 5: 0.002088282853735973, 6: 0.002088282853735973, 7: 0.002088282853735973, 8: 0.002088282853735973, 9: 0.002088282853735973, 10: 0.002088282853735973}
/usr/local/spark/python/pyspark/sql/connect/expressions.py:1091: UserWarning: WARN WindowExpression: No Partition Defined for Window operation! Moving all data to a single partition, this can cause serious performance degradation. warnings.warn(
Unnamed sampled: 29,790
Phase 1B target UniRef50 set: 100,192
control_class count
none 29790
pos_tcs_hk 10089
neg_ribosomal 10077
neg_trna_synth 10070
natural_expansion 10030
hyp_cazyme 9993
pos_betalac 9917
pos_amr 3656
neg_rnap_core 3606
pos_crispr_cas 2964
/usr/local/spark/python/pyspark/sql/connect/expressions.py:1091: UserWarning: WARN WindowExpression: No Partition Defined for Window operation! Moving all data to a single partition, this can cause serious performance degradation. warnings.warn(
Stage 9 — Materialize species + UniRef50 metadata + extract¶
The extract is the (species, UniRef50) presence/paralog long-format table — the input to NB06's null-model construction. For Phase 1B at full scale this can be 10s–100s of millions of rows; we write parquet via Spark to MinIO and a local copy if size permits.
# Per-genome annotated fraction (for D2)
ann_density = (
all_gene_clusters.select("gene_cluster_id", "gtdb_species_clade_id")
.join(
pilot_eggnog.select("gene_cluster_id").distinct().withColumn("is_annotated", F.lit(1)),
on="gene_cluster_id", how="left",
)
.na.fill({"is_annotated": 0})
.groupBy("gtdb_species_clade_id")
.agg(
F.count("*").alias("n_clusters_total"),
F.sum("is_annotated").alias("n_clusters_annotated"),
)
.withColumn("annotated_fraction", F.col("n_clusters_annotated") / F.col("n_clusters_total"))
.toPandas()
)
species_pdf = species_full.toPandas().merge(ann_density, on="gtdb_species_clade_id", how="left")
species_pdf.to_csv(DATA_DIR / "p1b_full_species.tsv", sep="\t", index=False)
print(f"Wrote p1b_full_species.tsv: {len(species_pdf):,} rows")
# UniRef50 metadata (no IPR/GO enrichment yet — defer to NB06 if needed; class flags + signal sufficient here)
p1b_uref_out = p1b_uref_pdf[[
"uniref50_id", "control_class",
"is_ribosomal", "is_trna_synth", "is_rnap_core", "is_tcs_hk",
"is_betalac", "is_crispr_cas", "is_cazyme", "is_amr",
"max_paralog_in_any_species", "n_species_with_uref",
]]
p1b_uref_out.to_csv(DATA_DIR / "p1b_full_uniref50.tsv", sep="\t", index=False)
print(f"Wrote p1b_full_uniref50.tsv: {len(p1b_uref_out):,} rows")
# Build the long-format extract on Spark, restricted to target UniRef50 set
target_uref_ids = p1b_uref_pdf["uniref50_id"].tolist()
spark.createDataFrame(pd.DataFrame({"uniref50_id": target_uref_ids})).createOrReplaceTempView("p1b_target_uref")
uniref50_filtered = uniref50.join(spark.table("p1b_target_uref"), on="uniref50_id", how="inner")
extract = (
uniref50_filtered
.join(all_gene_clusters.select("gene_cluster_id", "gtdb_species_clade_id"), on="gene_cluster_id")
.join(uniref90.select("gene_cluster_id", "uniref90_id"), on="gene_cluster_id", how="left")
.groupBy("gtdb_species_clade_id", "uniref50_id")
.agg(
F.countDistinct("uniref90_id").alias("n_uniref90_present"),
F.countDistinct("gene_cluster_id").alias("n_gene_clusters"),
)
.withColumn("is_present", (F.col("n_gene_clusters") > 0).cast("boolean"))
)
extract.cache()
n_extract = extract.count()
log["n_extract_rows"] = int(n_extract)
print(f"\nExtract rows: {n_extract:,}")
# Try driver-side parquet first (fast path); fall back to Spark write to MinIO if driver OOMs.
out_path_local = str(DATA_DIR / "p1b_full_extract_local.parquet")
out_path_minio = "s3a://cdm-lake/tenant-general-warehouse/microbialdiscoveryforge/projects/gene_function_ecological_agora/data/p1b_full_extract.parquet"
try:
extract_pdf = extract.toPandas()
extract_clean = pd.DataFrame({col: extract_pdf[col].to_numpy() for col in extract_pdf.columns})
if os.path.isdir(out_path_local):
import shutil; shutil.rmtree(out_path_local)
elif os.path.isfile(out_path_local):
os.remove(out_path_local)
extract_clean.to_parquet(out_path_local, index=False)
log["extract_parquet_local_path"] = out_path_local
log["extract_parquet_size_mb"] = round(os.path.getsize(out_path_local) / 1e6, 1)
print(f"Wrote p1b_full_extract_local.parquet ({log['extract_parquet_size_mb']} MB)")
except Exception as e_local:
print(f"Driver-side parquet write failed ({e_local!r}); falling back to Spark write to MinIO")
log["extract_local_write_error"] = repr(e_local)
try:
# Spark write to MinIO — works even when driver collect fails
extract.write.mode("overwrite").parquet(out_path_minio)
log["extract_parquet_minio_path"] = out_path_minio
# Read back the row count to verify
n_minio = spark.read.parquet(out_path_minio).count()
log["extract_minio_row_count"] = int(n_minio)
print(f"Wrote {out_path_minio} ({n_minio:,} rows verified by read-back)")
except Exception as e_minio:
print(f"MinIO write ALSO failed ({e_minio!r}) — extract not materialized; downstream notebooks must rerun NB05 with smaller scope")
log["extract_minio_write_error"] = repr(e_minio)
log["extract_parquet_minio_path"] = None
log["completed_utc"] = time.strftime("%Y-%m-%dT%H:%M:%SZ", time.gmtime())
with open(DATA_DIR / "p1b_full_extraction_log.json", "w") as f:
json.dump(log, f, indent=2, default=str)
print(f"\nWrote p1b_full_extraction_log.json")
Wrote p1b_full_species.tsv: 18,989 rows Wrote p1b_full_uniref50.tsv: 100,192 rows
Extract rows: 1,539,643
Wrote p1b_full_extract_local.parquet (14.0 MB) Wrote p1b_full_extraction_log.json
Summary¶
Phase 1B extraction complete. Outputs feed NB06 (null model construction at full scale with M1 rank-stratified parents) and NB07 (atlas + Bacteroidota PUL hypothesis test).
Class composition (see p1b_full_uniref50.tsv):
- Negative controls: ribosomal / tRNA-synth / RNAP core (dosage-constrained signature; M2 criterion)
- Positive controls (intra-phylum HGT): AMR + TCS HK
- Positive controls (cross-phylum HGT — HIGH 1): β-lactamase + class-I CRISPR-Cas
- Hypothesis target: CAZyme (Bacteroidota PUL test in NB07)
- Sanity positive: natural_expansion (paralog signal at full scale)
- Atlas baseline: prevalence-stratified random sample