Share via


DirectMLX

DirectMLX est une bibliothèque d’assistance d’en-tête uniquement C++ pour DirectML, destinée à faciliter la composition d’opérateurs individuels dans des graphes.

DirectMLX fournit des wrappers pratiques pour tous les types d’opérateurs DirectML (DML), ainsi que des surcharges d’opérateur intuitives, ce qui simplifie l’instanciation des opérateurs DML et les chaîne en graphes complexes.

Où trouver DirectMLX.h

DirectMLX.h est distribué en tant que logiciel open source sous licence MIT. La dernière version est disponible sur DirectML GitHub.

Configuration requise pour la version

DirectMLX nécessite DirectML version 1.4.0 ou ultérieure (voir historique des versions DirectML). Les anciennes versions de DirectML ne sont pas prises en charge.

DirectMLX.h nécessite un compilateur compatible C++11, y compris (mais non limité à) :

  • Visual Studio 2017
  • Visual Studio 2019
  • Clang 10

Notez qu’un compilateur C++17 (ou version ultérieure) est l’option que nous vous recommandons. La compilation pour C++11 est possible, mais elle nécessite l’utilisation de bibliothèques tierces (telles que GSL et Abseil) pour remplacer les fonctionnalités de bibliothèque standard manquantes.

Si vous avez une configuration qui ne parvient pas à compiler DirectMLX.h, signalez un problème sur notre GitHub.

Utilisation de base

#include <DirectML.h>
#include <DirectMLX.h>

IDMLDevice* device;

/* ... */

dml::Graph graph(device);

// Input tensor of type FLOAT32 and sizes { 1, 2, 3, 4 }
auto x = dml::InputTensor(graph, 0, dml::TensorDesc(DML_TENSOR_DATA_TYPE_FLOAT32, {1, 2, 3, 4}));

// Create an operator to compute the square root of x
auto y = dml::Sqrt(x);

// Compile a DirectML operator from the graph. When executed, this compiled operator will compute
// the square root of its input.
DML_EXECUTION_FLAGS flags = DML_EXECUTION_FLAG_NONE;
ComPtr<IDMLCompiledOperator> op = graph.Compile(flags, { y });

// Now initialize and dispatch the DML operator as usual

Voici un autre exemple, qui crée un graphe DirectML capable de calculer la formule quadratique.

#include <DirectML.h>
#include <DirectMLX.h>

IDMLDevice* device;

/* ... */

std::pair<dml::Expression, dml::Expression>
    QuadraticFormula(dml::Expression a, dml::Expression b, dml::Expression c)
{
    // Quadratic formula: given an equation of the form ax^2 + bx + c = 0, x can be found by:
    //   x = -b +/- sqrt(b^2 - 4ac) / (2a)
    // https://en.wikipedia.org/wiki/Quadratic_formula

    // Note: DirectMLX provides operator overloads for common mathematical expressions. So for 
    // example a*c is equivalent to dml::Multiply(a, c).
    auto x1 = -b + dml::Sqrt(b*b - 4*a*c) / (2*a);
    auto x2 = -b - dml::Sqrt(b*b - 4*a*c) / (2*a);

    return { x1, x2 };
}

/* ... */

dml::Graph graph(device);

dml::TensorDimensions inputSizes = {1, 2, 3, 4};
auto a = dml::InputTensor(graph, 0, dml::TensorDesc(DML_TENSOR_DATA_TYPE_FLOAT32, inputSizes));
auto b = dml::InputTensor(graph, 1, dml::TensorDesc(DML_TENSOR_DATA_TYPE_FLOAT32, inputSizes));
auto c = dml::InputTensor(graph, 2, dml::TensorDesc(DML_TENSOR_DATA_TYPE_FLOAT32, inputSizes));

auto [x1, x2] = QuadraticFormula(a, b, c);

// When executed with input tensors a, b, and c, this compiled operator computes the two outputs
// of the quadratic formula, and returns them as two output tensors x1 and x2
DML_EXECUTION_FLAGS flags = DML_EXECUTION_FLAG_NONE;
ComPtr<IDMLCompiledOperator> op = graph.Compile(flags, { x1, x2 });

// Now initialize and dispatch the DML operator as usual

Autres exemples

Vous trouverez des exemples complets à l’aide de DirectMLX sur le référentiel GitHub DirectML.

Options au moment de la compilation

DirectMLX prend en charge les #define au moment de la compilation pour personnaliser différentes parties de l’en-tête.

Option Description
DMLX_NO_EXCEPTIONS Si défini (#define), il provoque des erreurs qui entraînent un appel à std::abort lieu de lever une exception. Cela est défini par défaut si les exceptions ne sont pas disponibles (par exemple, si des exceptions ont été désactivées dans les options du compilateur).
DMLX_USE_WIL Si défini (#define), les exceptions sont levées à l’aide de types d’exceptions bibliothèque d’implémentation Windows. Sinon, les types d’exceptions standard (std::runtime_error par exemple) sont utilisés à la place. Cette option n’a aucun effet si DMLX_NO_EXCEPTIONS est défini.
DMLX_USE_ABSEIL Si défini (#define), il utilise Abseil comme substitut pour les types de bibliothèque standard indisponibles dans C++11. Ces types incluent absl::optional (à la place de std::optional), absl::Span (à la place de std::span) et absl::InlinedVector.
DMLX_USE_GSL Contrôle s’il faut utiliser GSL comme remplacement de std::span. Si défini (#define), les utilisations de std::span sont remplacées par gsl::span sur les compilateurs sans implémentations natives std::span. Sinon, un substitut inclus est fourni à la place. Notez que cette option est utilisée uniquement lors de la compilation sur un compilateur pré-C++20 sans prise en charge de std::span, et quand aucun autre substitut de bibliothèque standard (comme Abseil) n’est en cours d’utilisation.

Contrôle de la disposition du tenseur

Pour la plupart des opérateurs, DirectMLX calcule les propriétés des tenseurs de sortie de l’opérateur en votre nom. Par exemple, lors de l’exécution d’un dml::Reduce parmi les axes { 0, 2, 3 } à l’aide d’un tenseur d’entrée de tailles { 3, 4, 5, 6 }, DirectMLX calcule automatiquement les propriétés du tenseur de sortie, y compris la forme correcte de { 1, 4, 1, 1 }.

Toutefois, les autres propriétés d’un tenseur de sortie incluent Strides, TotalTensorSizeInBytes et GuaranteedBaseOffsetAlignment. Par défaut, DirectMLX définit ces propriétés de sorte que le tenseur n’ait aucun striding, aucun alignement de décalage de base garanti et une taille totale de tenseur en octets calculée par DMLCalcBufferTensorSize.

DirectMLX prend en charge la possibilité de personnaliser ces propriétés de tenseur de sortie à l’aide d’objets appelés stratégies de tenseur. Un TensorPolicy est un rappel personnalisable qui est appelé par DirectMLX et retourne des propriétés de tenseur de sortie en fonction du type de données calculé, des indicateurs et des tailles de tenseur.

Les stratégies de tenseur peuvent être définies sur l’objet dml::Graph et seront utilisées pour tous les opérateurs suivants sur ce graphe. Les stratégies de tenseur peuvent également être définies directement lors de la construction d’un TensorDesc.

La disposition des tenseurs produits par DirectMLX peut donc être contrôlée en définissant une TensorPolicy qui définit les strides appropriés sur ses tenseurs.

Exemple 1

// Define a policy, which is a function that returns a TensorProperties given a data type,
// flags, and sizes.
dml::TensorProperties MyCustomPolicy(
    DML_TENSOR_DATA_TYPE dataType,
    DML_TENSOR_FLAGS flags,
    Span<const uint32_t> sizes)
{
    // Compute your custom strides, total tensor size in bytes, and guaranteed base
    // offset alignment
    dml::TensorProperties props;
    props.strides = /* ... */;
    props.totalTensorSizeInBytes = /* ... */;
    props.guaranteedBaseOffsetAlignment = /* ... */;
    return props;
};

// Set the policy on the dml::Graph
dml::Graph graph(/* ... */);
graph.SetTensorPolicy(dml::TensorPolicy(&MyCustomPolicy));

Exemple 2

DirectMLX fournit également d’autres stratégies de tenseur intégrées. La stratégie InterleavedChannel, par exemple, est fournie en guise de commodité et peut être utilisée pour produire des tenseurs avec des strides telles qu’elles sont écrites dans l’ordre NHWC.

// Set the InterleavedChannel policy on the dml::Graph
dml::Graph graph(/* ... */);
graph.SetTensorPolicy(dml::TensorPolicy::InterleavedChannel());

// When executed, the tensor `result` will be in NHWC layout (rather than the default NCHW)
auto result = dml::Convolution(/* ... */);

Voir aussi