Skip to content

Documentation for agricola Python API

agricola exports two functions which can be used in place of the recommended command line workflow:


agricola.step1.step1(datasets, Y, X, phenotypes, train_mask, test_mask, h2_prior, trait_type, loocv=False, B=1000, idx_sample=None, variants=None, level0_dir=None)

Perform agricola step 1

Parameters:

Name Type Description Default
datasets list[LancData]

A list of LancData objects (either single object or one per-chromosome)

required
Y ArrayLike

A (N, P) jax array of phenotypes

required
X Optional[ArrayLike]

A (N, C) jax array of covariates (no intercept)

required
phenotypes list[str]

A list of phenotype names, ordered as the columns of Y

required
train_mask ArrayLike

A (N, K) jax array indicating training set status for each set k in 1, ..., K

required
test_mask ArrayLike

A (N, K) jax array indicating test set status for each set k in 1, ..., K

required
h2_prior ArrayLike

A 1D jax array of prior values for snp heritability

required
trait_type str

Either "qt" or "bt"

required
loocv bool

A boolean indicating whether to perform LOOCV instead of standard cross validation. Ignored for trait_type="qt".

False
B int

The number of variants per block

1000
idx_sample Optional[ArrayLike]

An optional (N_sub,) jax array with indices of samples to include

None
variants Optional[list[str]]

A list of variant IDs to include in the analysis. If not provided, all variants are used

None
level0_dir Optional[str]

The directory where level 0 predictions are written

None

Returns:

Type Description
dict[str, DataFrame]

A dict where keys are chromosomes and values are (N, P) pandas DataFrames of level 1 predictions

Source code in src/agricola/step1.py
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
def step1(
    datasets: list[LancData],
    Y: ArrayLike,
    X: Optional[ArrayLike],
    phenotypes: list[str],
    train_mask: ArrayLike,
    test_mask: ArrayLike,
    h2_prior: ArrayLike,
    trait_type: str,
    loocv: bool = False,
    B: int = 1000,
    idx_sample: Optional[ArrayLike] = None,
    variants: Optional[list[str]] = None,
    level0_dir: Optional[str] = None,
) -> dict[str, DataFrame]:
    """Perform agricola step 1

    Args:
        datasets: A list of LancData objects (either single object or one
            per-chromosome)
        Y: A (N, P) jax array of phenotypes
        X: A (N, C) jax array of covariates (no intercept)
        phenotypes: A list of phenotype names, ordered as the columns of Y
        train_mask: A (N, K) jax array indicating training set status for each set k in 1, ..., K
        test_mask: A (N, K) jax array indicating test set status for each set k in 1, ..., K
        h2_prior: A 1D jax array of prior values for snp heritability
        trait_type: Either "qt" or "bt"
        loocv: A boolean indicating whether to perform LOOCV instead of standard
            cross validation. Ignored for trait_type="qt".
        B: The number of variants per block
        idx_sample: An optional (N_sub,) jax array with indices of samples to include
        variants: A list of variant IDs to include in the analysis. If not provided, all variants are used
        level0_dir: The directory where level 0 predictions are written

    Returns:
        A dict where keys are chromosomes and values are (N, P) pandas DataFrames of level 1 predictions
    """
    rm_dir0 = False
    if level0_dir is None:
        rm_dir0 = True
        tmp = tempfile.TemporaryDirectory()
        level0_dir = tmp.name

    os.makedirs(level0_dir, exist_ok=True)

    level0_files = level0(
        datasets,
        Y,
        X,
        phenotypes,
        train_mask,
        test_mask,
        h2_prior,
        B,
        idx_sample,
        variants,
        level0_dir,
    )

    step1_predictions = level1(
        level0_files,
        Y,
        X,
        phenotypes,
        train_mask,
        test_mask,
        h2_prior,
        trait_type,
        loocv,
    )

    ## Cleanup
    if rm_dir0:
        for files in level0_files.values():
            for file in files.values():
                os.remove(file)
        if os.path.isdir(level0_dir) and not os.listdir(level0_dir):
            os.rmdir(level0_dir)

    return step1_predictions

agricola.step2.step2(datasets, Y, X, step1_predictions, out_prefixes, phenotypes, trait_type='qt', test_type='score', chrom=None, B=1000, min_ac=1, idx_sample=None, variants=None, adjust_lanc=True, impute=False)

Perform agricola step 2

Parameters:

Name Type Description Default
datasets list[LancData]

A list of LancData objects (either single object or one per-chromosome)

required
Y ArrayLike

A (N, P) jax array of outcomes

required
X Optional[ArrayLike]

A (N, C) jax array of covariates

required
step1_predictions dict[str, DataFrame]

A dict with LOCO linear predictions from step 1. The values are (N, P) NumPy arrays

required
out_prefixes list[str]

A list of prefixes for each dataset. Outputs will be written to {output_prefix}_{phenotype}.parquet

required
phenotypes list[str]

A list of phenotype names

required
trait_type str

either "qt" or "bt"

'qt'
test_type str

Either "score" or "wald"

'score'
B int

The block size (max number of variants to read at once)

1000
min_ac int

the minimum allele count threshold

1
idx_sample Optional[ArrayLike]

An optional numpy array with ordered indices of samples (in the psam file) to retain

None
variants Optional[list[str]]

An optional list of variant IDs to retain

None
adjust_lanc bool

A boolean indicating whether to adjust tests for local ancestry

True
impute bool

Whether to impute the phenotype. Much faster, but only available for qt traits. If all phenotypes are non-missing, this is ignored.

False
Source code in src/agricola/step2.py
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
def step2(
    datasets: list[LancData],
    Y: ArrayLike,
    X: Optional[ArrayLike],
    step1_predictions: dict[str, pd.DataFrame],
    out_prefixes: list[str],
    phenotypes: list[str],
    trait_type: str = "qt",
    test_type: str = "score",
    chrom: Optional[str] = None,
    B: int = 1000,
    min_ac: int = 1,
    idx_sample: Optional[ArrayLike] = None,
    variants: Optional[list[str]] = None,
    adjust_lanc: bool = True,
    impute: bool = False,
) -> None:
    """Perform agricola step 2

    Args:
        datasets: A list of LancData objects (either single object or one
            per-chromosome)
        Y: A (N, P) jax array of outcomes
        X: A (N, C) jax array of covariates
        step1_predictions: A dict with LOCO linear predictions from step 1. The values are (N, P) NumPy arrays
        out_prefixes: A list of prefixes for each dataset. Outputs will be written to {output_prefix}_{phenotype}.parquet
        phenotypes: A list of phenotype names
        trait_type: either "qt" or "bt"
        test_type: Either "score" or "wald"
        B: The block size (max number of variants to read at once)
        min_ac: the minimum allele count threshold
        idx_sample: An optional numpy array with ordered indices of samples (in
            the psam file) to retain
        variants: An optional list of variant IDs to retain
        adjust_lanc: A boolean indicating whether to adjust tests for local ancestry
        impute: Whether to impute the phenotype. Much faster, but only available
            for qt traits. If all phenotypes are non-missing, this is ignored.
    """
    if impute:
        M = jnp.ones(shape=jnp.asarray(Y).shape)
    else:
        M = (~jnp.isnan(jnp.asarray(Y))).astype(jnp.float32)

    Y, X, step1_predictions_np, idx_sample, test_type_enum, trait_type_enum = (
        validate_step2_inputs(
            datasets,
            Y,
            X,
            phenotypes,
            step1_predictions,
            out_prefixes,
            B,
            idx_sample,
            variants,
            test_type,
            trait_type,
        )
    )

    ## Adjust phenotype for covariates to match step 1
    if trait_type_enum == TraitType.QT:
        Q, _ = jnp.linalg.qr(X, mode="reduced")
        Y = stdize(Y - (Q @ (Q.T @ Y)))
        if (M == 1).all():
            impute = True
    else:
        if impute:
            raise ValueError("impute must be False for binary traits")

    ## Adjust covariates for per-phenotype missingness
    X = X[:, :, None] - jnp.sum(X[:, :, None] * M[:, None, :], axis=0) / jnp.sum(
        M, axis=0
    )
    X = X * M[:, None, :]

    for i, dataset in enumerate(datasets):
        pgen_path = dataset.plink_prefix + ".pgen"
        desc = f"Getting step 2 results for file: {pgen_path}"
        _step2_dataset(
            dataset,
            Y,
            M,
            step1_predictions_np,
            X,
            idx_sample,
            out_prefixes[i],
            phenotypes,
            trait_type_enum,
            test_type_enum,
            desc,
            chrom,
            B,
            min_ac,
            variants,
            adjust_lanc,
            impute,
        )