Skip to content

Commit

Permalink
feat(air): introduce explicit types for generation numbers (#530)
Browse files Browse the repository at this point in the history

---------

Co-authored-by: vms <[email protected]>
Co-authored-by: Anatolios Laskaris <[email protected]>
  • Loading branch information
3 people authored Apr 10, 2023
1 parent 3027f0b commit d62fa6f
Show file tree
Hide file tree
Showing 21 changed files with 225 additions and 95 deletions.
102 changes: 62 additions & 40 deletions air/src/execution_step/boxed_value/stream.rs
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ use crate::ExecutionError;
use crate::JValue;
use crate::UncatchableError;

use air_interpreter_data::GenerationIdx;
use air_trace_handler::merger::ValueSource;
use air_trace_handler::TraceHandler;

Expand All @@ -38,17 +39,16 @@ pub struct Stream {
}

impl Stream {
pub(crate) fn from_generations_count(previous_count: usize, current_count: usize) -> Self {
let last_generation_count = 1;
pub(crate) fn from_generations_count(previous_count: GenerationIdx, current_count: GenerationIdx) -> Self {
let last_generation_count = GenerationIdx::from(1);
// TODO: bubble up an overflow error instead of expect
let overall_count = previous_count
.checked_add(current_count)
.and_then(|value| value.checked_add(last_generation_count))
.expect("it shouldn't overflow");

Self {
values: vec![vec![]; overall_count],
previous_gens_count: previous_count,
values: vec![vec![]; overall_count.into()],
previous_gens_count: previous_count.into(),
}
}

Expand All @@ -68,11 +68,13 @@ impl Stream {
value: ValueAggregate,
generation: Generation,
source: ValueSource,
) -> ExecutionResult<u32> {
) -> ExecutionResult<GenerationIdx> {
let generation_number = match (generation, source) {
(Generation::Last, _) => self.values.len() - 1,
(Generation::Nth(previous_gen), ValueSource::PreviousData) => previous_gen as usize,
(Generation::Nth(current_gen), ValueSource::CurrentData) => self.previous_gens_count + current_gen as usize,
(Generation::Nth(previous_gen), ValueSource::PreviousData) => previous_gen.into(),
(Generation::Nth(current_gen), ValueSource::CurrentData) => {
self.previous_gens_count + usize::from(current_gen)
}
};

if generation_number >= self.values.len() {
Expand All @@ -84,9 +86,10 @@ impl Stream {
}

self.values[generation_number].push(value);
Ok(generation_number as u32)
Ok(generation_number.into())
}

// TODO: remove this function
pub(crate) fn generations_count(&self) -> usize {
// the last generation could be empty due to the logic of from_generations_count ctor
if self.values.last().unwrap().is_empty() {
Expand All @@ -96,14 +99,15 @@ impl Stream {
}
}

pub(crate) fn last_non_empty_generation(&self) -> usize {
pub(crate) fn last_non_empty_generation(&self) -> GenerationIdx {
self.values
.iter()
.rposition(|generation| !generation.is_empty())
// it's safe to add + 1 here, because this function is called when
// there is a new state was added with add_new_generation_if_non_empty
.map(|non_empty_gens| non_empty_gens + 1)
.unwrap_or_else(|| self.generations_count())
.into()
}

/// Add a new empty generation if the latest isn't empty.
Expand Down Expand Up @@ -133,10 +137,13 @@ impl Stream {
should_remove_generation
}

pub(crate) fn elements_count(&self, generation: Generation) -> Option<usize> {
pub(crate) fn generation_elements_count(&self, generation: Generation) -> Option<usize> {
match generation {
Generation::Nth(generation) if generation as usize > self.generations_count() => None,
Generation::Nth(generation) => Some(self.values.iter().take(generation as usize).map(|v| v.len()).sum()),
Generation::Nth(generation) if generation > self.generations_count() => None,
Generation::Nth(generation) => {
let elements_count = generation.into();
Some(self.values.iter().take(elements_count).map(|v| v.len()).sum())
}
Generation::Last => Some(self.values.iter().map(|v| v.len()).sum()),
}
}
Expand All @@ -160,14 +167,14 @@ impl Stream {

pub(crate) fn iter(&self, generation: Generation) -> Option<StreamIter<'_>> {
let iter: Box<dyn Iterator<Item = &ValueAggregate>> = match generation {
Generation::Nth(generation) if generation as usize >= self.generations_count() => return None,
Generation::Nth(generation) if generation >= self.generations_count() => return None,
Generation::Nth(generation) => {
Box::new(self.values.iter().take(generation as usize + 1).flat_map(|v| v.iter()))
Box::new(self.values.iter().take(generation.next().into()).flat_map(|v| v.iter()))
}
Generation::Last => Box::new(self.values.iter().flat_map(|v| v.iter())),
};
// unwrap is safe here, because generation's been already checked
let len = self.elements_count(generation).unwrap();
let len = self.generation_elements_count(generation).unwrap();

let iter = StreamIter { iter, len };

Expand All @@ -179,39 +186,39 @@ impl Stream {
return None;
}

let generations_count = self.generations_count() as u32 - 1;
let generations_count = self.generations_count() - 1;
let (start, end) = match (start, end) {
(Generation::Nth(start), Generation::Nth(end)) => (start, end),
(Generation::Nth(start), Generation::Last) => (start, generations_count),
(Generation::Last, Generation::Nth(end)) => (generations_count, end),
(Generation::Nth(start), Generation::Nth(end)) => (usize::from(start), usize::from(end)),
(Generation::Nth(start), Generation::Last) => (start.into(), generations_count),
(Generation::Last, Generation::Nth(end)) => (generations_count, end.into()),
(Generation::Last, Generation::Last) => (generations_count, generations_count),
};

if start > end || end > generations_count {
return None;
}

let len = (end - start) as usize + 1;
let len = end - start + 1;
let iter: Box<dyn Iterator<Item = &[ValueAggregate]>> =
Box::new(self.values.iter().skip(start as usize).take(len).map(|v| v.as_slice()));
Box::new(self.values.iter().skip(start).take(len).map(|v| v.as_slice()));
let iter = StreamSliceIter { iter, len };

Some(iter)
}

/// Removes empty generations updating data and returns final generation count.
pub(crate) fn compactify(mut self, trace_ctx: &mut TraceHandler) -> ExecutionResult<usize> {
pub(crate) fn compactify(mut self, trace_ctx: &mut TraceHandler) -> ExecutionResult<GenerationIdx> {
self.remove_empty_generations();

for (generation, values) in self.values.iter().enumerate() {
for value in values.iter() {
trace_ctx
.update_generation(value.trace_pos, generation as u32)
.update_generation(value.trace_pos, generation.into())
.map_err(|e| ExecutionError::Uncatchable(UncatchableError::GenerationCompatificationError(e)))?;
}
}

Ok(self.values.len())
let last_generation_idx = self.values.len();
Ok(last_generation_idx.into())
}

/// Removes empty generations from current values.
Expand All @@ -223,7 +230,22 @@ impl Stream {
#[derive(Clone, Copy, Debug, Eq, PartialEq)]
pub enum Generation {
Last,
Nth(u32),
Nth(GenerationIdx),
}

impl Generation {
pub fn last() -> Self {
Self::Last
}

#[cfg(test)]
pub fn nth(generation_id: u32) -> Self {
use std::convert::TryFrom;

let generation_id = usize::try_from(generation_id).unwrap();
let generation_idx = GenerationIdx::from(generation_id);
Self::Nth(generation_idx)
}
}

pub(crate) struct StreamIter<'result> {
Expand Down Expand Up @@ -313,22 +335,22 @@ mod test {
fn test_slice_iter() {
let value_1 = ValueAggregate::new(Rc::new(json!("value")), <_>::default(), 1.into());
let value_2 = ValueAggregate::new(Rc::new(json!("value")), <_>::default(), 1.into());
let mut stream = Stream::from_generations_count(2, 0);
let mut stream = Stream::from_generations_count(2.into(), 0.into());

stream
.add_value(value_1, Generation::Nth(0), ValueSource::PreviousData)
.add_value(value_1, Generation::nth(0), ValueSource::PreviousData)
.unwrap();
stream
.add_value(value_2, Generation::Nth(1), ValueSource::PreviousData)
.add_value(value_2, Generation::nth(1), ValueSource::PreviousData)
.unwrap();

let slice = stream.slice_iter(Generation::Nth(0), Generation::Nth(1)).unwrap();
let slice = stream.slice_iter(Generation::nth(0), Generation::nth(1)).unwrap();
assert_eq!(slice.len, 2);

let slice = stream.slice_iter(Generation::Nth(0), Generation::Last).unwrap();
let slice = stream.slice_iter(Generation::nth(0), Generation::Last).unwrap();
assert_eq!(slice.len, 2);

let slice = stream.slice_iter(Generation::Nth(0), Generation::Nth(0)).unwrap();
let slice = stream.slice_iter(Generation::nth(0), Generation::nth(0)).unwrap();
assert_eq!(slice.len, 1);

let slice = stream.slice_iter(Generation::Last, Generation::Last).unwrap();
Expand All @@ -337,15 +359,15 @@ mod test {

#[test]
fn test_slice_on_empty_stream() {
let stream = Stream::from_generations_count(2, 0);
let stream = Stream::from_generations_count(2.into(), 0.into());

let slice = stream.slice_iter(Generation::Nth(0), Generation::Nth(1));
let slice = stream.slice_iter(Generation::nth(0), Generation::nth(1));
assert!(slice.is_none());

let slice = stream.slice_iter(Generation::Nth(0), Generation::Last);
let slice = stream.slice_iter(Generation::nth(0), Generation::Last);
assert!(slice.is_none());

let slice = stream.slice_iter(Generation::Nth(0), Generation::Nth(0));
let slice = stream.slice_iter(Generation::nth(0), Generation::nth(0));
assert!(slice.is_none());

let slice = stream.slice_iter(Generation::Last, Generation::Last);
Expand All @@ -356,13 +378,13 @@ mod test {
fn generation_from_current_data() {
let value_1 = ValueAggregate::new(Rc::new(json!("value_1")), <_>::default(), 1.into());
let value_2 = ValueAggregate::new(Rc::new(json!("value_2")), <_>::default(), 2.into());
let mut stream = Stream::from_generations_count(5, 5);
let mut stream = Stream::from_generations_count(5.into(), 5.into());

stream
.add_value(value_1.clone(), Generation::Nth(2), ValueSource::CurrentData)
.add_value(value_1.clone(), Generation::nth(2), ValueSource::CurrentData)
.unwrap();
stream
.add_value(value_2.clone(), Generation::Nth(4), ValueSource::PreviousData)
.add_value(value_2.clone(), Generation::nth(4), ValueSource::PreviousData)
.unwrap();

let generations_count = stream.generations_count();
Expand Down
31 changes: 20 additions & 11 deletions air/src/execution_step/execution_context/streams_variables.rs
Original file line number Diff line number Diff line change
Expand Up @@ -21,9 +21,11 @@ mod utils;
use crate::execution_step::ExecutionResult;
use crate::execution_step::Stream;
use crate::ExecutionError;

use stream_descriptor::*;
pub(crate) use stream_value_descriptor::StreamValueDescriptor;

use air_interpreter_data::GenerationIdx;
use air_interpreter_data::GlobalStreamGens;
use air_interpreter_data::RestrictedStreamGens;
use air_parser::ast::Span;
Expand Down Expand Up @@ -82,7 +84,10 @@ impl Streams {
.and_then(|descriptors| find_closest_mut(descriptors.iter_mut(), position))
}

pub(crate) fn add_stream_value(&mut self, value_descriptor: StreamValueDescriptor<'_>) -> ExecutionResult<u32> {
pub(crate) fn add_stream_value(
&mut self,
value_descriptor: StreamValueDescriptor<'_>,
) -> ExecutionResult<GenerationIdx> {
let StreamValueDescriptor {
value,
name,
Expand All @@ -105,17 +110,16 @@ impl Streams {
let descriptor = StreamDescriptor::global(stream);
self.streams.insert(name.to_string(), vec![descriptor]);
let generation = 0;
Ok(generation)
Ok(generation.into())
}
}
}

pub(crate) fn meet_scope_start(&mut self, name: impl Into<String>, span: Span, iteration: u32) {
pub(crate) fn meet_scope_start(&mut self, name: impl Into<String>, span: Span, iteration: usize) {
let name = name.into();
let (prev_gens_count, current_gens_count) =
self.stream_generation_from_data(&name, span.left, iteration as usize);
let (prev_gens_count, current_gens_count) = self.stream_generation_from_data(&name, span.left, iteration);

let new_stream = Stream::from_generations_count(prev_gens_count as usize, current_gens_count as usize);
let new_stream = Stream::from_generations_count(prev_gens_count, current_gens_count);
let new_descriptor = StreamDescriptor::restricted(new_stream, span);
match self.streams.entry(name) {
Occupied(mut entry) => {
Expand Down Expand Up @@ -143,7 +147,7 @@ impl Streams {
}
let gens_count = last_descriptor.stream.compactify(trace_ctx)?;

self.collect_stream_generation(name, position, gens_count as u32);
self.collect_stream_generation(name, position, gens_count);
Ok(())
}

Expand All @@ -164,14 +168,19 @@ impl Streams {
// of the execution
let stream = descriptors.pop().unwrap().stream;
let gens_count = stream.compactify(trace_ctx)?;
Ok((name, gens_count as u32))
Ok((name, gens_count))
})
.collect::<Result<GlobalStreamGens, _>>()?;

Ok((global_streams, self.new_restricted_stream_gens))
}

fn stream_generation_from_data(&self, name: &str, position: AirPos, iteration: usize) -> (u32, u32) {
fn stream_generation_from_data(
&self,
name: &str,
position: AirPos,
iteration: usize,
) -> (GenerationIdx, GenerationIdx) {
let previous_generation =
Self::restricted_stream_generation(&self.previous_restricted_stream_gens, name, position, iteration)
.unwrap_or_default();
Expand All @@ -187,14 +196,14 @@ impl Streams {
name: &str,
position: AirPos,
iteration: usize,
) -> Option<u32> {
) -> Option<GenerationIdx> {
restricted_stream_gens
.get(name)
.and_then(|scopes| scopes.get(&position).and_then(|iterations| iterations.get(iteration)))
.copied()
}

fn collect_stream_generation(&mut self, name: String, position: AirPos, generation: u32) {
fn collect_stream_generation(&mut self, name: String, position: AirPos, generation: GenerationIdx) {
match self.new_restricted_stream_gens.entry(name) {
Occupied(mut streams) => match streams.get_mut().entry(position) {
Occupied(mut iterations) => iterations.get_mut().push(generation),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ pub(super) fn merge_global_streams(
.iter()
.map(|(stream_name, &prev_gens_count)| {
let current_gens_count = current_global_streams.get(stream_name).cloned().unwrap_or_default();
let global_stream = Stream::from_generations_count(prev_gens_count as usize, current_gens_count as usize);
let global_stream = Stream::from_generations_count(prev_gens_count, current_gens_count);
let descriptor = StreamDescriptor::global(global_stream);
(stream_name.to_string(), vec![descriptor])
})
Expand All @@ -40,7 +40,7 @@ pub(super) fn merge_global_streams(
continue;
}

let global_stream = Stream::from_generations_count(0, current_gens_count as usize);
let global_stream = Stream::from_generations_count(0.into(), current_gens_count);
let descriptor = StreamDescriptor::global(global_stream);
global_streams.insert(stream_name, vec![descriptor]);
}
Expand Down
5 changes: 3 additions & 2 deletions air/src/execution_step/instructions/ap.rs
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@ use crate::SecurityTetraplet;
use apply_to_arguments::*;
use utils::*;

use air_interpreter_data::GenerationIdx;
use air_parser::ast;
use air_parser::ast::Ap;
use air_trace_handler::merger::MergerApResult;
Expand Down Expand Up @@ -75,7 +76,7 @@ fn populate_context<'ctx>(
merger_ap_result: &MergerApResult,
result: ValueAggregate,
exec_ctx: &mut ExecutionCtx<'ctx>,
) -> ExecutionResult<Option<u32>> {
) -> ExecutionResult<Option<GenerationIdx>> {
match ap_result {
ast::ApResult::Scalar(scalar) => exec_ctx.scalars.set_scalar_value(scalar.name, result).map(|_| None),
ast::ApResult::Stream(stream) => {
Expand All @@ -85,7 +86,7 @@ fn populate_context<'ctx>(
}
}

fn maybe_update_trace(maybe_generation: Option<u32>, trace_ctx: &mut TraceHandler) {
fn maybe_update_trace(maybe_generation: Option<GenerationIdx>, trace_ctx: &mut TraceHandler) {
use air_interpreter_data::ApResult;

if let Some(generation) = maybe_generation {
Expand Down
Loading

0 comments on commit d62fa6f

Please sign in to comment.