pytorch安裝 cuda版本總結

蜂鸟直升机 2024-04-24 09:48 1次浏览 0 条评论 taohigo.com

起因,我之前在2080ti上使用瞭pip安裝瞭pytorch1.9+cuda10.2,在2080ti上可以跑(2080ti上cuda版本為10.0),我將所有的環境直接使用scp原封不動挪到3090以後,跑不瞭瞭,報錯!具體報錯我忘瞭,總之是很離譜的報錯,好像是找不到對應的cuda還是啥

經過我大量資料查找和實驗,基本確定瞭一下

註意pytorch的cuda版本這個概念,理解這個概念很重要,pytorch的cuda版本是pytorch版本所可以使用的最高cuda版本,是可以高於服務器本地的cuda版本的

首先,安裝命令參考pytorch官網的命令,需要註意的是,pytorch官網選擇cuda版本的時候需要註意以下地方:

  1. 服務器本地的cuda版本的時候應該與gpu算力匹配。(算力有影響的,像是3090算力為8.6,必須裝cuda11以上的版本)。然後選擇的pytorch的cuda版本必須高於服務器本地的cuda版本。比如我之前在2080ti上使用瞭pip安裝瞭pytorch1.9+cuda10.2,在2080ti上可以跑(2080ti上cuda版本為10.0),我將所有的環境直接使用scp原封不動挪到3090就是因為3090本地的cuda版本是11.1,大於安裝的pytorch1.9+cuda10.2,所以會直接報錯!!!

2.服務器本地cuda版本應該與driver版本匹配,gpu driver版本通過nvidia-smi查看。這個因為大部分服務器本地都安裝有cuda,一般來說都是滿足的。這一點唯一要註意的就是(看底下)如果使用conda安裝的話,不會使用服務器本地的cudatoolkit。因此要註意conda 安裝的cudatoolkit版本號要與服務器中的driver版本匹配好,不能太高瞭

3.pytorch版本和pytorch的cuda版本不沖突,這個最簡單,隻要pytorch官網上存在對應的包,就不沖突,安裝就完事瞭。經過我的實驗,隻要pytorch的cuda版本稍微高於服務器本地環境的cuda版本是沒有問題的,可以正常跑的。

4. 服務器本地的cuda位置:在/usr/local/cuda有軟鏈接

5.關於nvidia-smi和nvcc -V顯示的cuda版本不一樣原因:

  • 第一是因為環境本地裝的cuda版本正確的應該是nvcc -V顯示的版本。nvcc在環境變量path中被定義,一般在/usr/local/cuda-10.0/bin這種。
  • nvidia-smi顯示的cuda版本是對應的driver版本所最高支持的cuda版本!!!也就是說,在第一步選cuda版本的時候你不可以選擇版本大於這裡顯示的,小於是可以的!

6.關於conda和pip安裝區別:

  • conda會安裝一個cudatoolkit,這是一個已經編譯完的各種可執行文件的集合庫,註意裡面並沒有gpu driver!!! conda裝的pytorch會直接調用這個cudatoolkit,而不使用服務器本地環境安裝的cuda。為什麼我這樣說呢,因為你conda uninstall cudatoolkit –force以後程序是跑不瞭的,會報錯。而直接conda uninstall cudatoolkit會刪除pytorch,torchvision這些包
  • 使用pip安裝pytorch以後不會給你安裝cudatoolkit,而是會使用服務器本地安裝好的cuda,因此需要保證pytorch的cuda版本號大於本地環境的cuda版本號,以便pytorch可以正常使用。經過我的驗證,使用pip安裝的時候pytorch版本號就算大於nvidia-smi版本號也沒關系,因為此時使用的還是服務器本地安裝的cuda。因此隻需要保證pytorch的cuda版本號大於服務器本地安裝cuda版本號即可。

7.驗證能夠正常使用:

import torch

torch.cuda.is_available()

x = torch.randn(2,3)

x.cuda()

  • 看版本號:
    • torch.version.cuda #查看pytorch的cuda版本號!!!!,通常可以比實際環境cuda版本大一些
    • import torch.utils.cpp_extension #通常跟nvcc -V顯示的版本一樣的 torch.utils.cpp_extension.CUDA_HOME 編譯c++的cuda程序時候使用的cuda位置

參考文獻:我覺得寫得非常好

marsggbo:顯卡,顯卡驅動,nvcc, cuda driver,cudatoolkit,cudnn到底是什麼?

Pytorch 使用不同版本的 cuda