l2gv2.align.geo package#
Submodules#
l2gv2.align.geo.geoalign module#
Alignment using pytorch geometric.
- class l2gv2.align.geo.geoalign.GeoAlignmentProblem(patches: list[Patch], patch_edges=None, min_overlap=None, copy_data=True, self_loops=False, verbose=False, num_epochs: int = 1000, learning_rate: float = 0.001, model_type: str = 'affine', device: str = 'cpu')#
Bases:
AlignmentProblemAlignment problem using pytorch geometric.
- align_patches(scale=False)#
Align the patches.
l2gv2.align.geo.model module#
Model for aligning patch embeddings
l2gv2.align.geo.train module#
Train the model on the patch embeddings
- l2gv2.align.geo.train.patchgraph_mse_loss(transformed_emb)#
Custom loss function that computes the squared norm of differences between transformed pairs in the dictionary.
- Parameters:
transformed_dict – Dictionary with keys (i,j) and values (XW_i+b_i, YW_j+b_j)
- Returns:
Total loss as the sum of squared differences
- l2gv2.align.geo.train.train_alignment_model(patch_intersections, n_patches, device='cpu', num_epochs=100, learning_rate=0.05, model_type='affine', verbose=True)#
Train the model on the patch embeddings :param patch_emb: list of torch.Tensor
patch embeddings
- Parameters:
n_patches – int number of patches
device – str device to run the model on
num_epochs – int number of epochs to train the model
learning_rate – float learning rate for the optimizer
- Returns:
Model loss_hist: list
- Return type:
model
Module contents#
- class l2gv2.align.geo.GeoAlignmentProblem(patches: list[Patch], patch_edges=None, min_overlap=None, copy_data=True, self_loops=False, verbose=False, num_epochs: int = 1000, learning_rate: float = 0.001, model_type: str = 'affine', device: str = 'cpu')#
Bases:
AlignmentProblemAlignment problem using pytorch geometric.
- align_patches(scale=False)#
Align the patches.
- class l2gv2.align.geo.OrthogonalModel(dim, n_patches, device)#
Bases:
ModuleModel for aligning patch embeddings
- forward(patch_intersection)#
Forward pass :param patch_intersection: list of tuples
- Returns:
list of transformed embeddings
- l2gv2.align.geo.patchgraph_mse_loss(transformed_emb)#
Custom loss function that computes the squared norm of differences between transformed pairs in the dictionary.
- Parameters:
transformed_dict – Dictionary with keys (i,j) and values (XW_i+b_i, YW_j+b_j)
- Returns:
Total loss as the sum of squared differences
- l2gv2.align.geo.train_alignment_model(patch_intersections, n_patches, device='cpu', num_epochs=100, learning_rate=0.05, model_type='affine', verbose=True)#
Train the model on the patch embeddings :param patch_emb: list of torch.Tensor
patch embeddings
- Parameters:
n_patches – int number of patches
device – str device to run the model on
num_epochs – int number of epochs to train the model
learning_rate – float learning rate for the optimizer
- Returns:
Model loss_hist: list
- Return type:
model