AccelModel Class
Abstract base class for accel models.
Accelerated models are neural networks that can be accelerated using dedicated hardware.
Abstract base class for anaccel model.
To add a new model, implement a subclass - model folder_name, version, checkpoint_uri and save_name should all be given by the user. Weight path and is frozen should be exposed at least for quantized versions. :param model_base_path: The base path to store all models in. Generally given by the user. :param model_folder_name: The path on disk to store all versions of this model. :param version: The version of this model. :param check_point_uri: The URI where the model is downloaded from if they don't have it on disk. :param save_name: The name the checkpoint is saved under, used to load metagraph. :param is_frozen: If the model should be frozen when it is loaded. This freezes the graph by removing the variables from tf.GraphKeys.TRAINABLE_VARIABLES. :param weight_path: A custom path to load weights from, instead of the default path on disk. Used in retraining scenarios.
- Inheritance
-
AccelModel
Constructor
AccelModel(model_base_path, model_folder_name, version, check_point_uri, save_name, is_frozen=False, weight_path=None)
Parameters
Name | Description |
---|---|
model_base_path
Required
|
|
model_folder_name
Required
|
|
version
Required
|
|
check_point_uri
Required
|
|
save_name
Required
|
|
is_frozen
|
default value: False
|
weight_path
|
default value: None
|
Methods
get_default_classifier |
Import a frozen, default Imagenet classifier for the model into the current graph. |
get_input_dims |
Get nth model input tensor dimensions. |
get_output_dims |
Get nth model output tensor dimensions. |
import_graph_def |
Import the graph definition corresponding to this model. Imports accelerated model into currently active graph. |
restore_weights |
Restore the weights of the model into the specific session. |
save_weights |
Save the weights of the model from a specific session into a specific path. |
get_default_classifier
Import a frozen, default Imagenet classifier for the model into the current graph.
get_default_classifier(input_tensor, prefix='classifier')
Parameters
Name | Description |
---|---|
prefix
|
namespace to load classifier into. default value: classifier
|
input_tensor
Required
|
The input feature tensor for the classifier. Expected to be [?, 2048] |
model_dir
Required
|
The directory to download the classifier into. Used as a cache locally. |
get_input_dims
Get nth model input tensor dimensions.
abstract get_input_dims(index=0)
Parameters
Name | Description |
---|---|
index
|
default value: 0
|
get_output_dims
Get nth model output tensor dimensions.
abstract get_output_dims(index=0)
Parameters
Name | Description |
---|---|
index
|
default value: 0
|
import_graph_def
Import the graph definition corresponding to this model.
Imports accelerated model into currently active graph.
import_graph_def(input_tensor=None, is_training=True)
Parameters
Name | Description |
---|---|
input_tensor
|
Replace input tensor to accelerated model (must match expected shape and dtype) default value: None
|
is_training
|
Boolean indicating if the imported graph is intending for training. default value: True
|
Returns
Type | Description |
---|---|
Either single output tensor or list of output tensors (if more than one). |
restore_weights
Restore the weights of the model into the specific session.
restore_weights(session)
Parameters
Name | Description |
---|---|
session
Required
|
<xref:tf.Session>
The session to load the weights into. |
save_weights
Save the weights of the model from a specific session into a specific path.
save_weights(path, session=None)
Parameters
Name | Description |
---|---|
path
Required
|
Path of the checkpoint to save the weights into. |
session
|
<xref:tf.Session>
Session to save weights from. default value: None
|
Attributes
input_tensor_list
List of names of the input tensors of this model.
model_path
Path to directory that contains the model.
model_ref
Name that refers to the model - used for writing the model_def.
model_version
Model Version.
output_tensor_list
List of names of the output tensors of this model.
Feedback
https://aka.ms/ContentUserFeedback.
Coming soon: Throughout 2024 we will be phasing out GitHub Issues as the feedback mechanism for content and replacing it with a new feedback system. For more information see:Submit and view feedback for