迎接多模態(tài)AI-如何訓(xùn)練您自己的Clip家貓-by 神櫻團(tuán)隊(duì)-ok8_第1頁(yè)
迎接多模態(tài)AI-如何訓(xùn)練您自己的Clip家貓-by 神櫻團(tuán)隊(duì)-ok8_第2頁(yè)
迎接多模態(tài)AI-如何訓(xùn)練您自己的Clip家貓-by 神櫻團(tuán)隊(duì)-ok8_第3頁(yè)
迎接多模態(tài)AI-如何訓(xùn)練您自己的Clip家貓-by 神櫻團(tuán)隊(duì)-ok8_第4頁(yè)
迎接多模態(tài)AI-如何訓(xùn)練您自己的Clip家貓-by 神櫻團(tuán)隊(duì)-ok8_第5頁(yè)
已閱讀5頁(yè),還剩102頁(yè)未讀, 繼續(xù)免費(fèi)閱讀

下載本文檔

版權(quán)說(shuō)明:本文檔由用戶(hù)提供并上傳,收益歸屬內(nèi)容提供方,若內(nèi)容存在侵權(quán),請(qǐng)進(jìn)行舉報(bào)或認(rèn)領(lǐng)

文檔簡(jiǎn)介

1迎接多模態(tài)AIBy神櫻AI團(tuán)隊(duì)/高煥堂教授指導(dǎo)***本文摘自高煥堂的下列書(shū)籍***2OpenAI在互聯(lián)網(wǎng)上找到大量可用的監(jiān)督來(lái)源測(cè)出我們數(shù)據(jù)集裡的那一個(gè)文本(Text)與它實(shí)一個(gè)輸入句子,它將能夠檢索與該句子相對(duì)應(yīng)的最相關(guān)的圖像。3集上進(jìn)行訓(xùn)練時(shí),也可以當(dāng)做分類(lèi)器來(lái)用。然語(yǔ)言>監(jiān)督中有效地學(xué)習(xí)<視覺(jué)>概念。也就是,基於自己企業(yè)的素材圖庫(kù)來(lái)(基本架構(gòu)4的特徵,然後映射到潛藏空間裡的一個(gè)新的點(diǎn)。接著,經(jīng)由矩陣運(yùn)算,計(jì)算出位於<新的點(diǎn)>附近的一些點(diǎn)測(cè)值了。以中藥材的CLIP為例5助提取個(gè)文本的特徵,然後將各文本(隨意)對(duì)映到潛藏空間的點(diǎn)6展開(kāi)訓(xùn)練7以上訓(xùn)練完成了。其智慧表達(dá)於模型裡的參數(shù)(即w預(yù)測(cè)(一):從圖像找文本8預(yù)測(cè)(二):從文本找圖像9觀察訓(xùn)練過(guò)程的相似度變化以陣列呈現(xiàn)最新的相似度。如下圖:與<提詞-2>相似度提高了。Step-1:檢測(cè)編譯環(huán)境(及所import的套件)參數(shù)定義程式碼:##myConfig.pyimage_path="C:/Moein/AI/Datasets/Flicker-8k/Images"captions_path="C:/Moein/AI/Datasets/Flicker-8k"device=torch.device("cuda"iftorctext_encoder_model="distilbert-base-uncased"text_embedding=768text_tokenizer="distilbert-base-uncased"pretrained=False#forbothimageencodetrainable=False#for#forprojectionhead;usedforbothimageandtextencodersnum_projection_layer#clip_ex_001.pyimporttorchvision.modelsasmodelsfromtransformersimportDistilBertModel,DistilBertConfig#在這myConfig裡的projection_#原來(lái)是:256Encodeimagestoafixedsizevedef__init__(self):super().__init__()#載入ResNet50預(yù)訓(xùn)練模型self.model=models.resnet50(weights=models.ResNet50_Weights.IMAGENET1K_V1)#遷移學(xué)習(xí)不需要梯度(不更改權(quán)重)forparaminself.model.parameters():#增添自己的分類(lèi)器,ResNet50的輸出,成為分類(lèi)器的輸入fc_featuresfc_features=self.model.fc.in_featuresself.model.fc=nn.Sequential(nn.Linear(fc_features,nn.Softmax(dim=1))defforward(self,x):classTextEncoder(nn.Module):def__init__(self,model_name=CFG.text_encoder_model,pretrained=CFG.pretrained,trainable=CFG.trainable):super().__init__()self.model=DistilBertModel.from_pretrained(else:self.model=DistilBertModel(config=DistilBertConfig())forpinself.model.parameters(#weareusingtheCLStokenhiddenembeddingself.target_token_idx=0defforward(self,input_ids,attention_mask):output=self.model(input_ids=input_ids,attlast_hidden_state=output.last_hidden_sreturnlast_hidden_state[:,self.target_tokenclassProjectionHead(def__init__(self,embedding_dim,projectionprojection_dim=CFG.projection_dim,dropout=CFG.dropoutsuper().__init__()jection=nn.Linear(embedding_dim,projection_dim)self.gelu=nn.GELU()self.fc=nn.Linear(projection_dim,projection_dim)self.dropout=nn.Dropout(dropout)self.layer_norm=nn.defforward(self,x):x=self.gelu(projected)classCLIPModel(nn.Module):def__init__(self,temperature=CFG.temperature,image_embedding=CFG.image_embedding,text_embedding=CFG.text_embedding,super().__init__()self.image_encoder=ImageEncoder()self.text_encoder=TextEncoder()self.image_projection=ProjectionHead(embedding_dim=image_embedding)self.text_projection=ProjectionHead(embedding_dim=text_embedding)self.temperature=temperaturedefforward(self,batc#GettingImageandTextFeaturesimage_features=self.image_encoder(batch["imagtext_features=self.text_encoder(input_ids=batch["input_ids"],atinput_ids=batch["input_ids"],at)#GettingImageandTextEmbeddings(withsamedimension)image_embeddings=self.image_projection(image_featurtext_embeddings=self.text_projection(text_features)logits=(text_embeddings@image_embeddings.T)/simages_similarity=image_embeddings@image_emtexts_similarity=text_embeddings@text_embeddings.T(images_similarity+texts_similarity)/2*self.temperature,)texts_loss=cross_entropy(logits,targets,reduction='none')images_loss=cross_entropy(logits.T,targets.T,reduction='nloss=(images_loss+texts_loss)/2.0#returnloss.mean()defcross_entropy(preds,targets,reduction='none'):log_softmax=nn.LogSoftmax(dim=-1)loss=(-targets*log_softmax(preds))returnloss.mean()model=CLIPModel()#ENDStep-2:準(zhǔn)備訓(xùn)練資料於CLIP使用著名的預(yù)訓(xùn)練模型:Resn]]接下來(lái),繼續(xù)微調(diào)一下程式。程式碼如下:接下來(lái),繼續(xù)微調(diào)一下程式。程式碼如下:#clip_ex_002.pyimporttorchvision.modelsasmodelsfromtorchvisionimporttransformsfromtorchvision.datasetsimportImageFolderfromtorch.utils.dataimportDataset,DataLoaderfromtransformersimportDistilBertModel,DistilBertConfigfromtransformersimportAutoTokenizer#在這myConfig裡的projection_#原來(lái)是:256classImageEncoder(nn.Module):Encodeimagestoafixedsizevedef__init__(self):super().__init__()#載入ResNet50預(yù)訓(xùn)練模型selfself.model=models.resnet50(weights=models.ResNet50_Weights.IMAGENET1K_V1)#遷移學(xué)習(xí)不需要梯度(不更改權(quán)重)forparaminself.model.parameters():#增添自己的分類(lèi)器,ResNet50的輸出,成為分類(lèi)器的輸入fc_features=self.model.fc.in_features#萃取2048個(gè)特徵self.model.fc=nn.Sequential(nn.Linear(fc_features,CFG.image_embedding),#nn.Softmax(dim=1))defforward(self,x):returnself.model(x)classTextEncoder(nn.Module):def__init__(self,model_name=CFG.text_encoder_model,pretrained=CFG.pretratrainable=CFG.trainable):super().__init__()self.model=DistilBertModel.from_pretrained(else:self.model=DistilBertModel(config=DistilBertConfig())forpinself.model.parameters(#weareusingtheCLStokenhiddenreembeddingself.target_token_idx=0defforward(self,input_ids,attention_output=self.model(input_ids=input_ids,attlast_hidden_state=output.last_hidden_sreturnlast_hidden_state[:,self.target_tokenreturnlast_hidden_state[:,self.target_tokenclassProjectionHeaddef__init__(self,embedding_dim,projection_dim=CFG.projection_dim,dropout=CFG.dropoutsuper().__init__()jection=nn.Linear(embedding_dim,projection_dim)self.gelu=nn.GELU()self.fc=nn.Linear(projection_dim,projection_dim)self.dropout=nn.Dropout(dropout)self.layer_norm=nn.defforward(self,x):x=self.gelu(projected)classCLIPModel(nn.Module):def__init__(self,temperature=CFG.temperature,image_embedding=CFG.image_embedding,text_embedding=CFG.text_embedding,super().__init__()self.image_encoder=ImageEncoder()self.text_encoder=TextEncoder()self.image_projection=ProjectionHead(embedding_dim=image_embedding)selfself.text_projection=ProjectionHead(embedding_dim=text_embedding)self.temperature=temperaturedefforward(self,batc#GettingImageandTextFeaturesimage_features=self.image_encoder(batch["imagtext_features=self.text_encoder(input_ids=batch["input_ids"],at)#GettingImageandTextEmbeddings(withsamedimension)image_embeddings=self.image_projection(image_featurtext_embeddings=self.text_projection(text_features)logits=(text_embeddings@image_embeddings.T)/simages_similarity=image_embeddings@image_emtexts_similarity=text_embeddings@text_embeddi(images_similarity+texts_similarity)/2*self.temperature,)texts_loss=cross_entropy(logits,targets,reduction='none')images_loss=cross_entropy(logits.T,targets.T,reduction='none')loss=(images_loss+texts_loss)/2.0#returnloss.mean()defcross_entropy(preds,targets,reduction='none'):log_softmax=nn.LogSoftmax(dim=-1)loss=(-targets*log_softmax(preds))returnloss.mean()model=CLIPModel()root_path='c:/oopc/m_clip_data/traiim_list+=[os.path.join(root_path,'class_02',i)foriinos.listdir(root_path+'class_0im_list+=[os.path.join(root_path,'class_03',i)foriinos.listdir(root_path+'class_0im_list+=[os.path.join(root_path,'class_04',i)foriinos.listdir(root_path+'class_0#把圖片轉(zhuǎn)換成Tensortransform=transforms.Compose([transforms.Resize((224,224)),transforms.ToTensor()data_set=ImageFolder(root_path,transform=transform)train_loader=DataLoader(data_set,batch_size=7,shuffle=False)"一棵人蔘","一條韓國(guó)人蔘","一盤(pán)枸杞","這是銀川枸杞","這是兩只靈芝","許多靈芝","玫瑰香菇",]checkpoint="bert-base-uncased"tokenizer=AutoTokenizer.from_pretrained(checkpoint)batch=tokenizer(sequences,padding=True,truncation=True,return_tensors="pt")forbatch_idx,(data,target)inenumerate#END驗(yàn)<訓(xùn)練步驟>的程式碼,其目標(biāo)只訓(xùn)練一回合,輸出一次los##clip_ex_003.pyimporttorchvision.modelsasmodelsfromtorchvisionimporttransformsfromfromtorchvision.datasetsimportImageFolderfromtorch.utils.dataimportDataset,DataLoaderfromtransformersimportDistilBertModel,DistilBertConfigfromtransformersimportAutoTokenizerclassImageEncoder(nn.Module):Encodeimagestoafixedsizevedef__init__(self):super().__init__()#載入ResNet50預(yù)訓(xùn)練模型self.model=models.resnet50(weights=models.ResNet50_Weights.IMAGENET1K_V1)#遷移學(xué)習(xí)不需要梯度(不更改權(quán)重)forparaminself.model.parameters():#增添自己的分類(lèi)器,ResNet50的輸出,成為分類(lèi)器的輸入fc_features=self.model.fc.in_features#萃取2048個(gè)特徵self.model.fc=nn.Sequential(nn.Linear(fc_features,CFG.image_embedding),#nn.Softmax(dim=1))defforward(self,x):classTextEncoder(nn.Module):def__init__(self,model_name=CFG.text_encoder_model,pretrained=CFG.pretratrainable=CFG.trainable):supersuper().__init__()self.model=DistilBertModel.from_pretrained(else:self.model=DistilBertModel(config=DistilBertConfig())forpinself.model.parameters(#weareusingtheCLStokenhiddenembeddingself.target_token_idx=0defforward(self,input_ids,attention_output=self.model(input_ids=input_ids,attention_mask=atlast_hidden_state=output.last_hidden_sreturnlast_hidden_state[:,self.target_tokenclassProjectionHeaddef__init__(self,embedding_dim,projection_dim=CFG.projection_dim,dropout=CFG.dropoutsuper().__init__()jection=nn.Linear(embedding_dim,projection_dim)self.gelu=nn.GELU()self.fc=nn.Linear(projection_dim,projection_dim)self.dropout=nn.Dropout(dropout)self.layer_norm=nn.defforward(self,x):x=self.gelu(projected)classCLIPModel(nn.Module):def__init__(self,temperature=CFG.temperature,image_embedding=CFG.image_embedding,text_embedding=CFG.text_embedding,super().__init__()self.image_encoder=ImageEncoder()self.text_encoder=TextEncoder()self.image_projection=ProjectionHead(embedding_dim=image_embedding)self.text_projection=ProjectionHead(embedding_dim=text_embedding)self.temperature=temperaturedefforward(self,batc#GettingImageandTextFeaturesimage_features=self.image_encoder(batch["imagtext_features=self.text_encoder(input_ids=batch["input_ids"],at)#GettingImageandTextEmbeddings(withsamedimension)image_embeddings=self.image_projection(image_features)text_embeddings=self.text_projection(text_features)logits=(text_embeddings@image_embeddings.T)/simages_similarity=image_embeddings@image_emtexts_similarity=text_embeddings@text_embeddings.T(images_similarity+texts_similarity)/2*self.temperature,)textstexts_loss=cross_entropy(logits,targets,reduction='none')images_loss=cross_entropy(logits.T,targets.T,reduction='nonloss=(images_loss+texts_loss)/2.0#returnloss.mean()defcross_entropy(preds,targets,reduction='none'):log_softmax=nn.LogSoftmax(dim=-1)loss=(-targets*log_softmax(preds))returnloss.mean()model=CLIPModel()optimizer=torch.optim.AdamW(model.parameters(),lr=CFG.lr,weight_decay=CFG.weight_decay)root_path='c:/oopc/m_clip_data/traiim_list+=[os.path.join(root_path,'class_02',i)foriinos.listdir(root_path+'class_0im_list+=[os.path.join(root_path,'class_03',i)foriinos.listdir(root_path+'class_0im_list+=[os.path.join(root_path,'class_04',i)foriinos.listdir(root_path+'class_04')#把圖片轉(zhuǎn)換成Tensortransform=transforms.Compose([transforms.Resize((224,224)),transforms.ToTensor()data_set=ImageFolder(root_path,transform=transform)train_loader=DataLoader(data_set,batch_size=7,shuffle=False)"一棵人蔘","一條韓國(guó)人蔘","一盤(pán)枸杞","這是銀川枸杞","這是兩只靈芝","許多靈芝","玫瑰香菇",]checkpoint="bert-base-uncased"tokenizer=AutoTokenizer.from_pretrained(checkpoint)batch=tokenizer(sequences,padding=True,truncation=True,return_tensors="pt")forbatch_idx,(data,target)inenumerateoptimizer.zero_grad()optimizer.step()#ENDStep-4.訓(xùn)練多個(gè)回合,觀察loss值是否持續(xù)下降##clip_ex_004_train.pyimporttorchvision.modelsasmodelsfromtorchvisionimporttransformsfromtorchvision.datasetsimportImageFolderfromtorch.utils.dataimportDataset,DataLoaderfromtransformersimportDistilBertModel,DistilBertConfigfromtransformersimportAutoTokenizerclassImageEncoder(nn.Module):Encodeimagestoafixedsizevedef__init__(self):super().__init__()#載入ResNet50預(yù)訓(xùn)練模型self.model=models.resnet50(weights=models.ResNet50_Weights.IMAGENET1K_V1)#遷移學(xué)習(xí)不需要梯度(不更改權(quán)重)#遷移學(xué)習(xí)不需要梯度(不更改權(quán)重)forparaminself.model.parameters():#增添自己的分類(lèi)器,ResNet50的輸出,成為分類(lèi)器的輸入fc_features=self.model.fc.in_features#萃取2048個(gè)特徵self.model.fc=nn.Sequential(nn.Linear(fc_features,CFG.image_embedding),#nn.Softmax(dim=1))defforward(self,x):classTextEncoder(nn.Module):def__init__(self,model_name=CFG.text_encoder_model,pretrained=CFG.pretratrainable=CFG.trainable):super().__init__()self.model=DistilBertModel.from_pretrained(else:self.model=DistilBertModel(config=DistilBertConfig())forpinself.model.parameters(#weareusingtheCLStokenhiddenembeddingself.target_token_idx=0defforward(self,input_ids,attention_output=self.model(input_ids=input_ids,attention_mask=atlast_hidden_state=output.last_hidden_sreturnlast_hidden_state[:,self.target_tokenclassProjectionHeaddef__init__(self,embedding_dim,projection_dim=CFG.projection_dim,dropout=CFG.dropoutsuper().__init__()jection=nn.Linear(embedding_dim,projection_dim)self.gelu=nn.GELU()self.fc=nn.Linear(projection_dim,projection_dim)self.dropout=nn.Dropout(dropout)self.layer_norm=nn.defforward(self,x):x=self.gelu(projected)classCLIPModel(nn.Module):def__init__(self,temperature=CFG.temperature,image_embedding=CFG.image_embedding,text_embedding=CFG.text_embedding,super().__init__()self.image_encoder=ImageEncoder()self.text_encoder=TextEncoder()self.image_projection=ProjectionHead(embedding_dim=image_embedding)self.text_projection=ProjectionHead(embedding_dim=text_embedding)self.temperatureself.temperature=temperaturedefforward(self,batc#GettingImageandTextFeaturesimage_features=self.image_encoder(batch["imagtext_features=self.text_encoder(input_ids=batch["input_ids"],at)#GettingImageandTextEmbeddings(withsamedimension)image_embeddings=self.image_projection(image_features)text_embeddings=self.text_projection(text_features)logits=(text_embeddings@image_embeddings.T)/simages_similarity=image_embeddings@image_emtexts_similarity=text_embeddings@text_embeddings.T(images_similarity+texts_similarity)/2*self.temperature,)texts_loss=cross_entropy(logits,targets,reduction='none')images_loss=cross_entropy(logits.T,targets.T,reduction='nonloss=(images_loss+texts_loss)/2.0#returnloss.mean()defcross_entropy(preds,targets,reduction='none'):log_softmax=nn.LogSoftmax(dim=-1)loss=(-targets*log_softmax(preds))returnloss.mean()model=CLIPModel()optimizer=torch.optim.AdamW(modelmodel.parameters(),lr=CFG.lr,weight_decay=CFG.weight_decay)root_path='c:/oopc/m_clip_data/traiim_list+=[os.path.join(root_path,'class_02',i)foriinos.listdir(root_path+'class_0im_list+=[os.path.join(root_path,'class_03',i)foriinos.listdir(root_path+'class_0im_list+=[os.path.join(root_path,'class_04',i)foriinos.listdir(root_path+'class_04')#把圖片轉(zhuǎn)換成Tensortransform=transforms.Compose([transforms.Resize((224,224)),transforms.ToTensor()data_set=ImageFolder(root_path,transform=transform)train_loader=DataLoader(data_set,batch_size=7,shuffle=False)checkpoint="bert-base-uncased"tokenizer=AutoTokenizer.from_pretrained(checkpoint)"一棵人蔘","一條韓國(guó)人蔘","一盤(pán)枸杞","這是銀川枸杞","這是兩只靈芝","許多靈芝","玫瑰香菇",]batch=tokenizer(sequences,padding=True,truncation=True,return_tensors="pt")forbatch_idx,(data,target)inenumerateoptimizer.zero_grad()optimizer.step()#END定義裡,現(xiàn)在把它移到類(lèi)別定義之外。##clip_ex_005_loss.pyimporttorchvision.modelsasmodelsfromtorchvisionimporttransformsfromtorchvision.datasetsimportImageFolderfromtorch.utils.dataimportDataset,DataLoaderfromtransformersimportDistilBertModel,DistilBertConfigfromtransformersimportAutoTokenizerclassImageEncoder(nn.Module):Encodeimagestoafixedsizevedef__init__(self):super().__init__()#載入ResNet50預(yù)訓(xùn)練模型self.model=models.resnet50(weights=models.ResNet50_Weights.IMAGENET1K_V1)#遷移學(xué)習(xí)不需要梯度(不更改權(quán)重)forparaminself.model.parameters():#增添自己的分類(lèi)器,ResNet50的輸出,成為分類(lèi)器的輸入fc_features=self.model.fc.in_features#萃取2048個(gè)特徵selfself.model.fc=nn.Sequential(nn.Linear(fc_features,CFG.image_embedding),#nn.Softmax(dim=1))defforward(self,x):classTextEncoder(nn.Module):def__init__(self,model_name=CFG.text_encoder_model,pretrained=CFG.pretratrainable=CFG.trainable):super().__init__()self.model=DistilBertModel.from_pretrained(else:self.model=DistilBertModel(config=DistilBertConfig())forpinself.model.parameters(#weareusingtheCLStokenhiddenembeddingself.target_token_idx=0defforward(self,input_ids,attention_output=self.model(input_ids=input_ids,attention_mask=last_hidden_state=output.last_hidden_sreturnlast_hidden_state[:,self.target_tokenclassProjectionHeaddef__init__(self,embedding_dim,projection_dim=CFG.projection_dim,dropout=CFG.dropoutsupersuper().__init__()jection=nn.Linear(embedding_dim,projection_dim)self.gelu=nn.GELU()self.fc=nn.Linear(projection_dim,projection_dim)self.dropout=nn.Dropout(dropout)self.layer_norm=nn.defforward(self,x):x=self.gelu(projected)classCLIPModel(nn.Module):def__init__(self,temperature=CFG.temperature,image_embedding=CFG.image_embedding,text_embedding=CFG.text_embedding,super().__init__()self.image_encoder=ImageEncoder()self.text_encoder=TextEncoder()self.image_projection=ProjectionHead(embedding_dim=image_embedding)self.text_projection=ProjectionHead(embedding_dim=text_embedding)self.temperature=temperaturedefforward(self,batc#GettingImageandTextFeaturesimage_features=self.image_encoder(batch["imagtext_features=self.text_encoder(input_ids=batch["input_ids"],at)#GettingImageandTextEmbeddings(withsamedimension)#GettingImageandTextEmbeddings(withsamedimension)image_embeddings=self.image_projection(image_features)text_embeddings=self.text_projection(text_features)returnimage_embeddings,text_embeddingsdefcross_entropy(preds,targets,reduction='none'):log_softmax=nn.LogSoftmax(dim=-1)loss=(-targets*log_softmax(preds))returnloss.mean()#================================================model=CLIPModel()optimizer=torch.optim.AdamW(model.parameters(),lr=CFG.lr,weight_decay=CFG.weight_decay)#================================================root_path='c:/oopc/m_clip_data/traiim_list+=[os.path.join(root_path,'class_02',i)foriinos.listdir(root_path+'class_0im_list+=[os.path.join(root_path,'class_03',i)foriinos.listdir(root_path+'class_0im_list+=[os.path.join(root_path,'class_04',i)foriinos.listdir(root_path+'class_04'#把圖片轉(zhuǎn)換成Tensortransform=transforms.Compose([transforms.Resize((224,224)),transforms.ToTensor()])data_set=ImageFolder(root_path,transform=transform)40traintrain_loader=DataLoader(data_set,batch_size=7,shuffle=False)checkpoint="bert-base-uncased"tokenizer=AutoTokenizer.from_pretrained(checkpoint)"一棵人蔘","一條韓國(guó)人蔘","一盤(pán)枸杞","這是銀川枸杞","這是兩只靈芝","許多靈芝","玫瑰香菇",]batch=tokenizer(sequences,padding=True,truncation=True,return_tensors="pt")forbatch_idx,(data,target)inenumerate#====================================================image_embeddings,text_embeddings=model(batch)temperature=CFG.temperaturelogits=(text_embeddings@image_embeddings.T)/temperatureimages_similarity=image_embeddings@image_embeddings.Ttexts_similarity=text_embeddings@text_embeddings.T41(images_similarity(images_similarity+texts_similarity)/2*tempera)texts_loss=cross_entropy(logits,targets,reduction='none')images_loss=cross_entropy(logits.T,targets.T,reduction='none')loss_v=(images_loss+texts_loss)/2.0#shape:(batoptimizer.zero_grad()optimizer.step()#END42訓(xùn)練過(guò)程是可以暫停的,與暫停之前,可以將目前的參數(shù)值匯出(Export),段落,將模型W&B儲(chǔ)存於*.pt。下一個(gè)程式,就繼續(xù)加碼訓(xùn)練。##clip_ex_006_pt_im.pyimporttorchvision.modelsasmodelsfromtorchvisionimporttransformsfromtorch.utils.dataimportDataset,DataLoaderfromtransformersimportDistilBertModel,DistilBertConfigfromtransformersimportAutoTokenizerdefcross_entropy(preds,targets,reduction='none'):log_softmax=nn.LogSoftmax(dim=-1)loss=(-targets*log_softmax(preds))returnloss.mean()#================================================model=CL.CLIPModel()optimizer=torch.optim.AdamW(model.parameters(),lr=CFG.lr,weight_decay=CFG.weight_decay)43#==========================#================================================#把圖片轉(zhuǎn)換成Tensortransform=transforms.Compose([transforms.Resize((224,224)),transforms.ToTensor()filenames=glob.glob(path+'**/*.jpg',recursive=True)forfnainfilenames:img_tensor=transform(image)img_tensor_list.append(img_tensor)print('\n',len(img_tensor_lisclassmyDataset(Dataset):def__init__(self):def__getitem__(self,idx):def__len__(self):returnlen(img_tensor_list)mydset=myDataset()#將data_set放入DataLoaader裡train_loader=DataLoader(mydset,batch_size=7,shuffle=False)checkpoint="bert-base-uncased"44tokenizertokenizer=AutoTokenizer.from_pretrained(checkpoint)"一棵人蔘","一條韓國(guó)人蔘","一盤(pán)枸杞","這是銀川枸杞","這是兩只靈芝","許多靈芝","玫瑰香菇",]batch=tokenizer(sequences,padding=True,truncation=True,return_tensors="pt")forbatch_idx,datainenumerate(train_loader):#====================================================print('訓(xùn)練',epochs,'回合...')image_embeddings,text_embeddings=model(batch)temperature=CFG.temperaturelogits=(text_embeddings@image_embeddings.T)/temperatureimages_similarity=image_embeddings@image_embeddings.Ttexts_similarity=text_embeddings@text_embeddings.T(images_similarity+texts_similarity)/2)45textstexts_loss=cross_entropy(logits,targets,reduction='none')images_loss=cross_entropy(logits.T,targets.T,reduction='none')loss_v=(images_loss+texts_loss)/2.0#shape:(batch_optimizer.zero_grad()optimizer.step()print('modelsaved')torch.save(model,FILE)#END46就暫時(shí)匯出,儲(chǔ)存於c:/ox/目錄下的clip_111*.pt檔案,##clip_ex_007_pt.pyimporttorchvision.modelsasmodelsfromtorchvisionimporttransformsfromtorch.utils.dataimportDataset,DataLoaderfromtransformersimportDistilBertModel,DistilBertConfigfromtransformersimportAutoTokenizerdefcross_entropy(preds,targets,reduction='none'):log_softmax=nn.LogSoftmax(dim=-1)loss=(-targets*log_softmax(preds))returnloss.mean()#================================================model=CL.CLIPModel()47model=torch.load(model=torch.load(optimizer=torch.optim.AdamW(model.parameters(),lr=CFG.lr,weight_decay=CFG.weight_decay)#================================================#把圖片轉(zhuǎn)換成Tensortransform=transforms.Compose([transforms.Resize((224,224)),transforms.ToTensor()filenames=glob.glob(path+'**/*.jpg',recursive=True)forfnainfilenames:img_tensor=transform(image)img_tensor_list.append(img_tensor)print('\n',len(img_tensor_lisclassmyDataset(Dataset):def__init__(self):def__getitem__(self,idx):def__len__(self):returnlen(img_tensor_list)mydset=myDataset()48#將data_set#將data_set放入DataLoaader裡train_loader=DataLoader(mydset,batch_size=7,shuffle=False)checkpoint="bert-base-uncased"tokenizer=AutoTokenizer.from_pretrained(checkpoint)"一棵人蔘","一條韓國(guó)人蔘","一盤(pán)枸杞","這是銀川枸杞","這是兩只靈芝","許多靈芝","玫瑰香菇",]batch=tokenizer(sequences,padding=True,truncation=True,return_tensors="pt")forbatch_idx,datainenumerate(train_loader):#====================================================print('訓(xùn)練',epochs,'回合...')image_embeddings,text_embeddings=model(batch)temperature=CFG.temperaturelogits=(text_embeddings@image_embeddings.T)/temperature49images_similarityimages_similarity=image_embeddings@image_embeddings.Ttexts_similarity=text_embeddings@text_embeddings.T(images_similarity+texts_similarity)/2)texts_loss=cross_entropy(logits,targets,reduction='none')images_loss=cross_entropy(logits.T,targets.T,reduction='none')loss_v=(images_loss+texts_loss)/2.0#shape:(batch_soptimizer.zero_grad()optimizer.step()print('modelsaved')torch.save(model,FILE)#ENDA.4觀察相似度矩陣文本與圖像之間的相似度,並以矩陣表達(dá)之,如下圖。##clip_ex_008_simi.pyimporttorchvision.modelsasmodelsfromtorchvisionimporttransformsfromtorch.utils.dataimportDataset,DataLoaderfromtransformersimportDistilBertModel,DistilBertConfigfromtransformersimportAutoTokenizerdefcross_entropy(preds,targets,reduction='none'):log_softmax=nn.LogSoftmax(dim=-1)loss=(-targets*log_softmax(preds))returnloss.mean()#================================================model=CL.CLIPModel()model=torch.load(optimizer=torch.optim.AdamW(model.parameters(),lr=CFG.lr,weight_decay=CFG.weight_decay)#================================================#把圖片轉(zhuǎn)換成Tensortransform=transforms.Compose([transforms.Resize((224,224)),transforms.ToTensor()filenames=glob.glob(path+'**/*.jpg',recursive=True)forfnainfilenames:img_tensor=transform(image)img_tensor_list.append(img_tensor)print('\n',len(img_tensor_lisclassmyDataset(Dataset):def__init__(self):def__getitem__(self,idx):def__len__(self):returnlen(img_tensor_list)mydset=myDataset()#將data_set放入DataLoaader裡train_loader=DataLoader(mydset,batch_size=7,shuffle=False)checkpoint="bert-base-uncased"tokenizer=AutoTokenizer.from_pretrained(checkpoint)"一棵人蔘","一條韓國(guó)人蔘","一盤(pán)枸杞","這是銀川枸杞","這是兩只靈芝","這是兩只靈芝","許多靈芝","玫瑰香菇",]batch=tokenizer(sequences,padding=True,truncation=True,return_tensors="pt")forbatch_idx,datainenumerate(train_loader):#====================================================print('繼續(xù)訓(xùn)練',epochs,'回合..image_embeddings,text_embeddings=model(batch)temperature=CFG.temperaturelogits=(text_embeddings@image_embeddings.T)/temperatureimages_similarity=image_embeddings@image_embeddings.Ttexts_similarity=text_embeddings@text_embeddings.T(images_similarity+texts_similarity)/2)texts_loss=cross_entropy(logits,targets,reduction='none')images_loss=cross_entropy(logits.T,targets.T,reduction='none')loss_v=(images_loss+texts_loss)/2.0#shape:(batch_optimizer.zero_grad()optimizeroptimizer.step()print('modelsaved')torch.save(model,FILE)#===========預(yù)測(cè)==================================test_sequences=sequences.copy()model.eval()withtorch.no_grad():t_batch=tokenizer(test_sequences,padding=True,truncation=True,return_tensors="pt")t_text_features=model.text_encoder(input_ids=t_batch["input_ids"],atte)t_text_embeddings=model.text_projection(t_text_features)t_image_features=model.image_encoder(data)t_image_embeddings=model.image_projection(t_image_features)t_image_embeddings_n=F.normalize(t_image_embeddings,p=2,dim=-1)t_text_embeddings_n=F.normalize(t_text_embeddings,p=2,dim=dot_similarity=t_text_embeddings_n@t_image_embsimi=F.softmax(dot_s#ENDA.5預(yù)測(cè)(一):輸入一張圖,找出最相似(接近)文句會(huì)輸出:玫瑰蘑菇。##clip_ex_009_pred.pyfromfromtorch.utils.dataimportDataset,DataLoaderfromtorchvisionimporttransformsfromtransformersimportAutoTokenizerdefcross_entropy(preds,targets,re

溫馨提示

  • 1. 本站所有資源如無(wú)特殊說(shuō)明,都需要本地電腦安裝OFFICE2007和PDF閱讀器。圖紙軟件為CAD,CAXA,PROE,UG,SolidWorks等.壓縮文件請(qǐng)下載最新的WinRAR軟件解壓。
  • 2. 本站的文檔不包含任何第三方提供的附件圖紙等,如果需要附件,請(qǐng)聯(lián)系上傳者。文件的所有權(quán)益歸上傳用戶(hù)所有。
  • 3. 本站RAR壓縮包中若帶圖紙,網(wǎng)頁(yè)內(nèi)容里面會(huì)有圖紙預(yù)覽,若沒(méi)有圖紙預(yù)覽就沒(méi)有圖紙。
  • 4. 未經(jīng)權(quán)益所有人同意不得將文件中的內(nèi)容挪作商業(yè)或盈利用途。
  • 5. 人人文庫(kù)網(wǎng)僅提供信息存儲(chǔ)空間,僅對(duì)用戶(hù)上傳內(nèi)容的表現(xiàn)方式做保護(hù)處理,對(duì)用戶(hù)上傳分享的文檔內(nèi)容本身不做任何修改或編輯,并不能對(duì)任何下載內(nèi)容負(fù)責(zé)。
  • 6. 下載文件中如有侵權(quán)或不適當(dāng)內(nèi)容,請(qǐng)與我們聯(lián)系,我們立即糾正。
  • 7. 本站不保證下載資源的準(zhǔn)確性、安全性和完整性, 同時(shí)也不承擔(dān)用戶(hù)因使用這些下載資源對(duì)自己和他人造成任何形式的傷害或損失。

最新文檔

評(píng)論

0/150

提交評(píng)論