Skip to content

Commit

Permalink
feat!: use the array api in the scanner
Browse files Browse the repository at this point in the history
  • Loading branch information
amaanq committed Mar 13, 2024
1 parent de04ef1 commit 673bc22
Showing 1 changed file with 83 additions and 146 deletions.
229 changes: 83 additions & 146 deletions src/scanner.c
Original file line number Diff line number Diff line change
@@ -1,74 +1,11 @@
#include "tree_sitter/array.h"
#include "tree_sitter/parser.h"

#include <assert.h>
#include <ctype.h>
#include <string.h>
#include <wctype.h>

#define MAX(a, b) ((a) > (b) ? (a) : (b))

#define VEC_RESIZE(vec, _cap) \
void *tmp = realloc((vec).data, (_cap) * sizeof((vec).data[0])); \
assert(tmp != NULL); \
(vec).data = tmp; \
assert((vec).data != NULL); \
(vec).cap = (_cap);

#define VEC_PUSH(vec, el) \
if ((vec).cap == (vec).len) { \
VEC_RESIZE((vec), MAX(16, (vec).len * 2)); \
} \
(vec).data[(vec).len++] = (el);

#define VEC_POP(vec) \
{ (vec).len--; }

#define VEC_BACK(vec) ((vec).data[(vec).len - 1])

#define VEC_FREE(vec) \
{ \
if ((vec).data != NULL) \
free((vec).data); \
(vec).data = NULL; \
}

#define VEC_CLEAR(vec) \
{ \
for (uint32_t i = 0; i < (vec).len; i++) { \
STRING_FREE((vec).data[i].word); \
} \
(vec).len = 0; \
}

#define STRING_RESIZE(vec, _cap) \
void *tmp = realloc((vec).data, ((_cap) + 1) * sizeof((vec).data[0])); \
assert(tmp != NULL); \
(vec).data = tmp; \
memset((vec).data + (vec).len, 0, (((_cap) + 1) - (vec).len) * sizeof((vec).data[0])); \
(vec).cap = (_cap);

#define STRING_GROW(vec, _cap) \
if ((vec).cap < (_cap)) { \
STRING_RESIZE((vec), (_cap)); \
}

#define STRING_PUSH(vec, el) \
if ((vec).cap == (vec).len) { \
STRING_RESIZE((vec), MAX(16, (vec).len * 2)); \
} \
(vec).data[(vec).len++] = (el);

#define STRING_FREE(vec) \
if ((vec).data != NULL) \
free((vec).data); \
(vec).data = NULL;

#define STRING_CLEAR(vec) \
{ \
(vec).len = 0; \
memset((vec).data, 0, (vec).cap * sizeof(char)); \
}

enum TokenType {
HEREDOC_START,
SIMPLE_HEREDOC_BODY,
Expand Down Expand Up @@ -101,13 +38,7 @@ enum TokenType {
ERROR_RECOVERY,
};

typedef struct {
uint32_t cap;
uint32_t len;
char *data;
} String;

static String string_new() { return (String){.cap = 16, .len = 0, .data = calloc(1, sizeof(char) * 17)}; }
typedef Array(char) String;

typedef struct {
bool is_raw;
Expand All @@ -117,33 +48,20 @@ typedef struct {
String current_leading_word;
} Heredoc;

static Heredoc heredoc_new() {
Heredoc heredoc = {
.is_raw = false,
.started = false,
.allows_indent = false,
.delimiter = string_new(),
.current_leading_word = string_new(),
#define heredoc_new() \
{ \
.is_raw = false, \
.started = false, \
.allows_indent = false, \
.delimiter = array_new(), \
.current_leading_word = array_new(), \
};
return heredoc;
}

typedef struct {
uint32_t len;
uint32_t cap;
Heredoc *data;
} heredoc_vec;

static heredoc_vec vec_new() {
heredoc_vec vec = {0, 0, NULL};
vec.data = calloc(1, sizeof(Heredoc));
vec.cap = 1;
return vec;
}

typedef struct {
uint8_t last_glob_paren_depth;
heredoc_vec heredocs;
bool ext_was_in_double_quote;
bool ext_saw_outside_quote;
Array(Heredoc) heredocs;
} Scanner;

static inline void advance(TSLexer *lexer) { lexer->advance(lexer, false); }
Expand All @@ -152,39 +70,48 @@ static inline void skip(TSLexer *lexer) { lexer->advance(lexer, true); }

static inline bool in_error_recovery(const bool *valid_symbols) { return valid_symbols[ERROR_RECOVERY]; }

static inline void reset_string(String *string) {
if (string->size > 0) {
memset(string->contents, 0, string->size);
array_clear(string);
}
}

static inline void reset_heredoc(Heredoc *heredoc) {
heredoc->is_raw = false;
heredoc->started = false;
heredoc->allows_indent = false;
STRING_CLEAR(heredoc->delimiter);
reset_string(&heredoc->delimiter);
}

static inline void reset(Scanner *scanner) {
for (uint32_t i = 0; i < scanner->heredocs.len; i++) {
reset_heredoc(&scanner->heredocs.data[i]);
for (uint32_t i = 0; i < scanner->heredocs.size; i++) {
reset_heredoc(array_get(&scanner->heredocs, i));
}
}

static unsigned serialize(Scanner *scanner, char *buffer) {
uint32_t size = 0;

buffer[size++] = (char)scanner->last_glob_paren_depth;
buffer[size++] = (char)scanner->heredocs.len;
buffer[size++] = (char)scanner->ext_was_in_double_quote;
buffer[size++] = (char)scanner->ext_saw_outside_quote;
buffer[size++] = (char)scanner->heredocs.size;

for (uint32_t i = 0; i < scanner->heredocs.len; i++) {
Heredoc heredoc = scanner->heredocs.data[i];
if (heredoc.delimiter.len + 3 + size >= TREE_SITTER_SERIALIZATION_BUFFER_SIZE) {
for (uint32_t i = 0; i < scanner->heredocs.size; i++) {
Heredoc *heredoc = array_get(&scanner->heredocs, i);
if (heredoc->delimiter.size + 3 + size >= TREE_SITTER_SERIALIZATION_BUFFER_SIZE) {
return 0;
}

buffer[size++] = (char)heredoc.is_raw;
buffer[size++] = (char)heredoc.started;
buffer[size++] = (char)heredoc.allows_indent;
buffer[size++] = (char)heredoc->is_raw;
buffer[size++] = (char)heredoc->started;
buffer[size++] = (char)heredoc->allows_indent;

memcpy(&buffer[size], &heredoc.delimiter.len, sizeof(uint32_t));
memcpy(&buffer[size], &heredoc->delimiter.size, sizeof(uint32_t));
size += sizeof(uint32_t);
memcpy(&buffer[size], heredoc.delimiter.data, heredoc.delimiter.len);
size += heredoc.delimiter.len;
memcpy(&buffer[size], heredoc->delimiter.contents, heredoc->delimiter.size);

Check notice on line 113 in src/scanner.c

View workflow job for this annotation

GitHub Actions / Parser fuzzing

Sanitizer

null pointer passed as argument 2, which is declared to never be null

Check notice on line 113 in src/scanner.c

View workflow job for this annotation

GitHub Actions / Parser fuzzing

Sanitizer

null pointer passed as argument 2, which is declared to never be null
size += heredoc->delimiter.size;
}
return size;
}
Expand All @@ -195,27 +122,29 @@ static void deserialize(Scanner *scanner, const char *buffer, unsigned length) {
} else {
uint32_t size = 0;
scanner->last_glob_paren_depth = buffer[size++];
scanner->ext_was_in_double_quote = buffer[size++];
scanner->ext_saw_outside_quote = buffer[size++];
uint32_t heredoc_count = (unsigned char)buffer[size++];
for (uint32_t i = 0; i < heredoc_count; i++) {
Heredoc *heredoc = NULL;
if (i < scanner->heredocs.len) {
heredoc = &scanner->heredocs.data[i];
if (i < scanner->heredocs.size) {
heredoc = array_get(&scanner->heredocs, i);
} else {
Heredoc new_heredoc = heredoc_new();
VEC_PUSH(scanner->heredocs, new_heredoc);
heredoc = &VEC_BACK(scanner->heredocs);
array_push(&scanner->heredocs, new_heredoc);
heredoc = array_back(&scanner->heredocs);
}

heredoc->is_raw = buffer[size++];
heredoc->started = buffer[size++];
heredoc->allows_indent = buffer[size++];

memcpy(&heredoc->delimiter.len, &buffer[size], sizeof(uint32_t));
memcpy(&heredoc->delimiter.size, &buffer[size], sizeof(uint32_t));
size += sizeof(uint32_t);
STRING_GROW(heredoc->delimiter, heredoc->delimiter.len);
array_reserve(&heredoc->delimiter, heredoc->delimiter.size);

memcpy(heredoc->delimiter.data, &buffer[size], heredoc->delimiter.len);
size += heredoc->delimiter.len;
memcpy(heredoc->delimiter.contents, &buffer[size], heredoc->delimiter.size);

Check notice on line 146 in src/scanner.c

View workflow job for this annotation

GitHub Actions / Parser fuzzing

Sanitizer

null pointer passed as argument 1, which is declared to never be null

Check notice on line 146 in src/scanner.c

View workflow job for this annotation

GitHub Actions / Parser fuzzing

Sanitizer

null pointer passed as argument 1, which is declared to never be null
size += heredoc->delimiter.size;
}
assert(size == length);
}
Expand Down Expand Up @@ -247,9 +176,10 @@ static bool advance_word(TSLexer *lexer, String *unquoted_word) {
}
}
empty = false;
STRING_PUSH(*unquoted_word, lexer->lookahead);
array_push(unquoted_word, lexer->lookahead);
advance(lexer);
}
array_push(unquoted_word, '\0');

if (quote && lexer->lookahead == quote) {
advance(lexer);
Expand Down Expand Up @@ -282,29 +212,37 @@ static bool scan_heredoc_start(Heredoc *heredoc, TSLexer *lexer) {
heredoc->is_raw = lexer->lookahead == '\'' || lexer->lookahead == '"' || lexer->lookahead == '\\';

bool found_delimiter = advance_word(lexer, &heredoc->delimiter);
if (!found_delimiter)
STRING_CLEAR(heredoc->delimiter);
if (!found_delimiter) {
reset_string(&heredoc->delimiter);
return false;
}
return found_delimiter;
}

static bool scan_heredoc_end_identifier(Heredoc *heredoc, TSLexer *lexer) {
STRING_CLEAR(heredoc->current_leading_word);
reset_string(&heredoc->current_leading_word);
// Scan the first 'n' characters on this line, to see if they match the
// heredoc delimiter
int32_t size = 0;
while (lexer->lookahead != '\0' && lexer->lookahead != '\n' &&
((int32_t)heredoc->delimiter.data[size++]) == lexer->lookahead &&
heredoc->current_leading_word.len < heredoc->delimiter.len) {
STRING_PUSH(heredoc->current_leading_word, lexer->lookahead);
advance(lexer);
if (heredoc->delimiter.size > 0) {
while (lexer->lookahead != '\0' && lexer->lookahead != '\n' &&
(int32_t)*array_get(&heredoc->delimiter, size) == lexer->lookahead &&
heredoc->current_leading_word.size < heredoc->delimiter.size) {
array_push(&heredoc->current_leading_word, lexer->lookahead);
advance(lexer);
size++;
}
}
return strcmp(heredoc->current_leading_word.data, heredoc->delimiter.data) == 0;
array_push(&heredoc->current_leading_word, '\0');
return heredoc->delimiter.size == 0
? false
: strcmp(heredoc->current_leading_word.contents, heredoc->delimiter.contents) == 0;
}

static bool scan_heredoc_content(Scanner *scanner, TSLexer *lexer, enum TokenType middle_type,
enum TokenType end_type) {
bool did_advance = false;
Heredoc *heredoc = &VEC_BACK(scanner->heredocs);
Heredoc *heredoc = array_back(&scanner->heredocs);

for (;;) {
switch (lexer->lookahead) {
Expand Down Expand Up @@ -364,7 +302,7 @@ static bool scan_heredoc_content(Scanner *scanner, TSLexer *lexer, enum TokenTyp
lexer->mark_end(lexer);
if (scan_heredoc_end_identifier(heredoc, lexer)) {
if (lexer->result_symbol == HEREDOC_END) {
VEC_POP(scanner->heredocs);
array_pop(&scanner->heredocs);
}
return true;
}
Expand All @@ -376,7 +314,6 @@ static bool scan_heredoc_content(Scanner *scanner, TSLexer *lexer, enum TokenTyp
// an alternative is to check the starting column of the
// heredoc body and track that statefully
while (iswspace(lexer->lookahead)) {
/* did_advance ? advance(lexer) : skip(lexer); */
if (did_advance) {
advance(lexer);
} else {
Expand Down Expand Up @@ -494,29 +431,29 @@ static bool scan(Scanner *scanner, TSLexer *lexer, const bool *valid_symbols) {
}
}

if ((valid_symbols[HEREDOC_BODY_BEGINNING] || valid_symbols[SIMPLE_HEREDOC_BODY]) && scanner->heredocs.len > 0 &&
!VEC_BACK(scanner->heredocs).started && !in_error_recovery(valid_symbols)) {
if ((valid_symbols[HEREDOC_BODY_BEGINNING] || valid_symbols[SIMPLE_HEREDOC_BODY]) && scanner->heredocs.size > 0 &&
!array_back(&scanner->heredocs)->started && !in_error_recovery(valid_symbols)) {
return scan_heredoc_content(scanner, lexer, HEREDOC_BODY_BEGINNING, SIMPLE_HEREDOC_BODY);
}

if (valid_symbols[HEREDOC_END] && scanner->heredocs.len > 0) {
Heredoc *heredoc = &VEC_BACK(scanner->heredocs);
if (valid_symbols[HEREDOC_END] && scanner->heredocs.size > 0) {
Heredoc *heredoc = array_back(&scanner->heredocs);
if (scan_heredoc_end_identifier(heredoc, lexer)) {
STRING_FREE(heredoc->current_leading_word);
STRING_FREE(heredoc->delimiter);
VEC_POP(scanner->heredocs);
array_delete(&heredoc->current_leading_word);
array_delete(&heredoc->delimiter);
array_pop(&scanner->heredocs);
lexer->result_symbol = HEREDOC_END;
return true;
}
}

if (valid_symbols[HEREDOC_CONTENT] && scanner->heredocs.len > 0 && VEC_BACK(scanner->heredocs).started &&
if (valid_symbols[HEREDOC_CONTENT] && scanner->heredocs.size > 0 && array_back(&scanner->heredocs)->started &&
!in_error_recovery(valid_symbols)) {
return scan_heredoc_content(scanner, lexer, HEREDOC_CONTENT, HEREDOC_END);
}

if (valid_symbols[HEREDOC_START] && !in_error_recovery(valid_symbols) && scanner->heredocs.len > 0) {
return scan_heredoc_start(&VEC_BACK(scanner->heredocs), lexer);
if (valid_symbols[HEREDOC_START] && !in_error_recovery(valid_symbols) && scanner->heredocs.size > 0) {
return scan_heredoc_start(array_back(&scanner->heredocs), lexer);
}

if (valid_symbols[TEST_OPERATOR] && !valid_symbols[EXPANSION_WORD]) {
Expand Down Expand Up @@ -653,13 +590,13 @@ static bool scan(Scanner *scanner, TSLexer *lexer, const bool *valid_symbols) {
advance(lexer);
Heredoc heredoc = heredoc_new();
heredoc.allows_indent = true;
VEC_PUSH(scanner->heredocs, heredoc);
array_push(&scanner->heredocs, heredoc);
lexer->result_symbol = HEREDOC_ARROW_DASH;
} else if (lexer->lookahead == '<' || lexer->lookahead == '=') {
return false;
} else {
Heredoc heredoc = heredoc_new();
VEC_PUSH(scanner->heredocs, heredoc);
array_push(&scanner->heredocs, heredoc);
lexer->result_symbol = HEREDOC_ARROW;
}
return true;
Expand Down Expand Up @@ -1245,7 +1182,7 @@ static bool scan(Scanner *scanner, TSLexer *lexer, const bool *valid_symbols) {

void *tree_sitter_bash_external_scanner_create() {
Scanner *scanner = calloc(1, sizeof(Scanner));
scanner->heredocs = vec_new();
array_init(&scanner->heredocs);
return scanner;
}

Expand All @@ -1266,11 +1203,11 @@ void tree_sitter_bash_external_scanner_deserialize(void *payload, const char *st

void tree_sitter_bash_external_scanner_destroy(void *payload) {
Scanner *scanner = (Scanner *)payload;
for (size_t i = 0; i < scanner->heredocs.len; i++) {
Heredoc *heredoc = &scanner->heredocs.data[i];
STRING_FREE(heredoc->current_leading_word);
STRING_FREE(heredoc->delimiter);
for (size_t i = 0; i < scanner->heredocs.size; i++) {
Heredoc *heredoc = array_get(&scanner->heredocs, i);
array_delete(&heredoc->current_leading_word);
array_delete(&heredoc->delimiter);
}
VEC_FREE(scanner->heredocs);
array_delete(&scanner->heredocs);
free(scanner);
}

0 comments on commit 673bc22

Please sign in to comment.