Source code for cbf_sdp.transmitters.spead2_transmitters

# -*- coding: utf-8 -*-
"""
Implementation for the SPEAD2 network transport

This module contains the logic to take ICD Payloads and transmit them using
the SPEAD protocol.
"""
import asyncio
from contextlib import AbstractAsyncContextManager
import logging
import math
import time
from overrides import overrides

import numpy as np
import spead2.send.asyncio

from realtime.receive.core.icd import Items, Payload


IS_SPEAD3 = spead2.__version__.split('.')[0] == "3"

logger = logging.getLogger(__name__)


async def create_stream(thread_pool, target_host, port, config, buffer_size, transport_proto):
    assert transport_proto in ('udp', 'tcp')
    kwargs = {
        'thread_pool': thread_pool,
        'config': config,
        'buffer_size': buffer_size
    }
    if IS_SPEAD3:
        kwargs['endpoints'] = (target_host, port),
    else:
        kwargs['hostname'] = target_host
        kwargs['port'] = port
        kwargs['loop'] = asyncio.get_running_loop()
    if transport_proto == 'tcp':
        return await spead2.send.asyncio.TcpStream.connect(**kwargs)
    return spead2.send.asyncio.UdpStream(**kwargs)


[docs]def parse_endpoints(endpoints_spec): """ Parse endpoint specifications. Each endpoint is a colon-separated host and port pair, and multiple endpoints are separated by commas. A port can be a single number or a range specified as "start-end", both inclusive. """ endpoints = [] for endpoint in endpoints_spec.split(','): host, port = endpoint.split(':') if '-' in port: start, end = map(int, port.split('-')) if start > end: raise ValueError(f'invalid port range: {start} > {end}') for port in range(start, end + 1): endpoints.append((host, port)) else: endpoints.append((host, int(port))) return endpoints
[docs]class Spead2SenderPayload(Payload): """SPEAD2 payload following the CSP-SDP interface document""" def __init__(self, num_baselines=None, num_channels=None): super(Spead2SenderPayload, self).__init__() self._item_group = spead2.send.ItemGroup(flavour=spead2.Flavour(4, 64, 48, 0)) self._add_items(num_baselines, num_channels) self.baseline_count = num_baselines self.channel_count = num_channels self.correlated_data_fraction = np.ones([num_channels, num_baselines]) def _add_items(self, num_baselines, num_channels): """ Adds all the items to the payload as defined by the ICD :param num_baselines: number of baselines int the HEAP - used for sizing :param num_channels: number of channels in the HEAP - used for sizing """ vis_shape = (num_channels, num_baselines) for item in Items: item_desc = item.value shape = tuple() if item == Items.CORRELATOR_OUTPUT_DATA: shape = vis_shape self._item_group.add_item( id=item_desc.id, name=item_desc.name, description="", shape=shape, format=None, dtype=item_desc.dtype, ) vis = np.zeros( shape=vis_shape, dtype=Items.CORRELATOR_OUTPUT_DATA.value.dtype, ) self._item_group[Items.CORRELATOR_OUTPUT_DATA.value.id].value = vis def get_heap(self): def set_item(item, value): self._item_group[item.value.id].value = value set_item(Items.BASELINE_COUNT, self.baseline_count) set_item(Items.CHANNEL_COUNT, self.channel_count) set_item(Items.CHANNEL_ID, self.channel_id) set_item(Items.HARDWARE_ID, self.hardware_id) set_item(Items.PHASE_BIN_ID, self.phase_bin_id) set_item(Items.PHASE_BIN_COUNT, self.phase_bin_count) set_item(Items.POLARISATION_ID, self.polarisation_id) set_item(Items.SCAN_ID, self.scan_id) set_item(Items.TIMESTAMP_COUNT, self.timestamp_count) set_item(Items.TIMESTAMP_FRACTION, self.timestamp_fraction) corr_out_data = self._item_group[Items.CORRELATOR_OUTPUT_DATA.value.id].value if len(self.time_centroid_indices): corr_out_data['TCI'] = self.time_centroid_indices if len(self.correlated_data_fraction): corr_out_data['FD'] = self.correlated_data_fraction if len(self.visibilities): corr_out_data['VIS'] = self.visibilities return self._item_group.get_heap(descriptors="none", data="all") def get_start_heap(self): start_heap = self._item_group.get_start() self._item_group.add_to_heap(start_heap, descriptors="all", data="none") return start_heap def get_end_heap(self): return self._item_group.get_end()
[docs]class transmitter(AbstractAsyncContextManager): """ SPEAD2 transmitter This class uses the spead2 library to transmit visibilities over multiple spead2 streams. Each visiblity set given to this class' `send` method is broken down by channel range (depending on the configuration parameters), and each channel range is sent through a different stream. """ def __init__(self, config): self.config = config max_packet_size = int(config.get('max_packet_size', 1472)) logger.info( 'Creating StreamConfig with max_packet_size=%d', max_packet_size) self.stream_config = spead2.send.StreamConfig( max_packet_size=max_packet_size, rate=int(config.getfloat('rate', 1024 * 1024 * 1024)), burst_size=10, max_heaps=int(config.get('max_heaps', 1)), ) self.channels_per_stream = int( config.get('channels_per_stream', 0)) self.sender_threads = int( config.get('sender_threads', 1)) self.num_streams = 0 # set on first call to send() self.bytes_sent = 0 self.heaps_sent = 0 self.streams = [] self._start_heap_sent = False
[docs] async def prepare(self, num_baselines, num_channels): """Create the sending SPEAD streams""" start_time = time.time() if self.channels_per_stream == 0: self.num_streams = 1 self.channels_per_stream = num_channels else: self.num_streams = math.ceil(num_channels / self.channels_per_stream) # Each stream uses a separate ItemGroup because Heaps created out of # ItemGroups can point to memory held by the ItemGroup; and since we # want different heaps sent through each fo the streams we then need # independent ItemGroups self.payloads = [Spead2SenderPayload(num_baselines, self.channels_per_stream) for _ in range(self.num_streams)] # Create the streams; they still share a single I/O threadpool thread_pool = spead2.ThreadPool(threads=self.sender_threads) config = self.config if 'endpoints' in config: def endpoints(): endpoints = parse_endpoints(config['endpoints']) if len(endpoints) < self.num_streams: raise ValueError('missing endpoints for number of streams') yield from endpoints[:self.num_streams] else: def endpoints(): target_host = config.get('target_host', '127.0.0.1') target_port = int(config.get('target_port_start', 41000)) for i in range(self.num_streams): yield (target_host, target_port + i) buffer_size = int(config.get('buffer_size', spead2.send.asyncio.UdpStream.DEFAULT_BUFFER_SIZE)) transport_proto = config.get('transport_protocol', 'udp').lower() if transport_proto not in ('udp', 'tcp'): raise ValueError("transport_protocol should be udp or tcp") for i, endpoint in enumerate(endpoints()): host, port = endpoint logger.debug("Sending stream %d to %s:%d", i, host, port) stream = await create_stream( thread_pool, host, port, self.stream_config, buffer_size, transport_proto ) self.streams.append(stream) logger.info( 'Created %d %s spead2 streams to send data for %d channels in %.3f [ms]', self.num_streams, transport_proto.upper(), num_channels, (time.time() - start_time) * 1000)
async def _send_heaps(self, heaps): assert(len(heaps) == len(self.streams)) send_operations = [] for heap, stream in zip(heaps, self.streams): send_operations.append(stream.async_send_heap(heap)) results = await asyncio.gather(*send_operations) self.bytes_sent += sum(results) self.heaps_sent += len(heaps)
[docs] async def send(self, scan_id: int, ts: int, ts_fraction: int, vis: np.ndarray): """ Send a visibility set through all SPEAD2 streams :param int: the scan id :param ts: the integer part of the visibilities' timestamp :param ts_fraction: the fractional part of the visibilities' timestamp :param vis: the visibilities """ if not self._start_heap_sent: await self._send_heaps([payload.get_start_heap() for payload in self.payloads]) self._start_heap_sent = True logger.debug('Sending heaps to %d spead2 streams', len(self.streams)) heaps = [] assert(len(self.payloads) == len(self.streams)) for i, payload in enumerate(self.payloads): # When sending to multiple streams (e.g., 96) we could spend a long # time in this loop without yielding control back, so let's do that if i % 20 == 0: await asyncio.sleep(0) first_chan, last_chan = self.channels_per_stream * i, self.channels_per_stream * (i + 1) payload.scan_id = int(scan_id) payload.timestamp_count = ts payload.timestamp_fraction = ts_fraction payload.visibilities = vis[first_chan:last_chan] payload.channel_id = first_chan payload.channel_count = self.channels_per_stream heaps.append(payload.get_heap()) await self._send_heaps(heaps)
[docs] async def close(self): """Sends the end-of-stream message""" await self._send_heaps([payload.get_end_heap() for payload in self.payloads])
@overrides async def __aenter__(self): return self @overrides async def __aexit__(self, ext_type, exc, tb): await self.close()