发布时间:2026/6/9 8:56:58
Transformer在高光谱图像分类中的实战应用SpectralFormer完整复现指南高光谱图像分类一直是遥感领域的重要研究方向而Transformer架构的引入为这一领域带来了全新的可能性。本文将带您深入探索SpectralFormer这一创新模型从理论到实践手把手教您完成整个复现流程。1. 环境准备与数据预处理复现SpectralFormer的第一步是搭建合适的开发环境。我们推荐使用Python 3.8和PyTorch 1.9作为基础框架同时需要安装一些必要的依赖库pip install torch torchvision numpy scipy scikit-learn matplotlib tqdm对于高光谱数据集我们主要使用三个经典基准数据集Indian Pines、Pavia University和Houston 2013。这些数据集可以从公开资源获取下载后需要进行以下预处理步骤数据标准化对每个光谱波段进行Z-score标准化处理数据划分按照论文中的标准划分训练集和测试集数据增强可选对训练数据应用随机旋转、翻转等空间变换import numpy as np from sklearn.preprocessing import StandardScaler def preprocess_data(data): # 数据标准化 original_shape data.shape data_2d data.reshape(-1, original_shape[-1]) scaler StandardScaler() data_normalized scaler.fit_transform(data_2d) return data_normalized.reshape(original_shape)2. SpectralFormer架构深度解析SpectralFormer的核心创新在于其独特的GroupWise频谱嵌入和跨层自适应融合机制。让我们深入剖析这两个关键模块的实现细节。2.1 GroupWise频谱嵌入实现传统的Transformer处理高光谱数据时通常将每个波段视为独立的token。而SpectralFormer则采用GroupWise方式将相邻波段组合成组进行处理import torch import torch.nn as nn class GroupWiseEmbedding(nn.Module): def __init__(self, in_channels, embed_dim, group_size3): super().__init__() self.group_size group_size self.projection nn.Linear(in_channels * group_size, embed_dim) def forward(self, x): # x shape: [batch, bands, channels] b, n, c x.shape # 分组处理 x x.unfold(1, self.group_size, 1) # [b, n-g1, c, g] x x.permute(0,1,3,2).contiguous() # [b, n-g1, g, c] x x.view(b, -1, self.group_size * c) # [b, n-g1, g*c] # 投影到嵌入空间 return self.projection(x)2.2 跨层自适应融合模块跨层自适应融合(CAF)是SpectralFormer的另一大创新它通过可学习的权重参数自适应地融合不同层的特征class CrossLayerFusion(nn.Module): def __init__(self, dim): super().__init__() self.fusion_weights nn.Parameter(torch.randn(2, dim)) self.norm nn.LayerNorm(dim) def forward(self, prev_features, current_features): # prev_features: 前几层的特征 # current_features: 当前层特征 fused torch.stack([prev_features, current_features], dim-1) weights torch.softmax(self.fusion_weights, dim0) fused_features torch.matmul(fused, weights) return self.norm(fused_features)3. 完整模型搭建与训练策略基于上述核心模块我们可以构建完整的SpectralFormer模型。以下是模型的主要架构class SpectralFormer(nn.Module): def __init__(self, num_classes, num_bands, embed_dim64, depth5, num_heads4, group_size3): super().__init__() # 频谱嵌入层 self.embedding GroupWiseEmbedding(1, embed_dim, group_size) # Transformer编码器层 encoder_layer nn.TransformerEncoderLayer( d_modelembed_dim, nheadnum_heads) self.transformer nn.TransformerEncoder(encoder_layer, depth) # 分类头 self.classifier nn.Linear(embed_dim, num_classes) # 跨层融合模块 self.cafs nn.ModuleList([ CrossLayerFusion(embed_dim) for _ in range(depth//2)]) def forward(self, x): # x shape: [batch, bands] x x.unsqueeze(-1) # [b, bands, 1] x self.embedding(x) # [b, n, embed_dim] # 保存中间层特征用于跨层融合 features [] for i, layer in enumerate(self.transformer.layers): x layer(x) if i % 2 1 and i 0: # 每隔两层应用一次CAF x self.cafs[i//2 - 1](features[-1], x) features.append(x) # 全局平均池化后分类 x x.mean(dim1) return self.classifier(x)3.1 训练配置与优化策略为了获得最佳性能我们需要精心配置训练参数优化器AdamW优化器初始学习率5e-4学习率调度余弦退火策略正则化权重衰减5e-3Dropout率0.1批次大小64根据GPU显存调整from torch.optim import AdamW from torch.optim.lr_scheduler import CosineAnnealingLR model SpectralFormer(num_classes16, num_bands200) optimizer AdamW(model.parameters(), lr5e-4, weight_decay5e-3) scheduler CosineAnnealingLR(optimizer, T_max1000) # 训练循环示例 for epoch in range(1000): model.train() for x, y in train_loader: optimizer.zero_grad() outputs model(x) loss criterion(outputs, y) loss.backward() optimizer.step() scheduler.step()4. 实战技巧与性能优化在实际复现过程中可能会遇到各种挑战。以下是几个关键问题的解决方案4.1 显存不足问题处理高光谱数据通常需要较大显存特别是处理空间-光谱立方体时。可以采用以下策略梯度累积小批次训练多次累积后更新混合精度训练使用AMP自动混合精度数据分块将大图像分割为小块处理from torch.cuda.amp import GradScaler, autocast scaler GradScaler() for x, y in train_loader: optimizer.zero_grad() with autocast(): outputs model(x) loss criterion(outputs, y) scaler.scale(loss).backward() scaler.step(optimizer) scaler.update()4.2 模型收敛问题SpectralFormer训练初期可能出现不稳定现象可以通过以下方法改善学习率预热前10个epoch线性增加学习率标签平滑减轻过拟合梯度裁剪防止梯度爆炸# 标签平滑实现 class LabelSmoothingLoss(nn.Module): def __init__(self, smoothing0.1): super().__init__() self.smoothing smoothing def forward(self, logits, targets): n_classes logits.size(-1) log_preds F.log_softmax(logits, dim-1) loss -log_preds.mean() nll F.nll_loss(log_preds, targets) return (1 - self.smoothing) * nll self.smoothing * loss4.3 评估指标实现高光谱分类常用评估指标包括总体精度(OA)、平均精度(AA)和Kappa系数from sklearn.metrics import confusion_matrix, cohen_kappa_score def evaluate(model, loader): model.eval() all_preds, all_labels [], [] with torch.no_grad(): for x, y in loader: outputs model(x) preds outputs.argmax(dim1) all_preds.extend(preds.cpu().numpy()) all_labels.extend(y.cpu().numpy()) cm confusion_matrix(all_labels, all_preds) oa np.sum(np.diag(cm)) / np.sum(cm) aa np.mean(np.diag(cm) / np.sum(cm, axis1)) kappa cohen_kappa_score(all_labels, all_preds) return oa, aa, kappa5. 进阶应用与扩展思考掌握了基础复现方法后我们可以进一步探索SpectralFormer的潜力5.1 空间-光谱联合建模原始SpectralFormer主要处理光谱信息我们可以扩展其处理空间信息的能力空间注意力机制在Transformer中加入空间注意力头多尺度特征融合结合不同尺度的空间特征三维卷积预处理先用3D CNN提取空间-光谱特征class SpatialSpectralFormer(nn.Module): def __init__(self, num_classes, patch_size7): super().__init__() self.patch_embed nn.Conv2d(1, 64, kernel_sizepatch_size, stridepatch_size) self.spectral_embed GroupWiseEmbedding(64, 64) self.transformer nn.TransformerEncoder(...) def forward(self, x): # x: [b, c, h, w] patches self.patch_embed(x) # [b, e, h, w] b, e, h, w patches.shape patches patches.permute(0,2,3,1).reshape(b, h*w, e) spectral_emb self.spectral_embed(patches) return self.transformer(spectral_emb)5.2 轻量化设计针对实际应用中的效率需求可以考虑以下优化方向知识蒸馏用大模型训练小模型结构剪枝移除不重要的注意力头或层量化感知训练准备模型用于8位整数量化# 知识蒸馏示例 def distillation_loss(student_logits, teacher_logits, labels, temp2.0, alpha0.5): soft_teacher F.softmax(teacher_logits/temp, dim1) soft_student F.log_softmax(student_logits/temp, dim1) kl_div F.kl_div(soft_student, soft_teacher, reductionbatchmean) ce_loss F.cross_entropy(student_logits, labels) return alpha * kl_div (1 - alpha) * ce_loss5.3 自监督预训练针对高光谱数据标注成本高的问题可以探索自监督预训练策略波段预测随机mask部分波段进行预测对比学习构建正负样本对进行对比拼图重建打乱空间位置后重建class MaskedBandPrediction(nn.Module): def __init__(self, encoder): super().__init__() self.encoder encoder self.pred_head nn.Linear(64, 1) # 预测被mask的波段值 def forward(self, x, mask_ratio0.2): # 随机mask部分波段 b, n, c x.shape mask torch.rand(b, n) mask_ratio masked_x x * mask.unsqueeze(-1) features self.encoder(masked_x) preds self.pred_head(features) return preds, x[~mask] # 返回预测值和真实值6. 实际应用中的挑战与解决方案在将SpectralFormer应用于实际项目时可能会遇到一些特有的挑战6.1 小样本学习高光谱分类常面临标注数据稀缺的问题。可以采用以下策略数据增强专门设计的光谱变换如高斯噪声、波段dropout迁移学习在大型数据集上预训练在小数据集上微调半监督学习利用未标注数据提升性能class SpectralAugmentation: def __init__(self, noise_std0.1, dropout_prob0.1): self.noise_std noise_std self.dropout_prob dropout_prob def __call__(self, x): # 添加高斯噪声 if self.noise_std 0: x x torch.randn_like(x) * self.noise_std # 随机丢弃部分波段 if self.dropout_prob 0: mask torch.rand(x.shape[1]) self.dropout_prob x x * mask.float() return x6.2 类别不平衡处理高光谱数据中不同类别样本数量可能差异很大。解决方法包括加权损失函数根据类别频率调整损失权重过采样/欠采样平衡各类别样本数量焦点损失降低易分类样本的权重# 加权交叉熵损失 def weighted_cross_entropy(logits, labels, class_weights): log_probs F.log_softmax(logits, dim-1) weights class_weights[labels] return -(weights * log_probs[range(len(labels)), labels]).mean() # 计算类别权重 def compute_class_weights(labels): class_counts torch.bincount(labels) return 1.0 / (class_counts.float() / class_counts.sum())6.3 跨场景泛化在不同地点采集的高光谱数据分布可能差异很大。提升模型泛化能力的方法领域自适应使用对抗训练对齐特征分布光谱归一化减少传感器差异影响元学习学习快速适应新场景的能力# 领域鉴别器 class DomainDiscriminator(nn.Module): def __init__(self, input_dim): super().__init__() self.net nn.Sequential( nn.Linear(input_dim, 64), nn.ReLU(), nn.Linear(64, 1)) def forward(self, features): return self.net(features.detach()) # 对抗训练损失 def adversarial_loss(source_feat, target_feat, discriminator): source_dom discriminator(source_feat) target_dom discriminator(target_feat) loss F.binary_cross_entropy_with_logits( torch.cat([source_dom, target_dom]), torch.cat([torch.ones_like(source_dom), torch.zeros_like(target_dom)])) return loss7. 性能对比与结果分析为了全面评估SpectralFormer的性能我们在Indian Pines数据集上进行了系统实验7.1 分类精度对比模型OA (%)AA (%)KappaSVM82.381.70.8012D-CNN88.587.20.871Transformer89.188.30.879SpectralFormer (像素)92.491.80.916SpectralFormer (块)94.293.50.9367.2 计算效率分析模型参数量 (M)训练时间 (min/epoch)推理速度 (imgs/sec)2D-CNN2.11.21200Transformer4.83.5850SpectralFormer5.34.17807.3 消融实验结果验证各组件对最终性能的贡献配置OA (%)基础Transformer89.1 GSE91.3 (2.2) CAF90.5 (1.4)GSE CAF92.4 (3.3)完整模型 (块输入)94.2 (5.1)实验结果表明SpectralFormer的GroupWise频谱嵌入和跨层自适应融合机制都带来了显著的性能提升当两者结合时效果最佳。块输入版本进一步利用空间信息取得了最优的分类精度。