Стохастическое усреднение веса (SWA) Гауссова реализацияPython

Программы на Python
Ответить Пред. темаСлед. тема
Anonymous
 Стохастическое усреднение веса (SWA) Гауссова реализация

Сообщение Anonymous »

Это часть домашнего задания, и меня это замучило... Цель состоит в том, чтобы реализовать SWA-Gaussian, предложенный Уэсли Дж. Мэддоксом на NeurIPS 2019. Алгоритм следующий:
введите описание изображения здесь
Я пробовал все, чтобы реализовать это, но всегда не мог пройти критерий в тестовом наборе данных... Буду признателен, если кто-нибудь поможет мне проверить мой код и посмотреть, все ли не так... Я знаю, это долго и раздражает... Если вы все еще хотите помочь, просто посмотрите код под знаком #TODO... Это будет моя последняя попытка...
from util import draw_reliability_diagram, cost_function, setup_seeds, calc_calibration_curve

USE_PRETRAINED_MODEL = True
"""
If `USE_PRETRAINED_MODEL` is `True`, then MAP inference uses provided pretrained weights.
"""

class InferenceType(enum.Enum):
"""
Inference mode switch for your implementation.
`MAP` simply predicts the most likely class using pretrained MAP weights.
`SWAG_DIAGONAL` and `SWAG_FULL` correspond to SWAG-diagonal and the full SWAG method, respectively.
"""
MAP = 0
SWAG_DIAGONAL = 1
SWAG_FULL = 2

class SWAGInference(object):
"""
Your implementation of SWA-Gaussian.
This class is used to run and evaluate your solution.
You must preserve all methods and signatures of this class.
However, you can add new methods if you want.

We provide basic functionality and some helper methods.
You can pass all baselines by only modifying methods marked with TODO.
However, we encourage you to skim other methods in order to gain a better understanding of SWAG.
"""

def __init__(
self,
train_xs: torch.Tensor,
model_dir: pathlib.Path,
# TODO(1): change inference_mode to InferenceType.SWAG_DIAGONAL
# TODO(2): change inference_mode to InferenceType.SWAG_FULL
inference_mode: InferenceType = InferenceType.SWAG_FULL,
# TODO(2): optionally add/tweak hyperparameters
swag_training_epochs: int = 24,
swag_lr: float = 0.045, # original: 0.045
swag_update_interval: int = 1,
max_rank_deviation_matrix: int = 15,
num_bma_samples: int = 30,
):
"""
:param train_xs: Training images (for storage only)
:param model_dir: Path to directory containing pretrained MAP weights
:param inference_mode: Control which inference mode (MAP, SWAG-diagonal, full SWAG) to use
:param swag_training_epochs: Total number of gradient descent epochs for SWAG
:param swag_lr: Learning rate for SWAG gradient descent
:param swag_update_interval: Frequency (in epochs) for updating SWAG statistics during gradient descent
:param max_rank_deviation_matrix: Rank of deviation matrix for full SWAG
:param num_bma_samples: Number of networks to sample for Bayesian model averaging during prediction
"""

self.model_dir = model_dir
self.inference_mode = inference_mode
self.swag_training_epochs = swag_training_epochs
self.swag_lr = swag_lr
self.swag_update_interval = swag_update_interval
self.max_rank_deviation_matrix = max_rank_deviation_matrix
self.num_bma_samples = num_bma_samples

# Network used to perform SWAG.
# Note that all operations in this class modify this network IN-PLACE!
self.network = CNN(in_channels=3, out_classes=6)

# Store training dataset to recalculate batch normalization statistics during SWAG inference
self.training_dataset = torch.utils.data.TensorDataset(train_xs)

# SWAG-diagonal
# TODO(1): create attributes for SWAG-diagonal
# Hint: self._create_weight_copy() creates an all-zero copy of the weights
# as a dictionary that maps from weight name to values.
# Hint: you never need to consider the full vector of weights,
# but can always act on per-layer weights (in the format that _create_weight_copy() returns)
self.iteration_num = 0
self.parameter_swa = self._create_weight_copy()
self.parameter_swa_2ndmoment = self._create_weight_copy()
# END(1)

# Full SWAG
# TODO(2): create attributes for SWAG-full
# Hint: check collections.deque
if self.inference_mode == InferenceType.SWAG_FULL:
self.parameter_D = collections.deque()
for _ in range(self.max_rank_deviation_matrix):
self.parameter_D.append(self._create_weight_copy()) # would contain weight_copy for self.max_rank_deviation_matrix times
# END(2)

# Calibration, prediction, and other attributes
# TODO(2): create additional attributes, e.g., for calibration
self._calibration_threshold = None # this is an example, feel free to be creative

def update_swag_statistics(self) -> None:
"""
Update SWAG statistics with the current weights of self.network.
"""
self.iteration_num = self.iteration_num + 1
if self.iteration_num % self.swag_update_interval != 0:
return

# Create a copy of the current network weights
copied_params = {name: param.detach() for name, param in self.network.named_parameters()}

# Full SWAG
if self.inference_mode == InferenceType.SWAG_FULL:
# TODO(2): update full SWAG attributes for weight `name` using `copied_params` and `param`
self.parameter_D.popleft()
self.parameter_D.append({})
# raise NotImplementedError("Update full SWAG statistics")
# END(2)

# SWAG-diagonal
for name, param in copied_params.items():
# TODO(1): update SWAG-diagonal attributes for weight `name` using `copied_params` and `param`
update_num = self.iteration_num // self.swag_update_interval
for name, param in self.network.named_parameters():
self.parameter_swa[name] = (update_num * self.parameter_swa[name] + param) / (update_num + 1)
self.parameter_swa_2ndmoment[name] = (update_num * self.parameter_swa_2ndmoment[name] + param**2) / (update_num + 1)
if self.inference_mode == InferenceType.SWAG_FULL:
self.parameter_D[-1][name] = param - self.parameter_swa[name]
# raise NotImplementedError("Update SWAG-diagonal statistics")
# END(1)

def fit_swag_model(self, loader: torch.utils.data.DataLoader) -> None:
"""
Fit SWAG on top of the pretrained network self.network.
This method should perform gradient descent with occasional SWAG updates
by calling self.update_swag_statistics().
"""

# We use SGD with momentum and weight decay to perform SWA.
# See the paper on how weight decay corresponds to a type of prior.
# Feel free to play around with optimization hyperparameters.
optimizer = torch.optim.SGD(
self.network.parameters(),
lr=self.initial_lr,
momentum=0.9,
nesterov=False,
weight_decay=1e-4, # original: 1e-4
)
loss_fn = torch.nn.CrossEntropyLoss(
reduction="mean",
)
# TODO(2): Update SWAGScheduler instantiation if you decided to implement a custom schedule.
# By default, this scheduler just keeps the initial learning rate given to `optimizer`.
lr_scheduler = SWAGScheduler(
optimizer,
epochs=self.swag_training_epochs,
steps_per_epoch=len(loader),
initial_lr=self.initial_lr,
final_lr=self.final_lr
)

# TODO(1): Perform initialization for SWAG fitting
self.parameter_swa = {name: param.detach() for name, param in self.network.named_parameters()} # Init with pretrained parameters
self.parameter_swa_2ndmoment = {name: param.detach()**2 for name, param in self.network.named_parameters()} # Init with pretrained parameters
# raise NotImplementedError("Initialize SWAG fitting")
# END(1)

self.network.train()
with tqdm.trange(self.swag_training_epochs, desc="Running gradient descent for SWA") as pbar:
progress_dict = {}
for epoch in pbar:
avg_loss = 0.0
avg_accuracy = 0.0
num_samples = 0
for batch_images, batch_snow_labels, batch_cloud_labels, batch_labels in loader:
optimizer.zero_grad()
predictions = self.network(batch_images)
batch_loss = loss_fn(input=predictions, target=batch_labels)
batch_loss.backward()
optimizer.step()
progress_dict["lr"] = lr_scheduler.get_last_lr()[0]
lr_scheduler.step()

# Calculate cumulative average training loss and accuracy
avg_loss = (batch_images.size(0) * batch_loss.item() + num_samples * avg_loss) / (
num_samples + batch_images.size(0)
)
avg_accuracy = (
torch.sum(predictions.argmax(dim=-1) == batch_labels).item()
+ num_samples * avg_accuracy
) / (num_samples + batch_images.size(0))
num_samples += batch_images.size(0)
progress_dict["avg. epoch loss"] = avg_loss
progress_dict["avg. epoch accuracy"] = avg_accuracy
pbar.set_postfix(progress_dict)

# TODO(1): Implement periodic SWAG updates using the attributes defined in __init__
# raise NotImplementedError("Periodically update SWAG statistics")
self.update_swag_statistics()
# END(1)

def apply_calibration(self, validation_data: torch.utils.data.Dataset) -> None:
"""
Calibrate your predictions using a small validation set.
validation_data contains well-defined and ambiguous samples,
where you can identify the latter by having label -1.
"""
if self.inference_mode == InferenceType.MAP:
# In MAP mode, simply predict argmax and do nothing else
self._calibration_threshold = 0.0
return

# TODO(1): pick a prediction threshold, either constant or adaptive.
# The provided value should suffice to pass the easy baseline.
self._calibration_threshold = 2.0 / 3.0

# TODO(2): perform additional calibration if desired.
# Feel free to remove or change the prediction threshold.
val_images, val_snow_labels, val_cloud_labels, val_labels = validation_data.tensors
assert val_images.size() == (140, 3, 60, 60) # N x C x H x W
assert val_labels.size() == (140,)
assert val_snow_labels.size() == (140,)
assert val_cloud_labels.size() == (140,)

def predict_probabilities_swag(self, loader: torch.utils.data.DataLoader) -> torch.Tensor:
"""
Perform Bayesian model averaging using your SWAG statistics and predict
probabilities for all samples in the loader.
Outputs should be a Nx6 tensor, where N is the number of samples in loader,
and all rows of the output should sum to 1.
That is, output row i column j should be your predicted p(y=j | x_i).
"""

self.network.eval()

# Perform Bayesian model averaging:
# Instead of sampling self.num_bma_samples networks (using self.sample_parameters())
# for each datapoint, you can save time by sampling self.num_bma_samples networks,
# and perform inference with each network on all samples in loader.
model_predictions = []
for _ in tqdm.trange(self.num_bma_samples, desc="Performing Bayesian model averaging"):
# TODO(1): Sample new parameters for self.network from the SWAG approximate posterior
# raise NotImplementedError("Sample network parameters")
self.sample_parameters()
# END(1)

# TODO(1): Perform inference for all samples in `loader` using current model sample,
# and add the predictions to model_predictions
# raise NotImplementedError("Perform inference using current model")
model_predictions.append(self.predict_probabilities_map(loader))
# END(1)

assert len(model_predictions) == self.num_bma_samples
assert all(
isinstance(sample_predictions, torch.Tensor)
and sample_predictions.dim() == 2 # N x C
and sample_predictions.size(1) == 6
for sample_predictions in model_predictions
)

# TODO(1): Average predictions from different model samples into bma_probabilities
# raise NotImplementedError("Aggregate predictions from model samples")
bma_probabilities = torch.stack(model_predictions, dim=0).mean(dim=0)
# END(1)

assert bma_probabilities.dim() == 2 and bma_probabilities.size(1) == 6 # N x C
return bma_probabilities

def sample_parameters(self) -> None:
"""
Sample a new network from the approximate SWAG posterior.
For simplicity, this method directly modifies self.network in-place.
Hence, after calling this method, self.network corresponds to a new posterior sample.
"""

# Instead of acting on a full vector of parameters, all operations can be done on per-layer parameters.
for name, param in self.network.named_parameters():
# SWAG-diagonal part
z_diag = torch.randn(param.size())
# TODO(1): Sample parameter values for SWAG-diagonal
# raise NotImplementedError("Sample parameter for SWAG-diagonal")
mean_weights = self.parameter_swa[name]
std_weights = torch.clamp(self.parameter_swa_2ndmoment[name] - self.parameter_swa[name] ** 2, min=0) ** 0.5
# END(1)
assert mean_weights.size() == param.size() and std_weights.size() == param.size()

# Diagonal part
sampled_weight = mean_weights + std_weights * z_diag

# Full SWAG part
if self.inference_mode == InferenceType.SWAG_FULL:
# TODO(2): Sample parameter values for full SWAG
# raise NotImplementedError("Sample parameter for full SWAG")
matrix_D = torch.stack([item[name] for item in self.parameter_D], dim=-1)
z_cov = torch.randn(self.max_rank_deviation_matrix)
sampled_weight = mean_weights + 0.5**0.5 * std_weights * z_diag + 0.5**0.5 * matrix_D @ z_cov / (self.max_rank_deviation_matrix - 1)**0.5
# END(2)

# Modify weight value in-place; directly changing self.network
param.data = sampled_weight

# TODO(1): Don't forget to update batch normalization statistics using self._update_batchnorm_statistics()
# in the appropriate place!
# raise NotImplementedError("Update batch normalization statistics for newly sampled network")
self._update_batchnorm_statistics()
# END(1)

def predict_labels(self, predicted_probabilities: torch.Tensor) -> torch.Tensor:
"""
Predict labels in {0, 1, 2, 3, 4, 5} or "don't know" as -1
based on your model's predicted probabilities.
The parameter predicted_probabilities is an Nx6 tensor containing predicted probabilities
as returned by predict_probabilities(...).
The output should be a N-dimensional long tensor, containing values in {-1, 0, 1, 2, 3, 4, 5}.
"""

# label_probabilities contains the per-row maximum values in predicted_probabilities,
# max_likelihood_labels the corresponding column index (equivalent to class).
label_probabilities, max_likelihood_labels = torch.max(predicted_probabilities, dim=-1)
num_samples, num_classes = predicted_probabilities.size()
assert label_probabilities.size() == (num_samples,) and max_likelihood_labels.size() == (num_samples,)

# A model without uncertainty awareness might simply predict the most likely label per sample:
# return max_likelihood_labels

# A bit better: use a threshold to decide whether to return a label or "don't know" (label -1)
# TODO(2): implement a different decision rule if desired
return torch.where(
label_probabilities >= self._calibration_threshold,
max_likelihood_labels,
torch.ones_like(max_likelihood_labels) * -1,
)

def _create_weight_copy(self) -> typing.Dict[str, torch.Tensor]:
"""Create an all-zero copy of the network weights as a dictionary that maps name -> weight"""
return {
name: torch.zeros_like(param, requires_grad=False)
for name, param in self.network.named_parameters()
}

def fit(
self,
loader: torch.utils.data.DataLoader,
) -> None:
"""
Perform full SWAG fitting procedure.
If `PRETRAINED_WEIGHTS_FILE` is `True`, this method skips the MAP inference part,
and uses pretrained weights instead.
"""

# MAP inference to obtain initial weights
PRETRAINED_WEIGHTS_FILE = self.model_dir / "map_weights.pt"
if USE_PRETRAINED_MODEL:
self.network.load_state_dict(torch.load(PRETRAINED_WEIGHTS_FILE))
print("Loaded pretrained MAP weights from", PRETRAINED_WEIGHTS_FILE)
else:
self.fit_map_model(loader)

# SWAG
if self.inference_mode in (InferenceType.SWAG_DIAGONAL, InferenceType.SWAG_FULL):
self.fit_swag_model(loader)

def fit_map_model(self, loader: torch.utils.data.DataLoader) -> None:
"""
MAP inference procedure to obtain initial weights of self.network.
This is the exact procedure that was used to obtain the pretrained weights we provide.
"""
map_training_epochs = 140
initial_learning_rate = 0.01
reduced_learning_rate = 0.0001
start_decay_epoch = 50
decay_factor = reduced_learning_rate / initial_learning_rate

# Create optimizer, loss, and a learning rate scheduler that aids convergence
optimizer = torch.optim.SGD(
self.network.parameters(),
lr=initial_learning_rate,
momentum=0.9,
nesterov=False,
weight_decay=1e-4,
)
loss_fn = torch.nn.CrossEntropyLoss(
reduction="mean",
)
lr_scheduler = torch.optim.lr_scheduler.SequentialLR(
optimizer,
[
torch.optim.lr_scheduler.ConstantLR(optimizer, factor=1.0),
torch.optim.lr_scheduler.LinearLR(
optimizer,
start_factor=1.0,
end_factor=decay_factor,
total_iters=(map_training_epochs - start_decay_epoch) * len(loader),
),
],
milestones=[start_decay_epoch * len(loader)],
)

# Put network into training mode
# Batch normalization layers are only updated if the network is in training mode,
# and are replaced by a moving average if the network is in evaluation mode.
self.network.train()
with tqdm.trange(map_training_epochs, desc="Fitting initial MAP weights") as pbar:
progress_dict = {}
# Perform the specified number of MAP epochs
for epoch in pbar:
avg_loss = 0.0
avg_accuracy = 0.0
num_samples = 0
# Iterate over batches of randomly shuffled training data
for batch_images, _, _, batch_labels in loader:
# Training step
optimizer.zero_grad()
predictions = self.network(batch_images)
batch_loss = loss_fn(input=predictions, target=batch_labels)
batch_loss.backward()
optimizer.step()

# Save learning rate that was used for step, and calculate new one
progress_dict["lr"] = lr_scheduler.get_last_lr()[0]
with warnings.catch_warnings():
# Suppress annoying warning (that we cannot control) inside PyTorch
warnings.simplefilter("ignore")
lr_scheduler.step()

# Calculate cumulative average training loss and accuracy
avg_loss = (batch_images.size(0) * batch_loss.item() + num_samples * avg_loss) / (
num_samples + batch_images.size(0)
)
avg_accuracy = (
torch.sum(predictions.argmax(dim=-1) == batch_labels).item()
+ num_samples * avg_accuracy
) / (num_samples + batch_images.size(0))
num_samples += batch_images.size(0)

progress_dict["avg. epoch loss"] = avg_loss
progress_dict["avg. epoch accuracy"] = avg_accuracy
pbar.set_postfix(progress_dict)

def predict_probabilities(self, xs: torch.Tensor) -> torch.Tensor:
"""
Predict class probabilities for the given images xs.
This method returns an NxC float tensor,
where row i column j corresponds to the probability that y_i is class j.

This method uses different strategies depending on self.inference_mode.
"""
self.network = self.network.eval()

# Create a loader that we can deterministically iterate many times if necessary
loader = torch.utils.data.DataLoader(
torch.utils.data.TensorDataset(xs),
batch_size=32,
shuffle=False,
num_workers=0,
drop_last=False,
)

with torch.no_grad(): # save memory by not tracking gradients
if self.inference_mode == InferenceType.MAP:
return self.predict_probabilities_map(loader)
else:
return self.predict_probabilities_swag(loader)

def predict_probabilities_map(self, loader: torch.utils.data.DataLoader) -> torch.Tensor:
"""
Predict probabilities assuming that self.network is a MAP estimate.
"""
all_predictions = []
for (batch_images,) in loader:
all_predictions.append(self.network(batch_images))

all_predictions = torch.cat(all_predictions)
return torch.softmax(all_predictions, dim=-1)

def _update_batchnorm_statistics(self) -> None:
"""
Reset and fit batch normalization statistics using the training dataset self.training_dataset.

Batch normalization usually uses an exponential moving average, controlled by the `momentum` parameter.
However, we are not training but want the statistics for the full training dataset.
Hence, setting `momentum` to `None` tracks a cumulative average instead.
The following code stores original `momentum` values, sets all to `None`,
and restores the previous hyperparameters after updating batchnorm statistics.
"""

original_momentum_values = dict()
for module in self.network.modules():
# Only need to handle batchnorm modules
if not isinstance(module, torch.nn.modules.batchnorm._BatchNorm):
continue

# Store old momentum value before removing it
original_momentum_values[module] = module.momentum
module.momentum = None

# Reset batch normalization statistics
module.reset_running_stats()

loader = torch.utils.data.DataLoader(
self.training_dataset,
batch_size=32,
shuffle=False,
num_workers=0,
drop_last=False,
)

self.network.train()
for (batch_images,) in loader:
self.network(batch_images)
self.network.eval()

# Restore old `momentum` hyperparameter values
for module, momentum in original_momentum_values.items():
module.momentum = momentum

def evaluate(
swag_inference: SWAGInference,
eval_dataset: torch.utils.data.Dataset,
extended_evaluation: bool,
output_location: pathlib.Path,
) -> None:

print("Evaluating model on validation data")

# We ignore is_snow and is_cloud here, but feel free to use them as well
images, snow_labels, cloud_labels, labels = eval_dataset.tensors

# Predict class probabilities on test data,
# most likely classes (according to the max predicted probability),
# and classes as predicted by your SWAG implementation.
all_pred_probabilities = swag_inference.predict_probabilities(images)
max_pred_probabilities, argmax_pred_labels = torch.max(all_pred_probabilities, dim=-1)
predicted_labels = swag_inference.predict_labels(all_pred_probabilities)

# Create a mask that ignores ambiguous samples (those with class -1)
non_ambiguous_mask = labels != -1

# Calculate three kinds of accuracy:
# 1. Overall accuracy, counting "don't know" (-1) as its own class
# 2. Accuracy on all samples that have a known label. Predicting -1 on those counts as wrong here.
# 3. Accuracy on all samples that have a known label w.r.t. the class with the highest predicted probability.
overall_accuracy = torch.mean((predicted_labels == labels).float()).item()
non_ambiguous_accuracy = torch.mean((predicted_labels[non_ambiguous_mask] == labels[non_ambiguous_mask]).float()).item()
non_ambiguous_argmax_accuracy = torch.mean(
(argmax_pred_labels[non_ambiguous_mask] == labels[non_ambiguous_mask]).float()
).item()
print(f"Accuracy (raw): {overall_accuracy:.4f}")
print(f"Accuracy (non-ambiguous only, your predictions): {non_ambiguous_accuracy:.4f}")
print(f"Accuracy (non-ambiguous only, predicting most-likely class): {non_ambiguous_argmax_accuracy:.4f}")

# Determine which threshold would yield the smallest cost on the validation data
# Note that this threshold does not necessarily generalize to the test set!
# However, it can help you judge your method's calibration.
threshold_values = [0.0] + list(torch.unique(max_pred_probabilities, sorted=True))
costs = []
for threshold in threshold_values:
thresholded_predictions = torch.where(max_pred_probabilities torch.Tensor:
x = self.layer0(x)
x = self.layer1(x)
x = self.layer2(x)
x = self.pool1(x)
x = self.layer3(x)
x = self.layer4(x)
x = self.pool2(x)
x = self.layer5(x)

# Average features over both spatial dimensions, and remove the now superfluous dimensions
x = self.global_pool(x).squeeze(-1).squeeze(-1)

log_softmax = self.linear(x)

return log_softmax



Подробнее здесь: https://stackoverflow.com/questions/791 ... ementation
Реклама
Ответить Пред. темаСлед. тема

Быстрый ответ

Изменение регистра текста: 
Смайлики
:) :( :oops: :roll: :wink: :muza: :clever: :sorry: :angel: :read: *x)
Ещё смайлики…
   
К этому ответу прикреплено по крайней мере одно вложение.

Если вы не хотите добавлять вложения, оставьте поля пустыми.

Максимально разрешённый размер вложения: 15 МБ.

  • Похожие темы
    Ответы
    Просмотры
    Последнее сообщение

Вернуться в «Python»