Source code for tests.test_uda

import unittest

from feldera import PipelineBuilder
from tests import TEST_CLIENT


[docs] class TestUDA(unittest.TestCase):
[docs] def test_local(self): sql = """ CREATE LINEAR AGGREGATE I128_SUM(s BINARY(16)) RETURNS BINARY(16); CREATE TABLE T(x BINARY(16)); CREATE MATERIALIZED VIEW V AS SELECT I128_SUM(x) AS S, COUNT(*) AS C FROM T; """ toml = """ i256 = { version = "0.2.2", features = ["num-traits"] } num-traits = "0.2.19" """ udfs = """ use i256::I256; use feldera_sqllib::*; use crate::{AddAssignByRef, AddByRef, HasZero, MulByRef, SizeOf, Tup3}; use derive_more::Add; use num_traits::Zero; use rkyv::Fallible; use std::ops::{Add, AddAssign}; #[derive(Add, Clone, Debug, Default, PartialOrd, Ord, Eq, PartialEq, Hash)] pub struct I256Wrapper { pub data: I256, } impl SizeOf for I256Wrapper { fn size_of_children(&self, context: &mut size_of::Context) {} } impl From<[u8; 32]> for I256Wrapper { fn from(value: [u8; 32]) -> Self { Self { data: I256::from_be_bytes(value) } } } impl From<&[u8]> for I256Wrapper { fn from(value: &[u8]) -> Self { let mut padded = [0u8; 32]; // If original value is negative, pad with sign if value[0] & 0x80 != 0 { padded.fill(0xff); } let len = value.len(); if len > 32 { panic!("Slice larger than target"); } padded[32-len..].copy_from_slice(&value[..len]); Self { data: I256::from_be_bytes(padded) } } } impl MulByRef<Weight> for I256Wrapper { type Output = Self; fn mul_by_ref(&self, other: &Weight) -> Self::Output { println!("Mul {:?} by {}", self, other); Self { data: self.data.checked_mul_i64(*other) .expect("Overflow during multiplication"), } } } impl HasZero for I256Wrapper { fn zero() -> Self { Self { data: I256::zero() } } fn is_zero(&self) -> bool { self.data.is_zero() } } impl AddByRef for I256Wrapper { fn add_by_ref(&self, other: &Self) -> Self { Self { data: self.data.add(other.data) } } } impl AddAssignByRef<Self> for I256Wrapper { fn add_assign_by_ref(&mut self, other: &Self) { self.data += other.data } } #[repr(C)] #[derive(Debug, Copy, Clone, PartialOrd, Ord, Eq, PartialEq)] pub struct ArchivedI256Wrapper { pub bytes: [u8; 32], } impl rkyv::Archive for I256Wrapper { type Archived = ArchivedI256Wrapper; type Resolver = (); #[inline] unsafe fn resolve(&self, pos: usize, _: Self::Resolver, out: *mut Self::Archived) { out.write(ArchivedI256Wrapper { bytes: self.data.to_be_bytes(), }); } } impl<S: Fallible + ?Sized> rkyv::Serialize<S> for I256Wrapper { #[inline] fn serialize(&self, serializer: &mut S) -> Result<Self::Resolver, S::Error> { Ok(()) } } impl<D: Fallible + ?Sized> rkyv::Deserialize<I256Wrapper, D> for ArchivedI256Wrapper { #[inline] fn deserialize(&self, _: &mut D) -> Result<I256Wrapper, D::Error> { Ok(I256Wrapper::from(self.bytes)) } } pub type i128_sum_accumulator_type = Tup3<I256Wrapper, i64, i64>; pub fn i128_sum_map(val: Option<ByteArray>) -> i128_sum_accumulator_type { match val { None => Tup3::new(I256Wrapper::zero(), 0, 1), Some(val) => Tup3::new( I256Wrapper::from(val.as_slice()), 1, 1, ), } } pub fn i128_sum_post(val: i128_sum_accumulator_type) -> Option<ByteArray> { if val.1 == 0 { None } else { // Check for overflow if val.0.data < I256::from(i128::MIN) || val.0.data > I256::from(i128::MAX) { panic!("Result of aggregation {} does not fit in 128 bits", val.0.data); } Some(ByteArray::new(&val.0.data.to_be_bytes()[16..])) } } """ pipeline = PipelineBuilder( TEST_CLIENT, name="test_uda", sql=sql, udf_rust=udfs, udf_toml=toml ).create_or_replace() pipeline.start() pipeline.input_json( "t", [ { "insert": { "x": [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1], } } ], update_format="insert_delete", ) pipeline.wait_for_idle() output = list(pipeline.query("SELECT * FROM V;")) assert output == [{"s": "00000000000000000000000000000001", "c": 1}] # Insert -1 pipeline.input_json( "t", [ { "insert": { "x": [ 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, ], } } ], update_format="insert_delete", ) pipeline.wait_for_idle() output = list(pipeline.query("SELECT * FROM V;")) assert output == [{"s": "00000000000000000000000000000000", "c": 2}] pipeline.input_json( "t", [ { "insert": { "x": [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2], } } ], update_format="insert_delete", ) output = list(pipeline.query("SELECT * FROM V;")) assert output == [{"s": "00000000000000000000000000000002", "c": 3}] pipeline.input_json( "t", [ { "insert": { "x": [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 3], } } ], update_format="insert_delete", ) output = list(pipeline.query("SELECT * FROM V;")) assert output == [{"s": "00000000000000000000000000000005", "c": 4}] pipeline.input_json( "t", [ { "delete": { "x": [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1], } }, { "delete": { "x": [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2], } }, { "delete": { "x": [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 3], } }, { "delete": { "x": [ 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 1, ], } }, ], update_format="insert_delete", ) output = list(pipeline.query("SELECT * FROM V;")) assert output == [{"s": None, "c": 0}] pipeline.stop(force=True)
if __name__ == "__main__": unittest.main()