范文健康探索娱乐情感热点
投稿投诉
热点动态
科技财经
情感日志
励志美文
娱乐时尚
游戏搞笑
探索旅游
历史星座
健康养生
美丽育儿
范文作文
教案论文

使用JAX实现完整的VisionTransformer

  本文将展示如何使用JAX/Flax实现Vision Transformer (ViT),以及如何使用JAX/Flax训练ViT。Vision Transformer
  在实现Vision Transformer时,首先要记住这张图。
  以下是论文描述的ViT执行过程。
  从输入图像中提取补丁图像,并将其转换为平面向量。
  投影到 Transformer Encoder 来处理的维度
  预先添加一个可学习的嵌入([class]标记),并添加一个位置嵌入。
  由 Transformer Encoder 进行编码处理
  使用[class]令牌作为输出,输入到MLP进行分类。细节实现
  下面,我们将使用JAX/Flax创建每个模块。
  1、图像到展平的图像补丁
  下面的代码从输入图像中提取图像补丁。这个过程通过卷积来实现,内核大小为patch_size * patch_size, stride为patch_size * patch_size,以避免重复。class Patches(nn.Module): patch_size: int embed_dim: int def setup(self): self.conv = nn.Conv( features=self.embed_dim, kernel_size=(self.patch_size, self.patch_size), strides=(self.patch_size, self.patch_size), padding="VALID" ) def __call__(self, images): patches = self.conv(images) b, h, w, c = patches.shape patches = jnp.reshape(patches, (b, h*w, c)) return patches
  2和3、对展平补丁块的线性投影/添加[CLS]标记/位置嵌入
  Transformer Encoder 对所有层使用相同的尺寸大小hidden_dim。上面创建的补丁块向量被投影到hidden_dim维度向量上。与BERT一样,有一个CLS令牌被添加到序列的开头,还增加了一个可学习的位置嵌入来保存位置信息。class PatchEncoder(nn.Module): hidden_dim: int @nn.compact def __call__(self, x): assert x.ndim == 3 n, seq_len, _ = x.shape # Hidden dim x = nn.Dense(self.hidden_dim)(x) # Add cls token cls = self.param("cls_token", nn.initializers.zeros, (1, 1, self.hidden_dim)) cls = jnp.tile(cls, (n, 1, 1)) x = jnp.concatenate([cls, x], axis=1) # Add position embedding pos_embed = self.param( "position_embedding",  nn.initializers.normal(stddev=0.02), # From BERT (1, seq_len + 1, self.hidden_dim) ) return x + pos_embed
  4、Transformer encoder
  如上图所示,编码器由多头自注意(MSA)和MLP交替层组成。Norm层 (LN)在MSA和MLP块之前,残差连接在块之后。class TransformerEncoder(nn.Module): embed_dim: int hidden_dim: int n_heads: int drop_p: float mlp_dim: int def setup(self): self.mha = MultiHeadSelfAttention(self.hidden_dim, self.n_heads, self.drop_p) self.mlp = MLP(self.mlp_dim, self.drop_p) self.layer_norm = nn.LayerNorm(epsilon=1e-6)  def __call__(self, inputs, train=True): # Attention Block x = self.layer_norm(inputs) x = self.mha(x, train) x = inputs + x # MLP block y = self.layer_norm(x) y = self.mlp(y, train) return x + y
  MLP是一个两层网络。激活函数是GELU。本文将Dropout应用于Dense层之后。class MLP(nn.Module): mlp_dim: int drop_p: float out_dim: Optional[int] = None @nn.compact def __call__(self, inputs, train=True): actual_out_dim = inputs.shape[-1] if self.out_dim is None else self.out_dim x = nn.Dense(features=self.mlp_dim)(inputs) x = nn.gelu(x) x = nn.Dropout(rate=self.drop_p, deterministic=not train)(x) x = nn.Dense(features=actual_out_dim)(x) x = nn.Dropout(rate=self.drop_p, deterministic=not train)(x) return x
  多头自注意(MSA)
  qkv的形式应为[B, N, T, D],如Single Head中计算权重和注意力后,应输出回原维度[B, T, C=N*D]。class MultiHeadSelfAttention(nn.Module): hidden_dim: int n_heads: int drop_p: float def setup(self): self.q_net = nn.Dense(self.hidden_dim) self.k_net = nn.Dense(self.hidden_dim) self.v_net = nn.Dense(self.hidden_dim) self.proj_net = nn.Dense(self.hidden_dim) self.att_drop = nn.Dropout(self.drop_p) self.proj_drop = nn.Dropout(self.drop_p) def __call__(self, x, train=True): B, T, C = x.shape # batch_size, seq_length, hidden_dim N, D = self.n_heads, C // self.n_heads # num_heads, head_dim q = self.q_net(x).reshape(B, T, N, D).transpose(0, 2, 1, 3) # (B, N, T, D) k = self.k_net(x).reshape(B, T, N, D).transpose(0, 2, 1, 3) v = self.v_net(x).reshape(B, T, N, D).transpose(0, 2, 1, 3) # weights (B, N, T, T) weights = jnp.matmul(q, jnp.swapaxes(k, -2, -1)) / math.sqrt(D) normalized_weights = nn.softmax(weights, axis=-1) # attention (B, N, T, D) attention = jnp.matmul(normalized_weights, v) attention = self.att_drop(attention, deterministic=not train) # gather heads attention = attention.transpose(0, 2, 1, 3).reshape(B, T, N*D) # project out = self.proj_drop(self.proj_net(attention), deterministic=not train) return out
  5、使用CLS嵌入进行分类
  最后MLP头(分类头)。class ViT(nn.Module): patch_size: int embed_dim: int hidden_dim: int n_heads: int drop_p: float num_layers: int mlp_dim: int num_classes: int def setup(self): self.patch_extracter = Patches(self.patch_size, self.embed_dim) self.patch_encoder = PatchEncoder(self.hidden_dim) self.dropout = nn.Dropout(self.drop_p) self.transformer_encoder = TransformerEncoder(self.embed_dim, self.hidden_dim, self.n_heads, self.drop_p, self.mlp_dim) self.cls_head = nn.Dense(features=self.num_classes) def __call__(self, x, train=True): x = self.patch_extracter(x) x = self.patch_encoder(x) x = self.dropout(x, deterministic=not train) for i in range(self.num_layers): x = self.transformer_encoder(x, train) # MLP head x = x[:, 0] # [CLS] token x = self.cls_head(x) return x使用JAX/Flax训练
  现在已经创建了模型,下面就是使用JAX/Flax来训练。
  数据集
  这里我们直接使用 torchvision的CIFAR10.
  首先是一些工具函数def image_to_numpy(img): img = np.array(img, dtype=np.float32) img = (img / 255. - DATA_MEANS) / DATA_STD return img def numpy_collate(batch): if isinstance(batch[0], np.ndarray): return np.stack(batch) elif isinstance(batch[0], (tuple, list)): transposed = zip(*batch) return [numpy_collate(samples) for samples in transposed] else: return np.array(batch)
  然后是训练和测试的dataloadertest_transform = image_to_numpy train_transform = transforms.Compose([ transforms.RandomHorizontalFlip(), transforms.RandomResizedCrop((IMAGE_SIZE, IMAGE_SIZE), scale=CROP_SCALES, ratio=CROP_RATIO), image_to_numpy ]) # Validation set should not use the augmentation. train_dataset = CIFAR10("data", train=True, transform=train_transform, download=True) val_dataset = CIFAR10("data", train=True, transform=test_transform, download=True) train_set, _ = torch.utils.data.random_split(train_dataset, [45000, 5000], generator=torch.Generator().manual_seed(SEED)) _, val_set = torch.utils.data.random_split(val_dataset, [45000, 5000], generator=torch.Generator().manual_seed(SEED)) test_set = CIFAR10("data", train=False, transform=test_transform, download=True) train_loader = torch.utils.data.DataLoader( train_set, batch_size=BATCH_SIZE, shuffle=True, drop_last=True, num_workers=2, persistent_workers=True, collate_fn=numpy_collate, ) val_loader = torch.utils.data.DataLoader( val_set, batch_size=BATCH_SIZE, shuffle=False, drop_last=False, num_workers=2, persistent_workers=True, collate_fn=numpy_collate, ) test_loader = torch.utils.data.DataLoader( test_set, batch_size=BATCH_SIZE, shuffle=False, drop_last=False, num_workers=2, persistent_workers=True, collate_fn=numpy_collate, )
  初始化模型
  初始化ViT模型def initialize_model( seed=42, patch_size=16, embed_dim=192, hidden_dim=192, n_heads=3, drop_p=0.1, num_layers=12, mlp_dim=768, num_classes=10 ): main_rng = jax.random.PRNGKey(seed) x = jnp.ones(shape=(5, 32, 32, 3)) # ViT model = ViT( patch_size=patch_size, embed_dim=embed_dim, hidden_dim=hidden_dim, n_heads=n_heads, drop_p=drop_p, num_layers=num_layers, mlp_dim=mlp_dim, num_classes=num_classes ) main_rng, init_rng, drop_rng = random.split(main_rng, 3) params = model.init({"params": init_rng, "dropout": drop_rng}, x, train=True)["params"] return model, params, main_rng vit_model, vit_params, vit_rng = initialize_model()
  创建TrainState
  在Flax中常见的模式是创建管理训练的状态的类,包括轮次、优化器状态和模型参数等等。还可以通过在apply_fn中指定apply_fn来减少学习循环中的函数参数列表,apply_fn对应于模型的前向传播。def create_train_state( model, params, learning_rate ): optimizer = optax.adam(learning_rate) return train_state.TrainState.create( apply_fn=model.apply, tx=optimizer, params=params )  state = create_train_state(vit_model, vit_params, 3e-4)
  循环训练def train_model(train_loader, val_loader, state, rng, num_epochs=100): best_eval = 0.0 for epoch_idx in tqdm(range(1, num_epochs + 1)): state, rng = train_epoch(train_loader, epoch_idx, state, rng) if epoch_idx % 1 == 0: eval_acc = eval_model(val_loader, state, rng) logger.add_scalar("val/acc", eval_acc, global_step=epoch_idx) if eval_acc >= best_eval: best_eval = eval_acc save_model(state, step=epoch_idx) logger.flush() # Evaluate after training test_acc = eval_model(test_loader, state, rng) print(f"test_acc: {test_acc}")  def train_epoch(train_loader, epoch_idx, state, rng): metrics = defaultdict(list) for batch in tqdm(train_loader, desc="Training", leave=False): state, rng, loss, acc = train_step(state, rng, batch) metrics["loss"].append(loss) metrics["acc"].append(acc) for key in metrics.keys(): arg_val = np.stack(jax.device_get(metrics[key])).mean() logger.add_scalar("train/" + key, arg_val, global_step=epoch_idx) print(f"[epoch {epoch_idx}] {key}: {arg_val}") return state, rng
  验证def eval_model(data_loader, state, rng): # Test model on all images of a data loader and return avg loss correct_class, count = 0, 0 for batch in data_loader: rng, acc = eval_step(state, rng, batch) correct_class += acc * batch[0].shape[0] count += batch[0].shape[0] eval_acc = (correct_class / count).item() return eval_acc
  训练步骤
  在train_step中定义损失函数,计算模型参数的梯度,并根据梯度更新参数;在value_and_gradients方法中,计算状态的梯度。在apply_gradients中,更新TrainState。交叉熵损失是通过apply_fn(与model.apply相同)计算logits来计算的,apply_fn是在创建TrainState时指定的。@jax.jit def train_step(state, rng, batch): loss_fn = lambda params: calculate_loss(params, state, rng, batch, train=True) # Get loss, gradients for loss, and other outputs of loss function (loss, (acc, rng)), grads = jax.value_and_grad(loss_fn, has_aux=True)(state.params) # Update parameters and batch statistics state = state.apply_gradients(grads=grads) return state, rng, loss, acc
  计算损失def calculate_loss(params, state, rng, batch, train): imgs, labels = batch rng, drop_rng = random.split(rng) logits = state.apply_fn({"params": params}, imgs, train=train, rngs={"dropout": drop_rng}) loss = optax.softmax_cross_entropy_with_integer_labels(logits=logits, labels=labels).mean() acc = (logits.argmax(axis=-1) == labels).mean() return loss, (acc, rng)结果
  训练结果如下所示。在Colab pro的标准GPU上,训练时间约为1.5小时。
  test_acc: 0.7704000473022461
  作者:satojkovic

顾炎武为什么能提出天下兴亡,匹夫有责天下兴亡,匹夫有责,这句爱国名言可谓家喻户晓影响深远,其源于明末清初著名思想家顾炎武的日知录正始。这句话言简意赅地指出个人应对国家兴亡承担责任,对激发中华儿女的爱国精神和责任意识发房地产曙光已现,关键在于软着陆展望2023告别三高,进入分化时代。文中国企业家记者李艳艳编辑周春林头图摄影曾靖郁亮说,曙光已现。2022年12月16日的万科临时股东大会上,万科董事局主席郁亮现身。他谈到行业变化,称时隔20警惕大面积的复阳重现现象再次发生虽然此轮疫情经过近一个月的大面积扩散,大部分市民都已经经历了第一次阳性过程,保守估计已有90以上人员都已杨过了。可以说这次几乎无一幸免,都未能逃过此劫,这次传染速度如此之快,是大家海软近百年办校擦亮党建品牌全校大思政格局逐渐形成新海南客户端南海网南国都市报1月6日消息(记者苏桂除)1月5日,海南软件职业技术学院(以下简称海软)党委副书记刘明鹏教授做客新海南客户端校长在线访谈,就建校百年历史的海软在教育教学百年一遇!美国国会瘫痪了作者丨吴斌编辑丨和佳图源丨新华社2023年刚开始,美国民众就无奈地迎来一场众议院议长难产的百年一遇闹剧。美东时间1月4日,深陷领导权乱斗的美国众议院又三次拒绝了共和党人麦卡锡(Ke母乳喂养好处多,更利于神经发育,宝妈要知道镜子随着宝宝的出生,孩子带给我们的快乐是语无伦次的。可是有的宝妈由于各种不同原因在母乳还是奶粉喂养上纠结。其实,母乳的优势是毋庸置疑的,他们不但能够给宝宝带来更强的免疫力更充沛的营7岁孩子成语量特别大,只因妈妈教得好,学成语真不难!小孩子学习成语,自有价值和意义。成语虽然只是短短的几个字,却含有一个故事,一个道理,一段引申含义。小孩子学会成语,能提升语言表达力感染力,提升文化修养。同时,还能丰富写作素材词汇,童书,知道怎么办2022年读过的童书,特别喜欢的就有卧底机器人,我叫小朵,我是如假包换100的人类。你该不会信了吧?其实我是一个高科技机器人。你千万不要泄露这个秘密,轻而易举,激起我的期待,一页一儿戲儿戏,在大人语境中视为办事轻率不认真不负责任。但在儿童世界,那是全身心很认真快乐无比的事情。从儿戏里,孩子们获得了怎样的快乐,学到了怎样的处人之道,可能并像大人想像的那样简单。儿戏一起读绘本DiaryofaWorm引言此文旨在协助家长在家陪孩子一起读英文绘本。会陆续上传50本初级绘本的图文讲解,收集在一起读绘本的合集里,便于查阅。DiaryofaWorm蚯蚓的日记Myreportcard。我孕妇如果长期咳嗽不止,原来还能吃这种水果!最近关于如何止咳的话题真的频上热搜,特别是孕妇咳嗽剧烈就更尴尬小编作为一个孕妇在阳了之后也是感同深受,连续多日的咳嗽打喷嚏导致漏尿肚子发紧,哪怕是在自己老公面前都觉得特别尴尬甚至还
20几岁真的很痛苦,急于求成且一事无成hi,我是不晚,点击上方关注我,每天为你分享成长干货我今年20岁了,还算年轻。今天这篇是我写的第20篇文章了,也算在写作这条路上坚持了一个月吧。文章不公开的时候写的很开心,公开发表分享上个月选股回顾12月12日夜间选的股,有头条截图为证,聊一聊启发一下共勉那次总共选了四只股,新华百货,百合花,伊力特,云南铜业。新华百货这个就牛叉了,连续四板涨停,我居然忘记吹牛皮了呲牙。选股后2022IPO最赚钱中介机构排行榜本文来源时代商学院作者彭晨雨来源时代商学院作者彭晨雨编辑郑少娜2022年,在注册制的深入推进下,哪家中介机构啃下了A股IPO的最大一块蛋糕?据同花顺iFinD数据,2022年,全国被仙女说面膜上了一课面膜只能补水肯定不止我这么认为,昨晚我和他们争论面膜到底能不能拯救熬夜脸,我说绝不可能!有人就拿了张仙女说面膜说要好好给我上一课!她让我记住我那张熬夜脸。敷的时候感觉跟其他面膜没什终于明白高启强为啥对大嫂念念不忘了,看完这个你就知道了这组片子出自2022年时装男士九月刊大嫂原来早就有大嫂风范,可塑性极强且肢体表现力一绝,镜头下尽显酷飒霸气无论是性感穿搭还是复古西装造型统统都能hold得住,尽显高级质感别说高启强回家不知何时,我们把自己划归为城里人,于是,回老家就变成了一个温暖的话题。回老家有很多种感情理由,但有一种情怀不变,就是想回去小时候的家,继续在那个家里被称呼为儿子或女儿,闻着孩童时的小王子我看了四十四次日落一场旅行,是七个星球的流转。小王子犹如透亮的镜子,照出了荒唐的成人世界。字里行间都能击中人心,可能正在指向我们,或者是指向即将成为这样大人的我们。这本书是作者献给从前当过孩子的那个短短的五句话,说透了人性!(精辟)作者拾壹言01序言人活一世,最难的不是赚钱,而是学会说透人性。如果你不会说话,那么可以先去学会闭嘴。因为,说出一句话,要花上很多力气,即使是很小的一件事,也需要花费许多时间来完成。清荷札记阅尽千帆,初心依旧阅尽千帆,初心依旧作者清荷札记时光不语,亲情恒久,犹如人生中意味深长的祈念,让爱静静地待在我的平淡日常中,人生一辈子,都是因果。生活的意味,仿佛能在庸常的光尘中,早已经成为我心中期这世界很喧嚣,做你自己就好云朵也是一种风景,眼泪也是一种品尝。风吹过,浪埋过,岩石依旧在泰戈尔1。总是假装很乖,可我知道,有时候,你也会累2。我明白,颤抖着的指尖是你满腹的心事3。笑一个吧!就当是对自己的奖当你25岁,身边的人都结婚了,你会恐慌吗?头号解忧馆很喜欢这句话有人22岁就开始相夫教子,有人32岁了还在追求梦想,有人为了车房拼了一辈子,而有一些人,买一辆摩托车就走遍大好河山。生活是自己的,千万别为难自己,有什么样的能