cbam代码
CBAM是一种用于增强卷积神经网络的注意力机制。它可以增强网络对重要特征的感知,从而提高网络的性能。
pythonimport tensorflow as tf
class ChannelAttention(tf.keras.layers.Layer):
def __init__(self, ratio=8):
super(ChannelAttention, self).__init__()
self.ratio = ratio
def build(self, input_shape):
self.channels = input_shape[-1]
self.dense_layer1 = tf.keras.layers.Dense(units=self.channels // self.ratio,
activation='relu',
kernel_initializer='he_normal',
use_bias=True)
self.dense_layer2 = tf.keras.layers.Dense(units=self.channels,
kernel_initializer='he_normal',
use_bias=True)
def call(self, inputs):
avg_pool = tf.reduce_mean(inputs, axis=[1, 2], keepdims=True)
max_pool = tf.reduce_max(inputs, axis=[1, 2], keepdims=True)
avg_pool = self.dense_layer2(self.dense_layer1(avg_pool))
max_pool = self.dense_layer2(self.dense_layer1(max_pool))
attention = tf.nn.sigmoid(avg_pool + max_pool)
return inputs * attention
class SpatialAttention(tf.keras.layers.Layer):
def __init__(self, kernel_size=7):
super(SpatialAttention, self).__init__()
self.kernel_size = kernel_size
def build(self, input_shape):
self.conv = tf.keras.layers.Conv2D(filters=1,
kernel_size=self.kernel_size,
padding='same',
kernel_initializer='he_normal',
use_bias=False)
def call(self, inputs):
avg_pool = tf.reduce_mean(inputs, axis=-1, keepdims=True)
max_pool = tf.reduce_max(inputs, axis=-1, keepdims=True)
combined = tf.concat([avg_pool, max_pool], axis=-1)
attention = self.conv(combined)
attention = tf.nn.sigmoid(attention)
return inputs * attention
class CBAM(tf.keras.layers.Layer):
def __init__(self, ratio=8, kernel_size=7):
super(CBAM, self).__init__()
self.channel_attention = ChannelAttention(ratio)
self.spatial_attention = SpatialAttention(kernel_size)
def call(self, inputs):
outputs = self.channel_attention(inputs)
outputs = self.spatial_attention(outputs)
return outputs
这个代码实现了 CBAM 模块,包括通道注意力和空间注意力。通道注意力通过对特征图在通道维度上进行平均池化和最大池化操作,学习通道间的关系,然后使用全连接层学习通道间的重要性。空间注意力通过对特征图在空间维度上进行平均池化和最大池化操作,然后使用卷积层学习空间上的注意力权重。最后将通道注意力和空间注意力结合起来,得到最终的 CBAM 模块输出。
pythonimport tensorflow as tf
# 构建一个简单的 CNN 模型
def build_cnn_model(input_shape, num_classes):
inputs = tf.keras.layers.Input(shape=input_shape)
x = tf.keras.layers.Conv2D(32, (3, 3), activation='relu', padding='same')(inputs)
x = tf.keras.layers.MaxPooling2D((2, 2))(x)
x = tf.keras.layers.Conv2D(64, (3, 3), activation='relu', padding='same')(x)
x = tf.keras.layers.MaxPooling2D((2, 2))(x)
# 在这里添加 CBAM 模块
x = CBAM()(x)
x = tf.keras.layers.Flatten()(x)
x = tf.keras.layers.Dense(128, activation='relu')(x)
outputs = tf.keras.layers.Dense(num_classes, activation='softmax')(x)
model = tf.keras.models.Model(inputs=inputs, outputs=outputs)
return model
# 定义模型参数
input_shape = (224, 224, 3)
num_classes = 10
# 构建 CNN 模型并加入 CBAM 模块
model = build_cnn_model(input_shape, num_classes)
# 编译模型
model.compile(optimizer='adam',
loss='categorical_crossentropy',
metrics=['accuracy'])
# 打印模型结构
model.summary()
在上面的示例中,首先定义了一个简单的 CNN 模型。然后在模型的构建过程中,通过调用 CBAM(),将 CBAM 模块集成到了 CNN 模型中。最后编译模型并打印模型结构。
你可以根据需要调整 CNN 模型的结构、参数和 CBAM 模块的参数,以满足你的实际需求。