A simple visualization toolbox (script) for transformer attention visualization
This is a super simple visualization toolbox (script) for transformer attention visualization ✌
1. How to prepare your attention matrix?
Just convert it to numpy array like this 👇
# build an attetion matrixs as torch-output like
token_num = 6
case_num = 3
layer_num = 2
head_num = 4
attention_map_mhml = [np.stack([make_attention_map_mh(head_num, token_num)]*case_num, 0) for _ in range(layer_num)] # 4cases' 3 layers attention, with 3 head per layer( 每个case相同)
_ = [print(i.shape) for i in attention_map_mhml]
"""
>>>(3, 4, 6, 6)
(3, 4, 6, 6)
"""
2. Just try the following lines of code 👇
# import function
from transformer_attention_visualization import *
# build canvas
scale = 3
canvas = np.zeros([120*scale,60*scale]).astype(np.float)
# build an attetion matrixs as torch-output like
token_num =