Visión: CNNs y Transfer Learning
¿Qué es?
Una CNN (Convolutional Neural Network, red neuronal convolucional) es una red especializada en imágenes. En vez de aplastar la imagen en un vector, conserva su estructura 2D y desliza pequeños filtros que detectan bordes, texturas y, en capas profundas, formas completas. El transfer learning es reutilizar una CNN ya entrenada con millones de imágenes y adaptarla a tu problema, en lugar de entrenar desde cero.
¿Cómo funciona?
Una capa convolucional aplica filtros que recorren la imagen detectando patrones locales sin importar dónde aparezcan. Capas de pooling reducen la resolución y quedan los rasgos importantes. Apilando convoluciones, la red aprende una jerarquía: bordes → texturas → partes → objetos.
En transfer learning tomamos un modelo como ResNet18, ya entrenado en ImageNet. Sus primeras capas detectan rasgos universales que sirven para casi cualquier imagen. Congelamos esas capas y solo entrenamos una cabeza nueva para nuestras clases. Así obtenemos buena precisión con pocos datos y poco cómputo.
¿Para qué sirve?
Para clasificar imágenes, detectar objetos, control de calidad visual, diagnóstico médico. El transfer learning es lo que hace que esto sea viable sin un centro de datos: aprovechas el aprendizaje de otro.
La pieza del proyecto que construimos aquí
Aquí ofrecemos la rama de visión del proyecto final: un clasificador de imágenes por transfer learning. Es una alternativa al camino tabular; el patrón de entrenamiento es idéntico (forward, loss, backward, step), solo cambian los datos y el modelo. Usaremos un dataset pequeño (CIFAR-10) para que corra en cualquier máquina.
Paso 1: cargar imágenes con torchvision
import torch
import torchvision
import torchvision.transforms as T
from torch.utils.data import DataLoader
# Las medias/desv. estándar con que se entrenó ResNet en ImageNet
normalize = T.Normalize(mean=[0.485, 0.456, 0.406],
std=[0.229, 0.224, 0.225])
transform = T.Compose([
T.Resize(224), # ResNet espera 224x224
T.ToTensor(),
normalize,
])
train_set = torchvision.datasets.CIFAR10(
root="./data", train=True, download=True, transform=transform)
test_set = torchvision.datasets.CIFAR10(
root="./data", train=False, download=True, transform=transform)
train_dl = DataLoader(train_set, batch_size=64, shuffle=True, num_workers=2)
test_dl = DataLoader(test_set, batch_size=64, shuffle=False, num_workers=2)
Si no normalizas las imágenes con las mismas medias y desviaciones de ImageNet, el modelo preentrenado recibe valores en un rango que nunca vio y rinde mucho peor. La normalización no es opcional en transfer learning.
Paso 2: cargar ResNet18 y congelar el cuerpo
import torch.nn as nn
from torchvision.models import resnet18, ResNet18_Weights
modelo = resnet18(weights=ResNet18_Weights.DEFAULT)
# Congelamos todas las capas preentrenadas
for p in modelo.parameters():
p.requires_grad = False
# Reemplazamos la cabeza por una nueva (CIFAR-10 = 10 clases)
n_clases = 10
modelo.fc = nn.Linear(modelo.fc.in_features, n_clases)
# La nueva fc ya tiene requires_grad=True por defecto
Congelar (requires_grad = False) significa que esas capas no se actualizan: conservan lo aprendido en ImageNet. Solo entrenamos la fc nueva, que es pequeña y rápida.
Paso 3: entrenar solo la cabeza
device = "cuda" if torch.cuda.is_available() else "cpu"
modelo = modelo.to(device)
loss_fn = nn.CrossEntropyLoss()
# Solo los parámetros que SÍ requieren gradiente
optimizer = torch.optim.Adam(
[p for p in modelo.parameters() if p.requires_grad], lr=1e-3)
for epoch in range(2): # pocas épocas bastan con transfer learning
modelo.train()
for xb, yb in train_dl:
xb, yb = xb.to(device), yb.to(device)
optimizer.zero_grad()
logits = modelo(xb)
loss = loss_fn(logits, yb)
loss.backward()
optimizer.step()
# Evaluación
modelo.eval()
correct = total = 0
with torch.no_grad():
for xb, yb in test_dl:
xb, yb = xb.to(device), yb.to(device)
pred = modelo(xb).argmax(dim=1)
correct += (pred == yb).sum().item()
total += yb.size(0)
print(f"epoch {epoch} test_acc {correct/total:.3f}")
Para multiclase usamos CrossEntropyLoss, que espera logits y etiquetas enteras (no one-hot). La predicción es el índice del logit más alto: argmax(dim=1).
Tech English: convolution (convolución), pooling, feature map (mapa de rasgos), pretrained (preentrenado), freeze (congelar), fine-tuning (ajuste fino).
Ejercicios
- Lee el efecto de congelar. Entrena el modelo (a) con el cuerpo congelado y (b) descongelando todo (
requires_grad = Truepara todos los parámetros) con un learning rate pequeño,1e-4. Compara accuracy y tiempo por época. ¿Compensó descongelar, dado el coste extra? Interpreta el resultado. - Inspecciona los errores. Toma un lote del test, predice y encuentra 5 imágenes mal clasificadas. Imprime la clase real y la predicha. ¿Los errores tienen sentido (por ejemplo, gato confundido con perro) o son absurdos? ¿Qué te dice eso sobre lo que el modelo "entiende"?