diff --git a/pkgs/development/libraries/science/math/magma/default.nix b/pkgs/development/libraries/science/math/magma/default.nix index 38700c963bf..2079ace021b 100644 --- a/pkgs/development/libraries/science/math/magma/default.nix +++ b/pkgs/development/libraries/science/math/magma/default.nix @@ -50,4 +50,6 @@ in stdenv.mkDerivation { platforms = platforms.unix; maintainers = with maintainers; [ tbenst ]; }; + + passthru.cudatoolkit = cudatoolkit; } diff --git a/pkgs/development/python-modules/pytorch/default.nix b/pkgs/development/python-modules/pytorch/default.nix index 9bdead706d7..4635f813c3a 100644 --- a/pkgs/development/python-modules/pytorch/default.nix +++ b/pkgs/development/python-modules/pytorch/default.nix @@ -25,15 +25,11 @@ assert !openMPISupport || openmpi != null; assert !cudaSupport || cudatoolkit != null; assert cudnn == null || cudatoolkit != null; assert !cudaSupport || (let majorIs = lib.versions.major cudatoolkit.version; - in majorIs == "9" || majorIs == "10"); + in majorIs == "9" || majorIs == "10" || majorIs == "11"); -let - hasDependency = dep: pkg: lib.lists.any (inp: inp == dep) pkg.buildInputs; - matchesCudatoolkit = hasDependency cudatoolkit; -in # confirm that cudatoolkits are sync'd across dependencies -assert !(openMPISupport && cudaSupport) || matchesCudatoolkit openmpi; -assert !cudaSupport || matchesCudatoolkit magma; +assert !(openMPISupport && cudaSupport) || openmpi.cudatoolkit == cudatoolkit; +assert !cudaSupport || magma.cudatoolkit == cudatoolkit; let cudatoolkit_joined = symlinkJoin {