Skip to content

Commit

Permalink
Do training on Vertex AI TPU (#7)
Browse files Browse the repository at this point in the history
* add tpu training

* add docker container for TPU

* add docker container for TPU

* add docker container for TPU

* add job args and fix linting

* fix linting issues

* fix bug in encoding function

* modify conflicting accelerator type name

* Add train on tpu parts

* fix tpu issues

* modify xm parallel trial runs

* separate functions into seperate modules

* update unit test

* update unit test

* update unit test and xm runs

* apply reversible encoding

* fix type annotation

* add unittesting for dataencoder

* add description to functions

* Restructure code and cleanup

* Add method description to encoding methods
  • Loading branch information
panford committed Dec 7, 2023
1 parent 949e305 commit c46c9e5
Show file tree
Hide file tree
Showing 8 changed files with 454 additions and 65 deletions.
200 changes: 200 additions & 0 deletions src/skai/model/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -177,6 +177,206 @@ def apply_batch(dataloader, batch_size):
return dataloader


class DataEncoder:
def __init__(self):
string_label_categories: dict[str, int] = {b'bad_example' :0,
b'destroyed' :1,
b'major_damage':2,
b'minor_damage':3,
b'no_damage' :4}
self.label_to_int_table = tf.lookup.StaticHashTable(
tf.lookup.KeyValueTensorInitializer(list(string_label_categories.keys()),
list(string_label_categories.values())),
default_value=-1
)
self.int_to_label_table = tf.lookup.StaticHashTable(
tf.lookup.KeyValueTensorInitializer(list(string_label_categories.values()),
list(string_label_categories.keys())),
default_value='unknown'
)
def encode_example_ids(self, dataloader: Dataloader)-> Dataloader:
"""
Encode example IDs from hexadecimal strings to integers in a
TensorFlow DataLoader.
Description:
example_id are hexadecimal strings, eg. b0b947f423a1c77ac948c76f63fa8209.
This is encode by taking the int to base 16. This gives a long integer
representation, ie 125613911306676688688906949689977127817181202292590253,
which cannot be stored by a tensorflow tensor. This long integer can be
broken into smaller segments like [2323, 9023, 3403] using a combination of
integer division and modulo operations which can be reversed. The segments
are (pre-)padded to same size for all examples in a batch and initial size
before padding appended to segments. ie [0, 0, 2323, 9023, 3403, 3]
Args:
- dataloader: The TensorFlow DataLoader containing example IDs to be encoded.
Returns:
- dataloader: The modified TensorFlow DataLoader with encoded example IDs.
"""
return self._apply_map_to_features(dataloader,
self._convert_hex_strings_to_int,
'example_id')

def encode_string_labels(self, dataloader: Dataloader)-> Dataloader:
"""
Encode string data components to numerical values.
HashTable used stores the unique labels and an integer value for lookup
Args:
dataloader: The dataloader.
Returns:
dataloader with string label encoded.
"""
return self._apply_map_to_features(dataloader,
self._convert_label_to_int,
'string_label')

def decode_example_ids(self, inputs: tf.Tensor | Dataloader):
"""
Decode example IDs from integers to hexadecimal strings in a batch.
Args:
- inputs: A batch of data or dataloader containing encoded example IDs.
Returns:
- The modified batch or dataloader with decoded example IDs.
"""
if isinstance(inputs, Dataloader):
return self._apply_map_to_features(
inputs,
self._convert_int_to_hex_strings,
'example_id')
else:
return self._convert_int_to_hex_strings(inputs)

def decode_string_labels(self, inputs: tf.Tensor | Dataloader):
"""
Decodes string labels by looking up strings from integers in a batch.
Args:
- inputs: A batch of data or dataloader containing encoded string labels.
Returns:
- The modified batch or dataloader with decoded string labels.
"""

if isinstance(inputs, Dataloader):
return self._apply_map_to_features(
inputs,
self._convert_int_to_label,
'string_label')
else:
return self._convert_int_to_label(inputs)

def _convert_hex_strings_to_int(self, hex_strings):
"""Converts hex strings to integer values, typically a very long one.
This long integer values do not fit into a tensorflow tensor int datatype.
So the long integer is broken into segments using modulo technique and padding
to same size
"""
segment_size=4
def split_long_integer(number):
segments = []
while number > 0:
segment = number % (10 ** segment_size) # Extract the last `segment_size` digits
segments.append(segment)
number //= 10 ** segment_size # Removes the last `segment_size` digits
return segments

output = []
for hex_string in hex_strings:
integer = int(hex_string.numpy(), 16)
short_integers = split_long_integer(integer)
short_integers += [len(short_integers)]
output.append(short_integers)
padded_sequences = tf.keras.preprocessing.sequence.pad_sequences(
output, padding='pre')
return padded_sequences

def _convert_int_to_hex_strings(self, segments):
"""Converts integer segments to a long integer value
that can be decoded to retrieve its hex string representation
"""
def combine_segments(segments, segment_size=4):
list_size = segments[-1]
segments_to_decode = segments[-(list_size+1):-1]

number = 0
for i, segment in enumerate(segments_to_decode):
number += segment * (10 ** (i * segment_size))
return number

def long_integer_to_string(integer):
strings = f'{integer:032x}'
return tf.compat.as_bytes(
strings, encoding='utf-8'
)

output = []
segment_size = 4
for segment in segments:
long_integer = combine_segments(
segment.numpy().tolist(), segment_size)
output.append(long_integer_to_string(long_integer))
return tf.convert_to_tensor(output)

def _convert_label_to_int(self, string_labels):
"""Lookup integer values from string labels"""
return self.label_to_int_table.lookup(string_labels)

def _convert_int_to_label(self, int_labels):
"""Lookup string labels from integer keys"""
return self.int_to_label_table.lookup(int_labels)

def _process_per_batch(self, batch, map_fn, feature):
"""Apply a map function to a batch of data."""
for idx, examples in enumerate(batch):
processed = map_fn(examples[feature])
examples[feature] = processed

if idx==0:
transformed_batch=tf.data.Dataset.from_tensor_slices(examples)
continue
transformed_batch.concatenate(
tf.data.Dataset.from_tensor_slices(examples))
return transformed_batch

def _apply_map_to_features(self, dataloader: Dataloader,
map_fn: collections.abc.Callable[[tf.Tensor], tf.Tensor],
feature: str):
"""
Apply a map function to a TensorFlow DataLoader and return the modified DataLoader.
Args:
- dataloader: The TensorFlow DataLoader to apply the map function to.
- map_fn: The mapping function to apply.
Returns:
- dataloader: The modified TensorFlow DataLoader.
"""
batch_size = dataloader.train_splits[0]._batch_size.numpy()

dataloader.train_splits = [
self._process_per_batch(data, map_fn, feature) for data in dataloader.train_splits
]
dataloader.val_splits = [
self._process_per_batch(data, map_fn, feature) for data in dataloader.val_splits
]
num_splits = len(dataloader.train_splits)
train_ds = gather_data_splits(
list(range(num_splits)), dataloader.train_splits)
val_ds = gather_data_splits(list(range(num_splits)), dataloader.val_splits)
dataloader.train_ds = train_ds
dataloader.eval_ds['val'] = val_ds
for (k, v) in dataloader.eval_ds.items():
if k != 'val':
dataloader.eval_ds[k] = self._process_per_batch(v, map_fn, feature)
dataloader = apply_batch(dataloader, batch_size)
return dataloader


def gather_data_splits(
slice_idx: list[int],
dataset: tf.data.Dataset | list[tf.data.Dataset]) -> tf.data.Dataset:
Expand Down
137 changes: 136 additions & 1 deletion src/skai/model/data_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
import os
import tempfile
from typing import List

from absl.testing import absltest
import numpy as np
from skai.model import data
import tensorflow as tf
Expand Down Expand Up @@ -135,5 +135,140 @@ def setUpClass(cls):
]


def _create_test_data_with_hex_strings():
examples_dir = _make_temp_dir()
labeled_train_path = os.path.join(
examples_dir, 'train_labeled_examples.tfrecord')
labeled_test_path = os.path.join(
examples_dir, 'test_labeled_examples.tfrecord')
unlabeled_path = os.path.join(
examples_dir, 'unlabeled_examples.tfrecord')

_write_tfrecord([
_make_example('b0b947f423a1c77ac948c76f63fa8209', 0, 0, 'A0', 0, 'no_damage', 64, 256),
_make_example('5fb3fc48db76805c169e8dc667c3f266', 0, 1, 'A1', 0, 'no_damage', 64, 256),
_make_example('21bdfdb3f65974473d4a19f05871449d', 0, 2, 'A2', 1, 'major_damage', 64, 256),
], labeled_train_path)

_write_tfrecord([
_make_example('a564b943bdebd4936ce0fd135cc19fbf', 1, 0, 'B0', 0, 'no_damage', 64, 256),
], labeled_test_path)

_write_tfrecord([
_make_example('3a8e68680d3ec6d1013d11f492a2d7d5', 2, 0, 'C0', -1, 'bad_example', 64, 256),
_make_example('1004dc994ff1888052aa3ff4be5e55cf', 2, 1, 'C1', -1, 'bad_example', 64, 256),
_make_example('4b49276f4f10856b9e8a57fad78ee593', 2, 2, 'C2', -1, 'bad_example', 64, 256),
_make_example('97a9600f1e418132af93ea03f4264ad2', 2, 3, 'C3', -1, 'bad_example', 64, 256),
], unlabeled_path)

return labeled_train_path, labeled_test_path, unlabeled_path
class TestDataEncoder(absltest.TestCase):
def setUp(self):
self.data_encoder = data.DataEncoder()
labeled_train_path, labeled_test_path, unlabeled_path = _create_test_data_with_hex_strings()
self.labeled_train_path = labeled_train_path
self.labeled_test_path = labeled_test_path
self.unlabeled_path = unlabeled_path

dataset_builder = data.get_dataset('skai')
kwargs = {
'labeled_train_pattern': self.labeled_train_path,
'unlabeled_train_pattern': self.unlabeled_path,
'validation_pattern': self.labeled_test_path,
'use_post_disaster_only': False,
'load_small_images': True,
'data_dir': _make_temp_dir(),
}

dataloader = dataset_builder(
1,
initial_sample_proportion=1,
subgroup_ids=(),
subgroup_proportions=(),
**kwargs
)
self.dataloader = data.apply_batch(dataloader, 2)

def test_encode_example_ids_returns_dataloader(self):
# Check if encode_example_id method correctly returns a dataloader
encoded_dataloader = self.data_encoder.encode_example_ids(self.dataloader)
self.assertIsInstance(encoded_dataloader, data.Dataloader)

def test_encode_example_ids_encodes_strings_to_int(self):
# Check if the example IDs are correctly encoded to ints
encoded_dataloader = self.data_encoder.encode_example_ids(self.dataloader)
dataset = encoded_dataloader.train_splits[0]
encoded_example_ids = list(dataset.map(lambda x: x['example_id']
).as_numpy_iterator())[0]
self.assertIsInstance(encoded_example_ids, np.ndarray)
self.assertTrue(np.issubdtype(encoded_example_ids.dtype, np.integer))

def test_encode_string_labels_returns_dataloader(self):
# Check if encode_string_label method correctly returns a dataloader
encoded_dataloader = self.data_encoder.encode_string_labels(self.dataloader)
self.assertIsInstance(encoded_dataloader, data.Dataloader)

def test_encode_string_labels_encodes_strings_to_int(self):
# Check if encode_string_label method correctly returns a dataloader
encoded_dataloader = self.data_encoder.encode_string_labels(self.dataloader)
dataset = encoded_dataloader.train_splits[0] #pick one example and evaluate
encoded_string_label = list(dataset.map(lambda x: x['string_label']
).as_numpy_iterator())[0]
self.assertIsInstance(encoded_string_label, np.ndarray)
self.assertTrue(np.issubdtype(encoded_string_label.dtype, np.integer))

def test_decode_example_ids_returns_dataloader(self):
encoded_dataloader = self.data_encoder.encode_example_ids(self.dataloader)
decoded_data = self.data_encoder.decode_example_ids(encoded_dataloader)
self.assertIsInstance(decoded_data, data.Dataloader)

def test_decode_int_label_decodes_int_to_string(self):
# Check if the example IDs are correctly encoded
encoded_dataloader = self.data_encoder.encode_string_labels(self.dataloader)
decoded_dataloader = self.data_encoder.decode_string_labels(encoded_dataloader)
dataset = decoded_dataloader.train_splits[0]
decoded_int_label = list(dataset.map(lambda x: x['string_label']
).as_numpy_iterator())[0]
self.assertIsInstance(decoded_int_label, np.ndarray)
self.assertTrue(np.issubdtype(decoded_int_label.dtype, np.str_) or
np.issubdtype(decoded_int_label.dtype, object))

def test_decode_example_id_outputs_matches_inputs(self):
all_example_ids = []
dataset_true = self.dataloader.train_splits[0]
true_id_list = list(dataset_true.map(lambda x: x['example_id']).as_numpy_iterator())
for string_label in true_id_list:
all_example_ids += string_label.tolist()

encoded_dataloader = self.data_encoder.encode_example_ids(self.dataloader)
decoded_dataloader = self.data_encoder.decode_example_ids(encoded_dataloader)

all_decoded_ids = []
dataset_decoded = decoded_dataloader.train_splits[0]
decoded_id_list = list(dataset_decoded.map(lambda x: x['example_id']).as_numpy_iterator())
for string_label in decoded_id_list:
all_decoded_ids += string_label.tolist()
self.assertItemsEqual(all_example_ids[:len(all_decoded_ids)],
all_decoded_ids)

def test_decode_string_label_outputs_matches_inputs(self):
all_string_labels = []
dataset_true = self.dataloader.train_splits[0]
true_labels_list = list(dataset_true.map(lambda x: x['string_label']).as_numpy_iterator())
for string_label in true_labels_list:
all_string_labels += string_label.tolist()

encoded_dataloader = self.data_encoder.encode_string_labels(self.dataloader)
decoded_dataloader = self.data_encoder.decode_string_labels(encoded_dataloader)

all_decoded_labels = []
dataset_decoded = decoded_dataloader.train_splits[0]
decoded_labels_list = list(dataset_decoded.map(lambda x: x['string_label']).as_numpy_iterator())
for string_label in decoded_labels_list:
all_decoded_labels += string_label.tolist()
self.assertItemsEqual(all_string_labels[:len(all_decoded_labels)],
all_decoded_labels)


if __name__ == '__main__':
tfds.testing.test_main()
Loading

0 comments on commit c46c9e5

Please sign in to comment.