Skip to content

Documentation for detector.orbits

axial_tilt(equatorial_coords, earth_tilt)

Rotate a vector by an angle tilt around the x-axis. Convert from equatorial to ecliptic coordinates when earth_tilt is the positive Earth's tilt.

Parameters:

Name Type Description Default
equatorial_coords ArrayLike

Vector in equatorial coordinates.

required
earth_tilt float

Angle to rotate around the x-axis.

required

Returns:

Type Description
Array

Vector in ecliptic coordinates.

Source code in src/jax_gw/detector/orbits.py
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
def axial_tilt(equatorial_coords: ArrayLike, earth_tilt: float) -> Array:
    """Rotate a vector by an angle `tilt` around the x-axis.
    Convert from equatorial to ecliptic coordinates when `earth_tilt` is the positive Earth's tilt.

    Parameters
    ----------
    equatorial_coords : ArrayLike
        Vector in equatorial coordinates.
    earth_tilt : float
        Angle to rotate around the x-axis.

    Returns
    -------
    Array
        Vector in ecliptic coordinates.
    """
    rot_matrix = jnp.array(
        [
            [1.0, 0.0, 0.0],
            [0.0, jnp.cos(earth_tilt), jnp.sin(earth_tilt)],
            [0.0, -jnp.sin(earth_tilt), jnp.cos(earth_tilt)],
        ]
    )
    return jnp.dot(rot_matrix, equatorial_coords)

create_cartwheel_arm_lengths(ecc, r, N, times, freq=1.0)

Create the scalar separations for a cartwheel orbit.

Parameters:

Name Type Description Default
ecc float

Eccentricity of the orbit.

required
r float

Radius of the orbit of the guiding center.

required
N int

Number of spacecraft.

required
times array

Times at which to evaluate the orbit.

required

Returns:

Type Description
array

Separations. Dimensions: (len(times), N, N).

Source code in src/jax_gw/detector/orbits.py
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
def create_cartwheel_arm_lengths(
    ecc: float,
    r: float,
    N: int,
    times: ArrayLike,
    freq: float = 1.0,
) -> Array:
    """Create the scalar separations for a cartwheel orbit.

    Parameters
    ----------
    ecc : float
        Eccentricity of the orbit.
    r : float
        Radius of the orbit of the guiding center.
    N : int
        Number of spacecraft.
    times : jnp.array
        Times at which to evaluate the orbit.

    Returns
    -------
    jnp.array
        Separations. Dimensions: (len(times), N, N).
    """
    assert N == 3
    L = 2.0 * jnp.sqrt(3) * ecc * r

    lambda_cart = 0.0
    kappa_orbit = -20.0 / 360.0 * 2 * jnp.pi
    alpha = 2.0 * jnp.pi * freq * times + kappa_orbit

    exp_1_n1 = jnp.exp(1j * (alpha - lambda_cart))
    cos_1_n1 = jnp.real(exp_1_n1)
    cos_3_n3 = jnp.real(exp_1_n1**3)
    sin_1_n1_pi6 = jnp.imag(exp_1_n1 * jnp.exp(1j * jnp.pi / 6.0))
    sin_1_n1_npi6 = jnp.imag(exp_1_n1 * jnp.exp(-1j * jnp.pi / 6.0))

    arm_12 = L * (1 + ecc / 32.0 * (15.0 * sin_1_n1_pi6 - cos_3_n3))
    arm_13 = L * (1 - ecc / 32.0 * (15.0 * sin_1_n1_npi6 + cos_3_n3))
    arm_23 = L * (1 - ecc / 32.0 * (15.0 * cos_1_n1 + cos_3_n3))

    separations_flat = jnp.stack([arm_12, arm_13, arm_23], axis=0)

    d = jnp.zeros((N, N, len(times)))
    d = d.at[0, 1, :].set(separations_flat[0, :])
    d = d.at[1, 0, :].set(separations_flat[0, :])
    d = d.at[0, 2, :].set(separations_flat[1, :])
    d = d.at[2, 0, :].set(separations_flat[1, :])
    d = d.at[1, 2, :].set(separations_flat[2, :])
    d = d.at[2, 1, :].set(separations_flat[2, :])

    # move the time axis to the front
    d = jnp.moveaxis(d, -1, 0)

    return d

create_cartwheel_orbit(ecc, r, N, times, timeshift=0, freq=1.0)

Create a cartwheel orbit.

Parameters:

Name Type Description Default
ecc float

Eccentricity of the orbits.

required
r float

Radius of the orbit of the guiding center. Units: AU.

required
N int

Number of spacecraft.

required
times ArrayLike

Times at which to evaluate the orbit. Units: years.

required

Returns:

Type Description
Array

Orbit. Dimensions: (N, 3, len(times)). Units: AU.

Source code in src/jax_gw/detector/orbits.py
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
def create_cartwheel_orbit(
    ecc: float,
    r: float,
    N: int,
    times: ArrayLike,
    timeshift: float = 0,
    freq: float = 1.0,
) -> Array:
    """Create a cartwheel orbit.

    Parameters
    ----------
    ecc : float
        Eccentricity of the orbits.
    r : float
        Radius of the orbit of the guiding center. Units: AU.
    N : int
        Number of spacecraft.
    times : ArrayLike
        Times at which to evaluate the orbit. Units: years.

    Returns
    -------
    Array
        Orbit. Dimensions: (N, 3, len(times)). Units: AU.
    """
    # kappa is 20 degrees behind Earth
    kappa_orbit = -20.0 / 360.0 * 2 * jnp.pi
    lambda_cart = timeshift
    alpha = 2.0 * jnp.pi * freq * times + kappa_orbit
    beta_n = jnp.arange(N)[:, jnp.newaxis] * 2.0 * jnp.pi / N + lambda_cart

    exp_1_0 = jnp.exp(1j * (alpha))
    exp_2_n1 = jnp.exp(1j * (2 * alpha - beta_n))
    exp_0_1 = jnp.exp(1j * (beta_n))
    exp_3_n2 = jnp.exp(1j * (3 * alpha - 2 * beta_n))
    exp_1_n2 = jnp.exp(1j * (alpha - 2 * beta_n))

    term_1 = r * exp_1_0
    term_2 = 0.5 * r * ecc * (exp_2_n1 - 3.0 * exp_0_1)
    term_3 = 0.125 * r * ecc**2 * (3.0 * exp_3_n2 - 10.0 * exp_1_0)
    term_4 = 0.125 * r * ecc**2 * 5.0 * exp_1_n2

    L = jnp.sqrt(3) * ecc * r
    exp_1_n1 = jnp.exp(1j * (alpha - beta_n))
    cos_1_n1, sin_1_n1 = jnp.real(exp_1_n1), jnp.imag(exp_1_n1)

    common_x_y = term_1 + term_2 + term_3

    x = jnp.real(common_x_y - term_4)
    y = jnp.imag(common_x_y + term_4)
    z = -L * cos_1_n1 + L * ecc * (1 + sin_1_n1**2)

    return jnp.stack([x, y, z], axis=1)

create_circular_orbit_xy(r, f_orb, times)

Create an orbit around the Sun with x and y Arms.

Parameters:

Name Type Description Default
r float

Radius of the orbit.

required
times array

Times at which to evaluate the orbit.

required

Returns:

Type Description
array

Orbit. Dimensions: (1, 3, len(times)).

Source code in src/jax_gw/detector/orbits.py
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
def create_circular_orbit_xy(r: float, f_orb: float, times: ArrayLike) -> Array:
    """Create an orbit around the Sun with x and y Arms.

    Parameters
    ----------
    r : float
        Radius of the orbit.
    times : jnp.array
        Times at which to evaluate the orbit.

    Returns
    -------
    jnp.array
        Orbit. Dimensions: (1, 3, len(times)).
    """
    # for now let's assume a circular orbit on the ecliptic plane
    x = r * jnp.cos(2.0 * jnp.pi * f_orb * times)
    y = r * jnp.sin(2.0 * jnp.pi * f_orb * times)
    z = jnp.zeros_like(x)

    return jnp.stack([x, y, z], axis=0)

earthbound_ifo_pipeline(lat, lon, times, r, L_arm, psi=0, beta_arm=jnp.pi / 2)

Create the orbits of the spacecraft for an Earthbound interferometer. Currently only works for perpendicular arms and assumes a circular orbit with the Earth modeled as a sphere.

Parameters:

Name Type Description Default
lat float

Earth latitude in radians. Zero is the equator, +pi/2 is the North pole.

required
lon float

Earth longitude in radians. Zero is the Greenwich meridian, positive is East.

required
times array

Times at which to evaluate the orbit in years.

required
r float

Radius of the orbit in AU.

required
L_arm float

Length of the arms in km.

required
psi float

Angle between the X arm and local East in radians. Positive North of East. The Y arm is rotated by an additional pi/2.

0
beta_arm float

Angle between the X and Y arms in radians.

pi / 2

Returns:

Type Description
array

Orbits of the N=3 points defining the interferometer. Dimensions: (N, 3, len(times)).

Source code in src/jax_gw/detector/orbits.py
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
def earthbound_ifo_pipeline(
    lat: float,
    lon: float,
    times: ArrayLike,
    r: float,
    L_arm: float,
    psi: float = 0,
    beta_arm: float = jnp.pi / 2,
) -> Array:
    """Create the orbits of the spacecraft for an Earthbound interferometer.
    Currently only works for perpendicular arms and assumes a circular orbit
    with the Earth modeled as a sphere.

    Parameters
    ----------
    lat : float
        Earth latitude in radians. Zero is the equator, +pi/2 is the North pole.
    lon : float
        Earth longitude in radians. Zero is the Greenwich meridian, positive is East.
    times : jnp.array
        Times at which to evaluate the orbit in years.
    r : float
        Radius of the orbit in AU.
    L_arm : float
        Length of the arms in km.
    psi : float
        Angle between the X arm and local East in radians. Positive North of East. The Y arm is rotated by an additional pi/2.
    beta_arm : float
        Angle between the X and Y arms in radians.

    Returns
    -------
    jnp.array
        Orbits of the N=3 points defining the interferometer. Dimensions: (N, 3, len(times)).
    """
    FREQ_CENTER_ORBIT = 1  # in 1/year
    FREQ_ROTATION = 365.25  # in 1/year
    r_orbital = create_circular_orbit_xy(r, FREQ_CENTER_ORBIT, times)

    # calculate x, y, z coordinates of detector with respect to the guiding center
    # at time t=0
    r_detector_initial_equatorial = lat_lon_to_cartesian(lat, lon)

    hour_angle = 2.0 * jnp.pi * FREQ_ROTATION * times
    r_detector = equatorial_timeshift(r_detector_initial_equatorial, hour_angle)

    r_detector = axial_tilt(r_detector, EARTH_TILT)

    r_earth_in_km = 6371.0
    # local East unit direction at the detector
    north_pole_equatorial = jnp.array([0.0, 0.0, 1.0])
    local_east = jnp.cross(north_pole_equatorial, r_detector_initial_equatorial)
    local_east = local_east / jnp.linalg.norm(local_east)
    # rotate the arms by psi with respect to r_detector_initial_equatorial
    # by applying the matrix form of Rodrigues' rotation formula
    K_matrix = jnp.array(
        [
            [0.0, -r_detector_initial_equatorial[2], r_detector_initial_equatorial[1]],
            [r_detector_initial_equatorial[2], 0.0, -r_detector_initial_equatorial[0]],
            [-r_detector_initial_equatorial[1], r_detector_initial_equatorial[0], 0.0],
        ]
    )
    rotation_matrix_psi = (
        jnp.eye(3) + jnp.sin(psi) * K_matrix + (1 - jnp.cos(psi)) * K_matrix @ K_matrix
    )
    rotation_matrix_beta = (
        jnp.eye(3)
        + jnp.sin(beta_arm) * K_matrix
        + (1 - jnp.cos(beta_arm)) * K_matrix @ K_matrix
    )
    x_arm_direction = rotation_matrix_psi @ local_east
    y_arm_direction = rotation_matrix_beta @ x_arm_direction
    print(x_arm_direction)
    print(y_arm_direction)
    arm_length = L_arm / r_earth_in_km
    x_arm_local_equatorial_initial = arm_length * x_arm_direction
    y_arm_local_equatorial_initial = arm_length * y_arm_direction
    # convert from equatorial to ecliptic coordinates
    x_arm_ecliptic_initial = axial_tilt(x_arm_local_equatorial_initial, +EARTH_TILT)
    print(x_arm_ecliptic_initial)
    y_arm_ecliptic_initial = axial_tilt(y_arm_local_equatorial_initial, +EARTH_TILT)
    print(y_arm_ecliptic_initial)

    # x_arm_ecliptic_initial = jnp.array([L_arm / r_earth_in_km, 0.0, 0.0])
    # y_arm_ecliptic_initial = jnp.array([0.0, L_arm / r_earth_in_km, 0.0])
    x_arm = ecliptic_timeshift(x_arm_ecliptic_initial, hour_angle, EARTH_TILT)
    y_arm = ecliptic_timeshift(y_arm_ecliptic_initial, hour_angle, EARTH_TILT)

    # add a rotation around this guiding center, assuming a solid body like the Earth
    AU_per_billion_meters = 149.597871
    AU_per_earth_radius = (AU_per_billion_meters * 1e9) / (r_earth_in_km * 1e3)
    print(AU_per_earth_radius)

    r_detector = jnp.array(r_detector, dtype=jnp.float64)
    r_beam_splitter = r_orbital + r_detector / AU_per_earth_radius

    x_arm = jnp.array(x_arm, dtype=jnp.float64) / AU_per_earth_radius
    y_arm = jnp.array(y_arm, dtype=jnp.float64) / AU_per_earth_radius
    x_arm = r_beam_splitter + x_arm
    y_arm = r_beam_splitter + y_arm

    orbits = jnp.stack([r_beam_splitter, x_arm, y_arm], axis=0)

    return orbits

ecliptic_timeshift(ecliptic_coords, angle, tilt)

Rotate a vector in ecliptic coordinates by an angle angle around the z-axis of equatorial coordinates. Shift ecliptic coordinates to a hour angle later.

Parameters:

Name Type Description Default
ecliptic_coords ArrayLike

Vector in ecliptic coordinates.

required
angle ArrayLike

Angle to rotate around the z-axis of equatorial coordinates.

required
tilt float

Angle to rotate around the x-axis.

required

Returns:

Type Description
Array

Vector in equatorial coordinates at time shifted by angle.

Source code in src/jax_gw/detector/orbits.py
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
def ecliptic_timeshift(
    ecliptic_coords: ArrayLike, angle: ArrayLike, tilt: float
) -> Array:
    """Rotate a vector in ecliptic coordinates by an angle `angle` around the z-axis of equatorial coordinates.
    Shift ecliptic coordinates to a hour `angle` later.

    Parameters
    ----------
    ecliptic_coords : ArrayLike
        Vector in ecliptic  coordinates.
    angle : ArrayLike
        Angle to rotate around the z-axis of equatorial coordinates.
    tilt : float
        Angle to rotate around the x-axis.

    Returns
    -------
    Array
        Vector in equatorial coordinates at time shifted by `angle`.
    """
    equatorial_initial = axial_tilt(ecliptic_coords, -tilt)

    equatorial_coords = equatorial_timeshift(equatorial_initial, angle)
    ecliptic_coords = axial_tilt(equatorial_coords, tilt)

    return ecliptic_coords

equatorial_timeshift(equatorial_coords, angle)

Rotate a vector by an angle angle around the z-axis of equatorial coordinates. Shift equatorial coordinates to a hour angle later.

Parameters:

Name Type Description Default
equatorial_coords ArrayLike

Vector in equatorial coordinates.

required
angle ArrayLike

Angle to rotate around the z-axis.

required

Returns:

Type Description
Array

Vector in equatorial coordinates at time shifted by angle.

Source code in src/jax_gw/detector/orbits.py
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
def equatorial_timeshift(equatorial_coords: ArrayLike, angle: ArrayLike) -> Array:
    """Rotate a vector by an angle `angle` around the z-axis of equatorial coordinates.
    Shift equatorial coordinates to a hour `angle` later.

    Parameters
    ----------
    equatorial_coords : ArrayLike
        Vector in equatorial coordinates.
    angle : ArrayLike
        Angle to rotate around the z-axis.

    Returns
    -------
    Array
        Vector in equatorial coordinates at time shifted by `angle`.
    """
    x, y, z = equatorial_coords
    cos_angle = jnp.cos(angle)
    sin_angle = jnp.sin(angle)

    x_return = cos_angle * x + sin_angle * y
    y_return = -sin_angle * x + cos_angle * y
    z_return = z * jnp.ones_like(x_return)
    return jnp.stack([x_return, y_return, z_return], axis=0)

flat_index(i, j, N)

Calculate the flat index for a pair of indices i, j.

Parameters:

Name Type Description Default
i int32

First index.

required
j int32

Second index.

required
N int

Possible values for i (or j).

required

Returns:

Type Description
int32

Flat index for the pair of indices corresponding to the pair of spacecraft.

Source code in src/jax_gw/detector/orbits.py
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
def flat_index(i: jnp.int32, j: jnp.int32, N: int) -> jnp.int32:
    """Calculate the flat index for a pair of indices i, j.

    Parameters
    ----------
    i : jnp.int32
        First index.
    j : jnp.int32
        Second index.
    N : int
        Possible values for i (or j).

    Returns
    -------
    jnp.int32
        Flat index for the pair of indices corresponding to the pair of
        spacecraft.
    """
    min_ij = jnp.minimum(i, j)
    max_ij = jnp.maximum(i, j)
    # returns 0 if i < j and 1 if i > j
    really_fast_index = jnp.greater(i, j)
    fast_index = 2 * (max_ij - min_ij - 1)
    slow_index = min_ij * (2 * N - min_ij - 1)

    return really_fast_index + fast_index + slow_index

flat_to_matrix_indices(N)

Calculate the (N*(N-1), 2) matrix of flat indices for a given number of spacecraft.

Parameters:

Name Type Description Default
N int

Number of spacecraft.

required

Returns:

Type Description
array

Matrix of flat indices.

Source code in src/jax_gw/detector/orbits.py
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
def flat_to_matrix_indices(
    N: int,
) -> Array:
    """Calculate the (N*(N-1), 2) matrix of flat indices for a given number of
    spacecraft.

    Parameters
    ----------
    N : int
        Number of spacecraft.

    Returns
    -------
    jnp.array
        Matrix of flat indices.
    """
    # create the matrix of flat indices
    flat_indices = jnp.zeros((N * (N - 1), 2), dtype=jnp.int32)

    index_func = jax.jit(flat_index)
    for i, j in zip(*jnp.triu_indices(N, k=1)):
        k = index_func(i, j, N)
        flat_indices = flat_indices.at[k, :].set(jnp.stack([i, j], axis=0))
        k = index_func(j, i, N)
        flat_indices = flat_indices.at[k, :].set(jnp.stack([j, i], axis=0))

    return flat_indices

flatten_pairs(matrix_form)

Flatten the separations or receiver positions from a pair of indices to a single dimension of length N * (N - 1).

Parameters:

Name Type Description Default
matrix_form array

Separations or receiver positions in matrix form.

required

Returns:

Type Description
array

Flattened separations and receiver positions with shape (N_steps, N * (N - 1), 3).

Source code in src/jax_gw/detector/orbits.py
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
@jax.jit
def flatten_pairs(
    matrix_form: ArrayLike,
) -> Array:
    """Flatten the separations or receiver positions from a pair of indices
    to a single dimension of length N * (N - 1).

    Parameters
    ----------
    matrix_form : jnp.array
        Separations or receiver positions in matrix form.

    Returns
    -------
    jnp.array
        Flattened separations and receiver positions with shape (N_steps, N * (N - 1), 3).
    """
    N1, N2 = matrix_form.shape[1], matrix_form.shape[2]
    N = max(N1, N2)
    # use of minimum is to avoid out of bounds error when N1 or N2 has length 1
    vmapped_flat_index = jax.vmap(
        lambda i, j: matrix_form[:, jnp.minimum(i, N1), jnp.minimum(j, N2), ...],
        in_axes=(0, 0),
        out_axes=0,
    )
    indices = flat_to_matrix_indices(N)
    receivers, emitters = indices[:, 0], indices[:, 1]
    return vmapped_flat_index(receivers, emitters)

get_arm_lengths(separations)

Calculate the arm lengths from the vector separations.

Parameters:

Name Type Description Default
separations ndarray

Vector separations. Last dimension must be of length 3.

required

Returns:

Type Description
ndarray

Arm lengths in shape (N_steps, N, N).

Source code in src/jax_gw/detector/orbits.py
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
def get_arm_lengths(separations: ArrayLike) -> Array:
    """Calculate the arm lengths from the vector separations.

    Parameters
    ----------
    separations : jnp.ndarray
        Vector separations. Last dimension must be of length 3.

    Returns
    -------
    jnp.ndarray
        Arm lengths in shape `(N_steps, N, N)`.
    """
    d = jnp.linalg.norm(separations, axis=-1)

    return d

get_emitter_positions(position)

Calculate the emitter positions of the spacecraft for a given arm. Since the separation matrix is defined as r[i, j] = r[i] - r[j], the emitter positions must be calculated as e_pos[i, j, 3, N...] = pos[j, 3, N...].

Parameters:

Name Type Description Default
position array

Position of the spacecraft. Shape (N, 3, ...).

required

Returns:

Type Description
array

Emitter spacecraft positions for the arm. Same shape as separations.

Source code in src/jax_gw/detector/orbits.py
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
def get_emitter_positions(
    position: ArrayLike,
) -> Array:
    """Calculate the emitter positions of the spacecraft for a given arm.
    Since the separation matrix is defined as r[i, j] = r[i] - r[j], the
    emitter positions must be calculated as e_pos[i, j, 3, N...] = pos[j, 3, N...].
    Parameters
    ----------
    position : jnp.array
        Position of the spacecraft. Shape (N, 3, ...).

    Returns
    -------
    jnp.array
        Emitter spacecraft positions for the arm. Same shape as separations.
    """
    # first create a newaxis for the i index
    pos = position[jnp.newaxis, :, :, ...]
    # then move the time axis to the front, if it exists
    from_idx = [0, 1, 2]
    to_idx = [-3, -2, -1]
    pos = jnp.moveaxis(pos, from_idx, to_idx)

    return pos

get_receiver_positions(position)

Calculate the receiver positions of the spacecraft for a collection of arms. Since the separation matrix is defined as r[i, j] = r[i] - r[j], the receiver positions must be defined via r_pos[i, j, 3, N...] = pos[i, 3, N...]. Positions has shape (N, 3, N_steps) while separations has shape (N_steps, N, N, 3). Therefore, we need to create a newaxis for the j index, and finally move the time axis to the front.

Parameters:

Name Type Description Default
position ArrayLike

Position of the spacecraft. Shape (N, 3, ...).

required

Returns:

Type Description
Array

Receiver spacecraft positions for the arm. Shape broadcastable to the shape of separations.

Source code in src/jax_gw/detector/orbits.py
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
def get_receiver_positions(
    position: ArrayLike,
) -> Array:
    """Calculate the receiver positions of the spacecraft for a collection of arms.
    Since the separation matrix is defined as `r[i, j] = r[i] - r[j]`, the
    receiver positions must be defined via `r_pos[i, j, 3, N...] = pos[i, 3, N...]`.
    Positions has shape `(N, 3, N_steps)` while separations has shape
    `(N_steps, N, N, 3)`. Therefore, we need to create a newaxis for the `j` index,
    and finally move the time axis to the front.

    Parameters
    ----------
    position : ArrayLike
        Position of the spacecraft. Shape (N, 3, ...).

    Returns
    -------
    Array
        Receiver spacecraft positions for the arm. Shape broadcastable to
        the shape of separations.
    """

    # first create a newaxis for the j index
    pos = position[:, jnp.newaxis, :, ...]
    # then move the time axis to the front, if it exists
    from_idx = [0, 1, 2]
    to_idx = [-3, -2, -1]
    pos = jnp.moveaxis(pos, from_idx, to_idx)

    return pos

get_separations(orbits)

Calculate the vector separations between the spacecraft.

r_{ij} = r_i - r_j

Parameters:

Name Type Description Default
orbits ArrayLike

Array of shape (N, 3, N_steps) containing the orbits of the N spacecraft.

required

Returns:

Type Description
Array

Vector separations. Dimensions: (N_steps, N, N, 3).

Source code in src/jax_gw/detector/orbits.py
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
def get_separations(orbits: ArrayLike) -> Array:
    """Calculate the vector separations between the spacecraft.

    `r_{ij} = r_i - r_j`

    Parameters
    ----------
    orbits : ArrayLike
        Array of shape `(N, 3, N_steps)` containing the orbits of the N
        spacecraft.

    Returns
    -------
    Array
        Vector separations. Dimensions: `(N_steps, N, N, 3)`.
    """
    # calculate the vector separations
    N_steps, N = orbits.shape[2], orbits.shape[0]
    r = jnp.zeros((N_steps, N, N, 3))
    for i in range(N):
        for j in range(N):
            r = r.at[:, i, j, :].set(jnp.transpose(orbits[i, :, :] - orbits[j, :, :]))

    return r

get_vertex_angle(orbits)

get the angle between the two arms of the interferometer using the initial positions of the two arms

Source code in src/jax_gw/detector/orbits.py
20
21
22
23
24
25
26
27
28
29
30
31
def get_vertex_angle(orbits):
    """
    get the angle between the two arms of the interferometer
    using the initial positions of the two arms
    """
    x_arm = (orbits[1] - orbits[0])[:, 0]
    y_arm = (orbits[2] - orbits[0])[:, 0]
    # get the angle between the two arms of the interferometer
    vertex_angle = jnp.arccos(
        jnp.dot(x_arm, y_arm) / (jnp.linalg.norm(x_arm) * jnp.linalg.norm(y_arm))
    )
    return vertex_angle.item()

lat_lon_to_cartesian(lat, lon, r=1)

Convert latitude and longitude to equatorial cartesian coordinates.

Parameters:

Name Type Description Default
lat float

Latitude in radians.

required
lon float

Longitude in radians.

required
r float

Radius.

1

Returns:

Type Description
array

Equatorial cartesian coordinates.

Source code in src/jax_gw/detector/orbits.py
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
def lat_lon_to_cartesian(lat: float, lon: float, r: float = 1) -> Array:
    """Convert latitude and longitude to equatorial cartesian coordinates.

    Parameters
    ----------
    lat : float
        Latitude in radians.
    lon : float
        Longitude in radians.
    r : float
        Radius.

    Returns
    -------
    jnp.array
        Equatorial cartesian coordinates.
    """
    x = r * jnp.cos(lat) * jnp.cos(lon)
    y = r * jnp.cos(lat) * jnp.sin(lon)
    z = r * jnp.sin(lat)

    return jnp.stack([x, y, z], axis=0)

path_from_indices(indices)

Convert an array of indices of spacecraft and length N_depth+1 to a path that is a 1D array of length N_depth and contains the arm index for each part of the path.

Parameters:

Name Type Description Default
indices array

Array of indices of spacecraft and length N_depth+1.

required

Returns:

Type Description
array

Path that is a 1D array of length N_depth and contains the arm index for each part of the path.

Source code in src/jax_gw/detector/orbits.py
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
def path_from_indices(indices: ArrayLike) -> Array:
    """Convert an array of indices of spacecraft and length N_depth+1 to a path
    that is a 1D array of length N_depth and contains the arm index for each part of the path.

    Parameters
    ----------
    indices : jnp.array
        Array of indices of spacecraft and length N_depth+1.

    Returns
    -------
    jnp.array
        Path that is a 1D array of length N_depth and contains the arm index for each part of the path.
    """
    # first from (N_paths, N_depth+1,) to (N_paths, N_depth+1, 2), where the last axis contains the indices
    # shifted by one, so that (..., i, 0) is the start and (..., i, 1) is the end of the segment i of the path
    indices = jnp.stack([indices, jnp.roll(indices, -1, axis=-1)], axis=-1)
    # remove the last row
    indices = indices[..., :-1, :]

    return indices