Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat(air): introduce explicit types for generation numbers [fixes VM-261] #530

Merged
merged 17 commits into from
Apr 10, 2023
Merged
Show file tree
Hide file tree
Changes from 14 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
101 changes: 61 additions & 40 deletions air/src/execution_step/boxed_value/stream.rs
Original file line number Diff line number Diff line change
Expand Up @@ -20,9 +20,12 @@ use crate::ExecutionError;
use crate::JValue;
use crate::UncatchableError;

use air_interpreter_data::GenerationIdx;
use air_trace_handler::merger::ValueSource;
use air_trace_handler::TraceHandler;

use std::convert::TryFrom;

/// Streams are CRDT-like append only data structures. They are guaranteed to have the same order
/// of values on each peer.
#[derive(Debug, Default, Clone)]
Expand All @@ -38,17 +41,16 @@ pub struct Stream {
}

impl Stream {
pub(crate) fn from_generations_count(previous_count: usize, current_count: usize) -> Self {
let last_generation_count = 1;
pub(crate) fn from_generations_count(previous_count: GenerationIdx, current_count: GenerationIdx) -> Self {
let last_generation_count = GenerationIdx::from(1);
// TODO: bubble up an overflow error instead of expect
let overall_count = previous_count
.checked_add(current_count)
.and_then(|value| value.checked_add(last_generation_count))
.expect("it shouldn't overflow");

Self {
values: vec![vec![]; overall_count],
previous_gens_count: previous_count,
values: vec![vec![]; overall_count.into()],
previous_gens_count: previous_count.into(),
}
}

Expand All @@ -68,11 +70,13 @@ impl Stream {
value: ValueAggregate,
generation: Generation,
source: ValueSource,
) -> ExecutionResult<u32> {
) -> ExecutionResult<GenerationIdx> {
let generation_number = match (generation, source) {
(Generation::Last, _) => self.values.len() - 1,
(Generation::Nth(previous_gen), ValueSource::PreviousData) => previous_gen as usize,
(Generation::Nth(current_gen), ValueSource::CurrentData) => self.previous_gens_count + current_gen as usize,
(Generation::Nth(previous_gen), ValueSource::PreviousData) => previous_gen.into(),
(Generation::Nth(current_gen), ValueSource::CurrentData) => {
self.previous_gens_count + usize::from(current_gen)
}
};

if generation_number >= self.values.len() {
Expand All @@ -84,9 +88,10 @@ impl Stream {
}

self.values[generation_number].push(value);
Ok(generation_number as u32)
Ok(generation_number.into())
}

// TODO: remove this function
pub(crate) fn generations_count(&self) -> usize {
// the last generation could be empty due to the logic of from_generations_count ctor
if self.values.last().unwrap().is_empty() {
Expand All @@ -96,14 +101,15 @@ impl Stream {
}
}

pub(crate) fn last_non_empty_generation(&self) -> usize {
pub(crate) fn last_non_empty_generation(&self) -> GenerationIdx {
self.values
.iter()
.rposition(|generation| !generation.is_empty())
// it's safe to add + 1 here, because this function is called when
// there is a new state was added with add_new_generation_if_non_empty
.map(|non_empty_gens| non_empty_gens + 1)
.unwrap_or_else(|| self.generations_count())
.into()
}

/// Add a new empty generation if the latest isn't empty.
Expand Down Expand Up @@ -133,10 +139,13 @@ impl Stream {
should_remove_generation
}

pub(crate) fn elements_count(&self, generation: Generation) -> Option<usize> {
pub(crate) fn generation_elements_count(&self, generation: Generation) -> Option<usize> {
match generation {
Generation::Nth(generation) if generation as usize > self.generations_count() => None,
Generation::Nth(generation) => Some(self.values.iter().take(generation as usize).map(|v| v.len()).sum()),
Generation::Nth(generation) if generation > self.generations_count() => None,
Generation::Nth(generation) => {
let elements_count = generation.into();
Some(self.values.iter().take(elements_count).map(|v| v.len()).sum())
}
Generation::Last => Some(self.values.iter().map(|v| v.len()).sum()),
}
}
Expand All @@ -160,14 +169,14 @@ impl Stream {

pub(crate) fn iter(&self, generation: Generation) -> Option<StreamIter<'_>> {
let iter: Box<dyn Iterator<Item = &ValueAggregate>> = match generation {
Generation::Nth(generation) if generation as usize >= self.generations_count() => return None,
Generation::Nth(generation) if generation >= self.generations_count() => return None,
Generation::Nth(generation) => {
Box::new(self.values.iter().take(generation as usize + 1).flat_map(|v| v.iter()))
Box::new(self.values.iter().take(generation.next().into()).flat_map(|v| v.iter()))
}
Generation::Last => Box::new(self.values.iter().flat_map(|v| v.iter())),
};
// unwrap is safe here, because generation's been already checked
let len = self.elements_count(generation).unwrap();
let len = self.generation_elements_count(generation).unwrap();

let iter = StreamIter { iter, len };

Expand All @@ -179,39 +188,39 @@ impl Stream {
return None;
}

let generations_count = self.generations_count() as u32 - 1;
let generations_count = self.generations_count() - 1;
let (start, end) = match (start, end) {
(Generation::Nth(start), Generation::Nth(end)) => (start, end),
(Generation::Nth(start), Generation::Last) => (start, generations_count),
(Generation::Last, Generation::Nth(end)) => (generations_count, end),
(Generation::Nth(start), Generation::Nth(end)) => (usize::from(start), usize::from(end)),
(Generation::Nth(start), Generation::Last) => (start.into(), generations_count),
(Generation::Last, Generation::Nth(end)) => (generations_count, end.into()),
(Generation::Last, Generation::Last) => (generations_count, generations_count),
};

if start > end || end > generations_count {
return None;
}

let len = (end - start) as usize + 1;
let len = end - start + 1;
let iter: Box<dyn Iterator<Item = &[ValueAggregate]>> =
Box::new(self.values.iter().skip(start as usize).take(len).map(|v| v.as_slice()));
Box::new(self.values.iter().skip(start).take(len).map(|v| v.as_slice()));
let iter = StreamSliceIter { iter, len };

Some(iter)
}

/// Removes empty generations updating data and returns final generation count.
pub(crate) fn compactify(mut self, trace_ctx: &mut TraceHandler) -> ExecutionResult<usize> {
pub(crate) fn compactify(mut self, trace_ctx: &mut TraceHandler) -> ExecutionResult<GenerationIdx> {
self.remove_empty_generations();

for (generation, values) in self.values.iter().enumerate() {
for value in values.iter() {
trace_ctx
.update_generation(value.trace_pos, generation as u32)
.update_generation(value.trace_pos, generation.into())
.map_err(|e| ExecutionError::Uncatchable(UncatchableError::GenerationCompatificationError(e)))?;
}
}

Ok(self.values.len())
let last_generation_idx = self.values.len();
Ok(last_generation_idx.into())
}

/// Removes empty generations from current values.
Expand All @@ -223,7 +232,19 @@ impl Stream {
#[derive(Clone, Copy, Debug, Eq, PartialEq)]
pub enum Generation {
Last,
Nth(u32),
Nth(GenerationIdx),
}

impl Generation {
pub fn last() -> Self {
Self::Last
}

pub fn nth(generation_id: u32) -> Self {
mikevoronov marked this conversation as resolved.
Show resolved Hide resolved
let generation_id = usize::try_from(generation_id).unwrap();
let generation_idx = GenerationIdx::from(generation_id);
Self::Nth(generation_idx)
}
}

pub(crate) struct StreamIter<'result> {
Expand Down Expand Up @@ -313,22 +334,22 @@ mod test {
fn test_slice_iter() {
let value_1 = ValueAggregate::new(Rc::new(json!("value")), <_>::default(), 1.into());
let value_2 = ValueAggregate::new(Rc::new(json!("value")), <_>::default(), 1.into());
let mut stream = Stream::from_generations_count(2, 0);
let mut stream = Stream::from_generations_count(2.into(), 0.into());

stream
.add_value(value_1, Generation::Nth(0), ValueSource::PreviousData)
.add_value(value_1, Generation::nth(0), ValueSource::PreviousData)
.unwrap();
stream
.add_value(value_2, Generation::Nth(1), ValueSource::PreviousData)
.add_value(value_2, Generation::nth(1), ValueSource::PreviousData)
.unwrap();

let slice = stream.slice_iter(Generation::Nth(0), Generation::Nth(1)).unwrap();
let slice = stream.slice_iter(Generation::nth(0), Generation::nth(1)).unwrap();
assert_eq!(slice.len, 2);

let slice = stream.slice_iter(Generation::Nth(0), Generation::Last).unwrap();
let slice = stream.slice_iter(Generation::nth(0), Generation::Last).unwrap();
assert_eq!(slice.len, 2);

let slice = stream.slice_iter(Generation::Nth(0), Generation::Nth(0)).unwrap();
let slice = stream.slice_iter(Generation::nth(0), Generation::nth(0)).unwrap();
assert_eq!(slice.len, 1);

let slice = stream.slice_iter(Generation::Last, Generation::Last).unwrap();
Expand All @@ -337,15 +358,15 @@ mod test {

#[test]
fn test_slice_on_empty_stream() {
let stream = Stream::from_generations_count(2, 0);
let stream = Stream::from_generations_count(2.into(), 0.into());

let slice = stream.slice_iter(Generation::Nth(0), Generation::Nth(1));
let slice = stream.slice_iter(Generation::nth(0), Generation::nth(1));
assert!(slice.is_none());

let slice = stream.slice_iter(Generation::Nth(0), Generation::Last);
let slice = stream.slice_iter(Generation::nth(0), Generation::Last);
assert!(slice.is_none());

let slice = stream.slice_iter(Generation::Nth(0), Generation::Nth(0));
let slice = stream.slice_iter(Generation::nth(0), Generation::nth(0));
assert!(slice.is_none());

let slice = stream.slice_iter(Generation::Last, Generation::Last);
Expand All @@ -356,13 +377,13 @@ mod test {
fn generation_from_current_data() {
let value_1 = ValueAggregate::new(Rc::new(json!("value_1")), <_>::default(), 1.into());
let value_2 = ValueAggregate::new(Rc::new(json!("value_2")), <_>::default(), 2.into());
let mut stream = Stream::from_generations_count(5, 5);
let mut stream = Stream::from_generations_count(5.into(), 5.into());

stream
.add_value(value_1.clone(), Generation::Nth(2), ValueSource::CurrentData)
.add_value(value_1.clone(), Generation::nth(2), ValueSource::CurrentData)
.unwrap();
stream
.add_value(value_2.clone(), Generation::Nth(4), ValueSource::PreviousData)
.add_value(value_2.clone(), Generation::nth(4), ValueSource::PreviousData)
.unwrap();

let generations_count = stream.generations_count();
Expand Down
31 changes: 20 additions & 11 deletions air/src/execution_step/execution_context/streams_variables.rs
Original file line number Diff line number Diff line change
Expand Up @@ -21,9 +21,11 @@ mod utils;
use crate::execution_step::ExecutionResult;
use crate::execution_step::Stream;
use crate::ExecutionError;

use stream_descriptor::*;
pub(crate) use stream_value_descriptor::StreamValueDescriptor;

use air_interpreter_data::GenerationIdx;
use air_interpreter_data::GlobalStreamGens;
use air_interpreter_data::RestrictedStreamGens;
use air_parser::ast::Span;
Expand Down Expand Up @@ -82,7 +84,10 @@ impl Streams {
.and_then(|descriptors| find_closest_mut(descriptors.iter_mut(), position))
}

pub(crate) fn add_stream_value(&mut self, value_descriptor: StreamValueDescriptor<'_>) -> ExecutionResult<u32> {
pub(crate) fn add_stream_value(
&mut self,
value_descriptor: StreamValueDescriptor<'_>,
) -> ExecutionResult<GenerationIdx> {
let StreamValueDescriptor {
value,
name,
Expand All @@ -105,17 +110,16 @@ impl Streams {
let descriptor = StreamDescriptor::global(stream);
self.streams.insert(name.to_string(), vec![descriptor]);
let generation = 0;
Ok(generation)
Ok(generation.into())
}
}
}

pub(crate) fn meet_scope_start(&mut self, name: impl Into<String>, span: Span, iteration: u32) {
pub(crate) fn meet_scope_start(&mut self, name: impl Into<String>, span: Span, iteration: usize) {
let name = name.into();
let (prev_gens_count, current_gens_count) =
self.stream_generation_from_data(&name, span.left, iteration as usize);
let (prev_gens_count, current_gens_count) = self.stream_generation_from_data(&name, span.left, iteration);

let new_stream = Stream::from_generations_count(prev_gens_count as usize, current_gens_count as usize);
let new_stream = Stream::from_generations_count(prev_gens_count, current_gens_count);
let new_descriptor = StreamDescriptor::restricted(new_stream, span);
match self.streams.entry(name) {
Occupied(mut entry) => {
Expand Down Expand Up @@ -143,7 +147,7 @@ impl Streams {
}
let gens_count = last_descriptor.stream.compactify(trace_ctx)?;

self.collect_stream_generation(name, position, gens_count as u32);
self.collect_stream_generation(name, position, gens_count);
Ok(())
}

Expand All @@ -164,14 +168,19 @@ impl Streams {
// of the execution
let stream = descriptors.pop().unwrap().stream;
let gens_count = stream.compactify(trace_ctx)?;
Ok((name, gens_count as u32))
Ok((name, gens_count))
})
.collect::<Result<GlobalStreamGens, _>>()?;

Ok((global_streams, self.new_restricted_stream_gens))
}

fn stream_generation_from_data(&self, name: &str, position: AirPos, iteration: usize) -> (u32, u32) {
fn stream_generation_from_data(
&self,
name: &str,
position: AirPos,
iteration: usize,
) -> (GenerationIdx, GenerationIdx) {
let previous_generation =
Self::restricted_stream_generation(&self.previous_restricted_stream_gens, name, position, iteration)
.unwrap_or_default();
Expand All @@ -187,14 +196,14 @@ impl Streams {
name: &str,
position: AirPos,
iteration: usize,
) -> Option<u32> {
) -> Option<GenerationIdx> {
restricted_stream_gens
.get(name)
.and_then(|scopes| scopes.get(&position).and_then(|iterations| iterations.get(iteration)))
.copied()
}

fn collect_stream_generation(&mut self, name: String, position: AirPos, generation: u32) {
fn collect_stream_generation(&mut self, name: String, position: AirPos, generation: GenerationIdx) {
match self.new_restricted_stream_gens.entry(name) {
Occupied(mut streams) => match streams.get_mut().entry(position) {
Occupied(mut iterations) => iterations.get_mut().push(generation),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ pub(super) fn merge_global_streams(
.iter()
.map(|(stream_name, &prev_gens_count)| {
let current_gens_count = current_global_streams.get(stream_name).cloned().unwrap_or_default();
let global_stream = Stream::from_generations_count(prev_gens_count as usize, current_gens_count as usize);
let global_stream = Stream::from_generations_count(prev_gens_count, current_gens_count);
let descriptor = StreamDescriptor::global(global_stream);
(stream_name.to_string(), vec![descriptor])
})
Expand All @@ -40,7 +40,7 @@ pub(super) fn merge_global_streams(
continue;
}

let global_stream = Stream::from_generations_count(0, current_gens_count as usize);
let global_stream = Stream::from_generations_count(0.into(), current_gens_count);
let descriptor = StreamDescriptor::global(global_stream);
global_streams.insert(stream_name, vec![descriptor]);
}
Expand Down
5 changes: 3 additions & 2 deletions air/src/execution_step/instructions/ap.rs
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@ use crate::SecurityTetraplet;
use apply_to_arguments::*;
use utils::*;

use air_interpreter_data::GenerationIdx;
use air_parser::ast;
use air_parser::ast::Ap;
use air_trace_handler::merger::MergerApResult;
Expand Down Expand Up @@ -75,7 +76,7 @@ fn populate_context<'ctx>(
merger_ap_result: &MergerApResult,
result: ValueAggregate,
exec_ctx: &mut ExecutionCtx<'ctx>,
) -> ExecutionResult<Option<u32>> {
) -> ExecutionResult<Option<GenerationIdx>> {
match ap_result {
ast::ApResult::Scalar(scalar) => exec_ctx.scalars.set_scalar_value(scalar.name, result).map(|_| None),
ast::ApResult::Stream(stream) => {
Expand All @@ -85,7 +86,7 @@ fn populate_context<'ctx>(
}
}

fn maybe_update_trace(maybe_generation: Option<u32>, trace_ctx: &mut TraceHandler) {
fn maybe_update_trace(maybe_generation: Option<GenerationIdx>, trace_ctx: &mut TraceHandler) {
use air_interpreter_data::ApResult;

if let Some(generation) = maybe_generation {
Expand Down
Loading