diff --git a/pkgs/development/python-modules/pytorch/default.nix b/pkgs/development/python-modules/pytorch/default.nix index e65626e7787..7d80e756c77 100644 --- a/pkgs/development/python-modules/pytorch/default.nix +++ b/pkgs/development/python-modules/pytorch/default.nix @@ -1,4 +1,4 @@ -{ buildPythonPackage, pythonOlder, +{ fetchurl, buildPythonPackage, pythonOlder, cudaSupport ? false, cudatoolkit ? null, cudnn ? null, fetchFromGitHub, lib, numpy, pyyaml, cffi, typing, cmake, hypothesis, linkFarm, symlinkJoin, @@ -36,6 +36,14 @@ in buildPythonPackage rec { sha256 = "076cpbig4sywn9vv674c0xdg832sdrd5pk1d0725pjkm436kpvlm"; }; + patches = + [ # Skips two tests that are only meant to run on multi GPUs + (fetchurl { + url = "https://github.com/pytorch/pytorch/commit/bfa666eb0deebac21b03486e26642fd70d66e478.patch"; + sha256 = "1fgblcj02gjc0y62svwc5gnml879q3x2z7m69c9gax79dpr37s9i"; + }) + ]; + preConfigure = lib.optionalString cudaSupport '' export CC=${cudatoolkit.cc}/bin/gcc CXX=${cudatoolkit.cc}/bin/g++ '' + lib.optionalString (cudaSupport && cudnn != null) ''