Skip to content

Documentation for detector.pixel

Pixelization scheme module for jax-gw.

This module contains functions for calculating the sky geometry.

flat_to_matrix_sky_indices(N_theta, N_phi)

Calculate the (N_theta*N_phi, 2) matrix of flat indices for a given sky resolution.

Parameters:

Name Type Description Default
N_theta int

Number of ecliptic thetas.

required
N_phi int

Number of ecliptic phis.

required

Returns:

Type Description
ArrayLike

Matrix of flat indices.

Source code in src/jax_gw/detector/pixel.py
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
def flat_to_matrix_sky_indices(N_theta: int, N_phi: int):
    """Calculate the (N_theta*N_phi, 2) matrix of flat indices for a given sky resolution.

    Parameters
    ----------
    N_theta : int
        Number of ecliptic thetas.
    N_phi : int
        Number of ecliptic phis.

    Returns
    -------
    ArrayLike
        Matrix of flat indices.
    """
    # without for loop or list comprehension
    a = jnp.arange(N_theta * N_phi)

    i = jnp.floor_divide(a, N_phi)
    j = jnp.mod(a, N_phi)

    return jnp.stack([i, j], axis=1)

flatten_sky(i_theta, j_phi, N_phi)

Flatten the sky coordinates into a single index.

Parameters:

Name Type Description Default
i_theta int

Index of the ecliptic theta.

required
j_phi int

Index of the ecliptic phi.

required
N_phi int

Number of ecliptic phis.

required

Returns:

Type Description
int

Flattened index.

Source code in src/jax_gw/detector/pixel.py
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
def flatten_sky(i_theta: int, j_phi: int, N_phi: int) -> int:
    """Flatten the sky coordinates into a single index.

    Parameters
    ----------
    i_theta : int
        Index of the ecliptic theta.
    j_phi : int
        Index of the ecliptic phi.
    N_phi : int
        Number of ecliptic phis.

    Returns
    -------
    int
        Flattened index.
    """
    return i_theta * N_phi + j_phi

get_directional_basis(ecl_theta, ecl_phi)

Calculate the directional basis for a given source direction.

Parameters:

Name Type Description Default
ecl_theta ArrayLike

Ecliptic latitude of the source.

required
ecl_phi ArrayLike

Ecliptic phi of the source.

required

Returns:

Type Description
array

Directional basis k_hat, u_hat, v_hat, where k_hat is the direction of the incoming signal, u_hat is same as theta_hat, and v_hat is same as phi_hat.

Note that k, u, v are not a right-handed coordinate system, but -k, u, v is.

Source code in src/jax_gw/detector/pixel.py
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
def get_directional_basis(ecl_theta: ArrayLike, ecl_phi: ArrayLike) -> Array:
    """Calculate the directional basis for a given source direction.

    Parameters
    ----------
    ecl_theta : ArrayLike
        Ecliptic latitude of the source.
    ecl_phi : ArrayLike
        Ecliptic phi of the source.

    Returns
    -------
    jnp.array
        Directional basis k_hat, u_hat, v_hat, where k_hat is the direction of the
        incoming signal, u_hat is same as theta_hat, and v_hat is same as phi_hat.

        Note that k, u, v are not a right-handed coordinate system, but -k, u, v is.
    """
    cos_theta = jnp.cos(ecl_theta)
    sin_theta = jnp.sin(ecl_theta)
    cos_phi = jnp.cos(ecl_phi)
    sin_phi = jnp.sin(ecl_phi)
    zero_element = jnp.zeros_like(cos_theta)

    k_hat = -jnp.stack(
        [sin_theta * cos_phi, sin_theta * sin_phi, cos_theta],
        axis=-1,
    )
    # u_hat is theta_hat, v_hat is phi_hat
    u_hat = jnp.stack([cos_theta * cos_phi, cos_theta * sin_phi, -sin_theta], axis=-1)
    v_hat = jnp.stack([-sin_phi, cos_phi, zero_element], axis=-1)

    return jnp.stack([k_hat, u_hat, v_hat], axis=0)

get_solid_angle_theta_phi(theta, phi, N_theta, N_phi)

Get the sky area associated with a given theta and phi in a pixelated sphere. Assumes linear spacing in theta and phi.

Parameters:

Name Type Description Default
theta float

Ecliptic theta.

required
phi float

Ecliptic phi.

required
N_theta int

Number of theta bins.

required
N_phi int

Number of phi bins.

required

Returns:

Type Description
float

Sky area associated with theta and phi.

Source code in src/jax_gw/detector/pixel.py
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
def get_solid_angle_theta_phi(theta, phi, N_theta, N_phi):
    """Get the sky area associated with a given theta and phi in a
    pixelated sphere. Assumes linear spacing in theta and phi.

    Parameters
    ----------
    theta : float
        Ecliptic theta.
    phi : float
        Ecliptic phi.
    N_theta : int
        Number of theta bins.
    N_phi : int
        Number of phi bins.

    Returns
    -------
    float
        Sky area associated with theta and phi.
    """
    delta_phi = 2 * jnp.pi / N_phi
    delta_theta = jnp.pi / (N_theta - 1)
    min_phi = phi - delta_phi / 2
    max_phi = phi + delta_phi / 2
    min_theta = jnp.maximum(theta - delta_theta / 2, 0)
    max_theta = jnp.minimum(theta + delta_theta / 2, jnp.pi)
    solid_angle = (max_phi - min_phi) * (jnp.cos(min_theta) - jnp.cos(max_theta))
    return solid_angle

pixel_to_lm(data_omega, axis, N_theta, N_phi, ecl_thetas, ecl_phis, sph_harm_values)

Convert a pixelated map to a spherical harmonic map.

Source code in src/jax_gw/detector/pixel.py
182
183
184
185
186
187
188
189
190
191
def pixel_to_lm(
    data_omega, axis, N_theta, N_phi, ecl_thetas, ecl_phis, sph_harm_values
):
    """Convert a pixelated map to a spherical harmonic map."""
    # sky axis last, preceded by two axes for l and m
    data_omega = jnp.moveaxis(data_omega, axis, -1)[..., None, None, :]
    data_lm = sph_harm_values * data_omega
    data_lm = data_lm * get_solid_angle_theta_phi(ecl_thetas, ecl_phis, N_theta, N_phi)
    data_lm = jnp.sum(data_lm, axis=-1)
    return data_lm

unflatten_sky(index, N_phi)

Unflatten the sky coordinates from a single index.

Parameters:

Name Type Description Default
index int

Flattened index.

required
N_phi int

Number of ecliptic phis.

required

Returns:

Type Description
NamedTuple

Unflattened sky coordinates.

Source code in src/jax_gw/detector/pixel.py
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
def unflatten_sky(index: int, N_phi: int):
    """Unflatten the sky coordinates from a single index.

    Parameters
    ----------
    index : int
        Flattened index.
    N_phi : int
        Number of ecliptic phis.

    Returns
    -------
    NamedTuple
        Unflattened sky coordinates.
    """
    i_theta = index // N_phi
    j_phi = index % N_phi

    return i_theta, j_phi

unflatten_sky_axis(matrix, axis, N_theta, N_phi)

Unflatten the axis of a matrix that corresponds to the sky coordinates.

Shape is converted from (...N, N_theta*N_phi, M...) to (...N, N_theta, N_phi, M...).

Parameters:

Name Type Description Default
matrix ArrayLike

Matrix to unflatten.

required
axis int

Axis to unflatten.

required
N_theta int

Number of ecliptic thetas.

required
N_phi int

Number of ecliptic phis.

required

Returns:

Type Description
array

Unflattened matrix.

Source code in src/jax_gw/detector/pixel.py
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
def unflatten_sky_axis(matrix, axis: int, N_theta: int, N_phi: int) -> Array:
    """Unflatten the axis of a matrix that corresponds to the sky coordinates.

    Shape is converted from (...N, N_theta*N_phi, M...) to (...N, N_theta, N_phi, M...).

    Parameters
    ----------
    matrix : ArrayLike
        Matrix to unflatten.
    axis : int
        Axis to unflatten.
    N_theta : int
        Number of ecliptic thetas.
    N_phi : int
        Number of ecliptic phis.

    Returns
    -------
    jnp.array
        Unflattened matrix.
    """
    flat_to_matrix = jnp.arange(N_theta * N_phi).reshape(N_theta, N_phi)

    return jnp.take(matrix, flat_to_matrix, axis=axis)