Python collections 模块,defaultdict() 实例源码

我们从Python开源项目中,提取了以下50个代码示例,用于说明如何使用collections.defaultdict()

项目:rca-evaluation    作者:sieve-microservices    | 项目源码 | 文件源码
def draw(path, srv):
     filename = os.path.join(path, srv["preprocessed_filename"])
     df = pd.read_csv(filename, sep="\t", index_col='time', parse_dates=True)
     bins = defaultdict(list)
     for i, col in enumerate(df.columns):
         serie = df[col].dropna()
         if pd.algos.is_monotonic_float64(serie.values, False)[0]:
             serie = serie.diff()[1:]
         p_value = adfuller(serie, autolag='AIC')[1]
         if math.isnan(p_value): continue
         nearest = 0.05 * round(p_value/0.05)
         bins[nearest].append(serie)
     for bin, members in bins.items():
         series = [serie.name for serie in members]
         if len(members) <= 10:
             columns = series
         else:
             columns = random.sample(series, 10)

         subset = df[columns]
         name = "%s_adf_confidence_%.2f.png" % (srv["name"], bin)
         print(name)
         axes = subset.plot(subplots=True)
         plt.savefig(os.path.join(path, name))
         plt.close("all")
项目:deep-summarization    作者:harpribot    | 项目源码 | 文件源码
def precook(s, n=4, out=False):
    """
    Takes a string as input and returns an object that can be given to
    either cook_refs or cook_test. This is optional: cook_refs and cook_test
    can take string arguments as well.

    :param s:
    :param n:
    :param out:
    :return:
    """
    words = s.split()
    counts = defaultdict(int)
    for k in xrange(1,n+1):
        for i in xrange(len(words)-k+1):
            ngram = tuple(words[i:i+k])
            counts[ngram] += 1
    return (len(words), counts)
项目:lang-reps    作者:chaitanyamalaviya    | 项目源码 | 文件源码
def add_unk(self, thresh=0, unk_string='<UNK>'):
        if unk_string in self.s2t.keys(): raise Exception("tried to add an UNK token that already existed")
        if self.unk is not None: raise Exception("already added an UNK token")
        strings = [unk_string]
        for token in self.tokens:
            if token.count >= thresh: strings.append(token.s)
        if self.START_TOK is not None and self.START_TOK not in strings: strings.append(self.START_TOK.s)
        if self.END_TOK is not None and self.END_TOK not in strings: strings.append(self.END_TOK.s)
        self.tokens = set([])
        self.strings = set([])
        self.i2t = defaultdict(lambda :self.unk)
        self.s2t = defaultdict(lambda :self.unk)
        for string in strings:
            self.add_string(string)
        self.unk = self.s2t[unk_string]
        if self.START_TOK is not None: self.START_TOK = self.s2t[self.START_TOK.s]
        if self.END_TOK is not None: self.END_TOK = self.s2t[self.END_TOK.s]
项目:mpiFFT4py    作者:spectralDNS    | 项目源码 | 文件源码
def __init__(self, N, L, comm, precision,
                 communication="Alltoall",
                 padsize=1.5,
                 threads=1,
                 planner_effort=defaultdict(lambda: "FFTW_MEASURE")):
        R2C.__init__(self, N, L, comm, precision,
                     communication=communication,
                     padsize=padsize, threads=threads,
                     planner_effort=planner_effort)
        # Reuse all shapes from r2c transform R2C simply by resizing the final complex z-dimension:
        self.Nf = N[2]
        self.Nfp = int(self.padsize*self.N[2]) # Independent complex wavenumbers in z-direction for padded array

        # Rename since there's no real space
        self.original_shape_padded = self.real_shape_padded
        self.original_shape = self.real_shape
        self.transformed_shape = self.complex_shape
        self.original_local_slice = self.real_local_slice
        self.transformed_local_slice = self.complex_local_slice
        self.ks = (fftfreq(N[2])*N[2]).astype(int)
项目:mpiFFT4py    作者:spectralDNS    | 项目源码 | 文件源码
def __init__(self, N, L, comm, precision, padsize=1.5, threads=1,
                 planner_effort=defaultdict(lambda : "FFTW_MEASURE")):
        self.N = N         # The global size of the problem
        self.L = L
        assert len(L) == 2
        assert len(N) == 2
        self.comm = comm
        self.float, self.complex, self.mpitype = float, complex, mpitype = datatypes(precision)
        self.num_processes = comm.Get_size()
        self.rank = comm.Get_rank()
        self.padsize = padsize
        self.threads = threads
        self.planner_effort = planner_effort
        # Each cpu gets ownership of Np indices
        self.Np = N // self.num_processes
        self.Nf = N[1]//2+1
        self.Npf = self.Np[1]//2+1 if self.rank+1 == self.num_processes else self.Np[1]//2
        self.Nfp = int(padsize*self.N[1]/2+1)
        self.ks = (fftfreq(N[0])*N[0]).astype(int)
        self.dealias = zeros(0)
        self.work_arrays = work_arrays()
项目:ubidump    作者:nlitsme    | 项目源码 | 文件源码
def scanblocks(self):
        """
        creates map of volid + lnum => physical lnum
        """
        self.vmap = defaultdict(lambda : defaultdict(int))
        for lnum in range(self.maxlebs):

            try:
                ec = UbiEcHeader()
                hdr = self.readblock(lnum, 0, ec.hdrsize)
                ec.parse(hdr)

                vid = UbiVidHead()
                viddata = self.readblock(lnum, ec.vid_hdr_ofs, vid.hdrsize)
                vid.parse(viddata)

                self.vmap[vid.vol_id][vid.lnum] = lnum
            except:
                pass
项目:pytorch-semseg    作者:meetshah1995    | 项目源码 | 文件源码
def __init__(self, root, split="train_aug", 
                 is_transform=False, img_size=512, augmentations=None):
        self.root = root
        self.split = split
        self.is_transform = is_transform
        self.augmentations = augmentations
        self.n_classes = 21
        self.img_size = img_size if isinstance(img_size, tuple) else (img_size, img_size)
        self.mean = np.array([104.00699, 116.66877, 122.67892])
        self.files = collections.defaultdict(list)

        for split in ["train", "val", "trainval"]:
            file_list = tuple(open(root + '/ImageSets/Segmentation/' + split + '.txt', 'r'))
            file_list = [id_.rstrip() for id_ in file_list]
            self.files[split] = file_list

        if not os.path.isdir(self.root + '/SegmentationClass/pre_encoded'):
            self.setup(pre_encode=True)
        else:
            self.setup(pre_encode=False)
项目:abodepy    作者:MisterWil    | 项目源码 | 文件源码
def __init__(self, abode, reconnect_hours=12):
        """Init event subscription class."""
        self._abode = abode
        self._thread = None
        self._socketio = None
        self._running = False

        # Setup callback dicts
        self._device_callbacks = collections.defaultdict(list)
        self._event_callbacks = collections.defaultdict(list)
        self._timeline_callbacks = collections.defaultdict(list)

        # Default "sane" values
        self._ping_interval = 25.0
        self._ping_timeout = 60.0
        self._last_pong = None
        self._max_connection_time = reconnect_hours * 3600
        self._connection_time = None
项目:picoCTF    作者:picoCTF    | 项目源码 | 文件源码
def get_team_member_solve_stats(eligible=True):
    db = api.api.common.get_conn()
    teams = api.team.get_all_teams(show_ineligible=(not eligible))
    user_breakdowns = {}
    for t in teams:
        uid_map = defaultdict(lambda: defaultdict(int))
        members = api.team.get_team_members(tid=t['tid'], show_disabled=False)
        subs = db.submissions.find({'tid': t['tid']})
        for sub in subs:
            uid = sub['uid']
            uid_map[uid]['submits'] += 1
            if uid_map[uid]['times'] == 0:
                uid_map[uid]['times'] = list()
            uid_map[uid]['times'].append(sub['timestamp'])
            if sub['correct']:
                uid_map[uid]['correct'] += 1
                uid_map[uid][sub['category']] += 1
            else:
                uid_map[uid]['incorrect'] += 1
        user_breakdowns[t['tid']] = uid_map
        for member in members:
            if member['uid'] not in uid_map:
                uid_map[uid] = None
    return user_breakdowns
项目:picoCTF    作者:picoCTF    | 项目源码 | 文件源码
def get_team_participation_percentage(eligible=True, user_breakdown=None):
    if user_breakdown is None:
        user_breakdown = get_team_member_solve_stats(eligible)
    team_size_any = defaultdict(list)
    team_size_correct = defaultdict(list)
    for tid, breakdown in user_breakdown.items():
        count_any = 0
        count_correct = 0
        for uid, work in breakdown.items():
            if work is not None:
                count_any += 1
                if work['correct'] > 0:
                    count_correct += 1
        team_size_any[len(breakdown.keys())].append(count_any)
        team_size_correct[len(breakdown.keys())].append(count_correct)
    return {x: statistics.mean(y) for x, y in team_size_any.items()}, \
           {x: statistics.mean(y) for x, y in team_size_correct.items()}
项目:OMW    作者:globalwordnet    | 项目源码 | 文件源码
def fetch_src_id_pos_stats(src_id):
        src_pos_stats=dd(lambda: dd(int))
        pos = fetch_pos()
        r  =  query_omw_direct("""    
        SELECT pos_id, count(distinct s.ss_id),
        count(distinct s.w_id), count(distinct s.id)
        FROM s JOIN s_src
        ON s.id=s_src.s_id
        JOIN ss ON s.ss_id=ss.id
        WHERE s_src.src_id=? group by pos_id""", (src_id,))
        for (p, ss, w, s) in r:
            ps =  pos['id'][p] 
            src_pos_stats[ps]['synsets'] = ss
            src_pos_stats[ps]['words'] = w
            src_pos_stats[ps]['senses'] = s
        return  src_pos_stats
项目:OMW    作者:globalwordnet    | 项目源码 | 文件源码
def fetch_ssrel_stats(src_id):
        constitutive = ['instance_hyponym','instance_hypernym',
                         'hypernym', 'hyponym',
                         'synonym', 'antonym',
                         'mero_part', 'holo_part',
                         'mero_member', 'holo_member',
                         'mero_substance', 'holo_substance' ]
        src_ssrel_stats = dd(int)
        ssrl=fetch_ssrel()
        for r in query_omw("""
        SELECT  ssrel_id, count(ssrel_id)
        FROM sslink JOIN sslink_src
        ON sslink.id=sslink_src.sslink_id
        WHERE sslink_src.src_id=?
        GROUP BY ssrel_id""", [src_id]):
            link = ssrl['id'][r['ssrel_id']]
            src_ssrel_stats[link[0]] = r['count(ssrel_id)']
            src_ssrel_stats['TOTAL'] += r['count(ssrel_id)']
            if link[0] in constitutive:
                src_ssrel_stats['CONSTITUATIVE'] += r['count(ssrel_id)']

        return src_ssrel_stats
项目:OMW    作者:globalwordnet    | 项目源码 | 文件源码
def f_rate_summary(ili_ids):
        """
        This function takes a list of ili ids and returns a dictionary with the
        cumulative ratings filtered by the ids.
        """
        counts = dd(lambda: dd(int))
        rates = fetch_rate_id(ili_ids)
        up_who = dd(list)
        down_who = dd(list)
        for key, value in rates.items():
            for (r, u, t) in rates[key]:
                if r == 1:
                    counts[int(key)]['up'] += 1
                    up_who[int(key)].append(u)
                elif r == -1:
                    counts[int(key)]['down'] += 1
                    down_who[int(key)].append(u)

        return counts, up_who, down_who
项目:OMW    作者:globalwordnet    | 项目源码 | 文件源码
def fetch_comment_id(ili_ids, u=None):
        """
        This function takes a list of ili ids and, optionally a username.
        It returns a dictionary with the comments filtered by the ids and,
        if provided, for that specific user.
        """
        comments = dd(list)
        ili_list = (",".join("?" for s in ili_ids), ili_ids)
        if u:
            for r in query_omw("""SELECT id, ili_id, com, u, t
                                  FROM ili_com
                                  WHERE ili_id in ({})
                                  AND   u = ?""".format(ili_list[0]),
                               ili_list[1]+[u]):
                comments[r['ili_id']].append((r['com'], r['u'], r['t']))
        else:
            for r in query_omw("""SELECT id, ili_id, com, u, t
                                  FROM ili_com
                                  WHERE ili_id in ({})
                               """.format(ili_list[0]),
                               ili_list[1]):
                comments[r['ili_id']].append((r['com'], r['u'], r['t']))

        return comments
项目:django_pipedrive    作者:MasAval    | 项目源码 | 文件源码
def create_modifications(cls, instance, previous, current):

        prev = defaultdict(lambda: None, previous)
        curr = defaultdict(lambda: None, current)

        # Compute difference between previous and current
        diffkeys = set([k for k in prev if prev[k] != curr[k]])
        in_previous_not_current = set([k for k in prev if k not in curr])
        in_current_not_previous = set([k for k in curr if k not in prev])

        diffkeys = diffkeys.union(in_previous_not_current).union(in_current_not_previous)
        current_datetime = timezone.now()

        for key in diffkeys:
            FieldModification.objects.create(
                field_name=key,
                previous_value=prev[key],
                current_value=curr[key],
                content_object=instance,
                created=current_datetime,
            )
项目:seq2seq    作者:google    | 项目源码 | 文件源码
def load(model_dir):
    """ Loads options from the given model directory.

    Args:
      model_dir: Path to the model directory.
    """
    with gfile.GFile(TrainOptions.path(model_dir), "rb") as file:
      options_dict = json.loads(file.read().decode("utf-8"))
    options_dict = defaultdict(None, options_dict)

    return TrainOptions(
        model_class=options_dict["model_class"],
        model_params=options_dict["model_params"])
项目:zipline-chinese    作者:zhanghan1990    | 项目源码 | 文件源码
def cancel_all(self, sid):
        """
        Cancel all open orders for a given sid.
        """
        # (sadly) open_orders is a defaultdict, so this will always succeed.
        orders = self.open_orders[sid]

        # We're making a copy here because `cancel` mutates the list of open
        # orders in place.  The right thing to do here would be to make
        # self.open_orders no longer a defaultdict.  If we do that, then we
        # should just remove the orders once here and be done with the matter.
        for order in orders[:]:
            self.cancel(order.id)

        assert not orders
        del self.open_orders[sid]
项目:pbtk    作者:marin-m    | 项目源码 | 文件源码
def assert_installed(win=None, modules=[], binaries=[]):
    missing = defaultdict(list)
    for items, what, func in ((modules, 'modules', find_spec),
                              (binaries, 'binaries', which)):
        for item in items:
            if not func(item):
                missing[what].append(item)
    if missing:
        msg = []
        for subject, names in missing.items():
            if len(names) == 1:
                subject = {'modules': 'module', 'binaries': 'binary'}[subject]
            msg.append('%s "%s"' % (subject, '", "'.join(names)))
        msg = 'You are missing the %s for this.' % ' and '.join(msg)
        if win:
            from PyQt5.QtWidgets import QMessageBox
            QMessageBox.warning(win, ' ', msg)
        else:
            raise ImportError(msg)
    return not missing
项目:pbtk    作者:marin-m    | 项目源码 | 文件源码
def load_endpoints(self):
        self.choose_endpoint.endpoints.clear()

        for name in listdir(str(BASE_PATH / 'endpoints')):
            if name.endswith('.json'):
                item = QListWidgetItem(name.split('.json')[0], self.choose_endpoint.endpoints)
                item.setFlags(item.flags() & ~Qt.ItemIsEnabled)

                pb_msg_to_endpoints = defaultdict(list)
                with open(str(BASE_PATH / 'endpoints' / name)) as fd:
                    for endpoint in load(fd, object_pairs_hook=OrderedDict):
                        pb_msg_to_endpoints[endpoint['request']['proto_msg'].split('.')[-1]].append(endpoint)

                for pb_msg, endpoints in pb_msg_to_endpoints.items():
                    item = QListWidgetItem(' ' * 4 + pb_msg, self.choose_endpoint.endpoints)
                    item.setFlags(item.flags() & ~Qt.ItemIsEnabled)

                    for endpoint in endpoints:
                        path_and_qs = '/' + endpoint['request']['url'].split('/', 3).pop()
                        item = QListWidgetItem(' ' * 8 + path_and_qs, self.choose_endpoint.endpoints)
                        item.setData(Qt.UserRole, endpoint)

        self.set_view(self.choose_endpoint)
项目:binf-scripts    作者:lazappi    | 项目源码 | 文件源码
def group_filenames(filenames, pat, sep):
    """
    Group files based on their merged file names.

    Args:
        filenames: List of filename strings to group.
        pat: Merging patter to use for grouping.
        sep: String separating filename sections.

    Returns:
        Dictionary with group names as keys and lists of original filenames as
        values.
    """

    groups = defaultdict(list)

    for filename in filenames:
        group = merge_filename(filename, pat, sep)
        groups[group].append(filename)

    return groups
项目:cellranger    作者:10XGenomics    | 项目源码 | 文件源码
def split_genes_by_genomes(genes, genomes):
    """ Returns a list of lists [genome1, genome2, ...]
    where genome1 = [gene1,gene2,...].
    Args:
      genes - list of Gene tuples
      genomes - list of genome names, e.g. ['hg19', 'mm10']
    """
    assert len(genomes) > 0

    if len(genomes) == 1:
        return [genes]

    d = collections.defaultdict(list)
    for gene in genes:
        genome = get_genome_from_str(gene.id, genomes)
        d[genome].append(gene)

    genes_per_genome = []
    for genome in genomes:
        genes_per_genome.append(d[genome])
    return genes_per_genome
项目:kinect-2-libras    作者:inessadl    | 项目源码 | 文件源码
def _get_headnode_dict(fixer_list):
    """ Accepts a list of fixers and returns a dictionary
        of head node type --> fixer list.  """
    head_nodes = collections.defaultdict(list)
    every = []
    for fixer in fixer_list:
        if fixer.pattern:
            try:
                heads = _get_head_types(fixer.pattern)
            except _EveryNode:
                every.append(fixer)
            else:
                for node_type in heads:
                    head_nodes[node_type].append(fixer)
        else:
            if fixer._accept_type is not None:
                head_nodes[fixer._accept_type].append(fixer)
            else:
                every.append(fixer)
    for node_type in chain(pygram.python_grammar.symbol2number.itervalues(),
                           pygram.python_grammar.tokens):
        head_nodes[node_type].extend(every)
    return dict(head_nodes)
项目:cbapi-python    作者:carbonblack    | 项目源码 | 文件源码
def main():
    parser = build_cli_parser()
    args = parser.parse_args()
    c = get_cb_response_object(args)

    hostname_user_pairs = defaultdict(ItemCount)
    username_activity = defaultdict(ItemCount)

    for proc in c.select(Process).where("process_name:explorer.exe"):
        hostname_user_pairs[proc.hostname].add(proc.username)
        username_activity[proc.username].add(proc.hostname)

    for hostname, user_activity in iteritems(hostname_user_pairs):
        print("For host {0:s}:".format(hostname))
        for username, count in user_activity.report():
            print("  %-20s: logged in %d times" % (username, count))
项目:cbapi-python    作者:carbonblack    | 项目源码 | 文件源码
def main():
    parser = build_cli_parser("Delete duplicate computers")
    parser.add_argument("--dry-run", "-d", help="perform a dry run, don't actually delete the computers",
                        action="store_true", dest="dry_run")

    args = parser.parse_args()
    p = get_cb_protection_object(args)

    computer_list = defaultdict(list)
    for computer in p.select(Computer).where("deleted:false"):
        computer_list[computer.name].append({"id": computer.id, "offline": computer.daysOffline})

    for computer_name, computer_ids in iteritems(computer_list):
        if len(computer_ids) > 1:
            sorted_computers = sorted(computer_ids, key=lambda x: x["offline"], reverse=True)
            for computer_id in sorted_computers[:-1]:
                if computer_id["offline"] > 0:
                    print("deleting computer id %d (offline %d days, hostname %s)" % (computer_id["id"],
                                                                                      computer_id["offline"],
                                                                                      computer_name))
                    if not args.dry_run:
                        print("deleting from server...")
                        p.select(Computer, computer_id["id"]).delete()
项目:board-games-app    作者:sampathweb    | 项目源码 | 文件源码
def generate_bot_move(self):
        """Returns the computer selected row, col
        """
        selections = defaultdict(list)
        if self.bot_level == 1:  # Easy - Pick any one from valid_choices list
            selected_item = random.choice(self.game_choices)
        elif self.bot_level == 2:  # Hard - Try to block the player from winning
            for win_set in self.winning_combos:
                rem_items = list(win_set - self.player_a_choices - self.player_b_choices)
                selections[len(rem_items)].append(rem_items)
            if selections.get(1):
                selected_item = random.choice(random.choice(selections[1]))
            elif selections.get(2):
                selected_item = random.choice(random.choice(selections[2]))
            else:
                selected_item = random.choice(random.choice(selections[3]))
        return selected_item
项目:tts-bug-bounty-dashboard    作者:18F    | 项目源码 | 文件源码
def report_on_perm_differences(self, program_list):
        perms = defaultdict(dict)

        for program in program_list:
            program_name = program['data']['attributes']['handle']
            for member in program['data']['relationships']['members']['data']:
                username = member['relationships']['user']['data']['attributes']['username']
                permissions = member['attributes']['permissions']
                perms[username][program_name] = set(permissions)

        for user in perms:
            handled = False
            for program in perms[user]:
                other_programs = set(perms[user].keys()) - set([program])
                for other_program in other_programs:
                    if perms[user][program] != perms[user][other_program]:
                        self.stdout.write(f'Mismatching perms for {user}:')
                        self.stdout.write(f'    {program}: {perms[user][program]}')
                        self.stdout.write(f'    {other_program}: {perms[user][other_program]}')
                        handled = True
                if handled:
                    break
项目:conec    作者:cod3licious    | 项目源码 | 文件源码
def build_windex(self, sentences, wordlist=[]):
        """
        go through all the sentences and get an overview of all used words and their frequencies
        """
        # get an overview of the vocabulary
        vocab = defaultdict(int)
        total_words = 0
        for sentence_no, sentence in enumerate(sentences):
            if not sentence_no % self.progress:
                print("PROGRESS: at sentence #%i, processed %i words and %i unique words" % (sentence_no, sum(vocab.values()), len(vocab)))
            for word in sentence:
                vocab[word] += 1
        print("collected %i unique words from a corpus of %i words and %i sentences" % (len(vocab), sum(vocab.values()), sentence_no + 1))
        # assign a unique index to each word and remove all words with freq < min_count
        self.wcounts, self.word2index, self.index2word = {}, {}, []
        if not wordlist:
            wordlist = [word for word, c in vocab.items() if c >= self.min_count]
        for word in wordlist:
            self.word2index[word] = len(self.word2index)
            self.index2word.append(word)
            self.wcounts[word] = vocab[word]
项目:kubey    作者:bradrf    | 项目源码 | 文件源码
def ctl_each(obj, command, arguments):
    '''Invoke any kubectl command directly for each pod matched and collate the output.'''
    width, height = click.get_terminal_size()
    kubectl = obj.kubey.kubectl
    collector = tabular.RowCollector()
    ns_pods = defaultdict(list)
    for pod in obj.kubey.each_pod(obj.maximum):
        ns_pods[pod.namespace].append(pod)
    for ns, pods in ns_pods.items():
        args = ['-n', ns] + list(arguments) + [p.name for p in pods]
        kubectl.call_table_rows(collector.handler_for(ns), command, *args)
    kubectl.wait()
    if collector.rows:
        click.echo(tabular.tabulate(obj, sorted(collector.rows), collector.headers))
    if kubectl.final_rc != 0:
        click.get_current_context().exit(kubectl.final_rc)
项目:concierge    作者:9seconds    | 项目源码 | 文件源码
def fix_star_host(root):
    star_host = None

    for host in root.childs:
        if host.name == "*":
            LOG.debug("Detected known '*' host.")
            star_host = host
            break
    else:
        LOG.debug("Add new '*' host.")
        star_host = root.add_host("*")

    values = collections.defaultdict(set)
    values.update(root.values)
    values.update(star_host.values)
    star_host.values = values
    star_host.trackable = True
    root.values.clear()

    return root
项目:simple_rl    作者:david-abel    | 项目源码 | 文件源码
def __init__(self, init_predicate, term_predicate, policy, name="o", term_prob=0.01):
        '''
        Args:
            init_func (S --> {0,1})
            init_func (S --> {0,1})
            policy (S --> A)
        '''
        self.init_predicate = init_predicate
        self.term_predicate = term_predicate
        self.term_flag = False
        self.name = name
        self.term_prob = term_prob

        if type(policy) is defaultdict or type(policy) is dict:
            self.policy_dict = dict(policy)
            self.policy = self.policy_from_dict
        else:
            self.policy = policy
项目:simple_rl    作者:david-abel    | 项目源码 | 文件源码
def __init__(self, mdp_prob_dict, horizon=0):
        '''
        Args:
            mdp_prob_dict (dict):
                Key (MDP)
                Val (float): Represents the probability with which the MDP is sampled.

        Notes:
            @mdp_prob_dict can also be a list, in which case the uniform distribution is used.
        '''
        if type(mdp_prob_dict) == list or len(mdp_prob_dict.values()) == 0:
            # Assume uniform if no probabilities are provided.
            mdp_prob = 1.0 / len(mdp_prob_dict.keys())
            new_dict = defaultdict(float)
            for mdp in mdp_prob_dict:
                new_dict[mdp] = mdp_prob
            mdp_prob_dict = new_dict

        self.horizon = horizon
        self.mdp_prob_dict = mdp_prob_dict
项目:simple_rl    作者:david-abel    | 项目源码 | 文件源码
def _compute_matrix_from_trans_func(self):
        if self.has_computed_matrix:
            self._compute_reachable_state_space()
            # We've already run this, just return.
            return

        self.trans_dict = defaultdict(lambda:defaultdict(lambda:defaultdict(float)))
            # K: state
                # K: a
                    # K: s_prime
                    # V: prob

        for s in self.get_states():
            for a in self.actions:
                for sample in xrange(self.sample_rate):
                    s_prime = self.transition_func(s, a)
                    self.trans_dict[s][a][s_prime] += 1.0 / self.sample_rate

        self.has_computed_matrix = True
项目:simple_rl    作者:david-abel    | 项目源码 | 文件源码
def __init__(self, num_states=5, num_rand_trans=5, gamma=0.99):
        '''
        Args:
            num_states (int) [optional]: Number of states in the Random MDP.
            num_rand_trans (int) [optional]: Number of possible next states.

        Summary:
            Each state-action pair picks @num_rand_trans possible states and has a uniform distribution
            over them for transitions. Rewards are also chosen randomly.
        '''
        MDP.__init__(self, RandomMDP.ACTIONS, self._transition_func, self._reward_func, init_state=RandomState(1), gamma=gamma)
        # assert(num_rand_trans <= num_states)
        self.num_rand_trans = num_rand_trans
        self.num_states = num_states
        self._reward_s_a = (random.choice(range(self.num_states)), random.choice(RandomMDP.ACTIONS))
        self._transitions = defaultdict(lambda: defaultdict(str))
项目:otRebuilder    作者:Pal3love    | 项目源码 | 文件源码
def test_results_unchanged(self):
        """Tests that the results of conversion haven't changed since the time
        of this test's writing. Useful as a quick check whenever one modifies
        the conversion algorithm.
        """

        expected = {
            2: 6,
            3: 26,
            4: 82,
            5: 232,
            6: 360,
            7: 266,
            8: 28}

        results = collections.defaultdict(int)
        for spline in self.single_splines:
            n = len(spline) - 2
            results[n] += 1
        self.assertEqual(results, expected)
        self.results.append(('single spline lengths', results))
项目:lang-reps    作者:chaitanyamalaviya    | 项目源码 | 文件源码
def __init__(self):
        self.tokens = set([])
        self.strings = set([])
        self.s2t = defaultdict(Token.not_found)
        self.i2t = defaultdict(Token.not_found)
        self.unk = None
        self.START_TOK = None
        self.END_TOK = None
项目:lang-reps    作者:chaitanyamalaviya    | 项目源码 | 文件源码
def load(cls, filename):
        with open(filename, "r") as f:
            info_dict = pickle.load(f)
            v = Vocab()
            v.tokens = info_dict["tokens"]
            v.strings = info_dict["strings"]
            v.unk = info_dict["unk"]
            v.START_TOK = info_dict["START_TOK"]
            v.END_TOK = info_dict["END_TOK"]
            defaultf = (lambda :v.unk) if (v.unk is not None) else Token.not_found
            v.s2t = defaultdict(defaultf, info_dict["s2t"])
            v.i2t = defaultdict(defaultf, info_dict["i2t"])
            return v
项目:flash_services    作者:textbook    | 项目源码 | 文件源码
def format_data(self, name, data):
        counts = defaultdict(int)
        for issue in data:
            if issue.get('pull_request') is not None:
                counts['{}-pull-requests'.format(issue['state'])] += 1
            else:
                counts['{}-issues'.format(issue['state'])] += 1
        half_life = self.half_life(data)
        return dict(
            halflife=naturaldelta(half_life),
            health=self.health_summary(half_life),
            issues=counts,
            name=name,
        )
项目:flash_services    作者:textbook    | 项目源码 | 文件源码
def story_summary(stories):
        """Get a summary of stories in each state.

        Arguments:
          stories (:py:class:`list`): A list of stories.

        Returns:
          :py:class:`collections.defaultdict`: Summary of points by
            story state.

        """
        result = defaultdict(int)
        for story in stories:
            result[story['current_state']] += int(story.get('estimate', 0))
        return result
项目:topically-driven-language-model    作者:jhlau    | 项目源码 | 文件源码
def gen_vocab(dummy_symbols, corpus, stopwords, vocab_minfreq, vocab_maxfreq, verbose):
    idxvocab = []
    vocabxid = defaultdict(int)
    vocab_freq = defaultdict(int)
    for line_id, line in enumerate(codecs.open(corpus, "r", "utf-8")):
        for word in line.strip().split():
            vocab_freq[word] += 1
        if line_id % 1000 == 0 and verbose:
            sys.stdout.write(str(line_id) + " processed\r")
            sys.stdout.flush()

    #add in dummy symbols into vocab
    for s in dummy_symbols:
        update_vocab(s, idxvocab, vocabxid)

    #remove low fequency words
    for w, f in sorted(vocab_freq.items(), key=operator.itemgetter(1), reverse=True):
        if f < vocab_minfreq:
            break
        else:
            update_vocab(w, idxvocab, vocabxid)

    #ignore stopwords, frequent words and symbols for the document input for topic model
    stopwords = set([item.strip().lower() for item in open(stopwords)])
    freqwords = set([item[0] for item in sorted(vocab_freq.items(), key=operator.itemgetter(1), \
        reverse=True)[:int(float(len(vocab_freq))*vocab_maxfreq)]]) #ignore top N% most frequent words for topic model
    alpha_check = re.compile("[a-zA-Z]")
    symbols = set([ w for w in vocabxid.keys() if ((alpha_check.search(w) == None) or w.startswith("'")) ])
    ignore = stopwords | freqwords | symbols | set(dummy_symbols) | set(["n't"])
    ignore = set([vocabxid[w] for w in ignore if w in vocabxid])

    return idxvocab, vocabxid, ignore
项目:mpiFFT4py    作者:spectralDNS    | 项目源码 | 文件源码
def __init__(self, N, L, comm, precision,
                 communication="Alltoallw",
                 padsize=1.5,
                 threads=1,
                 planner_effort=defaultdict(lambda: "FFTW_MEASURE")):
        assert len(L) == 3
        assert len(N) == 3
        self.N = N
        self.Nf = N[2]//2+1          # Independent complex wavenumbers in z-direction
        self.Nfp = int(padsize*N[2]//2+1) # Independent complex wavenumbers in z-direction for padded array
        self.comm = comm
        self.float, self.complex, self.mpitype = datatypes(precision)
        self.communication = communication
        self.num_processes = comm.Get_size()
        self.rank = comm.Get_rank()
        self.Np = N // self.num_processes
        self.L = L.astype(self.float)
        self.dealias = np.zeros(0)
        self.padsize = padsize
        self.threads = threads
        self.planner_effort = planner_effort
        self.work_arrays = work_arrays()
        if not self.num_processes in [2**i for i in range(int(np.log2(N[0]))+1)]:
            raise IOError("Number of cpus must be in ",
                          [2**i for i in range(int(np.log2(N[0]))+1)])
        self._subarraysA = []
        self._subarraysB = []
        self._counts_displs = 0
        self._subarraysA_pad = []
        self._subarraysB_pad = []
项目:mpiFFT4py    作者:spectralDNS    | 项目源码 | 文件源码
def __init__(self, N, L, comm, precision, P1=None, communication='Alltoall',
                 padsize=1.5, threads=1,
                 planner_effort=defaultdict(lambda: "FFTW_MEASURE")):
        R2CY.__init__(self, N, L, comm, precision, P1=P1, communication=communication,
                      padsize=padsize, threads=threads, planner_effort=planner_effort)
        self.N2f = self.N2[2]//2 if self.comm1_rank < self.P2-1 else self.N2[2]//2+1
        if self.communication == 'AlltoallN':
            self.N2f = self.N2[2]//2
        if self.communication == 'Alltoallw':
            q = _subsize(self.Nf, self.P2, self.comm1_rank)
            self.N2f = q
项目:xr-telemetry-m2m-web    作者:cisco    | 项目源码 | 文件源码
def render_POST(self, request):
        """
        Handle a request from the client.
        """
        script_env = {
            method: api_method(request, method)
            for method in request.sdata.api.fns
        }

        # Make get do auto-formatting for convenience, even though this
        # breaks if you try to use literal '{}' named arguments
        # @@@ reconsider whether this is at all a good idea
        def get_with_formatting(path, *args):
            return api_method(request, 'get')(path.format(*args))
        script_env['get'] = get_with_formatting

        script_env['re'] = re
        script_env['dumps'] = dumps
        script_env['defaultdict'] = defaultdict
        script_env['OrderedDict'] = OrderedDict

        buf = []
        def dummy_print(*args):
            if len(args) == 1 and (isinstance(args[0], list) or isinstance(args[0], dict)):
                buf.append(dumps(args[0], indent=4))
            else:
                buf.append(' '.join(map(str, args)))
        script_env['print'] = dummy_print

        def run_script(script):
            try:
                exec script in script_env
            except:
                exception_info = sys.exc_info()
                buf.extend(traceback.format_exception(*exception_info))
            request.sdata.log('got reply {}'.format(buf))
            request.sdata.add_to_push_queue('script', text=dumps(buf))

        script = request.args['script'][0]
        reactor.callInThread(run_script, script)
项目:devops-playground    作者:jerrywardlow    | 项目源码 | 文件源码
def get_instances_by_region(self, region):
        ''' Makes an AWS EC2 API call to the list of instances in a particular
        region '''

        try:
            conn = self.connect(region)
            reservations = []
            if self.ec2_instance_filters:
                for filter_key, filter_values in self.ec2_instance_filters.items():
                    reservations.extend(conn.get_all_instances(filters = { filter_key : filter_values }))
            else:
                reservations = conn.get_all_instances()

            # Pull the tags back in a second step
            # AWS are on record as saying that the tags fetched in the first `get_all_instances` request are not
            # reliable and may be missing, and the only way to guarantee they are there is by calling `get_all_tags`
            instance_ids = []
            for reservation in reservations:
                instance_ids.extend([instance.id for instance in reservation.instances])

            max_filter_value = 199
            tags = []
            for i in range(0, len(instance_ids), max_filter_value):
                tags.extend(conn.get_all_tags(filters={'resource-type': 'instance', 'resource-id': instance_ids[i:i+max_filter_value]}))

            tags_by_instance_id = defaultdict(dict)
            for tag in tags:
                tags_by_instance_id[tag.res_id][tag.name] = tag.value

            for reservation in reservations:
                for instance in reservation.instances:
                    instance.tags = tags_by_instance_id[instance.id]
                    self.add_instance(instance, region)

        except boto.exception.BotoServerError as e:
            if e.error_code == 'AuthFailure':
                error = self.get_auth_error_message()
            else:
                backend = 'Eucalyptus' if self.eucalyptus else 'AWS'
                error = "Error connecting to %s backend.\n%s" % (backend, e.message)
            self.fail_with_error(error, 'getting EC2 instances')
项目:devops-playground    作者:jerrywardlow    | 项目源码 | 文件源码
def get_instances_by_region(self, region):
        ''' Makes an AWS EC2 API call to the list of instances in a particular
        region '''

        try:
            conn = self.connect(region)
            reservations = []
            if self.ec2_instance_filters:
                for filter_key, filter_values in self.ec2_instance_filters.items():
                    reservations.extend(conn.get_all_instances(filters = { filter_key : filter_values }))
            else:
                reservations = conn.get_all_instances()

            # Pull the tags back in a second step
            # AWS are on record as saying that the tags fetched in the first `get_all_instances` request are not
            # reliable and may be missing, and the only way to guarantee they are there is by calling `get_all_tags`
            instance_ids = []
            for reservation in reservations:
                instance_ids.extend([instance.id for instance in reservation.instances])

            max_filter_value = 199
            tags = []
            for i in range(0, len(instance_ids), max_filter_value):
                tags.extend(conn.get_all_tags(filters={'resource-type': 'instance', 'resource-id': instance_ids[i:i+max_filter_value]}))

            tags_by_instance_id = defaultdict(dict)
            for tag in tags:
                tags_by_instance_id[tag.res_id][tag.name] = tag.value

            for reservation in reservations:
                for instance in reservation.instances:
                    instance.tags = tags_by_instance_id[instance.id]
                    self.add_instance(instance, region)

        except boto.exception.BotoServerError as e:
            if e.error_code == 'AuthFailure':
                error = self.get_auth_error_message()
            else:
                backend = 'Eucalyptus' if self.eucalyptus else 'AWS'
                error = "Error connecting to %s backend.\n%s" % (backend, e.message)
            self.fail_with_error(error, 'getting EC2 instances')
项目:aws-consolidated-admin    作者:awslabs    | 项目源码 | 文件源码
def lambda_handler(event, context):
    status_counts = collections.defaultdict(lambda: 0)

    for wf in event['Workflows']:
        resp = sfn.describe_execution(
            executionArn=wf['ExecutionArn'])

        status = resp['status']
        wf['Status'] = status
        status_counts[status] += 1

        if 'stopDate' in resp:
            wf['StoppedAt'] = resp['stopDate'].isoformat()

    if status_counts['RUNNING'] > 0:
        event['Status'] = 'RUNNING'
    else:
        if status_counts['SUCCEEDED'] == len(event['Workflows']):
            event['Status'] = 'SUCCEEDED'
        else:
            event['Status'] = 'FAILED'

    return event
项目:python-    作者:secondtonone1    | 项目源码 | 文件源码
def _convert_extras_requirements(self):
        """
        Convert requirements in `extras_require` of the form
        `"extra": ["barbazquux; {marker}"]` to
        `"extra:{marker}": ["barbazquux"]`.
        """
        spec_ext_reqs = getattr(self, 'extras_require', None) or {}
        self._tmp_extras_require = defaultdict(list)
        for section, v in spec_ext_reqs.items():
            # Do not strip empty sections.
            self._tmp_extras_require[section]
            for r in pkg_resources.parse_requirements(v):
                suffix = self._suffix_for(r)
                self._tmp_extras_require[section + suffix].append(r)
项目:python-    作者:secondtonone1    | 项目源码 | 文件源码
def configuration_to_dict(handlers):
    """Returns configuration data gathered by given handlers as a dict.

    :param list[ConfigHandler] handlers: Handlers list,
        usually from parse_configuration()

    :rtype: dict
    """
    config_dict = defaultdict(dict)

    for handler in handlers:

        obj_alias = handler.section_prefix
        target_obj = handler.target_obj

        for option in handler.set_options:
            getter = getattr(target_obj, 'get_%s' % option, None)

            if getter is None:
                value = getattr(target_obj, option)

            else:
                value = getter()

            config_dict[obj_alias][option] = value

    return config_dict
项目:python-    作者:secondtonone1    | 项目源码 | 文件源码
def draw(self, screen):
        help_string1 = '(W)Up (S)Down (A)Left (D)Right'
        help_string2 = '     (R)Restart (Q)Exit'
        gameover_string = '           GAME OVER'
        win_string = '          YOU WIN!'
        def cast(string):
            screen.addstr(string + '\n')

        def draw_hor_separator():
            line = '+' + ('+------' * self.width + '+')[1:]
            separator = defaultdict(lambda: line)
            if not hasattr(draw_hor_separator, "counter"):
                draw_hor_separator.counter = 0
            cast(separator[draw_hor_separator.counter])
            draw_hor_separator.counter += 1

        def draw_row(row):
            cast(''.join('|{: ^5} '.format(num) if num > 0 else '|      ' for num in row) + '|')

        screen.clear()
        cast('SCORE: ' + str(self.score))
        if 0 != self.highscore:
            cast('HIGHSCORE: ' + str(self.highscore))
        for row in self.field:
            draw_hor_separator()
            draw_row(row)
        draw_hor_separator()
        if self.is_win():
            cast(win_string)
        else:
            if self.is_gameover():
                cast(gameover_string)
            else:
                cast(help_string1)
        cast(help_string2)



    #????????????
项目:db-import    作者:antismash    | 项目源码 | 文件源码
def test_parse_ripp_core():
    '''Test parse_ripp_core'''
    params = defaultdict(lambda: None)
    notes = [
        'Totally unrelated: nonsense',
        'monoisotopic mass: 3333.6',
        'molecular weight: 3336.0',
        'alternative weights: 3354.0; 3372.1; 3390.1; 3408.1',
        'number of bridges: 5',
        'predicted core seq: ITSISLCTPGCKTGALMGCNMKTATCHCSIHVSK',
        'predicted class: Class-I',
        'score: 26.70',
    ]

    expected = {
        'peptide_sequence': 'ITSISLCTPGCKTGALMGCNMKTATCHCSIHVSK',
        'molecular_weight': '3336.0',
        'monoisotopic_mass': '3333.6',
        'alternative_weights': '3354.0; 3372.1; 3390.1; 3408.1',
        'bridges': '5',
        'class': 'Class-I',
        'score': '26.70',
    }

    fake = FakeFeature()
    fake.qualifiers['note'] = notes

    parse_ripp_core(fake, params)
    assert params == expected
项目:db-import    作者:antismash    | 项目源码 | 文件源码
def get_lineage(taxid):
    '''Get the full lineage for a taxid from Entrez'''
    handle = Entrez.efetch(db="taxonomy", id=taxid, retmode="xml")
    records = Entrez.read(handle)
    lineage = defaultdict(lambda: 'Unclassified')
    for entry in records[0]['LineageEx']:
        if entry['Rank'] == 'no rank':
            continue
        lineage[entry['Rank']] = entry['ScientificName'].split(' ')[-1]

    return lineage