Python torchvision.models 模块,__dict__() 实例源码

我们从Python开源项目中,提取了以下5个代码示例,用于说明如何使用torchvision.models.__dict__()

项目:DeepLearning_PlantDiseases    作者:MarkoArsenovic    | 项目源码 | 文件源码
def load_defined_model(name, num_classes):

    model = models.__dict__[name](num_classes=num_classes)

    #Densenets don't (yet) pass on num_classes, hack it in for 169
    if name == 'densenet169':
        model = torchvision.models.DenseNet(num_init_features=64, growth_rate=32, \
                                            block_config=(6, 12, 32, 32), num_classes=num_classes)

    pretrained_state = model_zoo.load_url(model_urls[name])

    #Diff
    diff = [s for s in diff_states(model.state_dict(), pretrained_state)]
    print("Replacing the following state from initialized", name, ":", \
          [d[0] for d in diff])

    for name, value in diff:
        pretrained_state[name] = value

    assert len([s for s in diff_states(model.state_dict(), pretrained_state)]) == 0

    #Merge
    model.load_state_dict(pretrained_state)
    return model, diff
项目:vsepp    作者:fartashf    | 项目源码 | 文件源码
def get_cnn(self, arch, pretrained):
        """Load a pretrained CNN and parallelize over GPUs
        """
        if pretrained:
            print("=> using pre-trained model '{}'".format(arch))
            model = models.__dict__[arch](pretrained=True)
        else:
            print("=> creating model '{}'".format(arch))
            model = models.__dict__[arch]()

        if arch.startswith('alexnet') or arch.startswith('vgg'):
            model.features = nn.DataParallel(model.features)
            model.cuda()
        else:
            model = nn.DataParallel(model).cuda()

        return model
项目:DeepLearning_PlantDiseases    作者:MarkoArsenovic    | 项目源码 | 文件源码
def load_defined_model(path, num_classes,name):
    model = models.__dict__[name](num_classes=num_classes)
    pretrained_state = torch.load(path)
    new_pretrained_state= OrderedDict()

    for k, v in pretrained_state['state_dict'].items():
        layer_name = k.replace("module.", "")
        new_pretrained_state[layer_name] = v

    #Diff
    diff = [s for s in diff_states(model.state_dict(), new_pretrained_state)]
    if(len(diff)!=0):
        print("Mismatch in these layers :", name, ":", [d[0] for d in diff])

    assert len(diff) == 0

    #Merge
    model.load_state_dict(new_pretrained_state)
    return model


#Load the model
项目:DeepLearning_PlantDiseases    作者:MarkoArsenovic    | 项目源码 | 文件源码
def load_defined_model(path, num_classes,name):
    model = models.__dict__[name](num_classes=num_classes)
    pretrained_state = torch.load(path)
    new_pretrained_state= OrderedDict()

    for k, v in pretrained_state['state_dict'].items():
        layer_name = k.replace("module.", "")
        new_pretrained_state[layer_name] = v

    #Diff
    diff = [s for s in diff_states(model.state_dict(), new_pretrained_state)]
    if(len(diff)!=0):
        print("Mismatch in these layers :", name, ":", [d[0] for d in diff])

    assert len(diff) == 0

    #Merge
    model.load_state_dict(new_pretrained_state)
    return model


#Load the model
项目:emu    作者:mlosch    | 项目源码 | 文件源码
def _nn_forward_hook(self, module, input, output, name=''):
        if type(output) is list:
            self.blobs[name] = [o.data.clone() for o in output]
        else:
            self.blobs[name] = output.data.clone()

    # @staticmethod
    # def _load_model_config(model_def):
    #     if isinstance(model_def, torch.nn.Module):
    #
    #     elif '.' not in os.path.basename(model_def):
    #         import torchvision.models as models
    #         if model_def not in models.__dict__:
    #             raise KeyError('Model {} does not exist in pytorchs model zoo.'.format(model_def))
    #         print('Loading model {} from pytorch model zoo'.format(model_def))
    #         return models.__dict__[model_def](pretrained=True)
    #     else:
    #         print('Loading model from {}'.format(model_def))
    #         if model_def.endswith('.t7'):
    #             return load_legacy_model(model_def)
    #         else:
    #             return torch.load(model_def)
    #
    #
    #     if type(model_cfg) == str:
    #         if not os.path.exists(model_cfg):
    #             try:
    #                 class_ = getattr(applications, model_cfg)
    #                 return class_(weights=model_weights)
    #             except AttributeError:
    #                 available_mdls = [attr for attr in dir(applications) if callable(getattr(applications, attr))]
    #                 raise ValueError('Could not load pretrained model with key {}. '
    #                                  'Available models: {}'.format(model_cfg, ', '.join(available_mdls)))
    #
    #         with open(model_cfg, 'r') as fileh:
    #             try:
    #                 return model_from_json(fileh)
    #             except ValueError:
    #                 pass
    #
    #             try:
    #                 return model_from_yaml(fileh)
    #             except ValueError:
    #                 pass
    #
    #         raise ValueError('Could not load model from configuration file {}. '
    #                          'Make sure the path is correct and the file format is yaml or json.'.format(model_cfg))
    #     elif type(model_cfg) == dict:
    #         return Model.from_config(model_cfg)
    #     elif type(model_cfg) == list:
    #         return Sequential.from_config(model_cfg)
    #
    #     raise ValueError('Could not load model from configuration object of type {}.'.format(type(model_cfg)))