Skip to content

Documentation for detector.response

antenna_pattern(u_hat, v_hat, arm_direction)

Calculate the antenna pattern for a given source direction.

Parameters:

Name Type Description Default
u_hat ArrayLike

First unit vector in the transverse plane of the incoming signal.

required
v_hat ArrayLike

Second unit vector in the transverse plane of the incoming signal.

required
arm_direction ArrayLike

Unit vector pointing along the arm from the emitter to the receiver.

required

Returns:

Type Description
array

Plus and cross antenna pattern functions.

Source code in src/jax_gw/detector/response.py
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
def antenna_pattern(
    u_hat: ArrayLike,
    v_hat: ArrayLike,
    arm_direction: ArrayLike,
) -> Array:
    """Calculate the antenna pattern for a given source direction.

    Parameters
    -------
    u_hat : ArrayLike
        First unit vector in the transverse plane of the incoming signal.
    v_hat : ArrayLike
        Second unit vector in the transverse plane of the incoming signal.
    arm_direction : ArrayLike
        Unit vector pointing along the arm from the emitter to the receiver.

    Returns
    -------
    jnp.array
        Plus and cross antenna pattern functions.
    """
    ksi_plus = jnp.dot(arm_direction, u_hat) ** 2 - jnp.dot(arm_direction, v_hat) ** 2
    ksi_cross = 2 * jnp.dot(arm_direction, u_hat) * jnp.dot(arm_direction, v_hat)

    return jnp.stack([ksi_plus, ksi_cross], axis=-1)

get_differential_strain_response(path_response, path_idx_1, path_idx_2, cumul_path_separations)

Calculate the strain response from the difference in the responses of two photon paths, of equal cumulative length, i.e.

R_{diff} = (R[path_idx_1] - R[path_idx_2]) / ( L_tot / c)

Parameters:

Name Type Description Default
path_response ArrayLike

Timing response function for a collection of photon paths.

required
path_idx_1 int

Index of the first photon path.

required
path_idx_2 int

Index of the second photon path.

required
cumul_path_separations ArrayLike

Cumulative path lengths for the photon paths.

required

Returns:

Type Description
array

Michelson strain response for the two chosen photon paths.

Source code in src/jax_gw/detector/response.py
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
def get_differential_strain_response(
    path_response: ArrayLike,
    path_idx_1: int,
    path_idx_2: int,
    cumul_path_separations: ArrayLike,
):
    """Calculate the strain response from the difference in the responses of two photon paths,
    of equal cumulative length, i.e.

    `R_{diff} = (R[path_idx_1] - R[path_idx_2]) / ( L_tot / c)`

    Parameters
    ----------
    path_response : ArrayLike
        Timing response function for a collection of photon paths.
    path_idx_1 : int
        Index of the first photon path.
    path_idx_2 : int
        Index of the second photon path.
    cumul_path_separations : ArrayLike
        Cumulative path lengths for the photon paths.

    Returns
    -------
    jnp.array
        Michelson strain response for the two chosen photon paths.
    """
    # get the cumulative path lengths for the two paths
    total_length_1 = cumul_path_separations[:, path_idx_1, -1]
    total_length_2 = cumul_path_separations[:, path_idx_2, -1]
    total_length = 0.5 * (total_length_1 + total_length_2)
    total_time = total_length / C_IN_AU_PER_S

    # get the difference in the response functions
    path_response_diff = path_response[path_idx_1] - path_response[path_idx_2]

    # get the strain response
    strain_response = path_response_diff / total_time[:, None, None, None]
    return strain_response

get_pairwise_differential_strain_response(path_response, cumul_path_separations)

Calculate the strain response from the difference in the responses of two subsequent photon paths, of equal cumulative length, i.e.

R_{diff} = (R[path_idx_1] - R[path_idx_2]) / ( L_tot / c)

where path_idx_1 and path_idx_2 are even and odd successive indices.

Parameters:

Name Type Description Default
path_response ArrayLike

Timing response function for a collection of photon paths.

required
cumul_path_separations ArrayLike

Cumulative path lengths for the photon paths.

required

Returns:

Type Description
array

Michelson strain response for all successive photon pairs.

Source code in src/jax_gw/detector/response.py
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
def get_pairwise_differential_strain_response(
    path_response: ArrayLike,
    cumul_path_separations: ArrayLike,
):
    """Calculate the strain response from the difference in the responses of two subsequent photon paths,
    of equal cumulative length, i.e.

    `R_{diff} = (R[path_idx_1] - R[path_idx_2]) / ( L_tot / c)`

    where `path_idx_1` and `path_idx_2` are even and odd successive indices.

    Parameters
    ----------
    path_response : ArrayLike
        Timing response function for a collection of photon paths.
    cumul_path_separations : ArrayLike
        Cumulative path lengths for the photon paths.

    Returns
    -------
    jnp.array
        Michelson strain response for all successive photon pairs.
    """
    total_length_cw = cumul_path_separations[:, ::2, -1]
    total_length_ccw = cumul_path_separations[:, 1::2, -1]
    total_length = 0.5 * (total_length_cw + total_length_ccw)
    total_time = total_length / C_IN_AU_PER_S

    path_response_diff = path_response[::2] - path_response[1::2]

    strain_response = path_response_diff / total_time.T[..., None, None, None]

    return strain_response

get_path_response(paths, freqs, arm_lengths, response)

Calculate the timing response function for a collection of photon paths.

Parameters:

Name Type Description Default
paths ArrayLike

Spacecraft indices for photon paths in shape (N_paths, N_depth).

required
freqs ArrayLike

Frequencies of the gravitational wave in shape (N_freq,).

required
arm_lengths ArrayLike

Arm lengths of the spacecraft in shape (N_pair, N_times).

required
response ArrayLike

Response function for the given source direction

required

Returns:

Type Description
array

Response function for the given source direction

Source code in src/jax_gw/detector/response.py
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
def get_path_response(
    paths: ArrayLike,
    freqs: ArrayLike,
    arm_lengths: ArrayLike,
    response: ArrayLike,
):
    """Calculate the timing response function for a collection of photon paths.

    Parameters
    ----------
    paths : ArrayLike
        Spacecraft indices for photon paths in shape (N_paths, N_depth).
    freqs : ArrayLike
        Frequencies of the gravitational wave in shape (N_freq,).
    arm_lengths : ArrayLike
        Arm lengths of the spacecraft in shape (N_pair, N_times).
    response : ArrayLike
        Response function for the given source direction

    Returns
    -------
    jnp.array
        Response function for the given source direction
    """
    indices = path_from_indices(paths)
    N_pair = response.shape[0]
    # N_pair = N * (N - 1), thus
    N = round(jnp.sqrt(N_pair + 1 / 4) + 1 / 2)

    flat_indices = jnp.apply_along_axis(
        lambda indices: flat_index(*indices, N),
        axis=-1,
        arr=indices,
    )
    # print(flat_indices)

    cumul_path_separations = get_cumulative_path_separations(flat_indices, arm_lengths)
    # remove the last element of each path, as it does not appear in emitter phases
    reduced_cumul_path_separations = cumul_path_separations[..., :-1]

    cumul_path_phases = -2 * jnp.pi * jnp.outer(freqs, reduced_cumul_path_separations)
    cumul_path_phases = (
        cumul_path_phases.reshape(freqs.shape + reduced_cumul_path_separations.shape)
        / C_IN_AU_PER_S
    )
    cumul_path_exp = jnp.exp(1j * cumul_path_phases)

    path_responses = jnp.einsum(
        "ijkl,kljmin->kjmin", cumul_path_exp, response[flat_indices]
    )

    return path_responses, cumul_path_separations

response_function(k_hat, freq, receiver_positions, full_transfer, antennae)

Calculate the timing response function for a given source direction.

Parameters:

Name Type Description Default
k_hat ArrayLike

Unit vector pointing in the direction of propagation of the incoming signal.

required
freq ArrayLike

Frequencies of the gravitational wave.

required
receiver_positions ArrayLike

Positions of the receivers.

required
full_transfer ArrayLike

Transfer function for the given source direction. Shape: (N_sky, N_freq, N_pair, N_times)

required
antennae ArrayLike

Plus and cross antenna pattern functions. Shape: (N_sky, N_pair, N_times, N_pol)

required

Returns:

Type Description
array

Response function for the given source direction.

Source code in src/jax_gw/detector/response.py
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
def response_function(
    k_hat: ArrayLike,
    freq: ArrayLike,
    receiver_positions: ArrayLike,
    full_transfer: ArrayLike,
    antennae: ArrayLike,
) -> Array:
    """Calculate the timing response function for a given source direction.

    Parameters
    ----------
    k_hat : ArrayLike
        Unit vector pointing in the direction of propagation of the incoming
        signal.
    freq : ArrayLike
        Frequencies of the gravitational wave.
    receiver_positions : ArrayLike
        Positions of the receivers.
    full_transfer : ArrayLike
        Transfer function for the given source direction.
        Shape: (N_sky, N_freq, N_pair, N_times)
    antennae : ArrayLike
        Plus and cross antenna pattern functions.
        Shape: (N_sky, N_pair, N_times, N_pol)

    Returns
    -------
    jnp.array
        Response function for the given source direction.
    """
    if not (isinstance(freq, (jnp.ndarray, Array))) or jnp.isscalar(freq):
        raise TypeError(f"freq must be an array, got {type(freq)}")

    dot_product = jnp.dot(receiver_positions, k_hat) / C_IN_AU_PER_S
    delta_phi = (
        2
        * jnp.pi
        * jnp.outer(freq, dot_product).reshape(freq.shape + dot_product.shape)
    )
    delta_phi = jnp.moveaxis(delta_phi, -1, 0)
    position_phase_shift = jnp.exp(-1j * delta_phi)

    # include the position phase shift to the transfer function
    full_transfer = full_transfer * position_phase_shift

    # response function assuming no time delay
    response_no_delay = 0.5 * jnp.einsum("ij...,i...k->...ijk", full_transfer, antennae)
    return response_no_delay

response_pipe(orbits, freqs, sky_basis)

Calculate the response function for a given source direction.

Source code in src/jax_gw/detector/response.py
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
def response_pipe(
    orbits,
    freqs,
    sky_basis,
):
    """Calculate the response function for a given source direction."""
    k_hat, u_hat, v_hat = sky_basis
    separations = get_separations(orbits)

    receiver_orbits = get_receiver_positions(orbits)
    receiver_positions = flatten_pairs(receiver_orbits)

    arms = flatten_pairs(separations)
    arm_lengths = get_arm_lengths(arms)
    arm_directions = arms / arm_lengths[..., None]

    full_transfer = jitted_vmapped_transfer_function(k_hat, freqs, arms)
    antennae = sky_vmapped_antenna_pattern(u_hat, v_hat, arm_directions)

    response = response_function(
        k_hat.T,
        freqs,
        receiver_positions,
        full_transfer,
        antennae,
    )

    return response, antennae

transfer_function(k_hat, freq, arms)

Calculate the transfer function for a given source direction.

Parameters:

Name Type Description Default
k_hat ArrayLike

Unit vector pointing in the direction of propagation of the incoming signal.

required
freq ArrayLike

Frequencies of the gravitational wave.

required
arms ArrayLike

Arm configuration of the spacecraft.

required

Returns:

Type Description
array

Transfer function for the given source direction.

Source code in src/jax_gw/detector/response.py
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
def transfer_function(
    k_hat: ArrayLike,
    freq: ArrayLike,
    arms: ArrayLike,
) -> Array:
    """Calculate the transfer function for a given source direction.

    Parameters
    ----------
    k_hat : ArrayLike
        Unit vector pointing in the direction of propagation of the incoming
        signal.
    freq : ArrayLike
        Frequencies of the gravitational wave.
    arms : ArrayLike
        Arm configuration of the spacecraft.

    Returns
    -------
    jnp.array
        Transfer function for the given source direction.
    """
    if not (isinstance(freq, (jnp.ndarray, Array))) or jnp.isscalar(freq):
        raise TypeError(f"freq must be an array, got {type(freq)}")

    arm_length = get_arm_lengths(arms)

    delta_t = arm_length - jnp.dot(arms, k_hat)
    delta_t = delta_t / C_IN_AU_PER_S
    L_over_c = arm_length / C_IN_AU_PER_S
    # jnp outer flattens the array, so we need to reshape it
    delta_phi = jnp.pi * jnp.outer(freq, delta_t).reshape(freq.shape + delta_t.shape)

    return L_over_c * jnp.sinc(delta_phi / jnp.pi) * jnp.exp(-1j * delta_phi)