AIモデルは高精度化する一方で、モデルサイズや計算コストも増大しています。エッジ端末やリアルタイム処理では**「軽くて速い」**が重要。そこで今日は、PyTorchとTransformerを使って軽量化するコツを紹介します。
軽量化の主なアプローチ
- パッチサイズの調整(画像)
- 小パッチ(8×8)は精度◎だが計算量増
- 大パッチ(32×32)は軽量だが細部が粗くなる
- タスクに応じてバランスを取る
- 特徴次元の削減(音声)
- Melスペクトログラムは 80〜128バンドで十分なケース多数
- Conv層で圧縮し、Transformerに渡す
- 蒸留(Knowledge Distillation)
- 大モデルで学習 → 小モデルに知識を継承
- 混合精度(AMP)推論
torch.cuda.amp
で半精度化しつつ精度を保つ
簡易コード例(画像+音声)
import torch
from torchvision.models import vit_b_16
vit = vit_b_16(weights=None) # ViT backbone
vit.heads = torch.nn.Identity()
# 画像: 16×16パッチ, 音声: 80band
img_emb = vit(images) # (B, 768)
aud_emb = audio_encoder(audio) # (B, 256)
h = torch.cat([img_emb, aud_emb], dim=-1) # (B, 1024)
h = torch.nn.Linear(1024, 256)(h)
補足:下の実行サンプルは今回のデータ特性(黒背景+微小点)に合わせたTransformer不使用の安全モードです。ViT を使う場合は上の「ViT 版に切り替える最小例」を参考に差し替えてください。
たとえ話
Transformerを軽くするのは、登山で荷物を減らす工夫に似ています。必要な道具は残しつつ、無駄な重りをそぎ落とす。これにより、長時間の稼働や省エネが可能になります。
運用Tips
- モデル軽量化は精度低下とのトレードオフ。必ず評価セットで比較
- パッチやバンド数の削減は段階的に試す
- 異常検知はMahalanobis距離やkNNと組み合わせて精度補完
おまけ:コピペで試せる最小スクリプト(画像のみ)
内容:パッチ平均 × 対角マハラノビスで異常スコア化(Transformer 不使用)。学習なし・CPUでOK。※今回の想定データ(真っ黒な背景にごく小さな色点)では、ViT より ローレベル特徴のほうが強く、軽くて堅牢です。
フォルダ構成
data/
normal_images/ # 正常だけ数十枚〜
test_images/ # 判定したい画像(正常/異常混在OK)
サンプルの位置づけ
- 記事前半では ViT(Transformer) を使った軽量化や融合の設計指針を解説しています。
- 末尾のサンプルは「黒背景+微小な点/線」を想定した ローレベル重視の簡易手法。小さな画素変化を確実に拾うため、パッチ平均+対角マハラノビスを採用しています。
- 課題により最適な特徴量は変わります。複雑背景やテクスチャ差を扱う場合は ViT/PaDiM/PatchCore など パッチ単位の距離を用いる構成が有効です。
使い分けガイド
- ViT 向き:テクスチャ/形状/文脈の違い、自然画像や複雑背景、部品の向きや模様の乱れ。→ Self-Supervised(MAE/SimCLR)+ マハラノビス、もしくは PaDiM/PatchCore(パッチごとに距離)を検討。
- パッチ平均ベース向き:均一背景に孤立した画素・細線・ドット欠陥など“点”を拾いたいケース。高速・堅牢で調整が容易。
- 中間案:ViT を使いつつ パッチ単位の距離で判定(PaDiM系)→ 微小欠陥にも強く、本文の Transformer 方針とも整合。
ViT 版に切り替える最小例(参考)
import timm, torch
model = timm.create_model('deit_tiny_patch16_224', pretrained=True, num_classes=0).eval()
# img -> z = model(preprocess(img)) # (D,)
# 正常zから μ, Σ を推定 → マハラノビス距離で判定(本文と同じ手順)
Colab 用:画像アップロード補助セル
Python(そのままコピペで実行)
下の3セルを順に実行してください。
セル1:関連ライブラリのインストールと定数の設定
#@title Setup (folders & params)
from pathlib import Path
from PIL import Image
import torch, torch.nn.functional as F
import torchvision.transforms.functional as TF
ROOT = Path('data')
NORMAL_DIR = ROOT/'normal_images'
TEST_DIR = ROOT/'test_images'
NORMAL_DIR.mkdir(parents=True, exist_ok=True)
TEST_DIR.mkdir(parents=True, exist_ok=True)
# ---- 安全モードのパラメータ ----
MAX_SIDE = 256 # 最長辺を縮小(巨大画像でのクラッシュ回避)
PATCH = 8 # 4: 点に超敏感 / 8: 標準 / 12-16: 鈍感で誤検知減
EPS = 1e-6 # 数値安定
セル2:正常系画像のアップロード(10枚以上推奨)
#@title Upload NORMAL images
from google.colab import files
print("Upload NORMAL images (multi-select OK)")
uploaded = files.upload()
for name, data in uploaded.items():
with open(NORMAL_DIR/name, "wb") as f:
f.write(data)
print(f"Saved {len(uploaded)} files -> {NORMAL_DIR}")
セル3:異常系、正常系が混在しているテストしたい画像をアップロード
from google.colab import files
print('Upload TEST images (optional). ない場合は Cancel upload を押してください。')
uploaded = files.upload() # キャンセル可能
for name, data in uploaded.items():
with open('data/test_images/'+name, 'wb') as f:
f.write(data)
print(f'Saved {len(uploaded)} files -> data/test_images')
セル4:スコア計算
#@title Detect anomalies (safe-mode, diagonal Mahalanobis)
import torch
from pathlib import Path
from PIL import Image
import torchvision.transforms.functional as TF
def load_gray01(path):
img = Image.open(path).convert('L')
# 最長辺を縮小(アスペクト比維持)
img.thumbnail((MAX_SIDE, MAX_SIDE), Image.BICUBIC)
t = TF.to_tensor(img) # (1,H,W) in [0,1]
# H,W を PATCH の倍数に切り下げ
_, H, W = t.shape
H2 = (H//PATCH)*PATCH; W2 = (W//PATCH)*PATCH
return t[:, :H2, :W2]
def patch_mean_vec(gray): # gray: (1,H,W)
x = gray.unsqueeze(0) # (1,1,H,W)
cols = F.unfold(x, kernel_size=PATCH, stride=PATCH) # (1, p*p, NumP)
pm = cols.mean(dim=1).squeeze(0) # (NumP,)
return pm
def folder_vecs(folder):
files = [p for p in Path(folder).glob('*') if p.is_file()]
Z, names = [], []
for p in sorted(files):
try:
g = load_gray01(str(p))
z = patch_mean_vec(g).float()
Z.append(z); names.append(str(p))
except Exception as e:
print(f"[skip] {p} -> {e}")
if not Z:
raise RuntimeError(f"No readable images in {folder}")
# サイズ違い対策:最小次元に揃える
D = min(z.numel() for z in Z)
Z = torch.stack([z[:D] for z in Z]) # (N,D)
return Z, names
# 1) 正常セットから μ と 分散(対角)を推定
Z_norm, _ = folder_vecs(NORMAL_DIR)
mu = Z_norm.mean(0)
var = Z_norm.var(0, unbiased=True) + EPS # (D,)
def diag_maha(Z):
d = Z - mu
return torch.sqrt((d*d/var).sum(-1))
# 2) 閾値(正常の99%分位)。全て0のケースに備えてフォールバック
score_norm = diag_maha(Z_norm)
thr = torch.quantile(score_norm, 0.99).item()
Z_test, test_names = folder_vecs(TEST_DIR)
score_test = diag_maha(Z_test)
if thr <= 0.0:
pos = score_test[score_test > 0]
thr = float(pos.min()) if pos.numel() > 0 else 1e-6 # “ゼロより少し上”に
# 3) 結果表示
print(f"threshold (99th of normal or fallback) = {thr:.8f}")
k = min(10, len(score_test))
val, idx = torch.topk(score_test, k)
anom = (score_test > thr)
print(f"Predicted ANOMALY: {int(anom.sum())} / {len(anom)}")
for s, i in zip(val.tolist(), idx.tolist()):
label = 'ANOMALY' if s > thr else 'normal' # 同点は normal
print(f"{s:.8f}\t{test_names[i]}\t{label}")
# 参考:分布のレンジ
print("normal score range:", float(score_norm.min()), "->", float(score_norm.max()))
print("test score range:", float(score_test.min()), "->", float(score_test.max()))
サンプルコードを実行した結果
使用する画像はまず、正常系は「ペイント」アプリでキャンバスを真っ黒に塗りつぶしたものを使用しました。

異常系は真っ黒なキャンバスに白で線を引くなどして準備

結果は用意した4枚の異常系のうち3枚を異常系として検知できました。

検知できなかった異常系は白い点を3pixelつけただけのもの。
これは正常系と判断して差し支えないのかも。実務においてはどこまで厳密にするかは調整次第ということですね。
